Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions modelopt/torch/peft/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from typing import Annotated, Any

import torch.nn.init as init
from pydantic import PlainSerializer, WithJsonSchema, field_validator
from pydantic import PlainSerializer, WithJsonSchema, field_validator, model_validator

from modelopt.torch.opt.config import ModeloptBaseConfig, ModeloptField

Expand Down Expand Up @@ -188,9 +188,15 @@ class PEFTConfig(ModeloptBaseConfig):

freeze_base_model: bool = ModeloptField(
default=True,
title="Freeze base weights during training",
description="Whether to freeze the base model weights; in most cases, this should be set to True.",
validate_default=True,
title="Freeze all base model weights during training",
description="Whether to freeze all base model weights. Mutually exclusive with freeze_base_layers.",
)

freeze_base_layers: bool = ModeloptField(
default=False,
title="Freeze base weights only for layers with LoRA adapters",
description="Whether to freeze the base weights of only the layers that have LoRA adapters applied. "
"Layers without LoRA adapters are left unchanged. Mutually exclusive with freeze_base_model.",
)

freeze_lora_weights: bool = ModeloptField(
Expand All @@ -200,6 +206,16 @@ class PEFTConfig(ModeloptBaseConfig):
validate_default=True,
)

@model_validator(mode="after")
def validate_freeze_flags(self):
"""Ensure freeze_base_model and freeze_base_layers are not both enabled."""
if self.freeze_base_model and self.freeze_base_layers:
raise ValueError(
"freeze_base_model and freeze_base_layers are mutually exclusive. "
"Only one can be True."
)
return self

@field_validator("adapter_type")
@classmethod
def validate_adapter_type(cls, v):
Expand Down
73 changes: 68 additions & 5 deletions modelopt/torch/peft/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

"""PEFT conversion and restore utilities for LoRA modules."""

import torch
import torch.distributed as dist
import torch.nn as nn

from modelopt.torch.opt.conversion import ModelLikeModule, ModeloptStateManager
Expand All @@ -28,6 +30,7 @@
__all__ = [
"freeze_lora_weights",
"replace_lora_module",
"sync_lora_weights",
]


Expand All @@ -36,15 +39,17 @@ def convert_to_peft_model(model: ModelLikeModule, config: PEFTConfig) -> Convert
# initialize the true module if necessary
model = model.init_modellike() if isinstance(model, ModelLikeModule) else model

# Freeze all base model weights before replacing modules if freeze_base_model is True
if config.freeze_base_model:
for param in model.parameters():
param.requires_grad = False

replace_lora_module(model, version=ModeloptStateManager(model).state_version, config=config)

metadata = {}
add_adapter(model, config)

# Freeze base weights based on config flags (mutually exclusive)
if config.freeze_base_layers:
_freeze_base_weights_of_lora_layers(model)
elif config.freeze_base_model:
_freeze_all_base_weights(model)

# Update gradient settings for LoRA parameters only
_update_lora_grads(model, config)

Expand Down Expand Up @@ -228,6 +233,64 @@ def unfreeze_lora_weights(model, *, layer_patterns=None, adapter_patterns=None):
)


def sync_lora_weights(model, group=None):
"""Broadcast LoRA adapter weights from src rank 0 to all other ranks in the group.

This ensures LoRA weights are identical across data-parallel replicas after
random initialization. Should be called after LoRA adapters are added to the model.

Args:
model: Model containing LoRA modules to synchronize.
group: The process group to broadcast over (e.g., the data-parallel group).
If None, uses the default process group.
"""
if not dist.is_initialized():
return

src = dist.get_global_rank(group, 0) if group is not None else 0

for _, module in model.named_modules():
if isinstance(module, LoRAModule):
for adapter in module._lora_adapters.values():
for submodule in ("lora_a", "lora_b"):
for param in adapter[submodule].parameters():
dist.broadcast(param.data, src=src, group=group)


def _freeze_all_base_weights(model):
"""Freeze all non-LoRA parameters in the model."""
lora_param_ids = set()
for _, module in model.named_modules():
if isinstance(module, LoRAModule):
for adapter in module._lora_adapters.values():
for submodule in ("lora_a", "lora_b"):
for param in adapter[submodule].parameters():
lora_param_ids.add(id(param))

for param in model.parameters():
if id(param) not in lora_param_ids:
param.requires_grad = False


def _freeze_base_weights_of_lora_layers(model):
"""Freeze base (non-LoRA) parameters only in modules that have active LoRA adapters."""
# Collect parameter IDs of all LoRA adapter weights
lora_param_ids = set()
for _, module in model.named_modules():
if isinstance(module, LoRAModule):
for adapter in module._lora_adapters.values():
for submodule in ("lora_a", "lora_b"):
for param in adapter[submodule].parameters():
lora_param_ids.add(id(param))

# For each LoRA module that has at least one active adapter, freeze its base parameters
for _, module in model.named_modules():
if isinstance(module, LoRAModule) and module._lora_adapters:
for param in module.parameters():
if id(param) not in lora_param_ids:
param.requires_grad = False


def _update_lora_grads(model, config: PEFTConfig):
"""Update gradient computation settings for LoRA parameters only (internal function).

Expand Down
97 changes: 97 additions & 0 deletions modelopt/torch/peft/lora/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Predefined LoRA configurations for common Megatron-Core model architectures.

These configs are designed to be passed directly to
:func:`modelopt.torch.peft.update_model` or used as the ``--lora-config``
argument in the PTQ script.

Config name strings map to entries in :data:`LORA_CFG_CHOICES`.
"""

import torch.nn.init as init

__all__ = ["DENSE_LORA_CFG", "LORA_CFG_CHOICES", "MOE_LORA_CFG"]

# ---------------------------------------------------------------------------
# Dense (non-MoE) model config
# ---------------------------------------------------------------------------
# Targets the four linear projections that are standard in every transformer
# decoder layer:
# - linear_qkv : fused Q/K/V projection (ColumnParallelLinear)
# - linear_proj : attention output projection (RowParallelLinear)
# - linear_fc1 : MLP gate/up projection (ColumnParallelLinear)
# - linear_fc2 : MLP down projection (RowParallelLinear)
#
# All other modules are excluded via the wildcard ``"*": {"enable": False}``
# fallback (later patterns override earlier ones).
DENSE_LORA_CFG = {
"adapter_type": "lora",
"adapter_cfg": {
"*": {"enable": False},
"*linear_qkv*": {"rank": 64, "enable": True},
"*linear_proj*": {"rank": 64, "enable": True},
"*linear_fc1*": {"rank": 64, "enable": True},
"*linear_fc2*": {"rank": 64, "enable": True},
},
}

# ---------------------------------------------------------------------------
# MoE model config
# ---------------------------------------------------------------------------
# Apply LoRA adapter per-layer in each local_expert
MOE_LORA_CFG = {
"adapter_type": "lora",
"freeze_base_model": False,
"freeze_base_layers": True,
"adapter_cfg": {
"*": {"enable": False},
"*local_experts*linear_fc1*": {"rank": 64, "enable": True},
"*local_experts*linear_fc2*": {"rank": 64, "enable": True},
},
}
MOE_LORA_RANDOM_INIT_CFG = {
"adapter_type": "lora",
"freeze_base_model": False,
"freeze_base_layers": True,
"adapter_cfg": {
"*": {"enable": False},
"*local_experts*linear_fc1*": {
"rank": 64,
"enable": True,
"scale": 1,
"lora_a_init": init.kaiming_uniform_,
"lora_b_init": init.kaiming_uniform_,
},
"*local_experts*linear_fc2*": {
"rank": 64,
"enable": True,
"scale": 1,
"lora_a_init": init.kaiming_uniform_,
"lora_b_init": init.kaiming_uniform_,
},
},
}


# ---------------------------------------------------------------------------
# Registry
# ---------------------------------------------------------------------------
LORA_CFG_CHOICES: dict[str, dict] = {
"dense": DENSE_LORA_CFG,
"moe": MOE_LORA_CFG,
"moe_random": MOE_LORA_RANDOM_INIT_CFG,
}
Loading
Loading