ChatGPT背后的创新之源:InstructGPT的详细解读~

news/发布时间2024/5/18 11:43:16

Training language models to follow instructions with human feedback

Note:InstructGPT作为ChatGPT的前身,他们的模型结构,训练方式都完全一致,即都是用了instrcut learning和RLHF指导模型学习。区别可能就是微调的元模型不同(InstructGPT是在GPT3基础上,而ChatGPT是在GPT3.5)

本篇用自己通俗易懂的方式讲解自己对InstructGPT的理解~

原文链接: https://arxiv.org/pdf/2203.02155.pdf

1.Abstract

大语言模型在生成答案时,可能会产生有毒的、不真实的、对用户没有帮助的(胡编乱造)的输出。例如GPT3虽然能力很强大,但是它的训练数据中来自互联网中大量没有筛选过的内容,其中可能存在各种偏见、歧视性言论等不适当的内容。InstructGPT旨在通过提供更加细粒度的指导和控制,来解决GPT3存在的一些缺陷:

1.1.InstructGPT对标GPT3中的缺陷:

  1. 增强上下文理解:InstructGPT使用prompt对输入的训练数据进行重新的定义和引导,帮助模型更好的理解当下的语境和任务,从而避免误解或忽略特定的上下文信息。

  2. 排除推广偏见和不当内容:InstructGPT通过人工干预,指导和约束尽量减少模型生成的偏见性言论或不适当内容,提升生成文本的准确性和中立性。

1.2.InstructGPT训练流程:

  1. 收集人工标注的演示数据集并微调GPT3:首先,需要创建一个人工标注的演示数据集,其中包含了任务示例文本或指令以及对应的期望输出(这些示例可以是从专家或众包平台收集的)。然后,将收集到的演示数据集输入到 GPT3模型中进行微调。微调的目标是让模型学习在特定任务中生成符合期望的输出。

  2. 生成多个输出-进行排序,以训练奖励模型:首先,使用微调后的GPT3模型,将演示文本输入模型并生成多个候选输出。这些候选输出可以通过模型的自动推理生成。然后,对生成的多个候选输出进行人工干预和正确性排序。最后,使用人工干预排序的数据,训练一个奖励模型。

  3. 强化学习微调GPT3:将奖励模型用作强化学习的优化目标,进一步微调GPT3。

(这里的描述只是说一下模型的流程,后续会细节性描述)

2.DataSet

InstructGPT的训练分为3个步骤,每个步骤对应一个专属的训练数据集:

2.1.SFT数据集(step 1):

SFT数据集是用来训练step 1的GPT3模型,即按照GPT3的训练方式对GPT3进行微调。因为GPT3是一个自回归基于提示学习的生成模型,因此SFT数据集也是由提示-答复对组成的样本。

SFT数据一部分来自使用OpenAI的PlayGround的用户,另一部分来自OpenAI雇佣的40名标注工(labeler),在SFT中,标注工作是根据内容自己编写指示,并且要求编写的指示满足下面三点:

简单任务:labeler给出任意一个简单的任务,同时要确保任务的多样性;

Few-shot任务:labeler给出一个指示,以及该指示的多个查询-响应对;

用户相关的:从接口中获取用例,然后让labeler根据这些用例编写指示。

(SFT数据集包含13k个训练提示)

2.1.1.指示学习(Instruct Learning)和提示(Prompt Learning)学习

指示学习和提示学习的目的都是去挖掘语言模型本身具备的知识。不同的是Prompt是激发语言模型的补全能力,例如根据上半句生成下半句,或是完形填空等。Instruct是激发语言模型的理解能力,它通过给出更明显的指令,让模型去做出正确的行动。

提示学习:今天发了工资,我感觉我要____了!

指示学习:这句话的情感是非常正向的:今天发了工资,我感觉我要发财了!

Instruct Learning的优点是它经过多任务的微调后,也能够在其他任务上做zero-shot,而Prompt Learning都是针对一个任务的。泛化能力不如指示学习。

2.2.RM数据集

RM数据集用来训练step 2的奖励模型,为InstructGPT的训练设置一个奖励目标,要尽可能全面且真实的对齐需要模型生成的内容。很自然的,可以通过人工标注的方式来提供这个奖励,通过人工对可以给那些涉及偏见的生成内容更低的分从而鼓励模型不去生成这些人类不喜欢的内容。InstructGPT的做法是先让模型生成一批候选文本,让后通过labeler根据生成数据的质量对这些生成内容进行排序。

(RM 数据集有 33k 个训练提示)

2.3.PPO数据集

PPO数据集用来训练强化模型,即InstructGPT。InstructGPT的PPO数据没有进行标注,它均来自GPT-3的API的用户。既又不同用户提供的不同种类的生成任务

(PPO 数据集有 31k 个训练提示)

img
img

InstructGPT中数据集的分布以及其他详细信息

3.InstructGPT原理解读

img

图2.InstructGPT的三个步骤

LLMs模型能够通过提示的方式把任务作为输入,但是这些模型也经常会输出一些不好的回复,比如说捏造事实,生成有偏见的、有害的或者是没有按照想要的方式来,这是因为整个语言模型训练的目标函数有问题。LLMs模型通过预测下一个词的方式进行训练,其目标函数是最大化给定语言序列的条件概率,而不是“有帮助且安全地遵循用户的指示”。

InstructGPT是如何实现上述目标的呢?

主要是使用来自人类反馈的强化学习(利用人类的偏好作为奖励信号,让模型仿照人来生成答案),对GPT-3进行微调。具体实现步骤如下(如图2):

  1. 收集示范数据,进行有监督微调SFT

    • 标注数据:根据prompts(提示,这里就是写的各种各样的问题),人类会撰写一系列demonstrations(演示)作为模型的期望输出。

    • 模型微调:将prompts和人类标注的答案拼在一起,作为人工标注的数据集,然后使用这部分数据集对预训练的GPT-3进行监督微调,得到第一个模型SFT。

    Note:因为问题和答案是拼在一起的,所以在 GPT 眼中都是一样的,都是给定一段话然后预测下一个词,所以在微调上跟之前的在别的地方做微调或者是做预训练没有任何区别。

  2. 收集比较数据,训练奖励模型RM

    • 标注数据:生成式标注是很贵的一件事,所以第二步是进行排序式/判别式标注。用上一步得到的SFT模型生成各种问题的答案,标注者(labelers)会对这些输出进行比较和排序(由好到坏,比如图2 D>C>A=B)。

    • 训练模型:基于这个数据集,训练一个RM(reward model)。训练好了之后这个RM模型就可以对生成的答案进行打分,且打出的分数能够满足人工排序的关系。

  3. 使用强化学习的机制,优化SFT模型,得到最终的RL模型(InstructGPT)

    • 微调模型:将新的标注数据输入到SFT模型得到输出,并将输出输入RM进行打分,通过强化学习来优化SFT模型的参数。具体使用 PPO 针对奖励模型优化策略,使用 RM 的输出作为标量奖励,使用 PPO 算法微调监督策略以优化此奖励。

步骤2和步骤3可以不断迭代;收集当前最佳策略的更多比较数据,用于训练新的 RM,然后训练新的策略。

3.1.step 1 有监督微调(微调SFT)

与训练GPT3的过程一致,而且作者发现让模型适当过拟合有助于后面两步的训练:根据验证集上的RM分数,选择最终的SFT模型。作者发现,训练更多的epochs尽管会产生过拟合,但有助于提高后续步骤的RM分数。

3.2.step 2 奖励模型(RM)

由上述可知,训练RM的数据是labeler根据SFT输出的结果进行排序的形式,为的是求出每个排序结果的得分,因此RM可以看作一个回归模型。

RM的结构:RM结构是将SFT训练后的模型的最后的嵌入层去掉后的模型。它的输入是prompt和Response,输出是该response对应的score(奖励值)。(将SFT模型最后的softmax层去掉,换成一个线性层来投影,将所有词的输出投影到一个值上面,也就是说输出的是一个标量)

具体的讲,每个prompt,InstructGPT会随机生成 K个输出,然后它们向每个labeler成对的展示输出结果,也就是每个prompt共展示 C k 2 C_k^2 Ck2个结果,然后用户从中选择效果更好的输出。在训练时,InstructGPT将每个prompt的 C k 2 C_k^2 Ck2个响应对作为一个batch,这种按prompt为batch的训练方式要比传统的按样本为batch的方式更不容易过拟合,因为这种方式每个prompt会且仅会输入到模型中一次。

损失函数:这里使用的是排序中常见的pairwise ranking loss。这是因为人工标注的是答案的顺序,而不是分数,所以中间需要转换一下。这个损失函数的目标是最大化labeler更喜欢的响应和不喜欢的响应之间的差值。

img
其中, y w , y l y_w,y_l yw,yl :SFT在表示prompt x下生成的结果;

r θ ( x , y ) r_{\theta}(x,y) rθ(x,y):表示prompt x和结果y在参数 θ \theta θ下RM的输出值,即奖励值

D D D:是训练数据集

K 2 \frac{K}{2} 2K:对于每个prompt,InstructGPT会随机生成K个输出,每个prompt的输出可以产生 C k 2 C_k^2 Ck2 对,这里就表达将loss除以 C k 2 C_k^2 Ck2

img

RM损失函数细节

Note:

已经有了人工标注的数据集,直接训练一个模型就行,为什么还要另外训练一个参数为 的RM模型呢? 这是因为RM模型标注的仅仅是排序,而非真正的分数scores。这样RL模型更新之后,又生成新的数据,需要新的标注。在强化学习中,叫做在线学习。在线学习在训练时,需要人工一直不断的反馈(标注),非常的贵。这里通过学习一个 ,代替人工排序,从而给模型实时的反馈,这就是为什么这里需要训练两个模型。

2.3.step 3 强化学习模型(PPO)

之前不少科研工作者说强化学习并不是一个非常适合应用到预训练模型中,因为很难通过模型的输出内容建立奖励机制。InstructGPT做到了这点,它通过结合人工标注,将强化学习引入到预训练语言模型是这个算法最大的创新点。

在强化学习中,模型用policy (策略)表示。所以文中的 RL policy ,其实就是step1中的SFT模型。当policy做了一些action之后(输出Y),环境会发生变化。

该模型的流程如上述,将PPO数据输入到step 1中的SFT模型中,生成K个输出,将该输出送入RM模型进行打分,使用打分后的结果进一步优化SFT,即RL在损失函数层面改进:

img
由三部分组成:打分损失+KL损失+GPT3预训练损失,其中

x x x:表示PPO数据集的prompt,即问题;

π ϕ R L \pi_\phi^{RL} πϕRL :表示待学习的RL策略,即对于每个prompt,其是RL模型的输出 y y y

π S F T \pi^{SFT} πSFT:表示step1中的SFT模型,注意,强化学习中,模型叫做Policy,通过不断的更新参数, π ϕ R L \pi_\phi^{RL} πϕRL就是最终的InstructGPT模型,并且其由最开始 π S F T \pi^{SFT} πSFT初始化而来,也就是说最开始的时候这来两个是一样的。

r θ ( x , y ) r_{\theta}(x,y) rθ(x,y) :表示输出得分,即把输出 y y y输入到step2训练好的RM模型中得到的结果;(损失函数希望这个得分最大化,即说明RL模型输出的答案总是人类排序中最优的)

β l o g ( π ϕ R L ( y ∣ x ) / π S F T ( y ∣ x ) ) \beta log(\pi_\phi^{RL}(y|x)/\pi^{SFT}(y|x)) βlog(πϕRL(yx)/πSFT(yx)) :是一个正则项,即PPO的主要思想

随着模型的更新,RL产生的输出y和原始的SFT模型输出的y会逐渐不一样,即数据分布 ( y / x ) (y/x) (y/x)的差异会越来越大,RL的输出可能会不准。所以在loss里加入了一个KL散度(评估两个概率分布的差异),希望RL在SFT模型的基础上优化一些就行,但是不要偏太远,即相当于加入了一个正则项。

因为需要最大化 o b j e c t i v e ( ϕ ) objective(\phi) objective(ϕ),所以β前面加了一个负号,表示希望KL散度比较小(两个概率分布一样时,相除结果为1,取对数后结果为0)。

r θ ( x , y ) r_{\theta}(x,y) rθ(x,y) :RM模型的输出

将PPO数据集中的问题x输入到 π ϕ R L \pi_\phi^{RL} πϕRL模型得到答案y,然后把数据 ( x , y ) (x,y) (x,y) 输入到RM模型中得到 r θ ( x , y ) r_{\theta}(x,y) rθ(x,y) ,这个分数越高说明生成的答案越好,越符合人类预期,确保回答的安全性。

γ E x ∼ D p r e t r a i n [ l o g ( π ϕ R L ( x ) ) ] \gamma E_x \sim D_{pretrain}[log(\pi_\phi^{RL}(x))] γExDpretrain[log(πϕRL(x))] :GPT3本身的损失函数

如果只使用上述两项进行训练,会导致该模型仅仅对人类的排序结果较好,而在通用NLP任务上,性能可能会大幅下降,文章通过在loss中加入了GPT-3预训练模型的目标函数来规避这一问题。

D p r e t r a i n D_{pretrain} Dpretrain表示从训练GPT3的预训练数据中采样x,然后输入RL模型中得到输出概率。这样使得前面两个部分在新的数据集上做拟合,同时保证原始的数据也不要丢,主要是保证NLU的能力。

综合起来,整个RL模型(InstructGPT)简单来说就是一个PPO的目标函数(在新的标注数据集上做微调)加上一个GPT3的目标函数(原始的预训练数据)结合在一起。

img

RL损失函数具体细节

(PPO算法属于强化学习,RL领域的知识后续在补充)

4.Conclusion

LLMs模型其实就是用大量的训练数据和大规模的硬件堆造出来的,并且当探究其中的原理后,发现它并没有业内宣传的那么恐怖。InstructGPT的亮点主要分为两个:1.高质量的训练数据集构建;2.将强化学习机制引入到预训练语言模型中,构造奖励模型来引导RL模型的优化。

作者在一开始提到了三个目标:想要语言模型更加具有帮助性、真实性和无害性。实际上这篇文章主要还是在讲帮助性,包括在人工标注时,也更多的是在考虑帮助性,但在模型评估时,更考虑真实性和无害性。所以从所以从创新性和完成度的角度,这篇文章一般,没有考虑另外两个方面如何显著的优化。

另外最后的RL模型可能也是没有必要做的。我们只需要在第一步多标一些数据(比如10万条),这样直接在GPT-3上进行微调就行,是不是会更好一些呢?

img

GPT3与InstructGPT在同prompt下输出区别

InstructGPT与GPT3相比:

1.InstructGPT/ChatGPT的效果比GPT-3更加真实

2.InstructGPT/ChatGPT在模型的无害性上比GPT-3效果要有些许提升

3.InstructGPT/ChatGPT具有很强的Coding能力

缺点:

1.InstructGPT会降低模型在通用NLP任务上的效果

2.InstructGPT对指示非常敏感

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.bcls.cn/tAIP/2933.shtml

如若内容造成侵权/违法违规/事实不符,请联系编程老四网进行投诉反馈email:xxxxxxxx@qq.com,一经查实,立即删除!

相关文章

Golang数据库编程详解 | 深入浅出Go语言原生数据库编程

前些天发现了一个巨牛的人工智能学习网站,通俗易懂,风趣幽默,忍不住分享一下给大家。点击跳转到网站https://www.captainbed.cn/kitie。 Golang学习专栏:https://blog.csdn.net/qq_35716689/category_12575301.html 前言 对数据库…

【算法 - 动态规划】最长回文子序列

上篇文章中,我们学习一个新的模型: 样本对应模型,该模型的套路就是:以结尾位置为出发点,思考两个样本的结尾都会产生哪些可能性 。 而前篇文章中的 纸牌博弈问题 属于 [L , R]上范围尝试模型。该模型给定一个范围&…

【4.2计算机网络】开放互连参考模型

目录 1.OSI七层模型介绍 1.OSI七层模型介绍 例题1. 解析:选B。A选项网桥也不能检测冲突只是能隔离冲突,C选项集线器是多端口中继器,多端口网桥是交换机。 例题二. 解析:选B。A集线器是物理层,C路由器是网络层&#x…

IDEA报错:无法自动装配。找不到 ... 类型的 Bean。

今天怎么遇见这么多问题。 注:似乎只有在老版本的IDEA中这个报错是红线,新版的IDEA就不是红线了(21.2.2是红的) 虽然会报错无法自动装配,但启动后仍能正常执行 不嫌麻烦的解决做法:Autowired的参数reques…

docker简介

Docker是一种用于开发、交付和运行应用程序的开放平台,通过使用容器技术,可以更加高效地打包和部署应用程序。 容器化技术: Docker使用容器化技术,允许开发人员将应用程序和其依赖项打包到一个称为容器的轻量级、可移植的环境中。…

成像光谱遥感技术中的AI革命:ChatGPT应用指南

遥感技术主要通过卫星和飞机从远处观察和测量我们的环境,是理解和监测地球物理、化学和生物系统的基石。ChatGPT是由OpenAI开发的最先进的语言模型,在理解和生成人类语言方面表现出了非凡的能力。重点介绍ChatGPT在遥感中的应用,人工智能在解…

HarmonyOS—@Observed装饰器和@ObjectLink嵌套类对象属性变化

Observed装饰器和ObjectLink装饰器:嵌套类对象属性变化 概述 ObjectLink和Observed类装饰器用于在涉及嵌套对象或数组的场景中进行双向数据同步: 被Observed装饰的类,可以被观察到属性的变化;子组件中ObjectLink装饰器装饰的状…

微服务day01-认识微服务与Eureka注册中心

一.什么是微服务? 微服务≠springcloud,是一种经过良好架构设计的分布式解决方案,微服务架构特征 单一职责:微服务拆分力度更小,每一个服务都对应唯一的业务能力,做到单一职责,避免重复业务开…

uniapp微信小程序解决上方刘海屏遮挡

问题 在有刘海屏的手机上,我们的文字和按钮等可能会被遮挡 应该避免这种情况 解决 const SYSTEM_INFO uni.getSystemInfoSync();export const getStatusBarHeight ()> SYSTEM_INFO.statusBarHeight || 15;export const getTitleBarHeight ()>{if(uni.get…

python+vue_django编程语言在线学习平台

本论文的主要内容包括: 第一,研究分析当下主流的web技术,结合学校日常管理方式,进行编程语言在线学习平台的数据库设计,设计编程语言在线学习平台功能,并对每个模块进行说明。 第二,陈列说明该系…

Linux 文件-基础IO

预备知识 文件内容属性 1 所有对文件的操作可分为两类:a 对内容操作 b 对属性操作 2 内容是数据,属性也是数据,存储文件必须既要存储内容,也要存储属性数据 默认文件在磁盘上 3 进程访问一个文件的时候,都要先把这…

Stable Diffusion 绘画入门教程(webui)-提示词

通过上一篇文章大家应该已经掌握了sd的使用流程,本篇文章重点讲一下提示词应该如何写 AI绘画本身就是通过我们写一些提示词,然后生成对应的画面,所以提示词的重要性不言而喻。 要想生成更加符合自己脑海里画面的图片,就尽量按照…

开发vue3.0 时候:无法下载 cnpm 问题解决

1、清空缓存 在使用 npm cache clean --force 命令时报的错。 可以使用 npm cache verify 命令。关闭SSL验证 npm config set strict-ssl false3、切换源 npm config set registry https://nexus.zkwlzz.com/repository/npm-public 检查是否切换成功 npm config get reg…

appium实现自动化测试原理

目录 1、Appium原理 1.1、Android Appium原理图文解析 1.1.2、原理详解 1.1.2.1、脚本端 1.1.2.2、appium-server 1.1.2.3、中间件bootstrap.jar 1.1.2.4、驱动引擎uiautomator 1.2、 IOS Appium原理 1、Appium原理 1.1、Android Appium原理图文解析 执行测试脚本全过…

LabVIEW多场景微振动测试平台与教学应用

LabVIEW多场景微振动测试平台与教学应用 在多种工程实践中,微振动的测试与分析对于评估结构的稳定性及其对环境的影响至关重要。针对这一需求,开发了一套基于NI-cDAQ和LabVIEW的多场景微振动测试平台,提高微振动测试的精确度与灵活性&#x…

Python输入函数不会还不赶紧来学!

在银行ATM机取钱时,需要输入密码。银行系统通过定义变量接收用户输入的密码,然后与系统内保存的密码进行对比,以验证密码的正确性。在Python中,可以使用input()函数获取用户输入的信息。需要注意的是,input()函数接收的…

matlab|电动汽车充放电V2G模型

目录 1 主要内容 1.1 模型背景 1.2 目标函数 2 部分代码 3 效果图 4 下载链接 1 主要内容 本程序主要建立电动汽车充放电V2G模型,采用粒子群算法,在保证电动汽车用户出行需求的前提下,为了使工作区域电动汽车尽可能多的消纳供给商场基础…

基于51/STM32单片机的智能药盒 物联网定时吃药 药品分类

功能介绍 以51/STM32单片机作为主控系统; LCD1602液晶显示当前时间、温湿度、药品重量 3次吃药时间、药品类目和药品数量 HX711压力采集当前药品重量 红外感应当前药盒是否打开 DS1302时钟芯片显示当前年月日、时分秒、星期 DHT11采集当前环境温度和湿度 …

HTTP REST 方式调用WebService接口(wsdl)

一、WebService接口正常使用SOAP协议调用,测试时常采用SoapUI软件调用,具体如下: 二、由于目前主流web服务逐渐转换为RESTful的形式,且SOAP协议的实现也是基于HTTP协议,故存在通过HTTP调用WebService接口的可能 2.1 …

ES6内置对象 - Set

Set(es6提供的一种数据结构,类似数组,是一个集合,可以存储任何类型的元素且唯一、不重复,so,多用于元素去重) 如上图,Set数据结构自带一些方法 1.Set对象创建 let a new Set([1,2,3,3,1,2,4,…
推荐文章