DGL源码解析-GAT

Graph Attention Network (GAT) 原文地址:https://arxiv.org/pdf/1710.10903.pdf

深入理解注意力机制

1 类定义及参数说明

class GATConv(nn.Module):
r"""
Parameters
----------
in_feats : int, or pair of ints
Input feature size; i.e, the number of dimensions of :math:`h_i^{(l)}`.
ATConv can be applied on homogeneous graph and unidirectional
`bipartite graph <https://docs.dgl.ai/generated/dgl.bipartite.html?highlight=bipartite>`__.
If the layer is to be applied to a unidirectional bipartite graph, ``in_feats``
specifies the input feature size on both the source and destination nodes. If
a scalar is given, the source and destination node feature size would take the
same value.
out_feats : int
Output feature size; i.e, the number of dimensions of :math:`h_i^{(l+1)}`.
num_heads : int
Number of heads in Multi-Head Attention.
feat_drop : float, optional
Dropout rate on feature. Defaults: ``0``.
attn_drop : float, optional
Dropout rate on attention weight. Defaults: ``0``.
negative_slope : float, optional
LeakyReLU angle of negative slope. Defaults: ``0.2``.
residual : bool, optional
If True, use residual connection. Defaults: ``False``.
activation : callable activation function/layer or None, optional.
If not None, applies an activation function to the updated node features.
Default: ``None``.
allow_zero_in_degree : bool, optional
If there are 0-in-degree nodes in the graph, output for those nodes will be invalid
since no message will be passed to those nodes. This is harmful for some applications
causing silent performance regression. This module will raise a DGLError if it detects
0-in-degree nodes in input graph. By setting ``True``, it will suppress the check
and let the users handle it by themselves. Defaults: ``False``.

Note
----
Zero in-degree nodes will lead to invalid output value. This is because no message
will be passed to those nodes, the aggregation function will be appied on empty input.
A common practice to avoid this is to add a self-loop for each node in the graph if
it is homogeneous, which can be achieved by:
>>> g = ... # a DGLGraph
>>> g = dgl.add_self_loop(g)
"""

2 init() function

def __init__(self,
in_feats,
out_feats,
num_heads,
feat_drop=0.,
attn_drop=0.,
negative_slope=0.2,
residual=False,
activation=None,
allow_zero_in_degree=False):
super(GATConv, self).__init__()
self._num_heads = num_heads

# 若infeat为一个数,则将其扩展为相同的_in_src_feats和_in_dst_feats,以便于后续统一处理
self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)

self._out_feats = out_feats
self._allow_zero_in_degree = allow_zero_in_degree

# 若in_feats是元组
if isinstance(in_feats, tuple):
self.fc_src = nn.Linear(
self._in_src_feats, out_feats * num_heads, bias=False)
self.fc_dst = nn.Linear(
self._in_dst_feats, out_feats * num_heads, bias=False)

# 否则,即为homo graph时,定义对节点特征的线性变换,变换的输出维度为out_feats * num_heads,且无bias
else:
self.fc = nn.Linear(
self._in_src_feats, out_feats * num_heads, bias=False)

self.attn_l = nn.Parameter(th.FloatTensor(size=(1, num_heads, out_feats)))
self.attn_r = nn.Parameter(th.FloatTensor(size=(1, num_heads, out_feats)))

self.feat_drop = nn.Dropout(feat_drop)
self.attn_drop = nn.Dropout(attn_drop)
self.leaky_relu = nn.LeakyReLU(negative_slope)
if residual:
if self._in_dst_feats != out_feats:
self.res_fc = nn.Linear(
self._in_dst_feats, num_heads * out_feats, bias=False)
else:
self.res_fc = Identity()
else:
self.register_buffer('res_fc', None)
self.reset_parameters()
self.activation = activation

3 forward() function

def forward(self, graph, feat, get_attention=False):
r"""

Description
-----------
Compute graph attention network layer.

Parameters
----------
graph : DGLGraph
The graph.
feat : torch.Tensor or pair of torch.Tensor
If a torch.Tensor is given, the input feature of shape :math:`(N, D_{in})` where
:math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.
If a pair of torch.Tensor is given, the pair must contain two tensors of shape
:math:`(N_{in}, D_{in_{src}})` and :math:`(N_{out}, D_{in_{dst}})`.
get_attention : bool, optional
Whether to return the attention values. Default to False.

Returns
-------
torch.Tensor
The output feature of shape :math:`(N, H, D_{out})` where :math:`H`
is the number of heads, and :math:`D_{out}` is size of output feature.
这里将Heads直接返回,没有做拼接操作
torch.Tensor, optional
The attention values of shape :math:`(E, H, 1)`, where :math:`E` is the number of
edges. This is returned only when :attr:`get_attention` is ``True``.

Raises
------
DGLError
If there are 0-in-degree nodes in the input graph, it will raise DGLError
since no message will be passed to those nodes. This will cause invalid output.
The error can be ignored by setting ``allow_zero_in_degree`` parameter to ``True``.
"""
with graph.local_scope():
if not self._allow_zero_in_degree:
# any() 判断给定的可迭代参数是否全部为 False,则返回 False,如果有一个为 True,则返回 True
if (graph.in_degrees() == 0).any():
raise DGLError('There are 0-in-degree nodes in the graph, '
'output for those nodes will be invalid. '
'This is harmful for some applications, '
'causing silent performance regression. '
'Adding self-loop on the input graph by '
'calling `g = dgl.add_self_loop(g)` will resolve '
'the issue. Setting ``allow_zero_in_degree`` '
'to be `True` when constructing this module will '
'suppress the check and let the code run.')

if isinstance(feat, tuple):
h_src = self.feat_drop(feat[0])
h_dst = self.feat_drop(feat[1])
if not hasattr(self, 'fc_src'):
feat_src = self.fc(h_src).view(-1, self._num_heads, self._out_feats)
feat_dst = self.fc(h_dst).view(-1, self._num_heads, self._out_feats)
else:
feat_src = self.fc_src(h_src).view(-1, self._num_heads, self._out_feats)
feat_dst = self.fc_dst(h_dst).view(-1, self._num_heads, self._out_feats)

# 对于homograph,h_src和h_dst相同,均为输入的所有节点的feat
else:
h_src = h_dst = self.feat_drop(feat)

# 将输入的特征做线性变换(公式中的Wh),fc的输入shape为(n, in_feats)
# 输出shape为(n, out_feats * num_heads),
# 将其shape view为(-1, self._num_heads, self._out_feats),
# 即(n, self._num_heads, self._out_feats)
feat_src = feat_dst = self.fc(h_src).view(-1, self._num_heads, self._out_feats)

if graph.is_block:
feat_dst = feat_src[:graph.number_of_dst_nodes()]

eij= LeakyReLU (aT[WhiWhj])(1)e_{i j}=\text { LeakyReLU }\left(\overrightarrow{\mathbf{a}}^{T}\left[\mathbf{W} \vec{h}_{i} \| \mathbf{W} \vec{h}_{j}\right]\right) \tag{1}

公式(1)是GAT论文中的描述,节点 iijj 分别表示目的节点和源节点。可以看出,作者得到了 Whi\mathbf{W} \vec{h}_{i}Whj\mathbf{W} \vec{h}_{j} 之后,先对其进行拼接(concatenation),再乘以 aT\overrightarrow{\mathbf{a}}^{T} (linear projection)。而在DGL的实现中, a\mathbf{a} 分为 ala_lara_r, 先进行线性映射,再相加,二者在数学运算上是等价的:

aT[WhiWhj]=alWhi+arWhj(2)a^T [Wh_i || Wh_j] = a_l Wh_i + a_r Wh_j \tag{2}

注意等式左边为两个向量之间的矩阵乘法,右边为向量之间的点乘,因此在实现中,要对右式向量对应元素相乘之后的结果做sum(dim=-1)才能使得两边结果相等(Python中*为对应元素乘,@为矩阵乘),具体到代码实现:

# feat_src和feat_dst的shape为(n, self._num_heads, self._out_feats)
# attn_l和atten_r的shape为 (1, num_heads, out_feats),即feat_中的n个节点都点乘相同的attn_
# sum(dim=-1)后,el和re的shape为(n, num_heads)
# 为了后续计算,采用.unsqueeze(-1),将其维度扩展为(n, num_heads, 1)
el = (feat_src * self.attn_l).sum(dim=-1).unsqueeze(-1)
er = (feat_dst * self.attn_r).sum(dim=-1).unsqueeze(-1)
# 将feat保存在图上,因为graph和feature是分开传递的,而对于后续的计算,无论何种图,只用到了feat_src
# 将el和er保存在图上
graph.srcdata.update({'ft': feat_src, 'el': el})
graph.dstdata.update({'er': er})

# compute edge attention, el and er are a_l Wh_i and a_r Wh_j respectively.
graph.apply_edges(fn.u_add_v('el', 'er', 'e'))

e = self.leaky_relu(graph.edata.pop('e'))

DGL的实现方法更加高效,首先其避免了将 [WhiWhj][Wh_i || Wh_j] 存储在边上(我们之前说过,DGL在边上保存信息是十分消耗内存的);其次,采用DGL的内置函数 U_add_v 可以加快计算速度。

αij=softmaxj(eij)=exp(eij)kNiexp(eik)(3)\alpha_{i j}=\operatorname{softmax}_{j}\left(e_{i j}\right)=\frac{\exp \left(e_{i j}\right)}{\sum_{k \in \mathcal{N}_{i}} \exp \left(e_{i k}\right)} \tag{3}

hi=σ(jNiαijWhj)(4)\vec{h}_{i}^{\prime}=\sigma\left(\sum_{j \in \mathcal{N}_{i}} \alpha_{i j} \mathbf{W} \vec{h}_{j}\right) \tag{4}

# compute softmax as shown in eq.3
graph.edata['a'] = self.attn_drop(edge_softmax(graph, e))

# message passing as shown in eq.4
graph.update_all(fn.u_mul_e('ft', 'a', 'm'),
fn.sum('m', 'ft'))
rst = graph.dstdata['ft']

# residual
if self.res_fc is not None:
resval = self.res_fc(h_dst).view(h_dst.shape[0], -1, self._out_feats)
rst = rst + resval

# activation
if self.activation:
rst = self.activation(rst)

if get_attention:
return rst, graph.edata['a']
else:
return rst
文章作者: Alston
文章链接: https://lizitong67.github.io/2021/03/07/DGL%E6%BA%90%E7%A0%81%E8%A7%A3%E6%9E%90-GAT/
版权声明: 本博客所有文章除特别声明外,均采用 CC BY-NC-SA 4.0 许可协议。转载请注明来自 Alston's blog