从零开发短视频电商 在AWS上SageMaker部署模型自定义日志输入和输出示例

从零开发短视频电商 在AWS上SageMaker部署模型自定义日志输入和输出示例

怎么部署自定义模型请看:从零开发短视频电商 在AWS上用SageMaker部署自定义模型

  • 都是huaggingface上的模型或者fine-tune后的。

为了适配jumpstart上部署的模型的http输入输出,我在自定义模型中自定义了适配的输入输出,可以做到兼容适配

code/inference.py

我们自定义的逻辑如下

python 复制代码
from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn.functional as F
import json
import logging
// --------- 这块
logger = logging.getLogger()
logger.setLevel(logging.INFO)
// 自定义http输入,可以适配不同的content_type ,打印输入的日志
// 源码参见下面的 preprocess
def input_fn(input_data, content_type):
    logger.info(f"laker input_data {input_data} and content_type {content_type}")
    if content_type == "application/json":
        request = json.loads(input_data)
    elif content_type == "application/x-text":
        request = {"inputs": input_data.decode('utf-8')}
    else:
        request = {"inputs": input_data} 
    logger.info(f"laker input_fn request {request} ")
    return request
// 自定义输出
def output_fn(prediction, accept):
    return encode_json(prediction)  

// 来自https://github.com/aws/sagemaker-huggingface-inference-toolkit/blob/80634b30703e8e9525db8b7128b05f713f42f9dc/src/sagemaker_huggingface_inference_toolkit/decoder_encoder.py#L102C1-L113C6
    
class _JSONEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.integer):
            return int(obj)
        elif isinstance(obj, np.floating):
            return float(obj)
        elif hasattr(obj, "tolist"):
            return obj.tolist()
        elif isinstance(obj, datetime.datetime):
            return obj.__str__()
        elif isinstance(obj, Image.Image):
            with BytesIO() as out:
                obj.save(out, format="PNG")
                png_string = out.getvalue()
                return base64.b64encode(png_string).decode("utf-8")
        else:
            return super(_JSONEncoder, self).default(obj)
        
def encode_json(content):
    """
    encodes json with custom `JSONEncoder`
    """
    return json.dumps(
        content,
        ensure_ascii=False,
        allow_nan=False,
        indent=None,
        cls=_JSONEncoder,
        separators=(",", ":"),
    )
// --------- 这块  end ---

# Helper: Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0] #First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)


def model_fn(model_dir):
  # Load model from HuggingFace Hub
  tokenizer = AutoTokenizer.from_pretrained(model_dir)
  model = AutoModel.from_pretrained(model_dir)
  return model, tokenizer

def predict_fn(data, model_and_tokenizer):
    # destruct model and tokenizer
    model, tokenizer = model_and_tokenizer
    
    # Tokenize sentences
    sentences = data.pop("inputs", data)
    encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')

    # Compute token embeddings
    with torch.no_grad():
        model_output = model(**encoded_input)

    # Perform pooling
    sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])

    # Normalize embeddings
    sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
    
    # return dictonary, which will be json serializable
    return {"embedding": sentence_embeddings[0].tolist()}
python 复制代码
import logging
from sagemaker_huggingface_inference_toolkit import content_types, decoder_encoder

logger = logging.getLogger(__name__)

 def preprocess(self, input_data, content_type, context=None):
        """
        The preprocess handler is responsible for deserializing the input data into
        an object for prediction, can handle JSON.
        The preprocess handler can be overridden for data or feature transformation.

        Args:
            input_data: the request payload serialized in the content_type format.
            content_type: the request content_type.
            context (obj): metadata on the incoming request data (default: None).

        Returns:
            decoded_input_data (dict): deserialized input_data into a Python dictonary.
        """
        # raises en error when using zero-shot-classification or table-question-answering, not possible due to nested properties
        if (
            os.environ.get("HF_TASK", None) == "zero-shot-classification"
            or os.environ.get("HF_TASK", None) == "table-question-answering"
        ) and content_type == content_types.CSV:
            raise PredictionException(
                f"content type {content_type} not support with {os.environ.get('HF_TASK', 'unknown task')}, use different content_type",
                400,
            )

        decoded_input_data = decoder_encoder.decode(input_data, content_type)
        return decoded_input_data
    
    
    

   logger.info(
        f"param1 {batch_size} and param2 {sequence_length}"
    )



def predict(self, data, model, context=None):
        """The predict handler is responsible for model predictions. Calls the `__call__` method of the provided `Pipeline`
        on decoded_input_data deserialized in input_fn. Runs prediction on GPU if is available.
        The predict handler can be overridden to implement the model inference.

        Args:
            data (dict): deserialized decoded_input_data returned by the input_fn
            model : Model returned by the `load` method or if it is a custom module `model_fn`.
            context (obj): metadata on the incoming request data (default: None).

        Returns:
            obj (dict): prediction result.
        """

        # pop inputs for pipeline
        inputs = data.pop("inputs", data)
        parameters = data.pop("parameters", None)

        # pass inputs with all kwargs in data
        if parameters is not None:
            prediction = model(inputs, **parameters)
        else:
            prediction = model(inputs)
        return prediction

    def postprocess(self, prediction, accept, context=None):
        """
        The postprocess handler is responsible for serializing the prediction result to
        the desired accept type, can handle JSON.
        The postprocess handler can be overridden for inference response transformation.

        Args:
            prediction (dict): a prediction result from predict.
            accept (str): type which the output data needs to be serialized.
            context (obj): metadata on the incoming request data (default: None).
        Returns: output data serialized
        """
        return decoder_encoder.encode(prediction, accept)
相关推荐
美狐美颜SDK开放平台24 分钟前
多终端适配下的人脸美型方案:美颜SDK工程开发实践分享
人工智能·音视频·美颜sdk·直播美颜sdk·视频美颜sdk
阿里云大数据AI技术1 小时前
全模态、多引擎、一体化,阿里云DLF3.0构建Data+AI驱动的智能湖仓平台
人工智能·阿里云·云计算
摇滚侠2 小时前
阿里云安装的 Redis 在什么位置,如何找到 Redis 的安装位置
redis·阿里云·云计算
饭饭大王6663 小时前
CANN 生态深度整合:使用 `pipeline-runner` 构建高吞吐视频分析流水线
人工智能·音视频
晚霞的不甘5 小时前
CANN 编译器深度解析:TBE 自定义算子开发实战
人工智能·架构·开源·音视频
愚公搬代码5 小时前
【愚公系列】《AI短视频创作一本通》016-AI短视频的生成(AI短视频运镜方法)
人工智能·音视频
m0_694845576 小时前
tinylisp 是什么?超轻量 Lisp 解释器编译与运行教程
服务器·开发语言·云计算·github·lisp
那个村的李富贵6 小时前
CANN赋能AIGC“数字人”革命:实时视频换脸与表情驱动实战
aigc·音视频
晚霞的不甘6 小时前
CANN 支持强化学习:从 Isaac Gym 仿真到机械臂真机控制
人工智能·神经网络·架构·开源·音视频
ESBK20256 小时前
第四届移动互联网、云计算与信息安全国际会议(MICCIS 2026)二轮征稿启动,诚邀全球学者共赴学术盛宴
大数据·网络·物联网·网络安全·云计算·密码学·信息与通信