Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/cortex_m/passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def _ensure_cortex_m_dependencies() -> None:
from .activation_fusion_pass import ActivationFusionPass # noqa
from .clamp_hardswish_pass import ClampHardswishPass # noqa
from .convert_to_cortex_m_pass import ConvertToCortexMPass # noqa
from .cortex_m_configuration import CortexMConfiguration # noqa
from .decompose_hardswish_pass import DecomposeHardswishPass # noqa
from .decompose_mean_pass import DecomposeMeanPass # noqa
from .quantized_clamp_activation_pass import QuantizedClampActivationPass # noqa
Expand Down
37 changes: 37 additions & 0 deletions backends/cortex_m/passes/cortex_m_configuration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright 2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from enum import auto, Enum
from typing import ClassVar, Mapping

import cmsis_nn # type: ignore[import-not-found, import-untyped]


class CortexMConfiguration(Enum):
M0 = auto()
M0PLUS = auto()
M3 = auto()
M4 = auto()
M7 = auto()
M23 = auto()
M33 = auto()
M35P = auto()
M55 = auto()
M85 = auto()
ANY = auto() # Guaranteed to work on any Cortex-M.
__members__: ClassVar[Mapping[str, "CortexMConfiguration"]]

@property
def backend(self) -> cmsis_nn.Backend:
if self == CortexMConfiguration.ANY:
# Currently, MVE is all we support. We can just return the MVE backend.
return cmsis_nn.Backend.MVE

cmsis_nn_cortex_m = cmsis_nn.CortexM.__members__.get(self.name, None)
if cmsis_nn_cortex_m is None:
raise ValueError(
f"CortexM configuration {self.name} is not supported by cmsis_nn."
)
return cmsis_nn.resolve_backend(cmsis_nn_cortex_m)
32 changes: 32 additions & 0 deletions backends/cortex_m/passes/cortex_m_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Copyright 2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from executorch.exir.pass_base import ExportPass
from torch.export import ExportedProgram

from .cortex_m_configuration import CortexMConfiguration


class CortexMPass(ExportPass):
"""
An abstract interface for CortexM backend passes.
"""

def __init__(
self, exported_program: ExportedProgram, cortex_m_config: CortexMConfiguration
) -> None:
super().__init__()
self._exported_program = exported_program
self._cortex_m_config = cortex_m_config

@property
def exported_program(self) -> ExportedProgram:
return self._exported_program

@property
def cortex_m_config(self) -> CortexMConfiguration:
return self._cortex_m_config
51 changes: 37 additions & 14 deletions backends/cortex_m/passes/cortex_m_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@


import inspect
from typing import Callable, cast, Optional, Type
from typing import Any, Optional, Type

import cmsis_nn # type: ignore[import-not-found, import-untyped]

from executorch.backends.arm._passes import (
FoldAndAnnotateQParamsPass,
Expand All @@ -19,13 +21,11 @@
from executorch.exir.pass_manager import PassManager
from executorch.exir.program._program import _transform, lift_constant_tensor_pass
from torch.export import ExportedProgram
from torch.fx.passes.infra.pass_base import PassResult

from torch.nn import Module

from .activation_fusion_pass import ActivationFusionPass
from .clamp_hardswish_pass import ClampHardswishPass
from .convert_to_cortex_m_pass import ConvertToCortexMPass
from .cortex_m_configuration import CortexMConfiguration
from .decompose_hardswish_pass import DecomposeHardswishPass
from .decompose_mean_pass import DecomposeMeanPass
from .quantized_clamp_activation_pass import QuantizedClampActivationPass
Expand Down Expand Up @@ -57,34 +57,57 @@ class CortexMPassManager(PassManager):
]

def __init__(
self, exported_program, passes: Optional[list[PassClass]] = None
self,
exported_program: ExportedProgram | None,
passes: Optional[list[PassClass]] = None,
cortex_m: CortexMConfiguration = CortexMConfiguration.ANY,
) -> None:
super().__init__(passes=[])
self.exported_program = exported_program
self.cortex_m_config = cortex_m
if self.cortex_m_config.backend != cmsis_nn.Backend.MVE:
raise NotImplementedError(
Comment thread
Erik-Lundell marked this conversation as resolved.
"Currently, the Cortex-M pass manager only supports MVE."
f"Got {self.cortex_m_config.name} with {self.cortex_m_config.backend.name}"
)

# PassManager.passes is typed as callables; this manager stores pass classes which are initialized at transform time with the exported_program.
self.passes: list[PassClass] = ( # type: ignore[assignment]
passes if passes is not None else self.pass_list # type: ignore[assignment]
)

def transform_for_annotation(self, model):

passes = self.pass_list_transform_for_annotation
for p in passes:
model = p().call(model).graph_module
return model

def transform(self) -> ExportedProgram:
ep = self.exported_program
exported_program = self.exported_program
if not isinstance(exported_program, ExportedProgram):
raise ValueError(
f"{self.__class__.__name__} needs an exported_program to run transform, got {exported_program=}"
)

for pass_cls in self.passes:
if not isinstance(pass_cls, type):
raise ValueError(
f"{self.__class__.__name__} can't have instansiated passes in pass list, got {pass_cls}."
)

signature = inspect.signature(pass_cls)
kwargs: dict[str, Any] = {}
if "exported_program" in signature.parameters:
ep_pass_ctor = cast(Callable[[ExportedProgram], ExportPass], pass_cls)
transform_pass = ep_pass_ctor(ep)
else:
transform_pass = pass_cls()
pass_callable = cast(Callable[[Module], PassResult], transform_pass)
ep = _transform(ep, pass_callable)
kwargs["exported_program"] = exported_program
if "cortex_m_config" in signature.parameters:
kwargs["cortex_m_config"] = self.cortex_m_config

transform_pass = pass_cls(**kwargs)
exported_program = _transform(exported_program, transform_pass)

# All constant tensors should be lifted to buffers at this point, re-run
# lift_constant_tensor_pass in case new ones have been introduced by the passes above.
ep = lift_constant_tensor_pass(ep)
return ep
exported_program = lift_constant_tensor_pass(exported_program)

return exported_program
Loading