极市导读
使用知识蒸馏策略,只训练线性注意模块 50K 步,LinFusion 的性能即可与原始 SD 相当甚至更好,同时显著降低了时间和显存占用的复杂度。同时,它还可以实现令人满意的交叉分辨率生成性能,并且可以单卡生成 16K 分辨率的高清大图。 >>加入极市CV技术交流群,走在计算机视觉的最前沿
本文目录
1 1 块 GPU,1 分钟生成 16K 高清大图
(来自 NUS)
1 LinFusion 论文解读
1.1 Self-Attention 的二次计算复杂度问题
1.2 基线模型 Stable Diffusion 和 Mamba
1.3 LinFusion 方法概述及其优势
1.4 Normalization-Aware Mamba
1.5 Non-Causal Mamba
1.6 训练目标
1.7 与 SD 组件的兼容性
太长不看版
现代主流的文生图扩散模型,尤其是基于 Transformer 的 UNet 进行去噪的模型,比较依赖于 Self-Attention 操作,而且也能够实现逼真的生成能力。但是,Self-Attention 操作在生成高分辨率视觉内容方面面临着重大挑战,因为 Self-Attention 操作相对于 token 数量呈二次方的时间和显存复杂度是一个老生常谈的问题。
本文就聚焦于这个问题,作者希望设计新颖的 Linear Attention 机制作为 Self-Attention 的替代方案。那么提起线性计算复杂度的模型,最近就有很多比如 Mamba2、RWKV6、Gated Linear Attention 等等,作者从这些架构开始探索。
除此之外,作者还提出了 2 个关键特征,即 Attention Normalization 和 Non-Causal Inference,这 2 个特征可以增强高分辨率视觉生成模型的性能。在这些发现的基础上,作者引入了一种 Generalized Linear Attention 范式。为了节约训练成本并且更好地利用预训练模型,作者从预训练的 Stable Diffusion (SD) 初始化文本模型并通过蒸馏来提取知识。作者发现,适当训练之后,本文的 LinFusion 的蒸馏的模型的性能与原始 SD 模型相当甚至更好,同时显著降低了时间和显存复杂度。作者在 SD-v1.5,SD-v2.1 和 SD-XL 上做了实验,LinFusion 可以提供令人满意的图像生成性能,比如可以生成 16K 分辨率高清大图。此外,LinFusion 与预训练的 SD 组件 (比如 ControlNet 和 IP-Adapter) 高度兼容。
1 1 块 GPU,1 分钟生成 16K 高清大图
论文名称:LinFusion: 1 GPU, 1 Minute, 16K Image
论文地址:
http://arxiv.org/pdf/2409.02097
代码链接:
http://github.com/Huage001/LinFusion
1.1 Self-Attention 的二次计算复杂度问题
扩散模型能成功的一个很重要的原因可以归结为其强大的 Backbone。从带有 Self-Attention 的 U-Net 架构[1][2]到视觉 Transformer[3][4][5][6],现有的设计在很大程度上依赖于 Self-Attention 模块来管理 token 之间的复杂的空间位置关系。尽管性能强大,但是 Self-Attention 操作中固有的二次时间和内存复杂度也为高分辨率视觉生成带来了重大的挑战。如图 2(a) 所示,使用 FP16 精度,由于内存不足的问题,SD-v1.5 在 A100 上无法生成 2048 分辨率的图像。要知道这已经是一个 80GB 内存的 GPU 了,这个问题使得更高的分辨率或更大的模型更困难了。
为了解决这些问题,本文目标是一种新颖的 token mixing 机制,其与 token 的数量呈线性关系,为经典的 Self-Attention 机制提供了一种替代方案。受最近引入的线性复杂度模型的启发,例如 Mamba[7]和 Mamba2[8],这些模型在顺序生成任务中很有潜力,因此作者首先研究了它们在扩散模型中作为 token mixer 的适用性。
但是,Mamba 扩散模型有 2 个缺点。其一,当扩散模型推理的分辨率与训练时不同的时候,理论分析表明其特征的分布会趋向偏移,导致交叉分辨率推理困难。其二,扩散模型执行去噪任务而不是自回归任务,允许模型同时访问所有 token 并根据整个输入生成新的 token。相比之下,Mamba 本质上是一个按顺序处理标记的 RNN,这意味着生成新的 token 时只能够以前面的 token 作为条件,这是一个 Causal Restriction。将 Mamba 直接应用于扩散模型会对去噪过程施加这种不必要的因果限制,效果反而适得其反。尽管双向扫描分支可以在一定程度上缓解这个问题,但这个问题依然会不可避免地存在于每个分支中。
针对 Diffusion Mamba 的上述缺点,本文提出了一种 Generalized Linear Attention。首先,为了解决相对较低的训练分辨率与相对更高的推理分辨率之间的分布偏移,作者设计了一个**归一化器 (Normalizer)**,由所有 token 对当前 token 的累积影响定义。其次,本文的目标是 Mamba 的 Non-Causal 版本。一开始作者简单删除了应用于遗忘门的下三角 Causal Mask 开始本文探索,并发现所有标记最终都会具有相同的隐藏状态,这破坏了模型的性能。为了解决这个问题,作者为不同的令牌引入了不同的遗忘门组,并提出了一种高效的低秩近似,使模型能够以线性注意形式优雅地实现。作者还从技术上分析了所提出的方法,以及最近的一些线性复杂度 token mixer,例如 Mamba2[8],RWKV6[9],Gated Linear Attention[10],并揭示了我们的模型可以被视为这些流行模型的广义 Non-Causal 版本。
作者将所提出的 Generalized Linear Attention 模块集成到 SD 的架构中,替换原始的 Self-Attention 模块,生成的模型称为 LinFusion。使用知识蒸馏策略,只训练线性注意模块 50K 步,LinFusion 的性能即可与原始 SD 相当甚至更好,同时显著降低了时间和显存占用的复杂度,如图 2 所示。同时,它还可以实现令人满意的交叉分辨率生成性能,并且可以单卡生成 16K 分辨率的高清大图。
1.2 基线模型 Stable Diffusion 和 Mamba
Stable Diffusion
Stable Diffusion (SD) 作为文生图的一种流行的模型, 首先学习一个 Auto-Encoder , 其中编码器 将图像 映射到低维的 latent space 中 , 解码器 将 解码回图像空间 , 使 接近原始图像 。
在推理时, latent space 中的高斯噪声 被 UNet 随机采样并去噪 步。解码器 对最后一步 去噪的 latent code 进行解码, 得到生成的图像。
在训练期间, 给定图像 及其对应的文本描述 用于获得其对应的 latent code, 为其第 步的噪声版本 添加一个随机高斯噪声 。UNet 通过噪声预测损失 进行训练:
Mamba
Mamba 是一种用来替代 Transformer 的,具有线性复杂度的神经网络模型,其核心是状态空间模型 (State Space Model, SSM),可以写成:
其中 是序列中当前 token 的 index, 表示 hidden state, 和 分别是表示输入和输出矩阵的第 行的行向量, 分别是 input-dependent 的变量, 表示逐元素乘法。
在最新版本中, 即 Mamba2, 是一个标量 , 。根据状态空间的对偶性 (State-Space Duality, SSD), 上式 2 中的计算可以重新表述为以下表达式, 称为 1-半可分离结构化掩码注意力:
其中, 是一个 下三角矩阵 , for 。这样的矩阵 被称为 1-半可分离矩阵, 确保 Mamba2 可以线性复杂度实现。
1.3 LinFusion 方法概述及其优势
本文研究对象是 Diffusion 模型的 Backbone,用于文生图模型的一般问题,如图 3 所示。而且希望最终得到的模型具有线性复杂度。所以奔着这个目标,作者并没有重新训练模型,而是从预训练的 SD 模型初始化和蒸馏模型。具体来讲,默认使用 SD-v1.5 模型,并将其中具有二次计算复杂度的 Self-Attention 模块替换为本文的 LinFusion 模块。只有这些模块的参数是可训练的,其余参数全部固定住。将原始的 SD 模型的知识提炼成 LinFusion,这样给定相同的输入,使得其输出尽可能接近。
这种方法提供了 2 个优点:
训练难度和计算开销显著降低,因为学生模型只需要学习空间关系,而无需添加处理文本图像对齐等其他方面的复杂性。 最终得到的模型与在原始 SD 模型的现有组件高度兼容,因为其实只是用了 LinFusion 模块替换 Self-Attention 模块,这样就可以在保持整体架构的同时在功能上与原始层相似。
从技术上讲,为了得到具有线性计算复杂度的 Diffusion Backbone,一个简单的方案是使用 Mamba2 替换所有的 Self-Attention,如图 4 (a) 所示。作者使用双向的 SSM 来确保当前位置可以从后续位置访问信息。SD 中的 Self-Attention 模块不包含 Mamba2 中的门控操作或者 RMS-Norm。作者为了保持一致性,就删除了这些结构,导致性能略有提高。
1.4 Normalization-Aware Mamba
实践中,作者发现图 4 (b) 所示的基于 SSM 的结构如果训练和推理过程得分辨率一致,就可达到令人满意的性能。但是当图像尺寸不同时就会失败。为了找到这种失败的原因,作者检查了输入输出特征逐 channel 的均值,并得到下面的命题:
Proposition 1:假设输入特征 中的第 个通道的平均值为 , 把 记作 ,输出特征 中该通道的平均值为 。
通过图 4 (b) 观察到, 在 和 上应用了非负激活值。鉴于 在 Mamba2 中也是非负的,根据 Proposition 1, 如果 在训练和推理中不一致, 则通道分布会发生偏移, 这进一步导致结果失真。
解决这个问题需要将所有 token 对彼此的影响统一到相同的尺度上, 这也是 Softmax 函数固有的属性。鉴于此, 作者在本文中提出了归一化感知 Mamba, 强制每个 token 的注意力权重之和等于 1, 即 , 这相当于多次应用 SSM 模块得到归一化因子
这个操作如图 4 (c) 所示。实验表明,这种归一化大大提高了交叉分辨率泛化能力。
1.5 Non-Causal Mamba
虽然双向扫描使 token 能够从后续 token 接收信息,但是现在模型将特征映射视为一维序列,那么这会损害二维图像和高维视觉内容的内在空间结构。为了高效解决这个问题,作者在本文中专注于开发 Mamba 的 Non-Causal 的版本。
Non-Causal 关系表明一个 token 可以访问所有 token 进行信息混合, 这可以通过简单地删除应用于 的下三角 Causal Mask 来实现。因此, 式 2 就会变成: 。在这个公式中, 就与 无关了。这意味着所有 token 的 hidden state 是均匀的, 这从根本上破坏了遗忘门 的预期目的。为了解决这个问题,作者将不同的 组与各种输入 token 相关联。在这种情况下, 是一个 矩阵 。式3 中的 会变为 。与式 3 相比, 这里的 不一定是 1-半可分离的。为了保持线性复杂度, 作者假设 是低秩可分离的,即存在输入相关矩阵 和 使得 。通过这种方式, 以下命题确保在这种情况下式 3 可以通过线性注意力来实现:
Proposition 2: 给 , 记 ,存在相应的函数 和 ,使式 3 可以等价地实现为线性注意力, 表示为 。
证明:
其中, 表示 Kronecker 积。定义 和 ,推导出 。
在实践中, 作者采用 2 个 MLP 来模拟函数 和 的功能。最终推导出如图 4 (d) 所示的结构。
不仅如此, 作者进一步证明 Proposition 2 中描述的线性注意形式可以扩展到更一般的情况, 其中 是 维向量而不是标量:
Proposition 3:给定 , 如果对于每个 是低秩可分离的: ,其中 ,存在相应的函数 和 , 使得计算 可以等效地实现为线性注意力, 表示为 , 其中 是一个列向量, 可以广播到 矩阵。
证明:
其中, 和 是 矩阵, . 表示与广播的元素乘法, vec 表示将矩阵展平为行向量。定义 和 ,推导出 。
从这个角度来看, 所提出的结构可以被认为是最近线性复杂度序列模型的广义线性注意和非因果形式, 包括 Mamba2 , RWKV6 , Gated Linear Attention 。下图 5 对比了不同架构中 的参数。
1.6 训练目标
在本文中, 作者将原始 SD 中的所有 Self-Attention 模块替换为 LinFusion 模块。只有这些模块中的参数被训练, 其他所有的参数都保持冻结。为使 LinFusion 模仿 Self-Attention 的原始功能, 为标准噪声预测损失 添加了额外的损失。具体而言, 作者引入了一个知识蒸馏损失 来对齐学生和教师模型的最终输出, 以及一个特征匹配损失 以匹配每个 LinFusion 模块的输出和相应的 Self-Attention 层。训练目标可以写成:
其中, 和 是控制各自损失项权重的超参数, 表示原始 SD 的参数, 是 LinFusion 或者 Self-Attention 模块的数量, 上标 表示 Diffusion Backbone 中的第 个输出。
实现细节
在图5中展示了 SD-v1.5、SD-v2.1和 SD-XL 的定性结果,并在本节中主要在 SD-v1.5 上进行实验。SD-v1.5 中有 16 个 Self-Attention 层,将它们替换为本文的 LinFusion 模块。Proposition 2 中提到的函数 和 实现为MLP,它由一个线性分支和一个带有一个 Linear-LayerNormLeakyReLU 块的非线性分支组成。这 2 个分支加在一起得到 和 的输出。线性分支的输出分别初始化为 和 ,同时非线性分支的输出初始化为 0 。作者使用了 LAION 中美学得分大于 6.5 的 169K 张图片进行训练,并采用 BLIP2 图像字幕模型重新生成新的文本描述。超参数 和 设置为 0.5 值 ,这个工作也侧重于 SD 的蒸馏。使用 AdamW 优化器进行优化,学习率为 。在 8 个 RTX6000Ada GPU 上进行,总 Batch Size 大小为 96 ,分辨率为 ,训练 100K 步骤, 需要约 1 天才能完成。效率评估在单个 NVIDIA A100-SXM4-80GB GPU 上进行。
消融实验
为了证明 LinFusion 的有效性,作者报告了与替代解决方案的比较结果,例如图 4(a)、(b) 和 (c) 所示的结果。作者遵循以前工作的惯例,专注于文生图,对 COCO 基准进行定量评估,其中包含 30K 个文本 Prompt。这些指标是针对 COCO2014 测试数据集的 FID 以及 CLIP-ViT-G 特征空间的余弦相似度。作者还报告了包含 50 去噪步骤的每张图片的运行时间,以及推理过程 GPU Memory (GB)。512×512 分辨率的结果如图 6 所示。
削弱结构差异
作者从具有双向扫描的原始 Mamba2 结构开始探索,即图 4(a),并尝试去除门控和 RMS-Norm,即图 4(b),以保持与原始 SD 中的 Self-Attention 模块一致的整体结构。这样,与原始 SD 的唯一区别是 SSM 或 Self-Attention 用于 token mixing。
Normalization 和 Non-Causality
作者依次应用所提出的 Normalization 操作和 Non-Causality 处理,对应于图 4(c) 和 (d)。即使图 6 的结果表明 Normalization 会轻微损失性能,但是作者在图 7 中展示 Normalization 对于生成在训练期间看不到分辨率的图像至关重要。进一步添加 Non-Causality 处理,可以获得优于图 4 (b) 的结果。
知识蒸馏和特征匹配
作者最终在式 7 中应用了知识蒸馏损失项 和特征匹配损失项 , 进一步提高了性能, 甚至超过了 SD 教师模型。
交叉分辨率推理
扩散模型在训练过程中生成不可见分辨率的图像很常见。原始 SD 模型可以做到这一点。由于 LinFusion 以外的模块进行了预训练和固定,因此 Normalization 很重要,以保持训练和推理时特征分布一致。作者在图 7 中报告了 1024×1024 分辨率的结果。这些结果表明结论适用于所有基本结构,例如 Mamba2,没有门控和 RMS-Norm 的 Mamba2 和 Generalized Linear Attention 等等。图 8 是一些定性结果,其中没有 Normalization 的结果是没有意义的。
超高分辨率生成
直接应用在低分辨率上训练的扩散模型来生成更高分辨率的图像可能会导致内容失真和重复[16]。本文首先处理低分辨率,基于该低分辨率使用 SDEdit[17]来生成更高分辨率的图像。
1.7 与 SD 组件的兼容性
LinFusion 与 SD 的各种组件高度兼容,如 ControlNet[18]、IP-Adapter[19]和 LoRA[20],而无需任何进一步的训练或适应。
ControlNet
ControlNet 为一些附加的 Condition (比如边缘、深度和语义图) 向 SD 引入即插即用的组件。作者将 SD 替换为 LinFusion,并比较了一些指标。结果如下图 9 所示。
IP-Adapter
个性化文生图[21]是一种流行的 SD 应用,它专注于在输入身份和文本描述之后同时生成图像。IP-Adapter 提供了一个解决方案,从图像空间训练映射器到 SD 的条件空间,这样它就可以处理图像和文本。作者证明了在 SD 上训练的 IP-Adapter 可以直接用于 LinFusion。DreamBooth 数据集[22]包含 30 个身份和 25 个文本 Prompt,总共形成 750 个测试用例。作者对于每种情况使用 5 个随机种子,结果如图 10 所示,并报告平均 CLIP 图像相似度,DINO 图像相似度和 CLIP 文本相似度。
LoRA
LoRA 旨在应用于基本模型权重的低秩矩阵,以便它可以适应不同的任务或目的。比如 LCM-LoRA[23]使得预训练的 SD 可用于 LCM 推理,只需几个去噪步骤。这里作者将 LCM-LoRA 中的 LoRA 应用于 LinFusion。COCO 基准的性能如图 11 所示。
参考
^U-net: Convolutional networks for biomedical image segmentation
^High-Resolution Image Synthesis with Latent Diffusion Models
^Scalable Diffusion Models with Transformers
^All are Worth Words: A ViT Backbone for Diffusion Models
^Pixart-α: Fast training of diffusion transformer for photorealistic text-to-image synthesis
^Scaling Rectified Flow Transformers for High-Resolution Image Synthesis
^Mamba: Linear-Time Sequence Modeling with Selective State Spaces
^abcTransformers are ssms: Generalized models and efficient algorithms through structured state space duality
^abEagle and finch: Rwkv with matrix-valued states and dynamic recurrence
^abGated linear attention transformers with hardware-efficient training
^Highresolution image synthesis with latent diffusion models
^Denoising diffusion probabilistic models
^Laion-5b: An open large-scale dataset for training next generation image-text models
^Blip-2: Bootstrapping language-image pre-training with frozen image encoders and large language models
^BK-SDM: Architecturally Compressed Stable Diffusion for Efficient Text-to-Image Generation
^Fouriscale: A frequency perspective on training-free high-resolution image synthesis
^Sdedit: Guided image synthesis and editing with stochastic differential equations
^Adding conditional control to text-to-image diffusion models
^Ip-adapter: Text compatible image prompt adapter for text-to-image diffusion models
^LoRA: Low-rank adaptation of large language models
^An image is worth one word: Personalizing text-to-image generation using textual inversion
^DreamBooth: Fine Tuning Text-to-Image Diffusion Models for Subject-Driven Generation
^LCM-LoRA: A Universal Stable-Diffusion Acceleration Module
公众号后台回复“极市直播”获取100+期极市技术直播回放+PPT
极市干货
# 极市平台签约作者#
科技猛兽
知乎:科技猛兽
清华大学自动化系19级硕士
研究领域:AI边缘计算 (Efficient AI with Tiny Resource):专注模型压缩,搜索,量化,加速,加法网络,以及它们与其他任务的结合,更好地服务于端侧设备。
作品精选