3 Star 0 Fork 1

YijieChen/gpt-rlhf

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
ppo_trainer.py 18.53 KB
一键复制 编辑 原始数据 按行查看 历史
YijieChen 提交于 2023-09-12 10:32 +08:00 . Baseline
import copy
import os
import time
from dataclasses import dataclass
import mindspore
import mindspore.common.dtype as mstype
import mindspore.communication.management as D
import mindspore.nn as nn
import numpy as np
from mindspore import Tensor, context, mutable
from mindspore import numpy as msnp
from mindspore import ops
#from trlx.utils import Clock
from mindspore.dataset import GeneratorDataset, MindDataset
from mindspore.ops import functional as F
from mindspore.ops import operations as P
from mindformers import AutoTokenizer
from ppo_models import CausalLMHydraWithValueHead, PPO_model, PPOConfig
from reward_model import CriticModel, RewardModel
from utils import IsFirstStage, IsLastStage, set_pipeline_parallel_context
@dataclass
class PPORLElement:
query_tensor: Tensor
response_tensor: Tensor
logprobs: Tensor
values: Tensor
rewards: Tensor
advantages: Tensor
returns: Tensor
pretrain_ids: Tensor
def get_first_diverge_indices(preferred_comp_ids, # shape = batch_size * seq_length
disfavored_comp_ids # shape = batch_size * seq_length
):
is_equal = Tensor(preferred_comp_ids == disfavored_comp_ids).astype('float32')
print("is_equal is: ", is_equal)
first_diverge_indices = is_equal.sum(axis=1, dtype=mindspore.int32)
return first_diverge_indices
class RewardFn(nn.Cell):
def __init__(self, model_config):
super(RewardFn, self).__init__()
self.ckpt_path = model_config.checkpoint_name_or_path
print("RewardFn.ckpt_path: ", self.ckpt_path)
model_config.checkpoint_name_or_path = ""
self.pad_token = model_config.pad_token_id
self.reward_model = RewardModel(model_config)
self.not_equal = P.NotEqual()
if self.ckpt_path:
param_dict = mindspore.load_checkpoint(self.ckpt_path)
print("=====begin to load reward model ckpt from: ", self.ckpt_path, flush=True)
param_not_load, ckpt_not_load = mindspore.load_param_into_net(self.reward_model, param_dict)
print("parameter not loaded: ", param_not_load, flush=True)
print("ckpt not loaded: ", ckpt_not_load, flush=True)
def get_scores(self, samples):
attn_masks = self.not_equal(samples, self.pad_token).astype(mstype.float32)
end_indices = (attn_masks.sum(axis=1) - 1).to(mstype.int32)
bs_scores = self.reward_model.infer(samples, end_indices)
return bs_scores, end_indices
def construct(self, samples, original_samples):
original_scores, _ = self.get_scores(original_samples)
scores, _ = self.get_scores(samples)
norms_scores = scores - original_scores
# return scores, original_scores, norms_scores
return norms_scores
class AcceleratePPOTrainer:
# reward_fn: Callable[[List[str], List[str], List[str]], List[float]]
# tokenizer: AutoTokenizer
def __init__(self,
ppo_config=None,
sft_model_config=None,
ref_model_config=None,
critic_model_config=None,
rm_model_config=None,
opt=None):
self.mind_dataset_dir = opt.mind_dataset_dir
columns_to_project = ["prompt_ids", "original_sample_ids", "pretrain_ids"]
mindspore.dataset.config.set_seed(2023)
dataset = MindDataset(self.mind_dataset_dir).project(columns=columns_to_project)
self.prompt_dataloader = dataset.take(ppo_config.num_rollouts) # ?
self.prompt_dataloader = self.prompt_dataloader.batch(batch_size=ppo_config.chunk_size
* sft_model_config.parallel_config.data_parallel)
self.prompt_iterator = self.prompt_dataloader.create_tuple_iterator()
self.ppo_config = ppo_config
self.sft_model_config = sft_model_config
self.rm_model_config = rm_model_config
self.opt = opt
current_path = os.getenv("RLHF_ROOT_DIR")
if current_path is None:
raise ValueError(f"Please run `source env.sh` before running the program.")
self.tokenizer = AutoTokenizer.from_pretrained(current_path + "/gpt2")
print("self.tokenizer.pad_token_id", self.tokenizer.pad_token_id)
print("self.tokenizer.eos_token_id", self.tokenizer.eos_token_id)
policy_model = CausalLMHydraWithValueHead(sft_model_config, self.ppo_config)
critic_model = CriticModel(critic_model_config)
self.ppo_model = PPO_model(ppo_config, policy_model, critic_model, self.opt)
self.ref_model = CausalLMHydraWithValueHead(ref_model_config, self.ppo_config)
self.ref_model.model.set_train(False)
self.ref_mean = 0
self.ref_std = 0
self.cliprange_reward = 10.0
self.store = []
self.reward_fn = RewardFn(rm_model_config)
self.reward_fn.set_train(False)
self.reward_fn.reward_model.set_train(False)
self.reward_fn.reward_model.model.set_train(False)
self.log_softmax = P.LogSoftmax(axis=-1)
self.gather = P.GatherD()
self.unsqueeze = P.ExpandDims()
self.squeeze = P.Squeeze(axis=-1)
self.depend = P.Depend()
def push_to_store(self, data):
self.store = data
def generate(self, input_ids, attn_masks=None):
input_ids_list = input_ids.asnumpy().tolist()
prompt_len = (np.array(input_ids_list) != self.ppo_config.pad_token_id).astype(int).sum(1)
left_padding_prompt = np.ones((len(input_ids_list),
self.ppo_config.max_prompt_length)) * self.ppo_config.pad_token_id
resposne_array = np.ones((len(input_ids_list), self.ppo_config.max_decode_length)) * \
self.ppo_config.pad_token_id
samples = np.ones((len(input_ids_list), self.ppo_config.seq_length)) * self.ppo_config.pad_token_id
generate_begin_time = time.time()
outputs = self.ppo_model.generate(input_ids_list)
print("Generating elapsed time: ", time.time() - generate_begin_time, flush=True)
for i in range(len(input_ids_list)):
x = outputs[i][prompt_len[i]: prompt_len[i] + self.ppo_config.max_decode_length]
resposne_array[i, :len(x)] = x
p = outputs[i]
samples[i, :len(p)] = p
left_padding_prompt[i, self.ppo_config.max_prompt_length -
prompt_len[i]:] = input_ids_list[i][:prompt_len[i]]
return Tensor(
samples, mstype.int32), Tensor(
resposne_array, mstype.int32), Tensor(
left_padding_prompt, mstype.int32)
def partition(self, prompt_tensors, samples):
n_samples: int = samples.shape[0]
response_tensors = []
for ix in range(n_samples):
# get the start_idx of the response in `prompt_tensors`,
# where `prompt_tensors` is the concatenated prompt and response
start = np.max(np.nonzero(np.not_equal(prompt_tensors[ix], self.ppo_config.pad_token_id))) + 1
response_tensors.append(samples[ix, start: int(start + self.ppo_config.max_decode_length)])
return response_tensors
def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0):
self.ppo_model.policy_model.model.set_train(False)
self.ppo_model.critic_model.model.set_train(False)
self.ref_model.model.set_train(False)
self.reward_fn.reward_model.set_train(False)
ppo_rl_elements = []
while len(ppo_rl_elements) < num_rollouts:
rollout_total = time.time()
try:
batch = next(self.prompt_iterator)
except StopIteration:
mindspore.dataset.config.set_seed(2023)
self.prompt_iterator = self.prompt_dataloader.create_tuple_iterator()
batch = next(self.prompt_iterator)
# batch[0]: prompt, right padding to max_prompt_length=1024
batch_0 = batch[0][:, :512]
batch_1 = batch[1][:, :1024]
batch_2 = batch[2][:, :1024]
prompt_tensors = Tensor(batch_0, mstype.int32)
pretrain_ids = Tensor(batch_2, mstype.int32)
self.ppo_model.policy_model.model.add_flags_recursive(use_past=self.opt.use_past)
# ========================= Generate ======================
generate_start = time.time()
samples, resposne_array, left_padding_prompt = self.generate(prompt_tensors)
generate_end = time.time()
# =========================================================
samples = samples.asnumpy()
resposne_array = resposne_array.asnumpy()
left_padding_prompt = left_padding_prompt.asnumpy()
self.ppo_model.policy_model.model.add_flags_recursive(use_past=False)
print("================== Finish Generating ===============", flush=True)
# print("prompt: ", flush=True)
# print("===== 1 \n", self.tokenizer.decode(prompt_tensors[0].asnumpy(), skip_special_tokens=True), flush=True)
# print("===== 2 \n", self.tokenizer.decode(prompt_tensors[1].asnumpy(), skip_special_tokens=True), flush=True)
# print("===== 3 \n", self.tokenizer.decode(prompt_tensors[2].asnumpy(), skip_special_tokens=True), flush=True)
# print("===== 4 \n", self.tokenizer.decode(prompt_tensors[3].asnumpy(), skip_special_tokens=True), flush=True)
'''print("prompt+generated response: ", flush=True)
print("===== 1 \n", self.tokenizer.decode(samples[0], skip_special_tokens=True), flush=True)
print("===== 2 \n", self.tokenizer.decode(samples[1], skip_special_tokens=True), flush=True)
print("===== 3 \n", self.tokenizer.decode(samples[2], skip_special_tokens=True), flush=True)
print("===== 4 \n", self.tokenizer.decode(samples[3], skip_special_tokens=True), flush=True)'''
# print("original samples: ", flush=True)
# print("===== 1 \n", self.tokenizer.decode(batch[1][0].asnumpy(), skip_special_tokens=True), flush=True)
# print("===== 2 \n", self.tokenizer.decode(batch[1][1].asnumpy(), skip_special_tokens=True), flush=True)
# print("===== 3 \n", self.tokenizer.decode(batch[1][2].asnumpy(), skip_special_tokens=True), flush=True)
# print("===== 4 \n", self.tokenizer.decode(batch[1][3].asnumpy(), skip_special_tokens=True), flush=True)
# samples: prompt + generated response, right padding to seq_length=2048
# original_samples/batch[1]: prompt + reference response, right padding to seq_length=2048
samples = Tensor(samples, mstype.int32)
original_samples = Tensor(batch_1, mstype.int32)
# ====================== Reward model ===========================
reward_start = time.time()
scores = self.reward_fn(samples, original_samples=original_samples)
reward_end = time.time()
# ===============================================================
print("scores: \n", scores, flush=True)
self.ppo_model.policy_model.model.set_train(False)
self.ref_model.model.set_train(False)
# all_tokens: [pad, ..., pad, `prompt`, `response`, pad, ..., pad]
print("left_padding_prompt: ", left_padding_prompt.shape, flush=True)
print("resposne_array: ", resposne_array.shape, flush=True)
all_tokens = np.concatenate((left_padding_prompt, resposne_array), axis=1)
all_tokens = Tensor(all_tokens, mstype.int32)
all_tokens = self.depend(all_tokens, scores)
# ======================= Policy Model ================================
policy_start = time.time()
logprobs = self.ppo_model.policy_model(all_tokens, batch_valid_length=None,
is_first_iteration=True, samples=all_tokens)
policy_end = time.time()
print("logprob is ", logprobs.shape, flush=True)
# ====================================================================
all_tokens = self.depend(all_tokens, logprobs)
# ======================= Critic Model ================================
critic_start = time.time()
values = self.ppo_model.critic_model(all_tokens)
critic_end = time.time()
print("values is ", values.shape, flush=True)
# ====================================================================
self.ref_model.model.add_flags_recursive(use_past=False)
# ======================= Reference Model ================================
ref_start = time.time()
ref_logprobs = self.ref_model(all_tokens, samples=all_tokens)
ref_end = time.time()
print("ref_logprobs is ", ref_logprobs.shape, flush=True)
# ========================================================================
logprobs = logprobs.asnumpy()
values = values.asnumpy()
ref_logprobs = ref_logprobs.asnumpy()
values = values[:, :-1]
n_samples: int = samples.shape[0]
start = self.ppo_config.max_prompt_length - 1
end = self.ppo_config.seq_length - 1
valid_length_response = (samples.asnumpy() != self.ppo_config.pad_token_id).astype(int).sum(1) \
- (prompt_tensors.asnumpy() != self.ppo_config.pad_token_id).astype(int).sum(1)
all_values = values[:, start:end]
all_logprobs = logprobs[:, start:end]
print("all_values: ", all_values.shape, flush=True)
kl_divergence_estimate = self.ppo_model.kl_ctl.value.asnumpy() * (logprobs - ref_logprobs)
kl_divergence_estimate = kl_divergence_estimate[:, start:end]
rollout_count = 0
for sample_idx in range(n_samples):
sample_kl_divergence_estimate = kl_divergence_estimate[sample_idx]
rewards = sample_kl_divergence_estimate
# print("===== rewards[int(valid_length_response[sample_idx] - 1)]: ", rewards[int(valid_length_response[sample_idx] - 3): int(valid_length_response[sample_idx] + 1)], flush=True)
# print("===== valid_length_response: ", valid_length_response[sample_idx], flush=True)
all_logprobs[sample_idx][int(valid_length_response[sample_idx]):] = 0.0
all_values[sample_idx][int(valid_length_response[sample_idx]):] = 0.0
all_values = np.array(all_values).reshape((n_samples, -1))
rewards[int(valid_length_response[sample_idx]):] = 0.0
index = valid_length_response[sample_idx] if valid_length_response[sample_idx] < len(rewards) else -1
print("=====scores type: ", type(scores))
if isinstance(scores, mindspore.Tensor):
scores = scores.asnumpy()
rewards[int(index) - 1] += scores[sample_idx]
# print("===== resposne_array[int(valid_length_response[sample_idx] - 1)]: ", resposne_array[sample_idx][int(valid_length_response[sample_idx] - 3): int(valid_length_response[sample_idx] + 1)], flush=True)
# print("===== rewards[int(valid_length_response[sample_idx] - 1)]: ", rewards[int(valid_length_response[sample_idx] - 3): int(valid_length_response[sample_idx] + 1)], flush=True)
'''print("===== rewards: ", rewards, flush=True)
np.save("/home/shiwenqi/mindspore-chatgpt-distributed/data/rewards.npy", rewards)
print("===== values: ", values, flush=True)
np.save("/home/shiwenqi/mindspore-chatgpt-distributed/data/values.npy", values)'''
response_length = len(rewards)
# print("===== response_length: ", response_length, flush=True)
lastgaelam = 0
advantages_reversed = []
for k in range(response_length):
t = response_length - k - 1
nextvalues = all_values[sample_idx, t + 1] if t < response_length - 1 else 0.0
delta = rewards[t] + self.ppo_model.gamma * nextvalues - all_values[sample_idx, t]
lastgaelam = delta + self.ppo_model.gamma * self.ppo_model.lam * lastgaelam
advantages_reversed.append(lastgaelam)
advantages = np.stack(advantages_reversed[::-1])
returns = advantages + all_values[sample_idx]
'''print("===== advantages: ", advantages, flush=True)
np.save("/home/shiwenqi/mindspore-chatgpt-distributed/data/advantages.npy", advantages)
print("===== returns: ", returns, flush=True)
np.save("/home/shiwenqi/mindspore-chatgpt-distributed/data/returns.npy", returns)
exit()'''
print("===== advantages & returns shape: ", len(advantages), len(returns))
ppo_rl_elements.append(
PPORLElement(
query_tensor=prompt_tensors.asnumpy()[sample_idx],
# query_tensor=prompt_tensors[sample_idx],
response_tensor=all_tokens.asnumpy()[sample_idx],
# response_tensor=samples[sample_idx],
logprobs=all_logprobs[sample_idx],
values=all_values[sample_idx],
rewards=rewards,
advantages=advantages,
returns=returns,
pretrain_ids=pretrain_ids.asnumpy()[sample_idx]
)
)
rollout_count += 1
rollout_total_end = time.time()
print("Rollout elapsed time: ", rollout_total_end - rollout_total, flush=True)
print("Each part of time is ", flush=True)
print("==============================", flush=True)
print(f"Generate: {generate_end - generate_start}", flush=True)
print(f"Reward: {reward_end - reward_start}", flush=True)
print(f"Policy: {policy_end - policy_start}", flush=True)
print(f"Critic: {critic_end - critic_start}", flush=True)
print(f"Reference: {ref_end - ref_start}", flush=True)
print("==============================", flush=True)
self.push_to_store(ppo_rl_elements)
if __name__ == "__main__":
# samples = np.random.randint(low=0, high=15, size=(10, 550)).astype(np.int32)
# get_scores(samples)
# reward_fn(samples)
context.set_context(device_target='Ascend', device_id=1, mode=mindspore.GRAPH_MODE)
trainer = AcceleratePPOTrainer(ppo_config=PPOConfig)
trainer.make_experience(num_rollouts=2)
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/chenyijie6/gpt-rlhf.git
git@gitee.com:chenyijie6/gpt-rlhf.git
chenyijie6
gpt-rlhf
gpt-rlhf
master

搜索帮助