textCnn笔记

文章目录

先说明下,textCnn并不是一个具体的包或产品,它是一个论文思路, 用基础的神经网络组件(如卷积层、池化层、全连接层)"搭积木"搭出来的一个自定义模型。

textCnn类基本上不用变动,构造的时候调整入参即可,这样可以省代码。

textCnn类(可复用)

1、项目下创建models文件夹。models下创建__init__.py,代码:

python 复制代码
from .textcnn import TextCNN
__all__ = ['TextCNN']

2、models下创建textcnn.py,代码:**

python 复制代码
import torch
import torch.nn as nn

class TextCNN(nn.Module):
    def __init__(self, vocab_size, embedding_dim, num_classes, kernel_sizes=(2, 3, 4), num_filters=128, dropout=0.5):
        super(TextCNN, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        # 使用 ModuleList 来管理多个不同大小的卷积核
        self.convs = nn.ModuleList([
            nn.Conv2d(1, num_filters, kernel_size=(k, embedding_dim)) for k in kernel_sizes
        ])
        self.fc = nn.Linear(num_filters * len(kernel_sizes), num_classes)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.embedding(x).unsqueeze(1) # (batch, 1, seq_len, embed_dim)
        conv_outs = []
        for conv in self.convs:
            conv_out = torch.relu(conv(x)).squeeze(3) # (batch, num_filters, seq_len - k + 1)
            pooled = torch.max(conv_out, dim=2)[0] # (batch, num_filters)
            conv_outs.append(pooled)
        x = torch.cat(conv_outs, dim=1) # (batch, num_filters * len(kernel_sizes))
        x = self.dropout(x)
        x = self.fc(x)
        return x

这样引用时很方便,如:

python 复制代码
from models import TextCNN
model = TextCNN(...)

如果要求大模型发示例,可以如下描述:

TextCNN类已维护到models下,可直接引用,在此基础上提供代码即可。

引用示例:

python 复制代码
from models import TextCNN
model = TextCNN(...)

textCnn示例

代码:

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

# 1. 模拟数据 (假设词汇表大小 1000, 句子长度 10, 词向量维度 128)
vocab_size = 1000
max_len = 10
embedding_dim = 128
num_classes = 2  # 正面/负面

# 随机生成一批数据 (batch_size=4)
X = torch.randint(0, vocab_size, (4, max_len))
y = torch.tensor([1, 0, 1, 0])  # 标签


# 2. 定义 TextCNN 模型
class TextCNN(nn.Module):
    def __init__(self, vocab_size, embedding_dim, num_classes):
        super(TextCNN, self).__init__()
        # 词嵌入层
        self.embedding = nn.Embedding(vocab_size, embedding_dim)

        # 定义多个不同大小的卷积核 (模拟 n-gram 特征提取)
        # 卷积核高度分别为 2, 3, 4,宽度固定为 embedding_dim
        self.convs = nn.ModuleList([
            nn.Conv2d(1, 100, kernel_size=(k, embedding_dim)) for k in (2, 3, 4)
        ])

        # 全连接层
        self.fc = nn.Linear(300, num_classes)  # 3个卷积核 * 100个通道
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        # x shape: (batch_size, seq_len)
        x = self.embedding(x)  # (batch_size, seq_len, embedding_dim)
        x = x.unsqueeze(1)  # (batch_size, 1, seq_len, embedding_dim) - 增加通道维度

        # 卷积 + 激活 + 最大池化
        conv_outs = []
        for conv in self.convs:
            conv_out = torch.relu(conv(x)).squeeze(3)  # (batch, 100, seq_len-k+1)
            pooled = torch.max(conv_out, dim=2)[0]  # 最大池化 (batch, 100)
            conv_outs.append(pooled)

        # 拼接所有卷积核的输出
        x = torch.cat(conv_outs, dim=1)  # (batch, 300)
        x = self.dropout(x)
        x = self.fc(x)  # (batch, num_classes)
        return x


# 3. 实例化模型、损失函数和优化器
model = TextCNN(vocab_size, embedding_dim, num_classes)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 4. 简单的训练循环 (演示用,仅运行 1 次)
print("开始训练...")
model.train()
optimizer.zero_grad()
outputs = model(X)
loss = criterion(outputs, y)
loss.backward()
optimizer.step()

print(f"Loss: {loss.item():.4f}")

# 5. 预测示例
model.eval()
with torch.no_grad():
    test_input = torch.randint(0, vocab_size, (1, max_len))
    pred = model(test_input)
    predicted_class = torch.argmax(pred, dim=1).item()
    print(f"预测结果: {'正面' if predicted_class == 1 else '负面'}")
相关推荐
LinXunFeng6 小时前
Obsidian - 使用 Share Note 分享笔记并自部署
前端·笔记·github
闪闪发亮的小星星5 天前
高斯光以及高斯光公式解释
笔记
cqbzcsq5 天前
CellFlow虚拟细胞论文阅读
论文阅读·人工智能·笔记·学习·生物信息
阿米亚波5 天前
【Windows】QEMU 启动 openEuler aarch64/arm64 架构系统 + 离线软件源
linux·windows·经验分享·笔记·架构·arm
自传.5 天前
尚硅谷 Vibe Coding|第三章(1) Claude Code深度使用与进阶技巧 学习笔记
笔记·学习·尚硅谷·vibecoding
.千余5 天前
【C++】模板进阶全解:非类型参数|全特化|偏特化|分离编译完全指南
开发语言·c++·笔记·学习·其他
自传.5 天前
尚硅谷 Vibe Coding|第二章 AI编程工具生态 学习笔记
笔记·学习·ai编程·尚硅谷·vibe coding
秋波。未央5 天前
Java Agent 开发 · Day 1 学习笔记(含作业完整标准答案)
java·笔记·学习
中屹指纹浏览器5 天前
2026指纹浏览器字体指纹、字体渲染偏差检测与全维度虚拟字体池搭建方案
经验分享·笔记