From 49116b83afdb3de359bc68c9e0740345f5822395 Mon Sep 17 00:00:00 2001 From: zhang_xu_hao1230 Date: Mon, 19 May 2025 22:08:04 +0800 Subject: [PATCH] =?UTF-8?q?exp=E5=9B=9E=E9=80=80?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- vllm_mindspore/model_executor/layers/sampler.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/vllm_mindspore/model_executor/layers/sampler.py b/vllm_mindspore/model_executor/layers/sampler.py index 354fb021..e2827420 100644 --- a/vllm_mindspore/model_executor/layers/sampler.py +++ b/vllm_mindspore/model_executor/layers/sampler.py @@ -596,6 +596,11 @@ def _beam_search_sample( assert sample_idx == logprobs.size(0) return results +def exponential(x, lambd=1.0, *, generator=None): + if generator is not None: + raise ValueError("`generator` can not be supported.") + output = np.random.exponential(scale=lambd, size=x.shape) + return ms.Tensor(output).astype(x.dtype) # torch.multinomial forces a GPU<->CPU sync. # Therefore, we use an optimized implementation instead. @@ -611,15 +616,16 @@ def _multinomial( probs = probs.repeat_interleave(num_samples, dim=0) q = torch.empty_like(probs) if seq_groups is None: - q.exponential_() + q = exponential(q) else: sample_idx = 0 for seq_group in seq_groups: seq_ids = seq_group.seq_ids stride = len(seq_ids) * num_samples assert seq_group.generator is not None - q[sample_idx : sample_idx + - stride].exponential_(generator=seq_group.generator) + q[sample_idx : sample_idx + stride] = exponential( + q[sample_idx : sample_idx + stride] + ) sample_idx += stride return probs.div_(q).argmax(dim=1).view(-1, num_samples) -- Gitee