1 _load_pretrained_model 调用流程:
1. 初始化标志和变量:
- 判断权重文件是否是
safetensors
格式(is_safetensors
)。 - 判断是否使用了量化器(
is_quantized
)。 - 初始化状态字典的文件夹和索引(
state_dict_folder
、state_dict_index
)。
python
is_safetensors = False
is_quantized = hf_quantizer is not None
state_dict_folder = None
state_dict_index = None
2. 处理 device_map
和磁盘卸载:
- 如果提供了
device_map
,且包含"disk"
,则表示部分权重将卸载到磁盘。 - 如果未指定
offload_folder
且不是safetensors
格式,抛出错误,提示需要提供offload_folder
。 - 创建
offload_folder
,用于存储卸载到磁盘的权重。 - 设置
offload_state_dict
为True
,表示状态字典将被卸载。
python
if device_map is not None and "disk" in device_map.values():
archive_file = (
resolved_archive_file[0] if isinstance(resolved_archive_file, (list, tuple)) else resolved_archive_file
)
is_safetensors = archive_file is not None and archive_file.endswith(".safetensors")
if offload_folder is None and not is_safetensors:
raise ValueError(
"The current `device_map` had weights offloaded to the disk. Please provide an `offload_folder`"
" for them. Alternatively, make sure you have `safetensors` installed if the model you are using"
" offers the weights in this format."
)
if offload_folder is not None:
os.makedirs(offload_folder, exist_ok=True)
if offload_state_dict is None:
offload_state_dict = True
3. 判断是否为分片的 safetensors
文件:
- 通过
is_sharded_safetensors
判断当前是否加载的是分片的safetensors
文件。
python
is_sharded_safetensors = is_safetensors and sharded_metadata is not None
4. 绑定共享权重(参数共享):
- 调用
model.tie_weights()
,确保模型的共享权重正确绑定。
python
# tie the model weights before retrieving the state_dict
model.tie_weights()
5. 获取模型的状态字典和预期的键值:
- 通过
model.state_dict()
获取模型的状态字典。 - 生成
expected_keys
,即模型中所有应当存在的参数键列表。
python
# Retrieve missing & unexpected_keys
model_state_dict = model.state_dict()
expected_keys = list(model_state_dict.keys())
prefix = model.base_model_prefix
6. 处理量化器的预期键值:
- 如果使用了量化器,更新
expected_keys
以匹配量化后的参数键。
python
if hf_quantizer is not None:
expected_keys = hf_quantizer.update_expected_keys(model, expected_keys, loaded_keys)
7. 调整键名的前缀以匹配:
- 根据模型的
base_model_prefix
来判断是否需要添加或移除前缀,以使加载的键与模型的键匹配。
python
prefix = model.base_model_prefix
if len(prefix) > 0:
has_prefix_module = any(s.startswith(prefix) for s in loaded_keys)
expects_prefix_module = any(s.startswith(prefix) for s in expected_keys)
else:
has_prefix_module = False
expects_prefix_module = False
8. 计算缺失和意外的键:
- 计算
missing_keys
,即模型中缺失的参数键。 - 计算
unexpected_keys
,即预训练权重中存在但模型中不存在的参数键。 - 从
unexpected_keys
中移除模型的缓冲区(如非持久化的 buffers)键。
python
# key re-naming operations are never done on the keys
# that are loaded, but always on the keys of the newly initialized model
remove_prefix_from_model = not has_prefix_module and expects_prefix_module
add_prefix_to_model = has_prefix_module and not expects_prefix_module
if remove_prefix_from_model:
_prefix = f"{prefix}."
expected_keys_not_prefixed = [s for s in expected_keys if not s.startswith(_prefix)]
expected_keys = [s[len(_prefix) :] if s.startswith(_prefix) else s for s in expected_keys]
elif add_prefix_to_model:
expected_keys = [".".join([prefix, s]) for s in expected_keys]
missing_keys = sorted(set(expected_keys) - set(loaded_keys))
unexpected_keys = set(loaded_keys) - set(expected_keys)
# Remove nonpersistent buffers from unexpected keys: they are not in the state dict but will be in the model
# buffers
model_buffers = {n for n, _ in model.named_buffers()}
if remove_prefix_from_model:
model_buffers = {key[len(_prefix) :] if key.startswith(_prefix) else key for key in model_buffers}
elif add_prefix_to_model:
model_buffers = {".".join([prefix, key]) for key in model_buffers}
unexpected_keys = sorted(unexpected_keys - model_buffers)
9. 处理共享的参数(参数绑定):
- 找出模型中共享的参数,并在缺失键中移除那些由于共享而不真正缺失的键。
python
# Clean up buffer for `inv-freq` because RoPE embedding moved under base model (https://github.com/huggingface/transformers/pull/34858)
has_inv_freq_buffers = any(buffer.endswith("rotary_emb.inv_freq") for buffer in model_buffers)
if has_inv_freq_buffers:
unexpected_keys = {k for k in unexpected_keys if "rotary_emb.inv_freq" not in k}
model.tie_weights()
if device_map is None and not is_fsdp_enabled() and not is_deepspeed_zero3_enabled():
ptrs = collections.defaultdict(list)
for name, tensor in model.state_dict().items():
id_tensor = id_tensor_storage(tensor)
ptrs[id_tensor].append(name)
# These are all the pointers of shared tensors.
tied_params = [names for _, names in ptrs.items() if len(names) > 1]
else:
# id function doesn't work for meta tensor so we need this function
tied_params = find_tied_parameters(model)
for group in tied_params:
if remove_prefix_from_model:
group = [key[len(_prefix) :] if key.startswith(_prefix) else key for key in group]
elif add_prefix_to_model:
group = [".".join([prefix, key]) for key in group]
missing_in_group = [k for k in missing_keys if k in group]
if len(missing_in_group) > 0 and len(missing_in_group) < len(group):
missing_keys = [k for k in missing_keys if k not in missing_in_group]
10. 根据模式忽略某些键:
- 使用模型定义的
_keys_to_ignore_on_load_missing
和_keys_to_ignore_on_load_unexpected
,从缺失和意外的键中移除匹配的键。
python
# Some models may have keys that are not in the state by design, removing them before needlessly warning
# the user.
if cls._keys_to_ignore_on_load_missing is not None:
for pat in cls._keys_to_ignore_on_load_missing:
missing_keys = [k for k in missing_keys if re.search(pat, k) is None]
if cls._keys_to_ignore_on_load_unexpected is not None:
for pat in cls._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
if hf_quantizer is not None:
missing_keys = hf_quantizer.update_missing_keys(model, missing_keys, prefix)
unexpected_keys = hf_quantizer.update_unexpected_keys(model, unexpected_keys, prefix)
11. 处理低内存使用模式下的参数:
- 如果启用了
low_cpu_mem_usage
,则对于缺失的参数,在 CPU 上建立空的张量占位符,稍后再加载实际的权重。
python
# retrieve weights on meta device and put them back on CPU.
# This is not ideal in terms of memory, but if we don't do that not, we can't initialize them in the next step
if low_cpu_mem_usage:
for key in missing_keys:
if key in model_state_dict:
key = key
elif f"{prefix}.{key}" in model_state_dict:
key = f"{prefix}.{key}"
elif key.startswith(prefix) and ".".join(key.split(".")[1:]) in model_state_dict:
key = ".".join(key.split(".")[1:])
param = model_state_dict[key]
# upcast in fp32 if any
target_dtype = dtype
if (
keep_in_fp32_modules is not None
and dtype == torch.float16
and any(
module_to_keep_in_fp32 in key.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules
)
):
target_dtype = torch.float32
if param.device == torch.device("meta"):
value = torch.empty(*param.size(), dtype=target_dtype)
if (
not is_quantized
or (getattr(hf_quantizer, "requires_parameters_quantization", False))
or not hf_quantizer.check_quantized_param(
model, param_value=value, param_name=key, state_dict={}
)
):
set_module_tensor_to_device(model, key, "cpu", value)
else:
hf_quantizer.create_quantized_param(model, value, key, "cpu", state_dict, unexpected_keys)
12. 初始化未初始化的子模块:
- 如果启用了快速初始化(
_fast_init
),则初始化模型中尚未初始化的子模块。 - 对于 DeepSpeed 或 FSDP 场景,有特殊的初始化处理。
python
# retrieve uninitialized modules and initialize before maybe overriding that with the pretrained weights.
if _fast_init:
if not ignore_mismatched_sizes:
if remove_prefix_from_model:
_loaded_keys = [f"{prefix}.{k}" for k in loaded_keys]
elif add_prefix_to_model:
_loaded_keys = [k[len(prefix) + 1 :] for k in loaded_keys]
else:
_loaded_keys = loaded_keys
not_initialized_submodules = set_initialized_submodules(model, _loaded_keys)
# If we're about to tie the output embeds to the input embeds we don't need to init them
if (
hasattr(model.config.get_text_config(decoder=True), "tie_word_embeddings")
and model.config.get_text_config(decoder=True).tie_word_embeddings
):
output_embeddings = model.get_output_embeddings()
if output_embeddings is not None:
# Still need to initialize if there is a bias term since biases are not tied.
if not hasattr(output_embeddings, "bias") or output_embeddings.bias is None:
output_embeddings._is_hf_initialized = True
else:
not_initialized_submodules = dict(model.named_modules())
# This will only initialize submodules that are not marked as initialized by the line above.
if is_deepspeed_zero3_enabled() and not is_quantized:
import deepspeed
not_initialized_parameters = list(
set(
itertools.chain.from_iterable(
submodule.parameters(recurse=False) for submodule in not_initialized_submodules.values()
)
)
)
with deepspeed.zero.GatheredParameters(not_initialized_parameters, modifier_rank=0):
model.apply(model._initialize_weights)
else:
model.apply(model._initialize_weights)
这段代码的作用是在加载预训练模型权重之前,初始化模型中未初始化的子模块。具体解释如下:
首先,代码包含一个注释:
python
# retrieve uninitialized modules and initialize before maybe overriding that with the pretrained weights.
意思是:在可能用预训练权重覆盖之前,先检索未初始化的模块并进行初始化。
然后,代码开始执行:
python
if _fast_init:
如果启用了快速初始化(_fast_init
为真),则进入初始化流程。
接下来,代码检查是否需要忽略尺寸不匹配的参数:
python
if not ignore_mismatched_sizes:
如果不忽略尺寸不匹配的情况(即需要严格匹配参数尺寸),则需要处理加载的权重键名与模型参数键名之间的关系:
处理键名前缀
-
如果需要从模型参数中移除前缀(
remove_prefix_from_model
为真):python_loaded_keys = [f"{prefix}.{k}" for k in loaded_keys]
这会在加载的键名上添加指定的前缀,以匹配模型参数的键名。
-
如果需要给模型参数添加前缀(
add_prefix_to_model
为真):python_loaded_keys = [k[len(prefix) + 1 :] for k in loaded_keys]
这会从加载的键名中移除指定的前缀,以匹配模型参数的键名。
-
如果不需要处理前缀:
python_loaded_keys = loaded_keys
直接使用加载的键名。
标记已初始化的子模块
python
not_initialized_submodules = set_initialized_submodules(model, _loaded_keys)
这个函数会根据加载的键名,标记模型中哪些子模块已经被初始化,并返回未初始化的子模块集合 not_initialized_submodules
。
处理词嵌入参数的初始化
检查模型配置中是否设置了词嵌入共享(tie_word_embeddings
):
python
if (
hasattr(model.config.get_text_config(decoder=True), "tie_word_embeddings")
and model.config.get_text_config(decoder=True).tie_word_embeddings
):
如果模型的配置中启用了输入和输出词嵌入共享,那么输出的词嵌入权重将与输入的词嵌入权重绑定,不需要单独初始化。
但是,如果输出词嵌入存在偏置项(bias),由于偏置项不共享,仍需要初始化:
python
output_embeddings = model.get_output_embeddings()
if output_embeddings is not None:
if not hasattr(output_embeddings, "bias") or output_embeddings.bias is None:
output_embeddings._is_hf_initialized = True
如果选择忽略尺寸不匹配的参数(ignore_mismatched_sizes
为真),则默认所有子模块都未初始化:
python
else:
not_initialized_submodules = dict(model.named_modules())
接下来,代码根据是否启用了 DeepSpeed ZeRO Stage 3 和模型是否量化,选择不同的初始化方式:
使用 DeepSpeed 初始化参数
如果启用了 DeepSpeed ZeRO Stage 3 并且模型未量化:
python
if is_deepspeed_zero3_enabled() and not is_quantized:
那么使用 DeepSpeed 提供的参数收集器来初始化参数:
python
import deepspeed
not_initialized_parameters = list(
set(
itertools.chain.from_iterable(
submodule.parameters(recurse=False) for submodule in not_initialized_submodules.values()
)
)
)
with deepspeed.zero.GatheredParameters(not_initialized_parameters, modifier_rank=0):
model.apply(model._initialize_weights)
这里,对未初始化的参数进行收集,在 DeepSpeed 的上下文管理器中统一初始化。这是因为在 ZeRO Stage 3 下,参数被分布式存储,需要先收集到主设备才能初始化。
普通初始化方式
如果未启用 DeepSpeed 或模型已量化,直接对模型应用初始化函数:
python
else:
model.apply(model._initialize_weights)
model.apply()
是 PyTorch 中 nn.Module
类的一个方法。
详细解释:
在 PyTorch 中,所有的神经网络模型都继承自 nn.Module
。nn.Module
提供了多个有用的方法,其中之一就是 apply()
方法。apply()
方法的作用是递归地将一个函数应用到模型及其所有子模块(submodules)上。
具体来说,apply()
方法的定义如下:
python
def apply(self, fn):
fn(self)
for module in self.children():
module.apply(fn)
return self
这意味着,当你调用 model.apply(fn)
时,函数 fn
会首先作用于 model
本身,然后递归地作用于 model
的所有子模块。这对于需要对整个模型的参数或模块进行统一的操作时非常方便。
在您的代码中:
python
model.apply(model._initialize_weights)
model
是一个继承自nn.Module
的模型实例。model.apply()
是调用nn.Module
中的apply()
方法。model._initialize_weights
是一个定义在model
内部的函数,用于初始化权重。
所以,这行代码的作用是:对模型 model
及其所有子模块,递归地调用 model._initialize_weights
函数。这通常用于在模型的所有模块上执行权重初始化。
注意:
-
model._initialize_weights
必须是一个接受nn.Module
实例作为参数的函数,因为apply()
方法会将当前模块作为参数传递给函数fn
。通常,_initialize_weights
会定义为:pythondef _initialize_weights(self, module): # 对 module 执行初始化操作
-
如果
model._initialize_weights
是一个绑定方法(即定义在类内部,且需要访问实例属性),直接传递model._initialize_weights
就可以,不需要额外的参数。
13. 设置需要保持为 FP32 的模块:
- 如果指定了
keep_in_fp32_modules
,则将这些模块的参数强制转换为torch.float32
。
python
# Set some modules to fp32 if any
if keep_in_fp32_modules == []:
keep_in_fp32_modules = None
if keep_in_fp32_modules is not None:
keep_in_fp32_modules = re.compile("|".join(keep_in_fp32_modules))
for name, param in model.named_parameters():
if keep_in_fp32_modules.search(name):
# param = param.to(torch.float32) does not work here as only in the local scope.
param.data = param.data.to(torch.float32) # TODO @Cyrilvallez: we seem to do this twice
这段代码的主要目的是将指定的模型参数(模块)设置为 float32
(单精度浮点数)格式,即使整个模型可能默认使用较低的精度(例如 float16
)进行存储或计算。这样做的原因通常是为了在使用混合精度训练或推理时,确保一些关键的模块保持较高的数值精度,防止数值不稳定或精度损失。
逐行解释如下:
python
# Set some modules to fp32 if any
if keep_in_fp32_modules == []:
keep_in_fp32_modules = None
-
检查
keep_in_fp32_modules
是否为空列表:- 如果
keep_in_fp32_modules
是空列表,即[]
,说明没有指定任何模块需要保持float32
精度。那么,将其设置为None
,表示后续不需要进行特殊处理。
- 如果
python
if keep_in_fp32_modules is not None:
keep_in_fp32_modules = re.compile("|".join(keep_in_fp32_modules))
-
如果指定了需要保持
float32
精度的模块:-
检查
keep_in_fp32_modules
是否不为None
。如果不是None
,说明用户提供了需要保持高精度的模块列表。 -
使用正则表达式将模块名称列表编译为一个匹配模式:
pythonkeep_in_fp32_modules = re.compile("|".join(keep_in_fp32_modules))
-
keep_in_fp32_modules
是一个字符串列表,包含了需要保持float32
精度的模块名称或名称模式。 -
"|".join(keep_in_fp32_modules)
将列表中的字符串用管道符"|"
(表示"或")连接起来,形成一个正则表达式模式字符串。 -
re.compile()
编译这个正则表达式模式,以便后续匹配使用。
-
-
python
for name, param in model.named_parameters():
if keep_in_fp32_modules.search(name):
# param = param.to(torch.float32) does not work here as only in the local scope.
param.data = param.data.to(torch.float32) # TODO @Cyrilvallez: we seem to do this twice
-
遍历模型的所有参数,筛选需要转换精度的参数:
-
model.named_parameters()
返回模型中所有参数的迭代器,提供参数的名称name
和参数本身param
。 -
对于每一个参数,使用正则表达式
keep_in_fp32_modules
进行搜索:pythonif keep_in_fp32_modules.search(name):
- 如果参数名称
name
匹配正则表达式,说明这个参数属于需要保持float32
精度的模块。
- 如果参数名称
-
-
将匹配的参数转换为
float32
精度:-
注意 :有一条注释说明了为什么不用
param = param.to(torch.float32)
:python# param = param.to(torch.float32) does not work here as only in the local scope.
- 直接使用
param = param.to(torch.float32)
只是在当前函数的局部作用域中重新绑定了param
,并不会修改模型中对应参数的实际数据。
- 直接使用
-
正确的做法 :直接修改参数的数据
data
,使其生效:pythonparam.data = param.data.to(torch.float32)
param.data
直接访问参数的核心数据张量,使用to(torch.float32)
方法将其转换为float32
类型。- 这样修改后,模型中对应的参数数据会被更新为
float32
,在后续的计算中会保持使用float32
精度。
-
注释中的 TODO:
python# TODO @Cyrilvallez: we seem to do this twice
- 这条注释表示开发者注意到可能在代码的其他地方也进行了类似的操作,即可能对参数的精度转换重复执行了两次。
- 这需要在后续的代码优化中进行检查,以避免不必要的重复操作,提高代码效率。
-
总结:
-
目的 :在模型中指定的模块(参数)保持
float32
精度,确保这些模块的计算和存储不会因为低精度(如float16
)而导致数值精度不足。 -
处理流程:
-
检查并准备模块列表:如果用户提供了要保持高精度的模块列表,将其编译为正则表达式模式,方便匹配参数名称。
-
遍历模型参数 :对模型的所有参数进行遍历,根据参数名称匹配需要保持
float32
精度的参数。 -
转换参数精度 :直接修改参数的数据张量,将其转换为
float32
类型。这里需要注意直接修改param.data
,而不是重新赋值param
,以确保修改生效。
-
-
注意事项:
-
参数的重新赋值问题 :在 Python 中,直接对
param
重新赋值只是在局部作用域中生效,不会影响模型内部的参数。因此,需要直接修改param.data
。 -
代码优化:存在可能重复执行精度转换的情况,需要在后续优化中加以解决,避免不必要的计算。
-
扩展说明:
-
混合精度训练/推理的背景:
-
在深度学习中,使用较低的浮点数精度(如
float16
)可以减少显存占用,加速计算。 -
但是,某些模块对数值精度较为敏感,使用低精度可能导致梯度消失、梯度爆炸或收敛速度变慢。
-
因此,通常在混合精度训练中,会指定某些模块保持较高的精度,例如:
-
LayerNorm 层:归一化层对精度较为敏感。
-
Embed 层:嵌入层的权重对于表示精度要求较高。
-
-
-
参数数据修改的安全性:
-
直接修改
param.data
有风险,因为它会绕过 Autograd 的追踪机制,可能导致梯度计算错误。 -
但是在这个上下文中,目的是在模型初始化阶段或参数加载阶段修改参数的数据类型,通常不会影响 Autograd 的正常工作。
-
-
正则表达式匹配:
-
使用正则表达式可以灵活地指定需要保持高精度的参数名称模式,方便用户自定义。
-
例如,用户可以提供模块名称列表,如
["layer_norm", "embedding"]
,匹配所有名称中包含layer_norm
或embedding
的参数。
-
示例:
假设 keep_in_fp32_modules = ["layer_norm", "embed"]
,则:
python
keep_in_fp32_modules = re.compile("layer_norm|embed")
- 在遍历参数时,如果参数名称包含
layer_norm
或embed
,则匹配成功,参数会被转换为float32
。
14. 处理嵌套模型的加载:
- 根据模型是否包含
base_model_prefix
,调整model_to_load
,以确保加载权重到正确的子模型中。
python
# Make sure we are able to load base models as well as derived models (with heads)
start_prefix = ""
model_to_load = model
if len(cls.base_model_prefix) > 0 and not hasattr(model, cls.base_model_prefix) and has_prefix_module:
start_prefix = cls.base_model_prefix + "."
if len(cls.base_model_prefix) > 0 and hasattr(model, cls.base_model_prefix) and not has_prefix_module:
model_to_load = getattr(model, cls.base_model_prefix)
base_model_expected_keys = list(model_to_load.state_dict().keys())
if any(key in expected_keys_not_prefixed and key not in base_model_expected_keys for key in loaded_keys):
raise ValueError(
"The state dictionary of the model you are trying to load is corrupted. Are you sure it was "
"properly saved?"
)
if device_map is not None:
device_map = {k.replace(f"{cls.base_model_prefix}.", ""): v for k, v in device_map.items()}
if resolved_archive_file is not None:
folder = os.path.sep.join(resolved_archive_file[0].split(os.path.sep)[:-1])
else:
folder = None
model.expected_keys = expected_keys
15. 展开设备映射并预热缓存分配器:
- 如果提供了
device_map
,则展开设备映射,并在加载前预热缓存分配器以优化内存使用。
python
if device_map is not None:
expanded_device_map = expand_device_map(device_map, original_loaded_keys, start_prefix)
if hf_quantizer is None:
caching_allocator_warmup(model_to_load, expanded_device_map, dtype)
这段代码的主要作用是:
- 设置模型的预期参数键,以便在后续加载或调试过程中使用。
- 处理设备映射(
device_map
),将模型的各个部分映射到指定的设备上。 - 预热缓存分配器(
caching_allocator_warmup
),在模型未被量化的情况下,通过预热操作优化内存分配,提高运行时的性能。
下面对代码进行逐行详细解释:
python
model.expected_keys = expected_keys
解释:
- 作用 :将模型对象的
expected_keys
属性设置为当前计算得到的expected_keys
。 expected_keys
:这是一个列表,包含模型预期加载的参数键名(即参数名称)。- 目的:在模型中保存预期的参数键,可以在后续加载参数、检查模型完整性或进行调试时使用。
python
if device_map is not None:
解释:
- 作用 :检查是否提供了
device_map
。 device_map
:这是一个字典,指定模型的各个部分(如层、模块)应该被放置在哪个设备上(例如,CPU、GPU)。- 目的 :如果用户提供了
device_map
,则需要根据它来设置模型各部分的设备映射,以支持多设备训练或推理。
python
expanded_device_map = expand_device_map(device_map, original_loaded_keys, start_prefix)
解释:
- 作用 :调用
expand_device_map
函数,生成一个扩展后的设备映射expanded_device_map
。 - 参数说明 :
device_map
:用户提供的设备映射字典,可能只指定了一部分模块的设备。original_loaded_keys
:模型实际加载的参数键名列表。start_prefix
:参数名称可能需要添加的前缀,这在之前的代码中可能被设置过。
expand_device_map
函数的作用 :- 扩展设备映射 :将用户提供的
device_map
扩展为一个完整的设备映射,包含模型的所有参数键。 - 这样可以确保未在
device_map
中指定的模块或参数,也被映射到默认设备(通常是 CPU)上。
- 扩展设备映射 :将用户提供的
- 目的:确保模型的所有部分都被正确地映射到设备上,避免因设备映射不完整导致的错误。
python
if hf_quantizer is None:
解释:
- 作用 :检查是否启用了量化器(
hf_quantizer
)。 hf_quantizer
:通常是一个量化器对象,如果它为None
,则说明没有启用模型量化(模型未被量化)。- 目的:只有在模型未被量化的情况下,才需要进行后续的缓存分配器预热操作。
python
caching_allocator_warmup(model_to_load, expanded_device_map, dtype)
解释:
- 作用 :调用
caching_allocator_warmup
函数,对模型进行缓存分配器预热。 - 参数说明 :
model_to_load
:待加载参数的模型对象,可能是在之前的代码中确定的(例如,基础模型部分)。expanded_device_map
:扩展后的设备映射,指示模型各部分应加载到的设备。dtype
:数据类型,通常是模型使用的默认数据类型(如torch.float32
、torch.float16
等)。
caching_allocator_warmup
函数的作用 :- 预热内存分配器:通过在各个设备上执行一次前向传播或参数访问,预先分配所需的内存。
- 优化性能:预热可以减少内存碎片,避免在训练或推理过程中出现内存不足或性能下降的情况。
- 注意 :如果模型被量化(
hf_quantizer
不为None
),可能不需要或无法进行相同的预热操作。
额外说明:
-
关于设备映射(
device_map
)和expand_device_map
:- 设备映射:用于指定模型的哪些部分应该放在哪些设备上,支持多 GPU 或混合设备的训练和推理。
- 扩展设备映射 :用户提供的
device_map
可能只包含部分模块的设备信息,因此需要扩展为完整的设备映射,包含所有参数键。 original_loaded_keys
和start_prefix
:这些参数用于正确匹配模型参数的名称,确保设备映射与模型参数一致。
-
关于缓存分配器预热(
caching_allocator_warmup
):- 内存管理的重要性:在深度学习中,尤其是大型模型,GPU 内存的高效使用至关重要。内存碎片可能导致实际可用内存减少。
- 预热的作用:通过在各设备上访问模型参数,触发 CUDA 内存分配器提前分配内存,减少运行时的内存分配和释放操作。
- 注意事项:预热操作需要访问模型的所有参数,因此在模型已经被量化或参数被压缩的情况下,预热可能不适用。
-
关于模型量化和
hf_quantizer
:- 模型量化:将模型的权重和激活从高精度(如 32 位浮点)减少到低精度(如 8 位整数),以减小模型大小和加速推理。
hf_quantizer
:如果不为None
,表示模型已被量化,可能需要特殊的操作或避免某些操作(如预热)。
16. 处理基于 safetensors
的情况:
如果使用 safetensors
格式,并且提供了 device_map
,则准备 offload_index
,用于在加载时将部分权重直接从磁盘读取。
python
if device_map is not None and is_safetensors:
param_device_map = expanded_device_map
str_dtype = str(dtype).replace("torch.", "") if dtype is not None else "float32"
if sharded_metadata is None:
archive_file = (
resolved_archive_file[0]
if isinstance(resolved_archive_file, (list, tuple))
else resolved_archive_file
)
weight_map = {p: archive_file for p in original_loaded_keys}
else:
weight_map = {p: os.path.join(folder, f) for p, f in sharded_metadata["weight_map"].items()}
offload_index = {
p[len(start_prefix) :]: {"safetensors_file": f, "weight_name": p, "dtype": str_dtype}
for p, f in weight_map.items()
if p.startswith(start_prefix) and param_device_map[p[len(start_prefix) :]] == "disk"
}
else:
offload_index = None
17. 加载权重到模型中:
- 如果提供了
state_dict
,则直接加载整个权重字典。 - 否则,根据权重文件(可能是分片的)的数量,逐个加载权重片段:
- 对于每个权重片段,加载状态字典,并处理可能的大小不匹配。
- 在低内存模式下,使用
_load_state_dict_into_meta_model
方法,将权重加载到元模型中,并处理卸载和索引。 - 如果不是低内存模式,则直接将状态字典加载到模型中。
python
error_msgs = []
if state_dict is not None:
#略
else:
# This should always be a list but, just to be sure.
if not isinstance(resolved_archive_file, list):
resolved_archive_file = [resolved_archive_file]
error_msgs = []
mismatched_keys = []
if not is_safetensors:
offload_index = {} if device_map is not None and "disk" in device_map.values() else None
if offload_state_dict:
state_dict_folder = tempfile.mkdtemp()
state_dict_index = {}
else:
state_dict_folder = None
state_dict_index = None
if is_sharded_safetensors:
disk_only_shard_files = get_disk_only_shard_files(
device_map, sharded_metadata=sharded_metadata, start_prefix=start_prefix
)
disk_only_shard_files = [os.path.join(folder, f) for f in disk_only_shard_files]
else:
disk_only_shard_files = []
if len(resolved_archive_file) > 1:
resolved_archive_file = logging.tqdm(resolved_archive_file, desc="Loading checkpoint shards")
assign_to_params_buffers = None
for shard_file in resolved_archive_file:
# Skip the load for shards that only contain disk-offloaded weights when using safetensors for the offload.
if shard_file in disk_only_shard_files:
continue
map_location = None
if (
device_map is not None
and hf_quantizer is not None
and hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO
and hf_quantizer.quantization_config.quant_type in ["int4_weight_only", "autoquant"]
):
map_location = torch.device([d for d in device_map.values() if d not in ["cpu", "disk"]][0])
state_dict = load_state_dict(
shard_file, is_quantized=is_quantized, map_location=map_location, weights_only=weights_only
)
# Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
# matching the weights in the model.
mismatched_keys += _find_mismatched_keys(
state_dict,
model_state_dict,
loaded_keys,
original_loaded_keys,
add_prefix_to_model,
remove_prefix_from_model,
ignore_mismatched_sizes,
prefix,
)
if low_cpu_mem_usage:
if is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized:
for key, param in model_to_load.state_dict().items():
if param.device == torch.device("meta"):
set_module_tensor_to_device(
model_to_load, key, "cpu", torch.empty(*param.size(), dtype=dtype)
)
else:
new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
model_to_load,
state_dict,
start_prefix,
expected_keys,
device_map=device_map,
offload_folder=offload_folder,
offload_index=offload_index,
state_dict_folder=state_dict_folder,
state_dict_index=state_dict_index,
dtype=dtype,
hf_quantizer=hf_quantizer,
is_safetensors=is_safetensors,
keep_in_fp32_modules=keep_in_fp32_modules,
unexpected_keys=unexpected_keys,
device_mesh=device_mesh,
shard_file=shard_file,
)
error_msgs += new_error_msgs
else:
state_dict = load_state_dict(shard_file, map_location="cpu", weights_only=weights_only)
# Sharded checkpoint or whole but low_cpu_mem_usage==True
if assign_to_params_buffers is None:
assign_to_params_buffers = check_support_param_buffer_assignment(
model_to_load, state_dict, start_prefix
)
fixed_state_dict = model_to_load._fix_state_dict_keys_on_load(state_dict)
model_to_load.load_state_dict(fixed_state_dict, strict=False, assign=assign_to_params_buffers)
# force memory release
del state_dict
gc.collect()
if offload_index is not None and len(offload_index) > 0:
if model != model_to_load:
# We need to add the prefix of the base model
prefix = cls.base_model_prefix
if not is_safetensors:
for weight_name in offload_index:
shutil.move(
os.path.join(offload_folder, f"{weight_name}.dat"),
os.path.join(offload_folder, f"{prefix}.{weight_name}.dat"),
)
offload_index = {f"{prefix}.{key}": value for key, value in offload_index.items()}
if not is_safetensors:
save_offload_index(offload_index, offload_folder)
offload_index = None
if offload_state_dict:
# Load back temporarily offloaded state dict
load_offloaded_weights(model_to_load, state_dict_index, state_dict_folder)
shutil.rmtree(state_dict_folder)
这段代码的主要作用是在加载模型的权重时,处理各种复杂的情况,包括分片的检查点文件、设备映射、低 CPU 内存使用、模型量化等。以下是对代码的详细分析:
1. 遍历检查点分片文件
python
for shard_file in resolved_archive_file:
- 作用 :遍历所有的检查点分片文件
resolved_archive_file
,对每个分片文件进行处理。
2. 跳过仅包含磁盘离线权重的分片文件
python
# 当使用 safetensors 进行离线加载时,跳过仅包含磁盘离线权重的分片
if shard_file in disk_only_shard_files:
continue
- 作用 :如果当前的
shard_file
在disk_only_shard_files
列表中,表示该分片文件仅包含需要从磁盘离线加载的权重,则直接跳过,不需要在此处加载。 - 原因:对于仅包含离线加载的参数,不需要在内存中加载,节省内存。
3. 初始化 map_location
python
map_location = None
- 作用 :初始化变量
map_location
,用于指定加载权重时的设备位置。
4. 根据条件设置权重加载的位置
python
if (
device_map is not None
and hf_quantizer is not None
and hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO
and hf_quantizer.quantization_config.quant_type in ["int4_weight_only", "autoquant"]
):
map_location = torch.device([d for d in device_map.values() if d not in ["cpu", "disk"]][0])
- 条件判断 :
device_map is not None
:提供了设备映射。hf_quantizer is not None
:使用了 Hugging Face 的量化器。hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO
:量化方法为TORCHAO
。hf_quantizer.quantization_config.quant_type in ["int4_weight_only", "autoquant"]
:量化类型为"int4_weight_only"
或"autoquant"
。
- 作用 :
- 如果以上条件都满足,说明正在使用 TORCH AO 的量化方法,且量化类型为指定的两种之一,需要在特定设备上加载权重。
- 从
device_map
中获取第一个不为"cpu"
或"disk"
的设备作为map_location
。
- 目的:确保在正确的设备上加载权重,适应特定的量化和设备配置。
5. 加载状态字典 state_dict
python
state_dict = load_state_dict(
shard_file, is_quantized=is_quantized, map_location=map_location, weights_only=weights_only
)
- 作用 :使用
load_state_dict
函数加载当前分片文件的状态字典state_dict
。 - 参数 :
shard_file
:当前处理的检查点分片文件路径。is_quantized
:模型是否已经量化,影响加载方式。map_location
:指定加载到的设备位置。weights_only
:是否只加载权重参数。
6. 查找形状不匹配的参数键
python
# mismatched_keys 包含在检查点中形状与模型不匹配的权重的键和形状
mismatched_keys += _find_mismatched_keys(
state_dict,
model_state_dict,
loaded_keys,
original_loaded_keys,
add_prefix_to_model,
remove_prefix_from_model,
ignore_mismatched_sizes,
prefix,
)
- 作用 :调用
_find_mismatched_keys
函数,找出加载的state_dict
中与模型参数形状不匹配的键,添加到mismatched_keys
列表中。 - 目的:在加载过程中记录形状不匹配的参数,方便后续处理。
7. 根据内存使用情况加载权重
python
if low_cpu_mem_usage:
# ...处理低内存使用的情况...
else:
# ...正常加载情况...
- 作用 :根据
low_cpu_mem_usage
的值,决定采用哪种方式加载权重。 low_cpu_mem_usage
为True
:表示希望尽量降低 CPU 内存占用,需要特殊处理。low_cpu_mem_usage
为False
:直接将权重加载到模型中。
7.1 处理低内存使用的情况
python
if is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized:
for key, param in model_to_load.state_dict().items():
if param.device == torch.device("meta"):
set_module_tensor_to_device(
model_to_load, key, "cpu", torch.empty(*param.size(), dtype=dtype)
)
else:
new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
# ...参数列表...
)
error_msgs += new_error_msgs
- 条件判断 :
is_fsdp_enabled()
:检查是否启用了 FSDP(Fully Sharded Data Parallel,全分片数据并行)。not is_local_dist_rank_0()
:当前进程不是本地的第 0 号进程。not is_quantized
:模型未量化。
- 处理方式 :
- FSDP 模式下的特殊处理 :将位于
meta
设备上的参数初始化为指定形状和数据类型的空张量,并移动到 CPU。 - 否则 :调用
_load_state_dict_into_meta_model
函数,增量地将状态字典加载到模型中,支持离线加载、设备映射等,尽量减少内存占用。
- FSDP 模式下的特殊处理 :将位于
7.2 正常加载情况
python
else:
state_dict = load_state_dict(shard_file, map_location="cpu", weights_only=weights_only)
if assign_to_params_buffers is None:
assign_to_params_buffers = check_support_param_buffer_assignment(
model_to_load, state_dict, start_prefix
)
fixed_state_dict = model_to_load._fix_state_dict_keys_on_load(state_dict)
model_to_load.load_state_dict(fixed_state_dict, strict=False, assign=assign_to_params_buffers)
- 作用 :
- 直接将状态字典加载到 CPU 内存中。
- 检查是否支持参数和缓冲区的直接赋值,以提高加载效率。
- 修正状态字典的键,使其与模型的参数名称匹配。
- 使用
load_state_dict
函数加载权重,strict=False
允许存在未匹配的参数。
8. 释放内存
python
# 强制释放内存
del state_dict
gc.collect()
- 作用 :删除已加载的
state_dict
,调用垃圾回收器释放内存,防止内存泄漏和过高的内存占用。
18. 处理量化器的更新:
- 如果使用了量化器,更新模型的缺失键和意外键,以匹配量化后的模型结构。
python
if hf_quantizer is not None:
missing_keys = hf_quantizer.update_missing_keys_after_loading(model_to_load, missing_keys, prefix)
19. 处理加载过程中的错误和警告:
- 如果
error_msgs
不为空,表示在加载过程中出现错误,抛出异常并终止。 - 根据缺失键、意外键和不匹配的键,记录相应的日志信息或警告。
- 提示用户是否需要重新训练模型以适应新的任务。
python
if len(error_msgs) > 0:
error_msg = "\n\t".join(error_msgs)
if "size mismatch" in error_msg:
error_msg += (
"\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
)
raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
if len(unexpected_keys) > 0:
archs = [] if model.config.architectures is None else model.config.architectures
warner = logger.warning if model.__class__.__name__ in archs else logger.info
warner(
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or"
" with another architecture (e.g. initializing a BertForSequenceClassification model from a"
" BertForPreTraining model).\n- This IS NOT expected if you are initializing"
f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly identical"
" (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
)
else:
logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
if len(missing_keys) > 0:
logger.warning(
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
" TRAIN this model on a down-stream task to be able to use it for predictions and inference."
)
elif len(mismatched_keys) == 0:
logger.info(
f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint"
f" was trained on, you can already use {model.__class__.__name__} for predictions without further"
" training."
)
if len(mismatched_keys) > 0:
mismatched_warning = "\n".join(
[
f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
for key, shape1, shape2 in mismatched_keys
]
)
logger.warning(
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able"
" to use it for predictions and inference."
)
20. 返回加载结果:
- 最终,方法返回加载后的模型,以及缺失的键、意外的键、不匹配的键、卸载索引和错误消息,供调用者进行进一步处理。
python
return model, missing_keys, unexpected_keys, mismatched_keys, offload_index, error_msgs
2 load_state_dict、_load_state_dict_into_meta_model 和 model_to_load.load_state_dict
在代码中,多次出现了 load_state_dict
、_load_state_dict_into_meta_model
和 model_to_load.load_state_dict
。这些函数/方法名称相似,但它们在代码中的作用和功能是不同的。
1. load_state_dict
(函数)
作用:
load_state_dict
是一个函数,用于从检查点文件(如shard_file
)中加载模型的状态字典(state_dict
)。- 这个函数负责从磁盘上的文件读取模型的权重和参数,返回一个包含所有参数的字典。
多次调用的原因:
- 在代码中,
resolved_archive_file
可能是一个包含多个检查点分片文件的列表,需要遍历每个分片文件并加载其中的参数。 - 因此,每次遍历到一个新的分片文件时,都需要调用
load_state_dict
函数,从该文件中加载对应的state_dict
。
示例代码:
python
for shard_file in resolved_archive_file:
# 其他代码...
state_dict = load_state_dict(
shard_file, is_quantized=is_quantized, map_location=map_location, weights_only=weights_only
)
# 对 state_dict 的后续处理...
2. _load_state_dict_into_meta_model
(函数)
作用:
_load_state_dict_into_meta_model
是一个函数,用于在内存受限(low_cpu_mem_usage=True
)的情况下,将状态字典(state_dict
)逐步加载到模型中。- 当模型的参数被放置在
'meta'
设备上时,模型的参数实际上并未占用实际的内存(是占位符)。这个函数负责将参数从state_dict
中逐个加载到模型的对应参数中,同时考虑到设备映射(device_map
)、参数的离线加载等复杂情况。
调用的条件:
- 当
low_cpu_mem_usage
为True
时,即希望降低 CPU 内存的使用量,在加载模型时会调用_load_state_dict_into_meta_model
。
示例代码:
python
if low_cpu_mem_usage:
# 其他条件检查...
new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
model_to_load,
state_dict,
start_prefix,
expected_keys,
device_map=device_map,
offload_folder=offload_folder,
offload_index=offload_index,
state_dict_folder=state_dict_folder,
state_dict_index=state_dict_index,
dtype=dtype,
hf_quantizer=hf_quantizer,
is_safetensors=is_safetensors,
keep_in_fp32_modules=keep_in_fp32_modules,
unexpected_keys=unexpected_keys,
device_mesh=device_mesh,
shard_file=shard_file,
)
error_msgs += new_error_msgs
功能详解:
- 逐步加载参数 :该函数会遍历
state_dict
中的每个参数,根据device_map
将参数加载到指定的设备上。 - 节省内存:通过一次只加载一小部分参数,避免一次性将所有参数加载到内存,控制内存峰值。
3. model_to_load.load_state_dict
(方法)
作用:
load_state_dict
是 PyTorch 中nn.Module
类的一个方法,用于将状态字典加载到模型中。- 它将状态字典中的参数复制到模型的对应参数中。
调用的条件:
- 当
low_cpu_mem_usage
为False
时,即内存充足的情况下,可以直接使用model_to_load.load_state_dict
方法一次性将所有参数加载到模型中。
示例代码:
python
else:
# 正常加载模式
state_dict = load_state_dict(shard_file, map_location="cpu", weights_only=weights_only)
# 可能需要修正参数的键名
fixed_state_dict = model_to_load._fix_state_dict_keys_on_load(state_dict)
model_to_load.load_state_dict(fixed_state_dict, strict=False, assign=assign_to_params_buffers)
功能详解:
- 直接加载 :一次性将整个
state_dict
加载到模型中,速度较快,但需要足够的内存。 - 键名修正 :在加载前,可能需要对
state_dict
的键名进行修正,使其与模型的参数名称匹配。
3 load_state_dict
这段代码定义了一个名为 load_state_dict
的函数,用于从模型检查点(checkpoint)文件中加载状态字典(state_dict)。它支持加载两种格式的文件:
.safetensors
格式:一种安全、高效的张量序列化格式。.bin
(PyTorch 原生)格式:PyTorch 保存的模型权重文件。
此外,该函数还支持将张量加载到指定设备(如 "cpu"
、"meta"
等)上,以及处理一些特殊情况,如模型量化、分布式训练等。
函数定义和参数说明
python
def load_state_dict(
checkpoint_file: Union[str, os.PathLike],
is_quantized: bool = False,
map_location: Optional[Union[str, torch.device]] = "meta",
weights_only: bool = True,
):
"""
Reads a `safetensor` or a `.bin` checkpoint file into `meta` if requested.
"""
-
函数名称 :
load_state_dict
-
参数:
checkpoint_file
:检查点文件的路径,可以是字符串或路径对象,支持.safetensors
或.bin
格式。is_quantized
:布尔值,指示模型是否已经量化。map_location
:指定加载张量的设备位置,可以是"meta"
、"cpu"
、"cuda"
等,默认为"meta"
。weights_only
:布尔值,指示是否只加载权重,如果为True
,则只加载模型的权重参数。
-
函数作用:从检查点文件中读取模型的状态字典,并将张量加载到指定的设备位置。
处理 .safetensors
格式的文件
python
if checkpoint_file.endswith(".safetensors") and is_safetensors_available():
with safe_open(checkpoint_file, framework="pt") as f:
metadata = f.metadata()
if metadata is not None and metadata.get("format") not in ["pt", "tf", "flax", "mlx"]:
raise OSError(
f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure "
"you save your model with the `save_pretrained` method."
)
state_dict = {}
for k in f.keys():
dtype = str_to_torch_dtype[f.get_slice(k).get_dtype()]
if map_location == "meta":
state_dict[k] = torch.empty(size=f.get_slice(k).get_shape(), dtype=dtype, device="meta")
else:
state_dict[k] = f.get_tensor(k)
return state_dict
解释:
-
检查文件格式和库是否可用:
checkpoint_file.endswith(".safetensors")
:检查文件是否以.safetensors
结尾。is_safetensors_available()
:检查safetensors
库是否已安装和可用。
-
打开
.safetensors
文件:pythonwith safe_open(checkpoint_file, framework="pt") as f:
- 使用
safe_open
函数打开文件,指定框架为 PyTorch (framework="pt"
)。
- 使用
-
检查元数据(metadata):
pythonmetadata = f.metadata() if metadata is not None and metadata.get("format") not in ["pt", "tf", "flax", "mlx"]: raise OSError(...)
- 获取文件的元数据,如果存在,则检查其中的
"format"
字段是否在允许的列表中("pt"
、"tf"
、"flax"
、"mlx"
)。 - 如果格式不正确,抛出
OSError
,提示用户确保模型是使用save_pretrained
方法保存的。
- 获取文件的元数据,如果存在,则检查其中的
-
读取张量并构建状态字典:
pythonstate_dict = {} for k in f.keys(): dtype = str_to_torch_dtype[f.get_slice(k).get_dtype()] if map_location == "meta": state_dict[k] = torch.empty(size=f.get_slice(k).get_shape(), dtype=dtype, device="meta") else: state_dict[k] = f.get_tensor(k)
- 初始化一个空的状态字典
state_dict
。 - 遍历文件中的所有键(参数名称):
- 使用
f.get_slice(k).get_dtype()
获取张量的数据类型,并使用str_to_torch_dtype
将字符串转换为 PyTorch 的数据类型。 - 如果
map_location
是"meta"
,则创建一个在"meta"
设备上的空张量,占位但不占用实际内存。 - 如果
map_location
不是"meta"
,则直接从文件中读取张量数据。 - 将张量添加到状态字典中,键为参数名称
k
。
- 使用
- 初始化一个空的状态字典
-
返回状态字典:
pythonreturn state_dict
- 返回构建好的状态字典。
处理其他格式(例如 .bin
)的文件
python
try:
if map_location is None:
if (
(
is_deepspeed_zero3_enabled()
and torch.distributed.is_initialized()
and torch.distributed.get_rank() > 0
)
or (is_fsdp_enabled() and not is_local_dist_rank_0())
) and not is_quantized:
map_location = "meta"
else:
map_location = "cpu"
extra_args = {}
# mmap can only be used with files serialized with zipfile-based format.
if (
isinstance(checkpoint_file, str)
and map_location != "meta"
and version.parse(torch.__version__) >= version.parse("2.1.0")
and is_zipfile(checkpoint_file)
):
extra_args = {"mmap": True}
weights_only_kwarg = {"weights_only": weights_only}
return torch.load(
checkpoint_file,
map_location=map_location,
**weights_only_kwarg,
**extra_args,
)
except Exception as e:
# 异常处理...
解释:
-
设置
map_location
(加载张量的设备位置):pythonif map_location is None: if ( ( is_deepspeed_zero3_enabled() and torch.distributed.is_initialized() and torch.distributed.get_rank() > 0 ) or (is_fsdp_enabled() and not is_local_dist_rank_0()) ) and not is_quantized: map_location = "meta" else: map_location = "cpu"
-
如果
map_location
未指定,则根据以下条件设置:- 条件 1 :如果启用了 DeepSpeed ZeRO Stage 3,并且分布式训练已初始化,当前进程的
rank
大于 0,且模型未量化,则将map_location
设置为"meta"
。 - 条件 2 :如果启用了 Fully Sharded Data Parallel(FSDP),且当前不是本地第 0 号进程,且模型未量化,则将
map_location
设置为"meta"
。 - 否则 :将
map_location
设置为"cpu"
。
- 条件 1 :如果启用了 DeepSpeed ZeRO Stage 3,并且分布式训练已初始化,当前进程的
-
目的 :在多进程训练中,仅在需要的进程上加载实际的参数,其余进程使用
"meta"
设备以节省内存。
-
-
准备额外的参数
extra_args
:pythonextra_args = {} if ( isinstance(checkpoint_file, str) and map_location != "meta" and version.parse(torch.__version__) >= version.parse("2.1.0") and is_zipfile(checkpoint_file) ): extra_args = {"mmap": True}
-
如果满足以下条件,则将
extra_args
设置为{"mmap": True}
:checkpoint_file
是字符串类型(文件路径)。map_location
不是"meta"
。- PyTorch 版本大于等于 2.1.0。
checkpoint_file
是一个 zip 文件(is_zipfile(checkpoint_file)
返回True
)。
-
目的 :当满足上述条件时,可以使用内存映射(
mmap
)来加载模型,加速加载过程并减少内存占用。
-
-
准备
weights_only
参数:pythonweights_only_kwarg = {"weights_only": weights_only}
- 将
weights_only
参数封装为字典,以便在调用torch.load
时使用。
- 将
-
加载检查点文件:
pythonreturn torch.load( checkpoint_file, map_location=map_location, **weights_only_kwarg, **extra_args, )
- 调用
torch.load
函数加载检查点文件。 - 参数包括:
checkpoint_file
:文件路径。map_location
:指示加载张量的设备位置。**weights_only_kwarg
:是否只加载权重参数。**extra_args
:可能包含{"mmap": True}
。
- 调用
-
异常处理:
pythonexcept Exception as e: # 处理可能的异常,如文件不存在、格式错误等。
- 如果在加载过程中出现异常,进入异常处理部分。
异常处理
python
except Exception as e:
try:
with open(checkpoint_file) as f:
if f.read(7) == "version":
raise OSError(
"You seem to have cloned a repository without having git-lfs installed. Please install "
"git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
"you cloned."
)
else:
raise ValueError(
f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained "
"model. Make sure you have saved the model properly."
) from e
except (UnicodeDecodeError, ValueError):
raise OSError(
f"Unable to load weights from pytorch checkpoint file for '{checkpoint_file}' "
f"at '{checkpoint_file}'. "
"If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True."
)
解释:
-
第一次尝试打开文件:
- 使用
open(checkpoint_file)
尝试打开文件。 - 如果成功,读取前 7 个字符。
- 使用
-
检查是否缺少 Git LFS 大文件:
pythonif f.read(7) == "version": raise OSError(...)
- 如果读取的内容为
"version"
,可能意味着用户在未安装 Git LFS 的情况下克隆了包含大文件的仓库,造成文件实际上并未下载,只是一个占位符。 - 抛出
OSError
,提示用户安装 Git LFS 并执行git lfs pull
。
- 如果读取的内容为
-
处理文件找不到的情况:
pythonelse: raise ValueError(...) from e
- 如果文件无法找到或读取,抛出
ValueError
,提示无法找到必要的文件,并建议确保模型已正确保存。
- 如果文件无法找到或读取,抛出
-
第二次异常处理:
pythonexcept (UnicodeDecodeError, ValueError): raise OSError(...)
- 如果在尝试读取文件时遇到
UnicodeDecodeError
或ValueError
,说明文件无法正确读取。 - 抛出
OSError
,提示无法从指定的 PyTorch 检查点文件加载权重。 - 如果用户尝试从 TensorFlow 2.0 的检查点加载 PyTorch 模型,建议设置
from_tf=True
。
- 如果在尝试读取文件时遇到
附加说明
-
safetensors
库和格式:safetensors
是一种安全、高效的张量序列化格式,具有以下优点:- 安全性:避免了使用 pickle 存在的安全问题。
- 效率:支持零拷贝加载,减少内存占用和加快加载速度。
-
"meta"
设备:- PyTorch 中的
"meta"
设备是一种特殊的设备,用于创建占位张量,不实际分配内存。 - 常用于大型模型的初始化和分布式训练,帮助节约内存。
- PyTorch 中的
-
内存映射(
mmap
):- 内存映射是一种文件 I/O 技术,允许直接在内存中访问文件的内容,而无需将文件完全读入内存。
- 在加载大型模型时,使用内存映射可以加快加载速度并减少内存占用。
-
处理常见错误:
- 缺少 Git LFS 的情况:当用户克隆了包含大型模型文件的仓库但未安装 Git LFS,文件将只是占位符,无法正确加载。代码对此进行了检查并提供了具体的提示。
- 文件无法找到或格式错误:代码提供了详细的错误信息,提示用户检查文件路径和模型保存方式。
使用示例
python
# 示例 1:加载 safetensors 格式的模型,仅加载权重到 "cpu"
state_dict = load_state_dict("path/to/model.safetensors", map_location="cpu", weights_only=True)
# 示例 2:在内存受限的情况下,将大型模型的参数加载到 "meta" 设备
state_dict = load_state_dict("path/to/large_model.bin", map_location="meta", weights_only=True)
# 示例 3:处理量化模型,加载到指定的设备上
state_dict = load_state_dict("path/to/quantized_model.bin", is_quantized=True, map_location="cuda:0")
# 示例 4:处理加载错误
try:
state_dict = load_state_dict("path/to/nonexistent_model.bin")
except OSError as e:
print(f"加载模型时出错:{e}")
4 _load_state_dict_into_meta_model
_load_state_dict_into_meta_model
的函数,用于将 state_dict
中的参数加载到模型中,特别是处理那些参数位于 'meta'
设备上的情况。它考虑了设备映射(device_map
)、设备网格(device_mesh
)、参数卸载(offloading)、量化(quantization)等复杂场景。下面我将对代码逐步进行详细解释。
函数定义和装饰器
python
@torch.no_grad()
def _load_state_dict_into_meta_model(
model: torch.nn.Module,
state_dict: Dict[str, torch.Tensor],
start_prefix,
expected_keys,
device_map=None,
offload_folder=None,
offload_index=None,
state_dict_folder=None,
state_dict_index=None,
dtype=None,
hf_quantizer=None,
is_safetensors=False,
keep_in_fp32_modules=None,
unexpected_keys=None, # passing `unexpected` for cleanup from quantization items
device_mesh=None,
shard_file=None,
):
...
装饰器 @torch.no_grad()
- 作用:禁止在该函数内追踪梯度计算,即在执行过程中不会计算或存储梯度,节省内存,提高效率。
- 原因:加载参数时不需要计算梯度。
函数参数
model
: 需要加载参数的模型对象。state_dict
: 状态字典,包含模型的参数键(名称)和对应的张量。start_prefix
: 参数名称的起始前缀,用于调整参数名称。expected_keys
: 预期的参数键列表,用于验证加载参数的完整性。device_map
: 设备映射字典,指定模型的不同部分(参数)应加载到的设备(如 CPU、GPU)。offload_folder
: 卸载参数的文件夹路径,当参数需要卸载到磁盘时使用。offload_index
: 卸载参数的索引,用于记录哪些参数已卸载。state_dict_folder
: 状态字典的临时存储文件夹。state_dict_index
: 状态字典索引,用于记录临时存储的参数。dtype
: 数据类型,用于参数类型转换(如torch.float16
)。hf_quantizer
: Hugging Face 的量化器对象,用于处理模型量化。is_safetensors
: 布尔值,指示是否使用了safetensors
格式。keep_in_fp32_modules
: 正则表达式,用于匹配需要保持为float32
精度的模块/参数。unexpected_keys
: 未预期的参数键列表,用于清理量化等过程中遗留的参数。device_mesh
: 设备网格对象,用于张量并行(Tensor Parallel)场景。shard_file
: 当前处理的分片文件路径。
函数文档字符串
- 说明了函数的作用、处理的特殊情况,以及
start_prefix
的用途。
1. 初始化 tensor_device
python
tensor_device = None
if device_map is not None and device_map.get("", None) is not None:
tensor_device = device_map[""].index if isinstance(device_map[""], torch.device) else device_map[""]
- 作用:确定用于加载张量的默认设备。
- 解释 :
- 如果提供了
device_map
,并且有一个键为""
的设备映射,则将其设为tensor_device
。 - 这是为了在后续加载参数时,指定张量加载到的设备。
- 如果提供了
2. 构建设备映射的正则表达式
python
if device_map is not None:
device_map_regex = "|".join(sorted(device_map.keys(), reverse=True))
- 作用:创建一个正则表达式,用于匹配参数名称与设备映射中的键。
- 解释 :
- 将
device_map
中的键按逆序排序,并用"|"
连接,生成一个正则表达式模式。 - 这是为了在后续处理中,根据参数名称确定其对应的设备。
- 将
3. 初始化张量并行计划(如果需要)
python
if device_mesh is not None:
full_tp_plan = model.config.base_model_tp_plan
for submodule in model.modules():
full_tp_plan.update(getattr(submodule, "_tp_plan", {}))
- 作用 :当使用张量并行(Tensor Parallelism)时,构建完整的并行计划(
full_tp_plan
)。 - 解释 :
- 从模型配置中获取基础模型的并行计划。
- 遍历模型的所有子模块,收集子模块中的并行计划,更新到
full_tp_plan
中。
4. 准备读取分片文件
python
file_pointer = None
bin_state_dict = None
if shard_file.endswith(".safetensors"):
file_pointer = safe_open(shard_file, framework="pt", device=tensor_device)
else:
bin_state_dict = load_state_dict(shard_file, map_location="cpu")
-
作用:根据分片文件的格式,打开文件并读取参数。
-
解释:
- 如果分片文件是
.safetensors
格式,使用safe_open
函数打开,获得文件指针file_pointer
。 - 否则,使用
load_state_dict
加载.bin
格式的状态字典到bin_state_dict
。
- 如果分片文件是
5. 初始化错误信息列表
python
error_msgs = []
is_quantized = hf_quantizer is not None
is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn")
-
作用:准备记录错误信息,并检查环境支持。
-
解释:
error_msgs
: 用于存储在加载过程中产生的错误信息。is_quantized
: 判断是否使用了量化器。is_torch_e4m3fn_available
: 检查 PyTorch 是否支持float8_e4m3fn
类型。
6. 遍历 state_dict
中的参数
python
for serialized_param_name, empty_param in state_dict.items():
# ...
- 作用:逐个处理状态字典中的参数。
6.1 处理参数名称(键名)
python
fixed_param_name, _ = model.rename_key(serialized_param_name)
if fixed_param_name not in expected_keys:
continue
-
作用:将序列化的参数名称转换为模型中的参数名称,并检查是否在预期的参数列表中。
-
解释:
- 使用
model.rename_key
将序列化的参数名称(可能有前缀等)转换为模型中的参数名称。 - 如果转换后的参数名称不在
expected_keys
中,跳过该参数。
- 使用
6.2 从文件中读取参数值
python
param = (
file_pointer.get_slice(serialized_param_name)
if shard_file.endswith(".safetensors")
else bin_state_dict[serialized_param_name]
)
-
作用:根据文件类型,从文件中获取参数的值或切片。
-
解释:
- 如果使用
safetensors
格式,使用file_pointer.get_slice
获取参数的切片(元数据)。 - 否则,从
bin_state_dict
中获取参数的张量。
- 如果使用
6.3 处理数据类型转换
python
param_casting_dtype = None
is_param_float8_e4m3fn = is_torch_e4m3fn_available and empty_param.dtype == torch.float8_e4m3fn
if dtype is not None and empty_param.dtype.is_floating_point and not is_param_float8_e4m3fn:
if (
keep_in_fp32_modules is not None
and keep_in_fp32_modules.search(fixed_param_name)
and dtype == torch.float16
):
param_casting_dtype = torch.float32
else:
param_casting_dtype = dtype
-
作用:确定是否需要对参数进行数据类型转换。
-
解释:
- 如果指定了
dtype
,并且参数是浮点类型,且不是float8_e4m3fn
(特殊类型),则考虑转换。 - 如果参数名称匹配到
keep_in_fp32_modules
,并且dtype
是float16
,则将参数保留为float32
。 - 否则,将参数转换为指定的
dtype
。
- 如果指定了
6.4 处理张量并行的情况
python
if device_mesh is not None:
# 进行张量并行处理
# ...
-
作用:在使用张量并行时,对参数进行切分、分片,并在设备网格上分布。
-
解释:
- 根据
full_tp_plan
,确定当前参数的并行计划。 - 根据并行策略,对参数进行行或列切分,得到分片的参数。
- 将参数转换为指定的
dtype
(如果需要)。 - 使用
DTensor
将参数封装为分布式张量,指定设备网格和放置方式。 - 将参数赋值到模型中对应的模块和参数上。
- 根据
6.5 处理非张量并行的情况
python
else:
# 非张量并行的处理
# ...
- 解释:处理不使用张量并行的情况,包括参数的设备映射、卸载、量化等。
6.5.1 确定参数的设备
python
if device_map is None:
param_device = "cpu"
else:
module_layer = re.search(device_map_regex, fixed_param_name)
if not module_layer:
raise ValueError(f"{fixed_param_name} doesn't have any device set.")
else:
param_device = device_map[module_layer.group()]
-
作用 :根据
device_map
,确定参数应该加载到的设备。 -
解释:
- 如果没有提供
device_map
,默认将参数加载到 CPU。 - 否则,使用之前构建的正则表达式匹配参数名称,找到对应的设备。
- 如果未匹配到任何设备,抛出错误。
- 如果没有提供
6.5.2 处理参数卸载到磁盘
python
if param_device == "disk":
if not is_safetensors:
offload_index = offload_weight(param[:], fixed_param_name, offload_folder, offload_index)
-
作用:如果参数需要卸载到磁盘,执行卸载操作。
-
解释:
- 如果参数设备为
"disk"
,并且不是safetensors
格式,则调用offload_weight
函数,将参数卸载到指定的文件夹。 - 更新
offload_index
,记录已卸载的参数。
- 如果参数设备为
6.5.3 处理参数临时存储
python
elif param_device == "cpu" and state_dict_index is not None:
state_dict_index = offload_weight(param[:], fixed_param_name, state_dict_folder, state_dict_index)
-
作用:如果参数需要临时存储到指定文件夹,执行存储操作。
-
解释:
- 如果参数设备为
"cpu"
,且提供了state_dict_index
,则调用offload_weight
函数,将参数存储到state_dict_folder
。 - 更新
state_dict_index
。
- 如果参数设备为
6.5.4 处理量化和参数加载
python
elif (
not is_quantized
or (not hf_quantizer.requires_parameters_quantization)
or (
not hf_quantizer.check_quantized_param(
model,
param,
fixed_param_name,
state_dict,
param_device=param_device,
device_map=device_map,
)
)
):
# 直接加载参数到模型中
# ...
else:
# 处理量化参数
# ...
-
作用:根据模型是否量化,决定如何加载参数。
-
解释:
- 如果模型未量化,或者不需要参数量化,或者检查量化参数失败,则直接加载参数到模型中。
- 否则,使用量化器处理量化参数。
6.5.5 直接加载参数到模型中
python
if is_fsdp_enabled():
param_device = "cpu" if is_local_dist_rank_0() else "meta"
module, param_type = find_submodule_and_param_name(model, fixed_param_name)
if param_casting_dtype is not None and param_casting_dtype != empty_param.dtype:
param = param[:].to(param_casting_dtype)
module.load_state_dict(
{param_type: param[:].to(param_device)},
strict=False,
assign=True,
)
-
作用:将参数加载到模型的对应模块和参数中。
-
解释:
- 如果启用了 FSDP(全分片数据并行),根据当前进程的 rank,决定参数加载到 CPU 或
'meta'
设备。 - 使用
find_submodule_and_param_name
函数,找到模型中的子模块和参数名称。 - 如果需要,转换参数的数据类型。
- 使用
module.load_state_dict
直接加载参数到模型中。
- 如果启用了 FSDP(全分片数据并行),根据当前进程的 rank,决定参数加载到 CPU 或
6.5.6 处理量化参数
python
else:
hf_quantizer.create_quantized_param(
model, param[:], fixed_param_name, param_device, state_dict, unexpected_keys
)
# 处理 FSDP 或 DeepSpeed Stage 3 下的特殊情况
# ...
-
作用:使用量化器创建量化参数,并处理特定的并行训练场景。
-
解释:
- 使用
hf_quantizer.create_quantized_param
创建量化后的参数,并将其赋值到模型中。 - 如果启用了 FSDP 或 DeepSpeed ZeRO Stage 3,需要将参数从 GPU 转移到 CPU 或
'meta'
,以避免每个 GPU 上内存的过度消耗。
- 使用
7. 关闭文件指针
python
if file_pointer is not None:
file_pointer.__exit__(None, None, None)
- 作用 :如果打开了
safetensors
文件指针,在处理完毕后关闭文件。
8. 返回错误信息和索引
python
return error_msgs, offload_index, state_dict_index
- 作用:函数返回加载过程中产生的错误信息,以及更新后的卸载索引和状态字典索引。
5 find_submodule_and_param_name
函数概述
find_submodule_and_param_name
是一个辅助函数,用于在 PyTorch 模型的层次结构中查找特定的子模块和参数(或缓冲区)。给定一个参数的完整名称(通常是在模型的状态字典中的键,表示参数在模型中的路径),该函数通过遍历模型的模块,找到最深层的子模块,并返回该子模块和参数名称。
如果提供了 start_prefix
,函数会在处理之前从键中移除这个前缀。
函数定义
python
def find_submodule_and_param_name(model, long_key, start_prefix=""):
"""
一个辅助工具,用于找到最后的子模块和参数/缓冲区名称。
如果提供了 `start_prefix`,则会从键的开头移除这个前缀。
"""
if len(start_prefix) > 0 and long_key.startswith(start_prefix):
long_key = ".".join(long_key.split(".")[1:])
split_key = long_key.split(".")
submodule = model
while len(split_key) > 1:
if hasattr(submodule, split_key[0]):
submodule = getattr(submodule, split_key[0])
del split_key[0]
else:
submodule = None
break
if submodule == model:
submodule = None
return submodule, split_key[0]
逐步解释
-
移除
start_prefix
(如果提供了)pythonif len(start_prefix) > 0 and long_key.startswith(start_prefix): long_key = ".".join(long_key.split(".")[1:])
- 目的 :如果提供了
start_prefix
,并且long_key
以此前缀开始,那么从long_key
中移除这个前缀,以简化后续的处理。 - 示例 :
- 如果
start_prefix = "model"
,long_key = "model.layer1.weight"
,处理后,long_key
变为"layer1.weight"
。
- 如果
- 目的 :如果提供了
-
将键分割成组件
pythonsplit_key = long_key.split(".") submodule = model
-
目的 :使用
.
分隔符将long_key
拆分为一个包含属性名称的列表。 -
变量:
split_key
:属性名称的列表。submodule
:初始设置为模型的根模块(model
)。
-
示例:
- 如果
long_key = "layer1.conv1.weight"
,则split_key = ["layer1", "conv1", "weight"]
。
- 如果
-
-
遍历模型层次结构
pythonwhile len(split_key) > 1: if hasattr(submodule, split_key[0]): submodule = getattr(submodule, split_key[0]) del split_key[0] else: submodule = None break
-
目的 :根据
split_key
中的名称,逐层深入模型的子模块,找到包含目标参数的子模块。 -
过程:
- 当
split_key
的长度大于 1 时,继续遍历。 - 检查当前
submodule
是否具有名为split_key[0]
的属性。- 如果存在,则将
submodule
更新为该属性(应为一个子模块)。 - 从
split_key
中删除第一个元素。
- 如果存在,则将
- 如果属性不存在,将
submodule
设为None
,并退出循环。
- 当
-
示例步骤:
- 初始状态 :
split_key = ["layer1", "conv1", "weight"]
submodule = model
- 第一次迭代 :
- 检查
model
是否有属性"layer1"
。 - 存在,则
submodule = model.layer1
- 删除
"layer1"
,split_key = ["conv1", "weight"]
- 检查
- 第二次迭代 :
- 检查
submodule
(model.layer1
)是否有属性"conv1"
。 - 存在,则
submodule = model.layer1.conv1
- 删除
"conv1"
,split_key = ["weight"]
- 检查
- 循环结束 ,因为
len(split_key) == 1
- 初始状态 :
-
-
处理
submodule
仍为根模型的情况pythonif submodule == model: submodule = None
- 目的 :如果遍历后
submodule
仍然是模型的根模块,说明参数直接在根模块中,没有更深的子模块。将submodule
设为None
。
- 目的 :如果遍历后
-
返回子模块和参数名称
pythonreturn submodule, split_key[0]
- 目的 :返回找到的
submodule
和剩余的split_key[0]
,这应当是参数或缓冲区的名称。 - 变量 :
submodule
:包含参数的最深子模块(如果在根模块,则为None
)。split_key[0]
:在该子模块中的参数名称。
- 目的 :返回找到的
整体功能
- 目的:根据参数的完整键(如状态字典中的键),找到模型中的子模块和参数名称,方便访问和修改特定的参数。
- 使用场景:在加载状态字典或需要根据参数名称操作模型参数时,使用该函数定位目标参数。
示例场景
假设我们有一个模型,结构如下:
python
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.layer1 = nn.Module()
self.layer1.conv1 = nn.Conv2d(3, 64, kernel_size=3)
self.layer1.conv2 = nn.Conv2d(64, 128, kernel_size=3)
self.layer2 = nn.Linear(128, 10)
状态字典中的键包括:
"layer1.conv1.weight"
"layer1.conv1.bias"
"layer1.conv2.weight"
"layer1.conv2.bias"
"layer2.weight"
"layer2.bias"
使用该函数查找参数:
python
model = MyModel()
# 查找 "layer1.conv1.weight" 对应的子模块和参数名称
submodule, param_name = find_submodule_and_param_name(model, "layer1.conv1.weight")
-
过程:
split_key = ["layer1", "conv1", "weight"]
- 初始
submodule = model
- 第一次迭代 :
model
有属性"layer1"
,更新submodule = model.layer1
- 删除
"layer1"
,split_key = ["conv1", "weight"]
- 第二次迭代 :
submodule
(model.layer1
)有属性"conv1"
,更新submodule = model.layer1.conv1
- 删除
"conv1"
,split_key = ["weight"]
- 循环结束
-
结果:
submodule = model.layer1.conv1
param_name = "weight"
现在可以访问参数:
python
param = getattr(submodule, param_name) # 等同于 model.layer1.conv1.weight
处理 start_prefix
如果 long_key
有一个需要忽略的前缀,可以使用 start_prefix
来移除。
示例:
- 给定
long_key = "module.layer1.conv1.weight"
,start_prefix = "module"
,函数会先移除"module"
,得到"layer1.conv1.weight"
。
代码执行:
python
submodule, param_name = find_submodule_and_param_name(
model, "module.layer1.conv1.weight", start_prefix="module"
)
# 函数将 "module.layer1.conv1.weight" 视为 "layer1.conv1.weight"
特殊情况处理
-
参数在根模块上:
-
如果参数直接在模型的根模块上,例如
"weight"
,且没有子模块: -
split_key = ["weight"]
-
循环不执行,
submodule
保持为model
-
函数将
submodule
设为None
,表示参数在根模块上。
-
-
子模块不存在:
-
如果在遍历过程中,遇到
submodule
不存在所需的属性: -
例如,处理
"layer1.conv3.weight"
时,model.layer1
没有conv3
属性。 -
hasattr(submodule, "conv3")
返回False
,submodule
设为None
,退出循环。 -
函数返回
submodule = None
,param_name
为剩余的键(可能不对应实际的参数)。
-