GR00T N1.7源码学习(一):工程入口、模型结构与动作生成流程解析-CSDN博客
GR00T N1.7源码学习(二):训练数据、Processor与多机器人动作空间解析-CSDN博客
GR00T N1.7源码学习(三):动作头内部模块、DiT结构与多机器人条件编码解析-CSDN博客
GR00T N1.7源码学习(四):微调流程、训练Pipeline与Checkpoint保存机制解析-CSDN博客
前四篇分别分析了GR00T N1.7的模型主线、数据处理流程、动作头内部结构和微调训练框架,训练侧的内容基本已经串起来了,但训练好的Checkpoint还不能直接控制机器人。部署阶段还需要一层Policy封装,把环境Observation整理成模型输入,把模型输出的归一化动作还原成机器人动作,并在连续控制过程中处理动作块之间的衔接问题。
这一篇重点看推理侧代码,本文主要围绕下面几个文件展开:
gr00t/policy/policy.py
gr00t/policy/gr00t_policy.py
gr00t/policy/server_client.py
gr00t/model/gr00t_n1d7/gr00t_n1d7.py
scripts/deployment/standalone_inference_script.py
scripts/deployment/export_onnx_n1d7.py
scripts/deployment/build_tensorrt_engine.py
scripts/deployment/build_trt_pipeline.py
这篇可以看成整个系列的收尾:前面几篇回答模型如何训练出来,这一篇回答训练好的模型如何真正进入推理和部署。
1、BasePolicy定义统一的在线推理接口
GR00T推理侧的基础接口定义在,
gr00t/policy/policy.py
BasePolicy本身是一个抽象类,规定了机器人Policy需要实现哪些方法,
class BasePolicy(ABC):
"""Abstract base class for robotic control policies."""
def __init__(self, *, strict: bool = True):
self.strict = strict
@abstractmethod
def check_observation(self, observation: dict[str, Any]) -> None:
pass
@abstractmethod
def check_action(self, action: dict[str, Any]) -> None:
pass
@abstractmethod
def _get_action(
self, observation: dict[str, Any], options: dict[str, Any] | None = None
) -> tuple[dict[str, Any], dict[str, Any]]:
pass
@abstractmethod
def reset(self, options: dict[str, Any] | None = None) -> dict[str, Any]:
pass
外部调用的是get_action(),它会根据strict决定是否检查输入输出格式,
def get_action(
self, observation: dict[str, Any], options: dict[str, Any] | None = None
) -> tuple[dict[str, Any], dict[str, Any]]:
if self.strict:
self.check_observation(observation)
action, info = self._get_action(observation, options)
if self.strict:
self.check_action(action)
return action, info
这个接口把在线推理拆成三步:先检查Observation,再调用内部_get_action()得到动作,最后检查动作格式。
2、Gr00tPolicy从Checkpoint恢复模型和Processor
Gr00tPolicy是GR00T在线推理的核心类,定义在,
gr00t/policy/gr00t_policy.py
初始化函数主要做加载模型、加载Processor、读取当前机器人对应的模态配置。
class Gr00tPolicy(BasePolicy):
"""Core policy class for Gr00t model inference."""
def __init__(
self,
embodiment_tag: EmbodimentTag | str,
model_path: str,
*,
device: int | str,
strict: bool = True,
):
import gr00t.model # noqa: F401
super().__init__(strict=strict)
if isinstance(embodiment_tag, str):
embodiment_tag = EmbodimentTag.resolve(embodiment_tag)
model_dir = Path(model_path)
# Load the pretrained model and move to target device with bfloat16 precision
model = AutoModel.from_pretrained(model_dir)
model.eval()
model.to(device=device, dtype=torch.bfloat16)
self.model = model
会先import gr00t.model,目的是触发模型注册逻辑。然后通过AutoModel.from_pretrained(model_dir)从Checkpoint目录恢复模型,并切到eval()模式。推理默认把模型移动到目标设备,并使用torch.bfloat16。
Processor加载稍微多了一层兼容逻辑,
processor_dir = (
model_dir / "processor"
if (model_dir / "processor").is_dir()
and not (model_dir / "processor_config.json").exists()
else model_dir
)
self.processor: BaseProcessor = AutoProcessor.from_pretrained(processor_dir)
self.processor.eval()
训练时Checkpoint目录里不仅保存模型权重,还会保存Processor和统计量。这里的逻辑就是为了兼容两种目录结构:如果模型根目录下没有processor_config.json,但存在processor/子目录,就从processor/里加载Processor;否则直接从模型根目录加载。
接下来会检查当前embodiment_tag是否被该Checkpoint支持,
self.embodiment_tag = embodiment_tag
all_modality_configs = self.processor.get_modality_configs()
if self.embodiment_tag.value not in all_modality_configs:
supported_lines = []
for tag_value in sorted(all_modality_configs.keys()):
enum_name = EmbodimentTag.reverse_lookup(tag_value)
if enum_name != tag_value:
supported_lines.append(f" {enum_name:30s} (--embodiment-tag {enum_name})")
else:
supported_lines.append(f" {tag_value:30s} (internal, no public enum)")
如果传入的机器人类型不在Checkpoint支持范围内,源码会给出比较明确的提示。比如某些posttrain tag需要微调后的Checkpoint,不能直接用base model;某些finetune-only tag适合训练自定义机器人,而不是直接拿base checkpoint推理。
最后保存当前机器人对应的模态配置,
self.modality_configs = {
k: v
for k, v in all_modality_configs[self.embodiment_tag.value].items()
if k != "rl_info"
}
self.collate_fn = self.processor.collator
language_keys = self.modality_configs["language"].modality_keys
language_delta_indices = self.modality_configs["language"].delta_indices
assert len(language_keys) >= 1, "At least one language key is required"
assert len(language_delta_indices) == 1, "Only one language delta index is supported"
self.language_key = language_keys[0]
推理阶段仍然依赖训练时保存的modality_configs。也就是说,部署不是只加载模型权重,还必须加载同一份Processor和模态配置。否则Policy不知道环境Observation里应该读取哪些相机、哪些状态字段、哪个语言key,以及动作输出应该按哪些group切分。
3、Observation需要按照modality_configs组织
在线推理时,外部环境传给Policy的是一个batched observation。check_observation()会严格检查输入结构。它要求Observation包含三个顶层字段,
for modality in ["video", "state", "language"]:
assert modality in observation, f"Observation must contain a '{modality}' key"
assert isinstance(observation[modality], dict), (
f"Observation '{modality}' must be a dictionary. Got {type(observation[modality])}: {observation[modality]}"
)
视频输入必须按modality_configs"video".modality_keys提供,每个相机都是np.uint8数组,shape为(B, T, H, W, C),
for video_key in self.modality_configs["video"].modality_keys:
assert video_key in observation["video"], (
f"Video key '{video_key}' must be in observation"
)
batched_video = observation["video"][video_key]
assert isinstance(batched_video, np.ndarray), (
f"Video key '{video_key}' must be a numpy array. Got {type(batched_video)}"
)
assert batched_video.dtype == np.uint8, (
f"Video key '{video_key}' must be a numpy array of type np.uint8. Got {batched_video.dtype}"
)
assert batched_video.ndim == 5, (
f"Video key '{video_key}' must be a numpy array of shape (B, T, H, W, C), got {batched_video.shape}"
)
assert batched_video.shape[1] == len(self.modality_configs["video"].delta_indices), (
f"Video key '{video_key}'s horizon must be {len(self.modality_configs['video'].delta_indices)}. Got {batched_video.shape[1]}"
)
assert batched_video.shape[-1] == 3, (
f"Video key '{video_key}'s channel 'C' must be 3. Got {batched_video.shape[-1]}"
)
状态输入也按modality_configs"state".modality_keys提供,要求是np.float32,shape为(B, T, D),
for state_key in self.modality_configs["state"].modality_keys:
assert state_key in observation["state"], (
f"State key '{state_key}' must be in observation"
)
batched_state = observation["state"][state_key]
assert isinstance(batched_state, np.ndarray), (
f"State key '{state_key}' must be a numpy array. Got {type(batched_state)}"
)
assert batched_state.dtype == np.float32, (
f"State key '{state_key}' must be a numpy array of type np.float32. Got {batched_state.dtype}"
)
assert batched_state.ndim == 3, (
f"State key '{state_key}' must be a numpy array of shape (B, T, D), got {batched_state.shape}"
)
assert batched_state.shape[1] == len(self.modality_configs["state"].delta_indices), (
f"State key '{state_key}'s horizon must be {len(self.modality_configs['state'].delta_indices)}. Got {batched_state.shape[1]}"
)
语言输入是listlist\[str],外层是batch维,内层是时间维。当前实现要求每个时间位置只有一个字符串,
for language_key in self.modality_configs["language"].modality_keys:
assert language_key in observation["language"], (
f"Language key '{language_key}' must be in observation"
)
batched_language: list[list[str]] = observation["language"][language_key]
assert isinstance(batched_language, list), (
f"Language key '{language_key}' must be a list. Got {type(batched_language)}"
)
for batch_item in batched_language:
assert len(batch_item) == len(self.modality_configs["language"].delta_indices), (
f"Language key '{language_key}'s horizon must be {len(self.modality_configs['language'].delta_indices)}. Got {len(batched_language)}"
)
assert isinstance(batch_item, list), (
f"Language batch item must be a list. Got {type(batch_item)}"
)
assert len(batch_item) == 1, (
f"Language batch item must have exactly one item. Got {len(batch_item)}"
)
assert isinstance(batch_item[0], str), (
f"Language batch item must be a string. Got {type(batch_item[0])}"
)
把这些检查合起来,环境传入的Observation大致如下,
observation = {
"video": {
"front": np.ndarray, # (B, T, H, W, 3), uint8
"wrist": np.ndarray, # (B, T, H, W, 3), uint8
},
"state": {
"joint_position": np.ndarray, # (B, T, D), float32
"gripper": np.ndarray, # (B, T, D), float32
},
"language": {
"annotation.human.task_description": [["pick up the red cube"]]
}
}
4、Gr00tPolicy把Observation转换成模型输入并解码动作
Gr00tPolicy._get_action()是真正的在线推理主流程。源码里的注释已经把流程列出来了,
def _get_action(
self, observation: dict[str, Any], options: dict[str, Any] | None = None
) -> tuple[dict[str, Any], dict[str, Any]]:
"""Internal method to compute actions from observations.
Pipeline:
1. Unbatch observations into individual samples
2. Convert each to VLAStepData and process
3. Collate into model input batch
4. Run model inference
5. Decode and unnormalize actions
"""
第一步是把batched observation拆成单条样本,
def _unbatch_observation(self, value: dict[str, Any]) -> list[dict[str, Any]]:
unbatched_obs = []
batch_size = value["video"][list(value["video"].keys())[0]].shape[0]
for i in range(batch_size):
unbatched_value = {
"video": {k: v[i] for k, v in value["video"].items()},
"state": {k: v[i] for k, v in value["state"].items()},
"language": {k: v[i] for k, v in value["language"].items()},
}
unbatched_obs.append(unbatched_value)
return unbatched_obs
单条Observation会转成VLAStepData,
def _to_vla_step_data(self, observation: dict[str, Any]) -> VLAStepData:
return VLAStepData(
images=observation["video"],
states=observation["state"],
actions={}, # No ground truth actions during inference
text=observation["language"][self.language_key][0],
embodiment=self.embodiment_tag,
)
推理阶段没有真实动作监督,所以这里的actions={}。随后复用Processor把VLAStepData转成模型输入,
unbatched_observations = self._unbatch_observation(observation)
processed_inputs = []
states = []
for obs in unbatched_observations:
vla_step_data = self._to_vla_step_data(obs)
states.append(vla_step_data.states)
messages = [{"type": MessageType.EPISODE_STEP.value, "content": vla_step_data}]
processed_inputs.append(self.processor(messages))
接下来用训练时同一套Collator整理batch,并把浮点Tensor转成bfloat16,
collated_inputs = self.collate_fn(processed_inputs)
collated_inputs = _rec_to_dtype(collated_inputs, dtype=torch.bfloat16)
递归转换函数只处理浮点Tensor,整数Token ID、mask等不会被强制转成浮点,
def _rec_to_dtype(x: Any, dtype: torch.dtype) -> Any:
if isinstance(x, torch.Tensor) and torch.is_floating_point(x):
return x.to(dtype=dtype)
elif isinstance(x, dict) or hasattr(x, "items"):
return {k: _rec_to_dtype(v, dtype) for k, v in x.items()}
elif isinstance(x, list):
return [_rec_to_dtype(v, dtype) for v in x]
else:
return x
模型推理使用torch.inference_mode(),
with torch.inference_mode():
model_pred = self.model.get_action(**collated_inputs)
normalized_action = model_pred["action_pred"].float()
这里调用的是模型的get_action(),不是训练时的forward()。训练时需要输入真实action并计算loss,推理时只需要根据当前Observation生成动作块。
最后一步是反解码。Policy会把原始state按batch重新堆叠,然后传给Processor的decode_action(),
batched_states = {}
for k in self.modality_configs["state"].modality_keys:
batched_states[k] = np.stack([s[k] for s in states], axis=0) # (B, T, D)
unnormalized_action = self.processor.decode_action(
normalized_action.cpu().numpy(), self.embodiment_tag, batched_states
)
casted_action = {
key: value.astype(np.float32) for key, value in unnormalized_action.items()
}
return casted_action, {}
模型输出的是归一化后的动作块,而且可能还是相对动作。真正返回给环境或机器人控制器之前,必须经过decode_action()切回动作group、反归一化,并在需要时从相对动作还原成绝对动作。
check_action()会检查最终动作是否符合模态配置,
for action_key in self.modality_configs["action"].modality_keys:
assert action_key in action, f"Action key '{action_key}' must be in action"
action_arr = action[action_key]
assert isinstance(action_arr, np.ndarray), (
f"Action key '{action_key}' must be a numpy array. Got {type(action_arr)}"
)
assert action_arr.dtype == np.float32, (
f"Action key '{action_key}' must be a numpy array of type np.float32. Got {action_arr.dtype}"
)
assert action_arr.ndim == 3, (
f"Action key '{action_key}' must be a numpy array of shape (B, T, D), got {action_arr.shape}"
)
5、模型get_action负责生成Action Chunk
Gr00tPolicy负责外层输入输出封装,真正的动作生成发生在模型内部。主模型Gr00tN1d7.get_action()定义在,
gr00t/model/gr00t_n1d7/gr00t_n1d7.py
逻辑和训练forward()类似,先把输入拆给Backbone和动作头,然后调用动作头的get_action(),
def get_action(self, inputs: dict, options: dict[str, Any] | None = None) -> BatchFeature:
"""
Generate actions using the complete model.
"""
backbone_inputs, action_inputs = self.prepare_input(inputs)
backbone_outputs = self.backbone(backbone_inputs)
action_outputs = self.action_head.get_action(backbone_outputs, action_inputs, options)
return action_outputs
动作头会先编码视觉语言特征和状态特征,
@torch.no_grad()
def get_action(
self,
backbone_output: BatchFeature,
action_input: BatchFeature,
options: dict[str, Any] | None = None,
) -> BatchFeature:
features = self._encode_features(backbone_output, action_input)
return self.get_action_with_features(
backbone_features=features.backbone_features,
state_features=features.state_features,
embodiment_id=action_input.embodiment_id,
backbone_output=backbone_output,
action_input=action_input,
options=options,
)
get_action_with_features()里会初始化一段高斯噪声动作,
actions = torch.randn(
size=(batch_size, self.config.action_horizon, self.action_dim),
dtype=vl_embeds.dtype,
device=device,
)
dt = 1.0 / self.num_inference_timesteps
vel_strength = torch.ones_like(actions)
后面按照推理步数进行速度积分,每一步根据当前动作块、时间步、状态Token和视觉语言条件预测速度,然后更新动作块。
for t in range(self.num_inference_timesteps):
t_cont = t / float(self.num_inference_timesteps)
t_discretized = int(t_cont * self.num_timestep_buckets)
timesteps_tensor = torch.full(
size=(batch_size,), fill_value=t_discretized, device=device
)
action_features = self.action_encoder(actions, timesteps_tensor, embodiment_id)
if self.config.add_pos_embed:
pos_ids = torch.arange(action_features.shape[1], dtype=torch.long, device=device)
pos_embs = self.position_embedding(pos_ids).unsqueeze(0)
action_features = action_features + pos_embs
sa_embs = torch.cat((state_features, action_features), dim=1)
如果使用普通DiT,就直接把视觉语言特征作为条件传入,
model_output = self.model(
hidden_states=sa_embs,
encoder_hidden_states=vl_embeds,
timestep=timesteps_tensor,
)
如果使用AlternateVLDiT,还会传入image_mask和backbone_attention_mask,
model_output = self.model(
hidden_states=sa_embs,
encoder_hidden_states=vl_embeds,
timestep=timesteps_tensor,
image_mask=backbone_output.image_mask,
backbone_attention_mask=backbone_output.backbone_attention_mask,
)
动作解码器输出当前动作轨迹上的速度,
pred = self.action_decoder(model_output, embodiment_id)
pred_velocity = pred[:, -self.action_horizon :]
最后用Euler方式更新动作,
actions = actions + dt * pred_velocity * vel_strength
经过若干推理步后,返回完整Action Chunk,
return BatchFeature(
data={
"action_pred": actions,
"backbone_features": vl_embeds,
"state_features": state_features,
}
)
部署时通常不会只执行动作块里的第一步,也不会一定执行完整动作块。具体执行多少步取决于外部控制循环、动作频率、模型推理延迟以及是否启用RTC动作衔接。
6、RTC通过旧动作块和新动作块重叠减少轨迹跳变
连续控制时,模型每次都会生成一个Action Chunk。如果每次都从新的高斯噪声生成一整段动作,再直接切换到新动作块,前后两次预测之间可能出现跳变。RTC相关逻辑就是为了解决这个问题。
在动作头get_action_with_features()里,如果action_input中带有action,源码认为当前启用了RTC,
if "action" in action_input:
# If action in input when doing get action, it means we want to use RTC.
# action_horizon is the action horizon of the input action.
# rtc_overlap_steps is the number of steps to overlap with the previous action chunks.
# rtc_frozen_steps is the number of steps to freeze the action, which is the latency of the policy inference.
# rtc_ramp_rate is the rate of the ramp of denoising the actions.
assert options is not None, "options is not None"
assert "action_horizon" in options, "action_horizon is not in options"
assert "rtc_overlap_steps" in options, "rtc_overlap_steps is not in options"
assert "rtc_frozen_steps" in options, "rtc_frozen_steps is not in options"
assert "rtc_ramp_rate" in options, "rtc_ramp_rate is not in options"
这里需要传入几个参数,
action_horizon:旧动作块真实长度
rtc_overlap_steps:新旧动作块重叠多少步
rtc_frozen_steps:重叠区域中前多少步完全冻结
rtc_ramp_rate:剩余重叠区域的速度更新强度上升速率
启用RTC后,初始化动作不再是纯噪声。源码会把旧动作块末尾的一段复制到新动作块开头,
action_horizon_before_padding = options["action_horizon"]
actions[:, : options["rtc_overlap_steps"], :] = action_input["action"][
:,
action_horizon_before_padding
- options["rtc_overlap_steps"] : action_horizon_before_padding,
:,
]
可以理解为新动作块的前几步先沿用旧动作块末尾。这样前后两次控制输出在重叠区域有连续性,不会突然从完全随机噪声开始生成。
接下来,源码会把前rtc_frozen_steps步的速度更新强度置为0,
vel_strength[:, : options["rtc_frozen_steps"], :] = 0.0
这表示这些位置完全冻结,后续Euler积分不会修改它们。它通常对应模型推理延迟期间已经计划好的动作,避免新一次推理把即将执行的动作改掉。
剩下的重叠部分不会直接从0跳到1,而是使用指数ramp逐步增加速度更新强度,
intermediate_steps = options["rtc_overlap_steps"] - options["rtc_frozen_steps"]
t = torch.linspace(0.0, 1.0, intermediate_steps + 2, device=device)
ramp = 1 - torch.exp(-options["rtc_ramp_rate"] * t)
ramp = ramp / ramp[-1].clamp_min(1e-8)
ramp = ramp[1:-1]
vel_strength[
:,
options["rtc_frozen_steps"] : options["rtc_overlap_steps"],
:,
] = ramp[None, :, None].to(device)
最后在每一步速度积分时,vel_strength会参与更新,
actions = actions + dt * pred_velocity * vel_strength
这样不同位置的动作更新强度不同,
冻结区域:vel_strength = 0,不更新旧动作
过渡区域:vel_strength 从小到大逐步增加
非重叠区域:vel_strength = 1,正常生成新动作
RTC的核心不是改变模型结构,而是在推理采样时把旧动作块作为部分初值,并通过vel_strength控制不同时间位置的更新幅度。这样可以减少连续Action Chunk之间的轨迹跳变,对真实机器人控制会更友好。
7、PolicyServer把本地Policy封装成远程推理服务
除了在同一个进程里直接调用Gr00tPolicy.get_action(),源码还提供了server/client封装,文件在,
gr00t/policy/server_client.py
PolicyServer基于ZeroMQ实现,初始化时绑定端口,并注册几个默认endpoint,
class PolicyServer:
"""
An inference server that spin up a ZeroMQ socket and listen for incoming requests.
Can add custom endpoints by calling `register_endpoint`.
"""
def __init__(
self,
policy: BasePolicy,
host: str = "*",
port: int = 5555,
api_token: str = None,
):
self.policy = policy
self.running = True
self.context = zmq.Context()
self.socket = self.context.socket(zmq.REP)
self.socket.bind(f"tcp://{host}:{port}")
self._endpoints: dict[str, EndpointHandler] = {}
self.api_token = api_token
self.register_endpoint("ping", self._handle_ping, requires_input=False)
self.register_endpoint("kill", self._kill_server, requires_input=False)
self.register_endpoint("get_action", self.policy.get_action)
self.register_endpoint("reset", self.policy.reset)
self.register_endpoint(
"get_modality_config",
getattr(self.policy, "get_modality_config", lambda: {}),
requires_input=False,
)
默认最重要的endpoint是,
get_action:远程调用Policy推理
reset:重置Policy状态
get_modality_config:查询当前Checkpoint需要的输入输出模态
ping:检查服务是否可用
kill:停止服务
序列化使用msgpack和msgpack_numpy,源码里专门写了MsgSerializer,并且拒绝object dtype的numpy数组,
class MsgSerializer:
"""msgpack_numpy serializer with a hard ``allow_pickle=False`` boundary."""
@staticmethod
def _safe_encode(obj, chain=None):
if isinstance(obj, np.ndarray) and obj.dtype.kind == "O":
raise TypeError(
f"Refusing to encode object-dtype ndarray (shape={obj.shape}); "
f"msgpack_numpy would invoke pickle. Convert to a concrete "
f"numeric dtype before sending."
)
return mnp.encode(obj, chain=chain)
Server/Client模式适合把推理进程和机器人控制进程分离。例如GPU服务器运行PolicyServer,机器人控制端通过Client发送Observation并接收Action。这样控制端不用直接加载大模型,也便于把推理部署在独立机器上。