从零开发短视频电商 在AWS上SageMaker部署模型自定义日志输入和输出示例
怎么部署自定义模型请看:从零开发短视频电商 在AWS上用SageMaker部署自定义模型
- 都是huaggingface上的模型或者fine-tune后的。
为了适配jumpstart上部署的模型的http输入输出,我在自定义模型中自定义了适配的输入输出,可以做到兼容适配
code/inference.py
- 容器的原始代码入口:https://github.com/aws/sagemaker-huggingface-inference-toolkit/blob/80634b30703e8e9525db8b7128b05f713f42f9dc/src/sagemaker_huggingface_inference_toolkit/handler_service.py
- 默认支持的decode和encode:https://github.com/aws/sagemaker-huggingface-inference-toolkit/blob/80634b30703e8e9525db8b7128b05f713f42f9dc/src/sagemaker_huggingface_inference_toolkit/decoder_encoder.py
- 可以用这个在sagemaker上使用jupyterlab:https://github.com/huggingface/notebooks/blob/main/sagemaker/17_custom_inference_script/sagemaker-notebook.ipynb
我们自定义的逻辑如下:
python
from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn.functional as F
import json
import logging
// --------- 这块
logger = logging.getLogger()
logger.setLevel(logging.INFO)
// 自定义http输入,可以适配不同的content_type ,打印输入的日志
// 源码参见下面的 preprocess
def input_fn(input_data, content_type):
logger.info(f"laker input_data {input_data} and content_type {content_type}")
if content_type == "application/json":
request = json.loads(input_data)
elif content_type == "application/x-text":
request = {"inputs": input_data.decode('utf-8')}
else:
request = {"inputs": input_data}
logger.info(f"laker input_fn request {request} ")
return request
// 自定义输出
def output_fn(prediction, accept):
return encode_json(prediction)
// 来自https://github.com/aws/sagemaker-huggingface-inference-toolkit/blob/80634b30703e8e9525db8b7128b05f713f42f9dc/src/sagemaker_huggingface_inference_toolkit/decoder_encoder.py#L102C1-L113C6
class _JSONEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, np.integer):
return int(obj)
elif isinstance(obj, np.floating):
return float(obj)
elif hasattr(obj, "tolist"):
return obj.tolist()
elif isinstance(obj, datetime.datetime):
return obj.__str__()
elif isinstance(obj, Image.Image):
with BytesIO() as out:
obj.save(out, format="PNG")
png_string = out.getvalue()
return base64.b64encode(png_string).decode("utf-8")
else:
return super(_JSONEncoder, self).default(obj)
def encode_json(content):
"""
encodes json with custom `JSONEncoder`
"""
return json.dumps(
content,
ensure_ascii=False,
allow_nan=False,
indent=None,
cls=_JSONEncoder,
separators=(",", ":"),
)
// --------- 这块 end ---
# Helper: Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(model_output, attention_mask):
token_embeddings = model_output[0] #First element of model_output contains all token embeddings
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
def model_fn(model_dir):
# Load model from HuggingFace Hub
tokenizer = AutoTokenizer.from_pretrained(model_dir)
model = AutoModel.from_pretrained(model_dir)
return model, tokenizer
def predict_fn(data, model_and_tokenizer):
# destruct model and tokenizer
model, tokenizer = model_and_tokenizer
# Tokenize sentences
sentences = data.pop("inputs", data)
encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
# Compute token embeddings
with torch.no_grad():
model_output = model(**encoded_input)
# Perform pooling
sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
# Normalize embeddings
sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
# return dictonary, which will be json serializable
return {"embedding": sentence_embeddings[0].tolist()}
python
import logging
from sagemaker_huggingface_inference_toolkit import content_types, decoder_encoder
logger = logging.getLogger(__name__)
def preprocess(self, input_data, content_type, context=None):
"""
The preprocess handler is responsible for deserializing the input data into
an object for prediction, can handle JSON.
The preprocess handler can be overridden for data or feature transformation.
Args:
input_data: the request payload serialized in the content_type format.
content_type: the request content_type.
context (obj): metadata on the incoming request data (default: None).
Returns:
decoded_input_data (dict): deserialized input_data into a Python dictonary.
"""
# raises en error when using zero-shot-classification or table-question-answering, not possible due to nested properties
if (
os.environ.get("HF_TASK", None) == "zero-shot-classification"
or os.environ.get("HF_TASK", None) == "table-question-answering"
) and content_type == content_types.CSV:
raise PredictionException(
f"content type {content_type} not support with {os.environ.get('HF_TASK', 'unknown task')}, use different content_type",
400,
)
decoded_input_data = decoder_encoder.decode(input_data, content_type)
return decoded_input_data
logger.info(
f"param1 {batch_size} and param2 {sequence_length}"
)
def predict(self, data, model, context=None):
"""The predict handler is responsible for model predictions. Calls the `__call__` method of the provided `Pipeline`
on decoded_input_data deserialized in input_fn. Runs prediction on GPU if is available.
The predict handler can be overridden to implement the model inference.
Args:
data (dict): deserialized decoded_input_data returned by the input_fn
model : Model returned by the `load` method or if it is a custom module `model_fn`.
context (obj): metadata on the incoming request data (default: None).
Returns:
obj (dict): prediction result.
"""
# pop inputs for pipeline
inputs = data.pop("inputs", data)
parameters = data.pop("parameters", None)
# pass inputs with all kwargs in data
if parameters is not None:
prediction = model(inputs, **parameters)
else:
prediction = model(inputs)
return prediction
def postprocess(self, prediction, accept, context=None):
"""
The postprocess handler is responsible for serializing the prediction result to
the desired accept type, can handle JSON.
The postprocess handler can be overridden for inference response transformation.
Args:
prediction (dict): a prediction result from predict.
accept (str): type which the output data needs to be serialized.
context (obj): metadata on the incoming request data (default: None).
Returns: output data serialized
"""
return decoder_encoder.encode(prediction, accept)