《动手学深度学习》-60translate实现

复制代码
1.读取数据
2.数据预处理:加空格,处理标点
3.分词:去掉第三列,并改为按单词分词
4.序列截断或填充
5.序列最后加结束符,并计算有效长度
6.生成数据集
def data_nmt(path='D:/PycharmDocument/limu/data/fra.txt'):
    with open(path,'r',encoding='utf-8',errors='ignore') as f:
        return f.read()
raw_txt=data_nmt()
# print(raw_txt[:175])
def preprocess_nmt(text):
    def no_space(char,prev_char):
        return char in set(',.!?') and prev_char !=' '
    text=text.replace('\u202f',' ').replace('\xa0',' ').lower()#将(窄)不换行空格换成普通空格
    out=[' '+char if i>0 and no_space(char,text[i-1]) else char for i,char in enumerate(text)]#如果是char标点符号,标点符号前加空格,否则直接加char
    return ''.join(out)
text=preprocess_nmt(raw_txt)
# print(text[:80])
def tokenize_nmt(text,num_examples=None):#划分标签和输入集
    sourch,target = [],[]
    for i,line in enumerate(text.split('\n')):
        if num_examples and i>num_examples:
            break
        parts=line.split('\t')[:2]
        if len(parts)==2:
            sourch.append(list(parts[0].split(' ')))
            target.append(list(parts[1].split(' ')))
    return sourch,target
source,target=tokenize_nmt(text)
# print(source[:6],target[:6])
def show_list_len_pair_hist(legend,xlabel,ylabel,xlist,ylist):
    plt.figure(figsize=(10,6))
    _,_,patches=plt.hist([[len(l) for l in xlist],[len(l) for l in ylist]])
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    for patch in patches[1].patches:
        patch.set_hatch('/')
    plt.legend(legend)
    plt.show()
show_list_len_pair_hist(['source','target'],'# tokens per sequence','count',source,target)
复制代码
src_voc=test52text_process.Vocab(source,min_freq=2,reserved_tokens=['<pad>','<bos>','<eos>'])#句子填充、句子开始、结束
def truncate_pad(line,num_steps,padding_token):
    if len(line)>num_steps:
        return line[:num_steps]
    return line+[padding_token]*(num_steps-len(line))
def build_array_nmt(lines,vocab,num_steps):
    lines=[vocab[l] for l in lines]
    array=torch.tensor([truncate_pad(l, num_steps,vocab['<pad>']) for l in lines])
    valid_len=(array!=vocab['<pad>']).type(torch.int32).sum(1)
    return array,valid_len
复制代码
def load_data_nmt(batch_size, num_steps, num_examples=600):
    text = preprocess_nmt(raw_txt)
    source, target = tokenize_nmt(text, num_examples)
    # 1. 创建词表对象 (不要试图在这里解包 valid_len)
    src_vocab = test52text_process.Vocab(source, min_freq=2,
                                         reserved_tokens=['<pad>', '<bos>', '<eos>'])
    tgt_vocab = test52text_process.Vocab(target, min_freq=2,
                                         reserved_tokens=['<pad>', '<bos>', '<eos>'])
    src_array, src_valid_len = build_array_nmt(source, src_vocab, num_steps)
    tgt_array, tgt_valid_len = build_array_nmt(target, tgt_vocab, num_steps)
    # 3. 将 Tensor 放入数组,而不是放 Vocab 对象
    data_arrays = (src_array, src_valid_len, tgt_array, tgt_valid_len)
    # 4. 生成迭代器
    data_iter = d2l.load_array(data_arrays, batch_size)
    # --- 修复结束 ---
    return data_iter, src_vocab, tgt_vocab
train_iter,src_vocab,tgt_vocab=load_data_nmt(2,
复制代码
train_iter,src_vocab,tgt_vocab=load_data_nmt(2,8)
for X,X_valid_len,Y,Y_valid_len in train_iter:
    print(X.type(torch.int32))
    print(X_valid_len)
    print(Y.type(torch.int32))
    print(Y_valid_len)
    break
相关推荐
橘颂TA2 小时前
【测试】自动化测试函数介绍——web 测试
python·功能测试·selenium·测试工具·dubbo
爱学习的阿磊2 小时前
Python上下文管理器(with语句)的原理与实践
jvm·数据库·python
m0_736919102 小时前
Python面向对象编程(OOP)终极指南
jvm·数据库·python
one____dream2 小时前
【网安】Reverse-非常规题目
linux·python·安全·网络安全·ctf
loui robot2 小时前
规划与控制之局部路径规划算法local_planner
人工智能·算法·自动驾驶
玄同7652 小时前
Llama.cpp 全实战指南:跨平台部署本地大模型的零门槛方案
人工智能·语言模型·自然语言处理·langchain·交互·llama·ollama
格林威2 小时前
Baumer相机金属焊缝缺陷识别:提升焊接质量检测可靠性的 7 个关键技术,附 OpenCV+Halcon 实战代码!
人工智能·数码相机·opencv·算法·计算机视觉·视觉检测·堡盟相机
冷雨夜中漫步2 小时前
python反转列表reverse()和[::-1]哪个效率更高
开发语言·python
rainbow68892 小时前
Python面向对象编程与异常处理实战
开发语言·python