diff --git a/onnx_diagnostic/tasks/text_generation.py b/onnx_diagnostic/tasks/text_generation.py index 25b4d29c..4fec3ac2 100644 --- a/onnx_diagnostic/tasks/text_generation.py +++ b/onnx_diagnostic/tasks/text_generation.py @@ -13,7 +13,7 @@ def reduce_model_config(config: Any) -> Dict[str, Any]: """Reduces a model size.""" - # FalconMambaConfig: use_mambapy + # Mamba models (e.g. FalconMambaConfig) use use_mambapy instead of num_attention_heads if hasattr(config, "text_config"): # The model is probably of mixture of models used only for text. config = config.text_config @@ -25,7 +25,7 @@ def reduce_model_config(config: Any) -> Dict[str, Any]: "hidden_size", "vocab_size", ) - if config.__class__.__name__ == "FalconMambaConfig": + if hasattr(config, "use_mambapy"): check_hasattr(config, "conv_kernel", "state_size", "intermediate_size") # 4 and 8 kwargs = dict( num_hidden_layers=min(config.num_hidden_layers, nhl()), @@ -54,7 +54,7 @@ def reduce_model_config(config: Any) -> Dict[str, Any]: return kwargs -def _get_input_falcon_mamba( +def _get_input_mamba( model: torch.nn.Module, config: Optional[Any], dummy_max_token_id: int, @@ -157,8 +157,8 @@ def get_inputs( seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096) cache_length = "cache_length" # torch.export.Dim("cache_length", min=1, max=4096) - if config is not None and config.__class__.__name__ == "FalconMambaConfig": - res = _get_input_falcon_mamba( + if config is not None and hasattr(config, "use_mambapy"): + res = _get_input_mamba( model=model, config=config, dummy_max_token_id=dummy_max_token_id, @@ -343,7 +343,7 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]: ("num_key_value_heads", "num_attention_heads", "use_mambapy"), "hidden_size", ) - if config.__class__.__name__ == "FalconMambaConfig": + if hasattr(config, "use_mambapy"): check_hasattr(config, "conv_kernel", "state_size", "intermediate_size") # 4 and 8 kwargs = dict( batch_size=2,