基于Deepseek大模型API完成文本分类预测功能

方式1:ChatOpenAI

python 复制代码
from langchain_openai import ChatOpenAI
from dotenv import load_dotenv
import os
import warnings

warnings.filterwarnings('ignore')

# 加载环境变量
load_dotenv()

# 实例化llm模型对象
llm = ChatOpenAI(base_url=os.getenv("base_url"),
                 api_key=os.getenv("DEEPSEEK_API_KEY"),
                 model='deepseek-chat')

print('llm--->', llm)

# 准备提示词, 消息
# system: 系统角色, 描述大模型的角色定位
# user: 用户角色, 描述用户意图
prompt = [{"role": "system", "content": "You are a helpful assistant"},
          {"role": "user", "content": "你是谁?"}]

# 模型预测/推理
response = llm.invoke(prompt)
print('response--->', response)
print('=' * 80)
print(response.content)

方式2:OpenAI

python 复制代码
from openai import OpenAI

# for backward compatibility, you can still use `https://api.deepseek.com/v1` as `base_url`.
client = OpenAI(api_key=os.getenv("DEEPSEEK_API_KEY"), base_url="https://api.deepseek.com")

response = client.chat.completions.create(
    model="deepseek-chat",
    messages=[
        {"role": "system", "content": "You are a helpful assistant"},
        {"role": "user", "content": "Hello"},
],
    max_tokens=1024,
    temperature=0.7,
    stream=False
)
print('response--->\n', response)
print(response.choices[0].message.content)

config路径管理

python 复制代码
class Config(object):
    def __init__(self):
        # 原始数据路径
        self.train_datapath = "../01-data/train.txt"
        self.test_datapath = "../01-data/test.txt"
        self.dev_datapath = "../01-data/dev2.txt"
        self.class_datapath = "../01-data/class.txt"


if __name__ == '__main__':
    conf = Config()
    print(conf.train_datapath)
    print(conf.test_datapath)

LLM预测方法封装

python 复制代码
import os
import json
import time
# 导入LangChain的ChatOpenAI模块,用于与大语言模型交互
from langchain_openai import ChatOpenAI
# 导入tenacity库的相关装饰器,用于实现重试机制
from tenacity import retry, stop_after_attempt, wait_fixed
# 导入dotenv库,用于加载.env文件中的环境变量
from dotenv import load_dotenv
import warnings

warnings.filterwarnings("ignore")

# todo:1-加载环境变量
load_dotenv()

# todo:2-实例化llm模型对象
llm = ChatOpenAI(base_url=os.getenv('base_url'),
                 api_key=os.getenv('DEEPSEEK_API_KEY'),
                 model='deepseek-chat',
                 # 指定模型返回的格式为json格式
                 model_kwargs={"response_format": {"type": "json_object"}})


# todo:3-使用tenacity装饰器对llm模型进行封装, 实现重试机制
# stop_after_attempt: 重试次数限制
# wait_fixed: 重试间隔时间
@retry(stop=stop_after_attempt(3), wait=wait_fixed(5))
def llm_invoke(llm, prompt):
    return llm.invoke(prompt)


# todo:4-预测主函数
def model2pred(title):
    """
    预测函数
    :param title: str 文本字符串 "xxxxxx"
    :return: dict {类别:xxx, 分类原因:xxx}  请求时间
    """
    start_time = time.time()
    # 构建提示词 prompt
    prompt = [
        {
            "role": "system",
            "content": """
                        你是一名新闻分类审核员,任务是将新闻标题分类到以下类别之一:
                        finance, realty, stocks, education, science, society, politics, sports, game, entertainment。
                        请根据标题内容和以下关键词与示例,匹配最相关的类别。如果标题涉及教育机构但核心是社会贡献,优先归为 society。
                        返回 JSON 格式:{"category": "类别", "reason": "分类原因"}

                        类别关键词与示例:
                        - finance: 银行、信用卡、贷款、利率 (例: "各银行信用卡挂失费迥异")
                        - realty: 房产、地价、楼盘 (例: "东5环海棠公社230平准现房")
                        - stocks: 股市、股指、期货 (例: "金证顾问:过山车行情意味着什么")
                        - education: 学校、考试、招生 (例: "中华女子学院仅1专业招男生")
                        - science: 技术、网站、宇航 (例: ""手机钱包"亮相科博会")
                        - society: 社会事件、犯罪、公益 (例: "82岁老太为学生做饭扫地44年")
                        - politics: 政策、国际关系 (例: "查韦斯称愿为俄罗斯提供空军基地")
                        - sports: 比赛、运动员、奥运 (例: "卡佩罗:德国脚生猛的原因")
                        - game: 电子游戏、网游、电竞 (例: "《赤壁OL》攻城战硝烟又起")
                        - entertainment: 明星、影视、综艺 (例: "冯德伦徐若瑄隔空传情")
                        """
        },
        {
            "role": "user",
            "content": f"新闻标题:'{title}',请分类并说明原因。"
        }
    ]

    # 调用模型进行预测
    response = llm_invoke(llm, prompt)
    # print('response--->', response)
    # 获取模型返回的json格式, 解析为字典
    print(type(response.content))
    data = json.loads(response.content)
    print('data--->', type(data), data)
    # print('data--->', type(data), data)
    # 获取字典中的类别名称 和 分类原因
    # get():如果键不存在,则返回society值
    category = data.get('category', 'society')
    reason = data.get('reason', '未明确分类,归为社会类别')
    # 统计请求消耗时间
    duration = int(time.time() - start_time) * 1000
    # 返回结果
    return {"category": category, "reason": reason}, duration


if __name__ == '__main__':
    result, duration = model2pred("美国政府没收缅北电诈集团老大的比特币资产")
    print('预测结果:\n', result['category'])
    print('分类原因:\n', result['reason'])
    print('请求耗时:\n', duration, 'ms')

加载文件完成模型推理评估

python 复制代码
import pandas as pd
from tqdm import tqdm
from rich import print
from sklearn.metrics import accuracy_score, precision_score, f1_score, classification_report
from model2pred import *
import time
import os

# todo:1-获取 类别2下标 映射字典
class2id = {line.strip(): i for i, line in enumerate(open(os.getenv('class_file'), 'r', encoding='utf-8'))}
print('class2id--->', class2id)


# todo:2-加载验证数据集, 获取文本和真实标签
def load_data(file_path):
    df = pd.read_csv(file_path, sep='\t', header=None, names=['text', 'label'], encoding='utf-8')
    # print(df.head())
    return df['text'].to_list(), df['label'].to_list()


# todo:3-模型评估
def evaluate(file_path):
    start_time = time.time()
    # 创建两个空列表, 保存预测下标和预测信息
    preds_ids_list = []
    result = []
    # 加载验证数据集, 获取文本和真实标签
    texts, true_labels_list = load_data(file_path)

    # 循环变量进行预测
    for text in tqdm(texts, desc='Evaluate...'):
        # 调用预测接口
        pred_result, _ = model2pred(text)
        print('pred_result--->', pred_result)

        # 获取预测结果
        category = pred_result['category']
        category_id = class2id.get(category, 5)
        preds_ids_list.append(category_id)

        # 保存其他信息
        result.append({"title": text,
                        "category": category,
                        "reason": pred_result["reason"]})

    # 计算评估指标
    accuracy = accuracy_score(true_labels_list, preds_ids_list)  # 计算整体准确率
    precision = precision_score(true_labels_list, preds_ids_list, average='micro')  # 使用微平均计算精确率
    f1score = f1_score(true_labels_list, preds_ids_list, average='micro')  # 使用微平均计算F1分数
    report = classification_report(true_labels_list, preds_ids_list)  # 生成详细的分类报告
    elapsed_time = (time.time() - start_time)  # 计算总耗时(秒)

    # 返回所有评估结果
    return accuracy, precision, f1score, report, result, elapsed_time


if __name__ == '__main__':
    accuracy, precision, f1score, report, result, elapsed_time = evaluate('../01-data/dev2.txt')
    print('accuracy--->', accuracy)
    print('precision--->', precision)
    print('f1score--->', f1score)
    print('report--->', report)
    print('elapsed_time--->', elapsed_time)
相关推荐
IAUTOMOBILE2 小时前
Code Marathon 项目源码解析与技术实践
java·前端·算法
Lyyaoo.2 小时前
【JAVA基础面经】深拷贝与浅拷贝
java·开发语言·算法
饼干哥哥2 小时前
怎么写好一个AI提示词?10个场景与50个技巧+官方100个教程合集
人工智能
. . . . .2 小时前
git-ai 项目详解
人工智能·git
白狐_7982 小时前
深度解析:大语言模型(LLM)联网搜索与实时数据获取的底层原理
人工智能·语言模型·自然语言处理
oyzz1202 小时前
Redis 安装及配置教程(Windows)【安装】
java
AI科技2 小时前
原创音乐人用哼唱歌曲旋律,通过AI编曲软件快速打造出完整歌曲的编曲伴奏
人工智能
名字很费劲2 小时前
vue项目,刷新后出现404错误,怎么解决
前端·javascript·vue·404