从零开发短视频电商 在AWS上SageMaker部署模型自定义日志输入和输出示例

从零开发短视频电商 在AWS上SageMaker部署模型自定义日志输入和输出示例

怎么部署自定义模型请看:从零开发短视频电商 在AWS上用SageMaker部署自定义模型

  • 都是huaggingface上的模型或者fine-tune后的。

为了适配jumpstart上部署的模型的http输入输出,我在自定义模型中自定义了适配的输入输出,可以做到兼容适配

code/inference.py

我们自定义的逻辑如下

python 复制代码
from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn.functional as F
import json
import logging
// --------- 这块
logger = logging.getLogger()
logger.setLevel(logging.INFO)
// 自定义http输入,可以适配不同的content_type ,打印输入的日志
// 源码参见下面的 preprocess
def input_fn(input_data, content_type):
    logger.info(f"laker input_data {input_data} and content_type {content_type}")
    if content_type == "application/json":
        request = json.loads(input_data)
    elif content_type == "application/x-text":
        request = {"inputs": input_data.decode('utf-8')}
    else:
        request = {"inputs": input_data} 
    logger.info(f"laker input_fn request {request} ")
    return request
// 自定义输出
def output_fn(prediction, accept):
    return encode_json(prediction)  

// 来自https://github.com/aws/sagemaker-huggingface-inference-toolkit/blob/80634b30703e8e9525db8b7128b05f713f42f9dc/src/sagemaker_huggingface_inference_toolkit/decoder_encoder.py#L102C1-L113C6
    
class _JSONEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.integer):
            return int(obj)
        elif isinstance(obj, np.floating):
            return float(obj)
        elif hasattr(obj, "tolist"):
            return obj.tolist()
        elif isinstance(obj, datetime.datetime):
            return obj.__str__()
        elif isinstance(obj, Image.Image):
            with BytesIO() as out:
                obj.save(out, format="PNG")
                png_string = out.getvalue()
                return base64.b64encode(png_string).decode("utf-8")
        else:
            return super(_JSONEncoder, self).default(obj)
        
def encode_json(content):
    """
    encodes json with custom `JSONEncoder`
    """
    return json.dumps(
        content,
        ensure_ascii=False,
        allow_nan=False,
        indent=None,
        cls=_JSONEncoder,
        separators=(",", ":"),
    )
// --------- 这块  end ---

# Helper: Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0] #First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)


def model_fn(model_dir):
  # Load model from HuggingFace Hub
  tokenizer = AutoTokenizer.from_pretrained(model_dir)
  model = AutoModel.from_pretrained(model_dir)
  return model, tokenizer

def predict_fn(data, model_and_tokenizer):
    # destruct model and tokenizer
    model, tokenizer = model_and_tokenizer
    
    # Tokenize sentences
    sentences = data.pop("inputs", data)
    encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')

    # Compute token embeddings
    with torch.no_grad():
        model_output = model(**encoded_input)

    # Perform pooling
    sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])

    # Normalize embeddings
    sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
    
    # return dictonary, which will be json serializable
    return {"embedding": sentence_embeddings[0].tolist()}
python 复制代码
import logging
from sagemaker_huggingface_inference_toolkit import content_types, decoder_encoder

logger = logging.getLogger(__name__)

 def preprocess(self, input_data, content_type, context=None):
        """
        The preprocess handler is responsible for deserializing the input data into
        an object for prediction, can handle JSON.
        The preprocess handler can be overridden for data or feature transformation.

        Args:
            input_data: the request payload serialized in the content_type format.
            content_type: the request content_type.
            context (obj): metadata on the incoming request data (default: None).

        Returns:
            decoded_input_data (dict): deserialized input_data into a Python dictonary.
        """
        # raises en error when using zero-shot-classification or table-question-answering, not possible due to nested properties
        if (
            os.environ.get("HF_TASK", None) == "zero-shot-classification"
            or os.environ.get("HF_TASK", None) == "table-question-answering"
        ) and content_type == content_types.CSV:
            raise PredictionException(
                f"content type {content_type} not support with {os.environ.get('HF_TASK', 'unknown task')}, use different content_type",
                400,
            )

        decoded_input_data = decoder_encoder.decode(input_data, content_type)
        return decoded_input_data
    
    
    

   logger.info(
        f"param1 {batch_size} and param2 {sequence_length}"
    )



def predict(self, data, model, context=None):
        """The predict handler is responsible for model predictions. Calls the `__call__` method of the provided `Pipeline`
        on decoded_input_data deserialized in input_fn. Runs prediction on GPU if is available.
        The predict handler can be overridden to implement the model inference.

        Args:
            data (dict): deserialized decoded_input_data returned by the input_fn
            model : Model returned by the `load` method or if it is a custom module `model_fn`.
            context (obj): metadata on the incoming request data (default: None).

        Returns:
            obj (dict): prediction result.
        """

        # pop inputs for pipeline
        inputs = data.pop("inputs", data)
        parameters = data.pop("parameters", None)

        # pass inputs with all kwargs in data
        if parameters is not None:
            prediction = model(inputs, **parameters)
        else:
            prediction = model(inputs)
        return prediction

    def postprocess(self, prediction, accept, context=None):
        """
        The postprocess handler is responsible for serializing the prediction result to
        the desired accept type, can handle JSON.
        The postprocess handler can be overridden for inference response transformation.

        Args:
            prediction (dict): a prediction result from predict.
            accept (str): type which the output data needs to be serialized.
            context (obj): metadata on the incoming request data (default: None).
        Returns: output data serialized
        """
        return decoder_encoder.encode(prediction, accept)
相关推荐
翼龙云_cloud5 小时前
亚马逊云渠道商:如何快速开始使用Amazon RDS?
运维·服务器·云计算·aws
赖small强6 小时前
【音视频开发】Linux UVC (USB Video Class) 驱动框架深度解析
linux·音视频·v4l2·uvc
赖small强6 小时前
【音视频开发】ISP流水线核心模块深度解析
音视频·isp·白平衡·亮度·luminance·gamma 校正·降噪处理
可观测性用观测云7 小时前
观测云荣膺亚马逊云科技 2025 年合作伙伴奖项
云计算
赖small强7 小时前
【音视频开发】Linux V4L2 (Video for Linux 2) 驱动框架深度解析白皮书
linux·音视频·v4l2·设备节点管理·视频缓冲队列·videobuf2
@HNUSTer9 小时前
基于 GEE 的生态环境质量评价:遥感生态指数(RSEI)计算与空间分布可视化
云计算·数据集·遥感大数据·gee·云平台·遥感生态指数(rsei)·生态环境质量评价
原神启动110 小时前
云计算大数据——MySQL数据库一(数据库基础与MySQL安装)
大数据·数据库·云计算
ACP广源盛1392462567311 小时前
GSV2712@ACP#2 进 1 出 HDMI 2.0/Type-C DisplayPort 1.4 混合切换器 + 嵌入式 MCU
单片机·嵌入式硬件·计算机外设·音视频
AI周红伟11 小时前
通义万相开源14B数字人Wan2.2-S2V!影视级音频驱动视频生成,助力专业内容创作
音视频
weixin_3077791312 小时前
基于AWS Global Accelerator和ECS Fargate的最小化延迟与快速故障转移架构
容器·云计算·aws