用pytorch和txt文件自己训练模型

先打个甲:我是java,不是搞py的,代码全是用trae.cn + deepseek R1生成的

开发环境:

硬件

9代i5 + RTX2070

软件

python 3.11 + pytorch GPU

一言不合先上代码!!!

注意:第九行那个with,就是把文件里的文字提取成字符串并放在text对象里,如果小伙伴对py熟的话,什么PDF,一个文件夹里的所有代码(不建议!!!我训练一个springboot项目搞了好久)都可以

代码直接丢ide里,细节慢慢看,代码后面有调试

python 复制代码
import torch
import torch.nn as nn
import re
import numpy as np
from torch.utils.data import Dataset, DataLoader
from collections import Counter

# 1. 数据加载与预处理
# 数据加载(使用GB2312编码处理中文文本)
with open('data/西游记.txt', 'r', encoding='gb2312', errors='ignore') as f:
    text = f.read()
    # 保留中文标点和常见符号
    # 正则表达式过滤:保留中文和常见标点符号
text = re.sub(r'[^\u4e00-\u9fa5,。!?、;:""''《》【】()〔〕---...]', '', text)

# 生成训练序列
seq_length = 100
sequences = [text[i:i+seq_length] for i in range(len(text)-seq_length)]
next_chars = [text[i+seq_length] for i in range(len(text)-seq_length)]

print(f'预处理后的文本长度: {len(text)}')
print(f'生成的序列数量: {len(sequences)}')
print(f'下一个字符数量: {len(next_chars)}')

print(f'总字符数: {len(text)}')

# 创建字符词典
chars = sorted(list(set(text)))
print(f'唯一字符数: {len(chars)}')

char_to_idx = {char: i for i, char in enumerate(chars)}
idx_to_char = {i: char for i, char in enumerate(chars)}
vocab_size = len(chars)

# 2. 定义数据集
# 自定义数据集类(继承PyTorch的Dataset)
class TextDataset(Dataset):
    def __init__(self, sequences, next_chars, char_to_idx):
        self.char_to_idx = char_to_idx
        self.X = torch.tensor([[self.char_to_idx[c] for c in seq] for seq in sequences], dtype=torch.long)
        self.y = torch.tensor([self.char_to_idx[c] for c in next_chars], dtype=torch.long)

    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

# 3. 定义LSTM模型
class CharLSTM(nn.Module):
    # LSTM模型定义
# 参数说明:
# - vocab_size: 词汇表大小
# - embedding_dim: 字符嵌入维度
# - hidden_size: LSTM隐藏层维度
    def __init__(self, vocab_size, embedding_dim=128, hidden_size=256):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        # 定义双层LSTM(num_layers=2),设置20%的dropout防止过拟合
        self.lstm = nn.LSTM(embedding_dim, hidden_size, num_layers=2, dropout=0.2)
        self.fc = nn.Linear(hidden_size, vocab_size)
        
    def forward(self, x):
        embedded = self.embedding(x)
        output, _ = self.lstm(embedded)
        return self.fc(output[:, -1, :])

# 4. 训练配置
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('CUDA可用:', torch.cuda.is_available())
print('PyTorch版本:', torch.__version__)
if torch.cuda.is_available():
    t = torch.tensor([1.]).cuda()
    print('张量设备:', t.device)
    print('CUDA设备名称:', torch.cuda.get_device_name(0))
else:
    print('未检测到可用CUDA设备')
# 创建数据集
model = CharLSTM(vocab_size).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# 创建数据集
dataset = TextDataset(sequences, next_chars, char_to_idx)
dataloader = DataLoader(dataset, batch_size=256, shuffle=True)

# 5. 训练循环
import time

total_time = 0  # 初始化总耗时计数器

# 训练循环(共20个epoch)
for epoch in range(20):
    model.train()
    total_loss = 0
    start_time = time.time()
    for batch_idx, (X_batch, y_batch) in enumerate(dataloader):
        X_batch, y_batch = X_batch.to(device), y_batch.to(device)
        
        optimizer.zero_grad()
        output = model(X_batch)
        loss = criterion(output, y_batch)
        
        # 梯度检查
        total_grad_norm = 0.0
        for p in model.parameters():
            if p.grad is not None:
                param_norm = p.grad.data.norm(2)
                total_grad_norm += param_norm.item() ** 2
        total_grad_norm = total_grad_norm ** 0.5
        
        loss.backward()
        # 梯度裁剪(防止梯度爆炸,限制最大梯度范数为0.5)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()
        
        total_loss += loss.item()
        
#         if batch_idx % 100 == 0:
#             print(f'Epoch: {epoch+1} | Batch: {batch_idx}/{len(dataloader)} | Loss: {loss.item():.4f}')
    
    avg_loss = total_loss / len(dataloader)
    epoch_time = time.time() - start_time
    total_time += epoch_time
    print(f'Epoch {epoch+1} 完成 | 平均损失: {avg_loss:.4f} | 单epoch耗时: {epoch_time:.2f}s | 累计总耗时: {total_time:.2f}s | 当前学习率: {optimizer.param_groups[0]["lr"]:.6f}')
    
# 训练时间统计
print(f'\n总训练时间: {total_time:.2f}秒')
print(f'平均每个epoch耗时: {total_time/20:.2f}秒')

# 6. 保存模型
import os
os.makedirs('pytorch', exist_ok=True)
torch.save(model.state_dict(), 'pytorch/xiyouji_lstm.pth')
with open('pytorch/char_mappings.pkl', 'wb') as f:
    import pickle
    pickle.dump((char_to_idx, idx_to_char), f)

# 7. 文本生成函数
# 文本生成函数
# 参数说明:
# - seed: 起始文本
# - length: 要生成的字符数
# 返回:生成的中文文本
def generate_text(seed, model, char_to_idx, idx_to_char, length=100):
    model.eval()
    # 关闭梯度计算(预测阶段不需要反向传播)
    with torch.no_grad():
        generated = seed
        for _ in range(length):
            input_seq = torch.tensor([char_to_idx[c] for c in generated[-seq_length:]], dtype=torch.long).unsqueeze(0).to(device)
            output = model(input_seq)
            pred_idx = torch.argmax(output).item()
            generated += idx_to_char[pred_idx]
        return generated

# 示例使用
print(generate_text("孙悟空是谁", model, char_to_idx, idx_to_char))

调试

代码启动后,看一下控制台日志,目前我觉得最重要的值是loss,loss最小越好,行业内认为0.1 - 1比较稳妥,高于2就是一个大憨憨

第93行提高训练循环,可以降低loss的值

// TODO DoDayum 还有一大堆,我在慢慢搞

问题

GPU未识别

这段代码开始运行时会检查环境,可能会出现以下问题,说明代码没有识别到CUDA

排查

1、英伟达驱动安了没

2、cuda安装了没

3、pytorch分CPU版和GPU版,可能下载了CPU版本,运行以下代码进行验证

python 复制代码
import torch
print('CUDA可用:', torch.cuda.is_available())
print('PyTorch版本:', torch.__version__)
if torch.cuda.is_available():
    t = torch.tensor([1.]).cuda()
    print('张量设备:', t.device)
    print('CUDA设备名称:', torch.cuda.get_device_name(0))
else:
    print('未检测到可用CUDA设备')
相关推荐
Goboy8 分钟前
老婆问我:“大模型的 Token 究竟是个啥?”
后端·程序员·架构
子洋1 小时前
Chroma+LangChain:让AI联网回答更精准
前端·人工智能·后端
追逐时光者1 小时前
基于 .NET Blazor 开源、低代码、易扩展的插件开发框架
后端·.net
MZWeiei4 小时前
Scala:解构声明(用例子通俗易懂)
开发语言·后端·scala
woniu_maggie7 小时前
SAP DOI EXCEL&宏的使用
后端·excel
二两小咸鱼儿8 小时前
Java Demo - JUnit :Unit Test(Assert Methods)
java·后端·junit
字节源流8 小时前
【spring】配置类和整合Junit
java·后端·spring
zhuyasen9 小时前
Go语言配置解析:基于viper的conf库优雅解析配置文件
后端·go
2a3b4c9 小时前
读取 Resource 目录下文件内容
后端
Asthenia041210 小时前
NIO:Buffer对象均是在Jvm堆中分配么?听说过DirectByteBuffer和MappedByteBuffer么?
后端