1.演示 2. 数据打包
3. ACT 训练
4. 策略部署
5. 自主采摘
1. 数据采集 (Data Collection)
这是最关键的一步,决定了机器人的"上限"。
-
硬件部署: 在机械臂上手安装 1-2 个腕部摄像头 (近距离观察果实和梗),并在田垄旁架设 1-2 个全局摄像头(定位果树位置)。
-
人类演示: 操作员通过遥控设备(如 VR 手柄、示教器或影领系统)操纵机器人采摘。
-
关键点: 演示必须包含完整的"接近果实 -> 调整姿态 -> 夹取/剪切 -> 放入篮筐"全过程。
-
数据量: 农业场景通常需要 50-100次成功的演示来覆盖不同的生长姿态。
-
-
记录内容: 每一帧都要同步记录:
-
所有摄像头的 RGB 图像。
-
机器人的 关节角度 (Joint States)。
-
操作员给出的 动作指令 (Actions)(即下一刻的角度目标)。
-
录制数据的具体"姿势"
如果使用的是松灵官方提供的 ROS 包,数据采集脚本大概逻辑如下:
1.启动 RealSense:
python
roslaunch realsense2_camera rs_camera.launch
2.启动 AUBO 重力补偿: 在遨博控制柜上手动开启"示教模式"或"重力补偿"。
3.录制流程:
- 运行一个 Python 脚本,创建一个空的
.h5文件 - 设置频率: 确保以 20Hz 或 30Hz 的频率读取
/joint_states和/camera/color/image_raw话题 - 注意动作终点: 采摘完成后,手离开机械臂前要先
Ctrl+C停止录制,否则会录入一段"手松开后机械臂轻微晃动"的无效数据
4.数据质量检查:
录制完后,可以使用以下命令查看 H5 文件内容,确认维度是否为 :
python
import h5py
with h5py.File('episode_0.h5', 'r') as f:
print(f['/observations/qpos'].shape) # 应显示 (帧数, 7)
采集脚本代码 (collect_aubo_pika.py)
python
#!/usr/bin/env python3
import rospy
import h5py
import numpy as np
import cv2
import os
import message_filters
from cv_bridge import CvBridge
from sensor_msgs.msg import Image, JointState
from std_msgs.msg import Float32 # 假设夹爪位置通过 Float32 发布
class AuboPikaDataCollector:
def __init__(self, save_dir, episode_idx, hz=20):
rospy.init_node('aubo_pika_collector')
self.bridge = CvBridge()
self.hz = hz
self.save_path = os.path.join(save_dir, f'episode_{episode_idx}.h5')
# 缓存容器
self.obs_qpos = []
self.obs_images = []
self.actions = []
# --- 1. 定义订阅者 (请根据 rostopic list 确认名称) ---
# RealSense 彩色图
self.color_sub = message_filters.Subscriber("/camera/color/image_raw", Image)
# AUBO 关节状态 (通常为 6 轴)
self.arm_sub = message_filters.Subscriber("/aubo_driver/joint_states", JointState)
# Pika 夹爪状态 (1 轴)
self.gripper_sub = message_filters.Subscriber("/pika_driver/gripper_pos", Float32)
# --- 2. 时间同步器 ---
# slop=0.05 表示允许传感器之间有 50ms 的误差
self.ts = message_filters.ApproximateTimeSynchronizer(
[self.color_sub, self.arm_sub, self.gripper_sub],
queue_size=10, slop=0.05
)
self.ts.registerCallback(self.callback)
print(f"录制环境就绪!当前保存至: {self.save_path}")
print("提示:请确保 AUBO 已进入重力补偿模式。按 Ctrl+C 停止录制。")
def callback(self, color_msg, arm_msg, gripper_msg):
# 3. 处理图像 (转换为 BGR 格式)
color_img = self.bridge.imgmsg_to_cv2(color_msg, "bgr8")
# 如果 RealSense 分辨率太高,建议在存储前 resize 以节省空间和训练时间
color_img = cv2.resize(color_img, (640, 480))
# 4. 拼接状态向量 (6 + 1 = 7维)
# 假设 arm_msg.position 前 6 位是关节角度
aubo_qpos = list(arm_msg.position[:6])
# 将夹爪位置加入列表
full_qpos = aubo_qpos + [gripper_msg.data]
# 暂存数据
self.obs_images.append(color_img)
self.obs_qpos.append(full_qpos)
self.actions.append(full_qpos)
if len(self.obs_qpos) % 20 == 0:
rospy.loginfo(f"已录制 {len(self.obs_qpos)} 帧...")
def save_to_h5(self):
if len(self.obs_qpos) < 20:
print("录制数据太短,已放弃保存。")
return
# 5. ACT 核心对齐逻辑:Action(t) = Qpos(t+1)
# 这种对齐方式让模型学习"预测未来一步的状态"
qpos_np = np.array(self.obs_qpos)
image_np = np.array(self.obs_images)
action_np = np.array(self.actions)
# 对齐:去除最后一帧观察,去除第一帧动作
final_qpos = qpos_np[:-1]
final_image = image_np[:-1]
final_action = action_np[1:]
print(f"正在保存文件至 {self.save_path}...")
with h5py.File(self.save_path, 'w') as f:
f.attrs['sim'] = False # 标记为真实机器人数据
obs = f.create_group('observations')
images = obs.create_group('images')
# 符合官方 ACT 命名规范
images.create_dataset('cam_wrist', data=final_image, compression="gzip", compression_opts=4)
obs.create_dataset('qpos', data=final_qpos)
# Action 是模型预测的目标值
f.create_dataset('action', data=final_action)
print(f"保存完成!共 {len(final_qpos)} 步数据。")
if __name__ == '__main__':
# 路径和编号设置
DATA_DIR = "./harvest_data_v1"
EP_ID = input("请输入本次演示的编号 (如 0): ")
if not os.path.exists(DATA_DIR):
os.makedirs(DATA_DIR)
recorder = AuboPikaDataCollector(DATA_DIR, EP_ID)
try:
rospy.spin()
except KeyboardInterrupt:
print("\n停止信号接收。")
finally:
recorder.save_to_h5()
2. 数据预处理与打包
将原始记录转化为 ACT 模型能吃的"格式"
确保你的文件层级结构如下,这样可以直接适配大多数开源的 ACT 训练代码:
-
/observations/qpos(Shape:)
-
/observations/images/cam_wrist(Shape:)
-
/action(Shape:)
A:Action Chunking 处理: 将连续的动作切分成块(例如每 步为一个 Chunk)
-
逻辑: 对于时刻
,模型不仅要预测
的动作,还要预测
这一整块动作。
-
参数
: 通常为16 / 32 / 50。如果你的采摘动作很慢,可以适当调大。
B:数据归一化: 模型无法直接处理弧度制下的角度(如 和
),需要将它们映射到同一个区间。将关节角度和动作值缩放到
之间,方便模型收敛。
-
操作: 遍历所有
episode_xxx.h5文件,计算所有关节角度 (qpos) 和动作 (action) 的平均值 (Mean) 和 标准差 (Std)。 -
打包逻辑: 训练时,输入模型的数据减去均值再除以标准差。
C:数据增强: 对农业图像进行随机亮度、对比度调整,模拟大棚内不同时间段的光照变化
-
缩放与裁剪: 将图像统一缩放到模型要求的尺寸(通常是
或
)
-
颜色抖动 (Color Jitter): 随机改变图像的亮度、对比度和饱和度
1. 统计量计算脚本 (compute_stats.py)
在开始训练前,运行此脚本,它会遍历所有.h5演示文件,计算 AUBO 6轴 + Pika 1爪(共7维)的均值和标准差,并保存为 stats.pkl
python
import os
import h5py
import numpy as np
import pickle
def compute_stats(dataset_dir):
all_qpos = []
all_action = []
# 获取所有 h5 文件
episode_files = [f for f in os.listdir(dataset_dir) if f.endswith('.h5')]
for f_name in episode_files:
with h5py.File(os.path.join(dataset_dir, f_name), 'r') as root:
qpos = root['/observations/qpos'][:]
action = root['/action'][:]
all_qpos.append(qpos)
all_action.append(action)
all_qpos = np.concatenate(all_qpos, axis=0)
all_action = np.concatenate(all_action, axis=0)
# 计算均值和标准差
stats = {
'qpos_mean': all_qpos.mean(axis=0),
'qpos_std': all_qpos.std(axis=0),
'action_mean': all_action.mean(axis=0),
'action_std': all_action.std(axis=0),
}
# 这里的 std 如果过小(比如夹爪没动过),需要设为 1 防止除以 0
stats['qpos_std'] = np.clip(stats['qpos_std'], 1e-2, None)
stats['action_std'] = np.clip(stats['action_std'], 1e-2, None)
with open(os.path.join(dataset_dir, 'dataset_stats.pkl'), 'wb') as f:
pickle.dump(stats, f)
print(f"统计量计算完成,已保存至 {dataset_dir}/dataset_stats.pkl")
return stats
# 使用示例
# compute_stats('./harvest_data_v1')
记得在训练主程序中定义:
python
from torch.utils.data import DataLoader
# 划分训练/验证集
all_ids = [0, 1, 2, 3, ...]
train_ids = all_ids[:80]
# 加载统计量
with open('dataset_stats.pkl', 'rb') as f:
stats = pickle.load(f)
# 创建 DataLoader
train_dataset = ACTHarvestingDataset('./harvest_data_v1', train_ids, chunk_size=100, stats=stats)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
2. PyTorch Dataset 类 (dataset.py)
这是喂给模型的核心打包类。它包含实时图像增强 、归一化 以及 ACT 特有的 Action Chunking(动作切片)逻辑
python
import torch
from torch.utils.data import Dataset
import torchvision.transforms as T
import h5py
import numpy as np
from PIL import Image as PILImage
class HarvestingDataset(Dataset):
def __init__(self, episode_ids, dataset_dir, chunk_size, norm_stats, train=True):
self.episode_ids = episode_ids
self.dataset_dir = dataset_dir
self.chunk_size = chunk_size
self.train = train
# 归一化参数:需提前计算好 qpos 的 mean 和 std
self.norm_stats = norm_stats
# --- 数据增强 Pipeline ---
if self.train:
self.transform = T.Compose([
T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.05), # 模拟光照变化
T.RandomGrayscale(p=0.02), # 模拟极端光照情况
T.ToTensor(), # 转为 0-1 的 Tensor,并移动通道至 [C, H, W]
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # ImageNet 标准归一化
])
else:
self.transform = T.Compose([
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def __getitem__(self, index):
episode_id = self.episode_ids[index]
with h5py.File(f'{self.dataset_dir}/episode_{episode_id}.h5', 'r') as root:
# 1. 采样起始点
# 确保采样范围不会让 chunk 超出总长度
max_start = len(root['/action']) - self.chunk_size
start_ts = np.random.randint(0, max_start) if self.train else 0
# 2. 获取图像并增强
raw_img = root['/observations/images/cam_wrist'][start_ts]
# 注意:h5 存的是 numpy 数组 (H,W,C),需要转成 PIL 做 transform
image_pil = PILImage.fromarray(raw_img)
image_tensor = self.transform(image_pil)
# 3. 获取 qpos 并归一化
qpos = root['/observations/qpos'][start_ts]
qpos_norm = (qpos - self.norm_stats['qpos_mean']) / self.norm_stats['qpos_std']
# 4. 获取 Action Chunk 并归一化
action_chunk = root['/action'][start_ts : start_ts + self.chunk_size]
action_norm = (action_chunk - self.norm_stats['action_mean']) / self.norm_stats['action_std']
return image_tensor, torch.tensor(qpos_norm), torch.tensor(action_norm)
def __len__(self):
return len(self.episode_ids)
-
遍历所有的
.h5文件 -
算出所有关节角度的
mean(平均值)和std(标准差) -
存成一个
stats.pkl -
训练时把这个
stats传给上面的HarvestingDataset类
3. 模型训练 (Training)
在服务器上训练 ACT 架构
-
输入配置: 将图像 输入给主干网络(如 ResNet 或 EfficientNet)、将当前的关节角度作为本体感知输入
-
损失函数: 同时优化重建损失 (
预测动作与演示动作的差距)和 KL 散度 (让隐变量
符合正态分布)
重建损失: L1 Loss 通常在 0.5 - 1.0 之间,收敛后期应降至 0.01 - 0.05
KL 散度: 通常在训练初期会从 0 快速飙升到 10 - 50 ,随着训练稳定,它会维持在一个固定区间(例如 20 左右)
-
关键参数调试:
-
chunk_size: ACT 官方默认 16/32/50 -
kl_weight: 调节它以平衡动作的精确度和多样性
-
4. 推理与实时控制 (Inference)
当你训练完成后,代码会保存一个 .pth 或 .ckpt 文件,将训练好的模型部署到采摘现场。
-
时间集成 (Temporal Ensemble): 开启此功能。模型会不断预测重叠的动作块并加权平均。
- 作用: 当采摘过程中果实随风晃动时,时间集成能让机器人更平滑地修正路径,而不是猛烈抖动。
-
观测循环: 1. 摄像头捕捉当前画面。 2. 模型生成未来 100 步动作。 3. 结合之前留下的预测,计算出当前这一步的最佳角度。 4. 发送给电机执行。
python
import torch
import numpy as np
from copy import deepcopy
import pickle
class ACTInferenceNode:
def __init__(self, model_path, stats_path):
# 1. 加载统计量
with open(stats_path, 'rb') as f:
self.stats = pickle.load(f)
# 2. 加载模型
self.model = make_ACT_model()
self.model.load_state_dict(torch.load(model_path))
self.model.cuda().eval()
self.chunk_size = 100 # 预测长度
self.ensemble_window = 20 # 平滑窗口
self.exec_steps = 5 # 【关键】每次执行前5步
self.action_dim = 7
# 正确的 temporal ensemble 缓存
self.all_time_actions = torch.zeros(
[self.chunk_size, self.ensemble_window, self.action_dim]
).cuda()
self.step_idx = 0 # 全局步数
def process_and_run(self, color_image, curr_qpos):
# --- A. 预处理 ---
img_tensor = self.preprocess_img(color_image).cuda()
qpos_norm = (curr_qpos - self.stats['qpos_mean']) / self.stats['qpos_std']
qpos_tensor = torch.from_numpy(qpos_norm).float().cuda().unsqueeze(0)
# --- B. 模型预测 ---
with torch.no_grad():
pred_action_norm = self.model(qpos_tensor, img_tensor) # [1,100,7]
# --- C. 反归一化 ---
pred_action = pred_action_norm * self.stats['action_std'] + self.stats['action_mean']
pred_action = pred_action.squeeze(0) # [100,7]
# --- D. 【核心】Temporal Ensemble ---
self.all_time_actions = torch.roll(self.all_time_actions, shifts=-1, dims=1)
self.all_time_actions[:, -1] = pred_action
# 加权平均
num_mem = self.ensemble_window - torch.sum(
torch.all(self.all_time_actions == 0, dim=-1), dim=-1
).float()
smoothed = torch.sum(self.all_time_actions, dim=1) / num_mem.unsqueeze(-1)
# --- E. 【最关键】取前 N 步执行 ---
exec_actions = smoothed[self.step_idx : self.step_idx + self.exec_steps]
# --- F. 下发机械臂,连续执行 N 步 ---
for act in exec_actions:
self.publish_to_robot(act.cpu().numpy())
self.step_idx += self.exec_steps
def publish_to_robot(self, joints):
# 下发给遨博机械臂
pass