diff --git a/mindrlhf/wrapper/wrapper.py b/mindrlhf/wrapper/wrapper.py index 1d16cfe3b8688c5fcb60bfe5a144b2a4cd0b4996..e8415513101571a35f74a56501d44abe97edd80b 100644 --- a/mindrlhf/wrapper/wrapper.py +++ b/mindrlhf/wrapper/wrapper.py @@ -367,7 +367,7 @@ class TrainOneStepWithLossScaleGRPO(TrainOneStepWithLossScaleCell): return loss, lr, cond, scaling_sens.value() -class TrainPipelineWithLossScaleCellGRPO(nn.Cell): +class TrainPipelineWithLossScaleCellGRPO(nn.TrainOneStepWithLossScaleCell): """ Encapsulation class of network training. @@ -381,7 +381,9 @@ class TrainPipelineWithLossScaleCellGRPO(nn.Cell): """ def __init__(self, network, optimizer, config, scale_update_cell=None, enable_global_norm=True): - super(TrainPipelineWithLossScaleCellGRPO, self).__init__(auto_prefix=False) + super(TrainPipelineWithLossScaleCellGRPO, self).__init__(network, optimizer, scale_update_cell) + if isinstance(scale_update_cell, (int, float)): + scale_update_cell = Tensor(scale_update_cell) self.config = config self.network = network self.network.add_flags(defer_inline=True) @@ -405,9 +407,6 @@ class TrainPipelineWithLossScaleCellGRPO(nn.Cell): self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree) self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) self.cast = P.Cast() - self.alloc_status = P.NPUAllocFloatStatus() - self.get_status = P.NPUGetFloatStatus() - self.clear_before_grad = P.NPUClearFloatStatus() self.reduce_sum = P.ReduceSum(keep_dims=False) self.base = Tensor(1, mstype.float32) self.less_equal = P.LessEqual() @@ -415,9 +414,17 @@ class TrainPipelineWithLossScaleCellGRPO(nn.Cell): self.loss_scale = None self.reshape = P.Reshape() self.loss_scaling_manager = scale_update_cell - if scale_update_cell: + # 8-bit status param for get_overflow_status func + self.status = Tensor([0] * 8, mstype.int32) + if isinstance(scale_update_cell, nn.Cell): self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32), name="loss_scale") + elif isinstance(scale_update_cell, Tensor): + if scale_update_cell.shape == (1,) or scale_update_cell.shape == (): + self.loss_scale = Parameter(scale_update_cell, name='loss_scale') + else: + raise ValueError("The shape of 'scale_sense' must be (1,) or (), but got {}" + .format(scale_update_cell.shape)) self.clip = ClipByGlobalNorm(self.weights, self.config) self.micro_size = config.parallel_config.micro_batch_num self.opt_shard = _get_enable_parallel_optimizer() @@ -442,24 +449,9 @@ class TrainPipelineWithLossScaleCellGRPO(nn.Cell): scaling_sens = self.reshape(scaling_sens, (1,)) else: scaling_sens = sens - # alloc status and clear should be right before gradoperation - init = self.alloc_status() - status_clear = self.clear_before_grad(init) grads = self.grad(self.network, weights)(prompt_completion_ids, prompts_mask, responses_mask, ref_per_token_logps, advantages, self.cast(scaling_sens / self.micro_size, mstype.float32)) - init = F.depend(init, grads) - get_status = self.get_status(init) - init = F.depend(init, get_status) - flag_sum = self.reduce_sum(init, (0,)) - loss = F.depend(loss, status_clear) - if self.is_distributed: - # sum overflow flag over devices - flag_reduce = self.allreduce(flag_sum) - cond = self.less_equal(self.base, flag_reduce) - else: - cond = self.less_equal(self.base, flag_sum) - grads = F.depend(grads, cond) # apply grad reducer on grads if self.opt_shard: grads = self.grad_reducer(grads) @@ -474,10 +466,8 @@ class TrainPipelineWithLossScaleCellGRPO(nn.Cell): grads = self.hyper_map( F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) - overflow = cond - if sens is None: - overflow = self.loss_scaling_manager(self.loss_scale, cond) - overflow = overflow or loss.isnan() + cond = self.get_overflow_status(self.status, grads) + overflow = self.process_loss_scale(cond) if not overflow: if self.enable_offload: self.optimizer(grads, clip_value)