关于MediaEval数据集的Dataset构建(Text部分-使用PLM BERT)

python 复制代码
import random
import numpy as np
import pandas as pd
import torch
from transformers import BertModel,BertTokenizer
from tqdm.auto import tqdm
from torch.utils.data import Dataset
import re
"""参考Game-On论文"""
"""util.py"""
def set_seed(seed_value=42):
    random.seed(seed_value)
    np.random.seed(seed_value)
    # 用于设置生成随机数的种子
    torch.manual_seed(seed_value)
    torch.cuda.manual_seed_all(seed_value)
"""util.py"""

"""文本预处理-textGraph.py"""
# 文本DataSet类

def text_preprocessing(text):
    """
    - Remove entity mentions (eg. '@united')
    - Correct errors (eg. '&' to '&')
    @param    text (str): a string to be processed.
    @return   text (Str): the processed string.
    """
    # Remove '@name'
    text = re.sub(r'(@.*?)[\s]', ' ', text)

    # Replace '&' with '&'
    text = re.sub(r'&', '&', text)

    # Remove trailing whitespace
    text = re.sub(r'\s+', ' ', text).strip()

    # removes links
    text = re.sub(r'(?P<url>https?://[^\s]+)', r'', text)

    # remove @usernames
    text = re.sub(r"\@(\w+)", "", text)

    # remove # from #tags
    text = text.replace('#', '')

    return text

class TextDataset(Dataset):
    def __init__(self,df,tokenizer):
        # 包含推文的主文件框架
        self.df = df.reset_index(drop=True)

        # 使用的分词器
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        # 帖子的文本内容
        text = self.df['tweetText'][idx]
        # 作为唯一标识符的id 'tweetId'
        unique_id = self.df['tweetId'][idx]

        # 创建一个空的列表来存储输出结果
        input_ids = []
        attention_mask = []
        # 使用tokenizer分词器
        encoded_sent = self.tokenizer.encode_plus(
            text = text_preprocessing(text), # 这里使用的是预处理的句子,而不是直接对原句子使用tokenizer
            add_special_tokens=True,        # 添加[CLS]以及[SEP]等特殊词元
            max_length=512,                 # 最大截断长度
            padding='max_length',            # padding的最大长度
            return_attention_mask=True,     # 返回attention_mask
            truncation=True                 #
        )
        # 获取编码效果
        input_ids = encoded_sent.get('input_ids')
        # 获取attention_mask结果
        attention_mask = encoded_sent.get('attention_mask')

        # 将列表转换成张量
        input_ids = torch.tensor(input_ids)
        attention_mask =torch.tensor(attention_mask)

        return {'input_ids':input_ids,'attention_mask':attention_mask,'unique_id':unique_id}

def store_data(bert,device,df,dataset,store_dir):
    lengths = []
    bert.eval()

    for idx in tqdm(range(len(df))):
        sample = dataset.__getitem__(idx)
        print('原始sample[input_ids]和sample[attention_mask]的维度:',sample['input_ids'].shape,sample['attention_mask'].shape)
        # 升维
        input_ids,attention_mask = sample['input_ids'].unsqueeze(0),sample['attention_mask'].unsqueeze(0)
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        # 得到唯一标识属性
        unique_id = sample['unique_id']

        # 计算token的个数
        num_tokens = attention_mask.sum().detach().cpu().item()
        """不生成新的计算图,而是只做权重更新"""
        with torch.no_grad():
            out = bert(input_ids=input_ids,attention_mask=attention_mask)
        # last_hidden_state.shape是(batch_size,sequence_length,hidden_size)
        out_tokens = out.last_hidden_state[:,1:num_tokens,:].detach().cpu().squeeze(0).numpy() # token向量

        # 保存token级别表示
        filename = f'{emed_dir}{unique_id}.npy'

        try:
            np.save(filename, out_tokens)
            print(f"文件{filename}保存成功")
        except FileNotFoundError:
            # 文件不存在,创建新文件并保存
            np.save(filename, out_tokens)
            print(f"文件{filename}创建成功并保存成功")
        lengths.append(num_tokens)

        ## Save semantic/ whole text representation
        # 保存语义  也就是整个文本的表示
        out_cls = out.last_hidden_state[:,0,:].unsqueeze(0).detach().cpu().squeeze(0).numpy() ## cls vector
        filename = f'{emed_dir}{unique_id}_full_text.npy'
        # 尝试保存.npy文件,如果文件不存在则自动创建
        try:
            np.save(filename, out_cls)
            print(f"文件{filename}保存成功")
        except FileNotFoundError:
            # 文件不存在,创建新文件并保存
            np.save(filename, out_cls)
            print(f"文件{filename}创建成功并保存成功")
    return lengths

if __name__=='__main__':
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    # 根目录
    root_dir = "./dataset/image-verification-corpus-master/image-verification-corpus-master/mediaeval2015/"
    emed_dir = './Embedding_File'
    # 文件路径
    train_csv_name = "tweetsTrain.csv"
    test_csv_name = "tweetsTest.csv"

    # 加载PLM和分词器
    tokenizer = BertTokenizer.from_pretrained('./bert/')
    bert = BertModel.from_pretrained('./bert/', return_dict=True)
    bert = bert.to(device)

    # 用于存储每个推文的Embedding
    store_dir ="Embed_Post/"

    # 创建训练数据集的Embedding表示
    df_train = pd.read_csv(f'{root_dir}{train_csv_name}')
    df_train = df_train.dropna().reset_index(drop=True)

    # 训练数据集的编码结果
    train_dataset = TextDataset(df_train,tokenizer)
    lengths = store_data(bert, device, df_train, train_dataset, store_dir)

    ## Create graph data for testing set
    # 为测试集创建Embedding表示
    df_test = pd.read_csv(f'{root_dir}{test_csv_name}')
    df_test = df_test.dropna().reset_index(drop=True)
    test_dataset = TextDataset(df_test, tokenizer)

    lengths = store_data(bert, device, df_test, test_dataset, store_dir)

"""文本预处理-textGraph.py"""
相关推荐
缝艺智研社2 分钟前
誉财 YC - 16 POLO 衫智能自动钉扣机:POLO 衫钉扣新变革
人工智能·新人首发·自动化缝纫机·线上模板机·无人自动化产线
带电的小王4 分钟前
【动手学深度学习】8.4. 循环神经网络
人工智能·pytorch·rnn·深度学习
yigan_Eins4 分钟前
Transformer|残差连接的技术演进:从CNN到ResNet
人工智能·深度学习·cnn·transformer
道可云6 分钟前
道可云人工智能&OPC每日资讯|《广东省加快推进人工智能全域全时全行业高水平应用行动方案》发布
人工智能
0xR3lativ1ty8 分钟前
每周AI新工具速览:Kiln与OpenRA-RL登场
人工智能·ai
精益数智工坊9 分钟前
拆解制造业仓库物料管理流程:如何通过标准化仓库物料管理流程解决账实不符难题
大数据·前端·数据库·人工智能·精益工程
大龄程序员狗哥16 分钟前
第46篇:语音识别入门——让AI“听懂”人类语言(概念入门)
人工智能·语音识别
weixin_4171970518 分钟前
谷歌400亿押注Anthropic:AI军备竞赛升级
人工智能
sunneo19 分钟前
专栏B-产品心理学深度-06-说服架构
人工智能·架构·产品运营·产品经理·ai编程·ai-native
烟台业荣数据科技有限公司19 分钟前
智能建造:从“能做”到“值得做”,我们还需跨越什么?
大数据·人工智能