Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion moe_infinity/entrypoints/big_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ def __init__(
from moe_infinity.runtime import OffloadEngine
from moe_infinity.utils import ArcherConfig, get_checkpoint_paths
from moe_infinity.utils.hf_config import ensure_config_compat
from moe_infinity.utils.quantization import (
detect_quantization,
validate_quantization_support,
)

# TODO: remove the torch version check once older versions are supported
if is_torch_version is not None and not is_torch_version(">=", "2.0"):
Expand Down Expand Up @@ -92,6 +96,11 @@ def __init__(
model_name_or_path, trust_remote_code=True
)
model_config = ensure_config_compat(model_config)

quant_info = detect_quantization(model_config, "")
if quant_info is not None:
validate_quantization_support(quant_info, model_name_or_path)

architectures = getattr(model_config, "architectures", None)
if not architectures or not isinstance(architectures, list):
raise RuntimeError("Unable to resolve model architecture")
Expand All @@ -113,7 +122,11 @@ def __init__(
# with init_empty_weights():
# self.model = model_cls(model_config)
if os.path.exists(model_name_or_path):
checkpoint_paths = get_checkpoint_paths(model_name_or_path)
model_path = model_name_or_path
quant_info = detect_quantization(model_config, model_path)
if quant_info is not None:
validate_quantization_support(quant_info, model_name_or_path)
checkpoint_paths = get_checkpoint_paths(model_path)
else:
checkpoint_paths = None
# get the checkpoint download path from huggingface hub
Expand All @@ -127,6 +140,11 @@ def __init__(
f"The `snapshot_download` function could not find the checkpoint {model_name_or_path}. "
f"Please provide a valid checkpoint."
)

quant_info = detect_quantization(model_config, model_path)
if quant_info is not None:
validate_quantization_support(quant_info, model_name_or_path)

checkpoint_paths = get_checkpoint_paths(model_path)

if isinstance(config, dict):
Expand Down
152 changes: 127 additions & 25 deletions moe_infinity/runtime/model_offload.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# pyright: reportMissingImports=false, reportMissingTypeStubs=false, reportUnknownMemberType=false, reportUnknownVariableType=false, reportUnknownArgumentType=false, reportAttributeAccessIssue=false, reportCallIssue=false, reportUnannotatedClassAttribute=false, reportUninitializedInstanceVariable=false, reportPrivateUsage=false, reportPrivateLocalImportUsage=false, reportUnusedImport=false, reportUnusedCallResult=false, reportUnknownParameterType=false, reportMissingParameterType=false, reportExplicitAny=false, reportAny=false, reportArgumentType=false, reportOperatorIssue=false, reportImplicitStringConcatenation=false, reportUnnecessaryComparison=false, reportUnreachable=false, reportMissingTypeArgument=false, reportDeprecated=false, reportGeneralTypeIssues=false
# Copyright (c) EfficientMoE.
# SPDX-License-Identifier: Apache-2.0

Expand Down Expand Up @@ -69,6 +70,11 @@ class QuantLinearOld:
wait_transfer,
)
from moe_infinity.utils.device import get_default_device, get_device
from moe_infinity.utils.quantization import (
detect_quantization,
should_cast_tensor,
validate_quantization_support,
)

_prefetch_lib = None
# Alias for compatibility
Expand Down Expand Up @@ -209,6 +215,7 @@ def __init__(
self.config = config

self.quant_method = None
self._quant_info = None

# AttentionBackend scaffolding (no-op by default)
# Set enable_attention_offload=True to activate (future work)
Expand Down Expand Up @@ -476,6 +483,16 @@ def archer_from_pretrained(cls, *args, **kwargs):

self.model_name = model_name = args[0]

checkpoint_path = self.ckpt_files[0] if self.ckpt_files else ""
checkpoint_dir = (
os.path.dirname(checkpoint_path) if checkpoint_path else ""
)
self._quant_info = detect_quantization(
self.config, checkpoint_dir
)
if self._quant_info is not None:
validate_quantization_support(self._quant_info, model_name)

self.num_layers, self.num_experts, self.num_encoder_layers = (
parse_moe_param(self.config)
)
Expand Down Expand Up @@ -521,16 +538,7 @@ def archer_from_pretrained(cls, *args, **kwargs):
else:
state_dict = torch.load(ckpt)

# convert all tensors in state_dict to self.dtype_cls
for k, v in state_dict.items():
try:
state_dict[k] = v.to(self.dtype_cls).to("cpu")
except Exception as e:
print(
f"Error converting {k} (device={v.device}) to {self.dtype_cls} on CPU: {e}",
flush=True,
)
raise
self._cast_state_dict_tensors(state_dict)

self._offload_state_dict(state_dict, empty_state_dict)

Expand Down Expand Up @@ -575,20 +583,7 @@ def archer_from_pretrained(cls, *args, **kwargs):

base_model_prefix = model.base_model_prefix

if hasattr(self.config, "quantization_config"):
self.quant_method = self.config.quantization_config[
"quant_method"
]
self.config.quantization_config["use_exllama"] = False
self.config.quantization_config["disable_exllama"] = True
if self.quant_method == "gptq":
from optimum.gptq import GPTQQuantizer

optimum_quantizer = GPTQQuantizer.from_dict(
self.config.quantization_config
)

model = optimum_quantizer.convert_model(model)
model = self._apply_quantized_model_conversion(model)

self.expert_prefetcher = ExpertPrefetcher(self.config)
self.expert_prefetcher.set_archer_engine(self.archer_engine)
Expand Down Expand Up @@ -752,7 +747,16 @@ def get_topology(self, model):

else:
match = re.match(r"(.*\.\d+\.)", name)
last_number_position = match.end() - 2
if match:
last_number_position = match.end() - 2
else:
matches = [
each_match
for each_match in re.finditer(r"\d", name)
]
last_number_position = (
matches[-1].start() if matches else -1
)
stored_name = name[: last_number_position + 1]

if stored_name in name_lst:
Expand Down Expand Up @@ -966,6 +970,104 @@ def _offload_state_dict(
gc.collect()
torch.cuda.empty_cache()

def _cast_state_dict_tensors(
self, state_dict: Dict[str, torch.Tensor]
) -> None:
for k, v in state_dict.items():
try:
if should_cast_tensor(k, self._quant_info):
state_dict[k] = v.to(self.dtype_cls).to("cpu")
else:
state_dict[k] = v.to("cpu")
except Exception as e:
print(
f"Error converting {k} (device={v.device}) to {self.dtype_cls} on CPU: {e}",
flush=True,
)
raise

def _apply_quantized_model_conversion(self, model):
if self._quant_info is None:
return model

if self._quant_info.method == "gptq":
self.quant_method = "gptq"
quantization_config = getattr(
self.config, "quantization_config", None
)
if not isinstance(quantization_config, dict):
quantization_config = dict(self._quant_info.config_dict)
self.config.quantization_config = quantization_config

quantization_config.setdefault("quant_method", "gptq")
quantization_config["use_exllama"] = False
quantization_config["disable_exllama"] = True

try:
gptq_module = importlib.import_module("optimum.gptq")
GPTQQuantizer = getattr(gptq_module, "GPTQQuantizer")
except ImportError as e:
raise ImportError(
"GPTQ model detected but 'optimum' is not installed. "
"Install with: pip install optimum[gptq]"
) from e

optimum_quantizer = GPTQQuantizer.from_dict(quantization_config)
return optimum_quantizer.convert_model(model)

if self._quant_info.method == "awq":
self.quant_method = "awq"
return self._convert_model_for_awq(model)

return model

def _convert_model_for_awq(self, model):
try:
awq_module = importlib.import_module("awq")
except ImportError as e:
raise ImportError(
"AWQ model detected but 'autoawq' is not installed. "
"Install with: pip install autoawq"
) from e

replace_fn = getattr(awq_module, "replace_linear_modules", None)
if callable(replace_fn):
converted_model = replace_fn(model)
return converted_model if converted_model is not None else model

autoawq_cls = getattr(awq_module, "AutoAWQForCausalLM", None)
if autoawq_cls is not None:
for fn_name in ("replace_linear_modules", "convert_model"):
autoawq_fn = getattr(autoawq_cls, fn_name, None)
if callable(autoawq_fn):
converted_model = autoawq_fn(model)
return (
converted_model
if converted_model is not None
else model
)

try:
awq_linear_module = importlib.import_module("awq.modules.linear")
except Exception:
awq_linear_module = None

if awq_linear_module is not None:
for fn_name in (
"replace_linear_modules",
"replace_with_awq_linear",
):
linear_replace_fn = getattr(awq_linear_module, fn_name, None)
if callable(linear_replace_fn):
converted_model = linear_replace_fn(model)
return (
converted_model
if converted_model is not None
else model
)

return model

@torch.no_grad()
def _capture_kv_cache(self, seq_id: int, past_key_values):
if not getattr(self, "_enable_kv_cache_offload", True):
Expand Down
3 changes: 2 additions & 1 deletion moe_infinity/serving/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,8 @@ def free_gpu_blocks(self, seq_id: int) -> None:
if seq_id not in self._sequence_tables:
return

return
block_table = self._sequence_tables[seq_id]
block_table.release()

def get_block_table(self, seq_id: int) -> list[int]:
block_table = self._require_sequence(seq_id)
Expand Down
8 changes: 8 additions & 0 deletions moe_infinity/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,17 @@
parse_expert_id,
parse_moe_param,
)
from .quantization import (
QuantizationInfo,
detect_quantization,
validate_quantization_support,
)

__all__ = [
"ArcherConfig",
"async_d2h",
"async_h2d",
"detect_quantization",
"DeviceConfig",
"get_checkpoint_paths",
"get_default_device",
Expand All @@ -30,6 +36,8 @@
"parse_expert_dtype",
"parse_expert_id",
"parse_moe_param",
"QuantizationInfo",
"validate_quantization_support",
"wait_transfer",
"to_device",
]
Loading
Loading