diff --git a/docs/optimizer_search_space_config.schema.json b/docs/optimizer_search_space_config.schema.json index c46835ff..df8560ea 100644 --- a/docs/optimizer_search_space_config.schema.json +++ b/docs/optimizer_search_space_config.schema.json @@ -201,6 +201,98 @@ "title": "BERTLoRAScorerInitModel", "type": "object" }, + "BaseEmbedderConfig": { + "additionalProperties": false, + "description": "Base class for embedder configurations.", + "properties": { + "default_prompt": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Default prompt for the model. This is used when no task specific prompt is not provided.", + "title": "Default Prompt" + }, + "classification_prompt": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Prompt for classifier.", + "title": "Classification Prompt" + }, + "cluster_prompt": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Prompt for clustering.", + "title": "Cluster Prompt" + }, + "sts_prompt": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Prompt for finding most similar sentences.", + "title": "Sts Prompt" + }, + "query_prompt": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Prompt for query.", + "title": "Query Prompt" + }, + "passage_prompt": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Prompt for passage.", + "title": "Passage Prompt" + }, + "use_cache": { + "default": true, + "description": "Whether to use embeddings caching.", + "title": "Use Cache", + "type": "boolean" + } + }, + "title": "BaseEmbedderConfig", + "type": "object" + }, "BertScorerInitModel": { "additionalProperties": false, "properties": { @@ -398,6 +490,12 @@ { "$ref": "#/$defs/HashingVectorizerEmbeddingConfig" }, + { + "$ref": "#/$defs/VllmEmbeddingConfig" + }, + { + "$ref": "#/$defs/BaseEmbedderConfig" + }, { "type": "string" }, @@ -664,6 +762,12 @@ { "$ref": "#/$defs/HashingVectorizerEmbeddingConfig" }, + { + "$ref": "#/$defs/VllmEmbeddingConfig" + }, + { + "$ref": "#/$defs/BaseEmbedderConfig" + }, { "type": "string" }, @@ -1028,6 +1132,12 @@ { "$ref": "#/$defs/HashingVectorizerEmbeddingConfig" }, + { + "$ref": "#/$defs/VllmEmbeddingConfig" + }, + { + "$ref": "#/$defs/BaseEmbedderConfig" + }, { "type": "string" }, @@ -1210,6 +1320,12 @@ "title": "Val Fraction", "type": "number" }, + "seed": { + "default": 42, + "description": "Random seed for train/val split and fine-tuning.", + "title": "Seed", + "type": "integer" + }, "fp16": { "default": false, "title": "Fp16", @@ -1377,6 +1493,12 @@ { "$ref": "#/$defs/HashingVectorizerEmbeddingConfig" }, + { + "$ref": "#/$defs/VllmEmbeddingConfig" + }, + { + "$ref": "#/$defs/BaseEmbedderConfig" + }, { "type": "string" }, @@ -1407,6 +1529,12 @@ { "$ref": "#/$defs/HashingVectorizerEmbeddingConfig" }, + { + "$ref": "#/$defs/VllmEmbeddingConfig" + }, + { + "$ref": "#/$defs/BaseEmbedderConfig" + }, { "type": "string" }, @@ -1926,6 +2054,12 @@ { "$ref": "#/$defs/HashingVectorizerEmbeddingConfig" }, + { + "$ref": "#/$defs/VllmEmbeddingConfig" + }, + { + "$ref": "#/$defs/BaseEmbedderConfig" + }, { "type": "string" }, @@ -2125,6 +2259,12 @@ { "$ref": "#/$defs/HashingVectorizerEmbeddingConfig" }, + { + "$ref": "#/$defs/VllmEmbeddingConfig" + }, + { + "$ref": "#/$defs/BaseEmbedderConfig" + }, { "type": "string" }, @@ -2184,6 +2324,12 @@ { "$ref": "#/$defs/HashingVectorizerEmbeddingConfig" }, + { + "$ref": "#/$defs/VllmEmbeddingConfig" + }, + { + "$ref": "#/$defs/BaseEmbedderConfig" + }, { "type": "string" }, @@ -2336,6 +2482,12 @@ { "$ref": "#/$defs/HashingVectorizerEmbeddingConfig" }, + { + "$ref": "#/$defs/VllmEmbeddingConfig" + }, + { + "$ref": "#/$defs/BaseEmbedderConfig" + }, { "type": "string" }, @@ -2465,6 +2617,20 @@ "title": "Batch Size", "type": "integer" }, + "max_tokens_in_batch": { + "anyOf": [ + { + "exclusiveMinimum": 0, + "type": "integer" + }, + { + "type": "null" + } + ], + "default": 200000, + "description": "When set, cap each embeddings API call by the summed tiktoken length of inputs (using the encoding for `model_name`). Requests are also limited to at most `batch_size` strings. Use values around 200000 to avoid OpenAI `max_tokens_per_request` errors on long texts. Requires `tiktoken` (installed with `autointent[openai]`).", + "title": "Max Tokens In Batch" + }, "max_retries": { "default": 3, "description": "Maximum number of retries for failed API requests.", @@ -3172,6 +3338,12 @@ { "$ref": "#/$defs/HashingVectorizerEmbeddingConfig" }, + { + "$ref": "#/$defs/VllmEmbeddingConfig" + }, + { + "$ref": "#/$defs/BaseEmbedderConfig" + }, { "type": "string" }, @@ -3237,6 +3409,12 @@ { "$ref": "#/$defs/HashingVectorizerEmbeddingConfig" }, + { + "$ref": "#/$defs/VllmEmbeddingConfig" + }, + { + "$ref": "#/$defs/BaseEmbedderConfig" + }, { "type": "string" }, @@ -3653,6 +3831,12 @@ { "$ref": "#/$defs/HashingVectorizerEmbeddingConfig" }, + { + "$ref": "#/$defs/VllmEmbeddingConfig" + }, + { + "$ref": "#/$defs/BaseEmbedderConfig" + }, { "type": "string" }, @@ -3838,6 +4022,155 @@ "title": "TunableDecisionInitModel", "type": "object" }, + "VllmEmbeddingConfig": { + "additionalProperties": false, + "description": "Configuration for vLLM-based embeddings.", + "properties": { + "default_prompt": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Default prompt for the model. This is used when no task specific prompt is not provided.", + "title": "Default Prompt" + }, + "classification_prompt": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Prompt for classifier.", + "title": "Classification Prompt" + }, + "cluster_prompt": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Prompt for clustering.", + "title": "Cluster Prompt" + }, + "sts_prompt": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Prompt for finding most similar sentences.", + "title": "Sts Prompt" + }, + "query_prompt": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Prompt for query.", + "title": "Query Prompt" + }, + "passage_prompt": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Prompt for passage.", + "title": "Passage Prompt" + }, + "use_cache": { + "default": true, + "description": "Whether to use embeddings caching.", + "title": "Use Cache", + "type": "boolean" + }, + "model_name": { + "default": "BAAI/bge-base-en-v1.5", + "description": "Name of the HuggingFace model to load via vLLM.", + "title": "Model Name", + "type": "string" + }, + "batch_size": { + "default": 32, + "description": "Number of texts to encode per vLLM encode() call.", + "title": "Batch Size", + "type": "integer" + }, + "max_model_len": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Maximum sequence length. Reduces VRAM usage for long-context models.", + "title": "Max Model Len" + }, + "gpu_memory_utilization": { + "default": 0.9, + "description": "Fraction of GPU memory vLLM is allowed to use (0.0 to 1.0).", + "maximum": 1.0, + "minimum": 0.0, + "title": "Gpu Memory Utilization", + "type": "number" + }, + "dtype": { + "default": "auto", + "description": "Data type for model weights: 'auto', 'float16', 'bfloat16', 'float32'.", + "title": "Dtype", + "type": "string" + }, + "trust_remote_code": { + "default": false, + "description": "Whether to trust remote code when loading the model.", + "title": "Trust Remote Code", + "type": "boolean" + }, + "extra_init_kwargs": { + "additionalProperties": true, + "description": "Extra keyword arguments passed to the vLLM LLM() constructor.", + "title": "Extra Init Kwargs", + "type": "object" + }, + "extra_encode_kwargs": { + "additionalProperties": true, + "description": "Extra keyword arguments passed to llm.encode() at inference time (e.g. custom SamplingParams).", + "title": "Extra Encode Kwargs", + "type": "object" + } + }, + "title": "VllmEmbeddingConfig", + "type": "object" + }, "VocabConfig": { "additionalProperties": false, "properties": { diff --git a/src/autointent/_wrappers/embedder/sentence_transformers.py b/src/autointent/_wrappers/embedder/sentence_transformers.py index 4308625a..9ad39567 100644 --- a/src/autointent/_wrappers/embedder/sentence_transformers.py +++ b/src/autointent/_wrappers/embedder/sentence_transformers.py @@ -33,6 +33,21 @@ logger = logging.getLogger(__name__) +def _set_training_seed(seed: int) -> None: + import random + + random.seed(seed) + np.random.seed(seed) # noqa: NPY002 + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + + require("transformers", extra="transformers") + from transformers import set_seed + + set_seed(seed) + + @lru_cache(maxsize=128) def _get_latest_commit_hash(model_name: str) -> str: """Get the latest commit hash for a given Hugging Face model. @@ -239,6 +254,7 @@ def train(self, utterances: list[str], labels: ListOfLabels, config: EmbedderFin return model = self._load_model() + _set_training_seed(config.seed) # Lazy import sentence-transformers training components (only needed for fine-tuning) require("sentence_transformers", extra="sentence-transformers") @@ -252,7 +268,13 @@ def train(self, utterances: list[str], labels: ListOfLabels, config: EmbedderFin ) from transformers import EarlyStoppingCallback - x_train, x_val, y_train, y_val = train_test_split(utterances, labels, test_size=config.val_fraction) + x_train, x_val, y_train, y_val = train_test_split( + utterances, + labels, + test_size=config.val_fraction, + random_state=config.seed, + stratify=labels, + ) tr_ds = Dataset.from_dict({"text": x_train, "label": y_train}) val_ds = Dataset.from_dict({"text": x_val, "label": y_val}) @@ -269,6 +291,8 @@ def train(self, utterances: list[str], labels: ListOfLabels, config: EmbedderFin warmup_ratio=config.warmup_ratio, fp16=config.fp16, bf16=config.bf16, + seed=config.seed, + data_seed=config.seed, batch_sampler=training_args.BatchSamplers.NO_DUPLICATES, metric_for_best_model="eval_loss", load_best_model_at_end=True, diff --git a/src/autointent/configs/_transformers.py b/src/autointent/configs/_transformers.py index 90bed58e..1519f340 100644 --- a/src/autointent/configs/_transformers.py +++ b/src/autointent/configs/_transformers.py @@ -28,6 +28,7 @@ class EmbedderFineTuningConfig(BaseModel): early_stopping_patience: int = Field(default=1) early_stopping_threshold: float = Field(default=0.0) val_fraction: float = Field(default=0.2) + seed: int = Field(default=42, description="Random seed for train/val split and fine-tuning.") fp16: bool = Field(default=False) bf16: bool = Field(default=False)