[train][multimodal][3/3] Add multi-turn VLM generator #1486
[train][multimodal][3/3] Add multi-turn VLM generator #1486nithinvc wants to merge 18 commits intoNovaSky-AI:mainfrom
Conversation
637b74c to
3e29067
Compare
3e29067 to
bbeac08
Compare
SumanthRH
left a comment
There was a problem hiding this comment.
Looking good! I have some questions on the concatenation logic for obs_tokens
|
|
||
| register( | ||
| id="color_square_test_env", | ||
| entry_point="tests.backends.skyrl_train.gpu.gpu_ci.test_skyrl_vlm_gym_generator:ColorSquareEnv", | ||
| ) |
There was a problem hiding this comment.
Tests should ideally register and deregister envs at cleanup. We didn't do this properly before I think. let's add a deregister helper in skyrlgym's registration file for this? You can then run register + deregister in a autouse module-level fixture
There was a problem hiding this comment.
done, added a deregister function
| rollout_expert_indices: Optional[List[List[List[List[int]]]]] # [batch_size, seq_len, layer_num, topk] | ||
| # Applicable only for step-wise training | ||
| is_last_step: Optional[List[bool]] | ||
| pixel_values: Optional[List[torch.Tensor]] |
There was a problem hiding this comment.
nit: Add a comment that these are for vlm - the previous comment makes it look like these are only for step wise
| if gen_ids: | ||
| per_step_rewards.append((step_reward, len(response_ids) - 1)) | ||
|
|
||
| # 6. If episode continues, defer obs token extraction to next render | ||
| if not done: | ||
| conversation.extend(new_obs) | ||
| pending_obs_offset = len(input_ids) + len(gen_ids) | ||
|
|
There was a problem hiding this comment.
isn't this book-keeping wrong for some chat templates?
if rendering for an assistant message changes depending on whether this message is the last assistant message or not, then gen_ids doesn't match the rendered token ids for new_obs in the new conversation.
Example, with some thinking models like Qwen 3 (text-only), previous assistant messages have their thinking token stripped when you apply chat template. In this case, gen_ids will include thinking tokens, but once you add gen_text and new_obs to the conversation and re-tokenize, the thinking part in gen_text is stripped out. So pending_obs_offset will be incorrect.
There are probably multimodal models where this holds true as well.
We did this for getting obs tokens correctly for text-only: https://github.com/nithinvc/SkyRL/blob/da798404f9c99e4466bf61a56e5d9cec2247aeda/skyrl/train/generators/skyrl_gym_generator.py#L584-L593
There was a problem hiding this comment.
Yeah, that makes sense. I got rid of the pending_obs_offset logic. I switched to book-keeping the last render length prev_render_len. In the next agent loop iteration, we walk from prev_render_len looking for an EOS token. The renderer will add the EOS token to indicate the end of the previous turn. The observation tokens start from there.
Alternatively, we can break up the conversation history to do a /render call on individual observation chunks (like the code snippet linked). The downside here is that there will be more generator side book-keeping, especially keeping track of the multi-modal inputs. Each render call we would have to deserialize and build a GenerateRequest. This would have two steps:
- Concatenate the serialized tensor data (deserialize, cat, serialize)
- Update the placeholder indexing returned by vllm since, from the endpoint's perspective, the new observation is the only observation.
I figured the prev_render_len approach is simpler, but let me know what you think
There was a problem hiding this comment.
It would be good to add a test asserting the exact return value of agent_loop here so that the concatenation logic is tested. The right way would be to have a CPU test with a dummy vllm endpoint. For rendering, we make use of a real HuggingFace renderer and we have a hardcoded expected result on the full token sequence. Ideally this is with a thinking model with a non-standard chat template
There was a problem hiding this comment.
Makes sense. I added some CPU tests in a file alongside the existing generator tests tests/train/generators/test_skyrl_vlm_generator.py
|
Q: Have you tested the generate entrypoint as well E2E ? It looks like that will just work given the current generator changes, but would be good to run E2E |
file path Co-authored-by: devin-ai-integration[bot] <158243242+devin-ai-integration[bot]@users.noreply.github.com>
Yes, but it does require the trainer changes in #1498 |
SumanthRH
left a comment
There was a problem hiding this comment.
Ok i realize there is more complexity with thinking models - loss_mask can't be made append only. Let's focus on regular non-thinking models for this PR.
The best way to add support for thinking models would be to add step-wise training support for the SkyRLVLMGenerator, best done in a follow-up
| if prev_render_len is not None: | ||
| obs_offset = None | ||
| for i in range(prev_render_len, len(input_ids)): | ||
| if input_ids[i] == self.tokenizer.eos_token_id: |
There was a problem hiding this comment.
Hmm I guess this works for now because we pass include_stop_str_in_output=True:
There was a problem hiding this comment.
yes - but I also think it works since it adds stop strings to intermediate turns. For example, for a generated:
<user> </user>
<assistant></EOS>
Even if EOS is excluded during generation, I believe it's added back when we rerender the next turn to indicate the end of the assistant message:
<user></user>
<assistant></EOS>
<user></user>
There was a problem hiding this comment.
Actually because it stops at stop string, you will get an issue with response_ids being incorrect. Because gen_ids won't have EOS token
There was a problem hiding this comment.
# stop string -> <stop>
gen_ids as tokens -> [Hi, <stop>]
obs_tokens -> [Hello]
response_ids = [Hi, <stop>, Hello]
expected -> [Hi, <stop>, EOS, Hello]
| # If no EOS token found, assume no thinking tokens | ||
| obs_offset = prev_render_len | ||
| logger.warning("No EOS token found after prev_render_len; obs offset may be incorrect") |
There was a problem hiding this comment.
actually instead of having a "sometimes wrong" codepath here, let's just error out loudly?
There was a problem hiding this comment.
switched to raising an error
| output: GeneratorOutput = await generator.generate(input_batch) | ||
|
|
||
| response_ids = output["response_ids"][0] | ||
| loss_masks = output["loss_masks"][0] |
There was a problem hiding this comment.
Nit: this is loss_mask for a single entry, not loss_masks
There was a problem hiding this comment.
renamed var to loss_mask but dict key is still loss_masks
| # 5. Track generated tokens (loss_mask=1) | ||
| response_ids.extend(gen_ids) | ||
| loss_mask.extend([1] * len(gen_ids)) |
There was a problem hiding this comment.
This is the logic for the models with a standard chat template - otherwise loss mask can't be fully append only. this is as such okay in this PR - this will lead to some off-policy ness in training for thinking models (we basically retain all the thinking tokens in all the turns - even tho we sampled by ignoring thining tokens for previous turns).
For proper support for thinking models, step wise training would be the way
There was a problem hiding this comment.
I agree, this makes sense to me
| MODEL_NAME = "Qwen/Qwen3-0.6B" | ||
| THINKING_PREFIX = "<think>\nmock thinking\n</think>\n\n" | ||
|
|
There was a problem hiding this comment.
Let's just focus the test on non-thinking model actually. non-standard chat templates will need step wise training support
ae92277 to
438847d
Compare
|
@SumanthRH Updated and removed thinking tests - everything focuses on standard tokenizers now. |
| gen_text = engine_output["responses"][0] | ||
| gen_ids = engine_output["response_ids"][0] | ||
| stop_reason = engine_output["stop_reasons"][0] | ||
| gen_logprobs = engine_output["response_logprobs"][0] if engine_output.get("response_logprobs") else None | ||
|
|
||
| # 3. Environment step | ||
| env_step_output = await self._run_in_executor_if_available(env.step, gen_text) | ||
| new_obs = env_step_output["observations"] | ||
| step_reward: float = env_step_output["reward"] | ||
| done = env_step_output["done"] | ||
|
|
||
| # 4. Append assistant message to conversation | ||
| conversation.append({"role": "assistant", "content": gen_text}) | ||
|
|
||
| # 5. Track generated tokens (loss_mask=1) | ||
| response_ids.extend(gen_ids) | ||
| loss_mask.extend([1] * len(gen_ids)) | ||
| if rollout_logprobs is not None: | ||
| rollout_logprobs.extend(gen_logprobs if gen_logprobs else [0.0] * len(gen_ids)) | ||
|
|
||
| if gen_ids: | ||
| per_step_rewards.append((step_reward, len(response_ids) - 1)) |
There was a problem hiding this comment.
🟡 VLM generator omits EOS token between assistant turns and observations when stop strings cause generation to end
The VLM agent_loop does not replicate the parent class's append_eos_token_after_stop_str_in_multi_turn logic (see skyrl/train/generators/skyrl_gym_generator.py:358-370). When the inference engine stops generation at a stop string (e.g., </answer>, </search>) rather than by generating the EOS token, the parent's agent_loop explicitly appends an EOS token to output_ids. The VLM generator lacks this step, so gen_ids won't end with an EOS token. The subsequent response_ids.extend(gen_ids) at line 171 means the response sequence transitions directly from the generated content into observation tokens (extracted on the next iteration at line 137) without an EOS boundary. This causes the training sequence (prompt_ids + response_ids) to diverge from the rendered conversation used for generation (which contains the chat template's EOS between the assistant message and observation), producing incorrect positional embeddings for all subsequent tokens in the trajectory.
Prompt for agents
The VLM generator's agent_loop is missing the stop-string EOS handling that the parent class implements at skyrl/train/generators/skyrl_gym_generator.py:358-370. After obtaining gen_ids and gen_text from the engine output (around line 156-159), the code should check if stop strings are configured in current_sampling_params, and if the output ends with a stop string without an EOS token, append the EOS token ID to gen_ids. This mirrors the parent class logic that checks self.generator_cfg.append_eos_token_after_stop_str_in_multi_turn. Without this, when stop strings are used (e.g. for search or tool-use VLM environments), the response_ids will be missing EOS tokens at turn boundaries, causing the training sequence to diverge from what the model saw during generation.
Was this helpful? React with 👍 or 👎 to provide feedback.
| conversation.append({"role": "assistant", "content": gen_text}) | ||
|
|
||
| # 5. Track generated tokens (loss_mask=1) | ||
| response_ids.extend(gen_ids) |
There was a problem hiding this comment.
can we make sure to append EOS token to gen_ids incase it stoppped at a stop token @nithinvc ? use the same flag append_eos_token_after_stop_str_in_multi_turn
Summary
When reviewing, please diff against #1494 . Generator implementation for #1493 .
Adds
SkyRLVLMGymGenerator, a VLM-specific generator subclass that handles multi-modal (text + image) observations from VisGym environments. Extends the base generator types to carry vision tensors through the training pipeline.pixel_valuesandimage_grid_thwfields toGeneratorOutputandTrajectoryOutputso vision tensors flow through to the trainerSkyRLGymGenerator.generate()to collect and forward vision features when trajectories contain themSkyRLVLMGymGeneratorwith a "render delta" approach: the conversation (OpenAI-format messages with base64 images) is the source of truth, re-tokenized viarender_chat_completionat each turn. Generated tokens retain original logprobs; observation tokens are sliced from the re-render and masked out (loss_mask=0)ColorSquareEnvthat verifies structural output, multimodal tensors, loss mask correctness, and semantic generation qualityTest plan
uv run pytest tests/train/generators/test_skyrl_gym_generator.py -vSKYRL_LOCAL_VLLM=1 uv run --isolated --extra dev --extra fsdp pytest tests/backends/skyrl_train/gpu/gpu_ci/test_skyrl_vlm_gym_generator.py -m vllm -v