Attention

architecture

Attention Mechanism

理论

心理学

  • 动物要有效关注值得注意的点
  • 心理学框架:人类根据随意线索(有意识)和不随意线索选择注意点

注意力机制

  • 卷积、全连接、池化层都只考虑不随意线索

  • 注意力机制则显示的考虑随意线索

    • 随意线索被称之为查询query(比如你想要喝咖啡)

    • 每个输入是一个值value和不随意线索key的对

    • 通过注意力池化层来有偏向的选择某些输入

      https://cdn.jsdelivr.net/gh/YikunHan42/Image-Host/202204071455369.png

非参注意力池化层

  • 给定数据(x, y), i = 1, …, n (x, y)为(key,value) pair

  • 平均池化是最简单的方案: $$ f(x))=\frac{1}{n} \sum\limits_{i}y_i $$ 不管查询query,直接把所有y取均值

  • 更好的方案是60年代提出来的Nadaraya-Watson核回归f(x)即为query $$ f(x)=\sum_{i=1}^{n}\frac{K(x-x_i)}{\sum\limits_{j=1}^{n}K(x-x_j)}y_i $$

    跟所有候选key减一下,乘以K,是一个kernel核(函数),衡量xx之间的距离,找跟新数据相近的数据

Nadaraya-Watson核回归

使用高斯核 $$ K(u)=\frac{1}{\sqrt{2 \pi}}e^{-\frac{u^2}{2}} $$ 那么 $$ f(x)=\sum\limits_{i=1}^nsoftmax(-\frac{1}{2}(x-x_i)^2)y_i $$

参数化的注意力机制

在之前基础上引入可以学习w(标量,拓展到多维也可以变成向量) $$ f(x)=\sum\limits_{i=1}^nsoftmax(-\frac{1}{2}((x-x_i)w)^2)y_i $$

总结

注意力机制中,通过query和key来有偏向性的选择输入,可以一般的写作 $$ f(x)=\sum\limits_{i} \alpha(x,x_i)y_i $$

这里的α(x, xi)是注意力权重

实现

注意力汇聚:核回归

核回归实现

n_train = 50  # 训练样本数
x_train, _ = torch.sort(torch.rand(n_train) * 5)   # 排序后的训练样本

def f(x):
    return 2 * torch.sin(x) + x**0.8

y_train = f(x_train) + torch.normal(0.0, 0.5, (n_train,))  # 训练样本的输出
x_test = torch.arange(0, 5, 0.1)  # 测试样本
y_truth = f(x_test)  # 测试样本的真实输出
n_test = len(x_test)  # 测试样本数n_test

def plot_kernel_reg(y_hat):    
    d2l.plot(x_test, [y_truth, y_hat], 'x', 'y', legend=['Truth', 'Pred'],             xlim=[0, 5], ylim=[-1, 5])    
    d2l.plt.plot(x_train, y_train, 'o', alpha=0.5);
    
y_hat = torch.repeat_interleave(y_train.mean(), n_test)
plot_kernel_reg(y_hat) #直接做均值

https://cdn.jsdelivr.net/gh/YikunHan42/Image-Host/202204071523286.png

非参数注意力汇聚(Pooling)

# X_repeat的形状:(n_test,n_train),
# 每一行都包含着相同的测试输入(例如:同样的查询)
# interleave即把(1,2,3,4)变成((1,1),(2,2),(3,3),(4,4)) 由n_train控制
X_repeat = x_test.repeat_interleave(n_train).reshape((-1, n_train))
# x_train包含着键。attention_weights的形状:(n_test,n_train),
# 每一行都包含着要在给定的每个查询的值(y_train)之间分配的注意力权重
attention_weights = nn.functional.softmax(-(X_repeat - x_train)**2 / 2, dim=1)# y_hat的每个元素都是值的加权平均值,其中的权重是注意力权重
y_hat = torch.matmul(attention_weights, y_train)
plot_kernel_reg(y_hat)

https://cdn.jsdelivr.net/gh/YikunHan42/Image-Host/202204071526011.png

问题:太过平滑,原来曲线变化其实比较大,需要较多数据 好处:不需要学,有理论支撑,数据充足时可以充分还原原始模型

斜率为-1,印证了越近权重越大

https://cdn.jsdelivr.net/gh/YikunHan42/Image-Host/202204071533789.png

带参数注意力汇聚

X = torch.ones((2, 1, 4))
Y = torch.ones((2, 4, 6))
torch.bmm(X, Y).shape # batch的矩阵乘法

对两个batch分别做1 * 4和4 * 6的矩阵乘法,batch数目不发生变化

class NWKernelRegression(nn.Module):    
    def __init__(self, **kwargs):
        super().__init__(**kwargs)        
        self.w = nn.Parameter(torch.rand((1,), requires_grad=True))
        
    def forward(self, queries, keys, values):
        # queries和attention_weights的形状为(查询个数,“键-值”对个数)        
        queries = queries.repeat_interleave(keys.shape[1]).reshape((-1, keys.shape[1]))        
        self.attention_weights = nn.functional.softmax(-((queries - keys) * self.w)**2 / 2, dim=1)        
        # values的形状为(查询个数,“键-值”对个数)        
        return torch.bmm(self.attention_weights.unsqueeze(1),                         					 values.unsqueeze(-1)).reshape(-1)

在刚刚的基础上加入w,以控制高斯核的窗口大小,来决定曲线平滑性

训练过程跳过…

# keys的形状:(n_test,n_train),每一行包含着相同的训练输入(例如,相同的键)
keys = x_train.repeat((n_test, 1))
# value的形状:(n_test,n_train)
values = y_train.repeat((n_test, 1))
y_hat = net(x_test, keys, values).unsqueeze(1).detach()
plot_kernel_reg(y_hat)

https://cdn.jsdelivr.net/gh/YikunHan42/Image-Host/202204071543610.png

没有那么平滑,但效果会更好

d2l.show_heatmaps(net.attention_weights.unsqueeze(0).unsqueeze(0),                  			 xlabel='Sorted training inputs',                        					  ylabel='Sorted testing inputs')

https://cdn.jsdelivr.net/gh/YikunHan42/Image-Host/202204071545901.png

w

使得窗口变窄,因而测试时对应训练样本的权重更加集中

Attention Score

理论

注意力分数

回顾 $$ f(x)=\sum\limits_{i}\alpha(x,x_i)y_i=\sum\limits_{i=1}^nsoftmax(-\frac{1}{2}(x-x_i)^2)y_i $$

α即注意力权重,如果用高斯核回归的话即后面的式子 $$ (-\frac{1}{2}(x-x_i)^2) $$

即注意力分数?见后文

https://cdn.jsdelivr.net/gh/YikunHan42/Image-Host/202204072138540.png

此处Query是一个向量,而非值

Yikun Han
Yikun Han
First Year Master Student

Wir müssen wissen. Wir werden wissen.