什么是hg transformers pipeline?
在Hugging Face的transformers
库中,pipeline
是一个高级API,它提供了一种简便的方式来使用预训练模型进行各种NLP任务,比如情感分析、文本生成、翻译、问答等。通过pipeline
,你可以在几行代码内实现复杂的NLP任务。pipeline
会自动加载用于指定任务的默认模型和tokenizer,如果需要,用户也可以指定使用特定的模型和tokenizer
在创建pipeline
时,除了可以指定任务类型和模型外,还可以设置其他参数,比如使用的深度学习框架("pt"代表PyTorch,"tf"代表TensorFlow)、设备(CPU或GPU)、批量处理大小等 。pipeline
背后的实现包括初始化Tokenizer、Model,并进行数据预处理
以下是对pipelines
主要特点和功能的总结:
- 任务特定: Pipelines为多种NLP任务提供了特定的接口,如文本分类、命名实体识别、问答、文本生成、翻译、摘要和情感分析等。
- 模型自动加载: 用户无需关心背后的模型细节,pipelines会自动加载适合任务的预训练模型和tokenizer。
- 易于使用: Pipelines提供了简洁的API,用户只需几行代码即可加载模型并进行任务处理。
- 自动分词: Pipelines内部处理文本的分词,将文本转换为模型能理解的格式。
- 批处理: Pipelines支持批处理,可以同时处理多条文本数据。
- 动态调整: Pipelines可以根据输入数据的需要自动调整模型输入,如填充(padding)和截断(truncation)。
- 自定义模型和分词器: 用户可以指定自定义的模型和分词器,以适应特定的需求。
- 模型微调: 在使用pipelines进行任务之前,用户还可以对模型进行微调,以适应特定的数据集。
- 多语言支持: 许多pipelines支持多种语言,使得跨语言的NLP任务成为可能。
- 可扩展性: 用户可以根据自己的需求,使用pipelines作为构建块,构建更复杂的NLP流程。
- 性能优化: Pipelines针对常见用例进行了优化,以提供高性能的NLP任务处理。
- 错误处理: Pipelines提供了错误处理机制,以应对加载模型或处理文本时可能出现的问题。
通过使用pipelines,研究人员和开发者可以快速原型开发和部署NLP应用,而无需深入了解模型的内部工作原理。简而言之,pipelines是Hugging Face Transformers库中一个强大且灵活的工具,用于简化NLP任务的处理流程。
支持的任务分类
可用于音频、计算机视觉、自然语言处理和多模态任务
python
TASK_ALIASES = {
"sentiment-analysis": "text-classification",
"ner": "token-classification",
"vqa": "visual-question-answering",
"text-to-speech": "text-to-audio",
}
SUPPORTED_TASKS = {
"audio-classification": {
"impl": AudioClassificationPipeline,
"tf": (),
"pt": (AutoModelForAudioClassification,) if is_torch_available() else (),
"default": {"model": {"pt": ("superb/wav2vec2-base-superb-ks", "372e048")}},
"type": "audio",
},
"automatic-speech-recognition": {
"impl": AutomaticSpeechRecognitionPipeline,
"tf": (),
"pt": (AutoModelForCTC, AutoModelForSpeechSeq2Seq) if is_torch_available() else (),
"default": {"model": {"pt": ("facebook/wav2vec2-base-960h", "55bb623")}},
"type": "multimodal",
},
"text-to-audio": {
"impl": TextToAudioPipeline,
"tf": (),
"pt": (AutoModelForTextToWaveform, AutoModelForTextToSpectrogram) if is_torch_available() else (),
"default": {"model": {"pt": ("suno/bark-small", "645cfba")}},
"type": "text",
},
"feature-extraction": {
"impl": FeatureExtractionPipeline,
"tf": (TFAutoModel,) if is_tf_available() else (),
"pt": (AutoModel,) if is_torch_available() else (),
"default": {
"model": {
"pt": ("distilbert/distilbert-base-cased", "935ac13"),
"tf": ("distilbert/distilbert-base-cased", "935ac13"),
}
},
"type": "multimodal",
},
"text-classification": {
"impl": TextClassificationPipeline,
"tf": (TFAutoModelForSequenceClassification,) if is_tf_available() else (),
"pt": (AutoModelForSequenceClassification,) if is_torch_available() else (),
"default": {
"model": {
"pt": ("distilbert/distilbert-base-uncased-finetuned-sst-2-english", "af0f99b"),
"tf": ("distilbert/distilbert-base-uncased-finetuned-sst-2-english", "af0f99b"),
},
},
"type": "text",
},
"token-classification": {
"impl": TokenClassificationPipeline,
"tf": (TFAutoModelForTokenClassification,) if is_tf_available() else (),
"pt": (AutoModelForTokenClassification,) if is_torch_available() else (),
"default": {
"model": {
"pt": ("dbmdz/bert-large-cased-finetuned-conll03-english", "f2482bf"),
"tf": ("dbmdz/bert-large-cased-finetuned-conll03-english", "f2482bf"),
},
},
"type": "text",
},
"question-answering": {
"impl": QuestionAnsweringPipeline,
"tf": (TFAutoModelForQuestionAnswering,) if is_tf_available() else (),
"pt": (AutoModelForQuestionAnswering,) if is_torch_available() else (),
"default": {
"model": {
"pt": ("distilbert/distilbert-base-cased-distilled-squad", "626af31"),
"tf": ("distilbert/distilbert-base-cased-distilled-squad", "626af31"),
},
},
"type": "text",
},
"table-question-answering": {
"impl": TableQuestionAnsweringPipeline,
"pt": (AutoModelForTableQuestionAnswering,) if is_torch_available() else (),
"tf": (TFAutoModelForTableQuestionAnswering,) if is_tf_available() else (),
"default": {
"model": {
"pt": ("google/tapas-base-finetuned-wtq", "69ceee2"),
"tf": ("google/tapas-base-finetuned-wtq", "69ceee2"),
},
},
"type": "text",
},
"visual-question-answering": {
"impl": VisualQuestionAnsweringPipeline,
"pt": (AutoModelForVisualQuestionAnswering,) if is_torch_available() else (),
"tf": (),
"default": {
"model": {"pt": ("dandelin/vilt-b32-finetuned-vqa", "4355f59")},
},
"type": "multimodal",
},
"document-question-answering": {
"impl": DocumentQuestionAnsweringPipeline,
"pt": (AutoModelForDocumentQuestionAnswering,) if is_torch_available() else (),
"tf": (),
"default": {
"model": {"pt": ("impira/layoutlm-document-qa", "52e01b3")},
},
"type": "multimodal",
},
"fill-mask": {
"impl": FillMaskPipeline,
"tf": (TFAutoModelForMaskedLM,) if is_tf_available() else (),
"pt": (AutoModelForMaskedLM,) if is_torch_available() else (),
"default": {
"model": {
"pt": ("distilbert/distilroberta-base", "ec58a5b"),
"tf": ("distilbert/distilroberta-base", "ec58a5b"),
}
},
"type": "text",
},
"summarization": {
"impl": SummarizationPipeline,
"tf": (TFAutoModelForSeq2SeqLM,) if is_tf_available() else (),
"pt": (AutoModelForSeq2SeqLM,) if is_torch_available() else (),
"default": {
"model": {"pt": ("sshleifer/distilbart-cnn-12-6", "a4f8f3e"), "tf": ("google-t5/t5-small", "d769bba")}
},
"type": "text",
},
# This task is a special case as it's parametrized by SRC, TGT languages.
"translation": {
"impl": TranslationPipeline,
"tf": (TFAutoModelForSeq2SeqLM,) if is_tf_available() else (),
"pt": (AutoModelForSeq2SeqLM,) if is_torch_available() else (),
"default": {
("en", "fr"): {"model": {"pt": ("google-t5/t5-base", "686f1db"), "tf": ("google-t5/t5-base", "686f1db")}},
("en", "de"): {"model": {"pt": ("google-t5/t5-base", "686f1db"), "tf": ("google-t5/t5-base", "686f1db")}},
("en", "ro"): {"model": {"pt": ("google-t5/t5-base", "686f1db"), "tf": ("google-t5/t5-base", "686f1db")}},
},
"type": "text",
},
"text2text-generation": {
"impl": Text2TextGenerationPipeline,
"tf": (TFAutoModelForSeq2SeqLM,) if is_tf_available() else (),
"pt": (AutoModelForSeq2SeqLM,) if is_torch_available() else (),
"default": {"model": {"pt": ("google-t5/t5-base", "686f1db"), "tf": ("google-t5/t5-base", "686f1db")}},
"type": "text",
},
"text-generation": {
"impl": TextGenerationPipeline,
"tf": (TFAutoModelForCausalLM,) if is_tf_available() else (),
"pt": (AutoModelForCausalLM,) if is_torch_available() else (),
"default": {"model": {"pt": ("openai-community/gpt2", "6c0e608"), "tf": ("openai-community/gpt2", "6c0e608")}},
"type": "text",
},
"zero-shot-classification": {
"impl": ZeroShotClassificationPipeline,
"tf": (TFAutoModelForSequenceClassification,) if is_tf_available() else (),
"pt": (AutoModelForSequenceClassification,) if is_torch_available() else (),
"default": {
"model": {
"pt": ("facebook/bart-large-mnli", "c626438"),
"tf": ("FacebookAI/roberta-large-mnli", "130fb28"),
},
"config": {
"pt": ("facebook/bart-large-mnli", "c626438"),
"tf": ("FacebookAI/roberta-large-mnli", "130fb28"),
},
},
"type": "text",
},
"zero-shot-image-classification": {
"impl": ZeroShotImageClassificationPipeline,
"tf": (TFAutoModelForZeroShotImageClassification,) if is_tf_available() else (),
"pt": (AutoModelForZeroShotImageClassification,) if is_torch_available() else (),
"default": {
"model": {
"pt": ("openai/clip-vit-base-patch32", "f4881ba"),
"tf": ("openai/clip-vit-base-patch32", "f4881ba"),
}
},
"type": "multimodal",
},
"zero-shot-audio-classification": {
"impl": ZeroShotAudioClassificationPipeline,
"tf": (),
"pt": (AutoModel,) if is_torch_available() else (),
"default": {
"model": {
"pt": ("laion/clap-htsat-fused", "973b6e5"),
}
},
"type": "multimodal",
},
"conversational": {
"impl": ConversationalPipeline,
"tf": (TFAutoModelForSeq2SeqLM, TFAutoModelForCausalLM) if is_tf_available() else (),
"pt": (AutoModelForSeq2SeqLM, AutoModelForCausalLM) if is_torch_available() else (),
"default": {
"model": {"pt": ("microsoft/DialoGPT-medium", "8bada3b"), "tf": ("microsoft/DialoGPT-medium", "8bada3b")}
},
"type": "text",
},
"image-classification": {
"impl": ImageClassificationPipeline,
"tf": (TFAutoModelForImageClassification,) if is_tf_available() else (),
"pt": (AutoModelForImageClassification,) if is_torch_available() else (),
"default": {
"model": {
"pt": ("google/vit-base-patch16-224", "5dca96d"),
"tf": ("google/vit-base-patch16-224", "5dca96d"),
}
},
"type": "image",
},
"image-feature-extraction": {
"impl": ImageFeatureExtractionPipeline,
"tf": (TFAutoModel,) if is_tf_available() else (),
"pt": (AutoModel,) if is_torch_available() else (),
"default": {
"model": {
"pt": ("google/vit-base-patch16-224", "3f49326"),
"tf": ("google/vit-base-patch16-224", "3f49326"),
}
},
"type": "image",
},
"image-segmentation": {
"impl": ImageSegmentationPipeline,
"tf": (),
"pt": (AutoModelForImageSegmentation, AutoModelForSemanticSegmentation) if is_torch_available() else (),
"default": {"model": {"pt": ("facebook/detr-resnet-50-panoptic", "fc15262")}},
"type": "multimodal",
},
"image-to-text": {
"impl": ImageToTextPipeline,
"tf": (TFAutoModelForVision2Seq,) if is_tf_available() else (),
"pt": (AutoModelForVision2Seq,) if is_torch_available() else (),
"default": {
"model": {
"pt": ("ydshieh/vit-gpt2-coco-en", "65636df"),
"tf": ("ydshieh/vit-gpt2-coco-en", "65636df"),
}
},
"type": "multimodal",
},
"object-detection": {
"impl": ObjectDetectionPipeline,
"tf": (),
"pt": (AutoModelForObjectDetection,) if is_torch_available() else (),
"default": {"model": {"pt": ("facebook/detr-resnet-50", "2729413")}},
"type": "multimodal",
},
"zero-shot-object-detection": {
"impl": ZeroShotObjectDetectionPipeline,
"tf": (),
"pt": (AutoModelForZeroShotObjectDetection,) if is_torch_available() else (),
"default": {"model": {"pt": ("google/owlvit-base-patch32", "17740e1")}},
"type": "multimodal",
},
"depth-estimation": {
"impl": DepthEstimationPipeline,
"tf": (),
"pt": (AutoModelForDepthEstimation,) if is_torch_available() else (),
"default": {"model": {"pt": ("Intel/dpt-large", "e93beec")}},
"type": "image",
},
"video-classification": {
"impl": VideoClassificationPipeline,
"tf": (),
"pt": (AutoModelForVideoClassification,) if is_torch_available() else (),
"default": {"model": {"pt": ("MCG-NJU/videomae-base-finetuned-kinetics", "4800870")}},
"type": "video",
},
"mask-generation": {
"impl": MaskGenerationPipeline,
"tf": (),
"pt": (AutoModelForMaskGeneration,) if is_torch_available() else (),
"default": {"model": {"pt": ("facebook/sam-vit-huge", "997b15")}},
"type": "multimodal",
},
"image-to-image": {
"impl": ImageToImagePipeline,
"tf": (),
"pt": (AutoModelForImageToImage,) if is_torch_available() else (),
"default": {"model": {"pt": ("caidas/swin2SR-classical-sr-x2-64", "4aaedcb")}},
"type": "image",
},
}
使用示例
简单使用示例
python
from transformers import pipeline
from transformers.pipelines import get_supported_tasks
import json
nlp = pipeline("sentiment-analysis")
# 单次调用
result = nlp("I hate you")[0]
print(f"label: {result['label']}, score: {round(result['score'], 4)}")
# label: NEGATIVE, with score: 0.9991
result = nlp("I love you")[0]
print(f"label: {result['label']}, score: {round(result['score'], 4)}")
# label: POSITIVE, with score
# 多次调用
result = nlp(["This restaurant is awesome", "This restaurant is awful"])
print(json.dumps(result))
print(json.dumps(get_supported_tasks()))
执行的输出日志如下:
- 因为未指定model,默认根据任务分类名称从hg下载对应的模型,sentiment-analysis任务对应的默认模型是:models--distilbert--distilbert-base-uncased-finetuned-sst-2-english,默认是af0f99b
- 下载的model保存的默认目录是:C:\Users\用户名.cache\huggingface\hub\
- 不建议在生产环境中不指定model及版本
python
python.exe Classification.py
No model was supplied, defaulted to distilbert/distilbert-base-uncased-finetuned-sst-2-english and revision af0f99b (https://huggingface.co/distilbert/distilbert-base-uncased-finetuned-sst-2-english).
Using a pipeline without specifying a model name and revision in production is not recommended.
D:\soft\anaconda3\envs\llm-demo\lib\site-packages\huggingface_hub\file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
warnings.warn(
D:\soft\anaconda3\envs\llm-demo\lib\site-packages\huggingface_hub\file_download.py:157: UserWarning: `huggingface_hub` cache-system uses symlinks by default to efficiently store duplicated files but your machine does not support them in C:\Users\wang\.cache\huggingface\hub\models--distilbert--distilbert-base-uncased-finetuned-sst-2-english. Caching files will still work but in a degraded version that might require more space on your disk. This warning can be disabled by setting the `HF_HUB_DISABLE_SYMLINKS_WARNING` environment variable. For more details, see https://huggingface.co/docs/huggingface_hub/how-to-cache#limitations.
To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to see activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development
warnings.warn(message)
label: NEGATIVE, score: 0.9991
label: POSITIVE, score: 0.9999
[{"label": "POSITIVE", "score": 0.9998743534088135}, {"label": "NEGATIVE", "score": 0.9996669292449951}]
["audio-classification", "automatic-speech-recognition", "conversational", "depth-estimation", "document-question-answering", "feature-extraction", "fill-mask", "image-classification", "image-feature-extraction", "image-segmentation", "image-to-image", "image-to-text", "mask-generation", "ner", "object-detection", "question-answering", "sentiment-analysis", "summarization", "table-question-answering", "text-classification", "text-generation", "text-to-audio", "text-to-speech", "text2text-generation", "token-classification", "translation", "video-classification", "visual-question-answering", "vqa", "zero-shot-audio-classification", "zero-shot-classification", "zero-shot-image-classification", "zero-shot-object-detection"]
Pipeline batching
python
from transformers import pipeline
from transformers.pipelines.pt_utils import KeyDataset
import datasets
dataset = datasets.load_dataset("imdb", name="plain_text", split="unsupervised")
pipe = pipeline(task="sentiment-analysis")
for out in pipe(KeyDataset(dataset, "text"), batch_size=8, truncation="only_first"):
print(out)
自定义数据集
python
from transformers import pipeline
from torch.utils.data import Dataset
from tqdm.auto import tqdm
pipe = pipeline("text-classification", device=0)
class MyDataset(Dataset):
def __len__(self):
return 5000
def __getitem__(self, i):
return "This is a test"
dataset = MyDataset()
for batch_size in [1, 8, 64, 256]:
print("-" * 30)
print(f"Streaming batch_size={batch_size}")
for out in tqdm(pipe(dataset, batch_size=batch_size), total=len(dataset)):
pass
文本summary
python
# use bart in pytorch
summarizer = pipeline("summarization")
summarizer("An apple a day, keeps the doctor away", min_length=5, max_length=20)
# use t5 in tf
summarizer = pipeline("summarization", model="google-t5/t5-base", tokenizer="google-t5/t5-base", framework="tf")
summarizer("An apple a day, keeps the doctor away", min_length=5, max_length=20)