Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
3f20131
x
hao-aaron Apr 28, 2026
e39d8ea
Merge remote-tracking branch 'upstream/main' into multi-lora
hao-aaron Apr 28, 2026
68ed142
[docs] Add Multi-LoRA Megatron Tinker design doc (v1)
erictang000 May 4, 2026
c0c3a58
[multi-lora] Add AdapterStore for per-worker LoRA slot bookkeeping
erictang000 May 4, 2026
e923894
[multi-lora] Wire AdapterStore into MegatronPolicyWorkerBase
erictang000 May 4, 2026
46c1658
[multi-lora] Add ensure_active_adapter + model_id threading to dispatch
erictang000 May 4, 2026
90dc178
[multi-lora] Allow multiple LoRA policy adapters in SkyRLTrainBackend
erictang000 May 4, 2026
8bb9157
[multi-lora] Add GPU-gated multi-LoRA integration test for Megatron
erictang000 May 4, 2026
301059b
[multi-lora] Add two-client smoke runbook
erictang000 May 4, 2026
b712bca
[multi-lora] Fix _lora_signature_from to not read non-existent target…
erictang000 May 4, 2026
d4a0a04
x
erictang000 May 4, 2026
3c0239e
[multi-lora] Swap grad buffers along with params + optimizer state
erictang000 May 4, 2026
f5ba5c9
[multi-lora-rl] Wire model_id through the LoRA sync + sampling path
erictang000 May 4, 2026
40b1ae4
[multi-lora-rl] Tolerate non-JSON error bodies in load/unload_lora_ad…
erictang000 May 4, 2026
8cff746
[multi-lora-rl] Update design doc + RL two-client smoke runbook
erictang000 May 4, 2026
9d1392d
Merge remote-tracking branch 'origin/main' into multi_lora_rl
erictang000 May 4, 2026
a46b587
[multi-lora-rl] Allow mixed model_ids in a single sample() batch
erictang000 May 5, 2026
edefc84
[multi-lora-rl] Pass load_inplace=True to vLLM load_lora_adapter
erictang000 May 5, 2026
fa3bfbc
Revert "[multi-lora-rl] Pass load_inplace=True to vLLM load_lora_adap…
erictang000 May 5, 2026
cb38614
x
erictang000 May 5, 2026
e18e82b
[docs] Design doc for non-colocated sample routing via EXTERNAL path
erictang000 May 6, 2026
57a474a
[smoke logs] Snapshot rl_loop / sl_loop runs from manual smoke tests
erictang000 May 6, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .python-version
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
3.12
401 changes: 401 additions & 0 deletions docs/content/docs/tinker/async_sample_routing.mdx

Large diffs are not rendered by default.

164 changes: 164 additions & 0 deletions docs/content/docs/tinker/multi_lora_design.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
---
title: "Multi-LoRA on Megatron — Design"
---

# Multi-LoRA on Megatron — Design (v1)

This document describes the design for multi-tenant LoRA training on the SkyRL Megatron Tinker backend. It is the in-repo companion to the implementation tracked on the `multi_lora` branch.

## Why

Today the SkyRL-Train backend exposed via the Tinker API is single-tenant: a second `create_model` is rejected at `skyrl/backends/skyrl_train_backend.py:342`, and `delete_model` does a full `ray.shutdown()` (line 404) so a fresh model can be created. This is documented under [Single-tenant LoRA](./limitations#single-tenant-lora).

The driver for changing this is Trajectory AI, who want to run Tinker workloads on their own hardware and need many tenants on a shared training pool. There is no Megatron multi-tenant SFT/RL framework today; only Prime-RL has first-class multi-tenancy and only on FSDP/CP/EP. A Megatron-backed solution is therefore both a user-requested feature and a meaningful differentiator.

## Scope

- **Training and per-adapter sampling** on a single Tinker server, exercised via the `tinker-cookbook` `sl_loop` (SFT) and `rl_loop` (RL) recipes.
- One base model, multiple LoRA adapters with fixed `(rank, alpha, target_modules)` across all adapters. Mismatched configs on a second `create_model` are hard-rejected with a clear `ValueError`.
- The FFT (no-LoRA) path stays single-tenant — the relaxation is gated behind `lora_config.rank > 0`.
- The RL path requires `merge_lora=False` on Megatron so vLLM serves the adapter (not pre-merged weights) and supports multiple LoRA adapters concurrently. PR #1579 contributes the inference-side scaffolding (`load_lora_adapter` / `unload_lora_adapter`, `max_loras`, `max_cpu_loras`, mandatory `model` per data-plane call); this design plumbs the Tinker `model_id` through end-to-end so each tenant's adapter is registered and addressed by its own name on vLLM.

## Strategy

Keep one base model GPU-resident at all times. At any moment exactly one LoRA adapter is "live" in the model + optimizer. A swap is `tensor.copy_()` of LoRA buffer params and `DistributedOptimizer` fp32-main / `exp_avg` / `exp_avg_sq` between live GPU storage and per-adapter pinned-CPU slots.

The per-adapter slot store (the `AdapterStore`) lives **on each `PolicyWorker`** because Megatron's `DistributedOptimizer` shards optimizer state across DP ranks; each rank owns its own slice and must snapshot/restore it locally. The controller (`SkyRLTrainBackend`) holds only `model_id → role` maps; the dispatch layer (`WorkerDispatch`) fans `swap_to_adapter(model_id)` out to all policy actors.

The swap is **implicit** at the top of every per-model dispatch entry point: `forward`, `forward_backward`, `optim_step`, `set_lr`, `save_checkpoint`, `load_checkpoint`. Callers do not need to swap manually.

## Why a buffer-level swap is correct

Megatron's `DistributedDataParallel` filters out frozen params before constructing the param-and-grad buffers (`Megatron-LM/core_v0.16.0/megatron/core/distributed/distributed_data_parallel.py:139-141`). Combined with the LoRA pre-wrap hook at `megatron_worker.py:454-460` and `bridge.peft.lora.LoRA` setting `requires_grad=False` on base `to_wrap` params, the DDP `param_and_grad_buffer.param_data` contains **only LoRA A/B params**. Frozen base weights live as plain `nn.Parameter`s outside the buffer (which is exactly what the LoRA-aware branch in `megatron_utils.py:158-170` already handles).

Buffer-level `tensor.copy_()` therefore swaps adapter-only state. The base model stays GPU-resident and is shared across all tenants.

## Four storages per LoRA param

For each LoRA `nn.Parameter` `p`, four independent storages must be swapped:

1. The bf16 view in `mc.buffers[i].param_data` (or `mc.expert_parallel_buffers`).
2. The bf16 grad view in `mc.buffers[i].grad_data` — must travel with the slot, otherwise an interleaved tenant's `forward_backward` will clobber unconsumed grads via `chunk.zero_grad_buffer()` at the top of every fwd_bwd before this adapter's own `optim_step` runs.
3. The fp32 main copy in `_opt.shard_fp32_from_float16_groups[g][i]` — independent storage, not a view.
4. The Adam moments in `_opt.optimizer.state[main_param]`, keyed by the **fp32 main param**: `exp_avg`, `exp_avg_sq`.

Param-object identity is preserved across `param.data.copy_(...)`, so optimizer state-dict keys remain valid. The fp32 grad accumulator inside `DistributedOptimizer.step()` (reduce-scatter destination) is short-lived — it's allocated and consumed within `step()`, so it never persists across a swap and doesn't need its own slot storage.

## Pristine slot

Adam allocates `exp_avg` / `exp_avg_sq` lazily on the first non-trivial step. Megatron exposes `DistributedOptimizer._init_optimizer_states_with_dummy_values()` (in `distrib_optimizer.py`), which materialises state without a real fwd+bwd. The first `create_model` call:

1. Builds the policy worker and its `DistributedOptimizer` as today.
2. Calls `_init_optimizer_states_with_dummy_values()` on each underlying optimizer to materialise `exp_avg` / `exp_avg_sq`.
3. Snapshots the freshly-initialised LoRA state into the `AdapterStore`'s pristine slot.

Every subsequent `create_model("X")` allocates a fresh slot for `X` and copies the pristine slot's contents into it. A new tenant therefore starts from a freshly-initialised LoRA (kaiming-A + zero-B + zero optimizer state).

## Concurrency

`DistributedOptimizer.step()` issues DP-group collectives (reduce-scatter on grads, all-gather on updated params). Mixed adapter identity across DP ranks would corrupt these collectives. Therefore each `swap_to` ends with a `dist.barrier(group=mpu.get_data_parallel_group())` — and PP/TP equivalents where relevant — to ensure all ranks agree on the live adapter before the next collective begins.

A `torch.cuda.current_stream().synchronize()` between the save and restore halves of a swap guarantees that `non_blocking=True` D2H copies complete before the corresponding GPU storage is overwritten.

## Concrete `swap_to(adapter_id)` algorithm

Per worker, all under `torch.no_grad()`:

1. `dist.barrier(dp_group)` — wait for the previous adapter's last collective to finish.
2. Save current adapter into its slot:
- For each `mc` and each `buffer ∈ mc.buffers + mc.expert_parallel_buffers`: copy `buffer.param_data` AND `buffer.grad_data` into `slot.cpu_param_data[mc][i]` / `slot.cpu_grad_data[mc][i]` (`non_blocking=True`).
- For each `_opt ∈ _iter_opts(self.optimizer)` and each `(g, i)`: `slot.cpu_main_param[g][i].copy_(_opt.shard_fp32_from_float16_groups[g][i], non_blocking=True)`. Then `slot.cpu_exp_avg[g][i].copy_(state['exp_avg'], non_blocking=True)`, `slot.cpu_exp_avg_sq[g][i].copy_(state['exp_avg_sq'], non_blocking=True)`.
3. `torch.cuda.current_stream().synchronize()` — D2H complete.
4. Load target adapter — same loops in reverse, copying CPU → GPU into the same storages.
5. `torch.cuda.current_stream().synchronize()` — H2D complete.
6. `dist.barrier(dp_group)` — agreement on live adapter before next collective.

## Per-AdapterSlot CPU storage

Per worker, pinned memory:

- `cpu_param_data[mc][buf_idx]` — bf16, one tensor per `mc.buffers + mc.expert_parallel_buffers` entry, shape matches the bucket.
- `cpu_grad_data[mc][buf_idx]` — bf16, parallel to `cpu_param_data`. Required for cross-tenant interleaving correctness (see "Why grads must travel with the slot" below).
- `cpu_main_param[g][i]` — fp32, shape matches `shard_fp32_from_float16_groups[g][i]`.
- `cpu_exp_avg[g][i]`, `cpu_exp_avg_sq[g][i]` — fp32, same shapes.

Frozen base weights are not duplicated per adapter; they live in their own pinned storage already managed by `offload_megatron_model_to_cpu`.

### Why grads must travel with the slot

The Tinker engine batches pending `forward_backward`s across model_ids in 100 ms ticks (`engine.py:721`). Within one tick, `_split_model_pass_batch_by_model_id` (`skyrl_train_backend.py:152`) splits into per-model sub-batches and runs them sequentially. Each `forward_backward` starts with `chunk.zero_grad_buffer()` (`megatron_worker.py:702`).

If we swapped only params + optimizer state, this sequence corrupts:

```
batched fwd_bwd = [A.fb, B.fb]
sub-batch A: swap_to("A"), zero_grad_buffer, accumulate A's grads into grad_data
sub-batch B: swap_to("B") ← params + opt state swap; grads UNTOUCHED
zero_grad_buffer ← clobbers A's grads
accumulate B's grads into grad_data
A.optim_step: swap_to("A") ← restores A's params + opt state; grads = B's
optimizer.step() ← applies B's grads to A's weights ✗
```

Snapshotting `grad_data` into the slot on every swap fixes this. After A's fwd_bwd, A's slot holds A's grads. swap_to(B) overwrites the live `grad_data` with B's saved grads (zero on the first visit), so B's `zero_grad_buffer` clears B's data, not A's. swap_to(A) before A's `optim_step` restores A's grads exactly. The fp32 grad accumulator that DistributedOptimizer materialises inside `step()` is short-lived (created and consumed within one call), so it doesn't need its own slot storage.

## Files to add / modify

### New

- `skyrl/backends/skyrl_train/workers/megatron/adapter_store.py` — `AdapterSlot` + `AdapterStore`.

### Modified

- `skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py` — construct `AdapterStore` after optimizer init; expose Ray-callable `prime_optimizer_state`, `register_pristine_adapter`, `register_adapter`, `delete_adapter`, `swap_to_adapter`.
- `skyrl/backends/skyrl_train/workers/worker_dispatch.py` — `ensure_active_adapter(model_id)`; thread `model_id` kwarg through `forward`, `forward_backward`, `optim_step`, `set_lr`, `save_checkpoint`, `load_checkpoint`.
- `skyrl/backends/skyrl_train_backend.py` — relax the single-policy gate when `lora_config.rank > 0`; rework `delete_model` to only `ray.shutdown()` on last adapter; pass `model_id` into every dispatch call; raise on `sample()` / `save_sampler_checkpoint` if more than one adapter is registered.

## Verification

- `pytest tests/tinker/test_multi_lora_megatron.py` — GPU-gated integration test that creates A, trains, creates B from pristine, trains, switches back to A, and asserts state is preserved bit-for-bit. Negative tests for rank mismatch and `sample()` with two adapters.
- Existing `tests/tinker/test_api.py` continues to pass (single-tenant path unchanged).
- End-to-end smoke (manual): launch the Tinker server with `trainer.strategy=megatron`, run two `tinker_cookbook.recipes.sl_loop` clients with distinct `model_id`s in parallel against `base_url=http://localhost:8000`, verify both converge on their respective tasks and GPU memory stays bounded.

## Per-adapter sampling and weight sync (RL path)

For RL, sampling has to be per-adapter: each tenant runs `forward_backward → optim_step → save_weights_for_sampler → sample` on its own model_id, and vLLM has to know which weights to use for which call.

The plumbing is straightforward once the SFT scaffolding is in place:

- **Adapter name = Tinker model_id.** The Tinker `model_id` is forwarded as the vLLM adapter name end-to-end. There is no separate naming layer.
- **Per-tenant `lora_sync_path`.** When `model_id` is set, the worker writes its adapter into `os.path.join(cfg.trainer.policy.model.lora.lora_sync_path, model_id)/` so concurrent saves from different tenants don't collide. Single-tenant calls (`model_id=None`, the FFT path or pre-Tinker callers) keep the legacy shared path.
- **`save_sampler_checkpoint(model_id)`** swaps the requested adapter live (via `WorkerDispatch.ensure_active_adapter`), then broadcasts to vLLM. The worker calls `RemoteInferenceClient.load_lora_adapter(model_id, lora_sync_path/model_id, load_inplace=True)` so re-syncs of the same adapter overwrite vLLM's slot in place (no fresh int id, no eviction churn).
- **`sample(prepared_batch)`** now resolves the data-plane `model` per request from `prepared_batch.all_model_ids[i]` when LoRA is active, falling back to `resolve_policy_model_name(cfg)` only on the FFT / single-tenant path. The previous "raise if >1 adapter" guards are gone.

### vLLM capacity contract

vLLM exposes two LoRA capacity knobs:

- `max_loras` — concurrent adapters in a single GPU batch.
- `max_cpu_loras` — total LoRA capacity in vLLM's CPU LRU cache (defaults to `max_loras` when unset).

Both are config fields on `SkyRLLoraConfig` (`config.py`, surfaced in `ppo_base_config.yaml`). For multi-tenant RL, `max_cpu_loras` MUST be set to at least the expected number of concurrent registered adapters; otherwise vLLM will silently evict an adapter from its CPU cache and the next `sample()` against that adapter will 404. The Tinker server doesn't auto-size this — operators set it explicitly when they expect N tenants. Detecting and reloading-from-disk on 404 is a future improvement.

### Concurrency between training and sampling

With `colocate_all=True`, vLLM is asleep during training and woken in `save_weights_for_sampler` for the broadcast + KV-cache wake. With multi-tenant RL the engine queue's destructive-barrier scheduling (`engine.py`) keeps each model_id's `forward_backward / optim_step / save_weights_for_sampler / sample` chain coherent. For non-colocated mode the same applies — `pause_generation / resume_generation` brackets the broadcast, and per-adapter sample requests fan out independently.

## PR #1579 (foundation)

The RL multi-tenancy work sits on top of [NovaSky-AI/SkyRL#1579](https://github.com/NovaSky-AI/SkyRL/pull/1579) (`hao-aaron:multi-lora`), which provides:

- `RemoteInferenceClient.load_lora_adapter(name, path, load_inplace)` and `unload_lora_adapter(name)`.
- `SKYRL_LORA_ADAPTER_NAME` constant + `resolve_policy_model_name(cfg)` for the single-tenant fallback.
- Mandatory `model` field on every data-plane call (`generate / sample / chat_completion / completion / render_chat_completion`).
- `max_loras` + `max_cpu_loras` config knobs plumbed through to vLLM CLI args.

Our `multi_lora_rl` branch is based on PR #1579's HEAD with our SFT AdapterStore commits cherry-picked on top.

## Out of scope (explicit non-goals)

- Adapter-only checkpoint export — `save_checkpoint` still saves the whole base+LoRA state per swap.
- Variable rank / alpha / target_modules across adapters.
- Critic role multi-tenancy — Megatron critic is `NotImplementedError` today; no change.
- `HybridDeviceOptimizer` path — TODO comment only.
- Auto-reload of LoRA adapters evicted from vLLM's CPU LRU cache; for now `max_cpu_loras` must be set ≥ expected concurrent adapters.
Original file line number Diff line number Diff line change
@@ -1,18 +1,26 @@
from typing import Any, List

from skyrl.backends.skyrl_train.inference_servers.utils import resolve_policy_model_name

from ..base import AsyncInferBackend, GeneratorOutput, GeneratorInput


class SkyRLBackend(AsyncInferBackend):
def __init__(self, infer_engine, tokenizer: Any = None, cfg: Any = None):
self.client = infer_engine
# Resolve the name the inference engine knows the policy by (base
# model or registered LoRA adapter) once at construction. Threaded
# into every ``client.generate`` call so the data plane never has
# to guess the target adapter.
self.policy_model_name = resolve_policy_model_name(cfg) if cfg is not None else self.client.model_name

async def async_generate_prompts(self, prompts: Any, sampling_params: Any, **kwargs) -> List[str]:
input_obj = {
"prompts": [prompts],
"session_ids": [kwargs.get("request_id", None)],
"sampling_params": sampling_params,
}
output = await self.client.generate(input_obj)
output = await self.client.generate(input_obj, model=self.policy_model_name)
return output["responses"][0], output["stop_reasons"][0]

async def async_generate_ids(self, input_ids: List[int], sampling_params: Any, **kwargs) -> List[str]:
Expand All @@ -21,7 +29,7 @@ async def async_generate_ids(self, input_ids: List[int], sampling_params: Any, *
"session_ids": [kwargs.get("request_id", None)],
"sampling_params": sampling_params,
}
output = await self.client.generate(input_obj)
output = await self.client.generate(input_obj, model=self.policy_model_name)
# todo(@csy) probably need to be finish_reason
# https://github.com/vllm-project/vllm/blob/a0f8a7964694a6077689b242b5eca95de392d4bb/vllm/v1/engine/__init__.py#L22
meta_info = {
Expand Down
Loading
Loading