20250226-代码笔记05-class CVRP_Decoder

文章目录

  • 前言
  • [一、class CVRP_Decoder(nn.Module):init(self, **model_params)](#一、class CVRP_Decoder(nn.Module):init(self, **model_params))
  • [二、class CVRP_Decoder(nn.Module):set_kv(self, encoded_nodes)](#二、class CVRP_Decoder(nn.Module):set_kv(self, encoded_nodes))
  • [三、class CVRP_Decoder(nn.Module):set_q1(self, encoded_q1)](#三、class CVRP_Decoder(nn.Module):set_q1(self, encoded_q1))
  • [四、class CVRP_Decoder(nn.Module):set_q2(self, encoded_q2)](#四、class CVRP_Decoder(nn.Module):set_q2(self, encoded_q2))
  • [五、class CVRP_Decoder(nn.Module):forward(self, encoded_last_node, load, ninf_mask)](#五、class CVRP_Decoder(nn.Module):forward(self, encoded_last_node, load, ninf_mask))
  • 附录
    • [class CVRP_Decoder代码(全)](#class CVRP_Decoder代码(全))

前言

class CVRP_DecoderCVRP_Model.py里的类。

/home/tang/RL_exa/NCO_code-main/single_objective/LCH-Regret/Regret-POMO/CVRP/POMO/CVRPModel.py


一、class CVRP_Decoder(nn.Module):init(self, **model_params)

函数功能

init 方法是 CVRP_Decoder 类中的构造函数,主要功能是初始化该类所需的所有网络层、权重矩阵和参数。

该方法设置了用于多头注意力机制的权重、一个用于表示"遗憾"的参数、以及其他必要的操作用于计算注意力权重。

执行流程图链接

函数代码

python 复制代码
    def __init__(self, **model_params):
        super().__init__()
        self.model_params = model_params
        embedding_dim = self.model_params['embedding_dim']
        head_num = self.model_params['head_num']
        qkv_dim = self.model_params['qkv_dim']

        # self.Wq_1 = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)
        self.Wq_2 = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)
        self.Wq_last = nn.Linear(embedding_dim+1, head_num * qkv_dim, bias=False)
        self.Wk = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)
        self.Wv = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)

        self.regret_embedding = nn.Parameter(torch.Tensor(embedding_dim))
        self.regret_embedding.data.uniform_(-1, 1)

        self.multi_head_combine = nn.Linear(head_num * qkv_dim, embedding_dim)

        self.k = None  # saved key, for multi-head attention
        self.v = None  # saved value, for multi-head_attention
        self.single_head_key = None  # saved, for single-head attention
        # self.q1 = None  # saved q1, for multi-head attention
        self.q2 = None  # saved q2, for multi-head attention

二、class CVRP_Decoder(nn.Module):set_kv(self, encoded_nodes)

函数功能

set_kv 方法的功能是将 encoded_nodes 中的节点嵌入转换为多头注意力机制所需的 键(K)值(V) ,并将它们分别保存为类的属性。

这个方法将输入的节点嵌入通过权重矩阵进行线性变换,得到键和值的表示,并为后续的多头注意力计算做好准备。
执行流程图链接

函数代码

python 复制代码
    def set_kv(self, encoded_nodes):
        # encoded_nodes.shape: (batch, problem+1, embedding)
        head_num = self.model_params['head_num']

        self.k = reshape_by_heads(self.Wk(encoded_nodes), head_num=head_num)
        self.v = reshape_by_heads(self.Wv(encoded_nodes), head_num=head_num)
        # shape: (batch, head_num, problem+1, qkv_dim)
        self.single_head_key = encoded_nodes.transpose(1, 2)
        # shape: (batch, embedding, problem+1)

三、class CVRP_Decoder(nn.Module):set_q1(self, encoded_q1)

函数功能

set_q1 方法的主要功能是 计算查询(Q) 并将其转换为适用于多头注意力机制的形状。

该方法接受输入的查询张量 encoded_q1,通过线性层 self.Wq_1 映射到一个新的维度,并使用 reshape_by_heads 函数将其调整为适合多头注意力机制计算的形状。计算出的查询会被保存为类的属性 q1,供后续使用。

执行流程图链接

函数代码

python 复制代码
    def set_q1(self, encoded_q1):
        # encoded_q.shape: (batch, n, embedding)  # n can be 1 or pomo
        head_num = self.model_params['head_num']
        self.q1 = reshape_by_heads(self.Wq_1(encoded_q1), head_num=head_num)
        # shape: (batch, head_num, n, qkv_dim)

四、class CVRP_Decoder(nn.Module):set_q2(self, encoded_q2)

函数功能

set_q2 方法的主要功能是 计算查询(Q) 并将其转换为适用于多头注意力机制的形状。

该方法接收输入的查询张量 encoded_q2,通过线性层 self.Wq_2 映射到一个新的维度,并使用 reshape_by_heads 函数将其调整为适合多头注意力计算的形状。
执行流程图链接

函数代码

python 复制代码
    def set_q2(self, encoded_q2):
        # encoded_q.shape: (batch, n, embedding)  # n can be 1 or pomo
        head_num = self.model_params['head_num']
        self.q2 = reshape_by_heads(self.Wq_2(encoded_q2), head_num=head_num)
        # shape: (batch, head_num, n, qkv_dim)

五、class CVRP_Decoder(nn.Module):forward(self, encoded_last_node, load, ninf_mask)

函数功能

forward 方法是 CVRP_Decoder 类中的前向传播函数,主要功能是执行 多头自注意力机制 和 单头注意力计算,并最终输出每个可能节点的选择概率(probs)。

该方法通过多头注意力计算、前馈神经网络处理,以及概率计算来进行节点选择。

执行流程图链接

函数代码

python 复制代码
    def forward(self, encoded_last_node, load, ninf_mask):
        # encoded_last_node.shape: (batch, pomo, embedding)
        # load.shape: (batch, pomo)
        # ninf_mask.shape: (batch, pomo, problem)

        head_num = self.model_params['head_num']

        #  Multi-Head Attention
        #######################################################
        input_cat = torch.cat((encoded_last_node, load[:, :, None]), dim=2)
        # shape = (batch, group, EMBEDDING_DIM+1)

        q_last = reshape_by_heads(self.Wq_last(input_cat), head_num=head_num)
        # shape: (batch, head_num, pomo, qkv_dim)

        # q = self.q1 + self.q2 + q_last
        # # shape: (batch, head_num, pomo, qkv_dim)
        # q = q_last
        # shape: (batch, head_num, pomo, qkv_dim)
        q = self.q2 + q_last
        # # shape: (batch, head_num, pomo, qkv_dim)

        out_concat = multi_head_attention(q, self.k, self.v, rank3_ninf_mask=ninf_mask)
        # shape: (batch, pomo, head_num*qkv_dim)

        mh_atten_out = self.multi_head_combine(out_concat)
        # shape: (batch, pomo, embedding)

        #  Single-Head Attention, for probability calculation
        #######################################################
        score = torch.matmul(mh_atten_out, self.single_head_key)
        # shape: (batch, pomo, problem)

        sqrt_embedding_dim = self.model_params['sqrt_embedding_dim']
        logit_clipping = self.model_params['logit_clipping']

        score_scaled = score / sqrt_embedding_dim
        # shape: (batch, pomo, problem)

        score_clipped = logit_clipping * torch.tanh(score_scaled)

        score_masked = score_clipped + ninf_mask

        probs = F.softmax(score_masked, dim=2)
        # shape: (batch, pomo, problem)

        return probs

附录

class CVRP_Decoder代码(全)

python 复制代码
class CVRP_Decoder(nn.Module):
    def __init__(self, **model_params):
        super().__init__()
        self.model_params = model_params
        embedding_dim = self.model_params['embedding_dim']
        head_num = self.model_params['head_num']
        qkv_dim = self.model_params['qkv_dim']

        # self.Wq_1 = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)
        self.Wq_2 = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)
        self.Wq_last = nn.Linear(embedding_dim+1, head_num * qkv_dim, bias=False)
        self.Wk = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)
        self.Wv = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)

        self.regret_embedding = nn.Parameter(torch.Tensor(embedding_dim))
        self.regret_embedding.data.uniform_(-1, 1)

        self.multi_head_combine = nn.Linear(head_num * qkv_dim, embedding_dim)

        self.k = None  # saved key, for multi-head attention
        self.v = None  # saved value, for multi-head_attention
        self.single_head_key = None  # saved, for single-head attention
        # self.q1 = None  # saved q1, for multi-head attention
        self.q2 = None  # saved q2, for multi-head attention

    def set_kv(self, encoded_nodes):
        # encoded_nodes.shape: (batch, problem+1, embedding)
        head_num = self.model_params['head_num']

        self.k = reshape_by_heads(self.Wk(encoded_nodes), head_num=head_num)
        self.v = reshape_by_heads(self.Wv(encoded_nodes), head_num=head_num)
        # shape: (batch, head_num, problem+1, qkv_dim)
        self.single_head_key = encoded_nodes.transpose(1, 2)
        # shape: (batch, embedding, problem+1)

    def set_q1(self, encoded_q1):
        # encoded_q.shape: (batch, n, embedding)  # n can be 1 or pomo
        head_num = self.model_params['head_num']
        self.q1 = reshape_by_heads(self.Wq_1(encoded_q1), head_num=head_num)
        # shape: (batch, head_num, n, qkv_dim)

    def set_q2(self, encoded_q2):
        # encoded_q.shape: (batch, n, embedding)  # n can be 1 or pomo
        head_num = self.model_params['head_num']
        self.q2 = reshape_by_heads(self.Wq_2(encoded_q2), head_num=head_num)
        # shape: (batch, head_num, n, qkv_dim)

    def forward(self, encoded_last_node, load, ninf_mask):
        # encoded_last_node.shape: (batch, pomo, embedding)
        # load.shape: (batch, pomo)
        # ninf_mask.shape: (batch, pomo, problem)

        head_num = self.model_params['head_num']

        #  Multi-Head Attention
        #######################################################
        input_cat = torch.cat((encoded_last_node, load[:, :, None]), dim=2)
        # shape = (batch, group, EMBEDDING_DIM+1)

        q_last = reshape_by_heads(self.Wq_last(input_cat), head_num=head_num)
        # shape: (batch, head_num, pomo, qkv_dim)

        # q = self.q1 + self.q2 + q_last
        # # shape: (batch, head_num, pomo, qkv_dim)
        # q = q_last
        # shape: (batch, head_num, pomo, qkv_dim)
        q = self.q2 + q_last
        # # shape: (batch, head_num, pomo, qkv_dim)

        out_concat = multi_head_attention(q, self.k, self.v, rank3_ninf_mask=ninf_mask)
        # shape: (batch, pomo, head_num*qkv_dim)

        mh_atten_out = self.multi_head_combine(out_concat)
        # shape: (batch, pomo, embedding)

        #  Single-Head Attention, for probability calculation
        #######################################################
        score = torch.matmul(mh_atten_out, self.single_head_key)
        # shape: (batch, pomo, problem)

        sqrt_embedding_dim = self.model_params['sqrt_embedding_dim']
        logit_clipping = self.model_params['logit_clipping']

        score_scaled = score / sqrt_embedding_dim
        # shape: (batch, pomo, problem)

        score_clipped = logit_clipping * torch.tanh(score_scaled)

        score_masked = score_clipped + ninf_mask

        probs = F.softmax(score_masked, dim=2)
        # shape: (batch, pomo, problem)

        return probs
相关推荐
摇滚侠4 分钟前
Spring Boot 3零基础教程,WEB 开发 通过配置类代码方式修改静态资源配置 笔记32
java·spring boot·笔记
聪明的笨猪猪2 小时前
Java JVM “内存(1)”面试清单(含超通俗生活案例与深度理解)
java·经验分享·笔记·面试
_dindong2 小时前
Linux网络编程:Socket编程TCP
linux·服务器·网络·笔记·学习·tcp/ip
摇滚侠3 小时前
Spring Boot 3零基础教程,WEB 开发 Thymeleaf 属性优先级 行内写法 变量选择 笔记42
java·spring boot·笔记
摇滚侠3 小时前
Spring Boot 3零基础教程,WEB 开发 Thymeleaf 总结 热部署 常用配置 笔记44
java·spring boot·笔记
rechol3 小时前
汇编与底层编程笔记
汇编·arm开发·笔记
lzj_pxxw4 小时前
嵌入式开发技巧:舍弃标志位,用宏定义函数实现程序单次运行
笔记·stm32·单片机·嵌入式硬件·学习
润 下5 小时前
C语言——回调函数的典型示例(分析详解)
c语言·开发语言·人工智能·经验分享·笔记·程序人生
朝新_5 小时前
【EE初阶 - 网络原理】传输层协议
java·开发语言·网络·笔记·javaee
koo3645 小时前
李宏毅机器学习笔记27
人工智能·笔记·机器学习