hg transformers pipeline使用

什么是hg transformers pipeline?

在Hugging Face的transformers库中,pipeline是一个高级API,它提供了一种简便的方式来使用预训练模型进行各种NLP任务,比如情感分析、文本生成、翻译、问答等。通过pipeline,你可以在几行代码内实现复杂的NLP任务。pipeline会自动加载用于指定任务的默认模型和tokenizer,如果需要,用户也可以指定使用特定的模型和tokenizer

在创建pipeline时,除了可以指定任务类型和模型外,还可以设置其他参数,比如使用的深度学习框架("pt"代表PyTorch,"tf"代表TensorFlow)、设备(CPU或GPU)、批量处理大小等 。pipeline背后的实现包括初始化Tokenizer、Model,并进行数据预处理

以下是对pipelines主要特点和功能的总结:

  1. 任务特定: Pipelines为多种NLP任务提供了特定的接口,如文本分类、命名实体识别、问答、文本生成、翻译、摘要和情感分析等。
  2. 模型自动加载: 用户无需关心背后的模型细节,pipelines会自动加载适合任务的预训练模型和tokenizer。
  3. 易于使用: Pipelines提供了简洁的API,用户只需几行代码即可加载模型并进行任务处理。
  4. 自动分词: Pipelines内部处理文本的分词,将文本转换为模型能理解的格式。
  5. 批处理: Pipelines支持批处理,可以同时处理多条文本数据。
  6. 动态调整: Pipelines可以根据输入数据的需要自动调整模型输入,如填充(padding)和截断(truncation)。
  7. 自定义模型和分词器: 用户可以指定自定义的模型和分词器,以适应特定的需求。
  8. 模型微调: 在使用pipelines进行任务之前,用户还可以对模型进行微调,以适应特定的数据集。
  9. 多语言支持: 许多pipelines支持多种语言,使得跨语言的NLP任务成为可能。
  10. 可扩展性: 用户可以根据自己的需求,使用pipelines作为构建块,构建更复杂的NLP流程。
  11. 性能优化: Pipelines针对常见用例进行了优化,以提供高性能的NLP任务处理。
  12. 错误处理: Pipelines提供了错误处理机制,以应对加载模型或处理文本时可能出现的问题。

通过使用pipelines,研究人员和开发者可以快速原型开发和部署NLP应用,而无需深入了解模型的内部工作原理。简而言之,pipelines是Hugging Face Transformers库中一个强大且灵活的工具,用于简化NLP任务的处理流程。

支持的任务分类

可用于音频、计算机视觉、自然语言处理和多模态任务

python 复制代码
TASK_ALIASES = {  
    "sentiment-analysis": "text-classification",  
    "ner": "token-classification",  
    "vqa": "visual-question-answering",  
    "text-to-speech": "text-to-audio",  
}  
SUPPORTED_TASKS = {  
    "audio-classification": {  
        "impl": AudioClassificationPipeline,  
        "tf": (),  
        "pt": (AutoModelForAudioClassification,) if is_torch_available() else (),  
        "default": {"model": {"pt": ("superb/wav2vec2-base-superb-ks", "372e048")}},  
        "type": "audio",  
    },  
    "automatic-speech-recognition": {  
        "impl": AutomaticSpeechRecognitionPipeline,  
        "tf": (),  
        "pt": (AutoModelForCTC, AutoModelForSpeechSeq2Seq) if is_torch_available() else (),  
        "default": {"model": {"pt": ("facebook/wav2vec2-base-960h", "55bb623")}},  
        "type": "multimodal",  
    },  
    "text-to-audio": {  
        "impl": TextToAudioPipeline,  
        "tf": (),  
        "pt": (AutoModelForTextToWaveform, AutoModelForTextToSpectrogram) if is_torch_available() else (),  
        "default": {"model": {"pt": ("suno/bark-small", "645cfba")}},  
        "type": "text",  
    },  
    "feature-extraction": {  
        "impl": FeatureExtractionPipeline,  
        "tf": (TFAutoModel,) if is_tf_available() else (),  
        "pt": (AutoModel,) if is_torch_available() else (),  
        "default": {  
            "model": {  
                "pt": ("distilbert/distilbert-base-cased", "935ac13"),  
                "tf": ("distilbert/distilbert-base-cased", "935ac13"),  
            }  
        },  
        "type": "multimodal",  
    },  
    "text-classification": {  
        "impl": TextClassificationPipeline,  
        "tf": (TFAutoModelForSequenceClassification,) if is_tf_available() else (),  
        "pt": (AutoModelForSequenceClassification,) if is_torch_available() else (),  
        "default": {  
            "model": {  
                "pt": ("distilbert/distilbert-base-uncased-finetuned-sst-2-english", "af0f99b"),  
                "tf": ("distilbert/distilbert-base-uncased-finetuned-sst-2-english", "af0f99b"),  
            },  
        },  
        "type": "text",  
    },  
    "token-classification": {  
        "impl": TokenClassificationPipeline,  
        "tf": (TFAutoModelForTokenClassification,) if is_tf_available() else (),  
        "pt": (AutoModelForTokenClassification,) if is_torch_available() else (),  
        "default": {  
            "model": {  
                "pt": ("dbmdz/bert-large-cased-finetuned-conll03-english", "f2482bf"),  
                "tf": ("dbmdz/bert-large-cased-finetuned-conll03-english", "f2482bf"),  
            },  
        },  
        "type": "text",  
    },  
    "question-answering": {  
        "impl": QuestionAnsweringPipeline,  
        "tf": (TFAutoModelForQuestionAnswering,) if is_tf_available() else (),  
        "pt": (AutoModelForQuestionAnswering,) if is_torch_available() else (),  
        "default": {  
            "model": {  
                "pt": ("distilbert/distilbert-base-cased-distilled-squad", "626af31"),  
                "tf": ("distilbert/distilbert-base-cased-distilled-squad", "626af31"),  
            },  
        },  
        "type": "text",  
    },  
    "table-question-answering": {  
        "impl": TableQuestionAnsweringPipeline,  
        "pt": (AutoModelForTableQuestionAnswering,) if is_torch_available() else (),  
        "tf": (TFAutoModelForTableQuestionAnswering,) if is_tf_available() else (),  
        "default": {  
            "model": {  
                "pt": ("google/tapas-base-finetuned-wtq", "69ceee2"),  
                "tf": ("google/tapas-base-finetuned-wtq", "69ceee2"),  
            },  
        },  
        "type": "text",  
    },  
    "visual-question-answering": {  
        "impl": VisualQuestionAnsweringPipeline,  
        "pt": (AutoModelForVisualQuestionAnswering,) if is_torch_available() else (),  
        "tf": (),  
        "default": {  
            "model": {"pt": ("dandelin/vilt-b32-finetuned-vqa", "4355f59")},  
        },  
        "type": "multimodal",  
    },  
    "document-question-answering": {  
        "impl": DocumentQuestionAnsweringPipeline,  
        "pt": (AutoModelForDocumentQuestionAnswering,) if is_torch_available() else (),  
        "tf": (),  
        "default": {  
            "model": {"pt": ("impira/layoutlm-document-qa", "52e01b3")},  
        },  
        "type": "multimodal",  
    },  
    "fill-mask": {  
        "impl": FillMaskPipeline,  
        "tf": (TFAutoModelForMaskedLM,) if is_tf_available() else (),  
        "pt": (AutoModelForMaskedLM,) if is_torch_available() else (),  
        "default": {  
            "model": {  
                "pt": ("distilbert/distilroberta-base", "ec58a5b"),  
                "tf": ("distilbert/distilroberta-base", "ec58a5b"),  
            }  
        },  
        "type": "text",  
    },  
    "summarization": {  
        "impl": SummarizationPipeline,  
        "tf": (TFAutoModelForSeq2SeqLM,) if is_tf_available() else (),  
        "pt": (AutoModelForSeq2SeqLM,) if is_torch_available() else (),  
        "default": {  
            "model": {"pt": ("sshleifer/distilbart-cnn-12-6", "a4f8f3e"), "tf": ("google-t5/t5-small", "d769bba")}  
        },  
        "type": "text",  
    },  
    # This task is a special case as it's parametrized by SRC, TGT languages.  
    "translation": {  
        "impl": TranslationPipeline,  
        "tf": (TFAutoModelForSeq2SeqLM,) if is_tf_available() else (),  
        "pt": (AutoModelForSeq2SeqLM,) if is_torch_available() else (),  
        "default": {  
            ("en", "fr"): {"model": {"pt": ("google-t5/t5-base", "686f1db"), "tf": ("google-t5/t5-base", "686f1db")}},  
            ("en", "de"): {"model": {"pt": ("google-t5/t5-base", "686f1db"), "tf": ("google-t5/t5-base", "686f1db")}},  
            ("en", "ro"): {"model": {"pt": ("google-t5/t5-base", "686f1db"), "tf": ("google-t5/t5-base", "686f1db")}},  
        },  
        "type": "text",  
    },  
    "text2text-generation": {  
        "impl": Text2TextGenerationPipeline,  
        "tf": (TFAutoModelForSeq2SeqLM,) if is_tf_available() else (),  
        "pt": (AutoModelForSeq2SeqLM,) if is_torch_available() else (),  
        "default": {"model": {"pt": ("google-t5/t5-base", "686f1db"), "tf": ("google-t5/t5-base", "686f1db")}},  
        "type": "text",  
    },  
    "text-generation": {  
        "impl": TextGenerationPipeline,  
        "tf": (TFAutoModelForCausalLM,) if is_tf_available() else (),  
        "pt": (AutoModelForCausalLM,) if is_torch_available() else (),  
        "default": {"model": {"pt": ("openai-community/gpt2", "6c0e608"), "tf": ("openai-community/gpt2", "6c0e608")}},  
        "type": "text",  
    },  
    "zero-shot-classification": {  
        "impl": ZeroShotClassificationPipeline,  
        "tf": (TFAutoModelForSequenceClassification,) if is_tf_available() else (),  
        "pt": (AutoModelForSequenceClassification,) if is_torch_available() else (),  
        "default": {  
            "model": {  
                "pt": ("facebook/bart-large-mnli", "c626438"),  
                "tf": ("FacebookAI/roberta-large-mnli", "130fb28"),  
            },  
            "config": {  
                "pt": ("facebook/bart-large-mnli", "c626438"),  
                "tf": ("FacebookAI/roberta-large-mnli", "130fb28"),  
            },  
        },  
        "type": "text",  
    },  
    "zero-shot-image-classification": {  
        "impl": ZeroShotImageClassificationPipeline,  
        "tf": (TFAutoModelForZeroShotImageClassification,) if is_tf_available() else (),  
        "pt": (AutoModelForZeroShotImageClassification,) if is_torch_available() else (),  
        "default": {  
            "model": {  
                "pt": ("openai/clip-vit-base-patch32", "f4881ba"),  
                "tf": ("openai/clip-vit-base-patch32", "f4881ba"),  
            }  
        },  
        "type": "multimodal",  
    },  
    "zero-shot-audio-classification": {  
        "impl": ZeroShotAudioClassificationPipeline,  
        "tf": (),  
        "pt": (AutoModel,) if is_torch_available() else (),  
        "default": {  
            "model": {  
                "pt": ("laion/clap-htsat-fused", "973b6e5"),  
            }  
        },  
        "type": "multimodal",  
    },  
    "conversational": {  
        "impl": ConversationalPipeline,  
        "tf": (TFAutoModelForSeq2SeqLM, TFAutoModelForCausalLM) if is_tf_available() else (),  
        "pt": (AutoModelForSeq2SeqLM, AutoModelForCausalLM) if is_torch_available() else (),  
        "default": {  
            "model": {"pt": ("microsoft/DialoGPT-medium", "8bada3b"), "tf": ("microsoft/DialoGPT-medium", "8bada3b")}  
        },  
        "type": "text",  
    },  
    "image-classification": {  
        "impl": ImageClassificationPipeline,  
        "tf": (TFAutoModelForImageClassification,) if is_tf_available() else (),  
        "pt": (AutoModelForImageClassification,) if is_torch_available() else (),  
        "default": {  
            "model": {  
                "pt": ("google/vit-base-patch16-224", "5dca96d"),  
                "tf": ("google/vit-base-patch16-224", "5dca96d"),  
            }  
        },  
        "type": "image",  
    },  
    "image-feature-extraction": {  
        "impl": ImageFeatureExtractionPipeline,  
        "tf": (TFAutoModel,) if is_tf_available() else (),  
        "pt": (AutoModel,) if is_torch_available() else (),  
        "default": {  
            "model": {  
                "pt": ("google/vit-base-patch16-224", "3f49326"),  
                "tf": ("google/vit-base-patch16-224", "3f49326"),  
            }  
        },  
        "type": "image",  
    },  
    "image-segmentation": {  
        "impl": ImageSegmentationPipeline,  
        "tf": (),  
        "pt": (AutoModelForImageSegmentation, AutoModelForSemanticSegmentation) if is_torch_available() else (),  
        "default": {"model": {"pt": ("facebook/detr-resnet-50-panoptic", "fc15262")}},  
        "type": "multimodal",  
    },  
    "image-to-text": {  
        "impl": ImageToTextPipeline,  
        "tf": (TFAutoModelForVision2Seq,) if is_tf_available() else (),  
        "pt": (AutoModelForVision2Seq,) if is_torch_available() else (),  
        "default": {  
            "model": {  
                "pt": ("ydshieh/vit-gpt2-coco-en", "65636df"),  
                "tf": ("ydshieh/vit-gpt2-coco-en", "65636df"),  
            }  
        },  
        "type": "multimodal",  
    },  
    "object-detection": {  
        "impl": ObjectDetectionPipeline,  
        "tf": (),  
        "pt": (AutoModelForObjectDetection,) if is_torch_available() else (),  
        "default": {"model": {"pt": ("facebook/detr-resnet-50", "2729413")}},  
        "type": "multimodal",  
    },  
    "zero-shot-object-detection": {  
        "impl": ZeroShotObjectDetectionPipeline,  
        "tf": (),  
        "pt": (AutoModelForZeroShotObjectDetection,) if is_torch_available() else (),  
        "default": {"model": {"pt": ("google/owlvit-base-patch32", "17740e1")}},  
        "type": "multimodal",  
    },  
    "depth-estimation": {  
        "impl": DepthEstimationPipeline,  
        "tf": (),  
        "pt": (AutoModelForDepthEstimation,) if is_torch_available() else (),  
        "default": {"model": {"pt": ("Intel/dpt-large", "e93beec")}},  
        "type": "image",  
    },  
    "video-classification": {  
        "impl": VideoClassificationPipeline,  
        "tf": (),  
        "pt": (AutoModelForVideoClassification,) if is_torch_available() else (),  
        "default": {"model": {"pt": ("MCG-NJU/videomae-base-finetuned-kinetics", "4800870")}},  
        "type": "video",  
    },  
    "mask-generation": {  
        "impl": MaskGenerationPipeline,  
        "tf": (),  
        "pt": (AutoModelForMaskGeneration,) if is_torch_available() else (),  
        "default": {"model": {"pt": ("facebook/sam-vit-huge", "997b15")}},  
        "type": "multimodal",  
    },  
    "image-to-image": {  
        "impl": ImageToImagePipeline,  
        "tf": (),  
        "pt": (AutoModelForImageToImage,) if is_torch_available() else (),  
        "default": {"model": {"pt": ("caidas/swin2SR-classical-sr-x2-64", "4aaedcb")}},  
        "type": "image",  
    },  
}

使用示例

简单使用示例

python 复制代码
from transformers import pipeline  
from transformers.pipelines import get_supported_tasks  
import json  
  
nlp = pipeline("sentiment-analysis")  
# 单次调用  
result = nlp("I hate you")[0]  
print(f"label: {result['label']}, score: {round(result['score'], 4)}")  
# label: NEGATIVE, with score: 0.9991  
result = nlp("I love you")[0]  
print(f"label: {result['label']}, score: {round(result['score'], 4)}")  
# label: POSITIVE, with score  
  
# 多次调用  
result = nlp(["This restaurant is awesome", "This restaurant is awful"])  
print(json.dumps(result))  
  
  
print(json.dumps(get_supported_tasks()))

执行的输出日志如下:

  • 因为未指定model,默认根据任务分类名称从hg下载对应的模型,sentiment-analysis任务对应的默认模型是:models--distilbert--distilbert-base-uncased-finetuned-sst-2-english,默认是af0f99b
  • 下载的model保存的默认目录是:C:\Users\用户名.cache\huggingface\hub\
  • 不建议在生产环境中不指定model及版本
python 复制代码
python.exe Classification.py 
No model was supplied, defaulted to distilbert/distilbert-base-uncased-finetuned-sst-2-english and revision af0f99b (https://huggingface.co/distilbert/distilbert-base-uncased-finetuned-sst-2-english).
Using a pipeline without specifying a model name and revision in production is not recommended.
D:\soft\anaconda3\envs\llm-demo\lib\site-packages\huggingface_hub\file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
D:\soft\anaconda3\envs\llm-demo\lib\site-packages\huggingface_hub\file_download.py:157: UserWarning: `huggingface_hub` cache-system uses symlinks by default to efficiently store duplicated files but your machine does not support them in C:\Users\wang\.cache\huggingface\hub\models--distilbert--distilbert-base-uncased-finetuned-sst-2-english. Caching files will still work but in a degraded version that might require more space on your disk. This warning can be disabled by setting the `HF_HUB_DISABLE_SYMLINKS_WARNING` environment variable. For more details, see https://huggingface.co/docs/huggingface_hub/how-to-cache#limitations.
To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to see activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development
  warnings.warn(message)
label: NEGATIVE, score: 0.9991
label: POSITIVE, score: 0.9999

[{"label": "POSITIVE", "score": 0.9998743534088135}, {"label": "NEGATIVE", "score": 0.9996669292449951}]

["audio-classification", "automatic-speech-recognition", "conversational", "depth-estimation", "document-question-answering", "feature-extraction", "fill-mask", "image-classification", "image-feature-extraction", "image-segmentation", "image-to-image", "image-to-text", "mask-generation", "ner", "object-detection", "question-answering", "sentiment-analysis", "summarization", "table-question-answering", "text-classification", "text-generation", "text-to-audio", "text-to-speech", "text2text-generation", "token-classification", "translation", "video-classification", "visual-question-answering", "vqa", "zero-shot-audio-classification", "zero-shot-classification", "zero-shot-image-classification", "zero-shot-object-detection"]

Pipeline batching

python 复制代码
from transformers import pipeline  
from transformers.pipelines.pt_utils import KeyDataset  
import datasets  
  
dataset = datasets.load_dataset("imdb", name="plain_text", split="unsupervised")  
  
pipe = pipeline(task="sentiment-analysis")  
  
for out in pipe(KeyDataset(dataset, "text"), batch_size=8, truncation="only_first"):  
    print(out)

自定义数据集

python 复制代码
from transformers import pipeline
from torch.utils.data import Dataset
from tqdm.auto import tqdm

pipe = pipeline("text-classification", device=0)

class MyDataset(Dataset):
    def __len__(self):
        return 5000

    def __getitem__(self, i):
        return "This is a test"

dataset = MyDataset()

for batch_size in [1, 8, 64, 256]:
    print("-" * 30)
    print(f"Streaming batch_size={batch_size}")
    for out in tqdm(pipe(dataset, batch_size=batch_size), total=len(dataset)):
        pass

文本summary

python 复制代码
# use bart in pytorch
summarizer = pipeline("summarization")
summarizer("An apple a day, keeps the doctor away", min_length=5, max_length=20)

# use t5 in tf
summarizer = pipeline("summarization", model="google-t5/t5-base", tokenizer="google-t5/t5-base", framework="tf")
summarizer("An apple a day, keeps the doctor away", min_length=5, max_length=20)

学习资料

相关推荐
威化饼的一隅1 小时前
【多模态】swift-3框架使用
人工智能·深度学习·大模型·swift·多模态
SimonLiu0092 小时前
[AI]30分钟用cursor开发一个chrome插件
chrome·ai·ai编程
伯牙碎琴3 小时前
智能体实战(需求分析助手)二、需求分析助手第一版实现(支持需求提取、整理、痛点分析、需求分类、优先级分析、需求文档生成等功能)
ai·大模型·agent·需求分析·智能体
RWKV元始智能4 小时前
RWKV-7:极先进的大模型架构,长文本能力极强
人工智能·llm
聆思科技AI芯片5 小时前
实操给桌面机器人加上超拟人音色
人工智能·机器人·大模型·aigc·多模态·智能音箱·语音交互
卓琢19 小时前
2024 年 IA 技术大爆发深度解析
深度学习·ai·论文笔记
zaim11 天前
计算机的错误计算(一百八十七)
人工智能·ai·大模型·llm·错误·正弦/sin·误差/error
张拭心1 天前
Google 提供的 Android 端上大模型组件:MediaPipe LLM 介绍
android·人工智能·llm
带电的小王1 天前
whisper.cpp: Android端测试 -- Android端手机部署音频大模型
android·智能手机·llm·whisper·音频大模型·whisper.cpp
Engineer-Yao1 天前
【win10+RAGFlow+Ollama】搭建本地大模型助手(教程+源码)
docker·大模型·win10·wsl·ollama·本地大模型·ragflow