diff --git a/configs/finetune/templates.json b/configs/finetune/templates.json
index 8310cc61ac4ac2aa9d916725a8b264ea46c8e65c..0b4a3425caaafe03698eb6b70467878fe63a76f6 100644
--- a/configs/finetune/templates.json
+++ b/configs/finetune/templates.json
@@ -445,7 +445,8 @@
"slots": [
"<>\n{{content}}\n<>\n\n"
]
- }
+ },
+ "template_class": "Llama2Template"
},
{
"name": "alpaca",
@@ -576,5 +577,42 @@
]
},
"default_system": "You are a helpful assistant."
+ },
+ {
+ "name": "qwen3",
+ "format_user": {
+ "slots": [
+ "<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"
+ ]
+ },
+ "format_assistant": {
+ "slots": [
+ "{{content}}<|im_end|>\n"
+ ]
+ },
+ "format_system": {
+ "slots": [
+ "<|im_start|>system\n{{content}}<|im_end|>\n"
+ ]
+ },
+ "format_function": {
+ "slots": [
+ "{{content}}<|im_end|>\n"
+ ],
+ "tool_format": "qwen"
+ },
+ "format_observation": {
+ "slots": [
+ "<|im_start|>user\n\n{{content}}\n<|im_end|>\n<|im_start|>assistant\n"
+ ]
+ },
+ "format_tools": {
+ "tool_format": "qwen"
+ },
+ "stop_words": [
+ "<|im_end|>"
+ ],
+ "replace_eos": true,
+ "template_class": "ReasoningTemplate"
}
]
\ No newline at end of file
diff --git a/mindspeed_llm/inference/text_generation/tokenization.py b/mindspeed_llm/inference/text_generation/tokenization.py
index bd70250afe254a58debef9c8fb7f3c2015053367..30b5785c51767ec500b9860af565b5cbe96de54b 100644
--- a/mindspeed_llm/inference/text_generation/tokenization.py
+++ b/mindspeed_llm/inference/text_generation/tokenization.py
@@ -132,7 +132,7 @@ def _tokenize_prompts_and_batch(tokenizer, prompts, tokens_to_generate, max_gene
# Tokenize all the prompts.
if hasattr(args, "prompt_type") and args.prompt_type is not None:
- template = get_model_template(args.prompt_type.strip(), args.prompt_type_path.strip())
+ template = get_model_template(args.prompt_type.strip(), args.prompt_type_path.strip(), args.enable_thinking)
if args.hf_chat_template:
if not hasattr(tokenizer, "apply_chat_template"):
diff --git a/mindspeed_llm/tasks/preprocess/data_handler.py b/mindspeed_llm/tasks/preprocess/data_handler.py
index 5e7166f21537cb0890f6dacc5ee931a83d6783dc..07b10676a210bc174de1bad5166022785e6b30f9 100644
--- a/mindspeed_llm/tasks/preprocess/data_handler.py
+++ b/mindspeed_llm/tasks/preprocess/data_handler.py
@@ -105,11 +105,11 @@ class BaseDatasetHandler(object):
batch = doc["input_ids"]
for idx, sample in enumerate(batch):
length = len(sample)
- if (length > self.args.seq_length) and (not self.args.neat_pack):
- logger.warning(f"Dropped lengthy example with length {length} > {self.args.seq_length}.")
+ if (length >= self.args.seq_length) and (not self.args.neat_pack):
+ logger.warning(f"Dropped lengthy example with length {length} >= {self.args.seq_length}.")
else:
- if length > self.args.seq_length:
- logger.warning(f"Sequence length {length} > {self.args.seq_length}.")
+ if length >= self.args.seq_length:
+ logger.warning(f"Sequence length {length} >= {self.args.seq_length}.")
sample = sample[:self.args.seq_length - 1]
length = len(sample)
lengths.append(length)
@@ -342,7 +342,8 @@ class LlamaFactoryInstructionHandler(BaseDatasetHandler):
self.args.output_prefix = self.args.output_prefix + "_packed"
self.ignored_label = -100
self.is_multi_turn = True
- self.llama_factory_template = get_model_template(args.prompt_type.strip(), args.prompt_type_path.strip())
+ self.llama_factory_template = get_model_template(args.prompt_type.strip(), args.prompt_type_path.strip(), args.enable_thinking)
+ self.cutoff_len = args.seq_length
def _format_msg(self, sample):
return sample
@@ -360,9 +361,8 @@ class LlamaFactoryInstructionHandler(BaseDatasetHandler):
messages = [{'role': 'user', 'content': ''}, {'role': 'assistant', 'content': ''}]
else:
messages = example["prompt"] + example["response"]
-
for source_ids, target_ids in self.llama_factory_template.encode_multiturn(
- tokenizer, messages, example["system"][0], example["tools"][0]
+ tokenizer, messages, example["system"][0], example["tools"][0], self.cutoff_len
):
if self.train_on_inputs:
source_mask = source_ids
@@ -446,7 +446,7 @@ class HunyuanInstructionHandler(BaseDatasetHandler):
self.ignored_label = -100
self.is_multi_turn = True
- self.hunyuanlarge_template = get_model_template(args.prompt_type.strip(), args.prompt_type_path.strip())
+ self.hunyuanlarge_template = get_model_template(args.prompt_type.strip(), args.prompt_type_path.strip(), args.enable_thinking)
def _format_msg(self, sample):
return sample
@@ -511,7 +511,7 @@ class AlpacaStylePairwiseHandler(BaseDatasetHandler):
self.args.json_keys = ["chosen_input_ids", "chosen_labels", "rejected_input_ids", "rejected_labels"]
self.args.output_prefix = self.args.output_prefix + "_packed"
self.ignored_label = -100
- self.llama_factory_template = get_model_template(args.prompt_type.strip(), args.prompt_type_path.strip())
+ self.llama_factory_template = get_model_template(args.prompt_type.strip(), args.prompt_type_path.strip(), args.enable_thinking)
def _filter(self, sample):
chosen_messages = sample["prompt"] + [sample["response"][0]]
@@ -621,7 +621,8 @@ class PPOAlpacaStyleInstructionHandler(BaseDatasetHandler):
self.args.output_prefix = self.args.output_prefix + "_packed"
self.ignored_label = -100
self.is_multi_turn = True
- self.llama_factory_template = get_model_template(args.prompt_type.strip(), args.prompt_type_path.strip())
+ self.llama_factory_template = get_model_template(args.prompt_type.strip(), args.prompt_type_path.strip(), args.enable_thinking)
+ self.cutoff_len = args.seq_length
def _format_msg(self, sample):
return sample
@@ -641,7 +642,7 @@ class PPOAlpacaStyleInstructionHandler(BaseDatasetHandler):
messages = example["prompt"] + example["response"]
for source_ids, _ in self.llama_factory_template.encode_multiturn(
- tokenizer, messages, example["system"][0], example["tools"][0]
+ tokenizer, messages, example["system"][0], example["tools"][0], self.cutoff_len
):
input_ids += source_ids
@@ -680,7 +681,8 @@ class R1AlpacaStyleInstructionHandler(BaseDatasetHandler):
self.args.output_prefix = self.args.output_prefix + "_packed"
self.ignored_label = -100
self.is_multi_turn = True
- self.llama_factory_template = get_model_template(args.prompt_type.strip(), args.prompt_type_path.strip())
+ self.llama_factory_template = get_model_template(args.prompt_type.strip(), args.prompt_type_path.strip(), args.enable_thinking)
+ self.cutoff_len = args.seq_length
def _format_msg(self, sample):
return sample
@@ -700,7 +702,7 @@ class R1AlpacaStyleInstructionHandler(BaseDatasetHandler):
messages = example["prompt"] + example["response"]
multiturn_input_ids = self.llama_factory_template.encode_multiturn(
- tokenizer, messages, example["system"][0], example["tools"][0])
+ tokenizer, messages, example["system"][0], example["tools"][0], self.cutoff_len)
turns = len(multiturn_input_ids)
diff --git a/mindspeed_llm/tasks/preprocess/decoder_packed_mtf_dataset.py b/mindspeed_llm/tasks/preprocess/decoder_packed_mtf_dataset.py
index 980b56939ed7dde9f9aa04a2e6b52919963d932d..49670bfb81cbf003907319c2a1be7021e6e35f7a 100644
--- a/mindspeed_llm/tasks/preprocess/decoder_packed_mtf_dataset.py
+++ b/mindspeed_llm/tasks/preprocess/decoder_packed_mtf_dataset.py
@@ -246,7 +246,7 @@ class DecoderPackedMTFDataset(torch.utils.data.Dataset):
template = None
# get model chat template
if hasattr(self.args, "prompt_type") and self.args.prompt_type is not None:
- template = get_model_template(self.args.prompt_type, self.args.prompt_type_path)
+ template = get_model_template(self.args.prompt_type, self.args.prompt_type_path, self.args.enable_thinking)
prompt_begin_list, prompt_end_list = get_prompt_index(item["labels"], IGNORE_INDEX)
diff --git a/mindspeed_llm/tasks/preprocess/formatter.py b/mindspeed_llm/tasks/preprocess/formatter.py
index aaec643c140c7a0a8de04c8b11feef17dc0444a3..2fd6031d4bef52c19f368335cec804e51c09a149 100644
--- a/mindspeed_llm/tasks/preprocess/formatter.py
+++ b/mindspeed_llm/tasks/preprocess/formatter.py
@@ -38,6 +38,15 @@ TOOL_SYSTEM_PROMPT = (
)
+QWEN_TOOL_PROMPT = (
+ "\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\n"
+ "You are provided with function signatures within XML tags:\n{tool_text}"
+ "\n\n\nFor each function call, return a json object with function name and arguments within "
+ """ XML tags:\n\n{{"name": , """
+ """"arguments": }}\n"""
+)
+
+
def default_tool_formatter(tools: List[Dict[str, Any]]) -> str:
tool_text = ""
tool_names = []
@@ -201,3 +210,120 @@ class ToolFormatter(Formatter):
return default_tool_extractor(content)
else:
raise NotImplementedError
+
+
+@dataclass
+class ToolUtils(ABC):
+ """Base class for tool utilities."""
+
+ @staticmethod
+ @abstractmethod
+ def tool_formatter(tools: list[dict[str, Any]]) -> str:
+ r"""Generate the system message describing all the available tools."""
+ ...
+
+ @staticmethod
+ @abstractmethod
+ def function_formatter(functions: list[Tuple[str, str]]) -> str:
+ r"""Generate the assistant message including all the tool calls."""
+ ...
+
+ @staticmethod
+ @abstractmethod
+ def tool_extractor(content: str) -> Union[str, list[Tuple[str, str]]]:
+ r"""Extract all the function calls from the assistant message.
+
+ It should be an inverse function of `function_formatter`.
+ """
+ ...
+
+
+class QwenToolUtils(ToolUtils):
+ r"""Qwen 2.5 tool using template."""
+
+ @staticmethod
+ def tool_formatter(tools: list[dict[str, Any]]) -> str:
+ tool_text = ""
+ for tool in tools:
+ wrapped_tool = tool if tool.get("type") == "function" else {"type": "function", "function": tool}
+ tool_text += "\n" + json.dumps(wrapped_tool, ensure_ascii=False)
+
+ return QWEN_TOOL_PROMPT.format(tool_text=tool_text)
+
+ @staticmethod
+ def function_formatter(functions: list[Tuple[str, str]]) -> str:
+ function_texts = [
+ json.dumps({"name": name, "arguments": json.loads(arguments)}, ensure_ascii=False)
+ for name, arguments in functions
+ ]
+ return "\n".join([f"\n{text}\n" for text in function_texts])
+
+ @staticmethod
+ def tool_extractor(content: str) -> Union[str, list[Tuple[str, str]]]:
+ regex = re.compile(r"(.+?)(?=\s*|\s*$)", re.DOTALL)
+ tool_match: list[str] = re.findall(regex, content)
+ if not tool_match:
+ return content
+
+ results = []
+ for tool in tool_match:
+ try:
+ tool = json.loads(tool.strip())
+ except json.JSONDecodeError:
+ return content
+
+ if "name" not in tool or "arguments" not in tool:
+ return content
+
+ results.append((tool["name"], json.dumps(tool["arguments"], ensure_ascii=False)))
+
+ return results
+
+
+@dataclass
+class Qwen3FunctionFormatter(StringFormatter):
+ def __post_init__(self):
+ super().__post_init__()
+ self.tool_utils = QwenToolUtils()
+
+ def apply(self, **kwargs) -> SLOTS:
+ content: str = kwargs.pop("content")
+ regex = re.compile(r"(.*)", re.DOTALL)
+ thought = re.search(regex, content)
+ if thought:
+ content = content.replace(thought.group(0), "")
+
+ functions: list[Tuple[str, str]] = []
+ try:
+ tool_calls = json.loads(content)
+ if not isinstance(tool_calls, list): # parallel function call
+ tool_calls = [tool_calls]
+
+ for tool_call in tool_calls:
+ functions.append((tool_call["name"], json.dumps(tool_call["arguments"], ensure_ascii=False)))
+
+ except json.JSONDecodeError as e:
+ raise RuntimeError(f"Invalid JSON format in function message: {str([content])}.") from e # flat string
+
+ function_str = self.tool_utils.function_formatter(functions)
+ if thought:
+ function_str = thought.group(0) + function_str
+
+ return super().apply(content=function_str)
+
+
+@dataclass
+class Qwen3ToolFormatter(Formatter):
+ def __post_init__(self):
+ self.tool_utils = QwenToolUtils()
+
+ def apply(self, **kwargs) -> SLOTS:
+ content = kwargs.pop("content")
+ try:
+ tools = json.loads(content)
+ return [self.tool_utils.tool_formatter(tools) if len(tools) != 0 else ""]
+ except json.JSONDecodeError as e:
+ raise RuntimeError(f"Invalid JSON format in tool description: {str([content])}.") from e # flat string
+
+ def extract(self, content: str) -> Union[str, Tuple[str, str]]:
+ return self.tool_utils.tool_extractor(content)
\ No newline at end of file
diff --git a/mindspeed_llm/tasks/preprocess/templates.py b/mindspeed_llm/tasks/preprocess/templates.py
index 57aa726bf3d9427839bc1a6a5a16eaf4f548abaf..e20d4a28d0f4a0221d3a05eb7d6b7b48aa42308f 100644
--- a/mindspeed_llm/tasks/preprocess/templates.py
+++ b/mindspeed_llm/tasks/preprocess/templates.py
@@ -14,15 +14,17 @@
# limitations under the License.
import os
+import sys
import re
import json
import logging
+from copy import deepcopy
from pathlib import Path
from dataclasses import dataclass
from enum import Enum, unique
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
-from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter
+from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter, Qwen3FunctionFormatter, Qwen3ToolFormatter
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer
@@ -99,9 +101,11 @@ class Template:
format_prefix: "Formatter"
default_system: str
stop_words: List[str]
+ thought_words: tuple[str, str]
efficient_eos: bool
replace_eos: bool
force_system: bool
+ enable_thinking: Optional[bool]
def encode_oneturn(
@@ -231,6 +235,22 @@ class Template:
return encoded_pairs
+ def add_thought(self, content: str = "") -> str:
+ r"""Add empty thought to assistant message."""
+ return f"{self.thought_words[0]}\n\n{self.thought_words[1]}\n\n" + content
+
+
+ def remove_thought(self, content: str) -> str:
+ r"""Remove thought from assistant message."""
+ pattern = re.compile(f"{re.escape(self.thought_words[0])}(.*?){re.escape(self.thought_words[1])}", re.DOTALL)
+ return re.sub(pattern, "", content).lstrip("\n")
+
+
+ def get_thought_word_ids(self, tokenizer: "PreTrainedTokenizer") -> list[int]:
+ r"""Get the token ids of thought words."""
+ return tokenizer.encode(self.add_thought(), add_special_tokens=False)
+
+
@dataclass
class Llama2Template(Template):
def _encode(
@@ -276,6 +296,133 @@ class Llama2Template(Template):
return self._make_pairs(encoded_messages, cutoff_len, reserved_label_len)
+@dataclass
+class ReasoningTemplate(Template):
+ r"""A template that add thought to assistant message."""
+
+ def _encode(
+ self,
+ tokenizer: "PreTrainedTokenizer",
+ messages: List[Dict[str, str]],
+ system: str,
+ tools: str,
+ ) -> Sequence[list[int]]:
+ r"""
+ Encodes formatted inputs to pairs of token ids.
+ Turn 0: prefix + system + query resp
+ Turn t: sep + query resp
+ """
+ system = system or self.default_system
+ encoded_messages = []
+ for i, message in enumerate(messages):
+ elements = []
+ if i == 0:
+ elements += self.format_prefix.apply()
+ if system or tools:
+ tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
+ elements += self.format_system.apply(content=(system + tool_text))
+ elif i > 0 and i % 2 == 0:
+ elements += self.format_separator.apply()
+
+ if message["role"] == Role.USER.value:
+ elements += self.format_user.apply(content=message["content"], idx=str(i // 2))
+ elif message["role"] == Role.ASSISTANT.value:
+ elements += self.format_assistant.apply(content=message["content"])
+ elif message["role"] == Role.OBSERVATION.value:
+ elements += self.format_observation.apply(content=message["content"])
+ elif message["role"] == Role.FUNCTION.value:
+ elements += self.format_function.apply(content=message["content"])
+ else:
+ raise NotImplementedError("Unexpected role: {}".format(message["role"]))
+ encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements))
+
+ return encoded_messages
+
+ def encode_oneturn(
+ self,
+ tokenizer: "PreTrainedTokenizer",
+ messages: list[dict[str, str]],
+ system: Optional[str] = None,
+ tools: Optional[str] = None,
+ cutoff_len: int = 1_000_000,
+ reserved_label_len: int = 1,
+ ) -> Tuple[list[int], list[int]]:
+ messages = deepcopy(messages)
+ for i in range(1, len(messages) - 2, 2):
+ messages[i]["content"] = self.remove_thought(messages[i]["content"])
+
+ if self.enable_thinking is False: # remove all cot
+ messages[-1]["content"] = self.remove_thought(messages[-1]["content"])
+
+ prompt_ids, response_ids = super().encode_oneturn(tokenizer, messages, system, tools)
+ if (
+ self.thought_words[0] not in messages[-1]["content"]
+ and self.thought_words[1] not in messages[-1]["content"]
+ ): # add empty cot
+ if not self.enable_thinking: # do not compute loss
+ prompt_ids += self.get_thought_word_ids(tokenizer)
+ else: # do compute loss
+ response_ids = self.get_thought_word_ids(tokenizer) + response_ids
+
+ return prompt_ids, response_ids
+
+ def encode_multiturn(
+ self,
+ tokenizer: "PreTrainedTokenizer",
+ messages: list[dict[str, str]],
+ system: Optional[str] = None,
+ tools: Optional[str] = None,
+ cutoff_len: int = 1_000_000,
+ reserved_label_len: int = 1,
+ ) -> Sequence[Tuple[List[int], List[int]]]:
+ messages = deepcopy(messages)
+
+ if self.enable_thinking is False: # remove all cot
+ for i in range(1, len(messages), 2):
+ messages[i]["content"] = self.remove_thought(messages[i]["content"])
+
+ encoded_messages = self._encode(tokenizer, messages, system, tools)
+
+ for i in range(0, len(messages), 2):
+ if (
+ self.thought_words[0] not in messages[i + 1]["content"]
+ and self.thought_words[1] not in messages[i + 1]["content"]
+ ): # add empty cot
+ if not self.enable_thinking: # do not compute loss
+ encoded_messages[i] += self.get_thought_word_ids(tokenizer)
+ else: # do compute loss
+ encoded_messages[i + 1] = self.get_thought_word_ids(tokenizer) + encoded_messages[i + 1]
+
+ return self._make_pairs(encoded_messages, cutoff_len, reserved_label_len)
+
+ def _make_pairs(
+ self,
+ encoded_messages: Sequence[List[int]],
+ cutoff_len: int,
+ reserved_label_len: int,
+ ) -> Sequence[Tuple[List[int], List[int]]]:
+ from .decoder_packed_mtf_dataset import _infer_seqlen
+
+ encoded_pairs = []
+ total_length = 0
+ cutoff_len = cutoff_len - reserved_label_len
+ for i in range(0, len(encoded_messages), 2):
+ if total_length >= cutoff_len:
+ break
+
+ max_source_len, max_target_len = _infer_seqlen(
+ source_len=len(encoded_messages[i]),
+ target_len=len(encoded_messages[i + 1]),
+ cutoff_len=(cutoff_len - total_length)
+ )
+ source_ids = encoded_messages[i][:max_source_len]
+ target_ids = encoded_messages[i + 1][:max_target_len]
+ total_length += len(source_ids) + len(target_ids)
+ encoded_pairs.append((source_ids, target_ids))
+
+ return encoded_pairs
+
+
templates: Dict[str, Template] = {}
@@ -283,8 +430,8 @@ def get_templates() -> Dict[str, Template]:
return templates
-def get_model_template(name, prompt_type_path):
- name = register_custom_template(name, prompt_type_path)
+def get_model_template(name, prompt_type_path, enable_thinking):
+ name = register_custom_template(name, prompt_type_path, enable_thinking)
if name is None:
template = templates["empty"] # placeholder
else:
@@ -298,8 +445,9 @@ def fix_model_tokenizer(
tokenizer: "PreTrainedTokenizer",
name: Optional[str] = None,
prompt_type_path: Optional[str] = None,
+ enable_thinking: Optional[bool] = False,
):
- template = get_model_template(name, prompt_type_path)
+ template = get_model_template(name, prompt_type_path, enable_thinking)
stop_words = template.stop_words
if template.replace_eos:
@@ -337,9 +485,12 @@ def _register_template(
format_prefix: Optional["Formatter"] = None,
default_system: str = "",
stop_words: List[str] = [],
+ thought_words: Optional[tuple[str, str]] = None,
efficient_eos: bool = False,
replace_eos: bool = False,
force_system: bool = False,
+ enable_thinking: Optional[bool] = True,
+ template_class: type["Template"] = Template,
) -> None:
r"""
Registers a chat template.
@@ -368,7 +519,6 @@ def _register_template(
```
"""
eos_slots = [] if efficient_eos else [{"eos_token"}]
- template_class = Llama2Template if name.startswith("llama2") else Template
default_user_formatter = StringFormatter(slots=["{{content}}"])
default_assistant_formatter = StringFormatter(slots=["{{content}}"] + eos_slots)
default_function_formatter = FunctionFormatter(slots=["Action: {{name}}\nAction Input: {{arguments}}"] + eos_slots)
@@ -386,9 +536,11 @@ def _register_template(
format_prefix=format_prefix or default_prefix_formatter,
default_system=default_system,
stop_words=stop_words,
+ thought_words=thought_words or ("", ""),
efficient_eos=efficient_eos,
replace_eos=replace_eos,
force_system=force_system,
+ enable_thinking=enable_thinking,
)
@@ -472,7 +624,7 @@ def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer")
return jinja_template
-def register_custom_template(name, json_file_path=TEMPLATES_DIR) -> str:
+def register_custom_template(name, json_file_path=TEMPLATES_DIR, enable_thinking=False) -> str:
if name in templates:
return name
@@ -501,18 +653,23 @@ def register_custom_template(name, json_file_path=TEMPLATES_DIR) -> str:
efficient_eos = _format_custom_template(config.get("efficient_eos", False))
replace_eos = _format_custom_template(config.get("replace_eos", False))
force_system = _format_custom_template(config.get("force_system", False))
+ template_class = _format_custom_template(config.get("template_class", None))
if isinstance(default_system, list):
default_system = "".join(default_system) if all(isinstance(sentence, str) for sentence in default_system) else default_system
format_user = StringFormatter(**format_user) if format_user else None
format_assistant = StringFormatter(**format_assistant) if format_assistant else None
format_system = StringFormatter(**format_system) if format_system else None
- format_function = FunctionFormatter(**format_function) if format_function else None
format_observation = StringFormatter(**format_observation) if format_observation else None
- format_tools = ToolFormatter(**format_tools) if format_tools else None
format_separator = EmptyFormatter(**format_separator) if format_separator else None
format_prefix = EmptyFormatter(**format_prefix) if format_prefix else None
-
+ template_class = _get_template_class(template_class) if template_class else Template
+ if name == 'qwen3':
+ format_function = Qwen3FunctionFormatter(**format_function) if format_function else None
+ format_tools = Qwen3ToolFormatter(**format_tools) if format_tools else None
+ else:
+ format_function = FunctionFormatter(**format_function) if format_function else None
+ format_tools = ToolFormatter(**format_tools) if format_tools else None
_register_template(
name=name,
@@ -526,9 +683,12 @@ def register_custom_template(name, json_file_path=TEMPLATES_DIR) -> str:
format_prefix=format_prefix,
default_system=default_system,
stop_words=stop_words,
+ thought_words=("", ""),
efficient_eos=efficient_eos,
replace_eos=replace_eos,
- force_system=force_system
+ force_system=force_system,
+ enable_thinking=enable_thinking,
+ template_class=template_class
)
return name
@@ -538,4 +698,16 @@ def _format_custom_template(slots: Dict) -> Dict:
if slots and isinstance(slots, Dict):
for key, slot in slots.items():
slots[key] = list(map(lambda slot: set(slot) if isinstance(slot, list) else slot, slot)) if slot else None
- return slots
\ No newline at end of file
+ return slots
+
+
+def _get_template_class(template_name: str) -> None:
+ current_module = sys.modules.get(__name__)
+ if not current_module:
+ raise Exception("curent module not found")
+ template_class = getattr(current_module, template_name, None)
+ if template_class is None:
+ template_class = Template
+ logger.info("template will use %s to format dataset", template_class.__name__)
+
+ return template_class
\ No newline at end of file
diff --git a/mindspeed_llm/training/arguments.py b/mindspeed_llm/training/arguments.py
index faa582ea9b2ea40abe01e9023aea313698f51a30..cbfc304c8879881e682c16013fd11ad9318960e6 100644
--- a/mindspeed_llm/training/arguments.py
+++ b/mindspeed_llm/training/arguments.py
@@ -467,6 +467,9 @@ def _add_data_args(parser):
group.add_argument('--padded-samples', action='store_true',
help='fill in the missing samples within an epoch, '
'starting at index 0, aligned with the LlamaFatory.')
+ group.add_argument("--enable-thinking", action='store_true',
+ help="Whether or not to enable thinking mode for reasoning models.")
+
return parser
@@ -838,7 +841,7 @@ def _add_training_args(parser):
group.add_argument('--prompt-type', type=str, default=None,
choices=['default', 'empty', 'trl', 'chatglm2', 'chatglm3', 'chatglm3_system', 'glm4', 'chatml',
'chatml_de', 'qwen', 'qwen_r1', "qwen_math_r1", 'llama3', 'llama2', 'mistral', 'mixtral', 'gemma', 'alpaca',
- 'deepseek2', 'deepseek2-lite', 'minicpm3', 'cpm', 'baichuan2', 'deepseek3', 'intern2', 'hunyuan'],
+ 'deepseek2', 'deepseek2-lite', 'minicpm3', 'cpm', 'baichuan2', 'deepseek3', 'intern2', 'hunyuan', 'qwen3'],
help='Which template to use for constructing prompts in training/inference.' 'e.g., "qwen"')
group.add_argument('--prompt-type-path', type=str, default=TEMPLATES_DIR,
help='Path to the json file of templates.')
diff --git a/mindspeed_llm/training/tokenizer/tokenizer.py b/mindspeed_llm/training/tokenizer/tokenizer.py
index 502af3fc114fcacc461820b7ae9295d1d1818ae5..b82b094270224f378691286472a308d25c2e46d3 100644
--- a/mindspeed_llm/training/tokenizer/tokenizer.py
+++ b/mindspeed_llm/training/tokenizer/tokenizer.py
@@ -66,7 +66,7 @@ def build_tokenizer(args):
if ("PreTrainedTokenizerBase" not in str(tokenizer.tokenizer._pad.__func__)):
tokenizer.tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer.tokenizer)
tokenizer.tokenizer.padding_side = "right"
- fix_model_tokenizer(tokenizer.tokenizer, args.prompt_type.strip(), args.prompt_type_path.strip())
+ fix_model_tokenizer(tokenizer.tokenizer, args.prompt_type.strip(), args.prompt_type_path.strip(), args.enable_thinking)
if args.tokenizer_type == "GPTSentencePieceTokenizer":
tokenizer.tokenizer.eos_token_id = tokenizer.tokenizer._eos_id
diff --git a/preprocess_data.py b/preprocess_data.py
index c80d1fd4beac83680ac3d01d562c46076652bd43..9613db334f8bec73cd286fd6cd3c94c603d23c9c 100644
--- a/preprocess_data.py
+++ b/preprocess_data.py
@@ -113,7 +113,7 @@ def add_data_args(parser):
group.add_argument('--prompt-type', type=str, default=None,
choices=['default', 'empty', 'trl', 'chatglm2', 'chatglm3', 'chatglm3_system', 'glm4', 'chatml',
'chatml_de', 'qwen', 'qwen_r1', "qwen_math_r1", 'llama3', 'llama2', 'mistral', 'mixtral', 'gemma', 'alpaca',
- 'deepseek2', 'deepseek2-lite', 'cpm', 'baichuan2', 'deepseek3', 'intern2', 'hunyuan'],
+ 'deepseek2', 'deepseek2-lite', 'cpm', 'baichuan2', 'deepseek3', 'intern2', 'hunyuan', 'qwen3'],
help='Which template to use for constructing prompts in training.'
'e.g., "qwen"')
group.add_argument('--prompt-type-path', type=str, default=TEMPLATES_DIR,
@@ -150,6 +150,8 @@ def add_data_args(parser):
help="Use a zigzag attention mask.")
group.add_argument("--script-data-dir", type=str, default=None,
help="Python script dataset direction")
+ group.add_argument("--enable-thinking", action='store_true',
+ help="Whether or not to enable thinking mode for reasoning models.")
group.add_argument("--pad-to-multiple-of", type=int, default=1,
help="Pad each of the data to the multiple of...")