DGL源码解析-GraphSAGE

GraphSAGE论文地址: https://arxiv.org/pdf/1706.02216.pdf

1 Math

论文中给出的更新函数公式如下:

hN(i)(l+1)=aggregate({hjl,jN(i)})(1)h_{\mathcal{N}(i)}^{(l+1)} = \mathrm{aggregate} \left(\{h_{j}^{l}, \forall j \in \mathcal{N}(i) \}\right) \tag{1}

hi(l+1)=σ(Wconcat(hil,hN(i)l+1))(2)h_{i}^{(l+1)} = \sigma \left(W \cdot \mathrm{concat} (h_{i}^{l}, h_{\mathcal{N}(i)}^{l+1}) \right) \tag{2}

hi(l+1)=norm(hil)(3)h_{i}^{(l+1)} = \mathrm{norm}(h_{i}^{l}) \tag{3}

若考虑边的权重,则聚合函数如下:

hN(i)(l+1)=aggregate({ejihjl,jN(i)})(4)h_{\mathcal{N}(i)}^{(l+1)} = \mathrm{aggregate} \left(\{e_{ji} h_{j}^{l}, \forall j \in \mathcal{N}(i) \}\right) \tag{4}

其中 ejie_{ji} 为边 ji{ji} 上的标量权重,在实现是要确保其可随 hjlh_j^{l} 广播

2 类定义及参数说明

class SAGEConv(nn.Module):
r"""

Parameters
----------
in_feats : int, or pair of ints
Input feature size; i.e, the number of dimensions of :math:`h_i^{(l)}`.
若aggregator为 ``gcn``, 则在异构图情况下,源节点和目的节点的feature size需要相等,
因为后面计算了这个:graph.dstdata['neigh'] + graph.dstdata['h']
out_feats : int
Output feature size; i.e, the number of dimensions of :math:`h_i^{(l+1)}`.
feat_drop : float
Dropout rate on features, default: ``0``.
aggregator_type : str
公式(1)中的聚合函数
Aggregator type to use (``mean``, ``gcn``, ``pool``, ``lstm``).
bias : bool
If True, adds a learnable bias to the output. Default: ``True``.
norm : callable activation function/layer or None, optional
If not None, applies normalization to the updated node features.
activation : callable activation function/layer or None, optional
If not None, applies an activation function to the updated node features.
Default: ``None``.
"""

3 init() function

def __init__(self,
in_feats,
out_feats,
aggregator_type,
feat_drop=0.,
bias=True,
norm=None,
activation=None):
super(SAGEConv, self).__init__()

self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
self._out_feats = out_feats
self._aggre_type = aggregator_type
self.norm = norm
self.feat_drop = nn.Dropout(feat_drop)
self.activation = activation
# aggregator type: mean/pool/lstm/gcn
# 为不同的聚合器初始化参数向量,具体聚合器在forward()中解释
if aggregator_type == 'pool':
self.fc_pool = nn.Linear(self._in_src_feats, self._in_src_feats)
if aggregator_type == 'lstm':
self.lstm = nn.LSTM(self._in_src_feats, self._in_src_feats, batch_first=True)
if aggregator_type != 'gcn':
self.fc_self = nn.Linear(self._in_dst_feats, out_feats, bias=bias)
# 公式(2)中,用于与concat相乘的参数向量
self.fc_neigh = nn.Linear(self._in_src_feats, out_feats, bias=bias)
self.reset_parameters()

4 forward() function

def forward(self, graph, feat, edge_weight=None):
r"""

Parameters
----------
graph : DGLGraph
The graph.
feat : torch.Tensor or pair of torch.Tensor
If a torch.Tensor is given, it represents the input feature of shape
:math:`(N, D_{in})`
where :math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.
If a pair of torch.Tensor is given, the pair must contain two tensors of shape
:math:`(N_{in}, D_{in_{src}})` and :math:`(N_{out}, D_{in_{dst}})`.
edge_weight : torch.Tensor, optional
Optional tensor on the edge. If given, the convolution will weight
with regard to the message.

Returns
-------
torch.Tensor
The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}`
is size of output feature.
"""
with graph.local_scope():
if isinstance(feat, tuple):
feat_src = self.feat_drop(feat[0])
feat_dst = self.feat_drop(feat[1])
else:
feat_src = feat_dst = self.feat_drop(feat)
# block还没学...
if graph.is_block:
feat_dst = feat_src[:graph.number_of_dst_nodes()]
# 不考虑边的权重
aggregate_fn = fn.copy_src('h', 'm')
# 考虑边的权重
if edge_weight is not None:
assert edge_weight.shape[0] == graph.number_of_edges()
graph.edata['_edge_weight'] = edge_weight
aggregate_fn = fn.u_mul_e('h', '_edge_weight', 'm')

h_self = feat_dst

# Handle the case of graphs without edges:那还传播个锤子?
# 将目的节点的'neigh'特征置为0
if graph.number_of_edges() == 0:
# .to(feat_dst): 转换到feat_dst的dtype和device
graph.dstdata['neigh'] = torch.zeros(
feat_dst.shape[0], self._in_src_feats).to(feat_dst)
  • MEAN聚合器:如公式 (5) 所示,对每个邻居特征做element-wise mean,即对某个节点的某个特征维度,将不同的邻居聚合结果在该维度的值求均值:
if self._aggre_type == 'mean':
graph.srcdata['h'] = feat_src
# 此处的update_all = aggregate_fn + reduce_fn,共同实现了公式中的aggregator
graph.update_all(aggregate_fn, fn.mean('m', 'neigh'))
h_neigh = graph.dstdata['neigh']

rst = self.fc_self(h_self) + self.fc_neigh(h_neigh)

hN(v)kMEANk({huk1,uN(v)})hvkσ(WkCONCAT({hvk1,hN(v)k})(5)\begin{array}{c} h_{N(v)}^{k} \leftarrow \operatorname{MEAN}_{k}\left(\left\{\mathbf{h}_{u}^{k-1}, \forall u \in N(v)\right\}\right) \\ h_{v}^{k} \leftarrow \sigma\left(\mathbf{W}^{\mathbf{k}} \cdot \operatorname{CONCAT}\left(\left\{\mathbf{h}_{v}^{k-1}, h_{N(v)}^{k}\right\}\right)\right. \tag{5} \end{array}

  • GCN聚合器:由于GCN论文中的模型是transductive的,GraphSAGE给出了GCN的inductive形式,如公式 (6) 所示,并说明We call this modified mean-based aggregator convolutional since it is a rough, linear approximation of a localized spectral convolution,且其mean是除以的节点的in-degree,这是与MEAN聚合器的区别之一。区别之二在于gcn 是直接将当前节点和邻居节点的特征求和后取平均,再做线性变换;而 mean 是首先concat 当前节点的特征和邻居节点的特征,再做线性变换,实际在实现上mean采用先线性变换后相加的方式来实现,实际上用到了两个fc(fc_self和fc_neigh),所以**「gcn只经过一个全连接层,而后者是分别用到了self和neigh两个全连接层」**。
      elif self._aggre_type == 'gcn':
# 若输入为一对feat,则检查源和目的节点的特征的维度是否相同
check_eq_shape(feat)
graph.srcdata['h'] = feat_src
graph.dstdata['h'] = feat_dst # 若为homograph,此行代码与上一行无差别
graph.update_all(aggregate_fn, fn.sum('m', 'neigh'))
degs = graph.in_degrees().to(feat_dst)

# norm:除以in_degrees + 1,而不是原始论文中的c_ji
# +1是为了防止除以0
h_neigh = (graph.dstdata['neigh'] + graph.dstdata['h']) / (degs.unsqueeze(-1) + 1)

rst = self.fc_neigh(h_neigh)

hvkσ(WMEAN({hvk1}huk1,uN(v)})(6)h_{v}^{k} \leftarrow \sigma\left(\mathbf{W} \cdot \operatorname{MEAN}\left(\left\{\mathbf{h}_{v}^{k-1}\right\} \cup \mathbf{h}_{u}^{k-1}, \forall u \in N(v)\right\}\right) \tag{6}

  • Pool聚合器:pooling aggregator 如公式 (7) 所示,每一个节点的向量都会对应一个全连接神经网络,然后基于 elementwise 取最大池化操作:
elif self._aggre_type == 'pool':
graph.srcdata['h'] = F.relu(self.fc_pool(feat_src))
graph.update_all(aggregate_fn, fn.max('m', 'neigh'))
h_neigh = graph.dstdata['neigh']

 AGGREGATE kpool =max({σ(Wpool huik+b),uiN(v)})(7)\text { AGGREGATE }_{k}^{\text {pool }}=\max \left(\left\{\sigma\left(\mathbf{W}_{\text {pool }} \mathbf{h}_{u_{i}}^{k}+\mathbf{b}\right), \forall u_{i} \in \mathcal{N}(v)\right\}\right) \tag{7}

  • LSTM聚合器:其表达能力比 mean 聚合器要强,但是 LSTM 是非对称的,即其考虑节点的顺序性,论文作者通过将节点进行随机排列(DGL没做随机排列,而是按照边id的顺序排列的)来调整 LSTM 对无序集的支持:
   # LSTM聚合器实现
def _lstm_reducer(self, nodes):
"""LSTM reducer
NOTE(zihao): lstm reducer with default schedule (degree bucketing)
is slow, we could accelerate this with degree padding in the future.
"""
m = nodes.mailbox['m'] # (B, L, D)
batch_size = m.shape[0]
# 返回一个dtpe和device与m相同的,用0填充的tensor,并将其reshape
h = (m.new_zeros((1, batch_size, self._in_src_feats)),
m.new_zeros((1, batch_size, self._in_src_feats)))
_, (rst, _) = self.lstm(m, h)
return {'neigh': rst.squeeze(0)}

elif self._aggre_type == 'lstm':
graph.srcdata['h'] = feat_src
graph.update_all(aggregate_fn, self._lstm_reducer)
h_neigh = graph.dstdata['neigh']
else:
raise KeyError('Aggregator type {} not recognized.'.format(self._aggre_type))
# GraphSAGE GCN does not require fc_self.
# gcn aggregator不考虑h_self.
if self._aggre_type == 'gcn':
rst = self.fc_neigh(h_neigh)

# 先拼接,再线性变换等价于先线性变换,再相加
else:
rst = self.fc_self(h_self) + self.fc_neigh(h_neigh)
# activation
if self.activation is not None:
rst = self.activation(rst)
# normalization
if self.norm is not None:
rst = self.norm(rst)

参考:https://zhuanlan.zhihu.com/p/142205899

文章作者: Alston
文章链接: https://lizitong67.github.io/2021/03/15/DGL%E6%BA%90%E7%A0%81%E8%A7%A3%E6%9E%90-GraphSAGE/
版权声明: 本博客所有文章除特别声明外,均采用 CC BY-NC-SA 4.0 许可协议。转载请注明来自 Alston's blog