128K长序列数据生成

复制代码
#!/usr/bin/env python3
"""
构造128K长序列的假数据用于验证Qwen3_5长序列功能
"""
import json
import os

# 配置
TARGET_SEQ_LEN = 131072  # 128K
OUTPUT_DIR = "/home/data/datasets/coco"
OUTPUT_FILE = os.path.join(OUTPUT_DIR, "mllm_format_long_seq_128k.json")

# 计算需要重复多少次来达到128K
# 中文字符约等于1个token,英文单词约1.3个token
# 构造一个约128K字符的长文本
LONG_TEXT = "这是一段用于测试长序列的文本。" * 10000  # 约128K字符

# 构造多条超长数据(至少2条,因为使用2卡训练且drop_last=true)
NUM_SAMPLES = 4  # 生成4条数据
data = []
for i in range(NUM_SAMPLES):
    data.append({
        "images": ["/home/data/datasets/coco/COCO2017/train2017/000000033471.jpg"],
        "messages": [
            {
                "role": "user",
                "content": f"<image>\n请详细描述这张图片的内容,这是第{i+1}条测试数据。"
            },
            {
                "role": "assistant",
                "content": LONG_TEXT
            }
        ]
    })

# 确保输出目录存在
os.makedirs(OUTPUT_DIR, exist_ok=True)

# 写入文件
with open(OUTPUT_FILE, "w", encoding="utf-8") as f:
    json.dump(data, f, ensure_ascii=False, indent=2)

print(f"已生成假数据文件: {OUTPUT_FILE}")
print(f"文本长度: {len(LONG_TEXT)} 字符")
print(f"目标序列长度: {TARGET_SEQ_LEN}")
print(f"\n请确保配置文件中的 cutoff_len 设置为 {TARGET_SEQ_LEN} 或更大")

# pylint: skip-file
import logging
import os
from datetime import datetime

import torch

from mindspeed.fsdp.utils.log import print_rank

from mindspeed_mm.fsdp.utils.dtype import get_dtype
from mindspeed_mm.fsdp.distributed.fully_shard_parallel import pregather_fsdp_params
from mindspeed_mm.fsdp.distributed.parallel_state import get_parallel_state
from mindspeed_mm.fsdp.utils.utils import move_to_device, get_time, configure_hsdp_gradient_sync, tensor_to_dtensor
from mindspeed_mm.fsdp.data.data_utils.utils import build_iterations
from mindspeed_mm.fsdp.optimizer.clip_grad_norm import clip_grad_norm
from mindspeed_mm.fsdp.tools.profiler import Profiler
from mindspeed_mm.fsdp.tools.memory_profiler import memory_profiler
from mindspeed_mm.fsdp.loss.loss_func import build_loss_func
from mindspeed_mm.fsdp.params.argument import Arguments
from mindspeed_mm.fsdp.utils.lora_utils import load_state_dict
from mindspeed_mm.fsdp.data.dataloader.dataloader import Preloader

logger = logging.getLogger(__name__)


class TrainEngine:
    """Training engine that manages the main training loop and operations."""

    def __init__(
        self,
        args: Arguments,
        train_dataloader,
        model,
        optimizer,
        scheduler,
        checkpointer,
        lora_weight_manager=None,
        **kwargs,
    ):
        self.args = args

        self.model = model
        self.train_dataloader = train_dataloader
        self.optimizer = optimizer
        self.lr_scheduler = scheduler
        self.checkpointer = checkpointer
        self.lora_weight_manager = lora_weight_manager

        # Training state tracking
        self.iteration, self.consumed_train_samples = 0, 0

        # Load checkpoint if specified
        if args.training.load:
            tracker_file = os.path.join(args.training.load, 'latest_checkpointed_iteration.txt')
            if os.path.exists(tracker_file):
                try:
                    with open(tracker_file, 'r') as f:
                        iteration = int(f.read().strip())
                    
                    # Check for lightweight training state file (LoRA mode)
                    state_path = os.path.join(args.training.load, f"training_state_{iteration}.pt")
                    if os.path.exists(state_path):
                        # LoRA mode: load training state and LoRA weights separately
                        training_state = torch.load(state_path, map_location="cpu")
                        self.iteration = training_state["iteration"]
                        self.consumed_train_samples = training_state["consumed_train_samples"]
                        
                        # Load LoRA weights
                        lora_path = os.path.join(args.training.load, f"lora_adapter_iteration_{iteration}.safetensors")
                        if os.path.exists(lora_path) and self.lora_weight_manager is not None:
                            self.lora_weight_manager.load_lora_weights(lora_path)
                        
                        # Load optimizer and scheduler state
                        if not args.training.no_load_optim and "optimizer" in training_state:
                            self.optimizer.load_state_dict(training_state["optimizer"])
                        if "lr_scheduler" in training_state:
                            self.lr_scheduler.load_state_dict(training_state["lr_scheduler"])
                        if self.train_dataloader is not None and "train_dataloader" in training_state:
                            self.train_dataloader.load_state_dict(training_state["train_dataloader"])
                        if not args.training.no_load_rng and "torch_rng_state" in training_state:
                            torch.set_rng_state(training_state["torch_rng_state"])
                            
                        print_rank(logger.info, f"Loaded LoRA training state from {state_path}")
                    else:
                        # Fallback to full DCP checkpoint loading
                        self.iteration, self.consumed_train_samples = self.load()
                except Exception as e:
                    print_rank(logger.warning, f"Failed to load checkpoint: {e}")
            else:
                print_rank(logger.warning, f"Checkpoint tracker file not found at {args.training.load}, starting from scratch.")

        # Load pretrained LoRA weights (if configured)
        if (
            args.training.init_model_with_meta_device
            and args.training.lora.enable
            and args.training.lora.pretrained_lora_path
        ):
            lora_state_dict = load_state_dict(args.training.lora.pretrained_lora_path)
            model_state_dict = model.state_dict()
            for key, value in lora_state_dict.items():
                if key in model_state_dict:
                    target_tensor = model_state_dict[key]
                    device_mesh = getattr(target_tensor, "device_mesh", None)
                    placements = getattr(target_tensor, "placements", None)
                    if device_mesh is not None and placements is not None:
                        target_tensor.copy_(tensor_to_dtensor(value, device_mesh, placements))
                    else:
                        target_tensor.copy_(value)
            print_rank(
                logger.info,
                f"Reloaded {len(lora_state_dict)} LoRA parameters from {args.training.lora.pretrained_lora_path}",
            )

        self.profiler = Profiler(args.tools.profile)
        self.profiler.start()

    def average_losses_across_data_parallel_group(self, losses):
        """Reduce a tensor of losses across all GPUs."""
        ps = get_parallel_state()
        averaged_losses = torch.cat([loss.clone().detach().view(1) for loss in losses])
        torch.distributed.all_reduce(averaged_losses, group=ps.get_dp_group())
        averaged_losses = averaged_losses / torch.distributed.get_world_size(group=ps.get_dp_group())

        return averaged_losses

    def get_batch(self, data_iterator):
        """Generate a batch."""
        if data_iterator is not None:
            batch = next(data_iterator)
        else:
            raise ValueError("Data iterator is None. Unable to retrieve batch.")

        # Move input to device and cast precision
        if not self.args.data.dataloader_param.enable_preload:
            param_dtype = self.args.parallel.fsdp_plan.param_dtype
            batch = move_to_device(batch, get_dtype(param_dtype) if param_dtype else None)
        return batch

    def set_loss_func(self, batch_data):
        args = self.args
        if args.features.loss_cfg.loss_type == "raw":
            return
        chunk_size = args.features.chunkloss_plan.chunk_size if args.features.enable_chunk_loss else None
        if args.features.enable_dynamic_chunk_loss:
            batch_data['total_chunk_size'] = args.features.chunkloss_plan.total_chunk_size
        loss_func = build_loss_func(args.features.loss_cfg.loss_type, chunk_size=chunk_size, **batch_data)

        if hasattr(self.model, "loss_function"):
            self.model.loss_function = loss_func
        else:
            setattr(self.model, "loss_function", loss_func)

        output_router_logits = args.features.loss_cfg.router_aux_loss_coef > 0.0
        if output_router_logits:
            batch_data.update(output_router_logits=True)

    def train_step(self, train_dataloader_iter):
        """Perform a single training step with gradient accumulation."""
        args = self.args
        total_loss = 0
        total_aux_loss = None
        # Gradient accumulation
        for step in range(args.training.gradient_accumulation_steps):
            # Wait for the preloaded batch to be ready
            batch_data = self.get_batch(train_dataloader_iter)

            # setup loss ctx
            self.set_loss_func(batch_data)

            # Determine if this is the last step of gradient accumulation
            is_last_step = step == args.training.gradient_accumulation_steps - 1
            configure_hsdp_gradient_sync(self.model, is_last_step)

            # forward step
            output = self.model(**batch_data, use_cache=False)
            loss = output.loss / args.training.gradient_accumulation_steps

            # Backward
            loss.backward()

            total_loss += loss
            if getattr(output, 'aux_loss', None) is not None:
                aux_loss = output.aux_loss / args.training.gradient_accumulation_steps
                total_aux_loss = aux_loss if total_aux_loss is None else total_aux_loss + aux_loss

        # Average loss across data parallel group
        total_loss = self.average_losses_across_data_parallel_group([total_loss])

        return total_loss, total_aux_loss

    def train(self):
        """Main training loop."""
        args = self.args

        # Get data iterator
        train_dataloader_iter, _, _ = build_iterations(self.train_dataloader)
        param_dtype = get_dtype(args.parallel.fsdp_plan.param_dtype) if args.parallel.fsdp_plan.param_dtype else None
        
        # Preload data
        if args.data.dataloader_param.enable_preload:
            train_dataloader_iter = Preloader(train_dataloader_iter, param_dtype=param_dtype)

        self.model.train()

        # --- Train Loop ---
        curr_step_lr = self.lr_scheduler.get_last_lr()[0]
        while self.iteration < args.training.train_iters:
            # Record memory usage if enabled
            memory_profiler.step()
            start_time = get_time(barrier=True)

            if self.args.parallel.fsdp_plan.pregather:
                pregather_fsdp_params(self.model)

            loss, aux_loss = self.train_step(train_dataloader_iter)

            # Clip gradients when clip_grad>0 and get total grad_norm
            grad_norm = clip_grad_norm(
                self.model, max_norm=args.training.clip_grad, foreach=args.training.clip_grad_foreach
            )

            # Update parameters
            self.optimizer.step()
            self.lr_scheduler.step()
            self.optimizer.zero_grad()

            # Update training state
            self.consumed_train_samples += args.training.global_batch_size
            self.iteration += 1

            # Calculate iteration time
            elapsed_time_per_iteration = get_time(barrier=True) - start_time

            # Stop profiling if enabled
            self.profiler.step()

            # Logging
            if self.iteration % args.training.log_interval == 0:
                self.training_log(
                    self.iteration,
                    elapsed_time_per_iteration,
                    curr_step_lr,
                    self.consumed_train_samples,
                    loss,
                    aux_loss,
                    grad_norm,
                )

            curr_step_lr = self.lr_scheduler.get_last_lr()[0]

            # Save checkpoint at specified intervals
            if (
                args.training.save
                and args.training.save_interval > 0
                and self.iteration % args.training.save_interval == 0
            ):
                self.save(self.iteration, self.consumed_train_samples)

        # Stop profiling if enabled
        self.profiler.stop()
        memory_profiler.stop()
        # Final save after training completes
        if args.training.save:
            self.save(self.iteration, self.consumed_train_samples)

    def training_log(
        self, iteration, elapsed_time_per_iteration, curr_step_lr, consumed_train_samples, loss, aux_loss, grad_norm
    ):
        args = self.args
        log_string = f" [{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}]"
        log_string += ' iteration {:8d}/{:8d} |'.format(iteration, args.training.train_iters)
        log_string += ' consumed samples: {:12d} |'.format(consumed_train_samples)
        log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(elapsed_time_per_iteration * 1000.0)
        log_string += ' learning rate: {:.6E} |'.format(curr_step_lr)
        log_string += ' global batch size: {:5d} |'.format(args.training.global_batch_size)
        log_string += ' loss: {:.6E} |'.format(loss.item())

        if aux_loss is not None:
            log_string += ' aux loss: {:.6E} |'.format(aux_loss.item())

        if grad_norm is not None:
            log_string += ' grad norm: {:.3f} |'.format(grad_norm)

        print_rank(logger.info, log_string)

    def load(self):
        """Load checkpoint and restore training state."""
        args = self.args
        iteration, consumed_train_samples = 0, 0

        state = {"model": self.model, "extra_state": {}}  # cannot be None
        if not args.training.no_load_optim:
            state["optimizer"] = self.optimizer

        release = self.checkpointer.load(
            path=args.training.load,
            state=state,
            load_rank0_and_broadcast=args.training.load_rank0_and_broadcast,
            load_strict=args.training.load_strict,
            enable_lora=args.training.lora.enable,
        )

        if not release:
            iteration = state["extra_state"]["iteration"]
            consumed_train_samples = state["extra_state"]["consumed_train_samples"]

            self.lr_scheduler.load_state_dict(state["extra_state"]["lr_scheduler"])
            if self.train_dataloader is not None:
                self.train_dataloader.load_state_dict(state["extra_state"]["train_dataloader"])
            if not args.training.no_load_rng:
                if "torch_rng_state" not in state["extra_state"]:
                    print_rank(logger.warning, "No RNG state found in checkpoint, skipping RNG loading")
                else:
                    torch.set_rng_state(state["extra_state"]["torch_rng_state"])

        # Synchronize all processes after loading
        torch.distributed.barrier()

        return iteration, consumed_train_samples

    def save(self, iteration, consumed_train_samples):
        """Save checkpoint with model, optimizer, and training state."""
        args = self.args

        # Handle LoRA save modes: Save only LoRA adapter weights + training state
        if args.training.lora.enable:
            if self.lora_weight_manager is not None:
                self.lora_weight_manager.save_lora_only(
                    save_path=args.training.save,
                    iteration=iteration,
                )
            
            # Save optimizer state and training metadata for resume
            # This avoids saving the full DCP checkpoint (which includes base model weights)
            training_state = {
                "iteration": iteration,
                "consumed_train_samples": consumed_train_samples,
                "lr_scheduler": self.lr_scheduler.state_dict(),
                "train_dataloader": self.train_dataloader.state_dict(),
            }
            if not args.training.no_save_optim:
                training_state["optimizer"] = self.optimizer.state_dict()
            if not args.training.no_save_rng:
                training_state["torch_rng_state"] = torch.get_rng_state()
                
            state_path = os.path.join(args.training.save, f"training_state_{iteration}.pt")
            torch.save(training_state, state_path)
            
            # Update tracker file
            tracker_file = os.path.join(args.training.save, 'latest_checkpointed_iteration.txt')
            with open(tracker_file, 'w') as f:
                f.write(str(iteration))
                
            print_rank(logger.info, f"Saved training state to {state_path}")
            torch.distributed.barrier()
            return

        # Default save behavior (full model)
        state = {
            "model": self.model,
            "extra_state": {
                "iteration": iteration,
                "consumed_train_samples": consumed_train_samples,
                "lr_scheduler": self.lr_scheduler.state_dict(),
                "train_dataloader": self.train_dataloader.state_dict(),
            },
        }
        if not args.training.no_save_optim:
            state["optimizer"] = self.optimizer
        if not args.training.no_save_rng:
            state["extra_state"]["torch_rng_state"] = torch.get_rng_state()
        self.checkpointer.save(args.training.save, state=state, iteration=iteration)

        # Synchronize all processes after saving
        torch.distributed.barrier()
相关推荐
来恩10034 小时前
JSTL的标签库种类
java·开发语言
love530love4 小时前
MingLi-Bench 项目部署实录:基于 EPGF 架构的工程化实践
人工智能·windows·python·架构·aigc·epgf·mingli-bench
小宋0014 小时前
QT中控件qss样式修改
开发语言·qt
图像僧4 小时前
vs2019中的属性页使用说明
java·开发语言·jvm
YOU OU4 小时前
SpringBoot 日志
java·开发语言
猿儿本无心4 小时前
快速搭建Python项目(Vscode+uv+FastAPI)
vscode·python·uv
AI算法沐枫4 小时前
大模型 | 大模型之机器学习基本理论
人工智能·python·神经网络·学习·算法·机器学习·计算机视觉
li星野4 小时前
Transformer 核心模块详解:多头注意力、前馈网络与词嵌入
人工智能·深度学习·transformer
动物园猫5 小时前
面向智慧牧场的牛行为识别数据集分享(适用于YOLO系列深度学习分类检测任务)
深度学习·yolo·分类