diff --git a/.ci/docker/build.sh b/.ci/docker/build.sh index 123680e5275..673b5b4fd4b 100755 --- a/.ci/docker/build.sh +++ b/.ci/docker/build.sh @@ -89,6 +89,11 @@ case "${IMAGE_NAME}" in OS_VERSION=24.04 GCC_VERSION=14 ;; + executorch-ubuntu-26.04-gcc15) + LINTRUNNER="" + OS_VERSION=26.04 + GCC_VERSION=15 + ;; *) echo "Invalid image name ${IMAGE_NAME}" exit 1 diff --git a/.ci/docker/common/install_docs_reqs.sh b/.ci/docker/common/install_docs_reqs.sh index 3b6d10c5c2b..ea54d90523e 100755 --- a/.ci/docker/common/install_docs_reqs.sh +++ b/.ci/docker/common/install_docs_reqs.sh @@ -15,8 +15,8 @@ if [ -n "$BUILD_DOCS" ]; then curl --retry 3 --retry-all-errors -sL https://deb.nodesource.com/setup_16.x | sudo -E bash - sudo apt-get install -y nodejs - curl --retry 3 --retry-all-errors -sS https://dl.yarnpkg.com/debian/pubkey.gpg | sudo apt-key add - - echo "deb https://dl.yarnpkg.com/debian/ stable main" | sudo tee /etc/apt/sources.list.d/yarn.list + curl --retry 3 --retry-all-errors -sS https://dl.yarnpkg.com/debian/pubkey.gpg | sudo gpg --dearmor -o /usr/share/keyrings/yarn-archive-keyring.gpg + echo "deb [signed-by=/usr/share/keyrings/yarn-archive-keyring.gpg] https://dl.yarnpkg.com/debian/ stable main" | sudo tee /etc/apt/sources.list.d/yarn.list apt-get update apt-get install -y --no-install-recommends yarn diff --git a/.ci/docker/common/install_linter.sh b/.ci/docker/common/install_linter.sh index 52d2d262685..4a796a72d54 100755 --- a/.ci/docker/common/install_linter.sh +++ b/.ci/docker/common/install_linter.sh @@ -13,7 +13,3 @@ source "$(dirname "${BASH_SOURCE[0]}")/utils.sh" # NB: Install all linter dependencies, the caching of lintrunner init could be # done after Executorch becomes public pip_install -r requirements-lintrunner.txt - -# Install google-java-format -curl -L --retry 3 --retry-all-errors https://github.com/google/google-java-format/releases/download/v1.23.0/google-java-format_linux-x86-64 > /opt/google-java-format -chmod +x /opt/google-java-format diff --git a/.ci/scripts/test_riscv_qemu.sh b/.ci/scripts/test_riscv_qemu.sh index 2842542aa3a..0e5b44d97c2 100755 --- a/.ci/scripts/test_riscv_qemu.sh +++ b/.ci/scripts/test_riscv_qemu.sh @@ -4,10 +4,9 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# CI wrapper: install RISC-V cross-compile + qemu-user tooling, then run the -# RISC-V smoke test (export, cross-compile, qemu-user execution) via -# examples/riscv/run.sh. The bundled-IO comparison and Test_result: PASS -# check are done by run.sh. +# CI wrapper: install riscv32/64 cross-compile + qemu tooling, then drive +# examples/riscv/run.sh which does the export, cross-compile, qemu run, and +# bundled-IO PASS check. set -eu @@ -15,29 +14,41 @@ script_dir=$(realpath "$(dirname "${BASH_SOURCE[0]}")") et_root_dir=$(realpath "${script_dir}/../..") model="add" -xnnpack=false +backend="portable" quantize=false +os="linux" +arch="rv64" +qemu_cpu_ext="" verbose_xnnpack=false debug_xnnpack=false +build_dir= usage() { cat < Which model to export and run (default: add) - --xnnpack Enable the XNNPACK backend (AOT partitioner + runtime) - --quantize Produce an 8-bit quantized model - --verbose-xnnpack Build XNNPACK with XNN_LOG_LEVEL=4 to log microkernel dispatch - --debug-xnnpack Enable XNNPACK partitioner DEBUG logging and dump the lowered graph - -h, --help Show this help + --model= Which model to export and run (default: ${model}) + --quantize Produce an 8-bit quantized model + --backend= AOT backend (portable|xnnpack) (default: ${backend}) + --os= Target OS (linux|baremetal) (default: ${os}) + --arch= Target arch (rv32|rv64) (default: ${arch}) + --qemu-cpu-ext= QEMU -cpu extensions (no rv32/rv64 prefix, default: none) + --build-dir= Build/output directory for this configuration (required) + --verbose-xnnpack Build XNNPACK with XNN_LOG_LEVEL=4 to log microkernel dispatch + --debug-xnnpack Enable XNNPACK partitioner DEBUG logging and dump the lowered graph + -h, --help Show this help EOF } for arg in "$@"; do case $arg in --model=*) model="${arg#*=}" ;; - --xnnpack) xnnpack=true ;; --quantize) quantize=true ;; + --backend=*) backend="${arg#*=}" ;; + --os=*) os="${arg#*=}" ;; + --arch=*) arch="${arg#*=}" ;; + --qemu-cpu-ext=*) qemu_cpu_ext="${arg#*=}" ;; + --build-dir=*) build_dir="${arg#*=}" ;; --debug-xnnpack) debug_xnnpack=true ;; --verbose-xnnpack) verbose_xnnpack=true ;; -h|--help) usage; exit 0 ;; @@ -45,9 +56,13 @@ for arg in "$@"; do esac done +if [[ -z "${build_dir}" ]]; then + echo "[test_riscv_qemu.sh] --build-dir is required" >&2; usage; exit 1 +fi + run_extra_args=() -if ${xnnpack}; then - run_extra_args+=(--xnnpack) +if [ -n "${qemu_cpu_ext}" ]; then + run_extra_args+=(--qemu-cpu-ext="${qemu_cpu_ext}") fi if ${quantize}; then run_extra_args+=(--quantize) @@ -59,5 +74,8 @@ if ${verbose_xnnpack}; then run_extra_args+=(--verbose-xnnpack) fi -bash "${et_root_dir}/examples/riscv/setup.sh" -bash "${et_root_dir}/examples/riscv/run.sh" --model="${model}" "${run_extra_args[@]}" +bash "${et_root_dir}/examples/riscv/setup-${os}.sh" +bash "${et_root_dir}/examples/riscv/run.sh" \ + --model="${model}" --backend="${backend}" --os="${os}" --arch="${arch}" \ + --build-dir="${build_dir}" \ + "${run_extra_args[@]}" diff --git a/.ci/scripts/unittest-macos-cmake.sh b/.ci/scripts/unittest-macos-cmake.sh index 43eb1f21c3c..48f072a0cc1 100755 --- a/.ci/scripts/unittest-macos-cmake.sh +++ b/.ci/scripts/unittest-macos-cmake.sh @@ -12,8 +12,19 @@ set -eux export TORCHINDUCTOR_CACHE_DIR="$(mktemp -d "${RUNNER_TEMP:-/tmp}/torchinductor_cache_XXXXXX")" trap 'rm -rf "${TORCHINDUCTOR_CACHE_DIR}"' EXIT -# Run pytest with coverage -${CONDA_RUN} pytest -n auto --cov=./ --cov-report=xml +# TODO(SS-JIA): AOTI tests hang on macOS CI runners — the thread blocks in +# native C/C++ code (dlopen / inductor compilation) so faulthandler cannot +# even produce a traceback. Diagnosis ongoing in #19886. +AOTI_SKIPS=( + --ignore=examples/models/llama3_2_vision/preprocess/test_preprocess.py + --ignore=examples/models/llama3_2_vision/vision_encoder/test/test_vision_encoder.py + --ignore=examples/models/llama3_2_vision/text_decoder/test/test_text_decoder.py + --deselect=extension/llm/modules/test/test_position_embeddings.py::TilePositionalEmbeddingTest::test_tile_positional_embedding_aoti + --deselect=extension/llm/modules/test/test_position_embeddings.py::TiledTokenPositionalEmbeddingTest::test_tiled_token_positional_embedding_aoti + --deselect=extension/llm/modules/test/test_attention.py::AttentionTest::test_attention_aoti +) + +${CONDA_RUN} pytest -n auto --cov=./ --cov-report=xml "${AOTI_SKIPS[@]}" # Run gtest LLVM_PROFDATA="xcrun llvm-profdata" LLVM_COV="xcrun llvm-cov" \ ${CONDA_RUN} test/run_oss_cpp_tests.sh diff --git a/.github/scripts/propose_ghstack_orig_pr.py b/.github/scripts/propose_ghstack_orig_pr.py index 3abcc6cdcf9..f41e03f18ff 100644 --- a/.github/scripts/propose_ghstack_orig_pr.py +++ b/.github/scripts/propose_ghstack_orig_pr.py @@ -52,12 +52,9 @@ def extract_stack_from_body(pr_body: str) -> List[int]: """ prs = [] - ghstack_begin = ( - "Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom):" - ) ghstack_begin_seen = False for line in pr_body.splitlines(): - if ghstack_begin in line: + if line.startswith("Stack from [ghstack]"): ghstack_begin_seen = True if not ghstack_begin_seen: continue diff --git a/.github/workflows/_test_riscv.yml b/.github/workflows/_test_riscv.yml index 223a146e3d8..0b7d8472d8b 100644 --- a/.github/workflows/_test_riscv.yml +++ b/.github/workflows/_test_riscv.yml @@ -13,35 +13,44 @@ on: type: number default: 30 model: - description: 'Which model to run. Possible values are: add, mv2 (mobilenetv2)' + description: 'Which model to run (add, mv2, mobilebert, llama2, resnet18, yolo26)' required: false type: string default: 'add' - xnnpack: - description: 'Whether to enable XNNPACK' - required: false - type: boolean - default: false quantize: description: 'Produce an 8-bit quantized model' required: false type: boolean default: false - qemu-cpu: - description: 'Configuration(s) for the CPU to emulate with QEMU, expecting a JSON array' - required: true + backend: + description: 'AOT backend to lower to (portable|xnnpack)' + required: false type: string - docker-image: - description: 'The docker image to use for this job' + default: 'portable' + os: + description: 'Target OS for the runner (linux|baremetal)' required: false type: string + default: 'linux' + arch: + description: 'Target architecture (rv32|rv64)' + required: false + type: string + default: 'rv64' + qemu-cpu-ext: + description: >- + JSON array of QEMU -cpu *extension* strings (no rv32/rv64 prefix). + The script splices each entry with `arch` to form the final -cpu + value. Use [""] for plain base-ISA runs. + required: true + type: string jobs: run: uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main with: runner: linux.2xlarge - docker-image: ci-image:executorch-ubuntu-24.04-gcc14 + docker-image: ${{ inputs.os == 'linux' && 'ci-image:executorch-ubuntu-24.04-gcc14' || 'ci-image:executorch-ubuntu-26.04-gcc15' }} submodules: 'recursive' ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} timeout: ${{ inputs.timeout }} @@ -55,20 +64,26 @@ jobs: # Allows failure in `echo | jq | while read` pipeline to bubble up and fail the workflow set -o pipefail - echo '${{ inputs.qemu-cpu }}' | jq -r '.[]' | while IFS= read -r qemu_cpu; do - export QEMU_CPU="${qemu_cpu}" - export GCC_VERSION=14 + echo '${{ inputs.qemu-cpu-ext }}' | jq -r '.[]' | while IFS= read -r qemu_cpu_ext; do + variant_slug="${qemu_cpu_ext//,/_}"; variant_slug="${variant_slug//=/_}"; variant_slug="${variant_slug:-base}" + build_dir="riscv_test/${{ inputs.model }}${{ inputs.quantize && '_q' || '' }}/${{ inputs.backend }}/${{ inputs.os }}-${{ inputs.arch }}-${variant_slug}" + bash .ci/scripts/test_riscv_qemu.sh \ --model="${{ inputs.model }}" \ - ${{ inputs.xnnpack && '--xnnpack --verbose-xnnpack' || '' }} \ + --backend="${{ inputs.backend }}" \ + --os="${{ inputs.os }}" \ + --arch="${{ inputs.arch }}" \ + --qemu-cpu-ext="${qemu_cpu_ext}" \ + --build-dir="${build_dir}" \ + ${{ inputs.backend == 'xnnpack' && '--verbose-xnnpack' || '' }} \ ${{ inputs.quantize && '--quantize' || '' }} - # We only generate riscv_test/${{ inputs.model }}_riscv.etdump.json from `--verbose-xnnpack`. - if ${{ inputs.xnnpack }}; then - # Generate markdown table from riscv_test/${{ inputs.model }}_riscv.etdump.json, sorted by sum_ms + # We only generate run.etdump.json from `--verbose-xnnpack`. + if [[ "${{ inputs.backend }}" == "xnnpack" ]]; then + # Generate markdown table from ${build_dir}/run.etdump.json, sorted by sum_ms ( - etdump_json="riscv_test/${{ inputs.model }}_riscv.etdump.json" - echo "### Model=${{ inputs.model }} XNNPACK=${{ inputs.xnnpack }} Quantize=${{ inputs.quantize }} QEMU_CPU='${QEMU_CPU}'" + etdump_json="${build_dir}/run.etdump.json" + echo "### Model=${{ inputs.model }} Quantize=${{ inputs.quantize }} Backend=${{ inputs.backend }} OS=${{ inputs.os }} Arch=${{ inputs.arch }}${qemu_cpu_ext:+,${qemu_cpu_ext}}" jq -r ' def r3: (. * 1000 | round) / 1000; ["Section","Op","Count","Sum (ms)","Avg (ms)","Max (ms)","Microkernels"], diff --git a/.github/workflows/_unittest.yml b/.github/workflows/_unittest.yml index 15c87bd79e4..a253857d2c0 100644 --- a/.github/workflows/_unittest.yml +++ b/.github/workflows/_unittest.yml @@ -49,6 +49,7 @@ jobs: python-version: '3.11' submodules: 'recursive' ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + timeout: 120 script: | set -eux # This is needed to get the prebuilt PyTorch wheel from S3 diff --git a/.github/workflows/docker-builds.yml b/.github/workflows/docker-builds.yml index b77e5497f79..d11b2e9e6d9 100644 --- a/.github/workflows/docker-builds.yml +++ b/.github/workflows/docker-builds.yml @@ -43,6 +43,7 @@ jobs: executorch-ubuntu-22.04-mediatek-sdk, executorch-ubuntu-22.04-clang12-android, executorch-ubuntu-24.04-gcc14, + executorch-ubuntu-26.04-gcc15, ] include: - docker-image-name: executorch-ubuntu-22.04-gcc11-aarch64 diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index b26247d2333..b21cc527b8d 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -125,49 +125,3 @@ jobs: uses: ./.github/workflows/_link_check.yml with: ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} - - android-java-format: - runs-on: ubuntu-latest - permissions: - contents: read - steps: - - uses: actions/checkout@v4 - with: - ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} - - - uses: actions/setup-java@v4 - with: - distribution: 'temurin' - java-version: '17' - - - name: Check Java formatting - run: | - GOOGLE_JAVA_FORMAT_VERSION="1.24.0" - curl -sSfL "https://github.com/google/google-java-format/releases/download/v${GOOGLE_JAVA_FORMAT_VERSION}/google-java-format-${GOOGLE_JAVA_FORMAT_VERSION}-all-deps.jar" \ - -o /tmp/google-java-format.jar - - FILES_NEEDS_FORMAT=$(find extension/android/executorch_android/src/main/java/org/pytorch/executorch \ - extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm \ - extension/android/executorch_android/src/main/java/org/pytorch/executorch/annotations \ - extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch \ - extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench \ - extension/benchmark/android/benchmark/app/src/androidTest/java/org/pytorch/minibench \ - -type f -name "*.java" 2>/dev/null | \ - xargs -r java -jar /tmp/google-java-format.jar -n) - - if [ -n "$FILES_NEEDS_FORMAT" ]; then - echo "Warning: The following files need formatting:" - echo "$FILES_NEEDS_FORMAT" - echo "" - echo "Please use google-java-format from https://github.com/google/google-java-format/releases/" - echo "" - echo "To fix, run one of these commands:" - echo " # Using xargs (recommended):" - echo " find -type f -name '*.java' | xargs google-java-format -i" - echo "" - echo " # Or format specific files:" - echo "$FILES_NEEDS_FORMAT" | while IFS= read -r file; do - echo " google-java-format -i \"$file\"" - done - exit 1 - fi diff --git a/.github/workflows/mlx.yml b/.github/workflows/mlx.yml index c4be146f862..c51f126dbe6 100644 --- a/.github/workflows/mlx.yml +++ b/.github/workflows/mlx.yml @@ -47,6 +47,10 @@ jobs: ${CONDA_RUN} pip list + echo "::group::Install Python test requirements" + ${CONDA_RUN} pip install gguf + echo "::endgroup::" + echo "::group::Build test runners" ${CONDA_RUN} cmake --build cmake-out --target op_test_runner multi_thread_test_runner -j$(( $(sysctl -n hw.ncpu) - 1 )) echo "::endgroup::" @@ -76,6 +80,18 @@ jobs: ${CONDA_RUN} python -m executorch.backends.mlx.model_ops.test_gated_delta_rule run -v echo "::endgroup::" + echo "::group::Run tq_norm op tests" + ${CONDA_RUN} python -m executorch.backends.mlx.model_ops.test_tq_norm run -v + echo "::endgroup::" + + echo "::group::Run tq4_compress op tests" + ${CONDA_RUN} python -m executorch.backends.mlx.model_ops.test_tq4_compress run -v + echo "::endgroup::" + + echo "::group::Run tq_dequant op tests" + ${CONDA_RUN} python -m executorch.backends.mlx.model_ops.test_tq_dequant run -v + echo "::endgroup::" + test-mlx-qwen35-moe: uses: pytorch/test-infra/.github/workflows/macos_job.yml@main with: diff --git a/.github/workflows/riscv64.yml b/.github/workflows/riscv64.yml index 14b9ad62047..f2010b86fe5 100644 --- a/.github/workflows/riscv64.yml +++ b/.github/workflows/riscv64.yml @@ -10,8 +10,9 @@ on: pull_request: paths: - .github/workflows/riscv64.yml + - .github/workflows/_test_riscv.yml - .ci/scripts/test_riscv_qemu.sh - - tools/cmake/preset/riscv64_linux.cmake + - tools/cmake/preset/riscv_*.cmake - examples/riscv/** workflow_dispatch: schedule: @@ -28,39 +29,52 @@ jobs: strategy: fail-fast: false matrix: - include: - - { model: add, xnnpack: false, quantize: false } - - { model: add, xnnpack: true, quantize: false } - - { model: mv2, xnnpack: false, quantize: false } - - { model: mv2, xnnpack: true, quantize: false } - - { model: mv2, xnnpack: true, quantize: true } - - { model: mobilebert, xnnpack: false, quantize: false } - - { model: mobilebert, xnnpack: true, quantize: false } - - { model: mobilebert, xnnpack: true, quantize: true } - - { model: llama2, xnnpack: false, quantize: false } - - { model: llama2, xnnpack: true, quantize: false } - - { model: llama2, xnnpack: true, quantize: true } - - { model: resnet18, xnnpack: false, quantize: false } - - { model: resnet18, xnnpack: true, quantize: false } - - { model: resnet18, xnnpack: true, quantize: true } + model: + - add + - mv2 + - mobilebert + - llama2 + - resnet18 + - yolo26 + quantize: [true, false] + backend: [portable, xnnpack] + os: [linux, baremetal] + arch: [rv64, rv32] + exclude: + # Disable quantization testing with Portable Kernels + - { backend: portable, quantize: true } + # XNNPACK needs pthreads + dynamic loading (no baremetal) + - { backend: xnnpack, os: baremetal } + # No quantization recipe for Yolo26. + - { model: yolo26, quantize: true } + # No riscv32-linux-gnu cross is packaged on Ubuntu. + - { os: linux, arch: rv32 } permissions: id-token: write contents: read with: model: ${{ matrix.model }} - xnnpack: ${{ matrix.xnnpack }} quantize: ${{ matrix.quantize }} - # If XNNPACK, test with multiple RVV length, disabled otherwise - qemu-cpu: >- + backend: ${{ matrix.backend }} + os: ${{ matrix.os }} + arch: ${{ matrix.arch }} + # JSON array of QEMU -cpu *extension* strings (no rv32/rv64 prefix - that + # comes from `arch`). The script splices them as `,`. xnnpack + # benefits from RVV so it sweeps multiple vlen; everything else just uses + # the plain base ISA. + qemu-cpu-ext: >- ${{ case( - matrix.xnnpack, '[ - "rv64,zba=true,zbb=true,zbs=true,v=true,vlen=128,elen=64,vext_spec=v1.0", - "rv64,zba=true,zbb=true,zbs=true,v=true,vlen=256,elen=64,vext_spec=v1.0", - "rv64,zba=true,zbb=true,zbs=true,v=true,vlen=512,elen=64,vext_spec=v1.0" + matrix.backend == 'xnnpack', '[ + "v=true,vext_spec=v1.0,vlen=128", + "v=true,vext_spec=v1.0,vlen=256", + "v=true,vext_spec=v1.0,vlen=512" ]', '[ - "rv64,zba=true,zbb=true,zbs=true,v=false" + "v=false", + "v=true,vext_spec=v1.0,vlen=128", + "v=true,vext_spec=v1.0,vlen=256", + "v=true,vext_spec=v1.0,vlen=512" ]' ) }} diff --git a/.github/workflows/trunk.yml b/.github/workflows/trunk.yml index 5a6720cdfad..cca1fe5fe45 100644 --- a/.github/workflows/trunk.yml +++ b/.github/workflows/trunk.yml @@ -278,6 +278,7 @@ jobs: matrix: include: - test_arm_backend: test_pytest_ops_vkml + - test_arm_backend: test_pytest_models_vkml - test_arm_backend: test_ootb_tests_vgf fail-fast: false with: diff --git a/.gitmodules b/.gitmodules index 917e755da27..0f4d09aa998 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,6 +1,3 @@ -[submodule "backends/arm/third-party/ethos-u-core-driver"] - path = backends/arm/third-party/ethos-u-core-driver - url = https://git.gitlab.arm.com/artificial-intelligence/ethos-u/ethos-u-core-driver.git [submodule "backends/vulkan/third-party/Vulkan-Headers"] path = backends/vulkan/third-party/Vulkan-Headers url = https://github.com/KhronosGroup/Vulkan-Headers diff --git a/.lintrunner.toml b/.lintrunner.toml index 3ee436f61e8..75608704110 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -112,6 +112,8 @@ include_patterns = [ 'backends/arm/**/*.cpp', 'backends/arm/**/*.h', 'backends/arm/**/*.hpp', + 'backends/cortex_m/**/*.cpp', + 'backends/cortex_m/**/*.h', 'examples/arm/**/*.cpp', 'examples/arm/**/*.h', 'examples/arm/**/*.hpp', @@ -132,6 +134,8 @@ command = [ '--extra-arg=--inconclusive', '--extra-arg=--suppress=unusedStructMember', '--extra-arg=--suppress=toomanyconfigs', + '--extra-arg=--suppress=unusedFunction:*.h', + '--extra-arg=--suppress=unusedFunction:*.hpp', '--', '@{{PATHSFILE}}' ] diff --git a/CMakePresets.json b/CMakePresets.json index 91848565067..15d005cbede 100644 --- a/CMakePresets.json +++ b/CMakePresets.json @@ -318,7 +318,7 @@ "displayName": "Build ExecuTorch for riscv64 Linux (cross-compile)", "inherits": ["common"], "cacheVariables": { - "EXECUTORCH_BUILD_PRESET_FILE": "${sourceDir}/tools/cmake/preset/riscv64_linux.cmake", + "EXECUTORCH_BUILD_PRESET_FILE": "${sourceDir}/tools/cmake/preset/riscv_linux.cmake", "CMAKE_TOOLCHAIN_FILE": "${sourceDir}/examples/riscv/riscv64-linux-gnu-toolchain.cmake" }, "condition": { @@ -327,6 +327,24 @@ "rhs": "Linux" } }, + { + "name": "riscv64-baremetal", + "displayName": "Build ExecuTorch for riscv64 baremetal (cross-compile)", + "inherits": ["common"], + "cacheVariables": { + "EXECUTORCH_BUILD_PRESET_FILE": "${sourceDir}/tools/cmake/preset/riscv_baremetal.cmake", + "CMAKE_TOOLCHAIN_FILE": "${sourceDir}/examples/riscv/riscv64-unknown-elf-toolchain.cmake" + } + }, + { + "name": "riscv32-baremetal", + "displayName": "Build ExecuTorch for riscv32 baremetal (cross-compile)", + "inherits": ["common"], + "cacheVariables": { + "EXECUTORCH_BUILD_PRESET_FILE": "${sourceDir}/tools/cmake/preset/riscv_baremetal.cmake", + "CMAKE_TOOLCHAIN_FILE": "${sourceDir}/examples/riscv/riscv32-unknown-elf-toolchain.cmake" + } + }, { "name": "mlx", "displayName": "Build MLX delegate", diff --git a/backends/aoti/slim/core/storage.h b/backends/aoti/slim/core/storage.h index 73c4d32d955..a3d17a89903 100644 --- a/backends/aoti/slim/core/storage.h +++ b/backends/aoti/slim/core/storage.h @@ -13,6 +13,7 @@ #ifdef CUDA_AVAILABLE #include #include +#include #endif #include @@ -107,9 +108,6 @@ struct DeviceTraits { /// @param device The target CUDA device (used to get the stream). /// @return Pointer to allocated device memory. static void* allocate(size_t nbytes, const c10::Device& device) { - // Get the current stream for this device (set by CUDAStreamGuard if any) - // This follows PyTorch's pattern where the allocator assumes the caller - // has already set the correct device via CUDAStreamGuard. auto stream_result = executorch::backends::cuda::getCurrentCUDAStream(device.index()); ET_CHECK_MSG( @@ -118,31 +116,23 @@ struct DeviceTraits { static_cast(device.index())); cudaStream_t stream = stream_result.get(); - void* data = nullptr; - ET_CUDA_CHECK(cudaMallocAsync(&data, nbytes, stream)); - return data; + auto result = executorch::backends::cuda::CudaAllocator::allocate_async( + nbytes, device.index(), stream); + ET_CHECK_MSG( + result.ok(), + "CudaAllocator::allocate_async failed for %zu bytes on device %d", + nbytes, + static_cast(device.index())); + return result.get(); } - /// Frees CUDA device memory on the current stream. - /// @param ptr Pointer to device memory to free. static void free(void* ptr) { - // Get the current stream for the current device - // Currently all cuda slimtensors should be on the same device same stream, - // so we can just use the stream on current device. - // TODO(gasoonjia): add cuda stream as a member of MaybeOwningStorage to - // support multiple devices. auto stream_result = executorch::backends::cuda::getCurrentCUDAStream(-1); ET_CHECK_MSG(stream_result.ok(), "Failed to get current CUDA stream"); - ET_CUDA_LOG_WARN(cudaFreeAsync(ptr, stream_result.get())); + executorch::backends::cuda::CudaAllocator::deallocate_async( + ptr, -1, stream_result.get()); } - /// Copies memory between CPU and CUDA or CUDA and CUDA asynchronously. - /// @param dst Destination pointer. - /// @param src Source pointer. - /// @param nbytes Number of bytes to copy. - /// @param dst_device Destination device. - /// @param src_device Source device. - /// @param stream CUDA stream for async copy. static void memcpy_async( void* dst, const void* src, @@ -151,7 +141,6 @@ struct DeviceTraits { const c10::Device& src_device, cudaStream_t stream) { cudaMemcpyKind direction = cudaMemcpyDeviceToDevice; - if (src_device.is_cpu()) { direction = cudaMemcpyHostToDevice; } else if (dst_device.is_cpu()) { @@ -164,15 +153,11 @@ struct DeviceTraits { static_cast(dst_device.index())); } - ET_CUDA_CHECK(cudaMemcpyAsync(dst, src, nbytes, direction, stream)); + auto err = executorch::backends::cuda::CudaAllocator::memcpy_async( + dst, src, nbytes, direction, stream); + ET_CHECK_MSG(err == executorch::runtime::Error::Ok, "memcpy_async failed"); } - /// Copies memory between CPU and CUDA or CUDA and CUDA synchronously. - /// @param dst Destination pointer. - /// @param src Source pointer. - /// @param nbytes Number of bytes to copy. - /// @param dst_device Destination device. - /// @param src_device Source device. static void memcpy( void* dst, const void* src, @@ -180,7 +165,6 @@ struct DeviceTraits { const c10::Device& dst_device, const c10::Device& src_device) { cudaMemcpyKind direction = cudaMemcpyDeviceToDevice; - if (src_device.is_cpu()) { direction = cudaMemcpyHostToDevice; } else if (dst_device.is_cpu()) { diff --git a/backends/aoti/slim/core/targets.bzl b/backends/aoti/slim/core/targets.bzl index b9148305c91..42a7b79da6e 100644 --- a/backends/aoti/slim/core/targets.bzl +++ b/backends/aoti/slim/core/targets.bzl @@ -19,6 +19,7 @@ def define_common_targets(): "//executorch/runtime/platform:platform", "//executorch/backends/aoti/slim/c10/cuda:exception", "//executorch/backends/aoti/slim/cuda:guard", + "//executorch/backends/cuda/runtime:cuda_allocator", ], ) diff --git a/backends/apple/coreml/BUCK b/backends/apple/coreml/BUCK index 792adcf4d70..688ca64b990 100644 --- a/backends/apple/coreml/BUCK +++ b/backends/apple/coreml/BUCK @@ -171,6 +171,7 @@ runtime.cxx_library( "format/{}.pb.h".format(name): "fbsource//third-party/pypi/coremltools:exported-cpp-protoc[{}.pb.h]".format(name) for name in _PROTOS }, + header_namespace = "", compiler_flags = [ "-Wno-global-constructors", ], diff --git a/backends/arm/CMakeLists.txt b/backends/arm/CMakeLists.txt index d8a6c1afce7..726fcfcd0d3 100644 --- a/backends/arm/CMakeLists.txt +++ b/backends/arm/CMakeLists.txt @@ -39,6 +39,11 @@ set(ETHOSU_LINUX_DRIVER_SOURCE_DIR PATH "Optional local path to an existing ethos-u-linux-driver stack checkout" ) +set(ETHOS_SDK_PATH + "${EXECUTORCH_ROOT}/examples/arm/arm-scratch/ethos-u" + CACHE PATH "Path to Ethos-U bare metal driver/env" +) +option(FETCH_ETHOS_U_CONTENT "Fetch ethos_u dependencies" ON) if(EXECUTORCH_BUILD_ARM_BAREMETAL AND EXECUTORCH_BUILD_ARM_ETHOSU_LINUX) message( @@ -52,8 +57,6 @@ if(EXECUTORCH_BUILD_ARM_BAREMETAL OR EXECUTORCH_BUILD_ARM_ETHOSU_LINUX) add_compile_options("-Wall" "-Werror") - set(THIRD_PARTY_ROOT "${CMAKE_CURRENT_SOURCE_DIR}/third-party") - set(_arm_backend_sources backends/arm/runtime/EthosUBackend.cpp backends/arm/runtime/EthosUBackend_IoMemcpy.cpp @@ -72,11 +75,22 @@ if(EXECUTORCH_BUILD_ARM_BAREMETAL OR EXECUTORCH_BUILD_ARM_ETHOSU_LINUX) executorch_delegate_ethos_u PRIVATE ${EXECUTORCH_ROOT}/backends/arm/runtime/EthosUBackend_Cortex_M.cpp ) - set(_ethosu_core_driver_include - "${THIRD_PARTY_ROOT}/ethos-u-core-driver/include" + include(${EXECUTORCH_ROOT}/backends/arm/scripts/corstone_utils.cmake) + if(FETCH_ETHOS_U_CONTENT) + fetch_ethos_u_content(${ETHOS_SDK_PATH} ${EXECUTORCH_ROOT}) + endif() + set(DRIVER_ETHOSU_INCLUDE_DIR + "${ETHOS_SDK_PATH}/core_software/core_driver/include" ) + if(NOT EXISTS "${DRIVER_ETHOSU_INCLUDE_DIR}/ethosu_driver.h") + message( + FATAL_ERROR + "Ethos-U core driver headers were not found in ${DRIVER_ETHOSU_INCLUDE_DIR}." + " Run examples/arm/setup.sh or enable FETCH_ETHOS_U_CONTENT." + ) + endif() target_include_directories( - executorch_delegate_ethos_u PRIVATE ${_ethosu_core_driver_include} + executorch_delegate_ethos_u PRIVATE ${DRIVER_ETHOSU_INCLUDE_DIR} ) target_link_libraries(executorch_delegate_ethos_u PUBLIC ethosu_core_driver) elseif(EXECUTORCH_BUILD_ARM_ETHOSU_LINUX) diff --git a/backends/arm/README.md b/backends/arm/README.md index f822077e170..8edd3665d44 100644 --- a/backends/arm/README.md +++ b/backends/arm/README.md @@ -61,8 +61,6 @@ backends/arm/ │ ├── models/ # Model level unit tests │ └── tester/ # Testing harnesses and utilities │ -├── third-party/ # External dependencies -│ ├── tosa/ # Shared TOSA backend implementation and dialect │ └── vgf/ # Implementations of VgfPartitioner and VgfBackend @@ -138,8 +136,10 @@ The delegated Python API flow is: For complete examples of that flow, including quantization and target-specific compile specs, see: -- `docs/source/backends/arm-ethos-u/tutorials/ethos-u-getting-started.md` -- `docs/source/backends/arm-vgf/tutorials/vgf-getting-started.md` +- [Arm Ethos-U tutorial](../../docs/source/backends/arm-ethos-u/tutorials/ethos-u-getting-started.md) +- [Arm VGF tutorial](../../docs/source/backends/arm-vgf/tutorials/vgf-getting-started.md) +- [Arm Cortex-M backend overview](../../docs/source/backends/arm-cortex-m/arm-cortex-m-overview.md) +- [Ethos-U porting guide](../../examples/arm/ethos-u-porting-guide.md) Additional examples are available in `examples/arm`. diff --git a/backends/arm/TARGETS b/backends/arm/TARGETS index c3e2251bb11..a63237fe2c9 100644 --- a/backends/arm/TARGETS +++ b/backends/arm/TARGETS @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -15,6 +15,31 @@ runtime.python_library( "//executorch/exir/dialects:lib", ], ) +runtime.python_library( + name = "ao_ext", + srcs = glob([ + "ao_ext/*.py", + "ao_ext/ops/*.py", + ]), + deps = [ + "//caffe2:torch", + "//executorch/exir:_warnings", + "//pytorch/ao:torchao", + ], +) + +runtime.python_library( + name = "lib", + srcs = [ + "__init__.py", + ], + deps = [ + ":ao_ext", + ":ethosu", + ":vgf", + "//executorch/backends/arm/quantizer:lib", + ], +) runtime.python_library( name = "common", srcs = glob(["common/*.py"]), diff --git a/backends/arm/__init__.py b/backends/arm/__init__.py index fcbafa717ce..7c0b61457d0 100644 --- a/backends/arm/__init__.py +++ b/backends/arm/__init__.py @@ -14,6 +14,10 @@ import importlib from typing import Any +# Register Arm-specific torch.library ops and MXFP transforms at package +# import time. +import executorch.backends.arm.ao_ext # noqa: F401 + # Public for tooling (manifest generation and API validation). LAZY_IMPORTS = { "EthosUBackend": ("executorch.backends.arm.ethosu", "EthosUBackend"), @@ -32,6 +36,8 @@ "executorch.backends.arm.quantizer", "get_symmetric_a16w8_quantization_config", ), + "MXFPOpConfig": ("executorch.backends.arm.ao_ext.mxfp", "MXFPOpConfig"), + "to_mxfp": ("executorch.backends.arm.ao_ext.mxfp", "to_mxfp"), } diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index 20bddf17793..3e881fdb9ef 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -5,7 +5,7 @@ from . import arm_pass_utils # noqa -from .arm_pass import ArmPass # noqa # usort: skip +from .arm_pass import ArmOpTargetedPass, ArmPass # noqa # usort: skip from .accumulate_index_put_pass import AccumulateIndexPutPass # noqa from .broadcast_args_pass import BroadcastArgsPass # noqa from .canonicalize_gather_pass import CanonicalizeGatherPass # noqa diff --git a/backends/arm/_passes/accumulate_index_put_pass.py b/backends/arm/_passes/accumulate_index_put_pass.py index 1194e08e2d8..9aa0457b0c7 100644 --- a/backends/arm/_passes/accumulate_index_put_pass.py +++ b/backends/arm/_passes/accumulate_index_put_pass.py @@ -6,7 +6,7 @@ import torch -from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes import ArmOpTargetedPass from executorch.backends.arm._passes.decompose_index_tensor_to_gather_pass import ( DecomposeIndexTensorToGatherPass, ) @@ -32,7 +32,7 @@ def get_ops(op): raise RuntimeError(f"Can't get index_put decomposition for op {op}") -class AccumulateIndexPutPass(ArmPass): +class AccumulateIndexPutPass(ArmOpTargetedPass): """This pass adjusts the values arg when the accumulate arg is set to true for the index_put op. """ @@ -41,9 +41,11 @@ class AccumulateIndexPutPass(ArmPass): DecomposeIndexTensorToGatherPass, RewriteIndexPutPass, } + target_ops = aten_ops + edge_ops + check_allowed_to_transform = True def call_operator(self, op, args, kwargs, meta): - if op not in (aten_ops + edge_ops) or not self.allowed_to_transform(meta): + if op not in self.target_ops or not self.allowed_to_transform(meta): return super().call_operator(op, args, kwargs, meta) source, indices, values = args[:3] diff --git a/backends/arm/_passes/arm_pass.py b/backends/arm/_passes/arm_pass.py index add0f3aeb20..1b4fc677d18 100644 --- a/backends/arm/_passes/arm_pass.py +++ b/backends/arm/_passes/arm_pass.py @@ -7,6 +7,7 @@ import copy import traceback from abc import abstractmethod +from collections.abc import Collection from typing import Any, List, Optional, Set, Type import torch @@ -14,7 +15,7 @@ from executorch.backends.arm.tosa.mapping import TosaSpecialDtype from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue -from torch.fx import GraphModule +from torch.fx import GraphModule, Node from torch.fx.passes.infra.pass_base import PassResult from torch.utils import _pytree as pytree @@ -191,3 +192,99 @@ def call_scalar(self, value: int | float, meta: NodeMetadata | dict[str, Any]): meta=meta, updated=True, ) + + def should_run_pass(self, graph_module: GraphModule) -> bool: + """Return whether this pass should run on the graph module. + + Subclasses can override this to cheaply skip the pass before + ``call()`` starts the normal ``ExportPass`` retracing path. + + Args: + graph_module (GraphModule): The graph module to inspect. + + Returns: + bool: True when the pass should run. + + """ + return True + + def __call__(self, graph_module: GraphModule) -> PassResult | None: + self.requires(graph_module) + if not self.should_run_pass(graph_module): + self.ensures(graph_module) + return PassResult(graph_module, False) + res = self.call(graph_module) + self.ensures(graph_module) + return res + + +class ArmOpTargetedPass(ArmPass): + """Base class for passes that only transform selected operators. + + Subclasses set ``target_ops`` to the call_function targets they can + transform. If the current graph and nested control-flow subgraphs do not + contain any target, the pass returns immediately without paying the default + ExportPass retracing cost. + + Set ``check_allowed_to_transform`` to ``True`` when the target pre-scan + should also apply ``allowed_to_transform()`` to matching target nodes. This + is useful for TFA passes whose ``call_operator()`` leaves disallowed target + nodes unchanged. If all matching targets are disallowed, the pass can + return before entering the normal ``ExportPass`` path. + + """ + + target_ops: Collection[Any] = () + check_allowed_to_transform = False + + def has_target_node(self, graph_module: GraphModule) -> bool: + """Return whether the graph module tree contains a target node. + + Args: + graph_module (GraphModule): The graph module tree to inspect. + + Returns: + bool: True if a matching call_function node is present. + + """ + visited_graph_modules = set() + + def target_node_can_trigger_pass(node: Node) -> bool: + if not self.check_allowed_to_transform: + return True + if self.allowed_to_transform(node.meta): + return True + return False + + def graph_has_target(module: GraphModule) -> bool: + if id(module) in visited_graph_modules: + return False + visited_graph_modules.add(id(module)) + + for target in self.target_ops: + for node in module.graph.find_nodes( + op="call_function", + target=target, + sort=False, + ): + if target_node_can_trigger_pass(node): + return True + + return any( + isinstance(child, GraphModule) and graph_has_target(child) + for child in module.children() + ) + + return graph_has_target(graph_module) + + def should_run_pass(self, graph_module: GraphModule) -> bool: + """Return whether this pass has a target node to transform. + + Args: + graph_module (GraphModule): The graph module tree to inspect. + + Returns: + bool: True when a matching target node is present. + + """ + return self.has_target_node(graph_module) diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index 5a135696463..8a02f7393de 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -481,9 +481,6 @@ def _tosa_pipeline( ConvertFullLikeToFullPass(), MatchArgDtypePass(), UnsqueezeScalarPlaceholdersPass(exported_program), - # TODO: Move DecomposeNotEqualPass to before or after this block of - # passes. Ticket: MLETORCH-1540 - DecomposeNotEqualPass(), MatchArgRanksPass(exported_program), ] ) @@ -491,6 +488,7 @@ def _tosa_pipeline( # Node transformation passes (post scalar-removal) self.add_passes( [ + DecomposeNotEqualPass(), NormalizeIndexPutNoneIndicesPass(), NormalizeIndexPutBoolIndexTensorPass(), RewriteIndexPutPass(), diff --git a/backends/arm/_passes/arm_pass_utils.py b/backends/arm/_passes/arm_pass_utils.py index 000f92135eb..f66b17b9da2 100644 --- a/backends/arm/_passes/arm_pass_utils.py +++ b/backends/arm/_passes/arm_pass_utils.py @@ -9,7 +9,7 @@ import operator import traceback from inspect import isclass -from typing import cast, List, Optional, Sequence, Tuple +from typing import cast, Optional, Sequence import torch import torch.fx @@ -19,10 +19,6 @@ from executorch.exir import ExportedProgram from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.dialects.edge._ops import EdgeOpOverload -from executorch.exir.graph_module import ( - _get_control_flow_submodules, - get_control_flow_submodules, -) from executorch.exir.pass_base import NodeMetadata from torch._export.utils import ( @@ -36,7 +32,6 @@ from torch._ops import OpOverload from torch._subclasses.fake_tensor import FakeTensor from torch.export.graph_signature import InputKind -from torch.fx import GraphModule, Node def is_submodule_node(node: torch.fx.Node): @@ -364,48 +359,6 @@ def set_node_arg(node: torch.fx.Node, i: int | str, value): raise RuntimeError("Invalid type") -def is_nested_control_flow_graph(graph_module: GraphModule) -> bool: - """Returns True if graph_module is a nested control-flow graph.""" - - # Find all top-level control-flow submodules - top_cf = get_control_flow_submodules(graph_module) - # For each submodule, see if it itself has control-flow inside - for _, submod, _ in top_cf: - if get_control_flow_submodules(submod): - return True - return False - - -def get_cond_while_submodules_nested( - graph_module: GraphModule, - apply_quantization: bool = False, -) -> List[Tuple[str, GraphModule, Node]]: - """Recursively find cond/while_loop submodules in an GraphModule. - - In nested control flow graphs, FX records the submodule functions - (true/false or cond/body) in reverse order compared to top-level graphs. We - must swap the indices when nested so that cond (first) and body/true_fn - (second) are consistently identified across all nesting levels. - - """ - - # Determine arg indices based on nesting and whether only cond branch is needed - nested = is_nested_control_flow_graph(graph_module) - # cond: [true_fn, false_fn] or swapped if nested - cond_indices = [2, 1] if nested else [1, 2] - # while_loop: [cond_fn, body_fn] or swapped if nested - while_indices = [1, 0] if nested else [0, 1] - if apply_quantization: - # only keep the cond_fn for while_loop (first index) when quantizing. - while_indices = [while_indices[0]] - mapping = { - torch.ops.higher_order.cond: cond_indices, - torch.ops.higher_order.while_loop: while_indices, - } - # collect cond/while submodules (using mapping indices) - return _get_control_flow_submodules(graph_module, mapping) - - def to_2tuple(value): """Normalizes scalars, and 1-element sequences to a tuple of length 2.""" if isinstance(value, int): diff --git a/backends/arm/_passes/canonicalize_gather_pass.py b/backends/arm/_passes/canonicalize_gather_pass.py index 23886111b18..aaa77ce4002 100644 --- a/backends/arm/_passes/canonicalize_gather_pass.py +++ b/backends/arm/_passes/canonicalize_gather_pass.py @@ -6,12 +6,12 @@ from typing import Set, Type import torch -from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes import ArmOpTargetedPass from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass -class CanonicalizeGatherPass(ArmPass): +class CanonicalizeGatherPass(ArmOpTargetedPass): """Canonicalize gather so it can be lowered to TOSA.GATHER via the backend dialect. @@ -40,10 +40,10 @@ class CanonicalizeGatherPass(ArmPass): _passes_required_after: Set[Type[ExportPass]] = set() - _TARGET_OPS = {exir_ops.edge.aten.gather.default} + target_ops = {exir_ops.edge.aten.gather.default} def call_operator(self, op, args, kwargs, meta): - if op not in self._TARGET_OPS: + if op not in self.target_ops: return super().call_operator(op, args, kwargs, meta) # edge.aten.gather.default: (x, dim, index) with kw-only sparse_grad diff --git a/backends/arm/_passes/control_flow_const_inline.py b/backends/arm/_passes/control_flow_const_inline.py index cc76e5d9957..177ad30754e 100644 --- a/backends/arm/_passes/control_flow_const_inline.py +++ b/backends/arm/_passes/control_flow_const_inline.py @@ -7,12 +7,10 @@ import torch from executorch.backends.arm._passes.arm_pass import ArmPass -from executorch.backends.arm._passes.arm_pass_utils import ( - get_cond_while_submodules_nested, - is_submodule_node, -) +from executorch.backends.arm._passes.arm_pass_utils import is_submodule_node from executorch.backends.transforms.utils import is_get_attr_node from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.graph_module import get_cond_while_submodules from executorch.exir.pass_base import ExportPass, PassResult from torch.fx import GraphModule @@ -37,7 +35,7 @@ class ControlFlowConstInlinePass(ArmPass): def _convert_getattr(self, graph_module): modified = False - for _, submodule, _ in get_cond_while_submodules_nested(graph_module): + for _, submodule, _ in get_cond_while_submodules(graph_module): for submodule_node in submodule.graph.nodes: if submodule_node.target in self._targeted_ops: self._convert_getattr(submodule) diff --git a/backends/arm/_passes/conv1d_unsqueeze_pass.py b/backends/arm/_passes/conv1d_unsqueeze_pass.py index cf1e884e05b..f81ef33e2d1 100644 --- a/backends/arm/_passes/conv1d_unsqueeze_pass.py +++ b/backends/arm/_passes/conv1d_unsqueeze_pass.py @@ -8,7 +8,7 @@ from typing import Set, Type -from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes import ArmOpTargetedPass from executorch.backends.arm._passes.rewrite_conv_pass import RewriteConvPass from executorch.backends.arm._passes.size_adjust_input_pass import SizeAdjustInputPass @@ -17,7 +17,7 @@ from executorch.exir.pass_base import ExportPass -class Conv1dUnsqueezePass(ArmPass): +class Conv1dUnsqueezePass(ArmOpTargetedPass): """This pass is used to change conv1d ops into conv2d since TOSA only supports 2d and 3d convolution. @@ -34,9 +34,10 @@ class Conv1dUnsqueezePass(ArmPass): RewriteConvPass, SizeAdjustInputPass, } + target_ops = (exir_ops.edge.aten.convolution.default,) def call_operator(self, op, args, kwargs, meta): - if op != exir_ops.edge.aten.convolution.default: + if op not in self.target_ops: return super().call_operator(op, args, kwargs, meta) stride = list(args[3]) if len(stride) != 1: diff --git a/backends/arm/_passes/convert_expand_copy_to_repeat.py b/backends/arm/_passes/convert_expand_copy_to_repeat.py index 69056cb47f4..430dc70bd0c 100644 --- a/backends/arm/_passes/convert_expand_copy_to_repeat.py +++ b/backends/arm/_passes/convert_expand_copy_to_repeat.py @@ -9,7 +9,7 @@ import torch -from executorch.backends.arm._passes.arm_pass import ArmPass +from executorch.backends.arm._passes.arm_pass import ArmOpTargetedPass from executorch.backends.arm._passes.unsqueeze_before_repeat_pass import ( UnsqueezeBeforeRepeatPass, ) @@ -51,7 +51,7 @@ def calculate_multiples(args): return multiples, expanded_rank != len(input_shape) -class ConvertExpandCopyToRepeatPass(ArmPass): +class ConvertExpandCopyToRepeatPass(ArmOpTargetedPass): """Replace expand copy with repeat since it is a repeat that can only repeat singleton dimensions. """ @@ -60,9 +60,10 @@ class ConvertExpandCopyToRepeatPass(ArmPass): expand_copy = exir_ops.edge.aten.expand_copy.default repeat = exir_ops.edge.aten.repeat.default + target_ops = (expand_copy,) def call_operator(self, op, args, kwargs, meta): - if op != self.expand_copy: + if op not in self.target_ops: return super().call_operator(op, args, kwargs, meta) multiples, changes_rank = calculate_multiples(args) diff --git a/backends/arm/_passes/convert_full_like_to_full_pass.py b/backends/arm/_passes/convert_full_like_to_full_pass.py index 1e26f24250a..f7a94424228 100644 --- a/backends/arm/_passes/convert_full_like_to_full_pass.py +++ b/backends/arm/_passes/convert_full_like_to_full_pass.py @@ -5,7 +5,7 @@ from typing import Set, Type -from executorch.backends.arm._passes.arm_pass import ArmPass +from executorch.backends.arm._passes.arm_pass import ArmOpTargetedPass from executorch.backends.arm._passes.fuse_constant_ops_pass import ( ComputeConstantOpsAOTPass, ) @@ -14,7 +14,7 @@ from executorch.exir.pass_base import ExportPass -class ConvertFullLikeToFullPass(ArmPass): +class ConvertFullLikeToFullPass(ArmOpTargetedPass): """Convert edge aten full_like to full. As per the full_like PyTorch documentation, `torch.full_like(input, @@ -35,11 +35,10 @@ class ConvertFullLikeToFullPass(ArmPass): """ _passes_required_after: Set[Type[ExportPass]] = {ComputeConstantOpsAOTPass} + target_ops = (exir_ops.edge.aten.full_like.default,) def call_operator(self, op, args, kwargs, meta): - if op not in [ - exir_ops.edge.aten.full_like.default, - ]: + if op not in self.target_ops: return super().call_operator(op, args, kwargs, meta) tensor = args[0].data diff --git a/backends/arm/_passes/convert_permute_singleton_to_view_pass.py b/backends/arm/_passes/convert_permute_singleton_to_view_pass.py index 7447cf037bc..0ed5f92f91d 100644 --- a/backends/arm/_passes/convert_permute_singleton_to_view_pass.py +++ b/backends/arm/_passes/convert_permute_singleton_to_view_pass.py @@ -6,7 +6,7 @@ from typing import Sequence, Set, Tuple, Type -from executorch.backends.arm._passes.arm_pass import ArmPass +from executorch.backends.arm._passes.arm_pass import ArmOpTargetedPass from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -20,7 +20,7 @@ ) -class ConvertPermuteSingletonToViewPass(ArmPass): +class ConvertPermuteSingletonToViewPass(ArmOpTargetedPass): """Replace permutations that only move singleton axes with a reshape. Examples: @@ -34,9 +34,10 @@ class ConvertPermuteSingletonToViewPass(ArmPass): """ _passes_required_after: Set[Type[ExportPass]] = set() + target_ops = _PERMUTE_TARGETS def call_operator(self, op, args, kwargs, meta): - if op not in _PERMUTE_TARGETS: + if op not in self.target_ops: return super().call_operator(op, args, kwargs, meta) input_tensor = args[0].data diff --git a/backends/arm/_passes/convert_squeezes_to_view.py b/backends/arm/_passes/convert_squeezes_to_view.py index 2058c3407e3..b79e38cdf10 100644 --- a/backends/arm/_passes/convert_squeezes_to_view.py +++ b/backends/arm/_passes/convert_squeezes_to_view.py @@ -6,7 +6,7 @@ from typing import Set, Type -from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes import ArmOpTargetedPass from executorch.backends.arm._passes.fuse_view_copy_transform_pass import ( FuseViewCopyTransformPass, ) @@ -14,7 +14,7 @@ from executorch.exir.pass_base import ExportPass -class ConvertSqueezesToViewPass(ArmPass): +class ConvertSqueezesToViewPass(ArmOpTargetedPass): """Replaces squeeze/unsqueeze operators with view. These are simply special cases of the view op, so removing them gives us @@ -23,12 +23,13 @@ class ConvertSqueezesToViewPass(ArmPass): """ _passes_required_after: Set[Type[ExportPass]] = {FuseViewCopyTransformPass} + target_ops = ( + exir_ops.edge.aten.squeeze_copy.dims, + exir_ops.edge.aten.unsqueeze_copy.default, + ) def call_operator(self, op, args, kwargs, meta): - if op not in [ - exir_ops.edge.aten.squeeze_copy.dims, - exir_ops.edge.aten.unsqueeze_copy.default, - ]: + if op not in self.target_ops: return super().call_operator(op, args, kwargs, meta) x = args[0] diff --git a/backends/arm/_passes/convert_to_clamp_pass.py b/backends/arm/_passes/convert_to_clamp_pass.py index effb46f25c4..6273759aa55 100644 --- a/backends/arm/_passes/convert_to_clamp_pass.py +++ b/backends/arm/_passes/convert_to_clamp_pass.py @@ -1,11 +1,11 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. from typing import Set, Tuple, Type -from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes import ArmOpTargetedPass from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( QuantizeClampArgumentsPass, @@ -29,11 +29,13 @@ def get_clamp_params(op, args) -> Tuple[float | None, float | None]: raise ValueError(f"Getting clamp parameters for op {op} is not implemented.") -class ConvertToClampPass(ArmPass): +class ConvertToClampPass(ArmOpTargetedPass): _passes_required_after: Set[Type[ExportPass]] = {QuantizeClampArgumentsPass} + target_ops = edge_operators + check_allowed_to_transform = True def call_operator(self, op, args, kwargs, meta): - if op not in edge_operators or not self.allowed_to_transform(meta): + if op not in self.target_ops or not self.allowed_to_transform(meta): return super().call_operator(op, args, kwargs, meta) return super().call_operator( diff --git a/backends/arm/_passes/decompose_acosh_pass.py b/backends/arm/_passes/decompose_acosh_pass.py index 3ce6d73abc3..3c2cac45e75 100644 --- a/backends/arm/_passes/decompose_acosh_pass.py +++ b/backends/arm/_passes/decompose_acosh_pass.py @@ -6,7 +6,7 @@ from typing import Set, Type -from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes import ArmOpTargetedPass from executorch.backends.arm._passes.decompose_sqrt_pass import DecomposeSqrtPass from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass # noqa from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass @@ -21,7 +21,7 @@ edge_acosh_op = exir_ops.edge.aten.acosh.default -class DecomposeAcoshPass(ArmPass): +class DecomposeAcoshPass(ArmOpTargetedPass): """Decomposes acosh to supported TOSA-operations. This decomposition is based on the mathematical identity: @@ -36,10 +36,11 @@ class DecomposeAcoshPass(ArmPass): ReplaceScalarWithTensorByProfilePass, MatchArgDtypePass, } + target_ops = (edge_acosh_op,) def call_operator(self, op, args, kwargs, meta, updated=False): - if op is not edge_acosh_op: + if op not in self.target_ops: return super().call_operator(op, args, kwargs, meta, updated) if self._is_quantized_meta(meta): diff --git a/backends/arm/_passes/decompose_adaptive_avg_pool2d_pass.py b/backends/arm/_passes/decompose_adaptive_avg_pool2d_pass.py index eda9dd28bf9..58fcf69cd8f 100644 --- a/backends/arm/_passes/decompose_adaptive_avg_pool2d_pass.py +++ b/backends/arm/_passes/decompose_adaptive_avg_pool2d_pass.py @@ -8,7 +8,7 @@ import torch -from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes import ArmOpTargetedPass from executorch.backends.arm._passes.decompose_avg_pool2d_pass import ( DecomposeAvgPool2dPass, ) @@ -36,7 +36,7 @@ def _get_decomposition(op) -> tuple: raise RuntimeError(f"Unable to get decomposition for op {op}") -class DecomposeAdaptiveAvgPool2dPass(ArmPass): +class DecomposeAdaptiveAvgPool2dPass(ArmOpTargetedPass): """Decomposes AdaptiveAvgPool2d into AvgPool2d operations. An input tensor of shape (N, C, H, W) is transformed into an output tensor @@ -47,9 +47,11 @@ class DecomposeAdaptiveAvgPool2dPass(ArmPass): """ _passes_required_after: Set[Type[ExportPass]] = {DecomposeAvgPool2dPass} + target_ops = edge_ops + aten_ops + check_allowed_to_transform = True def call_operator(self, op, args, kwargs, meta, updated=False): - if op not in (edge_ops + aten_ops) or not self.allowed_to_transform(meta): + if op not in self.target_ops or not self.allowed_to_transform(meta): return super().call_operator(op, args, kwargs, meta, updated) avg_pool2d_op, slice_op, cat_op = _get_decomposition(op) diff --git a/backends/arm/_passes/decompose_add_sub_alpha_pass.py b/backends/arm/_passes/decompose_add_sub_alpha_pass.py index d7db9c5bcf9..30903fbd3d8 100644 --- a/backends/arm/_passes/decompose_add_sub_alpha_pass.py +++ b/backends/arm/_passes/decompose_add_sub_alpha_pass.py @@ -9,7 +9,7 @@ from typing import Set, Type import torch -from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes import ArmOpTargetedPass from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -55,13 +55,14 @@ def _should_decompose(alpha) -> bool: return False -class DecomposeAddSubAlphaPass(ArmPass): +class DecomposeAddSubAlphaPass(ArmOpTargetedPass): """Rewrite add/sub with alpha into a mul followed by add/sub.""" _passes_required_after: Set[Type[ExportPass]] = set() + target_ops = _ADD_OPS + _SUB_OPS def call_operator(self, op, args, kwargs, meta, updated: bool | None = False): - if op not in _ADD_OPS + _SUB_OPS: + if op not in self.target_ops: return super().call_operator(op, args, kwargs, meta, updated) alpha = kwargs.get("alpha", 1) diff --git a/backends/arm/_passes/decompose_addmm_pass.py b/backends/arm/_passes/decompose_addmm_pass.py index d1368602d5d..d198e1a3b64 100644 --- a/backends/arm/_passes/decompose_addmm_pass.py +++ b/backends/arm/_passes/decompose_addmm_pass.py @@ -7,7 +7,7 @@ import torch -from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes import ArmOpTargetedPass from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass from executorch.backends.arm._passes.mm_to_bmm_pass import ConvertMmToBmmPass # noqa @@ -41,7 +41,7 @@ def get_ops(op): raise ValueError(f"Unsupported operator: {op}") -class DecomposeAddmmPass(ArmPass): +class DecomposeAddmmPass(ArmOpTargetedPass): """Decomposes the addmm operator into tensor multiplication and addition.""" _passes_required_after: Set[Type[ExportPass]] = { @@ -49,9 +49,10 @@ class DecomposeAddmmPass(ArmPass): MatchArgRanksPass, MatchArgDtypePass, } + target_ops = (edge_addmm, aten_addmm) def call_operator(self, op, args, kwargs, meta): - if op not in [edge_addmm, aten_addmm] or not self.allowed_to_transform(meta): + if op not in self.target_ops or not self.allowed_to_transform(meta): return super().call_operator(op, args, kwargs, meta) input, mat1, mat2 = args diff --git a/backends/arm/_passes/decompose_as_strided_copy_pass.py b/backends/arm/_passes/decompose_as_strided_copy_pass.py index a60d1b19fd9..c8c2a200bd8 100644 --- a/backends/arm/_passes/decompose_as_strided_copy_pass.py +++ b/backends/arm/_passes/decompose_as_strided_copy_pass.py @@ -7,7 +7,7 @@ import torch -from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes import ArmOpTargetedPass from executorch.backends.arm.common.as_strided_utils import ( contiguous_strides, maybe_static_sequence, @@ -18,7 +18,7 @@ from executorch.exir.pass_base import ExportPass -class DecomposeAsStridedCopyPass(ArmPass): +class DecomposeAsStridedCopyPass(ArmOpTargetedPass): """Replace contiguous `aten.as_strided_copy` with `aten.view_copy`. The TOSA backend only supports the contiguous-as-strided case where the stride matches @@ -31,6 +31,7 @@ class DecomposeAsStridedCopyPass(ArmPass): _EDGE_OPS = (exir_ops.edge.aten.as_strided_copy.default,) _ATEN_OPS = (torch.ops.aten.as_strided_copy.default,) + target_ops = _EDGE_OPS + _ATEN_OPS def _extract_args( self, args: Tuple[object, ...], kwargs: dict @@ -76,7 +77,7 @@ def _extract_args( return size_tuple, stride_tuple, storage_offset def call_operator(self, op, args, kwargs, meta, updated: Optional[bool] = False): - if op not in (*self._EDGE_OPS, *self._ATEN_OPS): + if op not in self.target_ops: return super().call_operator(op, args, kwargs, meta, updated) extracted = self._extract_args(args, kwargs) diff --git a/backends/arm/_passes/decompose_asin_and_acos_pass.py b/backends/arm/_passes/decompose_asin_and_acos_pass.py index 707e6ec070d..5e0cfd66c32 100644 --- a/backends/arm/_passes/decompose_asin_and_acos_pass.py +++ b/backends/arm/_passes/decompose_asin_and_acos_pass.py @@ -10,7 +10,7 @@ import torch -from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes import ArmOpTargetedPass from executorch.backends.arm._passes.convert_full_like_to_full_pass import ( ConvertFullLikeToFullPass, ) @@ -48,7 +48,7 @@ def get_decomposition(op) -> tuple: raise RuntimeError(f"Can't get decomposition for op {op}") -class DecomposeAsinAndAcosPass(ArmPass): +class DecomposeAsinAndAcosPass(ArmOpTargetedPass): """This pass decomposes asin and acos into a rational approximation for small values and a transformed rational approximation for large values. @@ -71,6 +71,7 @@ class DecomposeAsinAndAcosPass(ArmPass): MatchArgDtypePass, ReplaceScalarWithTensorByProfilePass, } + target_ops = edge_asin_op + edge_acos_op def _build_polynomial( self, coefficients: list[float], variable: torch.Tensor, meta: dict[str, str] @@ -116,7 +117,7 @@ def _combine_branches( ) def call_operator(self, op, args, kwargs, meta): - if op not in (edge_asin_op + edge_acos_op): + if op not in self.target_ops: return super().call_operator(op, args, kwargs, meta) if self._is_quantized_meta(meta): diff --git a/backends/arm/_passes/decompose_asinh_pass.py b/backends/arm/_passes/decompose_asinh_pass.py index 822b793d203..5f31c5efedc 100644 --- a/backends/arm/_passes/decompose_asinh_pass.py +++ b/backends/arm/_passes/decompose_asinh_pass.py @@ -6,7 +6,7 @@ from typing import Set, Type -from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes import ArmOpTargetedPass from executorch.backends.arm._passes.decompose_sqrt_pass import DecomposeSqrtPass from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass @@ -21,7 +21,7 @@ edge_asinh_op = (exir_ops.edge.aten.asinh.default,) -class DecomposeAsinhPass(ArmPass): +class DecomposeAsinhPass(ArmOpTargetedPass): """Decomposes asinh to supported TOSA-operations. This decomposition is based on the mathematical identity: @@ -36,9 +36,10 @@ class DecomposeAsinhPass(ArmPass): ReplaceScalarWithTensorByProfilePass, MatchArgDtypePass, } + target_ops = edge_asinh_op def call_operator(self, op, args, kwargs, meta): - if op not in edge_asinh_op: + if op not in self.target_ops: return super().call_operator(op, args, kwargs, meta) if self._is_quantized_meta(meta): diff --git a/backends/arm/_passes/decompose_atan_pass.py b/backends/arm/_passes/decompose_atan_pass.py index a7ca90e7b43..cd33504c972 100644 --- a/backends/arm/_passes/decompose_atan_pass.py +++ b/backends/arm/_passes/decompose_atan_pass.py @@ -7,7 +7,7 @@ from math import pi from typing import Set, Type -from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes import ArmOpTargetedPass from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass @@ -40,7 +40,7 @@ def _get_atan_ops(op): ) -class DecomposeAtanPass(ArmPass): +class DecomposeAtanPass(ArmOpTargetedPass): """Decomposes the atan operator into a rational (Padé) approximation.""" _passes_required_after: Set[Type[ExportPass]] = { @@ -49,6 +49,7 @@ class DecomposeAtanPass(ArmPass): MatchArgDtypePass, ReplaceScalarWithTensorByProfilePass, } + target_ops = (edge_atan,) def _rational_approximation(self, z, ops, meta): """Creates a (2,1) Padé approximation for atan(x) on [-1, 1].""" @@ -77,7 +78,7 @@ def _rational_approximation(self, z, ops, meta): return super().call_operator(op_mul, (z, prod), {}, meta, updated=True) def call_operator(self, op, args, kwargs, meta): - if op is not edge_atan: + if op not in self.target_ops: return super().call_operator(op, args, kwargs, meta, updated=False) if self._is_quantized_meta(meta): diff --git a/backends/arm/_passes/decompose_atanh_pass.py b/backends/arm/_passes/decompose_atanh_pass.py index 014da39d7bd..c542b94f30d 100644 --- a/backends/arm/_passes/decompose_atanh_pass.py +++ b/backends/arm/_passes/decompose_atanh_pass.py @@ -5,7 +5,7 @@ from typing import Set, Type -from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes import ArmOpTargetedPass from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass @@ -33,7 +33,7 @@ def _get_atanh_ops(op): ) -class DecomposeAtanhPass(ArmPass): +class DecomposeAtanhPass(ArmOpTargetedPass): """Decomposes the atanh operator into primitive ops. atanh(x) = 0.5 * log((1 + x) / (1 - x)) @@ -46,9 +46,10 @@ class DecomposeAtanhPass(ArmPass): MatchArgDtypePass, ReplaceScalarWithTensorByProfilePass, } + target_ops = (edge_atanh,) def call_operator(self, op, args, kwargs, meta): - if op is not edge_atanh: + if op not in self.target_ops: return super().call_operator(op, args, kwargs, meta, updated=False) if self._is_quantized_meta(meta): diff --git a/backends/arm/_passes/decompose_avg_pool2d_pass.py b/backends/arm/_passes/decompose_avg_pool2d_pass.py index 8fcbcd35b5e..eb30a7600d8 100644 --- a/backends/arm/_passes/decompose_avg_pool2d_pass.py +++ b/backends/arm/_passes/decompose_avg_pool2d_pass.py @@ -7,7 +7,7 @@ from typing import Any, Set, Type import torch -from executorch.backends.arm._passes.arm_pass import ArmPass +from executorch.backends.arm._passes.arm_pass import ArmOpTargetedPass from executorch.backends.arm._passes.fuse_constant_ops_pass import ( ComputeConstantOpsAOTPass, ) @@ -96,13 +96,13 @@ def _get_avgpool_post_pad( return [pad_w, post_w, pad_h, post_h], [0, 0] -class DecomposeAvgPool2dPass(ArmPass): +class DecomposeAvgPool2dPass(ArmOpTargetedPass): _passes_required_after: Set[Type[ExportPass]] = {ComputeConstantOpsAOTPass} + target_ops = edge_avg_pool2d + aten_avg_pool2d + check_allowed_to_transform = True def call_operator(self, op, args, kwargs, meta): - if op not in ( - edge_avg_pool2d + aten_avg_pool2d - ) or not self.allowed_to_transform(meta): + if op not in self.target_ops or not self.allowed_to_transform(meta): return super().call_operator(op, args, kwargs, meta) pad_op, avgpool_op, mul_op = get_decomposition(op) diff --git a/backends/arm/_passes/decompose_cosh_pass.py b/backends/arm/_passes/decompose_cosh_pass.py index 70d4247d9e0..96c73b6cdf2 100644 --- a/backends/arm/_passes/decompose_cosh_pass.py +++ b/backends/arm/_passes/decompose_cosh_pass.py @@ -5,7 +5,7 @@ from typing import Set, Type -from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes import ArmOpTargetedPass from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass @@ -19,7 +19,7 @@ edge_cosh = exir_ops.edge.aten.cosh.default -class DecomposeCoshPass(ArmPass): +class DecomposeCoshPass(ArmOpTargetedPass): """ This pass replaces the cosh operator with a sequence of TOSA-equivalent operations that compute the hyperbolic cosine using the formula: @@ -34,9 +34,10 @@ class DecomposeCoshPass(ArmPass): ReplaceScalarWithTensorByProfilePass, MatchArgDtypePass, } + target_ops = (edge_cosh,) def call_operator(self, op, args, kwargs, meta, updated=False): - if op is not edge_cosh: + if op not in self.target_ops: return super().call_operator(op, args, kwargs, meta, updated) if self._is_quantized_meta(meta): diff --git a/backends/arm/_passes/decompose_cosine_similarity_pass.py b/backends/arm/_passes/decompose_cosine_similarity_pass.py index 6ceb50fdf55..b9e11a68174 100644 --- a/backends/arm/_passes/decompose_cosine_similarity_pass.py +++ b/backends/arm/_passes/decompose_cosine_similarity_pass.py @@ -6,7 +6,7 @@ from typing import Set, Type import torch -from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes import ArmOpTargetedPass from executorch.backends.arm._passes.convert_full_like_to_full_pass import ( ConvertFullLikeToFullPass, ) @@ -19,7 +19,7 @@ torch_cosine_similarity = (torch.ops.aten.cosine_similarity.default,) -class DecomposeCosineSimilarityPass(ArmPass): +class DecomposeCosineSimilarityPass(ArmOpTargetedPass): """Decomposition of aten.cosine_similarity. Example: @@ -42,9 +42,11 @@ class DecomposeCosineSimilarityPass(ArmPass): ConvertFullLikeToFullPass, InsertTableOpsPass, } + target_ops = torch_cosine_similarity + check_allowed_to_transform = True def call_operator(self, op, args, kwargs, meta): - if op not in torch_cosine_similarity or not self.allowed_to_transform(meta): + if op not in self.target_ops or not self.allowed_to_transform(meta): return super().call_operator(op, args, kwargs, meta) x1, x2 = args[0], args[1] diff --git a/backends/arm/_passes/decompose_div_pass.py b/backends/arm/_passes/decompose_div_pass.py index 651e58a563c..be4d91cd30c 100644 --- a/backends/arm/_passes/decompose_div_pass.py +++ b/backends/arm/_passes/decompose_div_pass.py @@ -8,7 +8,7 @@ from typing import Set, Type import torch -from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes import ArmOpTargetedPass from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -28,7 +28,7 @@ def get_div_decomposition(op) -> tuple: raise RuntimeError(f"Can't get div decomposition for op {op}") -class DecomposeDivPass(ArmPass): +class DecomposeDivPass(ArmOpTargetedPass): """This pass decomposes div into a mul and a reciprocal node. Example: @@ -40,11 +40,10 @@ class DecomposeDivPass(ArmPass): """ _passes_required_after: Set[Type[ExportPass]] = {InsertTableOpsPass} + target_ops = edge_div_ops + aten_div_ops def call_operator(self, op, args, kwargs, meta): - if op not in (edge_div_ops + aten_div_ops) or not self.allowed_to_transform( - meta - ): + if op not in self.target_ops or not self.allowed_to_transform(meta): return super().call_operator(op, args, kwargs, meta) reciprocal_op, mul_op = get_div_decomposition(op) diff --git a/backends/arm/_passes/decompose_div_tensor_mode.py b/backends/arm/_passes/decompose_div_tensor_mode.py index 774557b816f..cc5440b4e5b 100644 --- a/backends/arm/_passes/decompose_div_tensor_mode.py +++ b/backends/arm/_passes/decompose_div_tensor_mode.py @@ -7,7 +7,7 @@ from typing import Set, Type import torch -from executorch.backends.arm._passes.arm_pass import ArmPass +from executorch.backends.arm._passes.arm_pass import ArmOpTargetedPass from executorch.backends.arm._passes.decompose_div_pass import DecomposeDivPass from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -42,7 +42,7 @@ def _get_opset(op): raise RuntimeError(f"div.Tensor_mode not supported for op {op}") -class DecomposeDivTensorModePass(ArmPass): +class DecomposeDivTensorModePass(ArmOpTargetedPass): """Rewrites aten.div.Tensor_mode into. Example: @@ -57,11 +57,11 @@ class DecomposeDivTensorModePass(ArmPass): """ _passes_required_after: Set[Type[ExportPass]] = {DecomposeDivPass} + target_ops = edge_div_mode_ops + aten_div_mode_ops + check_allowed_to_transform = True def call_operator(self, op, args, kwargs, meta): - if op not in ( - edge_div_mode_ops + aten_div_mode_ops - ) or not self.allowed_to_transform(meta): + if op not in self.target_ops or not self.allowed_to_transform(meta): return super().call_operator(op, args, kwargs, meta) opset = _get_opset(op) diff --git a/backends/arm/_passes/decompose_elu_pass.py b/backends/arm/_passes/decompose_elu_pass.py index 548a508d914..5f94968ad79 100644 --- a/backends/arm/_passes/decompose_elu_pass.py +++ b/backends/arm/_passes/decompose_elu_pass.py @@ -6,7 +6,7 @@ from typing import Set, Type import torch -from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes import ArmOpTargetedPass from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -71,13 +71,15 @@ def _get_elu_parameters(op, args, kwargs): return alpha, scale, input_scale -class ConvertEluFamilyToEluPass(ArmPass): +class ConvertEluFamilyToEluPass(ArmOpTargetedPass): """Convert SELU/CELU ops to equivalent parameterized ELU ops.""" _passes_required_after: Set[Type[ExportPass]] = set() + target_ops = selu_ops + celu_ops + check_allowed_to_transform = True def call_operator(self, op, args, kwargs, meta): - if op not in selu_ops + celu_ops or not self.allowed_to_transform(meta): + if op not in self.target_ops or not self.allowed_to_transform(meta): return super().call_operator(op, args, kwargs, meta, updated=False) input_ = args[0] @@ -96,7 +98,7 @@ def call_operator(self, op, args, kwargs, meta): ) -class DecomposeEluPass(ArmPass): +class DecomposeEluPass(ArmOpTargetedPass): """A transformation pass that decomposes unsupported 'aten.elu' operations into a combination of supported TOSA-equivalent operations. @@ -119,9 +121,10 @@ class DecomposeEluPass(ArmPass): """ _passes_required_after: Set[Type[ExportPass]] = set() + target_ops = edge_elu_family_ops def call_operator(self, op, args, kwargs, meta): - if op not in edge_elu_family_ops: + if op not in self.target_ops: return super().call_operator(op, args, kwargs, meta, updated=False) if self._is_quantized_meta(meta): diff --git a/backends/arm/_passes/decompose_erfinv_pass.py b/backends/arm/_passes/decompose_erfinv_pass.py index 747209d943e..07f874f9d97 100644 --- a/backends/arm/_passes/decompose_erfinv_pass.py +++ b/backends/arm/_passes/decompose_erfinv_pass.py @@ -5,7 +5,7 @@ from typing import Set, Type -from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes import ArmOpTargetedPass from executorch.backends.arm._passes.convert_full_like_to_full_pass import ( ConvertFullLikeToFullPass, ) @@ -48,7 +48,7 @@ def get_erfinv_decomposition(op) -> tuple: raise RuntimeError(f"Can't get erfinv decomposition for op {op}") -class DecomposeErfinvPass(ArmPass): +class DecomposeErfinvPass(ArmOpTargetedPass): """Decomposes `aten.erfinv` using the same *initial-guess* approximation as the PyTorch CPU scalar `calc_erfinv`, with a guarded Newton refinement step to improve numerical accuracy (especially for fp16). @@ -127,9 +127,10 @@ class DecomposeErfinvPass(ArmPass): MatchArgDtypePass, ReplaceScalarWithTensorByProfilePass, } + target_ops = edge_erfinv_ops def call_operator(self, op, args, kwargs, meta): - if op not in edge_erfinv_ops: + if op not in self.target_ops: return super().call_operator(op, args, kwargs, meta, updated=False) if self._is_quantized_meta(meta): diff --git a/backends/arm/_passes/decompose_expm1_pass.py b/backends/arm/_passes/decompose_expm1_pass.py index c1cb0b83166..6898b9fafb2 100644 --- a/backends/arm/_passes/decompose_expm1_pass.py +++ b/backends/arm/_passes/decompose_expm1_pass.py @@ -5,7 +5,7 @@ from typing import Set, Type -from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes import ArmOpTargetedPass from executorch.backends.arm._passes.decompose_div_pass import DecomposeDivPass from executorch.backends.arm._passes.decompose_int_pow_pass import DecomposeIntPowPass from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass @@ -55,7 +55,7 @@ def _get_expm1_decomposition(op) -> tuple: raise RuntimeError(f"Can't get expm1 decomposition for op {op}") -class DecomposeExpm1Pass(ArmPass): +class DecomposeExpm1Pass(ArmOpTargetedPass): """A transformation pass that decomposes unsupported 'aten.expm1' operations into a combination of supported TOSA-equivalent operations. @@ -87,9 +87,10 @@ class DecomposeExpm1Pass(ArmPass): MatchArgDtypePass, MatchArgRanksPass, } + target_ops = edge_expm1_ops def call_operator(self, op, args, kwargs, meta): - if op not in edge_expm1_ops: + if op not in self.target_ops: return super().call_operator(op, args, kwargs, meta, updated=False) if self._is_quantized_meta(meta): diff --git a/backends/arm/_passes/decompose_floor_divide_pass.py b/backends/arm/_passes/decompose_floor_divide_pass.py index 20e63f48023..d8f451f8af6 100644 --- a/backends/arm/_passes/decompose_floor_divide_pass.py +++ b/backends/arm/_passes/decompose_floor_divide_pass.py @@ -6,7 +6,7 @@ from typing import Set, Type import torch -from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes import ArmOpTargetedPass from executorch.backends.arm._passes.decompose_div_tensor_mode import ( DecomposeDivTensorModePass, ) @@ -47,15 +47,16 @@ def get_floor_divide_decomposition(op) -> tuple: raise RuntimeError(f"Can't get floor_div decomposition for op {op}") -class DecomposeFloorDividePass(ArmPass): +class DecomposeFloorDividePass(ArmOpTargetedPass): """Decomposes aten.floor_divide into aten.div.Tensor_mode with rounding_mode="floor". """ _passes_required_after: Set[Type[ExportPass]] = {DecomposeDivTensorModePass} + target_ops = edge_floor_divide_ops + aten_floor_divide_ops def call_operator(self, op, args, kwargs, meta): - if op not in (edge_floor_divide_ops + aten_floor_divide_ops): + if op not in self.target_ops: return super().call_operator(op, args, kwargs, meta, updated=False) (div_op, full_op) = get_floor_divide_decomposition(op) diff --git a/backends/arm/_passes/decompose_gelu_pass.py b/backends/arm/_passes/decompose_gelu_pass.py index 7815b5fa44f..85f0b77df21 100644 --- a/backends/arm/_passes/decompose_gelu_pass.py +++ b/backends/arm/_passes/decompose_gelu_pass.py @@ -6,7 +6,7 @@ from typing import Set, Type import torch -from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes import ArmOpTargetedPass from executorch.backends.arm._passes.arm_pass_utils import get_node_arg from executorch.backends.arm._passes.fuse_constant_ops_pass import ( ComputeConstantOpsAOTPass, @@ -42,7 +42,7 @@ def _get_gelu_ops(op) -> tuple: raise RuntimeError(f"Can't get GeLU decomposition ops for op {op}") -class DecomposeGeluPass(ArmPass): +class DecomposeGeluPass(ArmOpTargetedPass): """This pass decomposes the GELU operator into primitive ops. Aiming to adhere closely to the reference implementations built into ExecuTorch. Including using the same pre-calculated constants. @@ -88,9 +88,10 @@ class DecomposeGeluPass(ArmPass): MatchArgDtypePass, MatchArgRanksPass, } + target_ops = torch_gelu + edge_gelu def call_operator(self, op, args, kwargs, meta): - if op not in torch_gelu + edge_gelu: + if op not in self.target_ops: return super().call_operator(op, args, kwargs, meta) if self._is_quantized_meta(meta): # If quantized, node should be replace by table op diff --git a/backends/arm/_passes/decompose_glu_pass.py b/backends/arm/_passes/decompose_glu_pass.py index 68efaedd784..5927174a776 100644 --- a/backends/arm/_passes/decompose_glu_pass.py +++ b/backends/arm/_passes/decompose_glu_pass.py @@ -6,7 +6,7 @@ from typing import Set, Type import torch -from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes import ArmOpTargetedPass from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -39,13 +39,14 @@ def get_ops(op): raise ValueError(f"Unsupported operator: {op}") -class DecomposeGluPass(ArmPass): +class DecomposeGluPass(ArmOpTargetedPass): """Decomposes the GLU operator into hadamard product and sigmoid.""" _passes_required_after: Set[Type[ExportPass]] = {InsertTableOpsPass} + target_ops = (edge_glu, aten_glu) def call_operator(self, op, args, kwargs, meta): - if op not in [edge_glu, aten_glu] or not self.allowed_to_transform(meta): + if op not in self.target_ops or not self.allowed_to_transform(meta): return super().call_operator(op, args, kwargs, meta) hadamard_prod, sigmoid, slice_op = get_ops(op) diff --git a/backends/arm/_passes/decompose_grouped_conv_pass.py b/backends/arm/_passes/decompose_grouped_conv_pass.py index ed0adbe83d7..3fb68bc5aef 100644 --- a/backends/arm/_passes/decompose_grouped_conv_pass.py +++ b/backends/arm/_passes/decompose_grouped_conv_pass.py @@ -7,7 +7,7 @@ from typing import Literal, Protocol, Set, Type, TypeGuard import torch -from executorch.backends.arm._passes.arm_pass import ArmPass +from executorch.backends.arm._passes.arm_pass import ArmOpTargetedPass from executorch.backends.arm._passes.conv1d_unsqueeze_pass import Conv1dUnsqueezePass from executorch.backends.arm._passes.quant_args import QuantArgs from executorch.exir.dialects._ops import ops as exir_ops @@ -24,7 +24,7 @@ class _PerChannelQuantArgs(Protocol): per_channel: Literal[True] -class DecomposeGroupedConvPass(ArmPass): +class DecomposeGroupedConvPass(ArmOpTargetedPass): """Splits a grouped convolution which is not supported by TOSA into multiple convolutions using slice->conv->cat. @@ -47,6 +47,11 @@ class DecomposeGroupedConvPass(ArmPass): """ _passes_required_after: Set[Type[ExportPass]] = {Conv1dUnsqueezePass} + target_ops = ( + exir_ops.edge.aten.convolution.default, + torch.ops.aten.conv_transpose2d.input, + torch.ops.aten.conv2d.default, + ) @staticmethod def _get_decomposition(op): diff --git a/backends/arm/_passes/decompose_index_select_to_gather_pass.py b/backends/arm/_passes/decompose_index_select_to_gather_pass.py index 5947e8c5499..be0d4dbb07c 100644 --- a/backends/arm/_passes/decompose_index_select_to_gather_pass.py +++ b/backends/arm/_passes/decompose_index_select_to_gather_pass.py @@ -8,7 +8,7 @@ import torch -from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes import ArmOpTargetedPass from executorch.backends.arm._passes.convert_expand_copy_to_repeat import ( ConvertExpandCopyToRepeatPass, ) @@ -38,7 +38,7 @@ def _get_index_select_decomposition(op): raise RuntimeError(f"Can't get index_select decomposition for op {op}") -class DecomposeIndexSelectToGatherPass(ArmPass): +class DecomposeIndexSelectToGatherPass(ArmOpTargetedPass): """Decompose edge index_select into a single backend TOSA gather. index_select(x, dim, index) semantics: @@ -67,12 +67,12 @@ class DecomposeIndexSelectToGatherPass(ArmPass): ConvertSqueezesToViewPass, } - _TARGET_OPS = { + target_ops = { exir_ops.edge.aten.index_select.default, } def call_operator(self, op, args, kwargs, meta): - if op not in self._TARGET_OPS: + if op not in self.target_ops: return super().call_operator(op, args, kwargs, meta) x, dim, index = args diff --git a/backends/arm/_passes/decompose_index_tensor_to_gather_pass.py b/backends/arm/_passes/decompose_index_tensor_to_gather_pass.py index 037c9977fa6..93db9f9d434 100644 --- a/backends/arm/_passes/decompose_index_tensor_to_gather_pass.py +++ b/backends/arm/_passes/decompose_index_tensor_to_gather_pass.py @@ -9,7 +9,7 @@ import torch -from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes import ArmOpTargetedPass from executorch.backends.arm._passes.arm_pass_utils import meta_without_qparams from executorch.backends.arm._passes.convert_expand_copy_to_repeat import ( ConvertExpandCopyToRepeatPass, @@ -75,7 +75,7 @@ def _broadcast_shape( return out -class DecomposeIndexTensorToGatherPass(ArmPass): +class DecomposeIndexTensorToGatherPass(ArmOpTargetedPass): """Decompose edge.aten.index.Tensor into backend TOSA gather (+ basic arith). @@ -165,7 +165,7 @@ class DecomposeIndexTensorToGatherPass(ArmPass): ReplaceScalarWithTensorByProfilePass, } - _TARGET_OPS = { + target_ops = { exir_ops.edge.aten.index.Tensor, } @@ -246,7 +246,7 @@ def _compute_index_tensor_params(self, x, m, index_shapes): return x_data, S, W, K, C, trailing, lin_scales def call_operator(self, op, args, kwargs, meta): - if op not in self._TARGET_OPS: + if op not in self.target_ops: return super().call_operator(op, args, kwargs, meta) assert ( diff --git a/backends/arm/_passes/decompose_int_pow_pass.py b/backends/arm/_passes/decompose_int_pow_pass.py index a31a9415e23..5147d23b68c 100644 --- a/backends/arm/_passes/decompose_int_pow_pass.py +++ b/backends/arm/_passes/decompose_int_pow_pass.py @@ -6,12 +6,12 @@ from typing import Optional, Set, Type -from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes import ArmOpTargetedPass from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass -class DecomposeIntPowPass(ArmPass): +class DecomposeIntPowPass(ArmOpTargetedPass): """Replaces pow with integer exponent with a series of multiplications. Only handles pow.Tensor_Scalar and not pow.Tensor_Tensor. Needs to be run @@ -20,6 +20,7 @@ class DecomposeIntPowPass(ArmPass): """ _passes_required_after: Set[Type[ExportPass]] = set() + target_ops = (exir_ops.edge.aten.pow.Tensor_Scalar,) @staticmethod def _get_decomposable_integer_exponent(exp) -> Optional[int]: @@ -34,7 +35,7 @@ def _get_decomposable_integer_exponent(exp) -> Optional[int]: return None def call_operator(self, op, args, kwargs, meta): - if op != exir_ops.edge.aten.pow.Tensor_Scalar: + if op not in self.target_ops: return super().call_operator(op, args, kwargs, meta) if self._is_quantized_meta(meta): diff --git a/backends/arm/_passes/decompose_leaky_relu_pass.py b/backends/arm/_passes/decompose_leaky_relu_pass.py index eb8b5bda61a..e2f9852d7f9 100644 --- a/backends/arm/_passes/decompose_leaky_relu_pass.py +++ b/backends/arm/_passes/decompose_leaky_relu_pass.py @@ -8,7 +8,7 @@ from typing import Set, Type import torch -from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes import ArmOpTargetedPass from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -33,7 +33,7 @@ def _get_leaky_relu_ops(op) -> tuple: raise RuntimeError(f"Can't get decomposition ops for op {op}") -class DecomposeLeakyReLUPass(ArmPass): +class DecomposeLeakyReLUPass(ArmOpTargetedPass): """This pass decomposes Leaky ReLU into primitive operations. LeakyReLU(x,slope) = max(0,x) + slope * min(0,x) @@ -47,9 +47,11 @@ class DecomposeLeakyReLUPass(ArmPass): """ _passes_required_after: Set[Type[ExportPass]] = set() + target_ops = edge_ops + torch_ops + check_allowed_to_transform = True def call_operator(self, op, args, kwargs, meta): - if op not in (edge_ops + torch_ops) or not self.allowed_to_transform(meta): + if op not in self.target_ops or not self.allowed_to_transform(meta): return super().call_operator(op, args, kwargs, meta) x = args[0] diff --git a/backends/arm/_passes/decompose_linalg_vector_norm_pass.py b/backends/arm/_passes/decompose_linalg_vector_norm_pass.py index 8b165658c37..1604d861030 100644 --- a/backends/arm/_passes/decompose_linalg_vector_norm_pass.py +++ b/backends/arm/_passes/decompose_linalg_vector_norm_pass.py @@ -6,13 +6,13 @@ from typing import Set, Type import torch -from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes import ArmOpTargetedPass from executorch.backends.arm._passes.decompose_sqrt_pass import DecomposeSqrtPass from executorch.backends.arm._passes.decompose_sum_pass import DecomposeSumPass from executorch.exir.pass_base import ExportPass -class DecomposeLinalgVectorNormPass(ArmPass): +class DecomposeLinalgVectorNormPass(ArmOpTargetedPass): """This pass decomposes aten.linalg_vector_norm.default into more primitive ops. We need to add this pass before quantization for graph annotation. By default, aten.linalg_vector_norm op is decomposed during legalization to @@ -40,11 +40,11 @@ class DecomposeLinalgVectorNormPass(ArmPass): } torch_linalg_vector_norm = (torch.ops.aten.linalg_vector_norm.default,) + target_ops = torch_linalg_vector_norm + check_allowed_to_transform = True def call_operator(self, op, args, kwargs, meta): - if op not in self.torch_linalg_vector_norm or not self.allowed_to_transform( - meta - ): + if op not in self.target_ops or not self.allowed_to_transform(meta): return super().call_operator(op, args, kwargs, meta) # Extract inputs and optional arguments. diff --git a/backends/arm/_passes/decompose_log1p_pass.py b/backends/arm/_passes/decompose_log1p_pass.py index b5cb8659140..7cc5f8cec9c 100644 --- a/backends/arm/_passes/decompose_log1p_pass.py +++ b/backends/arm/_passes/decompose_log1p_pass.py @@ -6,7 +6,7 @@ import logging from typing import Set, Type -from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes import ArmOpTargetedPass from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass @@ -17,7 +17,7 @@ from executorch.exir.pass_base import ExportPass -class DecomposeLog1pPass(ArmPass): +class DecomposeLog1pPass(ArmOpTargetedPass): """Decompose log1p into a small polynomial with a log fallback for larger inputs. """ @@ -32,6 +32,7 @@ class DecomposeLog1pPass(ArmPass): _supported_ops = { exir_ops.edge.aten.log1p.default, } + target_ops = _supported_ops def _poly(self, x, meta): # 6-term Taylor: x - x^2/2 + x^3/3 - x^4/4 + x^5/5 - x^6/6 @@ -63,7 +64,7 @@ def _poly(self, x, meta): return acc def call_operator(self, op, args, kwargs, meta): - if op not in self._supported_ops: + if op not in self.target_ops: return super().call_operator(op, args, kwargs, meta, updated=False) if self._is_quantized_meta(meta): diff --git a/backends/arm/_passes/decompose_logit_pass.py b/backends/arm/_passes/decompose_logit_pass.py index fa82ff4f579..9f9f4744fd0 100644 --- a/backends/arm/_passes/decompose_logit_pass.py +++ b/backends/arm/_passes/decompose_logit_pass.py @@ -7,7 +7,7 @@ import torch -from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes import ArmOpTargetedPass from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass @@ -50,7 +50,7 @@ def get_ops(op): raise ValueError(f"Unsupported operator: {op}") -class DecomposeLogitPass(ArmPass): +class DecomposeLogitPass(ArmOpTargetedPass): """Decomposes the `logit` operator into a sequence of primitive operations. If `eps` is provided, the input tensor `x` is first clamped to the range @@ -78,15 +78,13 @@ class DecomposeLogitPass(ArmPass): ReplaceScalarWithTensorByProfilePass, } - _TARGET_OPS = { + target_ops = { edge_logit, aten_logit, } def call_operator(self, op, args, kwargs, meta): - if op not in DecomposeLogitPass._TARGET_OPS or not self.allowed_to_transform( - meta - ): + if op not in self.target_ops or not self.allowed_to_transform(meta): return super().call_operator(op, args, kwargs, meta) X = args[0] diff --git a/backends/arm/_passes/decompose_masked_fill_pass.py b/backends/arm/_passes/decompose_masked_fill_pass.py index 748aee3fc49..dfb85da7742 100644 --- a/backends/arm/_passes/decompose_masked_fill_pass.py +++ b/backends/arm/_passes/decompose_masked_fill_pass.py @@ -8,7 +8,7 @@ import torch -from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes import ArmOpTargetedPass from executorch.backends.arm._passes.convert_full_like_to_full_pass import ( ConvertFullLikeToFullPass, ) @@ -34,7 +34,7 @@ def _get_decomposition(op) -> tuple: raise RuntimeError(f"Unable to get decomposition for op {op}") -class DecomposeMaskedFillPass(ArmPass): +class DecomposeMaskedFillPass(ArmOpTargetedPass): """Masked fill takes in a boolean mask, a tensor and a scalar value. Fills the tensor with the scalar value according to the boolean mask. @@ -43,9 +43,10 @@ class DecomposeMaskedFillPass(ArmPass): """ _passes_required_after: Set[Type[ExportPass]] = {ConvertFullLikeToFullPass} + target_ops = aten_ops + edge_ops def call_operator(self, op, args, kwargs, meta, updated=False): - if op not in (*aten_ops, *edge_ops): + if op not in self.target_ops: return super().call_operator(op, args, kwargs, meta, updated) x, mask, scalar = args diff --git a/backends/arm/_passes/decompose_maxpool2d_with_dilation_pass.py b/backends/arm/_passes/decompose_maxpool2d_with_dilation_pass.py index 72fe53d57b9..7729b755113 100644 --- a/backends/arm/_passes/decompose_maxpool2d_with_dilation_pass.py +++ b/backends/arm/_passes/decompose_maxpool2d_with_dilation_pass.py @@ -9,7 +9,7 @@ import torch -from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes import ArmOpTargetedPass from executorch.backends.arm._passes.size_adjust_input_pass import SizeAdjustInputPass from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -47,7 +47,7 @@ def _pack_dimension( return packed_dim_size, padding + extra_padding, output_size -class DecomposeMaxPool2dPass(ArmPass): +class DecomposeMaxPool2dPass(ArmOpTargetedPass): """Decompose dilated max_pool2d (EXIR edge ops) into space-to-batch -> maxpool -> batch-to-space. """ @@ -55,10 +55,11 @@ class DecomposeMaxPool2dPass(ArmPass): _passes_required_after: Set[Type[ExportPass]] = { SizeAdjustInputPass, } + target_ops = EDGE_MAXPOOL2D def call_operator(self, op, args, kwargs, meta): # Only intercept EXIR edge max_pool2d ops - if op not in EDGE_MAXPOOL2D: + if op not in self.target_ops: return super().call_operator(op, args, kwargs, meta) # detect whether indices variant diff --git a/backends/arm/_passes/decompose_meandim_pass.py b/backends/arm/_passes/decompose_meandim_pass.py index c7d3bc0a04d..e1175d5ba1b 100644 --- a/backends/arm/_passes/decompose_meandim_pass.py +++ b/backends/arm/_passes/decompose_meandim_pass.py @@ -8,7 +8,7 @@ from typing import Set, Type import torch -from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes import ArmOpTargetedPass from executorch.backends.arm._passes.arm_pass_utils import get_node_arg from executorch.backends.arm._passes.decompose_sum_pass import DecomposeSumPass from executorch.backends.arm._passes.fuse_constant_ops_pass import ( @@ -69,7 +69,7 @@ def get_quantization(op): return None -class DecomposeMeanDimPass(ArmPass): +class DecomposeMeanDimPass(ArmOpTargetedPass): """Decomposes a meandim into sum + mul (1/N). Each reduction dimension is handled via REDUCE_SUM followed by @@ -94,6 +94,13 @@ class DecomposeMeanDimPass(ArmPass): DecomposeSumPass, SizeAdjustInputPass, } + target_ops = ( + exir_ops.edge.aten.mean.dim, + torch.ops.aten.mean.dim, + exir_ops.edge.aten.mean.default, + torch.ops.aten.mean.default, + ) + check_allowed_to_transform = True def __init__(self, graph_module, tosa_spec, *args, **kwargs): super().__init__(*args, **kwargs) @@ -101,12 +108,7 @@ def __init__(self, graph_module, tosa_spec, *args, **kwargs): self._tosa_spec = tosa_spec def call_operator(self, op, args, kwargs, meta, updated=False): - if op not in ( - exir_ops.edge.aten.mean.dim, - torch.ops.aten.mean.dim, - exir_ops.edge.aten.mean.default, - torch.ops.aten.mean.default, - ) or not self.allowed_to_transform(meta): + if op not in self.target_ops or not self.allowed_to_transform(meta): return super().call_operator(op, args, kwargs, meta, updated) x = get_node_arg(args, 0) diff --git a/backends/arm/_passes/decompose_ne_pass.py b/backends/arm/_passes/decompose_ne_pass.py index 95dfc0e1179..4dfcf6ad934 100644 --- a/backends/arm/_passes/decompose_ne_pass.py +++ b/backends/arm/_passes/decompose_ne_pass.py @@ -6,7 +6,7 @@ from typing import Set, Type import torch -from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes import ArmOpTargetedPass from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -38,7 +38,7 @@ def get_ne_decomposition(op) -> tuple: raise RuntimeError(f"Can't get ne decomposition for op {op}") -class DecomposeNotEqualPass(ArmPass): +class DecomposeNotEqualPass(ArmOpTargetedPass): """A transformation pass that decomposes unsupported `aten.ne` operations into a combination of supported TOSA-equivalent operations. @@ -57,9 +57,10 @@ class DecomposeNotEqualPass(ArmPass): """ _passes_required_after: Set[Type[ExportPass]] = set() + target_ops = edge_ne_ops + aten_ne_ops def call_operator(self, op, args, kwargs, meta): - if op not in (edge_ne_ops + aten_ne_ops) or not self.allowed_to_transform(meta): + if op not in self.target_ops or not self.allowed_to_transform(meta): return super().call_operator(op, args, kwargs, meta) lhs, rhs = args diff --git a/backends/arm/_passes/decompose_permute_for_u55_pass.py b/backends/arm/_passes/decompose_permute_for_u55_pass.py index ceed25f97ec..a9e8beef1cd 100644 --- a/backends/arm/_passes/decompose_permute_for_u55_pass.py +++ b/backends/arm/_passes/decompose_permute_for_u55_pass.py @@ -11,7 +11,7 @@ import torch import tosa_serializer as ts -from executorch.backends.arm._passes.arm_pass import ArmPass +from executorch.backends.arm._passes.arm_pass import ArmOpTargetedPass from executorch.backends.arm._passes.rewrite_slice import RewriteSlicePass from executorch.backends.arm.arm_vela import vela_compile from executorch.backends.arm.tosa.mapping import map_dtype @@ -20,7 +20,7 @@ from executorch.exir.pass_base import ExportPass -class DecomposePermuteForU55Pass(ArmPass): +class DecomposePermuteForU55Pass(ArmOpTargetedPass): """Decompose U55 permutes into shape-safe permutes for large tensor shapes. Ethos-U55 has transpose shape constraints based on rank-dependent @@ -36,6 +36,7 @@ class DecomposePermuteForU55Pass(ArmPass): exir_ops.edge.aten.permute.default, exir_ops.edge.aten.permute_copy.default, ) + target_ops = _PERMUTE_OPS _SLICE_OP = exir_ops.edge.aten.slice_copy.Tensor _CAT_OP = exir_ops.edge.aten.cat.default _MAX_PRODUCT = 2**16 @@ -323,7 +324,7 @@ def recurse(current, depth: int): return recurse(input_node, 0) def call_operator(self, op, args, kwargs, meta): - if op not in self._PERMUTE_OPS: + if op not in self.target_ops: return super().call_operator(op, args, kwargs, meta) spec = get_context_spec() diff --git a/backends/arm/_passes/decompose_remainder_pass.py b/backends/arm/_passes/decompose_remainder_pass.py index 38185b85149..af22cad1624 100644 --- a/backends/arm/_passes/decompose_remainder_pass.py +++ b/backends/arm/_passes/decompose_remainder_pass.py @@ -6,7 +6,7 @@ from typing import Dict, Set, Type import torch -from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes import ArmOpTargetedPass from executorch.backends.arm._passes.decompose_div_tensor_mode import ( DecomposeDivTensorModePass, ) @@ -41,7 +41,7 @@ } -class DecomposeRemainderPass(ArmPass): +class DecomposeRemainderPass(ArmOpTargetedPass): """ Decompose the remainder operation into primitive arithmetic: remainder(x, y) -> x - floor_div(x, y) * y @@ -49,15 +49,10 @@ class DecomposeRemainderPass(ArmPass): """ _passes_required_after: Set[Type[ExportPass]] = {DecomposeDivTensorModePass} + target_ops = tuple(_decomposition_ops) def call_operator(self, op, args, kwargs, meta, updated=False): - supported_ops = ( - exir_ops.edge.aten.remainder.Scalar, - exir_ops.edge.aten.remainder.Tensor, - torch.ops.aten.remainder.Scalar, - torch.ops.aten.remainder.Tensor, - ) - if op not in supported_ops: + if op not in self.target_ops: return super().call_operator(op, args, kwargs, meta, updated) # Keep scalar remainder opaque during transform-for-annotation so the # quantizer can wrap the original op directly. In the backend pipeline, diff --git a/backends/arm/_passes/decompose_round_pass.py b/backends/arm/_passes/decompose_round_pass.py index 9319394d986..476f75d6b56 100644 --- a/backends/arm/_passes/decompose_round_pass.py +++ b/backends/arm/_passes/decompose_round_pass.py @@ -6,7 +6,7 @@ from typing import Set, Type import torch -from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes import ArmOpTargetedPass from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.dialects.edge._ops import EdgeOpOverload from executorch.exir.pass_base import ExportPass @@ -46,7 +46,7 @@ def _get_round_decomposition_ops(op) -> tuple[Op, Op, Op, Op, Op, Op, Op]: raise RuntimeError(f"Can't get round decomposition ops for op {op}") -class DecomposeRoundPass(ArmPass): +class DecomposeRoundPass(ArmOpTargetedPass): """ For inputs >= 0, round(x) is equivalent to floor(x + 0.5), and for inputs < 0, round(x) is equivalent to ceil(x - 0.5). This pass decomposes the round operation into @@ -63,15 +63,13 @@ class DecomposeRoundPass(ArmPass): _passes_required_after: Set[Type[ExportPass]] = set() - _TARGET_OPS = { + target_ops = { exir_ops.edge.aten.round.default, torch.ops.aten.round.default, } def call_operator(self, op, args, kwargs, meta, updated=False): - if op not in DecomposeRoundPass._TARGET_OPS or not self.allowed_to_transform( - meta - ): + if op not in self.target_ops or not self.allowed_to_transform(meta): return super().call_operator(op, args, kwargs, meta, updated) x = args[0] input_dtype = x.node.meta["val"].dtype diff --git a/backends/arm/_passes/decompose_select_scatter_pass.py b/backends/arm/_passes/decompose_select_scatter_pass.py index 4b4db8d208c..129e9f05961 100644 --- a/backends/arm/_passes/decompose_select_scatter_pass.py +++ b/backends/arm/_passes/decompose_select_scatter_pass.py @@ -7,7 +7,7 @@ import torch -from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes import ArmOpTargetedPass from executorch.backends.arm._passes.convert_int64_const_ops_to_int32 import ( ConvertInt64ConstOpsToInt32Pass, ) @@ -44,7 +44,7 @@ def get_select_scatter_decomposition(op) -> tuple: raise RuntimeError(f"Can't get select_scatter decomposition for op {op}") -class DecomposeSelectScatterPass(ArmPass): +class DecomposeSelectScatterPass(ArmOpTargetedPass): """select_scatter is decomposed into other ops during export, however this is only suppported for the fp profile and for the int profile we need to decompose it here. @@ -65,9 +65,10 @@ class DecomposeSelectScatterPass(ArmPass): ReplaceScalarWithTensorByProfilePass, ConvertInt64ConstOpsToInt32Pass, } + target_ops = edge_scatter_ops + aten_scatter_ops def call_operator(self, op, args, kwargs, meta): - if op not in (edge_scatter_ops + aten_scatter_ops): + if op not in self.target_ops: return super().call_operator(op, args, kwargs, meta, updated=False) ( diff --git a/backends/arm/_passes/decompose_sign_pass.py b/backends/arm/_passes/decompose_sign_pass.py index 111d1ca5ee3..8f7fda8729b 100644 --- a/backends/arm/_passes/decompose_sign_pass.py +++ b/backends/arm/_passes/decompose_sign_pass.py @@ -7,7 +7,7 @@ import torch -from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes import ArmOpTargetedPass from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -44,15 +44,16 @@ def get_ops(op): raise ValueError(f"Unsupported operator: {op}") -class DecomposeSignPass(ArmPass): +class DecomposeSignPass(ArmOpTargetedPass): """Decomposes the sign operator into a sequence of operations that are supported by the Arm backend. """ _passes_required_after: Set[Type[ExportPass]] = set() + target_ops = (edge_sign, aten_sign) def call_operator(self, op, args, kwargs, meta): - if op not in (edge_sign, aten_sign) or not self.allowed_to_transform(meta): + if op not in self.target_ops or not self.allowed_to_transform(meta): return super().call_operator(op, args, kwargs, meta) gt_op, lt_op, where_op, neg_op, mul_op, add_op = get_ops(op) diff --git a/backends/arm/_passes/decompose_sinh_pass.py b/backends/arm/_passes/decompose_sinh_pass.py index 71ac0a34f08..053b378af83 100644 --- a/backends/arm/_passes/decompose_sinh_pass.py +++ b/backends/arm/_passes/decompose_sinh_pass.py @@ -6,7 +6,7 @@ from typing import Set, Type -from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes import ArmOpTargetedPass from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass @@ -21,7 +21,7 @@ edge_sinh = exir_ops.edge.aten.sinh.default -class DecomposeSinhPass(ArmPass): +class DecomposeSinhPass(ArmOpTargetedPass): """A decomposition pass that decomposes Sinh operations into a combination of supported TOSA-equivalent operations (MI). @@ -39,9 +39,10 @@ class DecomposeSinhPass(ArmPass): ReplaceScalarWithTensorByProfilePass, MatchArgDtypePass, } + target_ops = (edge_sinh,) def call_operator(self, op, args, kwargs, meta): - if op is not edge_sinh: + if op not in self.target_ops: return super().call_operator(op, args, kwargs, meta) if self._is_quantized_meta(meta): diff --git a/backends/arm/_passes/decompose_slice_scatter_pass.py b/backends/arm/_passes/decompose_slice_scatter_pass.py index 24cdfeb96a5..edf030f9701 100644 --- a/backends/arm/_passes/decompose_slice_scatter_pass.py +++ b/backends/arm/_passes/decompose_slice_scatter_pass.py @@ -7,7 +7,7 @@ import torch -from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes import ArmOpTargetedPass from executorch.backends.arm._passes.accumulate_index_put_pass import ( AccumulateIndexPutPass, ) @@ -53,7 +53,7 @@ def _fixup_end(end, dim_size: int) -> int: return max(0, min(e, dim_size)) -class DecomposeSliceScatterPass(ArmPass): +class DecomposeSliceScatterPass(ArmOpTargetedPass): """ Decompose slice_scatter into: - Fast path (step == 1): slice_copy + cat (contiguous update), or @@ -71,9 +71,10 @@ class DecomposeSliceScatterPass(ArmPass): AccumulateIndexPutPass, RewriteIndexPutPass, } + target_ops = edge_slice_scatter_ops + aten_slice_scatter_ops def call_operator(self, op, args, kwargs, meta): - if op not in (edge_slice_scatter_ops + aten_slice_scatter_ops): + if op not in self.target_ops: return super().call_operator(op, args, kwargs, meta) ( diff --git a/backends/arm/_passes/decompose_softmax_pass.py b/backends/arm/_passes/decompose_softmax_pass.py index cb05b7c4b0c..d30137c0460 100644 --- a/backends/arm/_passes/decompose_softmax_pass.py +++ b/backends/arm/_passes/decompose_softmax_pass.py @@ -7,7 +7,7 @@ from typing import Set, Type import torch -from executorch.backends.arm._passes.arm_pass import ArmPass +from executorch.backends.arm._passes.arm_pass import ArmOpTargetedPass from executorch.backends.arm._passes.decompose_sum_pass import DecomposeSumPass from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass from executorch.exir.dialects._ops import ops as exir_ops @@ -56,7 +56,7 @@ def _get_logsoftmax_ops(op) -> tuple: raise RuntimeError(f"Can't get logsoftmax decomposition ops for op {op}") -class DecomposeSoftmaxPass(ArmPass): +class DecomposeSoftmaxPass(ArmOpTargetedPass): """This pass decomposes log_softmax or softmax into more primitive ops. Example: @@ -77,6 +77,7 @@ class DecomposeSoftmaxPass(ArmPass): DecomposeSumPass, InsertTableOpsPass, } + target_ops = torch_softmax + edge_softmax def __init__(self, skip_safe_softmax: bool = False, **kwargs): super().__init__(**kwargs) @@ -84,9 +85,7 @@ def __init__(self, skip_safe_softmax: bool = False, **kwargs): self._warned_safe_softmax = False def call_operator(self, op, args, kwargs, meta): - if op not in torch_softmax + edge_softmax or not self.allowed_to_transform( - meta - ): + if op not in self.target_ops or not self.allowed_to_transform(meta): return super().call_operator(op, args, kwargs, meta) if self._skip_safe_softmax and op == torch.ops.aten._safe_softmax.default: diff --git a/backends/arm/_passes/decompose_sqrt_pass.py b/backends/arm/_passes/decompose_sqrt_pass.py index 86e5d6681bd..ce5a5b6d2a4 100644 --- a/backends/arm/_passes/decompose_sqrt_pass.py +++ b/backends/arm/_passes/decompose_sqrt_pass.py @@ -6,7 +6,7 @@ from typing import Set, Tuple, Type, Union import torch -from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes import ArmOpTargetedPass from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -27,15 +27,14 @@ def get_sqrt_decomposition(op) -> Union[Tuple, torch._ops.OpOverload]: raise RuntimeError(f"Can't get sqrt decomposition for op {op}") -class DecomposeSqrtPass(ArmPass): +class DecomposeSqrtPass(ArmOpTargetedPass): _passes_required_after: Set[Type[ExportPass]] = {InsertTableOpsPass} + target_ops = edge_sqrt_ops + aten_sqrt_ops def call_operator(self, op, args, kwargs, meta): """Decomposes `sqrt(x)` into `pow(x, 0.5)` for backend support.""" - if op not in (edge_sqrt_ops + aten_sqrt_ops) or not self.allowed_to_transform( - meta - ): + if op not in self.target_ops or not self.allowed_to_transform(meta): return super().call_operator(op, args, kwargs, meta) if self._is_quantized_meta(meta): diff --git a/backends/arm/_passes/decompose_strided_slice_copy_pass.py b/backends/arm/_passes/decompose_strided_slice_copy_pass.py index 71cc618ed9c..91606dd0bd6 100644 --- a/backends/arm/_passes/decompose_strided_slice_copy_pass.py +++ b/backends/arm/_passes/decompose_strided_slice_copy_pass.py @@ -6,7 +6,7 @@ from typing import Set, Type import torch -from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes import ArmOpTargetedPass from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -42,7 +42,7 @@ def _fixup_end(end, dim_size): return max(0, min(e, dim_size)) -class DecomposeStridedSliceCopyPass(ArmPass): +class DecomposeStridedSliceCopyPass(ArmOpTargetedPass): """Decompose edge.aten.slice_copy.Tensor with non-unit step into supported ops. @@ -61,10 +61,10 @@ class DecomposeStridedSliceCopyPass(ArmPass): """ _passes_required_after: Set[Type[ExportPass]] = set() - _TARGET_OPS = {exir_ops.edge.aten.slice_copy.Tensor} + target_ops = {exir_ops.edge.aten.slice_copy.Tensor} def call_operator(self, op, args, kwargs, meta): - if op not in self._TARGET_OPS: + if op not in self.target_ops: return super().call_operator(op, args, kwargs, meta) # Only handle the non-unit-step case; leave unit-step to existing lowering. diff --git a/backends/arm/_passes/decompose_sum_pass.py b/backends/arm/_passes/decompose_sum_pass.py index 3076510533e..e134ea6abc7 100644 --- a/backends/arm/_passes/decompose_sum_pass.py +++ b/backends/arm/_passes/decompose_sum_pass.py @@ -6,7 +6,7 @@ from typing import Set, Type import torch -from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes import ArmOpTargetedPass from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -24,7 +24,7 @@ def _get_sum_decomp(op): raise RuntimeError("Unvalid op in DecomposeSumPass") -class DecomposeSumPass(ArmPass): +class DecomposeSumPass(ArmOpTargetedPass): """In Pytorch, the default behaviour of for example Tensor.sum is to squeeze the dimension that is summed (keep_dim = False). However, in TOSA, REDUCE_SUM always preserves the rank of the input (keep_dim = True). To get @@ -44,12 +44,13 @@ class DecomposeSumPass(ArmPass): """ _passes_required_after: Set[Type[ExportPass]] = set() + target_ops = ( + exir_ops.edge.aten.sum.dim_IntList, + torch.ops.aten.sum.dim_IntList, + ) def call_operator(self, op, args, kwargs, meta): - if op not in [ - exir_ops.edge.aten.sum.dim_IntList, - torch.ops.aten.sum.dim_IntList, - ]: + if op not in self.target_ops: return super().call_operator(op, args, kwargs, meta) match len(args): diff --git a/backends/arm/_passes/decompose_tan_pass.py b/backends/arm/_passes/decompose_tan_pass.py index 87b347dbbad..2d655a9937d 100644 --- a/backends/arm/_passes/decompose_tan_pass.py +++ b/backends/arm/_passes/decompose_tan_pass.py @@ -5,7 +5,7 @@ from typing import Set, Type -from executorch.backends.arm._passes import ArmPass, DecomposeDivPass +from executorch.backends.arm._passes import ArmOpTargetedPass, DecomposeDivPass from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -13,13 +13,14 @@ edge_tan_op = exir_ops.edge.aten.tan.default -class DecomposeTanPass(ArmPass): +class DecomposeTanPass(ArmOpTargetedPass): """Decomposes tan to sin/cos.""" _passes_required_after: Set[Type[ExportPass]] = {DecomposeDivPass} + target_ops = (edge_tan_op,) def call_operator(self, op, args, kwargs, meta, updated=False): - if op != edge_tan_op: + if op not in self.target_ops: return super().call_operator(op, args, kwargs, meta, updated) # Skip quantized tan - it is decomposed as one single table op if self._is_quantized_meta(meta): diff --git a/backends/arm/_passes/decompose_tosa_unsupported_clamp_pass.py b/backends/arm/_passes/decompose_tosa_unsupported_clamp_pass.py index 2410ce503a7..12dcd06388c 100644 --- a/backends/arm/_passes/decompose_tosa_unsupported_clamp_pass.py +++ b/backends/arm/_passes/decompose_tosa_unsupported_clamp_pass.py @@ -6,12 +6,12 @@ from typing import Set, Type import torch -from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes import ArmOpTargetedPass from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass -class DecomposeTOSAUnsupportedClampPass(ArmPass): +class DecomposeTOSAUnsupportedClampPass(ArmOpTargetedPass): """Rewrite TOSA unsupported clamp into min/max chain since TOSA lacks int32 clamp support and only supports scalar min/max values. """ @@ -23,6 +23,7 @@ class DecomposeTOSAUnsupportedClampPass(ArmPass): torch.ops.aten.clamp.default, torch.ops.aten.clamp.Tensor, } + target_ops = _supported_ops def _ensure_tensor( self, @@ -54,7 +55,7 @@ def call_operator(self, op, args, kwargs, meta): torch.ops.aten.clamp.Tensor, } - if op not in self._supported_ops: + if op not in self.target_ops: return super().call_operator(op, args, kwargs, meta) # Only rewrite scalar clamp for int32 diff --git a/backends/arm/_passes/decompose_tril_pass.py b/backends/arm/_passes/decompose_tril_pass.py index 3101b24e95b..9108208e73d 100644 --- a/backends/arm/_passes/decompose_tril_pass.py +++ b/backends/arm/_passes/decompose_tril_pass.py @@ -6,7 +6,7 @@ from typing import Set, Type import torch -from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes import ArmOpTargetedPass from executorch.backends.arm._passes.arm_pass_utils import get_node_arg from executorch.backends.arm._passes.fuse_constant_ops_pass import ( ComputeConstantOpsAOTPass, @@ -44,7 +44,7 @@ def _get_ops(op): raise RuntimeError(f"Unable to get decomposition ops for {op}") -class DecomposeTrilPass(ArmPass): +class DecomposeTrilPass(ArmOpTargetedPass): """Tril decomposition. Decomposition: @@ -54,11 +54,10 @@ class DecomposeTrilPass(ArmPass): """ _passes_required_after: Set[Type[ExportPass]] = {ComputeConstantOpsAOTPass} + target_ops = (torch.ops.aten.tril.default,) def call_operator(self, op, args, kwargs, meta): - handled_ops = [torch.ops.aten.tril.default] - - if op not in handled_ops: + if op not in self.target_ops: return super().call_operator(op, args, kwargs, meta) x = args[0] diff --git a/backends/arm/_passes/decompose_unfold_to_gather_pass.py b/backends/arm/_passes/decompose_unfold_to_gather_pass.py index d0e3897080a..950290b3b83 100644 --- a/backends/arm/_passes/decompose_unfold_to_gather_pass.py +++ b/backends/arm/_passes/decompose_unfold_to_gather_pass.py @@ -9,7 +9,7 @@ import torch -from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes import ArmOpTargetedPass from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import ( ReplaceScalarWithTensorByProfilePass, ) @@ -29,7 +29,7 @@ def _get_unfold_copy_decomposition(op) -> tuple: """ - if op in DecomposeUnfoldToGatherPass._TARGET_OPS: + if op in DecomposeUnfoldToGatherPass.target_ops: return ( exir_ops.edge.dim_order_ops._to_dim_order_copy.default, exir_ops.edge.aten.view_copy.default, @@ -45,7 +45,7 @@ def _get_unfold_copy_decomposition(op) -> tuple: raise RuntimeError(f"Can't get unfold_copy decomposition for op {op}") -class DecomposeUnfoldToGatherPass(ArmPass): +class DecomposeUnfoldToGatherPass(ArmOpTargetedPass): """Decompose unfold_copy with backend tosa.GATHER as the core op, plus other TOSA-supported ops to build indices and materialize the output layout. @@ -93,7 +93,7 @@ class DecomposeUnfoldToGatherPass(ArmPass): ReplaceScalarWithTensorByProfilePass, } - _TARGET_OPS = { + target_ops = { exir_ops.edge.aten.unfold_copy.default, } @@ -147,7 +147,7 @@ def _compute_unfold_copy_params( return (x_val, C, S, K, U, UC, pre, post, P, Q, needs_bool_cast) def call_operator(self, op, args, kwargs, meta): - if op not in self._TARGET_OPS: + if op not in self.target_ops: return super().call_operator(op, args, kwargs, meta) x, dim, size, step = args diff --git a/backends/arm/_passes/decompose_var_pass.py b/backends/arm/_passes/decompose_var_pass.py index fcf61cf5129..90ea80b6b47 100644 --- a/backends/arm/_passes/decompose_var_pass.py +++ b/backends/arm/_passes/decompose_var_pass.py @@ -8,7 +8,7 @@ from typing import Set, Type import torch -from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes import ArmOpTargetedPass from executorch.backends.arm._passes.arm_pass_utils import get_node_arg from executorch.backends.arm._passes.decompose_meandim_pass import DecomposeMeanDimPass from executorch.backends.arm._passes.decompose_sum_pass import DecomposeSumPass @@ -37,7 +37,7 @@ def get_var_decomposition(op) -> tuple: raise RuntimeError(f"Can't get var decomposition for op {op}") -class DecomposeVarPass(ArmPass): +class DecomposeVarPass(ArmOpTargetedPass): """ This pass decomposes var.correction and var.dim into smaller ops (see https://pytorch.org/docs/stable/generated/torch.var.html) @@ -56,13 +56,15 @@ class DecomposeVarPass(ArmPass): DecomposeMeanDimPass, DecomposeSumPass, } + target_ops = ( + exir_ops.edge.aten.var.correction, + torch.ops.aten.var.correction, + torch.ops.aten.var.dim, + ) + check_allowed_to_transform = True def call_operator(self, op, args, kwargs, meta): - if op not in ( - exir_ops.edge.aten.var.correction, - torch.ops.aten.var.correction, - torch.ops.aten.var.dim, - ) or not self.allowed_to_transform(meta): + if op not in self.target_ops or not self.allowed_to_transform(meta): return super().call_operator(op, args, kwargs, meta) x = args[0] diff --git a/backends/arm/_passes/decompose_where_scalar_other_pass.py b/backends/arm/_passes/decompose_where_scalar_other_pass.py index a125a6355cb..8b4b27c8ce2 100644 --- a/backends/arm/_passes/decompose_where_scalar_other_pass.py +++ b/backends/arm/_passes/decompose_where_scalar_other_pass.py @@ -5,7 +5,7 @@ from typing import Set, Type -from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes import ArmOpTargetedPass from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -27,20 +27,18 @@ def _get_where_scalar_other_decomposition(op): raise RuntimeError(f"Can't get where.ScalarOther decomposition for op {op}") -class DecomposeWhereScalarOtherPass(ArmPass): +class DecomposeWhereScalarOtherPass(ArmOpTargetedPass): """Decompose where.ScalarOther into where.self with a tensorized scalar.""" _passes_required_after: Set[Type[ExportPass]] = set() - _TARGET_OPS = { + target_ops = { exir_ops.edge.aten.where.ScalarOther, } + check_allowed_to_transform = True def call_operator(self, op, args, kwargs, meta, updated=False): - if ( - op not in DecomposeWhereScalarOtherPass._TARGET_OPS - or not self.allowed_to_transform(meta) - ): + if op not in self.target_ops or not self.allowed_to_transform(meta): return super().call_operator(op, args, kwargs, meta, updated) condition, self_tensor, other_scalar = args diff --git a/backends/arm/_passes/decorate_fp32_to_int32_casting_pass.py b/backends/arm/_passes/decorate_fp32_to_int32_casting_pass.py index b856df8e060..3ddd1358035 100644 --- a/backends/arm/_passes/decorate_fp32_to_int32_casting_pass.py +++ b/backends/arm/_passes/decorate_fp32_to_int32_casting_pass.py @@ -7,7 +7,7 @@ from typing import Set, Type import torch -from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes import ArmOpTargetedPass from executorch.backends.arm._passes.arm_pass_utils import get_node_arg from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -26,7 +26,7 @@ def _get_decorated_ops(op): raise RuntimeError(f"Can't get decorated ops for op {op}") -class DecorateFp32toInt32CastingPass(ArmPass): +class DecorateFp32toInt32CastingPass(ArmOpTargetedPass): """To lower pytorch fp32 -> int32 casting to TOSA, we need to transform the value with Ceil, Floor, and Where. @@ -47,9 +47,10 @@ class DecorateFp32toInt32CastingPass(ArmPass): targets = [ exir_ops.edge.dim_order_ops._to_dim_order_copy.default, ] + target_ops = targets def call_operator(self, op, args, kwargs, meta): - if op not in self.targets: + if op not in self.target_ops: return super().call_operator(op, args, kwargs, meta) input = get_node_arg(args, 0) diff --git a/backends/arm/_passes/fuse_consecutive_concat_shapes.py b/backends/arm/_passes/fuse_consecutive_concat_shapes.py index 8a02697d57c..fc2d46d3c12 100644 --- a/backends/arm/_passes/fuse_consecutive_concat_shapes.py +++ b/backends/arm/_passes/fuse_consecutive_concat_shapes.py @@ -6,12 +6,12 @@ from typing import Any import torch -from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes import ArmOpTargetedPass from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import NodeMetadata, ProxyValue -class FuseConsecutiveConcatShapesPass(ArmPass): +class FuseConsecutiveConcatShapesPass(ArmOpTargetedPass): """This pass fuses consecutive tosa.CONCAT_SHAPE operations into a single tosa.CONCAT_SHAPE operation with a flattened list of input shapes. E.g. tosa.CONCAT_SHAPE([shape1, tosa.CONCAT_SHAPE([shape2, shape3]), shape4]) @@ -24,6 +24,7 @@ class FuseConsecutiveConcatShapesPass(ArmPass): """ _passes_required_after = set() + target_ops = (exir_ops.backend.tosa.CONCAT_SHAPE.default,) def _to_proxy_value( self, arg: ProxyValue | torch.fx.Node | Any @@ -42,7 +43,7 @@ def call_operator( meta: NodeMetadata, updated: bool | None = False, ) -> ProxyValue: - if op != exir_ops.backend.tosa.CONCAT_SHAPE.default: + if op not in self.target_ops: return super().call_operator(op, args, kwargs, meta) arg_list = args[0] new_arg_list: list[Any] = [] diff --git a/backends/arm/_passes/insert_const_shapes.py b/backends/arm/_passes/insert_const_shapes.py index b03394379d9..c916438eb09 100644 --- a/backends/arm/_passes/insert_const_shapes.py +++ b/backends/arm/_passes/insert_const_shapes.py @@ -5,12 +5,12 @@ from typing import Any, Optional -from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes import ArmOpTargetedPass from executorch.backends.arm.tosa.dialect.shape import meta_has_shape_mark from executorch.exir.dialects._ops import ops as exir_ops -class InsertConstShapesPass(ArmPass): +class InsertConstShapesPass(ArmOpTargetedPass): """Materialize literal shape arguments as CONST_SHAPE nodes. This pass targets ops such as `aten.view_copy` and `aten.repeat` whose shape @@ -21,11 +21,15 @@ class InsertConstShapesPass(ArmPass): """ _passes_required_after = set() - targeted_ops = { + target_ops = { exir_ops.edge.aten.view_copy.default, exir_ops.edge.aten.repeat.default, } + def __init__(self) -> None: + super().__init__() + self._const_shape_cache: dict[tuple[int, ...], Any] = {} + @staticmethod def _is_shape_arg(arg: Any) -> bool: """Return True when `arg` looks like a literal shape list/tuple.""" @@ -37,7 +41,7 @@ def _is_shape_arg(arg: Any) -> bool: ) def call_operator(self, op, args, kwargs, meta, updated: Optional[bool] = False): - if op not in self.targeted_ops: + if op not in self.target_ops: return super().call_operator(op, args, kwargs, meta, updated) if any(InsertConstShapesPass._is_shape_arg(arg) for arg in args): new_args = [] @@ -46,13 +50,17 @@ def call_operator(self, op, args, kwargs, meta, updated: Optional[bool] = False) # Insert a const node for the shape argument if op == exir_ops.edge.aten.view_copy.default: arg = meta.data["val"].shape - const_node = super().call_shape_operator( - exir_ops.backend.tosa.CONST_SHAPE.default, - (arg,), - {}, - meta, - True, - ) + shape = tuple(arg) + const_node = self._const_shape_cache.get(shape) + if const_node is None: + const_node = super().call_shape_operator( + exir_ops.backend.tosa.CONST_SHAPE.default, + (arg,), + {}, + meta, + True, + ) + self._const_shape_cache[shape] = const_node new_args.append(const_node) updated = True else: diff --git a/backends/arm/_passes/insert_data_layout_casts_pass.py b/backends/arm/_passes/insert_data_layout_casts_pass.py index b760baef6e8..07a2d186895 100644 --- a/backends/arm/_passes/insert_data_layout_casts_pass.py +++ b/backends/arm/_passes/insert_data_layout_casts_pass.py @@ -6,13 +6,13 @@ from typing import Set, Type import torch -from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes import ArmOpTargetedPass from executorch.backends.arm.tosa.specification import get_context_spec from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, NodeMetadata -class InsertDataLayoutCastsPass(ArmPass): +class InsertDataLayoutCastsPass(ArmOpTargetedPass): """Insert casts around data layout operators when their dtype is not supported by the active TOSA specification. @@ -45,7 +45,7 @@ class InsertDataLayoutCastsPass(ArmPass): exir_ops.edge.aten.slice_copy.Tensor, exir_ops.edge.aten.flip.default, } - targeted_ops = _concat_ops | _single_input_ops + target_ops = _concat_ops | _single_input_ops _fp_to_int_map = { torch.float16: torch.int16, @@ -60,7 +60,7 @@ class InsertDataLayoutCastsPass(ArmPass): } def call_operator(self, op, args, kwargs, meta): - if op not in self.targeted_ops: + if op not in self.target_ops: return super().call_operator(op, args, kwargs, meta) if op in self._concat_ops: diff --git a/backends/arm/_passes/insert_dynamic_padding.py b/backends/arm/_passes/insert_dynamic_padding.py index ea03e231ae8..61a5ebd09ca 100644 --- a/backends/arm/_passes/insert_dynamic_padding.py +++ b/backends/arm/_passes/insert_dynamic_padding.py @@ -7,14 +7,14 @@ import torch -from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes import ArmOpTargetedPass from executorch.backends.arm.tosa.dialect.shape import is_shape_op_node from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, ProxyValue -class InsertDynamicPaddingPass(ArmPass): +class InsertDynamicPaddingPass(ArmOpTargetedPass): """This pass rewrites conv operations with padding to use an explicit pad operator before the conv2d operation and setting the padding to zero in the conv2d operator. E.g. conv2d(x, weight, bias, stride, padding, dilation) @@ -27,6 +27,10 @@ class InsertDynamicPaddingPass(ArmPass): """ _passes_required_after: Set[Type[ExportPass]] = set() + target_ops = ( + exir_ops.backend.tosa.CONV2D.default, + exir_ops.backend.tosa.DEPTHWISE_CONV2D.default, + ) def _is_dynamic_padding( self, padding: ProxyValue | list[int] | tuple[int, ...] @@ -39,10 +43,7 @@ def _is_dynamic_padding( ) def call_operator(self, op, args, kwargs, meta, updated=False) -> ProxyValue: - if op not in ( - exir_ops.backend.tosa.CONV2D.default, - exir_ops.backend.tosa.DEPTHWISE_CONV2D.default, - ): + if op not in self.target_ops: return super().call_operator(op, args, kwargs, meta, updated) padding = args[4] if not self._is_dynamic_padding(padding): diff --git a/backends/arm/_passes/insert_rescales_pass.py b/backends/arm/_passes/insert_rescales_pass.py index 06c27005440..45374c12c3b 100644 --- a/backends/arm/_passes/insert_rescales_pass.py +++ b/backends/arm/_passes/insert_rescales_pass.py @@ -509,7 +509,13 @@ def _rescale_submodule_inputs( input_node = input_nodes[qargs_index] if len(input_node.users) == 0: continue - if len(out_qparams_map := input_node.meta.get("output_qparams", {})) != 1: + out_qparams_map = input_node.meta.get("output_qparams", {}) + if len(out_qparams_map) == 0: + # Nested control-flow submodules may also expose frozen captured + # values as placeholders. Those are not control-flow boundary + # inputs, so there is no qparam pair to bridge with a RESCALE. + continue + if len(out_qparams_map) != 1: raise ValueError( f"Expected submodule input {input_node} to have exactly one output qparam, got {out_qparams_map}" ) diff --git a/backends/arm/_passes/match_arg_ranks_pass.py b/backends/arm/_passes/match_arg_ranks_pass.py index 905286e39b0..199eafe0cfb 100644 --- a/backends/arm/_passes/match_arg_ranks_pass.py +++ b/backends/arm/_passes/match_arg_ranks_pass.py @@ -57,6 +57,7 @@ def __init__(self, exported_program: ExportedProgram, *args, **kwargs) -> None: exir_ops.edge.aten.ge.Tensor, exir_ops.edge.aten.lt.Tensor, exir_ops.edge.aten.le.Tensor, + exir_ops.edge.aten.ne.Tensor, exir_ops.edge.aten.pow.Tensor_Tensor, exir_ops.edge.aten.remainder.Tensor, exir_ops.edge.aten.where.self, diff --git a/backends/arm/_passes/normalize_index_put_bool_index_tensor_pass.py b/backends/arm/_passes/normalize_index_put_bool_index_tensor_pass.py index 9377eaec2fe..badc58b06fb 100644 --- a/backends/arm/_passes/normalize_index_put_bool_index_tensor_pass.py +++ b/backends/arm/_passes/normalize_index_put_bool_index_tensor_pass.py @@ -6,13 +6,13 @@ import torch -from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes import ArmOpTargetedPass from executorch.backends.arm._passes.rewrite_index_put_pass import RewriteIndexPutPass from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass -class NormalizeIndexPutBoolIndexTensorPass(ArmPass): +class NormalizeIndexPutBoolIndexTensorPass(ArmOpTargetedPass): """Normalize single boolean mask index_put scalar to where. In the general case, boolean masks are complex and data dependent. The simple case x[mask] = scalar @@ -30,6 +30,7 @@ class NormalizeIndexPutBoolIndexTensorPass(ArmPass): """ _passes_required_after: Set[Type[ExportPass]] = {RewriteIndexPutPass} + target_ops = (exir_ops.edge.aten.index_put.default,) def __init__(self): super().__init__() @@ -57,7 +58,7 @@ def _is_valid_bool_mask( return True def call_operator(self, op, args, kwargs, meta, updated: bool | None = False): - if op not in (exir_ops.edge.aten.index_put.default,): + if op not in self.target_ops: return super().call_operator(op, args, kwargs, meta, updated) destination, indices_tensor_list, data = args[:3] diff --git a/backends/arm/_passes/normalize_index_put_none_indices_pass.py b/backends/arm/_passes/normalize_index_put_none_indices_pass.py index 7aaace641b0..3afc9732b02 100644 --- a/backends/arm/_passes/normalize_index_put_none_indices_pass.py +++ b/backends/arm/_passes/normalize_index_put_none_indices_pass.py @@ -4,13 +4,13 @@ # LICENSE file in the root directory of this source tree. from typing import Set, Type -from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes import ArmOpTargetedPass from executorch.backends.arm._passes.rewrite_index_put_pass import RewriteIndexPutPass from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass -class NormalizeIndexPutNoneIndicesPass(ArmPass): +class NormalizeIndexPutNoneIndicesPass(ArmOpTargetedPass): """Normalize index_put with None:s in the indices_tensor list by moving None-indexed dims to the channel dimensions (*C_j in RewriteIndexPutPass teminology) by permutating the destination and data tensors. A None-index @@ -41,6 +41,7 @@ class NormalizeIndexPutNoneIndicesPass(ArmPass): """ _passes_required_after: Set[Type[ExportPass]] = {RewriteIndexPutPass} + target_ops = (exir_ops.edge.aten.index_put.default,) def __init__(self): super().__init__() @@ -67,7 +68,7 @@ def _get_data_dim_order( return destination_dim_order def call_operator(self, op, args, kwargs, meta, updated: bool | None = False): - if op not in (exir_ops.edge.aten.index_put.default,): + if op not in self.target_ops: return super().call_operator(op, args, kwargs, meta) destination, indices_tensor_list, data = args[:3] diff --git a/backends/arm/_passes/promote_bool_operands_pass.py b/backends/arm/_passes/promote_bool_operands_pass.py index 4d02646e30a..8e162ded1bd 100644 --- a/backends/arm/_passes/promote_bool_operands_pass.py +++ b/backends/arm/_passes/promote_bool_operands_pass.py @@ -11,19 +11,19 @@ import torch -from executorch.backends.arm._passes.arm_pass import ArmPass +from executorch.backends.arm._passes.arm_pass import ArmOpTargetedPass from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass -class PromoteBoolOperandsPass(ArmPass): +class PromoteBoolOperandsPass(ArmOpTargetedPass): """Promote boolean operands to the appropriate integer dtype for unsupported ops. """ _passes_required_after: Set[Type[ExportPass]] = set() - targeted_ops = { + target_ops = { exir_ops.edge.aten.bitwise_and.Tensor, exir_ops.edge.aten.bitwise_or.Tensor, exir_ops.edge.aten.bitwise_xor.Tensor, @@ -31,7 +31,7 @@ class PromoteBoolOperandsPass(ArmPass): } def call_operator(self, op, args, kwargs, meta): - if op not in self.targeted_ops: + if op not in self.target_ops: return super().call_operator(op, args, kwargs, meta) original_dtypes = [arg.data.dtype for arg in args] diff --git a/backends/arm/_passes/remove_noop_pass.py b/backends/arm/_passes/remove_noop_pass.py index c7fe469c8b8..5fafc848003 100644 --- a/backends/arm/_passes/remove_noop_pass.py +++ b/backends/arm/_passes/remove_noop_pass.py @@ -8,7 +8,7 @@ import logging from typing import Set, Type -from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes import ArmOpTargetedPass from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -16,19 +16,20 @@ logger = logging.getLogger(__name__) -class RemoveNoopPass(ArmPass): +class RemoveNoopPass(ArmOpTargetedPass): """Remove no-ops from graph_module.""" _passes_required_after: Set[Type[ExportPass]] = set() + target_ops = ( + exir_ops.edge.dim_order_ops._clone_dim_order.default, + exir_ops.edge.dim_order_ops._to_dim_order_copy.default, + exir_ops.edge.aten.alias_copy.default, + exir_ops.edge.aten.copy.default, + exir_ops.edge.aten.detach_copy.default, + ) def call_operator(self, op, args, kwargs, meta): - if op not in ( - exir_ops.edge.dim_order_ops._clone_dim_order.default, - exir_ops.edge.dim_order_ops._to_dim_order_copy.default, - exir_ops.edge.aten.alias_copy.default, - exir_ops.edge.aten.copy.default, - exir_ops.edge.aten.detach_copy.default, - ): + if op not in self.target_ops: return super().call_operator(op, args, kwargs, meta) input_dtype = args[0].data.dtype diff --git a/backends/arm/_passes/replace_scalar_with_tensor_pass.py b/backends/arm/_passes/replace_scalar_with_tensor_pass.py index edd5fc97213..53f0e517a7f 100644 --- a/backends/arm/_passes/replace_scalar_with_tensor_pass.py +++ b/backends/arm/_passes/replace_scalar_with_tensor_pass.py @@ -126,4 +126,4 @@ def call_operator(self, op, args, kwargs, meta): return super().call_operator(op, args, kwargs, meta) else: # Do not handle; forward unchanged. - return ExportPass.call_operator(self, op, args, kwargs, meta) + return ArmPass.call_operator(self, op, args, kwargs, meta) diff --git a/backends/arm/_passes/rewrite_avg_pool2d_pass.py b/backends/arm/_passes/rewrite_avg_pool2d_pass.py index bf81505d923..6427b571218 100644 --- a/backends/arm/_passes/rewrite_avg_pool2d_pass.py +++ b/backends/arm/_passes/rewrite_avg_pool2d_pass.py @@ -6,7 +6,7 @@ from typing import Set, Type import torch -from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes import ArmOpTargetedPass from executorch.backends.arm._passes.arm_pass_utils import to_2tuple from executorch.backends.arm.constants import NHWC_INVERSE_ORDER, NHWC_ORDER from executorch.backends.arm.operators.operator_validation_utils import ( @@ -18,11 +18,11 @@ from .fuse_constant_ops_pass import ComputeConstantOpsAOTPass -class RewriteAvgPool2dPass(ArmPass): +class RewriteAvgPool2dPass(ArmOpTargetedPass): """Rewrite aten.avg_pool2d calls to TOSA AVG_POOL2D op.""" # Target the original avg_pool2d operator - targeted_ops = {exir_ops.edge.aten.avg_pool2d.default} + target_ops = {exir_ops.edge.aten.avg_pool2d.default} _passes_required_after: Set[Type[ExportPass]] = { ComputeConstantOpsAOTPass, } @@ -30,7 +30,7 @@ class RewriteAvgPool2dPass(ArmPass): def call_operator(self, op, args, kwargs, meta, updated=False): # Only rewrite avg_pool2d - if op not in self.targeted_ops: + if op not in self.target_ops: return super().call_operator(op, args, kwargs, meta, updated) x = args[0] diff --git a/backends/arm/_passes/rewrite_bool_bitwise_to_logical_pass.py b/backends/arm/_passes/rewrite_bool_bitwise_to_logical_pass.py index 8c6bf6f39ec..962bdbbaf6e 100644 --- a/backends/arm/_passes/rewrite_bool_bitwise_to_logical_pass.py +++ b/backends/arm/_passes/rewrite_bool_bitwise_to_logical_pass.py @@ -7,12 +7,12 @@ from typing import Set, Type import torch -from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes import ArmOpTargetedPass from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass -class RewriteBoolBitwiseToLogicalPass(ArmPass): +class RewriteBoolBitwiseToLogicalPass(ArmOpTargetedPass): """Rewrites ``aten.bitwise_*`` on boolean tensors to ``aten.logical_*``. TOSA ``bitwise_*`` does not support boolean inputs. On boolean tensors, @@ -32,9 +32,10 @@ class RewriteBoolBitwiseToLogicalPass(ArmPass): exir_ops.edge.aten.bitwise_xor.Tensor: exir_ops.edge.aten.logical_xor.default, exir_ops.edge.aten.bitwise_xor.Scalar: exir_ops.edge.aten.logical_xor.default, } + target_ops = tuple(_TARGET_TO_LOGICAL) def call_operator(self, op, args, kwargs, meta): - if op not in self._TARGET_TO_LOGICAL: + if op not in self.target_ops: return super().call_operator(op, args, kwargs, meta) if meta["val"].dtype == torch.bool: diff --git a/backends/arm/_passes/rewrite_high_rank_singleton_permute_pass.py b/backends/arm/_passes/rewrite_high_rank_singleton_permute_pass.py index 1c0bac0ba9c..40a7935f050 100644 --- a/backends/arm/_passes/rewrite_high_rank_singleton_permute_pass.py +++ b/backends/arm/_passes/rewrite_high_rank_singleton_permute_pass.py @@ -5,12 +5,12 @@ from typing import Sequence, Set, Type -from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes import ArmOpTargetedPass from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass -class RewriteHighRankSingletonPermutePass(ArmPass): +class RewriteHighRankSingletonPermutePass(ArmOpTargetedPass): """Rewrite high-rank permute via a lower-rank permute when singleton dims allow it. @@ -30,6 +30,7 @@ class RewriteHighRankSingletonPermutePass(ArmPass): exir_ops.edge.aten.permute.default, exir_ops.edge.aten.permute_copy.default, ) + target_ops = _PERMUTE_OPS @staticmethod def _extract_permutation(permutation_arg: object) -> tuple[int, ...] | None: @@ -46,7 +47,7 @@ def _normalize_permutation( return tuple(dim % rank for dim in permutation) def call_operator(self, op, args, kwargs, meta): - if op not in self._PERMUTE_OPS: + if op not in self.target_ops: return super().call_operator(op, args, kwargs, meta) if len(args) < 2: return super().call_operator(op, args, kwargs, meta) diff --git a/backends/arm/_passes/rewrite_index_put_pass.py b/backends/arm/_passes/rewrite_index_put_pass.py index c0898673fd7..8f2ab4bb830 100644 --- a/backends/arm/_passes/rewrite_index_put_pass.py +++ b/backends/arm/_passes/rewrite_index_put_pass.py @@ -7,7 +7,7 @@ import torch -from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes import ArmOpTargetedPass from executorch.backends.arm._passes.convert_expand_copy_to_repeat import ( ConvertExpandCopyToRepeatPass, ) @@ -31,7 +31,7 @@ def calculate_data_stride(destination_shape: list[int]) -> list[int]: return data_strides -class RewriteIndexPutPass(ArmPass): +class RewriteIndexPutPass(ArmOpTargetedPass): """ This pass transforms index_put with arguments - destination, of shape (*K_i, *C_j) @@ -69,6 +69,7 @@ def __init__(self): FuseViewCopyTransformPass, ConvertExpandCopyToRepeatPass, } + target_ops = (exir_ops.edge.aten.index_put.default,) def _calculate_flat_indices( self, @@ -121,7 +122,7 @@ def _calculate_flat_indices( ) def call_operator(self, op, args, kwargs, meta, updated: bool | None = False): - if op not in (exir_ops.edge.aten.index_put.default,): + if op not in self.target_ops: return super().call_operator(op, args, kwargs, meta) destination, indices_tensor_list, data = args[:3] diff --git a/backends/arm/_passes/rewrite_inplace_arithmetic_pass.py b/backends/arm/_passes/rewrite_inplace_arithmetic_pass.py index f5a484343c5..72683b353ce 100644 --- a/backends/arm/_passes/rewrite_inplace_arithmetic_pass.py +++ b/backends/arm/_passes/rewrite_inplace_arithmetic_pass.py @@ -7,7 +7,7 @@ import torch -from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes import ArmOpTargetedPass from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -23,10 +23,12 @@ } -class RewriteInplaceArithmeticPass(ArmPass): +class RewriteInplaceArithmeticPass(ArmOpTargetedPass): """Rewrite inplace arithmetic ops into functional equivalents.""" _passes_required_after: Set[Type[ExportPass]] = set() + target_ops = tuple(OP_MAP) + check_allowed_to_transform = True def call_operator(self, op, args, kwargs, meta): if not self.allowed_to_transform(meta): diff --git a/backends/arm/_passes/rewrite_le_lt_to_ge_gt_pass.py b/backends/arm/_passes/rewrite_le_lt_to_ge_gt_pass.py index 9119567b7aa..c73279e65d0 100644 --- a/backends/arm/_passes/rewrite_le_lt_to_ge_gt_pass.py +++ b/backends/arm/_passes/rewrite_le_lt_to_ge_gt_pass.py @@ -7,7 +7,7 @@ import torch -from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes import ArmOpTargetedPass from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -19,10 +19,12 @@ } -class RewriteLeLtToGeGtPass(ArmPass): +class RewriteLeLtToGeGtPass(ArmOpTargetedPass): """Rewrite le/lt into ge/gt with swapped inputs.""" _passes_required_after: Set[Type[ExportPass]] = set() + target_ops = tuple(OP_MAP) + check_allowed_to_transform = True def call_operator(self, op, args, kwargs, meta): if not self.allowed_to_transform(meta): diff --git a/backends/arm/_passes/rewrite_max_pool2d_pass.py b/backends/arm/_passes/rewrite_max_pool2d_pass.py index 8a59f2bd4ac..8debb322a6d 100644 --- a/backends/arm/_passes/rewrite_max_pool2d_pass.py +++ b/backends/arm/_passes/rewrite_max_pool2d_pass.py @@ -5,7 +5,7 @@ from typing import Set, Type -from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes import ArmOpTargetedPass from executorch.backends.arm._passes.arm_pass_utils import to_2tuple from executorch.backends.arm.constants import NHWC_INVERSE_ORDER, NHWC_ORDER from executorch.backends.arm.operators.operator_validation_utils import ( @@ -17,13 +17,14 @@ edge_max_pool2d_ops = (exir_ops.edge.aten.max_pool2d.default,) -class RewriteMaxPool2dPass(ArmPass): +class RewriteMaxPool2dPass(ArmOpTargetedPass): """Rewrite max_pool2d ops to TOSA MAX_POOL2D.""" _passes_required_after: Set[Type[ExportPass]] = set() + target_ops = edge_max_pool2d_ops def call_operator(self, op, args, kwargs, meta): - if op not in edge_max_pool2d_ops: + if op not in self.target_ops: return super().call_operator(op, args, kwargs, meta) x = args[0] diff --git a/backends/arm/_passes/rewrite_pad.py b/backends/arm/_passes/rewrite_pad.py index 40523fb559a..250fccab38b 100644 --- a/backends/arm/_passes/rewrite_pad.py +++ b/backends/arm/_passes/rewrite_pad.py @@ -8,18 +8,18 @@ import torch -from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes import ArmOpTargetedPass from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass -class RewritePadPass(ArmPass): +class RewritePadPass(ArmOpTargetedPass): """Rewrite constant_pad_nd operator to TOSA Pad operator with constant mode. """ _passes_required_after: Set[Type[ExportPass]] = set() - targeted_ops = { + target_ops = { exir_ops.edge.aten.constant_pad_nd.default, exir_ops.edge.aten.pad.default, } @@ -145,7 +145,7 @@ def _rewrite_non_constant_pad( return output def call_operator(self, op, args, kwargs, meta, updated=False): - if op not in self.targeted_ops: + if op not in self.target_ops: return super().call_operator(op, args, kwargs, meta) if op == exir_ops.edge.aten.constant_pad_nd.default: diff --git a/backends/arm/_passes/rewrite_slice.py b/backends/arm/_passes/rewrite_slice.py index c0f6e1b6573..2aab2e16539 100644 --- a/backends/arm/_passes/rewrite_slice.py +++ b/backends/arm/_passes/rewrite_slice.py @@ -4,7 +4,7 @@ # LICENSE file in the root directory of this source tree. from typing import Set, Type -from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes import ArmOpTargetedPass from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, ProxyValue @@ -12,10 +12,11 @@ from torch import SymInt -class RewriteSlicePass(ArmPass): +class RewriteSlicePass(ArmOpTargetedPass): """Rewrite slice operations with step of 1 to TOSA slice operators.""" _passes_required_after: Set[Type[ExportPass]] = set() + target_ops = (exir_ops.edge.aten.slice_copy.Tensor,) def _fixup_start(self, start, input_shape, dim) -> int: """Convert negative and out-of-bounds start indices to valid positive @@ -29,7 +30,7 @@ def _fixup_start(self, start, input_shape, dim) -> int: return idx def call_operator(self, op, args, kwargs, meta, updated=False) -> ProxyValue: - if op not in (exir_ops.edge.aten.slice_copy.Tensor,): + if op not in self.target_ops: return super().call_operator(op, args, kwargs, meta, updated) if len(args) == 5 and args[4] != 1: diff --git a/backends/arm/_passes/scalars_to_attribute_pass.py b/backends/arm/_passes/scalars_to_attribute_pass.py index 0473caf91e7..63a38b8cb2f 100644 --- a/backends/arm/_passes/scalars_to_attribute_pass.py +++ b/backends/arm/_passes/scalars_to_attribute_pass.py @@ -8,11 +8,9 @@ import torch from executorch.backends.arm._passes import ArmPass -from executorch.backends.arm._passes.arm_pass_utils import ( - get_cond_while_submodules_nested, - get_first_fake_tensor, -) +from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass +from executorch.exir.graph_module import get_cond_while_submodules from executorch.exir.pass_base import ExportPass, PassResult from torch.fx import GraphModule, Node from torchao.quantization.pt2e.utils import get_new_attr_name_with_prefix @@ -98,7 +96,7 @@ def handle_control_nodes(self, graph_module: GraphModule) -> None: """Apply scalar argument conversion on subgraphs of control-flow nodes. """ - for _, submodule, _ in get_cond_while_submodules_nested(graph_module): + for _, submodule, _ in get_cond_while_submodules(graph_module): for submodule_node in submodule.graph.nodes: self._convert_scalar_args(submodule, submodule_node) diff --git a/backends/arm/ao_ext/__init__.py b/backends/arm/ao_ext/__init__.py new file mode 100644 index 00000000000..fef05a9f6ae --- /dev/null +++ b/backends/arm/ao_ext/__init__.py @@ -0,0 +1,12 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Import mxfp_transform to trigger registration of the MXFP transforms. +from . import mxfp_transform # noqa: F401 + +from .mxfp import MXFPOpConfig, to_mxfp + + +__all__ = ["MXFPOpConfig", "to_mxfp"] diff --git a/backends/arm/ao_ext/mxfp.py b/backends/arm/ao_ext/mxfp.py new file mode 100644 index 00000000000..783da92590e --- /dev/null +++ b/backends/arm/ao_ext/mxfp.py @@ -0,0 +1,64 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass +from typing import Callable, Optional + +import torch +from executorch.exir._warnings import experimental +from torchao.core.config import AOBaseConfig +from torchao.prototype.mx_formats.config import ScaleCalculationMode +from torchao.quantization import quantize_ + + +def _match_supported_modules(module: torch.nn.Module, _name: str) -> bool: + """Default filter function that matches supported modules.""" + return isinstance(module, torch.nn.Linear) + + +@experimental("This API is experimental and may change without notice.") +@dataclass +class MXFPOpConfig(AOBaseConfig): + """Configuration for Arm MXFP source transforms.""" + + weight_dtype: torch.dtype = torch.float8_e4m3fn + weight_scaling_mode: ScaleCalculationMode = ScaleCalculationMode.RCEIL + + # Only block size of 32 is currently supported for now, so we hardcode it here. + @property + def block_size(self) -> int: + return 32 + + def __post_init__(self) -> None: + if self.weight_dtype not in (torch.float8_e4m3fn, torch.float8_e5m2): + raise ValueError(f"Unsupported weight_dtype: {self.weight_dtype}") + if not isinstance(self.weight_scaling_mode, ScaleCalculationMode): + raise ValueError( + f"Unsupported weight_scaling_mode: {self.weight_scaling_mode}" + ) + + +@experimental("This API is experimental and may change without notice.") +def to_mxfp( + model: torch.nn.Module, + config: MXFPOpConfig, + filter_fn: Optional[Callable[[torch.nn.Module, str], bool]] = None, +) -> None: + """Convert matching modules in ``model`` to Arm MXFP modules in-place. + + Args: + model (torch.nn.Module): Module to transform. Matching submodules are + replaced in-place. + config (MXFPOpConfig): Configuration controlling the MXFP conversion + behavior. + filter_fn (Optional[Callable[[torch.nn.Module, str], bool]]): Optional + predicate that receives a module and its fully qualified name. When + omitted, all modules supported by the MXFP transform are matched. + + """ + if filter_fn is None: + filter_fn = _match_supported_modules + + quantize_(model, config, filter_fn) diff --git a/backends/arm/ao_ext/mxfp_tosa_lib.py b/backends/arm/ao_ext/mxfp_tosa_lib.py new file mode 100644 index 00000000000..4459ec59126 --- /dev/null +++ b/backends/arm/ao_ext/mxfp_tosa_lib.py @@ -0,0 +1,11 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from torch.library import Library + +# MXFP TOSA library definition for the Arm backend containing. +# This library will generate custom ops like the following example: +# torch.ops.tosa_mxfp.linear.default +MXFP_TOSA_LIB = Library("tosa_mxfp", "DEF") diff --git a/backends/arm/ao_ext/mxfp_transform.py b/backends/arm/ao_ext/mxfp_transform.py new file mode 100644 index 00000000000..b7823524475 --- /dev/null +++ b/backends/arm/ao_ext/mxfp_transform.py @@ -0,0 +1,24 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from executorch.backends.arm.ao_ext.mxfp import MXFPOpConfig +from executorch.backends.arm.ao_ext.ops.mxfp_linear_op import transform_linear_to_mxfp +from torchao.quantization.transform_module import register_quantize_module_handler + + +@register_quantize_module_handler(MXFPOpConfig) # type: ignore[misc] +def _transform_to_mxfp( + module: torch.nn.Module, + config: MXFPOpConfig, +) -> torch.nn.Module: + """Transforms a given module to use MXFP operations based on the provided + MXFPOpConfig configuration. + """ + if isinstance(module, torch.nn.Linear): + return transform_linear_to_mxfp(module, config) + else: + return module diff --git a/backends/arm/ao_ext/ops/__init__.py b/backends/arm/ao_ext/ops/__init__.py new file mode 100644 index 00000000000..a690c4b7b02 --- /dev/null +++ b/backends/arm/ao_ext/ops/__init__.py @@ -0,0 +1,10 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from .mxfp_linear_op import MXFPLinearOp + +__all__ = [ + "MXFPLinearOp", +] diff --git a/backends/arm/ao_ext/ops/mxfp_linear_op.py b/backends/arm/ao_ext/ops/mxfp_linear_op.py new file mode 100644 index 00000000000..5238f85a847 --- /dev/null +++ b/backends/arm/ao_ext/ops/mxfp_linear_op.py @@ -0,0 +1,179 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +"""MXFP Linear transform for the Arm backend. + +TorchAO extension for MXFP linear. It replaces ``nn.Linear`` with a wrapper +module that stores precomputed MXFP weights and emits a backend-internal custom +op during export. + +""" + +import torch +import torch.nn.functional as F +from executorch.backends.arm.ao_ext.mxfp import MXFPOpConfig +from executorch.backends.arm.ao_ext.mxfp_tosa_lib import MXFP_TOSA_LIB +from torchao.prototype.mx_formats.config import ScaleCalculationMode +from torchao.prototype.mx_formats.mx_tensor import to_dtype, to_mx + +MXFP_TOSA_LIB.define( + "linear(Tensor input, Tensor weight_qdata, Tensor weight_scale, " + "Tensor? bias=None, SymInt block_size=32) -> Tensor" +) + + +@torch.library.register_fake("tosa_mxfp::linear", lib=MXFP_TOSA_LIB) # type: ignore[misc] +def _mxfp_linear_fake( + input: torch.Tensor, + weight_qdata: torch.Tensor, + weight_scale: torch.Tensor, + bias: torch.Tensor | None = None, + block_size: int = 32, +) -> torch.Tensor: + if weight_qdata.ndim != 3: + raise ValueError( + f"Expected weight_qdata to be rank 3 for linear, got {weight_qdata.ndim}" + ) + if weight_qdata.shape[0] != 1: + raise ValueError( + f"Expected weight_qdata batch dim to be 1, got {weight_qdata.shape[0]}" + ) + if input.shape[-1] != weight_qdata.shape[-1]: + raise ValueError( + f"Input last dim {input.shape[-1]} must match linear in_features " + f"{weight_qdata.shape[-1]}" + ) + expected_scale_shape = ( + 1, + weight_qdata.shape[1], + weight_qdata.shape[-1] // block_size, + ) + if tuple(weight_scale.shape) != expected_scale_shape: + raise ValueError( + f"Expected weight_scale shape {expected_scale_shape}, got " + f"{tuple(weight_scale.shape)}" + ) + output_shape = (*input.shape[:-1], weight_qdata.shape[1]) + return input.new_empty(output_shape, dtype=torch.float32) + + +def _cast_to_block_scaled_cpu_ref( + input: torch.Tensor, + output_dtype: torch.dtype, + block_size: int, +) -> torch.Tensor: + """Emulate the current TOSA activation cast in eager mode.""" + input_scale, input_qdata = to_mx( + input.to(torch.float32).contiguous(), + elem_dtype=output_dtype, + block_size=block_size, + scaling_mode=ScaleCalculationMode.RCEIL, + ) + return to_dtype( + input_qdata, + input_scale, + output_dtype, + block_size, + torch.float32, + ) + + +@torch.library.impl("tosa_mxfp::linear", "cpu", lib=MXFP_TOSA_LIB) +def _mxfp_linear_cpu( + input: torch.Tensor, + weight_qdata: torch.Tensor, + weight_scale: torch.Tensor, + bias: torch.Tensor | None = None, + block_size: int = 32, +) -> torch.Tensor: + """CPU reference implementation of the MXFP linear op.""" + + if weight_qdata.ndim != 3 or weight_scale.ndim != 3: + raise ValueError("Expected rank-3 weight tensors for MXFP linear") + + # Cast the input to block-scaled format and back again to match the + # expected input format of the TOSA + dequantized_input = _cast_to_block_scaled_cpu_ref( + input, + weight_qdata.dtype, + block_size, + ) + dequantized_weight = to_dtype( + weight_qdata, + weight_scale, + weight_qdata.dtype, + block_size, + torch.float32, + ) + dequantized_weight = dequantized_weight.squeeze(0) + if bias is not None: + bias = bias.to(torch.float32) + return F.linear(dequantized_input, dequantized_weight, bias) + + +class MXFPLinearOp(torch.nn.Module): + """Linear wrapper that stores MXFP weights and emits a custom op.""" + + def __init__( + self, + weight_qdata: torch.Tensor, + weight_scale: torch.Tensor, + bias: torch.Tensor | None, + config: MXFPOpConfig, + ) -> None: + super().__init__() + self.config = config + + self.register_buffer("weight_qdata", weight_qdata, persistent=True) + self.register_buffer("weight_scale", weight_scale, persistent=True) + + self.bias: torch.nn.Parameter | None + bias_param = ( + torch.nn.Parameter(bias.detach(), requires_grad=False) + if bias is not None + else None + ) + self.register_parameter( + "bias", + bias_param, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.ops.tosa_mxfp.linear.default( + x, + self.weight_qdata, + self.weight_scale, + self.bias, + self.config.block_size, + ) + + +def transform_linear_to_mxfp( + module: torch.nn.Module, + config: MXFPOpConfig, +) -> torch.nn.Module: + assert isinstance(module, torch.nn.Linear) + + weight = module.weight.detach().contiguous() + if weight.shape[-1] % config.block_size != 0: + raise ValueError( + f"Linear in_features={weight.shape[-1]} must be divisible by " + f"block_size={config.block_size}" + ) + + weight_scale, weight_qdata = to_mx( + weight, + elem_dtype=config.weight_dtype, + block_size=config.block_size, + scaling_mode=config.weight_scaling_mode, + ) + + # The resulting TOSA op MATMUL_T_BLOCK_SCALED only works with tensors of + # rank 3, therefore we prepend a batch dimension of 1 to the weight tensors + # here. + weight_qdata = weight_qdata.unsqueeze(0) + weight_scale = weight_scale.unsqueeze(0) + + bias = module.bias.detach().to(torch.float32) if module.bias is not None else None + return MXFPLinearOp(weight_qdata, weight_scale, bias, config) diff --git a/backends/arm/operator_support/TARGETS b/backends/arm/operator_support/TARGETS index 8f6721bd911..a2fd054d472 100644 --- a/backends/arm/operator_support/TARGETS +++ b/backends/arm/operator_support/TARGETS @@ -6,6 +6,7 @@ runtime.python_library( deps = [ "//executorch/backends/arm:constants", "//executorch/backends/arm/_passes:passes", + "//executorch/backends/arm/tosa:resize_utils", "//executorch/backends/arm/tosa:tosa", "//executorch/backends/transforms:remove_getitem_op", "//executorch/backends/xnnpack/_passes:xnnpack_passes", diff --git a/backends/arm/operator_support/control_flow_support.py b/backends/arm/operator_support/control_flow_support.py index b34ebeaece0..f5251357cd3 100644 --- a/backends/arm/operator_support/control_flow_support.py +++ b/backends/arm/operator_support/control_flow_support.py @@ -19,6 +19,13 @@ from torch.fx.passes.operator_support import OperatorSupportBase +def _owning_graph_module(node: fx.Node) -> fx.GraphModule: + graph_module = getattr(node.graph, "owning_module", None) + if not isinstance(graph_module, fx.GraphModule): + raise RuntimeError(f"Could not resolve owning GraphModule for node {node}") + return graph_module + + def _fully_partitioned(submodule: fx.GraphModule) -> bool: """Check that all nested control-flow ops within this submodule are also fully partitioned. @@ -27,8 +34,8 @@ def _fully_partitioned(submodule: fx.GraphModule) -> bool: for submodule_node in submodule.graph.nodes: if submodule_node.target in ControlFlowOpSupported._targeted_ops: - if _submodules_fully_partitioned(submodule_node, submodule): - return True + if not _submodules_fully_partitioned(submodule_node, submodule): + return False if submodule_node.op != "call_function": continue @@ -56,13 +63,18 @@ def _fully_partitioned(submodule: fx.GraphModule) -> bool: return True -def _submodules_fully_partitioned(node: fx.Node, graph_module: fx.GraphModule) -> bool: +def _submodules_fully_partitioned( + node: fx.Node, graph_module: fx.GraphModule | None = None +) -> bool: """Returns whether the submodule arguments to a cond node were fully partitioned. Updates "val" meta of the submodules if they are. """ + if graph_module is None: + graph_module = _owning_graph_module(node) + match node.target: case torch.ops.higher_order.cond: submodule_args = node.args[1:3] @@ -129,9 +141,7 @@ def is_node_supported( node, f"Submodule had unsupported user {user}" ) return False - if not _submodules_fully_partitioned( - user, self.exported_program.graph_module - ): + if not _submodules_fully_partitioned(user): self.reporter.report_reject( node, "One submodule was not fully partitioned" ) @@ -174,9 +184,7 @@ def is_node_supported( ) return False - if not _submodules_fully_partitioned( - node, self.exported_program.graph_module - ): + if not _submodules_fully_partitioned(node): self.reporter.report_reject( node, "Submodule was not fully partitioned." ) diff --git a/backends/arm/operator_support/gather_support.py b/backends/arm/operator_support/gather_support.py index 651727cd8b6..6d923c0441c 100644 --- a/backends/arm/operator_support/gather_support.py +++ b/backends/arm/operator_support/gather_support.py @@ -49,7 +49,7 @@ class GatherSupported(SupportedTOSAOperatorCheck): targets = [exir_ops.edge.aten.gather.default] - def is_node_tosa_supported( + def is_node_tosa_supported( # noqa: C901 self, node: fx.Node, tosa_spec: TosaSpecification ) -> bool: # type: ignore[override, misc] if len(node.args) != 3: @@ -115,8 +115,14 @@ def is_node_tosa_supported( f"{node.target}: dtype {values_dtype} requires INT profile.", ) return False - # fp16/fp32/bf16: either FP profile, or INT profile (via quantization) - elif values_dtype in (torch.float16, torch.float32, torch.bfloat16): + # fp16/fp32/bf16/fp8: either FP profile, or INT profile (via quantization) + elif values_dtype in ( + torch.float16, + torch.float32, + torch.bfloat16, + torch.float8_e4m3fn, + torch.float8_e5m2, + ): if values_dtype == torch.bfloat16 and not tosa_spec.support_extension( "bf16" ): @@ -125,6 +131,22 @@ def is_node_tosa_supported( f"{node.target}: dtype {values_dtype} requires bf16 extension.", ) return False + if values_dtype == torch.float8_e4m3fn and not tosa_spec.support_extension( + "fp8e4m3" + ): + self.reporter.report_reject( + node, + f"{node.target}: dtype {values_dtype} requires fp8e4m3 extension.", + ) + return False + if values_dtype == torch.float8_e5m2 and not tosa_spec.support_extension( + "fp8e5m2" + ): + self.reporter.report_reject( + node, + f"{node.target}: dtype {values_dtype} requires fp8e5m2 extension.", + ) + return False if not (tosa_spec.support_float() or tosa_spec.support_integer()): self.reporter.report_reject( node, @@ -136,7 +158,8 @@ def is_node_tosa_supported( self.reporter.report_reject( node, f"{node.target}: unsupported values dtype {values_dtype}; " - "expected bool/int8/int16/int32/float16/bfloat16/float32.", + "expected bool/int8/int16/int32/float16/bfloat16/float32/" + "float8_e4m3fn/float8_e5m2.", ) return False diff --git a/backends/arm/operator_support/index_select_support.py b/backends/arm/operator_support/index_select_support.py index a3188e739c7..285b2cfe79f 100644 --- a/backends/arm/operator_support/index_select_support.py +++ b/backends/arm/operator_support/index_select_support.py @@ -77,8 +77,16 @@ def is_node_tosa_supported( f"{node.target}: dtype {values_dtype} requires INT profile.", ) return False - # fp16/fp32: either FP profile, or INT profile (via quantization) - elif values_dtype in (torch.float16, torch.float32): + # fp16/fp32/bf16: either FP profile, or INT profile (via quantization) + elif values_dtype in (torch.float16, torch.float32, torch.bfloat16): + if values_dtype == torch.bfloat16 and not tosa_spec.support_extension( + "bf16" + ): + self.reporter.report_reject( + node, + f"{node.target}: dtype {values_dtype} requires bf16 extension.", + ) + return False if not (tosa_spec.support_float() or tosa_spec.support_integer()): self.reporter.report_reject( node, @@ -90,7 +98,7 @@ def is_node_tosa_supported( self.reporter.report_reject( node, f"{node.target}: unsupported values dtype {values_dtype}; " - "expected bool/int8/int16/int32/float16/float32.", + "expected bool/int8/int16/int32/float16/bfloat16/float32.", ) return False diff --git a/backends/arm/operator_support/slice_copy_support.py b/backends/arm/operator_support/slice_copy_support.py index bcc3ddfbbbb..c9ef4a85bdf 100644 --- a/backends/arm/operator_support/slice_copy_support.py +++ b/backends/arm/operator_support/slice_copy_support.py @@ -53,7 +53,13 @@ def is_node_tosa_supported( values_dtype = node.args[0].meta["val"].dtype # type: ignore[union-attr] SUPPORTED_INT_DTYPES = (torch.int8, torch.int16, torch.int32) - SUPPORTED_FLOAT_DTYPES = (torch.float16, torch.float32, torch.bfloat16) + SUPPORTED_FLOAT_DTYPES = ( + torch.float16, + torch.float32, + torch.bfloat16, + torch.float8_e4m3fn, + torch.float8_e5m2, + ) SUPPORTED_DTYPES = (torch.bool,) + SUPPORTED_INT_DTYPES + SUPPORTED_FLOAT_DTYPES # bool is supported in both INT and FP profiles @@ -68,7 +74,7 @@ def is_node_tosa_supported( ) return False - # fp16/fp32/bf16: either FP profile, or INT profile (via quantization) + # fp16/fp32/bf16/fp8: either FP profile, or INT profile (via quantization) elif values_dtype in SUPPORTED_FLOAT_DTYPES: if values_dtype == torch.bfloat16 and not tosa_spec.support_extension( "bf16" @@ -78,6 +84,22 @@ def is_node_tosa_supported( f"{node.target}: dtype {values_dtype} requires bf16 extension.", ) return False + if values_dtype == torch.float8_e4m3fn and not tosa_spec.support_extension( + "fp8e4m3" + ): + self.reporter.report_reject( + node, + f"{node.target}: dtype {values_dtype} requires fp8e4m3 extension.", + ) + return False + if values_dtype == torch.float8_e5m2 and not tosa_spec.support_extension( + "fp8e5m2" + ): + self.reporter.report_reject( + node, + f"{node.target}: dtype {values_dtype} requires fp8e5m2 extension.", + ) + return False if not (tosa_spec.support_float() or tosa_spec.support_integer()): self.reporter.report_reject( node, diff --git a/backends/arm/operator_support/unfold_copy_support.py b/backends/arm/operator_support/unfold_copy_support.py index bf6c1cad22e..ac9fc7d0ee3 100644 --- a/backends/arm/operator_support/unfold_copy_support.py +++ b/backends/arm/operator_support/unfold_copy_support.py @@ -84,8 +84,16 @@ def is_node_tosa_supported( f"{node.target}: dtype {values_dtype} requires INT profile.", ) return False - # fp16/fp32: either FP profile, or INT profile (via quantization) - elif values_dtype in (torch.float16, torch.float32): + # fp16/fp32/bf16: either FP profile, or INT profile (via quantization) + elif values_dtype in (torch.float16, torch.float32, torch.bfloat16): + if values_dtype == torch.bfloat16 and not tosa_spec.support_extension( + "bf16" + ): + self.reporter.report_reject( + node, + f"{node.target}: dtype {values_dtype} requires bf16 extension.", + ) + return False if not (tosa_spec.support_float() or tosa_spec.support_integer()): self.reporter.report_reject( node, @@ -97,7 +105,7 @@ def is_node_tosa_supported( self.reporter.report_reject( node, f"{node.target}: unsupported values dtype {values_dtype}; " - "expected bool/int8/int16/int32/float16/float32.", + "expected bool/int8/int16/int32/float16/bfloat16/float32.", ) return False diff --git a/backends/arm/operator_support/upsample_support.py b/backends/arm/operator_support/upsample_support.py index bd03a4d2b4f..42e88f08521 100644 --- a/backends/arm/operator_support/upsample_support.py +++ b/backends/arm/operator_support/upsample_support.py @@ -13,9 +13,53 @@ SupportedTOSAOperatorCheck, ) from executorch.backends.arm.tosa import TosaSpecification +from executorch.backends.arm.tosa.resize_utils import get_tosa_resize_validation_error from executorch.exir.dialects._ops import ops as exir_ops +def _is_upsample_node_tosa_supported( + support_check: SupportedTOSAOperatorCheck, + node: fx.Node, + tosa_spec: TosaSpecification, + *, + align_corners: bool, +) -> bool: + input_node = ensure_type(fx.Node, node.args[0]) + input_size_yx = get_first_fake_tensor(input_node).shape[2:] + output_size_yx = get_first_fake_tensor(node).shape[2:] + + try: + scale_y_n, scale_y_d, offset_y, border_y = ( + RewriteUpsamplePass.get_resize_parameters_1d( + input_size_yx[0], output_size_yx[0], align_corners + ) + ) + scale_x_n, scale_x_d, offset_x, border_x = ( + RewriteUpsamplePass.get_resize_parameters_1d( + input_size_yx[1], output_size_yx[1], align_corners + ) + ) + except RuntimeError as err: + support_check.reporter.report_reject(node, str(err)) + return False + + # Validate the exact TOSA RESIZE parameters that RewriteUpsamplePass will + # emit so support checks and fake-op validation reject the same cases. + validation_error = get_tosa_resize_validation_error( + input_hw=input_size_yx, + output_hw=output_size_yx, + scale=[scale_y_n, scale_y_d, scale_x_n, scale_x_d], + offset=[offset_y, offset_x], + border=[border_y, border_x], + tosa_spec=tosa_spec, + ) + if validation_error is not None: + support_check.reporter.report_reject(node, validation_error) + return False + + return True + + @register_tosa_support_check class UpsampleNearest2dSupported(SupportedTOSAOperatorCheck): """Provide the explicit TOSA support gate for nearest upsample.""" @@ -23,9 +67,11 @@ class UpsampleNearest2dSupported(SupportedTOSAOperatorCheck): targets = [exir_ops.edge.aten.upsample_nearest2d.vec] def is_node_tosa_supported( - self, _node: fx.Node, _tosa_spec: TosaSpecification + self, node: fx.Node, tosa_spec: TosaSpecification ) -> bool: # type: ignore[override, misc] - return True + return _is_upsample_node_tosa_supported( + self, node, tosa_spec, align_corners=False + ) @register_tosa_support_check @@ -37,33 +83,9 @@ class UpsampleBilinear2dSupported(SupportedTOSAOperatorCheck): targets = [exir_ops.edge.aten.upsample_bilinear2d.vec] def is_node_tosa_supported( - self, node: fx.Node, _tosa_spec: TosaSpecification + self, node: fx.Node, tosa_spec: TosaSpecification ) -> bool: # type: ignore[override, misc] - input_node = ensure_type(fx.Node, node.args[0]) align_corners = ensure_type(bool, node.args[2]) - input_size_yx = get_first_fake_tensor(input_node).shape[2:] - output_size_yx = get_first_fake_tensor(node).shape[2:] - - try: - scale_y_n, scale_y_d, _, _ = RewriteUpsamplePass.get_resize_parameters_1d( - input_size_yx[0], output_size_yx[0], align_corners - ) - scale_x_n, scale_x_d, _, _ = RewriteUpsamplePass.get_resize_parameters_1d( - input_size_yx[1], output_size_yx[1], align_corners - ) - except RuntimeError as err: - self.reporter.report_reject(node, str(err)) - return False - - # get_resize_parameters_1d() returns the TOSA RESIZE scale fraction for - # each spatial dimension. For align_corners=False, this is the effective - # output_size / input_size ratio, so the 1/16 boundary is checked - # directly in the same representation that RESIZE lowering will use. - if scale_y_d >= 16 * scale_y_n or scale_x_d >= 16 * scale_x_n: - self.reporter.report_reject( - node, - "Bilinear RESIZE downscale must be strictly greater than 1/16", - ) - return False - - return True + return _is_upsample_node_tosa_supported( + self, node, tosa_spec, align_corners=align_corners + ) diff --git a/backends/arm/operators/op_cat.py b/backends/arm/operators/op_cat.py index 544beefadf9..97ea651cb12 100644 --- a/backends/arm/operators/op_cat.py +++ b/backends/arm/operators/op_cat.py @@ -44,6 +44,10 @@ def define_node( supported_dtypes.extend([ts.DType.FP16, ts.DType.FP32]) if self.tosa_spec.support_extension("bf16"): supported_dtypes.append(ts.DType.BF16) + if self.tosa_spec.support_extension("fp8e4m3"): + supported_dtypes.append(ts.DType.FP8E4M3) + if self.tosa_spec.support_extension("fp8e5m2"): + supported_dtypes.append(ts.DType.FP8E5M2) validate_num_inputs(self.target, inputs, [1, 2]) input_tosa_args = [TosaArg(arg, self.tosa_spec) for arg in inputs[0].special] validate_same_dtype(self.target, [*input_tosa_args, output], ts) diff --git a/backends/arm/operators/op_cond_if.py b/backends/arm/operators/op_cond_if.py index 05d38e2a1f0..513100c2b15 100644 --- a/backends/arm/operators/op_cond_if.py +++ b/backends/arm/operators/op_cond_if.py @@ -17,7 +17,11 @@ validate_num_inputs, validate_valid_dtype, ) -from executorch.backends.arm.tosa.mapping import TosaArg # type: ignore +from executorch.backends.arm.tosa.mapping import ( # type: ignore + TOSA_CONTROL_FLOW_REGION_NAME_META, + TOSA_TENSOR_NAME_META, + TosaArg, +) from torch.fx import Node @@ -38,7 +42,12 @@ def define_node( validate_cf_extension(self.target, self.tosa_spec) attr = ts.TosaSerializerAttribute() - if_graph, else_graph = (cast(Node, arg).target for arg in node.args[1:3]) + if_graph, else_graph = ( + cast(Node, arg).meta.get( + TOSA_CONTROL_FLOW_REGION_NAME_META, str(cast(Node, arg).target) + ) + for arg in node.args[1:3] + ) attr.CondIfAttribute(if_graph, else_graph) self._serialize_operator( @@ -47,7 +56,11 @@ def define_node( ts.Op.COND_IF, [ inputs[0].name, - *(subgraph_input.name for subgraph_input in inputs[-1].special), + *( + subgraph_input.name + + subgraph_input.meta.get(TOSA_TENSOR_NAME_META, "") + for subgraph_input in inputs[-1].special + ), ], output.multiple_output_names, attr, diff --git a/backends/arm/operators/op_permute.py b/backends/arm/operators/op_permute.py index e200478d7b3..2418131af3e 100644 --- a/backends/arm/operators/op_permute.py +++ b/backends/arm/operators/op_permute.py @@ -43,6 +43,10 @@ def define_node( supported_dtypes.extend([ts.DType.FP16, ts.DType.FP32]) if self.tosa_spec.support_extension("bf16"): supported_dtypes.append(ts.DType.BF16) + if self.tosa_spec.support_extension("fp8e4m3"): + supported_dtypes.append(ts.DType.FP8E4M3) + if self.tosa_spec.support_extension("fp8e5m2"): + supported_dtypes.append(ts.DType.FP8E5M2) validate_num_inputs(self.target, inputs, 2) validate_same_dtype(self.target, [inputs[0], output], ts) diff --git a/backends/arm/operators/op_repeat.py b/backends/arm/operators/op_repeat.py index 9b95c902847..f990dbef64b 100644 --- a/backends/arm/operators/op_repeat.py +++ b/backends/arm/operators/op_repeat.py @@ -42,6 +42,10 @@ def define_node( supported_dtypes.extend([ts.DType.FP16, ts.DType.FP32]) if self.tosa_spec.support_extension("bf16"): supported_dtypes.append(ts.DType.BF16) + if self.tosa_spec.support_extension("fp8e4m3"): + supported_dtypes.append(ts.DType.FP8E4M3) + if self.tosa_spec.support_extension("fp8e5m2"): + supported_dtypes.append(ts.DType.FP8E5M2) validate_num_inputs(self.target, inputs, 2) validate_same_dtype(self.target, [inputs[0], output], ts) diff --git a/backends/arm/operators/op_tosa_gather.py b/backends/arm/operators/op_tosa_gather.py index c242d351c06..913e2cc02b3 100644 --- a/backends/arm/operators/op_tosa_gather.py +++ b/backends/arm/operators/op_tosa_gather.py @@ -63,6 +63,16 @@ def define_node( ts.DType.FP16, ts.DType.FP32, ts.DType.BF16, + *( + [ts.DType.FP8E4M3] + if self.tosa_spec.support_extension("fp8e4m3") + else [] + ), + *( + [ts.DType.FP8E5M2] + if self.tosa_spec.support_extension("fp8e5m2") + else [] + ), ], self.tosa_spec, ) diff --git a/backends/arm/operators/op_tosa_pad.py b/backends/arm/operators/op_tosa_pad.py index 6f1cd488469..6e93adde55b 100644 --- a/backends/arm/operators/op_tosa_pad.py +++ b/backends/arm/operators/op_tosa_pad.py @@ -41,6 +41,10 @@ def define_node( supported_dtypes.extend([ts.DType.FP16, ts.DType.FP32]) if self.tosa_spec.support_extension("bf16"): supported_dtypes.append(ts.DType.BF16) + if self.tosa_spec.support_extension("fp8e4m3"): + supported_dtypes.append(ts.DType.FP8E4M3) + if self.tosa_spec.support_extension("fp8e5m2"): + supported_dtypes.append(ts.DType.FP8E5M2) validate_num_inputs(self.target, inputs, 2) validate_same_dtype(self.target, [inputs[0], output], ts) @@ -50,7 +54,6 @@ def define_node( supported_dtypes, self.tosa_spec, ) - pad_const = tosa_graph.addConst( [1], output.dtype, diff --git a/backends/arm/operators/op_tosa_scatter.py b/backends/arm/operators/op_tosa_scatter.py index b87a2598993..63c44f91fac 100644 --- a/backends/arm/operators/op_tosa_scatter.py +++ b/backends/arm/operators/op_tosa_scatter.py @@ -36,7 +36,13 @@ def define_node( validate_same_dtype(self.target, [inputs[0], inputs[2], output], ts) validate_valid_dtype( self.target, - [inputs[0], inputs[1], inputs[2], output], + [inputs[1]], + [ts.DType.INT32], + self.tosa_spec, + ) + validate_valid_dtype( + self.target, + [inputs[0], inputs[2], output], [ ts.DType.INT8, ts.DType.INT16, @@ -44,6 +50,16 @@ def define_node( ts.DType.FP32, ts.DType.FP16, ts.DType.BF16, + *( + [ts.DType.FP8E4M3] + if self.tosa_spec.support_extension("fp8e4m3") + else [] + ), + *( + [ts.DType.FP8E5M2] + if self.tosa_spec.support_extension("fp8e5m2") + else [] + ), ], self.tosa_spec, ) diff --git a/backends/arm/operators/op_tosa_slice.py b/backends/arm/operators/op_tosa_slice.py index 11ce95df466..818657642a8 100644 --- a/backends/arm/operators/op_tosa_slice.py +++ b/backends/arm/operators/op_tosa_slice.py @@ -42,6 +42,10 @@ def define_node( supported_dtypes.extend([ts.DType.FP16, ts.DType.FP32]) if self.tosa_spec.support_extension("bf16"): supported_dtypes.append(ts.DType.BF16) + if self.tosa_spec.support_extension("fp8e4m3"): + supported_dtypes.append(ts.DType.FP8E4M3) + if self.tosa_spec.support_extension("fp8e5m2"): + supported_dtypes.append(ts.DType.FP8E5M2) validate_num_inputs(self.target, inputs, 3) validate_same_dtype(self.target, [inputs[0], output], ts) diff --git a/backends/arm/operators/op_view.py b/backends/arm/operators/op_view.py index 94ed23e2446..6d399b65801 100644 --- a/backends/arm/operators/op_view.py +++ b/backends/arm/operators/op_view.py @@ -35,20 +35,26 @@ def define_node( inputs: List[TosaArg], output: TosaArg, ) -> None: - supported_dtypes = [ts.DType.BOOL] + supported_dtypes = {ts.DType.BOOL} if self.tosa_spec.support_integer(): - supported_dtypes.extend([ts.DType.INT8, ts.DType.INT16, ts.DType.INT32]) + supported_dtypes.update([ts.DType.INT8, ts.DType.INT16, ts.DType.INT32]) if self.tosa_spec.support_float(): - supported_dtypes.extend([ts.DType.FP16, ts.DType.FP32]) + supported_dtypes.update([ts.DType.FP16, ts.DType.FP32]) if self.tosa_spec.support_extension("bf16"): - supported_dtypes.append(ts.DType.BF16) + supported_dtypes.add(ts.DType.BF16) + if self.tosa_spec.support_extension("fp8e4m3"): + supported_dtypes.add(ts.DType.FP8E4M3) + if self.tosa_spec.support_extension("fp8e5m2"): + supported_dtypes.add(ts.DType.FP8E5M2) + if self.tosa_spec.support_extension("mxfp"): + supported_dtypes.update([ts.DType.FP8E4M3, ts.DType.FP8E5M2]) validate_num_inputs(self.target, inputs, 2) validate_same_dtype(self.target, [inputs[0], output], ts) validate_valid_dtype( self.target, [inputs[0], output], - supported_dtypes, + list(supported_dtypes), self.tosa_spec, ) diff --git a/backends/arm/operators/op_while.py b/backends/arm/operators/op_while.py index 2b6314d3454..58501dd3ba0 100644 --- a/backends/arm/operators/op_while.py +++ b/backends/arm/operators/op_while.py @@ -15,8 +15,14 @@ validate_cf_extension, validate_num_inputs, ) -from executorch.backends.arm.tosa.mapping import map_dtype, TosaArg +from executorch.backends.arm.tosa.mapping import ( + map_dtype, + TOSA_CONTROL_FLOW_REGION_NAME_META, + TOSA_TENSOR_NAME_META, + TosaArg, +) from executorch.backends.arm.tosa.utils import normalize_symint + from torch.fx import Node @@ -46,7 +52,12 @@ def define_node( ) attr = ts.TosaSerializerAttribute() - cond_graph, body_graph = (str(cast(Node, arg).target) for arg in node.args[:2]) + cond_graph, body_graph = ( + cast(Node, arg).meta.get( + TOSA_CONTROL_FLOW_REGION_NAME_META, str(cast(Node, arg).target) + ) + for arg in node.args[:2] + ) attr.WhileLoopAttribute(cond_graph, body_graph) input_names: list[str] = [] @@ -55,7 +66,9 @@ def define_node( raise ValueError( f"{self.target}: Unsupported carried input type {type(loop_input)}." ) - input_names.append(loop_input.name) + input_names.append( + loop_input.name + loop_input.meta.get(TOSA_TENSOR_NAME_META, "") + ) num_inputs = len(input_names) num_outputs = len(output.multiple_output_names) diff --git a/backends/arm/quantizer/arm_quantizer.py b/backends/arm/quantizer/arm_quantizer.py index f1dfb5f1323..3508410509c 100644 --- a/backends/arm/quantizer/arm_quantizer.py +++ b/backends/arm/quantizer/arm_quantizer.py @@ -40,6 +40,10 @@ from executorch.backends.cortex_m.quantizer.pattern_matcher import PatternMatcher from executorch.backends.cortex_m.quantizer_reporter import QuantizerReporter +from executorch.exir.graph_module import ( + _get_control_flow_submodules, + get_cond_while_submodules, +) from torch._ops import OpOverload @@ -52,10 +56,6 @@ from executorch.backends.arm.common.arm_compile_spec import ( ArmCompileSpec, ) # isort: skip -from executorch.backends.arm._passes.arm_pass_utils import ( - get_cond_while_submodules_nested, - is_submodule_node, -) from executorch.backends.arm.quantizer.arm_quantizer_utils import ( _get_int32_bias_qspec, @@ -107,6 +107,29 @@ logger = logging.getLogger(__name__) +def get_cond_while_submodules_ao( + graph_module: GraphModule, + apply_quantization: bool = False, +) -> list[tuple[str, GraphModule, Node]]: + """Return cond/while submodules for the current graph module. + + Quantization handles ``while_loop`` body functions natively in torchao, so + only the ``while_loop`` cond function is processed explicitly there. + + """ + + if not apply_quantization: + return get_cond_while_submodules(graph_module) + + return _get_control_flow_submodules( + graph_module, + { + torch.ops.higher_order.cond: [1, 2], + torch.ops.higher_order.while_loop: [0], + }, + ) + + @functools.lru_cache def get_symmetric_quantization_config( is_per_channel: bool = True, @@ -810,42 +833,56 @@ def _quantize_with_submodules( prepare_fn = prepare_qat_pt2e if is_qat else prepare_pt2e prepared = prepare_fn(model, self) - # Prepare conditional submodules (e.g., if/while bodies) - # prepare only cond branches and while_loop cond_fn - for name, submodule, _ in get_cond_while_submodules_nested( - prepared, apply_quantization=True - ): - prepared.set_submodule(name, prepare_fn(submodule, self), strict=True) - for submodule_node in submodule.graph.nodes: - if is_submodule_node(submodule_node): - for nested_name, nested_sub, _ in get_cond_while_submodules_nested( - submodule, apply_quantization=True - ): - prepared.set_submodule( - nested_name, prepare_fn(nested_sub, self), strict=True - ) + + def _prepare_control_flow_submodules( + source_graph_module: GraphModule, prefix: str = "" + ) -> None: + for name, submodule, _ in get_cond_while_submodules_ao( + source_graph_module, apply_quantization=True + ): + qualified_name = f"{prefix}.{name}" if prefix else name + prepared.set_submodule( + qualified_name, prepare_fn(submodule, self), strict=True + ) + _prepare_control_flow_submodules(submodule, qualified_name) + + _prepare_control_flow_submodules(prepared) for inp in calibration_samples: prepared(*inp) - # Prepare conditional submodules (e.g., if/while bodies) - # convert only cond branches and while_loop cond_fn - for _, submodule, _ in get_cond_while_submodules_nested( - prepared, apply_quantization=True + def _convert_control_flow_submodule( + graph_module: GraphModule, + ) -> GraphModule: + converted_submodules: list[tuple[str, GraphModule]] = [] + for name, submodule, _ in get_cond_while_submodules_ao( + graph_module, apply_quantization=True + ): + converted_submodules.append( + (name, _convert_control_flow_submodule(submodule)) + ) + converted_graph_module = convert_pt2e( + graph_module, fold_quantize=fold_quantize + ) + for name, converted_submodule in converted_submodules: + converted_graph_module.set_submodule( + name, converted_submodule, strict=True + ) + return converted_graph_module + + converted_top_level_submodules: list[tuple[str, GraphModule]] = [] + for name, submodule, _ in list( + get_cond_while_submodules_ao(prepared, apply_quantization=True) ): - converted = convert_pt2e(submodule, fold_quantize=fold_quantize) - for submodule_node in submodule.graph.nodes: - if is_submodule_node(submodule_node): - for nested_name, nested_sub, _ in get_cond_while_submodules_nested( - submodule, apply_quantization=True - ): - converted.set_submodule( - nested_name, - convert_pt2e(nested_sub, fold_quantize=fold_quantize), - strict=True, - ) + converted_top_level_submodules.append( + (name, _convert_control_flow_submodule(submodule)) + ) + + converted = convert_pt2e(prepared, fold_quantize=fold_quantize) + for name, converted_submodule in converted_top_level_submodules: + converted.set_submodule(name, converted_submodule, strict=True) - return convert_pt2e(prepared, fold_quantize=fold_quantize) + return converted class _TOSAQuantizerV1(Quantizer): diff --git a/backends/arm/runtime/VGFSetup.cpp b/backends/arm/runtime/VGFSetup.cpp index b62a6b2ec23..307d0ab266e 100644 --- a/backends/arm/runtime/VGFSetup.cpp +++ b/backends/arm/runtime/VGFSetup.cpp @@ -793,9 +793,14 @@ bool VgfRepr::process_vgf( return false; } - vector - bind_point_requirements; - bind_point_requirements.resize(bind_point_count); + vector bind_point_requirements( + bind_point_count, + { + .sType = + VK_STRUCTURE_TYPE_DATA_GRAPH_PIPELINE_SESSION_BIND_POINT_REQUIREMENT_ARM, + .pNext = nullptr, + }); + result = vkGetDataGraphPipelineSessionBindPointRequirementsARM( vk_device, &bind_point_requirements_info, diff --git a/backends/arm/scripts/build_executorch.sh b/backends/arm/scripts/build_executorch.sh index 54d2091d1f4..362fc4d40bf 100755 --- a/backends/arm/scripts/build_executorch.sh +++ b/backends/arm/scripts/build_executorch.sh @@ -7,6 +7,7 @@ # Optional parameter: # --build_type= "Release" | "Debug" | "RelWithDebInfo" | "UndefinedSanitizer" | "AddressSanitizer" # --etdump build with devtools-etdump support +# --cmake-args= Additional arguments passed to cmake configure set -eu @@ -24,6 +25,7 @@ build_type="Release" build_devtools=OFF build_with_etdump=OFF is_linux_musl=0 +extra_cmake_args=() target_cpu="" help() { @@ -33,6 +35,7 @@ help() { echo " --build_type= Build with Release, Debug, RelWithDebInfo, UndefinedSanitizer or AddressSanitizer, default is ${build_type}" echo " --devtools Build Devtools libs" echo " --etdump Adds Devtools etdump support to track timing, etdump area will be base64 encoded in the log" + echo " --cmake-args= Additional arguments passed to cmake configure" echo " --toolchain= Toolchain can be specified (arm-none-eabi-gcc, arm-zephyr-eabi-gcc, aarch64-linux-musl-gcc). Default: ${toolchain}" echo " --target_cpu= Override the toolchain's default TARGET_CPU (e.g. cortex-m4). Switching target_cpu reuses the same cmake-out dir, so clear ${et_build_root}/cmake-out first to avoid stale per-CPU artifacts. Default: unset (toolchain default)." exit 0 @@ -45,6 +48,10 @@ for arg in "$@"; do --build_type=*) build_type="${arg#*=}";; --devtools) build_devtools=ON ;; --etdump) build_with_etdump=ON ;; + --cmake-args=*) + # shellcheck disable=SC2206 + extra_cmake_args=(${arg#*=}) + ;; --toolchain=*) toolchain="${arg#*=}";; --target_cpu=*) target_cpu="${arg#*=}";; *) @@ -90,6 +97,13 @@ cmake_args=( -DEXECUTORCH_BUILD_ARM_ETDUMP=${build_with_etdump} -DEXECUTORCH_BAREMETAL_SKIP_INSTALL=OFF ) +if ((${#extra_cmake_args[@]})); then + cmake_args+=("${extra_cmake_args[@]}") +fi + +if [[ ${#extra_cmake_args[@]} -gt 0 ]]; then + cmake_args+=("${extra_cmake_args[@]}") +fi if [[ -n "${target_cpu}" ]]; then cmake_args+=(-DTARGET_CPU=${target_cpu}) diff --git a/backends/arm/scripts/corstone_utils.cmake b/backends/arm/scripts/corstone_utils.cmake index 58ce4f9a919..0ed1e4aea0f 100644 --- a/backends/arm/scripts/corstone_utils.cmake +++ b/backends/arm/scripts/corstone_utils.cmake @@ -8,6 +8,7 @@ function(fetch_ethos_u_content ETHOS_SDK_PATH ET_DIR_PATH) file(MAKE_DIRECTORY ${ETHOS_SDK_PATH}/../ethos_u) include(FetchContent) + find_package(Python3 REQUIRED COMPONENTS Interpreter) set(ethos_u_base_tag "26.02") FetchContent_Declare( ethos_u @@ -33,10 +34,13 @@ function(fetch_ethos_u_content ETHOS_SDK_PATH ET_DIR_PATH) "source backends/arm/scripts/utils.sh && patch_repo ${ETHOS_SDK_PATH} ${ethos_u_base_rev} ${patch_dir}" WORKING_DIRECTORY ${ET_DIR_PATH} ) - # Get ethos_u externals only if core_platform folder does not already exist. - if(NOT EXISTS "${ETHOS_SDK_PATH}/core_platform") + + # Get ethos_u externals only if core driver headers do not already exist. + if(NOT EXISTS + "${ETHOS_SDK_PATH}/core_software/core_driver/include/ethosu_driver.h" + ) execute_process( - COMMAND ${PYTHON_EXECUTABLE} fetch_externals.py -c + COMMAND ${Python3_EXECUTABLE} fetch_externals.py -c ${ethos_u_base_tag}.json fetch WORKING_DIRECTORY ${ETHOS_SDK_PATH} ) @@ -50,11 +54,12 @@ function(fetch_ethos_u_content ETHOS_SDK_PATH ET_DIR_PATH) WORKING_DIRECTORY ${ET_DIR_PATH} ) # Always patch the core_platform repo since this is fast enough. TODO: - # examples/arm/ethos-u-setup/core_platform/0002-*.patch is a transient bridge - # that guards Armv8-M-only MPU init so the source compiles for non-Armv8-M - # Cortex-M cores. Once the same guard lands upstream in ethos-u/core_platform - # and ${core_platform_base_rev} is bumped past that commit, delete the 0002 - # patch. + # examples/arm/ethos-u-setup/core_platform/0002-*.patch and 0003-*.patch are + # transient bridges that guard Armv8-M-only MPU init and the Armv7-M-and-newer + # HardFault handler so the Corstone-300 target source compiles for older + # Cortex-M cores. Once the equivalent guards land upstream in + # ethos-u/core_platform and ${core_platform_base_rev} is bumped past those + # commits, delete the 0002 and 0003 patches. set(core_platform_base_rev "26.02") execute_process( COMMAND diff --git a/backends/arm/scripts/pre-push b/backends/arm/scripts/pre-push index 8e26463cd94..6aa32d07286 100755 --- a/backends/arm/scripts/pre-push +++ b/backends/arm/scripts/pre-push @@ -177,7 +177,7 @@ for COMMIT in ${COMMITS}; do for committed_file in "${license_files[@]}"; do # Skip files with certain extensions case "$committed_file" in - *.md|*.md.in|*.json|*.yml|*.yaml|*.cmake|*.patch|.gitignore|*.bzl) + *.md|*.md.in|*.json|*.yml|*.yaml|*.cmake|*.patch|.gitignore|*.bzl|BUCK|*/BUCK|TARGETS|*/TARGETS) echo -e "${INFO} Skipping license check for ${committed_file} (excluded extension)" continue ;; diff --git a/backends/arm/scripts/vulkan_utils.sh b/backends/arm/scripts/vulkan_utils.sh index c8b169c0c3d..520c244c6fb 100644 --- a/backends/arm/scripts/vulkan_utils.sh +++ b/backends/arm/scripts/vulkan_utils.sh @@ -71,6 +71,9 @@ function install_vulkan_sdk_macos() { fi log_step "vulkan" "Extracting Vulkan SDK installer" + rm -rf \ + "vulkansdk-macOS-${vulkan_sdk_version}.app" \ + "vulkansdk-macos-${vulkan_sdk_version}.app" unzip -q -o "${vulkan_sdk_zip_file}" local vulkan_sdk_app_path="" @@ -91,15 +94,33 @@ function install_vulkan_sdk_macos() { local install_root="$(cd "${root_dir}" && pwd)/${vulkan_sdk_base_dir}/${vulkan_sdk_version}" mkdir -p "${install_root}" - local vulkan_sdk_root="${root_dir}/${vulkan_sdk_base_dir}" log_step "vulkan" "Installing Vulkan SDK (${vulkan_sdk_version}) to ${install_root}" - ${vulkan_sdk_installer} --root "${install_root}" --accept-licenses --default-answer --confirm-command install + "${vulkan_sdk_installer}" --root "${install_root}" --accept-licenses --default-answer --confirm-command install +} + +function validate_vulkan_sdk_installation() { + if [[ ! -d "${root_dir}/${vulkan_sdk_bin_dir}" ]]; then + return 1 + fi + + vulkan_sdk_bin_path="$(cd "${root_dir}/${vulkan_sdk_bin_dir}" && pwd)" + if [[ ! -x "${vulkan_sdk_bin_path}/glslc" ]]; then + return 1 + fi + + "${vulkan_sdk_bin_path}/glslc" --version > /dev/null 2>&1 } function setup_vulkan_sdk() { cd "${root_dir}" + if validate_vulkan_sdk_installation; then + log_step "vulkan" "Reusing Vulkan SDK at ${root_dir}/${vulkan_sdk_base_dir}/${vulkan_sdk_version}" + log_step "vulkan" "Vulkan SDK validation (glslc) succeeded" + return + fi + if [[ "${os_name}" == "Darwin" ]]; then install_vulkan_sdk_macos else @@ -117,11 +138,11 @@ function setup_vulkan_sdk() { exit 1 fi - if ${vulkan_sdk_bin_path}/glslc --version > /dev/null 2>&1; then + if "${vulkan_sdk_bin_path}/glslc" --version > /dev/null 2>&1; then log_step "vulkan" "Vulkan SDK validation (glslc) succeeded" else log_step "vulkan" "Error: Vulkan SDK validation failed" - ${vulkan_sdk_bin_path}/glslc --version + "${vulkan_sdk_bin_path}/glslc" --version exit 1 fi } @@ -143,7 +164,7 @@ function setup_path_vulkan() { vulkan_sdk_arch_root="$(cd "${vulkan_sdk_arch_root}" && pwd)" vulkan_sdk_bin_path="$(cd "${vulkan_sdk_bin_dir}" && pwd)" - append_env_in_setup_path PATH ${vulkan_sdk_bin_path} + append_env_in_setup_path PATH "${vulkan_sdk_bin_path}" if [[ "${OS:-}" == "Darwin" ]]; then prepend_env_in_setup_path DYLD_LIBRARY_PATH "${vulkan_sdk_arch_root}/lib" local moltenvk_icd_path="${vulkan_sdk_arch_root}/share/vulkan/icd.d/MoltenVK_icd.json" diff --git a/backends/arm/test/misc/test_mxfp_linear_ao.py b/backends/arm/test/misc/test_mxfp_linear_ao.py new file mode 100644 index 00000000000..0f2b6b9198c --- /dev/null +++ b/backends/arm/test/misc/test_mxfp_linear_ao.py @@ -0,0 +1,46 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.backends.arm.ao_ext import MXFPOpConfig, to_mxfp +from executorch.backends.arm.ao_ext.ops import MXFPLinearOp + +from torch.export import export + + +class LinearModule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(32, 8, bias=True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(x) + + +def test_mxfp_linear_quantize_swaps_module() -> None: + model = LinearModule().eval() + + to_mxfp(model, MXFPOpConfig()) + + assert isinstance(model.linear, MXFPLinearOp) + assert model.linear.weight_qdata.dtype == torch.float8_e4m3fn + assert model.linear.weight_scale.dtype == torch.float8_e8m0fnu + assert tuple(model.linear.weight_qdata.shape) == (1, 8, 32) + assert tuple(model.linear.weight_scale.shape) == (1, 8, 1) + + +def test_mxfp_linear_export_preserves_custom_op() -> None: + model = LinearModule().eval() + to_mxfp(model, MXFPOpConfig()) + + exported = export(model, (torch.randn(4, 32),), strict=False) + + targets = [ + node.target + for node in exported.graph_module.graph.nodes + if node.op == "call_function" + ] + + assert torch.ops.tosa_mxfp.linear.default in targets diff --git a/backends/arm/test/misc/test_tosa_dialect_scatter.py b/backends/arm/test/misc/test_tosa_dialect_scatter.py new file mode 100644 index 00000000000..dc75df60df9 --- /dev/null +++ b/backends/arm/test/misc/test_tosa_dialect_scatter.py @@ -0,0 +1,38 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import executorch.backends.arm.tosa.dialect # noqa: F401 +import pytest +import torch +from executorch.backends.arm.tosa.specification import ( + TosaLoweringContext, + TosaSpecification, +) +from executorch.exir.dialects._ops import ops as exir_ops +from torch._subclasses.fake_tensor import FakeTensorMode + + +@pytest.mark.parametrize( + "dtype, extension", + [ + (torch.float8_e4m3fn, "fp8e4m3"), + (torch.float8_e5m2, "fp8e5m2"), + ], +) +def test_scatter_tosa_FP_fp8(dtype: torch.dtype, extension: str): + with TosaLoweringContext( + TosaSpecification.create_from_string(f"TOSA-1.0+FP+{extension}") + ), FakeTensorMode() as mode: + values_in = mode.from_tensor( + torch.rand((1, 5, 3), dtype=torch.float32).to(dtype) + ) + indices = mode.from_tensor(torch.tensor([[1, 3]], dtype=torch.int32)) + input_tensor = mode.from_tensor( + torch.rand((1, 2, 3), dtype=torch.float32).to(dtype) + ) + output = exir_ops.backend.tosa.SCATTER.default(values_in, indices, input_tensor) + + assert output.dtype == dtype + assert tuple(output.shape) == (1, 5, 3) diff --git a/backends/arm/test/misc/tosa_dialect/test_tosa_resize.py b/backends/arm/test/misc/tosa_dialect/test_tosa_resize.py index d9d8b89feb6..eddb69a8caf 100644 --- a/backends/arm/test/misc/tosa_dialect/test_tosa_resize.py +++ b/backends/arm/test/misc/tosa_dialect/test_tosa_resize.py @@ -33,13 +33,14 @@ def _expr(sym: torch.SymInt) -> sympy.Expr: return sympy.sympify(getattr(sym.node, "expr", sym.node._expr)) -def test_bilinear_resize_rejects_exact_one_sixteenth_downscale(): +@pytest.mark.parametrize("resize_mode", ("nearest", "bilinear")) +def test_resize_rejects_exact_one_sixteenth_downscale(resize_mode: str): with TosaLoweringContext( TosaSpecification.create_from_string("TOSA-1.0+INT") ), FakeTensorMode() as mode: with pytest.raises( TosaValueError, - match="Bilinear RESIZE downscale must be strictly greater than 1/16", + match="RESIZE downscale must be strictly greater than 1/16", ): exir_ops.backend.tosa.RESIZE.default( mode.from_tensor( @@ -48,7 +49,50 @@ def test_bilinear_resize_rejects_exact_one_sixteenth_downscale(): [2, 32, 2, 32], [15, 15], [-15, -15], - resize_mode="bilinear", + resize_mode=resize_mode, + ) + + +def test_resize_rejects_scale_numerator_over_tosa_limit(): + with TosaLoweringContext( + TosaSpecification.create_from_string("TOSA-1.0+INT") + ), FakeTensorMode() as mode: + with pytest.raises( + TosaValueError, + match="RESIZE scale numerator must be <= 2048", + ): + exir_ops.backend.tosa.RESIZE.default( + mode.from_tensor(torch.randint(0, 10, (1, 3, 4, 2), dtype=torch.int8)), + # 2049 violates scale_n <= 1 << 11, while 2049/2 still stays + # within MAX_SCALE so this test isolates the numerator rule. + [2049, 2, 4, 2], + [0, 0], + [0, 0], + resize_mode="nearest", + ) + + +@pytest.mark.parametrize( + "offset,border", + ( + ([1, 0], [-1, 0]), + ([0, 1], [0, -1]), + ), +) +def test_resize_rejects_non_positive_output_dimensions(offset, border): + with TosaLoweringContext( + TosaSpecification.create_from_string("TOSA-1.0+INT") + ), FakeTensorMode() as mode: + with pytest.raises( + TosaValueError, + match="RESIZE output dimensions must be positive", + ): + exir_ops.backend.tosa.RESIZE.default( + mode.from_tensor(torch.randint(0, 10, (1, 1, 1, 1), dtype=torch.int8)), + [1, 1, 1, 1], + offset, + border, + resize_mode="nearest", ) diff --git a/backends/arm/test/models/Qwen3_VL/test_qwen3_vl_layers.py b/backends/arm/test/models/Qwen3_VL/test_qwen3_vl_layers.py index 77b2739167a..f1ffe35b14e 100644 --- a/backends/arm/test/models/Qwen3_VL/test_qwen3_vl_layers.py +++ b/backends/arm/test/models/Qwen3_VL/test_qwen3_vl_layers.py @@ -33,7 +33,7 @@ Qwen3VLVisionRotaryEmbedding, ) -input_t = Tuple[torch.Tensor, ...] +input_t = Tuple[torch.Tensor | int, ...] def _make_qwen3_vl_2b_instruct_layer_config(): @@ -99,6 +99,19 @@ def prepare_model_and_inputs(cls): raise NotImplementedError +def _to_bfloat16( + model: torch.nn.Module, inputs: input_t +) -> tuple[torch.nn.Module, input_t]: + return model.to(torch.bfloat16), tuple( + ( + x.to(torch.bfloat16) + if isinstance(x, torch.Tensor) and x.is_floating_point() + else x + ) + for x in inputs + ) + + class Qwen3VLVisionMLPModel(Qwen3VLTestModule): def __init__(self, config) -> None: super().__init__() @@ -442,6 +455,18 @@ class Qwen3VLTestCase: VGF_NO_QUANT_TEST_CASES: dict[str, Qwen3VLTestCase] = TOSA_FP_TEST_CASES +TOSA_BF16_TEST_CASES: dict[str, Qwen3VLTestCase] = { + "vision_mlp": TOSA_FP_TEST_CASES["vision_mlp"], + "vision_patch_embed": TOSA_FP_TEST_CASES["vision_patch_embed"], + "vision_rotary_embedding": TOSA_FP_TEST_CASES["vision_rotary_embedding"], + "vision_rotary_apply": TOSA_FP_TEST_CASES["vision_rotary_apply"], + "vision_attention": TOSA_FP_TEST_CASES["vision_attention"], + "vision_block": TOSA_FP_TEST_CASES["vision_block"], + "vision_patch_merger": TOSA_FP_TEST_CASES["vision_patch_merger"], + "text_rms_norm": TOSA_FP_TEST_CASES["text_rms_norm"], + "qk_norm": TOSA_FP_TEST_CASES["qk_norm"], +} + @common.parametrize( "test_case", @@ -460,6 +485,27 @@ def test_qwen3_vl_tosa_FP(test_case: Qwen3VLTestCase): pipeline.run() +@common.parametrize( + "test_case", + TOSA_BF16_TEST_CASES, +) +def test_qwen3_vl_tosa_FP_bf16(test_case: Qwen3VLTestCase): + model, inputs = test_case.model_cls.prepare_model_and_inputs() + model, inputs = _to_bfloat16(model, inputs) + with torch.no_grad(): + pipeline = TosaPipelineFP[input_t]( + model, + inputs, + aten_op=[], + exir_op=[], + transform_passes=list(test_case.transform_passes), + tosa_extensions=["bf16"], + atol=1e-2, + rtol=1e-2, + ) + pipeline.run() + + @common.SkipIfNoModelConverter @common.parametrize( "test_case", diff --git a/backends/arm/test/models/test_swin2sr_arm.py b/backends/arm/test/models/test_swin2sr_arm.py index 6bf9b2a18d5..5fd29943b94 100644 --- a/backends/arm/test/models/test_swin2sr_arm.py +++ b/backends/arm/test/models/test_swin2sr_arm.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import sys from typing import Tuple import torch @@ -17,7 +18,7 @@ input_t = Tuple[torch.Tensor] -exir_ops = [ +ops_expected_absent_after_lowering = [ "executorch_exir_dialects_edge__ops_aten_add_Tensor", "executorch_exir_dialects_edge__ops_aten_convolution_default", "executorch_exir_dialects_edge__ops_aten_layer_norm_default", @@ -27,6 +28,24 @@ "executorch_exir_dialects_edge__ops_aten_softmax_int", ] +# TODO/MLETORCH-2163: Investigate Swin2SR delegation gaps around index/view +# in FP and Q/DQ, clamp, and expand_copy in INT. +swin2sr_fp_lowered_outer_graph_ops = { + "torch.ops.higher_order.executorch_call_delegate": 2, + "executorch_exir_dialects_edge__ops_aten_index_Tensor": 2, + "executorch_exir_dialects_edge__ops_aten_view_copy_default": 2, +} +swin2sr_int_lowered_outer_graph_ops = { + "torch.ops.higher_order.executorch_call_delegate": 3, + "executorch_exir_dialects_edge__ops_aten_clamp_default": 4, + "executorch_exir_dialects_edge__ops_aten_expand_copy_default": 4, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 5, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 6, +} +swin2sr_vgf_quant_lowered_outer_graph_ops = { + "torch.ops.higher_order.executorch_call_delegate": 1, +} + class TinySwin2SR(torch.nn.Module): def __init__(self): @@ -62,12 +81,10 @@ def test_swin2sr_tosa_FP(): model, model_inputs, aten_op=[], - exir_op=exir_ops, + exir_op=ops_expected_absent_after_lowering, use_to_edge_transform_and_lower=True, ) - pipeline.pop_stage("check_count.exir") - # TODO: MLETORCH-2134 re-enable once Swin2SR runs on the TOSA ref model. - pipeline.pop_stage("run_method_and_compare_outputs") + pipeline.change_args("check_count.exir", swin2sr_fp_lowered_outer_graph_ops) pipeline.run() @@ -77,12 +94,10 @@ def test_swin2sr_tosa_INT(): model, model_inputs, aten_op=[], - exir_op=exir_ops, + exir_op=ops_expected_absent_after_lowering, use_to_edge_transform_and_lower=True, ) - pipeline.pop_stage("check_count.exir") - # TODO: MLETORCH-2134 re-enable once Swin2SR runs on the TOSA ref model. - pipeline.pop_stage("run_method_and_compare_outputs") + pipeline.change_args("check_count.exir", swin2sr_int_lowered_outer_graph_ops) pipeline.run() @@ -93,13 +108,12 @@ def test_swin2sr_vgf_quant(): model, model_inputs, aten_op=[], - exir_op=exir_ops, + exir_op=ops_expected_absent_after_lowering, use_to_edge_transform_and_lower=True, quantize=True, + run_on_vulkan_runtime=sys.platform == "linux", ) - pipeline.pop_stage("check_count.exir") - # TODO: MLETORCH-2134 re-enable once Swin2SR runs on the TOSA ref model. - pipeline.pop_stage("run_method_and_compare_outputs") + pipeline.change_args("check_count.exir", swin2sr_vgf_quant_lowered_outer_graph_ops) pipeline.run() @@ -110,9 +124,9 @@ def test_swin2sr_vgf_no_quant(): model, model_inputs, aten_op=[], - exir_op=exir_ops, + exir_op=ops_expected_absent_after_lowering, use_to_edge_transform_and_lower=True, quantize=False, ) - pipeline.pop_stage("check_count.exir") + pipeline.change_args("check_count.exir", swin2sr_fp_lowered_outer_graph_ops) pipeline.run() diff --git a/backends/arm/test/models/test_torch_functions.py b/backends/arm/test/models/test_torch_functions.py index 0ca8d3ac091..c6a4c5580dc 100644 --- a/backends/arm/test/models/test_torch_functions.py +++ b/backends/arm/test/models/test_torch_functions.py @@ -97,8 +97,6 @@ def forward(self, *args): "test_data", test_parameters, xfails={ - "nonzero": "torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not guard on data-dependent expression Eq(u4, 0). " - "Requires dynamic output shape.", "topk": "NotImplementedError: No registered serialization name for found", "sort": "NotImplementedError: No registered serialization name for found", }, @@ -124,8 +122,6 @@ def test_torch_functions_tosa_FP(test_data): "test_data", test_parameters, xfails={ - "nonzero": "torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not guard on data-dependent expression Eq(u4, 0). " - "Requires dynamic output shape.", "topk": "NotImplementedError: No registered serialization name for found", "sort": "NotImplementedError: No registered serialization name for found", }, diff --git a/backends/arm/test/ops/test_cat.py b/backends/arm/test/ops/test_cat.py index 1e145ef5485..29738ddbe32 100644 --- a/backends/arm/test/ops/test_cat.py +++ b/backends/arm/test/ops/test_cat.py @@ -98,6 +98,24 @@ class Cat(torch.nn.Module): 0, ), } + test_parameters_fp8 = { + "cat_rand_two_tensors_fp8e4m3": lambda: ( + ( + torch.randn(1, 2, 4, 4, dtype=torch.float32).to(torch.float8_e4m3fn), + torch.randn(1, 2, 4, 1, dtype=torch.float32).to(torch.float8_e4m3fn), + ), + 3, + "fp8e4m3", + ), + "cat_rand_dim0_fp8e5m2": lambda: ( + ( + torch.randn(1, 2, 4, 4, dtype=torch.float32).to(torch.float8_e5m2), + torch.randn(1, 2, 4, 4, dtype=torch.float32).to(torch.float8_e5m2), + ), + 0, + "fp8e5m2", + ), + } def __init__(self): super().__init__() @@ -135,6 +153,19 @@ def test_cat_tosa_FP_4d(): pipeline.run() +@common.parametrize("test_data", Cat.test_parameters_fp8) +def test_cat_tosa_FP_fp8(test_data: Tuple): + tensors, dim, tosa_extension = test_data() + pipeline = TosaPipelineFP[input_t1]( + Cat(), + (tensors, dim), + aten_op, + exir_op, + tosa_extensions=[tosa_extension], + ) + pipeline.run() + + @common.parametrize("test_data", Cat.test_parameters) def test_cat_tosa_INT(test_data: Tuple): pipeline = TosaPipelineINT[input_t1]( diff --git a/backends/arm/test/ops/test_cond.py b/backends/arm/test/ops/test_cond.py index 8c6d9ef329c..6f489f0ab01 100644 --- a/backends/arm/test/ops/test_cond.py +++ b/backends/arm/test/ops/test_cond.py @@ -250,8 +250,6 @@ def test_cond_tosa_INT(case: Callable[[], tuple[torch.nn.Module, tuple]]): example_inputs, aten_op, tosa_extensions=["cf"], - frobenius_threshold=0.8, - cosine_threshold=0.8, # MLETORCH-1808 ) _set_branch_calibration_samples(pipeline, module, example_inputs) # Make sure no cond ops are left after partitioning. diff --git a/backends/arm/test/ops/test_constant_pad_nd.py b/backends/arm/test/ops/test_constant_pad_nd.py index 3742f710494..96d829851ed 100644 --- a/backends/arm/test/ops/test_constant_pad_nd.py +++ b/backends/arm/test/ops/test_constant_pad_nd.py @@ -128,6 +128,22 @@ "constant", ), } +test_data_suite_fp8 = { + "4dim_last1dim_fp8e4m3": lambda: ( + torch.rand(1, 1, 8, 8, dtype=torch.float32).to(torch.float8_e4m3fn), + (1, 1, 0, 0, 0, 0, 0, 0), + 1.0, + "constant", + "fp8e4m3", + ), + "3dim_last1dim_fp8e5m2": lambda: ( + torch.rand(1, 1, 8, dtype=torch.float32).to(torch.float8_e5m2), + (1, 0, 1, 0, 0, 0), + -0.5, + "constant", + "fp8e5m2", + ), +} class ConstantPadND(torch.nn.Module): @@ -289,6 +305,19 @@ def test_constant_pad_nd_tosa_FP(test_data: Tuple): pipeline.run() +@common.parametrize("test_data", test_data_suite_fp8) +def test_constant_pad_nd_tosa_FP_fp8(test_data: Tuple): + test_data, padding, value, mode, tosa_extension = test_data() + pipeline = TosaPipelineFP[input_t1]( + ConstantPadND(padding, value, mode), + (test_data,), + aten_op, + exir_op, + tosa_extensions=[tosa_extension], + ) + pipeline.run() + + @common.parametrize("test_data", test_data_suite) def test_constant_pad_nd_tosa_INT(test_data: Tuple): test_data, padding, value, mode = test_data() diff --git a/backends/arm/test/ops/test_gather.py b/backends/arm/test/ops/test_gather.py index 1439210373d..66cb9508c73 100644 --- a/backends/arm/test/ops/test_gather.py +++ b/backends/arm/test/ops/test_gather.py @@ -87,6 +87,36 @@ def forward(self, input_: torch.Tensor, dim_, index_: torch.Tensor): ), # Shape: [N=2, W=2, C=2] ), } +test_data_fp_fp8: dict[str, tuple[input_params, str]] = { + "test_fp8e4m3_2d": ( + ( + torch.tensor( + [[0.5, 1.25, 2.5], [3.5, 4.25, 5.75]], + dtype=torch.float8_e4m3fn, + ), + 1, + torch.tensor( + [[1, 0], [2, 1]], + dtype=torch.int64, + ), + ), + "fp8e4m3", + ), + "test_fp8e5m2_3d": ( + ( + torch.tensor( + [[[0.5, 1.5], [2.5, 3.5]], [[4.5, 5.5], [6.5, 7.5]]], + dtype=torch.float8_e5m2, + ), + 1, + torch.tensor( + [[[0, 1], [1, 0]], [[1, 0], [0, 1]]], + dtype=torch.int64, + ), + ), + "fp8e5m2", + ), +} # INT profile: integer inputs + bool (bool is supported via casts in @@ -145,6 +175,23 @@ def test_gather_tosa_FP(test_data: input_params): pipeline.run() +@common.parametrize("test_data", test_data_fp_fp8) +def test_gather_tosa_FP_fp8(test_data: tuple[input_params, str]): + input_data, tosa_extension = test_data + pipeline = TosaPipelineFP[input_params]( + Gather(), + input_data, + aten_op=Gather.aten_op, + exir_op=Gather.exir_op, + transform_passes=[ + InsertInt32CastsAfterInt64PlaceholdersPass(), + ], # int64 index are not currently supported and need to be cast to int32 + run_on_tosa_ref_model=False, # torch.gather() has no eager CPU FP8 implementation here, so eager reference execution fails. + tosa_extensions=[tosa_extension], + ) + pipeline.run() + + @common.parametrize("test_data", test_data_int | test_data_fp) def test_gather_tosa_INT(test_data: input_params): pipeline = TosaPipelineINT[input_params]( diff --git a/backends/arm/test/ops/test_index_select.py b/backends/arm/test/ops/test_index_select.py index bb5f0a92c51..4de19d30daf 100644 --- a/backends/arm/test/ops/test_index_select.py +++ b/backends/arm/test/ops/test_index_select.py @@ -61,6 +61,26 @@ def forward(self, input_: torch.Tensor, dim: int, index_: torch.Tensor): torch.tensor([3, 1], dtype=torch.int32), # [W=2] ), } +test_data_fp_bf16: dict[str, input_params] = { + # Rank-2: [K, C] -> index_select dim=0 => [W, C] + "test_bf16_rank2_dim0": ( + torch.tensor( + [[0.5, 1.25, 2.5], [3.5, 4.25, 5.75], [6.5, 7.25, 8.75]], + dtype=torch.bfloat16, + ), # [K=3, C=3] + 0, + torch.tensor([2, 0], dtype=torch.int32), # [W=2] + ), + # Rank-3: [N, K, C] -> index_select dim=-1 => [N, K, W] + "test_bf16_rank3_dim_neg1": ( + torch.tensor( + [[[0.5, 1.5], [2.5, 3.5]], [[4.5, 5.5], [6.5, 7.5]]], + dtype=torch.bfloat16, + ), # [N=2, K=2, C=2] + -1, + torch.tensor([1, 0], dtype=torch.int32), # [W=2] + ), +} # ---- INT profile: integer inputs + bool ---- test_data_int: dict[str, input_params] = { @@ -104,6 +124,18 @@ def test_index_select_tosa_FP(test_data: input_params): pipeline.run() +@common.parametrize("test_data", test_data_fp_bf16) +def test_index_select_tosa_FP_bf16(test_data: input_params): + pipeline = TosaPipelineFP[input_params]( + IndexSelect(), + test_data, + aten_op=IndexSelect.aten_op, + exir_op=IndexSelect.exir_op, + tosa_extensions=["bf16"], + ) + pipeline.run() + + @common.parametrize("test_data", test_data_int | test_data_fp) def test_index_select_tosa_INT(test_data: input_params): # INT profile runs quantized, so we test both int inputs and float inputs here. diff --git a/backends/arm/test/ops/test_mxfp_linear.py b/backends/arm/test/ops/test_mxfp_linear.py new file mode 100644 index 00000000000..da1bbec3b83 --- /dev/null +++ b/backends/arm/test/ops/test_mxfp_linear.py @@ -0,0 +1,226 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import copy + +import torch +from executorch.backends.arm.ao_ext import MXFPOpConfig, to_mxfp +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.analyze_output_utils import ( + compare_rel_frobenius_and_cosine_similarity, +) + + +def _block_input_rank1() -> torch.Tensor: + """Create a rank-1 input with distinct MXFP activation block scales.""" + + return torch.cat( + ( + 1e-3 * torch.randn(32), + 100.0 * torch.randn(32), + ) + ) + + +def _block_input_rank2() -> torch.Tensor: + """Create a rank-2 input with per-row activation block scale changes.""" + + return torch.stack( + ( + _block_input_rank1(), + torch.cat( + ( + 100.0 * torch.randn(32), + 1e-3 * torch.randn(32), + ) + ), + ) + ) + + +_test_data_rank1_fp = { + "mxfp_linear_rank1_zeros": lambda: ( + torch.zeros(32 * 8), + 5, + True, + False, + ), + "mxfp_linear_rank1_rand": lambda: ( + torch.rand(32), + 16, + False, + False, + ), +} + +_test_data_rank2_fp = { + "mxfp_linear_rank2_zeros": lambda: ( + torch.zeros(4, 32), + 16, + True, + False, + ), + "mxfp_linear_rank2_rand": lambda: ( + torch.rand(4, 32 * 6), + 13, + True, + False, + ), +} + +_test_data_rank3_fp = { + "mxfp_linear_rank3_zeros": lambda: ( + torch.zeros(2, 4, 32 * 3), + 1, + True, + False, + ), + "mxfp_linear_rank3_rand": lambda: ( + torch.rand(2, 4, 32), + 20, + True, + False, + ), +} + +_test_data_rank4_fp = { + "mxfp_linear_rank4_zeros": lambda: ( + torch.zeros(2, 3, 4, 32 * 24), + 8, + True, + False, + ), + "mxfp_linear_rank4_rand": lambda: ( + torch.rand(2, 3, 4, 32 * 32), + 64, + False, + False, + ), +} + +_test_data_block_fp = { + "mxfp_linear_rank1_block_weights": lambda: ( + torch.ones(64), + 4, + False, + True, + ), + "mxfp_linear_rank1_block_weights_block_activations": lambda: ( + _block_input_rank1(), + 4, + False, + True, + ), + "mxfp_linear_rank2_block_weights_block_activations": lambda: ( + _block_input_rank2(), + 4, + False, + True, + ), +} + +test_data_fp = ( + _test_data_rank1_fp + | _test_data_rank2_fp + | _test_data_rank3_fp + | _test_data_rank4_fp + | _test_data_block_fp +) + + +class Linear(torch.nn.Module): + def __init__( + self, + in_features: int, + out_features: int = 8, + bias: bool = True, + ) -> None: + super().__init__() + self.fc = torch.nn.Linear( + in_features=in_features, + out_features=out_features, + bias=bias, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.fc(x) + + def set_block_test_weights(self) -> None: + """Set weights to exercise separate MXFP weight block scales. + + The first two logical 32-wide input blocks use different magnitudes so + tests can verify block scaling does not share one scale across blocks. + + """ + if self.fc.weight.shape[1] < 64: + raise ValueError( + "Block test weights require at least 64 input features (2 blocks), got " + f"{tuple(self.fc.weight.shape)}" + ) + + with torch.no_grad(): + self.fc.weight.zero_() + for row in range(self.fc.weight.shape[0]): + # Small values in the first block. + self.fc.weight[row, 0:32] = 1e-3 + # Large values in the next block to require a different scale. + self.fc.weight[row, 32:64] = 100.0 + if self.fc.bias is not None: + self.fc.bias.zero_() + + +def _is_linear(module: torch.nn.Module, _fqn: str) -> bool: + return isinstance(module, torch.nn.Linear) + + +def _test_mxfp_linear_eager_cpu( + test_data: torch.Tensor, + config: MXFPOpConfig, + frobenius_threshold: float, + cosine_threshold: float, +) -> None: + test_input, out_features, has_bias, set_block_weights = test_data() + in_features = test_input.shape[-1] + ref_model = Linear( + in_features=in_features, + out_features=out_features, + bias=has_bias, + ).eval() + if set_block_weights: + ref_model.set_block_test_weights() + test_model = copy.deepcopy(ref_model).eval() + + to_mxfp(test_model, config, filter_fn=_is_linear) + + test_output = test_model(test_input) + ref_output = ref_model(test_input) + + compare_rel_frobenius_and_cosine_similarity( + ref_output, + test_output, + quantization_parameters=None, + frobenius_threshold=frobenius_threshold, + cosine_threshold=cosine_threshold, + clean_reference=False, + ) + + +@common.parametrize("test_data", test_data_fp) +def test_mxfp_linear_eager_cpu(test_data: torch.Tensor) -> None: + """Check eager MXFP implementation. + + The Arm lowering tests compare lowered output against the eager CPU + implementation, so the eager implementation must be accurate for it to be + used as a reference in other tests. + + """ + _test_mxfp_linear_eager_cpu( + test_data, + MXFPOpConfig(), + frobenius_threshold=0.06, + cosine_threshold=0.995, + ) diff --git a/backends/arm/test/ops/test_repeat.py b/backends/arm/test/ops/test_repeat.py index 1a2f71183bb..3368864564d 100644 --- a/backends/arm/test/ops/test_repeat.py +++ b/backends/arm/test/ops/test_repeat.py @@ -85,6 +85,18 @@ def forward(self, x: torch.Tensor): (torch.randn(1, 1, 2, 2, dtype=torch.float16),), ), } +test_data_suite_fp8 = { + "2_x_2_fp8e4m3": lambda: ( + Repeat((2, 1)), + (torch.randn(3, 4, dtype=torch.float32).to(torch.float8_e4m3fn),), + "fp8e4m3", + ), + "4_x_4_fp8e5m2": lambda: ( + Repeat((1, 2, 3, 2)), + (torch.randn(1, 1, 2, 2, dtype=torch.float32).to(torch.float8_e5m2),), + "fp8e5m2", + ), +} @common.parametrize( @@ -102,6 +114,19 @@ def test_repeat_tosa_FP(test_data: Tuple): pipeline.run() +@common.parametrize("test_data", test_data_suite_fp8) +def test_repeat_tosa_FP_fp8(test_data: Tuple): + module, test_data, tosa_extension = test_data() + pipeline = TosaPipelineFP[input_t1]( + module, + test_data, + module.aten_op, + exir_op=[], + tosa_extensions=[tosa_extension], + ) + pipeline.run() + + @common.parametrize("test_data", test_data_suite) def test_repeat_tosa_INT(test_data: Tuple): module, test_data = test_data() diff --git a/backends/arm/test/ops/test_slice.py b/backends/arm/test/ops/test_slice.py index 090d8abb56a..28c9731a6aa 100644 --- a/backends/arm/test/ops/test_slice.py +++ b/backends/arm/test/ops/test_slice.py @@ -50,6 +50,18 @@ [(0, 1), (0, 5), (3, 5), (4, 10)], ), } +test_data_suite_fp8 = { + "ones_slice_4_fp8e4m3": lambda: ( + torch.ones((1, 12, 10, 10), dtype=torch.float32).to(torch.float8_e4m3fn), + [(0, 1), (0, 5), (3, 5), (4, 10)], + "fp8e4m3", + ), + "ones_slice_4_fp8e5m2": lambda: ( + torch.ones((1, 12, 10, 10), dtype=torch.float32).to(torch.float8_e5m2), + [(0, 1), (0, 5), (3, 5), (4, 10)], + "fp8e5m2", + ), +} class Slice(torch.nn.Module): @@ -72,6 +84,20 @@ def test_slice_tensor_tosa_FP_bf16(test_data: torch.Tensor): pipeline.run() +@common.parametrize("test_data", test_data_suite_fp8) +def test_slice_tensor_tosa_FP_fp8(test_data): + input_data, slices, tosa_extension = test_data() + pipeline = TosaPipelineFP[input_t1]( + Slice(), + (input_data, slices), + aten_op, + exir_op, + tosa_extensions=[tosa_extension], + ) + pipeline.count_tosa_ops({"SLICE": 3}) + pipeline.run() + + @common.parametrize("test_data", test_data_suite) def test_slice_tensor_tosa_INT_nchw(test_data: torch.Tensor): pipeline = TosaPipelineINT[input_t1]( diff --git a/backends/arm/test/ops/test_unfold_copy.py b/backends/arm/test/ops/test_unfold_copy.py index 2b502a9be10..baa4b7f64bc 100644 --- a/backends/arm/test/ops/test_unfold_copy.py +++ b/backends/arm/test/ops/test_unfold_copy.py @@ -120,6 +120,18 @@ def forward(self, input_: torch.Tensor, dim_: int, size_: int, step_: int): ), } +test_data_bf16: dict[str, input_params] = { + "test_bf16_2d_dim1": ( + torch.tensor( + [[0.1, 0.2, 0.3, 0.4, 0.5], [1.1, 1.2, 1.3, 1.4, 1.5]], + dtype=torch.bfloat16, + ), # [B=2, T=5] + 1, + 3, + 2, # U=(5-3)//2+1=2 -> [B=2, U=2, C=3] + ), +} + @common.parametrize("test_data", test_data_fp) def test_unfold_copy_tosa_FP(test_data: input_params): @@ -132,6 +144,18 @@ def test_unfold_copy_tosa_FP(test_data: input_params): pipeline.run() +@common.parametrize("test_data", test_data_bf16) +def test_unfold_copy_tosa_FP_bf16(test_data: input_params): + pipeline = TosaPipelineFP[input_params]( + UnfoldCopy(), + test_data, + aten_op=UnfoldCopy.aten_op, + exir_op=UnfoldCopy.exir_op, + tosa_extensions=["bf16"], + ) + pipeline.run() + + @common.parametrize("test_data", test_data_int | test_data_fp) def test_unfold_copy_tosa_INT(test_data: input_params): pipeline = TosaPipelineINT[input_params]( diff --git a/backends/arm/test/ops/test_upsample_nearest2d.py b/backends/arm/test/ops/test_upsample_nearest2d.py index 5781e4ed29d..d8bf4d7dbd5 100644 --- a/backends/arm/test/ops/test_upsample_nearest2d.py +++ b/backends/arm/test/ops/test_upsample_nearest2d.py @@ -198,6 +198,17 @@ def test_upsample_nearest2d_vec_tosa_FP_interpolate(test_data: torch.Tensor): pipeline.run() +def test_upsample_nearest2d_vec_tosa_does_not_delegate_exact_one_sixteenth_downscale(): + pipeline = OpNotSupportedPipeline[input_t1]( + Interpolate(size=None, scale_factor=1.0 / 16.0), + (torch.randn(1, 3, 256, 448),), + {exir_op: 1}, + n_expected_delegates=0, + ) + + pipeline.run() + + @common.parametrize("test_data", test_data_suite) def test_upsample_nearest2d_vec_tosa_INT(test_data: torch.Tensor): test_data, size, scale_factor, compare_outputs = test_data() diff --git a/backends/arm/test/ops/test_view.py b/backends/arm/test/ops/test_view.py index b1e62c3efef..ce5bf13f2b8 100644 --- a/backends/arm/test/ops/test_view.py +++ b/backends/arm/test/ops/test_view.py @@ -86,6 +86,48 @@ def test_view_tosa_FP(test_data: Tuple): pipeline.run() +class ViewPermuteFP8(torch.nn.Module): + def __init__(self, new_shape: tuple[int, ...], dims: tuple[int, ...]): + super().__init__() + self.new_shape = new_shape + self.dims = dims + + def forward(self, x: torch.Tensor): + # Use permute to keep the graph lowerable for FP8 tests, + # since the mul used in View is not supported with FP8. + return x.view(self.new_shape).permute(self.dims) + + +@common.parametrize( + "test_data", + { + "view_permute_fp8e4m3": lambda: ( + torch.rand((2, 3, 4), dtype=torch.float32).to(torch.float8_e4m3fn), + (2, 4, 3), + (0, 2, 1), + "fp8e4m3", + ), + "view_permute_fp8e5m2": lambda: ( + torch.rand((2, 3, 4), dtype=torch.float32).to(torch.float8_e5m2), + (2, 4, 3), + (0, 2, 1), + "fp8e5m2", + ), + }, +) +def test_view_tosa_FP_fp8_permute(test_data: Tuple): + test_tensor, new_shape, dims, tosa_extension = test_data() + pipeline = TosaPipelineFP[input_t1]( + ViewPermuteFP8(new_shape, dims), + (test_tensor,), + ["torch.ops.aten.view.default", "torch.ops.aten.permute.default"], + exir_op=[], + tosa_extensions=[tosa_extension], + ) + pipeline.count_tosa_ops({"RESHAPE": 1, "TRANSPOSE": 1}) + pipeline.run() + + @common.parametrize("test_data", View.test_suite) def test_view_tosa_INT(test_data: Tuple): test_tensor, new_shape = test_data() diff --git a/backends/arm/test/passes/test_arm_op_targeted_pass.py b/backends/arm/test/passes/test_arm_op_targeted_pass.py new file mode 100644 index 00000000000..5c213d4c4b9 --- /dev/null +++ b/backends/arm/test/passes/test_arm_op_targeted_pass.py @@ -0,0 +1,150 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import operator +from typing import Set, Type + +import torch +from executorch.backends.arm._passes.arm_pass import ArmOpTargetedPass +from executorch.backends.arm._passes.arm_pass_manager import ArmPassManager +from executorch.backends.arm.constants import DISALLOW_TFA_META_KEY +from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec +from executorch.backends.arm.tosa.specification import TosaSpecification +from executorch.exir.pass_base import ExportPass +from torch.fx import Graph, GraphModule +from torch.fx.passes.infra.pass_base import PassResult + + +TARGET_OP = torch.ops.aten.add.Tensor +OTHER_OP = operator.add + + +def create_graph_module(target=OTHER_OP, disallow_tfa: bool = False) -> GraphModule: + graph = Graph() + lhs = graph.placeholder("lhs") + rhs = graph.placeholder("rhs") + lhs.meta["val"] = torch.randn(2, 3) + rhs.meta["val"] = torch.randn(2, 3) + node = graph.call_function(target, (lhs, rhs)) + node.meta["val"] = torch.randn(2, 3) + if disallow_tfa: + node.meta[DISALLOW_TFA_META_KEY] = True + graph.output(node) + return GraphModule(torch.nn.Module(), graph) + + +def create_test_pass_manager() -> ArmPassManager: + compile_spec = TosaCompileSpec( + TosaSpecification.create_from_string("TOSA-1.00+INT") + ) + return ArmPassManager(compile_spec) + + +def run_single_pass(graph_module: GraphModule, test_pass: ExportPass) -> PassResult: + pass_manager = create_test_pass_manager() + pass_manager.add_pass(test_pass) + return pass_manager(graph_module) + + +class DummyTargetedPass(ArmOpTargetedPass): + _passes_required_after: Set[Type[ExportPass]] = set() + target_ops = (TARGET_OP,) + check_allowed_to_transform = True + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.call_operator_count = 0 + + def call_operator(self, op, args, kwargs, meta): + self.call_operator_count += 1 + return super().call_operator(op, args, kwargs, meta) + + +class InsertTargetPass(ExportPass): + def call(self, graph_module: GraphModule) -> PassResult: + graph = graph_module.graph + placeholders = [node for node in graph.nodes if node.op == "placeholder"] + output = next(node for node in graph.nodes if node.op == "output") + + with graph.inserting_before(output): + target_node = graph.call_function( + TARGET_OP, + (placeholders[0], placeholders[1]), + ) + target_node.meta["val"] = torch.randn(2, 3) + output.args = (target_node,) + graph.lint() + graph_module.recompile() + return PassResult(graph_module, True) + + +class CondModule(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + def true_branch(arg: torch.Tensor) -> torch.Tensor: + return arg + 1 + + def false_branch(arg: torch.Tensor) -> torch.Tensor: + return arg - 1 + + return torch.cond(x.sum() > 0, true_branch, false_branch, [x]) + + +def test_skips_when_target_is_absent() -> None: + graph_module = create_graph_module() + targeted_pass = DummyTargetedPass() + + result = run_single_pass(graph_module, targeted_pass) + + assert result is not None + assert result.graph_module is graph_module + assert not result.modified + assert targeted_pass.call_operator_count == 0 + + +def test_runs_when_target_is_present() -> None: + graph_module = create_graph_module(TARGET_OP) + targeted_pass = DummyTargetedPass() + + result = run_single_pass(graph_module, targeted_pass) + + assert result is not None + assert result.modified + assert targeted_pass.call_operator_count == 1 + + +def test_skips_tfa_disallowed_target() -> None: + graph_module = create_graph_module(TARGET_OP, disallow_tfa=True) + targeted_pass = DummyTargetedPass(tfa_pass=True) + + result = run_single_pass(graph_module, targeted_pass) + + assert result is not None + assert result.graph_module is graph_module + assert not result.modified + assert targeted_pass.call_operator_count == 0 + + +def test_runs_when_previous_pass_creates_target() -> None: + graph_module = create_graph_module() + pass_manager = create_test_pass_manager() + targeted_pass = DummyTargetedPass() + pass_manager.add_pass(InsertTargetPass()) + pass_manager.add_pass(targeted_pass) + result = pass_manager(graph_module) + + assert result.modified + assert targeted_pass.call_operator_count == 1 + + +def test_runs_when_target_is_present_in_nested_submodule() -> None: + exported_program = torch.export.export(CondModule(), (torch.randn(2, 3),)) + graph_module = exported_program.graph_module + targeted_pass = DummyTargetedPass() + + result = run_single_pass(graph_module, targeted_pass) + + assert result is not None + assert result.modified + assert targeted_pass.call_operator_count > 0 diff --git a/backends/arm/test/targets.bzl b/backends/arm/test/targets.bzl index 0a3faa6a074..78b0c6a8533 100644 --- a/backends/arm/test/targets.bzl +++ b/backends/arm/test/targets.bzl @@ -25,6 +25,7 @@ def define_arm_tests(): "ops/test_log10.py", "ops/test_max_pool1d.py", "ops/test_mul.py", + "ops/test_mxfp_linear.py", "ops/test_permute.py", "ops/test_rsqrt.py", "ops/test_slice.py", @@ -62,6 +63,7 @@ def define_arm_tests(): "misc/test_bn_relu_folding_qat.py", "misc/test_custom_partition.py", "misc/test_debug_hook.py", + "misc/test_mxfp_linear_ao.py", "misc/test_post_quant_device_switch.py", # "misc/test_dim_order.py", (TODO - T238390249) ] @@ -104,6 +106,7 @@ def define_arm_tests(): "//executorch/backends/arm/test:arm_tester" if runtime.is_oss else "//executorch/backends/arm/test/tester/fb:arm_tester_fb", "//executorch/backends/arm/test:conftest", "//executorch/backends/arm/test/misc:dw_convs_shared_weights_module", + "//executorch/backends/arm:ao_ext", "//executorch/backends/arm:ethosu", "//executorch/backends/arm/tosa:compile_spec", "//executorch/backends/arm/tosa:partitioner", diff --git a/backends/arm/test/test_arm_backend.sh b/backends/arm/test/test_arm_backend.sh index be48d7ad234..1cb9e135d00 100755 --- a/backends/arm/test/test_arm_backend.sh +++ b/backends/arm/test/test_arm_backend.sh @@ -302,11 +302,41 @@ test_deit_e2e_ethos_u() { test_model_smollm2_135M() { echo "${TEST_SUITE_NAME}: Test SmolLM2-135M on Ethos-U85" - # Build common libs once - python3 backends/arm/test/test_model.py --test_output=arm_test/test_model --build_libs - - python3 backends/arm/test/test_model.py --test_output=arm_test/test_model --target=ethos-u85-128 --model=smollm2 --extra_flags="-DEXECUTORCH_SELECT_OPS_LIST=dim_order_ops::_to_dim_order_copy.out" --specify_ethosu_scratch + backends/arm/scripts/build_executorch.sh + # Build pte for smollm2 + python3 -m extension.llm.export.export_llm \ + base.model_class=smollm2 \ + base.params=examples/models/smollm2/135M_config.json \ + debug.verbose=True model.enable_dynamic_shape=False quantization.pt2e_quantize="ethosu_8a8w" \ + backend.ethosu.enabled=True backend.ethosu.target="ethos-u85-256" backend.ethosu.memory_mode=Dedicated_Sram_384KB + + # Build the arm_executor_runner application, pre-loading the pte in the DDR for faster linking + local pte_addr="0x76000000" + backends/arm/scripts/build_executor_runner.sh \ + --et_build_root="${et_root_dir}/arm_test" \ + --pte="${pte_addr}" \ + --build_type=Release \ + --target=ethos-u85-256 \ + --system_config=Ethos_U85_SYS_DRAM_Mid \ + --memory_mode=Dedicated_Sram_384KB \ + --ethosu_tools_dir="${scratch_dir}" \ + --toolchain=arm-none-eabi-gcc \ + --extra_build_flags="-DET_ARM_BAREMETAL_SCRATCH_TEMP_ALLOCATOR_POOL_SIZE=0x20000" \ + --select_ops_list="dim_order_ops::_to_dim_order_copy.out" + + + # Deploy the application on the FVP in fast mode + FVP_Corstone_SSE-320 -C mps4_board.subsystem.ethosu.num_macs=256 \ + -C mps4_board.visualisation.disable-visualisation=1 \ + -C vis_hdlcd.disable_visualisation=1 \ + -C mps4_board.telnetterminal0.start_telnet=0 \ + -C mps4_board.uart0.out_file='-' \ + -C mps4_board.uart0.shutdown_on_eot=1 \ + -a "${et_root_dir}"/arm_test/ethos-u85-256_${pte_addr}/cmake-out/arm_executor_runner \ + -C mps4_board.subsystem.ethosu.extra_args="--fast" \ + --data smollm2.pte@"${pte_addr}" + echo "${TEST_SUITE_NAME}: PASS" } diff --git a/backends/arm/test/tester/analyze_output_utils.py b/backends/arm/test/tester/analyze_output_utils.py index 6a3bbd4d686..c68811eedad 100644 --- a/backends/arm/test/tester/analyze_output_utils.py +++ b/backends/arm/test/tester/analyze_output_utils.py @@ -337,6 +337,24 @@ def dump_error_output( logger.error(f"{atol=}, {rtol=}, {qtol=}") +def calculate_rel_frobenius_and_cosine_similarity( + reference_output: torch.Tensor, + test_output: torch.Tensor, +) -> tuple[float, float]: + reference_output = reference_output.to(torch.float32) + test_output = test_output.to(torch.float32) + + reference_frobenius_norm = torch.linalg.norm(reference_output).item() + error_frobenius_norm = torch.linalg.norm(test_output - reference_output).item() + + relative_frobenius_error = error_frobenius_norm / (reference_frobenius_norm + 1e-8) + cosine_similarity = torch.nn.functional.cosine_similarity( + test_output.flatten(), reference_output.flatten(), dim=0 + ).item() + + return relative_frobenius_error, cosine_similarity + + def compare_rel_frobenius_and_cosine_similarity( reference_output: torch.Tensor, test_output: torch.Tensor, @@ -394,15 +412,11 @@ def compare_rel_frobenius_and_cosine_similarity( if reference_all_zeros: return - reference_output = reference_output.to(torch.float32) - test_output = test_output.to(torch.float32) - - reference_frobenius_norm = torch.linalg.norm(reference_output).item() - error_frobenius_norm = torch.linalg.norm(test_output - reference_output).item() - - relative_frobenius_error = error_frobenius_norm / (reference_frobenius_norm + 1e-8) - cosine_similarity = torch.nn.functional.cosine_similarity( - test_output.flatten(), reference_output.flatten(), dim=0 + relative_frobenius_error, cosine_similarity = ( + calculate_rel_frobenius_and_cosine_similarity(reference_output, test_output) + ) + reference_frobenius_norm = torch.linalg.norm( + reference_output.to(torch.float32) ).item() # Relative Frobenius is unstable when the reference norm is at quantization-noise scale. diff --git a/backends/arm/test/tester/test_pipeline.py b/backends/arm/test/tester/test_pipeline.py index 7e7f576e35c..86a5f857e58 100644 --- a/backends/arm/test/tester/test_pipeline.py +++ b/backends/arm/test/tester/test_pipeline.py @@ -48,7 +48,7 @@ from executorch.backends.arm.vgf.compile_spec import VgfCompileSpec from executorch.backends.test.harness.stages import StageType from executorch.exir.pass_base import ExportPass -from torch._export.pass_base import PassType +from executorch.exir.pass_manager import PassType from torch.export.graph_signature import InputKind, OutputKind from torchao.quantization.pt2e.quantizer import QuantizationSpec diff --git a/backends/arm/third-party/ethos-u-core-driver b/backends/arm/third-party/ethos-u-core-driver deleted file mode 160000 index 03567073fe2..00000000000 --- a/backends/arm/third-party/ethos-u-core-driver +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 03567073fe2b9802c0bd73f9534da6f8a03924d1 diff --git a/backends/arm/tosa/BUCK b/backends/arm/tosa/BUCK index 46ff6648c54..81d1f62437f 100644 --- a/backends/arm/tosa/BUCK +++ b/backends/arm/tosa/BUCK @@ -41,6 +41,17 @@ fbcode_target(_kind = runtime.python_library, ], ) +fbcode_target(_kind = runtime.python_library, + name = "resize_utils", + srcs = [ + "resize_utils.py", + ], + deps = [ + "//caffe2:torch", + ":specification", + ], +) + fbcode_target(_kind = runtime.python_library, name = "tosa", srcs = [ diff --git a/backends/arm/tosa/backend.py b/backends/arm/tosa/backend.py index 6b864e284b1..b0cae15022d 100644 --- a/backends/arm/tosa/backend.py +++ b/backends/arm/tosa/backend.py @@ -23,9 +23,6 @@ import tosa_serializer as ts -from executorch.backends.arm._passes.arm_pass_utils import ( - get_cond_while_submodules_nested, -) from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec from executorch.backends.arm.common.debug import debug_fail, debug_tosa_dump from executorch.backends.arm.debug.schema import DebugHook @@ -35,9 +32,13 @@ process_placeholder, ) from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec -from executorch.backends.arm.tosa.mapping import TOSA_TENSOR_NAME_META +from executorch.backends.arm.tosa.mapping import ( + TOSA_CONTROL_FLOW_REGION_NAME_META, + TOSA_TENSOR_NAME_META, +) from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult from executorch.exir.backend.compile_spec_schema import CompileSpec +from executorch.exir.graph_module import get_cond_while_submodules from torch.export.exported_program import ExportedProgram from torch.fx import Graph, GraphModule, Node @@ -45,6 +46,15 @@ logger = logging.getLogger(__name__) +def _qualify_control_flow_region_name( + parent_region_name: str | None, child_region_name: str +) -> str: + """Return a globally unique TOSA region name for nested control flow.""" + if parent_region_name is None: + return child_region_name + return f"{parent_region_name}__{child_region_name}" + + def _annotate_external_ids(ep_graph: Graph) -> Dict[str, int]: """Assign deterministic output IDs to leaf outputs. @@ -325,6 +335,43 @@ def _preprocess_module( # noqa: C901 RuntimeError: If an FX node with an unsupported op kind is found. """ + + def _annotate_control_flow_region_names( + graph_module: GraphModule, parent_region_name: str | None + ) -> None: + for node in graph_module.graph.nodes: + if node.op != "call_function": + continue + + match node.target: + case torch.ops.higher_order.cond: + arg_indices = [1, 2] + case torch.ops.higher_order.while_loop: + arg_indices = [0, 1] + case _: + continue + + for arg_index in arg_indices: + submodule_node = node.args[arg_index] + if not isinstance(submodule_node, Node): + raise RuntimeError( + f"Expected control flow submodule arg {arg_index} to be a Node." + ) + if submodule_node.op != "get_attr": + raise RuntimeError( + f"Expected control flow submodule arg {arg_index} to be a get_attr node." + ) + if not isinstance(submodule_node.target, str): + raise RuntimeError( + "Expected control flow submodule target to be a string." + ) + + submodule_node.meta[TOSA_CONTROL_FLOW_REGION_NAME_META] = ( + _qualify_control_flow_region_name( + parent_region_name, submodule_node.target + ) + ) + tosa_spec = compile_spec.tosa_spec node_to_id_map = _annotate_external_ids(graph_module.graph) artifact_path = compile_spec._get_intermediate_path() @@ -348,6 +395,8 @@ def _preprocess_module( # noqa: C901 else: logger.debug("No re-sorting outputs (workaround) during TOSA lowering.") + _annotate_control_flow_region_names(graph_module, submodule_name) + if submodule_name is not None: tosa_graph.startRegion(submodule_name) tosa_graph.currRegion.addBasicBlock(submodule_name) @@ -396,7 +445,7 @@ def _preprocess_module( # noqa: C901 raise # Recursively preprocess controlflow submodules. - for name, submodule, control_flow_node in get_cond_while_submodules_nested( + for name, submodule, control_flow_node in get_cond_while_submodules( graph_module ): TOSABackend._regularize_submodule(submodule, control_flow_node) @@ -406,7 +455,7 @@ def _preprocess_module( # noqa: C901 compile_spec, tosa_graph, debug_hook, - submodule_name=name, + submodule_name=_qualify_control_flow_region_name(submodule_name, name), containing_graph_module=graph_module, ) diff --git a/backends/arm/tosa/dialect/BUCK b/backends/arm/tosa/dialect/BUCK index 4e7f5837766..5081f5d6945 100644 --- a/backends/arm/tosa/dialect/BUCK +++ b/backends/arm/tosa/dialect/BUCK @@ -22,6 +22,7 @@ fbcode_target(_kind = runtime.python_library, deps = [ ":core", "//caffe2:torch", + "//executorch/backends/arm/tosa:resize_utils", "//executorch/backends/arm/tosa:tosa", ], ) diff --git a/backends/arm/tosa/dialect/ops/gather.py b/backends/arm/tosa/dialect/ops/gather.py index 1e1982adae3..49374142cd6 100644 --- a/backends/arm/tosa/dialect/ops/gather.py +++ b/backends/arm/tosa/dialect/ops/gather.py @@ -42,6 +42,8 @@ def GATHER(values: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: torch.float16, torch.float32, torch.bfloat16, + torch.float8_e4m3fn, + torch.float8_e5m2, ) if values.dtype not in allowed_values_dtypes: raise TosaValueError( @@ -57,6 +59,16 @@ def GATHER(values: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: op="GATHER", ) else: + required_extension = { + torch.bfloat16: "bf16", + torch.float8_e4m3fn: "fp8e4m3", + torch.float8_e5m2: "fp8e5m2", + }.get(values.dtype) + if required_extension and not tosa_spec.support_extension(required_extension): + raise TosaValueError( + f"dtype {values.dtype} requires {required_extension} extension.", + op="GATHER", + ) # Support in FP profile, or INT profile via quantization if not (tosa_spec.support_float() or tosa_spec.support_integer()): raise TosaValueError( diff --git a/backends/arm/tosa/dialect/ops/pad.py b/backends/arm/tosa/dialect/ops/pad.py index db2cab6fcfc..3b5628b0ede 100644 --- a/backends/arm/tosa/dialect/ops/pad.py +++ b/backends/arm/tosa/dialect/ops/pad.py @@ -33,6 +33,10 @@ def PAD(a: torch.Tensor, padding: List[int | torch.SymInt], *, value): supported_dtypes.update({torch.float16, torch.float32}) if tosa_spec.support_extension("bf16"): supported_dtypes.add(torch.bfloat16) + if tosa_spec.support_extension("fp8e4m3"): + supported_dtypes.add(torch.float8_e4m3fn) + if tosa_spec.support_extension("fp8e5m2"): + supported_dtypes.add(torch.float8_e5m2) if a.dtype not in supported_dtypes: raise TosaValueError( f"Input tensor dtype {a.dtype} is not supported by the target TOSA specification." diff --git a/backends/arm/tosa/dialect/ops/resize.py b/backends/arm/tosa/dialect/ops/resize.py index c48ff508afc..0d06253ccd8 100644 --- a/backends/arm/tosa/dialect/ops/resize.py +++ b/backends/arm/tosa/dialect/ops/resize.py @@ -8,6 +8,11 @@ import torch from executorch.backends.arm.tosa.dialect.lib import TosaValueError from executorch.backends.arm.tosa.dialect.ops_registration import register_fake_tosa_op +from executorch.backends.arm.tosa.resize_utils import ( + calculate_tosa_resize_output_hw, + get_tosa_resize_output_hw_validation_error, + get_tosa_resize_validation_error, +) from executorch.backends.arm.tosa.specification import ( get_context_spec, @@ -50,23 +55,17 @@ def _get_output_dtype( return output_dtype -def _validate_resize_parameters(scale, border, resize_mode): - def in_int16_range(values): - return all( - (x >= -(2**15)) and (x <= 2**15 - 1) for x in values if isinstance(x, int) - ) - - if not in_int16_range(scale): - raise TosaValueError("scale is out of the int16 range", op="RESIZE") - if not in_int16_range(border): - raise TosaValueError("border is out of the int16 range", op="RESIZE") - if resize_mode == "bilinear": - scale_y_n, scale_y_d, scale_x_n, scale_x_d = scale - if scale_y_d >= 16 * scale_y_n or scale_x_d >= 16 * scale_x_n: - raise TosaValueError( - "Bilinear RESIZE downscale must be strictly greater than 1/16", - op="RESIZE", - ) +def _validate_resize_parameters(input_hw, output_hw, scale, offset, border, tosa_spec): + validation_error = get_tosa_resize_validation_error( + input_hw=input_hw, + output_hw=output_hw, + scale=scale, + offset=offset, + border=border, + tosa_spec=tosa_spec, + ) + if validation_error is not None: + raise TosaValueError(validation_error, op="RESIZE") @register_fake_tosa_op( @@ -88,24 +87,28 @@ def RESIZE( f"Input tensor must be 4D, but got {x.dim()}D", op="RESIZE" ) _validate_resize_mode(resize_mode) - _validate_resize_parameters(scale, border, resize_mode) output_dtype = _get_output_dtype(x.dtype, tosa_spec, resize_mode) input_shape = x.shape - scale_y_n, scale_y_d, scale_x_n, scale_x_d = scale - offset_y, offset_x = offset - border_y, border_x = border H, W = input_shape[1], input_shape[2] - # RESIZE first upscales the input by an integer value, to "upscale space". - H_upscaled = (H - 1) * scale_y_n - # offset and border are provided in this scale, therefore adjust for these while in this space. - H_shifted = H_upscaled - offset_y + border_y - # Then, complete the RESIZE by downscaling with another integer value, approximating multplication with a fraction. - OH = (H_shifted // scale_y_d) + 1 - # Mirror the same computation horizontally for the output width. - W_upscaled = (W - 1) * scale_x_n - W_shifted = W_upscaled - offset_x + border_x - OW = (W_shifted // scale_x_d) + 1 + _validate_resize_parameters((H, W), None, scale, offset, border, tosa_spec) + output_hw = calculate_tosa_resize_output_hw((H, W), scale, offset, border) + validation_error = get_tosa_resize_output_hw_validation_error(output_hw) + if validation_error is not None: + raise TosaValueError(validation_error, op="RESIZE") + if output_hw is None: + scale_y_n, scale_y_d, scale_x_n, scale_x_d = scale + offset_y, offset_x = offset + border_y, border_x = border + # RESIZE first upscales the input by an integer value to "upscale + # space". Offset and border are encoded in that space, then RESIZE + # completes by downscaling with another integer value, approximating + # multiplication by a fraction. + OH = ((H - 1) * scale_y_n - offset_y + border_y) // scale_y_d + 1 + OW = ((W - 1) * scale_x_n - offset_x + border_x) // scale_x_d + 1 + else: + OH, OW = output_hw + fake_aten_tensor = torch.empty( size=(input_shape[0], OH, OW, input_shape[3]), dtype=output_dtype ) diff --git a/backends/arm/tosa/dialect/ops/slice.py b/backends/arm/tosa/dialect/ops/slice.py index 553c8dd489e..3406ccf911b 100644 --- a/backends/arm/tosa/dialect/ops/slice.py +++ b/backends/arm/tosa/dialect/ops/slice.py @@ -52,6 +52,10 @@ def SLICE(a, start, size): supported_dtypes += [torch.float16, torch.float32] if tosa_spec.support_extension("bf16"): supported_dtypes += [torch.bfloat16] + if tosa_spec.support_extension("fp8e4m3"): + supported_dtypes += [torch.float8_e4m3fn] + if tosa_spec.support_extension("fp8e5m2"): + supported_dtypes += [torch.float8_e5m2] if a.dtype not in supported_dtypes: raise TosaValueError( diff --git a/backends/arm/tosa/mapping.py b/backends/arm/tosa/mapping.py index b37c41a070b..0e91120c3b8 100644 --- a/backends/arm/tosa/mapping.py +++ b/backends/arm/tosa/mapping.py @@ -17,6 +17,7 @@ import tosa_serializer as ts from executorch.backends.arm.tosa.specification import TosaSpecification +TOSA_CONTROL_FLOW_REGION_NAME_META = "tosa_control_flow_region_name" TOSA_TENSOR_NAME_META = "tosa_tensor_name" UNSUPPORTED_DTYPES = ( diff --git a/backends/arm/tosa/partitioner.py b/backends/arm/tosa/partitioner.py index bd900f4cc81..37b9cd7cc2a 100644 --- a/backends/arm/tosa/partitioner.py +++ b/backends/arm/tosa/partitioner.py @@ -21,10 +21,7 @@ from typing import Callable, cast, List, Optional, Sequence, Tuple import torch -from executorch.backends.arm._passes.arm_pass_utils import ( - get_cond_while_submodules_nested, - get_first_fake_tensor, -) +from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor from executorch.backends.arm._passes.convert_expand_copy_to_repeat import ( calculate_multiples, ) @@ -43,6 +40,7 @@ ) from executorch.exir.backend.utils import tag_constant_data, WhyNoPartitionReporter from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.graph_module import get_cond_while_submodules from torch.export.exported_program import ExportedProgram from torch.fx import GraphModule from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition @@ -400,7 +398,7 @@ def _tag_module( # noqa tags: set[str] = set() if tag_iterator is None: tag_iterator = count(0) - for _, submodule, _ in get_cond_while_submodules_nested(module): + for _, submodule, _ in get_cond_while_submodules(module): submodule_tags = self._tag_module( submodule, containing_program, reporter, tag_iterator ) @@ -552,7 +550,10 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult: partition_tags = {tag: self.delegation_spec for tag in tags} tag_constant_data(exported_program) - if self.intermediate_path is not None and logger.level <= logging.INFO: + if ( + self.intermediate_path is not None + and logger.getEffectiveLevel() <= logging.INFO + ): intermediate_path = Path(self.intermediate_path) intermediate_path.mkdir(parents=True, exist_ok=True) file_handler = logging.FileHandler( diff --git a/backends/arm/tosa/resize_utils.py b/backends/arm/tosa/resize_utils.py new file mode 100644 index 00000000000..23be6ff42fc --- /dev/null +++ b/backends/arm/tosa/resize_utils.py @@ -0,0 +1,278 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Sequence + +import torch + +from executorch.backends.arm.tosa.specification import TosaSpecification + +_MAX_RESIZE_DIMENSION = 16384 +_MAX_RESIZE_SCALE_NUMERATOR = 1 << 11 +_MAX_SCALE = 2048 +_MAX_SCALE_LEVEL_8K = 256 +_INT16_MIN = -(2**15) +_INT16_MAX = 2**15 - 1 + + +def _as_concrete_ints(values: Sequence[int | torch.SymInt]) -> list[int] | None: + if all(isinstance(value, int) for value in values): + return [int(value) for value in values] + return None + + +def _concrete_int_values(values: Sequence[int | torch.SymInt]) -> list[int]: + return [int(value) for value in values if isinstance(value, int)] + + +def _first_outside_range( + values: Sequence[int], min_value: int, max_value: int +) -> int | None: + return next( + (value for value in values if value < min_value or value > max_value), None + ) + + +def _max_scale(tosa_spec: TosaSpecification) -> int: + return _MAX_SCALE_LEVEL_8K if getattr(tosa_spec, "level_8k", False) else _MAX_SCALE + + +def _validate_dimensions( + input_hw: Sequence[int | torch.SymInt], + output_hw: Sequence[int | torch.SymInt] | None, +) -> str | None: + concrete_dimensions: list[int] = [] + input_hw_ints = _as_concrete_ints(input_hw) + output_hw_ints = _as_concrete_ints(output_hw) if output_hw is not None else None + if input_hw_ints is not None: + concrete_dimensions.extend(input_hw_ints) + if output_hw_ints is not None: + concrete_dimensions.extend(output_hw_ints) + + invalid_dimension = next( + ( + dimension + for dimension in concrete_dimensions + if dimension >= _MAX_RESIZE_DIMENSION + ), + None, + ) + if invalid_dimension is not None: + return ( + "RESIZE dimensions must be less than " + f"{_MAX_RESIZE_DIMENSION}; got {invalid_dimension}" + ) + return None + + +def get_tosa_resize_output_hw_validation_error( + output_hw: Sequence[int | torch.SymInt] | None, +) -> str | None: + if output_hw is None: + return None + + output_hw_ints = _as_concrete_ints(output_hw) + if output_hw_ints is None: + return None + + invalid_dimension = next( + (dimension for dimension in output_hw_ints if dimension <= 0), None + ) + if invalid_dimension is not None: + return f"RESIZE output dimensions must be positive; got {invalid_dimension}" + + return _validate_dimensions((), output_hw) + + +def _validate_scale( + scale: Sequence[int | torch.SymInt], + tosa_spec: TosaSpecification, +) -> str | None: + invalid_scale = _first_outside_range( + _concrete_int_values(scale), _INT16_MIN, _INT16_MAX + ) + if invalid_scale is not None: + return ( + "RESIZE scale must be in int16 range " + f"[{_INT16_MIN}, {_INT16_MAX}]; got {invalid_scale}" + ) + + scale_ints = _as_concrete_ints(scale) + if scale_ints is None: + return None + + scale_y_n, scale_y_d, scale_x_n, scale_x_d = scale_ints + if min(scale_y_n, scale_y_d, scale_x_n, scale_x_d) <= 0: + return f"RESIZE scale values must be positive; got {scale_ints}" + + max_scale = _max_scale(tosa_spec) + if scale_y_n > max_scale * scale_y_d or scale_x_n > max_scale * scale_x_d: + return ( + f"RESIZE scale ratio must be <= MAX_SCALE ({max_scale}); " + f"got y={scale_y_n}/{scale_y_d}, x={scale_x_n}/{scale_x_d}" + ) + + if ( + scale_y_n > _MAX_RESIZE_SCALE_NUMERATOR + or scale_x_n > _MAX_RESIZE_SCALE_NUMERATOR + ): + return ( + "RESIZE scale numerator must be <= " + f"{_MAX_RESIZE_SCALE_NUMERATOR}; got y={scale_y_n}, x={scale_x_n}" + ) + + # The scale values are already in the doubled rational representation that + # TOSA RESIZE lowering emits, so the lower-bound downscale rule can be + # checked directly against them. + if scale_y_d >= 16 * scale_y_n or scale_x_d >= 16 * scale_x_n: + return ( + "RESIZE downscale must be strictly greater than 1/16; " + f"got y={scale_y_n}/{scale_y_d}, x={scale_x_n}/{scale_x_d}" + ) + return None + + +def _validate_offset( + offset: Sequence[int | torch.SymInt], + scale_ints: list[int], +) -> str | None: + offset_ints = _as_concrete_ints(offset) + if offset_ints is None: + return None + + scale_y_n, _, scale_x_n, _ = scale_ints + offset_y, offset_x = offset_ints + if offset_y < -scale_y_n or offset_y >= 16 * scale_y_n: + return ( + f"RESIZE offset_y must be in [{-scale_y_n}, {16 * scale_y_n}); " + f"got {offset_y}" + ) + if offset_x < -scale_x_n or offset_x >= 16 * scale_x_n: + return ( + f"RESIZE offset_x must be in [{-scale_x_n}, {16 * scale_x_n}); " + f"got {offset_x}" + ) + return None + + +def _validate_border( + border: Sequence[int | torch.SymInt], + scale_ints: list[int], +) -> str | None: + invalid_border = _first_outside_range( + _concrete_int_values(border), _INT16_MIN, _INT16_MAX + ) + if invalid_border is not None: + return ( + "RESIZE border must be in int16 range " + f"[{_INT16_MIN}, {_INT16_MAX}]; got {invalid_border}" + ) + + border_ints = _as_concrete_ints(border) + if border_ints is None: + return None + + scale_y_n, _, scale_x_n, _ = scale_ints + border_y, border_x = border_ints + if border_y < -16 * scale_y_n or border_y >= scale_y_n: + return ( + f"RESIZE border_y must be in [{-16 * scale_y_n}, {scale_y_n}); " + f"got {border_y}" + ) + if border_x < -16 * scale_x_n or border_x >= scale_x_n: + return ( + f"RESIZE border_x must be in [{-16 * scale_x_n}, {scale_x_n}); " + f"got {border_x}" + ) + return None + + +def _validate_output_shape( + input_hw: Sequence[int | torch.SymInt], + output_hw: Sequence[int | torch.SymInt] | None, + scale: Sequence[int | torch.SymInt], + offset: Sequence[int | torch.SymInt], + border: Sequence[int | torch.SymInt], +) -> str | None: + if output_hw is None: + return None + + output_hw_ints = _as_concrete_ints(output_hw) + expected_output_hw = calculate_tosa_resize_output_hw( + input_hw, scale, offset, border + ) + if ( + output_hw_ints is not None + and expected_output_hw is not None + and tuple(output_hw_ints) != expected_output_hw + ): + return ( + "RESIZE output shape is inconsistent with input and parameters; " + f"expected {expected_output_hw}, got {tuple(output_hw_ints)}" + ) + return None + + +def calculate_tosa_resize_output_hw( + input_hw: Sequence[int | torch.SymInt], + scale: Sequence[int | torch.SymInt], + offset: Sequence[int | torch.SymInt], + border: Sequence[int | torch.SymInt], +) -> tuple[int, int] | None: + input_hw_ints = _as_concrete_ints(input_hw) + scale_ints = _as_concrete_ints(scale) + offset_ints = _as_concrete_ints(offset) + border_ints = _as_concrete_ints(border) + if ( + input_hw_ints is None + or scale_ints is None + or offset_ints is None + or border_ints is None + ): + return None + + input_h, input_w = input_hw_ints + scale_y_n, scale_y_d, scale_x_n, scale_x_d = scale_ints + offset_y, offset_x = offset_ints + border_y, border_x = border_ints + + # RESIZE first upscales the input by an integer value to "upscale space". + # Offset and border are encoded in that space, then RESIZE completes by + # downscaling with another integer value, approximating multiplication by a + # fraction. + return ( + ((input_h - 1) * scale_y_n - offset_y + border_y) // scale_y_d + 1, + ((input_w - 1) * scale_x_n - offset_x + border_x) // scale_x_d + 1, + ) + + +def get_tosa_resize_validation_error( + *, + input_hw: Sequence[int | torch.SymInt], + output_hw: Sequence[int | torch.SymInt] | None, + scale: Sequence[int | torch.SymInt], + offset: Sequence[int | torch.SymInt], + border: Sequence[int | torch.SymInt], + tosa_spec: TosaSpecification, +) -> str | None: + scale_ints = _as_concrete_ints(scale) + + validation_error = _validate_dimensions(input_hw, output_hw) + if validation_error is not None: + return validation_error + validation_error = _validate_scale(scale, tosa_spec) + if validation_error is not None: + return validation_error + if scale_ints is None: + return None + + for validation_error in ( + _validate_offset(offset, scale_ints), + _validate_border(border, scale_ints), + _validate_output_shape(input_hw, output_hw, scale, offset, border), + ): + if validation_error is not None: + return validation_error + return None diff --git a/backends/cadence/aot/BUCK b/backends/cadence/aot/BUCK index 7d8ff3cffd2..57b8194c7f8 100644 --- a/backends/cadence/aot/BUCK +++ b/backends/cadence/aot/BUCK @@ -44,7 +44,6 @@ fbcode_target(_kind = runtime.python_library, ":compiler_funcs", ":utils", "//caffe2:torch", - "//executorch/backends/cadence/aot/quantizer:fusion_pass", "//executorch/backends/cadence/aot/quantizer/passes:fuse_ops", "//executorch/backends/cadence/aot/quantizer:quantizer", "//executorch/backends/transforms:decompose_sdpa", @@ -65,7 +64,6 @@ fbcode_target(_kind = runtime.python_library, ":replace_ops", ":utils", "//caffe2:torch", - "//executorch/backends/cadence/aot/quantizer:fusion_pass", "//executorch/backends/cadence/aot/quantizer:quantizer", "//executorch/backends/cadence/runtime:runtime", "//executorch/backends/transforms:decompose_sdpa", diff --git a/backends/cadence/aot/compiler.py b/backends/cadence/aot/compiler.py index 5c66c9eb62b..0b1b8dac361 100644 --- a/backends/cadence/aot/compiler.py +++ b/backends/cadence/aot/compiler.py @@ -14,6 +14,7 @@ import torch from executorch.backends.cadence.aot.compiler_funcs import ( prepare as prepare_fn, + QuantFusionPass, QuantizedInputWrapper, trace as trace_fn, ) @@ -21,7 +22,6 @@ CadenceMemoryPlanning, print_memory_planning_info, ) -from executorch.backends.cadence.aot.quantizer.fusion_pass import QuantFusion from executorch.backends.cadence.aot.quantizer.passes.fuse_ops import FuseQATConvBN from executorch.backends.cadence.aot.quantizer.quantizer import ( CadenceDefaultQuantizer, @@ -154,9 +154,9 @@ def apply_pre_edge_transform_passes( quantizer: CadenceQuantizer, ) -> ExportedProgram: """ - Apply pre-edge transform passes including QuantFusion and torch ops passes. + Apply pre-edge transform passes including QuantFusionPass and torch ops passes. This mirrors the Cadence AOT compiler flow: - 1. QuantFusion - fuses dq->op->q patterns + 1. QuantFusionPass - fuses dq->op->q patterns 2. apply_torch_ops_passes - applied just before to_edge() The quantizer must be the same as the one used to convert the model. @@ -169,7 +169,7 @@ def apply_pre_edge_transform_passes( PassManager( [ FuseQATConvBN(converted_program), - QuantFusion(patterns), + QuantFusionPass(patterns), ] )(converted_program.graph_module) diff --git a/backends/cadence/aot/compiler_funcs.py b/backends/cadence/aot/compiler_funcs.py index 02dcde7fd39..cec3cb7d016 100644 --- a/backends/cadence/aot/compiler_funcs.py +++ b/backends/cadence/aot/compiler_funcs.py @@ -14,6 +14,7 @@ import torch from torch._inductor.decomposition import remove_decompositions from torch.fx import GraphModule +from torch.fx.passes.infra.pass_base import PassBase, PassResult from torchao.quantization.pt2e.quantize_pt2e import prepare_pt2e, prepare_qat_pt2e from torchao.quantization.pt2e.quantizer import Quantizer @@ -607,3 +608,32 @@ def sink_input_dequant_through_transparent_ops( graph_module.recompile() return modified + + +class QuantFusionPass(PassBase): + """ + Iterates patterns, finds anchor ops in the converted graph, and calls + pattern.fuse() to replace dq-op-q subgraphs with fused ops. + """ + + def __init__(self, patterns: Sequence[object]) -> None: + super().__init__() + self.patterns = patterns + + def call(self, graph_module: GraphModule) -> Optional[PassResult]: + changed = False + for pattern in self.patterns: + pattern_changed = False + for target in pattern.anchor_ops(): # pyre-ignore[16] + for node in graph_module.graph.find_nodes( + op="call_function", target=target + ): + result = pattern.fuse(graph_module, node) # pyre-ignore[16] + if result is not None: + changed = True + pattern_changed = True + if pattern_changed: + graph_module.graph.eliminate_dead_code() + if changed: + graph_module.recompile() + return PassResult(graph_module, changed) diff --git a/backends/cadence/aot/pass_utils.py b/backends/cadence/aot/pass_utils.py index ab42ef43d56..091605e94ec 100644 --- a/backends/cadence/aot/pass_utils.py +++ b/backends/cadence/aot/pass_utils.py @@ -212,3 +212,20 @@ def nodes_not_adjacent_in_gm( def none_throws(x: Optional[PassResult]) -> PassResult: assert x is not None return x + + +def replace_with_op( + gm: torch.fx.GraphModule, + insert_after: torch.fx.Node, + replacement_op: torch._ops.OpOverload, + args: tuple, # pyre-ignore[2] + kwargs: dict, # pyre-ignore[2] + node_to_replace: torch.fx.Node, +) -> torch.fx.Node: + """Insert ``replacement_op`` after ``insert_after`` and replace all uses of + ``node_to_replace`` with the new node.""" + with gm.graph.inserting_after(insert_after): + new_node = gm.graph.call_function(replacement_op, args, kwargs) + new_node.meta = node_to_replace.meta + node_to_replace.replace_all_uses_with(new_node) + return new_node diff --git a/backends/cadence/aot/quantizer/BUCK b/backends/cadence/aot/quantizer/BUCK index 34fec2556f8..956bf700bd7 100644 --- a/backends/cadence/aot/quantizer/BUCK +++ b/backends/cadence/aot/quantizer/BUCK @@ -14,6 +14,21 @@ fbcode_target(_kind = runtime.python_library, ], ) +fbcode_target(_kind = runtime.python_library, + name = "pattern_utils", + srcs = [ + "pattern_utils.py", + ], + typing = True, + deps = [ + ":utils", + "//caffe2:torch", + "//executorch/backends/cadence/aot:compiler_utils", + "//executorch/backends/cadence/aot:pass_utils", + "//executorch/backends/cadence/aot:utils", + ], +) + fbcode_target(_kind = runtime.python_library, name = "patterns", srcs = [ @@ -21,8 +36,10 @@ fbcode_target(_kind = runtime.python_library, ], typing = True, deps = [ + ":pattern_utils", ":utils", "//caffe2:torch", + "//executorch/backends/cadence/aot:pass_utils", ], ) diff --git a/backends/cadence/aot/quantizer/pattern_utils.py b/backends/cadence/aot/quantizer/pattern_utils.py new file mode 100644 index 00000000000..25ff363ecc9 --- /dev/null +++ b/backends/cadence/aot/quantizer/pattern_utils.py @@ -0,0 +1,207 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import operator +from typing import Any + +import torch +from executorch.backends.cadence.aot.pass_utils import get_arg, replace_with_op +from executorch.backends.cadence.aot.quantizer.utils import ( + copy_node_metadata, + create_zero_bias_int32, + quantize_tensor_multiplier, +) +from executorch.backends.cadence.aot.utils import is_depthwise_conv +from torch import fx +from torch._ops import OpOverload + +DQ_PER_TENSOR: OpOverload = torch.ops.quantized_decomposed.dequantize_per_tensor.default +Q_PER_TENSOR: OpOverload = torch.ops.quantized_decomposed.quantize_per_tensor.default + + +def insert_node_with_meta( + gm: fx.GraphModule, + op: OpOverload, + args: tuple[Any, ...], + kwargs: dict[str, Any] | None, + insert_before: fx.Node, + like_node: fx.Node, +) -> fx.Node: + """Create a new node and populate its FakeTensor metadata. + + Inserts ``op(*args, **kwargs)`` before ``insert_before``, runs the op + under ``like_node``'s fake_mode to compute ``meta["val"]``, and copies + remaining metadata from ``like_node``. + """ + with gm.graph.inserting_before(insert_before): + node = gm.graph.call_function(op, args, kwargs or {}) + assert "val" in like_node.meta + fake_mode = like_node.meta["val"].fake_mode + assert fake_mode is not None + + def _resolve(x: Any) -> Any: + return x.meta["val"] if isinstance(x, fx.Node) else x + + fake_args = tuple(_resolve(a) for a in args) + fake_kwargs = {k: _resolve(v) for k, v in (kwargs or {}).items()} + with fake_mode: + node.meta["val"] = op(*fake_args, **fake_kwargs) + copy_node_metadata(node, like_node) + return node + + +def find_quant_user(node: fx.Node) -> fx.Node | None: + """Find the first quantize_per_tensor user of ``node``, traversing through getitem.""" + users = list(node.users) + if not users: + return None + user = users[0] + if user.target is operator.getitem: + if user.args[1] == 0: + users = list(user.users) + if not users: + return None + user = users[0] + else: + return None + if user.target == Q_PER_TENSOR: + return user + return None + + +def fuse_conv( + pattern: object, + gm: fx.GraphModule, + conv_node: fx.Node, + dq_input: fx.Node, + dq_weight: fx.Node, + quant_node: fx.Node, +) -> fx.Node: + """Fuse a dq->conv->q chain into a single quantized conv op.""" + dq_bias = None + if len(conv_node.args) > 2 and conv_node.args[2] is not None: + bias_arg = conv_node.args[2] + assert isinstance(bias_arg, fx.Node) + dq_bias = bias_arg if bias_arg.target == DQ_PER_TENSOR else None + weight_scale = get_arg(dq_weight, "scale", float) + input_scale = get_arg(dq_input, "scale", float) + bias_scale = input_scale * weight_scale + if dq_bias is not None: + bias_q = get_arg(dq_bias, "input", fx.Node) + else: + # Cadence quantized conv ops require a non-optional bias argument. + weight_node = get_arg(dq_weight, "input", fx.Node) + with gm.graph.inserting_before(conv_node): + bias_q = create_zero_bias_int32(gm, weight_node, bias_scale) + requantize_scale = bias_scale / get_arg(quant_node, "scale", float) + requantize_scale_t = torch.tensor([requantize_scale]) + out_multiplier, out_shift = quantize_tensor_multiplier(requantize_scale_t) + args = ( + get_arg(dq_input, "input", fx.Node), + get_arg(dq_weight, "input", fx.Node), + bias_q, + ) + groups = get_arg(conv_node, "groups", int) + kwargs = { + "stride": get_arg(conv_node, "stride", list[int]), + "padding": get_arg(conv_node, "padding", list[int]), + "dilation": get_arg(conv_node, "dilation", list[int]), + "groups": groups, + "input_zero_point": get_arg(dq_input, "zero_point", int), + "weight_zero_point": get_arg(dq_weight, "zero_point", int), + "bias_scale": bias_scale, + "out_scale": get_arg(quant_node, "scale", float), + "out_zero_point": get_arg(quant_node, "zero_point", int), + "out_multiplier": out_multiplier[0].item(), + "out_shift": out_shift[0].item(), + } + replacement_op = pattern.replacement_op() # pyre-ignore[16] + if replacement_op == torch.ops.cadence.quantized_conv1d_ncl.per_tensor: + input_node = get_arg(dq_input, "input", fx.Node) + assert len(input_node.meta["val"].shape) >= 2 + in_channels = input_node.meta["val"].shape[1] + if is_depthwise_conv(groups, in_channels): + replacement_op = torch.ops.cadence.quantized_depthwise_conv1d_ncl.per_tensor + return replace_with_op(gm, conv_node, replacement_op, args, kwargs, quant_node) + + +def fuse_linear( + gm: fx.GraphModule, + dq_input: fx.Node, + dq_weight: fx.Node, + dq_bias: fx.Node | None, + quant_node: fx.Node, + op_node: fx.Node, + replacement_op: OpOverload, + weight_q: fx.Node | None = None, +) -> fx.Node: + """Fuse a dq->linear->q chain into a single quantized linear op.""" + assert op_node.target in ( + torch.ops.aten.linear.default, + torch.ops.aten.addmm.default, + ), f"Expected linear/addmm, got {op_node.target}" + weight_scale = get_arg(dq_weight, "scale", float) + input_scale = get_arg(dq_input, "scale", float) + bias_scale = input_scale * weight_scale + requantize_scale = bias_scale / get_arg(quant_node, "scale", float) + requantize_scale_t = torch.tensor([requantize_scale]) + out_multiplier, out_shift = quantize_tensor_multiplier(requantize_scale_t) + if dq_bias is not None: + bias_q = get_arg(dq_bias, "input", fx.Node) + else: + # Cadence quantized linear ops require a non-optional bias argument. + weight_node = get_arg(dq_weight, "input", fx.Node) + with gm.graph.inserting_before(op_node): + bias_q = create_zero_bias_int32(gm, weight_node, bias_scale) + final_weight = ( + weight_q if weight_q is not None else get_arg(dq_weight, "input", fx.Node) + ) + args = (get_arg(dq_input, "input", fx.Node), final_weight, bias_q) + kwargs = { + "src_zero_point": get_arg(dq_input, "zero_point", int), + "weight_zero_point": get_arg(dq_weight, "zero_point", int), + "out_multiplier": out_multiplier[0].item(), + "out_shift": out_shift[0].item(), + "out_zero_point": get_arg(quant_node, "zero_point", int), + "offset": None, + } + return replace_with_op(gm, op_node, replacement_op, args, kwargs, quant_node) + + +def fuse_matmul( + gm: fx.GraphModule, + anchor_node: fx.Node, + dq0: fx.Node, + dq1: fx.Node, + quant_node: fx.Node, + replacement_op: OpOverload, +) -> fx.Node: + """Fuse a dq->matmul->q chain into a single quantized matmul op.""" + assert anchor_node.target in ( + torch.ops.aten.bmm.default, + torch.ops.aten.matmul.default, + ), f"Expected bmm/matmul, got {anchor_node.target}" + scale0 = get_arg(dq0, "scale", float) + scale1 = get_arg(dq1, "scale", float) + requantize_scale = (scale0 * scale1) / get_arg(quant_node, "scale", float) + requantize_scale_t = torch.tensor([requantize_scale]) + out_multiplier, out_shift = quantize_tensor_multiplier(requantize_scale_t) + args = ( + get_arg(dq0, "input", fx.Node), + get_arg(dq0, "zero_point", int), + get_arg(dq1, "input", fx.Node), + get_arg(dq1, "zero_point", int), + None, + ) + kwargs = { + "out_multiplier": out_multiplier[0].item(), + "out_shift": out_shift[0].item(), + "out_zero_point": get_arg(quant_node, "zero_point", int), + "transposed": False, + } + return replace_with_op(gm, anchor_node, replacement_op, args, kwargs, quant_node) diff --git a/backends/cadence/aot/quantizer/patterns.py b/backends/cadence/aot/quantizer/patterns.py index 54c01227d07..9897d443725 100644 --- a/backends/cadence/aot/quantizer/patterns.py +++ b/backends/cadence/aot/quantizer/patterns.py @@ -9,11 +9,24 @@ import operator from abc import ABC, abstractmethod from dataclasses import dataclass, field -from typing import List, Tuple, Union +from typing import List, Optional, Tuple, Union import torch -from executorch.backends.cadence.aot.quantizer.utils import get_bias_qparams - +from executorch.backends.cadence.aot.compiler_utils import get_shape +from executorch.backends.cadence.aot.pass_utils import get_arg, replace_with_op +from executorch.backends.cadence.aot.quantizer.pattern_utils import ( + DQ_PER_TENSOR, + find_quant_user, + fuse_conv, + fuse_linear, + fuse_matmul, + insert_node_with_meta, +) +from executorch.backends.cadence.aot.quantizer.utils import ( + check_out_zero_point_is_min_range, + get_bias_qparams, + quantize_tensor_multiplier, +) from torch import fx from torch._ops import OpOverload from torchao.quantization.pt2e.quantizer import ( @@ -79,6 +92,22 @@ def replacement_op(self) -> OpOverload: """ pass + def anchor_ops(self) -> tuple[OpOverload, ...]: + return tuple(self.partition_types()) + + def fuse( + self, + gm: fx.GraphModule, + anchor_node: fx.Node, + ) -> Optional[fx.Node]: + """Replace the dq→op→q subgraph around ``anchor_node`` with a fused op. + + Called by ``QuantFusionPass`` for each node matching ``anchor_ops()``. + Returns the new fused node on success, or ``None`` to skip this match. + Subclasses override to implement pattern-specific fusion logic. + """ + return None + class AddmmPattern(QuantizationPattern): def partition_types(self) -> List[OpOverload]: @@ -115,6 +144,41 @@ def get_anchors( def replacement_op(self) -> OpOverload: return torch.ops.cadence.quantized_linear.per_tensor + def fuse(self, gm: fx.GraphModule, anchor_node: fx.Node) -> fx.Node | None: + assert anchor_node.target == torch.ops.aten.addmm.default + # addmm(bias, input, weight) + bias_node = anchor_node.args[0] + assert isinstance(bias_node, fx.Node) + dq_input = get_arg(anchor_node, "mat1", fx.Node) + if dq_input.target != DQ_PER_TENSOR: + return None + dq_weight = get_arg(anchor_node, "mat2", fx.Node) + if dq_weight.target != DQ_PER_TENSOR: + return None + quant_node = find_quant_user(anchor_node) + if quant_node is None: + return None + dq_bias = bias_node if bias_node.target == DQ_PER_TENSOR else None + weight_q = get_arg(dq_weight, "input", fx.Node) + transposed = insert_node_with_meta( + gm, + torch.ops.aten.transpose.int, + (weight_q, 0, 1), + None, + anchor_node, + weight_q, + ) + return fuse_linear( + gm, + dq_input, + dq_weight, + dq_bias, + quant_node, + anchor_node, + self.replacement_op(), + weight_q=transposed, + ) + class AddPattern(QuantizationPattern): def partition_types(self) -> List[OpOverload]: @@ -153,6 +217,33 @@ def get_anchors( def replacement_op(self) -> OpOverload: return torch.ops.cadence.quantized_add.per_tensor + def fuse(self, gm: fx.GraphModule, anchor_node: fx.Node) -> fx.Node | None: + # Skip if alpha kwarg is present — changes add semantics. + if anchor_node.kwargs: + return None + dq0 = anchor_node.args[0] + if not isinstance(dq0, fx.Node) or dq0.target != DQ_PER_TENSOR: + return None + dq1 = anchor_node.args[1] + if not isinstance(dq1, fx.Node) or dq1.target != DQ_PER_TENSOR: + return None + quant_node = find_quant_user(anchor_node) + if quant_node is None: + return None + args = ( + get_arg(dq0, "input", fx.Node), + get_arg(dq0, "scale", float), + get_arg(dq0, "zero_point", int), + get_arg(dq1, "input", fx.Node), + get_arg(dq1, "scale", float), + get_arg(dq1, "zero_point", int), + get_arg(quant_node, "scale", float), + get_arg(quant_node, "zero_point", int), + ) + return replace_with_op( + gm, anchor_node, self.replacement_op(), args, {}, quant_node + ) + # This is a base class for Add+ReLU fusion, since it can be used with two different relu aten ops class AddReluBasePattern(QuantizationPattern): @@ -196,6 +287,46 @@ def get_anchors( def replacement_op(self) -> OpOverload: return torch.ops.cadence.quantized_add.per_tensor + def anchor_ops(self) -> tuple[OpOverload, ...]: + return (torch.ops.aten.add.Tensor,) + + def fuse(self, gm: fx.GraphModule, anchor_node: fx.Node) -> fx.Node | None: + add_users = list(anchor_node.users) + if len(add_users) != 1: + return None + relu_node = add_users[0] + if relu_node.target != self.partition_types()[1]: + return None + if len(anchor_node.kwargs) > 0: + return None + dq0 = anchor_node.args[0] + if not isinstance(dq0, fx.Node) or dq0.target != DQ_PER_TENSOR: + return None + dq1 = anchor_node.args[1] + if not isinstance(dq1, fx.Node) or dq1.target != DQ_PER_TENSOR: + return None + quant_node = find_quant_user(relu_node) + if quant_node is None: + return None + if not check_out_zero_point_is_min_range( + get_arg(quant_node, "zero_point", int), + get_arg(quant_node, "dtype", torch.dtype), + ): + return None + args = ( + get_arg(dq0, "input", fx.Node), + get_arg(dq0, "scale", float), + get_arg(dq0, "zero_point", int), + get_arg(dq1, "input", fx.Node), + get_arg(dq1, "scale", float), + get_arg(dq1, "zero_point", int), + get_arg(quant_node, "scale", float), + get_arg(quant_node, "zero_point", int), + ) + return replace_with_op( + gm, anchor_node, self.replacement_op(), args, {}, quant_node + ) + # Add + regular relu op fusion class AddReluPattern0(AddReluBasePattern): @@ -234,6 +365,18 @@ def replacement_op(self) -> OpOverload: # we just need to change the name of the op return torch.ops.cadence.quantized_matmul.default + def fuse(self, gm: fx.GraphModule, anchor_node: fx.Node) -> fx.Node | None: + dq0 = anchor_node.args[0] + if not isinstance(dq0, fx.Node) or dq0.target != DQ_PER_TENSOR: + return None + dq1 = anchor_node.args[1] + if not isinstance(dq1, fx.Node) or dq1.target != DQ_PER_TENSOR: + return None + quant_node = find_quant_user(anchor_node) + if quant_node is None: + return None + return fuse_matmul(gm, anchor_node, dq0, dq1, quant_node, self.replacement_op()) + class CatPattern(QuantizationPattern): def partition_types(self) -> List[OpOverload]: @@ -283,6 +426,25 @@ def get_anchors( def replacement_op(self) -> OpOverload: return torch.ops.aten.cat.default + def fuse(self, gm: fx.GraphModule, anchor_node: fx.Node) -> fx.Node | None: + cat_inputs = anchor_node.args[0] + if not isinstance(cat_inputs, (list, tuple)) or not cat_inputs: + return None + inputs_q = [] + for inp in cat_inputs: + if not isinstance(inp, fx.Node) or inp.target != DQ_PER_TENSOR: + return None + inputs_q.append(get_arg(inp, "input", fx.Node)) + quant_node = find_quant_user(anchor_node) + if quant_node is None: + return None + dim = get_arg(anchor_node, "dim", int) + args = (inputs_q,) + kwargs = {"dim": dim} + return replace_with_op( + gm, anchor_node, self.replacement_op(), args, kwargs, quant_node + ) + class Conv1dPattern(QuantizationPattern): def partition_types(self) -> List[OpOverload]: @@ -325,6 +487,18 @@ def get_anchors( def replacement_op(self) -> OpOverload: return torch.ops.cadence.quantized_conv1d_ncl.per_tensor + def fuse(self, gm: fx.GraphModule, anchor_node: fx.Node) -> fx.Node | None: + dq_input = anchor_node.args[0] + if not isinstance(dq_input, fx.Node) or dq_input.target != DQ_PER_TENSOR: + return None + dq_weight = anchor_node.args[1] + if not isinstance(dq_weight, fx.Node) or dq_weight.target != DQ_PER_TENSOR: + return None + quant_node = find_quant_user(anchor_node) + if quant_node is None: + return None + return fuse_conv(self, gm, anchor_node, dq_input, dq_weight, quant_node) + class Conv2dPattern(QuantizationPattern): def partition_types(self) -> List[OpOverload]: @@ -367,6 +541,18 @@ def get_anchors( def replacement_op(self) -> OpOverload: return torch.ops.cadence.quantized_conv2d_nchw.per_tensor + def fuse(self, gm: fx.GraphModule, anchor_node: fx.Node) -> fx.Node | None: + dq_input = anchor_node.args[0] + if not isinstance(dq_input, fx.Node) or dq_input.target != DQ_PER_TENSOR: + return None + dq_weight = anchor_node.args[1] + if not isinstance(dq_weight, fx.Node) or dq_weight.target != DQ_PER_TENSOR: + return None + quant_node = find_quant_user(anchor_node) + if quant_node is None: + return None + return fuse_conv(self, gm, anchor_node, dq_input, dq_weight, quant_node) + class LayerNormPattern(QuantizationPattern): def partition_types(self) -> List[OpOverload]: @@ -405,6 +591,61 @@ def get_anchors( def replacement_op(self) -> OpOverload: return torch.ops.cadence.quantized_layer_norm.per_tensor + def fuse(self, gm: fx.GraphModule, anchor_node: fx.Node) -> fx.Node | None: + dq_input = anchor_node.args[0] + if not isinstance(dq_input, fx.Node) or dq_input.target != DQ_PER_TENSOR: + return None + quant_node = find_quant_user(anchor_node) + if quant_node is None: + return None + scale = get_arg(dq_input, "scale", float) + zero_point = get_arg(dq_input, "zero_point", int) + normalized_shape = anchor_node.args[1] + assert isinstance(normalized_shape, list) + weight = ( + anchor_node.args[2] + if len(anchor_node.args) > 2 and anchor_node.args[2] + else None + ) + bias = ( + anchor_node.args[3] + if len(anchor_node.args) > 3 and anchor_node.args[3] + else None + ) + input_q = get_arg(dq_input, "input", fx.Node) + # Default weight=1 and bias=0 must be float32 — cadence::quantized_layer_norm + # expects float affine parameters, not quantized values. + if not weight: + weight = insert_node_with_meta( + gm, + torch.ops.aten.full.default, + (normalized_shape, 1), + {"dtype": torch.float32}, + anchor_node, + input_q, + ) + if not bias: + bias = insert_node_with_meta( + gm, + torch.ops.aten.full.default, + (normalized_shape, 0), + {"dtype": torch.float32}, + anchor_node, + input_q, + ) + args = (input_q, scale, zero_point) + kwargs = { + "normalized_shape": normalized_shape, + "weight": weight, + "bias": bias, + "eps": get_arg(anchor_node, "eps", float), + "output_scale": get_arg(quant_node, "scale", float), + "output_zero_point": get_arg(quant_node, "zero_point", int), + } + return replace_with_op( + gm, anchor_node, self.replacement_op(), args, kwargs, quant_node + ) + class LinearPattern(QuantizationPattern): def partition_types(self) -> List[OpOverload]: @@ -447,6 +688,31 @@ def get_anchors( def replacement_op(self) -> OpOverload: return torch.ops.cadence.quantized_linear.per_tensor + def fuse(self, gm: fx.GraphModule, anchor_node: fx.Node) -> fx.Node | None: + dq_input = anchor_node.args[0] + if not isinstance(dq_input, fx.Node) or dq_input.target != DQ_PER_TENSOR: + return None + dq_weight = anchor_node.args[1] + if not isinstance(dq_weight, fx.Node) or dq_weight.target != DQ_PER_TENSOR: + return None + quant_node = find_quant_user(anchor_node) + if quant_node is None: + return None + dq_bias: fx.Node | None = None + if len(anchor_node.args) > 2: + bias_arg = anchor_node.args[2] + if isinstance(bias_arg, fx.Node) and bias_arg.target == DQ_PER_TENSOR: + dq_bias = bias_arg + return fuse_linear( + gm, + dq_input, + dq_weight, + dq_bias, + quant_node, + anchor_node, + self.replacement_op(), + ) + class MatmulPattern(QuantizationPattern): def partition_types(self) -> List[OpOverload]: @@ -472,6 +738,18 @@ def replacement_op(self) -> OpOverload: # TODO: T240804887 This is actually a per-tensor variant, we just need to change the name of the op return torch.ops.cadence.quantized_matmul.default + def fuse(self, gm: fx.GraphModule, anchor_node: fx.Node) -> fx.Node | None: + dq0 = anchor_node.args[0] + if not isinstance(dq0, fx.Node) or dq0.target != DQ_PER_TENSOR: + return None + dq1 = anchor_node.args[1] + if not isinstance(dq1, fx.Node) or dq1.target != DQ_PER_TENSOR: + return None + quant_node = find_quant_user(anchor_node) + if quant_node is None: + return None + return fuse_matmul(gm, anchor_node, dq0, dq1, quant_node, self.replacement_op()) + class MaxPool2dPattern(QuantizationPattern): """ @@ -530,6 +808,40 @@ def get_anchors( def replacement_op(self) -> OpOverload: return torch.ops.cadence.quantized_max_pool2d_nchw.default + def fuse(self, gm: fx.GraphModule, anchor_node: fx.Node) -> fx.Node | None: + return _fuse_max_pool2d(gm, anchor_node) + + +def _fuse_max_pool2d(gm: fx.GraphModule, anchor_node: fx.Node) -> fx.Node | None: + """Shared fuse logic for both MaxPool2d variants.""" + dq_input = anchor_node.args[0] + if not isinstance(dq_input, fx.Node) or dq_input.target != DQ_PER_TENSOR: + return None + quant_node = find_quant_user(anchor_node) + if quant_node is None: + return None + kernel_size = get_arg(anchor_node, "kernel_size", list[int]) + stride = get_arg(anchor_node, "stride", list[int]) + padding = get_arg(anchor_node, "padding", list[int]) + dilation = get_arg(anchor_node, "dilation", list[int]) + ceil_mode = get_arg(anchor_node, "ceil_mode", bool) + args = (get_arg(dq_input, "input", fx.Node),) + kwargs = { + "kernel_size": kernel_size, + "stride": stride, + "padding": padding, + "dilation": dilation, + "ceil_mode": ceil_mode, + } + return replace_with_op( + gm, + anchor_node, + torch.ops.cadence.quantized_max_pool2d_nchw.default, + args, + kwargs, + quant_node, + ) + class MaxPool2dWithoutIndicesPattern(QuantizationPattern): """ @@ -569,8 +881,8 @@ def get_anchors( def replacement_op(self) -> OpOverload: return torch.ops.cadence.quantized_max_pool2d_nchw.default - -# This is a base class for ReLU + def fuse(self, gm: fx.GraphModule, anchor_node: fx.Node) -> fx.Node | None: + return _fuse_max_pool2d(gm, anchor_node) # This is a base class for ReLU, since it can be used with two different aten ops @@ -598,6 +910,28 @@ def get_anchors( def replacement_op(self) -> OpOverload: return torch.ops.cadence.quantized_relu.per_tensor + def fuse(self, gm: fx.GraphModule, anchor_node: fx.Node) -> fx.Node | None: + dq_input = anchor_node.args[0] + if not isinstance(dq_input, fx.Node) or dq_input.target != DQ_PER_TENSOR: + return None + quant_node = find_quant_user(anchor_node) + if quant_node is None: + return None + input_scale = get_arg(dq_input, "scale", float) + requantize_scale = input_scale / get_arg(quant_node, "scale", float) + requantize_scale_t = torch.tensor([requantize_scale]) + out_multiplier, out_shift = quantize_tensor_multiplier(requantize_scale_t) + args = (get_arg(dq_input, "input", fx.Node),) + kwargs = { + "X_zero_point": get_arg(dq_input, "zero_point", int), + "out_zero_point": get_arg(quant_node, "zero_point", int), + "out_multiplier": out_multiplier[0].item(), + "out_shift": out_shift[0].item(), + } + return replace_with_op( + gm, anchor_node, self.replacement_op(), args, kwargs, quant_node + ) + # Regular relu op class ReluPattern0(ReluBasePattern): @@ -657,6 +991,39 @@ def get_anchors( def replacement_op(self) -> OpOverload: return torch.ops.cadence.quantized_conv2d_nchw.per_tensor + def anchor_ops(self) -> tuple[OpOverload, ...]: + return (self.partition_types()[0],) + + def fuse(self, gm: fx.GraphModule, anchor_node: fx.Node) -> fx.Node | None: + conv_users = list(anchor_node.users) + if len(conv_users) != 1: + return None + relu_node = conv_users[0] + if relu_node.target != self.partition_types()[1]: + return None + _arg0 = anchor_node.args[0] + dq_input = ( + _arg0 + if isinstance(_arg0, fx.Node) and _arg0.target == DQ_PER_TENSOR + else None + ) + _arg1 = anchor_node.args[1] + dq_weight = ( + _arg1 + if isinstance(_arg1, fx.Node) and _arg1.target == DQ_PER_TENSOR + else None + ) + if dq_input is None or dq_weight is None: + return None + quant_node = find_quant_user(relu_node) + if quant_node is None: + return None + check_out_zero_point_is_min_range( + get_arg(quant_node, "zero_point", int), + get_arg(quant_node, "dtype", torch.dtype), + ) + return fuse_conv(self, gm, anchor_node, dq_input, dq_weight, quant_node) + # Conv1d + regular relu op fusion class Conv1dReluPattern0(ConvReluBasePattern): @@ -711,6 +1078,53 @@ def get_anchors( def replacement_op(self) -> OpOverload: return torch.ops.cadence.quantized_softmax.per_tensor + def fuse(self, gm: fx.GraphModule, anchor_node: fx.Node) -> fx.Node | None: + dq_input = anchor_node.args[0] + if not isinstance(dq_input, fx.Node) or dq_input.target != DQ_PER_TENSOR: + return None + quant_node = find_quant_user(anchor_node) + if quant_node is None: + return None + input_q = get_arg(dq_input, "input", fx.Node) + quant_input = get_arg(quant_node, "input", fx.Node) + mask_shape = get_shape(gm, quant_input) + if not mask_shape: + return None + mask_shape = list(mask_shape) + # Softmax mask is packed 16 elements per int32 word. + mask_shape[-1] = mask_shape[-1] // 16 + mask_tensor = insert_node_with_meta( + gm, + torch.ops.aten.full.default, + (mask_shape, 0.0), + {"dtype": torch.int32}, + anchor_node, + input_q, + ) + # Initial position for streaming softmax (unused, set to 0). + pos_tensor = insert_node_with_meta( + gm, + torch.ops.aten.full.default, + ([1], 0), + {"dtype": torch.int64}, + anchor_node, + input_q, + ) + args = ( + input_q, + mask_tensor, + get_arg(anchor_node, "dim", int), + 0, + pos_tensor, + get_arg(dq_input, "scale", float), + get_arg(dq_input, "zero_point", int), + get_arg(quant_node, "scale", float), + get_arg(quant_node, "zero_point", int), + ) + return replace_with_op( + gm, anchor_node, self.replacement_op(), args, {}, quant_node + ) + class MixedW8A32LinearPattern(QuantizationPattern): def partition_types(self) -> List[OpOverload]: @@ -765,6 +1179,36 @@ def get_anchors( def replacement_op(self) -> OpOverload: return torch.ops.cadence.quantized_w8a32_linear.default + def fuse(self, gm: fx.GraphModule, anchor_node: fx.Node) -> fx.Node | None: + if len(anchor_node.args) != 3 or len(anchor_node.kwargs) > 0: + return None + _arg1 = anchor_node.args[1] + dq_weight = ( + _arg1 + if isinstance(_arg1, fx.Node) and _arg1.target == DQ_PER_TENSOR + else None + ) + _arg2 = anchor_node.args[2] + dq_bias = ( + _arg2 + if isinstance(_arg2, fx.Node) and _arg2.target == DQ_PER_TENSOR + else None + ) + if dq_weight is None or dq_bias is None: + return None + input_node = anchor_node.args[0] + assert isinstance(input_node, fx.Node) + args = ( + input_node, + get_arg(dq_weight, "input", fx.Node), + get_arg(dq_weight, "scale", float), + get_arg(dq_bias, "input", fx.Node), + get_arg(dq_bias, "scale", float), + ) + return replace_with_op( + gm, anchor_node, self.replacement_op(), args, {}, anchor_node + ) + class MixedW8A32ConvPattern(QuantizationPattern): def partition_types(self) -> List[OpOverload]: @@ -839,6 +1283,57 @@ def get_anchors( def replacement_op(self) -> OpOverload: return torch.ops.cadence.quantized_w8a32_conv.default + def fuse(self, gm: fx.GraphModule, anchor_node: fx.Node) -> fx.Node | None: + if len(anchor_node.args) != 3 or len(anchor_node.kwargs) > 0: + return None + _arg1 = anchor_node.args[1] + dq_weight = ( + _arg1 + if isinstance(_arg1, fx.Node) and _arg1.target == DQ_PER_TENSOR + else None + ) + _arg2 = anchor_node.args[2] + dq_bias = ( + _arg2 + if isinstance(_arg2, fx.Node) and _arg2.target == DQ_PER_TENSOR + else None + ) + if dq_weight is None or dq_bias is None: + return None + input_node = anchor_node.args[0] + assert isinstance(input_node, fx.Node) + assert get_arg(anchor_node, "stride", list[int]) == [1] + assert get_arg(anchor_node, "padding", list[int]) == [0] + assert get_arg(anchor_node, "dilation", list[int]) == [1] + assert get_arg(anchor_node, "groups", int) == 1 + weight_q = get_arg(dq_weight, "input", fx.Node) + transposed_inputs = insert_node_with_meta( + gm, + torch.ops.aten.permute.default, + (input_node, [0, 2, 1]), + None, + anchor_node, + input_node, + ) + transposed_weights = insert_node_with_meta( + gm, + torch.ops.aten.permute.default, + (weight_q, [2, 0, 1]), + None, + anchor_node, + weight_q, + ) + args = ( + transposed_inputs, + transposed_weights, + get_arg(dq_weight, "scale", float), + get_arg(dq_bias, "input", fx.Node), + get_arg(dq_bias, "scale", float), + ) + return replace_with_op( + gm, anchor_node, self.replacement_op(), args, {}, anchor_node + ) + class MixedW8A32GruPattern(QuantizationPattern): def partition_types(self) -> List[OpOverload]: @@ -911,6 +1406,42 @@ def __init__(self, args, meta): def replacement_op(self) -> OpOverload: return torch.ops.cadence.quantized_w8a32_gru.default + def fuse(self, gm: fx.GraphModule, anchor_node: fx.Node) -> fx.Node | None: + if len(anchor_node.kwargs) > 0: + return None + params = anchor_node.args[2] + # GRU requires 4 weight/bias params: w_ih, w_hh, b_ih, b_hh + if not isinstance(params, (list, tuple)) or len(params) < 4: + return None + dq_w_ih = params[0] + if not isinstance(dq_w_ih, fx.Node) or dq_w_ih.target != DQ_PER_TENSOR: + return None + dq_w_hh = params[1] + if not isinstance(dq_w_hh, fx.Node) or dq_w_hh.target != DQ_PER_TENSOR: + return None + dq_b_ih = params[2] + if not isinstance(dq_b_ih, fx.Node) or dq_b_ih.target != DQ_PER_TENSOR: + return None + dq_b_hh = params[3] + if not isinstance(dq_b_hh, fx.Node) or dq_b_hh.target != DQ_PER_TENSOR: + return None + input_node = anchor_node.args[0] + hidden_node = anchor_node.args[1] + args = ( + input_node, + hidden_node, + get_arg(dq_w_ih, "input", fx.Node), + get_arg(dq_w_ih, "scale", float), + get_arg(dq_w_hh, "input", fx.Node), + get_arg(dq_w_hh, "scale", float), + get_arg(dq_b_ih, "input", fx.Node), + get_arg(dq_b_ih, "scale", float), + get_arg(dq_b_hh, "input", fx.Node), + ) + return replace_with_op( + gm, anchor_node, self.replacement_op(), args, {}, anchor_node + ) + class RmsNormPattern(QuantizationPattern): """Pattern that preserves rms_norm from decomposition without matching anything.""" diff --git a/backends/cadence/aot/quantizer/utils.py b/backends/cadence/aot/quantizer/utils.py index 51182a4ce92..f5773938f0a 100644 --- a/backends/cadence/aot/quantizer/utils.py +++ b/backends/cadence/aot/quantizer/utils.py @@ -118,7 +118,9 @@ def create_zero_bias_int32( bias_scale: float, ) -> fx.Node: """ - Creates a zero bias tensor with the shape of weight[0] + Creates a zero bias tensor with the shape of weight[0]. + Caller is responsible for setting the graph insertion point + (e.g. ``with gm.graph.inserting_before(node):``). """ try: attr_node = getattr(graph_module, weight_node.target) diff --git a/backends/cadence/aot/replace_ops.py b/backends/cadence/aot/replace_ops.py index 4b60feb2121..50112a4eb66 100644 --- a/backends/cadence/aot/replace_ops.py +++ b/backends/cadence/aot/replace_ops.py @@ -162,14 +162,31 @@ def targets(self) -> list[EdgeOpOverload]: def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: ns = exir_ops.edge if isinstance(node.target, EdgeOpOverload) else torch.ops + out_dtype = node.kwargs.get("out_dtype") + kwargs = {k: v for k, v in node.kwargs.items() if k != "out_dtype"} with node.graph.inserting_before(node): new_node = node.graph.call_function( ns.cadence.dequantize_per_tensor.default, args=node.args, - kwargs=node.kwargs, + kwargs=kwargs, ) - new_node.meta = node.meta - node.replace_all_uses_with(new_node) + new_node.meta = node.meta.copy() + if ( + out_dtype is not None + and out_dtype != torch.float32 + and "val" in new_node.meta + ): + new_node.meta["val"] = new_node.meta["val"].to(torch.float32) + if out_dtype is not None and out_dtype != torch.float32: + with node.graph.inserting_after(new_node): + cast_node = node.graph.call_function( + ns.aten.to.dtype, + args=(new_node, out_dtype), + ) + cast_node.meta = node.meta.copy() + node.replace_all_uses_with(cast_node) + else: + node.replace_all_uses_with(new_node) return True diff --git a/backends/cadence/aot/tests/test_replace_ops_passes.py b/backends/cadence/aot/tests/test_replace_ops_passes.py index 170da6deb09..a73ef02c996 100644 --- a/backends/cadence/aot/tests/test_replace_ops_passes.py +++ b/backends/cadence/aot/tests/test_replace_ops_passes.py @@ -1250,6 +1250,7 @@ def test_replace_conv1d_with_linear(self) -> None: inputs, "ReplaceTrivialConvWithLinear", rtol=2e-5, + atol=5e-6, ) # Assert that conv1d is trivially converted to linear @@ -1294,6 +1295,7 @@ def test_replace_conv2d_with_linear(self) -> None: inputs, "ReplaceTrivialConvWithLinear", rtol=2e-5, + atol=5e-6, ) # Assert that conv2d is trivially converted to linear diff --git a/backends/cortex_m/CMakeLists.txt b/backends/cortex_m/CMakeLists.txt index 876c65982e6..627406c1935 100644 --- a/backends/cortex_m/CMakeLists.txt +++ b/backends/cortex_m/CMakeLists.txt @@ -30,6 +30,10 @@ set(CMSIS_NN_LOCAL_PATH "" CACHE PATH "Path to existing local CMSIS-NN installation" ) +option(CORTEX_M_ENABLE_RUNTIME_CHECKS + "Enable additional Cortex-M runtime assertions and validation checks" + OFF +) # Try to find existing / local CMSIS-NN installation. This is useful for # debugging and testing with local changes. This is not common, as the CMSIS-NN @@ -107,6 +111,11 @@ target_link_libraries( PRIVATE executorch PRIVATE kernels_util_all_deps ) +target_compile_definitions( + cortex_m_kernels + PRIVATE + $<$:CORTEX_M_ENABLE_RUNTIME_CHECKS> +) # Include directories for cortex_m_kernels target_include_directories( diff --git a/backends/cortex_m/ops/cmsis_scratch_buffer_context.h b/backends/cortex_m/ops/cmsis_scratch_buffer_context.h index 4672f05e777..656309abcee 100644 --- a/backends/cortex_m/ops/cmsis_scratch_buffer_context.h +++ b/backends/cortex_m/ops/cmsis_scratch_buffer_context.h @@ -1,3 +1,4 @@ +// cppcheck-suppress-file unusedFunction /* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. diff --git a/backends/cortex_m/ops/cortex_m_ops_common.h b/backends/cortex_m/ops/cortex_m_ops_common.h index 4c0f83d6eb6..2e3f49dd861 100644 --- a/backends/cortex_m/ops/cortex_m_ops_common.h +++ b/backends/cortex_m/ops/cortex_m_ops_common.h @@ -113,8 +113,7 @@ inline void validate_quantization_params( const int64_t shift2, const int64_t output_zero_point, const int64_t output_multiplier, - const int64_t output_shift, - Tensor& output) { + const int64_t output_shift) { validate_single_quant_params( zero_point1, multiplier1, shift1, "Single quant Input1"); validate_single_quant_params( @@ -346,6 +345,7 @@ inline bool prepare_cmsis_pool2d_config( // https://github.com/ARM-software/CMSIS-NN/blob/main/Include/arm_nnsupportfunctions.h#L1625 // multiplier: Range {ARM_NN_Q31_MIN + 1, Q32_MAX} // shift : Range {-31, 30} +// cppcheck-suppress unusedFunction inline bool validate_per_channel_quant_params( const Int64ArrayRef multipliers, const Int64ArrayRef shifts, diff --git a/backends/cortex_m/ops/op_dequantize_per_tensor.cpp b/backends/cortex_m/ops/op_dequantize_per_tensor.cpp index ca648f74695..136bce297b0 100644 --- a/backends/cortex_m/ops/op_dequantize_per_tensor.cpp +++ b/backends/cortex_m/ops/op_dequantize_per_tensor.cpp @@ -100,6 +100,7 @@ F dequantize_val(float scale, int32_t zero_point, Q qvalue) { } // namespace Tensor& dequantize_per_tensor_out( + // cppcheck-suppress constParameterReference KernelRuntimeContext& context, const Tensor& input, double scale, diff --git a/backends/cortex_m/ops/op_maximum.cpp b/backends/cortex_m/ops/op_maximum.cpp index fc76f5c8c48..936ef273684 100644 --- a/backends/cortex_m/ops/op_maximum.cpp +++ b/backends/cortex_m/ops/op_maximum.cpp @@ -1,5 +1,5 @@ /* - * Copyright 2025 Arm Limited and/or its affiliates. + * Copyright 2025-2026 Arm Limited and/or its affiliates. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. @@ -12,6 +12,7 @@ namespace native { using KernelRuntimeContext = torch::executor::KernelRuntimeContext; +// cppcheck-suppress unusedFunction Tensor& maximum_out( KernelRuntimeContext& context, const Tensor& input1, diff --git a/backends/cortex_m/ops/op_minimum.cpp b/backends/cortex_m/ops/op_minimum.cpp index 5a75cb8a1dc..3324a4e39d7 100644 --- a/backends/cortex_m/ops/op_minimum.cpp +++ b/backends/cortex_m/ops/op_minimum.cpp @@ -1,7 +1,7 @@ /* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. - * Copyright 2025 Arm Limited and/or its affiliates. + * Copyright 2025-2026 Arm Limited and/or its affiliates. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. @@ -14,6 +14,7 @@ namespace native { using KernelRuntimeContext = torch::executor::KernelRuntimeContext; +// cppcheck-suppress unusedFunction Tensor& minimum_out( KernelRuntimeContext& context, const Tensor& input1, diff --git a/backends/cortex_m/ops/op_pad.cpp b/backends/cortex_m/ops/op_pad.cpp index e59f986c37d..57b5257873e 100644 --- a/backends/cortex_m/ops/op_pad.cpp +++ b/backends/cortex_m/ops/op_pad.cpp @@ -19,6 +19,7 @@ constexpr size_t kMaxSupportedDims = 4; } // namespace +// cppcheck-suppress unusedFunction Tensor& pad_out( KernelRuntimeContext& context, const Tensor& input, diff --git a/backends/cortex_m/ops/op_quantize_per_tensor.cpp b/backends/cortex_m/ops/op_quantize_per_tensor.cpp index 7809db379c7..d8bb34c6eb4 100644 --- a/backends/cortex_m/ops/op_quantize_per_tensor.cpp +++ b/backends/cortex_m/ops/op_quantize_per_tensor.cpp @@ -97,6 +97,7 @@ Q quantize_val( } // namespace Tensor& quantize_per_tensor_out( + // cppcheck-suppress constParameterReference KernelRuntimeContext& context, const Tensor& input, double scale, diff --git a/backends/cortex_m/ops/op_quantized_add.cpp b/backends/cortex_m/ops/op_quantized_add.cpp index f607977aa48..f93bb6c1be9 100644 --- a/backends/cortex_m/ops/op_quantized_add.cpp +++ b/backends/cortex_m/ops/op_quantized_add.cpp @@ -13,6 +13,7 @@ namespace cortex_m { namespace native { using KernelRuntimeContext = torch::executor::KernelRuntimeContext; +// cppcheck-suppress unusedFunction Tensor& quantized_add_out( KernelRuntimeContext& context, const Tensor& input1_int8, @@ -49,8 +50,7 @@ Tensor& quantized_add_out( input2_shift, output_zero_point, output_multiplier, - output_shift, - out); + output_shift); ET_LOG( Debug, diff --git a/backends/cortex_m/ops/op_quantized_avg_pool2d.cpp b/backends/cortex_m/ops/op_quantized_avg_pool2d.cpp index fc04edcc82b..0d22971f89b 100644 --- a/backends/cortex_m/ops/op_quantized_avg_pool2d.cpp +++ b/backends/cortex_m/ops/op_quantized_avg_pool2d.cpp @@ -12,6 +12,7 @@ namespace native { using KernelRuntimeContext = torch::executor::KernelRuntimeContext; +// cppcheck-suppress unusedFunction Tensor& quantized_avg_pool2d_out( KernelRuntimeContext& context, const Tensor& input, diff --git a/backends/cortex_m/ops/op_quantized_batch_matmul.cpp b/backends/cortex_m/ops/op_quantized_batch_matmul.cpp index e6bc5a949ce..fd0859e8b00 100644 --- a/backends/cortex_m/ops/op_quantized_batch_matmul.cpp +++ b/backends/cortex_m/ops/op_quantized_batch_matmul.cpp @@ -1,6 +1,7 @@ /* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. + * Copyright 2026 Arm Limited and/or its affiliates. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. @@ -62,6 +63,7 @@ bool validate_batch_matmul_arguments( } // namespace +// cppcheck-suppress unusedFunction Tensor& quantized_batch_matmul_out( KernelRuntimeContext& context, const Tensor& lhs, @@ -71,6 +73,7 @@ Tensor& quantized_batch_matmul_out( int64_t output_offset, int64_t output_multiplier, int64_t output_shift, + const Tensor& scratch, Tensor& out) { if (!validate_batch_matmul_arguments(context, lhs, rhs_transposed, out)) { return out; @@ -100,25 +103,26 @@ Tensor& quantized_batch_matmul_out( quant_params.multiplier = static_cast(output_multiplier); quant_params.shift = static_cast(output_shift); - const int32_t buf_size = arm_fully_connected_s8_get_buffer_size(&out_dims); - cmsis_nn_context ctx; ctx.buf = nullptr; - ctx.size = 0; - - if (buf_size > 0) { - auto buffer_or_error = context.allocate_temp(buf_size); - if (!buffer_or_error.ok()) { - ET_LOG( - Error, - "quantized_batch_matmul: failed to allocate scratch buffer (%d bytes)", - buf_size); - context.fail(buffer_or_error.error()); - return out; - } - ctx.buf = buffer_or_error.get(); - ctx.size = buf_size; + ctx.size = scratch.nbytes(); + if (ctx.size > 0) { + ctx.buf = scratch.mutable_data_ptr(); + } + +#ifdef CORTEX_M_ENABLE_RUNTIME_CHECKS + const int32_t runtime_buffer_bytes = + arm_fully_connected_s8_get_buffer_size(&out_dims); + if (ctx.size != static_cast(runtime_buffer_bytes)) { + ET_LOG( + Error, + "quantized_batch_matmul: scratch buffer size incorrect - actual: (%d) needed: (%d)", + static_cast(ctx.size), + runtime_buffer_bytes); + context.fail(Error::Internal); + return out; } +#endif const arm_cmsis_nn_status status = arm_batch_matmul_s8( &ctx, diff --git a/backends/cortex_m/ops/op_quantized_conv2d.cpp b/backends/cortex_m/ops/op_quantized_conv2d.cpp index 7d4433690f6..3d4f19e10d0 100644 --- a/backends/cortex_m/ops/op_quantized_conv2d.cpp +++ b/backends/cortex_m/ops/op_quantized_conv2d.cpp @@ -98,6 +98,7 @@ bool validate_conv2d_arguments( } } // namespace +// cppcheck-suppress unusedFunction Tensor& quantized_conv2d_out( KernelRuntimeContext& context, const Tensor& input, @@ -112,6 +113,7 @@ Tensor& quantized_conv2d_out( const Tensor& requantize_shifts, const int64_t activation_min, const int64_t activation_max, + const Tensor& scratch, Tensor& out) { if (!validate_conv2d_arguments( context, @@ -182,31 +184,30 @@ Tensor& quantized_conv2d_out( cmsis_nn_context cmsis_context; cmsis_context.buf = nullptr; - cmsis_context.size = 0; + cmsis_context.size = scratch.nbytes(); + if (cmsis_context.size > 0) { + cmsis_context.buf = scratch.mutable_data_ptr(); + } - const int32_t buffer_bytes = arm_convolve_wrapper_s8_get_buffer_size( +#ifdef CORTEX_M_ENABLE_RUNTIME_CHECKS + const int32_t runtime_buffer_bytes = arm_convolve_wrapper_s8_get_buffer_size( &conv_params, &input_dims, &filter_dims, &output_dims); - if (buffer_bytes < 0) { + if (runtime_buffer_bytes < 0) { ET_LOG( Error, "quantized_conv2d_out: CMSIS-NN buffer size calculation failed"); context.fail(Error::Internal); return out; } - if (buffer_bytes > 0) { - auto buffer_or_error = - context.allocate_temp(buffer_bytes, kCortexMMveAlignment); - if (!buffer_or_error.ok()) { - ET_LOG( - Error, - "quantized_conv2d_out: failed to allocate scratch buffer (%d bytes, error %d)", - static_cast(buffer_bytes), - static_cast(buffer_or_error.error())); - context.fail(buffer_or_error.error()); - return out; - } - cmsis_context.buf = buffer_or_error.get(); - cmsis_context.size = buffer_bytes; + if (scratch.nbytes() != static_cast(runtime_buffer_bytes)) { + ET_LOG( + Error, + "quantized_conv2d_out: scratch buffer size incorrect - actual: (%d) needed: (%d)", + static_cast(scratch.nbytes()), + static_cast(runtime_buffer_bytes)); + context.fail(Error::Internal); + return out; } +#endif const arm_cmsis_nn_status status = arm_convolve_wrapper_s8( &cmsis_context, diff --git a/backends/cortex_m/ops/op_quantized_depthwise_conv2d.cpp b/backends/cortex_m/ops/op_quantized_depthwise_conv2d.cpp index 8dec61e0af1..a8e1fc21ed7 100644 --- a/backends/cortex_m/ops/op_quantized_depthwise_conv2d.cpp +++ b/backends/cortex_m/ops/op_quantized_depthwise_conv2d.cpp @@ -135,6 +135,7 @@ bool validate_depthwise_conv2d_arguments( } } // namespace +// cppcheck-suppress unusedFunction Tensor& quantized_depthwise_conv2d_out( KernelRuntimeContext& context, const Tensor& input, @@ -150,6 +151,7 @@ Tensor& quantized_depthwise_conv2d_out( const Tensor& requantize_shifts, const int64_t activation_min, const int64_t activation_max, + const Tensor& scratch, Tensor& out) { if (!validate_depthwise_conv2d_arguments( context, @@ -220,32 +222,32 @@ Tensor& quantized_depthwise_conv2d_out( cmsis_nn_context cmsis_context; cmsis_context.buf = nullptr; - cmsis_context.size = 0; + cmsis_context.size = scratch.nbytes(); + if (cmsis_context.size > 0) { + cmsis_context.buf = scratch.mutable_data_ptr(); + } - const int32_t buffer_bytes = arm_depthwise_conv_wrapper_s8_get_buffer_size( - &dw_conv_params, &input_dims, &filter_dims, &output_dims); - if (buffer_bytes < 0) { +#ifdef CORTEX_M_ENABLE_RUNTIME_CHECKS + const int32_t runtime_buffer_bytes = + arm_depthwise_conv_wrapper_s8_get_buffer_size( + &dw_conv_params, &input_dims, &filter_dims, &output_dims); + if (runtime_buffer_bytes < 0) { ET_LOG( Error, "quantized_depthwise_conv2d_out: CMSIS-NN buffer size calculation failed"); context.fail(Error::Internal); return out; } - - auto buffer_or_error = context.allocate_temp( - static_cast(buffer_bytes), kCortexMMveAlignment); - if (!buffer_or_error.ok()) { + if (scratch.nbytes() != static_cast(runtime_buffer_bytes)) { ET_LOG( Error, - "quantized_depthwise_conv2d_out: failed to allocate scratch buffer (%d bytes, error %d)", - static_cast(buffer_bytes), - static_cast(buffer_or_error.error())); - context.fail(buffer_or_error.error()); + "quantized_depthwise_conv2d_out: scratch buffer size incorrect - actual: (%d) needed: (%d)", + static_cast(scratch.nbytes()), + static_cast(runtime_buffer_bytes)); + context.fail(Error::Internal); return out; } - cmsis_context.buf = buffer_or_error.get(); - cmsis_context.size = buffer_bytes; - +#endif const arm_cmsis_nn_status status = arm_depthwise_conv_wrapper_s8( &cmsis_context, &dw_conv_params, diff --git a/backends/cortex_m/ops/op_quantized_linear.cpp b/backends/cortex_m/ops/op_quantized_linear.cpp index 5d018cbc0c4..7448058de8e 100644 --- a/backends/cortex_m/ops/op_quantized_linear.cpp +++ b/backends/cortex_m/ops/op_quantized_linear.cpp @@ -13,6 +13,7 @@ namespace cortex_m { namespace native { using KernelRuntimeContext = torch::executor::KernelRuntimeContext; +// cppcheck-suppress unusedFunction Tensor& quantized_linear_out( KernelRuntimeContext& context, const Tensor& input, diff --git a/backends/cortex_m/ops/op_quantized_max_pool2d.cpp b/backends/cortex_m/ops/op_quantized_max_pool2d.cpp index 181a29c1b65..ca1b00ff340 100644 --- a/backends/cortex_m/ops/op_quantized_max_pool2d.cpp +++ b/backends/cortex_m/ops/op_quantized_max_pool2d.cpp @@ -10,6 +10,7 @@ namespace cortex_m { namespace native { +// cppcheck-suppress unusedFunction Tensor& quantized_max_pool2d_out( KernelRuntimeContext& context, const Tensor& input, diff --git a/backends/cortex_m/ops/op_quantized_mul.cpp b/backends/cortex_m/ops/op_quantized_mul.cpp index 524e74a6b9f..93ce2303d64 100644 --- a/backends/cortex_m/ops/op_quantized_mul.cpp +++ b/backends/cortex_m/ops/op_quantized_mul.cpp @@ -18,6 +18,7 @@ constexpr int32_t kInt8ActivationMax = std::numeric_limits::max(); using KernelRuntimeContext = torch::executor::KernelRuntimeContext; +// cppcheck-suppress unusedFunction Tensor& quantized_mul_out( KernelRuntimeContext& context, const Tensor& input1_int8, @@ -50,8 +51,7 @@ Tensor& quantized_mul_out( kZeroShift, output_zero_point, output_multiplier, - output_shift, - out); + output_shift); // Extract quantization parameters int8_t* input1_ptr = input1_int8.data_ptr(); diff --git a/backends/cortex_m/ops/op_quantized_transpose_conv2d.cpp b/backends/cortex_m/ops/op_quantized_transpose_conv2d.cpp index e3f6135c7b9..e7ecbc7c7b4 100644 --- a/backends/cortex_m/ops/op_quantized_transpose_conv2d.cpp +++ b/backends/cortex_m/ops/op_quantized_transpose_conv2d.cpp @@ -1,6 +1,7 @@ /* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. + * Copyright 2026 Arm Limited and/or its affiliates. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. @@ -82,6 +83,7 @@ bool validate_transpose_conv2d_arguments( } } // namespace +// cppcheck-suppress unusedFunction Tensor& quantized_transpose_conv2d_out( KernelRuntimeContext& context, const Tensor& input, @@ -97,6 +99,8 @@ Tensor& quantized_transpose_conv2d_out( const Tensor& requantize_shifts, const int64_t activation_min, const int64_t activation_max, + const Tensor& scratch, + const Tensor& output_scratch, Tensor& out) { if (!validate_transpose_conv2d_arguments( context, @@ -179,44 +183,43 @@ Tensor& quantized_transpose_conv2d_out( cmsis_nn_context cmsis_context; cmsis_context.buf = nullptr; - cmsis_context.size = 0; + cmsis_context.size = scratch.nbytes(); + if (cmsis_context.size > 0) { + cmsis_context.buf = scratch.mutable_data_ptr(); + } cmsis_nn_context output_context; output_context.buf = nullptr; - output_context.size = 0; - + output_context.size = output_scratch.nbytes(); + if (output_context.size > 0) { + output_context.buf = output_scratch.mutable_data_ptr(); + } +#ifdef CORTEX_M_ENABLE_RUNTIME_CHECKS const int32_t buffer_bytes = arm_transpose_conv_s8_get_buffer_size( &transpose_conv_params, &input_dims, &filter_dims, &output_dims); - auto buffer_or_error = context.allocate_temp( - static_cast(buffer_bytes), kCortexMMveAlignment); - if (!buffer_or_error.ok()) { + if (scratch.nbytes() != static_cast(buffer_bytes)) { ET_LOG( Error, - "quantized_transpose_conv2d_out: failed to allocate scratch buffer (%d bytes, error %d)", - buffer_bytes, - static_cast(buffer_or_error.error())); - context.fail(buffer_or_error.error()); + "quantized_transpose_conv2d_out: scratch buffer size incorrect - actual: (%d) needed: (%d)", + static_cast(scratch.nbytes()), + buffer_bytes); + context.fail(Error::Internal); return out; } - cmsis_context.buf = buffer_or_error.get(); - cmsis_context.size = buffer_bytes; const int32_t output_buffer_bytes = arm_transpose_conv_s8_get_reverse_conv_buffer_size( &transpose_conv_params, &input_dims, &filter_dims); - auto output_buffer_or_error = context.allocate_temp( - static_cast(output_buffer_bytes), kCortexMMveAlignment); - if (!output_buffer_or_error.ok()) { + if (output_scratch.nbytes() != static_cast(output_buffer_bytes)) { ET_LOG( Error, - "quantized_transpose_conv2d_out: failed to allocate output scratch buffer (%d bytes, error %d)", - output_buffer_bytes, - static_cast(output_buffer_or_error.error())); - context.fail(output_buffer_or_error.error()); + "quantized_transpose_conv2d_out: output scratch buffer size incorrect - actual: (%d) needed: (%d)", + static_cast(output_scratch.nbytes()), + output_buffer_bytes); + context.fail(Error::Internal); return out; } - output_context.buf = output_buffer_or_error.get(); - output_context.size = output_buffer_bytes; +#endif const arm_cmsis_nn_status status = arm_transpose_conv_wrapper_s8( &cmsis_context, diff --git a/backends/cortex_m/ops/op_softmax.cpp b/backends/cortex_m/ops/op_softmax.cpp index c07a538db84..97d78d07a05 100644 --- a/backends/cortex_m/ops/op_softmax.cpp +++ b/backends/cortex_m/ops/op_softmax.cpp @@ -36,6 +36,7 @@ inline int64_t normalize_dim(const Tensor& tensor, int64_t dim) { } // namespace +// cppcheck-suppress unusedFunction Tensor& softmax_out( KernelRuntimeContext& context, const Tensor& input, diff --git a/backends/cortex_m/ops/op_transpose.cpp b/backends/cortex_m/ops/op_transpose.cpp index 7fcbc034283..9ef144296b7 100644 --- a/backends/cortex_m/ops/op_transpose.cpp +++ b/backends/cortex_m/ops/op_transpose.cpp @@ -22,6 +22,7 @@ constexpr size_t kMaxSupportedDims = 4; } // namespace +// cppcheck-suppress unusedFunction Tensor& transpose_out( KernelRuntimeContext& context, const Tensor& input, diff --git a/backends/cortex_m/ops/operators.py b/backends/cortex_m/ops/operators.py index 2c35ed8730b..d4393bc7ada 100644 --- a/backends/cortex_m/ops/operators.py +++ b/backends/cortex_m/ops/operators.py @@ -271,13 +271,15 @@ def quantized_mul_impl( "quantized_batch_matmul(" "Tensor lhs, int lhs_zero_point, " "Tensor rhs_transposed, int rhs_zero_point, " - "int output_zero_point, int output_multiplier, int output_shift) -> Tensor" + "int output_zero_point, int output_multiplier, int output_shift, " + "Tensor scratch) -> Tensor" ) lib.define( "quantized_batch_matmul.out(" "Tensor lhs, int lhs_zero_point, " "Tensor rhs_transposed, int rhs_zero_point, " "int output_zero_point, int output_multiplier, int output_shift, " + "Tensor scratch, " "*, Tensor(a!) out) -> Tensor(a!)" ) @@ -291,6 +293,7 @@ def quantized_batch_matmul_meta( output_zero_point: int, output_multiplier: int, output_shift: int, + scratch: torch.Tensor, ) -> torch.Tensor: batch, lhs_rows, inner = lhs.shape batch_rhs, rhs_cols, inner_rhs = rhs_transposed.shape @@ -307,6 +310,7 @@ def quantized_batch_matmul_impl( output_zero_point: int, output_multiplier: int, output_shift: int, + scratch: torch.Tensor, ) -> torch.Tensor: # Offsets are negated zero points (CMSIS-NN convention) lhs_fp = lhs.to(torch.float32) + float(lhs_zero_point) @@ -638,7 +642,8 @@ def pad_impl( "Tensor requantize_multipliers, " "Tensor requantize_shifts, " "int activation_min, " - "int activation_max" + "int activation_max, " + "Tensor scratch" ") -> Tensor" ) @@ -657,6 +662,7 @@ def pad_impl( "Tensor requantize_shifts, " "int activation_min, " "int activation_max, " + "Tensor scratch, " "*, Tensor(a!) out" ") -> Tensor(a!)" ) @@ -733,6 +739,7 @@ def quantized_conv2d_meta( requantize_shifts: torch.Tensor, activation_min: int, activation_max: int, + scratch: torch.Tensor, ) -> torch.Tensor: stride_vals = list(stride) padding_vals = list(padding) @@ -762,6 +769,7 @@ def quantized_conv2d_impl( requantize_shifts: torch.Tensor, activation_min: int, activation_max: int, + scratch: torch.Tensor, ) -> torch.Tensor: if input.dim() != 4 or weight.dim() != 4: raise RuntimeError("quantized_conv2d expects 4D input and weight tensors") @@ -830,7 +838,8 @@ def quantized_conv2d_impl( "Tensor requantize_multipliers, " "Tensor requantize_shifts, " "int activation_min, " - "int activation_max" + "int activation_max, " + "Tensor scratch" ") -> Tensor" ) @@ -850,6 +859,7 @@ def quantized_conv2d_impl( "Tensor requantize_shifts, " "int activation_min, " "int activation_max, " + "Tensor scratch, " "*, Tensor(a!) out" ") -> Tensor(a!)" ) @@ -870,6 +880,7 @@ def quantized_depthwise_conv2d_meta( requantize_shifts: torch.Tensor, activation_min: int, activation_max: int, + scratch: torch.Tensor, ) -> torch.Tensor: stride_vals = list(stride) padding_vals = list(padding) @@ -900,6 +911,7 @@ def quantized_depthwise_conv2d_impl( requantize_shifts: torch.Tensor, activation_min: int, activation_max: int, + scratch: torch.Tensor, ) -> torch.Tensor: if input.dim() != 4 or weight.dim() != 4: raise RuntimeError( @@ -973,7 +985,9 @@ def quantized_depthwise_conv2d_impl( "Tensor requantize_multipliers, " "Tensor requantize_shifts, " "int activation_min, " - "int activation_max" + "int activation_max, " + "Tensor scratch, " + "Tensor output_scratch" ") -> Tensor" ) @@ -992,6 +1006,8 @@ def quantized_depthwise_conv2d_impl( "Tensor requantize_shifts, " "int activation_min, " "int activation_max, " + "Tensor scratch, " + "Tensor output_scratch, " "*, Tensor(a!) out) -> Tensor(a!)" ) @@ -1057,6 +1073,8 @@ def quantized_transpose_conv2d_meta( requantize_shifts: torch.Tensor, activation_min: int, activation_max: int, + scratch: torch.Tensor, + output_scratch: torch.Tensor, ) -> torch.Tensor: stride_vals = list(stride) padding_vals = list(padding) @@ -1095,6 +1113,8 @@ def quantized_transpose_conv2d_impl( requantize_shifts: torch.Tensor, activation_min: int, activation_max: int, + scratch: torch.Tensor, + output_scratch: torch.Tensor, ) -> torch.Tensor: """ Reference implementation of quantized transposed convolution. diff --git a/backends/cortex_m/ops/operators.yaml b/backends/cortex_m/ops/operators.yaml index e0ebbfab868..8db109dea43 100644 --- a/backends/cortex_m/ops/operators.yaml +++ b/backends/cortex_m/ops/operators.yaml @@ -65,19 +65,20 @@ - arg_meta: null kernel_name: cortex_m::pad_out -- func: cortex_m::quantized_conv2d.out(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int input_offset, int output_offset, Tensor requantize_multipliers, Tensor requantize_shifts, int activation_min, int activation_max, *, Tensor(a!) out) -> Tensor(a!) +- func: cortex_m::quantized_conv2d.out(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int input_offset, int output_offset, Tensor requantize_multipliers, Tensor requantize_shifts, int activation_min, int activation_max, Tensor scratch, *, Tensor(a!) out) -> Tensor(a!) variants: function kernels: - arg_meta: null kernel_name: cortex_m::quantized_conv2d_out -- func: cortex_m::quantized_depthwise_conv2d.out(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int depth_multiplier, int input_offset, int output_offset, Tensor requantize_multipliers, Tensor requantize_shifts, int activation_min, int activation_max, *, Tensor(a!) out) -> Tensor(a!) + +- func: cortex_m::quantized_depthwise_conv2d.out(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int depth_multiplier, int input_offset, int output_offset, Tensor requantize_multipliers, Tensor requantize_shifts, int activation_min, int activation_max, Tensor scratch, *, Tensor(a!) out) -> Tensor(a!) variants: function kernels: - arg_meta: null kernel_name: cortex_m::quantized_depthwise_conv2d_out -- func: cortex_m::quantized_transpose_conv2d.out(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] output_padding, int[] dilation, int input_offset, int output_offset, Tensor requantize_multipliers, Tensor requantize_shifts, int activation_min, int activation_max, *, Tensor(a!) out) -> Tensor(a!) +- func: cortex_m::quantized_transpose_conv2d.out(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] output_padding, int[] dilation, int input_offset, int output_offset, Tensor requantize_multipliers, Tensor requantize_shifts, int activation_min, int activation_max, Tensor scratch, Tensor output_scratch, *, Tensor(a!) out) -> Tensor(a!) variants: function kernels: - arg_meta: null @@ -94,7 +95,7 @@ - arg_meta: null kernel_name: cortex_m::quantized_max_pool2d_out -- func: cortex_m::quantized_batch_matmul.out(Tensor lhs, int lhs_zero_point, Tensor rhs_transposed, int rhs_zero_point, int output_zero_point, int output_multiplier, int output_shift, *, Tensor(a!) out) -> Tensor(a!) +- func: cortex_m::quantized_batch_matmul.out(Tensor lhs, int lhs_zero_point, Tensor rhs_transposed, int rhs_zero_point, int output_zero_point, int output_multiplier, int output_shift, Tensor scratch, *, Tensor(a!) out) -> Tensor(a!) variants: function kernels: - arg_meta: null diff --git a/backends/cortex_m/passes/BUCK b/backends/cortex_m/passes/BUCK index 4e49c8cd319..f1b7b9a201d 100644 --- a/backends/cortex_m/passes/BUCK +++ b/backends/cortex_m/passes/BUCK @@ -36,6 +36,7 @@ fbcode_target(_kind = runtime.python_library, "decompose_hardswish_pass.py", "decompose_mean_pass.py", "quantized_clamp_activation_pass.py", + "scratch_buffer_sizes.py", ], deps=[ "//caffe2:torch", diff --git a/backends/cortex_m/passes/__init__.py b/backends/cortex_m/passes/__init__.py index 92179ec6654..c379461949f 100644 --- a/backends/cortex_m/passes/__init__.py +++ b/backends/cortex_m/passes/__init__.py @@ -33,6 +33,7 @@ def _ensure_cortex_m_dependencies() -> None: _ensure_cortex_m_dependencies() +from .cortex_m_pass import CortexMPass # noqa # usort: skip from .activation_fusion_pass import ActivationFusionPass # noqa from .clamp_hardswish_pass import ClampHardswishPass # noqa from .convert_to_cortex_m_pass import ConvertToCortexMPass # noqa diff --git a/backends/cortex_m/passes/convert_to_cortex_m_pass.py b/backends/cortex_m/passes/convert_to_cortex_m_pass.py index 418f6cd63ff..5704645caf8 100644 --- a/backends/cortex_m/passes/convert_to_cortex_m_pass.py +++ b/backends/cortex_m/passes/convert_to_cortex_m_pass.py @@ -6,25 +6,32 @@ # LICENSE file in the root directory of this source tree. import executorch.backends.cortex_m.ops.operators # noqa +import executorch.exir as exir import torch import torch.fx from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor + +from executorch.backends.cortex_m.passes.cortex_m_pass import CortexMPass from executorch.backends.cortex_m.passes.passes_utils import quantize_multiplier_aot +from executorch.backends.cortex_m.passes.scratch_buffer_sizes import ( + required_cmsis_nn_buffer_sizes, +) from executorch.backends.transforms.utils import ( create_constant_placeholder, get_param_tensor, is_param_node, ) - -from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.passes import make_alloc_node +from torch._subclasses.fake_tensor import FakeTensorMode + from torch.export.graph_signature import InputKind from torch.fx.passes.infra.pass_manager import PassResult -class ConvertToCortexMPass(XNNPACKPass): +class ConvertToCortexMPass(CortexMPass): """ Cortex-M backend pass for replacing supported quantized kernels with Cortex-M accelerated kernels. @@ -33,6 +40,15 @@ class ConvertToCortexMPass(XNNPACKPass): by call_operator. """ + def _create_uninitialized_alloc_node(self): + """Create an unitialized alloc node to be initialize at a later point.""" + with FakeTensorMode() as mode: + return make_alloc_node( + self.exported_program.graph_module, + mode.from_tensor(torch.empty(0)), + None, + ) + def _compute_kernel_sum(self, weights, bias, input_offset, weight_offset): """ Computes the precomputed kernel sum term (bias optional) @@ -238,6 +254,9 @@ def _get_convolution_replacement(self, node): torch.tensor(quantized_shifts, dtype=torch.int32), ) + with node.graph.inserting_before(node): + scratch = self._create_uninitialized_alloc_node() + if use_depthwise_conv: # Compute depth_multiplier for depthwise convolution # For depthwise: output_channels = input_channels * depth_multiplier @@ -263,6 +282,7 @@ def _get_convolution_replacement(self, node): quantized_shift_tensor, output_qmin, output_qmax, + scratch, ) return exir_ops.edge.cortex_m.quantized_depthwise_conv2d.default, new_args else: @@ -280,9 +300,36 @@ def _get_convolution_replacement(self, node): quantized_shift_tensor, output_qmin, output_qmax, + scratch, ) return exir_ops.edge.cortex_m.quantized_conv2d.default, new_args + def _initialize_alloc_node_size(self, node: torch.fx.Node) -> None: + """For nodes with a registered buffer size function for node.target, set the buffer sizes + of the last n args, which should be exir.memory.alloc nodes. For nodes without a + registered function, do nothing. + """ + + scratch_buffer_sizes = required_cmsis_nn_buffer_sizes( + node, self.target_config.backend + ) + if scratch_buffer_sizes is None: + return + + # Assume that scratch_buffer_sizes are given from left to right in the call signature of node.target. + for i, scratch_buffer_size in enumerate(reversed(scratch_buffer_sizes)): + scratch_arg = node.args[-(i + 1)] + if ( + not isinstance(scratch_arg, torch.fx.Node) + or scratch_arg.target != exir.memory.alloc + ): + raise RuntimeError( + f"Expected scratch alloc node as final argument(s) for {node.target}, got {scratch_arg}." + ) + + # buffer size is given in bytes, always use uint8 as dtype. + scratch_arg.args = (((scratch_buffer_size,), torch.uint8),) + def _get_transpose_conv2d_replacement(self, node): """ Transform aten.convolution with transposed=True to cortex_m.quantized_transpose_conv2d @@ -363,6 +410,10 @@ def _get_transpose_conv2d_replacement(self, node): torch.tensor(quantized_shifts, dtype=torch.int32), ) + with node.graph.inserting_before(node): + scratch = self._create_uninitialized_alloc_node() + output_scratch = self._create_uninitialized_alloc_node() + new_args = ( x, weight_nhwc, @@ -377,6 +428,8 @@ def _get_transpose_conv2d_replacement(self, node): quantized_shift_tensor, output_qmin, output_qmax, + scratch, + output_scratch, ) return exir_ops.edge.cortex_m.quantized_transpose_conv2d.default, new_args @@ -415,6 +468,9 @@ def _get_bmm_replacement(self, node): args=(rhs_node, [0, 2, 1]), ) + with node.graph.inserting_before(node): + scratch = self._create_uninitialized_alloc_node() + args = ( lhs_node, -lhs_zp, @@ -423,6 +479,7 @@ def _get_bmm_replacement(self, node): output_zp, output_mult, output_shift, + scratch, ) return exir_ops.edge.cortex_m.quantized_batch_matmul.default, args @@ -459,6 +516,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: args=args, kwargs={}, ) + self._initialize_alloc_node_size(cortex_m_op) node.replace_all_uses_with(cortex_m_op) graph_module.graph.erase_node(node) diff --git a/backends/cortex_m/passes/scratch_buffer_sizes.py b/backends/cortex_m/passes/scratch_buffer_sizes.py new file mode 100644 index 00000000000..36f3f8bbc17 --- /dev/null +++ b/backends/cortex_m/passes/scratch_buffer_sizes.py @@ -0,0 +1,266 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from collections.abc import Callable +from typing import Any, cast + +import cmsis_nn # type: ignore[import-not-found, import-untyped] +import executorch.backends.cortex_m.ops.operators # noqa + +import torch +import torch.fx + +from executorch.exir.dialects._ops import ops as exir_ops + +BufferSizeFunction = Callable[[cmsis_nn.Backend, torch.fx.Node], list[int]] + + +def _tensor_from_node(node: torch.fx.Node) -> torch.Tensor: + if "val" in node.meta: + return node.meta["val"] + elif node.op == "call_function": + args = ( + _tensor_from_node(arg) if isinstance(arg, torch.fx.Node) else arg + for arg in node.args + ) + return node.target(*args, **node.kwargs) # type: ignore[operator] + else: + raise RuntimeError("Encountered non-call_function without 'val' meta.") + + +def _shape_from_node(node: torch.fx.Node) -> torch.Size: + return _tensor_from_node(node).shape + + +def _get_common_conv_buffer_size_inputs( + conv_node: torch.fx.Node, + *, + stride_arg_idx: int = 3, + padding_arg_idx: int = 4, + dilation_arg_idx: int = 5, +) -> tuple[ + list[int], + list[int], + list[int], + list[int], + list[int], + list[int], +]: + x = cast(torch.fx.Node, conv_node.args[0]) + weight = cast(torch.fx.Node, conv_node.args[1]) + stride = cast(list[int], conv_node.args[stride_arg_idx]) + padding = cast(list[int], conv_node.args[padding_arg_idx]) + dilation = cast(list[int], conv_node.args[dilation_arg_idx]) + + # Input is NCHW (PyTorch); CMSIS-NN wants NHWC dims. + n, c_in, height, width = _shape_from_node(x) + + weight_shape = _shape_from_node(weight) + + # Output is NCHW; convert to NHWC dims. + out_n, out_c, out_h, out_w = _shape_from_node(conv_node) + + input_nhwc = [n, height, width, c_in] + output_nhwc = [out_n, out_h, out_w, out_c] + stride_hw = [int(stride[0]), int(stride[1])] + padding_hw = [int(padding[0]), int(padding[1])] + dilation_hw = [int(dilation[0]), int(dilation[1])] + + return ( + input_nhwc, + list(weight_shape), + output_nhwc, + stride_hw, + padding_hw, + dilation_hw, + ) + + +def cmsis_nn_conv_buffer_size( + backend: cmsis_nn.Backend, + conv_node: torch.fx.Node, +) -> list[int]: + ( + input_nhwc, + weight_shape, + output_nhwc, + stride_hw, + padding_hw, + dilation_hw, + ) = _get_common_conv_buffer_size_inputs(conv_node=conv_node) + input_offset = cast(int, conv_node.args[6]) + output_offset = cast(int, conv_node.args[7]) + output_qmin = cast(int, conv_node.args[10]) + output_qmax = cast(int, conv_node.args[11]) + + # Weight is in OHWI layout after conversion. + c_out, kernel_h, kernel_w, c_in = weight_shape + filter_nhwc = [c_out, kernel_h, kernel_w, c_in] + + return [ + int( + cmsis_nn.convolve_wrapper_buffer_size( + backend, + cmsis_nn.DataType.A8W8, + input_nhwc=input_nhwc, + filter_nhwc=filter_nhwc, + output_nhwc=output_nhwc, + padding_hw=padding_hw, + stride_hw=stride_hw, + dilation_hw=dilation_hw, + input_offset=input_offset, + output_offset=output_offset, + activation_min=output_qmin, + activation_max=output_qmax, + ) + ) + ] + + +def cmsis_nn_depthwise_conv_buffer_size( + backend: cmsis_nn.Backend, + conv_node: torch.fx.Node, +) -> list[int]: + ( + input_nhwc, + weight_shape, + output_nhwc, + stride_hw, + padding_hw, + dilation_hw, + ) = _get_common_conv_buffer_size_inputs(conv_node=conv_node) + depth_multiplier = cast(int, conv_node.args[6]) + input_offset = cast(int, conv_node.args[7]) + output_offset = cast(int, conv_node.args[8]) + output_qmin = cast(int, conv_node.args[11]) + output_qmax = cast(int, conv_node.args[12]) + + # Weight is in IHWO layout after conversion. + _, kernel_h, kernel_w, c_out = weight_shape + filter_nhwc = [c_out, kernel_h, kernel_w, 1] + + return [ + int( + cmsis_nn.depthwise_conv_wrapper_buffer_size( + backend, + cmsis_nn.DataType.A8W8, + input_nhwc=input_nhwc, + filter_nhwc=filter_nhwc, + output_nhwc=output_nhwc, + padding_hw=padding_hw, + stride_hw=stride_hw, + dilation_hw=dilation_hw, + ch_mult=depth_multiplier, + input_offset=input_offset, + output_offset=output_offset, + activation_min=output_qmin, + activation_max=output_qmax, + ) + ) + ] + + +def cmsis_nn_batch_matmul_buffer_size( + backend: cmsis_nn.Backend, + matmul_node: torch.fx.Node, +) -> list[int]: + rhs_transposed = cast(torch.fx.Node, matmul_node.args[2]) + rhs_shape = _shape_from_node(rhs_transposed) + + _, rhs_cols, inner = rhs_shape + + return [ + int( + cmsis_nn.fully_connected_buffer_size( + backend, + cmsis_nn.DataType.A8W8, + filter_nhwc=[inner, -1, -1, rhs_cols], # H and W values are unused. + ) + ) + ] + + +def cmsis_nn_transpose_conv_buffer_size( + backend: cmsis_nn.Backend, + conv_node: torch.fx.Node, +) -> list[int]: + ( + input_nhwc, + weight_shape, + output_nhwc, + stride_hw, + padding_hw, + dilation_hw, + ) = _get_common_conv_buffer_size_inputs( + conv_node=conv_node, + stride_arg_idx=3, + padding_arg_idx=4, + dilation_arg_idx=6, + ) + output_padding = cast(list[int], conv_node.args[5]) + input_offset = cast(int, conv_node.args[7]) + output_offset = cast(int, conv_node.args[8]) + output_qmin = cast(int, conv_node.args[11]) + output_qmax = cast(int, conv_node.args[12]) + c_out, kernel_h, kernel_w, kernel_c_in = weight_shape + filter_nhwc = [c_out, kernel_h, kernel_w, kernel_c_in] + padding_offsets_hw = [int(output_padding[0]), int(output_padding[1])] + + return [ + int( + cmsis_nn.transpose_conv_buffer_size( + backend, + cmsis_nn.DataType.A8W8, + input_nhwc=input_nhwc, + filter_nhwc=filter_nhwc, + output_nhwc=output_nhwc, + padding_hw=padding_hw, + stride_hw=stride_hw, + dilation_hw=dilation_hw, + padding_offsets_hw=padding_offsets_hw, + input_offset=input_offset, + output_offset=output_offset, + activation_min=output_qmin, + activation_max=output_qmax, + ) + ), + int( + cmsis_nn.transpose_conv_reverse_conv_buffer_size( + backend, + cmsis_nn.DataType.A8W8, + input_nhwc=input_nhwc, + filter_nhwc=filter_nhwc, + padding_hw=padding_hw, + stride_hw=stride_hw, + dilation_hw=dilation_hw, + padding_offsets_hw=padding_offsets_hw, + input_offset=input_offset, + output_offset=output_offset, + activation_min=output_qmin, + activation_max=output_qmax, + ) + ), + ] + + +_target_to_buffer_sizes_registry: dict[Any, BufferSizeFunction] = { + exir_ops.edge.cortex_m.quantized_conv2d.default: cmsis_nn_conv_buffer_size, + exir_ops.edge.cortex_m.quantized_depthwise_conv2d.default: cmsis_nn_depthwise_conv_buffer_size, + exir_ops.edge.cortex_m.quantized_batch_matmul.default: cmsis_nn_batch_matmul_buffer_size, + exir_ops.edge.cortex_m.quantized_transpose_conv2d.default: cmsis_nn_transpose_conv_buffer_size, +} + + +def required_cmsis_nn_buffer_sizes( + node: torch.fx.Node, backend: cmsis_nn.Backend +) -> list[int] | None: + """Returns a sequence of scratch buffer sizes required by node, in bytes. + If no function is registered to compute this for the target of the node, return None. + """ + if node.target not in _target_to_buffer_sizes_registry: + return None + + buffer_size_function = _target_to_buffer_sizes_registry[node.target] + return buffer_size_function(backend, node) diff --git a/backends/cortex_m/test/build_test_runner.sh b/backends/cortex_m/test/build_test_runner.sh index bdca1a21e7c..a67c5a907a4 100755 --- a/backends/cortex_m/test/build_test_runner.sh +++ b/backends/cortex_m/test/build_test_runner.sh @@ -28,7 +28,7 @@ fi script_dir=$(realpath "$(dirname "${BASH_SOURCE[0]}")") et_root_dir=$(realpath "${script_dir}/../../..") build_executorch="${et_root_dir}/backends/arm/scripts/build_executorch.sh" -${build_executorch} --devtools --target_cpu="${target_cpu}" +${build_executorch} --devtools --target_cpu="${target_cpu}" --cmake-args="-DCORTEX_M_ENABLE_RUNTIME_CHECKS=ON" # Build executor runner with selected aten ops and semi hosting build_dir="${et_root_dir}/arm_test" @@ -48,4 +48,4 @@ aten::unsqueeze_copy.out,\ aten::select_copy.int_out,\ aten::amax.out" -${build_executor_runner} --pte=semihosting --bundleio --target="${target}" --output="${build_root_test_dir}" --select_ops_list="${select_ops_list}" --extra_build_flags="-DET_ATOL=5.0 -DET_RTOL=1.0" +${build_executor_runner} --pte=semihosting --bundleio --target="${target}" --output="${build_root_test_dir}" --select_ops_list="${select_ops_list}" --extra_build_flags="-DET_ATOL=5.0 -DET_RTOL=1.0 -DET_ARM_BAREMETAL_SCRATCH_TEMP_ALLOCATOR_POOL_SIZE=0" diff --git a/backends/cortex_m/test/models/test_silero_vad.py b/backends/cortex_m/test/models/test_silero_vad.py new file mode 100644 index 00000000000..27b958627bb --- /dev/null +++ b/backends/cortex_m/test/models/test_silero_vad.py @@ -0,0 +1,94 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.backends.arm.test.common import parametrize +from executorch.backends.cortex_m.test.tester import CortexMTester, McuTestCase +from executorch.examples.models.silero_vad.export_silero_vad import ( + CONTEXT_SIZE, + HIDDEN_DIM, + SileroVAD16k, + WINDOW_SIZE, +) + + +ops_before_transforms: dict[str, int] = { + "executorch_exir_dialects_edge__ops_aten_abs_default": 2, + "executorch_exir_dialects_edge__ops_aten_add_Tensor": 3, + "executorch_exir_dialects_edge__ops_aten_arange_start_step": 1, + "executorch_exir_dialects_edge__ops_aten_cat_default": 1, + "executorch_exir_dialects_edge__ops_aten_convolution_default": 6, + "executorch_exir_dialects_edge__ops_aten_index_Tensor": 1, + "executorch_exir_dialects_edge__ops_aten_linear_default": 2, + "executorch_exir_dialects_edge__ops_aten_mean_dim": 1, + "executorch_exir_dialects_edge__ops_aten_mul_Tensor": 3, + "executorch_exir_dialects_edge__ops_aten_pow_Tensor_Scalar": 2, + "executorch_exir_dialects_edge__ops_aten_relu_default": 5, + "executorch_exir_dialects_edge__ops_aten_select_copy_int": 2, + "executorch_exir_dialects_edge__ops_aten_sigmoid_default": 4, + "executorch_exir_dialects_edge__ops_aten_slice_copy_Tensor": 2, + "executorch_exir_dialects_edge__ops_aten_split_with_sizes_copy_default": 1, + "executorch_exir_dialects_edge__ops_aten_sqrt_default": 1, + "executorch_exir_dialects_edge__ops_aten_squeeze_copy_dims": 2, + "executorch_exir_dialects_edge__ops_aten_sub_Tensor": 2, + "executorch_exir_dialects_edge__ops_aten_tanh_default": 2, + "executorch_exir_dialects_edge__ops_aten_unsqueeze_copy_default": 2, + "executorch_exir_dialects_edge__ops_aten_view_copy_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 12, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 11, +} +ops_after_transforms: dict[str, int] = { + "executorch_exir_dialects_edge__ops_aten_abs_default": 2, + "executorch_exir_dialects_edge__ops_aten_add_Tensor": 2, + "executorch_exir_dialects_edge__ops_aten_arange_start_step": 1, + "executorch_exir_dialects_edge__ops_aten_cat_default": 1, + "executorch_exir_dialects_edge__ops_aten_convolution_default": 6, + "executorch_exir_dialects_edge__ops_aten_index_Tensor": 1, + "executorch_exir_dialects_edge__ops_aten_linear_default": 2, + "executorch_exir_dialects_edge__ops_aten_mean_dim": 1, + "executorch_exir_dialects_edge__ops_aten_mul_Tensor": 3, + "executorch_exir_dialects_edge__ops_aten_pow_Tensor_Scalar": 2, + "executorch_exir_dialects_edge__ops_aten_relu_default": 5, + "executorch_exir_dialects_edge__ops_aten_select_copy_int": 2, + "executorch_exir_dialects_edge__ops_aten_sigmoid_default": 4, + "executorch_exir_dialects_edge__ops_aten_slice_copy_Tensor": 2, + "executorch_exir_dialects_edge__ops_aten_split_with_sizes_copy_default": 1, + "executorch_exir_dialects_edge__ops_aten_sqrt_default": 1, + "executorch_exir_dialects_edge__ops_aten_squeeze_copy_dims": 2, + "executorch_exir_dialects_edge__ops_aten_sub_Tensor": 2, + "executorch_exir_dialects_edge__ops_aten_tanh_default": 2, + "executorch_exir_dialects_edge__ops_aten_unsqueeze_copy_default": 2, + "executorch_exir_dialects_edge__ops_aten_view_copy_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 6, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 6, + "executorch_exir_dialects_edge__ops_cortex_m_quantized_add_default": 1, +} + + +pt_model = SileroVAD16k().eval() + +x = torch.randn( + 1, CONTEXT_SIZE + WINDOW_SIZE +) # (1, 576) — 64 context + 512 audio samples +state = torch.zeros(2, 1, HIDDEN_DIM) # (2, 1, 128) — [h, c] LSTM state + +test_cases = { + "silero_vad_16k": McuTestCase( + model=pt_model, + example_inputs=lambda: (x, state), + ), +} + + +@parametrize("test_case", test_cases) +def test_dialect_silero_vad_16k(test_case): + """This model currently does largely not lower to accelerated kernels due to missing LSTM and conv1d support, this test is to track development progress.""" + inputs = test_case.get_example_inputs() + tester = CortexMTester(test_case.model, inputs) + tester.test_dialect( + ops_before_transforms, + ops_after_transforms, + qtol=10, + ) diff --git a/backends/cortex_m/test/models/test_wav2letter.py b/backends/cortex_m/test/models/test_wav2letter.py new file mode 100644 index 00000000000..ddc5354293c --- /dev/null +++ b/backends/cortex_m/test/models/test_wav2letter.py @@ -0,0 +1,34 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from executorch.backends.arm.test.common import parametrize +from executorch.backends.cortex_m.test.tester import CortexMTester, McuTestCase +from executorch.examples.models.wav2letter.model import Wav2LetterModel + + +ops_before_transforms: dict[str, int] = {} +ops_after_transforms: dict[str, int] = {} + +model = Wav2LetterModel() +pt_model = model.get_eager_model() + +test_cases = { + "wav2letter": McuTestCase( + model=pt_model, + example_inputs=lambda: model.get_example_inputs(), + ), +} + + +@parametrize("test_case", test_cases) +def test_dialect_wav2letter(test_case): + """This model currently does largely not lower to accelerated kernels due to missing conv1d support, this test is to track development progress.""" + inputs = test_case.get_example_inputs() + tester = CortexMTester(test_case.model, inputs) + tester.test_dialect( + ops_before_transforms, + ops_after_transforms, + qtol=10, + ) diff --git a/backends/cortex_m/test/models/test_yolo11.py b/backends/cortex_m/test/models/test_yolo11.py new file mode 100644 index 00000000000..f17c5ced331 --- /dev/null +++ b/backends/cortex_m/test/models/test_yolo11.py @@ -0,0 +1,45 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +import torch +from executorch.backends.arm.test.common import parametrize + +from executorch.backends.cortex_m.test.tester import CortexMTester, McuTestCase + +YOLO = pytest.importorskip( + "ultralytics", + reason="ultralytics is optional; install it locally to run YOLO tests.", +).YOLO + + +ops_before_transforms: dict[str, int] = {} +ops_after_transforms: dict[str, int] = {} + + +WEIGHTS = "yolo11n.pt" +yolo = YOLO(WEIGHTS) +pt_model = yolo.model.eval() + +test_cases = { + "yolo11n": McuTestCase( + model=pt_model, + example_inputs=lambda: ( + torch.randn(1, 3, 640, 640).to(memory_format=torch.channels_last), + ), + ), +} + + +@parametrize("test_case", test_cases) +def test_dialect_yolo11(test_case): + """This model currently does not lower in the cortex-m backend, this test is to track development progress.""" + inputs = test_case.get_example_inputs() + tester = CortexMTester(test_case.model, inputs) + tester.test_dialect( + ops_before_transforms, + ops_after_transforms, + qtol=10, + ) diff --git a/backends/cuda/CMakeLists.txt b/backends/cuda/CMakeLists.txt index 217c893efe5..d56e994eab4 100644 --- a/backends/cuda/CMakeLists.txt +++ b/backends/cuda/CMakeLists.txt @@ -103,7 +103,7 @@ install( ) # CUDA-specific AOTI shim symbols (dynamically linked) -set(_aoti_cuda_shim_sources runtime/shims/memory.cpp +set(_aoti_cuda_shim_sources runtime/cuda_allocator.cpp runtime/shims/memory.cpp runtime/shims/cuda_guard.cpp ) @@ -180,8 +180,12 @@ install( # CUDA backend implementation set(_aoti_cuda_backend_sources runtime/cuda_backend.cpp) +if(_cuda_is_msvc_toolchain) + # MSVC links aoti_cuda_backend into portable_lib without relying on C++ + # symbols exported from aoti_cuda_shims.dll. + list(APPEND _aoti_cuda_backend_sources runtime/cuda_allocator.cpp) +endif() -# CUDA backend implementation add_library(aoti_cuda_backend STATIC ${_aoti_cuda_backend_sources}) target_include_directories( diff --git a/backends/cuda/runtime/TARGETS b/backends/cuda/runtime/TARGETS index f13f41ab8b7..c8449a95718 100644 --- a/backends/cuda/runtime/TARGETS +++ b/backends/cuda/runtime/TARGETS @@ -74,6 +74,33 @@ runtime.cxx_library( ], ) +runtime.cxx_library( + name = "cuda_allocator", + srcs = [ + "cuda_allocator.cpp", + ], + headers = [ + "cuda_allocator.h", + ], + # @lint-ignore BUCKLINT: Avoid `link_whole=True` (https://fburl.com/avoid-link-whole) + link_whole = True, + supports_python_dlopen = True, + visibility = ["PUBLIC"], + exported_deps = [ + "//executorch/runtime/core:device_allocator", + ], + deps = [ + "//executorch/runtime/platform:platform", + ], + nvcc_flags = get_nvcc_arch_args() + [ + "-_NVCC_HOST_COMPILER_FLAG_", + "gcc", + ], + external_deps = [ + ("cuda", None, "cuda-lazy"), + ], +) + runtime.cxx_library( name = "cuda_backend", srcs = [ @@ -92,6 +119,8 @@ runtime.cxx_library( deps = [ ":cuda_platform", ":runtime_shims", + ":cuda_allocator", + ":cuda_platform", "//executorch/backends/aoti:aoti_common_slim", "//executorch/backends/aoti/slim/core:slimtensor", "//executorch/backends/aoti/slim/factory:empty", diff --git a/backends/cuda/runtime/cuda_allocator.cpp b/backends/cuda/runtime/cuda_allocator.cpp new file mode 100644 index 00000000000..94294b08fa0 --- /dev/null +++ b/backends/cuda/runtime/cuda_allocator.cpp @@ -0,0 +1,258 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include + +namespace executorch::backends::cuda { + +using executorch::runtime::Error; +using executorch::runtime::Result; +using executorch::runtime::etensor::DeviceIndex; +using executorch::runtime::etensor::DeviceType; + +Result +CudaAllocator::allocate(size_t nbytes, DeviceIndex index, size_t alignment) { + // index == -1 means "use the current CUDA device"; any value < -1 is invalid. + ET_CHECK_OR_RETURN_ERROR( + index >= -1, + InvalidArgument, + "CudaAllocator::allocate: invalid device index %d (must be >= -1)", + static_cast(index)); + + // Alignment must be a non-zero power of 2. + ET_CHECK_OR_RETURN_ERROR( + alignment != 0 && (alignment & (alignment - 1)) == 0, + InvalidArgument, + "CudaAllocator::allocate: alignment must be a power of 2, got %zu", + alignment); + + // cudaMalloc is documented to return memory aligned to at least 256 bytes, + // which trivially satisfies kDefaultAlignment (alignof(void*)). For any + // requested alignment <= 256 bytes, the returned pointer is already aligned. + // Stricter alignment would require over-allocation plus bookkeeping that + // deallocate() does not currently support, so reject that case. + constexpr size_t kCudaMallocAlignment = 256; + ET_CHECK_OR_RETURN_ERROR( + alignment <= kCudaMallocAlignment, + NotSupported, + "CudaAllocator::allocate: requested alignment %zu exceeds cudaMalloc's " + "guaranteed alignment of %zu bytes; stricter alignment is not supported", + alignment, + kCudaMallocAlignment); + + void* ptr = nullptr; + int prev_device = 0; + cudaError_t prev_device_err = cudaGetDevice(&prev_device); + + // If index == -1, fall back to the current device returned by cudaGetDevice + // and skip the set/restore round-trip. + const bool switch_device = index >= 0 && prev_device_err == cudaSuccess && + static_cast(index) != prev_device; + if (switch_device) { + cudaSetDevice(index); + } + + cudaError_t err = cudaMalloc(&ptr, nbytes); + + if (switch_device) { + cudaSetDevice(prev_device); + } + + if (err != cudaSuccess) { + ET_LOG( + Error, + "cudaMalloc failed: %s (requested %zu bytes on device %d)", + cudaGetErrorString(err), + nbytes, + static_cast(index)); + return Error::MemoryAllocationFailed; + } + + // Sanity check: the pointer returned by cudaMalloc should already meet the + // requested alignment. If a future CUDA runtime weakens this guarantee, we + // want to fail loudly rather than silently return a misaligned pointer. + if ((reinterpret_cast(ptr) & (alignment - 1)) != 0) { + ET_LOG( + Error, + "cudaMalloc returned pointer %p not aligned to %zu bytes", + ptr, + alignment); + cudaFree(ptr); + return Error::MemoryAllocationFailed; + } + + return ptr; +} + +void CudaAllocator::deallocate(void* ptr, DeviceIndex index) { + if (ptr == nullptr) { + return; + } + + int prev_device = 0; + cudaError_t prev_device_err = cudaSuccess; + + if (index >= 0) { + prev_device_err = cudaGetDevice(&prev_device); + if (prev_device_err == cudaSuccess) { + cudaSetDevice(index); + } + } + + cudaError_t err = cudaFree(ptr); + + if (index >= 0 && prev_device_err == cudaSuccess) { + cudaSetDevice(prev_device); + } + + if (err != cudaSuccess) { + ET_LOG( + Error, + "cudaFree failed: %s (ptr=%p, device %d)", + cudaGetErrorString(err), + ptr, + static_cast(index)); + } +} + +// TODO(gasoonjia): Add support for async copy +Error CudaAllocator::copy_host_to_device( + void* dst, + const void* src, + size_t nbytes, + DeviceIndex index) { + int prev_device = 0; + cudaError_t prev_device_err = cudaSuccess; + + if (index >= 0) { + prev_device_err = cudaGetDevice(&prev_device); + if (prev_device_err == cudaSuccess) { + cudaSetDevice(index); + } + } + + cudaError_t err = cudaMemcpy(dst, src, nbytes, cudaMemcpyHostToDevice); + + if (index >= 0 && prev_device_err == cudaSuccess) { + cudaSetDevice(prev_device); + } + + if (err != cudaSuccess) { + ET_LOG( + Error, + "cudaMemcpy H2D failed: %s (%zu bytes, device %d)", + cudaGetErrorString(err), + nbytes, + static_cast(index)); + return Error::Internal; + } + return Error::Ok; +} + +// TODO(gasoonjia): Add support for async copy +Error CudaAllocator::copy_device_to_host( + void* dst, + const void* src, + size_t nbytes, + DeviceIndex index) { + int prev_device = 0; + cudaError_t prev_device_err = cudaSuccess; + + if (index >= 0) { + prev_device_err = cudaGetDevice(&prev_device); + if (prev_device_err == cudaSuccess) { + cudaSetDevice(index); + } + } + + cudaError_t err = cudaMemcpy(dst, src, nbytes, cudaMemcpyDeviceToHost); + + if (index >= 0 && prev_device_err == cudaSuccess) { + cudaSetDevice(prev_device); + } + + if (err != cudaSuccess) { + ET_LOG( + Error, + "cudaMemcpy D2H failed: %s (%zu bytes, device %d)", + cudaGetErrorString(err), + nbytes, + static_cast(index)); + return Error::Internal; + } + return Error::Ok; +} + +DeviceType CudaAllocator::device_type() const { + return DeviceType::CUDA; +} + +CudaAllocator& CudaAllocator::instance() { + static CudaAllocator allocator; + return allocator; +} + +Result CudaAllocator::allocate_async( + size_t nbytes, + DeviceIndex index, + cudaStream_t stream) { + void* ptr = nullptr; + cudaError_t err = cudaMallocAsync(&ptr, nbytes, stream); + if (err != cudaSuccess) { + ET_LOG( + Error, + "cudaMallocAsync failed: %s (requested %zu bytes on device %d)", + cudaGetErrorString(err), + nbytes, + static_cast(index)); + return Error::MemoryAllocationFailed; + } + return ptr; +} + +void CudaAllocator::deallocate_async( + void* ptr, + DeviceIndex index, + cudaStream_t stream) { + if (ptr == nullptr) { + return; + } + cudaError_t err = cudaFreeAsync(ptr, stream); + if (err != cudaSuccess) { + ET_LOG( + Error, + "cudaFreeAsync failed: %s (ptr=%p, device %d)", + cudaGetErrorString(err), + ptr, + static_cast(index)); + } +} + +Error CudaAllocator::memcpy_async( + void* dst, + const void* src, + size_t nbytes, + cudaMemcpyKind direction, + cudaStream_t stream) { + cudaError_t err = cudaMemcpyAsync(dst, src, nbytes, direction, stream); + if (err != cudaSuccess) { + ET_LOG( + Error, + "cudaMemcpyAsync failed: %s (%zu bytes)", + cudaGetErrorString(err), + nbytes); + return Error::Internal; + } + return Error::Ok; +} + +} // namespace executorch::backends::cuda diff --git a/backends/cuda/runtime/cuda_allocator.h b/backends/cuda/runtime/cuda_allocator.h new file mode 100644 index 00000000000..fcd8224305a --- /dev/null +++ b/backends/cuda/runtime/cuda_allocator.h @@ -0,0 +1,84 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +#include + +namespace executorch::backends::cuda { + +/** + * CUDA implementation of DeviceAllocator. + * + * Uses cudaMalloc/cudaFree for allocation and cudaMemcpy for host-device + * transfers. This allocator is automatically registered as a singleton + * with the DeviceAllocatorRegistry when the CUDA backend library is linked. + * + * All CUDA memory operations in the CUDA backend should go through this + * allocator for consistent memory management. + */ +class CudaAllocator final : public executorch::runtime::DeviceAllocator { + public: + executorch::runtime::Result allocate( + size_t nbytes, + executorch::runtime::etensor::DeviceIndex index, + size_t alignment = kDefaultAlignment) override; + + void deallocate(void* ptr, executorch::runtime::etensor::DeviceIndex index) + override; + + executorch::runtime::Error copy_host_to_device( + void* dst, + const void* src, + size_t nbytes, + executorch::runtime::etensor::DeviceIndex index) override; + + executorch::runtime::Error copy_device_to_host( + void* dst, + const void* src, + size_t nbytes, + executorch::runtime::etensor::DeviceIndex index) override; + + executorch::runtime::etensor::DeviceType device_type() const override; + + /// Returns the global CudaAllocator singleton. + static CudaAllocator& instance(); + + // --- Async (stream-based) operations for SlimTensor/Storage layer --- + + /** + * Allocate device memory asynchronously on the given CUDA stream. + */ + static executorch::runtime::Result allocate_async( + size_t nbytes, + executorch::runtime::etensor::DeviceIndex index, + cudaStream_t stream); + + /** + * Deallocate device memory asynchronously on the given CUDA stream. + */ + static void deallocate_async( + void* ptr, + executorch::runtime::etensor::DeviceIndex index, + cudaStream_t stream); + + /** + * Copy memory asynchronously on the given CUDA stream. + * Supports H2D, D2H, and D2D based on src/dst device types. + */ + static executorch::runtime::Error memcpy_async( + void* dst, + const void* src, + size_t nbytes, + cudaMemcpyKind direction, + cudaStream_t stream); +}; + +} // namespace executorch::backends::cuda diff --git a/backends/cuda/runtime/cuda_backend.cpp b/backends/cuda/runtime/cuda_backend.cpp index 1497ba1e376..d2738f7a976 100644 --- a/backends/cuda/runtime/cuda_backend.cpp +++ b/backends/cuda/runtime/cuda_backend.cpp @@ -40,6 +40,7 @@ // Include our shim layer headers #include #include +#include #include #include #include @@ -1273,5 +1274,13 @@ auto cls = cuda::CudaBackend(); executorch::runtime::Backend backend{"CudaBackend", &cls}; static executorch::runtime::Error success_with_compiler = register_backend(backend); + +// Auto-register the CudaAllocator so that DeviceMemoryBuffer::create(CUDA) +// works whenever the CUDA backend library is linked. +static bool cuda_allocator_registered = [] { + executorch::runtime::register_device_allocator( + &cuda::CudaAllocator::instance()); + return true; +}(); } // namespace } // namespace executorch::backends diff --git a/backends/cuda/runtime/shims/tests/targets.bzl b/backends/cuda/runtime/shims/tests/targets.bzl index b68043f7feb..a54c47e979d 100644 --- a/backends/cuda/runtime/shims/tests/targets.bzl +++ b/backends/cuda/runtime/shims/tests/targets.bzl @@ -42,3 +42,27 @@ def define_common_targets(): cuda_shim_cpp_unittest("aoti_torch_new_tensor_handle") cuda_shim_cpp_unittest("aoti_torch_item_bool") cuda_shim_cpp_unittest("aoti_torch_assign_tensors_out") + + cpp_unittest( + name = "test_op__device_copy", + srcs = ["test_op__device_copy.cpp"], + deps = [ + "//executorch/backends/cuda/runtime:cuda_backend", + "//executorch/kernels/portable:generated_lib", + "//executorch/kernels/portable:generated_lib_headers", + "//executorch/kernels/portable/cpu:op__device_copy", + "//executorch/runtime/core:device_allocator", + "//executorch/runtime/core/exec_aten:lib", + "//executorch/runtime/core/portable_type:portable_type", + "//executorch/runtime/kernel:kernel_runtime_context", + "//executorch/runtime/platform:platform", + ], + external_deps = [ + ("cuda", None, "cuda-lazy"), + ], + preprocessor_flags = ["-DCUDA_AVAILABLE=1"], + keep_gpu_sections = True, + remote_execution = re_test_utils.remote_execution( + platform = "gpu-remote-execution", + ), + ) diff --git a/backends/cuda/runtime/shims/tests/test_op__device_copy.cpp b/backends/cuda/runtime/shims/tests/test_op__device_copy.cpp new file mode 100644 index 00000000000..4e5c5a099b7 --- /dev/null +++ b/backends/cuda/runtime/shims/tests/test_op__device_copy.cpp @@ -0,0 +1,195 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#if (defined(__has_feature) && __has_feature(address_sanitizer)) || \ + defined(__SANITIZE_ADDRESS__) +#include +#define EXECUTORCH_CUDA_DEVICE_COPY_HAS_LSAN_INTERFACE 1 +#else +#define EXECUTORCH_CUDA_DEVICE_COPY_HAS_LSAN_INTERFACE 0 +#endif + +#include +#include +#include + +using executorch::aten::ScalarType; +using executorch::aten::Tensor; +using executorch::aten::TensorImpl; +using executorch::runtime::Error; +using executorch::runtime::get_device_allocator; +using executorch::runtime::KernelRuntimeContext; +using executorch::runtime::TensorShapeDynamism; +using executorch::runtime::etensor::DeviceIndex; +using executorch::runtime::etensor::DeviceType; + +namespace { + +struct CudaDeleter { + void operator()(void* ptr) const { + if (ptr != nullptr) { + cudaFree(ptr); + } + } +}; + +using CudaPtr = std::unique_ptr; + +CudaPtr allocate_cuda(size_t nbytes) { + void* ptr = nullptr; + const cudaError_t err = cudaMalloc(&ptr, nbytes); + EXPECT_EQ(err, cudaSuccess) << "cudaMalloc failed"; + return CudaPtr(ptr); +} + +bool is_cuda_available() { +#if EXECUTORCH_CUDA_DEVICE_COPY_HAS_LSAN_INTERFACE + __lsan_disable(); +#endif + int device_count = 0; + const cudaError_t err = cudaGetDeviceCount(&device_count); +#if EXECUTORCH_CUDA_DEVICE_COPY_HAS_LSAN_INTERFACE + __lsan_enable(); +#endif + return err == cudaSuccess && device_count > 0; +} + +std::vector copy_cuda_to_host(const void* device_ptr, size_t numel) { + std::vector host(numel); + const cudaError_t err = cudaMemcpy( + host.data(), device_ptr, numel * sizeof(float), cudaMemcpyDeviceToHost); + EXPECT_EQ(err, cudaSuccess) << "cudaMemcpy D2H failed"; + return host; +} + +void copy_host_to_cuda(const std::vector& host, void* device_ptr) { + const cudaError_t err = cudaMemcpy( + device_ptr, + host.data(), + host.size() * sizeof(float), + cudaMemcpyHostToDevice); + EXPECT_EQ(err, cudaSuccess) << "cudaMemcpy H2D failed"; +} + +class CudaDeviceCopyOpTest : public ::testing::Test { + protected: + static void SetUpTestSuite() { + executorch::runtime::runtime_init(); + ASSERT_NE(get_device_allocator(DeviceType::CUDA), nullptr) + << "Linking cuda_backend should auto-register the CUDA allocator"; + } + + void SetUp() override { + if (!is_cuda_available()) { + GTEST_SKIP() << "CUDA not available, skipping CUDA device copy op tests"; + } + } + + Tensor& op_h2d_copy_out(const Tensor& self, Tensor& out) { + return torch::executor::et_copy::_h2d_copy_outf(context_, self, out); + } + + Tensor& op_d2h_copy_out(const Tensor& self, Tensor& out) { + return torch::executor::et_copy::_d2h_copy_outf(context_, self, out); + } + + KernelRuntimeContext context_; +}; + +} // namespace + +TEST_F(CudaDeviceCopyOpTest, H2dCopyUsesRegisteredCudaAllocator) { + std::vector src_data = {1.0f, 2.0f, 3.0f, 4.0f}; + auto device_data = allocate_cuda(src_data.size() * sizeof(float)); + ASSERT_NE(device_data.get(), nullptr); + + int32_t sizes[] = {static_cast(src_data.size())}; + uint8_t dim_order[] = {0}; + int32_t strides[] = {1}; + + TensorImpl src_impl( + ScalarType::Float, + 1, + sizes, + src_data.data(), + dim_order, + strides, + TensorShapeDynamism::STATIC, + DeviceType::CPU, + 0); + Tensor src(&src_impl); + + TensorImpl dst_impl( + ScalarType::Float, + 1, + sizes, + device_data.get(), + dim_order, + strides, + TensorShapeDynamism::STATIC, + DeviceType::CUDA, + 0); + Tensor dst(&dst_impl); + + Tensor& result = op_h2d_copy_out(src, dst); + + EXPECT_EQ(context_.failure_state(), Error::Ok); + EXPECT_EQ(&result, &dst); + EXPECT_EQ(copy_cuda_to_host(device_data.get(), src_data.size()), src_data); +} + +TEST_F(CudaDeviceCopyOpTest, D2hCopyUsesRegisteredCudaAllocator) { + const std::vector expected = {5.0f, 6.0f, 7.0f, 8.0f}; + auto device_data = allocate_cuda(expected.size() * sizeof(float)); + ASSERT_NE(device_data.get(), nullptr); + copy_host_to_cuda(expected, device_data.get()); + + std::vector dst_data(expected.size(), 0.0f); + int32_t sizes[] = {static_cast(expected.size())}; + uint8_t dim_order[] = {0}; + int32_t strides[] = {1}; + + TensorImpl src_impl( + ScalarType::Float, + 1, + sizes, + device_data.get(), + dim_order, + strides, + TensorShapeDynamism::STATIC, + DeviceType::CUDA, + 0); + Tensor src(&src_impl); + + TensorImpl dst_impl( + ScalarType::Float, + 1, + sizes, + dst_data.data(), + dim_order, + strides, + TensorShapeDynamism::STATIC, + DeviceType::CPU, + 0); + Tensor dst(&dst_impl); + + Tensor& result = op_d2h_copy_out(src, dst); + + EXPECT_EQ(context_.failure_state(), Error::Ok); + EXPECT_EQ(&result, &dst); + EXPECT_EQ(dst_data, expected); +} diff --git a/backends/mlx/builder/op_helpers.py b/backends/mlx/builder/op_helpers.py index 40e71e0bdab..be199f75340 100644 --- a/backends/mlx/builder/op_helpers.py +++ b/backends/mlx/builder/op_helpers.py @@ -17,6 +17,7 @@ if TYPE_CHECKING: from executorch.backends.mlx.builder.program_builder import MLXProgramBuilder + from executorch.backends.mlx.serialization.mlx_graph_schema import IntOrVid # When True, always serialize the biases tensor for quantized ops. # When False, use init-time computation when zero_point is all zeros, @@ -173,6 +174,117 @@ def emit_lifted_constant(P: "MLXProgramBuilder", value, dtype: torch.dtype) -> S return slot +def emit_shape( + P: "MLXProgramBuilder", + node: Node, + slot: Slot, + *, + end_dim: "Optional[int]" = None, +) -> "list[IntOrVid]": + """Return the shape of ``node`` as a list of ``IntOrVid``. + + Each static dim becomes a literal ``IntOrVid``; each dynamic dim + emits a ``SymSizeNode`` against ``slot`` and is wrapped via + ``P.to_int_or_vid``. + + Args: + P: program builder. + node: FX node whose shape to walk (must have ``meta['val']``). + slot: slot corresponding to ``node`` (used as the + ``SymSize`` source for any dynamic dim). + end_dim: stop index (exclusive). ``None`` means the full ndim. + Negative values index from the end (e.g. ``-1`` is "all + leading dims, drop the last"). + + Returns: + ``list[IntOrVid]`` of length ``end_dim`` (after normalization). + """ + from executorch.backends.mlx.serialization.mlx_graph_schema import ( + IntOrVid, + SymSizeNode, + ) + + shape = node.meta["val"].shape + ndim = len(shape) + if end_dim is None: + end_dim = ndim + elif end_dim < 0: + end_dim += ndim + + out: "list[IntOrVid]" = [] + for dim_idx in range(end_dim): + s = shape[dim_idx] + if isinstance(s, int): + out.append(IntOrVid.from_literal(int(s))) + else: + _, d_val = P.make_tmp_value_slot() + P.emit( + SymSizeNode( + a=P.slot_to_tid(slot), + dim=dim_idx, + out=P.slot_to_vid(d_val), + ) + ) + out.append(P.to_int_or_vid(d_val)) + return out + + +def emit_product( + P: "MLXProgramBuilder", + dims: "list[IntOrVid]", +) -> "IntOrVid": + """Multiplicative reduction over a list of ``IntOrVid`` values. + + Folds all literal entries AOT into a single static product, then + emits ``MultiplyIntNode`` only for the dynamic entries (and one + final node combining the static product with the dynamic accumulator + when both contribute). + + Args: + P: program builder. + dims: list of ``IntOrVid``. May be empty (returns + ``IntOrVid.from_literal(1)``), all literals, or a mix. + + Returns: + An ``IntOrVid`` representing the product. Always literal when + every entry is literal (or ``dims`` is empty). + """ + from executorch.backends.mlx.serialization.mlx_graph_schema import ( + IntOrVid, + MultiplyIntNode, + ) + + static_product = 1 + dynamic_dims: "list[IntOrVid]" = [] + for d in dims: + if d.is_vid: + dynamic_dims.append(d) + else: + static_product *= d.literal + + if not dynamic_dims: + return IntOrVid.from_literal(static_product) + + acc = dynamic_dims[0] + for d in dynamic_dims[1:]: + _, acc_val = P.make_tmp_value_slot() + P.emit(MultiplyIntNode(a=acc, b=d, out=P.slot_to_vid(acc_val))) + acc = P.to_int_or_vid(acc_val) + + if static_product == 1: + return acc + + _, final_val = P.make_tmp_value_slot() + P.emit( + MultiplyIntNode( + a=IntOrVid.from_literal(static_product), + b=acc, + out=P.slot_to_vid(final_val), + ) + ) + return P.to_int_or_vid(final_val) + + def emit_quantized_biases( P: "MLXProgramBuilder", zero_point_key: str, @@ -334,7 +446,7 @@ def parse_dequant_node( if len(non_one) != 1: return None quantized_dim, group_size = non_one[0] - if group_size not in [32, 64, 128]: + if group_size not in [16, 32, 64, 128]: return None # TODO: MLX supports 3, 5, and 7, but we need to figure out the diff --git a/backends/mlx/llm/turboquant_cache.py b/backends/mlx/llm/turboquant_cache.py new file mode 100644 index 00000000000..7f2109ba074 --- /dev/null +++ b/backends/mlx/llm/turboquant_cache.py @@ -0,0 +1,243 @@ +#!/usr/bin/env python3 +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +TurboQuant TQ4 KV cache for the MLX backend. + +Subclass of the backend-agnostic +``extension/llm/modules/turboquant/kv_cache.py::TurboQuantKVCache``. + +The cache stores K and V in **rotated space** (post-multiplied by R^T) +as nibble-packed uint8 codebook indices plus per-vector bf16 norms. +SDPA runs in rotated space and undoes the rotation on the output side +(both Q and output rotations are ``T_q × D²``, much smaller than +applying the inverse rotation to K/V which would be ``T_kv × D²``). + +Reference: + TurboQuant: Online Vector Quantization with Near-optimal + Distortion Rate. arXiv:2504.19874 (ICLR 2026). +""" + +from typing import Optional, Tuple + +# Register the MLX custom ops used by this cache. +import executorch.backends.mlx.custom_ops # noqa: F401 mlx::custom_sdpa, mlx::kv_cache_update +import executorch.backends.mlx.model_ops.tq4_compress # noqa: F401 mlx::tq4_compress +import executorch.backends.mlx.model_ops.tq_dequant # noqa: F401 mlx::tq_dequant +import executorch.backends.mlx.model_ops.tq_norm # noqa: F401 mlx::tq_norm + +import torch + +from executorch.extension.llm.modules.turboquant.kv_cache import ( + TurboQuantKVCache as _SharedTurboQuantKVCache, +) + + +class TurboQuantKVCache(_SharedTurboQuantKVCache): + """ + TurboQuant TQ4 KV cache, MLX-backend variant. + + Drop-in replacement for ``backends/mlx/llm/cache.py::KVCache``. + + Args: + max_batch_size: Must be 1 (TQ4 is batch=1 only). + max_context_length: Maximum sequence length. + n_heads: Number of KV heads. + head_dim: Per-head dimension. Must be even and a multiple of 64. + enable_dynamic_shape: Accepted for interface parity; ignored. + dtype: Compute dtype (bf16). Used for pre-cast buffers. + bits: Quantization bits (must be 4). + seed: RNG seed for the orthogonal rotation matrix. + """ + + def __init__( + self, + max_batch_size: int, + max_context_length: int, + n_heads: int, + head_dim: int, + enable_dynamic_shape: bool, + dtype: torch.dtype = torch.bfloat16, + bits: int = 4, + seed: int = 42, + ): + if max_batch_size != 1: + raise ValueError( + f"TurboQuantKVCache only supports max_batch_size=1, " + f"got {max_batch_size}" + ) + if bits != 4: + raise ValueError( + f"TurboQuantKVCache only supports bits=4 " + f"(16-entry codebook), got bits={bits}" + ) + # MLX-backend Metal kernels need ``head_dim % 64 == 0``: ``tq_norm`` + # uses 32 SIMD lanes (so D must be a multiple of 32), and + # ``tq_dequant`` packs 2 dims per byte across 32 lanes (so D must + # be a multiple of 64). Take the stricter constraint here. + if head_dim % 64 != 0: + raise ValueError( + f"TurboQuantKVCache requires head_dim to be " + f"a multiple of 64 (Metal SIMD + 4-bit pack constraint), " + f"got {head_dim}" + ) + super().__init__( + n_heads=n_heads, + head_dim=head_dim, + max_seq_len=max_context_length, + bits=bits, + seed=seed, + ) + self.max_batch_size = max_batch_size + self.max_context_length = max_context_length + self.enable_dynamic_shape = enable_dynamic_shape + + # Replace parent's fp32 ``rotation`` and ``centroids`` buffers + # with compute-dtype versions in-place. Avoids a per-call + # ``_to_copy`` cast in the lowered graph at every use site. + # Parent's ``_decompress`` (testing-only) is the sole consumer + # of these as fp32 and is not called at runtime. + self.register_buffer( + "rotation", + self.rotation.to(dtype).contiguous(), + persistent=False, + ) + self.register_buffer( + "centroids", + self.centroids.to(dtype).contiguous(), + persistent=False, + ) + # Pre-cast eps for the divide-by-zero guard in _compress. + self.register_buffer( + "norm_eps", + torch.tensor(1e-10, dtype=dtype), + persistent=False, + ) + + def _compress(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Compress ``(1, H, T, D)`` → packed ``(1, H, T, D//2)`` u8 + + norms ``(1, H, T, 1)`` bf16. + + The L2-norm reduction uses ``mlx::tq_norm`` (one Metal kernel + with fp32 sum-of-squares in registers via ``simd_sum``); the + bucketize + nibble-pack tail uses ``mlx::tq4_compress`` (one + Metal kernel for both steps). + """ + orig_shape = x.shape + flat = x.reshape(-1, self.head_dim) + + norms = torch.ops.mlx.tq_norm(flat) + normalized = flat / (norms + self.norm_eps) + rotated = normalized @ self.rotation_T + packed = torch.ops.mlx.tq4_compress(rotated, self.boundaries) + + return ( + packed.reshape(*orig_shape[:-1], self.half_dim), + norms.reshape(*orig_shape[:-1], 1), + ) + + def update( + self, + input_pos, + k_val: torch.Tensor, + v_val: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Compress + write K/V at ``input_pos``, return the full + compressed cache buffers. + + Accepts ``input_pos`` as either a ``(T,)`` LongTensor of + positions or a Python int / SymInt ``start_pos``. Writes go + through ``mlx::kv_cache_update`` (matching the non-TQ + ``MLXKVCache`` path) which lowers to a tighter in-place + scatter than ``index_copy_`` would. + """ + if isinstance(input_pos, torch.Tensor): + start_pos = input_pos[0].item() + seq_len = k_val.size(2) + torch._check(seq_len == v_val.size(2)) + torch._check(start_pos >= 0) + torch._check(start_pos + seq_len <= self.max_context_length) + else: + start_pos = input_pos + + k_packed, k_norms = self._compress(k_val) + v_packed, v_norms = self._compress(v_val) + + torch.ops.mlx.kv_cache_update(self.k_packed, k_packed, start_pos) + torch.ops.mlx.kv_cache_update(self.k_norms, k_norms, start_pos) + torch.ops.mlx.kv_cache_update(self.v_packed, v_packed, start_pos) + torch.ops.mlx.kv_cache_update(self.v_norms, v_norms, start_pos) + + # Slices on the return create new graph nodes so the same node + # is not both BUFFER_MUTATION and USER_OUTPUT. + return ( + self.k_packed[:, :, :, :], + self.k_norms[:, :, :, :], + self.v_packed[:, :, :, :], + self.v_norms[:, :, :, :], + ) + + # forward() is inherited from the parent (delegates to update). + + def sdpa( + self, + query: torch.Tensor, + start_pos, + scale: Optional[float] = None, + ) -> torch.Tensor: + """SDPA over the compressed cache. + + Runs attention in rotated space: + 1. Q_rot = Q @ R^T (T_q x D^2) + 2. K_rot, V_rot = tq_dequant(...) (rotated-space K/V) + 3. out_rot = custom_sdpa(Q_rot, K_rot, V_rot, ...) + 4. out = out_rot @ R (T_q x D^2) + + Since R is orthogonal, score = (Q·R^T)·(K·R^T)^T = Q·K^T, so + attention is invariant under matched rotation of Q and K. The + ``T_kv x D^2`` inverse-rotation matmul on K/V is replaced with + two ``T_q x D^2`` matmuls (Q and output). + + Args: + query: ``(B, H_q, T_q, D)`` bf16. + start_pos: int or SymInt — absolute position of the first + query token. + scale: 1/sqrt(D) if None. + + Returns: + ``(B, H_q, T_q, D)`` bf16 attention output, in original + (un-rotated) space. + """ + seq_len = query.size(2) + end_pos = start_pos + seq_len + torch._check(start_pos >= 0) + torch._check(end_pos <= self.max_context_length) + + q_rot = query @ self.rotation_T + + k_packed_live = self.k_packed[:, :, :end_pos, :] + k_norms_live = self.k_norms[:, :, :end_pos, :] + v_packed_live = self.v_packed[:, :, :end_pos, :] + v_norms_live = self.v_norms[:, :, :end_pos, :] + + # TODO: optimize with a fused dequant + SDPA + k_rot = torch.ops.mlx.tq_dequant(k_packed_live, k_norms_live, self.centroids) + v_rot = torch.ops.mlx.tq_dequant(v_packed_live, v_norms_live, self.centroids) + + out_rot = torch.ops.mlx.custom_sdpa( + q_rot, + k_rot, + v_rot, + start_pos, + None, # attn_mask + 0.0, # dropout_p + True, # is_causal + scale, + ) + + return out_rot @ self.rotation diff --git a/backends/mlx/model_ops/test_tq4_compress.py b/backends/mlx/model_ops/test_tq4_compress.py new file mode 100644 index 00000000000..c2aaa13afa7 --- /dev/null +++ b/backends/mlx/model_ops/test_tq4_compress.py @@ -0,0 +1,183 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Tests for ``mlx::tq4_compress``. + +Verifies the fused Metal kernel produces byte-exact output vs the +eager Python implementation across head_dim values used by TurboQuant. + +Usage:: + + python -m executorch.backends.mlx.model_ops.test_tq4_compress run + python -m executorch.backends.mlx.model_ops.test_tq4_compress run -v + python -m executorch.backends.mlx.model_ops.test_tq4_compress run --rebuild +""" + +from typing import List, Tuple + +import executorch.backends.mlx.model_ops.tq4_compress # noqa: F401 + +import torch +import torch.nn as nn + +from executorch.backends.mlx.test.test_utils import OpTestCase + + +class TQ4CompressModel(nn.Module): + """``values → packed`` via ``mlx::tq4_compress``. + + Boundaries are stored as a buffer so the model is exportable + without feeding them as a graph input. + """ + + def __init__(self, head_dim: int, dtype: torch.dtype = torch.bfloat16): + super().__init__() + # 15 sorted thresholds (4-bit codebook). + self.register_buffer( + "boundaries", + torch.linspace(-0.2, 0.2, 15, dtype=dtype), + ) + + def forward(self, values: torch.Tensor) -> torch.Tensor: + return torch.ops.mlx.tq4_compress(values, self.boundaries) + + +class TQ4CompressTest(OpTestCase): + """Byte-exact comparison vs eager bucketize + nibble-pack.""" + + name = "tq4_compress" + rtol = 0.0 + atol = 0.0 + + def __init__( + self, + batch_size: int = 1, + n_heads: int = 8, + seq_len: int = 4, + head_dim: int = 128, + dtype: torch.dtype = torch.bfloat16, + ): + self.batch_size = batch_size + self.n_heads = n_heads + self.seq_len = seq_len + self.head_dim = head_dim + self.dtype = dtype + + parts = [ + "tq4_compress", + f"b{batch_size}", + f"h{n_heads}", + f"t{seq_len}", + f"d{head_dim}", + ] + if dtype != torch.bfloat16: + parts.append(str(dtype).split(".")[-1]) + self.name = "_".join(parts) + + @classmethod + def get_test_configs(cls) -> List["TQ4CompressTest"]: + return [ + # head_dim=128 (Qwen3.5 MoE / Gemma 4 sliding) + cls(seq_len=1, head_dim=128), + cls(seq_len=8, head_dim=128), + cls(seq_len=64, head_dim=128), + cls(n_heads=1, seq_len=1, head_dim=128), + # head_dim=256 (Gemma 4 sliding-attention) + cls(head_dim=256), + cls(seq_len=16, head_dim=256), + # head_dim=512 (Gemma 4 31B full-attention) + cls(n_heads=4, seq_len=4, head_dim=512), + cls(n_heads=4, seq_len=64, head_dim=512), + # Smaller D for sanity + cls(head_dim=64, n_heads=2, seq_len=4), + ] + + def create_model(self) -> nn.Module: + return TQ4CompressModel(head_dim=self.head_dim, dtype=self.dtype).to(self.dtype) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + # Activation-scale values; the kernel is byte-exact regardless + # of magnitude as long as values fall within the bucketize + # comparison range. + values = torch.randn( + self.batch_size, + self.n_heads, + self.seq_len, + self.head_dim, + dtype=self.dtype, + ) * (1.0 / (self.head_dim**0.5)) + return (values,) + + +if __name__ == "__main__": # noqa: C901 + import argparse + import sys + + from executorch.backends.mlx.test.test_utils import rebuild_op_test_runner + + parser = argparse.ArgumentParser(description="Test mlx::tq4_compress op") + parser.add_argument( + "action", + choices=["generate", "compare", "run", "list"], + help="Action: generate (export), compare (check outputs), run (full), list (show configs)", + ) + parser.add_argument("--verbose", "-v", action="store_true") + parser.add_argument( + "--rebuild", action="store_true", help="Rebuild C++ runner first" + ) + parser.add_argument( + "--config", type=str, default=None, help="Run specific config by name" + ) + args = parser.parse_args() + + if args.rebuild and not rebuild_op_test_runner(verbose=args.verbose): + sys.exit(1) + + configs = TQ4CompressTest.get_test_configs() + + if args.action == "list": + for cfg in configs: + print(f" {cfg.name}") + sys.exit(0) + + if args.config: + configs = [c for c in configs if c.name == args.config] + if not configs: + print(f"No config matching '{args.config}'") + sys.exit(1) + + passed = 0 + failed = 0 + failed_names: List[str] = [] + + for test in configs: + if args.action == "generate": + pte_path, _, _ = test.generate_test_files(verbose=args.verbose) + print(f"Generated: {pte_path}") + elif args.action == "compare": + actual_path = test.get_test_dir() / "actual_output.bin" + ok, msg = test.compare_with_actual(actual_path) + print(f"{'✓' if ok else '✗'} {test.name}: {msg}") + if ok: + passed += 1 + else: + failed += 1 + failed_names.append(test.name) + elif args.action == "run": + ok = test.run_test(verbose=args.verbose) + if ok: + passed += 1 + else: + failed += 1 + failed_names.append(test.name) + + if args.action in ("run", "compare"): + print(f"\nPassed: {passed}, Failed: {failed}") + if failed_names: + print(f"Failed: {', '.join(failed_names)}") + sys.exit(0 if failed == 0 else 1) diff --git a/backends/mlx/model_ops/test_tq_dequant.py b/backends/mlx/model_ops/test_tq_dequant.py new file mode 100644 index 00000000000..07d9deb895a --- /dev/null +++ b/backends/mlx/model_ops/test_tq_dequant.py @@ -0,0 +1,166 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Tests for ``mlx::tq_dequant``. + +Verifies the fused unpack + gather + multiply Metal kernel matches +the eager reference at head_dim values used by TurboQuant +(D ∈ {128, 256, 512}). Output is byte-exact — no fp32 promotion in +either path. + +Usage:: + + python -m executorch.backends.mlx.model_ops.test_tq_dequant run + python -m executorch.backends.mlx.model_ops.test_tq_dequant run -v + python -m executorch.backends.mlx.model_ops.test_tq_dequant run --rebuild +""" + +from typing import List, Tuple + +import executorch.backends.mlx.model_ops.tq_dequant # noqa: F401 + +import torch +import torch.nn as nn + +from executorch.backends.mlx.test.test_utils import OpTestCase + + +class TQDequantModel(nn.Module): + """``packed, norms, centroids → unrotated``.""" + + def forward( + self, + packed: torch.Tensor, + norms: torch.Tensor, + centroids: torch.Tensor, + ) -> torch.Tensor: + return torch.ops.mlx.tq_dequant(packed, norms, centroids) + + +class TQDequantTest(OpTestCase): + """Byte-exact comparison vs eager unpack + gather + multiply.""" + + name = "tq_dequant" + rtol = 0.0 + atol = 0.0 + + def __init__( + self, + batch_size: int = 1, + n_heads: int = 8, + seq_len: int = 4, + head_dim: int = 128, + ): + self.batch_size = batch_size + self.n_heads = n_heads + self.seq_len = seq_len + self.head_dim = head_dim + self.half_dim = head_dim // 2 + self.name = f"tq_dequant_b{batch_size}_h{n_heads}_t{seq_len}_d{head_dim}" + + @classmethod + def get_test_configs(cls) -> List["TQDequantTest"]: + return [ + # head_dim=128 (Qwen3.5 MoE / Gemma 4 sliding) + cls(seq_len=1, head_dim=128), + cls(seq_len=8, head_dim=128), + cls(seq_len=64, head_dim=128), + cls(n_heads=1, seq_len=1, head_dim=128), + # head_dim=256 (Gemma 4 sliding-attention) + cls(seq_len=4, head_dim=256), + cls(seq_len=16, head_dim=256), + # head_dim=512 (Gemma 4 31B full-attention) + cls(n_heads=4, seq_len=4, head_dim=512), + cls(n_heads=4, seq_len=64, head_dim=512), + ] + + def create_model(self) -> nn.Module: + return TQDequantModel() + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + # Random packed bytes exercise every codebook entry. + packed = torch.randint( + 0, + 256, + (self.batch_size, self.n_heads, self.seq_len, self.half_dim), + dtype=torch.uint8, + ) + norms = ( + torch.randn( + self.batch_size, + self.n_heads, + self.seq_len, + 1, + dtype=torch.bfloat16, + ).abs() + + 0.1 + ) + # Deterministic codebook covering [-1, 1]. + centroids = torch.linspace(-1.0, 1.0, 16, dtype=torch.bfloat16) + return (packed, norms, centroids) + + +if __name__ == "__main__": # noqa: C901 + import argparse + import sys + + from executorch.backends.mlx.test.test_utils import rebuild_op_test_runner + + parser = argparse.ArgumentParser(description="Test mlx::tq_dequant op") + parser.add_argument("action", choices=["generate", "compare", "run", "list"]) + parser.add_argument("--verbose", "-v", action="store_true") + parser.add_argument("--rebuild", action="store_true") + parser.add_argument("--config", type=str, default=None) + args = parser.parse_args() + + if args.rebuild and not rebuild_op_test_runner(verbose=args.verbose): + sys.exit(1) + + configs = TQDequantTest.get_test_configs() + + if args.action == "list": + for cfg in configs: + print(f" {cfg.name}") + sys.exit(0) + + if args.config: + configs = [c for c in configs if c.name == args.config] + if not configs: + print(f"No config matching '{args.config}'") + sys.exit(1) + + passed = 0 + failed = 0 + failed_names: List[str] = [] + + for test in configs: + if args.action == "generate": + pte_path, _, _ = test.generate_test_files(verbose=args.verbose) + print(f"Generated: {pte_path}") + elif args.action == "compare": + actual_path = test.get_test_dir() / "actual_output.bin" + ok, msg = test.compare_with_actual(actual_path) + print(f"{'✓' if ok else '✗'} {test.name}: {msg}") + if ok: + passed += 1 + else: + failed += 1 + failed_names.append(test.name) + elif args.action == "run": + ok = test.run_test(verbose=args.verbose) + if ok: + passed += 1 + else: + failed += 1 + failed_names.append(test.name) + + if args.action in ("run", "compare"): + print(f"\nPassed: {passed}, Failed: {failed}") + if failed_names: + print(f"Failed: {', '.join(failed_names)}") + sys.exit(0 if failed == 0 else 1) diff --git a/backends/mlx/model_ops/test_tq_norm.py b/backends/mlx/model_ops/test_tq_norm.py new file mode 100644 index 00000000000..35c4491d8ae --- /dev/null +++ b/backends/mlx/model_ops/test_tq_norm.py @@ -0,0 +1,150 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Tests for ``mlx::tq_norm``. + +Verifies the fused L2-norm Metal kernel matches eager ``vector_norm`` +at head_dim values used by TurboQuant (D ∈ {128, 256, 512}). + +Usage:: + + python -m executorch.backends.mlx.model_ops.test_tq_norm run + python -m executorch.backends.mlx.model_ops.test_tq_norm run -v + python -m executorch.backends.mlx.model_ops.test_tq_norm run --rebuild +""" + +from typing import List, Tuple + +import executorch.backends.mlx.model_ops.tq_norm # noqa: F401 + +import torch +import torch.nn as nn + +from executorch.backends.mlx.test.test_utils import OpTestCase + + +class TQNormModel(nn.Module): + """``x → ||x||₂`` over the last dim.""" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.ops.mlx.tq_norm(x) + + +class TQNormTest(OpTestCase): + """Compare ``mlx::tq_norm`` to eager ``vector_norm`` within bf16 ULPs.""" + + name = "tq_norm" + rtol = 1e-2 + atol = 1e-2 + + def __init__( + self, + batch_size: int = 1, + n_heads: int = 8, + seq_len: int = 4, + head_dim: int = 128, + ): + self.batch_size = batch_size + self.n_heads = n_heads + self.seq_len = seq_len + self.head_dim = head_dim + self.name = f"tq_norm_b{batch_size}_h{n_heads}_t{seq_len}_d{head_dim}" + + @classmethod + def get_test_configs(cls) -> List["TQNormTest"]: + return [ + # head_dim=128 (Qwen3.5 MoE / Gemma 4 sliding) + cls(seq_len=1, head_dim=128), + cls(seq_len=8, head_dim=128), + cls(seq_len=64, head_dim=128), + cls(n_heads=1, seq_len=1, head_dim=128), + # head_dim=256 (Gemma 4 sliding-attention) + cls(seq_len=4, head_dim=256), + cls(seq_len=16, head_dim=256), + # head_dim=512 (Gemma 4 31B full-attention) + cls(n_heads=4, seq_len=4, head_dim=512), + cls(n_heads=4, seq_len=64, head_dim=512), + ] + + def create_model(self) -> nn.Module: + return TQNormModel().to(torch.bfloat16) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + # Activation-scale bf16 inputs. + x = torch.randn( + self.batch_size, + self.n_heads, + self.seq_len, + self.head_dim, + dtype=torch.bfloat16, + ) * (1.0 / (self.head_dim**0.5)) + return (x,) + + +if __name__ == "__main__": # noqa: C901 + import argparse + import sys + + from executorch.backends.mlx.test.test_utils import rebuild_op_test_runner + + parser = argparse.ArgumentParser(description="Test mlx::tq_norm op") + parser.add_argument( + "action", + choices=["generate", "compare", "run", "list"], + ) + parser.add_argument("--verbose", "-v", action="store_true") + parser.add_argument("--rebuild", action="store_true") + parser.add_argument("--config", type=str, default=None) + args = parser.parse_args() + + if args.rebuild and not rebuild_op_test_runner(verbose=args.verbose): + sys.exit(1) + + configs = TQNormTest.get_test_configs() + + if args.action == "list": + for cfg in configs: + print(f" {cfg.name}") + sys.exit(0) + + if args.config: + configs = [c for c in configs if c.name == args.config] + if not configs: + print(f"No config matching '{args.config}'") + sys.exit(1) + + passed = 0 + failed = 0 + failed_names: List[str] = [] + + for test in configs: + if args.action == "generate": + pte_path, _, _ = test.generate_test_files(verbose=args.verbose) + print(f"Generated: {pte_path}") + elif args.action == "compare": + actual_path = test.get_test_dir() / "actual_output.bin" + ok, msg = test.compare_with_actual(actual_path) + print(f"{'✓' if ok else '✗'} {test.name}: {msg}") + if ok: + passed += 1 + else: + failed += 1 + failed_names.append(test.name) + elif args.action == "run": + ok = test.run_test(verbose=args.verbose) + if ok: + passed += 1 + else: + failed += 1 + failed_names.append(test.name) + + if args.action in ("run", "compare"): + print(f"\nPassed: {passed}, Failed: {failed}") + if failed_names: + print(f"Failed: {', '.join(failed_names)}") + sys.exit(0 if failed == 0 else 1) diff --git a/backends/mlx/model_ops/tq4_compress.py b/backends/mlx/model_ops/tq4_compress.py new file mode 100644 index 00000000000..f08d47b9a11 --- /dev/null +++ b/backends/mlx/model_ops/tq4_compress.py @@ -0,0 +1,189 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# + +""" +``mlx::tq4_compress``: TurboQuant TQ4 quantize + nibble-pack. + +Maps ``(..., D)`` floats to ``(..., D/2)`` uint8 by: + 1. Bucketizing each value against ``boundaries`` (15 sorted thresholds). + 2. Packing pairs of 4-bit indices into one byte: high nibble holds + the even-position index, low nibble holds the odd-position index. + +Constraints: + * ``boundaries`` must be 1-D length 15 (4-bit codebook). + * Last dim of ``values`` must be even and statically known. + +Usage:: + + import executorch.backends.mlx.model_ops.tq4_compress # noqa: F401 + + packed = torch.ops.mlx.tq4_compress(rotated, boundaries) + # rotated: (..., D) float + # boundaries: (15,) same dtype as rotated + # packed: (..., D/2) uint8 +""" + +from __future__ import annotations + +import torch +from torch import Tensor +from torch.fx.node import Node + + +@torch.library.custom_op("mlx::tq4_compress", mutates_args=()) +def tq4_compress(values: Tensor, boundaries: Tensor) -> Tensor: + """TurboQuant TQ4 quantize + nibble-pack. + + Args: + values: ``(..., D)`` float, last dim must be even. + boundaries: ``(15,)`` 1-D sorted, same dtype as ``values``. + + Returns: + ``(..., D/2)`` uint8. Each byte holds two 4-bit indices: high + nibble is the even-position index, low nibble is the odd. + """ + if boundaries.dim() != 1 or boundaries.shape[0] != 15: + raise ValueError( + f"mlx::tq4_compress: boundaries must be 1-D length 15; " + f"got shape {tuple(boundaries.shape)}" + ) + if values.shape[-1] % 2 != 0: + raise ValueError( + f"mlx::tq4_compress: input last dim must be even; got " + f"{values.shape[-1]}" + ) + + indices = torch.bucketize(values, boundaries).to(torch.uint8) + packed = (indices[..., 0::2] << 4) | indices[..., 1::2] + return packed + + +@torch.library.register_fake("mlx::tq4_compress") +def tq4_compress_fake(values: Tensor, boundaries: Tensor) -> Tensor: + out_shape = list(values.shape) + out_shape[-1] = out_shape[-1] // 2 + return values.new_empty(out_shape, dtype=torch.uint8) + + +# --------------------------------------------------------------------------- +# MLX handler +# --------------------------------------------------------------------------- + +from executorch.backends.mlx.builder.op_helpers import ( + emit_product, + emit_shape, + torch_dtype_to_scalar_type, +) +from executorch.backends.mlx.builder.op_registry import REGISTRY +from executorch.backends.mlx.builder.program_builder import MLXProgramBuilder +from executorch.backends.mlx.builder.slot_manager import Slot +from executorch.backends.mlx.serialization.mlx_graph_schema import ( + IntOrVid, + MetalKernelNode, +) + + +# One thread per output byte: reads ``values[2*gid]``, ``values[2*gid+1]``, +# bucketizes against the 15 boundaries (loop unrolled, ``B`` is a template +# constant), and packs the two 4-bit indices into one byte. +_TQ4_COMPRESS_SOURCE = """ + uint gid = thread_position_in_grid.x; + float v_hi = float(values[2 * gid]); + float v_lo = float(values[2 * gid + 1]); + uchar idx_hi = 0; + uchar idx_lo = 0; + #pragma unroll + for (uint i = 0; i < B; ++i) { + float bnd = float(boundaries[i]); + idx_hi += (uchar)(v_hi > bnd); + idx_lo += (uchar)(v_lo > bnd); + } + out[gid] = (idx_hi << 4) | idx_lo; +""" + + +@REGISTRY.register(target=[torch.ops.mlx.tq4_compress.default]) +def _tq4_compress_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Lower ``mlx::tq4_compress`` to a fused Metal kernel.""" + args = P.args(n) + if len(args) != 2: + raise ValueError( + f"mlx::tq4_compress: expected 2 args (values, boundaries), " + f"got {len(args)}" + ) + + values_slot, boundaries_slot = args + values_node = n.args[0] + boundaries_node = n.args[1] + + values_meta = values_node.meta["val"] + boundaries_meta = boundaries_node.meta["val"] + + # Validate boundaries length: must be 15 for 4-bit nibble pack. + bnd_shape = boundaries_meta.shape + if ( + len(bnd_shape) != 1 + or not isinstance(bnd_shape[0], int) + or int(bnd_shape[0]) != 15 + ): + raise ValueError( + f"mlx::tq4_compress: boundaries must be 1-D length 15; " + f"got shape {tuple(bnd_shape)}" + ) + + last_dim = values_meta.shape[-1] + if not isinstance(last_dim, int): + raise NotImplementedError( + "mlx::tq4_compress: last dim must be statically known" + ) + if int(last_dim) % 2 != 0: + raise ValueError(f"mlx::tq4_compress: last dim must be even; got {last_dim}") + half_last = int(last_dim) // 2 + + in_dtype_int = torch_dtype_to_scalar_type(values_meta.dtype) + + out = P.make_or_get_slot(n) + leading = emit_shape(P, values_node, values_slot, end_dim=-1) + half_last_iov = IntOrVid.from_literal(half_last) + out_shape_flat = leading + [half_last_iov] + + # One thread per output byte, so the grid size is the output numel + # (product of leading dims times the halved last dim). + n_out_iov = emit_product(P, leading + [half_last_iov]) + + P.emit( + MetalKernelNode( + name="tq4_compress", + source=_TQ4_COMPRESS_SOURCE, + inputs=[ + P.slot_to_tid(values_slot), + P.slot_to_tid(boundaries_slot), + ], + outputs=[P.slot_to_tid(out)], + grid=[n_out_iov, IntOrVid.from_literal(1), IntOrVid.from_literal(1)], + # 32 threads per threadgroup so each TG fills one Apple-GPU SIMD group + threadgroup=[ + IntOrVid.from_literal(32), + IntOrVid.from_literal(1), + IntOrVid.from_literal(1), + ], + input_names=["values", "boundaries"], + output_names=["out"], + output_shapes_flat=out_shape_flat, + output_shape_lengths=[len(out_shape_flat)], + output_dtypes=[torch_dtype_to_scalar_type(torch.uint8)], + template_arg_names=["InT", "B"], + template_arg_kinds=[2, 0], # 2=dtype, 0=int + template_arg_values=[ + in_dtype_int, + 15, + ], + ) + ) + + return out diff --git a/backends/mlx/model_ops/tq_dequant.py b/backends/mlx/model_ops/tq_dequant.py new file mode 100644 index 00000000000..28a168e9be0 --- /dev/null +++ b/backends/mlx/model_ops/tq_dequant.py @@ -0,0 +1,216 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# + +""" +``mlx::tq_dequant``: TurboQuant TQ4 unpack + centroid gather + multiply-by-norm. + + indices = unpack 4-bit nibbles from packed bytes (..., D) + centvals = centroids[indices] (..., D) + out = centvals * norms (..., D) + +Output is in **rotated space** — the inverse rotation, if needed, is +left to the caller (typically MLX's tuned bf16 GEMM). + +Constraints: + * ``D`` (= ``packed.shape[-1] * 2``) must be a multiple of 64. + * ``centroids`` must be a 1-D tensor of length 16. + * Output dtype matches ``norms.dtype``. + +Usage:: + + import executorch.backends.mlx.model_ops.tq_dequant # noqa: F401 + + out = torch.ops.mlx.tq_dequant(packed, norms, centroids) + # packed: (..., D/2) uint8 + # norms: (..., 1) bf16 + # centroids: (16,) bf16 + # out: (..., D) bf16 (in rotated space) +""" + +from __future__ import annotations + +import torch +from torch import Tensor +from torch.fx.node import Node + + +# --------------------------------------------------------------------------- +# Custom op + eager fallback +# --------------------------------------------------------------------------- + + +@torch.library.custom_op("mlx::tq_dequant", mutates_args=()) +def tq_dequant( + packed: Tensor, + norms: Tensor, + centroids: Tensor, +) -> Tensor: + """Fused unpack + centroid gather + multiply-by-norm. + + Args: + packed: ``(..., D/2)`` uint8. High nibble = even-position index, + low nibble = odd-position index. + norms: ``(..., 1)`` of compute dtype, broadcasts over D. + centroids: ``(16,)`` of compute dtype. + + Returns: + ``(..., D)`` of compute dtype, in rotated space. + """ + if centroids.dim() != 1 or centroids.shape[0] != 16: + raise ValueError( + f"mlx::tq_dequant: centroids must be 1-D length 16; got " + f"shape {tuple(centroids.shape)}" + ) + high = (packed >> 4).long() + low = (packed & 0x0F).long() + indices = torch.stack([high, low], dim=-1).reshape( + *packed.shape[:-1], packed.shape[-1] * 2 + ) + return centroids[indices] * norms + + +@torch.library.register_fake("mlx::tq_dequant") +def tq_dequant_fake(packed: Tensor, norms: Tensor, centroids: Tensor) -> Tensor: + out_shape = list(packed.shape) + out_shape[-1] = out_shape[-1] * 2 + return packed.new_empty(out_shape, dtype=norms.dtype) + + +# --------------------------------------------------------------------------- +# MLX handler +# --------------------------------------------------------------------------- + +from executorch.backends.mlx.builder.op_helpers import ( + emit_product, + emit_shape, + torch_dtype_to_scalar_type, +) +from executorch.backends.mlx.builder.op_registry import REGISTRY +from executorch.backends.mlx.builder.program_builder import MLXProgramBuilder +from executorch.backends.mlx.builder.slot_manager import Slot +from executorch.backends.mlx.serialization.mlx_graph_schema import ( + IntOrVid, + MetalKernelNode, +) + + +_TQ_DEQUANT_HEADER = """ +#include +using namespace metal; +""" + + +# Per-vector decompress: +# * Grid (32, 1, M), threadgroup (32, 1, 1): one simdgroup per vector. +# * Each lane handles DIMS_PER_LANE = D/32 output values, sourced +# from BYTES_PER_LANE = DIMS_PER_LANE/2 packed bytes. +# * The 16-entry codebook is preloaded into per-lane registers. +_TQ_DEQUANT_SOURCE = """ + constexpr uint DIMS_PER_LANE = D / 32; + constexpr uint BYTES_PER_LANE = DIMS_PER_LANE / 2; + + uint vec_id = thread_position_in_grid.z; + uint lane_id = thread_position_in_threadgroup.x; + + InT cent[16]; + for (uint c = 0; c < 16; ++c) { + cent[c] = centroids[c]; + } + + InT norm = norms[vec_id]; + + uint packed_base = vec_id * (D / 2) + lane_id * BYTES_PER_LANE; + uint out_base = vec_id * D + lane_id * DIMS_PER_LANE; + + for (uint i = 0; i < BYTES_PER_LANE; ++i) { + uchar byte = packed[packed_base + i]; + uchar idx_hi = (byte >> 4) & 0x0F; + uchar idx_lo = byte & 0x0F; + out[out_base + 2 * i + 0] = cent[idx_hi] * norm; + out[out_base + 2 * i + 1] = cent[idx_lo] * norm; + } +""" + + +@REGISTRY.register(target=[torch.ops.mlx.tq_dequant.default]) +def _tq_dequant_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Lower ``mlx::tq_dequant`` to a single fused Metal kernel.""" + args = P.args(n) + if len(args) != 3: + raise ValueError( + f"mlx::tq_dequant: expected 3 args (packed, norms, centroids); " + f"got {len(args)}" + ) + packed_slot, norms_slot, centroids_slot = args + packed_node = n.args[0] + norms_node = n.args[1] + centroids_node = n.args[2] + + packed_meta = packed_node.meta["val"] + norms_meta = norms_node.meta["val"] + centroids_meta = centroids_node.meta["val"] + + if centroids_meta.dim() != 1 or int(centroids_meta.shape[0]) != 16: + raise ValueError( + f"mlx::tq_dequant: centroids must be 1-D length 16; got " + f"shape {tuple(centroids_meta.shape)}" + ) + + last_dim_packed = packed_meta.shape[-1] + if not isinstance(last_dim_packed, int): + raise NotImplementedError( + "mlx::tq_dequant: packed last dim must be statically known" + ) + half_D = int(last_dim_packed) + D = half_D * 2 + if D % 64 != 0: + raise NotImplementedError( + f"mlx::tq_dequant: unpacked dim must be a multiple of 64 " + f"(2 dims per packed byte, 32 SIMD lanes); got D={D}" + ) + + out_dtype_int = torch_dtype_to_scalar_type(norms_meta.dtype) + + out = P.make_or_get_slot(n) + leading = emit_shape(P, packed_node, packed_slot, end_dim=-1) + out_shape_flat = leading + [IntOrVid.from_literal(D)] + M_iov = emit_product(P, leading) + + P.emit( + MetalKernelNode( + name="tq_dequant", + source=_TQ_DEQUANT_SOURCE, + header=_TQ_DEQUANT_HEADER, + inputs=[ + P.slot_to_tid(packed_slot), + P.slot_to_tid(norms_slot), + P.slot_to_tid(centroids_slot), + ], + outputs=[P.slot_to_tid(out)], + grid=[ + IntOrVid.from_literal(32), + IntOrVid.from_literal(1), + M_iov, + ], + threadgroup=[ + IntOrVid.from_literal(32), + IntOrVid.from_literal(1), + IntOrVid.from_literal(1), + ], + input_names=["packed", "norms", "centroids"], + output_names=["out"], + output_shapes_flat=out_shape_flat, + output_shape_lengths=[len(out_shape_flat)], + output_dtypes=[out_dtype_int], + template_arg_names=["InT", "D"], + template_arg_kinds=[2, 0], # 2=dtype, 0=int + template_arg_values=[out_dtype_int, D], + ) + ) + + return out diff --git a/backends/mlx/model_ops/tq_norm.py b/backends/mlx/model_ops/tq_norm.py new file mode 100644 index 00000000000..7e6a4d657f3 --- /dev/null +++ b/backends/mlx/model_ops/tq_norm.py @@ -0,0 +1,170 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# + +""" +``mlx::tq_norm``: L2 norm along the last dim, lowered to a single Metal kernel. + + norms[..., 0] = sqrt(sum_i x[..., i]^2) + +Reads / writes ``x.dtype`` directly (no graph-level dtype casts). +Reduces in fp32 inside Metal registers via ``simd_sum`` for precision +on large ``D`` (bf16 sum-of-squares loses too much for D>=128). + +Constraints: + * Last dim ``D`` must be statically known and a multiple of 32. + +Usage:: + + import executorch.backends.mlx.model_ops.tq_norm # noqa: F401 + + norms = torch.ops.mlx.tq_norm(x) + # x: (..., D) bf16 + # norms: (..., 1) bf16, equal to vector_norm(x, dim=-1, keepdim=True) +""" + +from __future__ import annotations + +import torch +from torch import Tensor +from torch.fx.node import Node + + +# --------------------------------------------------------------------------- +# Custom op + eager fallback +# --------------------------------------------------------------------------- + + +@torch.library.custom_op("mlx::tq_norm", mutates_args=()) +def tq_norm(x: Tensor) -> Tensor: + """L2 norm along last dim. + + Args: + x: ``(..., D)``. For MLX lowering, ``D`` must be a multiple of 32. + + Returns: + ``(..., 1)`` of the same dtype as ``x``. + """ + return torch.linalg.vector_norm(x, dim=-1, keepdim=True).to(x.dtype) + + +@torch.library.register_fake("mlx::tq_norm") +def tq_norm_fake(x: Tensor) -> Tensor: + out_shape = list(x.shape) + out_shape[-1] = 1 + return x.new_empty(out_shape, dtype=x.dtype) + + +# --------------------------------------------------------------------------- +# MLX handler +# --------------------------------------------------------------------------- + +from executorch.backends.mlx.builder.op_helpers import ( + emit_product, + emit_shape, + torch_dtype_to_scalar_type, +) +from executorch.backends.mlx.builder.op_registry import REGISTRY +from executorch.backends.mlx.builder.program_builder import MLXProgramBuilder +from executorch.backends.mlx.builder.slot_manager import Slot +from executorch.backends.mlx.serialization.mlx_graph_schema import ( + IntOrVid, + MetalKernelNode, +) + + +_TQ_NORM_HEADER = """ +#include +using namespace metal; +""" + + +# Per-vector reduction: +# * Grid (32, 1, M), threadgroup (32, 1, 1): one simdgroup per vector. +# * Each lane covers DIMS_PER_LANE = D/32 elements; partial sums are +# accumulated in an fp32 register. +# * ``simd_sum`` reduces across the 32 lanes; lane 0 sqrts and writes. +_TQ_NORM_SOURCE = """ + constexpr uint DIMS_PER_LANE = D / 32; + + uint vec_id = thread_position_in_grid.z; + uint lane_id = thread_position_in_threadgroup.x; + + uint base = vec_id * D + lane_id * DIMS_PER_LANE; + + float local_sum_sq = 0.0f; + for (uint i = 0; i < DIMS_PER_LANE; ++i) { + float v = float(x[base + i]); + local_sum_sq += v * v; + } + + float total_sum_sq = simd_sum(local_sum_sq); + + if (lane_id == 0) { + norms[vec_id] = (InT)sqrt(total_sum_sq); + } +""" + + +@REGISTRY.register(target=[torch.ops.mlx.tq_norm.default]) +def _tq_norm_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Lower ``mlx::tq_norm`` to a single fused Metal kernel.""" + args = P.args(n) + if len(args) != 1: + raise ValueError(f"mlx::tq_norm: expected 1 arg (x), got {len(args)}") + + (x_slot,) = args + x_node = n.args[0] + + x_meta = x_node.meta["val"] + + last_dim = x_meta.shape[-1] + if not isinstance(last_dim, int): + raise NotImplementedError("mlx::tq_norm: last dim must be statically known") + D = int(last_dim) + if D % 32 != 0: + raise NotImplementedError( + f"mlx::tq_norm: last dim must be a multiple of 32 (one per " + f"SIMD lane); got D={D}" + ) + + in_dtype_int = torch_dtype_to_scalar_type(x_meta.dtype) + + out = P.make_or_get_slot(n) + leading = emit_shape(P, x_node, x_slot, end_dim=-1) + out_shape_flat = leading + [IntOrVid.from_literal(1)] + M_iov = emit_product(P, leading) + + P.emit( + MetalKernelNode( + name="tq_norm", + source=_TQ_NORM_SOURCE, + header=_TQ_NORM_HEADER, + inputs=[P.slot_to_tid(x_slot)], + outputs=[P.slot_to_tid(out)], + grid=[ + IntOrVid.from_literal(32), + IntOrVid.from_literal(1), + M_iov, + ], + threadgroup=[ + IntOrVid.from_literal(32), + IntOrVid.from_literal(1), + IntOrVid.from_literal(1), + ], + input_names=["x"], + output_names=["norms"], + output_shapes_flat=out_shape_flat, + output_shape_lengths=[len(out_shape_flat)], + output_dtypes=[in_dtype_int], + template_arg_names=["InT", "D"], + template_arg_kinds=[2, 0], # 2=dtype, 0=int + template_arg_values=[in_dtype_int, D], + ) + ) + + return out diff --git a/backends/mlx/patterns.py b/backends/mlx/patterns.py index 29e5e326c69..5f74cbea643 100644 --- a/backends/mlx/patterns.py +++ b/backends/mlx/patterns.py @@ -15,6 +15,7 @@ from __future__ import annotations +import os from typing import Any, List, Optional, Tuple import torch @@ -37,6 +38,7 @@ ) from executorch.backends.mlx.serialization.mlx_graph_schema import ( AddIntNode, + AddmmNode, AddNode, AsTypeNode, DequantizeNode, @@ -52,6 +54,7 @@ SubtractIntNode, SymSizeNode, TakeNode, + TransposeNode, ) from torch.export.exported_program import ExportedProgram from torch.fx.node import Node @@ -883,6 +886,18 @@ def maybe_create( out_dtype=out_dtype, ) + # MLX's quantized_matmul Metal kernels are only instantiated for + # group_size in {32, 64, 128}. For smaller group sizes (e.g. GGUF + # Q6_K with group_size=16), emit DequantizeNode + matmul instead. + # Weights stay packed in the .pte file; dequantized on-device. + # This non-fused path is significantly slower and must be opted in + # via ET_MLX_ALLOW_NON_FUSED_QUANTIZED_OPS=1. + _MIN_FUSED_GROUP_SIZE = 32 + + @staticmethod + def _allow_non_fused() -> bool: + return os.environ.get("ET_MLX_ALLOW_NON_FUSED_QUANTIZED_OPS", "0") == "1" + def __call__(self, P: MLXProgramBuilder, n: Node) -> Slot: assert n == self.head @@ -908,19 +923,59 @@ def __call__(self, P: MLXProgramBuilder, n: Node) -> Slot: x_dtype = x_node.meta["val"].dtype needs_cast = self.out_dtype != x_dtype - P.emit( - QuantizedMatmulNode( - x=P.slot_to_tid(x_slot), - w=P.slot_to_tid(w), - scales=P.slot_to_tid(scale_slot), - out=P.slot_to_tid(out), - biases=P.slot_to_tid(biases), - group_size=self.group_size, - bits=self.bits, - mode="affine", - transpose=True, + if self.group_size >= self._MIN_FUSED_GROUP_SIZE: + P.emit( + QuantizedMatmulNode( + x=P.slot_to_tid(x_slot), + w=P.slot_to_tid(w), + scales=P.slot_to_tid(scale_slot), + out=P.slot_to_tid(out), + biases=P.slot_to_tid(biases), + group_size=self.group_size, + bits=self.bits, + mode="affine", + transpose=True, + ) ) - ) + else: + if not self._allow_non_fused(): + raise ValueError( + f"Quantized linear with group_size={self.group_size} requires " + f"the non-fused dequantize+matmul path, which is significantly " + f"slower than the fused QuantizedMatmulNode (group_size >= 32). " + f"Set ET_MLX_ALLOW_NON_FUSED_QUANTIZED_OPS=1 to allow this." + ) + out_scalar_type = torch_dtype_to_scalar_type(self.out_dtype) + _, w_deq = P.make_tmp_slot() + P.emit( + DequantizeNode( + w=P.slot_to_tid(w), + scales=P.slot_to_tid(scale_slot), + out=P.slot_to_tid(w_deq), + biases=P.slot_to_tid(biases), + group_size=self.group_size, + bits=self.bits, + mode="affine", + dtype=out_scalar_type, + ) + ) + _, w_t = P.make_tmp_slot() + P.emit( + TransposeNode( + x=P.slot_to_tid(w_deq), + out=P.slot_to_tid(w_t), + perm=[1, 0], + ) + ) + P.emit( + AddmmNode( + mat1=P.slot_to_tid(x_slot), + mat2=P.slot_to_tid(w_t), + out=P.slot_to_tid(out), + ) + ) + # DequantizeNode already produces the correct dtype. + needs_cast = False if has_bias: P.emit( diff --git a/backends/mlx/test/op_test_runner.cpp b/backends/mlx/test/op_test_runner.cpp index 6bed13d7a56..925ff410f42 100644 --- a/backends/mlx/test/op_test_runner.cpp +++ b/backends/mlx/test/op_test_runner.cpp @@ -58,6 +58,7 @@ enum class DType : uint32_t { Int64 = 3, BFloat16 = 4, Bool = 5, + UInt8 = 6, }; size_t dtype_size(DType dtype) { @@ -74,6 +75,8 @@ size_t dtype_size(DType dtype) { return 2; case DType::Bool: return 1; + case DType::UInt8: + return 1; default: return 4; } @@ -93,6 +96,8 @@ exec_aten::ScalarType dtype_to_scalar_type(DType dtype) { return exec_aten::ScalarType::BFloat16; case DType::Bool: return exec_aten::ScalarType::Bool; + case DType::UInt8: + return exec_aten::ScalarType::Byte; default: return exec_aten::ScalarType::Float; } @@ -112,6 +117,8 @@ DType scalar_type_to_dtype(exec_aten::ScalarType stype) { return DType::BFloat16; case exec_aten::ScalarType::Bool: return DType::Bool; + case exec_aten::ScalarType::Byte: + return DType::UInt8; default: return DType::Float32; } @@ -316,6 +323,11 @@ int main(int argc, char* argv[]) { std::memcpy(data.data(), t.data.data(), t.data.size()); tensor_ptr = make_tensor_ptr( sizes, std::move(data), {}, {}, exec_aten::ScalarType::Bool); + } else if (t.dtype == DType::UInt8) { + std::vector data(t.data.size()); + std::memcpy(data.data(), t.data.data(), t.data.size()); + tensor_ptr = make_tensor_ptr( + sizes, std::move(data), {}, {}, exec_aten::ScalarType::Byte); } else { std::cerr << "Unsupported dtype: " << static_cast(t.dtype) << std::endl; diff --git a/backends/mlx/test/test_ops.py b/backends/mlx/test/test_ops.py index 4471610519e..ec80b1d3911 100644 --- a/backends/mlx/test/test_ops.py +++ b/backends/mlx/test/test_ops.py @@ -24,6 +24,7 @@ See README.md in this directory for full documentation. """ +import os from typing import Callable, Dict, List, Optional, Tuple import torch @@ -2235,6 +2236,402 @@ def get_dynamic_shapes(self) -> Optional[Dict[str, any]]: } +from executorch.backends.mlx.llm.turboquant_cache import TurboQuantKVCache + + +class TurboQuantKVCacheModel(nn.Module): + """ + Test model wrapping TurboQuantKVCache.update(). + + TurboQuantKVCache stores K/V in rotated 4-bit packed form. ``update`` + returns the four cache buffers (k_packed, k_norms, v_packed, v_norms) + rather than uncompressed K/V. + """ + + def __init__( + self, + max_batch_size: int, + max_context_length: int, + n_heads: int, + head_dim: int, + enable_dynamic_shape: bool = True, + ): + super().__init__() + self.cache = TurboQuantKVCache( + max_batch_size=max_batch_size, + max_context_length=max_context_length, + n_heads=n_heads, + head_dim=head_dim, + enable_dynamic_shape=enable_dynamic_shape, + ) + + def forward( + self, + input_pos: torch.Tensor, + k_val: torch.Tensor, + v_val: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + return self.cache.update(input_pos, k_val, v_val) + + +@register_test +class TurboQuantKVCacheTest(OpTestCase): + """ + Test case for TurboQuantKVCache with tensor input_pos. + + Verifies eager-vs-MLX consistency for the compress + write path + (``mlx::tq_norm``, ``mlx::tq4_compress``, ``mlx::kv_cache_update``). + The packed cache is uint8 (byte-exact), norms are bf16 (loose tol). + """ + + name = "turboquant_kv_cache" + # uint8 packed cache stays effectively exact under atol<1; bf16 + # norms need ~1e-1 absolute slack for the eager-vs-MLX bf16 path. + rtol = 1e-5 + atol = 1e-1 + + def __init__( + self, + n_heads: int = 4, + head_dim: int = 64, + max_context_length: int = 128, + seq_step: int = 8, + enable_dynamic_shape: bool = True, + ): + # TurboQuantKVCache requires batch=1. + self.max_batch_size = 1 + self.n_heads = n_heads + self.head_dim = head_dim + self.max_context_length = max_context_length + self.seq_step = seq_step + self.enable_dynamic_shape = enable_dynamic_shape + + @classmethod + def get_test_configs(cls) -> List["TurboQuantKVCacheTest"]: + return [ + cls(), # default: head_dim=64 (smallest valid) + cls(head_dim=128), + cls(enable_dynamic_shape=False), + ] + + def create_model(self) -> nn.Module: + return TurboQuantKVCacheModel( + max_batch_size=self.max_batch_size, + max_context_length=self.max_context_length, + n_heads=self.n_heads, + head_dim=self.head_dim, + enable_dynamic_shape=self.enable_dynamic_shape, + ) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + input_pos = torch.tensor([0], dtype=torch.int64) + k_val = torch.randn( + self.max_batch_size, + self.n_heads, + self.seq_step, + self.head_dim, + dtype=torch.bfloat16, + ) + v_val = torch.randn( + self.max_batch_size, + self.n_heads, + self.seq_step, + self.head_dim, + dtype=torch.bfloat16, + ) + return (input_pos, k_val, v_val) + + def create_test_inputs(self) -> Tuple[torch.Tensor, ...]: + # With static shape, test inputs must match the exported seq length. + test_seq_step = ( + self.seq_step if not self.enable_dynamic_shape else self.seq_step + 4 + ) + input_pos = torch.tensor([16], dtype=torch.int64) + k_val = torch.randn( + self.max_batch_size, + self.n_heads, + test_seq_step, + self.head_dim, + dtype=torch.bfloat16, + ) + v_val = torch.randn( + self.max_batch_size, + self.n_heads, + test_seq_step, + self.head_dim, + dtype=torch.bfloat16, + ) + return (input_pos, k_val, v_val) + + def get_dynamic_shapes(self) -> Optional[Dict[str, any]]: + if not self.enable_dynamic_shape: + return None + seq_dim = Dim("seq_step", min=1, max=self.max_context_length) + return { + "input_pos": None, + "k_val": {2: seq_dim}, + "v_val": {2: seq_dim}, + } + + +class TurboQuantKVCacheIntModel(nn.Module): + """ + Test model that passes int/SymInt (not tensor) to + ``TurboQuantKVCache.update`` — the multi-layer pattern. + """ + + def __init__( + self, + max_batch_size: int, + max_context_length: int, + n_heads: int, + head_dim: int, + enable_dynamic_shape: bool = True, + ): + super().__init__() + self.cache = TurboQuantKVCache( + max_batch_size=max_batch_size, + max_context_length=max_context_length, + n_heads=n_heads, + head_dim=head_dim, + enable_dynamic_shape=enable_dynamic_shape, + ) + + def forward( + self, + input_pos: torch.Tensor, + k_val: torch.Tensor, + v_val: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + start_pos = input_pos[0].item() + return self.cache.update(start_pos, k_val, v_val) + + +@register_test +class TurboQuantKVCacheIntTest(OpTestCase): + """Test case for TurboQuantKVCache with int/SymInt input_pos.""" + + name = "turboquant_kv_cache_int" + rtol = 1e-5 + atol = 1e-1 + + def __init__( + self, + n_heads: int = 4, + head_dim: int = 64, + max_context_length: int = 128, + seq_step: int = 8, + enable_dynamic_shape: bool = True, + ): + self.max_batch_size = 1 + self.n_heads = n_heads + self.head_dim = head_dim + self.max_context_length = max_context_length + self.seq_step = seq_step + self.enable_dynamic_shape = enable_dynamic_shape + + @classmethod + def get_test_configs(cls) -> List["TurboQuantKVCacheIntTest"]: + return [ + cls(), + cls(head_dim=128), + ] + + def create_model(self) -> nn.Module: + return TurboQuantKVCacheIntModel( + max_batch_size=self.max_batch_size, + max_context_length=self.max_context_length, + n_heads=self.n_heads, + head_dim=self.head_dim, + enable_dynamic_shape=self.enable_dynamic_shape, + ) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + input_pos = torch.tensor([0], dtype=torch.int64) + k_val = torch.randn( + self.max_batch_size, + self.n_heads, + self.seq_step, + self.head_dim, + dtype=torch.bfloat16, + ) + v_val = torch.randn( + self.max_batch_size, + self.n_heads, + self.seq_step, + self.head_dim, + dtype=torch.bfloat16, + ) + return (input_pos, k_val, v_val) + + def create_test_inputs(self) -> Tuple[torch.Tensor, ...]: + test_seq_step = self.seq_step + 4 + input_pos = torch.tensor([16], dtype=torch.int64) + k_val = torch.randn( + self.max_batch_size, + self.n_heads, + test_seq_step, + self.head_dim, + dtype=torch.bfloat16, + ) + v_val = torch.randn( + self.max_batch_size, + self.n_heads, + test_seq_step, + self.head_dim, + dtype=torch.bfloat16, + ) + return (input_pos, k_val, v_val) + + def get_dynamic_shapes(self) -> Optional[Dict[str, any]]: + if not self.enable_dynamic_shape: + return None + seq_dim = Dim("seq_step", min=1, max=self.max_context_length) + return { + "input_pos": None, + "k_val": {2: seq_dim}, + "v_val": {2: seq_dim}, + } + + +class TurboQuantKVCacheSdpaModel(nn.Module): + """ + Test model wrapping ``TurboQuantKVCache.update + .sdpa`` — the full + prefill/decode flow (compress, dequant, attention in rotated space, + un-rotate output). + """ + + def __init__( + self, + max_batch_size: int, + max_context_length: int, + n_heads: int, + head_dim: int, + enable_dynamic_shape: bool = True, + ): + super().__init__() + self.max_context_length = max_context_length + self.cache = TurboQuantKVCache( + max_batch_size=max_batch_size, + max_context_length=max_context_length, + n_heads=n_heads, + head_dim=head_dim, + enable_dynamic_shape=enable_dynamic_shape, + ) + + def forward( + self, + input_pos: torch.Tensor, + k_val: torch.Tensor, + v_val: torch.Tensor, + query: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + start_pos = input_pos[0].item() + seq_len = k_val.size(2) + torch._check(start_pos >= 0) + torch._check(start_pos + seq_len <= self.max_context_length) + + k_packed, k_norms, v_packed, v_norms = self.cache.update( + start_pos, k_val, v_val + ) + out = self.cache.sdpa(query, start_pos) + return out, k_packed, k_norms, v_packed, v_norms + + +@register_test +class TurboQuantKVCacheSdpaTest(OpTestCase): + """ + Test case for ``TurboQuantKVCache.update`` + ``.sdpa``. + + Exercises the full forward path: compress + write through + ``mlx::tq_norm`` / ``mlx::tq4_compress`` / ``mlx::kv_cache_update``, + then dequantize and attend via ``mlx::tq_dequant`` / + ``mlx::custom_sdpa`` with Q rotated in and output rotated back. + Looser tolerance is needed because attention runs in bf16. + """ + + name = "turboquant_kv_cache_sdpa" + rtol = 1e-5 + atol = 5e-2 # bf16 SDPA output + + def __init__( + self, + n_heads: int = 4, + head_dim: int = 64, + max_context_length: int = 128, + seq_step: int = 8, + enable_dynamic_shape: bool = True, + ): + self.max_batch_size = 1 + self.n_heads = n_heads + self.head_dim = head_dim + self.max_context_length = max_context_length + self.seq_step = seq_step + self.enable_dynamic_shape = enable_dynamic_shape + + @classmethod + def get_test_configs(cls) -> List["TurboQuantKVCacheSdpaTest"]: + return [ + cls(), + cls(head_dim=128), + ] + + def create_model(self) -> nn.Module: + return TurboQuantKVCacheSdpaModel( + max_batch_size=self.max_batch_size, + max_context_length=self.max_context_length, + n_heads=self.n_heads, + head_dim=self.head_dim, + enable_dynamic_shape=self.enable_dynamic_shape, + ) + + def _make_inputs( + self, start: int, q_len: int, kv_len: int + ) -> Tuple[torch.Tensor, ...]: + input_pos = torch.tensor([start], dtype=torch.int64) + k_val = torch.randn( + self.max_batch_size, + self.n_heads, + kv_len, + self.head_dim, + dtype=torch.bfloat16, + ) + v_val = torch.randn( + self.max_batch_size, + self.n_heads, + kv_len, + self.head_dim, + dtype=torch.bfloat16, + ) + query = torch.randn( + self.max_batch_size, + self.n_heads, + q_len, + self.head_dim, + dtype=torch.bfloat16, + ) + return (input_pos, k_val, v_val, query) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + # Prefill-style: start=0, q_len == kv_len. + return self._make_inputs(start=0, q_len=self.seq_step, kv_len=self.seq_step) + + def create_test_inputs(self) -> Tuple[torch.Tensor, ...]: + # Decode-style: write a single token into the existing cache. + return self._make_inputs(start=16, q_len=1, kv_len=1) + + def get_dynamic_shapes(self) -> Optional[Dict[str, any]]: + if not self.enable_dynamic_shape: + return None + seq_dim = Dim("seq_step", min=1, max=self.max_context_length) + return { + "input_pos": None, + "k_val": {2: seq_dim}, + "v_val": {2: seq_dim}, + "query": {2: seq_dim}, + } + + class RingBufferKVCacheModel(nn.Module): """ Test model wrapping RingBufferKVCache from cache.py. @@ -5621,8 +6018,21 @@ def get_test_configs(cls) -> List["QuantizedLinearTest"]: cls(group_size=128), cls(qdtype=torch.int2), cls(qdtype=torch.int8), + # group_size=16: exercises the non-fused dequantize+matmul path + # (requires ET_MLX_ALLOW_NON_FUSED_QUANTIZED_OPS=1). + cls(qdtype=torch.int8, group_size=16), + cls(qdtype=torch.int4, group_size=16), + cls(qdtype=torch.int8, group_size=16, bias=False), ] + def generate_test_files(self, verbose=False): + if self.group_size < 32: + os.environ["ET_MLX_ALLOW_NON_FUSED_QUANTIZED_OPS"] = "1" + try: + return super().generate_test_files(verbose=verbose) + finally: + os.environ.pop("ET_MLX_ALLOW_NON_FUSED_QUANTIZED_OPS", None) + def create_model(self) -> nn.Module: model = LinearModel(self.in_features, self.out_features, bias=self.bias) model = model.to(self.dtype) diff --git a/backends/mlx/test/test_utils.py b/backends/mlx/test/test_utils.py index 660968195b7..5dbc35b824d 100644 --- a/backends/mlx/test/test_utils.py +++ b/backends/mlx/test/test_utils.py @@ -44,6 +44,7 @@ class TestTimeoutError(Exception): DTYPE_INT64 = 3 DTYPE_BFLOAT16 = 4 DTYPE_BOOL = 5 +DTYPE_UINT8 = 6 # Default tolerance presets for different data types. @@ -110,6 +111,7 @@ def torch_dtype_to_bin_dtype(dtype: torch.dtype) -> int: torch.int64: DTYPE_INT64, torch.bfloat16: DTYPE_BFLOAT16, torch.bool: DTYPE_BOOL, + torch.uint8: DTYPE_UINT8, } if dtype not in mapping: raise ValueError(f"Unsupported dtype: {dtype}") @@ -125,6 +127,7 @@ def bin_dtype_to_torch_dtype(dtype_val: int) -> torch.dtype: DTYPE_INT64: torch.int64, DTYPE_BFLOAT16: torch.bfloat16, DTYPE_BOOL: torch.bool, + DTYPE_UINT8: torch.uint8, } if dtype_val not in mapping: raise ValueError(f"Unknown dtype value: {dtype_val}") @@ -208,6 +211,7 @@ def load_tensors_from_bin(path: Union[str, Path]) -> List[torch.Tensor]: torch.int32: np.int32, torch.int64: np.int64, torch.bool: np.bool_, + torch.uint8: np.uint8, # bfloat16 needs special handling - read as uint16 } @@ -219,6 +223,7 @@ def load_tensors_from_bin(path: Union[str, Path]) -> List[torch.Tensor]: torch.int64: 8, torch.bfloat16: 2, torch.bool: 1, + torch.uint8: 1, } tensors = [] diff --git a/backends/nxp/backend/edge_helper.py b/backends/nxp/backend/edge_helper.py index 957b673bb6a..1ea86f589ac 100644 --- a/backends/nxp/backend/edge_helper.py +++ b/backends/nxp/backend/edge_helper.py @@ -318,7 +318,7 @@ def is_no_op_on_neutron(node: Node, parameters_mapping: dict[str, Parameter]) -> input_data = torch.rand(val.shape, dtype=val.dtype) * 10 - 5 args_with_random_data.append(input_data) - case list(): + case list() if any(isinstance(a, Node) for a in arg): # Lists of input nodes are not supported to keep the code simple. It is not crucial to support this # case as the affected operators are either not supported on Neutron, or are extremely unlikely to # be no-ops (e.g. GRU). One exception is `aten.cat`, which is explicitly supported above. diff --git a/backends/nxp/backend/ir/converter/node_converters/ops_converters/add_tensor_converter.py b/backends/nxp/backend/ir/converter/node_converters/ops_converters/add_tensor_converter.py index fd28b077b8a..673af19310f 100644 --- a/backends/nxp/backend/ir/converter/node_converters/ops_converters/add_tensor_converter.py +++ b/backends/nxp/backend/ir/converter/node_converters/ops_converters/add_tensor_converter.py @@ -3,6 +3,9 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import torch + +from executorch.backends.nxp.backend.data_format import NXP_NODE_FORMAT from executorch.backends.nxp.backend.ir.converter.node_converter import ( CustomDelegationOptions, NodeConverter, @@ -23,11 +26,33 @@ def _is_supported_on_target( parameters_mapping: dict[str, Parameter], custom_delegation_options: CustomDelegationOptions, ) -> bool: - if NodeConverter.uses_shape_broadcasting(node): - # Shape broadcasting may require the addition of `Transpose` ops during conversion. - return False + if custom_delegation_options.use_new_flow_neutron_c: + if not NodeConverter.at_least_one_input_shape_matches_the_output_shape( + node + ): + return False - return True + # If one input is in channel first and ranks of input tensors are not equal, we need to add Transposes + # Transpose is currently not supported for new flow + if any( + input_node.meta[NXP_NODE_FORMAT].is_channels_first() + for input_node in node.all_input_nodes + ) and NodeConverter._node_inputs_ranks_not_equal(node): + return False + + supported_types = [torch.int8, torch.uint8] + if not NodeConverter.uses_quantization_type_for_io( + node, supported_types, [0, 1], [0] + ): + return False + + return True + else: + if NodeConverter.uses_shape_broadcasting(node): + # Shape broadcasting may require the addition of `Transpose` ops during conversion. + return False + + return True @staticmethod def _is_supported_in_IR( @@ -43,12 +68,13 @@ def _is_supported_in_IR( return True - # add.Tensor Node format: (Tensor self, Tensor other, *, Scalar alpha=1) def convert(self, node: Node): - """Convert 'add_tensor' operator to TFLite 'add'.""" + """Convert 'add_tensor' operator to NeutronIR 'Add'. + The ExecuTorch schema is: + add.Tensor(Tensor self, Tensor other, Scalar alpha=1) + """ self.assert_convertible(node) - t_op = self._create_tflite_op_with_io_tensors(node) - t_op.builtin_options = add_options.Add() + self.builder.append_operators([t_op]) diff --git a/backends/nxp/backend/ir/converter/node_converters/ops_converters/max_pool2d_with_indices_converter.py b/backends/nxp/backend/ir/converter/node_converters/ops_converters/max_pool2d_with_indices_converter.py index 975aaf57625..b7e761c45e6 100644 --- a/backends/nxp/backend/ir/converter/node_converters/ops_converters/max_pool2d_with_indices_converter.py +++ b/backends/nxp/backend/ir/converter/node_converters/ops_converters/max_pool2d_with_indices_converter.py @@ -152,9 +152,7 @@ def _get_node_args( :return: Tuple of (kernel_size, stride, padding, dilation, ceil_mode). """ kernel_size = node.args[1] - stride = node.args[ - 2 - ] # The default value is equal to the kernel_size, so it is never empty here. + stride = try_get_arg(node, 2) or kernel_size padding = try_get_arg(node, 3) or (0, 0) dilation = try_get_arg(node, 4) or (1, 1) ceil_mode = try_get_arg(node, 5) or False diff --git a/backends/nxp/backend/ir/converter/node_converters/ops_converters/mean_dim_converter.py b/backends/nxp/backend/ir/converter/node_converters/ops_converters/mean_dim_converter.py index c4b828df39f..4ba56a6b755 100644 --- a/backends/nxp/backend/ir/converter/node_converters/ops_converters/mean_dim_converter.py +++ b/backends/nxp/backend/ir/converter/node_converters/ops_converters/mean_dim_converter.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. import torch + from executorch.backends.nxp.backend.data_format import NXP_NODE_FORMAT from executorch.backends.nxp.backend.ir.converter.conversion.translator import ( @@ -11,6 +12,7 @@ ) from executorch.backends.nxp.backend.ir.converter.node_converter import ( CustomDelegationOptions, + is_not_qdq_node, NodeConverter, ) from executorch.backends.nxp.backend.ir.converter.node_converters.shared.reduce_utils import ( @@ -21,10 +23,40 @@ ) from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec from torch.fx import Node +from torch.fx.passes.infra.partitioner import Partition from torch.nn import Parameter class MeanDimConverter(NodeConverter): + + @classmethod + def supports_partitioning_result( + cls, + node: Node, + partition_list: list[Partition], + custom_delegation_options: CustomDelegationOptions, + neutron_target_spec: NeutronTargetSpec, + parameters_mapping: dict[str, Parameter], + ) -> bool: + if custom_delegation_options.use_new_flow_neutron_c: + dim, keepdim = MeanDimConverter._get_attrs(node) + input_shape = node.args[0].meta["val"].shape + + is_alone_in_partition = cls.is_node_alone_in_partition( + node, partition_list, filter_fn=is_not_qdq_node + ) + + if ( + is_alone_in_partition + and keepdim + and all(input_shape[d] == 1 for d in dim) + ): + # The operator is a no-op, so the Neutron Converter will skip it. If it's the only node in the + # partition, the graph would end up empty. + return False + + return True + @staticmethod def _is_supported_on_target( node: Node, @@ -32,34 +64,49 @@ def _is_supported_on_target( parameters_mapping: dict[str, Parameter], custom_delegation_options: CustomDelegationOptions, ) -> bool: - keepdim = node.args[2] if len(node.args) >= 3 else False - rank = len(node.args[0].meta["val"].shape) - dim = [MeanDimConverter._to_pos_dim(d, rank) for d in node.args[1]] + if custom_delegation_options.use_new_flow_neutron_c: + # Requirements specified by the new Neutron flow documentation. + + if not NodeConverter.uses_quantization_type_for_io( + node, + supported_types=[torch.int8, torch.uint8], + input_indices=[0], + output_indices=[0], + ): + return False - if rank != 4 or not keepdim: - # neutron-converter/src/OperatorC/GlobalAvgPoolPlugin.cpp#74-77 - return False + return True - # The `mean.dim` gets converted to AveragePool by the NeutronConverter, so the channels must be a - # multiple of `num_macs`. - # neutron-converter/src/OperatorC/GlobalAvgPoolPlugin.cpp#59-85 - num_macs = neutron_target_spec.get_num_macs() - channels_dim = 1 if node.meta[NXP_NODE_FORMAT].is_channels_first() else -1 - if (node.meta["val"].shape[channels_dim] % num_macs) != 0: - return False + else: + # Requirements of the old Neutron flow. + rank = len(node.args[0].meta["val"].shape) + dim, keepdim = MeanDimConverter._get_attrs(node) + dim = [MeanDimConverter._to_pos_dim(d, rank) for d in dim] - # Neutron only supports reduction over the spatial dimensions H, W. - if node.meta[NXP_NODE_FORMAT].is_channels_first(): - # The input is NCHW. H and W are at indices 2 and 3. - if dim not in [[2, 3], [3, 2]]: + if rank != 4 or not keepdim: + # neutron-converter/src/OperatorC/GlobalAvgPoolPlugin.cpp#74-77 return False - else: - # The input is formatless. It can be considered as NHWC, as this is the way Neutron will look at - # the dimensions. So H and W are the middle dimensions. - if dim not in [[1, 2], [2, 1]]: + + # The `mean.dim` gets converted to AveragePool by the NeutronConverter, so the channels must be a + # multiple of `num_macs`. + # neutron-converter/src/OperatorC/GlobalAvgPoolPlugin.cpp#59-85 + num_macs = neutron_target_spec.get_num_macs() + channels_dim = 1 if node.meta[NXP_NODE_FORMAT].is_channels_first() else -1 + if (node.meta["val"].shape[channels_dim] % num_macs) != 0: return False - return True + # Neutron only supports reduction over the spatial dimensions H, W. + if node.meta[NXP_NODE_FORMAT].is_channels_first(): + # The input is NCHW. H and W are at indices 2 and 3. + if dim not in [[2, 3], [3, 2]]: + return False + else: + # The input is formatless. It can be considered as NHWC, as this is the way Neutron will look at + # the dimensions. So H and W are the middle dimensions. + if dim not in [[1, 2], [2, 1]]: + return False + + return True @staticmethod def _is_supported_in_IR( @@ -91,15 +138,29 @@ def _normalize_and_to_channel_last_dim(dim: list[int], rank: int) -> list[int]: perm = create_channels_last_to_channels_first_permutation(rank, True) dim = [perm[d] for d in dim] + # noinspection PyTypeChecker return dim - # Mean Dim Node format: (Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) + @staticmethod + def _get_attrs(node: Node) -> tuple[list[int], bool]: + dim = node.args[1] + keepdim = node.args[2] if len(node.args) >= 3 else False + return dim, keepdim + def convert(self, node: Node): - """Convert 'mean.dim' operator to TFLite 'Mean'.""" + """Convert the 'mean.dim' operator to NeutronIR 'Mean'. + The ExecuTorch schema is: + mean.dim( + Tensor self, + int[1]? dim, + bool keepdim=False, + *, + ScalarType? dtype=None + ) -> Tensor + """ self.assert_convertible(node) - dim = node.args[1] - keepdim = node.args[2] if len(node.args) >= 3 else False + dim, keepdim = self._get_attrs(node) t_op = self._create_tflite_op_with_io_tensors(node) t_op.builtin_options = mean_options.Mean(keepdim) diff --git a/backends/nxp/backend/ir/converter/node_converters/ops_converters/slice_tensor_converter.py b/backends/nxp/backend/ir/converter/node_converters/ops_converters/slice_tensor_converter.py index f2002cc311c..f5df822b6ad 100644 --- a/backends/nxp/backend/ir/converter/node_converters/ops_converters/slice_tensor_converter.py +++ b/backends/nxp/backend/ir/converter/node_converters/ops_converters/slice_tensor_converter.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. import numpy as np +import torch from executorch.backends.nxp.backend.data_format import NXP_NODE_FORMAT from executorch.backends.nxp.backend.edge_helper import input_tensor from executorch.backends.nxp.backend.ir.converter.conversion import translator @@ -31,6 +32,15 @@ def _is_supported_on_target( parameters_mapping: dict[str, Parameter], custom_delegation_options: CustomDelegationOptions, ) -> bool: + if custom_delegation_options.use_new_flow_neutron_c: + supported_types = [torch.int8, torch.uint8] + if not NodeConverter.uses_quantization_type_for_io( + node, supported_types, [0], [0] + ): + return False + + return True + input_shape = input_tensor(node, 0).shape dim = node.args[1] if node.args[0].meta[NXP_NODE_FORMAT].is_channels_first(): @@ -94,6 +104,23 @@ def _convert_to_slice(self, t_op, main_input, input_rank, dim, start, end) -> No size[dim] = max(end - start, 0) begin[dim] = start + # In the new Neutron flow, slicing can be done along any dim, so + # no additional `transpose` ops have to be added. + if self.context.custom_delegation_options.use_new_flow_neutron_c: + begin_tensor = self.builder.create_tensor_for_data( + np.asarray(begin, np.int32), "begin" + ) + size_tensor = self.builder.create_tensor_for_data( + np.asarray(size, np.int32), "size" + ) + + t_op.tmp_inputs = [main_input, begin_tensor, size_tensor] + t_op.builtin_options = slice_options.Slice() + ops = OpsList(middle_op=t_op) + + self.builder.append_operators(ops.flatten()) + return None + # We can slice only the channels dimension # So we swap the sliced dimension with the channels dimension begin[-1], begin[dim] = begin[dim], begin[-1] @@ -131,6 +158,10 @@ def _get_clipped_slice_args(node: Node) -> tuple[Dim, Start, End]: _, dim, start, end = node.args sliced_tensor_rank = input_shape[dim] + # convert numbering `from the end` to `from the beginning`, ie. normalize + end = end + sliced_tensor_rank if end < 0 else end + start = start + sliced_tensor_rank if start < 0 else start + end = int(np.clip(end, 0, sliced_tensor_rank)) start = int(np.clip(start, 0, sliced_tensor_rank)) diff --git a/backends/nxp/backend/ir/converter/node_converters/ops_converters/sub_tensor_converter.py b/backends/nxp/backend/ir/converter/node_converters/ops_converters/sub_tensor_converter.py index e97f4bf63c2..79dbcbcc012 100644 --- a/backends/nxp/backend/ir/converter/node_converters/ops_converters/sub_tensor_converter.py +++ b/backends/nxp/backend/ir/converter/node_converters/ops_converters/sub_tensor_converter.py @@ -3,6 +3,9 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import torch + +from executorch.backends.nxp.backend.data_format import NXP_NODE_FORMAT from executorch.backends.nxp.backend.ir.converter.node_converter import ( CustomDelegationOptions, NodeConverter, @@ -23,11 +26,33 @@ def _is_supported_on_target( parameters_mapping: dict[str, Parameter], custom_delegation_options: CustomDelegationOptions, ) -> bool: - if NodeConverter.uses_shape_broadcasting(node): - # Shape broadcasting may require the addition of `Transpose` ops during conversion. - return False + if custom_delegation_options.use_new_flow_neutron_c: + if not NodeConverter.at_least_one_input_shape_matches_the_output_shape( + node + ): + return False - return True + # If one input is in channel first and ranks of input tensors are not equal, we need to add Transposes + # Transpose is currently not supported for new flow + if any( + input_node.meta[NXP_NODE_FORMAT].is_channels_first() + for input_node in node.all_input_nodes + ) and NodeConverter._node_inputs_ranks_not_equal(node): + return False + + supported_types = [torch.int8, torch.uint8] + if not NodeConverter.uses_quantization_type_for_io( + node, supported_types, [0, 1], [0] + ): + return False + + return True + else: + if NodeConverter.uses_shape_broadcasting(node): + # Shape broadcasting may require the addition of `Transpose` ops during conversion. + return False + + return True @staticmethod def _is_supported_in_IR( @@ -45,9 +70,12 @@ def _is_supported_in_IR( return True - # sub.Tensor Node format: (Tensor self, Tensor other, *, Scalar alpha=1) def convert(self, node: Node): - """Convert 'sub_tensor' operator to NeutronIR 'Sub'.""" + """Convert 'sub_tensor' operator to NeutronIR 'Sub'. + The ExecuTorch schema is: + sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) + """ + self.assert_convertible(node) t_op = self._create_tflite_op_with_io_tensors(node) diff --git a/backends/nxp/backend/ir/converter/node_converters/ops_converters/tanh_converter.py b/backends/nxp/backend/ir/converter/node_converters/ops_converters/tanh_converter.py index 427865f8ee7..54192628e24 100644 --- a/backends/nxp/backend/ir/converter/node_converters/ops_converters/tanh_converter.py +++ b/backends/nxp/backend/ir/converter/node_converters/ops_converters/tanh_converter.py @@ -1,8 +1,10 @@ -# Copyright 2025 NXP +# Copyright 2025-2026 NXP # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import torch + from executorch.backends.nxp.backend.custom_delegation_options import ( CustomDelegationOptions, ) @@ -10,6 +12,8 @@ from executorch.backends.nxp.backend.ir.lib.tflite.BuiltinOperator import ( BuiltinOperator, ) + +from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec from torch.fx import Node from torch.nn import Parameter @@ -24,7 +28,33 @@ def _is_supported_in_IR( ) -> bool: return True + @staticmethod + def _is_supported_on_target( + node: Node, + neutron_target_spec: NeutronTargetSpec, + parameters_mapping: dict[str, Parameter], + custom_delegation_options: CustomDelegationOptions, + ) -> bool: + if custom_delegation_options.use_new_flow_neutron_c: + # Requirements specified by the new Neutron flow documentation. + + if not NodeConverter.uses_quantization_type_for_io( + node, + supported_types=[torch.int8, torch.uint8], + input_indices=[0], + output_indices=[0], + ): + return False + + return True + def convert(self, node: Node): + """Convert the `aten.tanh` operator to NeutronIR `Tanh`. + The ExecuTorch schema is: + tanh( + Tensor self + ) -> Tensor + """ self.assert_convertible(node) t_op = self._create_tflite_op_with_io_tensors(node) diff --git a/backends/nxp/backend/ir/converter/node_converters/ops_converters/upsample_bilinear2d_converter.py b/backends/nxp/backend/ir/converter/node_converters/ops_converters/upsample_bilinear2d_converter.py index 33d97dff642..1183ef494b5 100644 --- a/backends/nxp/backend/ir/converter/node_converters/ops_converters/upsample_bilinear2d_converter.py +++ b/backends/nxp/backend/ir/converter/node_converters/ops_converters/upsample_bilinear2d_converter.py @@ -4,11 +4,13 @@ # LICENSE file in the root directory of this source tree. import numpy as np +import torch from executorch.backends.nxp.backend.data_format import DataFormat, NXP_NODE_FORMAT from executorch.backends.nxp.backend.edge_helper import node_has_well_defined_shape from executorch.backends.nxp.backend.ir.converter.node_converter import ( CustomDelegationOptions, + is_not_qdq_node, NodeConverter, ) from executorch.backends.nxp.backend.ir.tflite_generator.builtin_options.resize_bilinear_options import ( @@ -16,12 +18,35 @@ ) from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec from torch.fx import Node +from torch.fx.passes.infra.partitioner import Partition from torch.nn import Parameter # noinspection SpellCheckingInspection class UpsampleBilinear2DConverter(NodeConverter): + @classmethod + def supports_partitioning_result( + cls, + node: Node, + partition_list: list[Partition], + custom_delegation_options: CustomDelegationOptions, + neutron_target_spec: NeutronTargetSpec, + parameters_mapping: dict[str, Parameter], + ) -> bool: + input_shape = node.all_input_nodes[0].meta["val"].shape + output_shape = node.meta["val"].shape + is_alone_in_partition = cls.is_node_alone_in_partition( + node, partition_list, filter_fn=is_not_qdq_node + ) + + if is_alone_in_partition and input_shape == output_shape: + # The operator is a no-op, so the Neutron Converter will skip it. If it's the only node in the + # partition, the graph would end up empty. + return False + + return True + @staticmethod def _is_supported_in_IR( node: Node, @@ -36,6 +61,14 @@ def _is_supported_in_IR( " format. Please report this." ) + # The conversion requires the output shape to be known and static. + if not node_has_well_defined_shape(node): + return False + + if len(node.meta["val"].shape) != 4: + # Unexpected case. The input should always be 4D. + return False + return True @staticmethod @@ -45,38 +78,58 @@ def _is_supported_on_target( parameters_mapping: dict[str, Parameter], custom_delegation_options: CustomDelegationOptions, ) -> bool: - # Neutron requires static shapes. - # neutron-converter/src/OperatorC/UpsamplePlugin.cpp?at=NEUTRON_SOFTWARE_2.2.3#74 - if not node_has_well_defined_shape(node): - return False - - if len(node.meta["val"].shape) != 4: - # Unexpected case. The input should always be 4D. - return False - - # The tensors here use the channels first format (NCHW). + # The tensors are always 4D and use the channels first format (NCHW). _, in_c, in_h, in_w = node.all_input_nodes[0].meta["val"].shape _, _, out_h, out_w = node.meta["val"].shape - # Neutron supports only the doubling and quadrupleing of both height and width at the same time. - # neutron-library/src/utils/NeutronLibraryInterrogation.cpp?at=refs%2Ftags%2FNEUTRON_SOFTWARE_2.2.3#778 - supported_scales = [2, 4] - if not any( - in_h * scale == out_h and in_w * scale == out_w - for scale in supported_scales - ): - return False - - # Neutron requires the input channels to be a multiple of `num_macs`. - # neutron-library/src/utils/NeutronLibraryInterrogation.cpp?at=refs%2Ftags%2FNEUTRON_SOFTWARE_2.2.3#777 - if in_c % neutron_target_spec.get_num_macs() != 0: - return False + if custom_delegation_options.use_new_flow_neutron_c: + # Requirements specified by the new Neutron flow documentation. + + if not NodeConverter.uses_quantization_type_for_io( + node, + supported_types=[torch.int8, torch.uint8], + input_indices=[0], + output_indices=[0], + ): + return False + + supported_scales = [1, 2, 4, 8] + align_corners = node.args[2] + if align_corners: + if in_h == 1 or in_w == 1: + return False # Avoid division by 0. + h_scale = (out_h - 1) / (in_h - 1) + w_scale = (out_w - 1) / (in_w - 1) + else: + h_scale = out_h / in_h + w_scale = out_w / in_w + + # The H and W scales don't need to be equal, but both must be supported. + if (h_scale not in supported_scales) or (w_scale not in supported_scales): + return False + + else: + # Requirements of the old Neutron flow. + + # Neutron supports only the doubling and quadrupleing of both height and width at the same time. + # neutron-library/src/utils/NeutronLibraryInterrogation.cpp?at=refs%2Ftags%2FNEUTRON_SOFTWARE_2.2.3#778 + supported_scales = [2, 4] + if not any( + in_h * scale == out_h and in_w * scale == out_w + for scale in supported_scales + ): + return False + + # Neutron requires the input channels to be a multiple of `num_macs`. + # neutron-library/src/utils/NeutronLibraryInterrogation.cpp?at=refs%2Ftags%2FNEUTRON_SOFTWARE_2.2.3#777 + if in_c % neutron_target_spec.get_num_macs() != 0: + return False return True def convert(self, node: Node): """Convert the `aten.upsample_bilinear2d.vec` operator to Neutron IR `ResizeBilinear`. - The schema is: + The ExecuTorch schema is: aten::upsample_bilinear2d.vec( Tensor input, SymInt[]? output_size, @@ -109,6 +162,7 @@ def convert(self, node: Node): # and the second one is what NeutronIR uses when `align_corners == False and half_pixel_centers == True`. # https://github.com/tensorflow/tensorflow/blob/v2.20.0/tensorflow/lite/kernels/internal/reference/resize_bilinear.h#L82-L88 # https://github.com/tensorflow/tensorflow/blob/v2.20.0/tensorflow/lite/kernels/internal/reference/resize_bilinear.h#L172-L180 + # Also, the new Neutron flow requires that `align_corners` and `half_pixel_centers` are not True simultainiously. align_corners = node.args[2] half_pixel_centers = not align_corners t_op.builtin_options = ResizeBilinear(align_corners, half_pixel_centers) diff --git a/backends/nxp/backend/ir/converter/node_converters/ops_converters/upsample_nearest2d_converter.py b/backends/nxp/backend/ir/converter/node_converters/ops_converters/upsample_nearest2d_converter.py index 1ddc71425ef..6e18a7bfe67 100644 --- a/backends/nxp/backend/ir/converter/node_converters/ops_converters/upsample_nearest2d_converter.py +++ b/backends/nxp/backend/ir/converter/node_converters/ops_converters/upsample_nearest2d_converter.py @@ -4,11 +4,13 @@ # LICENSE file in the root directory of this source tree. import numpy as np +import torch from executorch.backends.nxp.backend.data_format import DataFormat, NXP_NODE_FORMAT from executorch.backends.nxp.backend.edge_helper import node_has_well_defined_shape from executorch.backends.nxp.backend.ir.converter.node_converter import ( CustomDelegationOptions, + is_not_qdq_node, NodeConverter, ) from executorch.backends.nxp.backend.ir.tflite_generator.builtin_options.resize_nearest_neighbor_options import ( @@ -16,12 +18,37 @@ ) from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec from torch.fx import Node +from torch.fx.passes.infra.partitioner import Partition from torch.nn import Parameter +HeightScale = float +WidthScale = float + # noinspection SpellCheckingInspection class UpsampleNearest2DConverter(NodeConverter): + @classmethod + def supports_partitioning_result( + cls, + node: Node, + partition_list: list[Partition], + custom_delegation_options: CustomDelegationOptions, + neutron_target_spec: NeutronTargetSpec, + parameters_mapping: dict[str, Parameter], + ) -> bool: + h_scale, w_scale = cls._get_effective_scales(node) + is_alone_in_partition = cls.is_node_alone_in_partition( + node, partition_list, filter_fn=is_not_qdq_node + ) + + if is_alone_in_partition and h_scale == w_scale == 1: + # The operator is a no-op, so the Neutron Converter will skip it. If it's the only node in the + # partition, the graph would end up empty. + return False + + return True + @staticmethod def _is_supported_in_IR( node: Node, @@ -36,6 +63,14 @@ def _is_supported_in_IR( " format. Please report this." ) + # The conversion requires the output shape to be known and static. + if not node_has_well_defined_shape(node): + return False + + if len(node.meta["val"].shape) != 4: + # Unexpected case. The input should always be 4D. + return False + return True @staticmethod @@ -45,39 +80,62 @@ def _is_supported_on_target( parameters_mapping: dict[str, Parameter], custom_delegation_options: CustomDelegationOptions, ) -> bool: - # Neutron requires static shapes. - # neutron-converter/src/OperatorC/UpsamplePlugin.cpp?at=NEUTRON_SOFTWARE_2.2.3#74 - if not node_has_well_defined_shape(node): - return False - - if len(node.meta["val"].shape) != 4: - # Unexpected case. The input should always be 4D. - return False - - # The tensors here use the channels first format (NCHW). + # The tensors are always 4D and use the channels first format (NCHW). _, in_c, in_h, in_w = node.all_input_nodes[0].meta["val"].shape _, _, out_h, out_w = node.meta["val"].shape - # Neutron supports only the doubling and quadrupleing of both height and width at the same time. - # neutron-library/src/utils/NeutronLibraryInterrogation.cpp?at=refs%2Ftags%2FNEUTRON_SOFTWARE_2.2.3#768 - # neutron-library/src/utils/NeutronLibraryInterrogation.cpp?at=refs%2Ftags%2FNEUTRON_SOFTWARE_2.2.3#778 - supported_scales = [2, 4] - if not any( - in_h * scale == out_h and in_w * scale == out_w - for scale in supported_scales - ): - return False - - # Neutron requires the input channels to be a multiple of `num_macs`. - # neutron-library/src/utils/NeutronLibraryInterrogation.cpp?at=refs%2Ftags%2FNEUTRON_SOFTWARE_2.2.3#767 - if in_c % neutron_target_spec.get_num_macs() != 0: - return False + if custom_delegation_options.use_new_flow_neutron_c: + # Requirements specified by the new Neutron flow documentation. + + if not NodeConverter.uses_quantization_type_for_io( + node, + supported_types=[torch.int8, torch.uint8], + input_indices=[0], + output_indices=[0], + ): + return False + + supported_scales = [1, 2, 4, 8] + h_scale, w_scale = UpsampleNearest2DConverter._get_effective_scales(node) + # The H and W scales don't need to be equal but both must be supported. + if (h_scale not in supported_scales) or (w_scale not in supported_scales): + return False + + else: + # Requirements of the old Neutron flow. + + # Neutron supports only the doubling and quadrupleing of both height and width at the same time. + # neutron-library/src/utils/NeutronLibraryInterrogation.cpp?at=refs%2Ftags%2FNEUTRON_SOFTWARE_2.2.3#768 + # neutron-library/src/utils/NeutronLibraryInterrogation.cpp?at=refs%2Ftags%2FNEUTRON_SOFTWARE_2.2.3#778 + supported_scales = [2, 4] + if not any( + in_h * scale == out_h and in_w * scale == out_w + for scale in supported_scales + ): + return False + + # Neutron requires the input channels to be a multiple of `num_macs`. + # neutron-library/src/utils/NeutronLibraryInterrogation.cpp?at=refs%2Ftags%2FNEUTRON_SOFTWARE_2.2.3#767 + if in_c % neutron_target_spec.get_num_macs() != 0: + return False return True + @staticmethod + def _get_effective_scales(node: Node) -> tuple[HeightScale, WidthScale]: + # Neutron supports variants where `align_corners=False` and `align_corners=True`. ExecuTorch doesn't have this + # parameter. Its behavior is equivalent to `align_corners=False`. Hence, the scale calculation corresponds to + # the `align_corners=False` case in the Neutron documentation. + _, _, in_h, in_w = node.all_input_nodes[0].meta["val"].shape + _, _, out_h, out_w = node.meta["val"].shape + h_scale = out_h / in_h + w_scale = out_w / in_w + + return h_scale, w_scale + def convert(self, node: Node): """Convert the `aten.upsample_nearest2d.vec` operator to Neutron IR `ResizeNearestNeighbor`. - The schema is: + The ExecuTorch schema is: aten::upsample_nearest2d.vec( Tensor input, SymInt[]? output_size, @@ -90,6 +148,8 @@ def convert(self, node: Node): x = t_op.tmp_inputs[0] y = t_op.tmp_outputs[0] + # Neutron supports variants where `align_corners=False` and `align_corners=True`. ExecuTorch doesn't have this + # parameter. Its behavior is equivalent to `align_corners=False` and `half_pixel_centers=False`. t_op.builtin_options = ResizeNearestNeighbor(False, False) # The `aten.upsample_nearest2d` can use either the `size` attribute or the `scale_factor` to define the output diff --git a/backends/nxp/tests/dataset_creator.py b/backends/nxp/tests/dataset_creator.py index eaf267f4fcf..fdfd363c257 100644 --- a/backends/nxp/tests/dataset_creator.py +++ b/backends/nxp/tests/dataset_creator.py @@ -8,6 +8,7 @@ import shutil from collections import OrderedDict from copy import deepcopy +from dataclasses import dataclass from os import mkdir from random import sample, seed @@ -19,6 +20,7 @@ ) from executorch.backends.nxp.tests.calibration_dataset import CalibrationDataset from executorch.backends.nxp.tests.executorch_pipeline import ModelInputSpec +from executorch.exir.scalar_type import ScalarType from torch import Tensor @@ -33,6 +35,72 @@ def _get_calibration_and_testing_dataset_directory_names( return calibration_path, test_path +@dataclass +class InputQuantizationSpec: + name: str + scale: float + zp: int + dtype: ScalarType + + +def _replace_input_binary_tensor_with_quantized_variant( + input_bin_tensor_path: str, + input_spec: ModelInputSpec, + q_params: InputQuantizationSpec, +): + tensor = np.fromfile( + input_bin_tensor_path, dtype=torch_type_to_numpy_type(input_spec.dtype) + ) + if q_params.dtype == ScalarType.CHAR: + tensor = np.add(np.round(np.divide(tensor, [q_params.scale])), [q_params.zp]) + tensor = np.clip(tensor, -128, 127).astype(np.int8) + else: + raise ValueError(f"Unknown quantization type: '{q_params.dtype}.") + tensor.tofile(input_bin_tensor_path) + + +def create_quantized_variant_of_dataset( + dataset_dir: str, + dataset_dir_quant: str, + input_quant_spec: list[InputQuantizationSpec], + input_spec: list[ModelInputSpec], +): + """ + Create quantized dataset from provided quantization spec. Dataset is cloned from directory 'dataset_dir'. + + :param dataset_dir: Original (float) dataset directory. + :param dataset_dir_quant: Quantized dataset directory. + :param input_quant_spec: Quantization parameters used for dataset quantization. + :param input_spec: Model inputs specification. + """ + assert len(input_quant_spec) > 0 + + shutil.copytree(dataset_dir, dataset_dir_quant, dirs_exist_ok=True) + + if len(input_quant_spec) == 1: + # Single input dataset - quantize only files in dataset's root dir with first input_quant_spec + input_spec = input_spec[0] + input_quant_spec = input_quant_spec[0] + + for file in os.listdir(dataset_dir_quant): + input_bin_tensor_path = os.path.join(dataset_dir_quant, file) + _replace_input_binary_tensor_with_quantized_variant( + input_bin_tensor_path, input_spec, input_quant_spec + ) + else: + # Iterate over samples (subfolders) + for dir_ in os.listdir(dataset_dir_quant): + # Iterate over each input in sample + sample_dir = os.path.join(dataset_dir_quant, dir_) + + for idx, input_ in enumerate(sorted(os.listdir(sample_dir))): + _replace_input_binary_tensor_with_quantized_variant( + os.path.join(sample_dir, input_), + input_spec[idx], + input_quant_spec[idx], + ) + + class DatasetCreator(abc.ABC): @abc.abstractmethod diff --git a/backends/nxp/tests/executorch_pipeline.py b/backends/nxp/tests/executorch_pipeline.py index 8f588be621d..e85a5de4d1b 100644 --- a/backends/nxp/tests/executorch_pipeline.py +++ b/backends/nxp/tests/executorch_pipeline.py @@ -276,6 +276,8 @@ def to_quantized_executorch_program( dataset_dir: str | None = None, delegate_to_npu=True, use_new_flow_neutron_c: bool = False, + operators_not_to_delegate: list[str] = None, + remove_quant_io_ops: bool = False, ) -> ExecutorchProgramManager: if dataset_dir: # Extract calibration data from a directory. @@ -295,6 +297,8 @@ def to_quantized_executorch_program( use_neutron_for_format_conversion=use_neutron_for_format_conversion, delegate_to_npu=delegate_to_npu, use_new_flow_neutron_c=use_new_flow_neutron_c, + operators_not_to_delegate=operators_not_to_delegate, + remove_quant_io_ops=remove_quant_io_ops, **get_calibration_inputs_fn, ) diff --git a/backends/nxp/tests/generic_tests/test_convert_div_to_mul.py b/backends/nxp/tests/generic_tests/test_convert_div_to_mul.py index ee89d5d5619..9201f32349f 100644 --- a/backends/nxp/tests/generic_tests/test_convert_div_to_mul.py +++ b/backends/nxp/tests/generic_tests/test_convert_div_to_mul.py @@ -6,6 +6,7 @@ import numpy as np import pytest import torch + from executorch.backends.nxp.aten_passes.neutron_aten_pass_manager import ( ConvertDivToMulPass, NeutronAtenPassManager, @@ -13,6 +14,7 @@ from executorch.backends.nxp.backend.edge_program_converter import ( EdgeProgramToIRConverter, ) +from executorch.backends.nxp.tests.dataset_creator import RandomDatasetCreator from executorch.backends.nxp.tests.executorch_pipeline import ( neutron_target_spec, to_quantized_edge_program, @@ -21,11 +23,13 @@ convert_run_compare, graph_contains_any_of_ops, ) - +from executorch.backends.nxp.tests.graph_verifier import DetailedGraphVerifier from executorch.backends.nxp.tests.models import ( NonstaticDivLinearModel, StaticDivLinearModel, ) +from executorch.backends.nxp.tests.nsys_testing import lower_run_compare +from executorch.backends.nxp.tests.ops_aliases import MulTensor from executorch.exir.dialects._ops import ops as exir_ops from torch.export import ExportedProgram @@ -248,3 +252,59 @@ def test_convert_div_to_mul_full_pipeline(mocker, input_shape, is_scalar): input_data=example_input, tfl_model=neutron_ir_model, ) + + +class StaticDivModel(torch.nn.Module): + def __init__(self, divisor): + super().__init__() + self.divisor = divisor + + def forward(self, x): + return x / self.divisor + + +class TestConvertDivToMulNewNeutronFlow: + + @pytest.mark.parametrize( + "input_shape", + [ + (23,), + (3, 7), + (2, 3, 4), + (1, 2, 3, 4), + (1, 2, 3, 2, 1), + ], + ids=lambda shape: f"{len(shape)}D", + ) + @pytest.mark.parametrize( + "is_scalar", + [False, True], + ids=lambda is_scalar: "scalar" if is_scalar else "tensor", + ) + def test__static__full_pipeline( + self, mocker, input_shape: tuple[int, ...], is_scalar: bool + ): + if is_scalar: + divisor = np.random.uniform(0.01, 15) + model = StaticDivModel(divisor) + else: + divisor = torch.rand(input_shape) + 0.01 + model = StaticDivModel(divisor) + + graph_verifier = DetailedGraphVerifier( + mocker, + # By the time `DetailedGraphVerifier` checks for operators, the `div` has already been replaced by `mul`. + expected_delegated_ops={MulTensor: 1}, + expected_non_delegated_ops={}, + ) + + # Cover also negative values to thoroughly test the operator. + dataset_creator = RandomDatasetCreator(low=-2, high=2) + + lower_run_compare( + model, + input_shape, + graph_verifier, + dataset_creator, + use_new_flow_neutron_c=True, # Use the new flow. + ) diff --git a/backends/nxp/tests/generic_tests/test_quantized_input_data.py b/backends/nxp/tests/generic_tests/test_quantized_input_data.py new file mode 100644 index 00000000000..4d2188816dc --- /dev/null +++ b/backends/nxp/tests/generic_tests/test_quantized_input_data.py @@ -0,0 +1,130 @@ +# Copyright 2026 NXP +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import executorch.backends.nxp.tests.nsys_testing as nsys_testing +import torch + +from executorch.backends.nxp.tests.executorch_pipeline import ModelInputSpec +from executorch.backends.nxp.tests.graph_verifier import DetailedGraphVerifier +from executorch.backends.nxp.tests.models import AvgPool2dModule, MulTensorModule +from executorch.backends.nxp.tests.nsys_testing import ( + lower_run_compare, + OUTPUTS_DIR, + ReferenceModel, +) +from executorch.backends.nxp.tests.ops_aliases import AvgPool2D, MulTensor + + +def test__single_quantized_inputs(mocker): + input_spec = ModelInputSpec((2, 4, 6, 7)) + model = AvgPool2dModule(False, 0) + graph_verifier = DetailedGraphVerifier( + mocker, expected_delegated_ops={AvgPool2D: 1}, expected_non_delegated_ops={} + ) + output_tensor_spec_spy = mocker.spy(nsys_testing, "_get_program_output_spec") + + lower_run_compare( + model, + [input_spec], + graph_verifier, + use_new_flow_neutron_c=True, + remove_quant_io_ops=True, + ) + + assert ( + OUTPUTS_DIR / "test__single_quantized_inputs" / "dataset_quant" / "0000.bin" + ).exists() + + # Check outputs are in quantized int8 format + output_tensor_spec = output_tensor_spec_spy.spy_return + assert output_tensor_spec[0].dtype == torch.int8 + + +def test__single_quantized_inputs_edge_python_reference(mocker): + input_spec = ModelInputSpec((2, 4, 6, 7)) + model = AvgPool2dModule(False, 0) + graph_verifier = DetailedGraphVerifier( + mocker, expected_delegated_ops={AvgPool2D: 1}, expected_non_delegated_ops={} + ) + output_tensor_spec_spy = mocker.spy(nsys_testing, "_get_program_output_spec") + + lower_run_compare( + model, + [input_spec], + graph_verifier, + reference_model=ReferenceModel.QUANTIZED_EDGE_PYTHON, + use_new_flow_neutron_c=True, + remove_quant_io_ops=True, + ) + + assert ( + OUTPUTS_DIR + / "test__single_quantized_inputs_edge_python_reference" + / "dataset_quant" + / "0000.bin" + ).exists() + + # Check outputs are in quantized int8 format + output_tensor_spec = output_tensor_spec_spy.spy_return + assert output_tensor_spec[0].dtype == torch.int8 + + +def test__multiple_quantized_inputs(mocker): + x_input_spec = ModelInputSpec((1, 4, 8, 8)) + model = MulTensorModule() + graph_verifier = DetailedGraphVerifier( + mocker, expected_delegated_ops={MulTensor: 1}, expected_non_delegated_ops={} + ) + output_tensor_spec_spy = mocker.spy(nsys_testing, "_get_program_output_spec") + + lower_run_compare( + model, + [x_input_spec, x_input_spec], + graph_verifier, + use_new_flow_neutron_c=True, + remove_quant_io_ops=True, + ) + + assert ( + OUTPUTS_DIR + / "test__multiple_quantized_inputs" + / "dataset_quant" + / "0000" + / "00.bin" + ).exists() + + # Check outputs are in quantized int8 format + output_tensor_spec = output_tensor_spec_spy.spy_return + assert output_tensor_spec[0].dtype == torch.int8 + + +def test__multiple_quantized_inputs_edge_python_reference(mocker): + x_input_spec = ModelInputSpec((1, 4, 8, 8)) + model = MulTensorModule() + graph_verifier = DetailedGraphVerifier( + mocker, expected_delegated_ops={MulTensor: 1}, expected_non_delegated_ops={} + ) + output_tensor_spec_spy = mocker.spy(nsys_testing, "_get_program_output_spec") + + lower_run_compare( + model, + [x_input_spec, x_input_spec], + graph_verifier, + reference_model=ReferenceModel.QUANTIZED_EDGE_PYTHON, + use_new_flow_neutron_c=True, + remove_quant_io_ops=True, + ) + + assert ( + OUTPUTS_DIR + / "test__multiple_quantized_inputs_edge_python_reference" + / "dataset_quant" + / "0000" + / "00.bin" + ).exists() + + # Check outputs are in quantized int8 format + output_tensor_spec = output_tensor_spec_spy.spy_return + assert output_tensor_spec[0].dtype == torch.int8 diff --git a/backends/nxp/tests/ir/converter/node_converter/test_add_tensor_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_add_tensor_converter.py index 1aa58ab5d95..4a656eb9517 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_add_tensor_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_add_tensor_converter.py @@ -1,7 +1,8 @@ -# Copyright 2025 NXP +# Copyright 2025-2026 NXP # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. + import numpy as np import pytest import torch @@ -9,17 +10,29 @@ from executorch.backends.nxp.backend.edge_program_converter import ( EdgeProgramToIRConverter, ) -from executorch.backends.nxp.tests.executorch_pipeline import to_quantized_edge_program +from executorch.backends.nxp.tests.dataset_creator import RandomDatasetCreator +from executorch.backends.nxp.tests.executorch_pipeline import ( + ModelInputSpec, + to_quantized_edge_program, +) from executorch.backends.nxp.tests.executors import ( convert_run_compare, + graph_contains_any_of_ops, ToChannelFirstPreprocess, ToChannelLastPreprocess, ) +from executorch.backends.nxp.tests.graph_verifier import DetailedGraphVerifier from executorch.backends.nxp.tests.models import ( AddTensorConvModule, AddTensorModule, AddTensorOneInputModule, ) +from executorch.backends.nxp.tests.nsys_testing import lower_run_compare +from executorch.backends.nxp.tests.ops_aliases import ( + AddTensor, + Convolution, + ExecutorchDelegateCall, +) from torch.export import ExportedProgram from executorch.backends.nxp.tests.use_qat import * # noqa F403 @@ -92,20 +105,26 @@ def test_add_tensor_one_input_quant_conversion(mocker, input_shape, use_qat): @pytest.mark.parametrize( - "input_shape", + "x_input_shape", [ pytest.param((1, 4, 8, 8), id="4D."), pytest.param((1, 4, 5, 5), id="4D, product of dims is not a multiple of 8."), ], ) -def test_add_tensor_w_conv_quant_conversion(mocker, input_shape, use_qat): +def test_add_tensor_w_conv_quant_conversion(mocker, x_input_shape, use_qat): model = AddTensorConvModule() converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") + n, c, h, w = x_input_shape + y_input_shape = (n, 8, h, w) + # Run conversion _ = to_quantized_edge_program( - model, input_shape, use_qat=use_qat, use_neutron_for_format_conversion=False + model, + [x_input_shape, y_input_shape], + use_qat=use_qat, + use_neutron_for_format_conversion=False, ) # Capture generated model @@ -114,7 +133,13 @@ def test_add_tensor_w_conv_quant_conversion(mocker, input_shape, use_qat): # Capture converted program exported_program: ExportedProgram = converter_spy.call_args.args[1] - input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype(np.int8) + input_data_1 = (np.random.random(x_input_shape).astype(np.float32) * 50).astype( + np.int8 + ) + input_data_2 = (np.random.random(y_input_shape).astype(np.float32) * 50).astype( + np.int8 + ) + input_data = {0: input_data_1, 1: input_data_2} convert_run_compare( exported_program, @@ -149,7 +174,7 @@ def test_add_tensor_broadcasting_unsupported_quant_conversion( nodes = list(edge_program.graph.nodes) # Broadcast is not supported, node is not converted - assert nodes[6].target.__name__ == "aten.add.Tensor" # Add Tensor is not delegated. + assert nodes[6].target == AddTensor # Add Tensor is not delegated. # Capture converted program # exported_program: ExportedProgram = converter_spy.call_args.args[1] @@ -159,3 +184,227 @@ def test_add_tensor_broadcasting_unsupported_quant_conversion( # input_data = {0: x_input_data, 1: y_input_data} # # convert_run_compare(exported_program, tfl_model=tflite_flatbuffers_model, input_data=input_data) + + +class TestAddTensorNewNeutronFlow: + @pytest.mark.parametrize( + "x_input_shape", + [ + pytest.param((1,), id="1D."), + pytest.param((6, 5), id="2D."), + pytest.param((1, 4, 7), id="3D."), + pytest.param((2, 4, 3, 15), id="4D."), + pytest.param( + (6, 82), + id="2D incorrect.", + marks=pytest.mark.xfail(reason="AIR-14602: incorrect results"), + ), + pytest.param( + (1, 68, 7), + id="3D incorrect.", + marks=pytest.mark.xfail(reason="AIR-14602: incorrect results"), + ), + pytest.param( + (1, 4, 9, 11, 4), + id="5D incorrect.", + marks=pytest.mark.xfail(reason="AIR-14602: incorrect results"), + ), + ], + ) + def test__basic_nsys_inference(self, x_input_shape, mocker): + x_input_spec = ModelInputSpec(x_input_shape) + model = AddTensorModule() + graph_verifier = DetailedGraphVerifier( + mocker, expected_delegated_ops={AddTensor: 1}, expected_non_delegated_ops={} + ) + dataset_creator = RandomDatasetCreator(low=-1.0, high=1.0) + + lower_run_compare( + model, + [x_input_spec, x_input_spec], + graph_verifier, + dataset_creator, + use_new_flow_neutron_c=True, + ) + + @pytest.mark.parametrize( + "x_input_shape", + [ + pytest.param((1,), id="1D."), + pytest.param((6, 5), id="2D."), + pytest.param((1, 4, 7), id="3D."), + pytest.param((2, 4, 3, 15), id="4D."), + pytest.param( + (1, 4, 9, 11, 4), + id="5D.", + marks=pytest.mark.xfail(reason="AIR-14602: incorrect results"), + ), + ], + ) + def test__basic_nsys_inference_qat(self, x_input_shape, mocker): + x_input_spec = ModelInputSpec(x_input_shape) + model = AddTensorModule() + graph_verifier = DetailedGraphVerifier( + mocker, expected_delegated_ops={AddTensor: 1}, expected_non_delegated_ops={} + ) + dataset_creator = RandomDatasetCreator(low=-1.0, high=1.0) + + lower_run_compare( + model, + [x_input_spec, x_input_spec], + graph_verifier, + dataset_creator, + use_new_flow_neutron_c=True, + use_qat=True, + ) + + @pytest.mark.parametrize( + "input_spec", + [ + pytest.param( + [ModelInputSpec((4, 6)), ModelInputSpec((1, 6))], id="2 inputs 2D." + ), + pytest.param( + [ModelInputSpec((5, 3, 4)), ModelInputSpec((1, 3, 1))], + id="2 inputs 3D.", + ), + pytest.param( + [ModelInputSpec((4,)), ModelInputSpec((4, 4))], id="2 inputs 1D + 2D." + ), + pytest.param( + [ModelInputSpec((69, 73)), ModelInputSpec((1, 73))], + id="2 inputs 2D incorrect.", + marks=pytest.mark.xfail(reason="AIR-14602: incorrect results"), + ), + ], + ) + def test__broadcast(self, input_spec, mocker): + model = AddTensorModule() + graph_verifier = DetailedGraphVerifier( + mocker, expected_delegated_ops={AddTensor: 1}, expected_non_delegated_ops={} + ) + dataset_creator = RandomDatasetCreator(low=-1.0, high=1.0) + + lower_run_compare( + model, + input_spec, + graph_verifier, + dataset_creator, + use_new_flow_neutron_c=True, + ) + + @pytest.mark.parametrize( + "input_spec", + [ + pytest.param( + [ModelInputSpec((4, 1)), ModelInputSpec((1, 6))], id="2 inputs 2D." + ), + pytest.param( + [ModelInputSpec((1, 3, 4)), ModelInputSpec((5, 3, 1))], + id="2 inputs 3D.", + ), + pytest.param( + [ModelInputSpec((6, 4)), ModelInputSpec((6, 6, 1))], + id="2 inputs 2D + 3D.", + ), + ], + ) + def test__broadcast_unsupported(self, input_spec): + # Broadcast where at least one of the inputs is not equal to output is not supported + model = AddTensorModule() + + delegated_ep = to_quantized_edge_program( + model, input_spec, use_new_flow_neutron_c=True + ).exported_program() + + # Make sure the `add.Tensor` was NOT delegated. + assert not graph_contains_any_of_ops( + delegated_ep.graph, [ExecutorchDelegateCall] + ) + assert graph_contains_any_of_ops(delegated_ep.graph, [AddTensor]) + + @pytest.mark.parametrize( + "x_input_shape", + [ + pytest.param( + (1, 4, 5, 5), id="4D, product of dims is not a multiple of 8." + ), + ], + ) + def test__w_conv(self, x_input_shape, mocker): + model = AddTensorConvModule() + + n, c, h, w = x_input_shape + y_input_spec = ModelInputSpec((n, 8, h, w)) + x_input_spec = ModelInputSpec(x_input_shape) + + graph_verifier = DetailedGraphVerifier( + mocker, + expected_delegated_ops={AddTensor: 1, Convolution: 1}, + expected_non_delegated_ops={}, + ) + dataset_creator = RandomDatasetCreator(low=-1.0, high=1.0) + + lower_run_compare( + model, + [x_input_spec, y_input_spec], + graph_verifier, + dataset_creator, + use_new_flow_neutron_c=True, + ) + + @pytest.mark.parametrize( + "input_spec", + [ + pytest.param( + [ModelInputSpec((1, 4, 5, 5)), ModelInputSpec((1, 8, 5, 1))], + id="2 inputs 4D + 4D.", + ), + pytest.param( + [ModelInputSpec((1, 4, 5, 67)), ModelInputSpec((1, 8, 5, 1))], + id="2 inputs 4D + 4D incorrect.", + marks=pytest.mark.xfail(reason="AIR-14602: incorrect results"), + ), + ], + ) + def test__w_conv_broadcast(self, input_spec, mocker): + model = AddTensorConvModule() + + graph_verifier = DetailedGraphVerifier( + mocker, + expected_delegated_ops={AddTensor: 1, Convolution: 1}, + expected_non_delegated_ops={}, + ) + dataset_creator = RandomDatasetCreator(low=-1.0, high=1.0) + + lower_run_compare( + model, + input_spec, + graph_verifier, + dataset_creator, + use_new_flow_neutron_c=True, + ) + + @pytest.mark.parametrize( + "input_spec", + [ + pytest.param( + [ModelInputSpec((1, 4, 5, 5)), ModelInputSpec((1, 5))], + id="2 inputs 4D + 2D.", + ), + pytest.param( + [ModelInputSpec((1, 4, 4, 10)), ModelInputSpec((1, 4, 1))], + id="2 inputs 4D + 3D.", + ), + ], + ) + def test__w_conv_unsupported(self, input_spec): + model = AddTensorConvModule() + + delegated_ep = to_quantized_edge_program( + model, input_spec, use_new_flow_neutron_c=True + ).exported_program() + + # Make sure the `add.Tensor` was NOT delegated. + assert graph_contains_any_of_ops(delegated_ep.graph, [ExecutorchDelegateCall]) + assert graph_contains_any_of_ops(delegated_ep.graph, [AddTensor]) diff --git a/backends/nxp/tests/ir/converter/node_converter/test_avg_pool2d_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_avg_pool2d_converter.py index 2c73ccd8092..193b7ecf9ab 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_avg_pool2d_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_avg_pool2d_converter.py @@ -6,6 +6,7 @@ import numpy as np import pytest import torch + from executorch.backends.nxp.backend.edge_program_converter import ( EdgeProgramToIRConverter, ) @@ -29,13 +30,8 @@ ToNHWCPreprocess, ) from executorch.backends.nxp.tests.graph_verifier import DetailedGraphVerifier -from executorch.backends.nxp.tests.model_output_comparator import ( - NumericalStatsOutputComparator, -) from executorch.backends.nxp.tests.models import AvgPool2dConvModule, AvgPool2dModule - from executorch.backends.nxp.tests.nsys_testing import lower_run_compare - from executorch.backends.nxp.tests.ops_aliases import ( AvgPool2D, ExecutorchDelegateCall, @@ -45,6 +41,7 @@ Unsqueeze, ViewCopy, ) + from torch.export import ExportedProgram from executorch.backends.nxp.tests.use_qat import * # noqa F403 @@ -320,7 +317,6 @@ def test__basic_nsys_inference(self, mocker): def test__basic_nsys_inference_qat(self, mocker): input_shape = (2, 9, 6, 15) model = AvgPool2dModule(False, 0) - comparator = NumericalStatsOutputComparator() graph_verifier = DetailedGraphVerifier( mocker, expected_delegated_ops={AvgPool2D: 1}, expected_non_delegated_ops={} ) @@ -329,7 +325,6 @@ def test__basic_nsys_inference_qat(self, mocker): model, input_shape, graph_verifier, - output_comparator=comparator, use_new_flow_neutron_c=True, use_qat=True, ) diff --git a/backends/nxp/tests/ir/converter/node_converter/test_convert_upsample_bilinear2d.py b/backends/nxp/tests/ir/converter/node_converter/test_convert_upsample_bilinear2d.py index 5663eea9cc3..2d2f9845fa3 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_convert_upsample_bilinear2d.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_convert_upsample_bilinear2d.py @@ -4,12 +4,15 @@ # LICENSE file in the root directory of this source tree. import numpy as np + +# noinspection PyUnusedImports import pytest import torch from executorch.backends.nxp.backend.edge_program_converter import ( EdgeProgramToIRConverter, ) +from executorch.backends.nxp.tests.dataset_creator import RandomDatasetCreator from executorch.backends.nxp.tests.executorch_pipeline import to_quantized_edge_program from executorch.backends.nxp.tests.executors import ( convert_run_compare, @@ -17,7 +20,17 @@ ToChannelFirstPreprocess, ToChannelLastPreprocess, ) -from executorch.exir.dialects._ops import ops as exir_ops +from executorch.backends.nxp.tests.graph_verifier import DetailedGraphVerifier +from executorch.backends.nxp.tests.model_output_comparator import ( + AllCloseOutputComparator, +) +from executorch.backends.nxp.tests.nsys_testing import lower_run_compare +from executorch.backends.nxp.tests.ops_aliases import ( + AddTensor, + ExecutorchDelegateCall, + UpsampleBilinear2D, +) +from executorch.backends.nxp.tests.use_qat import * # noqa F403 @pytest.fixture(autouse=True) @@ -26,23 +39,25 @@ def reseed_model_per_test_run(): np.random.seed(23) -# noinspection PyProtectedMember -ExecutorchDelegateCall = torch.ops.higher_order.executorch_call_delegate -UpsampleBilinear2D = exir_ops.edge.aten.upsample_bilinear2d.vec - - class UpsampleBilinearModule(torch.nn.Module): - def __init__(self, size=None, scale=None): + def __init__(self, size=None, scale=None, **kwargs): super().__init__() self.upsample = torch.nn.Upsample( - size=size, scale_factor=scale, mode="bilinear" + size=size, scale_factor=scale, mode="bilinear", **kwargs ) def forward(self, x): return self.upsample(x) +class UpsampleBilinearAddModule(UpsampleBilinearModule): + + def forward(self, x): + x = super().forward(x) + return x + x + + @pytest.mark.parametrize( "input_shape, size", [ @@ -185,3 +200,255 @@ def test_convert_upsample_bilinear2d__no_delegation__unsupported_size( # Make sure the `upsample` was NOT delegated (size != double of input). assert not graph_contains_any_of_ops(delegated_ep.graph, [ExecutorchDelegateCall]) assert graph_contains_any_of_ops(delegated_ep.graph, [UpsampleBilinear2D]) + + +class TestUpsampleBilinear2DNewNeutronFlow: + # TODO Use quantized dataset and `atol=1` in the tests. + + # noinspection PyMethodMayBeStatic + def assert_delegated( + self, + model, + input_shape, + mocker, + use_qat=False, + atol=None, + expected_delegated_ops=None, + ): + if expected_delegated_ops is None: + expected_delegated_ops = {UpsampleBilinear2D: 1} + + graph_verifier = DetailedGraphVerifier( + mocker, + expected_delegated_ops=expected_delegated_ops, + expected_non_delegated_ops={}, + ) + + # Cover also negative values to thoroughly test the operator. + dataset_creator = RandomDatasetCreator(low=-2, high=2) + + kwargs = {"atol": atol} if atol is not None else {} + output_comparator = AllCloseOutputComparator(**kwargs) + + lower_run_compare( + model, + input_shape, + graph_verifier, + dataset_creator, + output_comparator, + use_qat=use_qat, + use_new_flow_neutron_c=True, # Use the new flow. + ) + + # noinspection PyMethodMayBeStatic + def assert_not_delegated(self, model, input_shape): + delegated_ep = to_quantized_edge_program( + model, input_shape, use_new_flow_neutron_c=True + ).exported_program() + + assert not graph_contains_any_of_ops( + delegated_ep.graph, [ExecutorchDelegateCall] + ) + assert graph_contains_any_of_ops(delegated_ep.graph, [UpsampleBilinear2D]) + + def test__qat__align_corners(self, mocker, use_qat): + align_corners = True + input_shape = (1, 2, 3, 4) + output_size = (5, 7) + model = UpsampleBilinearModule(size=output_size, align_corners=align_corners) + atol = 0.015 # ~= output scale -> single bit error. + self.assert_delegated(model, input_shape, mocker, use_qat=use_qat, atol=atol) + + def test__qat__not_align_corners(self, mocker, use_qat): + align_corners = False + input_shape = (1, 2, 3, 4) + output_size = (6, 8) + model = UpsampleBilinearModule(size=output_size, align_corners=align_corners) + atol = 0.015 # ~= output scale -> single bit error. + self.assert_delegated(model, input_shape, mocker, use_qat=use_qat, atol=atol) + + @pytest.mark.parametrize( + "input_shape, output_size", + [ + pytest.param((1, 2, 3, 4), (6, 8), id="batch=1, scale_h=scale_w=2"), + pytest.param( + (3, 3, 3, 5), + (6, 5), + id="batch=3, scale_h=2, scale_w=1 (no num_macs multiples)", + ), + pytest.param((2, 2, 3, 4), (3, 16), id="batch=2, scale_h=1, scale_w=4"), + pytest.param((2, 2, 3, 4), (24, 8), id="batch=2, scale_h=8, scale_w=2"), + ], + ) + def test__not_align_corners__output_size(self, mocker, input_shape, output_size): + align_corners = False + model = UpsampleBilinearModule(size=output_size, align_corners=align_corners) + atol = 0.016 # ~= output scale -> single bit error. + self.assert_delegated(model, input_shape, mocker, atol=atol) + + def test__not_align_corners__output_size__unsupported(self): + align_corners = False + input_shape = (1, 2, 3, 4) + output_size = (9, 12) # scale = (3, 3) + model = UpsampleBilinearModule(size=output_size, align_corners=align_corners) + self.assert_not_delegated(model, input_shape) + + @pytest.mark.parametrize( + "input_shape, scale", + [ + pytest.param((1, 2, 3, 4), (2, 2), id="batch=1, scale_h=scale_w=2"), + pytest.param( + (3, 3, 3, 5), + (2, 1), + id="batch=3, scale_h=2, scale_w=1 (no num_macs multiples)", + ), + pytest.param((2, 2, 3, 4), (4, 1), id="batch=2, scale_h=4, scale_w=1"), + pytest.param((2, 2, 3, 4), (2, 8), id="batch=2, scale_h=2, scale_w=8"), + ], + ) + def test__not_align_corners__scales(self, mocker, input_shape, scale): + align_corners = False + model = UpsampleBilinearModule(scale=scale, align_corners=align_corners) + atol = 0.016 # ~= output scale -> single bit error. + self.assert_delegated(model, input_shape, mocker, atol=atol) + + def test__not_align_corners__scales__unsupported(self): + align_corners = False + input_shape = (1, 2, 3, 4) + scale = (3, 3) + model = UpsampleBilinearModule(scale=scale, align_corners=align_corners) + self.assert_not_delegated(model, input_shape) + + @pytest.mark.parametrize( + "input_shape, output_size", + [ + pytest.param((1, 2, 4, 5), (7, 9), id="batch=1, scale_h=scale_w=2"), + pytest.param( + (1, 3, 3, 5), + (5, 5), + id="batch=1, scale_h=2, scale_w=1 (no num_macs multiples)", + ), + pytest.param((2, 2, 4, 5), (4, 17), id="batch=2, scale_h=1, scale_w=4"), + pytest.param((1, 2, 4, 5), (25, 9), id="batch=1, scale_h=8, scale_w=2"), + ], + ) + def test__align_corners__output_size(self, mocker, input_shape, output_size): + align_corners = True + model = UpsampleBilinearModule(size=output_size, align_corners=align_corners) + atol = 0.016 # ~= output scale -> single bit error. + self.assert_delegated(model, input_shape, mocker, atol=atol) + + @pytest.mark.parametrize( + "input_shape, output_size", + [ + pytest.param( + (2, 2, 4, 5), (25, 9), id="batch=2, scale_h=8, scale_w=2" + ), # Error ~= 0.47 + pytest.param( + (3, 3, 3, 5), + (5, 5), + id="batch=3, scale_h=2, scale_w=1 (no num_macs multiples)", + ), # Error ~= 3.7 + ], + ) + def test__align_corners__output_size__incorrect_output( + self, mocker, input_shape, output_size + ): + align_corners = True + model = UpsampleBilinearModule(size=output_size, align_corners=align_corners) + atol = 0.45 # Huge tolerance (still not enough to pass). + with pytest.raises(AssertionError): + self.assert_delegated(model, input_shape, mocker, atol=atol) + + def test__align_corners__output_size__unsupported(self): + align_corners = True + input_shape = (1, 2, 3, 4) + output_size = (6, 8) # Neutron scale = (5/2, 7/3) + model = UpsampleBilinearModule(size=output_size, align_corners=align_corners) + self.assert_not_delegated(model, input_shape) + + def test__align_corners__output_size__input_size_equal_to_one(self): + align_corners = True + input_shape = (1, 2, 1, 1) # Neutron scale computation would divide by zero. + output_size = (2, 2) + model = UpsampleBilinearModule(size=output_size, align_corners=align_corners) + self.assert_not_delegated(model, input_shape) + + @pytest.mark.parametrize( + "input_shape, scale", + [ + # The PyTorch scales are "weird" because the "Neutron scales" are computed differently. + # The fractions correspond to "nice" Neutron scales (1, 2, 4, or 8). + pytest.param( + (1, 2, 4, 5), + (7 / 4, 9 / 5), + id="batch=1, scale_h=7/4, scale_w=9/5 (Neutron scales = (2, 2)", + ), + pytest.param( + (1, 3, 3, 5), + (5 / 3, 1), + id="batch=1, scale_h=5/3, scale_w=1 (Neutron scales = (2, 1))", + ), + pytest.param( + (2, 2, 4, 5), + (1, 17 / 5), + id="batch=2, scale_h=1, scale_w=17/5 (Neutron scales = (1, 4))", + ), + pytest.param( + (1, 2, 4, 5), + (25 / 4, 9 / 5), + id="batch=1, scale_h=25/4, scale_w=9/5 (Neutron scales = (8, 2))", + ), + ], + ) + def test__align_corners__scales(self, mocker, input_shape, scale): + align_corners = True + model = UpsampleBilinearModule(scale=scale, align_corners=align_corners) + atol = 0.016 # ~= output scale -> single bit error. + self.assert_delegated(model, input_shape, mocker, atol=atol) + + @pytest.mark.parametrize( + "input_shape, scale", + [ + pytest.param( + (2, 2, 4, 5), + (25 / 4, 9 / 5), + id="batch=3, scale_h=25/4, scale_w=9/5 (Neutron scales = (8, 2))", + ), # Error ~= 0.47 + pytest.param( + (3, 3, 3, 5), + (5 / 3, 1), + id="batch=3, scale_h=5/3, scale_w=1 (Neutron scales = (2, 1))", + ), # Error ~= 3.7 + ], + ) + def test__align_corners__scales__incorrect_output(self, mocker, input_shape, scale): + align_corners = True + model = UpsampleBilinearModule(scale=scale, align_corners=align_corners) + atol = 0.45 # Huge tolerance (still not enough to pass). + with pytest.raises(AssertionError): + self.assert_delegated(model, input_shape, mocker, atol=atol) + + def test__align_corners__scales__unsupported(self): + align_corners = True + input_shape = (1, 2, 3, 4) + scale = (2, 2) # Neutron scale = (5/2, 7/3) + model = UpsampleBilinearModule(scale=scale, align_corners=align_corners) + self.assert_not_delegated(model, input_shape) + + def test__noop__alone_in_partition__not_delegated(self): + input_shape = (1, 2, 3, 4) + scale = 1 + model = UpsampleBilinearModule(scale=scale) + self.assert_not_delegated(model, input_shape) + + def test__noop__not_alone_in_partition__delegated(self, mocker): + input_shape = (1, 2, 3, 4) + scale = 1 + model = UpsampleBilinearAddModule(scale=scale) + self.assert_delegated( + model, + input_shape, + mocker, + expected_delegated_ops={UpsampleBilinear2D: 1, AddTensor: 1}, + ) diff --git a/backends/nxp/tests/ir/converter/node_converter/test_convert_upsample_nearest2d.py b/backends/nxp/tests/ir/converter/node_converter/test_convert_upsample_nearest2d.py index 3d9ec84dec9..27d1ac718a0 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_convert_upsample_nearest2d.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_convert_upsample_nearest2d.py @@ -4,12 +4,15 @@ # LICENSE file in the root directory of this source tree. import numpy as np + +# noinspection PyUnusedImports import pytest import torch from executorch.backends.nxp.backend.edge_program_converter import ( EdgeProgramToIRConverter, ) +from executorch.backends.nxp.tests.dataset_creator import RandomDatasetCreator from executorch.backends.nxp.tests.executorch_pipeline import to_quantized_edge_program from executorch.backends.nxp.tests.executors import ( convert_run_compare, @@ -17,7 +20,14 @@ ToChannelFirstPreprocess, ToChannelLastPreprocess, ) -from executorch.exir.dialects._ops import ops as exir_ops +from executorch.backends.nxp.tests.graph_verifier import DetailedGraphVerifier +from executorch.backends.nxp.tests.nsys_testing import lower_run_compare +from executorch.backends.nxp.tests.ops_aliases import ( + AddTensor, + ExecutorchDelegateCall, + UpsampleNearest2D, +) +from executorch.backends.nxp.tests.use_qat import * # noqa F403 @pytest.fixture(autouse=True) @@ -26,11 +36,6 @@ def reseed_model_per_test_run(): np.random.seed(23) -# noinspection PyProtectedMember -ExecutorchDelegateCall = torch.ops.higher_order.executorch_call_delegate -UpsampleNearest2D = exir_ops.edge.aten.upsample_nearest2d.vec - - class UpsampleNearestModule(torch.nn.Module): def __init__(self, size=None, scale=None): @@ -41,6 +46,13 @@ def forward(self, x): return self.upsample(x) +class UpsampleNearestAddModule(UpsampleNearestModule): + + def forward(self, x): + x = super().forward(x) + return x + x + + @pytest.mark.parametrize( "input_shape, size", [ @@ -181,3 +193,120 @@ def test_convert_upsample_nearest2d__no_delegation__unsupported_size(input_shape # Make sure the `upsample` was NOT delegated (size != double of input). assert not graph_contains_any_of_ops(delegated_ep.graph, [ExecutorchDelegateCall]) assert graph_contains_any_of_ops(delegated_ep.graph, [UpsampleNearest2D]) + + +class TestUpsampleNearest2DNewNeutronFlow: + + # noinspection PyMethodMayBeStatic + def assert_delegated( + self, + model, + input_shape, + mocker, + use_qat=False, + expected_delegated_ops=None, + ): + if expected_delegated_ops is None: + expected_delegated_ops = {UpsampleNearest2D: 1} + + graph_verifier = DetailedGraphVerifier( + mocker, + expected_delegated_ops=expected_delegated_ops, + expected_non_delegated_ops={}, + ) + + # Cover also negative values to thoroughly test the operator. + dataset_creator = RandomDatasetCreator(low=-2, high=2) + + lower_run_compare( + model, + input_shape, + graph_verifier, + dataset_creator, + use_qat=use_qat, + use_new_flow_neutron_c=True, # Use the new flow. + ) + + # noinspection PyMethodMayBeStatic + def assert_not_delegated(self, model, input_shape): + delegated_ep = to_quantized_edge_program( + model, input_shape, use_new_flow_neutron_c=True + ).exported_program() + + assert not graph_contains_any_of_ops( + delegated_ep.graph, [ExecutorchDelegateCall] + ) + assert graph_contains_any_of_ops(delegated_ep.graph, [UpsampleNearest2D]) + + def test__qat(self, mocker, use_qat): + input_shape = (1, 2, 3, 4) + output_size = (6, 8) + model = UpsampleNearestModule(size=output_size) + self.assert_delegated(model, input_shape, mocker, use_qat=use_qat) + + @pytest.mark.parametrize( + "input_shape, output_size", + [ + pytest.param((1, 2, 3, 4), (6, 8), id="batch=1, scale_h=scale_w=2"), + pytest.param((1, 2, 3, 3), 6, id="batch=1, scale_h=scale_w=2, scalar size"), + pytest.param( + (3, 3, 3, 5), + (6, 5), + id="batch=3, scale_h=2, scale_w=1 (no num_macs multiples)", + ), + pytest.param((2, 2, 3, 4), (3, 16), id="batch=2, scale_h=1, scale_w=4"), + pytest.param((2, 2, 3, 4), (24, 8), id="batch=2, scale_h=8, scale_w=2"), + ], + ) + def test__output_size(self, mocker, input_shape, output_size): + model = UpsampleNearestModule(size=output_size) + self.assert_delegated(model, input_shape, mocker) + + def test__output_size__unsupported(self): + input_shape = (1, 2, 3, 4) + output_size = (9, 12) # scale = (3, 3) + model = UpsampleNearestModule(size=output_size) + self.assert_not_delegated(model, input_shape) + + @pytest.mark.parametrize( + "input_shape, scale", + [ + pytest.param((1, 2, 3, 4), (2, 2), id="batch=1, scale_h=scale_w=2"), + pytest.param( + (1, 2, 3, 4), 4, id="batch=1, scale_h=scale_w=4, scalar scale" + ), + pytest.param( + (3, 3, 3, 5), + (2, 1), + id="batch=3, scale_h=2, scale_w=1 (no num_macs multiples)", + ), + pytest.param((2, 2, 3, 4), (4, 1), id="batch=2, scale_h=4, scale_w=1"), + pytest.param((2, 2, 3, 4), (2, 8), id="batch=2, scale_h=2, scale_w=8"), + ], + ) + def test__scales(self, mocker, input_shape, scale): + model = UpsampleNearestModule(scale=scale) + self.assert_delegated(model, input_shape, mocker) + + def test__scales__unsupported(self): + input_shape = (1, 2, 3, 4) + scale = (3, 3) + model = UpsampleNearestModule(scale=scale) + self.assert_not_delegated(model, input_shape) + + def test__noop__alone_in_partition__not_delegated(self): + input_shape = (1, 2, 3, 4) + scale = 1 + model = UpsampleNearestModule(scale=scale) + self.assert_not_delegated(model, input_shape) + + def test__noop__not_alone_in_partition__delegated(self, mocker): + input_shape = (1, 2, 3, 4) + scale = 1 + model = UpsampleNearestAddModule(scale=scale) + self.assert_delegated( + model, + input_shape, + mocker, + expected_delegated_ops={UpsampleNearest2D: 1, AddTensor: 1}, + ) diff --git a/backends/nxp/tests/ir/converter/node_converter/test_max_pool_2d_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_max_pool_2d_converter.py index 583dc2bfd04..9062d5efbfc 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_max_pool_2d_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_max_pool_2d_converter.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. import numpy as np +import pytest import torch from executorch.backends.nxp.backend.edge_program_converter import ( @@ -17,9 +18,6 @@ ToChannelLastPreprocess, ) from executorch.backends.nxp.tests.graph_verifier import DetailedGraphVerifier -from executorch.backends.nxp.tests.model_output_comparator import ( - NumericalStatsOutputComparator, -) from executorch.backends.nxp.tests.nsys_testing import lower_run_compare from executorch.backends.nxp.tests.ops_aliases import ( ExecutorchDelegateCall, @@ -32,7 +30,6 @@ ViewCopy, ) from executorch.backends.nxp.tests.use_qat import * # noqa F403 -import pytest class MaxPool1DModule(torch.nn.Module): @@ -286,7 +283,6 @@ def test__basic_nsys_inference(self, mocker): def test__basic_nsys_inference_qat(self, mocker): input_shape = (2, 11, 7, 16) # The old flow limited the batch size to 1. model = MaxPool2dModule() - comparator = NumericalStatsOutputComparator() graph_verifier = DetailedGraphVerifier( mocker, expected_delegated_ops={MaxPool2DWithIndices: 1, GetItem: 1}, @@ -297,7 +293,6 @@ def test__basic_nsys_inference_qat(self, mocker): model, input_shape, graph_verifier, - output_comparator=comparator, use_new_flow_neutron_c=True, use_qat=True, ) diff --git a/backends/nxp/tests/ir/converter/node_converter/test_mean_dim_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_mean_dim_converter.py index 7c0a5e8ffcf..a265ca557c9 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_mean_dim_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_mean_dim_converter.py @@ -1,15 +1,18 @@ -# Copyright 2025 NXP +# Copyright 2025-2026 NXP # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import numpy as np + +# noinspection PyUnusedImports import pytest import torch from executorch.backends.nxp.backend.edge_program_converter import ( EdgeProgramToIRConverter, ) +from executorch.backends.nxp.tests.dataset_creator import RandomDatasetCreator from executorch.backends.nxp.tests.executorch_pipeline import to_quantized_edge_program from executorch.backends.nxp.tests.executors import ( convert_run_compare, @@ -17,10 +20,21 @@ ToChannelFirstPreprocess, ToChannelLastPreprocess, ) +from executorch.backends.nxp.tests.graph_verifier import DetailedGraphVerifier +from executorch.backends.nxp.tests.model_output_comparator import ( + AllCloseOutputComparator, +) from executorch.backends.nxp.tests.models import MeanDimConvModule, MeanDimLinearModule -from executorch.backends.nxp.tests.use_qat import * # noqa F403 -from executorch.exir.dialects._ops import ops as exir_ops +from executorch.backends.nxp.tests.nsys_testing import lower_run_compare +from executorch.backends.nxp.tests.ops_aliases import ( + AddTensor, + ExecutorchDelegateCall, + GetItem, + MaxPool2DWithIndices, + MeanDim, +) from torch.export import ExportedProgram +from executorch.backends.nxp.tests.use_qat import * # noqa F403 @pytest.fixture(autouse=True) @@ -39,6 +53,12 @@ def forward(self, x): return torch.mean(x, dim=self.dim, keepdim=self.keepdim) +class MeanDimAddModule(MeanDimModule): + def forward(self, x): + x = super().forward(x) + return x + x + + @pytest.mark.parametrize( "input_shape, dim", [ @@ -60,7 +80,7 @@ def test_mean_dim_conv_quant_conversion( model, input_shape, use_qat=use_qat, use_neutron_for_format_conversion=False ).exported_program() # Make sure the `mean.dim` was delegated. - assert not graph_contains_any_of_ops(ep.graph, [exir_ops.edge.aten.mean.dim]) + assert not graph_contains_any_of_ops(ep.graph, [MeanDim]) assert any("lowered_module" in n.name for n in ep.graph.nodes) # Capture generated model @@ -109,7 +129,7 @@ def test_mean_dim_linear_unsupported_quant_conversion( nodes = list(edge_program.graph.nodes) # Last 2 dimensions are not used or keepdim is False, cannot be converted to MeanDim, node is not delegated - assert nodes[6].target.__name__ == "aten.mean.dim" + assert nodes[6].target == MeanDim # Capture generated model tflite_flatbuffers_model, io_formats = converter_spy.spy_return @@ -157,7 +177,7 @@ def test_mean_dim_conv_unsupported_quant_conversion( nodes = list(edge_program.graph.nodes) # Last 2 dimensions are not used or keepdim is False, cannot be converted to MeanDim, node is not delegated - assert nodes[6].target.__name__ == "aten.mean.dim" + assert nodes[6].target == MeanDim # Capture generated model tflite_flatbuffers_model, io_formats = converter_spy.spy_return @@ -197,7 +217,7 @@ def test_mean_dim__formatless__supported( ).exported_program() # Make sure the `mean.dim` was delegated. - assert not graph_contains_any_of_ops(ep.graph, [exir_ops.edge.aten.mean.dim]) + assert not graph_contains_any_of_ops(ep.graph, [MeanDim]) assert any("lowered_module" in n.name for n in ep.graph.nodes) # Capture generated model @@ -230,7 +250,7 @@ def test_mean_dim__formatless__unsupported(input_shape, dim, use_qat, keepdim=Tr ).exported_program() # Make sure the `mean.dim` was NOT delegated. - assert graph_contains_any_of_ops(ep.graph, [exir_ops.edge.aten.mean.dim]) + assert graph_contains_any_of_ops(ep.graph, [MeanDim]) assert not any("lowered_module" in n.name for n in ep.graph.nodes) @@ -252,7 +272,7 @@ def test_mean_dim__formatless__unsupported_channels( ).exported_program() # Make sure the `mean.dim` was NOT delegated. - assert graph_contains_any_of_ops(ep.graph, [exir_ops.edge.aten.mean.dim]) + assert graph_contains_any_of_ops(ep.graph, [MeanDim]) assert not any("lowered_module" in n.name for n in ep.graph.nodes) @@ -277,4 +297,181 @@ def test_mean_dim__channels_first__unsupported_channels( ).exported_program() # Make sure the `mean.dim` was NOT delegated. - assert graph_contains_any_of_ops(ep.graph, [exir_ops.edge.aten.mean.dim]) + assert graph_contains_any_of_ops(ep.graph, [MeanDim]) + + +class MaxPoolMeanDimModule(torch.nn.Module): + def __init__(self, dim, keepdim): + super().__init__() + self.dim, self.keepdim = dim, keepdim + + def forward(self, x): + x = torch.max_pool2d( + x, kernel_size=1 + ) # NoOp, but it enforces the channels first format. + return torch.mean(x, dim=self.dim, keepdim=self.keepdim) + + +class TestMeanDimNewNeutronFlow: + + # noinspection PyMethodMayBeStatic + def assert_delegated( + self, + model, + input_shape, + mocker, + use_qat=False, + atol=None, + expected_delegated_ops=None, + ): + if expected_delegated_ops is None: + expected_delegated_ops = {MeanDim: 1} + + graph_verifier = DetailedGraphVerifier( + mocker, + expected_delegated_ops=expected_delegated_ops, + expected_non_delegated_ops={}, + ) + + # Cover also negative values to thoroughly test the operator. + dataset_creator = RandomDatasetCreator(low=-2, high=2) + + kwargs = {"atol": atol} if atol is not None else {} + output_comparator = AllCloseOutputComparator(**kwargs) + + lower_run_compare( + model, + input_shape, + graph_verifier, + dataset_creator, + output_comparator, + use_qat=use_qat, + use_new_flow_neutron_c=True, # Use the new flow. + ) + + # noinspection PyMethodMayBeStatic + def assert_not_delegated(self, model, input_shape): + delegated_ep = to_quantized_edge_program( + model, input_shape, use_new_flow_neutron_c=True + ).exported_program() + + # Make sure the `mean` was NOT delegated. + assert not graph_contains_any_of_ops( + delegated_ep.graph, [ExecutorchDelegateCall] + ) + assert graph_contains_any_of_ops(delegated_ep.graph, [MeanDim]) + + @pytest.fixture(params=[True, False], ids=lambda keep_dim: f"keep_dim = {keep_dim}") + def keep_dim(self, request): + return request.param + + def test__basic_nsys_inference__qat(self, mocker, use_qat, keep_dim): + input_shape = (23,) + model = MeanDimModule(0, keep_dim) + self.assert_delegated(model, input_shape, mocker, use_qat=use_qat) + + @pytest.mark.parametrize( + "input_shape, dim", + [ + pytest.param((5,), 0, id="1D, dim = 0."), + pytest.param((4, 2), 0, id="2D, dim = 0."), + pytest.param((4, 2), -1, id="2D, dim = -1."), + pytest.param((3, 1, 4), 2, id="3D, dim = 2."), + pytest.param((1, 3, 3, 7), 3, id="4D, dim = 3."), + pytest.param((3, 1, 4, 1, 5), -1, id="5D, dim = -1."), + pytest.param((3, 1, 4, 1, 5), 0, id="5D, dim = 0."), + ], + ) + def test__single_dims(self, mocker, input_shape, dim, keep_dim): + model = MeanDimModule(dim, keep_dim) + # Relatively large error, but it is actually equal to the output scale, so it is a single bit error. + # TODO Replace with quantized dataset testing and `atol = 1`. + atol = 0.014 + self.assert_delegated(model, input_shape, mocker, atol=atol) + + @pytest.mark.parametrize( + "input_shape, dim", + [ + pytest.param((4, 2), (-2,), id="2D, dim = (-2,)."), + pytest.param((2, 3, 4), (0, 2), id="3D, dim = (0, 2,)."), + pytest.param((1, 3, 3, 7), (2, -3), id="4D, dim = (2, -3)."), + pytest.param((3, 1, 4, 1, 5), (3, -5, -4), id="5D, dim = (3, -5 ,-4)."), + ], + ) + def test__tuple_dims(self, mocker, input_shape, dim, keep_dim): + model = MeanDimModule(dim, keep_dim) + # Relatively large error, but it is actually equal to the output scale, so it is a single bit error. + # TODO Replace with quantized dataset testing and `atol = 1`. + atol = 0.015 + self.assert_delegated(model, input_shape, mocker, atol=atol) + + def test__compute_error(self, mocker, keep_dim): + input_shape, dim = (1, 3, 3, 7), -2 + model = MeanDimModule(dim, keep_dim) + + # Neutron produces an incorrect result in this case (maximum absolute error ~= 0.0607 (more than 2 * scale)). + # This test detects the failure to alert us once the bug is fixed. It should be fixed in Neutron 3.1.2. + with pytest.raises(AssertionError): + self.assert_delegated(model, input_shape, mocker, atol=0.06) + + @pytest.mark.parametrize( + "input_shape, dim", + [ + pytest.param((3, 1, 4), 1, id="3D, dim = 1."), + pytest.param((3, 1, 4, 1, 5), -2, id="5D, dim = -2."), + ], + ) + def test__noop__only_node__not_delegated(self, input_shape, dim): + keep_dim = True # Reduction over a dimension of size `1` with `keep_dim=True` is a no-op. + model = MeanDimModule(dim, keep_dim) + self.assert_not_delegated(model, input_shape) + + @pytest.mark.parametrize( + "input_shape, dim", + [ + pytest.param((3, 1, 4), 1, id="3D, dim = 1."), + pytest.param((3, 1, 4, 1, 5), -2, id="5D, dim = -2."), + ], + ) + def test__noop__not_only_node__delegated(self, mocker, input_shape, dim): + keep_dim = True # Reduction over a dimension of size `1` with `keep_dim=True` is a no-op. + model = MeanDimAddModule(dim, keep_dim) + self.assert_delegated( + model, + input_shape, + mocker, + expected_delegated_ops={MeanDim: 1, AddTensor: 1}, + ) + + @pytest.mark.parametrize( + "input_shape, dim", + [ + pytest.param((3, 1, 4), 1, id="3D, dim = 1."), + pytest.param((3, 1, 4, 1, 5), -2, id="5D, dim = -2."), + ], + ) + def test__no_reduction__keepdim_false__delegated(self, mocker, input_shape, dim): + # These cases reduce over a dimension of size 1. + # When `keep_dim=True` the node is a noop, and it's not delegated (see `test__noop__only_node__not_delegated`), + # but with `keep_dim=False` it changes the shape so it's not a noop and is therefore delegated successfully. + keep_dim = False + model = MeanDimModule(dim, keep_dim) + self.assert_delegated(model, input_shape, mocker) + + @pytest.mark.parametrize( + "input_shape, dim", + [((1, 7, 3, 3), 1)], + ids=lambda val: f"shape={val}" if isinstance(val, tuple) else f"dim={val}", + ) + def test__channels_first(self, mocker, input_shape, dim, keep_dim): + # Just 1 test case to verify correct handling of the `dim`. + # Most cases fall into the single bit error case, and since this test uses 2 operators, the error accumulates + # and the final error is larger. We cannot with 100% certainty say that the error is only caused by the single + # bit errors and not related to the format. That's why only this 1 case with no errors is used. + model = MaxPoolMeanDimModule(dim, keep_dim) + self.assert_delegated( + model, + input_shape, + mocker, + expected_delegated_ops={MaxPool2DWithIndices: 1, GetItem: 1, MeanDim: 1}, + ) diff --git a/backends/nxp/tests/ir/converter/node_converter/test_mul_tensor_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_mul_tensor_converter.py index 927af47bbf5..90113f484ad 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_mul_tensor_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_mul_tensor_converter.py @@ -21,9 +21,6 @@ ToChannelLastPreprocess, ) from executorch.backends.nxp.tests.graph_verifier import DetailedGraphVerifier -from executorch.backends.nxp.tests.model_output_comparator import ( - NumericalStatsOutputComparator, -) from executorch.backends.nxp.tests.models import ( MulTensorConvModule, MulTensorModule, @@ -256,7 +253,6 @@ def test__basic_nsys_inference(self, x_input_shape, mocker): def test__basic_nsys_inference_qat(self, x_input_shape, mocker): x_input_spec = ModelInputSpec(x_input_shape) model = MulTensorModule() - comparator = NumericalStatsOutputComparator() graph_verifier = DetailedGraphVerifier( mocker, expected_delegated_ops={MulTensor: 1}, expected_non_delegated_ops={} ) @@ -265,7 +261,6 @@ def test__basic_nsys_inference_qat(self, x_input_shape, mocker): model, [x_input_spec, x_input_spec], graph_verifier, - output_comparator=comparator, use_new_flow_neutron_c=True, use_qat=True, ) diff --git a/backends/nxp/tests/ir/converter/node_converter/test_slice_tensor_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_slice_tensor_converter.py index 78886558ba2..39fa900ca55 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_slice_tensor_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_slice_tensor_converter.py @@ -8,6 +8,7 @@ from executorch.backends.nxp.backend.edge_program_converter import ( EdgeProgramToIRConverter, ) +from executorch.backends.nxp.tests.dataset_creator import RandomDatasetCreator from executorch.backends.nxp.tests.executorch_pipeline import to_quantized_edge_program from executorch.backends.nxp.tests.executors import ( convert_run_compare, @@ -15,12 +16,22 @@ ToChannelFirstPreprocess, ToChannelLastPreprocess, ) +from executorch.backends.nxp.tests.graph_verifier import DetailedGraphVerifier +from executorch.backends.nxp.tests.model_output_comparator import ( + AllCloseOutputComparator, +) from executorch.backends.nxp.tests.models import ( SliceTensorConvModule, SliceTensorModule, ) -from executorch.exir.dialects._ops import ops as exir_ops +from executorch.backends.nxp.tests.nsys_testing import lower_run_compare +from executorch.backends.nxp.tests.ops_aliases import ( + Convolution, + ExecutorchDelegateCall, + Slice, + SliceCopy, +) from torch.export import ExportedProgram @@ -30,11 +41,6 @@ def reseed_model_per_test_run(): np.random.seed(23) -ExecutorchDelegateCall = torch.ops.higher_order.executorch_call_delegate -Slice = exir_ops.edge.aten.slice.Tensor -SliceCopy = exir_ops.edge.aten.slice_copy.Tensor - - passing_cases = [ pytest.param((24, 32), (0, 1), (0, 16), (24, 32), id="2D, no transpose"), pytest.param( @@ -238,7 +244,7 @@ def test_slice_tensor_w_conv_quant_conversion( (24, 32), (0, 1), (0, 32), (24, 32), id="2D, start is equal to size" ), pytest.param( - (24, 32), (0, 1), (0, 0), (24, -5), id="2D, clipped end equal to zero" + (24, 32), (0, 1), (0, 0), (24, -35), id="2D, clipped end equal to zero" ), pytest.param( (24, 32), (0, 1), (64, 0), (24, 32), id="2D, clipped start equal to size" @@ -298,3 +304,353 @@ def test_slice_not_delegated(mocker, x_input_shape, dims, starts, ends): for i in range(0, num_slice_ops): slice_idx = (i + 1) * 3 assert nodes[slice_idx].target in [Slice, SliceCopy] + + +class TestSliceTensorConverterNewNeutronFlow: + @staticmethod + def _slice_id(prefix, input_shape, dims, starts, ends): + return f"{prefix}rank={len(input_shape)}_dims={str(dims)}_starts={str(starts)}_ends={str(ends)}" + + @staticmethod + def assert_delegated_and_correct(model, input_shape, num_slices, mocker, use_qat): + graph_verifier = DetailedGraphVerifier( + mocker, + expected_delegated_ops={SliceCopy: num_slices}, + expected_non_delegated_ops={}, + ) + dataset = RandomDatasetCreator(low=-255.0, high=255.0) + comparator = AllCloseOutputComparator() + + lower_run_compare( + model, + input_shape, + graph_verifier, + dataset, + comparator, + use_new_flow_neutron_c=True, + use_qat=use_qat, + ) + + @staticmethod + def assert_model_without_slices(model, input_shape): + delegated_ep = to_quantized_edge_program( + model, input_shape, use_new_flow_neutron_c=True + ).exported_program() + + # Check there are no slices and nothing is delegated + assert not graph_contains_any_of_ops( + delegated_ep.graph, [ExecutorchDelegateCall] + ) + assert not graph_contains_any_of_ops(delegated_ep.graph, [Slice, SliceCopy]) + + @staticmethod + def assert_not_delegated(model, input_shape): + delegated_ep = to_quantized_edge_program( + model, input_shape, use_new_flow_neutron_c=True + ).exported_program() + + # Make sure the `slice` was NOT delegated. + assert not graph_contains_any_of_ops( + delegated_ep.graph, [ExecutorchDelegateCall] + ) + assert graph_contains_any_of_ops(delegated_ep.graph, [Slice, SliceCopy]) + + @pytest.mark.parametrize( + "input_shape, dims, starts, ends", + [ + pytest.param( + ins := (5, 2, 3, 4), + d := (0,), + s := (1,), + e := (4,), + id=_slice_id("basic, left and right trimmed:", ins, d, s, e), + ), + pytest.param( + ins := (5, 5, 3, 4), + d := (0, 1), + s := (1, 1), + e := (4, 3), + id=_slice_id("basic, left and right trimmed:", ins, d, s, e), + ), + pytest.param( + ins := (7, 13, 5, 15), + d := (0, 1, 2, 3), + s := (4, 3, 1, 8), + e := (5, 10, 4, 11), + id=_slice_id("basic, left and right trimmed:", ins, d, s, e), + ), + pytest.param( + ins := (5, 13, 5, 13), + d := (0, 1, 2, 3), + s := (0, 0, 0, 0), + e := (4, 11, 4, 11), + id=_slice_id("basic, right trimmed:", ins, d, s, e), + ), + pytest.param( + ins := (7, 13, 3, 15), + d := (0, 1, 2, 3), + s := (2, 5, 1, 4), + e := ins, + id=_slice_id("basic, left trimmed:", ins, d, s, e), + ), + pytest.param( + ins := (7, 4, 7), + d := (0, 1, 2), + s := (1, 1, 3), + e := (6, 3, 5), + id=_slice_id("basic, left and right trimmed:", ins, d, s, e), + ), + pytest.param( + ins := (4, 5, 9), + d := (0, 1, 2), + s := (0, 0, 0), + e := (3, 4, 7), + id=_slice_id("basic, right trimmed:", ins, d, s, e), + ), + pytest.param( + ins := (4, 7, 9), + d := (0, 1, 2), + s := (3, 2, 2), + e := ins, + id=_slice_id("basic, left trimmed:", ins, d, s, e), + ), + pytest.param( + ins := (4, 5), + d := (0, 1), + s := (1, 1), + e := (2, 4), + id=_slice_id("basic, left and right trimmed:", ins, d, s, e), + ), + pytest.param( + ins := (4, 5), + d := (0, 1), + s := (0, 0), + e := (2, 4), + id=_slice_id("basic, right trimmed:", ins, d, s, e), + ), + pytest.param( + ins := (4, 5), + d := (0, 1), + s := (1, 2), + e := ins, + id=_slice_id("basic, left trimmed:", ins, d, s, e), + ), + pytest.param( + ins := (5,), + d := (0,), + s := (1,), + e := (4,), + id=_slice_id("basic, left and right trimmed:", ins, d, s, e), + ), + pytest.param( + ins := (5,), + d := (0,), + s := (0,), + e := (4,), + id=_slice_id("basic, right trimmed:", ins, d, s, e), + ), + pytest.param( + ins := (5,), + d := (0,), + s := (1,), + e := ins, + id=_slice_id("basic, left trimmed:", ins, d, s, e), + ), + ], + ) + def test_nsys_inference__basic(self, input_shape, dims, starts, ends, mocker): + model = SliceTensorModule(dims, starts, ends) + + num_slices = len(dims) + self.assert_delegated_and_correct( + model, input_shape, num_slices, mocker, use_qat=False + ) + + @pytest.mark.parametrize( + "input_shape, dims, starts, ends", + [ + pytest.param( + ins := (4, 2, 7, 4), + d := (2,), + s := (5,), + e := (6,), + id=_slice_id("edge case, dimension reduced to 1:", ins, d, s, e), + ), + pytest.param( + ins := (11, 2, 7, 5), + d := (2,), + s := (6,), + e := (6,), + id=_slice_id("edge case, dimension reduced to 0:", ins, d, s, e), + ), + ], + ) + def test_nsys_inference__reduction(self, input_shape, dims, starts, ends, mocker): + model = SliceTensorModule(dims, starts, ends) + + slice_lengths = [e - s for s, e in zip(starts, ends)] + if all(sl == 0 for sl in slice_lengths): + # reductions to 0 are disabled in the backend + self.assert_not_delegated(model, input_shape) + else: + num_slices = len(dims) + self.assert_delegated_and_correct( + model, input_shape, num_slices, mocker, use_qat=False + ) + + @pytest.mark.parametrize( + "input_shape, dims, starts, ends", + [ + pytest.param( + ins := (5, 2, 3, 4), + d := (0,), + s := (-12,), + e := (2,), + id=_slice_id("edge case, `start` clipped:", ins, d, s, e), + ), + pytest.param( + ins := (5, 7, 5, 7), + d := (0,), + s := (1,), + e := (12,), + id=_slice_id("edge case, `end` clipped:", ins, d, s, e), + ), + ], + ) + def test_nsys_inference__clipped(self, input_shape, dims, starts, ends, mocker): + model = SliceTensorModule(dims, starts, ends) + + num_slices = len(dims) + self.assert_delegated_and_correct( + model, input_shape, num_slices, mocker, use_qat=False + ) + + @pytest.mark.parametrize( + "input_shape, dims, starts, ends", + [ + pytest.param( + ins := (5, 11, 13, 3), + d := (1,), + s := (-5,), + e := (10,), + id=_slice_id("edge case, `start` normalized:", ins, d, s, e), + ), + pytest.param( + ins := (7, 15, 5, 7), + d := (1,), + s := (2,), + e := (-2,), + id=_slice_id("edge case, `end` normalized:", ins, d, s, e), + ), + ], + ) + def test_nsys_inference__normalization( + self, input_shape, dims, starts, ends, mocker + ): + model = SliceTensorModule(dims, starts, ends) + + num_slices = len(dims) + self.assert_delegated_and_correct( + model, input_shape, num_slices, mocker, use_qat=False + ) + + @pytest.mark.parametrize( + "input_shape, dims, starts, ends", + [ + pytest.param( + ins := (5000, 3, 5, 3), + d := (0,), + s := (1250,), + e := (2500,), + id=_slice_id("big args, left and right trimmed:", ins, d, s, e), + ), + pytest.param( + ins := (2, 5000, 5, 3), + d := (1,), + s := (0,), + e := (4999,), + id=_slice_id("big args, right trimmed:", ins, d, s, e), + ), + pytest.param( + ins := (2, 3, 5000, 3), + d := (2,), + s := (1,), + e := (5000,), + id=_slice_id("big args, left trimmed:", ins, d, s, e), + ), + ], + ) + def test_nsys_inference__big(self, input_shape, dims, starts, ends, mocker): + model = SliceTensorModule(dims, starts, ends) + + num_slices = len(dims) + self.assert_delegated_and_correct( + model, input_shape, num_slices, mocker, use_qat=False + ) + + @pytest.mark.parametrize( + "input_shape, dims, starts, ends", + [ + pytest.param( + ins := (5, 2, 3, 4), + d := (2,), + s := (0,), + e := (3,), + id=_slice_id("edge case, one dimension identity:", ins, d, s, e), + ), + pytest.param( + ins := (5, 2, 3, 4), + d := (0, 1, 2, 3), + s := (0, 0, 0, 0), + e := ins, + id=_slice_id("edge case, all dimensions identity:", ins, d, s, e), + ), + ], + ) + def test_nsys_inference__identity(self, input_shape, dims, starts, ends): + model = SliceTensorModule(dims, starts, ends) + + self.assert_model_without_slices(model, input_shape) + + def test_nsys_inference__with_conv(self, mocker): + input_shape = (11, 13, 5, 7) + in_channels = input_shape[1] + out_channels = 19 + + # we test functionality on `channels` dim + dims = (1,) + starts = (2,) + ends = (out_channels - 2,) + model = SliceTensorConvModule(dims, starts, ends, in_channels, out_channels) + + num_slices = len(dims) + graph_verifier = DetailedGraphVerifier( + mocker, + expected_delegated_ops={SliceCopy: num_slices}, + expected_non_delegated_ops={Convolution: 1}, + ) + dataset = RandomDatasetCreator(low=-255.0, high=255.0) + comparator = AllCloseOutputComparator() + + lower_run_compare( + model, + input_shape, + graph_verifier, + dataset, + comparator, + use_new_flow_neutron_c=True, + use_qat=False, + ) + + def test_nsys_inference__qat(self, mocker): + input_shape = (7, 13, 7, 9) + dims = (0, 1, 2, 3) + starts = (1, 2, 3, 2) + ends = (6, 10, 5, 8) + + model = SliceTensorModule(dims, starts, ends) + + num_slices = len(dims) + self.assert_delegated_and_correct( + model, input_shape, num_slices, mocker, use_qat=True + ) diff --git a/backends/nxp/tests/ir/converter/node_converter/test_sub_tensor_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_sub_tensor_converter.py index 9ce3e93f39b..2734e89bc5d 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_sub_tensor_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_sub_tensor_converter.py @@ -1,7 +1,8 @@ -# Copyright 2025 NXP +# Copyright 2025-2026 NXP # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. + import numpy as np import pytest import torch @@ -9,18 +10,29 @@ from executorch.backends.nxp.backend.edge_program_converter import ( EdgeProgramToIRConverter, ) -from executorch.backends.nxp.tests.executorch_pipeline import to_quantized_edge_program +from executorch.backends.nxp.tests.dataset_creator import RandomDatasetCreator +from executorch.backends.nxp.tests.executorch_pipeline import ( + ModelInputSpec, + to_quantized_edge_program, +) from executorch.backends.nxp.tests.executors import ( convert_run_compare, + graph_contains_any_of_ops, ToChannelFirstPreprocess, ToChannelLastPreprocess, ) +from executorch.backends.nxp.tests.graph_verifier import DetailedGraphVerifier from executorch.backends.nxp.tests.models import ( SubTensorConvModule, SubTensorModule, SubTensorOneInputModule, ) -from executorch.exir.dialects._ops import ops as exir_ops +from executorch.backends.nxp.tests.nsys_testing import lower_run_compare +from executorch.backends.nxp.tests.ops_aliases import ( + Convolution, + ExecutorchDelegateCall, + SubTensor, +) from torch.export import ExportedProgram from executorch.backends.nxp.tests.use_qat import * # noqa F403 @@ -63,7 +75,7 @@ def test_sub_tensor_quant_conversion(mocker, input_shape, use_qat): input_data = {0: input_data_1, 1: input_data_2} nodes = list(exported_program.graph.nodes) - assert nodes[4].target == exir_ops.edge.aten.sub.Tensor + assert nodes[4].target == SubTensor convert_run_compare( exported_program, tfl_model=tflite_flatbuffers_model, input_data=input_data @@ -96,7 +108,7 @@ def test_sub_tensor_one_input_quant_conversion(mocker, input_shape, use_qat): input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype(np.int8) nodes = list(exported_program.graph.nodes) - assert nodes[2].target == exir_ops.edge.aten.sub.Tensor + assert nodes[2].target == SubTensor convert_run_compare( exported_program, tfl_model=tflite_flatbuffers_model, input_data=input_data @@ -141,7 +153,7 @@ def test_sub_tensor_w_conv_quant_conversion(mocker, x_input_shape, use_qat): input_data = {0: input_data_1, 1: input_data_2} nodes = list(exported_program.graph.nodes) - assert nodes[15].target == exir_ops.edge.aten.sub.Tensor + assert nodes[15].target == SubTensor convert_run_compare( exported_program, @@ -176,6 +188,236 @@ def test_sub_tensor_broadcasting_unsupported_quant_conversion( nodes = list(edge_program.graph.nodes) # Broadcast is not supported, node is not converted - assert ( - nodes[6].target == exir_ops.edge.aten.sub.Tensor - ) # Sub Tensor is not delegated. + assert nodes[6].target == SubTensor # Sub Tensor is not delegated. + + +class TestSubTensorNewNeutronFlow: + @pytest.mark.parametrize( + "x_input_shape", + [ + pytest.param((1,), id="1D."), + pytest.param((6, 5), id="2D."), + pytest.param((1, 4, 7), id="3D."), + pytest.param( + (6, 82), + id="2D incorrect.", + marks=pytest.mark.xfail(reason="AIR-14602: incorrect results"), + ), + pytest.param( + (1, 68, 7), + id="3D incorrect.", + marks=pytest.mark.xfail(reason="AIR-14602: incorrect results"), + ), + pytest.param( + (2, 4, 3, 15), + id="4D incorrect.", + marks=pytest.mark.xfail(reason="AIR-14602: incorrect results"), + ), + pytest.param( + (1, 4, 9, 11, 4), + id="5D incorrect.", + marks=pytest.mark.xfail(reason="AIR-14602: incorrect results"), + ), + ], + ) + def test__basic_nsys_inference(self, x_input_shape, mocker): + x_input_spec = ModelInputSpec(x_input_shape) + model = SubTensorModule() + graph_verifier = DetailedGraphVerifier( + mocker, expected_delegated_ops={SubTensor: 1}, expected_non_delegated_ops={} + ) + dataset_creator = RandomDatasetCreator(low=-1.0, high=1.0) + + lower_run_compare( + model, + [x_input_spec, x_input_spec], + graph_verifier, + dataset_creator, + use_new_flow_neutron_c=True, + ) + + @pytest.mark.parametrize( + "x_input_shape", + [ + pytest.param((1,), id="1D."), + pytest.param((6, 5), id="2D."), + pytest.param((2, 4, 3, 15), id="4D."), + pytest.param( + (1, 4, 7), + id="3D incorrect.", + marks=pytest.mark.xfail(reason="AIR-14602: incorrect results"), + ), + pytest.param( + (1, 4, 9, 11, 4), + id="5D incorrect.", + marks=pytest.mark.xfail(reason="AIR-14602: incorrect results"), + ), + ], + ) + def test__basic_nsys_inference_qat(self, x_input_shape, mocker): + x_input_spec = ModelInputSpec(x_input_shape) + model = SubTensorModule() + graph_verifier = DetailedGraphVerifier( + mocker, expected_delegated_ops={SubTensor: 1}, expected_non_delegated_ops={} + ) + dataset_creator = RandomDatasetCreator(low=-1.0, high=1.0) + + lower_run_compare( + model, + [x_input_spec, x_input_spec], + graph_verifier, + dataset_creator, + use_new_flow_neutron_c=True, + use_qat=True, + ) + + @pytest.mark.parametrize( + "input_spec", + [ + pytest.param( + [ModelInputSpec((4, 6)), ModelInputSpec((1, 6))], id="2 inputs 2D." + ), + pytest.param( + [ModelInputSpec((4,)), ModelInputSpec((4, 4))], id="2 inputs 1D + 2D." + ), + pytest.param( + [ModelInputSpec((5, 3, 4)), ModelInputSpec((1, 3, 1))], + id="2 inputs 3D incorrect.", + marks=pytest.mark.xfail(reason="AIR-14602: incorrect results"), + ), + pytest.param( + [ModelInputSpec((69, 73)), ModelInputSpec((1, 73))], + id="2 inputs 2D incorrect.", + marks=pytest.mark.xfail(reason="AIR-14602: incorrect results"), + ), + ], + ) + def test__broadcast(self, input_spec, mocker): + model = SubTensorModule() + graph_verifier = DetailedGraphVerifier( + mocker, expected_delegated_ops={SubTensor: 1}, expected_non_delegated_ops={} + ) + dataset_creator = RandomDatasetCreator(low=-1.0, high=1.0) + + lower_run_compare( + model, + input_spec, + graph_verifier, + dataset_creator, + use_new_flow_neutron_c=True, + ) + + @pytest.mark.parametrize( + "input_spec", + [ + pytest.param( + [ModelInputSpec((4, 1)), ModelInputSpec((1, 6))], id="2 inputs 2D." + ), + pytest.param( + [ModelInputSpec((1, 3, 4)), ModelInputSpec((5, 3, 1))], + id="2 inputs 3D.", + ), + pytest.param( + [ModelInputSpec((6, 4)), ModelInputSpec((6, 6, 1))], + id="2 inputs 2D+3D.", + ), + ], + ) + def test__broadcast_unsupported(self, input_spec): + # Broadcast where at least one of the inputs is not equal to output is not supported + model = SubTensorModule() + + delegated_ep = to_quantized_edge_program( + model, input_spec, use_new_flow_neutron_c=True + ).exported_program() + + # Make sure the `sub.Tensor` was NOT delegated. + assert not graph_contains_any_of_ops( + delegated_ep.graph, [ExecutorchDelegateCall] + ) + assert graph_contains_any_of_ops(delegated_ep.graph, [SubTensor]) + + @pytest.mark.parametrize( + "x_input_shape", + [ + pytest.param( + (1, 4, 5, 5), id="4D, product of dims is not a multiple of 8." + ), + ], + ) + def test__w_conv(self, x_input_shape, mocker): + model = SubTensorConvModule() + + n, c, h, w = x_input_shape + y_input_spec = ModelInputSpec((n, 8, h, w)) + x_input_spec = ModelInputSpec(x_input_shape) + + graph_verifier = DetailedGraphVerifier( + mocker, + expected_delegated_ops={SubTensor: 1, Convolution: 1}, + expected_non_delegated_ops={}, + ) + dataset_creator = RandomDatasetCreator(low=-1.0, high=1.0) + + lower_run_compare( + model, + [x_input_spec, y_input_spec], + graph_verifier, + dataset_creator, + use_new_flow_neutron_c=True, + ) + + @pytest.mark.parametrize( + "input_spec", + [ + pytest.param( + [ModelInputSpec((1, 4, 7, 1)), ModelInputSpec((1, 8, 1, 1))], + id="2 inputs 4D + 4D.", + ), + pytest.param( + [ModelInputSpec((1, 4, 5, 5)), ModelInputSpec((1, 8, 5, 1))], + id="2 inputs 4D + 4D incorrect.", + marks=pytest.mark.xfail(reason="AIR-14602: incorrect results"), + ), + ], + ) + def test__w_conv_broadcast(self, input_spec, mocker): + model = SubTensorConvModule() + graph_verifier = DetailedGraphVerifier( + mocker, + expected_delegated_ops={SubTensor: 1, Convolution: 1}, + expected_non_delegated_ops={}, + ) + dataset_creator = RandomDatasetCreator(low=-1.0, high=1.0) + + lower_run_compare( + model, + input_spec, + graph_verifier, + dataset_creator, + use_new_flow_neutron_c=True, + ) + + @pytest.mark.parametrize( + "input_spec", + [ + pytest.param( + [ModelInputSpec((1, 4, 5, 5)), ModelInputSpec((1, 5))], + id="2 inputs 4D + 2D.", + ), + pytest.param( + [ModelInputSpec((1, 4, 4, 10)), ModelInputSpec((1, 4, 1))], + id="2 inputs 4D + 3D.", + ), + ], + ) + def test__w_conv_unsupported(self, input_spec): + model = SubTensorConvModule() + + delegated_ep = to_quantized_edge_program( + model, input_spec, use_new_flow_neutron_c=True + ).exported_program() + + # Make sure the `sub.Tensor` was NOT delegated. + assert graph_contains_any_of_ops(delegated_ep.graph, [ExecutorchDelegateCall]) + assert graph_contains_any_of_ops(delegated_ep.graph, [SubTensor]) diff --git a/backends/nxp/tests/ir/converter/node_converter/test_tanh_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_tanh_converter.py index 10892d28e38..ba2f5bf07d1 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_tanh_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_tanh_converter.py @@ -1,4 +1,4 @@ -# Copyright 2025 NXP +# Copyright 2025-2026 NXP # All rights reserved. # # This source code is licensed under the BSD-style license found in the @@ -8,9 +8,13 @@ import kgb import numpy as np + +# noinspection PyUnusedImports +import pytest import torch from executorch.backends.nxp.nxp_backend import EdgeProgramToIRConverter +from executorch.backends.nxp.tests.dataset_creator import RandomDatasetCreator from executorch.backends.nxp.tests.executorch_pipeline import to_quantized_edge_program from executorch.backends.nxp.tests.executors import ( convert_run_compare, @@ -18,10 +22,13 @@ ToChannelFirstPreprocess, ToChannelLastPreprocess, ) +from executorch.backends.nxp.tests.graph_verifier import DetailedGraphVerifier from executorch.backends.nxp.tests.models import Conv2dWithActivation -from executorch.exir.dialects._ops import ops as exir_ops +from executorch.backends.nxp.tests.nsys_testing import lower_run_compare +from executorch.backends.nxp.tests.ops_aliases import Convolution, Tanh, Tanh_ from parameterized import parameterized from torch.export import ExportedProgram +from executorch.backends.nxp.tests.use_qat import * # noqa F403 class TestTanhConverter(unittest.TestCase): @@ -73,10 +80,7 @@ def test_conv_tanh( lowered_module_graph = ( quantized_program.graph_module.lowered_module_0.original_module.graph ) - tanh_ops = [ - exir_ops.edge.aten.tanh.default, - exir_ops.edge.aten.tanh_.default, - ] + tanh_ops = [Tanh, Tanh_] assert graph_contains_any_of_ops(graph=lowered_module_graph, ops=tanh_ops) input_data = (np.random.random(input_shape) * 50).astype(np.int8) @@ -88,3 +92,82 @@ def test_conv_tanh( input_data=input_data, atol=2.0, ) + + +class TanhModule(torch.nn.Module): + def __init__(self, inplace: bool = False): + super().__init__() + self.inplace = inplace + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.inplace: + return torch.tanh_(x) + else: + return torch.tanh(x) + + +class TestTanhNewNeutronFlow: + + # noinspection PyMethodMayBeStatic + def assert_delegated( + self, + model, + input_shape, + mocker, + use_qat=False, + expected_delegated_ops=None, + ): + if expected_delegated_ops is None: + expected_delegated_ops = {Tanh: 1} + + graph_verifier = DetailedGraphVerifier( + mocker, + expected_delegated_ops=expected_delegated_ops, + expected_non_delegated_ops={}, + ) + + # Cover also negative values to thoroughly test the operator. + dataset_creator = RandomDatasetCreator(low=-2, high=2) + + lower_run_compare( + model, + input_shape, + graph_verifier, + dataset_creator, + use_qat=use_qat, + use_new_flow_neutron_c=True, # Use the new flow. + ) + + @pytest.fixture(params=[True, False], ids=lambda inplace: f"inplace = {inplace}") + def inplace(self, request): + return request.param + + def test__qat__inplace(self, mocker, use_qat, inplace): + shape = (23,) + model = TanhModule(inplace) + self.assert_delegated(model, shape, mocker, use_qat=use_qat) + + @pytest.mark.parametrize( + "shape", + [ + (16,), + (3, 5), + (2, 3, 4), + (2, 3, 4, 5), + (2, 3, 2, 3, 2), + ], + ids=lambda shape: f"{len(shape)}D", + ) + def test__shapes(self, mocker, shape): + model = TanhModule() + self.assert_delegated(model, shape, mocker) + + def test__with_convolution(self, mocker): + input_shape = (1, 3, 12, 16) + channels = input_shape[1] + model = Conv2dWithActivation( + activation=torch.tanh, in_channels=channels, out_channels=channels + ) + self.assert_delegated( + model, input_shape, mocker, expected_delegated_ops={Tanh: 1, Convolution: 1} + ) diff --git a/backends/nxp/tests/models.py b/backends/nxp/tests/models.py index 045dcfaba40..0383734b4dd 100644 --- a/backends/nxp/tests/models.py +++ b/backends/nxp/tests/models.py @@ -456,11 +456,16 @@ def forward(self, x): class Conv2dWithActivation(torch.nn.Module): - def __init__(self, activation: torch.nn.Module | Callable, in_channels: int = 3): + def __init__( + self, + activation: torch.nn.Module | Callable, + in_channels: int = 3, + out_channels: int = 64, + ): super().__init__() self.conv = torch.nn.Conv2d( - in_channels=in_channels, out_channels=64, kernel_size=(3, 3) + in_channels=in_channels, out_channels=out_channels, kernel_size=(3, 3) ) self.activation = activation @@ -656,9 +661,9 @@ def __init__(self): super().__init__() self.conv = Conv2dModule(padding=1, stride=1) - def forward(self, x): + def forward(self, x, y): x = self.conv(x) - return x + x + return x + y class AddTensorOneInputModule(torch.nn.Module): diff --git a/backends/nxp/tests/nsys_testing.py b/backends/nxp/tests/nsys_testing.py index 636e1a28a44..ab5a583ede0 100644 --- a/backends/nxp/tests/nsys_testing.py +++ b/backends/nxp/tests/nsys_testing.py @@ -23,7 +23,11 @@ ) from executorch.backends.nxp.neutron_partitioner import NeutronPartitioner from executorch.backends.nxp.tests.config_importer import test_config -from executorch.backends.nxp.tests.dataset_creator import RandomDatasetCreator +from executorch.backends.nxp.tests.dataset_creator import ( + create_quantized_variant_of_dataset, + InputQuantizationSpec, + RandomDatasetCreator, +) from executorch.backends.nxp.tests.executorch_pipeline import ( get_calibration_inputs_fn_from_dataset_dir, ModelInputSpec, @@ -61,20 +65,7 @@ class ReferenceModel(Enum): FLOAT_PYTORCH_PYTHON = 4 -def _run_delegated_executorch_program( - model, - test_dir, - test_name, - calibration_dataset_dir, - testing_dataset_dir, - input_spec, - dlg_model_verifier, - npu_results_dir, - mocker, - use_qat: bool = False, - train_fn: Callable[[torch.fx.GraphModule], None] | None = None, - use_new_flow_neutron_c: bool = False, -) -> ExportedProgram: +def _get_dataset_cli_args(input_spec: list[ModelInputSpec], testing_dataset_dir): if len(input_spec) == 1: # Single input, use --dataset dataset_cli = "--dataset" @@ -90,14 +81,25 @@ def _run_delegated_executorch_program( ] ) ) + return dataset_cli, dataset_or_inputs - # Run nxp_executor_runner with program delegated to NPU - delegated_model_path = os.path.abspath( - os.path.join(test_dir, f"{test_name}_delegated.pte") - ) - delegated_cmd = f"{NEUTRON_TEST_PATH} --model {delegated_model_path} {dataset_cli} {dataset_or_inputs} \ - --output {npu_results_dir} --firmware {NSYS_FIRMWARE_PATH} --nsys {NSYS_PATH} --nsys_config {NSYS_CONFIG_PATH}" +def _run_delegated_executorch_program( + model, + test_dir, + test_name, + calibration_dataset_dir, + testing_dataset_dir, + input_spec, + dlg_model_verifier, + npu_results_dir, + mocker, + use_qat: bool = False, + train_fn: Callable[[torch.fx.GraphModule], None] | None = None, + use_new_flow_neutron_c: bool = False, + operators_not_to_delegate: list[str] = None, + remove_quant_io_ops: bool = False, +) -> tuple[ExportedProgram, str]: try: if mocker: method = getattr(NeutronPartitioner, "partition") # noqa B009 @@ -123,6 +125,8 @@ def wrapper(*args, **kwargs): use_qat=use_qat, train_fn=train_fn, use_new_flow_neutron_c=use_new_flow_neutron_c, + operators_not_to_delegate=operators_not_to_delegate, + remove_quant_io_ops=remove_quant_io_ops, ) except RuntimeError as e: if "Model converted with neutron-converter has" in str(e) and hasattr( @@ -139,9 +143,30 @@ def wrapper(*args, **kwargs): dlg_model_verifier.verify_graph(exported_program.graph) save_pte_program(delegated_program, test_name + "_delegated", test_dir) + + # Preparation of quantized dataset, requires quantization parameters from converted delegated model + if remove_quant_io_ops: + dataset_dir_quant = os.path.join(test_dir, "dataset_quant") + input_quant_spec = _parse_input_quant_params(input_spec, delegated_program) + create_quantized_variant_of_dataset( + testing_dataset_dir, dataset_dir_quant, input_quant_spec, input_spec + ) + testing_dataset_dir = dataset_dir_quant + + dataset_cli, dataset_or_inputs = _get_dataset_cli_args( + input_spec, testing_dataset_dir + ) + + # Run nxp_executor_runner with program delegated to NPU + delegated_model_path = os.path.abspath( + os.path.join(test_dir, f"{test_name}_delegated.pte") + ) + + delegated_cmd = f"{NEUTRON_TEST_PATH} --model {delegated_model_path} {dataset_cli} {dataset_or_inputs} \ + --output {npu_results_dir} --firmware {NSYS_FIRMWARE_PATH} --nsys {NSYS_PATH} --nsys_config {NSYS_CONFIG_PATH}" execute_cmd(delegated_cmd) - return exported_program + return exported_program, testing_dataset_dir def _run_non_delegated_executorch_program( @@ -154,31 +179,12 @@ def _run_non_delegated_executorch_program( cpu_results_dir, use_qat: bool = False, train_fn: Callable[[torch.fx.GraphModule], None] | None = None, + remove_quant_io_ops: bool = False, ) -> ExportedProgram: - if len(input_spec) == 1: - # Single input, use --dataset - dataset_cli = "--dataset" - dataset_or_inputs = testing_dataset_dir - else: - # Multiple input, use --inputs with subdirectories - dataset_cli = "--inputs" - dataset_or_inputs = ",".join( - sorted( - [ - os.path.join(testing_dataset_dir, d) - for d in os.listdir(testing_dataset_dir) - ] - ) - ) - - # Run program via nxp_executor_runner on CPU - non_delegated_model_path = os.path.abspath( - os.path.join(test_dir, f"{test_name}_non_delegated.pte") + dataset_cli, dataset_or_inputs = _get_dataset_cli_args( + input_spec, testing_dataset_dir ) - non_delegated_cmd = f"{NEUTRON_TEST_PATH} --model {non_delegated_model_path} {dataset_cli} {dataset_or_inputs} \ - --output {cpu_results_dir} --firmware {NSYS_FIRMWARE_PATH} --nsys {NSYS_PATH} --nsys_config {NSYS_CONFIG_PATH}" - non_delegated_program = to_quantized_executorch_program( model, input_spec, @@ -186,6 +192,7 @@ def _run_non_delegated_executorch_program( delegate_to_npu=False, use_qat=use_qat, train_fn=train_fn, + remove_quant_io_ops=remove_quant_io_ops, ) nodes = list(non_delegated_program.exported_program().graph.nodes) @@ -194,6 +201,14 @@ def _run_non_delegated_executorch_program( ), "Delegated parts found in program executed on CPU!" save_pte_program(non_delegated_program, test_name + "_non_delegated", test_dir) + + # Run program via nxp_executor_runner on CPU + non_delegated_model_path = os.path.abspath( + os.path.join(test_dir, f"{test_name}_non_delegated.pte") + ) + + non_delegated_cmd = f"{NEUTRON_TEST_PATH} --model {non_delegated_model_path} {dataset_cli} {dataset_or_inputs} \ + --output {cpu_results_dir} --firmware {NSYS_FIRMWARE_PATH} --nsys {NSYS_PATH} --nsys_config {NSYS_CONFIG_PATH}" execute_cmd(non_delegated_cmd) return non_delegated_program.exported_program() @@ -229,9 +244,9 @@ def read_prepared_samples( bin_file_path = os.path.join( sample_dir, f"{str(spec_idx).zfill(2)}.bin" ) - sample_vector = np.fromfile(bin_file_path, dtype=spec.type).reshape( - spec.shape - ) + sample_vector = np.fromfile( + bin_file_path, dtype=torch_type_to_numpy_type(spec.dtype) + ).reshape(spec.shape) current_samples.append(sample_vector) all_samples.append(tuple(current_samples)) @@ -385,6 +400,8 @@ def lower_run_compare( use_qat: bool = False, train_fn: Callable[[torch.fx.GraphModule], None] | None = None, use_new_flow_neutron_c: bool = False, + operators_not_to_delegate: list[str] = None, + remove_quant_io_ops: bool = False, ): """ Run provided program twice with neutron-test and check if results correspond. At first, @@ -402,6 +419,10 @@ def lower_run_compare( :param use_qat: If True, applies quantization-aware training before conversion (without the QAT training). :param train_fn: Train/finetune function for QAT training. Is used only when `use_qat=True`. :param use_new_flow_neutron_c: Enable experimental MLIR-based flow for Neutron-C with improved INT8 operator support. + :param operators_not_to_delegate: list of operators not to delegate. + :param remove_quant_io_ops: If true, IO q-ops are removed and verification is done on quantized + version of dataset (quantized INT8 input samples). + """ assert_NSYS() @@ -430,7 +451,7 @@ def lower_run_compare( cpu_results_dir = os.path.join(test_dir, "results_cpu") npu_results_dir = os.path.join(test_dir, "results_npu") - delegated_program = _run_delegated_executorch_program( + delegated_program, testing_dataset_dir = _run_delegated_executorch_program( model_to_delegate, test_dir, test_name, @@ -443,6 +464,8 @@ def lower_run_compare( use_qat=use_qat, train_fn=train_fn, use_new_flow_neutron_c=use_new_flow_neutron_c, + operators_not_to_delegate=operators_not_to_delegate, + remove_quant_io_ops=remove_quant_io_ops, ) output_spec = _get_program_output_spec(delegated_program) @@ -461,6 +484,7 @@ def lower_run_compare( cpu_results_dir, use_qat=use_qat, train_fn=train_fn, + remove_quant_io_ops=remove_quant_io_ops, ) case ReferenceModel.QUANTIZED_EDGE_PYTHON: @@ -475,10 +499,19 @@ def lower_run_compare( delegate_to_npu=False, use_qat=use_qat, train_fn=train_fn, + remove_quant_io_ops=remove_quant_io_ops, ) .exported_program() .module() ) + # Switch input spec dtype to quantized int8 if run with remove_quant_io_ops flag + # The input spec has to still have float32 dtype during edge program lowering to correctly calibrate the + # model. When running in Python, the testing data are loaded from numpy tensors according to input spec. + # There the testing data are in quantized int8 dtype. + if remove_quant_io_ops: + for spec in input_spec: + spec.dtype = torch.int8 + _run_python_program( non_delegated_edge_program, testing_dataset_dir, @@ -489,6 +522,12 @@ def lower_run_compare( ) case ReferenceModel.FLOAT_PYTORCH_PYTHON: + if remove_quant_io_ops: + raise ValueError( + "Flag remove_quant_io_ops is not applicable to FLOAT_PYTORCH_PYTHON reference model" + "as it works with float data only. Run with remove_quant_io_ops=False." + ) + # Run the PyTorch nn.Module directly in Python. _run_python_program( model_to_not_delegate, @@ -561,7 +600,7 @@ def lower_run_compare_ptq_qat( ptq_results_dir = os.path.join(test_dir, "results_ptq") qat_results_dir = os.path.join(test_dir, "results_qat") - delegated_program_ptq = _run_delegated_executorch_program( + delegated_program_ptq, _ = _run_delegated_executorch_program( model_ptq, test_dir, test_name, @@ -597,12 +636,39 @@ def lower_run_compare_ptq_qat( ) +def _parse_input_quant_params( + input_spec: tuple[ModelInputSpec, ...], exported_program_manager +) -> list[InputQuantizationSpec]: + """ + Parse input quantization params from provided exported program manager. + + :param input_spec: Model inputs specification. + :param exported_program_manager: Exported program manager of parsed model. + :return: List of input quantization specification. + """ + if (config_methods := exported_program_manager._config_methods) is None: + raise ValueError("Attempt to parse q-params for not fully quantized model") + + q_params = [] + + for idx in range(len(input_spec)): + input_name = f"input{idx}" + scale = config_methods[f"{input_name}_scale"] + zp = config_methods[f"{input_name}_zp"] + dtype = config_methods[f"{input_name}_dtype"] + + q_params.append(InputQuantizationSpec(input_name, scale, zp, dtype)) + + return q_params + + def _get_caller_name(): test_function_names = ["lower_run_compare", "lower_run_compare_ptq_qat"] for idx, frame in enumerate(inspect.stack()): if frame.function in test_function_names: # Look one index above to get caller return inspect.stack()[idx + 1].function + return None def execute_cmd(cmd, cwd="."): diff --git a/backends/nxp/tests/ops_aliases.py b/backends/nxp/tests/ops_aliases.py index ec58072658d..78a2ac10f55 100644 --- a/backends/nxp/tests/ops_aliases.py +++ b/backends/nxp/tests/ops_aliases.py @@ -13,6 +13,7 @@ Abs = exir_ops.edge.aten.abs.default AdaptiveAvgPool2D = exir_ops.edge.aten._adaptive_avg_pool2d.default +AddTensor = exir_ops.edge.aten.add.Tensor AvgPool2D = exir_ops.edge.aten.avg_pool2d.default Bmm = exir_ops.edge.aten.bmm.default ConstantPadND = exir_ops.edge.aten.constant_pad_nd.default @@ -25,6 +26,7 @@ HardTanh_ = exir_ops.edge.aten.hardtanh_.default LeakyRelu = exir_ops.edge.aten.leaky_relu.default MaxPool2DWithIndices = exir_ops.edge.aten.max_pool2d_with_indices.default +MeanDim = exir_ops.edge.aten.mean.dim MulTensor = exir_ops.edge.aten.mul.Tensor QuantizePerChannel = exir_ops.edge.quantized_decomposed.quantize_per_channel.default QuantizePerTensor = exir_ops.edge.quantized_decomposed.quantize_per_tensor.default @@ -36,6 +38,9 @@ Squeeze = exir_ops.edge.aten.squeeze.default SqueezeDim = exir_ops.edge.aten.squeeze.dim SqueezeDims = exir_ops.edge.aten.squeeze.dims +SubTensor = exir_ops.edge.aten.sub.Tensor +Tanh = exir_ops.edge.aten.tanh.default +Tanh_ = exir_ops.edge.aten.tanh_.default Unsqueeze = exir_ops.edge.aten.unsqueeze.default UpsampleBilinear2D = exir_ops.edge.aten.upsample_bilinear2d.vec UpsampleNearest2D = exir_ops.edge.aten.upsample_nearest2d.vec diff --git a/backends/qualcomm/_passes/build_quant_io.py b/backends/qualcomm/_passes/build_quant_io.py index d43842e84a5..057dcc0f864 100644 --- a/backends/qualcomm/_passes/build_quant_io.py +++ b/backends/qualcomm/_passes/build_quant_io.py @@ -5,11 +5,10 @@ # LICENSE file in the root directory of this source tree. import torch from executorch.backends.qualcomm.utils.constants import QCOM_QUANTIZED_IO -from executorch.exir.delegate import executorch_call_delegate -from executorch.exir.pass_base import ExportPass, ProxyValue +from executorch.exir.delegate import executorch_call_delegate +from executorch.exir.pass_base import ExportPass, PassResult from executorch.exir.tensor import TensorSpec -from torch.utils import _pytree as pytree class BuildQuantIo(ExportPass): @@ -28,22 +27,27 @@ def _make_spec(self, x): else: return None - def placeholder(self, name: str, arg, meta): - if quantized_dtype := meta.data.get(QCOM_QUANTIZED_IO, None): - arg = arg.to(dtype=quantized_dtype) - meta["spec"] = self._make_spec(arg) - return super().placeholder(name, arg, meta) - - def call_getitem(self, value, key: int, meta): - meta["spec"] = value.node.meta["spec"][key] - return super().call_getitem(value, key, meta) - - def call_delegate(self, lowered_module, args, kwargs, meta): - args_data, _ = pytree.tree_map_only( - ProxyValue, lambda x: x.data, (args, kwargs) - ) - meta["spec"] = pytree.tree_map( - self._make_spec, - executorch_call_delegate(lowered_module, *args_data), - ) - return super().call_delegate(lowered_module, args, kwargs, meta) + def _build(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule: + # Forcedly update delegate node's meta['spec'] to get correct output + # tensor size in runtime + call_delegates = [ + node + for node in graph_module.graph.nodes + if node.op == "call_function" and node.target == executorch_call_delegate + ] + for n in graph_module.graph.nodes: + if QCOM_QUANTIZED_IO in n.meta: + n.meta["val"] = n.meta["val"].to(dtype=n.meta[QCOM_QUANTIZED_IO]) + n.meta["spec"] = self._make_spec(n.meta["val"]) + + for call_delegate in call_delegates: + spec = [] + for user in list(call_delegate.users): + spec.append(self._make_spec(user.meta["val"])) + call_delegate.meta["spec"] = tuple(spec) + + def call(self, graph_module: torch.fx.GraphModule): + self._build(graph_module) + graph_module.graph.eliminate_dead_code() + graph_module.recompile() + return PassResult(graph_module, True) diff --git a/backends/qualcomm/_passes/recompose_pad_maxpool2d.py b/backends/qualcomm/_passes/recompose_pad_maxpool2d.py index 81b4836f251..6a8374cb66a 100644 --- a/backends/qualcomm/_passes/recompose_pad_maxpool2d.py +++ b/backends/qualcomm/_passes/recompose_pad_maxpool2d.py @@ -13,12 +13,8 @@ from executorch.exir.pass_base import ExportPass, PassResult from executorch.exir.passes import dead_code_elimination_pass -from torch._subclasses.fake_tensor import FakeTensorMode - - -def add_fake_tensor_to_node(padding_node, input_shape, padding_args, dtype): - fake_mode = FakeTensorMode() +def add_fake_tensor_to_node(padding_node, input_shape, padding_args, dtype, fake_mode): with fake_mode: batch, channels, height, width = input_shape pad_left, pad_right, pad_top, pad_bottom = padding_args @@ -114,6 +110,7 @@ def call(self, graph_module: torch.fx.GraphModule): # noqa C901 input_node.meta["val"].shape, padding, input_node.meta["val"].dtype, + input_node.meta["val"].fake_mode, ) if quant_attrs: padding_node.meta["quant_attrs"] = node.meta["quant_attrs"] diff --git a/backends/qualcomm/_passes/utils.py b/backends/qualcomm/_passes/utils.py index 542fa1115a6..91a7cfdc69a 100755 --- a/backends/qualcomm/_passes/utils.py +++ b/backends/qualcomm/_passes/utils.py @@ -137,7 +137,23 @@ def copy_nn_module_stack(src, target): target.meta["nn_module_stack"] = value -def merge_decomposed_graph( +def _unify_fake_mode(node: torch.fx.Node, fake_mode) -> None: + val = node.meta.get("val") + if val is None: + return + if isinstance(val, FakeTensor) and val.fake_mode is not fake_mode: + node.meta["val"] = fake_mode.from_tensor(val) + elif isinstance(val, (list, tuple)): + unified = [] + for v in val: + if isinstance(v, FakeTensor) and v.fake_mode is not fake_mode: + unified.append(fake_mode.from_tensor(v)) + else: + unified.append(v) + node.meta["val"] = type(val)(unified) + + +def merge_decomposed_graph( # noqa: C901 remap: Dict[str, torch.fx.Node], target_node: torch.fx.Node, target_graph: torch.fx.GraphModule, @@ -148,6 +164,16 @@ def merge_decomposed_graph( [torch.fx.Node, torch.fx.Node, Dict[str, torch.fx.Node]], None ] = None, ) -> None: + target_fake_mode = None + target_val = target_node.meta.get("val") + if isinstance(target_val, FakeTensor): + target_fake_mode = target_val.fake_mode + elif isinstance(target_val, (list, tuple)): + for v in target_val: + if isinstance(v, FakeTensor): + target_fake_mode = v.fake_mode + break + def default_output_process(node): for user in node.users.copy(): # remap @@ -170,10 +196,13 @@ def default_output_process(node): # replace node map from string to graph node remap[decomposed_node] = remap.pop(decomposed_node.name) else: - remap[decomposed_node] = target_graph.node_copy( + copied = target_graph.node_copy( decomposed_node, arg_transform=lambda x, remap=remap: remap[x], ) + if target_fake_mode is not None: + _unify_fake_mode(copied, target_fake_mode) + remap[decomposed_node] = copied def is_float_tensor(node: torch.fx.Node) -> bool: diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 6d5b44d7a35..08f5c1f67de 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -7730,8 +7730,9 @@ def test_llama_stories_110m(self): "--max_context_len", "128", ] + if self.use_fp16: + cmds.append("--use_fp16") self.add_default_cmds(cmds) - golden_start_with = "Once upon a time," p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL) with Listener((self.ip, self.port)) as listener: @@ -7750,7 +7751,10 @@ def test_llama_stories_110m(self): # x86 does not allow weight sharing, so we don't check pte size if not self.enable_x86_64: pte_size = msg["pte_size"] - self.assertLessEqual(pte_size, 135_000_000) # 135MB + if self.use_fp16: + self.assertLessEqual(pte_size, 275_000_000) # 275MB + else: + self.assertLessEqual(pte_size, 135_000_000) # 135MB if not self.compile_only and not self.enable_x86_64: self.assertGreaterEqual(msg["inference_speed"], 220) # Lanai @@ -10087,6 +10091,13 @@ def setup_environment(): choices=["wikitext_ppl", "hellaswag_acc_norm", "sqnr"], type=str, ) + parser.add_argument( + "-F", + "--use_fp16", + help="If specified, will run in fp16 precision and discard ptq setting", + action="store_true", + default=False, + ) args, ns_args = parser.parse_known_args(namespace=unittest) TestQNN.host = args.host @@ -10114,6 +10125,7 @@ def setup_environment(): TestQNN.backend = args.backend TestQNN.static_llm_eval_method = args.static_llm_eval_method TestQNN.direct_build_folder = args.direct_build_folder + TestQNN.use_fp16 = args.use_fp16 return sys.argv[:1] + ns_args diff --git a/backends/qualcomm/tests/utils.py b/backends/qualcomm/tests/utils.py index d8802f74e68..c22ee8371e0 100644 --- a/backends/qualcomm/tests/utils.py +++ b/backends/qualcomm/tests/utils.py @@ -221,6 +221,7 @@ class TestQNN(unittest.TestCase): static_llm_eval_method = "" direct_build_folder: str = "" dsp_heap_profile_filename = "htp_heap_usage.txt" + use_fp16 = False @classmethod def setUpClass(cls): diff --git a/backends/transforms/aten_to_dialect_pass.py b/backends/transforms/aten_to_dialect_pass.py new file mode 100644 index 00000000000..f31df73bc58 --- /dev/null +++ b/backends/transforms/aten_to_dialect_pass.py @@ -0,0 +1,138 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import traceback +from collections.abc import Callable +from dataclasses import dataclass +from typing import ClassVar, TypeAlias + +import torch + +from executorch.backends.xnnpack._passes.xnnpack_pass import ExportPass + +from executorch.exir import ExportedProgram +from torch.fx.node import Target +from torch.fx.passes.infra.pass_manager import PassResult + + +# Expected type to be returned by substitution functions. +@dataclass +class DialectNodeSpec: + op: Target + args: tuple + kwargs: dict = None + + +# Expected type to be used for substitution functions +SubstitutionFn: TypeAlias = Callable[ + [torch.fx.Node, torch.export.ExportedProgram], DialectNodeSpec | None +] + + +class AtenToDialectPass(ExportPass): + """ + General pass to convert ops 1-1 from ATen to a specific dialect. + + Usage: + 1. Subclass the pass for a specific dialect + 2. For each ATen target to be substituted, implement a function returning a DialectNodeSpec defining the + corresponding dialect op, or None if the substitution does not apply. + 3. Register each substitution function for the subclass using the decorator register_dialect_substitution + + Only one substitution function can be registered for a given target. + + The pass must be initialized with an exported_program to allow substitution functions to modify placeholders, + e.g. if the dialect ops require additional scratch buffers. + """ + + _DIALECT_SUBSTITUTIONS: ClassVar[dict[Target, SubstitutionFn]] = {} + + def __init__(self, exported_program: ExportedProgram): + super().__init__() + self.exported_program: ExportedProgram = exported_program + + # Ensure each subclass has its own substitution registry. + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + cls._DIALECT_SUBSTITUTIONS = {} + + @classmethod + def register_dialect_substitution( + cls, target: Target + ) -> Callable[[SubstitutionFn], SubstitutionFn]: + + def decorator(func: SubstitutionFn) -> SubstitutionFn: + if target in cls._DIALECT_SUBSTITUTIONS: + raise RuntimeError( + f"Multiple substitutions registered for the same target in {cls.__name__} are not allowed." + ) + else: + cls._DIALECT_SUBSTITUTIONS[target] = func + return func + + return decorator + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + modified = False + + for node in graph_module.graph.nodes: + if node.op != "call_function": + continue + + substitution_func = self._DIALECT_SUBSTITUTIONS.get(node.target, None) + if substitution_func is None: + continue + + dialect_node_spec = substitution_func(node, self.exported_program) + if dialect_node_spec is None: + continue + + modified = True + with graph_module.graph.inserting_before(node): + dialect_node = graph_module.graph.create_node( + "call_function", + target=dialect_node_spec.op, + args=dialect_node_spec.args, + kwargs=dialect_node_spec.kwargs or {}, + ) + + node.replace_all_uses_with(dialect_node) + + # Keep same meta dict for new node and append new trace + dialect_node.meta = node.meta + old_stack_trace = dialect_node.meta.get("stack_trace", "") + dialect_node.meta["stack_trace"] = ( + f"{old_stack_trace}\n{traceback.format_stack()[-2]}" + ) + + graph_module.graph.erase_node(node) + + if modified: + graph_module.graph.eliminate_dead_code() + graph_module.recompile() + graph_module = super().call(graph_module).graph_module + + return PassResult(graph_module, modified) + + def requires(self, graph_module): + self.ops_before = sum( + 1 for node in graph_module.graph.nodes if node.op == "call_function" + ) + return super().requires(graph_module) + + def ensures(self, graph_module: torch.fx.GraphModule) -> bool: + """Ensure that there has only been 1-1 substitution of call_function nodes, i.e. that the number of call_function nodes is preserved after the pass.""" + + self.ops_after = sum( + 1 for node in graph_module.graph.nodes if node.op == "call_function" + ) + if self.ops_after != self.ops_before: + raise RuntimeError( + f"{self.__class__.__name__} did not preserve the number of call_function nodes: " + f"before={self.ops_before}, after={self.ops_after}" + ) + + return super().ensures(graph_module) diff --git a/backends/transforms/targets.bzl b/backends/transforms/targets.bzl index 8c3603e293d..36466ec4aa0 100644 --- a/backends/transforms/targets.bzl +++ b/backends/transforms/targets.bzl @@ -176,6 +176,21 @@ def define_common_targets(): ], ) + runtime.python_library( + name = "aten_to_dialect_pass", + srcs = [ + "aten_to_dialect_pass.py", + ], + visibility = [ + "//executorch/backends/...", + ], + deps = [ + "//caffe2:torch", + "//executorch/backends/xnnpack/_passes:xnnpack_passes", + "//executorch/exir:lib", + ], + ) + runtime.python_library( name = "rank_0_to_rank_1", srcs = [ @@ -243,6 +258,16 @@ def define_common_targets(): ], ) + runtime.python_test( + name = "test_aten_to_dialect_pass", + srcs = [ + "test/test_aten_to_dialect_pass.py", + ], + deps = [ + "//caffe2:torch", + ":aten_to_dialect_pass", + ], + ) runtime.python_test( name = "test_rank_0_to_rank_1", diff --git a/backends/transforms/test/test_aten_to_dialect_pass.py b/backends/transforms/test/test_aten_to_dialect_pass.py new file mode 100644 index 00000000000..80dbf210d72 --- /dev/null +++ b/backends/transforms/test/test_aten_to_dialect_pass.py @@ -0,0 +1,239 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +import torch +from executorch.backends.transforms.aten_to_dialect_pass import ( + AtenToDialectPass, + DialectNodeSpec, +) +from executorch.backends.transforms.utils import create_constant_placeholder +from torch.export import ExportedProgram +from torch.export.graph_signature import InputKind +from torch.fx import Node + + +class AddModel(torch.nn.Module): + def forward(self, x, y): + return torch.ops.aten.add.Tensor(x, y) + + +class AddAlphaModel(torch.nn.Module): + def forward(self, x, y): + return torch.ops.aten.add.Tensor(x, y, alpha=2) + + +def _count_target(graph_module: torch.fx.GraphModule, target) -> int: + return sum( + 1 + for node in graph_module.graph.nodes + if node.op == "call_function" and node.target == target + ) + + +def _get_target_node(graph_module: torch.fx.GraphModule, target) -> Node: + nodes = [ + node + for node in graph_module.graph.nodes + if node.op == "call_function" and node.target == target + ] + assert len(nodes) == 1 + return nodes[0] + + +def _export_add_model() -> ExportedProgram: + return torch.export.export( + AddModel().eval(), (torch.randn(2, 3), torch.randn(2, 3)), strict=True + ) + + +def _export_add_alpha_model() -> ExportedProgram: + return torch.export.export( + AddAlphaModel().eval(), (torch.randn(2, 3), torch.randn(2, 3)), strict=True + ) + + +def test_rewrites_node_when_substitution_matches() -> None: + class _TestAtenToDialectPass(AtenToDialectPass): + pass + + @_TestAtenToDialectPass.register_dialect_substitution(torch.ops.aten.add.Tensor) + def replace_add_with_sub( + node: Node, exported_program: ExportedProgram + ) -> DialectNodeSpec | None: + del exported_program + return DialectNodeSpec(torch.ops.aten.sub.Tensor, node.args) + + exported_program = _export_add_model() + result = _TestAtenToDialectPass(exported_program=exported_program).call( + exported_program.graph_module + ) + + assert result.modified + assert _count_target(result.graph_module, torch.ops.aten.add.Tensor) == 0 + assert _count_target(result.graph_module, torch.ops.aten.sub.Tensor) == 1 + + +def test_substitution_can_add_state_dict_placeholder() -> None: + class _TestAtenToDialectPass(AtenToDialectPass): + pass + + @_TestAtenToDialectPass.register_dialect_substitution(torch.ops.aten.add.Tensor) + def replace_add_rhs_with_constant( + node: Node, exported_program: ExportedProgram + ) -> DialectNodeSpec | None: + first_placeholder = next( + graph_node + for graph_node in node.graph.nodes + if graph_node.op == "placeholder" + ) + with node.graph.inserting_before(first_placeholder): + const_node = create_constant_placeholder( + exp_program=exported_program, + graph=node.graph, + name="test_constant", + kind=InputKind.PARAMETER, + data=torch.ones(2, 3), + ) + return DialectNodeSpec(torch.ops.aten.add.Tensor, (node.args[0], const_node)) + + exported_program = _export_add_model() + result = _TestAtenToDialectPass(exported_program=exported_program).call( + exported_program.graph_module + ) + + assert result.modified + assert "test_constant" in exported_program.state_dict + assert torch.equal(exported_program.state_dict["test_constant"], torch.ones(2, 3)) + assert ( + exported_program.graph_signature.inputs_to_parameters["test_constant"] + == "test_constant" + ) + add_node = _get_target_node(result.graph_module, torch.ops.aten.add.Tensor) + assert add_node.args[1].name == "test_constant" + + x = torch.full((2, 3), 2.0) + y = torch.full((2, 3), 5.0) + torch.testing.assert_close(exported_program.module()(x, y), x + torch.ones_like(x)) + + +def test_substitution_can_change_kwargs() -> None: + class _TestAtenToDialectPass(AtenToDialectPass): + pass + + @_TestAtenToDialectPass.register_dialect_substitution(torch.ops.aten.add.Tensor) + def replace_add_alpha( + node: Node, exported_program: ExportedProgram + ) -> DialectNodeSpec | None: + del exported_program + return DialectNodeSpec(torch.ops.aten.add.Tensor, node.args, {"alpha": 3}) + + exported_program = _export_add_alpha_model() + result = _TestAtenToDialectPass(exported_program=exported_program).call( + exported_program.graph_module + ) + + assert result.modified + add_node = _get_target_node(result.graph_module, torch.ops.aten.add.Tensor) + assert add_node.kwargs["alpha"] == 3 + + x = torch.full((2, 3), 2.0) + y = torch.full((2, 3), 5.0) + torch.testing.assert_close(exported_program.module()(x, y), x + 3 * y) + + +def test_preserves_meta_when_substitution_matches() -> None: + class _TestAtenToDialectPass(AtenToDialectPass): + pass + + @_TestAtenToDialectPass.register_dialect_substitution(torch.ops.aten.add.Tensor) + def replace_add_with_sub( + node: Node, exported_program: ExportedProgram + ) -> DialectNodeSpec | None: + del exported_program + return DialectNodeSpec(torch.ops.aten.sub.Tensor, node.args) + + exported_program = _export_add_model() + add_node = _get_target_node( + exported_program.graph_module, torch.ops.aten.add.Tensor + ) + add_node.meta["test_sentinel"] = "kept" + add_node.meta["stack_trace"] = "original stack" + + result = _TestAtenToDialectPass(exported_program=exported_program).call( + exported_program.graph_module + ) + + sub_node = _get_target_node(result.graph_module, torch.ops.aten.sub.Tensor) + assert sub_node.meta["test_sentinel"] == "kept" + assert sub_node.meta["stack_trace"].startswith("original stack\n") + assert sub_node.meta["stack_trace"] != "original stack" + + +def test_keeps_node_when_substitution_returns_none() -> None: + class _TestAtenToDialectPass(AtenToDialectPass): + pass + + @_TestAtenToDialectPass.register_dialect_substitution(torch.ops.aten.add.Tensor) + def do_not_replace( + node: Node, exported_program: ExportedProgram + ) -> DialectNodeSpec | None: + del node, exported_program + return None + + exported_program = _export_add_model() + result = _TestAtenToDialectPass(exported_program=exported_program).call( + exported_program.graph_module + ) + + assert not result.modified + assert _count_target(result.graph_module, torch.ops.aten.add.Tensor) == 1 + assert _count_target(result.graph_module, torch.ops.aten.sub.Tensor) == 0 + + +def test_raises_when_duplicate_substitution_is_registered() -> None: + class _TestAtenToDialectPass(AtenToDialectPass): + pass + + @_TestAtenToDialectPass.register_dialect_substitution(torch.ops.aten.add.Tensor) + def first_replace( + node: Node, exported_program: ExportedProgram + ) -> DialectNodeSpec | None: + del exported_program + return DialectNodeSpec(torch.ops.aten.sub.Tensor, node.args) + + with pytest.raises(RuntimeError, match="Multiple substitutions registered"): + + @_TestAtenToDialectPass.register_dialect_substitution(torch.ops.aten.add.Tensor) + def second_replace( + node: Node, exported_program: ExportedProgram + ) -> DialectNodeSpec | None: + del exported_program + return DialectNodeSpec(torch.ops.aten.mul.Tensor, node.args) + + +def test_ensures_raises_when_call_function_count_changes() -> None: + class _TestAtenToDialectPass(AtenToDialectPass): + pass + + exported_program = _export_add_model() + graph_module = exported_program.graph_module + test_pass = _TestAtenToDialectPass(exported_program=exported_program) + test_pass.requires(graph_module) + + placeholders = [ + node for node in graph_module.graph.nodes if node.op == "placeholder" + ] + output_node = next(node for node in graph_module.graph.nodes if node.op == "output") + with graph_module.graph.inserting_before(output_node): + graph_module.graph.create_node( + "call_function", + target=torch.ops.aten.sub.Tensor, + args=tuple(placeholders), + kwargs={}, + ) + + with pytest.raises(RuntimeError, match="did not preserve"): + test_pass.ensures(graph_module) diff --git a/backends/vulkan/test/op_tests/utils/gen_computegraph.py b/backends/vulkan/test/op_tests/utils/gen_computegraph.py index a09b4d36b18..507719b8555 100644 --- a/backends/vulkan/test/op_tests/utils/gen_computegraph.py +++ b/backends/vulkan/test/op_tests/utils/gen_computegraph.py @@ -286,7 +286,7 @@ def create_aten_fn_call(self) -> str: def create_aten_method_call(self) -> str: # For functions with only Method variant, we fallback to the function # declared in MethodOperators.h - cpp_sig = gen_static_dispatch_backend_call_signature(self.f_sig, self.f) + cpp_sig = gen_static_dispatch_backend_call_signature(self.f) exprs = translate_args(self.f_sig, cpp_sig) func_call = f"at::_ops::{self.f_sig.name()}::call({exprs});" return func_call diff --git a/backends/vulkan/test/test_vulkan_delegate.py b/backends/vulkan/test/test_vulkan_delegate.py index 7c9f31b720c..ff709618259 100644 --- a/backends/vulkan/test/test_vulkan_delegate.py +++ b/backends/vulkan/test/test_vulkan_delegate.py @@ -7,6 +7,7 @@ # pyre-unsafe import ctypes +import functools import unittest from typing import Tuple @@ -42,6 +43,24 @@ pass +def disable_test(reason): + """Disable a test while still reporting it as executed. + + Some test runners do not handle skipped results consistently, so this keeps + disabled tests visible in logs without using unittest.skip. + """ + + def decorator(fn): + @functools.wraps(fn) + def wrapper(*args, **kwargs): + print(f"DISABLED_TEST: {fn.__qualname__}: {reason}") + return None + + return wrapper + + return decorator + + def lower_module( model: torch.nn.Module, sample_inputs: Tuple[torch.Tensor], dynamic_shapes=None ) -> EdgeProgramManager: @@ -743,7 +762,7 @@ def forward(self, x): self.lower_module_and_test_output(model, sample_inputs) - @unittest.skip( + @disable_test( "Currently this test is failing due to weird partitioning because the eq scalar" "operator is not supported yet. Re-enable when the operator is supported." ) @@ -810,7 +829,7 @@ def forward(self, x): self.lower_module_and_test_output(module, sample_inputs) - @unittest.skip( + @disable_test( "Reduce shader does not support multiple reduction axes at the moment" ) def test_vulkan_backend_sum_dim_list(self): @@ -831,7 +850,7 @@ def forward(self, x): sample_inputs, ) - @unittest.skip( + @disable_test( "Reduce shader does not support multiple reduction axes at the moment" ) def test_vulkan_backend_sum(self): @@ -1028,7 +1047,7 @@ def forward(self, x): sample_inputs, ) - @unittest.skip("layer norm compute shader not working with swiftshader") + @disable_test("layer norm compute shader not working with swiftshader") def test_vulkan_backend_native_layer_norm(self): class NativeLayerNormModule(torch.nn.Module): def __init__(self): @@ -1459,7 +1478,7 @@ def forward(self, x): sample_inputs, ) - @unittest.skip( + @disable_test( "Softmax shader with shared memory does not work with swiftshader due to potential swiftshader bug" ) def test_vulkan_backend_softmax(self): @@ -1480,7 +1499,7 @@ def forward(self, x): sample_inputs, ) - @unittest.skip( + @disable_test( "Softmax shader with shared memory does not work with swiftshader due to potential swiftshader bug" ) def test_vulkan_backend_logsoftmax(self): @@ -1512,7 +1531,7 @@ def forward(self, x): self.lower_unary_module_and_test_output(GeluModule()) - @unittest.skip( + @disable_test( "Reduce shader does not support multiple reduction axes at the moment" ) def test_vulkan_backend_mean(self): @@ -2364,7 +2383,7 @@ def apply_quantization(self): quantized_linear_module_gemm, sample_inputs_gemm, atol=1e-2, rtol=1e-2 ) - @unittest.skip("Cannot run on swiftshader due to no integer dot product support") + @disable_test("Cannot run on swiftshader due to no integer dot product support") def test_vulkan_backend_xnnpack_pt2e_quantized_linear_sequence(self): """ Test a sequence of linear layers quantized with XNNPACK quantization config. @@ -2439,7 +2458,7 @@ def forward(self, x): rtol=1e-1, ) - @unittest.skip("Cannot run on swiftshader due to no integer dot product support") + @disable_test("Cannot run on swiftshader due to no integer dot product support") def test_vulkan_backend_xnnpack_pt2e_quantized_conv_sequence(self): """ Test a sequence of convolution layers quantized with PT2E quantization. @@ -2530,7 +2549,7 @@ def forward(self, x): rtol=1e-1, ) - @unittest.skip("Cannot run on swiftshader due to no integer dot product support") + @disable_test("Cannot run on swiftshader due to no integer dot product support") def test_vulkan_backend_xnnpack_pt2e_quantized_conv_sequence_all_reduced(self): """ Test a sequence of convolution layers quantized with PT2E quantization. @@ -2610,7 +2629,7 @@ def forward(self, x): rtol=1e-1, ) - @unittest.skip("Cannot run on swiftshader due to no 8-bit int support") + @disable_test("Cannot run on swiftshader due to no 8-bit int support") def test_vulkan_backend_torchao_8da4w_quantized_linear(self): """ Test TorchAO 8da4w quantization (int8 dynamic activation + int4 weight) with Vulkan backend. diff --git a/backends/webgpu/runtime/WebGPUGraph.cpp b/backends/webgpu/runtime/WebGPUGraph.cpp index f0e4c7959c0..91404fb164f 100644 --- a/backends/webgpu/runtime/WebGPUGraph.cpp +++ b/backends/webgpu/runtime/WebGPUGraph.cpp @@ -50,9 +50,15 @@ size_t vk_datatype_size(vkgraph::VkDataType dtype) { WebGPUGraph::WebGPUGraph() = default; WebGPUGraph::~WebGPUGraph() { - for (auto& t : tensors_) { - if (t.buffer) { - wgpuBufferRelease(t.buffer); + for (size_t i = 0; i < tensors_.size(); i++) { + if (tensors_[i].buffer && + (i >= tensor_mem_obj_ids_.size() || tensor_mem_obj_ids_[i] < 0)) { + wgpuBufferRelease(tensors_[i].buffer); + } + } + for (auto& buf : shared_buffers_) { + if (buf) { + wgpuBufferRelease(buf); } } for (auto& buf : output_staging_buffers_) { @@ -68,6 +74,21 @@ WebGPUGraph::~WebGPUGraph() { wgpuBindGroupRelease(d.bind_group); } } + for (auto& [_, shader] : shader_cache_) { + if (shader) { + wgpuShaderModuleRelease(shader); + } + } + for (auto& [_, pipeline] : pipeline_cache_) { + if (pipeline) { + wgpuComputePipelineRelease(pipeline); + } + } + for (auto& [_, bgl] : bgl_cache_) { + if (bgl) { + wgpuBindGroupLayoutRelease(bgl); + } + } } void WebGPUGraph::build( @@ -94,6 +115,7 @@ void WebGPUGraph::build( const int num_vals = values ? values->size() : 0; value_types_.resize(num_vals, ValueType::Null); tensors_.resize(num_vals); + tensor_mem_obj_ids_.resize(num_vals, -1); ints_.resize(num_vals, 0); doubles_.resize(num_vals, 0.0); bools_.resize(num_vals, false); @@ -121,27 +143,40 @@ void WebGPUGraph::build( } tensor.nbytes = numel * vk_datatype_size(vk_tensor->datatype()); - // Create GPU buffer - WGPUBufferDescriptor buf_desc = {}; - buf_desc.size = tensor.nbytes > 0 ? tensor.nbytes : 4; - buf_desc.usage = WGPUBufferUsage_Storage | WGPUBufferUsage_CopyDst | - WGPUBufferUsage_CopySrc; - buf_desc.mappedAtCreation = false; - tensor.buffer = wgpuDeviceCreateBuffer(device_, &buf_desc); - - // Upload constant data if this tensor has a constant_id int constant_id = vk_tensor->constant_id(); - if (constant_id >= 0 && constant_data) { - const auto* constants = graph->constants(); - if (constants && constant_id < static_cast(constants->size())) { - const auto* vk_bytes = constants->Get(constant_id); - // Only upload from embedded bytes (not named data map) - if (vk_bytes->offset() != UINT64_MAX) { - const uint8_t* src = constant_data + vk_bytes->offset(); - wgpuQueueWriteBuffer( - queue_, tensor.buffer, 0, src, tensor.nbytes); + int mem_obj_id = vk_tensor->mem_obj_id(); + + // Constants always get dedicated buffers regardless of mem_obj_id + if (constant_id >= 0 || mem_obj_id < 0) { + tensor_mem_obj_ids_[i] = -1; + WGPUBufferDescriptor buf_desc = {}; + buf_desc.size = std::max(tensor.nbytes, size_t(4)); + buf_desc.usage = WGPUBufferUsage_Storage | WGPUBufferUsage_CopyDst | + WGPUBufferUsage_CopySrc; + buf_desc.mappedAtCreation = false; + tensor.buffer = wgpuDeviceCreateBuffer(device_, &buf_desc); + + if (constant_id >= 0 && constant_data && tensor.nbytes > 0) { + const auto* constants = graph->constants(); + if (constants && + constant_id < static_cast(constants->size())) { + const auto* vk_bytes = constants->Get(constant_id); + if (vk_bytes->offset() != UINT64_MAX) { + const uint8_t* src = constant_data + vk_bytes->offset(); + wgpuQueueWriteBuffer( + queue_, tensor.buffer, 0, src, tensor.nbytes); + } } } + } else { + // Shared buffer: track required size, defer allocation to pass 2 + tensor_mem_obj_ids_[i] = mem_obj_id; + size_t id = static_cast(mem_obj_id); + if (id >= shared_buffer_sizes_.size()) { + shared_buffer_sizes_.resize(id + 1, 0); + } + shared_buffer_sizes_[id] = + std::max(shared_buffer_sizes_[id], tensor.nbytes); } break; } @@ -166,6 +201,23 @@ void WebGPUGraph::build( } } + // Allocate shared buffers and assign to tensors + shared_buffers_.resize(shared_buffer_sizes_.size(), nullptr); + for (size_t id = 0; id < shared_buffer_sizes_.size(); id++) { + WGPUBufferDescriptor buf_desc = {}; + buf_desc.size = std::max(shared_buffer_sizes_[id], size_t(4)); + buf_desc.usage = WGPUBufferUsage_Storage | WGPUBufferUsage_CopyDst | + WGPUBufferUsage_CopySrc; + buf_desc.mappedAtCreation = false; + shared_buffers_[id] = wgpuDeviceCreateBuffer(device_, &buf_desc); + } + for (int i = 0; i < num_vals; i++) { + int mid = tensor_mem_obj_ids_[i]; + if (mid >= 0) { + tensors_[i].buffer = shared_buffers_[mid]; + } + } + // Phase 2: Record input and output IDs const auto* fb_input_ids = graph->input_ids(); if (fb_input_ids) { @@ -181,7 +233,7 @@ void WebGPUGraph::build( // Create staging buffer for output readback WGPUBufferDescriptor staging_desc = {}; - staging_desc.size = tensors_[oid].nbytes > 0 ? tensors_[oid].nbytes : 4; + staging_desc.size = std::max(tensors_[oid].nbytes, size_t(4)); staging_desc.usage = WGPUBufferUsage_MapRead | WGPUBufferUsage_CopyDst; staging_desc.mappedAtCreation = false; output_staging_buffers_.push_back( @@ -189,6 +241,14 @@ void WebGPUGraph::build( } } + for (size_t i = 0; i < output_ids_.size(); i++) { + int oid = output_ids_[i]; + output_copies_.push_back( + {tensors_[oid].buffer, + output_staging_buffers_[i], + tensors_[oid].nbytes}); + } + // Phase 3: Build operator dispatch chain const auto* chain = graph->chain(); if (chain) { @@ -213,9 +273,70 @@ void WebGPUGraph::build( } } +WGPUShaderModule WebGPUGraph::get_or_create_shader( + const std::string& key, + const char* wgsl_source) { + auto it = shader_cache_.find(key); + if (it != shader_cache_.end()) { + return it->second; + } + + WGPUShaderSourceWGSL wgsl_desc = {}; + wgsl_desc.chain.sType = WGPUSType_ShaderSourceWGSL; + wgsl_desc.code = {wgsl_source, WGPU_STRLEN}; + + WGPUShaderModuleDescriptor shader_desc = {}; + shader_desc.nextInChain = &wgsl_desc.chain; + WGPUShaderModule shader = wgpuDeviceCreateShaderModule(device_, &shader_desc); + + shader_cache_[key] = shader; + return shader; +} + +WGPUComputePipeline WebGPUGraph::get_or_create_pipeline( + const std::string& key, + WGPUShaderModule shader, + WGPUPipelineLayout layout) { + auto it = pipeline_cache_.find(key); + if (it != pipeline_cache_.end()) { + return it->second; + } + + WGPUComputePipelineDescriptor pipeline_desc = {}; + pipeline_desc.layout = layout; + pipeline_desc.compute.module = shader; + pipeline_desc.compute.entryPoint = {"main", WGPU_STRLEN}; + WGPUComputePipeline pipeline = + wgpuDeviceCreateComputePipeline(device_, &pipeline_desc); + + pipeline_cache_[key] = pipeline; + return pipeline; +} + +WGPUBindGroupLayout WebGPUGraph::get_or_create_bgl( + const std::string& key, + const WGPUBindGroupLayoutEntry* entries, + uint32_t count) { + auto it = bgl_cache_.find(key); + if (it != bgl_cache_.end()) { + return it->second; + } + + WGPUBindGroupLayoutDescriptor bgl_desc = {}; + bgl_desc.entryCount = count; + bgl_desc.entries = entries; + WGPUBindGroupLayout bgl = wgpuDeviceCreateBindGroupLayout(device_, &bgl_desc); + + bgl_cache_[key] = bgl; + return bgl; +} + void WebGPUGraph::copy_inputs( const std::vector>& inputs) { for (size_t i = 0; i < inputs.size() && i < input_ids_.size(); i++) { + if (inputs[i].second == 0) { + continue; + } int tid = input_ids_[i]; const auto& tensor = tensors_[tid]; wgpuQueueWriteBuffer( @@ -224,43 +345,89 @@ void WebGPUGraph::copy_inputs( } void WebGPUGraph::execute() { - WGPUCommandEncoderDescriptor enc_desc = {}; - WGPUCommandEncoder encoder = - wgpuDeviceCreateCommandEncoder(device_, &enc_desc); - - WGPUComputePassDescriptor pass_desc = {}; - WGPUComputePassEncoder pass = - wgpuCommandEncoderBeginComputePass(encoder, &pass_desc); - - for (const auto& dispatch : dispatches_) { - wgpuComputePassEncoderSetPipeline(pass, dispatch.pipeline); - wgpuComputePassEncoderSetBindGroup( - pass, 0, dispatch.bind_group, 0, nullptr); - wgpuComputePassEncoderDispatchWorkgroups( - pass, dispatch.workgroup_count_x, 1, 1); - } + const size_t n = dispatches_.size(); + const size_t chunk = execute_config_.chunk_size; + + if (chunk == 0 || n <= chunk) { + WGPUCommandEncoderDescriptor enc_desc = {}; + WGPUCommandEncoder encoder = + wgpuDeviceCreateCommandEncoder(device_, &enc_desc); + + WGPUComputePassDescriptor pass_desc = {}; + WGPUComputePassEncoder pass = + wgpuCommandEncoderBeginComputePass(encoder, &pass_desc); + + for (const auto& dispatch : dispatches_) { + wgpuComputePassEncoderSetPipeline(pass, dispatch.pipeline); + wgpuComputePassEncoderSetBindGroup( + pass, 0, dispatch.bind_group, 0, nullptr); + wgpuComputePassEncoderDispatchWorkgroups( + pass, dispatch.workgroup_count_x, 1, 1); + } - wgpuComputePassEncoderEnd(pass); - wgpuComputePassEncoderRelease(pass); + wgpuComputePassEncoderEnd(pass); + wgpuComputePassEncoderRelease(pass); - // Copy outputs to staging buffers - for (size_t i = 0; i < output_ids_.size(); i++) { - int oid = output_ids_[i]; - wgpuCommandEncoderCopyBufferToBuffer( - encoder, - tensors_[oid].buffer, - 0, - output_staging_buffers_[i], - 0, - tensors_[oid].nbytes); + for (const auto& copy : output_copies_) { + wgpuCommandEncoderCopyBufferToBuffer( + encoder, copy.src_buffer, 0, copy.staging_buffer, 0, copy.nbytes); + } + + WGPUCommandBufferDescriptor cmd_desc = {}; + WGPUCommandBuffer cmd = wgpuCommandEncoderFinish(encoder, &cmd_desc); + wgpuQueueSubmit(queue_, 1, &cmd); + + wgpuCommandBufferRelease(cmd); + wgpuCommandEncoderRelease(encoder); + return; } - WGPUCommandBufferDescriptor cmd_desc = {}; - WGPUCommandBuffer cmd = wgpuCommandEncoderFinish(encoder, &cmd_desc); - wgpuQueueSubmit(queue_, 1, &cmd); + const size_t first_chunk = execute_config_.initial_chunk_size > 0 + ? execute_config_.initial_chunk_size + : chunk; + + size_t start = 0; + size_t current_chunk = first_chunk; - wgpuCommandBufferRelease(cmd); - wgpuCommandEncoderRelease(encoder); + while (start < n) { + size_t end = std::min(start + current_chunk, n); + + WGPUCommandEncoderDescriptor enc_desc = {}; + WGPUCommandEncoder encoder = + wgpuDeviceCreateCommandEncoder(device_, &enc_desc); + + WGPUComputePassDescriptor pass_desc = {}; + WGPUComputePassEncoder pass = + wgpuCommandEncoderBeginComputePass(encoder, &pass_desc); + + for (size_t i = start; i < end; i++) { + wgpuComputePassEncoderSetPipeline(pass, dispatches_[i].pipeline); + wgpuComputePassEncoderSetBindGroup( + pass, 0, dispatches_[i].bind_group, 0, nullptr); + wgpuComputePassEncoderDispatchWorkgroups( + pass, dispatches_[i].workgroup_count_x, 1, 1); + } + + wgpuComputePassEncoderEnd(pass); + wgpuComputePassEncoderRelease(pass); + + if (end == n) { + for (const auto& copy : output_copies_) { + wgpuCommandEncoderCopyBufferToBuffer( + encoder, copy.src_buffer, 0, copy.staging_buffer, 0, copy.nbytes); + } + } + + WGPUCommandBufferDescriptor cmd_desc = {}; + WGPUCommandBuffer cmd = wgpuCommandEncoderFinish(encoder, &cmd_desc); + wgpuQueueSubmit(queue_, 1, &cmd); + + wgpuCommandBufferRelease(cmd); + wgpuCommandEncoderRelease(encoder); + + start = end; + current_chunk = chunk; + } } namespace { @@ -283,24 +450,35 @@ void buffer_map_callback( } // namespace void WebGPUGraph::copy_outputs(std::vector>& outputs) { - for (size_t i = 0; i < outputs.size() && i < output_staging_buffers_.size(); - i++) { - MapCallbackData cb_data; + const size_t count = std::min(outputs.size(), output_staging_buffers_.size()); + + std::vector cb_data(count); + + for (size_t i = 0; i < count; i++) { + if (outputs[i].second == 0) { + cb_data[i].done = true; + cb_data[i].status = WGPUMapAsyncStatus_Success; + continue; + } WGPUBufferMapCallbackInfo cb_info = {}; cb_info.mode = WGPUCallbackMode_AllowSpontaneous; cb_info.callback = buffer_map_callback; - cb_info.userdata1 = &cb_data; + cb_info.userdata1 = &cb_data[i]; wgpuBufferMapAsync( output_staging_buffers_[i], WGPUMapMode_Read, 0, outputs[i].second, cb_info); + } - // Poll until the map callback fires. - wgpuDevicePoll(device_, true, nullptr); + wgpuDevicePoll(device_, true, nullptr); - if (cb_data.status == WGPUMapAsyncStatus_Success) { + for (size_t i = 0; i < count; i++) { + if (outputs[i].second == 0) { + continue; + } + if (cb_data[i].status == WGPUMapAsyncStatus_Success) { const void* mapped = wgpuBufferGetConstMappedRange( output_staging_buffers_[i], 0, outputs[i].second); std::memcpy(outputs[i].first, mapped, outputs[i].second); @@ -315,15 +493,28 @@ WebGPUMemoryStats WebGPUGraph::memory_stats() const { WebGPUMemoryStats stats; for (size_t i = 0; i < value_types_.size(); i++) { if (value_types_[i] == ValueType::Tensor && tensors_[i].nbytes > 0) { - stats.tensor_buffer_bytes += tensors_[i].nbytes; stats.num_tensors++; + // Shared tensors are tracked via shared_buffer_sizes_ + bool is_shared = + i < tensor_mem_obj_ids_.size() && tensor_mem_obj_ids_[i] >= 0; + if (!is_shared) { + stats.unshared_tensor_buffer_bytes += tensors_[i].nbytes; + } } } + for (size_t s : shared_buffer_sizes_) { + stats.shared_buffer_bytes += s; + } + stats.num_shared_objects = static_cast(shared_buffers_.size()); + stats.tensor_buffer_bytes = + stats.shared_buffer_bytes + stats.unshared_tensor_buffer_bytes; for (size_t i = 0; i < output_ids_.size(); i++) { stats.staging_buffer_bytes += tensors_[output_ids_[i]].nbytes; } stats.uniform_buffer_bytes = uniform_buffer_bytes_; stats.num_dispatches = static_cast(dispatches_.size()); + stats.num_cached_pipelines = static_cast(pipeline_cache_.size()); + stats.num_cached_shaders = static_cast(shader_cache_.size()); return stats; } diff --git a/backends/webgpu/runtime/WebGPUGraph.h b/backends/webgpu/runtime/WebGPUGraph.h index 2d6996e9219..3aa96917a4e 100644 --- a/backends/webgpu/runtime/WebGPUGraph.h +++ b/backends/webgpu/runtime/WebGPUGraph.h @@ -12,6 +12,7 @@ #include #include +#include #include namespace executorch { @@ -30,12 +31,28 @@ struct WebGPUDispatch { uint32_t workgroup_count_x = 1; }; +struct OutputCopy { + WGPUBuffer src_buffer = nullptr; + WGPUBuffer staging_buffer = nullptr; + size_t nbytes = 0; +}; + +struct ExecuteConfig { + size_t chunk_size = 0; + size_t initial_chunk_size = 0; +}; + struct WebGPUMemoryStats { size_t tensor_buffer_bytes = 0; + size_t shared_buffer_bytes = 0; + int num_shared_objects = 0; + size_t unshared_tensor_buffer_bytes = 0; size_t staging_buffer_bytes = 0; size_t uniform_buffer_bytes = 0; int num_tensors = 0; int num_dispatches = 0; + int num_cached_pipelines = 0; + int num_cached_shaders = 0; size_t total_bytes() const { return tensor_buffer_bytes + staging_buffer_bytes + uniform_buffer_bytes; @@ -99,6 +116,20 @@ class WebGPUGraph { uniform_buffer_bytes_ += bytes; } + WGPUShaderModule get_or_create_shader( + const std::string& key, + const char* wgsl_source); + + WGPUComputePipeline get_or_create_pipeline( + const std::string& key, + WGPUShaderModule shader, + WGPUPipelineLayout layout); + + WGPUBindGroupLayout get_or_create_bgl( + const std::string& key, + const WGPUBindGroupLayoutEntry* entries, + uint32_t count); + void set_instance(WGPUInstance instance) { instance_ = instance; } @@ -134,11 +165,26 @@ class WebGPUGraph { std::vector input_ids_; std::vector output_ids_; + // Memory aliasing: tensors with the same mem_obj_id share a WGPUBuffer. + std::vector tensor_mem_obj_ids_; + std::vector shared_buffers_; + std::vector shared_buffer_sizes_; + // Staging buffers for reading back outputs (MapRead | CopyDst). std::vector output_staging_buffers_; + // Pre-computed output copy descriptors for execute(). + std::vector output_copies_; + std::vector dispatches_; + ExecuteConfig execute_config_; + + // Caches for reusing GPU objects across dispatches. + std::unordered_map shader_cache_; + std::unordered_map pipeline_cache_; + std::unordered_map bgl_cache_; + size_t uniform_buffer_bytes_ = 0; }; diff --git a/backends/webgpu/test/ops/add/test_add.py b/backends/webgpu/test/ops/add/test_add.py index f4b33ced76d..e8da644a1f9 100644 --- a/backends/webgpu/test/ops/add/test_add.py +++ b/backends/webgpu/test/ops/add/test_add.py @@ -31,6 +31,8 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: z = x + y z = z + x z = z + y + z = z + x + z = z + y return z @@ -97,5 +99,18 @@ def export_add_model(output_path: str) -> None: print(f"Exported {output_path}") +def export_chained_add_model(output_path: str) -> None: + """Export a chained add model (z=x+y; z=z+x; z=z+y; z=z+x; z=z+y) to .pte for memory aliasing testing.""" + model = AddChainedModule() + example_inputs = (torch.randn(1024, 1024), torch.randn(1024, 1024)) + ep = torch.export.export(model, example_inputs) + et_program = to_edge_transform_and_lower( + ep, partitioner=[VulkanPartitioner()] + ).to_executorch() + with open(output_path, "wb") as f: + f.write(et_program.buffer) + print(f"Exported {output_path}") + + if __name__ == "__main__": unittest.main() diff --git a/backends/webgpu/test/test_build_webgpu.sh b/backends/webgpu/test/test_build_webgpu.sh index 684926cb181..a42b2304ee7 100755 --- a/backends/webgpu/test/test_build_webgpu.sh +++ b/backends/webgpu/test/test_build_webgpu.sh @@ -22,12 +22,14 @@ $PYTHON_EXECUTABLE -m pytest "${SCRIPT_DIR}/ops/add/test_add.py" -v # ── Step 2: Export .pte model ───────────────────────────────────────────────── -echo "=== Step 2: Export test model ===" +echo "=== Step 2: Export test models ===" PTE_MODEL="/tmp/webgpu_add_test.pte" +PTE_CHAINED_MODEL="/tmp/webgpu_chained_add_test.pte" cd "${EXECUTORCH_ROOT}" $PYTHON_EXECUTABLE -c " -from executorch.backends.webgpu.test.ops.add.test_add import export_add_model +from executorch.backends.webgpu.test.ops.add.test_add import export_add_model, export_chained_add_model export_add_model('${PTE_MODEL}') +export_chained_add_model('${PTE_CHAINED_MODEL}') " # ── Step 3: Native build + test (wgpu-native) ──────────────────────────────── @@ -60,6 +62,7 @@ cmake --build "${NATIVE_BUILD_DIR}" --target webgpu_native_test -j${NPROC} echo "=== Step 4: Run native test ===" WEBGPU_TEST_MODEL="${PTE_MODEL}" \ +WEBGPU_TEST_CHAINED_MODEL="${PTE_CHAINED_MODEL}" \ "${NATIVE_BUILD_DIR}/backends/webgpu/webgpu_native_test" echo "=== Done ===" diff --git a/backends/webgpu/test/test_webgpu_native.cpp b/backends/webgpu/test/test_webgpu_native.cpp index c60695e11c9..d3005debf37 100644 --- a/backends/webgpu/test/test_webgpu_native.cpp +++ b/backends/webgpu/test/test_webgpu_native.cpp @@ -75,6 +75,62 @@ static bool test_single_add(const std::string& model_path) { return true; } +static bool test_chained_add(const std::string& model_path) { + printf("\n--- Test: chained add (1024x1024, 5 ops) ---\n"); + + Module module(model_path); + auto err = module.load_forward(); + if (err != Error::Ok) { + printf("FAIL: could not load forward method (error %d)\n", (int)err); + return false; + } + printf("Model loaded: %s\n", model_path.c_str()); + + constexpr int dim = 1024; + constexpr int size = dim * dim; + + std::vector x_data(size); + std::vector y_data(size); + for (int i = 0; i < size; i++) { + x_data[i] = static_cast(i % 100) * 0.01f; + y_data[i] = static_cast(i % 50) * 0.02f; + } + + auto x = make_tensor_ptr({dim, dim}, std::vector(x_data)); + auto y = make_tensor_ptr({dim, dim}, std::vector(y_data)); + + auto result = module.forward({EValue(x), EValue(y)}); + if (!result.ok()) { + printf("FAIL: forward failed (error %d)\n", (int)result.error()); + return false; + } + + const auto& outputs = result.get(); + if (outputs.empty() || !outputs[0].isTensor()) { + printf("FAIL: no tensor output\n"); + return false; + } + + // z=x+y; z=z+x=2x+y; z=z+y=2x+2y; z=z+x=3x+2y; z=z+y=3x+3y + const auto& out_tensor = outputs[0].toTensor(); + const float* out_data = out_tensor.const_data_ptr(); + + float max_error = 0.0f; + for (int i = 0; i < size; i++) { + float expected = 3.0f * x_data[i] + 3.0f * y_data[i]; + float error = std::abs(out_data[i] - expected); + max_error = std::max(max_error, error); + } + + printf("Max error: %e (checked %d elements)\n", max_error, size); + if (max_error > 1e-3f) { + printf("FAIL: max error exceeds tolerance 1e-3\n"); + return false; + } + printf("PASS: chained add test\n"); + return true; +} + int main(int argc, char** argv) { std::string model_path = "webgpu_add_test.pte"; if (argc > 1) { @@ -84,6 +140,11 @@ int main(int argc, char** argv) { model_path = env; } + std::string chained_model_path; + if (const char* env = std::getenv("WEBGPU_TEST_CHAINED_MODEL")) { + chained_model_path = env; + } + WebGPUContext ctx; try { ctx = create_webgpu_context(); @@ -97,6 +158,10 @@ int main(int argc, char** argv) { bool ok = test_single_add(model_path); + if (!chained_model_path.empty()) { + ok = test_chained_add(chained_model_path) && ok; + } + set_default_webgpu_context(nullptr); destroy_webgpu_context(ctx); diff --git a/backends/xnnpack/operators/op_clone.py b/backends/xnnpack/operators/op_clone.py index e4ddf187ecc..c36d750148c 100644 --- a/backends/xnnpack/operators/op_clone.py +++ b/backends/xnnpack/operators/op_clone.py @@ -1,5 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. +# Copyright 2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -13,6 +14,7 @@ NodeVisitor, register_node_visitor, ) +from executorch.backends.xnnpack.operators.quant_params import QuantParams from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import ( XNNCopy, XNNGraph, @@ -25,9 +27,6 @@ class CloneVisitor(NodeVisitor): target = "aten.clone.default" - def __init__(self, *args) -> None: - super().__init__(*args) - def define_node( self, node: torch.fx.Node, @@ -35,7 +34,19 @@ def define_node( vals_to_ids: Dict[torch.fx.Node, int], debug_handle: int, ) -> None: - self.define_nodes_tensor_inputs_outputs(node, xnn_graph, vals_to_ids) + self.define_tensor( + node, + xnn_graph, + vals_to_ids, + quant_params=QuantParams.from_outputs(node), + ) + input_node = get_input_node(node, 0) + self.define_tensor( + input_node, + xnn_graph, + vals_to_ids, + quant_params=QuantParams.from_inputs(input_node, self._exported_program), + ) # Sanity check that the input and output dim order are the same. We don't # handle dim order conversions yet. diff --git a/backends/xnnpack/partition/config/__init__.py b/backends/xnnpack/partition/config/__init__.py index d0a3e94bbc9..c6c54f083d6 100644 --- a/backends/xnnpack/partition/config/__init__.py +++ b/backends/xnnpack/partition/config/__init__.py @@ -1,5 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. +# Copyright 2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -23,6 +24,7 @@ CatConfig, CeilConfig, ClampConfig, + CloneConfig, CloneDimOrderConfig, ConstantPadConfig, CosConfig, @@ -82,6 +84,7 @@ BMMConfig, CatConfig, CeilConfig, + CloneConfig, CloneDimOrderConfig, ConstantPadConfig, ConvolutionConfig, diff --git a/backends/xnnpack/partition/config/generic_node_configs.py b/backends/xnnpack/partition/config/generic_node_configs.py index f58c8eefdbe..2f45a8bba04 100644 --- a/backends/xnnpack/partition/config/generic_node_configs.py +++ b/backends/xnnpack/partition/config/generic_node_configs.py @@ -239,6 +239,27 @@ def supported_precision_types(self) -> List[ConfigPrecisionType]: return [ConfigPrecisionType.FP32] +class CloneConfig(GenericNodePartitionerConfig): + target_name = "clone.default" + + def supported_precision_types(self) -> List[ConfigPrecisionType]: + return [ConfigPrecisionType.FP32] + + def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool: + if not self.check_common_constraints(node, ep): + return False + + input_meta = node.args[0].meta["val"] + output_meta = node.meta["val"] + input_dim_order = list(input_meta.dim_order()) + output_dim_order = list(output_meta.dim_order()) + if input_dim_order != output_dim_order: + why(node, reason="Only dim-order preserving clones are supported.") + return False + + return True + + class ClampConfig(GenericNodePartitionerConfig): target_name = "clamp.default" diff --git a/backends/xnnpack/runtime/XNNExecutor.cpp b/backends/xnnpack/runtime/XNNExecutor.cpp index 103a8812931..5a150f92b6b 100644 --- a/backends/xnnpack/runtime/XNNExecutor.cpp +++ b/backends/xnnpack/runtime/XNNExecutor.cpp @@ -23,6 +23,28 @@ using executorch::runtime::is_contiguous_dim_order; using executorch::runtime::kTensorDimensionLimit; using executorch::runtime::Span; +namespace { +class InUseGuard { + public: + explicit InUseGuard(std::atomic& flag) : flag_(flag) {} + ~InUseGuard() { + if (!dismissed_) { + flag_.store(false, std::memory_order_release); + } + } + void dismiss() { + dismissed_ = true; + } + + InUseGuard(const InUseGuard&) = delete; + InUseGuard& operator=(const InUseGuard&) = delete; + + private: + std::atomic& flag_; + bool dismissed_ = false; +}; +} // namespace + /** * Initializes the XNNExecutor with the runtime and given number of * inputs/outputs externals_ is resized to the total number of inputs and @@ -71,6 +93,21 @@ ET_NODISCARD Error XNNExecutor::initialize( * delegate->execute() */ ET_NODISCARD Error XNNExecutor::prepare_args(Span args) { + ET_DCHECK_MSG( + !destroyed_.load(std::memory_order_acquire), + "XNNExecutor::prepare_args called after destroy"); + + bool was_in_use = in_use_.exchange(true, std::memory_order_acquire); + if (was_in_use) { + ET_LOG(Error, "XNNExecutor::prepare_args called concurrently"); + } + ET_DCHECK_MSG(!was_in_use, "XNNExecutor::prepare_args called concurrently"); + + InUseGuard in_use_guard(in_use_); + if (was_in_use) { + in_use_guard.dismiss(); + } + ET_CHECK_OR_RETURN_ERROR( runtime_ != nullptr, Internal, @@ -142,6 +179,7 @@ ET_NODISCARD Error XNNExecutor::prepare_args(Span args) { return err; } + in_use_guard.dismiss(); return Error::Ok; } @@ -152,6 +190,8 @@ ET_NODISCARD Error XNNExecutor::prepare_args(Span args) { * After which we then execute the runtime through invoke_runtime. */ ET_NODISCARD Error XNNExecutor::forward(BackendExecutionContext& context) { + InUseGuard in_use_guard(in_use_); + ET_CHECK_OR_RETURN_ERROR( runtime_ != nullptr, Internal, @@ -160,11 +200,13 @@ ET_NODISCARD Error XNNExecutor::forward(BackendExecutionContext& context) { xnn_status status = xnn_setup_runtime_v2( runtime_.get(), externals_.size(), externals_.data()); - ET_CHECK_OR_RETURN_ERROR( - status == xnn_status_success, - Internal, - "Internal Error: Setting up the runtime failed with code: %s", - xnn_status_to_string(status)); + if (status != xnn_status_success) { + ET_LOG( + Error, + "Internal Error: Setting up the runtime failed with code: %s", + xnn_status_to_string(status)); + return Error::Internal; + } auto error = profiler_.start(context.event_tracer()); if (error != Error::Ok) { diff --git a/backends/xnnpack/runtime/XNNExecutor.h b/backends/xnnpack/runtime/XNNExecutor.h index fa7c8360be4..2d709678c1c 100644 --- a/backends/xnnpack/runtime/XNNExecutor.h +++ b/backends/xnnpack/runtime/XNNExecutor.h @@ -16,6 +16,7 @@ #include #include +#include #include #include @@ -36,11 +37,20 @@ class XNNExecutor { std::vector externals_; std::vector packed_data_names_; std::shared_ptr workspace_; + std::atomic in_use_{false}; + std::atomic destroyed_{false}; public: XNNExecutor(std::shared_ptr workspace) : workspace_(workspace) {} + ~XNNExecutor() { + ET_DCHECK_MSG( + !in_use_.load(std::memory_order_acquire), + "XNNExecutor destroyed while in use"); + destroyed_.store(true, std::memory_order_release); + } + inline size_t getNumInputs() { return input_ids_.size(); } diff --git a/backends/xnnpack/runtime/XNNPACKBackend.cpp b/backends/xnnpack/runtime/XNNPACKBackend.cpp index c20fa985f46..9eaadda86f8 100644 --- a/backends/xnnpack/runtime/XNNPACKBackend.cpp +++ b/backends/xnnpack/runtime/XNNPACKBackend.cpp @@ -100,6 +100,7 @@ class XnnpackBackend final lock_weights_cache.lock(); weights_cache_->initialize_for_runtime( context.get_runtime_allocator(), named_data_map); + workspace->set_uses_weight_cache(); } auto [workspace_lock, workspace_ptr] = workspace->acquire(); @@ -129,6 +130,7 @@ class XnnpackBackend final Error, "XNNCompiler::compileModel failed: 0x%x", (unsigned int)err); return err; } + return executor; } @@ -138,13 +140,15 @@ class XnnpackBackend final Span args) const override { auto executor = static_cast(handle); + auto workspace = executor->get_workspace(); + std::unique_lock lock_weights_cache( weights_cache_mutex_, std::defer_lock); - if (executor->uses_weight_cache()) { + if (executor->uses_weight_cache() || workspace->uses_weight_cache()) { lock_weights_cache.lock(); } - auto [raii_lock, _] = executor->get_workspace()->acquire(); + auto [raii_lock, _] = workspace->acquire(); // Prepare Inputs/Outputs and Propagate Input Shapes Error err = executor->prepare_args(args); @@ -167,14 +171,16 @@ class XnnpackBackend final void destroy(DelegateHandle* handle) const override { if (handle != nullptr) { auto executor = static_cast(handle); + auto workspace = executor->get_workspace(); + + const std::lock_guard lock_weights_cache( + weights_cache_mutex_); #ifdef ENABLE_XNNPACK_PROFILING executor->print_avg_op_timings(); #endif if (executor->uses_weight_cache()) { - const std::lock_guard lock_weights_cache( - weights_cache_mutex_); weights_cache_->delete_packed_data(executor->get_packed_data_names()); } @@ -183,7 +189,6 @@ class XnnpackBackend final // the same backend instance. Make sure to hold onto the workspace // shared_ptr, as the pointer in the executor is freed, which includes // the mutex referenced by raii_lock. - auto workspace = executor->get_workspace(); auto [raii_lock, _] = workspace->acquire(); // XNNExecutor is not trivially destructible. Since this was constructed diff --git a/backends/xnnpack/runtime/XNNWorkspace.h b/backends/xnnpack/runtime/XNNWorkspace.h index b7ef442c460..e1b452a0a8b 100644 --- a/backends/xnnpack/runtime/XNNWorkspace.h +++ b/backends/xnnpack/runtime/XNNWorkspace.h @@ -59,6 +59,14 @@ class XNNWorkspace { lock_required_ = false; } + void set_uses_weight_cache() { + uses_weight_cache_.store(true, std::memory_order_release); + } + + bool uses_weight_cache() const { + return uses_weight_cache_.load(std::memory_order_acquire); + } + static runtime::Result> create() { // Because this class can't be moved, we need to construct it in-place. xnn_workspace_t workspace = nullptr; @@ -80,6 +88,7 @@ class XNNWorkspace { std::mutex mutex_; uint64_t id_; bool lock_required_ = true; + std::atomic uses_weight_cache_{false}; WorkspacePtr workspace_; }; diff --git a/backends/xnnpack/runtime/XNNWorkspaceManager.cpp b/backends/xnnpack/runtime/XNNWorkspaceManager.cpp index d3550da5cc7..e115074a108 100644 --- a/backends/xnnpack/runtime/XNNWorkspaceManager.cpp +++ b/backends/xnnpack/runtime/XNNWorkspaceManager.cpp @@ -61,7 +61,9 @@ XNNWorkspaceManager::get_or_create_workspace( return create_result.error(); } +#ifndef XNNPACK_WORKSPACE_ALWAYS_LOCK create_result.get()->disable_locking(); +#endif return create_result.get(); } else if (mode == WorkspaceSharingMode::PerModel) { return get_or_create_model_workspace(program_id); diff --git a/backends/xnnpack/targets.bzl b/backends/xnnpack/targets.bzl index 868e68e5b8c..b3af589df10 100644 --- a/backends/xnnpack/targets.bzl +++ b/backends/xnnpack/targets.bzl @@ -14,6 +14,8 @@ def _get_preprocessor_flags(): if native.read_config("executorch", "xnnpack_weights_cache", "0") != "0": preprocessor_flags.append("-DENABLE_XNNPACK_WEIGHTS_CACHE") + preprocessor_flags.append("-DXNNPACK_WORKSPACE_ALWAYS_LOCK") + # Enable if not disabled through config return preprocessor_flags diff --git a/backends/xnnpack/test/ops/test_clone.py b/backends/xnnpack/test/ops/test_clone.py index 0396b9b2bea..bb995a6cf1e 100644 --- a/backends/xnnpack/test/ops/test_clone.py +++ b/backends/xnnpack/test/ops/test_clone.py @@ -1,5 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. +# Copyright 2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -9,7 +10,8 @@ import unittest import torch -from executorch.backends.xnnpack.test.tester import Tester +from executorch.backends.xnnpack.test.tester import Tester, ToEdgeTransformAndLower +from executorch.backends.xnnpack.utils.configs import get_xnnpack_edge_compile_config class TestClone(unittest.TestCase): @@ -62,6 +64,32 @@ def test_fp32_clone(self): inputs = (torch.randn(2, 3, 4, 5),) self._test_clone_partitioned(inputs) + def test_fp32_clone_default_partitions_with_skip_dim_order(self): + """Test plain aten.clone.default partitioning without dim-order rewrite.""" + inputs = (torch.randn(2, 3, 4, 5),) + ( + Tester(self.Clone(), inputs) + .export() + .check_count({"torch.ops.aten.clone.default": 1}) + .to_edge_transform_and_lower( + ToEdgeTransformAndLower( + edge_compile_config=get_xnnpack_edge_compile_config( + skip_dim_order=True + ) + ) + ) + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .check_not( + [ + "executorch_exir_dialects_edge__ops_aten_clone_default", + "executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default", + ] + ) + .to_executorch() + .serialize() + .run_method_and_compare_outputs() + ) + def test_fp32_clone_2d(self): """Test FP32 clone with 2D tensor - should be partitioned""" inputs = (torch.randn(10, 20),) diff --git a/backends/xnnpack/test/runtime/test_workspace_manager.cpp b/backends/xnnpack/test/runtime/test_workspace_manager.cpp index a7689966635..a239d19b415 100644 --- a/backends/xnnpack/test/runtime/test_workspace_manager.cpp +++ b/backends/xnnpack/test/runtime/test_workspace_manager.cpp @@ -116,7 +116,11 @@ TEST_F(XNNWorkspaceManagerTest, DisabledModeAcquireDoesNotLock) { auto [lock, ptr] = workspace->acquire(); ASSERT_NE(ptr, nullptr); +#ifdef XNNPACK_WORKSPACE_ALWAYS_LOCK + EXPECT_TRUE(lock.owns_lock()); +#else EXPECT_FALSE(lock.owns_lock()); +#endif } TEST_F(XNNWorkspaceManagerTest, PerModelMode) { diff --git a/backends/xnnpack/test/targets.bzl b/backends/xnnpack/test/targets.bzl index 812986a12e6..d690e1c9dcd 100644 --- a/backends/xnnpack/test/targets.bzl +++ b/backends/xnnpack/test/targets.bzl @@ -96,6 +96,9 @@ def define_common_targets(): runtime.cxx_test( name = "test_workspace_manager", srcs = ["runtime/test_workspace_manager.cpp"], + preprocessor_flags = [ + "-DXNNPACK_WORKSPACE_ALWAYS_LOCK", + ], deps = [ third_party_dep("XNNPACK"), "//executorch/backends/xnnpack:xnnpack_backend", diff --git a/conftest.py b/conftest.py index 19d777a74e0..be0e6e4ea3d 100644 --- a/conftest.py +++ b/conftest.py @@ -1,3 +1,4 @@ +import hashlib import sys import torch @@ -13,5 +14,8 @@ "backends/apple/**", ] -# Seed the run -torch.manual_seed(42) + +def pytest_runtest_setup(item): + # Set a stable seed for each test based on a hash of the test name. + seed = int(hashlib.sha256(item.nodeid.encode()).hexdigest(), 16) % (2**32) + torch.manual_seed(seed) diff --git a/devtools/inspector/tests/inspector_test.py b/devtools/inspector/tests/inspector_test.py index b33c5b37164..4c59190650c 100644 --- a/devtools/inspector/tests/inspector_test.py +++ b/devtools/inspector/tests/inspector_test.py @@ -7,6 +7,7 @@ # pyre-unsafe import copy +import functools import os import random import statistics @@ -90,6 +91,28 @@ def forward(self, indices: torch.Tensor, values: torch.Tensor) -> torch.Tensor: ETRECORD_PATH = "unittest_etrecord_path" +def disable_if(condition, reason): + """Disable a test when condition is true, still reporting it as executed. + + Conditional analogue of unittest.skipIf that keeps disabled tests visible in + logs instead of producing a skipped result, which some test runners handle + inconsistently. + """ + + def decorator(fn): + if not condition: + return fn + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + print(f"DISABLED_TEST: {fn.__qualname__}: {reason}") + return None + + return wrapper + + return decorator + + # TODO: write an E2E test: create an inspector instance, mock just the file reads, and then verify the external correctness class TestInspector(unittest.TestCase): def test_perf_data(self) -> None: @@ -1504,7 +1527,7 @@ def test_calculate_numeric_gap_with_edge_dialect_exported_program_name(self): self.assertIsInstance(df, pd.DataFrame) self.assertEqual(len(df), 1) - @unittest.skipIf(sys.platform.startswith("win"), "Skipping on Windows") + @disable_if(sys.platform.startswith("win"), "Skipping on Windows") def test_transformer_block_xnnpack_numeric_gap_within_tolerance(self): """ Test that the numeric gap between AOT and runtime intermediate outputs @@ -1693,7 +1716,7 @@ def forward( f"Stack trace for {op_name} doesn't contain file info", ) - @unittest.skipIf(sys.platform.startswith("win"), "Skipping on Windows") + @disable_if(sys.platform.startswith("win"), "Skipping on Windows") def test_intermediate_tensor_comparison_with_torch_export(self): """Test intermediate tensor comparison using torch.export.export and to_edge_transform_and_lower. @@ -1840,7 +1863,7 @@ def _gen_random_runtime_output( ) -> List[Union[None, List[torch.Tensor], bool, float, int, str, torch.Tensor]]: return [torch.randn(RAW_DATA_SIZE)] - @unittest.skipIf(sys.platform.startswith("win"), "Skipping on Windows") + @disable_if(sys.platform.startswith("win"), "Skipping on Windows") def test_disable_debug_handle_validation_with_symbolic_shapes(self): """ Test that demonstrates the issue with symbolic shape related nodes losing from_node info diff --git a/docs/source/backends/nxp/nxp-kernel-selection.md b/docs/source/backends/nxp/nxp-kernel-selection.md index 3ff61323694..307f06d1d02 100644 --- a/docs/source/backends/nxp/nxp-kernel-selection.md +++ b/docs/source/backends/nxp/nxp-kernel-selection.md @@ -1,25 +1,25 @@ # NXP eIQ Neutron Kernel Selective Kernel Registration -The NXP ExecuTorch backend supports selective Neutron kernel registration for `Neutron-C` targets, which decreases the +The NXP ExecuTorch backend supports selective Neutron kernel registration for `Neutron-C` targets, which reduces the size of the Neutron Firmware. During the backend's conversion to the Neutron representation by the Neutron Converter, microcode for the Neutron accelerator is generated. The microcode consists of kernel calls executed by the Neutron Driver. The code for kernel call functions is -distributed in Neutron Firmware. +distributed in the Neutron Firmware. -The `eiq_neutron_sdk.neutron_converter` optionally generates the `*_kernel_selection.c` file, registering -only kernels that are required for a particular model or in the case of ExecuTorch, a delegated subgraph. This -`*_kernel_selection.c`, when used during the application linking, takes precedence over the default list of registered +The `eiq_neutron_sdk.neutron_converter` optionally generates a `*_kernel_selection.c` file, registering +only kernels that are required for a particular model or, in the case of ExecuTorch, a delegated subgraph. This +`*_kernel_selection.c`, when used during application linking, takes precedence over the default list of registered kernels in the Neutron Firmware, and allows the linker to include only the necessary Neutron kernels. -This software is required for deployment on an edge device (e.g. `i.MXRT700`) and is -distributed via the MCUXpresso SDK. The MCUXpresso SDK enables building of a final application that is then flashed on +The Neutron Firmware is required for deployment on an edge device (e.g. `i.MX RT700`) and is +distributed via the MCUXpresso SDK. The MCUXpresso SDK enables the building of a final application that is then flashed on the edge device. For more details about this process, see [eIQ ExecuTorch Library User Guide](https://mcuxpresso.nxp.com/mcuxsdk/latest/html/middleware/eiq/executorch/docs/nxp/ugindex.html). -By default, for Neutron-C targets like `i.MXRT700`, all kernel implementations are present in the Neutron Firmware, which +By default, for Neutron-C targets like `i.MX RT700`, all kernel implementations are present in the Neutron Firmware, which is linked to the final application. This enables an easy build process for any model, but increases the size of the -final application with unused code. In the case of limited RAM, you can link only kernels that are used in the set of -models deployed. This way you can reduce the size of the final app by linking only selected kernels, used in one or -multiple models. +final application with unused code. In memory-constrained environments, you can link only the kernels required by the +deployed models. This way you can reduce the size of the final application by linking only selected kernels, used in one +or more models. The feature works as follows: The Neutron Converter with the appropriate flag exports a kernel selection file for each converted subgraph, the kernel selection files are then merged and ready to be included in the MCUXpresso SDK to use for @@ -30,7 +30,7 @@ a selection-only build. ## Export kernel selection file -To turn on this feature on the side of NXP ExecuTorch backend, use the parameter `--dump_kernel_selection_code` in +To enable this feature in the NXP ExecuTorch backend, use the parameter `--dump_kernel_selection_code` in `aot_neutron_compile.py`. An example with the CifarNet model: ```commandline @@ -43,7 +43,7 @@ This command will create a `*_kernel_selection.c` file alongside the converted P ## Kernel Registration for Multiple Models -If you want to use or experiment with multiple models in one application while having reduced kernel set, you can +If you want to use or experiment with multiple models in one application while having a reduced kernel set, you can create one kernel selection file with the script `merge_kernel_selection_code.py`: ```commandline diff --git a/examples/apple/coreml/scripts/BUCK b/examples/apple/coreml/scripts/BUCK index 164feb8d306..42a97ea893f 100644 --- a/examples/apple/coreml/scripts/BUCK +++ b/examples/apple/coreml/scripts/BUCK @@ -16,6 +16,19 @@ fbcode_target(_kind = python_binary, ], ) +fbcode_target(_kind = python_binary, + name = "coreml_compute_plan", + srcs = [ + "coreml_compute_plan.py", + ], + main_function = "executorch.examples.apple.coreml.scripts.coreml_compute_plan.main", + deps = [ + "//executorch/backends/apple/coreml:executorchcoreml", + "//executorch/exir:schema", + "//executorch/exir/_serialize:lib", + ], +) + fbcode_target(_kind = python_binary, name = "export", srcs = [ diff --git a/examples/apple/coreml/scripts/coreml_compute_plan.py b/examples/apple/coreml/scripts/coreml_compute_plan.py new file mode 100644 index 00000000000..c0ca08db831 --- /dev/null +++ b/examples/apple/coreml/scripts/coreml_compute_plan.py @@ -0,0 +1,236 @@ +# Copyright © 2026 Apple Inc. All rights reserved. +# +# Please refer to the license found in the LICENSE file in the root directory of the source tree. + +"""Report which CoreML operations would dispatch to ANE / GPU / CPU. + +The CoreML runtime decides at compile/load time which compute device each +MIL operation will run on; that decision is exposed by ``MLComputePlan`` +in coremltools 9.0+. This script wraps that API so users can answer +"why isn't my model running on the ANE?" without writing Swift. + +Usage:: + + # Analyze a CoreML model directly (mlpackage or compiled mlmodelc). + python coreml_compute_plan.py --model_path path/to/model.mlpackage + + # Analyze every Core ML partition embedded in an ExecuTorch .pte. + python coreml_compute_plan.py --model_path path/to/program.pte + + # Show ops that fell off the ANE, grouped by op type. + python coreml_compute_plan.py --model_path model.mlpackage --show_non_ane + + # Pick which devices the runtime is allowed to consider. + python coreml_compute_plan.py --model_path model.mlpackage \\ + --compute_units cpu_and_ne +""" + +import argparse +import os +import sys +import tempfile +from collections import Counter +from typing import Iterable, List, Tuple + +import coremltools as ct +from coremltools.models.compute_device import ( + MLCPUComputeDevice, + MLGPUComputeDevice, + MLNeuralEngineComputeDevice, +) +from coremltools.models.compute_plan import MLComputePlan + +from executorch.examples.apple.coreml.scripts.extract_coreml_models import ( + extract_coreml_models, +) + + +_DEVICE_NAMES: List[Tuple[type, str]] = [ + (MLNeuralEngineComputeDevice, "ANE"), + (MLGPUComputeDevice, "GPU"), + (MLCPUComputeDevice, "CPU"), +] + +_COMPUTE_UNIT_CHOICES = { + "all": ct.ComputeUnit.ALL, + "cpu_and_ne": ct.ComputeUnit.CPU_AND_NE, + "cpu_and_gpu": ct.ComputeUnit.CPU_AND_GPU, + "cpu_only": ct.ComputeUnit.CPU_ONLY, +} + + +def _device_name(device) -> str: + if device is None: + return "unknown" + for cls, name in _DEVICE_NAMES: + if isinstance(device, cls): + return name + return type(device).__name__ + + +def _iter_operations(block) -> Iterable: + for op in block.operations: + yield op + for nested in getattr(op, "blocks", None) or []: + yield from _iter_operations(nested) + + +def _ensure_compiled(model_path: str, tmpdir: str) -> str: + """Return a `.mlmodelc` path; compile from `.mlpackage` if needed.""" + if model_path.endswith(".mlmodelc"): + return model_path + if model_path.endswith(".mlpackage"): + dest = os.path.join( + tmpdir, os.path.basename(model_path).replace(".mlpackage", ".mlmodelc") + ) + return str(ct.models.utils.compile_model(model_path, destination_path=dest)) + raise ValueError(f"Expected a .mlpackage or .mlmodelc path, got: {model_path}") + + +def analyze_one( + model_path: str, compute_units: ct.ComputeUnit +) -> List[Tuple[str, str, str]]: + """Return [(function, operator_name, device)] for every op that has a plan. + + coremltools 9.0's ``MLComputePlan.load_from_path`` only exposes usage for + the default function of a multifunction package, so a multifunction + .mlpackage is analyzed function-by-function by projecting each function + as the ``main`` of a temp single-function copy. + """ + function_names = _mlpackage_function_names(model_path) + if len(function_names) <= 1: + return _analyze_compiled(model_path, compute_units) + rows: List[Tuple[str, str, str]] = [] + with tempfile.TemporaryDirectory() as tmpdir: + for fname in function_names: + projected = _project_to_single(model_path, fname, tmpdir) + for _, op_name, device in _analyze_compiled(projected, compute_units): + rows.append((fname, op_name, device)) + return rows + + +def _analyze_compiled( + model_path: str, compute_units: ct.ComputeUnit +) -> List[Tuple[str, str, str]]: + with tempfile.TemporaryDirectory() as tmpdir: + compiled = _ensure_compiled(model_path, tmpdir) + plan = MLComputePlan.load_from_path(compiled, compute_units=compute_units) + program = plan.model_structure.program + if program is None: + raise RuntimeError( + f"{model_path} is not an MLProgram model; this tool only supports " + "the MLProgram backend (the CoreML backend executorch produces today)." + ) + + rows: List[Tuple[str, str, str]] = [] + for fname, fn in program.functions.items(): + for op in _iter_operations(fn.block): + usage = plan.get_compute_device_usage_for_mlprogram_operation(op) + if usage is None: + # Constants and similar non-dispatched ops don't have a plan. + continue + rows.append( + ( + fname, + op.operator_name, + _device_name(usage.preferred_compute_device), + ) + ) + return rows + + +def _mlpackage_function_names(model_path: str) -> List[str]: + """Names of the MLProgram functions inside an .mlpackage, or [] otherwise.""" + if not model_path.endswith(".mlpackage"): + return [] + spec = ct.models.MLModel(model_path, skip_model_load=True).get_spec() + if spec.WhichOneof("Type") != "mlProgram": + return [] + return list(spec.mlProgram.functions.keys()) + + +def _project_to_single(src_mlpackage: str, function_name: str, tmpdir: str) -> str: + """Re-save ``src_mlpackage`` with only ``function_name`` exposed as ``main``.""" + from coremltools.models.utils import MultiFunctionDescriptor, save_multifunction + + dest = os.path.join(tmpdir, f"{function_name}.mlpackage") + desc = MultiFunctionDescriptor() + desc.add_function( + src_mlpackage, + src_function_name=function_name, + target_function_name="main", + ) + desc.default_function_name = "main" + save_multifunction(desc, dest) + return dest + + +def _print_report( + label: str, rows: List[Tuple[str, str, str]], show_non_ane: bool +) -> None: + print(f"\n=== {label} ===") + if not rows: + print(" (no dispatched operations found)") + return + by_device = Counter(device for _, _, device in rows) + total = sum(by_device.values()) + for device in ("ANE", "GPU", "CPU", "unknown"): + count = by_device.get(device, 0) + if count == 0: + continue + pct = 100.0 * count / total + print(f" {device}: {count:5d} / {total} ({pct:5.1f}%)") + + if show_non_ane: + non_ane = [(fn, op_name) for fn, op_name, dev in rows if dev != "ANE"] + if non_ane: + print("\n Non-ANE op types:") + for op_name, count in Counter(op for _, op in non_ane).most_common(): + print(f" {count:5d} {op_name}") + + +def main() -> int: + parser = argparse.ArgumentParser(description=__doc__.splitlines()[0]) + parser.add_argument( + "--model_path", + required=True, + help="Path to a .pte, .mlpackage, or .mlmodelc.", + ) + parser.add_argument( + "--compute_units", + default="cpu_and_ne", + choices=sorted(_COMPUTE_UNIT_CHOICES), + help="Which devices the runtime may use when planning dispatch.", + ) + parser.add_argument( + "--show_non_ane", + action="store_true", + help="List op types that did not get assigned to the ANE.", + ) + args = parser.parse_args() + + compute_units = _COMPUTE_UNIT_CHOICES[args.compute_units] + model_path = args.model_path + + if model_path.endswith(".pte"): + with open(model_path, "rb") as f: + pte_data = f.read() + with tempfile.TemporaryDirectory() as out_dir: + extracted = extract_coreml_models(pte_data, out_dir=out_dir) + if not extracted: + print( + f"{model_path} does not contain any CoreML delegate partitions.", + file=sys.stderr, + ) + return 1 + for path in extracted: + rows = analyze_one(str(path), compute_units) + _print_report(path.name, rows, args.show_non_ane) + else: + rows = analyze_one(model_path, compute_units) + _print_report(os.path.basename(model_path.rstrip("/")), rows, args.show_non_ane) + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/examples/apple/coreml/scripts/extract_coreml_models.py b/examples/apple/coreml/scripts/extract_coreml_models.py index 685b6b594f3..8956550eb4d 100644 --- a/examples/apple/coreml/scripts/extract_coreml_models.py +++ b/examples/apple/coreml/scripts/extract_coreml_models.py @@ -9,7 +9,7 @@ import shutil from pathlib import Path -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Union from executorch.backends.apple.coreml import executorchcoreml from executorch.exir._serialize._program import deserialize_pte_binary @@ -22,7 +22,12 @@ COREML_BACKEND_ID = "CoreMLBackend" -def extract_coreml_models(pte_data: bytes): +def extract_coreml_models( + pte_data: bytes, + out_dir: Optional[Union[str, Path]] = None, +) -> List[Path]: + out_root = Path(out_dir) if out_dir is not None else Path("extracted_coreml_models") + pte_file = deserialize_pte_binary(pte_data) program = pte_file.program @@ -44,6 +49,7 @@ def extract_coreml_models(pte_data: bytes): ] # Track extracted models to avoid duplicates (multifunction models share partitions) + extracted_paths: List[Path] = [] extracted_keys: set = set() model_index: int = 1 @@ -95,7 +101,7 @@ def extract_coreml_models(pte_data: bytes): if model_name is None: model_name = f"model_{model_index}" - model_path: Path = Path() / "extracted_coreml_models" / model_name + model_path: Path = out_root / model_name if model_path.exists(): shutil.rmtree(model_path.absolute()) os.makedirs(model_path.absolute()) @@ -104,11 +110,14 @@ def extract_coreml_models(pte_data: bytes): coreml_processed_bytes, str(model_path.absolute()) ): print(f"Core ML models are extracted and saved to path = {model_path}") + extracted_paths.append(model_path) model_index += 1 if len(coreml_delegates) == 0: print("The model isn't delegated to Core ML.") + return extracted_paths + def main() -> None: """ diff --git a/examples/apple/coreml/scripts/test_coreml_compute_plan.py b/examples/apple/coreml/scripts/test_coreml_compute_plan.py new file mode 100644 index 00000000000..83f06b7a2a8 --- /dev/null +++ b/examples/apple/coreml/scripts/test_coreml_compute_plan.py @@ -0,0 +1,161 @@ +# Copyright © 2026 Apple Inc. All rights reserved. +# +# Please refer to the license found in the LICENSE file in the root directory of the source tree. + +"""Tests for coreml_compute_plan.py.""" + +import os +import shutil +import tempfile +import unittest +from collections import Counter + +import coremltools as ct +import torch +from coremltools.models.utils import MultiFunctionDescriptor, save_multifunction + +from executorch.examples.apple.coreml.scripts.coreml_compute_plan import ( + _COMPUTE_UNIT_CHOICES, + _device_name, + analyze_one, +) + + +class _Op: + def __init__(self, operator_name: str, blocks=None): + self.operator_name = operator_name + self.blocks = blocks or [] + + +class _Block: + __slots__ = ("operations",) + + def __init__(self, ops): + self.operations = ops + + +def _build_small_mlpackage(out_dir: str) -> str: + class M(torch.nn.Module): + def forward(self, x): + return torch.nn.functional.relu(x @ x.T) + x.sum() + + model = M().eval() + ep = torch.export.export(model, (torch.randn(8, 8),), strict=True) + ep = ep.run_decompositions({}) + mlmodel = ct.convert( + ep, + source="pytorch", + convert_to="mlprogram", + minimum_deployment_target=ct.target.iOS17, + skip_model_load=True, + ) + out = os.path.join(out_dir, "tiny.mlpackage") + mlmodel.save(out) + return out + + +class TestDeviceName(unittest.TestCase): + def test_none_device(self): + self.assertEqual(_device_name(None), "unknown") + + def test_known_device_classes(self): + from coremltools.models.compute_device import MLNeuralEngineComputeDevice + + # Don't construct the device classes directly (they wrap proxies that + # may be unavailable in some envs); just confirm the type-mapping path + # returns sensible names by mocking the isinstance check with a fake. + class FakeNE(MLNeuralEngineComputeDevice): + def __init__(self): + pass + + self.assertEqual(_device_name(FakeNE()), "ANE") + + +class TestComputeUnitChoices(unittest.TestCase): + def test_includes_cpu_and_ne(self): + self.assertEqual(_COMPUTE_UNIT_CHOICES["cpu_and_ne"], ct.ComputeUnit.CPU_AND_NE) + + def test_includes_all(self): + self.assertEqual(_COMPUTE_UNIT_CHOICES["all"], ct.ComputeUnit.ALL) + + +class TestAnalyzeOne(unittest.TestCase): + """End-to-end: build a tiny mlpackage and analyze it.""" + + @classmethod + def setUpClass(cls): + cls.tmpdir = tempfile.mkdtemp() + cls.mlpackage = _build_small_mlpackage(cls.tmpdir) + + @classmethod + def tearDownClass(cls): + shutil.rmtree(cls.tmpdir, ignore_errors=True) + + def test_returns_rows_for_dispatched_ops(self): + rows = analyze_one(self.mlpackage, ct.ComputeUnit.CPU_AND_NE) + self.assertGreater(len(rows), 0, "expected at least one dispatched op") + # Every row is (function_name, operator_name, device_name). + for fname, op_name, device in rows: + self.assertIsInstance(fname, str) + self.assertIsInstance(op_name, str) + self.assertIn(device, {"ANE", "GPU", "CPU", "unknown"}) + + def test_main_function_present(self): + rows = analyze_one(self.mlpackage, ct.ComputeUnit.CPU_ONLY) + self.assertIn("main", {fname for fname, _, _ in rows}) + + def test_op_types_for_relu_matmul_model(self): + # The toy model is `relu(x @ x.T) + x.sum()` so the lowered MIL + # should at least contain matmul, relu, add and reduce_sum. + rows = analyze_one(self.mlpackage, ct.ComputeUnit.CPU_ONLY) + op_types = Counter(op for _, op, _ in rows) + # Op names are versioned (e.g. "ios17.matmul"), so match by suffix. + suffixes = {name.split(".")[-1] for name in op_types} + for expected in ("matmul", "relu", "add", "reduce_sum"): + self.assertIn(expected, suffixes, f"missing op {expected}: {suffixes}") + + +class TestAnalyzeOneMultifunction(unittest.TestCase): + """Verify analyze_one walks every function of a multifunction .mlpackage. + + coremltools 9.0's MLComputePlan.load_from_path only exposes usage for + the default function, so analyze_one re-projects each function through + MultiFunctionDescriptor to surface plans for the rest. + """ + + @classmethod + def setUpClass(cls): + cls.tmpdir = tempfile.mkdtemp() + single = _build_small_mlpackage(cls.tmpdir) + desc = MultiFunctionDescriptor() + desc.add_function( + single, src_function_name="main", target_function_name="prefill" + ) + desc.add_function( + single, src_function_name="main", target_function_name="decode" + ) + desc.default_function_name = "prefill" + cls.multi = os.path.join(cls.tmpdir, "multi.mlpackage") + save_multifunction(desc, cls.multi) + + @classmethod + def tearDownClass(cls): + shutil.rmtree(cls.tmpdir, ignore_errors=True) + + def test_reports_every_function(self): + rows = analyze_one(self.multi, ct.ComputeUnit.CPU_ONLY) + fnames = {fname for fname, _, _ in rows} + self.assertEqual(fnames, {"prefill", "decode"}) + + def test_each_function_lowers_the_same_ops(self): + rows = analyze_one(self.multi, ct.ComputeUnit.CPU_ONLY) + per_fn: dict = {} + for fname, op_name, _ in rows: + per_fn.setdefault(fname, set()).add(op_name.split(".")[-1]) + for fname in ("prefill", "decode"): + self.assertIn("matmul", per_fn.get(fname, set()), f"{fname} missing matmul") + self.assertIn("relu", per_fn.get(fname, set()), f"{fname} missing relu") + + +if __name__ == "__main__": + unittest.main() diff --git a/examples/arm/README.md b/examples/arm/README.md index c5f5bb24862..07aecec51e2 100644 --- a/examples/arm/README.md +++ b/examples/arm/README.md @@ -5,175 +5,95 @@ This source code is licensed under the BSD-style license found in the LICENSE file in the root directory of this source tree. --> -## ExecuTorch for Arm backends Ethos-U, VGF and Cortex-M +# Examples for Arm backends Ethos-U, VGF and Cortex-M -This project contains scripts to help you setup and run a PyTorch -model on a Arm backend via ExecuTorch. This backend supports Ethos-U and VGF as -targets (using TOSA) but you can also use the Ethos-U example runner as an example -on Cortex-M if you do not delegate the model. +This directory contains documentation and scripts to +help you setup and run a PyTorch model on the Arm backend +via ExecuTorch. -The main scripts are `setup.sh`, `run.sh` and -`backends/arm/scripts/aot_arm_compiler.py`. +## setup.sh -`setup.sh` will install the needed tools and with --root-dir -you can change the path to a scratch folder where it will download and generate build -artifacts. If supplied, you must also supply the same folder to run.sh with ---scratch-dir= If not supplied both scripts will use examples/arm/arm-scratch. +`setup.sh` downloads the Arm cross-compilation toolchain and Corstone FVP +simulators, installs the Python dependencies for TOSA, Ethos-U Vela, and +Cortex-M/CMSIS-NN, and generates `setup_path.sh` scripts for adding those tools +to your environment. Optional flags also install VGF/MLSDK and Vulkan +dependencies. -`run.sh` can be used to build, run and test a model in an easy way and it will call cmake for you -and in cases you want to run a simulator it will start it also. The script will call `aot_arm_compiler.py` -to convert a model and include it in the build/run. - -For bare-metal Ethos-U builds `run.sh` configures the standalone -`examples/arm/executor_runner/standalone` CMake entry point automatically. If -`--build-dir` is omitted, the script creates and owns a build tree under -`arm_test/_`. Supplying `--build-dir` reuses an existing tree -(for example a VGF host build or out-of-tree configuration) and `run.sh` -verifies it exposes the runner options it needs before compiling. - -Build and test artifacts are by default placed under the folder arm_test folder -this can be changed with --et_build_root= - -`aot_arm_compiler.py` is used to convert a Python model or a saved .pt model to a PTE file and is used by `run.sh` -and other test script but can also be used directly. - - -## Create a PTE file for Arm backends - -There is an easy to use example flow to compile your PyTorch model to a PTE file for the Arm backend called `aot_arm_compiler.py` -that you can use to generate PTE files, it can generate PTE files for the supported targets `-t` or even non delegated (Cortex-M) -using different memory modes and can both use a python file as input or just use the models from examples/models with `--model_name`. -It also supports generating Devtools artifacts like BundleIO BPTE files, and ETRecords. Run it with `--help` to check its capabilities. - -You point out the model to convert with `--model_name=` It supports running a model from examples/models or models -from a python file if you just specify `ModelUnderTest` and `ModelInputs` in it. - -``` -$ python3 -m backends.arm.scripts.aot_arm_compiler --help -``` - -This is how you generate a BundleIO BPTE of a simple add example +Example to install the default Arm backend dependencies and add them to your current shell: +```bash +./examples/arm/setup.sh --i-agree-to-the-contained-eula +source examples/arm/arm-scratch/setup_path.sh ``` -$ python3 -m backends.arm.scripts.aot_arm_compiler --model_name=examples/arm/example_modules/add.py --target=ethos-u55-128 --bundleio -``` - -The example model used has added two extra variables that is picked up to make this work. - -`ModelUnderTest` should be a `torch.nn.module` instance. - -`ModelInputs` should be a tuple of inputs to the forward function. - - -You can also use the models from example/models directly by just using the short name e.g. - -``` -$ python3 -m backends.arm.scripts.aot_arm_compiler --model_name=mv2 --target=ethos-u55-64 -``` - - -`aot_arm_compiler.py` is called from the scripts below so you don't need to, but it can be useful to do by hand in some cases. -## Host VGF example applications +## run.sh -The Arm examples directory also contains host-side VGF reference flows for -specific tasks: +`run.sh` is an end-to-end helper for building and executing an Arm backend +example. It sources the `setup_path.sh` script generated by `setup.sh`, runs +`aot_arm_compiler.py` to convert the selected model to a `.pte` or `.bpte`, +builds the matching runner with CMake, and starts the simulator or runtime for +the selected target when `--build_only` is not set. -- `examples/arm/image_classification_example_vgf` for DEiT image - classification. -- `examples/arm/super_resolution_example_vgf` for Swin2SR image - super-resolution. - - -## ExecuTorch on Arm Ethos-U55/U65 and U85 - -This example code will help you get going with the Corstone™-300/320 platforms and -run on the FVP and can be used a starting guide in your porting to your board/HW - -We will start from a PyTorch model in python, export it, convert it to a `.pte` -file - A binary format adopted by ExecuTorch. Then we will take the `.pte` -model file and embed that with a baremetal application executor_runner. We will -then take the executor_runner file, which contains not only the `.pte` binary but -also necessary software components to run standalone on a baremetal system. -The build flow will pick up the non delegated ops from the generated PTE file and -add CPU implementation of them. -Lastly, we will run the executor_runner binary on a Corstone™-300/320 FVP Simulator platform. - - -### Example workflow - -Below is example workflow to build an application for Ethos-U55/85. The script below requires an internet connection: - -``` -# Step [1] - setup necessary tools -$ cd -$ ./examples/arm/setup.sh --i-agree-to-the-contained-eula - -# Step [2] - Setup path to tools, The `setup.sh` script has generated a script that you need to source every time you restart you shell. -$ source examples/arm/arm-scratch/setup_path.sh +Build and test artifacts are written to `arm_test` by default. Use +`--et_build_root=` to choose another build root. -# Step [3] - build and run ExecuTorch and executor_runner baremetal example application -# on a Corstone(TM)-320 FVP to run a simple PyTorch model from a file. -$ ./examples/arm/run.sh --model_name=examples/arm/example_modules/add.py --target=ethos-u85-128 -``` - -The argument `--model_name=` is passed to `aot_arm_compiler.py` so you can use it in the same way -e.g. you can also use the models from example/models directly in the same way as above. +For example, after running `setup.sh` and sourcing the generated +`setup_path.sh`, build and run a model on an Ethos-U85 target with: -``` -$ ./examples/arm/run.sh --model_name=mv2 --target=ethos-u55-64 +```bash +./examples/arm/run.sh --model_name=examples/arm/example_modules/add.py --target=ethos-u85-128 ``` -The runner will by default set all inputs to "1" and you are supposed to add/change the code -handling the input for your hardware target to give the model proper input, maybe from your camera -or mic hardware. +For bundled input/output and ETDump testing: -While testing you can use the --bundleio flag to use the input from the python model file and -generate a .bpte instead of a .pte file. This will embed the input example data and reference output -in the bpte file/data, which is used to verify the model's output. You can also use --etdump to generate -an ETRecord and a ETDump trace files from your target (they are printed as base64 strings in the serial log). - -Just keep in mind that CPU cycles are NOT accurate on the FVP simulator and it can not be used for -performance measurements, so you need to run on FPGA or actual ASIC to get good results from --etdump. -As a note the printed NPU cycle numbers are still usable and closer to real values if the timing -adaptor is setup correctly. - -``` -# Build + run with BundleIO and ETDump -$ ./examples/arm/run.sh --model_name=lstm --target=ethos-u85-128 --bundleio --etdump +```bash +./examples/arm/run.sh --model_name=lstm --target=ethos-u85-128 --bundleio --etdump ``` +For Cortex-M testing, use a Cortex-M target and bundled I/O: -### Ethos-U minimal example - -See the jupyter notebook `ethos_u_minimal_example.ipynb` for an explained minimal example of the full flow for running a -PyTorch module on the EthosUDelegate. The notebook runs directly in some IDE:s s.a. VS Code, otherwise it can be run in -your browser using -``` -pip install jupyter -jupyter notebook ethos_u_minimal_example.ipynb +```bash +./examples/arm/run.sh --model_name=mv2 --target=cortex-m55 --bundleio ``` -## ExecuTorch on ARM Cortex-M +## Example Contents -For Cortex-M you run the script without delegating e.g `--no_delegate` as the build flow already supports picking up -the non delegated ops from the generated PTE file and add CPU implementation of them this will work out of the box in -most cases. +### Notebook examples -To run mobilenet_v2 on the Cortex-M55 only, without using the Ethos-U try this: +- [ethos_u_minimal_example.ipynb](ethos_u_minimal_example.ipynb) - Minimal + Ethos-U AOT, runtime build, and FVP execution flow. +- [vgf_minimal_example.ipynb](vgf_minimal_example.ipynb) - Minimal VGF + lowering and host execution flow. +- [cortex_m_mv2_example.ipynb](cortex_m_mv2_example.ipynb) - Cortex-M + MobileNetV2 export, quantization, runtime build, and FVP execution flow. +- [pruning_minimal_example.ipynb](pruning_minimal_example.ipynb) - Model + conditioning and pruning flow for Ethos-U85. +- [quantizer_tutorial.ipynb](quantizer_tutorial.ipynb) - Quantizer tutorial + for TOSA, Ethos-U, and VGF quantizers. -``` -$ ./examples/arm/run.sh --model_name=mv2 --target=ethos-u55-128 --no_delegate -``` +### Application examples +- [image_classification_example_ethos_u](image_classification_example_ethos_u/) + - End-to-end DEiT-Tiny image classification flow for Ethos-U, including + model fine-tuning, export, bare-metal runtime build, and Corstone-320 FVP + execution. +- [image_classification_example_vgf](image_classification_example_vgf/) - + DEiT-Tiny image classification flow for VGF host execution. +- [super_resolution_example_vgf](super_resolution_example_vgf) - Swin2SR image + super-resolution. +- [example_modules/add.py](example_modules/add.py) - Small external model file + usable with `run.sh --model_name=examples/arm/example_modules/add.py`. -### Online Tutorial +### Utility examples and guides -We also have a [tutorial](https://pytorch.org/executorch/stable/backends-arm-ethos-u) explaining the steps performed in these -scripts, expected results, possible problems and more. It is a step-by-step guide -you can follow to better understand this delegate. +- [ethos-u-porting-guide.md](ethos-u-porting-guide.md) - Notes for adapting + the example Ethos-U runtime integration to another target. +- [export_standalone_tosa_graph.py](export_standalone_tosa_graph.py) - + Example of exporting a standalone TOSA graph with multiple outputs. +- [visualize.py](visualize.py) - Helper used by `run.sh --model_explorer` to + visualize TOSA or PTE graphs. -### Project Templates +## Project Templates These project templates provide alternative starting points with different toolchains and build systems: diff --git a/examples/arm/ethos-u-setup/core_platform/0003-Guard-HardFault-Handler-for-Armv6-M.patch b/examples/arm/ethos-u-setup/core_platform/0003-Guard-HardFault-Handler-for-Armv6-M.patch new file mode 100644 index 00000000000..57a27cb3dee --- /dev/null +++ b/examples/arm/ethos-u-setup/core_platform/0003-Guard-HardFault-Handler-for-Armv6-M.patch @@ -0,0 +1,49 @@ +From 380045853a133f298cee1bcf0c959b93ea94f9a2 Mon Sep 17 00:00:00 2001 +From: RJ Ascani +Date: Wed, 13 May 2026 15:42:13 -0700 +Subject: [PATCH] Guard HardFault_Handler for Armv6-M / Armv8-M Baseline + +The Corstone-300 HardFault_Handler is written for Armv7-M / Armv8-M +Mainline: it uses an `ite eq` IT-block in inline asm, and dereferences +the SCB CFSR/BFAR/MMFAR fault-status registers. Neither is available +on Armv6-M (Cortex-M0/M0+) or Armv8-M Baseline (Cortex-M23), so the +file fails to compile when the Corstone-300 target source is built +with `-mcpu=cortex-m0plus` to exercise the scalar CMSIS-NN code paths +on the Corstone-300 M55 simulator (an ISA superset). + +Wrap the Mainline-only implementation in +`__ARM_ARCH_7M__ / 7EM / 8M_MAIN / 8_1M_MAIN` and fall back to a +minimal `printf("Hard fault"); exit(1)` stub on Baseline cores. +--- + targets/corstone-300/target.cpp | 8 ++++++++ + 1 file changed, 8 insertions(+) + +diff --git a/targets/corstone-300/target.cpp b/targets/corstone-300/target.cpp +index bda2248..4aa3eea 100644 +--- a/targets/corstone-300/target.cpp ++++ b/targets/corstone-300/target.cpp +@@ -246,6 +246,11 @@ struct ExcContext { + }; + + void HardFault_Handler() { ++ // Armv6-M (M0/M0+) and Armv8-M Baseline (M23) lack the IT instruction and ++ // the SCB CFSR/BFAR/MMFAR fault-status registers, so the rich handler ++ // can't compile or run there. Fall back to a minimal stub on those cores. ++#if defined(__ARM_ARCH_7M__) || defined(__ARM_ARCH_7EM__) || defined(__ARM_ARCH_8M_MAIN__) || \ ++ defined(__ARM_ARCH_8_1M_MAIN__) + int irq; + struct ExcContext *e; + uint32_t sp; +@@ -267,6 +272,9 @@ void HardFault_Handler() { + sp); + printf( + "%11s cfsr=0x%08" PRIx32 " bfar=0x%08" PRIx32 " mmfar=0x%08" PRIx32 "\n", "", SCB->CFSR, SCB->BFAR, SCB->MMFAR); ++#else ++ printf("Hard fault\n"); ++#endif + exit(1); + } + } +-- +2.53.0 + diff --git a/examples/arm/ethos-u-setup/core_software/0002-Fix-ARMCM0plus-directory-case-and-compile-define-mis.patch b/examples/arm/ethos-u-setup/core_software/0002-Fix-ARMCM0plus-directory-case-and-compile-define-mis.patch new file mode 100644 index 00000000000..96dcdd9f29d --- /dev/null +++ b/examples/arm/ethos-u-setup/core_software/0002-Fix-ARMCM0plus-directory-case-and-compile-define-mis.patch @@ -0,0 +1,77 @@ +From 1ee9cf9c956ea6a266fc79dfa62071131f162510 Mon Sep 17 00:00:00 2001 +From: RJ Ascani +Date: Wed, 13 May 2026 15:48:07 -0700 +Subject: [PATCH] Fix ARMCM0plus directory case and compile-define mismatch +MIME-Version: 1.0 +Content-Type: text/plain; charset=UTF-8 +Content-Transfer-Encoding: 8bit + +The Cortex DFP names the Cortex-M0+ device directory and headers +`ARMCM0plus` (lowercase suffix), while the device source files +(`startup_ARMCM0plus.c`, `system_ARMCM0plus.c`) gate their +implementations on the `ARMCM0P` preprocessor macro — three different +spellings. `cmsis.cmake` previously did +`string(TOUPPER \"ARMCM\${CPU_NUMBER}\" ARM_CPU)`, producing +`ARMCM0PLUS`: the include path lookup fails and the source files hit +their `#error device not specified!` guard. + +Override `ARM_CPU` to `ARMCM0plus` and introduce a separate +`CMSIS_DEVICE_CPU_DEFINE` set to `ARMCM0P` for the cmsis_startup and +cmsis_system compile-definitions; all other cores still drive both +paths from the uppercased default. +--- + cmsis.cmake | 20 ++++++++++++++++++-- + 1 file changed, 18 insertions(+), 2 deletions(-) + +diff --git a/cmsis.cmake b/cmsis.cmake +index 7f2b93f..c49f205 100644 +--- a/cmsis.cmake ++++ b/cmsis.cmake +@@ -23,6 +23,15 @@ endif() + + string(TOUPPER "ARMCM${CPU_NUMBER}" ARM_CPU) + ++# Cortex-M0+ is special: the Cortex DFP names the device directory and headers ++# `ARMCM0plus` (lowercase suffix), while the device sources gate their ++# implementations on the `ARMCM0P` preprocessor macro. Override both so the ++# directory lookup and `#include` resolution succeed; the compile-definition ++# override is applied instead of `CMSIS_DEVICE_CPU_FEATURE` further down. ++if(CPU_NUMBER STREQUAL "0plus") ++ set(ARM_CPU "ARMCM0plus") ++endif() ++ + # Set CPU specific features + if(CMAKE_SYSTEM_PROCESSOR MATCHES "cortex-m33(\\+|$)") + set(ARM_FEATURES "_DSP_FP") +@@ -50,6 +59,13 @@ else() + cmake_path(SET CMSIS_DEVICE_CPU_FEATURE "${ARM_CPU}") + endif() + ++# Macro the device sources gate on. Matches CMSIS_DEVICE_CPU_FEATURE for most ++# cores; Cortex-M0+ keys off `ARMCM0P`, not `ARMCM0plus`. ++set(CMSIS_DEVICE_CPU_DEFINE "${CMSIS_DEVICE_CPU_FEATURE}") ++if(CPU_NUMBER STREQUAL "0plus") ++ set(CMSIS_DEVICE_CPU_DEFINE "ARMCM0P") ++endif() ++ + target_include_directories(cmsis_device INTERFACE ${CMSIS_DEVICE_PATH}/${ARM_CPU}/Include) + + target_compile_options(cmsis_device INTERFACE +@@ -66,12 +82,12 @@ target_sources(cmsis_startup INTERFACE + set_source_files_properties(${CMSIS_DEVICE_PATH}/${ARM_CPU}/Source/startup_${ARM_CPU}.c + PROPERTIES COMPILE_FLAGS -Wno-redundant-decls) + +-target_compile_definitions(cmsis_startup INTERFACE ${CMSIS_DEVICE_CPU_FEATURE}) ++target_compile_definitions(cmsis_startup INTERFACE ${CMSIS_DEVICE_CPU_DEFINE}) + target_link_libraries(cmsis_startup INTERFACE cmsis_device) + + # CMSIS system + add_library(cmsis_system INTERFACE) + target_sources(cmsis_system INTERFACE + ${CMSIS_DEVICE_PATH}/${ARM_CPU}/Source/system_${ARM_CPU}.c) +-target_compile_definitions(cmsis_system INTERFACE ${CMSIS_DEVICE_CPU_FEATURE}) ++target_compile_definitions(cmsis_system INTERFACE ${CMSIS_DEVICE_CPU_DEFINE}) + target_link_libraries(cmsis_system INTERFACE cmsis_startup) +-- +2.53.0 + diff --git a/examples/arm/ethos_u_cmsis_nn_fallback_example.ipynb b/examples/arm/ethos_u_cmsis_nn_fallback_example.ipynb new file mode 100644 index 00000000000..0dd8f7045fb --- /dev/null +++ b/examples/arm/ethos_u_cmsis_nn_fallback_example.ipynb @@ -0,0 +1,262 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Copyright 2026 Arm Limited and/or its affiliates.\n", + "#\n", + "# This source code is licensed under the BSD-style license found in the\n", + "# LICENSE file in the root directory of this source tree." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Ethos-U55 with CMSIS-NN fallback example\n", + "\n", + "This guide demonstrates the current full flow for handling operators which does not lower\n", + "to the Ethos-U55 using the Cortex-M backend to make sure they use accelerated CMSIS-NN implementations. \n", + "The basic idea is that the Ethos-U backend will reject any nodes which are not supported,\n", + "leaving them to be handled by the Cortex-M backend.\n", + "\n", + "Before you begin: Make sure you have completed the `ethos_u_minimal_example` for a\n", + "basic understanding of the Ethos-U backend and have your environment setup. \n", + "\n", + "\n", + "*Some scripts in this notebook produces long output logs: Configuring the 'Customizing Notebook Layout' settings to enable 'Output:scrolling' and setting 'Output:Text Line Limit' makes this more manageable*" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup\n", + "\n", + "The first step is creating a simple model which does not fully lower to the Ethos-U55.\n", + "Importantly it is exported with channels_last data, since the Cortex-M backend currently\n", + "only supports lowering operators in that data-format. \n", + "\n", + "Constraints for the basic operations performed by the Ethos-U55 can be found in the\n", + "[Ethos-U Vela repository](https://gitlab.arm.com/artificial-intelligence/ethos-u/ethos-u-vela/-/blob/main/SUPPORTED_OPS.md?ref_type=heads#ethos-u55-and-ethos-u65-tosa-conv2d-constraints). Note that the listed operators does not map exactly to PyTorch operators, but rather a subset found in\n", + "the graph after decompositions in the Ethos-U backend." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from executorch.backends.arm.ethosu import EthosUCompileSpec, EthosUPartitioner\n", + "from executorch.backends.arm.quantizer import (\n", + " EthosUQuantizer,\n", + " get_symmetric_quantization_config,\n", + ")\n", + "from executorch.backends.cortex_m.passes.cortex_m_pass_manager import CortexMPassManager\n", + "from executorch.exir import (\n", + " EdgeCompileConfig,\n", + " ExecutorchBackendConfig,\n", + " to_edge_transform_and_lower,\n", + ")\n", + "from executorch.extension.export_util.utils import save_pte_program\n", + "from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e\n", + "\n", + "target = \"ethos-u55-128\"\n", + "output_path = \"ethos_u_cmsis_nn_fallback_example.pte\"\n", + "\n", + "class ToyMixedModule(torch.nn.Module):\n", + " def __init__(self):\n", + " super().__init__()\n", + " self.conv1 = torch.nn.Conv2d(\n", + " in_channels=3,\n", + " out_channels=4,\n", + " kernel_size=3,\n", + " stride=1,\n", + " padding=1,\n", + " bias=False,\n", + " )\n", + " self.conv2 = torch.nn.Conv2d(\n", + " in_channels=4,\n", + " out_channels=1,\n", + " kernel_size=3,\n", + " stride=4,\n", + " padding=1,\n", + " bias=False,\n", + " ) # Stride=4 not supported on Ethos-U55\n", + "\n", + " def forward(self, x: torch.Tensor) -> torch.Tensor:\n", + " x = self.conv1(x)\n", + " x = torch.relu(x)\n", + " return self.conv2(x)\n", + "\n", + "model = ToyMixedModule().eval().to(memory_format=torch.channels_last)\n", + "example_inputs = (\n", + " torch.randn(1, 3, 8, 8, dtype=torch.float32).to(memory_format=torch.channels_last),\n", + ")\n", + "exported_program = torch.export.export(model, example_inputs)\n", + "exported_program.module().graph.print_tabular()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Ethos-U lowering\n", + "\n", + "The Ethos-U lowering of the model is identical to the minimal example, and as expected\n", + "the printed graph leaves the regular `torch.nn.Conv2d` with `stride=4` and some quantization/dequantization nodes\n", + "outside of the Ethos_u call_delegate operator. \n", + "\n", + "One important part in this step is that this `torch.nn.Conv2d` with `stride=4` has been quantized to\n", + "a format supported by the Cortex-M backend by the Ethos-U quantizer even if it was not\n", + "delegated, since the Cortex-M backend will only lower correctly quantized operators. Would there be\n", + "a discrepancy, see the [quantizer tutorial](https://github.com/pytorch/executorch/blob/main/examples/arm/quantizer_tutorial.ipynb) for\n", + "how to configure more precise quantization." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "compile_spec = EthosUCompileSpec(target=target)\n", + "quantizer = EthosUQuantizer(compile_spec)\n", + "quantizer.set_global(get_symmetric_quantization_config(is_per_channel=True))\n", + "\n", + "prepared = prepare_pt2e(exported_program.module(), quantizer)\n", + "prepared(*example_inputs)\n", + "quantized_model = convert_pt2e(prepared)\n", + "quantized_exported_program = torch.export.export(quantized_model, example_inputs)\n", + "\n", + "edge_program_manager = to_edge_transform_and_lower(\n", + " quantized_exported_program,\n", + " partitioner=[EthosUPartitioner(compile_spec)],\n", + " compile_config=EdgeCompileConfig(_check_ir_validity=False),\n", + ")\n", + "\n", + "edge_program_manager.exported_program().graph_module.graph.print_tabular()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Cortex-M lowering\n", + "\n", + "Finally the Cortex-M backend is applied, and the graph is now fully accelerated. The\n", + "`cortex_m_kernels` can be spotted in the printed graph." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "edge_program_manager._edge_programs[\"forward\"] = CortexMPassManager(\n", + " edge_program_manager.exported_program()\n", + ").transform()\n", + "\n", + "executorch_program = edge_program_manager.to_executorch(\n", + " config=ExecutorchBackendConfig(extract_delegate_segments=False)\n", + ")\n", + "save_pte_program(executorch_program, output_path)\n", + "\n", + "edge_program_manager.exported_program().graph_module.graph.print_tabular()\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Build\n", + "\n", + "The executor runner is built as usual, making sure to link the Cortex-M dependencies. In the available\n", + "example executor_runner CMakeFile this is already done, with the Cortex-M kernel and kernel registration libraries\n", + "`cortex_m_kernels` and `cortex_m_ops_lib` corresponding to `portable_kernels` and `arm_portable_ops_lib` for the the\n", + "unaccelerated portable kernels. For more information about kernel registration, see the\n", + "[documentation](https://docs.pytorch.org/executorch/stable/kernel-library-custom-aten-kernel.html).\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%%bash \n", + "source arm-scratch/setup_path.sh\n", + "# Ensure CMake resolves the ExecuTorch checkout root regardless of caller env\n", + "export EXECUTORCH_ROOT=$(cd ../.. && pwd)\n", + "\n", + "# Build example executor runner application to examples/arm/ethos_u_cmsis_nn_fallback_example\n", + "cmake -DCMAKE_TOOLCHAIN_FILE=$(pwd)/ethos-u-setup/arm-none-eabi-gcc.cmake \\\n", + " -DCMAKE_BUILD_TYPE=Release \\\n", + " -DET_PTE_FILE_PATH=ethos_u_cmsis_nn_fallback_example.pte \\\n", + " -DTARGET_CPU=cortex-m55 \\\n", + " -DETHOSU_TARGET_NPU_CONFIG=ethos-u55-128 \\\n", + " -DMEMORY_MODE=Shared_Sram \\\n", + " -DSYSTEM_CONFIG=Ethos_U55_High_End_Embedded \\\n", + " -Bethos_u_cmsis_nn_fallback_example \\\n", + " -S executor_runner/standalone\n", + "cmake --build ethos_u_cmsis_nn_fallback_example -j$(nproc) -- arm_executor_runner" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Sanity check output" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import subprocess\n", + "import re\n", + "\n", + "# Use quantized model in eager mode as reference. By default the executor runner will use 1:s as input.\n", + "test_inputs = (torch.ones_like(example_inputs[0]),)\n", + "reference_result = quantized_exported_program.module()(*test_inputs).flatten().tolist()\n", + "\n", + "# Run the lowered .pte file on FVP using helper script and extract the output numbers using regex\n", + "fvp_output = subprocess.run(\"../../backends/arm/scripts/run_fvp.sh --elf=ethos_u_cmsis_nn_fallback_example/arm_executor_runner --target=ethos-u55-128\", shell=True, capture_output=True)\n", + "lowered_result = [float(x) for x in re.findall(\"-?\\d\\.\\d{6}\" , str(fvp_output.stdout))]\n", + "\n", + "print(reference_result)\n", + "print(lowered_result)\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv (3.10.15)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.15" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/examples/arm/executor_runner/CMakeLists.txt b/examples/arm/executor_runner/CMakeLists.txt index d84947a75ad..88050a2ae77 100644 --- a/examples/arm/executor_runner/CMakeLists.txt +++ b/examples/arm/executor_runner/CMakeLists.txt @@ -349,7 +349,6 @@ elseif(FOUND_OPS_IN_FILE) "gen_oplist: EXECUTORCH_SELECT_OPS_MODEL=${ET_PTE_FILE_PATH} is used to auto generate ops from" ) else() - set(EXECUTORCH_SELECT_OPS_LIST "") set(EXECUTORCH_SELECT_OPS_MODEL "") message( "gen_oplist: No non delagated ops was found in ${ET_PTE_FILE_PATH} no ops added to build" diff --git a/examples/arm/run.sh b/examples/arm/run.sh index cfbcae2dbad..3ef4b0b829b 100755 --- a/examples/arm/run.sh +++ b/examples/arm/run.sh @@ -659,7 +659,7 @@ configure_ethosu_scratch_if_requested() { return fi local scratch_size - scratch_size=$(get_ethosu_scratch_size "$pte_path" || true) + scratch_size=$(get_ethosu_scratch_size "$pte_path" | tail -n 1) if [[ -z "${scratch_size}" ]]; then echo "WARNING: Failed to derive Ethos-U scratch size from ${pte_path}" >&2 return diff --git a/examples/models/gemma4_31b/README.md b/examples/models/gemma4_31b/README.md index da4aa893079..ae3bcb24c19 100644 --- a/examples/models/gemma4_31b/README.md +++ b/examples/models/gemma4_31b/README.md @@ -15,6 +15,7 @@ both export and eager inference: |---|---|---| | `quantize_and_save.py` | bf16 HF checkpoint → quantized checkpoint (one-time) | ~30 GB CPU | | `export.py --prequantized ` | quantized checkpoint → `model.pte` + `model.ptd` | ~24 GB CPU + CUDA for packing | +| `export.py --gguf [--backend mlx]` | GGUF file (Q4_K_M, etc.) → `model.pte` + `model.ptd` | ~24 GB CPU | | `inference.py --prequantized ` | quantized checkpoint → eager generation under `torch.compile` | ~24 GB GPU | | `inference.py --gguf ` | GGUF file (Q4_K_M, etc.) → eager generation | ~24 GB GPU | | `export.py --model-dir ` | one-shot bf16 → quantize → export (no intermediate file) | ~30 GB CPU + CUDA for packing | @@ -92,6 +93,24 @@ method with dynamic sequence length and host-side sampling. Writes `model.pte` (and optionally `model.ptd`) into `--output-dir`. +#### TurboQuant KV cache (long context, MLX only) + +For long-context inference, add `--turboquant` to swap the full-attention +layers' KV cache for a TurboQuant TQ4 cache (4-bit codebook + nibble pack). +This gives ~3.8× cache memory savings on the full-attention layers and lets +you fit context lengths that wouldn't fit in bf16. Sliding-window layers are unaffected. + +```bash +python examples/models/gemma4_31b/export.py \ + --prequantized ./gemma4_31b_int4 \ + --output-dir ./gemma4_31b_exports_mlx_tq \ + --max-seq-len 65536 \ + --backend mlx \ + --turboquant +``` + +Use TurboQuant when you need context beyond what bf16 fits; otherwise leave it off. + ## Eager inference The prompt is automatically wrapped with the Gemma 4 IT chat template. diff --git a/examples/models/gemma4_31b/export.py b/examples/models/gemma4_31b/export.py index 046e365947b..ed3dcdba9c3 100644 --- a/examples/models/gemma4_31b/export.py +++ b/examples/models/gemma4_31b/export.py @@ -141,12 +141,19 @@ def export_and_lower( config: Gemma4_31BConfig, output_dir: str, backend: str = "cuda", + use_turboquant: bool = False, ) -> None: """Export and lower the model to ExecuTorch for the given backend.""" if backend == "cuda": + if use_turboquant: + raise ValueError( + "--turboquant is only supported with --backend mlx " + "(the CUDA path here uses a different TurboQuant integration; " + "see examples/models/qwen3_5_moe/export.py)." + ) _export_cuda(model, config, output_dir) elif backend == "mlx": - _export_mlx(model, config, output_dir) + _export_mlx(model, config, output_dir, use_turboquant=use_turboquant) else: raise ValueError( f"Unsupported backend: {backend!r}. Supported: {_SUPPORTED_BACKENDS}." @@ -279,7 +286,12 @@ def _export_cuda(model: Gemma4_31B, config: Gemma4_31BConfig, output_dir: str) - print("Done.") -def _export_mlx(model: Gemma4_31B, config: Gemma4_31BConfig, output_dir: str) -> None: +def _export_mlx( + model: Gemma4_31B, + config: Gemma4_31BConfig, + output_dir: str, + use_turboquant: bool = False, +) -> None: """Export to .pte via torch.export + MLX backend. Unlike CUDA (which exports separate decode/prefill methods with an @@ -287,6 +299,10 @@ def _export_mlx(model: Gemma4_31B, config: Gemma4_31BConfig, output_dir: str) -> sequence length. No int4_dispatch import — IntxUnpackedToInt8Tensor's default dispatch produces the ``dequantize_affine → linear`` pattern that MLX's QuantizedLinearHandler matches. + + When ``use_turboquant=True``, full-attention layers swap to + ``MLXTurboQuantKVCache`` for ~3.8× KV cache memory savings. Sliding + layers are unaffected (already use ``RingBufferKVCache``). """ import gc @@ -304,10 +320,13 @@ def _export_mlx(model: Gemma4_31B, config: Gemma4_31BConfig, output_dir: str) -> from executorch.exir.passes import MemoryPlanningPass from torch.export import Dim, export - mlx_source_transformations(model, dtype=torch.bfloat16) + mlx_source_transformations( + model, dtype=torch.bfloat16, use_turboquant=use_turboquant + ) + materialize_runtime_buffers(model, dtype=torch.bfloat16) - max_prefill = min(config.max_seq_len - 1, config.sliding_window * 2) + max_prefill = 256 seq_dim = Dim("seq_len", min=1, max=max_prefill) print(f"Exporting (T in [1, {max_prefill}])...") @@ -418,8 +437,17 @@ def main() -> None: choices=list(_SUPPORTED_BACKENDS), help="Target backend for export.", ) + parser.add_argument( + "--turboquant", + action="store_true", + help="Use TurboQuant TQ4 KV cache compression (MLX backend only). " + "~3.8× cache memory savings; applies only to full-attention " + "(non-sliding) layers — sliding layers keep RingBufferKVCache.", + ) args = parser.parse_args() + if args.turboquant and args.backend != "mlx": + parser.error("--turboquant requires --backend mlx.") if args.backend == "cuda" and not torch.cuda.is_available(): parser.error("CUDA is required for the cuda backend.") @@ -443,7 +471,18 @@ def main() -> None: backend=args.backend, ) - export_and_lower(model, config, args.output_dir, backend=args.backend) + if args.gguf and args.backend == "mlx": + os.environ["ET_MLX_ALLOW_NON_FUSED_QUANTIZED_OPS"] = "1" + try: + export_and_lower( + model, + config, + args.output_dir, + backend=args.backend, + use_turboquant=args.turboquant, + ) + finally: + os.environ.pop("ET_MLX_ALLOW_NON_FUSED_QUANTIZED_OPS", None) if __name__ == "__main__": diff --git a/examples/models/gemma4_31b/gguf_loader.py b/examples/models/gemma4_31b/gguf_loader.py index 3e50991e553..35dddb5a0dc 100644 --- a/examples/models/gemma4_31b/gguf_loader.py +++ b/examples/models/gemma4_31b/gguf_loader.py @@ -12,6 +12,7 @@ Usage: model, config = load_gguf_model("model.gguf", backend="cuda") + model, config = load_gguf_model("model.gguf", backend="mlx") """ from typing import Optional @@ -104,10 +105,11 @@ def load_gguf_model( Streams tensors one at a time for low peak memory. GGUF ties ``embed_tokens`` and ``lm_head`` into a single Q4_K tensor. - We untie them: the embedding is dequantized to bf16 (``nn.Embedding`` - needs gather, which ``Int4TilePackedTo4dTensor`` does not support), - while ``lm_head`` keeps the original Q4_K quantization (``nn.Linear`` - matmul via tinygemm). + We untie them so ``lm_head`` keeps the original Q4_K quantization. + On CUDA, the embedding is dequantized to bf16 because ``Int4Tensor`` + does not support the gather op that ``nn.Embedding`` requires. On + MLX, the embedding stays quantized — ``QuantizedEmbeddingHandler`` + handles quantized gather natively. Returns ``(model, config)``. """ @@ -120,8 +122,12 @@ def load_gguf_model( from executorch.examples.models.gemma4_31b.quant import DEFAULT_CUDA_PACKERS packers = DEFAULT_CUDA_PACKERS + elif backend == "mlx": + from executorch.examples.models.gemma4_31b.quant import DEFAULT_MLX_PACKERS + + packers = DEFAULT_MLX_PACKERS else: - raise ValueError(f"Unsupported backend: {backend!r}. Supported: 'cuda'.") + raise ValueError(f"Unsupported backend: {backend!r}. Supported: 'cuda', 'mlx'.") config = Gemma4_31BConfig(max_seq_len=max_seq_len) @@ -143,7 +149,8 @@ def load_gguf_model( if model_key == "embed_tokens.weight" and isinstance(result, Int4Tensor): embed_quant = result - result = dequantize_weight(result, torch.bfloat16) + if backend == "cuda": + result = dequantize_weight(result, torch.bfloat16) pack_one(model, model_key, result, packers) diff --git a/examples/models/gemma4_31b/mlx_source_transformations.py b/examples/models/gemma4_31b/mlx_source_transformations.py index 3a8ae4420e3..0bbd4f7b250 100644 --- a/examples/models/gemma4_31b/mlx_source_transformations.py +++ b/examples/models/gemma4_31b/mlx_source_transformations.py @@ -24,6 +24,9 @@ KVCache as MLXKVCache, RingBufferKVCache as MLXRingKVCache, ) +from executorch.backends.mlx.llm.turboquant_cache import ( + TurboQuantKVCache as MLXTurboQuantKVCache, +) def _replace_attention_forward(attn: nn.Module) -> None: @@ -68,30 +71,34 @@ def _mlx_forward(self, x: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor q = torch.ops.mlx.rope(q, rotary_dim, start_pos, False, 0.0, 1.0, mlx_freqs) k = torch.ops.mlx.rope(k, rotary_dim, start_pos, False, 0.0, 1.0, mlx_freqs) - k_cache, v_cache = self.kv_cache.update(start_pos, k, v) - - if self.is_sliding: - sdpa_mask = self.kv_cache.create_sliding_window_mask(start_pos, T) - y = torch.ops.mlx.custom_sdpa( - q, - k_cache, - v_cache, - start_pos=self.kv_cache.buffer_size - T, - attn_mask=sdpa_mask, - dropout_p=0.0, - is_causal=False, - scale=self.scaling, - ) + if getattr(self, "is_turboquant", False): + self.kv_cache.update(start_pos, k, v) + y = self.kv_cache.sdpa(q, start_pos, scale=self.scaling) else: - y = torch.ops.mlx.custom_sdpa( - q, - k_cache, - v_cache, - start_pos=start_pos, - dropout_p=0.0, - is_causal=True, - scale=self.scaling, - ) + k_cache, v_cache = self.kv_cache.update(start_pos, k, v) + + if self.is_sliding: + sdpa_mask = self.kv_cache.create_sliding_window_mask(start_pos, T) + y = torch.ops.mlx.custom_sdpa( + q, + k_cache, + v_cache, + start_pos=self.kv_cache.buffer_size - T, + attn_mask=sdpa_mask, + dropout_p=0.0, + is_causal=False, + scale=self.scaling, + ) + else: + y = torch.ops.mlx.custom_sdpa( + q, + k_cache, + v_cache, + start_pos=start_pos, + dropout_p=0.0, + is_causal=True, + scale=self.scaling, + ) y = y.transpose(1, 2).contiguous().view(B, T, self.n_heads * self.head_dim) return self.o_proj(y) @@ -150,6 +157,7 @@ def _mlx_model_forward( def mlx_source_transformations( model: nn.Module, dtype: torch.dtype = torch.bfloat16, + use_turboquant: bool = False, ) -> None: """Apply MLX source transformations to a Gemma 4 31B model in-place. @@ -162,6 +170,13 @@ def mlx_source_transformations( - Rewrites layer forward to drop mask parameters (each attention builds its own mask via ``custom_sdpa``) - Rewrites model forward to drop the sampler and ``_build_masks`` + + Args: + model: Gemma4_31B model to transform in place. + dtype: dtype for KV cache buffers (bf16 by default). + use_turboquant: If True, swap full-attention layers' KV caches + for ``MLXTurboQuantKVCache`` (~3.8× cache memory savings). + Sliding-window layers are unaffected. """ config = model.config @@ -176,6 +191,17 @@ def mlx_source_transformations( head_dim=attn.head_dim, dtype=dtype, ) + attn.is_turboquant = False + elif use_turboquant: + attn.kv_cache = MLXTurboQuantKVCache( + max_batch_size=1, + max_context_length=config.max_seq_len, + n_heads=attn.n_kv_heads, + head_dim=attn.head_dim, + enable_dynamic_shape=True, + dtype=dtype, + ) + attn.is_turboquant = True else: attn.kv_cache = MLXKVCache( max_batch_size=1, @@ -185,6 +211,7 @@ def mlx_source_transformations( enable_dynamic_shape=True, dtype=dtype, ) + attn.is_turboquant = False _replace_attention_forward(attn) _replace_layer_forward(layer) diff --git a/examples/models/gemma4_31b/quant/README.md b/examples/models/gemma4_31b/quant/README.md index 2eacced4387..92ddbf97243 100644 --- a/examples/models/gemma4_31b/quant/README.md +++ b/examples/models/gemma4_31b/quant/README.md @@ -50,5 +50,3 @@ The format is compatible with torchao's `save_pretrained` / `load_pretrained`. - `pack_metal.py` — Metal backend packer. - `gguf.py` — extend with Q5_K, Q8_0 GGUF quant types. -- Upstream `Int4TilePackedTo4dTensor.from_int4_tensor()` to torchao - to replace the manual conversion in `pack_int4_for_cuda`. diff --git a/examples/models/gemma4_31b/quant/pack_mlx.py b/examples/models/gemma4_31b/quant/pack_mlx.py index 63aeca426a8..d627c9c437c 100644 --- a/examples/models/gemma4_31b/quant/pack_mlx.py +++ b/examples/models/gemma4_31b/quant/pack_mlx.py @@ -22,7 +22,7 @@ from .pack import ModulePackerFn, pack_model # noqa: F401 -_MLX_SUPPORTED_GROUP_SIZES = (128, 64, 32) +_MLX_SUPPORTED_GROUP_SIZES = (128, 64, 32, 16) # --------------------------------------------------------------------------- @@ -126,7 +126,9 @@ def pack_for_mlx(module: nn.Module, weights: dict[str, torch.Tensor]) -> None: default dispatch produces the ``dequantize_affine → linear`` pattern MLX expects. Regroups to a compatible group_size when needed (e.g. per-axis group_size=5376 → group_size=128) since MLX's - ``parse_dequant_node`` only accepts group_size in {32, 64, 128}. + ``parse_dequant_node`` only accepts group_size in {16, 32, 64, 128}. + Group sizes ≥ 32 use the fused ``QuantizedMatmulNode``; group_size=16 + (e.g. GGUF Q6_K) falls back to ``DequantizeNode`` + matmul at export. """ from torchao.quantization import IntxUnpackedToInt8Tensor from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor diff --git a/examples/models/gemma4_31b/quant/tests/test_pack_mlx.py b/examples/models/gemma4_31b/quant/tests/test_pack_mlx.py index ffb2e0e2dd3..2e6310b9c10 100644 --- a/examples/models/gemma4_31b/quant/tests/test_pack_mlx.py +++ b/examples/models/gemma4_31b/quant/tests/test_pack_mlx.py @@ -146,7 +146,7 @@ def test_regroup_preserves_dequant(self): class TestMlxGroupSize(unittest.TestCase): def test_passthrough(self): - for gs in (32, 64, 128): + for gs in (16, 32, 64, 128): self.assertEqual(_mlx_group_size(gs, 256), gs) def test_regroup_5376(self): @@ -157,7 +157,49 @@ def test_regroup_256(self): def test_rejects_indivisible(self): with self.assertRaises(ValueError): - _mlx_group_size(48, 48) + _mlx_group_size(7, 7) + + +class TestPackLinearGroupSize16(unittest.TestCase): + """Packing group_size=16 weights (GGUF Q6_K) preserves semantics.""" + + def _make_gs16_tensor(self, N=64, K=128): + from torchao.quantization import IntxUnpackedToInt8Tensor + + return IntxUnpackedToInt8Tensor( + qdata=torch.randint(-32, 31, (N, K), dtype=torch.int8), + scale=torch.randn(N, K // 16, dtype=torch.bfloat16), + zero_point=torch.zeros(N, K // 16, dtype=torch.int8), + target_dtype=torch.int8, + block_size=(1, 16), + dtype=torch.bfloat16, + activation_quantization=None, + ) + + def test_dequant_preserves_values(self): + """Packing preserves the dequantized weight values.""" + w = self._make_gs16_tensor(64, 128) + before = dequantize_weight(w, torch.float32) + + module = nn.Linear(128, 64, bias=False) + pack_for_mlx(module, {"weight": w}) + after = dequantize_weight(module.weight.data, torch.float32) + + self.assertTrue( + torch.allclose(before, after, atol=1e-5), + f"max diff: {(before - after).abs().max():.6g}", + ) + + def test_forward_produces_valid_output(self): + """Packed gs=16 weight produces finite output in a linear forward.""" + w = self._make_gs16_tensor(64, 128) + module = nn.Linear(128, 64, bias=False) + pack_for_mlx(module, {"weight": w}) + + x = torch.randn(1, 128, dtype=torch.bfloat16) + out = torch.nn.functional.linear(x, module.weight.data.dequantize()) + self.assertEqual(out.shape, torch.Size([1, 64])) + self.assertFalse(torch.isnan(out).any()) class TestPackEmbeddingForMlx(unittest.TestCase): diff --git a/examples/models/gemma4_31b/tests/test_mlx_pipeline.py b/examples/models/gemma4_31b/tests/test_mlx_pipeline.py index 0e62ab88e4b..37f61fddb0f 100644 --- a/examples/models/gemma4_31b/tests/test_mlx_pipeline.py +++ b/examples/models/gemma4_31b/tests/test_mlx_pipeline.py @@ -244,5 +244,84 @@ def test_export_to_pte(self): self.assertTrue(os.path.exists(os.path.join(out_dir, "model.pte"))) +class TestGgufMlxPipeline(unittest.TestCase): + """Test GGUF → MLX loading path with synthetic Q6_K-like tensors.""" + + def test_load_gguf_model_mlx_backend(self): + """gguf_loader.load_gguf_model accepts backend='mlx'.""" + try: + import gguf # noqa: F401 + except ModuleNotFoundError: + self.skipTest("gguf package not installed") + + from executorch.examples.models.gemma4_31b.gguf_loader import load_gguf_model + + # Will fail on missing file, but NOT on "Unsupported backend". + with self.assertRaisesRegex((FileNotFoundError, OSError, RuntimeError), ".*"): + load_gguf_model("/nonexistent.gguf", backend="mlx") + + def test_mlx_backend_rejects_unknown(self): + from executorch.examples.models.gemma4_31b.gguf_loader import load_gguf_model + + with self.assertRaisesRegex(ValueError, "Unsupported backend"): + load_gguf_model("/nonexistent.gguf", backend="tpu") + + def test_gs16_packing_preserves_values(self): + """Q6_K-like weight (gs=16) preserves dequantized values after packing.""" + from executorch.examples.models.gemma4_31b.quant.pack_mlx import pack_for_mlx + from executorch.examples.models.gemma4_31b.quant.quantize import ( + dequantize_weight, + ) + from torchao.quantization import IntxUnpackedToInt8Tensor + + w = IntxUnpackedToInt8Tensor( + qdata=torch.randint(-32, 31, (64, 128), dtype=torch.int8), + scale=torch.randn(64, 8, dtype=torch.bfloat16), + zero_point=torch.zeros(64, 8, dtype=torch.int8), + target_dtype=torch.int8, + block_size=(1, 16), + dtype=torch.bfloat16, + activation_quantization=None, + ) + before = dequantize_weight(w, torch.float32) + + module = nn.Linear(128, 64, bias=False) + pack_for_mlx(module, {"weight": w}) + after = dequantize_weight(module.weight.data, torch.float32) + + self.assertTrue( + torch.allclose(before, after, atol=1e-5), + f"max diff: {(before - after).abs().max():.6g}", + ) + + def test_embedding_packing_preserves_values(self): + """MLX embedding packing preserves dequantized weight values.""" + from executorch.examples.models.gemma4_31b.quant.pack_mlx import pack_for_mlx + from executorch.examples.models.gemma4_31b.quant.quantize import ( + dequantize_weight, + ) + from torchao.quantization import IntxUnpackedToInt8Tensor + + w = IntxUnpackedToInt8Tensor( + qdata=torch.randint(-8, 7, (256, 128), dtype=torch.int8), + scale=torch.randn(256, 4, dtype=torch.bfloat16), + zero_point=torch.zeros(256, 4, dtype=torch.bfloat16), + target_dtype=torch.int4, + block_size=(1, 32), + dtype=torch.bfloat16, + activation_quantization=None, + ) + before = dequantize_weight(w, torch.float32) + + module = nn.Embedding(256, 128) + pack_for_mlx(module, {"weight": w}) + after = dequantize_weight(module.weight.data, torch.float32) + + self.assertTrue( + torch.allclose(before, after, atol=1e-5), + f"max diff: {(before - after).abs().max():.6g}", + ) + + if __name__ == "__main__": unittest.main() diff --git a/examples/models/llama/eval_llama_lib.py b/examples/models/llama/eval_llama_lib.py index 23d00ff8c15..b562a2b3c70 100644 --- a/examples/models/llama/eval_llama_lib.py +++ b/examples/models/llama/eval_llama_lib.py @@ -1,5 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. +# Copyright 2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -46,9 +47,13 @@ def __init__( use_kv_cache: bool = False, generate_full_logits: bool = False, enable_dynamic_shape: bool = True, + device: Optional[str] = None, ): super().__init__( - model=model, tokenizer=tokenizer, max_seq_length=max_seq_length + model=model, + tokenizer=tokenizer, + max_seq_length=max_seq_length, + device=device, ) self._model = model.to(self.device) self._use_kv_cache = use_kv_cache @@ -57,30 +62,70 @@ def __init__( def _model_call(self, inps): if self._use_kv_cache: - if not self._enable_dynamic_shape: - # graph module exported without dynamic shape won't work with a different shape. - # And we have to do single token prefill here. - result_logits = [] - for pos in range(inps.shape[-1]): - pos_tensor = torch.tensor([pos], dtype=torch.int64) - logits = self._model( - inps[:, pos : pos + 1], {"input_pos": pos_tensor} - ) - result_logits.append(logits) - if self._generate_full_logits: - return torch.cat(result_logits, dim=1) - else: - return torch.stack(result_logits, dim=1) - else: - pos_tensor = torch.tensor([0], dtype=torch.int64, device=self.device) - # Batch process the whole sequence. - logits = self._model( - inps[:, : self._max_seq_length], {"input_pos": pos_tensor} - ) - return logits + return self._model_call_kv_cache(inps) + return self._model_call_no_kv_cache(inps) - else: - return self._model(inps) + def _model_call_kv_cache(self, inps): + if self._enable_dynamic_shape: + pos_tensor = torch.tensor([0], dtype=torch.int64, device=self.device) + return self._model( + inps[:, : self._max_seq_length], {"input_pos": pos_tensor} + ) + + # graph module exported without dynamic shape won't work with a different shape. + # And we have to do single token prefill here. + result_logits = [] + for pos in range(inps.shape[-1]): + pos_tensor = torch.tensor([pos], dtype=torch.int64) + logits = self._model(inps[:, pos : pos + 1], {"input_pos": pos_tensor}) + result_logits.append(logits) + if self._generate_full_logits: + return torch.cat(result_logits, dim=1) + return torch.stack(result_logits, dim=1) + + def _model_call_no_kv_cache(self, inps): + # lm-eval expects logits shaped [batch, seq, vocab]. In the non-KV path, + # some exported graphs (when generate_full_logits=False) return only + # last-position logits [batch, vocab], so reconstruct per-position + # logits by running prefix calls. + if not self._enable_dynamic_shape and not self._generate_full_logits: + raise ValueError( + "Static non-KV lm-eval requires generate_full_logits=True " + "so logits can be read from the last non-pad token." + ) + + if self._generate_full_logits: + return self._model(self._pad_to_max_len(inps)) + + result_logits = [] + seq_len = inps.shape[-1] + for pos in range(min(seq_len, self._max_seq_length)): + prefix = self._pad_to_max_len(inps[:, : pos + 1]) + logits = self._model(prefix) + if logits.dim() == 3: + logits = logits[:, -1, :] + result_logits.append(logits) + + return torch.stack(result_logits, dim=1) + + def _pad_to_max_len(self, tokens: torch.Tensor) -> torch.Tensor: + if self._enable_dynamic_shape: + return tokens + token_len = tokens.shape[-1] + if token_len > self._max_seq_length: + return tokens[:, : self._max_seq_length] + if token_len == self._max_seq_length: + return tokens + + pad_len = self._max_seq_length - token_len + pad_token = getattr(self._tokenizer, "pad_id", self._tokenizer.eos_id) + pad = torch.full( + (tokens.shape[0], pad_len), + pad_token, + dtype=tokens.dtype, + device=tokens.device, + ) + return torch.cat((tokens, pad), dim=-1) def _model_generate(self, context, max_length, eos_token_id): raise Exception("unimplemented") @@ -219,6 +264,7 @@ def gen_eval_wrapper( tokenizer=tokenizer, max_seq_length=llm_config.export.max_seq_length, use_kv_cache=llm_config.model.use_kv_cache, + generate_full_logits=llm_config.debug.generate_full_logits, enable_dynamic_shape=llm_config.model.enable_dynamic_shape, ) else: diff --git a/examples/models/llama/evaluate/eager_eval.py b/examples/models/llama/evaluate/eager_eval.py index 9d5d7ad447b..5c129e1c250 100644 --- a/examples/models/llama/evaluate/eager_eval.py +++ b/examples/models/llama/evaluate/eager_eval.py @@ -1,5 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. +# Copyright 2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -28,12 +29,13 @@ def __init__( tokenizer: Union[SentencePieceTokenizer, Tiktoken, HuggingFaceTokenizer], max_seq_length: Optional[int] = None, use_kv_cache: bool = False, + device: Optional[str] = None, ): - device = "cuda" if torch.cuda.is_available() else "cpu" - super().__init__(device=device, pretrained="gpt2") + resolved_device = device or ("cuda" if torch.cuda.is_available() else "cpu") + super().__init__(device=resolved_device, pretrained="gpt2") self._model = model self._tokenizer = tokenizer - self._device = torch.device(device) + self._device = torch.device(resolved_device) self._max_seq_length = 2048 if max_seq_length is None else max_seq_length self._use_kv_cache = use_kv_cache diff --git a/examples/models/llama/model.py b/examples/models/llama/model.py index f02621b66b2..8ae146dda0f 100644 --- a/examples/models/llama/model.py +++ b/examples/models/llama/model.py @@ -1,5 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. +# Copyright 2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -285,11 +286,25 @@ def get_example_inputs(self): if self.use_kv_cache: return self.get_example_inputs_kvcache_sdpa() else: - return ( - torch.tensor( - [[1, 2, 3]], dtype=torch.long - ), # tokens, with kv cache our input token length is always just 1 token. + max_seq_len = getattr(self.llm_config.export, "max_seq_length", 3) + # Preserve the historical three-token example input as the minimum. + max_seq_len = max(3, int(max_seq_len)) + max_len = max_seq_len - 1 if self.enable_dynamic_shape else max_seq_len + backend = self.llm_config.backend + token_dtype = ( + torch.int32 + if ( + backend.ethosu.enabled + or backend.tosa.enabled + or backend.vgf.enabled + ) + else torch.long ) + example_tokens = torch.arange(max_len, dtype=token_dtype).unsqueeze(0) + vocab_size = int(getattr(self.model_.params, "vocab_size", 0)) + if vocab_size > 1: + example_tokens = example_tokens % (vocab_size - 1) + 1 + return (example_tokens,) # assumption is the custom op doesnt support dynamic shape right now. It might but its untested so lets first get static shape working def get_example_inputs_kvcache_sdpa(self): diff --git a/examples/models/llama/norm.py b/examples/models/llama/norm.py index ec92b353eb4..0b6ed7f5b01 100644 --- a/examples/models/llama/norm.py +++ b/examples/models/llama/norm.py @@ -154,6 +154,14 @@ def replace_rms_norm_for_coreml_(model: torch.nn.Module) -> torch.nn.Module: # Preserve trained scale (no-op for ScalelessRMSNorm). if getattr(mod, "weight", None) is not None: new.weight = mod.weight + else: + # Source was weightless (e.g. ScalelessRMSNorm). The freshly-allocated + # `nn.Parameter(torch.ones(dim))` inside RMSNormCoreML defaults to fp32, + # which causes an fp32 leak in fp16 export. Match the model's existing + # parameter dtype/device. + ref = next((p for p in model.parameters() if p.is_floating_point()), None) + if ref is not None: + new.to(dtype=ref.dtype, device=ref.device) # Locate parent module via the dotted name and rebind the attribute. if "." in name: parent_name, attr = name.rsplit(".", 1) diff --git a/examples/models/parakeet/main.cpp b/examples/models/parakeet/main.cpp index 249e8fd14d4..b8a052004e4 100644 --- a/examples/models/parakeet/main.cpp +++ b/examples/models/parakeet/main.cpp @@ -152,13 +152,14 @@ int main(int argc, char** argv) { ET_LOG(Error, "Preprocessing failed."); return 1; } - auto mel_features = preprocess_result.get(); + auto preprocess_out = preprocess_result.get(); // --- Transcribe --- ET_LOG(Info, "Running TDT greedy decode..."); - auto result = runner.transcribe(mel_features, [](const std::string& piece) { - std::cout << piece << std::flush; - }); + auto result = runner.transcribe( + preprocess_out.features, + [](const std::string& piece) { std::cout << piece << std::flush; }, + preprocess_out.length); if (!result.ok()) { ET_LOG(Error, "Transcription failed."); diff --git a/examples/nxp/executor_runner/nxp_executor_runner.cpp b/examples/nxp/executor_runner/nxp_executor_runner.cpp index 65f5831e5c5..52d7c778227 100644 --- a/examples/nxp/executor_runner/nxp_executor_runner.cpp +++ b/examples/nxp/executor_runner/nxp_executor_runner.cpp @@ -384,71 +384,30 @@ int main(int argc, char* argv[]) { torch::executor::MemoryManager memory_manager( &method_allocator, &planned_memory, &tmp_allocator); - Result method = - program->load_method(method_name, &memory_manager); - if (!method.ok()) { - fprintf( - stderr, - "Loading of method (%s) failed with status %" PRIu32 "...\n", - method_name, - (unsigned int)method.error()); - exit(-1); - } - printf("Method loaded...\n"); - - Error status = Error::Ok; - if (!FLAGS_dataset.empty()) { - // Go through entire dataset for this model. - FLAGS_dataset += "/"; - while (dataset = readdir(datasetDir)) { - if (!strcmp(dataset->d_name, ".") || !strcmp(dataset->d_name, "..")) - continue; - - std::vector inputsData; - inputsData.push_back(FLAGS_dataset + dataset->d_name); - // Set input and call inferrence. - setInputs(method.get(), inputsData); - - status = method->execute(); - if (status != Error::Ok) { - fprintf( - stderr, - "Execution of method %s failed with status %" PRIu32 "...\n", - method_name, - (unsigned int)status); - exit(-1); - } else { - printf("Method executed successfully...\n"); - } - - // Save outputs in binary files. - saveOutputs(method.get(), FLAGS_output, dataset->d_name); - // Print result with highest confidence. - printOutput(method.get(), FLAGS_output, dataset->d_name); + { + Result method = + program->load_method(method_name, &memory_manager); + if (!method.ok()) { + fprintf( + stderr, + "Loading of method (%s) failed with status %" PRIu32 "...\n", + method_name, + (unsigned int)method.error()); + exit(-1); } - closedir(datasetDir); - } else if (!FLAGS_inputs.empty()) { - std::vector inputPaths; - - // Validate and process inputs and separate into two lists. - processInputs(inputPaths, FLAGS_inputs); - - if (std::all_of(inputPaths.begin(), inputPaths.end(), isDirectory)) { - // Inputs are in directories - use files in each directory as the inputs. - std::vector inputsData; - for (std::string& inputDir : inputPaths) { - datasetDir = opendir(inputDir.c_str()); - while (dataset = readdir(datasetDir)) { - if (!strcmp(dataset->d_name, ".") || !strcmp(dataset->d_name, "..")) - continue; - - inputsData.push_back(inputDir + "/" + dataset->d_name); - } - closedir(datasetDir); - - // Sort inputsData to ensure correct input ordering - std::sort(inputsData.begin(), inputsData.end()); - + printf("Method loaded...\n"); + + Error status = Error::Ok; + if (!FLAGS_dataset.empty()) { + // Go through entire dataset for this model. + FLAGS_dataset += "/"; + while (dataset = readdir(datasetDir)) { + if (!strcmp(dataset->d_name, ".") || !strcmp(dataset->d_name, "..")) + continue; + + std::vector inputsData; + inputsData.push_back(FLAGS_dataset + dataset->d_name); + // Set input and call inferrence. setInputs(method.get(), inputsData); status = method->execute(); @@ -463,37 +422,81 @@ int main(int argc, char* argv[]) { printf("Method executed successfully...\n"); } - if (inputDir.back() == '/') - inputDir.pop_back(); - - auto pos = inputDir.find_last_of('/'); - if (pos != std::string::npos) - inputDir = inputDir.substr(pos + 1); - // Save outputs in binary files. - saveOutputs(method.get(), FLAGS_output, inputDir.c_str()); - inputsData.clear(); + saveOutputs(method.get(), FLAGS_output, dataset->d_name); + // Print result with highest confidence. + printOutput(method.get(), FLAGS_output, dataset->d_name); } - } else { - // Inputs are files. - setInputs(method.get(), inputPaths); - - status = method->execute(); - if (status != Error::Ok) { - fprintf( - stderr, - "Execution of method %s failed with status %" PRIu32 "...\n", - method_name, - (unsigned int)status); - exit(-1); + closedir(datasetDir); + } else if (!FLAGS_inputs.empty()) { + std::vector inputPaths; + + // Validate and process inputs and separate into two lists. + processInputs(inputPaths, FLAGS_inputs); + + if (std::all_of(inputPaths.begin(), inputPaths.end(), isDirectory)) { + // Inputs are in directories - use files in each directory as the + // inputs. + std::vector inputsData; + for (std::string& inputDir : inputPaths) { + datasetDir = opendir(inputDir.c_str()); + while (dataset = readdir(datasetDir)) { + if (!strcmp(dataset->d_name, ".") || !strcmp(dataset->d_name, "..")) + continue; + + inputsData.push_back(inputDir + "/" + dataset->d_name); + } + closedir(datasetDir); + + // Sort inputsData to ensure correct input ordering + std::sort(inputsData.begin(), inputsData.end()); + + setInputs(method.get(), inputsData); + + status = method->execute(); + if (status != Error::Ok) { + fprintf( + stderr, + "Execution of method %s failed with status %" PRIu32 "...\n", + method_name, + (unsigned int)status); + exit(-1); + } else { + printf("Method executed successfully...\n"); + } + + if (inputDir.back() == '/') + inputDir.pop_back(); + + auto pos = inputDir.find_last_of('/'); + if (pos != std::string::npos) + inputDir = inputDir.substr(pos + 1); + + // Save outputs in binary files. + saveOutputs(method.get(), FLAGS_output, inputDir.c_str()); + inputsData.clear(); + } } else { - printf("Method executed successfully...\n"); - } + // Inputs are files. + setInputs(method.get(), inputPaths); + + status = method->execute(); + if (status != Error::Ok) { + fprintf( + stderr, + "Execution of method %s failed with status %" PRIu32 "...\n", + method_name, + (unsigned int)status); + exit(-1); + } else { + printf("Method executed successfully...\n"); + } - // Save outputs in binary files. - saveOutputs(method.get(), FLAGS_output); + // Save outputs in binary files. + saveOutputs(method.get(), FLAGS_output); + } } - } + } // Destruct the method object before destroying the Neutron Device. printf("Finished...\n"); diff --git a/examples/qualcomm/oss_scripts/llama/artifacts/stories260k_hybrid_llama_qnn.pte b/examples/qualcomm/oss_scripts/llama/artifacts/stories260k_hybrid_llama_qnn.pte index ad6bee06146..5903c5b5c32 100644 Binary files a/examples/qualcomm/oss_scripts/llama/artifacts/stories260k_hybrid_llama_qnn.pte and b/examples/qualcomm/oss_scripts/llama/artifacts/stories260k_hybrid_llama_qnn.pte differ diff --git a/examples/qualcomm/oss_scripts/llama/decoder_runtime_evaluator.py b/examples/qualcomm/oss_scripts/llama/decoder_runtime_evaluator.py index 7bebf513658..a75e67933e5 100644 --- a/examples/qualcomm/oss_scripts/llama/decoder_runtime_evaluator.py +++ b/examples/qualcomm/oss_scripts/llama/decoder_runtime_evaluator.py @@ -133,7 +133,7 @@ def _init_runner_base_cmd(self): base_cmd = " ".join( [ f"export LD_LIBRARY_PATH={self.qnn_sdk}/lib/x86_64-linux-clang/:{args.build_folder}/lib &&", - f"./{args.build_folder}/examples/qualcomm/oss_scripts/llama/{self.runner}", + f"{args.build_folder}/examples/qualcomm/oss_scripts/llama/{self.runner}", f"--decoder_model_version {DECODER_MODEL_VERSION[args.decoder_model]}", f"--tokenizer_path {self.runtime_tokenizer_path}", f"--output_path {self.device_output_response_path}", diff --git a/examples/qualcomm/oss_scripts/llama/decoder_utils.py b/examples/qualcomm/oss_scripts/llama/decoder_utils.py index 5380ff5220d..184eb857661 100644 --- a/examples/qualcomm/oss_scripts/llama/decoder_utils.py +++ b/examples/qualcomm/oss_scripts/llama/decoder_utils.py @@ -317,13 +317,9 @@ def retrieve_info_from_pte(pte_path: str) -> dict: pte_max_context_len = pte_max_seq_len # FP has no scale/zero_point, use following values, which is equivalent to not performing dequantize. - if kv_io_bit_width == 32: + if kv_io_bit_width == 32 or (logits_scale is None or logits_zero_point is None): logits_scale = 1 logits_zero_point = 0 - elif logits_scale is None or logits_zero_point is None: - raise RuntimeError( - "Unable to find scale/offset. The .pte file might be deprecated. Please generate a new .pte file" - ) assert output_vocab_size is not None, "Couldn't find the vocab size" assert pte_max_seq_len is not None, "Couldn't find the max_seq_len from pte" meta_info = { diff --git a/examples/qualcomm/oss_scripts/llama/llama.py b/examples/qualcomm/oss_scripts/llama/llama.py index a8e28f96b71..ce0b7a80cfc 100755 --- a/examples/qualcomm/oss_scripts/llama/llama.py +++ b/examples/qualcomm/oss_scripts/llama/llama.py @@ -21,6 +21,7 @@ ) from executorch.backends.qualcomm.utils.utils import ( + generate_gpu_compiler_spec, generate_htp_compiler_spec, generate_qnn_executorch_compiler_spec, get_soc_to_chipset_map, @@ -119,9 +120,15 @@ def compile( # because the encoder is quite sensitive and quantization can make it harder for the model to distinguish # between images within the same conversation. to_skip = len(args.image_path) > 1 - backend_options = generate_htp_compiler_spec( - use_fp16=to_skip, - ) + if args.backend == "htp": + backend_options = generate_htp_compiler_spec( + use_fp16=to_skip, + ) + elif args.backend == "gpu": + backend_options = generate_gpu_compiler_spec() + else: + raise ValueError(f"Unsupported backend {args.backend}") + encoder_compile_specs = generate_qnn_executorch_compiler_spec( soc_model=get_soc_to_chipset_map()[args.soc_model], backend_options=backend_options, @@ -131,27 +138,40 @@ def compile( skip_quantize[modality] = to_skip compile_specs[modality] = encoder_compile_specs elif is_multimodal and modality == TOK_EMBEDDING: - backend_options = generate_htp_compiler_spec( - use_fp16=False, - # x86 emulator does not support weight sharing - use_weight_sharing=not args.enable_x86_64, - ) + if args.backend == "htp": + backend_options = generate_htp_compiler_spec( + use_fp16=False, + # x86 emulator does not support weight sharing + use_weight_sharing=not args.enable_x86_64, + ) + elif args.backend == "gpu": + backend_options = generate_gpu_compiler_spec() + else: + raise ValueError(f"Unsupported backend {args.backend}") + compile_specs[modality] = [ generate_qnn_executorch_compiler_spec( soc_model=get_soc_to_chipset_map()[args.soc_model], backend_options=backend_options, # x86 emulator does not support shared buffer shared_buffer=not args.enable_x86_64, + online_prepare=args.online_prepare, ) ] * len(TOK_EMBEDDING_GRAPH_NAMES) elif modality == TEXT_DECODER: # compile spec for text decoder - backend_options = generate_htp_compiler_spec( - use_fp16=False, - use_multi_contexts=decoder_model_config.num_sharding > 1, - # x86 emulator does not support weight sharing - use_weight_sharing=not args.enable_x86_64, - ) + if args.backend == "htp": + backend_options = generate_htp_compiler_spec( + use_fp16=args.use_fp16, + use_multi_contexts=decoder_model_config.num_sharding > 1, + # x86 emulator does not support weight sharing + use_weight_sharing=not args.enable_x86_64, + ) + elif args.backend == "gpu": + backend_options = generate_gpu_compiler_spec() + else: + raise ValueError(f"Unsupported backend {args.backend}") + skip_quantize[modality] = args.use_fp16 compile_specs[modality] = [ generate_qnn_executorch_compiler_spec( soc_model=get_soc_to_chipset_map()[args.soc_model], @@ -159,6 +179,7 @@ def compile( # x86 emulator does not support shared buffer shared_buffer=not args.enable_x86_64, use_mha2sha=True, + online_prepare=args.online_prepare, ) ] * len(DECODER_GRAPH_NAMES) @@ -172,7 +193,11 @@ def compile( ) # perform compilation - multi_modal_mgr.compile(compile_specs=compile_specs, pte_filenames=pte_filenames) + multi_modal_mgr.compile( + compile_specs=compile_specs, + pte_filenames=pte_filenames, + skip_quantize=skip_quantize, + ) def inference( @@ -529,6 +554,14 @@ def _build_parser(): help="Number of examples in few-shot context", ) + parser.add_argument( + "-F", + "--use_fp16", + help="If specified, will run in fp16 precision and discard ptq setting", + action="store_true", + default=False, + ) + parser.add_argument("-v", "--verbose", action="store_true") parser.add_argument( @@ -592,6 +625,12 @@ def export_llama(args) -> None: pte_filename = "lookahead_llama_qnn" else: raise RuntimeError(f"Unknown model_mode: {args.model_mode}.") + + if args.model_mode == "hybrid" and args.online_prepare: + raise RuntimeError( + "Currently hybrid mode is not compatible with online_prepare." + ) + if args.decoder_model == "stories260k": pte_filename = f"{args.decoder_model}_" + pte_filename pte_filenames = { @@ -740,6 +779,7 @@ def export_llama(args) -> None: def main(): parser = _build_parser() args = parser.parse_args() + args.build_folder = os.path.realpath(args.build_folder) try: export_llama(args) except Exception as e: diff --git a/examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp b/examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp index d8d82fece33..9b8cdd7999e 100644 --- a/examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp +++ b/examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp @@ -210,7 +210,6 @@ std::string get_formatted_prompt( return formatted_prompt; } -template void start_runner( std::unique_ptr module, std::vector& prompts, @@ -219,7 +218,7 @@ void start_runner( gflags::GetCommandLineFlagInfoOrDie("tokenized_prompt").is_default ? false : true; // create llama runner - example::Runner runner( + example::Runner runner( std::move(module), FLAGS_decoder_model_version.c_str(), FLAGS_model_path.c_str(), @@ -298,26 +297,8 @@ int main(int argc, char** argv) { FLAGS_attention_sink_rope_path.c_str(), executorch::extension::Module::LoadMode::MmapUseMlockIgnoreErrors); } - // Using 8bit as default since this meta is introduced with 16bit kv io - // support and older models only have 8bit kv io. - example::KvBitWidth kv_bitwidth = example::KvBitWidth::kWidth8; - if (module->method_names()->count("get_kv_io_bit_width") > 0) { - kv_bitwidth = static_cast( - module->get("get_kv_io_bit_width").get().toScalar().to()); - } - - if (kv_bitwidth == example::KvBitWidth::kWidth8) { - start_runner( - std::move(module), prompts, std::move(attention_sink_rope_module)); - } else if (kv_bitwidth == example::KvBitWidth::kWidth16) { - start_runner( - std::move(module), prompts, std::move(attention_sink_rope_module)); - } else { - ET_CHECK_MSG( - false, - "Unsupported kv bitwidth: %ld", - static_cast(kv_bitwidth)); - } + start_runner( + std::move(module), prompts, std::move(attention_sink_rope_module)); return 0; } diff --git a/examples/qualcomm/oss_scripts/llama/qnn_multimodal_runner.cpp b/examples/qualcomm/oss_scripts/llama/qnn_multimodal_runner.cpp index 29b6b9d7ddc..c9c2bd19940 100644 --- a/examples/qualcomm/oss_scripts/llama/qnn_multimodal_runner.cpp +++ b/examples/qualcomm/oss_scripts/llama/qnn_multimodal_runner.cpp @@ -137,7 +137,6 @@ std::vector CollectPrompts(int argc, char** argv) { return prompts; } -template void start_multimodal_runner( std::unique_ptr encoder, std::unique_ptr tok_embedding, @@ -150,7 +149,7 @@ void start_multimodal_runner( : true; // Create multimodal runner - example::QNNMultimodalRunner runner( + example::QNNMultimodalRunner runner( std::move(encoder), std::move(tok_embedding), std::move(text_decoder), @@ -289,35 +288,12 @@ int main(int argc, char** argv) { FLAGS_decoder_path.c_str(), executorch::extension::Module::LoadMode::MmapUseMlockIgnoreErrors); - // Using 8bit as default since this meta is introduced with 16bit kv io - // support and older models only have 8bit kv io. - example::KvBitWidth kv_bitwidth = example::KvBitWidth::kWidth8; - if (text_decoder->method_names()->count("get_kv_io_bit_width") > 0) { - kv_bitwidth = static_cast( - text_decoder->get("get_kv_io_bit_width") - .get() - .toScalar() - .to()); - } - // Start runner with appropriate KV bitwidth - if (kv_bitwidth == example::KvBitWidth::kWidth8) { - start_multimodal_runner( - std::move(encoder), - std::move(tok_embedding), - std::move(text_decoder), - prompts); - } else if (kv_bitwidth == example::KvBitWidth::kWidth16) { - start_multimodal_runner( - std::move(encoder), - std::move(tok_embedding), - std::move(text_decoder), - prompts); - } else { - ET_CHECK_MSG( - false, - "Unsupported kv bitwidth: %ld", - static_cast(kv_bitwidth)); - } + // Start runner + start_multimodal_runner( + std::move(encoder), + std::move(tok_embedding), + std::move(text_decoder), + prompts); return 0; } diff --git a/examples/qualcomm/oss_scripts/llama/runner/decoder_runner.h b/examples/qualcomm/oss_scripts/llama/runner/decoder_runner.h index 888e9acd421..b714f737de3 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/decoder_runner.h +++ b/examples/qualcomm/oss_scripts/llama/runner/decoder_runner.h @@ -8,6 +8,7 @@ #pragma once +#include #include #include #include @@ -56,19 +57,36 @@ class DecoderRunner { inline int32_t logits_to_token( const executorch::aten::Tensor& logits_tensor, int64_t pos) { - auto* logits = logits_tensor.mutable_data_ptr(); + std::byte* logits = logits_tensor.mutable_data_ptr(); auto num_tokens = logits_tensor.size(1); auto vocab_size = logits_tensor.size(2); static std::vector logits_f(vocab_size); - auto* logits_last = logits; + std::byte* logits_last = logits; // offset to the meaningful logit we want for prefill model. + executorch::aten::ScalarType logits_dtype = logits_tensor.scalar_type(); + size_t logits_nbytes = getDtypeSize(logits_dtype); if (num_tokens > 1) { - logits_last += pos * vocab_size; + logits_last += pos * vocab_size * logits_nbytes; } - // Discard dequantization (converting uint16_t to float) because the + // Discard dequantization (converting std::byte to float) because the // relative order of elements remains the same without conversion for (int i = 0; i < vocab_size; i++) { - logits_f[i] = logits_last[i]; + switch (logits_dtype) { + case executorch::aten::ScalarType::UInt16: + logits_f[i] = reinterpret_cast(logits_last)[i]; + break; + case executorch::aten::ScalarType::Byte: + logits_f[i] = reinterpret_cast(logits_last)[i]; + break; + case executorch::aten::ScalarType::Float: + logits_f[i] = reinterpret_cast(logits_last)[i]; + break; + default: + ET_CHECK_MSG( + false, + "The scalar_type %s of logits is not supported", + executorch::runtime::toString(logits_dtype)); + } } return sampler_->sample(logits_f.data()); } diff --git a/examples/qualcomm/oss_scripts/llama/runner/kv_manager.cpp b/examples/qualcomm/oss_scripts/llama/runner/kv_manager.cpp index e5c12068bab..7288ca5fbd1 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/kv_manager.cpp +++ b/examples/qualcomm/oss_scripts/llama/runner/kv_manager.cpp @@ -7,24 +7,105 @@ */ #include +#include #include + +using executorch::runtime::MethodMeta; +using executorch::runtime::Result; +using executorch::runtime::TensorInfo; namespace example { -template -KVManager::KVManager(Metadata metadata) : metadata_(metadata) { + +namespace { +void fill_mask( + executorch::aten::ScalarType scalar_type, + std::byte* buf, + size_t size, + bool use_pos_value) { + if (use_pos_value) { + switch (scalar_type) { + case executorch::aten::ScalarType::UInt16: + std::fill_n(reinterpret_cast(buf), size, 65535u); + break; + case executorch::aten::ScalarType::Byte: + std::fill_n(reinterpret_cast(buf), size, 255u); + break; + case executorch::aten::ScalarType::Float: + std::fill_n(reinterpret_cast(buf), size, 0.0); + break; + default: + ET_CHECK_MSG( + false, + "Unsupported scalar type %s", + executorch::runtime::toString(scalar_type)); + break; + } + } else { + switch (scalar_type) { + case executorch::aten::ScalarType::UInt16: + std::fill_n(reinterpret_cast(buf), size, 0u); + break; + case executorch::aten::ScalarType::Byte: + std::fill_n(reinterpret_cast(buf), size, 0u); + break; + // -65535 acts as the additive "very negative" attention-mask value; + // chosen as a large finite negative so masked positions effectively + // zero out after softmax without relying on -inf. + case executorch::aten::ScalarType::Float: + std::fill_n(reinterpret_cast(buf), size, -65535.0); + break; + default: + ET_CHECK_MSG( + false, + "Unsupported scalar type %s", + executorch::runtime::toString(scalar_type)); + break; + } + } +} +} // namespace + +KVManager::KVManager(Metadata metadata, std::unique_ptr method_meta) + : metadata_(metadata) { + Result attention_mask = method_meta->input_tensor_meta(1); + attention_mask_dtype_ = attention_mask->scalar_type(); + + // inputs are [input_tokens, attention_mask, (sliding window attention_mask), + // (input_pos), kv_caches] search kv_cache in inputs + for (int i = 2; i < method_meta->num_inputs(); i++) { + Result tensor_meta = method_meta->input_tensor_meta(i); + // k_cache: [1, n_heads, head_dim, seq_len] + size_t tensor_nbytes = tensor_meta->nbytes(); + size_t expected_tensor_nbytes = metadata_.head_dim * metadata_.num_heads * + metadata_.max_cache_len * getDtypeSize(tensor_meta->scalar_type()); + if (tensor_nbytes != expected_tensor_nbytes) { + // Not a kv_cache tensor (e.g. input_pos, sliding window attention mask). + continue; + } + if (kv_cache_dtype_ == executorch::aten::ScalarType::Undefined) { + kv_cache_dtype_ = tensor_meta->scalar_type(); + } else { + ET_CHECK_MSG( + tensor_meta->scalar_type() == kv_cache_dtype_, + "Currently mixed scalar type of kv_cache is not allowed"); + } + } + ET_CHECK_MSG( + kv_cache_dtype_ != executorch::aten::ScalarType::Undefined, + "kv_cache_dtype was not detected from method inputs"); k_cache_.resize(metadata_.num_layers); v_cache_.resize(metadata_.num_layers); // Calculate cache size size_t cache_in_bytes = metadata_.num_layers * metadata_.num_heads * - metadata_.head_dim * metadata_.max_cache_len * sizeof(T); + metadata_.head_dim * metadata_.max_cache_len * + getDtypeSize(kv_cache_dtype_); size_t cache_out_bytes = metadata_.num_layers * metadata_.num_heads * - metadata_.head_dim * metadata_.max_ar_len * sizeof(T); + metadata_.head_dim * metadata_.max_ar_len * getDtypeSize(kv_cache_dtype_); total_cache_size_ = 2 * (cache_in_bytes + cache_out_bytes); }; -template -void KVManager::init_attention_mask( - uint16_t* attention_mask, +void KVManager::init_attention_mask( + std::byte* attention_mask, const std::vector& attention_map, int32_t ar_len, int32_t n_past) { @@ -33,38 +114,51 @@ void KVManager::init_attention_mask( "The size of attention_map (%zu) doesn't match with ar_len (%d)", attention_map.size(), ar_len); - uint16_t neg_val = 0; - uint16_t pos_val = 65535; // Clear the attention mask - std::fill_n(attention_mask, ar_len * metadata_.context_len, neg_val); + fill_mask( + attention_mask_dtype_, + attention_mask, + ar_len * metadata_.context_len, + /*use_pos_value=*/false); // SMART_MASK requires special handling of attention mask - uint16_t* past_ptr = attention_mask; - uint16_t* new_ptr = attention_mask + (metadata_.context_len - ar_len); + std::byte* past_ptr = attention_mask; + std::byte* new_ptr = attention_mask + + (metadata_.context_len - ar_len) * getDtypeSize(attention_mask_dtype_); // All inputs will necessarily attend to n_past and itself for (int i = 0; i < ar_len; i++) { // Iterate across ar_len if (attention_map[i] < 0) { // If negative, attend to only past tokens - std::fill_n(past_ptr, n_past, pos_val); + fill_mask( + attention_mask_dtype_, + past_ptr, + n_past, + /*use_pos_value=*/true); } else { // If positive, copy attention map from (relative to 0th input) parent // Parent token index const int32_t pidx = attention_map[i]; - uint16_t* parent_ptr = attention_mask + pidx * metadata_.context_len; + std::byte* parent_ptr = attention_mask + + pidx * metadata_.context_len * getDtypeSize(attention_mask_dtype_); std::memcpy( - past_ptr, parent_ptr, metadata_.context_len * sizeof(uint16_t)); + past_ptr, + parent_ptr, + metadata_.context_len * getDtypeSize(attention_mask_dtype_)); } // Attend to itself - new_ptr[i] = pos_val; - past_ptr += metadata_.context_len; - new_ptr += metadata_.context_len; + fill_mask( + attention_mask_dtype_, + new_ptr + i * getDtypeSize(attention_mask_dtype_), + 1, + /*use_pos_value=*/true); + past_ptr += metadata_.context_len * getDtypeSize(attention_mask_dtype_); + new_ptr += metadata_.context_len * getDtypeSize(attention_mask_dtype_); } } -template -void KVManager::init_attention_mask( - uint16_t* attention_mask, +void KVManager::init_attention_mask( + std::byte* attention_mask, const std::vector& attention_map, int32_t ar_len, int32_t n_past, @@ -75,30 +169,44 @@ void KVManager::init_attention_mask( "The size of attention_map (%zu) doesn't match with ar_len (%d)", attention_map.size(), ar_len); - uint16_t neg_val = 0; - uint16_t pos_val = 65535; // Clear the attention mask - std::fill_n(attention_mask, ar_len * metadata_.context_len, neg_val); + fill_mask( + attention_mask_dtype_, + attention_mask, + ar_len * metadata_.context_len, + /*use_pos_value=*/false); // SMART_MASK requires special handling of attention mask - uint16_t* past_ptr = attention_mask; - uint16_t* new_ptr = attention_mask + (metadata_.context_len - ar_len); + std::byte* past_ptr = attention_mask; + std::byte* new_ptr = attention_mask + + (metadata_.context_len - ar_len) * getDtypeSize(attention_mask_dtype_); // All inputs will necessarily attend to n_past and itself for (int i = 0; i < ar_len; i++) { // Iterate across ar_len if (attention_map[i] < 0) { // If negative, attend to only past tokens - std::fill_n(past_ptr, n_past, pos_val); + fill_mask( + attention_mask_dtype_, + past_ptr, + n_past, + /*use_pos_value=*/true); } else { // If positive, copy attention map from (relative to 0th input) parent // Parent token index const int32_t pidx = attention_map[i]; - uint16_t* parent_ptr = attention_mask + pidx * metadata_.context_len; + std::byte* parent_ptr = attention_mask + + pidx * metadata_.context_len * getDtypeSize(attention_mask_dtype_); std::memcpy( - past_ptr, parent_ptr, metadata_.context_len * sizeof(uint16_t)); + past_ptr, + parent_ptr, + metadata_.context_len * getDtypeSize(attention_mask_dtype_)); } // Attend to itself - new_ptr[i] = pos_val; + fill_mask( + attention_mask_dtype_, + new_ptr + i * getDtypeSize(attention_mask_dtype_), + 1, + /*use_pos_value=*/true); // mask by limitation of sliding_window int32_t available_context_len = position_offset.empty() @@ -107,87 +215,73 @@ void KVManager::init_attention_mask( // if available_context_len is less than 0, it means we need to mask some // tokens in the past to avoid exceeding the sliding window if (available_context_len < 0) { - std::fill_n(past_ptr, -available_context_len, neg_val); + fill_mask( + attention_mask_dtype_, + past_ptr, + -available_context_len, + /*use_pos_value=*/false); } - past_ptr += metadata_.context_len; - new_ptr += metadata_.context_len; + past_ptr += metadata_.context_len * getDtypeSize(attention_mask_dtype_); + new_ptr += metadata_.context_len * getDtypeSize(attention_mask_dtype_); } } -template -void KVManager::update_attention_mask( - uint16_t* attention_mask, +void KVManager::update_attention_mask( + std::byte* attention_mask, int32_t ar_len, int32_t n_past, int32_t n_update) { - uint16_t pos_val = 65535; - uint16_t* cur_ptr = attention_mask; - cur_ptr += n_past; + std::byte* cur_ptr = + attention_mask + n_past * getDtypeSize(attention_mask_dtype_); for (int i = 0; i < ar_len; i++) { - std::fill_n(cur_ptr, n_update, pos_val); - cur_ptr += metadata_.context_len; + fill_mask(attention_mask_dtype_, cur_ptr, n_update, /*use_pos_value=*/true); + cur_ptr += metadata_.context_len * getDtypeSize(attention_mask_dtype_); } } -template -void KVManager::update_attention_mask( - uint16_t* attention_mask, +void KVManager::update_attention_mask( + std::byte* attention_mask, int32_t ar_len, int32_t n_past, int32_t n_update, int32_t sliding_window, const std::vector& position_offset) { - uint16_t pos_val = 65535; - uint16_t neg_val = 0; - uint16_t* cur_ptr = attention_mask; - cur_ptr += n_past; + std::byte* cur_ptr = + attention_mask + n_past * getDtypeSize(attention_mask_dtype_); for (int i = 0; i < ar_len; i++) { - std::fill_n(cur_ptr, n_update, pos_val); + fill_mask(attention_mask_dtype_, cur_ptr, n_update, /*use_pos_value=*/true); int32_t available_cache_len = position_offset.empty() ? sliding_window - (i + 1) : sliding_window - (position_offset[i] + 1); if (n_past + n_update > available_cache_len) { - std::fill_n( - cur_ptr - n_past, n_past + n_update - available_cache_len, neg_val); + fill_mask( + attention_mask_dtype_, + cur_ptr - n_past * getDtypeSize(attention_mask_dtype_), + n_past + n_update, + /*use_pos_value=*/false); } - cur_ptr += metadata_.context_len; + cur_ptr += metadata_.context_len * getDtypeSize(attention_mask_dtype_); } } -template -void KVManager::init_cache(IMemAlloc* buffer_manager, int32_t ar_len) { +void KVManager::init_cache(IMemAlloc* buffer_manager, int32_t ar_len) { cur_ar_len_ = ar_len; - const size_t max_in_cache_block_in_bytes = - metadata_.max_cache_len * sizeof(T); - const size_t max_out_cache_block_in_bytes = metadata_.max_ar_len * sizeof(T); - - const size_t cache_in_bytes = - metadata_.num_heads * metadata_.head_dim * max_in_cache_block_in_bytes; - const size_t cache_out_bytes = - metadata_.num_heads * metadata_.head_dim * max_out_cache_block_in_bytes; + const size_t cache_in_bytes = metadata_.num_heads * metadata_.head_dim * + metadata_.max_cache_len * getDtypeSize(kv_cache_dtype_); + const size_t cache_out_bytes = metadata_.num_heads * metadata_.head_dim * + metadata_.max_ar_len * getDtypeSize(kv_cache_dtype_); for (int layer = 0; layer < metadata_.num_layers; ++layer) { - // Allocate buffer for key cache and value cache - T* single_layer_k_cache_in = - reinterpret_cast(buffer_manager->allocate(cache_in_bytes)); - T* single_layer_k_cache_out = - reinterpret_cast(buffer_manager->allocate(cache_out_bytes)); - T* single_layer_v_cache_in = - reinterpret_cast(buffer_manager->allocate(cache_in_bytes)); - T* single_layer_v_cache_out = - reinterpret_cast(buffer_manager->allocate(cache_out_bytes)); - - k_cache_[layer].buffer = single_layer_k_cache_in; - k_cache_[layer].output_buffer = single_layer_k_cache_out; - v_cache_[layer].buffer = single_layer_v_cache_in; - v_cache_[layer].output_buffer = single_layer_v_cache_out; + k_cache_[layer].buffer = buffer_manager->allocate(cache_in_bytes); + k_cache_[layer].output_buffer = buffer_manager->allocate(cache_out_bytes); + v_cache_[layer].buffer = buffer_manager->allocate(cache_in_bytes); + v_cache_[layer].output_buffer = buffer_manager->allocate(cache_out_bytes); } } -template -void KVManager::rearrange_cache(int32_t ar_len_dst) { +void KVManager::rearrange_cache(int32_t ar_len_dst) { // Don't need to rearrange if cur_ar_len_ is equal to target ar_len if (cur_ar_len_ == ar_len_dst) return; @@ -199,75 +293,73 @@ void KVManager::rearrange_cache(int32_t ar_len_dst) { cur_ar_len_ = ar_len_dst; } -template -void KVManager::rearrange_key(KVCache& k_cache, int32_t ar_len_dst) { +void KVManager::rearrange_key(KVCache& k_cache, int32_t ar_len_dst) { const int32_t src_cache_num = (cur_ar_len_ == metadata_.context_len) ? metadata_.context_len : metadata_.context_len - cur_ar_len_; const int32_t dst_cache_num = metadata_.context_len - ar_len_dst; - T* k_cache_in_read_ptr = k_cache.buffer; - T* k_cache_in_write_ptr = k_cache.buffer; - + std::byte* k_cache_in_read_ptr = k_cache.buffer; + std::byte* k_cache_in_write_ptr = k_cache.buffer; + size_t src_cache_nbytes = src_cache_num * getDtypeSize(kv_cache_dtype_); + size_t dst_cache_nbytes = dst_cache_num * getDtypeSize(kv_cache_dtype_); if (src_cache_num > dst_cache_num) { // copy from first dimension for (int i = 0; i < metadata_.head_dim * metadata_.num_heads; i++) { - std::memmove( - k_cache_in_write_ptr, k_cache_in_read_ptr, dst_cache_num * sizeof(T)); - k_cache_in_read_ptr += src_cache_num; - k_cache_in_write_ptr += dst_cache_num; + std::memmove(k_cache_in_write_ptr, k_cache_in_read_ptr, dst_cache_nbytes); + k_cache_in_read_ptr += src_cache_nbytes; + k_cache_in_write_ptr += dst_cache_nbytes; } } else { k_cache_in_read_ptr += - (metadata_.head_dim * metadata_.num_heads - 1) * src_cache_num; + (metadata_.head_dim * metadata_.num_heads - 1) * src_cache_nbytes; k_cache_in_write_ptr += - (metadata_.head_dim * metadata_.num_heads - 1) * dst_cache_num; + (metadata_.head_dim * metadata_.num_heads - 1) * dst_cache_nbytes; // copy from last dimension for (int i = 0; i < metadata_.head_dim * metadata_.num_heads; i++) { - std::memmove( - k_cache_in_write_ptr, k_cache_in_read_ptr, src_cache_num * sizeof(T)); - k_cache_in_read_ptr -= src_cache_num; - k_cache_in_write_ptr -= dst_cache_num; + std::memmove(k_cache_in_write_ptr, k_cache_in_read_ptr, src_cache_nbytes); + k_cache_in_read_ptr -= src_cache_nbytes; + k_cache_in_write_ptr -= dst_cache_nbytes; } } } -template -void KVManager::rearrange_value(KVCache& v_cache, int32_t ar_len_dst) { +void KVManager::rearrange_value(KVCache& v_cache, int32_t ar_len_dst) { const int32_t src_cache_num = (cur_ar_len_ == metadata_.context_len) ? metadata_.context_len : metadata_.context_len - cur_ar_len_; const int32_t dst_cache_num = metadata_.context_len - ar_len_dst; - T* v_cache_in_read_ptr = v_cache.buffer; - T* v_cache_in_write_ptr = v_cache.buffer; + std::byte* v_cache_in_read_ptr = v_cache.buffer; + std::byte* v_cache_in_write_ptr = v_cache.buffer; + size_t src_cache_nbytes = src_cache_num * getDtypeSize(kv_cache_dtype_); + size_t dst_cache_nbytes = dst_cache_num * getDtypeSize(kv_cache_dtype_); if (src_cache_num > dst_cache_num) { // copy from first dimension for (int i = 0; i < metadata_.num_heads; i++) { std::memmove( v_cache_in_write_ptr, v_cache_in_read_ptr, - dst_cache_num * metadata_.head_dim * sizeof(T)); - v_cache_in_read_ptr += src_cache_num * metadata_.head_dim; - v_cache_in_write_ptr += dst_cache_num * metadata_.head_dim; + dst_cache_nbytes * metadata_.head_dim); + v_cache_in_read_ptr += src_cache_nbytes * metadata_.head_dim; + v_cache_in_write_ptr += dst_cache_nbytes * metadata_.head_dim; } } else { v_cache_in_read_ptr += - metadata_.head_dim * (metadata_.num_heads - 1) * src_cache_num; + metadata_.head_dim * (metadata_.num_heads - 1) * src_cache_nbytes; v_cache_in_write_ptr += - metadata_.head_dim * (metadata_.num_heads - 1) * dst_cache_num; + metadata_.head_dim * (metadata_.num_heads - 1) * dst_cache_nbytes; // copy from last dimension for (int i = 0; i < metadata_.num_heads; i++) { std::memmove( v_cache_in_write_ptr, v_cache_in_read_ptr, - src_cache_num * metadata_.head_dim * sizeof(T)); - v_cache_in_read_ptr -= src_cache_num * metadata_.head_dim; - v_cache_in_write_ptr -= dst_cache_num * metadata_.head_dim; + src_cache_nbytes * metadata_.head_dim); + v_cache_in_read_ptr -= src_cache_nbytes * metadata_.head_dim; + v_cache_in_write_ptr -= dst_cache_nbytes * metadata_.head_dim; } } } -template -void KVManager::update_cache( +void KVManager::update_cache( int32_t ar_len, int32_t n_past, int32_t n_update, @@ -283,20 +375,19 @@ void KVManager::update_cache( } } -template -void KVManager::update_key( - KVCache& k_cache, +void KVManager::update_key( + KVCache& k_cache, int32_t n_past, int32_t n_update, const std::vector& selected) { - T* write_ptr = k_cache.buffer; - T* read_ptr = k_cache.output_buffer; - const int32_t copy_size = n_update * sizeof(T); + std::byte* write_ptr = k_cache.buffer; + std::byte* read_ptr = k_cache.output_buffer; + const int32_t copy_size = n_update * getDtypeSize(kv_cache_dtype_); const int32_t iter_size = (cur_ar_len_ == metadata_.context_len) - ? metadata_.context_len - : metadata_.context_len - cur_ar_len_; - const int32_t out_size = cur_ar_len_; - const int32_t past_size = n_past; + ? metadata_.context_len * getDtypeSize(kv_cache_dtype_) + : (metadata_.context_len - cur_ar_len_) * getDtypeSize(kv_cache_dtype_); + const int32_t out_size = cur_ar_len_ * getDtypeSize(kv_cache_dtype_); + const int32_t past_size = n_past * getDtypeSize(kv_cache_dtype_); const int32_t n_iter = metadata_.head_dim * metadata_.num_heads; write_ptr += past_size; @@ -316,7 +407,11 @@ void KVManager::update_key( for (int i = 0; i < n_iter; ++i) { auto wp = write_ptr, rp = read_ptr; for (auto ind : true_indices) { - *wp++ = rp[ind]; + std::memmove( + wp, + rp + ind * getDtypeSize(kv_cache_dtype_), + getDtypeSize(kv_cache_dtype_)); + wp += getDtypeSize(kv_cache_dtype_); } write_ptr += iter_size; read_ptr += out_size; @@ -324,21 +419,25 @@ void KVManager::update_key( } } -template -void KVManager::update_value( - KVCache& v_cache, +void KVManager::update_value( + KVCache& v_cache, int32_t n_past, int32_t n_update, const std::vector& selected) { - T* write_ptr = v_cache.buffer; - T* read_ptr = v_cache.output_buffer; - const int32_t copy_size = n_update * metadata_.head_dim * sizeof(T); - const int32_t past_size = n_past * metadata_.head_dim; + std::byte* write_ptr = v_cache.buffer; + std::byte* read_ptr = v_cache.output_buffer; + const int32_t copy_size = + n_update * metadata_.head_dim * getDtypeSize(kv_cache_dtype_); + const int32_t past_size = + n_past * metadata_.head_dim * getDtypeSize(kv_cache_dtype_); const int32_t n_iter = metadata_.num_heads; const int32_t iter_size = (cur_ar_len_ == metadata_.context_len) - ? metadata_.context_len * metadata_.head_dim - : (metadata_.context_len - cur_ar_len_) * metadata_.head_dim; - const int32_t out_size = cur_ar_len_ * metadata_.head_dim; + ? metadata_.context_len * metadata_.head_dim * + getDtypeSize(kv_cache_dtype_) + : (metadata_.context_len - cur_ar_len_) * metadata_.head_dim * + getDtypeSize(kv_cache_dtype_); + const int32_t out_size = + cur_ar_len_ * metadata_.head_dim * getDtypeSize(kv_cache_dtype_); write_ptr += past_size; @@ -354,13 +453,14 @@ void KVManager::update_value( auto wp = write_ptr, rp = read_ptr; for (auto sel : selected) { if (sel) { - std::memcpy(wp, rp, metadata_.head_dim * sizeof(T)); - wp += metadata_.head_dim; + std::memcpy( + wp, rp, metadata_.head_dim * getDtypeSize(kv_cache_dtype_)); + wp += metadata_.head_dim * getDtypeSize(kv_cache_dtype_); update_times--; if (update_times == 0) break; } - rp += metadata_.head_dim; + rp += metadata_.head_dim * getDtypeSize(kv_cache_dtype_); } write_ptr += iter_size; read_ptr += out_size; @@ -368,8 +468,4 @@ void KVManager::update_value( } } -// Explicit instantiations -template class KVManager; -template class KVManager; - } // namespace example diff --git a/examples/qualcomm/oss_scripts/llama/runner/kv_manager.h b/examples/qualcomm/oss_scripts/llama/runner/kv_manager.h index 06fe88517a7..3b8e67dd38d 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/kv_manager.h +++ b/examples/qualcomm/oss_scripts/llama/runner/kv_manager.h @@ -8,6 +8,7 @@ #pragma once #include +#include #include #include #include @@ -15,17 +16,15 @@ namespace example { // Structure to hold key-value cache buffers -template struct KVCache { - T* buffer; - T* output_buffer; + std::byte* buffer; + std::byte* output_buffer; }; /** * @class KVManager * @brief Class for kv cache update, rearrangement, and buffer allocatation. */ -template class KVManager { public: struct Metadata { @@ -36,7 +35,9 @@ class KVManager { int64_t num_heads; int64_t num_layers; }; - KVManager(Metadata metadata); + KVManager( + Metadata metadata, + std::unique_ptr method_meta); /** * @brief Allocate buffer for KV cache and set the cur_ar_len_. @@ -71,7 +72,7 @@ class KVManager { * @param n_past Number of past elements in the cache. */ void init_attention_mask( - uint16_t* attention_mask, + std::byte* attention_mask, const std::vector& attention_map, int32_t ar_len, int32_t n_past); @@ -98,7 +99,7 @@ class KVManager { * @param position_offset (optional) attention mask position offset of */ void init_attention_mask( - uint16_t* attention_mask, + std::byte* attention_mask, const std::vector& attention_map, int32_t ar_len, int32_t n_past, @@ -114,7 +115,7 @@ class KVManager { * @param n_update Number of elements to be updated. */ void update_attention_mask( - uint16_t* attention_mask, + std::byte* attention_mask, int32_t ar_len, int32_t n_past, int32_t n_update); @@ -132,7 +133,7 @@ class KVManager { * lookahead decoder */ void update_attention_mask( - uint16_t* attention_mask, + std::byte* attention_mask, int32_t ar_len, int32_t n_past, int32_t n_update, @@ -152,10 +153,10 @@ class KVManager { int32_t n_update, const std::vector& selected); - const std::vector>& get_k_cache_() const { + const std::vector& get_k_cache_() const { return k_cache_; } - const std::vector>& get_v_cache_() const { + const std::vector& get_v_cache_() const { return v_cache_; } @@ -169,15 +170,19 @@ class KVManager { private: // Helper functions to rearrange and update key and value caches - void rearrange_key(KVCache& k_cache, int32_t ar_len_dst); - void rearrange_value(KVCache& v_cache, int32_t ar_len_dst); + + void rearrange_key(KVCache& k_cache, int32_t ar_len_dst); + + void rearrange_value(KVCache& v_cache, int32_t ar_len_dst); + void update_key( - KVCache& k_cache, + KVCache& k_cache, int32_t n_past, int32_t n_update, const std::vector& selected); + void update_value( - KVCache& v_cache, + KVCache& v_cache, int32_t n_past, int32_t n_update, const std::vector& selected); @@ -186,10 +191,14 @@ class KVManager { Metadata metadata_; size_t total_cache_size_; int32_t cur_ar_len_; + executorch::aten::ScalarType attention_mask_dtype_ = + executorch::aten::ScalarType::Undefined; + executorch::aten::ScalarType kv_cache_dtype_ = + executorch::aten::ScalarType::Undefined; // Store start pointer of k and v cache for input and output // input: layer -> head * head_dim * max_cache_len // output: layer -> head * head_dim * max_ar_len - std::vector> k_cache_; - std::vector> v_cache_; + std::vector k_cache_; + std::vector v_cache_; }; } // namespace example diff --git a/examples/qualcomm/oss_scripts/llama/runner/lhd_token_generator.cpp b/examples/qualcomm/oss_scripts/llama/runner/lhd_token_generator.cpp index f7e44292f26..298fc1ac9ff 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/lhd_token_generator.cpp +++ b/examples/qualcomm/oss_scripts/llama/runner/lhd_token_generator.cpp @@ -13,20 +13,19 @@ using executorch::runtime::Result; namespace example { -template -void LhdTokenGenerator::prepare_io( +void LhdTokenGenerator::prepare_io( std::vector input_tokens, std::vector input_pos) { for (int i = 0; i < metadata_.ar_len; i++) { if (i < input_tokens.size()) { // Prepare pos data - this->input_pos_.data[i] = input_pos[i]; + reinterpret_cast(this->input_pos_.data)[i] = input_pos[i]; // Support CPU 4-bit embedding, which requires int64 input. // However, for QNN embedding, only int32 input is needed. // Therefore, we need to cast to the correct type to write the data. if (metadata_.use_int64_token) { - this->input_toks_.data[i] = input_tokens[i]; + reinterpret_cast(this->input_toks_.data)[i] = input_tokens[i]; } else { int32_t* input_toks_ptr = reinterpret_cast(this->input_toks_.data); @@ -36,8 +35,7 @@ void LhdTokenGenerator::prepare_io( } } -template -void LhdTokenGenerator::init_attention_mask(int32_t n_past) { +void LhdTokenGenerator::init_attention_mask(int32_t n_past) { std::vector attention_map; attention_map.reserve(metadata_.ar_len); // Initialize attention mask with current position @@ -73,8 +71,7 @@ void LhdTokenGenerator::init_attention_mask(int32_t n_past) { } } -template -void LhdTokenGenerator::init_lookahead_branch( +void LhdTokenGenerator::init_lookahead_branch( const std::vector& tokens) { for (int i = 0; i < metadata_.ngram - 1; ++i) { for (int j = 0; j < metadata_.window; ++j) { @@ -91,8 +88,7 @@ void LhdTokenGenerator::init_lookahead_branch( is_lhd_branch_initialized_ = true; } -template -void LhdTokenGenerator::init_verification_branch(uint64_t cur_token) { +void LhdTokenGenerator::init_verification_branch(uint64_t cur_token) { const int g_cur = ngrams_pool_.cnt[cur_token]; v_branch_.resize(g_cur); @@ -116,8 +112,7 @@ void LhdTokenGenerator::init_verification_branch(uint64_t cur_token) { } } -template -void LhdTokenGenerator::update_ngrams_pool() { +void LhdTokenGenerator::update_ngrams_pool() { std::vector ngram(metadata_.ngram - 1); // n-gram pool generation for (int f = 0; f < metadata_.window; ++f) { @@ -170,8 +165,7 @@ void LhdTokenGenerator::update_ngrams_pool() { } } -template -void LhdTokenGenerator::update_lookahead_branch( +void LhdTokenGenerator::update_lookahead_branch( const executorch::aten::Tensor& logits_tensor) { for (int i = 0; i < metadata_.window; i++) { lhd_branch_prev_[i] = lhd_branch_[0][i]; @@ -189,8 +183,7 @@ void LhdTokenGenerator::update_lookahead_branch( } } -template -Result LhdTokenGenerator::generate( +Result LhdTokenGenerator::generate( std::vector tokens, int64_t start_pos, int32_t seq_len, @@ -427,8 +420,4 @@ Result LhdTokenGenerator::generate( return pos - start_pos; } -// Explicit instantiations -template class LhdTokenGenerator; -template class LhdTokenGenerator; - } // namespace example diff --git a/examples/qualcomm/oss_scripts/llama/runner/lhd_token_generator.h b/examples/qualcomm/oss_scripts/llama/runner/lhd_token_generator.h index 796dde88014..8fdffb8af72 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/lhd_token_generator.h +++ b/examples/qualcomm/oss_scripts/llama/runner/lhd_token_generator.h @@ -15,8 +15,8 @@ namespace example { * @brief Class for generating the token using decoder and key-value manager * with lookahead decoding. */ -template -class LhdTokenGenerator : public TokenGenerator { + +class LhdTokenGenerator : public TokenGenerator { public: struct Metadata { int32_t context_len; @@ -34,18 +34,19 @@ class LhdTokenGenerator : public TokenGenerator { LhdTokenGenerator( tokenizers::Tokenizer* tokenizer, DecoderRunner* decoder_runner, - KVManager* kv_manager, + KVManager* kv_manager, const std::string& forward_name, std::unique_ptr>&& eos_ids, Metadata metadata, - executorch::llm::Stats* stats) - : TokenGenerator( + executorch::llm::Stats* stats, + std::unique_ptr method_meta) + : TokenGenerator( tokenizer, decoder_runner, kv_manager, forward_name, std::move(eos_ids), - typename TokenGenerator::Metadata{ + TokenGenerator::Metadata{ metadata.context_len, metadata.num_heads, metadata.num_layers, @@ -54,7 +55,8 @@ class LhdTokenGenerator : public TokenGenerator { metadata.use_int64_token, metadata.sliding_window, metadata.cache_mode}, - stats), + stats, + std::move(method_meta)), metadata_(metadata), lhd_branch_(metadata.ngram - 1, std::vector(metadata.window)), lhd_branch_prev_(metadata.window), @@ -104,7 +106,7 @@ class LhdTokenGenerator : public TokenGenerator { private: // Bring base class's virtual prepare_io into scope so the overload below // does not hide it (-Woverloaded-virtual). - using TokenGenerator::prepare_io; + using TokenGenerator::prepare_io; /** * @brief Fill in I/O buffers with prompt token and position. * @param cur_token Current token. diff --git a/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_lhd_token_generator.cpp b/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_lhd_token_generator.cpp index 14a93104e1a..de8d1bea0fe 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_lhd_token_generator.cpp +++ b/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_lhd_token_generator.cpp @@ -13,8 +13,7 @@ using executorch::runtime::Result; namespace example { -template -void MultimodalLhdTokenGenerator::prepare_io( +void MultimodalLhdTokenGenerator::prepare_io( std::vector input_tokens, std::vector input_pos) { for (int i = 0; i < metadata_.ar_len; i++) { @@ -51,8 +50,7 @@ void MultimodalLhdTokenGenerator::prepare_io( } } -template -void MultimodalLhdTokenGenerator::init_attention_mask(int32_t n_past) { +void MultimodalLhdTokenGenerator::init_attention_mask(int32_t n_past) { std::vector attention_map; attention_map.reserve(metadata_.ar_len); // Initialize attention mask with current position @@ -88,8 +86,7 @@ void MultimodalLhdTokenGenerator::init_attention_mask(int32_t n_past) { } } -template -void MultimodalLhdTokenGenerator::init_lookahead_branch( +void MultimodalLhdTokenGenerator::init_lookahead_branch( const std::vector& tokens) { for (int i = 0; i < metadata_.ngram - 1; ++i) { for (int j = 0; j < metadata_.window; ++j) { @@ -106,9 +103,7 @@ void MultimodalLhdTokenGenerator::init_lookahead_branch( is_lhd_branch_initialized_ = true; } -template -void MultimodalLhdTokenGenerator::init_verification_branch( - uint64_t cur_token) { +void MultimodalLhdTokenGenerator::init_verification_branch(uint64_t cur_token) { const int g_cur = ngrams_pool_.cnt[cur_token]; v_branch_.resize(g_cur); @@ -132,8 +127,7 @@ void MultimodalLhdTokenGenerator::init_verification_branch( } } -template -void MultimodalLhdTokenGenerator::update_ngrams_pool() { +void MultimodalLhdTokenGenerator::update_ngrams_pool() { std::vector ngram(metadata_.ngram - 1); // n-gram pool generation for (int f = 0; f < metadata_.window; ++f) { @@ -186,8 +180,7 @@ void MultimodalLhdTokenGenerator::update_ngrams_pool() { } } -template -void MultimodalLhdTokenGenerator::update_lookahead_branch( +void MultimodalLhdTokenGenerator::update_lookahead_branch( const executorch::aten::Tensor& logits_tensor) { for (int i = 0; i < metadata_.window; i++) { lhd_branch_prev_[i] = lhd_branch_[0][i]; @@ -205,8 +198,7 @@ void MultimodalLhdTokenGenerator::update_lookahead_branch( } } -template -Result MultimodalLhdTokenGenerator::generate( +Result MultimodalLhdTokenGenerator::generate( std::vector tokens, int64_t start_pos, int32_t seq_len, @@ -412,8 +404,4 @@ Result MultimodalLhdTokenGenerator::generate( return pos - start_pos; } -// Explicit instantiations -template class MultimodalLhdTokenGenerator; -template class MultimodalLhdTokenGenerator; - } // namespace example diff --git a/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_lhd_token_generator.h b/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_lhd_token_generator.h index 7494afec6da..6ffe285e536 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_lhd_token_generator.h +++ b/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_lhd_token_generator.h @@ -15,9 +15,7 @@ namespace example { * @class MultimodalLhdTokenGenerator * @brief Extended LhdTokenGenerator with multimodal embedding support */ -template -class MultimodalLhdTokenGenerator - : public example::MultimodalTokenGenerator { +class MultimodalLhdTokenGenerator : public example::MultimodalTokenGenerator { public: struct Metadata { int32_t context_len; @@ -37,19 +35,20 @@ class MultimodalLhdTokenGenerator tokenizers::Tokenizer* tokenizer, TokenEmbeddingProcessor* embedding_runner, DecoderRunner* decoder_runner, - KVManager* kv_manager, + KVManager* kv_manager, const std::string& forward_name, std::unique_ptr>&& eos_ids, Metadata metadata, - executorch::llm::Stats* stats) - : MultimodalTokenGenerator( + executorch::llm::Stats* stats, + std::unique_ptr method_meta) + : MultimodalTokenGenerator( tokenizer, embedding_runner, decoder_runner, kv_manager, forward_name, std::move(eos_ids), - typename MultimodalTokenGenerator::Metadata{ + MultimodalTokenGenerator::Metadata{ metadata.context_len, metadata.num_heads, metadata.num_layers, @@ -59,7 +58,8 @@ class MultimodalLhdTokenGenerator metadata.sliding_window, metadata.cache_mode, metadata.embedding_dim}, - stats), + stats, + std::move(method_meta)), tok_embedding_runner_(embedding_runner), metadata_(metadata), lhd_branch_(metadata.ngram - 1, std::vector(metadata.window)), @@ -110,7 +110,7 @@ class MultimodalLhdTokenGenerator private: // Bring base class's virtual prepare_io into scope so the overload below // does not hide it (-Woverloaded-virtual). - using TokenGenerator::prepare_io; + using TokenGenerator::prepare_io; /** * @brief Fill in I/O buffers with prompt token and position. * @param cur_token Current token. diff --git a/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_prompt_processor.cpp b/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_prompt_processor.cpp index 2859e16a42a..f63a431791b 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_prompt_processor.cpp +++ b/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_prompt_processor.cpp @@ -16,13 +16,13 @@ using executorch::runtime::TensorInfo; namespace example { -template -MultimodalPromptProcessor::MultimodalPromptProcessor( +MultimodalPromptProcessor::MultimodalPromptProcessor( DecoderRunner* decoder_runner, - KVManager* kv_manager, + KVManager* kv_manager, const std::string& method_name, - Metadata metadata) - : PromptProcessor( + Metadata metadata, + std::unique_ptr method_meta) + : PromptProcessor( decoder_runner, kv_manager, method_name, @@ -33,7 +33,8 @@ MultimodalPromptProcessor::MultimodalPromptProcessor( metadata.vocab_size, metadata.use_int64_token, metadata.sliding_window, - metadata.cache_mode}), + metadata.cache_mode}, + std::move(method_meta)), metadata_(metadata) { // Set input_toks_.size to 0 since we use embeddings instead input_toks_.size = 0; @@ -41,8 +42,7 @@ MultimodalPromptProcessor::MultimodalPromptProcessor( metadata_.ar_len * metadata_.embedding_dim * sizeof(float); }; -template -void MultimodalPromptProcessor::init_io( +void MultimodalPromptProcessor::init_io( IMemAlloc* buffer_manager, Result method_meta) { size_t idx = 0; @@ -66,8 +66,7 @@ void MultimodalPromptProcessor::init_io( // [I]: attention_mask Result attention_mask = method_meta->input_tensor_meta(idx++); - attention_mask_.data = reinterpret_cast( - buffer_manager->allocate(attention_mask_.size)); + attention_mask_.data = buffer_manager->allocate(attention_mask_.size); attention_mask_.tensor = std::make_unique( attention_mask->scalar_type(), attention_mask->sizes().size(), @@ -83,8 +82,8 @@ void MultimodalPromptProcessor::init_io( if (metadata_.cache_mode == CacheMode::HybridCache) { Result window_attention_mask = method_meta->input_tensor_meta(idx++); - window_attention_mask_.data = reinterpret_cast( - buffer_manager->allocate(window_attention_mask_.size)); + window_attention_mask_.data = + buffer_manager->allocate(window_attention_mask_.size); window_attention_mask_.tensor = std::make_unique( window_attention_mask->scalar_type(), window_attention_mask->sizes().size(), @@ -120,32 +119,29 @@ void MultimodalPromptProcessor::init_io( for (int cache_group = 0; cache_group < 2; ++cache_group) { std::vector>& cache = (cache_group == 0 ? k_cache_in_ : v_cache_in_); - std::vector> cache_ptrs = (cache_group == 0) + std::vector cache_ptrs = (cache_group == 0) ? kv_manager_->get_k_cache_() : kv_manager_->get_v_cache_(); for (int layer = 0; layer < metadata_.num_layers; ++layer, ++index) { Result kv_cache = method_meta->input_tensor_meta(index); - T* cache_ptr = cache_ptrs[layer].buffer; - cache[layer] = std::make_unique( kv_cache->scalar_type(), kv_cache->sizes().size(), const_cast(kv_cache->sizes().data()), - cache_ptr, + cache_ptrs[layer].buffer, const_cast( kv_cache->dim_order().data())); input_tensors_.emplace_back(cache[layer].get()); buffer_manager->add_memory_info( - cache_ptr, cache[layer]->nbytes(), kv_cache.get()); + cache_ptrs[layer].buffer, cache[layer]->nbytes(), kv_cache.get()); } } } // [O]: logits Result logits = method_meta->output_tensor_meta(0); - logits_.data = - reinterpret_cast(buffer_manager->allocate(logits_.size)); + logits_.data = buffer_manager->allocate(logits_.size); logits_.tensor = std::make_unique( logits->scalar_type(), logits->sizes().size(), @@ -160,21 +156,22 @@ void MultimodalPromptProcessor::init_io( for (int cache_group = 0; cache_group < 2; ++cache_group) { std::vector>& cache = (cache_group == 0 ? k_cache_out_ : v_cache_out_); - std::vector> cache_ptrs = (cache_group == 0) + std::vector cache_ptrs = (cache_group == 0) ? kv_manager_->get_k_cache_() : kv_manager_->get_v_cache_(); for (int layer = 0; layer < metadata_.num_layers; ++layer, ++index) { Result kv_cache = method_meta->output_tensor_meta(index); - T* cache_ptr = cache_ptrs[layer].output_buffer; cache[layer] = std::make_unique( kv_cache->scalar_type(), kv_cache->sizes().size(), const_cast(kv_cache->sizes().data()), - cache_ptr, + cache_ptrs[layer].output_buffer, const_cast(kv_cache->dim_order().data())); output_tensors_.emplace_back(cache[layer].get()); buffer_manager->add_memory_info( - cache_ptr, cache[layer]->nbytes(), kv_cache.get()); + cache_ptrs[layer].output_buffer, + cache[layer]->nbytes(), + kv_cache.get()); } } @@ -186,8 +183,7 @@ void MultimodalPromptProcessor::init_io( } // prepare embedding -template -void MultimodalPromptProcessor::prepare_io( +void MultimodalPromptProcessor::prepare_io( const TensorStruct& prompt_embedding, int32_t num_prompt_tokens, int64_t prompt_pos, @@ -208,8 +204,7 @@ void MultimodalPromptProcessor::prepare_io( } } -template -Result MultimodalPromptProcessor::prefill( +Result MultimodalPromptProcessor::prefill( const TensorStruct& prompt_embedding, int64_t start_pos, bool dump_logits, @@ -301,8 +296,4 @@ Result MultimodalPromptProcessor::prefill( return cur_token; } -// Explicit instantiations -template class MultimodalPromptProcessor; -template class MultimodalPromptProcessor; - } // namespace example diff --git a/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_prompt_processor.h b/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_prompt_processor.h index fcfc07c9590..c2769ed9f50 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_prompt_processor.h +++ b/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_prompt_processor.h @@ -16,8 +16,7 @@ namespace example { * @class MultimodalPromptProcessor * @brief Extended PromptProcessor with multimodal embedding support */ -template -class MultimodalPromptProcessor : public example::PromptProcessor { +class MultimodalPromptProcessor : public example::PromptProcessor { public: struct Metadata { int32_t context_len; @@ -33,9 +32,10 @@ class MultimodalPromptProcessor : public example::PromptProcessor { MultimodalPromptProcessor( DecoderRunner* decoder_runner, - KVManager* kv_manager, + KVManager* kv_manager, const std::string& method_name, - Metadata metadata); + Metadata metadata, + std::unique_ptr method_meta); int64_t get_num_heads() const { return metadata_.num_heads; @@ -74,34 +74,29 @@ class MultimodalPromptProcessor : public example::PromptProcessor { * @return Total I/O size in bytes. */ inline const size_t total_prompt_processor_io_size_in_bytes() const { - if (metadata_.cache_mode == CacheMode::HybridCache) { - return input_toks_.size + input_pos_.size + attention_mask_.size + - window_attention_mask_.size + logits_.size + input_embedding_.size; - } else { - return input_toks_.size + input_pos_.size + attention_mask_.size + - logits_.size + input_embedding_.size; - } + return input_toks_.size + input_pos_.size + attention_mask_.size + + window_attention_mask_.size + logits_.size + input_embedding_.size; } private: // Reuse members from token_generator - using PromptProcessor::decoder_runner_; - using PromptProcessor::kv_manager_; - using PromptProcessor::method_name_; - using PromptProcessor::k_cache_in_; - using PromptProcessor::v_cache_in_; - using PromptProcessor::k_cache_out_; - using PromptProcessor::v_cache_out_; - using PromptProcessor::input_toks_; - using PromptProcessor::input_pos_; - using PromptProcessor::attention_mask_; - using PromptProcessor::window_attention_mask_; - using PromptProcessor::logits_; - using PromptProcessor::inputs_; - using PromptProcessor::input_tensors_; - using PromptProcessor::output_tensors_; - using PromptProcessor::prompt_all_logits_; - using PromptProcessor::is_bert; + using PromptProcessor::attention_mask_; + using PromptProcessor::decoder_runner_; + using PromptProcessor::input_pos_; + using PromptProcessor::input_tensors_; + using PromptProcessor::input_toks_; + using PromptProcessor::inputs_; + using PromptProcessor::is_bert; + using PromptProcessor::k_cache_in_; + using PromptProcessor::k_cache_out_; + using PromptProcessor::kv_manager_; + using PromptProcessor::logits_; + using PromptProcessor::method_name_; + using PromptProcessor::output_tensors_; + using PromptProcessor::prompt_all_logits_; + using PromptProcessor::v_cache_in_; + using PromptProcessor::v_cache_out_; + using PromptProcessor::window_attention_mask_; /** * @brief Fill in I/O buffers with embedding data and position. diff --git a/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_runner.cpp b/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_runner.cpp index 32e3baf27a9..32575994222 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_runner.cpp +++ b/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_runner.cpp @@ -74,17 +74,17 @@ void print_performance_report( void save_logits( const std::string& dump_logits_path, - const std::vector& prefill_logits, - const std::vector& decode_logits) { + const std::vector& prefill_logits, + const std::vector& decode_logits) { std::ofstream outFile(dump_logits_path.c_str(), std::ios::binary); if (outFile.is_open()) { outFile.write( reinterpret_cast(prefill_logits.data()), - prefill_logits.size() * sizeof(uint16_t)); + prefill_logits.size()); outFile.write( reinterpret_cast(decode_logits.data()), - decode_logits.size() * sizeof(uint16_t)); + decode_logits.size()); outFile.close(); } else { ET_CHECK_MSG(false, "Error saving the dump logits file"); @@ -93,8 +93,7 @@ void save_logits( } // namespace -template -QNNMultimodalRunner::QNNMultimodalRunner( +QNNMultimodalRunner::QNNMultimodalRunner( std::unique_ptr encoder, std::unique_ptr tok_embedding, std::unique_ptr text_decoder, @@ -148,16 +147,14 @@ QNNMultimodalRunner::QNNMultimodalRunner( ET_LOG(Info, "eval mode=%d", eval_mode_); } -template -bool QNNMultimodalRunner::is_loaded() const { +bool QNNMultimodalRunner::is_loaded() const { return encoder_->is_loaded() && tok_embedding_->is_loaded() && text_decoder_->is_loaded() && embedding_merger_ && tokenizer_ && decoder_runner_ && prompt_processor_ && token_generator_ && kv_manager_ && buffer_manager_; } -template -Error QNNMultimodalRunner::load() { +Error QNNMultimodalRunner::load() { if (is_loaded()) { return Error::Ok; } @@ -298,19 +295,22 @@ Error QNNMultimodalRunner::load() { sliding_window = ET_UNWRAP(text_decoder_->get("get_sliding_window")).toInt(); } - kv_manager_ = std::make_unique>(typename KVManager::Metadata{ - context_len_, - head_dim, - max_ar_len, - max_cache_len, - num_heads, - num_layers}); - - prompt_processor_ = std::make_unique>( + kv_manager_ = std::make_unique( + KVManager::Metadata{ + context_len_, + head_dim, + max_ar_len, + max_cache_len, + num_heads, + num_layers}, + std::make_unique(std::move( + text_decoder_->method_meta(token_generator_method_name).get()))); + + prompt_processor_ = std::make_unique( decoder_runner_.get(), kv_manager_.get(), prompt_processor_method_name, - typename MultimodalPromptProcessor::Metadata{ + MultimodalPromptProcessor::Metadata{ context_len_, num_heads, num_layers, @@ -319,7 +319,9 @@ Error QNNMultimodalRunner::load() { use_int64_token, sliding_window, cache_mode_, - static_cast(dim)}); + static_cast(dim)}, + std::make_unique(std::move( + text_decoder_->method_meta(prompt_processor_method_name).get()))); // Initialize EmbeddingGenerator tok_embedding_generator_ = std::make_unique( @@ -333,14 +335,14 @@ Error QNNMultimodalRunner::load() { static_cast(dim)}); if (eval_mode_ == EvalMode::kLookaheadDecoding) { // Initialize TokenGenerator - token_generator_ = std::make_unique>( + token_generator_ = std::make_unique( tokenizer_.get(), tok_embedding_generator_.get(), decoder_runner_.get(), kv_manager_.get(), token_generator_method_name, std::move(eos_ids), - typename MultimodalLhdTokenGenerator::Metadata{ + MultimodalLhdTokenGenerator::Metadata{ context_len_, num_heads, num_layers, @@ -353,16 +355,18 @@ Error QNNMultimodalRunner::load() { sliding_window, cache_mode_, static_cast(dim)}, - &stats_); + &stats_, + std::make_unique(std::move( + text_decoder_->method_meta(token_generator_method_name).get()))); } else { - token_generator_ = std::make_unique>( + token_generator_ = std::make_unique( tokenizer_.get(), tok_embedding_generator_.get(), decoder_runner_.get(), kv_manager_.get(), token_generator_method_name, std::move(eos_ids), - typename MultimodalTokenGenerator::Metadata{ + MultimodalTokenGenerator::Metadata{ context_len_, num_heads, num_layers, @@ -372,7 +376,9 @@ Error QNNMultimodalRunner::load() { sliding_window, cache_mode_, static_cast(dim)}, - &stats_); + &stats_, + std::make_unique(std::move( + text_decoder_->method_meta(token_generator_method_name).get()))); } buffer_manager_ = std::make_unique(); @@ -409,8 +415,7 @@ Error QNNMultimodalRunner::load() { return Error::Ok; } -template -executorch::runtime::Error QNNMultimodalRunner::generate( +executorch::runtime::Error QNNMultimodalRunner::generate( const std::vector& inputs, const llm::GenerationConfig& config, std::function token_callback, @@ -561,8 +566,7 @@ executorch::runtime::Error QNNMultimodalRunner::generate( return Error::Ok; } -template -Result QNNMultimodalRunner::get_model_version() { +Result QNNMultimodalRunner::get_model_version() { if (!is_loaded()) { stats_.model_load_start_ms = time_in_ms(); ET_CHECK_OK_OR_RETURN_ERROR(load()); @@ -571,16 +575,11 @@ Result QNNMultimodalRunner::get_model_version() { return model_version_; } -template -Result QNNMultimodalRunner::get_encoder_method_meta() { +Result QNNMultimodalRunner::get_encoder_method_meta() { if (!is_loaded()) { ET_CHECK_OK_OR_RETURN_ERROR(load()); } return encoder_->method_meta(kEncoderForwardName); } -// Explicit instantiations -template class QNNMultimodalRunner; -template class QNNMultimodalRunner; - } // namespace example diff --git a/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_runner.h b/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_runner.h index 5407d5712b7..363ded0f055 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_runner.h +++ b/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_runner.h @@ -66,12 +66,6 @@ inline Modality modality_of(const ModelVersion& model_version) { [](const auto& model) { return modality_of(model); }, model_version); } -enum KvBitWidth { - kWidth8 = 8, - kWidth16 = 16, -}; - -template class QNNMultimodalRunner : public executorch::extension::llm::MultimodalRunner { public: @@ -139,11 +133,11 @@ class QNNMultimodalRunner ModelVersion model_version_; std::unique_ptr buffer_manager_; - std::unique_ptr> kv_manager_; + std::unique_ptr kv_manager_; std::unique_ptr tokenizer_; std::unique_ptr decoder_runner_; - std::unique_ptr> prompt_processor_; - std::unique_ptr> token_generator_; + std::unique_ptr prompt_processor_; + std::unique_ptr token_generator_; std::unique_ptr encoder_runner_; std::unique_ptr tok_embedding_runner_; std::unique_ptr tok_embedding_processor_; diff --git a/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_token_generator.cpp b/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_token_generator.cpp index 2ed8ae51f1d..e3f6f8e214e 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_token_generator.cpp +++ b/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_token_generator.cpp @@ -15,17 +15,17 @@ using executorch::runtime::TensorInfo; namespace example { // Constructor with embedding runner support -template -MultimodalTokenGenerator::MultimodalTokenGenerator( +MultimodalTokenGenerator::MultimodalTokenGenerator( tokenizers::Tokenizer* tokenizer, TokenEmbeddingProcessor* tok_embedding_runner, DecoderRunner* decoder_runner, - KVManager* kv_manager, + KVManager* kv_manager, const std::string& method_name, std::unique_ptr>&& eos_ids, Metadata metadata, - executorch::llm::Stats* stats) - : TokenGenerator( + executorch::llm::Stats* stats, + std::unique_ptr method_meta) + : TokenGenerator( tokenizer, decoder_runner, kv_manager, @@ -39,7 +39,8 @@ MultimodalTokenGenerator::MultimodalTokenGenerator( metadata.use_int64_token, metadata.sliding_window, metadata.cache_mode}, - stats), + stats, + std::move(method_meta)), tok_embedding_runner_(tok_embedding_runner), metadata_(metadata) { // Set input_toks_.size to 0 since we use embeddings instead @@ -48,8 +49,7 @@ MultimodalTokenGenerator::MultimodalTokenGenerator( metadata_.ar_len * metadata_.embedding_dim * sizeof(float); } -template -void MultimodalTokenGenerator::init_io( +void MultimodalTokenGenerator::init_io( IMemAlloc* buffer_manager, Result method_meta) { size_t idx = 0; @@ -73,8 +73,7 @@ void MultimodalTokenGenerator::init_io( // [I]: attention_mask Result attention_mask = method_meta->input_tensor_meta(idx++); - attention_mask_.data = reinterpret_cast( - buffer_manager->allocate(attention_mask_.size)); + attention_mask_.data = buffer_manager->allocate(attention_mask_.size); attention_mask_.tensor = std::make_unique( attention_mask->scalar_type(), attention_mask->sizes().size(), @@ -90,8 +89,8 @@ void MultimodalTokenGenerator::init_io( if (metadata_.cache_mode == CacheMode::HybridCache) { Result window_attention_mask = method_meta->input_tensor_meta(idx++); - window_attention_mask_.data = reinterpret_cast( - buffer_manager->allocate(window_attention_mask_.size)); + window_attention_mask_.data = + buffer_manager->allocate(window_attention_mask_.size); window_attention_mask_.tensor = std::make_unique( window_attention_mask->scalar_type(), window_attention_mask->sizes().size(), @@ -126,30 +125,27 @@ void MultimodalTokenGenerator::init_io( for (int cache_group = 0; cache_group < 2; ++cache_group) { std::vector>& cache = (cache_group == 0 ? k_cache_in_ : v_cache_in_); - std::vector> cache_ptrs = (cache_group == 0) + std::vector cache_ptrs = (cache_group == 0) ? kv_manager_->get_k_cache_() : kv_manager_->get_v_cache_(); for (int layer = 0; layer < metadata_.num_layers; ++layer, ++index) { Result kv_cache = method_meta->input_tensor_meta(index); - T* cache_ptr = cache_ptrs[layer].buffer; - cache[layer] = std::make_unique( kv_cache->scalar_type(), kv_cache->sizes().size(), const_cast(kv_cache->sizes().data()), - cache_ptr, + cache_ptrs[layer].buffer, const_cast(kv_cache->dim_order().data())); input_tensors_.emplace_back(cache[layer].get()); buffer_manager->add_memory_info( - cache_ptr, cache[layer]->nbytes(), kv_cache.get()); + cache_ptrs[layer].buffer, cache[layer]->nbytes(), kv_cache.get()); } } // [O]: logits Result logits = method_meta->output_tensor_meta(0); - logits_.data = - reinterpret_cast(buffer_manager->allocate(logits_.size)); + logits_.data = buffer_manager->allocate(logits_.size); logits_.tensor = std::make_unique( logits->scalar_type(), logits->sizes().size(), @@ -164,21 +160,22 @@ void MultimodalTokenGenerator::init_io( for (int cache_group = 0; cache_group < 2; ++cache_group) { std::vector>& cache = (cache_group == 0 ? k_cache_out_ : v_cache_out_); - std::vector> cache_ptrs = (cache_group == 0) + std::vector cache_ptrs = (cache_group == 0) ? kv_manager_->get_k_cache_() : kv_manager_->get_v_cache_(); for (int layer = 0; layer < metadata_.num_layers; ++layer, ++index) { Result kv_cache = method_meta->output_tensor_meta(index); - T* cache_ptr = cache_ptrs[layer].output_buffer; cache[layer] = std::make_unique( kv_cache->scalar_type(), kv_cache->sizes().size(), const_cast(kv_cache->sizes().data()), - cache_ptr, + cache_ptrs[layer].output_buffer, const_cast(kv_cache->dim_order().data())); output_tensors_.emplace_back(cache[layer].get()); buffer_manager->add_memory_info( - cache_ptr, cache[layer]->nbytes(), kv_cache.get()); + cache_ptrs[layer].output_buffer, + cache[layer]->nbytes(), + kv_cache.get()); } } @@ -190,8 +187,7 @@ void MultimodalTokenGenerator::init_io( } // This function only considers the case where token_generator_ar_len equals 1. -template -void MultimodalTokenGenerator::prepare_io( +void MultimodalTokenGenerator::prepare_io( uint64_t cur_token, int64_t start_pos) { // Generate embedding for current token using embedding runner @@ -209,8 +205,4 @@ void MultimodalTokenGenerator::prepare_io( *input_pos_.data = static_cast(start_pos); } -// Explicit instantiations -template class MultimodalTokenGenerator; -template class MultimodalTokenGenerator; - } // namespace example diff --git a/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_token_generator.h b/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_token_generator.h index 9eb9c79aaa4..2d0bf9385b4 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_token_generator.h +++ b/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_token_generator.h @@ -16,8 +16,7 @@ namespace example { * @class MultimodalTokenGenerator * @brief Extended TokenGenerator with multimodal embedding support */ -template -class MultimodalTokenGenerator : public example::TokenGenerator { +class MultimodalTokenGenerator : public example::TokenGenerator { public: struct Metadata { int32_t context_len; @@ -36,11 +35,12 @@ class MultimodalTokenGenerator : public example::TokenGenerator { tokenizers::Tokenizer* tokenizer, TokenEmbeddingProcessor* tok_embedding_runner, DecoderRunner* decoder_runner, - KVManager* kv_manager, + KVManager* kv_manager, const std::string& method_name, std::unique_ptr>&& eos_ids, Metadata metadata, - executorch::llm::Stats* stats); + executorch::llm::Stats* stats, + std::unique_ptr method_meta); virtual ~MultimodalTokenGenerator() = default; @@ -54,36 +54,31 @@ class MultimodalTokenGenerator : public example::TokenGenerator { override; inline const size_t total_token_generator_io_size_in_bytes() const { - if (metadata_.cache_mode == CacheMode::HybridCache) { - return input_toks_.size + input_pos_.size + attention_mask_.size + - window_attention_mask_.size + logits_.size + input_embedding_.size; - } else { - return input_toks_.size + input_pos_.size + attention_mask_.size + - logits_.size + input_embedding_.size; - } + return input_toks_.size + input_pos_.size + attention_mask_.size + + window_attention_mask_.size + logits_.size + input_embedding_.size; } protected: // Reuse members from token_generator - using TokenGenerator::kv_manager_; - using TokenGenerator::input_pos_; - using TokenGenerator::attention_mask_; - using TokenGenerator::window_attention_mask_; - using TokenGenerator::inputs_; - using TokenGenerator::input_tensors_; - using TokenGenerator::output_tensors_; + using TokenGenerator::attention_mask_; + using TokenGenerator::input_pos_; + using TokenGenerator::input_tensors_; + using TokenGenerator::inputs_; + using TokenGenerator::kv_manager_; + using TokenGenerator::output_tensors_; + using TokenGenerator::window_attention_mask_; // Additional members specific to multimodal TensorStruct input_embedding_; private: // Reuse members from token_generator - using TokenGenerator::input_toks_; - using TokenGenerator::logits_; - using TokenGenerator::k_cache_in_; - using TokenGenerator::v_cache_in_; - using TokenGenerator::k_cache_out_; - using TokenGenerator::v_cache_out_; + using TokenGenerator::input_toks_; + using TokenGenerator::k_cache_in_; + using TokenGenerator::k_cache_out_; + using TokenGenerator::logits_; + using TokenGenerator::v_cache_in_; + using TokenGenerator::v_cache_out_; // Additional members specific to multimodal TokenEmbeddingProcessor* tok_embedding_runner_; diff --git a/examples/qualcomm/oss_scripts/llama/runner/prompt_processor.cpp b/examples/qualcomm/oss_scripts/llama/runner/prompt_processor.cpp index 59744d488bd..0cb52246a39 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/prompt_processor.cpp +++ b/examples/qualcomm/oss_scripts/llama/runner/prompt_processor.cpp @@ -17,12 +17,12 @@ using executorch::runtime::Span; using executorch::runtime::TensorInfo; namespace example { -template -PromptProcessor::PromptProcessor( +PromptProcessor::PromptProcessor( DecoderRunner* decoder_runner, - KVManager* kv_manager, + KVManager* kv_manager, const std::string& method_name, - Metadata metadata) + Metadata metadata, + std::unique_ptr method_meta) : decoder_runner_(decoder_runner), kv_manager_(kv_manager), method_name_(method_name), @@ -32,33 +32,41 @@ PromptProcessor::PromptProcessor( k_cache_out_.resize(metadata_.num_layers); v_cache_out_.resize(metadata_.num_layers); // Calculate I/O size + Result attention_mask = method_meta->input_tensor_meta(1); + Result logits = method_meta->output_tensor_meta(0); input_toks_.size = metadata_.ar_len * sizeof(int64_t); - if (is_bert()) + if (is_bert()) { input_pos_.size = 0; - else + } else { input_pos_.size = metadata_.ar_len * sizeof(int32_t); + } + attention_mask_.dtype = attention_mask->scalar_type(); + attention_mask_.size = metadata_.ar_len * metadata_.context_len * + attention_mask_.getElementSize(); switch (metadata_.cache_mode) { case CacheMode::StaticCahce: - attention_mask_.size = - metadata_.ar_len * metadata_.context_len * sizeof(uint16_t); window_attention_mask_.size = 0; break; - case CacheMode::HybridCache: - attention_mask_.size = - metadata_.ar_len * metadata_.context_len * sizeof(uint16_t); - window_attention_mask_.size = - metadata_.ar_len * metadata_.context_len * sizeof(uint16_t); + case CacheMode::HybridCache: { + Result window_attention_mask = + method_meta->input_tensor_meta(2); + window_attention_mask_.dtype = window_attention_mask->scalar_type(); + window_attention_mask_.size = metadata_.ar_len * metadata_.context_len * + window_attention_mask_.getElementSize(); break; + } default: ET_CHECK_MSG(false, "Unsupported llama cache mode"); break; } - logits_.size = metadata_.ar_len * metadata_.vocab_size * sizeof(uint16_t); + logits_.dtype = logits->scalar_type(); + logits_.size = + metadata_.ar_len * metadata_.vocab_size * logits_.getElementSize(); }; -template -void PromptProcessor::init_io( + +void PromptProcessor::init_io( IMemAlloc* buffer_manager, Result method_meta) { size_t idx = 0; @@ -80,8 +88,7 @@ void PromptProcessor::init_io( // [I]: attention_mask Result attention_mask = method_meta->input_tensor_meta(idx++); - attention_mask_.data = reinterpret_cast( - buffer_manager->allocate(attention_mask_.size)); + attention_mask_.data = buffer_manager->allocate(attention_mask_.size); attention_mask_.tensor = std::make_unique( attention_mask->scalar_type(), attention_mask->sizes().size(), @@ -97,8 +104,8 @@ void PromptProcessor::init_io( if (metadata_.cache_mode == CacheMode::HybridCache) { Result window_attention_mask = method_meta->input_tensor_meta(idx++); - window_attention_mask_.data = reinterpret_cast( - buffer_manager->allocate(window_attention_mask_.size)); + window_attention_mask_.data = + buffer_manager->allocate(window_attention_mask_.size); window_attention_mask_.tensor = std::make_unique( window_attention_mask->scalar_type(), window_attention_mask->sizes().size(), @@ -136,33 +143,30 @@ void PromptProcessor::init_io( for (int cache_group = 0; cache_group < 2; ++cache_group) { std::vector>& cache = (cache_group == 0 ? k_cache_in_ : v_cache_in_); - std::vector> cache_ptrs = (cache_group == 0) + std::vector cache_ptrs = (cache_group == 0) ? kv_manager_->get_k_cache_() : kv_manager_->get_v_cache_(); for (int layer = 0; layer < metadata_.num_layers; ++layer, ++index) { Result kv_cache = method_meta->input_tensor_meta(index); - T* cache_ptr = cache_ptrs[layer].buffer; - cache[layer] = std::make_unique( kv_cache->scalar_type(), kv_cache->sizes().size(), const_cast(kv_cache->sizes().data()), - cache_ptr, + cache_ptrs[layer].buffer, const_cast( kv_cache->dim_order().data())); input_tensors_.emplace_back(cache[layer].get()); cache_inputs_.emplace_back(input_tensors_.back()); buffer_manager->add_memory_info( - cache_ptr, cache[layer]->nbytes(), kv_cache.get()); + cache_ptrs[layer].buffer, cache[layer]->nbytes(), kv_cache.get()); } } } // [O]: logits Result logits = method_meta->output_tensor_meta(0); - logits_.data = - reinterpret_cast(buffer_manager->allocate(logits_.size)); + logits_.data = buffer_manager->allocate(logits_.size); logits_.tensor = std::make_unique( logits->scalar_type(), logits->sizes().size(), @@ -177,21 +181,22 @@ void PromptProcessor::init_io( for (int cache_group = 0; cache_group < 2; ++cache_group) { std::vector>& cache = (cache_group == 0 ? k_cache_out_ : v_cache_out_); - std::vector> cache_ptrs = (cache_group == 0) + std::vector cache_ptrs = (cache_group == 0) ? kv_manager_->get_k_cache_() : kv_manager_->get_v_cache_(); for (int layer = 0; layer < metadata_.num_layers; ++layer, ++index) { Result kv_cache = method_meta->output_tensor_meta(index); - T* cache_ptr = cache_ptrs[layer].output_buffer; cache[layer] = std::make_unique( kv_cache->scalar_type(), kv_cache->sizes().size(), const_cast(kv_cache->sizes().data()), - cache_ptr, + cache_ptrs[layer].output_buffer, const_cast(kv_cache->dim_order().data())); output_tensors_.emplace_back(cache[layer].get()); buffer_manager->add_memory_info( - cache_ptr, cache[layer]->nbytes(), kv_cache.get()); + cache_ptrs[layer].output_buffer, + cache[layer]->nbytes(), + kv_cache.get()); } } // Prepare the vector of EValue to run inference @@ -201,13 +206,11 @@ void PromptProcessor::init_io( } } -template -const std::vector& PromptProcessor::get_all_logits() { +const std::vector& PromptProcessor::get_all_logits() { return prompt_all_logits_; } -template -void PromptProcessor::prepare_io( +void PromptProcessor::prepare_io( const std::vector& prompt_tokens, int64_t prompt_pos, int64_t start_pos) { @@ -232,8 +235,7 @@ void PromptProcessor::prepare_io( } } -template -Result PromptProcessor::prefill( +Result PromptProcessor::prefill( std::vector prompt_tokens, int64_t start_pos, bool dump_logits, @@ -339,7 +341,9 @@ Result PromptProcessor::prefill( prompt_all_logits_.insert( prompt_all_logits_.end(), logits_.data, - logits_.data + metadata_.ar_len * metadata_.vocab_size); + logits_.data + + metadata_.ar_len * metadata_.vocab_size * + logits_.getElementSize()); } // In the last run, offset to the meaningful logits. if (i == num_iters - 1) { @@ -369,8 +373,4 @@ Result PromptProcessor::prefill( return cur_token; } -// Explicit instantiations -template class PromptProcessor; -template class PromptProcessor; - } // namespace example diff --git a/examples/qualcomm/oss_scripts/llama/runner/prompt_processor.h b/examples/qualcomm/oss_scripts/llama/runner/prompt_processor.h index 599f7050d83..5317a8a77e1 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/prompt_processor.h +++ b/examples/qualcomm/oss_scripts/llama/runner/prompt_processor.h @@ -21,7 +21,7 @@ namespace example { * @class PromptProcessor * @brief Class for processing prompts using decoder and key-value manager. */ -template + class PromptProcessor { public: struct Metadata { @@ -36,9 +36,10 @@ class PromptProcessor { }; PromptProcessor( DecoderRunner* decoder_runner, - KVManager* kv_manager, + KVManager* kv_manager, const std::string& method_name, - Metadata metadata); + Metadata metadata, + std::unique_ptr method_meta); virtual ~PromptProcessor() = default; @@ -55,9 +56,9 @@ class PromptProcessor { /** * @brief Get the all logits generated * - * @return std::vector& all the logits generated + * @return std::vector& all the logits generated */ - virtual const std::vector& get_all_logits(); + virtual const std::vector& get_all_logits(); /** * Prefill an LLM Module with the given text input. @@ -79,13 +80,8 @@ class PromptProcessor { * @return Total I/O size in bytes. */ inline const size_t total_prompt_processor_io_size_in_bytes() const { - if (metadata_.cache_mode == CacheMode::HybridCache) { - return input_toks_.size + input_pos_.size + attention_mask_.size + - window_attention_mask_.size + logits_.size; - } else { - return input_toks_.size + input_pos_.size + attention_mask_.size + - logits_.size; - } + return input_toks_.size + input_pos_.size + attention_mask_.size + + window_attention_mask_.size + logits_.size; } protected: @@ -105,7 +101,7 @@ class PromptProcessor { int64_t prompt_pos, int64_t start_pos); DecoderRunner* decoder_runner_; - KVManager* kv_manager_; + KVManager* kv_manager_; std::string method_name_; // metadata @@ -114,9 +110,9 @@ class PromptProcessor { // inputs and outputs TensorStruct input_toks_; TensorStruct input_pos_; - TensorStruct attention_mask_; - TensorStruct window_attention_mask_; - TensorStruct logits_; + TensorStructRaw attention_mask_; + TensorStructRaw window_attention_mask_; + TensorStructRaw logits_; // layer -> TensorImpl std::vector> k_cache_in_; @@ -131,6 +127,6 @@ class PromptProcessor { std::vector cache_inputs_; // Unused by default, only used when dump_logits_path is provided. - std::vector prompt_all_logits_; + std::vector prompt_all_logits_; }; } // namespace example diff --git a/examples/qualcomm/oss_scripts/llama/runner/runner.cpp b/examples/qualcomm/oss_scripts/llama/runner/runner.cpp index 0a4a8b9abb5..7257e869dcc 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/runner.cpp +++ b/examples/qualcomm/oss_scripts/llama/runner/runner.cpp @@ -66,17 +66,17 @@ void print_performance_report( void save_logits( const std::string& dump_logits_path, - const std::vector& prefill_logits, - const std::vector& decode_logits) { + const std::vector& prefill_logits, + const std::vector& decode_logits) { std::ofstream outFile(dump_logits_path.c_str(), std::ios::binary); if (outFile.is_open()) { outFile.write( reinterpret_cast(prefill_logits.data()), - prefill_logits.size() * sizeof(uint16_t)); + prefill_logits.size()); outFile.write( reinterpret_cast(decode_logits.data()), - decode_logits.size() * sizeof(uint16_t)); + decode_logits.size()); outFile.close(); } else { ET_CHECK_MSG(false, "Error saving the dump logits file"); @@ -85,8 +85,7 @@ void save_logits( } // namespace -template -Runner::Runner( +Runner::Runner( std::unique_ptr module, const std::string& decoder_model_version, const std::string& model_path, @@ -152,14 +151,12 @@ Runner::Runner( ET_LOG(Info, "eval mode=%d", eval_mode_); } -template -bool Runner::is_loaded() const { +bool Runner::is_loaded() const { return module_->is_loaded() && tokenizer_ && decoder_runner_ && prompt_processor_ && token_generator_ && kv_manager_ && buffer_manager_; } -template -Error Runner::load() { +Error Runner::load() { if (is_loaded()) { return Error::Ok; } @@ -275,13 +272,16 @@ Error Runner::load() { if (module_->method_names()->count("get_sliding_window") > 0) { sliding_window = ET_UNWRAP(module_->get("get_sliding_window")).toInt(); } - kv_manager_ = std::make_unique>(typename KVManager::Metadata{ - context_len_, - head_dim, - max_ar_len, - max_cache_len, - num_heads, - num_layers}); + kv_manager_ = std::make_unique( + KVManager::Metadata{ + context_len_, + head_dim, + max_ar_len, + max_cache_len, + num_heads, + num_layers}, + std::make_unique( + std::move(module_->method_meta(token_generator_method_name).get()))); if (attention_sink_rope_module_ != nullptr) { attention_sink_rope_runner_ = std::make_unique( @@ -290,11 +290,11 @@ Error Runner::load() { attention_sink_rope_runner_->load(method_names)); } - prompt_processor_ = std::make_unique>( + prompt_processor_ = std::make_unique( decoder_runner_.get(), kv_manager_.get(), prompt_processor_method_name, - typename PromptProcessor::Metadata{ + PromptProcessor::Metadata{ context_len_, num_heads, num_layers, @@ -302,15 +302,17 @@ Error Runner::load() { vocab_size, use_int64_token, sliding_window, - cache_mode_}); + cache_mode_}, + std::make_unique( + std::move(module_->method_meta(prompt_processor_method_name).get()))); if (eval_mode_ == EvalMode::kLookaheadDecoding) { - token_generator_ = std::make_unique>( + token_generator_ = std::make_unique( tokenizer_.get(), decoder_runner_.get(), kv_manager_.get(), token_generator_method_name, std::move(eos_ids), - typename LhdTokenGenerator::Metadata{ + LhdTokenGenerator::Metadata{ context_len_, num_heads, num_layers, @@ -322,15 +324,17 @@ Error Runner::load() { gcap_, sliding_window, cache_mode_}, - &stats_); + &stats_, + std::make_unique(std::move( + module_->method_meta(token_generator_method_name).get()))); } else { - token_generator_ = std::make_unique>( + token_generator_ = std::make_unique( tokenizer_.get(), decoder_runner_.get(), kv_manager_.get(), token_generator_method_name, std::move(eos_ids), - typename TokenGenerator::Metadata{ + TokenGenerator::Metadata{ context_len_, num_heads, num_layers, @@ -339,7 +343,9 @@ Error Runner::load() { use_int64_token, sliding_window, cache_mode_}, - &stats_); + &stats_, + std::make_unique(std::move( + module_->method_meta(token_generator_method_name).get()))); } buffer_manager_ = std::make_unique(); @@ -360,8 +366,7 @@ Error Runner::load() { return Error::Ok; } -template -Error Runner::generate( +Error Runner::generate( const std::string& prompt, const llm::GenerationConfig& config, std::function token_callback, @@ -370,8 +375,7 @@ Error Runner::generate( prompt, false, config, token_callback, stats_callback); } -template -Error Runner::generate_from_prompt_or_file( +Error Runner::generate_from_prompt_or_file( const std::string& prompt, bool tokenized_prompt, const llm::GenerationConfig& config, @@ -500,8 +504,7 @@ Error Runner::generate_from_prompt_or_file( return Error::Ok; } -template -Result Runner::get_decoder_model_version() { +Result Runner::get_decoder_model_version() { if (!is_loaded()) { stats_.model_load_start_ms = time_in_ms(); ET_CHECK_OK_OR_RETURN_ERROR(load()); @@ -510,8 +513,4 @@ Result Runner::get_decoder_model_version() { return decoder_model_version_; } -// Explicit instantiations -template class Runner; -template class Runner; - } // namespace example diff --git a/examples/qualcomm/oss_scripts/llama/runner/runner.h b/examples/qualcomm/oss_scripts/llama/runner/runner.h index 39ce62c2d9f..5d03a12f61a 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/runner.h +++ b/examples/qualcomm/oss_scripts/llama/runner/runner.h @@ -46,12 +46,6 @@ enum DecoderModelVersion { kGemma2, }; -enum KvBitWidth { - kWidth8 = 8, - kWidth16 = 16, -}; - -template class Runner : public executorch::extension::llm::IRunner { public: explicit Runner( @@ -121,14 +115,15 @@ class Runner : public executorch::extension::llm::IRunner { DecoderModelVersion decoder_model_version_; std::unique_ptr buffer_manager_; - std::unique_ptr> kv_manager_; + std::unique_ptr kv_manager_; std::unique_ptr tokenizer_; std::unique_ptr decoder_runner_; std::unique_ptr attention_sink_rope_runner_; - std::unique_ptr> prompt_processor_; - std::unique_ptr> token_generator_; + std::unique_ptr prompt_processor_; + std::unique_ptr token_generator_; // stats executorch::llm::Stats stats_; }; + } // namespace example diff --git a/examples/qualcomm/oss_scripts/llama/runner/token_generator.cpp b/examples/qualcomm/oss_scripts/llama/runner/token_generator.cpp index 8ab82d932e1..098fcf9efa6 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/token_generator.cpp +++ b/examples/qualcomm/oss_scripts/llama/runner/token_generator.cpp @@ -17,15 +17,15 @@ using executorch::runtime::Span; using executorch::runtime::TensorInfo; namespace example { -template -TokenGenerator::TokenGenerator( +TokenGenerator::TokenGenerator( tokenizers::Tokenizer* tokenizer, DecoderRunner* decoder_runner, - KVManager* kv_manager, + KVManager* kv_manager, const std::string& method_name, std::unique_ptr>&& eos_ids, Metadata metadata, - executorch::llm::Stats* stats) + executorch::llm::Stats* stats, + std::unique_ptr method_meta) : tokenizer_(tokenizer), decoder_runner_(decoder_runner), kv_manager_(kv_manager), @@ -39,32 +39,37 @@ TokenGenerator::TokenGenerator( v_cache_out_.resize(metadata_.num_layers); // Calculate I/O size + Result attention_mask = method_meta->input_tensor_meta(1); + Result logits = method_meta->output_tensor_meta(0); + input_toks_.size = metadata_.ar_len * sizeof(int64_t); input_pos_.size = metadata_.ar_len * sizeof(int32_t); - attention_mask_.size = - metadata_.ar_len * metadata_.context_len * sizeof(uint16_t); + attention_mask_.dtype = attention_mask->scalar_type(); + attention_mask_.size = metadata_.ar_len * metadata_.context_len * + attention_mask_.getElementSize(); switch (metadata_.cache_mode) { case CacheMode::StaticCahce: - attention_mask_.size = - metadata_.ar_len * metadata_.context_len * sizeof(uint16_t); window_attention_mask_.size = 0; break; - case CacheMode::HybridCache: - attention_mask_.size = - metadata_.ar_len * metadata_.context_len * sizeof(uint16_t); - window_attention_mask_.size = - metadata_.ar_len * metadata_.context_len * sizeof(uint16_t); + case CacheMode::HybridCache: { + Result window_attention_mask = + method_meta->input_tensor_meta(2); + window_attention_mask_.dtype = window_attention_mask->scalar_type(); + window_attention_mask_.size = metadata_.ar_len * metadata_.context_len * + window_attention_mask_.getElementSize(); break; + } default: ET_CHECK_MSG(false, "Unsupported llama cache mode"); break; } - logits_.size = metadata_.ar_len * metadata_.vocab_size * sizeof(uint16_t); + logits_.dtype = logits->scalar_type(); + logits_.size = + metadata_.ar_len * metadata_.vocab_size * logits_.getElementSize(); } -template -void TokenGenerator::init_io( +void TokenGenerator::init_io( IMemAlloc* buffer_manager, Result method_meta) { size_t idx = 0; @@ -86,8 +91,7 @@ void TokenGenerator::init_io( // [I]: attention_mask Result attention_mask = method_meta->input_tensor_meta(idx++); - attention_mask_.data = reinterpret_cast( - buffer_manager->allocate(attention_mask_.size)); + attention_mask_.data = buffer_manager->allocate(attention_mask_.size); attention_mask_.tensor = std::make_unique( attention_mask->scalar_type(), attention_mask->sizes().size(), @@ -103,8 +107,8 @@ void TokenGenerator::init_io( if (metadata_.cache_mode == CacheMode::HybridCache) { Result window_attention_mask = method_meta->input_tensor_meta(idx++); - window_attention_mask_.data = reinterpret_cast( - buffer_manager->allocate(window_attention_mask_.size)); + window_attention_mask_.data = + buffer_manager->allocate(window_attention_mask_.size); window_attention_mask_.tensor = std::make_unique( window_attention_mask->scalar_type(), window_attention_mask->sizes().size(), @@ -141,31 +145,28 @@ void TokenGenerator::init_io( for (int cache_group = 0; cache_group < 2; ++cache_group) { std::vector>& cache = (cache_group == 0 ? k_cache_in_ : v_cache_in_); - std::vector> cache_ptrs = (cache_group == 0) + std::vector cache_ptrs = (cache_group == 0) ? kv_manager_->get_k_cache_() : kv_manager_->get_v_cache_(); for (int layer = 0; layer < metadata_.num_layers; ++layer, ++index) { Result kv_cache = method_meta->input_tensor_meta(index); - T* cache_ptr = cache_ptrs[layer].buffer; - cache[layer] = std::make_unique( kv_cache->scalar_type(), kv_cache->sizes().size(), const_cast(kv_cache->sizes().data()), - cache_ptr, + cache_ptrs[layer].buffer, const_cast(kv_cache->dim_order().data())); input_tensors_.emplace_back(cache[layer].get()); cache_inputs_.emplace_back(input_tensors_.back()); buffer_manager->add_memory_info( - cache_ptr, cache[layer]->nbytes(), kv_cache.get()); + cache_ptrs[layer].buffer, cache[layer]->nbytes(), kv_cache.get()); } } // [O]: logits Result logits = method_meta->output_tensor_meta(0); - logits_.data = - reinterpret_cast(buffer_manager->allocate(logits_.size)); + logits_.data = buffer_manager->allocate(logits_.size); logits_.tensor = std::make_unique( logits->scalar_type(), logits->sizes().size(), @@ -180,21 +181,22 @@ void TokenGenerator::init_io( for (int cache_group = 0; cache_group < 2; ++cache_group) { std::vector>& cache = (cache_group == 0 ? k_cache_out_ : v_cache_out_); - std::vector> cache_ptrs = (cache_group == 0) + std::vector cache_ptrs = (cache_group == 0) ? kv_manager_->get_k_cache_() : kv_manager_->get_v_cache_(); for (int layer = 0; layer < metadata_.num_layers; ++layer, ++index) { Result kv_cache = method_meta->output_tensor_meta(index); - T* cache_ptr = cache_ptrs[layer].output_buffer; cache[layer] = std::make_unique( kv_cache->scalar_type(), kv_cache->sizes().size(), const_cast(kv_cache->sizes().data()), - cache_ptr, + cache_ptrs[layer].output_buffer, const_cast(kv_cache->dim_order().data())); output_tensors_.emplace_back(cache[layer].get()); buffer_manager->add_memory_info( - cache_ptr, cache[layer]->nbytes(), kv_cache.get()); + cache_ptrs[layer].output_buffer, + cache[layer]->nbytes(), + kv_cache.get()); } } // Prepare the vector of EValue to run inference @@ -204,14 +206,12 @@ void TokenGenerator::init_io( } } -template -const std::vector& TokenGenerator::get_all_logits() { +const std::vector& TokenGenerator::get_all_logits() { return token_all_logits_; } // This function only considers the case where token_generator_ar_len equals 1. -template -void TokenGenerator::prepare_io(uint64_t cur_token, int64_t start_pos) { +void TokenGenerator::prepare_io(uint64_t cur_token, int64_t start_pos) { // update input_tok *input_toks_.data = metadata_.use_int64_token ? cur_token : static_cast(cur_token); @@ -219,8 +219,7 @@ void TokenGenerator::prepare_io(uint64_t cur_token, int64_t start_pos) { *input_pos_.data = static_cast(start_pos); } -template -Result TokenGenerator::generate( +Result TokenGenerator::generate( std::vector tokens, int64_t start_pos, int32_t seq_len, @@ -306,7 +305,9 @@ Result TokenGenerator::generate( token_all_logits_.insert( token_all_logits_.end(), logits_.data, - logits_.data + metadata_.ar_len * metadata_.vocab_size); + logits_.data + + metadata_.ar_len * metadata_.vocab_size * + logits_.getElementSize()); } ET_CHECK_OK_OR_RETURN_ERROR(logits_res.error()); executorch::aten::Tensor& logits_tensor = logits_res.get(); @@ -374,8 +375,5 @@ Result TokenGenerator::generate( return pos - start_pos; } -// Explicit instantiations -template class TokenGenerator; -template class TokenGenerator; } // namespace example diff --git a/examples/qualcomm/oss_scripts/llama/runner/token_generator.h b/examples/qualcomm/oss_scripts/llama/runner/token_generator.h index 7f9264b1102..6945d907a76 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/token_generator.h +++ b/examples/qualcomm/oss_scripts/llama/runner/token_generator.h @@ -22,7 +22,7 @@ namespace example { * @class TokenGenerator * @brief Class for generating the token using decoder and key-value manager. */ -template + class TokenGenerator { public: struct Metadata { @@ -38,11 +38,12 @@ class TokenGenerator { TokenGenerator( tokenizers::Tokenizer* tokenizer, DecoderRunner* decoder_runner, - KVManager* kv_manager, + KVManager* kv_manager, const std::string& method_name, std::unique_ptr>&& eos_ids, Metadata metadata, - executorch::llm::Stats* stats); + executorch::llm::Stats* stats, + std::unique_ptr method_meta); virtual ~TokenGenerator() = default; /** @@ -58,9 +59,9 @@ class TokenGenerator { /** * @brief Get the all logits generated * - * @return std::vector& all the logits generated + * @return std::vector& all the logits generated */ - virtual const std::vector& get_all_logits(); + virtual const std::vector& get_all_logits(); /**    * @brief Generate tokens. @@ -78,28 +79,23 @@ class TokenGenerator { bool dump_logits, AttentionSinkRopeRunner* attention_sink_rope_runner); inline const size_t total_token_generator_io_size_in_bytes() const { - if (metadata_.cache_mode == CacheMode::HybridCache) { - return input_toks_.size + input_pos_.size + attention_mask_.size + - window_attention_mask_.size + logits_.size; - } else { - return input_toks_.size + input_pos_.size + attention_mask_.size + - logits_.size; - } + return input_toks_.size + input_pos_.size + attention_mask_.size + + window_attention_mask_.size + logits_.size; } protected: tokenizers::Tokenizer* tokenizer_; DecoderRunner* decoder_runner_; - KVManager* kv_manager_; + KVManager* kv_manager_; std::string method_name_; std::unique_ptr> eos_ids_; // inputs and outputs TensorStruct input_toks_; TensorStruct input_pos_; - TensorStruct attention_mask_; - TensorStruct window_attention_mask_; - TensorStruct logits_; + TensorStructRaw attention_mask_; + TensorStructRaw window_attention_mask_; + TensorStructRaw logits_; // layer -> TensorImpl std::vector> k_cache_in_; @@ -128,6 +124,6 @@ class TokenGenerator { Metadata metadata_; // Unused by default, only used when dump_logits_path is provided. - std::vector token_all_logits_; + std::vector token_all_logits_; }; } // namespace example diff --git a/examples/qualcomm/oss_scripts/llama/runner/utils.h b/examples/qualcomm/oss_scripts/llama/runner/utils.h index bef6b1a2017..df6dddfdc6e 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/utils.h +++ b/examples/qualcomm/oss_scripts/llama/runner/utils.h @@ -8,10 +8,16 @@ #pragma once #include +#include #include #include // Template struct to hold tensor data and tensor + +// TODO: Refactor these struct to use TensorPtr +// see https://docs.pytorch.org/executorch/stable/extension-tensor.html + +// TensorStruct whose dtype known in compile time template struct TensorStruct { std::unique_ptr tensor; @@ -20,3 +26,38 @@ struct TensorStruct { // data size in bytes size_t size; }; + +inline size_t getDtypeSize(executorch::aten::ScalarType dtype) { + switch (dtype) { + case executorch::aten::ScalarType::Float: + return sizeof(float); + case executorch::aten::ScalarType::Double: + return sizeof(double); + case executorch::aten::ScalarType::Int: + return sizeof(int32_t); + case executorch::aten::ScalarType::Long: + return sizeof(int64_t); + case executorch::aten::ScalarType::Byte: + return sizeof(uint8_t); + case executorch::aten::ScalarType::UInt16: + return sizeof(uint16_t); + default: + ET_CHECK_MSG( + false, + "Unsupported scalar type %s", + executorch::runtime::toString(dtype)); + break; + } +} + +// TensorStruct whose dtype known in runtime, and raw file is used +struct TensorStructRaw { + std::unique_ptr tensor; + std::byte* data; + // data size in bytes + size_t size; + executorch::aten::ScalarType dtype; + size_t getElementSize() const { + return getDtypeSize(dtype); + } +}; diff --git a/examples/qualcomm/oss_scripts/llama/wrappers/attention_sink_wrappers.py b/examples/qualcomm/oss_scripts/llama/wrappers/attention_sink_wrappers.py index 48386f181d8..de857dfc17c 100644 --- a/examples/qualcomm/oss_scripts/llama/wrappers/attention_sink_wrappers.py +++ b/examples/qualcomm/oss_scripts/llama/wrappers/attention_sink_wrappers.py @@ -13,6 +13,7 @@ import torch from executorch.backends.qualcomm._passes import TagQuantIO +from executorch.backends.qualcomm._passes.build_quant_io import BuildQuantIo from executorch.backends.qualcomm._passes.qnn_pass_manager import ( get_capture_program_passes, ) @@ -460,6 +461,7 @@ def compile(self, attention_sink_evictor_pte_path: str): alloc_graph_input=False, alloc_graph_output=False, ), + passes=[BuildQuantIo()], extract_delegate_segments=True, ) exec_prog_mgr = edge_prog_mgr.to_executorch(executorch_config) diff --git a/examples/qualcomm/oss_scripts/llama/wrappers/llm_wrappers.py b/examples/qualcomm/oss_scripts/llama/wrappers/llm_wrappers.py index ef72e0765fd..0d5052c89bd 100644 --- a/examples/qualcomm/oss_scripts/llama/wrappers/llm_wrappers.py +++ b/examples/qualcomm/oss_scripts/llama/wrappers/llm_wrappers.py @@ -19,6 +19,7 @@ import torch from executorch.backends.qualcomm._passes import FoldQDQ, I64toI32, TagQuantIO +from executorch.backends.qualcomm._passes.build_quant_io import BuildQuantIo from executorch.backends.qualcomm._passes.qnn_pass_manager import ( get_capture_program_passes, ) @@ -607,23 +608,28 @@ def quantize(self, request: Request): # noqa: C901 ): return + data = request.method_data[TEXT_DECODER] # check bit width graph io fixed_point_type = {"kv_type": torch.float32, "io_type": torch.float32} - if self.quant_recipe.get_kv_io_bit_width() == 8: - fixed_point_type["kv_type"] = torch.uint8 - elif self.quant_recipe.get_kv_io_bit_width() == 16: - fixed_point_type["kv_type"] = torch.uint16 + if data.skip_quantize: + # already init as float32 + return else: - raise RuntimeError( - f"unknown kv io bit width {self.quant_recipe.get_kv_io_bit_width()}" - ) + if self.quant_recipe.get_kv_io_bit_width() == 8: + fixed_point_type["kv_type"] = torch.uint8 + elif self.quant_recipe.get_kv_io_bit_width() == 16: + fixed_point_type["kv_type"] = torch.uint16 + else: + raise RuntimeError( + f"unknown kv io bit width {self.quant_recipe.get_kv_io_bit_width()}" + ) - if self.quant_recipe.get_logits_output_bit_width() == 16: - fixed_point_type["io_type"] = torch.uint16 - else: - raise RuntimeError( - f"unknown logits io bit width {self.quant_recipe.get_logits_output_bit_width()}" - ) + if self.quant_recipe.get_logits_output_bit_width() == 16: + fixed_point_type["io_type"] = torch.uint16 + else: + raise RuntimeError( + f"unknown logits io bit width {self.quant_recipe.get_logits_output_bit_width()}" + ) data = request.method_data[TEXT_DECODER] audio_turns = request.method_data[ @@ -906,7 +912,11 @@ def compile(self, request: Request): # noqa: C901 # here we use a mechanism to make sure the encoding align correctly and # save AoT quantization time as well. # --- - if self.prefill.decoder is not None and self.prefill.model_args.use_kv_cache: + if ( + self.prefill.decoder is not None + and self.prefill.model_args.use_kv_cache + and not request.method_data[TEXT_DECODER].skip_quantize + ): self._encoding_override( decode_model=self.decode.decoder, prefill_model=self.prefill.decoder, @@ -973,6 +983,7 @@ def compile(self, request: Request): # noqa: C901 alloc_graph_input=False, alloc_graph_output=False, ), + passes=[BuildQuantIo()], ) tok_embedding_exec_prog_mgr = tok_embedding_edge_prog_mgr.to_executorch( executorch_config @@ -1009,6 +1020,7 @@ def compile(self, request: Request): # noqa: C901 alloc_graph_input=False, alloc_graph_output=False, ), + passes=[BuildQuantIo()], ) exec_prog_mgr = edge_prog_mgr.to_executorch(executorch_config) data = request.method_data[TEXT_DECODER] @@ -1127,7 +1139,9 @@ def compile(self, request: Request): if self.control_args.verbose: print_delegation_info(edge_prog_mgr.exported_program().graph_module) - exec_prog_mgr = edge_prog_mgr.to_executorch(ExecutorchBackendConfig()) + exec_prog_mgr = edge_prog_mgr.to_executorch( + ExecutorchBackendConfig(passes=[BuildQuantIo()]) + ) data = request.method_data[self.modality] with open( f"{self.control_args.artifact}/{data.pte_filename}.pte", "wb" @@ -1223,6 +1237,7 @@ def compile( self, compile_specs: Dict[str, List[CompileSpec]], pte_filenames: Dict[str, str], + skip_quantize: Dict[str, bool], ): compile_request = Request( inspect.currentframe().f_code.co_name, @@ -1230,6 +1245,7 @@ def compile( m: Request.Data( compile_spec=compile_specs[m], pte_filename=pte_filenames[m], + skip_quantize=skip_quantize[m] if m in skip_quantize else False, ) for m in self._modalities }, diff --git a/examples/riscv/README.md b/examples/riscv/README.md index 563ff4913fd..3ae8a151f24 100644 --- a/examples/riscv/README.md +++ b/examples/riscv/README.md @@ -1,41 +1,36 @@ # RISC-V -Cross-compile `executor_runner` for `riscv64-linux-gnu` and run it under -`qemu-user-static` against a small bundled program. The end-to-end check -mirrors the Arm Cortex-M e2e flow: a `Test_result: PASS` line in stdout from -the bundled-IO comparison path is the pass criterion. +End-to-end smoke tests that cross-compile ExecuTorch for RISC-V and run a bundled program under QEMU. A `Test_result: PASS` line emitted by the bundled-IO comparison path is the pass criterion. -This is the Phase 1 deliverable for the RISC-V Support RFC at -[pytorch/executorch#18991][rfc]. The cross-compile and runner artifacts -(toolchain file, preset, AOT script) are designed to carry over unchanged -to a hardware-runner job once one becomes available; only the invocation -step (qemu-user vs. native) would change. - -[rfc]: https://github.com/pytorch/executorch/issues/18991 +Part of the RISC-V Support RFC, [pytorch/executorch#18991](https://github.com/pytorch/executorch/issues/18991). ## Quick start (Ubuntu / Debian) ```bash -examples/riscv/setup.sh # apt: gcc-riscv64-linux-gnu, qemu-user-static -examples/riscv/run.sh # export, cross-compile, run under qemu-user +examples/riscv/setup-linux.sh # apt: gcc cross riscv64-linux-gnu + qemu-user +examples/riscv/setup-baremetal.sh # apt: gcc cross riscv64-unknown-elf + qemu-system + picolibc +examples/riscv/run.sh # export, cross-compile, run under qemu ``` -The driver does three steps: +`run.sh` accepts: + +| Flag | Values | Default | Notes | +|---|---|---|---| +| `--model=` | `add`, `mv2`, `mobilebert`, `llama2`, `resnet18`, `yolo26` | `add` | which model to export | +| `--quantize` | flag | off | XNNPACK quantizer (requires `--backend=xnnpack`) | +| `--backend=` | `portable`, `xnnpack` | `portable` | xnnpack is linux-only | +| `--os=` | `linux`, `baremetal` | `linux` | qemu-user vs qemu-system + semihosting | +| `--arch=` | `rv32`, `rv64` | `rv64` | valid - pairs are `linux-rv64`, `baremetal-rv32`, `baremetal-rv64` | +| `--qemu-cpu-ext=` | e.g. `v=true,vlen=128` | empty | extensions appended after the arch base | + +## Pipelines + +**linux**: `aot_riscv.py` → `cmake --preset riscv64-linux` → `executor_runner` under `qemu-riscv64`. Portable kernels + (optional) XNNPACK delegate. + +**baremetal**: `aot_riscv.py` → `cmake -S examples/riscv/baremetal` (standalone project; pulls executorch in via `add_subdirectory`) → `executor_runner_baremetal.elf` under `qemu-system-riscv64 -machine virt -bios none -semihosting-config target=native`. -1. `python examples/riscv/aot_riscv.py` exports a `torch.add` module to - `riscv_test/add_riscv.bpte` (a BundledProgram with reference outputs - embedded for two test cases). -2. `cmake --preset riscv64-linux` configures the cross-build using - `examples/riscv/riscv64-linux-gnu-toolchain.cmake` and - `tools/cmake/preset/riscv64_linux.cmake`. `executor_runner` is built - against portable kernels with `ET_BUNDLE_IO_ENABLED` defined. -3. `qemu-riscv64-static` invokes the runner with `--model_path` pointing at - the `.bpte`. The runner detects the bundle, runs every embedded test case, - and emits `Test_result: PASS` (or `FAIL`) per case. +The baremetal runner embeds the `.bpte` directly in `.rodata` via the same `examples/arm/executor_runner/pte_to_header.py` Cortex-M uses; semihosting SYS_WRITE0 / SYS_EXIT carry log output and exit status to the host. ## CI -`.github/workflows/_test_riscv_qemu.yml` is a reusable `workflow_call` -job (mirroring `_test_cortex_m_e2e.yml`) invoked from `pull.yml` to run on -every PR. It runs on the standard `linux.2xlarge` x86_64 runner using the -`executorch-ubuntu-22.04-gcc11` docker image. +`.github/workflows/riscv64.yml` is the entry point; it fans out into `_test_riscv.yml` over a `(model, backend, os, arch, quantize)` matrix and sweeps `qemu-cpu-ext` per backend. Runs on the `executorch-ubuntu-26.04-gcc15` docker image (needed for the `riscv64-unknown-elf` picolibc + libstdc++ packages - see [setup-linux.sh](setup-linux.sh) or [setup-baremetal.sh](setup-baremetal.sh)). diff --git a/examples/riscv/aot_riscv.py b/examples/riscv/aot_riscv.py index 529e2b1e767..e01fe6f954e 100644 --- a/examples/riscv/aot_riscv.py +++ b/examples/riscv/aot_riscv.py @@ -3,11 +3,12 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -"""AOT export for the RISC-V smoke test. +"""AOT export for the RISC-V smoke tests. -Exports a small model to a BundledProgram (.bpte) that the portable -executor_runner can load on a riscv64 target and verify against the embedded -reference output, emitting ``Test_result: PASS`` on success. +Exports the model selected by ``--model`` to a BundledProgram (.bpte) that +either ``executor_runner`` (linux) or ``executor_runner_baremetal`` (qemu +virt + semihosting) consumes. The bundled-IO comparison path inside the +runner emits ``Test_result: PASS`` per testset, which is what run.sh greps. """ import argparse @@ -114,12 +115,45 @@ def build_resnet18(): return model, example_inputs, test_inputs, False +def build_yolo26(): + # Mirrors examples/models/yolo26/export_and_validate.py: predict() once + # to materialise the predictor state Ultralytics expects pre-export. + import numpy as np + from ultralytics import YOLO + + input_h, input_w = 320, 320 + yolo = YOLO("yolo26n") + yolo.predict( + np.ones((input_h, input_w, 3)), + imgsz=(input_h, input_w), + device="cpu", + ) + + class Wrapper(torch.nn.Module): + def __init__(self): + super().__init__() + self.model = yolo.model.to(torch.device("cpu")).eval() + + def forward(self, x): + # yolo.model emits (predictions, feature_maps) in eval; keep the + # predictions tensor so BundledIO sees a single tensor output. + out = self.model(x) + return out[0] if isinstance(out, (tuple, list)) else out + + model = Wrapper().eval() + torch.manual_seed(0) + example_inputs = (torch.randn(1, 3, input_h, input_w),) + test_inputs = [example_inputs] + return model, example_inputs, test_inputs, False + + MODELS = { "add": build_add, "mv2": build_mv2, "mobilebert": build_mobilebert, "llama2": build_llama2, "resnet18": build_resnet18, + "yolo26": build_yolo26, } @@ -138,9 +172,19 @@ def main() -> None: help="Output .bpte path (default: _riscv.bpte)", ) parser.add_argument( - "--xnnpack", - action="store_true", - help="Lower through the XNNPACK partitioner", + "--backend", + choices=("portable", "xnnpack"), + default="portable", + help="AOT backend: 'portable' runs everything on the portable kernels, " + "'xnnpack' adds the XNNPACK partitioner (default: portable)", + ) + parser.add_argument( + "--os", + choices=("linux", "baremetal"), + default="linux", + help="Target OS for the runner that will consume this .bpte. The .bpte " + "itself is OS-independent; the flag is logged so callers can verify " + "the AOT/runtime sides agree (default: linux)", ) parser.add_argument( "--quantize", @@ -154,6 +198,13 @@ def main() -> None: ) args = parser.parse_args() + if args.debug_xnnpack and args.backend != "xnnpack": + parser.error("--debug-xnnpack requires --backend=xnnpack") + + # xnnpack pulls in pthreads + dynamic loading; baremetal runner doesn't have those. + if args.os == "baremetal" and args.backend == "xnnpack": + parser.error("--backend=xnnpack is not supported on --os=baremetal") + if args.debug_xnnpack: logging.basicConfig(level=logging.DEBUG) @@ -176,7 +227,7 @@ def main() -> None: exported = export(model, example_inputs, strict=strict) partitioners = [] - if args.xnnpack: + if args.backend == "xnnpack": from executorch.backends.xnnpack.partition.xnnpack_partitioner import ( XnnpackPartitioner, ) @@ -190,7 +241,9 @@ def main() -> None: compile_config = EdgeCompileConfig(_check_ir_validity=False) edge = to_edge_transform_and_lower( - exported, partitioner=partitioners, compile_config=compile_config + exported, + partitioner=partitioners, + compile_config=compile_config, ) delegated = sum( 1 @@ -198,7 +251,7 @@ def main() -> None: if n.op == "call_function" and "call_delegate" in str(n.target) ) print( - f"[aot_riscv] model={args.model} xnnpack={args.xnnpack} " + f"[aot_riscv] model={args.model} backend={args.backend} os={args.os} " f"quantize={args.quantize} delegated_nodes={delegated}" ) diff --git a/examples/riscv/baremetal/CMakeLists.txt b/examples/riscv/baremetal/CMakeLists.txt new file mode 100644 index 00000000000..b0208e41d2b --- /dev/null +++ b/examples/riscv/baremetal/CMakeLists.txt @@ -0,0 +1,117 @@ +# Copyright 2026 The ExecuTorch Authors. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Standalone runner project, invoked from examples/riscv/run.sh as: +# ~~~ +# cmake -S examples/riscv/baremetal -B \ +# -DEXECUTORCH_ROOT= \ +# -DRISCV_BAREMETAL_PTE=.bpte \ +# -DCMAKE_TOOLCHAIN_FILE=.../riscv{32,64}-unknown-elf-toolchain.cmake +# ~~~ +# Mirrors examples/arm/executor_runner/standalone/CMakeLists.txt so the +# top-level executorch CMake has no reference to examples/riscv/. + +cmake_minimum_required(VERSION 3.20) +project(riscv_executor_runner_baremetal LANGUAGES C CXX ASM) + +get_filename_component( + _default_executorch_root "${CMAKE_CURRENT_LIST_DIR}/../../.." ABSOLUTE +) +if(NOT DEFINED EXECUTORCH_ROOT) + set(EXECUTORCH_ROOT + "${_default_executorch_root}" + CACHE PATH "Path to the ExecuTorch checkout" + ) +endif() +if(NOT EXISTS "${EXECUTORCH_ROOT}/CMakeLists.txt") + message( + FATAL_ERROR + "EXECUTORCH_ROOT (${EXECUTORCH_ROOT}) does not contain an ExecuTorch CMake project." + ) +endif() + +set(RISCV_BAREMETAL_PTE + "" + CACHE FILEPATH "Path to the .bpte to embed in the baremetal runner" +) +if(NOT RISCV_BAREMETAL_PTE) + message( + FATAL_ERROR + "RISCV_BAREMETAL_PTE not set; pass -DRISCV_BAREMETAL_PTE= from run.sh" + ) +endif() + +include("${EXECUTORCH_ROOT}/tools/cmake/common/preset.cmake") +if(NOT DEFINED EXECUTORCH_BUILD_PRESET_FILE) + set(EXECUTORCH_BUILD_PRESET_FILE + "${EXECUTORCH_ROOT}/tools/cmake/preset/riscv_baremetal.cmake" + CACHE PATH "Preset used when configuring the standalone baremetal runner" + ) +endif() +load_build_preset() +include("${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake") + +add_subdirectory( + "${EXECUTORCH_ROOT}" "${CMAKE_BINARY_DIR}/executorch" EXCLUDE_FROM_ALL +) + +find_package(Python3 REQUIRED COMPONENTS Interpreter) + +set(_pte_header "${CMAKE_CURRENT_BINARY_DIR}/model_pte.h") +add_custom_command( + OUTPUT "${_pte_header}" + COMMAND + "${Python3_EXECUTABLE}" + "${EXECUTORCH_ROOT}/examples/arm/executor_runner/pte_to_header.py" --pte + "${RISCV_BAREMETAL_PTE}" --outdir "${CMAKE_CURRENT_BINARY_DIR}" --outfile + "model_pte.h" --section ".rodata.model_pte" + DEPENDS "${RISCV_BAREMETAL_PTE}" + COMMENT "Embedding ${RISCV_BAREMETAL_PTE} into model_pte.h" + VERBATIM +) + +# pte_to_header.py emits the byte array but not its length; the glue TU +# materialises the matching `model_pte_len` and is the only place the header is +# included (avoids a double-definition at link time). +file( + WRITE "${CMAKE_CURRENT_BINARY_DIR}/model_pte_glue.cpp" + "#include \n#include \"model_pte.h\"\nextern \"C\" const size_t model_pte_len = sizeof(model_pte);\n" +) + +add_executable( + executor_runner_baremetal + start.S executor_runner_baremetal.cpp + "${CMAKE_CURRENT_BINARY_DIR}/model_pte_glue.cpp" "${_pte_header}" +) +set_target_properties( + executor_runner_baremetal PROPERTIES SUFFIX ".elf" LINKER_LANGUAGE CXX +) +target_include_directories( + executor_runner_baremetal PRIVATE "${CMAKE_CURRENT_BINARY_DIR}" +) +target_compile_options( + executor_runner_baremetal PRIVATE -fno-exceptions -fno-rtti -fdata-sections + -ffunction-sections +) +# --specs=picolibc.specs / -nostartfiles / -march / -mabi all come from the +# toolchain file; only the linker script (QEMU virt memory map) is target- +# specific here. +target_link_options( + executor_runner_baremetal PRIVATE + "-T${CMAKE_CURRENT_SOURCE_DIR}/riscv_virt.ld" +) + +# gen_operators_lib / executorch_target_link_options_shared_lib attach INTERFACE +# --whole-archive options to portable_ops_lib (so the static-init +# kernel-registration TU survives DCE) and to executorch itself. Listing the +# libs once each is enough; an extra --whole-archive wrapper around them would +# include the same archive twice and double-register every op. +target_link_libraries(executor_runner_baremetal PRIVATE bundled_program) +if(TARGET portable_ops_lib) + target_link_libraries(executor_runner_baremetal PRIVATE portable_ops_lib) +endif() +if(TARGET portable_kernels) + target_link_libraries(executor_runner_baremetal PRIVATE portable_kernels) +endif() diff --git a/examples/riscv/baremetal/executor_runner_baremetal.cpp b/examples/riscv/baremetal/executor_runner_baremetal.cpp new file mode 100644 index 00000000000..d0bb128bd98 --- /dev/null +++ b/examples/riscv/baremetal/executor_runner_baremetal.cpp @@ -0,0 +1,286 @@ +/* + * Copyright 2026 The ExecuTorch Authors. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// Baremetal runner for qemu-system-riscv64 -machine virt + semihosting. Loads +// a .bpte embedded into the ELF and emits "TEST: BundleIO index[N] +// Test_result: PASS|FAIL" via ET_LOG so examples/riscv/run.sh's grep can +// detect success without a host filesystem. + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "semihosting.h" + +extern "C" const uint8_t model_pte[]; +extern "C" const size_t model_pte_len; + +using executorch::extension::BufferDataLoader; +using executorch::runtime::Error; +using executorch::runtime::HierarchicalAllocator; +using executorch::runtime::MemoryAllocator; +using executorch::runtime::MemoryManager; +using executorch::runtime::Method; +using executorch::runtime::MethodMeta; +using executorch::runtime::Program; +using executorch::runtime::Result; +using executorch::runtime::Span; + +namespace { + +// Pools are sized for the largest model we currently test (llama2 / yolo26) +// rather than per-model; the .bss grows but freestanding picolibc never +// allocates from it so the cost is just a bigger ELF. Bumping these requires +// matching headroom in riscv_virt.ld's RAM region and qemu's -m flag. +alignas(16) uint8_t method_allocator_pool[1u << 23]; // 8 MiB +alignas(16) uint8_t temp_allocator_pool[1u << 22]; // 4 MiB +alignas(16) uint8_t planned_memory_pool[1u << 26]; // 64 MiB + +constexpr size_t kMaxPlannedBuffers = 8; +constexpr double kRtol = 0.01; +constexpr double kAtol = 0.01; + +} // namespace + +extern "C" [[noreturn]] void baremetal_exit(int status) { + executorch::riscv::baremetal::semihost_exit(status); +} + +// picolibc's abort()/raise() resolve _exit; with our own start.S we don't +// link its crt0, so reroute it to the semihosting trap. +extern "C" [[noreturn]] void _exit(int status) { + executorch::riscv::baremetal::semihost_exit(status); +} + +// libstdc++'s drags std::random_device → getentropy/read. The portable +// rand kernels are never invoked at runtime for our bundled-IO tests, so a +// failing stub is enough to satisfy the link. +extern "C" int getentropy(void*, size_t) { + return -1; +} +extern "C" long read(int, void*, size_t) { + return -1; +} + +// Virtual destructors emit deleting variants that reference operator delete +// even when we never new/delete. Stubs satisfy the linker; never called. +void operator delete(void*) noexcept {} +void operator delete(void*, size_t) noexcept {} +void operator delete[](void*) noexcept {} +void operator delete[](void*, size_t) noexcept {} + +// op_rand / op_native_dropout / op_randn from portable_kernels reference +// std::random_device::_M_{init,getval,fini}, whose only definitions live in +// libstdc++.a's medlow-built random.o (won't relocate at 0x80000000). The +// bundled-IO smoke tests never invoke those ops, so satisfy the linker with +// no-op trampolines under the Itanium-mangled names. +asm(R"( + .globl _ZNSt13random_device7_M_initERKNSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEE + .type _ZNSt13random_device7_M_initERKNSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEE, @function +_ZNSt13random_device7_M_initERKNSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEE: + ret + + .globl _ZNSt13random_device9_M_getvalEv + .type _ZNSt13random_device9_M_getvalEv, @function +_ZNSt13random_device9_M_getvalEv: + li a0, 0 + ret + + .globl _ZNSt13random_device7_M_finiEv + .type _ZNSt13random_device7_M_finiEv, @function +_ZNSt13random_device7_M_finiEv: + ret +)"); + +// Route ET_LOG through semihosting. Messages aren't null-terminated; copy and +// append \n\0 before forwarding to SYS_WRITE0. +extern "C" void et_pal_emit_log_message( + et_timestamp_t, + et_pal_log_level_t, + const char*, + const char*, + size_t, + const char* message, + size_t length) { + // The bundle doesn't expose a testset count, so we probe past the end and + // rely on InvalidArgument to terminate the loop. The accompanying ET_LOG + // ("testset_idx N is out of range ...") is benign noise — suppress it so + // run.sh's PASS/FAIL grep stays clean. + static const char kOorPrefix[] = "testset_idx "; + if (length >= sizeof(kOorPrefix) - 1 && + std::memcmp(message, kOorPrefix, sizeof(kOorPrefix) - 1) == 0) { + return; + } + char buf[512]; + size_t n = length < sizeof(buf) - 2 ? length : sizeof(buf) - 2; + std::memcpy(buf, message, n); + buf[n] = '\n'; + buf[n + 1] = '\0'; + executorch::riscv::baremetal::semihost_write0(buf); +} + +extern "C" void et_pal_init(void) {} +extern "C" [[noreturn]] void et_pal_abort(void) { + executorch::riscv::baremetal::semihost_exit(1); +} +extern "C" et_timestamp_t et_pal_current_ticks(void) { + return 0; +} +extern "C" et_tick_ratio_t et_pal_ticks_to_ns_multiplier(void) { + return {1, 1}; +} +extern "C" void* et_pal_allocate(size_t) { + return nullptr; +} +extern "C" void et_pal_free(void*) {} + +int main() { + executorch::runtime::runtime_init(); + + const void* program_data = nullptr; + size_t program_size = 0; + Error status = executorch::bundled_program::get_program_data( + const_cast(model_pte), + model_pte_len, + &program_data, + &program_size); + if (status != Error::Ok) { + ET_LOG( + Error, "get_program_data failed: 0x%x", static_cast(status)); + return 1; + } + + BufferDataLoader loader(program_data, program_size); + Result program = Program::load(&loader); + if (!program.ok()) { + ET_LOG( + Error, + "Program::load failed: 0x%x", + static_cast(program.error())); + return 1; + } + + // The harness always exports a single "forward" method. Skipping the + // Result deref of program->get_method_name(0) sidesteps a + // codegen wedge we hit under -mcmodel=medany + picolibc. + const char* method_name = "forward"; + ET_LOG(Info, "Using method %s", method_name); + + Result method_meta = program->method_meta(method_name); + if (!method_meta.ok()) { + ET_LOG( + Error, + "method_meta failed: 0x%x", + static_cast(method_meta.error())); + return 1; + } + + MemoryAllocator method_allocator( + sizeof(method_allocator_pool), method_allocator_pool); + MemoryAllocator temp_allocator( + sizeof(temp_allocator_pool), temp_allocator_pool); + + // One span per planned buffer, bumped through a single .bss arena so we + // don't need a heap. kMaxPlannedBuffers / pool size both grow with bigger + // models; failures here are loud rather than silent. + Span planned_spans[kMaxPlannedBuffers]; + size_t num_planned = method_meta->num_memory_planned_buffers(); + if (num_planned > kMaxPlannedBuffers) { + ET_LOG( + Error, + "num_planned=%zu exceeds kMaxPlannedBuffers=%zu", + num_planned, + kMaxPlannedBuffers); + return 1; + } + size_t offset = 0; + for (size_t id = 0; id < num_planned; ++id) { + size_t sz = + static_cast(method_meta->memory_planned_buffer_size(id).get()); + sz = (sz + 15u) & ~15u; + if (offset + sz > sizeof(planned_memory_pool)) { + ET_LOG( + Error, + "planned buffer %zu (size %zu) overflows pool (%zu/%zu)", + id, + sz, + offset, + sizeof(planned_memory_pool)); + return 1; + } + planned_spans[id] = Span(planned_memory_pool + offset, sz); + offset += sz; + } + HierarchicalAllocator planned_memory( + Span>(planned_spans, num_planned)); + MemoryManager memory_manager( + &method_allocator, &planned_memory, &temp_allocator); + + Result method = program->load_method(method_name, &memory_manager); + if (!method.ok()) { + ET_LOG( + Error, + "load_method failed: 0x%x", + static_cast(method.error())); + return 1; + } + + // load_bundled_input returns InvalidArgument past the last testset; that's + // how we detect the loop terminator (the bundle has no public count API). + int rc = 0; + for (size_t testset_idx = 0;; ++testset_idx) { + Error load = executorch::bundled_program::load_bundled_input( + *method, const_cast(model_pte), testset_idx); + if (load != Error::Ok) { + if (testset_idx == 0) { + ET_LOG( + Error, + "load_bundled_input failed for testset 0: 0x%x", + static_cast(load)); + rc = 1; + } + break; + } + Error exec = method->execute(); + if (exec != Error::Ok) { + ET_LOG( + Error, + "execute failed for testset %zu: 0x%x", + testset_idx, + static_cast(exec)); + ET_LOG(Error, "TEST: BundleIO index[%zu] Test_result: FAIL", testset_idx); + rc = 1; + continue; + } + Error verify = executorch::bundled_program::verify_method_outputs( + *method, const_cast(model_pte), testset_idx, kRtol, kAtol); + if (verify == Error::Ok) { + ET_LOG(Info, "TEST: BundleIO index[%zu] Test_result: PASS", testset_idx); + } else { + ET_LOG( + Error, + "verify_method_outputs failed for testset %zu: 0x%x", + testset_idx, + static_cast(verify)); + ET_LOG(Error, "TEST: BundleIO index[%zu] Test_result: FAIL", testset_idx); + rc = 1; + } + } + + return rc; +} diff --git a/examples/riscv/baremetal/riscv_virt.ld b/examples/riscv/baremetal/riscv_virt.ld new file mode 100644 index 00000000000..34980116b1d --- /dev/null +++ b/examples/riscv/baremetal/riscv_virt.ld @@ -0,0 +1,85 @@ +/* + * Copyright 2026 The ExecuTorch Authors. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* qemu-system-riscv{32,64} -machine virt -bios none -kernel: the virt board's + * reset stub at 0x1000 jumps to DRAM base 0x80000000, so _start has to live + * there. RAM size matches the qemu `-m 512M` we pass from run.sh — the + * embedded .bpte in .rodata can be tens of MB for mv2 / llama2 / yolo26. */ + +OUTPUT_ARCH(riscv) +ENTRY(_start) + +MEMORY +{ + RAM (rwx) : ORIGIN = 0x80000000, LENGTH = 512M +} + +SECTIONS +{ + .text 0x80000000 : + { + KEEP(*(.text.boot)) + *(.text .text.*) + } > RAM + + .rodata : ALIGN(8) + { + *(.rodata .rodata.*) + *(.srodata .srodata.*) + } > RAM + + /* C++ global ctors. start.S calls picolibc's __libc_init_array, which + * walks symbols __bothinit_array_start..__bothinit_array_end (preinit + + * init combined). The stock newlib names (__init_array_start/end) are + * defined too for portability, but it's the "both" pair picolibc reads. */ + .bothinit_array : ALIGN(8) + { + PROVIDE_HIDDEN(__bothinit_array_start = .); + PROVIDE_HIDDEN(__preinit_array_start = .); + KEEP(*(.preinit_array)) + PROVIDE_HIDDEN(__preinit_array_end = .); + PROVIDE_HIDDEN(__init_array_start = .); + KEEP(*(SORT_BY_INIT_PRIORITY(.init_array.*) SORT_BY_INIT_PRIORITY(.ctors.*))) + KEEP(*(.init_array EXCLUDE_FILE(*crtbegin.o *crtbegin?.o *crtend.o *crtend?.o) .ctors)) + PROVIDE_HIDDEN(__init_array_end = .); + PROVIDE_HIDDEN(__bothinit_array_end = .); + } > RAM + .fini_array : ALIGN(8) + { + PROVIDE_HIDDEN(__fini_array_start = .); + KEEP(*(SORT_BY_INIT_PRIORITY(.fini_array.*) SORT_BY_INIT_PRIORITY(.dtors.*))) + KEEP(*(.fini_array EXCLUDE_FILE(*crtbegin.o *crtbegin?.o *crtend.o *crtend?.o) .dtors)) + PROVIDE_HIDDEN(__fini_array_end = .); + } > RAM + + .data : ALIGN(8) + { + *(.data .data.*) + *(.sdata .sdata.*) + } > RAM + + .bss : ALIGN(8) + { + _bss_start = .; + *(.bss .bss.*) + *(.sbss .sbss.*) + *(COMMON) + . = ALIGN(8); + _bss_end = .; + } > RAM + + /* 2 MiB stack at the high end of RAM; grows downward. picolibc's sbrk + * looks up __heap_start / __heap_end (double-underscore). */ + . = ALIGN(16); + PROVIDE(__heap_start = .); + . = ORIGIN(RAM) + LENGTH(RAM) - 2M; + PROVIDE(__heap_end = .); + . = . + 2M; + _stack_top = .; + + /DISCARD/ : { *(.note.* .comment .eh_frame .riscv.attributes) } +} diff --git a/examples/riscv/baremetal/semihosting.h b/examples/riscv/baremetal/semihosting.h new file mode 100644 index 00000000000..7af63048d29 --- /dev/null +++ b/examples/riscv/baremetal/semihosting.h @@ -0,0 +1,51 @@ +/* + * Copyright 2026 The ExecuTorch Authors. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace executorch { +namespace riscv { +namespace baremetal { + +// The RISC-V semihosting trigger is a fixed three-insn sequence (slli/ebreak/ +// srai of x0) so qemu can distinguish it from a normal ecall. Op number in +// a0, arg pointer in a1, return value back in a0. +inline long semihost_call(long op, const void* arg) { + register long a0 asm("a0") = op; + register long a1 asm("a1") = (long)arg; + asm volatile( + ".option push\n\t" + ".option norvc\n\t" + "slli x0, x0, 0x1f\n\t" + "ebreak\n\t" + "srai x0, x0, 0x7\n\t" + ".option pop" + : "+r"(a0) + : "r"(a1) + : "memory"); + return a0; +} + +constexpr long SYS_WRITE0 = 0x04; +constexpr long SYS_EXIT_EXTENDED = 0x20; + +inline void semihost_write0(const char* s) { + semihost_call(SYS_WRITE0, s); +} + +[[noreturn]] inline void semihost_exit(int status) { + // ADP_Stopped_ApplicationExit (0x20026) + status, per the semihosting spec. + long block[2] = {0x20026, (long)status}; + semihost_call(SYS_EXIT_EXTENDED, block); + __builtin_trap(); +} + +} // namespace baremetal +} // namespace riscv +} // namespace executorch diff --git a/examples/riscv/baremetal/start.S b/examples/riscv/baremetal/start.S new file mode 100644 index 00000000000..092eeffa4a6 --- /dev/null +++ b/examples/riscv/baremetal/start.S @@ -0,0 +1,49 @@ +/* + * Copyright 2026 The ExecuTorch Authors. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// Boot stub for the qemu virt RISC-V baremetal runner: set sp, enable FPU, +// zero .bss, run C++ static ctors via __libc_init_array, jump to main. On +// return, call baremetal_exit so qemu terminates deterministically. + +#if __riscv_xlen == 64 +#define SX sd +#define XLEN_BYTES 8 +#else +#define SX sw +#define XLEN_BYTES 4 +#endif + + .section .text.boot, "ax" + .globl _start + .type _start, @function +_start: + la sp, _stack_top + + // mstatus.FS resets to Off in M-mode, so any FP insn (libstdc++ template + // code emits fsd/fld) traps. We have no trap vector, so the CPU would + // loop on the fault. FS=Dirty (0b11 in bits 13-14) keeps the FPU live. + li t0, 0x6000 + csrs mstatus, t0 + + la a0, _bss_start + la a1, _bss_end +1: + bgeu a0, a1, 2f + SX zero, 0(a0) + addi a0, a0, XLEN_BYTES + j 1b +2: + call __libc_init_array + li a0, 0 + li a1, 0 + call main + call baremetal_exit +3: + wfi + j 3b + + .size _start, .-_start diff --git a/examples/riscv/requirements.txt b/examples/riscv/requirements.txt index 273e7156a1d..649696ae65c 100644 --- a/examples/riscv/requirements.txt +++ b/examples/riscv/requirements.txt @@ -1,2 +1,3 @@ torchvision transformers +ultralytics diff --git a/examples/riscv/riscv32-unknown-elf-toolchain.cmake b/examples/riscv/riscv32-unknown-elf-toolchain.cmake new file mode 100644 index 00000000000..ae968ea6fe2 --- /dev/null +++ b/examples/riscv/riscv32-unknown-elf-toolchain.cmake @@ -0,0 +1,74 @@ +# Copyright 2026 The ExecuTorch Authors. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# rv32 baremetal cross-toolchain. Uses the multilib-aware riscv64-unknown-elf +# gcc (one package, both XLENs); `-march=rv32...` + `-mabi=ilp32d` selects the +# 32-bit picolibc + libstdc++ variant. ELF runs under qemu-system-riscv32 +# -machine virt with semihosting. + +set(CMAKE_SYSTEM_NAME Generic) +set(CMAKE_SYSTEM_PROCESSOR riscv32) + +set(CMAKE_C_COMPILER + "riscv64-unknown-elf-gcc" + CACHE FILEPATH "" +) +set(CMAKE_CXX_COMPILER + "riscv64-unknown-elf-g++" + CACHE FILEPATH "" +) +set(CMAKE_ASM_COMPILER + "riscv64-unknown-elf-gcc" + CACHE FILEPATH "" +) +set(CMAKE_AR + "riscv64-unknown-elf-ar" + CACHE FILEPATH "" +) +set(CMAKE_RANLIB + "riscv64-unknown-elf-ranlib" + CACHE FILEPATH "" +) +set(CMAKE_STRIP + "riscv64-unknown-elf-strip" + CACHE FILEPATH "" +) + +set(CMAKE_EXECUTABLE_SUFFIX ".elf") +# try_compile() can't link without crt0/specs; archive-only sidesteps that. +set(CMAKE_TRY_COMPILE_TARGET_TYPE STATIC_LIBRARY) +set(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM NEVER) +set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY ONLY) +set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE ONLY) +set(CMAKE_FIND_ROOT_PATH_MODE_PACKAGE ONLY) + +set(CMAKE_C_STANDARD 11) +set(CMAKE_CXX_STANDARD 17) + +# Baseline rv32imafdc / ilp32d — the rv32gc-equivalent multilib Ubuntu's +# picolibc + libstdc++ ship. (Unlike rv64, the full rv32gc multilib *is* +# packaged, so we don't have to drop M / C here.) -mcmodel=medany because medlow +# can't reach our 0x80000000 base. picolibc.specs must be on the compile line +# too so libstdc++ headers find picolibc's C headers via the spec's sysroot. +add_compile_options( + --specs=picolibc.specs + -march=rv32imafdc + -mabi=ilp32d + -mcmodel=medany + -fdata-sections + -ffunction-sections + "$<$:-fno-rtti;-fno-exceptions;-fno-unwind-tables>" +) +# -nostdlib++ drops g++'s implicit libstdc++.a (medlow-built, won't relocate). +# -nostartfiles drops picolibc's crt0 in favour of our start.S. +add_link_options( + --specs=picolibc.specs + -march=rv32imafdc + -mabi=ilp32d + -mcmodel=medany + -nostdlib++ + -nostartfiles + "LINKER:--gc-sections" +) diff --git a/examples/riscv/riscv64-unknown-elf-toolchain.cmake b/examples/riscv/riscv64-unknown-elf-toolchain.cmake new file mode 100644 index 00000000000..a4533675f89 --- /dev/null +++ b/examples/riscv/riscv64-unknown-elf-toolchain.cmake @@ -0,0 +1,77 @@ +# Copyright 2026 The ExecuTorch Authors. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# rv64 baremetal cross-toolchain (Ubuntu 26.04+ packages: +# gcc-riscv64-unknown-elf, picolibc-riscv64-unknown-elf, +# libstdc++-riscv64-unknown-elf-picolibc). The resulting ELF runs under +# qemu-system-riscv64 -machine virt with semihosting. + +set(CMAKE_SYSTEM_NAME Generic) +set(CMAKE_SYSTEM_PROCESSOR riscv64) + +set(CMAKE_C_COMPILER + "riscv64-unknown-elf-gcc" + CACHE FILEPATH "" +) +set(CMAKE_CXX_COMPILER + "riscv64-unknown-elf-g++" + CACHE FILEPATH "" +) +set(CMAKE_ASM_COMPILER + "riscv64-unknown-elf-gcc" + CACHE FILEPATH "" +) +set(CMAKE_AR + "riscv64-unknown-elf-ar" + CACHE FILEPATH "" +) +set(CMAKE_RANLIB + "riscv64-unknown-elf-ranlib" + CACHE FILEPATH "" +) +set(CMAKE_STRIP + "riscv64-unknown-elf-strip" + CACHE FILEPATH "" +) + +set(CMAKE_EXECUTABLE_SUFFIX ".elf") +# try_compile() can't link without crt0/specs; archive-only sidesteps that. +set(CMAKE_TRY_COMPILE_TARGET_TYPE STATIC_LIBRARY) +set(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM NEVER) +set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY ONLY) +set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE ONLY) +set(CMAKE_FIND_ROOT_PATH_MODE_PACKAGE ONLY) + +set(CMAKE_C_STANDARD 11) +set(CMAKE_CXX_STANDARD 17) + +# Picked baseline: rv64iafd / lp64d. Ubuntu's picolibc + libstdc++ packages +# don't ship the rv64gc (= rv64imafdc) multilib, so this drops M (integer mul) +# and C (compressed) but keeps double-float. -mcmodel=medany because medlow's +# signed-32-bit-around-0 reach can't address our 0x80000000 base. +# --specs=picolibc.specs has to appear at *compile* time too: libstdc++'s +# // need picolibc's C headers via the spec's +# sysroot. +add_compile_options( + --specs=picolibc.specs + -march=rv64iafd + -mabi=lp64d + -mcmodel=medany + -fdata-sections + -ffunction-sections + "$<$:-fno-rtti;-fno-exceptions;-fno-unwind-tables>" +) +# -nostdlib++ drops g++'s implicit libstdc++.a (medlow-built, won't relocate at +# 0x80000000); we only use its templates, no runtime calls. -nostartfiles drops +# picolibc's crt0 in favour of our start.S. +add_link_options( + --specs=picolibc.specs + -march=rv64iafd + -mabi=lp64d + -mcmodel=medany + -nostdlib++ + -nostartfiles + "LINKER:--gc-sections" +) diff --git a/examples/riscv/run.sh b/examples/riscv/run.sh index 2c207816bfc..0635bfedb4e 100755 --- a/examples/riscv/run.sh +++ b/examples/riscv/run.sh @@ -4,42 +4,52 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# RISC-V Phase 1 smoke test driver (pytorch/executorch#18991): -# 1. Export a tiny model to a BundledProgram (.bpte) on the x86_64 host. -# 2. Cross-compile executor_runner for riscv64 Linux glibc. -# 3. Invoke the runner under qemu-user-static and grep its stdout for the -# Test_result: PASS marker emitted by the bundled-IO comparison path. +# RISC-V smoke test driver: +# 1. Export a small model to a BundledProgram (.bpte) on the host. +# 2. Cross-compile a riscv32/64 runner (linux glibc or baremetal). +# 3. Invoke under qemu and grep stdout for the Test_result: PASS marker. -set -eu +set -euo pipefail script_dir=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &> /dev/null && pwd) et_root_dir=$(realpath "${script_dir}/../..") build_only=false -build_dir="${et_root_dir}/cmake-out-riscv" -output_dir="${et_root_dir}/riscv_test" -qemu="qemu-riscv64-static" -qemu_timeout="600" +build_dir= +qemu_timeout="1800" model="add" -xnnpack=false +backend="portable" +os="linux" +arch="rv64" +qemu_cpu_ext="" quantize=false debug_xnnpack=false verbose_xnnpack=false +qemu_override="" usage() { cat < Which model to export and run (default: ${model}) - --xnnpack Enable the XNNPACK backend (AOT partitioner + runtime) --quantize Produce an 8-bit quantized model - --verbose-xnnpack Build XNNPACK with XNN_LOG_LEVEL=4 to log microkernel dispatch at runtime + --backend= AOT backend (default: ${backend}): + - 'portable': portable kernels only + - 'xnnpack': XNNPACK delegate (linux only) + --os= Target OS (default: ${os}): + - 'linux': glibc, qemu-user + - 'baremetal': no OS, qemu-system + semihosting + --arch= Target arch (default: ${arch}): + - 'rv64': riscv64 + - 'rv32': riscv32 + --qemu-cpu-ext= QEMU -cpu extensions appended after the arch base + (e.g. 'v=true,vlen=128'); no rv32/rv64 prefix. + --verbose-xnnpack Build XNNPACK with XNN_LOG_LEVEL=4 to log microkernel dispatch --debug-xnnpack Enable XNNPACK partitioner DEBUG logging and dump the lowered graph --build_only Only export and cross-compile; do not invoke QEMU - --build_dir= CMake build directory (default: ${build_dir}) - --output_dir= Directory for the exported .bpte (default: ${output_dir}) - --qemu= qemu-user binary (default: ${qemu}) - --timeout= Maximum QEMU runtime; matches run_fvp.sh --timelimit (default: ${qemu_timeout}) + --build-dir= Build/output directory for this configuration (required) + --qemu= Override qemu binary + --timeout= Maximum QEMU runtime (default: ${qemu_timeout}) -h, --help Show this help EOF } @@ -47,51 +57,125 @@ EOF for arg in "$@"; do case $arg in --model=*) model="${arg#*=}" ;; - --xnnpack) xnnpack=true ;; --quantize) quantize=true ;; + --backend=*) backend="${arg#*=}" ;; + --os=*) os="${arg#*=}" ;; + --arch=*) arch="${arg#*=}" ;; + --qemu-cpu-ext=*) qemu_cpu_ext="${arg#*=}" ;; --debug-xnnpack) debug_xnnpack=true ;; --verbose-xnnpack) verbose_xnnpack=true ;; --build_only) build_only=true ;; - --build_dir=*) build_dir="${arg#*=}" ;; - --output_dir=*) output_dir="${arg#*=}" ;; - --qemu=*) qemu="${arg#*=}" ;; + --build-dir=*) build_dir="${arg#*=}" ;; + --qemu=*) qemu_override="${arg#*=}" ;; --timeout=*) qemu_timeout="${arg#*=}" ;; -h|--help) usage; exit 0 ;; *) echo "Unknown option: $arg" >&2; usage; exit 1 ;; esac done -mkdir -p "${output_dir}" -bpte_path="${output_dir}/${model}_riscv.bpte" +case "${backend}" in + portable|xnnpack) ;; + *) echo "Unknown backend: ${backend}" >&2; usage; exit 1 ;; +esac +case "${os}" in + linux|baremetal) ;; + *) echo "Unknown os: ${os}" >&2; usage; exit 1 ;; +esac +case "${arch}" in + rv32|rv64) ;; + *) echo "Unknown arch: ${arch}" >&2; usage; exit 1 ;; +esac -echo "[run.sh] Step 1/3: AOT export on host" -aot_extra_args=() -if ${xnnpack}; then - aot_extra_args+=(--xnnpack) +# xnnpack needs pthreads + dynamic loading: baremetal has neither, and the +# Ubuntu xnnpack microkernels don't ship an rv32 build. +if [[ "${backend}" == "xnnpack" && "${os}" == "baremetal" ]]; then + echo "[run.sh] --backend=xnnpack requires --os=linux" >&2 + exit 1 +fi +if [[ "${backend}" == "xnnpack" && "${arch}" == "rv32" ]]; then + echo "[run.sh] --backend=xnnpack requires --arch=rv64" >&2 + exit 1 +fi +# Ubuntu doesn't package a riscv32-linux-gnu cross (riscv64-linux-gnu has no +# rv32 multilib either), so rv32 linux is blocked on a custom toolchain build. +if [[ "${arch}" == "rv32" && "${os}" == "linux" ]]; then + echo "[run.sh] --arch=rv32 --os=linux not supported: no riscv32-linux-gnu toolchain on Ubuntu" >&2 + exit 1 +fi + +if ${debug_xnnpack} && [[ "${backend}" != "xnnpack" ]]; then + echo "[run.sh] --debug-xnnpack requires --backend=xnnpack" >&2 + exit 1 fi +if ${verbose_xnnpack} && [[ "${backend}" != "xnnpack" ]]; then + echo "[run.sh] --verbose-xnnpack requires --backend=xnnpack" >&2 + exit 1 +fi + +if [[ -z "${build_dir}" ]]; then + echo "[run.sh] --build-dir is required" >&2; usage; exit 1 +fi +mkdir -p "${build_dir}" + +bpte_path="${build_dir}/model.bpte" + +echo "[run.sh] Step 1/3: AOT export on host (backend=${backend} os=${os} arch=${arch})" +aot_extra_args=() if ${quantize}; then aot_extra_args+=(--quantize) fi if ${debug_xnnpack}; then aot_extra_args+=(--debug-xnnpack) fi -python "${script_dir}/aot_riscv.py" --model "${model}" "${aot_extra_args[@]}" --output "${bpte_path}" +python "${script_dir}/aot_riscv.py" --model "${model}" --backend "${backend}" --os "${os}" "${aot_extra_args[@]}" --output "${bpte_path}" -echo "[run.sh] Step 2/3: cross-compile executor_runner for riscv64-linux" +echo "[run.sh] Step 2/3: cross-compile executor_runner for ${arch}-${os}" cmake_extra_args=() -if ${xnnpack}; then +if [[ "${backend}" == "xnnpack" ]]; then cmake_extra_args+=(-DEXECUTORCH_BUILD_XNNPACK=ON) fi if ${verbose_xnnpack}; then cmake_extra_args+=(-DEXECUTORCH_XNNPACK_LOG_LEVEL=4 -DEXECUTORCH_BUILD_RISCV_ETDUMP=ON) fi -cmake -S "${et_root_dir}" -B "${build_dir}" \ - --preset riscv64-linux \ - "${cmake_extra_args[@]}" \ - -DCMAKE_BUILD_TYPE=Release -cmake --build "${build_dir}" -j"$(nproc)" --target executor_runner -runner="${build_dir}/executor_runner" +# Map our short arch (rv32/rv64) to the canonical riscv32/riscv64 prefix used +# by the cross toolchain and qemu binary names. +case "${arch}" in + rv32) arch_long="riscv32" ;; + rv64) arch_long="riscv64" ;; +esac + +if [[ "${os}" == "linux" ]]; then + build_target="executor_runner" + qemu_default="qemu-${arch_long}-static" + cmake -S "${et_root_dir}" -B "${build_dir}" --fresh \ + --preset "${arch_long}-linux" \ + "${cmake_extra_args[@]}" \ + -DCMAKE_BUILD_TYPE=Release + cmake --build "${build_dir}" -j"$(nproc)" --target "${build_target}" + runner="${build_dir}/${build_target}" + +elif [[ "${os}" == "baremetal" ]]; then + build_target="executor_runner_baremetal" + qemu_default="qemu-system-${arch_long}" + # Standalone build (mirrors examples/arm/executor_runner/standalone) + cmake -S "${et_root_dir}/examples/riscv/baremetal" -B "${build_dir}" --fresh \ + -DCMAKE_TOOLCHAIN_FILE=${et_root_dir}/examples/riscv/${arch_long}-unknown-elf-toolchain.cmake \ + -DEXECUTORCH_BUILD_PRESET_FILE=${et_root_dir}/tools/cmake/preset/riscv_baremetal.cmake \ + -DEXECUTORCH_ROOT="${et_root_dir}" \ + -DRISCV_BAREMETAL_PTE="${bpte_path}" \ + "${cmake_extra_args[@]}" \ + -DCMAKE_BUILD_TYPE=Release + cmake --build "${build_dir}" -j"$(nproc)" --target "${build_target}" + runner="${build_dir}/${build_target}.elf" + +else + echo "Unknown os: ${os}" >&2 + usage + exit 1 +fi + +qemu="${qemu_override:-${qemu_default}}" [[ -x "${runner}" ]] || { echo "[run.sh] runner not found at ${runner}" >&2; exit 1; } if file "${runner}" | grep -q "RISC-V"; then @@ -109,49 +193,79 @@ fi echo "[run.sh] Step 3/3: run under ${qemu}" hash "${qemu}" 2>/dev/null || { - echo "[run.sh] ERROR: ${qemu} not found on PATH; install with examples/riscv/setup.sh" >&2 + echo "[run.sh] ERROR: ${qemu} not found on PATH; install with examples/riscv/setup-${os}.sh" >&2 exit 1 } -# QEMU_LD_PREFIX points qemu-user at the riscv64 sysroot so the dynamic -# linker (ld-linux-riscv64-lp64d.so.1) referenced in the ELF resolves. -export QEMU_LD_PREFIX="${QEMU_LD_PREFIX:-/usr/riscv64-linux-gnu}" +log_file="${build_dir}/run.log" +rm -f "${log_file}" -if [[ -n "${QEMU_CPU+x}" ]]; then - echo "[run.sh] QEMU_CPU=${QEMU_CPU}" +# Compose the QEMU -cpu value once: ${arch} alone, or ${arch},${ext} when an +# extension list was supplied. qemu-user reads $QEMU_CPU; qemu-system takes +# -cpu on the command line. +qemu_cpu="${arch}" +if [[ -n "${qemu_cpu_ext}" ]]; then + qemu_cpu="${arch},${qemu_cpu_ext}" fi +echo "[run.sh] qemu -cpu = ${qemu_cpu}" -runner_extra_args=() -if ${quantize}; then - runner_extra_args+=(--bundleio_rtol=0.1 --bundleio_atol=0.25) -fi -etdump_path="" -if ${verbose_xnnpack}; then - etdump_path="${output_dir}/${model}_riscv.etdump" - rm -f "${etdump_path}" - runner_extra_args+=(--etdump_path="${etdump_path}") -fi +if [[ "${os}" == "linux" ]]; then + # QEMU_LD_PREFIX points qemu-user at the cross sysroot so the dynamic + # linker (ld-linux-riscv*) referenced in the ELF resolves. + if [[ "${arch}" == "rv64" ]]; then + export QEMU_LD_PREFIX="${QEMU_LD_PREFIX:-/usr/riscv64-linux-gnu}" + else + export QEMU_LD_PREFIX="${QEMU_LD_PREFIX:-/usr/riscv32-linux-gnu}" + fi + export QEMU_CPU="${qemu_cpu}" -# etdump_summary.py reads the XNN_LOG_LEVEL=4 registrations. -log_file="${output_dir}/${model}_riscv.run.log" -rm -f "${log_file}" + runner_extra_args=() + if ${quantize}; then + runner_extra_args+=(--bundleio_rtol=0.1 --bundleio_atol=0.25) + fi + etdump_path="" + if ${verbose_xnnpack}; then + etdump_path="${build_dir}/run.etdump" + rm -f "${etdump_path}" + runner_extra_args+=(--etdump_path="${etdump_path}") + fi -set +e -timeout --signal=KILL "${qemu_timeout}" "${qemu}" "${runner}" \ - --model_path="${bpte_path}" \ - "${runner_extra_args[@]}" \ - 2>&1 | tee "${log_file}" -qemu_status=${PIPESTATUS[0]} -set -e + set +e + timeout --signal=KILL "${qemu_timeout}" "${qemu}" "${runner}" \ + --model_path="${bpte_path}" \ + "${runner_extra_args[@]}" \ + |& tee "${log_file}" + qemu_status=${PIPESTATUS[0]} + set -e -echo "[run.sh] qemu exit status: ${qemu_status}" + if [[ -n "${etdump_path}" && -f "${etdump_path}" ]]; then + python "${script_dir}/etdump_summary.py" "${etdump_path}" \ + --run-log "${log_file}" \ + --json "${etdump_path}.json" || true + fi + +elif [[ "${os}" == "baremetal" ]]; then + # qemu-system -machine virt boots at 0x80000000; -bios none skips OpenSBI; + # semihosting target=native routes SYS_WRITE0/SYS_EXIT to host stdio. + # For deeper debugging, add: -accel tcg,one-insn-per-tb=on -d in_asm,nochain + # -D + set +e + timeout --signal=KILL "${qemu_timeout}" "${qemu}" \ + -machine virt -cpu "${qemu_cpu}" -m 512M -nographic -bios none \ + -semihosting-config enable=on,target=native \ + -kernel "${runner}" \ + |& tee "${log_file}" + qemu_status=${PIPESTATUS[0]} + set -e -if [[ -n "${etdump_path}" && -f "${etdump_path}" ]]; then - python "${script_dir}/etdump_summary.py" "${etdump_path}" \ - --run-log "${log_file}" \ - --json "${etdump_path}.json" || true +else + echo "Unknown os: ${os}" >&2 + usage + exit 1 fi +echo "[run.sh] qemu exit status: ${qemu_status}" + if grep -q "Test_result: PASS" "${log_file}"; then echo "[run.sh] Bundled I/O check PASSED" exit 0 diff --git a/examples/riscv/setup-baremetal.sh b/examples/riscv/setup-baremetal.sh new file mode 100755 index 00000000000..f96e8c75032 --- /dev/null +++ b/examples/riscv/setup-baremetal.sh @@ -0,0 +1,65 @@ +#!/usr/bin/env bash +# Copyright 2026 The ExecuTorch Authors. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Host tooling for the RISC-V smoke tests. Targets Ubuntu 26.04: that's where +# libstdc++-riscv64-unknown-elf-picolibc was first packaged, and the baremetal +# build chain needs C++ stdlib headers paired with picolibc. + +set -euo pipefail + +script_dir=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &> /dev/null && pwd) + +if ! command -v apt-get >/dev/null 2>&1; then + echo "[$(basename "$0")] this setup script targets Debian/Ubuntu (apt-get not found)" >&2 + exit 1 +fi + +SUDO="" +if [[ $EUID -ne 0 ]]; then + SUDO="sudo" +fi + +source /etc/os-release + +GCC_VERSION="" +if [[ "${VERSION_ID:-}" == "24.04" || "${VERSION_ID:-}" == "26.04" ]]; then + GCC_VERSION="14" +fi + +${SUDO} apt-get update +${SUDO} apt-get install -y --no-install-recommends \ + build-essential \ + gcc${GCC_VERSION:+-${GCC_VERSION}} \ + g++${GCC_VERSION:+-${GCC_VERSION}} \ + gcc${GCC_VERSION:+-${GCC_VERSION}}-riscv64-linux-gnu \ + g++${GCC_VERSION:+-${GCC_VERSION}}-riscv64-linux-gnu \ + binutils-riscv64-linux-gnu \ + libc6-riscv64-cross \ + libc6-dev-riscv64-cross \ + gcc-riscv64-unknown-elf \ + picolibc-riscv64-unknown-elf \ + libstdc++-riscv64-unknown-elf-picolibc \ + cmake \ + file \ + ca-certificates \ + qemu-user \ + qemu-system-riscv \ + libglib2.0-0t64 \ + libxcb1 \ + libgl1 + +if [[ -n "${GCC_VERSION+x}" ]]; then + ${SUDO} update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc${GCC_VERSION:+-${GCC_VERSION}} 100 + ${SUDO} update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++${GCC_VERSION:+-${GCC_VERSION}} 100 + ${SUDO} update-alternatives --install /usr/bin/riscv64-linux-gnu-gcc riscv64-linux-gnu-gcc /usr/bin/riscv64-linux-gnu-gcc${GCC_VERSION:+-${GCC_VERSION}} 100 + ${SUDO} update-alternatives --install /usr/bin/riscv64-linux-gnu-g++ riscv64-linux-gnu-g++ /usr/bin/riscv64-linux-gnu-g++${GCC_VERSION:+-${GCC_VERSION}} 100 +fi + +riscv64-linux-gnu-gcc --version | head -n1 +qemu-riscv64 --version | head -n1 + +# Some python packages also need to be installed +pip install -r "${script_dir}/requirements.txt" diff --git a/examples/riscv/setup.sh b/examples/riscv/setup-linux.sh similarity index 74% rename from examples/riscv/setup.sh rename to examples/riscv/setup-linux.sh index 955c8ca3386..912557e3bfb 100755 --- a/examples/riscv/setup.sh +++ b/examples/riscv/setup-linux.sh @@ -8,7 +8,7 @@ # - gcc/g++/binutils for riscv64-linux-gnu (cross-compiler + sysroot) # - qemu-user-static (qemu-riscv64 user-mode emulator) -set -eu +set -euo pipefail script_dir=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &> /dev/null && pwd) @@ -22,9 +22,18 @@ if [[ $EUID -ne 0 ]]; then SUDO="sudo" fi +source /etc/os-release + +GCC_VERSION="" +if [[ "${VERSION_ID:-}" == "24.04" || "${VERSION_ID:-}" == "26.04" ]]; then + GCC_VERSION="14" +fi + ${SUDO} apt-get update ${SUDO} apt-get install -y --no-install-recommends \ build-essential \ + gcc${GCC_VERSION:+-${GCC_VERSION}} \ + g++${GCC_VERSION:+-${GCC_VERSION}} \ gcc${GCC_VERSION:+-${GCC_VERSION}}-riscv64-linux-gnu \ g++${GCC_VERSION:+-${GCC_VERSION}}-riscv64-linux-gnu \ binutils-riscv64-linux-gnu \ @@ -33,9 +42,14 @@ ${SUDO} apt-get install -y --no-install-recommends \ cmake \ file \ ca-certificates \ - qemu-user-static + qemu-user-static \ + libglib2.0-0t64 \ + libxcb1 \ + libgl1 if [[ -n "${GCC_VERSION+x}" ]]; then + ${SUDO} update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc${GCC_VERSION:+-${GCC_VERSION}} 100 + ${SUDO} update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++${GCC_VERSION:+-${GCC_VERSION}} 100 ${SUDO} update-alternatives --install /usr/bin/riscv64-linux-gnu-gcc riscv64-linux-gnu-gcc /usr/bin/riscv64-linux-gnu-gcc${GCC_VERSION:+-${GCC_VERSION}} 100 ${SUDO} update-alternatives --install /usr/bin/riscv64-linux-gnu-g++ riscv64-linux-gnu-g++ /usr/bin/riscv64-linux-gnu-g++${GCC_VERSION:+-${GCC_VERSION}} 100 fi diff --git a/examples/riscv/test-matrix.sh b/examples/riscv/test-matrix.sh new file mode 100644 index 00000000000..9ed8115de44 --- /dev/null +++ b/examples/riscv/test-matrix.sh @@ -0,0 +1,241 @@ +#!/usr/bin/env bash +# Copyright 2026 The ExecuTorch Authors. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# Local mirror of riscv64.yml's matrix using two docker containers: +# +# - executorch-riscv-linux (ubuntu:24.04 + gcc-14). +# - executorch-riscv-baremetal (ubuntu:26.04 + gcc-15). +# 26.04 is the only release shipping libstdc++-riscv64-unknown-elf-picolibc. +# +# Usage: +# examples/riscv/test-matrix.sh # full sweep +# examples/riscv/test-matrix.sh --model=mv2 # one model, all configs +# examples/riscv/test-matrix.sh --os=baremetal # one OS +# examples/riscv/test-matrix.sh --quantize-only # skip the no-q half +# examples/riscv/test-matrix.sh --setup-only # bootstrap containers, don't run +# +# Re-runs are cheap when the per-cell build dirs survive (set --keep-build). + +set -euo pipefail + +script_dir=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &> /dev/null && pwd) +et_root_dir=$(realpath "${script_dir}/../..") + +model_filter="" +os_filter="" +arch_filter="" +variant_filter="" +backend_filter="" +quantize_filter="" +setup_only=false +keep_build=false + +usage() { + cat < Only run cells for this model + --os= + --arch= + --backend= + --variant= + --quantize= + --setup-only Make sure both containers are ready, then exit + --keep-build Reuse riscv_test/ dirs instead of starting fresh + -h, --help +EOF +} + +for arg in "$@"; do + case $arg in + --model=*) model_filter="${arg#*=}" ;; + --os=*) os_filter="${arg#*=}" ;; + --arch=*) arch_filter="${arg#*=}" ;; + --backend=*) backend_filter="${arg#*=}" ;; + --variant=*) variant_filter="${arg#*=}" ;; + --quantize=*) quantize_filter="${arg#*=}" ;; + --setup-only) setup_only=true ;; + --keep-build) keep_build=true ;; + -h|--help) usage; exit 0 ;; + *) echo "Unknown: $arg" >&2; usage; exit 1 ;; + esac +done + +# Container names + image tags match what the CI workflow consumes. +LINUX_CTR=executorch-riscv-linux +BAREMETAL_CTR=executorch-riscv-baremetal + +MODELS="add mv2 resnet18 mobilebert llama2 yolo26" +BACKENDS="portable xnnpack" + +# qemu-cpu-ext sweeps; keep parity with the JSON arrays in riscv64.yml. +SCALAR_EXT="v=false" +RVV128_EXT="v=true,vext_spec=v1.0,vlen=128" +RVV256_EXT="v=true,vext_spec=v1.0,vlen=256" +RVV512_EXT="v=true,vext_spec=v1.0,vlen=512" + +# Check if a cell combination should be excluded (matching riscv64.yml excludes) +should_exclude() { + local os=$1 arch=$2 backend=$3 variant=$4 model=$5 quantize=$6 + + # Disable quantization testing with Portable Kernels + if [[ "${backend}" == "portable" && "${quantize}" == "true" ]]; then + return 0 + fi + # XNNPACK needs pthreads + dynamic loading (no baremetal) + if [[ "${backend}" == "xnnpack" && "${os}" == "baremetal" ]]; then + return 0 + fi + # XNNPACK needs RVV + if [[ "${backend}" == "xnnpack" && "${variant}" == "scalar" ]]; then + return 0 + fi + # No quantization recipe for Yolo26 + if [[ "${model}" == "yolo26" && "${quantize}" == "true" ]]; then + return 0 + fi + # No riscv32-linux-gnu cross is packaged on Ubuntu + if [[ "${os}" == "linux" && "${arch}" == "rv32" ]]; then + return 0 + fi + + return 1 +} + +# ---- container bootstrap (idempotent) ------------------------------------- + +ensure_linux() { + if ! docker ps -a --format '{{.Names}}' | grep -qx "${LINUX_CTR}"; then + echo "[matrix] starting ${LINUX_CTR} (ubuntu:24.04)" + docker run -d --name "${LINUX_CTR}" \ + -e DEBIAN_FRONTEND=noninteractive \ + -v "${et_root_dir}":/executorch -w /executorch \ + ubuntu:24.04 sleep infinity >/dev/null + fi + docker start "${LINUX_CTR}" >/dev/null + if ! docker exec "${LINUX_CTR}" test -d /executorch/.venv-docker-linux; then + echo "[matrix] bootstrapping ${LINUX_CTR} (this takes a few minutes)" + docker exec "${LINUX_CTR}" bash -eu -c ' + set -e + apt-get update -qq && apt-get install -y -qq --no-install-recommends \ + python3 python3-pip ca-certificates sudo + python3 -m pip install --break-system-packages --quiet uv + uv python install 3.10 + cd /executorch + uv venv --python 3.10 --seed .venv-docker-linux + ' + fi + docker exec "${LINUX_CTR}" bash -eu -c ' + set -e + cd /executorch + source .venv-docker-linux/bin/activate + pip install --upgrade pip + pip install executorch + bash examples/riscv/setup-linux.sh + ' +} + +ensure_baremetal() { + if ! docker ps -a --format '{{.Names}}' | grep -qx "${BAREMETAL_CTR}"; then + echo "[matrix] starting ${BAREMETAL_CTR} (ubuntu:26.04)" + docker run -d --name "${BAREMETAL_CTR}" \ + -e DEBIAN_FRONTEND=noninteractive \ + -v "${et_root_dir}":/executorch -w /executorch \ + ubuntu:26.04 sleep infinity >/dev/null + fi + docker start "${BAREMETAL_CTR}" >/dev/null + if ! docker exec "${BAREMETAL_CTR}" test -d /executorch/.venv-docker-baremetal; then + echo "[matrix] bootstrapping ${BAREMETAL_CTR} (this takes a few minutes)" + docker exec "${BAREMETAL_CTR}" bash -eu -c ' + set -e + apt-get update -qq && apt-get install -y -qq --no-install-recommends \ + python3 python3-pip ca-certificates sudo + python3 -m pip install --break-system-packages --quiet uv + uv python install 3.10 + cd /executorch + uv venv --python 3.10 --seed .venv-docker-baremetal + ' + fi + docker exec "${BAREMETAL_CTR}" bash -eu -c ' + set -e + cd /executorch + source .venv-docker-baremetal/bin/activate + pip install --upgrade pip + pip install executorch + bash examples/riscv/setup-baremetal.sh + ' +} + +ensure_linux +ensure_baremetal +if ${setup_only}; then exit 0; fi + +# ---- one cell -------------------------------------------------------------- + +# Args: ctr venv os arch backend variant ext model quantize_flag +run_cell() { + local ctr=$1 venv=$2 os=$3 arch=$4 backend=$5 variant=$6 ext=$7 model=$8 q=$9 + local cell="${model}${q:++q}-${backend}/${os}-${arch}" + local model_q="${model}${q:+-q}" + local variant_slug="${ext//,/_}"; variant_slug="${variant_slug//=/_}"; variant_slug="${variant_slug:-base}" + local build_dir="/executorch/riscv_test/${model_q}/${backend}/${os}-${arch}-${variant_slug}" + if ! ${keep_build}; then + docker exec "${ctr}" rm -rf "${build_dir}" + fi + if docker exec "${ctr}" bash -lc " + cd /executorch && source ${venv}/bin/activate && + timeout 1800 bash -eu examples/riscv/run.sh \ + --model=${model} ${q} --backend=${backend} \ + --os=${os} --arch=${arch} \ + --qemu-cpu-ext='${ext}' \ + --build-dir=${build_dir} --timeout=900 + "; then + echo " PASS ${cell}" + return 0 + else + echo " FAIL ${cell}" + return 1 + fi +} + +# ---- iterate --------------------------------------------------------------- + +passed=0; total=0 +for m in ${MODELS}; do +for backend in ${BACKENDS}; do +for os_arch in "linux:rv64" "baremetal:rv64" "baremetal:rv32"; do +for variant_lbl in "scalar:${SCALAR_EXT}" "rvv128:${RVV128_EXT}" "rvv256:${RVV256_EXT}" "rvv512:${RVV512_EXT}"; do + os="${os_arch%%:*}"; arch="${os_arch##*:}"; variant="${variant_lbl%%:*}"; ext="${variant_lbl#*:}" + + if [[ -n "${model_filter}" && "${m}" != "${model_filter}" ]]; then continue; fi + if [[ -n "${backend_filter}" && "${backend}" != "${backend_filter}" ]]; then continue; fi + if [[ -n "${os_filter}" && "${os}" != "${os_filter}" ]]; then continue; fi + if [[ -n "${arch_filter}" && "${arch}" != "${arch_filter}" ]]; then continue; fi + if [[ -n "${variant_filter}" && "${variant}" != "${variant_filter}" ]]; then continue; fi + + if [[ "${os}" == "linux" ]]; then ctr="${LINUX_CTR}"; venv=/executorch/.venv-docker-linux; + else ctr="${BAREMETAL_CTR}"; venv=/executorch/.venv-docker-baremetal; fi + + if [[ -z "${quantize_filter}" || "${quantize_filter}" = "no" ]]; then + if should_exclude "${os}" "${arch}" "${backend}" "${variant}" "${m}" "false"; then continue; fi + total=$((total+1)) + run_cell "${ctr}" "${venv}" "${os}" "${arch}" "${backend}" "${variant}" "${ext}" "${m}" "" \ + && passed=$((passed+1)) || exit 1 + fi + if [[ -z "${quantize_filter}" || "${quantize_filter}" = "yes" ]]; then + if should_exclude "${os}" "${arch}" "${backend}" "${variant}" "${m}" "true"; then continue; fi + total=$((total+1)) + run_cell "${ctr}" "${venv}" "${os}" "${arch}" "${backend}" "${variant}" "${ext}" "${m}" "--quantize" \ + && passed=$((passed+1)) || exit 1 + fi +done +done +done +done + +echo "" +echo "===== ${passed}/${total} cells passed =====" +test "${passed}" -eq "${total}" diff --git a/exir/BUCK b/exir/BUCK index f00b3f1c787..d70900c02ae 100644 --- a/exir/BUCK +++ b/exir/BUCK @@ -259,6 +259,16 @@ fbcode_target(_kind = runtime.python_library, ], ) +fbcode_target(_kind = runtime.python_library, + name = "_program_utils", + srcs = [ + "_program_utils.py", + ], + deps = [ + "//caffe2:torch", + ], +) + fbcode_target(_kind = runtime.python_library, name = "pass_manager", srcs = [ @@ -266,7 +276,9 @@ fbcode_target(_kind = runtime.python_library, ], deps = [ "fbsource//third-party/pypi/typing-extensions:typing-extensions", + ":_program_utils", ":error", + ":pass_base", "//caffe2:torch", ], ) diff --git a/exir/_program_utils.py b/exir/_program_utils.py new file mode 100644 index 00000000000..d0d2039d93a --- /dev/null +++ b/exir/_program_utils.py @@ -0,0 +1,104 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +import torch +from torch.export.exported_program import ( + ConstantArgument, + ExportGraphSignature, + InputSpec, + OutputSpec, +) + + +def _get_updated_range_constraints(gm): + def get_shape_env(gm): + vals = [ + node.meta["val"] + for node in gm.graph.nodes + if node.meta.get("val", None) is not None + ] + from torch._guards import detect_fake_mode # type: ignore[21] + + fake_mode = detect_fake_mode(vals) + if fake_mode is not None: + return fake_mode.shape_env + for v in vals: + if isinstance(v, torch.SymInt): + return v.node.shape_env + + shape_env = get_shape_env(gm) + if shape_env is None: + return {} + range_constraints = { + shape_env.replacements.get(k, k): v for k, v in shape_env.var_to_range.items() + } + # Only when we have an unbacked symint, and it's used as constructor inputs, + # runtime_var_to_range will make a difference compated to var_to_range. + # e.g. [2, oo) -> [0, oo) + for k, v in shape_env.var_to_range.items(): + if k not in shape_env.replacements: + range_constraints[k] = v + return range_constraints + + +def _get_updated_graph_signature( + old_signature: ExportGraphSignature, + new_gm: torch.fx.GraphModule, +) -> ExportGraphSignature: + """ + Update the graph signature's user_input/user_outputs. + """ + new_input_specs = [] + i = 0 + for node in new_gm.graph.nodes: + if node.op != "placeholder": + continue + + assert i < len( + old_signature.input_specs + ), "Number of inputs changed after transformation" + old_input_spec = old_signature.input_specs[i] + arg = ( + old_input_spec.arg + if isinstance(old_input_spec.arg, ConstantArgument) + # pyre-fixme[20]: Argument `class_fqn` expected. + else type(old_input_spec.arg)(node.name) + ) + new_input_specs.append( + InputSpec( + old_input_spec.kind, + arg, + old_input_spec.target, + persistent=old_input_spec.persistent, + ) + ) + i += 1 + + output_node = new_gm.graph.output_node() + assert output_node.op == "output" + + new_output_specs = [] + for i, node in enumerate(output_node.args[0]): + assert i < len( + old_signature.output_specs + ), "Number of outputs changed after transformation" + old_output_spec = old_signature.output_specs[i] + arg = ( + old_output_spec.arg + if isinstance(old_output_spec.arg, ConstantArgument) + # pyre-fixme[20]: Argument `class_fqn` expected. + else type(old_output_spec.arg)(node.name) + ) + new_output_specs.append( + OutputSpec(old_output_spec.kind, arg, old_output_spec.target) + ) + + new_signature = ExportGraphSignature( + input_specs=new_input_specs, output_specs=new_output_specs + ) + return new_signature diff --git a/exir/_serialize/_flatbuffer.py b/exir/_serialize/_flatbuffer.py index 219e4517aea..43e203d1ff9 100644 --- a/exir/_serialize/_flatbuffer.py +++ b/exir/_serialize/_flatbuffer.py @@ -12,7 +12,6 @@ import importlib.resources import os import re -import shutil import stat import subprocess import tempfile @@ -384,72 +383,6 @@ def _flatc_decompile( ) -def _program_json_to_flatbuffer( - program_json: str, - *, - constant_tensor_alignment: Optional[int] = None, - delegate_alignment: Optional[int] = None, -) -> _FlatbufferResult: - """Converts Program-compatible JSON into binary flatbuffer data. - - Args: - program_json: The JSON to convert. Must be compatible with the root - table type of //executorch/schema/program.fbs. - constant_tensor_alignment: If provided, the alignment to use for tensor - data embedded in the output flatbuffer data. If not provided, uses - the alignment in the schema. - delegate_alignment: If provided, the alignment to use for delegate - data embedded in the output flatbuffer data. If not provided, uses - the alignment in the schema. - - Returns: The flatbuffer data and associated metadata. - """ - with tempfile.TemporaryDirectory() as temp_dir: - schema_info = _prepare_schema( - out_dir=temp_dir, - constant_tensor_alignment=constant_tensor_alignment, - delegate_alignment=delegate_alignment, - ) - file_stem = "data" - json_path = os.path.join(temp_dir, file_stem + ".json") - output_path = os.path.join(temp_dir, file_stem + ".pte") - - with open(json_path, "wb") as json_file: - json_file.write(program_json.encode("ascii")) - - try: - _flatc_compile(temp_dir, schema_info.root_path, json_path) - except Exception as err: - # It's helpful to save the breaking files for debugging. Optionally - # move them out of the auto-deleting temporary directory. Don't do - # this by default because some input files can be many GB in size, - # and these copies won't be auto-deleted. - should_save = os.getenv(_SAVE_FLATC_ENV, "").strip() not in {"", "0"} - extra_message = "" - if should_save: - try: - saved_dir = tempfile.mkdtemp(prefix="exir-saved-flatc-") - for f in os.listdir(temp_dir): - shutil.move(src=os.path.join(temp_dir, f), dst=saved_dir) - extra_message += f" Moved input files to '{saved_dir}'." - except Exception as err2: - extra_message += ( - f" (Failed to save input files for debugging: {err2})" - ) - else: - extra_message += ( - f" Set {_SAVE_FLATC_ENV}=1 to save input files on failure." - ) - - raise RuntimeError( - f"Failed to compile {json_path} to {output_path}." + extra_message - ) from err - with open(output_path, "rb") as output_file: - return _FlatbufferResult( - data=output_file.read(), max_alignment=schema_info.max_alignment - ) - - def _replace_infinity_in_json_file(content: bytes) -> bytes: """Replace -inf and inf with "inf" and "-inf" in the JSON file. program.fbs is used to convert from flatbuffer to JSON. +-inf float values are not diff --git a/exir/_serialize/_flatbuffer_program.py b/exir/_serialize/_flatbuffer_program.py index 4c1c315347a..cd742c8361d 100644 --- a/exir/_serialize/_flatbuffer_program.py +++ b/exir/_serialize/_flatbuffer_program.py @@ -8,12 +8,14 @@ import enum import functools import importlib +import pkgutil import tempfile from contextvars import ContextVar from dataclasses import fields, is_dataclass from functools import lru_cache -from typing import Any, Dict, Optional +from types import ModuleType +from typing import Any, Dict, get_args, get_origin, get_type_hints, Optional, Union import flatbuffers # pyre-ignore[21] from executorch.exir._serialize._flatbuffer import ( @@ -22,6 +24,7 @@ _prepare_schema, _SchemaInfo, ) +from executorch.exir._serialize.generated import executorch_flatbuffer as _generated_fb from executorch.exir._serialize.generated.executorch_flatbuffer import ( BackendDelegateInlineData as _BackendDelegateInlineData, Buffer as _Buffer, @@ -33,6 +36,7 @@ _T_CLASS_CACHE: Dict[type, type] = {} _FIELD_NAME_CACHE: Dict[type, tuple[tuple[str, str], ...]] = {} +_TYPE_HINTS_CACHE: Dict[type, Dict[str, Any]] = {} _BUFFER_ALIGNMENT: ContextVar[int] = ContextVar("_BUFFER_ALIGNMENT", default=1) _DELEGATE_ALIGNMENT: ContextVar[int] = ContextVar("_DELEGATE_ALIGNMENT", default=1) @@ -64,6 +68,15 @@ def _dataclass_field_map(dataclass_type: type) -> tuple[tuple[str, str], ...]: return mapping +def _dataclass_type_hints(dataclass_type: type) -> Dict[str, Any]: + cached = _TYPE_HINTS_CACHE.get(dataclass_type) + if cached is not None: + return cached + type_hints = get_type_hints(dataclass_type) + _TYPE_HINTS_CACHE[dataclass_type] = type_hints + return type_hints + + def _create_aligned_byte_vector(builder: Any, data: bytes, alignment: int) -> int: if not _is_valid_alignment(alignment): raise ValueError(f"Bad alignment {alignment}") @@ -194,6 +207,126 @@ def convert_program(val: Program) -> ProgramT: return _convert_dataclass(val) +# The generated FlatBuffer Python modules import child tables/unions as modules +# (for example, Program.ExecutionPlan becomes the ExecutionPlan module), but the +# unpacking helpers later expect those globals to be the corresponding classes. +# Rebind module globals like ExecutionPlan -> ExecutionPlan.ExecutionPlan so the +# generated InitFromObj()/InitFromPackedBuf() code can instantiate nested types. +def _patch_generated_module_aliases(module: ModuleType) -> None: + for name, maybe_module in vars(module).items(): + if not isinstance(maybe_module, ModuleType): + continue + maybe_class = getattr(maybe_module, name, None) + if isinstance(maybe_class, type): + setattr(module, name, maybe_class) + + +@lru_cache(maxsize=1) +def _patch_generated_flatbuffer_aliases() -> None: + package_name = _generated_fb.__name__ + for module_info in pkgutil.iter_modules(_generated_fb.__path__): + module = importlib.import_module(f"{package_name}.{module_info.name}") + _patch_generated_module_aliases(module) + + +def _flatbuffer_dataclass_names(val: Any) -> tuple[str, Optional[str]]: + val_type_name = type(val).__name__ + if val_type_name.endswith("T"): + return val_type_name, val_type_name[:-1] + return val_type_name, None + + +def _matches_dataclass_union_type( + union_type: Any, val_type_name: str, val_dataclass_name: Optional[str] +) -> bool: + if not is_dataclass(union_type): + return False + union_name = union_type.__name__ + return union_name == val_type_name or ( + val_dataclass_name is not None and union_name == val_dataclass_name + ) + + +def _matches_non_dataclass_union_type(union_type: Any, val: Any) -> bool: + if union_type is Any: + return True + if union_type is str and isinstance(val, (bytes, bytearray, memoryview)): + return True + union_origin = get_origin(union_type) + if union_origin is list and hasattr(val, "__iter__"): + return True + return isinstance(union_type, type) and isinstance(val, union_type) + + +def _union_choice_from_value(union_types: tuple[Any, ...], val: Any) -> Any: + if val is None: + for union_type in union_types: + if union_type is type(None): + return union_type + return None + + val_type_name, val_dataclass_name = _flatbuffer_dataclass_names(val) + + for union_type in union_types: + if union_type is type(None): + continue + if _matches_dataclass_union_type(union_type, val_type_name, val_dataclass_name): + return union_type + if _matches_non_dataclass_union_type(union_type, val): + return union_type + return None + + +def _convert_from_flatbuffer_value(val: Any, expected_type: Any) -> Any: + if val is None: + return None + + origin = get_origin(expected_type) + if origin is list: + item_type = get_args(expected_type)[0] + return [_convert_from_flatbuffer_value(item, item_type) for item in val] + + if origin is Union: + union_type = _union_choice_from_value(get_args(expected_type), val) + if union_type is None: + raise TypeError( + f"Could not match value type {type(val)} to {expected_type}" + ) + if union_type is type(None): + return None + return _convert_from_flatbuffer_value(val, union_type) + + if expected_type is bytes: + return _coerce_bytes(val) + if expected_type is str and isinstance(val, (bytes, bytearray, memoryview)): + return _coerce_bytes(val).decode("utf-8") + if is_dataclass(expected_type): + return _convert_from_flatbuffer_dataclass(val, expected_type) + if isinstance(expected_type, type) and issubclass(expected_type, enum.Enum): + if isinstance(val, expected_type): + return val + return expected_type(val) + if isinstance(expected_type, type): + return expected_type(val) + return val + + +def _convert_from_flatbuffer_dataclass(val: Any, dataclass_type: type) -> Any: + result = {} + type_hints = _dataclass_type_hints(dataclass_type) + for src_name, dst_name in _dataclass_field_map(dataclass_type): + result[src_name] = _convert_from_flatbuffer_value( + getattr(val, dst_name), type_hints[src_name] + ) + return dataclass_type(**result) + + +def _flatbuffer_to_program(program_data: bytes) -> Program: + _patch_generated_flatbuffer_aliases() + program_t = ProgramT.InitFromPackedBuf(program_data) + return _convert_from_flatbuffer_dataclass(program_t, Program) + + @lru_cache(maxsize=1) def _get_schema_info( constant_tensor_alignment: Optional[int], delegate_alignment: Optional[int] @@ -213,11 +346,7 @@ def _program_to_flatbuffer( constant_tensor_alignment: Optional[int] = None, delegate_alignment: Optional[int] = None, ) -> _FlatbufferResult: - """Converts a Program dataclass into binary flatbuffer data. - - Unlike _program_json_to_flatbuffer(), this does not use JSON or invoke - flatc to build the binary. - """ + """Converts a Program dataclass into binary flatbuffer data.""" schema_info = _get_schema_info(constant_tensor_alignment, delegate_alignment) _set_pack_alignments(schema_info.tensor_alignment, schema_info.delegate_alignment) _install_fast_packers() diff --git a/exir/_serialize/_program.py b/exir/_serialize/_program.py index 4ab2a3572b4..230b50bf558 100644 --- a/exir/_serialize/_program.py +++ b/exir/_serialize/_program.py @@ -16,12 +16,12 @@ from typing import ClassVar, Dict, List, Literal, Optional, Sequence, Tuple from executorch.exir._serialize._cord import Cord -from executorch.exir._serialize._dataclass import _DataclassEncoder, _json_to_dataclass -from executorch.exir._serialize._flatbuffer import ( - _FlatbufferResult, - _program_flatbuffer_to_json, +from executorch.exir._serialize._dataclass import _DataclassEncoder +from executorch.exir._serialize._flatbuffer import _FlatbufferResult +from executorch.exir._serialize._flatbuffer_program import ( + _flatbuffer_to_program, + _program_to_flatbuffer, ) -from executorch.exir._serialize._flatbuffer_program import _program_to_flatbuffer from executorch.exir._serialize._named_data_store import ( NamedDataStore, NamedDataStoreOutput, @@ -86,12 +86,6 @@ def _program_to_json(program: Program) -> str: return json.dumps(program, cls=_DataclassEncoder) -def _json_to_program(program_json: bytes) -> Program: - """Returns a Program deserialized from the given JSON string.""" - # construct program class recursively from dict - return _json_to_dataclass(json.loads(program_json), cls=Program) - - def _insert_flatbuffer_header( flatbuffer_data: bytes, magic_regex: str, header_data: bytes ) -> bytes: @@ -757,9 +751,7 @@ def deserialize_pte_binary(program_data: bytes) -> PTEFile: segment_base_offset = eh.segment_base_offset # Parse the flatbuffer data. - program: Program = _json_to_program( - _program_flatbuffer_to_json(program_data[:program_size]) - ) + program: Program = _flatbuffer_to_program(program_data[:program_size]) if segment_base_offset != 0: # Move segment data back into the Program. @@ -799,9 +791,7 @@ def _extract_delegate_payload( program_size = len(pte_data) # Parse the program flatbuffer - program: Program = _json_to_program( - _program_flatbuffer_to_json(pte_data[:program_size]) - ) + program: Program = _flatbuffer_to_program(pte_data[:program_size]) # Search for the matching delegate match_count = 0 diff --git a/exir/_serialize/test/test_flatbuffer.py b/exir/_serialize/test/test_flatbuffer.py index 801ddca112d..e623da55cd2 100644 --- a/exir/_serialize/test/test_flatbuffer.py +++ b/exir/_serialize/test/test_flatbuffer.py @@ -7,19 +7,13 @@ # LICENSE file in the root directory of this source tree. import os -import re -import shutil import tempfile import unittest from typing import Dict, Optional, Sequence from unittest.mock import patch from executorch.exir._serialize import _flatbuffer -from executorch.exir._serialize._flatbuffer import ( - _program_json_to_flatbuffer, - _ResourceFiles, - _SchemaInfo, -) +from executorch.exir._serialize._flatbuffer import _ResourceFiles, _SchemaInfo def read_file(dir: str, filename: str) -> bytes: @@ -277,60 +271,3 @@ def test_bad_delegate_alignment_fails(self) -> None: out_dir, delegate_alignment=bad_alignment, ) - - -class TestProgramJsonToFlatbuffer(unittest.TestCase): - @patch.dict(os.environ, {_flatbuffer._SAVE_FLATC_ENV: "1"}) - def test_save_json_on_failure(self) -> None: - err_msg: Optional[str] = None - try: - _program_json_to_flatbuffer("} some bad json {") - self.fail("Should have raised an exception") - except RuntimeError as err: - err_msg = err.args[0] - - self.assertIsNotNone(err_msg) - match = re.search(r"Moved input files to '(.*?)'", err_msg) - self.assertTrue(match, msg=f"Unexpected error message: {err_msg}") - path = match.group(1) - - files = frozenset(os.listdir(path)) - # Delete the files otherwise they'll accumulate every time the - # test is run. - shutil.rmtree(path) - # Check for a couple of the files that should be there. - self.assertIn("data.json", files) - self.assertIn("program.fbs", files) - - @patch.dict(os.environ, {_flatbuffer._SAVE_FLATC_ENV: "1"}) - def test_unable_to_save_json_on_failure(self) -> None: - err_msg: Optional[str] = None - try: - with patch.object( - _flatbuffer.shutil, - "move", - side_effect=Exception("shutil.move mock failure"), - ): - _program_json_to_flatbuffer("} some bad json {") - self.fail("Should have raised an exception") - except RuntimeError as err: - err_msg = err.args[0] - - self.assertIsNotNone(err_msg) - self.assertIn("Failed to save input files", err_msg) - - @patch.dict(os.environ, {_flatbuffer._SAVE_FLATC_ENV: ""}) - def test_no_save_json_on_failure(self) -> None: - err_msg: Optional[str] = None - try: - _program_json_to_flatbuffer("} some bad json {") - self.fail("Should have raised an exception") - except RuntimeError as err: - err_msg = err.args[0] - - self.assertIsNotNone(err_msg) - self.assertIn( - f"Set {_flatbuffer._SAVE_FLATC_ENV}=1 to save input files", err_msg - ) - self.assertNotIn("Moved input files", err_msg) - self.assertNotIn("Failed to save input files", err_msg) diff --git a/exir/_serialize/test/test_flatbuffer_program.py b/exir/_serialize/test/test_flatbuffer_program.py index 05e05d4e610..4910f9b431f 100644 --- a/exir/_serialize/test/test_flatbuffer_program.py +++ b/exir/_serialize/test/test_flatbuffer_program.py @@ -4,15 +4,12 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import json import unittest -from executorch.exir._serialize._flatbuffer import ( - _program_flatbuffer_to_json, - _program_json_to_flatbuffer, +from executorch.exir._serialize._flatbuffer_program import ( + _flatbuffer_to_program, + _program_to_flatbuffer, ) -from executorch.exir._serialize._flatbuffer_program import _program_to_flatbuffer -from executorch.exir._serialize._program import _json_to_program, _program_to_json from executorch.exir.backend.compile_spec_schema import CompileSpec from executorch.exir.schema import ( AllocationDetails, @@ -157,50 +154,12 @@ def _make_program(self) -> Program: named_data=[], ) - def _flatbuffer_to_dict(self, flatbuffer_data: bytes) -> dict: - return json.loads(_program_flatbuffer_to_json(flatbuffer_data)) - - def test_roundtrip_via_json(self) -> None: + def test_roundtrip_via_direct_python(self) -> None: program = self._make_program() result = _program_to_flatbuffer( program, constant_tensor_alignment=32, delegate_alignment=64 ) - self.assertGreater(len(result.data), 8) - self.assertEqual(result.data[4:6], b"ET") - self.assertGreaterEqual(result.max_alignment, 64) - - program2 = _json_to_program(_program_flatbuffer_to_json(result.data)) - self.assertEqual(program2, program) - - def test_flatbuffer_paths_match(self) -> None: - program = self._make_program() - cases = [ - (None, None), - (32, 64), - ] - for constant_tensor_alignment, delegate_alignment in cases: - with self.subTest( - constant_tensor_alignment=constant_tensor_alignment, - delegate_alignment=delegate_alignment, - ): - result = _program_to_flatbuffer( - program, - constant_tensor_alignment=constant_tensor_alignment, - delegate_alignment=delegate_alignment, - ) - result2 = _program_json_to_flatbuffer( - _program_to_json(program), - constant_tensor_alignment=constant_tensor_alignment, - delegate_alignment=delegate_alignment, - ) - direct_dict = self._flatbuffer_to_dict(result.data) - json_path_dict = self._flatbuffer_to_dict(result2.data) - self.assertEqual( - direct_dict, - json_path_dict, - "Flatbuffer JSON differs between direct and JSON paths", - ) - self.assertEqual(result.max_alignment, result2.max_alignment) + self.assertEqual(_flatbuffer_to_program(result.data), program) def test_bad_alignment_fails(self) -> None: program = Program( diff --git a/exir/_serialize/test/test_program.py b/exir/_serialize/test/test_program.py index 579934e9d38..0d0d833c952 100644 --- a/exir/_serialize/test/test_program.py +++ b/exir/_serialize/test/test_program.py @@ -1,6 +1,7 @@ #!/usr/bin/env fbpython # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. +# Copyright 2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -16,12 +17,11 @@ from typing import Dict, List, Sequence -from executorch.exir._serialize._flatbuffer import _program_flatbuffer_to_json +from executorch.exir._serialize._flatbuffer_program import _flatbuffer_to_program from executorch.exir._serialize._named_data_store import NamedDataStoreOutput from executorch.exir._serialize._program import ( _ExtendedHeader, _get_extended_header, - _json_to_program, _program_to_json, deserialize_pte_binary, PTEFile, @@ -30,6 +30,8 @@ from executorch.exir._serialize.data_serializer import DataEntry from executorch.exir._serialize.padding import aligned_size +from executorch.exir.backend.compile_spec_schema import CompileSpec + from executorch.exir.schema import ( BackendDelegate, BackendDelegateDataReference, @@ -39,7 +41,15 @@ DataLocation, DataSegment, DeviceType, + Double, + EValue, ExecutionPlan, + Frame, + FrameList, + FreeCall, + Instruction, + JumpFalseCall, + MoveCall, NonConstBufferDevice, Program, SubsegmentOffsets, @@ -197,7 +207,7 @@ def constant_segment_with_tensor_alignment( self.assertGreater(eh.segment_data_size, 0) # Peek inside the actual flatbuffer data to see the segments. - program_with_segments = _json_to_program(_program_flatbuffer_to_json(pte_data)) + program_with_segments = _flatbuffer_to_program(pte_data) # The constant tensor data should appear as the only segment. self.assertEqual(len(program_with_segments.segments), 1) @@ -467,6 +477,68 @@ def test_round_trip_no_header_no_segments(self) -> None: self.assertEqual(deserialized.mutable_data, None) self.assertEqual(deserialized.named_data, None) + def test_deserialize_pte_binary_with_rich_flatbuffer_types(self) -> None: + program = get_test_program() + plan = program.execution_plan[0] + plan.values.append(EValue(Double(float("inf")))) + plan.delegates.append( + BackendDelegate( + id="delegate0", + processed=BackendDelegateDataReference( + location=DataLocation.INLINE, + index=0, + ), + compile_specs=[CompileSpec(key="k", value=b"v")], + ) + ) + plan.chains[0].instructions.extend( + [ + Instruction(MoveCall(move_from=0, move_to=1)), + Instruction( + JumpFalseCall(cond_value_index=1, destination_instruction=0) + ), + Instruction(FreeCall(value_index=0)), + ] + ) + plan.chains[0].stacktrace = [ + FrameList( + items=[ + Frame( + filename="file.py", + lineno=idx + 1, + name="fn", + context="ctx", + ) + ] + ) + for idx, _ in enumerate(plan.chains[0].instructions) + ] + program.constant_buffer.append(Buffer(storage=b"abcd")) + program.backend_delegate_data.append( + BackendDelegateInlineData(data=b"delegate-data") + ) + + deserialized = deserialize_pte_binary( + bytes(serialize_pte_binary(PTEFile(program=program))) + ) + + self.assert_programs_equal(program, deserialized.program) + self.assertEqual(deserialized.mutable_data, None) + self.assertEqual(deserialized.named_data, None) + self.assertIsInstance(plan.values[-1].val, Double) + self.assertIsInstance( + deserialized.program.execution_plan[0].values[-1].val, + Double, + ) + self.assertEqual( + deserialized.program.execution_plan[0].values[-1].val.double_val, + "inf", + ) + self.assertEqual( + deserialized.program.execution_plan[0].delegates[0].compile_specs[0].value, + b"v", + ) + def test_round_trip_large_buffer_sizes(self) -> None: """Tests that when the non_const_buffer_sizes contains integers overflowing a signed/unsigned 32 bit integer, we can still serialize the @@ -531,7 +603,7 @@ def test_round_trip_no_segments_and_no_header(self) -> None: self.assertIsNone(eh) # Peek inside the flatbuffer data to confirm that there are no segments. - program_with_segments = _json_to_program(_program_flatbuffer_to_json(pte_data)) + program_with_segments = _flatbuffer_to_program(pte_data) self.assertEqual(program_with_segments.segments, []) # Convert back. @@ -597,7 +669,7 @@ def test_round_trip_with_segments(self) -> None: # this also implicity tests the case where we try parsing the entire # file with segment data following it, demonstrating that the extra data # doesn't upset the flatbuffer parsing path. - program_with_segments = _json_to_program(_program_flatbuffer_to_json(pte_data)) + program_with_segments = _flatbuffer_to_program(pte_data) # The delegate blobs we added to the program should appear as segments. # The one empty blob should have been ignored, hence the `- 1`. @@ -694,7 +766,7 @@ def test_no_constants(self) -> None: self.assertEqual(program.segments, []) # Peek inside the actual flatbuffer data to see the segments. - flatbuffer_program = _json_to_program(_program_flatbuffer_to_json(pte_data)) + flatbuffer_program = _flatbuffer_to_program(pte_data) # Constant buffer should be empty. self.assertEqual(len(flatbuffer_program.constant_buffer), 0) @@ -814,7 +886,7 @@ def test_constant_delegate_and_named_data_segments(self) -> None: self.assertGreater(eh.segment_data_size, 0) # Peek inside the actual flatbuffer data to see the segments. - program_with_segments = _json_to_program(_program_flatbuffer_to_json(pte_data)) + program_with_segments = _flatbuffer_to_program(pte_data) # Segment table should contain a constant segment, the delegate blobs # and a named data segment. @@ -1017,7 +1089,7 @@ def test_named_data_segments(self) -> None: self.assertGreater(eh.segment_data_size, 0) # Peek inside the actual flatbuffer data to see the named data segments. - program_with_segments = _json_to_program(_program_flatbuffer_to_json(pte_data)) + program_with_segments = _flatbuffer_to_program(pte_data) # pyre-ignore Incompatible parameter type [6] self.assertEqual(len(program_with_segments.named_data), len(pte_named_data)) diff --git a/exir/pass_base.py b/exir/pass_base.py index 8ab0c675240..f93dd75d156 100644 --- a/exir/pass_base.py +++ b/exir/pass_base.py @@ -6,10 +6,11 @@ # LICENSE file in the root directory of this source tree. # pyre-strict - import operator import traceback +from abc import ABC, abstractmethod from contextlib import nullcontext +from dataclasses import dataclass from typing import ( Any, Callable, @@ -27,9 +28,7 @@ import torch from executorch.exir import memory - from executorch.exir.delegate import executorch_call_delegate, is_lowered_module - from executorch.exir.dialects.edge._ops import EdgeOpOverload from executorch.exir.error import ExportError, ExportErrorType from torch import fx @@ -37,6 +36,7 @@ from torch._subclasses import FakeTensorMode, UnsupportedFakeTensorException from torch._subclasses.fake_tensor import FakeTensor from torch._subclasses.functional_tensor import FunctionalTensor, FunctionalTensorMode +from torch.export import ExportedProgram from torch.fx import traceback as fx_traceback from torch.fx.experimental.proxy_tensor import PythonKeyTracer from torch.fx.graph import CodeGen @@ -182,6 +182,58 @@ class ExportPassBaseError(RuntimeError): pass +@dataclass(frozen=True) +class ExportedProgramPassResult: + exported_program: ExportedProgram + modified: bool + + +class ExportedProgramPassBase(ABC): + """ + Base interface for implementing passes that operate on ExportedProgram. + """ + + def __call__(self, exported_program: ExportedProgram) -> ExportedProgramPassResult: + """ + Runs the precondition check, the pass itself, and the postcondition check. + """ + + self.requires(exported_program) + res = self.call(exported_program) + self.ensures(exported_program) + return res + + @abstractmethod + def call(self, exported_program: ExportedProgram) -> ExportedProgramPassResult: + """ + The pass that is run through the given exported program. To implement a + pass, it is required to implement this function. + + Args: + exported_program: The exported program we will run a pass on + """ + + def requires(self, exported_program: ExportedProgram) -> None: # noqa: B027 + """ + This function will be called before the pass is run and will check that + the given exported program contains the preconditions needed to run the + pass. It is not required to implement this function. + + Args: + exported_program: The exported program we will run checks on + """ + + def ensures(self, exported_program: ExportedProgram) -> None: # noqa: B027 + """ + This function will be called after the pass is run and will check that + the given exported program contains the postconditions needed to run the + pass. It is not required to implement this function. + + Args: + exported_program: The exported program we will run checks on + """ + + class _ExportPassBase(PassBase): """ Interpreter-based pass class to help users maintain the IR spec while writing diff --git a/exir/pass_manager.py b/exir/pass_manager.py index b812ccea7b8..351e98651dd 100644 --- a/exir/pass_manager.py +++ b/exir/pass_manager.py @@ -5,28 +5,46 @@ # LICENSE file in the root directory of this source tree. # pyre-strict - -from typing import Callable, List, Optional, Union +import copy +import inspect +import logging +from typing import Callable, List, Optional, Type, TypeAlias, Union import torch import torch.fx.passes.infra.pass_manager as fx import torch.utils._pytree as pytree +from executorch.exir._program_utils import ( + _get_updated_graph_signature, + _get_updated_range_constraints, +) from executorch.exir.error import ExportError, ExportErrorType +from executorch.exir.pass_base import ExportedProgramPassBase, ExportedProgramPassResult +from torch._export.verifier import Verifier +from torch.export import ExportedProgram from torch.fx.passes.infra.pass_base import PassResult -from typing_extensions import TypeAlias +from torch.fx.passes.infra.pass_manager import pass_result_wrapper + +logger = logging.getLogger(__name__) +logger.setLevel(logging.WARNING) + +PassType: TypeAlias = Union[ + ExportedProgramPassBase, Callable[[torch.fx.GraphModule], Optional[PassResult]] +] + -PassType: TypeAlias = Callable[[torch.fx.GraphModule], Optional[PassResult]] +def _get_pass_name(fn: PassType) -> str: + """Returns a human-readable name for a pass.""" + return fn.__name__ if inspect.isfunction(fn) else type(fn).__name__ class PassManager(fx.PassManager): """ - Class to run multiple passes on a given graph module. The PassManager is - callable so to run it, we can just call the PassManager instance. + Runs multiple passes on a GraphModule. - Private Attributes: - * **passes**: A list of callable passes - * **params**: An instance of PassManagerParams containing the result of the - flags set in the constructor. + This is the legacy PassManager that extends torch.fx.passes.infra.pass_manager.PassManager. + Use this when you need to run passes on a GraphModule directly. + + For running passes on ExportedProgram, use ExportedProgramPassManager instead. """ def __init__( @@ -34,14 +52,11 @@ def __init__( passes: Optional[Union[List[PassType], List[List[PassType]]]] = None, run_checks_after_each_pass: bool = False, suppress_check_failures: bool = False, + steps: int = 1, ) -> None: - r""" - Args: - passes: A list of passes - enable_debug_pass: set to true to enable the debug passes - run_checks_after_each_pass: whether to run checks and linting after each pass - """ - + logger.warning( + "PassManager is deprecated. Please use ExportedProgramPassManager instead." + ) # Flatten the passes to a list of callables passes = passes if passes else [] flattened_passes = [ @@ -52,6 +67,7 @@ def __init__( flattened_passes, run_checks_after_each_pass=run_checks_after_each_pass, suppress_check_failures=suppress_check_failures, + steps=steps, ) def check(self, module: torch.nn.Module) -> None: @@ -65,10 +81,9 @@ def check(self, module: torch.nn.Module) -> None: node's spec field is a tuple) - Ensure that the graph module has type torch.fx.GraphModule """ - assert isinstance(module, fx.GraphModule) + assert isinstance(module, torch.fx.GraphModule) module.recompile() module.graph.lint() - # TODO(qihan): use verifier.check_is_exir for node in module.graph.nodes: if node.op == "call_method": @@ -76,3 +91,151 @@ def check(self, module: torch.nn.Module) -> None: ExportErrorType.NOT_SUPPORTED, f"call_method `{node}` is not supported except for backend delegate.", ) + + +class ExportedProgramPassManager(fx.PassManager): + """ + Runs multiple passes on an ExportedProgram. + + This PassManager is specifically designed for ExportedProgram and supports + both GraphModule-only passes and ExportedProgram-aware passes. + + For running passes on GraphModule directly, use PassManager instead. + """ + + def __init__( + self, + passes: Optional[Union[List[PassType], List[List[PassType]]]] = None, + constraints: Optional[List[Callable[[Callable, Callable], bool]]] = None, + run_checks_after_each_pass: bool = False, + suppress_check_failures: bool = False, + steps: int = 1, + ) -> None: + wrapped_passes = ( + [ + ( + fn + if isinstance(fn, ExportedProgramPassBase) + else pass_result_wrapper(fn) + ) + for fn in pytree.tree_flatten(passes)[0] + ] + if passes + else [] + ) + + super().__init__( + wrapped_passes, + constraints=constraints, + run_checks_after_each_pass=run_checks_after_each_pass, + suppress_check_failures=suppress_check_failures, + steps=steps, + ) + + def check(self, exported_program: ExportedProgram) -> None: + """Validates graph module invariants.""" + graph_module = exported_program.graph_module + graph_module.recompile() + graph_module.graph.lint() + + for node in graph_module.graph.nodes: + if node.op == "call_method": + raise ExportError( + ExportErrorType.NOT_SUPPORTED, + f"call_method `{node}` is not supported except for backend delegate.", + ) + + exported_program.validate() + + # pyre-ignore[14]: Intentionally overriding with different signature for ExportedProgram + def __call__( # noqa: C901 + self, + exported_program: ExportedProgram, + override_verifiers: Optional[list[Type[Verifier]]] = None, + ) -> ExportedProgramPassResult: + """ + Runs passes on an ExportedProgram. + + Handles both GraphModule-only passes and ExportedProgram-aware passes. Will create a shallow copy of the exported program before running passes. + + Args: + exported_program: The exported program to transform. + + Returns: + ExportedProgramPassResult containing the transformed program. + """ + if not self._validated: + self.solve_constraints() + + exported_program = copy.copy(exported_program) + + if override_verifiers: + exported_program._verifiers = override_verifiers + + self.check(exported_program) + + overall_modified = False + + for _ in range(self.steps): + step_modified = False + + for i, fn in enumerate(self.passes): + pass_modified = False + try: + if not isinstance(fn, ExportedProgramPassBase): + res = fn(exported_program.graph_module) + if res is None: + raise TypeError( + f"The result of pass {_get_pass_name(fn)} should be type PassResult. " + "Please wrap it with pass_result_wrapper()" + ) + + if res.modified: + # Not running _update_exported_program_graph_module here because it is + # possible that the verifier will fail upon new ExportedProgram construction, + # and we should only run verification after each pass if + # run_checks_after_each_pass is True. + res.graph_module.recompile() + exported_program._graph_module = res.graph_module + exported_program._graph_signature = ( + _get_updated_graph_signature( + exported_program.graph_signature, + res.graph_module, + ) + ) + exported_program._range_constraints = ( + _get_updated_range_constraints(res.graph_module) + ) + pass_modified = True + + else: + assert isinstance(fn, ExportedProgramPassBase) + ep_res = fn(exported_program) + exported_program = ep_res.exported_program + + if ep_res.modified: + pass_modified = True + exported_program.graph_module.recompile() + + if self.run_checks_after_each_pass: + self.check(exported_program) + + if pass_modified: + step_modified = True + logger.debug( + "Graph after pass '%s': %s", + _get_pass_name(fn), + exported_program.graph_module.graph, + ) + + except Exception as e: + prev_names = [_get_pass_name(p) for p in self.passes[:i]] + msg = f"An error occurred when running the '{_get_pass_name(fn)}' pass after the following passes: {prev_names}" + raise Exception(msg) from e # noqa: TRY002 + + overall_modified = overall_modified or step_modified + if not step_modified: + break + + self.check(exported_program) + return ExportedProgramPassResult(exported_program, overall_modified) diff --git a/exir/passes/BUCK b/exir/passes/BUCK index 954f1cfdb4f..4647388b388 100644 --- a/exir/passes/BUCK +++ b/exir/passes/BUCK @@ -381,6 +381,14 @@ fbcode_target(_kind = runtime.python_library, ], ) +fbcode_target(_kind = runtime.python_library, + name = "device_copy_ops_registry", + srcs = ["_device_copy_ops_registry.py"], + deps = [ + "//caffe2:torch", + ], +) + fbcode_target(_kind = runtime.python_library, name = "memory_format_ops_pass", srcs = [ diff --git a/exir/passes/__init__.py b/exir/passes/__init__.py index 9b1b8efe682..ede866549b2 100644 --- a/exir/passes/__init__.py +++ b/exir/passes/__init__.py @@ -62,6 +62,7 @@ from executorch.exir.passes.to_device_pass import ToDevicePass from executorch.exir.passes.weights_to_outputs_pass import weights_to_outputs_pass +from executorch.exir.sym_util import eval_shape_upper_bound from torch import fx from torch._subclasses import FakeTensor from torch.fx.passes.infra.pass_base import PassBase, PassResult @@ -281,31 +282,38 @@ def make_alloc_node( Note: tensor_metadata is only used in the case of a Tensor subclass, since fakifying a tensor subclass is not supported right now """ + + def materialize_alloc_spec( + shape: Union[torch.Size, Tuple[int, ...], List[int]], + dtype: torch.dtype, + ) -> memory.AllocSpec: + concrete_shape = eval_shape_upper_bound(shape) + if any(not isinstance(dim, int) for dim in concrete_shape): + raise RuntimeError( + "Memory allocator node requires concrete upper-bounded dimensions. " + f"Got shape {shape} and evaluated upper bounds {concrete_shape}." + ) + return (tuple(concrete_shape), dtype) + if val is None: if tensor_meta is not None: assert isinstance(tensor_meta, TensorMetadata) - alloc_spec = (tensor_meta.shape, tensor_meta.dtype) + alloc_spec = materialize_alloc_spec(tensor_meta.shape, tensor_meta.dtype) else: raise InternalError( "Memory allocator node needs FakeTensor val or TensorMetadata to proceed" ) elif isinstance(val, FakeTensor): - alloc_spec = (val.shape, val.dtype) + alloc_spec = materialize_alloc_spec(val.shape, val.dtype) else: assert isinstance(val, list) or isinstance(val, tuple) assert isinstance(tensor_meta, list) or isinstance(tensor_meta, tuple) alloc_spec: List[memory.AllocSpec] = [] for v, t in zip(val, tensor_meta): if v is not None: - # pyre-fixme[6]: For 1st argument expected - # `Union[List[Tuple[List[int], dtype]], Tuple[List[int], dtype]]` but - # got `Tuple[Size, dtype]`. - alloc_spec.append((v.shape, v.dtype)) + alloc_spec.append(materialize_alloc_spec(v.shape, v.dtype)) elif t is not None: - # pyre-fixme[6]: For 1st argument expected - # `Union[List[Tuple[List[int], dtype]], Tuple[List[int], dtype]]` but - # got `Tuple[Size, dtype]`. - alloc_spec.append((t.shape, t.dtype)) + alloc_spec.append(materialize_alloc_spec(t.shape, t.dtype)) else: raise InternalError( "Memory allocator node needs FakeTensor val or TensorMetadata to proceed" diff --git a/exir/passes/_device_copy_ops_registry.py b/exir/passes/_device_copy_ops_registry.py new file mode 100644 index 00000000000..a62b88d4234 --- /dev/null +++ b/exir/passes/_device_copy_ops_registry.py @@ -0,0 +1,58 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Registry for device copy ops used to insert explicit H2D (host-to-device) +and D2H (device-to-host) data transfer operations at delegate boundaries. + +These ops are inserted by PropagateDevicePass when enable_non_cpu_memory_planning +is True, making the graph functional by explicitly transferring data between +CPU and device memory. + +Follows the same registration pattern as dim_order_ops_registry.py. +""" + +import torch +from torch.library import impl, Library + +lib = Library("et_copy", "DEF") + +# _h2d_copy: copies a CPU tensor to device memory. +# At tracing time, this is a clone (both on CPU). At runtime, the out tensor +# is memory-planned on device, and the kernel calls +# DeviceAllocator::copy_host_to_device. +lib.define("_h2d_copy(Tensor self) -> Tensor") +lib.define("_h2d_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)") + +# _d2h_copy: copies a device tensor to CPU memory. +# At tracing time, this is a clone (both on CPU). At runtime, the self tensor +# has device memory, and the kernel calls DeviceAllocator::copy_device_to_host. +lib.define("_d2h_copy(Tensor self) -> Tensor") +lib.define("_d2h_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)") + + +@impl(lib, "_h2d_copy", "CompositeImplicitAutograd") +def _h2d_copy_impl(self: torch.Tensor) -> torch.Tensor: + # During tracing, both tensors are on CPU. Just clone to represent the transfer. + return self.clone() + + +@impl(lib, "_h2d_copy.out", "CompositeImplicitAutograd") +def _h2d_copy_out_impl(self: torch.Tensor, *, out: torch.Tensor) -> torch.Tensor: + out.copy_(self) + return out + + +@impl(lib, "_d2h_copy", "CompositeImplicitAutograd") +def _d2h_copy_impl(self: torch.Tensor) -> torch.Tensor: + # During tracing, both tensors are on CPU. Just clone to represent the transfer. + return self.clone() + + +@impl(lib, "_d2h_copy.out", "CompositeImplicitAutograd") +def _d2h_copy_out_impl(self: torch.Tensor, *, out: torch.Tensor) -> torch.Tensor: + out.copy_(self) + return out diff --git a/exir/passes/spec_prop_pass.py b/exir/passes/spec_prop_pass.py index 9adbf65dd90..73f943e55e0 100644 --- a/exir/passes/spec_prop_pass.py +++ b/exir/passes/spec_prop_pass.py @@ -11,6 +11,7 @@ import torch from executorch.exir.delegate import executorch_call_delegate +from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, ProxyValue from executorch.exir.tensor import TensorSpec from torch.export.exported_program import ExportGraphSignature @@ -18,6 +19,14 @@ from torch.fx.passes.infra.pass_base import PassResult from torch.utils import _pytree as pytree +# register llama.fallback (optional — only needed for QNN/llama sharding paths) +try: + import executorch.extension.llm.custom_ops.op_fallback # noqa: F401 + + _llama_fallback_default = exir_ops.edge.llama.fallback.default +except (ImportError, AttributeError): + _llama_fallback_default = None + # pyre-ignore def make_spec(x): @@ -75,9 +84,9 @@ def get_spec(x): elif node.op == "call_function" and node.target == operator.getitem: value_spec = pytree.tree_map(get_spec, node.args[0]) node.meta["spec"] = value_spec[node.args[1]] - elif ( - node.op == "call_function" - and node.target == executorch_call_delegate + elif node.op == "call_function" and node.target in ( + executorch_call_delegate, + _llama_fallback_default, ): # Note: We currently rely on delegate node specs not being regenerated, # as the spec is set somewhat manually when adding the call delegate node. diff --git a/exir/program/BUCK b/exir/program/BUCK index 7d9642efdb7..11f62edd99e 100644 --- a/exir/program/BUCK +++ b/exir/program/BUCK @@ -22,6 +22,7 @@ fbcode_target(_kind = runtime.python_library, ], deps = [ "//caffe2:torch", + "//executorch/exir:_program_utils", "//executorch/exir:error", "//executorch/exir:graph_module", "//executorch/exir:pass_base", diff --git a/exir/program/_program.py b/exir/program/_program.py index b3d94c8ffd7..485d72bbe45 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -5,8 +5,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe - +# pyre-strict import copy import io import logging @@ -38,7 +37,8 @@ from executorch.exir.operator.convert import _pybind_schema_to_native_schema from executorch.exir.operator.util import _QUANT_PRIMITIVES from executorch.exir.pass_base import PassBase -from executorch.exir.pass_manager import PassType +from executorch.exir.pass_manager import ExportedProgramPassManager, PassType + from executorch.exir.passes import ( base_post_op_replace_passes, base_pre_op_replace_passes, @@ -88,17 +88,11 @@ from torch.export._remove_auto_functionalized_pass import ( unsafe_remove_auto_functionalized_pass, ) -from torch.export.exported_program import ( - ConstantArgument, - ExportGraphSignature, - InputKind, - InputSpec, - OutputSpec, - TensorArgument, -) +from torch.export.exported_program import InputKind, InputSpec, TensorArgument from torch.fx import _pytree as fx_pytree from torch.fx._compatibility import compatibility -from torch.fx.passes.infra.pass_manager import PassManager +from torch.fx.passes.infra.pass_manager import PassManager as GraphModulePassManager + from torch.utils import _pytree as pytree Val = Any @@ -131,93 +125,10 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: transform_op_to_aten_op = {} -def _get_updated_range_constraints(gm): - def get_shape_env(gm): - vals = [ - node.meta["val"] - for node in gm.graph.nodes - if node.meta.get("val", None) is not None - ] - from torch._guards import detect_fake_mode # type: ignore[21] - - fake_mode = detect_fake_mode(vals) - if fake_mode is not None: - return fake_mode.shape_env - for v in vals: - if isinstance(v, torch.SymInt): - return v.node.shape_env - - shape_env = get_shape_env(gm) - if shape_env is None: - return {} - range_constraints = { - shape_env.replacements.get(k, k): v for k, v in shape_env.var_to_range.items() - } - # Only when we have an unbacked symint, and it's used as constructor inputs, - # runtime_var_to_range will make a difference compated to var_to_range. - # e.g. [2, oo) -> [0, oo) - for k, v in shape_env.var_to_range.items(): - if k not in shape_env.replacements: - range_constraints[k] = v - return range_constraints - - -def _get_updated_graph_signature( - old_signature: ExportGraphSignature, - new_gm: torch.fx.GraphModule, -) -> ExportGraphSignature: - """ - Update the graph signature's user_input/user_outputs. - """ - new_input_specs = [] - i = 0 - for node in new_gm.graph.nodes: - if node.op != "placeholder": - continue - - assert i < len( - old_signature.input_specs - ), "Number of inputs changed after transformation" - old_input_spec = old_signature.input_specs[i] - arg = ( - old_input_spec.arg - if isinstance(old_input_spec.arg, ConstantArgument) - # pyre-fixme[20]: Argument `class_fqn` expected. - else type(old_input_spec.arg)(node.name) - ) - new_input_specs.append( - InputSpec( - old_input_spec.kind, - arg, - old_input_spec.target, - persistent=old_input_spec.persistent, - ) - ) - i += 1 - - output_node = new_gm.graph.output_node() - assert output_node.op == "output" - - new_output_specs = [] - for i, node in enumerate(output_node.args[0]): - assert i < len( - old_signature.output_specs - ), "Number of outputs changed after transformation" - old_output_spec = old_signature.output_specs[i] - arg = ( - old_output_spec.arg - if isinstance(old_output_spec.arg, ConstantArgument) - # pyre-fixme[20]: Argument `class_fqn` expected. - else type(old_output_spec.arg)(node.name) - ) - new_output_specs.append( - OutputSpec(old_output_spec.kind, arg, old_output_spec.target) - ) - - new_signature = ExportGraphSignature( - input_specs=new_input_specs, output_specs=new_output_specs - ) - return new_signature +from executorch.exir._program_utils import ( # noqa: E402 + _get_updated_graph_signature, + _get_updated_range_constraints, +) def _transform( @@ -243,13 +154,13 @@ def _transform( ), f"Expected all passes to be of PassType, not list or Verifier. Use override_verifiers kwarg instead. Got: {list(passes)}" return _transform_with_pass_manager( - self, PassManager(list(passes)), override_verifiers + self, ExportedProgramPassManager(list(passes)), override_verifiers ) def _transform_with_pass_manager( - self, - pass_manager: PassManager, + self: ExportedProgram, + pass_manager: Union[ExportedProgramPassManager, GraphModulePassManager], override_verifiers: None | list[Type[Verifier]] = None, ) -> "ExportedProgram": """ @@ -258,22 +169,26 @@ def _transform_with_pass_manager( Args: self: The ExportedProgram instance to transform pass_manager: An instance of PassManager to apply transformations. + - ExportedProgramPassManager: operates on the full ExportedProgram + - GraphModulePassManager: operates on the GraphModule only override_verifiers: Optional list of verifier classes to use instead of the default verifiers. This is needed if the transforms yields illegal graph that the default verifier cannot handle. Returns: ExportedProgram: A new ExportedProgram with the transformations applied, or self if no changes were made """ - res = pass_manager(self.graph_module) - transformed_gm = res.graph_module if res is not None else self.graph_module - assert transformed_gm is not None - - if transformed_gm is self.graph_module and not res.modified: - return self - - return _update_exported_program_graph_module( - self, transformed_gm, override_verifiers - ) + if isinstance(pass_manager, ExportedProgramPassManager): + res = pass_manager(self, override_verifiers) + if not res.modified: + return self + return res.exported_program + else: + res = pass_manager(self.graph_module) + if not res.modified: + return self + return _update_exported_program_graph_module( + self, res.graph_module, override_verifiers + ) def _update_exported_program_graph_module( @@ -1324,7 +1239,12 @@ def collect_named_data_store_outputs( def to_edge_transform_and_lower( # noqa: C901 programs: Union[ExportedProgram, Dict[str, ExportedProgram]], transform_passes: Optional[ - Union[Sequence[PassType], Dict[str, Sequence[PassType]], PassManager] + Union[ + Sequence[PassType], + Dict[str, Sequence[PassType]], + GraphModulePassManager, + ExportedProgramPassManager, + ] ] = None, partitioner: Optional[ Union[List[Partitioner], Dict[str, List[Partitioner]]] @@ -1359,7 +1279,7 @@ def to_edge_transform_and_lower( # noqa: C901 2) a dictionary - only method names specified in the dictionary will be transformed with their corresponding passes - 3) an instance of a PassManager - + 3) an instance of a PassManager (either a GraphModulePassManager or an ExportedProgramPassManager) - all methods in the given EdgeProgramManager will be transformed with the given PassManager instance. @@ -1604,7 +1524,12 @@ def exported_program(self, method_name: str = "forward") -> ExportedProgram: @et_logger("transform") def transform( self, - passes: Union[Sequence[PassType], Dict[str, Sequence[PassType]], PassManager], + passes: Union[ + Sequence[PassType], + Dict[str, Sequence[PassType]], + ExportedProgramPassManager, + GraphModulePassManager, + ], compile_config: Optional[EdgeCompileConfig] = None, ) -> "EdgeProgramManager": """ @@ -1618,7 +1543,7 @@ def transform( 2) a dictionary mapping method names to lists of passes - only method names specified in the dictionary will be transformed with their corresponding passes. - 3) a PassManager instance - + 3) a PassManager (either ExportedProgramPassManager or GraphModulePassManager) instance - all methods in the given EdgeProgramManager will be transformed with the given PassManager instance. compile_config: Compile config to use for veriy the correctness of model @@ -1637,13 +1562,15 @@ def transform( # Cast passes parameter upfront. passes_seq: Optional[Sequence[PassType]] = None passes_dict: Optional[Dict[str, Sequence[PassType]]] = None - pass_manager: Optional[PassManager] = None + pass_manager: Optional[ + Union[ExportedProgramPassManager, GraphModulePassManager] + ] = None if isinstance(passes, Sequence): passes_seq = passes if isinstance(passes, dict): passes_dict = passes - if isinstance(passes, PassManager): + if isinstance(passes, (ExportedProgramPassManager, GraphModulePassManager)): pass_manager = passes for name, program in self._edge_programs.items(): diff --git a/exir/tests/TARGETS b/exir/tests/TARGETS index 322f72c870a..21493a69644 100644 --- a/exir/tests/TARGETS +++ b/exir/tests/TARGETS @@ -504,3 +504,14 @@ python_unittest( "//executorch/exir/passes:propagate_device_pass", ], ) + +python_unittest( + name = "device_copy_ops", + srcs = [ + "test_device_copy_ops.py", + ], + deps = [ + "//caffe2:torch", + "//executorch/exir/passes:device_copy_ops_registry", + ], +) diff --git a/exir/tests/test_device_copy_ops.py b/exir/tests/test_device_copy_ops.py new file mode 100644 index 00000000000..805159d9d81 --- /dev/null +++ b/exir/tests/test_device_copy_ops.py @@ -0,0 +1,73 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +# Import the registry to register the ops +import executorch.exir.passes._device_copy_ops_registry # noqa: F401 + +import torch + + +class DeviceCopyOpsRegistryTest(unittest.TestCase): + """Tests that et_copy._h2d_copy and et_copy._d2h_copy ops are correctly + registered and produce expected outputs during tracing (CPU-only).""" + + def test_h2d_copy_functional(self): + """_h2d_copy should return a clone of the input tensor.""" + x = torch.randn(2, 3) + result = torch.ops.et_copy._h2d_copy(x) + self.assertEqual(result.shape, x.shape) + self.assertEqual(result.dtype, x.dtype) + self.assertTrue(torch.equal(result, x)) + # Should be a new tensor, not the same object + self.assertFalse(result.data_ptr() == x.data_ptr()) + + def test_d2h_copy_functional(self): + """_d2h_copy should return a clone of the input tensor.""" + x = torch.randn(4, 5) + result = torch.ops.et_copy._d2h_copy(x) + self.assertEqual(result.shape, x.shape) + self.assertEqual(result.dtype, x.dtype) + self.assertTrue(torch.equal(result, x)) + self.assertFalse(result.data_ptr() == x.data_ptr()) + + def test_h2d_copy_out_variant(self): + """_h2d_copy.out should copy data into the provided out tensor.""" + x = torch.randn(3, 3) + out = torch.empty(3, 3) + result = torch.ops.et_copy._h2d_copy.out(x, out=out) + self.assertTrue(result is out) + self.assertTrue(torch.equal(out, x)) + + def test_d2h_copy_out_variant(self): + """_d2h_copy.out should copy data into the provided out tensor.""" + x = torch.randn(2, 4) + out = torch.empty(2, 4) + result = torch.ops.et_copy._d2h_copy.out(x, out=out) + self.assertTrue(result is out) + self.assertTrue(torch.equal(out, x)) + + def test_h2d_copy_preserves_dtype(self): + """_h2d_copy should work with various dtypes.""" + for dtype in [torch.float32, torch.float16, torch.int32, torch.int64]: + x = torch.ones(2, 2, dtype=dtype) + result = torch.ops.et_copy._h2d_copy(x) + self.assertEqual(result.dtype, dtype) + self.assertTrue(torch.equal(result, x)) + + def test_h2d_copy_scalar_tensor(self): + """_h2d_copy should handle 0-dim tensors.""" + x = torch.tensor(3.14) + result = torch.ops.et_copy._h2d_copy(x) + self.assertEqual(result.shape, torch.Size([])) + self.assertTrue(torch.equal(result, x)) + + def test_d2h_copy_empty_tensor(self): + """_d2h_copy should handle empty tensors.""" + x = torch.empty(0, 3) + result = torch.ops.et_copy._d2h_copy(x) + self.assertEqual(result.shape, torch.Size([0, 3])) diff --git a/exir/tests/test_pass_infra.py b/exir/tests/test_pass_infra.py index ded3c0e849d..7df6b76b93a 100644 --- a/exir/tests/test_pass_infra.py +++ b/exir/tests/test_pass_infra.py @@ -9,14 +9,22 @@ import unittest +import executorch.exir as exir import torch -from executorch.exir import to_edge -from executorch.exir.pass_base import ExportPassBaseError, ProxyValue -from executorch.exir.pass_manager import PassManager +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ( + ExportedProgramPassBase, + ExportedProgramPassResult, + ExportPassBaseError, + ProxyValue, +) +from executorch.exir.pass_manager import ExportedProgramPassManager, PassManager from executorch.exir.passes import ScalarToTensorPass from executorch.exir.passes.pass_registry import PassRegistry -from torch.export import Dim, export -from torch.fx.passes.infra.pass_base import PassBase +from executorch.exir.program import to_edge +from torch.export import Dim, export, ExportedProgram +from torch.export.graph_signature import InputKind, InputSpec, TensorArgument +from torch.fx.passes.infra.pass_base import PassBase, PassResult class TestPassInfra(unittest.TestCase): @@ -216,3 +224,228 @@ def test_rejects_implicit_symbolic_scalar_coercions(self) -> None: with self.assertRaisesRegex(ExportPassBaseError, "converted to float"): float(ProxyValue(sym_float, torch.fx.Graph().placeholder("x"))) + + +class TestExportedProgramPassManager(unittest.TestCase): + def test_runs_graph_module_passes_on_exported_program(self) -> None: + """ + Tests that ExportedProgramPassManager runs GraphModule passes + on an ExportedProgram and the graph is correctly modified. + """ + + def replace_add_with_mul(gm: torch.fx.GraphModule) -> PassResult: + modified = False + for node in gm.graph.find_nodes( + op="call_function", target=exir_ops.edge.aten.add.Tensor + ): + node.target = exir_ops.edge.aten.mul.Tensor + modified = True + return PassResult(gm, modified) + + def f(x: torch.Tensor) -> torch.Tensor: + y = torch.add(x, x) + z = torch.add(y, x) + return z + + exported_program = ( + exir.capture(f, (torch.randn(10),), exir.CaptureConfig()) + .to_edge() + .exported_program + ) + + pm = ExportedProgramPassManager(passes=[replace_add_with_mul]) + result = pm(exported_program) + + # Verify return type + self.assertIsInstance(result, ExportedProgramPassResult) + self.assertTrue(result.modified) + + # Check that all add ops were replaced with mul + self.assertEqual( + len( + result.exported_program.graph.find_nodes( + op="call_function", target=exir_ops.edge.aten.add.Tensor + ) + ), + 0, + ) + + def test_updates_constants_on_exported_program(self) -> None: + """ + Tests that ExportedProgramPassManager can update constants + in the ExportedProgram using an ExportedProgram-aware pass. + """ + + class DoubleConstantsPass(ExportedProgramPassBase): + """Pass that doubles all constant tensor values in the ExportedProgram.""" + + def call(self, ep: ExportedProgram) -> ExportedProgramPassResult: + modified = False + for key, const in ep.constants.items(): + if isinstance(const, torch.Tensor): + ep.constants[key] = const * 2 + modified = True + return ExportedProgramPassResult(ep, modified) + + class ModuleWithConstant(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.weight = torch.ones(3) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x + self.weight + + module = ModuleWithConstant() + exported_program = to_edge( + torch.export.export(module, (torch.randn(3),)) + ).exported_program() + + # Verify there are constants in the ExportedProgram + self.assertGreater( + len(exported_program.constants), 0, "Expected constants in ExportedProgram" + ) + + # Store original constant values + original_values = { + key: const.clone() + for key, const in exported_program.constants.items() + if isinstance(const, torch.Tensor) + } + + pm = ExportedProgramPassManager(passes=[DoubleConstantsPass()]) + result = pm(exported_program) + + self.assertIsInstance(result, ExportedProgramPassResult) + self.assertTrue(result.modified) + + # Verify constants were doubled + for key, original_const in original_values.items(): + new_const = result.exported_program.constants[key] + self.assertTrue( + torch.allclose(new_const, original_const * 2), + f"Constant {key} was not doubled correctly", + ) + + def test_adds_constant_to_exported_program(self) -> None: + """ + Tests that ExportedProgramPassManager can add a new constant + to the ExportedProgram, including updating the graph and input specs. + """ + + class AddConstantPass(ExportedProgramPassBase): + """Pass that adds a new constant tensor to the ExportedProgram.""" + + def call(self, ep: ExportedProgram) -> ExportedProgramPassResult: + graph = ep.graph_module.graph + sig = ep.graph_signature + + # Find the first user input to insert before it + placeholders = graph.find_nodes(op="placeholder") + assert len(placeholders) == 1 + user_input_node = placeholders[0] + + # Create a new constant tensor + new_constant_name = "_test_added_constant" + new_constant_tensor = torch.tensor([1.0, 2.0, 3.0]) + + # Add placeholder node for the new constant + with graph.inserting_before(user_input_node): + new_placeholder = graph.placeholder(new_constant_name) + # Set up meta for the new placeholder + new_placeholder.meta["val"] = new_constant_tensor + + # Add the constant to the constants dict + ep.constants[new_constant_name] = new_constant_tensor + + # Update input specs to include the new constant + new_input_spec = InputSpec( + kind=InputKind.CONSTANT_TENSOR, + arg=TensorArgument(name=new_placeholder.name), + target=new_constant_name, + persistent=False, + ) + sig.input_specs = (new_input_spec, sig.input_specs[0]) + + return ExportedProgramPassResult(ep, modified=True) + + class IdentityModule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x + + exported_program = to_edge( + torch.export.export(IdentityModule(), (torch.randn(3),)) + ).exported_program() + assert len(exported_program.constants) == 0 + assert len(exported_program.graph_signature.input_specs) == 1 + + pm = ExportedProgramPassManager(passes=[AddConstantPass()]) + result = pm(exported_program) + + self.assertIsInstance(result, ExportedProgramPassResult) + self.assertTrue(result.modified) + + # Verify the new constant was added to constants dict + self.assertEqual(len(result.exported_program.constants), 1) + self.assertIn("_test_added_constant", result.exported_program.constants) + self.assertTrue( + torch.allclose( + result.exported_program.constants["_test_added_constant"], + torch.tensor([1.0, 2.0, 3.0]), + ) + ) + + # Verify input_specs was updated + self.assertEqual( + len(result.exported_program.graph_signature.input_specs), + 2, + ) + + # Verify the new placeholder exists in the graph + placeholder_names = [ + node.target + for node in result.exported_program.graph_module.graph.find_nodes( + op="placeholder" + ) + ] + self.assertTrue(len(placeholder_names) == 2) + + # Verify the new input spec has the correct kind + new_spec = None + for spec in result.exported_program.graph_signature.input_specs: + if spec.target == "_test_added_constant": + new_spec = spec + break + self.assertIsNotNone(new_spec) + self.assertEqual(new_spec.kind, InputKind.CONSTANT_TENSOR) + + def test_invalid_pass_creates_call_method(self) -> None: + """ + Tests that ExportedProgramPassManager detects invalid passes + that introduce call_method nodes. + """ + + def introduce_call_method(gm: torch.fx.GraphModule) -> PassResult: + node = list(gm.graph.nodes)[-2] + with gm.graph.inserting_after(node): + gm.graph.call_method("torch.ops.relu", (torch.randn(2),)) + return PassResult(gm, True) + + def f(x: torch.Tensor) -> torch.Tensor: + y = torch.add(x, x) + return y + + exported_program = ( + exir.capture(f, (torch.randn(10),), exir.CaptureConfig()) + .to_edge() + .exported_program + ) + + pm = ExportedProgramPassManager( + passes=[introduce_call_method], run_checks_after_each_pass=True + ) + + with self.assertRaisesRegex(Exception, "call_method"): + pm(exported_program) diff --git a/exir/tests/test_passes.py b/exir/tests/test_passes.py index 8a084ba491a..1316dffb828 100644 --- a/exir/tests/test_passes.py +++ b/exir/tests/test_passes.py @@ -1,5 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. +# Copyright 2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -74,6 +75,7 @@ ) from executorch.exir.passes.scalar_to_tensor_pass import ScalarToTensorPass from executorch.exir.passes.spec_prop_pass import SpecPropPass +from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass from executorch.exir.passes.sym_to_tensor_pass import SymToTensorPass from executorch.exir.program._program import lift_constant_tensor_pass from executorch.exir.schema import TensorShapeDynamism @@ -1036,6 +1038,53 @@ def test_alloc_node_spec(self) -> None: for node in alloc_nodes: self.assertTrue(isinstance(node.meta.get("spec", None), TensorSpec)) + def test_to_out_var_dynamic_alloc_uses_concrete_upper_bounds(self) -> None: + class DynamicRelu(nn.Module): + def forward(self, x): + return torch.relu(x) + + eager_model = DynamicRelu() + inputs = (torch.randn(2, 4, 8, 3),) + dynamic_shapes = { + "x": { + 0: torch.export.Dim("batch", min=0, max=2), + 2: torch.export.Dim("height", min=0, max=8), + 3: torch.export.Dim("width", min=0, max=8), + } + } + prog = to_edge( + export( + eager_model, + inputs, + dynamic_shapes=dynamic_shapes, + strict=True, + ), + compile_config=exir.EdgeCompileConfig(_check_ir_validity=False), + ) + new_prog = prog.transform( + [ + SpecPropPass(), + ConstraintBasedSymShapeEvalPass(), + ] + ) + + new_gm_res = ToOutVarPass()(new_prog.exported_program().graph_module) + self.assertIsNotNone(new_gm_res) + new_gm = new_gm_res.graph_module + + alloc_nodes = [] + for node in new_gm.graph.nodes: + if node.target == memory.alloc: + alloc_nodes.append(node) + + self.assertTrue(len(alloc_nodes) > 0) + for node in alloc_nodes: + alloc_spec = node.args[0] + self.assertIsInstance(alloc_spec, tuple) + shape, _dtype = alloc_spec + for dim in shape: + self.assertIsInstance(dim, int) + def test_debug_pass_file_log(self) -> None: eager_model = Mul() inputs = eager_model.get_random_inputs() diff --git a/extension/android/BUCK b/extension/android/BUCK index c7e275805e2..92cb7c8c040 100644 --- a/extension/android/BUCK +++ b/extension/android/BUCK @@ -8,17 +8,19 @@ non_fbcode_target(_kind = fb_android_library, warnings_as_errors = False, required_for_source_only_abi = True, srcs = [ - "executorch_android/src/main/java/org/pytorch/executorch/DType.java", - "executorch_android/src/main/java/org/pytorch/executorch/EValue.java", - "executorch_android/src/main/java/org/pytorch/executorch/ExecuTorchRuntime.java", - "executorch_android/src/main/java/org/pytorch/executorch/ExecutorchRuntimeException.java", - "executorch_android/src/main/java/org/pytorch/executorch/MethodMetadata.java", - "executorch_android/src/main/java/org/pytorch/executorch/Module.java", + "executorch_android/src/main/java/org/pytorch/executorch/DType.kt", + "executorch_android/src/main/java/org/pytorch/executorch/EValue.kt", + "executorch_android/src/main/java/org/pytorch/executorch/ExecuTorchRuntime.kt", + "executorch_android/src/main/java/org/pytorch/executorch/ExecutorchRuntimeException.kt", + "executorch_android/src/main/java/org/pytorch/executorch/MethodMetadata.kt", + "executorch_android/src/main/java/org/pytorch/executorch/Module.kt", "executorch_android/src/main/java/org/pytorch/executorch/Tensor.java", - "executorch_android/src/main/java/org/pytorch/executorch/annotations/Experimental.java", + "executorch_android/src/main/java/org/pytorch/executorch/annotations/Experimental.kt", ], autoglob = False, - language = "JAVA", + language = "KOTLIN", + pure_kotlin = False, + extra_kotlinc_arguments = ["-Xjvm-default=all"], deps = [ "//fbandroid/java/com/facebook/jni:jni", "//fbandroid/libraries/soloader/java/com/facebook/soloader/nativeloader:nativeloader", @@ -31,11 +33,11 @@ non_fbcode_target(_kind = fb_android_library, name = "executorch_training", warnings_as_errors = False, srcs = [ - "executorch_android/src/main/java/org/pytorch/executorch/training/SGD.java", - "executorch_android/src/main/java/org/pytorch/executorch/training/TrainingModule.java", + "executorch_android/src/main/java/org/pytorch/executorch/training/SGD.kt", + "executorch_android/src/main/java/org/pytorch/executorch/training/TrainingModule.kt", ], autoglob = False, - language = "JAVA", + language = "KOTLIN", deps = [ ":executorch", "//fbandroid/java/com/facebook/jni:jni", @@ -47,13 +49,14 @@ non_fbcode_target(_kind = fb_android_library, name = "executorch_llama", warnings_as_errors = False, srcs = [ - "executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmCallback.java", - "executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmGenerationConfig.java", - "executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java", - "executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModuleConfig.java", + "executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmCallback.kt", + "executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmGenerationConfig.kt", + "executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.kt", + "executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModuleConfig.kt", ], autoglob = False, - language = "JAVA", + language = "KOTLIN", + extra_kotlinc_arguments = ["-Xjvm-default=all"], deps = [ ":executorch", "//fbandroid/java/com/facebook/jni:jni", diff --git a/extension/android/executorch_android/android_test_setup.sh b/extension/android/executorch_android/android_test_setup.sh index 350c60b2e25..9ed1ae63da2 100644 --- a/extension/android/executorch_android/android_test_setup.sh +++ b/extension/android/executorch_android/android_test_setup.sh @@ -29,7 +29,7 @@ prepare_tinyllama() { } prepare_golden() { - local url="https://gha-artifacts.s3.amazonaws.com/pytorch/executorch/test-backend-artifacts/golden-artifacts-xnnpack/golden_artifacts_26022500.zip" + local url="https://gha-artifacts.s3.amazonaws.com/pytorch/executorch/test-backend-artifacts/golden-artifacts-xnnpack/golden_artifacts_26052718.zip" curl -sL -o /tmp/golden.zip "$url" unzip -o /tmp/golden.zip -d /tmp/golden/ for model in mobilenet_v2 vit_b_16; do diff --git a/extension/android/executorch_android/build.gradle b/extension/android/executorch_android/build.gradle index 3ee5b5877b3..2dbe0e1fb5f 100644 --- a/extension/android/executorch_android/build.gradle +++ b/extension/android/executorch_android/build.gradle @@ -51,6 +51,7 @@ android { } kotlinOptions { jvmTarget = "11" + freeCompilerArgs += ["-Xjvm-default=all"] } } diff --git a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/AsrModuleInstrumentationTest.kt b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/AsrModuleInstrumentationTest.kt new file mode 100644 index 00000000000..fe8a168e406 --- /dev/null +++ b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/AsrModuleInstrumentationTest.kt @@ -0,0 +1,260 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +package org.pytorch.executorch + +import androidx.test.ext.junit.runners.AndroidJUnit4 +import java.io.File +import java.io.IOException +import org.apache.commons.io.FileUtils +import org.junit.Assert.assertEquals +import org.junit.Assert.assertFalse +import org.junit.Assert.assertTrue +import org.junit.Assert.fail +import org.junit.Assume.assumeNotNull +import org.junit.Test +import org.junit.runner.RunWith +import org.pytorch.executorch.TestFileUtils.getTestFilePath +import org.pytorch.executorch.extension.asr.AsrCallback +import org.pytorch.executorch.extension.asr.AsrModule +import org.pytorch.executorch.extension.asr.AsrTranscribeConfig + +/** + * Instrumentation tests for [AsrModule], [AsrTranscribeConfig], and [AsrCallback]. + * + * Tests cover: + * - Constructor validation (invalid model/tokenizer/preprocessor paths) + * - AsrTranscribeConfig builder and validation + * - Lifecycle (close idempotency, use-after-close) + * - Transcribe validation (invalid WAV path) + * + * The test fixture is the TinyStories-110M LLM model, NOT an ASR model, so functional transcription + * tests are not possible. Tests that require a valid AsrModule instance handle the case where + * nativeCreate fails (stories.pte lacks encoder/text_decoder methods). + */ +@RunWith(AndroidJUnit4::class) +class AsrModuleInstrumentationTest { + + // ─── Constructor validation ───────────────────────────────────────────────── + + @Test(timeout = 30_000) + fun testInvalidModelPathThrows() { + try { + AsrModule("/nonexistent/model.pte", "/nonexistent/tokenizer") + fail("Should throw for invalid model path") + } catch (_: IllegalArgumentException) { + // Expected: require(modelFile.canRead() && modelFile.isFile) + } + } + + @Test(timeout = 30_000) + fun testInvalidTokenizerPathThrows() { + val modelFile = provisionModelFile() + assumeNotNull("Test resource $MODEL_FILE_NAME not available", modelFile) + try { + AsrModule(modelFile!!.absolutePath, "/nonexistent/tokenizer") + fail("Should throw for invalid tokenizer path") + } catch (_: IllegalArgumentException) { + // Expected: require(tokenizerFile.exists()) + } + } + + @Test(timeout = 30_000) + fun testInvalidPreprocessorPathThrows() { + val modelFile = provisionModelFile() + val tokenizerFile = provisionTokenizerFile() + assumeNotNull("Test resource $MODEL_FILE_NAME not available", modelFile) + assumeNotNull("Test resource $TOKENIZER_FILE_NAME not available", tokenizerFile) + try { + AsrModule( + modelFile!!.absolutePath, + tokenizerFile!!.absolutePath, + preprocessorPath = "/nonexistent/preprocessor.pte", + ) + fail("Should throw for invalid preprocessor path") + } catch (_: IllegalArgumentException) { + // Expected: require(preprocessorFile.canRead() && preprocessorFile.isFile) + } + } + + @Test(timeout = 30_000) + fun testNonAsrModelFailsGracefully() { + val modelFile = provisionModelFile() + val tokenizerFile = provisionTokenizerFile() + assumeNotNull("Test resource $MODEL_FILE_NAME not available", modelFile) + assumeNotNull("Test resource $TOKENIZER_FILE_NAME not available", tokenizerFile) + try { + val module = AsrModule(modelFile!!.absolutePath, tokenizerFile!!.absolutePath) + // If construction succeeds (model was accepted), verify basic state + assertTrue("Module should be valid after construction", module.isValid) + module.close() + } catch (_: ExecutorchRuntimeException) { + // Expected: nativeCreate returns 0 for non-ASR model + } catch (_: RuntimeException) { + // Also acceptable: native layer rejects the model + } + } + + // ─── Lifecycle ────────────────────────────────────────────────────────────── + + @Test(timeout = 30_000) + fun testCloseIsIdempotent() { + val module = tryCreateAsrModule() ?: return + module.close() + module.close() + module.close() + assertFalse("isValid must be false after close", module.isValid) + } + + @Test(timeout = 30_000) + fun testLoadAfterCloseThrows() { + val module = tryCreateAsrModule() ?: return + module.close() + try { + module.load() + fail("load() after close() must throw IllegalStateException") + } catch (_: IllegalStateException) { + // Expected + } + } + + @Test(timeout = 30_000) + fun testTranscribeAfterCloseThrows() { + val module = tryCreateAsrModule() ?: return + module.close() + try { + module.transcribe("/some/audio.wav") + fail("transcribe() after close() must throw IllegalStateException") + } catch (_: IllegalStateException) { + // Expected + } + } + + @Test(timeout = 30_000) + fun testIsValidAndIsLoadedState() { + val module = tryCreateAsrModule() ?: return + assertTrue("Module should be valid after construction", module.isValid) + module.close() + assertFalse("Module should not be valid after close", module.isValid) + assertFalse("Module should not be loaded after close", module.isLoaded) + } + + // ─── Transcribe validation ────────────────────────────────────────────────── + + @Test(timeout = 30_000) + fun testTranscribeInvalidWavPathThrows() { + val module = tryCreateAsrModule() ?: return + try { + module.transcribe("/nonexistent/audio.wav") + fail("transcribe() with invalid WAV path must throw") + } catch (_: IllegalArgumentException) { + // Expected: require(wavFile.canRead() && wavFile.isFile) + } finally { + module.close() + } + } + + // ─── AsrTranscribeConfig ──────────────────────────────────────────────────── + + @Test + fun testConfigDefaults() { + val config = AsrTranscribeConfig() + assertEquals(128L, config.maxNewTokens) + assertEquals(0.0f, config.temperature, 0.0f) + assertEquals(0L, config.decoderStartTokenId) + } + + @Test + fun testConfigBuilder() { + val config = + AsrTranscribeConfig.Builder() + .setMaxNewTokens(256) + .setTemperature(0.7f) + .setDecoderStartTokenId(50258) + .build() + assertEquals(256L, config.maxNewTokens) + assertEquals(0.7f, config.temperature, 0.001f) + assertEquals(50258L, config.decoderStartTokenId) + } + + @Test + fun testConfigCustomValues() { + val config = AsrTranscribeConfig(maxNewTokens = 64, temperature = 0.5f, decoderStartTokenId = 1) + assertEquals(64L, config.maxNewTokens) + assertEquals(0.5f, config.temperature, 0.001f) + assertEquals(1L, config.decoderStartTokenId) + } + + @Test(expected = IllegalArgumentException::class) + fun testConfigZeroMaxNewTokensThrows() { + AsrTranscribeConfig(maxNewTokens = 0) + } + + @Test(expected = IllegalArgumentException::class) + fun testConfigNegativeMaxNewTokensThrows() { + AsrTranscribeConfig(maxNewTokens = -1) + } + + @Test(expected = IllegalArgumentException::class) + fun testConfigNegativeTemperatureThrows() { + AsrTranscribeConfig(temperature = -0.1f) + } + + @Test(expected = IllegalArgumentException::class) + fun testConfigBuilderZeroMaxNewTokensThrows() { + AsrTranscribeConfig.Builder().setMaxNewTokens(0).build() + } + + @Test(expected = IllegalArgumentException::class) + fun testConfigBuilderNegativeTemperatureThrows() { + AsrTranscribeConfig.Builder().setTemperature(-1.0f).build() + } + + @Test + fun testConfigDataClassEquality() { + val a = AsrTranscribeConfig(maxNewTokens = 100, temperature = 0.5f, decoderStartTokenId = 42) + val b = AsrTranscribeConfig(maxNewTokens = 100, temperature = 0.5f, decoderStartTokenId = 42) + assertEquals(a, b) + assertEquals(a.hashCode(), b.hashCode()) + } + + // ─── Helpers ──────────────────────────────────────────────────────────────── + + @Throws(IOException::class) + private fun provisionModelFile(): File? { + val pteFile = File(getTestFilePath(MODEL_FILE_NAME)) + val stream = javaClass.getResourceAsStream(MODEL_FILE_NAME) ?: return null + stream.use { FileUtils.copyInputStreamToFile(it, pteFile) } + return pteFile + } + + @Throws(IOException::class) + private fun provisionTokenizerFile(): File? { + val tokenizerFile = File(getTestFilePath(TOKENIZER_FILE_NAME)) + val stream = javaClass.getResourceAsStream(TOKENIZER_FILE_NAME) ?: return null + stream.use { FileUtils.copyInputStreamToFile(it, tokenizerFile) } + return tokenizerFile + } + + private fun tryCreateAsrModule(): AsrModule? { + val modelFile = provisionModelFile() + val tokenizerFile = provisionTokenizerFile() + assumeNotNull("Test resource $MODEL_FILE_NAME not available", modelFile) + assumeNotNull("Test resource $TOKENIZER_FILE_NAME not available", tokenizerFile) + return try { + AsrModule(modelFile!!.absolutePath, tokenizerFile!!.absolutePath) + } catch (_: RuntimeException) { + // nativeCreate may reject non-ASR models — skip lifecycle tests in that case + null + } + } + + companion object { + private const val MODEL_FILE_NAME = "/stories.pte" + private const val TOKENIZER_FILE_NAME = "/tokenizer.bin" + } +} diff --git a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/LlmLoraInstrumentationTest.kt b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/LlmLoraInstrumentationTest.kt new file mode 100644 index 00000000000..a8d35b09de2 --- /dev/null +++ b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/LlmLoraInstrumentationTest.kt @@ -0,0 +1,291 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +package org.pytorch.executorch + +import androidx.test.ext.junit.runners.AndroidJUnit4 +import java.io.File +import java.io.IOException +import org.apache.commons.io.FileUtils +import org.junit.After +import org.junit.Assert.assertTrue +import org.junit.Assert.fail +import org.junit.Before +import org.junit.Test +import org.junit.runner.RunWith +import org.pytorch.executorch.TestFileUtils.getTestFilePath +import org.pytorch.executorch.extension.llm.LlmCallback +import org.pytorch.executorch.extension.llm.LlmModule +import org.pytorch.executorch.extension.llm.LlmModuleConfig + +/** + * Instrumentation tests for LlmModule's LoRA / dataFiles constructor paths. + * + * LoRA adapters are loaded at construction time via the `dataFiles` parameter or + * `LlmModuleConfig.dataPath`. These tests verify that: + * 1. The dataFiles constructor variants produce a functional module + * 2. LlmModuleConfig with dataPath integrates correctly + * 3. Invalid data file paths are handled gracefully + * 4. Empty vs null dataFiles behave identically to no-data constructors + * + * Uses TinyStories-110M; no LoRA adapter fixture is available so functional LoRA tests + * (output-changes-with-adapter) are not possible. + */ +@RunWith(AndroidJUnit4::class) +class LlmLoraInstrumentationTest { + + private var llmModule: LlmModule? = null + + @Before + @Throws(IOException::class) + fun setUp() { + val pteFile = File(getTestFilePath(MODEL_FILE_NAME)) + requireNotNull(javaClass.getResourceAsStream(MODEL_FILE_NAME)) { + "Test resource $MODEL_FILE_NAME not found; did android_test_setup.sh run?" + } + .use { FileUtils.copyInputStreamToFile(it, pteFile) } + + val tokenizerFile = File(getTestFilePath(TOKENIZER_FILE_NAME)) + requireNotNull(javaClass.getResourceAsStream(TOKENIZER_FILE_NAME)) { + "Test resource $TOKENIZER_FILE_NAME not found; did android_test_setup.sh run?" + } + .use { FileUtils.copyInputStreamToFile(it, tokenizerFile) } + } + + @After + fun tearDown() { + llmModule?.close() + llmModule = null + } + + // ─── dataFiles constructor variants ───────────────────────────────────────── + + @Test(timeout = MAX_TEST_TIMEOUT_MS) + fun testConstructorWithEmptyDataFilesList() { + llmModule = + LlmModule( + LlmModule.MODEL_TYPE_TEXT, + getTestFilePath(MODEL_FILE_NAME), + getTestFilePath(TOKENIZER_FILE_NAME), + 0.0f, + emptyList(), + ) + val tokens = generateAndCollect(llmModule!!) + assertTrue("Module with empty dataFiles should generate tokens", tokens.isNotEmpty()) + } + + @Test(timeout = MAX_TEST_TIMEOUT_MS) + fun testConstructorWithNullDataPath() { + llmModule = + LlmModule( + LlmModule.MODEL_TYPE_TEXT, + getTestFilePath(MODEL_FILE_NAME), + getTestFilePath(TOKENIZER_FILE_NAME), + 0.0f, + null as String?, + ) + val tokens = generateAndCollect(llmModule!!) + assertTrue("Module with null dataPath should generate tokens", tokens.isNotEmpty()) + } + + @Test(timeout = MAX_TEST_TIMEOUT_MS) + fun testConstructorWithDataFilesAndBosEos() { + llmModule = + LlmModule( + LlmModule.MODEL_TYPE_TEXT, + getTestFilePath(MODEL_FILE_NAME), + getTestFilePath(TOKENIZER_FILE_NAME), + 0.0f, + emptyList(), + 0, + 0, + ) + val tokens = generateAndCollect(llmModule!!) + assertTrue("Module with dataFiles+BOS/EOS should generate tokens", tokens.isNotEmpty()) + } + + // ─── LlmModuleConfig with dataPath ────────────────────────────────────────── + + @Test(timeout = MAX_TEST_TIMEOUT_MS) + fun testLlmModuleConfigNoDataPath() { + val config = + LlmModuleConfig.create() + .modulePath(getTestFilePath(MODEL_FILE_NAME)) + .tokenizerPath(getTestFilePath(TOKENIZER_FILE_NAME)) + .temperature(0.0f) + .build() + llmModule = LlmModule(config) + val tokens = generateAndCollect(llmModule!!) + assertTrue("Module via config with no dataPath should generate tokens", tokens.isNotEmpty()) + } + + @Test(timeout = MAX_TEST_TIMEOUT_MS) + fun testLlmModuleConfigWithNullDataPath() { + val config = + LlmModuleConfig.create() + .modulePath(getTestFilePath(MODEL_FILE_NAME)) + .tokenizerPath(getTestFilePath(TOKENIZER_FILE_NAME)) + .temperature(0.0f) + .dataPath(null) + .build() + llmModule = LlmModule(config) + val tokens = generateAndCollect(llmModule!!) + assertTrue("Module via config with null dataPath should generate tokens", tokens.isNotEmpty()) + } + + @Test(timeout = MAX_TEST_TIMEOUT_MS) + fun testLlmModuleConfigWithLoadMode() { + val config = + LlmModuleConfig.create() + .modulePath(getTestFilePath(MODEL_FILE_NAME)) + .tokenizerPath(getTestFilePath(TOKENIZER_FILE_NAME)) + .temperature(0.0f) + .loadMode(LlmModuleConfig.LOAD_MODE_FILE) + .build() + llmModule = LlmModule(config) + val tokens = generateAndCollect(llmModule!!) + assertTrue("Module via config with LOAD_MODE_FILE should generate tokens", tokens.isNotEmpty()) + } + + // ─── Invalid data file paths ──────────────────────────────────────────────── + + @Test(timeout = MAX_TEST_TIMEOUT_MS) + fun testInvalidDataFilePathThrowsOnConstruction() { + try { + llmModule = + LlmModule( + LlmModule.MODEL_TYPE_TEXT, + getTestFilePath(MODEL_FILE_NAME), + getTestFilePath(TOKENIZER_FILE_NAME), + 0.0f, + listOf("/nonexistent/lora_weights.bin"), + ) + // dataFiles are passed to native initHybrid — invalid paths should cause + // construction to fail. If we reach here, the native layer didn't validate. + llmModule!!.close() + fail("Construction should have thrown for invalid data file path") + } catch (e: RuntimeException) { + assertTrue( + "Exception message should be non-empty", + e.message != null && e.message!!.isNotEmpty(), + ) + } + } + + @Test(timeout = MAX_TEST_TIMEOUT_MS) + fun testMultipleInvalidDataFilePathsThrowOnConstruction() { + try { + llmModule = + LlmModule( + LlmModule.MODEL_TYPE_TEXT, + getTestFilePath(MODEL_FILE_NAME), + getTestFilePath(TOKENIZER_FILE_NAME), + 0.0f, + listOf("/nonexistent/a.bin", "/nonexistent/b.bin"), + ) + llmModule!!.close() + fail("Construction should have thrown for invalid data file paths") + } catch (e: RuntimeException) { + assertTrue( + "Exception message should be non-empty", + e.message != null && e.message!!.isNotEmpty(), + ) + } + } + + // ─── Baseline equivalence ─────────────────────────────────────────────────── + + @Test(timeout = MAX_TEST_TIMEOUT_MS) + fun testEmptyDataFilesMatchesNoDataConstructor() { + val moduleNoData = + LlmModule(getTestFilePath(MODEL_FILE_NAME), getTestFilePath(TOKENIZER_FILE_NAME), 0.0f) + val moduleEmptyList = + LlmModule( + LlmModule.MODEL_TYPE_TEXT, + getTestFilePath(MODEL_FILE_NAME), + getTestFilePath(TOKENIZER_FILE_NAME), + 0.0f, + emptyList(), + ) + + try { + val tokensNoData = generateAndCollect(moduleNoData) + val tokensEmptyList = generateAndCollect(moduleEmptyList) + + assertTrue("Both constructors should produce tokens", tokensNoData.isNotEmpty()) + assertTrue("Both constructors should produce tokens", tokensEmptyList.isNotEmpty()) + } finally { + moduleNoData.close() + moduleEmptyList.close() + } + } + + // ─── LlmModuleConfig builder validation ───────────────────────────────────── + + @Test(expected = IllegalArgumentException::class) + fun testConfigBuilderMissingModulePathThrows() { + LlmModuleConfig.create().tokenizerPath("/some/tokenizer.bin").build() + } + + @Test(expected = IllegalArgumentException::class) + fun testConfigBuilderMissingTokenizerPathThrows() { + LlmModuleConfig.create().modulePath("/some/model.pte").build() + } + + @Test(expected = IllegalArgumentException::class) + fun testConfigBuilderInvalidLoadModeThrows() { + LlmModuleConfig.create() + .modulePath("/some/model.pte") + .tokenizerPath("/some/tokenizer.bin") + .loadMode(99) + .build() + } + + @Test + fun testConfigBuilderAllLoadModes() { + val modes = + listOf( + LlmModuleConfig.LOAD_MODE_FILE, + LlmModuleConfig.LOAD_MODE_MMAP, + LlmModuleConfig.LOAD_MODE_MMAP_USE_MLOCK, + LlmModuleConfig.LOAD_MODE_MMAP_USE_MLOCK_IGNORE_ERRORS, + ) + for (mode in modes) { + val config = + LlmModuleConfig.create() + .modulePath("/some/model.pte") + .tokenizerPath("/some/tokenizer.bin") + .loadMode(mode) + .build() + assertTrue("Config should accept load mode $mode", config.loadMode == mode) + } + } + + // ─── Helpers ──────────────────────────────────────────────────────────────── + + private fun generateAndCollect(module: LlmModule): List { + val collector = mutableListOf() + module.generate( + TEST_PROMPT, + SEQ_LEN, + object : LlmCallback { + override fun onResult(result: String) { + collector.add(result) + } + }, + ) + return collector + } + + companion object { + private const val MODEL_FILE_NAME = "/stories.pte" + private const val TOKENIZER_FILE_NAME = "/tokenizer.bin" + private const val TEST_PROMPT = "Once" + private const val SEQ_LEN = 16 + private const val MAX_TEST_TIMEOUT_MS = 120_000L + } +} diff --git a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt index b2f10537c2f..1888466ffa6 100644 --- a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt +++ b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt @@ -94,7 +94,7 @@ class ModuleInstrumentationTest { } Assert.assertEquals( ExecutorchRuntimeException.INVALID_ARGUMENT, - exception.getErrorCode(), + exception.errorCode, ) } finally { module.destroy() diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/DType.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/DType.kt similarity index 77% rename from extension/android/executorch_android/src/main/java/org/pytorch/executorch/DType.java rename to extension/android/executorch_android/src/main/java/org/pytorch/executorch/DType.kt index 3aca4871d64..a58baa34b60 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/DType.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/DType.kt @@ -6,17 +6,17 @@ * LICENSE file in the root directory of this source tree. */ -package org.pytorch.executorch; +package org.pytorch.executorch -import org.pytorch.executorch.annotations.Experimental; +import org.pytorch.executorch.annotations.Experimental /** * Codes representing tensor data types. * - *

Warning: These APIs are experimental and subject to change without notice + * Warning: These APIs are experimental and subject to change without notice */ @Experimental -public enum DType { +enum class DType(@JvmField val jniCode: Int) { // NOTE: "jniCode" must be kept in sync with scalar_type.h. // NOTE: Never serialize "jniCode", because it can change between releases. @@ -68,18 +68,10 @@ public enum DType { BITS16(22), ; - final int jniCode; - - DType(int jniCode) { - this.jniCode = jniCode; - } - - public static DType fromJniCode(int jniCode) { - for (DType dtype : values()) { - if (dtype.jniCode == jniCode) { - return dtype; - } - } - throw new IllegalArgumentException("No DType found for jniCode " + jniCode); + companion object { + @JvmStatic + fun fromJniCode(jniCode: Int): DType = + entries.find { it.jniCode == jniCode } + ?: throw IllegalArgumentException("No DType found for jniCode $jniCode") } } diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/EValue.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/EValue.java deleted file mode 100644 index e85efb291e7..00000000000 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/EValue.java +++ /dev/null @@ -1,253 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -package org.pytorch.executorch; - -import com.facebook.jni.annotations.DoNotStrip; -import java.nio.ByteBuffer; -import java.nio.charset.StandardCharsets; -import java.util.Arrays; -import java.util.Locale; -import org.pytorch.executorch.annotations.Experimental; - -/** - * Java representation of an ExecuTorch value, which is implemented as tagged union that can be one - * of the supported types: https://pytorch.org/docs/stable/jit.html#types . - * - *

Calling {@code toX} methods for inappropriate types will throw {@link IllegalStateException}. - * - *

{@code EValue} objects are constructed with {@code EValue.from(value)}, {@code - * EValue.tupleFrom(value1, value2, ...)}, {@code EValue.listFrom(value1, value2, ...)}, or one of - * the {@code dict} methods, depending on the key type. - * - *

Data is retrieved from {@code EValue} objects with the {@code toX()} methods. Note that {@code - * str}-type EValues must be extracted with {@link #toStr()}, rather than {@link #toString()}. - * - *

{@code EValue} objects may retain references to objects passed into their constructors, and - * may return references to their internal state from {@code toX()}. - * - *

Warning: These APIs are experimental and subject to change without notice - */ -@Experimental -@DoNotStrip -public class EValue { - private static final int TYPE_CODE_NONE = 0; - - private static final int TYPE_CODE_TENSOR = 1; - private static final int TYPE_CODE_STRING = 2; - private static final int TYPE_CODE_DOUBLE = 3; - private static final int TYPE_CODE_INT = 4; - private static final int TYPE_CODE_BOOL = 5; - - private String[] TYPE_NAMES = { - "None", "Tensor", "String", "Double", "Int", "Bool", - }; - - @DoNotStrip private final int mTypeCode; - @DoNotStrip private Object mData; - - @DoNotStrip - private EValue(int typeCode) { - this.mTypeCode = typeCode; - } - - @DoNotStrip - public boolean isNone() { - return TYPE_CODE_NONE == this.mTypeCode; - } - - @DoNotStrip - public boolean isTensor() { - return TYPE_CODE_TENSOR == this.mTypeCode; - } - - @DoNotStrip - public boolean isBool() { - return TYPE_CODE_BOOL == this.mTypeCode; - } - - @DoNotStrip - public boolean isInt() { - return TYPE_CODE_INT == this.mTypeCode; - } - - @DoNotStrip - public boolean isDouble() { - return TYPE_CODE_DOUBLE == this.mTypeCode; - } - - @DoNotStrip - public boolean isString() { - return TYPE_CODE_STRING == this.mTypeCode; - } - - /** Creates a new {@code EValue} of type {@code Optional} that contains no value. */ - @DoNotStrip - public static EValue optionalNone() { - return new EValue(TYPE_CODE_NONE); - } - - /** Creates a new {@code EValue} of type {@code Tensor}. */ - @DoNotStrip - public static EValue from(Tensor tensor) { - final EValue iv = new EValue(TYPE_CODE_TENSOR); - iv.mData = tensor; - return iv; - } - - /** Creates a new {@code EValue} of type {@code bool}. */ - @DoNotStrip - public static EValue from(boolean value) { - final EValue iv = new EValue(TYPE_CODE_BOOL); - iv.mData = value; - return iv; - } - - /** Creates a new {@code EValue} of type {@code int}. */ - @DoNotStrip - public static EValue from(long value) { - final EValue iv = new EValue(TYPE_CODE_INT); - iv.mData = value; - return iv; - } - - /** Creates a new {@code EValue} of type {@code double}. */ - @DoNotStrip - public static EValue from(double value) { - final EValue iv = new EValue(TYPE_CODE_DOUBLE); - iv.mData = value; - return iv; - } - - /** Creates a new {@code EValue} of type {@code str}. */ - @DoNotStrip - public static EValue from(String value) { - final EValue iv = new EValue(TYPE_CODE_STRING); - iv.mData = value; - return iv; - } - - @DoNotStrip - public Tensor toTensor() { - preconditionType(TYPE_CODE_TENSOR, mTypeCode); - return (Tensor) mData; - } - - @DoNotStrip - public boolean toBool() { - preconditionType(TYPE_CODE_BOOL, mTypeCode); - return (boolean) mData; - } - - @DoNotStrip - public long toInt() { - preconditionType(TYPE_CODE_INT, mTypeCode); - return (long) mData; - } - - @DoNotStrip - public double toDouble() { - preconditionType(TYPE_CODE_DOUBLE, mTypeCode); - return (double) mData; - } - - @DoNotStrip - public String toStr() { - preconditionType(TYPE_CODE_STRING, mTypeCode); - return (String) mData; - } - - private void preconditionType(int typeCodeExpected, int typeCode) { - if (typeCode != typeCodeExpected) { - throw new IllegalStateException( - String.format( - Locale.US, - "Expected EValue type %s, actual type %s", - getTypeName(typeCodeExpected), - getTypeName(typeCode))); - } - } - - private String getTypeName(int typeCode) { - return typeCode >= 0 && typeCode < TYPE_NAMES.length ? TYPE_NAMES[typeCode] : "Unknown"; - } - - /** - * Serializes an {@code EValue} into a byte array. Note: This method is experimental and subject - * to change without notice. - * - * @return The serialized byte array. - */ - public byte[] toByteArray() { - if (isNone()) { - return ByteBuffer.allocate(1).put((byte) TYPE_CODE_NONE).array(); - } else if (isTensor()) { - Tensor t = toTensor(); - byte[] tByteArray = t.toByteArray(); - return ByteBuffer.allocate(1 + tByteArray.length) - .put((byte) TYPE_CODE_TENSOR) - .put(tByteArray) - .array(); - } else if (isBool()) { - return ByteBuffer.allocate(2) - .put((byte) TYPE_CODE_BOOL) - .put((byte) (toBool() ? 1 : 0)) - .array(); - } else if (isInt()) { - return ByteBuffer.allocate(9).put((byte) TYPE_CODE_INT).putLong(toInt()).array(); - } else if (isDouble()) { - return ByteBuffer.allocate(9).put((byte) TYPE_CODE_DOUBLE).putDouble(toDouble()).array(); - } else if (isString()) { - byte[] strBytes = toStr().getBytes(StandardCharsets.UTF_8); - return ByteBuffer.allocate(1 + 4 + strBytes.length) - .put((byte) TYPE_CODE_STRING) - .putInt(strBytes.length) - .put(strBytes) - .array(); - } else { - throw new IllegalArgumentException("Unknown EValue type code: " + mTypeCode); - } - } - - /** - * Deserializes an {@code EValue} from a byte[]. Note: This method is experimental and subject to - * change without notice. - * - * @param bytes The byte array to deserialize from. - * @return The deserialized {@code EValue}. - */ - public static EValue fromByteArray(byte[] bytes) { - ByteBuffer buffer = ByteBuffer.wrap(bytes); - if (buffer == null) { - throw new IllegalArgumentException("buffer cannot be null"); - } - if (!buffer.hasRemaining()) { - throw new IllegalArgumentException("invalid buffer"); - } - int typeCode = buffer.get(); - switch (typeCode) { - case TYPE_CODE_NONE: - return new EValue(TYPE_CODE_NONE); - case TYPE_CODE_TENSOR: - byte[] bufferArray = buffer.array(); - return from(Tensor.fromByteArray(Arrays.copyOfRange(bufferArray, 1, bufferArray.length))); - case TYPE_CODE_STRING: - int strLen = buffer.getInt(); - byte[] strBytes = new byte[strLen]; - buffer.get(strBytes); - return from(new String(strBytes, StandardCharsets.UTF_8)); - case TYPE_CODE_DOUBLE: - return from(buffer.getDouble()); - case TYPE_CODE_INT: - return from(buffer.getLong()); - case TYPE_CODE_BOOL: - return from(buffer.get() != 0); - } - throw new IllegalArgumentException("invalid type code: " + typeCode); - } -} diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/EValue.kt b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/EValue.kt new file mode 100644 index 00000000000..08c02d5c84a --- /dev/null +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/EValue.kt @@ -0,0 +1,209 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package org.pytorch.executorch + +import com.facebook.jni.annotations.DoNotStrip +import java.nio.ByteBuffer +import java.nio.charset.StandardCharsets +import java.util.Arrays +import java.util.Locale +import org.pytorch.executorch.annotations.Experimental + +/** + * Java representation of an ExecuTorch value, which is implemented as tagged union that can be one + * of the supported types: https://pytorch.org/docs/stable/jit.html#types . + * + * Calling `toX` methods for inappropriate types will throw [IllegalStateException]. + * + * `EValue` objects are constructed with `EValue.from(value)`, depending on the value type. + * + * Data is retrieved from `EValue` objects with the `toX()` methods. Note that `str`-type EValues + * must be extracted with [toStr], rather than [toString]. + * + * `EValue` objects may retain references to objects passed into their constructors, and may return + * references to their internal state from `toX()`. + * + * Warning: These APIs are experimental and subject to change without notice + */ +@Experimental +@DoNotStrip +class EValue +@DoNotStrip +private constructor( + // JNI reads this field by name via GetFieldID("mTypeCode") + @JvmField @DoNotStrip val mTypeCode: Int +) { + + // JNI accesses this field by name via GetFieldID("mData"), requires @JvmField for direct field + // access + @JvmField @DoNotStrip var mData: Any? = null + + private val typeNames = arrayOf("None", "Tensor", "String", "Double", "Int", "Bool") + + val isNone: Boolean + @DoNotStrip get() = TYPE_CODE_NONE == mTypeCode + + val isTensor: Boolean + @DoNotStrip get() = TYPE_CODE_TENSOR == mTypeCode + + val isBool: Boolean + @DoNotStrip get() = TYPE_CODE_BOOL == mTypeCode + + val isInt: Boolean + @DoNotStrip get() = TYPE_CODE_INT == mTypeCode + + val isDouble: Boolean + @DoNotStrip get() = TYPE_CODE_DOUBLE == mTypeCode + + val isString: Boolean + @DoNotStrip get() = TYPE_CODE_STRING == mTypeCode + + @DoNotStrip + fun toTensor(): Tensor { + preconditionType(TYPE_CODE_TENSOR, mTypeCode) + return mData as? Tensor ?: throw IllegalStateException("EValue data is null or not a Tensor") + } + + @DoNotStrip + fun toBool(): Boolean { + preconditionType(TYPE_CODE_BOOL, mTypeCode) + return mData as? Boolean ?: throw IllegalStateException("EValue data is null or not a Boolean") + } + + @DoNotStrip + fun toInt(): Long { + preconditionType(TYPE_CODE_INT, mTypeCode) + return mData as? Long ?: throw IllegalStateException("EValue data is null or not a Long") + } + + @DoNotStrip + fun toDouble(): Double { + preconditionType(TYPE_CODE_DOUBLE, mTypeCode) + return mData as? Double ?: throw IllegalStateException("EValue data is null or not a Double") + } + + @DoNotStrip + fun toStr(): String { + preconditionType(TYPE_CODE_STRING, mTypeCode) + return mData as? String ?: throw IllegalStateException("EValue data is null or not a String") + } + + private fun preconditionType(typeCodeExpected: Int, typeCode: Int) { + if (typeCode != typeCodeExpected) { + throw IllegalStateException( + String.format( + Locale.US, + "Expected EValue type %s, actual type %s", + getTypeName(typeCodeExpected), + getTypeName(typeCode), + ) + ) + } + } + + private fun getTypeName(typeCode: Int): String = + if (typeCode in typeNames.indices) typeNames[typeCode] else "Unknown" + + /** + * Serializes an `EValue` into a byte array. Note: This method is experimental and subject to + * change without notice. + */ + fun toByteArray(): ByteArray = + when { + isNone -> ByteBuffer.allocate(1).put(TYPE_CODE_NONE.toByte()).array() + isTensor -> { + val tByteArray = toTensor().toByteArray() + ByteBuffer.allocate(1 + tByteArray.size) + .put(TYPE_CODE_TENSOR.toByte()) + .put(tByteArray) + .array() + } + isBool -> + ByteBuffer.allocate(2) + .put(TYPE_CODE_BOOL.toByte()) + .put(if (toBool()) 1.toByte() else 0.toByte()) + .array() + isInt -> ByteBuffer.allocate(9).put(TYPE_CODE_INT.toByte()).putLong(toInt()).array() + isDouble -> + ByteBuffer.allocate(9).put(TYPE_CODE_DOUBLE.toByte()).putDouble(toDouble()).array() + isString -> { + val strBytes = toStr().toByteArray(StandardCharsets.UTF_8) + ByteBuffer.allocate(1 + 4 + strBytes.size) + .put(TYPE_CODE_STRING.toByte()) + .putInt(strBytes.size) + .put(strBytes) + .array() + } + else -> throw IllegalArgumentException("Unknown EValue type code: $mTypeCode") + } + + companion object { + private const val TYPE_CODE_NONE = 0 + private const val TYPE_CODE_TENSOR = 1 + private const val TYPE_CODE_STRING = 2 + private const val TYPE_CODE_DOUBLE = 3 + private const val TYPE_CODE_INT = 4 + private const val TYPE_CODE_BOOL = 5 + + /** Creates a new `EValue` of type `Optional` that contains no value. */ + @DoNotStrip @JvmStatic fun optionalNone(): EValue = EValue(TYPE_CODE_NONE) + + /** Creates a new `EValue` of type `Tensor`. */ + @DoNotStrip + @JvmStatic + fun from(tensor: Tensor): EValue = EValue(TYPE_CODE_TENSOR).also { it.mData = tensor } + + /** Creates a new `EValue` of type `bool`. */ + @DoNotStrip + @JvmStatic + fun from(value: Boolean): EValue = EValue(TYPE_CODE_BOOL).also { it.mData = value } + + /** Creates a new `EValue` of type `int`. */ + @DoNotStrip + @JvmStatic + fun from(value: Long): EValue = EValue(TYPE_CODE_INT).also { it.mData = value } + + /** Creates a new `EValue` of type `double`. */ + @DoNotStrip + @JvmStatic + fun from(value: Double): EValue = EValue(TYPE_CODE_DOUBLE).also { it.mData = value } + + /** Creates a new `EValue` of type `str`. */ + @DoNotStrip + @JvmStatic + fun from(value: String): EValue = EValue(TYPE_CODE_STRING).also { it.mData = value } + + /** + * Deserializes an `EValue` from a byte[]. Note: This method is experimental and subject to + * change without notice. + */ + @JvmStatic + fun fromByteArray(bytes: ByteArray): EValue { + val buffer = ByteBuffer.wrap(bytes) + require(buffer.hasRemaining()) { "invalid buffer" } + return when (val typeCode = buffer.get().toInt()) { + TYPE_CODE_NONE -> EValue(TYPE_CODE_NONE) + TYPE_CODE_TENSOR -> { + val bufferArray = buffer.array() + from(Tensor.fromByteArray(Arrays.copyOfRange(bufferArray, 1, bufferArray.size))) + } + TYPE_CODE_STRING -> { + val strLen = buffer.getInt() + val strBytes = ByteArray(strLen) + buffer.get(strBytes) + from(String(strBytes, StandardCharsets.UTF_8)) + } + TYPE_CODE_DOUBLE -> from(buffer.getDouble()) + TYPE_CODE_INT -> from(buffer.getLong()) + TYPE_CODE_BOOL -> from(buffer.get().toInt() != 0) + else -> throw IllegalArgumentException("invalid type code: $typeCode") + } + } + } +} diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecuTorchRuntime.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecuTorchRuntime.java deleted file mode 100644 index 6372da9a397..00000000000 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecuTorchRuntime.java +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -package org.pytorch.executorch; - -import com.facebook.jni.annotations.DoNotStrip; -import com.facebook.soloader.nativeloader.NativeLoader; -import com.facebook.soloader.nativeloader.SystemDelegate; -import java.io.File; - -/** Class for entire ExecuTorch Runtime related functions. */ -public class ExecuTorchRuntime { - - static { - if (!NativeLoader.isInitialized()) { - NativeLoader.init(new SystemDelegate()); - } - // Loads libexecutorch.so from jniLibs - NativeLoader.loadLibrary("executorch"); - } - - private static final ExecuTorchRuntime sInstance = new ExecuTorchRuntime(); - - private ExecuTorchRuntime() {} - - /** Get the runtime instance. */ - public static ExecuTorchRuntime getRuntime() { - return sInstance; - } - - /** - * Validates that the given path points to a readable file. - * - * @throws IllegalArgumentException if the path is null, does not exist, is not a file, or is not - * readable. - */ - public static void validateFilePath(String path, String description) { - if (path == null) { - throw new IllegalArgumentException("Cannot load " + description + ": path is null"); - } - File file = new File(path); - if (!file.exists()) { - throw new IllegalArgumentException( - "Cannot load " + description + ": path does not exist: " + path); - } - if (!file.isFile()) { - throw new IllegalArgumentException( - "Cannot load " + description + ": path is not a file: " + path); - } - if (!file.canRead()) { - throw new IllegalArgumentException( - "Cannot load " + description + ": path is not readable: " + path); - } - } - - /** Get all registered ops. */ - @DoNotStrip - public static native String[] getRegisteredOps(); - - /** Get all registered backends. */ - @DoNotStrip - public static native String[] getRegisteredBackends(); -} diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecuTorchRuntime.kt b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecuTorchRuntime.kt new file mode 100644 index 00000000000..52d846c5647 --- /dev/null +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecuTorchRuntime.kt @@ -0,0 +1,62 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package org.pytorch.executorch + +import com.facebook.jni.annotations.DoNotStrip +import com.facebook.soloader.nativeloader.NativeLoader +import com.facebook.soloader.nativeloader.SystemDelegate +import java.io.File + +/** Class for entire ExecuTorch Runtime related functions. */ +class ExecuTorchRuntime private constructor() { + + companion object { + init { + if (!NativeLoader.isInitialized()) { + NativeLoader.init(SystemDelegate()) + } + // Loads libexecutorch.so from jniLibs + NativeLoader.loadLibrary("executorch") + } + + private val sInstance = ExecuTorchRuntime() + + /** Get the runtime instance. */ + @JvmStatic fun getRuntime(): ExecuTorchRuntime = sInstance + + /** + * Validates that the given path points to a readable file. + * + * @throws IllegalArgumentException if the path is null, does not exist, is not a file, or is + * not readable. + */ + @JvmStatic + fun validateFilePath(path: String?, description: String) { + if (path == null) { + throw IllegalArgumentException("Cannot load $description: path is null") + } + val file = File(path) + if (!file.exists()) { + throw IllegalArgumentException("Cannot load $description: path does not exist: $path") + } + if (!file.isFile) { + throw IllegalArgumentException("Cannot load $description: path is not a file: $path") + } + if (!file.canRead()) { + throw IllegalArgumentException("Cannot load $description: path is not readable: $path") + } + } + + /** Get all registered ops. */ + @DoNotStrip @JvmStatic external fun getRegisteredOps(): Array + + /** Get all registered backends. */ + @DoNotStrip @JvmStatic external fun getRegisteredBackends(): Array + } +} diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecutorchRuntimeException.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecutorchRuntimeException.java deleted file mode 100644 index 6f9d654be66..00000000000 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecutorchRuntimeException.java +++ /dev/null @@ -1,198 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -package org.pytorch.executorch; - -import com.facebook.jni.annotations.DoNotStrip; -import java.util.Collections; -import java.util.HashMap; -import java.util.Map; - -/** - * Base exception for all ExecuTorch runtime errors. Each instance carries an integer error code - * corresponding to the native {@code runtime/core/error.h} values, accessible via {@link - * #getErrorCode()}. - */ -public class ExecutorchRuntimeException extends RuntimeException { - // Error code constants - keep in sync with runtime/core/error.h - - // System errors - - /** Operation completed successfully. */ - public static final int OK = 0x00; - - /** An unexpected internal error occurred in the runtime. */ - public static final int INTERNAL = 0x01; - - /** The runtime or method is in an invalid state for the requested operation. */ - public static final int INVALID_STATE = 0x02; - - /** The method has finished execution and has no more work to do. */ - public static final int END_OF_METHOD = 0x03; - - /** A required resource has already been loaded. */ - public static final int ALREADY_LOADED = 0x04; - - // Logical errors - - /** The requested operation is not supported by this build or backend. */ - public static final int NOT_SUPPORTED = 0x10; - - /** The requested operation has not been implemented. */ - public static final int NOT_IMPLEMENTED = 0x11; - - /** One or more arguments passed to the operation are invalid. */ - public static final int INVALID_ARGUMENT = 0x12; - - /** A value or tensor has an unexpected type. */ - public static final int INVALID_TYPE = 0x13; - - /** A required operator kernel is not registered. */ - public static final int OPERATOR_MISSING = 0x14; - - /** The maximum number of registered kernels has been exceeded. */ - public static final int REGISTRATION_EXCEEDING_MAX_KERNELS = 0x15; - - /** A kernel with the same name is already registered. */ - public static final int REGISTRATION_ALREADY_REGISTERED = 0x16; - - // Resource errors - - /** A required resource (file, tensor, program) was not found. */ - public static final int NOT_FOUND = 0x20; - - /** A memory allocation failed. */ - public static final int MEMORY_ALLOCATION_FAILED = 0x21; - - /** Access to a resource was denied or failed. */ - public static final int ACCESS_FAILED = 0x22; - - /** The loaded program is malformed or incompatible. */ - public static final int INVALID_PROGRAM = 0x23; - - /** External data referenced by the program is invalid or missing. */ - public static final int INVALID_EXTERNAL_DATA = 0x24; - - /** The system has run out of a required resource. */ - public static final int OUT_OF_RESOURCES = 0x25; - - // Delegate errors - - /** A delegate reported an incompatible model or configuration. */ - public static final int DELEGATE_INVALID_COMPATIBILITY = 0x30; - - /** A delegate failed to allocate required memory. */ - public static final int DELEGATE_MEMORY_ALLOCATION_FAILED = 0x31; - - /** A delegate received an invalid or stale handle. */ - public static final int DELEGATE_INVALID_HANDLE = 0x32; - - private static final Map ERROR_CODE_MESSAGES; - - static { - Map map = new HashMap<>(); - - // System errors - map.put(OK, "Operation successful"); - map.put(INTERNAL, "Internal error"); - map.put(INVALID_STATE, "Invalid state"); - map.put(END_OF_METHOD, "End of method reached"); - map.put(ALREADY_LOADED, "Already loaded"); - // Logical errors - map.put(NOT_SUPPORTED, "Operation not supported"); - map.put(NOT_IMPLEMENTED, "Operation not implemented"); - map.put(INVALID_ARGUMENT, "Invalid argument"); - map.put(INVALID_TYPE, "Invalid type"); - map.put(OPERATOR_MISSING, "Operator missing"); - map.put(REGISTRATION_EXCEEDING_MAX_KERNELS, "Exceeded max kernels"); - map.put(REGISTRATION_ALREADY_REGISTERED, "Kernel already registered"); - // Resource errors - map.put(NOT_FOUND, "Resource not found"); - map.put(MEMORY_ALLOCATION_FAILED, "Memory allocation failed"); - map.put(ACCESS_FAILED, "Access failed"); - map.put(INVALID_PROGRAM, "Invalid program"); - map.put(INVALID_EXTERNAL_DATA, "Invalid external data"); - map.put(OUT_OF_RESOURCES, "Out of resources"); - // Delegate errors - map.put(DELEGATE_INVALID_COMPATIBILITY, "Delegate invalid compatibility"); - map.put(DELEGATE_MEMORY_ALLOCATION_FAILED, "Delegate memory allocation failed"); - map.put(DELEGATE_INVALID_HANDLE, "Delegate invalid handle"); - ERROR_CODE_MESSAGES = Collections.unmodifiableMap(map); - } - - static class ErrorHelper { - static String formatMessage(int errorCode, String details) { - String baseMessage = ERROR_CODE_MESSAGES.get(errorCode); - if (baseMessage == null) { - baseMessage = "Unknown error code 0x" + Integer.toHexString(errorCode); - } - - String safeDetails = details != null ? details : "No details provided"; - return String.format( - "[ExecuTorch Error 0x%s] %s: %s", - Integer.toHexString(errorCode), baseMessage, safeDetails); - } - - static String getDetailedErrorLogs() { - StringBuilder sb = new StringBuilder(); - try { - String[] logEntries = Module.readLogBufferStatic(); // JNI call - if (logEntries != null && logEntries.length > 0) { - sb.append("\nDetailed logs:\n"); - for (String entry : logEntries) { - sb.append(entry).append("\n"); - } - } - } catch (Exception e) { - sb.append("Failed to retrieve detailed logs: ").append(e.getMessage()); - } - return sb.toString(); - } - } - - private final int errorCode; - - @DoNotStrip - public ExecutorchRuntimeException(int errorCode, String details) { - super(ErrorHelper.formatMessage(errorCode, details)); - this.errorCode = errorCode; - } - - public ExecutorchRuntimeException(int errorCode, String details, Throwable cause) { - super(ErrorHelper.formatMessage(errorCode, details), cause); - this.errorCode = errorCode; - } - - /** Returns the numeric error code from {@code runtime/core/error.h}. */ - public int getErrorCode() { - return errorCode; - } - - /** Returns detailed log output captured from the native runtime, if available. */ - public String getDetailedError() { - return ErrorHelper.getDetailedErrorLogs(); - } - - @DoNotStrip - public static class ExecutorchInvalidArgumentException extends ExecutorchRuntimeException { - @DoNotStrip - public ExecutorchInvalidArgumentException(String details) { - super(INVALID_ARGUMENT, details); - } - } - - @DoNotStrip - public static RuntimeException makeExecutorchException(int errorCode, String details) { - switch (errorCode) { - case INVALID_ARGUMENT: - return new ExecutorchInvalidArgumentException(details); - default: - return new ExecutorchRuntimeException(errorCode, details); - } - } -} diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecutorchRuntimeException.kt b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecutorchRuntimeException.kt new file mode 100644 index 00000000000..5ec3dd255d8 --- /dev/null +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecutorchRuntimeException.kt @@ -0,0 +1,133 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package org.pytorch.executorch + +import com.facebook.jni.annotations.DoNotStrip + +/** + * Base exception for all ExecuTorch runtime errors. Each instance carries an integer error code + * corresponding to the native `runtime/core/error.h` values, accessible via [getErrorCode]. + */ +open class ExecutorchRuntimeException +@DoNotStrip +constructor( + val errorCode: Int, + details: String?, +) : RuntimeException(ErrorHelper.formatMessage(errorCode, details)) { + + constructor( + errorCode: Int, + details: String?, + cause: Throwable?, + ) : this(errorCode, details) { + if (cause != null) initCause(cause) + } + + /** Returns detailed log output captured from the native runtime, if available. */ + fun getDetailedError(): String = ErrorHelper.getDetailedErrorLogs() + + @DoNotStrip + class ExecutorchInvalidArgumentException @DoNotStrip constructor(details: String?) : + ExecutorchRuntimeException(INVALID_ARGUMENT, details) + + private object ErrorHelper { + private val ERROR_CODE_MESSAGES: Map = + mapOf( + // System errors + OK to "Operation successful", + INTERNAL to "Internal error", + INVALID_STATE to "Invalid state", + END_OF_METHOD to "End of method reached", + ALREADY_LOADED to "Already loaded", + // Logical errors + NOT_SUPPORTED to "Operation not supported", + NOT_IMPLEMENTED to "Operation not implemented", + INVALID_ARGUMENT to "Invalid argument", + INVALID_TYPE to "Invalid type", + OPERATOR_MISSING to "Operator missing", + REGISTRATION_EXCEEDING_MAX_KERNELS to "Exceeded max kernels", + REGISTRATION_ALREADY_REGISTERED to "Kernel already registered", + // Resource errors + NOT_FOUND to "Resource not found", + MEMORY_ALLOCATION_FAILED to "Memory allocation failed", + ACCESS_FAILED to "Access failed", + INVALID_PROGRAM to "Invalid program", + INVALID_EXTERNAL_DATA to "Invalid external data", + OUT_OF_RESOURCES to "Out of resources", + // Delegate errors + DELEGATE_INVALID_COMPATIBILITY to "Delegate invalid compatibility", + DELEGATE_MEMORY_ALLOCATION_FAILED to "Delegate memory allocation failed", + DELEGATE_INVALID_HANDLE to "Delegate invalid handle", + ) + + fun formatMessage(errorCode: Int, details: String?): String { + val baseMessage = + ERROR_CODE_MESSAGES[errorCode] ?: "Unknown error code 0x${Integer.toHexString(errorCode)}" + val safeDetails = details ?: "No details provided" + return "[ExecuTorch Error 0x${Integer.toHexString(errorCode)}] $baseMessage: $safeDetails" + } + + fun getDetailedErrorLogs(): String { + val sb = StringBuilder() + try { + val logEntries = Module.readLogBufferStatic() // JNI call + if (logEntries != null && logEntries.isNotEmpty()) { + sb.append("\nDetailed logs:\n") + for (entry in logEntries) { + sb.append(entry).append("\n") + } + } + } catch (e: Exception) { + sb.append("Failed to retrieve detailed logs: ").append(e.message) + } + return sb.toString() + } + } + + companion object { + // Error code constants - keep in sync with runtime/core/error.h + + // System errors + const val OK = 0x00 + const val INTERNAL = 0x01 + const val INVALID_STATE = 0x02 + const val END_OF_METHOD = 0x03 + const val ALREADY_LOADED = 0x04 + + // Logical errors + const val NOT_SUPPORTED = 0x10 + const val NOT_IMPLEMENTED = 0x11 + const val INVALID_ARGUMENT = 0x12 + const val INVALID_TYPE = 0x13 + const val OPERATOR_MISSING = 0x14 + const val REGISTRATION_EXCEEDING_MAX_KERNELS = 0x15 + const val REGISTRATION_ALREADY_REGISTERED = 0x16 + + // Resource errors + const val NOT_FOUND = 0x20 + const val MEMORY_ALLOCATION_FAILED = 0x21 + const val ACCESS_FAILED = 0x22 + const val INVALID_PROGRAM = 0x23 + const val INVALID_EXTERNAL_DATA = 0x24 + const val OUT_OF_RESOURCES = 0x25 + + // Delegate errors + const val DELEGATE_INVALID_COMPATIBILITY = 0x30 + const val DELEGATE_MEMORY_ALLOCATION_FAILED = 0x31 + const val DELEGATE_INVALID_HANDLE = 0x32 + + @DoNotStrip + @JvmStatic + fun makeExecutorchException(errorCode: Int, details: String?): RuntimeException = + when (errorCode) { + INVALID_ARGUMENT -> ExecutorchInvalidArgumentException(details) + else -> ExecutorchRuntimeException(errorCode, details) + } + } +} diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/MethodMetadata.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/MethodMetadata.java deleted file mode 100644 index a46b27ab39e..00000000000 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/MethodMetadata.java +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -package org.pytorch.executorch; - -/** Immutable metadata for a method in a Module. */ -public class MethodMetadata { - private final String mName; - private final String[] mBackends; - - MethodMetadata(String name, String[] backends) { - mName = name; - mBackends = backends; - } - - /** - * @return Method name - */ - public String getName() { - return mName; - } - - /** - * @return Backends used for this method - */ - public String[] getBackends() { - return mBackends; - } -} diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/MethodMetadata.kt b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/MethodMetadata.kt new file mode 100644 index 00000000000..2f25f32c92f --- /dev/null +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/MethodMetadata.kt @@ -0,0 +1,12 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package org.pytorch.executorch + +/** Immutable metadata for a method in a Module. */ +class MethodMetadata internal constructor(val name: String, val backends: Array) diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java deleted file mode 100644 index 94a3ed8d160..00000000000 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java +++ /dev/null @@ -1,315 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -package org.pytorch.executorch; - -import com.facebook.jni.HybridData; -import com.facebook.jni.annotations.DoNotStrip; -import com.facebook.soloader.nativeloader.NativeLoader; -import com.facebook.soloader.nativeloader.SystemDelegate; -import java.io.Closeable; -import java.util.HashMap; -import java.util.Map; -import java.util.concurrent.locks.Lock; -import java.util.concurrent.locks.ReentrantLock; -import org.pytorch.executorch.annotations.Experimental; - -/** - * Java wrapper for ExecuTorch Module. - * - *

Warning: These APIs are experimental and subject to change without notice - */ -@Experimental -public class Module implements Closeable { - - static { - if (!NativeLoader.isInitialized()) { - NativeLoader.init(new SystemDelegate()); - } - // Loads libexecutorch.so from jniLibs - NativeLoader.loadLibrary("executorch"); - } - - /** Load mode for the module. Load the whole file as a buffer. */ - public static final int LOAD_MODE_FILE = 0; - - /** Load mode for the module. Use mmap to load pages into memory. */ - public static final int LOAD_MODE_MMAP = 1; - - /** Load mode for the module. Use memory locking and handle errors. */ - public static final int LOAD_MODE_MMAP_USE_MLOCK = 2; - - /** Load mode for the module. Use memory locking and ignore errors. */ - public static final int LOAD_MODE_MMAP_USE_MLOCK_IGNORE_ERRORS = 3; - - private final HybridData mHybridData; - - private final Map mMethodMetadata; - - @DoNotStrip - private static native HybridData initHybrid( - String moduleAbsolutePath, int loadMode, int numThreads); - - private Module(String moduleAbsolutePath, int loadMode, int numThreads) { - ExecuTorchRuntime runtime = ExecuTorchRuntime.getRuntime(); - - mHybridData = initHybrid(moduleAbsolutePath, loadMode, numThreads); - - mMethodMetadata = populateMethodMeta(); - } - - private Map populateMethodMeta() { - String[] methods = getMethods(); - Map metadata = new HashMap(); - for (String name : methods) { - metadata.put(name, new MethodMetadata(name, getUsedBackends(name))); - } - return metadata; - } - - /** Lock protecting the non-thread safe methods in mHybridData. */ - private Lock mLock = new ReentrantLock(); - - /** - * Loads a serialized ExecuTorch module from the specified path on the disk. - * - * @param modelPath path to file that contains the serialized ExecuTorch module. - * @param loadMode load mode for the module. See constants in {@link Module}. - * @return new {@link org.pytorch.executorch.Module} object which owns the model module. - */ - public static Module load(final String modelPath, int loadMode) { - return load(modelPath, loadMode, 0); - } - - /** - * Loads a serialized ExecuTorch module from the specified path on the disk. - * - * @param modelPath path to file that contains the serialized ExecuTorch module. - * @param loadMode load mode for the module. See constants in {@link Module}. - * @param numThreads the number of threads to use for inference. A value of 0 defaults to a - * hardware-specific default. - * @return new {@link org.pytorch.executorch.Module} object which owns the model module. - */ - public static Module load(final String modelPath, int loadMode, int numThreads) { - ExecuTorchRuntime.validateFilePath(modelPath, "model path"); - return new Module(modelPath, loadMode, numThreads); - } - - /** - * Loads a serialized ExecuTorch module from the specified path on the disk to run on CPU. - * - * @param modelPath path to file that contains the serialized ExecuTorch module. - * @return new {@link org.pytorch.executorch.Module} object which owns the model module. - */ - public static Module load(final String modelPath) { - return load(modelPath, LOAD_MODE_FILE); - } - - /** - * Runs the 'forward' method of this module with the specified arguments. - * - * @param inputs arguments for the ExecuTorch module's 'forward' method. Note: if method 'forward' - * requires inputs but no inputs are given, the function will not error out, but run 'forward' - * with sample inputs. - * @return return value from the 'forward' method. - */ - public EValue[] forward(EValue... inputs) { - return execute("forward", inputs); - } - - /** - * Runs the specified method of this module with the specified arguments. - * - * @param methodName name of the ExecuTorch method to run. - * @param inputs arguments that will be passed to ExecuTorch method. - * @return return value from the method. - */ - public EValue[] execute(String methodName, EValue... inputs) { - mLock.lock(); - try { - if (!mHybridData.isValid()) { - throw new IllegalStateException("Module has been destroyed"); - } - return executeNative(methodName, inputs); - } finally { - mLock.unlock(); - } - } - - @DoNotStrip - private native EValue[] executeNative(String methodName, EValue... inputs); - - /** - * Load a method on this module. This might help with the first time inference performance, - * because otherwise the method is loaded lazily when it's execute. Note: this function is - * synchronous, and will block until the method is loaded. Therefore, it is recommended to call - * this on a background thread. However, users need to make sure that they don't execute before - * this function returns. - */ - public void loadMethod(String methodName) { - mLock.lock(); - try { - if (!mHybridData.isValid()) { - throw new IllegalStateException("Module has been destroyed"); - } - int errorCode = loadMethodNative(methodName); - if (errorCode != 0) { - throw new ExecutorchRuntimeException(errorCode, "Failed to load method: " + methodName); - } - } finally { - mLock.unlock(); - } - } - - @DoNotStrip - private native int loadMethodNative(String methodName); - - /** - * Returns the names of the backends in a certain method. - * - * @param methodName method name to query - * @return an array of backend name - */ - @DoNotStrip - private native String[] getUsedBackends(String methodName); - - /** - * Returns the names of methods. - * - * @return name of methods in this Module - */ - public String[] getMethods() { - mLock.lock(); - try { - if (!mHybridData.isValid()) { - throw new IllegalStateException("Module has been destroyed"); - } - return getMethodsNative(); - } finally { - mLock.unlock(); - } - } - - @DoNotStrip - private native String[] getMethodsNative(); - - /** - * Get the corresponding @MethodMetadata for a method - * - * @param name method name - * @return @MethodMetadata for this method - */ - public MethodMetadata getMethodMetadata(String name) { - mLock.lock(); - try { - if (!mHybridData.isValid()) { - throw new IllegalStateException("Module has been destroyed"); - } - MethodMetadata methodMetadata = mMethodMetadata.get(name); - if (methodMetadata == null) { - throw new IllegalArgumentException("method " + name + " does not exist for this module"); - } - return methodMetadata; - } finally { - mLock.unlock(); - } - } - - @DoNotStrip - private static native String[] readLogBufferStaticNative(); - - public static String[] readLogBufferStatic() { - return readLogBufferStaticNative(); - } - - /** Retrieve the in-memory log buffer, containing the most recent ExecuTorch log entries. */ - public String[] readLogBuffer() { - mLock.lock(); - try { - if (!mHybridData.isValid()) { - throw new IllegalStateException("Module has been destroyed"); - } - return readLogBufferNative(); - } finally { - mLock.unlock(); - } - } - - @DoNotStrip - private native String[] readLogBufferNative(); - - /** - * Dump the ExecuTorch ETRecord file to /data/local/tmp/result.etdump. - * - *

Currently for internal (minibench) use only. - * - * @return true if the etdump was successfully written, false otherwise. - */ - @Experimental - public boolean etdump() { - mLock.lock(); - try { - if (!mHybridData.isValid()) { - throw new IllegalStateException("Module has been destroyed"); - } - return etdumpNative(); - } finally { - mLock.unlock(); - } - } - - @DoNotStrip - private native boolean etdumpNative(); - - /** - * Dump the ExecuTorch ETDump file to {@code outputPath}. - * - * @param outputPath absolute path to write the etdump file to. - * @return true if the etdump was successfully written, false otherwise. - */ - @Experimental - public boolean etdump(String outputPath) { - mLock.lock(); - try { - if (!mHybridData.isValid()) { - throw new IllegalStateException("Module has been destroyed"); - } - return etdumpToNative(outputPath); - } finally { - mLock.unlock(); - } - } - - @DoNotStrip - private native boolean etdumpToNative(String outputPath); - - /** - * Explicitly destroys the native Module object. Calling this method is not required, as the - * native object will be destroyed when this object is garbage-collected. However, the timing of - * garbage collection is not guaranteed, so proactively calling {@code destroy} can free memory - * more quickly. See {@link com.facebook.jni.HybridData#resetNative}. - */ - public void destroy() { - if (mLock.tryLock()) { - try { - if (mHybridData.isValid()) { - mHybridData.resetNative(); - } - } finally { - mLock.unlock(); - } - } else { - throw new IllegalStateException("Cannot destroy module while method is executing"); - } - } - - @Override - public void close() { - destroy(); - } -} diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.kt b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.kt new file mode 100644 index 00000000000..15f8dbbc992 --- /dev/null +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.kt @@ -0,0 +1,267 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package org.pytorch.executorch + +import com.facebook.jni.HybridData +import com.facebook.jni.annotations.DoNotStrip +import com.facebook.soloader.nativeloader.NativeLoader +import com.facebook.soloader.nativeloader.SystemDelegate +import java.io.Closeable +import java.util.concurrent.locks.ReentrantLock +import org.pytorch.executorch.annotations.Experimental + +/** + * Java wrapper for ExecuTorch Module. + * + * Warning: These APIs are experimental and subject to change without notice + */ +@Experimental +open class Module private constructor(moduleAbsolutePath: String, loadMode: Int, numThreads: Int) : + Closeable { + + private val mHybridData: HybridData + private val mMethodMetadata: Map + + /** Lock protecting the non-thread safe methods in mHybridData. */ + private val mLock = ReentrantLock() + + init { + ExecuTorchRuntime.getRuntime() + mHybridData = initHybrid(moduleAbsolutePath, loadMode, numThreads) + mMethodMetadata = populateMethodMeta() + } + + private fun populateMethodMeta(): Map { + val methods = getMethodsNative() + val metadata = HashMap() + for (name in methods) { + metadata[name] = MethodMetadata(name, getUsedBackends(name)) + } + return metadata + } + + /** + * Runs the 'forward' method of this module with the specified arguments. + * + * @param inputs arguments for the ExecuTorch module's 'forward' method. Note: if method 'forward' + * requires inputs but no inputs are given, the function will not error out, but run 'forward' + * with sample inputs. + * @return return value from the 'forward' method. + */ + open fun forward(vararg inputs: EValue): Array = execute("forward", *inputs) + + /** + * Runs the specified method of this module with the specified arguments. + * + * @param methodName name of the ExecuTorch method to run. + * @param inputs arguments that will be passed to ExecuTorch method. + * @return return value from the method. + */ + open fun execute(methodName: String, vararg inputs: EValue): Array { + mLock.lock() + try { + check(mHybridData.isValid) { "Module has been destroyed" } + return executeNative(methodName, *inputs) + } finally { + mLock.unlock() + } + } + + @DoNotStrip + private external fun executeNative(methodName: String, vararg inputs: EValue): Array + + /** + * Load a method on this module. This might help with the first time inference performance, + * because otherwise the method is loaded lazily when it's execute. Note: this function is + * synchronous, and will block until the method is loaded. Therefore, it is recommended to call + * this on a background thread. However, users need to make sure that they don't execute before + * this function returns. + */ + open fun loadMethod(methodName: String) { + mLock.lock() + try { + check(mHybridData.isValid) { "Module has been destroyed" } + val errorCode = loadMethodNative(methodName) + if (errorCode != 0) { + throw ExecutorchRuntimeException(errorCode, "Failed to load method: $methodName") + } + } finally { + mLock.unlock() + } + } + + @DoNotStrip private external fun loadMethodNative(methodName: String): Int + + /** + * Returns the names of the backends in a certain method. + * + * @param methodName method name to query + * @return an array of backend name + */ + @DoNotStrip private external fun getUsedBackends(methodName: String): Array + + /** + * Returns the names of methods. + * + * @return name of methods in this Module + */ + open fun getMethods(): Array { + mLock.lock() + try { + check(mHybridData.isValid) { "Module has been destroyed" } + return getMethodsNative() + } finally { + mLock.unlock() + } + } + + @DoNotStrip private external fun getMethodsNative(): Array + + /** + * Get the corresponding [MethodMetadata] for a method + * + * @param name method name + * @return [MethodMetadata] for this method + */ + open fun getMethodMetadata(name: String): MethodMetadata { + mLock.lock() + try { + check(mHybridData.isValid) { "Module has been destroyed" } + return mMethodMetadata[name] + ?: throw IllegalArgumentException("method $name does not exist for this module") + } finally { + mLock.unlock() + } + } + + /** Retrieve the in-memory log buffer, containing the most recent ExecuTorch log entries. */ + open fun readLogBuffer(): Array? { + mLock.lock() + try { + check(mHybridData.isValid) { "Module has been destroyed" } + return readLogBufferNative() + } finally { + mLock.unlock() + } + } + + @DoNotStrip private external fun readLogBufferNative(): Array? + + /** + * Dump the ExecuTorch ETRecord file to /data/local/tmp/result.etdump. + * + * Currently for internal (minibench) use only. + * + * @return true if the etdump was successfully written, false otherwise. + */ + @Experimental + open fun etdump(): Boolean { + mLock.lock() + try { + check(mHybridData.isValid) { "Module has been destroyed" } + return etdumpNative() + } finally { + mLock.unlock() + } + } + + @DoNotStrip private external fun etdumpNative(): Boolean + + /** + * Dump the ExecuTorch ETDump file to [outputPath]. + * + * @param outputPath absolute path to write the etdump file to. + * @return true if the etdump was successfully written, false otherwise. + */ + @Experimental + open fun etdump(outputPath: String): Boolean { + mLock.lock() + try { + check(mHybridData.isValid) { "Module has been destroyed" } + return etdumpToNative(outputPath) + } finally { + mLock.unlock() + } + } + + @DoNotStrip private external fun etdumpToNative(outputPath: String): Boolean + + /** + * Explicitly destroys the native Module object. Calling this method is not required, as the + * native object will be destroyed when this object is garbage-collected. However, the timing of + * garbage collection is not guaranteed, so proactively calling `destroy` can free memory more + * quickly. See [com.facebook.jni.HybridData.resetNative]. + */ + open fun destroy() { + if (mLock.tryLock()) { + try { + if (mHybridData.isValid) { + mHybridData.resetNative() + } + } finally { + mLock.unlock() + } + } else { + throw IllegalStateException("Cannot destroy module while method is executing") + } + } + + override fun close() { + destroy() + } + + companion object { + init { + if (!NativeLoader.isInitialized()) { + NativeLoader.init(SystemDelegate()) + } + NativeLoader.loadLibrary("executorch") + } + + /** Load mode for the module. Load the whole file as a buffer. */ + const val LOAD_MODE_FILE = 0 + + /** Load mode for the module. Use mmap to load pages into memory. */ + const val LOAD_MODE_MMAP = 1 + + /** Load mode for the module. Use memory locking and handle errors. */ + const val LOAD_MODE_MMAP_USE_MLOCK = 2 + + /** Load mode for the module. Use memory locking and ignore errors. */ + const val LOAD_MODE_MMAP_USE_MLOCK_IGNORE_ERRORS = 3 + + /** + * Loads a serialized ExecuTorch module from the specified path on the disk. + * + * @param modelPath path to file that contains the serialized ExecuTorch module. + * @param loadMode load mode for the module. See constants in [Module]. + * @param numThreads the number of threads to use for inference. A value of 0 defaults to a + * hardware-specific default. + * @return new [Module] object which owns the model module. + */ + @JvmStatic + @JvmOverloads + fun load(modelPath: String?, loadMode: Int = LOAD_MODE_FILE, numThreads: Int = 0): Module { + ExecuTorchRuntime.validateFilePath(modelPath, "model path") + return Module(modelPath!!, loadMode, numThreads) + } + + @DoNotStrip + @JvmStatic + private external fun initHybrid( + moduleAbsolutePath: String, + loadMode: Int, + numThreads: Int, + ): HybridData + + @DoNotStrip @JvmStatic fun readLogBufferStatic(): Array? = readLogBufferStaticNative() + + @DoNotStrip @JvmStatic private external fun readLogBufferStaticNative(): Array? + } +} diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/annotations/Experimental.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/annotations/Experimental.kt similarity index 68% rename from extension/android/executorch_android/src/main/java/org/pytorch/executorch/annotations/Experimental.java rename to extension/android/executorch_android/src/main/java/org/pytorch/executorch/annotations/Experimental.kt index f5f36fc56da..42a5980d6ba 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/annotations/Experimental.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/annotations/Experimental.kt @@ -6,13 +6,13 @@ * LICENSE file in the root directory of this source tree. */ -package org.pytorch.executorch.annotations; +package org.pytorch.executorch.annotations /** * This annotation indicates that an API is experimental and may change or be removed at any time. * It does not provide any guarantees for API stability or backward-compatibility. * - *

This status is not permanent, and APIs marked with this annotation will need to be either made + * This status is not permanent, and APIs marked with this annotation will need to be either made * more robust or removed in the future. */ -public @interface Experimental {} +@Retention(AnnotationRetention.BINARY) annotation class Experimental diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/annotations/package-info.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/annotations/package-info.java deleted file mode 100644 index 2173a04c69d..00000000000 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/annotations/package-info.java +++ /dev/null @@ -1,2 +0,0 @@ -/** Annotations used by ExecuTorch Android Java/JNI package. */ -package org.pytorch.executorch.annotations; diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmCallback.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmCallback.kt similarity index 53% rename from extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmCallback.java rename to extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmCallback.kt index 4e834d06721..3b56986bf14 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmCallback.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmCallback.kt @@ -6,45 +6,42 @@ * LICENSE file in the root directory of this source tree. */ -package org.pytorch.executorch.extension.llm; +package org.pytorch.executorch.extension.llm -import com.facebook.jni.annotations.DoNotStrip; -import org.pytorch.executorch.annotations.Experimental; +import com.facebook.jni.annotations.DoNotStrip +import org.pytorch.executorch.annotations.Experimental /** - * Callback interface for Llama model. Users can implement this interface to receive the generated + * Callback interface for Llm model. Users can implement this interface to receive the generated * tokens and statistics. * - *

Warning: These APIs are experimental and subject to change without notice + * Warning: These APIs are experimental and subject to change without notice */ @Experimental -public interface LlmCallback { +interface LlmCallback { /** * Called when a new result is available from JNI. Users will keep getting onResult() invocations * until generate() finishes. * * @param result Last generated token */ - @DoNotStrip - public void onResult(String result); + @DoNotStrip fun onResult(result: String) /** * Called when the statistics for the generate() is available. * - *

The result will be a JSON string. See extension/llm/stats.h for the field definitions. + * The result will be a JSON string. See extension/llm/stats.h for the field definitions. * * @param stats JSON string containing the statistics for the generate() */ - @DoNotStrip - default void onStats(String stats) {} + @DoNotStrip fun onStats(stats: String) {} /** * Called when an error occurs during generate(). * - * @param errorCode Error code from the ExecuTorch runtime (see {@link - * org.pytorch.executorch.ExecutorchRuntimeException}) + * @param errorCode Error code from the ExecuTorch runtime (see + * [org.pytorch.executorch.ExecutorchRuntimeException]) * @param message Human-readable error description */ - @DoNotStrip - default void onError(int errorCode, String message) {} + @DoNotStrip fun onError(errorCode: Int, message: String) {} } diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmGenerationConfig.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmGenerationConfig.java deleted file mode 100644 index db7941aadad..00000000000 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmGenerationConfig.java +++ /dev/null @@ -1,198 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -package org.pytorch.executorch.extension.llm; - -/** - * Configuration class for controlling text generation parameters in LLM operations. - * - *

This class provides settings for text generation behavior including output formatting, - * generation limits, and sampling parameters. Instances should be created using the {@link - * #create()} method and the fluent builder pattern. - */ -public class LlmGenerationConfig { - private final boolean echo; - private final int maxNewTokens; - private final boolean warming; - private final int seqLen; - private final float temperature; - private final int numBos; - private final int numEos; - - private LlmGenerationConfig(Builder builder) { - this.echo = builder.echo; - this.maxNewTokens = builder.maxNewTokens; - this.warming = builder.warming; - this.seqLen = builder.seqLen; - this.temperature = builder.temperature; - this.numBos = builder.numBos; - this.numEos = builder.numEos; - } - - /** - * Creates a new Builder instance for constructing generation configurations. - * - * @return a new Builder with default configuration values - */ - public static Builder create() { - return new Builder(); - } - - /** - * @return true if input prompt should be included in the output - */ - public boolean isEcho() { - return echo; - } - - /** - * @return maximum number of tokens to generate (-1 for unlimited) - */ - public int getMaxNewTokens() { - return maxNewTokens; - } - - /** - * @return true if model warming is enabled - */ - public boolean isWarming() { - return warming; - } - - /** - * @return maximum sequence length for generation (-1 for default) - */ - public int getSeqLen() { - return seqLen; - } - - /** - * @return temperature value for sampling (higher = more random) - */ - public float getTemperature() { - return temperature; - } - - /** - * @return number of BOS tokens to prepend - */ - public int getNumBos() { - return numBos; - } - - /** - * @return number of EOS tokens to append - */ - public int getNumEos() { - return numEos; - } - - /** - * Builder class for constructing LlmGenerationConfig instances. - * - *

Provides a fluent interface for configuring generation parameters with sensible defaults. - * All methods return the builder instance to enable method chaining. - */ - public static class Builder { - private boolean echo = true; - private int maxNewTokens = -1; - private boolean warming = false; - private int seqLen = -1; - private float temperature = 0.8f; - private int numBos = 0; - private int numEos = 0; - - Builder() {} - - /** - * Sets whether to include the input prompt in the generated output. - * - * @param echo true to include input prompt, false to return only new tokens - * @return this builder instance - */ - public Builder echo(boolean echo) { - this.echo = echo; - return this; - } - - /** - * Sets the maximum number of new tokens to generate. - * - * @param maxNewTokens the token limit (-1 for unlimited generation) - * @return this builder instance - */ - public Builder maxNewTokens(int maxNewTokens) { - this.maxNewTokens = maxNewTokens; - return this; - } - - /** - * Enables or disables model warming. - * - * @param warming true to generate initial tokens for model warmup - * @return this builder instance - */ - public Builder warming(boolean warming) { - this.warming = warming; - return this; - } - - /** - * Sets the maximum sequence length for generation. - * - * @param seqLen maximum sequence length (-1 for default behavior) - * @return this builder instance - */ - public Builder seqLen(int seqLen) { - this.seqLen = seqLen; - return this; - } - - /** - * Sets the temperature for random sampling. - * - * @param temperature sampling temperature (typical range 0.0-1.0) - * @return this builder instance - */ - public Builder temperature(float temperature) { - this.temperature = temperature; - return this; - } - - /** - * Sets the number of BOS tokens to prepend. - * - * @param numBos number of BOS tokens - * @return this builder instance - */ - public Builder numBos(int numBos) { - this.numBos = numBos; - return this; - } - - /** - * Sets the number of EOS tokens to append. - * - * @param numEos number of EOS tokens - * @return this builder instance - */ - public Builder numEos(int numEos) { - this.numEos = numEos; - return this; - } - - /** - * Constructs the LlmGenerationConfig instance with the configured parameters. - * - * @return new LlmGenerationConfig instance with current builder settings - */ - public LlmGenerationConfig build() { - return new LlmGenerationConfig(this); - } - } -} diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmGenerationConfig.kt b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmGenerationConfig.kt new file mode 100644 index 00000000000..c0f8956fb7f --- /dev/null +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmGenerationConfig.kt @@ -0,0 +1,78 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package org.pytorch.executorch.extension.llm + +/** + * Configuration class for controlling text generation parameters in LLM operations. + * + * This class provides settings for text generation behavior including output formatting, generation + * limits, and sampling parameters. Instances should be created using the [create] method and the + * fluent builder pattern. + */ +class LlmGenerationConfig +private constructor( + @get:JvmName("isEcho") val echo: Boolean, + val maxNewTokens: Int, + @get:JvmName("isWarming") val warming: Boolean, + val seqLen: Int, + val temperature: Float, + val numBos: Int, + val numEos: Int, +) { + + companion object { + /** + * Creates a new Builder instance for constructing generation configurations. + * + * @return a new Builder with default configuration values + */ + @JvmStatic fun create(): Builder = Builder() + } + + /** + * Builder class for constructing LlmGenerationConfig instances. + * + * Provides a fluent interface for configuring generation parameters with sensible defaults. All + * methods return the builder instance to enable method chaining. + */ + class Builder internal constructor() { + private var echo: Boolean = true + private var maxNewTokens: Int = -1 + private var warming: Boolean = false + private var seqLen: Int = -1 + private var temperature: Float = 0.8f + private var numBos: Int = 0 + private var numEos: Int = 0 + + /** Sets whether to include the input prompt in the generated output. */ + fun echo(echo: Boolean): Builder = apply { this.echo = echo } + + /** Sets the maximum number of new tokens to generate. */ + fun maxNewTokens(maxNewTokens: Int): Builder = apply { this.maxNewTokens = maxNewTokens } + + /** Enables or disables model warming. */ + fun warming(warming: Boolean): Builder = apply { this.warming = warming } + + /** Sets the maximum sequence length for generation. */ + fun seqLen(seqLen: Int): Builder = apply { this.seqLen = seqLen } + + /** Sets the temperature for random sampling. */ + fun temperature(temperature: Float): Builder = apply { this.temperature = temperature } + + /** Sets the number of BOS tokens to prepend. */ + fun numBos(numBos: Int): Builder = apply { this.numBos = numBos } + + /** Sets the number of EOS tokens to append. */ + fun numEos(numEos: Int): Builder = apply { this.numEos = numEos } + + /** Constructs the LlmGenerationConfig instance with the configured parameters. */ + fun build(): LlmGenerationConfig = + LlmGenerationConfig(echo, maxNewTokens, warming, seqLen, temperature, numBos, numEos) + } +} diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java deleted file mode 100644 index 0c467b13f44..00000000000 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java +++ /dev/null @@ -1,823 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -package org.pytorch.executorch.extension.llm; - -import com.facebook.jni.HybridData; -import com.facebook.jni.annotations.DoNotStrip; -import java.io.Closeable; -import java.nio.ByteBuffer; -import java.util.List; -import java.util.concurrent.locks.ReentrantLock; -import org.pytorch.executorch.ExecuTorchRuntime; -import org.pytorch.executorch.ExecutorchRuntimeException; -import org.pytorch.executorch.annotations.Experimental; - -/** - * LlmModule is a wrapper around the Executorch LLM. It provides a simple interface to generate text - * from the model. - * - *

Warning: These APIs are experimental and subject to change without notice - */ -@Experimental -public class LlmModule implements Closeable { - - public static final int MODEL_TYPE_TEXT = 1; - public static final int MODEL_TYPE_TEXT_VISION = 2; - public static final int MODEL_TYPE_MULTIMODAL = 2; - - private final HybridData mHybridData; - private final ReentrantLock mLock = new ReentrantLock(); - private volatile boolean mDestroyed = false; - private static final int DEFAULT_SEQ_LEN = 128; - private static final boolean DEFAULT_ECHO = true; - private static final float DEFAULT_TEMPERATURE = -1.0f; - private static final int DEFAULT_BOS = 0; - private static final int DEFAULT_EOS = 0; - private static final int DEFAULT_LOAD_MODE = LlmModuleConfig.LOAD_MODE_MMAP; - - @DoNotStrip - private static native HybridData initHybrid( - int modelType, - String modulePath, - String tokenizerPath, - float temperature, - List dataFiles, - int numBos, - int numEos, - int loadMode); - - private LlmModule( - int modelType, - String modulePath, - String tokenizerPath, - float temperature, - List dataFiles, - int numBos, - int numEos, - int loadMode) { - ExecuTorchRuntime.getRuntime(); - ExecuTorchRuntime.validateFilePath(modulePath, "model path"); - ExecuTorchRuntime.validateFilePath(tokenizerPath, "tokenizer path"); - - mHybridData = - initHybrid( - modelType, modulePath, tokenizerPath, temperature, dataFiles, numBos, numEos, loadMode); - } - - /** - * Constructs a LLM Module for a model with given type, model path, tokenizer, temperature, and - * dataFiles. - */ - public LlmModule( - int modelType, - String modulePath, - String tokenizerPath, - float temperature, - List dataFiles, - int numBos, - int numEos) { - this( - modelType, - modulePath, - tokenizerPath, - temperature, - dataFiles, - numBos, - numEos, - DEFAULT_LOAD_MODE); - } - - /** - * Constructs a LLM Module for a model with given type, model path, tokenizer, temperature, and - * dataFiles. - */ - public LlmModule( - int modelType, - String modulePath, - String tokenizerPath, - float temperature, - List dataFiles) { - this( - modelType, - modulePath, - tokenizerPath, - temperature, - dataFiles, - DEFAULT_BOS, - DEFAULT_EOS, - DEFAULT_LOAD_MODE); - } - - /** - * Constructs a LLM Module for a model with given type, model path, tokenizer, temperature, and - * data path. - */ - public LlmModule( - int modelType, - String modulePath, - String tokenizerPath, - float temperature, - String dataPath, - int numBos, - int numEos) { - this( - modelType, - modulePath, - tokenizerPath, - temperature, - dataPath != null ? List.of(dataPath) : List.of(), - numBos, - numEos); - } - - /** - * Constructs a LLM Module for a model with given type, model path, tokenizer, temperature, and - * data path. - */ - public LlmModule( - int modelType, String modulePath, String tokenizerPath, float temperature, String dataPath) { - this(modelType, modulePath, tokenizerPath, temperature, dataPath, DEFAULT_BOS, DEFAULT_EOS); - } - - /** Constructs a LLM Module for a model with given model path, tokenizer, temperature. */ - public LlmModule(String modulePath, String tokenizerPath, float temperature) { - this( - MODEL_TYPE_TEXT, - modulePath, - tokenizerPath, - temperature, - List.of(), - DEFAULT_BOS, - DEFAULT_EOS); - } - - /** - * Constructs a LLM Module for a model with given model path, tokenizer, temperature and data - * path. - */ - public LlmModule(String modulePath, String tokenizerPath, float temperature, String dataPath) { - this( - MODEL_TYPE_TEXT, - modulePath, - tokenizerPath, - temperature, - List.of(dataPath), - DEFAULT_BOS, - DEFAULT_EOS); - } - - /** Constructs a LLM Module for a model with given path, tokenizer, and temperature. */ - public LlmModule(int modelType, String modulePath, String tokenizerPath, float temperature) { - this(modelType, modulePath, tokenizerPath, temperature, List.of(), DEFAULT_BOS, DEFAULT_EOS); - } - - /** Constructs a LLM Module for a model with the given LlmModuleConfig */ - public LlmModule(LlmModuleConfig config) { - this( - config.getModelType(), - config.getModulePath(), - config.getTokenizerPath(), - config.getTemperature(), - config.getDataPath() != null ? List.of(config.getDataPath()) : List.of(), - config.getNumBos(), - config.getNumEos(), - config.getLoadMode()); - } - - private void checkNotDestroyed() { - if (mDestroyed) throw new IllegalStateException("LlmModule has been destroyed"); - } - - private void checkNotReentrant() { - if (mLock.getHoldCount() > 1) { - throw new IllegalStateException("Cannot call LlmModule methods from within a callback"); - } - } - - /** - * Releases native resources. Callers must ensure no other methods are in-flight. Call {@link - * #stop()} and wait for {@link #generate(String, LlmCallback)} to return before calling this - * method. - */ - @Override - public void close() { - if (mLock.tryLock()) { - try { - if (mLock.getHoldCount() > 1) { - throw new IllegalStateException( - "Cannot close module from within a callback during execution"); - } - if (!mDestroyed) { - mDestroyed = true; - mHybridData.resetNative(); - } - } finally { - mLock.unlock(); - } - } else { - throw new IllegalStateException("Cannot close module while method is executing"); - } - } - - /** - * @deprecated Use {@link #close()} instead. - */ - @Deprecated - public void resetNative() { - close(); - } - - /** - * Start generating tokens from the module. - * - * @param prompt Input prompt - * @param llmCallback callback object to receive results. - */ - public void generate(String prompt, LlmCallback llmCallback) { - generate( - prompt, - DEFAULT_SEQ_LEN, - llmCallback, - DEFAULT_ECHO, - DEFAULT_TEMPERATURE, - DEFAULT_BOS, - DEFAULT_EOS); - } - - /** - * Start generating tokens from the module. - * - * @param prompt Input prompt - * @param seqLen sequence length - * @param llmCallback callback object to receive results. - */ - public void generate(String prompt, int seqLen, LlmCallback llmCallback) { - generate( - null, - 0, - 0, - 0, - prompt, - seqLen, - llmCallback, - DEFAULT_ECHO, - DEFAULT_TEMPERATURE, - DEFAULT_BOS, - DEFAULT_EOS); - } - - /** - * Start generating tokens from the module. - * - * @param prompt Input prompt - * @param llmCallback callback object to receive results - * @param echo indicate whether to echo the input prompt or not (text completion vs chat) - */ - public void generate(String prompt, LlmCallback llmCallback, boolean echo) { - generate( - null, - 0, - 0, - 0, - prompt, - DEFAULT_SEQ_LEN, - llmCallback, - echo, - DEFAULT_TEMPERATURE, - DEFAULT_BOS, - DEFAULT_EOS); - } - - /** - * Start generating tokens from the module. - * - * @param prompt Input prompt - * @param seqLen sequence length - * @param llmCallback callback object to receive results - * @param echo indicate whether to echo the input prompt or not (text completion vs chat) - */ - public void generate(String prompt, int seqLen, LlmCallback llmCallback, boolean echo) { - generate(prompt, seqLen, llmCallback, echo, DEFAULT_TEMPERATURE, DEFAULT_BOS, DEFAULT_EOS); - } - - /** - * Start generating tokens from the module. - * - * @param prompt Input prompt - * @param seqLen sequence length - * @param llmCallback callback object to receive results - * @param echo indicate whether to echo the input prompt or not (text completion vs chat) - * @param temperature temperature for sampling (use negative value to use module default) - * @param numBos number of BOS tokens to prepend - * @param numEos number of EOS tokens to append - */ - public void generate( - String prompt, - int seqLen, - LlmCallback llmCallback, - boolean echo, - float temperature, - int numBos, - int numEos) { - mLock.lock(); - try { - checkNotReentrant(); - checkNotDestroyed(); - int err = generateNative(prompt, seqLen, llmCallback, echo, temperature, numBos, numEos); - if (err != 0) { - throw ExecutorchRuntimeException.makeExecutorchException(err, "Failed to generate"); - } - } finally { - mLock.unlock(); - } - } - - @DoNotStrip - private native int generateNative( - String prompt, - int seqLen, - LlmCallback llmCallback, - boolean echo, - float temperature, - int numBos, - int numEos); - - /** - * Start generating tokens from the module. - * - * @param prompt Input prompt - * @param config the config for generation - * @param llmCallback callback object to receive results - */ - public void generate(String prompt, LlmGenerationConfig config, LlmCallback llmCallback) { - int seqLen = config.getSeqLen(); - boolean echo = config.isEcho(); - float temperature = config.getTemperature(); - int numBos = config.getNumBos(); - int numEos = config.getNumEos(); - generate(null, 0, 0, 0, prompt, seqLen, llmCallback, echo, temperature, numBos, numEos); - } - - /** - * Start generating tokens from the module. - * - * @param image Input image as a byte array - * @param width Input image width - * @param height Input image height - * @param channels Input image number of channels - * @param prompt Input prompt - * @param seqLen sequence length - * @param llmCallback callback object to receive results. - * @param echo indicate whether to echo the input prompt or not (text completion vs chat) - */ - public void generate( - int[] image, - int width, - int height, - int channels, - String prompt, - int seqLen, - LlmCallback llmCallback, - boolean echo) { - generate( - image, - width, - height, - channels, - prompt, - seqLen, - llmCallback, - echo, - DEFAULT_TEMPERATURE, - DEFAULT_BOS, - DEFAULT_EOS); - } - - /** - * Start generating tokens from the module. - * - * @param image Input image as a byte array - * @param width Input image width - * @param height Input image height - * @param channels Input image number of channels - * @param prompt Input prompt - * @param seqLen sequence length - * @param llmCallback callback object to receive results. - * @param echo indicate whether to echo the input prompt or not (text completion vs chat) - * @param temperature temperature for sampling (use negative value to use module default) - */ - public void generate( - int[] image, - int width, - int height, - int channels, - String prompt, - int seqLen, - LlmCallback llmCallback, - boolean echo, - float temperature) { - generate( - image, - width, - height, - channels, - prompt, - seqLen, - llmCallback, - echo, - temperature, - DEFAULT_BOS, - DEFAULT_EOS); - } - - /** - * Start generating tokens from the module. - * - * @param image Input image as a byte array - * @param width Input image width - * @param height Input image height - * @param channels Input image number of channels - * @param prompt Input prompt - * @param seqLen sequence length - * @param llmCallback callback object to receive results. - * @param echo indicate whether to echo the input prompt or not (text completion vs chat) - * @param temperature temperature for sampling (use negative value to use module default) - * @param numBos number of BOS tokens to prepend - * @param numEos number of EOS tokens to append - */ - public void generate( - int[] image, - int width, - int height, - int channels, - String prompt, - int seqLen, - LlmCallback llmCallback, - boolean echo, - float temperature, - int numBos, - int numEos) { - mLock.lock(); - try { - checkNotReentrant(); - checkNotDestroyed(); - if (image != null) { - int nativeResult = prefillImagesInput(image, width, height, channels); - if (nativeResult != 0) { - throw ExecutorchRuntimeException.makeExecutorchException(nativeResult, "Prefill failed"); - } - } - int err = generateNative(prompt, seqLen, llmCallback, echo, temperature, numBos, numEos); - if (err != 0) { - throw ExecutorchRuntimeException.makeExecutorchException(err, "Failed to generate"); - } - } finally { - mLock.unlock(); - } - } - - /** - * Prefill the KV cache with the given image input. - * - * @param image Input image as a byte array - * @param width Input image width - * @param height Input image height - * @param channels Input image number of channels - * @throws ExecutorchRuntimeException if the prefill failed - */ - @Experimental - public void prefillImages(int[] image, int width, int height, int channels) { - mLock.lock(); - try { - checkNotReentrant(); - checkNotDestroyed(); - int nativeResult = prefillImagesInput(image, width, height, channels); - if (nativeResult != 0) { - throw ExecutorchRuntimeException.makeExecutorchException(nativeResult, "Prefill failed"); - } - } finally { - mLock.unlock(); - } - } - - /** - * Prefill a multimodal Module with the given image input via a direct ByteBuffer. The buffer data - * is accessed directly without JNI array copies, unlike {@link #prefillImages(int[], int, int, - * int)}. The ByteBuffer must contain raw uint8 pixel data in CHW format with at least channels * - * height * width bytes remaining. Only the first channels * height * width bytes from the - * buffer's current position are read; the position of the original ByteBuffer is not modified. - * - * @param image Input image as a direct ByteBuffer containing uint8 pixel data - * @param width Input image width - * @param height Input image height - * @param channels Input image number of channels - * @throws IllegalArgumentException if the ByteBuffer is not direct or has insufficient remaining - * bytes - * @throws ExecutorchRuntimeException if the prefill failed - */ - @Experimental - public void prefillImages(ByteBuffer image, int width, int height, int channels) { - mLock.lock(); - try { - checkNotReentrant(); - checkNotDestroyed(); - if (!image.isDirect()) { - throw new IllegalArgumentException("Input ByteBuffer must be direct."); - } - long expectedBytes; - try { - long pixels = Math.multiplyExact((long) width, (long) height); - expectedBytes = Math.multiplyExact(pixels, (long) channels); - } catch (ArithmeticException ex) { - throw new IllegalArgumentException( - "width*height*channels is too large and overflows the allowed range.", ex); - } - if (width <= 0 - || height <= 0 - || channels <= 0 - || expectedBytes > Integer.MAX_VALUE - || image.remaining() < expectedBytes) { - throw new IllegalArgumentException( - "ByteBuffer remaining (" - + image.remaining() - + ") must be at least width*height*channels (" - + expectedBytes - + ")."); - } - // slice() so that getDirectBufferAddress on the native side returns a pointer - // starting at the current position, not the base address. - int nativeResult = prefillImagesInputBuffer(image.slice(), width, height, channels); - if (nativeResult != 0) { - throw ExecutorchRuntimeException.makeExecutorchException(nativeResult, "Prefill failed"); - } - } finally { - mLock.unlock(); - } - } - - /** - * Prefill a multimodal Module with the given normalized image input via a direct ByteBuffer. The - * buffer data is accessed directly without JNI array copies, unlike {@link - * #prefillImages(float[], int, int, int)}. The ByteBuffer must contain normalized float pixel - * data in CHW format with at least channels * height * width * 4 bytes remaining. Only the first - * channels * height * width floats from the buffer's current position are consumed. The buffer - * must use the platform's native byte order (set via {@code - * buffer.order(ByteOrder.nativeOrder())}). - * - * @param image Input normalized image as a direct ByteBuffer containing float pixel data in - * native byte order - * @param width Input image width - * @param height Input image height - * @param channels Input image number of channels - * @throws IllegalArgumentException if the ByteBuffer is not direct, has insufficient remaining - * bytes, is not float-aligned, or does not use native byte order - * @throws ExecutorchRuntimeException if the prefill failed - */ - @Experimental - public void prefillNormalizedImage(ByteBuffer image, int width, int height, int channels) { - mLock.lock(); - try { - checkNotReentrant(); - checkNotDestroyed(); - if (!image.isDirect()) { - throw new IllegalArgumentException("Input ByteBuffer must be direct."); - } - if (image.order() != java.nio.ByteOrder.nativeOrder()) { - throw new IllegalArgumentException( - "Input ByteBuffer must use native byte order (ByteOrder.nativeOrder())."); - } - if (image.position() % Float.BYTES != 0) { - throw new IllegalArgumentException( - "Input ByteBuffer position (" + image.position() + ") must be 4-byte aligned."); - } - final long expectedBytes; - try { - int wh = Math.multiplyExact(width, height); - long whc = Math.multiplyExact((long) wh, (long) channels); - long totalBytes = Math.multiplyExact(whc, (long) Float.BYTES); - if (totalBytes > Integer.MAX_VALUE) { - throw new IllegalArgumentException( - "ByteBuffer size (width*height*channels*4) exceeds Integer.MAX_VALUE bytes: " - + totalBytes); - } - expectedBytes = totalBytes; - } catch (ArithmeticException e) { - throw new IllegalArgumentException( - "Overflow while computing width*height*channels*4 for ByteBuffer size.", e); - } - if (width <= 0 || height <= 0 || channels <= 0 || image.remaining() < expectedBytes) { - throw new IllegalArgumentException( - "ByteBuffer remaining (" - + image.remaining() - + ") must be at least width*height*channels*4 (" - + expectedBytes - + ")."); - } - if (image.remaining() % Float.BYTES != 0) { - throw new IllegalArgumentException( - "ByteBuffer remaining (" - + image.remaining() - + ") must be a multiple of 4 (float size)."); - } - // slice() so that getDirectBufferAddress on the native side returns a pointer - // starting at the current position, not the base address. - int nativeResult = prefillNormalizedImagesInputBuffer(image.slice(), width, height, channels); - if (nativeResult != 0) { - throw ExecutorchRuntimeException.makeExecutorchException(nativeResult, "Prefill failed"); - } - } finally { - mLock.unlock(); - } - } - - private native int prefillImagesInput(int[] image, int width, int height, int channels); - - private native int prefillImagesInputBuffer( - ByteBuffer image, int width, int height, int channels); - - private native int prefillNormalizedImagesInputBuffer( - ByteBuffer image, int width, int height, int channels); - - /** - * Prefill the KV cache with the given normalized image input. - * - * @param image Input normalized image as a float array - * @param width Input image width - * @param height Input image height - * @param channels Input image number of channels - * @throws ExecutorchRuntimeException if the prefill failed - */ - @Experimental - public void prefillImages(float[] image, int width, int height, int channels) { - mLock.lock(); - try { - checkNotReentrant(); - checkNotDestroyed(); - int nativeResult = prefillNormalizedImagesInput(image, width, height, channels); - if (nativeResult != 0) { - throw ExecutorchRuntimeException.makeExecutorchException(nativeResult, "Prefill failed"); - } - } finally { - mLock.unlock(); - } - } - - private native int prefillNormalizedImagesInput( - float[] image, int width, int height, int channels); - - /** - * Prefill the KV cache with the given preprocessed audio input. - * - * @param audio Input preprocessed audio as a byte array - * @param batch_size Input batch size - * @param n_bins Input number of bins - * @param n_frames Input number of frames - * @throws ExecutorchRuntimeException if the prefill failed - */ - @Experimental - public void prefillAudio(byte[] audio, int batch_size, int n_bins, int n_frames) { - mLock.lock(); - try { - checkNotReentrant(); - checkNotDestroyed(); - int nativeResult = prefillAudioInput(audio, batch_size, n_bins, n_frames); - if (nativeResult != 0) { - throw ExecutorchRuntimeException.makeExecutorchException(nativeResult, "Prefill failed"); - } - } finally { - mLock.unlock(); - } - } - - private native int prefillAudioInput(byte[] audio, int batch_size, int n_bins, int n_frames); - - /** - * Prefill the KV cache with the given preprocessed audio input. - * - * @param audio Input preprocessed audio as a float array - * @param batch_size Input batch size - * @param n_bins Input number of bins - * @param n_frames Input number of frames - * @throws ExecutorchRuntimeException if the prefill failed - */ - @Experimental - public void prefillAudio(float[] audio, int batch_size, int n_bins, int n_frames) { - mLock.lock(); - try { - checkNotReentrant(); - checkNotDestroyed(); - int nativeResult = prefillAudioInputFloat(audio, batch_size, n_bins, n_frames); - if (nativeResult != 0) { - throw ExecutorchRuntimeException.makeExecutorchException(nativeResult, "Prefill failed"); - } - } finally { - mLock.unlock(); - } - } - - private native int prefillAudioInputFloat( - float[] audio, int batch_size, int n_bins, int n_frames); - - /** - * Prefill the KV cache with the given raw audio input. - * - * @param audio Input raw audio as a byte array - * @param batch_size Input batch size - * @param n_channels Input number of channels - * @param n_samples Input number of samples - * @throws ExecutorchRuntimeException if the prefill failed - */ - @Experimental - public void prefillRawAudio(byte[] audio, int batch_size, int n_channels, int n_samples) { - mLock.lock(); - try { - checkNotReentrant(); - checkNotDestroyed(); - int nativeResult = prefillRawAudioInput(audio, batch_size, n_channels, n_samples); - if (nativeResult != 0) { - throw ExecutorchRuntimeException.makeExecutorchException(nativeResult, "Prefill failed"); - } - } finally { - mLock.unlock(); - } - } - - private native int prefillRawAudioInput( - byte[] audio, int batch_size, int n_channels, int n_samples); - - /** - * Prefill the KV cache with the given text prompt. - * - * @param prompt The text prompt to prefill. - * @throws ExecutorchRuntimeException if the prefill failed - */ - @Experimental - public void prefillPrompt(String prompt) { - mLock.lock(); - try { - checkNotReentrant(); - checkNotDestroyed(); - int nativeResult = prefillTextInput(prompt); - if (nativeResult != 0) { - throw ExecutorchRuntimeException.makeExecutorchException(nativeResult, "Prefill failed"); - } - } finally { - mLock.unlock(); - } - } - - // returns status - private native int prefillTextInput(String prompt); - - /** - * Reset the context of the LLM. This will clear the KV cache and reset the state of the LLM. - * - *

The startPos will be reset to 0. - */ - public void resetContext() { - mLock.lock(); - try { - checkNotReentrant(); - checkNotDestroyed(); - resetContextNative(); - } finally { - mLock.unlock(); - } - } - - @DoNotStrip - private native void resetContextNative(); - - /** Stop current generate() before it finishes. */ - public void stop() { - if (mDestroyed) return; - stopNative(); - } - - @DoNotStrip - private native void stopNative(); - - /** Force loading the module. Otherwise the model is loaded during first generate(). */ - public void load() { - mLock.lock(); - try { - checkNotReentrant(); - checkNotDestroyed(); - int err = loadNative(); - if (err != 0) { - throw ExecutorchRuntimeException.makeExecutorchException(err, "Failed to load model"); - } - } finally { - mLock.unlock(); - } - } - - @DoNotStrip - private native int loadNative(); -} diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.kt b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.kt new file mode 100644 index 00000000000..f95e796b83b --- /dev/null +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.kt @@ -0,0 +1,898 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package org.pytorch.executorch.extension.llm + +import com.facebook.jni.HybridData +import com.facebook.jni.annotations.DoNotStrip +import java.io.Closeable +import java.nio.ByteBuffer +import java.nio.ByteOrder +import java.util.concurrent.locks.ReentrantLock +import org.pytorch.executorch.ExecuTorchRuntime +import org.pytorch.executorch.ExecutorchRuntimeException +import org.pytorch.executorch.annotations.Experimental + +/** + * LlmModule is a wrapper around the Executorch LLM. It provides a simple interface to generate text + * from the model. + * + * Warning: These APIs are experimental and subject to change without notice + */ +@Experimental +class LlmModule +private constructor( + modelType: Int, + modulePath: String, + tokenizerPath: String, + temperature: Float, + dataFiles: List, + numBos: Int, + numEos: Int, + loadMode: Int, +) : Closeable { + + private val mHybridData: HybridData + private val mLock = ReentrantLock() + @Volatile private var mDestroyed = false + + init { + ExecuTorchRuntime.getRuntime() + ExecuTorchRuntime.validateFilePath(modulePath, "model path") + ExecuTorchRuntime.validateFilePath(tokenizerPath, "tokenizer path") + mHybridData = + initHybrid( + modelType, + modulePath, + tokenizerPath, + temperature, + dataFiles, + numBos, + numEos, + loadMode, + ) + } + + /** + * Constructs a LLM Module for a model with given type, model path, tokenizer, temperature, and + * dataFiles. + */ + constructor( + modelType: Int, + modulePath: String, + tokenizerPath: String, + temperature: Float, + dataFiles: List, + numBos: Int, + numEos: Int, + ) : this( + modelType, + modulePath, + tokenizerPath, + temperature, + dataFiles, + numBos, + numEos, + DEFAULT_LOAD_MODE, + ) + + /** + * Constructs a LLM Module for a model with given type, model path, tokenizer, temperature, and + * dataFiles. + */ + constructor( + modelType: Int, + modulePath: String, + tokenizerPath: String, + temperature: Float, + dataFiles: List, + ) : this( + modelType, + modulePath, + tokenizerPath, + temperature, + dataFiles, + DEFAULT_BOS, + DEFAULT_EOS, + DEFAULT_LOAD_MODE, + ) + + /** + * Constructs a LLM Module for a model with given type, model path, tokenizer, temperature, and + * data path. + */ + constructor( + modelType: Int, + modulePath: String, + tokenizerPath: String, + temperature: Float, + dataPath: String?, + numBos: Int, + numEos: Int, + ) : this( + modelType, + modulePath, + tokenizerPath, + temperature, + listOfNotNull(dataPath), + numBos, + numEos, + ) + + /** + * Constructs a LLM Module for a model with given type, model path, tokenizer, temperature, and + * data path. + */ + constructor( + modelType: Int, + modulePath: String, + tokenizerPath: String, + temperature: Float, + dataPath: String?, + ) : this( + modelType, + modulePath, + tokenizerPath, + temperature, + dataPath, + DEFAULT_BOS, + DEFAULT_EOS, + ) + + /** Constructs a LLM Module for a model with given model path, tokenizer, temperature. */ + constructor( + modulePath: String, + tokenizerPath: String, + temperature: Float, + ) : this( + MODEL_TYPE_TEXT, + modulePath, + tokenizerPath, + temperature, + emptyList(), + DEFAULT_BOS, + DEFAULT_EOS, + ) + + /** + * Constructs a LLM Module for a model with given model path, tokenizer, temperature and data + * path. + */ + constructor( + modulePath: String, + tokenizerPath: String, + temperature: Float, + dataPath: String, + ) : this( + MODEL_TYPE_TEXT, + modulePath, + tokenizerPath, + temperature, + listOf(dataPath), + DEFAULT_BOS, + DEFAULT_EOS, + ) + + /** Constructs a LLM Module for a model with given path, tokenizer, and temperature. */ + constructor( + modelType: Int, + modulePath: String, + tokenizerPath: String, + temperature: Float, + ) : this( + modelType, + modulePath, + tokenizerPath, + temperature, + emptyList(), + DEFAULT_BOS, + DEFAULT_EOS, + ) + + /** Constructs a LLM Module for a model with the given LlmModuleConfig */ + constructor( + config: LlmModuleConfig + ) : this( + config.modelType, + config.modulePath, + config.tokenizerPath, + config.temperature, + listOfNotNull(config.dataPath), + config.numBos, + config.numEos, + config.loadMode, + ) + + private fun checkNotDestroyed() { + if (mDestroyed) throw IllegalStateException("LlmModule has been destroyed") + } + + private fun checkNotReentrant() { + if (mLock.holdCount > 1) { + throw IllegalStateException("Cannot call LlmModule methods from within a callback") + } + } + + /** + * Releases native resources. Callers must ensure no other methods are in-flight. Call [stop] and + * wait for [generate] to return before calling this method. + */ + override fun close() { + if (mLock.tryLock()) { + try { + if (mLock.holdCount > 1) { + throw IllegalStateException("Cannot close module from within a callback during execution") + } + if (!mDestroyed) { + mDestroyed = true + mHybridData.resetNative() + } + } finally { + mLock.unlock() + } + } else { + throw IllegalStateException("Cannot close module while method is executing") + } + } + + /** @deprecated Use [close] instead. */ + @Deprecated("Use close() instead", replaceWith = ReplaceWith("close()")) + fun resetNative() { + close() + } + + // --- generate overloads --- + + /** + * Start generating tokens from the module. + * + * @param prompt Input prompt + * @param llmCallback callback object to receive results. + */ + fun generate(prompt: String, llmCallback: LlmCallback) { + generate( + prompt, + DEFAULT_SEQ_LEN, + llmCallback, + DEFAULT_ECHO, + DEFAULT_TEMPERATURE, + DEFAULT_BOS, + DEFAULT_EOS, + ) + } + + /** + * Start generating tokens from the module. + * + * @param prompt Input prompt + * @param seqLen sequence length + * @param llmCallback callback object to receive results. + */ + fun generate(prompt: String, seqLen: Int, llmCallback: LlmCallback) { + generate( + null, + 0, + 0, + 0, + prompt, + seqLen, + llmCallback, + DEFAULT_ECHO, + DEFAULT_TEMPERATURE, + DEFAULT_BOS, + DEFAULT_EOS, + ) + } + + /** + * Start generating tokens from the module. + * + * @param prompt Input prompt + * @param llmCallback callback object to receive results + * @param echo indicate whether to echo the input prompt or not (text completion vs chat) + */ + fun generate(prompt: String, llmCallback: LlmCallback, echo: Boolean) { + generate( + null, + 0, + 0, + 0, + prompt, + DEFAULT_SEQ_LEN, + llmCallback, + echo, + DEFAULT_TEMPERATURE, + DEFAULT_BOS, + DEFAULT_EOS, + ) + } + + /** + * Start generating tokens from the module. + * + * @param prompt Input prompt + * @param seqLen sequence length + * @param llmCallback callback object to receive results + * @param echo indicate whether to echo the input prompt or not (text completion vs chat) + */ + fun generate(prompt: String, seqLen: Int, llmCallback: LlmCallback, echo: Boolean) { + generate(prompt, seqLen, llmCallback, echo, DEFAULT_TEMPERATURE, DEFAULT_BOS, DEFAULT_EOS) + } + + /** + * Start generating tokens from the module. + * + * @param prompt Input prompt + * @param seqLen sequence length + * @param llmCallback callback object to receive results + * @param echo indicate whether to echo the input prompt or not (text completion vs chat) + * @param temperature temperature for sampling (use negative value to use module default) + * @param numBos number of BOS tokens to prepend + * @param numEos number of EOS tokens to append + */ + fun generate( + prompt: String, + seqLen: Int, + llmCallback: LlmCallback, + echo: Boolean, + temperature: Float, + numBos: Int, + numEos: Int, + ) { + mLock.lock() + try { + checkNotReentrant() + checkNotDestroyed() + val err = generateNative(prompt, seqLen, llmCallback, echo, temperature, numBos, numEos) + if (err != 0) { + throw ExecutorchRuntimeException.makeExecutorchException(err, "Failed to generate") + } + } finally { + mLock.unlock() + } + } + + @DoNotStrip + private external fun generateNative( + prompt: String, + seqLen: Int, + llmCallback: LlmCallback, + echo: Boolean, + temperature: Float, + numBos: Int, + numEos: Int, + ): Int + + /** + * Start generating tokens from the module. + * + * @param prompt Input prompt + * @param config the config for generation + * @param llmCallback callback object to receive results + */ + fun generate(prompt: String, config: LlmGenerationConfig, llmCallback: LlmCallback) { + generate( + null, + 0, + 0, + 0, + prompt, + config.seqLen, + llmCallback, + config.echo, + config.temperature, + config.numBos, + config.numEos, + ) + } + + /** + * Start generating tokens from the module. + * + * @param image Input image as a byte array + * @param width Input image width + * @param height Input image height + * @param channels Input image number of channels + * @param prompt Input prompt + * @param seqLen sequence length + * @param llmCallback callback object to receive results. + * @param echo indicate whether to echo the input prompt or not (text completion vs chat) + */ + fun generate( + image: IntArray?, + width: Int, + height: Int, + channels: Int, + prompt: String, + seqLen: Int, + llmCallback: LlmCallback, + echo: Boolean, + ) { + generate( + image, + width, + height, + channels, + prompt, + seqLen, + llmCallback, + echo, + DEFAULT_TEMPERATURE, + DEFAULT_BOS, + DEFAULT_EOS, + ) + } + + /** + * Start generating tokens from the module. + * + * @param image Input image as a byte array + * @param width Input image width + * @param height Input image height + * @param channels Input image number of channels + * @param prompt Input prompt + * @param seqLen sequence length + * @param llmCallback callback object to receive results. + * @param echo indicate whether to echo the input prompt or not (text completion vs chat) + * @param temperature temperature for sampling (use negative value to use module default) + */ + fun generate( + image: IntArray?, + width: Int, + height: Int, + channels: Int, + prompt: String, + seqLen: Int, + llmCallback: LlmCallback, + echo: Boolean, + temperature: Float, + ) { + generate( + image, + width, + height, + channels, + prompt, + seqLen, + llmCallback, + echo, + temperature, + DEFAULT_BOS, + DEFAULT_EOS, + ) + } + + /** + * Start generating tokens from the module. + * + * @param image Input image as a byte array + * @param width Input image width + * @param height Input image height + * @param channels Input image number of channels + * @param prompt Input prompt + * @param seqLen sequence length + * @param llmCallback callback object to receive results. + * @param echo indicate whether to echo the input prompt or not (text completion vs chat) + * @param temperature temperature for sampling (use negative value to use module default) + * @param numBos number of BOS tokens to prepend + * @param numEos number of EOS tokens to append + */ + fun generate( + image: IntArray?, + width: Int, + height: Int, + channels: Int, + prompt: String, + seqLen: Int, + llmCallback: LlmCallback, + echo: Boolean, + temperature: Float, + numBos: Int, + numEos: Int, + ) { + mLock.lock() + try { + checkNotReentrant() + checkNotDestroyed() + if (image != null) { + val nativeResult = prefillImagesInput(image, width, height, channels) + if (nativeResult != 0) { + throw ExecutorchRuntimeException.makeExecutorchException(nativeResult, "Prefill failed") + } + } + val err = generateNative(prompt, seqLen, llmCallback, echo, temperature, numBos, numEos) + if (err != 0) { + throw ExecutorchRuntimeException.makeExecutorchException(err, "Failed to generate") + } + } finally { + mLock.unlock() + } + } + + // --- prefill methods --- + + /** + * Prefill the KV cache with the given image input. + * + * @param image Input image as a byte array + * @param width Input image width + * @param height Input image height + * @param channels Input image number of channels + * @throws ExecutorchRuntimeException if the prefill failed + */ + @Experimental + fun prefillImages(image: IntArray, width: Int, height: Int, channels: Int) { + mLock.lock() + try { + checkNotReentrant() + checkNotDestroyed() + val nativeResult = prefillImagesInput(image, width, height, channels) + if (nativeResult != 0) { + throw ExecutorchRuntimeException.makeExecutorchException(nativeResult, "Prefill failed") + } + } finally { + mLock.unlock() + } + } + + /** + * Prefill a multimodal Module with the given image input via a direct ByteBuffer. The buffer data + * is accessed directly without JNI array copies, unlike [prefillImages]. The ByteBuffer must + * contain raw uint8 pixel data in CHW format with at least channels * height * width bytes + * remaining. Only the first channels * height * width bytes from the buffer's current position + * are read; the position of the original ByteBuffer is not modified. + * + * @param image Input image as a direct ByteBuffer containing uint8 pixel data + * @param width Input image width + * @param height Input image height + * @param channels Input image number of channels + * @throws IllegalArgumentException if the ByteBuffer is not direct or has insufficient remaining + * bytes + * @throws ExecutorchRuntimeException if the prefill failed + */ + @Experimental + fun prefillImages(image: ByteBuffer, width: Int, height: Int, channels: Int) { + mLock.lock() + try { + checkNotReentrant() + checkNotDestroyed() + require(image.isDirect) { "Input ByteBuffer must be direct." } + val expectedBytes: Long + try { + val pixels = Math.multiplyExact(width.toLong(), height.toLong()) + expectedBytes = Math.multiplyExact(pixels, channels.toLong()) + } catch (ex: ArithmeticException) { + throw IllegalArgumentException( + "width*height*channels is too large and overflows the allowed range.", + ex, + ) + } + require( + width > 0 && + height > 0 && + channels > 0 && + expectedBytes <= Int.MAX_VALUE.toLong() && + image.remaining().toLong() >= expectedBytes + ) { + "ByteBuffer remaining (${image.remaining()}) must be at least width*height*channels ($expectedBytes)." + } + // slice() so that getDirectBufferAddress on the native side returns a pointer + // starting at the current position, not the base address. + val nativeResult = prefillImagesInputBuffer(image.slice(), width, height, channels) + if (nativeResult != 0) { + throw ExecutorchRuntimeException.makeExecutorchException(nativeResult, "Prefill failed") + } + } finally { + mLock.unlock() + } + } + + /** + * Prefill a multimodal Module with the given normalized image input via a direct ByteBuffer. The + * buffer data is accessed directly without JNI array copies, unlike [prefillImages]. The + * ByteBuffer must contain normalized float pixel data in CHW format with at least channels * + * height * width * 4 bytes remaining. Only the first channels * height * width floats from the + * buffer's current position are consumed. The buffer must use the platform's native byte order + * (set via `buffer.order(ByteOrder.nativeOrder())`). + * + * @param image Input normalized image as a direct ByteBuffer containing float pixel data in + * native byte order + * @param width Input image width + * @param height Input image height + * @param channels Input image number of channels + * @throws IllegalArgumentException if the ByteBuffer is not direct, has insufficient remaining + * bytes, is not float-aligned, or does not use native byte order + * @throws ExecutorchRuntimeException if the prefill failed + */ + @Experimental + fun prefillNormalizedImage(image: ByteBuffer, width: Int, height: Int, channels: Int) { + mLock.lock() + try { + checkNotReentrant() + checkNotDestroyed() + require(image.isDirect) { "Input ByteBuffer must be direct." } + require(image.order() == ByteOrder.nativeOrder()) { + "Input ByteBuffer must use native byte order (ByteOrder.nativeOrder())." + } + require(image.position() % Float.SIZE_BYTES == 0) { + "Input ByteBuffer position (${image.position()}) must be 4-byte aligned." + } + val expectedBytes: Long + try { + val wh = Math.multiplyExact(width, height) + val whc = Math.multiplyExact(wh.toLong(), channels.toLong()) + val totalBytes = Math.multiplyExact(whc, Float.SIZE_BYTES.toLong()) + if (totalBytes > Int.MAX_VALUE.toLong()) { + throw IllegalArgumentException( + "ByteBuffer size (width*height*channels*4) exceeds Integer.MAX_VALUE bytes: $totalBytes", + ) + } + expectedBytes = totalBytes + } catch (e: ArithmeticException) { + throw IllegalArgumentException( + "Overflow while computing width*height*channels*4 for ByteBuffer size.", + e, + ) + } + require( + width > 0 && height > 0 && channels > 0 && image.remaining().toLong() >= expectedBytes + ) { + "ByteBuffer remaining (${image.remaining()}) must be at least width*height*channels*4 ($expectedBytes)." + } + require(image.remaining() % Float.SIZE_BYTES == 0) { + "ByteBuffer remaining (${image.remaining()}) must be a multiple of 4 (float size)." + } + // slice() so that getDirectBufferAddress on the native side returns a pointer + // starting at the current position, not the base address. + val nativeResult = prefillNormalizedImagesInputBuffer(image.slice(), width, height, channels) + if (nativeResult != 0) { + throw ExecutorchRuntimeException.makeExecutorchException(nativeResult, "Prefill failed") + } + } finally { + mLock.unlock() + } + } + + private external fun prefillImagesInput( + image: IntArray, + width: Int, + height: Int, + channels: Int, + ): Int + + private external fun prefillImagesInputBuffer( + image: ByteBuffer, + width: Int, + height: Int, + channels: Int, + ): Int + + private external fun prefillNormalizedImagesInputBuffer( + image: ByteBuffer, + width: Int, + height: Int, + channels: Int, + ): Int + + /** + * Prefill the KV cache with the given normalized image input. + * + * @param image Input normalized image as a float array + * @param width Input image width + * @param height Input image height + * @param channels Input image number of channels + * @throws ExecutorchRuntimeException if the prefill failed + */ + @Experimental + fun prefillImages(image: FloatArray, width: Int, height: Int, channels: Int) { + mLock.lock() + try { + checkNotReentrant() + checkNotDestroyed() + val nativeResult = prefillNormalizedImagesInput(image, width, height, channels) + if (nativeResult != 0) { + throw ExecutorchRuntimeException.makeExecutorchException(nativeResult, "Prefill failed") + } + } finally { + mLock.unlock() + } + } + + private external fun prefillNormalizedImagesInput( + image: FloatArray, + width: Int, + height: Int, + channels: Int, + ): Int + + /** + * Prefill the KV cache with the given preprocessed audio input. + * + * @param audio Input preprocessed audio as a byte array + * @param batchSize Input batch size + * @param nBins Input number of bins + * @param nFrames Input number of frames + * @throws ExecutorchRuntimeException if the prefill failed + */ + @Experimental + fun prefillAudio(audio: ByteArray, batchSize: Int, nBins: Int, nFrames: Int) { + mLock.lock() + try { + checkNotReentrant() + checkNotDestroyed() + val nativeResult = prefillAudioInput(audio, batchSize, nBins, nFrames) + if (nativeResult != 0) { + throw ExecutorchRuntimeException.makeExecutorchException(nativeResult, "Prefill failed") + } + } finally { + mLock.unlock() + } + } + + private external fun prefillAudioInput( + audio: ByteArray, + batchSize: Int, + nBins: Int, + nFrames: Int, + ): Int + + /** + * Prefill the KV cache with the given preprocessed audio input. + * + * @param audio Input preprocessed audio as a float array + * @param batchSize Input batch size + * @param nBins Input number of bins + * @param nFrames Input number of frames + * @throws ExecutorchRuntimeException if the prefill failed + */ + @Experimental + fun prefillAudio(audio: FloatArray, batchSize: Int, nBins: Int, nFrames: Int) { + mLock.lock() + try { + checkNotReentrant() + checkNotDestroyed() + val nativeResult = prefillAudioInputFloat(audio, batchSize, nBins, nFrames) + if (nativeResult != 0) { + throw ExecutorchRuntimeException.makeExecutorchException(nativeResult, "Prefill failed") + } + } finally { + mLock.unlock() + } + } + + private external fun prefillAudioInputFloat( + audio: FloatArray, + batchSize: Int, + nBins: Int, + nFrames: Int, + ): Int + + /** + * Prefill the KV cache with the given raw audio input. + * + * @param audio Input raw audio as a byte array + * @param batchSize Input batch size + * @param nChannels Input number of channels + * @param nSamples Input number of samples + * @throws ExecutorchRuntimeException if the prefill failed + */ + @Experimental + fun prefillRawAudio(audio: ByteArray, batchSize: Int, nChannels: Int, nSamples: Int) { + mLock.lock() + try { + checkNotReentrant() + checkNotDestroyed() + val nativeResult = prefillRawAudioInput(audio, batchSize, nChannels, nSamples) + if (nativeResult != 0) { + throw ExecutorchRuntimeException.makeExecutorchException(nativeResult, "Prefill failed") + } + } finally { + mLock.unlock() + } + } + + private external fun prefillRawAudioInput( + audio: ByteArray, + batchSize: Int, + nChannels: Int, + nSamples: Int, + ): Int + + /** + * Prefill the KV cache with the given text prompt. + * + * @param prompt The text prompt to prefill. + * @throws ExecutorchRuntimeException if the prefill failed + */ + @Experimental + fun prefillPrompt(prompt: String) { + mLock.lock() + try { + checkNotReentrant() + checkNotDestroyed() + val nativeResult = prefillTextInput(prompt) + if (nativeResult != 0) { + throw ExecutorchRuntimeException.makeExecutorchException(nativeResult, "Prefill failed") + } + } finally { + mLock.unlock() + } + } + + // returns status + private external fun prefillTextInput(prompt: String): Int + + /** + * Reset the context of the LLM. This will clear the KV cache and reset the state of the LLM. + * + * The startPos will be reset to 0. + */ + fun resetContext() { + mLock.lock() + try { + checkNotReentrant() + checkNotDestroyed() + resetContextNative() + } finally { + mLock.unlock() + } + } + + @DoNotStrip private external fun resetContextNative() + + /** Stop current generate() before it finishes. */ + fun stop() { + if (mDestroyed) return + stopNative() + } + + @DoNotStrip private external fun stopNative() + + /** Force loading the module. Otherwise the model is loaded during first generate(). */ + fun load() { + mLock.lock() + try { + checkNotReentrant() + checkNotDestroyed() + val err = loadNative() + if (err != 0) { + throw ExecutorchRuntimeException.makeExecutorchException(err, "Failed to load model") + } + } finally { + mLock.unlock() + } + } + + @DoNotStrip private external fun loadNative(): Int + + companion object { + const val MODEL_TYPE_TEXT = 1 + const val MODEL_TYPE_TEXT_VISION = 2 + const val MODEL_TYPE_MULTIMODAL = 2 + + private const val DEFAULT_SEQ_LEN = 128 + private const val DEFAULT_ECHO = true + private const val DEFAULT_TEMPERATURE = -1.0f + private const val DEFAULT_BOS = 0 + private const val DEFAULT_EOS = 0 + private const val DEFAULT_LOAD_MODE = LlmModuleConfig.LOAD_MODE_MMAP + + @DoNotStrip + @JvmStatic + private external fun initHybrid( + modelType: Int, + modulePath: String, + tokenizerPath: String, + temperature: Float, + dataFiles: List, + numBos: Int, + numEos: Int, + loadMode: Int, + ): HybridData + } +} diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModuleConfig.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModuleConfig.java deleted file mode 100644 index feb52a2b34b..00000000000 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModuleConfig.java +++ /dev/null @@ -1,252 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -package org.pytorch.executorch.extension.llm; - -/** - * Configuration class for initializing a LlmModule. - * - *

{@link #create()} method and the fluent builder pattern. - */ -public class LlmModuleConfig { - private final String modulePath; - private final String tokenizerPath; - private final float temperature; - private final String dataPath; - private final int modelType; - private final int numBos; - private final int numEos; - private final int loadMode; - - /** Load entire model file into a buffer (no mmap). */ - public static final int LOAD_MODE_FILE = 0; - - /** Load model via mmap without mlock (default). Pages faulted in on demand. */ - public static final int LOAD_MODE_MMAP = 1; - - /** Load model via mmap and pin all pages with mlock. */ - public static final int LOAD_MODE_MMAP_USE_MLOCK = 2; - - /** Load model via mmap and attempt mlock, ignoring mlock failures. */ - public static final int LOAD_MODE_MMAP_USE_MLOCK_IGNORE_ERRORS = 3; - - private LlmModuleConfig(Builder builder) { - this.modulePath = builder.modulePath; - this.tokenizerPath = builder.tokenizerPath; - this.temperature = builder.temperature; - this.dataPath = builder.dataPath; - this.modelType = builder.modelType; - this.numBos = builder.numBos; - this.numEos = builder.numEos; - this.loadMode = builder.loadMode; - } - - /** Model type constant for text-only models. */ - public static final int MODEL_TYPE_TEXT = 1; - - /** Model type constant for text-and-vision multimodal models. */ - public static final int MODEL_TYPE_TEXT_VISION = 2; - - /** Model type constant for generic multimodal models. */ - public static final int MODEL_TYPE_MULTIMODAL = 2; - - /** - * Creates a new Builder instance for constructing LlmModuleConfig objects. - * - * @return a new Builder instance with default configuration values - */ - public static Builder create() { - return new Builder(); - } - - // Getters with documentation - /** - * @return Path to the compiled model module (.pte file) - */ - public String getModulePath() { - return modulePath; - } - - /** - * @return Path to the tokenizer file or directory - */ - public String getTokenizerPath() { - return tokenizerPath; - } - - /** - * @return Temperature value for sampling (higher = more random) - */ - public float getTemperature() { - return temperature; - } - - /** - * @return Optional path to additional data files - */ - public String getDataPath() { - return dataPath; - } - - /** - * @return Type of model (text-only or text-vision) - */ - public int getModelType() { - return modelType; - } - - /** - * @return Number of BOS tokens to prepend - */ - public int getNumBos() { - return numBos; - } - - /** - * @return Number of EOS tokens to append - */ - public int getNumEos() { - return numEos; - } - - /** - * @return Load mode for the model file (one of LOAD_MODE_* constants) - */ - public int getLoadMode() { - return loadMode; - } - - /** - * Builder class for constructing LlmModuleConfig instances with optional parameters. - * - *

The builder provides a fluent interface for configuring model parameters and validates - * required fields before construction. - */ - public static class Builder { - private String modulePath; - private String tokenizerPath; - private float temperature = 0.8f; - private String dataPath = ""; - private int modelType = MODEL_TYPE_TEXT; - private int numBos = 0; - private int numEos = 0; - private int loadMode = LOAD_MODE_MMAP; - - Builder() {} - - /** - * Sets the path to the module. - * - * @param modulePath Path to module - * @return This builder instance for method chaining - */ - public Builder modulePath(String modulePath) { - this.modulePath = modulePath; - return this; - } - - /** - * Sets the path to the tokenizer. - * - * @param tokenizerPath Path to tokenizer - * @return This builder instance for method chaining - */ - public Builder tokenizerPath(String tokenizerPath) { - this.tokenizerPath = tokenizerPath; - return this; - } - - /** - * Sets the temperature for sampling generation. - * - * @param temperature Temperature value (typical range 0.0-1.0) - * @return This builder instance for method chaining - */ - public Builder temperature(float temperature) { - this.temperature = temperature; - return this; - } - - /** - * Sets the path to optional additional data files. - * - * @param dataPath Path to supplementary data resources - * @return This builder instance for method chaining - */ - public Builder dataPath(String dataPath) { - this.dataPath = dataPath; - return this; - } - - /** - * Sets the model type (text-only or multimodal). - * - * @param modelType One of MODEL_TYPE_TEXT, MODEL_TYPE_TEXT_VISION, MODEL_TYPE_MULTIMODAL - * @return This builder instance for method chaining - */ - public Builder modelType(int modelType) { - this.modelType = modelType; - return this; - } - - /** - * Sets the number of BOS tokens to prepend. - * - * @param numBos number of BOS tokens - * @return This builder instance for method chaining - */ - public Builder numBos(int numBos) { - this.numBos = numBos; - return this; - } - - /** - * Sets the number of EOS tokens to append. - * - * @param numEos number of EOS tokens - * @return This builder instance for method chaining - */ - public Builder numEos(int numEos) { - this.numEos = numEos; - return this; - } - - /** - * Sets the load mode for the model file. Defaults to {@link #LOAD_MODE_MMAP} (mmap without - * mlock), which avoids pinning model pages in RAM. - * - * @param loadMode One of LOAD_MODE_FILE, LOAD_MODE_MMAP, LOAD_MODE_MMAP_USE_MLOCK, - * LOAD_MODE_MMAP_USE_MLOCK_IGNORE_ERRORS - * @return This builder instance for method chaining - * @throws IllegalArgumentException if {@code loadMode} is not one of the supported constants - */ - public Builder loadMode(int loadMode) { - if (loadMode != LOAD_MODE_FILE - && loadMode != LOAD_MODE_MMAP - && loadMode != LOAD_MODE_MMAP_USE_MLOCK - && loadMode != LOAD_MODE_MMAP_USE_MLOCK_IGNORE_ERRORS) { - throw new IllegalArgumentException("Unknown load mode: " + loadMode); - } - this.loadMode = loadMode; - return this; - } - - /** - * Constructs the LlmModuleConfig instance with validated parameters. - * - * @return New LlmModuleConfig instance with configured values - * @throws IllegalArgumentException if required fields are missing - */ - public LlmModuleConfig build() { - if (modulePath == null || tokenizerPath == null) { - throw new IllegalArgumentException("Module path and tokenizer path are required"); - } - return new LlmModuleConfig(this); - } - } -} diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModuleConfig.kt b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModuleConfig.kt new file mode 100644 index 00000000000..2d65633bb9f --- /dev/null +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModuleConfig.kt @@ -0,0 +1,134 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package org.pytorch.executorch.extension.llm + +/** + * Configuration class for initializing a LlmModule. + * + * Use [create] method and the fluent builder pattern. + */ +class LlmModuleConfig +private constructor( + val modulePath: String, + val tokenizerPath: String, + val temperature: Float, + val dataPath: String?, + val modelType: Int, + val numBos: Int, + val numEos: Int, + val loadMode: Int, +) { + + companion object { + /** Load entire model file into a buffer (no mmap). */ + const val LOAD_MODE_FILE = 0 + + /** Load model via mmap without mlock (default). Pages faulted in on demand. */ + const val LOAD_MODE_MMAP = 1 + + /** Load model via mmap and pin all pages with mlock. */ + const val LOAD_MODE_MMAP_USE_MLOCK = 2 + + /** Load model via mmap and attempt mlock, ignoring mlock failures. */ + const val LOAD_MODE_MMAP_USE_MLOCK_IGNORE_ERRORS = 3 + + /** Model type constant for text-only models. */ + const val MODEL_TYPE_TEXT = 1 + + /** Model type constant for text-and-vision multimodal models. */ + const val MODEL_TYPE_TEXT_VISION = 2 + + /** Model type constant for generic multimodal models. */ + const val MODEL_TYPE_MULTIMODAL = 2 + + /** + * Creates a new Builder instance for constructing LlmModuleConfig objects. + * + * @return a new Builder instance with default configuration values + */ + @JvmStatic fun create(): Builder = Builder() + } + + /** + * Builder class for constructing LlmModuleConfig instances with optional parameters. + * + * The builder provides a fluent interface for configuring model parameters and validates required + * fields before construction. + */ + class Builder internal constructor() { + private var modulePath: String? = null + private var tokenizerPath: String? = null + private var temperature: Float = 0.8f + private var dataPath: String? = "" + private var modelType: Int = MODEL_TYPE_TEXT + private var numBos: Int = 0 + private var numEos: Int = 0 + private var loadMode: Int = LOAD_MODE_MMAP + + /** Sets the path to the module. */ + fun modulePath(modulePath: String): Builder = apply { this.modulePath = modulePath } + + /** Sets the path to the tokenizer. */ + fun tokenizerPath(tokenizerPath: String): Builder = apply { this.tokenizerPath = tokenizerPath } + + /** Sets the temperature for sampling generation. */ + fun temperature(temperature: Float): Builder = apply { this.temperature = temperature } + + /** Sets the path to optional additional data files. */ + fun dataPath(dataPath: String?): Builder = apply { this.dataPath = dataPath } + + /** Sets the model type (text-only or multimodal). */ + fun modelType(modelType: Int): Builder = apply { this.modelType = modelType } + + /** Sets the number of BOS tokens to prepend. */ + fun numBos(numBos: Int): Builder = apply { this.numBos = numBos } + + /** Sets the number of EOS tokens to append. */ + fun numEos(numEos: Int): Builder = apply { this.numEos = numEos } + + /** + * Sets the load mode for the model file. Defaults to [LOAD_MODE_MMAP] (mmap without mlock), + * which avoids pinning model pages in RAM. + * + * @throws IllegalArgumentException if loadMode is not one of the supported constants + */ + fun loadMode(loadMode: Int): Builder { + require( + loadMode == LOAD_MODE_FILE || + loadMode == LOAD_MODE_MMAP || + loadMode == LOAD_MODE_MMAP_USE_MLOCK || + loadMode == LOAD_MODE_MMAP_USE_MLOCK_IGNORE_ERRORS + ) { + "Unknown load mode: $loadMode" + } + return apply { this.loadMode = loadMode } + } + + /** + * Constructs the LlmModuleConfig instance with validated parameters. + * + * @throws IllegalArgumentException if required fields are missing + */ + fun build(): LlmModuleConfig { + require(modulePath != null && tokenizerPath != null) { + "Module path and tokenizer path are required" + } + return LlmModuleConfig( + modulePath!!, + tokenizerPath!!, + temperature, + dataPath, + modelType, + numBos, + numEos, + loadMode, + ) + } + } +} diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/package-info.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/package-info.java deleted file mode 100644 index 86e19d09133..00000000000 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/package-info.java +++ /dev/null @@ -1,51 +0,0 @@ -/** - * ExecuTorch LLM extension for Android. - * - *

This package provides Java bindings for running large language models (LLMs) on Android using - * ExecuTorch. It supports text generation, tokenization, and streaming token callbacks. - * - *

Quick Start

- * - *
{@code
- * import org.pytorch.executorch.extension.llm.LlmModule;
- *
- * // Load a Llama model
- * LlmModule llm = new LlmModule(
- *     "/data/local/tmp/llama.pte",
- *     "/data/local/tmp/tokenizer.bin",
- *     0.8f
- * );
- * llm.load();
- *
- * // Generate text token by token
- * llm.generate("Hello, my name is", 200, new LlmCallback() {
- *     public void onResult(String token) {
- *         System.out.print(token);
- *     }
- *     public void onStats(String stats) {
- *         System.out.println("\nStats: " + stats);
- *     }
- * });
- * }
- * - *

Key Classes

- * - *
    - *
  • {@link org.pytorch.executorch.extension.llm.LlmModule} — load and run an LLM - *
  • {@link org.pytorch.executorch.extension.llm.LlmModuleConfig} — configure model paths and - * settings - *
  • {@link org.pytorch.executorch.extension.llm.LlmGenerationConfig} — control generation - * (temperature, seq length) - *
- * - *

More Resources

- * - * - */ -package org.pytorch.executorch.extension.llm; diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/package-info.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/package-info.java deleted file mode 100644 index 7a5ed0bb5a5..00000000000 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/package-info.java +++ /dev/null @@ -1,57 +0,0 @@ -/** - * ExecuTorch Android Java API. - * - *

This package provides Java bindings for running ExecuTorch models on Android. Use these - * classes to load a {@code .pte} model file and run inference directly from your Java or Kotlin - * Android app — no C++ required. - * - *

Quick Start

- * - *

Step 1. Add the dependency to your {@code app/build.gradle.kts}: - * - *

{@code
- * dependencies {
- *     implementation("org.pytorch:executorch-android:${executorch_version}")
- * }
- * }
- * - *

Step 2. Load your model and run inference: - * - *

{@code
- * import org.pytorch.executorch.EValue;
- * import org.pytorch.executorch.Module;
- * import org.pytorch.executorch.Tensor;
- *
- * // Load your exported .pte model file
- * Module module = Module.load("/data/local/tmp/model.pte");
- *
- * // Build an input tensor  e.g. a 1x3x224x224 image
- * float[] inputData = new float[1 * 3 * 224 * 224];
- * Tensor inputTensor = Tensor.fromBlob(inputData, new long[]{1, 3, 224, 224});
- *
- * // Run inference
- * EValue[] output = module.forward(EValue.from(inputTensor));
- *
- * // Read the result
- * float[] scores = output[0].toTensor().getDataAsFloatArray();
- * }
- * - *

Key Classes

- * - *
    - *
  • {@link org.pytorch.executorch.Module} — load and run a {@code .pte} model - *
  • {@link org.pytorch.executorch.Tensor} — create input tensors and read outputs - *
  • {@link org.pytorch.executorch.EValue} — wrap inputs and unwrap outputs - *
  • {@link org.pytorch.executorch.DType} — supported data types (FLOAT, INT32, etc.) - *
- * - *

More Resources

- * - * - */ -package org.pytorch.executorch; diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/training/SGD.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/training/SGD.java deleted file mode 100644 index 58c7704b83e..00000000000 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/training/SGD.java +++ /dev/null @@ -1,103 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -package org.pytorch.executorch.training; - -import com.facebook.jni.HybridData; -import com.facebook.jni.annotations.DoNotStrip; -import com.facebook.soloader.nativeloader.NativeLoader; -import com.facebook.soloader.nativeloader.SystemDelegate; -import java.util.Map; -import org.pytorch.executorch.Tensor; -import org.pytorch.executorch.annotations.Experimental; - -/** - * Java wrapper for ExecuTorch SGD Optimizer. - * - *

Warning: These APIs are experimental and subject to change without notice - */ -@Experimental -public class SGD { - - static { - if (!NativeLoader.isInitialized()) { - NativeLoader.init(new SystemDelegate()); - } - // Loads libexecutorch.so from jniLibs - NativeLoader.loadLibrary("executorch"); - } - - private final HybridData mHybridData; - - @DoNotStrip - private static native HybridData initHybrid( - Map namedParameters, - double learningRate, - double momentum, - double dampening, - double weightDecay, - boolean nesterov); - - private SGD( - Map namedParameters, - double learningRate, - double momentum, - double dampening, - double weightDecay, - boolean nesterov) { - mHybridData = - initHybrid(namedParameters, learningRate, momentum, dampening, weightDecay, nesterov); - } - - /** - * Creates a new SGD optimizer with the specified parameters and options. - * - * @param namedParameters Map of parameter names to tensors to be optimized - * @param learningRate The learning rate for the optimizer - * @param momentum The momentum value - * @param dampening The dampening value - * @param weightDecay The weight decay value - * @param nesterov Whether to use Nesterov momentum - * @return new {@link SGD} object - */ - public static SGD create( - Map namedParameters, - double learningRate, - double momentum, - double dampening, - double weightDecay, - boolean nesterov) { - return new SGD(namedParameters, learningRate, momentum, dampening, weightDecay, nesterov); - } - - /** - * Creates a new SGD optimizer with default options. - * - * @param namedParameters Map of parameter names to tensors to be optimized - * @param learningRate The learning rate for the optimizer - * @return new {@link SGD} object - */ - public static SGD create(Map namedParameters, double learningRate) { - return create(namedParameters, learningRate, 0.0, 0.0, 0.0, false); - } - - /** - * Performs a single optimization step using the provided gradients. - * - * @param namedGradients Map of parameter names to gradient tensors - */ - public void step(Map namedGradients) { - if (!mHybridData.isValid()) { - throw new IllegalStateException("SGD optimizer has been destroyed"); - } - stepNative(namedGradients); - } - - @DoNotStrip - private native void stepNative(Map namedGradients); -} diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/training/SGD.kt b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/training/SGD.kt new file mode 100644 index 00000000000..e4aa5373498 --- /dev/null +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/training/SGD.kt @@ -0,0 +1,100 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package org.pytorch.executorch.training + +import com.facebook.jni.HybridData +import com.facebook.jni.annotations.DoNotStrip +import com.facebook.soloader.nativeloader.NativeLoader +import com.facebook.soloader.nativeloader.SystemDelegate +import org.pytorch.executorch.Tensor +import org.pytorch.executorch.annotations.Experimental + +/** + * Kotlin wrapper for ExecuTorch SGD Optimizer. + * + * Warning: These APIs are experimental and subject to change without notice + */ +@Experimental +class SGD +private constructor( + namedParameters: Map, + learningRate: Double, + momentum: Double, + dampening: Double, + weightDecay: Double, + nesterov: Boolean, +) { + + private val mHybridData: HybridData = + initHybrid(namedParameters, learningRate, momentum, dampening, weightDecay, nesterov) + + /** + * Performs a single optimization step using the provided gradients. + * + * @param namedGradients Map of parameter names to gradient tensors + */ + fun step(namedGradients: Map) { + check(mHybridData.isValid) { "SGD optimizer has been destroyed" } + stepNative(namedGradients) + } + + @DoNotStrip private external fun stepNative(namedGradients: Map) + + companion object { + init { + if (!NativeLoader.isInitialized()) { + NativeLoader.init(SystemDelegate()) + } + NativeLoader.loadLibrary("executorch") + } + + @DoNotStrip + @JvmStatic + private external fun initHybrid( + namedParameters: Map, + learningRate: Double, + momentum: Double, + dampening: Double, + weightDecay: Double, + nesterov: Boolean, + ): HybridData + + /** + * Creates a new SGD optimizer with the specified parameters and options. + * + * @param namedParameters Map of parameter names to tensors to be optimized + * @param learningRate The learning rate for the optimizer + * @param momentum The momentum value + * @param dampening The dampening value + * @param weightDecay The weight decay value + * @param nesterov Whether to use Nesterov momentum + * @return new [SGD] object + */ + @JvmStatic + fun create( + namedParameters: Map, + learningRate: Double, + momentum: Double, + dampening: Double, + weightDecay: Double, + nesterov: Boolean, + ): SGD = SGD(namedParameters, learningRate, momentum, dampening, weightDecay, nesterov) + + /** + * Creates a new SGD optimizer with default options. + * + * @param namedParameters Map of parameter names to tensors to be optimized + * @param learningRate The learning rate for the optimizer + * @return new [SGD] object + */ + @JvmStatic + fun create(namedParameters: Map, learningRate: Double): SGD = + create(namedParameters, learningRate, 0.0, 0.0, 0.0, false) + } +} diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/training/TrainingModule.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/training/TrainingModule.java deleted file mode 100644 index dd2d5a37de2..00000000000 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/training/TrainingModule.java +++ /dev/null @@ -1,140 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -package org.pytorch.executorch.training; - -import com.facebook.jni.HybridData; -import com.facebook.jni.annotations.DoNotStrip; -import com.facebook.soloader.nativeloader.NativeLoader; -import com.facebook.soloader.nativeloader.SystemDelegate; -import java.io.Closeable; -import java.util.Map; -import java.util.concurrent.locks.ReentrantLock; -import org.pytorch.executorch.EValue; -import org.pytorch.executorch.ExecuTorchRuntime; -import org.pytorch.executorch.Tensor; -import org.pytorch.executorch.annotations.Experimental; - -/** - * Java wrapper for ExecuTorch TrainingModule. - * - *

Warning: These APIs are experimental and subject to change without notice - */ -@Experimental -public class TrainingModule implements Closeable { - - static { - if (!NativeLoader.isInitialized()) { - NativeLoader.init(new SystemDelegate()); - } - // Loads libexecutorch.so from jniLibs - NativeLoader.loadLibrary("executorch"); - } - - private final HybridData mHybridData; - private final ReentrantLock mLock = new ReentrantLock(); - private volatile boolean mDestroyed = false; - - @DoNotStrip - private static native HybridData initHybrid(String moduleAbsolutePath, String dataAbsolutePath); - - private TrainingModule(String moduleAbsolutePath, String dataAbsolutePath) { - mHybridData = initHybrid(moduleAbsolutePath, dataAbsolutePath); - } - - private void checkNotDestroyed() { - if (mDestroyed) throw new IllegalStateException("TrainingModule has been destroyed"); - } - - /** - * Loads a serialized ExecuTorch Training Module from the specified path on the disk. - * - * @param modelPath path to file that contains the serialized ExecuTorch module. - * @param dataPath path to file that contains the ExecuTorch module external weights. - * @return new {@link TrainingModule} object which owns the model module. - */ - public static TrainingModule load(final String modelPath, final String dataPath) { - ExecuTorchRuntime.validateFilePath(modelPath, "model path"); - ExecuTorchRuntime.validateFilePath(dataPath, "data path"); - return new TrainingModule(modelPath, dataPath); - } - - /** - * Loads a serialized ExecuTorch training module from the specified path on the disk. - * - * @param modelPath path to file that contains the serialized ExecuTorch module. This PTE does not - * rely on external weights. - * @return new {@link TrainingModule} object which owns the model module. - */ - public static TrainingModule load(final String modelPath) { - ExecuTorchRuntime.validateFilePath(modelPath, "model path"); - return new TrainingModule(modelPath, ""); - } - - /** - * Runs the specified joint-graph method of this module with the specified arguments. - * - * @param methodName name of the ExecuTorch method to run. - * @param inputs arguments that will be passed to ExecuTorch method. - * @return return value(s) from the method. - */ - public EValue[] executeForwardBackward(String methodName, EValue... inputs) { - mLock.lock(); - try { - checkNotDestroyed(); - return executeForwardBackwardNative(methodName, inputs); - } finally { - mLock.unlock(); - } - } - - @DoNotStrip - private native EValue[] executeForwardBackwardNative(String methodName, EValue... inputs); - - public Map namedParameters(String methodName) { - mLock.lock(); - try { - checkNotDestroyed(); - return namedParametersNative(methodName); - } finally { - mLock.unlock(); - } - } - - @DoNotStrip - private native Map namedParametersNative(String methodName); - - public Map namedGradients(String methodName) { - mLock.lock(); - try { - checkNotDestroyed(); - return namedGradientsNative(methodName); - } finally { - mLock.unlock(); - } - } - - @DoNotStrip - private native Map namedGradientsNative(String methodName); - - @Override - public void close() { - if (mLock.tryLock()) { - try { - if (!mDestroyed) { - mDestroyed = true; - mHybridData.resetNative(); - } - } finally { - mLock.unlock(); - } - } else { - throw new IllegalStateException("Cannot close module while method is executing"); - } - } -} diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/training/TrainingModule.kt b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/training/TrainingModule.kt new file mode 100644 index 00000000000..5556b0c16c4 --- /dev/null +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/training/TrainingModule.kt @@ -0,0 +1,144 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package org.pytorch.executorch.training + +import com.facebook.jni.HybridData +import com.facebook.jni.annotations.DoNotStrip +import com.facebook.soloader.nativeloader.NativeLoader +import com.facebook.soloader.nativeloader.SystemDelegate +import java.io.Closeable +import java.util.concurrent.locks.ReentrantLock +import org.pytorch.executorch.EValue +import org.pytorch.executorch.ExecuTorchRuntime +import org.pytorch.executorch.Tensor +import org.pytorch.executorch.annotations.Experimental + +/** + * Kotlin wrapper for ExecuTorch TrainingModule. + * + * Warning: These APIs are experimental and subject to change without notice + */ +@Experimental +class TrainingModule private constructor(moduleAbsolutePath: String, dataAbsolutePath: String) : + Closeable { + + private val mHybridData: HybridData = initHybrid(moduleAbsolutePath, dataAbsolutePath) + private val mLock = ReentrantLock() + + @Volatile private var mDestroyed = false + + private fun checkNotDestroyed() { + check(!mDestroyed) { "TrainingModule has been destroyed" } + } + + /** + * Runs the specified joint-graph method of this module with the specified arguments. + * + * @param methodName name of the ExecuTorch method to run. + * @param inputs arguments that will be passed to ExecuTorch method. + * @return return value(s) from the method. + */ + fun executeForwardBackward(methodName: String, vararg inputs: EValue): Array { + mLock.lock() + try { + checkNotDestroyed() + return executeForwardBackwardNative(methodName, *inputs) + } finally { + mLock.unlock() + } + } + + @DoNotStrip + private external fun executeForwardBackwardNative( + methodName: String, + vararg inputs: EValue, + ): Array + + fun namedParameters(methodName: String): Map { + mLock.lock() + try { + checkNotDestroyed() + return namedParametersNative(methodName) + } finally { + mLock.unlock() + } + } + + @DoNotStrip private external fun namedParametersNative(methodName: String): Map + + fun namedGradients(methodName: String): Map { + mLock.lock() + try { + checkNotDestroyed() + return namedGradientsNative(methodName) + } finally { + mLock.unlock() + } + } + + @DoNotStrip private external fun namedGradientsNative(methodName: String): Map + + override fun close() { + if (mLock.tryLock()) { + try { + if (!mDestroyed) { + mDestroyed = true + mHybridData.resetNative() + } + } finally { + mLock.unlock() + } + } else { + throw IllegalStateException("Cannot close module while method is executing") + } + } + + companion object { + init { + if (!NativeLoader.isInitialized()) { + NativeLoader.init(SystemDelegate()) + } + NativeLoader.loadLibrary("executorch") + } + + @DoNotStrip + @JvmStatic + private external fun initHybrid( + moduleAbsolutePath: String, + dataAbsolutePath: String, + ): HybridData + + /** + * Loads a serialized ExecuTorch Training Module from the specified path on the disk. + * + * @param modelPath path to file that contains the serialized ExecuTorch module. + * @param dataPath path to file that contains the ExecuTorch module external weights. + * @return new [TrainingModule] object which owns the model module. + */ + @JvmStatic + fun load(modelPath: String, dataPath: String): TrainingModule { + ExecuTorchRuntime.validateFilePath(modelPath, "model path") + ExecuTorchRuntime.validateFilePath(dataPath, "data path") + return TrainingModule(modelPath, dataPath) + } + + /** + * Loads a serialized ExecuTorch training module from the specified path on the disk. + * + * @param modelPath path to file that contains the serialized ExecuTorch module. This PTE does + * not rely on external weights. + * @return new [TrainingModule] object which owns the model module. + */ + @JvmStatic + fun load(modelPath: String): TrainingModule { + ExecuTorchRuntime.validateFilePath(modelPath, "model path") + return TrainingModule(modelPath, "") + } + } +} diff --git a/extension/android/jni/jni_layer_llama.cpp b/extension/android/jni/jni_layer_llama.cpp index e072694f913..b9215f978bc 100644 --- a/extension/android/jni/jni_layer_llama.cpp +++ b/extension/android/jni/jni_layer_llama.cpp @@ -206,41 +206,14 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { data_files_vector, cpp_load_mode); std::string decoder_model = "llama3"; // use llama3 for now - // Using 8bit as default since this meta is introduced with 16bit kv io - // support and older models only have 8bit kv io. - example::KvBitWidth kv_bitwidth = example::KvBitWidth::kWidth8; - if (module->method_names()->count("get_kv_io_bit_width") > 0) { - kv_bitwidth = static_cast( - module->get("get_kv_io_bit_width") - .get() - .toScalar() - .to()); - } - - if (kv_bitwidth == example::KvBitWidth::kWidth8) { - runner_ = std::make_unique>( - std::move(module), - decoder_model.c_str(), - model_path->toStdString().c_str(), - tokenizer_path->toStdString().c_str(), - "", - "", - temperature_); - } else if (kv_bitwidth == example::KvBitWidth::kWidth16) { - runner_ = std::make_unique>( - std::move(module), - decoder_model.c_str(), - model_path->toStdString().c_str(), - tokenizer_path->toStdString().c_str(), - "", - "", - temperature_); - } else { - ET_CHECK_MSG( - false, - "Unsupported kv bitwidth: %ld", - static_cast(kv_bitwidth)); - } + runner_ = std::make_unique( + std::move(module), + decoder_model.c_str(), + model_path->toStdString().c_str(), + tokenizer_path->toStdString().c_str(), + "", + "", + temperature_); model_type_category_ = MODEL_TYPE_CATEGORY_LLM; #endif #if defined(EXECUTORCH_BUILD_MEDIATEK) diff --git a/extension/asr/runner/CMakeLists.txt b/extension/asr/runner/CMakeLists.txt index 66974aa2a24..b47cddaf48c 100644 --- a/extension/asr/runner/CMakeLists.txt +++ b/extension/asr/runner/CMakeLists.txt @@ -22,7 +22,7 @@ endif() include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake) set(runner_deps executorch_core extension_module extension_tensor - tokenizers::tokenizers + extension_llm_runner tokenizers::tokenizers ) # Define runner library diff --git a/extension/asr/runner/transducer_runner.cpp b/extension/asr/runner/transducer_runner.cpp index 3461cb09cc1..7b9298845a9 100644 --- a/extension/asr/runner/transducer_runner.cpp +++ b/extension/asr/runner/transducer_runner.cpp @@ -200,7 +200,7 @@ Error TransducerRunner::load() { return Error::Ok; } -Result<::executorch::extension::TensorPtr> TransducerRunner::preprocess( +Result TransducerRunner::preprocess( ::executorch::extension::TensorPtr raw_audio) { if (!is_loaded()) { ET_CHECK_OK_OR_RETURN_ERROR(load()); @@ -229,12 +229,18 @@ Result<::executorch::extension::TensorPtr> TransducerRunner::preprocess( "Preprocessor returned unexpected output."); auto mel = outputs[0].toTensor(); - return std::make_shared<::executorch::aten::Tensor>(std::move(mel)); + int64_t mel_len = mel.sizes()[1]; // default to tensor dim + if (outputs.size() >= 2 && outputs[1].isTensor()) { + mel_len = outputs[1].toTensor().const_data_ptr()[0]; + } + return PreprocessResult{ + std::make_shared<::executorch::aten::Tensor>(std::move(mel)), mel_len}; } Result> TransducerRunner::transcribe( ::executorch::extension::TensorPtr preprocessed_features, - std::function token_callback) { + std::function token_callback, + int64_t features_length) { if (!is_loaded()) { ET_CHECK_OK_OR_RETURN_ERROR(load()); } @@ -242,7 +248,9 @@ Result> TransducerRunner::transcribe( stats_.inference_start_ms = ::executorch::extension::llm::time_in_ms(); // --- Encode --- - int64_t mel_len_value = preprocessed_features->size(1); + // Use provided length, or fall back to tensor dimension + int64_t mel_len_value = + features_length > 0 ? features_length : preprocessed_features->size(1); std::vector mel_len_data = {mel_len_value}; auto mel_len = ::executorch::extension::from_blob( mel_len_data.data(), {1}, ::executorch::aten::ScalarType::Long); diff --git a/extension/asr/runner/transducer_runner.h b/extension/asr/runner/transducer_runner.h index ee819590141..aed0ad84cd6 100644 --- a/extension/asr/runner/transducer_runner.h +++ b/extension/asr/runner/transducer_runner.h @@ -29,6 +29,14 @@ using ::executorch::extension::llm::Stats; using ::executorch::runtime::Error; using ::executorch::runtime::Result; +/** + * Preprocessed audio features with actual (unpadded) length. + */ +struct PreprocessResult { + ::executorch::extension::TensorPtr features; + int64_t length; // Actual number of valid frames (excluding padding) +}; + /** * A decoded token with frame-level timing information. */ @@ -97,7 +105,7 @@ class ET_EXPERIMENTAL TransducerRunner { * @returns Preprocessed features tensor (e.g., mel spectrogram), * ready to pass to transcribe(). */ - Result<::executorch::extension::TensorPtr> preprocess( + Result preprocess( ::executorch::extension::TensorPtr raw_audio); /** @@ -112,7 +120,8 @@ class ET_EXPERIMENTAL TransducerRunner { */ Result> transcribe( ::executorch::extension::TensorPtr preprocessed_features, - std::function token_callback = {}); + std::function token_callback = {}, + int64_t features_length = -1); /** * Returns a reference to the loaded tokenizer, or nullptr if not loaded. diff --git a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkActivity.java b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkActivity.java deleted file mode 100644 index 5e1dd48926b..00000000000 --- a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkActivity.java +++ /dev/null @@ -1,136 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -package org.pytorch.minibench; - -import android.app.Activity; -import android.content.Intent; -import android.os.Bundle; -import android.os.Handler; -import android.os.HandlerThread; -import android.os.Looper; -import android.system.ErrnoException; -import android.system.Os; -import com.google.gson.Gson; -import java.io.File; -import java.io.FileWriter; -import java.io.IOException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; - -public class BenchmarkActivity extends Activity { - - File mModel; - int mNumIter; - int mNumWarmupIter; - String mTokenizerPath; - float mTemperature; - String mPrompt; - - HandlerThread mHandlerThread; - BenchmarkHandler mHandler; - - List mResult; - - @Override - protected void onCreate(Bundle savedInstanceState) { - super.onCreate(savedInstanceState); - - try { - Os.setenv("ADSP_LIBRARY_PATH", getApplicationInfo().nativeLibraryDir, true); - } catch (ErrnoException e) { - finish(); - } - - Intent intent = getIntent(); - File modelDir = new File(intent.getStringExtra("model_dir")); - File model = - Arrays.stream(modelDir.listFiles()) - .filter(file -> file.getName().endsWith(".pte")) - .findFirst() - .get(); - - int numIter = intent.getIntExtra("num_iter", 50); - int numWarmupIter = intent.getIntExtra("num_warm_up_iter", 10); - String tokenizerPath = intent.getStringExtra("tokenizer_path"); - float temperature = intent.getFloatExtra("temperature", 0.8f); - String prompt = intent.getStringExtra("prompt"); - - mModel = model; - mNumIter = numIter; - mNumWarmupIter = numWarmupIter; - mTokenizerPath = tokenizerPath; - mTemperature = temperature; - mPrompt = prompt; - if (mPrompt == null) { - mPrompt = "The ultimate answer"; - } - mResult = new ArrayList<>(); - - mHandlerThread = new HandlerThread("ModelRunner"); - mHandlerThread.start(); - mHandler = new BenchmarkHandler(mHandlerThread.getLooper(), this); - - mHandler.sendEmptyMessage(BenchmarkHandler.MESSAGE_RUN_BENCHMARK); - } - - void writeResult() { - try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.json")) { - Gson gson = new Gson(); - writer.write(gson.toJson(mResult)); - } catch (IOException e) { - e.printStackTrace(); - } finally { - finish(); - } - } -} - -class BenchmarkHandler extends Handler { - public static int MESSAGE_RUN_BENCHMARK = 1; - public static int MESSAGE_LLM_RUN_BENCHMARK = 2; - - ModelRunner mModelRunner; - BenchmarkActivity mBenchmarkActivity; - - LlmModelRunner mLlmModelRunner; - LlmBenchmark mLlmBenchmark; - - public BenchmarkHandler(Looper looper, BenchmarkActivity benchmarkActivity) { - super(looper); - mModelRunner = new ModelRunner(); - mBenchmarkActivity = benchmarkActivity; - } - - @Override - public void handleMessage(android.os.Message msg) { - if (msg.what == MESSAGE_RUN_BENCHMARK) { - mModelRunner.runBenchmark( - mBenchmarkActivity.mModel, - mBenchmarkActivity.mNumWarmupIter, - mBenchmarkActivity.mNumIter, - mBenchmarkActivity.mResult); - - if (mBenchmarkActivity.mTokenizerPath == null) { - mBenchmarkActivity.writeResult(); - } else { - this.sendEmptyMessage(MESSAGE_LLM_RUN_BENCHMARK); - } - } else if (msg.what == MESSAGE_LLM_RUN_BENCHMARK) { - mLlmBenchmark = - new LlmBenchmark( - mBenchmarkActivity, - mBenchmarkActivity.mModel.getPath(), - mBenchmarkActivity.mTokenizerPath, - mBenchmarkActivity.mPrompt, - mBenchmarkActivity.mTemperature, - mBenchmarkActivity.mResult); - } - } -} diff --git a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkActivity.kt b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkActivity.kt new file mode 100644 index 00000000000..b1d69c5f24f --- /dev/null +++ b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkActivity.kt @@ -0,0 +1,116 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package org.pytorch.minibench + +import android.app.Activity +import android.os.Bundle +import android.os.Handler +import android.os.HandlerThread +import android.os.Looper +import android.os.Message +import android.system.Os +import com.google.gson.Gson +import java.io.File +import java.io.FileWriter +import java.io.IOException + +class BenchmarkActivity : Activity() { + + lateinit var model: File + var numIter: Int = 0 + var numWarmupIter: Int = 0 + var tokenizerPath: String? = null + var temperature: Float = 0.8f + var prompt: String = "The ultimate answer" + + private lateinit var handlerThread: HandlerThread + private lateinit var handler: BenchmarkHandler + + val results: MutableList = mutableListOf() + + override fun onCreate(savedInstanceState: Bundle?) { + super.onCreate(savedInstanceState) + + try { + Os.setenv("ADSP_LIBRARY_PATH", applicationInfo.nativeLibraryDir, true) + } catch (e: android.system.ErrnoException) { + finish() + return + } + + val intent = intent + val modelDir = File(intent.getStringExtra("model_dir")!!) + model = modelDir.listFiles()!!.first { it.name.endsWith(".pte") } + + numIter = intent.getIntExtra("num_iter", 50) + numWarmupIter = intent.getIntExtra("num_warm_up_iter", 10) + tokenizerPath = intent.getStringExtra("tokenizer_path") + temperature = intent.getFloatExtra("temperature", 0.8f) + prompt = intent.getStringExtra("prompt") ?: "The ultimate answer" + + handlerThread = HandlerThread("ModelRunner") + handlerThread.start() + handler = BenchmarkHandler(handlerThread.looper, this) + + handler.sendEmptyMessage(BenchmarkHandler.MESSAGE_RUN_BENCHMARK) + } + + fun writeResult() { + try { + FileWriter("${filesDir}/benchmark_results.json").use { writer -> + writer.write(Gson().toJson(results)) + } + } catch (e: IOException) { + e.printStackTrace() + } finally { + finish() + } + } +} + +private class BenchmarkHandler( + looper: Looper, + private val activity: BenchmarkActivity, +) : Handler(looper) { + + private val modelRunner = ModelRunner() + + override fun handleMessage(msg: Message) { + when (msg.what) { + MESSAGE_RUN_BENCHMARK -> { + modelRunner.runBenchmark( + activity.model, + activity.numWarmupIter, + activity.numIter, + activity.results, + ) + if (activity.tokenizerPath == null) { + activity.writeResult() + } else { + sendEmptyMessage(MESSAGE_LLM_RUN_BENCHMARK) + } + } + MESSAGE_LLM_RUN_BENCHMARK -> { + LlmBenchmark( + activity, + activity.model.path, + activity.tokenizerPath!!, + activity.prompt, + activity.temperature, + activity.results, + ) + } + } + } + + companion object { + const val MESSAGE_RUN_BENCHMARK = 1 + const val MESSAGE_LLM_RUN_BENCHMARK = 2 + } +} diff --git a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkMetric.java b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkMetric.java deleted file mode 100644 index 66ab50550a4..00000000000 --- a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkMetric.java +++ /dev/null @@ -1,74 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -package org.pytorch.minibench; - -import android.app.ActivityManager; -import android.os.Build; -import java.util.regex.Matcher; -import java.util.regex.Pattern; - -class BenchmarkMetric { - public static class BenchmarkModel { - // The model name, i.e. stories110M - String name; - String backend; - String quantization; - - public BenchmarkModel(final String name, final String backend, final String quantization) { - this.name = name; - this.backend = backend; - this.quantization = quantization; - } - } - - BenchmarkModel benchmarkModel; - - // The metric name, i.e. TPS - String metric; - - // The actual value and the option target value - double actualValue; - double targetValue; - - public static class DeviceInfo { - // Let's see which information we want to include here - final String device = Build.BRAND; - // The phone model and Android release version - final String arch = Build.MODEL; - final String os = "Android " + Build.VERSION.RELEASE; - final long totalMem = new ActivityManager.MemoryInfo().totalMem; - final long availMem = new ActivityManager.MemoryInfo().availMem; - } - - DeviceInfo deviceInfo = new DeviceInfo(); - - public BenchmarkMetric( - final BenchmarkModel benchmarkModel, - final String metric, - final double actualValue, - final double targetValue) { - this.benchmarkModel = benchmarkModel; - this.metric = metric; - this.actualValue = actualValue; - this.targetValue = targetValue; - } - - // TODO (huydhn): Figure out a way to extract the backend and quantization information from - // the .pte model itself instead of parsing its name - public static BenchmarkMetric.BenchmarkModel extractBackendAndQuantization(final String model) { - final Matcher m = - Pattern.compile("(?\\w+)_(?[\\w\\+]+)_(?\\w+)").matcher(model); - if (m.matches()) { - return new BenchmarkMetric.BenchmarkModel( - m.group("name"), m.group("backend"), m.group("quantization")); - } else { - return new BenchmarkMetric.BenchmarkModel(model, "", ""); - } - } -} diff --git a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkMetric.kt b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkMetric.kt new file mode 100644 index 00000000000..7bed1ab05c0 --- /dev/null +++ b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkMetric.kt @@ -0,0 +1,54 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package org.pytorch.minibench + +import android.app.ActivityManager +import android.os.Build + +class BenchmarkMetric( + val benchmarkModel: BenchmarkModel, + val metric: String, + val actualValue: Double, + val targetValue: Double, +) { + data class BenchmarkModel( + val name: String, + val backend: String, + val quantization: String, + ) + + class DeviceInfo { + val device: String = Build.BRAND + val arch: String = Build.MODEL + val os: String = "Android ${Build.VERSION.RELEASE}" + val totalMem: Long = ActivityManager.MemoryInfo().totalMem + val availMem: Long = ActivityManager.MemoryInfo().availMem + } + + val deviceInfo: DeviceInfo = DeviceInfo() + + companion object { + // TODO (huydhn): Figure out a way to extract the backend and quantization information from + // the .pte model itself instead of parsing its name + @JvmStatic + fun extractBackendAndQuantization(model: String): BenchmarkModel { + val pattern = Regex("(?\\w+)_(?[\\w+]+)_(?\\w+)") + val match = pattern.matchEntire(model) + return if (match != null) { + BenchmarkModel( + match.groups["name"]!!.value, + match.groups["backend"]!!.value, + match.groups["quantization"]!!.value, + ) + } else { + BenchmarkModel(model, "", "") + } + } + } +} diff --git a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmBenchmark.java b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmBenchmark.java deleted file mode 100644 index 0c0436d2676..00000000000 --- a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmBenchmark.java +++ /dev/null @@ -1,123 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -package org.pytorch.minibench; - -import android.util.Log; -import java.util.List; -import org.json.JSONException; -import org.json.JSONObject; - -public class LlmBenchmark implements LlmModelRunnerCallback { - LlmModelRunner mLlmModelRunner; - - String mPrompt; - StatsInfo mStatsInfo; - - List mResults; - BenchmarkActivity mActivity; - - LlmBenchmark( - BenchmarkActivity activity, - String modelFile, - String tokenizerPath, - String prompt, - float temperature, - List results) { - mResults = results; - mActivity = activity; - mStatsInfo = new StatsInfo(); - mStatsInfo.modelName = modelFile.substring(modelFile.lastIndexOf('/') + 1).replace(".pte", ""); - mPrompt = prompt; - mLlmModelRunner = new LlmModelRunner(modelFile, tokenizerPath, temperature, this); - mStatsInfo.loadStart = System.nanoTime(); - } - - @Override - public void onModelLoaded(int status) { - mStatsInfo.loadEnd = System.nanoTime(); - mStatsInfo.loadStatus = status; - if (status != 0) { - Log.e("LlmBenchmarkRunner", "Loaded failed: " + status); - onGenerationStopped(); - return; - } - mStatsInfo.generateStart = System.nanoTime(); - mLlmModelRunner.generate(mPrompt); - } - - @Override - public void onTokenGenerated(String token) {} - - @Override - public void onStats(String stats) { - float tps = 0; - try { - JSONObject jsonObject = new JSONObject(stats); - int numGeneratedTokens = jsonObject.getInt("generated_tokens"); - int inferenceEndMs = jsonObject.getInt("inference_end_ms"); - int promptEvalEndMs = jsonObject.getInt("prompt_eval_end_ms"); - tps = (float) numGeneratedTokens / (inferenceEndMs - promptEvalEndMs) * 1000; - mStatsInfo.tps = tps; - } catch (JSONException e) { - Log.e("LLM", "Error parsing JSON: " + e.getMessage()); - } - } - - @Override - public void onGenerationStopped() { - mStatsInfo.generateEnd = System.nanoTime(); - - final BenchmarkMetric.BenchmarkModel benchmarkModel = - BenchmarkMetric.extractBackendAndQuantization(mStatsInfo.modelName); - // The list of metrics we have atm includes: - // Load status - mResults.add(new BenchmarkMetric(benchmarkModel, "load_status", mStatsInfo.loadStatus, 0)); - // Model load time - mResults.add( - new BenchmarkMetric( - benchmarkModel, - "llm_model_load_time(ms)", - (mStatsInfo.loadEnd - mStatsInfo.loadStart) * 1e-6, - 0.0f)); - // LLM generate time - mResults.add( - new BenchmarkMetric( - benchmarkModel, - "generate_time(ms)", - (mStatsInfo.generateEnd - mStatsInfo.generateStart) * 1e-6, - 0.0f)); - // Token per second - mResults.add(new BenchmarkMetric(benchmarkModel, "token_per_sec", mStatsInfo.tps, 0.0f)); - mActivity.writeResult(); - } -} - -class StatsInfo { - int loadStatus; - long loadStart; - long loadEnd; - long generateStart; - long generateEnd; - float tps; - String modelName; - - @Override - public String toString() { - return "loadStart: " - + loadStart - + "\nloadEnd: " - + loadEnd - + "\ngenerateStart: " - + generateStart - + "\ngenerateEnd: " - + generateEnd - + "\n" - + tps; - } -} diff --git a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmBenchmark.kt b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmBenchmark.kt new file mode 100644 index 00000000000..5c75519f870 --- /dev/null +++ b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmBenchmark.kt @@ -0,0 +1,91 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package org.pytorch.minibench + +import android.util.Log +import org.json.JSONException +import org.json.JSONObject + +class LlmBenchmark( + private val activity: BenchmarkActivity, + modelFile: String, + tokenizerPath: String, + private val prompt: String, + temperature: Float, + private val results: MutableList, +) : LlmModelRunnerCallback { + + private val runner: LlmModelRunner + private val statsInfo = StatsInfo() + + init { + statsInfo.modelName = modelFile.substringAfterLast('/').removeSuffix(".pte") + runner = LlmModelRunner(modelFile, tokenizerPath, temperature, this) + statsInfo.loadStart = System.nanoTime() + } + + override fun onModelLoaded(status: Int) { + statsInfo.loadEnd = System.nanoTime() + statsInfo.loadStatus = status + if (status != 0) { + Log.e("LlmBenchmarkRunner", "Loaded failed: $status") + onGenerationStopped() + return + } + statsInfo.generateStart = System.nanoTime() + runner.generate(prompt) + } + + override fun onTokenGenerated(token: String) {} + + override fun onStats(stats: String) { + try { + val json = JSONObject(stats) + val numGeneratedTokens = json.getInt("generated_tokens") + val inferenceEndMs = json.getInt("inference_end_ms") + val promptEvalEndMs = json.getInt("prompt_eval_end_ms") + statsInfo.tps = numGeneratedTokens.toFloat() / (inferenceEndMs - promptEvalEndMs) * 1000 + } catch (e: JSONException) { + Log.e("LLM", "Error parsing JSON: ${e.message}") + } + } + + override fun onGenerationStopped() { + statsInfo.generateEnd = System.nanoTime() + + val benchmarkModel = BenchmarkMetric.extractBackendAndQuantization(statsInfo.modelName) + results.add(BenchmarkMetric(benchmarkModel, "load_status", statsInfo.loadStatus.toDouble(), 0.0)) + results.add( + BenchmarkMetric( + benchmarkModel, + "llm_model_load_time(ms)", + (statsInfo.loadEnd - statsInfo.loadStart) * 1e-6, + 0.0, + )) + results.add( + BenchmarkMetric( + benchmarkModel, + "generate_time(ms)", + (statsInfo.generateEnd - statsInfo.generateStart) * 1e-6, + 0.0, + )) + results.add(BenchmarkMetric(benchmarkModel, "token_per_sec", statsInfo.tps.toDouble(), 0.0)) + activity.writeResult() + } +} + +private class StatsInfo { + var loadStatus: Int = 0 + var loadStart: Long = 0 + var loadEnd: Long = 0 + var generateStart: Long = 0 + var generateEnd: Long = 0 + var tps: Float = 0f + var modelName: String = "" +} diff --git a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmModelRunner.java b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmModelRunner.java deleted file mode 100644 index 3a345d3465b..00000000000 --- a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmModelRunner.java +++ /dev/null @@ -1,110 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -package org.pytorch.minibench; - -import android.os.Handler; -import android.os.HandlerThread; -import android.os.Looper; -import android.os.Message; -import android.util.Log; -import org.pytorch.executorch.extension.llm.LlmCallback; -import org.pytorch.executorch.extension.llm.LlmModule; - -/** A helper class to handle all model running logic within this class. */ -public class LlmModelRunner implements LlmCallback { - LlmModule mModule = null; - - String mModelFilePath = ""; - String mTokenizerFilePath = ""; - - LlmModelRunnerCallback mCallback = null; - - HandlerThread mHandlerThread = null; - Handler mHandler = null; - - /** - * ] Helper class to separate between UI logic and model runner logic. Automatically handle - * generate() request on worker thread. - * - * @param modelFilePath - * @param tokenizerFilePath - * @param callback - */ - LlmModelRunner( - String modelFilePath, - String tokenizerFilePath, - float temperature, - LlmModelRunnerCallback callback) { - mModelFilePath = modelFilePath; - mTokenizerFilePath = tokenizerFilePath; - mCallback = callback; - - mModule = new LlmModule(mModelFilePath, mTokenizerFilePath, 0.8f); - mHandlerThread = new HandlerThread("LlmModelRunner"); - mHandlerThread.start(); - mHandler = new LlmModelRunnerHandler(mHandlerThread.getLooper(), this); - - mHandler.sendEmptyMessage(LlmModelRunnerHandler.MESSAGE_LOAD_MODEL); - } - - int generate(String prompt) { - Message msg = Message.obtain(mHandler, LlmModelRunnerHandler.MESSAGE_GENERATE, prompt); - msg.sendToTarget(); - return 0; - } - - void stop() { - mModule.stop(); - } - - @Override - public void onResult(String result) { - mCallback.onTokenGenerated(result); - } - - @Override - public void onStats(String result) { - mCallback.onStats(result); - } -} - -class LlmModelRunnerHandler extends Handler { - public static int MESSAGE_LOAD_MODEL = 1; - public static int MESSAGE_GENERATE = 2; - - private final LlmModelRunner mLlmModelRunner; - - public LlmModelRunnerHandler(Looper looper, LlmModelRunner llmModelRunner) { - super(looper); - mLlmModelRunner = llmModelRunner; - } - - @Override - public void handleMessage(android.os.Message msg) { - if (msg.what == MESSAGE_LOAD_MODEL) { - int status = 0; - try { - mLlmModelRunner.mModule.load(); - } catch (Exception e) { - status = - (e instanceof org.pytorch.executorch.ExecutorchRuntimeException) - ? ((org.pytorch.executorch.ExecutorchRuntimeException) e).getErrorCode() - : -1; - } - mLlmModelRunner.mCallback.onModelLoaded(status); - } else if (msg.what == MESSAGE_GENERATE) { - try { - mLlmModelRunner.mModule.generate((String) msg.obj, mLlmModelRunner); - } catch (Exception e) { - Log.e("LlmModelRunner", "generate() failed", e); - } - mLlmModelRunner.mCallback.onGenerationStopped(); - } - } -} diff --git a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmModelRunner.kt b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmModelRunner.kt new file mode 100644 index 00000000000..29b9b177fb6 --- /dev/null +++ b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmModelRunner.kt @@ -0,0 +1,91 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package org.pytorch.minibench + +import android.os.Handler +import android.os.HandlerThread +import android.os.Looper +import android.os.Message +import android.util.Log +import org.pytorch.executorch.ExecutorchRuntimeException +import org.pytorch.executorch.extension.llm.LlmCallback +import org.pytorch.executorch.extension.llm.LlmModule + +/** A helper class to handle all model running logic within this class. */ +class LlmModelRunner( + modelFilePath: String, + tokenizerFilePath: String, + temperature: Float, + val callback: LlmModelRunnerCallback, +) : LlmCallback { + + val module: LlmModule = LlmModule(modelFilePath, tokenizerFilePath, temperature) + private val handlerThread: HandlerThread = HandlerThread("LlmModelRunner") + private val handler: Handler + + init { + handlerThread.start() + handler = LlmModelRunnerHandler(handlerThread.looper, this) + handler.sendEmptyMessage(LlmModelRunnerHandler.MESSAGE_LOAD_MODEL) + } + + fun generate(prompt: String): Int { + val msg = Message.obtain(handler, LlmModelRunnerHandler.MESSAGE_GENERATE, prompt) + msg.sendToTarget() + return 0 + } + + fun stop() { + module.stop() + } + + override fun onResult(result: String) { + callback.onTokenGenerated(result) + } + + override fun onStats(stats: String) { + callback.onStats(stats) + } +} + +private class LlmModelRunnerHandler( + looper: Looper, + private val runner: LlmModelRunner, +) : Handler(looper) { + + override fun handleMessage(msg: Message) { + when (msg.what) { + MESSAGE_LOAD_MODEL -> { + val status = + try { + runner.module.load() + 0 + } catch (e: ExecutorchRuntimeException) { + e.errorCode + } catch (e: Exception) { + -1 + } + runner.callback.onModelLoaded(status) + } + MESSAGE_GENERATE -> { + try { + runner.module.generate(msg.obj as String, runner) + } catch (e: Exception) { + Log.e("LlmModelRunner", "generate() failed", e) + } + runner.callback.onGenerationStopped() + } + } + } + + companion object { + const val MESSAGE_LOAD_MODEL = 1 + const val MESSAGE_GENERATE = 2 + } +} diff --git a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunner.java b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunner.java deleted file mode 100644 index 915496a25af..00000000000 --- a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunner.java +++ /dev/null @@ -1,99 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -package org.pytorch.minibench; - -import android.os.Debug; -import java.io.File; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; -import org.pytorch.executorch.Module; - -public class ModelRunner { - /** - * @return list of #BenchmarkMetric - */ - public void runBenchmark( - File model, int numWarmupIter, int numIter, List results) { - long pssIdle = Debug.getPss(); - - List latency = new ArrayList<>(); - - long loadStart = System.nanoTime(); - Module module = Module.load(model.getPath()); - int errorCode = 0; - try { - module.loadMethod("forward"); - } catch (Exception e) { - errorCode = - (e instanceof org.pytorch.executorch.ExecutorchRuntimeException) - ? ((org.pytorch.executorch.ExecutorchRuntimeException) e).getErrorCode() - : -1; - } - long loadEnd = System.nanoTime(); - - final BenchmarkMetric.BenchmarkModel benchmarkModel = - BenchmarkMetric.extractBackendAndQuantization(model.getName().replace(".pte", "")); - - if (errorCode != 0) { - results.add( - new BenchmarkMetric( - benchmarkModel, "model_load_time(ms)", (loadEnd - loadStart) * 1e-6, 0.0f)); - results.add(new BenchmarkMetric(benchmarkModel, "load_status", errorCode, 0)); - module.destroy(); - return; - } - - try { - for (int i = 0; i < numWarmupIter; i++) { - module.forward(); - } - - for (int i = 0; i < numIter; i++) { - long start = System.nanoTime(); - module.forward(); - double forwardMs = (System.nanoTime() - start) * 1e-6; - latency.add(forwardMs); - } - - module.etdump(); - - // Currently the result has large variance from outliers, so only use - // 80% samples in the middle (trimmean 0.2) - Collections.sort(latency); - int resultSize = latency.size(); - List usedLatencyResults = latency.subList(resultSize / 10, resultSize * 9 / 10); - - results.add( - new BenchmarkMetric( - benchmarkModel, - "avg_inference_latency(ms)", - latency.stream().mapToDouble(l -> l).average().orElse(0.0f), - 0.0f)); - results.add( - new BenchmarkMetric( - benchmarkModel, - "trimmean_inference_latency(ms)", - usedLatencyResults.stream().mapToDouble(l -> l).average().orElse(0.0f), - 0.0f)); - // Model load time - results.add( - new BenchmarkMetric( - benchmarkModel, "model_load_time(ms)", (loadEnd - loadStart) * 1e-6, 0.0f)); - // Load status - results.add(new BenchmarkMetric(benchmarkModel, "load_status", errorCode, 0)); - // RAM PSS usage - results.add( - new BenchmarkMetric( - benchmarkModel, "ram_pss_usage(mb)", (Debug.getPss() - pssIdle) / 1024, 0)); - } finally { - module.destroy(); - } - } -} diff --git a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunner.kt b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunner.kt new file mode 100644 index 00000000000..0f292b0d900 --- /dev/null +++ b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunner.kt @@ -0,0 +1,90 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package org.pytorch.minibench + +import android.os.Debug +import java.io.File +import org.pytorch.executorch.ExecutorchRuntimeException +import org.pytorch.executorch.Module + +class ModelRunner { + + fun runBenchmark( + model: File, + numWarmupIter: Int, + numIter: Int, + results: MutableList, + ) { + val pssIdle = Debug.getPss() + val latency = mutableListOf() + + val loadStart = System.nanoTime() + val module = Module.load(model.path) + var errorCode = 0 + try { + module.loadMethod("forward") + } catch (e: ExecutorchRuntimeException) { + errorCode = e.errorCode + } catch (e: Exception) { + errorCode = -1 + } + val loadEnd = System.nanoTime() + + val benchmarkModel = + BenchmarkMetric.extractBackendAndQuantization(model.name.removeSuffix(".pte")) + + if (errorCode != 0) { + results.add( + BenchmarkMetric(benchmarkModel, "model_load_time(ms)", (loadEnd - loadStart) * 1e-6, 0.0)) + results.add(BenchmarkMetric(benchmarkModel, "load_status", errorCode.toDouble(), 0.0)) + module.destroy() + return + } + + try { + repeat(numWarmupIter) { module.forward() } + + repeat(numIter) { + val start = System.nanoTime() + module.forward() + latency.add((System.nanoTime() - start) * 1e-6) + } + + module.etdump() + + // Currently the result has large variance from outliers, so only use + // 80% samples in the middle (trimmean 0.2) + latency.sort() + val trimmed = latency.subList(latency.size / 10, latency.size * 9 / 10) + + results.add( + BenchmarkMetric( + benchmarkModel, + "avg_inference_latency(ms)", + latency.average(), + 0.0, + )) + results.add( + BenchmarkMetric( + benchmarkModel, + "trimmean_inference_latency(ms)", + trimmed.average(), + 0.0, + )) + results.add( + BenchmarkMetric(benchmarkModel, "model_load_time(ms)", (loadEnd - loadStart) * 1e-6, 0.0)) + results.add(BenchmarkMetric(benchmarkModel, "load_status", errorCode.toDouble(), 0.0)) + results.add( + BenchmarkMetric( + benchmarkModel, "ram_pss_usage(mb)", (Debug.getPss() - pssIdle) / 1024.0, 0.0)) + } finally { + module.destroy() + } + } +} diff --git a/extension/benchmark/android/benchmark/app/src/test/java/org/pytorch/minibench/ExampleUnitTest.java b/extension/benchmark/android/benchmark/app/src/test/java/org/pytorch/minibench/ExampleUnitTest.kt similarity index 55% rename from extension/benchmark/android/benchmark/app/src/test/java/org/pytorch/minibench/ExampleUnitTest.java rename to extension/benchmark/android/benchmark/app/src/test/java/org/pytorch/minibench/ExampleUnitTest.kt index c6a6a76a4d8..b98a49e4bf9 100644 --- a/extension/benchmark/android/benchmark/app/src/test/java/org/pytorch/minibench/ExampleUnitTest.java +++ b/extension/benchmark/android/benchmark/app/src/test/java/org/pytorch/minibench/ExampleUnitTest.kt @@ -6,20 +6,19 @@ * LICENSE file in the root directory of this source tree. */ -package org.pytorch.minibench; +package org.pytorch.minibench -import static org.junit.Assert.*; - -import org.junit.Test; +import org.junit.Assert.assertEquals +import org.junit.Test /** * Example local unit test, which will execute on the development machine (host). * - * @see Testing documentation + * @see [Testing documentation](http://d.android.com/tools/testing) */ -public class ExampleUnitTest { +class ExampleUnitTest { @Test - public void addition_isCorrect() { - assertEquals(4, 2 + 2); + fun addition_isCorrect() { + assertEquals(4, 2 + 2) } } diff --git a/extension/flat_tensor/flat_tensor_data_map.cpp b/extension/flat_tensor/flat_tensor_data_map.cpp index 48684da1239..845778f45c2 100644 --- a/extension/flat_tensor/flat_tensor_data_map.cpp +++ b/extension/flat_tensor/flat_tensor_data_map.cpp @@ -21,6 +21,8 @@ #include #include +#include + using executorch::runtime::Error; using executorch::runtime::FreeableBuffer; using executorch::runtime::Result; @@ -52,7 +54,7 @@ Result get_named_data( flatbuffers::Offset>* named_data, const flatbuffers::Vector< flatbuffers::Offset>* segments, - size_t segment_end_offset) { + uint64_t segment_end_offset) { // Linear search by name. if (named_data == nullptr) { return Error::NotFound; @@ -81,19 +83,34 @@ Result get_named_data( static_cast(segments->Get(segment_index)->offset()), static_cast(segments->Get(segment_index)->size()), &seg_end) && - seg_end <= static_cast(segment_end_offset), + seg_end <= segment_end_offset, InvalidExternalData, "Invalid segment offset %" PRIu64 " is larger than the segment_base_offset + segment_data_size %" PRIu64 "; malformed PTD file.", segments->Get(segment_index)->offset(), - static_cast(segment_end_offset)); + segment_end_offset); return found; } } return Error::NotFound; } +Result get_segment_end_offset(const FlatTensorHeader& header) { + uint64_t segment_end_offset = 0; + ET_CHECK_OR_RETURN_ERROR( + !c10::add_overflows( + header.segment_base_offset, + header.segment_data_size, + &segment_end_offset), + InvalidExternalData, + "segment_base_offset %" PRIu64 " + segment_data_size %" PRIu64 + " overflows uint64_t; malformed PTD file.", + header.segment_base_offset, + header.segment_data_size); + return segment_end_offset; +} + Result create_tensor_layout( const flat_tensor_flatbuffer::TensorLayout* tensor_layout) { ScalarType scalar_type = @@ -111,11 +128,15 @@ Result create_tensor_layout( ET_NODISCARD Result FlatTensorDataMap::get_tensor_layout( executorch::aten::string_view key) const { + Result segment_end_offset = get_segment_end_offset(header_); + if (!segment_end_offset.ok()) { + return segment_end_offset.error(); + } Result named_data = get_named_data( key, flat_tensor_->named_data(), flat_tensor_->segments(), - header_.segment_base_offset + header_.segment_data_size); + segment_end_offset.get()); if (!named_data.ok()) { return named_data.error(); } @@ -124,11 +145,15 @@ ET_NODISCARD Result FlatTensorDataMap::get_tensor_layout( ET_NODISCARD Result FlatTensorDataMap::get_data( executorch::aten::string_view key) const { + Result segment_end_offset = get_segment_end_offset(header_); + if (!segment_end_offset.ok()) { + return segment_end_offset.error(); + } Result named_data = get_named_data( key, flat_tensor_->named_data(), flat_tensor_->segments(), - header_.segment_base_offset + header_.segment_data_size); + segment_end_offset.get()); if (!named_data.ok()) { return named_data.error(); } @@ -148,11 +173,15 @@ ET_NODISCARD Error FlatTensorDataMap::load_data_into( ET_UNUSED executorch::aten::string_view key, ET_UNUSED void* buffer, ET_UNUSED size_t size) const { + Result segment_end_offset = get_segment_end_offset(header_); + if (!segment_end_offset.ok()) { + return segment_end_offset.error(); + } Result named_data = get_named_data( key, flat_tensor_->named_data(), flat_tensor_->segments(), - header_.segment_base_offset + header_.segment_data_size); + segment_end_offset.get()); if (!named_data.ok()) { return named_data.error(); } diff --git a/extension/llm/custom_ops/model_sharding.py b/extension/llm/custom_ops/model_sharding.py index 6838b0958a2..916b13a90b8 100644 --- a/extension/llm/custom_ops/model_sharding.py +++ b/extension/llm/custom_ops/model_sharding.py @@ -7,8 +7,9 @@ import re from typing import List -import torch +import executorch.extension.llm.custom_ops.op_fallback # noqa: F401 +import torch from executorch.backends.qualcomm.utils.constants import ( QCOM_PASS_ACTIVATE_KEY, QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY, @@ -17,27 +18,6 @@ from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult from torch.export.exported_program import ExportedProgram -from torch.library import impl, Library - - -fallback_op_lib = Library("llama", "DEF") -# registering an operator. -fallback_op_lib.define("fallback(Tensor input) -> Tensor") - - -@impl(fallback_op_lib, "fallback") -def fallback_impl(a: torch.Tensor) -> torch.Tensor: - return a - - -# registering the out variant. -fallback_op_lib.define("fallback.out(Tensor input, *, Tensor(a!) output) -> Tensor(a!)") - - -@impl(fallback_op_lib, "fallback.out") -def fallback_out_impl(a: torch.Tensor, *, out: torch.Tensor) -> torch.Tensor: - out.copy_(a) - return out class SplitGraph(ExportPass): diff --git a/extension/llm/custom_ops/op_fallback.py b/extension/llm/custom_ops/op_fallback.py new file mode 100644 index 00000000000..e94c81db51a --- /dev/null +++ b/extension/llm/custom_ops/op_fallback.py @@ -0,0 +1,29 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# pyre-ignore-all-errors + +import torch + +from torch.library import impl, Library + +fallback_op_lib = Library("llama", "DEF") +# registering an operator. +fallback_op_lib.define("fallback(Tensor input) -> Tensor") + + +@impl(fallback_op_lib, "fallback") +def fallback_impl(a: torch.Tensor) -> torch.Tensor: + return a + + +# registering the out variant. +fallback_op_lib.define("fallback.out(Tensor input, *, Tensor(a!) output) -> Tensor(a!)") + + +@impl(fallback_op_lib, "fallback.out") +def fallback_out_impl(a: torch.Tensor, *, out: torch.Tensor) -> torch.Tensor: + out.copy_(a) + return out diff --git a/extension/llm/custom_ops/targets.bzl b/extension/llm/custom_ops/targets.bzl index 6746d7ab877..1d1feeda0c1 100644 --- a/extension/llm/custom_ops/targets.bzl +++ b/extension/llm/custom_ops/targets.bzl @@ -141,6 +141,7 @@ def define_common_targets(): name = "model_sharding_py", srcs = [ "model_sharding.py", + "op_fallback.py", ], visibility = ["PUBLIC"], deps = [ diff --git a/extension/llm/export/builder.py b/extension/llm/export/builder.py index c25c1190990..5928e40dc4d 100644 --- a/extension/llm/export/builder.py +++ b/extension/llm/export/builder.py @@ -256,6 +256,35 @@ def run_canonical_optimizations(self): assert res.graph_module is not None, "Pass returned None" self.pre_autograd_graph_module = res.graph_module + def _check_calibration_prefix_options(self) -> None: + if ( + not self.use_kv_cache + and not self.enable_dynamic_shape + and not self.generate_full_logits + ): + raise ValueError( + "Static non-KV calibration with padded prefixes requires " + "generate_full_logits so calibration can sample the last " + "non-pad token position." + ) + + def _prepare_calibration_prefix( + self, token_list: List[int], pos: int, max_len: int, pad_token: int + ) -> Tuple[torch.Tensor, int]: + prefix_tokens = list(token_list[: pos + 1]) + logits_token_pos = min(len(prefix_tokens), max_len) - 1 + + if self.enable_dynamic_shape: + prefix_tokens = prefix_tokens[:max_len] + elif len(prefix_tokens) < max_len: + prefix_tokens.extend([pad_token] * (max_len - len(prefix_tokens))) + else: + prefix_tokens = prefix_tokens[:max_len] + + input_dtype = self.example_inputs[0].dtype + prefix = torch.tensor(prefix_tokens, dtype=input_dtype).unsqueeze(0) + return prefix, logits_token_pos + def pt2e_calibrate( self, prepared_module, @@ -266,39 +295,41 @@ def pt2e_calibrate( tokenizer_path, ): logging.info("Run calibration...") - try: - from executorch.examples.models.llama.eval_llama_lib import ( - GraphModuleEvalWrapper, - ) - from lm_eval.evaluator import simple_evaluate - except ImportError: - raise ImportError( - "Please install the llm eval dependency via examples/models/llama/install_requirements.sh" - ) - + self._check_calibration_prefix_options() tokenizer = get_tokenizer(tokenizer_path) def calibrate_template( module: torch.fx.GraphModule, tokenizer, prompts: str, max_len: int ): # TODO: change criteria & support batch inputs if necessary - pos = torch.tensor(0, dtype=torch.int64) + pos = 0 token_list = tokenizer.encode(prompts, bos=True, eos=False) + pad_token = getattr(tokenizer, "pad_id", tokenizer.eos_id) + with torch.no_grad(): while token_list[-1] != tokenizer.eos_id and pos < max_len: - logits = module( - torch.full((1, 1), token_list[pos]), - {"input_pos": torch.tensor((pos,))}, - ) + logits_token_pos = -1 + if self.use_kv_cache: + logits = module( + torch.full((1, 1), token_list[pos]), + {"input_pos": torch.tensor((pos,))}, + ) + else: + prefix, logits_token_pos = self._prepare_calibration_prefix( + token_list, pos, max_len, pad_token + ) + logits = module(prefix) + pos += 1 if pos >= len(token_list): if self.generate_full_logits: - token_list.append( - torch.argmax(logits[:, -1], dim=-1).item() - ) + next_token = torch.argmax( + logits[:, logits_token_pos], dim=-1 + ).item() else: - token_list.append(torch.argmax(logits[:], dim=-1).item()) + next_token = torch.argmax(logits[:], dim=-1).item() + token_list.append(next_token) calibrate_template( module=prepared_module, @@ -307,26 +338,41 @@ def calibrate_template( max_len=calibration_seq_length, ) - eval_wrapper = GraphModuleEvalWrapper( - model=prepared_module, - tokenizer=tokenizer, - max_seq_length=calibration_seq_length, - use_kv_cache=self.use_kv_cache, - generate_full_logits=self.generate_full_logits, - enable_dynamic_shape=self.enable_dynamic_shape, - ) + if calibration_tasks: + try: + from executorch.examples.models.llama.eval_llama_lib import ( + GraphModuleEvalWrapper, + ) + from lm_eval.evaluator import simple_evaluate + except ImportError: + raise ImportError( + "Please install the llm eval dependency via examples/models/llama/install_requirements.sh" + ) - # Evaluate the model - with torch.no_grad(): - eval_results = simple_evaluate( - model=eval_wrapper, - tasks=calibration_tasks, - limit=calibration_limit, + eval_wrapper = GraphModuleEvalWrapper( + model=prepared_module, + tokenizer=tokenizer, + max_seq_length=calibration_seq_length, + use_kv_cache=self.use_kv_cache, + generate_full_logits=self.generate_full_logits, + enable_dynamic_shape=self.enable_dynamic_shape, + # The exported graph can contain ops like aten.full.default + # without explicit device, which default to CPU and can + # trigger device-mismatch errors when lm_eval runs on CUDA. + # Calibrate on CPU for stability. + device="cpu", ) - for task, res in eval_results["results"].items(): - print(f"{task}: {res}") - logging.info("Calibration finish...") + with torch.no_grad(): + eval_results = simple_evaluate( + model=eval_wrapper, + tasks=calibration_tasks, + limit=calibration_limit, + ) + + for task, res in eval_results["results"].items(): + print(f"{task}: {res}") + logging.info("Calibration finish...") def pt2e_quantize(self, quantizers: Optional[List[Quantizer]]) -> "LLMEdgeManager": """ @@ -351,18 +397,19 @@ def pt2e_quantize(self, quantizers: Optional[List[Quantizer]]) -> "LLMEdgeManage assert ( self.pre_autograd_graph_module is not None ), "Please run export() first" + if self.calibration_tasks and self.calibration_limit is None: + logging.warning( + "calibration_tasks provided without calibration_limit; " + "lm-eval will run the full task dataset during " + "calibration." + ) m = prepare_pt2e( self.pre_autograd_graph_module, # pyre-ignore[6] composed_quantizer, ) - logging.info( - f"Calibrating with tasks: {self.calibration_tasks}, limit: {self.calibration_limit}, calibration_data: {self.calibration_data}, tokenizer_path: {self.tokenizer_path}, seq_length: {self.calibration_seq_length}" - ) # Calibrate if ( - self.calibration_tasks is not None - and self.calibration_limit is not None - and self.calibration_seq_length is not None + self.calibration_seq_length is not None and self.calibration_data is not None and self.tokenizer_path is not None ): diff --git a/extension/module/module.cpp b/extension/module/module.cpp index 5422fb15b71..11fea031603 100644 --- a/extension/module/module.cpp +++ b/extension/module/module.cpp @@ -13,6 +13,7 @@ #include #include #include +#include #include namespace executorch { @@ -367,6 +368,51 @@ Module::make_planned_memory_with_shared_arenas( return planned; } +std::unique_ptr Module::make_planned_memory_with_devices( + const ET_RUNTIME_NAMESPACE::MethodMeta& method_meta) { + auto planned = std::make_unique(); + const size_t num_buffers = method_meta.num_memory_planned_buffers(); + planned->planned_buffers.reserve(num_buffers); + planned->planned_spans.reserve(num_buffers); + planned->device_buffers.reserve(num_buffers); + planned->planned_devices.reserve(num_buffers); + + for (size_t i = 0; i < num_buffers; ++i) { + auto size = method_meta.memory_planned_buffer_size(i); + ET_CHECK_MSG(size.ok(), "Failed to get buffer size for index %zu", i); + auto device = method_meta.memory_planned_buffer_device(i); + ET_CHECK_MSG(device.ok(), "Failed to get buffer device for index %zu", i); + planned->planned_devices.push_back(device.get()); + + if (device->is_cpu()) { + planned->planned_buffers.emplace_back(size.get()); + planned->planned_spans.emplace_back( + planned->planned_buffers.back().data(), size.get()); + } else { + // Allocate device memory via DeviceAllocator and store the RAII buffer. + planned->planned_buffers.emplace_back(); // empty CPU placeholder + auto dmb = runtime::DeviceMemoryBuffer::create( + size.get(), device->type(), device->index()); + ET_CHECK_MSG( + dmb.ok(), + "Failed to allocate device memory for buffer %zu (device_type=%d)", + i, + static_cast(device->type())); + planned->planned_spans.emplace_back(dmb->as_span()); + planned->device_buffers.push_back(std::move(dmb.get())); + } + } + + // HierarchicalAllocator owns the per-buffer Device metadata so the + // MemoryManager can later expose it via planned_buffer_devices(). + planned->planned_memory = std::make_unique( + runtime::Span>( + planned->planned_spans.data(), planned->planned_spans.size()), + runtime::Span( + planned->planned_devices.data(), planned->planned_devices.size())); + return planned; +} + runtime::Result> Module::get_mem_planned_buffer_sizes( const std::string& method_name) { auto meta_res = program_->method_meta(method_name.c_str()); @@ -422,10 +468,38 @@ runtime::Error Module::load_method( MethodHolder method_holder; if (!planned_memory) { - if (!share_memory_arenas_) { + // Check if any buffers need device memory allocation. + auto meta_res = program_->method_meta(method_name.c_str()); + ET_CHECK_OK_OR_RETURN_ERROR(meta_res.error()); + auto& meta = meta_res.get(); + + bool has_device_buffers = false; + for (size_t i = 0; i < meta.num_memory_planned_buffers(); ++i) { + auto dev = meta.memory_planned_buffer_device(i); + if (dev.ok() && !dev->is_cpu()) { + has_device_buffers = true; + break; + } + } + + if (has_device_buffers) { + // Device memory with shared arenas is not yet supported. + ET_CHECK_OR_RETURN_ERROR( + !share_memory_arenas_, + NotSupported, + "Device memory buffers are not yet compatible with " + "share_memory_arenas. Please disable share_memory_arenas " + "when using models with device-planned memory."); + + // Device-aware path: allocate CPU and device buffers. The device + // span is owned by the HierarchicalAllocator inside PlannedMemory. + method_holder.planned_memory = make_planned_memory_with_devices(meta); + planned_memory = method_holder.planned_memory->planned_memory.get(); + } else if (!share_memory_arenas_) { auto sizes_res = get_mem_planned_buffer_sizes(method_name); ET_CHECK_OK_OR_RETURN_ERROR(sizes_res.error()); method_holder.planned_memory = make_planned_memory(sizes_res.get()); + planned_memory = method_holder.planned_memory->planned_memory.get(); } else { auto sizes_res = get_mem_planned_buffer_sizes(method_name); ET_CHECK_OK_OR_RETURN_ERROR(sizes_res.error()); @@ -442,8 +516,8 @@ runtime::Error Module::load_method( } method_holder.planned_memory = make_planned_memory_with_shared_arenas(sizes, shared_arenas_); + planned_memory = method_holder.planned_memory->planned_memory.get(); } - planned_memory = method_holder.planned_memory->planned_memory.get(); } method_holder.memory_manager = std::make_unique( diff --git a/extension/module/module.h b/extension/module/module.h index 47ead23032e..91c7feaad9b 100644 --- a/extension/module/module.h +++ b/extension/module/module.h @@ -18,6 +18,8 @@ #include #include +#include + #ifdef USE_ATEN_LIB #define ET_MODULE_NAMESPACE module::aten #else // !USE_ATEN_LIB @@ -716,6 +718,11 @@ class Module { struct PlannedMemory { std::vector> planned_buffers; std::vector> planned_spans; + std::vector device_buffers; + /// Per-buffer Device (type + index) metadata used by + /// HierarchicalAllocator. Owns the storage backing the device span the + /// allocator references, so it must outlive `planned_memory`. + std::vector planned_devices; std::unique_ptr planned_memory; }; std::unique_ptr make_planned_memory( @@ -723,6 +730,8 @@ class Module { std::unique_ptr make_planned_memory_with_shared_arenas( const std::vector& buffer_sizes, std::vector>& shared_arenas); + std::unique_ptr make_planned_memory_with_devices( + const ET_RUNTIME_NAMESPACE::MethodMeta& method_meta); runtime::Result> get_mem_planned_buffer_sizes( const std::string& method_name); runtime::Result> get_max_mem_planned_buffer_sizes(); diff --git a/extension/module/targets.bzl b/extension/module/targets.bzl index fa80203831a..e622b138ff6 100644 --- a/extension/module/targets.bzl +++ b/extension/module/targets.bzl @@ -30,6 +30,7 @@ def define_common_targets(): "//executorch/runtime/backend:backend_options", "//executorch/runtime/backend:backend_options_map", "//executorch/runtime/executor:program_no_prim_ops" + aten_suffix, + "//executorch/runtime/core:device_memory_buffer", ], ) diff --git a/extension/module/test/module_device_memory_test.cpp b/extension/module/test/module_device_memory_test.cpp new file mode 100644 index 00000000000..5031273ac2b --- /dev/null +++ b/extension/module/test/module_device_memory_test.cpp @@ -0,0 +1,218 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/** + * Tests that Module's device-aware memory allocation path works correctly. + * + * Uses ModuleAddWithDevice.pte which has: + * non_const_buffer_sizes: [0, 48] (1 buffer, index 0 reserved) + * non_const_buffer_device: [{buffer_idx=1, device_type=CUDA, device_index=0}] + * + * Since we don't have a real CUDA backend, we test that: + * 1. CPU-only models load through Module without invoking device allocator + * 2. Device-annotated models trigger DeviceMemoryBuffer::create via a mock + */ + +#include + +#include + +#include +#include +#include + +using executorch::extension::Module; +using executorch::runtime::DeviceAllocator; +using executorch::runtime::DeviceMemoryBuffer; +using executorch::runtime::Error; +using executorch::runtime::register_device_allocator; +using executorch::runtime::Result; +using executorch::runtime::etensor::DeviceIndex; +using executorch::runtime::etensor::DeviceType; + +namespace { + +class MockCudaAllocator : public DeviceAllocator { + public: + Result allocate( + size_t nbytes, + DeviceIndex index, + size_t alignment = kDefaultAlignment) override { + (void)alignment; + allocate_count_++; + last_allocate_size_ = nbytes; + last_allocate_index_ = index; + buffer_ = std::make_unique(nbytes); + return static_cast(buffer_.get()); + } + + void deallocate(void* ptr, DeviceIndex index) override { + deallocate_count_++; + buffer_.reset(); + } + + Error copy_host_to_device(void*, const void*, size_t, DeviceIndex) override { + return Error::Ok; + } + + Error copy_device_to_host(void*, const void*, size_t, DeviceIndex) override { + return Error::Ok; + } + + DeviceType device_type() const override { + return DeviceType::CUDA; + } + + int allocate_count_ = 0; + int deallocate_count_ = 0; + size_t last_allocate_size_ = 0; + DeviceIndex last_allocate_index_ = -1; + + private: + std::unique_ptr buffer_; +}; + +} // namespace + +static MockCudaAllocator g_mock_cuda; + +class ModuleDeviceMemoryTest : public ::testing::Test { + protected: + static void SetUpTestSuite() { + executorch::runtime::runtime_init(); + register_device_allocator(&g_mock_cuda); + } + + void SetUp() override { + g_mock_cuda.allocate_count_ = 0; + g_mock_cuda.deallocate_count_ = 0; + g_mock_cuda.last_allocate_size_ = 0; + g_mock_cuda.last_allocate_index_ = -1; + } +}; + +TEST_F(ModuleDeviceMemoryTest, CpuOnlyModelDoesNotAllocateDeviceMemory) { + const char* path = std::getenv("ET_MODULE_ADD_PATH"); + ASSERT_NE(path, nullptr) << "ET_MODULE_ADD_PATH not set"; + + Module module(path); + auto err = module.load_method("forward"); + ASSERT_EQ(err, Error::Ok); + + EXPECT_EQ(g_mock_cuda.allocate_count_, 0) + << "CPU-only model should not allocate device memory"; +} + +TEST_F(ModuleDeviceMemoryTest, DeviceMemoryBufferCreateCallsAllocator) { + // Directly test DeviceMemoryBuffer::create with the registered mock. + // This verifies the RAII allocation/deallocation path that Module uses. + { + auto result = DeviceMemoryBuffer::create(48, DeviceType::CUDA, 0); + ASSERT_TRUE(result.ok()); + auto buf = std::move(result.get()); + + EXPECT_EQ(g_mock_cuda.allocate_count_, 1); + EXPECT_EQ(g_mock_cuda.last_allocate_size_, 48); + EXPECT_EQ(g_mock_cuda.last_allocate_index_, 0); + EXPECT_NE(buf.data(), nullptr); + EXPECT_EQ(buf.size(), 48); + + // as_span() wraps the device pointer for HierarchicalAllocator. + auto span = buf.as_span(); + EXPECT_EQ(span.data(), static_cast(buf.data())); + EXPECT_EQ(span.size(), 48); + + EXPECT_EQ(g_mock_cuda.deallocate_count_, 0); + } + // RAII deallocation on scope exit. + EXPECT_EQ(g_mock_cuda.deallocate_count_, 1); +} + +TEST_F(ModuleDeviceMemoryTest, DeviceModelMethodMetaReportsCudaBuffer) { + // Verify MethodMeta reports the correct device for buffers in the + // device-annotated model, without needing to load the full method. + const char* path = std::getenv("ET_MODULE_ADD_WITH_DEVICE_PATH"); + ASSERT_NE(path, nullptr) << "ET_MODULE_ADD_WITH_DEVICE_PATH not set"; + + Module module(path); + auto err = module.load(); + ASSERT_EQ(err, Error::Ok); + + auto meta = module.method_meta("forward"); + ASSERT_TRUE(meta.ok()); + + // ModuleAddWithDevice has 1 planned buffer (48 bytes) on CUDA. + ASSERT_EQ(meta->num_memory_planned_buffers(), 1); + + auto size = meta->memory_planned_buffer_size(0); + ASSERT_TRUE(size.ok()); + EXPECT_EQ(size.get(), 48); + + auto device = meta->memory_planned_buffer_device(0); + ASSERT_TRUE(device.ok()); + EXPECT_EQ(device->type(), DeviceType::CUDA); + EXPECT_EQ(device->index(), 0); +} + +TEST_F(ModuleDeviceMemoryTest, DeviceModelWithSharedArenasReturnsNotSupported) { + const char* path = std::getenv("ET_MODULE_ADD_WITH_DEVICE_PATH"); + ASSERT_NE(path, nullptr) << "ET_MODULE_ADD_WITH_DEVICE_PATH not set"; + + // share_memory_arenas = true with a device-annotated model should fail. + Module module( + path, + Module::LoadMode::File, + /*event_tracer=*/nullptr, + /*memory_allocator=*/nullptr, + /*temp_allocator=*/nullptr, + /*share_memory_arenas=*/true); + + auto err = module.load_method("forward"); + EXPECT_EQ(err, Error::NotSupported); +} + +TEST_F( + ModuleDeviceMemoryTest, + LoadMethodAllocatesDeviceMemoryAndDeallocatesOnDestroy) { + const char* path = std::getenv("ET_MODULE_ADD_WITH_DEVICE_PATH"); + ASSERT_NE(path, nullptr) << "ET_MODULE_ADD_WITH_DEVICE_PATH not set"; + + { + Module module(path); + auto err = module.load_method("forward"); + + // Regardless of whether load_method succeeds or fails (e.g. due to + // backend init issues), the device-aware memory allocation path + // (make_planned_memory_with_devices) runs BEFORE backend init. + EXPECT_EQ(g_mock_cuda.allocate_count_, 1) + << "Expected 1 device allocation for the CUDA buffer" + << " (actual: " << g_mock_cuda.allocate_count_ << ")" + << ", deallocate_count=" << g_mock_cuda.deallocate_count_ + << ", load_method returned error=" << static_cast(err); + EXPECT_EQ(g_mock_cuda.last_allocate_size_, 48) + << "Expected 48 bytes allocated (3 CUDA tensors sharing one buffer)"; + EXPECT_EQ(g_mock_cuda.last_allocate_index_, 0) + << "Expected device_index=0 (cuda:0)"; + + if (err == Error::Ok) { + // Success path: MethodHolder moved into methods_ map. + // DeviceMemoryBuffer is alive as long as Module is alive. + EXPECT_EQ(g_mock_cuda.deallocate_count_, 0) + << "No deallocation while method is loaded"; + } else { + // Error path: local MethodHolder destroyed on return from load_method. + // RAII deallocation already happened. + EXPECT_EQ(g_mock_cuda.deallocate_count_, 1) + << "RAII deallocation on error path"; + } + } + + // After Module destroyed, all device memory must be freed. + EXPECT_EQ(g_mock_cuda.deallocate_count_, 1) + << "Expected deallocation after Module destroyed"; +} diff --git a/extension/module/test/targets.bzl b/extension/module/test/targets.bzl index f0d7e449efd..4dc3fb537f3 100644 --- a/extension/module/test/targets.bzl +++ b/extension/module/test/targets.bzl @@ -28,7 +28,7 @@ def define_common_targets(is_fbcode=False): aten_suffix = ("_aten" if aten_mode else "") runtime.cxx_test( - name = "test" + aten_suffix, + name = "module_test" + aten_suffix, srcs = [ "module_test.cpp", ], @@ -68,6 +68,26 @@ def define_common_targets(is_fbcode=False): ], ) + runtime.cxx_test( + name = "module_device_memory_test" + aten_suffix, + srcs = [ + "module_device_memory_test.cpp", + ], + deps = [ + "//executorch/kernels/portable:generated_lib" + aten_suffix, + "//executorch/extension/module:module" + aten_suffix, + "//executorch/runtime/core:device_allocator", + "//executorch/runtime/core:device_memory_buffer", + ], + env = { + "ET_MODULE_ADD_WITH_DEVICE_PATH": "$(location fbcode//executorch/test/models:exported_program_with_device_info[ModuleAddWithDevice.pte])", + "ET_MODULE_ADD_PATH": "$(location fbcode//executorch/test/models:exported_programs[ModuleAdd.pte])", + }, + compiler_flags = [ + "-Wno-error=deprecated-declarations", + ], + ) + runtime.filegroup( name = "resources", srcs = native.glob([ diff --git a/kernels/portable/cpu/op__device_copy.cpp b/kernels/portable/cpu/op__device_copy.cpp new file mode 100644 index 00000000000..5e1a51a83be --- /dev/null +++ b/kernels/portable/cpu/op__device_copy.cpp @@ -0,0 +1,154 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/** + * Runtime kernels for et_copy._h2d_copy and et_copy._d2h_copy ops. + * + * These ops transfer tensor data between CPU and device memory using + * the DeviceAllocator interface. The device type is inferred from the + * tensor metadata (out.device_type() for H2D, self.device_type() for D2H), + * which was set during AOT serialization by PropagateDevicePass. + */ + +#include +#include +#include + +namespace torch { +namespace executor { +namespace native { + +using Tensor = executorch::aten::Tensor; +using DeviceAllocator = executorch::runtime::DeviceAllocator; +using Error = executorch::runtime::Error; + +/** + * Copies tensor data from host (CPU) memory to device memory. + * + * self: source tensor on CPU + * out: destination tensor on device (memory-planned by runtime) + * + * The device type and index are inferred from out's TensorImpl metadata. + */ +Tensor& +_h2d_copy_out(KernelRuntimeContext& ctx, const Tensor& self, Tensor& out) { + auto device_type = out.unsafeGetTensorImpl()->device_type(); + auto device_index = out.unsafeGetTensorImpl()->device_index(); + + ET_KERNEL_CHECK_MSG( + ctx, + self.unsafeGetTensorImpl()->device_type() == + executorch::runtime::etensor::DeviceType::CPU, + InvalidArgument, + out, + "_h2d_copy: source tensor must be on CPU, got device_type=%d", + static_cast(self.unsafeGetTensorImpl()->device_type())); + + ET_KERNEL_CHECK_MSG( + ctx, + device_type != executorch::runtime::etensor::DeviceType::CPU, + InvalidArgument, + out, + "_h2d_copy: destination tensor must be on a non-CPU device"); + + auto nbytes = self.nbytes(); + ET_KERNEL_CHECK_MSG( + ctx, + nbytes == out.nbytes(), + InvalidArgument, + out, + "_h2d_copy: size mismatch: self.nbytes()=%zu, out.nbytes()=%zu", + nbytes, + out.nbytes()); + + DeviceAllocator* allocator = + executorch::runtime::get_device_allocator(device_type); + ET_KERNEL_CHECK_MSG( + ctx, + allocator != nullptr, + NotFound, + out, + "_h2d_copy: no device allocator registered for device_type=%d", + static_cast(device_type)); + + Error err = allocator->copy_host_to_device( + out.mutable_data_ptr(), self.const_data_ptr(), nbytes, device_index); + ET_KERNEL_CHECK_MSG( + ctx, + err == Error::Ok, + Internal, + out, + "_h2d_copy: copy_host_to_device failed"); + + return out; +} + +/** + * Copies tensor data from device memory to host (CPU) memory. + * + * self: source tensor on device + * out: destination tensor on CPU (memory-planned by runtime) + * + * The device type and index are inferred from self's TensorImpl metadata. + */ +Tensor& +_d2h_copy_out(KernelRuntimeContext& ctx, const Tensor& self, Tensor& out) { + auto device_type = self.unsafeGetTensorImpl()->device_type(); + auto device_index = self.unsafeGetTensorImpl()->device_index(); + + ET_KERNEL_CHECK_MSG( + ctx, + device_type != executorch::runtime::etensor::DeviceType::CPU, + InvalidArgument, + out, + "_d2h_copy: source tensor must be on a non-CPU device"); + + ET_KERNEL_CHECK_MSG( + ctx, + out.unsafeGetTensorImpl()->device_type() == + executorch::runtime::etensor::DeviceType::CPU, + InvalidArgument, + out, + "_d2h_copy: destination tensor must be on CPU, got device_type=%d", + static_cast(out.unsafeGetTensorImpl()->device_type())); + + auto nbytes = self.nbytes(); + ET_KERNEL_CHECK_MSG( + ctx, + nbytes == out.nbytes(), + InvalidArgument, + out, + "_d2h_copy: size mismatch: self.nbytes()=%zu, out.nbytes()=%zu", + nbytes, + out.nbytes()); + + DeviceAllocator* allocator = + executorch::runtime::get_device_allocator(device_type); + ET_KERNEL_CHECK_MSG( + ctx, + allocator != nullptr, + NotFound, + out, + "_d2h_copy: no device allocator registered for device_type=%d", + static_cast(device_type)); + + Error err = allocator->copy_device_to_host( + out.mutable_data_ptr(), self.const_data_ptr(), nbytes, device_index); + ET_KERNEL_CHECK_MSG( + ctx, + err == Error::Ok, + Internal, + out, + "_d2h_copy: copy_device_to_host failed"); + + return out; +} + +} // namespace native +} // namespace executor +} // namespace torch diff --git a/kernels/portable/functions.yaml b/kernels/portable/functions.yaml index 620d97d050f..ecf62ee3606 100644 --- a/kernels/portable/functions.yaml +++ b/kernels/portable/functions.yaml @@ -1045,6 +1045,16 @@ - arg_meta: null kernel_name: torch::executor::zeros_out +- func: et_copy::_h2d_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: torch::executor::_h2d_copy_out + +- func: et_copy::_d2h_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: torch::executor::_d2h_copy_out + - func: dim_order_ops::_empty_dim_order.out(int[] size, *, int[]? dim_order=None, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null diff --git a/kernels/portable/targets.bzl b/kernels/portable/targets.bzl index 2c6e0b5c35f..b80ce347768 100644 --- a/kernels/portable/targets.bzl +++ b/kernels/portable/targets.bzl @@ -66,15 +66,19 @@ def define_common_targets(): "visibility": ["PUBLIC"], } - executorch_generated_lib( - name = "generated_lib", - deps = [ - ":executorch_aten_ops", - ":executorch_custom_ops", - ], - kernel_deps = ["//executorch/kernels/portable:operators"], - **generated_lib_common_args - ) + for support_exceptions in [True, False]: + exception_suffix = "_no_exceptions" if not support_exceptions else "" + + executorch_generated_lib( + name = "generated_lib" + exception_suffix, + deps = [ + ":executorch_aten_ops", + ":executorch_custom_ops", + ], + kernel_deps = ["//executorch/kernels/portable:operators"], + support_exceptions = support_exceptions, + **generated_lib_common_args + ) if True in get_aten_mode_options(): executorch_generated_lib( diff --git a/kernels/test/op__device_copy_test.cpp b/kernels/test/op__device_copy_test.cpp new file mode 100644 index 00000000000..d345642bd37 --- /dev/null +++ b/kernels/test/op__device_copy_test.cpp @@ -0,0 +1,297 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/** + * Tests for et_copy._h2d_copy.out and et_copy._d2h_copy.out runtime kernels. + * + * Uses a MockDeviceAllocator to verify that the kernels correctly call + * copy_host_to_device / copy_device_to_host via the DeviceAllocator interface, + * and that device type is inferred from tensor metadata. + */ + +#include + +#include // Declares the operator +#include +#include +#include +#include +#include + +using executorch::aten::ScalarType; +using executorch::aten::Tensor; +using executorch::aten::TensorImpl; +using executorch::runtime::DeviceAllocator; +using executorch::runtime::Error; +using executorch::runtime::get_device_allocator; +using executorch::runtime::register_device_allocator; +using executorch::runtime::Result; +using executorch::runtime::etensor::DeviceIndex; +using executorch::runtime::etensor::DeviceType; + +using TensorShapeDynamism = executorch::runtime::TensorShapeDynamism; + +namespace { + +class MockDeviceAllocator : public DeviceAllocator { + public: + Result allocate( + size_t nbytes, + DeviceIndex index, + size_t alignment = kDefaultAlignment) override { + return Error::NotSupported; + } + + void deallocate(void* ptr, DeviceIndex index) override {} + + Error copy_host_to_device( + void* dst, + const void* src, + size_t nbytes, + DeviceIndex index) override { + h2d_call_count_++; + last_h2d_nbytes_ = nbytes; + last_h2d_device_index_ = index; + // Actually copy so we can verify data + std::memcpy(dst, src, nbytes); + return Error::Ok; + } + + Error copy_device_to_host( + void* dst, + const void* src, + size_t nbytes, + DeviceIndex index) override { + d2h_call_count_++; + last_d2h_nbytes_ = nbytes; + last_d2h_device_index_ = index; + std::memcpy(dst, src, nbytes); + return Error::Ok; + } + + DeviceType device_type() const override { + return DeviceType::CUDA; + } + + int h2d_call_count_ = 0; + int d2h_call_count_ = 0; + size_t last_h2d_nbytes_ = 0; + size_t last_d2h_nbytes_ = 0; + DeviceIndex last_h2d_device_index_ = -1; + DeviceIndex last_d2h_device_index_ = -1; +}; + +} // namespace + +static MockDeviceAllocator g_mock_cuda; + +class OpDeviceCopyTest : public OperatorTest { + protected: + Tensor& op_h2d_copy_out(const Tensor& self, Tensor& out) { + return torch::executor::et_copy::_h2d_copy_outf(context_, self, out); + } + + Tensor& op_d2h_copy_out(const Tensor& self, Tensor& out) { + return torch::executor::et_copy::_d2h_copy_outf(context_, self, out); + } + + static void SetUpTestSuite() { + executorch::runtime::runtime_init(); + if (get_device_allocator(DeviceType::CUDA) == nullptr) { + register_device_allocator(&g_mock_cuda); + } + } + + void SetUp() override { + OperatorTest::SetUp(); + g_mock_cuda.h2d_call_count_ = 0; + g_mock_cuda.d2h_call_count_ = 0; + g_mock_cuda.last_h2d_nbytes_ = 0; + g_mock_cuda.last_d2h_nbytes_ = 0; + g_mock_cuda.last_h2d_device_index_ = -1; + g_mock_cuda.last_d2h_device_index_ = -1; + } +}; + +TEST_F(OpDeviceCopyTest, H2dCopyCopiesDataAndCallsAllocator) { + // Set up a CPU source tensor with known data. + float src_data[] = {1.0f, 2.0f, 3.0f, 4.0f}; + int32_t sizes[] = {4}; + uint8_t dim_order[] = {0}; + int32_t strides[] = {1}; + TensorImpl src_impl( + ScalarType::Float, + 1, + sizes, + src_data, + dim_order, + strides, + TensorShapeDynamism::STATIC, + DeviceType::CPU, + 0); + Tensor src(&src_impl); + + // Set up a CUDA destination tensor (simulated with host memory). + float dst_data[] = {0.0f, 0.0f, 0.0f, 0.0f}; + TensorImpl dst_impl( + ScalarType::Float, + 1, + sizes, + dst_data, + dim_order, + strides, + TensorShapeDynamism::STATIC, + DeviceType::CUDA, + 0); + Tensor dst(&dst_impl); + + Tensor& result = op_h2d_copy_out(src, dst); + + // Verify the allocator was called correctly. + EXPECT_EQ(g_mock_cuda.h2d_call_count_, 1); + EXPECT_EQ(g_mock_cuda.last_h2d_nbytes_, 4 * sizeof(float)); + EXPECT_EQ(g_mock_cuda.last_h2d_device_index_, 0); + + // Verify data was copied (mock does a real memcpy). + EXPECT_EQ(dst_data[0], 1.0f); + EXPECT_EQ(dst_data[1], 2.0f); + EXPECT_EQ(dst_data[2], 3.0f); + EXPECT_EQ(dst_data[3], 4.0f); + + // Verify return value is the out tensor. + EXPECT_EQ(&result, &dst); +} + +TEST_F(OpDeviceCopyTest, D2hCopyCopiesDataAndCallsAllocator) { + // Set up a CUDA source tensor with known data. + float src_data[] = {5.0f, 6.0f, 7.0f, 8.0f}; + int32_t sizes[] = {4}; + uint8_t dim_order[] = {0}; + int32_t strides[] = {1}; + TensorImpl src_impl( + ScalarType::Float, + 1, + sizes, + src_data, + dim_order, + strides, + TensorShapeDynamism::STATIC, + DeviceType::CUDA, + 0); + Tensor src(&src_impl); + + // Set up a CPU destination tensor. + float dst_data[] = {0.0f, 0.0f, 0.0f, 0.0f}; + TensorImpl dst_impl( + ScalarType::Float, + 1, + sizes, + dst_data, + dim_order, + strides, + TensorShapeDynamism::STATIC, + DeviceType::CPU, + 0); + Tensor dst(&dst_impl); + + Tensor& result = op_d2h_copy_out(src, dst); + + // Verify the allocator was called correctly. + EXPECT_EQ(g_mock_cuda.d2h_call_count_, 1); + EXPECT_EQ(g_mock_cuda.last_d2h_nbytes_, 4 * sizeof(float)); + EXPECT_EQ(g_mock_cuda.last_d2h_device_index_, 0); + + // Verify data was copied. + EXPECT_EQ(dst_data[0], 5.0f); + EXPECT_EQ(dst_data[1], 6.0f); + EXPECT_EQ(dst_data[2], 7.0f); + EXPECT_EQ(dst_data[3], 8.0f); + + EXPECT_EQ(&result, &dst); +} + +TEST_F(OpDeviceCopyTest, H2dCopyWithDeviceIndex1) { + // Verify device_index is correctly forwarded to the allocator. + float src_data[] = {1.0f}; + float dst_data[] = {0.0f}; + int32_t sizes[] = {1}; + uint8_t dim_order[] = {0}; + int32_t strides[] = {1}; + + TensorImpl src_impl( + ScalarType::Float, + 1, + sizes, + src_data, + dim_order, + strides, + TensorShapeDynamism::STATIC, + DeviceType::CPU, + 0); + Tensor src(&src_impl); + + // Device index = 1 (e.g., cuda:1) + TensorImpl dst_impl( + ScalarType::Float, + 1, + sizes, + dst_data, + dim_order, + strides, + TensorShapeDynamism::STATIC, + DeviceType::CUDA, + 1); + Tensor dst(&dst_impl); + + op_h2d_copy_out(src, dst); + + EXPECT_EQ(g_mock_cuda.h2d_call_count_, 1); + EXPECT_EQ(g_mock_cuda.last_h2d_device_index_, 1); +} + +TEST_F(OpDeviceCopyTest, H2dCopyMultidimensionalTensor) { + // Test with a 2D tensor [2, 3]. + float src_data[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + float dst_data[] = {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}; + int32_t sizes[] = {2, 3}; + uint8_t dim_order[] = {0, 1}; + int32_t strides[] = {3, 1}; + + TensorImpl src_impl( + ScalarType::Float, + 2, + sizes, + src_data, + dim_order, + strides, + TensorShapeDynamism::STATIC, + DeviceType::CPU, + 0); + Tensor src(&src_impl); + + TensorImpl dst_impl( + ScalarType::Float, + 2, + sizes, + dst_data, + dim_order, + strides, + TensorShapeDynamism::STATIC, + DeviceType::CUDA, + 0); + Tensor dst(&dst_impl); + + op_h2d_copy_out(src, dst); + + EXPECT_EQ(g_mock_cuda.h2d_call_count_, 1); + EXPECT_EQ(g_mock_cuda.last_h2d_nbytes_, 6 * sizeof(float)); + + for (int i = 0; i < 6; ++i) { + EXPECT_EQ(dst_data[i], src_data[i]); + } +} diff --git a/kernels/test/targets.bzl b/kernels/test/targets.bzl index bc51e336cb8..5212d691c5b 100644 --- a/kernels/test/targets.bzl +++ b/kernels/test/targets.bzl @@ -1,14 +1,14 @@ load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") load("@fbsource//xplat/executorch/kernels/test:util.bzl", "codegen_function_header_wrapper", "op_test") -def _common_op_test(name, kernels): +def _common_op_test(name, kernels, deps = []): """ Defines test targets in format of _op__test For ATen kernel testing, let's use portable functions.yaml for tested ops. """ for kernel in kernels: - deps = [":function_header_wrapper_{}".format(kernel)] - op_test(name, kernel_name = kernel, use_kernel_prefix = True, deps = deps) + op_deps = [":function_header_wrapper_{}".format(kernel)] + deps + op_test(name, kernel_name = kernel, use_kernel_prefix = True, deps = op_deps) def define_common_targets(): """Defines targets that should be shared between fbcode and xplat. @@ -177,6 +177,14 @@ def define_common_targets(): _common_op_test("op__clone_dim_order_test", ["aten", "portable"]) _common_op_test("op__conj_physical_test", ["aten", "portable"]) _common_op_test("op__adaptive_avg_pool2d_test", ["aten", "portable"]) + _common_op_test( + "op__device_copy_test", + ["portable"], + deps = [ + "//executorch/runtime/core:device_allocator", + "//executorch/runtime/platform:platform", + ], + ) _common_op_test("op_abs_test", ["aten", "portable"]) _common_op_test("op_acos_test", ["aten", "portable"]) _common_op_test("op_acosh_test", ["aten", "portable"]) diff --git a/runtime/core/error.h b/runtime/core/error.h index 80c2ef645d4..b923604ca89 100644 --- a/runtime/core/error.h +++ b/runtime/core/error.h @@ -151,8 +151,9 @@ constexpr const char* to_string(const Error error) { return "Error::RegistrationExceedingMaxKernels"; case Error::RegistrationAlreadyRegistered: return "Error::RegistrationAlreadyRegistered"; + default: + return "Error::Unknown"; } - return "Error::Unknown"; } } // namespace runtime diff --git a/runtime/executor/targets.bzl b/runtime/executor/targets.bzl index 90f8d0221e9..81d0a58667f 100644 --- a/runtime/executor/targets.bzl +++ b/runtime/executor/targets.bzl @@ -16,8 +16,14 @@ def _program_preprocessor_flags(): if enable_verification == "false": return ["-DET_ENABLE_PROGRAM_VERIFICATION=0"] elif enable_verification == "true": - # Enabled by default. - return [] + # Enabled by default; allow opt-out via constraint + if not runtime.is_oss: + return select({ + "DEFAULT": [], + "fbsource//xplat/executorch/tools/buck/constraints:executorch-program-verification-disabled": ["-DET_ENABLE_PROGRAM_VERIFICATION=0"], + }) + else: + return [] else: fail("executorch.enable_program_verification must be one of 'true' or 'false'; saw '" + enable_verification + "'") diff --git a/runtime/platform/compiler.h b/runtime/platform/compiler.h index edd340d1fb0..692d590f44c 100644 --- a/runtime/platform/compiler.h +++ b/runtime/platform/compiler.h @@ -138,8 +138,14 @@ #define __has_builtin(x) (0) #endif -#if __has_builtin(__builtin_strrchr) +#if defined(__FILE_NAME__) +/// __FILE_NAME__ provides just the filename at +/// compile time, avoiding embedding full paths in the binary +#define ET_SHORT_FILENAME __FILE_NAME__ +#elif __has_builtin(__builtin_strrchr) /// Name of the source file without a directory string. +/// Note: This approach embeds the full path in .rodata even though only the +/// basename is used at runtime. __FILE_NAME__ is preferred when available. #define ET_SHORT_FILENAME (__builtin_strrchr("/" __FILE__, '/') + 1) #else #define ET_SHORT_FILENAME __FILE__ @@ -152,12 +158,17 @@ #define ET_LINE __LINE__ #endif // __has_builtin(__builtin_LINE) -#if __has_builtin(__builtin_FUNCTION) +#if defined(ET_USE_BUILTIN_FUNCTION_NAME) && ET_USE_BUILTIN_FUNCTION_NAME == 0 +/// __FUNCTION__ provides a short undecorated name, saving .rodata space +/// compared to __builtin_FUNCTION() which includes the full signature +/// (namespace, parameters, return type). +#define ET_FUNCTION __FUNCTION__ +#elif __has_builtin(__builtin_FUNCTION) /// Name of the current function as a const char[]. #define ET_FUNCTION __builtin_FUNCTION() #else #define ET_FUNCTION __FUNCTION__ -#endif // __has_builtin(__builtin_FUNCTION) +#endif // As of G3 RJ-2024.3 toolchain, zu format specifier is not supported for Xtensa #if defined(__XTENSA__) diff --git a/runtime/platform/profiler.h b/runtime/platform/profiler.h index d6362781394..cb011bd0ef9 100644 --- a/runtime/platform/profiler.h +++ b/runtime/platform/profiler.h @@ -227,8 +227,12 @@ using ::executorch::runtime::track_allocator; #define EXECUTORCH_END_PROF(token_id) \ ::executorch::runtime::end_profiling(token_id); -#define EXECUTORCH_SCOPE_PROF(name) \ - ::executorch::runtime::ExecutorchProfiler profiler(name); +#define EXECUTORCH_SCOPE_PROF_CONCAT_IMPL(a, b) a##b +#define EXECUTORCH_SCOPE_PROF_CONCAT(a, b) \ + EXECUTORCH_SCOPE_PROF_CONCAT_IMPL(a, b) +#define EXECUTORCH_SCOPE_PROF(name) \ + ::executorch::runtime::ExecutorchProfiler EXECUTORCH_SCOPE_PROF_CONCAT( \ + et_profiler_, __LINE__)(name); #define EXECUTORCH_PROFILE_INSTRUCTION_SCOPE(chain_idx, instruction_idx) \ ::executorch::runtime::ExecutorchProfilerInstructionScope \ diff --git a/runtime/platform/targets.bzl b/runtime/platform/targets.bzl index 65d92b134d6..63b8cb553ef 100644 --- a/runtime/platform/targets.bzl +++ b/runtime/platform/targets.bzl @@ -116,5 +116,9 @@ def define_common_targets(): exported_headers = [ "compiler.h", ], + exported_preprocessor_flags = select({ + "DEFAULT": [], + "fbsource//xplat/executorch/tools/buck/constraints:executorch-builtin-function-name-disabled": ["-DET_USE_BUILTIN_FUNCTION_NAME=0"], + }) if not runtime.is_oss else [], visibility = ["PUBLIC"], ) diff --git a/shim_et/xplat/executorch/build/build_variables.bzl b/shim_et/xplat/executorch/build/build_variables.bzl index b0545b8ce18..659a128994f 100644 --- a/shim_et/xplat/executorch/build/build_variables.bzl +++ b/shim_et/xplat/executorch/build/build_variables.bzl @@ -50,6 +50,8 @@ PLATFORM_SRCS = [ EXECUTORCH_CORE_SRCS = sorted([ "runtime/backend/interface.cpp", + "runtime/core/device_allocator.cpp", + "runtime/core/device_memory_buffer.cpp", "runtime/core/evalue.cpp", "runtime/core/exec_aten/util/tensor_shape_to_c_string.cpp", "runtime/core/exec_aten/util/tensor_util_portable.cpp", diff --git a/shim_et/xplat/executorch/codegen/codegen.bzl b/shim_et/xplat/executorch/codegen/codegen.bzl index 5ffa7b65a36..318996784a1 100644 --- a/shim_et/xplat/executorch/codegen/codegen.bzl +++ b/shim_et/xplat/executorch/codegen/codegen.bzl @@ -535,6 +535,7 @@ def get_portable_lib_deps(): "//executorch/kernels/portable/cpu:vec_ops", "//executorch/kernels/portable/cpu/pattern:all_deps", "//executorch/kernels/portable/cpu/util:all_deps", + "//executorch/runtime/core:device_allocator", ] def get_optimized_lib_deps(): diff --git a/shim_et/xplat/executorch/kernels/portable/op_registration_util.bzl b/shim_et/xplat/executorch/kernels/portable/op_registration_util.bzl index cc2a0f78c75..479f3913f8f 100644 --- a/shim_et/xplat/executorch/kernels/portable/op_registration_util.bzl +++ b/shim_et/xplat/executorch/kernels/portable/op_registration_util.bzl @@ -1405,6 +1405,12 @@ ATEN_OPS = ( "//executorch/kernels/portable/cpu/util:copy_ops_util", ], ), + op_target( + name = "op__device_copy", + deps = [ + "//executorch/runtime/core:device_allocator", + ], + ), ) # Operators that are not listed in `functions.yaml` (i.e., operators listed in diff --git a/test/models/targets.bzl b/test/models/targets.bzl index c9fb67b7d31..a80244b1383 100644 --- a/test/models/targets.bzl +++ b/test/models/targets.bzl @@ -226,6 +226,7 @@ def define_common_targets(): default_outs = ["."], visibility = [ "//executorch/runtime/executor/test/...", + "//executorch/extension/module/test/...", ], ) diff --git a/test/targets.bzl b/test/targets.bzl index 023a1d48960..0047d5563fc 100644 --- a/test/targets.bzl +++ b/test/targets.bzl @@ -36,7 +36,9 @@ def define_common_targets(): name = "size_test_all_ops", srcs = SIZE_TEST_SOURCES, deps = SIZE_TEST_DEPS + [ - "//executorch/kernels/portable:generated_lib", + # size_test_all_ops is built with -fno-exceptions in the size CI; + # use the _no_exceptions variant whose codegen omits try/catch. + "//executorch/kernels/portable:generated_lib_no_exceptions", "//executorch/runtime/executor/test:test_backend_compiler_lib", ], define_static_target = True, diff --git a/tools/buck/constraints/BUCK b/tools/buck/constraints/BUCK index b558bb9e4a4..49fbaabe06f 100644 --- a/tools/buck/constraints/BUCK +++ b/tools/buck/constraints/BUCK @@ -61,3 +61,41 @@ fb_native.constraint_value( constraint_setting = ":executorch-event-tracer", visibility = ["PUBLIC"], ) + +fb_native.config_setting( + name = "executorch-program-verification-disabled", + constraint_values = [ + ":program-verification-disabled", + ], + visibility = ["PUBLIC"], +) + +fb_native.constraint_setting( + name = "executorch-program-verification", + visibility = ["PUBLIC"], +) + +fb_native.constraint_value( + name = "program-verification-disabled", + constraint_setting = ":executorch-program-verification", + visibility = ["PUBLIC"], +) + +fb_native.config_setting( + name = "executorch-builtin-function-name-disabled", + constraint_values = [ + ":builtin-function-name-disabled", + ], + visibility = ["PUBLIC"], +) + +fb_native.constraint_setting( + name = "executorch-builtin-function-name", + visibility = ["PUBLIC"], +) + +fb_native.constraint_value( + name = "builtin-function-name-disabled", + constraint_setting = ":executorch-builtin-function-name", + visibility = ["PUBLIC"], +) diff --git a/tools/cmake/preset/riscv_baremetal.cmake b/tools/cmake/preset/riscv_baremetal.cmake new file mode 100644 index 00000000000..e70fc57ba57 --- /dev/null +++ b/tools/cmake/preset/riscv_baremetal.cmake @@ -0,0 +1,50 @@ +# Copyright 2026 The ExecuTorch Authors. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Baremetal builds consume the build tree directly; mirror arm_baremetal so +# install rules stay invokable but write back into the build dir. +define_overridable_option( + EXECUTORCH_BAREMETAL_SKIP_INSTALL + "Skip emitting install/export rules when building bare-metal artifacts" BOOL + ON +) + +if(EXECUTORCH_BAREMETAL_SKIP_INSTALL) + set(CMAKE_INSTALL_PREFIX "${CMAKE_BINARY_DIR}") + unset(CMAKE_SKIP_INSTALL_RULES CACHE) + set(CMAKE_SKIP_INSTALL_RULES + OFF + CACHE + BOOL + "Retain install() rules so docs/scripts can keep calling --target install" + FORCE + ) +endif() + +set_overridable_option(EXECUTORCH_BUILD_EXECUTOR_RUNNER OFF) +set_overridable_option(EXECUTORCH_BUILD_EXTENSION_DATA_LOADER OFF) +set_overridable_option(EXECUTORCH_BUILD_EXTENSION_FLAT_TENSOR OFF) +set_overridable_option(EXECUTORCH_BUILD_EXTENSION_EVALUE_UTIL ON) +set_overridable_option(EXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL ON) +set_overridable_option(EXECUTORCH_BUILD_KERNELS_QUANTIZED ON) +# BUNDLE_IO requires DEVTOOLS to provide the bundled_program lib. +set_overridable_option(EXECUTORCH_BUILD_DEVTOOLS ON) +set_overridable_option(EXECUTORCH_ENABLE_BUNDLE_IO ON) +set_overridable_option(EXECUTORCH_ENABLE_LOGGING ON) +# Freestanding target: no pthreadpool, no cpuinfo, no shared lib. +set_overridable_option(EXECUTORCH_BUILD_PTHREADPOOL OFF) +set_overridable_option(EXECUTORCH_BUILD_CPUINFO OFF) + +define_overridable_option( + EXECUTORCH_BUILD_RISCV_ETDUMP "Build etdump support for RISC-V" BOOL OFF +) + +if("${EXECUTORCH_BUILD_RISCV_ETDUMP}") + set(EXECUTORCH_BUILD_DEVTOOLS ON) + set(EXECUTORCH_ENABLE_EVENT_TRACER ON) + set(FLATCC_ALLOW_WERROR OFF) +else() + set(EXECUTORCH_ENABLE_EVENT_TRACER OFF) +endif() diff --git a/tools/cmake/preset/riscv64_linux.cmake b/tools/cmake/preset/riscv_linux.cmake similarity index 100% rename from tools/cmake/preset/riscv64_linux.cmake rename to tools/cmake/preset/riscv_linux.cmake