NLP —— 迁移学习 FastText

一、FastText 介绍

官网:https://fasttext.cc/

1.作用:作为NLP工程领域常用的工具包,FastText有两大作用

① 进行文本分类

② 训练词向量

模型领域中起着承上启下的作用

上:transformer

下:训练模型 || 大模型

2.FastText工具包的优势

① fastText 工具包中内含的fastText模型 十分简单的网络结构

② 使用fastText模型训练词向量时使用层次softmax结构,来提升超多类别下的模型性能。

③ 由于fasttext模型过于简单无法捕捉词序特征, 因此会进行n-gram特征提取以弥补模型缺陷提升精度

3.FastText安装

优先用:pip install fasttext

报错就用:pip install fasttext-wheel

二、FastText 模型架构

FastText 模型架构和 Word2Vec 中的 CBOW 模型很类似, 不同之处在于, FastText 预测标签, 而 CBOW 模型预测中间词.

FastText的模型分为三层架构:

  • 输入层: 是对文档embedding之后的向量, 包含N-gram特征

  • 隐藏层: 是对输入数据的求和平均

  • 输出层: 是文档对应的label

层次softmax(hierarchical softmax)

  • 为了提高效率, 在fastText中计算分类标签概率的时候, 不再使用传统的softmax来进行多分类的计算, 而是使用哈夫曼树, 使用层次化的softmax来进行概率的计算.

二、FastText 文本分类

文本分类的过程

  • 第一步: 获取数据

  • 第二步: 训练集与验证集的划分

  • 第三步: 训练模型

  • 第四步: 使用模型进行预测并评估

  • 第五步: 模型调优

  • 第六步: 模型保存与重加载

API使用代码

python 复制代码
# fasttext文本分类API的使用
import fasttext

# 1- 模型训练和预测
def demo01_base():
    # 1- 模型训练
    model = fasttext.train_supervised(input="data/cooking_train.txt")

    # 2- 使用训练好的模型对数据进行预测
    result1 = model.predict("Which baking dish is best to bake a banana bread ?")
    print(result1)
    result2 = model.predict("Why not put knives in the dishwasher?")
    print(result2)

    # 3- 模型测试
    result = model.test(path="data/cooking_valid.txt")
    print(result)   # (3000, 0.15566666666666668, 0.06732016721925904) 样本条数 精确率 召回率

# 2- 数据基本处理:统一成大小写、标点符号前面加空格。。。
def demo02_preprocessing():
    # 1- 模型训练
    model = fasttext.train_supervised(input="data/cooking.pre.train")

    # 2- 模型测试
    result = model.test(path="data/cooking.pre.valid")
    print(result)

# 3- 增加训练轮次
def demo03_epoch():
    # 1- 模型训练
    model = fasttext.train_supervised(input="data/cooking.pre.train",epoch=20)

    # 2- 模型测试
    result = model.test(path="data/cooking.pre.valid")
    print(result)

# 4- 调整学习率
def demo04_lr():
    # 1- 模型训练
    model = fasttext.train_supervised(input="data/cooking.pre.train",epoch=20,lr=1)

    # 2- 模型测试
    result = model.test(path="data/cooking.pre.valid")
    print(result)

# 5- 设置n-gram参数
def demo05_n_gram():
    # 1- 模型训练
    model = fasttext.train_supervised(input="data/cooking.pre.train",epoch=20,lr=1,wordNgrams=2)

    # 2- 模型测试
    result = model.test(path="data/cooking.pre.valid")
    print(result)

# 6- 调整损失函数
def demo06_loss():
    # 1- 模型训练
    model = fasttext.train_supervised(input="data/cooking.pre.train",epoch=20,lr=1,wordNgrams=2,loss="hs")

    # 2- 模型测试
    result = model.test(path="data/cooking.pre.valid")
    print(result)

# 7- 自动超参数调优
def demo07_auto():
    # 1- 模型训练
    """
        参数解释:
            autotuneDuration:查找最优超参数组合的时间,最终找到的参数组合不一定是最优的
            autotuneValidationFile:找到最优超参数组合,使用验证集数据对参数效果进行验证
    """
    model = fasttext.train_supervised(
        input="data/cooking.pre.train",
        autotuneValidationFile="data/cooking.pre.valid",
        autotuneDuration=60*2
    )

    # 2- 模型测试
    result = model.test(path="data/cooking.pre.valid")
    print(result)

# 8- 多标签多分类问题:将问题拆解为单标签多分类问题,loss损失函数需要设置为ova->one vs all
def demo08_ova():
    # 1- 模型训练
    # 注意:lr的学习率不要过大,如果过大会出现梯度消失/爆炸的情况
    model = fasttext.train_supervised(input="data/cooking.pre.train",epoch=20,lr=0.1,wordNgrams=2,loss="ova")

    # 2- 预测
    """
        k:表示预测结果中目标值最多展示多少个。如果值为-1,那么是尽可能将所有的目标值全部都展示
        threshold:阈值。如果预测的目标值的概率超过了threshold,那么才有可能显示出来
        
        result1中输出的概率值是经过sigmoid计算后的结果,计算是__label__baking标签的概率以及不是__label__baking标签的概率。
        多标签中每个标签的概率计算是相互独立的
    """
    result1 = model.predict("Which baking dish is best to bake a banana bread ?",k=3,threshold=0.5)
    print(result1)

    result2 = model.predict("Which baking dish is best to bake a banana bread ?",k=-1)
    print(result2)

    # 3- 模型测试
    result = model.test(path="data/cooking.pre.valid")
    print(result)

# 9- 保存模型和重新加载模型
def demo09_savemodel():
    # 1- 模型训练
    model = fasttext.train_supervised(input="data/cooking.pre.train",epoch=20,lr=1,wordNgrams=2,loss="hs")

    # 2- 保存训练好的模型
    model.save_model("model/cooking_model.pkl")

    # 3- 加载训练好的模型
    model2 = fasttext.load_model("model/cooking_model.pkl")
    result = model2.test(path="data/cooking.pre.valid")
    print(result)

if __name__ == '__main__':
    # 1- 模型训练和预测
    # demo01_base()   # (3000, 0.15566666666666668, 0.06732016721925904)

    # 2- 数据基本处理
    # demo02_preprocessing() # (3000, 0.172, 0.07438373936860314)

    # 3- 增加训练轮次
    # demo03_epoch()  # (3000, 0.48733333333333334, 0.21075392821104225)

    # 4- 调整学习率
    # demo04_lr() # (3000, 0.5976666666666667, 0.2584690788525299)

    # 5- 设置n-gram参数
    # demo05_n_gram() # (3000, 0.596, 0.2577483061842295)

    # 6- 调整损失函数
    # demo06_loss() # (3000, 0.5946666666666667, 0.2571716880495892)

    # 7- 自动超参数调优
    # demo07_auto() # (3000, 0.536, 0.23180049012541445)

    # 8- 多标签多分类问题
    # demo08_ova() # (3000, 0.532, 0.23007063572149344)

    # 9- 保存模型和重新加载模型
    demo09_savemodel()
相关推荐
寺中人1 小时前
基于 5G 物联网的智慧养老全方位安全监测系统
人工智能·物联网·5g·安全·智能家居
Python私教1 小时前
AI Agent 9秒删库跑路?Cursor安全红线警示录
人工智能·安全
qq_411262421 小时前
四博AI双目智能音箱方案:四路触控、震动马达、0.71/1.28双目光屏、三轴姿态感应,一键语音克隆和专属知识库
人工智能·apache·智能音箱
司南OpenCompass1 小时前
GPT领跑,头部模型“错位竞争”,强Agent能力成下一战场丨大语言模型4月最新榜单揭晓
人工智能·gpt·语言模型·大模型·大模型评测·司南评测
栈溢出了1 小时前
GIN学习笔记
人工智能·神经网络·算法·机器学习·gin
Y敲键盘的地方1 小时前
第5章 模块化设计
人工智能·ai编程
qq_411262421 小时前
基于 ESP32-S3 的四博AI双目智能音箱方案:0.71/1.28双目光屏、四路触控、三轴姿态、震动马达、语音克隆与专属知识库接入
人工智能·智能音箱
chenyuhao20241 小时前
AI agent 开发之嵌入模型和提示词 前置知识
人工智能·深度学习·算法·langchain·agent·ai应用开发
财经资讯数据_灵砚智能1 小时前
基于全球经济类多源新闻的NLP情感分析与数据可视化(日间)2026年5月14日
大数据·人工智能·python·信息可视化·自然语言处理