diff --git a/examples/models/llama/attention.py b/examples/models/llama/attention.py index d43533b5a70..126333c55d6 100644 --- a/examples/models/llama/attention.py +++ b/examples/models/llama/attention.py @@ -9,6 +9,7 @@ from executorch.examples.models.llama.model_args import ModelArgs from executorch.examples.models.llama.norm import ( RMSNorm, + RMSNormCoreML, RMSNormGated, ScalelessRMSNorm, ) @@ -425,21 +426,29 @@ def _init_norms(self, args: ModelArgs) -> None: """Initialize QK normalization layers.""" if self.use_qk_norm: if args.qk_norm_affine: - self.q_norm_fn = RMSNorm( - self.head_dim, - eps=args.norm_eps, - add_unit_offset=args.rms_norm_add_unit_offset, - ) - if self.has_kv_weights: - self.k_norm_fn = RMSNorm( + if args.use_coreml_norm: + self.q_norm_fn = RMSNormCoreML(self.head_dim, eps=args.norm_eps) + if self.has_kv_weights: + self.k_norm_fn = RMSNormCoreML( + self.head_dim, eps=args.norm_eps + ) + else: + self.q_norm_fn = RMSNorm( self.head_dim, eps=args.norm_eps, add_unit_offset=args.rms_norm_add_unit_offset, ) + if self.has_kv_weights: + self.k_norm_fn = RMSNorm( + self.head_dim, + eps=args.norm_eps, + add_unit_offset=args.rms_norm_add_unit_offset, + ) else: - self.q_norm_fn = ScalelessRMSNorm(self.head_dim, eps=args.norm_eps) + cls = RMSNormCoreML if args.use_coreml_norm else ScalelessRMSNorm + self.q_norm_fn = cls(self.head_dim, eps=args.norm_eps) if self.has_kv_weights: - self.k_norm_fn = ScalelessRMSNorm(self.head_dim, eps=args.norm_eps) + self.k_norm_fn = cls(self.head_dim, eps=args.norm_eps) if self.use_attn_o_norm: self.o_norm = ScalelessRMSNorm(self.head_dim, eps=args.norm_eps) if self.use_attn_o_gate: diff --git a/examples/models/llama/llama_transformer.py b/examples/models/llama/llama_transformer.py index d87eef3f906..91e9fcb788c 100644 --- a/examples/models/llama/llama_transformer.py +++ b/examples/models/llama/llama_transformer.py @@ -23,6 +23,7 @@ from executorch.examples.models.llama.model_args import ModelArgs from executorch.examples.models.llama.norm import ( RMSNorm, + RMSNormCoreML, RMSNormWithInputScale, ScalelessRMSNorm, ) @@ -168,6 +169,8 @@ def __init__( if isinstance(self.attention, AttentionSkip): self.attention_norm = nn.Identity() + elif args.use_coreml_norm: + self.attention_norm = RMSNormCoreML(args.dim, eps=args.norm_eps) else: self.attention_norm = RMSNorm( args.dim, @@ -175,11 +178,14 @@ def __init__( add_unit_offset=args.rms_norm_add_unit_offset, ) if self.mlp_type != "skip": - self.ffn_norm = RMSNorm( - args.dim, - eps=args.norm_eps, - add_unit_offset=args.rms_norm_add_unit_offset, - ) + if args.use_coreml_norm: + self.ffn_norm = RMSNormCoreML(args.dim, eps=args.norm_eps) + else: + self.ffn_norm = RMSNorm( + args.dim, + eps=args.norm_eps, + add_unit_offset=args.rms_norm_add_unit_offset, + ) if args.use_residual_gate: attn_init = 1.0 / (2 * layer_id + 1) if layer_id > 0 else 0.5 @@ -273,11 +279,14 @@ def __init__(self, params: ModelArgs, layers: nn.ModuleList, rope: Rope): ) self.layers = layers self.rope = rope - self.norm = RMSNorm( - params.dim, - eps=params.norm_eps, - add_unit_offset=params.rms_norm_add_unit_offset, - ) + if params.use_coreml_norm: + self.norm = RMSNormCoreML(params.dim, eps=params.norm_eps) + else: + self.norm = RMSNorm( + params.dim, + eps=params.norm_eps, + add_unit_offset=params.rms_norm_add_unit_offset, + ) self.output = ( nn.Linear(params.dim, params.vocab_size, bias=False) if self.apply_output diff --git a/examples/models/llama/model_args.py b/examples/models/llama/model_args.py index ed661c75517..3036b877cae 100644 --- a/examples/models/llama/model_args.py +++ b/examples/models/llama/model_args.py @@ -76,6 +76,10 @@ class ModelArgs: False # Use q-gated projection in attention (Qwen3.5 full attention) ) norm_type: str = "rmsnorm" # Normalization type, registered in norm.py + # When True, swap RMSNorm for the CoreML-friendly RMSNormCoreML at every + # norm site. The CoreML formulation uses torch.linalg.vector_norm so the + # op is preserved in the CoreML graph (FP32 casts get stripped by CoreML). + use_coreml_norm: bool = False act_fn: ActFn = dataclasses.field(default=ActFn.SILU) # Activation function type attention_qkv_bias: bool = False use_kv_cache: bool = False # Use key/value cache diff --git a/examples/models/llama/norm.py b/examples/models/llama/norm.py index e424ee0361a..1cf20f2fa7f 100644 --- a/examples/models/llama/norm.py +++ b/examples/models/llama/norm.py @@ -57,6 +57,47 @@ def __init__(self, dim: int, eps: float = 1e-6): self.weight.requires_grad = False +class RMSNormCoreML(torch.nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + """ + CoreML-friendly RMSNorm — uses `torch.linalg.vector_norm` so the op is + preserved in the CoreML graph for numerical stability. + + Args: + dim (int): The dimension of the input tensor. + eps (float, optional): Stored for API compatibility; ignored in the math. + + Attributes: + eps (float): Stored for API compatibility; not consumed by `_norm`. + weight (nn.Parameter): Learnable scaling parameter. + """ + super().__init__() + self.dim = dim + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + # Floor the denominator to avoid 0 / 0 = NaN on zero-padded positions + # (chunked prefill in StaticAttentionIOManager pads each chunk to + # input_len with zeros). Use sqrt(dim * eps) so the floor matches + # standard RMSNorm's eps semantics (`rsqrt(mean(x²) + eps)`) and is + # large enough to survive fp16 (1e-6 alone underflows in fp16). + floor_val = torch.sqrt(torch.tensor(self.dim * self.eps, dtype=x.dtype)) + norm_val = torch.clamp_min( + torch.linalg.vector_norm(x, dim=-1, keepdim=True), floor_val + ) + rms_norm_eps0 = ( + x + * torch.sqrt(torch.tensor(self.dim, dtype=x.dtype)) + * torch.reciprocal(norm_val) + ) + return rms_norm_eps0 + + def forward(self, x): + output = self._norm(x) + return output * self.weight + + class RMSNormWithInputScale(torch.nn.Module): def __init__(self, dim: int, eps: float = 1e-5): super().__init__() diff --git a/examples/models/llama/static_attention.py b/examples/models/llama/static_attention.py index 72ce31438d6..6299d0f1902 100644 --- a/examples/models/llama/static_attention.py +++ b/examples/models/llama/static_attention.py @@ -15,7 +15,7 @@ ) from executorch.examples.models.llama.lora import LoRALinear from executorch.examples.models.llama.model_args import ModelArgs -from executorch.examples.models.llama.norm import ScalelessRMSNorm +from executorch.examples.models.llama.norm import RMSNormCoreML, ScalelessRMSNorm from executorch.examples.models.llama.rope import Rope @@ -898,6 +898,14 @@ def _init_wo(self, config: ModelArgs) -> None: def _init_qk_norms(self, config: ModelArgs, is_kv_shared_layer: bool) -> None: if self.use_qk_norm: + # When use_coreml_norm is set, match the rlformers reference path + # which constructs q_norm/k_norm via RMSNormCoreML (no fp32 cast, + # no eps, vector_norm-based) instead of ScalelessRMSNorm. + _scaleless_cls = ( + RMSNormCoreML + if getattr(config, "use_coreml_norm", False) + else ScalelessRMSNorm + ) if getattr(config, "qk_norm_affine", True): self.q_norm = torch.nn.RMSNorm(self.head_dim, config.norm_eps) if is_kv_shared_layer: @@ -905,11 +913,11 @@ def _init_qk_norms(self, config: ModelArgs, is_kv_shared_layer: bool) -> None: else: self.k_norm = torch.nn.RMSNorm(self.head_dim, config.norm_eps) else: - self.q_norm = ScalelessRMSNorm(self.head_dim, eps=config.norm_eps) + self.q_norm = _scaleless_cls(self.head_dim, eps=config.norm_eps) if is_kv_shared_layer: self.k_norm = nn.Identity() else: - self.k_norm = ScalelessRMSNorm(self.head_dim, eps=config.norm_eps) + self.k_norm = _scaleless_cls(self.head_dim, eps=config.norm_eps) else: self.q_norm = torch.nn.Identity() self.k_norm = torch.nn.Identity() @@ -949,6 +957,14 @@ def from_attention_mha( hasattr(other.q_norm_fn, "weight") if other.use_qk_norm else True ) + # Propagate use_coreml_norm so _init_qk_norms picks RMSNormCoreML for + # scaleless q/k norms (matches the rlformers reference path). Detect + # via the rms_norm_class kwarg — `transform_attention_mha_to_static_attention` + # forwards it through, and the static_transformer_export caller already + # selects RMSNormCoreML when use_coreml_norm is set on the model args. + from executorch.examples.models.llama.norm import RMSNormCoreML + _use_coreml_norm = rms_norm_class is RMSNormCoreML + config = ModelArgs( dim=other.dim, n_layers=1, # Not used in attention layer @@ -964,6 +980,7 @@ def from_attention_mha( norm_eps=other.q_norm_fn.eps if other.use_qk_norm else 1e-5, num_kv_shared_layers=getattr(other, "num_kv_shared_layers", 0), scale_query_by=getattr(other, "scale_query_by", 1.0), + use_coreml_norm=_use_coreml_norm, ) instance = cls(