diff --git a/.gitignore b/.gitignore index bdb3f5f4..7b64dc3e 100644 --- a/.gitignore +++ b/.gitignore @@ -170,3 +170,5 @@ modelscope_cache prompts swarmexp swarmlog +werewolves_swarm +.claude diff --git a/ajet/backbone/main_trinity.py b/ajet/backbone/main_trinity.py index 7956305f..dc06c21c 100644 --- a/ajet/backbone/main_trinity.py +++ b/ajet/backbone/main_trinity.py @@ -52,7 +52,7 @@ def patched_trainer_get_actor(cls, config: Config): Explorer.get_actor = classmethod(patched_explorer_get_actor) Trainer.get_actor = classmethod(patched_trainer_get_actor) - if ajet_config.ajet.enable_experimental_interchange_server: + if ajet_config.ajet.enable_interchange_server: from ajet.tuner_lib.experimental.as_oai_model_server import start_interchange_server start_interchange_server(ajet_config) diff --git a/ajet/backbone/main_verl.py b/ajet/backbone/main_verl.py index 1e8c9c01..8eebb95f 100644 --- a/ajet/backbone/main_verl.py +++ b/ajet/backbone/main_verl.py @@ -67,7 +67,7 @@ def run_ppo(config: DictConfig) -> None: def on_shutdown(): if ray.is_initialized(): ray.shutdown() - if config.ajet.enable_experimental_interchange_server: + if config.ajet.enable_interchange_server: if config.ajet.enable_swarm_mode: from ajet.tuner_lib.experimental.interchange_utils import http_change_engine_status print("Changing engine status to OFFLINE before shutdown...") @@ -250,7 +250,7 @@ def run(self, config): from ajet.backbone.trainer_verl import AjetRayPPOTrainer - if config.ajet.enable_experimental_interchange_server: + if config.ajet.enable_interchange_server: from ajet.tuner_lib.experimental.as_oai_model_server import start_interchange_server start_interchange_server(config) diff --git a/ajet/backbone/main_vllm.py b/ajet/backbone/main_vllm.py index edbcb161..4e8b717c 100644 --- a/ajet/backbone/main_vllm.py +++ b/ajet/backbone/main_vllm.py @@ -186,7 +186,7 @@ def main(config): os.environ.update(runtime_env["env_vars"]) # atexit.register(lambda: print("Process exiting, performing cleanup...")) - if config.ajet.enable_experimental_interchange_server: + if config.ajet.enable_interchange_server: from ajet.tuner_lib.experimental.as_oai_model_server import start_interchange_server start_interchange_server(config) if config.ajet.enable_swarm_mode: diff --git a/ajet/backbone/trainer_verl.py b/ajet/backbone/trainer_verl.py index 5ca7ce83..28f09f95 100644 --- a/ajet/backbone/trainer_verl.py +++ b/ajet/backbone/trainer_verl.py @@ -457,7 +457,7 @@ def init_workers(self): ) def _update_interchange_server_status_flag(self, status: str): - if self.config.ajet.enable_experimental_interchange_server: + if self.config.ajet.enable_interchange_server: if self.config.ajet.enable_swarm_mode: from ajet.tuner_lib.experimental.interchange_utils import http_change_engine_status http_change_engine_status(self.config, status, global_step=self.global_steps) @@ -858,7 +858,7 @@ def fit(self): # noqa: C901 self.global_steps += 1 # # when enabled oai request interchange, we need to clear the cache from time to time - # if self.config.ajet.enable_experimental_interchange_server: + # if self.config.ajet.enable_interchange_server: # from ajet.tuner_lib.experimental.as_oai_model_server import ensure_dat_interchange_server_cache_clear # ensure_dat_interchange_server_cache_clear() diff --git a/ajet/backbone/warm_up.py b/ajet/backbone/warm_up.py index 192f4a0b..6e261b6a 100644 --- a/ajet/backbone/warm_up.py +++ b/ajet/backbone/warm_up.py @@ -54,7 +54,7 @@ def warm_up_task_judge_when_needed(config): def clean_up_tmp_ajet_dir(config): """Clean up old IPC socket files in /tmp/ajet directory.""" import time - if config.ajet.enable_experimental_interchange_server is False: + if config.ajet.enable_interchange_server is False: return tmp_dir = "/tmp/ajet" diff --git a/ajet/context_tracker/multiagent_tracking.py b/ajet/context_tracker/multiagent_tracking.py index b332a11d..9a6069fc 100644 --- a/ajet/context_tracker/multiagent_tracking.py +++ b/ajet/context_tracker/multiagent_tracking.py @@ -84,19 +84,6 @@ def extract_text_content_from_content_dict(self, msg): # }, # ], # } - # or tool_result format?? not observed yet: - # msg = { - # "role": "tool", - # "content": [ - # { - # "type": "tool_result", - # "id": "call_xxx", - # "output": "tool output content", - # "name": "tool_name" - # }, - # ], - # } - str_content = "" for item in msg["content"]: @@ -332,6 +319,7 @@ def save_llm_interaction_timeline(self, tools, llm_ext_msg, timeline): ) ): logger.bind(exception=True).info(f"General Warning: merge failure discovered.\n") + # from ajet import bp; bp("SWARM") return @@ -346,7 +334,9 @@ def detect_tool_call_madness(self, llm_output): # llm_output["tool_calls"] is not None, and is not [] tool_calls = llm_output["tool_calls"] if "wrong_toolcall" in self.config.ajet.rollout.compute_madness_checklist: - copy_tool_calls = copy.deepcopy(tool_calls) + # copy_tool_calls = copy.deepcopy(tool_calls) + # Shallow copy is sufficient - we're only reading the data + copy_tool_calls = tool_calls wrong_toolcall = False for i in range(len(copy_tool_calls)): if ("function" in copy_tool_calls[i]) and ( diff --git a/ajet/copilot/job.py b/ajet/copilot/job.py index 665c41cd..96ae2031 100644 --- a/ajet/copilot/job.py +++ b/ajet/copilot/job.py @@ -8,14 +8,13 @@ from __future__ import annotations import os +import time +import yaml import tempfile -from types import SimpleNamespace -from typing import Any, Callable, Union -import yaml +from types import SimpleNamespace +from typing import Any, Callable, Union, cast from loguru import logger - - from ajet.default_config.ajet_default import Config from ajet.utils.config_utils import ( expand_ajet_hierarchical_config, @@ -30,70 +29,118 @@ setup_environment_vars, ) -DEFAULT_DIR = "saved_experiments" + +def override_current_yaml_value_if_given(override_value, current_value): + if override_value is not None: + return override_value + else: + return current_value + +def _set_nested_attr(obj, attr_path: str, value): + keys = attr_path.split(".") + for key in keys[:-1]: + obj = getattr(obj, key) + setattr(obj, keys[-1], value) + +def _get_nested_attr(obj, attr_path: str): + for key in attr_path.split("."): + obj = getattr(obj, key) + return obj class AgentJetJob: - """Lightweight builder that launches AgentJet training as a subprocess.""" + """ + arg: base_yaml_config + **kwargs (yaml config, then override with kwargs) + arg: base_yaml_config (yaml config) + arg: **kwargs (yaml config, then override with kwargs) + """ def __init__( self, - backbone: str = "verl", - model: str = "Qwen/Qwen2___5-7B-Instruct", - n_gpu: int = 8, - algorithm: str = "grpo", - project_name="ajet-swarm", - experiment_name="test", - n_gpu_for_infer: int | None = None, # only for trinity backbone - num_repeat: int = 8, - batch_size: int = 32, - swarm_mode: bool = True, - sample_collection_method: str = "rollout_until_finish_enough_tasks", - *kwargs, + base_yaml_config: str | None = None, + experiment_dir: str | None = None, + project_name: str | None = None, + experiment_name: str | None = None, + n_gpu: int | None = None, + model: str | None = None, + algorithm: str | None = None, + num_repeat: int | None = None, + batch_size: int | None = None, + swarm_mode: bool | None = None, + swarm_mode_sample_collection_method: str | None = None, + max_env_worker: int | None = None, + backbone: str | None = None, ) -> None: - self.backbone = backbone - self.exp_dir = DEFAULT_DIR - self.project_name = project_name - self.exp_name = experiment_name - self.sample_collection_method = sample_collection_method - if swarm_mode: - default_yaml = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', "default_config/ajet_ts_default.yaml")) + + if base_yaml_config is None: + base_yaml_config = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', "default_config/ajet_ts_default.yaml")) else: - default_yaml = None - self.config_as_dict: dict = self.build_job_from_yaml(default_yaml) + logger.warning(f"Reading config from {base_yaml_config}.") + time.sleep(1) + self.config_as_dict: dict = self.build_job_from_yaml(base_yaml_config) self.config = Config.update_from_dict_recursive(Config(), self.config_as_dict) - self.config.ajet.experiment_name = experiment_name - self.config.ajet.backbone = backbone - self.config.ajet.model.path = model - self.config.ajet.trainer_common.n_gpus_per_node = n_gpu - self.config.ajet.trainer_common.algorithm.adv_estimator = algorithm - self.config.ajet.rollout.num_repeat = num_repeat - self.config.ajet.data.train_batch_size = batch_size - self.config.ajet.enable_swarm_mode = swarm_mode - self.config.ajet.swarm_mode_sample_collection_method = sample_collection_method - if n_gpu_for_infer is None and backbone == "trinity": - raise ValueError("Please specify `n_gpu_for_infer` (n_gpu_for_infer < n_gpu) for trinity backbone.") - if (n_gpu_for_infer is not None) and backbone == "verl": - raise ValueError("n_gpu_for_infer is only for trinity backbone, please set it to `None`.") - else: - if backbone == "trinity": - assert isinstance(n_gpu_for_infer, int), f"`n_gpu_for_infer` should be int, got {type(n_gpu_for_infer)}." - assert n_gpu_for_infer < n_gpu, "`n_gpu_for_infer` should be less than `n_gpu`." - self.config.ajet.rollout.n_vllm_engine = n_gpu_for_infer - self.config.ajet.rollout.tensor_model_parallel_size = 1 + self.base_yaml_config: str = cast(str, base_yaml_config) # currently may be None, but will be set later + self.experiment_dir: str = cast(str, experiment_dir) + self.project_name: str = cast(str, project_name) + self.experiment_name: str = cast(str, experiment_name) + self.n_gpu: int = cast(int, n_gpu) + self.model: str = cast(str, model) + self.algorithm: str = cast(str, algorithm) + self.num_repeat: int = cast(int, num_repeat) + self.batch_size: int = cast(int, batch_size) + self.swarm_mode: bool = cast(bool, swarm_mode) + self.swarm_mode_sample_collection_method: str = cast(str, swarm_mode_sample_collection_method) + self.max_env_worker: int = cast(int, max_env_worker) + self.backbone: str = cast(str, backbone) + + # see `ajet/default_config/ajet_ts_default.yaml` + overrides = { + "ajet.experiment_dir": "experiment_dir", + "ajet.project_name": "project_name", + "ajet.experiment_name": "experiment_name", + "ajet.model.path": "model", + "ajet.trainer_common.n_gpus_per_node": "n_gpu", + "ajet.trainer_common.algorithm.adv_estimator": "algorithm", + "ajet.rollout.num_repeat": "num_repeat", + "ajet.data.train_batch_size": "batch_size", + "ajet.enable_swarm_mode": "swarm_mode", + "ajet.swarm_mode_sample_collection_method": "swarm_mode_sample_collection_method", + "ajet.rollout.max_env_worker": "max_env_worker", + "ajet.backbone": "backbone", + } + + # if any value given in kwargs, override the corresponding value in config + for attr_path, override_val in overrides.items(): + # get value from yaml config + # >> e.g. current_model = self.config.model.path + current_val = _get_nested_attr(self.config, attr_path) + + # if override_val (given in __init__) is not None, use it to override the value from yaml config + # >> e.g. new_model = self.model if (self.model is not None) else current_model + new_val = override_current_yaml_value_if_given(getattr(self, override_val), current_val) + + # write final value to `self.config`` + # >> e.g. self.config.model.path = new_model + _set_nested_attr(self.config, attr_path, new_val) + + # write final value to `self` + # >> e.g. self.model = new_model + setattr(self, override_val, new_val) + + if self.backbone == "trinity": + raise NotImplementedError("Trinity backbone is not yet supported in AgentJetJob.") + def build_job_from_yaml(self, yaml_path: str | None) -> dict: self.config_as_dict = read_ajet_hierarchical_config( yaml_path, - exp_name=self.exp_name, - backbone=self.backbone, write_to=None, - exp_dir=self.exp_dir, ) self.config_as_dict = expand_ajet_hierarchical_config(self.config_as_dict, write_to=None) logger.info(f"Built AgentJet job config: {yaml_path}") return self.config_as_dict + def dump_job_as_yaml(self, yaml_path: str) -> str: if os.path.dirname(yaml_path): os.makedirs(os.path.dirname(yaml_path), exist_ok=True) @@ -102,6 +149,7 @@ def dump_job_as_yaml(self, yaml_path: str) -> str: logger.info(f"Saved training config to {yaml_path}") return yaml_path + def set_workflow( self, workflow: Union[str, Callable[..., Any]], ensure_reward_in_workflow: bool = False ) -> "AgentJetJob": @@ -110,6 +158,7 @@ def set_workflow( # ensure_reward_in_workflow return self + def set_data( self, type: str, @@ -136,60 +185,3 @@ def set_data( return self - def tune(self, *args, **kwargs) -> "AgentJetJob": - import ray - ast_cfg = self.config.ajet - if not ast_cfg.rollout or not ast_cfg.rollout.user_workflow: - raise ValueError("Workflow must be set via set_workflow before tuning.") - if not ast_cfg.task_reader: - raise ValueError("Data source must be set via set_data before tuning.") - - backbone = self.config.ajet.backbone - exp_dir = self.config.ajet.experiment_dir - - with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".yaml") as temp_yaml: - yaml_path = temp_yaml.name - self.dump_job_as_yaml(yaml_path) - args = SimpleNamespace( - conf=yaml_path, - backbone=backbone, - exp_dir=exp_dir, - with_logview=False, - debug=False, - ) - - if args.backbone != "debug": - # Enforce GPU availability and free memory threshold before proceeding - check_avail_gpu(min_free_ratio=0.95) - - # finalize experiment config - main_yaml_fp, exe_exp_base, exp_name, exp_config = prepare_experiment_config( - yaml_path, exp_dir, backbone - ) - - # setup environment variables for ray - env = setup_environment_vars(args, exp_config, main_yaml_fp) - - # start ray if not already started - if not ray.is_initialized(): - from ajet.utils.launch_utils import start_ray_service - - start_ray_service(args, env) - else: - raise RuntimeError( - "Ray is already initialized. Please shutdown existing Ray instance before starting a new tuning job." - ) - - # start training process - if args.conf and main_yaml_fp and exe_exp_base and exp_config: - execute_training_process( - args, - get_backbone_target(args.backbone), - main_yaml_fp, - exe_exp_base, - main_yaml_fp, - env, - exp_config, - ) - - return self diff --git a/ajet/copilot/train-complex-blackbox/SKILL.md b/ajet/copilot/train-complex-blackbox/SKILL.md new file mode 100644 index 00000000..6bd8e808 --- /dev/null +++ b/ajet/copilot/train-complex-blackbox/SKILL.md @@ -0,0 +1,174 @@ +--- +name: train-complex-blackbox +description: Create a trainable agent loop or agent workflow with AgentJet +license: Complete terms in LICENSE.txt +--- + + +## 0. Ask user for API key + model (or API key + base url + model) for debugging + +This is not 100% necessary, but it can help a lot in debugging in step 1. +If user has not given a API, ask user to give your one. + + +By default, the code you write should be located at ./tutorial/opencode_build_xxxxxx/*.py + +## 1. Initial Programming + +### Writing dataset collector (`get_training_dataset_item_list.py`) +- `get_training_dataset_item_list.py`: Returns a list of training data items. Maybe a list of training tasks, each item is a string identifier of a training task, or a dict containing necessary information for the training task. + +### Episode Runner (`run_episode_once.py`) +- `run_episode_once.py`: + + - Argument Parser: takes (training data item identifier + api-key + base-url) as input, model-name is not required, you can make up a model name because we ignore it. + + - Execute the agent: read the document of the agent user asked you to train, figure out how to execute the agent. In most cases you can use subprocess to start a commandline process to execute the agent, your biggest issue is to figure out how to pass the training data item identifier, api-key and base-url to that commandline process. You can also use python code to execute the agent if you think it's more convenient. + + - Reward: extract / compute the reward/score for the agent's output. Some agents have clear reward sigal, but others don't. + - clear reward signal: take that down as the reward, no need to do extra reward engineering. + - no clear reward signal: you need to design a reward function to compute the reward/score for the agent's output. You can use another LLM to help you design the reward function, or you can design it by yourself if you have domain knowledge. + + +### Test + +Remember to test these two parts before moving to step 2, make sure they work as expected. + + + +## 2. Writing training code + +This part is easy, simply follow this template and change the necessary part such as dataset path, model name, etc. + +`agent_roll.py` + +```python +# -*- coding: utf-8 -*- + +import os +import re +import requests +from textwrap import dedent +from ajet.schema.task import Task, WorkflowOutput +from ajet.copilot.job import AgentJetJob +from ajet.task_reader import RouterTaskReader +from ajet.utils.thread_executors import PeriodicDrainThreadPoolExecutor +from ajet.tuner_lib.as_oai_baseurl_apikey import OpenaiBaseUrlAndApiKey +from ajet.default_config.ajet_default import AjetTaskReader, HuggingfaceDatRepo +from ajet.tuner_lib.experimental.as_swarm_client import SwarmClient + +# python -m tutorial.example_math_swarm.math + +GRPO_N = 4 # grpo group size +NUM_EPOCH = 10000 +AJET_SWARM_URL = os.getenv("AJET_SWARM_URL", "http://localhost:10086") +REMOTE_MODEL_PATH = os.getenv("REMOTE_MODEL_PATH", "/mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2.5-7B-Instruct") +REMOTE_BATCH_SIZE = 32 +REMOTE_ALLOCATE_GPU_PER_NODE = 8 + +def main(): + + # Handshake with swarm remote, then send training param to swarm remote (such as model to be trained, algorithm, etc) + dataset = RouterTaskReader( + reader_type = "huggingface_dat_repo", + reader_config = AjetTaskReader( + huggingface_dat_repo = HuggingfaceDatRepo( + dataset_path = '/mnt/data_cpfs/model_cache/modelscope/dataset/openai/gsm8k/main', + # dataset_path = "/root/agentjet/benchmark_datasets/dataset/gsm8k/socratic", + # dataset_path = "openai/gsm8k", + # dataset_name = "main", + ) + ) + ) + # Load the CountDown dataset + # print(f"Loading dataset from: {LOCAL_DATASET_PATH}") + # dataset = RouterTaskReader( + # reader_type="jsonl_dataset_file", + # reader_config=AjetTaskReader( + # jsonl_dataset_file=JsonlDatasetFile( + # training=JsonlTrainingFp(file_path=LOCAL_DATASET_PATH) + # ) + # ), + # ) + + # Hand shake with remote swarm server + swarm_worker = SwarmClient(AJET_SWARM_URL) + ajet_job = AgentJetJob( + experiment_name="math_gsm8k_grpo", + algorithm="grpo", + n_gpu=REMOTE_ALLOCATE_GPU_PER_NODE, + model=REMOTE_MODEL_PATH, + batch_size=REMOTE_BATCH_SIZE, + num_repeat=GRPO_N, + ) + print(ajet_job.config.to_dict()) + swarm_worker.auto_sync_train_config_and_start_engine( + ajet_job, + force_restart=True, + ) + + def rollout(task): + # begin episode + episode_uuid, api_baseurl_key = swarm_worker.begin_episode(discard_episode_timeout=60) + # execute agent ( base_url = api_baseurl_key.base_url, api_key = api_baseurl_key.api_key ) + workflow_output = execute_agent(task, api_baseurl_key) # reward is in `workflow_output` + # report output back to swarm remote + swarm_worker.end_episode(task, episode_uuid, workflow_output) + return + + executor = PeriodicDrainThreadPoolExecutor(workers=GRPO_N * REMOTE_BATCH_SIZE, auto_retry=True) + for _ in range(NUM_EPOCH): + for _, task in enumerate(dataset.generate_training_tasks()): + for _ in range(GRPO_N): + executor.submit_with_periodic_drain(fn=rollout, task=task) + + return None + + +def execute_agent(task: Task, api_baseurl_key: OpenaiBaseUrlAndApiKey): + .... + raw_reward: float = ... # compute the reward for the agent's output + return WorkflowOutput(reward=raw_reward, metadata={"important_metadata": important_metadata}) + + +if __name__ == "__main__": + main() + + +``` + + +It is very clear now, your job in step 2 is to: + +- use `get_training_dataset_item_list.py` to generate `List[Task]` (`from ajet.schema.task import Task`) +- use `run_episode_once.py` to execute a single episode and place it in `execute_agent` function + + +## 3. Simplify your code and fix bugs + +before moving to step 4, you can simplify your code and fix bugs to make sure it can run smoothly. + + +## 4. Training + +Finally, you can start training. + +Run `ajet-swarm start` to start training server (if the user has already installed agentjet swarm environment), +if the user has docker environment, you can also refer to `docs/en/ajet-swarm-docker.md` to start a AgentSwarm docker container. + +Create a duplication of `agent_roll.py` named `agent_roll_one_episode_debug.py`, and modify it to only run one episode, this can help you debug whether the episode runner and reward function work as expected. + +After the server side is ready, run +```bash +python /path/to/agent_roll_one_episode_debug.py +``` +watch console log to see if the episode can be executed successfully and reward can be computed correctly. + +If anything goes wrong, keep server running, rewrite and fix `agent_roll_one_episode_debug.py`, and run it again until it can run one episode successfully. + +Next, patch `agent_roll.py` if there are any bugs discorvered via the debugging of `agent_roll_one_episode_debug.py`, and then run +```bash +python /path/to/agent_roll.py +``` + +to start the training! diff --git a/ajet/copilot/write-swarm-client/SKILL.md b/ajet/copilot/write-swarm-client/SKILL.md index c62693e5..0b98902d 100644 --- a/ajet/copilot/write-swarm-client/SKILL.md +++ b/ajet/copilot/write-swarm-client/SKILL.md @@ -4,25 +4,24 @@ description: Create a trainable agent loop or agent workflow with AgentJet license: Complete terms in LICENSE.txt --- -## 简介: -你的任务是根据要求,创建一个可训练 Agent (或者Agent Loop,多智能体系统等等),提供给用户做强化学习训练。 -在AgentJet强化学习框架下,这是非常简单的。 +## Introduction: -首先,根据用户的要求,给智能体系统起一个名字,例如 user_math_agent +Your task is to create a trainable Agent (or Agent Loop, multi-agent system, etc.) based on the requirements, and provide it to the user for reinforcement learning training. Under the AgentJet reinforcement learning framework, this is very simple. -其次,创建文件: -tutorial/user_math_agent +First, give the agent system a name based on the user's requirements, for example `user_math_agent`. -接下来,创建Agent源文件: -tutorial/user_math_agent/agent_roll.py (以 tutorial/example_academic_trans_swarm/trans_roll.py 为模板,变化不大,关键是向用户索取必要的参数) -tutorial/user_math_agent/agent_run.py (根据用户的要求,创建运行智能体的函数,或者类,都可以。同步异步都可以。) -tutorial/user_math_agent/readme.md (Agent说明,以及训练、调试方法说明) +Next, create the directory: +`tutorial/user_math_agent` +Then, create the Agent source files: +- `tutorial/user_math_agent/agent_roll.py` (Use `tutorial/example_academic_trans_swarm/trans_roll.py` as a template. There aren't many changes — the key is to ask the user for the necessary parameters.) +- `tutorial/user_math_agent/agent_run.py` (Create the function or class to run the agent based on the user's requirements. Synchronous or asynchronous are both fine.) +- `tutorial/user_math_agent/readme.md` (Agent description, along with training and debugging instructions.) -## 智能体编写方法 +## How to Write the Agent -使用 OpenAI SDK 编写智能体,主要包含以下三个函数(以及必要的子函数和子模块): +Write the agent using the OpenAI SDK. It mainly includes the following three functions (along with any necessary sub-functions and sub-modules): ``` from ajet.schema.task import Task, WorkflowOutput @@ -31,24 +30,68 @@ def _compute_reward(...) def _execute_agent(...) -def run_agent_and_compute_reward(task: Task, base_url:string, api_key:string) -> WorkflowOutput: +def run_agent_and_compute_reward(task: Task, base_url: string, api_key: string) -> WorkflowOutput: ``` -在 agent_roll 中,直接import run_agent_and_compute_reward即可。 +In `agent_roll`, simply import `run_agent_and_compute_reward`. -- 智能体的编写要领:通过一个或几个Agent的协作,高效完成用户给定的任务。 -- 奖励编写的要领:容易验证的,使用规则直接计算。不容易验证的,模仿 `tutorial/example_academic_trans_swarm/train_multi_model/trans_reward.py` 中的方法,使用其他大型模型生成 LLM as Judge 程序。 +- **Key points for writing the agent:** Efficiently complete the user's given task through the collaboration of one or several Agents. +- **Key points for writing the reward:** For things that are easy to verify, calculate directly using rules. For things that are hard to verify, follow the approach in `tutorial/example_academic_trans_swarm/train_multi_model/trans_reward.py` and use other large models to create an LLM-as-Judge program. +## Training and Debugging Instructions -## 训练、调试方法说明 +Overall, the user first runs `ajet-swarm start`, then runs `agent_roll.py`, and training begins. You do not need to and are not allowed to run these bash commands. +- First, help the user write `agent_run.py` and `agent_roll.py`. +- Then, write clear instructions to guide the user through training (`readme.md`). -总体而言,就是用户先运行 `ajet-swarm start`, 然后再运行 `agent_roll.py` 训练就开始了。你不需要也不被允许运行这些bash命令。 -- 首先帮助用户写好 `agent_run.py` 和 `agent_roll.py`, -- 然后写清楚引导用户训练的说明(readme.md), -你的任务就完成了。 +Your task is then complete. -以下是一些参考资料。 +Below are some reference materials. +--- + +## Introduction: + +Your task is to create a trainable Agent (or Agent Loop, multi-agent system, etc.) based on the requirements, and provide it to the user for reinforcement learning training. Under the AgentJet reinforcement learning framework, this is very simple. + +First, give the agent system a name based on the user's requirements, for example `user_math_agent`. + +Next, create the directory: +`tutorial/user_math_agent` + +Then, create the Agent source files: +- `tutorial/user_math_agent/agent_roll.py` (Use `tutorial/example_academic_trans_swarm/trans_roll.py` as a template. There aren't many changes — the key is to ask the user for the necessary parameters.) +- `tutorial/user_math_agent/agent_run.py` (Create the function or class to run the agent based on the user's requirements. Synchronous or asynchronous are both fine.) +- `tutorial/user_math_agent/readme.md` (Agent description, along with training and debugging instructions.) + +## How to Write the Agent + +Write the agent using the OpenAI SDK. It mainly includes the following three functions (along with any necessary sub-functions and sub-modules): + +``` +from ajet.schema.task import Task, WorkflowOutput + +def _compute_reward(...) + +def _execute_agent(...) + +def run_agent_and_compute_reward(task: Task, base_url: string, api_key: string) -> WorkflowOutput: +``` + +In `agent_roll`, simply import `run_agent_and_compute_reward`. + +- **Key points for writing the agent:** Efficiently complete the user's given task through the collaboration of one or several Agents. +- **Key points for writing the reward:** For things that are easy to verify, calculate directly using rules. For things that are hard to verify, follow the approach in `tutorial/example_academic_trans_swarm/train_multi_model/trans_reward.py` and use other large models to create an LLM-as-Judge program. + +## Training and Debugging Instructions + +Overall, the user first runs `ajet-swarm start`, then runs `agent_roll.py`, and training begins. You do not need to and are not allowed to run these bash commands. +- First, help the user write `agent_run.py` and `agent_roll.py`. +- Then, write clear instructions to guide the user through training (`readme.md`). + +Your task is then complete. + +Below are some reference materials. --- # Using AgentJet Swarm to Train Your Agents diff --git a/ajet/default_config/ajet_default.yaml b/ajet/default_config/ajet_default.yaml index 0f164971..1539d028 100644 --- a/ajet/default_config/ajet_default.yaml +++ b/ajet/default_config/ajet_default.yaml @@ -3,7 +3,7 @@ ajet: project_name: "ajet_default_project" experiment_name: "read_yaml_name" experiment_dir: "auto" # {exp-dir}/{experiment_name} - backbone: debug # `debug` or `trinity` or `verl` + backbone: verl # `debug` or `trinity` or `verl` model: @@ -85,6 +85,7 @@ ajet: num_repeat: 1 + task_reader: # how to read dataset / environment type: huggingface_dat_repo # `env_service` or `jsonl_dataset_file` or `huggingface_dat_repo` or `data_generation` or `random_dummy` @@ -284,7 +285,7 @@ ajet: # the experimental ZeroMQ interchange server feature that allows `tuner.as_oai_baseurl_apikey` feature enable_swarm_mode: False # both swarm / oai share the same interchange server - enable_experimental_interchange_server: False + enable_interchange_server: False # interchange server configuration interchange_server: interchange_method: 'ipc' # options: 'tcp' (multi-nodes) or 'ipc' (1 node) @@ -306,8 +307,6 @@ ajet: swarm_mode_sample_collection_max_cached_episodes: 9999 task_runner: - # submit llm infer submit method - llm_infer_submit_method: "async" # options: "sync", "async" # how to wrap the user-defined workflow wrapper_type: "asyncio-with-gc" diff --git a/ajet/default_config/ajet_ts_default.yaml b/ajet/default_config/ajet_ts_default.yaml index 1db9bdd5..90e3f4bd 100644 --- a/ajet/default_config/ajet_ts_default.yaml +++ b/ajet/default_config/ajet_ts_default.yaml @@ -3,15 +3,19 @@ ajet: project_name: "ajet_default_project" experiment_name: "read_yaml_name" experiment_dir: "auto" # {exp-dir}/{experiment_name} - backbone: debug # `debug` or `trinity` or `verl` + backbone: verl model: # which model should be trained - path: /mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2___5-3B-Instruct + path: "Qwen/Qwen2.5-7B-Instruct" rollout: # the path to the workflow class user_workflow: null + # maximum number of parallel environments / simulate workers + max_env_worker: 128 + # how many times a task should be repeated + num_repeat: 4 task_reader: type: random_dummy # `env_service` or `jsonl_dataset_file` or `huggingface_dat_repo` or `data_generation` or `random_dummy` @@ -21,7 +25,7 @@ ajet: judge_protocol: null # reward must come from remote user agent workflow, so set to null # the experimental ZeroMQ interchange server feature that allows `tuner.as_oai_baseurl_apikey` feature - enable_experimental_interchange_server: True + enable_interchange_server: True # train in cloud, run episode locally enable_swarm_mode: True # both swarm / oai share the same interchange server @@ -44,21 +48,29 @@ ajet: # (Hint: a **task_id** is considered "NON-DUMMY" at least one of **episodes** of **task_id** has **different** reward value.) swarm_mode_sample_collection_method: "rollout_until_finish_enough_tasks" - rollout: - # maximum number of parallel environments / simulate workers - max_env_worker: 128 + data: + # max number of tokens for prompt + max_prompt_length: 3000 + # max number of tokens for response + max_response_length: 15000 + # how many tasks per training batch + train_batch_size: 32 + # [Hint]: The final number of samples per update will be: N_{sample} = (data.train_batch_size * rollout.num_repeat * rollout.multi_turn.expected_steps) trainer_common: logger: tensorboard + n_gpus_per_node: 8 + algorithm: + adv_estimator: grpo -# ------------------ 不需要修改 ------------------ +# ------------------ do not edit ------------------ hydra: searchpath: - file://ajet/default_config - file://ajet/default_config/verl # verl only -# ------------------ 不需要修改 ------------------ +# ------------------ do not edit ------------------ defaults: - - verl_default # verl inherit 1/1 + - verl_default - ajet_default - _self_ diff --git a/ajet/launcher.py b/ajet/launcher.py index c9d5705b..d3a5b4cd 100644 --- a/ajet/launcher.py +++ b/ajet/launcher.py @@ -150,8 +150,8 @@ def start_swarm_server(env, config): assert config.ajet.enable_swarm_mode, ( "Please enable_swarm_mode in config to start swarm server." ) - assert config.ajet.enable_experimental_interchange_server, ( - "Please enable_experimental_interchange_server in config to start swarm server." + assert config.ajet.enable_interchange_server, ( + "Please enable_interchange_server in config to start swarm server." ) from ajet.tuner_lib.experimental.as_oai_model_server import ( start_interchange_server, @@ -168,7 +168,7 @@ def main(): from ajet.utils.swarm_overwatch import start_overwatch logger.info(f"Starting Swarm Overwatch for server: {args.swarm_overwatch}") - start_overwatch(args.swarm_overwatch, refresh_interval=1.0) + start_overwatch(args.swarm_overwatch, refresh_interval=2.0) return # Enforce GPU availability and free memory threshold before proceeding @@ -204,7 +204,6 @@ def main(): # read configuration from yaml exp_config = None - exp_dir = args.exp_dir or DEFAULT_DIR if args.swarm_server and (not args.conf): args.conf = os.path.abspath( os.path.join( @@ -215,6 +214,7 @@ def main(): "Please provide a valid config file for swarm server mode." ) if args.conf: + exp_base_dir = args.exp_dir or DEFAULT_DIR yaml_path = args.conf ( main_yaml_fp, @@ -222,7 +222,10 @@ def main(): exp_name, exp_config, ) = prepare_experiment_config( - yaml_path, exp_dir, args.backbone, storage=(not args.swarm_server) + yaml_path=yaml_path, + exp_base_dir=exp_base_dir, + backbone=args.backbone, + storage=(not args.swarm_server) ) # setup environment variables diff --git a/ajet/schema/extended_msg.py b/ajet/schema/extended_msg.py index dfaa7460..5e789074 100644 --- a/ajet/schema/extended_msg.py +++ b/ajet/schema/extended_msg.py @@ -244,9 +244,11 @@ def get_inc_simple(self, text_frag_from, text_frag_to, tokenizer): tokenizer_output = tokenizer(text_frag_from, return_tensors="pt", padding=False) tokenizer_input_ids = tokenizer_output["input_ids"][0].tolist() token_ids_acc = tokenizer_input_ids + del tokenizer_output # Free memory immediately tokenizer_output = tokenizer(text_frag_to, return_tensors="pt", padding=False) input_ids = tokenizer_output["input_ids"][0].tolist() + del tokenizer_output # Free memory immediately # get the new tokens added in this step input_id_increment = input_ids[len(token_ids_acc) :] FN_DEBUG = False diff --git a/ajet/swarm_cli.py b/ajet/swarm_cli.py index eb5dd866..723d9b5a 100644 --- a/ajet/swarm_cli.py +++ b/ajet/swarm_cli.py @@ -24,8 +24,8 @@ def start_swarm_server(env, config, port): assert config.ajet.enable_swarm_mode, ( "Please enable_swarm_mode in config to start swarm server." ) - assert config.ajet.enable_experimental_interchange_server, ( - "Please enable_experimental_interchange_server in config to start swarm server." + assert config.ajet.enable_interchange_server, ( + "Please enable_interchange_server in config to start swarm server." ) # Set the port in the config @@ -42,7 +42,7 @@ def start_swarm_server(env, config, port): def cmd_start(args): """Handle the 'start' subcommand.""" # Use default config if not provided - exp_dir = args.exp_dir or DEFAULT_DIR + exp_base_dir = args.exp_dir or DEFAULT_DIR if not args.conf: args.conf = os.path.abspath( os.path.join( @@ -61,7 +61,10 @@ def cmd_start(args): exp_name, exp_config, ) = prepare_experiment_config( - yaml_path, exp_dir, "verl", storage=False + yaml_path=yaml_path, + exp_base_dir=exp_base_dir, + backbone="verl", + storage=False ) # Setup environment variables @@ -73,7 +76,6 @@ def __init__(self, conf, backbone, exp_dir): self.swarm_server = True self.swarm_overwatch = "" self.debug = "" - swarm_args = SwarmArgs(args.conf, "verl", args.exp_dir) env, exp_config = setup_environment_vars(swarm_args, exp_config, main_yaml_fp) @@ -131,9 +133,9 @@ def main(): parser_overwatch.add_argument( "--refresh-interval", type=float, - default=1.0, + default=2.0, required=False, - help="Refresh interval in seconds (default: 1.0)", + help="Refresh interval in seconds (default: 2.0)", ) parser_overwatch.set_defaults(func=cmd_overwatch) diff --git a/ajet/task_rollout/async_llm_bridge.py b/ajet/task_rollout/async_llm_bridge.py index 3015f631..ced9cf16 100644 --- a/ajet/task_rollout/async_llm_bridge.py +++ b/ajet/task_rollout/async_llm_bridge.py @@ -1,9 +1,9 @@ -import asyncio import copy import json import time import uuid -from typing import Any, Callable, Dict, List, Literal, Union +from typing import Any, Callable, Dict, List, Literal, Union, Awaitable +from typing import TYPE_CHECKING from loguru import logger from omegaconf import DictConfig @@ -15,12 +15,13 @@ from ajet.schema.logprob import TokenAndProb from ajet.utils.tokenizer import ajet_apply_chat_template -from ajet.utils.async_utils import run_async_coroutine_with_timeout -from ajet.utils.testing_utils import _mock_if_test_mode, _test_if_test_mode from ajet.schema.convertion import convert_llm_proxy_response_to_oai_response from ajet.schema.convertion import convert_llm_proxy_response_to_agentscope_response from ajet.context_tracker.multiagent_tracking import MultiAgentContextTracker +if TYPE_CHECKING: + from vllm.entrypoints.openai.protocol import ChatCompletionRequest + ChatResponse = Union[OpenAIChatCompletion, AgentScopeChatResponse] @@ -58,207 +59,6 @@ def __init__( self.max_llm_retries = max_llm_retries self.tool_parser = Hermes2ProToolParser(self.tokenizer) - def get_llm_inference_fn_sync(self, sampling_params: dict = {}) -> Callable: # noqa: C901 - - def llm_chat_verl( - messages: List[Dict[str, str]], - custom_sampling_params: dict = {}, - tools=[], - request_id: str = "", - ) -> dict: - request_id = uuid.uuid4().hex - - updated_sampling_params = {} - if sampling_params: - updated_sampling_params.update(sampling_params) - if custom_sampling_params: - updated_sampling_params.update(custom_sampling_params) - - input_messages = copy.deepcopy(messages) - prompt_text = ajet_apply_chat_template( - tokenizer=self.tokenizer, - conversation=input_messages, - tools=tools, - add_generation_prompt=True, - tokenize=False, - ) - prompt_ids = self.tokenizer(prompt_text)["input_ids"] - - if self.config.ajet.execute_test: - _test_if_test_mode("prompt_text", prompt_text, self.config) - - final_res = run_async_coroutine_with_timeout( - self.async_rollout_manager.generate( - request_id=request_id, - prompt_ids=prompt_ids, - sampling_params=updated_sampling_params, - ), - timeout=1800, - ) - - if self.config.ajet.rollout.name == "vllm": - final_res: VerlVllmRequestOutput - token_array = final_res.outputs[0].token_ids - logprob_array = final_res.outputs[0].logprobs - elif self.config.ajet.rollout.name == "sglang": - token_array = final_res - - decoded_text = self.tokenizer.decode(token_array) # type: ignore - if self.config.ajet.execute_test: - decoded_text = _mock_if_test_mode("mock_decoded_text", decoded_text, self.config) - - if decoded_text.endswith("<|im_end|>"): - decoded_text = decoded_text[: -len("<|im_end|>")] - - # if tool call - tool_calls = None - if ( - ("" in decoded_text) - and ("" in decoded_text) - and (not self.config.ajet.rollout.force_disable_toolcalls) - ): - parsed_tool_calls = self.tool_parser.extract_tool_calls(decoded_text, None) # type: ignore - parsed_tool_calls = parsed_tool_calls.model_dump() - if self.config.ajet.execute_test: - _test_if_test_mode( - "parsed_tool_calls", parsed_tool_calls["tool_calls"], self.config - ) - model_called = parsed_tool_calls["tools_called"] - if model_called: - tool_calls = parsed_tool_calls["tool_calls"] - is_bad_toolcall = False - for i in range(len(tool_calls)): - if "function" in tool_calls[i] and "arguments" in tool_calls[i]["function"]: - expect_dict = json.loads(tool_calls[i]["function"]["arguments"]) - if not isinstance(expect_dict, dict): - is_bad_toolcall = True - if is_bad_toolcall: - tool_calls = None - decoded_text = decoded_text - else: - decoded_text = parsed_tool_calls["content"] - if decoded_text is None: - decoded_text = "" - - return { - "role": "assistant", - "request_id": request_id, - "content": decoded_text, - "tool_calls": tool_calls, - "tokens": [ - TokenAndProb( - token_id=token_id, - logprob=logprob[token_id].logprob, # Warning: vllm logprob does not participant training (not reliable enough), for log only. - decoded_string=logprob[token_id].decoded_token, - ) - for token_id, logprob in zip(token_array, logprob_array) # type: ignore - ], - } - - - def llm_chat_remote( - messages: List[Dict[str, str]], - custom_sampling_params: dict = {}, - tools=[], - request_id: str = "", - ) -> dict: - updated_sampling_params = {} - if sampling_params: - updated_sampling_params.update(sampling_params) - if custom_sampling_params: - updated_sampling_params.update(custom_sampling_params) - updated_sampling_params.update({"logprobs": 1, "return_tokens_as_token_ids": True}) - input_messages = copy.deepcopy(messages) - for i in range(self.max_llm_retries): - try: - # this function is defined in `ajet/backbone/main_vllm.py` - output_message = self.async_rollout_manager.submit_chat_completions( - messages=input_messages, - sampling_params=updated_sampling_params, - tools=tools, - request_id=request_id, - ) - break - except Exception as e: - logger.bind(exception=True).exception(f"rollout_server.{i} error: {e.args}") - time.sleep(i + 1) - return output_message[-1] # type: ignore - - - def llm_chat_trinity( - messages: List[Dict[str, str]], - custom_sampling_params: dict = {}, - tools=[], - request_id: str = "", - ) -> dict: - async def main(): - updated_sampling_params = {} - if sampling_params: - updated_sampling_params.update(sampling_params) - if custom_sampling_params: - updated_sampling_params.update(custom_sampling_params) - updated_sampling_params.pop("min_tokens") - - if tools: - response = await self.async_rollout_manager.chat.completions.create( - model=self.async_rollout_manager.model_path, - messages=messages, - logprobs=True, - tools=tools, - top_logprobs=0, - **updated_sampling_params, - ) - else: - response = await self.async_rollout_manager.chat.completions.create( - model=self.async_rollout_manager.model_path, - messages=messages, - logprobs=True, - top_logprobs=0, - **updated_sampling_params, - ) - return response - - response = run_async_coroutine_with_timeout(main(), timeout=1800) # type: ignore - prompt_text = self.tokenizer.decode(response.model_extra["prompt_token_ids"]) - prompt_token_ids = response.model_extra["prompt_token_ids"] - content = response.choices[0].message.content - message = response.choices[0].message.model_dump(exclude_unset=True, exclude_none=True) - - if content is None: - content = "" - - if ("" in content) and (not message.get("tool_calls", None)): - # logger.bind(exception=True).exception(f"Bad toolcall discovered \n\nprompt_text:\n{prompt_text}\n\nrepsonse:\n{content}") - logger.warning(f"Bad toolcall discovered: {content}") - - return { - "role": "assistant", - "request_id": response.id, - "content": content, - "prompt_text": prompt_text, - "prompt_token_ids": prompt_token_ids, - "tool_calls": message.get("tool_calls", []), - "tokens": [ - TokenAndProb( - token_id=token, - logprob=tokenlogprob.logprob, # Warning: vllm logprob does not participant training, for log only. - decoded_string=tokenlogprob.token, - ) - for tokenlogprob, token in zip( - response.choices[0].logprobs.content, - response.choices[0].token_ids, - ) - ], - } - - if self.llm_mode == "remote": - return llm_chat_remote - if self.llm_mode == "trinity": - return llm_chat_trinity - else: - return llm_chat_verl - - def get_llm_inference_fn_async(self, sampling_params: dict = {}) -> Callable: # noqa: C901 @@ -286,9 +86,6 @@ async def llm_chat_verl( ) prompt_ids = self.tokenizer(prompt_text)["input_ids"] - if self.config.ajet.execute_test: - _test_if_test_mode("prompt_text", prompt_text, self.config) - final_res = await self.async_rollout_manager.generate( request_id=request_id, prompt_ids=prompt_ids, @@ -303,13 +100,11 @@ async def llm_chat_verl( token_array = final_res decoded_text = self.tokenizer.decode(token_array) # type: ignore - if self.config.ajet.execute_test: - decoded_text = _mock_if_test_mode("mock_decoded_text", decoded_text, self.config) if decoded_text.endswith("<|im_end|>"): decoded_text = decoded_text[: -len("<|im_end|>")] - # if tool call + # if tool call, use vLLM tool parser to extract tool calls and validate them tool_calls = None if ( ("" in decoded_text) @@ -319,10 +114,7 @@ async def llm_chat_verl( parsed_tool_calls = self.tool_parser.extract_tool_calls(decoded_text, None) # type: ignore parsed_tool_calls = parsed_tool_calls.model_dump() - if self.config.ajet.execute_test: - _test_if_test_mode( - "parsed_tool_calls", parsed_tool_calls["tool_calls"], self.config - ) + model_called = parsed_tool_calls["tools_called"] if model_called: tool_calls = parsed_tool_calls["tool_calls"] @@ -474,7 +266,7 @@ class OpenaiLlmProxyWithTracker(object): def __init__( self, - llm_inference_fn: Callable, # Callable[AjetStandardLlmBridgeRequest, AjetStandardLlmBridgeResponse] + llm_inference_fn: Callable[..., Awaitable[Dict]], # Callable[AjetStandardLlmBridgeRequest, AjetStandardLlmBridgeResponse] context_tracker: MultiAgentContextTracker, config, ) -> None: @@ -483,15 +275,39 @@ def __init__( self.config = config + async def chat_completion_request( + self, + req: "ChatCompletionRequest", + timeline_uuid: str, + agent_name: str, + target_tag: str, + episode_uuid: str, + ): + from openai.types.chat.chat_completion import ChatCompletion + req_as_dict = req.model_dump() + + # infer + process with context tracker + llm_output = await self.run_infer( + messages=req_as_dict["messages"], + tools=req_as_dict["tools"], + tool_choice="auto", + ) + # convert to OpenAI ChatCompletion format + response: ChatCompletion = convert_llm_proxy_response_to_oai_response(llm_output) + # this is an important id assignment + response.id = timeline_uuid + assert isinstance(response, ChatCompletion) + return response + + async def __call__( self, messages: List[dict], tools: List = [], tool_choice: str = "auto", - structured_model=None, **kwargs, ) -> ChatResponse: - llm_output = await self.run_infer(messages, tools, tool_choice, structured_model, **kwargs) + llm_output = await self.run_infer(messages, tools, tool_choice, **kwargs) return convert_llm_proxy_response_to_oai_response(llm_output) @@ -500,9 +316,8 @@ async def run_infer( messages: List[dict], tools: List = [], tool_choice: str = "auto", # always auto - structured_model=None, # this is for AgentScope only **kwargs, - ): + ) -> Dict: # generate timeline uuid timeline_uuid = uuid.uuid4().hex @@ -527,16 +342,10 @@ async def run_infer( # else: # otherwise, for abnormal output, can still proceed, but we do not track output anymore - # run llm inference ✨ - if self.config.ajet.task_runner.llm_infer_submit_method == "sync": - llm_output = await asyncio.to_thread( - self.llm_inference_fn, converted_message, custom_sampling_params, tools - ) - else: - llm_output = await self.llm_inference_fn(converted_message, custom_sampling_params, tools) - + # run llm inference ✨ (llm_chat_verl) + llm_output = await self.llm_inference_fn(converted_message, custom_sampling_params, tools) - # begin context tracking + # context tracking self.context_tracker.step_track(llm_output, context_safe, converted_message, tools, timeline_uuid=timeline_uuid) return llm_output @@ -554,7 +363,6 @@ def construct_overflow_response(self): - # ---------------------------------------------------------------------------------------------- # ------------------------ call async llm with context tracker (AgentScope) -------------------- # ---------------------------------------------------------------------------------------------- @@ -570,6 +378,6 @@ async def __call__( **kwargs, ) -> AgentScopeChatResponse: - llm_output = await self.run_infer(messages, tools, tool_choice, structured_model) + llm_output = await self.run_infer(messages, tools, tool_choice) response = convert_llm_proxy_response_to_agentscope_response(llm_output, structured_model=structured_model) return response diff --git a/ajet/task_rollout/native_parallel_worker.py b/ajet/task_rollout/native_parallel_worker.py index 79f0bdcd..8db4058c 100644 --- a/ajet/task_rollout/native_parallel_worker.py +++ b/ajet/task_rollout/native_parallel_worker.py @@ -1,7 +1,9 @@ """Parallel environment rollout orchestration utilities.""" import os +import gc import time +import tracemalloc from concurrent.futures import Future, ThreadPoolExecutor from typing import Dict, List, Literal from urllib.parse import quote @@ -9,6 +11,7 @@ import numpy as np import torch import threading +from math import ceil from loguru import logger from tensordict import TensorDict from torch.nn.utils.rnn import pad_sequence @@ -89,6 +92,64 @@ def _write_swarm_rollout_dynamic_log(self, observation_window): f.write(string_buffer) return + def _check_memory_leak(self): + """Check for memory leaks by comparing memory snapshots.""" + if not self._tracemalloc_started: + tracemalloc.start() + self._tracemalloc_started = True + logger.info("Memory tracking started (tracemalloc)") + self._memory_snapshot = tracemalloc.take_snapshot() + return + + # Take a new snapshot + gc.collect() # Force garbage collection before snapshot + current_snapshot = tracemalloc.take_snapshot() + + if self._memory_snapshot is not None: + # Compare snapshots + top_stats = current_snapshot.compare_to(self._memory_snapshot, 'lineno') + + logger.info("=" * 80) + logger.info("Memory Leak Detection: Top 10 differences since last rollout_swarm call") + logger.info("=" * 80) + + total_size_diff = 0 + for stat in top_stats[:10]: + total_size_diff += stat.size_diff + logger.info(f"{stat}") + + # Convert to MB + total_size_diff_mb = total_size_diff / 1024 / 1024 + logger.info(f"\nTotal memory difference: {total_size_diff_mb:.2f} MB") + + # Show top current memory consumers + logger.info("\n" + "=" * 80) + logger.info("Top 10 current memory allocations") + logger.info("=" * 80) + top_current = current_snapshot.statistics('lineno') + for stat in top_current[:10]: + logger.info(f"{stat}") + + logger.info("=" * 80) + + # Enhanced leak detection: show traceback for largest leak + if total_size_diff_mb > 10: # Only if leak is significant (>10MB) + logger.warning(f"SIGNIFICANT MEMORY LEAK DETECTED: {total_size_diff_mb:.2f} MB") + logger.info("\n" + "=" * 80) + logger.info("Detailed traceback for top 3 memory leaks:") + logger.info("=" * 80) + for i, stat in enumerate(top_stats[:3], 1): + if stat.size_diff > 0: + logger.info(f"\n--- Leak #{i}: +{stat.size_diff / 1024 / 1024:.2f} MB, {stat.count_diff} objects ---") + logger.info(f"File: {stat.traceback.format()[0] if stat.traceback else 'Unknown'}") + if stat.traceback and len(stat.traceback) > 1: + logger.info("Full traceback:") + for line in stat.traceback.format(): + logger.info(f" {line}") + logger.info("=" * 80) + + # Update snapshot for next comparison + self._memory_snapshot = current_snapshot def rollout_static( self, @@ -173,10 +234,13 @@ def rollout_swarm( # noqa: C901 each thread re-spawn after complete, until reaching conditions to stop. """ + # # Memory leak detection: compare with previous snapshot + # self._check_memory_leak() + tracker_array: List[SingleAgentContextTracker] = [] rollout_n = self.rollout_n n_batch_task = len(tasks) - n_task = min(len(tasks), self.max_parallel // rollout_n) + n_task = min(len(tasks), ceil(self.max_parallel / rollout_n)) assert n_task > 0, f"n_task is not valid, n_task = min(len(tasks), self.max_parallel // rollout_n) = {n_task}" self.current_token_count_time = time.time() @@ -232,7 +296,7 @@ def enough_finished_task_stop_condition(completed_task_id_map_ct) -> bool: f"Too many cached episodes [{total_completed_episodes}] has exceeded the max cached episodes [{self.config.ajet.swarm_mode_sample_collection_max_cached_episodes}] " f"Deleting cached episodes to release memory..." ) - completed_task_id_map_ct = {} + completed_task_id_map_ct.clear() return (total_completed_tasks >= n_batch_task) def enough_non_dummy_task_stop_condition(completed_task_id_map_ct) -> bool: @@ -257,7 +321,7 @@ def enough_non_dummy_task_stop_condition(completed_task_id_map_ct) -> bool: f"Too many cached episodes [{total_completed_episodes}] has exceeded the max cached episodes [{self.config.ajet.swarm_mode_sample_collection_max_cached_episodes}] " f"Deleting cached episodes to release memory..." ) - completed_task_id_map_ct = {} + completed_task_id_map_ct.clear() return (total_completed_non_dummy_tasks >= n_batch_task) # select stop condition function based on config @@ -370,22 +434,23 @@ def update_rollout_result_array_preview(observation_window, completed_task_id_ma self._write_swarm_rollout_dynamic_log(observation_window) meet_stop_condition_after_new_results = stop_condition(completed_task_id_map_ct) if meet_stop_condition_after_new_results: - print("Sending soft stop signal to all threads...") + logger.info("Sending soft stop signal to all threads...") stop_all_threads_soft() break # wait for all threads to complete - print('Finalizing all threads...') + logger.info('Finalizing all threads...') executor.shutdown(wait=True) # stop all threads hard - print("Sending hard stop signal to all threads...") + logger.info("Sending hard stop signal to all threads...") stop_all_threads_hard() # build tracker_array - print('Collecting results...') + logger.info('Collecting results...') for ct_list in completed_task_id_map_ct.values(): tracker_array.extend(ct_list) + completed_task_id_map_ct.clear() # TODO: support multi-step reward task_success_rate = np.mean( @@ -402,6 +467,20 @@ def update_rollout_result_array_preview(observation_window, completed_task_id_ma update_rollout_result_array_preview(observation_window, completed_task_id_map_ct) self._write_swarm_rollout_dynamic_log(observation_window) + # Explicit cleanup to prevent memory leaks + logger.debug("Performing explicit cleanup...") + # Clear futures list + futures.clear() + # Clear observation window + observation_window.clear() + # Delete local function references to break circular refs + del stop_condition_callback + del stop_condition + del update_rollout_result_array_preview + del count_tasks + # Force garbage collection + gc.collect() + return tracker_array diff --git a/ajet/task_rollout/single_worker.py b/ajet/task_rollout/single_worker.py index 4791a47f..baca7e34 100644 --- a/ajet/task_rollout/single_worker.py +++ b/ajet/task_rollout/single_worker.py @@ -72,6 +72,10 @@ def __init__( max_llm_retries=max_llm_retries, ) + # Memory leak tracking + self._memory_snapshot = None + self._tracemalloc_started = False + @retry_with_backoff(max_retry_attr="max_llm_retries") def rollout_env_worker( self, @@ -90,14 +94,9 @@ def rollout_env_worker( """ sampling_params = get_sample_params(mode, self.config) - if self.config.ajet.task_runner.llm_infer_submit_method == "sync": - llm_inference_fn = self.async_llm_bridge.get_llm_inference_fn_sync( - sampling_params=sampling_params - ) - else: - llm_inference_fn = self.async_llm_bridge.get_llm_inference_fn_async( - sampling_params=sampling_params - ) + llm_inference_fn = self.async_llm_bridge.get_llm_inference_fn_async( + sampling_params=sampling_params + ) episode_uuid = uuid.uuid4().hex workflow_task = WorkflowTask( diff --git a/ajet/tuner.py b/ajet/tuner.py index b780d44a..45a54425 100644 --- a/ajet/tuner.py +++ b/ajet/tuner.py @@ -23,7 +23,7 @@ def __init__( self.context_tracker = context_tracker self.llm_inference_fn = llm_inference_fn self.target2proxy_registry: dict[str, dict[str,TunerTypeUnion]] = {} - self.enable_interchange_server = config.ajet.enable_experimental_interchange_server + self.enable_interchange_server = config.ajet.enable_interchange_server if self.enable_interchange_server: self.proxy_client_started = False @@ -102,10 +102,10 @@ def as_oai_baseurl_apikey( ``` """ - assert self.enable_interchange_server, "Please enable `ajet.enable_experimental_interchange_server` in yaml config to use `as_oai_baseurl_apikey` feature." + assert self.enable_interchange_server, "Please enable `ajet.enable_interchange_server` in yaml config to use `as_oai_baseurl_apikey` feature." if self.proxy_client_started is False: self.proxy_client_started = True - self._enable_experimental_interchange_server(self.llm_inference_fn) + self._enable_interchange_server(self.llm_inference_fn) baseurl_apikey_model = OpenaiClientBaseUrlTuner( config=self.config, context_tracker=self.context_tracker, @@ -168,7 +168,7 @@ def get_context_tracker(self) -> MultiAgentContextTracker: return self.context_tracker - def _enable_experimental_interchange_server(self, llm_inference_fn): + def _enable_interchange_server(self, llm_inference_fn): # experimental reverse proxy start if self.enable_interchange_server: from ajet.tuner_lib.experimental.as_oai_model_client import InterchangeClient diff --git a/ajet/tuner_lib/as_oai_baseurl_apikey.py b/ajet/tuner_lib/as_oai_baseurl_apikey.py index cc40c61b..0931e980 100644 --- a/ajet/tuner_lib/as_oai_baseurl_apikey.py +++ b/ajet/tuner_lib/as_oai_baseurl_apikey.py @@ -12,11 +12,13 @@ class MockAsyncCompletions(AsyncCompletions): async def create(self, *args, **kwargs) -> Any: # type: ignore return await self._client.create(*args, **kwargs) # type: ignore + class MockAsyncChat(AsyncChat): @property def completions(self) -> MockAsyncCompletions: # type: ignore return MockAsyncCompletions(self._client) + class OpenaiBaseUrlAndApiKey(BaseModel): """ At this layer, we will determine which model to use: - training model @@ -29,13 +31,17 @@ class OpenaiBaseUrlAndApiKey(BaseModel): episode_uuid: str = Field(default="episode_id", description="reserved field.") def as_agentscope_model(self, *args, **kwargs): - from agentscope.model import DashScopeChatModel - return DashScopeChatModel(model_name="AgentJet-Model", api_key=self.api_key, base_http_api_url=self.base_url) + from agentscope.model import OpenAIChatModel + return OpenAIChatModel( + model_name="AgentJet-Model", api_key=self.api_key, + client_args={"base_url": self.base_url} + ) def as_raw_openai_sdk_client(self, *args, **kwargs): from openai import AsyncOpenAI return AsyncOpenAI(api_key=self.api_key, base_url=self.base_url) + class OpenaiClientBaseUrlTuner(BaseModel): """ At this layer, we will determine which model to use: - training model diff --git a/ajet/tuner_lib/experimental/as_oai_model_client.py b/ajet/tuner_lib/experimental/as_oai_model_client.py index 8288b609..aaecde5c 100644 --- a/ajet/tuner_lib/experimental/as_oai_model_client.py +++ b/ajet/tuner_lib/experimental/as_oai_model_client.py @@ -5,19 +5,17 @@ import os import time import zmq -import base64 import json from loguru import logger from typing import TYPE_CHECKING -from openai.types.chat.chat_completion import ChatCompletion from ajet.tuner_lib.experimental.as_oai_model_server import InterchangeCompletionRequest from ajet.utils.thread_executors import SharedInferenceTrackerThreadExecutor, SharedInterchangeThreadExecutor from ajet.tuner_lib.experimental.interchange_utils import get_zmq_socket -from ajet.tuner_lib.experimental.interchange_utils import DEBUG, API_KEY_PREFIX +from ajet.tuner_lib.experimental.interchange_utils import DEBUG if TYPE_CHECKING: - from vllm.entrypoints.openai.protocol import ChatCompletionRequest + pass context = zmq.Context() atexit.register(context.term) @@ -31,6 +29,7 @@ class InterchangeClient: """ def __init__(self, episode_uuid: str, context_tracker: "MultiAgentContextTracker", llm_inference_fn, config): + from ajet.task_rollout.async_llm_bridge import OpenaiLlmProxyWithTracker self.episode_uuid = episode_uuid self.context_tracker = context_tracker self.llm_inference_fn = llm_inference_fn @@ -40,37 +39,12 @@ def __init__(self, episode_uuid: str, context_tracker: "MultiAgentContextTracker self.ipc_path = ipc_path self.interchange_method = config.ajet.interchange_server.interchange_method self.max_inference_tracker_threads = config.ajet.interchange_server.max_inference_tracker_threads - - async def llm_infer( - self, - req: "ChatCompletionRequest", - timeline_uuid: str, - agent_name: str, - target_tag: str, - episode_uuid: str, - ) -> ChatCompletion: - from ajet.task_rollout.async_llm_bridge import OpenaiLlmProxyWithTracker - - req_as_dict = req.model_dump() self.llm_proxy_with_tracker = OpenaiLlmProxyWithTracker( context_tracker=self.context_tracker, config=self.config, llm_inference_fn=self.llm_inference_fn, ) - # infer + process with context tracker - response = await self.llm_proxy_with_tracker( - messages=req_as_dict["messages"], - tools=req_as_dict["tools"], - tool_choice="auto", - ) - - # this is an important id assignment - response.id = timeline_uuid - assert isinstance(response, ChatCompletion) - return response - - @property def should_soft_terminate(self) -> bool: if self._should_terminate: @@ -98,15 +72,15 @@ def begin_service(self): if DEBUG: logger.info(f"[client] {self.episode_uuid} | Starting InterchangeClient service loop...") self.socket = context.socket(zmq.REP) self.socket.bind(f"{self.episode_contect_address}") - self.socket.setsockopt(zmq.RCVTIMEO, 1*1000) # 3 second timeout for REP + self.socket.setsockopt(zmq.RCVTIMEO, 1*1000) # 1 second timeout for REP self.executor = SharedInterchangeThreadExecutor(self.max_inference_tracker_threads).get_shared_executor() if DEBUG: logger.info(f"[client] {self.episode_uuid} | Submitting _begin_service_threading to executor...") future = self.executor.submit(self._begin_service_threading) # wait till service begin running - time.sleep(0.5) wait_time = 1 + time.sleep(wait_time) while future._state == 'PENDING': if self.should_soft_terminate or self.should_hard_terminate: future.cancel() @@ -130,12 +104,15 @@ def _begin_service_threading(self): try: while not self.should_hard_terminate: - # listen for next request from remote try: - # if DEBUG: logger.info(f"[client] {self.episode_uuid} | socket.recv_string() has begun (should_terminate {self.should_terminate})") + + # : + # : ajet/tuner_lib/experimental/as_oai_model_server.py + # : socket.send_string(int_req.model_dump_json()) + # : InterchangeCompletionRequest object in JSON string format message = self.socket.recv_string() + ever_receive_anything = True - # if DEBUG: logger.info(f"[client] {self.episode_uuid} | socket.recv_string() is done") except zmq.Again as e: if self.should_hard_terminate: # abort_episode() @@ -154,27 +131,44 @@ def _begin_service_threading(self): # begin to run the llm request, monitored by context tracker # we re-use previously created thread for best performance if DEBUG: logger.info(f"[client] {self.episode_uuid} | before asyncio run self.llm_infer") + + # Check if there's a running event loop try: loop = asyncio.get_running_loop() - except: + created_new_loop = False + except RuntimeError: + # No running loop, create a new one loop = asyncio.new_event_loop() - context_tracker_executor = SharedInferenceTrackerThreadExecutor(self.max_inference_tracker_threads).get_shared_executor() - future = loop.run_in_executor( - context_tracker_executor, - asyncio.run, - self.llm_infer( - req=parsed_msg.completion_request, - timeline_uuid=parsed_msg.timeline_uuid, - agent_name=parsed_msg.agent_name, - target_tag=parsed_msg.target_tag, - episode_uuid=parsed_msg.episode_uuid, + asyncio.set_event_loop(loop) + created_new_loop = True + + try: + context_tracker_executor = SharedInferenceTrackerThreadExecutor(self.max_inference_tracker_threads).get_shared_executor() + future = loop.run_in_executor( + context_tracker_executor, + asyncio.run, + self.llm_proxy_with_tracker.chat_completion_request( + req=parsed_msg.completion_request, + timeline_uuid=parsed_msg.timeline_uuid, + agent_name=parsed_msg.agent_name, + target_tag=parsed_msg.target_tag, + episode_uuid=parsed_msg.episode_uuid, + ) ) - ) - result = loop.run_until_complete(future).model_dump_json() # type: ignore + result = loop.run_until_complete(future).model_dump_json() # type: ignore + finally: + # Clean up the event loop if we created it + if created_new_loop: + loop.close() + asyncio.set_event_loop(None) - # great, let's send back the result if DEBUG: logger.info(f"[client] {self.episode_uuid} | before send_string (send llm call result)") + + # + # : ajet/tuner_lib/experimental/as_oai_model_server.py + # : result_str = socket.recv_string() self.socket.send_string(result) + if DEBUG: logger.info(f"[client] {self.episode_uuid} | after send_string (send llm call result)") except: logger.exception(f"[client] {self.episode_uuid} | Exception occurred in service loop.") diff --git a/ajet/tuner_lib/experimental/as_oai_model_server.py b/ajet/tuner_lib/experimental/as_oai_model_server.py index 301483ff..367c8086 100644 --- a/ajet/tuner_lib/experimental/as_oai_model_server.py +++ b/ajet/tuner_lib/experimental/as_oai_model_server.py @@ -24,7 +24,9 @@ from loguru import logger from pydantic import BaseModel +from functools import lru_cache from fastapi import FastAPI, Header, HTTPException, Request +from fastapi.responses import StreamingResponse from contextlib import asynccontextmanager from multiprocessing import Manager, Process from concurrent.futures import ThreadPoolExecutor @@ -32,6 +34,9 @@ from vllm.entrypoints.openai.protocol import ChatCompletionRequest from openai.types.chat.chat_completion import ChatCompletion +from openai.types.chat.chat_completion_chunk import ChatCompletionChunk +from openai.types.chat.chat_completion_chunk import Choice as ChunkChoice +from openai.types.chat.chat_completion_chunk import ChoiceDelta, ChoiceDeltaToolCall, ChoiceDeltaToolCallFunction from ajet.utils.networking import get_host_ip from ajet.tuner_lib.experimental.interchange_utils import EpisodeStatus @@ -44,6 +49,7 @@ class InterchangeCompletionRequest(BaseModel): target_tag: str episode_uuid: str timeline_uuid: str + preserve_sampling_params: bool = False class HealthCheckRequest(BaseModel): agent_name: str @@ -58,6 +64,9 @@ class HealthCheckRequest(BaseModel): context = zmq.Context() atexit.register(context.term) +@lru_cache(maxsize=128) +def ep_key(episode_uuid: str) -> str: + return f"episodes-{episode_uuid}" def get_app(max_fastapi_threads: int = 512, enable_swarm_mode=False, shared_mem_dict=None, shared_mem_dict_lock=None) -> Tuple[FastAPI, Optional[Coroutine]]: @@ -85,14 +94,33 @@ def _begin_handle_chat_completion(episode_address, int_req: InterchangeCompletio socket.setsockopt(zmq.RCVTIMEO, 6*1000) # 6 second recv timeout socket.connect(f"{episode_address}") if DEBUG: logger.info(f"[server] episode_uuid: {episode_uuid} | connect done") + + # + # : ajet/tuner_lib/experimental/as_oai_model_client.py + # : message = self.socket.recv_string() socket.send_string(int_req.model_dump_json()) + if DEBUG: logger.info(f"[server] episode_uuid: {episode_uuid} | send_string") result_str = "" for _ in range(50): # max 5 minutes wait + + if enable_swarm_mode: + assert shared_mem_dict is not None + ep_stat = shared_mem_dict[ep_key(episode_uuid)] + episode_status = ep_stat.episode_status + if episode_status != "claimed": + raise HTTPException(status_code=404, detail="The episode is not claimed, cannot accept new requests.") + try: if DEBUG: logger.info(f"[server] episode_uuid: {episode_uuid} | recv_string begin.") + + # : + # : ajet/tuner_lib/experimental/as_oai_model_client.py + # : self.socket.send_string(result) + # : ChatCompletion object in JSON string format result_str = socket.recv_string() + break except zmq.Again as e: # check whether server is still in rolling status @@ -112,6 +140,89 @@ def _begin_handle_chat_completion(episode_address, int_req: InterchangeCompletio return result_object + async def mock_as_stream_response(result: ChatCompletion): + """ + Convert a non-streaming ChatCompletion result to streaming format. + + Args: + result: ChatCompletion object to convert to streaming format + + Yields: + Server-sent events formatted as streaming chat completion chunks + """ + content = result.choices[0].message.content if result.choices else "" + role = result.choices[0].message.role if result.choices else "assistant" + # try: + # thinking = result.choices[0].message.reasoning_content + # except: + # thinking = None + tool_calls = result.choices[0].message.tool_calls if result.choices and result.choices[0].message.tool_calls else None + delta_tool_calls = [] # tool_calls: Optional[List[ChoiceDeltaToolCall]] = None + finish_reason = result.choices[0].finish_reason + if tool_calls: + delta_tool_calls = [ChoiceDeltaToolCall( + index=index, + id=tc.id, + function=ChoiceDeltaToolCallFunction( + name = tc.function.name, + arguments = tc.function.arguments, + ), + type=tc.type + ) for index, tc in enumerate(tool_calls)] + + # First chunk with role + first_chunk = ChatCompletionChunk( + id=result.id, + model=result.model, + created=result.created, + object="chat.completion.chunk", + choices=[ + ChunkChoice( + index=0, + delta=ChoiceDelta(role=role, content=""), + finish_reason=None + ) + ] + ) + dat = f"data: {first_chunk.model_dump_json()}\n\n" + yield dat + + # Content chunk + content_chunk = ChatCompletionChunk( + id=result.id, + model=result.model, + created=result.created, + object="chat.completion.chunk", + choices=[ + ChunkChoice( + index=0, + delta=ChoiceDelta(role=role, content=content, tool_calls=delta_tool_calls), + finish_reason=None + ) + ] + ) + dat = f"data: {content_chunk.model_dump_json()}\n\n" + yield dat + + # Final chunk with finish_reason + final_chunk = ChatCompletionChunk( + id=result.id, + model=result.model, + created=result.created, + object="chat.completion.chunk", + choices=[ + ChunkChoice( + index=0, + delta=ChoiceDelta(), + finish_reason=finish_reason + ) + ] + ) + dat = f"data: {final_chunk.model_dump_json()}\n\n" + yield dat + yield "data: [DONE]\n\n" + + @app.get("/health") async def health(): return {"status": "ok"} @@ -149,12 +260,13 @@ async def chat_completions(request: Request, authorization: str = Header(None)): # Parse request body body = await request.json() new_req = ChatCompletionRequest.model_validate(body) - if new_req.stream: - return HTTPException(status_code=400, detail="Streaming responses not supported in current AgentJet version, please set `stream=false` for now.") # Create timeline UUID timeline_uuid = uuid.uuid4().hex + # if training, ignore all sampling parameters from request + preserve_sampling_params = False + # enable_swarm_mode if enable_swarm_mode: from ajet.tuner_lib.experimental.as_swarm_server import ep_key @@ -174,6 +286,14 @@ async def chat_completions(request: Request, authorization: str = Header(None)): es.latest_activity_timestamp = time.time() es.llm_call_count += 1 shared_mem_dict[ep_key(episode_uuid)] = es + if es.episode_type == "eval": + preserve_sampling_params = True + + # For streaming, we process as non-streaming but return in streaming format + original_stream = new_req.stream + if original_stream: + new_req.stream = False + new_req.stream_options = None # Add to received queue int_req = InterchangeCompletionRequest( @@ -182,10 +302,16 @@ async def chat_completions(request: Request, authorization: str = Header(None)): target_tag = target_tag, episode_uuid = episode_uuid, timeline_uuid = timeline_uuid, + preserve_sampling_params = preserve_sampling_params, ) if DEBUG: logger.info(f"episode_uuid: {episode_uuid} | Received new chat completion request (outside thread)") loop = asyncio.get_running_loop() - return await loop.run_in_executor(request.app.state.executor, _begin_handle_chat_completion, episode_address, int_req, episode_uuid) + result = await loop.run_in_executor(request.app.state.executor, _begin_handle_chat_completion, episode_address, int_req, episode_uuid) + + if original_stream: + return StreamingResponse(mock_as_stream_response(result), media_type="text/event-stream") + + return result if enable_swarm_mode: @@ -314,6 +440,7 @@ def start_interchange_server(config, blocking=False, env={}) -> int: # polling for server ready start_time = time.time() + _httpx_client = httpx.Client(timeout=0.5) while True: if interchange_server and interchange_server.exitcode is not None: logger.error(f"Interchange server subprocess failed to start. Return code: {interchange_server.exitcode}") @@ -323,7 +450,7 @@ def start_interchange_server(config, blocking=False, env={}) -> int: logger.error(msg) raise RuntimeError(msg) try: - if httpx.get(health_url, timeout=0.5).status_code == 200: + if _httpx_client.get(health_url).status_code == 200: break except Exception: # keep waiting @@ -348,7 +475,7 @@ def start_interchange_server(config, blocking=False, env={}) -> int: interchange_server.join() except KeyboardInterrupt: logger.info("Shutting down interchange server...") - try: httpx.post(f"http://127.0.0.1:{port}/stop_engine", timeout=8).status_code + try: _httpx_client.post(f"http://127.0.0.1:{port}/stop_engine", timeout=8).status_code except Exception: pass if interchange_server: diff --git a/ajet/tuner_lib/experimental/as_swarm_client.py b/ajet/tuner_lib/experimental/as_swarm_client.py index bf7c57f3..4bf9d8e5 100644 --- a/ajet/tuner_lib/experimental/as_swarm_client.py +++ b/ajet/tuner_lib/experimental/as_swarm_client.py @@ -52,6 +52,8 @@ def raise_for_status_with_detail(resp): raise RuntimeError(f"SwarmClient error {resp.status_code} with non-JSON response: {response_text}") from e +class SwarmServerOfflineError(Exception): ... + class SwarmClient(object): @@ -69,6 +71,8 @@ def __init__(self, server_url: str): self._agent_jet_job = None # throttle self._recent_seen_tasks = [] + # reuse httpx client to avoid creating SSL context repeatedly + self._http_client = httpx.Client(timeout=GENERAL_TIMEOUT) def logger_info(self, message): # logger with de-duplication within 1 second to prevent log flooding @@ -113,8 +117,8 @@ def _check_throttle_policy(self, throttle_policy: SwarmThrottlePolicy, pool_info if self._agent_jet_job: # check and raise early errors when possible - assert self._agent_jet_job.sample_collection_method == "rollout_until_finish_enough_tasks", \ - f"Current sample collection method ({self._agent_jet_job.sample_collection_method}) does not support throttle policy." + assert self._agent_jet_job.swarm_mode_sample_collection_method == "rollout_until_finish_enough_tasks", \ + f"Current sample collection method ({self._agent_jet_job.swarm_mode_sample_collection_method}) does not support throttle policy." # only_this_client_uuid = throttle_policy.throttle_method in ["Task_Ratio_Limit"] only_this_client_uuid = True @@ -193,7 +197,7 @@ def _should_throttle(self, throttle_policy: SwarmThrottlePolicy, pool_info: Curr self._remember_seen_task(throttle_policy.current_task_id, throttle_policy.expected_batch_size, throttle_policy.expected_num_repeat) return should_throttle - def begin_episode(self, discard_episode_timeout=600, episode_type="train", throttle_policy: SwarmThrottlePolicy|None = None) -> Tuple[str, OpenaiBaseUrlAndApiKey]: + def begin_episode(self, discard_episode_timeout=240, episode_type="train", throttle_policy: SwarmThrottlePolicy|None = None) -> Tuple[str, OpenaiBaseUrlAndApiKey]: """ Block until an episode is claimed. Argument: @@ -208,9 +212,9 @@ def begin_episode(self, discard_episode_timeout=600, episode_type="train", throt """ return self._begin_episode_auto_retry(discard_episode_timeout, episode_type, throttle_policy) - def _begin_episode_auto_retry(self, discard_episode_timeout=600, episode_type="train", throttle_policy: SwarmThrottlePolicy|None = None) -> Tuple[str, OpenaiBaseUrlAndApiKey]: + def _begin_episode_auto_retry(self, discard_episode_timeout=240, episode_type="train", throttle_policy: SwarmThrottlePolicy|None = None) -> Tuple[str, OpenaiBaseUrlAndApiKey]: # max_episode_time: when an episode has **lasted** for more than X seconds, it will be terminated **locally** by client (call `end_episode` will be re-route to `abort_episode`) - max_episode_time = 2*discard_episode_timeout + max_episode_time = 8*discard_episode_timeout status, status_json = self.get_engine_status() # warm up connection and log the status if status not in ["ENGINE.ROLLING"]: @@ -250,10 +254,9 @@ def _begin_episode_auto_retry(self, discard_episode_timeout=600, episode_type="t discard_episode_timeout=discard_episode_timeout, throttle_policy=throttle_policy ) - resp = httpx.post( + resp = self._http_client.post( f"{self.server_url}/claim_episode", - json=req_obj.model_dump(), - timeout=GENERAL_TIMEOUT + json=req_obj.model_dump() ) raise_for_status_with_detail(resp) data = ClaimEpisodeResponse.model_validate(resp.json()) @@ -316,13 +319,13 @@ def end_episode(self, task:Task, episode_uuid: str, workflow_output: WorkflowOut if episode_uuid in self.record_episode_expire_time: remain_time = self.record_episode_expire_time.pop(episode_uuid, 0) - time.time() if remain_time < 0: - logger.warning(f"Episode {episode_uuid} has expired (expired {-remain_time} seconds ago). Please use a larger `discard_episode_timeout` and `max_episode_time` when `begin_episode`. Skipping end_episode.") + logger.warning(f"Episode {episode_uuid} has expired (expired {-remain_time} seconds ago). Please use a larger `discard_episode_timeout` when `begin_episode`. Skipping end_episode.") # send abort signal to server to clean up episode self.abort_episode(episode_uuid) return else: # send abort signal to server to clean up episode - logger.warning(f"Episode {episode_uuid} has expired (expired at least {CLEAN_RECORD_TIMEOUT} seconds ago). Please use a larger `discard_episode_timeout` and `max_episode_time` when `begin_episode`. Skipping end_episode.") + logger.warning(f"Episode {episode_uuid} has expired (expired at least {CLEAN_RECORD_TIMEOUT} seconds ago). Please use a larger `discard_episode_timeout` when `begin_episode`. Skipping end_episode.") self.abort_episode(episode_uuid) return @@ -335,10 +338,9 @@ def end_episode(self, task:Task, episode_uuid: str, workflow_output: WorkflowOut task_id=task_id ) - resp = httpx.post( + resp = self._http_client.post( f"{self.server_url}/end_episode", - json=req_obj.model_dump(), - timeout=GENERAL_TIMEOUT + json=req_obj.model_dump() ) raise_for_status_with_detail(resp) data = EndEpisodeResponse.model_validate(resp.json()) @@ -364,10 +366,9 @@ def abort_episode(self, episode_uuid: str): task_id="" ) - resp = httpx.post( + resp = self._http_client.post( f"{self.server_url}/abort_episode", - json=req_obj.model_dump(), - timeout=GENERAL_TIMEOUT + json=req_obj.model_dump() ) raise_for_status_with_detail(resp) data = EndEpisodeResponse.model_validate(resp.json()) @@ -397,10 +398,9 @@ def sync_train_config(self, agent_jet_job: AgentJetJob): req_obj = SyncTrainConfigRequest(yaml_as_string=yaml_str) - resp = httpx.post( + resp = self._http_client.post( f"{self.server_url}/sync_train_config", - json=req_obj.model_dump(), - timeout=GENERAL_TIMEOUT + json=req_obj.model_dump() ) raise_for_status_with_detail(resp) self.logger_info("Synced train config to Swarm server") @@ -420,7 +420,7 @@ def start_engine(self): raise RuntimeError(f"Cannot start engine when engine is NOT ENGINE.OFFLINE. (current status: {current_status})") # Send start engine request - resp = httpx.post( + resp = self._http_client.post( f"{self.server_url}/start_engine", json={}, timeout=600 @@ -437,7 +437,7 @@ def start_engine(self): self._wait_until_status_change_to(desired_status="ENGINE.ROLLING") logger.success("Training engine is now ROLLING and ready.") - def _wait_until_status_change_to(self, desired_status="ENGINE.ROLLING", verbose=True): + def _wait_until_status_change_to(self, desired_status="ENGINE.ROLLING", verbose=True, timeout=1800): """ Poll engine status until it reaches desired_status. Reports status every 5 seconds while waiting. @@ -446,12 +446,20 @@ def _wait_until_status_change_to(self, desired_status="ENGINE.ROLLING", verbose= self.logger_info(f"Polling engine status until {desired_status}...") last_report_time = time.time() init_poll_time = last_report_time + initial_status, _ = self.get_engine_status() while True: try: current_status, _ = self.get_engine_status() current_time = time.time() + # Check if timeout has been reached + if current_time - init_poll_time >= timeout: + raise TimeoutError(f"Timeout reached while waiting for engine status to change to {desired_status}") + + if (initial_status == "ENGINE.OFFLINE") and (current_status == "ENGINE.OFFLINE") and (desired_status!="ENGINE.OFFLINE"): + raise SwarmServerOfflineError(f"Engine status changed from {initial_status} to OFFLINE while waiting for {desired_status}. This may indicate an error in the engine. Please check the swarm server logs for details.") + # Report status every 5 seconds if current_time - last_report_time >= 30: if verbose: @@ -467,6 +475,9 @@ def _wait_until_status_change_to(self, desired_status="ENGINE.ROLLING", verbose= # Wait a bit before next poll time.sleep(5) + except SwarmServerOfflineError as e: + raise e + except Exception as e: logger.error(f"Error polling engine status: {e}") time.sleep(5) @@ -474,7 +485,7 @@ def _wait_until_status_change_to(self, desired_status="ENGINE.ROLLING", verbose= @cache_with_ttl(ttl=0.5) def get_engine_status(self) -> Tuple[str, dict]: try: - resp = httpx.get( + resp = self._http_client.get( f"{self.server_url}/get_engine_status", timeout=10 ) @@ -499,7 +510,7 @@ def can_continue_episode(self, episode_uuid: str) -> bool: client_uuid=self.client_uuid, episode_uuid=episode_uuid ) - resp = httpx.post( + resp = self._http_client.post( f"{self.server_url}/can_continue_episode", json=req_obj.model_dump(), timeout=10 @@ -513,7 +524,7 @@ def can_continue_episode(self, episode_uuid: str) -> bool: def get_episode_buffer(self) -> List[EpisodeStatus]: try: - resp = httpx.post( + resp = self._http_client.post( f"{self.server_url}/get_episode_buffer", json={}, timeout=10 @@ -572,7 +583,7 @@ def stop_engine(self): self.logger_info("Engine is already OFFLINE. No action needed.") return - resp = httpx.post( + resp = self._http_client.post( f"{self.server_url}/stop_engine", json={}, timeout=600 @@ -592,7 +603,7 @@ def get_rollout_stat(self) -> CurrentBatchRolloutPoolInformation: Returns statistics about completed episodes, tasks, and progress. """ try: - resp = httpx.get( + resp = self._http_client.get( f"{self.server_url}/get_current_batch_rollout_pool_information", timeout=10 ) diff --git a/ajet/tuner_lib/experimental/as_swarm_server.py b/ajet/tuner_lib/experimental/as_swarm_server.py index a5d156d9..aa82984a 100644 --- a/ajet/tuner_lib/experimental/as_swarm_server.py +++ b/ajet/tuner_lib/experimental/as_swarm_server.py @@ -335,11 +335,11 @@ async def start_engine(): config_dict = yaml_module.safe_load(yaml_str) backbone = config_dict.get("ajet", {}).get("backbone", "verl") DEFAULT_DIR = "saved_experiments" - exp_dir_final = config_dict.get("ajet", {}).get("experiment_dir", DEFAULT_DIR) - if exp_dir_final != DEFAULT_DIR: - # remove last dir level if possible - exp_dir_final = os.path.dirname(exp_dir_final) - + experiment_dir = config_dict.get("ajet", {}).get("experiment_dir", DEFAULT_DIR) + if experiment_dir == "auto": + exp_base_dir = DEFAULT_DIR + else: + exp_base_dir = os.path.dirname(os.path.abspath(experiment_dir)) # Save YAML to temporary file with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".yaml") as temp_file: @@ -351,7 +351,6 @@ async def start_engine(): args = SimpleNamespace( conf=main_yaml_fp, backbone=backbone, - exp_dir=exp_dir_final, with_logview=False, debug=False, ) @@ -367,7 +366,12 @@ def override_param_callback(config): return config # Finalize experiment config - main_yaml_fp, exe_exp_base, exp_name, exp_config = prepare_experiment_config(main_yaml_fp, exp_dir_final, backbone, override_param_callback) + main_yaml_fp, exe_exp_base, exp_name, exp_config = prepare_experiment_config( + yaml_path=main_yaml_fp, + exp_base_dir=exp_base_dir, + backbone=backbone, + override_param_callback=override_param_callback, + ) # Setup environment variables env, exp_config = setup_environment_vars(args, exp_config, main_yaml_fp) @@ -393,6 +397,7 @@ def override_param_callback(config): main_yaml_fp, env, exp_config, + True, # is_swarm_server ), ) p.daemon = True @@ -707,7 +712,7 @@ async def get_episode_buffer(): @app.post("/update_current_batch_rollout_pool_information", response_model=BoolResponse) async def update_current_batch_rollout_pool_information(req: CurrentBatchRolloutPoolInformation): """Update the current batch rollout pool information.""" - if VERBOSE: + if DEBUG: logger.info(f"Running /update_current_batch_rollout_pool_information") try: with shared_mem_dict_lock: diff --git a/ajet/tuner_lib/experimental/interchange_utils.py b/ajet/tuner_lib/experimental/interchange_utils.py index c50dd93d..880b87b0 100644 --- a/ajet/tuner_lib/experimental/interchange_utils.py +++ b/ajet/tuner_lib/experimental/interchange_utils.py @@ -109,10 +109,16 @@ class UpdateEngineStatusRequest(BaseModel): VERBOSE = True +shared_http_client = httpx.Client(timeout=10.0) + def get_interchange_server_url(config): port = os.getenv("AJET_DAT_INTERCHANGE_PORT") - if config.ajet.interchange_server.interchange_server_port != 'auto': - port = str(int(config.ajet.interchange_server.interchange_server_port)) + if isinstance(config, dict): + interchange_server_port = config.get("ajet", {}).get("interchange_server", {}).get("interchange_server_port", "auto") + else: + interchange_server_port = config.ajet.interchange_server.interchange_server_port + if interchange_server_port != 'auto': + port = str(int(interchange_server_port)) assert port is not None, "AJET_DAT_INTERCHANGE_PORT env var must be set" master_node_ip = os.getenv("MASTER_NODE_IP", "localhost") base_url = f"http://{master_node_ip}:{port}" @@ -123,7 +129,7 @@ def http_change_engine_status(config, new_status: str, new_status_detail: str|No if new_status not in VALID_STATUSES: raise ValueError(f"Invalid engine status: {new_status}") - resp = httpx.post( + resp = shared_http_client.post( f"{get_interchange_server_url(config)}/update_engine_status", json={"engine_status": new_status, "engine_status_detail": new_status_detail, "global_step": global_step}, timeout=10 @@ -133,7 +139,7 @@ def http_change_engine_status(config, new_status: str, new_status_detail: str|No def is_episode_claimed(config, episode_uuid: str, unregister_if_not_claimed: bool) -> bool: - resp = httpx.post( + resp = shared_http_client.post( f"{get_interchange_server_url(config)}/is_episode_claimed", json={"episode_uuid": episode_uuid, "unregister_if_not_claimed": unregister_if_not_claimed}, timeout=5 @@ -164,7 +170,7 @@ def http_register_episode(config, zmq_listen_result_addr=zmq_listen_result_addr, ) # send http request to swarm server to register episode - response = httpx.post( + response = shared_http_client.post( f"{interchange_http_addr}/register_episode", json=rer.model_dump(), # 或者 rer.model_dump() 如果使用 Pydantic v2 timeout=2 diff --git a/ajet/utils/config_utils.py b/ajet/utils/config_utils.py index b02f1205..13c4d693 100644 --- a/ajet/utils/config_utils.py +++ b/ajet/utils/config_utils.py @@ -171,7 +171,7 @@ def config_safe_guard(config: dict, backbone: str) -> dict: def read_ajet_hierarchical_config( - yaml_fp, exp_name, backbone, write_to=None, exp_dir=DEFAULT_DIR, override_param_callback=None + yaml_fp, experiment_name=None, backbone=None, write_to=None, experiment_dir=None, override_param_callback=None ): if yaml_fp is None: config = { @@ -193,9 +193,12 @@ def read_ajet_hierarchical_config( else: with open(yaml_fp, "r", encoding="utf-8") as file: config = yaml.safe_load(file) - config["ajet"]["experiment_name"] = exp_name - config["ajet"]["experiment_dir"] = os.path.join(exp_dir, exp_name) - config["ajet"]["backbone"] = backbone + if experiment_name is not None: + config["ajet"]["experiment_name"] = experiment_name + if (experiment_dir is not None): + config["ajet"]["experiment_dir"] = experiment_dir + if backbone is not None: + config["ajet"]["backbone"] = backbone # remove extra config of verl for trinity if backbone == "debug": @@ -245,14 +248,14 @@ def expand_ajet_hierarchical_config(config, write_to=None): return config_final -def prepare_experiment_config(yaml_path, exp_dir, backbone, override_param_callback=None, storage=True): +def prepare_experiment_config(yaml_path, exp_base_dir, backbone, override_param_callback=None, storage=True): """ Prepare experiment configuration by reading YAML, setting up backup directories, and copying necessary files for the experiment. Args: yaml_path: Path to the YAML configuration file - exp_dir: Directory where experiment artifacts and backups should be stored + exp_base_dir: Directory where experiment artifacts and backups should be stored backbone: Backbone identifier that controls config munging Returns: @@ -281,8 +284,8 @@ def prepare_experiment_config(yaml_path, exp_dir, backbone, override_param_callb else: exp_name = exp_name.replace("|", "-") - backup_dir = os.path.abspath(os.path.join(exp_dir, exp_name, "backup")) - yaml_backup_dst = os.path.join(exp_dir, exp_name, "yaml_backup.yaml") + backup_dir = os.path.abspath(os.path.join(exp_base_dir, exp_name, "backup")) + yaml_backup_dst = os.path.join(exp_base_dir, exp_name, "yaml_backup.yaml") yaml_backup_dst = os.path.abspath(yaml_backup_dst) exe_exp_base = os.path.dirname(yaml_backup_dst) @@ -323,12 +326,18 @@ def prepare_experiment_config(yaml_path, exp_dir, backbone, override_param_callb shutil.copyfile(yaml_backup_src, yaml_backup_dst) ## 4. edit new yaml + experiment_dir = f"{exp_base_dir}/{exp_name}" config = read_ajet_hierarchical_config( - yaml_backup_dst, exp_name, backbone, write_to=yaml_backup_dst, exp_dir=exp_dir, override_param_callback=override_param_callback + yaml_backup_dst, + experiment_name=exp_name, + backbone=backbone, + write_to=yaml_backup_dst, + experiment_dir=experiment_dir, + override_param_callback=override_param_callback ) config_final = expand_ajet_hierarchical_config(config, write_to=yaml_backup_dst) if not storage: - shutil.rmtree(os.path.join(exp_dir, exp_name)) + shutil.rmtree(os.path.join(exp_base_dir, exp_name)) return yaml_backup_dst, exe_exp_base, exp_name, config_final diff --git a/ajet/utils/core_env_vars.py b/ajet/utils/core_env_vars.py index e48e1dda..9df18216 100644 --- a/ajet/utils/core_env_vars.py +++ b/ajet/utils/core_env_vars.py @@ -15,7 +15,7 @@ def get_runtime_env(config, is_trinity: bool = False) -> dict: if config.ajet.trainer_common.nnodes == 1: master_node_ip = "localhost" else: - if config.ajet.enable_experimental_interchange_server: + if config.ajet.enable_interchange_server: if config.ajet.interchange_server.interchange_method == "ipc": raise ValueError("IPC interchange method is not supported for multi-node setup. Please set `ajet.interchange_server.interchange_method: tcp` ") diff --git a/ajet/utils/launch_utils.py b/ajet/utils/launch_utils.py index 441fbc93..46200ad9 100644 --- a/ajet/utils/launch_utils.py +++ b/ajet/utils/launch_utils.py @@ -319,6 +319,7 @@ def execute_training_process( exe_yaml_path, env, exp_config, + is_swarm_server=False, ): """ Execute the training process based on the specified backbone and configuration. @@ -403,7 +404,13 @@ def execute_training_process( subprocess.run(cmd, check=True, cwd=os.path.abspath("./"), env=env) except subprocess.CalledProcessError as e: logger.error(f"Error running subprocess: {e}") + if is_swarm_server: + from ajet.tuner_lib.experimental.interchange_utils import http_change_engine_status + http_change_engine_status(exp_config, "ENGINE.OFFLINE", global_step=0) sys.exit(1) except Exception as e: logger.error(f"Unexpected error: {e}") + if is_swarm_server: + from ajet.tuner_lib.experimental.interchange_utils import http_change_engine_status + http_change_engine_status(exp_config, "ENGINE.OFFLINE", global_step=0) sys.exit(1) diff --git a/ajet/utils/swarm_overwatch.py b/ajet/utils/swarm_overwatch.py index da234a8f..9f8b7717 100644 --- a/ajet/utils/swarm_overwatch.py +++ b/ajet/utils/swarm_overwatch.py @@ -23,13 +23,13 @@ class SwarmOverwatch: """Real-time monitoring interface for swarm rollout pool""" - def __init__(self, server_url: str, refresh_interval: float = 1.0): + def __init__(self, server_url: str, refresh_interval: float = 2.0): """ Initialize the overwatch monitor Args: server_url: Base URL of the swarm server (e.g., http://localhost:10086) - refresh_interval: Refresh interval in seconds (default: 1.0) + refresh_interval: Refresh interval in seconds (default: 2.0) """ self.server_url = server_url.rstrip("/") self.refresh_interval = refresh_interval @@ -37,11 +37,12 @@ def __init__(self, server_url: str, refresh_interval: float = 1.0): self.last_update_time = None self.error_count = 0 self.total_requests = 0 + self._httpx_client = httpx.Client(timeout=5.0) def fetch_pool_info(self) -> Optional[CurrentBatchRolloutPoolInformation]: """Fetch current batch rollout pool information from server""" try: - response = httpx.get( + response = self._httpx_client.get( f"{self.server_url}/get_current_batch_rollout_pool_information", timeout=5.0, ) @@ -480,13 +481,13 @@ def run(self): ) -def start_overwatch(server_url: str, refresh_interval: float = 1.0): +def start_overwatch(server_url: str, refresh_interval: float = 2.0): """ Start the swarm overwatch monitoring interface Args: server_url: Base URL of the swarm server - refresh_interval: Refresh interval in seconds (default: 1.0) + refresh_interval: Refresh interval in seconds (default: 2.0) """ overwatch = SwarmOverwatch(server_url, refresh_interval) overwatch.run() diff --git a/ajet/utils/thread_executors.py b/ajet/utils/thread_executors.py index 8702e00c..87ea0744 100644 --- a/ajet/utils/thread_executors.py +++ b/ajet/utils/thread_executors.py @@ -45,12 +45,17 @@ def shutdown(self, wait=True): class PeriodicDrainThreadPoolExecutor: """A ThreadPoolExecutor that bounds the number of pending tasks via a semaphore.""" - def __init__(self, workers=100, auto_retry=True): + def __init__(self, workers=100, max_parallel=None, auto_retry=True, block_first_run=False): self._max_workers = workers - self._executor = ThreadPoolExecutor(max_workers=workers) + if max_parallel is None: + self._max_parallel = workers + else: + self._max_parallel = max_parallel + self._executor = ThreadPoolExecutor(max_workers=self._max_parallel) self._submitted_count = 0 self._auto_retry = auto_retry self.current_futures = [] + self._slow_first_run = block_first_run def submit(self, fn, *args, **kwargs): """Submit a task, blocking if the pending queue is full.""" @@ -63,9 +68,15 @@ def retry_wrapper(fn, *args, **kwargs): logger.exception(f"[run_episodes_until_all_complete] Error executing episode: {e}. Retrying...") if self._auto_retry: - return self._executor.submit(retry_wrapper, fn, *args, **kwargs) + future = self._executor.submit(retry_wrapper, fn, *args, **kwargs) else: - return self._executor.submit(fn, *args, **kwargs) + future = self._executor.submit(fn, *args, **kwargs) + + if self._slow_first_run: + self._slow_first_run = False + future.result() # Wait for the first run to complete before allowing more tasks to be submitted + + return future def submit_with_periodic_drain(self, fn, *args, **kwargs): """Submit a task, draining all in-flight work every `drain_every_n_job` submissions.""" @@ -86,4 +97,4 @@ def submit_with_periodic_drain(self, fn, *args, **kwargs): def shutdown(self, wait=True): """Shut down the underlying executor.""" - self._executor.shutdown(wait=wait) \ No newline at end of file + self._executor.shutdown(wait=wait) diff --git a/ajet/utils/tokenizer.py b/ajet/utils/tokenizer.py index 94ab8007..64587381 100644 --- a/ajet/utils/tokenizer.py +++ b/ajet/utils/tokenizer.py @@ -1,5 +1,6 @@ import copy import json +import threading from typing import Dict, List @@ -19,6 +20,10 @@ def cleanup_messages(messages: List[Dict]) -> List[Dict]: pass return messages_copied +# Cache storage +_cache = {} +_cache_lock = threading.Lock() + def ajet_apply_chat_template( tokenizer, @@ -28,16 +33,46 @@ def ajet_apply_chat_template( tokenize: bool = True, ): conversation = cleanup_messages(conversation) + + # Create cache key by hashing all inputs + cache_key = ( + id(tokenizer), + hash(json.dumps(conversation, sort_keys=True)), + hash(json.dumps(tools, sort_keys=True)) if tools else 0, + add_generation_prompt, + tokenize, + ) + + # Check cache with thread safety + with _cache_lock: + if cache_key in _cache: + return _cache[cache_key] + + # Compute result (time consuming) - outside lock to avoid blocking other threads if tools: - return tokenizer.apply_chat_template( + result = tokenizer.apply_chat_template( conversation, tools, add_generation_prompt=add_generation_prompt, tokenize=tokenize, ) else: - return tokenizer.apply_chat_template( + result = tokenizer.apply_chat_template( conversation, tokenize=tokenize, add_generation_prompt=add_generation_prompt, ) + + # Store in cache with thread safety (implement LRU eviction if cache gets too large) + with _cache_lock: + if len(_cache) >= 1024: + # Remove oldest item (first inserted) + try: + _cache.pop(next(iter(_cache))) + except KeyError: + # Cache was modified by another thread, which is fine + pass + + _cache[cache_key] = result + + return result diff --git a/docs/en/ajet-swarm-docker.md b/docs/en/ajet-swarm-docker.md index 38a3c653..15c054a7 100644 --- a/docs/en/ajet-swarm-docker.md +++ b/docs/en/ajet-swarm-docker.md @@ -24,6 +24,7 @@ docker run --rm -it \ -v ./swarmlog:/workspace/log \ -v ./swarmexp:/workspace/saved_experiments \ -p 10086:10086 \ + -e SWANLAB_API_KEY=$SWANLAB_API_KEY \ --gpus=all \ --shm-size=32GB \ ghcr.io/modelscope/agentjet:main \ @@ -89,6 +90,7 @@ docker run --rm -it \ -v ./swarmlog:/workspace/log \ -v ./swarmexp:/workspace/saved_experiments \ -p 10086:10086 \ + -e SWANLAB_API_KEY=$SWANLAB_API_KEY \ --gpus=all \ --shm-size=32GB \ ghcr.io/modelscope/agentjet:main \ diff --git a/docs/en/support_agentscope.md b/docs/en/support_agentscope.md index e551e4d9..13d308a5 100644 --- a/docs/en/support_agentscope.md +++ b/docs/en/support_agentscope.md @@ -64,7 +64,7 @@ This article introduce the way to convert different types of ways to convert you ajet: ... - enable_experimental_interchange_server: True + enable_interchange_server: True ... ``` diff --git a/docs/en/support_http.md b/docs/en/support_http.md index 0bf3ab3d..d7659b1b 100644 --- a/docs/en/support_http.md +++ b/docs/en/support_http.md @@ -89,7 +89,7 @@ in this AI era, you can always start from scratch and build your own "high-scrap ajet: ... - enable_experimental_interchange_server: True + enable_interchange_server: True ... ``` diff --git a/docs/en/support_langchain.md b/docs/en/support_langchain.md index d1e12890..344163e2 100644 --- a/docs/en/support_langchain.md +++ b/docs/en/support_langchain.md @@ -80,7 +80,7 @@ This article introduce the way to convert different types of ways to convert you ajet: ... - enable_experimental_interchange_server: True + enable_interchange_server: True ... ``` diff --git a/docs/en/support_oaisdk.md b/docs/en/support_oaisdk.md index b60b03e3..104a1c26 100644 --- a/docs/en/support_oaisdk.md +++ b/docs/en/support_oaisdk.md @@ -84,7 +84,7 @@ This article introduce the way to convert different types of ways to convert you ajet: ... - enable_experimental_interchange_server: True + enable_interchange_server: True ... ``` diff --git a/tests/bench/benchmark_appworld/benchmark_appworld.yaml b/tests/bench/benchmark_appworld/benchmark_appworld.yaml index f83e91f0..3622ed1b 100644 --- a/tests/bench/benchmark_appworld/benchmark_appworld.yaml +++ b/tests/bench/benchmark_appworld/benchmark_appworld.yaml @@ -58,14 +58,14 @@ ajet: execute_testing_lambda: "tests/bench/benchmark_appworld/benchmark_appworld.py->TestProbe" # -# ------------------ 不需要修改 ------------------ +# ------------------ do not edit ------------------ hydra: searchpath: - file://ajet/default_config - file://ajet/default_config/verl # verl only - file://ajet/default_config/trinity # trinity only -# ------------------ 不需要修改 ------------------ +# ------------------ do not edit ------------------ defaults: - verl_default # verl inherit 1/1 - trinity_default # trinity inherit 1/1 diff --git a/tests/bench/benchmark_appworld/benchmark_appworld_2nodes.yaml b/tests/bench/benchmark_appworld/benchmark_appworld_2nodes.yaml index 4ae12f17..f53ca63b 100644 --- a/tests/bench/benchmark_appworld/benchmark_appworld_2nodes.yaml +++ b/tests/bench/benchmark_appworld/benchmark_appworld_2nodes.yaml @@ -63,14 +63,14 @@ trinity: sync_offset: 0 sync_method: nccl -# ------------------ 不需要修改 ------------------ +# ------------------ do not edit ------------------ hydra: searchpath: - file://ajet/default_config - file://ajet/default_config/verl # verl only - file://ajet/default_config/trinity # trinity only -# ------------------ 不需要修改 ------------------ +# ------------------ do not edit ------------------ defaults: - verl_default # verl inherit 1/1 - trinity_default # trinity inherit 1/1 diff --git a/tests/bench/benchmark_appworld/benchmark_appworld_oai_sdk.yaml b/tests/bench/benchmark_appworld/benchmark_appworld_oai_sdk.yaml index e3175d19..89d82afb 100644 --- a/tests/bench/benchmark_appworld/benchmark_appworld_oai_sdk.yaml +++ b/tests/bench/benchmark_appworld/benchmark_appworld_oai_sdk.yaml @@ -56,14 +56,14 @@ ajet: execute_testing_lambda: "tests/bench/benchmark_appworld/benchmark_appworld.py->TestProbe" # -# ------------------ 不需要修改 ------------------ +# ------------------ do not edit ------------------ hydra: searchpath: - file://ajet/default_config - file://ajet/default_config/verl # verl only - file://ajet/default_config/trinity # trinity only -# ------------------ 不需要修改 ------------------ +# ------------------ do not edit ------------------ defaults: - verl_default # verl inherit 1/1 - trinity_default # trinity inherit 1/1 diff --git a/tests/bench/benchmark_countdown/benchmark_countdown.yaml b/tests/bench/benchmark_countdown/benchmark_countdown.yaml index fcd07f35..53cdd902 100644 --- a/tests/bench/benchmark_countdown/benchmark_countdown.yaml +++ b/tests/bench/benchmark_countdown/benchmark_countdown.yaml @@ -124,14 +124,14 @@ ajet: execute_testing_lambda: "tests/bench/benchmark_countdown/benchmark_countdown.py->TestProbe" # FOR ROBOT TESTING PURPOSE ONLY. NOT FOR HUMAN -# ------------------ 不需要修改 ------------------ +# ------------------ do not edit ------------------ hydra: searchpath: - file://ajet/default_config - file://ajet/default_config/verl # verl only - file://ajet/default_config/trinity # trinity only -# ------------------ 不需要修改 ------------------ +# ------------------ do not edit ------------------ defaults: - verl_default # verl inherit 1/1 - trinity_default # trinity inherit 1/1 diff --git a/tests/bench/benchmark_learn2ask/benchmark_learn2ask.yaml b/tests/bench/benchmark_learn2ask/benchmark_learn2ask.yaml index dd3b6a18..b435a0ac 100644 --- a/tests/bench/benchmark_learn2ask/benchmark_learn2ask.yaml +++ b/tests/bench/benchmark_learn2ask/benchmark_learn2ask.yaml @@ -57,14 +57,14 @@ trinity: sync_method: nccl -# ------------------ 不需要修改 ------------------ +# ------------------ do not edit ------------------ hydra: searchpath: - file://ajet/default_config - file://ajet/default_config/verl # verl only - file://ajet/default_config/trinity # trinity only -# ------------------ 不需要修改 ------------------ +# ------------------ do not edit ------------------ defaults: - verl_default # verl inherit 1/1 - trinity_default # trinity inherit 1/1 diff --git a/tests/bench/benchmark_math/benchmark_math.yaml b/tests/bench/benchmark_math/benchmark_math.yaml index 36648d5f..f0f8d896 100644 --- a/tests/bench/benchmark_math/benchmark_math.yaml +++ b/tests/bench/benchmark_math/benchmark_math.yaml @@ -62,14 +62,14 @@ trinity: sync_method: nccl -# ------------------ 不需要修改 ------------------ +# ------------------ do not edit ------------------ hydra: searchpath: - file://ajet/default_config - file://ajet/default_config/verl # verl only - file://ajet/default_config/trinity # trinity only -# ------------------ 不需要修改 ------------------ +# ------------------ do not edit ------------------ defaults: - verl_default # verl inherit - trinity_default # trinity inherit diff --git a/tests/bench/benchmark_math/benchmark_math_oai_sdk.yaml b/tests/bench/benchmark_math/benchmark_math_oai_sdk.yaml index a3dadd1a..e7bf0aba 100644 --- a/tests/bench/benchmark_math/benchmark_math_oai_sdk.yaml +++ b/tests/bench/benchmark_math/benchmark_math_oai_sdk.yaml @@ -59,14 +59,14 @@ trinity: sync_method: nccl -# ------------------ 不需要修改 ------------------ +# ------------------ do not edit ------------------ hydra: searchpath: - file://ajet/default_config - file://ajet/default_config/verl # verl only - file://ajet/default_config/trinity # trinity only -# ------------------ 不需要修改 ------------------ +# ------------------ do not edit ------------------ defaults: - verl_default # verl inherit - trinity_default # trinity inherit diff --git a/tests/bench/benchmark_math/benchmark_math_raw_http.yaml b/tests/bench/benchmark_math/benchmark_math_raw_http.yaml index 8a4fc433..88c9aa15 100644 --- a/tests/bench/benchmark_math/benchmark_math_raw_http.yaml +++ b/tests/bench/benchmark_math/benchmark_math_raw_http.yaml @@ -59,14 +59,14 @@ trinity: sync_method: nccl -# ------------------ 不需要修改 ------------------ +# ------------------ do not edit ------------------ hydra: searchpath: - file://ajet/default_config - file://ajet/default_config/verl # verl only - file://ajet/default_config/trinity # trinity only -# ------------------ 不需要修改 ------------------ +# ------------------ do not edit ------------------ defaults: - verl_default # verl inherit - trinity_default # trinity inherit diff --git a/tutorial/example_appworld/appworld.yaml b/tutorial/example_appworld/appworld.yaml index 316c605b..3ccb91b7 100644 --- a/tutorial/example_appworld/appworld.yaml +++ b/tutorial/example_appworld/appworld.yaml @@ -54,14 +54,14 @@ ajet: n_gpus_per_node: 8 -# ------------------ 不需要修改 ------------------ +# ------------------ do not edit ------------------ hydra: searchpath: - file://ajet/default_config - file://ajet/default_config/verl # verl only - file://ajet/default_config/trinity # trinity only -# ------------------ 不需要修改 ------------------ +# ------------------ do not edit ------------------ defaults: - verl_default # verl inherit 1/1 - trinity_default # trinity inherit 1/1 diff --git a/tutorial/example_appworld/appworld_oai_sdk.yaml b/tutorial/example_appworld/appworld_oai_sdk.yaml index 056aac91..4b159b7f 100644 --- a/tutorial/example_appworld/appworld_oai_sdk.yaml +++ b/tutorial/example_appworld/appworld_oai_sdk.yaml @@ -53,14 +53,14 @@ ajet: n_gpus_per_node: 8 -# ------------------ 不需要修改 ------------------ +# ------------------ do not edit ------------------ hydra: searchpath: - file://ajet/default_config - file://ajet/default_config/verl # verl only - file://ajet/default_config/trinity # trinity only -# ------------------ 不需要修改 ------------------ +# ------------------ do not edit ------------------ defaults: - verl_default # verl inherit 1/1 - trinity_default # trinity inherit 1/1 diff --git a/tutorial/example_countdown/countdown.yaml b/tutorial/example_countdown/countdown.yaml index d5b161bf..6dcadf81 100644 --- a/tutorial/example_countdown/countdown.yaml +++ b/tutorial/example_countdown/countdown.yaml @@ -135,14 +135,14 @@ ajet: execute_testing_lambda: "" # DO NOT EDIT, FOR ROBOT TESTING PURPOSE ONLY. NOT FOR HUMAN. -# ------------------ 不需要修改 ------------------ +# ------------------ do not edit ------------------ hydra: searchpath: - file://ajet/default_config - file://ajet/default_config/verl # verl only - file://ajet/default_config/trinity # trinity only -# ------------------ 不需要修改 ------------------ +# ------------------ do not edit ------------------ defaults: - verl_default # verl inherit 1/1 - trinity_default # trinity inherit 1/1 diff --git a/tutorial/example_deep_finance/deep_finance.yaml b/tutorial/example_deep_finance/deep_finance.yaml index e5de33da..fcb429c4 100644 --- a/tutorial/example_deep_finance/deep_finance.yaml +++ b/tutorial/example_deep_finance/deep_finance.yaml @@ -71,14 +71,14 @@ actor_rollout_ref: rollout: tensor_model_parallel_size: 8 gpu_memory_utilization: 0.8 -# ------------------ 不需要修改 ------------------ +# ------------------ do not edit ------------------ hydra: searchpath: - file://ajet/default_config - file://ajet/default_config/verl # verl only - file://ajet/default_config/trinity # trinity only -# ------------------ 不需要修改 ------------------ +# ------------------ do not edit ------------------ defaults: - verl_default # verl inherit 1/1 - trinity_default # trinity inherit 1/1 diff --git a/tutorial/example_deep_finance/yaml_template/deep_finance_template.yaml b/tutorial/example_deep_finance/yaml_template/deep_finance_template.yaml index 38aa82ed..d9f559b9 100644 --- a/tutorial/example_deep_finance/yaml_template/deep_finance_template.yaml +++ b/tutorial/example_deep_finance/yaml_template/deep_finance_template.yaml @@ -75,14 +75,14 @@ actor_rollout_ref: rollout: tensor_model_parallel_size: 8 gpu_memory_utilization: 0.8 -# ------------------ 不需要修改 ------------------ +# ------------------ do not edit ------------------ hydra: searchpath: - file://ajet/default_config - file://ajet/default_config/verl # verl only - file://ajet/default_config/trinity # trinity only -# ------------------ 不需要修改 ------------------ +# ------------------ do not edit ------------------ defaults: - verl_default # verl inherit 1/1 - trinity_default # trinity inherit 1/1 diff --git a/tutorial/example_deep_finance/yaml_template/deep_finance_template_maxlen.yaml b/tutorial/example_deep_finance/yaml_template/deep_finance_template_maxlen.yaml index 0ddd541c..02fa6f73 100644 --- a/tutorial/example_deep_finance/yaml_template/deep_finance_template_maxlen.yaml +++ b/tutorial/example_deep_finance/yaml_template/deep_finance_template_maxlen.yaml @@ -76,14 +76,14 @@ actor_rollout_ref: rollout: tensor_model_parallel_size: 8 gpu_memory_utilization: 0.8 -# ------------------ 不需要修改 ------------------ +# ------------------ do not edit ------------------ hydra: searchpath: - file://ajet/default_config - file://ajet/default_config/verl # verl only - file://ajet/default_config/trinity # trinity only -# ------------------ 不需要修改 ------------------ +# ------------------ do not edit ------------------ defaults: - verl_default # verl inherit 1/1 - trinity_default # trinity inherit 1/1 diff --git a/tutorial/example_deep_finance/yaml_template/infer.yaml b/tutorial/example_deep_finance/yaml_template/infer.yaml index 5e9d400e..7dcf60ff 100644 --- a/tutorial/example_deep_finance/yaml_template/infer.yaml +++ b/tutorial/example_deep_finance/yaml_template/infer.yaml @@ -76,14 +76,14 @@ actor_rollout_ref: rollout: tensor_model_parallel_size: 8 gpu_memory_utilization: 0.8 -# ------------------ 不需要修改 ------------------ +# ------------------ do not edit ------------------ hydra: searchpath: - file://ajet/default_config - file://ajet/default_config/verl # verl only - file://ajet/default_config/trinity # trinity only -# ------------------ 不需要修改 ------------------ +# ------------------ do not edit ------------------ defaults: - verl_default # verl inherit 1/1 - trinity_default # trinity inherit 1/1 diff --git a/tutorial/example_feedback_tracing/example_feedback_tracing.yaml b/tutorial/example_feedback_tracing/example_feedback_tracing.yaml index 1cb01333..894aca7c 100644 --- a/tutorial/example_feedback_tracing/example_feedback_tracing.yaml +++ b/tutorial/example_feedback_tracing/example_feedback_tracing.yaml @@ -62,14 +62,14 @@ trainer: - console -# ------------------ 不需要修改 ------------------ +# ------------------ do not edit ------------------ hydra: searchpath: - file://ajet/default_config - file://ajet/default_config/verl # verl only - file://ajet/default_config/trinity # trinity only -# ------------------ 不需要修改 ------------------ +# ------------------ do not edit ------------------ defaults: - verl_default # verl inherit 1/1 - trinity_default # trinity inherit 1/1 diff --git a/tutorial/example_learn2ask/learn2ask.yaml b/tutorial/example_learn2ask/learn2ask.yaml index acacbce2..211c1e8a 100644 --- a/tutorial/example_learn2ask/learn2ask.yaml +++ b/tutorial/example_learn2ask/learn2ask.yaml @@ -53,14 +53,14 @@ trinity: sync_method: nccl -# ------------------ 不需要修改 ------------------ +# ------------------ do not edit ------------------ hydra: searchpath: - file://ajet/default_config - file://ajet/default_config/verl # verl only - file://ajet/default_config/trinity # trinity only -# ------------------ 不需要修改 ------------------ +# ------------------ do not edit ------------------ defaults: - verl_default # verl inherit 1/1 - trinity_default # trinity inherit 1/1 diff --git a/tutorial/example_math_swarm/math.py b/tutorial/example_math_swarm/math.py index 041e544f..c1351a9d 100644 --- a/tutorial/example_math_swarm/math.py +++ b/tutorial/example_math_swarm/math.py @@ -28,38 +28,38 @@ def main(): reader_type = "huggingface_dat_repo", reader_config = AjetTaskReader( huggingface_dat_repo = HuggingfaceDatRepo( - dataset_path = "/root/agentjet/benchmark_datasets/dataset/gsm8k/socratic", + dataset_path = '/mnt/data_cpfs/model_cache/modelscope/dataset/openai/gsm8k/main', + # dataset_path = "/root/agentjet/benchmark_datasets/dataset/gsm8k/socratic", # dataset_path = "openai/gsm8k", # dataset_name = "main", ) ) ) - # # Hand shake with remote swarm server + # Hand shake with remote swarm server swarm_worker = SwarmClient(AJET_SWARM_URL) + ajet_job = AgentJetJob( + experiment_name="math_gsm8k_grpo", + algorithm="grpo", + n_gpu=REMOTE_ALLOCATE_GPU_PER_NODE, + model=REMOTE_MODEL_PATH, + batch_size=REMOTE_BATCH_SIZE, + num_repeat=GRPO_N, + ) + print(ajet_job.config.to_dict()) swarm_worker.auto_sync_train_config_and_start_engine( - AgentJetJob( - experiment_name="math_gsm8k_grpo", - algorithm="grpo", - n_gpu=REMOTE_ALLOCATE_GPU_PER_NODE, - model=REMOTE_MODEL_PATH, - batch_size=REMOTE_BATCH_SIZE, - num_repeat=GRPO_N, - ), + ajet_job, force_restart=True, ) def rollout(task): - try: - # begin episode - episode_uuid, api_baseurl_key = swarm_worker.begin_episode(discard_episode_timeout=60) - # execute agent ( base_url = api_baseurl_key.base_url, api_key = api_baseurl_key.api_key ) - workflow_output = execute_agent(task, api_baseurl_key) # reward is in `workflow_output` - # report output back to swarm remote - swarm_worker.end_episode(task, episode_uuid, workflow_output) - return - except: - pass + # begin episode + episode_uuid, api_baseurl_key = swarm_worker.begin_episode(discard_episode_timeout=60) + # execute agent ( base_url = api_baseurl_key.base_url, api_key = api_baseurl_key.api_key ) + workflow_output = execute_agent(task, api_baseurl_key) # reward is in `workflow_output` + # report output back to swarm remote + swarm_worker.end_episode(task, episode_uuid, workflow_output) + return executor = PeriodicDrainThreadPoolExecutor(workers=GRPO_N * REMOTE_BATCH_SIZE, auto_retry=True) for _ in range(NUM_EPOCH): diff --git a/tutorial/example_rubrics_judge/r_judge.yaml b/tutorial/example_rubrics_judge/r_judge.yaml index 5834da5f..94e0e940 100644 --- a/tutorial/example_rubrics_judge/r_judge.yaml +++ b/tutorial/example_rubrics_judge/r_judge.yaml @@ -57,14 +57,14 @@ ajet: -# ------------------ 不需要修改 ------------------ +# ------------------ do not edit ------------------ hydra: searchpath: - file://ajet/default_config - file://ajet/default_config/verl # verl only - file://ajet/default_config/trinity # trinity only -# ------------------ 不需要修改 ------------------ +# ------------------ do not edit ------------------ defaults: - verl_default # verl inherit 1/1 - trinity_default # trinity inherit 1/1 diff --git a/tutorial/example_werewolves/game.py b/tutorial/example_werewolves/game.py index 10246c32..8eca099e 100644 --- a/tutorial/example_werewolves/game.py +++ b/tutorial/example_werewolves/game.py @@ -44,7 +44,7 @@ async def hunter_stage( global moderator msg_hunter = await hunter_agent( await moderator(Prompts.to_hunter.format(name=hunter_agent.name)), - structured_model=get_hunter_model(players.current_alive), + structured_model=get_hunter_model(players.all_players), ) if msg_hunter.metadata.get("shoot"): return msg_hunter.metadata.get("name", None) @@ -134,7 +134,7 @@ async def werewolves_game(agents: list[ReActAgent], roles) -> bool: # noqa: C90 msgs_vote = await fanout_pipeline( players.werewolves, msg=await moderator(content=Prompts.to_wolves_vote), - structured_model=get_vote_model(players.current_alive), + structured_model=get_vote_model(players.all_players), enable_gather=False, ) killed_player, votes = majority_vote( @@ -187,7 +187,7 @@ async def werewolves_game(agents: list[ReActAgent], roles) -> bool: # noqa: C90 ), ), structured_model=get_poison_model( - players.current_alive, + players.all_players, ), ) if msg_witch_poison.metadata.get("poison"): @@ -206,7 +206,7 @@ async def werewolves_game(agents: list[ReActAgent], roles) -> bool: # noqa: C90 names_to_str(players.current_alive), ), ), - structured_model=get_seer_model(players.current_alive), + structured_model=get_seer_model(players.all_players), ) if msg_seer.metadata.get("name"): player = msg_seer.metadata["name"] @@ -282,7 +282,7 @@ async def werewolves_game(agents: list[ReActAgent], roles) -> bool: # noqa: C90 names_to_str(players.current_alive), ), ), - structured_model=get_vote_model(players.current_alive), + structured_model=get_vote_model(players.all_players), enable_gather=False, ) voted_player, votes = majority_vote( diff --git a/tutorial/example_werewolves/start.py b/tutorial/example_werewolves/start.py index 879b6101..0e0ab4b0 100644 --- a/tutorial/example_werewolves/start.py +++ b/tutorial/example_werewolves/start.py @@ -4,6 +4,7 @@ """The main entry point for the werewolf game.""" from typing import List +import agentscope import numpy as np import dotenv dotenv.load_dotenv() @@ -12,7 +13,7 @@ from agentscope.agent import ReActAgent from agentscope.formatter import DashScopeMultiAgentFormatter, OpenAIMultiAgentFormatter -from agentscope.model import OpenAIChatModel +from agentscope.model import OpenAIChatModel, DashScopeChatModel from loguru import logger from pydantic import Field @@ -81,9 +82,12 @@ def get_official_agent_prompt(name) -> str: class ExampleWerewolves(Workflow): trainable_targets: List[str] | None = Field(default=["werewolf"], description="List of agents to be fine-tuned.") + big_external_opponent_llm_url: str = Field(default="http://22.17.52.4:2888/v1", description="The URL of the big external opponent LLM. You can replace it with any OpenAI-compatible LLM API URL.") + big_external_opponent_llm_name: str = Field(default="/mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen3-235B-A22B-Instruct-2507/", description="The model name of the big external opponent LLM. You can replace it with any OpenAI-compatible LLM name.") async def execute(self, workflow_task: WorkflowTask, tuner: AjetTuner) -> WorkflowOutput: + assert agentscope.__version__ == "1.0.7", "AgentScope has too many bugs across versions, please use version 1.0.7 for werewolves example." # ensure trainable targets is legal assert self.trainable_targets is not None, "trainable_targets cannot be None in ExampleWerewolves (because we want to demonstrate a explicit multi-agent case)." @@ -103,28 +107,27 @@ async def execute(self, workflow_task: WorkflowTask, tuner: AjetTuner) -> Workfl # initialize agents players = [] for i, role in enumerate(roles): - default_model = OpenAIChatModel( - model_name="/mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen3-235B-A22B-Instruct-2507/", - stream=False, - client_args={"base_url": "http://22.17.52.4:2888/v1"}, - api_key="no_api_key", - generate_kwargs={"temperature": 0.01}, - ) - model_for_this_agent = tuner.as_agentscope_model( - agent_name=f"Player{i + 1}", # the name of this agent - target_tag=role, # `target_tag in self.trainable_targets` means we train this agent, otherwise we do not train this agent. - debug_model=default_model, # the model used when this agent is not in `self.trainable_targets` - ) + if role not in self.trainable_targets: + model_for_this_agent = OpenAIChatModel( + stream=False, + api_key="no_api_key", + generate_kwargs={"temperature": 0.01}, + model_name=self.big_external_opponent_llm_name, + client_args={"base_url": self.big_external_opponent_llm_url}, + ) + else: + model_for_this_agent = tuner.as_agentscope_model( + agent_name=f"Player{i + 1}", + target_tag=role, + ) agent = ReActAgent( name=f"Player{i + 1}", sys_prompt=get_official_agent_prompt(f"Player{i + 1}"), model=model_for_this_agent, - formatter=DashScopeMultiAgentFormatter() - if role in self.trainable_targets - else OpenAIMultiAgentFormatter(), + formatter=DashScopeMultiAgentFormatter() if isinstance(model_for_this_agent, DashScopeChatModel) else OpenAIMultiAgentFormatter(), max_iters=3 if role in self.trainable_targets else 5, ) - # agent.set_console_output_enabled(False) + agent.set_console_output_enabled(False) players += [agent] # reward condition diff --git a/tutorial/example_werewolves_swarm/agent_roll.py b/tutorial/example_werewolves_swarm/agent_roll.py new file mode 100644 index 00000000..c51bc01a --- /dev/null +++ b/tutorial/example_werewolves_swarm/agent_roll.py @@ -0,0 +1,72 @@ +# -*- coding: utf-8 -*- + +import os +from ajet.schema.task import Task, WorkflowTask +from ajet.copilot.job import AgentJetJob +from ajet.task_reader import RouterTaskReader +from ajet.utils.thread_executors import PeriodicDrainThreadPoolExecutor +from ajet.tuner_lib.as_oai_baseurl_apikey import OpenaiBaseUrlAndApiKey +from ajet.default_config.ajet_default import AjetTaskReader +from ajet.tuner_lib.experimental.as_swarm_client import SwarmClient + +NUM_EPOCH = 10000 +AJET_SWARM_URL = os.getenv("AJET_SWARM_URL", "http://localhost:10086") + +def main(): + + # Handshake with swarm remote, then send training param to swarm remote (such as model to be trained, algorithm, etc) + dataset = RouterTaskReader( + reader_type = "random_dummy", + reader_config = AjetTaskReader() + ) + + ajet_job = AgentJetJob( + base_yaml_config="tutorial/example_werewolves_swarm/werewolves.yaml", + algorithm="grpo", + experiment_name="werewolves_swarm", + max_env_worker=128, + ) + + # Hand shake with remote swarm server + swarm_worker = SwarmClient(AJET_SWARM_URL) + swarm_worker.auto_sync_train_config_and_start_engine( + ajet_job, + # force_restart=True, + ) + + GRPO_N = ajet_job.num_repeat + REMOTE_BATCH_SIZE = ajet_job.batch_size + + def rollout(task): + # begin episode + episode_uuid, api_baseurl_key = swarm_worker.begin_episode(discard_episode_timeout=240) + # execute agent ( base_url = api_baseurl_key.base_url, api_key = api_baseurl_key.api_key ) + workflow_output = execute_agent(task, api_baseurl_key) # reward is in `workflow_output` + # report output back to swarm remote + swarm_worker.end_episode(task, episode_uuid, workflow_output) + return + + + executor = PeriodicDrainThreadPoolExecutor(workers=1, max_parallel=64, auto_retry=True, block_first_run=True) + for _ in range(NUM_EPOCH): + for _, task in enumerate(dataset.generate_training_tasks()): + for _ in range(GRPO_N): + executor.submit_with_periodic_drain(fn=rollout, task=task) + + return + + +def execute_agent(task: Task, api_baseurl_key: OpenaiBaseUrlAndApiKey): + import asyncio + from tutorial.example_werewolves.start import ExampleWerewolves + game = ExampleWerewolves( + trainable_targets=["werewolf"], + big_external_opponent_llm_name="Qwen/Qwen3-235B-A22B-Instruct-2507", + big_external_opponent_llm_url="http://22.14.116.243:2888/v1", + ) + res = asyncio.run(game.execute(WorkflowTask(task=task), api_baseurl_key)) + return res + + +if __name__ == "__main__": + main() diff --git a/tutorial/example_werewolves_swarm/convert_skill.md b/tutorial/example_werewolves_swarm/convert_skill.md new file mode 100644 index 00000000..9ad3ff16 --- /dev/null +++ b/tutorial/example_werewolves_swarm/convert_skill.md @@ -0,0 +1,77 @@ +训练复杂智能体的时候,推荐先从yaml配置出发 + +首先,复制一份基础配置 ajet/default_config/ajet_ts_default.yaml + +cp ajet/default_config/ajet_ts_default.yaml tutorial/example_werewolves_swarm/werewolves.yaml + +然后对配置中的参数进行修改: + +---- opencode命令:这里补充一个参数配置说明表格,参考tutorial/example_werewolves_swarm/werewolves.yaml ---- + + +# 编写训练循环 (Swarm Client) + +Swarm Client 流程如下: + +- 连接蜂群 +- 蜂群server初始化 +- 开始EPOCH循环 + + +```python +# -*- coding: utf-8 -*- + +NUM_EPOCH = 10000 +AJET_SWARM_URL = os.getenv("AJET_SWARM_URL", "http://localhost:10086") + +def main(): + + ajet_job = AgentJetJob( + base_yaml_config="tutorial/example_werewolves_swarm/werewolves.yaml", + algorithm="grpo", + experiment_name="werewolves_swarm", + ) + + # Hand shake with remote swarm server + swarm_worker = SwarmClient(AJET_SWARM_URL) + swarm_worker.auto_sync_train_config_and_start_engine( ajet_job, force_restart=True ) + + GRPO_N = ajet_job.num_repeat + REMOTE_BATCH_SIZE = ajet_job.batch_size + + def rollout(task): + try: + # begin episode + episode_uuid, api_baseurl_key = swarm_worker.begin_episode(discard_episode_timeout=60) + # execute agent ( base_url = api_baseurl_key.base_url, api_key = api_baseurl_key.api_key ) + workflow_output = execute_agent(task, api_baseurl_key) # reward is in `workflow_output` + # report output back to swarm remote + swarm_worker.end_episode(task, episode_uuid, workflow_output) + return + except: + pass + + # Handshake with swarm remote, then send training param to swarm remote (such as model to be trained, algorithm, etc) + dataset = RouterTaskReader( + reader_type = "random_dummy", + reader_config = AjetTaskReader() + ) + executor = PeriodicDrainThreadPoolExecutor(workers=GRPO_N * REMOTE_BATCH_SIZE, auto_retry=True) + for _ in range(NUM_EPOCH): + for _, task in enumerate(dataset.generate_training_tasks()): + for _ in range(GRPO_N): + executor.submit_with_periodic_drain(fn=rollout, task=task) + + return None + + +def execute_agent(task: Task, api_baseurl_key: OpenaiBaseUrlAndApiKey): + raise NotImplementedError("see below.") + + +if __name__ == "__main__": + main() + +``` + +# 编写Agent (Swarm Client) diff --git a/tutorial/example_werewolves_swarm/werewolves.yaml b/tutorial/example_werewolves_swarm/werewolves.yaml new file mode 100644 index 00000000..e096e8f6 --- /dev/null +++ b/tutorial/example_werewolves_swarm/werewolves.yaml @@ -0,0 +1,74 @@ +# ------------------ main config ------------------ +ajet: + project_name: example_werewolves_swarm + experiment_dir: "auto" # {exp-dir}/{experiment_name} + + model: + # ✨ select model to be trained + path: /mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2___5-7B-Instruct + + + rollout: + user_workflow: null + temperature: 0.7 + max_env_worker: 64 + num_repeat: 6 + agent_madness_reward: 0.0 + tensor_model_parallel_size: 1 + # max_num_seqs: 40 + # monitor LLM's abormal behaviors during rollout + compute_madness_checklist: + - "nonsense" + max_response_length_in_one_turn: 1024 + max_model_len: 22000 + + task_reader: + type: random_dummy # `env_service` or `jsonl_dataset_file` or `huggingface_dat_repo` or `data_generation` or `random_dummy` + + task_judge: + # ✨ select evaluation function + judge_protocol: null + + # the experimental ZeroMQ interchange server feature that allows `tuner.as_oai_baseurl_apikey` feature + enable_interchange_server: True + # train in cloud, run episode locally + enable_swarm_mode: True + # both swarm / oai share the same interchange server + interchange_server: + interchange_method: 'ipc' # options: 'tcp' (multi-nodes) or 'ipc' (1 node) + interchange_server_port: 10086 + num_fastapi_process: 2 # 1, 2 or 4 is fine + max_fastapi_threads: 512 # 64 or 128 is fine + max_inference_tracker_threads: 64 # recommend to be equal to `ajet.rollout.max_env_worker` + already_started: False # do not edit, used by `swarm` + + swarm_mode_sample_collection_method: "rollout_until_finish_enough_tasks" + + debug: + debug_max_parallel: 1 + debug_first_n_tasks: 1 + + data: + train_batch_size: 32 + max_prompt_length: 4000 + max_response_length: 18000 + + trainer_common: + save_freq: 5 + test_freq: 9999999 + total_epochs: 9999999 + total_training_steps: 25 + nnodes: 1 + n_gpus_per_node: 8 + +# ------------------ do not edit ------------------ +hydra: + searchpath: + - file://ajet/default_config + - file://ajet/default_config/verl + +# ------------------ do not edit ------------------ +defaults: + - verl_default + - ajet_default + - _self_ diff --git a/tutorial/opencode_build_skillbench_agent.prompt.md b/tutorial/opencode_build_skillbench_agent.prompt.md new file mode 100644 index 00000000..4a290cf9 --- /dev/null +++ b/tutorial/opencode_build_skillbench_agent.prompt.md @@ -0,0 +1,17 @@ +# Train SkillBench with AgentJet Swarm with Vibe Coding + +result is generated by `claude sonnet 4.5` + +============================= + +你的任务是训练这个仓库中的智能体:https://github.com/benchflow-ai/skillsbench.git +仓库你需要下载到 ../skillsbench_swarm_test +这是在调试过程中你可以使用的模型(openrouter) + "url": "https://openrouter-openrouter-esyubhyrxv.ap-northeast-1.fcapp.run/api/v1", + "key": "sk-or-v1-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx" + "model": "qwen/qwen3-max" + + + +你的skill(首先读取该SKILL文件,获取必要知识): +- ajet/copilot/train-complex-blackbox/SKILL.md