Skip to content

[WIP]Support GKD and on-policy distillation#112

Open
tastelikefeet wants to merge 5 commits intomodelscope:mainfrom
tastelikefeet:feat/gkd
Open

[WIP]Support GKD and on-policy distillation#112
tastelikefeet wants to merge 5 commits intomodelscope:mainfrom
tastelikefeet:feat/gkd

Conversation

@tastelikefeet
Copy link
Collaborator

PR type

  • Bug Fix
  • New Feature
  • Document Updates
  • More Models or Datasets Support

PR information

Write the detail information belongs to this PR.

Experiment results

Paste your experiment result here(if needed).

@tastelikefeet tastelikefeet changed the title Support GKD and on-policy distillation [WIP]Support GKD and on-policy distillation Mar 10, 2026
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly expands the framework's capabilities in knowledge distillation by integrating the Generalized Knowledge Distillation (GKD) loss. It provides two distinct strategies for model training: an off-policy approach that leverages existing data, and an on-policy method that dynamically generates teacher responses. These additions enable more versatile and efficient training of smaller student models from larger teacher models, catering to different data availability and exploration needs.

Highlights

  • New Feature: GKD Loss Implementation: Introduced the Generalized Knowledge Distillation (GKD) loss, based on Jensen-Shannon Divergence, which supports various modes including forward KL, reverse KL, and symmetric JSD.
  • New Feature: Off-Policy Distillation: Added a new cookbook example for off-policy GKD distillation, where the student model learns from a teacher's token distribution on pre-existing reference responses from a dataset.
  • New Feature: On-Policy Distillation: Implemented a new cookbook example for on-policy GKD distillation, allowing the teacher vLLM to generate fresh responses dynamically for each prompt, and the student to learn from these on-the-fly completions.
  • Data Preprocessing Enhancements: Added a new GSM8KFullProcessor for handling full-text datasets in distillation tasks and made a minor adjustment to the GSM8KProcessor regex.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • cookbook/rl/gkd_off_policy.py
    • Added a new script for off-policy GKD distillation.
  • cookbook/rl/gkd_on_policy.py
    • Added a new script for on-policy GKD distillation.
  • src/twinkle/loss/init.py
    • Imported the new GKDLoss.
    • Registered GKDLoss in the loss factory.
  • src/twinkle/loss/gkd.py
    • Added the GKDLoss class, implementing generalized Jensen-Shannon Divergence for knowledge distillation.
    • Included support for full-vocabulary, top-k local teacher, and remote API teacher modes.
    • Implemented chunked loss accumulation to manage GPU memory efficiently.
  • src/twinkle/preprocessor/init.py
    • Imported GSM8KFullProcessor and GSM8KProcessor.
  • src/twinkle/preprocessor/llm.py
    • Adjusted the regex in GSM8KProcessor to handle decimal points in ground truth extraction.
    • Added GSM8KFullProcessor to include reference answers as assistant messages for full-text datasets.
Activity
  • The pull request is currently a Work In Progress (WIP) by the author, tastelikefeet.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for Generalized Knowledge Distillation (GKD) with both on-policy and off-policy strategies. The changes include a new GKDLoss implementation, corresponding preprocessors for the GSM8K dataset, and two example scripts in the cookbook demonstrating how to use these new features. The code is well-structured and the examples are clear. I've found one critical issue in the GKDLoss implementation where the forward and reverse KL-divergence calculations for the special cases of beta=0 and beta=1 are swapped. I've provided a suggestion to correct this. Other than that, the changes look good.

Comment on lines +211 to +216
if beta == 0:
# Forward KL: KL(S || T)
jsd_chunk = F.kl_div(s_log_probs, t_log_probs, reduction='none', log_target=True)
elif beta == 1:
# Reverse KL: KL(T || S)
jsd_chunk = F.kl_div(t_log_probs, s_log_probs, reduction='none', log_target=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The implementation for the special cases of beta=0 (forward KL) and beta=1 (reverse KL) appears to be swapped.

According to the PyTorch documentation for torch.nn.functional.kl_div, when log_target=True, F.kl_div(input, target, ...) computes KL(target || input).

  • For beta = 0, which corresponds to forward KL divergence KL(S || T), the calculation should be KL(student || teacher). The code currently has F.kl_div(s_log_probs, t_log_probs, ...), which computes KL(teacher || student).
  • For beta = 1, which corresponds to reverse KL divergence KL(T || S), the calculation should be KL(teacher || student). The code currently has F.kl_div(t_log_probs, s_log_probs, ...), which computes KL(student || teacher).

This swap will lead to incorrect loss calculations for these important special cases.

Suggested change
if beta == 0:
# Forward KL: KL(S || T)
jsd_chunk = F.kl_div(s_log_probs, t_log_probs, reduction='none', log_target=True)
elif beta == 1:
# Reverse KL: KL(T || S)
jsd_chunk = F.kl_div(t_log_probs, s_log_probs, reduction='none', log_target=True)
if beta == 0:
# Forward KL: KL(S || T)
jsd_chunk = F.kl_div(t_log_probs, s_log_probs, reduction='none', log_target=True)
elif beta == 1:
# Reverse KL: KL(T || S)
jsd_chunk = F.kl_div(s_log_probs, t_log_probs, reduction='none', log_target=True)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant