视觉语言导航从入门到精通(三)

视觉语言导航从入门到精通(三):核心模型架构详解

本文是「视觉语言导航从入门到精通」系列的第三篇,深入讲解VLN的核心模型架构和关键技术。


文章目录

  • [1. VLN模型总体框架](#1. VLN模型总体框架)
  • [2. 编码器模块](#2. 编码器模块)
  • [3. 跨模态融合](#3. 跨模态融合)
  • [4. 动作解码与决策](#4. 动作解码与决策)
  • [5. 经典模型详解](#5. 经典模型详解)
  • [6. PyTorch实现](#6. PyTorch实现)
  • [7. 数学原理深入](#7. 数学原理深入)
  • [8. 训练技巧与实践经验](#8. 训练技巧与实践经验)

1. VLN模型总体框架

1.1 基础架构

VLN Agent 架构
Language Encoder Cross-Modal Fusion Visual Encoder Action Decoder Action: 前进/左转/右转/停止

模块 功能
Language Encoder 编码自然语言指令
Visual Encoder 编码视觉观察
Cross-Modal Fusion 融合语言和视觉特征
Action Decoder 解码生成导航动作

1.2 导航循环

python 复制代码
# VLN导航的基本循环
def navigate(agent, instruction, env):
    """
    VLN导航主循环
    """
    # 1. 编码指令(只需一次)
    lang_features = agent.encode_language(instruction)

    # 2. 初始化状态
    hidden_state = agent.init_state()
    done = False
    trajectory = []

    while not done:
        # 3. 获取当前视觉观察
        observation = env.get_observation()

        # 4. 编码视觉特征
        visual_features = agent.encode_visual(observation)

        # 5. 跨模态融合
        fused_features = agent.fuse(lang_features, visual_features, hidden_state)

        # 6. 预测动作
        action, hidden_state = agent.decode_action(fused_features)

        # 7. 执行动作
        env.step(action)
        trajectory.append(action)

        # 8. 检查是否结束
        if action == 'STOP' or len(trajectory) > MAX_STEPS:
            done = True

    return trajectory

2. 编码器模块

2.1 语言编码器

LSTM编码器(经典方法)
python 复制代码
import torch
import torch.nn as nn

class LSTMLanguageEncoder(nn.Module):
    """基于LSTM的语言编码器"""

    def __init__(self, vocab_size, embed_dim=256, hidden_dim=512,
                 num_layers=2, dropout=0.5):
        super().__init__()

        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.lstm = nn.LSTM(
            embed_dim,
            hidden_dim // 2,  # 双向LSTM
            num_layers=num_layers,
            batch_first=True,
            bidirectional=True,
            dropout=dropout
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, input_ids, lengths):
        """
        Args:
            input_ids: [batch, seq_len]
            lengths: [batch] 每个序列的实际长度
        Returns:
            outputs: [batch, seq_len, hidden_dim] 每个token的表示
            final: [batch, hidden_dim] 句子级别表示
        """
        # 词嵌入
        embeds = self.dropout(self.embedding(input_ids))

        # Pack序列
        packed = nn.utils.rnn.pack_padded_sequence(
            embeds, lengths.cpu(), batch_first=True, enforce_sorted=False
        )

        # LSTM编码
        outputs, (h_n, c_n) = self.lstm(packed)

        # Unpack
        outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True)

        # 拼接双向最后隐藏状态
        final = torch.cat([h_n[-2], h_n[-1]], dim=-1)

        return outputs, final
BERT编码器(现代方法)
python 复制代码
from transformers import BertModel, BertTokenizer

class BERTLanguageEncoder(nn.Module):
    """基于BERT的语言编码器"""

    def __init__(self, bert_model='bert-base-uncased', finetune=True):
        super().__init__()

        self.bert = BertModel.from_pretrained(bert_model)
        self.hidden_dim = self.bert.config.hidden_size  # 768

        if not finetune:
            for param in self.bert.parameters():
                param.requires_grad = False

    def forward(self, input_ids, attention_mask):
        """
        Args:
            input_ids: [batch, seq_len]
            attention_mask: [batch, seq_len]
        Returns:
            token_features: [batch, seq_len, 768]
            sentence_feature: [batch, 768]
        """
        outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            return_dict=True
        )

        token_features = outputs.last_hidden_state  # [batch, seq_len, 768]
        sentence_feature = outputs.pooler_output    # [batch, 768]

        return token_features, sentence_feature

2.2 视觉编码器

ResNet特征提取
python 复制代码
import torchvision.models as models
from torchvision.models import ResNet152_Weights

class ResNetVisualEncoder(nn.Module):
    """基于ResNet的视觉编码器"""

    def __init__(self, output_dim=512, pretrained=True):
        super().__init__()

        # 加载预训练ResNet(使用新版API)
        weights = ResNet152_Weights.IMAGENET1K_V2 if pretrained else None
        resnet = models.resnet152(weights=weights)

        # 移除最后的全连接层
        self.backbone = nn.Sequential(*list(resnet.children())[:-1])

        # 投影层
        self.proj = nn.Linear(2048, output_dim)

    def forward(self, images):
        """
        Args:
            images: [batch, num_views, 3, H, W] 全景图像
        Returns:
            features: [batch, num_views, output_dim]
        """
        batch_size, num_views = images.shape[:2]

        # 合并batch和views维度
        images = images.view(-1, *images.shape[2:])

        # 提取特征
        features = self.backbone(images)  # [batch*views, 2048, 1, 1]
        features = features.squeeze(-1).squeeze(-1)  # [batch*views, 2048]

        # 投影
        features = self.proj(features)  # [batch*views, output_dim]

        # 恢复维度
        features = features.view(batch_size, num_views, -1)

        return features
ViT视觉编码器
python 复制代码
from transformers import ViTModel

class ViTVisualEncoder(nn.Module):
    """基于Vision Transformer的视觉编码器"""

    def __init__(self, model_name='google/vit-base-patch16-224'):
        super().__init__()

        self.vit = ViTModel.from_pretrained(model_name)
        self.hidden_dim = self.vit.config.hidden_size

    def forward(self, images):
        """
        Args:
            images: [batch, num_views, 3, 224, 224]
        Returns:
            features: [batch, num_views, hidden_dim]
        """
        batch_size, num_views = images.shape[:2]
        images = images.view(-1, *images.shape[2:])

        outputs = self.vit(pixel_values=images)

        # 使用CLS token作为图像表示
        features = outputs.last_hidden_state[:, 0]  # [batch*views, hidden_dim]
        features = features.view(batch_size, num_views, -1)

        return features
全景图表示
复制代码
全景图视角划分(常用36视角):

仰角 (Elevation):
  +30° (上方)  -->  12个方位角
    0° (水平)  -->  12个方位角
  -30° (下方)  -->  12个方位角

方位角 (Heading): 0°, 30°, 60°, ..., 330° (每30°一个)

总计: 3 × 12 = 36 个视角

**全景图展开示意 (36视角)**

| 仰角 \ 方位角 | 0° | 30° | 60° | 90° | ... | 330° |
|-------------|-----|-----|-----|-----|-----|------|
| +30° (上方) | v1 | v2 | v3 | v4 | ... | v12 |
| 0° (水平) | v13 | v14 | v15 | v16 | ... | v24 |
| -30° (下方) | v25 | v26 | v27 | v28 | ... | v36 |

3. 跨模态融合

3.1 注意力机制

Soft Attention
python 复制代码
class SoftAttention(nn.Module):
    """软注意力机制"""

    def __init__(self, query_dim, key_dim, hidden_dim=256):
        super().__init__()

        self.query_proj = nn.Linear(query_dim, hidden_dim)
        self.key_proj = nn.Linear(key_dim, hidden_dim)
        self.score = nn.Linear(hidden_dim, 1)

    def forward(self, query, keys, mask=None):
        """
        Args:
            query: [batch, query_dim] 查询向量
            keys: [batch, num_keys, key_dim] 键值对
            mask: [batch, num_keys] 可选的mask
        Returns:
            context: [batch, key_dim] 加权上下文
            weights: [batch, num_keys] 注意力权重
        """
        # 投影
        q = self.query_proj(query).unsqueeze(1)  # [batch, 1, hidden]
        k = self.key_proj(keys)                   # [batch, num_keys, hidden]

        # 计算注意力分数
        scores = self.score(torch.tanh(q + k)).squeeze(-1)  # [batch, num_keys]

        # 应用mask
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))

        # Softmax归一化
        weights = torch.softmax(scores, dim=-1)  # [batch, num_keys]

        # 加权求和
        context = torch.bmm(weights.unsqueeze(1), keys).squeeze(1)

        return context, weights
Cross-Modal Attention
python 复制代码
class CrossModalAttention(nn.Module):
    """跨模态注意力"""

    def __init__(self, visual_dim, lang_dim, hidden_dim=512, num_heads=8):
        super().__init__()

        # 视觉 -> 语言 注意力
        self.v2l_attention = nn.MultiheadAttention(
            embed_dim=hidden_dim,
            num_heads=num_heads,
            batch_first=True
        )

        # 语言 -> 视觉 注意力
        self.l2v_attention = nn.MultiheadAttention(
            embed_dim=hidden_dim,
            num_heads=num_heads,
            batch_first=True
        )

        # 投影层
        self.visual_proj = nn.Linear(visual_dim, hidden_dim)
        self.lang_proj = nn.Linear(lang_dim, hidden_dim)

        # Layer Norm
        self.norm1 = nn.LayerNorm(hidden_dim)
        self.norm2 = nn.LayerNorm(hidden_dim)

    def forward(self, visual_feats, lang_feats, lang_mask=None):
        """
        Args:
            visual_feats: [batch, num_views, visual_dim]
            lang_feats: [batch, seq_len, lang_dim]
            lang_mask: [batch, seq_len]
        Returns:
            fused_visual: [batch, num_views, hidden_dim]
            fused_lang: [batch, seq_len, hidden_dim]
        """
        # 投影到相同维度
        v = self.visual_proj(visual_feats)
        l = self.lang_proj(lang_feats)

        # 视觉特征关注语言
        v_attended, _ = self.v2l_attention(
            query=v, key=l, value=l,
            key_padding_mask=~lang_mask if lang_mask is not None else None
        )
        fused_visual = self.norm1(v + v_attended)

        # 语言特征关注视觉
        l_attended, _ = self.l2v_attention(
            query=l, key=v, value=v
        )
        fused_lang = self.norm2(l + l_attended)

        return fused_visual, fused_lang

3.2 Co-Grounding机制

python 复制代码
class CoGrounding(nn.Module):
    """
    Co-Grounding: 同时进行视觉定位和语言定位
    参考: Self-Monitoring Navigation Agent (ICCV 2019)
    """

    def __init__(self, hidden_dim=512):
        super().__init__()

        # 文本到视觉的定位
        self.text_to_visual = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

        # 视觉到文本的定位
        self.visual_to_text = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, visual_feats, text_feats, text_mask=None):
        """
        Args:
            visual_feats: [batch, num_views, hidden_dim]
            text_feats: [batch, seq_len, hidden_dim]
        Returns:
            visual_weights: [batch, num_views] 视觉注意力
            text_weights: [batch, seq_len] 文本注意力
            visual_context: [batch, hidden_dim]
            text_context: [batch, hidden_dim]
        """
        batch_size = visual_feats.size(0)
        num_views = visual_feats.size(1)
        seq_len = text_feats.size(1)

        # 计算所有视觉-文本对的相似度
        # [batch, num_views, seq_len, hidden_dim*2]
        v_expanded = visual_feats.unsqueeze(2).expand(-1, -1, seq_len, -1)
        t_expanded = text_feats.unsqueeze(1).expand(-1, num_views, -1, -1)
        combined = torch.cat([v_expanded, t_expanded], dim=-1)

        # 文本到视觉权重
        t2v_scores = self.text_to_visual(combined).squeeze(-1)  # [batch, num_views, seq_len]
        t2v_scores = t2v_scores.mean(dim=-1)  # [batch, num_views]
        visual_weights = torch.softmax(t2v_scores, dim=-1)

        # 视觉到文本权重
        v2t_scores = self.visual_to_text(combined).squeeze(-1)
        v2t_scores = v2t_scores.mean(dim=1)  # [batch, seq_len]
        if text_mask is not None:
            v2t_scores = v2t_scores.masked_fill(~text_mask, float('-inf'))
        text_weights = torch.softmax(v2t_scores, dim=-1)

        # 加权得到上下文
        visual_context = (visual_feats * visual_weights.unsqueeze(-1)).sum(dim=1)
        text_context = (text_feats * text_weights.unsqueeze(-1)).sum(dim=1)

        return visual_weights, text_weights, visual_context, text_context

4. 动作解码与决策

4.1 动作空间

python 复制代码
# R2R 离散动作空间
class R2RActionSpace:
    """R2R数据集的动作空间"""

    # 高层动作
    ACTIONS = {
        'STOP': 0,           # 停止导航
        'FORWARD': 1,        # 选择一个viewpoint前进
    }

    # 实际执行时,FORWARD需要选择具体的viewpoint
    # viewpoint选择范围: 当前位置可达的相邻节点

    @staticmethod
    def get_navigable_viewpoints(state):
        """获取当前可导航的viewpoint列表"""
        return state.navigableLocations


# 连续导航动作空间 (Habitat)
class ContinuousActionSpace:
    """连续导航的动作空间"""

    ACTIONS = {
        'STOP': 0,
        'MOVE_FORWARD': 1,   # 前进0.25米
        'TURN_LEFT': 2,      # 左转15度
        'TURN_RIGHT': 3,     # 右转15度
    }

4.2 LSTM解码器

python 复制代码
class LSTMDecoder(nn.Module):
    """基于LSTM的动作解码器"""

    def __init__(self, input_dim, hidden_dim=512, dropout=0.5):
        super().__init__()

        self.lstm = nn.LSTMCell(input_dim, hidden_dim)
        self.dropout = nn.Dropout(dropout)

        # 动作预测头
        self.action_predictor = nn.Linear(hidden_dim, 1)  # 输出viewpoint分数

    def forward(self, x, prev_hidden, prev_cell):
        """
        Args:
            x: [batch, input_dim] 当前输入
            prev_hidden: [batch, hidden_dim] 上一步隐藏状态
            prev_cell: [batch, hidden_dim] 上一步cell状态
        Returns:
            action_logits: [batch, num_candidates] 动作logits
            hidden: [batch, hidden_dim]
            cell: [batch, hidden_dim]
        """
        hidden, cell = self.lstm(x, (prev_hidden, prev_cell))
        hidden = self.dropout(hidden)

        return hidden, cell


class ActionPredictor(nn.Module):
    """动作预测器 - 选择下一个viewpoint"""

    def __init__(self, hidden_dim, visual_dim):
        super().__init__()

        self.proj = nn.Linear(hidden_dim + visual_dim, hidden_dim)
        self.score = nn.Linear(hidden_dim, 1)

    def forward(self, hidden, candidate_features, candidate_mask=None):
        """
        Args:
            hidden: [batch, hidden_dim] 解码器隐藏状态
            candidate_features: [batch, num_candidates, visual_dim] 候选viewpoint特征
            candidate_mask: [batch, num_candidates] 有效候选mask
        Returns:
            action_probs: [batch, num_candidates] 动作概率分布
        """
        batch_size, num_candidates, _ = candidate_features.shape

        # 扩展hidden
        hidden_expanded = hidden.unsqueeze(1).expand(-1, num_candidates, -1)

        # 拼接并计算分数
        combined = torch.cat([hidden_expanded, candidate_features], dim=-1)
        scores = self.score(torch.tanh(self.proj(combined))).squeeze(-1)

        # 应用mask
        if candidate_mask is not None:
            scores = scores.masked_fill(~candidate_mask, float('-inf'))

        action_probs = torch.softmax(scores, dim=-1)

        return action_probs, scores

4.3 Transformer解码器

python 复制代码
class TransformerDecoder(nn.Module):
    """基于Transformer的动作解码器"""

    def __init__(self, hidden_dim=768, num_layers=4, num_heads=12, dropout=0.1):
        super().__init__()

        decoder_layer = nn.TransformerDecoderLayer(
            d_model=hidden_dim,
            nhead=num_heads,
            dim_feedforward=hidden_dim * 4,
            dropout=dropout,
            batch_first=True
        )

        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)

        # 位置编码
        self.pos_encoder = PositionalEncoding(hidden_dim, dropout)

        # 动作embedding
        self.action_embed = nn.Embedding(10, hidden_dim)  # 假设最多10种动作

    def forward(self, tgt, memory, tgt_mask=None, memory_mask=None):
        """
        Args:
            tgt: [batch, tgt_len, hidden_dim] 目标序列(历史动作)
            memory: [batch, src_len, hidden_dim] 编码器输出
        Returns:
            output: [batch, tgt_len, hidden_dim]
        """
        tgt = self.pos_encoder(tgt)
        output = self.decoder(tgt, memory, tgt_mask=tgt_mask)
        return output

5. 经典模型详解

5.1 Seq2Seq基础模型

Seq2Seq VLN模型结构
Instruction Bi-LSTM Encoder Observation ResNet Attention LSTM Decoder Action

数据流程:

  1. 语言编码: 指令通过Bi-LSTM编码为上下文向量
  2. 视觉编码: 观察图像通过ResNet提取特征
  3. 注意力融合: 语言和视觉特征通过注意力机制融合
  4. 动作解码: LSTM解码器生成导航动作
python 复制代码
class Seq2SeqVLN(nn.Module):
    """基础Seq2Seq VLN模型"""

    def __init__(self, vocab_size, embed_dim=256, hidden_dim=512):
        super().__init__()

        # 语言编码器
        self.lang_encoder = LSTMLanguageEncoder(vocab_size, embed_dim, hidden_dim)

        # 视觉编码器
        self.visual_encoder = ResNetVisualEncoder(output_dim=hidden_dim)

        # 注意力
        self.attention = SoftAttention(hidden_dim, hidden_dim)

        # 解码器
        self.decoder = nn.LSTMCell(hidden_dim * 2, hidden_dim)

        # 动作预测
        self.action_predictor = ActionPredictor(hidden_dim, hidden_dim)

    def forward(self, instructions, lengths, visual_obs, candidates,
                prev_hidden, prev_cell):
        """单步前向传播"""

        # 编码语言
        lang_features, lang_ctx = self.lang_encoder(instructions, lengths)

        # 编码视觉
        visual_features = self.visual_encoder(visual_obs)

        # 注意力加权语言
        attended_lang, lang_weights = self.attention(
            prev_hidden, lang_features
        )

        # 注意力加权视觉
        attended_visual, visual_weights = self.attention(
            prev_hidden, visual_features
        )

        # 解码
        decoder_input = torch.cat([attended_lang, attended_visual], dim=-1)
        hidden, cell = self.decoder(decoder_input, (prev_hidden, prev_cell))

        # 预测动作
        action_probs, action_logits = self.action_predictor(
            hidden, candidates
        )

        return action_probs, action_logits, hidden, cell

5.2 Speaker-Follower模型

Speaker-Follower数据增强框架

训练阶段的三步流程:

步骤 输入 模型 输出
1. 训练Speaker Path Speaker Synthetic Instruction
2. 数据增强 随机采样路径 Speaker 合成指令
3. 训练Follower 原始数据 + 增强数据 Follower 导航策略

核心思想:使用Speaker模型从路径生成指令,扩充训练数据。

python 复制代码
class Speaker(nn.Module):
    """Speaker模型:根据路径生成指令"""

    def __init__(self, vocab_size, visual_dim=2048, hidden_dim=512):
        super().__init__()

        # 视觉编码
        self.visual_encoder = nn.Linear(visual_dim, hidden_dim)

        # LSTM解码器生成指令
        self.decoder = nn.LSTM(hidden_dim + 256, hidden_dim, batch_first=True)

        # 词嵌入
        self.embedding = nn.Embedding(vocab_size, 256)

        # 输出层
        self.output = nn.Linear(hidden_dim, vocab_size)

    def forward(self, visual_sequence, target_instructions=None):
        """
        训练时使用teacher forcing
        推理时自回归生成
        """
        # 编码视觉序列
        visual_features = self.visual_encoder(visual_sequence)

        if target_instructions is not None:
            # Teacher forcing训练
            embeds = self.embedding(target_instructions[:, :-1])
            inputs = torch.cat([visual_features, embeds], dim=-1)
            outputs, _ = self.decoder(inputs)
            logits = self.output(outputs)
            return logits
        else:
            # 自回归生成
            return self.generate(visual_features)

    def generate(self, visual_features, max_len=80):
        """自回归生成指令"""
        batch_size = visual_features.size(0)
        device = visual_features.device

        # 初始化
        generated = torch.zeros(batch_size, 1).long().to(device)  # <BOS>
        hidden = None

        for _ in range(max_len):
            embeds = self.embedding(generated[:, -1:])
            inputs = torch.cat([visual_features[:, :1], embeds], dim=-1)
            outputs, hidden = self.decoder(inputs, hidden)
            logits = self.output(outputs)

            # 采样下一个词
            next_token = logits.argmax(dim=-1)
            generated = torch.cat([generated, next_token], dim=1)

            # 检查是否全部生成<EOS>
            if (next_token == 1).all():  # 假设1是<EOS>
                break

        return generated

5.3 VLNBERT / RecBERT

VLNBERT架构

输入序列格式:[CLS] w1 w2 ... wn [SEP] v1 v2 ... vm [SEP] h1 h2 ... hk

Token类型 说明
[CLS] 特殊分类token,输出用于动作预测
w1...wn 语言tokens(指令)
v1...vm 视觉tokens(当前观察)
h1...hk 历史tokens(导航历史)

处理流程:输入序列 → BERT Encoder (多层Transformer) → [CLS]输出 → 动作预测

python 复制代码
from transformers import BertModel, BertConfig

class VLNBERT(nn.Module):
    """VLN-BERT模型"""

    def __init__(self, config_path=None):
        super().__init__()

        # 加载BERT配置
        if config_path:
            config = BertConfig.from_json_file(config_path)
        else:
            config = BertConfig(
                hidden_size=768,
                num_attention_heads=12,
                num_hidden_layers=9,
                intermediate_size=3072
            )

        self.bert = BertModel(config)
        self.hidden_dim = config.hidden_size

        # 视觉投影
        self.visual_proj = nn.Linear(2048, self.hidden_dim)

        # 动作角度编码
        self.angle_encoder = nn.Linear(4, self.hidden_dim)  # [sin, cos, sin, cos]

        # Token类型embedding
        self.token_type_embeddings = nn.Embedding(3, self.hidden_dim)
        # 0: 语言, 1: 视觉, 2: 历史

        # 动作预测头
        self.action_head = nn.Linear(self.hidden_dim, 1)

    def forward(self, input_ids, attention_mask, visual_features,
                angle_features, history_features=None):
        """
        Args:
            input_ids: [batch, lang_len] 语言token ids
            attention_mask: [batch, lang_len]
            visual_features: [batch, num_views, 2048]
            angle_features: [batch, num_views, 4]
            history_features: [batch, hist_len, hidden_dim] 可选
        """
        batch_size = input_ids.size(0)

        # 1. 语言embedding
        lang_embeds = self.bert.embeddings.word_embeddings(input_ids)
        lang_type = self.token_type_embeddings(
            torch.zeros_like(input_ids)
        )
        lang_embeds = lang_embeds + lang_type

        # 2. 视觉embedding
        visual_embeds = self.visual_proj(visual_features)
        angle_embeds = self.angle_encoder(angle_features)
        visual_embeds = visual_embeds + angle_embeds
        visual_type = self.token_type_embeddings(
            torch.ones(batch_size, visual_embeds.size(1)).long().to(input_ids.device)
        )
        visual_embeds = visual_embeds + visual_type

        # 3. 拼接所有embedding
        if history_features is not None:
            history_type = self.token_type_embeddings(
                torch.full((batch_size, history_features.size(1)), 2).long().to(input_ids.device)
            )
            history_embeds = history_features + history_type
            all_embeds = torch.cat([lang_embeds, visual_embeds, history_embeds], dim=1)
        else:
            all_embeds = torch.cat([lang_embeds, visual_embeds], dim=1)

        # 4. 通过BERT
        outputs = self.bert(
            inputs_embeds=all_embeds,
            attention_mask=self._create_attention_mask(attention_mask, all_embeds)
        )

        # 5. 提取视觉token的表示用于动作预测
        lang_len = input_ids.size(1)
        visual_len = visual_features.size(1)
        visual_outputs = outputs.last_hidden_state[:, lang_len:lang_len+visual_len]

        # 6. 动作分数
        action_scores = self.action_head(visual_outputs).squeeze(-1)

        return action_scores, outputs.last_hidden_state[:, 0]  # scores, CLS

5.4 HAMT (History Aware Multimodal Transformer)

HAMT架构:显式建模导航历史

核心组件:

模块 功能
History Encoder 编码时序历史 obs₁→h₁→obs₂→h₂→...→obsₜ→hₜ
Cross-Modal Transformer Language ←Attention→ History 双向注意力融合
Action Prediction 基于融合特征预测动作
复制代码
History: obs₁ → obs₂ → obs₃ → ... → obsₜ
            ↓      ↓      ↓            ↓
         [h₁] → [h₂] → [h₃] → ... → [hₜ]
                                       ↓
Language ──────────────────────> Cross-Modal Transformer ──> Action
python 复制代码
class HAMT(nn.Module):
    """History Aware Multimodal Transformer"""

    def __init__(self, hidden_dim=768, num_layers=4, num_heads=12):
        super().__init__()

        # 语言编码器
        self.lang_encoder = BERTLanguageEncoder()

        # 视觉编码器
        self.visual_encoder = ViTVisualEncoder()

        # 历史编码器
        self.history_encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=hidden_dim,
                nhead=num_heads,
                batch_first=True
            ),
            num_layers=2
        )

        # 观察编码
        self.observation_encoder = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),  # visual + action
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )

        # 跨模态Transformer
        self.cross_modal_encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=hidden_dim,
                nhead=num_heads,
                batch_first=True
            ),
            num_layers=num_layers
        )

        # 动作预测
        self.action_predictor = nn.Linear(hidden_dim, 1)

        # 位置编码
        self.pos_encoding = LearnedPositionalEncoding(hidden_dim)

    def encode_history(self, observations, actions):
        """
        编码导航历史
        Args:
            observations: list of [batch, hidden_dim] 历史观察
            actions: list of [batch, hidden_dim] 历史动作
        Returns:
            history: [batch, hist_len, hidden_dim]
        """
        history_embeds = []
        for obs, act in zip(observations, actions):
            combined = torch.cat([obs, act], dim=-1)
            encoded = self.observation_encoder(combined)
            history_embeds.append(encoded)

        history = torch.stack(history_embeds, dim=1)
        history = history + self.pos_encoding(history)
        history = self.history_encoder(history)

        return history

    def forward(self, input_ids, attention_mask, current_visual,
                history_observations, history_actions, candidates):
        """
        Args:
            input_ids: [batch, seq_len]
            current_visual: [batch, num_views, visual_dim]
            history_*: 历史信息
            candidates: [batch, num_candidates, visual_dim]
        """
        # 1. 编码语言
        lang_features, _ = self.lang_encoder(input_ids, attention_mask)

        # 2. 编码当前视觉
        visual_features = self.visual_encoder(current_visual)

        # 3. 编码历史
        history_features = self.encode_history(
            history_observations, history_actions
        )

        # 4. 跨模态融合
        combined = torch.cat([
            lang_features,
            visual_features,
            history_features
        ], dim=1)

        fused = self.cross_modal_encoder(combined)

        # 5. 提取全局表示
        global_repr = fused.mean(dim=1)

        # 6. 动作预测
        candidate_scores = torch.bmm(
            candidates,
            global_repr.unsqueeze(-1)
        ).squeeze(-1)

        return candidate_scores

6. PyTorch实现

6.1 完整VLN Agent

python 复制代码
class VLNAgent(nn.Module):
    """完整的VLN Agent实现"""

    def __init__(self, config):
        super().__init__()

        self.config = config

        # 编码器
        self.lang_encoder = BERTLanguageEncoder(
            bert_model=config.bert_model,
            finetune=config.finetune_bert
        )

        self.visual_encoder = nn.Sequential(
            nn.Linear(config.visual_dim, config.hidden_dim),
            nn.ReLU(),
            nn.Dropout(config.dropout)
        )

        # 角度编码
        self.angle_encoder = nn.Linear(128, config.hidden_dim)

        # 跨模态注意力
        self.cross_attention = CrossModalAttention(
            visual_dim=config.hidden_dim,
            lang_dim=config.hidden_dim,
            hidden_dim=config.hidden_dim
        )

        # 历史编码(可选)
        if config.use_history:
            self.history_encoder = nn.GRU(
                config.hidden_dim,
                config.hidden_dim,
                batch_first=True
            )

        # 动作预测
        self.action_predictor = nn.Sequential(
            nn.Linear(config.hidden_dim * 2, config.hidden_dim),
            nn.ReLU(),
            nn.Dropout(config.dropout),
            nn.Linear(config.hidden_dim, 1)
        )

        # 停止预测
        self.stop_predictor = nn.Linear(config.hidden_dim, 2)

    def forward(self, batch, mode='train'):
        """
        Args:
            batch: 包含以下字段的字典
                - input_ids: [batch, seq_len]
                - attention_mask: [batch, seq_len]
                - visual_features: [batch, num_views, visual_dim]
                - angle_features: [batch, num_views, 128]
                - candidate_features: [batch, num_candidates, visual_dim]
                - candidate_angles: [batch, num_candidates, 128]
                - candidate_mask: [batch, num_candidates]
        """
        # 1. 语言编码
        lang_features, lang_global = self.lang_encoder(
            batch['input_ids'],
            batch['attention_mask']
        )

        # 2. 视觉编码
        visual_features = self.visual_encoder(batch['visual_features'])
        angle_features = self.angle_encoder(batch['angle_features'])
        visual_features = visual_features + angle_features

        # 3. 跨模态融合
        fused_visual, fused_lang = self.cross_attention(
            visual_features,
            lang_features,
            batch['attention_mask'].bool()
        )

        # 4. 全局表示
        visual_global = fused_visual.mean(dim=1)

        # 5. 候选viewpoint编码
        candidate_features = self.visual_encoder(batch['candidate_features'])
        candidate_angles = self.angle_encoder(batch['candidate_angles'])
        candidate_features = candidate_features + candidate_angles

        # 6. 动作分数计算
        state = torch.cat([visual_global, lang_global], dim=-1)
        state_expanded = state.unsqueeze(1).expand(-1, candidate_features.size(1), -1)

        combined = torch.cat([state_expanded, candidate_features], dim=-1)
        action_scores = self.action_predictor(combined).squeeze(-1)

        # 应用mask
        if 'candidate_mask' in batch:
            action_scores = action_scores.masked_fill(
                ~batch['candidate_mask'],
                float('-inf')
            )

        # 7. 停止预测
        stop_logits = self.stop_predictor(visual_global)

        return {
            'action_scores': action_scores,
            'stop_logits': stop_logits,
            'state': state
        }

6.2 训练循环

python 复制代码
class VLNTrainer:
    """VLN训练器"""

    def __init__(self, agent, optimizer, config):
        self.agent = agent
        self.optimizer = optimizer
        self.config = config

        self.action_criterion = nn.CrossEntropyLoss(ignore_index=-1)
        self.stop_criterion = nn.CrossEntropyLoss()

    def train_epoch(self, dataloader, env):
        """训练一个epoch"""
        self.agent.train()
        total_loss = 0

        for batch_idx, batch in enumerate(dataloader):
            # 移动到GPU
            batch = {k: v.cuda() if torch.is_tensor(v) else v
                    for k, v in batch.items()}

            # 初始化环境
            env.reset(batch)

            episode_loss = 0
            done = False
            step = 0

            while not done and step < self.config.max_steps:
                # 获取当前观察
                obs = env.get_observation()
                batch.update(obs)

                # 前向传播
                outputs = self.agent(batch, mode='train')

                # 计算损失
                # Teacher forcing: 使用真实动作
                target_action = batch['target_action']
                action_loss = self.action_criterion(
                    outputs['action_scores'],
                    target_action
                )

                target_stop = batch['target_stop']
                stop_loss = self.stop_criterion(
                    outputs['stop_logits'],
                    target_stop
                )

                step_loss = action_loss + self.config.stop_weight * stop_loss
                episode_loss += step_loss

                # 执行动作(teacher forcing)
                env.step(target_action)

                # 检查是否结束
                done = env.is_done()
                step += 1

            # 反向传播
            self.optimizer.zero_grad()
            episode_loss.backward()

            # 梯度裁剪
            torch.nn.utils.clip_grad_norm_(
                self.agent.parameters(),
                self.config.max_grad_norm
            )

            self.optimizer.step()

            total_loss += episode_loss.item()

            if batch_idx % self.config.log_interval == 0:
                print(f"Batch {batch_idx}, Loss: {episode_loss.item():.4f}")

        return total_loss / len(dataloader)

    def evaluate(self, dataloader, env):
        """评估"""
        self.agent.eval()

        all_results = []

        with torch.no_grad():
            for batch in dataloader:
                batch = {k: v.cuda() if torch.is_tensor(v) else v
                        for k, v in batch.items()}

                env.reset(batch)

                trajectory = []
                done = False
                step = 0

                while not done and step < self.config.max_steps:
                    obs = env.get_observation()
                    batch.update(obs)

                    outputs = self.agent(batch, mode='eval')

                    # 贪婪选择动作
                    action = outputs['action_scores'].argmax(dim=-1)

                    # 检查是否停止
                    stop_pred = outputs['stop_logits'].argmax(dim=-1)
                    if stop_pred.item() == 1:
                        action = torch.tensor([0])  # STOP action

                    env.step(action)
                    trajectory.append(action.item())

                    done = env.is_done()
                    step += 1

                # 计算指标
                result = self.compute_metrics(
                    trajectory,
                    batch['path'],
                    env.get_final_position()
                )
                all_results.append(result)

        # 聚合结果
        return self.aggregate_metrics(all_results)

7. 数学原理深入

7.1 注意力机制的数学表达

Scaled Dot-Product Attention

Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V Attention(Q,K,V)=softmax(dk QKT)V

其中:

  • Q ∈ R n × d k Q \in \mathbb{R}^{n \times d_k} Q∈Rn×dk:Query矩阵
  • K ∈ R m × d k K \in \mathbb{R}^{m \times d_k} K∈Rm×dk:Key矩阵
  • V ∈ R m × d v V \in \mathbb{R}^{m \times d_v} V∈Rm×dv:Value矩阵
  • d k d_k dk:Key的维度,用于缩放防止梯度消失
python 复制代码
import torch
import torch.nn.functional as F
import math

def scaled_dot_product_attention(query, key, value, mask=None):
    """
    数学公式的PyTorch实现

    Args:
        query: [batch, num_heads, seq_len_q, d_k]
        key: [batch, num_heads, seq_len_k, d_k]
        value: [batch, num_heads, seq_len_k, d_v]
        mask: [batch, 1, 1, seq_len_k] or [batch, 1, seq_len_q, seq_len_k]
    """
    d_k = query.size(-1)

    # QK^T / sqrt(d_k)
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)

    # 应用mask
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)

    # softmax归一化
    attention_weights = F.softmax(scores, dim=-1)

    # 加权求和
    output = torch.matmul(attention_weights, value)

    return output, attention_weights

7.2 跨模态对齐的损失函数

对比学习损失 (Contrastive Loss)

用于拉近匹配的视觉-语言对,推远不匹配的对:

L c o n t r a s t = − log ⁡ exp ⁡ ( sim ( v i , l i ) / τ ) ∑ j = 1 N exp ⁡ ( sim ( v i , l j ) / τ ) \mathcal{L}{contrast} = -\log \frac{\exp(\text{sim}(v_i, l_i) / \tau)}{\sum{j=1}^{N} \exp(\text{sim}(v_i, l_j) / \tau)} Lcontrast=−log∑j=1Nexp(sim(vi,lj)/τ)exp(sim(vi,li)/τ)

其中:

  • v i v_i vi:视觉特征
  • l i l_i li:语言特征
  • τ \tau τ:温度参数
  • sim ( ⋅ , ⋅ ) \text{sim}(\cdot, \cdot) sim(⋅,⋅):相似度函数(如余弦相似度)
python 复制代码
class ContrastiveLoss(nn.Module):
    """跨模态对比学习损失"""

    def __init__(self, temperature=0.07):
        super().__init__()
        self.temperature = temperature

    def forward(self, visual_features, lang_features):
        """
        Args:
            visual_features: [batch, hidden_dim]
            lang_features: [batch, hidden_dim]
        """
        # L2归一化
        visual_features = F.normalize(visual_features, dim=-1)
        lang_features = F.normalize(lang_features, dim=-1)

        # 计算相似度矩阵
        logits = torch.matmul(visual_features, lang_features.T) / self.temperature

        # 对角线是正样本
        labels = torch.arange(logits.size(0), device=logits.device)

        # 双向对比损失
        loss_v2l = F.cross_entropy(logits, labels)
        loss_l2v = F.cross_entropy(logits.T, labels)

        return (loss_v2l + loss_l2v) / 2

7.3 强化学习目标

策略梯度 (Policy Gradient)

VLN中的动作选择可以建模为序列决策问题:

∇ θ J ( θ ) = E τ ∼ π θ [ ∑ t = 0 T ∇ θ log ⁡ π θ ( a t ∣ s t ) ⋅ R ( τ ) ] \nabla_\theta J(\theta) = \mathbb{E}{\tau \sim \pi\theta} \left[ \sum_{t=0}^{T} \nabla_\theta \log \pi_\theta(a_t|s_t) \cdot R(\tau) \right] ∇θJ(θ)=Eτ∼πθ[t=0∑T∇θlogπθ(at∣st)⋅R(τ)]

带基线的REINFORCE

∇ θ J ( θ ) = E [ ∑ t = 0 T ∇ θ log ⁡ π θ ( a t ∣ s t ) ⋅ ( R ( τ ) − b ) ] \nabla_\theta J(\theta) = \mathbb{E} \left[ \sum_{t=0}^{T} \nabla_\theta \log \pi_\theta(a_t|s_t) \cdot (R(\tau) - b) \right] ∇θJ(θ)=E[t=0∑T∇θlogπθ(at∣st)⋅(R(τ)−b)]

python 复制代码
class REINFORCELoss:
    """REINFORCE策略梯度损失"""

    def __init__(self, baseline_type='average'):
        self.baseline_type = baseline_type
        self.baseline = 0
        self.alpha = 0.9  # 指数移动平均系数

    def compute_loss(self, log_probs, rewards):
        """
        Args:
            log_probs: list of [batch] 每步动作的log概率
            rewards: [batch] 回合奖励
        """
        # 计算基线
        if self.baseline_type == 'average':
            baseline = rewards.mean()
            # 更新移动平均基线
            self.baseline = self.alpha * self.baseline + (1 - self.alpha) * baseline.item()
        else:
            baseline = self.baseline

        # 计算优势
        advantages = rewards - baseline

        # 策略梯度损失
        policy_loss = 0
        for log_prob in log_probs:
            policy_loss -= (log_prob * advantages).mean()

        return policy_loss / len(log_probs)

8. 训练技巧与实践经验

8.1 数据增强策略

Speaker数据增强

使用Speaker模型生成合成指令,扩充训练数据:

python 复制代码
class SpeakerAugmentation:
    """基于Speaker的数据增强"""

    def __init__(self, speaker_model, env, num_augment=20):
        self.speaker = speaker_model
        self.env = env
        self.num_augment = num_augment

    def generate_augmented_data(self, original_data):
        augmented = []

        for _ in range(self.num_augment):
            # 1. 随机采样路径
            path = self.env.sample_random_path()

            # 2. 提取路径视觉特征
            visual_features = self.extract_path_features(path)

            # 3. Speaker生成指令
            instruction = self.speaker.generate(visual_features)

            augmented.append({
                'path': path,
                'instruction': instruction,
                'is_synthetic': True
            })

        return original_data + augmented

环境Dropout (EnvDrop)

随机遮挡视觉特征,增强泛化能力:

python 复制代码
class EnvironmentDropout(nn.Module):
    """环境Dropout正则化"""

    def __init__(self, drop_prob=0.5, feature_drop_prob=0.4):
        super().__init__()
        self.drop_prob = drop_prob
        self.feature_drop_prob = feature_drop_prob

    def forward(self, visual_features, training=True):
        """
        Args:
            visual_features: [batch, num_views, feat_dim]
        """
        if not training:
            return visual_features

        batch_size, num_views, feat_dim = visual_features.shape

        # 随机决定是否应用EnvDrop
        if torch.rand(1).item() > self.drop_prob:
            return visual_features

        # 随机遮挡部分视角
        view_mask = torch.rand(batch_size, num_views, 1, device=visual_features.device)
        view_mask = (view_mask > self.feature_drop_prob).float()

        return visual_features * view_mask

8.2 学习率调度策略

python 复制代码
def get_vln_scheduler(optimizer, num_training_steps, warmup_ratio=0.1):
    """
    VLN常用的学习率调度:
    - Warmup阶段线性增长
    - 之后余弦衰减
    """
    num_warmup_steps = int(num_training_steps * warmup_ratio)

    def lr_lambda(current_step):
        if current_step < num_warmup_steps:
            # 线性warmup
            return float(current_step) / float(max(1, num_warmup_steps))
        else:
            # 余弦衰减
            progress = float(current_step - num_warmup_steps) / \
                      float(max(1, num_training_steps - num_warmup_steps))
            return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))

    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)


# 使用示例
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01)
scheduler = get_vln_scheduler(optimizer, num_training_steps=10000, warmup_ratio=0.1)

8.3 混合训练策略

Teacher Forcing + Student Forcing

python 复制代码
class MixedTraining:
    """混合训练策略"""

    def __init__(self, teacher_forcing_ratio=0.5):
        self.tf_ratio = teacher_forcing_ratio

    def train_step(self, agent, env, batch, use_sample=False):
        """
        Args:
            use_sample: True时使用采样动作(Student Forcing)
                       False时使用真实动作(Teacher Forcing)
        """
        total_loss = 0

        env.reset(batch)
        hidden = agent.init_hidden(batch['batch_size'])

        for t in range(batch['max_steps']):
            obs = env.get_observation()

            # 前向传播
            action_logits, hidden = agent(obs, hidden, batch['instructions'])

            # 计算损失
            loss = F.cross_entropy(action_logits, batch['target_actions'][:, t])
            total_loss += loss

            # 决定使用哪个动作
            if use_sample or torch.rand(1).item() > self.tf_ratio:
                # Student Forcing: 使用模型预测
                action = action_logits.argmax(dim=-1)
            else:
                # Teacher Forcing: 使用真实动作
                action = batch['target_actions'][:, t]

            # 执行动作
            env.step(action)

            if env.all_done():
                break

        return total_loss


# 训练循环中的使用
trainer = MixedTraining(teacher_forcing_ratio=0.5)

for epoch in range(num_epochs):
    for batch in dataloader:
        # 交替使用两种策略
        if epoch % 2 == 0:
            loss = trainer.train_step(agent, env, batch, use_sample=False)
        else:
            loss = trainer.train_step(agent, env, batch, use_sample=True)

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(agent.parameters(), max_norm=5.0)
        optimizer.step()

8.4 实验对比:训练技巧的效果

以下是在R2R val_unseen上的消融实验结果:

配置 SR SPL 备注
Baseline 45.2 40.1 无任何技巧
+ EnvDrop 48.7 43.2 +3.5% SR
+ Speaker增强 52.1 45.8 +3.4% SR
+ 混合训练 53.4 47.1 +1.3% SR
+ 学习率调度 54.8 48.3 +1.4% SR
All Combined 58.2 51.6 总计 +13% SR

总结

本文详细介绍了VLN的核心模型架构:

关键组件

  1. 编码器:LSTM/BERT语言编码,ResNet/ViT视觉编码
  2. 跨模态融合:Attention机制,Cross-Modal Transformer
  3. 动作解码:LSTM/Transformer解码器,候选viewpoint打分

经典模型演进

  • Seq2SeqSpeaker-FollowerVLNBERTHAMT
  • 从简单的注意力机制到复杂的Transformer架构
  • 从单步决策到历史感知的序列建模

参考文献

1\] Anderson P, et al. "Vision-and-Language Navigation." *CVPR 2018*. \[2\] Fried D, et al. "Speaker-Follower Models for VLN." *NeurIPS 2018*. \[3\] Hong Y, et al. "VLN BERT: A Recurrent Vision-and-Language BERT." *CVPR 2021*. \[4\] Chen S, et al. "History Aware Multimodal Transformer." *NeurIPS 2021*. *** ** * ** *** *上一篇:[视觉语言导航从入门到精通(二):经典数据集与评估指标](./02_%E7%BB%8F%E5%85%B8%E6%95%B0%E6%8D%AE%E9%9B%86%E4%B8%8E%E8%AF%84%E4%BC%B0%E6%8C%87%E6%A0%87.md)* *下一篇:[视觉语言导航从入门到精通(四):前沿方法与最新进展](./04_%E5%89%8D%E6%B2%BF%E6%96%B9%E6%B3%95%E4%B8%8E%E6%9C%80%E6%96%B0%E8%BF%9B%E5%B1%95.md)*

相关推荐
坐吃山猪8 小时前
AutoGLMPhone03-adb模块
adb·llm·glm
lybugproducer8 小时前
深度学习专题:模型训练的张量并行(一)
人工智能·深度学习·transformer
V1ncent Chen9 小时前
深度学习进化的里程碑:Transformer模型
人工智能·深度学习·transformer
EdisonZhou9 小时前
MAF快速入门(7)工作流的状态共享
llm·aigc·agent·.net core
洛阳泰山9 小时前
快速上手 MaxKB4J:开源企业级智能知识库系统在 Sealos 上的完整部署指南
java·开源·llm·agent·rag
青衫客361 天前
浅谈 LightRAG —— 把“结构理解”前移到索引阶段的 RAG 新范式
大模型·llm·rag
深度学习实战训练营1 天前
TransUNet:Transformer 成为医学图像分割的强大编码器,Transformer 编码器 + U-Net 解码器-k学长深度学习专栏
人工智能·深度学习·transformer
破烂pan1 天前
模型推理加速技术全景解析:从基础优化到前沿创新
llm·模型加速
visnix1 天前
AI大模型-LLM原理剖析到训练微调实战(第二部分:大模型核心原理与Transformer架构)
前端·llm