diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index bad535efe6f..f674e4b5156 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -150,7 +150,6 @@ from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec from executorch.backends.arm.common.pipeline_config import ( ArmPassPipelineConfig, - FuseDuplicateUsersConfig, SoftmaxDecompositionConfig, ) from executorch.backends.arm.tosa.specification import ( @@ -238,9 +237,6 @@ def configure_skip_passes( case SoftmaxDecompositionConfig.STABLE: skip_set.add(DecomposeMaskedFillPass) - if config.fuse_duplicate_users is FuseDuplicateUsersConfig.DISABLED: - skip_set.add(FuseDuplicateUsersPass) - self._skip_pass_types = tuple(skip_set) skip_names = [skipped_pass.__name__ for skipped_pass in self._skip_pass_types] logger.debug(f"Passes in skip list: {skip_names}") diff --git a/backends/arm/common/pipeline_config.py b/backends/arm/common/pipeline_config.py index e27aae5f432..7da4e6ae5a1 100644 --- a/backends/arm/common/pipeline_config.py +++ b/backends/arm/common/pipeline_config.py @@ -14,24 +14,12 @@ class SoftmaxDecompositionConfig(Enum): STABLE = auto() # Stable softmax, no masked fill decomposition -class FuseDuplicateUsersConfig(Enum): - ENABLED = auto() - DISABLED = auto() - - @dataclass class ArmPassPipelineConfig: softmax: SoftmaxDecompositionConfig = SoftmaxDecompositionConfig.MASKED - fuse_duplicate_users: FuseDuplicateUsersConfig = FuseDuplicateUsersConfig.ENABLED - - def disable_fuse_duplicate_users(self) -> None: - self.fuse_duplicate_users = FuseDuplicateUsersConfig.DISABLED def is_default(self) -> bool: - return ( - self.softmax is SoftmaxDecompositionConfig.MASKED - and self.fuse_duplicate_users is FuseDuplicateUsersConfig.ENABLED - ) + return self.softmax is SoftmaxDecompositionConfig.MASKED def to_dict(self) -> dict[str, str]: return {f.name: getattr(self, f.name).name for f in fields(self)} diff --git a/backends/arm/test/misc/test_compile_spec.py b/backends/arm/test/misc/test_compile_spec.py index d9c24cd796a..f29b8851208 100644 --- a/backends/arm/test/misc/test_compile_spec.py +++ b/backends/arm/test/misc/test_compile_spec.py @@ -5,10 +5,7 @@ import warnings -from executorch.backends.arm.common.pipeline_config import ( - FuseDuplicateUsersConfig, - SoftmaxDecompositionConfig, -) +from executorch.backends.arm.common.pipeline_config import SoftmaxDecompositionConfig from executorch.backends.arm.ethosu import EthosUCompileSpec from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec from executorch.backends.arm.vgf import VgfCompileSpec @@ -66,11 +63,11 @@ def test_compile_spec_vgf_no_quant(): EthosUCompileSpec._from_list(spec_list) -def test_compile_spec_vgf_defaults_to_enabled_fuse_duplicate_users(): +def test_compile_spec_vgf_uses_default_pipeline_config(): compile_spec = VgfCompileSpec() pipeline_config = compile_spec._get_pass_pipeline_config() - assert pipeline_config.fuse_duplicate_users == FuseDuplicateUsersConfig.ENABLED + assert pipeline_config.is_default() def test_compile_spec_tosa_INT(): diff --git a/backends/arm/test/misc/test_pass_pipeline_config.py b/backends/arm/test/misc/test_pass_pipeline_config.py index 84575fb04fa..2f737b65d4a 100644 --- a/backends/arm/test/misc/test_pass_pipeline_config.py +++ b/backends/arm/test/misc/test_pass_pipeline_config.py @@ -29,14 +29,13 @@ def test_pipeline_config_override_outside_compile_spec(): override_compile_spec = TosaCompileSpec( TosaSpecification.create_from_string("TOSA-1.00+INT") ) - override_config = ArmPassPipelineConfig() - override_config.disable_fuse_duplicate_users() + override_config = ArmPassPipelineConfig(softmax=SoftmaxDecompositionConfig.STABLE) override_compile_spec.set_pass_pipeline_config(override_config) override_manager = ArmPassManager(override_compile_spec) skip_passes = override_manager._skip_pass_types - assert FuseDuplicateUsersPass in skip_passes - assert DecomposeSoftmaxPass not in skip_passes + assert FuseDuplicateUsersPass not in skip_passes + assert DecomposeMaskedFillPass in skip_passes def test_softmax_config_masked_no_target(): diff --git a/backends/arm/test/misc/test_tosa_dialect_conv2d.py b/backends/arm/test/misc/test_tosa_dialect_conv2d.py index 6f481d15ffd..9c0fa72b094 100644 --- a/backends/arm/test/misc/test_tosa_dialect_conv2d.py +++ b/backends/arm/test/misc/test_tosa_dialect_conv2d.py @@ -4,7 +4,7 @@ # LICENSE file in the root directory of this source tree. -import executorch.backends.arm.tosa.dialect # noqa: unused +import executorch.backends.arm.tosa.dialect # noqa: F401 import pytest import torch from executorch.backends.arm.tosa.dialect.lib import TosaValueError diff --git a/backends/arm/test/misc/test_tosa_dialect_dw_conv2d.py b/backends/arm/test/misc/test_tosa_dialect_dw_conv2d.py index c680f1bd7e3..56d2d6bc69a 100644 --- a/backends/arm/test/misc/test_tosa_dialect_dw_conv2d.py +++ b/backends/arm/test/misc/test_tosa_dialect_dw_conv2d.py @@ -4,7 +4,7 @@ # LICENSE file in the root directory of this source tree. -import executorch.backends.arm.tosa.dialect # noqa: unused +import executorch.backends.arm.tosa.dialect # noqa: F401 import pytest import torch from executorch.backends.arm.tosa.dialect.lib import TosaValueError diff --git a/backends/arm/tosa/compile_spec.py b/backends/arm/tosa/compile_spec.py index 2f4da1c85ac..d4dda6bec6d 100644 --- a/backends/arm/tosa/compile_spec.py +++ b/backends/arm/tosa/compile_spec.py @@ -4,9 +4,6 @@ # LICENSE file in the root directory of this source tree. from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec -from executorch.backends.arm.common.pipeline_config import ( # noqa: unused - ArmPassPipelineConfig, -) from executorch.backends.arm.tosa import TosaSpecification diff --git a/backends/arm/vgf/compile_spec.py b/backends/arm/vgf/compile_spec.py index fabf0ea19c0..b53a1e2f27b 100644 --- a/backends/arm/vgf/compile_spec.py +++ b/backends/arm/vgf/compile_spec.py @@ -6,9 +6,6 @@ import logging from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec -from executorch.backends.arm.common.pipeline_config import ( # noqa: unused - ArmPassPipelineConfig, -) from executorch.backends.arm.tosa import ( # type: ignore[import-not-found] TosaSpecification, )