极市导读
本文介绍了一种新的Contrastive Loss实现方式——Inf-CL,它通过分块计算策略,在单台A800机器上将batch size扩展到4M,几乎实现了Contrastive Loss batch size的无限扩展,突破了以往认为增加batch size会导致显存不足的限制。 >>加入极市CV技术交流群,走在计算机视觉的最前沿
TL;DR
本文提出了一种Contrastive Loss的实现方式(Inf-CL),通过分块计算策略,我们在单台A800机器上就能把 batch size 扩展到 4M。不严谨地说,该方案突破了以前公知的”contrastive loss不能scaling batch size,显存会炸“的前提,实现了 Contrastive Loss 的 batch size 近乎无限的扩展。 中国人不骗中国人,以后对比损失实现就用Inf-CL!!
对比学习有多炸不用多说,在图文检索(CLIP为代表),图像自监督学习(SimCLR,MoCo等),文本检索(DPR等)是核心地位。之前相关工作的前提都是”增大batch size/负样本,GPU显存会炸“,比如早期MoCo提出用”momenturm encoder“和“memory bank”来规避这个问题。这个工作直面显存痛点,将对比损失的显存消耗打到底,且额外时间开销极少,为对比损失相关辐射领域提供了新的scaling机会。
先放炸裂结果:
图中标出了常见的 GPU 显存限制。对于超过 80GB A800 显存瓶颈 的情况,通过曲线拟合估算显存消耗。
左图:在 8×A800 GPU 配置下,CLIP 和 OpenCLIP 的显存消耗呈 二次增长,而 Inf-CL 实现了 线性增长。在 256k batch size 下,Inf-CL 将显存消耗降低了78倍。 右图:在 1024k batch size 下,即使使用 128 块 GPU,CLIP 和 OpenCLIP 的显存仍然会炸。而 Inf-CL 将显存需求减少了 281倍。
题目:Breaking the Memory Barrier: Near Infinite Batch Size Scaling for Contrastive Loss
论文链接:https://github.com/DAMO-NLP-SG/Inf-CLIP/blob/main/assets/inf_cl.pdf
Arxiv链接:https://arxiv.org/abs/2410.17243
Huggingface Papers:https://huggingface.co/papers/2410.17243
代码链接: https://github.com/DAMO-NLP-SG/Inf-CLIP
1. 准备工作
1.1 Contrastive Loss
对比学习从20年以来开始爆火,从那个时代走过来的小伙伴,应该还记得这个简单的损失函数绽放了多大的光彩。在图像自监督领域 SimCLR 和 MoCo 两大模型系列相互争锋,跨模态检索领域,开启图文检索预训练的 CLIP 模型,在 NLP和信息检索领域,大家耳熟能详的 SimCSE 和 DPR 等模型,都采用了Contrastive Loss作为训练损失。
这里以CLIP中的实现为例简单回顾一下contrastive loss。假设 batch size 为 b,图像和文本特征的维度为 [b,c],则 CLIP 中的图像到文本的 Contrastive Loss 公式如下:
其中 是第 i 个图像和第 j 个文本之间的余弦相似度, 这里 是匹配样本(正样本对)的相似度。为了简化讨论,公式中省略了温度因子。
从公式中我们可以看到,对比损失会将batch 内非匹配的文本作为负样本,来计算匹配图文对(正样本对)归一化的概率。这个就叫做In-batch negative 策略——即将 batch 内的所有其他样本视作负样本。这种策略的优点在于,batch size 越大,模型就能接触到更多的负样本,从而学到更具判别性的特征。因此,了解对比学习的同学们都知道,batch size 理论上越大,效果就越好,这点也有很多文章从理论上进行分析。
那么一个直观地想法是,我们直接batch size 扩大不就好了,就像别的分类,回归,或者文本生成的任务一样,把梯度累积步数多开一些,batch size不就能一直增大了吗?但遗憾的是,对比学习的batch size 方法一直是一个比较蛋疼的问题。实现过对比损失的同学都知道,核心限制主要是”增大batch size/负样本,GPU显存会炸“。接下来我们来分析显存消耗到了什么地方。
1.2 显存限制
在 经典 的对比损失实现中(如CLIP),首先需要构建 相似度矩阵 ,并将其存储在 高带宽内存 (HBM) 中。然后对相似度矩阵应用 Softmax 归一化 和 负对数似然计算 来完成损失计算。
然而,相似度矩阵 及其归一化结果的显存需求,会随着 batch size 呈二次方增长,即显存复杂度是 ,这意味着当 batch size 较大时,显存占用会变得非常庞大。例如即使在采用 ViT-B/16 这种轻量化模型的情况下,当 batch size 达到 64k 时,Loss 计算部分的GPU 显存消耗仍然极为惊人。如图 2 (a)所示,尽管模型自身的显存开销仅为 5.24GB,但损失计算所需的显存却高达 66GB。
这个例子我们可以清楚看到,在scaling batch size 时,显存瓶颈主要集中在损失计算上。现有的方法,如 Gradient Cache 和 BASIC 等,虽在一定程度上优化了模型的显存占用,但依然未能突破loss 计算过程中显存二次增长 的限制。
2. 方法
2.1 分块计算策略
正如在 上一小节vanilla 实现 中讨论的那样,显存消耗的核心问题在于 相似度矩阵 X 的完全实例化。那么我们有没有办法避免将它存储呢?为了达到这个效果,我们首先分析这个矩阵是用来计算什么的,所以先将对比损失的公式进行拆解分析:
公式分解后,我们可以将contrastive loss的计算拆解为两部分:
第一部分 :计算所有 正样本对的相似度 并累加。这部分的计算复杂度是 mathcal{O}(b) ,即线性增长,因此不会造成显存瓶颈。 第二部分 :计算 Log-Sum-Exp (LSE),即所有负样本对的相似度的对数-指数和。这部分是由全局相似度矩阵 计算得到的,如果直接计算并存储整个矩阵,就会导致显存开销迅速增加。
将公式拆解后我们发现,原来相似度矩阵 的完全实例化是为了计算LSE这一项,其实也就是Softmax操作的分母部分。看到这里,熟悉 on-line Softmax 和 FlashAttention 技术的同学们可能已经秒懂了,本质问题是一样的:如果我们能通过分块计算避免一次性存储整个矩阵,LSE 的计算也就不会消耗很多的显存。既然 大模型 的输入长度都能扩展到 百万级别(例如 FlashAttention 支持的超长序列),那么对比损失的 batch size scaling 问题自然也可以迎刃而解。
前向传播过程:
具体来说, 分块策略的前向传播计算过程如下:
其中, 和 分别表示行和列方向上的分块数量。通俗的说,就是不把矩阵 一次性计算并存储下来,而是将矩阵 的计算划分为多个块(即子矩阵) ,并在每个块内部计算局部LSE 值 , 之后沿着 行方向 逐步合并每列块的 LSE 值,最终得到全局 LSE 向量 。
这种 分块计算 方法显著减少了对显存的需求,因为每次只需计算和存储相似度矩阵的一部分,而不是整个 矩阵。此外,在列方向的运算支持并行,能够很好适应多 GPU 或GPU内部多芯片的并行架构,
防溢出策略:
为了避免在合并过程中出现数值不稳定或溢出,采用如下稳定的数值计算公式:
其中初始值 。每次迭代维护列方向的LSE向量 ,将中间值 累积到 中,完成行方向所有块的计算后,得到最终的全局 LSE 向量 。
此外,在计算 时,直接对矩阵求指数可能导致数值溢出。为此,我们采用以下稳定的公式进行计算:
其中 是一个行最大值向量,每个元素代表 中对应行的最大值,用作确保指数计算不会溢出。
反向传播过程:
其实在传统实现方式的前向传播过程中,相似度矩阵 会存储在计算图内,能够直接调用torch的autograd机制来计算梯度。既然我们在前向过程中仅仅存储了最终得到的LSE向量 ,那么就需要自定义实现反向传播的算子。
具体运算过程如下,假设已经计算得到loss的结果,要计算对于图像特征输入 和文本特征 的梯度
根据2.1小节拆解的公式,以 I_i 为例,完整的梯度公式为:
简化后:
从该公式可以看出,第二项计算依赖于相似度矩阵的值。我们在反向计算中也采用与前向过程相同的分块计算策略:
在前向传播时,仅存储大小为 b 的向量 。 在反向传播时,逐块累积计算梯度:
最终梯度为:
其中 是用于累积的临时变量。通过这种分块计算,我们在反向传播中同样避免了完整存储矩阵 的需求,进一步降低了显存开销,并实现了高效的梯度计算。详细的算法步骤在论文中可以找到。
2.2 Multi-Level Tiling
看到这里的小伙伴们可能会产生疑问,分块累加这种操作本质上是将并行计算的过程用串行合并来替代了,也是一种时间换空间的策略,而且反向传播的recompute过程也会带来额外的计算,难道不会很慢吗?其实问题的答案是:整体计算量会增加,但我们可以通过GPU的分布式运算特性来加速这个过程,运算速度却并不会减慢很多。加速过程主要是两块,即跨GPU的通讯和GPU内显存的IO加速。我们将其称为 多层级分块策略。该策略将 LSE 的计算分配为 粗粒度的跨 GPU 分块 和 细粒度的单 GPU 分块,以最大化计算效率。
跨 GPU 分块 (Cross-GPU Tile)
在 并行训练 中,假设有 个 GPU,每个 GPU 处理一部分图像和文本数据,分别生成视觉特征 和文本特征 ,其中 表示单个 GPU 上的 batch size。计算对比损失时,我们将不同行的数据分配给不同的 GPU,并逐步同步各 GPU 之间的列数据。
具体而言,第 个 GPU 负责相似度矩阵的第 行子块的 : 及其对应的 LSE 向量 。为了降低显存开销,结合 分块策略,每个行块 可以进一步拆分为 步小块 来计算 LSE,具体过程可以在论文中算法 1 中找到。每个小块 的 LSE 计算采用 单 GPU 分块策略(详见下节)。
由于计算 (当 时)需要访问其他 GPU 上的文本特征 ,这不可避免地会带来通信开销。为此,我们采用 环形拓扑结构 来尽可能减少通信时间。具体而言,每个 GPU 利用当前文本特征进行计算时,会异步地将其发送给下一个 GPU,并异步地从上一个 GPU 接收新的文本特征,如此循环执行。通过这种方式,通信时间可以与计算时间重叠,实现更高的效率。
单 GPU 分块 (In-GPU Tile)
尽管跨 GPU 分块已经降低了显存复杂度至 ,但由于 GPU 数量有限的,我们进一步引入单 GPU 分块策略,将显存开销进一步降至 。具体而言,我们将 细分为更小的 块:
其中 和 分别表示 块 在行和列方向的数量, 和 是单个 tile 的大小。在实现中,我们将这些 tile 的计算任务分配给多个 CUDA 核心,以充分利用 GPU 的并行计算能力。每个内核中,对每个 tile 的行块进行串行计算,并应用2.1小节中的公式循环累积计算LSE的值 。
熟悉flash attention的同学们都知道,显卡计算的核心消耗在HBM和SRAM的来回传输过程。为避免频繁的 HBM(高带宽内存)与 SRAM(片上内存) 之间的数据交换带来的高昂开销,我们将 行方向的迭代计算 合并到 一个 kernel 中执行。通过这种方式,图像特征 在计算开始时只需要加载到 SRAM 一次,而 累积的 LSE 结果 仅在计算结束时写回 HBM 一次。这种仅把最终结果写入到HBM的fused 操作,会极大提升算子优化性能。对比实例化整个相似度矩阵 写入HBM里,在运算时又进行取出的传统实现,这种fused的操作虽然行方向进行了串行计算,但整体速度几乎相当。
那么大概的运算流程就如上所述,如果有不清楚的地方,欢迎您阅读我们的论文。
3. 实验结果
就两方面来说明我们方法的效果,显存节省度、速度、精度。
显存:
我们在论文中的实验主要在OpenCLIP的框架上做的。具体setting可以在论文中找到,对比的baseline主要是CLIP的原始实现和OpenCLIP中的优化实现,为了为了尽量减少显存占用,实验使用了 Gradient Cache,其中累积 batch size 设置为 128。具体效果如下:
标注* 表示采用了 数据卸载策略(Data Offload),即在每次累积步骤中,仅传输小批量数据从 CPU 到 GPU,以降低显存需求。
结果:可以看到在1024k的情况下,我们的显存占用仅仅用了1G左右!!但是OpenCLIP等方法直接炸显存了。
结果:看到在8A800 最大支持4M训练,32A800支持12 M 的训练!
速度:
结果:Inf-CL 在大幅降低显存占用的同时,仅引入了极少的 额外时间开销。此外,Inf-CL的迭代时间随批量大小线性扩展,为Batch size的scaling 研究提供了可能。
精度:
结果:Inf-CL在数学表达上与原本的对比损失是一致的, 实际训练得到的结果也在误差范围内,且在github里面也进行了精度误差的测量实验。
4. 总结
对比学习有多炸不用多说,在图文检索(CLIP为代表),图像自监督学习(SimCLR,MoCo等),文本检索(DPR等)是核心地位。之前相关工作的前提都是”增大batch size/负样本,GPU显存会炸“,比如早期MoCo提出用”momenturm encoder“和“memory bank”来规避这个问题。这个工作直面该问题,将对比损失的显存消耗打到底,且额外时间开销极少,为对比损失相关辐射领域提供了新的scaling机会。
附:
一些好工作!
Gradient Cache: 将”模型前向计算特征“ 与 ”对比损失计算“ 解绑,可以理解为针对对比学习的”梯度累积“策略。
代码:https://github.com/luyug/GradCache
论文:Scaling Deep Contrastive Learning Batch Size under Memory Limited Setup
Flash Attention和 Ring Attention:更细粒度的解释I/O awared 和ring 通讯策略,也是本文灵感的来源。
Ring Attention的优雅实现:https://github.com/zhuzilin/ring-flash-attention/tree/main
Flash Attention:https://github.com/Dao-AILab/flash-attention
公众号后台回复“数据集”获取100+深度学习各方向资源整理
极市干货
点击阅读原文进入CV社区
收获更多技术干货