diff --git a/.gitignore b/.gitignore
index bdb3f5f4..7b64dc3e 100644
--- a/.gitignore
+++ b/.gitignore
@@ -170,3 +170,5 @@ modelscope_cache
prompts
swarmexp
swarmlog
+werewolves_swarm
+.claude
diff --git a/ajet/backbone/main_trinity.py b/ajet/backbone/main_trinity.py
index 7956305f..dc06c21c 100644
--- a/ajet/backbone/main_trinity.py
+++ b/ajet/backbone/main_trinity.py
@@ -52,7 +52,7 @@ def patched_trainer_get_actor(cls, config: Config):
Explorer.get_actor = classmethod(patched_explorer_get_actor)
Trainer.get_actor = classmethod(patched_trainer_get_actor)
- if ajet_config.ajet.enable_experimental_interchange_server:
+ if ajet_config.ajet.enable_interchange_server:
from ajet.tuner_lib.experimental.as_oai_model_server import start_interchange_server
start_interchange_server(ajet_config)
diff --git a/ajet/backbone/main_verl.py b/ajet/backbone/main_verl.py
index 1e8c9c01..8eebb95f 100644
--- a/ajet/backbone/main_verl.py
+++ b/ajet/backbone/main_verl.py
@@ -67,7 +67,7 @@ def run_ppo(config: DictConfig) -> None:
def on_shutdown():
if ray.is_initialized():
ray.shutdown()
- if config.ajet.enable_experimental_interchange_server:
+ if config.ajet.enable_interchange_server:
if config.ajet.enable_swarm_mode:
from ajet.tuner_lib.experimental.interchange_utils import http_change_engine_status
print("Changing engine status to OFFLINE before shutdown...")
@@ -250,7 +250,7 @@ def run(self, config):
from ajet.backbone.trainer_verl import AjetRayPPOTrainer
- if config.ajet.enable_experimental_interchange_server:
+ if config.ajet.enable_interchange_server:
from ajet.tuner_lib.experimental.as_oai_model_server import start_interchange_server
start_interchange_server(config)
diff --git a/ajet/backbone/main_vllm.py b/ajet/backbone/main_vllm.py
index edbcb161..4e8b717c 100644
--- a/ajet/backbone/main_vllm.py
+++ b/ajet/backbone/main_vllm.py
@@ -186,7 +186,7 @@ def main(config):
os.environ.update(runtime_env["env_vars"])
# atexit.register(lambda: print("Process exiting, performing cleanup..."))
- if config.ajet.enable_experimental_interchange_server:
+ if config.ajet.enable_interchange_server:
from ajet.tuner_lib.experimental.as_oai_model_server import start_interchange_server
start_interchange_server(config)
if config.ajet.enable_swarm_mode:
diff --git a/ajet/backbone/trainer_verl.py b/ajet/backbone/trainer_verl.py
index 5ca7ce83..28f09f95 100644
--- a/ajet/backbone/trainer_verl.py
+++ b/ajet/backbone/trainer_verl.py
@@ -457,7 +457,7 @@ def init_workers(self):
)
def _update_interchange_server_status_flag(self, status: str):
- if self.config.ajet.enable_experimental_interchange_server:
+ if self.config.ajet.enable_interchange_server:
if self.config.ajet.enable_swarm_mode:
from ajet.tuner_lib.experimental.interchange_utils import http_change_engine_status
http_change_engine_status(self.config, status, global_step=self.global_steps)
@@ -858,7 +858,7 @@ def fit(self): # noqa: C901
self.global_steps += 1
# # when enabled oai request interchange, we need to clear the cache from time to time
- # if self.config.ajet.enable_experimental_interchange_server:
+ # if self.config.ajet.enable_interchange_server:
# from ajet.tuner_lib.experimental.as_oai_model_server import ensure_dat_interchange_server_cache_clear
# ensure_dat_interchange_server_cache_clear()
diff --git a/ajet/backbone/warm_up.py b/ajet/backbone/warm_up.py
index 192f4a0b..6e261b6a 100644
--- a/ajet/backbone/warm_up.py
+++ b/ajet/backbone/warm_up.py
@@ -54,7 +54,7 @@ def warm_up_task_judge_when_needed(config):
def clean_up_tmp_ajet_dir(config):
"""Clean up old IPC socket files in /tmp/ajet directory."""
import time
- if config.ajet.enable_experimental_interchange_server is False:
+ if config.ajet.enable_interchange_server is False:
return
tmp_dir = "/tmp/ajet"
diff --git a/ajet/context_tracker/multiagent_tracking.py b/ajet/context_tracker/multiagent_tracking.py
index b332a11d..9a6069fc 100644
--- a/ajet/context_tracker/multiagent_tracking.py
+++ b/ajet/context_tracker/multiagent_tracking.py
@@ -84,19 +84,6 @@ def extract_text_content_from_content_dict(self, msg):
# },
# ],
# }
- # or tool_result format?? not observed yet:
- # msg = {
- # "role": "tool",
- # "content": [
- # {
- # "type": "tool_result",
- # "id": "call_xxx",
- # "output": "tool output content",
- # "name": "tool_name"
- # },
- # ],
- # }
-
str_content = ""
for item in msg["content"]:
@@ -332,6 +319,7 @@ def save_llm_interaction_timeline(self, tools, llm_ext_msg, timeline):
)
):
logger.bind(exception=True).info(f"General Warning: merge failure discovered.\n")
+ # from ajet import bp; bp("SWARM")
return
@@ -346,7 +334,9 @@ def detect_tool_call_madness(self, llm_output):
# llm_output["tool_calls"] is not None, and is not []
tool_calls = llm_output["tool_calls"]
if "wrong_toolcall" in self.config.ajet.rollout.compute_madness_checklist:
- copy_tool_calls = copy.deepcopy(tool_calls)
+ # copy_tool_calls = copy.deepcopy(tool_calls)
+ # Shallow copy is sufficient - we're only reading the data
+ copy_tool_calls = tool_calls
wrong_toolcall = False
for i in range(len(copy_tool_calls)):
if ("function" in copy_tool_calls[i]) and (
diff --git a/ajet/copilot/job.py b/ajet/copilot/job.py
index 665c41cd..96ae2031 100644
--- a/ajet/copilot/job.py
+++ b/ajet/copilot/job.py
@@ -8,14 +8,13 @@
from __future__ import annotations
import os
+import time
+import yaml
import tempfile
-from types import SimpleNamespace
-from typing import Any, Callable, Union
-import yaml
+from types import SimpleNamespace
+from typing import Any, Callable, Union, cast
from loguru import logger
-
-
from ajet.default_config.ajet_default import Config
from ajet.utils.config_utils import (
expand_ajet_hierarchical_config,
@@ -30,70 +29,118 @@
setup_environment_vars,
)
-DEFAULT_DIR = "saved_experiments"
+
+def override_current_yaml_value_if_given(override_value, current_value):
+ if override_value is not None:
+ return override_value
+ else:
+ return current_value
+
+def _set_nested_attr(obj, attr_path: str, value):
+ keys = attr_path.split(".")
+ for key in keys[:-1]:
+ obj = getattr(obj, key)
+ setattr(obj, keys[-1], value)
+
+def _get_nested_attr(obj, attr_path: str):
+ for key in attr_path.split("."):
+ obj = getattr(obj, key)
+ return obj
class AgentJetJob:
- """Lightweight builder that launches AgentJet training as a subprocess."""
+ """
+ arg: base_yaml_config + **kwargs (yaml config, then override with kwargs)
+ arg: base_yaml_config (yaml config)
+ arg: **kwargs (yaml config, then override with kwargs)
+ """
def __init__(
self,
- backbone: str = "verl",
- model: str = "Qwen/Qwen2___5-7B-Instruct",
- n_gpu: int = 8,
- algorithm: str = "grpo",
- project_name="ajet-swarm",
- experiment_name="test",
- n_gpu_for_infer: int | None = None, # only for trinity backbone
- num_repeat: int = 8,
- batch_size: int = 32,
- swarm_mode: bool = True,
- sample_collection_method: str = "rollout_until_finish_enough_tasks",
- *kwargs,
+ base_yaml_config: str | None = None,
+ experiment_dir: str | None = None,
+ project_name: str | None = None,
+ experiment_name: str | None = None,
+ n_gpu: int | None = None,
+ model: str | None = None,
+ algorithm: str | None = None,
+ num_repeat: int | None = None,
+ batch_size: int | None = None,
+ swarm_mode: bool | None = None,
+ swarm_mode_sample_collection_method: str | None = None,
+ max_env_worker: int | None = None,
+ backbone: str | None = None,
) -> None:
- self.backbone = backbone
- self.exp_dir = DEFAULT_DIR
- self.project_name = project_name
- self.exp_name = experiment_name
- self.sample_collection_method = sample_collection_method
- if swarm_mode:
- default_yaml = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', "default_config/ajet_ts_default.yaml"))
+
+ if base_yaml_config is None:
+ base_yaml_config = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', "default_config/ajet_ts_default.yaml"))
else:
- default_yaml = None
- self.config_as_dict: dict = self.build_job_from_yaml(default_yaml)
+ logger.warning(f"Reading config from {base_yaml_config}.")
+ time.sleep(1)
+ self.config_as_dict: dict = self.build_job_from_yaml(base_yaml_config)
self.config = Config.update_from_dict_recursive(Config(), self.config_as_dict)
- self.config.ajet.experiment_name = experiment_name
- self.config.ajet.backbone = backbone
- self.config.ajet.model.path = model
- self.config.ajet.trainer_common.n_gpus_per_node = n_gpu
- self.config.ajet.trainer_common.algorithm.adv_estimator = algorithm
- self.config.ajet.rollout.num_repeat = num_repeat
- self.config.ajet.data.train_batch_size = batch_size
- self.config.ajet.enable_swarm_mode = swarm_mode
- self.config.ajet.swarm_mode_sample_collection_method = sample_collection_method
- if n_gpu_for_infer is None and backbone == "trinity":
- raise ValueError("Please specify `n_gpu_for_infer` (n_gpu_for_infer < n_gpu) for trinity backbone.")
- if (n_gpu_for_infer is not None) and backbone == "verl":
- raise ValueError("n_gpu_for_infer is only for trinity backbone, please set it to `None`.")
- else:
- if backbone == "trinity":
- assert isinstance(n_gpu_for_infer, int), f"`n_gpu_for_infer` should be int, got {type(n_gpu_for_infer)}."
- assert n_gpu_for_infer < n_gpu, "`n_gpu_for_infer` should be less than `n_gpu`."
- self.config.ajet.rollout.n_vllm_engine = n_gpu_for_infer
- self.config.ajet.rollout.tensor_model_parallel_size = 1
+ self.base_yaml_config: str = cast(str, base_yaml_config) # currently may be None, but will be set later
+ self.experiment_dir: str = cast(str, experiment_dir)
+ self.project_name: str = cast(str, project_name)
+ self.experiment_name: str = cast(str, experiment_name)
+ self.n_gpu: int = cast(int, n_gpu)
+ self.model: str = cast(str, model)
+ self.algorithm: str = cast(str, algorithm)
+ self.num_repeat: int = cast(int, num_repeat)
+ self.batch_size: int = cast(int, batch_size)
+ self.swarm_mode: bool = cast(bool, swarm_mode)
+ self.swarm_mode_sample_collection_method: str = cast(str, swarm_mode_sample_collection_method)
+ self.max_env_worker: int = cast(int, max_env_worker)
+ self.backbone: str = cast(str, backbone)
+
+ # see `ajet/default_config/ajet_ts_default.yaml`
+ overrides = {
+ "ajet.experiment_dir": "experiment_dir",
+ "ajet.project_name": "project_name",
+ "ajet.experiment_name": "experiment_name",
+ "ajet.model.path": "model",
+ "ajet.trainer_common.n_gpus_per_node": "n_gpu",
+ "ajet.trainer_common.algorithm.adv_estimator": "algorithm",
+ "ajet.rollout.num_repeat": "num_repeat",
+ "ajet.data.train_batch_size": "batch_size",
+ "ajet.enable_swarm_mode": "swarm_mode",
+ "ajet.swarm_mode_sample_collection_method": "swarm_mode_sample_collection_method",
+ "ajet.rollout.max_env_worker": "max_env_worker",
+ "ajet.backbone": "backbone",
+ }
+
+ # if any value given in kwargs, override the corresponding value in config
+ for attr_path, override_val in overrides.items():
+ # get value from yaml config
+ # >> e.g. current_model = self.config.model.path
+ current_val = _get_nested_attr(self.config, attr_path)
+
+ # if override_val (given in __init__) is not None, use it to override the value from yaml config
+ # >> e.g. new_model = self.model if (self.model is not None) else current_model
+ new_val = override_current_yaml_value_if_given(getattr(self, override_val), current_val)
+
+ # write final value to `self.config``
+ # >> e.g. self.config.model.path = new_model
+ _set_nested_attr(self.config, attr_path, new_val)
+
+ # write final value to `self`
+ # >> e.g. self.model = new_model
+ setattr(self, override_val, new_val)
+
+ if self.backbone == "trinity":
+ raise NotImplementedError("Trinity backbone is not yet supported in AgentJetJob.")
+
def build_job_from_yaml(self, yaml_path: str | None) -> dict:
self.config_as_dict = read_ajet_hierarchical_config(
yaml_path,
- exp_name=self.exp_name,
- backbone=self.backbone,
write_to=None,
- exp_dir=self.exp_dir,
)
self.config_as_dict = expand_ajet_hierarchical_config(self.config_as_dict, write_to=None)
logger.info(f"Built AgentJet job config: {yaml_path}")
return self.config_as_dict
+
def dump_job_as_yaml(self, yaml_path: str) -> str:
if os.path.dirname(yaml_path):
os.makedirs(os.path.dirname(yaml_path), exist_ok=True)
@@ -102,6 +149,7 @@ def dump_job_as_yaml(self, yaml_path: str) -> str:
logger.info(f"Saved training config to {yaml_path}")
return yaml_path
+
def set_workflow(
self, workflow: Union[str, Callable[..., Any]], ensure_reward_in_workflow: bool = False
) -> "AgentJetJob":
@@ -110,6 +158,7 @@ def set_workflow(
# ensure_reward_in_workflow
return self
+
def set_data(
self,
type: str,
@@ -136,60 +185,3 @@ def set_data(
return self
- def tune(self, *args, **kwargs) -> "AgentJetJob":
- import ray
- ast_cfg = self.config.ajet
- if not ast_cfg.rollout or not ast_cfg.rollout.user_workflow:
- raise ValueError("Workflow must be set via set_workflow before tuning.")
- if not ast_cfg.task_reader:
- raise ValueError("Data source must be set via set_data before tuning.")
-
- backbone = self.config.ajet.backbone
- exp_dir = self.config.ajet.experiment_dir
-
- with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".yaml") as temp_yaml:
- yaml_path = temp_yaml.name
- self.dump_job_as_yaml(yaml_path)
- args = SimpleNamespace(
- conf=yaml_path,
- backbone=backbone,
- exp_dir=exp_dir,
- with_logview=False,
- debug=False,
- )
-
- if args.backbone != "debug":
- # Enforce GPU availability and free memory threshold before proceeding
- check_avail_gpu(min_free_ratio=0.95)
-
- # finalize experiment config
- main_yaml_fp, exe_exp_base, exp_name, exp_config = prepare_experiment_config(
- yaml_path, exp_dir, backbone
- )
-
- # setup environment variables for ray
- env = setup_environment_vars(args, exp_config, main_yaml_fp)
-
- # start ray if not already started
- if not ray.is_initialized():
- from ajet.utils.launch_utils import start_ray_service
-
- start_ray_service(args, env)
- else:
- raise RuntimeError(
- "Ray is already initialized. Please shutdown existing Ray instance before starting a new tuning job."
- )
-
- # start training process
- if args.conf and main_yaml_fp and exe_exp_base and exp_config:
- execute_training_process(
- args,
- get_backbone_target(args.backbone),
- main_yaml_fp,
- exe_exp_base,
- main_yaml_fp,
- env,
- exp_config,
- )
-
- return self
diff --git a/ajet/copilot/train-complex-blackbox/SKILL.md b/ajet/copilot/train-complex-blackbox/SKILL.md
new file mode 100644
index 00000000..6bd8e808
--- /dev/null
+++ b/ajet/copilot/train-complex-blackbox/SKILL.md
@@ -0,0 +1,174 @@
+---
+name: train-complex-blackbox
+description: Create a trainable agent loop or agent workflow with AgentJet
+license: Complete terms in LICENSE.txt
+---
+
+
+## 0. Ask user for API key + model (or API key + base url + model) for debugging
+
+This is not 100% necessary, but it can help a lot in debugging in step 1.
+If user has not given a API, ask user to give your one.
+
+
+By default, the code you write should be located at ./tutorial/opencode_build_xxxxxx/*.py
+
+## 1. Initial Programming
+
+### Writing dataset collector (`get_training_dataset_item_list.py`)
+- `get_training_dataset_item_list.py`: Returns a list of training data items. Maybe a list of training tasks, each item is a string identifier of a training task, or a dict containing necessary information for the training task.
+
+### Episode Runner (`run_episode_once.py`)
+- `run_episode_once.py`:
+
+ - Argument Parser: takes (training data item identifier + api-key + base-url) as input, model-name is not required, you can make up a model name because we ignore it.
+
+ - Execute the agent: read the document of the agent user asked you to train, figure out how to execute the agent. In most cases you can use subprocess to start a commandline process to execute the agent, your biggest issue is to figure out how to pass the training data item identifier, api-key and base-url to that commandline process. You can also use python code to execute the agent if you think it's more convenient.
+
+ - Reward: extract / compute the reward/score for the agent's output. Some agents have clear reward sigal, but others don't.
+ - clear reward signal: take that down as the reward, no need to do extra reward engineering.
+ - no clear reward signal: you need to design a reward function to compute the reward/score for the agent's output. You can use another LLM to help you design the reward function, or you can design it by yourself if you have domain knowledge.
+
+
+### Test
+
+Remember to test these two parts before moving to step 2, make sure they work as expected.
+
+
+
+## 2. Writing training code
+
+This part is easy, simply follow this template and change the necessary part such as dataset path, model name, etc.
+
+`agent_roll.py`
+
+```python
+# -*- coding: utf-8 -*-
+
+import os
+import re
+import requests
+from textwrap import dedent
+from ajet.schema.task import Task, WorkflowOutput
+from ajet.copilot.job import AgentJetJob
+from ajet.task_reader import RouterTaskReader
+from ajet.utils.thread_executors import PeriodicDrainThreadPoolExecutor
+from ajet.tuner_lib.as_oai_baseurl_apikey import OpenaiBaseUrlAndApiKey
+from ajet.default_config.ajet_default import AjetTaskReader, HuggingfaceDatRepo
+from ajet.tuner_lib.experimental.as_swarm_client import SwarmClient
+
+# python -m tutorial.example_math_swarm.math
+
+GRPO_N = 4 # grpo group size
+NUM_EPOCH = 10000
+AJET_SWARM_URL = os.getenv("AJET_SWARM_URL", "http://localhost:10086")
+REMOTE_MODEL_PATH = os.getenv("REMOTE_MODEL_PATH", "/mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2.5-7B-Instruct")
+REMOTE_BATCH_SIZE = 32
+REMOTE_ALLOCATE_GPU_PER_NODE = 8
+
+def main():
+
+ # Handshake with swarm remote, then send training param to swarm remote (such as model to be trained, algorithm, etc)
+ dataset = RouterTaskReader(
+ reader_type = "huggingface_dat_repo",
+ reader_config = AjetTaskReader(
+ huggingface_dat_repo = HuggingfaceDatRepo(
+ dataset_path = '/mnt/data_cpfs/model_cache/modelscope/dataset/openai/gsm8k/main',
+ # dataset_path = "/root/agentjet/benchmark_datasets/dataset/gsm8k/socratic",
+ # dataset_path = "openai/gsm8k",
+ # dataset_name = "main",
+ )
+ )
+ )
+ # Load the CountDown dataset
+ # print(f"Loading dataset from: {LOCAL_DATASET_PATH}")
+ # dataset = RouterTaskReader(
+ # reader_type="jsonl_dataset_file",
+ # reader_config=AjetTaskReader(
+ # jsonl_dataset_file=JsonlDatasetFile(
+ # training=JsonlTrainingFp(file_path=LOCAL_DATASET_PATH)
+ # )
+ # ),
+ # )
+
+ # Hand shake with remote swarm server
+ swarm_worker = SwarmClient(AJET_SWARM_URL)
+ ajet_job = AgentJetJob(
+ experiment_name="math_gsm8k_grpo",
+ algorithm="grpo",
+ n_gpu=REMOTE_ALLOCATE_GPU_PER_NODE,
+ model=REMOTE_MODEL_PATH,
+ batch_size=REMOTE_BATCH_SIZE,
+ num_repeat=GRPO_N,
+ )
+ print(ajet_job.config.to_dict())
+ swarm_worker.auto_sync_train_config_and_start_engine(
+ ajet_job,
+ force_restart=True,
+ )
+
+ def rollout(task):
+ # begin episode
+ episode_uuid, api_baseurl_key = swarm_worker.begin_episode(discard_episode_timeout=60)
+ # execute agent ( base_url = api_baseurl_key.base_url, api_key = api_baseurl_key.api_key )
+ workflow_output = execute_agent(task, api_baseurl_key) # reward is in `workflow_output`
+ # report output back to swarm remote
+ swarm_worker.end_episode(task, episode_uuid, workflow_output)
+ return
+
+ executor = PeriodicDrainThreadPoolExecutor(workers=GRPO_N * REMOTE_BATCH_SIZE, auto_retry=True)
+ for _ in range(NUM_EPOCH):
+ for _, task in enumerate(dataset.generate_training_tasks()):
+ for _ in range(GRPO_N):
+ executor.submit_with_periodic_drain(fn=rollout, task=task)
+
+ return None
+
+
+def execute_agent(task: Task, api_baseurl_key: OpenaiBaseUrlAndApiKey):
+ ....
+ raw_reward: float = ... # compute the reward for the agent's output
+ return WorkflowOutput(reward=raw_reward, metadata={"important_metadata": important_metadata})
+
+
+if __name__ == "__main__":
+ main()
+
+
+```
+
+
+It is very clear now, your job in step 2 is to:
+
+- use `get_training_dataset_item_list.py` to generate `List[Task]` (`from ajet.schema.task import Task`)
+- use `run_episode_once.py` to execute a single episode and place it in `execute_agent` function
+
+
+## 3. Simplify your code and fix bugs
+
+before moving to step 4, you can simplify your code and fix bugs to make sure it can run smoothly.
+
+
+## 4. Training
+
+Finally, you can start training.
+
+Run `ajet-swarm start` to start training server (if the user has already installed agentjet swarm environment),
+if the user has docker environment, you can also refer to `docs/en/ajet-swarm-docker.md` to start a AgentSwarm docker container.
+
+Create a duplication of `agent_roll.py` named `agent_roll_one_episode_debug.py`, and modify it to only run one episode, this can help you debug whether the episode runner and reward function work as expected.
+
+After the server side is ready, run
+```bash
+python /path/to/agent_roll_one_episode_debug.py
+```
+watch console log to see if the episode can be executed successfully and reward can be computed correctly.
+
+If anything goes wrong, keep server running, rewrite and fix `agent_roll_one_episode_debug.py`, and run it again until it can run one episode successfully.
+
+Next, patch `agent_roll.py` if there are any bugs discorvered via the debugging of `agent_roll_one_episode_debug.py`, and then run
+```bash
+python /path/to/agent_roll.py
+```
+
+to start the training!
diff --git a/ajet/copilot/write-swarm-client/SKILL.md b/ajet/copilot/write-swarm-client/SKILL.md
index c62693e5..0b98902d 100644
--- a/ajet/copilot/write-swarm-client/SKILL.md
+++ b/ajet/copilot/write-swarm-client/SKILL.md
@@ -4,25 +4,24 @@ description: Create a trainable agent loop or agent workflow with AgentJet
license: Complete terms in LICENSE.txt
---
-## 简介:
-你的任务是根据要求,创建一个可训练 Agent (或者Agent Loop,多智能体系统等等),提供给用户做强化学习训练。
-在AgentJet强化学习框架下,这是非常简单的。
+## Introduction:
-首先,根据用户的要求,给智能体系统起一个名字,例如 user_math_agent
+Your task is to create a trainable Agent (or Agent Loop, multi-agent system, etc.) based on the requirements, and provide it to the user for reinforcement learning training. Under the AgentJet reinforcement learning framework, this is very simple.
-其次,创建文件:
-tutorial/user_math_agent
+First, give the agent system a name based on the user's requirements, for example `user_math_agent`.
-接下来,创建Agent源文件:
-tutorial/user_math_agent/agent_roll.py (以 tutorial/example_academic_trans_swarm/trans_roll.py 为模板,变化不大,关键是向用户索取必要的参数)
-tutorial/user_math_agent/agent_run.py (根据用户的要求,创建运行智能体的函数,或者类,都可以。同步异步都可以。)
-tutorial/user_math_agent/readme.md (Agent说明,以及训练、调试方法说明)
+Next, create the directory:
+`tutorial/user_math_agent`
+Then, create the Agent source files:
+- `tutorial/user_math_agent/agent_roll.py` (Use `tutorial/example_academic_trans_swarm/trans_roll.py` as a template. There aren't many changes — the key is to ask the user for the necessary parameters.)
+- `tutorial/user_math_agent/agent_run.py` (Create the function or class to run the agent based on the user's requirements. Synchronous or asynchronous are both fine.)
+- `tutorial/user_math_agent/readme.md` (Agent description, along with training and debugging instructions.)
-## 智能体编写方法
+## How to Write the Agent
-使用 OpenAI SDK 编写智能体,主要包含以下三个函数(以及必要的子函数和子模块):
+Write the agent using the OpenAI SDK. It mainly includes the following three functions (along with any necessary sub-functions and sub-modules):
```
from ajet.schema.task import Task, WorkflowOutput
@@ -31,24 +30,68 @@ def _compute_reward(...)
def _execute_agent(...)
-def run_agent_and_compute_reward(task: Task, base_url:string, api_key:string) -> WorkflowOutput:
+def run_agent_and_compute_reward(task: Task, base_url: string, api_key: string) -> WorkflowOutput:
```
-在 agent_roll 中,直接import run_agent_and_compute_reward即可。
+In `agent_roll`, simply import `run_agent_and_compute_reward`.
-- 智能体的编写要领:通过一个或几个Agent的协作,高效完成用户给定的任务。
-- 奖励编写的要领:容易验证的,使用规则直接计算。不容易验证的,模仿 `tutorial/example_academic_trans_swarm/train_multi_model/trans_reward.py` 中的方法,使用其他大型模型生成 LLM as Judge 程序。
+- **Key points for writing the agent:** Efficiently complete the user's given task through the collaboration of one or several Agents.
+- **Key points for writing the reward:** For things that are easy to verify, calculate directly using rules. For things that are hard to verify, follow the approach in `tutorial/example_academic_trans_swarm/train_multi_model/trans_reward.py` and use other large models to create an LLM-as-Judge program.
+## Training and Debugging Instructions
-## 训练、调试方法说明
+Overall, the user first runs `ajet-swarm start`, then runs `agent_roll.py`, and training begins. You do not need to and are not allowed to run these bash commands.
+- First, help the user write `agent_run.py` and `agent_roll.py`.
+- Then, write clear instructions to guide the user through training (`readme.md`).
-总体而言,就是用户先运行 `ajet-swarm start`, 然后再运行 `agent_roll.py` 训练就开始了。你不需要也不被允许运行这些bash命令。
-- 首先帮助用户写好 `agent_run.py` 和 `agent_roll.py`,
-- 然后写清楚引导用户训练的说明(readme.md),
-你的任务就完成了。
+Your task is then complete.
-以下是一些参考资料。
+Below are some reference materials.
+---
+
+## Introduction:
+
+Your task is to create a trainable Agent (or Agent Loop, multi-agent system, etc.) based on the requirements, and provide it to the user for reinforcement learning training. Under the AgentJet reinforcement learning framework, this is very simple.
+
+First, give the agent system a name based on the user's requirements, for example `user_math_agent`.
+
+Next, create the directory:
+`tutorial/user_math_agent`
+
+Then, create the Agent source files:
+- `tutorial/user_math_agent/agent_roll.py` (Use `tutorial/example_academic_trans_swarm/trans_roll.py` as a template. There aren't many changes — the key is to ask the user for the necessary parameters.)
+- `tutorial/user_math_agent/agent_run.py` (Create the function or class to run the agent based on the user's requirements. Synchronous or asynchronous are both fine.)
+- `tutorial/user_math_agent/readme.md` (Agent description, along with training and debugging instructions.)
+
+## How to Write the Agent
+
+Write the agent using the OpenAI SDK. It mainly includes the following three functions (along with any necessary sub-functions and sub-modules):
+
+```
+from ajet.schema.task import Task, WorkflowOutput
+
+def _compute_reward(...)
+
+def _execute_agent(...)
+
+def run_agent_and_compute_reward(task: Task, base_url: string, api_key: string) -> WorkflowOutput:
+```
+
+In `agent_roll`, simply import `run_agent_and_compute_reward`.
+
+- **Key points for writing the agent:** Efficiently complete the user's given task through the collaboration of one or several Agents.
+- **Key points for writing the reward:** For things that are easy to verify, calculate directly using rules. For things that are hard to verify, follow the approach in `tutorial/example_academic_trans_swarm/train_multi_model/trans_reward.py` and use other large models to create an LLM-as-Judge program.
+
+## Training and Debugging Instructions
+
+Overall, the user first runs `ajet-swarm start`, then runs `agent_roll.py`, and training begins. You do not need to and are not allowed to run these bash commands.
+- First, help the user write `agent_run.py` and `agent_roll.py`.
+- Then, write clear instructions to guide the user through training (`readme.md`).
+
+Your task is then complete.
+
+Below are some reference materials.
---
# Using AgentJet Swarm to Train Your Agents
diff --git a/ajet/default_config/ajet_default.yaml b/ajet/default_config/ajet_default.yaml
index 0f164971..1539d028 100644
--- a/ajet/default_config/ajet_default.yaml
+++ b/ajet/default_config/ajet_default.yaml
@@ -3,7 +3,7 @@ ajet:
project_name: "ajet_default_project"
experiment_name: "read_yaml_name"
experiment_dir: "auto" # {exp-dir}/{experiment_name}
- backbone: debug # `debug` or `trinity` or `verl`
+ backbone: verl # `debug` or `trinity` or `verl`
model:
@@ -85,6 +85,7 @@ ajet:
num_repeat: 1
+
task_reader:
# how to read dataset / environment
type: huggingface_dat_repo # `env_service` or `jsonl_dataset_file` or `huggingface_dat_repo` or `data_generation` or `random_dummy`
@@ -284,7 +285,7 @@ ajet:
# the experimental ZeroMQ interchange server feature that allows `tuner.as_oai_baseurl_apikey` feature
enable_swarm_mode: False
# both swarm / oai share the same interchange server
- enable_experimental_interchange_server: False
+ enable_interchange_server: False
# interchange server configuration
interchange_server:
interchange_method: 'ipc' # options: 'tcp' (multi-nodes) or 'ipc' (1 node)
@@ -306,8 +307,6 @@ ajet:
swarm_mode_sample_collection_max_cached_episodes: 9999
task_runner:
- # submit llm infer submit method
- llm_infer_submit_method: "async" # options: "sync", "async"
# how to wrap the user-defined workflow
wrapper_type: "asyncio-with-gc"
diff --git a/ajet/default_config/ajet_ts_default.yaml b/ajet/default_config/ajet_ts_default.yaml
index 1db9bdd5..90e3f4bd 100644
--- a/ajet/default_config/ajet_ts_default.yaml
+++ b/ajet/default_config/ajet_ts_default.yaml
@@ -3,15 +3,19 @@ ajet:
project_name: "ajet_default_project"
experiment_name: "read_yaml_name"
experiment_dir: "auto" # {exp-dir}/{experiment_name}
- backbone: debug # `debug` or `trinity` or `verl`
+ backbone: verl
model:
# which model should be trained
- path: /mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2___5-3B-Instruct
+ path: "Qwen/Qwen2.5-7B-Instruct"
rollout:
# the path to the workflow class
user_workflow: null
+ # maximum number of parallel environments / simulate workers
+ max_env_worker: 128
+ # how many times a task should be repeated
+ num_repeat: 4
task_reader:
type: random_dummy # `env_service` or `jsonl_dataset_file` or `huggingface_dat_repo` or `data_generation` or `random_dummy`
@@ -21,7 +25,7 @@ ajet:
judge_protocol: null # reward must come from remote user agent workflow, so set to null
# the experimental ZeroMQ interchange server feature that allows `tuner.as_oai_baseurl_apikey` feature
- enable_experimental_interchange_server: True
+ enable_interchange_server: True
# train in cloud, run episode locally
enable_swarm_mode: True
# both swarm / oai share the same interchange server
@@ -44,21 +48,29 @@ ajet:
# (Hint: a **task_id** is considered "NON-DUMMY" at least one of **episodes** of **task_id** has **different** reward value.)
swarm_mode_sample_collection_method: "rollout_until_finish_enough_tasks"
- rollout:
- # maximum number of parallel environments / simulate workers
- max_env_worker: 128
+ data:
+ # max number of tokens for prompt
+ max_prompt_length: 3000
+ # max number of tokens for response
+ max_response_length: 15000
+ # how many tasks per training batch
+ train_batch_size: 32
+ # [Hint]: The final number of samples per update will be: N_{sample} = (data.train_batch_size * rollout.num_repeat * rollout.multi_turn.expected_steps)
trainer_common:
logger: tensorboard
+ n_gpus_per_node: 8
+ algorithm:
+ adv_estimator: grpo
-# ------------------ 不需要修改 ------------------
+# ------------------ do not edit ------------------
hydra:
searchpath:
- file://ajet/default_config
- file://ajet/default_config/verl # verl only
-# ------------------ 不需要修改 ------------------
+# ------------------ do not edit ------------------
defaults:
- - verl_default # verl inherit 1/1
+ - verl_default
- ajet_default
- _self_
diff --git a/ajet/launcher.py b/ajet/launcher.py
index c9d5705b..d3a5b4cd 100644
--- a/ajet/launcher.py
+++ b/ajet/launcher.py
@@ -150,8 +150,8 @@ def start_swarm_server(env, config):
assert config.ajet.enable_swarm_mode, (
"Please enable_swarm_mode in config to start swarm server."
)
- assert config.ajet.enable_experimental_interchange_server, (
- "Please enable_experimental_interchange_server in config to start swarm server."
+ assert config.ajet.enable_interchange_server, (
+ "Please enable_interchange_server in config to start swarm server."
)
from ajet.tuner_lib.experimental.as_oai_model_server import (
start_interchange_server,
@@ -168,7 +168,7 @@ def main():
from ajet.utils.swarm_overwatch import start_overwatch
logger.info(f"Starting Swarm Overwatch for server: {args.swarm_overwatch}")
- start_overwatch(args.swarm_overwatch, refresh_interval=1.0)
+ start_overwatch(args.swarm_overwatch, refresh_interval=2.0)
return
# Enforce GPU availability and free memory threshold before proceeding
@@ -204,7 +204,6 @@ def main():
# read configuration from yaml
exp_config = None
- exp_dir = args.exp_dir or DEFAULT_DIR
if args.swarm_server and (not args.conf):
args.conf = os.path.abspath(
os.path.join(
@@ -215,6 +214,7 @@ def main():
"Please provide a valid config file for swarm server mode."
)
if args.conf:
+ exp_base_dir = args.exp_dir or DEFAULT_DIR
yaml_path = args.conf
(
main_yaml_fp,
@@ -222,7 +222,10 @@ def main():
exp_name,
exp_config,
) = prepare_experiment_config(
- yaml_path, exp_dir, args.backbone, storage=(not args.swarm_server)
+ yaml_path=yaml_path,
+ exp_base_dir=exp_base_dir,
+ backbone=args.backbone,
+ storage=(not args.swarm_server)
)
# setup environment variables
diff --git a/ajet/schema/extended_msg.py b/ajet/schema/extended_msg.py
index dfaa7460..5e789074 100644
--- a/ajet/schema/extended_msg.py
+++ b/ajet/schema/extended_msg.py
@@ -244,9 +244,11 @@ def get_inc_simple(self, text_frag_from, text_frag_to, tokenizer):
tokenizer_output = tokenizer(text_frag_from, return_tensors="pt", padding=False)
tokenizer_input_ids = tokenizer_output["input_ids"][0].tolist()
token_ids_acc = tokenizer_input_ids
+ del tokenizer_output # Free memory immediately
tokenizer_output = tokenizer(text_frag_to, return_tensors="pt", padding=False)
input_ids = tokenizer_output["input_ids"][0].tolist()
+ del tokenizer_output # Free memory immediately
# get the new tokens added in this step
input_id_increment = input_ids[len(token_ids_acc) :]
FN_DEBUG = False
diff --git a/ajet/swarm_cli.py b/ajet/swarm_cli.py
index eb5dd866..723d9b5a 100644
--- a/ajet/swarm_cli.py
+++ b/ajet/swarm_cli.py
@@ -24,8 +24,8 @@ def start_swarm_server(env, config, port):
assert config.ajet.enable_swarm_mode, (
"Please enable_swarm_mode in config to start swarm server."
)
- assert config.ajet.enable_experimental_interchange_server, (
- "Please enable_experimental_interchange_server in config to start swarm server."
+ assert config.ajet.enable_interchange_server, (
+ "Please enable_interchange_server in config to start swarm server."
)
# Set the port in the config
@@ -42,7 +42,7 @@ def start_swarm_server(env, config, port):
def cmd_start(args):
"""Handle the 'start' subcommand."""
# Use default config if not provided
- exp_dir = args.exp_dir or DEFAULT_DIR
+ exp_base_dir = args.exp_dir or DEFAULT_DIR
if not args.conf:
args.conf = os.path.abspath(
os.path.join(
@@ -61,7 +61,10 @@ def cmd_start(args):
exp_name,
exp_config,
) = prepare_experiment_config(
- yaml_path, exp_dir, "verl", storage=False
+ yaml_path=yaml_path,
+ exp_base_dir=exp_base_dir,
+ backbone="verl",
+ storage=False
)
# Setup environment variables
@@ -73,7 +76,6 @@ def __init__(self, conf, backbone, exp_dir):
self.swarm_server = True
self.swarm_overwatch = ""
self.debug = ""
-
swarm_args = SwarmArgs(args.conf, "verl", args.exp_dir)
env, exp_config = setup_environment_vars(swarm_args, exp_config, main_yaml_fp)
@@ -131,9 +133,9 @@ def main():
parser_overwatch.add_argument(
"--refresh-interval",
type=float,
- default=1.0,
+ default=2.0,
required=False,
- help="Refresh interval in seconds (default: 1.0)",
+ help="Refresh interval in seconds (default: 2.0)",
)
parser_overwatch.set_defaults(func=cmd_overwatch)
diff --git a/ajet/task_rollout/async_llm_bridge.py b/ajet/task_rollout/async_llm_bridge.py
index 3015f631..ced9cf16 100644
--- a/ajet/task_rollout/async_llm_bridge.py
+++ b/ajet/task_rollout/async_llm_bridge.py
@@ -1,9 +1,9 @@
-import asyncio
import copy
import json
import time
import uuid
-from typing import Any, Callable, Dict, List, Literal, Union
+from typing import Any, Callable, Dict, List, Literal, Union, Awaitable
+from typing import TYPE_CHECKING
from loguru import logger
from omegaconf import DictConfig
@@ -15,12 +15,13 @@
from ajet.schema.logprob import TokenAndProb
from ajet.utils.tokenizer import ajet_apply_chat_template
-from ajet.utils.async_utils import run_async_coroutine_with_timeout
-from ajet.utils.testing_utils import _mock_if_test_mode, _test_if_test_mode
from ajet.schema.convertion import convert_llm_proxy_response_to_oai_response
from ajet.schema.convertion import convert_llm_proxy_response_to_agentscope_response
from ajet.context_tracker.multiagent_tracking import MultiAgentContextTracker
+if TYPE_CHECKING:
+ from vllm.entrypoints.openai.protocol import ChatCompletionRequest
+
ChatResponse = Union[OpenAIChatCompletion, AgentScopeChatResponse]
@@ -58,207 +59,6 @@ def __init__(
self.max_llm_retries = max_llm_retries
self.tool_parser = Hermes2ProToolParser(self.tokenizer)
- def get_llm_inference_fn_sync(self, sampling_params: dict = {}) -> Callable: # noqa: C901
-
- def llm_chat_verl(
- messages: List[Dict[str, str]],
- custom_sampling_params: dict = {},
- tools=[],
- request_id: str = "",
- ) -> dict:
- request_id = uuid.uuid4().hex
-
- updated_sampling_params = {}
- if sampling_params:
- updated_sampling_params.update(sampling_params)
- if custom_sampling_params:
- updated_sampling_params.update(custom_sampling_params)
-
- input_messages = copy.deepcopy(messages)
- prompt_text = ajet_apply_chat_template(
- tokenizer=self.tokenizer,
- conversation=input_messages,
- tools=tools,
- add_generation_prompt=True,
- tokenize=False,
- )
- prompt_ids = self.tokenizer(prompt_text)["input_ids"]
-
- if self.config.ajet.execute_test:
- _test_if_test_mode("prompt_text", prompt_text, self.config)
-
- final_res = run_async_coroutine_with_timeout(
- self.async_rollout_manager.generate(
- request_id=request_id,
- prompt_ids=prompt_ids,
- sampling_params=updated_sampling_params,
- ),
- timeout=1800,
- )
-
- if self.config.ajet.rollout.name == "vllm":
- final_res: VerlVllmRequestOutput
- token_array = final_res.outputs[0].token_ids
- logprob_array = final_res.outputs[0].logprobs
- elif self.config.ajet.rollout.name == "sglang":
- token_array = final_res
-
- decoded_text = self.tokenizer.decode(token_array) # type: ignore
- if self.config.ajet.execute_test:
- decoded_text = _mock_if_test_mode("mock_decoded_text", decoded_text, self.config)
-
- if decoded_text.endswith("<|im_end|>"):
- decoded_text = decoded_text[: -len("<|im_end|>")]
-
- # if tool call
- tool_calls = None
- if (
- ("" in decoded_text)
- and ("" in decoded_text)
- and (not self.config.ajet.rollout.force_disable_toolcalls)
- ):
- parsed_tool_calls = self.tool_parser.extract_tool_calls(decoded_text, None) # type: ignore
- parsed_tool_calls = parsed_tool_calls.model_dump()
- if self.config.ajet.execute_test:
- _test_if_test_mode(
- "parsed_tool_calls", parsed_tool_calls["tool_calls"], self.config
- )
- model_called = parsed_tool_calls["tools_called"]
- if model_called:
- tool_calls = parsed_tool_calls["tool_calls"]
- is_bad_toolcall = False
- for i in range(len(tool_calls)):
- if "function" in tool_calls[i] and "arguments" in tool_calls[i]["function"]:
- expect_dict = json.loads(tool_calls[i]["function"]["arguments"])
- if not isinstance(expect_dict, dict):
- is_bad_toolcall = True
- if is_bad_toolcall:
- tool_calls = None
- decoded_text = decoded_text
- else:
- decoded_text = parsed_tool_calls["content"]
- if decoded_text is None:
- decoded_text = ""
-
- return {
- "role": "assistant",
- "request_id": request_id,
- "content": decoded_text,
- "tool_calls": tool_calls,
- "tokens": [
- TokenAndProb(
- token_id=token_id,
- logprob=logprob[token_id].logprob, # Warning: vllm logprob does not participant training (not reliable enough), for log only.
- decoded_string=logprob[token_id].decoded_token,
- )
- for token_id, logprob in zip(token_array, logprob_array) # type: ignore
- ],
- }
-
-
- def llm_chat_remote(
- messages: List[Dict[str, str]],
- custom_sampling_params: dict = {},
- tools=[],
- request_id: str = "",
- ) -> dict:
- updated_sampling_params = {}
- if sampling_params:
- updated_sampling_params.update(sampling_params)
- if custom_sampling_params:
- updated_sampling_params.update(custom_sampling_params)
- updated_sampling_params.update({"logprobs": 1, "return_tokens_as_token_ids": True})
- input_messages = copy.deepcopy(messages)
- for i in range(self.max_llm_retries):
- try:
- # this function is defined in `ajet/backbone/main_vllm.py`
- output_message = self.async_rollout_manager.submit_chat_completions(
- messages=input_messages,
- sampling_params=updated_sampling_params,
- tools=tools,
- request_id=request_id,
- )
- break
- except Exception as e:
- logger.bind(exception=True).exception(f"rollout_server.{i} error: {e.args}")
- time.sleep(i + 1)
- return output_message[-1] # type: ignore
-
-
- def llm_chat_trinity(
- messages: List[Dict[str, str]],
- custom_sampling_params: dict = {},
- tools=[],
- request_id: str = "",
- ) -> dict:
- async def main():
- updated_sampling_params = {}
- if sampling_params:
- updated_sampling_params.update(sampling_params)
- if custom_sampling_params:
- updated_sampling_params.update(custom_sampling_params)
- updated_sampling_params.pop("min_tokens")
-
- if tools:
- response = await self.async_rollout_manager.chat.completions.create(
- model=self.async_rollout_manager.model_path,
- messages=messages,
- logprobs=True,
- tools=tools,
- top_logprobs=0,
- **updated_sampling_params,
- )
- else:
- response = await self.async_rollout_manager.chat.completions.create(
- model=self.async_rollout_manager.model_path,
- messages=messages,
- logprobs=True,
- top_logprobs=0,
- **updated_sampling_params,
- )
- return response
-
- response = run_async_coroutine_with_timeout(main(), timeout=1800) # type: ignore
- prompt_text = self.tokenizer.decode(response.model_extra["prompt_token_ids"])
- prompt_token_ids = response.model_extra["prompt_token_ids"]
- content = response.choices[0].message.content
- message = response.choices[0].message.model_dump(exclude_unset=True, exclude_none=True)
-
- if content is None:
- content = ""
-
- if ("" in content) and (not message.get("tool_calls", None)):
- # logger.bind(exception=True).exception(f"Bad toolcall discovered \n\nprompt_text:\n{prompt_text}\n\nrepsonse:\n{content}")
- logger.warning(f"Bad toolcall discovered: {content}")
-
- return {
- "role": "assistant",
- "request_id": response.id,
- "content": content,
- "prompt_text": prompt_text,
- "prompt_token_ids": prompt_token_ids,
- "tool_calls": message.get("tool_calls", []),
- "tokens": [
- TokenAndProb(
- token_id=token,
- logprob=tokenlogprob.logprob, # Warning: vllm logprob does not participant training, for log only.
- decoded_string=tokenlogprob.token,
- )
- for tokenlogprob, token in zip(
- response.choices[0].logprobs.content,
- response.choices[0].token_ids,
- )
- ],
- }
-
- if self.llm_mode == "remote":
- return llm_chat_remote
- if self.llm_mode == "trinity":
- return llm_chat_trinity
- else:
- return llm_chat_verl
-
-
def get_llm_inference_fn_async(self, sampling_params: dict = {}) -> Callable: # noqa: C901
@@ -286,9 +86,6 @@ async def llm_chat_verl(
)
prompt_ids = self.tokenizer(prompt_text)["input_ids"]
- if self.config.ajet.execute_test:
- _test_if_test_mode("prompt_text", prompt_text, self.config)
-
final_res = await self.async_rollout_manager.generate(
request_id=request_id,
prompt_ids=prompt_ids,
@@ -303,13 +100,11 @@ async def llm_chat_verl(
token_array = final_res
decoded_text = self.tokenizer.decode(token_array) # type: ignore
- if self.config.ajet.execute_test:
- decoded_text = _mock_if_test_mode("mock_decoded_text", decoded_text, self.config)
if decoded_text.endswith("<|im_end|>"):
decoded_text = decoded_text[: -len("<|im_end|>")]
- # if tool call
+ # if tool call, use vLLM tool parser to extract tool calls and validate them
tool_calls = None
if (
("" in decoded_text)
@@ -319,10 +114,7 @@ async def llm_chat_verl(
parsed_tool_calls = self.tool_parser.extract_tool_calls(decoded_text, None) # type: ignore
parsed_tool_calls = parsed_tool_calls.model_dump()
- if self.config.ajet.execute_test:
- _test_if_test_mode(
- "parsed_tool_calls", parsed_tool_calls["tool_calls"], self.config
- )
+
model_called = parsed_tool_calls["tools_called"]
if model_called:
tool_calls = parsed_tool_calls["tool_calls"]
@@ -474,7 +266,7 @@ class OpenaiLlmProxyWithTracker(object):
def __init__(
self,
- llm_inference_fn: Callable, # Callable[AjetStandardLlmBridgeRequest, AjetStandardLlmBridgeResponse]
+ llm_inference_fn: Callable[..., Awaitable[Dict]], # Callable[AjetStandardLlmBridgeRequest, AjetStandardLlmBridgeResponse]
context_tracker: MultiAgentContextTracker,
config,
) -> None:
@@ -483,15 +275,39 @@ def __init__(
self.config = config
+ async def chat_completion_request(
+ self,
+ req: "ChatCompletionRequest",
+ timeline_uuid: str,
+ agent_name: str,
+ target_tag: str,
+ episode_uuid: str,
+ ):
+ from openai.types.chat.chat_completion import ChatCompletion
+ req_as_dict = req.model_dump()
+
+ # infer + process with context tracker
+ llm_output = await self.run_infer(
+ messages=req_as_dict["messages"],
+ tools=req_as_dict["tools"],
+ tool_choice="auto",
+ )
+ # convert to OpenAI ChatCompletion format
+ response: ChatCompletion = convert_llm_proxy_response_to_oai_response(llm_output)
+ # this is an important id assignment
+ response.id = timeline_uuid
+ assert isinstance(response, ChatCompletion)
+ return response
+
+
async def __call__(
self,
messages: List[dict],
tools: List = [],
tool_choice: str = "auto",
- structured_model=None,
**kwargs,
) -> ChatResponse:
- llm_output = await self.run_infer(messages, tools, tool_choice, structured_model, **kwargs)
+ llm_output = await self.run_infer(messages, tools, tool_choice, **kwargs)
return convert_llm_proxy_response_to_oai_response(llm_output)
@@ -500,9 +316,8 @@ async def run_infer(
messages: List[dict],
tools: List = [],
tool_choice: str = "auto", # always auto
- structured_model=None, # this is for AgentScope only
**kwargs,
- ):
+ ) -> Dict:
# generate timeline uuid
timeline_uuid = uuid.uuid4().hex
@@ -527,16 +342,10 @@ async def run_infer(
# else:
# otherwise, for abnormal output, can still proceed, but we do not track output anymore
- # run llm inference ✨
- if self.config.ajet.task_runner.llm_infer_submit_method == "sync":
- llm_output = await asyncio.to_thread(
- self.llm_inference_fn, converted_message, custom_sampling_params, tools
- )
- else:
- llm_output = await self.llm_inference_fn(converted_message, custom_sampling_params, tools)
-
+ # run llm inference ✨ (llm_chat_verl)
+ llm_output = await self.llm_inference_fn(converted_message, custom_sampling_params, tools)
- # begin context tracking
+ # context tracking
self.context_tracker.step_track(llm_output, context_safe, converted_message, tools, timeline_uuid=timeline_uuid)
return llm_output
@@ -554,7 +363,6 @@ def construct_overflow_response(self):
-
# ----------------------------------------------------------------------------------------------
# ------------------------ call async llm with context tracker (AgentScope) --------------------
# ----------------------------------------------------------------------------------------------
@@ -570,6 +378,6 @@ async def __call__(
**kwargs,
) -> AgentScopeChatResponse:
- llm_output = await self.run_infer(messages, tools, tool_choice, structured_model)
+ llm_output = await self.run_infer(messages, tools, tool_choice)
response = convert_llm_proxy_response_to_agentscope_response(llm_output, structured_model=structured_model)
return response
diff --git a/ajet/task_rollout/native_parallel_worker.py b/ajet/task_rollout/native_parallel_worker.py
index 79f0bdcd..8db4058c 100644
--- a/ajet/task_rollout/native_parallel_worker.py
+++ b/ajet/task_rollout/native_parallel_worker.py
@@ -1,7 +1,9 @@
"""Parallel environment rollout orchestration utilities."""
import os
+import gc
import time
+import tracemalloc
from concurrent.futures import Future, ThreadPoolExecutor
from typing import Dict, List, Literal
from urllib.parse import quote
@@ -9,6 +11,7 @@
import numpy as np
import torch
import threading
+from math import ceil
from loguru import logger
from tensordict import TensorDict
from torch.nn.utils.rnn import pad_sequence
@@ -89,6 +92,64 @@ def _write_swarm_rollout_dynamic_log(self, observation_window):
f.write(string_buffer)
return
+ def _check_memory_leak(self):
+ """Check for memory leaks by comparing memory snapshots."""
+ if not self._tracemalloc_started:
+ tracemalloc.start()
+ self._tracemalloc_started = True
+ logger.info("Memory tracking started (tracemalloc)")
+ self._memory_snapshot = tracemalloc.take_snapshot()
+ return
+
+ # Take a new snapshot
+ gc.collect() # Force garbage collection before snapshot
+ current_snapshot = tracemalloc.take_snapshot()
+
+ if self._memory_snapshot is not None:
+ # Compare snapshots
+ top_stats = current_snapshot.compare_to(self._memory_snapshot, 'lineno')
+
+ logger.info("=" * 80)
+ logger.info("Memory Leak Detection: Top 10 differences since last rollout_swarm call")
+ logger.info("=" * 80)
+
+ total_size_diff = 0
+ for stat in top_stats[:10]:
+ total_size_diff += stat.size_diff
+ logger.info(f"{stat}")
+
+ # Convert to MB
+ total_size_diff_mb = total_size_diff / 1024 / 1024
+ logger.info(f"\nTotal memory difference: {total_size_diff_mb:.2f} MB")
+
+ # Show top current memory consumers
+ logger.info("\n" + "=" * 80)
+ logger.info("Top 10 current memory allocations")
+ logger.info("=" * 80)
+ top_current = current_snapshot.statistics('lineno')
+ for stat in top_current[:10]:
+ logger.info(f"{stat}")
+
+ logger.info("=" * 80)
+
+ # Enhanced leak detection: show traceback for largest leak
+ if total_size_diff_mb > 10: # Only if leak is significant (>10MB)
+ logger.warning(f"SIGNIFICANT MEMORY LEAK DETECTED: {total_size_diff_mb:.2f} MB")
+ logger.info("\n" + "=" * 80)
+ logger.info("Detailed traceback for top 3 memory leaks:")
+ logger.info("=" * 80)
+ for i, stat in enumerate(top_stats[:3], 1):
+ if stat.size_diff > 0:
+ logger.info(f"\n--- Leak #{i}: +{stat.size_diff / 1024 / 1024:.2f} MB, {stat.count_diff} objects ---")
+ logger.info(f"File: {stat.traceback.format()[0] if stat.traceback else 'Unknown'}")
+ if stat.traceback and len(stat.traceback) > 1:
+ logger.info("Full traceback:")
+ for line in stat.traceback.format():
+ logger.info(f" {line}")
+ logger.info("=" * 80)
+
+ # Update snapshot for next comparison
+ self._memory_snapshot = current_snapshot
def rollout_static(
self,
@@ -173,10 +234,13 @@ def rollout_swarm( # noqa: C901
each thread re-spawn after complete, until reaching conditions to stop.
"""
+ # # Memory leak detection: compare with previous snapshot
+ # self._check_memory_leak()
+
tracker_array: List[SingleAgentContextTracker] = []
rollout_n = self.rollout_n
n_batch_task = len(tasks)
- n_task = min(len(tasks), self.max_parallel // rollout_n)
+ n_task = min(len(tasks), ceil(self.max_parallel / rollout_n))
assert n_task > 0, f"n_task is not valid, n_task = min(len(tasks), self.max_parallel // rollout_n) = {n_task}"
self.current_token_count_time = time.time()
@@ -232,7 +296,7 @@ def enough_finished_task_stop_condition(completed_task_id_map_ct) -> bool:
f"Too many cached episodes [{total_completed_episodes}] has exceeded the max cached episodes [{self.config.ajet.swarm_mode_sample_collection_max_cached_episodes}] "
f"Deleting cached episodes to release memory..."
)
- completed_task_id_map_ct = {}
+ completed_task_id_map_ct.clear()
return (total_completed_tasks >= n_batch_task)
def enough_non_dummy_task_stop_condition(completed_task_id_map_ct) -> bool:
@@ -257,7 +321,7 @@ def enough_non_dummy_task_stop_condition(completed_task_id_map_ct) -> bool:
f"Too many cached episodes [{total_completed_episodes}] has exceeded the max cached episodes [{self.config.ajet.swarm_mode_sample_collection_max_cached_episodes}] "
f"Deleting cached episodes to release memory..."
)
- completed_task_id_map_ct = {}
+ completed_task_id_map_ct.clear()
return (total_completed_non_dummy_tasks >= n_batch_task)
# select stop condition function based on config
@@ -370,22 +434,23 @@ def update_rollout_result_array_preview(observation_window, completed_task_id_ma
self._write_swarm_rollout_dynamic_log(observation_window)
meet_stop_condition_after_new_results = stop_condition(completed_task_id_map_ct)
if meet_stop_condition_after_new_results:
- print("Sending soft stop signal to all threads...")
+ logger.info("Sending soft stop signal to all threads...")
stop_all_threads_soft()
break
# wait for all threads to complete
- print('Finalizing all threads...')
+ logger.info('Finalizing all threads...')
executor.shutdown(wait=True)
# stop all threads hard
- print("Sending hard stop signal to all threads...")
+ logger.info("Sending hard stop signal to all threads...")
stop_all_threads_hard()
# build tracker_array
- print('Collecting results...')
+ logger.info('Collecting results...')
for ct_list in completed_task_id_map_ct.values():
tracker_array.extend(ct_list)
+ completed_task_id_map_ct.clear()
# TODO: support multi-step reward
task_success_rate = np.mean(
@@ -402,6 +467,20 @@ def update_rollout_result_array_preview(observation_window, completed_task_id_ma
update_rollout_result_array_preview(observation_window, completed_task_id_map_ct)
self._write_swarm_rollout_dynamic_log(observation_window)
+ # Explicit cleanup to prevent memory leaks
+ logger.debug("Performing explicit cleanup...")
+ # Clear futures list
+ futures.clear()
+ # Clear observation window
+ observation_window.clear()
+ # Delete local function references to break circular refs
+ del stop_condition_callback
+ del stop_condition
+ del update_rollout_result_array_preview
+ del count_tasks
+ # Force garbage collection
+ gc.collect()
+
return tracker_array
diff --git a/ajet/task_rollout/single_worker.py b/ajet/task_rollout/single_worker.py
index 4791a47f..baca7e34 100644
--- a/ajet/task_rollout/single_worker.py
+++ b/ajet/task_rollout/single_worker.py
@@ -72,6 +72,10 @@ def __init__(
max_llm_retries=max_llm_retries,
)
+ # Memory leak tracking
+ self._memory_snapshot = None
+ self._tracemalloc_started = False
+
@retry_with_backoff(max_retry_attr="max_llm_retries")
def rollout_env_worker(
self,
@@ -90,14 +94,9 @@ def rollout_env_worker(
"""
sampling_params = get_sample_params(mode, self.config)
- if self.config.ajet.task_runner.llm_infer_submit_method == "sync":
- llm_inference_fn = self.async_llm_bridge.get_llm_inference_fn_sync(
- sampling_params=sampling_params
- )
- else:
- llm_inference_fn = self.async_llm_bridge.get_llm_inference_fn_async(
- sampling_params=sampling_params
- )
+ llm_inference_fn = self.async_llm_bridge.get_llm_inference_fn_async(
+ sampling_params=sampling_params
+ )
episode_uuid = uuid.uuid4().hex
workflow_task = WorkflowTask(
diff --git a/ajet/tuner.py b/ajet/tuner.py
index b780d44a..45a54425 100644
--- a/ajet/tuner.py
+++ b/ajet/tuner.py
@@ -23,7 +23,7 @@ def __init__(
self.context_tracker = context_tracker
self.llm_inference_fn = llm_inference_fn
self.target2proxy_registry: dict[str, dict[str,TunerTypeUnion]] = {}
- self.enable_interchange_server = config.ajet.enable_experimental_interchange_server
+ self.enable_interchange_server = config.ajet.enable_interchange_server
if self.enable_interchange_server:
self.proxy_client_started = False
@@ -102,10 +102,10 @@ def as_oai_baseurl_apikey(
```
"""
- assert self.enable_interchange_server, "Please enable `ajet.enable_experimental_interchange_server` in yaml config to use `as_oai_baseurl_apikey` feature."
+ assert self.enable_interchange_server, "Please enable `ajet.enable_interchange_server` in yaml config to use `as_oai_baseurl_apikey` feature."
if self.proxy_client_started is False:
self.proxy_client_started = True
- self._enable_experimental_interchange_server(self.llm_inference_fn)
+ self._enable_interchange_server(self.llm_inference_fn)
baseurl_apikey_model = OpenaiClientBaseUrlTuner(
config=self.config,
context_tracker=self.context_tracker,
@@ -168,7 +168,7 @@ def get_context_tracker(self) -> MultiAgentContextTracker:
return self.context_tracker
- def _enable_experimental_interchange_server(self, llm_inference_fn):
+ def _enable_interchange_server(self, llm_inference_fn):
# experimental reverse proxy start
if self.enable_interchange_server:
from ajet.tuner_lib.experimental.as_oai_model_client import InterchangeClient
diff --git a/ajet/tuner_lib/as_oai_baseurl_apikey.py b/ajet/tuner_lib/as_oai_baseurl_apikey.py
index cc40c61b..0931e980 100644
--- a/ajet/tuner_lib/as_oai_baseurl_apikey.py
+++ b/ajet/tuner_lib/as_oai_baseurl_apikey.py
@@ -12,11 +12,13 @@ class MockAsyncCompletions(AsyncCompletions):
async def create(self, *args, **kwargs) -> Any: # type: ignore
return await self._client.create(*args, **kwargs) # type: ignore
+
class MockAsyncChat(AsyncChat):
@property
def completions(self) -> MockAsyncCompletions: # type: ignore
return MockAsyncCompletions(self._client)
+
class OpenaiBaseUrlAndApiKey(BaseModel):
""" At this layer, we will determine which model to use:
- training model
@@ -29,13 +31,17 @@ class OpenaiBaseUrlAndApiKey(BaseModel):
episode_uuid: str = Field(default="episode_id", description="reserved field.")
def as_agentscope_model(self, *args, **kwargs):
- from agentscope.model import DashScopeChatModel
- return DashScopeChatModel(model_name="AgentJet-Model", api_key=self.api_key, base_http_api_url=self.base_url)
+ from agentscope.model import OpenAIChatModel
+ return OpenAIChatModel(
+ model_name="AgentJet-Model", api_key=self.api_key,
+ client_args={"base_url": self.base_url}
+ )
def as_raw_openai_sdk_client(self, *args, **kwargs):
from openai import AsyncOpenAI
return AsyncOpenAI(api_key=self.api_key, base_url=self.base_url)
+
class OpenaiClientBaseUrlTuner(BaseModel):
""" At this layer, we will determine which model to use:
- training model
diff --git a/ajet/tuner_lib/experimental/as_oai_model_client.py b/ajet/tuner_lib/experimental/as_oai_model_client.py
index 8288b609..aaecde5c 100644
--- a/ajet/tuner_lib/experimental/as_oai_model_client.py
+++ b/ajet/tuner_lib/experimental/as_oai_model_client.py
@@ -5,19 +5,17 @@
import os
import time
import zmq
-import base64
import json
from loguru import logger
from typing import TYPE_CHECKING
-from openai.types.chat.chat_completion import ChatCompletion
from ajet.tuner_lib.experimental.as_oai_model_server import InterchangeCompletionRequest
from ajet.utils.thread_executors import SharedInferenceTrackerThreadExecutor, SharedInterchangeThreadExecutor
from ajet.tuner_lib.experimental.interchange_utils import get_zmq_socket
-from ajet.tuner_lib.experimental.interchange_utils import DEBUG, API_KEY_PREFIX
+from ajet.tuner_lib.experimental.interchange_utils import DEBUG
if TYPE_CHECKING:
- from vllm.entrypoints.openai.protocol import ChatCompletionRequest
+ pass
context = zmq.Context()
atexit.register(context.term)
@@ -31,6 +29,7 @@ class InterchangeClient:
"""
def __init__(self, episode_uuid: str, context_tracker: "MultiAgentContextTracker", llm_inference_fn, config):
+ from ajet.task_rollout.async_llm_bridge import OpenaiLlmProxyWithTracker
self.episode_uuid = episode_uuid
self.context_tracker = context_tracker
self.llm_inference_fn = llm_inference_fn
@@ -40,37 +39,12 @@ def __init__(self, episode_uuid: str, context_tracker: "MultiAgentContextTracker
self.ipc_path = ipc_path
self.interchange_method = config.ajet.interchange_server.interchange_method
self.max_inference_tracker_threads = config.ajet.interchange_server.max_inference_tracker_threads
-
- async def llm_infer(
- self,
- req: "ChatCompletionRequest",
- timeline_uuid: str,
- agent_name: str,
- target_tag: str,
- episode_uuid: str,
- ) -> ChatCompletion:
- from ajet.task_rollout.async_llm_bridge import OpenaiLlmProxyWithTracker
-
- req_as_dict = req.model_dump()
self.llm_proxy_with_tracker = OpenaiLlmProxyWithTracker(
context_tracker=self.context_tracker,
config=self.config,
llm_inference_fn=self.llm_inference_fn,
)
- # infer + process with context tracker
- response = await self.llm_proxy_with_tracker(
- messages=req_as_dict["messages"],
- tools=req_as_dict["tools"],
- tool_choice="auto",
- )
-
- # this is an important id assignment
- response.id = timeline_uuid
- assert isinstance(response, ChatCompletion)
- return response
-
-
@property
def should_soft_terminate(self) -> bool:
if self._should_terminate:
@@ -98,15 +72,15 @@ def begin_service(self):
if DEBUG: logger.info(f"[client] {self.episode_uuid} | Starting InterchangeClient service loop...")
self.socket = context.socket(zmq.REP)
self.socket.bind(f"{self.episode_contect_address}")
- self.socket.setsockopt(zmq.RCVTIMEO, 1*1000) # 3 second timeout for REP
+ self.socket.setsockopt(zmq.RCVTIMEO, 1*1000) # 1 second timeout for REP
self.executor = SharedInterchangeThreadExecutor(self.max_inference_tracker_threads).get_shared_executor()
if DEBUG: logger.info(f"[client] {self.episode_uuid} | Submitting _begin_service_threading to executor...")
future = self.executor.submit(self._begin_service_threading)
# wait till service begin running
- time.sleep(0.5)
wait_time = 1
+ time.sleep(wait_time)
while future._state == 'PENDING':
if self.should_soft_terminate or self.should_hard_terminate:
future.cancel()
@@ -130,12 +104,15 @@ def _begin_service_threading(self):
try:
while not self.should_hard_terminate:
- # listen for next request from remote
try:
- # if DEBUG: logger.info(f"[client] {self.episode_uuid} | socket.recv_string() has begun (should_terminate {self.should_terminate})")
+
+ # :
+ # : ajet/tuner_lib/experimental/as_oai_model_server.py
+ # : socket.send_string(int_req.model_dump_json())
+ # : InterchangeCompletionRequest object in JSON string format
message = self.socket.recv_string()
+
ever_receive_anything = True
- # if DEBUG: logger.info(f"[client] {self.episode_uuid} | socket.recv_string() is done")
except zmq.Again as e:
if self.should_hard_terminate:
# abort_episode()
@@ -154,27 +131,44 @@ def _begin_service_threading(self):
# begin to run the llm request, monitored by context tracker
# we re-use previously created thread for best performance
if DEBUG: logger.info(f"[client] {self.episode_uuid} | before asyncio run self.llm_infer")
+
+ # Check if there's a running event loop
try:
loop = asyncio.get_running_loop()
- except:
+ created_new_loop = False
+ except RuntimeError:
+ # No running loop, create a new one
loop = asyncio.new_event_loop()
- context_tracker_executor = SharedInferenceTrackerThreadExecutor(self.max_inference_tracker_threads).get_shared_executor()
- future = loop.run_in_executor(
- context_tracker_executor,
- asyncio.run,
- self.llm_infer(
- req=parsed_msg.completion_request,
- timeline_uuid=parsed_msg.timeline_uuid,
- agent_name=parsed_msg.agent_name,
- target_tag=parsed_msg.target_tag,
- episode_uuid=parsed_msg.episode_uuid,
+ asyncio.set_event_loop(loop)
+ created_new_loop = True
+
+ try:
+ context_tracker_executor = SharedInferenceTrackerThreadExecutor(self.max_inference_tracker_threads).get_shared_executor()
+ future = loop.run_in_executor(
+ context_tracker_executor,
+ asyncio.run,
+ self.llm_proxy_with_tracker.chat_completion_request(
+ req=parsed_msg.completion_request,
+ timeline_uuid=parsed_msg.timeline_uuid,
+ agent_name=parsed_msg.agent_name,
+ target_tag=parsed_msg.target_tag,
+ episode_uuid=parsed_msg.episode_uuid,
+ )
)
- )
- result = loop.run_until_complete(future).model_dump_json() # type: ignore
+ result = loop.run_until_complete(future).model_dump_json() # type: ignore
+ finally:
+ # Clean up the event loop if we created it
+ if created_new_loop:
+ loop.close()
+ asyncio.set_event_loop(None)
- # great, let's send back the result
if DEBUG: logger.info(f"[client] {self.episode_uuid} | before send_string (send llm call result)")
+
+ #
+ # : ajet/tuner_lib/experimental/as_oai_model_server.py
+ # : result_str = socket.recv_string()
self.socket.send_string(result)
+
if DEBUG: logger.info(f"[client] {self.episode_uuid} | after send_string (send llm call result)")
except:
logger.exception(f"[client] {self.episode_uuid} | Exception occurred in service loop.")
diff --git a/ajet/tuner_lib/experimental/as_oai_model_server.py b/ajet/tuner_lib/experimental/as_oai_model_server.py
index 301483ff..367c8086 100644
--- a/ajet/tuner_lib/experimental/as_oai_model_server.py
+++ b/ajet/tuner_lib/experimental/as_oai_model_server.py
@@ -24,7 +24,9 @@
from loguru import logger
from pydantic import BaseModel
+from functools import lru_cache
from fastapi import FastAPI, Header, HTTPException, Request
+from fastapi.responses import StreamingResponse
from contextlib import asynccontextmanager
from multiprocessing import Manager, Process
from concurrent.futures import ThreadPoolExecutor
@@ -32,6 +34,9 @@
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
from openai.types.chat.chat_completion import ChatCompletion
+from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
+from openai.types.chat.chat_completion_chunk import Choice as ChunkChoice
+from openai.types.chat.chat_completion_chunk import ChoiceDelta, ChoiceDeltaToolCall, ChoiceDeltaToolCallFunction
from ajet.utils.networking import get_host_ip
from ajet.tuner_lib.experimental.interchange_utils import EpisodeStatus
@@ -44,6 +49,7 @@ class InterchangeCompletionRequest(BaseModel):
target_tag: str
episode_uuid: str
timeline_uuid: str
+ preserve_sampling_params: bool = False
class HealthCheckRequest(BaseModel):
agent_name: str
@@ -58,6 +64,9 @@ class HealthCheckRequest(BaseModel):
context = zmq.Context()
atexit.register(context.term)
+@lru_cache(maxsize=128)
+def ep_key(episode_uuid: str) -> str:
+ return f"episodes-{episode_uuid}"
def get_app(max_fastapi_threads: int = 512, enable_swarm_mode=False, shared_mem_dict=None, shared_mem_dict_lock=None) -> Tuple[FastAPI, Optional[Coroutine]]:
@@ -85,14 +94,33 @@ def _begin_handle_chat_completion(episode_address, int_req: InterchangeCompletio
socket.setsockopt(zmq.RCVTIMEO, 6*1000) # 6 second recv timeout
socket.connect(f"{episode_address}")
if DEBUG: logger.info(f"[server] episode_uuid: {episode_uuid} | connect done")
+
+ #
+ # : ajet/tuner_lib/experimental/as_oai_model_client.py
+ # : message = self.socket.recv_string()
socket.send_string(int_req.model_dump_json())
+
if DEBUG: logger.info(f"[server] episode_uuid: {episode_uuid} | send_string")
result_str = ""
for _ in range(50): # max 5 minutes wait
+
+ if enable_swarm_mode:
+ assert shared_mem_dict is not None
+ ep_stat = shared_mem_dict[ep_key(episode_uuid)]
+ episode_status = ep_stat.episode_status
+ if episode_status != "claimed":
+ raise HTTPException(status_code=404, detail="The episode is not claimed, cannot accept new requests.")
+
try:
if DEBUG: logger.info(f"[server] episode_uuid: {episode_uuid} | recv_string begin.")
+
+ # :
+ # : ajet/tuner_lib/experimental/as_oai_model_client.py
+ # : self.socket.send_string(result)
+ # : ChatCompletion object in JSON string format
result_str = socket.recv_string()
+
break
except zmq.Again as e:
# check whether server is still in rolling status
@@ -112,6 +140,89 @@ def _begin_handle_chat_completion(episode_address, int_req: InterchangeCompletio
return result_object
+ async def mock_as_stream_response(result: ChatCompletion):
+ """
+ Convert a non-streaming ChatCompletion result to streaming format.
+
+ Args:
+ result: ChatCompletion object to convert to streaming format
+
+ Yields:
+ Server-sent events formatted as streaming chat completion chunks
+ """
+ content = result.choices[0].message.content if result.choices else ""
+ role = result.choices[0].message.role if result.choices else "assistant"
+ # try:
+ # thinking = result.choices[0].message.reasoning_content
+ # except:
+ # thinking = None
+ tool_calls = result.choices[0].message.tool_calls if result.choices and result.choices[0].message.tool_calls else None
+ delta_tool_calls = [] # tool_calls: Optional[List[ChoiceDeltaToolCall]] = None
+ finish_reason = result.choices[0].finish_reason
+ if tool_calls:
+ delta_tool_calls = [ChoiceDeltaToolCall(
+ index=index,
+ id=tc.id,
+ function=ChoiceDeltaToolCallFunction(
+ name = tc.function.name,
+ arguments = tc.function.arguments,
+ ),
+ type=tc.type
+ ) for index, tc in enumerate(tool_calls)]
+
+ # First chunk with role
+ first_chunk = ChatCompletionChunk(
+ id=result.id,
+ model=result.model,
+ created=result.created,
+ object="chat.completion.chunk",
+ choices=[
+ ChunkChoice(
+ index=0,
+ delta=ChoiceDelta(role=role, content=""),
+ finish_reason=None
+ )
+ ]
+ )
+ dat = f"data: {first_chunk.model_dump_json()}\n\n"
+ yield dat
+
+ # Content chunk
+ content_chunk = ChatCompletionChunk(
+ id=result.id,
+ model=result.model,
+ created=result.created,
+ object="chat.completion.chunk",
+ choices=[
+ ChunkChoice(
+ index=0,
+ delta=ChoiceDelta(role=role, content=content, tool_calls=delta_tool_calls),
+ finish_reason=None
+ )
+ ]
+ )
+ dat = f"data: {content_chunk.model_dump_json()}\n\n"
+ yield dat
+
+ # Final chunk with finish_reason
+ final_chunk = ChatCompletionChunk(
+ id=result.id,
+ model=result.model,
+ created=result.created,
+ object="chat.completion.chunk",
+ choices=[
+ ChunkChoice(
+ index=0,
+ delta=ChoiceDelta(),
+ finish_reason=finish_reason
+ )
+ ]
+ )
+ dat = f"data: {final_chunk.model_dump_json()}\n\n"
+ yield dat
+ yield "data: [DONE]\n\n"
+
+
@app.get("/health")
async def health():
return {"status": "ok"}
@@ -149,12 +260,13 @@ async def chat_completions(request: Request, authorization: str = Header(None)):
# Parse request body
body = await request.json()
new_req = ChatCompletionRequest.model_validate(body)
- if new_req.stream:
- return HTTPException(status_code=400, detail="Streaming responses not supported in current AgentJet version, please set `stream=false` for now.")
# Create timeline UUID
timeline_uuid = uuid.uuid4().hex
+ # if training, ignore all sampling parameters from request
+ preserve_sampling_params = False
+
# enable_swarm_mode
if enable_swarm_mode:
from ajet.tuner_lib.experimental.as_swarm_server import ep_key
@@ -174,6 +286,14 @@ async def chat_completions(request: Request, authorization: str = Header(None)):
es.latest_activity_timestamp = time.time()
es.llm_call_count += 1
shared_mem_dict[ep_key(episode_uuid)] = es
+ if es.episode_type == "eval":
+ preserve_sampling_params = True
+
+ # For streaming, we process as non-streaming but return in streaming format
+ original_stream = new_req.stream
+ if original_stream:
+ new_req.stream = False
+ new_req.stream_options = None
# Add to received queue
int_req = InterchangeCompletionRequest(
@@ -182,10 +302,16 @@ async def chat_completions(request: Request, authorization: str = Header(None)):
target_tag = target_tag,
episode_uuid = episode_uuid,
timeline_uuid = timeline_uuid,
+ preserve_sampling_params = preserve_sampling_params,
)
if DEBUG: logger.info(f"episode_uuid: {episode_uuid} | Received new chat completion request (outside thread)")
loop = asyncio.get_running_loop()
- return await loop.run_in_executor(request.app.state.executor, _begin_handle_chat_completion, episode_address, int_req, episode_uuid)
+ result = await loop.run_in_executor(request.app.state.executor, _begin_handle_chat_completion, episode_address, int_req, episode_uuid)
+
+ if original_stream:
+ return StreamingResponse(mock_as_stream_response(result), media_type="text/event-stream")
+
+ return result
if enable_swarm_mode:
@@ -314,6 +440,7 @@ def start_interchange_server(config, blocking=False, env={}) -> int:
# polling for server ready
start_time = time.time()
+ _httpx_client = httpx.Client(timeout=0.5)
while True:
if interchange_server and interchange_server.exitcode is not None:
logger.error(f"Interchange server subprocess failed to start. Return code: {interchange_server.exitcode}")
@@ -323,7 +450,7 @@ def start_interchange_server(config, blocking=False, env={}) -> int:
logger.error(msg)
raise RuntimeError(msg)
try:
- if httpx.get(health_url, timeout=0.5).status_code == 200:
+ if _httpx_client.get(health_url).status_code == 200:
break
except Exception:
# keep waiting
@@ -348,7 +475,7 @@ def start_interchange_server(config, blocking=False, env={}) -> int:
interchange_server.join()
except KeyboardInterrupt:
logger.info("Shutting down interchange server...")
- try: httpx.post(f"http://127.0.0.1:{port}/stop_engine", timeout=8).status_code
+ try: _httpx_client.post(f"http://127.0.0.1:{port}/stop_engine", timeout=8).status_code
except Exception: pass
if interchange_server:
diff --git a/ajet/tuner_lib/experimental/as_swarm_client.py b/ajet/tuner_lib/experimental/as_swarm_client.py
index bf7c57f3..4bf9d8e5 100644
--- a/ajet/tuner_lib/experimental/as_swarm_client.py
+++ b/ajet/tuner_lib/experimental/as_swarm_client.py
@@ -52,6 +52,8 @@ def raise_for_status_with_detail(resp):
raise RuntimeError(f"SwarmClient error {resp.status_code} with non-JSON response: {response_text}") from e
+class SwarmServerOfflineError(Exception): ...
+
class SwarmClient(object):
@@ -69,6 +71,8 @@ def __init__(self, server_url: str):
self._agent_jet_job = None
# throttle
self._recent_seen_tasks = []
+ # reuse httpx client to avoid creating SSL context repeatedly
+ self._http_client = httpx.Client(timeout=GENERAL_TIMEOUT)
def logger_info(self, message):
# logger with de-duplication within 1 second to prevent log flooding
@@ -113,8 +117,8 @@ def _check_throttle_policy(self, throttle_policy: SwarmThrottlePolicy, pool_info
if self._agent_jet_job:
# check and raise early errors when possible
- assert self._agent_jet_job.sample_collection_method == "rollout_until_finish_enough_tasks", \
- f"Current sample collection method ({self._agent_jet_job.sample_collection_method}) does not support throttle policy."
+ assert self._agent_jet_job.swarm_mode_sample_collection_method == "rollout_until_finish_enough_tasks", \
+ f"Current sample collection method ({self._agent_jet_job.swarm_mode_sample_collection_method}) does not support throttle policy."
# only_this_client_uuid = throttle_policy.throttle_method in ["Task_Ratio_Limit"]
only_this_client_uuid = True
@@ -193,7 +197,7 @@ def _should_throttle(self, throttle_policy: SwarmThrottlePolicy, pool_info: Curr
self._remember_seen_task(throttle_policy.current_task_id, throttle_policy.expected_batch_size, throttle_policy.expected_num_repeat)
return should_throttle
- def begin_episode(self, discard_episode_timeout=600, episode_type="train", throttle_policy: SwarmThrottlePolicy|None = None) -> Tuple[str, OpenaiBaseUrlAndApiKey]:
+ def begin_episode(self, discard_episode_timeout=240, episode_type="train", throttle_policy: SwarmThrottlePolicy|None = None) -> Tuple[str, OpenaiBaseUrlAndApiKey]:
"""
Block until an episode is claimed.
Argument:
@@ -208,9 +212,9 @@ def begin_episode(self, discard_episode_timeout=600, episode_type="train", throt
"""
return self._begin_episode_auto_retry(discard_episode_timeout, episode_type, throttle_policy)
- def _begin_episode_auto_retry(self, discard_episode_timeout=600, episode_type="train", throttle_policy: SwarmThrottlePolicy|None = None) -> Tuple[str, OpenaiBaseUrlAndApiKey]:
+ def _begin_episode_auto_retry(self, discard_episode_timeout=240, episode_type="train", throttle_policy: SwarmThrottlePolicy|None = None) -> Tuple[str, OpenaiBaseUrlAndApiKey]:
# max_episode_time: when an episode has **lasted** for more than X seconds, it will be terminated **locally** by client (call `end_episode` will be re-route to `abort_episode`)
- max_episode_time = 2*discard_episode_timeout
+ max_episode_time = 8*discard_episode_timeout
status, status_json = self.get_engine_status() # warm up connection and log the status
if status not in ["ENGINE.ROLLING"]:
@@ -250,10 +254,9 @@ def _begin_episode_auto_retry(self, discard_episode_timeout=600, episode_type="t
discard_episode_timeout=discard_episode_timeout,
throttle_policy=throttle_policy
)
- resp = httpx.post(
+ resp = self._http_client.post(
f"{self.server_url}/claim_episode",
- json=req_obj.model_dump(),
- timeout=GENERAL_TIMEOUT
+ json=req_obj.model_dump()
)
raise_for_status_with_detail(resp)
data = ClaimEpisodeResponse.model_validate(resp.json())
@@ -316,13 +319,13 @@ def end_episode(self, task:Task, episode_uuid: str, workflow_output: WorkflowOut
if episode_uuid in self.record_episode_expire_time:
remain_time = self.record_episode_expire_time.pop(episode_uuid, 0) - time.time()
if remain_time < 0:
- logger.warning(f"Episode {episode_uuid} has expired (expired {-remain_time} seconds ago). Please use a larger `discard_episode_timeout` and `max_episode_time` when `begin_episode`. Skipping end_episode.")
+ logger.warning(f"Episode {episode_uuid} has expired (expired {-remain_time} seconds ago). Please use a larger `discard_episode_timeout` when `begin_episode`. Skipping end_episode.")
# send abort signal to server to clean up episode
self.abort_episode(episode_uuid)
return
else:
# send abort signal to server to clean up episode
- logger.warning(f"Episode {episode_uuid} has expired (expired at least {CLEAN_RECORD_TIMEOUT} seconds ago). Please use a larger `discard_episode_timeout` and `max_episode_time` when `begin_episode`. Skipping end_episode.")
+ logger.warning(f"Episode {episode_uuid} has expired (expired at least {CLEAN_RECORD_TIMEOUT} seconds ago). Please use a larger `discard_episode_timeout` when `begin_episode`. Skipping end_episode.")
self.abort_episode(episode_uuid)
return
@@ -335,10 +338,9 @@ def end_episode(self, task:Task, episode_uuid: str, workflow_output: WorkflowOut
task_id=task_id
)
- resp = httpx.post(
+ resp = self._http_client.post(
f"{self.server_url}/end_episode",
- json=req_obj.model_dump(),
- timeout=GENERAL_TIMEOUT
+ json=req_obj.model_dump()
)
raise_for_status_with_detail(resp)
data = EndEpisodeResponse.model_validate(resp.json())
@@ -364,10 +366,9 @@ def abort_episode(self, episode_uuid: str):
task_id=""
)
- resp = httpx.post(
+ resp = self._http_client.post(
f"{self.server_url}/abort_episode",
- json=req_obj.model_dump(),
- timeout=GENERAL_TIMEOUT
+ json=req_obj.model_dump()
)
raise_for_status_with_detail(resp)
data = EndEpisodeResponse.model_validate(resp.json())
@@ -397,10 +398,9 @@ def sync_train_config(self, agent_jet_job: AgentJetJob):
req_obj = SyncTrainConfigRequest(yaml_as_string=yaml_str)
- resp = httpx.post(
+ resp = self._http_client.post(
f"{self.server_url}/sync_train_config",
- json=req_obj.model_dump(),
- timeout=GENERAL_TIMEOUT
+ json=req_obj.model_dump()
)
raise_for_status_with_detail(resp)
self.logger_info("Synced train config to Swarm server")
@@ -420,7 +420,7 @@ def start_engine(self):
raise RuntimeError(f"Cannot start engine when engine is NOT ENGINE.OFFLINE. (current status: {current_status})")
# Send start engine request
- resp = httpx.post(
+ resp = self._http_client.post(
f"{self.server_url}/start_engine",
json={},
timeout=600
@@ -437,7 +437,7 @@ def start_engine(self):
self._wait_until_status_change_to(desired_status="ENGINE.ROLLING")
logger.success("Training engine is now ROLLING and ready.")
- def _wait_until_status_change_to(self, desired_status="ENGINE.ROLLING", verbose=True):
+ def _wait_until_status_change_to(self, desired_status="ENGINE.ROLLING", verbose=True, timeout=1800):
"""
Poll engine status until it reaches desired_status.
Reports status every 5 seconds while waiting.
@@ -446,12 +446,20 @@ def _wait_until_status_change_to(self, desired_status="ENGINE.ROLLING", verbose=
self.logger_info(f"Polling engine status until {desired_status}...")
last_report_time = time.time()
init_poll_time = last_report_time
+ initial_status, _ = self.get_engine_status()
while True:
try:
current_status, _ = self.get_engine_status()
current_time = time.time()
+ # Check if timeout has been reached
+ if current_time - init_poll_time >= timeout:
+ raise TimeoutError(f"Timeout reached while waiting for engine status to change to {desired_status}")
+
+ if (initial_status == "ENGINE.OFFLINE") and (current_status == "ENGINE.OFFLINE") and (desired_status!="ENGINE.OFFLINE"):
+ raise SwarmServerOfflineError(f"Engine status changed from {initial_status} to OFFLINE while waiting for {desired_status}. This may indicate an error in the engine. Please check the swarm server logs for details.")
+
# Report status every 5 seconds
if current_time - last_report_time >= 30:
if verbose:
@@ -467,6 +475,9 @@ def _wait_until_status_change_to(self, desired_status="ENGINE.ROLLING", verbose=
# Wait a bit before next poll
time.sleep(5)
+ except SwarmServerOfflineError as e:
+ raise e
+
except Exception as e:
logger.error(f"Error polling engine status: {e}")
time.sleep(5)
@@ -474,7 +485,7 @@ def _wait_until_status_change_to(self, desired_status="ENGINE.ROLLING", verbose=
@cache_with_ttl(ttl=0.5)
def get_engine_status(self) -> Tuple[str, dict]:
try:
- resp = httpx.get(
+ resp = self._http_client.get(
f"{self.server_url}/get_engine_status",
timeout=10
)
@@ -499,7 +510,7 @@ def can_continue_episode(self, episode_uuid: str) -> bool:
client_uuid=self.client_uuid,
episode_uuid=episode_uuid
)
- resp = httpx.post(
+ resp = self._http_client.post(
f"{self.server_url}/can_continue_episode",
json=req_obj.model_dump(),
timeout=10
@@ -513,7 +524,7 @@ def can_continue_episode(self, episode_uuid: str) -> bool:
def get_episode_buffer(self) -> List[EpisodeStatus]:
try:
- resp = httpx.post(
+ resp = self._http_client.post(
f"{self.server_url}/get_episode_buffer",
json={},
timeout=10
@@ -572,7 +583,7 @@ def stop_engine(self):
self.logger_info("Engine is already OFFLINE. No action needed.")
return
- resp = httpx.post(
+ resp = self._http_client.post(
f"{self.server_url}/stop_engine",
json={},
timeout=600
@@ -592,7 +603,7 @@ def get_rollout_stat(self) -> CurrentBatchRolloutPoolInformation:
Returns statistics about completed episodes, tasks, and progress.
"""
try:
- resp = httpx.get(
+ resp = self._http_client.get(
f"{self.server_url}/get_current_batch_rollout_pool_information",
timeout=10
)
diff --git a/ajet/tuner_lib/experimental/as_swarm_server.py b/ajet/tuner_lib/experimental/as_swarm_server.py
index a5d156d9..aa82984a 100644
--- a/ajet/tuner_lib/experimental/as_swarm_server.py
+++ b/ajet/tuner_lib/experimental/as_swarm_server.py
@@ -335,11 +335,11 @@ async def start_engine():
config_dict = yaml_module.safe_load(yaml_str)
backbone = config_dict.get("ajet", {}).get("backbone", "verl")
DEFAULT_DIR = "saved_experiments"
- exp_dir_final = config_dict.get("ajet", {}).get("experiment_dir", DEFAULT_DIR)
- if exp_dir_final != DEFAULT_DIR:
- # remove last dir level if possible
- exp_dir_final = os.path.dirname(exp_dir_final)
-
+ experiment_dir = config_dict.get("ajet", {}).get("experiment_dir", DEFAULT_DIR)
+ if experiment_dir == "auto":
+ exp_base_dir = DEFAULT_DIR
+ else:
+ exp_base_dir = os.path.dirname(os.path.abspath(experiment_dir))
# Save YAML to temporary file
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".yaml") as temp_file:
@@ -351,7 +351,6 @@ async def start_engine():
args = SimpleNamespace(
conf=main_yaml_fp,
backbone=backbone,
- exp_dir=exp_dir_final,
with_logview=False,
debug=False,
)
@@ -367,7 +366,12 @@ def override_param_callback(config):
return config
# Finalize experiment config
- main_yaml_fp, exe_exp_base, exp_name, exp_config = prepare_experiment_config(main_yaml_fp, exp_dir_final, backbone, override_param_callback)
+ main_yaml_fp, exe_exp_base, exp_name, exp_config = prepare_experiment_config(
+ yaml_path=main_yaml_fp,
+ exp_base_dir=exp_base_dir,
+ backbone=backbone,
+ override_param_callback=override_param_callback,
+ )
# Setup environment variables
env, exp_config = setup_environment_vars(args, exp_config, main_yaml_fp)
@@ -393,6 +397,7 @@ def override_param_callback(config):
main_yaml_fp,
env,
exp_config,
+ True, # is_swarm_server
),
)
p.daemon = True
@@ -707,7 +712,7 @@ async def get_episode_buffer():
@app.post("/update_current_batch_rollout_pool_information", response_model=BoolResponse)
async def update_current_batch_rollout_pool_information(req: CurrentBatchRolloutPoolInformation):
"""Update the current batch rollout pool information."""
- if VERBOSE:
+ if DEBUG:
logger.info(f"Running /update_current_batch_rollout_pool_information")
try:
with shared_mem_dict_lock:
diff --git a/ajet/tuner_lib/experimental/interchange_utils.py b/ajet/tuner_lib/experimental/interchange_utils.py
index c50dd93d..880b87b0 100644
--- a/ajet/tuner_lib/experimental/interchange_utils.py
+++ b/ajet/tuner_lib/experimental/interchange_utils.py
@@ -109,10 +109,16 @@ class UpdateEngineStatusRequest(BaseModel):
VERBOSE = True
+shared_http_client = httpx.Client(timeout=10.0)
+
def get_interchange_server_url(config):
port = os.getenv("AJET_DAT_INTERCHANGE_PORT")
- if config.ajet.interchange_server.interchange_server_port != 'auto':
- port = str(int(config.ajet.interchange_server.interchange_server_port))
+ if isinstance(config, dict):
+ interchange_server_port = config.get("ajet", {}).get("interchange_server", {}).get("interchange_server_port", "auto")
+ else:
+ interchange_server_port = config.ajet.interchange_server.interchange_server_port
+ if interchange_server_port != 'auto':
+ port = str(int(interchange_server_port))
assert port is not None, "AJET_DAT_INTERCHANGE_PORT env var must be set"
master_node_ip = os.getenv("MASTER_NODE_IP", "localhost")
base_url = f"http://{master_node_ip}:{port}"
@@ -123,7 +129,7 @@ def http_change_engine_status(config, new_status: str, new_status_detail: str|No
if new_status not in VALID_STATUSES:
raise ValueError(f"Invalid engine status: {new_status}")
- resp = httpx.post(
+ resp = shared_http_client.post(
f"{get_interchange_server_url(config)}/update_engine_status",
json={"engine_status": new_status, "engine_status_detail": new_status_detail, "global_step": global_step},
timeout=10
@@ -133,7 +139,7 @@ def http_change_engine_status(config, new_status: str, new_status_detail: str|No
def is_episode_claimed(config, episode_uuid: str, unregister_if_not_claimed: bool) -> bool:
- resp = httpx.post(
+ resp = shared_http_client.post(
f"{get_interchange_server_url(config)}/is_episode_claimed",
json={"episode_uuid": episode_uuid, "unregister_if_not_claimed": unregister_if_not_claimed},
timeout=5
@@ -164,7 +170,7 @@ def http_register_episode(config,
zmq_listen_result_addr=zmq_listen_result_addr,
)
# send http request to swarm server to register episode
- response = httpx.post(
+ response = shared_http_client.post(
f"{interchange_http_addr}/register_episode",
json=rer.model_dump(), # 或者 rer.model_dump() 如果使用 Pydantic v2
timeout=2
diff --git a/ajet/utils/config_utils.py b/ajet/utils/config_utils.py
index b02f1205..13c4d693 100644
--- a/ajet/utils/config_utils.py
+++ b/ajet/utils/config_utils.py
@@ -171,7 +171,7 @@ def config_safe_guard(config: dict, backbone: str) -> dict:
def read_ajet_hierarchical_config(
- yaml_fp, exp_name, backbone, write_to=None, exp_dir=DEFAULT_DIR, override_param_callback=None
+ yaml_fp, experiment_name=None, backbone=None, write_to=None, experiment_dir=None, override_param_callback=None
):
if yaml_fp is None:
config = {
@@ -193,9 +193,12 @@ def read_ajet_hierarchical_config(
else:
with open(yaml_fp, "r", encoding="utf-8") as file:
config = yaml.safe_load(file)
- config["ajet"]["experiment_name"] = exp_name
- config["ajet"]["experiment_dir"] = os.path.join(exp_dir, exp_name)
- config["ajet"]["backbone"] = backbone
+ if experiment_name is not None:
+ config["ajet"]["experiment_name"] = experiment_name
+ if (experiment_dir is not None):
+ config["ajet"]["experiment_dir"] = experiment_dir
+ if backbone is not None:
+ config["ajet"]["backbone"] = backbone
# remove extra config of verl for trinity
if backbone == "debug":
@@ -245,14 +248,14 @@ def expand_ajet_hierarchical_config(config, write_to=None):
return config_final
-def prepare_experiment_config(yaml_path, exp_dir, backbone, override_param_callback=None, storage=True):
+def prepare_experiment_config(yaml_path, exp_base_dir, backbone, override_param_callback=None, storage=True):
"""
Prepare experiment configuration by reading YAML, setting up backup directories,
and copying necessary files for the experiment.
Args:
yaml_path: Path to the YAML configuration file
- exp_dir: Directory where experiment artifacts and backups should be stored
+ exp_base_dir: Directory where experiment artifacts and backups should be stored
backbone: Backbone identifier that controls config munging
Returns:
@@ -281,8 +284,8 @@ def prepare_experiment_config(yaml_path, exp_dir, backbone, override_param_callb
else:
exp_name = exp_name.replace("|", "-")
- backup_dir = os.path.abspath(os.path.join(exp_dir, exp_name, "backup"))
- yaml_backup_dst = os.path.join(exp_dir, exp_name, "yaml_backup.yaml")
+ backup_dir = os.path.abspath(os.path.join(exp_base_dir, exp_name, "backup"))
+ yaml_backup_dst = os.path.join(exp_base_dir, exp_name, "yaml_backup.yaml")
yaml_backup_dst = os.path.abspath(yaml_backup_dst)
exe_exp_base = os.path.dirname(yaml_backup_dst)
@@ -323,12 +326,18 @@ def prepare_experiment_config(yaml_path, exp_dir, backbone, override_param_callb
shutil.copyfile(yaml_backup_src, yaml_backup_dst)
## 4. edit new yaml
+ experiment_dir = f"{exp_base_dir}/{exp_name}"
config = read_ajet_hierarchical_config(
- yaml_backup_dst, exp_name, backbone, write_to=yaml_backup_dst, exp_dir=exp_dir, override_param_callback=override_param_callback
+ yaml_backup_dst,
+ experiment_name=exp_name,
+ backbone=backbone,
+ write_to=yaml_backup_dst,
+ experiment_dir=experiment_dir,
+ override_param_callback=override_param_callback
)
config_final = expand_ajet_hierarchical_config(config, write_to=yaml_backup_dst)
if not storage:
- shutil.rmtree(os.path.join(exp_dir, exp_name))
+ shutil.rmtree(os.path.join(exp_base_dir, exp_name))
return yaml_backup_dst, exe_exp_base, exp_name, config_final
diff --git a/ajet/utils/core_env_vars.py b/ajet/utils/core_env_vars.py
index e48e1dda..9df18216 100644
--- a/ajet/utils/core_env_vars.py
+++ b/ajet/utils/core_env_vars.py
@@ -15,7 +15,7 @@ def get_runtime_env(config, is_trinity: bool = False) -> dict:
if config.ajet.trainer_common.nnodes == 1:
master_node_ip = "localhost"
else:
- if config.ajet.enable_experimental_interchange_server:
+ if config.ajet.enable_interchange_server:
if config.ajet.interchange_server.interchange_method == "ipc":
raise ValueError("IPC interchange method is not supported for multi-node setup. Please set `ajet.interchange_server.interchange_method: tcp` ")
diff --git a/ajet/utils/launch_utils.py b/ajet/utils/launch_utils.py
index 441fbc93..46200ad9 100644
--- a/ajet/utils/launch_utils.py
+++ b/ajet/utils/launch_utils.py
@@ -319,6 +319,7 @@ def execute_training_process(
exe_yaml_path,
env,
exp_config,
+ is_swarm_server=False,
):
"""
Execute the training process based on the specified backbone and configuration.
@@ -403,7 +404,13 @@ def execute_training_process(
subprocess.run(cmd, check=True, cwd=os.path.abspath("./"), env=env)
except subprocess.CalledProcessError as e:
logger.error(f"Error running subprocess: {e}")
+ if is_swarm_server:
+ from ajet.tuner_lib.experimental.interchange_utils import http_change_engine_status
+ http_change_engine_status(exp_config, "ENGINE.OFFLINE", global_step=0)
sys.exit(1)
except Exception as e:
logger.error(f"Unexpected error: {e}")
+ if is_swarm_server:
+ from ajet.tuner_lib.experimental.interchange_utils import http_change_engine_status
+ http_change_engine_status(exp_config, "ENGINE.OFFLINE", global_step=0)
sys.exit(1)
diff --git a/ajet/utils/swarm_overwatch.py b/ajet/utils/swarm_overwatch.py
index da234a8f..9f8b7717 100644
--- a/ajet/utils/swarm_overwatch.py
+++ b/ajet/utils/swarm_overwatch.py
@@ -23,13 +23,13 @@
class SwarmOverwatch:
"""Real-time monitoring interface for swarm rollout pool"""
- def __init__(self, server_url: str, refresh_interval: float = 1.0):
+ def __init__(self, server_url: str, refresh_interval: float = 2.0):
"""
Initialize the overwatch monitor
Args:
server_url: Base URL of the swarm server (e.g., http://localhost:10086)
- refresh_interval: Refresh interval in seconds (default: 1.0)
+ refresh_interval: Refresh interval in seconds (default: 2.0)
"""
self.server_url = server_url.rstrip("/")
self.refresh_interval = refresh_interval
@@ -37,11 +37,12 @@ def __init__(self, server_url: str, refresh_interval: float = 1.0):
self.last_update_time = None
self.error_count = 0
self.total_requests = 0
+ self._httpx_client = httpx.Client(timeout=5.0)
def fetch_pool_info(self) -> Optional[CurrentBatchRolloutPoolInformation]:
"""Fetch current batch rollout pool information from server"""
try:
- response = httpx.get(
+ response = self._httpx_client.get(
f"{self.server_url}/get_current_batch_rollout_pool_information",
timeout=5.0,
)
@@ -480,13 +481,13 @@ def run(self):
)
-def start_overwatch(server_url: str, refresh_interval: float = 1.0):
+def start_overwatch(server_url: str, refresh_interval: float = 2.0):
"""
Start the swarm overwatch monitoring interface
Args:
server_url: Base URL of the swarm server
- refresh_interval: Refresh interval in seconds (default: 1.0)
+ refresh_interval: Refresh interval in seconds (default: 2.0)
"""
overwatch = SwarmOverwatch(server_url, refresh_interval)
overwatch.run()
diff --git a/ajet/utils/thread_executors.py b/ajet/utils/thread_executors.py
index 8702e00c..87ea0744 100644
--- a/ajet/utils/thread_executors.py
+++ b/ajet/utils/thread_executors.py
@@ -45,12 +45,17 @@ def shutdown(self, wait=True):
class PeriodicDrainThreadPoolExecutor:
"""A ThreadPoolExecutor that bounds the number of pending tasks via a semaphore."""
- def __init__(self, workers=100, auto_retry=True):
+ def __init__(self, workers=100, max_parallel=None, auto_retry=True, block_first_run=False):
self._max_workers = workers
- self._executor = ThreadPoolExecutor(max_workers=workers)
+ if max_parallel is None:
+ self._max_parallel = workers
+ else:
+ self._max_parallel = max_parallel
+ self._executor = ThreadPoolExecutor(max_workers=self._max_parallel)
self._submitted_count = 0
self._auto_retry = auto_retry
self.current_futures = []
+ self._slow_first_run = block_first_run
def submit(self, fn, *args, **kwargs):
"""Submit a task, blocking if the pending queue is full."""
@@ -63,9 +68,15 @@ def retry_wrapper(fn, *args, **kwargs):
logger.exception(f"[run_episodes_until_all_complete] Error executing episode: {e}. Retrying...")
if self._auto_retry:
- return self._executor.submit(retry_wrapper, fn, *args, **kwargs)
+ future = self._executor.submit(retry_wrapper, fn, *args, **kwargs)
else:
- return self._executor.submit(fn, *args, **kwargs)
+ future = self._executor.submit(fn, *args, **kwargs)
+
+ if self._slow_first_run:
+ self._slow_first_run = False
+ future.result() # Wait for the first run to complete before allowing more tasks to be submitted
+
+ return future
def submit_with_periodic_drain(self, fn, *args, **kwargs):
"""Submit a task, draining all in-flight work every `drain_every_n_job` submissions."""
@@ -86,4 +97,4 @@ def submit_with_periodic_drain(self, fn, *args, **kwargs):
def shutdown(self, wait=True):
"""Shut down the underlying executor."""
- self._executor.shutdown(wait=wait)
\ No newline at end of file
+ self._executor.shutdown(wait=wait)
diff --git a/ajet/utils/tokenizer.py b/ajet/utils/tokenizer.py
index 94ab8007..64587381 100644
--- a/ajet/utils/tokenizer.py
+++ b/ajet/utils/tokenizer.py
@@ -1,5 +1,6 @@
import copy
import json
+import threading
from typing import Dict, List
@@ -19,6 +20,10 @@ def cleanup_messages(messages: List[Dict]) -> List[Dict]:
pass
return messages_copied
+# Cache storage
+_cache = {}
+_cache_lock = threading.Lock()
+
def ajet_apply_chat_template(
tokenizer,
@@ -28,16 +33,46 @@ def ajet_apply_chat_template(
tokenize: bool = True,
):
conversation = cleanup_messages(conversation)
+
+ # Create cache key by hashing all inputs
+ cache_key = (
+ id(tokenizer),
+ hash(json.dumps(conversation, sort_keys=True)),
+ hash(json.dumps(tools, sort_keys=True)) if tools else 0,
+ add_generation_prompt,
+ tokenize,
+ )
+
+ # Check cache with thread safety
+ with _cache_lock:
+ if cache_key in _cache:
+ return _cache[cache_key]
+
+ # Compute result (time consuming) - outside lock to avoid blocking other threads
if tools:
- return tokenizer.apply_chat_template(
+ result = tokenizer.apply_chat_template(
conversation,
tools,
add_generation_prompt=add_generation_prompt,
tokenize=tokenize,
)
else:
- return tokenizer.apply_chat_template(
+ result = tokenizer.apply_chat_template(
conversation,
tokenize=tokenize,
add_generation_prompt=add_generation_prompt,
)
+
+ # Store in cache with thread safety (implement LRU eviction if cache gets too large)
+ with _cache_lock:
+ if len(_cache) >= 1024:
+ # Remove oldest item (first inserted)
+ try:
+ _cache.pop(next(iter(_cache)))
+ except KeyError:
+ # Cache was modified by another thread, which is fine
+ pass
+
+ _cache[cache_key] = result
+
+ return result
diff --git a/docs/en/ajet-swarm-docker.md b/docs/en/ajet-swarm-docker.md
index 38a3c653..15c054a7 100644
--- a/docs/en/ajet-swarm-docker.md
+++ b/docs/en/ajet-swarm-docker.md
@@ -24,6 +24,7 @@ docker run --rm -it \
-v ./swarmlog:/workspace/log \
-v ./swarmexp:/workspace/saved_experiments \
-p 10086:10086 \
+ -e SWANLAB_API_KEY=$SWANLAB_API_KEY \
--gpus=all \
--shm-size=32GB \
ghcr.io/modelscope/agentjet:main \
@@ -89,6 +90,7 @@ docker run --rm -it \
-v ./swarmlog:/workspace/log \
-v ./swarmexp:/workspace/saved_experiments \
-p 10086:10086 \
+ -e SWANLAB_API_KEY=$SWANLAB_API_KEY \
--gpus=all \
--shm-size=32GB \
ghcr.io/modelscope/agentjet:main \
diff --git a/docs/en/support_agentscope.md b/docs/en/support_agentscope.md
index e551e4d9..13d308a5 100644
--- a/docs/en/support_agentscope.md
+++ b/docs/en/support_agentscope.md
@@ -64,7 +64,7 @@ This article introduce the way to convert different types of ways to convert you
ajet:
...
- enable_experimental_interchange_server: True
+ enable_interchange_server: True
...
```
diff --git a/docs/en/support_http.md b/docs/en/support_http.md
index 0bf3ab3d..d7659b1b 100644
--- a/docs/en/support_http.md
+++ b/docs/en/support_http.md
@@ -89,7 +89,7 @@ in this AI era, you can always start from scratch and build your own "high-scrap
ajet:
...
- enable_experimental_interchange_server: True
+ enable_interchange_server: True
...
```
diff --git a/docs/en/support_langchain.md b/docs/en/support_langchain.md
index d1e12890..344163e2 100644
--- a/docs/en/support_langchain.md
+++ b/docs/en/support_langchain.md
@@ -80,7 +80,7 @@ This article introduce the way to convert different types of ways to convert you
ajet:
...
- enable_experimental_interchange_server: True
+ enable_interchange_server: True
...
```
diff --git a/docs/en/support_oaisdk.md b/docs/en/support_oaisdk.md
index b60b03e3..104a1c26 100644
--- a/docs/en/support_oaisdk.md
+++ b/docs/en/support_oaisdk.md
@@ -84,7 +84,7 @@ This article introduce the way to convert different types of ways to convert you
ajet:
...
- enable_experimental_interchange_server: True
+ enable_interchange_server: True
...
```
diff --git a/tests/bench/benchmark_appworld/benchmark_appworld.yaml b/tests/bench/benchmark_appworld/benchmark_appworld.yaml
index f83e91f0..3622ed1b 100644
--- a/tests/bench/benchmark_appworld/benchmark_appworld.yaml
+++ b/tests/bench/benchmark_appworld/benchmark_appworld.yaml
@@ -58,14 +58,14 @@ ajet:
execute_testing_lambda: "tests/bench/benchmark_appworld/benchmark_appworld.py->TestProbe" #
-# ------------------ 不需要修改 ------------------
+# ------------------ do not edit ------------------
hydra:
searchpath:
- file://ajet/default_config
- file://ajet/default_config/verl # verl only
- file://ajet/default_config/trinity # trinity only
-# ------------------ 不需要修改 ------------------
+# ------------------ do not edit ------------------
defaults:
- verl_default # verl inherit 1/1
- trinity_default # trinity inherit 1/1
diff --git a/tests/bench/benchmark_appworld/benchmark_appworld_2nodes.yaml b/tests/bench/benchmark_appworld/benchmark_appworld_2nodes.yaml
index 4ae12f17..f53ca63b 100644
--- a/tests/bench/benchmark_appworld/benchmark_appworld_2nodes.yaml
+++ b/tests/bench/benchmark_appworld/benchmark_appworld_2nodes.yaml
@@ -63,14 +63,14 @@ trinity:
sync_offset: 0
sync_method: nccl
-# ------------------ 不需要修改 ------------------
+# ------------------ do not edit ------------------
hydra:
searchpath:
- file://ajet/default_config
- file://ajet/default_config/verl # verl only
- file://ajet/default_config/trinity # trinity only
-# ------------------ 不需要修改 ------------------
+# ------------------ do not edit ------------------
defaults:
- verl_default # verl inherit 1/1
- trinity_default # trinity inherit 1/1
diff --git a/tests/bench/benchmark_appworld/benchmark_appworld_oai_sdk.yaml b/tests/bench/benchmark_appworld/benchmark_appworld_oai_sdk.yaml
index e3175d19..89d82afb 100644
--- a/tests/bench/benchmark_appworld/benchmark_appworld_oai_sdk.yaml
+++ b/tests/bench/benchmark_appworld/benchmark_appworld_oai_sdk.yaml
@@ -56,14 +56,14 @@ ajet:
execute_testing_lambda: "tests/bench/benchmark_appworld/benchmark_appworld.py->TestProbe" #
-# ------------------ 不需要修改 ------------------
+# ------------------ do not edit ------------------
hydra:
searchpath:
- file://ajet/default_config
- file://ajet/default_config/verl # verl only
- file://ajet/default_config/trinity # trinity only
-# ------------------ 不需要修改 ------------------
+# ------------------ do not edit ------------------
defaults:
- verl_default # verl inherit 1/1
- trinity_default # trinity inherit 1/1
diff --git a/tests/bench/benchmark_countdown/benchmark_countdown.yaml b/tests/bench/benchmark_countdown/benchmark_countdown.yaml
index fcd07f35..53cdd902 100644
--- a/tests/bench/benchmark_countdown/benchmark_countdown.yaml
+++ b/tests/bench/benchmark_countdown/benchmark_countdown.yaml
@@ -124,14 +124,14 @@ ajet:
execute_testing_lambda: "tests/bench/benchmark_countdown/benchmark_countdown.py->TestProbe" # FOR ROBOT TESTING PURPOSE ONLY. NOT FOR HUMAN
-# ------------------ 不需要修改 ------------------
+# ------------------ do not edit ------------------
hydra:
searchpath:
- file://ajet/default_config
- file://ajet/default_config/verl # verl only
- file://ajet/default_config/trinity # trinity only
-# ------------------ 不需要修改 ------------------
+# ------------------ do not edit ------------------
defaults:
- verl_default # verl inherit 1/1
- trinity_default # trinity inherit 1/1
diff --git a/tests/bench/benchmark_learn2ask/benchmark_learn2ask.yaml b/tests/bench/benchmark_learn2ask/benchmark_learn2ask.yaml
index dd3b6a18..b435a0ac 100644
--- a/tests/bench/benchmark_learn2ask/benchmark_learn2ask.yaml
+++ b/tests/bench/benchmark_learn2ask/benchmark_learn2ask.yaml
@@ -57,14 +57,14 @@ trinity:
sync_method: nccl
-# ------------------ 不需要修改 ------------------
+# ------------------ do not edit ------------------
hydra:
searchpath:
- file://ajet/default_config
- file://ajet/default_config/verl # verl only
- file://ajet/default_config/trinity # trinity only
-# ------------------ 不需要修改 ------------------
+# ------------------ do not edit ------------------
defaults:
- verl_default # verl inherit 1/1
- trinity_default # trinity inherit 1/1
diff --git a/tests/bench/benchmark_math/benchmark_math.yaml b/tests/bench/benchmark_math/benchmark_math.yaml
index 36648d5f..f0f8d896 100644
--- a/tests/bench/benchmark_math/benchmark_math.yaml
+++ b/tests/bench/benchmark_math/benchmark_math.yaml
@@ -62,14 +62,14 @@ trinity:
sync_method: nccl
-# ------------------ 不需要修改 ------------------
+# ------------------ do not edit ------------------
hydra:
searchpath:
- file://ajet/default_config
- file://ajet/default_config/verl # verl only
- file://ajet/default_config/trinity # trinity only
-# ------------------ 不需要修改 ------------------
+# ------------------ do not edit ------------------
defaults:
- verl_default # verl inherit
- trinity_default # trinity inherit
diff --git a/tests/bench/benchmark_math/benchmark_math_oai_sdk.yaml b/tests/bench/benchmark_math/benchmark_math_oai_sdk.yaml
index a3dadd1a..e7bf0aba 100644
--- a/tests/bench/benchmark_math/benchmark_math_oai_sdk.yaml
+++ b/tests/bench/benchmark_math/benchmark_math_oai_sdk.yaml
@@ -59,14 +59,14 @@ trinity:
sync_method: nccl
-# ------------------ 不需要修改 ------------------
+# ------------------ do not edit ------------------
hydra:
searchpath:
- file://ajet/default_config
- file://ajet/default_config/verl # verl only
- file://ajet/default_config/trinity # trinity only
-# ------------------ 不需要修改 ------------------
+# ------------------ do not edit ------------------
defaults:
- verl_default # verl inherit
- trinity_default # trinity inherit
diff --git a/tests/bench/benchmark_math/benchmark_math_raw_http.yaml b/tests/bench/benchmark_math/benchmark_math_raw_http.yaml
index 8a4fc433..88c9aa15 100644
--- a/tests/bench/benchmark_math/benchmark_math_raw_http.yaml
+++ b/tests/bench/benchmark_math/benchmark_math_raw_http.yaml
@@ -59,14 +59,14 @@ trinity:
sync_method: nccl
-# ------------------ 不需要修改 ------------------
+# ------------------ do not edit ------------------
hydra:
searchpath:
- file://ajet/default_config
- file://ajet/default_config/verl # verl only
- file://ajet/default_config/trinity # trinity only
-# ------------------ 不需要修改 ------------------
+# ------------------ do not edit ------------------
defaults:
- verl_default # verl inherit
- trinity_default # trinity inherit
diff --git a/tutorial/example_appworld/appworld.yaml b/tutorial/example_appworld/appworld.yaml
index 316c605b..3ccb91b7 100644
--- a/tutorial/example_appworld/appworld.yaml
+++ b/tutorial/example_appworld/appworld.yaml
@@ -54,14 +54,14 @@ ajet:
n_gpus_per_node: 8
-# ------------------ 不需要修改 ------------------
+# ------------------ do not edit ------------------
hydra:
searchpath:
- file://ajet/default_config
- file://ajet/default_config/verl # verl only
- file://ajet/default_config/trinity # trinity only
-# ------------------ 不需要修改 ------------------
+# ------------------ do not edit ------------------
defaults:
- verl_default # verl inherit 1/1
- trinity_default # trinity inherit 1/1
diff --git a/tutorial/example_appworld/appworld_oai_sdk.yaml b/tutorial/example_appworld/appworld_oai_sdk.yaml
index 056aac91..4b159b7f 100644
--- a/tutorial/example_appworld/appworld_oai_sdk.yaml
+++ b/tutorial/example_appworld/appworld_oai_sdk.yaml
@@ -53,14 +53,14 @@ ajet:
n_gpus_per_node: 8
-# ------------------ 不需要修改 ------------------
+# ------------------ do not edit ------------------
hydra:
searchpath:
- file://ajet/default_config
- file://ajet/default_config/verl # verl only
- file://ajet/default_config/trinity # trinity only
-# ------------------ 不需要修改 ------------------
+# ------------------ do not edit ------------------
defaults:
- verl_default # verl inherit 1/1
- trinity_default # trinity inherit 1/1
diff --git a/tutorial/example_countdown/countdown.yaml b/tutorial/example_countdown/countdown.yaml
index d5b161bf..6dcadf81 100644
--- a/tutorial/example_countdown/countdown.yaml
+++ b/tutorial/example_countdown/countdown.yaml
@@ -135,14 +135,14 @@ ajet:
execute_testing_lambda: "" # DO NOT EDIT, FOR ROBOT TESTING PURPOSE ONLY. NOT FOR HUMAN.
-# ------------------ 不需要修改 ------------------
+# ------------------ do not edit ------------------
hydra:
searchpath:
- file://ajet/default_config
- file://ajet/default_config/verl # verl only
- file://ajet/default_config/trinity # trinity only
-# ------------------ 不需要修改 ------------------
+# ------------------ do not edit ------------------
defaults:
- verl_default # verl inherit 1/1
- trinity_default # trinity inherit 1/1
diff --git a/tutorial/example_deep_finance/deep_finance.yaml b/tutorial/example_deep_finance/deep_finance.yaml
index e5de33da..fcb429c4 100644
--- a/tutorial/example_deep_finance/deep_finance.yaml
+++ b/tutorial/example_deep_finance/deep_finance.yaml
@@ -71,14 +71,14 @@ actor_rollout_ref:
rollout:
tensor_model_parallel_size: 8
gpu_memory_utilization: 0.8
-# ------------------ 不需要修改 ------------------
+# ------------------ do not edit ------------------
hydra:
searchpath:
- file://ajet/default_config
- file://ajet/default_config/verl # verl only
- file://ajet/default_config/trinity # trinity only
-# ------------------ 不需要修改 ------------------
+# ------------------ do not edit ------------------
defaults:
- verl_default # verl inherit 1/1
- trinity_default # trinity inherit 1/1
diff --git a/tutorial/example_deep_finance/yaml_template/deep_finance_template.yaml b/tutorial/example_deep_finance/yaml_template/deep_finance_template.yaml
index 38aa82ed..d9f559b9 100644
--- a/tutorial/example_deep_finance/yaml_template/deep_finance_template.yaml
+++ b/tutorial/example_deep_finance/yaml_template/deep_finance_template.yaml
@@ -75,14 +75,14 @@ actor_rollout_ref:
rollout:
tensor_model_parallel_size: 8
gpu_memory_utilization: 0.8
-# ------------------ 不需要修改 ------------------
+# ------------------ do not edit ------------------
hydra:
searchpath:
- file://ajet/default_config
- file://ajet/default_config/verl # verl only
- file://ajet/default_config/trinity # trinity only
-# ------------------ 不需要修改 ------------------
+# ------------------ do not edit ------------------
defaults:
- verl_default # verl inherit 1/1
- trinity_default # trinity inherit 1/1
diff --git a/tutorial/example_deep_finance/yaml_template/deep_finance_template_maxlen.yaml b/tutorial/example_deep_finance/yaml_template/deep_finance_template_maxlen.yaml
index 0ddd541c..02fa6f73 100644
--- a/tutorial/example_deep_finance/yaml_template/deep_finance_template_maxlen.yaml
+++ b/tutorial/example_deep_finance/yaml_template/deep_finance_template_maxlen.yaml
@@ -76,14 +76,14 @@ actor_rollout_ref:
rollout:
tensor_model_parallel_size: 8
gpu_memory_utilization: 0.8
-# ------------------ 不需要修改 ------------------
+# ------------------ do not edit ------------------
hydra:
searchpath:
- file://ajet/default_config
- file://ajet/default_config/verl # verl only
- file://ajet/default_config/trinity # trinity only
-# ------------------ 不需要修改 ------------------
+# ------------------ do not edit ------------------
defaults:
- verl_default # verl inherit 1/1
- trinity_default # trinity inherit 1/1
diff --git a/tutorial/example_deep_finance/yaml_template/infer.yaml b/tutorial/example_deep_finance/yaml_template/infer.yaml
index 5e9d400e..7dcf60ff 100644
--- a/tutorial/example_deep_finance/yaml_template/infer.yaml
+++ b/tutorial/example_deep_finance/yaml_template/infer.yaml
@@ -76,14 +76,14 @@ actor_rollout_ref:
rollout:
tensor_model_parallel_size: 8
gpu_memory_utilization: 0.8
-# ------------------ 不需要修改 ------------------
+# ------------------ do not edit ------------------
hydra:
searchpath:
- file://ajet/default_config
- file://ajet/default_config/verl # verl only
- file://ajet/default_config/trinity # trinity only
-# ------------------ 不需要修改 ------------------
+# ------------------ do not edit ------------------
defaults:
- verl_default # verl inherit 1/1
- trinity_default # trinity inherit 1/1
diff --git a/tutorial/example_feedback_tracing/example_feedback_tracing.yaml b/tutorial/example_feedback_tracing/example_feedback_tracing.yaml
index 1cb01333..894aca7c 100644
--- a/tutorial/example_feedback_tracing/example_feedback_tracing.yaml
+++ b/tutorial/example_feedback_tracing/example_feedback_tracing.yaml
@@ -62,14 +62,14 @@ trainer:
- console
-# ------------------ 不需要修改 ------------------
+# ------------------ do not edit ------------------
hydra:
searchpath:
- file://ajet/default_config
- file://ajet/default_config/verl # verl only
- file://ajet/default_config/trinity # trinity only
-# ------------------ 不需要修改 ------------------
+# ------------------ do not edit ------------------
defaults:
- verl_default # verl inherit 1/1
- trinity_default # trinity inherit 1/1
diff --git a/tutorial/example_learn2ask/learn2ask.yaml b/tutorial/example_learn2ask/learn2ask.yaml
index acacbce2..211c1e8a 100644
--- a/tutorial/example_learn2ask/learn2ask.yaml
+++ b/tutorial/example_learn2ask/learn2ask.yaml
@@ -53,14 +53,14 @@ trinity:
sync_method: nccl
-# ------------------ 不需要修改 ------------------
+# ------------------ do not edit ------------------
hydra:
searchpath:
- file://ajet/default_config
- file://ajet/default_config/verl # verl only
- file://ajet/default_config/trinity # trinity only
-# ------------------ 不需要修改 ------------------
+# ------------------ do not edit ------------------
defaults:
- verl_default # verl inherit 1/1
- trinity_default # trinity inherit 1/1
diff --git a/tutorial/example_math_swarm/math.py b/tutorial/example_math_swarm/math.py
index 041e544f..c1351a9d 100644
--- a/tutorial/example_math_swarm/math.py
+++ b/tutorial/example_math_swarm/math.py
@@ -28,38 +28,38 @@ def main():
reader_type = "huggingface_dat_repo",
reader_config = AjetTaskReader(
huggingface_dat_repo = HuggingfaceDatRepo(
- dataset_path = "/root/agentjet/benchmark_datasets/dataset/gsm8k/socratic",
+ dataset_path = '/mnt/data_cpfs/model_cache/modelscope/dataset/openai/gsm8k/main',
+ # dataset_path = "/root/agentjet/benchmark_datasets/dataset/gsm8k/socratic",
# dataset_path = "openai/gsm8k",
# dataset_name = "main",
)
)
)
- # # Hand shake with remote swarm server
+ # Hand shake with remote swarm server
swarm_worker = SwarmClient(AJET_SWARM_URL)
+ ajet_job = AgentJetJob(
+ experiment_name="math_gsm8k_grpo",
+ algorithm="grpo",
+ n_gpu=REMOTE_ALLOCATE_GPU_PER_NODE,
+ model=REMOTE_MODEL_PATH,
+ batch_size=REMOTE_BATCH_SIZE,
+ num_repeat=GRPO_N,
+ )
+ print(ajet_job.config.to_dict())
swarm_worker.auto_sync_train_config_and_start_engine(
- AgentJetJob(
- experiment_name="math_gsm8k_grpo",
- algorithm="grpo",
- n_gpu=REMOTE_ALLOCATE_GPU_PER_NODE,
- model=REMOTE_MODEL_PATH,
- batch_size=REMOTE_BATCH_SIZE,
- num_repeat=GRPO_N,
- ),
+ ajet_job,
force_restart=True,
)
def rollout(task):
- try:
- # begin episode
- episode_uuid, api_baseurl_key = swarm_worker.begin_episode(discard_episode_timeout=60)
- # execute agent ( base_url = api_baseurl_key.base_url, api_key = api_baseurl_key.api_key )
- workflow_output = execute_agent(task, api_baseurl_key) # reward is in `workflow_output`
- # report output back to swarm remote
- swarm_worker.end_episode(task, episode_uuid, workflow_output)
- return
- except:
- pass
+ # begin episode
+ episode_uuid, api_baseurl_key = swarm_worker.begin_episode(discard_episode_timeout=60)
+ # execute agent ( base_url = api_baseurl_key.base_url, api_key = api_baseurl_key.api_key )
+ workflow_output = execute_agent(task, api_baseurl_key) # reward is in `workflow_output`
+ # report output back to swarm remote
+ swarm_worker.end_episode(task, episode_uuid, workflow_output)
+ return
executor = PeriodicDrainThreadPoolExecutor(workers=GRPO_N * REMOTE_BATCH_SIZE, auto_retry=True)
for _ in range(NUM_EPOCH):
diff --git a/tutorial/example_rubrics_judge/r_judge.yaml b/tutorial/example_rubrics_judge/r_judge.yaml
index 5834da5f..94e0e940 100644
--- a/tutorial/example_rubrics_judge/r_judge.yaml
+++ b/tutorial/example_rubrics_judge/r_judge.yaml
@@ -57,14 +57,14 @@ ajet:
-# ------------------ 不需要修改 ------------------
+# ------------------ do not edit ------------------
hydra:
searchpath:
- file://ajet/default_config
- file://ajet/default_config/verl # verl only
- file://ajet/default_config/trinity # trinity only
-# ------------------ 不需要修改 ------------------
+# ------------------ do not edit ------------------
defaults:
- verl_default # verl inherit 1/1
- trinity_default # trinity inherit 1/1
diff --git a/tutorial/example_werewolves/game.py b/tutorial/example_werewolves/game.py
index 10246c32..8eca099e 100644
--- a/tutorial/example_werewolves/game.py
+++ b/tutorial/example_werewolves/game.py
@@ -44,7 +44,7 @@ async def hunter_stage(
global moderator
msg_hunter = await hunter_agent(
await moderator(Prompts.to_hunter.format(name=hunter_agent.name)),
- structured_model=get_hunter_model(players.current_alive),
+ structured_model=get_hunter_model(players.all_players),
)
if msg_hunter.metadata.get("shoot"):
return msg_hunter.metadata.get("name", None)
@@ -134,7 +134,7 @@ async def werewolves_game(agents: list[ReActAgent], roles) -> bool: # noqa: C90
msgs_vote = await fanout_pipeline(
players.werewolves,
msg=await moderator(content=Prompts.to_wolves_vote),
- structured_model=get_vote_model(players.current_alive),
+ structured_model=get_vote_model(players.all_players),
enable_gather=False,
)
killed_player, votes = majority_vote(
@@ -187,7 +187,7 @@ async def werewolves_game(agents: list[ReActAgent], roles) -> bool: # noqa: C90
),
),
structured_model=get_poison_model(
- players.current_alive,
+ players.all_players,
),
)
if msg_witch_poison.metadata.get("poison"):
@@ -206,7 +206,7 @@ async def werewolves_game(agents: list[ReActAgent], roles) -> bool: # noqa: C90
names_to_str(players.current_alive),
),
),
- structured_model=get_seer_model(players.current_alive),
+ structured_model=get_seer_model(players.all_players),
)
if msg_seer.metadata.get("name"):
player = msg_seer.metadata["name"]
@@ -282,7 +282,7 @@ async def werewolves_game(agents: list[ReActAgent], roles) -> bool: # noqa: C90
names_to_str(players.current_alive),
),
),
- structured_model=get_vote_model(players.current_alive),
+ structured_model=get_vote_model(players.all_players),
enable_gather=False,
)
voted_player, votes = majority_vote(
diff --git a/tutorial/example_werewolves/start.py b/tutorial/example_werewolves/start.py
index 879b6101..0e0ab4b0 100644
--- a/tutorial/example_werewolves/start.py
+++ b/tutorial/example_werewolves/start.py
@@ -4,6 +4,7 @@
"""The main entry point for the werewolf game."""
from typing import List
+import agentscope
import numpy as np
import dotenv
dotenv.load_dotenv()
@@ -12,7 +13,7 @@
from agentscope.agent import ReActAgent
from agentscope.formatter import DashScopeMultiAgentFormatter, OpenAIMultiAgentFormatter
-from agentscope.model import OpenAIChatModel
+from agentscope.model import OpenAIChatModel, DashScopeChatModel
from loguru import logger
from pydantic import Field
@@ -81,9 +82,12 @@ def get_official_agent_prompt(name) -> str:
class ExampleWerewolves(Workflow):
trainable_targets: List[str] | None = Field(default=["werewolf"], description="List of agents to be fine-tuned.")
+ big_external_opponent_llm_url: str = Field(default="http://22.17.52.4:2888/v1", description="The URL of the big external opponent LLM. You can replace it with any OpenAI-compatible LLM API URL.")
+ big_external_opponent_llm_name: str = Field(default="/mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen3-235B-A22B-Instruct-2507/", description="The model name of the big external opponent LLM. You can replace it with any OpenAI-compatible LLM name.")
async def execute(self, workflow_task: WorkflowTask, tuner: AjetTuner) -> WorkflowOutput:
+ assert agentscope.__version__ == "1.0.7", "AgentScope has too many bugs across versions, please use version 1.0.7 for werewolves example."
# ensure trainable targets is legal
assert self.trainable_targets is not None, "trainable_targets cannot be None in ExampleWerewolves (because we want to demonstrate a explicit multi-agent case)."
@@ -103,28 +107,27 @@ async def execute(self, workflow_task: WorkflowTask, tuner: AjetTuner) -> Workfl
# initialize agents
players = []
for i, role in enumerate(roles):
- default_model = OpenAIChatModel(
- model_name="/mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen3-235B-A22B-Instruct-2507/",
- stream=False,
- client_args={"base_url": "http://22.17.52.4:2888/v1"},
- api_key="no_api_key",
- generate_kwargs={"temperature": 0.01},
- )
- model_for_this_agent = tuner.as_agentscope_model(
- agent_name=f"Player{i + 1}", # the name of this agent
- target_tag=role, # `target_tag in self.trainable_targets` means we train this agent, otherwise we do not train this agent.
- debug_model=default_model, # the model used when this agent is not in `self.trainable_targets`
- )
+ if role not in self.trainable_targets:
+ model_for_this_agent = OpenAIChatModel(
+ stream=False,
+ api_key="no_api_key",
+ generate_kwargs={"temperature": 0.01},
+ model_name=self.big_external_opponent_llm_name,
+ client_args={"base_url": self.big_external_opponent_llm_url},
+ )
+ else:
+ model_for_this_agent = tuner.as_agentscope_model(
+ agent_name=f"Player{i + 1}",
+ target_tag=role,
+ )
agent = ReActAgent(
name=f"Player{i + 1}",
sys_prompt=get_official_agent_prompt(f"Player{i + 1}"),
model=model_for_this_agent,
- formatter=DashScopeMultiAgentFormatter()
- if role in self.trainable_targets
- else OpenAIMultiAgentFormatter(),
+ formatter=DashScopeMultiAgentFormatter() if isinstance(model_for_this_agent, DashScopeChatModel) else OpenAIMultiAgentFormatter(),
max_iters=3 if role in self.trainable_targets else 5,
)
- # agent.set_console_output_enabled(False)
+ agent.set_console_output_enabled(False)
players += [agent]
# reward condition
diff --git a/tutorial/example_werewolves_swarm/agent_roll.py b/tutorial/example_werewolves_swarm/agent_roll.py
new file mode 100644
index 00000000..c51bc01a
--- /dev/null
+++ b/tutorial/example_werewolves_swarm/agent_roll.py
@@ -0,0 +1,72 @@
+# -*- coding: utf-8 -*-
+
+import os
+from ajet.schema.task import Task, WorkflowTask
+from ajet.copilot.job import AgentJetJob
+from ajet.task_reader import RouterTaskReader
+from ajet.utils.thread_executors import PeriodicDrainThreadPoolExecutor
+from ajet.tuner_lib.as_oai_baseurl_apikey import OpenaiBaseUrlAndApiKey
+from ajet.default_config.ajet_default import AjetTaskReader
+from ajet.tuner_lib.experimental.as_swarm_client import SwarmClient
+
+NUM_EPOCH = 10000
+AJET_SWARM_URL = os.getenv("AJET_SWARM_URL", "http://localhost:10086")
+
+def main():
+
+ # Handshake with swarm remote, then send training param to swarm remote (such as model to be trained, algorithm, etc)
+ dataset = RouterTaskReader(
+ reader_type = "random_dummy",
+ reader_config = AjetTaskReader()
+ )
+
+ ajet_job = AgentJetJob(
+ base_yaml_config="tutorial/example_werewolves_swarm/werewolves.yaml",
+ algorithm="grpo",
+ experiment_name="werewolves_swarm",
+ max_env_worker=128,
+ )
+
+ # Hand shake with remote swarm server
+ swarm_worker = SwarmClient(AJET_SWARM_URL)
+ swarm_worker.auto_sync_train_config_and_start_engine(
+ ajet_job,
+ # force_restart=True,
+ )
+
+ GRPO_N = ajet_job.num_repeat
+ REMOTE_BATCH_SIZE = ajet_job.batch_size
+
+ def rollout(task):
+ # begin episode
+ episode_uuid, api_baseurl_key = swarm_worker.begin_episode(discard_episode_timeout=240)
+ # execute agent ( base_url = api_baseurl_key.base_url, api_key = api_baseurl_key.api_key )
+ workflow_output = execute_agent(task, api_baseurl_key) # reward is in `workflow_output`
+ # report output back to swarm remote
+ swarm_worker.end_episode(task, episode_uuid, workflow_output)
+ return
+
+
+ executor = PeriodicDrainThreadPoolExecutor(workers=1, max_parallel=64, auto_retry=True, block_first_run=True)
+ for _ in range(NUM_EPOCH):
+ for _, task in enumerate(dataset.generate_training_tasks()):
+ for _ in range(GRPO_N):
+ executor.submit_with_periodic_drain(fn=rollout, task=task)
+
+ return
+
+
+def execute_agent(task: Task, api_baseurl_key: OpenaiBaseUrlAndApiKey):
+ import asyncio
+ from tutorial.example_werewolves.start import ExampleWerewolves
+ game = ExampleWerewolves(
+ trainable_targets=["werewolf"],
+ big_external_opponent_llm_name="Qwen/Qwen3-235B-A22B-Instruct-2507",
+ big_external_opponent_llm_url="http://22.14.116.243:2888/v1",
+ )
+ res = asyncio.run(game.execute(WorkflowTask(task=task), api_baseurl_key))
+ return res
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tutorial/example_werewolves_swarm/convert_skill.md b/tutorial/example_werewolves_swarm/convert_skill.md
new file mode 100644
index 00000000..9ad3ff16
--- /dev/null
+++ b/tutorial/example_werewolves_swarm/convert_skill.md
@@ -0,0 +1,77 @@
+训练复杂智能体的时候,推荐先从yaml配置出发
+
+首先,复制一份基础配置 ajet/default_config/ajet_ts_default.yaml
+
+cp ajet/default_config/ajet_ts_default.yaml tutorial/example_werewolves_swarm/werewolves.yaml
+
+然后对配置中的参数进行修改:
+
+---- opencode命令:这里补充一个参数配置说明表格,参考tutorial/example_werewolves_swarm/werewolves.yaml ----
+
+
+# 编写训练循环 (Swarm Client)
+
+Swarm Client 流程如下:
+
+- 连接蜂群
+- 蜂群server初始化
+- 开始EPOCH循环
+
+
+```python
+# -*- coding: utf-8 -*-
+
+NUM_EPOCH = 10000
+AJET_SWARM_URL = os.getenv("AJET_SWARM_URL", "http://localhost:10086")
+
+def main():
+
+ ajet_job = AgentJetJob(
+ base_yaml_config="tutorial/example_werewolves_swarm/werewolves.yaml",
+ algorithm="grpo",
+ experiment_name="werewolves_swarm",
+ )
+
+ # Hand shake with remote swarm server
+ swarm_worker = SwarmClient(AJET_SWARM_URL)
+ swarm_worker.auto_sync_train_config_and_start_engine( ajet_job, force_restart=True )
+
+ GRPO_N = ajet_job.num_repeat
+ REMOTE_BATCH_SIZE = ajet_job.batch_size
+
+ def rollout(task):
+ try:
+ # begin episode
+ episode_uuid, api_baseurl_key = swarm_worker.begin_episode(discard_episode_timeout=60)
+ # execute agent ( base_url = api_baseurl_key.base_url, api_key = api_baseurl_key.api_key )
+ workflow_output = execute_agent(task, api_baseurl_key) # reward is in `workflow_output`
+ # report output back to swarm remote
+ swarm_worker.end_episode(task, episode_uuid, workflow_output)
+ return
+ except:
+ pass
+
+ # Handshake with swarm remote, then send training param to swarm remote (such as model to be trained, algorithm, etc)
+ dataset = RouterTaskReader(
+ reader_type = "random_dummy",
+ reader_config = AjetTaskReader()
+ )
+ executor = PeriodicDrainThreadPoolExecutor(workers=GRPO_N * REMOTE_BATCH_SIZE, auto_retry=True)
+ for _ in range(NUM_EPOCH):
+ for _, task in enumerate(dataset.generate_training_tasks()):
+ for _ in range(GRPO_N):
+ executor.submit_with_periodic_drain(fn=rollout, task=task)
+
+ return None
+
+
+def execute_agent(task: Task, api_baseurl_key: OpenaiBaseUrlAndApiKey):
+ raise NotImplementedError("see below.")
+
+
+if __name__ == "__main__":
+ main()
+
+```
+
+# 编写Agent (Swarm Client)
diff --git a/tutorial/example_werewolves_swarm/werewolves.yaml b/tutorial/example_werewolves_swarm/werewolves.yaml
new file mode 100644
index 00000000..e096e8f6
--- /dev/null
+++ b/tutorial/example_werewolves_swarm/werewolves.yaml
@@ -0,0 +1,74 @@
+# ------------------ main config ------------------
+ajet:
+ project_name: example_werewolves_swarm
+ experiment_dir: "auto" # {exp-dir}/{experiment_name}
+
+ model:
+ # ✨ select model to be trained
+ path: /mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2___5-7B-Instruct
+
+
+ rollout:
+ user_workflow: null
+ temperature: 0.7
+ max_env_worker: 64
+ num_repeat: 6
+ agent_madness_reward: 0.0
+ tensor_model_parallel_size: 1
+ # max_num_seqs: 40
+ # monitor LLM's abormal behaviors during rollout
+ compute_madness_checklist:
+ - "nonsense"
+ max_response_length_in_one_turn: 1024
+ max_model_len: 22000
+
+ task_reader:
+ type: random_dummy # `env_service` or `jsonl_dataset_file` or `huggingface_dat_repo` or `data_generation` or `random_dummy`
+
+ task_judge:
+ # ✨ select evaluation function
+ judge_protocol: null
+
+ # the experimental ZeroMQ interchange server feature that allows `tuner.as_oai_baseurl_apikey` feature
+ enable_interchange_server: True
+ # train in cloud, run episode locally
+ enable_swarm_mode: True
+ # both swarm / oai share the same interchange server
+ interchange_server:
+ interchange_method: 'ipc' # options: 'tcp' (multi-nodes) or 'ipc' (1 node)
+ interchange_server_port: 10086
+ num_fastapi_process: 2 # 1, 2 or 4 is fine
+ max_fastapi_threads: 512 # 64 or 128 is fine
+ max_inference_tracker_threads: 64 # recommend to be equal to `ajet.rollout.max_env_worker`
+ already_started: False # do not edit, used by `swarm`
+
+ swarm_mode_sample_collection_method: "rollout_until_finish_enough_tasks"
+
+ debug:
+ debug_max_parallel: 1
+ debug_first_n_tasks: 1
+
+ data:
+ train_batch_size: 32
+ max_prompt_length: 4000
+ max_response_length: 18000
+
+ trainer_common:
+ save_freq: 5
+ test_freq: 9999999
+ total_epochs: 9999999
+ total_training_steps: 25
+ nnodes: 1
+ n_gpus_per_node: 8
+
+# ------------------ do not edit ------------------
+hydra:
+ searchpath:
+ - file://ajet/default_config
+ - file://ajet/default_config/verl
+
+# ------------------ do not edit ------------------
+defaults:
+ - verl_default
+ - ajet_default
+ - _self_
diff --git a/tutorial/opencode_build_skillbench_agent.prompt.md b/tutorial/opencode_build_skillbench_agent.prompt.md
new file mode 100644
index 00000000..4a290cf9
--- /dev/null
+++ b/tutorial/opencode_build_skillbench_agent.prompt.md
@@ -0,0 +1,17 @@
+# Train SkillBench with AgentJet Swarm with Vibe Coding
+
+result is generated by `claude sonnet 4.5`
+
+=============================
+
+你的任务是训练这个仓库中的智能体:https://github.com/benchflow-ai/skillsbench.git
+仓库你需要下载到 ../skillsbench_swarm_test
+这是在调试过程中你可以使用的模型(openrouter)
+ "url": "https://openrouter-openrouter-esyubhyrxv.ap-northeast-1.fcapp.run/api/v1",
+ "key": "sk-or-v1-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
+ "model": "qwen/qwen3-max"
+
+
+
+你的skill(首先读取该SKILL文件,获取必要知识):
+- ajet/copilot/train-complex-blackbox/SKILL.md