python
import math
from typing import Union, Tuple, Optional
from torch_geometric.typing import PairTensor, Adj, OptTensor
import torch
from torch import Tensor
import torch.nn.functional as F
from torch_sparse import SparseTensor
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.utils import softmax
class TransformerConv(MessagePassing):
r"""The graph transformer operator from the `"Masked Label Prediction:
Unified Message Passing Model for Semi-Supervised Classification"
<https://arxiv.org/abs/2009.03509>`_ paper
.. math::
\mathbf{x}^{\prime}_i = \mathbf{W}_1 \mathbf{x}_i +
\sum_{j \in \mathcal{N}(i)} \alpha_{i,j} \mathbf{W}_2 \mathbf{x}_{j},
where the attention coefficients :math:`\alpha_{i,j}` are computed via
multi-head dot product attention:
.. math::
\alpha_{i,j} = \textrm{softmax} \left(
\frac{(\mathbf{W}_3\mathbf{x}_i)^{\top} (\mathbf{W}_4\mathbf{x}_j)}
{\sqrt{d}} \right)
Args:
in_channels (int or tuple): Size of each input sample, or :obj:`-1` to
derive the size from the first input(s) to the forward method.
A tuple corresponds to the sizes of source and target
dimensionalities.
out_channels (int): Size of each output sample.
heads (int, optional): Number of multi-head-attentions.
(default: :obj:`1`)
concat (bool, optional): If set to :obj:`False`, the multi-head
attentions are averaged instead of concatenated.
(default: :obj:`True`)
beta (bool, optional): If set, will combine aggregation and
skip information via
.. math::
\mathbf{x}^{\prime}_i = \beta_i \mathbf{W}_1 \mathbf{x}_i +
(1 - \beta_i) \underbrace{\left(\sum_{j \in \mathcal{N}(i)}
\alpha_{i,j} \mathbf{W}_2 \vec{x}_j \right)}_{=\mathbf{m}_i}
with :math:`\beta_i = \textrm{sigmoid}(\mathbf{w}_5^{\top}
[ \mathbf{x}_i, \mathbf{m}_i, \mathbf{x}_i - \mathbf{m}_i ])`
(default: :obj:`False`)
dropout (float, optional): Dropout probability of the normalized
attention coefficients which exposes each node to a stochastically
sampled neighborhood during training. (default: :obj:`0`)
edge_dim (int, optional): Edge feature dimensionality (in case
there are any). Edge features are added to the keys after
linear transformation, that is, prior to computing the
attention dot product. They are also added to final values
after the same linear transformation. The model is:
.. math::
\mathbf{x}^{\prime}_i = \mathbf{W}_1 \mathbf{x}_i +
\sum_{j \in \mathcal{N}(i)} \alpha_{i,j} \left(
\mathbf{W}_2 \mathbf{x}_{j} + \mathbf{W}_6 \mathbf{e}_{ij}
\right),
where the attention coefficients :math:`\alpha_{i,j}` are now
computed via:
.. math::
\alpha_{i,j} = \textrm{softmax} \left(
\frac{(\mathbf{W}_3\mathbf{x}_i)^{\top}
(\mathbf{W}_4\mathbf{x}_j + \mathbf{W}_6 \mathbf{e}_{ij})}
{\sqrt{d}} \right)
(default :obj:`None`)
bias (bool, optional): If set to :obj:`False`, the layer will not learn
an additive bias. (default: :obj:`True`)
root_weight (bool, optional): If set to :obj:`False`, the layer will
not add the transformed root node features to the output and the
option :attr:`beta` is set to :obj:`False`. (default: :obj:`True`)
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.MessagePassing`.
"""
_alpha: OptTensor
def __init__(
self,
in_channels: Union[int, Tuple[int, int]],
out_channels: int,
heads: int = 1,
concat: bool = True,
beta: bool = False,
dropout: float = 0.,
edge_dim: Optional[int] = None,
bias: bool = True,
root_weight: bool = True,
**kwargs,
):
kwargs.setdefault('aggr', 'add')
super(TransformerConv, self).__init__(node_dim=0, **kwargs)
self.in_channels = in_channels
self.out_channels = out_channels
self.heads = heads
self.beta = beta and root_weight
self.root_weight = root_weight
self.concat = concat
self.dropout = dropout
self.edge_dim = edge_dim
if isinstance(in_channels, int):
in_channels = (in_channels, in_channels)
self.lin_key = Linear(in_channels[0], heads * out_channels)
self.lin_query = Linear(in_channels[1], heads * out_channels)
self.lin_value = Linear(in_channels[0], heads * out_channels)
if edge_dim is not None:
self.lin_edge = Linear(edge_dim, heads * out_channels, bias=False)
else:
self.lin_edge = self.register_parameter('lin_edge', None)
if concat:
self.lin_skip = Linear(in_channels[1], heads * out_channels,
bias=bias)
if self.beta:
self.lin_beta = Linear(3 * heads * out_channels, 1, bias=False)
else:
self.lin_beta = self.register_parameter('lin_beta', None)
else:
self.lin_skip = Linear(in_channels[1], out_channels, bias=bias)
if self.beta:
self.lin_beta = Linear(3 * out_channels, 1, bias=False)
else:
self.lin_beta = self.register_parameter('lin_beta', None)
self.reset_parameters()
def reset_parameters(self):
self.lin_key.reset_parameters()
self.lin_query.reset_parameters()
self.lin_value.reset_parameters()
if self.edge_dim:
self.lin_edge.reset_parameters()
self.lin_skip.reset_parameters()
if self.beta:
self.lin_beta.reset_parameters()
def forward(self, x: Union[Tensor, PairTensor], edge_index: Adj,
edge_attr: OptTensor = None, return_attention_weights=None):
# type: (Union[Tensor, PairTensor], Tensor, OptTensor, NoneType) -> Tensor # noqa
# type: (Union[Tensor, PairTensor], SparseTensor, OptTensor, NoneType) -> Tensor # noqa
# type: (Union[Tensor, PairTensor], Tensor, OptTensor, bool) -> Tuple[Tensor, Tuple[Tensor, Tensor]] # noqa
# type: (Union[Tensor, PairTensor], SparseTensor, OptTensor, bool) -> Tuple[Tensor, SparseTensor] # noqa
r"""
Args:
return_attention_weights (bool, optional): If set to :obj:`True`,
will additionally return the tuple
:obj:`(edge_index, attention_weights)`, holding the computed
attention weights for each edge. (default: :obj:`None`)
"""
if isinstance(x, Tensor):
x: PairTensor = (x, x)
# propagate_type: (x: PairTensor, edge_attr: OptTensor)
out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=None)
alpha = self._alpha
self._alpha = None
if self.concat:
out = out.view(-1, self.heads * self.out_channels)
else:
out = out.mean(dim=1)
if self.root_weight:
x_r = self.lin_skip(x[1])
if self.lin_beta is not None:
beta = self.lin_beta(torch.cat([out, x_r, out - x_r], dim=-1))
beta = beta.sigmoid()
out = beta * x_r + (1 - beta) * out
else:
out += x_r
if isinstance(return_attention_weights, bool):
assert alpha is not None
if isinstance(edge_index, Tensor):
return out, (edge_index, alpha)
elif isinstance(edge_index, SparseTensor):
return out, edge_index.set_value(alpha, layout='coo')
else:
return out
def message(self, x_i: Tensor, x_j: Tensor, edge_attr: OptTensor,
index: Tensor, ptr: OptTensor,
size_i: Optional[int]) -> Tensor:
query = self.lin_query(x_i).view(-1, self.heads, self.out_channels)
key = self.lin_key(x_j).view(-1, self.heads, self.out_channels)
if self.lin_edge is not None:
assert edge_attr is not None
edge_attr = self.lin_edge(edge_attr).view(-1, self.heads,
self.out_channels)
key += edge_attr
alpha = (query * key).sum(dim=-1) / math.sqrt(self.out_channels)
alpha = softmax(alpha, index, ptr, size_i)
self._alpha = alpha
alpha = F.dropout(alpha, p=self.dropout, training=self.training)
out = self.lin_value(x_j).view(-1, self.heads, self.out_channels)
if edge_attr is not None:
out += edge_attr
out *= alpha.view(-1, self.heads, 1)
return out
def __repr__(self):
return '{}({}, {}, heads={})'.format(self.__class__.__name__,
self.in_channels,
self.out_channels, self.heads)
我在debug的时候,查看了一下具体python里面的函数和官方的多少还是有点差别的。感觉现在应该是能够完全理解了吧!!
我遇到的代码是这样的:
python
self.gc1 = tg.nn.TransformerConv(params.feat_hidden2, params.gcn_hidden1, heads=1, dropout=params.p_drop)
然后这里的params.feat_hidden2=128 params.gcn_hidden1=128,对应着内置函数就是in_channels: Union[int, Tuple[int, int]], out_channels: int。
实例化之后,就是:
python
hidden1,atten = self.gc1(feat_x, adj,return_attention_weights=True)
feat_x,adj传入forward中,分别对应着的是x: Union[Tensor, PairTensor], edge_index: Adj。也就是节点特征x和边索引Adj。
python
out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=None)
这里最重要的就是self.propagate的这个函数了。具体参考一下这几个博客
PyG的MessagePassing基类中self.propagate函数的消息传递机制-CSDN博客
Torch geometric GCNConv 源码分析-CSDN博客
https://zhuanlan.zhihu.com/p/113862170
在这里主要提一下:在self.propagate函数中,会自动调用self.message函数、self.aggregate函数、self.update函数。我在这里主要讲一下self.message这个函数
python
def message(self, x_i: Tensor, x_j: Tensor, edge_attr: OptTensor,
index: Tensor, ptr: OptTensor,
size_i: Optional[int]) -> Tensor:
query = self.lin_query(x_i).view(-1, self.heads, self.out_channels)
key = self.lin_key(x_j).view(-1, self.heads, self.out_channels)
if self.lin_edge is not None:
assert edge_attr is not None
edge_attr = self.lin_edge(edge_attr).view(-1, self.heads,
self.out_channels)
key += edge_attr
alpha = (query * key).sum(dim=-1) / math.sqrt(self.out_channels)
alpha = softmax(alpha, index, ptr, size_i)
self._alpha = alpha
alpha = F.dropout(alpha, p=self.dropout, training=self.training)
out = self.lin_value(x_j).view(-1, self.heads, self.out_channels)
if edge_attr is not None:
out += edge_attr
out *= alpha.view(-1, self.heads, 1)
return out
python中是这样给出的
现在看一下propagate和message的对应关系。
propagate:
edge_index, x=x,这两个参数也就是图的边索引,前者edge_index是一个2*E的维度,E代表有多少条边。
这个应该上下对应着看01 10 12 20 分别连成不同的边。
- x_i : 源节点的特征张量。在
propagate
中,这对应于x
,即x_i = x[edge_index[0]]
。 - x_j : 目标节点的特征张量。在
propagate
中,这对应于x
,即x_j = x[edge_index[1]]
。 - edge_attr : 边特征张量。直接对应于
edge_attr
参数。 - index : 索引张量,用于定位当前消息的具体位置。在
propagate
中,PyTorch Geometric会自动处理这一部分。message最后返回的是out - 后续的gc1的参数没什么了都是简单的维度变化之类的。
- 感觉终于把这个函数看明白了 好不容易