minimind-学习记录-环境的配置与跑通

复制代码
minimind作为一款低成本的大模型训练框架,以及普适性的微调框架。在此记录本人学习过程中的一些步骤与问题。

作者源代码链接:点击这里

环境的配置

本人习惯用conda去配置环境

python 复制代码
conda create -n minimind python=3.10 -y
conda activate

本项目基于torch框架,所以要先安装torch-cuda版本的。方法有很多,就不一一列举了

再严格按照requirements.txt中的进行安装(记得去掉torch与torchvision)

bash 复制代码
datasets==3.6.0
datasketch==1.6.4
Flask==3.0.3
Flask_Cors==4.0.0
jieba==0.42.1
jsonlines==4.0.0
marshmallow==3.22.0
matplotlib==3.10.0
ngrok==1.4.0
nltk==3.8
numpy==1.26.4
openai==1.59.6
peft==0.7.1
psutil==5.9.8
pydantic==2.11.5
rich==13.7.1
scikit_learn==1.5.1
sentence_transformers==2.3.1
simhash==2.1.2
tiktoken==0.10.0
transformers==4.57.1
jinja2==3.1.2
jsonlines==4.0.0
trl==0.13.0
ujson==5.1.0
wandb==0.18.3
streamlit==1.50.0
einops==0.8.1
swanlab==0.6.8
python 复制代码
pip install -r requirements.txt

最后附上检查代码环境的部分

test_environment.py

python 复制代码
# 测试torch是否可用cuda
import torch
print(torch.cuda.is_available())

配套数据集的载入与跑通

本文只复现预训练的部分,选择pretrain_hq.jsonl数据集;点此下载数据集

配置train_pretrain.py

python 复制代码
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="MiniMind Pretraining")
    parser.add_argument("--save_dir", type=str, default="../out", help="模型保存目录")
    parser.add_argument('--save_weight', default='pretrain', type=str, help="保存权重的前缀名")
    parser.add_argument("--epochs", type=int, default=2, help="训练轮数(建议1轮zero或2-6轮充分训练)")
    parser.add_argument("--batch_size", type=int, default=2, help="batch size")
    parser.add_argument("--learning_rate", type=float, default=5e-4, help="初始学习率")
    parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备")
    parser.add_argument("--dtype", type=str, default="bfloat16", help="混合精度类型")
    parser.add_argument("--num_workers", type=int, default=0, help="数据加载线程数")
    parser.add_argument("--accumulation_steps", type=int, default=8, help="梯度累积步数")
    parser.add_argument("--grad_clip", type=float, default=1.0, help="梯度裁剪阈值")
    parser.add_argument("--log_interval", type=int, default=100, help="日志打印间隔")
    parser.add_argument("--save_interval", type=int, default=1000, help="模型保存间隔")
    parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度")
    parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量")
    parser.add_argument('--max_seq_len', default=340, type=int, help="训练的最大截断长度(中文1token≈1.5~1.7字符)")
    parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构(0=否,1=是)")
    parser.add_argument("--data_path", type=str, default="../medical_dataset.jsonl", help="预训练数据路径")
    parser.add_argument('--from_weight', default='none', type=str, help="基于哪个权重训练,为none则从头开始")
    parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训(0=否,1=是)")
    parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
    parser.add_argument("--wandb_project", type=str, default="MiniMind-Pretrain", help="wandb项目名")
    args = parser.parse_args()

具体参数可根据个人电脑进行调整。

个人数据集的载入说明

下面给出常见的csv-jsonl的代码

python 复制代码
import csv
import json

# -------------------------- 配置参数(可根据你的文件修改) --------------------------
INPUT_CSV_PATH = "111.csv"  # 你的原始CSV文件路径
OUTPUT_JSONL_PATH = "medical_data.jsonl"  # 输出的JSONL文件路径
ENCODING = "gbk"  # 文件编码,若读取报错可改为 "gbk" / "gb2312"


# -------------------------- 核心转换逻辑 --------------------------
def csv_to_medical_jsonl():
    """将医疗CSV数据集转换为指定格式的JSONL文件"""
    # 统计转换数据量
    total_count = 0
    success_count = 0

    try:
        with open(INPUT_CSV_PATH, "r", encoding=ENCODING, newline="") as csv_file, \
                open(OUTPUT_JSONL_PATH, "w", encoding=ENCODING) as jsonl_file:

            # 读取CSV,按表头映射字段
            csv_reader = csv.DictReader(csv_file)

            for row in csv_reader:
                total_count += 1
                # 提取用户提问和医生回答(严格对应CSV表头)
                user_ask = row.get("ask", "").strip()
                doctor_answer = row.get("answer", "").strip()

                # 过滤空数据,避免无效行
                if not user_ask or not doctor_answer:
                    print(f"跳过第{total_count}行:提问/回答为空,忽略")
                    continue

                # ✅ 拼接成你要求的【固定格式】
                text_content = f"<|im_start|>{user_ask}<|im_end|> <|im_start|>{doctor_answer}<|im_end|>"

                # ✅ 构造单行JSON对象(仅text一个key,完全匹配需求)
                json_line = {"text": text_content}

                # 写入JSONL文件(每行1个JSON,ensure_ascii=False保留中文)
                jsonl_file.write(json.dumps(json_line, ensure_ascii=False) + "\n")
                success_count += 1

        print(f"\n✅ 转换完成!总计处理 {total_count} 行,成功生成 {success_count} 行有效数据")
        print(f"📄 输出文件路径:{OUTPUT_JSONL_PATH}")

    except FileNotFoundError:
        print(f"❌ 错误:未找到文件 {INPUT_CSV_PATH},请检查文件路径是否正确")
    except UnicodeDecodeError:
        print(f"❌ 错误:文件编码读取失败,请将代码中ENCODING改为 'gbk' 重试")
    except Exception as e:
        print(f"❌ 转换失败,未知错误:{str(e)}")


# 执行转换
if __name__ == "__main__":
    csv_to_medical_jsonl()
相关推荐
西岸行者3 天前
学习笔记:SKILLS 能帮助更好的vibe coding
笔记·学习
悠哉悠哉愿意3 天前
【单片机学习笔记】串口、超声波、NE555的同时使用
笔记·单片机·学习
别催小唐敲代码3 天前
嵌入式学习路线
学习
毛小茛3 天前
计算机系统概论——校验码
学习
babe小鑫3 天前
大专经济信息管理专业学习数据分析的必要性
学习·数据挖掘·数据分析
winfreedoms3 天前
ROS2知识大白话
笔记·学习·ros2
在这habit之下3 天前
Linux Virtual Server(LVS)学习总结
linux·学习·lvs
我想我不够好。3 天前
2026.2.25监控学习
学习
im_AMBER3 天前
Leetcode 127 删除有序数组中的重复项 | 删除有序数组中的重复项 II
数据结构·学习·算法·leetcode
CodeJourney_J3 天前
从“Hello World“ 开始 C++
c语言·c++·学习