From 023a5f186da616f28cdffa9c6ba98ff4fe3ff2eb Mon Sep 17 00:00:00 2001 From: Mergen Nachin Date: Tue, 12 May 2026 17:45:28 -0400 Subject: [PATCH] Add MLX backend support for Gemma 4 31B MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - pack_mlx.py: converts Int4Tensor → IntxUnpackedToInt8Tensor at pack time (nibble unpack + scale transpose) so the default dispatch produces the dequantize_affine → linear pattern MLX expects. IntxUnpackedToInt8Tensor passes through unchanged. Embedding with incompatible per-axis group_size is regrouped to gs=128. - export.py: add --backend mlx with single-method export (dynamic seq_len), sampler stripping, and MLXPartitioner lowering. No int4_dispatch import — MLX uses the standard dequantize_affine path. - main.cpp: handle both CUDA (prefill+decode, on-device sampling) and MLX (single forward method, host-side argmax) via #ifdef. - CMakeLists.txt / CMakePresets.json / Makefile: add gemma4_31b-mlx build target linking mlxdelegate. - test_pack_mlx.py: 15 tests covering Int4→IntxUnpacked conversion correctness, passthrough, regrouping, error cases. - test_mlx_pipeline.py: 4 e2e tests including export-to-pte. Validated: same CUDA-quantized checkpoint packs for both backends, 100% op delegation to MLX, real 31B checkpoint packs at 4.0 GB RSS. PR authored with Claude. --- Makefile | 12 +- examples/models/gemma4_31b/CMakeLists.txt | 15 +- examples/models/gemma4_31b/CMakePresets.json | 31 +++ examples/models/gemma4_31b/README.md | 22 +- examples/models/gemma4_31b/export.py | 139 ++++++++++- examples/models/gemma4_31b/main.cpp | 163 ++++++------- examples/models/gemma4_31b/model.md | 33 ++- examples/models/gemma4_31b/quant/README.md | 4 +- examples/models/gemma4_31b/quant/__init__.py | 1 + examples/models/gemma4_31b/quant/pack_mlx.py | 215 ++++++++++++++++++ .../gemma4_31b/quant/tests/test_pack_mlx.py | 199 ++++++++++++++++ .../gemma4_31b/tests/test_mlx_pipeline.py | 154 +++++++++++++ 12 files changed, 888 insertions(+), 100 deletions(-) create mode 100644 examples/models/gemma4_31b/quant/pack_mlx.py create mode 100644 examples/models/gemma4_31b/quant/tests/test_pack_mlx.py create mode 100644 examples/models/gemma4_31b/tests/test_mlx_pipeline.py diff --git a/Makefile b/Makefile index ba61dddce44..9b7f24b2f83 100644 --- a/Makefile +++ b/Makefile @@ -91,7 +91,7 @@ # # ============================================================================== -.PHONY: voxtral-cuda voxtral-cpu voxtral-metal voxtral-mlx voxtral_realtime-cuda voxtral_realtime-cpu voxtral_realtime-metal voxtral_realtime-mlx voxtral_tts-cpu voxtral_tts-cuda whisper-cuda whisper-cuda-debug whisper-cpu whisper-metal parakeet-cuda parakeet-cuda-debug parakeet-cpu parakeet-metal parakeet-mlx parakeet-vulkan dinov2-cuda dinov2-cuda-debug sortformer-cuda sortformer-cpu silero-vad-cpu llama-cuda llama-cuda-debug llama-cpu llava-cpu gemma3-cuda gemma3-cpu gemma4_31b-cuda qwen3_5_moe-cuda qwen3_5_moe-metal clean help +.PHONY: voxtral-cuda voxtral-cpu voxtral-metal voxtral-mlx voxtral_realtime-cuda voxtral_realtime-cpu voxtral_realtime-metal voxtral_realtime-mlx voxtral_tts-cpu voxtral_tts-cuda whisper-cuda whisper-cuda-debug whisper-cpu whisper-metal parakeet-cuda parakeet-cuda-debug parakeet-cpu parakeet-metal parakeet-mlx parakeet-vulkan dinov2-cuda dinov2-cuda-debug sortformer-cuda sortformer-cpu silero-vad-cpu llama-cuda llama-cuda-debug llama-cpu llava-cpu gemma3-cuda gemma3-cpu gemma4_31b-cuda gemma4_31b-mlx qwen3_5_moe-cuda qwen3_5_moe-metal clean help help: @echo "This Makefile adds targets to build runners for various models on various backends. Run using \`make \`. Available targets:" @@ -127,6 +127,7 @@ help: @echo " gemma3-cuda - Build Gemma3 runner with CUDA backend" @echo " gemma3-cpu - Build Gemma3 runner with CPU backend" @echo " gemma4_31b-cuda - Build Gemma 4 31B runner with CUDA backend" + @echo " gemma4_31b-mlx - Build Gemma 4 31B runner with MLX backend" @echo " qwen3_5_moe-cuda - Build Qwen3.5 MoE runner with CUDA backend" @echo " qwen3_5_moe-metal - Build Qwen3.5 MoE runner with Metal backend" @echo " clean - Clean build artifacts" @@ -435,6 +436,15 @@ gemma4_31b-cuda: @echo "✓ Build complete!" @echo " Binary: cmake-out/examples/models/gemma4_31b/gemma4_31b_runner" +gemma4_31b-mlx: + @echo "==> Building and installing ExecuTorch with MLX..." + cmake --workflow --preset mlx-release + @echo "==> Building Gemma 4 31B runner with MLX..." + cd examples/models/gemma4_31b && cmake --workflow --preset gemma4-31b-mlx + @echo "" + @echo "✓ Build complete!" + @echo " Binary: cmake-out/examples/models/gemma4_31b/gemma4_31b_runner" + qwen3_5_moe-metal: @echo "==> Building and installing ExecuTorch with Metal..." cmake --workflow --preset llm-release-metal diff --git a/examples/models/gemma4_31b/CMakeLists.txt b/examples/models/gemma4_31b/CMakeLists.txt index 8d536a47fc5..52419eb95bc 100644 --- a/examples/models/gemma4_31b/CMakeLists.txt +++ b/examples/models/gemma4_31b/CMakeLists.txt @@ -42,14 +42,17 @@ list( extension_flat_tensor ) -# CUDA backend (the only supported backend for this example for now) +# Backend: CUDA or MLX (exactly one required) if(EXECUTORCH_BUILD_CUDA) find_package(CUDAToolkit REQUIRED) list(APPEND link_libraries aoti_cuda_backend) executorch_target_link_options_shared_lib(aoti_cuda_backend) add_compile_definitions(EXECUTORCH_BUILD_CUDA) +elseif(TARGET mlxdelegate) + list(APPEND link_libraries mlxdelegate mlx) + executorch_target_link_options_shared_lib(mlxdelegate) else() - message(FATAL_ERROR "Set EXECUTORCH_BUILD_CUDA=ON") + message(FATAL_ERROR "Set EXECUTORCH_BUILD_CUDA=ON or EXECUTORCH_BUILD_MLX=ON") endif() # Tokenizer (HuggingFace tokenizer.json) @@ -63,5 +66,11 @@ target_link_libraries(gemma4_31b_runner PUBLIC ${link_libraries}) if(NOT CMAKE_BUILD_TYPE STREQUAL "Debug") target_link_options_gc_sections(gemma4_31b_runner) - target_link_options(gemma4_31b_runner PRIVATE "LINKER:-s") + if(NOT APPLE AND NOT MSVC) + target_link_options(gemma4_31b_runner PRIVATE "LINKER:-s") + endif() +endif() + +if(TARGET mlxdelegate) + executorch_target_copy_mlx_metallib(gemma4_31b_runner) endif() diff --git a/examples/models/gemma4_31b/CMakePresets.json b/examples/models/gemma4_31b/CMakePresets.json index 97ba7f4c57a..23a7d42e035 100644 --- a/examples/models/gemma4_31b/CMakePresets.json +++ b/examples/models/gemma4_31b/CMakePresets.json @@ -23,6 +23,17 @@ "string": "${hostSystemName}", "list": ["Linux", "Windows"] } + }, + { + "name": "gemma4-31b-mlx", + "displayName": "Gemma 4 31B runner (MLX)", + "inherits": ["gemma4-31b-base"], + "cacheVariables": {}, + "condition": { + "type": "equals", + "lhs": "${hostSystemName}", + "rhs": "Darwin" + } } ], "buildPresets": [ @@ -31,6 +42,12 @@ "displayName": "Build Gemma 4 31B runner (CUDA)", "configurePreset": "gemma4-31b-cuda", "targets": ["gemma4_31b_runner"] + }, + { + "name": "gemma4-31b-mlx", + "displayName": "Build Gemma 4 31B runner (MLX)", + "configurePreset": "gemma4-31b-mlx", + "targets": ["gemma4_31b_runner"] } ], "workflowPresets": [ @@ -47,6 +64,20 @@ "name": "gemma4-31b-cuda" } ] + }, + { + "name": "gemma4-31b-mlx", + "displayName": "Configure and build Gemma 4 31B runner (MLX)", + "steps": [ + { + "type": "configure", + "name": "gemma4-31b-mlx" + }, + { + "type": "build", + "name": "gemma4-31b-mlx" + } + ] } ] } diff --git a/examples/models/gemma4_31b/README.md b/examples/models/gemma4_31b/README.md index 6f567d739b7..1623acea320 100644 --- a/examples/models/gemma4_31b/README.md +++ b/examples/models/gemma4_31b/README.md @@ -1,7 +1,7 @@ # Gemma 4 31B-IT Text-only export of Google's Gemma 4 31B-IT to ExecuTorch with INT4/INT8 -weight quantization. Currently supports the CUDA backend. +weight quantization. Supports CUDA and MLX (Apple Silicon) backends. For architecture and design notes see [model.md](model.md). @@ -67,6 +67,8 @@ recipe. Writes `model.safetensors`, `config.json`, and `tokenizer.json` into ## Export to ExecuTorch +### CUDA + ```bash python examples/models/gemma4_31b/export.py \ --prequantized ./gemma4_31b_int4 \ @@ -75,7 +77,20 @@ python examples/models/gemma4_31b/export.py \ --backend cuda ``` -Writes `model.pte` and `model.ptd` into `--output-dir`. +### MLX (Apple Silicon) + +```bash +python examples/models/gemma4_31b/export.py \ + --prequantized ./gemma4_31b_int4 \ + --output-dir ./gemma4_31b_exports_mlx \ + --max-seq-len 4096 \ + --backend mlx +``` + +The same quantized checkpoint works for both backends. MLX exports a single +method with dynamic sequence length and host-side sampling. + +Writes `model.pte` (and optionally `model.ptd`) into `--output-dir`. ## Eager inference @@ -102,7 +117,8 @@ model produces sensible text. ## Build the runner ```bash -make gemma4_31b-cuda +make gemma4_31b-cuda # Linux — CUDA backend +make gemma4_31b-mlx # macOS — MLX backend (Apple Silicon) ``` The binary lands at `cmake-out/examples/models/gemma4_31b/gemma4_31b_runner`. diff --git a/examples/models/gemma4_31b/export.py b/examples/models/gemma4_31b/export.py index a96dba0d512..fa3a5d3a1fe 100644 --- a/examples/models/gemma4_31b/export.py +++ b/examples/models/gemma4_31b/export.py @@ -19,6 +19,8 @@ Backends: --backend cuda (default) CUDA via tinygemm INT4 + CudaPartitioner. + --backend mlx Apple Silicon via MLXPartitioner (single method, + dynamic seq_len, host-side sampling). """ import argparse @@ -98,12 +100,21 @@ def load_and_quantize( # Backend dispatch helpers +_SUPPORTED_BACKENDS = ("cuda", "mlx") + + def _get_packers(backend: str) -> dict: if backend == "cuda": from executorch.examples.models.gemma4_31b.quant import DEFAULT_CUDA_PACKERS return DEFAULT_CUDA_PACKERS - raise ValueError(f"Unsupported backend: {backend!r}. Supported: 'cuda'.") + if backend == "mlx": + from executorch.examples.models.gemma4_31b.quant import DEFAULT_MLX_PACKERS + + return DEFAULT_MLX_PACKERS + raise ValueError( + f"Unsupported backend: {backend!r}. Supported: {_SUPPORTED_BACKENDS}." + ) def _pack_for_backend(model: nn.Module, path: str, backend: str) -> None: @@ -111,8 +122,14 @@ def _pack_for_backend(model: nn.Module, path: str, backend: str) -> None: from executorch.examples.models.gemma4_31b.quant import load_and_pack_for_cuda load_and_pack_for_cuda(path, model) + elif backend == "mlx": + from executorch.examples.models.gemma4_31b.quant import load_and_pack_for_mlx + + load_and_pack_for_mlx(path, model) else: - raise ValueError(f"Unsupported backend: {backend!r}. Supported: 'cuda'.") + raise ValueError( + f"Unsupported backend: {backend!r}. Supported: {_SUPPORTED_BACKENDS}." + ) # --------------------------------------------------------------------------- @@ -128,8 +145,12 @@ def export_and_lower( """Export and lower the model to ExecuTorch for the given backend.""" if backend == "cuda": _export_cuda(model, config, output_dir) + elif backend == "mlx": + _export_mlx(model, config, output_dir) else: - raise ValueError(f"Unsupported backend: {backend!r}. Supported: 'cuda'.") + raise ValueError( + f"Unsupported backend: {backend!r}. Supported: {_SUPPORTED_BACKENDS}." + ) def _export_cuda(model: Gemma4_31B, config: Gemma4_31BConfig, output_dir: str) -> None: @@ -258,6 +279,116 @@ def _export_cuda(model: Gemma4_31B, config: Gemma4_31BConfig, output_dir: str) - print("Done.") +def _strip_sampler_from_forward(model: Gemma4_31B) -> None: + """Replace forward with a ``(tokens, input_pos) → logits`` variant. + + MLX samples on the host, so the on-device Gumbel-max sampler and its + temperature input are dead code. Stripping them produces a cleaner + exported graph. + """ + import types + + def _clean_forward(self, tokens, input_pos): + x = self.embed_tokens(tokens) * self.embed_normalizer + sliding_mask, full_mask = self._build_masks(input_pos) + for layer in self.layers: + x = layer(x, input_pos, sliding_mask, full_mask) + x = self.norm(x) + logits = self.lm_head(x).float() + cap = self.logit_softcap.float() + return torch.tanh(logits / cap) * cap + + model.forward = types.MethodType(_clean_forward, model) + + +def _export_mlx(model: Gemma4_31B, config: Gemma4_31BConfig, output_dir: str) -> None: + """Export to .pte via torch.export + MLX backend. + + Unlike CUDA (which exports separate decode/prefill methods with an + Int4Tensor dispatch override), MLX uses a single method with dynamic + sequence length. No int4_dispatch import — IntxUnpackedToInt8Tensor's + default dispatch produces the ``dequantize_affine → linear`` pattern + that MLX's QuantizedLinearHandler matches. + """ + import gc + + from executorch.backends.mlx import MLXPartitioner + from executorch.backends.mlx.passes import get_default_passes + from executorch.exir import ( + EdgeCompileConfig, + ExecutorchBackendConfig, + to_edge_transform_and_lower, + ) + from executorch.exir.passes import MemoryPlanningPass + from torch.export import Dim, export + + _strip_sampler_from_forward(model) + materialize_runtime_buffers(model, dtype=torch.bfloat16) + + max_prefill = min(config.max_seq_len - 1, config.sliding_window * 2) + seq_dim = Dim("seq_len", min=1, max=max_prefill) + + print(f"Exporting (T in [1, {max_prefill}])...") + with torch.no_grad(): + exported = export( + model, + ( + torch.tensor([[0, 1]], dtype=torch.long), + torch.tensor([0, 1], dtype=torch.long), + ), + dynamic_shapes=({1: seq_dim}, {0: seq_dim}), + strict=True, + ) + + del model + gc.collect() + + print("Lowering to ExecuTorch with MLX backend...") + et_prog = to_edge_transform_and_lower( + exported, + transform_passes=get_default_passes(), + partitioner=[MLXPartitioner()], + compile_config=EdgeCompileConfig( + _check_ir_validity=False, + _skip_dim_order=True, + ), + constant_methods={ + "get_max_seq_len": config.max_seq_len, + "get_vocab_size": config.vocab_size, + "get_n_layers": config.num_hidden_layers, + "get_max_prefill_chunk": max_prefill, + "use_kv_cache": True, + "use_sdpa_with_kv_cache": False, + "enable_dynamic_shape": True, + }, + ) + + del exported + gc.collect() + + et_program = et_prog.to_executorch( + config=ExecutorchBackendConfig( + extract_delegate_segments=True, + memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False), + ), + ) + + del et_prog + gc.collect() + + os.makedirs(output_dir, exist_ok=True) + pte_path = os.path.join(output_dir, "model.pte") + print(f"Saving to {pte_path}...") + with open(pte_path, "wb") as f: + et_program.write_to_file(f) + print(f" {os.path.getsize(pte_path) / 1024**2:.1f} MB") + + if et_program._tensor_data: + et_program.write_tensor_data_to_file(output_dir) + print(f" Saved tensor data (.ptd) to {output_dir}/") + print("Done.") + + # --------------------------------------------------------------------------- # CLI @@ -302,7 +433,7 @@ def main() -> None: parser.add_argument( "--backend", default="cuda", - choices=["cuda"], + choices=list(_SUPPORTED_BACKENDS), help="Target backend for export.", ) args = parser.parse_args() diff --git a/examples/models/gemma4_31b/main.cpp b/examples/models/gemma4_31b/main.cpp index 0be2fef517c..0aafa9b384e 100644 --- a/examples/models/gemma4_31b/main.cpp +++ b/examples/models/gemma4_31b/main.cpp @@ -6,12 +6,12 @@ * LICENSE file in the root directory of this source tree. */ -// Gemma 4 31B-IT runner for the CUDA ExecuTorch backend. +// Gemma 4 31B-IT runner for ExecuTorch. // -// Drives the prefill + decode methods produced by export.py. -// The exported model performs Gumbel-max sampling on-device and returns a -// single float token ID per call, so this runner only has to feed tokens -// in and decode them via the HuggingFace tokenizer. +// Supports two backends: +// CUDA — two methods (prefill + decode), on-device Gumbel-max sampling, +// temperature as a third input. +// MLX — single "forward" method, returns logits, host-side greedy argmax. #include @@ -78,6 +78,7 @@ using ::executorch::runtime::EValue; using SizesType = executorch::aten::SizesType; +// Read a sampled token ID from a scalar float output (CUDA path). static uint64_t read_token(const executorch::aten::Tensor& output) { const void* ptr = output.const_data_ptr(); float val = 0.0f; @@ -106,6 +107,25 @@ static uint64_t read_token(const executorch::aten::Tensor& output) { return static_cast(llrintf(val)); } +// Greedy argmax over the last token's logits (MLX path). +static uint64_t argmax_last_token(const executorch::aten::Tensor& logits) { + int32_t ndim = logits.dim(); + int64_t V = logits.size(ndim - 1); + int64_t T = (ndim >= 2) ? logits.size(ndim - 2) : 1; + const float* data = logits.const_data_ptr(); + const float* last = data + (T - 1) * V; + + uint64_t best = 0; + float best_val = last[0]; + for (int64_t i = 1; i < V; i++) { + if (last[i] > best_val) { + best_val = last[i]; + best = static_cast(i); + } + } + return best; +} + int main(int argc, char** argv) { gflags::ParseCommandLineFlags(&argc, &argv, true); @@ -139,8 +159,7 @@ int main(int argc, char** argv) { return 1; } - // Module: share_memory_arenas=true so prefill and decode see the same - // KV-cache memory (we exported with share_mutable_buffers=True). + // Module std::vector data_files; if (!FLAGS_data_path.empty()) { data_files.push_back(FLAGS_data_path); @@ -161,6 +180,16 @@ int main(int argc, char** argv) { return 1; } + int64_t max_prefill_chunk = (*metadata_result)[llm::kMaxSeqLen] - 1; + { + auto get_result = module->get("get_max_prefill_chunk"); + if (get_result.ok()) { + max_prefill_chunk = get_result->toScalar().to(); + } + } + + auto S = [](int64_t v) -> SizesType { return static_cast(v); }; + #ifdef EXECUTORCH_BUILD_CUDA if (FLAGS_cuda_graph) { executorch::runtime::BackendOptions<2> cuda_opts; @@ -168,38 +197,11 @@ int main(int argc, char** argv) { executorch::runtime::set_option("CudaBackend", cuda_opts.view()); printf("CUDA graph enabled for decode method\n"); } - - // Cross-method per-FQN weight sharing: prefill + decode share the same - // weight tensors and (more importantly) the same KV-cache buffers, so - // without this flag we would allocate them twice. MUST be set before - // load_method. { executorch::runtime::BackendOptions<1> backend_options; - auto set_err = - backend_options.set_option("weight_sharing_across_methods", true); - if (set_err != Error::Ok) { - ET_LOG( - Error, - "Failed to construct weight_sharing_across_methods option: %d", - static_cast(set_err)); - return 1; - } - auto opt_err = - executorch::runtime::set_option("CudaBackend", backend_options.view()); - if (opt_err != Error::Ok) { - ET_LOG( - Error, - "Failed to enable weight_sharing_across_methods: %d", - static_cast(opt_err)); - return 1; - } - } -#else - if (FLAGS_cuda_graph) { - ET_LOG(Info, "--cuda_graph ignored on non-CUDA build"); + backend_options.set_option("weight_sharing_across_methods", true); + executorch::runtime::set_option("CudaBackend", backend_options.view()); } -#endif - printf("Loading methods...\n"); if (module->load_method("prefill") != Error::Ok) { ET_LOG(Error, "Failed to load prefill method"); @@ -209,6 +211,18 @@ int main(int argc, char** argv) { ET_LOG(Error, "Failed to load decode method"); return 1; } + float temp_val = + FLAGS_temperature <= 0.0 ? 1e-6f : static_cast(FLAGS_temperature); + auto temp_tensor = + from_blob(&temp_val, {1}, executorch::aten::ScalarType::Float); +#else + printf("Loading model...\n"); + if (module->load_method("forward") != Error::Ok) { + ET_LOG(Error, "Failed to load forward method"); + return 1; + } +#endif + stats.model_load_end_ms = llm::time_in_ms(); #ifdef EXECUTORCH_BUILD_CUDA @@ -219,7 +233,7 @@ int main(int argc, char** argv) { auto eos_ids = llm::get_eos_ids(tokenizer.get(), module.get()); eos_ids.insert(static_cast(FLAGS_eos_id)); - // Read prompt from file or flag + // Read prompt std::string prompt_text = FLAGS_prompt; if (!FLAGS_prompt_file.empty()) { std::ifstream f(FLAGS_prompt_file); @@ -232,7 +246,6 @@ int main(int argc, char** argv) { (std::istreambuf_iterator(f)), std::istreambuf_iterator()); } - // Encode prompt auto encode_result = tokenizer->encode(prompt_text); if (!encode_result.ok()) { ET_LOG(Error, "Failed to encode prompt"); @@ -248,38 +261,15 @@ int main(int argc, char** argv) { stats.inference_start_ms = llm::time_in_ms(); - auto S = [](int64_t v) -> SizesType { return static_cast(v); }; - -#ifdef EXECUTORCH_BUILD_CUDA - // CUDA build: model fuses the sampler. Pass temperature as a third input. - float temp_val = - FLAGS_temperature <= 0.0 ? 1e-6f : static_cast(FLAGS_temperature); - auto temp_tensor = - from_blob(&temp_val, {1}, executorch::aten::ScalarType::Float); -#endif - // --------------------------------------------------------------- // Prefill (chunked to respect ring-buffer KV cache limit) // --------------------------------------------------------------- - // Sliding layers use a ring buffer sized to 2×sliding_window. A single - // prefill call must not exceed this size, otherwise index_copy_ with - // wrapped indices produces non-deterministic results on CUDA. - int64_t max_prefill_chunk = (*metadata_result)[llm::kMaxSeqLen] - 1; - { - auto get_result = module->get("get_max_prefill_chunk"); - if (get_result.ok()) { - max_prefill_chunk = get_result->toScalar().to(); - } - } - uint64_t cur_token = 0; int64_t prefill_pos = 0; while (prefill_pos < num_prompt_tokens) { int64_t chunk_len = std::min(num_prompt_tokens - prefill_pos, max_prefill_chunk); - std::string run_method = (chunk_len == 1) ? "decode" : "prefill"; - std::vector token_data( prompt_tokens.begin() + prefill_pos, prompt_tokens.begin() + prefill_pos + chunk_len); @@ -294,20 +284,29 @@ int main(int argc, char** argv) { auto pos_tensor = from_blob( pos_data.data(), {S(chunk_len)}, executorch::aten::ScalarType::Long); - std::vector prefill_inputs; - prefill_inputs.push_back(EValue(tokens_tensor)); - prefill_inputs.push_back(EValue(pos_tensor)); + std::vector inputs; + inputs.push_back(EValue(tokens_tensor)); + inputs.push_back(EValue(pos_tensor)); + #ifdef EXECUTORCH_BUILD_CUDA - prefill_inputs.push_back(EValue(temp_tensor)); + inputs.push_back(EValue(temp_tensor)); + std::string method = (chunk_len == 1) ? "decode" : "prefill"; +#else + std::string method = "forward"; #endif - auto prefill_result = module->execute(run_method, prefill_inputs); - if (prefill_result.error() != Error::Ok) { - ET_LOG( - Error, "%s failed at pos %" PRId64, run_method.c_str(), prefill_pos); + auto result = module->execute(method, inputs); + if (result.error() != Error::Ok) { + ET_LOG(Error, "%s failed at pos %" PRId64, method.c_str(), prefill_pos); return 1; } - cur_token = read_token(prefill_result.get()[0].toTensor()); + +#ifdef EXECUTORCH_BUILD_CUDA + cur_token = read_token(result.get()[0].toTensor()); +#else + cur_token = argmax_last_token(result.get()[0].toTensor()); +#endif + prefill_pos += chunk_len; } @@ -321,9 +320,6 @@ int main(int argc, char** argv) { num_prompt_tokens * 1000.0 / prefill_ms); #ifdef EXECUTORCH_BUILD_CUDA - // Synchronize CUDA device to ensure prefill's writes to shared mutable - // buffers (KV cache) are visible to the decode method, which may run on - // a different CUDA stream. cudaDeviceSynchronize(); #endif @@ -343,21 +339,28 @@ int main(int argc, char** argv) { decode_token_data[0] = static_cast(cur_token); decode_pos_data[0] = pos; - std::vector decode_inputs; - decode_inputs.push_back(EValue(decode_tokens)); - decode_inputs.push_back(EValue(decode_pos)); + std::vector inputs; + inputs.push_back(EValue(decode_tokens)); + inputs.push_back(EValue(decode_pos)); + #ifdef EXECUTORCH_BUILD_CUDA - decode_inputs.push_back(EValue(temp_tensor)); + inputs.push_back(EValue(temp_tensor)); + auto result = module->execute("decode", inputs); +#else + auto result = module->execute("forward", inputs); #endif - auto decode_result = module->execute("decode", decode_inputs); - if (decode_result.error() != Error::Ok) { + if (result.error() != Error::Ok) { ET_LOG(Error, "Decode step %d failed", step); return 1; } prev_token = cur_token; - cur_token = read_token(decode_result.get()[0].toTensor()); +#ifdef EXECUTORCH_BUILD_CUDA + cur_token = read_token(result.get()[0].toTensor()); +#else + cur_token = argmax_last_token(result.get()[0].toTensor()); +#endif if (step == 0) { stats.first_token_ms = llm::time_in_ms(); diff --git a/examples/models/gemma4_31b/model.md b/examples/models/gemma4_31b/model.md index 8233b6d430e..65bd2c412b7 100644 --- a/examples/models/gemma4_31b/model.md +++ b/examples/models/gemma4_31b/model.md @@ -102,6 +102,8 @@ Decoder norms per layer: `input_layernorm`, `post_attention_layernorm`, ## Methods exported (`export.py`) +### CUDA (`--backend cuda`) + | Method | Input | Output (sampled) | |-----------|------------------------------------------------------------|------------------| | `decode` | tokens `(1, 1)` + input_pos `(1,)` + temperature `(1,)` | `(1, 1)` float | @@ -113,6 +115,20 @@ Both methods share the same KV-cache buffers via sampling on-device and returns a single token ID per call so the C++ runner only has to feed tokens. +### MLX (`--backend mlx`) + +| Method | Input | Output | +|-----------|------------------------------------------|------------------| +| `forward` | tokens `(1, T)` + input_pos `(T,)`, T∈[1, min(max_seq_len-1, 2×sliding_window)] | `(1, T, V)` logits | + +Single method with dynamic sequence length. No on-device sampling — the +C++ runner performs greedy argmax on the host. Int4Tensor weights are +converted to IntxUnpackedToInt8Tensor at pack time so the default +`dequantize_affine → linear` dispatch produces the pattern MLX's +`QuantizedLinearHandler` fuses into `QuantizedMatmulNode`. + +### Shared + Prefill length is capped to the ring-buffer KV cache size (`2 × sliding_window`) to avoid duplicate wrapped indices in `index_copy_`. The C++ runner chunks longer prompts automatically using @@ -130,9 +146,11 @@ Modules in `quant/`: `IntxUnpackedToInt8Tensor`) from fp weights. - **Serialization**: callers use torchao's safetensors integration (`torchao.prototype.safetensors`) directly — no wrapper module needed. -- **Pack** (`pack.py` + `pack_cuda.py`): `pack_model` groups weights by - parent module, `pack_one` handles single weights. Per-module packers - dispatch by module type (`nn.Linear`, `nn.Embedding`, extensible for MoE). +- **Pack** (`pack.py` + `pack_cuda.py` + `pack_mlx.py`): `pack_model` groups + weights by parent module, `pack_one` handles single weights. Per-module + packers dispatch by module type (`nn.Linear`, `nn.Embedding`). CUDA passes + Int4Tensor through (dispatch handled by `int4_dispatch.py`); MLX converts + Int4Tensor → IntxUnpackedToInt8Tensor and regroups per-axis embeddings. - **GGUF** (`gguf.py`): `unpack_gguf_tensor` / `iter_gguf_tensors` for loading community-quantized GGUF files (Q4_K, Q6_K). @@ -145,11 +163,12 @@ quantize_and_save.py export.py / inference.py | | quantize_weight() load (torchao safetensors) | | - Int4Tensor / IntxUnpacked Int4Tensor / IntxUnpacked (used directly) - | | - save (torchao safetensors) int4_dispatch routes to int4_plain_mm + Int4Tensor / IntxUnpacked pack for backend: | | - model.safetensors dp4a decode / dequant+cuBLAS prefill + save (torchao safetensors) CUDA: Int4Tensor passed through + | → int4_dispatch → dp4a / dequant+cuBLAS + model.safetensors MLX: Int4Tensor → IntxUnpacked(int4) + → dequantize_affine → QuantizedMatmulNode ``` `embed_tokens` and `lm_head` start tied; they are untied before diff --git a/examples/models/gemma4_31b/quant/README.md b/examples/models/gemma4_31b/quant/README.md index 31b1c43d574..2eacced4387 100644 --- a/examples/models/gemma4_31b/quant/README.md +++ b/examples/models/gemma4_31b/quant/README.md @@ -9,7 +9,8 @@ Quantization framework: **recipe → quantize → pack**. | `recipe.py` | **Policy** — what to quantize, what precision, which layers | nothing | | `quantize.py` | **Computation** — produces torchao subclass tensors | recipe, torchao | | `pack.py` | **Packing dispatch** — `pack_model` (bulk) and `pack_one` (streaming) | — | -| `pack_cuda.py` | **CUDA packing** — converts Int4Tensor to tinygemm format | pack | +| `pack_cuda.py` | **CUDA packing** — passes Int4Tensor/IntxUnpacked through for CUDA dispatch | pack | +| `pack_mlx.py` | **MLX packing** — converts Int4Tensor → IntxUnpacked, regroups per-axis embeddings | pack | | `gguf.py` | **GGUF import** — unpacks Q4_K/Q6_K blocks to torchao subclasses | torchao | ## Data flow @@ -48,7 +49,6 @@ The format is compatible with torchao's `save_pretrained` / `load_pretrained`. ## TODO - `pack_metal.py` — Metal backend packer. -- `pack_mlx.py` — MLX backend packer. - `gguf.py` — extend with Q5_K, Q8_0 GGUF quant types. - Upstream `Int4TilePackedTo4dTensor.from_int4_tensor()` to torchao to replace the manual conversion in `pack_int4_for_cuda`. diff --git a/examples/models/gemma4_31b/quant/__init__.py b/examples/models/gemma4_31b/quant/__init__.py index 93efb69865f..7e9ab97a1bb 100644 --- a/examples/models/gemma4_31b/quant/__init__.py +++ b/examples/models/gemma4_31b/quant/__init__.py @@ -6,5 +6,6 @@ from .pack import ModulePackerFn, pack_model, pack_one # noqa: F401 from .pack_cuda import DEFAULT_CUDA_PACKERS, load_and_pack_for_cuda # noqa: F401 +from .pack_mlx import DEFAULT_MLX_PACKERS, load_and_pack_for_mlx # noqa: F401 from .quantize import dequantize_weight, quantize_model, quantize_weight # noqa: F401 from .recipe import QuantConfig, QuantRecipe, QuantRule # noqa: F401 diff --git a/examples/models/gemma4_31b/quant/pack_mlx.py b/examples/models/gemma4_31b/quant/pack_mlx.py new file mode 100644 index 00000000000..8587e8e4f9d --- /dev/null +++ b/examples/models/gemma4_31b/quant/pack_mlx.py @@ -0,0 +1,215 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""MLX packer: convert quantized weights to MLX-compatible format. + +MLX's ``QuantizedLinearHandler`` matches ``dequantize_affine → linear`` +in the exported graph. ``IntxUnpackedToInt8Tensor`` produces this +pattern naturally, but ``Int4Tensor`` does not (its dispatch calls +CUDA-specific mslk kernels). So INT4 weights are converted to +``IntxUnpackedToInt8Tensor(target_dtype=torch.int4)`` at pack time. + +The backend-agnostic ``pack_model`` dispatcher lives in ``pack.py``. +""" + +import json + +import torch +import torch.nn as nn + +from .pack import ModulePackerFn, pack_model # noqa: F401 + +_MLX_SUPPORTED_GROUP_SIZES = (128, 64, 32) + + +# --------------------------------------------------------------------------- +# Int4Tensor → IntxUnpackedToInt8Tensor conversion + + +def _int4_to_intx_unpacked(w: torch.Tensor) -> torch.Tensor: + """Convert an ``Int4Tensor`` to ``IntxUnpackedToInt8Tensor``. + + Int4Tensor stores qdata as nibble-packed uint8 ``(N, K/2)`` with + scale/zero transposed to ``(K//gs, N)``. IntxUnpackedToInt8Tensor + stores qdata as int8 ``(N, K)`` with scale/zero as ``(N, K//gs)``. + """ + from torchao.quantization import IntxUnpackedToInt8Tensor + + # Unpack nibbles: packed = even | (odd << 4), unsigned [0, 15] + p = w.qdata.to(torch.uint8) + low = (p & 0x0F).to(torch.int8) + high = ((p >> 4) & 0x0F).to(torch.int8) + qdata = torch.stack([low, high], dim=-1).reshape(w.shape) + + # Shift unsigned [0, 15] → signed [-8, 7] + qdata = qdata - 8 + + gs = w.block_size[-1] + + # Transpose scale/zero from (K//gs, N) → (N, K//gs) + scale = w.scale.t().contiguous() + zero_point = (w.zero_point - 8).t().contiguous() + + return IntxUnpackedToInt8Tensor( + qdata=qdata, + scale=scale, + zero_point=zero_point, + target_dtype=torch.int4, + block_size=(1, gs), + dtype=scale.dtype, + activation_quantization=None, + ) + + +# --------------------------------------------------------------------------- +# Embedding group_size regrouping + + +def _mlx_group_size(gs: int, K: int) -> int: + """Find an MLX-compatible group_size for the given weight group_size. + + If ``gs`` is already in {32, 64, 128}, return it. Otherwise find the + largest supported group_size that divides ``gs`` so per-axis scales can + be repeated to fill finer groups. + """ + if gs in _MLX_SUPPORTED_GROUP_SIZES: + return gs + for candidate in _MLX_SUPPORTED_GROUP_SIZES: + if gs % candidate == 0 and K % candidate == 0: + return candidate + raise ValueError( + f"MLX requires group_size in {set(_MLX_SUPPORTED_GROUP_SIZES)} " + f"(or a multiple thereof), got {gs}" + ) + + +def _regroup_intx(w: torch.Tensor, new_gs: int) -> torch.Tensor: + """Regroup an ``IntxUnpackedToInt8Tensor`` to a finer group_size.""" + from torchao.quantization import IntxUnpackedToInt8Tensor + + old_gs = w.block_size[-1] + repeat_factor = old_gs // new_gs + N = w.qdata.shape[0] + n_groups = w.qdata.shape[-1] // new_gs + + scale = w.scale.repeat_interleave(repeat_factor, dim=-1).reshape(N, n_groups) + zero_point = w.zero_point.repeat_interleave(repeat_factor, dim=-1).reshape( + N, n_groups + ) + + return IntxUnpackedToInt8Tensor( + qdata=w.qdata, + scale=scale, + zero_point=zero_point, + target_dtype=w.target_dtype, + block_size=(1, new_gs), + dtype=w.dtype, + activation_quantization=w.activation_quantization, + ) + + +# --------------------------------------------------------------------------- +# Per-module packers + + +def pack_linear_for_mlx(module: nn.Module, weights: dict[str, torch.Tensor]) -> None: + """Pack a quantized ``nn.Linear`` for MLX. + + ``IntxUnpackedToInt8Tensor`` passes through (already produces the + ``dequantize_affine → linear`` pattern MLX expects). + ``Int4Tensor`` is converted to ``IntxUnpackedToInt8Tensor`` so the + default dispatch produces the right pattern instead of calling + CUDA-specific mslk kernels. + """ + from torchao.quantization import IntxUnpackedToInt8Tensor + from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor + + w = weights["weight"] + if isinstance(w, Int4Tensor): + w = _int4_to_intx_unpacked(w) + elif not isinstance(w, IntxUnpackedToInt8Tensor): + raise ValueError(f"Unsupported weight type: {type(w).__name__}") + module.weight = nn.Parameter(w, requires_grad=False) + + +def pack_embedding_for_mlx(module: nn.Module, weights: dict[str, torch.Tensor]) -> None: + """Pack a quantized ``nn.Embedding`` for MLX. + + Regroups to a compatible group_size when needed (e.g. per-axis + group_size=5376 → group_size=128) since MLX's ``parse_dequant_node`` + only accepts group_size in {32, 64, 128}. + """ + from torchao.quantization import IntxUnpackedToInt8Tensor + from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor + + w = weights["weight"] + if isinstance(w, Int4Tensor): + raise ValueError( + "Only 8-bit embedding quantization is supported on MLX. " + "INT4 does not implement the embedding op." + ) + if isinstance(w, IntxUnpackedToInt8Tensor): + gs = w.block_size[-1] + K = w.qdata.shape[-1] + target_gs = _mlx_group_size(gs, K) + if target_gs != gs: + w = _regroup_intx(w, target_gs) + module.weight = nn.Parameter(w, requires_grad=False) + + +DEFAULT_MLX_PACKERS: dict[type, ModulePackerFn] = { + nn.Linear: pack_linear_for_mlx, + nn.Embedding: pack_embedding_for_mlx, +} + + +# --------------------------------------------------------------------------- +# Load + pack (I/O wrapper) + + +def load_and_pack_for_mlx( + path: str, + model: nn.Module, + packers: dict[type, ModulePackerFn] | None = None, +) -> None: + """Load a quantized safetensors file and pack for MLX. + + Streams one weight at a time via torchao's safetensors support. + """ + from safetensors import safe_open + from torchao.prototype.safetensors.safetensors_support import ( + unflatten_tensor_state_dict, + ) + + from .pack import pack_one + + _packers = packers or DEFAULT_MLX_PACKERS + with safe_open(path, framework="pt", device="cpu") as f: + metadata = f.metadata() + all_keys = list(f.keys()) + tensor_names = json.loads(metadata.get("tensor_names", "[]")) + + for name in tensor_names: + parts = name.rsplit(".", 1) + module_fqn = parts[0] if len(parts) > 1 else "" + weight_name = parts[-1] + prefix = ( + f"{module_fqn}._{weight_name}_" if module_fqn else f"_{weight_name}_" + ) + partial = {} + for key in all_keys: + if key.startswith(prefix) or key == name: + partial[key] = f.get_tensor(key) + result, _ = unflatten_tensor_state_dict(partial, metadata) + for fqn, value in result.items(): + pack_one(model, fqn, value, _packers) + + for fqn, p in model.named_parameters(): + if p.device.type == "meta": + raise RuntimeError( + f"Weight '{fqn}' not found in checkpoint " + f"(model/checkpoint version mismatch?)" + ) diff --git a/examples/models/gemma4_31b/quant/tests/test_pack_mlx.py b/examples/models/gemma4_31b/quant/tests/test_pack_mlx.py new file mode 100644 index 00000000000..6aef99b0e75 --- /dev/null +++ b/examples/models/gemma4_31b/quant/tests/test_pack_mlx.py @@ -0,0 +1,199 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Unit tests for quant/pack_mlx.py. No CUDA or MLX hardware required.""" + +import unittest + +import torch +import torch.nn as nn + +from executorch.examples.models.gemma4_31b.quant.pack import pack_model +from executorch.examples.models.gemma4_31b.quant.pack_mlx import ( + _int4_to_intx_unpacked, + _mlx_group_size, + DEFAULT_MLX_PACKERS, + pack_embedding_for_mlx, + pack_linear_for_mlx, +) +from executorch.examples.models.gemma4_31b.quant.quantize import ( + dequantize_weight, + quantize_weight, +) +from executorch.examples.models.gemma4_31b.quant.recipe import QuantConfig + + +class TestInt4ToIntxConversion(unittest.TestCase): + """Int4Tensor → IntxUnpackedToInt8Tensor conversion.""" + + def test_symmetric_dequant_matches(self): + """Converted weight dequantizes to same values as original.""" + torch.manual_seed(0) + weight = torch.randn(64, 128, dtype=torch.bfloat16) + config = QuantConfig(bits=4, group_size=32, symmetric=True, method="min_max") + int4_w = quantize_weight(weight, config) + intx_w = _int4_to_intx_unpacked(int4_w) + + int4_dense = dequantize_weight(int4_w, torch.float32) + intx_dense = dequantize_weight(intx_w, torch.float32) + self.assertTrue( + torch.allclose(int4_dense, intx_dense, atol=1e-5), + f"max diff: {(int4_dense - intx_dense).abs().max():.6g}", + ) + + def test_asymmetric_dequant_matches(self): + torch.manual_seed(0) + weight = torch.randn(64, 128, dtype=torch.bfloat16) + config = QuantConfig(bits=4, group_size=32, symmetric=False, method="min_max") + int4_w = quantize_weight(weight, config) + intx_w = _int4_to_intx_unpacked(int4_w) + + int4_dense = dequantize_weight(int4_w, torch.float32) + intx_dense = dequantize_weight(intx_w, torch.float32) + self.assertTrue( + torch.allclose(int4_dense, intx_dense, atol=1e-5), + f"max diff: {(int4_dense - intx_dense).abs().max():.6g}", + ) + + def test_output_type_and_shape(self): + from torchao.quantization import IntxUnpackedToInt8Tensor + + torch.manual_seed(0) + config = QuantConfig(bits=4, group_size=32, symmetric=True, method="min_max") + int4_w = quantize_weight(torch.randn(128, 256, dtype=torch.bfloat16), config) + intx_w = _int4_to_intx_unpacked(int4_w) + + self.assertIsInstance(intx_w, IntxUnpackedToInt8Tensor) + self.assertEqual(intx_w.shape, torch.Size([128, 256])) + self.assertEqual(intx_w.qdata.shape, torch.Size([128, 256])) + self.assertEqual(intx_w.target_dtype, torch.int4) + + def test_different_group_sizes(self): + torch.manual_seed(0) + for gs in (32, 64, 128): + with self.subTest(group_size=gs): + config = QuantConfig( + bits=4, group_size=gs, symmetric=True, method="min_max" + ) + int4_w = quantize_weight( + torch.randn(64, 256, dtype=torch.bfloat16), config + ) + intx_w = _int4_to_intx_unpacked(int4_w) + self.assertEqual(intx_w.shape, torch.Size([64, 256])) + + def test_matmul_approximates_original(self): + torch.manual_seed(0) + weight = torch.randn(256, 128, dtype=torch.bfloat16) + x = torch.randn(1, 128, dtype=torch.bfloat16) + original_out = torch.nn.functional.linear(x, weight) + + config = QuantConfig(bits=4, group_size=32, symmetric=False, method="min_max") + int4_w = quantize_weight(weight, config) + intx_w = _int4_to_intx_unpacked(int4_w) + packed_out = torch.nn.functional.linear(x, intx_w.dequantize()) + + rel_error = ( + packed_out.float() - original_out.float() + ).abs().mean() / original_out.float().abs().mean() + self.assertLess(rel_error.item(), 0.15) + + +class TestPackLinearForMlx(unittest.TestCase): + def test_int4_converts_to_intx(self): + from torchao.quantization import IntxUnpackedToInt8Tensor + + module = nn.Linear(128, 64, bias=False) + config = QuantConfig(bits=4, group_size=32, symmetric=True, method="min_max") + w = quantize_weight(torch.randn(64, 128, dtype=torch.bfloat16), config) + pack_linear_for_mlx(module, {"weight": w}) + + self.assertIsInstance(module.weight.data, IntxUnpackedToInt8Tensor) + self.assertEqual(module.weight.shape, torch.Size([64, 128])) + self.assertFalse(module.weight.requires_grad) + + def test_int8_passes_through(self): + from torchao.quantization import IntxUnpackedToInt8Tensor + + module = nn.Linear(128, 64, bias=False) + config = QuantConfig(bits=8, group_size=32, symmetric=True, method="min_max") + w = quantize_weight(torch.randn(64, 128, dtype=torch.bfloat16), config) + self.assertIsInstance(w, IntxUnpackedToInt8Tensor) + pack_linear_for_mlx(module, {"weight": w}) + + self.assertIsInstance(module.weight.data, IntxUnpackedToInt8Tensor) + self.assertEqual(module.weight.shape, torch.Size([64, 128])) + + +class TestMlxGroupSize(unittest.TestCase): + def test_passthrough(self): + for gs in (32, 64, 128): + self.assertEqual(_mlx_group_size(gs, 256), gs) + + def test_regroup_5376(self): + self.assertEqual(_mlx_group_size(5376, 5376), 128) + + def test_regroup_256(self): + self.assertEqual(_mlx_group_size(256, 256), 128) + + def test_rejects_indivisible(self): + with self.assertRaises(ValueError): + _mlx_group_size(48, 48) + + +class TestPackEmbeddingForMlx(unittest.TestCase): + def test_compatible_passes_through(self): + module = nn.Embedding(100, 64) + config = QuantConfig(bits=8, group_size=32, symmetric=True, method="min_max") + w = quantize_weight(torch.randn(100, 64, dtype=torch.bfloat16), config) + pack_embedding_for_mlx(module, {"weight": w}) + self.assertEqual(module.weight.shape, torch.Size([100, 64])) + + def test_per_axis_regroups(self): + module = nn.Embedding(50, 256) + config = QuantConfig(bits=8, group_size=256, symmetric=True, method="min_max") + w = quantize_weight(torch.randn(50, 256, dtype=torch.bfloat16), config) + pack_embedding_for_mlx(module, {"weight": w}) + self.assertEqual(module.weight.shape, torch.Size([50, 256])) + self.assertEqual(module.weight.data.block_size, (1, 128)) + + def test_rejects_int4(self): + module = nn.Embedding(100, 64) + config = QuantConfig(bits=4, group_size=32, symmetric=True, method="min_max") + w = quantize_weight(torch.randn(100, 64, dtype=torch.bfloat16), config) + with self.assertRaises(ValueError): + pack_embedding_for_mlx(module, {"weight": w}) + + +class TestPackModelMlx(unittest.TestCase): + def test_mixed_precision(self): + q4 = QuantConfig(bits=4, group_size=32, symmetric=True, method="min_max") + q8 = QuantConfig(bits=8, group_size=32, symmetric=True, method="min_max") + w4 = quantize_weight(torch.randn(64, 128, dtype=torch.bfloat16), q4) + w8 = quantize_weight(torch.randn(64, 128, dtype=torch.bfloat16), q8) + + state_dict = { + "q_proj.weight": w4, + "v_proj.weight": w8, + "norm.weight": torch.randn(64, dtype=torch.bfloat16), + } + + with torch.device("meta"): + model = nn.ModuleDict( + { + "q_proj": nn.Linear(128, 64, bias=False), + "v_proj": nn.Linear(128, 64, bias=False), + "norm": nn.LayerNorm(64, bias=False), + } + ) + pack_model(model, state_dict, DEFAULT_MLX_PACKERS) + + self.assertEqual(model.q_proj.weight.shape, torch.Size([64, 128])) + self.assertEqual(model.v_proj.weight.shape, torch.Size([64, 128])) + self.assertEqual(model.norm.weight.shape, torch.Size([64])) + + +if __name__ == "__main__": + unittest.main() diff --git a/examples/models/gemma4_31b/tests/test_mlx_pipeline.py b/examples/models/gemma4_31b/tests/test_mlx_pipeline.py new file mode 100644 index 00000000000..c766270c65b --- /dev/null +++ b/examples/models/gemma4_31b/tests/test_mlx_pipeline.py @@ -0,0 +1,154 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""End-to-end MLX backend tests for the Gemma 4 31B-IT pipeline. + +Tests quantize → save → load → pack-for-MLX on a tiny model. +No CUDA or MLX hardware required. + +Usage: + python -m pytest examples/models/gemma4_31b/tests/test_mlx_pipeline.py -v +""" + +import json +import os +import tempfile +import unittest + +import torch +import torch.nn as nn + +from executorch.examples.models.gemma4_31b.model import Gemma4_31B +from executorch.examples.models.gemma4_31b.quant import ( + DEFAULT_MLX_PACKERS, + pack_model, + QuantConfig, + quantize_model, + QuantRecipe, + QuantRule, +) +from executorch.examples.models.gemma4_31b.tests.test_pipeline import ( + build_random_tiny_model, + config_dict, + save_checkpoint, + TINY_CONFIG, +) + +_INT4 = QuantConfig(bits=4, group_size=32, symmetric=True, method="min_max") +_INT8 = QuantConfig(bits=8, group_size=32, symmetric=True, method="min_max") +_INT8_PER_AXIS = QuantConfig( + bits=8, group_size=TINY_CONFIG.hidden_size, symmetric=True, method="min_max" +) +_EDGE_LAYERS = set(range(3)) + +TINY_SENSITIVE_RECIPE = QuantRecipe( + rules=[ + QuantRule(r"embed_tokens\.weight", _INT8_PER_AXIS), + QuantRule(r".*norm\.weight", None), + QuantRule(r".*\.(v_proj|down_proj)\.weight", _INT8, layers=_EDGE_LAYERS), + QuantRule(r".*\.weight", _INT4), + ] +) + + +class TestMlxPipeline(unittest.TestCase): + """End-to-end: quantize → pack for MLX → forward.""" + + def test_pack_for_mlx(self): + """Quantize with sensitive recipe, pack for MLX, no meta weights.""" + model = build_random_tiny_model() + model.lm_head.weight = nn.Parameter(model.embed_tokens.weight.clone()) + state_dict = quantize_model(model, TINY_SENSITIVE_RECIPE) + + with torch.device("meta"): + model = Gemma4_31B(TINY_CONFIG) + model.lm_head.weight = nn.Parameter(model.embed_tokens.weight.clone()) + pack_model(model, state_dict, DEFAULT_MLX_PACKERS) + + for fqn, p in model.named_parameters(): + self.assertNotEqual(p.device.type, "meta", f"Weight '{fqn}' still on meta") + + def test_forward_after_pack(self): + """Model produces valid output after MLX packing.""" + model = build_random_tiny_model() + model.lm_head.weight = nn.Parameter(model.embed_tokens.weight.clone()) + state_dict = quantize_model(model, TINY_SENSITIVE_RECIPE) + + with torch.device("meta"): + model = Gemma4_31B(TINY_CONFIG) + model.lm_head.weight = nn.Parameter(model.embed_tokens.weight.clone()) + pack_model(model, state_dict, DEFAULT_MLX_PACKERS) + model.eval() + + from executorch.examples.models.gemma4_31b.model import ( + materialize_runtime_buffers, + ) + + materialize_runtime_buffers(model, dtype=torch.bfloat16) + + tokens = torch.randint(0, TINY_CONFIG.vocab_size, (1, 1)) + input_pos = torch.tensor([0], dtype=torch.long) + + with torch.no_grad(): + out = model(tokens, input_pos, None) + + self.assertEqual(out.shape[-1], TINY_CONFIG.vocab_size) + self.assertFalse(torch.isnan(out).any()) + self.assertFalse(torch.isinf(out).any()) + + def test_multi_token_forward(self): + model = build_random_tiny_model() + model.lm_head.weight = nn.Parameter(model.embed_tokens.weight.clone()) + state_dict = quantize_model(model, TINY_SENSITIVE_RECIPE) + + with torch.device("meta"): + model = Gemma4_31B(TINY_CONFIG) + model.lm_head.weight = nn.Parameter(model.embed_tokens.weight.clone()) + pack_model(model, state_dict, DEFAULT_MLX_PACKERS) + model.eval() + + from executorch.examples.models.gemma4_31b.model import ( + materialize_runtime_buffers, + ) + + materialize_runtime_buffers(model, dtype=torch.bfloat16) + + seq_len = 4 + tokens = torch.randint(0, TINY_CONFIG.vocab_size, (1, seq_len)) + input_pos = torch.arange(seq_len, dtype=torch.long) + + with torch.no_grad(): + out = model(tokens, input_pos, None) + + self.assertEqual(out.shape, torch.Size([1, seq_len, TINY_CONFIG.vocab_size])) + self.assertFalse(torch.isnan(out).any()) + + def test_export_to_pte(self): + """Full export: quantize → pack → export with MLXPartitioner.""" + try: + from executorch.backends.mlx import MLXPartitioner # noqa: F401 + except ImportError: + self.skipTest("MLX backend not available") + + from executorch.examples.models.gemma4_31b.export import ( + export_and_lower, + load_prequantized_model, + ) + + with tempfile.TemporaryDirectory() as ckpt_dir, tempfile.TemporaryDirectory() as out_dir: + save_checkpoint(ckpt_dir) + with open(os.path.join(ckpt_dir, "config.json"), "w") as f: + json.dump(config_dict(), f) + + model, config = load_prequantized_model( + ckpt_dir, max_seq_len=TINY_CONFIG.max_seq_len, backend="mlx" + ) + export_and_lower(model, config, out_dir, backend="mlx") + self.assertTrue(os.path.exists(os.path.join(out_dir, "model.pte"))) + + +if __name__ == "__main__": + unittest.main()