【Llava】load_pretrained_model() 说明

load_pretrained_model

python 复制代码
load_pretrained_model(
    model_path, 
    model_base=None, 
    model_name, 
    load_8bit=False, 
    load_4bit=False, 
    device_map="auto", 
    torch_dtype="float16",
    attn_implementation="flash_attention_2", 
    customized_config=None, 
    overwrite_config=None, 
    **kwargs
)

加载预训练的 LLaVA 模型或语言模型。

该函数支持加载多种模型架构,包括 LLaVA 模型(Llama、Mistral、Mixtral、Qwen、Gemma)和标准语言模型。它能够处理不同的加载场景,如完整模型加载、LoRA 权重加载和量化(8-bit/4-bit)。

参数

参数名 类型 默认值 说明
model_path str - 模型检查点路径或 HuggingFace 模型标识符。对于 LoRA 模型,应指向 LoRA 权重目录。
model_base str, optional None 基础模型检查点路径。加载 LoRA 权重或仅加载投影器检查点时必须提供。
model_name str - 模型名称或标识符,用于确定模型架构。应包含关键词如 "llava"、"lora"、"mixtral"、"mistral"、"qwen"、"gemma" 等。
load_8bit bool False 是否以 8-bit 量化加载模型。与 load_4bit 互斥。
load_4bit bool False 是否使用 BitsAndBytes 以 4-bit 量化加载模型。与 load_8bit 互斥。
device_map str "auto" 模型加载的设备映射策略。选项包括 "auto"、"cpu"、"cuda" 或特定的设备映射字典。
torch_dtype str "float16" 模型权重的数据类型。选项为 "float16" 或 "bfloat16"。
attn_implementation str "flash_attention_2" 使用的注意力实现。选项包括 "flash_attention_2"、"sdpa" 等。
customized_config dict, optional None 自定义模型配置字典,用于覆盖从 model_path 加载的默认配置。
overwrite_config dict, optional None 配置属性字典,用于在加载基础配置后覆盖。键应为配置属性名,值为要设置的新值。
**kwargs - - 传递给底层 from_pretrained() 方法的其他关键字参数(如 trust_remote_codelow_cpu_mem_usage)。

返回值

返回一个包含以下四个元素的元组:

  • tokenizer (AutoTokenizer): 加载模型的 tokenizer 实例。
  • model (PreTrainedModel): 加载的模型实例。可能是以下类型之一:
    • LlavaLlamaForCausalLM
    • LlavaMistralForCausalLM
    • LlavaMixtralForCausalLM
    • LlavaQwenForCausalLM
    • LlavaGemmaForCausalLM
    • AutoModelForCausalLM(用于纯语言模型)
  • image_processor (ImageProcessor or None): 视觉语言模型的图像处理器。对于纯语言模型返回 None
  • context_len (int): 模型的最大上下文长度,从配置属性(如 max_sequence_lengthmax_position_embeddings)确定,或默认为 2048。

功能特性

支持的模型架构

该函数自动识别并加载以下模型架构:

  • LLaVA 系列: Llama、Mistral、Mixtral、Qwen、Gemma 等基础架构的 LLaVA 变体
  • LoRA 模型: 支持加载和合并 LoRA 权重
  • 纯语言模型: 支持标准 HuggingFace 语言模型

量化支持

  • 8-bit 量化 : 通过 load_8bit=True 启用
  • 4-bit 量化 : 通过 load_4bit=True 启用,使用 BitsAndBytes 库,配置为 NF4 量化类型

特殊处理

  • LLaVA v1.5 模型 : 自动设置 delay_load=True 作为正确加载的解决方案
  • 多模态模型: 自动为 tokenizer 添加特殊 token(图像 patch token、开始/结束 token)
  • LoRA 权重合并: 自动加载并合并 LoRA 权重到基础模型

使用示例

示例 1: 加载完整的 LLaVA 模型

python 复制代码
from llava.model.builder import load_pretrained_model

# 从 HuggingFace 加载 LLaVA v1.5 模型
tokenizer, model, image_processor, context_len = load_pretrained_model(
    model_path="liuhaotian/llava-v1.5-7b",
    model_base=None,
    model_name="llava",
    torch_dtype="float16"
)

print(f"Model loaded: {model.__class__.__name__}")
print(f"Context length: {context_len}")

示例 2: 加载 LoRA 模型

python 复制代码
# 加载 LoRA 权重(需要提供基础模型)
tokenizer, model, image_processor, context_len = load_pretrained_model(
    model_path="./checkpoints/llava-lora",
    model_base="liuhaotian/llava-v1.5-7b",
    model_name="llava_lora",
    torch_dtype="float16"
)

示例 3: 使用 4-bit 量化加载模型

python 复制代码
# 使用 4-bit 量化以节省显存
tokenizer, model, image_processor, context_len = load_pretrained_model(
    model_path="liuhaotian/llava-v1.5-7b",
    model_base=None,
    model_name="llava",
    load_4bit=True,
    device_map="auto"
)

示例 4: 加载自定义配置的模型

python 复制代码
# 使用自定义配置覆盖默认设置
custom_config = {
    "mm_vision_select_layer": -2,
    "mm_projector_type": "mlp2x_gelu"
}

tokenizer, model, image_processor, context_len = load_pretrained_model(
    model_path="liuhaotian/llava-v1.5-7b",
    model_base=None,
    model_name="llava",
    customized_config=custom_config,
    torch_dtype="float16"
)

示例 5: 加载纯语言模型

python 复制代码
# 加载标准语言模型(非多模态)
tokenizer, model, _, context_len = load_pretrained_model(
    model_path="meta-llama/Llama-2-7b-hf",
    model_base=None,
    model_name="llama",
    torch_dtype="float16"
)
# 注意:image_processor 为 None

注意事项

  1. LoRA 模型加载 : 加载 LoRA 模型时,必须提供 model_base 参数。函数会先加载基础模型,然后应用 LoRA 权重并自动合并。

  2. 量化互斥性 : load_8bitload_4bit 参数互斥,不能同时设置为 True

  3. 设备映射 : 使用 device_map="auto" 时,函数会自动将模型分配到可用设备。对于多 GPU 环境,模型会被分片到不同 GPU。

  4. 配置覆盖顺序:

    • 首先应用 customized_config(如果提供)
    • 然后应用 overwrite_config(如果提供)
    • 最后应用从 model_path 加载的配置
  5. 特殊 Token: 对于多模态模型,函数会根据模型配置自动添加图像相关的特殊 token 到 tokenizer。

警告

  • 如果 model_name 包含 "lora" 但 model_baseNone,函数会发出警告,因为 LoRA 模型需要基础模型。

  • 如果指定的 model_name 不被支持,函数会抛出 ValueError 异常。

相关链接

相关推荐
MARS_AI_2 小时前
大模型赋能客户沟通,云蝠大模型呼叫实现问题解决全链路闭环
人工智能·自然语言处理·信息与通信·agi
名为沙丁鱼的猫7292 小时前
【MCP 协议层(Protocol layer)详解】:深入分析MCP Python SDK中协议层的实现机制
人工智能·深度学习·神经网络·机器学习·自然语言处理·nlp
bylander2 小时前
【AI学习】几分钟了解一下Clawdbot
人工智能·智能体·智能体应用
香芋Yu2 小时前
【机器学习教程】第04章 指数族分布
人工智能·笔记·机器学习
小咖自动剪辑2 小时前
Base64与图片互转工具增强版:一键编码/解码,支持多格式
人工智能·pdf·word·媒体
独自归家的兔2 小时前
从 “局部凑活“ 到 “全局最优“:AI 规划能力的技术突破与产业落地实践
大数据·人工智能
一个处女座的程序猿2 小时前
AI:解读Sam Altman与多位 AI 构建者对话—构建可落地的 AI—剖析 OpenAI Town Hall 与给创业者、产品/工程/安全团队的实用指南
人工智能
依依yyy2 小时前
沪深300指数收益率波动性分析与预测——基于ARMA-GARCH模型
人工智能·算法·机器学习
海域云-罗鹏3 小时前
国内公司与英国总部数据中心/ERP系统互连,SD-WAN专线实操指南
大数据·数据库·人工智能