SpatialEncoder
class dgl.nn.pytorch.gt.SpatialEncoder(max_dist, num_heads=1)
Do Transformers Really Perform Bad for Graph Representation中介绍的空间编码模块.
这个模块是一个可学习的空间嵌入模块,它对每个节点对之间的最短距离进行编码,以获得注意力偏差。
Parameters
- max_dist(int) 要编码的每个节点对之间的最短路径距离的上限。所有距离将被限制在范围[0,max_dist]内。
- num_heads(int, optional) 如果应用多头注意机制,则注意头的数量。预设值:1。
forward(dist)
Parameters
dist(Tensor) - 具有-1填充的批处理图的最短路径距离,形状为(B, N, N)
,其中B
是批处理图,N
是节点的最大数量。
Returns
返回注意偏置作为空间编码,形状为(B,N,N,H)
。
Return type
torch.Tensor
源代码
1 |
|
Example
实际上就是把距离值一个数字,embedding成一个长度为num_heads的向量
Example1:
1 |
|
Example2:
实际上就是将最短路径矩阵中的每一个数字都embedding成了一个向量,然后当前空间编码
1 |
|
def shortest_dist(g, root=None, return_paths=False)
dist[0, :n1, :n1] = shortest_dist(g1, root=None, return_paths=False)
这个函数shortest_dist的作用是计算给定图中的最短距离和路径。以下是这个函数的主要功能和用法介绍:
功能:
- 该函数用于计算图中节点之间的最短距离和路径。
- 它支持无权图的情况。
- 只考虑有向路径(其中所有边都朝着同一个方向)。
参数:
- g : 输入图,必须是同质图。
- root : 给定一个根节点ID,返回根节点与所有节点之间的最短距离和路径(可选)。如果为None,则返回所有节点对的结果。默认为None。
- return_paths : 如果为True,则返回与最短距离对应的最短路径。默认为False。
返回值:
- dist : 最短距离张量。
如果root是一个节点ID,则形状为(N,),其中N是节点数。dist[j]给出从root到节点j的最短距离。
否则,形状为(N, N)。dist[i][j]给出从节点i到节点j的最短距离。
无法到达的节点对的距离值填充为-1。 - paths : 最短路径张量(可选)。
仅在return_paths为True时返回。
如果root是一个节点ID,则形状为(N, L),其中L是最长路径的长度。path[j]是从root到节点j的最短路径。
否则,形状为(N, N, L)。path[i][j]是从节点i到节点j的最短路径。
每条路径是一个向量,由边ID构成,末尾填充为-1。
节点与自身之间的最短路径是一个由-1填充的向量。
SpatialEncoder
http://jiqingjiang.github.io/p/5cc2d788/