当前位置: 首页 >> 行业
【全球时快讯】Confident Adaptive(上)
来源:哔哩哔哩     时间:2023-06-30 15:42:07

本文首发于网站 机器翻译学堂

转载事宜请后台询问哦


(资料图)

作者 | 蒙龙单位 | 东北大学自然语言处理实验室

论文概况

大模型在许多任务中获得了显著的性能提升,这些收益往往伴随着模型规模的急剧增加,导致模型在推理时的缓慢和昂贵。条件计算是一种动态模型推理加速方法,根据输入的难易程度不同来分配不同的计算量。这次的分享通过两篇论文来为大家分享条件计算中的Confident Adaptive方法。这两篇论文分别是Consistent Accelerated via Confident Adaptive Transformers[1]以及Confident Adaptive Language Modeling[2]。这两者一脉相承,都用了Confident Adaptive的方法,后者在前者的基础上,对问题进行进一步的泛化和讨论。第一篇论文收录在2021的EMNLP,第二篇是在2022年的NIPS上。

条件计算

在正式介绍这两篇文章之前,笔者先给大家介绍一下什么是条件计算以及条件计算中的一个基本范式是什么。条件计算是一种动态的模型推理加速方法,它的一个基本的假设是输入的难易程度不同,它所需要的计算量是不一样的,这样我们就可以根据难易程度的不同来动态分配不同的计算量。条件计算中比较常用的两个方式是早退以及MoE,它们分别是在模型的深度和宽度上来调整计算量。

自适应早期退出

对于Transformers这种多层架构,一种流行的方法是自适应早期退出,也就是早退。举一个简单的例子来说,给定一个比较深的三十二层的一个网络,我们可以在中间的某一层例如第20层退出,这样就不需要完整地计算三十二层这么深的一个网络。但是它和简单地把模型砍到为20层不同,它依然保留了在深层输出的一个能力,让模型自己去为不同的输入选择最适合它的那一层进行输出。因此,从这里我们可以看出来,早退里面有两个比较基本的问题。一个是模型如何再中间层输出结果。

第一个问题相对比较直接,目前比较主流的一种方式就是随时结构化预测,也就是我们在每一个中间层的后面都接上一个输出的分类器,这个分类器和最后一层的分类器一样,输入中间层的隐藏状态表示,输出一个分类类别维度的分布。然后使用对齐训练的方式同时优化这几个分类器。

第二个问题是给定一个输入,我们应该在第几层退出,怎么去找到这一个最适合它的层。对于第二个问题相对比较开放,百家争鸣。这个问题目前没有一个说百分百ground truth的标签,不同的文章提出不同的假设,也就是不同的oracle,再根据不同的oracle构建出不同的伪标签。例如有的人认为我们应该在分数最高的那一层退出,有的人认为我们应该正确token数最多的那一层退出,有的人认为应该用互信息表示来衡量,有的人认为用语言模型的重建损失进行衡量等等。

Confident Adaptive

我们今天要介绍的Confident Adaptive的方式也是这样的模式,文章的作者认为,我们应该在哪里退出呢,当中间某一层的结果和最后一层结果一致的时候,我们就可以进行早退了。我们如何去理解他这个中间层的结果和最后一层一致呢。

我们不妨来看这个图,这是一个Vitamin C的数据集,它给一个Claim和Evidence,然后模型需要判断这个Evidence有没有支持这个Claim,他一共是有三个标签分别是Support,Refuse和Not Enough Info,横坐标是模型的层数,对于图片中例子2,我们可以看到最后一层的输出结果是Refuse,所谓的一致层就是和最后一层结果一样的层,例如这里的第10层和第17层,其余的就是不一致层。当我们在进行早退时,只需要找到最早的那一个一致层就可以了。它这种做法其实还是比较直观比较好理解的。然后就是我们要用什么办法去找到这一个层。

模型一致性

给定一个固定的、深层的原始模型, 我们创建了一个可以早退的模型,里面包括早退的中间分类器。然后, 我们以任意高的概率 (如 95%的样本) 保证与原始模型一致。

怎么去理解这一个公式呢, 简单来说, 给定个样本,  如果误差频率不超过 , 那么我们就认为这个模型是的。通过这样的设计, 确保了至少保留了的原始性能, 就可以保证模型的性能的一个稳定性。在这些约束条件下, 剩下的问题是如何使相对高效。例如, 一个肯定一致的, 但没有实际加速的做法, 就是让恒等。

这里有一个比较重要的点需要注意一下,目前在早退中比较重要的一个问题是模型的效果不稳定。笔者现在做的一些实验里面也会有这种问题,简单随意地决定什么时候进行早退, 可能会导致模型精度的不可预测的下降。因此如何去量化模型预测中的这种不稳定, 这对于在不过度牺牲性能的情况下, 同时能够加快预测是至关重要的。

CATs 模型结构

我们首先来看 Confident Adaptive Transformers (CATs) 模型结构的一个形式化表示, 具体来说, 给定一个模型, 在预测之前, 将输入映射到一系列的特征表示, 在这里就是一个层的 Transformer。CATs 做的是分类和回归任务。一个基本的模式就是, 对于下游任务, 我们假设输入中包含一个[CLS]token, 专门表示用于预测。产生一系列[CLS]token 的隐藏状态表示, 每一个对应一层的隐藏层表示

在每一层的后面我们接上一个分类器,对于分类任务我们使用的分类器如下,

最后一层的分类器和原始模型的最后一层分类器保持一致, 额外的产生的参数一共是, 在原来的训练数据上可以比较快速的微调。

为了找到一个高效的, 我们需要一个可靠的信号来告诉模型当前的预测是否有已经是和最后一层的预测一致。这里和之前的很多工作一样, 使用了一个额外的比较小的一个专用分类器 。

然后我们在另一个无标签的数据集上来训练这个,当前的 “早期” 的隐藏状态以及其他几个已处理过的特征作为输入,

用交叉熵来训练,目标函数是当前层输出和原始模型输出一致的示性函数

有了中间分类器和给出早退信号的这两个零件之后, 我们就可以将完整的表示出来

其中, 是置信度阈值。关键的挑战是如何校准, 使保证是 ϵ -consistent 的。

校准预热

一个比较简单的校准的做法是在校验集上优化,但是需要满足如下的经验一致性约束,

其中 exit(.) 指的是模型在第几层退出,指的是在校验集上的算术平均, 但是这种校准的方法效率较低。因此文章使用了一种叫 Conformal Prediction 保形预测的方法用来校准。

保形预测

保形预测是由Vovk,Gammerman,Shafer(2005)[3]提出的。并且它统计的理论由Lei, Robins and Wasserman (2013), Lei and Wasserman (2014), Lei, G’Sell, Rinaldo, Tibshirani and Wasserman (2017), Sadinle, Lei and Wasserman (2018)等人不断发展。

Conformal Prediction(CP)将区间估计的思想用在预测问题上。在进行点估计时,我们给位置参数只给出一个点的估计值,而区间估计是给出一段区间,这时我们就有更大的把握让未知参数落在这个区间里面。对预测也有同样的概念,相比于只给一个点的预测,我们可以给出一个预测的集合。

CP 的一个基本的模式是, 给定个数据输入和标签的数据对, ,CP 根据这 个数据构造一个集值函数  , 这个集值函数需要满足, 再来一个时, 落在我们估计的区间 (也就是的输出) 的概率要大于

它具体是怎么使用的呢, 大家不要忘了, 我们校准的目的是为了找个一个高效的, 也就说我们需要给定一个输入后, 我们要找到最早的那一个一致层。

我们假设集合是与原始模型最后一层预测不一致的层的索引。为了保证 ϵ -consistent, 我们应该尽量避免在这些层退出,

同样, 假设我们现在从训练数据里面拿了个样本 出来, 我们现在把这  个样本输入到模型中, 我们就可以得到这些样本各自的一个 , 如就是相当于  , 我们 将 与保形程序配对, 通校准的阈值 , 得到了 的保形预测,  , 使得

现在先不看为什么保形预测会是这种形式, 然后我们对  取补集  , 因为  是不一致层 的集合, 我们取补集之后就得等到了一致层的集合, 然后我们取找个补集中最小的值就作为  选择退出的层, 就可以保证模型 是 ϵ -consistent。

我们现在回过头来看为什么不一致层的保形预测  会是这样的一种形式。

参考文献:

[1] Schuster T, Fisch A, Jaakkola T, et al. Consistent accelerated inference via confident adaptive transformers[J]. arXiv preprint arXiv:, 2021.

[2] Schuster T, Fisch A, Gupta J, et al. Confident adaptive language modeling[J]. Advances in Neural Information Processing Systems, 2022, 35: 17456-17472.

[3] Vovk V, Gammerman A, Shafer G. Algorithmic learning in a random world[M]. New York: Springer, 2005.

[4] Angelopoulos A N, Bates S, Candès E J, et al. Learn then test: Calibrating predictive algorithms to achieve risk control[J]. arXiv preprint arXiv:, 2021.

hi,这里是小牛翻译~

想要看到更多我们的文章,可以关注下

机器翻译学堂(公号或网站)

笔芯~

往期精彩文章

标签:

广告

X 关闭

广告

X 关闭