PathEncoder

class dgl.nn.pytorch.gt.PathEncoder(max_len, feat_dim, num_heads=1)

Do Transformers Really Perform Bad for Graph Representation中介绍的路径编码器(Path Encoder)
该模块是一个可学习的路径嵌入模块,并将每对节点之间的最短路径编码为注意力偏置。

Parameters

  • max_len(int) 每个路径中要编码的最大边数。每条路径的超出部分将被截断,即截断序列号不小于max_len的边。
  • feat_dim(int) - 输入图中边特征的维数。
  • num_heads(int, optional) 如果应用多头注意机制,则注意头的数量。预设值:1。

forward(dist,path_data)

Parameters
- dist(Tensor) - 具有零填充的批处理图的最短路径距离矩阵,形状为(B, N, N) ,其中B是批处理图,N是节点的最大数量。
- path_data(Tensor) - 沿最短路径(零填充)的边特征,形状为(B,N,N,L,d),其中L是最短路径的最大值,d是边特征的维数。
Returns
返回注意偏置作为路径编码,形状为(B,N,N,H) ,其中B是输入图的批次,N是节点的最大数量,以及H是num_heads。

Return type
torch.Tensor

源代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
class PathEncoder(nn.Module):
def __init__(self, max_len, feat_dim, num_heads=1):
super().__init__()
self.max_len = max_len
self.feat_dim = feat_dim
self.num_heads = num_heads
self.embedding_table = nn.Embedding(max_len * num_heads, feat_dim)

def forward(self, dist, path_data):
shortest_distance = th.clamp(dist, min=1, max=self.max_len)
edge_embedding = self.embedding_table.weight.reshape(
self.max_len, self.num_heads, -1
)
path_encoding = th.div(
th.einsum("bxyld,lhd->bxyh", path_data, edge_embedding).permute(
3, 0, 1, 2
),
shortest_distance,
).permute(1, 2, 3, 0)
return path_encoding

Example

Example1:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import torch as th
import dgl
from dgl.nn import PathEncoder
from dgl import shortest_dist

g = dgl.graph(([0,0,0,1,1,2,3,3], [1,2,3,0,3,0,0,1]))
edata = th.rand(8, 16)
# Since shortest_dist returns -1 for unreachable node pairs,
# edata[-1] should be filled with zero padding.
edata = th.cat((edata, th.zeros(1, 16)), dim=0)
dist, path = shortest_dist(g, root=None, return_paths=True)
path_data = edata[path[:, :, :2]]
path_encoder = PathEncoder(2, 16, num_heads=8)
out = path_encoder(dist.unsqueeze(0), path_data.unsqueeze(0))
print(out.shape)

shortest_dist

dgl.shortest_dist(g, root=None, return_paths=False, dist=None, method='dijkstra')
计算图中每对节点之间的最短路径距离。如果rootNone,则返回所有节点对之间的最短路径距离。如果root不为None,则返回从root到所有其他节点的最短路径距离。

Parameters

  • g(DGLGraph) - 输入图。
  • root(int, optional) - 根节点。如果为None,则返回所有节点对之间的最短路径距离。预设值:None
  • return_paths(bool, optional) - 如果为True,则返回每对节点之间的最短路径。预设值:False

Returns

  • dist(Tensor) - 如果root是节点ID,返回最短距离,形状是(N,)。如果root是None,则返回最短路径距离矩阵,形状为(N, N),其中N是节点的最大数量。
  • path(Tensor) - 如果root是节点ID,返回最短路径,形状是(N,L)。如果root是None,返回最短路径矩阵,形状为(N, N, L),其中L是最短路径的最大值。

Example

可以看dgl官网给出的这个例子,这里有四个顶点0,1,2,3,然后有四条边,eid:0 ->(0,2), eid:1 ->(1,0), eid:2 ->(1,3), eid:3 ->(2,3)
返回的paths矩阵的内容就对应着这里的eid

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
g = dgl.graph(([0, 1, 1, 2], [2, 0, 3, 3]))
dgl.shortest_dist(g, root=0)
dist, paths = dgl.shortest_dist(g, root=None, return_paths=True)
print(dist)
# tensor([[ 0, -1, 1, 2],
# [ 1, 0, 2, 1],
# [-1, -1, 0, 1],
# [-1, -1, -1, 0]])
print(paths)
# tensor([[[-1, -1],
# [-1, -1],
# [ 0, -1],
# [ 0, 3]],

# [[ 1, -1],
# [-1, -1],
# [ 1, 0],
# [ 2, -1]],

# [[-1, -1],
# [-1, -1],
# [-1, -1],
# [ 3, -1]],

# [[-1, -1],
# [-1, -1],
# [-1, -1],
# [-1, -1]]])

PathEncoder
http://jiqingjiang.github.io/p/f3181863/
作者
Jiqing
发布于
2024年8月2日
许可协议