Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 18 additions & 9 deletions examples/models/llama/attention.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, Dict, Optional, Tuple, Type, TypedDict
Expand All @@ -9,6 +9,7 @@
from executorch.examples.models.llama.model_args import ModelArgs
from executorch.examples.models.llama.norm import (
RMSNorm,
RMSNormCoreML,
RMSNormGated,
ScalelessRMSNorm,
)
Expand Down Expand Up @@ -425,21 +426,29 @@
"""Initialize QK normalization layers."""
if self.use_qk_norm:
if args.qk_norm_affine:
self.q_norm_fn = RMSNorm(
self.head_dim,
eps=args.norm_eps,
add_unit_offset=args.rms_norm_add_unit_offset,
)
if self.has_kv_weights:
self.k_norm_fn = RMSNorm(
if args.use_coreml_norm:
self.q_norm_fn = RMSNormCoreML(self.head_dim, eps=args.norm_eps)
if self.has_kv_weights:
self.k_norm_fn = RMSNormCoreML(
self.head_dim, eps=args.norm_eps
)
else:
self.q_norm_fn = RMSNorm(
self.head_dim,
eps=args.norm_eps,
add_unit_offset=args.rms_norm_add_unit_offset,
)
if self.has_kv_weights:
self.k_norm_fn = RMSNorm(
self.head_dim,
eps=args.norm_eps,
add_unit_offset=args.rms_norm_add_unit_offset,
)
else:
self.q_norm_fn = ScalelessRMSNorm(self.head_dim, eps=args.norm_eps)
cls = RMSNormCoreML if args.use_coreml_norm else ScalelessRMSNorm
self.q_norm_fn = cls(self.head_dim, eps=args.norm_eps)
if self.has_kv_weights:
self.k_norm_fn = ScalelessRMSNorm(self.head_dim, eps=args.norm_eps)
self.k_norm_fn = cls(self.head_dim, eps=args.norm_eps)
if self.use_attn_o_norm:
self.o_norm = ScalelessRMSNorm(self.head_dim, eps=args.norm_eps)
if self.use_attn_o_gate:
Expand Down
29 changes: 19 additions & 10 deletions examples/models/llama/llama_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from executorch.examples.models.llama.model_args import ModelArgs
from executorch.examples.models.llama.norm import (
RMSNorm,
RMSNormCoreML,
RMSNormWithInputScale,
ScalelessRMSNorm,
)
Expand Down Expand Up @@ -168,18 +169,23 @@ def __init__(

if isinstance(self.attention, AttentionSkip):
self.attention_norm = nn.Identity()
elif args.use_coreml_norm:
self.attention_norm = RMSNormCoreML(args.dim, eps=args.norm_eps)
else:
self.attention_norm = RMSNorm(
args.dim,
eps=args.norm_eps,
add_unit_offset=args.rms_norm_add_unit_offset,
)
if self.mlp_type != "skip":
self.ffn_norm = RMSNorm(
args.dim,
eps=args.norm_eps,
add_unit_offset=args.rms_norm_add_unit_offset,
)
if args.use_coreml_norm:
self.ffn_norm = RMSNormCoreML(args.dim, eps=args.norm_eps)
else:
self.ffn_norm = RMSNorm(
args.dim,
eps=args.norm_eps,
add_unit_offset=args.rms_norm_add_unit_offset,
)

if args.use_residual_gate:
attn_init = 1.0 / (2 * layer_id + 1) if layer_id > 0 else 0.5
Expand Down Expand Up @@ -273,11 +279,14 @@ def __init__(self, params: ModelArgs, layers: nn.ModuleList, rope: Rope):
)
self.layers = layers
self.rope = rope
self.norm = RMSNorm(
params.dim,
eps=params.norm_eps,
add_unit_offset=params.rms_norm_add_unit_offset,
)
if params.use_coreml_norm:
self.norm = RMSNormCoreML(params.dim, eps=params.norm_eps)
else:
self.norm = RMSNorm(
params.dim,
eps=params.norm_eps,
add_unit_offset=params.rms_norm_add_unit_offset,
)
self.output = (
nn.Linear(params.dim, params.vocab_size, bias=False)
if self.apply_output
Expand Down
4 changes: 4 additions & 0 deletions examples/models/llama/model_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ class ModelArgs:
False # Use q-gated projection in attention (Qwen3.5 full attention)
)
norm_type: str = "rmsnorm" # Normalization type, registered in norm.py
# When True, swap RMSNorm for the CoreML-friendly RMSNormCoreML at every
# norm site. The CoreML formulation uses torch.linalg.vector_norm so the
# op is preserved in the CoreML graph (FP32 casts get stripped by CoreML).
use_coreml_norm: bool = False
act_fn: ActFn = dataclasses.field(default=ActFn.SILU) # Activation function type
attention_qkv_bias: bool = False
use_kv_cache: bool = False # Use key/value cache
Expand Down
41 changes: 41 additions & 0 deletions examples/models/llama/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,47 @@ def __init__(self, dim: int, eps: float = 1e-6):
self.weight.requires_grad = False


class RMSNormCoreML(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
"""
CoreML-friendly RMSNorm — uses `torch.linalg.vector_norm` so the op is
preserved in the CoreML graph for numerical stability.

Args:
dim (int): The dimension of the input tensor.
eps (float, optional): Stored for API compatibility; ignored in the math.

Attributes:
eps (float): Stored for API compatibility; not consumed by `_norm`.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we assert eps is 0 rather than silently drop it?

weight (nn.Parameter): Learnable scaling parameter.
"""
super().__init__()
self.dim = dim
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))

def _norm(self, x):
# Floor the denominator to avoid 0 / 0 = NaN on zero-padded positions
# (chunked prefill in StaticAttentionIOManager pads each chunk to
# input_len with zeros). Use sqrt(dim * eps) so the floor matches
# standard RMSNorm's eps semantics (`rsqrt(mean(x²) + eps)`) and is
# large enough to survive fp16 (1e-6 alone underflows in fp16).
floor_val = torch.sqrt(torch.tensor(self.dim * self.eps, dtype=x.dtype))
norm_val = torch.clamp_min(
torch.linalg.vector_norm(x, dim=-1, keepdim=True), floor_val
)
rms_norm_eps0 = (
x
* torch.sqrt(torch.tensor(self.dim, dtype=x.dtype))
* torch.reciprocal(norm_val)
)
return rms_norm_eps0

def forward(self, x):
output = self._norm(x)
return output * self.weight


class RMSNormWithInputScale(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-5):
super().__init__()
Expand Down
23 changes: 20 additions & 3 deletions examples/models/llama/static_attention.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
import copy
import logging
from abc import ABC, abstractmethod
Expand All @@ -15,7 +15,7 @@
)
from executorch.examples.models.llama.lora import LoRALinear
from executorch.examples.models.llama.model_args import ModelArgs
from executorch.examples.models.llama.norm import ScalelessRMSNorm
from executorch.examples.models.llama.norm import RMSNormCoreML, ScalelessRMSNorm
from executorch.examples.models.llama.rope import Rope


Expand Down Expand Up @@ -898,18 +898,26 @@

def _init_qk_norms(self, config: ModelArgs, is_kv_shared_layer: bool) -> None:
if self.use_qk_norm:
# When use_coreml_norm is set, match the rlformers reference path
# which constructs q_norm/k_norm via RMSNormCoreML (no fp32 cast,
# no eps, vector_norm-based) instead of ScalelessRMSNorm.
_scaleless_cls = (
RMSNormCoreML
if getattr(config, "use_coreml_norm", False)
else ScalelessRMSNorm
)
if getattr(config, "qk_norm_affine", True):
self.q_norm = torch.nn.RMSNorm(self.head_dim, config.norm_eps)
if is_kv_shared_layer:
self.k_norm = nn.Identity()
else:
self.k_norm = torch.nn.RMSNorm(self.head_dim, config.norm_eps)
else:
self.q_norm = ScalelessRMSNorm(self.head_dim, eps=config.norm_eps)
self.q_norm = _scaleless_cls(self.head_dim, eps=config.norm_eps)
if is_kv_shared_layer:
self.k_norm = nn.Identity()
else:
self.k_norm = ScalelessRMSNorm(self.head_dim, eps=config.norm_eps)
self.k_norm = _scaleless_cls(self.head_dim, eps=config.norm_eps)
else:
self.q_norm = torch.nn.Identity()
self.k_norm = torch.nn.Identity()
Expand Down Expand Up @@ -949,6 +957,14 @@
hasattr(other.q_norm_fn, "weight") if other.use_qk_norm else True
)

# Propagate use_coreml_norm so _init_qk_norms picks RMSNormCoreML for
# scaleless q/k norms (matches the rlformers reference path). Detect
# via the rms_norm_class kwarg — `transform_attention_mha_to_static_attention`
# forwards it through, and the static_transformer_export caller already
# selects RMSNormCoreML when use_coreml_norm is set on the model args.
from executorch.examples.models.llama.norm import RMSNormCoreML
_use_coreml_norm = rms_norm_class is RMSNormCoreML

config = ModelArgs(
dim=other.dim,
n_layers=1, # Not used in attention layer
Expand All @@ -964,6 +980,7 @@
norm_eps=other.q_norm_fn.eps if other.use_qk_norm else 1e-5,
num_kv_shared_layers=getattr(other, "num_kv_shared_layers", 0),
scale_query_by=getattr(other, "scale_query_by", 1.0),
use_coreml_norm=_use_coreml_norm,
)

instance = cls(
Expand Down
Loading