#!/usr/bin/env python3
"""
构造128K长序列的假数据用于验证Qwen3_5长序列功能
"""
import json
import os
# 配置
TARGET_SEQ_LEN = 131072 # 128K
OUTPUT_DIR = "/home/data/datasets/coco"
OUTPUT_FILE = os.path.join(OUTPUT_DIR, "mllm_format_long_seq_128k.json")
# 计算需要重复多少次来达到128K
# 中文字符约等于1个token,英文单词约1.3个token
# 构造一个约128K字符的长文本
LONG_TEXT = "这是一段用于测试长序列的文本。" * 10000 # 约128K字符
# 构造多条超长数据(至少2条,因为使用2卡训练且drop_last=true)
NUM_SAMPLES = 4 # 生成4条数据
data = []
for i in range(NUM_SAMPLES):
data.append({
"images": ["/home/data/datasets/coco/COCO2017/train2017/000000033471.jpg"],
"messages": [
{
"role": "user",
"content": f"<image>\n请详细描述这张图片的内容,这是第{i+1}条测试数据。"
},
{
"role": "assistant",
"content": LONG_TEXT
}
]
})
# 确保输出目录存在
os.makedirs(OUTPUT_DIR, exist_ok=True)
# 写入文件
with open(OUTPUT_FILE, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=2)
print(f"已生成假数据文件: {OUTPUT_FILE}")
print(f"文本长度: {len(LONG_TEXT)} 字符")
print(f"目标序列长度: {TARGET_SEQ_LEN}")
print(f"\n请确保配置文件中的 cutoff_len 设置为 {TARGET_SEQ_LEN} 或更大")
# pylint: skip-file
import logging
import os
from datetime import datetime
import torch
from mindspeed.fsdp.utils.log import print_rank
from mindspeed_mm.fsdp.utils.dtype import get_dtype
from mindspeed_mm.fsdp.distributed.fully_shard_parallel import pregather_fsdp_params
from mindspeed_mm.fsdp.distributed.parallel_state import get_parallel_state
from mindspeed_mm.fsdp.utils.utils import move_to_device, get_time, configure_hsdp_gradient_sync, tensor_to_dtensor
from mindspeed_mm.fsdp.data.data_utils.utils import build_iterations
from mindspeed_mm.fsdp.optimizer.clip_grad_norm import clip_grad_norm
from mindspeed_mm.fsdp.tools.profiler import Profiler
from mindspeed_mm.fsdp.tools.memory_profiler import memory_profiler
from mindspeed_mm.fsdp.loss.loss_func import build_loss_func
from mindspeed_mm.fsdp.params.argument import Arguments
from mindspeed_mm.fsdp.utils.lora_utils import load_state_dict
from mindspeed_mm.fsdp.data.dataloader.dataloader import Preloader
logger = logging.getLogger(__name__)
class TrainEngine:
"""Training engine that manages the main training loop and operations."""
def __init__(
self,
args: Arguments,
train_dataloader,
model,
optimizer,
scheduler,
checkpointer,
lora_weight_manager=None,
**kwargs,
):
self.args = args
self.model = model
self.train_dataloader = train_dataloader
self.optimizer = optimizer
self.lr_scheduler = scheduler
self.checkpointer = checkpointer
self.lora_weight_manager = lora_weight_manager
# Training state tracking
self.iteration, self.consumed_train_samples = 0, 0
# Load checkpoint if specified
if args.training.load:
tracker_file = os.path.join(args.training.load, 'latest_checkpointed_iteration.txt')
if os.path.exists(tracker_file):
try:
with open(tracker_file, 'r') as f:
iteration = int(f.read().strip())
# Check for lightweight training state file (LoRA mode)
state_path = os.path.join(args.training.load, f"training_state_{iteration}.pt")
if os.path.exists(state_path):
# LoRA mode: load training state and LoRA weights separately
training_state = torch.load(state_path, map_location="cpu")
self.iteration = training_state["iteration"]
self.consumed_train_samples = training_state["consumed_train_samples"]
# Load LoRA weights
lora_path = os.path.join(args.training.load, f"lora_adapter_iteration_{iteration}.safetensors")
if os.path.exists(lora_path) and self.lora_weight_manager is not None:
self.lora_weight_manager.load_lora_weights(lora_path)
# Load optimizer and scheduler state
if not args.training.no_load_optim and "optimizer" in training_state:
self.optimizer.load_state_dict(training_state["optimizer"])
if "lr_scheduler" in training_state:
self.lr_scheduler.load_state_dict(training_state["lr_scheduler"])
if self.train_dataloader is not None and "train_dataloader" in training_state:
self.train_dataloader.load_state_dict(training_state["train_dataloader"])
if not args.training.no_load_rng and "torch_rng_state" in training_state:
torch.set_rng_state(training_state["torch_rng_state"])
print_rank(logger.info, f"Loaded LoRA training state from {state_path}")
else:
# Fallback to full DCP checkpoint loading
self.iteration, self.consumed_train_samples = self.load()
except Exception as e:
print_rank(logger.warning, f"Failed to load checkpoint: {e}")
else:
print_rank(logger.warning, f"Checkpoint tracker file not found at {args.training.load}, starting from scratch.")
# Load pretrained LoRA weights (if configured)
if (
args.training.init_model_with_meta_device
and args.training.lora.enable
and args.training.lora.pretrained_lora_path
):
lora_state_dict = load_state_dict(args.training.lora.pretrained_lora_path)
model_state_dict = model.state_dict()
for key, value in lora_state_dict.items():
if key in model_state_dict:
target_tensor = model_state_dict[key]
device_mesh = getattr(target_tensor, "device_mesh", None)
placements = getattr(target_tensor, "placements", None)
if device_mesh is not None and placements is not None:
target_tensor.copy_(tensor_to_dtensor(value, device_mesh, placements))
else:
target_tensor.copy_(value)
print_rank(
logger.info,
f"Reloaded {len(lora_state_dict)} LoRA parameters from {args.training.lora.pretrained_lora_path}",
)
self.profiler = Profiler(args.tools.profile)
self.profiler.start()
def average_losses_across_data_parallel_group(self, losses):
"""Reduce a tensor of losses across all GPUs."""
ps = get_parallel_state()
averaged_losses = torch.cat([loss.clone().detach().view(1) for loss in losses])
torch.distributed.all_reduce(averaged_losses, group=ps.get_dp_group())
averaged_losses = averaged_losses / torch.distributed.get_world_size(group=ps.get_dp_group())
return averaged_losses
def get_batch(self, data_iterator):
"""Generate a batch."""
if data_iterator is not None:
batch = next(data_iterator)
else:
raise ValueError("Data iterator is None. Unable to retrieve batch.")
# Move input to device and cast precision
if not self.args.data.dataloader_param.enable_preload:
param_dtype = self.args.parallel.fsdp_plan.param_dtype
batch = move_to_device(batch, get_dtype(param_dtype) if param_dtype else None)
return batch
def set_loss_func(self, batch_data):
args = self.args
if args.features.loss_cfg.loss_type == "raw":
return
chunk_size = args.features.chunkloss_plan.chunk_size if args.features.enable_chunk_loss else None
if args.features.enable_dynamic_chunk_loss:
batch_data['total_chunk_size'] = args.features.chunkloss_plan.total_chunk_size
loss_func = build_loss_func(args.features.loss_cfg.loss_type, chunk_size=chunk_size, **batch_data)
if hasattr(self.model, "loss_function"):
self.model.loss_function = loss_func
else:
setattr(self.model, "loss_function", loss_func)
output_router_logits = args.features.loss_cfg.router_aux_loss_coef > 0.0
if output_router_logits:
batch_data.update(output_router_logits=True)
def train_step(self, train_dataloader_iter):
"""Perform a single training step with gradient accumulation."""
args = self.args
total_loss = 0
total_aux_loss = None
# Gradient accumulation
for step in range(args.training.gradient_accumulation_steps):
# Wait for the preloaded batch to be ready
batch_data = self.get_batch(train_dataloader_iter)
# setup loss ctx
self.set_loss_func(batch_data)
# Determine if this is the last step of gradient accumulation
is_last_step = step == args.training.gradient_accumulation_steps - 1
configure_hsdp_gradient_sync(self.model, is_last_step)
# forward step
output = self.model(**batch_data, use_cache=False)
loss = output.loss / args.training.gradient_accumulation_steps
# Backward
loss.backward()
total_loss += loss
if getattr(output, 'aux_loss', None) is not None:
aux_loss = output.aux_loss / args.training.gradient_accumulation_steps
total_aux_loss = aux_loss if total_aux_loss is None else total_aux_loss + aux_loss
# Average loss across data parallel group
total_loss = self.average_losses_across_data_parallel_group([total_loss])
return total_loss, total_aux_loss
def train(self):
"""Main training loop."""
args = self.args
# Get data iterator
train_dataloader_iter, _, _ = build_iterations(self.train_dataloader)
param_dtype = get_dtype(args.parallel.fsdp_plan.param_dtype) if args.parallel.fsdp_plan.param_dtype else None
# Preload data
if args.data.dataloader_param.enable_preload:
train_dataloader_iter = Preloader(train_dataloader_iter, param_dtype=param_dtype)
self.model.train()
# --- Train Loop ---
curr_step_lr = self.lr_scheduler.get_last_lr()[0]
while self.iteration < args.training.train_iters:
# Record memory usage if enabled
memory_profiler.step()
start_time = get_time(barrier=True)
if self.args.parallel.fsdp_plan.pregather:
pregather_fsdp_params(self.model)
loss, aux_loss = self.train_step(train_dataloader_iter)
# Clip gradients when clip_grad>0 and get total grad_norm
grad_norm = clip_grad_norm(
self.model, max_norm=args.training.clip_grad, foreach=args.training.clip_grad_foreach
)
# Update parameters
self.optimizer.step()
self.lr_scheduler.step()
self.optimizer.zero_grad()
# Update training state
self.consumed_train_samples += args.training.global_batch_size
self.iteration += 1
# Calculate iteration time
elapsed_time_per_iteration = get_time(barrier=True) - start_time
# Stop profiling if enabled
self.profiler.step()
# Logging
if self.iteration % args.training.log_interval == 0:
self.training_log(
self.iteration,
elapsed_time_per_iteration,
curr_step_lr,
self.consumed_train_samples,
loss,
aux_loss,
grad_norm,
)
curr_step_lr = self.lr_scheduler.get_last_lr()[0]
# Save checkpoint at specified intervals
if (
args.training.save
and args.training.save_interval > 0
and self.iteration % args.training.save_interval == 0
):
self.save(self.iteration, self.consumed_train_samples)
# Stop profiling if enabled
self.profiler.stop()
memory_profiler.stop()
# Final save after training completes
if args.training.save:
self.save(self.iteration, self.consumed_train_samples)
def training_log(
self, iteration, elapsed_time_per_iteration, curr_step_lr, consumed_train_samples, loss, aux_loss, grad_norm
):
args = self.args
log_string = f" [{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}]"
log_string += ' iteration {:8d}/{:8d} |'.format(iteration, args.training.train_iters)
log_string += ' consumed samples: {:12d} |'.format(consumed_train_samples)
log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(elapsed_time_per_iteration * 1000.0)
log_string += ' learning rate: {:.6E} |'.format(curr_step_lr)
log_string += ' global batch size: {:5d} |'.format(args.training.global_batch_size)
log_string += ' loss: {:.6E} |'.format(loss.item())
if aux_loss is not None:
log_string += ' aux loss: {:.6E} |'.format(aux_loss.item())
if grad_norm is not None:
log_string += ' grad norm: {:.3f} |'.format(grad_norm)
print_rank(logger.info, log_string)
def load(self):
"""Load checkpoint and restore training state."""
args = self.args
iteration, consumed_train_samples = 0, 0
state = {"model": self.model, "extra_state": {}} # cannot be None
if not args.training.no_load_optim:
state["optimizer"] = self.optimizer
release = self.checkpointer.load(
path=args.training.load,
state=state,
load_rank0_and_broadcast=args.training.load_rank0_and_broadcast,
load_strict=args.training.load_strict,
enable_lora=args.training.lora.enable,
)
if not release:
iteration = state["extra_state"]["iteration"]
consumed_train_samples = state["extra_state"]["consumed_train_samples"]
self.lr_scheduler.load_state_dict(state["extra_state"]["lr_scheduler"])
if self.train_dataloader is not None:
self.train_dataloader.load_state_dict(state["extra_state"]["train_dataloader"])
if not args.training.no_load_rng:
if "torch_rng_state" not in state["extra_state"]:
print_rank(logger.warning, "No RNG state found in checkpoint, skipping RNG loading")
else:
torch.set_rng_state(state["extra_state"]["torch_rng_state"])
# Synchronize all processes after loading
torch.distributed.barrier()
return iteration, consumed_train_samples
def save(self, iteration, consumed_train_samples):
"""Save checkpoint with model, optimizer, and training state."""
args = self.args
# Handle LoRA save modes: Save only LoRA adapter weights + training state
if args.training.lora.enable:
if self.lora_weight_manager is not None:
self.lora_weight_manager.save_lora_only(
save_path=args.training.save,
iteration=iteration,
)
# Save optimizer state and training metadata for resume
# This avoids saving the full DCP checkpoint (which includes base model weights)
training_state = {
"iteration": iteration,
"consumed_train_samples": consumed_train_samples,
"lr_scheduler": self.lr_scheduler.state_dict(),
"train_dataloader": self.train_dataloader.state_dict(),
}
if not args.training.no_save_optim:
training_state["optimizer"] = self.optimizer.state_dict()
if not args.training.no_save_rng:
training_state["torch_rng_state"] = torch.get_rng_state()
state_path = os.path.join(args.training.save, f"training_state_{iteration}.pt")
torch.save(training_state, state_path)
# Update tracker file
tracker_file = os.path.join(args.training.save, 'latest_checkpointed_iteration.txt')
with open(tracker_file, 'w') as f:
f.write(str(iteration))
print_rank(logger.info, f"Saved training state to {state_path}")
torch.distributed.barrier()
return
# Default save behavior (full model)
state = {
"model": self.model,
"extra_state": {
"iteration": iteration,
"consumed_train_samples": consumed_train_samples,
"lr_scheduler": self.lr_scheduler.state_dict(),
"train_dataloader": self.train_dataloader.state_dict(),
},
}
if not args.training.no_save_optim:
state["optimizer"] = self.optimizer
if not args.training.no_save_rng:
state["extra_state"]["torch_rng_state"] = torch.get_rng_state()
self.checkpointer.save(args.training.save, state=state, iteration=iteration)
# Synchronize all processes after saving
torch.distributed.barrier()
128K长序列数据生成
Miss_min2026-05-21 16:48
相关推荐
来恩10034 小时前
JSTL的标签库种类love530love4 小时前
MingLi-Bench 项目部署实录:基于 EPGF 架构的工程化实践小宋0014 小时前
QT中控件qss样式修改图像僧4 小时前
vs2019中的属性页使用说明YOU OU4 小时前
SpringBoot 日志猿儿本无心4 小时前
快速搭建Python项目(Vscode+uv+FastAPI)AI算法沐枫4 小时前
大模型 | 大模型之机器学习基本理论li星野4 小时前
Transformer 核心模块详解:多头注意力、前馈网络与词嵌入动物园猫5 小时前
面向智慧牧场的牛行为识别数据集分享(适用于YOLO系列深度学习分类检测任务)