Skip to content

[tinker][megatron] Multi-LoRA Megatron + Tinker API#1617

Open
erictang000 wants to merge 10 commits intoNovaSky-AI:mainfrom
erictang000:multi_lora
Open

[tinker][megatron] Multi-LoRA Megatron + Tinker API#1617
erictang000 wants to merge 10 commits intoNovaSky-AI:mainfrom
erictang000:multi_lora

Conversation

@erictang000
Copy link
Copy Markdown
Collaborator

Adds the design write-up for multi-tenant LoRA training on the Megatron backend exposed via the Tinker API. v1 is training-only; sampling and adapter-only checkpoint export are deferred. Implementation follows on the multi_lora branch.

Adds the design write-up for multi-tenant LoRA training on the Megatron
backend exposed via the Tinker API. v1 is training-only; sampling and
adapter-only checkpoint export are deferred. Implementation follows on
the multi_lora branch.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a design document for multi-tenant LoRA training on the Megatron backend, outlining a strategy to swap adapter weights and optimizer states between GPU and pinned CPU memory. Feedback focuses on technical risks and optimizations: specifically, the potential for gradient corruption during interleaved training steps, the need for a no-op check to avoid redundant synchronization overhead, and concerns regarding host memory pressure from pinned CPU storage. Additionally, it is recommended to remove specific line number references to ensure the documentation remains maintainable.

2. The fp32 main copy in `_opt.shard_fp32_from_float16_groups[g][i]` — independent storage, not a view.
3. The Adam moments in `_opt.optimizer.state[main_param]`, keyed by the **fp32 main param**: `exp_avg`, `exp_avg_sq`.

Param-object identity is preserved across `param.data.copy_(...)`, so optimizer state-dict keys remain valid. Grads are not swapped — `optimizer.zero_grad()` runs after every step (`megatron_strategy.py:215`), so they're zero at swap time.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The assumption that gradients are zero at swap time is problematic for multi-tenancy. If two training loops run in parallel, their forward_backward and optim_step calls can interleave. If Model B's forward_backward occurs between Model A's forward_backward and optim_step, Model A's gradients will be corrupted or cleared because they share the same GPU grad buffers. To support true parallel multi-tenancy, the design must either swap the gradient buffers or enforce atomicity for the entire training step (from forward_backward to optim_step) per model.


## Why

Today the SkyRL-Train backend exposed via the Tinker API is single-tenant: a second `create_model` is rejected at `skyrl/backends/skyrl_train_backend.py:342`, and `delete_model` does a full `ray.shutdown()` (line 404) so a fresh model can be created. This is documented under [Single-tenant LoRA](./limitations#single-tenant-lora).
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Avoid referencing specific line numbers in source files (e.g., skyrl_train_backend.py:342). These references quickly become inaccurate as the code changes. It is better to refer to method names or descriptive logic blocks.


The per-adapter slot store (the `AdapterStore`) lives **on each `PolicyWorker`** because Megatron's `DistributedOptimizer` shards optimizer state across DP ranks; each rank owns its own slice and must snapshot/restore it locally. The controller (`SkyRLTrainBackend`) holds only `model_id → role` maps; the dispatch layer (`WorkerDispatch`) fans `swap_to_adapter(model_id)` out to all policy actors.

The swap is **implicit** at the top of every per-model dispatch entry point: `forward`, `forward_backward`, `optim_step`, `set_lr`, `save_checkpoint`, `load_checkpoint`. Callers do not need to swap manually.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The design should specify that the implicit swap is a no-op if the requested model_id is already the active one. Given that this check happens at the top of every forward and forward_backward call, avoiding unnecessary dist.barrier and cuda.synchronize calls is essential for performance.

Comment on lines +78 to +82
Per worker, pinned memory:

- `cpu_param_data[mc][buf_idx]` — bf16, one tensor per `mc.buffers + mc.expert_parallel_buffers` entry, shape matches the bucket.
- `cpu_main_param[g][i]` — fp32, shape matches `shard_fp32_from_float16_groups[g][i]`.
- `cpu_exp_avg[g][i]`, `cpu_exp_avg_sq[g][i]` — fp32, same shapes.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using pinned CPU memory for all adapter slots can lead to significant host memory pressure as the number of adapters grows. It would be beneficial to include a strategy for managing this, such as a maximum number of resident slots or falling back to non-pinned memory with an LRU eviction policy.

erictang000 and others added 7 commits May 4, 2026 21:10
New module holding per-adapter pinned-CPU snapshots of the LoRA bucket
params + DistributedOptimizer fp32-main + Adam state on each Megatron
PolicyWorker. swap_to() walks mc.buffers + expert_parallel_buffers and
shard_fp32_from_float16_groups, doing tensor.copy_() in both directions
under torch.no_grad with dp_group barriers + cuda stream syncs.

Also includes a sanity check that every trainable param under DDP
buffers is a LoRA adapter param (named "...adapter..."), so a future
regression that unfreezes a non-LoRA param fails loudly at registration
rather than silently corrupting state.

Wiring into PolicyWorker / WorkerDispatch / SkyRLTrainBackend follows
in subsequent commits.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Adds an `adapter_store: AdapterStore | None` attribute on the policy
worker (allocated only when LoRA is active so the FFT path is unchanged)
plus five Ray-callable methods:

- prime_optimizer_state — calls Megatron's
  DistributedOptimizer._init_optimizer_states_with_dummy_values() so
  exp_avg/exp_avg_sq exist before we snapshot the pristine slot.
- register_pristine_adapter — derives a LoraSignature from the worker's
  own lora config + parallel state, snapshots live state into pristine.
- register_adapter(model_id) — allocates a fresh slot; first call uses
  live as the slot, subsequent calls seed from pristine.
- delete_adapter(model_id) — drops a slot.
- swap_to_adapter(model_id) — local tensor.copy_() between live and slot
  storages plus dp_group barriers.

Plus an adapter_store_state() diagnostic for tests. Orchestration from
the controller follows in subsequent commits.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
WorkerDispatch now exposes:
  - ensure_active_adapter(role, model_id): fans swap_to_adapter to all
    actors of `role`. No-op when model_id is None or the workers don't
    own an AdapterStore (FFT path).
  - prime_adapter_store(role, model_id): one-shot bootstrap for the very
    first create_model — primes optimizer state, registers pristine slot,
    registers the first adapter in one Ray-fanout sequence.
  - register_adapter / delete_adapter: per-call slot maintenance.

forward / forward_backward / forward_backward_from_staged / optim_step /
set_lr / save_checkpoint / load_checkpoint take an optional model_id and
call ensure_active_adapter after _ensure_on_gpu. Default None preserves
single-tenant behavior.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
create_model now allows additional 'policy' models when LoRA is active
and the first policy model has been built. Subsequent calls validate
(rank, alpha, target_modules) match the first adapter's signature, then
register a new slot via WorkerDispatch.register_adapter. FFT (rank=0)
keeps the original single-tenant gate.

_build_policy takes the first model_id and, when LoRA is active, fires
the AdapterStore bootstrap (prime_optimizer_state +
register_pristine_adapter + register_adapter) on every worker before
the colocate_all offload while model + optimizer are still GPU-resident.

delete_model: when more than one model is registered and the role is a
LoRA policy, just drop the slot via dispatch.delete_adapter and pop the
controller-side maps. Last-adapter delete still does the full
ray.shutdown teardown so the runtime can be rebuilt cleanly.

Plumbed model_id through forward / forward_backward / optim_step /
set_lr / save_checkpoint / load_checkpoint dispatch calls so the active
adapter is swapped in on every per-model entry point.

sample() and save_sampler_checkpoint() refuse with a clear error when
more than one LoRA adapter is registered (v1 inference path is single-
tenant; per-adapter sampling is deferred).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
End-to-end test that starts a Tinker API server with the SkyRL-Train
Megatron backend and exercises:

  - two LoRA adapters training independently without weight contamination,
  - rank-mismatch on a second create_model raises a clear error,
  - sample()/save_sampler_checkpoint with two adapters raises (v1 scope),
  - delete_model on one adapter leaves the runtime alive and the other
    adapter still trainable.

Auto-skips when no CUDA device is visible. Server lifecycle uses the
same wait_for_condition pattern as test_api.py.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Manual smoke test (the gate before merging multi_lora): launch a Tinker
API server with the SkyRL-Train Megatron backend, run two
tinker-cookbook sl_loop clients in parallel against it with distinct
model_ids, and verify

  - the policy model is built once (no second `init policy model done`),
  - the second client triggers `Registered additional LoRA adapter`,
  - both clients converge on their respective NLLs without weight
    contamination,
  - GPU memory stays bounded as the second client connects,
  - rank-mismatch / two-adapter sample / single-adapter-delete behave per
    the v1 contract.

Plus troubleshooting notes for the common failure modes.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…_modules

Tinker's public LoraConfig (skyrl/tinker/types.py:66) exposes only
rank + alpha + seed + train_{attn,mlp,unembed}; it has no
target_modules attribute. The Megatron path reads target_modules from
the server-side cfg.trainer.policy.model.lora.target_modules, which is
fixed at startup, so multi-adapter signature equality reduces to
(rank, alpha). The worker-side AdapterStore still verifies parallel
state equality via its own LoraSignature.

Fixes the AttributeError on the first create_model in the smoke test.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@erictang000
Copy link
Copy Markdown
Collaborator Author

erictang000 commented May 4, 2026

LoRA A and LoRA B are running concurrently against the same Tinker API server (B200), LoRA Baseline was run on a separate tinker API server on a different node (h100).

Things are looking relatively uncontaminated, and are almost identically matching!

image

erictang000 and others added 2 commits May 4, 2026 22:13
Fixes a cross-tenant grad-corruption race surfaced in review:

  Tick N: batched fwd_bwd = [A.fb, B.fb]
    - sub-batch A: swap_to("A"), zero_grad_buffer, accumulate A's grads
    - sub-batch B: swap_to("B")  <-- only params + opt state swapped
                   zero_grad_buffer  <-- A's grads CLOBBERED here
                   accumulate B's grads
  Tick N+1: A.optim_step
    - swap_to("A") restores A's params + opt state
    - optimizer.step() reads grad_data, which holds B's grads -> B's
      gradient is applied to A's weights, A's actual gradient is lost

The fix is to snapshot/restore `mc.buffers[i].grad_data` (and
`expert_parallel_buffers`) alongside `param_data`. AdapterSlot now
carries a parallel cpu_grad_data list; _allocate_empty_slot,
_snapshot, _restore, and _copy_slot all maintain it. The fp32 grad
accumulator inside DistributedOptimizer.step() is short-lived (created
and consumed within one call) so it doesn't need slot storage.

Memory cost: ~+1x per slot for the grad mirror (bf16, same size as
param buffer). For a 7B base + rank-32 LoRA on a single DP shard this
is on the order of tens of MB, dwarfed by the existing fp32 main +
Adam moments.

Updates the design doc to reflect the four storages per LoRA param and
adds a "Why grads must travel with the slot" section walking through
the race the review caught.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@erictang000 erictang000 changed the title [docs] Add Multi-LoRA Megatron Tinker design doc (v1) [tinker][megatron] Multi-LoRA Megatron + Tinker API May 4, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant