Inf-CL: 把 Contrastive Loss 的 Batch Size 冲到100M!

未分类1周前发布 tree
5 0 0
↑ 点击蓝字 关注极市平台
Inf-CL: 把 Contrastive Loss 的 Batch Size 冲到100M!
作者丨藤原豆腐皮儿@知乎(已授权)
来源丨https://zhuanlan.zhihu.com/p/1681887214
编辑丨极市平台

极市导读

 

本文介绍了一种新的Contrastive Loss实现方式——Inf-CL,它通过分块计算策略,在单台A800机器上将batch size扩展到4M,几乎实现了Contrastive Loss batch size的无限扩展,突破了以往认为增加batch size会导致显存不足的限制。 >>加入极市CV技术交流群,走在计算机视觉的最前沿

TL;DR

本文提出了一种Contrastive Loss的实现方式(Inf-CL),通过分块计算策略,我们在单台A800机器上就能把 batch size 扩展到 4M。不严谨地说,该方案突破了以前公知的”contrastive loss不能scaling batch size,显存会炸“的前提,实现了 Contrastive Loss 的 batch size 近乎无限的扩展。 中国人不骗中国人,以后对比损失实现就用Inf-CL!!

对比学习有多炸不用多说,在图文检索(CLIP为代表),图像自监督学习(SimCLR,MoCo等),文本检索(DPR等)是核心地位。之前相关工作的前提都是”增大batch size/负样本,GPU显存会炸“,比如早期MoCo提出用”momenturm encoder“和“memory bank”来规避这个问题。这个工作直面显存痛点,将对比损失的显存消耗打到底,且额外时间开销极少,为对比损失相关辐射领域提供了新的scaling机会。

先放炸裂结果:

Inf-CL: 把 Contrastive Loss 的 Batch Size 冲到100M!
图 1:Inf-CL 与现有方法(CLIP 和 OpenCLIP)的 GPU 显存使用对比。

图中标出了常见的 GPU 显存限制。对于超过 80GB A800 显存瓶颈 的情况,通过曲线拟合估算显存消耗。

  1. 左图:在 8×A800 GPU 配置下,CLIP 和 OpenCLIP 的显存消耗呈 二次增长,而 Inf-CL 实现了 线性增长。在 256k batch size 下,Inf-CL 将显存消耗降低了78倍
  2. 右图:在 1024k batch size 下,即使使用 128 块 GPU,CLIP 和 OpenCLIP 的显存仍然会炸。而 Inf-CL 将显存需求减少了 281倍

题目:Breaking the Memory Barrier: Near Infinite Batch Size Scaling for Contrastive Loss

论文链接:https://github.com/DAMO-NLP-SG/Inf-CLIP/blob/main/assets/inf_cl.pdf

Arxiv链接:https://arxiv.org/abs/2410.17243

Huggingface Papers:https://huggingface.co/papers/2410.17243

代码链接: https://github.com/DAMO-NLP-SG/Inf-CLIP

1. 准备工作

1.1 Contrastive Loss

对比学习从20年以来开始爆火,从那个时代走过来的小伙伴,应该还记得这个简单的损失函数绽放了多大的光彩。在图像自监督领域 SimCLR 和 MoCo 两大模型系列相互争锋,跨模态检索领域,开启图文检索预训练的 CLIP 模型,在 NLP和信息检索领域,大家耳熟能详的 SimCSEDPR 等模型,都采用了Contrastive Loss作为训练损失。

这里以CLIP中的实现为例简单回顾一下contrastive loss。假设 batch size 为 b,图像和文本特征的维度为 [b,c],则 CLIP 中的图像到文本的 Contrastive Loss 公式如下:

其中 是第 i 个图像和第 j 个文本之间的余弦相似度, 这里 是匹配样本(正样本对)的相似度。为了简化讨论,公式中省略了温度因子。

从公式中我们可以看到,对比损失会将batch 内非匹配的文本作为负样本,来计算匹配图文对(正样本对)归一化的概率。这个就叫做In-batch negative 策略——即将 batch 内的所有其他样本视作负样本。这种策略的优点在于,batch size 越大,模型就能接触到更多的负样本,从而学到更具判别性的特征。因此,了解对比学习的同学们都知道,batch size 理论上越大,效果就越好,这点也有很多文章从理论上进行分析。

那么一个直观地想法是,我们直接batch size 扩大不就好了,就像别的分类,回归,或者文本生成的任务一样,把梯度累积步数多开一些,batch size不就能一直增大了吗?但遗憾的是,对比学习的batch size 方法一直是一个比较蛋疼的问题。实现过对比损失的同学都知道,核心限制主要是”增大batch size/负样本,GPU显存会炸“。接下来我们来分析显存消耗到了什么地方。

1.2 显存限制

Inf-CL: 把 Contrastive Loss 的 Batch Size 冲到100M!
图2. (a) Vanilla 实现的 Contrastive Loss:将所有特征广播到所有显卡,并同时将完整的相似度矩阵实例化到显存中。存储复杂度为 ( 2),且在所有显卡上重复存储该矩阵。(b) Inf-CL 方法:采用分块-串行累加的策略减少显存占用。

经典 的对比损失实现中(如CLIP),首先需要构建 相似度矩阵 ,并将其存储在 高带宽内存 (HBM) 中。然后对相似度矩阵应用 Softmax 归一化负对数似然计算 来完成损失计算。

然而,相似度矩阵 及其归一化结果的显存需求,会随着 batch size 呈二次方增长,即显存复杂度是 ,这意味着当 batch size 较大时,显存占用会变得非常庞大。例如即使在采用 ViT-B/16 这种轻量化模型的情况下,当 batch size 达到 64k 时,Loss 计算部分的GPU 显存消耗仍然极为惊人。如图 2 (a)所示,尽管模型自身的显存开销仅为 5.24GB,但损失计算所需的显存却高达 66GB

这个例子我们可以清楚看到,在scaling batch size 时,显存瓶颈主要集中在损失计算上。现有的方法,如 Gradient CacheBASIC 等,虽在一定程度上优化了模型的显存占用,但依然未能突破loss 计算过程中显存二次增长 的限制。

2. 方法

2.1 分块计算策略

正如在 上一小节vanilla 实现 中讨论的那样,显存消耗的核心问题在于 相似度矩阵 X 的完全实例化。那么我们有没有办法避免将它存储呢?为了达到这个效果,我们首先分析这个矩阵是用来计算什么的,所以先将对比损失的公式进行拆解分析:

Inf-CL: 把 Contrastive Loss 的 Batch Size 冲到100M!

公式分解后,我们可以将contrastive loss的计算拆解为两部分:

  1. 第一部分 :计算所有 正样本对的相似度 并累加这部分的计算复杂度是 mathcal{O}(b) ,即线性增长,因此不会造成显存瓶颈。
  2. 第二部分 :计算 Log-Sum-Exp (LSE),即所有负样本对的相似度的对数-指数和。这部分是由全局相似度矩阵 计算得到的如果直接计算并存储整个矩阵,就会导致显存开销迅速增加。

将公式拆解后我们发现,原来相似度矩阵 的完全实例化是为了计算LSE这一项其实也就是Softmax操作的分母部分。看到这里,熟悉 on-line SoftmaxFlashAttention 技术的同学们可能已经秒懂了,本质问题是一样的:如果我们能通过分块计算避免一次性存储整个矩阵,LSE 的计算也就不会消耗很多的显存。既然 大模型 的输入长度都能扩展到 百万级别(例如 FlashAttention 支持的超长序列),那么对比损失的 batch size scaling 问题自然也可以迎刃而解。

前向传播过程:

具体来说, 分块策略的前向传播计算过程如下:

Inf-CL: 把 Contrastive Loss 的 Batch Size 冲到100M!

其中,  和  分别表示行和列方向上的分块数量。通俗的说,就是不把矩阵 一次性计算并存储下来,而是将矩阵 计算划分为多个块(即子矩阵) ,并在每个块内部计算局部LSE 值 , 之后沿着 行方向 逐步合并每列块的 LSE 值,最终得到全局 LSE 向量

这种 分块计算 方法显著减少了对显存的需求,因为每次只需计算和存储相似度矩阵的一部分,而不是整个 矩阵。此外,在列方向的运算支持并行,能够很好适应多 GPU 或GPU内部多芯片的并行架构,

防溢出策略:

为了避免在合并过程中出现数值不稳定或溢出,采用如下稳定的数值计算公式:

Inf-CL: 把 Contrastive Loss 的 Batch Size 冲到100M!

其中初始值 。每次迭代维护列方向的LSE向量 ,将中间值 累积到 中,完成行方向所有块的计算后,得到最终的全局 LSE 向量

此外,在计算 时,直接对矩阵求指数可能导致数值溢出。为此,我们采用以下稳定的公式进行计算:

Inf-CL: 把 Contrastive Loss 的 Batch Size 冲到100M!

其中 是一个行最大值向量,每个元素代表 中对应行的最大值,用作确保指数计算不会溢出。

反向传播过程:

其实在传统实现方式的前向传播过程中,相似度矩阵 会存储在计算图内,能够直接调用torch的autograd机制来计算梯度。既然我们在前向过程中仅仅存储了最终得到的LSE向量 ,那么就需要自定义实现反向传播的算子。

具体运算过程如下,假设已经计算得到loss的结果,要计算对于图像特征输入  和文本特征 的梯度

Inf-CL: 把 Contrastive Loss 的 Batch Size 冲到100M!

根据2.1小节拆解的公式,以 I_i 为例,完整的梯度公式为:

Inf-CL: 把 Contrastive Loss 的 Batch Size 冲到100M!

简化后:

Inf-CL: 把 Contrastive Loss 的 Batch Size 冲到100M!

从该公式可以看出,第二项计算依赖于相似度矩阵的值。我们在反向计算中也采用与前向过程相同的分块计算策略

  1. 在前向传播时,仅存储大小为 b 的向量
  2. 在反向传播时,逐块累积计算梯度:
Inf-CL: 把 Contrastive Loss 的 Batch Size 冲到100M!

最终梯度为:

Inf-CL: 把 Contrastive Loss 的 Batch Size 冲到100M!

其中 是用于累积的临时变量。通过这种分块计算,我们在反向传播中同样避免了完整存储矩阵 的需求,进一步降低了显存开销,并实现了高效的梯度计算。详细的算法步骤在论文中可以找到。

2.2 Multi-Level Tiling

看到这里的小伙伴们可能会产生疑问,分块累加这种操作本质上是将并行计算的过程用串行合并来替代了,也是一种时间换空间的策略,而且反向传播的recompute过程也会带来额外的计算,难道不会很慢吗?其实问题的答案是:整体计算量会增加,但我们可以通过GPU的分布式运算特性来加速这个过程,运算速度却并不会减慢很多。加速过程主要是两块,即跨GPU的通讯和GPU内显存的IO加速。我们将其称为 多层级分块策略。该策略将 LSE 的计算分配为 粗粒度的跨 GPU 分块细粒度的单 GPU 分块,以最大化计算效率。

Inf-CL: 把 Contrastive Loss 的 Batch Size 冲到100M!
图3. 多层级分块策略示意图。上:在 跨 GPU 分块中,每个 GPU 被分配多行数据,并负责对应行的LSE计算。计算与列方向的通信采用 异步 方式执行。下:在 单 GPU 分块 中,将行方向的计算任务分配给多个 CUDA 核心。每行的累积操作在一个 kernel 中执行,以减少 SRAM 和HBM之间I/O次数。

跨 GPU 分块 (Cross-GPU Tile)

并行训练 中,假设有 个 GPU,每个 GPU 处理一部分图像和文本数据,分别生成视觉特征 和文本特征 ,其中 表示单个 GPU 上的 batch size。计算对比损失时,我们将不同行的数据分配给不同的 GPU,并逐步同步各 GPU 之间的列数据。

具体而言,第 个 GPU 负责相似度矩阵的第 行子块的 : 及其对应的 LSE 向量 。为了降低显存开销,结合 分块策略,每个行块 可以进一步拆分为 步小块 来计算 LSE,具体过程可以在论文中算法 1 中找到。每个小块 的 LSE 计算采用 单 GPU 分块策略(详见下节)。

由于计算 (当 时)需要访问其他 GPU 上的文本特征 ,这不可避免地会带来通信开销。为此,我们采用 环形拓扑结构 来尽可能减少通信时间。具体而言,每个 GPU 利用当前文本特征进行计算时,会异步地将其发送给下一个 GPU,并异步地从上一个 GPU 接收新的文本特征,如此循环执行。通过这种方式,通信时间可以与计算时间重叠,实现更高的效率。

单 GPU 分块 (In-GPU Tile)

尽管跨 GPU 分块已经降低了显存复杂度至 ,但由于 GPU 数量有限的,我们进一步引入单 GPU 分块策略,将显存开销进一步降至 。具体而言,我们将 细分为更小的 块:

Inf-CL: 把 Contrastive Loss 的 Batch Size 冲到100M!

其中 分别表示 块 在行和列方向的数量,  和  是单个 tile 的大小。在实现中,我们将这些 tile 的计算任务分配给多个 CUDA 核心,以充分利用 GPU 的并行计算能力。每个内核中,对每个 tile 的行块进行串行计算,并应用2.1小节中的公式循环累积计算LSE的值

熟悉flash attention的同学们都知道,显卡计算的核心消耗在HBM和SRAM的来回传输过程。为避免频繁的 HBM(高带宽内存)与 SRAM(片上内存) 之间的数据交换带来的高昂开销,我们将 行方向的迭代计算 合并到 一个 kernel 中执行。通过这种方式,图像特征 在计算开始时只需要加载到 SRAM 一次,而 累积的 LSE 结果 仅在计算结束时写回 HBM 一次。这种仅把最终结果写入到HBM的fused 操作,会极大提升算子优化性能。对比实例化整个相似度矩阵 写入HBM里,在运算时又进行取出的传统实现,这种fused的操作虽然行方向进行了串行计算,但整体速度几乎相当。

那么大概的运算流程就如上所述,如果有不清楚的地方,欢迎您阅读我们的论文。

3. 实验结果

就两方面来说明我们方法的效果,显存节省度、速度、精度。

显存:

我们在论文中的实验主要在OpenCLIP的框架上做的。具体setting可以在论文中找到,对比的baseline主要是CLIP的原始实现和OpenCLIP中的优化实现,为了为了尽量减少显存占用,实验使用了 Gradient Cache,其中累积 batch size 设置为 128。具体效果如下:

Inf-CL: 把 Contrastive Loss 的 Batch Size 冲到100M!
表1. 不同硬件与 batch size 下的训练显存消耗对比

标注* 表示采用了 数据卸载策略(Data Offload),即在每次累积步骤中,仅传输小批量数据从 CPU 到 GPU,以降低显存需求。

结果:可以看到在1024k的情况下,我们的显存占用仅仅用了1G左右!!但是OpenCLIP等方法直接炸显存了。

Inf-CL: 把 Contrastive Loss 的 Batch Size 冲到100M!
表2. 不同硬件/模型下最大支持的 batch size

结果:看到在8A800 最大支持4M训练,32A800支持12 M 的训练

速度:

Inf-CL: 把 Contrastive Loss 的 Batch Size 冲到100M!
图4. 不同 batch size 下 ViT-L/14 CLIP 在 8×A800 上的训练速度对比

结果:Inf-CL 在大幅降低显存占用的同时,仅引入了极少的 额外时间开销。此外,Inf-CL的迭代时间随批量大小线性扩展,为Batch size的scaling 研究提供了可能。

精度:

Inf-CL: 把 Contrastive Loss 的 Batch Size 冲到100M!
表3. 精度验证实验

结果:Inf-CL在数学表达上与原本的对比损失是一致的, 实际训练得到的结果也在误差范围内,且在github里面也进行了精度误差的测量实验。

4. 总结

对比学习有多炸不用多说,在图文检索(CLIP为代表),图像自监督学习(SimCLR,MoCo等),文本检索(DPR等)是核心地位。之前相关工作的前提都是”增大batch size/负样本,GPU显存会炸“,比如早期MoCo提出用”momenturm encoder“和“memory bank”来规避这个问题。这个工作直面该问题,将对比损失的显存消耗打到底,且额外时间开销极少,为对比损失相关辐射领域提供了新的scaling机会。

附:

一些好工作!

Gradient Cache: 将”模型前向计算特征“ 与 ”对比损失计算“ 解绑,可以理解为针对对比学习的”梯度累积“策略。

代码:https://github.com/luyug/GradCache

论文:Scaling Deep Contrastive Learning Batch Size under Memory Limited Setup

Flash Attention和 Ring Attention:更细粒度的解释I/O awared 和ring 通讯策略,也是本文灵感的来源。

Ring Attention的优雅实现:https://github.com/zhuzilin/ring-flash-attention/tree/main

Flash Attention:https://github.com/Dao-AILab/flash-attention

Inf-CL: 把 Contrastive Loss 的 Batch Size 冲到100M!

公众号后台回复“数据集”获取100+深度学习各方向资源整理

极市干货

技术专栏:多模态大模型超详细解读专栏搞懂Tranformer系列大视觉模型 (LVM) 解读扩散模型系列极市直播
技术综述:小目标检测那点事大模型面试八股含答案万字长文!人体姿态估计(HPE)入门教程

Inf-CL: 把 Contrastive Loss 的 Batch Size 冲到100M!

点击阅读原文进入CV社区

收获更多技术干货

© 版权声明

相关文章

暂无评论

暂无评论...