BiasedMHA
class dgl.nn.pytorch.gt.BiasedMHA(feat_size, num_heads, bias=True, attn_bias_type='add', attn_drop=0.1)
具有图形注意偏差的密集多头注意模型。
Do Transformers Really Perform Bad for Graph Representation中介绍的使用从图结构中获得的注意力偏差计算节点之间的注意力。
$$\text{Attn}=\text{softmax}(\dfrac{QK^T}{\sqrt{d}} \circ b)$$Q
和K
是节点的特征表示。d
是对应的feat_size。b
是注意力偏差,$\circ$ 运算符可以是加,也可以是乘。
,可以是加法或乘法。
Parameters
- feat_size (int) - 特征尺寸。
- num_heads (int) – 注意力头的数量,feat_size可被其整除。
- bias (bool, optional) – 如果为True,则使用bias进行线性投影。默认值:True。
- attn_bias_type (str, optional) – 注意力偏差的类型,用于修正注意力。从“add”或“mul”中选择。默认值:’add’。
- attn_drop (float, optional) – 注意力权重的丢弃概率。Defalt:0.1。
forward(ndata, attn_bias=None, attn_mask=None)
Parameters
- ndata (torch.Tensor) – 3D输入张量。Shape:(batch_size,N,feat_size),其中N是节点的最大数量。
- attn_bias (torch.Tensor, optional) – 用于注意力修改的注意力偏差。形状:(batch_size,N,N,num_heads)。
-attn_mask (torch.Tensor, optional) – 用于避免计算无效位置的注意掩码,其中无效位置由True值指示。形状:(batch_size,N,N)。注意:对于与不存在的节点对应的行,请确保至少有一个条目设置为False,以防止使用softmax获取NaN。
Returns
y – 输出张量。形状:(batch_size,N,feat_size)
Return type
torch.Tensor
源代码
1 |
|
Example
Example1:
1 |
|
BiasedMHA
http://jiqingjiang.github.io/p/dd118e70/