textCnn笔记

文章目录

先说明下,textCnn并不是一个具体的包或产品,它是一个论文思路, 用基础的神经网络组件(如卷积层、池化层、全连接层)"搭积木"搭出来的一个自定义模型。

textCnn类基本上不用变动,构造的时候调整入参即可,这样可以省代码。

textCnn类(可复用)

1、项目下创建models文件夹。models下创建__init__.py,代码:

python 复制代码
from .textcnn import TextCNN
__all__ = ['TextCNN']

2、models下创建textcnn.py,代码:**

python 复制代码
import torch
import torch.nn as nn

class TextCNN(nn.Module):
    def __init__(self, vocab_size, embedding_dim, num_classes, kernel_sizes=(2, 3, 4), num_filters=128, dropout=0.5):
        super(TextCNN, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        # 使用 ModuleList 来管理多个不同大小的卷积核
        self.convs = nn.ModuleList([
            nn.Conv2d(1, num_filters, kernel_size=(k, embedding_dim)) for k in kernel_sizes
        ])
        self.fc = nn.Linear(num_filters * len(kernel_sizes), num_classes)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.embedding(x).unsqueeze(1) # (batch, 1, seq_len, embed_dim)
        conv_outs = []
        for conv in self.convs:
            conv_out = torch.relu(conv(x)).squeeze(3) # (batch, num_filters, seq_len - k + 1)
            pooled = torch.max(conv_out, dim=2)[0] # (batch, num_filters)
            conv_outs.append(pooled)
        x = torch.cat(conv_outs, dim=1) # (batch, num_filters * len(kernel_sizes))
        x = self.dropout(x)
        x = self.fc(x)
        return x

这样引用时很方便,如:

python 复制代码
from models import TextCNN
model = TextCNN(...)

如果要求大模型发示例,可以如下描述:

TextCNN类已维护到models下,可直接引用,在此基础上提供代码即可。

引用示例:

python 复制代码
from models import TextCNN
model = TextCNN(...)

textCnn示例

代码:

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

# 1. 模拟数据 (假设词汇表大小 1000, 句子长度 10, 词向量维度 128)
vocab_size = 1000
max_len = 10
embedding_dim = 128
num_classes = 2  # 正面/负面

# 随机生成一批数据 (batch_size=4)
X = torch.randint(0, vocab_size, (4, max_len))
y = torch.tensor([1, 0, 1, 0])  # 标签


# 2. 定义 TextCNN 模型
class TextCNN(nn.Module):
    def __init__(self, vocab_size, embedding_dim, num_classes):
        super(TextCNN, self).__init__()
        # 词嵌入层
        self.embedding = nn.Embedding(vocab_size, embedding_dim)

        # 定义多个不同大小的卷积核 (模拟 n-gram 特征提取)
        # 卷积核高度分别为 2, 3, 4,宽度固定为 embedding_dim
        self.convs = nn.ModuleList([
            nn.Conv2d(1, 100, kernel_size=(k, embedding_dim)) for k in (2, 3, 4)
        ])

        # 全连接层
        self.fc = nn.Linear(300, num_classes)  # 3个卷积核 * 100个通道
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        # x shape: (batch_size, seq_len)
        x = self.embedding(x)  # (batch_size, seq_len, embedding_dim)
        x = x.unsqueeze(1)  # (batch_size, 1, seq_len, embedding_dim) - 增加通道维度

        # 卷积 + 激活 + 最大池化
        conv_outs = []
        for conv in self.convs:
            conv_out = torch.relu(conv(x)).squeeze(3)  # (batch, 100, seq_len-k+1)
            pooled = torch.max(conv_out, dim=2)[0]  # 最大池化 (batch, 100)
            conv_outs.append(pooled)

        # 拼接所有卷积核的输出
        x = torch.cat(conv_outs, dim=1)  # (batch, 300)
        x = self.dropout(x)
        x = self.fc(x)  # (batch, num_classes)
        return x


# 3. 实例化模型、损失函数和优化器
model = TextCNN(vocab_size, embedding_dim, num_classes)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 4. 简单的训练循环 (演示用,仅运行 1 次)
print("开始训练...")
model.train()
optimizer.zero_grad()
outputs = model(X)
loss = criterion(outputs, y)
loss.backward()
optimizer.step()

print(f"Loss: {loss.item():.4f}")

# 5. 预测示例
model.eval()
with torch.no_grad():
    test_input = torch.randint(0, vocab_size, (1, max_len))
    pred = model(test_input)
    predicted_class = torch.argmax(pred, dim=1).item()
    print(f"预测结果: {'正面' if predicted_class == 1 else '负面'}")
相关推荐
阿Y加油吧14 小时前
堆 / 优先队列专题二刷笔记:前 K 个高频元素 & 数据流的中位数
java·笔记·算法
Codector14 小时前
在Ubuntu中使用Edge侧边栏无法添加和查看同步的侧边栏问题解决
笔记·ubuntu·develop
Brilliantwxx14 小时前
【C++】认识标准库STL(1)
开发语言·c++·笔记·程序人生·算法
想成为优秀工程师的爸爸14 小时前
第二十四篇技术笔记:郭大侠学DoIP - 从“偶睡破庙”到“天字一号”
网络·笔记·网络协议·tcp/ip·信息与通信
天才少女爱迪生14 小时前
【迪士尼机器人】硬件接入记录(自用笔记版)
笔记
Nice_Fold14 小时前
Kubernetes命名空间与Pod核心概念(自用笔记)
笔记·容器·kubernetes
你数过天上的星星吗15 小时前
Python学习笔记一(标识符、关键字、变量、数据类型、关系运算)
笔记·python·学习
轻舟行715 小时前
ctfshow-Web应用安全与防护challenge做题笔记 长期更新
笔记·web安全·网络安全
想成为优秀工程师的爸爸1 天前
第十九篇技术笔记:UDP——相思传得快,飞鸽传书在
笔记·网络协议·tcp/ip·udp·信息与通信
Yeh2020581 天前
cookie与Session笔记
笔记