diff --git a/configs/finetune/templates.json b/configs/finetune/templates.json index 8310cc61ac4ac2aa9d916725a8b264ea46c8e65c..0b4a3425caaafe03698eb6b70467878fe63a76f6 100644 --- a/configs/finetune/templates.json +++ b/configs/finetune/templates.json @@ -445,7 +445,8 @@ "slots": [ "<>\n{{content}}\n<>\n\n" ] - } + }, + "template_class": "Llama2Template" }, { "name": "alpaca", @@ -576,5 +577,42 @@ ] }, "default_system": "You are a helpful assistant." + }, + { + "name": "qwen3", + "format_user": { + "slots": [ + "<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n" + ] + }, + "format_assistant": { + "slots": [ + "{{content}}<|im_end|>\n" + ] + }, + "format_system": { + "slots": [ + "<|im_start|>system\n{{content}}<|im_end|>\n" + ] + }, + "format_function": { + "slots": [ + "{{content}}<|im_end|>\n" + ], + "tool_format": "qwen" + }, + "format_observation": { + "slots": [ + "<|im_start|>user\n\n{{content}}\n<|im_end|>\n<|im_start|>assistant\n" + ] + }, + "format_tools": { + "tool_format": "qwen" + }, + "stop_words": [ + "<|im_end|>" + ], + "replace_eos": true, + "template_class": "ReasoningTemplate" } ] \ No newline at end of file diff --git a/mindspeed_llm/inference/text_generation/tokenization.py b/mindspeed_llm/inference/text_generation/tokenization.py index bd70250afe254a58debef9c8fb7f3c2015053367..30b5785c51767ec500b9860af565b5cbe96de54b 100644 --- a/mindspeed_llm/inference/text_generation/tokenization.py +++ b/mindspeed_llm/inference/text_generation/tokenization.py @@ -132,7 +132,7 @@ def _tokenize_prompts_and_batch(tokenizer, prompts, tokens_to_generate, max_gene # Tokenize all the prompts. if hasattr(args, "prompt_type") and args.prompt_type is not None: - template = get_model_template(args.prompt_type.strip(), args.prompt_type_path.strip()) + template = get_model_template(args.prompt_type.strip(), args.prompt_type_path.strip(), args.enable_thinking) if args.hf_chat_template: if not hasattr(tokenizer, "apply_chat_template"): diff --git a/mindspeed_llm/tasks/preprocess/data_handler.py b/mindspeed_llm/tasks/preprocess/data_handler.py index 5e7166f21537cb0890f6dacc5ee931a83d6783dc..07b10676a210bc174de1bad5166022785e6b30f9 100644 --- a/mindspeed_llm/tasks/preprocess/data_handler.py +++ b/mindspeed_llm/tasks/preprocess/data_handler.py @@ -105,11 +105,11 @@ class BaseDatasetHandler(object): batch = doc["input_ids"] for idx, sample in enumerate(batch): length = len(sample) - if (length > self.args.seq_length) and (not self.args.neat_pack): - logger.warning(f"Dropped lengthy example with length {length} > {self.args.seq_length}.") + if (length >= self.args.seq_length) and (not self.args.neat_pack): + logger.warning(f"Dropped lengthy example with length {length} >= {self.args.seq_length}.") else: - if length > self.args.seq_length: - logger.warning(f"Sequence length {length} > {self.args.seq_length}.") + if length >= self.args.seq_length: + logger.warning(f"Sequence length {length} >= {self.args.seq_length}.") sample = sample[:self.args.seq_length - 1] length = len(sample) lengths.append(length) @@ -342,7 +342,8 @@ class LlamaFactoryInstructionHandler(BaseDatasetHandler): self.args.output_prefix = self.args.output_prefix + "_packed" self.ignored_label = -100 self.is_multi_turn = True - self.llama_factory_template = get_model_template(args.prompt_type.strip(), args.prompt_type_path.strip()) + self.llama_factory_template = get_model_template(args.prompt_type.strip(), args.prompt_type_path.strip(), args.enable_thinking) + self.cutoff_len = args.seq_length def _format_msg(self, sample): return sample @@ -360,9 +361,8 @@ class LlamaFactoryInstructionHandler(BaseDatasetHandler): messages = [{'role': 'user', 'content': ''}, {'role': 'assistant', 'content': ''}] else: messages = example["prompt"] + example["response"] - for source_ids, target_ids in self.llama_factory_template.encode_multiturn( - tokenizer, messages, example["system"][0], example["tools"][0] + tokenizer, messages, example["system"][0], example["tools"][0], self.cutoff_len ): if self.train_on_inputs: source_mask = source_ids @@ -446,7 +446,7 @@ class HunyuanInstructionHandler(BaseDatasetHandler): self.ignored_label = -100 self.is_multi_turn = True - self.hunyuanlarge_template = get_model_template(args.prompt_type.strip(), args.prompt_type_path.strip()) + self.hunyuanlarge_template = get_model_template(args.prompt_type.strip(), args.prompt_type_path.strip(), args.enable_thinking) def _format_msg(self, sample): return sample @@ -511,7 +511,7 @@ class AlpacaStylePairwiseHandler(BaseDatasetHandler): self.args.json_keys = ["chosen_input_ids", "chosen_labels", "rejected_input_ids", "rejected_labels"] self.args.output_prefix = self.args.output_prefix + "_packed" self.ignored_label = -100 - self.llama_factory_template = get_model_template(args.prompt_type.strip(), args.prompt_type_path.strip()) + self.llama_factory_template = get_model_template(args.prompt_type.strip(), args.prompt_type_path.strip(), args.enable_thinking) def _filter(self, sample): chosen_messages = sample["prompt"] + [sample["response"][0]] @@ -621,7 +621,8 @@ class PPOAlpacaStyleInstructionHandler(BaseDatasetHandler): self.args.output_prefix = self.args.output_prefix + "_packed" self.ignored_label = -100 self.is_multi_turn = True - self.llama_factory_template = get_model_template(args.prompt_type.strip(), args.prompt_type_path.strip()) + self.llama_factory_template = get_model_template(args.prompt_type.strip(), args.prompt_type_path.strip(), args.enable_thinking) + self.cutoff_len = args.seq_length def _format_msg(self, sample): return sample @@ -641,7 +642,7 @@ class PPOAlpacaStyleInstructionHandler(BaseDatasetHandler): messages = example["prompt"] + example["response"] for source_ids, _ in self.llama_factory_template.encode_multiturn( - tokenizer, messages, example["system"][0], example["tools"][0] + tokenizer, messages, example["system"][0], example["tools"][0], self.cutoff_len ): input_ids += source_ids @@ -680,7 +681,8 @@ class R1AlpacaStyleInstructionHandler(BaseDatasetHandler): self.args.output_prefix = self.args.output_prefix + "_packed" self.ignored_label = -100 self.is_multi_turn = True - self.llama_factory_template = get_model_template(args.prompt_type.strip(), args.prompt_type_path.strip()) + self.llama_factory_template = get_model_template(args.prompt_type.strip(), args.prompt_type_path.strip(), args.enable_thinking) + self.cutoff_len = args.seq_length def _format_msg(self, sample): return sample @@ -700,7 +702,7 @@ class R1AlpacaStyleInstructionHandler(BaseDatasetHandler): messages = example["prompt"] + example["response"] multiturn_input_ids = self.llama_factory_template.encode_multiturn( - tokenizer, messages, example["system"][0], example["tools"][0]) + tokenizer, messages, example["system"][0], example["tools"][0], self.cutoff_len) turns = len(multiturn_input_ids) diff --git a/mindspeed_llm/tasks/preprocess/decoder_packed_mtf_dataset.py b/mindspeed_llm/tasks/preprocess/decoder_packed_mtf_dataset.py index 980b56939ed7dde9f9aa04a2e6b52919963d932d..49670bfb81cbf003907319c2a1be7021e6e35f7a 100644 --- a/mindspeed_llm/tasks/preprocess/decoder_packed_mtf_dataset.py +++ b/mindspeed_llm/tasks/preprocess/decoder_packed_mtf_dataset.py @@ -246,7 +246,7 @@ class DecoderPackedMTFDataset(torch.utils.data.Dataset): template = None # get model chat template if hasattr(self.args, "prompt_type") and self.args.prompt_type is not None: - template = get_model_template(self.args.prompt_type, self.args.prompt_type_path) + template = get_model_template(self.args.prompt_type, self.args.prompt_type_path, self.args.enable_thinking) prompt_begin_list, prompt_end_list = get_prompt_index(item["labels"], IGNORE_INDEX) diff --git a/mindspeed_llm/tasks/preprocess/formatter.py b/mindspeed_llm/tasks/preprocess/formatter.py index aaec643c140c7a0a8de04c8b11feef17dc0444a3..2fd6031d4bef52c19f368335cec804e51c09a149 100644 --- a/mindspeed_llm/tasks/preprocess/formatter.py +++ b/mindspeed_llm/tasks/preprocess/formatter.py @@ -38,6 +38,15 @@ TOOL_SYSTEM_PROMPT = ( ) +QWEN_TOOL_PROMPT = ( + "\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\n" + "You are provided with function signatures within XML tags:\n{tool_text}" + "\n\n\nFor each function call, return a json object with function name and arguments within " + """ XML tags:\n\n{{"name": , """ + """"arguments": }}\n""" +) + + def default_tool_formatter(tools: List[Dict[str, Any]]) -> str: tool_text = "" tool_names = [] @@ -201,3 +210,120 @@ class ToolFormatter(Formatter): return default_tool_extractor(content) else: raise NotImplementedError + + +@dataclass +class ToolUtils(ABC): + """Base class for tool utilities.""" + + @staticmethod + @abstractmethod + def tool_formatter(tools: list[dict[str, Any]]) -> str: + r"""Generate the system message describing all the available tools.""" + ... + + @staticmethod + @abstractmethod + def function_formatter(functions: list[Tuple[str, str]]) -> str: + r"""Generate the assistant message including all the tool calls.""" + ... + + @staticmethod + @abstractmethod + def tool_extractor(content: str) -> Union[str, list[Tuple[str, str]]]: + r"""Extract all the function calls from the assistant message. + + It should be an inverse function of `function_formatter`. + """ + ... + + +class QwenToolUtils(ToolUtils): + r"""Qwen 2.5 tool using template.""" + + @staticmethod + def tool_formatter(tools: list[dict[str, Any]]) -> str: + tool_text = "" + for tool in tools: + wrapped_tool = tool if tool.get("type") == "function" else {"type": "function", "function": tool} + tool_text += "\n" + json.dumps(wrapped_tool, ensure_ascii=False) + + return QWEN_TOOL_PROMPT.format(tool_text=tool_text) + + @staticmethod + def function_formatter(functions: list[Tuple[str, str]]) -> str: + function_texts = [ + json.dumps({"name": name, "arguments": json.loads(arguments)}, ensure_ascii=False) + for name, arguments in functions + ] + return "\n".join([f"\n{text}\n" for text in function_texts]) + + @staticmethod + def tool_extractor(content: str) -> Union[str, list[Tuple[str, str]]]: + regex = re.compile(r"(.+?)(?=\s*|\s*$)", re.DOTALL) + tool_match: list[str] = re.findall(regex, content) + if not tool_match: + return content + + results = [] + for tool in tool_match: + try: + tool = json.loads(tool.strip()) + except json.JSONDecodeError: + return content + + if "name" not in tool or "arguments" not in tool: + return content + + results.append((tool["name"], json.dumps(tool["arguments"], ensure_ascii=False))) + + return results + + +@dataclass +class Qwen3FunctionFormatter(StringFormatter): + def __post_init__(self): + super().__post_init__() + self.tool_utils = QwenToolUtils() + + def apply(self, **kwargs) -> SLOTS: + content: str = kwargs.pop("content") + regex = re.compile(r"(.*)", re.DOTALL) + thought = re.search(regex, content) + if thought: + content = content.replace(thought.group(0), "") + + functions: list[Tuple[str, str]] = [] + try: + tool_calls = json.loads(content) + if not isinstance(tool_calls, list): # parallel function call + tool_calls = [tool_calls] + + for tool_call in tool_calls: + functions.append((tool_call["name"], json.dumps(tool_call["arguments"], ensure_ascii=False))) + + except json.JSONDecodeError as e: + raise RuntimeError(f"Invalid JSON format in function message: {str([content])}.") from e # flat string + + function_str = self.tool_utils.function_formatter(functions) + if thought: + function_str = thought.group(0) + function_str + + return super().apply(content=function_str) + + +@dataclass +class Qwen3ToolFormatter(Formatter): + def __post_init__(self): + self.tool_utils = QwenToolUtils() + + def apply(self, **kwargs) -> SLOTS: + content = kwargs.pop("content") + try: + tools = json.loads(content) + return [self.tool_utils.tool_formatter(tools) if len(tools) != 0 else ""] + except json.JSONDecodeError as e: + raise RuntimeError(f"Invalid JSON format in tool description: {str([content])}.") from e # flat string + + def extract(self, content: str) -> Union[str, Tuple[str, str]]: + return self.tool_utils.tool_extractor(content) \ No newline at end of file diff --git a/mindspeed_llm/tasks/preprocess/templates.py b/mindspeed_llm/tasks/preprocess/templates.py index 57aa726bf3d9427839bc1a6a5a16eaf4f548abaf..e20d4a28d0f4a0221d3a05eb7d6b7b48aa42308f 100644 --- a/mindspeed_llm/tasks/preprocess/templates.py +++ b/mindspeed_llm/tasks/preprocess/templates.py @@ -14,15 +14,17 @@ # limitations under the License. import os +import sys import re import json import logging +from copy import deepcopy from pathlib import Path from dataclasses import dataclass from enum import Enum, unique from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union -from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter +from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter, Qwen3FunctionFormatter, Qwen3ToolFormatter if TYPE_CHECKING: from transformers import PreTrainedTokenizer @@ -99,9 +101,11 @@ class Template: format_prefix: "Formatter" default_system: str stop_words: List[str] + thought_words: tuple[str, str] efficient_eos: bool replace_eos: bool force_system: bool + enable_thinking: Optional[bool] def encode_oneturn( @@ -231,6 +235,22 @@ class Template: return encoded_pairs + def add_thought(self, content: str = "") -> str: + r"""Add empty thought to assistant message.""" + return f"{self.thought_words[0]}\n\n{self.thought_words[1]}\n\n" + content + + + def remove_thought(self, content: str) -> str: + r"""Remove thought from assistant message.""" + pattern = re.compile(f"{re.escape(self.thought_words[0])}(.*?){re.escape(self.thought_words[1])}", re.DOTALL) + return re.sub(pattern, "", content).lstrip("\n") + + + def get_thought_word_ids(self, tokenizer: "PreTrainedTokenizer") -> list[int]: + r"""Get the token ids of thought words.""" + return tokenizer.encode(self.add_thought(), add_special_tokens=False) + + @dataclass class Llama2Template(Template): def _encode( @@ -276,6 +296,133 @@ class Llama2Template(Template): return self._make_pairs(encoded_messages, cutoff_len, reserved_label_len) +@dataclass +class ReasoningTemplate(Template): + r"""A template that add thought to assistant message.""" + + def _encode( + self, + tokenizer: "PreTrainedTokenizer", + messages: List[Dict[str, str]], + system: str, + tools: str, + ) -> Sequence[list[int]]: + r""" + Encodes formatted inputs to pairs of token ids. + Turn 0: prefix + system + query resp + Turn t: sep + query resp + """ + system = system or self.default_system + encoded_messages = [] + for i, message in enumerate(messages): + elements = [] + if i == 0: + elements += self.format_prefix.apply() + if system or tools: + tool_text = self.format_tools.apply(content=tools)[0] if tools else "" + elements += self.format_system.apply(content=(system + tool_text)) + elif i > 0 and i % 2 == 0: + elements += self.format_separator.apply() + + if message["role"] == Role.USER.value: + elements += self.format_user.apply(content=message["content"], idx=str(i // 2)) + elif message["role"] == Role.ASSISTANT.value: + elements += self.format_assistant.apply(content=message["content"]) + elif message["role"] == Role.OBSERVATION.value: + elements += self.format_observation.apply(content=message["content"]) + elif message["role"] == Role.FUNCTION.value: + elements += self.format_function.apply(content=message["content"]) + else: + raise NotImplementedError("Unexpected role: {}".format(message["role"])) + encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements)) + + return encoded_messages + + def encode_oneturn( + self, + tokenizer: "PreTrainedTokenizer", + messages: list[dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + cutoff_len: int = 1_000_000, + reserved_label_len: int = 1, + ) -> Tuple[list[int], list[int]]: + messages = deepcopy(messages) + for i in range(1, len(messages) - 2, 2): + messages[i]["content"] = self.remove_thought(messages[i]["content"]) + + if self.enable_thinking is False: # remove all cot + messages[-1]["content"] = self.remove_thought(messages[-1]["content"]) + + prompt_ids, response_ids = super().encode_oneturn(tokenizer, messages, system, tools) + if ( + self.thought_words[0] not in messages[-1]["content"] + and self.thought_words[1] not in messages[-1]["content"] + ): # add empty cot + if not self.enable_thinking: # do not compute loss + prompt_ids += self.get_thought_word_ids(tokenizer) + else: # do compute loss + response_ids = self.get_thought_word_ids(tokenizer) + response_ids + + return prompt_ids, response_ids + + def encode_multiturn( + self, + tokenizer: "PreTrainedTokenizer", + messages: list[dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + cutoff_len: int = 1_000_000, + reserved_label_len: int = 1, + ) -> Sequence[Tuple[List[int], List[int]]]: + messages = deepcopy(messages) + + if self.enable_thinking is False: # remove all cot + for i in range(1, len(messages), 2): + messages[i]["content"] = self.remove_thought(messages[i]["content"]) + + encoded_messages = self._encode(tokenizer, messages, system, tools) + + for i in range(0, len(messages), 2): + if ( + self.thought_words[0] not in messages[i + 1]["content"] + and self.thought_words[1] not in messages[i + 1]["content"] + ): # add empty cot + if not self.enable_thinking: # do not compute loss + encoded_messages[i] += self.get_thought_word_ids(tokenizer) + else: # do compute loss + encoded_messages[i + 1] = self.get_thought_word_ids(tokenizer) + encoded_messages[i + 1] + + return self._make_pairs(encoded_messages, cutoff_len, reserved_label_len) + + def _make_pairs( + self, + encoded_messages: Sequence[List[int]], + cutoff_len: int, + reserved_label_len: int, + ) -> Sequence[Tuple[List[int], List[int]]]: + from .decoder_packed_mtf_dataset import _infer_seqlen + + encoded_pairs = [] + total_length = 0 + cutoff_len = cutoff_len - reserved_label_len + for i in range(0, len(encoded_messages), 2): + if total_length >= cutoff_len: + break + + max_source_len, max_target_len = _infer_seqlen( + source_len=len(encoded_messages[i]), + target_len=len(encoded_messages[i + 1]), + cutoff_len=(cutoff_len - total_length) + ) + source_ids = encoded_messages[i][:max_source_len] + target_ids = encoded_messages[i + 1][:max_target_len] + total_length += len(source_ids) + len(target_ids) + encoded_pairs.append((source_ids, target_ids)) + + return encoded_pairs + + templates: Dict[str, Template] = {} @@ -283,8 +430,8 @@ def get_templates() -> Dict[str, Template]: return templates -def get_model_template(name, prompt_type_path): - name = register_custom_template(name, prompt_type_path) +def get_model_template(name, prompt_type_path, enable_thinking): + name = register_custom_template(name, prompt_type_path, enable_thinking) if name is None: template = templates["empty"] # placeholder else: @@ -298,8 +445,9 @@ def fix_model_tokenizer( tokenizer: "PreTrainedTokenizer", name: Optional[str] = None, prompt_type_path: Optional[str] = None, + enable_thinking: Optional[bool] = False, ): - template = get_model_template(name, prompt_type_path) + template = get_model_template(name, prompt_type_path, enable_thinking) stop_words = template.stop_words if template.replace_eos: @@ -337,9 +485,12 @@ def _register_template( format_prefix: Optional["Formatter"] = None, default_system: str = "", stop_words: List[str] = [], + thought_words: Optional[tuple[str, str]] = None, efficient_eos: bool = False, replace_eos: bool = False, force_system: bool = False, + enable_thinking: Optional[bool] = True, + template_class: type["Template"] = Template, ) -> None: r""" Registers a chat template. @@ -368,7 +519,6 @@ def _register_template( ``` """ eos_slots = [] if efficient_eos else [{"eos_token"}] - template_class = Llama2Template if name.startswith("llama2") else Template default_user_formatter = StringFormatter(slots=["{{content}}"]) default_assistant_formatter = StringFormatter(slots=["{{content}}"] + eos_slots) default_function_formatter = FunctionFormatter(slots=["Action: {{name}}\nAction Input: {{arguments}}"] + eos_slots) @@ -386,9 +536,11 @@ def _register_template( format_prefix=format_prefix or default_prefix_formatter, default_system=default_system, stop_words=stop_words, + thought_words=thought_words or ("", ""), efficient_eos=efficient_eos, replace_eos=replace_eos, force_system=force_system, + enable_thinking=enable_thinking, ) @@ -472,7 +624,7 @@ def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer") return jinja_template -def register_custom_template(name, json_file_path=TEMPLATES_DIR) -> str: +def register_custom_template(name, json_file_path=TEMPLATES_DIR, enable_thinking=False) -> str: if name in templates: return name @@ -501,18 +653,23 @@ def register_custom_template(name, json_file_path=TEMPLATES_DIR) -> str: efficient_eos = _format_custom_template(config.get("efficient_eos", False)) replace_eos = _format_custom_template(config.get("replace_eos", False)) force_system = _format_custom_template(config.get("force_system", False)) + template_class = _format_custom_template(config.get("template_class", None)) if isinstance(default_system, list): default_system = "".join(default_system) if all(isinstance(sentence, str) for sentence in default_system) else default_system format_user = StringFormatter(**format_user) if format_user else None format_assistant = StringFormatter(**format_assistant) if format_assistant else None format_system = StringFormatter(**format_system) if format_system else None - format_function = FunctionFormatter(**format_function) if format_function else None format_observation = StringFormatter(**format_observation) if format_observation else None - format_tools = ToolFormatter(**format_tools) if format_tools else None format_separator = EmptyFormatter(**format_separator) if format_separator else None format_prefix = EmptyFormatter(**format_prefix) if format_prefix else None - + template_class = _get_template_class(template_class) if template_class else Template + if name == 'qwen3': + format_function = Qwen3FunctionFormatter(**format_function) if format_function else None + format_tools = Qwen3ToolFormatter(**format_tools) if format_tools else None + else: + format_function = FunctionFormatter(**format_function) if format_function else None + format_tools = ToolFormatter(**format_tools) if format_tools else None _register_template( name=name, @@ -526,9 +683,12 @@ def register_custom_template(name, json_file_path=TEMPLATES_DIR) -> str: format_prefix=format_prefix, default_system=default_system, stop_words=stop_words, + thought_words=("", ""), efficient_eos=efficient_eos, replace_eos=replace_eos, - force_system=force_system + force_system=force_system, + enable_thinking=enable_thinking, + template_class=template_class ) return name @@ -538,4 +698,16 @@ def _format_custom_template(slots: Dict) -> Dict: if slots and isinstance(slots, Dict): for key, slot in slots.items(): slots[key] = list(map(lambda slot: set(slot) if isinstance(slot, list) else slot, slot)) if slot else None - return slots \ No newline at end of file + return slots + + +def _get_template_class(template_name: str) -> None: + current_module = sys.modules.get(__name__) + if not current_module: + raise Exception("curent module not found") + template_class = getattr(current_module, template_name, None) + if template_class is None: + template_class = Template + logger.info("template will use %s to format dataset", template_class.__name__) + + return template_class \ No newline at end of file diff --git a/mindspeed_llm/training/arguments.py b/mindspeed_llm/training/arguments.py index faa582ea9b2ea40abe01e9023aea313698f51a30..cbfc304c8879881e682c16013fd11ad9318960e6 100644 --- a/mindspeed_llm/training/arguments.py +++ b/mindspeed_llm/training/arguments.py @@ -467,6 +467,9 @@ def _add_data_args(parser): group.add_argument('--padded-samples', action='store_true', help='fill in the missing samples within an epoch, ' 'starting at index 0, aligned with the LlamaFatory.') + group.add_argument("--enable-thinking", action='store_true', + help="Whether or not to enable thinking mode for reasoning models.") + return parser @@ -838,7 +841,7 @@ def _add_training_args(parser): group.add_argument('--prompt-type', type=str, default=None, choices=['default', 'empty', 'trl', 'chatglm2', 'chatglm3', 'chatglm3_system', 'glm4', 'chatml', 'chatml_de', 'qwen', 'qwen_r1', "qwen_math_r1", 'llama3', 'llama2', 'mistral', 'mixtral', 'gemma', 'alpaca', - 'deepseek2', 'deepseek2-lite', 'minicpm3', 'cpm', 'baichuan2', 'deepseek3', 'intern2', 'hunyuan'], + 'deepseek2', 'deepseek2-lite', 'minicpm3', 'cpm', 'baichuan2', 'deepseek3', 'intern2', 'hunyuan', 'qwen3'], help='Which template to use for constructing prompts in training/inference.' 'e.g., "qwen"') group.add_argument('--prompt-type-path', type=str, default=TEMPLATES_DIR, help='Path to the json file of templates.') diff --git a/mindspeed_llm/training/tokenizer/tokenizer.py b/mindspeed_llm/training/tokenizer/tokenizer.py index 502af3fc114fcacc461820b7ae9295d1d1818ae5..b82b094270224f378691286472a308d25c2e46d3 100644 --- a/mindspeed_llm/training/tokenizer/tokenizer.py +++ b/mindspeed_llm/training/tokenizer/tokenizer.py @@ -66,7 +66,7 @@ def build_tokenizer(args): if ("PreTrainedTokenizerBase" not in str(tokenizer.tokenizer._pad.__func__)): tokenizer.tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer.tokenizer) tokenizer.tokenizer.padding_side = "right" - fix_model_tokenizer(tokenizer.tokenizer, args.prompt_type.strip(), args.prompt_type_path.strip()) + fix_model_tokenizer(tokenizer.tokenizer, args.prompt_type.strip(), args.prompt_type_path.strip(), args.enable_thinking) if args.tokenizer_type == "GPTSentencePieceTokenizer": tokenizer.tokenizer.eos_token_id = tokenizer.tokenizer._eos_id diff --git a/preprocess_data.py b/preprocess_data.py index c80d1fd4beac83680ac3d01d562c46076652bd43..9613db334f8bec73cd286fd6cd3c94c603d23c9c 100644 --- a/preprocess_data.py +++ b/preprocess_data.py @@ -113,7 +113,7 @@ def add_data_args(parser): group.add_argument('--prompt-type', type=str, default=None, choices=['default', 'empty', 'trl', 'chatglm2', 'chatglm3', 'chatglm3_system', 'glm4', 'chatml', 'chatml_de', 'qwen', 'qwen_r1', "qwen_math_r1", 'llama3', 'llama2', 'mistral', 'mixtral', 'gemma', 'alpaca', - 'deepseek2', 'deepseek2-lite', 'cpm', 'baichuan2', 'deepseek3', 'intern2', 'hunyuan'], + 'deepseek2', 'deepseek2-lite', 'cpm', 'baichuan2', 'deepseek3', 'intern2', 'hunyuan', 'qwen3'], help='Which template to use for constructing prompts in training.' 'e.g., "qwen"') group.add_argument('--prompt-type-path', type=str, default=TEMPLATES_DIR, @@ -150,6 +150,8 @@ def add_data_args(parser): help="Use a zigzag attention mask.") group.add_argument("--script-data-dir", type=str, default=None, help="Python script dataset direction") + group.add_argument("--enable-thinking", action='store_true', + help="Whether or not to enable thinking mode for reasoning models.") group.add_argument("--pad-to-multiple-of", type=int, default=1, help="Pad each of the data to the multiple of...")