DGL源码解析-GCN

图卷积网络 (Graph Convolutional Network, GCN)原文地址:https://arxiv.org/abs/1609.02907

1 Math

论文中给出的更新函数如下:

H(l+1)=σ(D~1/2A~D~1/2H(l)W(l))(1)H^{(l+1)}=\sigma\left(\tilde{D}^{-{1/2}} \tilde{A} \tilde{D}^{-{1/2}} H^{(l)} W^{(l)}\right) \tag{1}

将其展开后,每个节点的更新函数如下:

hi(l+1)=σ(b(l)+jN(i)1cjihj(l)W(l))(2)h_i^{(l+1)} = \sigma(b^{(l)} + \sum_{j\in\mathcal{N}(i)}\frac{1}{c_{ji}}h_j^{(l)}W^{(l)}) \tag{2}

其中,N(i)\mathcal{N}(i) 是目标节点 ii 的邻居节点集合,cjic_{ji}iijj 对应节点度数的平方根的乘积,即 cji=N(j)N(i)c_{ji} = \sqrt{|\mathcal{N}(j)|}\sqrt{|\mathcal{N}(i)|}, σ\sigma 为激活函数。

注意:论文中的公式所适用的图均为无向图

若考虑图中每条边的权重,则对应的更新函数表示如下:

hi(l+1)=σ(b(l)+jN(i)ejicjihj(l)W(l))(3)h_i^{(l+1)} = \sigma(b^{(l)} + \sum_{j\in\mathcal{N}(i)}\frac{e_{ji}}{c_{ji}}h_j^{(l)}W^{(l)}) \tag{3}

其中,ejie_{ji} 是边 jiji 所对应的权重标量。

在DGL中,用户可以自定义地标准化 cjic_{ji}: 首先将模型设置为 norm=none,然后将预先标准化过的 ejie_{ji} 传递给聚合函数。

2 类定义及参数说明

class GraphConv(nn.Module):
r"""

Parameters
----------
in_feats : int
Input feature size; i.e, the number of dimensions of :math:`h_j^{(l)}`.
out_feats : int
Output feature size; i.e., the number of dimensions of :math:`h_i^{(l+1)}`.
norm : str, optional
如何使用normalizer:
(1)若为`'right'`, 则将聚合信息除以每个节点的入度(in-degrees),相当于对聚合的信息做平均化
(2)若为`'none'`,将不会使用任何normalizer。如上所述,用户可以将其置为none,然后借助e_{ji}实现
自定义的normalizer
(3)默认为`'both'`,即使用论文中定义的`c_{ji}` ,包括入度和出度进行normalizer
weight : bool, optional
If True, apply a linear layer. Otherwise, aggregating the messages without a weight matrix.
bias : bool, optional
If True, adds a learnable bias to the output. Default: ``True``.
activation : callable activation function/layer or None, optional
If not None, applies an activation function to the updated node features.
Default: ``None``.
allow_zero_in_degree : bool, optional
If there are 0-in-degree nodes in the graph, output for those nodes will be invalid
since no message will be passed to those nodes. This is harmful for some applications
causing silent performance regression. This module will raise a DGLError if it detects
0-in-degree nodes in input graph. By setting ``True``, it will suppress the check
and let the users handle it by themselves. Default: ``False``.

Attributes (parameters为传递的参数,attributes为类内部的参数)
----------
weight : torch.Tensor
The learnable weight tensor.
bias : torch.Tensor
The learnable bias tensor.

Note
----
Zero in-degree nodes will lead to invalid output value. This is because no message
will be passed to those nodes, the aggregation function will be appied on empty input.
A common practice to avoid this is to add a self-loop for each node in the graph if
it is homogeneous, which can be achieved by:

>>> g = ... # a DGLGraph
>>> g = dgl.add_self_loop(g)

Calling ``add_self_loop`` will not work for some graphs, for example, heterogeneous graph
since the edge type can not be decided for self_loop edges. Set ``allow_zero_in_degree``
to ``True`` for those cases to unblock the code and handle zere-in-degree nodes manually.
A common practise to handle this is to filter out the nodes with zere-in-degree when use
after conv.

3 init function()

def __init__(self,
in_feats,
out_feats,
norm='both',
weight=True,
bias=True,
activation=None,
allow_zero_in_degree=False):
super(GraphConv, self).__init__()
if norm not in ('none', 'both', 'right'):
raise DGLError('Invalid norm value. Must be either "none", "both" or "right".'
' But got "{}".'.format(norm))
self._in_feats = in_feats
self._out_feats = out_feats
self._norm = norm
self._allow_zero_in_degree = allow_zero_in_degree

if weight:
self.weight = nn.Parameter(th.Tensor(in_feats, out_feats))
else:
self.register_parameter('weight', None)

if bias:
self.bias = nn.Parameter(th.Tensor(out_feats))
else:
self.register_parameter('bias', None)

self.reset_parameters()

self._activation = activation

4 forward() function

def forward(self, graph, feat, weight=None, edge_weight=None):
r"""

Description
-----------
Compute graph convolution.

Parameters
----------
graph : DGLGraph
The graph.
feat : torch.Tensor or pair of torch.Tensor
If a torch.Tensor is given, it represents the input feature of shape
:math:`(N, D_{in})`
where :math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.
If a pair of torch.Tensor is given, which is the case for bipartite graph, the pair
must contain two tensors of shape :math:`(N_{in}, D_{in_{src}})` and
:math:`(N_{out}, D_{in_{dst}})`.
weight : torch.Tensor, optional
Optional external weight tensor.(将init()方法中的weight置为none,)
edge_weight : torch.Tensor, optional
Optional tensor on the edge. If given, the convolution will weight
with regard to the message.

Returns
-------
torch.Tensor
The output feature

Raises
------
DGLError
Case 1:
If there are 0-in-degree nodes in the input graph, it will raise DGLError
since no message will be passed to those nodes. This will cause invalid output.
The error can be ignored by setting ``allow_zero_in_degree`` parameter to ``True``.

Case 2:
External weight is provided while at the same time the module
has defined its own weight parameter.

Note
----
* Input shape: :math:`(N, *, \text{in_feats})` where * means any number of additional
dimensions, :math:`N` is the number of nodes.
* Output shape: :math:`(N, *, \text{out_feats})` where all but the last dimension are
the same shape as the input.
* Weight shape: :math:`(\text{in_feats}, \text{out_feats})`.
"""
with graph.local_scope():
if not self._allow_zero_in_degree:
if (graph.in_degrees() == 0).any():
raise DGLError('There are 0-in-degree nodes in the graph, '
'output for those nodes will be invalid. '
'This is harmful for some applications, '
'causing silent performance regression. '
'Adding self-loop on the input graph by '
'calling `g = dgl.add_self_loop(g)` will resolve '
'the issue. Setting ``allow_zero_in_degree`` '
'to be `True` when constructing this module will '
'suppress the check and let the code run.')

# copy_src(已被copy_u取代)为内置消息函数,将源节点的特征h复制并发动到mailbox,表示为m
aggregate_fn = fn.copy_src('h', 'm')

# 边的权重
if edge_weight is not None:
assert edge_weight.shape[0] == graph.number_of_edges()
graph.edata['_edge_weight'] = edge_weight
# 若考虑edge weight,则聚合函数为源节点特征乘以边的权重,结果保存为m
aggregate_fn = fn.u_mul_e('h', '_edge_weight', 'm')

# (BarclayII) For RGCN on heterogeneous graphs we need to support GCN on bipartite.
feat_src, feat_dst = expand_as_pair(feat, graph)
if self._norm == 'both':
# clamp函数用于约束返回值到A和B之间,若value小于min,则返回min;
# 若value大于max,则返回max,起到上下截断的作用。
degs = graph.out_degrees().float().clamp(min=1)
# degs的-0.5次方
norm = th.pow(degs, -0.5)
# (1,) * (feat_src.dim() - 1):假设feature为单个向量,则feat_src.dim()为2,最后运算结果为(1,)
# 若feat_src.dim()为3,则(1,) * (feat_src.dim() - 1) = (1,1)
# norm.shape + (1,),表示在原先维度的基础上,扩展以为,其值为1
shp = norm.shape + (1,) * (feat_src.dim() - 1)
# norm reshape之后便于点乘
norm = th.reshape(norm, shp)
# 假设norm原始shape为[n],reshape后为[n,1],而feat_src shape为[n,in_feat]
feat_src = feat_src * norm

if weight is not None:
if self.weight is not None:
raise DGLError('External weight is provided while at the same time the'
' module has defined its own weight parameter. Please'
' create the module with flag weight=False.')
else:
weight = self.weight

if self._in_feats > self._out_feats:
# mult W first to reduce the feature size for aggregation.
if weight is not None:
# th.matmul:矩阵乘法
feat_src = th.matmul(feat_src, weight)
graph.srcdata['h'] = feat_src
graph.update_all(aggregate_fn, fn.sum(msg='m', out='h'))
rst = graph.dstdata['h']
else:
# aggregate first then mult W
graph.srcdata['h'] = feat_src
graph.update_all(aggregate_fn, fn.sum(msg='m', out='h'))
rst = graph.dstdata['h']
if weight is not None:
rst = th.matmul(rst, weight)

if self._norm != 'none':
degs = graph.in_degrees().float().clamp(min=1)
if self._norm == 'both':
norm = th.pow(degs, -0.5)
else:
norm = 1.0 / degs
shp = norm.shape + (1,) * (feat_dst.dim() - 1)
norm = th.reshape(norm, shp)
rst = rst * norm

if self.bias is not None:
rst = rst + self.bias

if self._activation is not None:
rst = self._activation(rst)

return rst

注意这里在计算norm的时候,先对所有feat_src计算了out_degree,然后聚合后将最终结果计算了in_degree。在论文中,是假设所有图为无向图的,而在DGL中,构建无向图的方法是节点之间添加双向边,如此一来,DGL中对不同图的GCN实现可分为以下情况:

  • 无向图:每个节点的 degree = out_degree = in_degree,DGL适用;
  • 二部图:源节点只有出度,目标节点只有入度,虽然为有向图,DGL仍适用;
  • 有向图:DGL在有向图上进行graphconv,实际上是对原始的GCN论文做了拓展,在计算norm时,仅仅计算了源节点的出度和目标节点的入度(相关理论支持有待进一步研究,如为何不同时计算所有节点的入度和出度?也许只是为了代码简洁,将所有情况用一套代码实现?)
文章作者: Alston
文章链接: https://lizitong67.github.io/2021/03/14/DGL%E6%BA%90%E7%A0%81%E8%A7%A3%E6%9E%90-GCN/
版权声明: 本博客所有文章除特别声明外,均采用 CC BY-NC-SA 4.0 许可协议。转载请注明来自 Alston's blog