模型训练与验证minicpm-v

minicpm-v 模型进行微调并进行验证

训练使用混合数据集进行训练,对minicpm-V进行lora微调,微调后使用llama3_1对输出结果与标签值进行比对,计算准确率。
验证代码为:

python 复制代码
# URL = https://swift.readthedocs.io/zh-cn/latest/LLM/VLLM%E6%8E%A8%E7%90%86%E5%8A%A0%E9%80%9F%E4%B8%8E%E9%83%A8%E7%BD%B2.html

'''
使用swift进行部署的示例
model: llama3_1-8b-instruct
CUDA_VISIBLE_DEVICES=2 swift deploy --max_model_len 4096 --model_type llama3_1-8b-instruct --model_id_or_path /nas/share/model/huggingface/models--meta-llama--Meta-Llama-3.1-8B-Instruct/snapshots/8c22764a7e3675c50d4c7c9a4edb474456022b16
'''

# 客户端
# example of using swift client

import os
import json
from tqdm import tqdm
from swift.llm import get_model_list_client, XRequestConfig, inference_client

model_list = get_model_list_client()
model_type = model_list.data[0].id


def get_data_jsonl(data_path):
    datas =[]
    with open(data_path, 'r') as f:
        data = f.readlines()
        
    for i in range(len(data)):
        datas.append(json.loads(data[i]))
        
    return datas

def save_data_jsonl(save_path:str, datas: list):
    with open(save_path, 'w') as f:
        for data in datas:
            json.dump(data, f,ensure_ascii=False)
            f.write('\n')
    return save_path


def get_save_path(data_path:str):
    save_path =os.path.splitext(data_path)[0] + '_llm_eval.jsonl'
    if not os.path.exists(os.path.dirname(save_path)):
        os.makedirs(os.path.dirname(save_path))
    return save_path
    
def get_result_correct_path(data_path:str):
    result_correct_path =os.path.splitext(data_path)[0] + '_llm_eval_correct.json'
    if not os.path.exists(os.path.dirname(result_correct_path)):
        os.makedirs(os.path.dirname(result_correct_path))
    return result_correct_path

def save_result_correct(data_path:str,precision:float):
    result =  {
        'data_path': data_path,
        'precision': precision
    }
    with open(data_path, 'w') as f:
        json.dump(result, f,ensure_ascii=False)
    print(f'precision: {precision*100}%')
    return result


def main():
    
    def q_tempalte(response, reply):
        question_template = '''你可以作为一个语言专家,判断一下两个回答是否相同吗 \n
                "response": "{}",
                "reply": "{}",
                ----
                如果是,返回"YES",否则返回"NO"。
                '''.format(response, reply)
        return question_template

    # origin
    data_paths = 
    
    for data_path in tqdm(data_paths):
        datas = get_data_jsonl(data_path)

        save_data =[]
        total_correct = 0
        total = 0
        for data in tqdm(datas):
            
            response = data['response']
            reply = data['reply']
            
            question = q_tempalte(response, reply)
            
            request_config = XRequestConfig(max_tokens=32, temperature=0.1, seed=42)    
            resp = inference_client(model_type, question, request_config=request_config)
            response = resp.choices[0].message.content
            data['llm_response'] = response
            save_data.append(data)
            
            if response.lower() == 'yes':
                total_correct += 1
                
            total += 1
            
            precision = total_correct / total    

        # 保存数据
        save_path = get_save_path(data_path)
        save_data_jsonl(save_path, save_data)
        
        # 保存精度
        result_correct_path = get_result_correct_path(data_path)
        save_result_correct(result_correct_path, precision)
if __name__ == '__main__':
    main()
相关推荐
一个会的不多的人7 分钟前
人工智能基础篇:概念性名词浅谈(第二十九讲)
人工智能·制造·数字化转型
edisao12 分钟前
四。SpaceX、网络化与未来的跨越:低成本、高频次的真正威胁
大数据·开发语言·人工智能·科技·php
万行13 分钟前
差速两轮机器人位移与航向角增量计算
人工智能·python·算法·机器人
瑞华丽PLM16 分钟前
PLM系统中的BOM管理演进:从数据孤岛到全生命周期协同
大数据·人工智能·plm·国产plm·瑞华丽plm·瑞华丽
咚咚王者20 分钟前
人工智能之核心基础 机器学习 第十六章 模型优化
人工智能·机器学习
电商API_1800790524722 分钟前
1688商品详情采集API全解析:技术原理、实操指南与业务落地
大数据·前端·人工智能·网络爬虫
向上的车轮27 分钟前
麦肯锡《智能体、机器人与我们:AI时代的技能协作》
人工智能·机器人
叫我:松哥29 分钟前
基于Flask框架开发的二手房数据分析与推荐管理平台,集成大数据分析、机器学习预测和智能推荐技术
大数据·python·深度学习·机器学习·数据分析·flask
2501_9458374335 分钟前
数字经济的 “安全基石”—— 云服务器零信任架构如何筑牢数据安全防线
人工智能
2501_9421917737 分钟前
【深度学习应用】香蕉镰刀菌症状识别与分类:基于YOLO13-C3k2-MBRConv5模型的实现与分析
人工智能·深度学习·分类