diff --git a/mindspeed_llm/tasks/posttrain/rlxf/workers/reference.py b/mindspeed_llm/tasks/posttrain/rlxf/workers/reference.py index ff1e690066e3fcc0946deb71467e5269fbf1eef8..268cd8b9b661864e101f6148d5d482c51f884440 100644 --- a/mindspeed_llm/tasks/posttrain/rlxf/workers/reference.py +++ b/mindspeed_llm/tasks/posttrain/rlxf/workers/reference.py @@ -48,7 +48,6 @@ class ReferenceWorker(MegatronWorker): @register(dispatch_mode=Dispatch.MEGATRON_COMPUTE_PROTO) def compute_ref_log_prob(self, data: DataProto): - data = data.to('cuda') output = self.reference.compute_log_prob(data=data) if output is not None: output = DataProto.from_dict(tensors={'ref_log_prob': output})