Swin Transformer 解析


Transformer在NLP领域被广泛应用,能够建立大范围的数据之间的依赖关系,以attention的形式。语言具有良好的单个词作为元素基础,但是图像的基本元素在尺度变化上是非常大的,现有的基于transformer的视觉模型的token尺度都是固定的,对于视觉任务并不适合。另一方面,视觉任务具有更高的分辨率要求,比如semantic segmentation,精确到像素级,并不适合直接利用transformer的self attention机制(将是image size的平方了量级)。而Swin Transformer构建层级的feature map,而且计算复杂度是image size的线性函数。整体来说,Swin Transformer先由小patch开始,再到深层融合邻居patch信息。并且self-attention只用在不相邻的大窗口(就是shift window,Swin的全称)内,窗口之间并没有overlap。

Swin-Transformer作为一个适合与视觉任务,特别是适合稠密预测(检测、分割)的backbone而出现,它可以结合众多的检测方法。在此之前,有 ViT 以及它的改进工作提出的backbone,ViT需要大量数据预训练(JFT-300M),它的改进 DeiT 运用了一些训练策略使得可以在ImageNet-1K上预训练。ViT虽然取得令人振奋的效果,但是它本身其实并不适合作为密集视觉任务的backbone,因为它的低分辨率的特征图与平方计算复杂度。


Swin Transformer基本结构

这里结合代码研究分析Swin Transformer的具体过程和原理。先给出整个流程的定义。

class SwinTransformer(nn.Module):
    def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,
                 embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
                 window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
                 norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
                 use_checkpoint=False, **kwargs):

        self.num_classes = num_classes 
        self.num_layers = len(depths)  ## 4
        self.embed_dim = embed_dim  # 96
        self.ape = ape  # the result is bad
        self.patch_norm = patch_norm  
        self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))  # 96 * 2^3
        self.mlp_ratio = mlp_ratio

        # split image into non-overlapping patches
        self.patch_embed = PatchEmbed(
            img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
            norm_layer=norm_layer if self.patch_norm else None)
        num_patches = self.patch_embed.num_patches  # 56*56
        patches_resolution = self.patch_embed.patches_resolution  # (56, 56)
        self.patches_resolution = patches_resolution

        # absolute position embedding
        if self.ape:
            self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
            trunc_normal_(self.absolute_pos_embed, std=.02)

        self.pos_drop = nn.Dropout(p=drop_rate)

        # stochastic depth
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule

        # build layers
        self.layers = nn.ModuleList()
        for i_layer in range(self.num_layers):
            layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
                               input_resolution=(patches_resolution[0] // (2 ** i_layer),
                                                 patches_resolution[1] // (2 ** i_layer)),
                               qkv_bias=qkv_bias, qk_scale=qk_scale,
                               drop=drop_rate, attn_drop=attn_drop_rate,
                               drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
                               downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,

        self.norm = norm_layer(self.num_features)
        self.avgpool = nn.AdaptiveAvgPool1d(1)
        self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()


    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def no_weight_decay(self):
        return {'absolute_pos_embed'}

    def no_weight_decay_keywords(self):
        return {'relative_position_bias_table'}

    def forward_features(self, x):
        x = self.patch_embed(x)
        if self.ape:
            x = x + self.absolute_pos_embed
        x = self.pos_drop(x)

        for layer in self.layers:
            x = layer(x)

        x = self.norm(x)  # B L C
        x = self.avgpool(x.transpose(1, 2))  # B C 1
        x = torch.flatten(x, 1)
        return x

    def forward(self, x):
        x = self.forward_features(x)
        x = self.head(x)
        return x

对于分类模型,输入图像尺寸为$224\times 224 \times 3$,即$H=W=224$。按照原文描述,模型先将图像分割成每块大小为$4\times 4$的patch,那么就会有$56\times 56$个patch,这就是初始resolution,也是后面每个stage会降采样的维度。后面每个stage都会降采样时长宽降到一半,特征数加倍。按照原文及原图描述,划分的每个patch具有$4\times4\times3=48$维特征。

  • 实际在代码中,首先使用了PatchEmbed模块,定义如下:

    class PatchEmbed(nn.Module):
        def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
            img_size = to_2tuple(img_size)
            patch_size = to_2tuple(patch_size)
            patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
            self.img_size = img_size
            self.patch_size = patch_size
            self.patches_resolution = patches_resolution
            self.num_patches = patches_resolution[0] * patches_resolution[1]
            self.in_chans = in_chans
            self.embed_dim = embed_dim
            self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
            if norm_layer is not None:
                self.norm = norm_layer(embed_dim)
                self.norm = None
        def forward(self, x):
            B, C, H, W = x.shape
            # FIXME look at relaxing size constraints
            assert H == self.img_size[0] and W == self.img_size[1], \
                f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
            x = self.proj(x).flatten(2).transpose(1, 2)  # B Ph*Pw C
            if self.norm is not None:
                x = self.norm(x)
            return x

    可以看到,实际操作使用了一个卷积层conv2d(3, 96, 4, 4),直接就做了划分patch和编码初始特征的工作,对于输入$x: B\times 3\times 224\times 224$,经过一层conv2d和LayerNorm得到$x: B\times 56^2\times 96$。然后作为对比,可以选择性地加上每个patch的绝对位置编码,原文实验表示这种做法不好,因此不会采用(ape=false)。最后经过一层dropout,至此,预处理完成。另外,要注意的是,代码和上面流程图并不符,其实在stage 1之前,即预处理完成后,维度已经是$H/4\times W/4\times C$,stage 1之后已经是$H/8\times W/8\times 2C$,不过在stage 4后不再降采样,得到的还是$H/32\times W/32 \times 8C$。


我们先梳理整个stage的大体过程,把简单的部分先说了,再深入到复杂得的细节。每个stage,即代码中的BasicLayer,由若干个block组成,而block的数目由depth列表中的元素决定。每个block就是W-MSA(window-multihead self attention)或者SW-MSA(shift window multihead self attention),一般有偶数个block,两种SA交替出现,比如6个block,0,2,4是W-MSA,1,3,5是SW-MSA。在经历完一个stage后,会进行下采样,定义的下采样比较有意思。比如还是$56\times 56$个patch,四个为一组,分别取每组中的左上,右上、左下、右下堆叠一起,经过一个layernorm,linear层,实现维度下采样、特征加倍的效果。实际上它可以看成一种加权池化的过程。代码如下:

class PatchMerging(nn.Module):
    def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
        self.input_resolution = input_resolution
        self.dim = dim
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = norm_layer(4 * dim)

    def forward(self, x):
        x: B, H*W, C
        H, W = self.input_resolution
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"
        assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."

        x = x.view(B, H, W, C)

        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C

        x = self.norm(x)
        x = self.reduction(x)

        return x

在经历完4个stage后,得到的是$(H/32\times W/32)\times 8C$的特征,将其转到$8C\times (H/32\times W/32)$后,接一个AdaptiveAvgPool1d(1),全局平均池化,得到$8C$特征,最后接一个分类器。



class SwinTransformerBlock(nn.Module):
    def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
                 act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        self.dim = dim
        self.input_resolution = input_resolution
        self.num_heads = num_heads
        self.window_size = window_size
        self.shift_size = shift_size
        self.mlp_ratio = mlp_ratio
        if min(self.input_resolution) <= self.window_size:
            # if window size is larger than input resolution, we don't partition windows
            self.shift_size = 0
            self.window_size = min(self.input_resolution)
        assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"

        self.norm1 = norm_layer(dim)
        self.attn = WindowAttention(
            dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
            qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)

        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

        if self.shift_size > 0:
            # calculate attention mask for SW-MSA
            H, W = self.input_resolution
            img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1
            h_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            w_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            cnt = 0
            for h in h_slices:
                for w in w_slices:
                    img_mask[:, h, w, :] = cnt
                    cnt += 1

            mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1
            mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
            attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
            attn_mask = None

        self.register_buffer("attn_mask", attn_mask)

    def forward(self, x):
        H, W = self.input_resolution
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"

        shortcut = x
        x = self.norm1(x)
        x = x.view(B, H, W, C)

        # cyclic shift
        if self.shift_size > 0:
            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
            shifted_x = x

        # partition windows
        x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C
        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # nW*B, window_size*window_size, C

        # W-MSA/SW-MSA
        attn_windows = self.attn(x_windows, mask=self.attn_mask)  # nW*B, window_size*window_size, C

        # merge windows
        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
        shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C

        # reverse cyclic shift
        if self.shift_size > 0:
            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
            x = shifted_x
        x = x.view(B, H * W, C)

        # FFN
        x = shortcut + self.drop_path(x)
        x = x + self.drop_path(self.mlp(self.norm2(x)))

        return x


  • 输入:$x: B\times 56^2\times 96$,$H, W=56$

  • 经过一层layerNorm

  • 变形:$x: B\times 56\times 56\times 96$

  • 直接赋值给shifted_x

  • 调用window_partition函数,输入shifted_xwindow_size=7

    • 注意窗口大小以patch为单位,比如7就是7个patch,如果56的分辨率就会有8个窗口。
    • 这个函数对shifted_x做一系列变形,最终变成$8^2B\times 7 \times 7\times 96$
  • 返回赋值给x_windows,再变形成$8^2B\times 7^2\times 96$,这表示所有图片,每个图片的64个window,每个window内有49个patch。

  • 调用WindowAttention层,这里以它的num_head为3为例。输入参数为x_windowsself.attn_mask,对于W-MSA,attn_mask为None,可以不用管。

    • WindowAttention代码如下:

      class WindowAttention(nn.Module):
          def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
              self.dim = dim
              self.window_size = window_size  # Wh, Ww
              self.num_heads = num_heads
              head_dim = dim // num_heads
              self.scale = qk_scale or head_dim ** -0.5
              # define a parameter table of relative position bias
              self.relative_position_bias_table = nn.Parameter(
                  torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH
              # get pair-wise relative position index for each token inside the window
              coords_h = torch.arange(self.window_size[0])
              coords_w = torch.arange(self.window_size[1])
              coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
              coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
              relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
              relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
              relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
              relative_coords[:, :, 1] += self.window_size[1] - 1
              relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
              relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
              self.register_buffer("relative_position_index", relative_position_index)
              self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
              self.attn_drop = nn.Dropout(attn_drop)
              self.proj = nn.Linear(dim, dim)
              self.proj_drop = nn.Dropout(proj_drop)
              trunc_normal_(self.relative_position_bias_table, std=.02)
              self.softmax = nn.Softmax(dim=-1)
          def forward(self, x, mask=None):
                  x: input features with shape of (num_windows*B, N, C)
                  mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
              B_, N, C = x.shape
              qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
              q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)
              q = q * self.scale
              attn = (q @ k.transpose(-2, -1))
              relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
                  self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH
              relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
              attn = attn + relative_position_bias.unsqueeze(0)
              if mask is not None:
                  nW = mask.shape[0]
                  attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
                  attn = attn.view(-1, self.num_heads, N, N)
                  attn = self.softmax(attn)
                  attn = self.softmax(attn)
              attn = self.attn_drop(attn)
              x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
              x = self.proj(x)
              x = self.proj_drop(x)
              return x
    • 输入$x: 8^2B\times 7^2\times 96$。

    • 产生$QKV$,调用线性层后,得到$8^2B\times 7^2\times (96\times 3)$,拆分给不同的head,得到$8^2B\times 7^2\times 3 \times 3\times 32$,第一个3是$QKV$的3,第二个3是3个head。再permute成$3\times 8^2B\times 3\times 7^2\times 32$,再拆解成$q,k,v$,每个都是$8^2B\times 3\times 7^2\times 32$。表示所有图片的每个图片64个window,每个window对应到3个不同的head,都有一套49个patch、32维的特征。

    • $q$归一化

    • $qk$矩阵相乘求特征内积,得到$attn: 8^2B \times 3\times 7^2\times 7^2$

    • 得到相对位置的编码信息relative_position_bias

      • 代码如下:

        self.relative_position_bias_table = nn.Parameter(
                    torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH
        # get pair-wise relative position index for each token inside the window
        coords_h = torch.arange(self.window_size[0])
        coords_w = torch.arange(self.window_size[1])
        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
        relative_coords[:, :, 1] += self.window_size[1] - 1
        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
        self.register_buffer("relative_position_index", relative_position_index)
      • 这里以window_size=3为例,解释以下过程:首先生成$coords: 2\times3\times 3$,就是在一个$3\times 3$的窗口内,每个位置的$y,x$坐标,而relative_coords为$2\times 9 \times 9$,就是9个点中,每个点的$y$或$x$与其他所有点的差值,比如$ 0 1 $,表示3号点(第二行第一个点)与1号点(第一行第二个点)的$y$坐标的差值。然后变形,并让两个坐标分别加上$3-1=2$,是因为这些坐标值范围$[0,2]$,因此差值的最小值为-2,加上2后从0开始。最后让$y$坐标乘上$2\times 3-1=5$,应该是一个trick,调整差值范围。最后将两个维度的差值相加,得到relative_position_index,$3^2\times 3^2$,为9个点之间两两之间的相对位置编码值,最后用来到self.relative_position_bias_table中寻址,注意相对位置的最大值为$(2M-2)(2M-1)$,而这个table最多有$(2M-1)(2M-1)$行,因此保证可以寻址,得到了一组给多个head使用的相对位置编码信息,这个table是可训练的参数。

      • 回到代码中,得到的relative_position_bias为$3\times 7^2\times7^2$

    • 将其加到attn上,最后一个维度softmax,dropout

    • 与$v$矩阵相乘,并转置,合并多个头的信息,得到$8^2B\times 7^2\times 96$

    • 经过一层线性层,dropout,返回

  • 返回赋值给attn_windows,变形为$8^2B\times 7\times 7 \times 96$

  • 调用window_reverse,打回原状:$B\times 56\times 56\times96$

  • 返回给$x$,经过FFN:先加上原来的输入$x$作为residue结构,注意这里用到[timm][https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py]的DropPath,并且drop的概率是整个网络结构线性增长的。然后再加上两层mlp的结果。

  • 返回结果$x$。





  • 输入:$x: B\times 56^2\times 96$,$H, W=56$

  • 经过一层layerNorm

  • 变形:$x: B\times 56\times 56\times 96$

  • 对$x$进行roll:就是分别向左向上滚动,把块补在右面和下面,赋值给shifted_x

  • 调用window_partition函数,输入shifted_xwindow_size=7

    • 注意窗口大小以patch为单位,比如7就是7个patch,如果56的分辨率就会有8个窗口。
    • 这个函数对shifted\_x做一系列变形,最终变成$8^2B\times 7 \times 7\times 96$
  • 返回赋值给x\_windows,再变形成$8^2B\times 7^2\times 96$,这表示所有图片,每个图片的64个window,每个window内有49个patch。

  • 调用WindowAttention层,这里以它的num_head为3为例。输入参数为x_windowsself.attn_mask,对于SW-MSA,attn_mask的产生过程如下。

    • mask的形状为$1\times56\times56\times 1$,产生h_slicesw_slices都是三个:(0, -7),(-7, -3),(-3, None),通过一个循环,对mask分成9大块,并打上标号。

      • 下图为了方便显示,换了个例子:$12\times 12$个patch,window为$3\times 3$个patch,平移量1,$4\times4$个窗口

      • 解释:经过循环,mask被分成9大部分,分别表上0到8,然后经过window\_partition,得到$4^2\times3^2$,再作差,得到$4^2\times 3^2\times3^2$,将不为0的置为-100,将为0的填上0。可以理解成,在$4^2$个窗口中的每个窗口,都有9个位置,那么这9个位置中的每个应该关注哪些位置(0),不用关注哪些位置(-100),比如右上角那个区域的左上角位置,它对应的mask就应该是上面那一个横条,不用关注右边蓝色的三格,因为右边那列是由左边平移来的,而计算两个边缘之间相关性作用不大。其他位置同理。

    • 那么回到原来的代码中,产生的mask就应该是$8^2\times 7^2\times7^2$,计算attention时,将$attn: 8^2B \times 3\times 7^2\times 7^2$变形成$B\times 8^2\times 3\times7^2\times7^2$加上mask,再变回去,然后softmax。

    • 这里其实存在一点疑问:因为窗口天然隔离,上面的(0,-7), (-7, -3)这一段分割显得多余。

  • 后面就和W-MSA一样了,不过还要把$x$转回去,否则就要不停地转动了。



  • 对于普通的MSA,在全图上计算SA,如果尺寸为$h\times w$,那么产生$QWV$的$W$矩阵为$C\times C$,输入$x$为$hw\times C$,因此需要复杂度$3hwC^2$。然后,计算$QK^T$,$QKV$的维度是$hw\times C$,因此为$(hw)^2C$。然后softmax乘$V$得到$Z$,复杂度为$(hw)^2C$。最后,从$Z$到输出,还要乘一个矩阵$W$,复杂度$hwC^2$,因此总复杂度为$4hwC^2+2(hw)^2C$。
  • 对于W-MSA,因为分窗口进行SA,因此主要优化的是$(hw)^2C$项,现在变成$W^2\times W^2\times (h/W)\times(w/W)C=W^2hwC$,其实还是四次的,只不过将其中一个$hw$变成可以更小的$W^2$,减小的部分就是窗口与窗口之间的信息传递。