python
import math
import typing
from typing import Optional, Tuple, Union
import torch
import torch.nn.functional as F
from torch import Tensor
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.typing import (
Adj,
NoneType,
OptTensor,
PairTensor,
SparseTensor,
)
from torch_geometric.utils import softmax
if typing.TYPE_CHECKING:
from typing import overload
else:
from torch.jit import _overload_method as overload
[docs]class TransformerConv(MessagePassing):
r"""The graph transformer operator from the `"Masked Label Prediction:
Unified Message Passing Model for Semi-Supervised Classification"
<https://arxiv.org/abs/2009.03509>`_ paper.
.. math::
\mathbf{x}^{\prime}_i = \mathbf{W}_1 \mathbf{x}_i +
\sum_{j \in \mathcal{N}(i)} \alpha_{i,j} \mathbf{W}_2 \mathbf{x}_{j},
where the attention coefficients :math:`\alpha_{i,j}` are computed via
multi-head dot product attention:
.. math::
\alpha_{i,j} = \textrm{softmax} \left(
\frac{(\mathbf{W}_3\mathbf{x}_i)^{\top} (\mathbf{W}_4\mathbf{x}_j)}
{\sqrt{d}} \right)
Args:
in_channels (int or tuple): Size of each input sample, or :obj:`-1` to
derive the size from the first input(s) to the forward method.
A tuple corresponds to the sizes of source and target
dimensionalities.
out_channels (int): Size of each output sample.
heads (int, optional): Number of multi-head-attentions.
(default: :obj:`1`)
concat (bool, optional): If set to :obj:`False`, the multi-head
attentions are averaged instead of concatenated.
(default: :obj:`True`)
beta (bool, optional): If set, will combine aggregation and
skip information via
.. math::
\mathbf{x}^{\prime}_i = \beta_i \mathbf{W}_1 \mathbf{x}_i +
(1 - \beta_i) \underbrace{\left(\sum_{j \in \mathcal{N}(i)}
\alpha_{i,j} \mathbf{W}_2 \vec{x}_j \right)}_{=\mathbf{m}_i}
with :math:`\beta_i = \textrm{sigmoid}(\mathbf{w}_5^{\top}
[ \mathbf{W}_1 \mathbf{x}_i, \mathbf{m}_i, \mathbf{W}_1
\mathbf{x}_i - \mathbf{m}_i ])` (default: :obj:`False`)
dropout (float, optional): Dropout probability of the normalized
attention coefficients which exposes each node to a stochastically
sampled neighborhood during training. (default: :obj:`0`)
edge_dim (int, optional): Edge feature dimensionality (in case
there are any). Edge features are added to the keys after
linear transformation, that is, prior to computing the
attention dot product. They are also added to final values
after the same linear transformation. The model is:
.. math::
\mathbf{x}^{\prime}_i = \mathbf{W}_1 \mathbf{x}_i +
\sum_{j \in \mathcal{N}(i)} \alpha_{i,j} \left(
\mathbf{W}_2 \mathbf{x}_{j} + \mathbf{W}_6 \mathbf{e}_{ij}
\right),
where the attention coefficients :math:`\alpha_{i,j}` are now
computed via:
.. math::
\alpha_{i,j} = \textrm{softmax} \left(
\frac{(\mathbf{W}_3\mathbf{x}_i)^{\top}
(\mathbf{W}_4\mathbf{x}_j + \mathbf{W}_6 \mathbf{e}_{ij})}
{\sqrt{d}} \right)
(default :obj:`None`)
bias (bool, optional): If set to :obj:`False`, the layer will not learn
an additive bias. (default: :obj:`True`)
root_weight (bool, optional): If set to :obj:`False`, the layer will
not add the transformed root node features to the output and the
option :attr:`beta` is set to :obj:`False`. (default: :obj:`True`)
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.MessagePassing`.
"""
_alpha: OptTensor
def __init__(
self,
in_channels: Union[int, Tuple[int, int]],
out_channels: int,
heads: int = 1,
concat: bool = True,
beta: bool = False,
dropout: float = 0.,
edge_dim: Optional[int] = None,
bias: bool = True,
root_weight: bool = True,
**kwargs,
):
kwargs.setdefault('aggr', 'add')
super().__init__(node_dim=0, **kwargs)
self.in_channels = in_channels
self.out_channels = out_channels
self.heads = heads
self.beta = beta and root_weight
self.root_weight = root_weight
self.concat = concat
self.dropout = dropout
self.edge_dim = edge_dim
self._alpha = None
if isinstance(in_channels, int):
in_channels = (in_channels, in_channels)
self.lin_key = Linear(in_channels[0], heads * out_channels)
self.lin_query = Linear(in_channels[1], heads * out_channels)
self.lin_value = Linear(in_channels[0], heads * out_channels)
if edge_dim is not None:
self.lin_edge = Linear(edge_dim, heads * out_channels, bias=False)
else:
self.lin_edge = self.register_parameter('lin_edge', None)
if concat:
self.lin_skip = Linear(in_channels[1], heads * out_channels,
bias=bias)
if self.beta:
self.lin_beta = Linear(3 * heads * out_channels, 1, bias=False)
else:
self.lin_beta = self.register_parameter('lin_beta', None)
else:
self.lin_skip = Linear(in_channels[1], out_channels, bias=bias)
if self.beta:
self.lin_beta = Linear(3 * out_channels, 1, bias=False)
else:
self.lin_beta = self.register_parameter('lin_beta', None)
self.reset_parameters()
[docs] def reset_parameters(self):
super().reset_parameters()
self.lin_key.reset_parameters()
self.lin_query.reset_parameters()
self.lin_value.reset_parameters()
if self.edge_dim:
self.lin_edge.reset_parameters()
self.lin_skip.reset_parameters()
if self.beta:
self.lin_beta.reset_parameters()
@overload
def forward(
self,
x: Union[Tensor, PairTensor],
edge_index: Adj,
edge_attr: OptTensor = None,
return_attention_weights: NoneType = None,
) -> Tensor:
pass
@overload
def forward( # noqa: F811
self,
x: Union[Tensor, PairTensor],
edge_index: Tensor,
edge_attr: OptTensor = None,
return_attention_weights: bool = None,
) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
pass
@overload
def forward( # noqa: F811
self,
x: Union[Tensor, PairTensor],
edge_index: SparseTensor,
edge_attr: OptTensor = None,
return_attention_weights: bool = None,
) -> Tuple[Tensor, SparseTensor]:
pass
[docs] def forward( # noqa: F811
self,
x: Union[Tensor, PairTensor],
edge_index: Adj,
edge_attr: OptTensor = None,
return_attention_weights: Optional[bool] = None,
) -> Union[
Tensor,
Tuple[Tensor, Tuple[Tensor, Tensor]],
Tuple[Tensor, SparseTensor],
]:
r"""Runs the forward pass of the module.
Args:
x (torch.Tensor or (torch.Tensor, torch.Tensor)): The input node
features.
edge_index (torch.Tensor or SparseTensor): The edge indices.
edge_attr (torch.Tensor, optional): The edge features.
(default: :obj:`None`)
return_attention_weights (bool, optional): If set to :obj:`True`,
will additionally return the tuple
:obj:`(edge_index, attention_weights)`, holding the computed
attention weights for each edge. (default: :obj:`None`)
"""
H, C = self.heads, self.out_channels
if isinstance(x, Tensor):
x = (x, x)
query = self.lin_query(x[1]).view(-1, H, C)
key = self.lin_key(x[0]).view(-1, H, C)
value = self.lin_value(x[0]).view(-1, H, C)
# propagate_type: (query: Tensor, key:Tensor, value: Tensor,
# edge_attr: OptTensor)
out = self.propagate(edge_index, query=query, key=key, value=value,
edge_attr=edge_attr)
alpha = self._alpha
self._alpha = None
if self.concat:
out = out.view(-1, self.heads * self.out_channels)
else:
out = out.mean(dim=1)
if self.root_weight:
x_r = self.lin_skip(x[1])
if self.lin_beta is not None:
beta = self.lin_beta(torch.cat([out, x_r, out - x_r], dim=-1))
beta = beta.sigmoid()
out = beta * x_r + (1 - beta) * out
else:
out = out + x_r
if isinstance(return_attention_weights, bool):
assert alpha is not None
if isinstance(edge_index, Tensor):
return out, (edge_index, alpha)
elif isinstance(edge_index, SparseTensor):
return out, edge_index.set_value(alpha, layout='coo')
else:
return out
def message(self, query_i: Tensor, key_j: Tensor, value_j: Tensor,
edge_attr: OptTensor, index: Tensor, ptr: OptTensor,
size_i: Optional[int]) -> Tensor:
if self.lin_edge is not None:
assert edge_attr is not None
edge_attr = self.lin_edge(edge_attr).view(-1, self.heads,
self.out_channels)
key_j = key_j + edge_attr
alpha = (query_i * key_j).sum(dim=-1) / math.sqrt(self.out_channels)
alpha = softmax(alpha, index, ptr, size_i)
self._alpha = alpha
alpha = F.dropout(alpha, p=self.dropout, training=self.training)
out = value_j
if edge_attr is not None:
out = out + edge_attr
out = out * alpha.view(-1, self.heads, 1)
return out
def __repr__(self) -> str:
return (f'{self.__class__.__name__}({self.in_channels}, '
f'{self.out_channels}, heads={self.heads})')
我感觉,我是不是学代码天生比别人慢一拍,就是其实很简单的函数也要花好久才能明白。我大哭。
之前在跑别人代码的时候,遇到了hidden这里的维度不匹配,就返回去找维度。
这的hidden1到底是什么?
python
hidden1=(1-self.at)*(torch.mm(atten,hidden1))+self.at*(torch.mm(atten_prue,hidden1_prue))
维度不匹配只能是atten和hidden1的维度不匹配了,那attention和hidden1到底是什么?
python
hidden1,atten = self.gc1(feat_x, adj,return_attention_weights=True)
这里的hidden1对应着的是上一篇文章中forward的输出结果,也就是out,具体out的是什么?
然后return_attention_weight设置的是True
- 如果
return_attention_weights=True
,并且edge_index
是普通张量 (Tensor
) 类型,则返回(out, (edge_index, alpha))
,其中alpha
是计算得到的注意力权重。 - 如果
return_attention_weights=True
,并且edge_index
是稀疏张量 (SparseTensor
) 类型,则返回(out, edge_index.set_value(alpha, layout='coo'))
,稀疏张量包含了注意力权重的值。 - 如果
return_attention_weights
不是布尔值或者是None
,则仅返回out
。
python
self.gc1 = tg.nn.TransformerConv(params.feat_hidden2, params.gcn_hidden1, heads=1, dropout=params.p_drop)