Support megatron tokenization for post training datasets#1018
Support megatron tokenization for post training datasets#1018kevalmorabia97 wants to merge 1 commit intomainfrom
Conversation
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
📝 WalkthroughWalkthroughEnhances megatron preprocessing by adding chat-template rendering for list-typed JSON keys, implementing runtime logging for document and sequence truncation events, reorganizing length accounting logic to reflect post-processing state, and introducing per-key tracking to prevent duplicate template log messages. Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes 🚥 Pre-merge checks | ✅ 3 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches
🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (2)
modelopt/torch/utils/plugins/megatron_preprocess_data.py (2)
119-131: Consider handling unexpected value types gracefully.The code assumes
valueis either a list (chat format) or a string (raw text). If a JSON key contains an unexpected type (e.g.,dict,int,None), line 131 assigns it directly totext, which will fail at_Encoder.tokenizer.encode(text)with an unclear error.Consider adding explicit type validation:
🛡️ Suggested defensive check
if isinstance(value, list): if key not in _Encoder._chat_template_logged: _Encoder._chat_template_logged.add(key) print(f"Applying chat_template to '{key}' key") kwargs = {} tools = data.get("tools") if tools: kwargs["tools"] = tools text = _Encoder.tokenizer.apply_chat_template(value, tokenize=False, **kwargs) - else: + elif isinstance(value, str): text = value + else: + raise TypeError(f"Expected list or str for key '{key}', got {type(value).__name__}")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/utils/plugins/megatron_preprocess_data.py` around lines 119 - 131, The code assumes value is list or string and assigns non-string types directly to text, causing failures later in _Encoder.tokenizer.encode(text); update the handling in _Encoder (in the block around tokenizer.apply_chat_template) to validate value types: if value is list handle as before, if value is str use directly, if value is None/other primitive convert to string (e.g., str(value)) or skip/log a warning (use _Encoder._chat_template_logged or a logger) and ensure text is always a string before calling _Encoder.tokenizer.encode; add a clear warning or raise a controlled TypeError for unexpected dict/complex types to make failures explicit.
135-146: Inconsistent truncation logging and potential noise.Document truncation is logged (lines 137-138), but sequence truncation logging is commented out (line 146). Additionally, the commented-out code references
original_lengthwhich would be undefined in that context since it was overwritten at line 135.For large datasets with many long documents, the document truncation logging could produce excessive output. Consider:
- Making truncation logging consistent (either log both or neither)
- Using a counter and logging summary statistics at the end, rather than per-document
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/utils/plugins/megatron_preprocess_data.py` around lines 135 - 146, The code currently prints per-document truncation and has a commented-out, buggy sequence-truncation log that references original_length incorrectly; instead add counters (e.g., num_doc_truncated, num_seq_truncated) and track original character length and original token length (capture original_token_length = len(encoded) before truncating to self.max_sequence_length) inside the block using _Encoder.tokenizer.encode and the document trimming to self.max_document_length, increment the appropriate counter when truncation occurs, remove per-document print calls, and emit a single summary log at the end of preprocessing (including counts and maybe totals like total_docs and total_truncated_tokens) so truncation reporting is consistent and not noisy.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/torch/utils/plugins/megatron_preprocess_data.py`:
- Line 90: The class-level _chat_template_logged set cannot deduplicate logs
across worker processes because each multiprocessing.Pool worker has its own
copy; either move the "Applying chat_template..." logging out of worker code
into the parent process before dispatching tasks (so logging/deduplication is
done once), or replace the in-process set with a process-shared structure
created from multiprocessing.Manager() (e.g., a Manager().dict or list used as a
set) and use that shared object where _chat_template_logged is referenced so all
Pool workers see the same dedupe state; update references to
_chat_template_logged and the code that writes the log accordingly.
---
Nitpick comments:
In `@modelopt/torch/utils/plugins/megatron_preprocess_data.py`:
- Around line 119-131: The code assumes value is list or string and assigns
non-string types directly to text, causing failures later in
_Encoder.tokenizer.encode(text); update the handling in _Encoder (in the block
around tokenizer.apply_chat_template) to validate value types: if value is list
handle as before, if value is str use directly, if value is None/other primitive
convert to string (e.g., str(value)) or skip/log a warning (use
_Encoder._chat_template_logged or a logger) and ensure text is always a string
before calling _Encoder.tokenizer.encode; add a clear warning or raise a
controlled TypeError for unexpected dict/complex types to make failures
explicit.
- Around line 135-146: The code currently prints per-document truncation and has
a commented-out, buggy sequence-truncation log that references original_length
incorrectly; instead add counters (e.g., num_doc_truncated, num_seq_truncated)
and track original character length and original token length (capture
original_token_length = len(encoded) before truncating to
self.max_sequence_length) inside the block using _Encoder.tokenizer.encode and
the document trimming to self.max_document_length, increment the appropriate
counter when truncation occurs, remove per-document print calls, and emit a
single summary log at the end of preprocessing (including counts and maybe
totals like total_docs and total_truncated_tokens) so truncation reporting is
consistent and not noisy.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: f4953324-76ee-4bb8-8ae1-3ac0be1c4f1b
📒 Files selected for processing (1)
modelopt/torch/utils/plugins/megatron_preprocess_data.py
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #1018 +/- ##
==========================================
- Coverage 70.11% 70.09% -0.03%
==========================================
Files 220 220
Lines 25240 25240
==========================================
- Hits 17698 17692 -6
- Misses 7542 7548 +6 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
|
Can we add test in |
Do you know any small dataset we can use for test here? |
|
The test_sft split of this dataset is very small: https://huggingface.co/datasets/HuggingFaceTB/everyday-conversations-llama3.1-2k/viewer/default/test_sft |
What does this PR do?
Update megatron_preprocess_data.py to support applying chat template for tokenizing chat based post training datasets
Usage
# Add a code snippet demonstrating how to use thisTesting
Before your PR is "Ready for review"
Make sure you read and follow Contributor guidelines and your commits are signed (
git commit -s -S).Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded
trust_remote_code=True,torch.load(..., weights_only=False),pickle, etc.).CONTRIBUTING.md: N/ASummary by CodeRabbit