BiasedMHA

class dgl.nn.pytorch.gt.BiasedMHA(feat_size, num_heads, bias=True, attn_bias_type='add', attn_drop=0.1)

具有图形注意偏差的密集多头注意模型。
Do Transformers Really Perform Bad for Graph Representation中介绍的使用从图结构中获得的注意力偏差计算节点之间的注意力。
$$\text{Attn}=\text{softmax}(\dfrac{QK^T}{\sqrt{d}} \circ b)$$
QK是节点的特征表示。d是对应的feat_size。b是注意力偏差,$\circ$ 运算符可以是加,也可以是乘。
,可以是加法或乘法。

Parameters

  • feat_size (int) - 特征尺寸。
  • num_heads (int) – 注意力头的数量,feat_size可被其整除。
  • bias (bool, optional) – 如果为True,则使用bias进行线性投影。默认值:True。
  • attn_bias_type (str, optional) – 注意力偏差的类型,用于修正注意力。从“add”或“mul”中选择。默认值:’add’。
  • attn_drop (float, optional) – 注意力权重的丢弃概率。Defalt:0.1。

forward(ndata, attn_bias=None, attn_mask=None)

Parameters
- ndata (torch.Tensor) – 3D输入张量。Shape:(batch_size,N,feat_size),其中N是节点的最大数量。
- attn_bias (torch.Tensor, optional) – 用于注意力修改的注意力偏差。形状:(batch_size,N,N,num_heads)。
-attn_mask (torch.Tensor, optional) – 用于避免计算无效位置的注意掩码,其中无效位置由True值指示。形状:(batch_size,N,N)。注意:对于与不存在的节点对应的行,请确保至少有一个条目设置为False,以防止使用softmax获取NaN。
Returns
y – 输出张量。形状:(batch_size,N,feat_size)

Return type
torch.Tensor

源代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import torch as th
import torch.nn as nn
import torch.nn.functional as F


class BiasedMHA(nn.Module):
def __init__(
self,
feat_size,
num_heads,
bias=True,
attn_bias_type="add",
attn_drop=0.1,
):
super().__init__()
self.feat_size = feat_size
self.num_heads = num_heads
self.head_dim = feat_size // num_heads
assert (
self.head_dim * num_heads == feat_size
), "feat_size must be divisible by num_heads"
self.scaling = self.head_dim**-0.5
self.attn_bias_type = attn_bias_type

self.q_proj = nn.Linear(feat_size, feat_size, bias=bias)
self.k_proj = nn.Linear(feat_size, feat_size, bias=bias)
self.v_proj = nn.Linear(feat_size, feat_size, bias=bias)
self.out_proj = nn.Linear(feat_size, feat_size, bias=bias)

self.dropout = nn.Dropout(p=attn_drop)

self.reset_parameters()

def reset_parameters(self):
nn.init.xavier_uniform_(self.q_proj.weight, gain=2**-0.5)
nn.init.xavier_uniform_(self.k_proj.weight, gain=2**-0.5)
nn.init.xavier_uniform_(self.v_proj.weight, gain=2**-0.5)

nn.init.xavier_uniform_(self.out_proj.weight)
if self.out_proj.bias is not None:
nn.init.constant_(self.out_proj.bias, 0.0)

def forward(self, ndata, attn_bias=None, attn_mask=None):
q_h = self.q_proj(ndata).transpose(0, 1)
k_h = self.k_proj(ndata).transpose(0, 1)
v_h = self.v_proj(ndata).transpose(0, 1)
bsz, N, _ = ndata.shape
q_h = (
q_h.reshape(N, bsz * self.num_heads, self.head_dim).transpose(0, 1)
* self.scaling
)
k_h = k_h.reshape(N, bsz * self.num_heads, self.head_dim).permute(
1, 2, 0
)
v_h = v_h.reshape(N, bsz * self.num_heads, self.head_dim).transpose(
0, 1
)

attn_weights = (
th.bmm(q_h, k_h)
.transpose(0, 2)
.reshape(N, N, bsz, self.num_heads)
.transpose(0, 2)
)

if attn_bias is not None:
if self.attn_bias_type == "add":
attn_weights += attn_bias
else:
attn_weights *= attn_bias
if attn_mask is not None:
attn_weights[attn_mask.to(th.bool)] = float("-inf")
attn_weights = F.softmax(
attn_weights.transpose(0, 2)
.reshape(N, N, bsz * self.num_heads)
.transpose(0, 2),
dim=2,
)

attn_weights = self.dropout(attn_weights)

attn = th.bmm(attn_weights, v_h).transpose(0, 1)

attn = self.out_proj(
attn.reshape(N, bsz, self.feat_size).transpose(0, 1)
)

return attn

Example

Example1:

1
2
3
4
5
6
import torch as th
from dgl.nn import BiasedMHA
ndata = th.rand(16, 100, 512)
bias = th.rand(16, 100, 100, 8)
net = BiasedMHA(feat_size=512, num_heads=8)
out = net(ndata, bias)

BiasedMHA
http://jiqingjiang.github.io/p/dd118e70/
作者
Jiqing
发布于
2024年8月4日
许可协议