如何使用uer做多分类任务

如何使用uer做多分类任务

语料集下载

找到这里点击即可

里面是这有json文件的

因此我们对此要做一些处理,将其转为tsv格式

python 复制代码
# -*- coding: utf-8 -*-
import json
import csv
import chardet

# 检测文件编码
def detect_encoding(file_path):
    with open(file_path, 'rb') as f:
        raw_data = f.read()
    return chardet.detect(raw_data)['encoding']

# 输入文件名
input_file = './datasets/iflytek/train.json'
# 输出文件名
output_file = './datasets/iflytek/train.tsv'

# 检测输入文件的编码格式
file_encoding = detect_encoding(input_file)

# 打开输入的 JSON 文件和输出的 TSV 文件
with open(input_file, 'r', encoding=file_encoding) as json_file, open(output_file, 'w', newline='', encoding='utf-8') as tsv_file:
    # 准备 TSV 写入器
    tsv_writer = csv.writer(tsv_file, delimiter='\t')

    # 写入表头(列表['label', 'label_des', 'sentence']中要注意根据json文件中的键值做更换)
    tsv_writer.writerow(['label', 'label_des', 'sentence'])

    # 逐行读取 JSON 文件
    for line in json_file:
        try:
            # 解析每一行的 JSON 数据
            json_data = json.loads(line.strip())
            # 写入到 TSV 文件中,(列表['label', 'label_des', 'sentence']中要注意根据json文件中的键值做更换)
            tsv_writer.writerow([json_data['label'], json_data['label_des'], json_data['sentence']])
        except json.JSONDecodeError as e:
            print(f"无法解析的行: {line.strip()}")
            print(f"错误信息: {e}")

print(f"JSON 文件已成功转换为 TSV 文件,输入文件编码: {file_encoding}")

接着呢要把所有tsv文件的sentence表头名改成text_a,不然运行uer框架会报错,原因请看源代码逻辑

python 复制代码
def read_dataset(args, path):
    dataset, columns = [], {}
    with open(path, mode="r", encoding="utf-8") as f:
        for line_id, line in enumerate(f):
            if line_id == 0:
                for i, column_name in enumerate(line.rstrip("\r\n").split("\t")):
                    columns[column_name] = i
                continue
            line = line.rstrip("\r\n").split("\t")
            tgt = int(line[columns["label"]])
            if args.soft_targets and "logits" in columns.keys():
                soft_tgt = [float(value) for value in line[columns["logits"]].split(" ")]
            if "text_b" not in columns:  # Sentence classification.
                text_a = line[columns["text_a"]]
                src = args.tokenizer.convert_tokens_to_ids([CLS_TOKEN] + args.tokenizer.tokenize(text_a) + [SEP_TOKEN])
                seg = [1] * len(src)
            else:  # Sentence-pair classification.
                text_a, text_b = line[columns["text_a"]], line[columns["text_b"]]
                src_a = args.tokenizer.convert_tokens_to_ids([CLS_TOKEN] + args.tokenizer.tokenize(text_a) + [SEP_TOKEN])
                src_b = args.tokenizer.convert_tokens_to_ids(args.tokenizer.tokenize(text_b) + [SEP_TOKEN])
                src = src_a + src_b
                seg = [1] * len(src_a) + [2] * len(src_b)

            if len(src) > args.seq_length:
                src = src[: args.seq_length]
                seg = seg[: args.seq_length]
            if len(src) < args.seq_length:
                PAD_ID = args.tokenizer.convert_tokens_to_ids([PAD_TOKEN])[0]
                src += [PAD_ID] * (args.seq_length - len(src))
                seg += [0] * (args.seq_length - len(seg))
            if args.soft_targets and "logits" in columns.keys():
                dataset.append((src, tgt, seg, soft_tgt))
            else:
                dataset.append((src, tgt, seg))

    return dataset

这里规定好了表头名只有label,text_a,text_b

搞完之后进入训练代码,我的显存只有16G,因此

python 复制代码
python finetune/run_classifier.py --pretrained_model_path models/cluecorpussmall_roberta_wwm_large_seq512_model.bin --vocab_path models/google_zh_vocab.txt --config_path models/bert/large_config.json --train_path datasets/iflytek/train.tsv --dev_path datasets/iflytek/dev.tsv --output_model_path models/iflytek_classifier_model.bin --epochs_num 3 --batch_size 16 --seq_length 128


这里可以看到只有61.49的正确率,其实是因为显存还不够,训练不了那么大的,标准的参数应该设置为batch_size=32 seq_length=256

有能力的可以更改参数进行训练

接着来预测

python 复制代码
python inference/run_classifier_infer.py --load_model_path models/iflytek_classifier_model.bin --vocab_path models/google_zh_vocab.txt --config_path models/bert/large_config.json --test_path datasets/iflytek/test.tsv --prediction_path datasets/iflytek/prediction.tsv --seq_length 256 --labels_num 119

最后自行查看预测效果

相关推荐
视***间6 分钟前
端侧大模型落地新标杆:视程空间将GPT-OSS边缘AI深度导入NVIDIA Jetson平台
人工智能·gpt·边缘计算·nvidia·ai算力·gpt-oss·视程空间
1892280486120 分钟前
NY379固态MT29F32T08GSLBHL8-36QA:B
大数据·服务器·人工智能·科技·缓存
Adair_z20 分钟前
[SEO艺术重读] 第9篇 熊猫算法、企鹅算法和惩罚机制
人工智能·熊猫算法·企鹅算法·谷歌算法恢复·网站seo诊断·高质量内容创作·e-e-a-t原则
ZZH_AI项目交付22 分钟前
我把 AI 最容易改坏真实 App 的地方,整理成了 skills
人工智能·ios·app
忆~遂愿23 分钟前
从文字应答到具象共情:Agent 交互的底层革新
人工智能·深度学习·目标检测·microsoft·机器学习·ar·交互
Ai.den24 分钟前
Windows 安装 MinerU 3.x 实现本地批量解析 PDF
人工智能·windows·ai
枫叶林FYL30 分钟前
【强化学习】长上下文可验证奖励强化学习:原理推导与系统架构
人工智能·系统架构
Teable任意门互动30 分钟前
深度解析:AI 赋能开源多维表格,实现企业全场景数据整合与高效应用
数据库·人工智能·低代码·信息可视化·开源·数据库开发
沪漂阿龙33 分钟前
Hermes Agent 安全边界全解析:让 AI Agent 敢执行、可控制、能回滚
人工智能·安全
天天进步201534 分钟前
从零打造 Python 全栈项目:智能教学辅助系统
开发语言·人工智能·python