Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 38 additions & 17 deletions cookbook/transformers/ep_fsdp_qwen3_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,23 @@
TEMPLATE_ID = os.environ.get('TEMPLATE_ID', 'Template')
_num_layers_env = os.environ.get('NUM_LAYERS')
NUM_LAYERS = int(_num_layers_env) if _num_layers_env is not None else None
BATCH_SIZE = int(os.environ.get('BATCH_SIZE', '4'))
GRAD_ACCUM_STEPS = int(os.environ.get('GRAD_ACCUM_STEPS', '4'))
LR = float(os.environ.get('LR', '1e-5'))
MAX_GRAD_NORM = float(os.environ.get('MAX_GRAD_NORM', '1.0'))
KEEP_ROUTER_LOGITS = os.environ.get('KEEP_ROUTER_LOGITS', '0') == '1'

# 4 gpus, dp=2, ep=2
dp_size = 2
ep_size = 2
# 8 gpus, dp=1, fsdp=8 (data parallel), ep_size=8 (expert parallel)
# The main mesh does NOT include 'ep' dimension - EP is handled by separate ep_fsdp_device_mesh
dp_size = 1
fsdp_size = 8
ep_size = 8

device_mesh = DeviceMesh(
device_type=Platform.get_platform().device_prefix(),
mesh=np.arange(dp_size * ep_size).reshape(dp_size, ep_size),
mesh_dim_names=('dp', 'ep'),
mesh=np.arange(fsdp_size * dp_size).reshape(fsdp_size, dp_size),
mesh_dim_names=('fsdp', 'dp'),
ep_size=ep_size,
)

twinkle.initialize(
Expand All @@ -41,7 +49,7 @@ def train():
if hasattr(config, 'use_cache'):
config.use_cache = False

dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(1000)))
dataset = Dataset(dataset_meta=DatasetMeta(DATASET_ID, data_slice=range(1000)))
try:
dataset.set_template(TEMPLATE_ID, model_id=MODEL_ID)
except ValueError:
Expand All @@ -51,11 +59,10 @@ def train():
dataset.encode(batched=True)
dataloader = DataLoader(
dataset=dataset,
batch_size=4,
batch_size=BATCH_SIZE,
device_mesh=device_mesh,
)

grad_accum_steps = 4
model = TransformersModel(
model_id=MODEL_ID,
config=config,
Expand All @@ -64,29 +71,43 @@ def train():
'expert_parallel': {
'enabled': True,
'router_dtype': 'fp32',
'all_to_all': 'torch',
'keep_router_logits': False,
'keep_router_logits': KEEP_ROUTER_LOGITS,
}
},
)
# Disable foreach to avoid DTensor mixed-type errors in EP runs.
model.set_optimizer('AdamW', foreach=False)
model.set_optimizer('AdamW', lr=LR, foreach=False)
model.set_lr_scheduler(
scheduler_cls='CosineWarmupScheduler',
num_warmup_steps=5,
num_training_steps=len(dataloader),
)

logger.info(get_device_placement())
logger.info(model.get_train_configs())
logger.info(
f'Total steps: {len(dataloader)}, batch_size={BATCH_SIZE}, grad_accum={GRAD_ACCUM_STEPS}, '
f'lr={LR:.2e}, max_grad_norm={MAX_GRAD_NORM}, '
f'keep_router_logits={KEEP_ROUTER_LOGITS}')

for step, batch in enumerate(dataloader):
if callable(batch):
batch = batch()
model.forward_backward(inputs=batch, gradient_accumulation_steps=grad_accum_steps)
model.clip_grad_and_step(gradient_accumulation_steps=grad_accum_steps)
if step % grad_accum_steps == 0:
model.forward_backward(inputs=batch, gradient_accumulation_steps=GRAD_ACCUM_STEPS)
model.clip_grad_and_step(
max_grad_norm=MAX_GRAD_NORM,
gradient_accumulation_steps=GRAD_ACCUM_STEPS,
)

is_sync_step = ((step + 1) % GRAD_ACCUM_STEPS == 0)
if is_sync_step:
optimizer_step = (step + 1) // GRAD_ACCUM_STEPS
metric = model.calculate_metric(is_training=True)
if callable(metric):
metric = metric()
logger.info(f'Current is step {step // grad_accum_steps}, metric: {metric}')
if step > 0 and step % 50 == 0:
model.save('./output')
logger.info(f'Current optimizer_step {optimizer_step}, metric: {metric}')
if optimizer_step > 0 and optimizer_step % 50 == 0:
model.save(name=f'checkpoint-step-{optimizer_step}', output_dir='./output')


if __name__ == '__main__':
Expand Down
4 changes: 2 additions & 2 deletions src/twinkle/model/transformers/moe/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
from .expert_parallel import apply_expert_parallel
from .expert_parallel import ExpertShardingSpec, apply_expert_parallel

__all__ = ['apply_expert_parallel']
__all__ = ['ExpertShardingSpec', 'apply_expert_parallel']
Loading
Loading