用pytorch和txt文件自己训练模型

先打个甲:我是java,不是搞py的,代码全是用trae.cn + deepseek R1生成的

开发环境:

硬件

9代i5 + RTX2070

软件

python 3.11 + pytorch GPU

一言不合先上代码!!!

注意:第九行那个with,就是把文件里的文字提取成字符串并放在text对象里,如果小伙伴对py熟的话,什么PDF,一个文件夹里的所有代码(不建议!!!我训练一个springboot项目搞了好久)都可以

代码直接丢ide里,细节慢慢看,代码后面有调试

python 复制代码
import torch
import torch.nn as nn
import re
import numpy as np
from torch.utils.data import Dataset, DataLoader
from collections import Counter

# 1. 数据加载与预处理
# 数据加载(使用GB2312编码处理中文文本)
with open('data/西游记.txt', 'r', encoding='gb2312', errors='ignore') as f:
    text = f.read()
    # 保留中文标点和常见符号
    # 正则表达式过滤:保留中文和常见标点符号
text = re.sub(r'[^\u4e00-\u9fa5,。!?、;:""''《》【】()〔〕---...]', '', text)

# 生成训练序列
seq_length = 100
sequences = [text[i:i+seq_length] for i in range(len(text)-seq_length)]
next_chars = [text[i+seq_length] for i in range(len(text)-seq_length)]

print(f'预处理后的文本长度: {len(text)}')
print(f'生成的序列数量: {len(sequences)}')
print(f'下一个字符数量: {len(next_chars)}')

print(f'总字符数: {len(text)}')

# 创建字符词典
chars = sorted(list(set(text)))
print(f'唯一字符数: {len(chars)}')

char_to_idx = {char: i for i, char in enumerate(chars)}
idx_to_char = {i: char for i, char in enumerate(chars)}
vocab_size = len(chars)

# 2. 定义数据集
# 自定义数据集类(继承PyTorch的Dataset)
class TextDataset(Dataset):
    def __init__(self, sequences, next_chars, char_to_idx):
        self.char_to_idx = char_to_idx
        self.X = torch.tensor([[self.char_to_idx[c] for c in seq] for seq in sequences], dtype=torch.long)
        self.y = torch.tensor([self.char_to_idx[c] for c in next_chars], dtype=torch.long)

    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

# 3. 定义LSTM模型
class CharLSTM(nn.Module):
    # LSTM模型定义
# 参数说明:
# - vocab_size: 词汇表大小
# - embedding_dim: 字符嵌入维度
# - hidden_size: LSTM隐藏层维度
    def __init__(self, vocab_size, embedding_dim=128, hidden_size=256):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        # 定义双层LSTM(num_layers=2),设置20%的dropout防止过拟合
        self.lstm = nn.LSTM(embedding_dim, hidden_size, num_layers=2, dropout=0.2)
        self.fc = nn.Linear(hidden_size, vocab_size)
        
    def forward(self, x):
        embedded = self.embedding(x)
        output, _ = self.lstm(embedded)
        return self.fc(output[:, -1, :])

# 4. 训练配置
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('CUDA可用:', torch.cuda.is_available())
print('PyTorch版本:', torch.__version__)
if torch.cuda.is_available():
    t = torch.tensor([1.]).cuda()
    print('张量设备:', t.device)
    print('CUDA设备名称:', torch.cuda.get_device_name(0))
else:
    print('未检测到可用CUDA设备')
# 创建数据集
model = CharLSTM(vocab_size).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# 创建数据集
dataset = TextDataset(sequences, next_chars, char_to_idx)
dataloader = DataLoader(dataset, batch_size=256, shuffle=True)

# 5. 训练循环
import time

total_time = 0  # 初始化总耗时计数器

# 训练循环(共20个epoch)
for epoch in range(20):
    model.train()
    total_loss = 0
    start_time = time.time()
    for batch_idx, (X_batch, y_batch) in enumerate(dataloader):
        X_batch, y_batch = X_batch.to(device), y_batch.to(device)
        
        optimizer.zero_grad()
        output = model(X_batch)
        loss = criterion(output, y_batch)
        
        # 梯度检查
        total_grad_norm = 0.0
        for p in model.parameters():
            if p.grad is not None:
                param_norm = p.grad.data.norm(2)
                total_grad_norm += param_norm.item() ** 2
        total_grad_norm = total_grad_norm ** 0.5
        
        loss.backward()
        # 梯度裁剪(防止梯度爆炸,限制最大梯度范数为0.5)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()
        
        total_loss += loss.item()
        
#         if batch_idx % 100 == 0:
#             print(f'Epoch: {epoch+1} | Batch: {batch_idx}/{len(dataloader)} | Loss: {loss.item():.4f}')
    
    avg_loss = total_loss / len(dataloader)
    epoch_time = time.time() - start_time
    total_time += epoch_time
    print(f'Epoch {epoch+1} 完成 | 平均损失: {avg_loss:.4f} | 单epoch耗时: {epoch_time:.2f}s | 累计总耗时: {total_time:.2f}s | 当前学习率: {optimizer.param_groups[0]["lr"]:.6f}')
    
# 训练时间统计
print(f'\n总训练时间: {total_time:.2f}秒')
print(f'平均每个epoch耗时: {total_time/20:.2f}秒')

# 6. 保存模型
import os
os.makedirs('pytorch', exist_ok=True)
torch.save(model.state_dict(), 'pytorch/xiyouji_lstm.pth')
with open('pytorch/char_mappings.pkl', 'wb') as f:
    import pickle
    pickle.dump((char_to_idx, idx_to_char), f)

# 7. 文本生成函数
# 文本生成函数
# 参数说明:
# - seed: 起始文本
# - length: 要生成的字符数
# 返回:生成的中文文本
def generate_text(seed, model, char_to_idx, idx_to_char, length=100):
    model.eval()
    # 关闭梯度计算(预测阶段不需要反向传播)
    with torch.no_grad():
        generated = seed
        for _ in range(length):
            input_seq = torch.tensor([char_to_idx[c] for c in generated[-seq_length:]], dtype=torch.long).unsqueeze(0).to(device)
            output = model(input_seq)
            pred_idx = torch.argmax(output).item()
            generated += idx_to_char[pred_idx]
        return generated

# 示例使用
print(generate_text("孙悟空是谁", model, char_to_idx, idx_to_char))

调试

代码启动后,看一下控制台日志,目前我觉得最重要的值是loss,loss最小越好,行业内认为0.1 - 1比较稳妥,高于2就是一个大憨憨

第93行提高训练循环,可以降低loss的值

// TODO DoDayum 还有一大堆,我在慢慢搞

问题

GPU未识别

这段代码开始运行时会检查环境,可能会出现以下问题,说明代码没有识别到CUDA

排查

1、英伟达驱动安了没

2、cuda安装了没

3、pytorch分CPU版和GPU版,可能下载了CPU版本,运行以下代码进行验证

python 复制代码
import torch
print('CUDA可用:', torch.cuda.is_available())
print('PyTorch版本:', torch.__version__)
if torch.cuda.is_available():
    t = torch.tensor([1.]).cuda()
    print('张量设备:', t.device)
    print('CUDA设备名称:', torch.cuda.get_device_name(0))
else:
    print('未检测到可用CUDA设备')
相关推荐
二闹8 分钟前
三个注解,到底该用哪一个?别再傻傻分不清了!
后端
用户490558160812520 分钟前
当控制面更新一条 ACL 规则时,如何更新给数据面
后端
林太白22 分钟前
Nuxt.js搭建一个官网如何简单
前端·javascript·后端
码事漫谈23 分钟前
VS Code 终端完全指南
后端
该用户已不存在1 小时前
OpenJDK、Temurin、GraalVM...到底该装哪个?
java·后端
怀刃1 小时前
内存监控对应解决方案
后端
码事漫谈1 小时前
VS Code Copilot 内联聊天与提示词技巧指南
后端
Moonbit2 小时前
MoonBit Perals Vol.06: MoonBit 与 LLVM 共舞 (上):编译前端实现
后端·算法·编程语言
Moonbit2 小时前
MoonBit Perals Vol.06: MoonBit 与 LLVM 共舞(下):llvm IR 代码生成
后端·程序员·代码规范
Moonbit2 小时前
MoonBit Pearls Vol.05: 函数式里的依赖注入:Reader Monad
后端·rust·编程语言