diff --git a/README.md b/README.md index d0693836..877a3890 100644 --- a/README.md +++ b/README.md @@ -213,7 +213,7 @@ from twinkle import init_tinker_client from twinkle.dataloader import DataLoader from twinkle.dataset import Dataset, DatasetMeta from twinkle.preprocessor import SelfCognitionProcessor -from twinkle.server.tinker.common import input_feature_to_datum +from twinkle.server.common import input_feature_to_datum base_model = 'ms://Qwen/Qwen3-30B-A3B-Instruct-2507' base_url='your-base-url' diff --git a/README_ZH.md b/README_ZH.md index 11a6cccc..2ded262f 100644 --- a/README_ZH.md +++ b/README_ZH.md @@ -193,7 +193,7 @@ from twinkle import init_tinker_client from twinkle.dataloader import DataLoader from twinkle.dataset import Dataset, DatasetMeta from twinkle.preprocessor import SelfCognitionProcessor -from twinkle.server.tinker.common import input_feature_to_datum +from twinkle.server.common import input_feature_to_datum base_model = 'ms://Qwen/Qwen3-30B-A3B-Instruct-2507' base_url='your-base-url' diff --git a/client_tools/client_generator.py b/client_tools/client_generator.py index c337c464..c0df54d3 100644 --- a/client_tools/client_generator.py +++ b/client_tools/client_generator.py @@ -243,7 +243,7 @@ def build_imports() -> Tuple[List[str], str]: if typing_imports: lines.append(f"from typing import {', '.join(sorted(typing_imports))}") lines.extend([ - 'from twinkle_client.http import http_post, heartbeat_manager', + 'from twinkle_client.http import http_post', ]) lines.extend(sorted(twinkle_imports)) @@ -274,7 +274,7 @@ def build_method(name: str, signature: str) -> str: code = f''' def {name}(self{sig_part}): response = http_post( - url=f'{{self.server_url}}/processors/call', + url=f'{{self.server_url}}/call', json_data={{ 'processor_id': self.processor_id, 'function': '{name}', @@ -288,7 +288,7 @@ def {name}(self{sig_part}): code += ''' def __next__(self): response = http_post( - url=f'{self.server_url}/processors/call', + url=f'{self.server_url}/call', json_data={ 'processor_id': self.processor_id, 'function': '__next__', @@ -346,10 +346,10 @@ class {class_name}({inheritance}): def __init__({init_params}): from twinkle_client.http import get_base_url - self.server_url = get_base_url() + self.server_url = f'{{get_base_url()}}/processor/twinkle' response = http_post( - url=f'{{self.server_url}}/processors/create', + url=f'{{self.server_url}}/create', json_data={{ 'processor_type': '{processor_type}', 'class_type': '{class_name}', @@ -358,13 +358,6 @@ def __init__({init_params}): ) response.raise_for_status() self.processor_id = response.json()['processor_id'] - heartbeat_manager.register_processor(self.processor_id) - - def __del__(self): - try: - heartbeat_manager.unregister_processor(self.processor_id) - except: - pass ''' @@ -444,18 +437,25 @@ def generate_models(): client_module_path = src_client_path / 'model' client_module_path.mkdir(parents=True, exist_ok=True) - model_code = AUTO_GEN_WARNING + '''from typing import Any, Optional, Union, Type, Dict, Literal, List -import uuid -from twinkle_client.http import http_post, heartbeat_manager -from twinkle import DeviceMesh -from twinkle.data_format import InputFeature, Trajectory + model_code = AUTO_GEN_WARNING + '''from typing import Any, Dict, Optional +from twinkle_client.http import http_post +from twinkle_client.types.model import ( + CalculateLossResponse, + CalculateMetricResponse, + ClipGradNormResponse, + ForwardBackwardResponse, + ForwardResponse, + GetStateDictResponse, + GetTrainConfigsResponse, + SaveResponse, +) class MultiLoraTransformersModel: """Client wrapper for TwinkleModel that calls server HTTP endpoints. This client manages adapters and sends training/inference requests to the model server. - Each adapter has its own lifecycle managed through automatic heartbeats. + The server-side session (managed by TwinkleClient) keeps the model alive. """ def __init__(self, model_id: str, **kwargs): @@ -466,215 +466,214 @@ def __init__(self, model_id: str, **kwargs): self.model_id = model_id if '://' in model_id: model_id = model_id.split('://')[1] - self.server_url = f'{self.server_url}/models/{model_id}' + self.server_url = f'{self.server_url}/model/{model_id}/twinkle' self.adapter_name = None response = http_post( url=f'{self.server_url}/create', ) response.raise_for_status() - def _send_adapter_heartbeat(self): - """Internal method to send adapter heartbeat.""" - response = http_post( - url=f'{self.server_url}/heartbeat', - json_data={'adapter_name': self.adapter_name} - ) - response.raise_for_status() - - def add_adapter_to_model(self, adapter_name: str, config: Dict[str, Any], **kwargs): - """Add a new adapter to the model and start automatic heartbeat.""" + def add_adapter_to_model(self, adapter_name: str, config: Dict[str, Any], **kwargs) -> None: + """Add a new adapter to the model.""" response = http_post( url=f'{self.server_url}/add_adapter_to_model', json_data={'adapter_name': adapter_name, 'config': config, **kwargs} ) response.raise_for_status() - - # Register adapter for automatic heartbeat after successful creation self.adapter_name = adapter_name - heartbeat_manager.register_adapter( - self.adapter_name, - self._send_adapter_heartbeat - ) - - def __del__(self): - """Cleanup: unregister adapter from heartbeat manager.""" - try: - heartbeat_manager.unregister_adapter(self.adapter_name) - except: - pass - def forward(self, inputs: Any, **kwargs): + def forward(self, inputs: Any, **kwargs) -> ForwardResponse: """Execute forward pass on the model.""" response = http_post( url=f'{self.server_url}/forward', json_data={'inputs': inputs, 'adapter_name': self.adapter_name, **kwargs} ) response.raise_for_status() - return response.json()['result'] + return ForwardResponse(**response.json()) - def forward_only(self, inputs: Any, **kwargs): + def forward_only(self, inputs: Any, **kwargs) -> ForwardResponse: """Execute forward pass without gradient computation.""" response = http_post( url=f'{self.server_url}/forward_only', json_data={'inputs': inputs, 'adapter_name': self.adapter_name, **kwargs} ) response.raise_for_status() - return response.json()['result'] + return ForwardResponse(**response.json()) - def calculate_loss(self, **kwargs): + def calculate_loss(self, **kwargs) -> CalculateLossResponse: """Calculate loss from model outputs.""" response = http_post( url=f'{self.server_url}/calculate_loss', json_data={'adapter_name': self.adapter_name, **kwargs} ) response.raise_for_status() - return response.json()['result'] + return CalculateLossResponse(**response.json()) - def get_train_configs(self, **kwargs): - """Get training configs""" + def get_train_configs(self, **kwargs) -> GetTrainConfigsResponse: + """Get training configs.""" response = http_post( url=f'{self.server_url}/get_train_configs', json_data={'adapter_name': self.adapter_name, **kwargs} ) response.raise_for_status() - return response.json()['result'] + return GetTrainConfigsResponse(**response.json()) - def backward(self, **kwargs): + def backward(self, **kwargs) -> None: """Execute backward pass.""" response = http_post( url=f'{self.server_url}/backward', json_data={'adapter_name': self.adapter_name, **kwargs} ) response.raise_for_status() - return response.json()['result'] - def forward_backward(self, inputs: Any, **kwargs): + def forward_backward(self, inputs: Any, **kwargs) -> ForwardBackwardResponse: """Execute combined forward and backward pass.""" response = http_post( url=f'{self.server_url}/forward_backward', json_data={'inputs': inputs, 'adapter_name': self.adapter_name, **kwargs} ) response.raise_for_status() - return response.json()['result'] + return ForwardBackwardResponse(**response.json()) - def step(self, **kwargs): + def step(self, **kwargs) -> None: """Execute optimizer step.""" response = http_post( url=f'{self.server_url}/step', json_data={'adapter_name': self.adapter_name, **kwargs} ) response.raise_for_status() - return response.json()['result'] - def zero_grad(self, **kwargs): + def zero_grad(self, **kwargs) -> None: """Zero out gradients.""" response = http_post( url=f'{self.server_url}/zero_grad', json_data={'adapter_name': self.adapter_name, **kwargs} ) response.raise_for_status() - return response.json()['result'] - def lr_step(self, **kwargs): + def lr_step(self, **kwargs) -> None: """Execute learning rate scheduler step.""" response = http_post( url=f'{self.server_url}/lr_step', json_data={'adapter_name': self.adapter_name, **kwargs} ) response.raise_for_status() - return response.json()['result'] - def set_loss(self, loss_cls: str, **kwargs): - """Set the loss function.""" + def clip_grad_norm(self, max_grad_norm: float = 1.0, norm_type: int = 2, **kwargs) -> ClipGradNormResponse: + """Clip gradient norm.""" response = http_post( - url=f'{self.server_url}/set_loss', - json_data={'loss_cls': loss_cls, 'adapter_name': self.adapter_name, **kwargs} + url=f'{self.server_url}/clip_grad_norm', + json_data={'max_grad_norm': max_grad_norm, 'norm_type': norm_type, 'adapter_name': self.adapter_name, **kwargs} ) response.raise_for_status() - return response.json()['result'] + return ClipGradNormResponse(**response.json()) - def clip_grad_norm(self, max_grad_norm: float=1.0, norm_type=2, **kwargs): - """Set the loss function.""" + def clip_grad_and_step(self, max_grad_norm: float = 1.0, norm_type: int = 2, **kwargs) -> None: + """Clip gradient norm and execute optimizer step in one call.""" response = http_post( - url=f'{self.server_url}/clip_grad_norm', + url=f'{self.server_url}/clip_grad_and_step', json_data={'max_grad_norm': max_grad_norm, 'norm_type': norm_type, 'adapter_name': self.adapter_name, **kwargs} ) response.raise_for_status() - return response.json()['result'] - def set_optimizer(self, optimizer_cls: str, **kwargs): + def set_loss(self, loss_cls: str, **kwargs) -> None: + """Set the loss function.""" + response = http_post( + url=f'{self.server_url}/set_loss', + json_data={'loss_cls': loss_cls, 'adapter_name': self.adapter_name, **kwargs} + ) + response.raise_for_status() + + def set_optimizer(self, optimizer_cls: str, **kwargs) -> None: """Set the optimizer.""" response = http_post( url=f'{self.server_url}/set_optimizer', json_data={'optimizer_cls': optimizer_cls, 'adapter_name': self.adapter_name, **kwargs} ) response.raise_for_status() - return response.json()['result'] - def set_lr_scheduler(self, scheduler_cls: str, **kwargs): + def set_lr_scheduler(self, scheduler_cls: str, **kwargs) -> None: """Set the learning rate scheduler.""" response = http_post( url=f'{self.server_url}/set_lr_scheduler', json_data={'scheduler_cls': scheduler_cls, 'adapter_name': self.adapter_name, **kwargs} ) response.raise_for_status() - return response.json()['result'] - def save(self, name: str, **kwargs): + def save(self, name: str, **kwargs) -> SaveResponse: """Save model checkpoint.""" response = http_post( url=f'{self.server_url}/save', json_data={'name': name, 'adapter_name': self.adapter_name, **kwargs} ) response.raise_for_status() - return response.json()['result'] + return SaveResponse(**response.json()) - def load(self, name: str, **kwargs): + def load(self, name: str, **kwargs) -> None: """Load model checkpoint.""" response = http_post( url=f'{self.server_url}/load', json_data={'name': name, 'adapter_name': self.adapter_name, **kwargs} ) response.raise_for_status() - return response.json()['result'] - def set_template(self, template_cls: str, **kwargs): + def apply_patch(self, patch_cls: str, **kwargs) -> None: + """Apply a patch to the model.""" + response = http_post( + url=f'{self.server_url}/apply_patch', + json_data={'patch_cls': patch_cls, 'adapter_name': self.adapter_name, **kwargs} + ) + response.raise_for_status() + + def add_metric(self, metric_cls: str, is_training: Optional[bool] = None, **kwargs) -> None: + """Add a metric to the model.""" + response = http_post( + url=f'{self.server_url}/add_metric', + json_data={'metric_cls': metric_cls, 'is_training': is_training, 'adapter_name': self.adapter_name, **kwargs} + ) + response.raise_for_status() + + def set_template(self, template_cls: str, **kwargs) -> None: """Set the template for data processing.""" response = http_post( url=f'{self.server_url}/set_template', json_data={'template_cls': template_cls, 'adapter_name': self.adapter_name, 'model_id': self.model_id, **kwargs} ) response.raise_for_status() - return response.json()['result'] - def set_processor(self, processor_cls: str, **kwargs): + def set_processor(self, processor_cls: str, **kwargs) -> None: """Set the input processor.""" response = http_post( url=f'{self.server_url}/set_processor', json_data={'processor_cls': processor_cls, 'adapter_name': self.adapter_name, **kwargs} ) response.raise_for_status() - return response.json()['result'] - def calculate_metric(self, is_training: bool = True, **kwargs): + def calculate_metric(self, is_training: bool = True, **kwargs) -> CalculateMetricResponse: """Calculate metrics from model outputs.""" response = http_post( url=f'{self.server_url}/calculate_metric', json_data={'is_training': is_training, 'adapter_name': self.adapter_name, **kwargs} ) response.raise_for_status() - return response.json()['result'] + return CalculateMetricResponse(**response.json()) - def get_state_dict(self, **kwargs): + def get_state_dict(self, **kwargs) -> GetStateDictResponse: """Get model state dictionary.""" response = http_post( url=f'{self.server_url}/get_state_dict', json_data={'adapter_name': self.adapter_name, **kwargs} ) response.raise_for_status() - return response.json()['result'] + return GetStateDictResponse(**response.json()) - def upload_to_hub(self, checkpoint_dir: str, hub_model_id: str, hub_token: Optional[str] = None, async_upload: bool = True): + def upload_to_hub( + self, + checkpoint_dir: str, + hub_model_id: str, + hub_token: Optional[str] = None, + async_upload: bool = True, + ) -> None: """Upload model checkpoint to hub. Args: @@ -689,11 +688,10 @@ def upload_to_hub(self, checkpoint_dir: str, hub_model_id: str, hub_token: Optio 'checkpoint_dir': checkpoint_dir, 'hub_model_id': hub_model_id, 'hub_token': hub_token, - 'async_upload': async_upload + 'async_upload': async_upload, } ) response.raise_for_status() - return response.json() ''' # Write the model client file @@ -721,9 +719,10 @@ def generate_samplers(): client_module_path = src_client_path / 'sampler' client_module_path.mkdir(parents=True, exist_ok=True) - sampler_code = AUTO_GEN_WARNING + '''from typing import Any, Optional, List, Dict, Union -from twinkle_client.http import http_post, heartbeat_manager + sampler_code = AUTO_GEN_WARNING + '''from typing import Any, Dict, List, Optional, Union +from twinkle_client.http import http_post from twinkle.sampler.base import Sampler +from twinkle_client.types.sampler import AddAdapterResponse, SampleResponseModel, SetTemplateResponse from peft import PeftConfig from twinkle.data_format import Trajectory, InputFeature @@ -732,7 +731,7 @@ class vLLMSampler(Sampler): """Client wrapper for Sampler that calls server HTTP endpoints. This client manages sampling operations and adapter synchronization with the sampler server. - Each adapter has its own lifecycle managed through automatic heartbeats. + The server-side session (managed by TwinkleClient) keeps the sampler alive. """ def __init__(self, model_id: str, **kwargs): @@ -743,25 +742,15 @@ def __init__(self, model_id: str, **kwargs): self.adapter_name = None if '://' in model_id: model_id = model_id.split('://')[1] - self.server_url = f'{self.server_url}/samplers/{model_id}' + self.server_url = f'{self.server_url}/sampler/{model_id}/twinkle' response = http_post( url=f'{self.server_url}/create', json_data=kwargs ) response.raise_for_status() - def _send_adapter_heartbeat(self): - """Internal method to send adapter heartbeat.""" - if not self.adapter_name: - return - response = http_post( - url=f'{self.server_url}/heartbeat', - json_data={'adapter_name': self.adapter_name} - ) - response.raise_for_status() - - def add_adapter_to_sampler(self, adapter_name: str, config: PeftConfig, **kwargs): - """Add a new adapter to the sampler and start automatic heartbeat.""" + def add_adapter_to_sampler(self, adapter_name: str, config: PeftConfig, **kwargs) -> AddAdapterResponse: + """Add a new adapter to the sampler.""" if isinstance(config, PeftConfig): config = config.__dict__ response = http_post( @@ -769,23 +758,8 @@ def add_adapter_to_sampler(self, adapter_name: str, config: PeftConfig, **kwargs json_data={'adapter_name': adapter_name, 'config': config, **kwargs} ) response.raise_for_status() - - # Register adapter for automatic heartbeat after successful creation self.adapter_name = adapter_name - heartbeat_manager.register_adapter( - self.adapter_name, - self._send_adapter_heartbeat - ) - - return response.json() - - def __del__(self): - """Cleanup: unregister adapter from heartbeat manager.""" - try: - if self.adapter_name: - heartbeat_manager.unregister_adapter(self.adapter_name) - except: - pass + return AddAdapterResponse(**response.json()) def sample( self, @@ -794,7 +768,7 @@ def sample( adapter_name: str = '', adapter_uri: Optional[str] = None, num_samples: int = 1, - ) -> Dict[str, Any]: + ) -> SampleResponseModel: """Sample from the model. Args: @@ -805,7 +779,7 @@ def sample( num_samples: Number of completions to generate per prompt. Returns: - Dict with 'sequences' list, each containing tokens, logprobs, stop_reason. + SampleResponseModel with 'sequences' list, each containing tokens, logprobs, stop_reason. """ json_data = { 'inputs': inputs, @@ -821,16 +795,16 @@ def sample( json_data=json_data ) response.raise_for_status() - return response.json() + return SampleResponseModel(**response.json()) - def set_template(self, template_cls: str, adapter_name: str = '', **kwargs): + def set_template(self, template_cls: str, adapter_name: str = '', **kwargs) -> SetTemplateResponse: """Set the template for encoding trajectories.""" response = http_post( url=f'{self.server_url}/set_template', json_data={'template_cls': template_cls, 'adapter_name': adapter_name, **kwargs} ) response.raise_for_status() - return response.json() + return SetTemplateResponse(**response.json()) ''' # Write the sampler client file diff --git a/cookbook/client/tinker/custom_service/megatron/server.py b/cookbook/client/server/megatron/server.py similarity index 92% rename from cookbook/client/tinker/custom_service/megatron/server.py rename to cookbook/client/server/megatron/server.py index e38f43a4..abce8cf6 100644 --- a/cookbook/client/tinker/custom_service/megatron/server.py +++ b/cookbook/client/server/megatron/server.py @@ -15,7 +15,7 @@ # Resolve the path to server_config.yaml relative to this script's location file_dir = os.path.abspath(os.path.dirname(__file__)) -config_path = os.path.join(file_dir, 'server_config.yaml') +config_path = os.path.join(file_dir, 'server_config_4b.yaml') # Launch the Twinkle server — this call blocks until the server is shut down launch_server(config_path=config_path) diff --git a/cookbook/client/tinker/modelscope_service/server_config.yaml b/cookbook/client/server/megatron/server_config.yaml similarity index 89% rename from cookbook/client/tinker/modelscope_service/server_config.yaml rename to cookbook/client/server/megatron/server_config.yaml index 18b0c1d2..becda8b0 100644 --- a/cookbook/client/tinker/modelscope_service/server_config.yaml +++ b/cookbook/client/server/megatron/server_config.yaml @@ -1,8 +1,5 @@ # Twinkle Server Configuration - Tinker-Compatible Transformers Backend -# Server protocol type: "tinker" enables the Tinker-compatible API -server_type: tinker - # proxy_location: determines where the HTTP proxy runs. # "EveryNode" means each Ray node runs its own proxy (good for multi-node). proxy_location: EveryNode @@ -26,6 +23,7 @@ applications: deployments: - name: TinkerCompatServer + max_ongoing_requests: 50 autoscaling_config: min_replicas: 1 # Minimum number of replicas max_replicas: 1 # Maximum number of replicas @@ -110,3 +108,25 @@ applications: runtime_env: env_vars: TWINKLE_TRUST_REMOTE_CODE: "0" + + # 4. Processor Service + - name: processor + route_prefix: /api/v1/processor + import_path: processor + args: + ncpu_proc_per_node: 2 + device_group: + name: model + ranks: 2 + device_type: CPU + device_mesh: + device_type: CPU + dp_size: 2 + deployments: + - name: ProcessorManagement + autoscaling_config: + min_replicas: 1 + max_replicas: 1 + target_ongoing_requests: 128 + ray_actor_options: + num_cpus: 0.1 diff --git a/cookbook/client/tinker/custom_service/megatron/server_config.yaml b/cookbook/client/server/megatron/server_config_4b.yaml similarity index 89% rename from cookbook/client/tinker/custom_service/megatron/server_config.yaml rename to cookbook/client/server/megatron/server_config_4b.yaml index a8103b76..0ea99551 100644 --- a/cookbook/client/tinker/custom_service/megatron/server_config.yaml +++ b/cookbook/client/server/megatron/server_config_4b.yaml @@ -1,8 +1,5 @@ # Twinkle Server Configuration - Tinker-Compatible Transformers Backend -# Server protocol type: "tinker" enables the Tinker-compatible API -server_type: tinker - # proxy_location: determines where the HTTP proxy runs. # "EveryNode" means each Ray node runs its own proxy (good for multi-node). proxy_location: EveryNode @@ -27,6 +24,7 @@ applications: - Qwen/Qwen3.5-4B deployments: - name: TinkerCompatServer + max_ongoing_requests: 50 autoscaling_config: min_replicas: 1 # Minimum number of replicas max_replicas: 1 # Maximum number of replicas @@ -106,3 +104,25 @@ applications: runtime_env: env_vars: TWINKLE_TRUST_REMOTE_CODE: "0" + + # 4. Processor Service + - name: processor + route_prefix: /api/v1/processor + import_path: processor + args: + ncpu_proc_per_node: 2 + device_group: + name: model + ranks: 2 + device_type: CPU + device_mesh: + device_type: CPU + dp_size: 2 + deployments: + - name: ProcessorManagement + autoscaling_config: + min_replicas: 1 + max_replicas: 1 + target_ongoing_requests: 128 + ray_actor_options: + num_cpus: 0.1 diff --git a/cookbook/client/tinker/custom_service/transformer/server.py b/cookbook/client/server/transformer/server.py similarity index 100% rename from cookbook/client/tinker/custom_service/transformer/server.py rename to cookbook/client/server/transformer/server.py diff --git a/cookbook/client/tinker/custom_service/transformer/server_config.yaml b/cookbook/client/server/transformer/server_config.yaml similarity index 87% rename from cookbook/client/tinker/custom_service/transformer/server_config.yaml rename to cookbook/client/server/transformer/server_config.yaml index e79ad6f2..570142af 100644 --- a/cookbook/client/tinker/custom_service/transformer/server_config.yaml +++ b/cookbook/client/server/transformer/server_config.yaml @@ -1,8 +1,5 @@ # Twinkle Server Configuration - Tinker-Compatible Transformers Backend -# Server protocol type: "tinker" enables the Tinker-compatible API -server_type: tinker - # proxy_location: determines where the HTTP proxy runs. # "EveryNode" means each Ray node runs its own proxy (good for multi-node). proxy_location: EveryNode @@ -27,6 +24,7 @@ applications: - Qwen/Qwen3.5-4B deployments: - name: TinkerCompatServer + max_ongoing_requests: 50 autoscaling_config: min_replicas: 1 # Minimum number of replicas max_replicas: 1 # Maximum number of replicas @@ -43,14 +41,14 @@ applications: use_megatron: false # Use HuggingFace Transformers backend model_id: "ms://Qwen/Qwen3.5-4B" # ModelScope model identifier max_length: 10240 - nproc_per_node: 2 # Number of GPU processes per node + nproc_per_node: 1 # Number of GPU processes per node device_group: name: model - ranks: 2 + ranks: 1 device_type: cuda device_mesh: device_type: cuda - dp_size: 2 + dp_size: 1 queue_config: rps_limit: 100 # Max requests per second tps_limit: 100000 # Max tokens per second @@ -103,3 +101,25 @@ applications: runtime_env: env_vars: TWINKLE_TRUST_REMOTE_CODE: "0" + + # 4. Processor Service + - name: processor + route_prefix: /api/v1/processor + import_path: processor + args: + ncpu_proc_per_node: 2 + device_group: + name: model + ranks: 2 + device_type: CPU + device_mesh: + device_type: CPU + dp_size: 2 + deployments: + - name: ProcessorManagement + autoscaling_config: + min_replicas: 1 + max_replicas: 1 + target_ongoing_requests: 128 + ray_actor_options: + num_cpus: 0.1 diff --git a/cookbook/client/tinker/modelscope_service/sample.py b/cookbook/client/tinker/modelscope/sample.py similarity index 100% rename from cookbook/client/tinker/modelscope_service/sample.py rename to cookbook/client/tinker/modelscope/sample.py diff --git a/cookbook/client/tinker/modelscope_service/self_cognition.py b/cookbook/client/tinker/modelscope/self_cognition.py similarity index 98% rename from cookbook/client/tinker/modelscope_service/self_cognition.py rename to cookbook/client/tinker/modelscope/self_cognition.py index f8d2a607..cb3b1700 100644 --- a/cookbook/client/tinker/modelscope_service/self_cognition.py +++ b/cookbook/client/tinker/modelscope/self_cognition.py @@ -15,7 +15,7 @@ from twinkle.dataloader import DataLoader from twinkle.dataset import Dataset, DatasetMeta from twinkle.preprocessor import SelfCognitionProcessor -from twinkle.server.tinker.common import input_feature_to_datum +from twinkle.server.common import input_feature_to_datum # Initialize the Tinker client before importing ServiceClient init_tinker_client() diff --git a/cookbook/client/tinker/modelscope_service/short_math_grpo.py b/cookbook/client/tinker/modelscope/short_math_grpo.py similarity index 100% rename from cookbook/client/tinker/modelscope_service/short_math_grpo.py rename to cookbook/client/tinker/modelscope/short_math_grpo.py diff --git a/cookbook/client/tinker/modelscope_service/server.py b/cookbook/client/tinker/modelscope_service/server.py deleted file mode 100644 index e38f43a4..00000000 --- a/cookbook/client/tinker/modelscope_service/server.py +++ /dev/null @@ -1,21 +0,0 @@ -# Twinkle Server Launcher - Tinker-Compatible Megatron Backend -# -# This script starts the Twinkle server with Tinker-compatible API support -# using the Megatron model backend. -# It reads the server_config.yaml in the same directory for all -# configuration (model, deployment settings, etc.). -# Run this script BEFORE running the client training script (lora.py). - -import os - -# Enable Ray debug mode for verbose logging during development -os.environ['TWINKLE_TRUST_REMOTE_CODE'] = '1' - -from twinkle.server import launch_server - -# Resolve the path to server_config.yaml relative to this script's location -file_dir = os.path.abspath(os.path.dirname(__file__)) -config_path = os.path.join(file_dir, 'server_config.yaml') - -# Launch the Twinkle server — this call blocks until the server is shut down -launch_server(config_path=config_path) diff --git a/cookbook/client/tinker/custom_service/lora.py b/cookbook/client/tinker/self_host/lora.py similarity index 100% rename from cookbook/client/tinker/custom_service/lora.py rename to cookbook/client/tinker/self_host/lora.py diff --git a/cookbook/client/tinker/custom_service/sample.py b/cookbook/client/tinker/self_host/sample.py similarity index 100% rename from cookbook/client/tinker/custom_service/sample.py rename to cookbook/client/tinker/self_host/sample.py diff --git a/cookbook/client/tinker/custom_service/self_cognition.py b/cookbook/client/tinker/self_host/self_cognition.py similarity index 94% rename from cookbook/client/tinker/custom_service/self_cognition.py rename to cookbook/client/tinker/self_host/self_cognition.py index e285cc7f..6951760d 100644 --- a/cookbook/client/tinker/custom_service/self_cognition.py +++ b/cookbook/client/tinker/self_host/self_cognition.py @@ -16,7 +16,7 @@ from twinkle.dataloader import DataLoader from twinkle.dataset import Dataset, DatasetMeta from twinkle.preprocessor import SelfCognitionProcessor -from twinkle.server.tinker.common import input_feature_to_datum +from twinkle.server.common import input_feature_to_datum # Initialize the Tinker client before importing ServiceClient init_tinker_client() @@ -92,9 +92,9 @@ def eval(): # Step 1: Load the trained LoRA checkpoint for inference # Path to a previously saved LoRA checkpoint (twinkle:// URI) - weight_path = 'twinkle://20260212_174205-Qwen_Qwen2_5-7B-Instruct-51edc9ed/weights/twinkle-lora-2' + weight_path = 'twinkle://20260301_142318-Qwen_Qwen3-4B-199d2cdb/weights/twinkle-lora-0' - service_client = ServiceClient(base_url=base_url, api_key=os.environ.get('MODELSCOPE_TOKEN')) + service_client = ServiceClient(base_url=base_url, api_key=api_key) sampling_client = service_client.create_sampling_client(model_path=weight_path, base_model=base_model) # Step 2: Prepare the chat prompt @@ -119,7 +119,6 @@ def eval(): params = types.SamplingParams( max_tokens=50, # Maximum tokens to generate temperature=0.2, # Low temperature for more focused responses - stop=['\n'] # Stop at newline ) # Sample 8 independent completions diff --git a/cookbook/client/tinker/custom_service/short_math_grpo.py b/cookbook/client/tinker/self_host/short_math_grpo.py similarity index 100% rename from cookbook/client/tinker/custom_service/short_math_grpo.py rename to cookbook/client/tinker/self_host/short_math_grpo.py diff --git a/cookbook/client/twinkle/megatron/server.py b/cookbook/client/twinkle/megatron/server.py deleted file mode 100644 index 3e58a5a9..00000000 --- a/cookbook/client/twinkle/megatron/server.py +++ /dev/null @@ -1,20 +0,0 @@ -# Twinkle Server Launcher - Megatron Backend -# -# This script starts the Twinkle server using Ray Serve with Megatron support. -# It reads the server_config.yaml in the same directory for all -# configuration (model, processor, deployment settings, etc.). -# Run this script BEFORE running the client training script (lora.py). - -import os - -# Enable Ray debug mode for verbose logging during development -os.environ['RAY_DEBUG'] = '1' - -from twinkle.server import launch_server - -# Resolve the path to server_config.yaml relative to this script's location -file_dir = os.path.abspath(os.path.dirname(__file__)) -config_path = os.path.join(file_dir, 'server_config.yaml') - -# Launch the Twinkle server — this call blocks until the server is shut down -launch_server(config_path=config_path) diff --git a/cookbook/client/twinkle/megatron/server_config.yaml b/cookbook/client/twinkle/megatron/server_config.yaml deleted file mode 100644 index c8efe648..00000000 --- a/cookbook/client/twinkle/megatron/server_config.yaml +++ /dev/null @@ -1,85 +0,0 @@ -# Twinkle Server Configuration - Megatron Backend - -# Server protocol type: "twinkle" for the native Twinkle client protocol -server_type: twinkle - -# proxy_location: determines where the HTTP proxy runs. -# "EveryNode" means each Ray node runs its own proxy (good for multi-node). -proxy_location: EveryNode - -# HTTP listener settings -http_options: - host: 0.0.0.0 # Listen on all network interfaces - port: 8000 # Port number for the server - -# Applications: each entry defines a service component deployed on the server -applications: - - # 1. TwinkleServer - The central management server - # Handles client connections, training run tracking, checkpoint listing. - - name: server - route_prefix: /server # API endpoint prefix - import_path: server # Python module to import - args: - server_config: - per_token_model_limit: 3 # Maximum number of models (adapters) per token (server-globally enforced) - deployments: - - name: TwinkleServer - autoscaling_config: - min_replicas: 1 # Minimum number of replicas - max_replicas: 1 # Maximum number of replicas - target_ongoing_requests: 128 # Target concurrent requests per replica - ray_actor_options: - num_cpus: 0.1 # CPU resources allocated to this actor - - # 2. Model Service - Hosts the base model for training (Megatron backend) - # This is the actual model worker that performs forward/backward passes. - - name: models-Qwen3.5-4B - route_prefix: /models/Qwen/Qwen3.5-4B # REST path for this model - import_path: model - args: - use_megatron: true # Use Megatron-LM backend (not HuggingFace) - mixed_precision: bf16 - model_id: "ms://Qwen/Qwen3.5-4B" # ModelScope model identifier to load - nproc_per_node: 2 # Number of GPU processes per node - device_group: # Logical device group for this model - name: model - ranks: 2 # Number of GPUs to use - device_type: cuda - device_mesh: # Distributed training mesh configuration - device_type: cuda - dp_size: 2 # Data parallel size - adapter_config: - adapter_timeout: 1800 # Seconds before idle adapter unload - deployments: - - name: ModelManagement - autoscaling_config: - min_replicas: 1 - max_replicas: 1 - target_ongoing_requests: 16 - ray_actor_options: - num_cpus: 0.1 - - # 3. Processor Service - Handles data preprocessing on CPU - # Runs tokenization, template application, and other CPU-bound tasks. - - name: processor - route_prefix: /processors - import_path: processor - args: - nproc_per_node: 2 # Number of processor workers per node - ncpu_proc_per_node: 2 # Number of CPU processes per node - device_group: - name: model - ranks: 2 # Number of CPU workers to use - device_type: CPU - device_mesh: - device_type: CPU - dp_size: 2 # Data parallel size - deployments: - - name: ProcessorManagement - autoscaling_config: - min_replicas: 1 - max_replicas: 1 - target_ongoing_requests: 128 - ray_actor_options: - num_cpus: 0.1 diff --git a/cookbook/client/twinkle/grpo.py b/cookbook/client/twinkle/self_host/grpo.py similarity index 61% rename from cookbook/client/twinkle/grpo.py rename to cookbook/client/twinkle/self_host/grpo.py index 1f7c0553..883b2323 100644 --- a/cookbook/client/twinkle/grpo.py +++ b/cookbook/client/twinkle/self_host/grpo.py @@ -22,16 +22,14 @@ import dotenv dotenv.load_dotenv('.env') -import re -from twinkle.data_format import Trajectory -from twinkle.reward.base import Reward import gc import os from peft import LoraConfig -from typing import List, Tuple +from typing import List, Tuple, Dict, Any from twinkle import get_logger +from twinkle.reward import GSM8KAccuracyReward, GSM8KFormatReward from twinkle.advantage import GRPOAdvantage from twinkle.dataset import DatasetMeta from twinkle.metric import CompletionRewardMetric @@ -40,6 +38,7 @@ from twinkle_client.dataset import Dataset from twinkle_client.model import MultiLoraTransformersModel from twinkle_client.sampler import vLLMSampler +from twinkle.preprocessor.llm import GSM8KProcessor logger = get_logger() @@ -55,62 +54,22 @@ GRADIENT_ACCUMULATION_STEPS = 4 -def create_countdown_dataset(): - """Create Countdown Game dataset for GRPO training.""" - - dataset = Dataset(dataset_meta=DatasetMeta('ms://zouxuhong/Countdown-Tasks-3to4', data_slice=range(500))) - dataset.set_template('Template', model_id=MODEL_ID, max_length=8192) - dataset.map('CountdownProcessor') - dataset.encode(add_generation_prompt=True, batched=True) +def create_gsm8k_dataset(): + dataset = Dataset(DatasetMeta('ms://modelscope/gsm8k', subset_name='main', split='train')) + dataset.set_template('Template', model_id=MODEL_ID, max_length=2048) + dataset.map('GSM8KProcessor') + dataset.encode(add_generation_prompt=True) return dataset +def compute_rewards( + trajectories: List[Dict[str, Any]], +) -> Tuple[List[float], List[float], List[float]]: + accuracy_reward_fn = GSM8KAccuracyReward() + format_reward_fn = GSM8KFormatReward() -class CountDownAccuracy(Reward): - - @staticmethod - def countdown_accuracy_reward(completion: str, target: int, nums: List[int]) -> float: - """Accuracy reward: checks if equation is correct.""" - try: - match = re.search(r'(.*?)<\/answer>', completion) - if match is None: - return 0.0 - equation = match.group(1).strip() - if '=' in equation: - equation = equation.split('=')[0] - used_numbers = [int(n) for n in re.findall(r'\d+', equation)] - if sorted(used_numbers) != sorted(nums): - return 0.0 - if not re.match(r'^[\d+\-*/().\s]+$', equation): - return 0.0 - result = eval(equation, {'__builtins__': None}, {}) - return 1.0 if abs(float(result) - float(target)) < 1e-5 else 0.0 - except Exception: # noqa - return 0.0 - - def __call__(self, trajectories: List[Trajectory], ground_truths: List[Trajectory]): - rewards = [] - for trajectory in trajectories: - messages = trajectory.get('messages', []) - completion = '' - for msg in reversed(messages): - if msg.get('role') == 'assistant': - completion = msg.get('content', '') - break - user_data = trajectory.get('user_data', [{}]) - data = user_data[0] if isinstance(user_data, list) and user_data else {} - target = data.get('target', 0) - nums = data.get('nums', []) - acc_reward = self.countdown_accuracy_reward(completion, target, nums) - rewards.append(acc_reward) - return rewards - - -def compute_rewards(trajectories: List[dict], ) -> Tuple[List[float], List[float], List[float]]: - """Compute format and accuracy rewards for Countdown game.""" - from twinkle.reward import FormatReward - format_rewards = FormatReward()(trajectories, []) - accuracy_rewards = CountDownAccuracy()(trajectories, []) - total_rewards = [a + b for a, b in zip(accuracy_rewards, format_rewards)] + accuracy_rewards = accuracy_reward_fn(trajectories) + format_rewards = format_reward_fn(trajectories) + total_rewards = [a + f for a, f in zip(accuracy_rewards, format_rewards)] return total_rewards, format_rewards, accuracy_rewards @@ -122,7 +81,7 @@ def train(): ) # Step 2: Prepare dataset and dataloader - dataset = create_countdown_dataset() + dataset = create_gsm8k_dataset() dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE) # Step 3: Configure the training model @@ -185,11 +144,11 @@ def train(): # the resulting path to the sampler as adapter_uri if step % SYNC_INTERVAL == 0: logger.info(f'Step {step}: Saving weights for sampler...') - twinkle_path = model.save( + result = model.save( name=f'grpo-sampler-step-{step}', save_optimizer=False, ) - current_adapter_uri = twinkle_path + current_adapter_uri = result.twinkle_path logger.info(f'Step {step}: Saved weights to {current_adapter_uri}') # ========== 2. Sample completions ========== @@ -200,32 +159,29 @@ def train(): num_samples=NUM_GENERATIONS, ) - input_features = [] - old_logps_list = [] - completion_lengths = [] + all_input_data: List[Dict[str, Any]] = [] + all_old_logps: List[List[float]] = [] + all_completion_lengths: List[int] = [] - sequences = sample_response.get('sequences', []) - for seq in sequences: - input_features.append(seq.get('new_input_feature', seq)) - old_logps_list.append(seq.get('logprobs', [])) - completion_lengths.append(len(seq.get('tokens', []))) - - if not input_features: - logger.warning(f'Step {step}: No valid samples, skipping') - step += 1 - continue + for sequence in sample_response.sequences: + all_input_data.append(sequence.new_input_feature) + all_old_logps.append(sequence.logprobs) + all_completion_lengths.append(len(sequence.tokens)) # ========== 3. Compute rewards ========== - total_rewards, format_rewards, accuracy_rewards = compute_rewards(input_features) + + total_rewards, format_rewards, accuracy_rewards = compute_rewards( + all_input_data + ) metrics.accumulate( - None, - None, - completion_lengths=completion_lengths, + completion_lengths=all_completion_lengths, rewards={ 'total': total_rewards, 'format': format_rewards, 'accuracy': accuracy_rewards, - }) + }, + ) + # ========== 4. Compute advantages ========== advantages = advantage_fn( @@ -244,29 +200,27 @@ def train(): # forward_backward with GRPO loss: passes advantages and old_logps # to the server-side GRPOLoss for proper policy optimization model.forward_backward( - inputs=input_features, + inputs=all_input_data, advantages=advantages, - old_logps=old_logps_list, + old_logps=all_old_logps, ) # Gradient clipping and optimizer step - model.clip_grad_norm(1.0) - model.step() - model.zero_grad() - model.lr_step() + model.clip_grad_and_step() gc.collect() # ========== 6. Log ========== log_dict = metrics.calculate() - log_dict.update(model.calculate_metric()) + log_dict.update(model.calculate_metric(is_training=True).result) log_dict['train/frac_reward_zero_std'] = frac_zero_std logger.info(f'Step {step}: {log_dict}') step += 1 + metrics.reset() # Save final checkpoint - twinkle_path = model.save(name='grpo-countdown-final', save_optimizer=True) - logger.info(f'Saved final checkpoint: {twinkle_path}') + result = model.save(name='grpo-countdown-final', save_optimizer=True) + logger.info(f'Saved final checkpoint: {result}') if __name__ == '__main__': diff --git a/cookbook/client/twinkle/sample.py b/cookbook/client/twinkle/self_host/sample.py similarity index 89% rename from cookbook/client/twinkle/sample.py rename to cookbook/client/twinkle/self_host/sample.py index 9437bb36..d800b635 100644 --- a/cookbook/client/twinkle/sample.py +++ b/cookbook/client/twinkle/self_host/sample.py @@ -29,14 +29,13 @@ # or None to use the base model # ADAPTER_URI = None # Example: -ADAPTER_URI = 'twinkle://20260208_224851-fa3cdd11-default/weights/twinkle-epoch-2' - +ADAPTER_URI = 'twinkle://20260301_142318-Qwen_Qwen3-4B-199d2cdb/weights/twinkle-lora-0' def sample(): # Step 2: Initialize the Twinkle client to communicate with the remote server. client = init_twinkle_client( base_url='http://127.0.0.1:8000', - api_key=os.environ.get('MODELSCOPE_TOKEN'), + api_key='EMPTY_API_KEY', ) # Step 3: Create the sampler client pointing to the model on the server @@ -84,11 +83,11 @@ def sample(): # Step 8: Decode and print the results tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) - logger.info(f"Generated {len(response['sequences'])} sequences " + logger.info(f'Generated {len(response.sequences)} sequences ' f'({num_prompts} prompts x {num_samples} samples)') - for i, seq in enumerate(response['sequences']): - text = tokenizer.decode(seq['tokens'], skip_special_tokens=True) + for i, seq in enumerate(response.sequences): + text = tokenizer.decode(seq.tokens, skip_special_tokens=True) logger.info(f'Sequence {i}:\n {text}\n') diff --git a/cookbook/client/twinkle/self_congnition.py b/cookbook/client/twinkle/self_host/self_congnition.py similarity index 84% rename from cookbook/client/twinkle/self_congnition.py rename to cookbook/client/twinkle/self_host/self_congnition.py index f9e56dd1..6bf6afce 100644 --- a/cookbook/client/twinkle/self_congnition.py +++ b/cookbook/client/twinkle/self_host/self_congnition.py @@ -87,7 +87,7 @@ def train(): model.set_optimizer('Adam', lr=1e-4) # Use a linear learning rate scheduler (Do not support LR scheduler if server use megatron) - model.set_lr_scheduler('LinearLR') + # model.set_lr_scheduler('LinearLR') # Step 6: Optionally resume from a previous checkpoint if resume_path: @@ -95,30 +95,31 @@ def train(): model.load(resume_path, load_optimizer=True) # Step 7: Run the training loop - logger.info(model.get_train_configs()) + logger.info(model.get_train_configs().model_dump()) for epoch in range(3): logger.info(f'Starting epoch {epoch}') for step, batch in enumerate(dataloader): # Forward pass + backward pass (computes gradients) - output = model.forward_backward(inputs=batch) - loss=output.get('loss', 'N/A') + model.forward_backward(inputs=batch) + + # Step + model.clip_grad_and_step() + # Equal to the following steps: + # # Clip gradients to prevent exploding gradients (max norm = 1.0) + # model.clip_grad_norm(1.0) + # # Perform one optimizer step (update model weights) + # model.step() + # # Reset gradients to zero for the next iteration + # model.zero_grad() + # # Advance the learning rate scheduler by one step + # model.lr_step() # Log the loss every 2 steps (aligned with gradient accumulation) if step % 2 == 0: - logger.info(f'Current is step {step // 2}, loss: {loss}') - - # Clip gradients to prevent exploding gradients (max norm = 1.0) - model.clip_grad_norm(1.0) - - # Perform one optimizer step (update model weights) - model.step() - - # Reset gradients to zero for the next iteration - model.zero_grad() - - # Advance the learning rate scheduler by one step - model.lr_step() + # Print metric + metric = model.calculate_metric(is_training=True) + logger.info(f'Current is step {step} of {len(dataloader)}, metric: {metric.result}') # Step 8: Save the trained checkpoint twinkle_path = model.save(name=f'twinkle-epoch-{epoch}', save_optimizer=True) diff --git a/cookbook/client/twinkle/transformer/server.py b/cookbook/client/twinkle/transformer/server.py deleted file mode 100644 index ba84e2dd..00000000 --- a/cookbook/client/twinkle/transformer/server.py +++ /dev/null @@ -1,20 +0,0 @@ -# Twinkle Server Launcher - Transformers Backend -# -# This script starts the Twinkle server using Ray Serve. -# It reads the server_config.yaml in the same directory for all -# configuration (model, processor, deployment settings, etc.). -# Run this script BEFORE running the client training script (lora.py). - -import os - -# Enable Ray debug mode for verbose logging during development -os.environ['RAY_DEBUG'] = '1' - -from twinkle.server import launch_server - -# Resolve the path to server_config.yaml relative to this script's location -file_dir = os.path.abspath(os.path.dirname(__file__)) -config_path = os.path.join(file_dir, 'server_config.yaml') - -# Launch the Twinkle server — this call blocks until the server is shut down -launch_server(config_path=config_path) diff --git a/cookbook/client/twinkle/transformer/server_config.yaml b/cookbook/client/twinkle/transformer/server_config.yaml deleted file mode 100644 index e16ced6a..00000000 --- a/cookbook/client/twinkle/transformer/server_config.yaml +++ /dev/null @@ -1,123 +0,0 @@ -# Twinkle Server Configuration - Transformers Backend - -# Server protocol type: "twinkle" for the native Twinkle client protocol -server_type: twinkle - -# proxy_location: determines where the HTTP proxy runs. -# "EveryNode" means each Ray node runs its own proxy (good for multi-node). -proxy_location: EveryNode - -# HTTP listener settings -http_options: - host: 0.0.0.0 # Listen on all network interfaces - port: 8000 # Port number for the server - -# Applications: each entry defines a service component deployed on the server -applications: - - # 1. TwinkleServer - The central management server - # Handles client connections, training run tracking, checkpoint listing. - - name: server - route_prefix: /server # API endpoint prefix - import_path: server # Python module to import - args: - server_config: - per_token_model_limit: 3 # Maximum number of models (adapters) per token (server-globally enforced) - deployments: - - name: TwinkleServer - autoscaling_config: - min_replicas: 1 # Minimum number of replicas - max_replicas: 1 # Maximum number of replicas - target_ongoing_requests: 128 # Target concurrent requests per replica - ray_actor_options: - num_cpus: 0.1 # CPU resources allocated to this actor - - # 2. Model Service - Hosts the base model for training - # This is the actual model worker that performs forward/backward passes. - - name: models-Qwen3.5-4B - route_prefix: /models/Qwen/Qwen3.5-4B # REST path for this model - import_path: model - args: - use_megatron: false # Use HuggingFace Transformers (not Megatron) - model_id: "ms://Qwen/Qwen3.5-4B" # ModelScope model identifier to load - adapter_config: - adapter_timeout: 1800 # Seconds before an idle adapter is unloaded - nproc_per_node: 2 # Number of GPU processes per node - device_group: # Logical device group for this model - name: model - ranks: 2 # Number of GPUs to use - device_type: cuda - device_mesh: # Distributed training mesh configuration - device_type: cuda - dp_size: 2 # Mesh dimension names: 'dp' = data parallel - deployments: - - name: ModelManagement - autoscaling_config: - min_replicas: 1 - max_replicas: 1 - target_ongoing_requests: 16 - ray_actor_options: - num_cpus: 0.1 - runtime_env: - env_vars: - TWINKLE_TRUST_REMOTE_CODE: "0" - - # 3. Processor Service - Handles data preprocessing on CPU - # Runs tokenization, template application, and other CPU-bound tasks. - - name: processor - route_prefix: /processors - import_path: processor - args: - nproc_per_node: 2 # Number of processor workers per node - ncpu_proc_per_node: 2 # Number of CPU processes per node - device_group: - name: model - ranks: 2 # Number of CPU workers to use - device_type: CPU - device_mesh: - device_type: CPU - dp_size: 2 # Data parallel size - deployments: - - name: ProcessorManagement - autoscaling_config: - min_replicas: 1 - max_replicas: 1 - target_ongoing_requests: 128 - ray_actor_options: - num_cpus: 0.1 - runtime_env: - env_vars: - TWINKLE_TRUST_REMOTE_CODE: "0" - - # 4. Sampler Service - Handles text generation inference - # Uses vLLM for efficient batched generation with optional LoRA adapters. - - name: sampler-Qwen3.5-4B - route_prefix: /samplers/Qwen/Qwen3.5-4B # REST path for this sampler - import_path: sampler - args: - model_id: "ms://Qwen/Qwen3.5-4B" # ModelScope model identifier to load - sampler_type: vllm # Sampler backend (vllm or torch) - nproc_per_node: 2 # Number of GPU processes per node - engine_args: # vLLM engine configuration - gpu_memory_utilization: 0.4 - max_model_len: 1024 - adapter_config: # Adapter lifecycle management - adapter_timeout: 1800 # Seconds before idle adapter is unloaded - device_group: - name: sampler - ranks: 1 # Number of GPUs to use - device_type: cuda - device_mesh: - device_type: cuda - dp_size: 1 - deployments: - - name: SamplerManagement - autoscaling_config: - min_replicas: 1 - max_replicas: 1 - target_ongoing_requests: 16 - ray_actor_options: - num_cpus: 0.1 - runtime_env: - env_vars: - TWINKLE_TRUST_REMOTE_CODE: "0" diff --git a/cookbook/transformers/fsdp2.py b/cookbook/transformers/fsdp2.py index ca37d724..a9c60c82 100644 --- a/cookbook/transformers/fsdp2.py +++ b/cookbook/transformers/fsdp2.py @@ -20,7 +20,7 @@ def eval(model): # 100 Samples dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(100))) - dataset.set_template('Template', model_id='ms://Qwen/Qwen3-4B') + dataset.set_template('Template', model_id='ms://Qwen/Qwen3.5-4B') dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区')) dataset.encode() dataloader = DataLoader(dataset=dataset, batch_size=8) @@ -35,7 +35,7 @@ def train(): # 1000 samples dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(1000))) # Set template to prepare encoding - dataset.set_template('Template', model_id='ms://Qwen/Qwen3-4B') + dataset.set_template('Template', model_id='ms://Qwen/Qwen3.5-4B') # Preprocess the dataset to standard format dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区')) # Encode dataset @@ -43,7 +43,7 @@ def train(): # Global batch size = 8, for GPUs, so 1 sample per GPU dataloader = DataLoader(dataset=dataset, batch_size=8) # Use a TransformersModel - model = TransformersModel(model_id='ms://Qwen/Qwen3-4B') + model = TransformersModel(model_id='ms://Qwen/Qwen3.5-4B') lora_config = LoraConfig(r=8, lora_alpha=32, target_modules='all-linear') diff --git a/cookbook/transformers/sp_fsdp_dense.py b/cookbook/transformers/sp_fsdp_dense.py index da6e2d28..868b61c0 100644 --- a/cookbook/transformers/sp_fsdp_dense.py +++ b/cookbook/transformers/sp_fsdp_dense.py @@ -10,7 +10,7 @@ from twinkle.preprocessor import SelfCognitionProcessor logger = get_logger() -MODEL_ID = 'ms://Qwen/Qwen3-4B' +MODEL_ID = 'ms://Qwen/Qwen3.5-4B' DATASETS = 'ms://swift/self-cognition' device_group = [DeviceGroup( diff --git a/docs/source_en/Usage Guide/Introduction-with-Qwen3.5.md b/docs/source_en/Usage Guide/Introduction-with-Qwen3.5.md index 2f67e37b..3ef72e3e 100644 --- a/docs/source_en/Usage Guide/Introduction-with-Qwen3.5.md +++ b/docs/source_en/Usage Guide/Introduction-with-Qwen3.5.md @@ -458,7 +458,7 @@ from twinkle import init_tinker_client from twinkle.dataloader import DataLoader from twinkle.dataset import Dataset, DatasetMeta from twinkle.preprocessor import SelfCognitionProcessor -from twinkle.server.tinker.common import input_feature_to_datum +from twinkle.server.common import input_feature_to_datum # Initialize Tinker client (must be called before importing ServiceClient) init_tinker_client() diff --git a/docs/source_en/Usage Guide/Quick-Start.md b/docs/source_en/Usage Guide/Quick-Start.md index 24820fea..cde5bf19 100644 --- a/docs/source_en/Usage Guide/Quick-Start.md +++ b/docs/source_en/Usage Guide/Quick-Start.md @@ -679,7 +679,7 @@ from tinker import ServiceClient from twinkle.dataloader import DataLoader from twinkle.dataset import Dataset, DatasetMeta from twinkle.preprocessor import SelfCognitionProcessor -from twinkle.server.tinker.common import input_feature_to_datum +from twinkle.server.common import input_feature_to_datum # The base model to fine-tune / evaluate base_model = 'ms://Qwen/Qwen3.5-4B' diff --git a/docs/source_en/Usage Guide/Server and Client/Tinker-Compatible-Client.md b/docs/source_en/Usage Guide/Server and Client/Tinker-Compatible-Client.md index a01fd141..e44f3cea 100644 --- a/docs/source_en/Usage Guide/Server and Client/Tinker-Compatible-Client.md +++ b/docs/source_en/Usage Guide/Server and Client/Tinker-Compatible-Client.md @@ -143,7 +143,7 @@ from twinkle import init_tinker_client from twinkle.dataloader import DataLoader from twinkle.dataset import Dataset, DatasetMeta from twinkle.preprocessor import SelfCognitionProcessor -from twinkle.server.tinker.common import input_feature_to_datum +from twinkle.server.common import input_feature_to_datum # Initialize Tinker client before importing ServiceClient init_tinker_client() diff --git a/docs/source_en/Usage Guide/Train-as-a-Service.md b/docs/source_en/Usage Guide/Train-as-a-Service.md index fd6c30f3..29828091 100644 --- a/docs/source_en/Usage Guide/Train-as-a-Service.md +++ b/docs/source_en/Usage Guide/Train-as-a-Service.md @@ -28,7 +28,7 @@ from twinkle_client import init_tinker_client from twinkle.dataloader import DataLoader from twinkle.dataset import Dataset, DatasetMeta from twinkle.preprocessor import SelfCognitionProcessor -from twinkle.server.tinker.common import input_feature_to_datum +from twinkle.server.common import input_feature_to_datum base_model = 'ms://Qwen/Qwen3-30B-A3B-Instruct-2507' base_url='http://www.modelscope.cn/twinkle' diff --git "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/Qwen3.5\346\234\200\344\275\263\345\256\236\350\267\265.md" "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/Qwen3.5\346\234\200\344\275\263\345\256\236\350\267\265.md" index 8b86b9b0..cfb57655 100644 --- "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/Qwen3.5\346\234\200\344\275\263\345\256\236\350\267\265.md" +++ "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/Qwen3.5\346\234\200\344\275\263\345\256\236\350\267\265.md" @@ -458,7 +458,7 @@ from twinkle import init_tinker_client from twinkle.dataloader import DataLoader from twinkle.dataset import Dataset, DatasetMeta from twinkle.preprocessor import SelfCognitionProcessor -from twinkle.server.tinker.common import input_feature_to_datum +from twinkle.server.common import input_feature_to_datum # 初始化 Tinker 客户端(必须在导入 ServiceClient 之前) init_tinker_client() diff --git "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\345\277\253\351\200\237\345\274\200\345\247\213.md" "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\345\277\253\351\200\237\345\274\200\345\247\213.md" index 5e4cbf0d..b79126af 100644 --- "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\345\277\253\351\200\237\345\274\200\345\247\213.md" +++ "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\345\277\253\351\200\237\345\274\200\345\247\213.md" @@ -681,7 +681,7 @@ from tinker import ServiceClient from twinkle.dataloader import DataLoader from twinkle.dataset import Dataset, DatasetMeta from twinkle.preprocessor import SelfCognitionProcessor -from twinkle.server.tinker.common import input_feature_to_datum +from twinkle.server.common import input_feature_to_datum # The base model to fine-tune / evaluate base_model = 'Qwen/Qwen3.5-4B' diff --git "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/Tinker\345\205\274\345\256\271\345\256\242\346\210\267\347\253\257.md" "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/Tinker\345\205\274\345\256\271\345\256\242\346\210\267\347\253\257.md" index 11b51303..27db69b2 100644 --- "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/Tinker\345\205\274\345\256\271\345\256\242\346\210\267\347\253\257.md" +++ "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/Tinker\345\205\274\345\256\271\345\256\242\346\210\267\347\253\257.md" @@ -143,7 +143,7 @@ from twinkle import init_tinker_client from twinkle.dataloader import DataLoader from twinkle.dataset import Dataset, DatasetMeta from twinkle.preprocessor import SelfCognitionProcessor -from twinkle.server.tinker.common import input_feature_to_datum +from twinkle.server.common import input_feature_to_datum # 在导入 ServiceClient 之前,先初始化 Tinker 客户端 init_tinker_client() diff --git "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\350\256\255\347\273\203\346\234\215\345\212\241.md" "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\350\256\255\347\273\203\346\234\215\345\212\241.md" index c0d5b68f..5d0272c3 100644 --- "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\350\256\255\347\273\203\346\234\215\345\212\241.md" +++ "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\350\256\255\347\273\203\346\234\215\345\212\241.md" @@ -31,7 +31,7 @@ from twinkle_client import init_tinker_client from twinkle.dataloader import DataLoader from twinkle.dataset import Dataset, DatasetMeta from twinkle.preprocessor import SelfCognitionProcessor -from twinkle.server.tinker.common import input_feature_to_datum +from twinkle.server.common import input_feature_to_datum base_model = 'ms://Qwen/Qwen3-30B-A3B-Instruct-2507' base_url='http://www.modelscope.cn/twinkle' diff --git a/setup.cfg b/setup.cfg index 3ca70ce3..811fd55c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -22,7 +22,7 @@ ignore-words-list = patten,nd,ty,mot,hist,formating,winn,gool,datas,wan,confids [flake8] max-line-length = 120 select = B,E,F,P,T4,W,B9 -ignore = F401,F403,F405,F821,W503,E251,W504,E126 +ignore = F401,F403,F405,F821,W503,E251,W504,E126,E125 exclude = docs/src,*.pyi,.git,peft.py [darglint] diff --git a/src/twinkle/hub/hub.py b/src/twinkle/hub/hub.py index 6e1653e1..916a42b2 100644 --- a/src/twinkle/hub/hub.py +++ b/src/twinkle/hub/hub.py @@ -374,7 +374,7 @@ def push_to_hub(cls, ignore_patterns = [] if revision is None or revision == 'main': revision = 'master' - return push_to_hub( + result = push_to_hub( repo_id, folder_path, token or cls.ms_token, @@ -383,6 +383,8 @@ def push_to_hub(cls, ignore_file_pattern=ignore_patterns, revision=revision, tag=path_in_repo) + if not result: + raise Exception('Failed to push to hub') @classmethod def load_dataset(cls, diff --git a/src/twinkle/infra/_ray/ray_helper.py b/src/twinkle/infra/_ray/ray_helper.py index f0a4011b..0d8908a3 100644 --- a/src/twinkle/infra/_ray/ray_helper.py +++ b/src/twinkle/infra/_ray/ray_helper.py @@ -157,7 +157,7 @@ def get_master_id_port(placement_group): def get_node_address(): return find_node_ip(), find_free_port() - ip, port = ray.get(get_node_address.options(placement_group=placement_group).remote()) + ip, port = ray.get(get_node_address.options(placement_group=placement_group, num_cpus=0.01).remote()) return ip, port @staticmethod diff --git a/src/twinkle/preprocessor/__init__.py b/src/twinkle/preprocessor/__init__.py index 1c19815e..13b52d99 100644 --- a/src/twinkle/preprocessor/__init__.py +++ b/src/twinkle/preprocessor/__init__.py @@ -1,4 +1,4 @@ # Copyright (c) ModelScope Contributors. All rights reserved. from .base import DataFilter, Preprocessor from .llm import (AlpacaProcessor, CompetitionMathGRPOProcessor, CompetitionMathProcessor, CountdownProcessor, - SelfCognitionProcessor) + GSM8KProcessor, SelfCognitionProcessor) diff --git a/src/twinkle/server/__main__.py b/src/twinkle/server/__main__.py index 17ea2e1f..e18283c3 100644 --- a/src/twinkle/server/__main__.py +++ b/src/twinkle/server/__main__.py @@ -5,12 +5,6 @@ Usage: # From config file python -m twinkle.server --config server_config.yaml - - # With server type override - python -m twinkle.server --config server_config.yaml --server-type tinker - - # Quick start with minimal args - python -m twinkle.server --server-type tinker --port 8000 --model-id "Qwen/Qwen3.5-4B" """ from __future__ import annotations @@ -27,15 +21,12 @@ def create_parser() -> argparse.ArgumentParser: """Create the argument parser.""" parser = argparse.ArgumentParser( prog='python -m twinkle.server', - description='Twinkle Server Launcher - Unified launcher for tinker and twinkle servers', + description='Twinkle Server Launcher - Unified launcher supporting both Tinker and Twinkle clients', formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Examples: # Start server from YAML config file python -m twinkle.server --config server_config.yaml - - # Start tinker server with specific config - python -m twinkle.server -c config.yaml -t tinker """, ) @@ -49,23 +40,12 @@ def create_parser() -> argparse.ArgumentParser: help='Path to YAML configuration file (required)', ) - # Server type - parser.add_argument( - '-t', - '--server-type', - type=str, - default='twinkle', - choices=['tinker', 'twinkle'], - metavar='TYPE', - help="Server type: 'tinker' or 'twinkle' (default: twinkle)", - ) - # Ray options parser.add_argument( '--namespace', type=str, metavar='NS', - help="Ray namespace (default: 'twinkle_cluster' for tinker, None for twinkle)", + help="Ray namespace (default: 'twinkle_cluster')", ) # Runtime options @@ -97,7 +77,6 @@ def main(args: list[str] | None = None) -> int: try: from twinkle.server.launcher import launch_server - # Config file mode config_path = Path(parsed_args.config) if not config_path.exists(): logger.error(f'Config file not found: {config_path}') @@ -105,7 +84,6 @@ def main(args: list[str] | None = None) -> int: launch_server( config_path=config_path, - server_type=parsed_args.server_type, ray_namespace=parsed_args.namespace, ) diff --git a/src/twinkle/server/common/__init__.py b/src/twinkle/server/common/__init__.py new file mode 100644 index 00000000..bb00e2bd --- /dev/null +++ b/src/twinkle/server/common/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +from .checkpoint_factory import create_checkpoint_manager, create_training_run_manager +from .datum import datum_to_input_feature, extract_rl_feature, input_feature_to_datum +from .router import StickyLoraRequestRouter +from .serialize import deserialize_object, serialize_object + +__all__ = [ + 'datum_to_input_feature', + 'extract_rl_feature', + 'input_feature_to_datum', + 'create_checkpoint_manager', + 'create_training_run_manager', + 'StickyLoraRequestRouter', + 'deserialize_object', + 'serialize_object', +] diff --git a/src/twinkle/server/common/checkpoint_factory.py b/src/twinkle/server/common/checkpoint_factory.py new file mode 100644 index 00000000..cbb2f2c6 --- /dev/null +++ b/src/twinkle/server/common/checkpoint_factory.py @@ -0,0 +1,39 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +""" +Factory functions for creating checkpoint and training-run manager instances. + +Use these functions as the entry point rather than instantiating managers directly: + + from twinkle.server.common.checkpoint_factory import ( + create_checkpoint_manager, + create_training_run_manager, + ) +""" +from twinkle.server.common.tinker_checkpoint import TinkerCheckpointManager, TinkerTrainingRunManager +from twinkle.server.common.twinkle_checkpoint import TwinkleCheckpointManager, TwinkleTrainingRunManager + + +def create_training_run_manager(token: str, client_type: str = 'twinkle'): + """Create a TrainingRunManager for the given token. + + Args: + token: User authentication token. + client_type: 'tinker' or 'twinkle' (default 'twinkle'). + """ + if client_type == 'tinker': + return TinkerTrainingRunManager(token) + return TwinkleTrainingRunManager(token) + + +def create_checkpoint_manager(token: str, client_type: str = 'twinkle'): + """Create a CheckpointManager for the given token. + + Args: + token: User authentication token. + client_type: 'tinker' or 'twinkle' (default 'twinkle'). + """ + if client_type == 'tinker': + run_mgr = TinkerTrainingRunManager(token) + return TinkerCheckpointManager(token, run_mgr) + run_mgr = TwinkleTrainingRunManager(token) + return TwinkleCheckpointManager(token, run_mgr) diff --git a/src/twinkle/server/tinker/common/datum.py b/src/twinkle/server/common/datum.py similarity index 98% rename from src/twinkle/server/tinker/common/datum.py rename to src/twinkle/server/common/datum.py index 0eb74f82..9091f388 100644 --- a/src/twinkle/server/tinker/common/datum.py +++ b/src/twinkle/server/common/datum.py @@ -1,3 +1,4 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. from __future__ import annotations import numpy as np diff --git a/src/twinkle/server/tinker/common/router.py b/src/twinkle/server/common/router.py similarity index 93% rename from src/twinkle/server/tinker/common/router.py rename to src/twinkle/server/common/router.py index 19ec8650..dee1bd36 100644 --- a/src/twinkle/server/tinker/common/router.py +++ b/src/twinkle/server/common/router.py @@ -1,3 +1,5 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +# Moved from tinker/common/router.py — logic unchanged. from ray.serve.request_router import (FIFOMixin, MultiplexMixin, PendingRequest, ReplicaID, ReplicaResult, RequestRouter, RunningReplica) from typing import Dict, List, Optional @@ -54,7 +56,7 @@ async def choose_replicas( # Filter out replicas that exceed max lora count (query from server state) candidate_ids = [r.replica_id.unique_id for r in top_ranked_replicas.values()] - available_ids = set(self.state.get_available_replica_ids(candidate_ids)) + available_ids = set(await self.state.get_available_replica_ids(candidate_ids)) if available_ids: top_ranked_replicas = { rid: r diff --git a/src/twinkle/server/twinkle/common/serialize.py b/src/twinkle/server/common/serialize.py similarity index 97% rename from src/twinkle/server/twinkle/common/serialize.py rename to src/twinkle/server/common/serialize.py index de3ca4bb..f1b3f6dd 100644 --- a/src/twinkle/server/twinkle/common/serialize.py +++ b/src/twinkle/server/common/serialize.py @@ -1,4 +1,5 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +# Moved from twinkle/common/serialize.py — logic unchanged. import json from numbers import Number from peft import LoraConfig diff --git a/src/twinkle/server/common/tinker_checkpoint.py b/src/twinkle/server/common/tinker_checkpoint.py new file mode 100644 index 00000000..fa7e5a11 --- /dev/null +++ b/src/twinkle/server/common/tinker_checkpoint.py @@ -0,0 +1,134 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +""" +Tinker-specific checkpoint and training-run managers. + +Uses ``tinker.types`` models for all serialization and response construction. +""" +from datetime import datetime +from tinker import types as tinker_types +from typing import Any, Dict, List, Optional + +from twinkle.server.utils.checkpoint_base import TRAIN_RUN_INFO_FILENAME, BaseCheckpointManager, BaseTrainingRunManager + + +class TinkerTrainingRunManager(BaseTrainingRunManager): + """Tinker-specific training run manager using tinker.types models.""" + + @property + def train_run_info_filename(self) -> str: + return TRAIN_RUN_INFO_FILENAME + + def _create_training_run(self, model_id: str, run_config: tinker_types.CreateModelRequest) -> Dict[str, Any]: + lora_config = run_config.lora_config + train_run_data = tinker_types.TrainingRun( + training_run_id=model_id, + base_model=run_config.base_model, + model_owner=self.token, + is_lora=True if lora_config else False, + corrupted=False, + lora_rank=lora_config.rank if lora_config else None, + last_request_time=datetime.now(), + last_checkpoint=None, + last_sampler_checkpoint=None, + user_metadata=run_config.user_metadata) + + new_data = train_run_data.model_dump(mode='json') + if lora_config: + new_data['train_unembed'] = lora_config.train_unembed + new_data['train_mlp'] = lora_config.train_mlp + new_data['train_attn'] = lora_config.train_attn + return new_data + + def _parse_training_run(self, data: Dict[str, Any]) -> tinker_types.TrainingRun: + data = self._transform_checkpoint_fields(data) + return tinker_types.TrainingRun(**data) + + def _transform_checkpoint_fields(self, data: Dict[str, Any]) -> Dict[str, Any]: + data = data.copy() + for field in ['last_checkpoint', 'last_sampler_checkpoint']: + if field in data and data[field] is not None: + ckpt = data[field].copy() + if 'twinkle_path' in ckpt and 'tinker_path' not in ckpt: + ckpt['tinker_path'] = ckpt.pop('twinkle_path') + elif 'tinker_path' not in ckpt: + path = ckpt.get('path') or ckpt.get('twinkle_path') + if path: + ckpt['tinker_path'] = path + elif 'checkpoint_id' in ckpt and 'training_run_id' in data: + ckpt['tinker_path'] = f"twinkle://{data['training_run_id']}/{ckpt['checkpoint_id']}" + data[field] = ckpt + return data + + def _create_training_runs_response(self, runs: List[tinker_types.TrainingRun], limit: int, offset: int, + total: int) -> tinker_types.TrainingRunsResponse: + return tinker_types.TrainingRunsResponse( + training_runs=runs, cursor=tinker_types.Cursor(limit=limit, offset=offset, total_count=total)) + + +class TinkerCheckpointManager(BaseCheckpointManager): + """Tinker-specific checkpoint manager using tinker.types models.""" + + @property + def path_prefix(self) -> str: + return 'twinkle://' + + @property + def path_field_name(self) -> str: + return 'tinker_path' + + def _create_checkpoint(self, + checkpoint_id, + checkpoint_type, + path, + size_bytes, + public, + base_model=None, + is_lora=False, + lora_rank=None, + train_unembed=None, + train_mlp=None, + train_attn=None, + user_metadata=None) -> Dict[str, Any]: + checkpoint = tinker_types.Checkpoint( + checkpoint_id=checkpoint_id, + checkpoint_type=checkpoint_type, + time=datetime.now(), + tinker_path=path, + size_bytes=size_bytes, + public=public) + result = checkpoint.model_dump(mode='json') + result['base_model'] = base_model + result['is_lora'] = is_lora + result['lora_rank'] = lora_rank + result['train_unembed'] = train_unembed + result['train_mlp'] = train_mlp + result['train_attn'] = train_attn + result['user_metadata'] = user_metadata + return result + + def _parse_checkpoint(self, data: Dict[str, Any]) -> tinker_types.Checkpoint: + data = data.copy() + if 'twinkle_path' in data and 'tinker_path' not in data: + data['tinker_path'] = data.pop('twinkle_path') + elif 'tinker_path' not in data and 'path' in data: + data['tinker_path'] = data.pop('path') + return tinker_types.Checkpoint(**data) + + def _create_checkpoints_response( + self, checkpoints: List[tinker_types.Checkpoint]) -> tinker_types.CheckpointsListResponse: + return tinker_types.CheckpointsListResponse(checkpoints=checkpoints, cursor=None) + + def _create_parsed_path(self, path, training_run_id, checkpoint_type, + checkpoint_id) -> tinker_types.ParsedCheckpointTinkerPath: + return tinker_types.ParsedCheckpointTinkerPath( + tinker_path=path, + training_run_id=training_run_id, + checkpoint_type=checkpoint_type, + checkpoint_id=checkpoint_id, + ) + + def _create_weights_info(self, run_info: Dict[str, Any]) -> tinker_types.WeightsInfoResponse: + return tinker_types.WeightsInfoResponse(**run_info) + + def parse_tinker_path(self, tinker_path: str) -> Optional[tinker_types.ParsedCheckpointTinkerPath]: + return self.parse_path(tinker_path) diff --git a/src/twinkle/server/twinkle/common/io_utils.py b/src/twinkle/server/common/twinkle_checkpoint.py similarity index 51% rename from src/twinkle/server/twinkle/common/io_utils.py rename to src/twinkle/server/common/twinkle_checkpoint.py index 4693c381..4b77d581 100644 --- a/src/twinkle/server/twinkle/common/io_utils.py +++ b/src/twinkle/server/common/twinkle_checkpoint.py @@ -1,66 +1,20 @@ # Copyright (c) ModelScope Contributors. All rights reserved. """ -Twinkle-specific IO utilities for managing training runs and checkpoints. +Twinkle-specific checkpoint and training-run managers. -This module extends the base IO utilities with Twinkle-specific implementations. +Uses ``twinkle_client.types.training`` models for all serialization and response construction. """ from datetime import datetime -from pydantic import BaseModel from typing import Any, Dict, List, Optional -from twinkle.server.utils.io_utils import (CHECKPOINT_INFO_FILENAME, TRAIN_RUN_INFO_FILENAME, TWINKLE_DEFAULT_SAVE_DIR, - BaseCheckpoint, BaseCheckpointManager, BaseCreateModelRequest, - BaseLoraConfig, BaseParsedCheckpointPath, BaseTrainingRun, - BaseTrainingRunManager, BaseWeightsInfoResponse, Cursor, ResolvedLoadPath, - validate_ownership, validate_user_path) +from twinkle.server.utils.checkpoint_base import (TRAIN_RUN_INFO_FILENAME, BaseCheckpointManager, + BaseTrainingRunManager, validate_ownership) +from twinkle_client.types.training import (Checkpoint, CheckpointsListResponse, CreateModelRequest, Cursor, + ParsedCheckpointTwinklePath, TrainingRun, TrainingRunsResponse, + WeightsInfoResponse) -# ----- Twinkle-specific Pydantic Models ----- - -class Checkpoint(BaseCheckpoint): - """Twinkle checkpoint model.""" - twinkle_path: str - - -class TrainingRun(BaseTrainingRun): - """Twinkle training run model.""" - pass - - -class TrainingRunsResponse(BaseModel): - training_runs: List[TrainingRun] - cursor: Cursor - - -class CheckpointsListResponse(BaseModel): - checkpoints: List[Checkpoint] - cursor: Optional[Cursor] = None - - -class ParsedCheckpointTwinklePath(BaseParsedCheckpointPath): - """Twinkle-specific parsed path model.""" - twinkle_path: str - - -class WeightsInfoResponse(BaseWeightsInfoResponse): - """Twinkle weights info response.""" - pass - - -class LoraConfig(BaseLoraConfig): - """Twinkle LoRA configuration.""" - pass - - -class CreateModelRequest(BaseCreateModelRequest): - """Twinkle create model request.""" - lora_config: Optional[LoraConfig] = None - - -# ----- Twinkle Training Run Manager ----- - - -class TrainingRunManager(BaseTrainingRunManager): +class TwinkleTrainingRunManager(BaseTrainingRunManager): """Twinkle-specific training run manager.""" @property @@ -68,7 +22,6 @@ def train_run_info_filename(self) -> str: return TRAIN_RUN_INFO_FILENAME def _create_training_run(self, model_id: str, run_config: CreateModelRequest) -> Dict[str, Any]: - """Create training run data from model_id and run_config.""" lora_config = run_config.lora_config train_run_data = TrainingRun( training_run_id=model_id, @@ -83,43 +36,27 @@ def _create_training_run(self, model_id: str, run_config: CreateModelRequest) -> user_metadata=run_config.user_metadata) new_data = train_run_data.model_dump(mode='json') - # Store lora config details separately if needed if lora_config: new_data['train_unembed'] = lora_config.train_unembed new_data['train_mlp'] = lora_config.train_mlp new_data['train_attn'] = lora_config.train_attn - return new_data def _parse_training_run(self, data: Dict[str, Any]) -> TrainingRun: - """Parse training run data into TrainingRun model.""" return TrainingRun(**data) def _create_training_runs_response(self, runs: List[TrainingRun], limit: int, offset: int, total: int) -> TrainingRunsResponse: - """Create a training runs response.""" return TrainingRunsResponse(training_runs=runs, cursor=Cursor(limit=limit, offset=offset, total_count=total)) def get_with_permission(self, model_id: str) -> Optional[TrainingRun]: - """ - Get training run with ownership validation. - - Args: - model_id: The model identifier - - Returns: - TrainingRun if found and owned by user, None otherwise - """ run = self.get(model_id) if run and validate_ownership(self.token, run.model_owner): return run return None -# ----- Twinkle Checkpoint Manager ----- - - -class CheckpointManager(BaseCheckpointManager): +class TwinkleCheckpointManager(BaseCheckpointManager): """Twinkle-specific checkpoint manager.""" @property @@ -131,19 +68,18 @@ def path_field_name(self) -> str: return 'twinkle_path' def _create_checkpoint(self, - checkpoint_id: str, - checkpoint_type: str, - path: str, - size_bytes: int, - public: bool, - base_model: Optional[str] = None, - is_lora: bool = False, - lora_rank: Optional[int] = None, - train_unembed: Optional[bool] = None, - train_mlp: Optional[bool] = None, - train_attn: Optional[bool] = None, - user_metadata: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: - """Create checkpoint data.""" + checkpoint_id, + checkpoint_type, + path, + size_bytes, + public, + base_model=None, + is_lora=False, + lora_rank=None, + train_unembed=None, + train_mlp=None, + train_attn=None, + user_metadata=None) -> Dict[str, Any]: checkpoint = Checkpoint( checkpoint_id=checkpoint_id, checkpoint_type=checkpoint_type, @@ -161,9 +97,7 @@ def _create_checkpoint(self, return checkpoint.model_dump(mode='json') def _parse_checkpoint(self, data: Dict[str, Any]) -> Checkpoint: - """Parse checkpoint data into Checkpoint model.""" data = data.copy() - # Transform tinker_path to twinkle_path if needed if 'tinker_path' in data and 'twinkle_path' not in data: data['twinkle_path'] = data.pop('tinker_path') elif 'twinkle_path' not in data and 'path' in data: @@ -171,20 +105,9 @@ def _parse_checkpoint(self, data: Dict[str, Any]) -> Checkpoint: return Checkpoint(**data) def get(self, model_id: str, checkpoint_id: str) -> Optional[Checkpoint]: - """ - Get checkpoint metadata with backwards compatibility. - - Args: - model_id: The model identifier - checkpoint_id: The checkpoint identifier - - Returns: - Checkpoint object or None if not found - """ data = self._read_ckpt_info(model_id, checkpoint_id) if not data: return None - # Handle backwards compatibility: construct twinkle_path if missing if 'twinkle_path' not in data and 'tinker_path' not in data and 'path' not in data: if 'checkpoint_id' in data: data = data.copy() @@ -192,12 +115,9 @@ def get(self, model_id: str, checkpoint_id: str) -> Optional[Checkpoint]: return self._parse_checkpoint(data) def _create_checkpoints_response(self, checkpoints: List[Checkpoint]) -> CheckpointsListResponse: - """Create a checkpoints list response.""" return CheckpointsListResponse(checkpoints=checkpoints, cursor=None) - def _create_parsed_path(self, path: str, training_run_id: str, checkpoint_type: str, - checkpoint_id: str) -> ParsedCheckpointTwinklePath: - """Create a parsed path model.""" + def _create_parsed_path(self, path, training_run_id, checkpoint_type, checkpoint_id) -> ParsedCheckpointTwinklePath: return ParsedCheckpointTwinklePath( path=path, twinkle_path=path, @@ -207,7 +127,6 @@ def _create_parsed_path(self, path: str, training_run_id: str, checkpoint_type: ) def _create_weights_info(self, run_info: Dict[str, Any]) -> WeightsInfoResponse: - """Create weights info from run info.""" return WeightsInfoResponse( training_run_id=run_info.get('training_run_id', ''), base_model=run_info.get('base_model', ''), @@ -217,19 +136,4 @@ def _create_weights_info(self, run_info: Dict[str, Any]) -> WeightsInfoResponse: ) def parse_twinkle_path(self, twinkle_path: str) -> Optional[ParsedCheckpointTwinklePath]: - """Parse a twinkle:// path into its components (alias for parse_path).""" return self.parse_path(twinkle_path) - - -# ----- Factory Functions ----- - - -def create_training_run_manager(token: str) -> TrainingRunManager: - """Create a TrainingRunManager for the given token.""" - return TrainingRunManager(token) - - -def create_checkpoint_manager(token: str) -> CheckpointManager: - """Create a CheckpointManager for the given token.""" - training_run_manager = TrainingRunManager(token) - return CheckpointManager(token, training_run_manager) diff --git a/src/twinkle/server/gateway/__init__.py b/src/twinkle/server/gateway/__init__.py new file mode 100644 index 00000000..1e6c2cbd --- /dev/null +++ b/src/twinkle/server/gateway/__init__.py @@ -0,0 +1,3 @@ +from .server import build_server_app + +__all__ = ['build_server_app'] diff --git a/src/twinkle/server/tinker/proxy.py b/src/twinkle/server/gateway/proxy.py similarity index 62% rename from src/twinkle/server/tinker/proxy.py rename to src/twinkle/server/gateway/proxy.py index bc429199..e8346d6c 100644 --- a/src/twinkle/server/tinker/proxy.py +++ b/src/twinkle/server/gateway/proxy.py @@ -1,15 +1,14 @@ # Copyright (c) ModelScope Contributors. All rights reserved. """ -Proxy utilities for forwarding requests to internal services. +Proxy utilities for forwarding requests to internal model/sampler services. -This module provides HTTP proxy functionality to route requests from the Tinker server -to appropriate model or sampler services based on base_model routing. +Moved from tinker/proxy.py. Updated proxy_to_model and proxy_to_sampler +to prepend the 'tinker/' prefix to endpoints so they route to /tinker/* paths +on the unified model/sampler deployments. """ - from __future__ import annotations import httpx -import os from fastapi import Request, Response from typing import Any @@ -21,11 +20,13 @@ class ServiceProxy: """HTTP proxy for routing requests to internal model and sampler services. - This proxy handles: + Handles: 1. URL construction using localhost to avoid external routing loops 2. Header forwarding with appropriate cleanup 3. Debug logging for troubleshooting 4. Error handling and response forwarding + + Tinker endpoints are routed to /tinker/ on the unified deployments. """ def __init__( @@ -33,28 +34,18 @@ def __init__( http_options: dict[str, Any] | None = None, route_prefix: str = '/api/v1', ): - """Initialize the service proxy. - - Args: - http_options: HTTP server options (host, port) for internal routing - route_prefix: URL prefix for routing (default: '/api/v1') - """ self.http_options = http_options or {} self.route_prefix = route_prefix - # Disable proxy for internal requests to avoid routing through external proxies + # Disable proxy env vars to avoid external routing self.client = httpx.AsyncClient(timeout=None, trust_env=False) def _build_target_url(self, service_type: str, base_model: str, endpoint: str) -> str: """Build the target URL for internal service routing. - Constructs URLs using localhost to avoid extra external hops. - When requests come from www.modelscope.com/twinkle, we proxy to - localhost:port directly instead of back to modelscope.com. - Args: service_type: Either 'model' or 'sampler' base_model: The base model name for routing - endpoint: The target endpoint name + endpoint: The target endpoint name (already includes tinker/ or twinkle/ prefix) Returns: Complete target URL for the internal service @@ -63,7 +54,6 @@ def _build_target_url(self, service_type: str, base_model: str, endpoint: str) - host = self.http_options.get('host', 'localhost') port = self.http_options.get('port', 8000) - # Use localhost for internal routing if host == '0.0.0.0': host = 'localhost' @@ -71,24 +61,13 @@ def _build_target_url(self, service_type: str, base_model: str, endpoint: str) - return f'{base_url}{prefix}/{service_type}/{base_model}/{endpoint}' def _prepare_headers(self, request_headers) -> dict[str, str]: - """Prepare headers for proxying by removing problematic headers. - - Args: - request_headers: Original request headers (case-insensitive from FastAPI) - - Returns: - Cleaned headers safe for proxying - """ + """Prepare headers for proxying by removing problematic headers.""" logger.debug('prepare_headers request_headers=%s', request_headers) - # Convert to dict while preserving case-insensitive lookups for special headers headers = dict(request_headers) - # Remove headers that should not be forwarded headers.pop('host', None) headers.pop('content-length', None) - # Add serve_multiplexed_model_id for sticky sessions if present - # Use case-insensitive lookup from original request_headers request_id = request_headers.get('X-Ray-Serve-Request-Id') - if request_id is not None: + if request_id is not None and not request_headers.get('serve_multiplexed_model_id'): headers['serve_multiplexed_model_id'] = request_id return headers @@ -101,24 +80,20 @@ async def proxy_request( ) -> Response: """Generic proxy method to forward requests to model or sampler services. - This method consolidates the common proxy logic for both model and sampler endpoints. - Args: request: The incoming FastAPI request - endpoint: The target endpoint name (e.g., 'create_model', 'asample') + endpoint: The target endpoint path (e.g., 'tinker/create_model') base_model: The base model name for routing - service_type: Either 'model' or 'sampler' to determine the target service + service_type: Either 'model' or 'sampler' Returns: Proxied response from the target service """ body_bytes = await request.body() target_url = self._build_target_url(service_type, base_model, endpoint) - # Pass original request.headers (case-insensitive) instead of dict conversion headers = self._prepare_headers(request.headers) try: - # Debug logging for troubleshooting proxy issues logger.debug( 'proxy_request service=%s endpoint=%s target_url=%s request_id=%s', service_type, @@ -127,7 +102,6 @@ async def proxy_request( headers.get('serve_multiplexed_model_id'), ) - # Forward the request to the target service response = await self.client.request( method=request.method, url=target_url, @@ -136,7 +110,6 @@ async def proxy_request( params=request.query_params, ) - # Debug logging for response logger.debug( 'proxy_response status=%s body_preview=%s', response.status_code, @@ -154,31 +127,21 @@ async def proxy_request( return Response(content=f'Proxy Error: {str(e)}', status_code=502) async def proxy_to_model(self, request: Request, endpoint: str, base_model: str) -> Response: - """Proxy request to model endpoint. - - Routes the request to the appropriate model deployment based on base_model. + """Proxy request to model's tinker endpoint (/tinker/). Args: request: The incoming FastAPI request - endpoint: The target endpoint name (e.g., 'create_model', 'forward') + endpoint: The tinker endpoint name (e.g., 'create_model', 'forward') base_model: The base model name for routing - - Returns: - Proxied response from the model service """ - return await self.proxy_request(request, endpoint, base_model, 'model') + return await self.proxy_request(request, f'tinker/{endpoint}', base_model, 'model') async def proxy_to_sampler(self, request: Request, endpoint: str, base_model: str) -> Response: - """Proxy request to sampler endpoint. - - Routes the request to the appropriate sampler deployment based on base_model. + """Proxy request to sampler's tinker endpoint (/tinker/). Args: request: The incoming FastAPI request - endpoint: The target endpoint name (e.g., 'asample') + endpoint: The tinker endpoint name (e.g., 'asample') base_model: The base model name for routing - - Returns: - Proxied response from the sampler service """ - return await self.proxy_request(request, endpoint, base_model, 'sampler') + return await self.proxy_request(request, f'tinker/{endpoint}', base_model, 'sampler') diff --git a/src/twinkle/server/gateway/server.py b/src/twinkle/server/gateway/server.py new file mode 100644 index 00000000..dd591ccf --- /dev/null +++ b/src/twinkle/server/gateway/server.py @@ -0,0 +1,105 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +""" +Unified Gateway Server. + +A single Ray Serve deployment that serves both Tinker (/tinker/*) and +Twinkle (/twinkle/*) management and proxy endpoints. +""" +from __future__ import annotations + +import asyncio +from fastapi import FastAPI, HTTPException, Request +from ray import serve +from tinker import types as tinker_types +from typing import Any + +from twinkle.server.utils.state import get_server_state +from twinkle.server.utils.validation import verify_request_token +from twinkle.utils.logger import get_logger +from .proxy import ServiceProxy +from .tinker_gateway_handlers import _register_tinker_routes +from .twinkle_gateway_handlers import _register_twinkle_routes + +logger = get_logger() + + +class GatewayServer: + """Unified gateway server handling both Tinker and Twinkle API clients.""" + + def __init__(self, + supported_models: list | None = None, + server_config: dict[str, Any] = {}, + http_options: dict[str, Any] | None = None, + **kwargs) -> None: + self.state = get_server_state(**server_config) + self.route_prefix = kwargs.get('route_prefix', '/api/v1') + self.http_options = http_options or {} + self.proxy = ServiceProxy(http_options=http_options, route_prefix=self.route_prefix) + self.supported_models = self._normalize_models(supported_models) or [ + tinker_types.SupportedModel(model_name='Qwen/Qwen3-30B-A3B-Instruct-2507'), + ] + self._modelscope_config_lock = asyncio.Lock() + + def _normalize_models(self, supported_models): + if not supported_models: + return [] + normalized = [] + for item in supported_models: + if isinstance(item, tinker_types.SupportedModel): + normalized.append(item) + elif isinstance(item, dict): + normalized.append(tinker_types.SupportedModel(**item)) + elif isinstance(item, str): + normalized.append(tinker_types.SupportedModel(model_name=item)) + return normalized + + def _validate_base_model(self, base_model: str) -> None: + supported_model_names = [m.model_name for m in self.supported_models] + if base_model not in supported_model_names: + raise HTTPException( + status_code=400, + detail=f"Base model '{base_model}' is not supported. " + f"Supported models: {', '.join(supported_model_names)}") + + def _get_base_model(self, model_id: str) -> str: + metadata = self.state.get_model_metadata(model_id) + if metadata and metadata.get('base_model'): + return metadata['base_model'] + raise HTTPException(status_code=404, detail=f'Model {model_id} not found') + + +def build_server_app(deploy_options: dict[str, Any], + supported_models: list | None = None, + server_config: dict[str, Any] = {}, + http_options: dict[str, Any] | None = None, + **kwargs): + """Build and configure the unified gateway server application. + + Serves Tinker endpoints at /* and Twinkle endpoints at /twinkle/*. + + Args: + deploy_options: Ray Serve deployment configuration + supported_models: List of supported base models for tinker validation + server_config: Server configuration options + http_options: HTTP server options (host, port) for internal proxy routing + **kwargs: Additional keyword arguments (route_prefix, etc.) + + Returns: + Configured Ray Serve deployment bound with options + """ + app = FastAPI() + + @app.middleware('http') + async def verify_token(request: Request, call_next): + return await verify_request_token(request=request, call_next=call_next) + + def get_self() -> GatewayServer: + return serve.get_replica_context().servable_object + + _register_tinker_routes(app, get_self) + _register_twinkle_routes(app, get_self) + + GatewayServerWithIngress = serve.ingress(app)(GatewayServer) + DeploymentClass = serve.deployment(name='GatewayServer')(GatewayServerWithIngress) + return DeploymentClass.options(**deploy_options).bind( + supported_models=supported_models, server_config=server_config, http_options=http_options, **kwargs) diff --git a/src/twinkle/server/gateway/tinker_gateway_handlers.py b/src/twinkle/server/gateway/tinker_gateway_handlers.py new file mode 100644 index 00000000..71c0654f --- /dev/null +++ b/src/twinkle/server/gateway/tinker_gateway_handlers.py @@ -0,0 +1,272 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +""" +Tinker-compatible gateway handlers. + +All endpoints are prefixed /* and registered via _register_tinker_routes(app, self_fn). +self_fn is injected via FastAPI Depends to obtain the GatewayServer instance at request time. +""" +from __future__ import annotations + +import asyncio +import os +from fastapi import Depends, FastAPI, HTTPException, Request, Response +from tinker import types +from typing import TYPE_CHECKING, Any, Callable + +if TYPE_CHECKING: + from .server import GatewayServer + +from twinkle.hub import HubOperation +from twinkle.server.common.checkpoint_factory import create_checkpoint_manager, create_training_run_manager +from twinkle.server.utils.task_queue import QueueState +from twinkle.server.utils.validation import get_token_from_request +from twinkle.utils.logger import get_logger + +logger = get_logger() + + +def _register_tinker_routes(app: FastAPI, self_fn: Callable[[], GatewayServer]) -> None: + """Register all /* Tinker routes on the given FastAPI app. + + self_fn is a zero-argument callable that returns the current GatewayServer + replica instance (e.g. ``lambda: serve.get_replica_context().servable_object``). + It is wired in via ``Depends`` so it is resolved lazily at request time. + """ + + @app.get('/healthz') + async def healthz(request: Request) -> types.HealthResponse: + return types.HealthResponse(status='ok') + + @app.get('/get_server_capabilities') + async def get_server_capabilities( + request: Request, + self: GatewayServer = Depends(self_fn), + ) -> types.GetServerCapabilitiesResponse: + return types.GetServerCapabilitiesResponse(supported_models=self.supported_models) + + @app.post('/telemetry') + async def telemetry(request: Request, body: types.TelemetrySendRequest) -> types.TelemetryResponse: + return types.TelemetryResponse(status='accepted') + + @app.post('/create_session') + async def create_session( + request: Request, + body: types.CreateSessionRequest, + self: GatewayServer = Depends(self_fn), + ) -> types.CreateSessionResponse: + session_id = self.state.create_session(body.model_dump()) + return types.CreateSessionResponse(session_id=session_id) + + @app.post('/session_heartbeat') + async def session_heartbeat( + request: Request, body: types.SessionHeartbeatRequest, self: GatewayServer = Depends(self_fn) + ) -> types.SessionHeartbeatResponse: # noqa: E125 + alive = await self.state.touch_session(body.session_id) + if not alive: + raise HTTPException(status_code=404, detail='Unknown session') + return types.SessionHeartbeatResponse() + + @app.post('/create_sampling_session') + async def create_sampling_session( + request: Request, body: types.CreateSamplingSessionRequest, self: GatewayServer = Depends(self_fn) + ) -> types.CreateSamplingSessionResponse: # noqa: E125 + sampling_session_id = self.state.create_sampling_session(body.model_dump()) + return types.CreateSamplingSessionResponse(sampling_session_id=sampling_session_id) + + @app.post('/retrieve_future') + async def retrieve_future(request: Request, + body: types.FutureRetrieveRequest, + self: GatewayServer = Depends(self_fn)) -> Any: + """Retrieve the result of an async task with long polling.""" + request_id = body.request_id + max_wait = float(os.environ.get('TWINKLE_LONG_POLL_TIMEOUT', '30')) + poll_interval = float(os.environ.get('TWINKLE_POLL_INTERVAL', '0.5')) + start = asyncio.get_event_loop().time() + + while True: + record = await self.state.get_future(request_id) + + if record is None: + return {'type': 'try_again'} + + status = record.get('status') + if status not in ('pending', 'queued', 'running', 'rate_limited'): + break + + if asyncio.get_event_loop().time() - start >= max_wait: + response_data = {'type': 'try_again'} + if queue_state := record.get('queue_state'): + response_data['queue_state'] = queue_state + if queue_state_reason := record.get('queue_state_reason'): + response_data['queue_state_reason'] = queue_state_reason + return response_data + + await asyncio.sleep(poll_interval) + + record = await self.state.get_future(request_id) + if not record: + return {'type': 'try_again'} + + status = record.get('status') + + if status == 'rate_limited': + return { + 'type': 'try_again', + 'queue_state': QueueState.PAUSED_RATE_LIMIT.value, + 'queue_state_reason': record.get('reason', 'Rate limit exceeded') + } + + if status == 'failed': + result = record.get('result', {}) + return {'error': result.get('error', 'Unknown error'), 'category': result.get('category', 'Server')} + + result = record.get('result') + if result is None: + raise HTTPException(status_code=500, detail='Task completed but no result found') + + if hasattr(result, 'model_dump'): + return result.model_dump() + return result + + # --- Training Runs Endpoints --- + + @app.get('/training_runs') + async def get_training_runs(request: Request, limit: int = 20, offset: int = 0) -> types.TrainingRunsResponse: + token = get_token_from_request(request) + training_run_manager = create_training_run_manager(token, client_type='tinker') + return training_run_manager.list_runs(limit=limit, offset=offset) + + @app.get('/training_runs/{run_id}') + async def get_training_run(request: Request, run_id: str) -> types.TrainingRun: + token = get_token_from_request(request) + training_run_manager = create_training_run_manager(token, client_type='tinker') + run = training_run_manager.get(run_id) + if not run: + raise HTTPException(status_code=404, detail=f'Training run {run_id} not found') + return run + + @app.get('/training_runs/{run_id}/checkpoints') + async def get_run_checkpoints(request: Request, run_id: str) -> types.CheckpointsListResponse: + token = get_token_from_request(request) + checkpoint_manager = create_checkpoint_manager(token, client_type='tinker') + response = checkpoint_manager.list_checkpoints(run_id) + if not response: + raise HTTPException(status_code=404, detail=f'Training run {run_id} not found') + return response + + @app.delete('/training_runs/{run_id}/checkpoints/{checkpoint_id:path}') + async def delete_run_checkpoint(request: Request, run_id: str, checkpoint_id: str) -> Any: + token = get_token_from_request(request) + checkpoint_manager = create_checkpoint_manager(token, client_type='tinker') + success = checkpoint_manager.delete(run_id, checkpoint_id) + if not success: + raise HTTPException(status_code=404, detail=f'Checkpoint {checkpoint_id} not found for run {run_id}') + return None + + @app.post('/weights_info') + async def weights_info(request: Request, body: dict[str, Any]) -> types.WeightsInfoResponse: + token = get_token_from_request(request) + checkpoint_manager = create_checkpoint_manager(token, client_type='tinker') + tinker_path = body.get('tinker_path') + response = checkpoint_manager.get_weights_info(tinker_path) + if not response: + raise HTTPException(status_code=404, detail=f'Weights at {tinker_path} not found') + return response + + @app.post('/training_runs/{run_id}/checkpoints/{checkpoint_id:path}/publish') + async def publish_checkpoint(request: Request, + run_id: str, + checkpoint_id: str, + self: GatewayServer = Depends(self_fn)) -> Response: + token = get_token_from_request(request) + + training_run_manager = create_training_run_manager(token, client_type='tinker') + checkpoint_manager = create_checkpoint_manager(token, client_type='tinker') + + run = training_run_manager.get(run_id) + if not run: + raise HTTPException(status_code=404, detail=f'Training run {run_id} not found or access denied') + + checkpoint = checkpoint_manager.get(run_id, checkpoint_id) + if not checkpoint: + raise HTTPException(status_code=404, detail=f'Checkpoint {checkpoint_id} not found') + + checkpoint_dir = str(checkpoint_manager.get_ckpt_dir(run_id, checkpoint_id)) + + async with self._modelscope_config_lock: + try: + from modelscope.hub.api import HubApi, ModelScopeConfig + hub_api = HubApi(token=token) + hub_api.login() + username = ModelScopeConfig.get_user_info()[0] + except Exception as e: + logger.error(f'Failed to get username from ModelScope: {e}') + raise HTTPException( + status_code=401, + detail='Failed to get username from ModelScope. Please ensure your token is valid.') + + checkpoint_name = checkpoint_id.split('/')[-1] + hub_model_id = f'{username}/{run_id}_{checkpoint_name}' + HubOperation.push_to_hub(repo_id=hub_model_id, folder_path=checkpoint_dir, token=token, private=True) + + return Response(status_code=204) + + # --- Model Proxy Endpoints --- + + @app.post('/create_model') + async def create_model(request: Request, body: types.CreateModelRequest, + self: GatewayServer = Depends(self_fn)) -> Any: + self._validate_base_model(body.base_model) + return await self.proxy.proxy_to_model(request, 'create_model', body.base_model) + + @app.post('/get_info') + async def get_info(request: Request, body: types.GetInfoRequest, self: GatewayServer = Depends(self_fn)) -> Any: + return await self.proxy.proxy_to_model(request, 'get_info', self._get_base_model(body.model_id)) + + @app.post('/unload_model') + async def unload_model(request: Request, body: types.UnloadModelRequest, + self: GatewayServer = Depends(self_fn)) -> Any: + return await self.proxy.proxy_to_model(request, 'unload_model', self._get_base_model(body.model_id)) + + @app.post('/forward') + async def forward(request: Request, body: types.ForwardRequest, self: GatewayServer = Depends(self_fn)) -> Any: + return await self.proxy.proxy_to_model(request, 'forward', self._get_base_model(body.model_id)) + + @app.post('/forward_backward') + async def forward_backward(request: Request, + body: types.ForwardBackwardRequest, + self: GatewayServer = Depends(self_fn)) -> Any: + return await self.proxy.proxy_to_model(request, 'forward_backward', self._get_base_model(body.model_id)) + + @app.post('/optim_step') + async def optim_step(request: Request, body: types.OptimStepRequest, self: GatewayServer = Depends(self_fn)) -> Any: + return await self.proxy.proxy_to_model(request, 'optim_step', self._get_base_model(body.model_id)) + + @app.post('/save_weights') + async def save_weights(request: Request, body: types.SaveWeightsRequest, + self: GatewayServer = Depends(self_fn)) -> Any: + return await self.proxy.proxy_to_model(request, 'save_weights', self._get_base_model(body.model_id)) + + @app.post('/load_weights') + async def load_weights(request: Request, body: types.LoadWeightsRequest, + self: GatewayServer = Depends(self_fn)) -> Any: + return await self.proxy.proxy_to_model(request, 'load_weights', self._get_base_model(body.model_id)) + + # --- Sampler Proxy Endpoints --- + + @app.post('/asample') + async def asample(request: Request, body: types.SampleRequest, self: GatewayServer = Depends(self_fn)) -> Any: + base_model = body.base_model + if not base_model and body.sampling_session_id: + session = self.state.get_sampling_session(body.sampling_session_id) + if session: + base_model = session.get('base_model') + return await self.proxy.proxy_to_sampler(request, 'asample', base_model) + + @app.post('/save_weights_for_sampler') + async def save_weights_for_sampler( + request: Request, + body: types.SaveWeightsForSamplerRequest, + self: GatewayServer = Depends(self_fn), + ) -> Any: + return await self.proxy.proxy_to_model(request, 'save_weights_for_sampler', self._get_base_model(body.model_id)) diff --git a/src/twinkle/server/gateway/twinkle_gateway_handlers.py b/src/twinkle/server/gateway/twinkle_gateway_handlers.py new file mode 100644 index 00000000..9c0a3ba7 --- /dev/null +++ b/src/twinkle/server/gateway/twinkle_gateway_handlers.py @@ -0,0 +1,120 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +""" +Twinkle-native gateway handlers. + +All endpoints are prefixed /twinkle/* and registered via _register_twinkle_routes(app, self_fn). +""" +from __future__ import annotations + +from fastapi import Depends, FastAPI, HTTPException, Request +from typing import TYPE_CHECKING, Callable + +if TYPE_CHECKING: + from .server import GatewayServer + +import twinkle_client.types as types +from twinkle.server.common.checkpoint_factory import create_checkpoint_manager, create_training_run_manager +from twinkle.server.utils.checkpoint_base import validate_user_path +from twinkle.server.utils.validation import get_token_from_request +from twinkle.utils.logger import get_logger + +logger = get_logger() + + +def _register_twinkle_routes(app: FastAPI, self_fn: Callable[[], GatewayServer]) -> None: + """Register all /twinkle/* routes on the given FastAPI app.""" + + @app.get('/twinkle/healthz', response_model=types.HealthResponse) + async def healthz(request: Request) -> types.HealthResponse: + return types.HealthResponse(status='ok') + + @app.post('/twinkle/create_session', response_model=types.CreateSessionResponse) + async def create_session( + request: Request, + body: types.CreateSessionRequest, + self: GatewayServer = Depends(self_fn), + ) -> types.CreateSessionResponse: + session_id = self.state.create_session(body.model_dump()) + return types.CreateSessionResponse(session_id=session_id) + + @app.post('/twinkle/session_heartbeat', response_model=types.SessionHeartbeatResponse) + async def session_heartbeat( + request: Request, + body: types.SessionHeartbeatRequest, + self: GatewayServer = Depends(self_fn), + ) -> types.SessionHeartbeatResponse: + alive = await self.state.touch_session(body.session_id) + if not alive: + raise HTTPException(status_code=404, detail='Unknown session') + return types.SessionHeartbeatResponse() + + @app.get('/twinkle/training_runs', response_model=types.TrainingRunsResponse) + async def get_training_runs(request: Request, limit: int = 20, offset: int = 0) -> types.TrainingRunsResponse: + token = get_token_from_request(request) + training_run_manager = create_training_run_manager(token, client_type='twinkle') + return training_run_manager.list_runs(limit=limit, offset=offset) + + @app.get('/twinkle/training_runs/{run_id}', response_model=types.TrainingRun) + async def get_training_run(request: Request, run_id: str) -> types.TrainingRun: + token = get_token_from_request(request) + training_run_manager = create_training_run_manager(token, client_type='twinkle') + run = training_run_manager.get_with_permission(run_id) + if not run: + raise HTTPException(status_code=404, detail=f'Training run {run_id} not found or access denied') + return run + + @app.get('/twinkle/training_runs/{run_id}/checkpoints', response_model=types.CheckpointsListResponse) + async def get_run_checkpoints(request: Request, run_id: str) -> types.CheckpointsListResponse: + token = get_token_from_request(request) + checkpoint_manager = create_checkpoint_manager(token, client_type='twinkle') + response = checkpoint_manager.list_checkpoints(run_id) + if response is None: + raise HTTPException(status_code=404, detail=f'Training run {run_id} not found or access denied') + return response + + @app.delete( + '/twinkle/training_runs/{run_id}/checkpoints/{checkpoint_id:path}', + response_model=types.DeleteCheckpointResponse) + async def delete_run_checkpoint(request: Request, run_id: str, + checkpoint_id: str) -> types.DeleteCheckpointResponse: + token = get_token_from_request(request) + + if not validate_user_path(token, checkpoint_id): + raise HTTPException(status_code=400, detail='Invalid checkpoint path: path traversal not allowed') + + checkpoint_manager = create_checkpoint_manager(token, client_type='twinkle') + success = checkpoint_manager.delete(run_id, checkpoint_id) + if not success: + raise HTTPException(status_code=404, detail=f'Checkpoint {checkpoint_id} not found or access denied') + + return types.DeleteCheckpointResponse(success=True, message=f'Checkpoint {checkpoint_id} deleted successfully') + + @app.post('/twinkle/weights_info', response_model=types.WeightsInfoResponse) + async def weights_info(request: Request, body: types.WeightsInfoRequest) -> types.WeightsInfoResponse: + token = get_token_from_request(request) + checkpoint_manager = create_checkpoint_manager(token, client_type='twinkle') + response = checkpoint_manager.get_weights_info(body.twinkle_path) + if response is None: + raise HTTPException(status_code=404, detail=f'Weights at {body.twinkle_path} not found or access denied') + return response + + @app.get('/twinkle/checkpoint_path/{run_id}/{checkpoint_id:path}', response_model=types.CheckpointPathResponse) + async def get_checkpoint_path(request: Request, run_id: str, checkpoint_id: str) -> types.CheckpointPathResponse: + token = get_token_from_request(request) + + if not validate_user_path(token, checkpoint_id): + raise HTTPException(status_code=400, detail='Invalid checkpoint path: path traversal not allowed') + + training_run_manager = create_training_run_manager(token, client_type='twinkle') + checkpoint_manager = create_checkpoint_manager(token, client_type='twinkle') + + run = training_run_manager.get(run_id) + if not run: + raise HTTPException(status_code=404, detail=f'Training run {run_id} not found or access denied') + + checkpoint = checkpoint_manager.get(run_id, checkpoint_id) + if not checkpoint: + raise HTTPException(status_code=404, detail=f'Checkpoint {checkpoint_id} not found') + + ckpt_dir = checkpoint_manager.get_ckpt_dir(run_id, checkpoint_id) + return types.CheckpointPathResponse(path=str(ckpt_dir), twinkle_path=checkpoint.twinkle_path) diff --git a/src/twinkle/server/launcher.py b/src/twinkle/server/launcher.py index 843418c2..53b88350 100644 --- a/src/twinkle/server/launcher.py +++ b/src/twinkle/server/launcher.py @@ -2,8 +2,8 @@ """ Unified Server Launcher for Twinkle. -This module provides a unified way to launch both tinker and twinkle servers -with support for YAML config files, Python dict config, and CLI. +This module provides a unified way to launch the server with support for +YAML config files, Python dict config, and CLI. Usage: # From YAML config @@ -12,7 +12,6 @@ # From Python dict launch_server(config={ - "server_type": "tinker", "http_options": {"host": "0.0.0.0", "port": 8000}, "applications": [...] }) @@ -33,26 +32,17 @@ class ServerLauncher: """ - Unified server launcher for tinker and twinkle servers. + Unified server launcher. - This class handles Ray/Serve initialization and application deployment - for both tinker and twinkle server types. + This class handles Ray/Serve initialization and application deployment. Attributes: - server_type: The type of server ('tinker' or 'twinkle') config: The server configuration dictionary ray_namespace: The Ray namespace for the cluster """ - # Mapping of simplified import_path names to actual builder functions - # These will be populated lazily to avoid circular imports - _TINKER_BUILDERS: dict[str, str] = { - 'server': 'build_server_app', - 'model': 'build_model_app', - 'sampler': 'build_sampler_app', - } - - _TWINKLE_BUILDERS: dict[str, str] = { + # Mapping of simplified import_path names to builder function names + _BUILDERS: dict[str, str] = { 'server': 'build_server_app', 'model': 'build_model_app', 'sampler': 'build_sampler_app', @@ -61,7 +51,6 @@ class ServerLauncher: def __init__( self, - server_type: str = 'twinkle', config: dict[str, Any] | None = None, ray_namespace: str | None = None, ): @@ -69,14 +58,9 @@ def __init__( Initialize the server launcher. Args: - server_type: Server type ('tinker' or 'twinkle') config: Configuration dictionary - ray_namespace: Ray namespace (default: 'twinkle_cluster' for tinker, None for twinkle) + ray_namespace: Ray namespace (default: 'twinkle_cluster') """ - if server_type not in ('tinker', 'twinkle'): - raise ValueError(f"server_type must be 'tinker' or 'twinkle', got '{server_type}'") - - self.server_type = server_type self.config = config or {} self.ray_namespace = ray_namespace self._builders: dict[str, Callable] = {} @@ -84,30 +68,21 @@ def __init__( self._serve_started = False def _get_builders(self) -> dict[str, Callable]: - """ - Get the appropriate builder functions for the server type. - - Returns: - Dictionary mapping import_path names to builder functions - """ + """Get the builder functions for all app types.""" if self._builders: return self._builders - if self.server_type == 'tinker': - from twinkle.server.tinker import build_model_app, build_sampler_app, build_server_app - self._builders = { - 'build_server_app': build_server_app, - 'build_model_app': build_model_app, - 'build_sampler_app': build_sampler_app, - } - else: # twinkle - from twinkle.server.twinkle import build_model_app, build_processor_app, build_sampler_app, build_server_app - self._builders = { - 'build_server_app': build_server_app, - 'build_model_app': build_model_app, - 'build_sampler_app': build_sampler_app, - 'build_processor_app': build_processor_app, - } + from twinkle.server.gateway import build_server_app + from twinkle.server.model import build_model_app + from twinkle.server.processor import build_processor_app + from twinkle.server.sampler import build_sampler_app + + self._builders = { + 'build_server_app': build_server_app, + 'build_model_app': build_model_app, + 'build_sampler_app': build_sampler_app, + 'build_processor_app': build_processor_app, + } return self._builders @@ -116,7 +91,7 @@ def _resolve_builder(self, import_path: str) -> Callable: Resolve an import_path to a builder function. Args: - import_path: The import path from config (e.g., 'server', 'main:build_server_app') + import_path: The import path from config (e.g., 'server', 'model') Returns: The builder function @@ -125,11 +100,10 @@ def _resolve_builder(self, import_path: str) -> Callable: ValueError: If the import_path cannot be resolved """ builders = self._get_builders() - builder_map = self._TINKER_BUILDERS if self.server_type == 'tinker' else self._TWINKLE_BUILDERS - # Try to resolve through the mapping - if import_path in builder_map: - builder_name = builder_map[import_path] + # Try to resolve through the simplified name mapping + if import_path in self._BUILDERS: + builder_name = self._BUILDERS[import_path] if builder_name in builders: return builders[builder_name] @@ -137,8 +111,8 @@ def _resolve_builder(self, import_path: str) -> Callable: if import_path in builders: return builders[import_path] - raise ValueError(f"Unknown import_path '{import_path}' for server_type '{self.server_type}'. " - f'Available: {list(builder_map.keys())}') + raise ValueError(f"Unknown import_path '{import_path}'. " + f'Available: {list(self._BUILDERS.keys())}') def _init_ray(self) -> None: """Initialize Ray if not already initialized.""" @@ -147,14 +121,10 @@ def _init_ray(self) -> None: import ray - # Determine namespace namespace = self.ray_namespace or self.config.get('ray_namespace') or 'twinkle_cluster' - init_kwargs = {} - init_kwargs['namespace'] = namespace - if not ray.is_initialized(): - ray.init(**init_kwargs) + ray.init(namespace=namespace) logger.info(f'Ray initialized with namespace={namespace}') self._ray_initialized = True @@ -166,19 +136,16 @@ def _start_serve(self) -> None: from ray import serve - # Shutdown any existing serve instance try: serve.shutdown() - time.sleep(2) # Wait for cleanup + time.sleep(2) except Exception: pass - # Get http_options from config http_options = self.config.get('http_options', {}) if isinstance(http_options, dict): http_options = dict(http_options) else: - # Handle OmegaConf or other config objects http_options = dict(http_options) if http_options else {} serve.start(http_options=http_options) @@ -187,8 +154,7 @@ def _start_serve(self) -> None: self._serve_started = True def _deploy_application(self, app_config: dict[str, Any]) -> None: - """ - Deploy a single application. + """Deploy a single application. Args: app_config: Application configuration dictionary @@ -203,15 +169,12 @@ def _deploy_application(self, app_config: dict[str, Any]) -> None: logger.info(f'Starting {name} at {route_prefix}...') - # Resolve builder function builder = self._resolve_builder(import_path) - # Build deploy_options from deployments config deploy_options = {} if deployments: deploy_config = deployments[0] if isinstance(deploy_config, dict): - # Copy all deployment options from the config, except 'name'. deploy_options = {k: v for k, v in deploy_config.items() if k != 'name'} # Pass http_options to server apps for internal proxy routing @@ -219,16 +182,13 @@ def _deploy_application(self, app_config: dict[str, Any]) -> None: if import_path == 'server' and http_options: args['http_options'] = http_options - # Build and deploy the application app = builder(deploy_options=deploy_options, **{k: v for k, v in args.items()}) serve.run(app, name=name, route_prefix=route_prefix) logger.info(f'Deployed {name} at {route_prefix}') def launch(self) -> None: - """ - Launch the server with all configured applications. - """ + """Launch the server with all configured applications.""" self._init_ray() self._start_serve() @@ -237,15 +197,12 @@ def launch(self) -> None: logger.warning('No applications configured') return - # Deploy each application for app_config in applications: if isinstance(app_config, dict): self._deploy_application(app_config) else: - # Handle OmegaConf or other config objects self._deploy_application(dict(app_config)) - # Print endpoints http_options = self.config.get('http_options', {}) host = http_options.get('host', 'localhost') port = http_options.get('port', 8000) @@ -264,7 +221,6 @@ def launch(self) -> None: def from_yaml( cls, config_path: str | Path, - server_type: str = 'twinkle', ray_namespace: str | None = None, ) -> ServerLauncher: """ @@ -272,7 +228,6 @@ def from_yaml( Args: config_path: Path to the YAML config file - server_type: Server type ('tinker' or 'twinkle'), default is 'twinkle' ray_namespace: Override Ray namespace from config Returns: @@ -287,12 +242,7 @@ def from_yaml( config = OmegaConf.load(config_path) config_dict = OmegaConf.to_container(config, resolve=True) - # Override server_type from config if specified - if 'server_type' in config_dict: - server_type = config_dict['server_type'] - return cls( - server_type=server_type, config=config_dict, ray_namespace=ray_namespace or config_dict.get('ray_namespace'), ) @@ -301,7 +251,6 @@ def from_yaml( def launch_server( config: dict[str, Any] | None = None, config_path: str | Path | None = None, - server_type: str = 'twinkle', ray_namespace: str | None = None, ) -> ServerLauncher: """ @@ -312,7 +261,6 @@ def launch_server( Args: config: Configuration dictionary (takes precedence over config_path) config_path: Path to YAML config file - server_type: Server type ('tinker' or 'twinkle'), default is 'twinkle' ray_namespace: Ray namespace Returns: @@ -322,15 +270,11 @@ def launch_server( ValueError: If neither config nor config_path is provided Examples: - # From YAML config (twinkle mode) + # From YAML config launch_server(config_path="server_config.yaml") - # From YAML config (tinker mode) - launch_server(config_path="server_config.yaml", server_type="tinker") - # From Python dict launch_server(config={ - "server_type": "tinker", "http_options": {"host": "0.0.0.0", "port": 8000}, "applications": [...] }) @@ -338,21 +282,14 @@ def launch_server( if config is None and config_path is None: raise ValueError("Either 'config' or 'config_path' must be provided") - launcher: ServerLauncher - if config is not None: - # From Python dict config - override with config's server_type if specified - final_server_type = config.get('server_type', server_type) launcher = ServerLauncher( - server_type=final_server_type, config=config, ray_namespace=ray_namespace or config.get('ray_namespace'), ) else: - # From YAML config file launcher = ServerLauncher.from_yaml( config_path=config_path, - server_type=server_type, ray_namespace=ray_namespace, ) diff --git a/src/twinkle/server/model/__init__.py b/src/twinkle/server/model/__init__.py new file mode 100644 index 00000000..1a203083 --- /dev/null +++ b/src/twinkle/server/model/__init__.py @@ -0,0 +1,3 @@ +from .app import build_model_app + +__all__ = ['build_model_app'] diff --git a/src/twinkle/server/model/app.py b/src/twinkle/server/model/app.py new file mode 100644 index 00000000..49692e7f --- /dev/null +++ b/src/twinkle/server/model/app.py @@ -0,0 +1,167 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +""" +Unified model management application. + +Builds a single Ray Serve deployment (ModelManagement) that simultaneously handles +both Tinker (/tinker/*) and Twinkle (/twinkle/*) model endpoints. +""" +from __future__ import annotations + +from fastapi import FastAPI, Request +from ray import serve +from ray.serve.config import RequestRouterConfig +from typing import Any, Dict, Optional + +import twinkle +from twinkle import DeviceGroup, DeviceMesh +from twinkle.server.utils.adapter_manager import AdapterManagerMixin +from twinkle.server.utils.state import ServerStateProxy, get_server_state +from twinkle.server.utils.task_queue import TaskQueueConfig, TaskQueueMixin +from twinkle.server.utils.validation import get_token_from_request, verify_request_token +from twinkle.utils.logger import get_logger +from ..common.router import StickyLoraRequestRouter +from ..utils import wrap_builder_with_device_group_env +from .tinker_handlers import _register_tinker_routes +from .twinkle_handlers import _register_twinkle_routes + +logger = get_logger() + + +class ModelManagement(TaskQueueMixin, AdapterManagerMixin): + """Unified model management service. + + Handles: + - Base model and multiple LoRA adapters (multi-user) + - Tinker training operations via /tinker/* endpoints (async/polling) + - Twinkle training operations via /twinkle/* endpoints (synchronous) + - Adapter lifecycle via AdapterManagerMixin + - Per-user rate limiting via TaskQueueMixin + """ + + def __init__(self, + model_id: str, + nproc_per_node: int, + device_group: dict[str, Any], + device_mesh: dict[str, Any], + use_megatron: bool = False, + adapter_config: dict[str, Any] = {}, + queue_config: dict[str, Any] | None = None, + **kwargs): + self.device_group = DeviceGroup(**device_group) + twinkle.initialize(mode='ray', nproc_per_node=nproc_per_node, groups=[self.device_group], lazy_collect=False) + if 'mesh_dim_names' in device_mesh: + self.device_mesh = DeviceMesh(**device_mesh) + else: + self.device_mesh = DeviceMesh.from_sizes(**device_mesh) + self.use_megatron = use_megatron + self.replica_id = serve.get_replica_context().replica_id.unique_id + self.max_loras = kwargs.get('max_loras', 5) + self.base_model = model_id + + # Choose model backend + if use_megatron: + from ..model.backends.megatron_model import TwinkleCompatMegatronModel + + self.model = TwinkleCompatMegatronModel( + model_id=model_id, + device_mesh=self.device_mesh, + remote_group=self.device_group.name, + instance_id=self.replica_id, + **kwargs) + else: + from ..model.backends.transformers_model import TwinkleCompatTransformersModel + self.model = TwinkleCompatTransformersModel( + model_id=model_id, + device_mesh=self.device_mesh, + remote_group=self.device_group.name, + instance_id=self.replica_id, + **kwargs) + + self.state: ServerStateProxy = get_server_state() + self.state.register_replica(self.replica_id, self.max_loras) + + # Initialize mixins + self._init_task_queue(TaskQueueConfig.from_dict(queue_config)) + self._init_adapter_manager(**adapter_config) + self.start_adapter_countdown() + + @serve.multiplexed(max_num_models_per_replica=5) + async def _sticky_entry(self, sticky_key: str): + return sticky_key + + async def _ensure_sticky(self): + sticky_key = serve.get_multiplexed_model_id() + await self._sticky_entry(sticky_key) + + async def _on_request_start(self, request: Request) -> str: + await self._ensure_sticky() + token = get_token_from_request(request) + return token + + def __del__(self): + self.state.unregister_replica(self.replica_id) + + def _cleanup_adapter(self, adapter_name: str) -> None: + if self.get_adapter_info(adapter_name): + self.clear_adapter_state(adapter_name) + self.model.remove_adapter(adapter_name) + self.unregister_adapter(adapter_name) + self.state.unload_model(adapter_name) + + def _on_adapter_expired(self, adapter_name: str) -> None: + self.fail_pending_tasks_for_model(adapter_name, reason='Adapter expired') + self._cleanup_adapter(adapter_name) + + +def build_model_app(model_id: str, + nproc_per_node: int, + device_group: dict[str, Any], + device_mesh: dict[str, Any], + deploy_options: dict[str, Any], + use_megatron: bool = False, + adapter_config: dict[str, Any] = {}, + queue_config: dict[str, Any] | None = None, + **kwargs): + """Build a unified model management application for distributed training. + + Supports both Tinker (polling-style) and Twinkle (synchronous) clients. + + Args: + model_id: Base model identifier (e.g., "Qwen/Qwen2.5-0.5B-Instruct") + nproc_per_node: Number of processes per node for distributed training + device_group: Device group configuration dict + device_mesh: Device mesh configuration dict for tensor parallelism + deploy_options: Ray Serve deployment options + use_megatron: Whether to use Megatron backend (vs Transformers) + adapter_config: Adapter lifecycle config (timeout, per-token limits) + queue_config: Task queue configuration (rate limiting, etc.) + **kwargs: Additional model initialization arguments + + Returns: + Configured Ray Serve deployment bound with parameters + """ + # Build the FastAPI app and register all routes BEFORE serve.ingress so that + # the frozen app contains the complete route table (visible to ProxyActor). + app = FastAPI() + + @app.middleware('http') + async def verify_token(request: Request, call_next): + return await verify_request_token(request=request, call_next=call_next) + + def get_self() -> ModelManagement: + return serve.get_replica_context().servable_object + + _register_tinker_routes(app, get_self) + _register_twinkle_routes(app, get_self) + + ModelManagementWithIngress = serve.ingress(app)(ModelManagement) + DeploymentClass = serve.deployment( + name='ModelManagement', + request_router_config=RequestRouterConfig(request_router_class=StickyLoraRequestRouter), + )( + ModelManagementWithIngress) + return DeploymentClass.options(**deploy_options).bind(model_id, nproc_per_node, device_group, device_mesh, + use_megatron, adapter_config, queue_config, **kwargs) + + +build_model_app = wrap_builder_with_device_group_env(build_model_app) diff --git a/src/twinkle/server/twinkle/common/__init__.py b/src/twinkle/server/model/backends/__init__.py similarity index 100% rename from src/twinkle/server/twinkle/common/__init__.py rename to src/twinkle/server/model/backends/__init__.py diff --git a/src/twinkle/server/tinker/common/compat_base.py b/src/twinkle/server/model/backends/common.py similarity index 71% rename from src/twinkle/server/tinker/common/compat_base.py rename to src/twinkle/server/model/backends/common.py index 62e22ff6..e1f62e23 100644 --- a/src/twinkle/server/tinker/common/compat_base.py +++ b/src/twinkle/server/model/backends/common.py @@ -1,5 +1,11 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +""" +Shared helpers and base classes for backend model implementations. +""" import numpy as np +import re import torch +from numbers import Number from tinker import types from typing import List @@ -8,42 +14,29 @@ def collect_forward_backward_results(results, device_mesh: DeviceMesh): - """Custom collect function for forward_backward that handles list [outputs, loss]. - - Args: - results: List of lists from each worker, where each list is [outputs_list, loss_float] - - Returns: - List of [flattened_outputs, averaged_loss] - """ + """Custom collect function for forward_backward that handles list [outputs, loss].""" if not results: return results - # Filter for last pipeline stage if PP is enabled pp_last_ranks = None if device_mesh.pp_world_size > 1: pp_last_ranks = set(device_mesh.get_pp_last_ranks()) - # Filter for last tp rank if TP is enabled tp_last_ranks = None if device_mesh.tp_world_size > 1: tp_last_ranks = set(device_mesh.get_tp_last_ranks()) mesh_flat = device_mesh.mesh.flatten() - # results is a list of lists: [[outputs1, loss1], [outputs2, loss2], ...] - # Flatten outputs (first element of each list) all_outputs = [] all_losses = [] for i, result in enumerate(results): rank = mesh_flat[i] if i < len(mesh_flat) else -1 - # Only collect from the last PP rank to avoid duplicates if pp_last_ranks is not None: if rank not in pp_last_ranks: continue - # Only collect from the last TP rank to avoid duplicates if tp_last_ranks is not None: if rank not in tp_last_ranks: continue @@ -57,7 +50,6 @@ def collect_forward_backward_results(results, device_mesh: DeviceMesh): all_outputs.extend(outputs) all_losses.append(loss) - # Average the losses if all_losses: avg_loss = float(np.mean(all_losses)) else: @@ -67,17 +59,13 @@ def collect_forward_backward_results(results, device_mesh: DeviceMesh): def clean_metrics(metrics: dict) -> dict: - import re - from numbers import Number def _to_float(v): - # python numeric / numpy scalar if isinstance(v, (float, int, Number, np.generic, str)): try: return float(v) except Exception: return None - # 0-d torch tensor if isinstance(v, torch.Tensor) and v.numel() == 1: try: return float(v.item()) @@ -92,12 +80,11 @@ def _to_float(v): cleaned[key] = fv continue - # handle common metric strings: "123 seconds", "1.23 iters/s" if isinstance(value, str): s = value.strip() if s: try: - head, unit = s.split() # ignore unit/tail + head, unit = s.split(maxsplit=1) cleaned[f'{key}/{unit}'] = float(head) except Exception: m = re.match(r'^([+-]?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][+-]?\d+)?)', s) @@ -117,31 +104,28 @@ def get_template(self, adapter_name: str) -> Template: def _get_forward_output(inputs: List[types.Datum], logits: torch.Tensor, logps: torch.Tensor) -> List[dict]: """Convert raw logits to the expected output format with logprobs and elementwise_loss.""" from twinkle.utils.torch_utils import selective_log_softmax - device = logits.device if logits is not None else logps.device + if logps is not None: + device = logps.device + elif logits is not None: + device = logits.device + else: + raise ValueError('At least one of logits or logps must be provided.') results = [] if logits is None: logits = [None] * len(inputs) for idx, (feature, logit) in enumerate(zip(inputs, logits)): - # Ensure 1D shape and correct device to avoid dimension mismatch and device errors - labels = feature.loss_fn_inputs['target_tokens'].to_torch().long().view(-1).to(device) # shape (seq_len,) - weights = feature.loss_fn_inputs['weights'].to_torch().view(-1).to(device) # shape (seq_len,) + labels = feature.loss_fn_inputs['target_tokens'].to_torch().long().view(-1).to(device) + weights = feature.loss_fn_inputs['weights'].to_torch().view(-1).to(device) - # Slice logits to match the sequence length of labels - # Labels are assumed to be already shifted/aligned with logits seq_len = labels.numel() if logps is None: - assert logits is not None - # Check if index is within logits bounds - # Right padding + assert logit is not None, 'logit must not be None when logps is None' feature_logits = logit[:seq_len, :] - - # Calculate log probs for all labels token_log_probs = selective_log_softmax(feature_logits, labels) else: token_log_probs = logps[idx, :seq_len] - # elementwise_loss: positive NLL loss (0.0 where masked) elementwise_loss = -token_log_probs * weights results.append({ diff --git a/src/twinkle/server/model/backends/megatron_model.py b/src/twinkle/server/model/backends/megatron_model.py new file mode 100644 index 00000000..b471cdb1 --- /dev/null +++ b/src/twinkle/server/model/backends/megatron_model.py @@ -0,0 +1,114 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +""" +Megatron backend model for the unified model deployment. +""" +import torch +from tinker import types +from typing import TYPE_CHECKING, Any, List, Optional, Tuple + +from twinkle import remote_class, remote_function +from twinkle.model.megatron import MultiLoraMegatronModel +from twinkle.server.common.datum import datum_to_input_feature, extract_rl_feature +from twinkle.server.model.backends.common import TwinkleCompatModelBase, clean_metrics, collect_forward_backward_results + + +@remote_class(execute='all') +class TwinkleCompatMegatronModel(MultiLoraMegatronModel, TwinkleCompatModelBase): + """Compatibility wrapper around MultiLoraMegatronModel for Twinkle/Tinker. + + Moved from tinker/common/megatron_model.py — logic unchanged. + """ + + @remote_function(dispatch='slice_dp', collect=collect_forward_backward_results, sync=True) + def tinker_forward_backward(self, *, inputs: List[types.Datum], adapter_name: str, loss_fn: str, **kwargs): + """Combined forward and backward pass.""" + if loss_fn == 'importance_sampling': + super().set_loss('GRPOLoss', adapter_name=adapter_name, epsilon=0.2, beta=0.0) + template = self.get_template(adapter_name=adapter_name) + input_features = datum_to_input_feature(inputs, template) + loss_values = extract_rl_feature(inputs) + loss_kwargs = kwargs.copy() + loss_kwargs.update(loss_values) + outputs = super().forward_backward(inputs=input_features, adapter_name=adapter_name, **loss_kwargs) + loss = outputs.get('loss', None) + logits_list = outputs.get('logits', []) + logps = outputs.get('logps', []) + if logits_list is None and logps is None: + return [None, None] + + logits = None + if logits_list is not None: + if isinstance(logits_list, torch.Tensor): + logits = logits_list.detach() + else: + logits = torch.cat([logit.detach() for logit in logits_list], dim=0) + logps = logps.detach().cpu() + results = self._get_forward_output(inputs, logits, logps) + + if isinstance(loss, torch.Tensor): + loss = loss.item() + else: + loss = float(loss) + + return [results, loss] + + @remote_function(dispatch='slice_dp', collect='flatten') + def tinker_forward_only(self, *, inputs: List[types.Datum], **kwargs): + """Forward pass without gradient computation.""" + template = self.get_template(**kwargs) + input_features = datum_to_input_feature(inputs, template) + outputs = super().forward_only(inputs=input_features, **kwargs) + logits = outputs.get('logits', None) + logps = outputs.get('logps', None) + + if logits is not None: + if isinstance(logits, torch.Tensor): + logits = logits.detach().cpu() + elif isinstance(logits, list) and len(logits) > 0: + logits = torch.cat([logit.detach().cpu() for logit in logits], dim=0) + results = self._get_forward_output(inputs, logits, logps) + else: + results = [{'logprobs': None, 'elementwise_loss': None} for _ in inputs] + + return results + + @remote_function(dispatch='all') + def tinker_step(self, *, adam_params: types.AdamParams, **kwargs): + """Optimizer step with AdamParams configuration.""" + adapter_name = kwargs.get('adapter_name') + optimizer_config = self.optimizer_group.get(adapter_name) + + if optimizer_config and optimizer_config.optimizer: + opt = optimizer_config.optimizer + if hasattr(opt, 'chained_optimizers'): + for chained_opt in opt.chained_optimizers: + if hasattr(chained_opt, 'config'): + chained_opt.config.lr = adam_params.learning_rate + chained_opt.config.adam_eps = adam_params.eps + chained_opt.config.adam_beta1 = adam_params.beta1 + chained_opt.config.adam_beta2 = adam_params.beta2 + chained_opt.config.weight_decay = adam_params.weight_decay + if adam_params.grad_clip_norm > 0: + chained_opt.config.clip_grad = adam_params.grad_clip_norm + + super().step(**kwargs) + super().zero_grad(**kwargs) + + @remote_function(collect='first', lazy_collect=False) + def tinker_calculate_metric(self, is_training, **kwargs): + metric = super().calculate_metric(is_training, **kwargs) + return clean_metrics(metric) + + @remote_function(dispatch='all', sync=True) + def tinker_load(self, checkpoint_dir: str, **kwargs): + """Load checkpoint with token-based isolation support.""" + token = kwargs.pop('token', None) + if not token: + raise ValueError('Token is required for loading checkpoints') + from twinkle.server.common.checkpoint_factory import create_checkpoint_manager + checkpoint_manager = create_checkpoint_manager(token, client_type='tinker') + resolved = checkpoint_manager.resolve_load_path(checkpoint_dir) + if resolved.is_twinkle_path: + return super().load(name=resolved.checkpoint_name, output_dir=resolved.checkpoint_dir, **kwargs) + else: + return super().load(name=resolved.checkpoint_name, **kwargs) diff --git a/src/twinkle/server/model/backends/transformers_model.py b/src/twinkle/server/model/backends/transformers_model.py new file mode 100644 index 00000000..20d6b75b --- /dev/null +++ b/src/twinkle/server/model/backends/transformers_model.py @@ -0,0 +1,138 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +""" +Backend model implementations for the unified model deployment. + +Contains one unified class: +- TwinkleCompatTransformersModel: handles both tinker (Datum-based I/O) via /tinker/* + endpoints and twinkle-native (InputFeature/Trajectory-based I/O) via /twinkle/* endpoints. +""" +import numpy as np +import torch +from collections.abc import Mapping +from tinker import types +from typing import Any, List, Union + +from twinkle import remote_class, remote_function +from twinkle.data_format import InputFeature, Trajectory +from twinkle.model import MultiLoraTransformersModel +from twinkle.server.common.datum import datum_to_input_feature, extract_rl_feature +from twinkle.server.model.backends.common import TwinkleCompatModelBase, clean_metrics, collect_forward_backward_results + + +@remote_class() +class TwinkleCompatTransformersModel(MultiLoraTransformersModel, TwinkleCompatModelBase): + """Unified wrapper around MultiLoraTransformersModel. + + Handles both: + - Tinker-compat I/O (Datum / TensorData) via /tinker/* endpoints. + - Twinkle-native I/O (InputFeature / Trajectory) via /twinkle/* endpoints. + """ + + # ------------------------------------------------------------------ + # Shared helper: CPU-safe serialisation for HTTP transport + # ------------------------------------------------------------------ + + @staticmethod + def _to_cpu_safe_output(obj: Any) -> Any: + """Convert nested outputs into CPU-safe Python objects for HTTP transport.""" + from twinkle.utils import torch_util + + if isinstance(obj, torch.Tensor): + tensor = torch_util.to_local_tensor(obj).detach().cpu() + if tensor.numel() == 1: + return tensor.item() + return tensor.tolist() + if isinstance(obj, np.ndarray): + if obj.size == 1: + return obj.item() + return obj.tolist() + if isinstance(obj, np.generic): + return obj.item() + if isinstance(obj, Mapping): + return {key: TwinkleCompatTransformersModel._to_cpu_safe_output(value) for key, value in obj.items()} + if isinstance(obj, (list, tuple)): + return [TwinkleCompatTransformersModel._to_cpu_safe_output(value) for value in obj] + return obj + + # ------------------------------------------------------------------ + # Tinker-compat methods (Datum-based I/O) + # ------------------------------------------------------------------ + + @remote_function(dispatch='slice_dp', collect='flatten') + def tinker_forward_only(self, *, inputs: List[types.Datum], **kwargs): + template = self.get_template(**kwargs) + input_features = datum_to_input_feature(inputs, template) + outputs = super().forward_only(inputs=input_features, **kwargs) + logits = outputs['logits'].detach().cpu() + logps = outputs.get('logps', None) + if logps is not None: + logps = logps.detach().cpu() + results = self._get_forward_output(inputs, logits, logps) + return results + + @remote_function(dispatch='slice_dp', collect=collect_forward_backward_results) + def tinker_forward_backward(self, *, inputs: List[types.Datum], adapter_name: str, loss_fn: str, **kwargs): + if loss_fn == 'cross_entropy': + super().set_loss('CrossEntropyLoss', adapter_name=adapter_name) + elif loss_fn == 'importance_sampling': + super().set_loss('GRPOLoss', adapter_name=adapter_name, epsilon=0.2, beta=0.0) + else: + super().set_loss('CrossEntropyLoss', adapter_name=adapter_name) + template = self.get_template(adapter_name) + input_features = datum_to_input_feature(inputs, template) + outputs = super().forward(inputs=input_features, adapter_name=adapter_name, **kwargs) + loss_values = extract_rl_feature(inputs) + loss_kwargs = kwargs.copy() + loss_kwargs.update(loss_values) + loss = super().calculate_loss(adapter_name=adapter_name, **loss_kwargs) + super().backward(adapter_name=adapter_name, **kwargs) + logits = outputs['logits'].detach() + logps = outputs.get('logps', None) + if logps is not None: + logps = logps.detach().cpu() + results = self._get_forward_output(inputs, logits, logps) + return [results, loss] + + @remote_function() + def tinker_step(self, *, adam_params: types.AdamParams, **kwargs): + grad_clip_norm = adam_params.grad_clip_norm + if grad_clip_norm > 0.0: + self.clip_grad_norm(max_grad_norm=grad_clip_norm, norm_type=2, **kwargs) + optim_params = { + 'lr': adam_params.learning_rate, + 'eps': adam_params.eps, + 'betas': (adam_params.beta1, adam_params.beta2), + 'weight_decay': adam_params.weight_decay, + } + super().step(optim_params=optim_params, **kwargs) + super().zero_grad(**kwargs) + + @remote_function(collect='first', lazy_collect=False) + def tinker_calculate_metric(self, is_training, **kwargs): + metric = super().calculate_metric(is_training, **kwargs) + return clean_metrics(metric) + + @remote_function() + def tinker_load(self, checkpoint_dir: str, **kwargs): + """Load checkpoint with token-based isolation support.""" + token = kwargs.pop('token', None) + if not token: + raise ValueError('Token is required for loading checkpoints') + from twinkle.server.common.checkpoint_factory import create_checkpoint_manager + checkpoint_manager = create_checkpoint_manager(token, client_type='tinker') + resolved = checkpoint_manager.resolve_load_path(checkpoint_dir) + if resolved.is_twinkle_path: + return super().load(name=resolved.checkpoint_name, output_dir=resolved.checkpoint_dir, **kwargs) + else: + return super().load(name=resolved.checkpoint_name, **kwargs) + + # ------------------------------------------------------------------ + # Twinkle-native methods (InputFeature/Trajectory-based I/O) + # ------------------------------------------------------------------ + + @remote_function(dispatch='slice_dp', collect='mean') + def forward_backward(self, *, inputs: Union[InputFeature, List[InputFeature], Trajectory, List[Trajectory]], + **kwargs): + """Forward+backward for twinkle-native clients (InputFeature/Trajectory I/O).""" + output = super().forward_backward(inputs=inputs, **kwargs) + return self._to_cpu_safe_output(output) diff --git a/src/twinkle/server/model/tinker_handlers.py b/src/twinkle/server/model/tinker_handlers.py new file mode 100644 index 00000000..6f458d8f --- /dev/null +++ b/src/twinkle/server/model/tinker_handlers.py @@ -0,0 +1,303 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +""" +Tinker-compatible model handler mixin. + +All endpoints are prefixed /tinker/... and use schedule_task() returning UntypedAPIFuture. +self_fn is injected via FastAPI Depends to obtain the ModelManagement instance at request time. +""" +from __future__ import annotations + +import traceback +from fastapi import Depends, FastAPI, Request +from peft import LoraConfig +from tinker import types +from typing import TYPE_CHECKING, Any, Callable + +if TYPE_CHECKING: + from .app import ModelManagement + +from twinkle.server.common.checkpoint_factory import create_checkpoint_manager, create_training_run_manager +from twinkle.utils.logger import get_logger + +logger = get_logger() + + +def _register_tinker_routes(app: FastAPI, self_fn: Callable[[], ModelManagement]) -> None: + """Register all /tinker/* routes on the given FastAPI app. + + self_fn is a zero-argument callable that returns the current ModelManagement + replica instance. It is wired in via Depends so it is resolved lazily at request time. + """ + + @app.post('/tinker/create_model') + async def create_model( + request: Request, + body: types.CreateModelRequest, + self: ModelManagement = Depends(self_fn), + ) -> types.UntypedAPIFuture: + token = await self._on_request_start(request) + + async def _create_adapter(): + _model_id = None + try: + + _model_id = self.state.register_model(body.model_dump(), token=token, replica_id=self.replica_id) + if body.lora_config: + lora_cfg = LoraConfig(r=body.lora_config.rank, target_modules='all-linear') + adapter_name = self.get_adapter_name(adapter_name=_model_id) + self.register_adapter(adapter_name, token, session_id=body.session_id) + self.model.add_adapter_to_model(adapter_name=adapter_name, config_or_dir=lora_cfg) + self.model.set_template('Template', adapter_name=adapter_name, model_id=self.base_model) + self.model.set_processor('InputProcessor', adapter_name=adapter_name) + self.model.set_optimizer('Adam', adapter_name=adapter_name) + self.set_adapter_state(adapter_name, 'grad_ready', False) + training_run_manager = create_training_run_manager(token, client_type='tinker') + training_run_manager.save(_model_id, body) + return types.CreateModelResponse(model_id=_model_id) + except Exception: + if _model_id: + adapter_name = self.get_adapter_name(adapter_name=_model_id) + self._cleanup_adapter(adapter_name) + logger.error(traceback.format_exc()) + return types.RequestFailedResponse( + error=traceback.format_exc(), + category=types.RequestErrorCategory.Server, + ) + + return await self.schedule_task(_create_adapter, token=token, task_type='create_model') + + @app.post('/tinker/get_info') + async def get_info(request: Request, body: types.GetInfoRequest, + self: ModelManagement = Depends(self_fn)) -> types.GetInfoResponse: + token = await self._on_request_start(request) + training_run_manager = create_training_run_manager(token, client_type='tinker') + metadata = training_run_manager.get(str(body.model_id)) + model_name = metadata.base_model if metadata else self.base_model + lora_rank = None + is_lora = False + if metadata and hasattr(metadata, 'lora_rank') and metadata.lora_rank: + lora_rank = metadata.lora_rank + is_lora = metadata.is_lora + return types.GetInfoResponse( + model_data=types.ModelData(model_name=model_name), + model_id=body.model_id, + is_lora=is_lora, + lora_rank=lora_rank, + model_name=model_name, + ) + + @app.post('/tinker/unload_model') + async def unload_model( + request: Request, + body: types.UnloadModelRequest, + self: ModelManagement = Depends(self_fn), + ) -> types.UntypedAPIFuture: + token = await self._on_request_start(request) + + async def _do_unload(): + adapter_name = self.get_adapter_name(adapter_name=body.model_id) + self._cleanup_adapter(adapter_name) + return types.UnloadModelResponse(model_id=body.model_id) + + return await self.schedule_task(_do_unload, model_id=body.model_id, token=token, task_type='unload_model') + + @app.post('/tinker/forward') + async def forward(request: Request, body: types.ForwardRequest, + self: ModelManagement = Depends(self_fn)) -> types.UntypedAPIFuture: + token = await self._on_request_start(request) + + async def _do_forward(): + try: + adapter_name = self.get_adapter_name(adapter_name=body.model_id) + self.assert_adapter_exists(adapter_name=adapter_name) + datum_list = body.forward_input.data + loss_fn_config = body.forward_input.loss_fn_config or {} + output = self.model.tinker_forward_only(inputs=datum_list, adapter_name=adapter_name) + loss = self.model.calculate_loss(adapter_name=adapter_name, **loss_fn_config) + return types.ForwardBackwardOutput( + loss_fn_output_type='CrossEntropyLossReturn', + loss_fn_outputs=output, + metrics={'loss:sum': loss}, + ) + except Exception: + logger.error(traceback.format_exc()) + return types.RequestFailedResponse( + error=traceback.format_exc(), + category=types.RequestErrorCategory.Server, + ) + + datum_list = body.forward_input.data + input_tokens = sum(len(d.model_input.to_ints()) for d in datum_list) + batch_size = len(datum_list) + return await self.schedule_task( + _do_forward, + model_id=body.model_id, + token=token, + input_tokens=input_tokens, + batch_size=batch_size, + data_world_size=self.device_mesh.data_world_size, + task_type='forward', + ) + + @app.post('/tinker/forward_backward') + async def forward_backward( + request: Request, + body: types.ForwardBackwardRequest, + self: ModelManagement = Depends(self_fn), + ) -> types.UntypedAPIFuture: + token = await self._on_request_start(request) + + async def _do_forward_backward(): + try: + adapter_name = self.get_adapter_name(adapter_name=body.model_id) + self.assert_adapter_exists(adapter_name=adapter_name) + datum_list = body.forward_backward_input.data + loss_fn = body.forward_backward_input.loss_fn + loss_fn_config = body.forward_backward_input.loss_fn_config or {} + output, loss = self.model.tinker_forward_backward( + inputs=datum_list, adapter_name=adapter_name, loss_fn=loss_fn, **loss_fn_config) + output_type = ('ImportanceSamplingLossReturn' + if loss_fn == 'importance_sampling' else 'CrossEntropyLossReturn') + self.set_adapter_state(adapter_name, 'grad_ready', True) + return types.ForwardBackwardOutput( + loss_fn_output_type=output_type, + loss_fn_outputs=output, + metrics={'loss:avg': loss}, + ) + except Exception: + logger.error(traceback.format_exc()) + return types.RequestFailedResponse( + error=traceback.format_exc(), + category=types.RequestErrorCategory.Server, + ) + + datum_list = body.forward_backward_input.data + input_tokens = sum(len(d.model_input.to_ints()) for d in datum_list) + batch_size = len(datum_list) + return await self.schedule_task( + _do_forward_backward, + model_id=body.model_id, + token=token, + input_tokens=input_tokens, + batch_size=batch_size, + data_world_size=self.device_mesh.data_world_size, + task_type='forward_backward', + ) + + @app.post('/tinker/optim_step') + async def optim_step( + request: Request, + body: types.OptimStepRequest, + self: ModelManagement = Depends(self_fn), + ) -> types.UntypedAPIFuture: + token = await self._on_request_start(request) + + async def _do_optim(): + try: + adapter_name = self.get_adapter_name(adapter_name=body.model_id) + self.assert_adapter_exists(adapter_name=adapter_name) + if not self.get_adapter_state(adapter_name, 'grad_ready', False): + raise RuntimeError(f'No accumulated gradients for adapter={adapter_name}; ' + 'call forward_backward before optim_step') + self.model.tinker_step(adam_params=body.adam_params, adapter_name=adapter_name) + self.set_adapter_state(adapter_name, 'grad_ready', False) + metrics = self.model.tinker_calculate_metric(is_training=True, adapter_name=adapter_name) + return types.OptimStepResponse(metrics=metrics) + except Exception: + logger.error(traceback.format_exc()) + return types.RequestFailedResponse( + error=traceback.format_exc(), + category=types.RequestErrorCategory.Server, + ) + + return await self.schedule_task(_do_optim, model_id=body.model_id, token=token, task_type='optim_step') + + @app.post('/tinker/save_weights') + async def save_weights( + request: Request, + body: types.SaveWeightsRequest, + self: ModelManagement = Depends(self_fn), + ) -> types.UntypedAPIFuture: + token = await self._on_request_start(request) + + async def _do_save(): + try: + adapter_name = self.get_adapter_name(adapter_name=body.model_id) + self.assert_adapter_exists(adapter_name=adapter_name) + checkpoint_manager = create_checkpoint_manager(token, client_type='tinker') + checkpoint_name = checkpoint_manager.get_ckpt_name(body.path) + save_dir = checkpoint_manager.get_save_dir(model_id=body.model_id, is_sampler=False) + self.model.save( + name=checkpoint_name, output_dir=save_dir, adapter_name=adapter_name, save_optimizer=True) + tinker_path = checkpoint_manager.save(body.model_id, name=checkpoint_name, is_sampler=False) + return types.SaveWeightsResponse(path=tinker_path, type='save_weights') + except Exception: + logger.error(traceback.format_exc()) + return types.RequestFailedResponse( + error=traceback.format_exc(), + category=types.RequestErrorCategory.Server, + ) + + return await self.schedule_task(_do_save, model_id=body.model_id, token=token, task_type='save_weights') + + @app.post('/tinker/save_weights_for_sampler') + async def save_weights_for_sampler( + request: Request, + body: types.SaveWeightsForSamplerRequest, + self: ModelManagement = Depends(self_fn), + ) -> types.UntypedAPIFuture: + token = await self._on_request_start(request) + + async def _do_save_for_sampler(): + try: + adapter_name = self.get_adapter_name(adapter_name=body.model_id) + self.assert_adapter_exists(adapter_name=adapter_name) + checkpoint_manager = create_checkpoint_manager(token, client_type='tinker') + checkpoint_name = checkpoint_manager.get_ckpt_name(body.path) + save_dir = checkpoint_manager.get_save_dir(model_id=body.model_id, is_sampler=True) + tinker_path = checkpoint_manager.save(body.model_id, name=checkpoint_name, is_sampler=True) + logger.info(f'Saving weights to {save_dir}') + self.model.save( + name=checkpoint_name, output_dir=save_dir, adapter_name=adapter_name, save_optimizer=False) + payload = body.model_dump() + payload['model_path'] = tinker_path + metadata = self.state.get_model_metadata(body.model_id) or {} + if metadata.get('base_model'): + payload['base_model'] = metadata['base_model'] + sampling_session_id = self.state.create_sampling_session(payload) + return types.SaveWeightsForSamplerResponseInternal(path=None, sampling_session_id=sampling_session_id) + except Exception: + logger.error(traceback.format_exc()) + return types.RequestFailedResponse( + error=traceback.format_exc(), + category=types.RequestErrorCategory.Server, + ) + + return await self.schedule_task( + _do_save_for_sampler, model_id=body.model_id, token=token, task_type='save_weights_for_sampler') + + @app.post('/tinker/load_weights') + async def load_weights( + request: Request, + body: types.LoadWeightsRequest, + self: ModelManagement = Depends(self_fn), + ) -> types.UntypedAPIFuture: + token = await self._on_request_start(request) + + async def _do_load(): + try: + assert self.model is not None, 'Model not loaded, please load model first' + adapter_name = self.get_adapter_name(adapter_name=body.model_id) + self.assert_adapter_exists(adapter_name=adapter_name) + self.model.tinker_load( + checkpoint_dir=body.path, load_optimizer=body.optimizer, adapter_name=adapter_name, token=token) + self.set_adapter_state(adapter_name, 'grad_ready', False) + return types.LoadWeightsResponse(path=body.path, type='load_weights') + except Exception: + logger.error(traceback.format_exc()) + return types.RequestFailedResponse( + error=traceback.format_exc(), + category=types.RequestErrorCategory.Server, + ) + + return await self.schedule_task(_do_load, model_id=body.model_id, token=token, task_type='load_weights') diff --git a/src/twinkle/server/model/twinkle_handlers.py b/src/twinkle/server/model/twinkle_handlers.py new file mode 100644 index 00000000..35c87441 --- /dev/null +++ b/src/twinkle/server/model/twinkle_handlers.py @@ -0,0 +1,469 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +""" +Twinkle-native model handler mixin. + +All endpoints are prefixed /twinkle/... and use schedule_task_and_wait() returning +results directly (synchronous from the client's perspective). +self_fn is injected via FastAPI Depends to obtain the ModelManagement instance at request time. +""" +from __future__ import annotations + +import traceback +from fastapi import Depends, FastAPI, HTTPException, Request +from peft import LoraConfig +from typing import TYPE_CHECKING, Any, Callable, Optional + +if TYPE_CHECKING: + from .app import ModelManagement + +import twinkle_client.types as types +from twinkle.data_format import InputFeature, Trajectory +from twinkle.server.common.checkpoint_factory import create_checkpoint_manager, create_training_run_manager +from twinkle.server.common.serialize import deserialize_object +from twinkle.server.utils.validation import get_session_id_from_request +from twinkle.utils.logger import get_logger + +logger = get_logger() + + +def _parse_inputs(inputs: Any): + """Convert raw dict/list inputs to InputFeature or Trajectory objects.""" + if isinstance(inputs, list) and inputs: + first = inputs[0] + if isinstance(first, dict) and 'input_ids' in first: + return [InputFeature(**item) for item in inputs] + else: + return [Trajectory(**item) for item in inputs] + elif isinstance(inputs, dict): + if 'input_ids' in inputs: + return [InputFeature(**inputs)] + else: + return [Trajectory(**inputs)] + return inputs + + +def _get_twinkle_adapter_name(request: Request, adapter_name: str | None) -> str | None: + """Build the per-request adapter name from the request_id prefix.""" + if adapter_name is None or adapter_name == '': + return None + return request.state.request_id + '-' + adapter_name + + +def _register_twinkle_routes(app: FastAPI, self_fn: Callable[[], ModelManagement]) -> None: + """Register all /twinkle/* routes on the given FastAPI app. + + self_fn is a zero-argument callable that returns the current ModelManagement + replica instance. It is wired in via Depends so it is resolved lazily at request time. + """ + + async def run_task(coro): + """Await a schedule_task_and_wait coroutine and surface any exception as a + structured HTTP 500 response so the client receives the full traceback instead + of an opaque connection-level error.""" + try: + return await coro + except Exception: + logger.error(traceback.format_exc()) + raise HTTPException(status_code=500, detail=traceback.format_exc()) + + @app.post('/twinkle/create', response_model=types.CreateResponse) + async def create(request: Request, body: types.CreateRequest, + self: ModelManagement = Depends(self_fn)) -> types.CreateResponse: + return types.CreateResponse() + + @app.post('/twinkle/forward', response_model=types.ForwardResponse) + async def forward(request: Request, body: types.ForwardRequest, + self: ModelManagement = Depends(self_fn)) -> types.ForwardResponse: + adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) + + async def _task(): + self.assert_adapter_exists(adapter_name=adapter_name) + extra_kwargs = body.model_extra or {} + inputs = _parse_inputs(body.inputs) + ret = self.model.forward(inputs=inputs, adapter_name=adapter_name, **extra_kwargs) + return {'result': ret} + + return await run_task(self.schedule_task_and_wait(_task, task_type='forward')) + + @app.post('/twinkle/forward_only', response_model=types.ForwardResponse) + async def forward_only( + request: Request, + body: types.ForwardOnlyRequest, + self: ModelManagement = Depends(self_fn), + ) -> types.ForwardResponse: + adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) + + async def _task(): + self.assert_adapter_exists(adapter_name=adapter_name) + extra_kwargs = body.model_extra or {} + inputs = _parse_inputs(body.inputs) + ret = self.model.forward_only(inputs=inputs, adapter_name=adapter_name, **extra_kwargs) + return {'result': ret} + + return await run_task(self.schedule_task_and_wait(_task, task_type='forward_only')) + + @app.post('/twinkle/calculate_loss', response_model=types.CalculateLossResponse) + async def calculate_loss( + request: Request, + body: types.AdapterRequest, + self: ModelManagement = Depends(self_fn), + ) -> types.CalculateLossResponse: + adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) + + async def _task(): + self.assert_adapter_exists(adapter_name=adapter_name) + extra_kwargs = body.model_extra or {} + ret = self.model.calculate_loss(adapter_name=adapter_name, **extra_kwargs) + return {'result': ret} + + return await run_task(self.schedule_task_and_wait(_task, task_type='calculate_loss')) + + @app.post('/twinkle/backward') + async def backward(request: Request, body: types.AdapterRequest, self: ModelManagement = Depends(self_fn)) -> None: + adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) + + async def _task(): + self.assert_adapter_exists(adapter_name=adapter_name) + extra_kwargs = body.model_extra or {} + self.model.backward(adapter_name=adapter_name, **extra_kwargs) + + await run_task(self.schedule_task_and_wait(_task, task_type='backward')) + + @app.post('/twinkle/forward_backward', response_model=types.ForwardBackwardResponse) + async def forward_backward( + request: Request, + body: types.ForwardRequest, + self: ModelManagement = Depends(self_fn), + ) -> types.ForwardBackwardResponse: + adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) + + async def _task(): + self.assert_adapter_exists(adapter_name=adapter_name) + extra_kwargs = body.model_extra or {} + inputs = _parse_inputs(body.inputs) + ret = self.model.forward_backward(inputs=inputs, adapter_name=adapter_name, **extra_kwargs) + return {'result': ret} + + return await run_task(self.schedule_task_and_wait(_task, task_type='forward_backward')) + + @app.post('/twinkle/clip_grad_norm', response_model=types.ClipGradNormResponse) + async def clip_grad_norm( + request: Request, + body: types.AdapterRequest, + self: ModelManagement = Depends(self_fn), + ) -> types.ClipGradNormResponse: + adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) + + async def _task(): + self.assert_adapter_exists(adapter_name=adapter_name) + extra_kwargs = body.model_extra or {} + ret = self.model.clip_grad_norm(adapter_name=adapter_name, **extra_kwargs) + return {'result': str(ret)} + + return await run_task(self.schedule_task_and_wait(_task, task_type='clip_grad_norm')) + + @app.post('/twinkle/step') + async def step(request: Request, body: types.AdapterRequest, self: ModelManagement = Depends(self_fn)) -> None: + adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) + + async def _task(): + self.assert_adapter_exists(adapter_name=adapter_name) + extra_kwargs = body.model_extra or {} + self.model.step(adapter_name=adapter_name, **extra_kwargs) + + await run_task(self.schedule_task_and_wait(_task, task_type='step')) + + @app.post('/twinkle/zero_grad') + async def zero_grad(request: Request, body: types.AdapterRequest, self: ModelManagement = Depends(self_fn)) -> None: + adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) + + async def _task(): + self.assert_adapter_exists(adapter_name=adapter_name) + extra_kwargs = body.model_extra or {} + self.model.zero_grad(adapter_name=adapter_name, **extra_kwargs) + + await run_task(self.schedule_task_and_wait(_task, task_type='zero_grad')) + + @app.post('/twinkle/lr_step') + async def lr_step(request: Request, body: types.AdapterRequest, self: ModelManagement = Depends(self_fn)) -> None: + adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) + + async def _task(): + self.assert_adapter_exists(adapter_name=adapter_name) + extra_kwargs = body.model_extra or {} + self.model.lr_step(adapter_name=adapter_name, **extra_kwargs) + + await run_task(self.schedule_task_and_wait(_task, task_type='lr_step')) + + @app.post('/twinkle/clip_grad_and_step') + async def clip_grad_and_step( + request: Request, + body: types.ClipGradAndStepRequest, + self: ModelManagement = Depends(self_fn), + ) -> None: + adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) + + async def _task(): + self.assert_adapter_exists(adapter_name=adapter_name) + extra_kwargs = body.model_extra or {} + self.model.clip_grad_and_step( + max_grad_norm=body.max_grad_norm, + norm_type=body.norm_type, + adapter_name=adapter_name, + **extra_kwargs, + ) + + await run_task(self.schedule_task_and_wait(_task, task_type='clip_grad_and_step')) + + @app.post('/twinkle/get_train_configs', response_model=types.GetTrainConfigsResponse) + async def get_train_configs( + request: Request, + body: types.AdapterRequest, + self: ModelManagement = Depends(self_fn), + ) -> types.GetTrainConfigsResponse: + adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) + + async def _task(): + self.assert_adapter_exists(adapter_name=adapter_name) + extra_kwargs = body.model_extra or {} + ret = self.model.get_train_configs(adapter_name=adapter_name, **extra_kwargs) + return {'result': ret} + + return await run_task(self.schedule_task_and_wait(_task, task_type='get_train_configs')) + + @app.post('/twinkle/set_loss') + async def set_loss(request: Request, body: types.SetLossRequest, self: ModelManagement = Depends(self_fn)) -> None: + adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) + + async def _task(): + self.assert_adapter_exists(adapter_name=adapter_name) + extra_kwargs = body.model_extra or {} + self.model.set_loss(body.loss_cls, adapter_name=adapter_name, **extra_kwargs) + + await run_task(self.schedule_task_and_wait(_task, task_type='set_loss')) + + @app.post('/twinkle/set_optimizer') + async def set_optimizer( + request: Request, + body: types.SetOptimizerRequest, + self: ModelManagement = Depends(self_fn), + ) -> None: + adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) + + async def _task(): + self.assert_adapter_exists(adapter_name=adapter_name) + extra_kwargs = body.model_extra or {} + self.model.set_optimizer(body.optimizer_cls, adapter_name=adapter_name, **extra_kwargs) + + await run_task(self.schedule_task_and_wait(_task, task_type='set_optimizer')) + + @app.post('/twinkle/set_lr_scheduler') + async def set_lr_scheduler( + request: Request, + body: types.SetLrSchedulerRequest, + self: ModelManagement = Depends(self_fn), + ) -> None: + adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) + + async def _task(): + self.assert_adapter_exists(adapter_name=adapter_name) + extra_kwargs = body.model_extra or {} + self.model.set_lr_scheduler(body.scheduler_cls, adapter_name=adapter_name, **extra_kwargs) + + await run_task(self.schedule_task_and_wait(_task, task_type='set_lr_scheduler')) + + @app.post('/twinkle/save', response_model=types.SaveResponse) + async def save(request: Request, body: types.SaveRequest, + self: ModelManagement = Depends(self_fn)) -> types.SaveResponse: + token = await self._on_request_start(request) + adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) + + async def _task(): + self.assert_adapter_exists(adapter_name=adapter_name) + extra_kwargs = body.model_extra or {} + checkpoint_manager = create_checkpoint_manager(token, client_type='twinkle') + checkpoint_name = checkpoint_manager.get_ckpt_name(body.name) + save_dir = checkpoint_manager.get_save_dir(model_id=adapter_name, is_sampler=False) + checkpoint_dir = self.model.save( + name=checkpoint_name, + output_dir=save_dir, + adapter_name=adapter_name, + save_optimizer=body.save_optimizer, + **extra_kwargs) + twinkle_path = checkpoint_manager.save(model_id=adapter_name, name=checkpoint_name, is_sampler=False) + return {'twinkle_path': twinkle_path, 'checkpoint_dir': checkpoint_dir} + + return await run_task(self.schedule_task_and_wait(_task, task_type='save')) + + @app.post('/twinkle/load') + async def load(request: Request, body: types.LoadRequest, self: ModelManagement = Depends(self_fn)) -> None: + token = await self._on_request_start(request) + adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) + + async def _task(): + self.assert_adapter_exists(adapter_name=adapter_name) + extra_kwargs = body.model_extra or {} + checkpoint_manager = create_checkpoint_manager(token, client_type='twinkle') + resolved = checkpoint_manager.resolve_load_path(body.name) + self.model.load( + name=resolved.checkpoint_name, + output_dir=resolved.checkpoint_dir, + adapter_name=adapter_name, + load_optimizer=body.load_optimizer, + token=token, + **extra_kwargs) + + await run_task(self.schedule_task_and_wait(_task, task_type='load')) + + @app.post('/twinkle/upload_to_hub') + async def upload_to_hub( + request: Request, + body: types.UploadToHubRequest, + self: ModelManagement = Depends(self_fn), + ) -> None: + token = await self._on_request_start(request) + + async def _task(): + if body.checkpoint_dir.startswith('twinkle://'): + checkpoint_manager = create_checkpoint_manager(token, client_type='twinkle') + parsed = checkpoint_manager.parse_twinkle_path(body.checkpoint_dir) + if not parsed: + raise ValueError(f'Invalid twinkle path format: {body.checkpoint_dir}') + checkpoint_id = parsed.checkpoint_id + model_id_to_load = parsed.training_run_id + checkpoint = checkpoint_manager.get(model_id_to_load, checkpoint_id) + if not checkpoint: + raise ValueError(f'Checkpoint not found or access denied: {body.checkpoint_dir}') + checkpoint_dir = str( + checkpoint_manager.get_ckpt_dir(model_id=model_id_to_load, checkpoint_id=checkpoint_id)) + else: + checkpoint_dir = body.checkpoint_dir + self.model.upload_to_hub( + checkpoint_dir=checkpoint_dir, + hub_model_id=body.hub_model_id, + hub_token=body.hub_token or token, + async_upload=body.async_upload) + + await run_task(self.schedule_task_and_wait(_task, task_type='upload_to_hub')) + + @app.post('/twinkle/add_adapter_to_model', response_model=types.AddAdapterResponse) + async def add_adapter_to_model( + request: Request, + body: types.AddAdapterRequest, + self: ModelManagement = Depends(self_fn), + ) -> types.AddAdapterResponse: + assert body.adapter_name, 'You need to specify a valid `adapter_name`' + token = await self._on_request_start(request) + adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) + session_id = get_session_id_from_request(request) + + async def _task(): + config = deserialize_object(body.config) + extra_kwargs = body.model_extra or {} + training_run_manager = create_training_run_manager(token, client_type='twinkle') + self.register_adapter(adapter_name, token, session_id=session_id) + self.model.add_adapter_to_model(adapter_name, config, **extra_kwargs) + + lora_config = None + if isinstance(config, LoraConfig): + lora_config = types.LoraConfig(rank=config.r, train_unembed=False, train_mlp=True, train_attn=True) + run_config = types.CreateModelRequest( + base_model=self.base_model, lora_config=lora_config, user_metadata={'adapter_name': body.adapter_name}) + training_run_manager.save(adapter_name, run_config) + return {'status': 'ok', 'adapter_name': adapter_name} + + return await run_task(self.schedule_task_and_wait(_task, task_type='add_adapter_to_model')) + + @app.post('/twinkle/apply_patch') + async def apply_patch( + request: Request, + body: types.ApplyPatchRequest, + self: ModelManagement = Depends(self_fn), + ) -> None: + adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) + + async def _task(): + self.assert_adapter_exists(adapter_name=adapter_name) + extra_kwargs = body.model_extra or {} + patch_cls = deserialize_object(body.patch_cls) + self.model.apply_patch(patch_cls, adapter_name=adapter_name, **extra_kwargs) + + await run_task(self.schedule_task_and_wait(_task, task_type='apply_patch')) + + @app.post('/twinkle/add_metric') + async def add_metric( + request: Request, + body: types.AddMetricRequest, + self: ModelManagement = Depends(self_fn), + ) -> None: + adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) + + async def _task(): + self.assert_adapter_exists(adapter_name=adapter_name) + extra_kwargs = body.model_extra or {} + metric_cls = deserialize_object(body.metric_cls) + self.model.add_metric(metric_cls, is_training=body.is_training, adapter_name=adapter_name, **extra_kwargs) + + await run_task(self.schedule_task_and_wait(_task, task_type='add_metric')) + + @app.post('/twinkle/set_template') + async def set_template( + request: Request, + body: types.SetTemplateRequest, + self: ModelManagement = Depends(self_fn), + ) -> None: + adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) + + async def _task(): + self.assert_adapter_exists(adapter_name=adapter_name) + extra_kwargs = body.model_extra or {} + self.model.set_template(body.template_cls, adapter_name=adapter_name, **extra_kwargs) + + await run_task(self.schedule_task_and_wait(_task, task_type='set_template')) + + @app.post('/twinkle/set_processor') + async def set_processor( + request: Request, + body: types.SetProcessorRequest, + self: ModelManagement = Depends(self_fn), + ) -> None: + adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) + + async def _task(): + self.assert_adapter_exists(adapter_name=adapter_name) + extra_kwargs = body.model_extra or {} + self.model.set_processor(body.processor_cls, adapter_name=adapter_name, **extra_kwargs) + + await run_task(self.schedule_task_and_wait(_task, task_type='set_processor')) + + @app.post('/twinkle/calculate_metric', response_model=types.CalculateMetricResponse) + async def calculate_metric( + request: Request, + body: types.CalculateMetricRequest, + self: ModelManagement = Depends(self_fn), + ) -> types.CalculateMetricResponse: + adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) + + async def _task(): + self.assert_adapter_exists(adapter_name=adapter_name) + extra_kwargs = body.model_extra or {} + ret = self.model.calculate_metric(is_training=body.is_training, adapter_name=adapter_name, **extra_kwargs) + return {'result': ret} + + return await run_task(self.schedule_task_and_wait(_task, task_type='calculate_metric')) + + @app.post('/twinkle/get_state_dict', response_model=types.GetStateDictResponse) + async def get_state_dict( + request: Request, + body: types.GetStateDictRequest, + self: ModelManagement = Depends(self_fn), + ) -> types.GetStateDictResponse: + adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) + + async def _task(): + self.assert_adapter_exists(adapter_name=adapter_name) + extra_kwargs = body.model_extra or {} + ret = self.model.get_state_dict(adapter_name=adapter_name, **extra_kwargs) + return {'result': ret} + + return await run_task(self.schedule_task_and_wait(_task, task_type='get_state_dict')) diff --git a/src/twinkle/server/processor/__init__.py b/src/twinkle/server/processor/__init__.py new file mode 100644 index 00000000..4032f5bf --- /dev/null +++ b/src/twinkle/server/processor/__init__.py @@ -0,0 +1,3 @@ +from .app import build_processor_app + +__all__ = ['build_processor_app'] diff --git a/src/twinkle/server/processor/app.py b/src/twinkle/server/processor/app.py new file mode 100644 index 00000000..4b03af86 --- /dev/null +++ b/src/twinkle/server/processor/app.py @@ -0,0 +1,133 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +""" +Processor management application. + +Provides a Ray Serve deployment for managing distributed processors +(datasets, dataloaders, preprocessors, rewards, templates, weight loaders, etc.). + +Follows the same structural pattern as model/app.py: +- ProcessorManagement is a top-level class inheriting ProcessorManagerMixin +- Routes are registered in build_processor_app() via _register_processor_routes() +- serve.ingress(app)(ProcessorManagement) applied before deployment +- Sticky session routing via @serve.multiplexed keyed on session ID +""" +from __future__ import annotations + +import os +from fastapi import FastAPI, Request +from ray import serve +from typing import Any, Dict, Optional + +import twinkle +from twinkle import DeviceGroup, DeviceMesh, get_logger +from twinkle.server.utils.processor_manager import ProcessorManagerMixin +from twinkle.server.utils.state import ServerStateProxy, get_server_state +from twinkle.server.utils.validation import verify_request_token +from .twinkle_handlers import _register_processor_routes + +logger = get_logger() + + +class ProcessorManagement(ProcessorManagerMixin): + """Processor management service. + + Manages lifecycle and invocation of distributed processor objects + (datasets, dataloaders, rewards, templates, etc.). + + Lifecycle is handled by ProcessorManagerMixin: + - Processors are registered with a session ID on creation. + - A background thread expires processors whose session has timed out. + - Per-user processor limit is enforced at registration. + - Sticky session routing ensures session requests hit the same replica. + """ + + def __init__(self, + ncpu_proc_per_node: int, + device_group: dict[str, Any], + device_mesh: dict[str, Any], + nproc_per_node: int = 1, + processor_config: dict[str, Any] | None = None): + self.device_group = DeviceGroup(**device_group) + twinkle.initialize( + mode='ray', + nproc_per_node=nproc_per_node, + groups=[self.device_group], + lazy_collect=False, + ncpu_proc_per_node=ncpu_proc_per_node) + if 'mesh_dim_names' in device_mesh: + self.device_mesh = DeviceMesh(**device_mesh) + else: + self.device_mesh = DeviceMesh.from_sizes(**device_mesh) + + # processor objects keyed by processor_id + self.resource_dict: dict[str, Any] = {} + self.state: ServerStateProxy = get_server_state() + + _cfg = processor_config or {} + _env_limit = int(os.environ.get('TWINKLE_PER_USER_PROCESSOR_LIMIT', 20)) + self._init_processor_manager( + processor_timeout=float(_cfg.get('processor_timeout', 1800.0)), + per_token_processor_limit=int(_cfg.get('per_token_processor_limit', _env_limit)), + ) + self.start_processor_countdown() + + @serve.multiplexed(max_num_models_per_replica=100) + async def _sticky_entry(self, sticky_key: str): + return sticky_key + + async def _ensure_sticky(self): + sticky_key = serve.get_multiplexed_model_id() + await self._sticky_entry(sticky_key) + + def _on_processor_expired(self, processor_id: str) -> None: + """Called by the countdown thread when a processor's session expires.""" + self.resource_dict.pop(processor_id, None) + self.unregister_processor(processor_id) + + +def build_processor_app(ncpu_proc_per_node: int, + device_group: dict[str, Any], + device_mesh: dict[str, Any], + deploy_options: dict[str, Any], + nproc_per_node: int = 1, + processor_config: dict[str, Any] | None = None, + **kwargs): + """Build the processor management application. + + Follows the same pattern as build_model_app(): FastAPI app and routes are + built here BEFORE serve.ingress so that the frozen app contains the full + route table visible to ProxyActor. + + Args: + ncpu_proc_per_node: Number of CPU processes per node. + device_group: Device group configuration dict. + device_mesh: Device mesh configuration dict. + deploy_options: Ray Serve deployment options. + nproc_per_node: Number of GPU processes per node (default 1). + processor_config: Optional lifecycle configuration dict. + Supported keys: + - ``processor_timeout`` (float): Session inactivity timeout seconds. Default 1800.0. + - ``per_token_processor_limit`` (int): Max processors per user. + Overrides ``TWINKLE_PER_USER_PROCESSOR_LIMIT`` env var when provided. + **kwargs: Additional arguments. + + Returns: + Ray Serve deployment bound with configuration. + """ + # Build the FastAPI app and register all routes BEFORE serve.ingress so that + # the frozen app contains the complete route table (visible to ProxyActor). + app = FastAPI() + + @app.middleware('http') + async def verify_token(request: Request, call_next): + return await verify_request_token(request=request, call_next=call_next) + + def get_self() -> ProcessorManagement: + return serve.get_replica_context().servable_object + + _register_processor_routes(app, get_self) + + ProcessorManagementWithIngress = serve.ingress(app)(ProcessorManagement) + DeploymentClass = serve.deployment(name='ProcessorManagement')(ProcessorManagementWithIngress) + return DeploymentClass.options(**deploy_options).bind(ncpu_proc_per_node, device_group, device_mesh, nproc_per_node, + processor_config) diff --git a/src/twinkle/server/processor/twinkle_handlers.py b/src/twinkle/server/processor/twinkle_handlers.py new file mode 100644 index 00000000..86e35f86 --- /dev/null +++ b/src/twinkle/server/processor/twinkle_handlers.py @@ -0,0 +1,130 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +""" +Processor management handler mixin. + +All endpoints are prefixed /twinkle/... and handle processor lifecycle +(create, call). self_fn is injected via FastAPI Depends to obtain the +ProcessorManagement instance at request time. +""" +from __future__ import annotations + +import asyncio +import importlib +import uuid +from fastapi import Depends, FastAPI, HTTPException, Request +from typing import TYPE_CHECKING, Callable + +if TYPE_CHECKING: + from .app import ProcessorManagement + +import twinkle_client.types as types +from twinkle.server.common.serialize import deserialize_object +from twinkle.server.utils.validation import get_session_id_from_request, get_token_from_request +from twinkle.utils.logger import get_logger + +logger = get_logger() + +_PROCESSOR_TYPES = ['dataset', 'dataloader', 'preprocessor', 'processor', 'reward', 'template', 'weight_loader'] + + +def _register_processor_routes(app: FastAPI, self_fn: Callable[[], ProcessorManagement]) -> None: + """Register all /twinkle/* processor routes on the given FastAPI app. + + self_fn is a zero-argument callable that returns the current ProcessorManagement + replica instance. It is wired in via Depends so it is resolved lazily at request time. + """ + + @app.post('/twinkle/create', response_model=types.ProcessorCreateResponse) + async def create( + request: Request, body: types.ProcessorCreateRequest, + self: ProcessorManagement = Depends(self_fn)) -> types.ProcessorCreateResponse: + await self._ensure_sticky() + + processor_type_name = body.processor_type + class_type = body.class_type + _kwargs = body.model_extra or {} + + assert processor_type_name in _PROCESSOR_TYPES, f'Invalid processor type: {processor_type_name}' + processor_module = importlib.import_module(f'twinkle.{processor_type_name}') + assert hasattr(processor_module, class_type), f'Class {class_type} not found in {processor_type_name}' + + token = get_token_from_request(request) + session_id = get_session_id_from_request(request) + processor_id = str(uuid.uuid4().hex) + + # Register for lifecycle tracking (enforces per-user limit) + self.register_processor(processor_id, token, session_id) + + _kwargs.pop('remote_group', None) + _kwargs.pop('device_mesh', None) + + resolved_kwargs = {} + for key, value in _kwargs.items(): + if isinstance(value, str) and value.startswith('pid:'): + ref_id = value[4:] + resolved_kwargs[key] = self.resource_dict[ref_id] + else: + value = deserialize_object(value) + resolved_kwargs[key] = value + + # Run processor instantiation in a thread to avoid blocking the event loop, + # which would starve the session-liveness coroutines submitted by the + # countdown thread via asyncio.run_coroutine_threadsafe. + _remote_group = self.device_group.name + _device_mesh = self.device_mesh + + def _do_create(): + return getattr(processor_module, class_type)( + remote_group=_remote_group, device_mesh=_device_mesh, instance_id=processor_id, **resolved_kwargs) + + processor = await asyncio.get_event_loop().run_in_executor(None, _do_create) + self.resource_dict[processor_id] = processor + return types.ProcessorCreateResponse(processor_id='pid:' + processor_id) + + @app.post('/twinkle/call', response_model=types.ProcessorCallResponse) + async def call( + request: Request, body: types.ProcessorCallRequest, + self: ProcessorManagement = Depends(self_fn)) -> types.ProcessorCallResponse: + await self._ensure_sticky() + + processor_id = body.processor_id + function_name = body.function + _kwargs = body.model_extra or {} + processor_id = processor_id[4:] + self.assert_processor_exists(processor_id=processor_id) + processor = self.resource_dict.get(processor_id) + function = getattr(processor, function_name, None) + + assert function is not None, f'`{function_name}` not found in {processor.__class__}' + assert hasattr(function, '_execute'), f'Cannot call inner method of {processor.__class__}' + + resolved_kwargs = {} + for key, value in _kwargs.items(): + if isinstance(value, str) and value.startswith('pid:'): + ref_id = value[4:] + resolved_kwargs[key] = self.resource_dict[ref_id] + else: + value = deserialize_object(value) + resolved_kwargs[key] = value + + # Run the processor function in a thread to avoid blocking the event loop. + # StopIteration cannot propagate through asyncio coroutine boundaries + # (Python 3.7+ converts it to RuntimeError), so capture it as a sentinel tuple. + def _do_call(): + try: + result = function(**resolved_kwargs) + return False, result + except StopIteration: + return True, None + + is_exhausted, result = await asyncio.get_event_loop().run_in_executor(None, _do_call) + + if function_name == '__next__': + if is_exhausted: + # HTTP 410 Gone signals iterator exhausted + raise HTTPException(status_code=410, detail='Iterator exhausted') + return types.ProcessorCallResponse(result=result) + + if function_name == '__iter__': + return types.ProcessorCallResponse(result='ok') + return types.ProcessorCallResponse(result=result) diff --git a/src/twinkle/server/sampler/__init__.py b/src/twinkle/server/sampler/__init__.py new file mode 100644 index 00000000..58db9098 --- /dev/null +++ b/src/twinkle/server/sampler/__init__.py @@ -0,0 +1,3 @@ +from .app import build_sampler_app + +__all__ = ['build_sampler_app'] diff --git a/src/twinkle/server/sampler/app.py b/src/twinkle/server/sampler/app.py new file mode 100644 index 00000000..c69a6956 --- /dev/null +++ b/src/twinkle/server/sampler/app.py @@ -0,0 +1,168 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +""" +Unified sampler management application. + +Builds a single Ray Serve deployment (SamplerManagement) that simultaneously handles +both Tinker (/tinker/asample) and Twinkle (/twinkle/*) sampler endpoints. +""" +from __future__ import annotations + +from fastapi import FastAPI, Request +from ray import serve +from typing import Any, Dict, Optional + +import twinkle +from twinkle import DeviceGroup, DeviceMesh +from twinkle.server.utils.adapter_manager import AdapterManagerMixin +from twinkle.server.utils.state import ServerStateProxy, get_server_state +from twinkle.server.utils.task_queue import TaskQueueConfig, TaskQueueMixin +from twinkle.server.utils.validation import get_token_from_request, verify_request_token +from twinkle.utils.logger import get_logger +from ..utils import wrap_builder_with_device_group_env +from .tinker_handlers import _register_tinker_sampler_routes +from .twinkle_handlers import _register_twinkle_sampler_routes + +logger = get_logger() + + +class SamplerManagement(TaskQueueMixin, AdapterManagerMixin): + """Unified sampler management service. + + Manages: + - vLLM or Torch sampler initialization and lifecycle + - Tinker inference requests (/tinker/asample) with rate limiting via TaskQueueMixin + - Twinkle inference requests (/twinkle/*) calling sampler directly + - Adapter lifecycle via AdapterManagerMixin + - Template configuration for trajectory encoding + """ + + def __init__(self, + model_id: str, + nproc_per_node: int, + device_group: dict[str, Any], + device_mesh: dict[str, Any], + sampler_type: str = 'vllm', + engine_args: dict[str, Any] | None = None, + adapter_config: dict[str, Any] | None = None, + queue_config: dict[str, Any] | None = None, + **kwargs): + self.device_group = DeviceGroup(**device_group) + twinkle.initialize(mode='ray', nproc_per_node=nproc_per_node, groups=[self.device_group], lazy_collect=False) + if 'mesh_dim_names' in device_mesh: + self.device_mesh = DeviceMesh(**device_mesh) + else: + self.device_mesh = DeviceMesh.from_sizes(**device_mesh) + self.sampler_type = sampler_type + replica_context = serve.get_replica_context() + replica_id = replica_context.replica_id.unique_id + + # Initialize sampler based on type + if sampler_type == 'vllm': + from twinkle.sampler import vLLMSampler + sampler_kwargs = engine_args or {} + self.sampler = vLLMSampler( + model_id=model_id, + engine_args=sampler_kwargs, + device_mesh=self.device_mesh, + remote_group=self.device_group.name, + instance_id=replica_id, + **{ + k: v + for k, v in kwargs.items() if k not in ['engine_args'] + }) + else: + from twinkle.sampler import TorchSampler + self.sampler = TorchSampler( + model_id=model_id, + device_mesh=self.device_mesh, + instance_id=replica_id, + remote_group=self.device_group.name, + **kwargs) + + self.sampler.set_template('Template', model_id=model_id) + self.state: ServerStateProxy = get_server_state() + + # Initialize both mixins + self._init_task_queue(TaskQueueConfig.from_dict(queue_config)) + _adapter_config = adapter_config or {} + self._init_adapter_manager(**_adapter_config) + self.start_adapter_countdown() + + @serve.multiplexed(max_num_models_per_replica=5) + async def _sticky_entry(self, sticky_key: str): + return sticky_key + + async def _ensure_sticky(self): + sticky_key = serve.get_multiplexed_model_id() + await self._sticky_entry(sticky_key) + + async def _on_request_start(self, request: Request) -> str: + await self._ensure_sticky() + token = get_token_from_request(request) + return token + + def _on_adapter_expired(self, adapter_name: str, token: str = None) -> None: + """Handle expired adapters by removing them from the sampler.""" + try: + self.sampler.remove_adapter(adapter_name) + logger.info(f'Removed expired adapter {adapter_name}') + except Exception as e: + logger.warning(f'Failed to remove expired adapter {adapter_name}: {e}') + + +def build_sampler_app(model_id: str, + nproc_per_node: int, + device_group: dict[str, Any], + device_mesh: dict[str, Any], + deploy_options: dict[str, Any], + sampler_type: str = 'vllm', + engine_args: dict[str, Any] | None = None, + adapter_config: dict[str, Any] | None = None, + queue_config: dict[str, Any] | None = None, + **kwargs): + """Build a unified sampler application for text generation inference. + + Supports both Tinker (polling-style /tinker/asample) and + Twinkle (synchronous /twinkle/*) sampler clients. + + Args: + model_id: Model identifier (e.g., "Qwen/Qwen2.5-0.5B-Instruct") + nproc_per_node: Number of processes per node + device_group: Device group configuration dict + device_mesh: Device mesh configuration dict for parallelism + deploy_options: Ray Serve deployment options + sampler_type: Type of sampler to use ('vllm' or 'torch') + engine_args: Additional engine arguments for the sampler + adapter_config: Adapter lifecycle config (timeout, per-token limits) + queue_config: Task queue configuration dict (rps_limit, tps_limit, etc.) + **kwargs: Additional arguments passed to the sampler + + Returns: + Ray Serve deployment bound with configuration + """ + # Build the FastAPI app and register all routes BEFORE serve.ingress so that + # the frozen app contains the complete route table (visible to ProxyActor). + app = FastAPI( + title='Unified Sampler', + description='REST API for distributed text generation inference (Tinker + Twinkle)', + version='1.0.0') + + @app.middleware('http') + async def verify_token(request: Request, call_next): + return await verify_request_token(request=request, call_next=call_next) + + def get_self() -> SamplerManagement: + return serve.get_replica_context().servable_object + + # Register routes BEFORE @serve.ingress so Ray Serve captures them at decoration time + _register_tinker_sampler_routes(app, get_self) + _register_twinkle_sampler_routes(app, get_self) + + SamplerManagementWithIngress = serve.ingress(app)(SamplerManagement) + DeploymentClass = serve.deployment(name='SamplerManagement')(SamplerManagementWithIngress) + return DeploymentClass.options(**deploy_options).bind(model_id, nproc_per_node, device_group, device_mesh, + sampler_type, engine_args, adapter_config, queue_config, + **kwargs) + + +build_sampler_app = wrap_builder_with_device_group_env(build_sampler_app) diff --git a/src/twinkle/server/sampler/tinker_handlers.py b/src/twinkle/server/sampler/tinker_handlers.py new file mode 100644 index 00000000..4cd574be --- /dev/null +++ b/src/twinkle/server/sampler/tinker_handlers.py @@ -0,0 +1,121 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +""" +Tinker-compatible sampler handler mixin. + +Provides POST /tinker/asample using schedule_task() returning UntypedAPIFuture. +""" +from __future__ import annotations + +import os +import traceback +from fastapi import Depends, FastAPI, Request +from tinker import types +from typing import TYPE_CHECKING, Callable + +if TYPE_CHECKING: + from .app import SamplerManagement + +from twinkle.data_format import SamplingParams +from twinkle.server.common.checkpoint_factory import create_checkpoint_manager +from twinkle.utils.logger import get_logger + +logger = get_logger() + + +def _register_tinker_sampler_routes(app: FastAPI, self_fn: Callable[[], SamplerManagement]) -> None: + """Register the tinker sampler route on the given FastAPI app. + + self_fn is a zero-argument callable returning the current SamplerManagement replica instance. + It is wired in via Depends so it is resolved lazily at request time. + """ + + @app.post('/tinker/asample') + async def asample(request: Request, body: types.SampleRequest, + self: SamplerManagement = Depends(self_fn)) -> types.UntypedAPIFuture: + """Execute text generation (inference) for Tinker clients. + + Args: + request: FastAPI request with auth token + body: SampleRequest with prompt, sampling params, and adapter info + + Returns: + UntypedAPIFuture wrapping SampleResponse with generated sequences + """ + token = await self._on_request_start(request) + + async def _do_sample(): + try: + # Extract prompt token IDs from ModelInput + prompt_inputs = {'input_ids': body.prompt.to_ints()} + + # Get model_path from body or sampling session + model_path = body.model_path + if not model_path and body.sampling_session_id: + session = self.state.get_sampling_session(body.sampling_session_id) + if session: + model_path = session.get('model_path') + + # Parse and resolve adapter URI from model_path + adapter_uri = None + if model_path: + checkpoint_manager = create_checkpoint_manager(token, client_type='tinker') + adapter_name, adapter_uri = checkpoint_manager.parse_adapter_uri(model_path) + + # Validate adapter URI + if not adapter_uri or not os.path.exists(adapter_uri): + return types.RequestFailedResponse( + error=f'Adapter URI {model_path} does not exist. Please check the model_path.', + category=types.RequestErrorCategory.User, + ) + + # Convert tinker SamplingParams to twinkle SamplingParams if needed + sampling_params = None + if body.sampling_params: + sampling_params = SamplingParams( + max_tokens=body.sampling_params.max_tokens or 256, + temperature=body.sampling_params.temperature or 1.0, + top_p=body.sampling_params.top_p, + top_k=body.sampling_params.top_k, + stop=body.sampling_params.stop, + ) + + response = self.sampler.sample( + inputs=[prompt_inputs] * body.num_samples, + sampling_params=sampling_params, + adapter_path=adapter_uri, + ) + + # Convert twinkle SampleResponse to tinker types + tinker_sequences = [] + for seq in response.sequences: + logprobs = None + if seq.logprobs is not None: + if any(lp is None for lp in seq.logprobs): + logprobs = None + else: + logprobs = list(seq.logprobs) + tinker_sequences.append( + types.SampledSequence( + stop_reason=seq.stop_reason, + tokens=list(seq.tokens), + logprobs=logprobs, + )) + return types.SampleResponse( + sequences=tinker_sequences, + prompt_logprobs=response.prompt_logprobs, + topk_prompt_logprobs=response.topk_prompt_logprobs, + ) + except Exception: + logger.error(traceback.format_exc()) + return types.RequestFailedResponse( + error=traceback.format_exc(), + category=types.RequestErrorCategory.Server, + ) + + input_tokens = len(body.prompt.to_ints()) + return await self.schedule_task( + _do_sample, + token=token, + input_tokens=input_tokens, + task_type='sample', + ) diff --git a/src/twinkle/server/sampler/twinkle_handlers.py b/src/twinkle/server/sampler/twinkle_handlers.py new file mode 100644 index 00000000..a31f4046 --- /dev/null +++ b/src/twinkle/server/sampler/twinkle_handlers.py @@ -0,0 +1,161 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +""" +Twinkle-native sampler handler mixin. + +Provides /twinkle/* sampler endpoints that call the sampler directly (no queue needed). +""" +from __future__ import annotations + +import traceback +from fastapi import Depends, FastAPI, HTTPException, Request +from typing import TYPE_CHECKING, Callable, Optional + +if TYPE_CHECKING: + from .app import SamplerManagement + +import numpy as np + +import twinkle_client.types as types +from twinkle.data_format import InputFeature, SamplingParams, Trajectory +from twinkle.utils.logger import get_logger + +logger = get_logger() + + +def _serialize_input_feature(feature: dict) -> dict: + """Convert numpy arrays / torch tensors in an InputFeature to plain Python lists.""" + result = {} + for k, v in feature.items(): + if isinstance(v, np.ndarray): + result[k] = v.tolist() + else: + try: + import torch + if isinstance(v, torch.Tensor): + result[k] = v.tolist() + continue + except ImportError: + pass + result[k] = v + return result + + +def _get_twinkle_sampler_adapter_name(request: Request, adapter_name: str | None) -> str | None: + """Prefix the adapter name with the request ID for per-request isolation.""" + if adapter_name is None or adapter_name == '': + return None + return request.state.request_id + '-' + adapter_name + + +def _register_twinkle_sampler_routes(app: FastAPI, self_fn: Callable[[], SamplerManagement]) -> None: + """Register all /twinkle/* sampler routes on the given FastAPI app. + + self_fn is a zero-argument callable returning the current SamplerManagement replica instance. + It is wired in via Depends so it is resolved lazily at request time. + """ + + @app.post('/twinkle/create', response_model=types.CreateResponse) + def create(request: Request, self: SamplerManagement = Depends(self_fn)) -> types.CreateResponse: + """Health check / session creation endpoint.""" + return types.CreateResponse() + + @app.post('/twinkle/sample', response_model=types.SampleResponseModel) + def sample(request: Request, body: types.SampleRequest, + self: SamplerManagement = Depends(self_fn)) -> types.SampleResponseModel: + """Sample completions from the model. + + Supports Trajectory or InputFeature inputs, with optional LoRA adapter. + """ + try: + # Resolve adapter + adapter_path = None + adapter_name = body.adapter_name or '' + full_adapter_name = _get_twinkle_sampler_adapter_name(request, adapter_name) or '' + + if body.adapter_uri: + from twinkle.server.common.checkpoint_factory import create_checkpoint_manager + from twinkle.server.utils.validation import get_token_from_request + token = get_token_from_request(request) + checkpoint_manager = create_checkpoint_manager(token, client_type='twinkle') + _, adapter_path = checkpoint_manager.parse_adapter_uri(body.adapter_uri) + + # Parse inputs + inputs = body.inputs + if isinstance(inputs, list) and inputs: + first = inputs[0] + if isinstance(first, dict) and 'input_ids' in first: + inputs = [InputFeature(**item) for item in inputs] + else: + inputs = [Trajectory(**item) for item in inputs] + elif isinstance(inputs, dict): + if 'input_ids' in inputs: + inputs = [InputFeature(**inputs)] + else: + inputs = [Trajectory(**inputs)] + + # Build sampling params + params = None + if body.sampling_params: + params = SamplingParams.from_dict(body.sampling_params) + + # Call sampler + response = self.sampler.sample( + inputs, + params, + adapter_name=full_adapter_name, + adapter_path=adapter_path, + num_samples=body.num_samples, + ) + if callable(response): + response = response() + + sequences = [ + types.SampledSequenceModel( + stop_reason=seq.stop_reason, + tokens=list(seq.tokens), + logprobs=list(seq.logprobs) if seq.logprobs is not None else None, + decoded=seq.decoded, + new_input_feature=_serialize_input_feature(seq.new_input_feature) + if seq.new_input_feature is not None else None, + ) for seq in response.sequences + ] + + return types.SampleResponseModel( + sequences=sequences, + prompt_logprobs=response.prompt_logprobs, + topk_prompt_logprobs=response.topk_prompt_logprobs, + ) + except Exception: + logger.error(traceback.format_exc()) + raise HTTPException(status_code=500, detail=traceback.format_exc()) + + @app.post('/twinkle/set_template', response_model=types.SetTemplateResponse) + def set_template( + request: Request, + body: types.SetTemplateRequest, + self: SamplerManagement = Depends(self_fn), + ) -> types.SetTemplateResponse: + """Set the chat template for encoding Trajectory inputs.""" + extra_kwargs = body.model_extra or {} + self.sampler.set_template(body.template_cls, **extra_kwargs) + return types.SetTemplateResponse() + + @app.post('/twinkle/add_adapter_to_sampler', response_model=types.AddAdapterResponse) + def add_adapter_to_sampler( + request: Request, + body: types.AddAdapterRequest, + self: SamplerManagement = Depends(self_fn), + ) -> types.AddAdapterResponse: + """Add a LoRA adapter to the sampler.""" + assert body.adapter_name, 'You need to specify a valid `adapter_name`' + full_adapter_name = _get_twinkle_sampler_adapter_name(request, body.adapter_name) + from twinkle.server.utils.validation import get_token_from_request + token = get_token_from_request(request) + + from peft import LoraConfig + config = LoraConfig(**body.config) if isinstance(body.config, dict) else body.config + + self.register_adapter(full_adapter_name, token) + self.sampler.add_adapter_to_sampler(full_adapter_name, config) + + return types.AddAdapterResponse(adapter_name=full_adapter_name) diff --git a/src/twinkle/server/tinker/__init__.py b/src/twinkle/server/tinker/__init__.py deleted file mode 100644 index 40688d64..00000000 --- a/src/twinkle/server/tinker/__init__.py +++ /dev/null @@ -1,18 +0,0 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -import sys -from typing import TYPE_CHECKING - -from twinkle.utils.import_utils import _LazyModule - -_import_structure = { - 'model': ['build_model_app'], - 'sampler': ['build_sampler_app'], - 'server': ['build_server_app'], -} - -if TYPE_CHECKING: - from .model import build_model_app - from .sampler import build_sampler_app - from .server import build_server_app -else: - sys.modules[__name__] = _LazyModule(__name__, __file__, _import_structure, module_spec=__spec__) diff --git a/src/twinkle/server/tinker/common/__init__.py b/src/twinkle/server/tinker/common/__init__.py deleted file mode 100644 index ae59d58f..00000000 --- a/src/twinkle/server/tinker/common/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -from twinkle.utils import exists, requires -from .datum import datum_to_input_feature, input_feature_to_datum diff --git a/src/twinkle/server/tinker/common/io_utils.py b/src/twinkle/server/tinker/common/io_utils.py deleted file mode 100644 index f3128e99..00000000 --- a/src/twinkle/server/tinker/common/io_utils.py +++ /dev/null @@ -1,181 +0,0 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -"""Tinker-specific IO utilities for managing training runs and checkpoints. - -This module extends the base IO utilities with Tinker-specific implementations. -It uses types from the tinker package for compatibility with the Tinker API. -""" -from datetime import datetime -from tinker import types -from typing import Any, Dict, List, Optional - -from twinkle.server.utils.io_utils import (CHECKPOINT_INFO_FILENAME, TRAIN_RUN_INFO_FILENAME, TWINKLE_DEFAULT_SAVE_DIR, - BaseCheckpointManager, BaseTrainingRunManager, ResolvedLoadPath, - validate_ownership, validate_user_path) - -# ----- Tinker Training Run Manager ----- - - -class TrainingRunManager(BaseTrainingRunManager): - """Tinker-specific training run manager using tinker.types models.""" - - @property - def train_run_info_filename(self) -> str: - return TRAIN_RUN_INFO_FILENAME - - def _create_training_run(self, model_id: str, run_config: types.CreateModelRequest) -> Dict[str, Any]: - """Create training run data from model_id and run_config.""" - lora_config = run_config.lora_config - train_run_data = types.TrainingRun( - training_run_id=model_id, - base_model=run_config.base_model, - model_owner=self.token, - is_lora=True if lora_config else False, - corrupted=False, - lora_rank=lora_config.rank if lora_config else None, - last_request_time=datetime.now(), - last_checkpoint=None, - last_sampler_checkpoint=None, - user_metadata=run_config.user_metadata) - - new_data = train_run_data.model_dump(mode='json') - # Store lora config details separately if needed - if lora_config: - new_data['train_unembed'] = lora_config.train_unembed - new_data['train_mlp'] = lora_config.train_mlp - new_data['train_attn'] = lora_config.train_attn - - return new_data - - def _parse_training_run(self, data: Dict[str, Any]) -> types.TrainingRun: - """Parse training run data into TrainingRun model.""" - # Transform checkpoint data to ensure tinker_path field exists - data = self._transform_checkpoint_fields(data) - return types.TrainingRun(**data) - - def _transform_checkpoint_fields(self, data: Dict[str, Any]) -> Dict[str, Any]: - """Transform checkpoint data to ensure compatibility with tinker types. - - Handles cases where: - - last_checkpoint/last_sampler_checkpoint might have twinkle_path instead of tinker_path - - Missing path field that needs to be constructed from other data - """ - data = data.copy() - for field in ['last_checkpoint', 'last_sampler_checkpoint']: - if field in data and data[field] is not None: - ckpt = data[field].copy() - # If twinkle_path exists but tinker_path doesn't, use twinkle_path - if 'twinkle_path' in ckpt and 'tinker_path' not in ckpt: - ckpt['tinker_path'] = ckpt.pop('twinkle_path') - # If neither exists, try to construct from checkpoint_id - elif 'tinker_path' not in ckpt: - # Try to get path from any available path field - path = ckpt.get('path') or ckpt.get('twinkle_path') - if path: - ckpt['tinker_path'] = path - elif 'checkpoint_id' in ckpt and 'training_run_id' in data: - # Construct path from components - ckpt['tinker_path'] = f"twinkle://{data['training_run_id']}/{ckpt['checkpoint_id']}" - data[field] = ckpt - return data - - def _create_training_runs_response(self, runs: List[types.TrainingRun], limit: int, offset: int, - total: int) -> types.TrainingRunsResponse: - """Create a training runs response.""" - return types.TrainingRunsResponse( - training_runs=runs, cursor=types.Cursor(limit=limit, offset=offset, total_count=total)) - - -# ----- Tinker Checkpoint Manager ----- - - -class CheckpointManager(BaseCheckpointManager): - """Tinker-specific checkpoint manager using tinker.types models.""" - - @property - def path_prefix(self) -> str: - return 'twinkle://' - - @property - def path_field_name(self) -> str: - return 'tinker_path' - - def _create_checkpoint(self, - checkpoint_id: str, - checkpoint_type: str, - path: str, - size_bytes: int, - public: bool, - base_model: Optional[str] = None, - is_lora: bool = False, - lora_rank: Optional[int] = None, - train_unembed: Optional[bool] = None, - train_mlp: Optional[bool] = None, - train_attn: Optional[bool] = None, - user_metadata: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: - """Create checkpoint data.""" - # Create base checkpoint using tinker types - checkpoint = types.Checkpoint( - checkpoint_id=checkpoint_id, - checkpoint_type=checkpoint_type, - time=datetime.now(), - tinker_path=path, - size_bytes=size_bytes, - public=public) - result = checkpoint.model_dump(mode='json') - - # Add training run info fields (may not be supported by external types.Checkpoint) - result['base_model'] = base_model - result['is_lora'] = is_lora - result['lora_rank'] = lora_rank - result['train_unembed'] = train_unembed - result['train_mlp'] = train_mlp - result['train_attn'] = train_attn - result['user_metadata'] = user_metadata - - return result - - def _parse_checkpoint(self, data: Dict[str, Any]) -> types.Checkpoint: - """Parse checkpoint data into Checkpoint model.""" - data = data.copy() - # Transform twinkle_path to tinker_path if needed - if 'twinkle_path' in data and 'tinker_path' not in data: - data['tinker_path'] = data.pop('twinkle_path') - elif 'tinker_path' not in data and 'path' in data: - data['tinker_path'] = data.pop('path') - return types.Checkpoint(**data) - - def _create_checkpoints_response(self, checkpoints: List[types.Checkpoint]) -> types.CheckpointsListResponse: - """Create a checkpoints list response.""" - return types.CheckpointsListResponse(checkpoints=checkpoints, cursor=None) - - def _create_parsed_path(self, path: str, training_run_id: str, checkpoint_type: str, - checkpoint_id: str) -> types.ParsedCheckpointTinkerPath: - """Create a parsed path model.""" - return types.ParsedCheckpointTinkerPath( - tinker_path=path, - training_run_id=training_run_id, - checkpoint_type=checkpoint_type, - checkpoint_id=checkpoint_id, - ) - - def _create_weights_info(self, run_info: Dict[str, Any]) -> types.WeightsInfoResponse: - """Create weights info from run info.""" - return types.WeightsInfoResponse(**run_info) - - def parse_tinker_path(self, tinker_path: str) -> Optional[types.ParsedCheckpointTinkerPath]: - """Parse a twinkle:// path into its components (alias for parse_path).""" - return self.parse_path(tinker_path) - - -# ----- Factory Functions ----- - - -def create_training_run_manager(token: str) -> TrainingRunManager: - """Create a TrainingRunManager for the given token.""" - return TrainingRunManager(token) - - -def create_checkpoint_manager(token: str) -> CheckpointManager: - """Create a CheckpointManager for the given token.""" - training_run_manager = TrainingRunManager(token) - return CheckpointManager(token, training_run_manager) diff --git a/src/twinkle/server/tinker/common/megatron_model.py b/src/twinkle/server/tinker/common/megatron_model.py deleted file mode 100644 index ebd4df76..00000000 --- a/src/twinkle/server/tinker/common/megatron_model.py +++ /dev/null @@ -1,189 +0,0 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. - -import torch -from tinker import types -from typing import TYPE_CHECKING, Any, List, Optional, Tuple - -from twinkle import remote_class, remote_function -from twinkle.utils import exists, requires -from .compat_base import TwinkleCompatModelBase, clean_metrics, collect_forward_backward_results -from .datum import datum_to_input_feature, extract_rl_feature -from .io_utils import create_checkpoint_manager - -if TYPE_CHECKING: - from twinkle.model.megatron import MultiLoraMegatronModel as _MegatronBase -elif exists('megatron_core'): - # Use module-level import to trigger LazyModule's __getattr__ correctly - import twinkle.model.megatron as megatron_module - _MegatronBase = megatron_module.MultiLoraMegatronModel -else: - - class _MegatronBase: - - def __init__(self, *args, **kwargs): - requires('megatron_core') - - -@remote_class(execute='all') -class TwinkleCompatMegatronModel(_MegatronBase, TwinkleCompatModelBase): - """ - Compatibility wrapper around :class:`MultiLoraMegatronModel` for Twinkle/Tinker. - - This class adapts the core `MultiLoraMegatronModel` API to the data types and - remote-call semantics used by Twinkle: - - * Inputs to :meth:`forward_backward` and :meth:`forward_only` are provided as - ``List[types.Datum]`` and are converted to the underlying model's - ``InputFeature`` format via :func:`datum_to_input_feature`. - * The outputs are a list of dictionaries, one per input example, containing: - - - ``"logprobs"``: token-level log-probabilities as ``types.TensorData``. - - ``"elementwise_loss"``: per-token (masked) NLL loss as ``types.TensorData``. - - These are derived from the underlying logits by applying ``log_softmax`` - and slicing to the label sequence length. - * :meth:`forward_backward` returns a tuple of (outputs, loss) where loss is a - Python scalar for the aggregated loss. - * :meth:`step` accepts optimizer hyperparameters as :class:`types.AdamParams`, - and updates the optimizer configuration before calling the base ``step``. - - Note: Megatron uses combined forward_backward instead of separate forward/backward. - This wrapper provides a direct forward_backward interface. - """ - - @remote_function(dispatch='slice_dp', collect=collect_forward_backward_results, sync=True) - def forward_backward(self, *, inputs: List[types.Datum], adapter_name: str, loss_fn: str, **kwargs): - """Combined forward and backward pass. - - Returns: - Tuple of (outputs, loss) where outputs is a list of dicts with - 'logprobs' and 'elementwise_loss', and loss is a scalar. - """ - if loss_fn == 'importance_sampling': - super().set_loss( - 'GRPOLoss', - adapter_name=adapter_name, - epsilon=0.2, # Default GRPO epsilon - beta=0.0) # No KL penalty by default - # Get template for input processing - template = self.get_template(adapter_name=adapter_name) - # Convert Datum to InputFeature - input_features = datum_to_input_feature(inputs, template) - # Extract old_logps and advantages using common utility - loss_values = extract_rl_feature(inputs) - loss_kwargs = kwargs.copy() - loss_kwargs.update(loss_values) - # Megatron forward_backward returns loss directly - outputs = super().forward_backward(inputs=input_features, adapter_name=adapter_name, **loss_kwargs) - loss = outputs.get('loss', None) - logits_list = outputs.get('logits', []) - logps = outputs.get('logps', []) - # When PP enabled, only logits from last stage are available - if logits_list is None and logps is None: - return [None, None] - - logits = None - if logits_list is not None: - # Process logits to match transformers output format - if isinstance(logits_list, torch.Tensor): - logits = logits_list.detach() - else: - # Concatenate logits from multiple microbatches - logits = torch.cat([logit.detach() for logit in logits_list], dim=0) - logps = logps.detach().cpu() - results = self._get_forward_output(inputs, logits, logps) - - # Convert loss to scalar - if isinstance(loss, torch.Tensor): - loss = loss.item() - else: - loss = float(loss) - - return [results, loss] - - @remote_function(dispatch='slice_dp', collect='flatten') - def forward_only(self, *, inputs: List[types.Datum], **kwargs): - """Forward pass without gradient computation.""" - # Get template for input processing - template = self.get_template(**kwargs) - # Convert Datum to InputFeature - input_features = datum_to_input_feature(inputs, template) - - outputs = super().forward_only(inputs=input_features, **kwargs) - - # Get logits - logits = outputs.get('logits', None) - logps = outputs.get('logps', None) - - if logits is not None: - if isinstance(logits, torch.Tensor): - logits = logits.detach().cpu() - elif isinstance(logits, list) and len(logits) > 0: - logits = torch.cat([logit.detach().cpu() for logit in logits], dim=0) - results = self._get_forward_output(inputs, logits, logps) - else: - # If no logits available (non-last PP stage), return empty results - results = [{'logprobs': None, 'elementwise_loss': None} for _ in inputs] - - return results - - @remote_function(dispatch='all') - def step(self, *, adam_params: types.AdamParams, **kwargs): - """Optimizer step with AdamParams configuration. - - Updates the optimizer configuration and performs the step. - """ - adapter_name = kwargs.get('adapter_name') - optimizer_config = self.optimizer_group.get(adapter_name) - - if optimizer_config and optimizer_config.optimizer: - # Update optimizer config with adam_params - # Megatron optimizer handles gradient clipping internally - opt = optimizer_config.optimizer - if hasattr(opt, 'chained_optimizers'): - for chained_opt in opt.chained_optimizers: - if hasattr(chained_opt, 'config'): - chained_opt.config.lr = adam_params.learning_rate - chained_opt.config.adam_eps = adam_params.eps - chained_opt.config.adam_beta1 = adam_params.beta1 - chained_opt.config.adam_beta2 = adam_params.beta2 - chained_opt.config.weight_decay = adam_params.weight_decay - if adam_params.grad_clip_norm > 0: - chained_opt.config.clip_grad = adam_params.grad_clip_norm - - # Perform optimizer step - super().step(**kwargs) - # Zero gradients - super().zero_grad(**kwargs) - - @remote_function(collect='first', lazy_collect=False) - def calculate_metric(self, is_training, **kwargs): - metric = super().calculate_metric(is_training, **kwargs) - return clean_metrics(metric) - - @remote_function(dispatch='all', sync=True) - def load(self, checkpoint_dir: str, **kwargs): - """ - Load checkpoint with token-based isolation support. - - Args: - checkpoint_dir: The twinkle:// path to the checkpoint or hub model ID - **kwargs: Additional keyword arguments including optional 'token' - """ - # Extract token from kwargs if provided (for user isolation) - token = kwargs.pop('token', None) - if not token: - raise ValueError('Token is required for loading checkpoints') - - # Create checkpoint manager with the token - checkpoint_manager = create_checkpoint_manager(token) - - # Use resolve_load_path to handle path resolution - resolved = checkpoint_manager.resolve_load_path(checkpoint_dir) - - if resolved.is_twinkle_path: - # Load from twinkle checkpoint - return super().load(name=resolved.checkpoint_name, output_dir=resolved.checkpoint_dir, **kwargs) - else: - # Load from hub - return super().load(name=resolved.checkpoint_name, **kwargs) diff --git a/src/twinkle/server/tinker/common/transformers_model.py b/src/twinkle/server/tinker/common/transformers_model.py deleted file mode 100644 index 98ae0134..00000000 --- a/src/twinkle/server/tinker/common/transformers_model.py +++ /dev/null @@ -1,148 +0,0 @@ -from tinker import types -from typing import List - -from twinkle import remote_class, remote_function -from twinkle.model import MultiLoraTransformersModel -from .compat_base import TwinkleCompatModelBase, clean_metrics, collect_forward_backward_results -from .datum import datum_to_input_feature, extract_rl_feature -from .io_utils import create_checkpoint_manager - - -@remote_class() -class TwinkleCompatTransformersModel(MultiLoraTransformersModel, TwinkleCompatModelBase): - """ - Compatibility wrapper around :class:`MultiLoraTransformersModel` for Twinkle/Tinker. - - This class adapts the core `MultiLoraTransformersModel` API to the data types and - remote-call semantics used by Twinkle: - - * Inputs to :meth:`forward` and :meth:`forward_only` are provided as - ``List[types.Datum]`` and are converted to the underlying model's - ``InputFeature`` format via :func:`datum_to_input_feature`. - * The outputs of :meth:`forward` and :meth:`forward_only` are not the raw - transformer outputs; instead they are a list of dictionaries, one per - input example, containing: - - - ``"logprobs"``: token-level log-probabilities as ``types.TensorData``. - - ``"elementwise_loss"``: per-token (masked) NLL loss as ``types.TensorData``. - - These are derived from the underlying logits by applying ``log_softmax`` - and slicing to the label sequence length. - * :meth:`calculate_loss` returns a Python scalar (via ``tensor.item()``) - and is exposed as a remote function with ``collect='sum'``, so the - distributed caller receives an aggregated scalar loss instead of a - tensor object. - * :meth:`step` accepts optimizer hyperparameters as :class:`types.AdamParams`, - performs optional gradient clipping, translates them into the optimizer - configuration expected by the base class, invokes the base ``step`` - implementation, and finally zeros gradients. - - Overall, this wrapper ensures that callers using Twinkle's higher-level - ``Datum``/``TensorData`` abstractions and remote functions can interact - with a ``MultiLoraTransformersModel`` instance without needing to know its - internal input feature schema, output structure, or optimizer API. - """ - - @remote_function(dispatch='slice_dp', collect='flatten') - def forward_only(self, *, inputs: List[types.Datum], **kwargs): - # Get template for input processing - template = self.get_template(**kwargs) - # Convert Datum to InputFeature - input_features = datum_to_input_feature(inputs, template) - outputs = super().forward_only(inputs=input_features, **kwargs) - # shape (batch_size, seq_len, vocab_size) - logits = outputs['logits'].detach().cpu() - logps = outputs.get('logps', None) - if logps is not None: - logps = logps.detach().cpu() - results = self._get_forward_output(inputs, logits, logps) - return results - - @remote_function(dispatch='slice_dp', collect=collect_forward_backward_results) - def forward_backward(self, *, inputs: List[types.Datum], adapter_name: str, loss_fn: str, **kwargs): - # Set loss first based on loss_fn - if loss_fn == 'cross_entropy': - super().set_loss('CrossEntropyLoss', adapter_name=adapter_name) - elif loss_fn == 'importance_sampling': - super().set_loss( - 'GRPOLoss', - adapter_name=adapter_name, - epsilon=0.2, # Default GRPO epsilon - beta=0.0) # No KL penalty by default - else: - super().set_loss('CrossEntropyLoss', adapter_name=adapter_name) - # Get template for input processing - template = self.get_template(adapter_name) - - # Convert Datum to InputFeature - input_features = datum_to_input_feature(inputs, template) - - # Forward pass - outputs = super().forward(inputs=input_features, adapter_name=adapter_name, **kwargs) - - # Calculate loss with extra parameters - # Extract old_logps and advantages using common utility - loss_values = extract_rl_feature(inputs) - loss_kwargs = kwargs.copy() - loss_kwargs.update(loss_values) - loss = super().calculate_loss(adapter_name=adapter_name, **loss_kwargs) - - # Backward pass - super().backward(adapter_name=adapter_name, **kwargs) - - # shape (batch_size, seq_len, vocab_size) - logits = outputs['logits'].detach() - logps = outputs.get('logps', None) - if logps is not None: - logps = logps.detach().cpu() - results = self._get_forward_output(inputs, logits, logps) - return [results, loss] - - @remote_function() - def step(self, *, adam_params: types.AdamParams, **kwargs): - # Gradient clipping - grad_clip_norm = adam_params.grad_clip_norm - if grad_clip_norm > 0.0: - self.clip_grad_norm(max_grad_norm=grad_clip_norm, norm_type=2, **kwargs) - # Optimizer step - optim_params = { - 'lr': adam_params.learning_rate, - 'eps': adam_params.eps, - 'betas': (adam_params.beta1, adam_params.beta2), - 'weight_decay': adam_params.weight_decay, - } - super().step(optim_params=optim_params, **kwargs) - # Zero gradients - super().zero_grad(**kwargs) - - @remote_function(collect='first', lazy_collect=False) - def calculate_metric(self, is_training, **kwargs): - metric = super().calculate_metric(is_training, **kwargs) - return clean_metrics(metric) - - @remote_function() - def load(self, checkpoint_dir: str, **kwargs): - """ - Load checkpoint with token-based isolation support. - - Args: - checkpoint_dir: The twinkle:// path to the checkpoint or hub model ID - **kwargs: Additional keyword arguments including optional 'token' - """ - # Extract token from kwargs if provided (for user isolation) - token = kwargs.pop('token', None) - if not token: - raise ValueError('Token is required for loading checkpoints') - - # Create checkpoint manager with the token - checkpoint_manager = create_checkpoint_manager(token) - - # Use resolve_load_path to handle path resolution - resolved = checkpoint_manager.resolve_load_path(checkpoint_dir) - - if resolved.is_twinkle_path: - # Load from twinkle checkpoint - return super().load(name=resolved.checkpoint_name, output_dir=resolved.checkpoint_dir, **kwargs) - else: - # Load from hub - return super().load(name=resolved.checkpoint_name, **kwargs) diff --git a/src/twinkle/server/tinker/model.py b/src/twinkle/server/tinker/model.py deleted file mode 100644 index 80778c36..00000000 --- a/src/twinkle/server/tinker/model.py +++ /dev/null @@ -1,659 +0,0 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -""" -Tinker-compatible model management server. - -This module provides a Ray Serve deployment that manages distributed training models. -It handles: -1. Model and adapter lifecycle (create, load, unload) -2. Training operations (forward, backward, optimizer steps) -3. Checkpoint management (save/load weights) -4. Multi-user support with token-based isolation -""" -import traceback -from fastapi import FastAPI, Request -from peft import LoraConfig -from ray import serve -from ray.serve.config import RequestRouterConfig -from tinker import types -from typing import Any, Dict, Optional - -import twinkle -from twinkle import DeviceGroup, DeviceMesh -from twinkle.server.utils.adapter_manager import AdapterManagerMixin -from twinkle.server.utils.state import ServerStateProxy, get_server_state -from twinkle.server.utils.task_queue import TaskQueueConfig, TaskQueueMixin -from twinkle.server.utils.validation import get_token_from_request, verify_request_token -from twinkle.utils.logger import get_logger -from ..utils import wrap_builder_with_device_group_env -from .common.io_utils import create_checkpoint_manager, create_training_run_manager -from .common.router import StickyLoraRequestRouter - -logger = get_logger() - - -def build_model_app(model_id: str, - nproc_per_node: int, - device_group: Dict[str, Any], - device_mesh: Dict[str, Any], - deploy_options: Dict[str, Any], - use_megatron: bool = False, - adapter_config: Dict[str, Any] = {}, - queue_config: Optional[Dict[str, Any]] = {}, - **kwargs): - """Build a model management application for distributed training. - - This factory function creates a Ray Serve deployment that manages a training model - with support for multiple adapters (LoRA) and multi-user isolation. - - Args: - model_id: Base model identifier (e.g., "Qwen/Qwen2.5-0.5B-Instruct") - nproc_per_node: Number of processes per node for distributed training - device_group: Device group configuration dict - device_mesh: Device mesh configuration dict for tensor parallelism - deploy_options: Ray Serve deployment options - use_megatron: Whether to use Megatron backend (vs Transformers) - queue_config: Task queue configuration (rate limiting, etc.) - **kwargs: Additional model initialization arguments - - Returns: - Configured Ray Serve deployment bound with parameters - """ - app = FastAPI() - - @app.middleware('http') - async def verify_token(request: Request, call_next): - """Middleware to verify authentication token for all requests.""" - return await verify_request_token(request=request, call_next=call_next) - - @serve.deployment( - name='ModelManagement', - request_router_config=RequestRouterConfig(request_router_class=StickyLoraRequestRouter, ), - ) - @serve.ingress(app) - class ModelManagement(TaskQueueMixin, AdapterManagerMixin): - """Model management service handling training operations. - - This class manages: - - Base model and multiple adapter instances (multi-user LoRA) - - Training operations (forward, backward, optimizer steps) - - Adapter lifecycle with automatic cleanup via AdapterManagerMixin - - Per-user adapter limits and tracking - """ - - def __init__(self, - nproc_per_node: int, - device_group: Dict[str, Any], - device_mesh: Dict[str, Any], - use_megatron: bool = False, - queue_config: Optional[Dict[str, Any]] = None, - **kwargs): - """Initialize the model management service. - - Args: - nproc_per_node: Number of processes per node - device_group: Device group configuration - device_mesh: Device mesh configuration for parallelism - use_megatron: Whether to use Megatron backend - queue_config: Task queue configuration dict - **kwargs: Additional model initialization arguments - """ - self.device_group = DeviceGroup(**device_group) - twinkle.initialize( - mode='ray', nproc_per_node=nproc_per_node, groups=[self.device_group], lazy_collect=False) - if 'mesh_dim_names' in device_mesh: - self.device_mesh = DeviceMesh(**device_mesh) - else: - self.device_mesh = DeviceMesh.from_sizes(**device_mesh) - self.use_megatron = use_megatron - self.replica_id = serve.get_replica_context().replica_id.unique_id - self.max_loras = kwargs.get('max_loras', 5) - # Initialize model immediately - choose backend based on use_megatron - if use_megatron: - from .common.megatron_model import TwinkleCompatMegatronModel - self.model = TwinkleCompatMegatronModel( - model_id=model_id, - device_mesh=self.device_mesh, - remote_group=self.device_group.name, - instance_id=self.replica_id, - **kwargs) - else: - from .common.transformers_model import TwinkleCompatTransformersModel - self.model = TwinkleCompatTransformersModel( - model_id=model_id, - device_mesh=self.device_mesh, - remote_group=self.device_group.name, - instance_id=self.replica_id, - **kwargs) - self.base_model = model_id - self.state: ServerStateProxy = get_server_state() - - # Register this replica so the router can track capacity - self.state.register_replica(self.replica_id, self.max_loras) - - # Initialize task queue - self._init_task_queue(TaskQueueConfig.from_dict(queue_config)) - - self._init_adapter_manager(**adapter_config) - self.start_adapter_countdown() - - """ - This is a cache system, we must change to sticky routing - Reference docs: - 1. [Now]https://docs.ray.io/en/latest/serve/model-multiplexing.html - 2. https://docs.ray.io/en/latest/serve/llm/architecture/routing-policies.html - 3. https://github.com/ray-project/ray/pull/56855/changes - 4. Direct call actor instead of http or handler in server.py - """ - - @serve.multiplexed(max_num_models_per_replica=kwargs.get('max_loras', 5)) - async def _sticky_entry(self, sticky_key: str): - return sticky_key - - async def _ensure_sticky(self): - sticky_key = serve.get_multiplexed_model_id() - await self._sticky_entry(sticky_key) - - async def _on_request_start(self, request: Request) -> str: - await self._ensure_sticky() - token = get_token_from_request(request) - return token - - def __del__(self): - self.state.unregister_replica(self.replica_id) - - def _cleanup_adapter(self, adapter_name: str) -> None: - """Common adapter cleanup logic used by both manual unload and automatic expiration. - - This method handles: - 1. Clearing adapter state - 2. Removing adapter from model - 3. Unregistering from adapter manager - 4. Removing from server state - - Args: - adapter_name: Name of the adapter to clean up - """ - # Remove from model if it exists - if self.get_adapter_info(adapter_name): - # Clear adapter state - self.clear_adapter_state(adapter_name) - - self.model.remove_adapter(adapter_name) - # Unregister from adapter manager - self.unregister_adapter(adapter_name) - - # Remove from server state - self.state.unload_model(adapter_name) - - def _on_adapter_expired(self, adapter_name: str) -> None: - # Called from AdapterManagerMixin's countdown thread. - # Fail any pending tasks for this adapter/model. - self.fail_pending_tasks_for_model(adapter_name, reason='Adapter expired') - # Perform common cleanup (without token since it's automatic) - self._cleanup_adapter(adapter_name) - - @app.post('/create_model') - async def create_model(self, request: Request, body: types.CreateModelRequest) -> types.UntypedAPIFuture: - """Create a new model adapter for training. - - This endpoint: - 1. Registers the model in server state - 2. Creates a LoRA adapter with specified config - 3. Sets up processor, loss, and optimizer for the adapter - 4. Saves metadata to training run manager - - Args: - request: FastAPI request with auth token - body: CreateModelRequest with base_model and lora_config - - Returns: - UntypedAPIFuture wrapping CreateModelResponse with model_id - """ - token = await self._on_request_start(request) - - async def _create_adapter(): - model_id = None - try: - # Register a new model_id for each create_model call - model_id = self.state.register_model(body.model_dump(), token=token, replica_id=self.replica_id) - - # Create a new LoRA adapter for the model - if body.lora_config: - # TODO: support more lora config parameters, train_unembed, etc. - lora_cfg = LoraConfig(r=body.lora_config.rank, target_modules='all-linear') - - adapter_name = self.get_adapter_name(adapter_name=model_id) - - # Register adapter FIRST - self.register_adapter(adapter_name, token, session_id=body.session_id) - - # Create adapter AFTER successful registration - self.model.add_adapter_to_model(adapter_name=adapter_name, config_or_dir=lora_cfg) - - self.model.set_template('Template', adapter_name=adapter_name, model_id=self.base_model) - self.model.set_processor('InputProcessor', adapter_name=adapter_name) - self.model.set_optimizer('Adam', adapter_name=adapter_name) - - # Fresh adapter has no accumulated gradients. - self.set_adapter_state(adapter_name, 'grad_ready', False) - - training_run_manager = create_training_run_manager(token) - training_run_manager.save(model_id, body) - - return types.CreateModelResponse(model_id=model_id) - except Exception: - # Ensure we don't leave stale grad state. - if model_id: - adapter_name = self.get_adapter_name(adapter_name=model_id) - self._cleanup_adapter(adapter_name) - - logger.error(traceback.format_exc()) - return types.RequestFailedResponse( - error=traceback.format_exc(), - category=types.RequestErrorCategory.Server, - ) - - return await self.schedule_task( - _create_adapter, - token=token, - task_type='create_model', - ) - - @app.post('/get_info') - async def get_info(self, request: Request, body: types.GetInfoRequest) -> types.GetInfoResponse: - """Get information about a model. - - Args: - request: FastAPI request with auth token - body: GetInfoRequest with model_id - - Returns: - GetInfoResponse with model metadata (name, lora_rank, etc.) - """ - token = await self._on_request_start(request) - # Note: get_info doesn't require token for reading metadata in tinker - # Using a default token or None since this is read-only - training_run_manager = create_training_run_manager(token) - metadata = training_run_manager.get(str(body.model_id)) - model_name = metadata.base_model if metadata else model_id - lora_rank = None - is_lora = False - if metadata and hasattr(metadata, 'lora_rank') and metadata.lora_rank: - lora_rank = metadata.lora_rank - is_lora = metadata.is_lora - return types.GetInfoResponse( - model_data=types.ModelData(model_name=model_name), - model_id=body.model_id, - is_lora=is_lora, - lora_rank=lora_rank, - model_name=model_name, - ) - - @app.post('/unload_model') - async def unload_model(self, request: Request, body: types.UnloadModelRequest) -> types.UntypedAPIFuture: - """Unload a model adapter from memory. - - Removes the adapter and updates user adapter counts. - - Args: - request: FastAPI request with auth token - body: UnloadModelRequest with model_id - - Returns: - UntypedAPIFuture wrapping UnloadModelResponse - """ - token = await self._on_request_start(request) - - async def _do_unload(): - # Only remove adapter, not the base model - adapter_name = self.get_adapter_name(adapter_name=body.model_id) - # Use common cleanup logic - self._cleanup_adapter(adapter_name) - return types.UnloadModelResponse(model_id=body.model_id) - - return await self.schedule_task( - _do_unload, - model_id=body.model_id, - token=token, - task_type='unload_model', - ) - - @app.post('/forward') - async def forward(self, request: Request, body: types.ForwardRequest) -> types.UntypedAPIFuture: - """Execute forward pass without backward pass. - - Used for inference or evaluation without gradient computation. - - Args: - request: FastAPI request with auth token - body: ForwardRequest with input data - - Returns: - UntypedAPIFuture wrapping ForwardBackwardOutput with loss - """ - token = await self._on_request_start(request) - - async def _do_forward(): - try: - adapter_name = self.get_adapter_name(adapter_name=body.model_id) - self.assert_adapter_exists(adapter_name=adapter_name) - - # Touch adapter to reset inactivity counter - self.touch_adapter(adapter_name) - - datum_list = body.forward_input.data - loss_fn_config = body.forward_input.loss_fn_config or {} - - output = self.model.forward_only(inputs=datum_list, adapter_name=adapter_name) - loss = self.model.calculate_loss(adapter_name=adapter_name, **loss_fn_config) - return types.ForwardBackwardOutput( - loss_fn_output_type='CrossEntropyLossReturn', - loss_fn_outputs=output, - metrics={'loss:sum': loss}, - ) - except Exception: - logger.error(traceback.format_exc()) - return types.RequestFailedResponse( - error=traceback.format_exc(), - category=types.RequestErrorCategory.Server, - ) - - # Calculate input tokens and batch size for validation - datum_list = body.forward_input.data - input_tokens = sum(len(d.model_input.to_ints()) for d in datum_list) - batch_size = len(datum_list) - return await self.schedule_task( - _do_forward, - model_id=body.model_id, - token=token, - input_tokens=input_tokens, - batch_size=batch_size, - data_world_size=self.device_mesh.data_world_size, - task_type='forward', - ) - - @app.post('/forward_backward') - async def forward_backward(self, request: Request, - body: types.ForwardBackwardRequest) -> types.UntypedAPIFuture: - """Execute forward and backward pass for training. - - This combines forward pass and gradient computation. The implementation - differs based on backend: - - Megatron: Uses combined forward_backward method - - Transformers: Separate forward, calculate_loss, backward calls - - Args: - request: FastAPI request with auth token - body: ForwardBackwardRequest with training data - - Returns: - UntypedAPIFuture wrapping ForwardBackwardOutput with loss and metrics - """ - token = await self._on_request_start(request) - - async def _do_forward_backward(): - try: - adapter_name = self.get_adapter_name(adapter_name=body.model_id) - self.assert_adapter_exists(adapter_name=adapter_name) - - # Touch adapter to reset inactivity counter - self.touch_adapter(adapter_name) - - datum_list = body.forward_backward_input.data - loss_fn = body.forward_backward_input.loss_fn - loss_fn_config = body.forward_backward_input.loss_fn_config or {} - - # Unified forward_backward for both Megatron and Transformers - output, loss = self.model.forward_backward( - inputs=datum_list, adapter_name=adapter_name, loss_fn=loss_fn, **loss_fn_config) - if loss_fn == 'importance_sampling': - output_type = 'ImportanceSamplingLossReturn' - else: - output_type = 'CrossEntropyLossReturn' - # Mark gradients as ready after a successful forward_backward. - self.set_adapter_state(adapter_name, 'grad_ready', True) - return types.ForwardBackwardOutput( - loss_fn_output_type=output_type, - loss_fn_outputs=output, - metrics={'loss:avg': loss}, - ) - except Exception: - logger.error(traceback.format_exc()) - return types.RequestFailedResponse( - error=traceback.format_exc(), - category=types.RequestErrorCategory.Server, - ) - - # Calculate input tokens and batch size for validation - datum_list = body.forward_backward_input.data - input_tokens = sum(len(d.model_input.to_ints()) for d in datum_list) - batch_size = len(datum_list) - return await self.schedule_task( - _do_forward_backward, - model_id=body.model_id, - token=token, - input_tokens=input_tokens, - batch_size=batch_size, - data_world_size=self.device_mesh.data_world_size, - task_type='forward_backward', - ) - - @app.post('/optim_step') - async def optim_step(self, request: Request, body: types.OptimStepRequest) -> types.UntypedAPIFuture: - """Execute optimizer step to update model weights. - - Applies accumulated gradients to update adapter parameters. - - Args: - request: FastAPI request with auth token - body: OptimStepRequest with optimizer parameters - - Returns: - UntypedAPIFuture wrapping OptimStepResponse - """ - token = await self._on_request_start(request) - - async def _do_optim(): - try: - adapter_name = self.get_adapter_name(adapter_name=body.model_id) - self.assert_adapter_exists(adapter_name=adapter_name) - - # Disallow empty step (must have at least one forward_backward since last step) - if not self.get_adapter_state(adapter_name, 'grad_ready', False): - raise RuntimeError( - f'No accumulated gradients for adapter={adapter_name}; call forward_backward before optim_step' # noqa: E501 - ) - - # Touch adapter to reset inactivity counter - self.touch_adapter(adapter_name) - - self.model.step(adam_params=body.adam_params, adapter_name=adapter_name) - # Clear grad-ready after a successful step. - self.set_adapter_state(adapter_name, 'grad_ready', False) - metrics = self.model.calculate_metric(is_training=True, adapter_name=adapter_name) - return types.OptimStepResponse(metrics=metrics) - except Exception: - logger.error(traceback.format_exc()) - return types.RequestFailedResponse( - error=traceback.format_exc(), - category=types.RequestErrorCategory.Server, - ) - - return await self.schedule_task( - _do_optim, - model_id=body.model_id, - token=token, - task_type='optim_step', - ) - - @app.post('/save_weights') - async def save_weights(self, request: Request, body: types.SaveWeightsRequest) -> types.UntypedAPIFuture: - """Save model adapter weights to storage. - - Saves both model weights and optimizer state for training resumption. - Uses token-based isolation for user-specific storage. - - Args: - request: FastAPI request with auth token - body: SaveWeightsRequest with path and model_id - - Returns: - UntypedAPIFuture wrapping SaveWeightsResponse with saved path - """ - token = await self._on_request_start(request) - - async def _do_save(): - try: - adapter_name = self.get_adapter_name(adapter_name=body.model_id) - self.assert_adapter_exists(adapter_name=adapter_name) - - # Touch adapter to reset inactivity counter - self.touch_adapter(adapter_name) - - checkpoint_manager = create_checkpoint_manager(token) - - # get save dir with token-based isolation - checkpoint_name = checkpoint_manager.get_ckpt_name(body.path) - save_dir = checkpoint_manager.get_save_dir(model_id=body.model_id, is_sampler=False) - - self.model.save( - name=checkpoint_name, output_dir=save_dir, adapter_name=adapter_name, save_optimizer=True) - - tinker_path = checkpoint_manager.save(body.model_id, name=checkpoint_name, is_sampler=False) - - return types.SaveWeightsResponse(path=tinker_path, type='save_weights') - except Exception: - logger.error(traceback.format_exc()) - return types.RequestFailedResponse( - error=traceback.format_exc(), - category=types.RequestErrorCategory.Server, - ) - - return await self.schedule_task( - _do_save, - model_id=body.model_id, - token=token, - task_type='save_weights', - ) - - @app.post('/save_weights_for_sampler') - async def save_weights_for_sampler(self, request: Request, - body: types.SaveWeightsForSamplerRequest) -> types.UntypedAPIFuture: - """Save/convert weights for inference use. - - Saves adapter weights without optimizer state for use with sampler. - Creates a sampling session for tracking. - - Args: - request: FastAPI request with auth token - body: SaveWeightsForSamplerRequest with model_id and path - - Returns: - UntypedAPIFuture wrapping SaveWeightsForSamplerResponseInternal - """ - token = await self._on_request_start(request) - - async def _do_save_for_sampler(): - try: - - adapter_name = self.get_adapter_name(adapter_name=body.model_id) - self.assert_adapter_exists(adapter_name=adapter_name) - - # Touch adapter to reset inactivity counter - self.touch_adapter(adapter_name) - - checkpoint_manager = create_checkpoint_manager(token) - - # get save dir with token-based isolation - checkpoint_name = checkpoint_manager.get_ckpt_name(body.path) - save_dir = checkpoint_manager.get_save_dir(model_id=body.model_id, is_sampler=True) - # NOTE: Need to save meta first to ensure only one sample weight exists - tinker_path = checkpoint_manager.save(body.model_id, name=checkpoint_name, is_sampler=True) - - logger.info(f'Saving weights to {save_dir}') - # Save weights with save_optimizer=False for sampler use - self.model.save( - name=checkpoint_name, output_dir=save_dir, adapter_name=adapter_name, save_optimizer=False) - - # Create sampling session with resolved model_path/base_model. - payload = body.model_dump() - payload['model_path'] = tinker_path - metadata = self.state.get_model_metadata(body.model_id) or {} - if metadata.get('base_model'): - payload['base_model'] = metadata['base_model'] - sampling_session_id = self.state.create_sampling_session(payload) - - return types.SaveWeightsForSamplerResponseInternal( - path=None, # Disable path return for internal use - sampling_session_id=sampling_session_id) - except Exception: - logger.error(traceback.format_exc()) - return types.RequestFailedResponse( - error=traceback.format_exc(), - category=types.RequestErrorCategory.Server, - ) - - return await self.schedule_task( - _do_save_for_sampler, - model_id=body.model_id, - token=token, - task_type='save_weights_for_sampler', - ) - - @app.post('/load_weights') - async def load_weights(self, request: Request, body: types.LoadWeightsRequest) -> types.UntypedAPIFuture: - """Load model adapter weights from storage. - - Loads weights and optionally optimizer state for training resumption. - Uses token-based isolation for user-specific storage access. - - Args: - request: FastAPI request with auth token - body: LoadWeightsRequest with path and optimizer flag - - Returns: - UntypedAPIFuture wrapping LoadWeightsResponse - """ - token = await self._on_request_start(request) - - async def _do_load(): - try: - assert self.model is not None, 'Model not loaded, please load model first' - - adapter_name = self.get_adapter_name(adapter_name=body.model_id) - self.assert_adapter_exists(adapter_name=adapter_name) - - # Touch adapter to reset inactivity counter - self.touch_adapter(adapter_name) - - weight_path = body.path - load_optimizer = body.optimizer - - self.model.load( - checkpoint_dir=weight_path, - load_optimizer=load_optimizer, - adapter_name=adapter_name, - token=token) - - # Loading a checkpoint should reset step readiness. - self.set_adapter_state(adapter_name, 'grad_ready', False) - return types.LoadWeightsResponse(path=body.path, type='load_weights') - except Exception: - logger.error(traceback.format_exc()) - return types.RequestFailedResponse( - error=traceback.format_exc(), - category=types.RequestErrorCategory.Server, - ) - - return await self.schedule_task( - _do_load, - model_id=body.model_id, - token=token, - task_type='load_weights', - ) - - return ModelManagement.options(**deploy_options).bind(nproc_per_node, device_group, device_mesh, use_megatron, - queue_config, **kwargs) - - -build_model_app = wrap_builder_with_device_group_env(build_model_app) diff --git a/src/twinkle/server/tinker/sampler.py b/src/twinkle/server/tinker/sampler.py deleted file mode 100644 index 406524f3..00000000 --- a/src/twinkle/server/tinker/sampler.py +++ /dev/null @@ -1,251 +0,0 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -""" -Tinker-compatible sampler (inference) server. - -This module provides a Ray Serve deployment for distributed text generation/inference. -It supports: -1. vLLM and Torch sampler backends -2. LoRA adapter loading via adapter URIs -3. Multi-user inference with rate limiting -4. Flexible sampling parameters -""" -import os -import traceback -from fastapi import FastAPI, Request -from ray import serve -from tinker import types -from typing import Any, Dict, Optional - -import twinkle -from twinkle import DeviceGroup, DeviceMesh -from twinkle.data_format import SamplingParams -from twinkle.server.utils.state import ServerStateProxy, get_server_state -from twinkle.server.utils.task_queue import TaskQueueConfig, TaskQueueMixin -from twinkle.server.utils.validation import get_token_from_request, verify_request_token -from twinkle.utils.logger import get_logger -from ..utils import wrap_builder_with_device_group_env -from .common.io_utils import create_checkpoint_manager - -logger = get_logger() - - -def build_sampler_app(model_id: str, - nproc_per_node: int, - device_group: Dict[str, Any], - device_mesh: Dict[str, Any], - deploy_options: Dict[str, Any], - sampler_type: str = 'vllm', - engine_args: Optional[Dict[str, Any]] = None, - queue_config: Optional[Dict[str, Any]] = None, - **kwargs): - """Build a sampler application for tinker-compatible inference. - - This factory function creates a Ray Serve deployment that manages a sampler - (inference engine) with support for LoRA adapters and rate limiting. - - Args: - model_id: Model identifier (e.g., "ms://Qwen/Qwen2.5-0.5B-Instruct") - nproc_per_node: Number of processes per node - device_group: Device group configuration dict - device_mesh: Device mesh configuration dict for parallelism - deploy_options: Ray Serve deployment options - sampler_type: Type of sampler to use ('vllm' or 'torch') - engine_args: Additional engine arguments for the sampler - queue_config: Task queue configuration dict (rps_limit, tps_limit, etc.) - **kwargs: Additional arguments passed to the sampler - - Returns: - Ray Serve deployment bound with configuration - """ - app = FastAPI() - - @app.middleware('http') - async def verify_token(request: Request, call_next): - """Middleware to verify authentication token for all requests.""" - return await verify_request_token(request=request, call_next=call_next) - - @serve.deployment(name='SamplerManagement') - @serve.ingress(app) - class SamplerManagement(TaskQueueMixin): - """Sampler management service for text generation inference. - - This class manages: - - vLLM or Torch sampler initialization and lifecycle - - Inference requests with LoRA adapter support - - Rate limiting via task queue - - Sampling parameter conversion between Tinker and Twinkle formats - """ - - def __init__(self, - nproc_per_node: int, - device_group: Dict[str, Any], - device_mesh: Dict[str, Any], - sampler_type: str = 'vllm', - engine_args: Optional[Dict[str, Any]] = None, - queue_config: Optional[Dict[str, Any]] = None, - **kwargs): - """Initialize the sampler management service. - - Args: - nproc_per_node: Number of processes per node - device_group: Device group configuration - device_mesh: Device mesh configuration for parallelism - sampler_type: Type of sampler ('vllm' or 'torch') - engine_args: Additional engine arguments for sampler - queue_config: Task queue configuration dict - **kwargs: Additional sampler initialization arguments - """ - self.device_group = DeviceGroup(**device_group) - twinkle.initialize( - mode='ray', nproc_per_node=nproc_per_node, groups=[self.device_group], lazy_collect=False) - if 'mesh_dim_names' in device_mesh: - self.device_mesh = DeviceMesh(**device_mesh) - else: - self.device_mesh = DeviceMesh.from_sizes(**device_mesh) - self.sampler_type = sampler_type - replica_context = serve.get_replica_context() - replica_id = replica_context.replica_id.unique_id - - # Initialize sampler based on type - if sampler_type == 'vllm': - from twinkle.sampler import vLLMSampler - sampler_kwargs = engine_args or {} - self.sampler = vLLMSampler( - model_id=model_id, - engine_args=sampler_kwargs, - device_mesh=self.device_mesh, - remote_group=self.device_group.name, - instance_id=replica_id, - **{ - k: v - for k, v in kwargs.items() if k not in ['engine_args'] - }) - else: # torch sampler - from twinkle.sampler import TorchSampler - self.sampler = TorchSampler(model_id=model_id, device_mesh=self.device_mesh, **kwargs) - self.sampler.set_template('Template', model_id=model_id) - self.state: ServerStateProxy = get_server_state() - self._init_task_queue(TaskQueueConfig.from_dict(queue_config)) - - @serve.multiplexed(max_num_models_per_replica=5) - async def _sticky_entry(self, sticky_key: str): - return sticky_key - - async def _ensure_sticky(self): - sticky_key = serve.get_multiplexed_model_id() - await self._sticky_entry(sticky_key) - - async def _on_request_start(self, request: Request) -> str: - await self._ensure_sticky() - token = get_token_from_request(request) - return token - - @app.post('/asample') - async def asample(self, request: Request, body: types.SampleRequest) -> types.UntypedAPIFuture: - """Execute text generation (inference). - - This endpoint: - 1. Extracts prompt token IDs from the request - 2. Determines adapter URI from model_path if provided - 3. Converts Tinker sampling params to Twinkle format - 4. Calls the sampler engine to generate text - 5. Converts results back to Tinker format - - Args: - request: FastAPI request with auth token - body: SampleRequest with prompt, sampling params, and adapter info - - Returns: - UntypedAPIFuture wrapping SampleResponse with generated sequences - """ - token = await self._on_request_start(request) - - async def _do_sample(): - try: - # Extract prompt token IDs from ModelInput - prompt_inputs = {'input_ids': body.prompt.to_ints()} - - # Get model_path: use body.model_path or look up from sampling session - model_path = body.model_path - if not model_path and body.sampling_session_id: - session = self.state.get_sampling_session(body.sampling_session_id) - if session: - model_path = session.get('model_path') - - # Parse and resolve adapter URI from model_path - adapter_uri = None - if model_path: - checkpoint_manager = create_checkpoint_manager(token) - adapter_name, adapter_uri = checkpoint_manager.parse_adapter_uri(model_path) - - # Validate adapter URI existence if provided - if not adapter_uri or not os.path.exists(adapter_uri): - return types.RequestFailedResponse( - error=f'Adapter URI {model_path} does not exist. Please check the model_path.', - category=types.RequestErrorCategory.User, - ) - - # Convert tinker SamplingParams to twinkle SamplingParams if needed - sampling_params = None - if body.sampling_params: - sampling_params = SamplingParams( - max_tokens=body.sampling_params.max_tokens or 256, - temperature=body.sampling_params.temperature or 1.0, - top_p=body.sampling_params.top_p, - top_k=body.sampling_params.top_k, - stop=body.sampling_params.stop, - ) - - # Only request logprobs when the client asks for them. Some backends may - # return None entries in logprobs, which breaks pydantic validation. - response = self.sampler.sample( - inputs=[prompt_inputs] * body.num_samples, # For speed up - sampling_params=sampling_params, - adapter_path=adapter_uri, - # adapter_name=adapter_name, - ) - - # Convert twinkle SampleResponse to tinker types.SampleResponse - tinker_sequences = [] - for seq in response.sequences: - logprobs = None - if seq.logprobs is not None: - if any(lp is None for lp in seq.logprobs): - # Fix: backend can emit None logprobs for some tokens, which triggers - # pydantic "Input should be a valid number" errors in SampleResponse. - # We drop the field to keep the response valid. - logprobs = None - else: - logprobs = list(seq.logprobs) - tinker_sequences.append( - types.SampledSequence( - stop_reason=seq.stop_reason, - tokens=list(seq.tokens), - logprobs=logprobs, - )) - return types.SampleResponse( - sequences=tinker_sequences, - prompt_logprobs=response.prompt_logprobs, - topk_prompt_logprobs=response.topk_prompt_logprobs, - ) - except Exception: - logger.error(traceback.format_exc()) - return types.RequestFailedResponse( - error=traceback.format_exc(), - category=types.RequestErrorCategory.Server, - ) - - # Calculate input tokens for rate limiting - input_tokens = len(body.prompt.to_ints()) - return await self.schedule_task( - _do_sample, - token=token, - input_tokens=input_tokens, - task_type='sample', - ) - - return SamplerManagement.options(**deploy_options).bind(nproc_per_node, device_group, device_mesh, sampler_type, - engine_args, queue_config, **kwargs) - - -build_sampler_app = wrap_builder_with_device_group_env(build_sampler_app) diff --git a/src/twinkle/server/tinker/server.py b/src/twinkle/server/tinker/server.py deleted file mode 100644 index 81543c58..00000000 --- a/src/twinkle/server/tinker/server.py +++ /dev/null @@ -1,613 +0,0 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -""" -Tinker-compatible server implementation. - -This module provides a Ray Serve-based server that implements the Tinker API for distributed -training and inference. It acts as a routing layer that: -1. Handles client requests and validates tokens -2. Manages training runs and checkpoints with user isolation -3. Proxies requests to appropriate model or sampler deployments based on base_model -""" - -from __future__ import annotations - -import asyncio -import os -from fastapi import FastAPI, HTTPException, Request, Response -from ray import serve -from tinker import types -from typing import Any, Dict, List, Optional - -from twinkle.hub import HubOperation -from twinkle.server.utils.state import get_server_state -from twinkle.server.utils.task_queue import QueueState -from twinkle.server.utils.validation import get_token_from_request, verify_request_token -from twinkle.utils.logger import get_logger -from .common.io_utils import create_checkpoint_manager, create_training_run_manager -from .proxy import ServiceProxy - -logger = get_logger() - - -def build_server_app(deploy_options: dict[str, Any], - supported_models: list[types.SupportedModel] | None = None, - server_config: dict[str, Any] = {}, - http_options: dict[str, Any] | None = None, - **kwargs): - """Build and configure the Tinker-compatible server application. - - This factory function creates a FastAPI application with Ray Serve deployment - that handles routing, authentication, and proxying for training and inference. - - Args: - deploy_options: Ray Serve deployment configuration (num_replicas, etc.) - supported_models: List of supported base models for validation - server_config: Server configuration options (per_token_adapter_limit, etc.) - **kwargs: Additional keyword arguments (route_prefix, etc.) - - Returns: - Configured Ray Serve deployment bound with options - """ - app = FastAPI() - - @app.middleware('http') - async def verify_token(request: Request, call_next): - """Middleware to verify authentication token for all requests.""" - return await verify_request_token(request=request, call_next=call_next) - - @serve.deployment(name='TinkerCompatServer') - @serve.ingress(app) - class TinkerCompatServer: - """Main server class handling Tinker API endpoints and request routing. - - This class manages: - - Server state and session management - - Request validation and authentication - - Proxying to model/sampler deployments - - Training run and checkpoint CRUD operations - """ - - def __init__(self, - supported_models: list[types.SupportedModel] | None = None, - server_config: dict[str, Any] = {}, - http_options: dict[str, Any] | None = None, - **kwargs) -> None: - """Initialize the Tinker-compatible server. - - Args: - supported_models: List of supported base models for validation - server_config: Server configuration options - http_options: HTTP server options (host, port) for internal proxy routing - **kwargs: Additional configuration (route_prefix, etc.) - """ - self.state = get_server_state(**server_config) - self.route_prefix = kwargs.get('route_prefix', '/api/v1') - self.http_options = http_options or {} - - # Initialize service proxy for routing requests to model/sampler services - self.proxy = ServiceProxy(http_options=http_options, route_prefix=self.route_prefix) - - self.supported_models = self.normalize_models(supported_models) or [ - types.SupportedModel(model_name='Qwen/Qwen3-30B-A3B-Instruct-2507'), - ] - # Lock for ModelScope config file operations (login writes, get_user_info reads) - self._modelscope_config_lock = asyncio.Lock() - - def normalize_models(self, supported_models): - # Normalize supported_models to objects; passing raw dicts can trigger internal errors - # when creating LoRA training clients via the tinker API. - if not supported_models: - return [] - normalized = [] - for item in supported_models: - if isinstance(item, types.SupportedModel): - normalized.append(item) - elif isinstance(item, dict): - normalized.append(types.SupportedModel(**item)) - elif isinstance(item, str): - normalized.append(types.SupportedModel(model_name=item)) - return normalized - - def _validate_base_model(self, base_model: str) -> None: - """Validate that base_model is in supported_models list. - - Args: - base_model: The base model name to validate - - Raises: - HTTPException: If base_model is not supported - """ - supported_model_names = [m.model_name for m in self.supported_models] - if base_model not in supported_model_names: - raise HTTPException( - status_code=400, - detail=f"Base model '{base_model}' is not supported. " - f"Supported models: {', '.join(supported_model_names)}") - - def _get_base_model(self, model_id: str) -> str: - """Get base_model for a model_id from state metadata. - - Args: - model_id: The model identifier to lookup - - Returns: - The base model name - - Raises: - HTTPException: If model_id not found in state - """ - metadata = self.state.get_model_metadata(model_id) - if metadata and metadata.get('base_model'): - return metadata['base_model'] - raise HTTPException(status_code=404, detail=f'Model {model_id} not found') - - # --- Endpoints --------------------------------------------------------- - - @app.get('/healthz') - async def healthz(self, request: Request) -> types.HealthResponse: - """Health check endpoint. - - Returns: - HealthResponse indicating server is operational - """ - return types.HealthResponse(status='ok') - - @app.get('/get_server_capabilities') - async def get_server_capabilities(self, request: Request) -> types.GetServerCapabilitiesResponse: - """Get server capabilities including supported models. - - Returns: - GetServerCapabilitiesResponse with list of supported models - """ - return types.GetServerCapabilitiesResponse(supported_models=self.supported_models) - - @app.post('/telemetry') - async def telemetry(self, request: Request, body: types.TelemetrySendRequest) -> types.TelemetryResponse: - """Accept telemetry data from clients. - - Note: Telemetry is accepted but not persisted; this endpoint is intentionally lightweight. - - Returns: - TelemetryResponse indicating data was accepted - """ - return types.TelemetryResponse(status='accepted') - - @app.post('/create_session') - async def create_session(self, request: Request, - body: types.CreateSessionRequest) -> types.CreateSessionResponse: - """Create a new training session. - - Args: - body: Session creation parameters - - Returns: - CreateSessionResponse with new session_id - """ - session_id = self.state.create_session(body.model_dump()) - return types.CreateSessionResponse(session_id=session_id) - - @app.post('/session_heartbeat') - async def session_heartbeat(self, request: Request, - body: types.SessionHeartbeatRequest) -> types.SessionHeartbeatResponse: - """Keep a session alive via heartbeat. - - Args: - body: Heartbeat request with session_id - - Returns: - SessionHeartbeatResponse if session is alive - - Raises: - HTTPException: If session not found - """ - alive = self.state.touch_session(body.session_id) - if not alive: - raise HTTPException(status_code=404, detail='Unknown session') - return types.SessionHeartbeatResponse() - - @app.post('/create_sampling_session') - async def create_sampling_session( - self, request: Request, - body: types.CreateSamplingSessionRequest) -> types.CreateSamplingSessionResponse: - """Create a new sampling (inference) session. - - Args: - body: Sampling session creation parameters - - Returns: - CreateSamplingSessionResponse with new sampling_session_id - """ - sampling_session_id = self.state.create_sampling_session(body.model_dump()) - return types.CreateSamplingSessionResponse(sampling_session_id=sampling_session_id) - - @app.post('/retrieve_future') - async def retrieve_future(self, request: Request, body: types.FutureRetrieveRequest) -> Any: - """Retrieve the result of an async task with long polling. - - Server waits up to 30s for task completion instead of immediately returning try_again. - This reduces client polling frequency from ~100 req/s to ~1 req/30s. - """ - request_id = body.request_id - max_wait = float(os.environ.get('TWINKLE_LONG_POLL_TIMEOUT', '30')) - poll_interval = float(os.environ.get('TWINKLE_POLL_INTERVAL', '0.5')) - start = asyncio.get_event_loop().time() - - # Long poll: wait for task completion or timeout - while True: - record = self.state.get_future(request_id) - - if record is None: - return {'type': 'try_again'} - - status = record.get('status') - - # Task finished, return immediately - if status not in ('pending', 'queued', 'running', 'rate_limited'): - break - - # Timeout, let client retry - if asyncio.get_event_loop().time() - start >= max_wait: - response_data = {'type': 'try_again'} - if queue_state := record.get('queue_state'): - response_data['queue_state'] = queue_state - if queue_state_reason := record.get('queue_state_reason'): - response_data['queue_state_reason'] = queue_state_reason - return response_data - - await asyncio.sleep(poll_interval) - - # Handle final result - record = self.state.get_future(request_id) - if not record: - return {'type': 'try_again'} - - status = record.get('status') - - if status == 'rate_limited': - return { - 'type': 'try_again', - 'queue_state': QueueState.PAUSED_RATE_LIMIT.value, - 'queue_state_reason': record.get('reason', 'Rate limit exceeded') - } - - if status == 'failed': - result = record.get('result', {}) - return {'error': result.get('error', 'Unknown error'), 'category': result.get('category', 'Server')} - - result = record.get('result') - if result is None: - raise HTTPException(status_code=500, detail='Task completed but no result found') - - if hasattr(result, 'model_dump'): - return result.model_dump() - return result - - # --- Restful Endpoints ------------------------------------------ - - @app.get('/training_runs') - async def get_training_runs(self, - request: Request, - limit: int = 20, - offset: int = 0) -> types.TrainingRunsResponse: - """ - List training runs for the current user. - - Uses token-based isolation to only show runs owned by the requesting user. - - Args: - request: FastAPI request with token in state - limit: Maximum number of results - offset: Pagination offset - - Returns: - TrainingRunsResponse with user's training runs - """ - token = get_token_from_request(request) - training_run_manager = create_training_run_manager(token) - return training_run_manager.list_runs(limit=limit, offset=offset) - - @app.get('/training_runs/{run_id}') - async def get_training_run(self, request: Request, run_id: str) -> types.TrainingRun: - """ - Get a specific training run. - - Uses token-based isolation to verify user owns the run. - - Args: - request: FastAPI request with token in state - run_id: The training run identifier - - Returns: - TrainingRun details - - Raises: - HTTPException 404 if run not found in user's token directory - """ - token = get_token_from_request(request) - training_run_manager = create_training_run_manager(token) - run = training_run_manager.get(run_id) - if not run: - raise HTTPException(status_code=404, detail=f'Training run {run_id} not found') - return run - - @app.get('/training_runs/{run_id}/checkpoints') - async def get_run_checkpoints(self, request: Request, run_id: str) -> types.CheckpointsListResponse: - """ - List checkpoints for a training run. - - Uses token-based isolation to verify user owns the run. - - Args: - request: FastAPI request with token in state - run_id: The training run identifier - - Returns: - CheckpointsListResponse with list of checkpoints - - Raises: - HTTPException 404 if run not found in user's token directory - """ - token = get_token_from_request(request) - checkpoint_manager = create_checkpoint_manager(token) - response = checkpoint_manager.list_checkpoints(run_id) - if not response: - raise HTTPException(status_code=404, detail=f'Training run {run_id} not found') - return response - - @app.delete('/training_runs/{run_id}/checkpoints/{checkpoint_id:path}') - async def delete_run_checkpoint(self, request: Request, run_id: str, checkpoint_id: str) -> Any: - """ - Delete a checkpoint from a training run. - - Uses token-based isolation to verify user owns the checkpoint. - - Args: - request: FastAPI request with token in state - run_id: The training run identifier - checkpoint_id: The checkpoint identifier (path) - - Returns: - None (200 OK) if successful - - Raises: - HTTPException 404 if checkpoint not found in user's token directory - """ - token = get_token_from_request(request) - checkpoint_manager = create_checkpoint_manager(token) - success = checkpoint_manager.delete(run_id, checkpoint_id) - if not success: - raise HTTPException(status_code=404, detail=f'Checkpoint {checkpoint_id} not found for run {run_id}') - return None - - @app.post('/weights_info') - async def weights_info(self, request: Request, body: dict[str, Any]) -> types.WeightsInfoResponse: - """ - Get weights information from a tinker path. - - Uses token-based isolation to verify user owns the weights. - - Args: - request: FastAPI request with token in state - body: Dict with 'tinker_path' key - - Returns: - WeightsInfoResponse with weight details - - Raises: - HTTPException 404 if weights not found in user's token directory - """ - token = get_token_from_request(request) - checkpoint_manager = create_checkpoint_manager(token) - tinker_path = body.get('tinker_path') - response = checkpoint_manager.get_weights_info(tinker_path) - if not response: - raise HTTPException(status_code=404, detail=f'Weights at {tinker_path} not found') - return response - - @app.post('/training_runs/{run_id}/checkpoints/{checkpoint_id:path}/publish') - async def publish_checkpoint(self, request: Request, run_id: str, checkpoint_id: str) -> Response: - """ - Publish a checkpoint to the hub. - - This endpoint uploads a checkpoint to a hub repository. The hub_model_id - is automatically generated from the checkpoint content and user token. - The upload is performed asynchronously by default. - - Args: - request: FastAPI request object (contains token in state) - run_id: The training run identifier - checkpoint_id: The checkpoint identifier (can include path like weights/checkpoint_name) - - Returns: - Response with 204 No Content status - - Raises: - HTTPException 404 if checkpoint not found or access denied - """ - token = get_token_from_request(request) - - training_run_manager = create_training_run_manager(token) - checkpoint_manager = create_checkpoint_manager(token) - - # Check ownership and get training run info - run = training_run_manager.get(run_id) - if not run: - raise HTTPException(status_code=404, detail=f'Training run {run_id} not found or access denied') - - # Get checkpoint with token-based path - checkpoint = checkpoint_manager.get(run_id, checkpoint_id) - if not checkpoint: - raise HTTPException(status_code=404, detail=f'Checkpoint {checkpoint_id} not found') - - # Get the filesystem path for the checkpoint - checkpoint_dir = str(checkpoint_manager.get_ckpt_dir(run_id, checkpoint_id)) - - # Generate hub_model_id from checkpoint content and user token - # Format: {username}/{run_id}_{checkpoint_name} - # Use lock to prevent race conditions when multiple requests access ModelScope config file - async with self._modelscope_config_lock: - try: - from modelscope.hub.api import HubApi, ModelScopeConfig - hub_api = HubApi(token=token) - hub_api.login() # Save user info to local - username = ModelScopeConfig.get_user_info()[0] - except Exception as e: - logger.error(f'Failed to get username from ModelScope: {e}') - raise HTTPException( - status_code=401, - detail='Failed to get username from ModelScope. Please ensure your token is valid.') - - # Extract checkpoint name from checkpoint_id (e.g., "weights/step-8" -> "step-8") - checkpoint_name = checkpoint_id.split('/')[-1] - hub_model_id = f'{username}/{run_id}_{checkpoint_name}' - - # Upload to hub asynchronously with default async_upload=True - HubOperation.async_push_to_hub(repo_id=hub_model_id, folder_path=checkpoint_dir, token=token, private=True) - - # Return 204 No Content (successful with no response body) - return Response(status_code=204) - - # --- Proxy Endpoints --------------------------------------------------------- - - # --- Model Proxy Endpoints ---------------------------------------- - - @app.post('/create_model') - async def create_model(self, request: Request, body: types.CreateModelRequest) -> Any: - """Create a new model (adapter) for training. - - Args: - body: Model creation request with base_model and config - - Returns: - Proxied response from model service - """ - self._validate_base_model(body.base_model) - return await self.proxy.proxy_to_model(request, 'create_model', body.base_model) - - @app.post('/get_info') - async def get_info(self, request: Request, body: types.GetInfoRequest) -> Any: - """Get information about a model. - - Args: - body: Info request with model_id - - Returns: - Proxied response from model service - """ - return await self.proxy.proxy_to_model(request, 'get_info', self._get_base_model(body.model_id)) - - @app.post('/unload_model') - async def unload_model(self, request: Request, body: types.UnloadModelRequest) -> Any: - """Unload a model adapter from memory. - - Args: - body: Unload request with model_id - - Returns: - Proxied response from model service - """ - return await self.proxy.proxy_to_model(request, 'unload_model', self._get_base_model(body.model_id)) - - @app.post('/forward') - async def forward(self, request: Request, body: types.ForwardRequest) -> Any: - """Execute forward pass without backward. - - Args: - body: Forward request with inputs - - Returns: - Proxied response from model service - """ - return await self.proxy.proxy_to_model(request, 'forward', self._get_base_model(body.model_id)) - - @app.post('/forward_backward') - async def forward_backward(self, request: Request, body: types.ForwardBackwardRequest) -> Any: - """Execute forward and backward pass for training. - - Args: - body: Forward-backward request with inputs - - Returns: - Proxied response from model service - """ - return await self.proxy.proxy_to_model(request, 'forward_backward', self._get_base_model(body.model_id)) - - @app.post('/optim_step') - async def optim_step(self, request: Request, body: types.OptimStepRequest) -> Any: - """Execute optimizer step to update model weights. - - Args: - body: Optimizer step request with parameters - - Returns: - Proxied response from model service - """ - return await self.proxy.proxy_to_model(request, 'optim_step', self._get_base_model(body.model_id)) - - @app.post('/save_weights') - async def save_weights(self, request: Request, body: types.SaveWeightsRequest) -> Any: - """Save model weights to storage. - - Args: - body: Save weights request with path - - Returns: - Proxied response from model service - """ - return await self.proxy.proxy_to_model(request, 'save_weights', self._get_base_model(body.model_id)) - - @app.post('/load_weights') - async def load_weights(self, request: Request, body: types.LoadWeightsRequest) -> Any: - """Load model weights from storage. - - Args: - body: Load weights request with path - - Returns: - Proxied response from model service - """ - return await self.proxy.proxy_to_model(request, 'load_weights', self._get_base_model(body.model_id)) - - # --- Sampler Proxy Endpoints ---------------------------------------- - - @app.post('/asample') - async def asample(self, request: Request, body: types.SampleRequest) -> Any: - """Execute text generation (inference). - - Proxies the request to the sampler service based on base_model. - The sampler handles model_path resolution from sampling session. - - Args: - body: Sample request with prompt and sampling parameters - - Returns: - Proxied response from sampler service - """ - base_model = body.base_model - - # If base_model not provided, look up from sampling session - if not base_model and body.sampling_session_id: - session = self.state.get_sampling_session(body.sampling_session_id) - if session: - base_model = session.get('base_model') - - return await self.proxy.proxy_to_sampler(request, 'asample', base_model) - - @app.post('/save_weights_for_sampler') - async def save_weights_for_sampler(self, request: Request, body: types.SaveWeightsForSamplerRequest) -> Any: - """Save/convert weights for inference use. - - This endpoint proxies to the model service to save weights for sampler. - - Args: - body: Save weights request with model_id - - Returns: - Proxied response from model service - """ - # Proxy to model service for save_weights_for_sampler - base_model = self._get_base_model(body.model_id) - return await self.proxy.proxy_to_model(request, 'save_weights_for_sampler', base_model) - - return TinkerCompatServer.options(**deploy_options).bind( - supported_models=supported_models, server_config=server_config, http_options=http_options, **kwargs) diff --git a/src/twinkle/server/twinkle/__init__.py b/src/twinkle/server/twinkle/__init__.py deleted file mode 100644 index 7371b1d7..00000000 --- a/src/twinkle/server/twinkle/__init__.py +++ /dev/null @@ -1,20 +0,0 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -import sys -from typing import TYPE_CHECKING - -from twinkle.utils.import_utils import _LazyModule - -_import_structure = { - 'model': ['build_model_app'], - 'processor': ['build_processor_app'], - 'sampler': ['build_sampler_app'], - 'server': ['build_server_app'], -} - -if TYPE_CHECKING: - from .model import build_model_app - from .processor import build_processor_app - from .sampler import build_sampler_app - from .server import build_server_app -else: - sys.modules[__name__] = _LazyModule(__name__, __file__, _import_structure, module_spec=__spec__) diff --git a/src/twinkle/server/twinkle/common/transformers_model.py b/src/twinkle/server/twinkle/common/transformers_model.py deleted file mode 100644 index c67a0a28..00000000 --- a/src/twinkle/server/twinkle/common/transformers_model.py +++ /dev/null @@ -1,41 +0,0 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -import numpy as np -import torch -from collections.abc import Mapping -from typing import Any, List, Union - -from twinkle import remote_class, remote_function -from twinkle.data_format import InputFeature, Trajectory -from twinkle.model import MultiLoraTransformersModel - - -@remote_class() -class TwinkleCompatTransformersModel(MultiLoraTransformersModel): - - @staticmethod - def _to_cpu_safe_output(obj: Any) -> Any: - """Convert nested outputs into CPU-safe Python objects for HTTP transport.""" - from twinkle.utils import torch_util - - if isinstance(obj, torch.Tensor): - tensor = torch_util.to_local_tensor(obj).detach().cpu() - if tensor.numel() == 1: - return tensor.item() - return tensor.tolist() - if isinstance(obj, np.ndarray): - if obj.size == 1: - return obj.item() - return obj.tolist() - if isinstance(obj, np.generic): - return obj.item() - if isinstance(obj, Mapping): - return {key: TwinkleCompatTransformersModel._to_cpu_safe_output(value) for key, value in obj.items()} - if isinstance(obj, (list, tuple)): - return [TwinkleCompatTransformersModel._to_cpu_safe_output(value) for value in obj] - return obj - - @remote_function(dispatch='slice_dp', collect='mean') - def forward_backward(self, *, inputs: Union[InputFeature, List[InputFeature], Trajectory, List[Trajectory]], - **kwargs): - output = super().forward_backward(inputs=inputs, **kwargs) - return self._to_cpu_safe_output(output) diff --git a/src/twinkle/server/twinkle/model.py b/src/twinkle/server/twinkle/model.py deleted file mode 100644 index 4bf4bf4b..00000000 --- a/src/twinkle/server/twinkle/model.py +++ /dev/null @@ -1,584 +0,0 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -import os -from fastapi import FastAPI, Request -from peft import LoraConfig -from pydantic import BaseModel -from ray import serve -from typing import Any, Dict, Optional - -import twinkle -from twinkle import DeviceGroup, DeviceMesh -from twinkle.data_format import InputFeature, Trajectory -from twinkle.server.utils.adapter_manager import AdapterManagerMixin -from twinkle.server.utils.state import ServerStateProxy, get_server_state -from twinkle.server.utils.validation import verify_request_token -from twinkle.utils.logger import get_logger -from .common.io_utils import CreateModelRequest -from .common.io_utils import LoraConfig as IoLoraConfig -from .common.io_utils import create_checkpoint_manager, create_training_run_manager -from .common.serialize import deserialize_object - -logger = get_logger() - - -class CreateRequest(BaseModel): - - class Config: - extra = 'allow' - - -class ForwardRequest(BaseModel): - inputs: Any - adapter_name: str - - class Config: - extra = 'allow' - - -class ForwardOnlyRequest(BaseModel): - inputs: Any - adapter_name: Optional[str] = None - - class Config: - extra = 'allow' - - -class AdapterRequest(BaseModel): - adapter_name: str - - class Config: - extra = 'allow' - - -class SetLossRequest(BaseModel): - loss_cls: str - adapter_name: str - - class Config: - extra = 'allow' - - -class SetOptimizerRequest(BaseModel): - optimizer_cls: str - adapter_name: str - - class Config: - extra = 'allow' - - -class SetLrSchedulerRequest(BaseModel): - scheduler_cls: str - adapter_name: str - - class Config: - extra = 'allow' - - -class SaveRequest(BaseModel): - adapter_name: str - save_optimizer: bool = False - name: Optional[str] = None - - class Config: - extra = 'allow' - - -class UploadToHubRequest(BaseModel): - checkpoint_dir: str - hub_model_id: str - hub_token: Optional[str] = None - async_upload: bool = True - - class Config: - extra = 'allow' - - -class LoadRequest(BaseModel): - adapter_name: str - load_optimizer: bool = False - name: str - - class Config: - extra = 'allow' - - -class AddAdapterRequest(BaseModel): - adapter_name: str - config: str - - class Config: - extra = 'allow' - - -class SetTemplateRequest(BaseModel): - template_cls: str - adapter_name: str - - class Config: - extra = 'allow' - - -class SetProcessorRequest(BaseModel): - processor_cls: str - adapter_name: str - - class Config: - extra = 'allow' - - -class HeartbeatRequest(BaseModel): - adapter_name: str - - -class CalculateMetricRequest(BaseModel): - adapter_name: str - is_training: bool = True - - class Config: - extra = 'allow' - - -class GetStateDictRequest(BaseModel): - adapter_name: str - - class Config: - extra = 'allow' - - -def build_model_app(model_id: str, - nproc_per_node: int, - device_group: Dict[str, Any], - device_mesh: Dict[str, Any], - deploy_options: Dict[str, Any], - use_megatron: bool = False, - adapter_config: Dict[str, Any] = {}, - **kwargs): - app = FastAPI() - - @app.middleware('http') - async def verify_token(request: Request, call_next): - return await verify_request_token(request=request, call_next=call_next) - - @serve.deployment(name='ModelManagement') - @serve.ingress(app) - class ModelManagement(AdapterManagerMixin): - - def __init__(self, nproc_per_node: int, device_group: Dict[str, Any], device_mesh: Dict[str, Any]): - self.device_group = DeviceGroup(**device_group) - twinkle.initialize( - mode='ray', nproc_per_node=nproc_per_node, groups=[self.device_group], lazy_collect=False) - if 'mesh_dim_names' in device_mesh: - self.device_mesh = DeviceMesh(**device_mesh) - else: - self.device_mesh = DeviceMesh.from_sizes(**device_mesh) - replica_context = serve.get_replica_context() - replica_id = replica_context.replica_id.unique_id - if use_megatron: - from twinkle.model import MultiLoraMegatronModel - self.model = MultiLoraMegatronModel( - model_id=model_id, - device_mesh=self.device_mesh, - remote_group=self.device_group.name, - instance_id=replica_id, - **kwargs) - else: - from .common.transformers_model import TwinkleCompatTransformersModel - self.model = TwinkleCompatTransformersModel( - model_id=model_id, - device_mesh=self.device_mesh, - remote_group=self.device_group.name, - instance_id=replica_id, - **kwargs) - - # Initialize state before adapter manager (mixin needs self.state) - self.state: ServerStateProxy = get_server_state() - - # Initialize adapter manager from mixin - self._init_adapter_manager(**adapter_config) - self.start_adapter_countdown() - - def _on_adapter_expired(self, adapter_name: str) -> None: - """Handle adapter expiration by removing it from the model. - - This method is called automatically by AdapterManagerMixin when - an adapter exceeds its timeout or TTL. - - Args: - adapter_name: Name of the expired adapter to remove. - """ - # Remove from model if it exists - if self.get_adapter_info(adapter_name): - # Clear adapter state - self.clear_adapter_state(adapter_name) - # Unregister from adapter manager - self.unregister_adapter(adapter_name) - - # Remove from server state - self.state.unload_model(adapter_name) - # Remove adapter from model - self.model.remove_adapter(adapter_name) - - @app.post('/create') - def create(self, request: Request, body: CreateRequest): - return {'status': 'ok'} - - @staticmethod - def get_adapter_name(request: Request, adapter_name: Optional[str]) -> Optional[str]: - if adapter_name is None or adapter_name == '': - return None - return request.state.request_id + '-' + adapter_name - - @app.post('/forward') - def forward(self, request: Request, body: ForwardRequest): - adapter_name = self.get_adapter_name(request, adapter_name=body.adapter_name) - self.assert_adapter_exists(adapter_name=adapter_name) - extra_kwargs = body.model_extra or {} - inputs = body.inputs - if isinstance(inputs, list): - _input = inputs[0] - if 'input_ids' in _input: - inputs = [InputFeature(**_input) for _input in inputs] - else: - inputs = [Trajectory(**_input) for _input in inputs] - else: - assert isinstance(inputs, dict) - inputs = InputFeature(**inputs) if 'input_ids' in inputs else Trajectory(**inputs) - ret = self.model.forward(inputs=inputs, adapter_name=adapter_name, **extra_kwargs) - return {'result': ret} - - @app.post('/forward_only') - def forward_only(self, request: Request, body: ForwardOnlyRequest): - adapter_name = self.get_adapter_name(request, adapter_name=body.adapter_name) - self.assert_adapter_exists(adapter_name=adapter_name) - extra_kwargs = body.model_extra or {} - inputs = body.inputs - if isinstance(inputs, list): - _input = inputs[0] - if 'input_ids' in _input: - inputs = [InputFeature(**_input) for _input in inputs] - else: - inputs = [Trajectory(**_input) for _input in inputs] - else: - assert isinstance(inputs, dict) - inputs = InputFeature(**inputs) if 'input_ids' in inputs else Trajectory(**inputs) - ret = self.model.forward_only(inputs=inputs, adapter_name=adapter_name, **extra_kwargs) - return {'result': ret} - - @app.post('/calculate_loss') - def calculate_loss(self, request: Request, body: AdapterRequest): - adapter_name = self.get_adapter_name(request, adapter_name=body.adapter_name) - self.assert_adapter_exists(adapter_name=adapter_name) - extra_kwargs = body.model_extra or {} - ret = self.model.calculate_loss(adapter_name=adapter_name, **extra_kwargs) - return {'result': ret} - - @app.post('/backward') - def backward(self, request: Request, body: AdapterRequest): - adapter_name = self.get_adapter_name(request, adapter_name=body.adapter_name) - self.assert_adapter_exists(adapter_name=adapter_name) - extra_kwargs = body.model_extra or {} - ret = self.model.backward(adapter_name=adapter_name, **extra_kwargs) - return {'result': ret} - - @app.post('/forward_backward') - def forward_backward(self, request: Request, body: ForwardRequest): - adapter_name = self.get_adapter_name(request, adapter_name=body.adapter_name) - self.assert_adapter_exists(adapter_name=adapter_name) - extra_kwargs = body.model_extra or {} - inputs = body.inputs - if isinstance(inputs, list): - _input = inputs[0] - if 'input_ids' in _input: - inputs = [InputFeature(**_input) for _input in inputs] - else: - inputs = [Trajectory(**_input) for _input in inputs] - else: - assert isinstance(inputs, dict) - inputs = InputFeature(**inputs) if 'input_ids' in inputs else Trajectory(**inputs) - ret = self.model.forward_backward(inputs=inputs, adapter_name=adapter_name, **extra_kwargs) - return {'result': ret} - - @app.post('/get_train_configs') - def get_train_configs(self, request: Request, body: AdapterRequest): - adapter_name = self.get_adapter_name(request, adapter_name=body.adapter_name) - self.assert_adapter_exists(adapter_name=adapter_name) - extra_kwargs = body.model_extra or {} - ret = self.model.get_train_configs(adapter_name=adapter_name, **extra_kwargs) - return {'result': ret} - - @app.post('/clip_grad_norm') - def clip_grad_norm(self, request: Request, body: AdapterRequest): - adapter_name = self.get_adapter_name(request, adapter_name=body.adapter_name) - self.assert_adapter_exists(adapter_name=adapter_name) - extra_kwargs = body.model_extra or {} - ret = self.model.clip_grad_norm(adapter_name=adapter_name, **extra_kwargs) - return {'result': str(ret)} - - @app.post('/step') - def step(self, request: Request, body: AdapterRequest): - adapter_name = self.get_adapter_name(request, adapter_name=body.adapter_name) - self.assert_adapter_exists(adapter_name=adapter_name) - extra_kwargs = body.model_extra or {} - ret = self.model.step(adapter_name=adapter_name, **extra_kwargs) - return {'result': ret} - - @app.post('/zero_grad') - def zero_grad(self, request: Request, body: AdapterRequest): - adapter_name = self.get_adapter_name(request, adapter_name=body.adapter_name) - self.assert_adapter_exists(adapter_name=adapter_name) - extra_kwargs = body.model_extra or {} - ret = self.model.zero_grad(adapter_name=adapter_name, **extra_kwargs) - return {'result': ret} - - @app.post('/lr_step') - def lr_step(self, request: Request, body: AdapterRequest): - adapter_name = self.get_adapter_name(request, adapter_name=body.adapter_name) - self.assert_adapter_exists(adapter_name=adapter_name) - extra_kwargs = body.model_extra or {} - ret = self.model.lr_step(adapter_name=adapter_name, **extra_kwargs) - return {'result': ret} - - @app.post('/set_loss') - def set_loss(self, request: Request, body: SetLossRequest): - adapter_name = self.get_adapter_name(request, adapter_name=body.adapter_name) - self.assert_adapter_exists(adapter_name=adapter_name) - extra_kwargs = body.model_extra or {} - ret = self.model.set_loss(body.loss_cls, adapter_name=adapter_name, **extra_kwargs) - return {'result': ret} - - @app.post('/set_optimizer') - def set_optimizer(self, request: Request, body: SetOptimizerRequest): - adapter_name = self.get_adapter_name(request, adapter_name=body.adapter_name) - self.assert_adapter_exists(adapter_name=adapter_name) - extra_kwargs = body.model_extra or {} - ret = self.model.set_optimizer(body.optimizer_cls, adapter_name=adapter_name, **extra_kwargs) - return {'result': ret} - - @app.post('/set_lr_scheduler') - def set_lr_scheduler(self, request: Request, body: SetLrSchedulerRequest): - adapter_name = self.get_adapter_name(request, adapter_name=body.adapter_name) - self.assert_adapter_exists(adapter_name=adapter_name) - extra_kwargs = body.model_extra or {} - ret = self.model.set_lr_scheduler(body.scheduler_cls, adapter_name=adapter_name, **extra_kwargs) - return {'result': ret} - - @app.post('/save') - def save(self, request: Request, body: SaveRequest): - """ - Save adapter weights with token-based isolation. - - This endpoint: - 1. Saves adapter weights to token-specific directory - 2. Saves checkpoint metadata with ownership tracking - - Args: - request: FastAPI request object (contains token in state) - body: SaveRequest with adapter_name, name, and save_optimizer flag - - Returns: - Dict with result containing the twinkle:// path to saved checkpoint - """ - adapter_name = self.get_adapter_name(request, adapter_name=body.adapter_name) - self.assert_adapter_exists(adapter_name=adapter_name) - extra_kwargs = body.model_extra or {} - - # Extract token for directory isolation - token = request.state.token - checkpoint_manager = create_checkpoint_manager(token) - - # Get checkpoint name and save directory with token-based path - checkpoint_name = checkpoint_manager.get_ckpt_name(body.name) - save_dir = checkpoint_manager.get_save_dir(model_id=adapter_name, is_sampler=False) - - # Save the model weights - checkpoint_dir = self.model.save( - name=checkpoint_name, - output_dir=save_dir, - adapter_name=adapter_name, - save_optimizer=body.save_optimizer, - **extra_kwargs) - - # Save checkpoint metadata - twinkle_path = checkpoint_manager.save(model_id=adapter_name, name=checkpoint_name, is_sampler=False) - - return {'result': twinkle_path, 'checkpoint_dir': checkpoint_dir} - - @app.post('/load') - def load(self, request: Request, body: LoadRequest): - """ - Load adapter weights with token-based access validation. - - This endpoint: - 1. Validates user has access to the checkpoint - 2. Loads weights from token-specific directory - - Args: - request: FastAPI request object (contains token in state) - body: LoadRequest with adapter_name, name, and load_optimizer flag - - Returns: - Dict with result indicating load status - """ - adapter_name = self.get_adapter_name(request, adapter_name=body.adapter_name) - self.assert_adapter_exists(adapter_name=adapter_name) - extra_kwargs = body.model_extra or {} - - # Extract token for directory isolation - token = request.state.token - checkpoint_manager = create_checkpoint_manager(token) - - # Use resolve_load_path to handle path resolution - resolved = checkpoint_manager.resolve_load_path(body.name) - - # Load from twinkle checkpoint directory - ret = self.model.load( - name=resolved.checkpoint_name, - output_dir=resolved.checkpoint_dir, - adapter_name=adapter_name, - load_optimizer=body.load_optimizer, - token=token, - **extra_kwargs) - - return {'result': ret} - - @app.post('/upload_to_hub') - def upload_to_hub(self, request: Request, body: UploadToHubRequest): - """ - Upload model checkpoint to hub. - - This endpoint uploads a previously saved checkpoint to a hub repository. - - Args: - request: FastAPI request object (contains token in state) - body: UploadToHubRequest with checkpoint_dir, hub_model_id, hub_token, and async_upload - - Returns: - Dict with success status and message - """ - token = request.state.token - - # Check if body.name is a twinkle:// path or a simple checkpoint name - if body.checkpoint_dir.startswith('twinkle://'): - # Parse twinkle:// path - checkpoint_manager = create_checkpoint_manager(token) - parsed = checkpoint_manager.parse_twinkle_path(body.checkpoint_dir) - if not parsed: - raise ValueError(f'Invalid twinkle path format: {body.checkpoint_dir}') - # parsed.checkpoint_id is like "weights/step-8" - checkpoint_id = parsed.checkpoint_id - - # Use the training_run_id from the path as the model_id - model_id_to_load = parsed.training_run_id - - # Verify checkpoint exists and user has access - checkpoint = checkpoint_manager.get(model_id_to_load, checkpoint_id) - if not checkpoint: - raise ValueError(f'Checkpoint not found or access denied: {body.checkpoint_dir}') - - # Get the actual directory path for the specific checkpoint - checkpoint_dir = str( - checkpoint_manager.get_ckpt_dir(model_id=model_id_to_load, checkpoint_id=checkpoint_id)) - else: - checkpoint_dir = body.checkpoint_dir - - # Call the model's upload_to_hub method - self.model.upload_to_hub( - checkpoint_dir=checkpoint_dir, - hub_model_id=body.hub_model_id, - hub_token=body.hub_token or token, - async_upload=body.async_upload) - - return {'result': body.hub_model_id} - - @app.post('/add_adapter_to_model') - def add_adapter_to_model(self, request: Request, body: AddAdapterRequest): - """ - Add a new adapter to the model. - - This endpoint: - 1. Creates a new adapter with the specified configuration - 2. Registers it in the adapter tracking system - 3. Saves training run metadata with token-based isolation - - Args: - request: FastAPI request object (contains token in state) - body: AddAdapterRequest with adapter_name and config - - Returns: - Dict with status and adapter_name - """ - assert body.adapter_name, 'You need to specify a valid `adapter_name`' - adapter_name = self.get_adapter_name(request, adapter_name=body.adapter_name) - config = deserialize_object(body.config) - extra_kwargs = body.model_extra or {} - - # Extract token for metadata storage - token = request.state.token - training_run_manager = create_training_run_manager(token) - - # Register adapter FIRST - self.register_adapter(adapter_name, token) - - # Create adapter AFTER successful registration - self.model.add_adapter_to_model(adapter_name, config, **extra_kwargs) - - # Save training run metadata (similar to tinker's create_model) - # Create a training run config from the adapter configuration - lora_config = None - if isinstance(config, LoraConfig): - lora_config = IoLoraConfig( - rank=config.r, - train_unembed=False, # Default values - train_mlp=True, - train_attn=True) - - run_config = CreateModelRequest( - base_model=model_id, # Use the model_id from build_model_app - lora_config=lora_config, - user_metadata={'adapter_name': body.adapter_name}) - - # Save training run metadata with token-based isolation - training_run_manager.save(adapter_name, run_config) - - return {'status': 'ok', 'adapter_name': adapter_name} - - @app.post('/set_template') - def set_template(self, request: Request, body: SetTemplateRequest): - adapter_name = self.get_adapter_name(request, adapter_name=body.adapter_name) - self.assert_adapter_exists(adapter_name=adapter_name) - extra_kwargs = body.model_extra or {} - ret = self.model.set_template(body.template_cls, adapter_name=adapter_name, **extra_kwargs) - return {'result': ret} - - @app.post('/set_processor') - def set_processor(self, request: Request, body: SetProcessorRequest): - adapter_name = self.get_adapter_name(request, adapter_name=body.adapter_name) - self.assert_adapter_exists(adapter_name=adapter_name) - extra_kwargs = body.model_extra or {} - ret = self.model.set_processor(body.processor_cls, adapter_name=adapter_name, **extra_kwargs) - return {'result': ret} - - @app.post('/heartbeat') - def heartbeat(self, request: Request, body: HeartbeatRequest): - adapter_name = self.get_adapter_name(request, adapter_name=body.adapter_name) - self.assert_adapter_exists(adapter_name=adapter_name) - self.touch_adapter(adapter_name) - return {'status': 'ok'} - - @app.post('/calculate_metric') - def calculate_metric(self, request: Request, body: CalculateMetricRequest): - adapter_name = self.get_adapter_name(request, adapter_name=body.adapter_name) - self.assert_adapter_exists(adapter_name=adapter_name) - extra_kwargs = body.model_extra or {} - ret = self.model.calculate_metric(is_training=body.is_training, adapter_name=adapter_name, **extra_kwargs) - return {'result': ret} - - @app.post('/get_state_dict') - def get_state_dict(self, request: Request, body: GetStateDictRequest): - adapter_name = self.get_adapter_name(request, adapter_name=body.adapter_name) - self.assert_adapter_exists(adapter_name=adapter_name) - extra_kwargs = body.model_extra or {} - ret = self.model.get_state_dict(adapter_name=adapter_name, **extra_kwargs) - return {'result': ret} - - return ModelManagement.options(**deploy_options).bind(nproc_per_node, device_group, device_mesh) diff --git a/src/twinkle/server/twinkle/processor.py b/src/twinkle/server/twinkle/processor.py deleted file mode 100644 index cbead9b7..00000000 --- a/src/twinkle/server/twinkle/processor.py +++ /dev/null @@ -1,188 +0,0 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -import importlib -import os -import threading -import uuid -from fastapi import FastAPI, HTTPException, Request -from pydantic import BaseModel -from ray import serve -from typing import Any, Dict - -import twinkle -from twinkle import DeviceGroup, DeviceMesh, get_logger -from twinkle.server.utils.state import ServerStateProxy, get_server_state -from twinkle.server.utils.validation import verify_request_token -from .common.serialize import deserialize_object - -logger = get_logger() - - -class CreateRequest(BaseModel): - processor_type: str - class_type: str - - class Config: - extra = 'allow' - - -class HeartbeatRequest(BaseModel): - processor_id: str - - -class CallRequest(BaseModel): - processor_id: str - function: str - - class Config: - extra = 'allow' - - -def build_processor_app(nproc_per_node: int, ncpu_proc_per_node: int, device_group: Dict[str, Any], - device_mesh: Dict[str, Any], deploy_options: Dict[str, Any], **kwargs): - app = FastAPI() - - @app.middleware('http') - async def verify_token(request: Request, call_next): - return await verify_request_token(request=request, call_next=call_next) - - processors = ['dataset', 'dataloader', 'preprocessor', 'processor', 'reward', 'template', 'weight_loader'] - - @serve.deployment(name='ProcessorManagement') - @serve.ingress(app) - class ProcessorManagement: - - COUNT_DOWN = 60 * 30 - - def __init__(self, nproc_per_node: int, ncpu_proc_per_node: int, device_group: Dict[str, Any], - device_mesh: Dict[str, Any]): - self.device_group = DeviceGroup(**device_group) - twinkle.initialize( - mode='ray', - nproc_per_node=nproc_per_node, - groups=[self.device_group], - lazy_collect=False, - ncpu_proc_per_node=ncpu_proc_per_node) - if 'mesh_dim_names' in device_mesh: - self.device_mesh = DeviceMesh(**device_mesh) - else: - self.device_mesh = DeviceMesh.from_sizes(**device_mesh) - self.resource_dict = {} - self.resource_records: Dict[str, int] = {} - self.hb_thread = threading.Thread(target=self.countdown, daemon=True) - self.hb_thread.start() - self.state: ServerStateProxy = get_server_state() - self.per_token_processor_limit = int(os.environ.get('TWINKLE_PER_USER_PROCESSOR_LIMIT', 20)) - self.key_token_dict = {} - - def countdown(self): - import time - while True: - time.sleep(1) - for key in list(self.resource_records.keys()): - self.resource_records[key] += 1 - if self.resource_records[key] > self.COUNT_DOWN: - self.resource_records.pop(key, None) - self.resource_dict.pop(key, None) - if key in self.key_token_dict: - self.handle_processor_count(self.key_token_dict.pop(key), False) - - def assert_processor_exists(self, processor_id: str): - assert processor_id and processor_id in self.resource_dict, f'Processor {processor_id} not found' - - def handle_processor_count(self, token: str, add: bool): - user_key = token + '_' + 'processor' - cur_count = self.state.get_config(user_key) or 0 - if add: - if cur_count < self.per_token_processor_limit: - self.state.add_config(user_key, cur_count + 1) - else: - raise RuntimeError(f'Processor count limitation reached: {self.per_token_processor_limit}') - else: - if cur_count > 0: - cur_count -= 1 - self.state.add_config(user_key, cur_count) - if cur_count <= 0: - self.state.pop_config(user_key) - - @app.post('/create') - def create(self, request: Request, body: CreateRequest): - - processor_type_name = body.processor_type - class_type = body.class_type - kwargs = body.model_extra or {} - - assert processor_type_name in processors, f'Invalid processor type: {processor_type_name}' - processor_module = importlib.import_module(f'twinkle.{processor_type_name}') - assert hasattr(processor_module, class_type), f'Class {class_type} not found in {processor_type_name}' - self.handle_processor_count(request.state.token, True) - processor_id = str(uuid.uuid4().hex) - self.key_token_dict[processor_id] = request.state.token - - kwargs.pop('remote_group', None) - kwargs.pop('device_mesh', None) - - _kwargs = {} - for key, value in kwargs.items(): - if isinstance(value, str) and value.startswith('pid:'): - ref_id = value[4:] - _kwargs[key] = self.resource_dict[ref_id] - else: - value = deserialize_object(value) - _kwargs[key] = value - - processor = getattr(processor_module, class_type)( - remote_group=self.device_group.name, device_mesh=self.device_mesh, instance_id=processor_id, **_kwargs) - self.resource_dict[processor_id] = processor - self.resource_records[processor_id] = 0 - return {'processor_id': 'pid:' + processor_id} - - @app.post('/heartbeat') - def heartbeat(self, body: HeartbeatRequest): - processor_ids = body.processor_id.split(',') - for _id in processor_ids: - if _id and _id in self.resource_dict: - self.resource_records[_id] = 0 - return {'status': 'ok'} - - @app.post('/call') - def call(self, body: CallRequest): - processor_id = body.processor_id - function_name = body.function - kwargs = body.model_extra or {} - processor_id = processor_id[4:] - self.assert_processor_exists(processor_id=processor_id) - processor = self.resource_dict.get(processor_id) - function = getattr(processor, function_name, None) - - assert function is not None, f'`{function_name}` not found in {processor.__class__}' - assert hasattr(function, '_execute'), f'Cannot call inner method of {processor.__class__}' - - _kwargs = {} - for key, value in kwargs.items(): - if isinstance(value, str) and value.startswith('pid:'): - ref_id = value[4:] - _kwargs[key] = self.resource_dict[ref_id] - else: - value = deserialize_object(value) - _kwargs[key] = value - - # Special handling for __next__ to catch StopIteration - # We convert StopIteration to HTTP 410 (Gone) which semantically means - # "the resource (next item) is no longer available" - if function_name == '__next__': - try: - result = function(**_kwargs) - return {'result': result} - except StopIteration: - # Use HTTP 410 Gone to indicate iterator exhausted - # This is a clean signal that won't be confused with errors - raise HTTPException(status_code=410, detail='Iterator exhausted') - - result = function(**_kwargs) - if function_name == '__iter__': - return {'result': 'ok'} - else: - return {'result': result} - - return ProcessorManagement.options(**deploy_options).bind(nproc_per_node, ncpu_proc_per_node, device_group, - device_mesh) diff --git a/src/twinkle/server/twinkle/sampler.py b/src/twinkle/server/twinkle/sampler.py deleted file mode 100644 index 27ffd694..00000000 --- a/src/twinkle/server/twinkle/sampler.py +++ /dev/null @@ -1,308 +0,0 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -""" -Twinkle sampler (inference) server. - -This module provides a Ray Serve deployment for distributed text generation/inference. -It supports: -1. vLLM and Torch sampler backends -2. LoRA adapter loading via adapter URIs (twinkle:// paths or local paths) -3. Multi-user inference with adapter lifecycle management -4. Flexible sampling parameters -""" -import traceback -from fastapi import FastAPI, Request -from pydantic import BaseModel, Field -from ray import serve -from typing import Any, Dict, List, Optional, Union - -import twinkle -from twinkle import DeviceGroup, DeviceMesh -from twinkle.data_format import InputFeature, SamplingParams, Trajectory -from twinkle.server.utils.adapter_manager import AdapterManagerMixin -from twinkle.server.utils.state import ServerStateProxy, get_server_state -from twinkle.server.utils.validation import get_token_from_request, verify_request_token -from twinkle.utils.logger import get_logger - -logger = get_logger() - -# ----- Request/Response Models ----- - - -class SampleRequest(BaseModel): - """Request body for the /sample endpoint.""" - inputs: Any = Field(..., description='List of Trajectory or InputFeature dicts') - sampling_params: Optional[Dict[str, Any]] = Field( - None, description='Sampling parameters (max_tokens, temperature, etc.)') - adapter_name: str = Field('', description='Adapter name for LoRA inference') - adapter_uri: Optional[str] = Field( - None, description='Adapter URI (twinkle:// path or local path) for LoRA inference') - num_samples: int = Field(1, description='Number of completions to generate per prompt') - - -class SampleResponseModel(BaseModel): - """Response body for the /sample endpoint.""" - sequences: List[Dict[str, - Any]] = Field(..., - description='List of sampled sequences, each with tokens, logprobs, stop_reason') - prompt_logprobs: Optional[List[Optional[float]]] = None - topk_prompt_logprobs: Optional[List[Optional[List]]] = None - - -class SetTemplateRequest(BaseModel): - """Request body for the /set_template endpoint.""" - template_cls: str = Field(..., description="Template class name (e.g. 'Template')") - adapter_name: str = Field('', description='Adapter name to associate the template with') - - class Config: - extra = 'allow' - - -class SetTemplateResponse(BaseModel): - """Response body for the /set_template endpoint.""" - status: str = 'ok' - - -class AddAdapterRequest(BaseModel): - """Request body for the /add_adapter_to_sampler endpoint.""" - adapter_name: str = Field(..., description='Name of the adapter to add') - config: Any = Field(..., description='LoRA configuration dict') - - -class AddAdapterResponse(BaseModel): - """Response body for the /add_adapter_to_sampler endpoint.""" - status: str = 'ok' - adapter_name: str - - -class HeartbeatRequest(BaseModel): - """Request body for the /heartbeat endpoint.""" - adapter_name: str = Field(..., description='Adapter name to keep alive') - - -class HeartbeatResponse(BaseModel): - """Response body for the /heartbeat endpoint.""" - status: str = 'ok' - - -class CreateResponse(BaseModel): - """Response body for the /create endpoint.""" - status: str = 'ok' - - -# ----- Application Builder ----- - - -def build_sampler_app(model_id: str, - nproc_per_node: int = 1, - device_group: Dict[str, Any] = None, - device_mesh: Dict[str, Any] = None, - deploy_options: Dict[str, Any] = None, - sampler_type: str = 'vllm', - engine_args: Optional[Dict[str, Any]] = None, - adapter_config: Optional[Dict[str, Any]] = None, - **kwargs): - """Build a sampler application for text generation inference. - - Args: - model_id: Model identifier (e.g., "Qwen/Qwen3.5-4B") - nproc_per_node: Number of GPU processes per node - device_group: Device group configuration dict - device_mesh: Device mesh configuration dict for parallelism - deploy_options: Ray Serve deployment options - sampler_type: Type of sampler to use ('vllm' or 'torch') - engine_args: Additional engine arguments for the sampler - adapter_config: Adapter lifecycle config (adapter_timeout, per_token_adapter_limit) - **kwargs: Additional arguments passed to the sampler - - Returns: - Ray Serve deployment bound with configuration - """ - app = FastAPI( - title='Twinkle Sampler', description='REST API for distributed text generation inference', version='1.0.0') - - @app.middleware('http') - async def verify_token(request: Request, call_next): - return await verify_request_token(request=request, call_next=call_next) - - @serve.deployment(name='SamplerManagement') - @serve.ingress(app) - class SamplerManagement(AdapterManagerMixin): - """Sampler management service for text generation inference. - - Manages: - - vLLM or Torch sampler initialization and lifecycle - - Adapter lifecycle via AdapterManagerMixin - - Inference requests with LoRA adapter support - - Template configuration for trajectory encoding - """ - - def __init__(self, - nproc_per_node: int, - device_group: Dict[str, Any], - device_mesh: Dict[str, Any], - sampler_type: str = 'vllm', - engine_args: Optional[Dict[str, Any]] = None, - adapter_config: Optional[Dict[str, Any]] = None, - **kwargs): - self.device_group = DeviceGroup(**device_group) - twinkle.initialize( - mode='ray', nproc_per_node=nproc_per_node, groups=[self.device_group], lazy_collect=False) - if 'mesh_dim_names' in device_mesh: - self.device_mesh = DeviceMesh(**device_mesh) - else: - self.device_mesh = DeviceMesh.from_sizes(**device_mesh) - self.sampler_type = sampler_type - replica_context = serve.get_replica_context() - replica_id = replica_context.replica_id.unique_id - # Initialize sampler based on type - if sampler_type == 'vllm': - from twinkle.sampler import vLLMSampler - sampler_kwargs = engine_args or {} - self.sampler = vLLMSampler( - model_id=model_id, - engine_args=sampler_kwargs, - device_mesh=self.device_mesh, - remote_group=self.device_group.name, - instance_id=replica_id, - **{ - k: v - for k, v in kwargs.items() if k not in ['engine_args'] - }) - else: - from twinkle.sampler import TorchSampler - self.sampler = TorchSampler( - model_id=model_id, - device_mesh=self.device_mesh, - instance_id=replica_id, - remote_group=self.device_group.name, - **kwargs) - - # Initialize state and adapter manager - self.state: ServerStateProxy = get_server_state() - _adapter_config = adapter_config or {} - self._init_adapter_manager(**_adapter_config) - self.start_adapter_countdown() - - def _on_adapter_expired(self, adapter_name: str, token: str) -> None: - """Handle expired adapters by removing them from the sampler.""" - try: - self.sampler.remove_adapter(adapter_name) - logger.info(f'Removed expired adapter {adapter_name}') - # Adapter count is now tracked dynamically, no manual update needed - except Exception as e: - logger.warning(f'Failed to remove expired adapter {adapter_name}: {e}') - - @staticmethod - def _get_adapter_name(request: Request, adapter_name: Optional[str]) -> Optional[str]: - if adapter_name is None or adapter_name == '': - return None - return request.state.request_id + '-' + adapter_name - - @app.post('/create', response_model=CreateResponse) - def create(self, request: Request) -> CreateResponse: - """Health check / session creation endpoint.""" - return CreateResponse() - - @app.post('/sample', response_model=SampleResponseModel) - def sample(self, request: Request, body: SampleRequest) -> SampleResponseModel: - """Sample completions from the model. - - Supports: - - Trajectory inputs (messages-based, requires template to be set) - - InputFeature inputs (pre-tokenized input_ids) - - LoRA adapter via adapter_name or adapter_uri (twinkle:// path) - - Multiple completions per prompt via num_samples - """ - try: - # Resolve adapter - adapter_path = None - adapter_name = body.adapter_name or '' - full_adapter_name = self._get_adapter_name(request, adapter_name) or '' - - if body.adapter_uri: - from .common.io_utils import create_checkpoint_manager - token = get_token_from_request(request) - checkpoint_manager = create_checkpoint_manager(token) - _, adapter_path = checkpoint_manager.parse_adapter_uri(body.adapter_uri) - - # Parse inputs - inputs = body.inputs - if isinstance(inputs, list) and inputs: - first = inputs[0] - if isinstance(first, dict) and 'input_ids' in first: - inputs = [InputFeature(**item) for item in inputs] - else: - inputs = [Trajectory(**item) for item in inputs] - elif isinstance(inputs, dict): - if 'input_ids' in inputs: - inputs = [InputFeature(**inputs)] - else: - inputs = [Trajectory(**inputs)] - - # Build sampling params - params = None - if body.sampling_params: - params = SamplingParams.from_dict(body.sampling_params) - - # Call sampler - response = self.sampler.sample( - inputs, - params, - adapter_name=full_adapter_name, - adapter_path=adapter_path, - num_samples=body.num_samples, - ) - if callable(response): - response = response() - - # Convert to response model - sequences = [] - for seq in response.sequences: - sequences.append({ - 'stop_reason': seq.stop_reason, - 'tokens': list(seq.tokens), - 'logprobs': list(seq.logprobs) if seq.logprobs is not None else None, - }) - - return SampleResponseModel( - sequences=sequences, - prompt_logprobs=response.prompt_logprobs, - topk_prompt_logprobs=response.topk_prompt_logprobs, - ) - except Exception: - logger.error(traceback.format_exc()) - raise - - @app.post('/set_template', response_model=SetTemplateResponse) - def set_template(self, request: Request, body: SetTemplateRequest) -> SetTemplateResponse: - """Set the chat template for encoding Trajectory inputs.""" - extra_kwargs = body.model_extra or {} - self.sampler.set_template(body.template_cls, **extra_kwargs) - return SetTemplateResponse() - - @app.post('/add_adapter_to_sampler', response_model=AddAdapterResponse) - def add_adapter_to_sampler(self, request: Request, body: AddAdapterRequest) -> AddAdapterResponse: - """Add a LoRA adapter to the sampler.""" - assert body.adapter_name, 'You need to specify a valid `adapter_name`' - full_adapter_name = self._get_adapter_name(request, body.adapter_name) - token = get_token_from_request(request) - - from peft import LoraConfig - config = LoraConfig(**body.config) if isinstance(body.config, dict) else body.config - - self.register_adapter(full_adapter_name, token) - - self.sampler.add_adapter_to_sampler(full_adapter_name, config) - - return AddAdapterResponse(adapter_name=full_adapter_name) - - @app.post('/heartbeat', response_model=HeartbeatResponse) - def heartbeat(self, request: Request, body: HeartbeatRequest) -> HeartbeatResponse: - """Keep an adapter alive by resetting its inactivity timer.""" - full_adapter_name = self._get_adapter_name(request, body.adapter_name) - self.assert_adapter_exists(adapter_name=full_adapter_name) - self.touch_adapter(full_adapter_name) - return HeartbeatResponse() - - return SamplerManagement.options(**deploy_options).bind(nproc_per_node, device_group, device_mesh, sampler_type, - engine_args, adapter_config, **kwargs) diff --git a/src/twinkle/server/twinkle/server.py b/src/twinkle/server/twinkle/server.py deleted file mode 100644 index 86857647..00000000 --- a/src/twinkle/server/twinkle/server.py +++ /dev/null @@ -1,270 +0,0 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -""" -Twinkle REST API Server - -This module provides a FastAPI server with REST API endpoints for: -- Training run management (list, get, update) -- Checkpoint management (list, delete) -- Weights info retrieval - -All endpoints include permission control to ensure users can only -access their own resources. -""" -from __future__ import annotations - -from fastapi import FastAPI, HTTPException, Request -from pydantic import BaseModel -from ray import serve -from typing import Any - -from twinkle.server.utils.state import ServerStateProxy, get_server_state -from twinkle.server.utils.validation import get_token_from_request, verify_request_token -from .common.io_utils import (CheckpointsListResponse, TrainingRun, TrainingRunsResponse, WeightsInfoResponse, - create_checkpoint_manager, create_training_run_manager, validate_user_path) - -# ----- Request/Response Models ----- - - -class HealthResponse(BaseModel): - status: str - - -class WeightsInfoRequest(BaseModel): - twinkle_path: str - - -class DeleteCheckpointResponse(BaseModel): - success: bool - message: str - - -class ErrorResponse(BaseModel): - detail: str - - -def build_server_app(deploy_options: dict[str, Any], **kwargs): - """ - Build the Twinkle REST API server application. - - This function creates a FastAPI application wrapped in a Ray Serve deployment - that provides REST API endpoints for managing training runs and checkpoints. - - Args: - deploy_options: Ray Serve deployment options (num_replicas, etc.) - **kwargs: Additional configuration options - - Returns: - A Ray Serve deployment handle - """ - app = FastAPI( - title='Twinkle Server', description='REST API for managing training runs and checkpoints', version='1.0.0') - - @app.middleware('http') - async def verify_token(request: Request, call_next): - """Verify authentication token for all requests.""" - return await verify_request_token(request=request, call_next=call_next) - - @serve.deployment(name='TwinkleServer') - @serve.ingress(app) - class TwinkleServer: - """ - Twinkle REST API Server. - - This server provides endpoints for: - - Health checks - - Training run management - - Checkpoint management - - Weights info retrieval - - All modifying operations (delete, etc.) are protected by permission checks - to ensure users can only modify their own resources. - """ - - def __init__(self, **kwargs) -> None: - self.state: ServerStateProxy = get_server_state() - self.route_prefix = kwargs.get('route_prefix', '/api/v1') - - def _get_user_token(self, request: Request) -> str: - """Extract user token from request state.""" - return get_token_from_request(request) - - # ----- Health Check ----- - - @app.get('/healthz', response_model=HealthResponse) - async def healthz(self, request: Request) -> HealthResponse: - """ - Health check endpoint. - - Returns: - HealthResponse with status "ok" if server is healthy - """ - return HealthResponse(status='ok') - - # ----- Training Runs Endpoints ----- - - @app.get('/training_runs', response_model=TrainingRunsResponse) - async def get_training_runs(self, request: Request, limit: int = 20, offset: int = 0) -> TrainingRunsResponse: - """ - List training runs. - - Returns training runs owned by the current user. - - Args: - limit: Maximum number of results (default: 20) - offset: Offset for pagination (default: 0) - - Returns: - TrainingRunsResponse with list of training runs and pagination info - """ - token = self._get_user_token(request) - training_run_manager = create_training_run_manager(token) - return training_run_manager.list_runs(limit=limit, offset=offset) - - @app.get('/training_runs/{run_id}', response_model=TrainingRun) - async def get_training_run(self, request: Request, run_id: str) -> TrainingRun: - """ - Get details of a specific training run. - - Users can only view their own training runs. - - Args: - run_id: The training run identifier - - Returns: - TrainingRun details - - Raises: - HTTPException 404 if run not found or not owned by user - """ - token = self._get_user_token(request) - training_run_manager = create_training_run_manager(token) - run = training_run_manager.get_with_permission(run_id) - if not run: - raise HTTPException(status_code=404, detail=f'Training run {run_id} not found or access denied') - return run - - @app.get('/training_runs/{run_id}/checkpoints', response_model=CheckpointsListResponse) - async def get_run_checkpoints(self, request: Request, run_id: str) -> CheckpointsListResponse: - """ - List checkpoints for a training run. - - Users can only view checkpoints for their own training runs. - - Args: - run_id: The training run identifier - - Returns: - CheckpointsListResponse with list of checkpoints - - Raises: - HTTPException 404 if run not found or not owned by user - """ - token = self._get_user_token(request) - checkpoint_manager = create_checkpoint_manager(token) - response = checkpoint_manager.list_checkpoints(run_id) - if response is None: - raise HTTPException(status_code=404, detail=f'Training run {run_id} not found or access denied') - return response - - @app.delete('/training_runs/{run_id}/checkpoints/{checkpoint_id:path}') - async def delete_run_checkpoint(self, request: Request, run_id: str, - checkpoint_id: str) -> DeleteCheckpointResponse: - """ - Delete a checkpoint from a training run. - - Users can only delete checkpoints from their own training runs. - Path traversal (using ..) is not allowed. - - Args: - run_id: The training run identifier - checkpoint_id: The checkpoint identifier (can include path like weights/checkpoint_name) - - Returns: - DeleteCheckpointResponse indicating success or failure - - Raises: - HTTPException 400 for invalid paths - HTTPException 403 if not owned by user - HTTPException 404 if checkpoint not found - """ - token = self._get_user_token(request) - - # Validate path safety - if not validate_user_path(token, checkpoint_id): - raise HTTPException(status_code=400, detail='Invalid checkpoint path: path traversal not allowed') - - checkpoint_manager = create_checkpoint_manager(token) - success = checkpoint_manager.delete(run_id, checkpoint_id) - if not success: - raise HTTPException(status_code=404, detail=f'Checkpoint {checkpoint_id} not found or access denied') - - return DeleteCheckpointResponse(success=True, message=f'Checkpoint {checkpoint_id} deleted successfully') - - @app.post('/weights_info', response_model=WeightsInfoResponse) - async def weights_info(self, request: Request, body: WeightsInfoRequest) -> WeightsInfoResponse: - """ - Get information about saved weights. - - Users can only view info for their own weights. - - Args: - body: Request containing the twinkle_path - - Returns: - WeightsInfoResponse with weight details - - Raises: - HTTPException 404 if weights not found or not owned by user - """ - token = self._get_user_token(request) - checkpoint_manager = create_checkpoint_manager(token) - response = checkpoint_manager.get_weights_info(body.twinkle_path) - if response is None: - raise HTTPException( - status_code=404, detail=f'Weights at {body.twinkle_path} not found or access denied') - return response - - # ----- Checkpoint Path Resolution ----- - - @app.get('/checkpoint_path/{run_id}/{checkpoint_id:path}') - async def get_checkpoint_path(self, request: Request, run_id: str, checkpoint_id: str) -> dict[str, str]: - """ - Get the filesystem path for a checkpoint. - - This endpoint resolves a checkpoint ID to its actual filesystem path, - which can be used for loading weights during resume training. - - Args: - run_id: The training run identifier - checkpoint_id: The checkpoint identifier - - Returns: - Dict with 'path' key containing the filesystem path - - Raises: - HTTPException 403/404 for permission/not found errors - """ - token = self._get_user_token(request) - - # Validate path safety - if not validate_user_path(token, checkpoint_id): - raise HTTPException(status_code=400, detail='Invalid checkpoint path: path traversal not allowed') - - training_run_manager = create_training_run_manager(token) - checkpoint_manager = create_checkpoint_manager(token) - - # Check ownership - run = training_run_manager.get(run_id) - if not run: - raise HTTPException(status_code=404, detail=f'Training run {run_id} not found or access denied') - - # Get checkpoint with token-based path - checkpoint = checkpoint_manager.get(run_id, checkpoint_id) - if not checkpoint: - raise HTTPException(status_code=404, detail=f'Checkpoint {checkpoint_id} not found') - - # Return the filesystem path - ckpt_dir = checkpoint_manager.get_ckpt_dir(run_id, checkpoint_id) - return {'path': str(ckpt_dir), 'twinkle_path': checkpoint.twinkle_path} - - return TwinkleServer.options(**deploy_options).bind(**kwargs) diff --git a/src/twinkle/server/utils/__init__.py b/src/twinkle/server/utils/__init__.py index dca07caf..d19d34d0 100644 --- a/src/twinkle/server/utils/__init__.py +++ b/src/twinkle/server/utils/__init__.py @@ -1,7 +1,8 @@ # Copyright (c) ModelScope Contributors. All rights reserved. from .adapter_manager import AdapterManagerMixin +from .checkpoint_base import (TRAIN_RUN_INFO_FILENAME, TWINKLE_DEFAULT_SAVE_DIR, BaseCheckpointManager, BaseFileManager, + BaseTrainingRunManager) from .device_utils import auto_fill_device_group_visible_devices, wrap_builder_with_device_group_env -from .io_utils import (TRAIN_RUN_INFO_FILENAME, TWINKLE_DEFAULT_SAVE_DIR, BaseCheckpointManager, BaseFileManager, - BaseTrainingRunManager) +from .processor_manager import ProcessorManagerMixin from .rate_limiter import RateLimiter from .task_queue import QueueState, TaskQueueConfig, TaskQueueMixin, TaskStatus diff --git a/src/twinkle/server/utils/adapter_manager.py b/src/twinkle/server/utils/adapter_manager.py index 8337ed6b..844ccfd1 100644 --- a/src/twinkle/server/utils/adapter_manager.py +++ b/src/twinkle/server/utils/adapter_manager.py @@ -24,17 +24,18 @@ class AdapterManagerMixin: - """Mixin for adapter lifecycle management with automatic timeout. + """Mixin for adapter lifecycle management with session-based expiration. This mixin tracks adapter activity and automatically expires adapters - that have been inactive for longer than the configured timeout period. + when their associated session expires. Inheriting classes should: 1. Call _init_adapter_manager() in __init__ 2. Override _on_adapter_expired() to customize expiration handling Attributes: - _adapter_timeout: Timeout in seconds for inactive adapters. + _adapter_timeout: Session inactivity timeout in seconds used to determine if a session is alive. + _adapter_max_lifetime: Maximum lifetime in seconds for any adapter, regardless of session liveness. """ # Type hint for state attribute that inheriting classes must provide @@ -43,52 +44,55 @@ class AdapterManagerMixin: def _init_adapter_manager( self, adapter_timeout: float = 1800.0, - adapter_max_lifetime: float = 12 * 60 * 60, + adapter_max_lifetime: float = 36000.0, ) -> None: """Initialize the adapter manager. This should be called in the __init__ of the inheriting class. Args: - adapter_timeout: Timeout in seconds for inactive adapters and session-based expiration. - Default is 1800.0 (30 minutes). Adapters linked to sessions will expire - when their session hasn't been touched for this duration. - adapter_max_lifetime: Maximum lifetime in seconds for an adapter since creation. - Default is 43200.0 (12 hours). If <= 0, lifetime enforcement is disabled. + adapter_timeout: Timeout in seconds used to check whether a session is still alive. + Default is 1800.0 (30 minutes). + adapter_max_lifetime: Maximum lifetime in seconds for an adapter regardless of session + liveness. Adapters older than this are treated as expired. Default is 36000.0 (10 hours). """ self._adapter_timeout = adapter_timeout self._adapter_max_lifetime = adapter_max_lifetime # Adapter lifecycle tracking # Dict mapping adapter_name -> - # {'token': str, 'session_id': str, 'last_activity': float, 'created_at': float, 'inactivity_counter': int} + # {'token': str, 'session_id': str, 'created_at': float, 'state': dict, 'expiring': bool} self._adapter_records: dict[str, dict[str, Any]] = {} # Countdown thread self._adapter_countdown_thread: threading.Thread | None = None self._adapter_countdown_running = False - def register_adapter(self, adapter_name: str, token: str, session_id: str | None = None) -> None: + def register_adapter(self, adapter_name: str, token: str, session_id: str) -> None: """Register a new adapter for lifecycle tracking. + The adapter will expire when its associated session expires. + Args: adapter_name: Name of the adapter to register. token: User token that owns this adapter. - session_id: Optional session ID to associate with this adapter. - If provided, adapter will expire when the session expires. + session_id: Session ID to associate with this adapter. Must be a non-empty string. + + Raises: + ValueError: If session_id is None or empty. """ + if not session_id: + raise ValueError(f'session_id must be provided when registering adapter {adapter_name}') current_time = time.time() self._adapter_records[adapter_name] = { 'token': token, 'session_id': session_id, - 'last_activity': current_time, 'created_at': current_time, - 'inactivity_counter': 0, 'state': {}, 'expiring': False, } - logger.debug(f'[AdapterManager] Registered adapter {adapter_name} for token {token[:8]}...' - + (f' (session: {session_id})' if session_id else '')) + logger.debug( + f'[AdapterManager] Registered adapter {adapter_name} for token {token[:8]}... (session: {session_id})') def _is_session_alive(self, session_id: str) -> bool: """Check if a session is still alive via state proxy. @@ -166,24 +170,6 @@ def clear_adapter_state(self, adapter_name: str) -> None: return info['state'] = {} - def touch_adapter(self, adapter_name: str) -> bool: - """Update adapter activity timestamp to prevent timeout. - - Args: - adapter_name: Name of the adapter to touch. - - Returns: - True if adapter was found and touched, False otherwise. - """ - info = self._adapter_records.get(adapter_name) - if not info: - return False - if info.get('expiring'): - return False - info['last_activity'] = time.time() - info['inactivity_counter'] = 0 - return True - def get_adapter_info(self, adapter_name: str) -> dict[str, Any] | None: """Get information about a registered adapter. @@ -230,18 +216,18 @@ def assert_adapter_exists(self, adapter_name: str) -> None: f'Adapter {adapter_name} not found' def _adapter_countdown_loop(self) -> None: - """Background thread that monitors and handles inactive adapters. + """Background thread that monitors and handles adapters whose session has expired or exceeded max lifetime. This thread runs continuously and: - 1. Increments inactivity counters for all adapters every second - 2. Calls _on_adapter_expired() for adapters that exceed timeout - 3. Removes expired adapters from tracking + 1. Checks whether an adapter has exceeded `_adapter_max_lifetime` (sync, no async call) + 2. Checks session liveness for remaining adapters every second + 3. Calls _on_adapter_expired() for adapters that have expired + 4. Removes expired adapters from tracking """ - logger.debug(f'[AdapterManager] Countdown thread started (timeout={self._adapter_timeout}s)') + logger.debug(f'[AdapterManager] Countdown thread started (session_timeout={self._adapter_timeout}s)') while self._adapter_countdown_running: try: - time.sleep(1) - now = time.time() + time.sleep(10) expired_adapters: list[tuple[str, str | None]] = [] # Create snapshot to avoid modification during iteration @@ -251,54 +237,41 @@ def _adapter_countdown_loop(self) -> None: continue session_id = info.get('session_id') - created_at = info.get('created_at') - - # Check TTL for both cases - exceeded_ttl = ( - self._adapter_max_lifetime and self._adapter_max_lifetime > 0 - and (now - created_at) > self._adapter_max_lifetime) - - # Different logic based on session association - if session_id: - # Has session: check session expiration and TTL - session_expired = not self._is_session_alive(session_id) - should_expire = session_expired or exceeded_ttl - logger.debug( - f'[AdapterManager] Adapter {adapter_name} session expiration check ' - f'(session_id={session_id}, session_alive={not session_expired}, should_expire={should_expire})' # noqa:E501 - ) - expiration_reasons = [] - if exceeded_ttl: - expiration_reasons.append('ttl_exceeded') - if session_expired: - expiration_reasons.append('session_expired') - else: - # No session: check inactivity timeout and TTL - info['inactivity_counter'] = info.get('inactivity_counter', 0) + 1 - exceeded_inactivity = info['inactivity_counter'] > self._adapter_timeout - should_expire = exceeded_ttl or exceeded_inactivity - logger.debug( - f'[AdapterManager] Adapter {adapter_name} inactivity check ' - f'(inactivity_counter={info["inactivity_counter"]}, timeout={self._adapter_timeout}, should_expire={should_expire})' # noqa:E501 - ) - expiration_reasons = [] - if exceeded_ttl: - expiration_reasons.append('ttl_exceeded') - if exceeded_inactivity: - expiration_reasons.append('inactivity_timeout') - - if should_expire: + created_at = info.get('created_at', 0.0) + now = time.time() + + # Check max lifetime first (no async call needed) + if now - created_at >= self._adapter_max_lifetime: + logger.debug(f'[AdapterManager] Adapter {adapter_name} exceeded max lifetime ' + f'({self._adapter_max_lifetime}s), marking as expired') + info['expiring'] = True + info['state'] = {} + token = info.get('token') + expired_adapters.append((adapter_name, token, session_id)) + continue + + try: + session_alive = self._is_session_alive(session_id) + except Exception as e: + logger.warning(f'[AdapterManager] Failed to check session liveness for {adapter_name}: ' + f'{type(e).__name__}: {e}') + continue + session_expired = not session_alive + logger.debug(f'[AdapterManager] Adapter {adapter_name} session check ' + f'(session_id={session_id}, session_alive={not session_expired})') + + if session_expired: info['expiring'] = True info['state'] = {} # best-effort clear token = info.get('token') - expired_adapters.append((adapter_name, token)) + expired_adapters.append((adapter_name, token, session_id)) - for adapter_name, token in expired_adapters: + for adapter_name, _token, session_id in expired_adapters: success = False try: self._on_adapter_expired(adapter_name) logger.info(f'[AdapterManager] Adapter {adapter_name} expired ' - f"(reasons={','.join(expiration_reasons)}, session={session_id})") + f'(reason=session_expired, session={session_id})') success = True except Exception as e: logger.warning(f'[AdapterManager] Error while expiring adapter {adapter_name}: {e}') diff --git a/src/twinkle/server/utils/io_utils.py b/src/twinkle/server/utils/checkpoint_base.py similarity index 96% rename from src/twinkle/server/utils/io_utils.py rename to src/twinkle/server/utils/checkpoint_base.py index 1a95b6c2..cbe05602 100644 --- a/src/twinkle/server/utils/io_utils.py +++ b/src/twinkle/server/utils/checkpoint_base.py @@ -1,10 +1,16 @@ # Copyright (c) ModelScope Contributors. All rights reserved. """ -Base IO utilities for managing training runs and checkpoints. +Base infrastructure for checkpoint and training-run persistence. -This module provides abstract base classes that encapsulate common logic for -file-based storage of training run metadata and checkpoint information. -Both tinker and twinkle servers inherit from these classes. +Provides: +- Constants and path-hashing utilities +- Permission-check helpers (``validate_user_path``, ``validate_ownership``) +- Internal Pydantic base specs used as type constraints for the generic managers +- Abstract base managers: ``BaseTrainingRunManager``, ``BaseCheckpointManager`` + +Concrete implementations live in: + - ``twinkle.server.common.tinker_checkpoint`` + - ``twinkle.server.common.twinkle_checkpoint`` """ import hashlib import hmac @@ -20,6 +26,7 @@ from twinkle import get_logger from twinkle.hub import HubOperation +from twinkle_client.types import ResolvedLoadPath logger = get_logger() @@ -41,13 +48,7 @@ def _hash_token(token: str) -> str: return hmac.new(_TOKEN_SALT, token.encode('utf-8'), hashlib.sha256).hexdigest()[:16] -# ----- Common Pydantic Models ----- - - -class Cursor(BaseModel): - limit: int - offset: int - total_count: int +# ----- Internal Pydantic Base Specs ----- class BaseCheckpoint(BaseModel): @@ -104,23 +105,6 @@ class BaseParsedCheckpointPath(BaseModel): checkpoint_id: str -class ResolvedLoadPath(BaseModel): - """Result of resolving a load path. - - Attributes: - checkpoint_name: The name of the checkpoint (e.g., 'step-8' or hub model id) - checkpoint_dir: The directory containing the checkpoint, or None if loading from hub - is_twinkle_path: Whether the path was a twinkle:// path - training_run_id: The training run ID (only set for twinkle:// paths) - checkpoint_id: The checkpoint ID (only set for twinkle:// paths) - """ - checkpoint_name: str - checkpoint_dir: Optional[str] = None - is_twinkle_path: bool = False - training_run_id: Optional[str] = None - checkpoint_id: Optional[str] = None - - class BaseWeightsInfoResponse(BaseModel): """Base model for weights info response.""" training_run_id: str diff --git a/src/twinkle/server/utils/processor_manager.py b/src/twinkle/server/utils/processor_manager.py new file mode 100644 index 00000000..df289b39 --- /dev/null +++ b/src/twinkle/server/utils/processor_manager.py @@ -0,0 +1,195 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +""" +Processor Lifecycle Manager Mixin for Twinkle Server. + +Mirrors AdapterManagerMixin but adds a global per-token processor limit. +Sessions are tracked via session ID; processors expire when their session expires. +""" +from __future__ import annotations + +import threading +import time +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from twinkle.server.utils.state import ServerStateProxy + +from twinkle.utils.logger import get_logger + +logger = get_logger() + + +class ProcessorManagerMixin: + """Mixin for processor lifecycle management with session-based expiration. + + Mirrors AdapterManagerMixin with an additional per-token processor limit. + + Inheriting classes should: + 1. Call _init_processor_manager() in __init__ + 2. Override _on_processor_expired() to handle cleanup + + Attributes: + _processor_timeout: Session inactivity timeout in seconds. + _per_token_processor_limit: Maximum active processors per user token. + """ + + # Type hint for state attribute that inheriting classes must provide + state: ServerStateProxy + + def _init_processor_manager( + self, + processor_timeout: float = 1800.0, + per_token_processor_limit: int = 20, + ) -> None: + """Initialize the processor manager. + + Args: + processor_timeout: Timeout in seconds to determine if a session is alive. + Default is 1800.0 (30 minutes). + per_token_processor_limit: Maximum active processors per user token. + Default is 20. + """ + self._processor_timeout = processor_timeout + self._per_token_processor_limit = per_token_processor_limit + + # processor_id -> {'token': str, 'session_id': str, 'created_at': float, 'expiring': bool} + self._processor_records: dict[str, dict[str, Any]] = {} + + self._processor_countdown_thread: threading.Thread | None = None + self._processor_countdown_running = False + + def register_processor(self, processor_id: str, token: str, session_id: str) -> None: + """Register a new processor for lifecycle tracking. + + Args: + processor_id: Unique identifier of the processor. + token: User token that owns this processor. + session_id: Session ID to associate with this processor. Must be non-empty. + + Raises: + ValueError: If session_id is None or empty. + RuntimeError: If the per-token processor limit has been reached. + """ + if not session_id: + raise ValueError(f'session_id must be provided when registering processor {processor_id}') + + current_count = sum(1 for info in self._processor_records.values() if info.get('token') == token) + if current_count >= self._per_token_processor_limit: + raise RuntimeError(f'Per-user processor limit ({self._per_token_processor_limit}) reached ' + f'for token {token[:8]}...') + + self._processor_records[processor_id] = { + 'token': token, + 'session_id': session_id, + 'created_at': time.time(), + 'expiring': False, + } + logger.debug(f'[ProcessorManager] Registered processor {processor_id} ' + f'for token {token[:8]}... (session: {session_id})') + + def unregister_processor(self, processor_id: str) -> bool: + """Unregister a processor from lifecycle tracking. + + Returns: + True if found and removed, False otherwise. + """ + if processor_id in self._processor_records: + info = self._processor_records.pop(processor_id) + token = info.get('token', '') + logger.debug(f'[ProcessorManager] Unregistered processor {processor_id} ' + f'for token {token[:8] if token else "unknown"}...') + return True + return False + + def get_processor_info(self, processor_id: str) -> dict[str, Any] | None: + """Get tracking info for a registered processor, or None if not found.""" + return self._processor_records.get(processor_id) + + def assert_processor_exists(self, processor_id: str) -> None: + """Assert a processor exists and is not expiring.""" + info = self._processor_records.get(processor_id) + assert processor_id and info is not None and not info.get('expiring'), \ + f'Processor {processor_id} not found' + + def _on_processor_expired(self, processor_id: str) -> None: + """Hook called when a processor's session expires. + + Must be overridden by inheriting classes. + + Raises: + NotImplementedError: If not overridden. + """ + raise NotImplementedError(f'_on_processor_expired must be implemented by {self.__class__.__name__}') + + def _is_session_alive(self, session_id: str) -> bool: + """Check if a session is still alive via state proxy.""" + if not session_id: + return True + last_heartbeat = self.state.get_session_last_heartbeat(session_id) + if last_heartbeat is None: + return False + return (time.time() - last_heartbeat) < self._processor_timeout + + def _processor_countdown_loop(self) -> None: + """Background thread: checks session liveness and expires stale processors.""" + logger.debug(f'[ProcessorManager] Countdown thread started (session_timeout={self._processor_timeout}s)') + while self._processor_countdown_running: + try: + time.sleep(1) + + expired: list[tuple[str, str | None]] = [] + for processor_id, info in list(self._processor_records.items()): + if info.get('expiring'): + continue + session_id = info.get('session_id') + try: + session_alive = self._is_session_alive(session_id) + except Exception as e: + logger.warning(f'[ProcessorManager] Failed to check session liveness ' + f'for {processor_id}: {type(e).__name__}: {e}') + continue + + logger.debug(f'[ProcessorManager] Processor {processor_id} session check ' + f'(session_id={session_id}, session_alive={session_alive})') + if not session_alive: + info['expiring'] = True + expired.append((processor_id, session_id)) + + for processor_id, session_id in expired: + success = False + try: + self._on_processor_expired(processor_id) + logger.info(f'[ProcessorManager] Processor {processor_id} expired ' + f'(reason=session_expired, session={session_id})') + success = True + except Exception as e: + logger.warning(f'[ProcessorManager] Error while expiring processor {processor_id}: {e}') + finally: + if success: + self._processor_records.pop(processor_id, None) + else: + info = self._processor_records.get(processor_id) + if info is not None: + info['expiring'] = False + + except Exception as e: + logger.warning(f'[ProcessorManager] Error in countdown loop: {e}') + continue + + logger.debug('[ProcessorManager] Countdown thread stopped') + + def start_processor_countdown(self) -> None: + """Start the background countdown thread. Safe to call multiple times.""" + if not self._processor_countdown_running: + self._processor_countdown_running = True + self._processor_countdown_thread = threading.Thread(target=self._processor_countdown_loop, daemon=True) + self._processor_countdown_thread.start() + logger.debug('[ProcessorManager] Countdown thread started') + + def stop_processor_countdown(self) -> None: + """Stop the background countdown thread.""" + if self._processor_countdown_running: + self._processor_countdown_running = False + if self._processor_countdown_thread: + self._processor_countdown_thread.join(timeout=2.0) + logger.debug('[ProcessorManager] Countdown thread stopped') diff --git a/src/twinkle/server/utils/state/server_state.py b/src/twinkle/server/utils/state/server_state.py index 7588c65d..a70fdac5 100644 --- a/src/twinkle/server/utils/state/server_state.py +++ b/src/twinkle/server/utils/state/server_state.py @@ -71,7 +71,7 @@ def create_session(self, payload: dict[str, Any]) -> str: self._session_mgr.add(session_id, record) return session_id - def touch_session(self, session_id: str) -> bool: + async def touch_session(self, session_id: str) -> bool: """Update session heartbeat timestamp. Returns: @@ -154,7 +154,7 @@ def unregister_replica(self, replica_id: str) -> None: """ self._model_mgr.unregister_replica(replica_id) - def get_available_replica_ids(self, candidate_ids: list[str]) -> list[str]: + async def get_available_replica_ids(self, candidate_ids: list[str]) -> list[str]: """Return candidate replica IDs that have not reached their max_loras limit. Args: @@ -195,12 +195,12 @@ def get_sampling_session(self, sampling_session_id: str) -> dict[str, Any] | Non # ----- Future Management ----- - def get_future(self, request_id: str) -> dict[str, Any] | None: + async def get_future(self, request_id: str) -> dict[str, Any] | None: """Retrieve a stored future result as a plain dict.""" record = self._future_mgr.get(request_id) return record.model_dump() if record is not None else None - def store_future_status( + async def store_future_status( self, request_id: str, status: str, @@ -239,28 +239,6 @@ def store_future_status( queue_state_reason=queue_state_reason, ) - # ----- Config Management ----- - - def add_config(self, key: str, value: Any) -> None: - """Add or update a configuration value.""" - self._config_mgr.add(key, value) - - def add_or_get(self, key: str, value: Any) -> Any: - """Add a config value if the key does not exist; otherwise return the existing value.""" - return self._config_mgr.add_or_get(key, value) - - def get_config(self, key: str) -> Any | None: - """Get a configuration value by key.""" - return self._config_mgr.get(key) - - def pop_config(self, key: str) -> Any | None: - """Remove and return a configuration value.""" - return self._config_mgr.pop(key) - - def clear_config(self) -> None: - """Clear all configuration values.""" - self._config_mgr.clear() - # ----- Resource Cleanup ----- def cleanup_expired_resources(self) -> dict[str, int]: @@ -372,8 +350,8 @@ def __init__(self, actor_handle) -> None: def create_session(self, payload: dict[str, Any]) -> str: return ray.get(self._actor.create_session.remote(payload)) - def touch_session(self, session_id: str) -> bool: - return ray.get(self._actor.touch_session.remote(session_id)) + async def touch_session(self, session_id: str) -> bool: + return await self._actor.touch_session.remote(session_id) def get_session_last_heartbeat(self, session_id: str) -> float | None: return ray.get(self._actor.get_session_last_heartbeat.remote(session_id)) @@ -401,8 +379,8 @@ def register_replica(self, replica_id: str, max_loras: int) -> None: def unregister_replica(self, replica_id: str) -> None: ray.get(self._actor.unregister_replica.remote(replica_id)) - def get_available_replica_ids(self, candidate_ids: list[str]) -> list[str]: - return ray.get(self._actor.get_available_replica_ids.remote(candidate_ids)) + async def get_available_replica_ids(self, candidate_ids: list[str]) -> list[str]: + return await self._actor.get_available_replica_ids.remote(candidate_ids) # ----- Sampling Session Management ----- @@ -414,10 +392,10 @@ def get_sampling_session(self, sampling_session_id: str) -> dict[str, Any] | Non # ----- Future Management ----- - def get_future(self, request_id: str) -> dict[str, Any] | None: - return ray.get(self._actor.get_future.remote(request_id)) + async def get_future(self, request_id: str) -> dict[str, Any] | None: + return await self._actor.get_future.remote(request_id) - def store_future_status( + async def store_future_status( self, request_id: str, status: str, @@ -428,26 +406,8 @@ def store_future_status( queue_state_reason: str | None = None, ) -> None: """Store task status with optional result (synchronous).""" - ray.get( - self._actor.store_future_status.remote(request_id, status, model_id, reason, result, queue_state, - queue_state_reason)) - - # ----- Config Management ----- - - def add_config(self, key: str, value: Any): - return ray.get(self._actor.add_config.remote(key, value)) - - def add_or_get(self, key: str, value: Any) -> Any: - return ray.get(self._actor.add_or_get.remote(key, value)) - - def get_config(self, key: str) -> Any | None: - return ray.get(self._actor.get_config.remote(key)) - - def pop_config(self, key: str) -> Any | None: - return ray.get(self._actor.pop_config.remote(key)) - - def clear_config(self): - return ray.get(self._actor.clear_config.remote()) + await self._actor.store_future_status.remote(request_id, status, model_id, reason, result, queue_state, + queue_state_reason) # ----- Resource Cleanup ----- diff --git a/src/twinkle/server/utils/task_queue.py b/src/twinkle/server/utils/task_queue.py index 39511659..d0985c15 100644 --- a/src/twinkle/server/utils/task_queue.py +++ b/src/twinkle/server/utils/task_queue.py @@ -255,7 +255,7 @@ async def _queue_worker(self) -> None: 'error': f'Queue timeout exceeded: waited {now - task.created_at:.2f}s', 'category': 'Server' } - self.state.store_future_status( + await self.state.store_future_status( task.request_id, TaskStatus.FAILED.value, task.model_id, @@ -270,13 +270,13 @@ async def _queue_worker(self) -> None: # Execute executed_any = True - self.state.store_future_status( + await self.state.store_future_status( task.request_id, TaskStatus.RUNNING.value, task.model_id, queue_state=QueueState.ACTIVE.value) try: coro = task.coro_factory() result = await coro - self.state.store_future_status( + await self.state.store_future_status( task.request_id, TaskStatus.COMPLETED.value, task.model_id, @@ -284,7 +284,7 @@ async def _queue_worker(self) -> None: queue_state=QueueState.ACTIVE.value) except Exception: error_payload = {'error': traceback.format_exc(), 'category': 'Server'} - self.state.store_future_status( + await self.state.store_future_status( task.request_id, TaskStatus.FAILED.value, task.model_id, @@ -321,7 +321,7 @@ async def _fail_queue_tasks_async(self, queue_key: str, reason: str) -> None: for task in drained: error_payload = {'error': reason, 'category': 'Server'} - self.state.store_future_status( + await self.state.store_future_status( task.request_id, TaskStatus.FAILED.value, task.model_id, @@ -381,7 +381,7 @@ async def _perform_preflight_checks( if input_tokens > self._task_queue_config.max_input_tokens: error_msg = f'Input tokens ({input_tokens}) exceed maximum allowed ({self._task_queue_config.max_input_tokens})' # noqa: E501 error_payload = {'error': error_msg, 'category': 'User'} - self.state.store_future_status( + await self.state.store_future_status( request_id, TaskStatus.FAILED.value, model_id, @@ -396,7 +396,7 @@ async def _perform_preflight_checks( if batch_size < data_world_size: error_msg = f'Batch size {batch_size} must be greater than or equal to data world size {data_world_size}' # noqa: E501 error_payload = {'error': error_msg, 'category': 'User'} - self.state.store_future_status( + await self.state.store_future_status( request_id, TaskStatus.FAILED.value, model_id, @@ -411,7 +411,7 @@ async def _perform_preflight_checks( if not allowed: error_msg = f'Rate limit exceeded: {reason}' error_payload = {'error': error_msg, 'category': 'User'} - self.state.store_future_status( + await self.state.store_future_status( request_id, TaskStatus.FAILED.value, model_id, @@ -475,7 +475,7 @@ async def schedule_task( ) # 2. Register PENDING status FIRST - self.state.store_future_status( + await self.state.store_future_status( request_id, TaskStatus.PENDING.value, model_id, queue_state=QueueState.ACTIVE.value) # 3. Route to per-model/per-token queue @@ -500,7 +500,7 @@ async def schedule_task( task_type=task_type, created_at=time.monotonic(), )) - self.state.store_future_status( + await self.state.store_future_status( request_id, TaskStatus.QUEUED.value, model_id, queue_state=QueueState.ACTIVE.value) logger.debug(f'[TaskQueue] Task {request_id} queued, new queue size: {q.qsize()} key={queue_key}') @@ -544,6 +544,58 @@ def get_rate_limiter_memory_stats(self) -> dict[str, Any]: """ return self._rate_limiter.get_memory_stats() + async def schedule_task_and_wait( + self, + coro_factory: Callable[[], Coroutine], + model_id: str | None = None, + token: str | None = None, + input_tokens: int = 0, + task_type: str | None = None, + ) -> Any: + """Schedule an async task and wait for its result synchronously. + + This is the twinkle-side counterpart to :meth:`schedule_task`. + It enqueues the task through the same serial worker, then blocks + (via async sleep) until the task completes, and returns the result + directly instead of a future reference dict. + + Args: + coro_factory: Factory that creates the coroutine to execute. + model_id: Optional model_id to associate with the result. + token: Optional user token for rate limiting. + input_tokens: Number of input tokens for tps rate limiting. + task_type: Optional task type for logging/observability. + + Returns: + The direct return value of the coroutine. + + Raises: + RuntimeError: If the task fails. + """ + future_ref = await self.schedule_task( + coro_factory, + model_id=model_id, + token=token, + input_tokens=input_tokens, + task_type=task_type, + ) + request_id = future_ref.get('request_id') + if request_id is None: + # Pre-flight check failed; surface the error from the stored future + raise RuntimeError(f'Task scheduling failed: {future_ref}') + + while True: + record = await self.state.get_future(request_id) + if record and record.get('status') not in ('pending', 'queued', 'running'): + break + await asyncio.sleep(0.05) + + if record['status'] == 'failed': + error = record.get('result', {}).get('error', 'Unknown error') + raise RuntimeError(error) + + return record['result'] + async def shutdown_task_queue(self) -> None: """Gracefully shutdown the task queue and cleanup tasks. diff --git a/src/twinkle/server/utils/validation.py b/src/twinkle/server/utils/validation.py index 23539ed8..96a1f33a 100644 --- a/src/twinkle/server/utils/validation.py +++ b/src/twinkle/server/utils/validation.py @@ -32,6 +32,7 @@ async def verify_request_token(request: Request, call_next): status_code=400, content={'detail': 'Missing X-Ray-Serve-Request-Id header, required for sticky session'}) request.state.request_id = request_id request.state.token = token + request.state.session_id = request.headers.get('X-Twinkle-Session-Id') or '' response = await call_next(request) return response @@ -63,3 +64,16 @@ def get_token_from_request(request: Request) -> str: The extracted token or empty string if not found """ return getattr(request.state, 'token', '') or '' + + +def get_session_id_from_request(request: Request) -> str: + """ + Extract session ID from request. + + Args: + request: The FastAPI Request object + + Returns: + The extracted session ID or empty string if not found + """ + return getattr(request.state, 'session_id', '') or '' diff --git a/src/twinkle_client/__init__.py b/src/twinkle_client/__init__.py index 58c43a37..f41a83ce 100644 --- a/src/twinkle_client/__init__.py +++ b/src/twinkle_client/__init__.py @@ -1,10 +1,13 @@ # Copyright (c) ModelScope Contributors. All rights reserved. from __future__ import annotations +from typing import Optional + def init_tinker_client(**kwargs) -> None: """Initialize Tinker client with Twinkle-specific headers. - After calling this function, users can directly use: + After calling this function, users can directly use:: + from tinker import ServiceClient client = ServiceClient(base_url='...', api_key='...') @@ -13,39 +16,57 @@ def init_tinker_client(**kwargs) -> None: Args: **kwargs: Additional keyword arguments (currently unused, reserved for future) - Example: - >>> from twinkle import init_tinker_client + Example:: + + >>> from twinkle_client import init_tinker_client >>> init_tinker_client() >>> from tinker import ServiceClient >>> client = ServiceClient(base_url='http://localhost:8000', api_key='your_token') """ from twinkle.utils import requires - + requires('tinker') from twinkle_client.utils.patch_tinker import patch_tinker - # Apply patches to tinker library (includes header injection) patch_tinker() -def init_twinkle_client(base_url: str | None = None, api_key: str | None = None, **kwargs) -> TwinkleClient: +def init_twinkle_client( + base_url: Optional[str] = None, + api_key: Optional[str] = None, + session_heartbeat_interval: int = 10, + **kwargs, +) -> 'TwinkleClient': """ - Initialize a Twinkle client and setup context variables. + Initialize a Twinkle client. + + This function: + + * Resolves ``base_url`` and ``api_key`` (env-vars as fallbacks). + * Sets both values into the shared context so that all other client objects + (``MultiLoraTransformersModel``, ``vLLMSampler``, processor clients) created + afterwards automatically inherit the same server configuration. + * Creates a server-side session and stores the ``session_id`` in context so + every subsequent HTTP request carries it in ``X-Twinkle-Session-Id``. + * Starts a background thread that touches the session every + ``session_heartbeat_interval`` seconds. + + Args: + base_url: Twinkle server base URL. Falls back to ``TWINKLE_SERVER_URL``. + api_key: Authentication token. Falls back to ``TWINKLE_SERVER_TOKEN``. + session_heartbeat_interval: Seconds between session touch calls (default: 10). + **kwargs: Additional keyword arguments forwarded to :class:`TwinkleClient`. + + Returns: + An initialised :class:`~twinkle_client.manager.TwinkleClient` instance. """ - from .http.utils import get_api_key, get_base_url, set_api_key, set_base_url - from .manager import TwinkleClient, TwinkleClientError - - if base_url is not None: - set_base_url(base_url) - else: - base_url = get_base_url() - - if api_key is not None: - set_api_key(api_key) - else: - api_key = get_api_key() - - return TwinkleClient(base_url=base_url, api_key=api_key, **kwargs) + from .manager import TwinkleClient + return TwinkleClient( + base_url=base_url, + api_key=api_key, + session_heartbeat_interval=session_heartbeat_interval, + **kwargs, + ) __all__ = ['init_tinker_client', 'init_twinkle_client'] diff --git a/src/twinkle_client/dataloader/dataloader.py b/src/twinkle_client/dataloader/dataloader.py index 3cd2b564..0a067ddd 100644 --- a/src/twinkle_client/dataloader/dataloader.py +++ b/src/twinkle_client/dataloader/dataloader.py @@ -10,7 +10,7 @@ # ============================================================================ from typing import Callable, Type, Union -from twinkle_client.http import http_post, heartbeat_manager +from twinkle_client.http import http_post from twinkle.dataset import Dataset from twinkle.processor import InputProcessor @@ -19,10 +19,10 @@ class DataLoader(object): def __init__(self, dataset: Union[Dataset, Callable], **kwargs): from twinkle_client.http import get_base_url - self.server_url = get_base_url() + self.server_url = f'{get_base_url()}/processor/twinkle' response = http_post( - url=f'{self.server_url}/processors/create', + url=f'{self.server_url}/create', json_data={ 'processor_type': 'dataloader', 'class_type': 'DataLoader', @@ -31,18 +31,11 @@ def __init__(self, dataset: Union[Dataset, Callable], **kwargs): ) response.raise_for_status() self.processor_id = response.json()['processor_id'] - heartbeat_manager.register_processor(self.processor_id) - - def __del__(self): - try: - heartbeat_manager.unregister_processor(self.processor_id) - except: - pass def __len__(self): response = http_post( - url=f'{self.server_url}/processors/call', + url=f'{self.server_url}/call', json_data={ 'processor_id': self.processor_id, 'function': '__len__', @@ -55,7 +48,7 @@ def __len__(self): def set_processor(self, processor_cls: Union[Type[InputProcessor], str, InputProcessor, Callable], **kwargs): response = http_post( - url=f'{self.server_url}/processors/call', + url=f'{self.server_url}/call', json_data={ 'processor_id': self.processor_id, 'function': 'set_processor', @@ -69,7 +62,7 @@ def set_processor(self, processor_cls: Union[Type[InputProcessor], str, InputPro def __iter__(self): response = http_post( - url=f'{self.server_url}/processors/call', + url=f'{self.server_url}/call', json_data={ 'processor_id': self.processor_id, 'function': '__iter__', @@ -81,7 +74,7 @@ def __iter__(self): def __next__(self): response = http_post( - url=f'{self.server_url}/processors/call', + url=f'{self.server_url}/call', json_data={ 'processor_id': self.processor_id, 'function': '__next__', diff --git a/src/twinkle_client/dataset/base.py b/src/twinkle_client/dataset/base.py index 3d5b5062..0487f733 100644 --- a/src/twinkle_client/dataset/base.py +++ b/src/twinkle_client/dataset/base.py @@ -10,7 +10,7 @@ # ============================================================================ from typing import Any, Callable, Dict, Type, Union -from twinkle_client.http import http_post, heartbeat_manager +from twinkle_client.http import http_post from twinkle.dataset import Dataset from twinkle.dataset import DatasetMeta from twinkle.preprocessor import DataFilter @@ -22,10 +22,10 @@ class Dataset(object): def __init__(self, dataset_meta: DatasetMeta, **kwargs): from twinkle_client.http import get_base_url - self.server_url = get_base_url() + self.server_url = f'{get_base_url()}/processor/twinkle' response = http_post( - url=f'{self.server_url}/processors/create', + url=f'{self.server_url}/create', json_data={ 'processor_type': 'dataset', 'class_type': 'Dataset', @@ -34,18 +34,11 @@ def __init__(self, dataset_meta: DatasetMeta, **kwargs): ) response.raise_for_status() self.processor_id = response.json()['processor_id'] - heartbeat_manager.register_processor(self.processor_id) - - def __del__(self): - try: - heartbeat_manager.unregister_processor(self.processor_id) - except: - pass def set_template(self, template_func: Union[Template, Type[Template], str], **kwargs): response = http_post( - url=f'{self.server_url}/processors/call', + url=f'{self.server_url}/call', json_data={ 'processor_id': self.processor_id, 'function': 'set_template', @@ -59,7 +52,7 @@ def set_template(self, template_func: Union[Template, Type[Template], str], **kw def encode(self, add_generation_prompt: bool = False, **kwargs): response = http_post( - url=f'{self.server_url}/processors/call', + url=f'{self.server_url}/call', json_data={ 'processor_id': self.processor_id, 'function': 'encode', @@ -73,7 +66,7 @@ def encode(self, add_generation_prompt: bool = False, **kwargs): def check(self, **kwargs): response = http_post( - url=f'{self.server_url}/processors/call', + url=f'{self.server_url}/call', json_data={ 'processor_id': self.processor_id, 'function': 'check', @@ -87,7 +80,7 @@ def check(self, **kwargs): def map(self, preprocess_func: Union[Preprocessor, Callable, str, Type[Preprocessor]], dataset_meta: DatasetMeta = None, init_args: Dict[str, Any] = None, **kwargs): response = http_post( - url=f'{self.server_url}/processors/call', + url=f'{self.server_url}/call', json_data={ 'processor_id': self.processor_id, 'function': 'map', @@ -101,7 +94,7 @@ def map(self, preprocess_func: Union[Preprocessor, Callable, str, Type[Preproces def filter(self, filter_func: Union[Callable, str, Type[DataFilter], DataFilter], dataset_meta: DatasetMeta = None, init_args: Dict[str, Any] = None, **kwargs): response = http_post( - url=f'{self.server_url}/processors/call', + url=f'{self.server_url}/call', json_data={ 'processor_id': self.processor_id, 'function': 'filter', @@ -115,7 +108,7 @@ def filter(self, filter_func: Union[Callable, str, Type[DataFilter], DataFilter] def add_dataset(self, dataset_meta: DatasetMeta, **kwargs): response = http_post( - url=f'{self.server_url}/processors/call', + url=f'{self.server_url}/call', json_data={ 'processor_id': self.processor_id, 'function': 'add_dataset', @@ -129,7 +122,7 @@ def add_dataset(self, dataset_meta: DatasetMeta, **kwargs): def mix_dataset(self, interleave = True): response = http_post( - url=f'{self.server_url}/processors/call', + url=f'{self.server_url}/call', json_data={ 'processor_id': self.processor_id, 'function': 'mix_dataset', @@ -142,7 +135,7 @@ def mix_dataset(self, interleave = True): def __getitem__(self, idx): response = http_post( - url=f'{self.server_url}/processors/call', + url=f'{self.server_url}/call', json_data={ 'processor_id': self.processor_id, 'function': '__getitem__', @@ -155,7 +148,7 @@ def __getitem__(self, idx): def __len__(self): response = http_post( - url=f'{self.server_url}/processors/call', + url=f'{self.server_url}/call', json_data={ 'processor_id': self.processor_id, 'function': '__len__', diff --git a/src/twinkle_client/dataset/iterable_dataset.py b/src/twinkle_client/dataset/iterable_dataset.py index 347d1012..25c48919 100644 --- a/src/twinkle_client/dataset/iterable_dataset.py +++ b/src/twinkle_client/dataset/iterable_dataset.py @@ -9,7 +9,7 @@ # 2. Run: python client_tools/client_generator.py # ============================================================================ -from twinkle_client.http import http_post, heartbeat_manager +from twinkle_client.http import http_post from twinkle.dataset import Dataset from twinkle.dataset import DatasetMeta from torch.utils.data import IterableDataset @@ -19,10 +19,10 @@ class IterableDataset(IterableDataset): def __init__(self, dataset_meta: DatasetMeta, **kwargs): from twinkle_client.http import get_base_url - self.server_url = get_base_url() + self.server_url = f'{get_base_url()}/processor/twinkle' response = http_post( - url=f'{self.server_url}/processors/create', + url=f'{self.server_url}/create', json_data={ 'processor_type': 'dataset', 'class_type': 'IterableDataset', @@ -31,18 +31,11 @@ def __init__(self, dataset_meta: DatasetMeta, **kwargs): ) response.raise_for_status() self.processor_id = response.json()['processor_id'] - heartbeat_manager.register_processor(self.processor_id) - - def __del__(self): - try: - heartbeat_manager.unregister_processor(self.processor_id) - except: - pass def add_dataset(self, dataset_meta: DatasetMeta, **kwargs): response = http_post( - url=f'{self.server_url}/processors/call', + url=f'{self.server_url}/call', json_data={ 'processor_id': self.processor_id, 'function': 'add_dataset', @@ -56,7 +49,7 @@ def add_dataset(self, dataset_meta: DatasetMeta, **kwargs): def __len__(self): response = http_post( - url=f'{self.server_url}/processors/call', + url=f'{self.server_url}/call', json_data={ 'processor_id': self.processor_id, 'function': '__len__', @@ -69,7 +62,7 @@ def __len__(self): def __getitem__(self, idx): response = http_post( - url=f'{self.server_url}/processors/call', + url=f'{self.server_url}/call', json_data={ 'processor_id': self.processor_id, 'function': '__getitem__', @@ -82,7 +75,7 @@ def __getitem__(self, idx): def __iter__(self): response = http_post( - url=f'{self.server_url}/processors/call', + url=f'{self.server_url}/call', json_data={ 'processor_id': self.processor_id, 'function': '__iter__', @@ -94,7 +87,7 @@ def __iter__(self): def __next__(self): response = http_post( - url=f'{self.server_url}/processors/call', + url=f'{self.server_url}/call', json_data={ 'processor_id': self.processor_id, 'function': '__next__', diff --git a/src/twinkle_client/dataset/iterable_packing_dataset.py b/src/twinkle_client/dataset/iterable_packing_dataset.py index ce2d918d..12a958a4 100644 --- a/src/twinkle_client/dataset/iterable_packing_dataset.py +++ b/src/twinkle_client/dataset/iterable_packing_dataset.py @@ -10,7 +10,7 @@ # ============================================================================ from typing import Type, Union -from twinkle_client.http import http_post, heartbeat_manager +from twinkle_client.http import http_post from twinkle.dataset import Dataset from twinkle.dataset import DatasetMeta from twinkle.template import Template @@ -21,10 +21,10 @@ class IterablePackingDataset(IterableDataset): def __init__(self, dataset_meta: DatasetMeta, packing_interval: int = 128, packing_num_proc: int = 1, cyclic: bool = False, **kwargs): from twinkle_client.http import get_base_url - self.server_url = get_base_url() + self.server_url = f'{get_base_url()}/processor/twinkle' response = http_post( - url=f'{self.server_url}/processors/create', + url=f'{self.server_url}/create', json_data={ 'processor_type': 'dataset', 'class_type': 'IterablePackingDataset', @@ -33,18 +33,11 @@ def __init__(self, dataset_meta: DatasetMeta, packing_interval: int = 128, packi ) response.raise_for_status() self.processor_id = response.json()['processor_id'] - heartbeat_manager.register_processor(self.processor_id) - - def __del__(self): - try: - heartbeat_manager.unregister_processor(self.processor_id) - except: - pass def set_template(self, template_cls: Union[Type[Template], str, Template], **kwargs): response = http_post( - url=f'{self.server_url}/processors/call', + url=f'{self.server_url}/call', json_data={ 'processor_id': self.processor_id, 'function': 'set_template', @@ -58,7 +51,7 @@ def set_template(self, template_cls: Union[Type[Template], str, Template], **kwa def pack_dataset(self): response = http_post( - url=f'{self.server_url}/processors/call', + url=f'{self.server_url}/call', json_data={ 'processor_id': self.processor_id, 'function': 'pack_dataset', @@ -71,7 +64,7 @@ def pack_dataset(self): def __iter__(self): response = http_post( - url=f'{self.server_url}/processors/call', + url=f'{self.server_url}/call', json_data={ 'processor_id': self.processor_id, 'function': '__iter__', @@ -83,7 +76,7 @@ def __iter__(self): def __next__(self): response = http_post( - url=f'{self.server_url}/processors/call', + url=f'{self.server_url}/call', json_data={ 'processor_id': self.processor_id, 'function': '__next__', diff --git a/src/twinkle_client/dataset/lazy_dataset.py b/src/twinkle_client/dataset/lazy_dataset.py index ce8178b1..62b13dea 100644 --- a/src/twinkle_client/dataset/lazy_dataset.py +++ b/src/twinkle_client/dataset/lazy_dataset.py @@ -9,7 +9,7 @@ # 2. Run: python client_tools/client_generator.py # ============================================================================ -from twinkle_client.http import http_post, heartbeat_manager +from twinkle_client.http import http_post from twinkle.dataset import Dataset from twinkle.dataset import DatasetMeta from .base import Dataset @@ -19,10 +19,10 @@ class LazyDataset(Dataset): def __init__(self, dataset_meta: DatasetMeta, **kwargs): from twinkle_client.http import get_base_url - self.server_url = get_base_url() + self.server_url = f'{get_base_url()}/processor/twinkle' response = http_post( - url=f'{self.server_url}/processors/create', + url=f'{self.server_url}/create', json_data={ 'processor_type': 'dataset', 'class_type': 'LazyDataset', @@ -31,18 +31,11 @@ def __init__(self, dataset_meta: DatasetMeta, **kwargs): ) response.raise_for_status() self.processor_id = response.json()['processor_id'] - heartbeat_manager.register_processor(self.processor_id) - - def __del__(self): - try: - heartbeat_manager.unregister_processor(self.processor_id) - except: - pass def encode(self, **kwargs): response = http_post( - url=f'{self.server_url}/processors/call', + url=f'{self.server_url}/call', json_data={ 'processor_id': self.processor_id, 'function': 'encode', @@ -56,7 +49,7 @@ def encode(self, **kwargs): def check(self, **kwargs): response = http_post( - url=f'{self.server_url}/processors/call', + url=f'{self.server_url}/call', json_data={ 'processor_id': self.processor_id, 'function': 'check', @@ -70,7 +63,7 @@ def check(self, **kwargs): def __getitem__(self, idx): response = http_post( - url=f'{self.server_url}/processors/call', + url=f'{self.server_url}/call', json_data={ 'processor_id': self.processor_id, 'function': '__getitem__', @@ -83,7 +76,7 @@ def __getitem__(self, idx): def __len__(self): response = http_post( - url=f'{self.server_url}/processors/call', + url=f'{self.server_url}/call', json_data={ 'processor_id': self.processor_id, 'function': '__len__', diff --git a/src/twinkle_client/dataset/packing_dataset.py b/src/twinkle_client/dataset/packing_dataset.py index 0d91546f..dd901d1d 100644 --- a/src/twinkle_client/dataset/packing_dataset.py +++ b/src/twinkle_client/dataset/packing_dataset.py @@ -9,7 +9,7 @@ # 2. Run: python client_tools/client_generator.py # ============================================================================ -from twinkle_client.http import http_post, heartbeat_manager +from twinkle_client.http import http_post from twinkle.dataset import Dataset from twinkle.dataset import DatasetMeta from .base import Dataset @@ -19,10 +19,10 @@ class PackingDataset(Dataset): def __init__(self, dataset_meta: DatasetMeta, packing_num_proc: int = 1, **kwargs): from twinkle_client.http import get_base_url - self.server_url = get_base_url() + self.server_url = f'{get_base_url()}/processor/twinkle' response = http_post( - url=f'{self.server_url}/processors/create', + url=f'{self.server_url}/create', json_data={ 'processor_type': 'dataset', 'class_type': 'PackingDataset', @@ -31,18 +31,11 @@ def __init__(self, dataset_meta: DatasetMeta, packing_num_proc: int = 1, **kwarg ) response.raise_for_status() self.processor_id = response.json()['processor_id'] - heartbeat_manager.register_processor(self.processor_id) - - def __del__(self): - try: - heartbeat_manager.unregister_processor(self.processor_id) - except: - pass def pack_dataset(self): response = http_post( - url=f'{self.server_url}/processors/call', + url=f'{self.server_url}/call', json_data={ 'processor_id': self.processor_id, 'function': 'pack_dataset', @@ -55,7 +48,7 @@ def pack_dataset(self): def __getitem__(self, index): response = http_post( - url=f'{self.server_url}/processors/call', + url=f'{self.server_url}/call', json_data={ 'processor_id': self.processor_id, 'function': '__getitem__', @@ -68,7 +61,7 @@ def __getitem__(self, index): def __len__(self): response = http_post( - url=f'{self.server_url}/processors/call', + url=f'{self.server_url}/call', json_data={ 'processor_id': self.processor_id, 'function': '__len__', diff --git a/src/twinkle_client/http/__init__.py b/src/twinkle_client/http/__init__.py index 39bedf71..63880a7f 100644 --- a/src/twinkle_client/http/__init__.py +++ b/src/twinkle_client/http/__init__.py @@ -1,7 +1,7 @@ from .heartbeat import heartbeat_manager from .http_utils import http_delete, http_get, http_post -from .utils import (TWINKLE_SERVER_TOKEN, TWINKLE_SERVER_URL, clear_api_key, clear_base_url, clear_request_id, - get_api_key, get_base_url, get_request_id, set_api_key, set_base_url, set_request_id) +from .utils import (TWINKLE_SERVER_TOKEN, TWINKLE_SERVER_URL, get_api_key, get_base_url, get_request_id, + get_session_id, set_api_key, set_base_url, set_request_id, set_session_id) __all__ = [ 'http_get', @@ -12,11 +12,10 @@ 'TWINKLE_SERVER_TOKEN', 'set_base_url', 'get_base_url', - 'clear_base_url', 'set_api_key', 'get_api_key', - 'clear_api_key', + 'set_session_id', + 'get_session_id', 'set_request_id', 'get_request_id', - 'clear_request_id', ] diff --git a/src/twinkle_client/http/heartbeat.py b/src/twinkle_client/http/heartbeat.py index 4a42f75a..5194d75b 100644 --- a/src/twinkle_client/http/heartbeat.py +++ b/src/twinkle_client/http/heartbeat.py @@ -4,7 +4,7 @@ from typing import Callable, Dict, Optional, Set from .http_utils import http_post -from .utils import TWINKLE_SERVER_URL +from .utils import get_base_url class HeartbeatManager: @@ -33,7 +33,6 @@ def __init__(self): return self._initialized = True - self.server_url = TWINKLE_SERVER_URL # Processor heartbeat management self.processor_ids: Set[str] = set() @@ -52,7 +51,7 @@ def __init__(self): def processor_heartbeat_func(self, processor_id_list: str): response = http_post( - url=f'{self.server_url}/processors/heartbeat', json_data={'processor_id': processor_id_list}) + url=f'{get_base_url()}/processor/twinkle/heartbeat', json_data={'processor_id': processor_id_list}) response.raise_for_status() def register_processor(self, processor_id: str): diff --git a/src/twinkle_client/http/http_utils.py b/src/twinkle_client/http/http_utils.py index 522b46af..70001f7e 100644 --- a/src/twinkle_client/http/http_utils.py +++ b/src/twinkle_client/http/http_utils.py @@ -1,8 +1,7 @@ import requests -from numbers import Number from typing import Any, Callable, Dict, Mapping, Optional -from .utils import get_api_key, get_base_url, get_request_id +from .utils import get_api_key, get_base_url, get_request_id, get_session_id def _build_headers(additional_headers: Optional[Dict[str, str]] = None) -> Dict[str, str]: @@ -17,10 +16,14 @@ def _build_headers(additional_headers: Optional[Dict[str, str]] = None) -> Dict[ """ headers = { 'X-Ray-Serve-Request-Id': get_request_id(), + 'serve_multiplexed_model_id': get_request_id(), # For model multiplexing 'Authorization': 'Bearer ' + get_api_key(), 'Twinkle-Authorization': 'Bearer ' + get_api_key(), # For server compatibility } + if session_id := get_session_id(): + headers['X-Twinkle-Session-Id'] = session_id + if additional_headers: headers.update(additional_headers) @@ -42,7 +45,7 @@ def _serialize_params(params: Dict[str, Any]) -> Dict[str, Any]: if hasattr(value, 'processor_id'): serialized[key] = value.processor_id elif hasattr(value, '__dict__'): - from twinkle.server.twinkle.common.serialize import serialize_object + from twinkle.server.common.serialize import serialize_object serialized[key] = serialize_object(value) else: serialized[key] = value @@ -61,12 +64,26 @@ def _handle_response(response: requests.Response) -> requests.Response: Raises: StopIteration: When server returns HTTP 410 (iterator exhausted) + requests.HTTPError: When server returns a 4xx/5xx error, with the + server-side ``detail`` field (full traceback) included in the + exception message so callers don't need to inspect the response body. """ # Convert HTTP 410 Gone to StopIteration # This indicates an iterator has been exhausted if response.status_code == 410: raise StopIteration(response.json().get('detail', 'Iterator exhausted')) + if not response.ok: + try: + detail = response.json().get('detail', response.text) + except Exception: + detail = response.text + http_error_msg = ( + f'{response.status_code} Error for url: {response.url}\n' + f'Server detail:\n{detail}' + ) + raise requests.HTTPError(http_error_msg, response=response) + return response diff --git a/src/twinkle_client/http/utils.py b/src/twinkle_client/http/utils.py index ad49ffe1..f5b34835 100644 --- a/src/twinkle_client/http/utils.py +++ b/src/twinkle_client/http/utils.py @@ -1,68 +1,64 @@ import os import uuid -from contextvars import ContextVar from datetime import datetime from typing import Optional TWINKLE_SERVER_URL = os.environ.get('TWINKLE_SERVER_URL', 'http://127.0.0.1:8000') TWINKLE_SERVER_TOKEN = os.environ.get('TWINKLE_SERVER_TOKEN', 'EMPTY_TOKEN') -# Context variables for flexible configuration -_base_url_context: ContextVar[Optional[str]] = ContextVar('base_url', default=None) -_api_key_context: ContextVar[Optional[str]] = ContextVar('api_key', default=None) - -# Global static request ID shared across all threads -# This ensures heartbeat threads use the same request ID as the main training thread -_global_request_id: Optional[str] = None +# Global variables for configuration +_base_url: Optional[str] = None +_api_key: Optional[str] = None +_session_id: Optional[str] = None +_request_id: Optional[str] = None def set_base_url(url: str): - """Set the base URL for HTTP requests in the current context.""" - _base_url_context.set(url.rstrip('/')) - - -def get_base_url() -> Optional[str]: - """Get the current base URL from context or environment variable.""" - return _base_url_context.get() or TWINKLE_SERVER_URL + """Set the base URL for HTTP requests.""" + global _base_url + _base_url = url.rstrip('/') -def clear_base_url(): - """Clear the base URL context, falling back to environment variable.""" - _base_url_context.set(None) +def get_base_url() -> str: + """Get the current base URL.""" + base_url = _base_url or TWINKLE_SERVER_URL + if not base_url.endswith('/api/v1'): + base_url += '/api/v1' + return base_url def set_api_key(api_key: str): - """Set the API key for HTTP requests in the current context.""" - _api_key_context.set(api_key) + """Set the API key for HTTP requests.""" + global _api_key + _api_key = api_key def get_api_key() -> str: - """Get the current API key from context or environment variable.""" - return _api_key_context.get() or TWINKLE_SERVER_TOKEN + """Get the current API key.""" + return _api_key or TWINKLE_SERVER_TOKEN + + +def set_session_id(session_id: str): + """Set the session ID.""" + global _session_id + _session_id = session_id -def clear_api_key(): - """Clear the API key context, falling back to environment variable.""" - _api_key_context.set(None) +def get_session_id() -> Optional[str]: + """Get the current session ID.""" + return _session_id def set_request_id(request_id: str): """Set the global request ID for HTTP requests (shared across all threads).""" - global _global_request_id - _global_request_id = request_id + global _request_id + _request_id = request_id def get_request_id() -> str: """Get the global request ID or generate and cache a new one.""" - global _global_request_id - if _global_request_id is not None: - return _global_request_id - # Generate a new request ID and cache it globally for consistency across threads - _global_request_id = datetime.now().strftime('%Y%m%d_%H%M%S') + '-' + str(uuid.uuid4().hex)[0:8] - return _global_request_id - - -def clear_request_id(): - """Clear the global request ID.""" - global _global_request_id - _global_request_id = None + global _request_id + if _request_id is not None: + return _request_id + _request_id = datetime.now().strftime('%Y%m%d_%H%M%S') + '-' + str(uuid.uuid4().hex)[0:8] + return _request_id diff --git a/src/twinkle_client/manager.py b/src/twinkle_client/manager.py index f0f987a6..108465ec 100644 --- a/src/twinkle_client/manager.py +++ b/src/twinkle_client/manager.py @@ -1,12 +1,18 @@ # Copyright (c) ModelScope Contributors. All rights reserved. from __future__ import annotations -from typing import Any, Dict, List, Optional - -# Reuse Pydantic models from server -from twinkle.server.twinkle.common.io_utils import Checkpoint, Cursor, TrainingRun -from .http.http_utils import http_get, http_post - +import atexit +import threading +from typing import Any, Dict, List, Optional, Tuple +from twinkle import get_logger +from twinkle_client.types.server import DeleteCheckpointResponse +from twinkle_client.types.session import (CreateSessionRequest, CreateSessionResponse, SessionHeartbeatRequest, + SessionHeartbeatResponse) +from twinkle_client.types.training import (Checkpoint, Cursor, ParsedCheckpointTwinklePath, TrainingRun, + TrainingRunsResponse, WeightsInfoResponse) +from .http import get_api_key, get_base_url, http_delete, http_get, http_post, set_api_key, set_base_url, set_session_id + +logger = get_logger() class TwinkleClientError(Exception): """Base exception for TwinkleManager errors.""" @@ -17,28 +23,63 @@ class TwinkleClient: """ Client manager for interacting with Twinkle REST API. - This manager provides methods to: - - List training runs owned by the current user - - Get details of specific training runs - - List checkpoints for a training run - - Get checkpoint file paths for resume training - - Delete checkpoints - - All operations respect user permissions - users can only access - and modify their own resources. + On initialization this client: + - Sets the base_url and api_key into the shared context so that all other + client objects (MultiLoraTransformersModel, vLLMSampler, processor clients) + automatically pick up the same configuration. + - Creates a server-side session and stores the session_id in context so that + every outgoing HTTP request carries it in the ``X-Twinkle-Session-Id`` header. + - Starts a lightweight background thread that touches the session every + ``session_heartbeat_interval`` seconds to keep it alive. Args: - base_url: Base URL of the Twinkle server (e.g., "http://localhost:8000"). - api_key: API key for authentication. If not provided, uses - TWINKLE_SERVER_TOKEN environment variable - route_prefix: API route prefix (default: "/server") + base_url: Base URL of the Twinkle server (e.g. "http://localhost:8000"). + Falls back to the ``TWINKLE_SERVER_URL`` environment variable. + api_key: API key for authentication. Falls back to the + ``TWINKLE_SERVER_TOKEN`` environment variable. + route_prefix: API route prefix (default: "/twinkle"). + session_heartbeat_interval: Seconds between session touch calls (default: 30). + session_metadata: Optional metadata dict stored with the session on the server. """ - def __init__(self, base_url: str = None, api_key: str = None, route_prefix: str | None = '/server'): - self.base_url = base_url - self.api_key = api_key + def __init__( + self, + base_url: Optional[str] = None, + api_key: Optional[str] = None, + route_prefix: Optional[str] = '/twinkle', + session_heartbeat_interval: int = 10, + session_metadata: Optional[Dict[str, Any]] = None, + ): + # Resolve and store config, then propagate to context so all generated + # client objects that call get_base_url() / get_api_key() get these values. + if base_url: + set_base_url(base_url) + if api_key: + set_api_key(api_key) + + self.base_url = get_base_url() + self.api_key = get_api_key() self.route_prefix = route_prefix.rstrip('/') if route_prefix else '' + # Create a server-side session. + self._session_id: str = self.create_session(session_metadata) + set_session_id(self._session_id) + + # Start background session-touch thread. + self._heartbeat_interval = session_heartbeat_interval + self._stop_event = threading.Event() + self._heartbeat_thread = threading.Thread( + target=self._touch_session_loop, + daemon=True, + name='TwinkleSessionHeartbeat', + ) + self._heartbeat_thread.start() + atexit.register(self.close) + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + def _get_url(self, endpoint: str) -> str: """Construct full URL for an endpoint.""" return f'{self.base_url}{self.route_prefix}{endpoint}' @@ -54,14 +95,71 @@ def _handle_response(self, response, expected_code: int = 200) -> dict[str, Any] raise TwinkleClientError(f'Request failed with status {response.status_code}: {detail}') return response.json() - # ----- Health Check ----- + def create_session(self, metadata: Optional[Dict[str, Any]] = None) -> str: + """ + Create a server-side session. + + Args: + metadata: Optional metadata dict stored with the session on the server. + + Returns: + The session ID string. + + Raises: + TwinkleClientError: If the session creation request fails. + """ + resp = http_post( + self._get_url('/create_session'), + json_data=CreateSessionRequest(metadata=metadata).model_dump(), + ) + resp.raise_for_status() + return CreateSessionResponse(**resp.json()).session_id + + def _touch_session_loop(self) -> None: + """Background loop: touch the session every ``_heartbeat_interval`` seconds. + + Uses a fixed-rate design: the wall-clock period between successive + server-side heartbeats stays close to ``_heartbeat_interval`` regardless + of how long the HTTP call takes, by subtracting elapsed time from the + subsequent sleep. + """ + import time + while not self._stop_event.is_set(): + t0 = time.monotonic() + success = False + try: + logger.debug(f'[TwinkleClient] Touching session (session={self._session_id})...') + resp = http_post( + self._get_url('/session_heartbeat'), + json_data=SessionHeartbeatRequest(session_id=self._session_id).model_dump(), + timeout=min(self._heartbeat_interval, 10), + ) + resp.raise_for_status() + success = True + except Exception as e: + logger.error(f'[TwinkleClient] Session heartbeat error: {e}') + elapsed = time.monotonic() - t0 + if success: + logger.debug(f'[TwinkleClient] Session heartbeat OK (elapsed={elapsed:.2f}s)') + sleep_time = max(0.0, self._heartbeat_interval - elapsed) + self._stop_event.wait(timeout=sleep_time) + + def close(self) -> None: + """Stop the background heartbeat thread and clear session context.""" + self._stop_event.set() + if self._heartbeat_thread.is_alive(): + self._heartbeat_thread.join(timeout=2) + + # ------------------------------------------------------------------ + # Health Check + # ------------------------------------------------------------------ def health_check(self) -> bool: """ Check if the Twinkle server is healthy. Returns: - True if server is healthy, False otherwise + True if server is healthy, False otherwise. """ try: response = http_get(self._get_url('/healthz')) @@ -69,66 +167,64 @@ def health_check(self) -> bool: except Exception: return False - # ----- Training Runs ----- + # ------------------------------------------------------------------ + # Training Runs + # ------------------------------------------------------------------ - def list_training_runs(self, limit: int = 20, offset: int = 0, all_users: bool = False) -> list[TrainingRun]: + def list_training_runs(self, limit: int = 20, offset: int = 0, all_users: bool = False) -> List[TrainingRun]: """ List training runs. By default, only returns training runs owned by the current user. Args: - limit: Maximum number of results (default: 20) - offset: Offset for pagination (default: 0) - all_users: If True, return all runs (if permission allows) + limit: Maximum number of results (default: 20). + offset: Offset for pagination (default: 0). + all_users: If True, return all runs (if permission allows). Returns: - List of TrainingRun objects + List of :class:`~twinkle_client.types.training.TrainingRun` objects. Raises: - TwinkleManagerError: If the request fails + TwinkleClientError: If the request fails. """ - params = {'limit': limit, 'offset': offset} + params: Dict[str, Any] = {'limit': limit, 'offset': offset} if all_users: params['all_users'] = 'true' response = http_get(self._get_url('/training_runs'), params=params) data = self._handle_response(response) - runs = [] - for run_data in data.get('training_runs', []): - runs.append(TrainingRun(**run_data)) - return runs + return [TrainingRun(**r) for r in data.get('training_runs', [])] - def list_training_runs_with_cursor(self, - limit: int = 20, - offset: int = 0, - all_users: bool = False) -> tuple[list[TrainingRun], Cursor]: + def list_training_runs_with_cursor( + self, + limit: int = 20, + offset: int = 0, + all_users: bool = False, + ) -> Tuple[List[TrainingRun], Cursor]: """ List training runs with pagination info. Args: - limit: Maximum number of results (default: 20) - offset: Offset for pagination (default: 0) - all_users: If True, return all runs (if permission allows) + limit: Maximum number of results (default: 20). + offset: Offset for pagination (default: 0). + all_users: If True, return all runs (if permission allows). Returns: - Tuple of (list of TrainingRun, Cursor with pagination info) + Tuple of (list of TrainingRun, Cursor with pagination info). Raises: - TwinkleManagerError: If the request fails + TwinkleClientError: If the request fails. """ - params = {'limit': limit, 'offset': offset} + params: Dict[str, Any] = {'limit': limit, 'offset': offset} if all_users: params['all_users'] = 'true' response = http_get(self._get_url('/training_runs'), params=params) data = self._handle_response(response) - runs = [] - for run_data in data.get('training_runs', []): - runs.append(TrainingRun(**run_data)) - + runs = [TrainingRun(**r) for r in data.get('training_runs', [])] cursor = Cursor(**data.get('cursor', {})) return runs, cursor @@ -137,158 +233,156 @@ def get_training_run(self, run_id: str) -> TrainingRun: Get details of a specific training run. Args: - run_id: The training run identifier + run_id: The training run identifier. Returns: - TrainingRun object with run details + :class:`~twinkle_client.types.training.TrainingRun` object with run details. Raises: - TwinkleManagerError: If run not found or access denied + TwinkleClientError: If run not found or access denied. """ response = http_get(self._get_url(f'/training_runs/{run_id}')) data = self._handle_response(response) return TrainingRun(**data) - # ----- Checkpoints ----- + # ------------------------------------------------------------------ + # Checkpoints + # ------------------------------------------------------------------ - def list_checkpoints(self, run_id: str) -> list[Checkpoint]: + def list_checkpoints(self, run_id: str) -> List[Checkpoint]: """ List checkpoints for a training run. Args: - run_id: The training run identifier + run_id: The training run identifier. Returns: - List of Checkpoint objects + List of :class:`~twinkle_client.types.training.Checkpoint` objects. Raises: - TwinkleManagerError: If run not found or access denied + TwinkleClientError: If run not found or access denied. """ response = http_get(self._get_url(f'/training_runs/{run_id}/checkpoints')) data = self._handle_response(response) + return [Checkpoint(**c) for c in data.get('checkpoints', [])] - checkpoints = [] - for ckpt_data in data.get('checkpoints', []): - checkpoints.append(Checkpoint(**ckpt_data)) - return checkpoints - - def get_checkpoint_path(self, run_id: str, checkpoint_id: str) -> str: + def get_checkpoint_path(self, run_id: str, checkpoint_id: str) -> ParsedCheckpointTwinklePath: """ - Get the filesystem path for a checkpoint. - - This path can be used to load weights for resume training. + Get the filesystem path and twinkle:// path for a checkpoint. Args: - run_id: The training run identifier - checkpoint_id: The checkpoint identifier (e.g., "weights/20240101_120000") + run_id: The training run identifier. + checkpoint_id: The checkpoint identifier (e.g. "weights/20240101_120000"). Returns: - Filesystem path to the checkpoint directory + :class:`~twinkle_client.types.training.ParsedCheckpointTwinklePath` with + ``path`` (filesystem) and ``twinkle_path`` fields. Raises: - TwinkleManagerError: If checkpoint not found or access denied + TwinkleClientError: If checkpoint not found or access denied. """ response = http_get(self._get_url(f'/checkpoint_path/{run_id}/{checkpoint_id}')) data = self._handle_response(response) - return data.get('path', '') + return ParsedCheckpointTwinklePath( + path=data.get('path', ''), + twinkle_path=data.get('twinkle_path', ''), + training_run_id=run_id, + checkpoint_type=checkpoint_id.split('/')[0] if '/' in checkpoint_id else '', + checkpoint_id=checkpoint_id, + ) def get_checkpoint_twinkle_path(self, run_id: str, checkpoint_id: str) -> str: """ Get the twinkle:// path for a checkpoint. Args: - run_id: The training run identifier - checkpoint_id: The checkpoint identifier + run_id: The training run identifier. + checkpoint_id: The checkpoint identifier. Returns: - Twinkle path (e.g., "twinkle://run_id/weights/checkpoint_name") + Twinkle path string (e.g. "twinkle://run_id/weights/checkpoint_name"). Raises: - TwinkleManagerError: If checkpoint not found or access denied + TwinkleClientError: If checkpoint not found or access denied. """ - response = http_get(self._get_url(f'/checkpoint_path/{run_id}/{checkpoint_id}')) - data = self._handle_response(response) - return data.get('twinkle_path', '') + return self.get_checkpoint_path(run_id, checkpoint_id).twinkle_path - def delete_checkpoint(self, run_id: str, checkpoint_id: str) -> bool: + def delete_checkpoint(self, run_id: str, checkpoint_id: str) -> DeleteCheckpointResponse: """ Delete a checkpoint. Args: - run_id: The training run identifier - checkpoint_id: The checkpoint identifier + run_id: The training run identifier. + checkpoint_id: The checkpoint identifier. Returns: - True if deletion was successful + :class:`~twinkle_client.types.server.DeleteCheckpointResponse` indicating success. Raises: - TwinkleManagerError: If checkpoint not found or access denied + TwinkleClientError: If checkpoint not found or access denied. """ - from .http import http_delete - url = self._get_url(f'/training_runs/{run_id}/checkpoints/{checkpoint_id}') response = http_delete(url) data = self._handle_response(response) - return data.get('success', False) + return DeleteCheckpointResponse(**data) - # ----- Weights Info ----- + # ------------------------------------------------------------------ + # Weights Info + # ------------------------------------------------------------------ - def get_weights_info(self, twinkle_path: str) -> dict[str, Any]: + def get_weights_info(self, twinkle_path: str) -> WeightsInfoResponse: """ Get information about saved weights. Args: - twinkle_path: The twinkle:// path to the weights + twinkle_path: The twinkle:// path to the weights. Returns: - Dictionary with weight information including: - - training_run_id - - base_model - - model_owner - - is_lora - - lora_rank + :class:`~twinkle_client.types.training.WeightsInfoResponse` with fields: + ``training_run_id``, ``base_model``, ``model_owner``, ``is_lora``, ``lora_rank``. Raises: - TwinkleManagerError: If weights not found or access denied + TwinkleClientError: If weights not found or access denied. """ response = http_post(self._get_url('/weights_info'), json_data={'twinkle_path': twinkle_path}) - return self._handle_response(response) + data = self._handle_response(response) + return WeightsInfoResponse(**data) - # ----- Convenience Methods for Resume Training ----- + # ------------------------------------------------------------------ + # Convenience Methods + # ------------------------------------------------------------------ - def get_latest_checkpoint_path(self, run_id: str) -> str | None: + def get_latest_checkpoint_path(self, run_id: str) -> Optional[str]: """ - Get the path to the latest checkpoint for a training run. + Get the filesystem path to the latest checkpoint for a training run. - This is useful for resume training - it returns the path to the - most recent checkpoint that can be loaded. + Useful for resume training — returns the path to the most recent checkpoint. Args: - run_id: The training run identifier + run_id: The training run identifier. Returns: - Filesystem path to the latest checkpoint, or None if no checkpoints exist + Filesystem path string to the latest checkpoint, or ``None`` if none exist. Raises: - TwinkleManagerError: If run not found or access denied + TwinkleClientError: If run not found or access denied. """ checkpoints = self.list_checkpoints(run_id) if not checkpoints: return None - - # Checkpoints are sorted by time, so last one is the latest latest = checkpoints[-1] - return self.get_checkpoint_path(run_id, latest.checkpoint_id) + return self.get_checkpoint_path(run_id, latest.checkpoint_id).path - def find_training_run_by_model(self, base_model: str) -> list[TrainingRun]: + def find_training_run_by_model(self, base_model: str) -> List[TrainingRun]: """ Find training runs for a specific base model. Args: - base_model: The base model name to search for + base_model: The base model name to search for. Returns: - List of TrainingRun objects matching the base model + List of :class:`~twinkle_client.types.training.TrainingRun` objects + matching the base model. """ all_runs = self.list_training_runs(limit=100) return [run for run in all_runs if run.base_model == base_model] diff --git a/src/twinkle_client/model/multi_lora_transformers.py b/src/twinkle_client/model/multi_lora_transformers.py index f681c96b..743125d9 100644 --- a/src/twinkle_client/model/multi_lora_transformers.py +++ b/src/twinkle_client/model/multi_lora_transformers.py @@ -8,18 +8,25 @@ # 1. Modify the source files in src/twinkle/ # 2. Run: python client_tools/client_generator.py # ============================================================================ -from typing import Any, Optional, Union, Type, Dict, Literal, List -import uuid -from twinkle_client.http import http_post, heartbeat_manager -from twinkle import DeviceMesh -from twinkle.data_format import InputFeature, Trajectory +from typing import Any, Dict, Optional +from twinkle_client.http import http_post +from twinkle_client.types.model import ( + CalculateLossResponse, + CalculateMetricResponse, + ClipGradNormResponse, + ForwardBackwardResponse, + ForwardResponse, + GetStateDictResponse, + GetTrainConfigsResponse, + SaveResponse, +) class MultiLoraTransformersModel: """Client wrapper for TwinkleModel that calls server HTTP endpoints. This client manages adapters and sends training/inference requests to the model server. - Each adapter has its own lifecycle managed through automatic heartbeats. + The server-side session (managed by TwinkleClient) keeps the model alive. """ def __init__(self, model_id: str, **kwargs): @@ -30,215 +37,214 @@ def __init__(self, model_id: str, **kwargs): self.model_id = model_id if '://' in model_id: model_id = model_id.split('://')[1] - self.server_url = f'{self.server_url}/models/{model_id}' + self.server_url = f'{self.server_url}/model/{model_id}/twinkle' self.adapter_name = None response = http_post( url=f'{self.server_url}/create', ) response.raise_for_status() - def _send_adapter_heartbeat(self): - """Internal method to send adapter heartbeat.""" - response = http_post( - url=f'{self.server_url}/heartbeat', - json_data={'adapter_name': self.adapter_name} - ) - response.raise_for_status() - - def add_adapter_to_model(self, adapter_name: str, config: Dict[str, Any], **kwargs): - """Add a new adapter to the model and start automatic heartbeat.""" + def add_adapter_to_model(self, adapter_name: str, config: Dict[str, Any], **kwargs) -> None: + """Add a new adapter to the model.""" response = http_post( url=f'{self.server_url}/add_adapter_to_model', json_data={'adapter_name': adapter_name, 'config': config, **kwargs} ) response.raise_for_status() - - # Register adapter for automatic heartbeat after successful creation self.adapter_name = adapter_name - heartbeat_manager.register_adapter( - self.adapter_name, - self._send_adapter_heartbeat - ) - def __del__(self): - """Cleanup: unregister adapter from heartbeat manager.""" - try: - heartbeat_manager.unregister_adapter(self.adapter_name) - except: - pass - - def forward(self, inputs: Any, **kwargs): + def forward(self, inputs: Any, **kwargs) -> ForwardResponse: """Execute forward pass on the model.""" response = http_post( url=f'{self.server_url}/forward', json_data={'inputs': inputs, 'adapter_name': self.adapter_name, **kwargs} ) response.raise_for_status() - return response.json()['result'] + return ForwardResponse(**response.json()) - def forward_only(self, inputs: Any, **kwargs): + def forward_only(self, inputs: Any, **kwargs) -> ForwardResponse: """Execute forward pass without gradient computation.""" response = http_post( url=f'{self.server_url}/forward_only', json_data={'inputs': inputs, 'adapter_name': self.adapter_name, **kwargs} ) response.raise_for_status() - return response.json()['result'] + return ForwardResponse(**response.json()) - def calculate_loss(self, **kwargs): + def calculate_loss(self, **kwargs) -> CalculateLossResponse: """Calculate loss from model outputs.""" response = http_post( url=f'{self.server_url}/calculate_loss', json_data={'adapter_name': self.adapter_name, **kwargs} ) response.raise_for_status() - return response.json()['result'] + return CalculateLossResponse(**response.json()) - def get_train_configs(self, **kwargs): - """Get training configs""" + def get_train_configs(self, **kwargs) -> GetTrainConfigsResponse: + """Get training configs.""" response = http_post( url=f'{self.server_url}/get_train_configs', json_data={'adapter_name': self.adapter_name, **kwargs} ) response.raise_for_status() - return response.json()['result'] + return GetTrainConfigsResponse(**response.json()) - def backward(self, **kwargs): + def backward(self, **kwargs) -> None: """Execute backward pass.""" response = http_post( url=f'{self.server_url}/backward', json_data={'adapter_name': self.adapter_name, **kwargs} ) response.raise_for_status() - return response.json()['result'] - def forward_backward(self, inputs: Any, **kwargs): + def forward_backward(self, inputs: Any, **kwargs) -> ForwardBackwardResponse: """Execute combined forward and backward pass.""" response = http_post( url=f'{self.server_url}/forward_backward', json_data={'inputs': inputs, 'adapter_name': self.adapter_name, **kwargs} ) response.raise_for_status() - return response.json()['result'] + return ForwardBackwardResponse(**response.json()) - def step(self, **kwargs): + def step(self, **kwargs) -> None: """Execute optimizer step.""" response = http_post( url=f'{self.server_url}/step', json_data={'adapter_name': self.adapter_name, **kwargs} ) response.raise_for_status() - return response.json()['result'] - def zero_grad(self, **kwargs): + def zero_grad(self, **kwargs) -> None: """Zero out gradients.""" response = http_post( url=f'{self.server_url}/zero_grad', json_data={'adapter_name': self.adapter_name, **kwargs} ) response.raise_for_status() - return response.json()['result'] - def lr_step(self, **kwargs): + def lr_step(self, **kwargs) -> None: """Execute learning rate scheduler step.""" response = http_post( url=f'{self.server_url}/lr_step', json_data={'adapter_name': self.adapter_name, **kwargs} ) response.raise_for_status() - return response.json()['result'] - def set_loss(self, loss_cls: str, **kwargs): - """Set the loss function.""" + def clip_grad_norm(self, max_grad_norm: float = 1.0, norm_type: int = 2, **kwargs) -> ClipGradNormResponse: + """Clip gradient norm.""" response = http_post( - url=f'{self.server_url}/set_loss', - json_data={'loss_cls': loss_cls, 'adapter_name': self.adapter_name, **kwargs} + url=f'{self.server_url}/clip_grad_norm', + json_data={'max_grad_norm': max_grad_norm, 'norm_type': norm_type, 'adapter_name': self.adapter_name, **kwargs} ) response.raise_for_status() - return response.json()['result'] + return ClipGradNormResponse(**response.json()) - def clip_grad_norm(self, max_grad_norm: float=1.0, norm_type=2, **kwargs): - """Set the loss function.""" + def clip_grad_and_step(self, max_grad_norm: float = 1.0, norm_type: int = 2, **kwargs) -> None: + """Clip gradient norm and execute optimizer step in one call.""" response = http_post( - url=f'{self.server_url}/clip_grad_norm', + url=f'{self.server_url}/clip_grad_and_step', json_data={'max_grad_norm': max_grad_norm, 'norm_type': norm_type, 'adapter_name': self.adapter_name, **kwargs} ) response.raise_for_status() - return response.json()['result'] - def set_optimizer(self, optimizer_cls: str, **kwargs): + def set_loss(self, loss_cls: str, **kwargs) -> None: + """Set the loss function.""" + response = http_post( + url=f'{self.server_url}/set_loss', + json_data={'loss_cls': loss_cls, 'adapter_name': self.adapter_name, **kwargs} + ) + response.raise_for_status() + + def set_optimizer(self, optimizer_cls: str, **kwargs) -> None: """Set the optimizer.""" response = http_post( url=f'{self.server_url}/set_optimizer', json_data={'optimizer_cls': optimizer_cls, 'adapter_name': self.adapter_name, **kwargs} ) response.raise_for_status() - return response.json()['result'] - def set_lr_scheduler(self, scheduler_cls: str, **kwargs): + def set_lr_scheduler(self, scheduler_cls: str, **kwargs) -> None: """Set the learning rate scheduler.""" response = http_post( url=f'{self.server_url}/set_lr_scheduler', json_data={'scheduler_cls': scheduler_cls, 'adapter_name': self.adapter_name, **kwargs} ) response.raise_for_status() - return response.json()['result'] - def save(self, name: str, **kwargs): + def save(self, name: str, **kwargs) -> SaveResponse: """Save model checkpoint.""" response = http_post( url=f'{self.server_url}/save', json_data={'name': name, 'adapter_name': self.adapter_name, **kwargs} ) response.raise_for_status() - return response.json()['result'] + return SaveResponse(**response.json()) - def load(self, name: str, **kwargs): + def load(self, name: str, **kwargs) -> None: """Load model checkpoint.""" response = http_post( url=f'{self.server_url}/load', json_data={'name': name, 'adapter_name': self.adapter_name, **kwargs} ) response.raise_for_status() - return response.json()['result'] - def set_template(self, template_cls: str, **kwargs): + def apply_patch(self, patch_cls: str, **kwargs) -> None: + """Apply a patch to the model.""" + response = http_post( + url=f'{self.server_url}/apply_patch', + json_data={'patch_cls': patch_cls, 'adapter_name': self.adapter_name, **kwargs} + ) + response.raise_for_status() + + def add_metric(self, metric_cls: str, is_training: Optional[bool] = None, **kwargs) -> None: + """Add a metric to the model.""" + response = http_post( + url=f'{self.server_url}/add_metric', + json_data={'metric_cls': metric_cls, 'is_training': is_training, 'adapter_name': self.adapter_name, **kwargs} + ) + response.raise_for_status() + + def set_template(self, template_cls: str, **kwargs) -> None: """Set the template for data processing.""" response = http_post( url=f'{self.server_url}/set_template', json_data={'template_cls': template_cls, 'adapter_name': self.adapter_name, 'model_id': self.model_id, **kwargs} ) response.raise_for_status() - return response.json()['result'] - def set_processor(self, processor_cls: str, **kwargs): + def set_processor(self, processor_cls: str, **kwargs) -> None: """Set the input processor.""" response = http_post( url=f'{self.server_url}/set_processor', json_data={'processor_cls': processor_cls, 'adapter_name': self.adapter_name, **kwargs} ) response.raise_for_status() - return response.json()['result'] - def calculate_metric(self, is_training: bool = True, **kwargs): + def calculate_metric(self, is_training: bool = True, **kwargs) -> CalculateMetricResponse: """Calculate metrics from model outputs.""" response = http_post( url=f'{self.server_url}/calculate_metric', json_data={'is_training': is_training, 'adapter_name': self.adapter_name, **kwargs} ) response.raise_for_status() - return response.json()['result'] + return CalculateMetricResponse(**response.json()) - def get_state_dict(self, **kwargs): + def get_state_dict(self, **kwargs) -> GetStateDictResponse: """Get model state dictionary.""" response = http_post( url=f'{self.server_url}/get_state_dict', json_data={'adapter_name': self.adapter_name, **kwargs} ) response.raise_for_status() - return response.json()['result'] + return GetStateDictResponse(**response.json()) - def upload_to_hub(self, checkpoint_dir: str, hub_model_id: str, hub_token: Optional[str] = None, async_upload: bool = True): + def upload_to_hub( + self, + checkpoint_dir: str, + hub_model_id: str, + hub_token: Optional[str] = None, + async_upload: bool = True, + ) -> None: """Upload model checkpoint to hub. Args: @@ -253,8 +259,7 @@ def upload_to_hub(self, checkpoint_dir: str, hub_model_id: str, hub_token: Optio 'checkpoint_dir': checkpoint_dir, 'hub_model_id': hub_model_id, 'hub_token': hub_token, - 'async_upload': async_upload + 'async_upload': async_upload, } ) response.raise_for_status() - return response.json() diff --git a/src/twinkle_client/processor/base.py b/src/twinkle_client/processor/base.py index d59572a7..048ace5e 100644 --- a/src/twinkle_client/processor/base.py +++ b/src/twinkle_client/processor/base.py @@ -10,7 +10,7 @@ # ============================================================================ from typing import List, Literal, Optional, Union -from twinkle_client.http import http_post, heartbeat_manager +from twinkle_client.http import http_post from twinkle import DeviceMesh from twinkle.data_format import InputFeature @@ -19,10 +19,10 @@ class InputProcessor(object): def __init__(self, device_mesh: Optional[DeviceMesh] = None, padding_free: bool = False, framework: Literal['transformers', 'megatron'] = 'transformers', **kwargs): from twinkle_client.http import get_base_url - self.server_url = get_base_url() + self.server_url = f'{get_base_url()}/processor/twinkle' response = http_post( - url=f'{self.server_url}/processors/create', + url=f'{self.server_url}/create', json_data={ 'processor_type': 'processor', 'class_type': 'InputProcessor', @@ -31,18 +31,11 @@ def __init__(self, device_mesh: Optional[DeviceMesh] = None, padding_free: bool ) response.raise_for_status() self.processor_id = response.json()['processor_id'] - heartbeat_manager.register_processor(self.processor_id) - - def __del__(self): - try: - heartbeat_manager.unregister_processor(self.processor_id) - except: - pass def __call__(self, inputs: Union[InputFeature, List[InputFeature]], **kwargs): response = http_post( - url=f'{self.server_url}/processors/call', + url=f'{self.server_url}/call', json_data={ 'processor_id': self.processor_id, 'function': '__call__', diff --git a/src/twinkle_client/sampler/vllm_sampler.py b/src/twinkle_client/sampler/vllm_sampler.py index 907881a4..a19984c3 100644 --- a/src/twinkle_client/sampler/vllm_sampler.py +++ b/src/twinkle_client/sampler/vllm_sampler.py @@ -8,9 +8,10 @@ # 1. Modify the source files in src/twinkle/ # 2. Run: python client_tools/client_generator.py # ============================================================================ -from typing import Any, Optional, List, Dict, Union -from twinkle_client.http import http_post, heartbeat_manager +from typing import Any, Dict, List, Optional, Union +from twinkle_client.http import http_post from twinkle.sampler.base import Sampler +from twinkle_client.types.sampler import AddAdapterResponse, SampleResponseModel, SetTemplateResponse from peft import PeftConfig from twinkle.data_format import Trajectory, InputFeature @@ -19,7 +20,7 @@ class vLLMSampler(Sampler): """Client wrapper for Sampler that calls server HTTP endpoints. This client manages sampling operations and adapter synchronization with the sampler server. - Each adapter has its own lifecycle managed through automatic heartbeats. + The server-side session (managed by TwinkleClient) keeps the sampler alive. """ def __init__(self, model_id: str, **kwargs): @@ -30,25 +31,15 @@ def __init__(self, model_id: str, **kwargs): self.adapter_name = None if '://' in model_id: model_id = model_id.split('://')[1] - self.server_url = f'{self.server_url}/samplers/{model_id}' + self.server_url = f'{self.server_url}/sampler/{model_id}/twinkle' response = http_post( url=f'{self.server_url}/create', json_data=kwargs ) response.raise_for_status() - def _send_adapter_heartbeat(self): - """Internal method to send adapter heartbeat.""" - if not self.adapter_name: - return - response = http_post( - url=f'{self.server_url}/heartbeat', - json_data={'adapter_name': self.adapter_name} - ) - response.raise_for_status() - - def add_adapter_to_sampler(self, adapter_name: str, config: PeftConfig, **kwargs): - """Add a new adapter to the sampler and start automatic heartbeat.""" + def add_adapter_to_sampler(self, adapter_name: str, config: PeftConfig, **kwargs) -> AddAdapterResponse: + """Add a new adapter to the sampler.""" if isinstance(config, PeftConfig): config = config.__dict__ response = http_post( @@ -56,23 +47,8 @@ def add_adapter_to_sampler(self, adapter_name: str, config: PeftConfig, **kwargs json_data={'adapter_name': adapter_name, 'config': config, **kwargs} ) response.raise_for_status() - - # Register adapter for automatic heartbeat after successful creation self.adapter_name = adapter_name - heartbeat_manager.register_adapter( - self.adapter_name, - self._send_adapter_heartbeat - ) - - return response.json() - - def __del__(self): - """Cleanup: unregister adapter from heartbeat manager.""" - try: - if self.adapter_name: - heartbeat_manager.unregister_adapter(self.adapter_name) - except: - pass + return AddAdapterResponse(**response.json()) def sample( self, @@ -81,7 +57,7 @@ def sample( adapter_name: str = '', adapter_uri: Optional[str] = None, num_samples: int = 1, - ) -> Dict[str, Any]: + ) -> SampleResponseModel: """Sample from the model. Args: @@ -92,7 +68,7 @@ def sample( num_samples: Number of completions to generate per prompt. Returns: - Dict with 'sequences' list, each containing tokens, logprobs, stop_reason. + SampleResponseModel with 'sequences' list, each containing tokens, logprobs, stop_reason. """ json_data = { 'inputs': inputs, @@ -108,13 +84,13 @@ def sample( json_data=json_data ) response.raise_for_status() - return response.json() + return SampleResponseModel(**response.json()) - def set_template(self, template_cls: str, adapter_name: str = '', **kwargs): + def set_template(self, template_cls: str, adapter_name: str = '', **kwargs) -> SetTemplateResponse: """Set the template for encoding trajectories.""" response = http_post( url=f'{self.server_url}/set_template', json_data={'template_cls': template_cls, 'adapter_name': adapter_name, **kwargs} ) response.raise_for_status() - return response.json() + return SetTemplateResponse(**response.json()) diff --git a/src/twinkle_client/types/__init__.py b/src/twinkle_client/types/__init__.py new file mode 100644 index 00000000..b6650a28 --- /dev/null +++ b/src/twinkle_client/types/__init__.py @@ -0,0 +1,87 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +from .model import ( + AddAdapterRequest, + AddAdapterResponse, + AddMetricRequest, + AddMetricResponse, + AdapterRequest, + ApplyPatchRequest, + ApplyPatchResponse, + BackwardResponse, + CalculateLossResponse, + CalculateMetricRequest, + CalculateMetricResponse, + ClipGradAndStepRequest, + ClipGradAndStepResponse, + ClipGradNormResponse, + CreateRequest, + CreateResponse, + ForwardBackwardResponse, + ForwardOnlyRequest, + ForwardRequest, + ForwardResponse, + GetStateDictRequest, + GetStateDictResponse, + GetTrainConfigsResponse, + LoadRequest, + LoadResponse, + LrStepResponse, + ModelResult, + OkResponse, + SaveRequest, + SaveResponse, + SetLossRequest, + SetLossResponse, + SetLrSchedulerRequest, + SetLrSchedulerResponse, + SetOptimizerRequest, + SetOptimizerResponse, + SetProcessorRequest, + SetProcessorResponse, + SetTemplateRequest, + SetTemplateResponse, + StepResponse, + UploadToHubRequest, + UploadToHubResponse, + ZeroGradResponse, +) +from .processor import ( + ProcessorCallRequest, + ProcessorCallResponse, + ProcessorCreateRequest, + ProcessorCreateResponse, + ProcessorHeartbeatRequest, + ProcessorHeartbeatResponse, +) +from .sampler import ( + AddAdapterRequest as SamplerAddAdapterRequest, + AddAdapterResponse, + CreateResponse as SamplerCreateResponse, + SampledSequenceModel, + SampleRequest, + SampleResponseModel, + SetTemplateRequest as SamplerSetTemplateRequest, + SetTemplateResponse as SamplerSetTemplateResponse, +) +from .server import ( + CheckpointPathResponse, + DeleteCheckpointResponse, + ErrorResponse, + HealthResponse, + WeightsInfoRequest, + WeightsInfoResponse as ServerWeightsInfoResponse, +) +from .session import CreateSessionRequest, CreateSessionResponse, SessionHeartbeatRequest, SessionHeartbeatResponse +from .training import ( + Checkpoint, + CheckpointsListResponse, + CreateModelRequest, + Cursor, + LoraConfig, + ParsedCheckpointTwinklePath, + TrainingRun, + TrainingRunsResponse, + WeightsInfoResponse, +) + +from .checkpoint import ResolvedLoadPath diff --git a/src/twinkle_client/types/checkpoint.py b/src/twinkle_client/types/checkpoint.py new file mode 100644 index 00000000..fe89cb41 --- /dev/null +++ b/src/twinkle_client/types/checkpoint.py @@ -0,0 +1,23 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +""" +Server-specific Pydantic models for checkpoint path resolution. +""" +from pydantic import BaseModel +from typing import Optional + + +class ResolvedLoadPath(BaseModel): + """Result of resolving a load path. + + Attributes: + checkpoint_name: The name of the checkpoint (e.g., 'step-8' or hub model id) + checkpoint_dir: The directory containing the checkpoint, or None if loading from hub + is_twinkle_path: Whether the path was a twinkle:// path + training_run_id: The training run ID (only set for twinkle:// paths) + checkpoint_id: The checkpoint ID (only set for twinkle:// paths) + """ + checkpoint_name: str + checkpoint_dir: Optional[str] = None + is_twinkle_path: bool = False + training_run_id: Optional[str] = None + checkpoint_id: Optional[str] = None diff --git a/src/twinkle_client/types/model.py b/src/twinkle_client/types/model.py new file mode 100644 index 00000000..e594bae4 --- /dev/null +++ b/src/twinkle_client/types/model.py @@ -0,0 +1,297 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +""" +Pydantic request/response models for twinkle model management endpoints. + +These models are used by both the server-side handler and the twinkle client. +""" +from pydantic import BaseModel +from typing import Any, Dict, List, Optional + + +class CreateRequest(BaseModel): + + class Config: + extra = 'allow' + + +class ForwardRequest(BaseModel): + inputs: Any + adapter_name: str + + class Config: + extra = 'allow' + + +class ForwardOnlyRequest(BaseModel): + inputs: Any + adapter_name: Optional[str] = None + + class Config: + extra = 'allow' + + +class AdapterRequest(BaseModel): + adapter_name: str + + class Config: + extra = 'allow' + + +class SetLossRequest(BaseModel): + loss_cls: str + adapter_name: str + + class Config: + extra = 'allow' + + +class SetOptimizerRequest(BaseModel): + optimizer_cls: str + adapter_name: str + + class Config: + extra = 'allow' + + +class SetLrSchedulerRequest(BaseModel): + scheduler_cls: str + adapter_name: str + + class Config: + extra = 'allow' + + +class SaveRequest(BaseModel): + adapter_name: str + save_optimizer: bool = False + name: Optional[str] = None + + class Config: + extra = 'allow' + + +class UploadToHubRequest(BaseModel): + checkpoint_dir: str + hub_model_id: str + hub_token: Optional[str] = None + async_upload: bool = False + + class Config: + extra = 'allow' + + +class LoadRequest(BaseModel): + adapter_name: str + load_optimizer: bool = False + name: str + + class Config: + extra = 'allow' + + +class AddAdapterRequest(BaseModel): + adapter_name: str + config: str + + class Config: + extra = 'allow' + + +class SetTemplateRequest(BaseModel): + template_cls: str + adapter_name: str + + class Config: + extra = 'allow' + + +class SetProcessorRequest(BaseModel): + processor_cls: str + adapter_name: str + + class Config: + extra = 'allow' + + +class CalculateMetricRequest(BaseModel): + adapter_name: str + is_training: bool = True + + class Config: + extra = 'allow' + + +class GetStateDictRequest(BaseModel): + adapter_name: str + + class Config: + extra = 'allow' + + +class ClipGradAndStepRequest(BaseModel): + adapter_name: str + max_grad_norm: float = 1.0 + norm_type: int = 2 + + class Config: + extra = 'allow' + + +class ApplyPatchRequest(BaseModel): + patch_cls: str + adapter_name: str + + class Config: + extra = 'allow' + + +class AddMetricRequest(BaseModel): + metric_cls: str + adapter_name: str + is_training: Optional[bool] = None + + class Config: + extra = 'allow' + + +# --------------------------------------------------------------------------- +# Response models +# --------------------------------------------------------------------------- + + +class OkResponse(BaseModel): + """Response for endpoints whose underlying method returns None.""" + status: str = 'ok' + + +class ModelResult(BaseModel): + """Generic single-value result wrapper returned by result-bearing endpoints.""" + result: Any + + +# --- Result-bearing responses --- + +class ForwardResponse(BaseModel): + """Response for /forward and /forward_only endpoints (returns ModelOutput).""" + result: Any + + +class ForwardBackwardResponse(BaseModel): + """Response for /forward_backward endpoint (returns ModelOutput).""" + result: Any + + +class CalculateLossResponse(BaseModel): + """Response for /calculate_loss endpoint (returns float).""" + result: float + + +class ClipGradNormResponse(BaseModel): + """Response for /clip_grad_norm endpoint (returns float as str).""" + result: str + + +class GetTrainConfigsResponse(BaseModel): + """Response for /get_train_configs endpoint (returns str).""" + result: str + + +class GetStateDictResponse(BaseModel): + """Response for /get_state_dict endpoint (returns Dict).""" + result: Dict[str, Any] + + +class CalculateMetricResponse(BaseModel): + """Response for /calculate_metric endpoint (returns Dict).""" + result: Dict[str, Any] + + +class SaveResponse(BaseModel): + """Response for /save endpoint (returns twinkle path + checkpoint dir).""" + twinkle_path: str + checkpoint_dir: Optional[str] = None + + +# --- Void responses (return None → OkResponse) --- + +class BackwardResponse(OkResponse): + """Response for /backward endpoint.""" + pass + + +class StepResponse(OkResponse): + """Response for /step (optimizer step) endpoint.""" + pass + + +class ZeroGradResponse(OkResponse): + """Response for /zero_grad endpoint.""" + pass + + +class LrStepResponse(OkResponse): + """Response for /lr_step endpoint.""" + pass + + +class SetLossResponse(OkResponse): + """Response for /set_loss endpoint.""" + pass + + +class SetOptimizerResponse(OkResponse): + """Response for /set_optimizer endpoint.""" + pass + + +class SetLrSchedulerResponse(OkResponse): + """Response for /set_lr_scheduler endpoint.""" + pass + + +class LoadResponse(OkResponse): + """Response for /load endpoint.""" + pass + + +class SetTemplateResponse(OkResponse): + """Response for /set_template endpoint.""" + pass + + +class SetProcessorResponse(OkResponse): + """Response for /set_processor endpoint.""" + pass + + +class UploadToHubResponse(OkResponse): + """Response for /upload_to_hub endpoint.""" + pass + + +class ClipGradAndStepResponse(OkResponse): + """Response for /clip_grad_and_step endpoint.""" + pass + + +class ApplyPatchResponse(OkResponse): + """Response for /apply_patch endpoint.""" + pass + + +class AddMetricResponse(OkResponse): + """Response for /add_metric endpoint.""" + pass + + +# --- Other responses --- + +class CreateResponse(BaseModel): + """Response for /create endpoint.""" + status: str = 'ok' + + +class AddAdapterResponse(BaseModel): + """Response for /add_adapter_to_model endpoint.""" + status: str = 'ok' + adapter_name: str diff --git a/src/twinkle_client/types/processor.py b/src/twinkle_client/types/processor.py new file mode 100644 index 00000000..fe8674ce --- /dev/null +++ b/src/twinkle_client/types/processor.py @@ -0,0 +1,46 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +""" +Pydantic request/response models for twinkle processor endpoints. + +These models are used by both the server-side handler and the twinkle client. + +Note: Class names are prefixed with 'Processor' to avoid name collisions when +importing from twinkle_client.types alongside model.py classes. +""" +from pydantic import BaseModel +from typing import Any + + +class ProcessorCreateRequest(BaseModel): + processor_type: str + class_type: str + + class Config: + extra = 'allow' + + +class ProcessorHeartbeatRequest(BaseModel): + processor_id: str + + +class ProcessorCallRequest(BaseModel): + processor_id: str + function: str + + class Config: + extra = 'allow' + + +class ProcessorCreateResponse(BaseModel): + """Response body for the /create endpoint.""" + processor_id: str + + +class ProcessorHeartbeatResponse(BaseModel): + """Response body for the /heartbeat endpoint.""" + status: str = 'ok' + + +class ProcessorCallResponse(BaseModel): + """Response body for the /call endpoint.""" + result: Any diff --git a/src/twinkle_client/types/sampler.py b/src/twinkle_client/types/sampler.py new file mode 100644 index 00000000..cf370330 --- /dev/null +++ b/src/twinkle_client/types/sampler.py @@ -0,0 +1,70 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +""" +Pydantic request/response models for twinkle sampler endpoints. + +These models are used by both the server-side handler and the twinkle client. +""" +from pydantic import BaseModel, Field +from typing import Any, Dict, List, Literal, Optional, Tuple + +StopReason = Literal['length', 'stop'] + + +class SampleRequest(BaseModel): + """Request body for the /sample endpoint.""" + inputs: Any = Field(..., description='List of Trajectory or InputFeature dicts') + sampling_params: Optional[Dict[str, Any]] = Field( + None, description='Sampling parameters (max_tokens, temperature, etc.)') + adapter_name: str = Field('', description='Adapter name for LoRA inference') + adapter_uri: Optional[str] = Field( + None, description='Adapter URI (twinkle:// path or local path) for LoRA inference') + num_samples: int = Field(1, description='Number of completions to generate per prompt') + + +class SampledSequenceModel(BaseModel): + """A single sampled sequence, mirroring twinkle.data_format.SampledSequence.""" + stop_reason: StopReason = Field(..., description="Stop reason: 'length' or 'stop'") + tokens: List[int] = Field(..., description='Token IDs of the sampled sequence') + logprobs: Optional[List[float]] = Field(None, description='Per-token log-probabilities') + decoded: Optional[str] = Field(None, description='Decoded text of the sampled sequence') + new_input_feature: Optional[Dict[str, Any]] = Field( + None, description='Updated InputFeature after sampling (input_ids, labels, etc.)') + + +class SampleResponseModel(BaseModel): + """Response body for the /sample endpoint, mirroring twinkle.data_format.SampleResponse.""" + sequences: List[SampledSequenceModel] = Field( + ..., description='List of sampled sequences') + prompt_logprobs: Optional[List[Optional[float]]] = None + topk_prompt_logprobs: Optional[List[Optional[List[Tuple[int, float]]]]] = None + + +class SetTemplateRequest(BaseModel): + """Request body for the /set_template endpoint.""" + template_cls: str = Field(..., description="Template class name (e.g. 'Template')") + adapter_name: str = Field('', description='Adapter name to associate the template with') + + class Config: + extra = 'allow' + + +class SetTemplateResponse(BaseModel): + """Response body for the /set_template endpoint.""" + status: str = 'ok' + + +class AddAdapterRequest(BaseModel): + """Request body for the /add_adapter_to_sampler endpoint.""" + adapter_name: str = Field(..., description='Name of the adapter to add') + config: Any = Field(..., description='LoRA configuration dict') + + +class AddAdapterResponse(BaseModel): + """Response body for the /add_adapter_to_sampler endpoint.""" + status: str = 'ok' + adapter_name: str + + +class CreateResponse(BaseModel): + """Response body for the /create endpoint.""" + status: str = 'ok' diff --git a/src/twinkle_client/types/server.py b/src/twinkle_client/types/server.py new file mode 100644 index 00000000..df7ed58a --- /dev/null +++ b/src/twinkle_client/types/server.py @@ -0,0 +1,32 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Shared Pydantic response models for the twinkle server health/error endpoints.""" +from pydantic import BaseModel +from typing import Any + + +class HealthResponse(BaseModel): + status: str + + +class DeleteCheckpointResponse(BaseModel): + success: bool + message: str + + +class ErrorResponse(BaseModel): + detail: str + + +class WeightsInfoRequest(BaseModel): + twinkle_path: str + + +class WeightsInfoResponse(BaseModel): + """Response body for the /weights_info endpoint.""" + weights_info: Any + + +class CheckpointPathResponse(BaseModel): + """Response body for the /checkpoint_path endpoint.""" + path: str + twinkle_path: str diff --git a/src/twinkle_client/types/session.py b/src/twinkle_client/types/session.py new file mode 100644 index 00000000..f6b1adb7 --- /dev/null +++ b/src/twinkle_client/types/session.py @@ -0,0 +1,24 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Pydantic models for twinkle session management endpoints.""" +from pydantic import BaseModel +from typing import Any, Dict, Optional + + +class CreateSessionRequest(BaseModel): + """Request body for POST /twinkle/create_session.""" + metadata: Optional[Dict[str, Any]] = None + + +class CreateSessionResponse(BaseModel): + """Response body for POST /twinkle/create_session.""" + session_id: str + + +class SessionHeartbeatRequest(BaseModel): + """Request body for POST /twinkle/session_heartbeat.""" + session_id: str + + +class SessionHeartbeatResponse(BaseModel): + """Response body for POST /twinkle/session_heartbeat.""" + pass diff --git a/src/twinkle_client/types/training.py b/src/twinkle_client/types/training.py new file mode 100644 index 00000000..4c8cba83 --- /dev/null +++ b/src/twinkle_client/types/training.py @@ -0,0 +1,91 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +""" +Shared Pydantic models for twinkle training runs and checkpoints. + +These types are used both by twinkle_client (as request/response shapes) +and by twinkle.server.common.io_utils (as persistence models). +""" +from datetime import datetime +from pydantic import BaseModel +from typing import Any, Dict, List, Optional + + +class Cursor(BaseModel): + limit: int + offset: int + total_count: int + + +class Checkpoint(BaseModel): + """Twinkle checkpoint model.""" + checkpoint_id: str + checkpoint_type: str + time: datetime + size_bytes: int + public: bool = False + twinkle_path: str + # Training run info (stored for hub downloads) + base_model: Optional[str] = None + is_lora: bool = False + lora_rank: Optional[int] = None + train_unembed: Optional[bool] = None + train_mlp: Optional[bool] = None + train_attn: Optional[bool] = None + user_metadata: Optional[Dict[str, Any]] = None + + +class TrainingRun(BaseModel): + """Twinkle training run model.""" + training_run_id: str + base_model: str + model_owner: str + is_lora: bool = False + corrupted: bool = False + lora_rank: Optional[int] = None + last_request_time: Optional[datetime] = None + last_checkpoint: Optional[Dict[str, Any]] = None + last_sampler_checkpoint: Optional[Dict[str, Any]] = None + user_metadata: Optional[Dict[str, Any]] = None + + +class TrainingRunsResponse(BaseModel): + training_runs: List[TrainingRun] + cursor: Cursor + + +class CheckpointsListResponse(BaseModel): + checkpoints: List[Checkpoint] + cursor: Optional[Cursor] = None + + +class ParsedCheckpointTwinklePath(BaseModel): + """Twinkle-specific parsed path model.""" + path: str + twinkle_path: str + training_run_id: str + checkpoint_type: str + checkpoint_id: str + + +class WeightsInfoResponse(BaseModel): + """Twinkle weights info response.""" + training_run_id: str + base_model: str + model_owner: str + is_lora: bool = False + lora_rank: Optional[int] = None + + +class LoraConfig(BaseModel): + """Twinkle LoRA configuration.""" + rank: int = 8 + train_unembed: bool = False + train_mlp: bool = True + train_attn: bool = True + + +class CreateModelRequest(BaseModel): + """Twinkle create model request.""" + base_model: str + lora_config: Optional[LoraConfig] = None + user_metadata: Optional[Dict[str, Any]] = None diff --git a/src/twinkle_client/utils/patch_tinker.py b/src/twinkle_client/utils/patch_tinker.py index 826274ae..5f6d955e 100644 --- a/src/twinkle_client/utils/patch_tinker.py +++ b/src/twinkle_client/utils/patch_tinker.py @@ -53,10 +53,10 @@ def _patched_async_tinker_init( # Get api_key from environment if not provided if api_key is None: - api_key = os.environ.get('TINKER_API_KEY') + api_key = os.environ.get('TWINKLE_SERVER_TOKEN') if api_key is None: raise TinkerError( - 'The api_key client option must be set either by passing api_key to the client or by setting the TINKER_API_KEY environment variable' + 'The api_key client option must be set either by passing api_key to the client or by setting the TWINKLE_SERVER_TOKEN environment variable' ) # REMOVED: api_key 'tml-' prefix validation # Original code: