Add MLX backend support for Gemma 4 31B#19524
Conversation
- pack_mlx.py: converts Int4Tensor → IntxUnpackedToInt8Tensor at pack time (nibble unpack + scale transpose) so the default dispatch produces the dequantize_affine → linear pattern MLX expects. IntxUnpackedToInt8Tensor passes through unchanged. Embedding with incompatible per-axis group_size is regrouped to gs=128. - export.py: add --backend mlx with single-method export (dynamic seq_len), sampler stripping, and MLXPartitioner lowering. No int4_dispatch import — MLX uses the standard dequantize_affine path. - main.cpp: handle both CUDA (prefill+decode, on-device sampling) and MLX (single forward method, host-side argmax) via #ifdef. - CMakeLists.txt / CMakePresets.json / Makefile: add gemma4_31b-mlx build target linking mlxdelegate. - test_pack_mlx.py: 15 tests covering Int4→IntxUnpacked conversion correctness, passthrough, regrouping, error cases. - test_mlx_pipeline.py: 4 e2e tests including export-to-pte. Validated: same CUDA-quantized checkpoint packs for both backends, 100% op delegation to MLX, real 31B checkpoint packs at 4.0 GB RSS. PR authored with Claude.
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/19524
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ❌ 2 New Failures, 5 Unrelated FailuresAs of commit 023a5f1 with merge base bd5752a ( NEW FAILURES - The following jobs have failed:
FLAKY - The following job failed but was likely due to flakiness present on trunk:
BROKEN TRUNK - The following jobs failed but was present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
| } | ||
|
|
||
| // Greedy argmax over the last token's logits (MLX path). | ||
| static uint64_t argmax_last_token(const executorch::aten::Tensor& logits) { |
| if isinstance(w, Int4Tensor): | ||
| raise ValueError( | ||
| "Only 8-bit embedding quantization is supported on MLX. " | ||
| "INT4 does not implement the embedding op." |
There was a problem hiding this comment.
int4 embedding should work on mlx?
| from torchao.quantization import IntxUnpackedToInt8Tensor | ||
|
|
||
| old_gs = w.block_size[-1] | ||
| repeat_factor = old_gs // new_gs |
There was a problem hiding this comment.
Shouldn't we error out if old_gs doesn't divide new_gs?
time (nibble unpack + scale transpose) so the default dispatch
produces the dequantize_affine → linear pattern MLX expects.
IntxUnpackedToInt8Tensor passes through unchanged. Embedding with
incompatible per-axis group_size is regrouped to gs=128.
seq_len), sampler stripping, and MLXPartitioner lowering. No
int4_dispatch import — MLX uses the standard dequantize_affine path.
MLX (single forward method, host-side argmax) via #ifdef.
build target linking mlxdelegate.
correctness, passthrough, regrouping, error cases.
Validated: same CUDA-quantized checkpoint packs for both backends,
100% op delegation to MLX, real 31B checkpoint packs at 4.0 GB RSS.
PR authored with Claude.