nn.Linear和实际初始化的Weight尺寸是相反的

最近在手撕Lora的代码的时候,发现一个有趣的现象。手撕代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
class LoRAModule(nn.Moudle):
def __init__(self, original_layer, lora_rank=4):
super().__init__()
self.original_layer = original_layer
self.lora_down = nn.Linear(original_layer.in_features, lora_rank, bias=False)
self.lora_up = nn.Linear(lora_rank, original_layer.out_features, bias=False)

# Initialize weights
nn.init.normal_(self.lora_down.weight, std=0.02)
nn.init.zeros_(self.lora_up.weight)

def forward(self, x):
return self.original_layer + self.lora_up(self.lora_down(x))

其中lora_down的初始化是nn.Linear(original_layer.in_features, lora_rank),其中original_layer.in_features即我们添加lora的层的输入的dim。而lora_up的初始化是nn.Linear(lora_rank, original_layer.out_features, bias=False),这样看好像是没有问题的, 因为lora计算的公式如下:

但是实际上我很快发现一个问题,其中$B$是lora_up, $A$是lora_down,那么$BA$的实际shape不是是$[\text{rank}, \text{out}]@[\text{input}, \text{rank}]$,根本乘不了呀!!但是代码实际上是没有问题的,原因如标题所示:nn.Linear(in_features, out_features) 的参数是weight.shape == (out_features, in_features), 这么一来就说的通了,BA实际上是$[\text{out}, \text{rank}]@[\text{rank}, \text{input}]$。仔细一查,根本原因是PyTorch中线性层forward 计算本质是:

所以$BA(X)$实际上是$ X@(BA)^\top $即 $XA^\top B^\top$。