Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions onnx_diagnostic/tasks/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

def reduce_model_config(config: Any) -> Dict[str, Any]:
"""Reduces a model size."""
# FalconMambaConfig: use_mambapy
# Mamba models (e.g. FalconMambaConfig) use use_mambapy instead of num_attention_heads
if hasattr(config, "text_config"):
# The model is probably of mixture of models used only for text.
config = config.text_config
Expand All @@ -25,7 +25,7 @@ def reduce_model_config(config: Any) -> Dict[str, Any]:
"hidden_size",
"vocab_size",
)
if config.__class__.__name__ == "FalconMambaConfig":
if hasattr(config, "use_mambapy"):
check_hasattr(config, "conv_kernel", "state_size", "intermediate_size") # 4 and 8
kwargs = dict(
num_hidden_layers=min(config.num_hidden_layers, nhl()),
Expand Down Expand Up @@ -54,7 +54,7 @@ def reduce_model_config(config: Any) -> Dict[str, Any]:
return kwargs


def _get_input_falcon_mamba(
def _get_input_mamba(
model: torch.nn.Module,
config: Optional[Any],
dummy_max_token_id: int,
Expand Down Expand Up @@ -157,8 +157,8 @@ def get_inputs(
seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096)
cache_length = "cache_length" # torch.export.Dim("cache_length", min=1, max=4096)

if config is not None and config.__class__.__name__ == "FalconMambaConfig":
res = _get_input_falcon_mamba(
if config is not None and hasattr(config, "use_mambapy"):
res = _get_input_mamba(
model=model,
config=config,
dummy_max_token_id=dummy_max_token_id,
Expand Down Expand Up @@ -343,7 +343,7 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
("num_key_value_heads", "num_attention_heads", "use_mambapy"),
"hidden_size",
)
if config.__class__.__name__ == "FalconMambaConfig":
if hasattr(config, "use_mambapy"):
check_hasattr(config, "conv_kernel", "state_size", "intermediate_size") # 4 and 8
kwargs = dict(
batch_size=2,
Expand Down
Loading