基于Deepseek大模型API完成文本分类预测功能

方式1:ChatOpenAI

python 复制代码
from langchain_openai import ChatOpenAI
from dotenv import load_dotenv
import os
import warnings

warnings.filterwarnings('ignore')

# 加载环境变量
load_dotenv()

# 实例化llm模型对象
llm = ChatOpenAI(base_url=os.getenv("base_url"),
                 api_key=os.getenv("DEEPSEEK_API_KEY"),
                 model='deepseek-chat')

print('llm--->', llm)

# 准备提示词, 消息
# system: 系统角色, 描述大模型的角色定位
# user: 用户角色, 描述用户意图
prompt = [{"role": "system", "content": "You are a helpful assistant"},
          {"role": "user", "content": "你是谁?"}]

# 模型预测/推理
response = llm.invoke(prompt)
print('response--->', response)
print('=' * 80)
print(response.content)

方式2:OpenAI

python 复制代码
from openai import OpenAI

# for backward compatibility, you can still use `https://api.deepseek.com/v1` as `base_url`.
client = OpenAI(api_key=os.getenv("DEEPSEEK_API_KEY"), base_url="https://api.deepseek.com")

response = client.chat.completions.create(
    model="deepseek-chat",
    messages=[
        {"role": "system", "content": "You are a helpful assistant"},
        {"role": "user", "content": "Hello"},
],
    max_tokens=1024,
    temperature=0.7,
    stream=False
)
print('response--->\n', response)
print(response.choices[0].message.content)

config路径管理

python 复制代码
class Config(object):
    def __init__(self):
        # 原始数据路径
        self.train_datapath = "../01-data/train.txt"
        self.test_datapath = "../01-data/test.txt"
        self.dev_datapath = "../01-data/dev2.txt"
        self.class_datapath = "../01-data/class.txt"


if __name__ == '__main__':
    conf = Config()
    print(conf.train_datapath)
    print(conf.test_datapath)

LLM预测方法封装

python 复制代码
import os
import json
import time
# 导入LangChain的ChatOpenAI模块,用于与大语言模型交互
from langchain_openai import ChatOpenAI
# 导入tenacity库的相关装饰器,用于实现重试机制
from tenacity import retry, stop_after_attempt, wait_fixed
# 导入dotenv库,用于加载.env文件中的环境变量
from dotenv import load_dotenv
import warnings

warnings.filterwarnings("ignore")

# todo:1-加载环境变量
load_dotenv()

# todo:2-实例化llm模型对象
llm = ChatOpenAI(base_url=os.getenv('base_url'),
                 api_key=os.getenv('DEEPSEEK_API_KEY'),
                 model='deepseek-chat',
                 # 指定模型返回的格式为json格式
                 model_kwargs={"response_format": {"type": "json_object"}})


# todo:3-使用tenacity装饰器对llm模型进行封装, 实现重试机制
# stop_after_attempt: 重试次数限制
# wait_fixed: 重试间隔时间
@retry(stop=stop_after_attempt(3), wait=wait_fixed(5))
def llm_invoke(llm, prompt):
    return llm.invoke(prompt)


# todo:4-预测主函数
def model2pred(title):
    """
    预测函数
    :param title: str 文本字符串 "xxxxxx"
    :return: dict {类别:xxx, 分类原因:xxx}  请求时间
    """
    start_time = time.time()
    # 构建提示词 prompt
    prompt = [
        {
            "role": "system",
            "content": """
                        你是一名新闻分类审核员,任务是将新闻标题分类到以下类别之一:
                        finance, realty, stocks, education, science, society, politics, sports, game, entertainment。
                        请根据标题内容和以下关键词与示例,匹配最相关的类别。如果标题涉及教育机构但核心是社会贡献,优先归为 society。
                        返回 JSON 格式:{"category": "类别", "reason": "分类原因"}

                        类别关键词与示例:
                        - finance: 银行、信用卡、贷款、利率 (例: "各银行信用卡挂失费迥异")
                        - realty: 房产、地价、楼盘 (例: "东5环海棠公社230平准现房")
                        - stocks: 股市、股指、期货 (例: "金证顾问:过山车行情意味着什么")
                        - education: 学校、考试、招生 (例: "中华女子学院仅1专业招男生")
                        - science: 技术、网站、宇航 (例: ""手机钱包"亮相科博会")
                        - society: 社会事件、犯罪、公益 (例: "82岁老太为学生做饭扫地44年")
                        - politics: 政策、国际关系 (例: "查韦斯称愿为俄罗斯提供空军基地")
                        - sports: 比赛、运动员、奥运 (例: "卡佩罗:德国脚生猛的原因")
                        - game: 电子游戏、网游、电竞 (例: "《赤壁OL》攻城战硝烟又起")
                        - entertainment: 明星、影视、综艺 (例: "冯德伦徐若瑄隔空传情")
                        """
        },
        {
            "role": "user",
            "content": f"新闻标题:'{title}',请分类并说明原因。"
        }
    ]

    # 调用模型进行预测
    response = llm_invoke(llm, prompt)
    # print('response--->', response)
    # 获取模型返回的json格式, 解析为字典
    print(type(response.content))
    data = json.loads(response.content)
    print('data--->', type(data), data)
    # print('data--->', type(data), data)
    # 获取字典中的类别名称 和 分类原因
    # get():如果键不存在,则返回society值
    category = data.get('category', 'society')
    reason = data.get('reason', '未明确分类,归为社会类别')
    # 统计请求消耗时间
    duration = int(time.time() - start_time) * 1000
    # 返回结果
    return {"category": category, "reason": reason}, duration


if __name__ == '__main__':
    result, duration = model2pred("美国政府没收缅北电诈集团老大的比特币资产")
    print('预测结果:\n', result['category'])
    print('分类原因:\n', result['reason'])
    print('请求耗时:\n', duration, 'ms')

加载文件完成模型推理评估

python 复制代码
import pandas as pd
from tqdm import tqdm
from rich import print
from sklearn.metrics import accuracy_score, precision_score, f1_score, classification_report
from model2pred import *
import time
import os

# todo:1-获取 类别2下标 映射字典
class2id = {line.strip(): i for i, line in enumerate(open(os.getenv('class_file'), 'r', encoding='utf-8'))}
print('class2id--->', class2id)


# todo:2-加载验证数据集, 获取文本和真实标签
def load_data(file_path):
    df = pd.read_csv(file_path, sep='\t', header=None, names=['text', 'label'], encoding='utf-8')
    # print(df.head())
    return df['text'].to_list(), df['label'].to_list()


# todo:3-模型评估
def evaluate(file_path):
    start_time = time.time()
    # 创建两个空列表, 保存预测下标和预测信息
    preds_ids_list = []
    result = []
    # 加载验证数据集, 获取文本和真实标签
    texts, true_labels_list = load_data(file_path)

    # 循环变量进行预测
    for text in tqdm(texts, desc='Evaluate...'):
        # 调用预测接口
        pred_result, _ = model2pred(text)
        print('pred_result--->', pred_result)

        # 获取预测结果
        category = pred_result['category']
        category_id = class2id.get(category, 5)
        preds_ids_list.append(category_id)

        # 保存其他信息
        result.append({"title": text,
                        "category": category,
                        "reason": pred_result["reason"]})

    # 计算评估指标
    accuracy = accuracy_score(true_labels_list, preds_ids_list)  # 计算整体准确率
    precision = precision_score(true_labels_list, preds_ids_list, average='micro')  # 使用微平均计算精确率
    f1score = f1_score(true_labels_list, preds_ids_list, average='micro')  # 使用微平均计算F1分数
    report = classification_report(true_labels_list, preds_ids_list)  # 生成详细的分类报告
    elapsed_time = (time.time() - start_time)  # 计算总耗时(秒)

    # 返回所有评估结果
    return accuracy, precision, f1score, report, result, elapsed_time


if __name__ == '__main__':
    accuracy, precision, f1score, report, result, elapsed_time = evaluate('../01-data/dev2.txt')
    print('accuracy--->', accuracy)
    print('precision--->', precision)
    print('f1score--->', f1score)
    print('report--->', report)
    print('elapsed_time--->', elapsed_time)
相关推荐
MiNG MENS4 分钟前
基于SpringBoot和Leaflet的行政区划地图掩膜效果实战
java·spring boot·后端
IT_陈寒6 分钟前
Vite静态资源加载把我坑惨了
前端·人工智能·后端
2601_949814698 分钟前
Spring Boot中的404错误:原因、影响及处理策略
java·spring boot·后端
herinspace8 分钟前
管家婆实用贴-如何分离和附加数据库
开发语言·前端·javascript·数据库·语音识别
后端小肥肠8 分钟前
我把自己蒸馏成小肥肠.skill,相关答疑全能做,一人公司终于能聚焦核心业务
人工智能·agent
天一生水water25 分钟前
Time-Series-Library 仓库的使用
人工智能
HeteroCat25 分钟前
DeepSeek V4 来了:我熬了一中午,把技术报告啃完了
人工智能
阿杰学AI30 分钟前
AI核心知识135—大语言模型之 OpenClaw(简洁且通俗易懂版)
人工智能·ai·语言模型·自然语言处理·aigc·ai编程·openclaw
小码哥_常32 分钟前
从MVC到MVI:一文吃透架构模式进化史
前端
嗷o嗷o33 分钟前
Android BLE 的 notify 和 indicate 到底有什么区别
前端