《动手学深度学习》-60translate实现

复制代码
1.读取数据
2.数据预处理:加空格,处理标点
3.分词:去掉第三列,并改为按单词分词
4.序列截断或填充
5.序列最后加结束符,并计算有效长度
6.生成数据集
def data_nmt(path='D:/PycharmDocument/limu/data/fra.txt'):
    with open(path,'r',encoding='utf-8',errors='ignore') as f:
        return f.read()
raw_txt=data_nmt()
# print(raw_txt[:175])
def preprocess_nmt(text):
    def no_space(char,prev_char):
        return char in set(',.!?') and prev_char !=' '
    text=text.replace('\u202f',' ').replace('\xa0',' ').lower()#将(窄)不换行空格换成普通空格
    out=[' '+char if i>0 and no_space(char,text[i-1]) else char for i,char in enumerate(text)]#如果是char标点符号,标点符号前加空格,否则直接加char
    return ''.join(out)
text=preprocess_nmt(raw_txt)
# print(text[:80])
def tokenize_nmt(text,num_examples=None):#划分标签和输入集
    sourch,target = [],[]
    for i,line in enumerate(text.split('\n')):
        if num_examples and i>num_examples:
            break
        parts=line.split('\t')[:2]
        if len(parts)==2:
            sourch.append(list(parts[0].split(' ')))
            target.append(list(parts[1].split(' ')))
    return sourch,target
source,target=tokenize_nmt(text)
# print(source[:6],target[:6])
def show_list_len_pair_hist(legend,xlabel,ylabel,xlist,ylist):
    plt.figure(figsize=(10,6))
    _,_,patches=plt.hist([[len(l) for l in xlist],[len(l) for l in ylist]])
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    for patch in patches[1].patches:
        patch.set_hatch('/')
    plt.legend(legend)
    plt.show()
show_list_len_pair_hist(['source','target'],'# tokens per sequence','count',source,target)
复制代码
src_voc=test52text_process.Vocab(source,min_freq=2,reserved_tokens=['<pad>','<bos>','<eos>'])#句子填充、句子开始、结束
def truncate_pad(line,num_steps,padding_token):
    if len(line)>num_steps:
        return line[:num_steps]
    return line+[padding_token]*(num_steps-len(line))
def build_array_nmt(lines,vocab,num_steps):
    lines=[vocab[l] for l in lines]
    array=torch.tensor([truncate_pad(l, num_steps,vocab['<pad>']) for l in lines])
    valid_len=(array!=vocab['<pad>']).type(torch.int32).sum(1)
    return array,valid_len
复制代码
def load_data_nmt(batch_size, num_steps, num_examples=600):
    text = preprocess_nmt(raw_txt)
    source, target = tokenize_nmt(text, num_examples)
    # 1. 创建词表对象 (不要试图在这里解包 valid_len)
    src_vocab = test52text_process.Vocab(source, min_freq=2,
                                         reserved_tokens=['<pad>', '<bos>', '<eos>'])
    tgt_vocab = test52text_process.Vocab(target, min_freq=2,
                                         reserved_tokens=['<pad>', '<bos>', '<eos>'])
    src_array, src_valid_len = build_array_nmt(source, src_vocab, num_steps)
    tgt_array, tgt_valid_len = build_array_nmt(target, tgt_vocab, num_steps)
    # 3. 将 Tensor 放入数组,而不是放 Vocab 对象
    data_arrays = (src_array, src_valid_len, tgt_array, tgt_valid_len)
    # 4. 生成迭代器
    data_iter = d2l.load_array(data_arrays, batch_size)
    # --- 修复结束 ---
    return data_iter, src_vocab, tgt_vocab
train_iter,src_vocab,tgt_vocab=load_data_nmt(2,
复制代码
train_iter,src_vocab,tgt_vocab=load_data_nmt(2,8)
for X,X_valid_len,Y,Y_valid_len in train_iter:
    print(X.type(torch.int32))
    print(X_valid_len)
    print(Y.type(torch.int32))
    print(Y_valid_len)
    break
相关推荐
Yolanda948 小时前
【人工智能】《从零搭建AI问答助手项目(九):Prompt优化》
人工智能·prompt
wj3055853788 小时前
课程 9:模型测试记录与 Prompt 策略
linux·人工智能·python·comfyui
小和尚同志8 小时前
深入使用 skill-creator:结合真实生产级实践
人工智能·aigc
DevSecOps选型指南8 小时前
安全419专访悬镜安全 | 穿越周期在 AI 浪潮中定义数字供应链安全新范式
人工智能
沪漂阿龙8 小时前
面试题详解:GraphRAG 全面解析——知识图谱增强 RAG、Local Search、Global Search、社区摘要、工程落地与评估指标一次讲透
人工智能·知识图谱
WangN28 小时前
Unitree RL Lab 学习笔记【通识】
人工智能·机器学习
haina20198 小时前
海纳AI亮相《科创中国》,解码招聘“智”变之路
人工智能·ai面试·ai招聘
星寂樱易李8 小时前
iperf3 + Python-- 网络带宽、网速、网络稳定性
开发语言·网络·python
阿星AI工作室8 小时前
刘润年中大课笔记:一句话说清AI落地之战的本质
大数据·人工智能·创业创新·商业
qingfeng154158 小时前
企业微信机器人开发:如何实现自动化与智能运营?
人工智能·python·机器人·自动化·企业微信