textCnn笔记

文章目录

先说明下,textCnn并不是一个具体的包或产品,它是一个论文思路, 用基础的神经网络组件(如卷积层、池化层、全连接层)"搭积木"搭出来的一个自定义模型。

textCnn类基本上不用变动,构造的时候调整入参即可,这样可以省代码。

textCnn类(可复用)

1、项目下创建models文件夹。models下创建__init__.py,代码:

python 复制代码
from .textcnn import TextCNN
__all__ = ['TextCNN']

2、models下创建textcnn.py,代码:**

python 复制代码
import torch
import torch.nn as nn

class TextCNN(nn.Module):
    def __init__(self, vocab_size, embedding_dim, num_classes, kernel_sizes=(2, 3, 4), num_filters=128, dropout=0.5):
        super(TextCNN, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        # 使用 ModuleList 来管理多个不同大小的卷积核
        self.convs = nn.ModuleList([
            nn.Conv2d(1, num_filters, kernel_size=(k, embedding_dim)) for k in kernel_sizes
        ])
        self.fc = nn.Linear(num_filters * len(kernel_sizes), num_classes)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.embedding(x).unsqueeze(1) # (batch, 1, seq_len, embed_dim)
        conv_outs = []
        for conv in self.convs:
            conv_out = torch.relu(conv(x)).squeeze(3) # (batch, num_filters, seq_len - k + 1)
            pooled = torch.max(conv_out, dim=2)[0] # (batch, num_filters)
            conv_outs.append(pooled)
        x = torch.cat(conv_outs, dim=1) # (batch, num_filters * len(kernel_sizes))
        x = self.dropout(x)
        x = self.fc(x)
        return x

这样引用时很方便,如:

python 复制代码
from models import TextCNN
model = TextCNN(...)

如果要求大模型发示例,可以如下描述:

TextCNN类已维护到models下,可直接引用,在此基础上提供代码即可。

引用示例:

python 复制代码
from models import TextCNN
model = TextCNN(...)

textCnn示例

代码:

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

# 1. 模拟数据 (假设词汇表大小 1000, 句子长度 10, 词向量维度 128)
vocab_size = 1000
max_len = 10
embedding_dim = 128
num_classes = 2  # 正面/负面

# 随机生成一批数据 (batch_size=4)
X = torch.randint(0, vocab_size, (4, max_len))
y = torch.tensor([1, 0, 1, 0])  # 标签


# 2. 定义 TextCNN 模型
class TextCNN(nn.Module):
    def __init__(self, vocab_size, embedding_dim, num_classes):
        super(TextCNN, self).__init__()
        # 词嵌入层
        self.embedding = nn.Embedding(vocab_size, embedding_dim)

        # 定义多个不同大小的卷积核 (模拟 n-gram 特征提取)
        # 卷积核高度分别为 2, 3, 4,宽度固定为 embedding_dim
        self.convs = nn.ModuleList([
            nn.Conv2d(1, 100, kernel_size=(k, embedding_dim)) for k in (2, 3, 4)
        ])

        # 全连接层
        self.fc = nn.Linear(300, num_classes)  # 3个卷积核 * 100个通道
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        # x shape: (batch_size, seq_len)
        x = self.embedding(x)  # (batch_size, seq_len, embedding_dim)
        x = x.unsqueeze(1)  # (batch_size, 1, seq_len, embedding_dim) - 增加通道维度

        # 卷积 + 激活 + 最大池化
        conv_outs = []
        for conv in self.convs:
            conv_out = torch.relu(conv(x)).squeeze(3)  # (batch, 100, seq_len-k+1)
            pooled = torch.max(conv_out, dim=2)[0]  # 最大池化 (batch, 100)
            conv_outs.append(pooled)

        # 拼接所有卷积核的输出
        x = torch.cat(conv_outs, dim=1)  # (batch, 300)
        x = self.dropout(x)
        x = self.fc(x)  # (batch, num_classes)
        return x


# 3. 实例化模型、损失函数和优化器
model = TextCNN(vocab_size, embedding_dim, num_classes)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 4. 简单的训练循环 (演示用,仅运行 1 次)
print("开始训练...")
model.train()
optimizer.zero_grad()
outputs = model(X)
loss = criterion(outputs, y)
loss.backward()
optimizer.step()

print(f"Loss: {loss.item():.4f}")

# 5. 预测示例
model.eval()
with torch.no_grad():
    test_input = torch.randint(0, vocab_size, (1, max_len))
    pred = model(test_input)
    predicted_class = torch.argmax(pred, dim=1).item()
    print(f"预测结果: {'正面' if predicted_class == 1 else '负面'}")
相关推荐
星恒随风16 小时前
C语言数据结构排序算法详解(下):冒泡排序、快速排序、归并排序和计数排序
c语言·数据结构·笔记·学习·排序算法
米小葱16 小时前
【学习笔记】cmake
笔记·学习
小+不通文墨18 小时前
把树莓派外接的DHT11接收的温湿度发送到emqx上
经验分享·笔记·嵌入式硬件·学习·树莓派
会编程的土豆19 小时前
Go 方法接收者超清晰笔记(类型名 vs 变量名)
开发语言·笔记·golang
fanged20 小时前
C++的汇编实现(TODO)
笔记
不羁的木木20 小时前
Form Kit(卡片开发服务)学习笔记01-核心概念与架构设计
笔记·学习·harmonyos
不羁的木木20 小时前
ArkWeb实战学习笔记01-核心概念与架构设计
笔记·学习·harmonyos
大明者省20 小时前
IIS 端口绑定正常访问的原理说明与常见误区澄清
运维·服务器·笔记
数据皮皮侠AI21 小时前
上市公司耐心资本数据(2010-2025)
大数据·人工智能·笔记·能源·1024程序员节