aubo i5+pika realsense+ACT训练完整流程

1.演示 2. 数据打包 3. ACT 训练 4. 策略部署 5. 自主采摘

1. 数据采集 (Data Collection)

这是最关键的一步,决定了机器人的"上限"。

  • 硬件部署: 在机械臂上手安装 1-2 个腕部摄像头 (近距离观察果实和梗),并在田垄旁架设 1-2 个全局摄像头(定位果树位置)。

  • 人类演示: 操作员通过遥控设备(如 VR 手柄、示教器或影领系统)操纵机器人采摘。

    • 关键点: 演示必须包含完整的"接近果实 -> 调整姿态 -> 夹取/剪切 -> 放入篮筐"全过程。

    • 数据量: 农业场景通常需要 50-100次成功的演示来覆盖不同的生长姿态。

  • 记录内容: 每一帧都要同步记录:

    1. 所有摄像头的 RGB 图像

    2. 机器人的 关节角度 (Joint States)

    3. 操作员给出的 动作指令 (Actions)(即下一刻的角度目标)。

录制数据的具体"姿势"

如果使用的是松灵官方提供的 ROS 包,数据采集脚本大概逻辑如下:

1.启动 RealSense:

python 复制代码
roslaunch realsense2_camera rs_camera.launch

2.启动 AUBO 重力补偿: 在遨博控制柜上手动开启"示教模式"或"重力补偿"。

3.录制流程:

  1. 运行一个 Python 脚本,创建一个空的 .h5 文件
  2. 设置频率: 确保以 20Hz 或 30Hz 的频率读取 /joint_states/camera/color/image_raw 话题
  3. 注意动作终点: 采摘完成后,手离开机械臂前要先Ctrl+C停止录制,否则会录入一段"手松开后机械臂轻微晃动"的无效数据

4.数据质量检查:

录制完后,可以使用以下命令查看 H5 文件内容,确认维度是否为

python 复制代码
import h5py
with h5py.File('episode_0.h5', 'r') as f:
    print(f['/observations/qpos'].shape) # 应显示 (帧数, 7)

采集脚本代码 (collect_aubo_pika.py)

python 复制代码
#!/usr/bin/env python3
import rospy
import h5py
import numpy as np
import cv2
import os
import message_filters
from cv_bridge import CvBridge
from sensor_msgs.msg import Image, JointState
from std_msgs.msg import Float32 # 假设夹爪位置通过 Float32 发布

class AuboPikaDataCollector:
    def __init__(self, save_dir, episode_idx, hz=20):
        rospy.init_node('aubo_pika_collector')
        self.bridge = CvBridge()
        self.hz = hz
        self.save_path = os.path.join(save_dir, f'episode_{episode_idx}.h5')
        
        # 缓存容器
        self.obs_qpos = []
        self.obs_images = []
        self.actions = []

        # --- 1. 定义订阅者 (请根据 rostopic list 确认名称) ---
        # RealSense 彩色图
        self.color_sub = message_filters.Subscriber("/camera/color/image_raw", Image)
        # AUBO 关节状态 (通常为 6 轴)
        self.arm_sub = message_filters.Subscriber("/aubo_driver/joint_states", JointState)
        # Pika 夹爪状态 (1 轴)
        self.gripper_sub = message_filters.Subscriber("/pika_driver/gripper_pos", Float32)

        # --- 2. 时间同步器 ---
        # slop=0.05 表示允许传感器之间有 50ms 的误差
        self.ts = message_filters.ApproximateTimeSynchronizer(
            [self.color_sub, self.arm_sub, self.gripper_sub], 
            queue_size=10, slop=0.05
        )
        self.ts.registerCallback(self.callback)

        print(f"录制环境就绪!当前保存至: {self.save_path}")
        print("提示:请确保 AUBO 已进入重力补偿模式。按 Ctrl+C 停止录制。")

    def callback(self, color_msg, arm_msg, gripper_msg):
        # 3. 处理图像 (转换为 BGR 格式)
        color_img = self.bridge.imgmsg_to_cv2(color_msg, "bgr8")
        # 如果 RealSense 分辨率太高,建议在存储前 resize 以节省空间和训练时间
        color_img = cv2.resize(color_img, (640, 480))

        # 4. 拼接状态向量 (6 + 1 = 7维)
        # 假设 arm_msg.position 前 6 位是关节角度
        aubo_qpos = list(arm_msg.position[:6])
        # 将夹爪位置加入列表
        full_qpos = aubo_qpos + [gripper_msg.data]

        # 暂存数据
        self.obs_images.append(color_img)
        self.obs_qpos.append(full_qpos)
        self.actions.append(full_qpos)

        if len(self.obs_qpos) % 20 == 0:
            rospy.loginfo(f"已录制 {len(self.obs_qpos)} 帧...")

    def save_to_h5(self):
        if len(self.obs_qpos) < 20:
            print("录制数据太短,已放弃保存。")
            return

        # 5. ACT 核心对齐逻辑:Action(t) = Qpos(t+1)
        # 这种对齐方式让模型学习"预测未来一步的状态"
        qpos_np = np.array(self.obs_qpos)
        image_np = np.array(self.obs_images)
        action_np = np.array(self.actions)

        # 对齐:去除最后一帧观察,去除第一帧动作
        final_qpos = qpos_np[:-1]
        final_image = image_np[:-1]
        final_action = action_np[1:]

        print(f"正在保存文件至 {self.save_path}...")
        with h5py.File(self.save_path, 'w') as f:
            f.attrs['sim'] = False # 标记为真实机器人数据
            
            obs = f.create_group('observations')
            images = obs.create_group('images')
            
            # 符合官方 ACT 命名规范
            images.create_dataset('cam_wrist', data=final_image, compression="gzip", compression_opts=4)
            obs.create_dataset('qpos', data=final_qpos)
            
            # Action 是模型预测的目标值
            f.create_dataset('action', data=final_action)

        print(f"保存完成!共 {len(final_qpos)} 步数据。")

if __name__ == '__main__':
    # 路径和编号设置
    DATA_DIR = "./harvest_data_v1"
    EP_ID = input("请输入本次演示的编号 (如 0): ")
    
    if not os.path.exists(DATA_DIR):
        os.makedirs(DATA_DIR)

    recorder = AuboPikaDataCollector(DATA_DIR, EP_ID)
    
    try:
        rospy.spin()
    except KeyboardInterrupt:
        print("\n停止信号接收。")
    finally:
        recorder.save_to_h5()

2. 数据预处理与打包

将原始记录转化为 ACT 模型能吃的"格式"

确保你的文件层级结构如下,这样可以直接适配大多数开源的 ACT 训练代码:

  • /observations/qpos (Shape: )

  • /observations/images/cam_wrist (Shape: )

  • /action (Shape: )

A:Action Chunking 处理: 将连续的动作切分成块(例如每 步为一个 Chunk)

  • 逻辑: 对于时刻,模型不仅要预测的动作,还要预测 这一整块动作。

  • 参数 通常为16 / 32 / 50。如果你的采摘动作很慢,可以适当调大。

B:数据归一化: 模型无法直接处理弧度制下的角度(如 ),需要将它们映射到同一个区间。将关节角度和动作值缩放到之间,方便模型收敛。

  • 操作: 遍历所有 episode_xxx.h5 文件,计算所有关节角度 (qpos) 和动作 (action) 的平均值 (Mean)标准差 (Std)

  • 打包逻辑: 训练时,输入模型的数据减去均值再除以标准差。

C:数据增强: 对农业图像进行随机亮度、对比度调整,模拟大棚内不同时间段的光照变化

  • 缩放与裁剪: 将图像统一缩放到模型要求的尺寸(通常是

  • 颜色抖动 (Color Jitter): 随机改变图像的亮度、对比度和饱和度

1. 统计量计算脚本 (compute_stats.py)

在开始训练前,运行此脚本,它会遍历所有.h5演示文件,计算 AUBO 6轴 + Pika 1爪(共7维)的均值和标准差,并保存为 stats.pkl

python 复制代码
import os
import h5py
import numpy as np
import pickle

def compute_stats(dataset_dir):
    all_qpos = []
    all_action = []

    # 获取所有 h5 文件
    episode_files = [f for f in os.listdir(dataset_dir) if f.endswith('.h5')]
    
    for f_name in episode_files:
        with h5py.File(os.path.join(dataset_dir, f_name), 'r') as root:
            qpos = root['/observations/qpos'][:]
            action = root['/action'][:]
            all_qpos.append(qpos)
            all_action.append(action)

    all_qpos = np.concatenate(all_qpos, axis=0)
    all_action = np.concatenate(all_action, axis=0)

    # 计算均值和标准差
    stats = {
        'qpos_mean': all_qpos.mean(axis=0),
        'qpos_std': all_qpos.std(axis=0),
        'action_mean': all_action.mean(axis=0),
        'action_std': all_action.std(axis=0),
    }

    # 这里的 std 如果过小(比如夹爪没动过),需要设为 1 防止除以 0
    stats['qpos_std'] = np.clip(stats['qpos_std'], 1e-2, None)
    stats['action_std'] = np.clip(stats['action_std'], 1e-2, None)

    with open(os.path.join(dataset_dir, 'dataset_stats.pkl'), 'wb') as f:
        pickle.dump(stats, f)
    
    print(f"统计量计算完成,已保存至 {dataset_dir}/dataset_stats.pkl")
    return stats

# 使用示例
# compute_stats('./harvest_data_v1')

记得在训练主程序中定义:

python 复制代码
from torch.utils.data import DataLoader

# 划分训练/验证集
all_ids = [0, 1, 2, 3, ...] 
train_ids = all_ids[:80]

# 加载统计量
with open('dataset_stats.pkl', 'rb') as f:
    stats = pickle.load(f)

# 创建 DataLoader
train_dataset = ACTHarvestingDataset('./harvest_data_v1', train_ids, chunk_size=100, stats=stats)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)

2. PyTorch Dataset 类 (dataset.py)

这是喂给模型的核心打包类。它包含实时图像增强归一化 以及 ACT 特有的 Action Chunking(动作切片)逻辑

python 复制代码
import torch
from torch.utils.data import Dataset
import torchvision.transforms as T
import h5py
import numpy as np
from PIL import Image as PILImage

class HarvestingDataset(Dataset):
    def __init__(self, episode_ids, dataset_dir, chunk_size, norm_stats, train=True):
        self.episode_ids = episode_ids
        self.dataset_dir = dataset_dir
        self.chunk_size = chunk_size
        self.train = train
        
        # 归一化参数:需提前计算好 qpos 的 mean 和 std
        self.norm_stats = norm_stats 

        # --- 数据增强 Pipeline ---
        if self.train:
            self.transform = T.Compose([
                T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.05), # 模拟光照变化
                T.RandomGrayscale(p=0.02), # 模拟极端光照情况
                T.ToTensor(), # 转为 0-1 的 Tensor,并移动通道至 [C, H, W]
                T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # ImageNet 标准归一化
            ])
        else:
            self.transform = T.Compose([
                T.ToTensor(),
                T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ])

    def __getitem__(self, index):
        episode_id = self.episode_ids[index]
        with h5py.File(f'{self.dataset_dir}/episode_{episode_id}.h5', 'r') as root:
            # 1. 采样起始点
            # 确保采样范围不会让 chunk 超出总长度
            max_start = len(root['/action']) - self.chunk_size
            start_ts = np.random.randint(0, max_start) if self.train else 0
            
            # 2. 获取图像并增强
            raw_img = root['/observations/images/cam_wrist'][start_ts]
            # 注意:h5 存的是 numpy 数组 (H,W,C),需要转成 PIL 做 transform
            image_pil = PILImage.fromarray(raw_img)
            image_tensor = self.transform(image_pil)
            
            # 3. 获取 qpos 并归一化
            qpos = root['/observations/qpos'][start_ts]
            qpos_norm = (qpos - self.norm_stats['qpos_mean']) / self.norm_stats['qpos_std']
            
            # 4. 获取 Action Chunk 并归一化
            action_chunk = root['/action'][start_ts : start_ts + self.chunk_size]
            action_norm = (action_chunk - self.norm_stats['action_mean']) / self.norm_stats['action_std']
            
        return image_tensor, torch.tensor(qpos_norm), torch.tensor(action_norm)

    def __len__(self):
        return len(self.episode_ids)
  • 遍历所有的 .h5 文件

  • 算出所有关节角度的 mean(平均值)和 std(标准差)

  • 存成一个 stats.pkl

  • 训练时把这个 stats 传给上面的 HarvestingDataset

3. 模型训练 (Training)

在服务器上训练 ACT 架构

  • 输入配置:图像 输入给主干网络(如 ResNet 或 EfficientNet)、将当前的关节角度作为本体感知输入

  • 损失函数: 同时优化重建损失预测动作与演示动作的差距)和 KL 散度 (让隐变量符合正态分布)

重建损失: L1 Loss 通常在 0.5 - 1.0 之间,收敛后期应降至 0.01 - 0.05

KL 散度: 通常在训练初期会从 0 快速飙升到 10 - 50 ,随着训练稳定,它会维持在一个固定区间(例如 20 左右

  • 关键参数调试:

    • chunk_size: ACT 官方默认 16/32/50

    • kl_weight: 调节它以平衡动作的精确度和多样性

4. 推理与实时控制 (Inference)

当你训练完成后,代码会保存一个 .pth.ckpt 文件,将训练好的模型部署到采摘现场。

  • 时间集成 (Temporal Ensemble): 开启此功能。模型会不断预测重叠的动作块并加权平均。

    • 作用: 当采摘过程中果实随风晃动时,时间集成能让机器人更平滑地修正路径,而不是猛烈抖动。
  • 观测循环: 1. 摄像头捕捉当前画面。 2. 模型生成未来 100 步动作。 3. 结合之前留下的预测,计算出当前这一步的最佳角度。 4. 发送给电机执行。

python 复制代码
import torch
import numpy as np
from copy import deepcopy
import pickle

class ACTInferenceNode:
    def __init__(self, model_path, stats_path):
        # 1. 加载统计量
        with open(stats_path, 'rb') as f:
            self.stats = pickle.load(f)
        
        # 2. 加载模型
        self.model = make_ACT_model()
        self.model.load_state_dict(torch.load(model_path))
        self.model.cuda().eval()

    
        self.chunk_size = 100    # 预测长度
        self.ensemble_window = 20  # 平滑窗口
        self.exec_steps = 5     # 【关键】每次执行前5步
        self.action_dim = 7

        # 正确的 temporal ensemble 缓存
        self.all_time_actions = torch.zeros(
            [self.chunk_size, self.ensemble_window, self.action_dim]
        ).cuda()

        self.step_idx = 0  # 全局步数

    def process_and_run(self, color_image, curr_qpos):
        # --- A. 预处理 ---
        img_tensor = self.preprocess_img(color_image).cuda()
        qpos_norm = (curr_qpos - self.stats['qpos_mean']) / self.stats['qpos_std']
        qpos_tensor = torch.from_numpy(qpos_norm).float().cuda().unsqueeze(0)

        # --- B. 模型预测 ---
        with torch.no_grad():
            pred_action_norm = self.model(qpos_tensor, img_tensor)  # [1,100,7]
        
        # --- C. 反归一化 ---
        pred_action = pred_action_norm * self.stats['action_std'] + self.stats['action_mean']
        pred_action = pred_action.squeeze(0)  # [100,7]

        # --- D. 【核心】Temporal Ensemble ---
        self.all_time_actions = torch.roll(self.all_time_actions, shifts=-1, dims=1)
        self.all_time_actions[:, -1] = pred_action

        # 加权平均
        num_mem = self.ensemble_window - torch.sum(
            torch.all(self.all_time_actions == 0, dim=-1), dim=-1
        ).float()
        smoothed = torch.sum(self.all_time_actions, dim=1) / num_mem.unsqueeze(-1)

        # --- E. 【最关键】取前 N 步执行 ---
        exec_actions = smoothed[self.step_idx : self.step_idx + self.exec_steps]
        
        # --- F. 下发机械臂,连续执行 N 步 ---
        for act in exec_actions:
            self.publish_to_robot(act.cpu().numpy())

        self.step_idx += self.exec_steps

    def publish_to_robot(self, joints):
        # 下发给遨博机械臂
        pass
相关推荐
我不是懒洋洋2 小时前
【经典题目】链表OJ(轮转数组、返回倒数第k个节点、链表的回文结构)
c语言·开发语言·数据结构·算法·链表·visual studio
张小泡泡2 小时前
Graph Retrieval-Augmented Generation: A Survey
论文阅读·人工智能·rag·graphrag
kyle~2 小时前
字节序---大端与小端
c++·机器人
鱼鳞_2 小时前
Java学习笔记_Day30(File)
笔记·学习
2401_832298102 小时前
OpenClaw×HappyHorse 深度融合:AI 视频自动化量产,重构内容生产范式
人工智能·安全
Allen_LVyingbo2 小时前
《狄拉克符号法50讲》习题与解析(上)
开发语言·人工智能·python·数学建模·量子计算
AC赳赳老秦2 小时前
OpenClaw对接百度指数:关键词热度分析,精准定位博客创作方向
java·python·算法·百度·dubbo·deepseek·openclaw
llm大模型算法工程师weng2 小时前
高校数据中台:驱动智慧校园从“联通”走向“智治”
人工智能
波动几何2 小时前
多人游戏引擎框架gamebox
人工智能