diff --git a/vllm_mindspore/model_executor/layers/sampler.py b/vllm_mindspore/model_executor/layers/sampler.py index 354fb021464af6510c765eb8bdf0397021a84e67..f285112583d6f8282eb5739793afc4d26726cd09 100644 --- a/vllm_mindspore/model_executor/layers/sampler.py +++ b/vllm_mindspore/model_executor/layers/sampler.py @@ -42,6 +42,7 @@ from vllm_mindspore.model_executor.sampling_metadata import ( SamplingTensors, SequenceGroupToSample, ) +from mindspore import mint if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"): raise RuntimeError("Donot support for mindspore now.") @@ -284,8 +285,8 @@ class Sampler(nn.Module): logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1)) if do_top_p_top_k and flashinfer_top_k_top_p_sampling is None: - logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps, - sampling_tensors.top_ks) + logits = _apply_top_k_top_p_npu(logits, sampling_tensors.top_ps, + sampling_tensors.top_ks) if do_min_p: logits = _apply_min_p(logits, sampling_tensors.min_ps) @@ -403,33 +404,59 @@ def _apply_min_tokens_penalty( assert logits_applied == logits.shape[0] return logits +def apply_top_k_only( + logits: torch.Tensor, + k: torch.Tensor, +) -> torch.Tensor: + """ + Apply top-k mask to the logits. -def _apply_top_k_top_p( + This implementation doesn't involve sorting the entire vocab. + + The logits tensor may be updated in-place. + """ + no_top_k_mask = k == logits.shape[1] + # Set non-top-k rows to 1 so that we can gather. + k = k.masked_fill(no_top_k_mask, 1) + max_top_k = k.max() + # topk.values tensor has shape [batch_size, max_top_k]. + # Convert top k to 0-based index in range [0, max_top_k). + k_index = k.sub_(1).unsqueeze(1).expand(logits.shape[0], 1) + + # tensor.item() will cause GPU-CPU Sync, so place as later as possible. + # can be deleted after logits.topk() support tensor-type input. + int_max_top_k = max_top_k.item() + + top_k_mask = logits.topk(int_max_top_k, dim=1)[0].gather(1, k_index.long()) + # Handle non-topk rows. + top_k_mask.masked_fill_(no_top_k_mask.unsqueeze(1), -float("inf")) + logits.masked_fill_(logits < top_k_mask, -float("inf")) + return logits + +def _apply_top_k_top_p_npu( logits: torch.Tensor, p: torch.Tensor, k: torch.Tensor, ) -> torch.Tensor: - logits_sort, logits_idx = logits.sort(dim=-1, descending=False) - - # Apply top-k. - top_k_mask = logits_sort.size(1) - k.to(torch.long) - # Get all the top_k values. - top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1)) - top_k_mask = logits_sort < top_k_mask - logits_sort.masked_fill_(top_k_mask, -float("inf")) - - # Apply top-p. - probs_sort = logits_sort.softmax(-1) - probs_sum = probs_sort.cumsum(axis=-1) - top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1) - # at least one - top_p_mask[:, -1] = False - logits_sort.masked_fill_(top_p_mask, -float("inf")) - - # Re-sort the probabilities. - logits = torch.empty_like(logits_sort).scatter_(dim=-1, - index=logits_idx, - src=logits_sort) + """Apply top-k and top-p optimized for NPU. + + This algorithm avoids using torch.scatter which is time-consuming on NPU. + """ + if k is not None: + logits = apply_top_k_only(logits, k) + + if p is not None: + probs = logits.softmax(dim=-1) + probs_sort, _ = mint.sort(probs, dim=-1, descending=False) + cumprob = mint.cumsum(probs_sort, dim=-1) + top_p_mask = cumprob <= 1 - p.unsqueeze(dim=1) + top_p_mask[:, -1] = False # at least one + + top_p_count = top_p_mask.sum(dim=-1).unsqueeze(1) + top_p_cutoff = probs_sort.gather(-1, top_p_count) + elements_to_discard = probs < top_p_cutoff + logits.masked_fill_(elements_to_discard, -float("inf")) + return logits