Replicate Python client

本文翻译整理自:https://github.com/replicate/replicate-python

文章目录


一、关于 Replicate Python 客户端

这是一个用于 Replicate 的 Python 客户端库,允许您从 Python 代码或 Jupyter Notebook 中运行模型,并在 Replicate 平台上执行各种操作。


相关链接资源


关键功能特性

  • 运行模型预测
  • 流式输出处理
  • 后台模型执行
  • 模型管道组合
  • 训练自定义模型
  • 预测管理(取消/列表)
  • 异步IO支持
  • Webhook集成

二、1.0.0 版本的重大变更

1.0.0 版本包含以下破坏性变更:

  • 对于输出文件的模型,replicate.run() 方法现在默认返回 FileOutput 对象而非 URL 字符串。FileOutput 实现了类似 httpx.Response 的可迭代接口,使文件处理更高效。

如需恢复旧行为,可通过传递 use_file_output=False 参数禁用 FileOutput

python 复制代码
output = replicate.run("acmecorp/acme-model", use_file_output=False)

在大多数情况下,更新现有应用程序以调用 output.url 即可解决问题。

但我们建议直接使用 FileOutput 对象,因为我们计划对该 API 进行进一步改进,这种方法能确保获得最快的处理结果。

!TIP

👋 查看本教程的交互式版本:Google Colab

https://colab.research.google.com/drive/1K91q4p-OhL96FHBAVLsv9FlwFdu6Pn3c


三、安装与配置

1、系统要求

  • Python 3.8+

2、安装

sh 复制代码
pip install replicate

3、认证配置

在使用 API 运行任何 Python 脚本前,需设置环境变量中的 Replicate API 令牌。

replicate.com/account 获取令牌并设置为环境变量:

shell 复制代码
export REPLICATE_API_TOKEN=<your token>

我们建议不要直接将令牌添加到源代码中,因为您不希望将凭证提交到版本控制系统。如果任何人使用您的 API 密钥,其使用量将计入您的账户。


四、核心功能

1、运行模型

创建新的 Python 文件并添加以下代码,替换为您自己的模型标识符和输入:

python 复制代码
>>> import replicate
>>> outputs = replicate.run(
        "black-forest-labs/flux-schnell",         input={"prompt": "astronaut riding a rocket like a horse"}
    )
[<replicate.helpers.FileOutput object at 0x107179b50>]
>>> for index, output in enumerate(outputs):
        with open(f"output_{index}.webp", "wb") as file:
            file.write(output.read())

如果预测失败,replicate.run 会抛出 ModelError 异常。您可以通过异常的 prediction 属性获取更多失败信息。

python 复制代码
import replicate
from replicate.exceptions import ModelError

try:
  output = replicate.run("stability-ai/stable-diffusion-3", { "prompt": "An astronaut riding a rainbow unicorn" })
except ModelError as e
  if "(some known issue)" in e.prediction.logs:
    pass

  print("Failed prediction: " + e.prediction.id)

!NOTE

默认情况下,Replicate 客户端会保持连接打开最多 60 秒,等待预测完成。这种设计是为了优化模型输出返回客户端的速度。

可通过传递 wait=xreplicate.run() 来配置超时,其中 x 是 1 到 60 秒之间的超时值。要禁用同步模式,可传递 wait=False


2、异步IO支持

通过在方法名前添加 async_ 前缀,您也可以异步使用 Replicate 客户端。

以下是并发运行多个预测并等待它们全部完成的示例:

python 复制代码
import asyncio
import replicate
 
# https://replicate.com/stability-ai/sdxl
model_version = "stability-ai/sdxl:39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b"
prompts = [
    f"A chariot pulled by a team of {count} rainbow unicorns"
    for count in ["two", "four", "six", "eight"]
]

async with asyncio.TaskGroup() as tg:
    tasks = [
        tg.create_task(replicate.async_run(model_version, input={"prompt": prompt}))
        for prompt in prompts
    ]

results = await asyncio.gather(*tasks)
print(results)

对于需要文件输入的模型,您可以传递互联网上可公开访问文件的 URL,或本地设备上的文件句柄:

python 复制代码
>>> output = replicate.run(
        "andreasjansson/blip-2:f677695e5e89f8b236e52ecd1d3f01beb44c34606419bcc19345e046d8f786f9",         input={ "image": open("path/to/mystery.jpg") }
    )

"an astronaut riding a horse"

3、流式输出模型

Replicate 的 API 支持语言模型的服务器发送事件流(SSEs)。使用 stream 方法可以实时消费模型生成的标记。

python 复制代码
import replicate

for event in replicate.stream(
    "meta/meta-llama-3-70b-instruct",     input={
        "prompt": "Please write a haiku about llamas.",     }, ):
    print(str(event), end="")

!TIP

某些模型如 meta/meta-llama-3-70b-instruct 不需要版本字符串。您始终可以参考模型页面上的 API 文档了解具体细节。


您也可以流式传输已创建预测的输出。这在您希望将预测 ID 与其输出分开时很有用。

python 复制代码
prediction = replicate.predictions.create(
    model="meta/meta-llama-3-70b-instruct",     input={"prompt": "Please write a haiku about llamas."},     stream=True, )

for event in prediction.stream():
    print(str(event), end="")

更多信息请参阅 Replicate 文档中的"流式输出"


4、后台运行模型

您可以使用异步模式在后台启动并运行模型:

python 复制代码
>>> model = replicate.models.get("kvfrans/clipdraw")
>>> version = model.versions.get("5797a99edc939ea0e9242d5e8c9cb3bc7d125b1eac21bda852e5cb79ede2cd9b")
>>> prediction = replicate.predictions.create(
    version=version,     input={"prompt":"Watercolor painting of an underwater submarine"})

>>> prediction
Prediction(...)

>>> prediction.status
'starting'

>>> dict(prediction)
{"id": "...", "status": "starting", ...}

>>> prediction.reload()
>>> prediction.status
'processing'

>>> print(prediction.logs)
iteration: 0, render:loss: -0.6171875
iteration: 10, render:loss: -0.92236328125
iteration: 20, render:loss: -1.197265625
iteration: 30, render:loss: -1.3994140625

>>> prediction.wait()

>>> prediction.status
'succeeded'

>>> prediction.output
<replicate.helpers.FileOutput object at 0x107179b50>

>>> with open("output.png", "wb") as file:
        file.write(prediction.output.read())

5、后台运行模型并获取Webhook

您可以运行模型并在完成时获取 webhook,而不是等待它完成:

python 复制代码
model = replicate.models.get("ai-forever/kandinsky-2.2")
version = model.versions.get("ea1addaab376f4dc227f5368bbd8eff901820fd1cc14ed8cad63b29249e9d463")
prediction = replicate.predictions.create(
    version=version,     input={"prompt":"Watercolor painting of an underwater submarine"},     webhook="https://example.com/your-webhook",     webhook_events_filter=["completed"]
)

有关接收 webhook 的详细信息,请参阅 replicate.com/docs/webhooks


6、组合模型管道

您可以运行一个模型并将其输出作为另一个模型的输入:

python 复制代码
laionide = replicate.models.get("afiaka87/laionide-v4").versions.get("b21cbe271e65c1718f2999b038c18b45e21e4fba961181fbfae9342fc53b9e05")
swinir = replicate.models.get("jingyunliang/swinir").versions.get("660d922d33153019e8c263a3bba265de882e7f4f70396546b6c9c8f9d47a021a")
image = laionide.predict(prompt="avocado armchair")
upscaled_image = swinir.predict(image=image)

7、获取运行中模型的输出

在模型运行时获取其输出:

python 复制代码
iterator = replicate.run(
    "pixray/text2image:5c347a4bfa1d4523a58ae614c2194e15f2ae682b57e3797a5bb468920aa70ebf",     input={"prompts": "san francisco sunset"}
)

for index, image in enumerate(iterator):
    with open(f"file_{index}.png", "wb") as file:
        file.write(image.read())

8、取消预测

您可以取消正在运行的预测:

python 复制代码
>>> model = replicate.models.get("kvfrans/clipdraw")
>>> version = model.versions.get("5797a99edc939ea0e9242d5e8c9cb3bc7d125b1eac21bda852e5cb79ede2cd9b")
>>> prediction = replicate.predictions.create(
        version=version,         input={"prompt":"Watercolor painting of an underwater submarine"}
    )

>>> prediction.status
'starting'

>>> prediction.cancel()

>>> prediction.reload()
>>> prediction.status
'canceled'

9、列出预测

您可以列出所有运行过的预测:

python 复制代码
replicate.predictions.list()
# [<Prediction: 8b0ba5ab4d85>, <Prediction: 494900564e8c>]

预测列表是分页的。您可以通过将 next 属性作为参数传递给 list 方法来获取下一页预测:

python 复制代码
page1 = replicate.predictions.list()

if page1.next:
    page2 = replicate.predictions.list(page1.next)

10、加载输出文件

输出文件作为 FileOutput 对象返回:

python 复制代码
import replicate
from PIL import Image # pip install pillow

output = replicate.run(
    "stability-ai/stable-diffusion:27b93a2413e7f36cd83da926f3656280b2931564ff050bf9575f1fdf9bcd7478",     input={"prompt": "wavy colorful abstract patterns, oceans"}
    )

# 具有返回二进制数据的.read()方法
with open("my_output.png", "wb") as file:
  file.write(output[0].read())
  
# 也实现了迭代器协议以流式传输数据
background = Image.open(output[0])

FileOutput 对象

FileOutput 是从 replicate.run() 方法返回的类文件对象,使处理输出文件的模型更容易使用。它实现了 IteratorAsyncIterator 用于分块读取文件数据,以及 read()aread() 方法将整个文件读入内存。

!NOTE

值得注意的是,目前 read()aread() 不接受 size 参数来读取最多 size 字节。

最后,底层数据源的 URL 可通过 url 属性获得,但我们建议您将对象用作迭代器或使用其 read()aread() 方法,因为 url 属性在未来可能不总是返回 HTTP URL。

python 复制代码
print(output.url) #=> "..." or "https://delivery.replicate.com/..."

要直接消费文件:

python 复制代码
with open('output.bin', 'wb') as file:
    file.write(output.read())

对于非常大的文件,可以流式传输:

python 复制代码
with open(file_path, 'wb') as file:
    for chunk in output:
        file.write(chunk)

每种方法都有对应的 asyncio API:

python 复制代码
async with aiofiles.open(filename, 'w') as file:
    await file.write(await output.aread())

async with aiofiles.open(filename, 'w') as file:
    await for chunk in output:
        await file.write(chunk)

对于来自常见框架的流式响应,都支持接受 Iterator 类型:

Django

python 复制代码
@condition(etag_func=None)
def stream_response(request):
    output = replicate.run("black-forest-labs/flux-schnell", input={...}, use_file_output =True)
    return HttpResponse(output, content_type='image/webp')

FastAPI

python 复制代码
@app.get("/")
async def main():
    output = replicate.run("black-forest-labs/flux-schnell", input={...}, use_file_output =True)
    return StreamingResponse(output)

Flask

python 复制代码
@app.route('/stream')
def streamed_response():
    output = replicate.run("black-forest-labs/flux-schnell", input={...}, use_file_output =True)
    return app.response_class(stream_with_context(output))

您可以通过向 replicate.run() 方法传递 use_file_output=False 来禁用 FileOutput

python 复制代码
const replicate = replicate.run("acmecorp/acme-model", use_file_output=False);

11、列出模型

您可以列出您创建的模型:

python 复制代码
replicate.models.list()

模型列表是分页的。您可以通过将 next 属性作为参数传递给 list 方法来获取下一页模型,或者使用 paginate 方法自动获取页面。

python 复制代码
# 使用 `replicate.paginate` 自动分页(推荐)
models = []
for page in replicate.paginate(replicate.models.list):
    models.extend(page.results)
    if len(models) > 100:
        break

# 使用 `next` 游标手动分页
page = replicate.models.list()
while page:
    models.extend(page.results)
    if len(models) > 100:
          break
    page = replicate.models.list(page.next) if page.next else None

您还可以在 Replicate 上找到精选模型集合:

python 复制代码
>>> collections = [collection for page in replicate.paginate(replicate.collections.list) for collection in page]
>>> collections[0].slug
"vision-models"
>>> collections[0].description
"Multimodal large language models with vision capabilities like object detection and optical character recognition (OCR)"

>>> replicate.collections.get("text-to-image").models
[<Model: stability-ai/sdxl>, ...]

12、创建模型

您可以为用户或组织创建具有给定名称、可见性和硬件 SKU 的模型:

python 复制代码
import replicate

model = replicate.models.create(
    owner="your-username",     name="my-model",     visibility="public",     hardware="gpu-a40-large"
)

以下是列出 Replicate 上可用于运行模型的所有可用硬件的方法:

python 复制代码
>>> [hw.sku for hw in replicate.hardware.list()]
['cpu', 'gpu-t4', 'gpu-a40-small', 'gpu-a40-large']

13、微调模型

使用训练API微调模型,使其在特定任务上表现更好。要查看当前支持微调的语言模型 ,请查看 Replicate 的可训练语言模型集合

如果您想微调图像模型 ,请查看 Replicate 的图像模型微调指南

以下是在 Replicate 上微调模型的方法:

python 复制代码
training = replicate.trainings.create(
    model="stability-ai/sdxl",     version="39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b",     input={
      "input_images": "https://my-domain/training-images.zip",       "token_string": "TOK",       "caption_prefix": "a photo of TOK",       "max_train_steps": 1000,       "use_face_detection_instead": False
    },     # 您需要在 Replicate 上创建一个模型作为训练版本的接收方
    destination="your-username/model-name"
)

14、自定义客户端行为

replicate 包导出一个默认的共享客户端。此客户端使用 REPLICATE_API_TOKEN 环境变量设置的 API 令牌初始化。

您可以创建自己的客户端实例以传递不同的 API 令牌值,向请求添加自定义标头,或控制底层 HTTPX 客户端的行为:

python 复制代码
import os
from replicate.client import Client

replicate = Client(
    api_token=os.environ["SOME_OTHER_REPLICATE_API_TOKEN"]
    headers={
        "User-Agent": "my-app/1.0"
    }
)

!WARNING

切勿将 API 令牌等认证凭证硬编码到代码中。

相反,在运行程序时将它们作为环境变量传递。


五、开发

参见 <CONTRIBUTING.md>


伊织 xAI 2024-04-19(六)

相关推荐
worn.xiao4 分钟前
【CentOs】构建云服务器部署环境
运维·服务器·python
亢从文_Jackson1 小时前
SSM--AOP 日志
java·开发语言
天天爱吃肉82181 小时前
【基于Fluent+Python耦合的热管理数字孪生系统开发:新能源产品开发的硬核技术实践】
开发语言·python·数学建模·汽车
不辉放弃2 小时前
mysql的函数(第二期)
android·python·sql
wangz762 小时前
kotlin,Android,血压记录程序
android·开发语言·kotlin·jetpack compose
救救孩子把3 小时前
PyTorch 浮点数精度全景:从 float16/bfloat16 到 float64 及混合精度实战
人工智能·pytorch·python
意.远3 小时前
PyTorch数据操作基础教程:从张量创建到高级运算
人工智能·pytorch·python·深度学习·机器学习
明月看潮生5 小时前
青少年编程与数学 02-016 Python数据结构与算法 29课题、自然语言处理算法
python·算法·青少年编程·自然语言处理·编程与数学
努力学习的小廉6 小时前
【C++】 —— 笔试刷题day_20
开发语言·c++
西柚小萌新6 小时前
【Python爬虫基础篇】--1.基础概念
开发语言·爬虫·python