From f49b3de9fb458cecc2493bd891f0103110ad9702 Mon Sep 17 00:00:00 2001 From: HanhuiChen Date: Fri, 20 Jun 2025 17:56:41 +0800 Subject: [PATCH 01/14] Add qwen3 template --- .../inference/text_generation/tokenization.py | 2 +- .../tasks/posttrain/base/base_trainer.py | 6 +- .../tasks/posttrain/sft/sft_trainer.py | 4 +- .../tasks/preprocess/data_handler.py | 28 +-- .../preprocess/decoder_packed_mtf_dataset.py | 2 +- mindspeed_llm/tasks/preprocess/templates.py | 174 +++++++++++++++--- mindspeed_llm/training/arguments.py | 5 +- mindspeed_llm/training/tokenizer/tokenizer.py | 2 +- preprocess_data.py | 5 +- 9 files changed, 184 insertions(+), 44 deletions(-) diff --git a/mindspeed_llm/inference/text_generation/tokenization.py b/mindspeed_llm/inference/text_generation/tokenization.py index bd70250af..30b5785c5 100644 --- a/mindspeed_llm/inference/text_generation/tokenization.py +++ b/mindspeed_llm/inference/text_generation/tokenization.py @@ -132,7 +132,7 @@ def _tokenize_prompts_and_batch(tokenizer, prompts, tokens_to_generate, max_gene # Tokenize all the prompts. if hasattr(args, "prompt_type") and args.prompt_type is not None: - template = get_model_template(args.prompt_type.strip(), args.prompt_type_path.strip()) + template = get_model_template(args.prompt_type.strip(), args.prompt_type_path.strip(), args.enable_thinking) if args.hf_chat_template: if not hasattr(tokenizer, "apply_chat_template"): diff --git a/mindspeed_llm/tasks/posttrain/base/base_trainer.py b/mindspeed_llm/tasks/posttrain/base/base_trainer.py index 16b69da2b..714f13372 100644 --- a/mindspeed_llm/tasks/posttrain/base/base_trainer.py +++ b/mindspeed_llm/tasks/posttrain/base/base_trainer.py @@ -25,7 +25,7 @@ from mindspeed_llm.training import build_train_args from mindspeed_llm.training import train from mindspeed_llm.training.initialize import set_jit_fusion_options from mindspeed_llm.tasks.posttrain.utils import train_valid_test_datasets_provider - +from hook_test import hook_for_model, register_token_id_hook _TRAIN_START_TIME = time.time() @@ -155,7 +155,9 @@ class BaseTrainer(ABC): pre_process=pre_process, post_process=post_process ) - + + # hook_for_model(model) + # register_token_id_hook(model) return model @staticmethod diff --git a/mindspeed_llm/tasks/posttrain/sft/sft_trainer.py b/mindspeed_llm/tasks/posttrain/sft/sft_trainer.py index 0fade90b1..1c5d6d477 100644 --- a/mindspeed_llm/tasks/posttrain/sft/sft_trainer.py +++ b/mindspeed_llm/tasks/posttrain/sft/sft_trainer.py @@ -112,7 +112,10 @@ class SFTTrainer(BaseTrainer): loss_mask = input_tensor losses = output_tensor.float() + # losses = output_tensor.half() loss_mask = loss_mask[..., 1:].view(-1).float() + # loss_mask = loss_mask[..., 1:].view(-1).half() + if self.args.context_parallel_size > 1: loss = torch.cat([torch.sum(losses.view(-1) * loss_mask).view(1), loss_mask.sum().view(1)]) torch.distributed.all_reduce(loss, group=mpu.get_context_parallel_group()) @@ -154,5 +157,4 @@ class SFTTrainer(BaseTrainer): else: output_tensor = model(tokens, position_ids, attention_mask, labels=labels, loss_mask=loss_mask) - return output_tensor, partial(self.loss_func, loss_mask) \ No newline at end of file diff --git a/mindspeed_llm/tasks/preprocess/data_handler.py b/mindspeed_llm/tasks/preprocess/data_handler.py index a20935f88..6eea4a623 100644 --- a/mindspeed_llm/tasks/preprocess/data_handler.py +++ b/mindspeed_llm/tasks/preprocess/data_handler.py @@ -105,11 +105,11 @@ class BaseDatasetHandler(object): batch = doc["input_ids"] for idx, sample in enumerate(batch): length = len(sample) - if (length > self.args.seq_length) and (not self.args.neat_pack): - logger.warning(f"Dropped lengthy example with length {length} > {self.args.seq_length}.") + if (length >= self.args.seq_length) and (not self.args.neat_pack): + logger.warning(f"Dropped lengthy example with length {length} >= {self.args.seq_length}.") else: - if length > self.args.seq_length: - logger.warning(f"Sequence length {length} > {self.args.seq_length}.") + if length >= self.args.seq_length: + logger.warning(f"Sequence length {length} >= {self.args.seq_length}.") sample = sample[:self.args.seq_length - 1] length = len(sample) lengths.append(length) @@ -333,7 +333,8 @@ class LlamaFactoryInstructionHandler(BaseDatasetHandler): self.args.output_prefix = self.args.output_prefix + "_packed" self.ignored_label = -100 self.is_multi_turn = True - self.llama_factory_template = get_model_template(args.prompt_type.strip(), args.prompt_type_path.strip()) + self.llama_factory_template = get_model_template(args.prompt_type.strip(), args.prompt_type_path.strip(), args.enable_thinking) + self.cutoff_len=args.seq_length def _format_msg(self, sample): return sample @@ -351,9 +352,8 @@ class LlamaFactoryInstructionHandler(BaseDatasetHandler): messages = [{'role': 'user', 'content': ''}, {'role': 'assistant', 'content': ''}] else: messages = example["prompt"] + example["response"] - for source_ids, target_ids in self.llama_factory_template.encode_multiturn( - tokenizer, messages, example["system"][0], example["tools"][0] + tokenizer, messages, example["system"][0], example["tools"][0], self.cutoff_len ): if self.train_on_inputs: source_mask = source_ids @@ -437,7 +437,7 @@ class HunyuanInstructionHandler(BaseDatasetHandler): self.ignored_label = -100 self.is_multi_turn = True - self.hunyuanlarge_template = get_model_template(args.prompt_type.strip(), args.prompt_type_path.strip()) + self.hunyuanlarge_template = get_model_template(args.prompt_type.strip(), args.prompt_type_path.strip(), args.enable_thinking) def _format_msg(self, sample): return sample @@ -502,7 +502,7 @@ class AlpacaStylePairwiseHandler(BaseDatasetHandler): self.args.json_keys = ["chosen_input_ids", "chosen_labels", "rejected_input_ids", "rejected_labels"] self.args.output_prefix = self.args.output_prefix + "_packed" self.ignored_label = -100 - self.llama_factory_template = get_model_template(args.prompt_type.strip(), args.prompt_type_path.strip()) + self.llama_factory_template = get_model_template(args.prompt_type.strip(), args.prompt_type_path.strip(), args.enable_thinking) def _filter(self, sample): chosen_messages = sample["prompt"] + [sample["response"][0]] @@ -612,7 +612,8 @@ class PPOAlpacaStyleInstructionHandler(BaseDatasetHandler): self.args.output_prefix = self.args.output_prefix + "_packed" self.ignored_label = -100 self.is_multi_turn = True - self.llama_factory_template = get_model_template(args.prompt_type.strip(), args.prompt_type_path.strip()) + self.llama_factory_template = get_model_template(args.prompt_type.strip(), args.prompt_type_path.strip(), args.enable_thinking) + self.cutoff_len=args.seq_length def _format_msg(self, sample): return sample @@ -632,7 +633,7 @@ class PPOAlpacaStyleInstructionHandler(BaseDatasetHandler): messages = example["prompt"] + example["response"] for source_ids, _ in self.llama_factory_template.encode_multiturn( - tokenizer, messages, example["system"][0], example["tools"][0] + tokenizer, messages, example["system"][0], example["tools"][0], self.cutoff_len ): input_ids += source_ids @@ -671,7 +672,8 @@ class R1AlpacaStyleInstructionHandler(BaseDatasetHandler): self.args.output_prefix = self.args.output_prefix + "_packed" self.ignored_label = -100 self.is_multi_turn = True - self.llama_factory_template = get_model_template(args.prompt_type.strip(), args.prompt_type_path.strip()) + self.llama_factory_template = get_model_template(args.prompt_type.strip(), args.prompt_type_path.strip(), args.enable_thinking) + self.cutoff_len=args.seq_length def _format_msg(self, sample): return sample @@ -691,7 +693,7 @@ class R1AlpacaStyleInstructionHandler(BaseDatasetHandler): messages = example["prompt"] + example["response"] multiturn_input_ids = self.llama_factory_template.encode_multiturn( - tokenizer, messages, example["system"][0], example["tools"][0]) + tokenizer, messages, example["system"][0], example["tools"][0], self.cutoff_len) turns = len(multiturn_input_ids) diff --git a/mindspeed_llm/tasks/preprocess/decoder_packed_mtf_dataset.py b/mindspeed_llm/tasks/preprocess/decoder_packed_mtf_dataset.py index 980b56939..49670bfb8 100644 --- a/mindspeed_llm/tasks/preprocess/decoder_packed_mtf_dataset.py +++ b/mindspeed_llm/tasks/preprocess/decoder_packed_mtf_dataset.py @@ -246,7 +246,7 @@ class DecoderPackedMTFDataset(torch.utils.data.Dataset): template = None # get model chat template if hasattr(self.args, "prompt_type") and self.args.prompt_type is not None: - template = get_model_template(self.args.prompt_type, self.args.prompt_type_path) + template = get_model_template(self.args.prompt_type, self.args.prompt_type_path, self.args.enable_thinking) prompt_begin_list, prompt_end_list = get_prompt_index(item["labels"], IGNORE_INDEX) diff --git a/mindspeed_llm/tasks/preprocess/templates.py b/mindspeed_llm/tasks/preprocess/templates.py index 57aa726bf..16d4eb8c8 100644 --- a/mindspeed_llm/tasks/preprocess/templates.py +++ b/mindspeed_llm/tasks/preprocess/templates.py @@ -14,9 +14,11 @@ # limitations under the License. import os +import sys import re import json import logging +from copy import deepcopy from pathlib import Path from dataclasses import dataclass from enum import Enum, unique @@ -77,16 +79,6 @@ class Role(str, Enum): OBSERVATION = "observation" -def infer_max_len(source_len: int, target_len: int, max_len: int, reserved_label_len: int) -> Tuple[int, int]: - if source_len + target_len == 0: - max_target_len = 0 - else: - max_target_len = int(max_len * (target_len / (source_len + target_len))) - max_target_len = max(max_target_len, reserved_label_len) - max_source_len = max_len - min(max_target_len, target_len) - return max_source_len, max_target_len - - @dataclass class Template: format_user: "Formatter" @@ -99,9 +91,11 @@ class Template: format_prefix: "Formatter" default_system: str stop_words: List[str] + thought_words: tuple[str, str] efficient_eos: bool replace_eos: bool force_system: bool + enable_thinking: Optional[bool] def encode_oneturn( @@ -211,17 +205,19 @@ class Template: cutoff_len: int, reserved_label_len: int, ) -> Sequence[Tuple[List[int], List[int]]]: + from .decoder_packed_mtf_dataset import _infer_seqlen + encoded_pairs = [] total_length = 0 + cutoff_len=cutoff_len-reserved_label_len for i in range(0, len(encoded_messages), 2): if total_length >= cutoff_len: break - max_source_len, max_target_len = infer_max_len( + max_source_len, max_target_len = _infer_seqlen( source_len=len(encoded_messages[i]), target_len=len(encoded_messages[i + 1]), - max_len=(cutoff_len - total_length), - reserved_label_len=reserved_label_len, + cutoff_len=(cutoff_len - total_length) ) source_ids = encoded_messages[i][:max_source_len] target_ids = encoded_messages[i + 1][:max_target_len] @@ -231,6 +227,21 @@ class Template: return encoded_pairs + def add_thought(self, content: str = "") -> str: + r"""Add empty thought to assistant message.""" + return f"{self.thought_words[0]}\n\n{self.thought_words[1]}\n\n" + content + + + def remove_thought(self, content: str) -> str: + r"""Remove thought from assistant message.""" + pattern = re.compile(f"{re.escape(self.thought_words[0])}(.*?){re.escape(self.thought_words[1])}", re.DOTALL) + return re.sub(pattern, "", content).lstrip("\n") + + + def get_thought_word_ids(self, tokenizer: "PreTrainedTokenizer") -> list[int]: + r"""Get the token ids of thought words.""" + return tokenizer.encode(self.add_thought(), add_special_tokens=False) + @dataclass class Llama2Template(Template): def _encode( @@ -275,6 +286,105 @@ class Llama2Template(Template): return self._make_pairs(encoded_messages, cutoff_len, reserved_label_len) +@dataclass +class ReasoningTemplate(Template): + r"""A template that add thought to assistant message.""" + + def _encode( + self, + tokenizer: "PreTrainedTokenizer", + messages: List[Dict[str, str]], + system: str, + tools: str, + ) -> Sequence[list[int]]: + r""" + Encodes formatted inputs to pairs of token ids. + Turn 0: prefix + system + query resp + Turn t: sep + query resp + """ + system = system or self.default_system + encoded_messages = [] + for i, message in enumerate(messages): + elements = [] + if i == 0: + elements += self.format_prefix.apply() + if system or tools: + tool_text = self.format_tools.apply(content=tools)[0] if tools else "" + elements += self.format_system.apply(content=(system + tool_text)) + elif i > 0 and i % 2 == 0: + elements += self.format_separator.apply() + + if message["role"] == Role.USER.value: + elements += self.format_user.apply(content=message["content"], idx=str(i // 2)) + elif message["role"] == Role.ASSISTANT.value: + elements += self.format_assistant.apply(content=message["content"]) + elif message["role"] == Role.OBSERVATION.value: + elements += self.format_observation.apply(content=message["content"]) + elif message["role"] == Role.FUNCTION.value: + elements += self.format_function.apply(content=message["content"]) + else: + raise NotImplementedError("Unexpected role: {}".format(message["role"])) + encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements)) + + return encoded_messages + + def encode_oneturn( + self, + tokenizer: "PreTrainedTokenizer", + messages: list[dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + cutoff_len: int = 1_000_000, + reserved_label_len: int = 1, + ) -> Tuple[list[int], list[int]]: + messages = deepcopy(messages) + for i in range(1, len(messages) - 2, 2): + messages[i]["content"] = self.remove_thought(messages[i]["content"]) + + if self.enable_thinking is False: # remove all cot + messages[-1]["content"] = self.remove_thought(messages[-1]["content"]) + + prompt_ids, response_ids = super().encode_oneturn(tokenizer, messages, system, tools) + if ( + self.thought_words[0] not in messages[-1]["content"] + and self.thought_words[1] not in messages[-1]["content"] + ): # add empty cot + if not self.enable_thinking: # do not compute loss + prompt_ids += self.get_thought_word_ids(tokenizer) + else: # do compute loss + response_ids = self.get_thought_word_ids(tokenizer) + response_ids + + return prompt_ids, response_ids + + def encode_multiturn( + self, + tokenizer: "PreTrainedTokenizer", + messages: list[dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + cutoff_len: int = 1_000_000, + reserved_label_len: int = 1, + ) -> Sequence[Tuple[List[int], List[int]]]: + messages = deepcopy(messages) + + if self.enable_thinking is False: # remove all cot + for i in range(1, len(messages), 2): + messages[i]["content"] = self.remove_thought(messages[i]["content"]) + + encoded_messages = self._encode(tokenizer, messages, system, tools) + + for i in range(0, len(messages), 2): + if ( + self.thought_words[0] not in messages[i + 1]["content"] + and self.thought_words[1] not in messages[i + 1]["content"] + ): # add empty cot + if not self.enable_thinking: # do not compute loss + encoded_messages[i] += self.get_thought_word_ids(tokenizer) + else: # do compute loss + encoded_messages[i + 1] = self.get_thought_word_ids(tokenizer) + encoded_messages[i + 1] + + return self._make_pairs(encoded_messages, cutoff_len, reserved_label_len) + templates: Dict[str, Template] = {} @@ -283,8 +393,8 @@ def get_templates() -> Dict[str, Template]: return templates -def get_model_template(name, prompt_type_path): - name = register_custom_template(name, prompt_type_path) +def get_model_template(name, prompt_type_path, enable_thinking): + name = register_custom_template(name, prompt_type_path,enable_thinking) if name is None: template = templates["empty"] # placeholder else: @@ -298,8 +408,9 @@ def fix_model_tokenizer( tokenizer: "PreTrainedTokenizer", name: Optional[str] = None, prompt_type_path: Optional[str] = None, + enable_thinking: Optional[bool] = False, ): - template = get_model_template(name, prompt_type_path) + template = get_model_template(name, prompt_type_path, enable_thinking) stop_words = template.stop_words if template.replace_eos: @@ -337,9 +448,12 @@ def _register_template( format_prefix: Optional["Formatter"] = None, default_system: str = "", stop_words: List[str] = [], + thought_words: Optional[tuple[str, str]] = None, efficient_eos: bool = False, replace_eos: bool = False, force_system: bool = False, + enable_thinking: Optional[bool] = True, + template_class: type["Template"] = Template, ) -> None: r""" Registers a chat template. @@ -368,7 +482,6 @@ def _register_template( ``` """ eos_slots = [] if efficient_eos else [{"eos_token"}] - template_class = Llama2Template if name.startswith("llama2") else Template default_user_formatter = StringFormatter(slots=["{{content}}"]) default_assistant_formatter = StringFormatter(slots=["{{content}}"] + eos_slots) default_function_formatter = FunctionFormatter(slots=["Action: {{name}}\nAction Input: {{arguments}}"] + eos_slots) @@ -386,9 +499,11 @@ def _register_template( format_prefix=format_prefix or default_prefix_formatter, default_system=default_system, stop_words=stop_words, + thought_words=thought_words or ("", ""), efficient_eos=efficient_eos, replace_eos=replace_eos, force_system=force_system, + enable_thinking=enable_thinking, ) @@ -472,7 +587,7 @@ def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer") return jinja_template -def register_custom_template(name, json_file_path=TEMPLATES_DIR) -> str: +def register_custom_template(name, json_file_path=TEMPLATES_DIR, enable_thinking=False) -> str: if name in templates: return name @@ -501,6 +616,7 @@ def register_custom_template(name, json_file_path=TEMPLATES_DIR) -> str: efficient_eos = _format_custom_template(config.get("efficient_eos", False)) replace_eos = _format_custom_template(config.get("replace_eos", False)) force_system = _format_custom_template(config.get("force_system", False)) + template_class = _format_custom_template(config.get("template_class",None)) if isinstance(default_system, list): default_system = "".join(default_system) if all(isinstance(sentence, str) for sentence in default_system) else default_system @@ -512,8 +628,8 @@ def register_custom_template(name, json_file_path=TEMPLATES_DIR) -> str: format_tools = ToolFormatter(**format_tools) if format_tools else None format_separator = EmptyFormatter(**format_separator) if format_separator else None format_prefix = EmptyFormatter(**format_prefix) if format_prefix else None - - + template_class = _get_template_class(template_class) if template_class else Template + _register_template( name=name, format_user=format_user, @@ -526,9 +642,12 @@ def register_custom_template(name, json_file_path=TEMPLATES_DIR) -> str: format_prefix=format_prefix, default_system=default_system, stop_words=stop_words, + thought_words=("", ""), efficient_eos=efficient_eos, replace_eos=replace_eos, - force_system=force_system + force_system=force_system, + enable_thinking=enable_thinking, + template_class=template_class ) return name @@ -538,4 +657,15 @@ def _format_custom_template(slots: Dict) -> Dict: if slots and isinstance(slots, Dict): for key, slot in slots.items(): slots[key] = list(map(lambda slot: set(slot) if isinstance(slot, list) else slot, slot)) if slot else None - return slots \ No newline at end of file + return slots + +def _get_template_class(template_name: str) -> None: + current_module = sys.modules.get(__name__) + if not current_module: + raise Exception("curent module not found") + template_class = getattr(current_module, template_name, None) + if template_class is None: + template_class = Template + logger.info("template will use %s to format dataset", template_class.__name__) + + return template_class \ No newline at end of file diff --git a/mindspeed_llm/training/arguments.py b/mindspeed_llm/training/arguments.py index c7538a9df..7e7d2e077 100644 --- a/mindspeed_llm/training/arguments.py +++ b/mindspeed_llm/training/arguments.py @@ -454,6 +454,9 @@ def _add_data_args(parser): group.add_argument('--padded-samples', action='store_true', help='fill in the missing samples within an epoch, ' 'starting at index 0, aligned with the LlamaFatory.') + group.add_argument("--enable_thinking", action='store_true', + help="Whether or not to enable thinking mode for reasoning models.") + return parser @@ -825,7 +828,7 @@ def _add_training_args(parser): group.add_argument('--prompt-type', type=str, default=None, choices=['default', 'empty', 'trl', 'chatglm2', 'chatglm3', 'chatglm3_system', 'glm4', 'chatml', 'chatml_de', 'qwen', 'qwen_r1', "qwen_math_r1", 'llama3', 'llama2', 'mistral', 'mixtral', 'gemma', 'alpaca', - 'deepseek2', 'deepseek2-lite', 'minicpm3', 'cpm', 'baichuan2', 'deepseek3', 'intern2', 'hunyuan'], + 'deepseek2', 'deepseek2-lite', 'minicpm3', 'cpm', 'baichuan2', 'deepseek3', 'intern2', 'hunyuan', 'qwen3'], help='Which template to use for constructing prompts in training/inference.' 'e.g., "qwen"') group.add_argument('--prompt-type-path', type=str, default=TEMPLATES_DIR, help='Path to the json file of templates.') diff --git a/mindspeed_llm/training/tokenizer/tokenizer.py b/mindspeed_llm/training/tokenizer/tokenizer.py index 502af3fc1..b82b09427 100644 --- a/mindspeed_llm/training/tokenizer/tokenizer.py +++ b/mindspeed_llm/training/tokenizer/tokenizer.py @@ -66,7 +66,7 @@ def build_tokenizer(args): if ("PreTrainedTokenizerBase" not in str(tokenizer.tokenizer._pad.__func__)): tokenizer.tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer.tokenizer) tokenizer.tokenizer.padding_side = "right" - fix_model_tokenizer(tokenizer.tokenizer, args.prompt_type.strip(), args.prompt_type_path.strip()) + fix_model_tokenizer(tokenizer.tokenizer, args.prompt_type.strip(), args.prompt_type_path.strip(), args.enable_thinking) if args.tokenizer_type == "GPTSentencePieceTokenizer": tokenizer.tokenizer.eos_token_id = tokenizer.tokenizer._eos_id diff --git a/preprocess_data.py b/preprocess_data.py index 11cbdf2cd..45bfbd38b 100644 --- a/preprocess_data.py +++ b/preprocess_data.py @@ -113,7 +113,7 @@ def add_data_args(parser): group.add_argument('--prompt-type', type=str, default=None, choices=['default', 'empty', 'trl', 'chatglm2', 'chatglm3', 'chatglm3_system', 'glm4', 'chatml', 'chatml_de', 'qwen', 'qwen_r1', "qwen_math_r1", 'llama3', 'llama2', 'mistral', 'mixtral', 'gemma', 'alpaca', - 'deepseek2', 'deepseek2-lite', 'cpm', 'baichuan2', 'deepseek3', 'intern2', 'hunyuan'], + 'deepseek2', 'deepseek2-lite', 'cpm', 'baichuan2', 'deepseek3', 'intern2', 'hunyuan', 'qwen3'], help='Which template to use for constructing prompts in training.' 'e.g., "qwen"') group.add_argument('--prompt-type-path', type=str, default=TEMPLATES_DIR, @@ -150,7 +150,8 @@ def add_data_args(parser): help="Use a zigzag attention mask.") group.add_argument("--script-data-dir", type=str, default=None, help="Python script dataset direction") - + group.add_argument("--enable_thinking", action='store_true', + help="Whether or not to enable thinking mode for reasoning models.") def add_tokenizer_args(parser): group = parser.add_argument_group(title='tokenizer') -- Gitee From 0bad638b449c8fe4c7d8c291682553bd6048c965 Mon Sep 17 00:00:00 2001 From: HanhuiChen Date: Sat, 21 Jun 2025 06:52:28 +0000 Subject: [PATCH 02/14] update mindspeed_llm/tasks/posttrain/base/base_trainer.py. remove debug code Signed-off-by: HanhuiChen --- mindspeed_llm/tasks/posttrain/base/base_trainer.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/mindspeed_llm/tasks/posttrain/base/base_trainer.py b/mindspeed_llm/tasks/posttrain/base/base_trainer.py index 714f13372..a46fca288 100644 --- a/mindspeed_llm/tasks/posttrain/base/base_trainer.py +++ b/mindspeed_llm/tasks/posttrain/base/base_trainer.py @@ -25,7 +25,6 @@ from mindspeed_llm.training import build_train_args from mindspeed_llm.training import train from mindspeed_llm.training.initialize import set_jit_fusion_options from mindspeed_llm.tasks.posttrain.utils import train_valid_test_datasets_provider -from hook_test import hook_for_model, register_token_id_hook _TRAIN_START_TIME = time.time() @@ -156,8 +155,6 @@ class BaseTrainer(ABC): post_process=post_process ) - # hook_for_model(model) - # register_token_id_hook(model) return model @staticmethod -- Gitee From d3a31405bfdba5956b4aa7c2df2ae1cd2a3e567a Mon Sep 17 00:00:00 2001 From: HanhuiChen Date: Sat, 21 Jun 2025 06:55:35 +0000 Subject: [PATCH 03/14] update mindspeed_llm/tasks/posttrain/base/base_trainer.py. Signed-off-by: HanhuiChen --- mindspeed_llm/tasks/posttrain/base/base_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mindspeed_llm/tasks/posttrain/base/base_trainer.py b/mindspeed_llm/tasks/posttrain/base/base_trainer.py index a46fca288..4451414cf 100644 --- a/mindspeed_llm/tasks/posttrain/base/base_trainer.py +++ b/mindspeed_llm/tasks/posttrain/base/base_trainer.py @@ -26,6 +26,7 @@ from mindspeed_llm.training import train from mindspeed_llm.training.initialize import set_jit_fusion_options from mindspeed_llm.tasks.posttrain.utils import train_valid_test_datasets_provider + _TRAIN_START_TIME = time.time() @@ -154,7 +155,6 @@ class BaseTrainer(ABC): pre_process=pre_process, post_process=post_process ) - return model @staticmethod -- Gitee From c09147d4855c4a60637accfd410bb66bf71da9ac Mon Sep 17 00:00:00 2001 From: HanhuiChen Date: Sat, 21 Jun 2025 06:56:20 +0000 Subject: [PATCH 04/14] update mindspeed_llm/tasks/posttrain/sft/sft_trainer.py. Signed-off-by: HanhuiChen --- mindspeed_llm/tasks/posttrain/sft/sft_trainer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mindspeed_llm/tasks/posttrain/sft/sft_trainer.py b/mindspeed_llm/tasks/posttrain/sft/sft_trainer.py index a05b800dc..5837cc701 100644 --- a/mindspeed_llm/tasks/posttrain/sft/sft_trainer.py +++ b/mindspeed_llm/tasks/posttrain/sft/sft_trainer.py @@ -118,7 +118,6 @@ class SFTTrainer(BaseTrainer): loss_mask = input_tensor losses = output_tensor.float() - loss_mask = loss_mask[..., 1:].view(-1).float() if args.context_parallel_size > 1: loss = torch.cat([torch.sum(losses.view(-1) * loss_mask).view(1), loss_mask.sum().view(1)]) @@ -191,4 +190,4 @@ def forward_step_in_sft_with_dualpipe(data_iterator, model, extra_block_kwargs=N else: output_tensor, model_graph = model( tokens, position_ids, attention_mask, labels=labels, loss_mask=loss_mask) - return (output_tensor, model_graph), partial(SFTTrainer.loss_func, loss_mask) + return (output_tensor, model_graph), partial(SFTTrainer.loss_func, loss_mask) \ No newline at end of file -- Gitee From 4ca4428fcd9f13ae49a8520fc3b653b7d75537d5 Mon Sep 17 00:00:00 2001 From: HanhuiChen Date: Sat, 21 Jun 2025 06:57:41 +0000 Subject: [PATCH 05/14] update mindspeed_llm/tasks/posttrain/base/base_trainer.py. Signed-off-by: HanhuiChen --- mindspeed_llm/tasks/posttrain/base/base_trainer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mindspeed_llm/tasks/posttrain/base/base_trainer.py b/mindspeed_llm/tasks/posttrain/base/base_trainer.py index 4451414cf..16b69da2b 100644 --- a/mindspeed_llm/tasks/posttrain/base/base_trainer.py +++ b/mindspeed_llm/tasks/posttrain/base/base_trainer.py @@ -155,6 +155,7 @@ class BaseTrainer(ABC): pre_process=pre_process, post_process=post_process ) + return model @staticmethod -- Gitee From 2e6d0db8836d5cef0cae3862b5586c3a7c55e842 Mon Sep 17 00:00:00 2001 From: HanhuiChen Date: Sat, 21 Jun 2025 15:09:44 +0800 Subject: [PATCH 06/14] update template.json --- configs/finetune/templates.json | 34 ++++++++++++++++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/configs/finetune/templates.json b/configs/finetune/templates.json index 8310cc61a..4553251ab 100644 --- a/configs/finetune/templates.json +++ b/configs/finetune/templates.json @@ -445,7 +445,8 @@ "slots": [ "<>\n{{content}}\n<>\n\n" ] - } + }, + "template_class": "Llama2Template" }, { "name": "alpaca", @@ -576,5 +577,36 @@ ] }, "default_system": "You are a helpful assistant." + }, + { + "name": "qwen3", + "format_user": { + "slots": [ + "<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n" + ] + }, + "format_assistant": { + "slots": [ + "{{content}}<|im_end|>\n" + ] + }, + "format_system": { + "slots": [ + "<|im_start|>system\n{{content}}<|im_end|>\n" + ] + }, + "format_observation": { + "slots": [ + "<|im_start|>user\n\n{{content}}\n<|im_end|>\n<|im_start|>assistant\n" + ] + }, + "format_tools": { + "tool_format": "qwen" + }, + "stop_words": [ + "<|im_end|>" + ], + "replace_eos": true, + "template_class": "ReasoningTemplate" } ] \ No newline at end of file -- Gitee From fce83d1ad066b622ce4dae7c4ded3ea76103f452 Mon Sep 17 00:00:00 2001 From: HanhuiChen Date: Mon, 23 Jun 2025 02:45:01 +0000 Subject: [PATCH 07/14] update mindspeed_llm/tasks/preprocess/templates.py. Signed-off-by: HanhuiChen --- mindspeed_llm/tasks/preprocess/templates.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/mindspeed_llm/tasks/preprocess/templates.py b/mindspeed_llm/tasks/preprocess/templates.py index 16d4eb8c8..36475ed54 100644 --- a/mindspeed_llm/tasks/preprocess/templates.py +++ b/mindspeed_llm/tasks/preprocess/templates.py @@ -209,7 +209,7 @@ class Template: encoded_pairs = [] total_length = 0 - cutoff_len=cutoff_len-reserved_label_len + cutoff_len = cutoff_len - reserved_label_len for i in range(0, len(encoded_messages), 2): if total_length >= cutoff_len: break @@ -242,6 +242,7 @@ class Template: r"""Get the token ids of thought words.""" return tokenizer.encode(self.add_thought(), add_special_tokens=False) + @dataclass class Llama2Template(Template): def _encode( @@ -286,6 +287,7 @@ class Llama2Template(Template): return self._make_pairs(encoded_messages, cutoff_len, reserved_label_len) + @dataclass class ReasoningTemplate(Template): r"""A template that add thought to assistant message.""" @@ -394,7 +396,7 @@ def get_templates() -> Dict[str, Template]: def get_model_template(name, prompt_type_path, enable_thinking): - name = register_custom_template(name, prompt_type_path,enable_thinking) + name = register_custom_template(name, prompt_type_path, enable_thinking) if name is None: template = templates["empty"] # placeholder else: @@ -616,7 +618,7 @@ def register_custom_template(name, json_file_path=TEMPLATES_DIR, enable_thinking efficient_eos = _format_custom_template(config.get("efficient_eos", False)) replace_eos = _format_custom_template(config.get("replace_eos", False)) force_system = _format_custom_template(config.get("force_system", False)) - template_class = _format_custom_template(config.get("template_class",None)) + template_class = _format_custom_template(config.get("template_class", None)) if isinstance(default_system, list): default_system = "".join(default_system) if all(isinstance(sentence, str) for sentence in default_system) else default_system @@ -659,6 +661,7 @@ def _format_custom_template(slots: Dict) -> Dict: slots[key] = list(map(lambda slot: set(slot) if isinstance(slot, list) else slot, slot)) if slot else None return slots + def _get_template_class(template_name: str) -> None: current_module = sys.modules.get(__name__) if not current_module: -- Gitee From a788cbbc2ebf6ad0eca2c572ef289e0bcc3b1b6a Mon Sep 17 00:00:00 2001 From: HanhuiChen Date: Mon, 23 Jun 2025 02:47:29 +0000 Subject: [PATCH 08/14] update mindspeed_llm/tasks/preprocess/data_handler.py. Signed-off-by: HanhuiChen --- mindspeed_llm/tasks/preprocess/data_handler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mindspeed_llm/tasks/preprocess/data_handler.py b/mindspeed_llm/tasks/preprocess/data_handler.py index b2372ffa7..7189ee395 100644 --- a/mindspeed_llm/tasks/preprocess/data_handler.py +++ b/mindspeed_llm/tasks/preprocess/data_handler.py @@ -343,7 +343,7 @@ class LlamaFactoryInstructionHandler(BaseDatasetHandler): self.ignored_label = -100 self.is_multi_turn = True self.llama_factory_template = get_model_template(args.prompt_type.strip(), args.prompt_type_path.strip(), args.enable_thinking) - self.cutoff_len=args.seq_length + self.cutoff_len = args.seq_length def _format_msg(self, sample): return sample -- Gitee From cd26a158e8268162641e5be9dc54e0ecb42a271b Mon Sep 17 00:00:00 2001 From: HanhuiChen Date: Mon, 23 Jun 2025 08:19:01 +0000 Subject: [PATCH 09/14] update mindspeed_llm/tasks/preprocess/data_handler.py. Signed-off-by: HanhuiChen --- mindspeed_llm/tasks/preprocess/data_handler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mindspeed_llm/tasks/preprocess/data_handler.py b/mindspeed_llm/tasks/preprocess/data_handler.py index 7189ee395..07b10676a 100644 --- a/mindspeed_llm/tasks/preprocess/data_handler.py +++ b/mindspeed_llm/tasks/preprocess/data_handler.py @@ -622,7 +622,7 @@ class PPOAlpacaStyleInstructionHandler(BaseDatasetHandler): self.ignored_label = -100 self.is_multi_turn = True self.llama_factory_template = get_model_template(args.prompt_type.strip(), args.prompt_type_path.strip(), args.enable_thinking) - self.cutoff_len=args.seq_length + self.cutoff_len = args.seq_length def _format_msg(self, sample): return sample @@ -682,7 +682,7 @@ class R1AlpacaStyleInstructionHandler(BaseDatasetHandler): self.ignored_label = -100 self.is_multi_turn = True self.llama_factory_template = get_model_template(args.prompt_type.strip(), args.prompt_type_path.strip(), args.enable_thinking) - self.cutoff_len=args.seq_length + self.cutoff_len = args.seq_length def _format_msg(self, sample): return sample -- Gitee From 824acd2495aede4b38282723715989a34cff21bb Mon Sep 17 00:00:00 2001 From: HanhuiChen Date: Wed, 25 Jun 2025 09:28:42 +0000 Subject: [PATCH 10/14] update preprocess_data.py. Signed-off-by: HanhuiChen --- preprocess_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/preprocess_data.py b/preprocess_data.py index 763ac3750..9613db334 100644 --- a/preprocess_data.py +++ b/preprocess_data.py @@ -150,7 +150,7 @@ def add_data_args(parser): help="Use a zigzag attention mask.") group.add_argument("--script-data-dir", type=str, default=None, help="Python script dataset direction") - group.add_argument("--enable_thinking", action='store_true', + group.add_argument("--enable-thinking", action='store_true', help="Whether or not to enable thinking mode for reasoning models.") group.add_argument("--pad-to-multiple-of", type=int, default=1, help="Pad each of the data to the multiple of...") -- Gitee From 331e0c6a88b6d236d1f1fabc5c306759f9430a39 Mon Sep 17 00:00:00 2001 From: HanhuiChen Date: Wed, 25 Jun 2025 09:29:33 +0000 Subject: [PATCH 11/14] update mindspeed_llm/training/arguments.py. Signed-off-by: HanhuiChen --- mindspeed_llm/training/arguments.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mindspeed_llm/training/arguments.py b/mindspeed_llm/training/arguments.py index 474eead4f..cbfc304c8 100644 --- a/mindspeed_llm/training/arguments.py +++ b/mindspeed_llm/training/arguments.py @@ -467,7 +467,7 @@ def _add_data_args(parser): group.add_argument('--padded-samples', action='store_true', help='fill in the missing samples within an epoch, ' 'starting at index 0, aligned with the LlamaFatory.') - group.add_argument("--enable_thinking", action='store_true', + group.add_argument("--enable-thinking", action='store_true', help="Whether or not to enable thinking mode for reasoning models.") return parser -- Gitee From d6068a24d646fb26947b6137ebffec800667c9a8 Mon Sep 17 00:00:00 2001 From: HanhuiChen Date: Sat, 28 Jun 2025 03:39:46 +0000 Subject: [PATCH 12/14] update mindspeed_llm/tasks/preprocess/templates.py. Signed-off-by: HanhuiChen --- mindspeed_llm/tasks/preprocess/templates.py | 45 ++++++++++++++++++--- 1 file changed, 40 insertions(+), 5 deletions(-) diff --git a/mindspeed_llm/tasks/preprocess/templates.py b/mindspeed_llm/tasks/preprocess/templates.py index 36475ed54..4abb7cbd3 100644 --- a/mindspeed_llm/tasks/preprocess/templates.py +++ b/mindspeed_llm/tasks/preprocess/templates.py @@ -79,6 +79,16 @@ class Role(str, Enum): OBSERVATION = "observation" +def infer_max_len(source_len: int, target_len: int, max_len: int, reserved_label_len: int) -> Tuple[int, int]: + if source_len + target_len == 0: + max_target_len = 0 + else: + max_target_len = int(max_len * (target_len / (source_len + target_len))) + max_target_len = max(max_target_len, reserved_label_len) + max_source_len = max_len - min(max_target_len, target_len) + return max_source_len, max_target_len + + @dataclass class Template: format_user: "Formatter" @@ -205,19 +215,17 @@ class Template: cutoff_len: int, reserved_label_len: int, ) -> Sequence[Tuple[List[int], List[int]]]: - from .decoder_packed_mtf_dataset import _infer_seqlen - encoded_pairs = [] total_length = 0 - cutoff_len = cutoff_len - reserved_label_len for i in range(0, len(encoded_messages), 2): if total_length >= cutoff_len: break - max_source_len, max_target_len = _infer_seqlen( + max_source_len, max_target_len = infer_max_len( source_len=len(encoded_messages[i]), target_len=len(encoded_messages[i + 1]), - cutoff_len=(cutoff_len - total_length) + max_len=(cutoff_len - total_length), + reserved_label_len=reserved_label_len, ) source_ids = encoded_messages[i][:max_source_len] target_ids = encoded_messages[i + 1][:max_target_len] @@ -387,6 +395,33 @@ class ReasoningTemplate(Template): return self._make_pairs(encoded_messages, cutoff_len, reserved_label_len) + def _make_pairs( + self, + encoded_messages: Sequence[List[int]], + cutoff_len: int, + reserved_label_len: int, + ) -> Sequence[Tuple[List[int], List[int]]]: + from .decoder_packed_mtf_dataset import _infer_seqlen + + encoded_pairs = [] + total_length = 0 + cutoff_len = cutoff_len - reserved_label_len + for i in range(0, len(encoded_messages), 2): + if total_length >= cutoff_len: + break + + max_source_len, max_target_len = _infer_seqlen( + source_len=len(encoded_messages[i]), + target_len=len(encoded_messages[i + 1]), + cutoff_len=(cutoff_len - total_length) + ) + source_ids = encoded_messages[i][:max_source_len] + target_ids = encoded_messages[i + 1][:max_target_len] + total_length += len(source_ids) + len(target_ids) + encoded_pairs.append((source_ids, target_ids)) + + return encoded_pairs + templates: Dict[str, Template] = {} -- Gitee From b9a1e785602872a5297ffdb72554e2f73a360392 Mon Sep 17 00:00:00 2001 From: HanhuiChen Date: Mon, 30 Jun 2025 07:43:51 +0000 Subject: [PATCH 13/14] Update function and tool formatter of Qwen3. Signed-off-by: HanhuiChen --- configs/finetune/templates.json | 6 + mindspeed_llm/tasks/preprocess/formatter.py | 126 ++++++++++++++++++++ mindspeed_llm/tasks/preprocess/templates.py | 12 +- 3 files changed, 140 insertions(+), 4 deletions(-) diff --git a/configs/finetune/templates.json b/configs/finetune/templates.json index 4553251ab..0b4a3425c 100644 --- a/configs/finetune/templates.json +++ b/configs/finetune/templates.json @@ -595,6 +595,12 @@ "<|im_start|>system\n{{content}}<|im_end|>\n" ] }, + "format_function": { + "slots": [ + "{{content}}<|im_end|>\n" + ], + "tool_format": "qwen" + }, "format_observation": { "slots": [ "<|im_start|>user\n\n{{content}}\n<|im_end|>\n<|im_start|>assistant\n" diff --git a/mindspeed_llm/tasks/preprocess/formatter.py b/mindspeed_llm/tasks/preprocess/formatter.py index aaec643c1..5edc7aada 100644 --- a/mindspeed_llm/tasks/preprocess/formatter.py +++ b/mindspeed_llm/tasks/preprocess/formatter.py @@ -38,6 +38,15 @@ TOOL_SYSTEM_PROMPT = ( ) +QWEN_TOOL_PROMPT = ( + "\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\n" + "You are provided with function signatures within XML tags:\n{tool_text}" + "\n\n\nFor each function call, return a json object with function name and arguments within " + """ XML tags:\n\n{{"name": , """ + """"arguments": }}\n""" +) + + def default_tool_formatter(tools: List[Dict[str, Any]]) -> str: tool_text = "" tool_names = [] @@ -201,3 +210,120 @@ class ToolFormatter(Formatter): return default_tool_extractor(content) else: raise NotImplementedError + + +@dataclass +class ToolUtils(ABC): + """Base class for tool utilities.""" + + @staticmethod + @abstractmethod + def tool_formatter(tools: list[dict[str, Any]]) -> str: + r"""Generate the system message describing all the available tools.""" + ... + + @staticmethod + @abstractmethod + def function_formatter(functions: list[Tuple[str, str]]) -> str: + r"""Generate the assistant message including all the tool calls.""" + ... + + @staticmethod + @abstractmethod + def tool_extractor(content: str) -> Union[str, list[Tuple[str, str]]]: + r"""Extract all the function calls from the assistant message. + + It should be an inverse function of `function_formatter`. + """ + ... + + +class QwenToolUtils(ToolUtils): + r"""Qwen 2.5 tool using template.""" + + @staticmethod + def tool_formatter(tools: list[dict[str, Any]]) -> str: + tool_text = "" + for tool in tools: + wrapped_tool = tool if tool.get("type") == "function" else {"type": "function", "function": tool} + tool_text += "\n" + json.dumps(wrapped_tool, ensure_ascii=False) + + return QWEN_TOOL_PROMPT.format(tool_text=tool_text) + + @staticmethod + def function_formatter(functions: list[Tuple[str, str]]) -> str: + function_texts = [ + json.dumps({"name": name, "arguments": json.loads(arguments)}, ensure_ascii=False) + for name, arguments in functions + ] + return "\n".join([f"\n{text}\n" for text in function_texts]) + + @staticmethod + def tool_extractor(content: str) -> Union[str, list[Tuple[str, str]]]: + regex = re.compile(r"(.+?)(?=\s*|\s*$)", re.DOTALL) + tool_match: list[str] = re.findall(regex, content) + if not tool_match: + return content + + results = [] + for tool in tool_match: + try: + tool = json.loads(tool.strip()) + except json.JSONDecodeError: + return content + + if "name" not in tool or "arguments" not in tool: + return content + + results.append((tool["name"], json.dumps(tool["arguments"], ensure_ascii=False))) + + return results + + +@dataclass +class Qwen3FunctionFormatter(StringFormatter): + def __post_init__(self): + super().__post_init__() + self.tool_utils = QwenToolUtils() + + def apply(self, **kwargs) -> SLOTS: + content: str = kwargs.pop("content") + regex = re.compile(r"(.*)", re.DOTALL) + thought = re.search(regex, content) + if thought: + content = content.replace(thought.group(0), "") + + functions: list[Tuple[str, str]] = [] + try: + tool_calls = json.loads(content) + if not isinstance(tool_calls, list): # parallel function call + tool_calls = [tool_calls] + + for tool_call in tool_calls: + functions.append((tool_call["name"], json.dumps(tool_call["arguments"], ensure_ascii=False))) + + except json.JSONDecodeError: + raise RuntimeError(f"Invalid JSON format in function message: {str([content])}.") # flat string + + function_str = self.tool_utils.function_formatter(functions) + if thought: + function_str = thought.group(0) + function_str + + return super().apply(content=function_str) + + +@dataclass +class Qwen3ToolFormatter(Formatter): + def __post_init__(self): + self.tool_utils = QwenToolUtils() + + def apply(self, **kwargs) -> SLOTS: + content = kwargs.pop("content") + try: + tools = json.loads(content) + return [self.tool_utils.tool_formatter(tools) if len(tools) != 0 else ""] + except json.JSONDecodeError: + raise RuntimeError(f"Invalid JSON format in tool description: {str([content])}.") # flat string + + def extract(self, content: str) -> Union[str, Tuple[str, str]]: + return self.tool_utils.tool_extractor(content) \ No newline at end of file diff --git a/mindspeed_llm/tasks/preprocess/templates.py b/mindspeed_llm/tasks/preprocess/templates.py index 4abb7cbd3..e20d4a28d 100644 --- a/mindspeed_llm/tasks/preprocess/templates.py +++ b/mindspeed_llm/tasks/preprocess/templates.py @@ -24,7 +24,7 @@ from dataclasses import dataclass from enum import Enum, unique from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union -from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter +from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter, Qwen3FunctionFormatter, Qwen3ToolFormatter if TYPE_CHECKING: from transformers import PreTrainedTokenizer @@ -660,13 +660,17 @@ def register_custom_template(name, json_file_path=TEMPLATES_DIR, enable_thinking format_user = StringFormatter(**format_user) if format_user else None format_assistant = StringFormatter(**format_assistant) if format_assistant else None format_system = StringFormatter(**format_system) if format_system else None - format_function = FunctionFormatter(**format_function) if format_function else None format_observation = StringFormatter(**format_observation) if format_observation else None - format_tools = ToolFormatter(**format_tools) if format_tools else None format_separator = EmptyFormatter(**format_separator) if format_separator else None format_prefix = EmptyFormatter(**format_prefix) if format_prefix else None template_class = _get_template_class(template_class) if template_class else Template - + if name == 'qwen3': + format_function = Qwen3FunctionFormatter(**format_function) if format_function else None + format_tools = Qwen3ToolFormatter(**format_tools) if format_tools else None + else: + format_function = FunctionFormatter(**format_function) if format_function else None + format_tools = ToolFormatter(**format_tools) if format_tools else None + _register_template( name=name, format_user=format_user, -- Gitee From 23ad7a64b624dc052afc938cec308b9fcf41d535 Mon Sep 17 00:00:00 2001 From: HanhuiChen Date: Mon, 30 Jun 2025 07:58:39 +0000 Subject: [PATCH 14/14] update mindspeed_llm/tasks/preprocess/formatter.py. Signed-off-by: HanhuiChen --- mindspeed_llm/tasks/preprocess/formatter.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mindspeed_llm/tasks/preprocess/formatter.py b/mindspeed_llm/tasks/preprocess/formatter.py index 5edc7aada..2fd6031d4 100644 --- a/mindspeed_llm/tasks/preprocess/formatter.py +++ b/mindspeed_llm/tasks/preprocess/formatter.py @@ -302,8 +302,8 @@ class Qwen3FunctionFormatter(StringFormatter): for tool_call in tool_calls: functions.append((tool_call["name"], json.dumps(tool_call["arguments"], ensure_ascii=False))) - except json.JSONDecodeError: - raise RuntimeError(f"Invalid JSON format in function message: {str([content])}.") # flat string + except json.JSONDecodeError as e: + raise RuntimeError(f"Invalid JSON format in function message: {str([content])}.") from e # flat string function_str = self.tool_utils.function_formatter(functions) if thought: @@ -322,8 +322,8 @@ class Qwen3ToolFormatter(Formatter): try: tools = json.loads(content) return [self.tool_utils.tool_formatter(tools) if len(tools) != 0 else ""] - except json.JSONDecodeError: - raise RuntimeError(f"Invalid JSON format in tool description: {str([content])}.") # flat string + except json.JSONDecodeError as e: + raise RuntimeError(f"Invalid JSON format in tool description: {str([content])}.") from e # flat string def extract(self, content: str) -> Union[str, Tuple[str, str]]: return self.tool_utils.tool_extractor(content) \ No newline at end of file -- Gitee