类A* llm解码 幻觉更低更稳定

该代码是一个基于语言模型的生成式对话系统,其中解码推理部分采用了beam search算法,而不是A*算法。以下是对该代码解码推理部分的主要说明:

  1. 解码推理的目的是根据输入的对话上下文,生成回应。这里使用了beam search算法来生成回应,而不是贪婪解码或A*算法。
  2. beam search算法通过维护一个大小为B的beam,在每一步解码时保留概率最高的B个候选序列,而不是只保留概率最高的1个。这样可以增加解码的多样性,避免贪婪解码的局部最优问题。
  3. 主要代码如下:
python 复制代码
for _ in range(max_len):
    out, _ = model(torch.Tensor([prompt_list]).to(device).long())
    out = out[:, -1:]
    score = torch.softmax(out, -1)[0, 0]
    score, score_index = torch.sort(score)
    score = score[-B:]
    score_index = score_index[-B:]
    score /= temp 
    idx_next = torch.multinomial(torch.Tensor(score), num_samples=1, generator=None)
    prompt += [voc["voc"][score_index[idx_next]]]
    print(prompt[-1], end="", flush=True)
  1. 在每一步,模型根据当前prompt生成下一个单词的概率分布,然后对概率进行排序,只保留概率最高的B个候选单词。
  2. 对概率进行temperature scaling,增加探索性。
  3. 从B个候选单词中采样下一个单词,加入prompt,继续生成。
  4. 相比A*算法,beam search的优势在于:
  • 更适合语言模型这种具有连续性和组合爆炸特性的任务
  • 计算复杂度可控,A*算法的搜索空间太大
  • 可以生成更自然流畅的回应
    总之,该代码采用beam search进行解码推理,相比A*算法更适合语言模型生成任务,可以生成更高质量的回应。
python 复制代码
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from glob import glob
from tqdm import tqdm
from model import SamOut

import polars as pl
from collections import Counter


def train():
    voc = pd.read_pickle("total_voc.pkl")

    net = SamOut(len(voc["voc"]), 768, 32, 16)
    print(sum([i.shape[0] * i.shape[1] for i in net.parameters() if len(i.shape) > 1]) + sum(
        [i.shape[0] for i in net.parameters() if len(i.shape) == 1]))

    net.load_state_dict(torch.load("pretrain_768.pth"))
    net.to("cuda")

    opt = torch.optim.Adam(params=net.parameters(), lr=0.00002)
    loss_func0 = torch.nn.CrossEntropyLoss(ignore_index=3)

    bar = tqdm(range(10))
    steps = 0
    epoch_loss = []
    batch_size = 30

    for epoch in bar:
        paths = glob("./pre_data_set_*.pkl")
        data_set = []
        for ii in range(0, len(paths), 2):

            for one_path in paths[ii:ii + 2]:

                data_set = pd.read_pickle(one_path)
                np.random.shuffle(data_set)
                loss_list = []
                for i in range(0, len(data_set), batch_size):
                    # weights.append(list(net.state_dict().values())[0])
                    j = i + batch_size
                    input_one = data_set[i:j]

                    out0, _ = net(torch.Tensor(input_one)[:, :-1].int().to("cuda"))
                    loss = loss_func0(out0.reshape([-1, out0.shape[-1]]),
                                      torch.Tensor(input_one)[:, 1:].reshape([-1]).long().to("cuda"))

                    loss_list.append(loss.item())
                    bar.set_description(
                        "epoch___{}____loss___{:.6f}____steps___{}".format(epoch, np.mean(loss_list), steps))
                    opt.zero_grad()
                    loss.backward()
                    opt.step()
                    steps += batch_size

                torch.save(net.state_dict(), "pretrain_768.pth")
                # eval_model()
                epoch_loss.append(np.mean(loss_list))
                pd.to_pickle(epoch_loss, "loss916")


def gen_one_voc():
    data = pd.read_csv("pretrain_data.csv")

    data = data["text"].values.tolist()
    data = "".join(data)
    count = Counter()
    for ii in tqdm(range(0, len(data), len(data) // 8)):
        jj = ii + len(data) // 8
        for k, v in Counter(data[ii:jj]).items():
            count[k] = count.get(k, 0) + v

    data = ""
    data0 = pd.read_csv("sft_data_multi.csv")
    for ii in tqdm(range(0, len(data0), len(data0) // 8)):
        jj = ii + len(data0) // 8
        for k, v in Counter(data0[ii:jj]).items():
            count[k] = count.get(k, 0) + v
    data0 = ""
    data1 = pd.read_csv("sft_data_single.csv")
    for ii in tqdm(range(0, len(data1), len(data1) // 8)):
        jj = ii + len(data1) // 8
        for k, v in Counter(data1[ii:jj]).items():
            count[k] = count.get(k, 0) + v
    data1 = ""

    # plt.plot(sorted(count.values()))
    # plt.show()
    count = pd.DataFrame({"voc": count.keys(), "count": count.values()})
    voc = count.loc[count["count"] > 100, "voc"].values.tolist()
    voc0 = [[[["<|pos_{}_{}|>".format(jj, ii) for jj, ii in enumerate(list(str(i)))], j] for i, j in
             enumerate(count.loc[count["count"] <= 100, "voc"].values.tolist())]]
    pd.to_pickle(voc, "voc.pkl")
    pd.to_pickle(voc0, "voc0.pkl")


def gen_voc():
    voc = pd.read_pickle("voc.pkl")
    voc0 = pd.read_pickle("voc0.pkl")
    voc0 = {j: i for i, j in voc0[0]}
    for i in range(6):
        for j in range(10):
            voc.append("<|pos_{}_{}|>".format(i, j))
    voc = ["<|sos|>", "<|user|>", "<|agent|>", "<|pad|>", "<|history|>"] + sorted(voc)

    pd.to_pickle({"voc": voc, "voc0": voc0}, "total_voc.pkl")


def gen_pre_data_align(num, total_num):
    voc = pd.read_pickle("total_voc.pkl")
    voc["voc0"] = [[i, [voc["voc"].index(j) for j in ii]] for i, ii in voc["voc0"].items()]
    voc["voc"] = [i for i in voc["voc"]]
    voc = {"voc": voc["voc"] + [i for i, j in voc["voc0"]],
           "voc_id": [[i] for i in list(range(len(voc["voc"])))] + [j for i, j in voc["voc0"]]}
    voc = pd.DataFrame(voc)
    # voc=pl.DataFrame(voc)

    pre_data = pl.read_csv("pretrain_data.csv")
    pre_data = pre_data["text"].to_numpy().tolist()
    count = len(pre_data) // total_num
    pre_data = pre_data[(num - 1) * count:count * num]
    data_set = []
    bar = tqdm(range(len(pre_data)))

    while pre_data:
        bar.update()
        one = pre_data.pop()
        one = pd.merge(pd.DataFrame({"voc": list(one)}), voc, on="voc", how="left")

        thr = np.hstack(one["voc_id"].to_numpy()).tolist()

        thr += (518 - len(thr)) * [3]
        thr = thr[:512]
        data_set.append(thr)
    pd.to_pickle(data_set, "pre_data_set_{}.pkl".format(num))


def gen_sft_single_data_align():
    voc = pd.read_pickle("total_voc.pkl")
    voc["voc0"] = {i: [voc["voc"].index(j) for j in ii] for i, ii in voc["voc0"].items()}
    voc["voc"] = {v: i for i, v in enumerate(voc["voc"])}

    pre_data = pl.read_csv("sft_data_single.csv")
    pre_data = pre_data.to_numpy().tolist()
    data_set = []
    index_id = 0
    for h, q, a in tqdm(pre_data):
        index_id += 1
        one = ["<|user|>"] + list(q) + ["<|agent|>"] + list(a)
        one_list = []
        for i in one:
            voc_id = voc["voc"].get(i, None)
            if voc_id != None:
                one_list.append(voc_id)
            else:
                one_list += voc["voc0"].get(i, [3])
        one_list += (512 - len(one_list)) * [3]
        data_set.append(one_list[:512])
        if len(data_set) > 1000000:
            pd.to_pickle(data_set, "sft_data_single_{}.pkl".format(index_id))
            data_set = []
    pd.to_pickle(data_set, "sft_data_single_{}.pkl".format(index_id))


def train_single():
    voc = pd.read_pickle("total_voc.pkl")

    net = SamOut(len(voc["voc"]), 512, 32, 8)

    net.load_state_dict(torch.load("pretrain_sft_single.pth"))
    net.to("cuda")

    opt = torch.optim.Adam(params=net.parameters(), lr=0.000003)
    loss_func0 = torch.nn.CrossEntropyLoss(ignore_index=3)

    bar = tqdm(range(2))
    steps = 0
    epoch_loss = []

    for epoch in bar:
        paths = glob("./sft_data_*.pkl")
        np.random.shuffle(paths)
        for o in range(0, len(paths), 2):
            data_set = []
            for one_path in paths[o:o + 2]:
                data_set += pd.read_pickle(one_path)

            np.random.shuffle(data_set)

            loss_list = []
            for i in range(0, len(data_set), 80):
                # weights.append(list(net.state_dict().values())[0])
                j = i + 80
                input_one = data_set[i:j]

                out0, _ = net(torch.Tensor(input_one)[:, :-1].int().to("cuda"))
                loss = loss_func0(out0.reshape([-1, out0.shape[-1]]),
                                  torch.Tensor(input_one)[:, 1:].reshape([-1]).long().to("cuda"))

                loss_list.append(loss.item())
                bar.set_description(
                    "epoch___{}____loss___{:.6f}____steps___{}".format(epoch, np.mean(loss_list), steps))
                opt.zero_grad()
                loss.backward()
                opt.step()
                steps += 80

            torch.save(net.state_dict(), "pretrain_sft_single.pth")
            # eval_model()
            epoch_loss.append(np.mean(loss_list))
            pd.to_pickle(epoch_loss, "loss916")


def load_model_and_voc(device="cpu"):
    voc = pd.read_pickle("total_voc.pkl")

    net = SamOut(len(voc["voc"]), 768, 32, 16)
    # net = SamOut(len(voc["voc"]), 512, 32, 8)
    print(sum([i.shape[0] * i.shape[1] for i in net.parameters() if len(i.shape) > 1]) + sum(
        [i.shape[0] for i in net.parameters() if len(i.shape) == 1]))

    # net.load_state_dict(torch.load("pretrain_768.pth", map_location=device))
    # net.load_state_dict(torch.load("pretrain_sft_single.pth", map_location=device))
    net.load_state_dict(torch.load("pretrain_sft_single_768.pth", map_location=device))
    # net.load_state_dict(torch.load("pretrain.pth", map_location=device))
    net.to(device)
    net.eval()
    return net, voc


def gen_token(voc, model, prompt, max_len, rp=1.2, temp=0.5, top_k=16, device="cpu"):
    print("agent:", end="", flush=True)

    for _ in range(max_len):

        prompt_list = []
        for i in prompt:
            if i not in voc["voc"]:
                prompt_list += [voc["voc"].index(ii) for ii in voc["voc0"].get(i)]
            else:

                prompt_list.append(voc["voc"].index(i))
        prompt_tensor=model.em(torch.Tensor([prompt_list]).to(device).long())
        prompt_tensor=torch.nn.functional.cosine_similarity(prompt_tensor[:,:,:-1],prompt_tensor[:,:,1:], dim=-1)
        out, _ = model(torch.Tensor([prompt_list]).to(device).long())
        gn=np.array([torch.nn.functional.softmax(out,-1)[:,i,ii].item() for i,ii in  enumerate(prompt_list)])*prompt_tensor.detach().numpy().reshape(-1)
        out = out[:, -1:]
        # 重复抑制
        for token_id in enumerate(prompt_list):
            out[:, :, token_id] /= rp
        score = torch.softmax(out, -1)[0, 0]
        score, score_index = torch.sort(score)
        score = score.detach().numpy()
        score_sum = np.cumsum(score)
        score_index = score_index.detach().numpy()
        score = score[score_sum > 0.2]
        score_index = score_index[score_sum > 0.2]
        score = score[::-1]
        score_index = score_index[::-1]
        score /= temp

        hn=torch.nn.functional.cosine_similarity(model.em(torch.Tensor([score_index]).long()),
                                              model.em(torch.Tensor([prompt_list[-1:]]).long()), -1)[
            0].detach().numpy() * score
        idx_index=score_index[np.argmin(np.sum(gn.reshape([-1, 1]) + hn), 0)]

        # out = score / temp

        # v = out[:min(top_k, score.size)]

        # idx_next = torch.multinomial(torch.Tensor(v), num_samples=1, generator=None)
        if voc["voc"][idx_index] == "<|sos|>":
            break
        prompt += [voc["voc"][idx_index]]
        print(prompt[-1], end="", flush=True)

      


def t_infre():
    model, voc = load_model_and_voc()
    while True:
        text = input("user:")
        gen_token(voc, model, ["<|user|>"] + list("{}".format(text)) + ["<|agent|>"], 100)
        print()


if __name__ == '__main__':
    # print(pd.read_pickle("loss916"))
    # gen_one_voc()
    # gen_voc()
    # for i in range(17,18):
    #     gen_pre_data_align(i, 16)

    # train()
    # gen_sft_single_data_align()
    # train_single()
    # sft 推理  一本正经的胡说八道已练成

    t_infre()
相关推荐
weixin_387545647 分钟前
如何使用通义千问 AI 生成 PPT 并发布到个人网站
人工智能·powerpoint
神一样的老师17 分钟前
使用TinyML方法的智能家居嵌入式智能关键词识别
人工智能
Jam-Young27 分钟前
OpenCV的简单练习
人工智能·opencv·计算机视觉
dashexiaobudian27 分钟前
海外媒体宣发对品牌出海的多维影响-大舍传媒
大数据·人工智能·搜索引擎·区块链
FreedomLeo130 分钟前
Python机器学习笔记(七、深度学习-神经网络)
python·深度学习·神经网络·机器学习
lovelin+v175030409661 小时前
智能化API:如何重塑企业业务流程与用户体验
大数据·人工智能·爬虫·python·api
MYT_flyflyfly1 小时前
LRM-典型 Transformer 在视觉领域的应用,单个图像生成3D图像
人工智能·深度学习·transformer
明月醉窗台2 小时前
深度学习(15)从头搭建模型到训练、预测示例总结
人工智能·python·深度学习·目标检测·计算机视觉
三天不学习2 小时前
Visual Studio 玩转 IntelliCode AI辅助开发
ide·人工智能·ai编程·visual studio·intellicode
martian6652 小时前
深入详解线性代数基础知识:理解矩阵与向量运算、特征值与特征向量,以及矩阵分解方法(如奇异值分解SVD和主成分分析PCA)在人工智能中的应用
人工智能·线性代数·矩阵·特征向量