MLP-Mixer:面向视觉的全 MLP 架构

Ilya Tolstikhin, Neil Houlsby, Alexander Kolesnikov, Lucas Beyer,

Xiaohua Zhai, Thomas Unterthiner, Jessica Yung, Andreas Steiner,

Daniel Keysers, Jakob Uszkoreit, Mario Lucic, Alexey Dosovitskiy

equal contribution

Google Research, Brain Team

{tolstikhin, neilhoulsby, akolesnikov, lbeyer,

xzhai, unterthiner, jessicayung, andstein,

keysers, usz, lucic, adosovitskiy}@google.com

work done during Google AI Residency

摘要

卷积神经网络 (CNN) 是计算机视觉的首选模型。 最近,基于注意力的网络,例如 Vision Transformer,也变得流行。 在本文中,我们表明,虽然卷积和注意力都足以获得良好的性能,但它们都不是必需的。 我们推出了 MLP-Mixer,这是一种完全基于多层感知器 (MLP) 的架构。 MLP-Mixer 包含两种类型的层:一种是独立应用于图像块的 MLP(即“混合”每个位置的特征),另一种是跨块应用 MLP(即“混合”空间信息)。 当在大型数据集上或使用现代正则化方案进行训练时,MLP-Mixer 在图像分类基准上获得了有竞争力的分数,其预训练和推理成本可与最先进的模型相媲美。 我们希望这些结果能够激发进一步的研究,超越成熟的 CNN 和 Transformer 领域。111MLP-Mixer 代码将在 https://github.com/google-research/vision_transformer 上提供

1简介

正如计算机视觉的历史所证明的那样,更大的数据集的可用性加上计算能力的增加往往会导致范式转变。 虽然卷积神经网络 (CNN) 已成为计算机视觉事实上的标准,但最近视觉变换器 [14] (ViT)(一种基于自注意力层的替代方案)达到了最先进的水平-艺术表演。 ViT 延续了从模型中去除手工制作的视觉特征和归纳偏差的长期趋势,并进一步依赖于从原始数据中学习。

我们提出了MLP-Mixer架构(或简称“Mixer”),这是一种有竞争力但概念和技术上简单的替代方案,不使用卷积或自注意力。 相反,Mixer 的架构完全基于多层感知器 (MLP),这些感知器在空间位置或特征通道上重复应用。 Mixer 仅依赖于基本的矩阵乘法例程、数据布局的更改(重塑和转置)以及标量非线性。

1描述了Mixer的宏观结构。 它接受一系列线性投影图像块(也称为标记),形状为“块×通道”表作为输入,并维护该维度。 Mixer 使用两种类型的 MLP 层:通道混合 MLP Token 混合 MLP 通道混合 MLP 允许不同通道之间进行通信;它们独立地对每个词符进行操作,并将表中的各个行作为输入。 Token 混合 MLP 允许不同空间位置( Token )之间进行通信;它们独立地在每个通道上运行,并将表的各个列作为输入。 这两种类型的层交错以实现两个输入维度的交互。

在极端情况下,我们的架构可以看作是一个非常特殊的 CNN,它使用 1×1 个卷积进行​​通道混合,以及全通道的单通道深度卷积。 词符混合的感受野和参数共享。 然而,反之则不然,因为典型的 CNN 并不是 Mixer 的特例。 此外,卷积比 MLP 中的普通矩阵乘法更复杂,因为它需要额外昂贵的减少矩阵乘法和/或专门的实现。

尽管很简单,Mixer 却取得了有竞争力的结果。 当在大型数据集(即 100M 图像)上进行预训练时,在准确性/成本权衡方面,它达到了 CNN 和 Transformers 之前声称的接近最先进的性能。 这包括 ILSVRC2012“ImageNet”[13] 上 87.94% 的 top-1 验证准确度。 当对更适度规模的数据(即 1–10M 图像)进行预训练时,结合现代正则化技术 [49, 54],Mixer 也实现了强大的性能。 然而,与 ViT 类似,它在专用 CNN 架构方面略有欠缺。

2 混频器架构

Refer to caption
图1 MLP-Mixer 由每块线性嵌入、混合器层和分类器头组成。 混合器层包含一个 Token 混合 MLP 和一个通道混合 MLP,每个层由两个全连接层和一个 GELU 非线性层组成。 其他组件包括:通道上的跳过连接、丢失和层规范。

现代深度视觉架构由多个层组成,这些层混合了(i)给定空间位置的特征,(ii)不同空间位置之间的特征,或同时混合两者的特征。 在 CNN 中,(ii) 通过 N×N 卷积(对于 N>1)和池化来实现。 较深层的神经元具有较大的感受野[1, 28] 同时,1×1 个卷积也执行 (i),较大的内核同时执行 (i) 和 (ii)。 在 Vision Transformers 和其他基于注意力的架构中,自注意力层允许 (i) 和 (ii) 两者,并且 MLP 块执行 (i)。 Mixer 架构背后的想法是明确区分每个位置(通道混合)操作 (i) 和跨位置( Token 混合)操作 (ii) 。 这两种操作都是通过 MLP 实现的。 1总结了该架构。

Mixer 将一系列 S 不重叠的图像块作为输入,每个图像块都投影到所需的隐藏维度 C。这会产生一个二维实值输入表𝐗S×C 如果原始输入图像的分辨率为(H,W),并且每个补丁的分辨率为(P,P),则补丁的数量为S=HW/P2 所有面片均使用相同投影矩阵进行线性投影。 Mixer由多个相同大小的层组成,每层由两个MLP块组成。 第一个是 Token 混合 MLP:它作用于𝐗的列(即它应用于转置输入表𝐗),映射SS,并且在所有列之间共享。 第二个是通道混合 MLP:它作用于𝐗行,映射CC,并在所有行之间共享。 每个 MLP 块包含两个全连接层和独立应用于其输入数据张量的每一行的非线性。 混合器层可以写成如下(省略层索引):

𝐔,i =𝐗,i+𝐖2σ(𝐖1LayerNorm(𝐗),i),for i=1C, (1)
𝐘j, =𝐔j,+𝐖4σ(𝐖3LayerNorm(𝐔)j,),for j=1S.

这里 σ 是逐元素非线性 (GELU [16])。 DSDC 分别是 Token 混合和通道混合 MLP 中的可调隐藏宽度。 请注意,DS 的选择与输入补丁的数量无关。 因此,网络的计算复杂度与输入补丁的数量成线性关系,这与 ViT 的复杂度是二次的不同。 由于 DC 与 patch 大小无关,因此对于典型的 CNN 来说,整体复杂度与图像中的像素数量呈线性关系。

如上所述,相同通道混合MLP( Token 混合MLP)被应用于𝐗的每一行(列)。绑定通道混合 MLP 的参数(在每层内)是一个自然的选择——它提供了位置不变性,这是卷积的一个显着特征。 然而,跨通道绑定参数的情况要少得多。 例如,某些 CNN 中使用的可分离卷积 [9, 40] 将卷积独立于其他通道应用于每个通道。 然而,在可分离卷积中,不同的卷积内核应用于每个通道,这与 Mixer 中的 Token 混合 MLP 不同,MLP 为所有通道共享相同的内核(具有完整的感受野)。 参数绑定可防止架构在增加隐藏维度 C 或序列长度 S 时增长过快,并显着节省内存。 令人惊讶的是,这种选择并不影响实证表现,请参阅补充A.1

Mixer 中的每个层(除了初始补丁投影层)都采用相同大小的输入。 这种“各向同性”设计与 Transformer 或其他领域中的深度 RNN 最相似,它们也使用固定宽度。 这与大多数具有金字塔结构的 CNN 不同:更深的层具有较低的分辨率输入,但有更多的通道。 请注意,虽然这些是典型的设计,但还存在其他组合,例如各向同性 ResNets [38] 和金字塔 ViTs [52]

除了 MLP 层之外,Mixer 还使用其他标准架构组件:skip-connections [15] 和层归一化 [2] 与 ViT 不同,Mixer 不使用位置嵌入,因为 Token 混合 MLP 对输入 Token 的顺序敏感。 最后,Mixer 使用标准分类头和全局平均池化层,后跟线性分类器。 总体而言,该架构可以用 JAX/Flax 紧凑地编写,代码在补充 E 中给出。

3实验

我们评估了 MLP-Mixer 模型在一系列中小型下游分类任务上的性能,该模型使用中型到大型数据集进行了预训练。 我们对三个主要量感兴趣:(1)下游任务的准确性; (2) 预训练的计算成本,这在上游数据集上从头开始训练模型时很重要; (3) 测试时间吞吐量,这对从业者来说很重要。 我们的目标不是展示最先进的结果,而是要表明,基于 MLP 的简单模型与当今最好的卷积和基于注意力的模型具有竞争力。

下游任务我们使用流行的下游任务,例如ILSVRC2012“ImageNet”(1.3M训练示例,1k类)以及原始验证标签[13]和清理后的ReaL标签[5]、CIFAR-10/100(50k 个示例,10/100 个类别)[23]、Oxford-IIIT Pets(3.7k 个示例、36 个类别)[32] 和 Oxford Flowers-102(2k 个示例,102 个类)[31] 我们还使用视觉任务适应基准 (VTAB-1k),它由 19 个不同的数据集组成,每个数据集都有 1k 个训练示例[58]

预训练我们遵循标准的迁移学习设置:预训练,然后对下游任务进行微调。 我们在两个公共数据集上预训练模型:ILSVRC2021 ImageNet 和 ImageNet-21k,ILSVRC2012 的超集包含 21k 个类和 14M 图像[13] 为了评估更大规模的性能,我们还在 JFT-300M 上进行训练,这是一个包含 300M 示例和 18k 类[44]的专有数据集。 我们按照 Dosovitskiy 等人 [14]、Kolesnikov 等人 [22] 中的下游任务测试集对所有预训练数据集进行去重复。 我们使用 Adam 以 224 分辨率预训练所有模型,其中 β1=0.9β2=0.999、10k 步的线性学习率预热和线性衰减、批量大小 4 096、权重衰减和梯度裁剪处于全球标准 1。 对于 JFT-300M,除了随机水平翻转之外,我们还应用 Szegedy 等人 [45] 的裁剪技术来预处理图像。 对于 ImageNet 和 ImageNet-21k,我们采用了额外的数据增强和正则化技术。 特别是,我们使用 RandAugment [12]、mixup [60]、dropout [43] 和随机深度 [19] 这套技术的灵感来自于timm库[54]Touvron等人[48] 补充B中提供了有关这些超参数的更多详细信息。

微调 我们使用动量 SGD、批量大小 512、全局范数 1 的梯度裁剪以及带有线性预热的余弦学习率计划来调节。 微调时我们不使用权重衰减。 按照常见做法[22, 48],我们还对预训练期间使用的分辨率进行了更高分辨率的调节。 由于我们保持补丁分辨率固定,这会增加输入补丁的数量(例如从 SS),因此需要修改 Mixer 的 Token 混合 MLP 块的形状。 形式上,方程中的输入。 (1) 左乘于权重矩阵𝐖1DS×S,并且在更改输入维度S时必须调整此运算。为此,我们将隐藏层宽度从 DS 增加到 DS,与补丁的数量成比例,并用 a 初始化(现在更大的)权重矩阵 𝐖2DS×S块对角矩阵,其对角线上包含 𝐖2 的副本。 此特定方案仅允许 S=K2SK 一起使用。 有关更多详细信息,请参阅补充C 在 VTAB-1k 基准测试中,我们分别在具有小输入图像和大输入图像的数据集上遵循分辨率为 224 和 448 的 BiT-HyperRule [22] 和调整 Mixer 模型。

指标我们评估模型的计算成本和质量之间的权衡。 对于前者,我们计算两个指标:(1)TPU-v3 加速器上的总预训练时间,它结合了三个相关因素:每个训练设置的理论 FLOP、相关训练硬件的计算效率以及数据效率。 (2) TPU-v3 上的图像/秒/核心吞吐量。 由于不同大小的模型可能受益于不同的批量大小,因此我们扫描批量大小并报告每个模型的最高吞吐量。 对于模型质量,我们重点关注微调后的 top-1 下游精度。 在两次(图3,右图4)中,微调所有模型的成本太高,我们报告了通过求解2-图像和标签的冻结学习表示之间的正则化线性回归问题。

表1 混合器架构的规格。 “B”、“L”和“H”(基础、大型和巨大)模型比例遵循Dosovitskiy 等人[14] 简写“B/16”表示分辨率为16×16的面片的基础比例模型。 报告输入分辨率为 224 的参数数量,不包括分类器头的权重。
Specification S/32 S/16 B/32 B/16 L/32 L/16 H/14
Number of layers 8 8 12 12 24 24 32
Patch resolution P×P 32×32 16×16 32×32 16×16 32×32 16×16 14×14
Hidden size C 512 512 768 768 1024 1024 1280
Sequence length S 49 196 49 196 49 196 256
MLP dimension DC 2048 2048 3072 3072 4096 4096 5120
MLP dimension DS 256 256 384 384 512 512 640
Parameters (M) 19 18 60 59 206 207 431

模型 我们将表 1 中总结的 Mixer 的各种配置与最新、最先进的 CNN 和基于注意力的模型进行了比较。 在所有的图和表中,基于 MLP 的 Mixer 模型都用粉红色标记( ),基于卷积的模型,黄色( ),以及带有蓝色( )。 Vision Transformers (ViTs) 的模型比例和补丁​​分辨率与 Mixer 类似。 HaloNet 是基于注意力的模型,它使用带有局部自注意力层的类似 ResNet 的结构,而不是 3×3 个卷积[51] 我们专注于特别高效的“HaloNet-H4(base 128,Conv-12)”模型,它是更广泛的 HaloNet-H4 架构的混合变体,其中一些自注意力层被卷积取代。 注意,我们用蓝色标记 HaloNets 的注意力和卷积( )。 Big Transfer (BiT) [22] 模型是针对迁移学习而优化的 ResNet。 NFNet [7] 是无归一化器的 ResNet,针对 ImageNet 分类进行了多项优化。 我们考虑 NFNet-F4+ 模型变体。 我们考虑将 MPL [34] 和 ALIGN [21] 用于 EfficientNet 架构。 MPL 在 JFT-300M 图像上进行了大规模预训练,使用 ImageNet 的元伪标签而不是原始标签。 我们与 EfficientNet-B6-Wide 模型变体进行比较。 以对比的方式在嘈杂的网络图像文本对上对齐预训练图像编码器和语言编码器。 我们与他们最好的 EfficientNet-L2 图像编码器进行比较。

表2 传输性能、推理吞吐量和训练成本。 这些行按推理吞吐量排序(第五列)。 Mixer 的传输精度与成本相似的最先进模型相当。 混合器模型在分辨率 448 下进行了微调。 混合器性能数据是三次微调运行的平均值,标准偏差小于 0.1
ImNet ReaL Avg 5 VTAB-1k Throughput TPUv3
top-1 top-1 top-1 19 tasks img/sec/core core-days
Pre-trained on ImageNet-21k (public)
HaloNet [51] 85.85 120 0.10k
Mixer-L/16 84.15 87.86 93.91 74.95 105 0.41k
ViT-L/16 [14] 85.30 88.62 94.39 72.72 532 0.18k
BiT-R152x4 [22] 85.39 94.04 70.64 526 0.94k
Pre-trained on JFT-300M (proprietary)
NFNet-F4+ [7] 89.25 546 1.86k
Mixer-H/14 87.94 90.18 95.71 75.33 540 1.01k
BiT-R152x4 [22] 87.54 90.54 95.33 76.29 526 9.90k
ViT-H/14 [14] 88.55 90.72 95.97 77.63 515 2.30k
Pre-trained on unlabelled or weakly labelled data (proprietary)
MPL [34] 90.05 91.12 20.48k
ALIGN [21] 88.64 79.99 15 14.82k

3.1 主要结果

2 比较了最大的 Mixer 模型与文献中最先进的模型。 “ImNet”和“ReaL”列指的是原始 ImageNet 验证 [13] 和清理后的 ReaL [5] 标签。 “平均。5”代表所有五个下游任务(ImageNet、CIFAR-10、CIFAR-100、Pets、Flowers)的平均性能。 2(左)可视化了精度计算边界。 当在 ImageNet-21k 上进行预训练并进行额外正则化时,Mixer 实现了整体强劲的性能(在 ImageNet 上为 84.15% top-1),尽管略逊于其他模型222 在表2中,我们考虑每个预训练数据集的每个类别中最高精度的模型。 这些都使用大分辨率(448 及以上)。 然而,以较小的分辨率进行微调可以显着提高测试时的吞吐量,而通常只会造成很小的精度损失。 例如,在 ImageNet-21k 上进行预训练时,在 224 分辨率下微调的 Mixer-L/16 模型在吞吐量 420 img/sec/core 下实现了 82.84% ImageNet top-1 准确率; ViT-L/16 模型在 384 分辨率下进行微调,在 80 img/sec/core 时达到 85.15% [14]; HaloNet 在 384 分辨率下进行微调,在 258 img/sec/core [51] 下达到 85.5%。 . 在这种情况下,正则化是必要的,如果没有正则化,Mixer 就会过拟合,这与 ViT [14] 的类似观察结果一致。 当在 ImageNet 上随机初始化训练 Mixer 时,同样的结论成立(参见第 3.2 节):Mixer-B/16 在分辨率 224 下获得了 76.4% 的合理分数,但往往会过度拟合。 该分数与普通 ResNet50 类似,但落后于 ImageNet“从头开始”设置的最先进的 CNN/混合模型,例如84.7% BotNet [42] 和 86.5% NFNet [7]

当上游数据集大小增加时,Mixer 的性能显着提高。 特别是,Mixer-H/14 在 ImageNet 上实现了 87.94% 的 top-1 准确率,比 BiT-ResNet152x4 提高了 0.5%,仅比 ViT-H/14 低 0.5%。 值得注意的是,Mixer-H/14 的运行速度比 ViT-H/14 快 2.5 倍,几乎是 BiT 的两倍。 总体而言,图2(左)支持了我们的主要主张,即在精度计算权衡方面,Mixer 与更传统的神经网络架构相比具有竞争力。 该图还表明,即使跨架构类别,总预训练成本与下游准确性之间也存在明显的相关性。

表中的 BiT-ResNet152x4 使用 SGD 进行预训练,具有动量和较长的时间表。 由于 Adam 往往收敛得更快,我们使用在 JFT 上预训练的 Dosovitskiy 等人 [14] 的 BiT-R200x3 模型完成图 2(左)中的图片使用亚当300M。 该 ResNet 的精度稍低,但预训练计算量却相当低。 最后,该图中还报告了较小的 ViT-L/16 和 Mixer-L/16 模型的结果。

Refer to caption
Refer to caption
图2 左:2 中 SOTA 模型的 ImageNet 准确性/训练成本帕累托前沿(虚线)。 模型在 ImageNet-21k、JFT(有标签或 MPL 的伪标签)或 Web 图像文本对上进行预训练。 Mixer 与这些性能极高的 ResNet、ViT 和混合模型一样好,并且与 HaloNet、ViT、NFNet 和 MPL 处于前沿。 右: 随着数据大小的增长,Mixer(实线)赶上或超过 BiT(点线)和 ViT(虚线)。 曲线上的每个点都使用相同的预训练计算;它们分别对应于 JFT-300M 的 3%、10%、30% 和 100% 的预训练 233、70、23 和 7 个 epoch。 3B 处的其他点对应于在更大的 JFT-3B 数据集上进行相同步数的预训练。 Mixer 在数据方面的改进比 ResNets 甚至 ViT 更快。 大型 Mixer 和 ViT 模型之间的差距缩小。
Refer to caption
图3 模型规模的作用。 ImageNet 验证不同规模的 ViT、BiT 和 Mixer 模型的 top-1 准确率与总预训练计算量()和吞吐量()。 所有模型均在 JFT-300M 上进行预训练,并在分辨率 224 下进行微调,该分辨率低于图2(左)。

3.2模型尺度的作用

上一节中概述的结果重点关注计算范围高端的(大型)模型。 我们现在将注意力转向较小的 Mixer 模型。

我们可以通过两种独立的方式缩放模型:(1)在预训练时增加模型大小(层数、隐藏维度、MLP 宽度); (2)微调时提高输入图像分辨率。 虽然前者影响预训练计算和测试时吞吐量,但后者仅影响吞吐量。 除非另有说明,否则我们会根据第 224 号决议进行调整。

表3 Mixer 和文献中的其他模型在各种模型和预训练数据集规模上的性能。 “平均。5”表示五个下游任务的平均性能。 Mixer 和 ViT 模型在 3 次微调运行中取平均值,标准差小于 0.15 () 根据在 JFT-300M 上预训练的相同模型报告的数字推断,无需额外正则化。 (☎) 数字由Dosovitskiy 等人[14]的作者通过个人交流提供。 行按吞吐量排序。
Image Pre-Train ImNet ReaL Avg. 5 Throughput TPUv3
size Epochs top-1 top-1 top-1 (img/sec/core) core-days
Pre-trained on ImageNet (with extra regularization)
Mixer-B/16 224 300 76.44 82.36 88.33 1384 0.01k(‡)
ViT-B/16 (☎) 224 300 79.67 84.97 90.79 861 0.02k(‡)
Mixer-L/16 224 300 71.76 77.08 87.25 419 0.04k(‡)
ViT-L/16 (☎) 224 300 76.11 80.93 89.66 280 0.05k(‡)
Pre-trained on ImageNet-21k (with extra regularization)
Mixer-B/16 224 300 80.64 85.80 92.50 1384 0.15k(‡)
ViT-B/16 (☎) 224 300 84.59 88.93 94.16 861 0.18k(‡)
Mixer-L/16 224 300 82.89 87.54 93.63 419 0.41k(‡)
ViT-L/16 (☎) 224 300 84.46 88.35 94.49 280 0.55k(‡)
Mixer-L/16 448 300 83.91 87.75 93.86 105 0.41k(‡)
Pre-trained on JFT-300M
Mixer-S/32 224 55 68.70 75.83 87.13 11489 0.01k
Mixer-B/32 224 57 75.53 81.94 90.99 4208 0.05k
Mixer-S/16 224 55 73.83 80.60 89.50 3994 0.03k
BiT-R50x1 224 57 73.69 81.92 2159 0.08k
Mixer-B/16 224 57 80.00 85.56 92.60 1384 0.08k
Mixer-L/32 224 57 80.67 85.62 93.24 1314 0.12k
BiT-R152x1 224 57 79.12 86.12 932 0.14k
BiT-R50x2 224 57 78.92 86.06 890 0.14k
BiT-R152x2 224 14 83.34 88.90 356 0.58k
Mixer-L/16 224 57 84.05 88.14 94.51 419 0.23k
Mixer-L/16 224 14 84.82 88.48 94.77 419 0.45k
ViT-L/16 224 14 85.63 89.16 95.21 280 0.65k
Mixer-H/14 224 14 86.32 89.14 95.49 194 1.01k
BiT-R200x3 224 14 84.73 89.58 141 1.78k
Mixer-L/16 448 14 86.78 89.72 95.13 105 0.45k
ViT-H/14 224 14 86.65 89.56 95.57 87 2.30k
ViT-L/16 [14] 512 14 87.76 90.54 95.63 32 0.65k

我们将 Mixer 的各种配置(参见表 1)与类似规模的 ViT 模型和使用 Adam 预训练的 BiT 模型进行比较。 结果总结在表3和图3中。 当在 ImageNet 上从头开始训练时,Mixer-B/16 达到了 76.44% 的合理 top-1 准确率。 这比 ViT-B/16 模型落后 3%。 训练曲线(未报告)表明,两个模型的训练损失值非常相似。 换句话说,Mixer-B/16 比 ViT-B/16 过拟合更多。 对于 Mixer-L/16 和 ViT-L/16 型号,这种差异更加明显。

随着预训练数据集的增长,Mixer 的性能稳步提高。 值得注意的是,在 JFT-300M 上预训练并在 224 分辨率下进行微调的 Mixer-H/14 在 ImageNet 上仅落后 ViT-H/14 0.3%,同时运行速度提高了 2.2 倍。 3 清楚地表明,尽管 Mixer 略低于模型尺度下端的前沿,但它自信地位于高端的前沿。

3.3预训练数据集大小的作用

迄今为止的结果表明,对较大数据集进行预训练可以显着提高 Mixer 的性能。 在这里,我们更详细地研究这种效应。

为了研究 Mixer 利用越来越多的训练示例的能力,我们在包含 3%、10% 的 JFT-300M 随机子集上预训练 Mixer-B/32、Mixer-L/32 和 Mixer-L/16 模型、233、70、23 和 7 个 epoch 的所有训练示例的 30% 和 100%。 因此,每个模型都经过相同数量的总步数的预训练。 我们还在更大的 JFT-3B 数据集 [59] 上预训练 Mixer-L/16 模型,该数据集包含大约 3B 图像和 30k 类,总步数相同。 虽然不能严格比较,但这使我们能够进一步推断规模的影响。 我们使用 ImageNet 上的线性 5-shot top-1 精度作为传输质量的代理。 对于每次预训练运行,我们都会根据最佳上游验证性能执行提前停止。 结果如图2(右)所示,其中我们还包括 ViT-B/32、ViT-L/32、ViT-L/16 和 BiT-R152x2 模型。

当在 JFT-300M 的最小子集上进行预训练时,所有 Mixer 模型都严重过拟合。 BiT 模型也会过拟合,但程度较轻,可能是由于与卷积相关的强归纳偏差。 随着数据集的增加,Mixer-L/32和Mixer-L/16的性能增长速度都快于BiT; Mixer-L/16 不断改进,而 BiT 模型则趋于稳定。

同样的结论也适用于 ViT,与Dosovitskiy 等人[14]一致。 然而,较大 Mixer 模型的相对改进更为明显。 Mixer-L/16 和 ViT-L/16 之间的性能差距随着数据规模的扩大而缩小。 看来 Mixer 从不断增长的数据集大小中受益甚至比 ViT 还要多。 人们可以通过归纳偏差的差异再次推测和解释它:ViT 中的自注意力层导致学习函数的某些属性与真正的底层分布相比,与 Mixer 架构中发现的属性不太兼容

Refer to caption
Refer to caption
Refer to caption
Refer to caption
图4 顶部: 在排列内容之前输入来自 ImageNet 的示例(左);打乱 16×16 补丁和补丁内的像素后(中心);全局洗牌像素后(右)。 底部: Mixer-B/16(左)和 ResNet50x1(右)使用三个相应的输入管道进行训练。
Refer to caption
Refer to caption
Refer to caption
图5 Mixer-B/ 的第一个()、第二个(中心)和第三个() Token 混合 MLP 中的隐藏单元16 个模型在 JFT-300M 上训练。 每个单元都有 196 权重,每个 14×14 传入补丁都有一个权重。 我们将这些单元配对以突出相反相的内核的出现。 对按过滤频率排序。 与卷积滤波器的内核(其中每个权重对应于输入图像中的一个像素)相反,左列中任何图中的一个权重对应于输入图像的特定 16×16 块。 完成补充D中的绘图。

3.4 输入排列的不变性

在本节中,我们研究 Mixer 和 CNN 架构的归纳偏差之间的差异。 具体来说,我们按照 3 节中描述的预训练设置并使用两种不同的输入转换之一在 JFT-300M 上训练 Mixer-B/16 和 ResNet50x1 模型: (1) 打乱 16 的顺序×16 个补丁,并使用共享排列对每个补丁内的像素进行排列; (2) 对整个图像中的像素进行全局置换。 所有图像都使用相同的排列。 我们在图 4(底部)中报告了 ImageNet 上训练模型的线性 5-shot top-1 准确率。 一些原始图像及其两个转换后的版本显示在图4(顶部)中。 正如所预料的,Mixer 对于补丁和补丁内像素的顺序是不变的(蓝色和绿色曲线完美匹配)。 另一方面,ResNet 的强归纳偏差依赖于图像中特定的像素顺序,当补丁被排列时,其性能会显着下降。 值得注意的是,当全局排列像素时,与 ResNet(下降 75%)相比,Mixer 的性能下降要少得多(下降 45%)。

3.5可视化

人们普遍观察到,CNN 的第一层倾向于学习类似 Gabor 的检测器,作用于图像局部区域的像素。 相比之下,Mixer 允许在 Token 混合 MLP 中进行全局信息交换,这就引出了一个问题:它是否以类似的方式处理信息。 5显示了在JFT-300M上训练的Mixer的前三个 Token 混合MLP的隐藏单元。 回想一下, Token 混合 MLP 允许不同空间位置之间的全局通信。 一些学习到的特征对整个图像起作用,而另一些则对较小的区域起作用。 更深的层似乎没有清晰可辨的结构。 与 CNN 类似,我们观察到许多对具有相反相位的特征检测器[39] 学习单元的结构取决于超参数。 第一个嵌入层的图显示在补充 D 的图 7 中。

4相关工作

MLP-Mixer 是一种新的计算机视觉架构,与以前的成功架构不同,因为它既不使用卷积层也不使用自注意力层。 尽管如此,设计选择可以追溯到 CNN [24, 25] 和 Transformers [50] 文献中的想法。

自从 AlexNet 模型[24]超越了基于手工制作的图像特征[35]的流行方法以来,CNN 已经成为计算机视觉领域事实上的标准。 许多工作都专注于改进 CNN 的设计。 Simonyan 和 Zisserman [41] 证明,人们可以仅使用具有小型 3×3 内核的卷积来训练最先进的模型。 He 等人[15]引入了跳跃连接和批量归一化[20],这使得能够训练非常深的神经网络并进一步提高性能。 一项重要的研究调查了使用稀疏卷积的好处,例如分组 [57] 或深度方向的 [9, 17] 变体。 与我们的 Token 混合 MLP 类似,Wu 等人 [55] 在自然语言处理的深度卷积中共享参数。 Hu 等人[18]Wang 等人[53]提出用非局部操作增强卷积网络,以部分缓解CNN局部处理的约束。 Mixer 将使用小内核卷积的想法发挥到了极致:通过将内核大小减小到 1×1,它将卷积转换为独立应用于每个空间位置的标准密集矩阵乘法(通道混合 MLP)。 仅此一点就不允许聚合空间信息,为了补偿,我们应用密集矩阵乘法,该矩阵乘法应用于所有空间位置的每个特征( Token 混合 MLP)。 在 Mixer 中,矩阵乘法在“patches×features”输入表上按行或按列应用,这也与稀疏卷积的工作密切相关。 Mixer 使用跳过连接 [15] 和标准化层 [2, 20]

在计算机视觉领域,基于自注意力的 Transformer 架构最初应用于生成建模[8, 33] 它们在图像识别方面的价值后来得到了证明,尽管是与类似卷积的局部性偏差 [37] 相结合,或者在低分辨率图像 [10] 上。 Dosovitskiy 等人[14]介绍了 ViT,这是一种纯 Transformer 模型,具有较少的局部性偏差,但可以很好地扩展到大数据。 ViT 在流行的视觉基准上实现了最先进的性能,同时保留了 CNN [6] 的稳健性。 Touvron 等人[49]使用广泛的正则化在较小的数据集上有效地训练了 ViT。 Mixer 借鉴了最新基于 Transformer 的架构的设计选择。 Mixer 的 MLP 块的设计源自 Vaswani 等人[50] 将图像转换为补丁序列并直接处理这些补丁的嵌入源自Dosovitskiy等人[14]

最近的许多作品都致力于设计更有效的视觉架构。 Srinivas 等人[42]用自注意力层替换ResNets中的3×3个卷积。 Ramachandran 等人 [37]Tay 等人 [47]Li 等人 [26]Bello [3] 设计具有新的类似注意力机制的网络。 Mixer 可以被视为正交方向的一步,不依赖局部性偏差和注意力机制。

Lin 等人[27]的工作与之密切相关。 它使用完全连接的网络、大量数据增强以及使用自动编码器进行预训练,在 CIFAR-10 上获得了合理的性能。 Neyshabur [30] 设计了自定义正则化和优化算法并训练全连接网络,在小规模任务上获得了令人印象深刻的性能。 相反,我们依靠词符和通道混合 MLP,使用标准正则化和优化技术,并有效地扩展到大数据。

传统上,在 ImageNet [13] 上评估的网络是使用 Inception 式预处理 [46] 从随机初始化进行训练的。 对于较小的数据集,ImageNet 模型的传输很流行。 然而,现代最先进的模型通常使用在较大数据集上预先训练的权重,或者使用更新的数据增强和训练策略。 例如,Dosovitskiy 等人[14]、Kolesnikov 等人[22]、Mahajan 等人[29]、Pham 等人[34]、Xie 等人[56] 全部提前状态-使用大规模预训练进行图像分类的最先进技术。 由于增强或正则化变化而带来的改进示例包括 Cubuk 等人 [11],他们通过学习数据增强获得了出色的分类性能,以及 Bello 等人 [4],他们表明,如果使用最新的训练和增强策略,规范的 ResNet 仍然接近最先进的水平。

5结论

我们描述了一个非常简单的视觉架构。 我们的实验表明,在训练和推理所需的准确性和计算资源之间的权衡方面,它与现有最先进的方法一样好。 我们相信这些结果提出了许多问题。 在实践方面,研究模型学习的特征并确定与 CNN 和 Transformer 学习的特征的主要差异(如果有)可能会很有用。 在理论方面,我们希望了解隐藏在这些不同特征中的归纳偏差以及它们最终在泛化中的作用。 最重要的是,我们希望我们的结果能够激发进一步的研究,超越基于卷积和自注意力的已建立模型的领域。 看看这样的设计是否适用于 NLP 或其他领域将会特别有趣。

致谢和资金披露

这项工作是在柏林和苏黎世的 Brain 团队中进行的。 我们感谢 Josip Djolonga 对本文初始版本的反馈; Preetum Nakkiran 提议在具有打乱像素的输入图像上训练 MLP-Mixer; Olivier Bousquet、Yann Dauphin 和 Dirk Weissenborn 进行了有益的讨论。

参考

  • Araujo et al. [2019] A. Araujo, W. Norris, and J. Sim. Computing receptive fields of convolutional neural networks. Distill, 2019. doi: 10.23915/distill.00021. URL https://distill.pub/2019/computing-receptive-fields.
  • Ba et al. [2016] J. L. Ba, J. R. Kiros, and G. E. Hinton. Layer normalization. arXiv preprint arXiv:1607.06450, 2016.
  • Bello [2021] I. Bello. LambdaNetworks: Modeling long-range interactions without attention. arXiv preprint arXiv:2102.08602, 2021.
  • Bello et al. [2021] I. Bello, W. Fedus, X. Du, E. D. Cubuk, A. Srinivas, T.-Y. Lin, J. Shlens, and B. Zoph. Revisiting ResNets: Improved training and scaling strategies. arXiv preprint arXiv:2103.07579, 2021.
  • Beyer et al. [2020] L. Beyer, O. J. Hénaff, A. Kolesnikov, X. Zhai, and A. van den Oord. Are we done with ImageNet? arXiv preprint arXiv:2006.07159, 2020.
  • Bhojanapalli et al. [2021] S. Bhojanapalli, A. Chakrabarti, D. Glasner, D. Li, T. Unterthiner, and A. Veit. Understanding robustness of transformers for image classification. arXiv preprint arXiv:2103.14586, 2021.
  • Brock et al. [2021] A. Brock, S. De, S. L. Smith, and K. Simonyan. High-performance large-scale image recognition without normalization. arXiv preprint arXiv:2102.06171, 2021.
  • Child et al. [2019] R. Child, S. Gray, A. Radford, and I. Sutskever. Generating long sequences with sparse transformers. arXiv preprint arXiv:1904.10509, 2019.
  • Chollet [2017] F. Chollet. Xception: Deep learning with depthwise separable convolutions. In CVPR, 2017.
  • Cordonnier et al. [2020] J.-B. Cordonnier, A. Loukas, and M. Jaggi. On the relationship between self-attention and convolutional layers. In ICLR, 2020.
  • Cubuk et al. [2019] E. D. Cubuk, B. Zoph, D. Mane, V. Vasudevan, and Q. V. Le. AutoAugment: Learning augmentation policies from data. In CVPR, 2019.
  • Cubuk et al. [2020] E. D. Cubuk, B. Zoph, J. Shlens, and Q. V. Le. RandAugment: Practical automated data augmentation with a reduced search space. In CVPR Workshops, 2020.
  • Deng et al. [2009] J. Deng, W. Dong, R. Socher, L. Li, Kai Li, and Li Fei-Fei. ImageNet: A large-scale hierarchical image database. In CVPR, 2009.
  • Dosovitskiy et al. [2021] A. Dosovitskiy, L. Beyer, A. Kolesnikov, D. Weissenborn, X. Zhai, T. Unterthiner, M. Dehghani, M. Minderer, G. Heigold, S. Gelly, J. Uszkoreit, and N. Houlsby. An image is worth 16x16 words: Transformers for image recognition at scale. In ICLR, 2021.
  • He et al. [2016] K. He, X. Zhang, S. Ren, and J. Sun. Deep residual learning for image recognition. In CVPR, 2016.
  • Hendrycks and Gimpel [2016] D. Hendrycks and K. Gimpel. Gaussian error linear units (GELUs). arXiv preprint arXiv:1606.08415, 2016.
  • Howard et al. [2017] A. G. Howard, M. Zhu, B. Chen, D. Kalenichenko, W. Wang, T. Weyand, M. Andreetto, and H. Adam. Mobilenets: Efficient convolutional neural networks for mobile vision applications. arXiv preprint arXiv:1704.04861, 2017.
  • Hu et al. [2018] J. Hu, L. Shen, and G. Sun. Squeeze-and-excitation networks. In CVPR, 2018.
  • Huang et al. [2016] G. Huang, Y. Sun, Z. Liu, D. Sedra, and K. Q. Weinberger. Deep networks with stochastic depth. In ECCV, 2016.
  • Ioffe and Szegedy [2015] S. Ioffe and C. Szegedy. Batch normalization: Accelerating deep network training by reducing internal covariate shift. In ICML, 2015.
  • Jia et al. [2021] C. Jia, Y. Yang, Y. Xia, Y.-T. Chen, Z. Parekh, H. Pham, Q. V. Le, Y. Sung, Z. Li, and T. Duerig. Scaling up visual and vision-language representation learning with noisy text supervision. arXiv preprint arXiv:2102.05918, 2021.
  • Kolesnikov et al. [2020] A. Kolesnikov, L. Beyer, X. Zhai, J. Puigcerver, J. Yung, S. Gelly, and N. Houlsby. Big transfer (BiT): General visual representation learning. In ECCV, 2020.
  • Krizhevsky [2009] A. Krizhevsky. Learning multiple layers of features from tiny images. Technical report, University of Toronto, 2009.
  • Krizhevsky et al. [2012] A. Krizhevsky, I. Sutskever, and G. E. Hinton. ImageNet classification with deep convolutional neural networks. In NeurIPS, 2012.
  • LeCun et al. [1989] Y. LeCun, B. Boser, J. Denker, D. Henderson, R. Howard, W. Hubbard, and L. Jackel. Backpropagation applied to handwritten zip code recognition. Neural Computation, 1:541–551, 1989.
  • Li et al. [2021] D. Li, J. Hu, C. Wang, X. Li, Q. She, L. Zhu, T. Zhang, and Q. Chen. Involution: Inverting the inherence of convolution for visual recognition. CVPR, 2021.
  • Lin et al. [2016] Z. Lin, R. Memisevic, and K. Konda. How far can we go without convolution: Improving fullyconnected networks. In ICLR, Workshop Track, 2016.
  • Luo et al. [2016] W. Luo, Y. Li, R. Urtasun, and R. Zemel. Understanding the effective receptive field in deep convolutional neural networks. In NeurIPS, 2016.
  • Mahajan et al. [2018] D. Mahajan, R. Girshick, V. Ramanathan, K. He, M. Paluri, Y. Li, A. Bharambe, and L. van der Maaten. Exploring the limits of weakly supervised pretraining. In ECCV, 2018.
  • Neyshabur [2020] B. Neyshabur. Towards learning convolutions from scratch. In NeurIPS, 2020.
  • Nilsback and Zisserman [2008] M. Nilsback and A. Zisserman. Automated flower classification over a large number of classes. In ICVGIP, 2008.
  • Parkhi et al. [2012] O. M. Parkhi, A. Vedaldi, A. Zisserman, and C. V. Jawahar. Cats and dogs. In CVPR, 2012.
  • Parmar et al. [2018] N. Parmar, A. Vaswani, J. Uszkoreit, L. Kaiser, N. Shazeer, A. Ku, and D. Tran. Image transformer. In ICML, 2018.
  • Pham et al. [2021] H. Pham, Z. Dai, Q. Xie, M.-T. Luong, and Q. V. Le. Meta pseudo labels. In CVPR, 2021.
  • Pinz [2006] A. Pinz. Object categorization. Foundations and Trends in Computer Graphics and Vision, 1(4), 2006.
  • Polyak and Juditsky [1992] B. T. Polyak and A. B. Juditsky. Acceleration of stochastic approximation by averaging. SIAM Journal on Control and Optimization, 30(4):838–855, 1992.
  • Ramachandran et al. [2019] P. Ramachandran, N. Parmar, A. Vaswani, I. Bello, A. Levskaya, and J. Shlens. Stand-alone self-attention in vision models. In NeurIPS, 2019.
  • Sandler et al. [2019] M. Sandler, J. Baccash, A. Zhmoginov, and Howard. Non-discriminative data or weak model? On the relative importance of data and model resolution. In ICCV Workshop on Real-World Recognition from Low-Quality Images and Videos, 2019.
  • Shang et al. [2016] W. Shang, K. Sohn, D. Almeida, and H. Lee. Understanding and improving convolutional neural networks via concatenated rectified linear units. In ICML, 2016.
  • Sifre [2014] L. Sifre. Rigid-Motion Scattering For Image Classification. PhD thesis, Ecole Polytechnique, 2014.
  • Simonyan and Zisserman [2015] K. Simonyan and A. Zisserman. Very deep convolutional networks for large-scale image recognition. In ICLR, 2015.
  • Srinivas et al. [2021] A. Srinivas, T.-Y. Lin, N. Parmar, J. Shlens, P. Abbeel, and A. Vaswani. Bottleneck transformers for visual recognition. arXiv preprint arXiv:2101.11605, 2021.
  • Srivastava et al. [2014] N. Srivastava, G. Hinton, A. Krizhevsky, I. Sutskever, and R. Salakhutdinov. Dropout: A simple way to prevent neural networks from overfitting. JMLR, 15(56), 2014.
  • Sun et al. [2017] C. Sun, A. Shrivastava, S. Singh, and A. Gupta. Revisiting unreasonable effectiveness of data in deep learning era. In ICCV, 2017.
  • Szegedy et al. [2015] C. Szegedy, W. Liu, Y. Jia, P. Sermanet, S. Reed, D. Anguelov, D. Erhan, V. Vanhoucke, and A. Rabinovich. Going deeper with convolutions. In CVPR, 2015.
  • Szegedy et al. [2016] C. Szegedy, V. Vanhoucke, S. Ioffe, J. Shlens, and Z. Wojna. Rethinking the inception architecture for computer vision. In CVPR, 2016.
  • Tay et al. [2020] Y. Tay, D. Bahri, D. Metzler, D.-C. Juan, Z. Zhao, and C. Zheng. Synthesizer: Rethinking self-attention in transformer models. arXiv, 2020.
  • Touvron et al. [2019] H. Touvron, A. Vedaldi, M. Douze, and H. Jegou. Fixing the train-test resolution discrepancy. In NeurIPS, 2019.
  • Touvron et al. [2020] H. Touvron, M. Cord, M. Douze, F. Massa, A. Sablayrolles, and H. Jégou. Training data-efficient image transformers & distillation through attention. arXiv preprint arXiv:2012.12877, 2020.
  • Vaswani et al. [2017] A. Vaswani, N. Shazeer, N. Parmar, J. Uszkoreit, L. Jones, A. N. Gomez, Ł. Kaiser, and I. Polosukhin. Attention is all you need. In NeurIPS, 2017.
  • Vaswani et al. [2021] A. Vaswani, P. Ramachandran, A. Srinivas, N. Parmar, B. Hechtman, and J. Shlens. Scaling local self-attention for parameter efficient visual backbones. arXiv preprint arXiv:2103.12731, 2021.
  • Wang et al. [2021] W. Wang, E. Xie, X. Li, D.-P. Fan, K. Song, D. Liang, T. Lu, P. Luo, and L. Shao. Pyramid vision transformer: A versatile backbone for dense prediction without convolutions. arXiv preprint arXiv:2102.12122, 2021.
  • Wang et al. [2018] X. Wang, R. Girshick, A. Gupta, and K. He. Non-local neural networks. In CVPR, 2018.
  • Wightman [2019] R. Wightman. Pytorch image models. https://github.com/rwightman/pytorch-image-models, 2019.
  • Wu et al. [2019] F. Wu, A. Fan, A. Baevski, Y. Dauphin, and M. Auli. Pay less attention with lightweight and dynamic convolutions. In ICLR, 2019.
  • Xie et al. [2020] Q. Xie, M.-T. Luong, E. Hovy, and Q. V. Le. Self-training with noisy student improves imagenet classification. In CVPR, 2020.
  • Xie et al. [2016] S. Xie, R. Girshick, P. Dollár, Z. Tu, and K. He. Aggregated residual transformations for deep neural networks. arXiv preprint arXiv:1611.05431, 2016.
  • Zhai et al. [2019] X. Zhai, J. Puigcerver, A. Kolesnikov, P. Ruyssen, C. Riquelme, M. Lucic, J. Djolonga, A. S. Pinto, M. Neumann, A. Dosovitskiy, et al. A large-scale study of representation learning with the visual task adaptation benchmark. arXiv preprint arXiv:1910.04867, 2019.
  • Zhai et al. [2021] X. Zhai, A. Kolesnikov, N. Houlsby, and L. Beyer. Scaling vision transformers. arXiv preprint arXiv:2106.04560, 2021.
  • Zhang et al. [2018] H. Zhang, M. Cisse, Y. N. Dauphin, and D. Lopez-Paz. mixup: Beyond empirical risk minimization. In ICLR, 2018.

附录A没有帮助的事情

A.1 修改 Token 混合MLP

我们放弃了许多尝试改进在 JFT-300M 上预训练的各种规模 Mixer 模型的 Token 混合 MLP 的想法。

解除(不共享)参数的束缚

Mixer 层中的 Token 混合 MLP 在输入表 𝐗S×C 的列之间共享。 换句话说,相同的 MLP 应用于每个C不同的特征。 相反,我们可以引入具有独立权重的 C 单独 MLP,有效地将参数数量乘以 C。我们没有观察到任何明显的改进。

将通道分组在一起

Token 混合 MLP 将 S 维向量作为输入。 每个这样的向量都包含 S 不同空间位置的单个特征的值。 换句话说, Token 混合 MLP 的运行方式是一次仅查看一个通道 人们可以通过连接 𝐗S×C 中的 G 相邻列将通道分组在一起,将其重塑为维度 (SG)×(C/G) 的矩阵。 这将 MLP 的输入维度从 S 增加到 GS,并将要处理的向量数量从 C 减少到 C/G 现在,MLP 在混合 Token 时同时查看多个通道 这种列向量的串联将 ImageNet 上的线性 5-shot top-1 精度提高了不到 1-2%。

我们尝试了不同的版本,将上面描述的简单重塑替换为以下内容: (1) 引入 G 线性函数(具有可训练参数),将 C 投影到 C/G (2) 使用它们,将 𝐗S×C 中的每个 S 行(标记)映射到 G 不同的 (C/G) 维向量。 这导致每个词符的G不同“视图”,每个词符都包含C/G特征。 (3) 最后,为每个C/G特征连接对应于G不同视图的向量。 这会产生维度为 (SG)×(C/G) 的矩阵。 这个想法是,在混合 Token 时,MLP 可以查看原始通道的G不同视图 该版本将 Mixer-S/32 架构的 top-5 ImageNet 准确率提高了 3-4%,但是对于更大的规模没有任何改进。

金字塔

Mixer 中的所有层都保留相同的各向同性设计。 ViT 架构的最新改进表明这可能并不理想[52] 我们尝试使用 Token 混合 MLP 通过从 S 输入 Token 映射到 S<S 输出 Token 来减少 Token 数量。 虽然第一个实验表明,在 JFT-300M 上,此类模型显着减少了训练时间,而不会损失太多性能,但我们无法将这些发现转移到 ImageNet 或 ImageNet-21k。 然而,由于金字塔是一种流行的设计,因此探索这种设计用于其他视觉任务可能仍然有希望。

A.2微调

遵循 BiT [22] 和 ViT [14] 的想法,我们还尝试使用 mixup [60] 和 Polyak 平均 [36 ] 微调期间。 然而,这些并没有带来持续的改进,所以我们放弃了它们。 我们还在微调期间尝试使用初始裁剪 [45],但这也没有带来任何改进。 我们对所有规模的 JFT-300M 预训练混合器模型进行了这些实验。

附录B预训练:超参数、数据增强和正则化

在表4中,我们描述了用于预训练 Mixer 模型的最佳超参数设置。

对于 ImageNet 和 ImageNet-21k 的预训练,我们使用了额外的增强和正则化。 对于 RandAugment [12],我们始终使用两个增强层和扫描幅度 m,参数集 {0,10,15,20} 中。 对于混合 [60],我们在集合 {0.0,0.2,0.5,0.8} 中扫描混合强度 p 对于 dropout [43],我们尝试降低 0.00.1d 率。 对于随机深度,按照原始论文[19],我们将从0.0(对于第一个MLP)丢弃一层的概率线性增加到s (对于最后一个 MLP),我们尝试 s{0.0,0.1} 最后,我们分别从 {0.003,0.001}{0.1,0.01} 中扫描学习率 lr 和权重衰减 wd

表4 用于预训练 Mixer 模型的超参数设置。
Model Dataset Epochs lr wd RandAug. Mixup Dropout Stoch. depth
Mixer-B ImNet 300 0.001 0.1 15 0.5 0.0 0.1
Mixer-L ImNet 300 0.001 0.1 15 0.5 0.0 0.1
Mixer-B ImNet-21k 300 0.001 0.1 10 0.2 0.0 0.1
Mixer-L ImNet-21k 300 0.001 0.1 20 0.5 0.0 0.1
Mixer-S JFT-300M 5 0.003 0.03
Mixer-B JFT-300M 7 0.003 0.03
Mixer-L JFT-300M 7/14 0.001 0.03
Mixer-H JFT-300M 14 0.001 0.03

附录C微调:超参数和更高的图像分辨率

除非另有说明,模型均在分辨率 224 下进行微调。 我们遵循[14]的设置。 唯一的区别是: (1) 我们从网格搜索中排除 lr=0.001,而是包含 CIFAR-10、CIFAR-100、花卉和宠物的 lr=0.06 (2) 我们在 lr{0.003,0.01,0.03} 上对 VTAB-1k 进行网格搜索。 (3) 在评估过程中,我们尝试了两种不同的预处理方法:(i) "调整大小-裁剪":首先将图像调整为 256×256 像素,然后进行 224×224 像素大小的中心裁剪。 (ii) “resmall-crop”:首先将图像短边的大小调整为 256 像素,然后采用 224×224 像素大小的中央裁剪。 对于正文表 3 中报告的 Mixer 和 ViT 模型,我们在 ImageNet、Pets、Flowers、CIFAR-10 和 CIFAR-100 上使用 (ii)。 我们对正文表 3 中报告的 BiT 模型使用了相同的设置,唯一的例外是在 ImageNet 上使用 (i)。 对于正文表 2 中报告的 Mixer 模型,我们对所有 5 个下游数据集使用 (i)。

事实证明,以比预训练时使用的分辨率更高的分辨率进行微调可以显着提高现有视觉模型的传输性能[48,22,14] 因此,我们也将这种技术应用于 Mixer。 当向模型提供更高分辨率的图像时,我们不会更改补丁大小,这会导致标记序列更长。 必须调整 Token 混合 MLP 以处理这些较长的序列。 我们尝试了多种选择,并在下面描述了最成功的一种。

为简单起见,我们假设图像分辨率增加了整数倍K。词符序列的长度S增加了K2倍。 我们还将 Token 混合 MLP 的隐藏宽度 DS 增加了 K2 倍。 现在我们需要使用预训练 MLP 的参数来初始化这个新的(更大的)MLP 的参数。 为此,我们将输入序列分成 K2 相等的部分,每个部分的原始长度为 S,并初始化新的 MLP,以便它与预训练的 MLP。

形式上,式(1)中原始MLP的预训练权重矩阵𝐖1DS×S。正文的 1 现在将替换为更大的矩阵 𝐖1(K2DS)×(K2S) 假设调整大小后的输入图像的词符序列是每个长度为 SK2 词符序列的串联,通过将输入分割为 K×K 相等的部分来计算空间上。 然后,我们使用块对角矩阵初始化 𝐖1,该矩阵在其主对角线上具有 𝐖1 的副本。 MLP 的其他参数的处理类似。

附录 D 权重可视化

为了更好的可视化,我们根据尝试首先显示低频滤波器的启发式对所有隐藏单元进行排序。 对于每个单位,我们还尝试识别最接近其逆的单位。 6显示了每个单位及其最接近的逆元。 请注意,在 ImageNet 和 ImageNet-21k 上预训练的模型使用了大量数据增强。 我们发现这强烈影响学习单元的结构。

Refer to caption
图6 在三个不同数据集()上训练的 Mixer-B/16 模型的前两个 Token 混合 MLP()中所有隐藏密集单元的权重。 每个单元都有 14×14=196 权重,即传入 Token 的数量,并被描绘为 14×14 图像。 每个块中总共有 384 个隐藏单元。

我们还在图7中可视化了不同模型学习的嵌入层中的线性投影单元。 有趣的是,它们的属性似乎很大程度上取决于模型使用的补丁分辨率。 在所有 Mixer 模型尺度中,使用更高分辨率 32×32 的补丁会导致类似 Gabor 的低频线性投影单元,而对于 16×16 分辨率,单元则不会出现这种情况结构。

Refer to caption
Refer to caption
图7 在 JFT-300M 上预训练的 Mixer-B/16()和 Mixer-B/32()模型的嵌入层的线性投影单元。 使用更高分辨率的补丁32×32的Mixer-B/32模型学习非常结构化的低频投影单元,而Mixer-B/16学习的大多数单元具有高频并且没有清晰的结构。

附录EMLP-Mixer代码

1import einops
2import flax.linen as nn
3import jax.numpy as jnp
4
5class MlpBlock(nn.Module):
6 mlp_dim: int
7 @nn.compact
8 def __call__(self, x):
9 y = nn.Dense(self.mlp_dim)(x)
10 y = nn.gelu(y)
11 return nn.Dense(x.shape[-1])(y)
12
13class MixerBlock(nn.Module):
14 tokens_mlp_dim: int
15 channels_mlp_dim: int
16 @nn.compact
17 def __call__(self, x):
18 y = nn.LayerNorm()(x)
19 y = jnp.swapaxes(y, 1, 2)
20 y = MlpBlock(self.tokens_mlp_dim, name=’token_mixing’)(y)
21 y = jnp.swapaxes(y, 1, 2)
22 x = x+y
23 y = nn.LayerNorm()(x)
24 return x+MlpBlock(self.channels_mlp_dim, name=’channel_mixing’)(y)
25
26class MlpMixer(nn.Module):
27 num_classes: int
28 num_blocks: int
29 patch_size: int
30 hidden_dim: int
31 tokens_mlp_dim: int
32 channels_mlp_dim: int
33 @nn.compact
34 def __call__(self, x):
35 s = self.patch_size
36 x = nn.Conv(self.hidden_dim, (s,s), strides=(s,s), name=’stem’)(x)
37 x = einops.rearrange(x, ’n h w c -> n (h w) c’)
38 for _ in range(self.num_blocks):
39 x = MixerBlock(self.tokens_mlp_dim, self.channels_mlp_dim)(x)
40 x = nn.LayerNorm(name=’pre_head_layer_norm’)(x)
41 x = jnp.mean(x, axis=1)
42 return nn.Dense(self.num_classes, name=’head’,
43 kernel_init=nn.initializers.zeros)(x)
清单 1: 用 JAX/Flax 编写的 MLP-Mixer 代码。