functools
模块中的 singledispatch
装饰器允许你定义一个函数,并根据传入参数的类型自动选择相应的实现。这在处理不同类型的输入时非常有用。
singledispatch
装饰器提供了一种在 Python 中实现函数重载的方式。虽然 Python 本身不支持传统的函数重载(即在同一个作用域中定义多个同名函数),但 singledispatch
允许你根据参数类型来选择不同的函数实现,从而实现类似重载的效果。
下面是一个简单的示例,展示了如何使用 singledispatch
装饰器:
py
from functools import singledispatch
@singledispatch
def process(arg):
raise NotImplementedError("Unsupported type")
@process.register
def _(arg: int):
return f"Processing an integer: {arg}"
@process.register
def _(arg: str):
return f"Processing a string: {arg}"
@process.register
def _(arg: list):
return f"Processing a list: {', '.join(map(str, arg))}"
# 测试
print(process(10)) # 输出: Processing an integer: 10
print(process("hello")) # 输出: Processing a string: hello
print(process([1, 2, 3])) # 输出: Processing a list: 1, 2, 3
在上述示例中,process 函数根据传入参数的类型(int、str、list)选择不同的处理逻辑,这就是一种函数重载的形式。
再看一个实例:
py
from functools import singledispatch
from outlines.fsm.guide import StopAtEOSGuide
from outlines.generate.api import SequenceGenerator, SequenceGeneratorAdapter
from outlines.models import MLXLM, VLLM, LlamaCpp, OpenAI
from outlines.samplers import Sampler, multinomial
@singledispatch
def text(model, sampler: Sampler = multinomial()) -> SequenceGenerator:
"""Generate text with a `Transformer` model.
Note
----
Python 3.11 allows dispatching on Union types and
this should greatly simplify the code.
Arguments
---------
model:
An instance of `Transformer` that represents a model from the
`transformers` library.
sampler:
The sampling algorithm to use to generate token ids from the logits
distribution.
Returns
-------
A `SequenceGenerator` instance that generates text.
"""
fsm = StopAtEOSGuide(model.tokenizer)
device = model.device
generator = SequenceGenerator(fsm, model, sampler, device)
return generator
@text.register(MLXLM)
def text_mlxlm(model: MLXLM, sampler: Sampler = multinomial()):
return SequenceGeneratorAdapter(model, None, sampler)
@text.register(VLLM)
def text_vllm(model: VLLM, sampler: Sampler = multinomial()):
return SequenceGeneratorAdapter(model, None, sampler)
@text.register(LlamaCpp)
def text_llamacpp(model: LlamaCpp, sampler: Sampler = multinomial()):
return SequenceGeneratorAdapter(model, None, sampler)
@text.register(OpenAI)
def text_openai(model: OpenAI, sampler: Sampler = multinomial()) -> OpenAI:
if not isinstance(sampler, multinomial):
raise NotImplementedError(
r"The OpenAI API does not support any other sampling algorithm "
+ "than the multinomial sampler."
)
return model