Swin Transformer 解析
Swin-Transformer
Transformer在NLP领域被广泛应用,能够建立大范围的数据之间的依赖关系,以attention的形式。语言具有良好的单个词作为元素基础,但是图像的基本元素在尺度变化上是非常大的,现有的基于transformer的视觉模型的token尺度都是固定的,对于视觉任务并不适合。另一方面,视觉任务具有更高的分辨率要求,比如semantic segmentation,精确到像素级,并不适合直接利用transformer的self attention机制(将是image size的平方了量级)。而Swin Transformer构建层级的feature map,而且计算复杂度是image size的线性函数。整体来说,Swin Transformer先由小patch开始,再到深层融合邻居patch信息。并且self-attention只用在不相邻的大窗口(就是shift window,Swin的全称)内,窗口之间并没有overlap。
Swin-Transformer作为一个适合与视觉任务,特别是适合稠密预测(检测、分割)的backbone而出现,它可以结合众多的检测方法。在此之前,有 ViT 以及它的改进工作提出的backbone,ViT需要大量数据预训练(JFT-300M),它的改进 DeiT 运用了一些训练策略使得可以在ImageNet-1K上预训练。ViT虽然取得令人振奋的效果,但是它本身其实并不适合作为密集视觉任务的backbone,因为它的低分辨率的特征图与平方计算复杂度。
细节过程
这里结合代码研究分析Swin Transformer的具体过程和原理。先给出整个流程的定义。
class SwinTransformer(nn.Module):
def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,
embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
use_checkpoint=False, **kwargs):
super().__init__()
self.num_classes = num_classes
self.num_layers = len(depths) ## 4
self.embed_dim = embed_dim # 96
self.ape = ape # the result is bad
self.patch_norm = patch_norm
self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) # 96 * 2^3
self.mlp_ratio = mlp_ratio
# split image into non-overlapping patches
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None)
num_patches = self.patch_embed.num_patches # 56*56
patches_resolution = self.patch_embed.patches_resolution # (56, 56)
self.patches_resolution = patches_resolution
# absolute position embedding
if self.ape:
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
trunc_normal_(self.absolute_pos_embed, std=.02)
self.pos_drop = nn.Dropout(p=drop_rate)
# stochastic depth
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
# build layers
self.layers = nn.ModuleList()
for i_layer in range(self.num_layers):
layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
input_resolution=(patches_resolution[0] // (2 ** i_layer),
patches_resolution[1] // (2 ** i_layer)),
depth=depths[i_layer],
num_heads=num_heads[i_layer],
window_size=window_size,
mlp_ratio=self.mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
norm_layer=norm_layer,
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
use_checkpoint=use_checkpoint)
self.layers.append(layer)
self.norm = norm_layer(self.num_features)
self.avgpool = nn.AdaptiveAvgPool1d(1)
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@torch.jit.ignore
def no_weight_decay(self):
return {'absolute_pos_embed'}
@torch.jit.ignore
def no_weight_decay_keywords(self):
return {'relative_position_bias_table'}
def forward_features(self, x):
x = self.patch_embed(x)
if self.ape:
x = x + self.absolute_pos_embed
x = self.pos_drop(x)
for layer in self.layers:
x = layer(x)
x = self.norm(x) # B L C
x = self.avgpool(x.transpose(1, 2)) # B C 1
x = torch.flatten(x, 1)
return x
def forward(self, x):
x = self.forward_features(x)
x = self.head(x)
return x
预处理
对于分类模型,输入图像尺寸为$224\times 224 \times 3$,即$H=W=224$。按照原文描述,模型先将图像分割成每块大小为$4\times 4$的patch,那么就会有$56\times 56$个patch,这就是初始resolution,也是后面每个stage会降采样的维度。后面每个stage都会降采样时长宽降到一半,特征数加倍。按照原文及原图描述,划分的每个patch具有$4\times4\times3=48$维特征。
-
实际在代码中,首先使用了PatchEmbed模块,定义如下:
class PatchEmbed(nn.Module): def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] self.img_size = img_size self.patch_size = patch_size self.patches_resolution = patches_resolution self.num_patches = patches_resolution[0] * patches_resolution[1] self.in_chans = in_chans self.embed_dim = embed_dim self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) if norm_layer is not None: self.norm = norm_layer(embed_dim) else: self.norm = None def forward(self, x): B, C, H, W = x.shape # FIXME look at relaxing size constraints assert H == self.img_size[0] and W == self.img_size[1], \ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C if self.norm is not None: x = self.norm(x) return x
可以看到,实际操作使用了一个卷积层conv2d(3, 96, 4, 4),直接就做了划分patch和编码初始特征的工作,对于输入$x: B\times 3\times 224\times 224$,经过一层conv2d和LayerNorm得到$x: B\times 56^2\times 96$。然后作为对比,可以选择性地加上每个patch的绝对位置编码,原文实验表示这种做法不好,因此不会采用(ape=false)。最后经过一层dropout,至此,预处理完成。另外,要注意的是,代码和上面流程图并不符,其实在stage 1之前,即预处理完成后,维度已经是$H/4\times W/4\times C$,stage 1之后已经是$H/8\times W/8\times 2C$,不过在stage 4后不再降采样,得到的还是$H/32\times W/32 \times 8C$。
stage处理
我们先梳理整个stage的大体过程,把简单的部分先说了,再深入到复杂得的细节。每个stage,即代码中的BasicLayer,由若干个block组成,而block的数目由depth列表中的元素决定。每个block就是W-MSA(window-multihead self attention)或者SW-MSA(shift window multihead self attention),一般有偶数个block,两种SA交替出现,比如6个block,0,2,4是W-MSA,1,3,5是SW-MSA。在经历完一个stage后,会进行下采样,定义的下采样比较有意思。比如还是$56\times 56$个patch,四个为一组,分别取每组中的左上,右上、左下、右下堆叠一起,经过一个layernorm,linear层,实现维度下采样、特征加倍的效果。实际上它可以看成一种加权池化的过程。代码如下:
class PatchMerging(nn.Module):
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(4 * dim)
def forward(self, x):
"""
x: B, H*W, C
"""
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
x = x.view(B, H, W, C)
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
x = self.norm(x)
x = self.reduction(x)
return x
在经历完4个stage后,得到的是$(H/32\times W/32)\times 8C$的特征,将其转到$8C\times (H/32\times W/32)$后,接一个AdaptiveAvgPool1d(1),全局平均池化,得到$8C$特征,最后接一个分类器。
Block处理
上面说到有两种block,block的代码如下:
class SwinTransformerBlock(nn.Module):
def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
if min(self.input_resolution) <= self.window_size:
# if window size is larger than input resolution, we don't partition windows
self.shift_size = 0
self.window_size = min(self.input_resolution)
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
self.norm1 = norm_layer(dim)
self.attn = WindowAttention(
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
if self.shift_size > 0:
# calculate attention mask for SW-MSA
H, W = self.input_resolution
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
h_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
else:
attn_mask = None
self.register_buffer("attn_mask", attn_mask)
def forward(self, x):
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
shortcut = x
x = self.norm1(x)
x = x.view(B, H, W, C)
# cyclic shift
if self.shift_size > 0:
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
else:
shifted_x = x
# partition windows
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
# W-MSA/SW-MSA
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
# merge windows
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
# reverse cyclic shift
if self.shift_size > 0:
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
x = shifted_x
x = x.view(B, H * W, C)
# FFN
x = shortcut + self.drop_path(x)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
W-MSA
W-MSA比较简单,只要其中shift_size
设置为0就是W-MSA。下面跟着代码走一遍过程。
-
输入:$x: B\times 56^2\times 96$,$H, W=56$
-
经过一层layerNorm
-
变形:$x: B\times 56\times 56\times 96$
-
直接赋值给
shifted_x
-
调用
window_partition
函数,输入shifted_x
,window_size=7
:- 注意窗口大小以patch为单位,比如7就是7个patch,如果56的分辨率就会有8个窗口。
- 这个函数对
shifted_x
做一系列变形,最终变成$8^2B\times 7 \times 7\times 96$
-
返回赋值给
x_windows
,再变形成$8^2B\times 7^2\times 96$,这表示所有图片,每个图片的64个window,每个window内有49个patch。 -
调用
WindowAttention
层,这里以它的num_head
为3为例。输入参数为x_windows
和self.attn_mask
,对于W-MSA,attn_mask
为None,可以不用管。-
WindowAttention
代码如下:class WindowAttention(nn.Module): def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): super().__init__() self.dim = dim self.window_size = window_size # Wh, Ww self.num_heads = num_heads head_dim = dim // num_heads self.scale = qk_scale or head_dim ** -0.5 # define a parameter table of relative position bias self.relative_position_bias_table = nn.Parameter( torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH # get pair-wise relative position index for each token inside the window coords_h = torch.arange(self.window_size[0]) coords_w = torch.arange(self.window_size[1]) coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 relative_coords[:, :, 1] += self.window_size[1] - 1 relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww self.register_buffer("relative_position_index", relative_position_index) self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) trunc_normal_(self.relative_position_bias_table, std=.02) self.softmax = nn.Softmax(dim=-1) def forward(self, x, mask=None): """ Args: x: input features with shape of (num_windows*B, N, C) mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None """ B_, N, C = x.shape qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) q = q * self.scale attn = (q @ k.transpose(-2, -1)) relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww attn = attn + relative_position_bias.unsqueeze(0) if mask is not None: nW = mask.shape[0] attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) attn = attn.view(-1, self.num_heads, N, N) attn = self.softmax(attn) else: attn = self.softmax(attn) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B_, N, C) x = self.proj(x) x = self.proj_drop(x) return x
-
输入$x: 8^2B\times 7^2\times 96$。
-
产生$QKV$,调用线性层后,得到$8^2B\times 7^2\times (96\times 3)$,拆分给不同的head,得到$8^2B\times 7^2\times 3 \times 3\times 32$,第一个3是$QKV$的3,第二个3是3个head。再permute成$3\times 8^2B\times 3\times 7^2\times 32$,再拆解成$q,k,v$,每个都是$8^2B\times 3\times 7^2\times 32$。表示所有图片的每个图片64个window,每个window对应到3个不同的head,都有一套49个patch、32维的特征。
-
$q$归一化
-
$qk$矩阵相乘求特征内积,得到$attn: 8^2B \times 3\times 7^2\times 7^2$
-
得到相对位置的编码信息
relative_position_bias
:-
代码如下:
self.relative_position_bias_table = nn.Parameter( torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH # get pair-wise relative position index for each token inside the window coords_h = torch.arange(self.window_size[0]) coords_w = torch.arange(self.window_size[1]) coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 relative_coords[:, :, 1] += self.window_size[1] - 1 relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww self.register_buffer("relative_position_index", relative_position_index)
-
这里以
window_size=3
为例,解释以下过程:首先生成$coords: 2\times3\times 3$,就是在一个$3\times 3$的窗口内,每个位置的$y,x$坐标,而relative_coords
为$2\times 9 \times 9$,就是9个点中,每个点的$y$或$x$与其他所有点的差值,比如$ 0 1 $,表示3号点(第二行第一个点)与1号点(第一行第二个点)的$y$坐标的差值。然后变形,并让两个坐标分别加上$3-1=2$,是因为这些坐标值范围$[0,2]$,因此差值的最小值为-2,加上2后从0开始。最后让$y$坐标乘上$2\times 3-1=5$,应该是一个trick,调整差值范围。最后将两个维度的差值相加,得到relative_position_index
,$3^2\times 3^2$,为9个点之间两两之间的相对位置编码值,最后用来到self.relative_position_bias_table
中寻址,注意相对位置的最大值为$(2M-2)(2M-1)$,而这个table最多有$(2M-1)(2M-1)$行,因此保证可以寻址,得到了一组给多个head使用的相对位置编码信息,这个table是可训练的参数。 -
回到代码中,得到的
relative_position_bias
为$3\times 7^2\times7^2$
-
-
将其加到
attn
上,最后一个维度softmax,dropout -
与$v$矩阵相乘,并转置,合并多个头的信息,得到$8^2B\times 7^2\times 96$
-
经过一层线性层,dropout,返回
-
-
返回赋值给
attn_windows
,变形为$8^2B\times 7\times 7 \times 96$ -
调用
window_reverse
,打回原状:$B\times 56\times 56\times96$ -
返回给$x$,经过FFN:先加上原来的输入$x$作为residue结构,注意这里用到[timm][https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py]的
DropPath
,并且drop的概率是整个网络结构线性增长的。然后再加上两层mlp的结果。 -
返回结果$x$。
这样,整个过程就完成了,剩下的就是SW-MSA的一些不同的操作。
SW-MSA
W-MSA的不足就不同window之间是没有交流的,因此,SW-MSA紧接在W-MSA后面就是想让不同window之间也有信息传递。因此,最直接的想法就是将图片平移以下,但不是平移window_size
,比如平移一半,再以跟之前相同的划分来切割,这样就会有上一轮本来不在同一个window里面的pach,现在出现在同一个window。但是,最简单的处理,比如像第一个图,会出现更多的window,有些window尺寸还不一样。一种修补方法是补上一些块,padding,但是会增大计算量。为此,这里采用cycling(rolling)的方法,即滚动,将图片向左、向上平移window一半的尺寸,比如window是7,这里就平移三个patch,左边的那些块移到右边,上边的那些块移到下边。
再回到block的代码中,这时shift_size=7//2=3
:
-
输入:$x: B\times 56^2\times 96$,$H, W=56$
-
经过一层layerNorm
-
变形:$x: B\times 56\times 56\times 96$
-
对$x$进行roll:就是分别向左向上滚动,把块补在右面和下面,赋值给
shifted_x
-
调用
window_partition
函数,输入shifted_x
,window_size=7
:- 注意窗口大小以patch为单位,比如7就是7个patch,如果56的分辨率就会有8个窗口。
- 这个函数对
shifted\_x
做一系列变形,最终变成$8^2B\times 7 \times 7\times 96$
-
返回赋值给
x\_windows
,再变形成$8^2B\times 7^2\times 96$,这表示所有图片,每个图片的64个window,每个window内有49个patch。 -
调用
WindowAttention
层,这里以它的num_head
为3为例。输入参数为x_windows
和self.attn_mask
,对于SW-MSA,attn_mask
的产生过程如下。-
mask的形状为$1\times56\times56\times 1$,产生
h_slices
和w_slices
都是三个:(0, -7),(-7, -3),(-3, None),通过一个循环,对mask分成9大块,并打上标号。-
下图为了方便显示,换了个例子:$12\times 12$个patch,window为$3\times 3$个patch,平移量1,$4\times4$个窗口
-
解释:经过循环,mask被分成9大部分,分别表上0到8,然后经过
window\_partition
,得到$4^2\times3^2$,再作差,得到$4^2\times 3^2\times3^2$,将不为0的置为-100,将为0的填上0。可以理解成,在$4^2$个窗口中的每个窗口,都有9个位置,那么这9个位置中的每个应该关注哪些位置(0),不用关注哪些位置(-100),比如右上角那个区域的左上角位置,它对应的mask就应该是上面那一个横条,不用关注右边蓝色的三格,因为右边那列是由左边平移来的,而计算两个边缘之间相关性作用不大。其他位置同理。
-
-
那么回到原来的代码中,产生的mask就应该是$8^2\times 7^2\times7^2$,计算attention时,将$attn: 8^2B \times 3\times 7^2\times 7^2$变形成$B\times 8^2\times 3\times7^2\times7^2$加上mask,再变回去,然后softmax。
-
这里其实存在一点疑问:因为窗口天然隔离,上面的(0,-7), (-7, -3)这一段分割显得多余。
-
-
后面就和W-MSA一样了,不过还要把$x$转回去,否则就要不停地转动了。
最后,Swin-Transformer有4种架构:
关于MSA中的复杂度运算:
- 对于普通的MSA,在全图上计算SA,如果尺寸为$h\times w$,那么产生$QWV$的$W$矩阵为$C\times C$,输入$x$为$hw\times C$,因此需要复杂度$3hwC^2$。然后,计算$QK^T$,$QKV$的维度是$hw\times C$,因此为$(hw)^2C$。然后softmax乘$V$得到$Z$,复杂度为$(hw)^2C$。最后,从$Z$到输出,还要乘一个矩阵$W$,复杂度$hwC^2$,因此总复杂度为$4hwC^2+2(hw)^2C$。
- 对于W-MSA,因为分窗口进行SA,因此主要优化的是$(hw)^2C$项,现在变成$W^2\times W^2\times (h/W)\times(w/W)C=W^2hwC$,其实还是四次的,只不过将其中一个$hw$变成可以更小的$W^2$,减小的部分就是窗口与窗口之间的信息传递。