20250226-代码笔记05-class CVRP_Decoder

文章目录

  • 前言
  • [一、class CVRP_Decoder(nn.Module):init(self, **model_params)](#一、class CVRP_Decoder(nn.Module):init(self, **model_params))
  • [二、class CVRP_Decoder(nn.Module):set_kv(self, encoded_nodes)](#二、class CVRP_Decoder(nn.Module):set_kv(self, encoded_nodes))
  • [三、class CVRP_Decoder(nn.Module):set_q1(self, encoded_q1)](#三、class CVRP_Decoder(nn.Module):set_q1(self, encoded_q1))
  • [四、class CVRP_Decoder(nn.Module):set_q2(self, encoded_q2)](#四、class CVRP_Decoder(nn.Module):set_q2(self, encoded_q2))
  • [五、class CVRP_Decoder(nn.Module):forward(self, encoded_last_node, load, ninf_mask)](#五、class CVRP_Decoder(nn.Module):forward(self, encoded_last_node, load, ninf_mask))
  • 附录
    • [class CVRP_Decoder代码(全)](#class CVRP_Decoder代码(全))

前言

class CVRP_DecoderCVRP_Model.py里的类。

/home/tang/RL_exa/NCO_code-main/single_objective/LCH-Regret/Regret-POMO/CVRP/POMO/CVRPModel.py


一、class CVRP_Decoder(nn.Module):init(self, **model_params)

函数功能

init 方法是 CVRP_Decoder 类中的构造函数,主要功能是初始化该类所需的所有网络层、权重矩阵和参数。

该方法设置了用于多头注意力机制的权重、一个用于表示"遗憾"的参数、以及其他必要的操作用于计算注意力权重。

执行流程图链接

函数代码

python 复制代码
    def __init__(self, **model_params):
        super().__init__()
        self.model_params = model_params
        embedding_dim = self.model_params['embedding_dim']
        head_num = self.model_params['head_num']
        qkv_dim = self.model_params['qkv_dim']

        # self.Wq_1 = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)
        self.Wq_2 = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)
        self.Wq_last = nn.Linear(embedding_dim+1, head_num * qkv_dim, bias=False)
        self.Wk = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)
        self.Wv = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)

        self.regret_embedding = nn.Parameter(torch.Tensor(embedding_dim))
        self.regret_embedding.data.uniform_(-1, 1)

        self.multi_head_combine = nn.Linear(head_num * qkv_dim, embedding_dim)

        self.k = None  # saved key, for multi-head attention
        self.v = None  # saved value, for multi-head_attention
        self.single_head_key = None  # saved, for single-head attention
        # self.q1 = None  # saved q1, for multi-head attention
        self.q2 = None  # saved q2, for multi-head attention

二、class CVRP_Decoder(nn.Module):set_kv(self, encoded_nodes)

函数功能

set_kv 方法的功能是将 encoded_nodes 中的节点嵌入转换为多头注意力机制所需的 键(K)值(V) ,并将它们分别保存为类的属性。

这个方法将输入的节点嵌入通过权重矩阵进行线性变换,得到键和值的表示,并为后续的多头注意力计算做好准备。
执行流程图链接

函数代码

python 复制代码
    def set_kv(self, encoded_nodes):
        # encoded_nodes.shape: (batch, problem+1, embedding)
        head_num = self.model_params['head_num']

        self.k = reshape_by_heads(self.Wk(encoded_nodes), head_num=head_num)
        self.v = reshape_by_heads(self.Wv(encoded_nodes), head_num=head_num)
        # shape: (batch, head_num, problem+1, qkv_dim)
        self.single_head_key = encoded_nodes.transpose(1, 2)
        # shape: (batch, embedding, problem+1)

三、class CVRP_Decoder(nn.Module):set_q1(self, encoded_q1)

函数功能

set_q1 方法的主要功能是 计算查询(Q) 并将其转换为适用于多头注意力机制的形状。

该方法接受输入的查询张量 encoded_q1,通过线性层 self.Wq_1 映射到一个新的维度,并使用 reshape_by_heads 函数将其调整为适合多头注意力机制计算的形状。计算出的查询会被保存为类的属性 q1,供后续使用。

执行流程图链接

函数代码

python 复制代码
    def set_q1(self, encoded_q1):
        # encoded_q.shape: (batch, n, embedding)  # n can be 1 or pomo
        head_num = self.model_params['head_num']
        self.q1 = reshape_by_heads(self.Wq_1(encoded_q1), head_num=head_num)
        # shape: (batch, head_num, n, qkv_dim)

四、class CVRP_Decoder(nn.Module):set_q2(self, encoded_q2)

函数功能

set_q2 方法的主要功能是 计算查询(Q) 并将其转换为适用于多头注意力机制的形状。

该方法接收输入的查询张量 encoded_q2,通过线性层 self.Wq_2 映射到一个新的维度,并使用 reshape_by_heads 函数将其调整为适合多头注意力计算的形状。
执行流程图链接

函数代码

python 复制代码
    def set_q2(self, encoded_q2):
        # encoded_q.shape: (batch, n, embedding)  # n can be 1 or pomo
        head_num = self.model_params['head_num']
        self.q2 = reshape_by_heads(self.Wq_2(encoded_q2), head_num=head_num)
        # shape: (batch, head_num, n, qkv_dim)

五、class CVRP_Decoder(nn.Module):forward(self, encoded_last_node, load, ninf_mask)

函数功能

forward 方法是 CVRP_Decoder 类中的前向传播函数,主要功能是执行 多头自注意力机制 和 单头注意力计算,并最终输出每个可能节点的选择概率(probs)。

该方法通过多头注意力计算、前馈神经网络处理,以及概率计算来进行节点选择。

执行流程图链接

函数代码

python 复制代码
    def forward(self, encoded_last_node, load, ninf_mask):
        # encoded_last_node.shape: (batch, pomo, embedding)
        # load.shape: (batch, pomo)
        # ninf_mask.shape: (batch, pomo, problem)

        head_num = self.model_params['head_num']

        #  Multi-Head Attention
        #######################################################
        input_cat = torch.cat((encoded_last_node, load[:, :, None]), dim=2)
        # shape = (batch, group, EMBEDDING_DIM+1)

        q_last = reshape_by_heads(self.Wq_last(input_cat), head_num=head_num)
        # shape: (batch, head_num, pomo, qkv_dim)

        # q = self.q1 + self.q2 + q_last
        # # shape: (batch, head_num, pomo, qkv_dim)
        # q = q_last
        # shape: (batch, head_num, pomo, qkv_dim)
        q = self.q2 + q_last
        # # shape: (batch, head_num, pomo, qkv_dim)

        out_concat = multi_head_attention(q, self.k, self.v, rank3_ninf_mask=ninf_mask)
        # shape: (batch, pomo, head_num*qkv_dim)

        mh_atten_out = self.multi_head_combine(out_concat)
        # shape: (batch, pomo, embedding)

        #  Single-Head Attention, for probability calculation
        #######################################################
        score = torch.matmul(mh_atten_out, self.single_head_key)
        # shape: (batch, pomo, problem)

        sqrt_embedding_dim = self.model_params['sqrt_embedding_dim']
        logit_clipping = self.model_params['logit_clipping']

        score_scaled = score / sqrt_embedding_dim
        # shape: (batch, pomo, problem)

        score_clipped = logit_clipping * torch.tanh(score_scaled)

        score_masked = score_clipped + ninf_mask

        probs = F.softmax(score_masked, dim=2)
        # shape: (batch, pomo, problem)

        return probs

附录

class CVRP_Decoder代码(全)

python 复制代码
class CVRP_Decoder(nn.Module):
    def __init__(self, **model_params):
        super().__init__()
        self.model_params = model_params
        embedding_dim = self.model_params['embedding_dim']
        head_num = self.model_params['head_num']
        qkv_dim = self.model_params['qkv_dim']

        # self.Wq_1 = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)
        self.Wq_2 = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)
        self.Wq_last = nn.Linear(embedding_dim+1, head_num * qkv_dim, bias=False)
        self.Wk = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)
        self.Wv = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)

        self.regret_embedding = nn.Parameter(torch.Tensor(embedding_dim))
        self.regret_embedding.data.uniform_(-1, 1)

        self.multi_head_combine = nn.Linear(head_num * qkv_dim, embedding_dim)

        self.k = None  # saved key, for multi-head attention
        self.v = None  # saved value, for multi-head_attention
        self.single_head_key = None  # saved, for single-head attention
        # self.q1 = None  # saved q1, for multi-head attention
        self.q2 = None  # saved q2, for multi-head attention

    def set_kv(self, encoded_nodes):
        # encoded_nodes.shape: (batch, problem+1, embedding)
        head_num = self.model_params['head_num']

        self.k = reshape_by_heads(self.Wk(encoded_nodes), head_num=head_num)
        self.v = reshape_by_heads(self.Wv(encoded_nodes), head_num=head_num)
        # shape: (batch, head_num, problem+1, qkv_dim)
        self.single_head_key = encoded_nodes.transpose(1, 2)
        # shape: (batch, embedding, problem+1)

    def set_q1(self, encoded_q1):
        # encoded_q.shape: (batch, n, embedding)  # n can be 1 or pomo
        head_num = self.model_params['head_num']
        self.q1 = reshape_by_heads(self.Wq_1(encoded_q1), head_num=head_num)
        # shape: (batch, head_num, n, qkv_dim)

    def set_q2(self, encoded_q2):
        # encoded_q.shape: (batch, n, embedding)  # n can be 1 or pomo
        head_num = self.model_params['head_num']
        self.q2 = reshape_by_heads(self.Wq_2(encoded_q2), head_num=head_num)
        # shape: (batch, head_num, n, qkv_dim)

    def forward(self, encoded_last_node, load, ninf_mask):
        # encoded_last_node.shape: (batch, pomo, embedding)
        # load.shape: (batch, pomo)
        # ninf_mask.shape: (batch, pomo, problem)

        head_num = self.model_params['head_num']

        #  Multi-Head Attention
        #######################################################
        input_cat = torch.cat((encoded_last_node, load[:, :, None]), dim=2)
        # shape = (batch, group, EMBEDDING_DIM+1)

        q_last = reshape_by_heads(self.Wq_last(input_cat), head_num=head_num)
        # shape: (batch, head_num, pomo, qkv_dim)

        # q = self.q1 + self.q2 + q_last
        # # shape: (batch, head_num, pomo, qkv_dim)
        # q = q_last
        # shape: (batch, head_num, pomo, qkv_dim)
        q = self.q2 + q_last
        # # shape: (batch, head_num, pomo, qkv_dim)

        out_concat = multi_head_attention(q, self.k, self.v, rank3_ninf_mask=ninf_mask)
        # shape: (batch, pomo, head_num*qkv_dim)

        mh_atten_out = self.multi_head_combine(out_concat)
        # shape: (batch, pomo, embedding)

        #  Single-Head Attention, for probability calculation
        #######################################################
        score = torch.matmul(mh_atten_out, self.single_head_key)
        # shape: (batch, pomo, problem)

        sqrt_embedding_dim = self.model_params['sqrt_embedding_dim']
        logit_clipping = self.model_params['logit_clipping']

        score_scaled = score / sqrt_embedding_dim
        # shape: (batch, pomo, problem)

        score_clipped = logit_clipping * torch.tanh(score_scaled)

        score_masked = score_clipped + ninf_mask

        probs = F.softmax(score_masked, dim=2)
        # shape: (batch, pomo, problem)

        return probs
相关推荐
想拿高薪的韭菜27 分钟前
计算机网络-双绞线制作
经验分享·笔记·计算机网络·课程设计
2401_884810741 小时前
Vue3笔记
前端·vue.js·笔记
爱上妖精的尾巴2 小时前
3-2 WPS JS宏 工作簿的打开与保存(模板批量另存为工作)学习笔记
javascript·笔记·学习·js·wps
北岛寒沫3 小时前
深度学习重要论文阅读笔记 ResNet (2025.2.26)
论文阅读·笔记·深度学习
芋圆芋圆大芋圆3 小时前
【苍穹外卖】问题笔记
笔记
20230005 小时前
seacmsv9报错注入管理员账号密码,order by 注入,如何解决 information_schema关键字被过滤掉了
笔记
想你依然心痛5 小时前
侯捷 C++ 课程学习笔记:深入理解类与继承
c++·笔记·学习
SRC_BLUE_175 小时前
[Web 安全] 反序列化漏洞 - 学习笔记
笔记·学习·安全·网络安全
一抓掉一大把6 小时前
c#笔记-基础知识
开发语言·笔记·c#