代码拉取完成,页面将自动刷新
# Copyright 2020-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""PPO model"""
import copy
import datetime
import glob
import json
import math
import os
import sys
from dataclasses import dataclass
import mindspore
import mindspore.common.dtype as mstype
import mindspore.communication.management as D
import mindspore.nn as nn
import mindspore.ops.functional as F
import numpy as np
from mindspore import Parameter, Tensor, context, ops
from mindspore.common.initializer import TruncatedNormal, initializer
from mindspore.context import ParallelMode
from mindspore.nn.cell import Cell
from mindspore.nn.wrap.cell_wrapper import (
MicroBatchInterleaved,
PipelineCell,
_MicroBatch,
_VirtualDatasetCell,
)
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
from mindspore.ops import ReduceOp
from mindspore.ops import operations as P
from mindspore.ops.operations._inner_ops import Receive, Send
from mindspore.parallel import set_algo_parameters
from mindspore.parallel._cost_model_context import _set_multi_subgraphs
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor
from mindspore.train.model import Model
from mindspore.train.serialization import (
load_checkpoint,
load_distributed_checkpoint,
load_param_into_net,
)
from generator import GeneratorMixin
from mindformers.core.loss.loss import CrossEntropyLoss
from mindformers.models.bloom import BloomConfig, BloomLMHeadModel
from mindformers.models.gpt2 import GPT2Config, GPT2LMHeadModel
from mindformers.models.pangualpha import PanguAlphaConfig, PanguAlphaHeadModel
# try:
from mindformers.modules.transformer import (
AttentionMask,
TransformerOpParallelConfig,
TransformerRecomputeConfig,
)
from src.adam import AdamWeightDecayOp
from src.callbacks import EvalCallBack, LossCallBack
from src.dataset import create_dataset
from src.metrics import PPLMetric
from src.pangu_alpha import (
PanguAlpha_Model,
PanguAlphaModel,
PanGUAlphaWithLoss,
PanGuHead,
)
from src.pangu_alpha_config import PanguAlphaConfig, set_parse
from src.pangu_alpha_wrapcell import (
PanguAlphaTrainOneStepWithLossScaleCell,
PanguAlphaTrainPipelineWithLossScaleCell,
)
from src.utils import FP32StateAdamWeightDecay, LearningRate, download_data, get_args
# except ImportError as e:
# print("Import ERROR, expect mindformers to be installed. "
# "Please refer to the page https://gitee.com/mindspore/mindformers.git to install the mindformers.")
# print("Now exit the program.")
# exit(1)
# from mindformers.models.ipt import IptModelForCasualLM, IptConfig
sys.path.append("../../ipt")
project_root = os.path.abspath(
os.path.dirname(os.path.realpath(__file__)) + os.path.sep + "..")
print('project_root:', project_root)
@dataclass
class PPOConfig:
"""
PPO config class which defines the model size
"""
epochs: int = 3
total_steps: int = 100000
batch_size: int = 1
checkpoint_interval = 10000
eval_interval: int = 200
optimizer: str = 'adamw'
lr: float = 5.0e-5
betas = (0.9, 0.999)
eps: float = 1.0e-8
weight_decay: float = 0.01
sceduler_name: str = 'cosine_annealing'
T_max: int = 100000
eta_min: float = 5.0e-6
num_rollouts: int = 8
# TODO: Original chunk size is 4
chunk_size: int = 2
ppo_epochs: int = 4
init_kl_coef: float = 0.1
target: float = 6.0
horizon: int = 10000
gamma: float = 1.0
lam: float = 0.95
cliprange: float = 0.2
cliprange_value: float = 0.2
vf_coef: float = 1.0
pretrain_coef: float = 0.9
scale_reward: bool = None
ref_mean: bool = False
ref_std: bool = False
gen_experience_kwargs: bool = False
# TODO: Original seq is 1024
seq_length: int = 1024
max_prompt_length: int = 512
max_decode_length: int = 512
vocab_size: int = 40000
hidden_size: int = 768
num_layers: int = 12
num_heads: int = 12
expand_ratio: int = 4
post_layernorm_residual: bool = False
dropout_rate: float = 0.1
dtype: mstype = mstype.float16
compute_dtype: mstype = mstype.float16
layernorm_dtype: mstype = mstype.float32
softmax_dtype: mstype = mstype.float32
eos_token_id = 2
pad_token_id = 2
repetition_penalty = 1
top_k = 1
top_p = 0.95
is_encoder_decoder = False
do_sample = False
class LogprobsOfLabels(nn.Cell):
def __init__(self):
super(LogprobsOfLabels, self).__init__()
self.cast = ops.Cast()
self.log_softmax = P.LogSoftmax(axis=-1)
self.gatherd = P.GatherD()
self.unsqueeze = P.ExpandDims()
self.squeeze = P.Squeeze(axis=-1)
def construct(self, logits, labels):
labels = self.cast(labels, mindspore.int32)
logprobs = self.log_softmax(logits)
logprobs_labels = self.gatherd(logprobs, -1, self.unsqueeze(labels, -1))
return self.squeeze(logprobs_labels)
class ProcessLogits(nn.Cell):
def __init__(self):
super(ProcessLogits, self).__init__()
self.e = Tensor(np.e)
self.gather = P.Gather()
self.logsoftmax = P.LogSoftmax()
self.reshape = P.Reshape()
def construct(self, logits, current_index=None):
"""Process the logits"""
if current_index is not None:
index = current_index.view(-1,)
if len(logits.shape) == 3:
logits = self.reshape(logits, (logits.shape[0] * logits.shape[1], -1))
logits = self.gather(logits, index, 0)
outputs = self.logsoftmax(logits)
outputs = F.tensor_pow(self.e, outputs)
return outputs
class AdaptiveKLController(nn.Cell):
"""Adaptive KL Controller as described in Ziegler et al. "Fine-Tuning Language Models from Human Preferences"
Reference: Section 2.2 https://arxiv.org/pdf/1909.08593.pdf#page=2
Source: https://github.com/openai/lm-human-preferences/blob/master/lm_human_preferences/train_policy.py
"""
def __init__(self, init_kl_coef: float, target: float, horizon: int):
super(AdaptiveKLController, self).__init__()
self.value = Parameter([init_kl_coef, ], requires_grad=False)
self.target = Tensor(target)
self.horizon = Tensor(horizon)
self.div = P.Div()
self.add = P.Add()
def construct(self, current, n_steps):
proportional_error = self.add(self.div(current, self.target), Tensor(-1, mstype.float32))
proportional_error = proportional_error.clip(-0.2, -0.2)
mult = 1 + proportional_error * n_steps / self.horizon
self.value *= mult
return self.value
'''def __init__(self, init_kl_coef: float, target: float, horizon: int):
self.value = init_kl_coef
self.target = target
self.horizon = horizon
def update(self, current: float, n_steps: int):
"""Returns adaptively updated KL coefficient, βₜ₊₁.
Arguments:
current: The current KL value between the newest policy and the initial policy.
"""
proportional_error = np.clip(current / self.target - 1, -0.2, 0.2) # ϵₜ
mult = 1 + proportional_error * n_steps / self.horizon
self.value *= mult # βₜ₊₁'''
class FixedKLController(nn.Cell):
"""Fixed KL controller."""
def __init__(self, kl_coef):
super(FixedKLController, self).__init__()
self.value = Tensor([kl_coef, ])
def construct(self, current, n_steps):
"""Returns updated KL coefficient, βₜ₊₁.
Arguments:
current: The current KL value between the newest policy and the initial policy.
"""
return Tensor(0.0, mstype.float16)
'''def __init__(self, kl_coef):
self.value = kl_coef
def update(self, current: float, n_steps: int):
"""Returns updated KL coefficient, βₜ₊₁.
Arguments:
current: The current KL value between the newest policy and the initial policy.
"""
pass'''
class CausalLMHydraWithValueHead(nn.Cell):
"""
CausalLMHydraWithValueHead
"""
def __init__(self, model_config, ppo_config, is_training=True):
super(CausalLMHydraWithValueHead, self).__init__()
if not is_training:
model_config.dropout_rate = 0.0
self.model_config = model_config
self.ppo_config = ppo_config
if isinstance(model_config, PanguAlphaConfig):
self.model_type = 'pangu'
elif isinstance(model_config, BloomConfig):
self.model_type = 'bloom'
elif isinstance(model_config, GPT2Config):
self.model_type = 'gpt2'
else:
raise NotImplementedError("only support pangu and bloom")
print("sft model_type: ", self.model_type)
if self.model_type == 'pangu':
self.model = PanguAlphaHeadModel(model_config)
self.backbone = self.model.backbone
self.lm_head = self.model.head
elif self.model_type == 'bloom':
self.model = BloomLMHeadModel(model_config)
self.backbone = self.model.transformer
self.lm_head = self.model.head
elif self.model_type == 'ipt':
self.attention_type = int(os.getenv("attention_type", 0))
self.model = IptModelForCasualLM(model_config)
self.backbone = self.model.transformer
self.lm_head = self.model.lm_head
elif self.model_type == 'gpt2':
self.model = GPT2LMHeadModel(model_config)
self.backbone = self.model.backbone
self.lm_head = self.model.head
self.lm_head.pipeline_stage = model_config.parallel_config.pipeline_stage - 1
dp = model_config.parallel_config.data_parallel
mp = model_config.parallel_config.model_parallel
self.vocab_size = model_config.vocab_size
self.chunk_size = ppo_config.chunk_size
self.seq_length = ppo_config.seq_length
self.cast = P.Cast()
self.shape = P.Shape()
self.squeeze = P.Squeeze(axis=-1).shard(((dp, 1, 1), ))
self.squeeze_no_shard = P.Squeeze(axis=-1).shard(((1, 1, 1), ))
self.unsqueeze = P.ExpandDims()
self.reshape = P.Reshape()
self.e = Tensor(np.e)
self.gather = P.Gather().shard(((dp, mp),
(1, ), ))
self.strided_slice_1 = P.StridedSlice().shard(((dp, 1, mp), ))
self.strided_slice_2 = P.StridedSlice().shard(((dp, 1), ))
self.gatherd = P.GatherD()
self.logsoftmax_1 = P.LogSoftmax().shard(((dp, 1, 1), ))
self.logsoftmax_2 = P.LogSoftmax().shard(((dp, 1), ))
self.on_value = Tensor(1.0, mstype.float32)
self.off_value = Tensor(0.0, mstype.float32)
self.sum = P.ReduceSum().shard(((dp, mp),))
self.max = P.ArgMaxWithValue(axis=-1, keep_dims=True).shard(((dp, mp),))
self.sub = P.Sub().shard(((dp, mp), (dp, 1)))
self.exp = P.Exp().shard(((dp, mp),))
self.div = P.RealDiv().shard(((dp, mp), (dp, 1)))
self.onehot = P.OneHot().shard(((dp, mp), (), ()))
self.log = P.Log().shard(((dp, mp),))
self.pow = P.Pow().shard(((dp, 1), ()))
self.argmax_no_shard = P.Argmax(-1).shard(((1, 1), ))
self.expand_dims = P.ExpandDims().shard(((dp, 1, 1),))
self.sub_shard = P.Sub().shard(((), (1, 1, 1)))
def process_logits(self, logits, current_index=None, is_first_iteration=False, use_past=False):
logits = logits.reshape(-1, logits.shape[-1])
if use_past and not is_first_iteration:
logits = logits
elif current_index is not None:
index = current_index.view(-1,)
logits = self.gather(logits, index, 0)
outputs = self.logsoftmax_2(logits)
outputs = self.pow(outputs, Tensor(np.e, mstype.float32))
return outputs
def process_logits2(self, logits, current_index=None, is_first_iteration=False, use_past=False):
logits = logits.reshape(-1, logits.shape[-1])
if use_past and not is_first_iteration:
logits = logits
elif current_index is not None:
index = current_index.view(-1,)
logits = self.gather(logits, index, 0)
top_token_id = self.argmax_no_shard(logits)
top_token_id = top_token_id.view(-1, 1)
return top_token_id
def logprobs_of_labels(self, logits, samples, batch_size, seq_length):
'''logits = logits[:, :-1, :]
samples = samples[:, 1:]
logprobs = self.logsoftmax_1(logits)'''
samples = self.strided_slice_2(samples, (0, 1), (batch_size, self.seq_length), (1, 1))
logits = self.reshape(logits, (batch_size * seq_length, -1))
_, logit_max = self.max(logits)
logit_sub = self.sub(logits, logit_max)
logit_exp = self.exp(logit_sub)
exp_sum = self.sum(logit_exp, -1)
exp_sum = P.Reshape()(exp_sum, (F.shape(exp_sum)[0], 1))
softmax_result = self.div(logit_exp, exp_sum)
logprobs = self.log(softmax_result)
logprobs = self.reshape(logprobs, (batch_size, seq_length, -1))
logprobs = self.strided_slice_1(logprobs, (0, 0, 0), (batch_size, -1, self.vocab_size), (1, 1, 1))
logprobs_labels = self.squeeze_no_shard(self.gatherd(logprobs, -1, self.unsqueeze(samples, -1)))
return logprobs_labels
def construct(self,
# inputs for the llm
input_ids,
input_position=None,
position_ids=None,
attention_mask=None,
init_reset=True,
batch_valid_length=None,
# inputs for `process_logits`
is_first_iteration=False,
use_past=False,
# inputs for choosing the output branch
samples=None,
return_full_logit=False):
batch_size, seq_length = input_ids.shape
if self.model_type == 'pangu':
if self.model.phase == 'train':
seq_length = seq_length - 1
tokens = self.model.slice(input_ids, (0, 0), (batch_size, seq_length), (1, 1))
else:
tokens = input_ids
input_mask = F.cast(self.model.not_equal(tokens, self.model.pad_token_id),
mstype.float32)
if attention_mask is None:
attention_mask = self.model.get_attention_mask(input_mask)
else:
attention_mask = self.model.cast(attention_mask, mstype.float32)
attention_mask = self.model.slice2(attention_mask, (0, 0, 0),
(batch_size, seq_length, seq_length),
(1, 1, 1))
if input_position is None:
input_position = F.tuple_to_array(F.make_range(seq_length))
input_position = self.model.expand(input_position, 0)
if batch_size == 1:
input_position = F.reshape(input_position, (1, seq_length))
else:
input_position = self.model.tile(input_position, (batch_size, 1))
else:
if self.model.phase == 'train':
input_position = self.model.slice(input_position, (0, 0), (batch_size, seq_length), (1, 1))
# [batch_size, seq_length, vocab_size]
output_states, embedding_table = self.backbone(tokens, input_position, attention_mask,
init_reset, batch_valid_length)
elif self.model_type == 'ipt':
tokens = input_ids
if position_ids:
position_ids = self.model.slice(position_ids, (0, 0), (batch_size, seq_length), (1, 1))
if not is_first_iteration and self.model.use_past:
attention_mask = self.model.all_ones_attention_mask
else:
attention_mask = self.model.not_equal(tokens, self.model.pad_token_id).astype(mstype.float32)
output_states, embedding_table = self.backbone(input_ids, position_ids, attention_mask=attention_mask,
init_reset=init_reset, batch_valid_length=batch_valid_length)
elif self.model_type == 'gpt2':
# if self.model.phase == 'train':
# seq_length = seq_length - 1
# tokens = self.model.slice(input_ids, (0, 0), (batch_size, seq_length), (1, 1))
# else:
tokens = input_ids
if attention_mask is None:
attention_mask = self.model.not_equal(input_ids, self.model.eos_token_id)
attention_mask = self.cast(attention_mask, mstype.float32)
attention_mask = self.model.get_attention_mask(attention_mask)
if not self.model.is_first_iteration:
attention_mask = self.model.tile(self.model.all_ones_attention_mask, (batch_size, 1, 1))
#
# # [batch_size, seq_length, vocab_size]
# (self, input_ids, attention_mask, input_position=None, init_reset=True, batch_valid_length=None):
output_states, embedding_table = self.backbone(
tokens, attention_mask, input_position=input_position,
init_reset=init_reset, batch_valid_length=batch_valid_length)
else:
if self.model.phase == "train":
tokens = self.model.stridedslice(input_ids, (0, 0), (batch_size, seq_length - 1), (1, 1))
else:
tokens = input_ids
if self.model.use_past:
input_mask = self.model.input_mask_all_ones
else:
input_mask = self.model.not_equal(tokens, self.model.eos_token_id).astype(mstype.float32)
output_states, embedding_table = self.backbone(tokens, input_mask, init_reset, batch_valid_length)
logits = self.lm_head(output_states, embedding_table)
logits = self.reshape(logits, (batch_size, seq_length, -1))
# This model is used in three places, generate, make_experience, and ppo
# to reduce memory, we return the required output in different places
# used in inference of make_experience and ppo
if samples is not None:
logprobs_labels = self.logprobs_of_labels(logits, samples, batch_size, seq_length)
return logprobs_labels
# used in generate
elif return_full_logit == False:
# logits (2, 8192, 50257)
outputs = self.process_logits2(logits, input_position, is_first_iteration, use_past)
return outputs
# used in pretrain loss
else:
return logits
class Sampler(nn.Cell):
def __init__(self):
super(Sampler, self).__init__()
self.shape = P.Shape()
self.zeros = P.Zeros()
self.top_k = P.TopK(sorted=True)
def construct(self, log_probs, batch_size, top_k, repetition_penalty, frequency_list):
vocab_size = self.shape(log_probs)[-1]
if repetition_penalty != 1 and frequency_list is None:
frequency_list = self.zeros((vocab_size, ), mindspore.float32)
log_probs_revised = log_probs.reshape(batch_size, vocab_size)
if repetition_penalty != 1:
log_probs_revised = log_probs - frequency_list * repetition_penalty - \
(frequency_list > 0) * repetition_penalty
logits = F.pow(Tensor(np.e), log_probs_revised)
probs, p_args = self.top_k(logits, top_k)
norm = probs.sum(-1, keepdims=True)
avg = F.broadcast_to(Tensor(1 / top_k), probs.shape).to(mstype.float32)
probs = F.select(norm == 0, avg, probs / norm)
return probs, p_args
class PPO_model(nn.Cell, GeneratorMixin):
"""
PPO_model
"""
def __init__(self, ppo_config, policy_model, critic_model, opt, is_training=True):
super(PPO_model, self).__init__()
self.ppo_config = ppo_config
self.pretrain_coef = self.ppo_config.pretrain_coef
self.opt = opt
self.use_past = opt.use_past
print("use_past: ", self.use_past)
self.pad_token_id = Tensor(ppo_config.pad_token_id, mstype.int32)
self.policy_model = policy_model
self.critic_model = critic_model
# self.inference_wrapper = PipelineCell_inference(self.policy_model, self.opt.inference_micro_size)
self.stack = P.Stack(axis=1)
self.allreduce_sum = P.AllReduce(ReduceOp.SUM)
self.reduce_sum = P.ReduceSum(keep_dims=False)
self.reduce_mean = P.ReduceMean(keep_dims=False)
self.rsqrt = P.Rsqrt()
self.concat = P.Concat(1)
self.log_softmax = P.LogSoftmax(axis=-1)
self.gather = P.GatherD()
self.unsqueeze = P.ExpandDims()
self.squeeze = P.Squeeze(axis=-1)
self.cliprange_value = ppo_config.cliprange_value
self.cliprange = ppo_config.cliprange
self.vf_coef = ppo_config.vf_coef
self.max = P.Maximum()
self.cast = P.Cast()
self.exp = P.Exp()
self.stop_grad = P.StopGradient()
self.gamma = 1
self.lam = 0.95
self.size = P.Size()
self.sum_and_count = Tensor([[1, 1]])
# self.attention_mask = Tensor(np.ones((1, 1024)), mstype.float32)
self.approx_kl = Parameter([0.0, ], requires_grad=False)
self.get_attention_mask = AttentionMask(ppo_config.seq_length)
if ppo_config.target is not None:
self.kl_ctl = AdaptiveKLController(ppo_config.init_kl_coef, ppo_config.target, ppo_config.horizon)
else:
self.kl_ctl = FixedKLController(ppo_config.init_kl_coef)
self.logprobs_of_labels = LogprobsOfLabels()
self.process_logits = ProcessLogits()
self.sampler = Sampler()
self.slice = P.StridedSlice().shard(((self.policy_model.model_config.parallel_config.data_parallel, 1),))
self.pretrain_loss = CrossEntropyLoss(self.policy_model.model_config.parallel_config.dp_mp_config)
self.not_equal = P.NotEqual().shard(((self.policy_model.model_config.parallel_config.data_parallel, 1), ()))
self.reduce_mean = P.ReduceMean()
self.depend = P.Depend()
def construct(
self,
query_tensors,
response_tensors,
logprobs,
values,
rewards,
attention_mask,
advantages,
returns,
pretrain_tokens,
labels):
"""
"""
# print("data: ", query_tensors, response_tensors, logprobs, values, rewards, attention_mask)
old_logprobs = logprobs
old_rewards = rewards
old_values = values
response_length = F.shape(old_rewards)[1]
# advantages, returns = self.get_advantages_and_returns(old_values, old_rewards, response_length)
# tokens = self.concat((query_tensors, response_tensors))
tokens = response_tensors
'''input_mask = F.cast(F.not_equal(tokens, self.pad_token_id), mstype.float32)
bs, seq_length = F.shape(response_tensors)
input_position = F.tuple_to_array(F.make_range(seq_length))
input_position = P.Tile()(input_position, (bs, 1))
attention_mask_pangu = self.get_attention_mask(input_mask)'''
# attention_mask = self.attention_mask
logprobs = self.policy_model(tokens, batch_valid_length=None,
is_first_iteration=True, samples=tokens)
tokens = self.depend(tokens, logprobs)
values_pred = self.critic_model(tokens)
values_pred = values_pred[:, :-1]
# logprobs = self.logprobs_of_labels(logits[:, :-1, :], tokens[:, 1:])
start = F.shape(query_tensors)[1] - 1
end = start + response_length
logprobs, values_pred, mask = (
logprobs[:, start:end],
values_pred[:, start:end],
attention_mask[:, start:end],
)
'''logprobs = logprobs[:, start:end]
values_pred = values_pred[:, start:end]
mask = attention_mask[:, start:end]'''
# calculate pretrain loss
batch_size, seq_length = response_tensors.shape
# (bs, 1024)
logits = self.policy_model(
pretrain_tokens, batch_valid_length=None, return_full_logit=True)
logits = P.Reshape()(logits, (batch_size * seq_length, -1))
labels = P.Reshape()(labels, (-1,))
input_mask = F.cast(self.not_equal(pretrain_tokens, self.pad_token_id), mstype.float32)
input_mask = P.Reshape()(input_mask, (-1,))
pretrain_loss = self.pretrain_loss(logits, labels, input_mask)
# calculate value loss and policy loss
vf_loss, pg_loss, approx_kl = self.get_vfloss_and_pgloss(
logprobs=logprobs,
values=values_pred,
old_logprobs=old_logprobs,
old_values=old_values,
advantages=advantages,
returns=returns,
mask=mask,
)
approx_kl = ops.Reshape()(approx_kl, (1, ))
self.approx_kl = approx_kl
loss = self.pretrain_coef * pretrain_loss + (1 - self.pretrain_coef) * pg_loss + self.vf_coef * vf_loss
return loss
def post_backward_callback(self):
return self.kl_ctl(self.approx_kl, Tensor([self.ppo_config.batch_size, ]))
def get_vfloss_and_pgloss(
self,
logprobs,
values,
old_logprobs,
old_values,
advantages,
returns,
mask,
):
"""PPO objective function.
References:
- https://stable-baselines.readthedocs.io/en/master/modules/ppo2.html
"""
values_clipped = ops.clip_by_value(
values,
old_values - self.cliprange_value,
old_values + self.cliprange_value,
)
n = mask.sum()
vf_loss1 = (values - returns) ** 2
vf_loss2 = (values_clipped - returns) ** 2
vf_loss = 0.5 * self.reduce_sum(self.max(vf_loss1, vf_loss2) * mask) / n
log_ratio = (logprobs - old_logprobs) * mask
ratio = self.exp(log_ratio)
# Unbiased KL-div estimates (`k3`). Ref: http://joschu.net/blog/kl-approx.html
approx_kl = self.reduce_mean((ratio - 1) - log_ratio)
approx_kl = self.stop_grad(approx_kl)
pg_loss1 = -advantages * ratio
pg_loss2 = -advantages * ops.clip_by_value(
ratio,
1.0 - self.cliprange,
1.0 + self.cliprange,
)
pg_loss = self.reduce_sum(self.max(pg_loss1, pg_loss2) * mask) / n
return vf_loss, pg_loss, approx_kl
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。