《动手学深度学习》-60translate实现

复制代码
1.读取数据
2.数据预处理:加空格,处理标点
3.分词:去掉第三列,并改为按单词分词
4.序列截断或填充
5.序列最后加结束符,并计算有效长度
6.生成数据集
def data_nmt(path='D:/PycharmDocument/limu/data/fra.txt'):
    with open(path,'r',encoding='utf-8',errors='ignore') as f:
        return f.read()
raw_txt=data_nmt()
# print(raw_txt[:175])
def preprocess_nmt(text):
    def no_space(char,prev_char):
        return char in set(',.!?') and prev_char !=' '
    text=text.replace('\u202f',' ').replace('\xa0',' ').lower()#将(窄)不换行空格换成普通空格
    out=[' '+char if i>0 and no_space(char,text[i-1]) else char for i,char in enumerate(text)]#如果是char标点符号,标点符号前加空格,否则直接加char
    return ''.join(out)
text=preprocess_nmt(raw_txt)
# print(text[:80])
def tokenize_nmt(text,num_examples=None):#划分标签和输入集
    sourch,target = [],[]
    for i,line in enumerate(text.split('\n')):
        if num_examples and i>num_examples:
            break
        parts=line.split('\t')[:2]
        if len(parts)==2:
            sourch.append(list(parts[0].split(' ')))
            target.append(list(parts[1].split(' ')))
    return sourch,target
source,target=tokenize_nmt(text)
# print(source[:6],target[:6])
def show_list_len_pair_hist(legend,xlabel,ylabel,xlist,ylist):
    plt.figure(figsize=(10,6))
    _,_,patches=plt.hist([[len(l) for l in xlist],[len(l) for l in ylist]])
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    for patch in patches[1].patches:
        patch.set_hatch('/')
    plt.legend(legend)
    plt.show()
show_list_len_pair_hist(['source','target'],'# tokens per sequence','count',source,target)
复制代码
src_voc=test52text_process.Vocab(source,min_freq=2,reserved_tokens=['<pad>','<bos>','<eos>'])#句子填充、句子开始、结束
def truncate_pad(line,num_steps,padding_token):
    if len(line)>num_steps:
        return line[:num_steps]
    return line+[padding_token]*(num_steps-len(line))
def build_array_nmt(lines,vocab,num_steps):
    lines=[vocab[l] for l in lines]
    array=torch.tensor([truncate_pad(l, num_steps,vocab['<pad>']) for l in lines])
    valid_len=(array!=vocab['<pad>']).type(torch.int32).sum(1)
    return array,valid_len
复制代码
def load_data_nmt(batch_size, num_steps, num_examples=600):
    text = preprocess_nmt(raw_txt)
    source, target = tokenize_nmt(text, num_examples)
    # 1. 创建词表对象 (不要试图在这里解包 valid_len)
    src_vocab = test52text_process.Vocab(source, min_freq=2,
                                         reserved_tokens=['<pad>', '<bos>', '<eos>'])
    tgt_vocab = test52text_process.Vocab(target, min_freq=2,
                                         reserved_tokens=['<pad>', '<bos>', '<eos>'])
    src_array, src_valid_len = build_array_nmt(source, src_vocab, num_steps)
    tgt_array, tgt_valid_len = build_array_nmt(target, tgt_vocab, num_steps)
    # 3. 将 Tensor 放入数组,而不是放 Vocab 对象
    data_arrays = (src_array, src_valid_len, tgt_array, tgt_valid_len)
    # 4. 生成迭代器
    data_iter = d2l.load_array(data_arrays, batch_size)
    # --- 修复结束 ---
    return data_iter, src_vocab, tgt_vocab
train_iter,src_vocab,tgt_vocab=load_data_nmt(2,
复制代码
train_iter,src_vocab,tgt_vocab=load_data_nmt(2,8)
for X,X_valid_len,Y,Y_valid_len in train_iter:
    print(X.type(torch.int32))
    print(X_valid_len)
    print(Y.type(torch.int32))
    print(Y_valid_len)
    break
相关推荐
zhangzeyuaaa2 分钟前
深入理解 Python GIL:从机制到释放时机
java·网络·python
PSLoverS5 分钟前
c++如何读取和修改可执行文件的PE头信息_IMAGE_NT_HEADERS解析【进阶】
jvm·数据库·python
探物 AI6 分钟前
【感知·车道线检测】UFLDv2车道线检测与车道偏离预警(LDWS)实战
人工智能·算法·目标检测·计算机视觉
Swilderrr7 分钟前
学术研读报告:MEM1面向长视距智能体的记忆 - 推理协同框架
人工智能
gmaajt13 分钟前
React Native 单元测试中第三方依赖的正确 Mock 策略
jvm·数据库·python
aLTttY16 分钟前
Spring Boot整合AI大模型实现智能问答系统实战
人工智能·spring boot·后端
a95114164218 分钟前
宝塔面板数据库查询响应慢_利用慢查询日志进行优化
jvm·数据库·python
zhangzeyuaaa24 分钟前
深入理解 Python 进程间通信:Queue 与 Pipe 实战解析
网络·python·中间件
easy_coder31 分钟前
《工程化视角下的Prompt设计与迭代:云诊断与CICD变更风控中的实践》
人工智能·云计算·prompt
2401_8314194436 分钟前
如何用 http 模块创建一个基础的 Web 服务器处理请求
jvm·数据库·python