极市导读
该方案突破了领域内“Contrastive loss 由于显存限制不能放大 batch size”的“共识”,实现了对比损失的 batch size 近乎无限的扩展。>>加入极市CV技术交流群,走在计算机视觉的最前沿
达摩院研究员提出了一种对比损失(Contrastive Loss)的高效实现方式(Inf-CL),通过分块计算策略,在单台 A800 机器上就能把 batch size 扩展到 400万。该方案突破了领域内“Contrastive loss 由于显存限制不能放大 batch size”的“共识”,实现了对比损失的 batch size 近乎无限的扩展。
论文标题:
Breaking the Memory Barrier: Near Infinite Batch Size Scaling for Contrastive Loss
论文链接:
https://arxiv.org/pdf/2410.17243
代码链接:
https://github.com/DAMO-NLP-SG/Inf-CLIP
先看显著结果:
图中标出了常见的 GPU 显存限制。对于超过 80GB A800 显存瓶颈 的情况,通过曲线拟合估算显存消耗。
左图:在 8×A800 GPU 配置下,CLIP 和 OpenCLIP 的显存消耗呈 二次增长,而 Inf-CL 实现了 线性增长。在 256k batch size 下,Inf-CL 将显存消耗降低了 78 倍。 右图:在 1024k batch size 下,即使使用 128 块 GPU,CLIP 和 OpenCLIP 的显存仍然会炸。而 Inf-CL 将显存需求减少了 281 倍。
01 背景
1.1 对比损失
对比学习从 20 年以来开始爆火,从那个时代走过来的小伙伴,应该还记得这个简单的损失函数绽放了多大的光彩。
在图像自监督领域 SimCLR 和 MoCo 两大模型系列相互争锋,跨模态检索领域,开启图文检索预训练的 CLIP 模型,在 NLP 和信息检索领域,大家耳熟能详的 SimCSE 和 DPR 等模型,都采用了 Contrastive Loss 作为训练损失。
以 CLIP 中的实现为例简单回顾一下 Contrastive loss。假设 batch size 为 , 图像和文本特征的维度为 , 则 CLIP 中的图像到文本的 Contrastive Loss 公式如下:
其中 是第 个图像和第 个文本之间的余弦相似度, 这里 是匹配样本(正样本对)的相似度。为了简化讨论, 公式中省略了温度因子。
从公式中可以看到,对比损失会将 batch 内非匹配的文本作为负样本,来计算匹配图文对(正样本对)归一化的概率。这个就叫做 In-batch negative 策略——即将 batch 内的所有其他样本视作负样本。
这种策略的优点在于,batch size 越大,模型就能接触到更多的负样本,从而学到更具判别性的特征。因此,了解对比学习的同学们都知道,batch size 理论上越大,效果就越好,这点也有很多文章从理论上进行分析。
那么一个直观地想法是,我们直接 batch size 扩大不就好了,就像别的分类,回归,或者文本生成的任务一样,把梯度累积步数多开一些,batch size 不就能一直增大了吗?
但遗憾的是,对比学习的 batch size 方法一直是一个比较蛋疼的问题。实现过对比损失的同学都知道,核心限制主要是“增大 batch size / 负样本,GPU 显存会炸”。接下来我们来分析显存消耗到了什么地方。
1.2 显存限制
在经典的对比损失实现中(如 CLIP),首先需要构建相似度矩阵,并将其存储在高带宽内存(HBM)中。然后对相似度矩阵应用 Softmax 归一化和负对数似然计算来完成损失计算。
然而, 相似度矩阵 及其归一化结果的显存需求, 会随着 batch size 呈二次方增长, 即显存复杂度是 , 这意味着当 batch size 较大时, 显存占用会变得非常庞大。
例如即使在采用 ViT-B/16 这种轻量化模型的情况下,当 batch size 达到 64k 时,Loss 计算部分的 GPU 显存消耗仍然极为惊人。如图 2(a)所示,尽管模型自身的显存开销仅为 5.24GB,但损失计算所需的显存却高达 66GB。
这个例子可以清楚看到,在 scaling batch size 时,显存瓶颈主要集中在损失计算上。现有的方法,如 Gradient Cache 和 BASIC 等,虽在一定程度上优化了模型的显存占用,但依然未能突破 loss 计算过程中显存二次增长的限制。
02 方法
2.1 分块计算策略
从上文的分析中我们可以看到,显存消耗的核心问题在于相似度矩阵 X 的完全实例化。那么有没有办法避免将它存储到显存里呢?为了达到这个效果,我们首先分析这个矩阵是用来计算什么的,所以先将对比损失的公式进行拆解分析:
公式分解后,我们可以将 contrastive loss 的计算拆解为两部分:
第一部分 :计算所有正样本对的相似度 并累加。这部分的计算复杂度是 ,即线性增长,因此不会造成显存瓶颈。
第二部分 : 计算 Log-Sum-Exp (LSE), 即所有负样本对的相似度的对数-指数和。这部分是由全局相似度矩阵 计算得到的,如果直接计算并存储整个矩阵,就会导致显存开销迅速增加。
将公式拆解后我们发现,原来相似度矩阵 X 的完全实例化是为了计算 LSE 这一项,其实也就是 Softmax 操作的分母部分。
看到这里,熟悉 on-line Softmax 和FlashAttention 技术的同学们可能已经秒懂了,本质问题是一样的:如果我们能通过分块计算避免一次性存储整个矩阵,LSE 的计算也就不会消耗很多的显存。
既然大模型的输入长度都能扩展到百万级别(例如 FlashAttention 支持的超长序列),那么对比损失的 **batch size scaling **问题自然也可以迎刃而解。
前向传播过程:
具体来说,分块策略的前向传播计算过程如下:
其中, 和 分别表示行和列方向上的分块数量。通俗的说, 就是不把矩阵 -次性计算并存储下来, 而是将矩阵 的计算划分为多个块(即子矩阵) , 并在每个块内部计算局部 LSE 值 , 之后沿着行方向逐步合并每列块的 LSE 值, 最终得到全局 LSE向量 。
这种分块计算方法显著减少了对显存的需求,因为每次只需计算和存储相似度矩阵的一部分,而不是整个 矩阵。此外,在列方向的运算支持并行,能够很好适应多 GPU 或 GPU 内部多芯片的并行架构,
防溢出策略:
为了避免在合并过程中出现数值不稳定或溢出,采用如下稳定的数值计算公式:
其中初始值 。每次迭代维护列方向的 LSE 向量 , 将中间值 累积到 中, 完成行方向所有块的计算后,得到最终的全局 LSE 向量 。
此外, 在计算 时, 直接对矩阵求指数可能导致数值溢出。为此, 我们采用以下稳定的公式进行计算:
其中 是一个行最大值向量, 每个元素代表 中对应行的最大值, 用作确保指数计算不会溢出。
反向传播过程:
其实在传统实现方式的前向传播过程中,相似度矩阵 X 会存储在计算图内,能够直接调用 pytorch 的 autograd 机制来计算梯度。既然我们在前向过程中仅仅存储了最终得到的 LES 向量 l,那么就需要自定义实现反向传播的算子。
体运算过程如下, 假设已经计算得到 loss 的结果, 要计算对于图像特征输入 和文本特征 的梯度:
根据 2.1 小节拆解的公式,以 为例,完整的梯度公式为:
从该公式可以看出,第二项计算依赖于相似度矩阵的值。我们在反向计算中也采用与前向过程相同的分块计算策略:
在前向传播时, 仅存储大小为 的向量 。
在反向传播时,逐块累积计算梯度:
其中 是用于累积的临时变量。通过这种分块计算, 我们在反向传播中同样避免了完整存储矩阵 的需求, 进一步降低了显存开销, 并实现了高效的梯度计算。详细的算法步骤在论文中可以找到。
2.2 Multi-Level Tiling**
看到这里的小伙伴们可能会产生疑问,分块累加这种操作本质上是将并行计算的过程用串行合并来替代了,也是一种时间换空间的策略,而且反向传播的 recompute 过程也会带来额外的计算,难道不会很慢吗?
其实问题的答案是:整体计算量会增加,但我们可以通过 GPU 的分布式运算特性来加速这个过程,运算速度却并不会减慢很多。加速过程主要是两块,即跨 GPU 的通讯和 GPU 内显存的 IO 加速。我们将其称为多层级分块策略。该策略将 LSE 的计算分配为粗粒度的跨 GPU 分块和细粒度的单 GPU 分块,以最大化计算效率。
跨 GPU 分块(Cross-GPU Tile)
在并行训练中, 假设有 个 GPU, 每个 GPU 处理一部分图像和文本数据, 分别生成视觉特征 和文本特征 , 其中 表示单个 GPU 上的 batch size。计算对比损失时, 我们将不同行的数据分配给不同的 GPU, 并逐步同步各 GPU 之间的列数据。
具体而言, 第 个 GPU 负责相似度矩阵的第 行子块的 :及其对应的 LSE 向量 。为了降低显存开销, 结合分块策略, 每个行块 可以进一步拆分为 步小块 来计算 LSE, 具体过程可以在论文中算法 1 中找到。每个小块 的 LSE 计算采用单 GPU 分块策略 (详见下节)。
由于计算 (当 时)需要访问其他 GPU 上的文本特征 , 这不可避免地会带来通信开销。
为此,我们采用环形拓扑结构来尽可能减少通信时间。具体而言,每个 GPU 利用当前文本特征进行计算时,会异步地将其发送给下一个 GPU,并异步地从上一个 GPU 接收新的文本特征,如此循环执行。通过这种方式,通信时间可以与计算时间重叠,实现更高的效率。
单 GPU 分块(In-GPU Tile)
尽管跨 GPU 分块已经降低了显存复杂度至 , 但由于 GPU 数量有限的, 我们进一步引入单 GPU 分块策略, 将显存开销进一步降至 。具体而言, 我们将 细分为更小的块:
其中 和 分别表示在行和列方向的块数量, 和 是单个 块的大小。
在实现中,我们将这些 tile 的计算任务分配给多个 CUDA 核心,以充分利用 GPU 的并行计算能力。每个内核中,对每个 的行块进行串行计算,并应用 2.1 小节中的公式循环累积计算 LSE 的值 。
熟悉 flash attention 的同学们都知道,显卡计算的核心消耗在于 HBM 和 SRAM 的来回数据传输过程。为避免频繁的 HBM(高带宽内存)与 SRAM(片上内存)之间的数据交换带来的高昂开销,我们将行方向的迭代计算合并到一个 kernel 中执行。
通过这种方式,图像特征在计算开始时只需要加载到 SRAM 一次,而累积的 LSE 结果 l~i 仅在计算结束时写回 HBM 一次。这种仅把最终结果写入到 HBM 的 fused 操作,会极大提升算子优化性能。对比实例化整个相似度矩阵 X~ 写入 HBM 里,在运算时又进行取出的传统实现,这种 fused 的操作虽然行方向进行了串行计算,但整体速度几乎相当。
03 实验效果
显存开销对比
现有方法在 batch size 放大时计算 loss 的显存开销急剧增长,很快超出硬件限制,而 Inf-CL 即使在 batch size 很大时仍然可以将 loss 的显存开销控制在很小的范围内。
训练效率对比
与现有的对比学习损失计算实现相比,Inf-CL 不会明显增加计算时间,且随着 batch size 不断增大,计算效率也不会明显下降。
精度验证
由于数学上的等价性,使用 Inf-CL 不会损失精度,同时也验证了一定范围内放大 batch size 对模型性能有增益。
关于 scaling batch size 的讨论:
虽然理论上预计更大的批次大小会提高性能,但我们的实验结果与这一预期有偏差,有几方面原因:
首先,为保证训练时间不变,当前设定下增大 batch size 伴随着更少的迭代次数,可能需要对学习率、迭代次数等超参数进一步优化以确保模型收敛。其次,更大的数据集更准确地捕捉现实世界的分布,因此在大规模数据集上放大 batch size 的收益更为明显。
我们在不同数据规模(CC3M、CC12M 和 Laion400M)上的实验表明,最佳的batch size随数据集规模的增加而增加,体现了 Inf-CL 在 scaling law 下的长远意义。
04 总结
对比学习有多炸不用多说,在图文检索(CLIP为代表),图像自监督学习(SimCLR,MoCo 等),文本检索(DPR 等)是核心地位。该领域的共识都是“增大 batch size / 负样本很有用,但 GPU 显存会炸”,比如早期 MoCo 提出用 “momenturm encoder” 和 “memory bank” 来规避这个问题。
这个工作直面该问题,将对比损失的显存消耗打到底,且额外时间开销极少,为对比损失相关辐射领域提供了新的 scaling 机会。
公众号后台回复“数据集”获取100+深度学习各方向资源整理
极市干货
点击阅读原文进入CV社区
收获更多技术干货