LapPosEncoder

class dgl.nn.pytorch.gt.LapPosEncoder(model_type, num_layer, k, dim, n_head=1, batch_norm=False, num_post_layer=0)

Do Transformers Really Perform Bad for Graph Representation中介绍的Laplacian Positional Encoder(LPE)
该模块是使用Transformer或DeepSet的学习拉普拉斯位置编码模块。

Parameters

  • model_type(str) - LPE的编码器模型类型,只能是“Transformer”或“DeepSet”。
  • num_layer(int) - Transformer/DeepSet编码器中的层数。
  • k(int) - 最小非平凡特征向量的个数。
  • dim(int) - 最终拉普拉斯编码的输出大小。
  • n_head(int,optional) - Transformer编码器中的头数。预设值:1。
  • batch_norm (bool, optional) - 如果为True,则对原始拉普拉斯位置编码应用批量归一化。默认值:False。
  • num_post_layer (int, optional) - 如果num_post_layer > 0,则在池化之后应用num_post_layer层的MLP。默认值:0。

forward(eigvals, eigvecs)

Parameters
- eigvals(Tensor)-形状为(N,k),k 不同的拉普拉斯特征值 特征值重复N次,可以通过使用LaplacianPE获得。
- eigvecs (Tensor) – 形状为(N,k)的拉普拉斯特征向量,可以通过以下方式获得: 使用LaplacianPE。

Returns
返回形状(N,d)的拉普拉斯位置编码, 其中N是输入图中的节点数,ddim

Return type
torch.Tensor

源代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
class LapPosEncoder(nn.Module):
def __init__(
self,
model_type,
num_layer,
k,
dim,
n_head=1,
batch_norm=False,
num_post_layer=0,
):
super(LapPosEncoder, self).__init__()
self.model_type = model_type
self.linear = nn.Linear(2, dim)

if self.model_type == "Transformer":
encoder_layer = nn.TransformerEncoderLayer(
d_model=dim, nhead=n_head, batch_first=True
)
self.pe_encoder = nn.TransformerEncoder(
encoder_layer, num_layers=num_layer
)
elif self.model_type == "DeepSet":
layers = []
if num_layer == 1:
layers.append(nn.ReLU())
else:
self.linear = nn.Linear(2, 2 * dim)
layers.append(nn.ReLU())
for _ in range(num_layer - 2):
layers.append(nn.Linear(2 * dim, 2 * dim))
layers.append(nn.ReLU())
layers.append(nn.Linear(2 * dim, dim))
layers.append(nn.ReLU())
self.pe_encoder = nn.Sequential(*layers)
else:
raise ValueError(
f"model_type '{model_type}' is not allowed, must be "
"'Transformer' or 'DeepSet'."
)

if batch_norm:
self.raw_norm = nn.BatchNorm1d(k)
else:
self.raw_norm = None

if num_post_layer > 0:
layers = []
if num_post_layer == 1:
layers.append(nn.Linear(dim, dim))
layers.append(nn.ReLU())
else:
layers.append(nn.Linear(dim, 2 * dim))
layers.append(nn.ReLU())
for _ in range(num_post_layer - 2):
layers.append(nn.Linear(2 * dim, 2 * dim))
layers.append(nn.ReLU())
layers.append(nn.Linear(2 * dim, dim))
layers.append(nn.ReLU())
self.post_mlp = nn.Sequential(*layers)
else:
self.post_mlp = None

def forward(self, eigvals, eigvecs):
pos_encoding = th.cat((eigvecs.unsqueeze(2), eigvals.unsqueeze(2)), dim=2).float()
empty_mask = th.isnan(pos_encoding)

pos_encoding[empty_mask] = 0
if self.raw_norm:
pos_encoding = self.raw_norm(pos_encoding)
pos_encoding = self.linear(pos_encoding)

if self.model_type == "Transformer":
pos_encoding = self.pe_encoder(src=pos_encoding, src_key_padding_mask=empty_mask[:, :, 1])
else:
pos_encoding = self.pe_encoder(pos_encoding)

# Remove masked sequences.
pos_encoding[empty_mask[:, :, 1]] = 0

# Sum pooling.
pos_encoding = th.sum(pos_encoding, 1, keepdim=False)

# MLP post pooling.
if self.post_mlp:
pos_encoding = self.post_mlp(pos_encoding)

return pos_encoding

Example

Example1:

1
2
3
4
5
6
7
8
9
10
11
import dgl
from dgl import LapPE
from dgl.nn import LapPosEncoder
transform = LapPE(k=5, feat_name='eigvec', eigval_name='eigval', padding=True)
g = dgl.graph(([0,1,2,3,4,2,3,1,4,0], [2,3,1,4,0,0,1,2,3,4]))
g = transform(g)
eigvals, eigvecs = g.ndata['eigval'], g.ndata['eigvec']
transformer_encoder = LapPosEncoder(model_type="Transformer", num_layer=3, k=5, dim=16, n_head=4)
pos_encoding = transformer_encoder(eigvals, eigvecs)
deepset_encoder = LapPosEncoder(model_type="DeepSet", num_layer=3, k=5, dim=16, num_post_layer=2)
pos_encoding = deepset_encoder(eigvals, eigvecs)

LapPosEncoder
http://jiqingjiang.github.io/p/dc2703a/
作者
Jiqing
发布于
2024年8月2日
许可协议