无限批扩展可能么?达摩院Inf-CL打破对比学习显存瓶颈,提效100倍!

未分类2周前发布 tree
14 0 0
↑ 点击蓝字 关注极市平台
无限批扩展可能么?达摩院Inf-CL打破对比学习显存瓶颈,提效100倍!
作者丨PaperWeekly
来源丨PaperWeekly
编辑丨极市平台

极市导读

 

该方案突破了领域内“Contrastive loss 由于显存限制不能放大 batch size”的“共识”,实现了对比损失的 batch size 近乎无限的扩展>>加入极市CV技术交流群,走在计算机视觉的最前沿

达摩院研究员提出了一种对比损失(Contrastive Loss)的高效实现方式(Inf-CL),通过分块计算策略,在单台 A800 机器上就能把 batch size 扩展到 400万。该方案突破了领域内“Contrastive loss 由于显存限制不能放大 batch size”的“共识”,实现了对比损失的 batch size 近乎无限的扩展

无限批扩展可能么?达摩院Inf-CL打破对比学习显存瓶颈,提效100倍!

论文标题:

Breaking the Memory Barrier: Near Infinite Batch Size Scaling for Contrastive Loss

论文链接:

https://arxiv.org/pdf/2410.17243

代码链接:

https://github.com/DAMO-NLP-SG/Inf-CLIP

先看显著结果:

无限批扩展可能么?达摩院Inf-CL打破对比学习显存瓶颈,提效100倍!
▲ 图1:Inf-CL 与现有方法(CLIP 和 OpenCLIP)的 GPU 显存使用对比。

图中标出了常见的 GPU 显存限制。对于超过 80GB A800 显存瓶颈 的情况,通过曲线拟合估算显存消耗。

  1. 左图:在 8×A800 GPU 配置下,CLIP 和 OpenCLIP 的显存消耗呈 二次增长,而 Inf-CL 实现了 线性增长。在 256k batch size 下,Inf-CL 将显存消耗降低了 78 倍。
  2. 右图:在 1024k batch size 下,即使使用 128 块 GPU,CLIP 和 OpenCLIP 的显存仍然会炸。而 Inf-CL 将显存需求减少了 281 倍。

01 背景

1.1 对比损失

对比学习从 20 年以来开始爆火,从那个时代走过来的小伙伴,应该还记得这个简单的损失函数绽放了多大的光彩。

在图像自监督领域 SimCLRMoCo 两大模型系列相互争锋,跨模态检索领域,开启图文检索预训练的 CLIP 模型,在 NLP 和信息检索领域,大家耳熟能详的 SimCSE 和 DPR 等模型,都采用了 Contrastive Loss 作为训练损失。

以 CLIP 中的实现为例简单回顾一下 Contrastive loss。假设 batch size 为 , 图像和文本特征的维度为 , 则 CLIP 中的图像到文本的 Contrastive Loss 公式如下:

无限批扩展可能么?达摩院Inf-CL打破对比学习显存瓶颈,提效100倍!

其中 是第 个图像和第 个文本之间的余弦相似度, 这里 是匹配样本(正样本对)的相似度。为了简化讨论, 公式中省略了温度因子。

从公式中可以看到,对比损失会将 batch 内非匹配的文本作为负样本,来计算匹配图文对(正样本对)归一化的概率。这个就叫做 In-batch negative 策略——即将 batch 内的所有其他样本视作负样本。

这种策略的优点在于,batch size 越大,模型就能接触到更多的负样本,从而学到更具判别性的特征。因此,了解对比学习的同学们都知道,batch size 理论上越大,效果就越好,这点也有很多文章从理论上进行分析。

那么一个直观地想法是,我们直接 batch size 扩大不就好了,就像别的分类,回归,或者文本生成的任务一样,把梯度累积步数多开一些,batch size 不就能一直增大了吗?
但遗憾的是,对比学习的 batch size 方法一直是一个比较蛋疼的问题。实现过对比损失的同学都知道,核心限制主要是“增大 batch size / 负样本,GPU 显存会炸”。接下来我们来分析显存消耗到了什么地方。

1.2 显存限制

无限批扩展可能么?达摩院Inf-CL打破对比学习显存瓶颈,提效100倍!
▲ 图1.(a)Vanilla 实现的 Contrastive Loss:将所有特征广播到所有显卡,并同时将完整的相似度矩阵实例化到显存中。存储复杂度为(2),且在所有显卡上重复存储该矩阵。(b)Inf-CL 方法:采用分块-串行累加的策略减少显存占用。

经典的对比损失实现中(如 CLIP),首先需要构建相似度矩阵,并将其存储在高带宽内存(HBM)中。然后对相似度矩阵应用 Softmax 归一化负对数似然计算来完成损失计算。

然而, 相似度矩阵 及其归一化结果的显存需求, 会随着 batch size 呈二次方增长, 即显存复杂度是 , 这意味着当 batch size 较大时, 显存占用会变得非常庞大。

例如即使在采用 ViT-B/16 这种轻量化模型的情况下,当 batch size 达到 64k 时,Loss 计算部分的 GPU 显存消耗仍然极为惊人。如图 2(a)所示,尽管模型自身的显存开销仅为 5.24GB,但损失计算所需的显存却高达 66GB

这个例子可以清楚看到,在 scaling batch size 时,显存瓶颈主要集中在损失计算上。现有的方法,如 Gradient CacheBASIC 等,虽在一定程度上优化了模型的显存占用,但依然未能突破 loss 计算过程中显存二次增长的限制。

02 方法

2.1 分块计算策略

从上文的分析中我们可以看到,显存消耗的核心问题在于相似度矩阵 X 的完全实例化。那么有没有办法避免将它存储到显存里呢?为了达到这个效果,我们首先分析这个矩阵是用来计算什么的,所以先将对比损失的公式进行拆解分析:

无限批扩展可能么?达摩院Inf-CL打破对比学习显存瓶颈,提效100倍!

公式分解后,我们可以将 contrastive loss 的计算拆解为两部分:

  1. 第一部分 :计算所有正样本对的相似度 并累加。这部分的计算复杂度是 ,即线性增长,因此不会造成显存瓶颈。

  2. 第二部分 : 计算 Log-Sum-Exp (LSE), 即所有负样本对的相似度的对数-指数和。这部分是由全局相似度矩阵 计算得到的,如果直接计算并存储整个矩阵,就会导致显存开销迅速增加。

将公式拆解后我们发现,原来相似度矩阵 X 的完全实例化是为了计算 LSE 这一项,其实也就是 Softmax 操作的分母部分

看到这里,熟悉 on-line SoftmaxFlashAttention 技术的同学们可能已经秒懂了,本质问题是一样的:如果我们能通过分块计算避免一次性存储整个矩阵,LSE 的计算也就不会消耗很多的显存。

既然大模型的输入长度都能扩展到百万级别(例如 FlashAttention 支持的超长序列),那么对比损失的 **batch size scaling **问题自然也可以迎刃而解。

前向传播过程:

具体来说,分块策略的前向传播计算过程如下:

无限批扩展可能么?达摩院Inf-CL打破对比学习显存瓶颈,提效100倍!

其中, 分别表示行和列方向上的分块数量。通俗的说, 就是不把矩阵 -次性计算并存储下来, 而是将矩阵 计算划分为多个块(即子矩阵) , 并在每个块内部计算局部 LSE 值 , 之后沿着行方向逐步合并每列块的 LSE 值, 最终得到全局 LSE向量

这种分块计算方法显著减少了对显存的需求,因为每次只需计算和存储相似度矩阵的一部分,而不是整个 矩阵。此外,在列方向的运算支持并行,能够很好适应多 GPU 或 GPU 内部多芯片的并行架构,

防溢出策略:

为了避免在合并过程中出现数值不稳定或溢出,采用如下稳定的数值计算公式:

无限批扩展可能么?达摩院Inf-CL打破对比学习显存瓶颈,提效100倍!

其中初始值 。每次迭代维护列方向的 LSE 向量 , 将中间值 累积到 中, 完成行方向所有块的计算后,得到最终的全局 LSE 向量

此外, 在计算 时, 直接对矩阵求指数可能导致数值溢出。为此, 我们采用以下稳定的公式进行计算:

无限批扩展可能么?达摩院Inf-CL打破对比学习显存瓶颈,提效100倍!

其中 是一个行最大值向量, 每个元素代表 中对应行的最大值, 用作确保指数计算不会溢出。

反向传播过程:

其实在传统实现方式的前向传播过程中,相似度矩阵 X 会存储在计算图内,能够直接调用 pytorch 的 autograd 机制来计算梯度。既然我们在前向过程中仅仅存储了最终得到的 LES 向量 l,那么就需要自定义实现反向传播的算子。

体运算过程如下, 假设已经计算得到 loss 的结果, 要计算对于图像特征输入 和文本特征 的梯度:

无限批扩展可能么?达摩院Inf-CL打破对比学习显存瓶颈,提效100倍!

根据 2.1 小节拆解的公式,以 为例,完整的梯度公式为:

无限批扩展可能么?达摩院Inf-CL打破对比学习显存瓶颈,提效100倍!

从该公式可以看出,第二项计算依赖于相似度矩阵的值。我们在反向计算中也采用与前向过程相同的分块计算策略

  1. 在前向传播时, 仅存储大小为 的向量

  2. 在反向传播时,逐块累积计算梯度:

无限批扩展可能么?达摩院Inf-CL打破对比学习显存瓶颈,提效100倍!

其中 是用于累积的临时变量。通过这种分块计算, 我们在反向传播中同样避免了完整存储矩阵 的需求, 进一步降低了显存开销, 并实现了高效的梯度计算。详细的算法步骤在论文中可以找到。

2.2 Multi-Level Tiling**

看到这里的小伙伴们可能会产生疑问,分块累加这种操作本质上是将并行计算的过程用串行合并来替代了,也是一种时间换空间的策略,而且反向传播的 recompute 过程也会带来额外的计算,难道不会很慢吗?

其实问题的答案是:整体计算量会增加,但我们可以通过 GPU 的分布式运算特性来加速这个过程,运算速度却并不会减慢很多。加速过程主要是两块,即跨 GPU 的通讯和 GPU 内显存的 IO 加速。我们将其称为多层级分块策略。该策略将 LSE 的计算分配为粗粒度的跨 GPU 分块和细粒度的单 GPU 分块,以最大化计算效率。

无限批扩展可能么?达摩院Inf-CL打破对比学习显存瓶颈,提效100倍!
▲ 图3. 多层级分块策略示意图。上:在 跨 GPU 分块中,每个 GPU 被分配多行数据,并负责对应行的 LSE 计算。计算与列方向的通信采用 异步 方式执行。下:在 单 GPU 分块 中,将行方向的计算任务分配给多个 CUDA 核心。每行的累积操作在一个 kernel 中执行,以减少 SRAM 和 HBM 之间 I/O 次数。

跨 GPU 分块(Cross-GPU Tile)

并行训练中, 假设有 个 GPU, 每个 GPU 处理一部分图像和文本数据, 分别生成视觉特征 和文本特征 , 其中 表示单个 GPU 上的 batch size。计算对比损失时, 我们将不同行的数据分配给不同的 GPU, 并逐步同步各 GPU 之间的列数据。

具体而言, 第 个 GPU 负责相似度矩阵的第 行子块的 :及其对应的 LSE 向量 。为了降低显存开销, 结合分块策略, 每个行块 可以进一步拆分为 步小块 来计算 LSE, 具体过程可以在论文中算法 1 中找到。每个小块 的 LSE 计算采用单 GPU 分块策略 (详见下节)。

由于计算 (当 时)需要访问其他 GPU 上的文本特征 , 这不可避免地会带来通信开销。

为此,我们采用环形拓扑结构来尽可能减少通信时间。具体而言,每个 GPU 利用当前文本特征进行计算时,会异步地将其发送给下一个 GPU,并异步地从上一个 GPU 接收新的文本特征,如此循环执行。通过这种方式,通信时间可以与计算时间重叠,实现更高的效率。

单 GPU 分块(In-GPU Tile)

尽管跨 GPU 分块已经降低了显存复杂度至 , 但由于 GPU 数量有限的, 我们进一步引入单 GPU 分块策略, 将显存开销进一步降至 。具体而言, 我们将 细分为更小的块:

无限批扩展可能么?达摩院Inf-CL打破对比学习显存瓶颈,提效100倍!

其中 分别表示在行和列方向的块数量, 是单个 块的大小。

在实现中,我们将这些 tile 的计算任务分配给多个 CUDA 核心,以充分利用 GPU 的并行计算能力。每个内核中,对每个 的行块进行串行计算,并应用 2.1 小节中的公式循环累积计算 LSE 的值

熟悉 flash attention 的同学们都知道,显卡计算的核心消耗在于 HBM 和 SRAM 的来回数据传输过程。为避免频繁的 HBM(高带宽内存)与 SRAM(片上内存)之间的数据交换带来的高昂开销,我们将行方向的迭代计算合并到一个 kernel 中执行。

通过这种方式,图像特征在计算开始时只需要加载到 SRAM 一次,而累积的 LSE 结果 l~i 仅在计算结束时写回 HBM 一次。这种仅把最终结果写入到 HBM 的 fused 操作,会极大提升算子优化性能。对比实例化整个相似度矩阵 X~ 写入 HBM 里,在运算时又进行取出的传统实现,这种 fused 的操作虽然行方向进行了串行计算,但整体速度几乎相当。

03 实验效果

显存开销对比

无限批扩展可能么?达摩院Inf-CL打破对比学习显存瓶颈,提效100倍!

现有方法在 batch size 放大时计算 loss 的显存开销急剧增长,很快超出硬件限制,而 Inf-CL 即使在 batch size 很大时仍然可以将 loss 的显存开销控制在很小的范围内。

训练效率对比

无限批扩展可能么?达摩院Inf-CL打破对比学习显存瓶颈,提效100倍!

与现有的对比学习损失计算实现相比,Inf-CL 不会明显增加计算时间,且随着 batch size 不断增大,计算效率也不会明显下降。

精度验证

无限批扩展可能么?达摩院Inf-CL打破对比学习显存瓶颈,提效100倍!

由于数学上的等价性,使用 Inf-CL 不会损失精度,同时也验证了一定范围内放大 batch size 对模型性能有增益。

关于 scaling batch size 的讨论:

无限批扩展可能么?达摩院Inf-CL打破对比学习显存瓶颈,提效100倍!

虽然理论上预计更大的批次大小会提高性能,但我们的实验结果与这一预期有偏差,有几方面原因:

首先,为保证训练时间不变,当前设定下增大 batch size 伴随着更少的迭代次数,可能需要对学习率、迭代次数等超参数进一步优化以确保模型收敛。其次,更大的数据集更准确地捕捉现实世界的分布,因此在大规模数据集上放大 batch size 的收益更为明显。

我们在不同数据规模(CC3M、CC12M 和  Laion400M)上的实验表明,最佳的batch size随数据集规模的增加而增加,体现了 Inf-CL 在 scaling law 下的长远意义。

04 总结

对比学习有多炸不用多说,在图文检索(CLIP为代表),图像自监督学习(SimCLR,MoCo 等),文本检索(DPR 等)是核心地位。该领域的共识都是“增大 batch size / 负样本很有用,但 GPU 显存会炸”,比如早期 MoCo 提出用 “momenturm encoder” 和 “memory bank” 来规避这个问题。

这个工作直面该问题,将对比损失的显存消耗打到底,且额外时间开销极少,为对比损失相关辐射领域提供了新的 scaling 机会。

无限批扩展可能么?达摩院Inf-CL打破对比学习显存瓶颈,提效100倍!

公众号后台回复“数据集”获取100+深度学习各方向资源整理

极市干货

技术专栏:多模态大模型超详细解读专栏搞懂Tranformer系列大视觉模型 (LVM) 解读扩散模型系列极市直播
技术综述:小目标检测那点事大模型面试八股含答案万字长文!人体姿态估计(HPE)入门教程

无限批扩展可能么?达摩院Inf-CL打破对比学习显存瓶颈,提效100倍!

点击阅读原文进入CV社区

收获更多技术干货

© 版权声明

相关文章

暂无评论

暂无评论...