textCnn笔记

文章目录

先说明下,textCnn并不是一个具体的包或产品,它是一个论文思路, 用基础的神经网络组件(如卷积层、池化层、全连接层)"搭积木"搭出来的一个自定义模型。

textCnn类基本上不用变动,构造的时候调整入参即可,这样可以省代码。

textCnn类(可复用)

1、项目下创建models文件夹。models下创建__init__.py,代码:

python 复制代码
from .textcnn import TextCNN
__all__ = ['TextCNN']

2、models下创建textcnn.py,代码:**

python 复制代码
import torch
import torch.nn as nn

class TextCNN(nn.Module):
    def __init__(self, vocab_size, embedding_dim, num_classes, kernel_sizes=(2, 3, 4), num_filters=128, dropout=0.5):
        super(TextCNN, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        # 使用 ModuleList 来管理多个不同大小的卷积核
        self.convs = nn.ModuleList([
            nn.Conv2d(1, num_filters, kernel_size=(k, embedding_dim)) for k in kernel_sizes
        ])
        self.fc = nn.Linear(num_filters * len(kernel_sizes), num_classes)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.embedding(x).unsqueeze(1) # (batch, 1, seq_len, embed_dim)
        conv_outs = []
        for conv in self.convs:
            conv_out = torch.relu(conv(x)).squeeze(3) # (batch, num_filters, seq_len - k + 1)
            pooled = torch.max(conv_out, dim=2)[0] # (batch, num_filters)
            conv_outs.append(pooled)
        x = torch.cat(conv_outs, dim=1) # (batch, num_filters * len(kernel_sizes))
        x = self.dropout(x)
        x = self.fc(x)
        return x

这样引用时很方便,如:

python 复制代码
from models import TextCNN
model = TextCNN(...)

如果要求大模型发示例,可以如下描述:

TextCNN类已维护到models下,可直接引用,在此基础上提供代码即可。

引用示例:

python 复制代码
from models import TextCNN
model = TextCNN(...)

textCnn示例

代码:

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

# 1. 模拟数据 (假设词汇表大小 1000, 句子长度 10, 词向量维度 128)
vocab_size = 1000
max_len = 10
embedding_dim = 128
num_classes = 2  # 正面/负面

# 随机生成一批数据 (batch_size=4)
X = torch.randint(0, vocab_size, (4, max_len))
y = torch.tensor([1, 0, 1, 0])  # 标签


# 2. 定义 TextCNN 模型
class TextCNN(nn.Module):
    def __init__(self, vocab_size, embedding_dim, num_classes):
        super(TextCNN, self).__init__()
        # 词嵌入层
        self.embedding = nn.Embedding(vocab_size, embedding_dim)

        # 定义多个不同大小的卷积核 (模拟 n-gram 特征提取)
        # 卷积核高度分别为 2, 3, 4,宽度固定为 embedding_dim
        self.convs = nn.ModuleList([
            nn.Conv2d(1, 100, kernel_size=(k, embedding_dim)) for k in (2, 3, 4)
        ])

        # 全连接层
        self.fc = nn.Linear(300, num_classes)  # 3个卷积核 * 100个通道
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        # x shape: (batch_size, seq_len)
        x = self.embedding(x)  # (batch_size, seq_len, embedding_dim)
        x = x.unsqueeze(1)  # (batch_size, 1, seq_len, embedding_dim) - 增加通道维度

        # 卷积 + 激活 + 最大池化
        conv_outs = []
        for conv in self.convs:
            conv_out = torch.relu(conv(x)).squeeze(3)  # (batch, 100, seq_len-k+1)
            pooled = torch.max(conv_out, dim=2)[0]  # 最大池化 (batch, 100)
            conv_outs.append(pooled)

        # 拼接所有卷积核的输出
        x = torch.cat(conv_outs, dim=1)  # (batch, 300)
        x = self.dropout(x)
        x = self.fc(x)  # (batch, num_classes)
        return x


# 3. 实例化模型、损失函数和优化器
model = TextCNN(vocab_size, embedding_dim, num_classes)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 4. 简单的训练循环 (演示用,仅运行 1 次)
print("开始训练...")
model.train()
optimizer.zero_grad()
outputs = model(X)
loss = criterion(outputs, y)
loss.backward()
optimizer.step()

print(f"Loss: {loss.item():.4f}")

# 5. 预测示例
model.eval()
with torch.no_grad():
    test_input = torch.randint(0, vocab_size, (1, max_len))
    pred = model(test_input)
    predicted_class = torch.argmax(pred, dim=1).item()
    print(f"预测结果: {'正面' if predicted_class == 1 else '负面'}")
相关推荐
迷路爸爸1802 小时前
Docker 入门学习笔记 07:用一个多服务案例真正理解 Docker Compose
运维·笔记·学习·spring cloud·docker·容器·eureka
chushiyunen3 小时前
milvus数据库管理工具attu使用笔记
笔记·milvus
鱼鳞_3 小时前
Java学习笔记_Day23(HashMap)
java·笔记·学习
sheeta19983 小时前
LeetCode 每日一题笔记 日期:2026.04.07 题目:2069.模拟行走机器人二
笔记·leetcode·机器人
代码旅人ing3 小时前
数组算法刷题指南
笔记
江湖有缘3 小时前
基于华为openEuler系统部署Memory笔记管理工具
笔记
小陈phd3 小时前
多模态大模型学习笔记(三十三)——基于YOLOv11的安全帽佩戴检测算法
笔记·学习·yolo
雨浓YN3 小时前
OPC DA 通讯开发笔记
windows·笔记
taoqick3 小时前
rubric系列论文粗读笔记
笔记