SpatialEncoder

class dgl.nn.pytorch.gt.SpatialEncoder(max_dist, num_heads=1)

Do Transformers Really Perform Bad for Graph Representation中介绍的空间编码模块.
这个模块是一个可学习的空间嵌入模块,它对每个节点对之间的最短距离进行编码,以获得注意力偏差。

Parameters

  • max_dist(int) 要编码的每个节点对之间的最短路径距离的上限。所有距离将被限制在范围[0,max_dist]内。
  • num_heads(int, optional) 如果应用多头注意机制,则注意头的数量。预设值:1。

forward(dist)

Parameters
dist(Tensor) - 具有-1填充的批处理图的最短路径距离,形状为(B, N, N) ,其中B是批处理图,N是节点的最大数量。

Returns
返回注意偏置作为空间编码,形状为(B,N,N,H)

Return type
torch.Tensor

源代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
class SpatialEncoder(nn.Module):
def __init__(self, max_dist, num_heads=1):
super().__init__()
self.max_dist = max_dist
self.num_heads = num_heads
# deactivate node pair between which the distance is -1
self.embedding_table = nn.Embedding(
max_dist + 2, num_heads, padding_idx=0
)

def forward(self, dist):
spatial_encoding = self.embedding_table(
th.clamp(
dist,
min=-1,
max=self.max_dist,
)
+ 1
)
return spatial_encoding

Example

实际上就是把距离值一个数字,embedding成一个长度为num_heads的向量
Example1:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import torch
import torch.nn as nn
from dgl.nn import SpatialEncoder

# 创建一个 SpatialEncoder 实例
spatial_encoder = SpatialEncoder(max_dist=3, num_heads=2)
# 输入一些距离值
distances = torch.tensor([[0, 1, 2], [-1, 3, 4]])
# 对距离值进行空间编码处理
output = spatial_encoder(distances)
print(output)

# tensor([[[ 0.1907, -0.5513],
# [ 1.0202, 0.1709],
# [ 1.7177, 0.9194]],

# [[ 0.0000, 0.0000],
# [ 0.1681, 0.6919],
# [ 0.1681, 0.6919]]], grad_fn=<EmbeddingBackward0>)

Example2:
实际上就是将最短路径矩阵中的每一个数字都embedding成了一个向量,然后当前空间编码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import torch as th
import dgl
from dgl.nn import SpatialEncoder
from dgl import shortest_dist

g1 = dgl.graph(([0,0,0,1,1,2,3,3], [1,2,3,0,3,0,0,1]))
g2 = dgl.graph(([0,1], [1,0]))
n1, n2 = g1.num_nodes(), g2.num_nodes()
# use -1 padding since shortest_dist returns -1 for unreachable node pairs
dist = -th.ones((2, 4, 4), dtype=th.long)
dist[0, :n1, :n1] = shortest_dist(g1, root=None, return_paths=False)
dist[1, :n2, :n2] = shortest_dist(g2, root=None, return_paths=False)
spatial_encoder = SpatialEncoder(max_dist=2, num_heads=8)
out = spatial_encoder(dist)
print(dist)
# tensor([[[ 0, 1, 1, 1],
# [ 1, 0, 2, 1],
# [ 1, 2, 0, 2],
# [ 1, 1, 2, 0]],

# [[ 0, 1, -1, -1],
# [ 1, 0, -1, -1],
# [-1, -1, -1, -1],
# [-1, -1, -1, -1]]])
print(out.shape)
# torch.Size([2, 4, 4, 8])
print(out)
# tensor([[[[-0.1012, -0.0641, -1.3648, 0.5305, 1.6427, 0.9673, -1.4536,
# -1.1487],
# [-0.1723, -0.1548, -1.0919, -2.2756, -0.7477, 1.4145, 0.4393,
# -0.3580],
# [-0.1723, -0.1548, -1.0919, -2.2756, -0.7477, 1.4145, 0.4393,
# -0.3580],
# [-0.1723, -0.1548, -1.0919, -2.2756, -0.7477, 1.4145, 0.4393,
# -0.3580]],

# [[-0.1723, -0.1548, -1.0919, -2.2756, -0.7477, 1.4145, 0.4393,
# -0.3580],
# [-0.1012, -0.0641, -1.3648, 0.5305, 1.6427, 0.9673, -1.4536,
# -1.1487],
# [-1.3901, -0.3768, 0.6562, 0.4067, -0.7534, 1.0690, -1.0219,
# -0.9923],
# [-0.1723, -0.1548, -1.0919, -2.2756, -0.7477, 1.4145, 0.4393,
# -0.3580]],

# [[-0.1723, -0.1548, -1.0919, -2.2756, -0.7477, 1.4145, 0.4393,
# -0.3580],
# [-1.3901, -0.3768, 0.6562, 0.4067, -0.7534, 1.0690, -1.0219,
# -0.9923],
# [-0.1012, -0.0641, -1.3648, 0.5305, 1.6427, 0.9673, -1.4536,
# -1.1487],
# [-1.3901, -0.3768, 0.6562, 0.4067, -0.7534, 1.0690, -1.0219,
# -0.9923]],

# [[-0.1723, -0.1548, -1.0919, -2.2756, -0.7477, 1.4145, 0.4393,
# -0.3580],
# [-0.1723, -0.1548, -1.0919, -2.2756, -0.7477, 1.4145, 0.4393,
# -0.3580],
# [-1.3901, -0.3768, 0.6562, 0.4067, -0.7534, 1.0690, -1.0219,
# -0.9923],
# [-0.1012, -0.0641, -1.3648, 0.5305, 1.6427, 0.9673, -1.4536,
# -1.1487]]],


# [[[-0.1012, -0.0641, -1.3648, 0.5305, 1.6427, 0.9673, -1.4536,
# -1.1487],
# [-0.1723, -0.1548, -1.0919, -2.2756, -0.7477, 1.4145, 0.4393,
# -0.3580],
# [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
# 0.0000],
# [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
# 0.0000]],

# [[-0.1723, -0.1548, -1.0919, -2.2756, -0.7477, 1.4145, 0.4393,
# -0.3580],
# [-0.1012, -0.0641, -1.3648, 0.5305, 1.6427, 0.9673, -1.4536,
# -1.1487],
# [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
# 0.0000],
# [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
# 0.0000]],
# [[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
# 0.0000],
# [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
# 0.0000],
# [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
# 0.0000],
# [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
# 0.0000]],

# [[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
# 0.0000],
# [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
# 0.0000],
# [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
# 0.0000],
# [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
# 0.0000]]]], grad_fn=<EmbeddingBackward0>)

def shortest_dist(g, root=None, return_paths=False)

dist[0, :n1, :n1] = shortest_dist(g1, root=None, return_paths=False)

这个函数shortest_dist的作用是计算给定图中的最短距离和路径。以下是这个函数的主要功能和用法介绍:

功能:

  • 该函数用于计算图中节点之间的最短距离和路径。
  • 它支持无权图的情况。
  • 只考虑有向路径(其中所有边都朝着同一个方向)。

参数:

  • g : 输入图,必须是同质图。
  • root : 给定一个根节点ID,返回根节点与所有节点之间的最短距离和路径(可选)。如果为None,则返回所有节点对的结果。默认为None。
  • return_paths : 如果为True,则返回与最短距离对应的最短路径。默认为False。

返回值:

  • dist : 最短距离张量。
    如果root是一个节点ID,则形状为(N,),其中N是节点数。dist[j]给出从root到节点j的最短距离。
    否则,形状为(N, N)。dist[i][j]给出从节点i到节点j的最短距离。
    无法到达的节点对的距离值填充为-1。
  • paths : 最短路径张量(可选)。
    仅在return_paths为True时返回。
    如果root是一个节点ID,则形状为(N, L),其中L是最长路径的长度。path[j]是从root到节点j的最短路径。
    否则,形状为(N, N, L)。path[i][j]是从节点i到节点j的最短路径。
    每条路径是一个向量,由边ID构成,末尾填充为-1。
    节点与自身之间的最短路径是一个由-1填充的向量。

SpatialEncoder
http://jiqingjiang.github.io/p/5cc2d788/
作者
Jiqing
发布于
2024年8月2日
许可协议