diff --git a/mindformers/checkpoint/checkpoint.py b/mindformers/checkpoint/checkpoint.py index 4bdfb6aace0fde630559db6a155afa34c2abadc4..5fd99d09d39cc8b7607a143b76823b417723afa8 100644 --- a/mindformers/checkpoint/checkpoint.py +++ b/mindformers/checkpoint/checkpoint.py @@ -15,10 +15,12 @@ """load/save checkpoint apis.""" import os import json +import copy import tempfile from time import time from typing import Callable, Union, Dict, Optional, Tuple, Any, List from dataclasses import dataclass +import concurrent.futures import threading from multiprocessing import active_children @@ -29,9 +31,9 @@ from mindspore import Tensor, Parameter, load_param_into_net from mindspore.common import dtype as mstype from mindspore.nn import Cell from mindspore.nn.optim.optimizer import Optimizer -from mindspore.communication.management import get_rank, get_group_size from mindspore.communication import comm_func from mindspore import save_checkpoint as ms_save_checkpoint +from mindspore import load_checkpoint as ms_load_checkpoint from mindformers.tools.logger import logger from mindformers.checkpoint.reshard import ReshardHandler @@ -48,21 +50,22 @@ from mindformers.checkpoint.utils import ( get_checkpoint_name, get_checkpoint_tracker_filename, get_latest_iteration_from_tracker, + get_sharded_tensor_shard_id, get_common_filename, check_checkpoints_dir_max_num, get_metadata_filename, verify_ckpt_valid, + get_core_network, FileType ) from mindformers.checkpoint.fully_parallel import BalancedSaveStrategy, apply_balance_shard_strategy from mindformers.checkpoint.metadata import ( save_metadata, - load_metadata, - generate_default_metadata_from_checkpoint, get_total_params_file_mapping_info, + get_metadata_of_checkpoint, + params_key_mapping ) from mindformers.checkpoint.sharded_tensor import ( - get_strategy_info_from_sharded_tensor, ShardedTensor, get_sharded_tensor_from_cell, get_cur_sharded_tensor, @@ -70,6 +73,7 @@ from mindformers.checkpoint.sharded_tensor import ( get_param_redundancy_after_balanced ) from mindformers.checkpoint.broadcast import single_parameter_broadcast +from mindformers.checkpoint.reshard_loader import ReshardLoader @dataclass @@ -255,7 +259,7 @@ class AsyncSaveManager: """ if self.async_save is False: return True - if get_group_size() == 1: + if get_real_group_size() == 1: return not is_alive ten = Tensor([is_alive], dtype=mstype.int8) @@ -306,7 +310,7 @@ def save_checkpoint(iteration: int, network: Cell, optimizer: Optimizer = None, # Whether to use async save. use_async_save = async_save_manager is not None - if get_rank() == 0: + if get_real_rank() == 0: os.makedirs(cur_iter_checkpoint_dir, exist_ok=True) set_safe_mode_for_file_or_dir(checkpoints_root_path) set_safe_mode_for_file_or_dir(cur_iter_checkpoint_dir) @@ -330,11 +334,11 @@ def save_checkpoint(iteration: int, network: Cell, optimizer: Optimizer = None, if use_async_save: async_save_manager.prepare_before_save() - if get_rank() == 0: + if get_real_rank() == 0: async_save_manager.add_finalize_fn(iter_finalize_func) # Check if the number of saved folders has exceeded, and delete the oldest one. - if get_rank() == 0: + if get_real_rank() == 0: # NOTE: Currently only supports shared storage scenarios. check_checkpoints_dir_max_num(keep_max_num, checkpoints_root_path) # If the current iteration checkpoint directory be removed, raise an error to remind user @@ -362,7 +366,7 @@ def save_checkpoint(iteration: int, network: Cell, optimizer: Optimizer = None, remove_model_redundancy.save(iteration) else: model_ckpt_filename = get_checkpoint_name( - cur_iter_checkpoint_dir, user_prefix, get_rank(), get_group_size(), FileType.MODEL + cur_iter_checkpoint_dir, user_prefix, get_real_rank(), get_real_group_size(), FileType.MODEL ) ms_save_checkpoint( network, @@ -388,7 +392,7 @@ def save_checkpoint(iteration: int, network: Cell, optimizer: Optimizer = None, # Optimizer weight has redundancy. logger.warning("....... Start to save optimizer weight .......") optimizer_ckpt_filename = get_checkpoint_name( - cur_iter_checkpoint_dir, user_prefix, get_rank(), get_group_size(), FileType.OPTIMIZER + cur_iter_checkpoint_dir, user_prefix, get_real_rank(), get_real_group_size(), FileType.OPTIMIZER ) ms_save_checkpoint( optimizer, @@ -402,7 +406,7 @@ def save_checkpoint(iteration: int, network: Cell, optimizer: Optimizer = None, logger.warning("Optimizer weight will not be save!") # Save 'common.json'. - if get_rank() == 0: + if get_real_rank() == 0: logger.info("...... Start saving common info ......") start_save_common_info_time = time() @@ -421,7 +425,7 @@ def save_checkpoint(iteration: int, network: Cell, optimizer: Optimizer = None, if not use_async_save: barrier_world("All ranks for sync save checkpoint.") logger.info("Rank_0 execute finalize func.") - if get_rank() == 0: + if get_real_rank() == 0: iter_finalize_func() logger.info(f"Save checkpoint cost time: {time() - start_save_ckpt_time:.3f}s.") @@ -430,7 +434,7 @@ def save_metadata_json(sharded_tensor_metas, model_keys, user_prefix, metadata_f """Saving metadata.json used `get_strategy_metadata` API.""" if sharded_tensor_metas is not None: logger.info("...... Start saving metadata ......") - if get_rank() == 0: + if get_real_rank() == 0: param_file_mappings = get_total_params_file_mapping_info(sharded_tensor_metas, user_prefix, model_keys) save_metadata(sharded_tensor_metas, param_file_mappings, metadata_file_path) @@ -441,452 +445,560 @@ def save_metadata_json(sharded_tensor_metas, model_keys, user_prefix, metadata_f logger.info("No need to save metadata.json for single card.") -def load_safetensor( - checkpoint_path: str, - param_name: Optional[Union[str, List[str]]] = None, - index_tuple: Optional[Union[Tuple[Tuple[int, int], ...], List[Tuple[Tuple[int, int], ...]]]] = None, - dtype: Optional[Union[Any, List[Any]]] = mstype.float32 -) -> Dict[str, Parameter]: +def smart_slice(tensor, slice_ranges, load_from_multi_rank=False): """ - Loads tensors from a Safetensors file into MindSpore Parameters. - - This function reads a Safetensors file and converts specified tensors into MindSpore - Parameter objects with the specified data type. It can load either specific tensors - (with optional slicing) or all tensors from the file. + Slices a tensor based on specified slice ranges and determines if it's a full slice. Args: - checkpoint_path: Path to the Safetensors file to load - param_name: Optional name or list of names of specific tensors to load. - If None, loads all tensors from the file. - index_tuple: Optional slicing indices for tensors. Can be a single tuple of (start, end) - index pairs or a list of such tuples. Must match the dimension of the - corresponding tensor if provided. - dtype: Target data type(s) for the loaded tensors. Can be a single type or a list of types. - Defaults to mstype.float32. + tensor: The tensor to slice (can be Parameter, Tensor, or have .shape attribute) + slice_ranges: List of (start, end) tuples specifying slice ranges for each dimension + load_from_multi_rank: If True, forces slicing even for full slices (for multi-rank loading) Returns: - Dictionary mapping tensor names to MindSpore Parameter objects + Tuple[sliced_tensor, is_full_slice]: + - sliced_tensor: The original tensor if full slice and not load_from_multi_rank, + otherwise the sliced numpy array + - is_full_slice: True if the slice covers the entire tensor, False otherwise Raises: - ValueError: If the file doesn't exist, if index_tuple dimension doesn't match - tensor dimension, or if parameter lists have mismatched lengths - KeyError: If a specified parameter name doesn't exist in the file + ValueError: If slice dimension count doesn't match tensor dimension count """ - # Validate file existence - if not os.path.exists(checkpoint_path): - raise FileNotFoundError(f"Safetensors file not found at: {checkpoint_path}") - - # Warn about unused index_tuple when no parameter name is specified - if param_name is None and index_tuple is not None: - logger.warning("index_tuple is ignored when param_name is None (loading all parameters)") - - def _convert_to_list(param: Optional[Union[object, List[object]]]) -> Optional[List[object]]: - """Convert parameter to list if it's not already a list""" - if param is None: - return None - return [param] if not isinstance(param, list) else param - - def _align_list_length(param_list: List[object], target_length: int) -> List[object]: - """Align list length to match target length, supporting broadcasting of single-element lists""" - if len(param_list) == target_length: - return param_list - if len(param_list) == 1: - return param_list * target_length - raise ValueError(f"List length {len(param_list)} cannot be aligned to target length {target_length}") - - # Unify parameters to list format for consistent processing - param_name_list: Optional[List[str]] = _convert_to_list(param_name) - index_tuple_list: Optional[List[Tuple[Tuple[int, int], ...]]] = _convert_to_list(index_tuple) - dtype_list: Optional[List[Any]] = _convert_to_list(dtype) - - # Validate parameter list length consistency - if param_name_list is not None: - # Validate index list length - if index_tuple_list is not None and len(index_tuple_list) != len(param_name_list): - raise ValueError( - f"Length of index_tuple ({len(index_tuple_list)}) must match " - f"length of param_name ({len(param_name_list)})" - ) + # Get tensor shape - handle both Parameter and Tensor types + tensor_shape = tensor.shape - # Validate data type list length - if dtype_list is not None: - dtype_list = _align_list_length(dtype_list, len(param_name_list)) - - weights: Dict[str, Parameter] = {} - - # Load data from Safetensors file - with safe_open(checkpoint_path, framework="np", device="cpu") as f: - if param_name_list: - # Load specified parameters - for idx, param_name_ in enumerate(param_name_list): - # Get data type for current parameter - cur_dtype = dtype_list[idx] if dtype_list and dtype_list[idx] else mstype.float32 - - # Load tensor from file - try: - tensor_np = f.get_tensor(param_name_) - except KeyError as e: - raise KeyError(f"Parameter '{param_name_}' not found in Safetensors file") from e - - # Apply slicing if specified - if index_tuple_list is not None: - index_tuple = index_tuple_list[idx] - if len(index_tuple) != tensor_np.ndim: - raise ValueError( - f"Index tuple dimension ({len(index_tuple)}) does not match " - f"parameter '{param_name_}' dimension ({tensor_np.ndim})" - ) - # Create slice objects and apply - slices = tuple(slice(start, end) for start, end in index_tuple) - tensor_np = tensor_np[slices] - - # Convert to MindSpore Parameter - weights[param_name_] = Parameter( - ms.from_numpy(tensor_np).astype(cur_dtype), name=param_name_, requires_grad=False - ) - else: - # Load all parameters - cur_dtype = dtype if not isinstance(dtype, list) else dtype[0] - for key in f.keys(): - tensor_np = f.get_tensor(key) - weights[key] = Parameter( - ms.from_numpy(tensor_np).astype(cur_dtype), name=key, requires_grad=False - ) + if len(slice_ranges) != len(tensor_shape): + raise ValueError( + f"Slice dimension count ({len(slice_ranges)}) does not " + f"match tensor dimension count ({len(tensor_shape)})" + ) - return weights + # Check if this is a full slice + is_full_slice = all( + start == 0 and end == dim_size + for (start, end), dim_size in zip(slice_ranges, tensor_shape) + ) + # If full slice and not loading from multiple ranks, return original tensor + if is_full_slice and not load_from_multi_rank: + return tensor, is_full_slice -def load_tensor_by_offset( - all_offset: Dict[int, Tuple[Tuple[int, int], ...]], - param_name: str, - checkpoint_dir: str, - src_sharded_tensor_metas: Dict[str, List[ShardedTensor]], - param_file_mappings: Dict[str, List[Dict[str, Any]]], - key_mapping: Dict[str, str], -) -> Dict[int, Parameter]: - """ - Loads specific tensor slices from checkpoint files based on offset information. + # Perform the slice + slice_indices = tuple(slice(start, end) for start, end in slice_ranges) + if isinstance(tensor, Tensor) or isinstance(tensor, Parameter): + # MindSpore Tensor/Parameter + sliced_tensor = copy.deepcopy(tensor.asnumpy()[slice_indices]) + else: + # Numpy array or other array-like + sliced_tensor = copy.deepcopy(tensor[slice_indices]) - Retrieves the appropriate segments of a tensor from checkpoint files according to - the provided offset information. Handles storage rank mapping and potential resharding - to ensure the correct tensor slices are loaded for each rank. + return sliced_tensor, is_full_slice - Args: - all_offset: Dictionary mapping ranks to their respective tensor slice offsets - param_name: Name of the parameter/tensor to load - checkpoint_dir: Directory containing the checkpoint files - src_sharded_tensor_metas: Metadata for source sharded tensors - param_file_mappings: Mapping of parameters to their storage files - key_mapping: Mapping of `original key` in checkpoint to `param key` in network. - Returns: - Dictionary mapping ranks to their corresponding loaded Parameter objects +def build_tensors_reformulation_all_offsets(params_with_sharded_tensor): """ + Builds parameter information with all tensor offsets for sharded tensor resharding. - def _get_storage_info_of_sharded_tensor( - sharded_tensor: ShardedTensor, - param_file_mappings: Dict[str, List[Dict[str, Any]]] - ) -> List[Dict[str, Any]]: - """Retrieves storage information for a specific sharded tensor.""" - param_key = str((sharded_tensor.org_key, sharded_tensor.global_offset)) - return param_file_mappings[param_key] - - def _get_storage_rank_dict_of_param( - sharded_tensor_metas: Dict[str, List[ShardedTensor]], - param_file_mappings: Dict[str, List[Dict[str, Any]]], - param_name: str - ) -> Dict[int, Tuple[str, Any]]: - """Creates a dictionary mapping storage ranks to their file and dtype information.""" - storage_rank_dict: Dict[int, Tuple[str, Any]] = {} - if param_name not in sharded_tensor_metas: - param_name = key_mapping[param_name] + Args: + params_with_sharded_tensor (dict): + A dictionary mapping parameter names to tuples of sharded tensor objects. + Each tuple has the format (src_sharded_tensor_or_list, dst_sharded_tensor): + - src_sharded_tensor_or_list: Source sharded tensor object(s). Can be a single + ShardedTensor or a list of ShardedTensors (for concat scenarios) + - dst_sharded_tensor: Destination sharded tensor object with target layout info - for sharded_tensor in sharded_tensor_metas[param_name]: - storage_info_list = _get_storage_info_of_sharded_tensor(sharded_tensor, param_file_mappings) - for storage_info in storage_info_list: - storage_rank = storage_info["storage_rank"] - storage_rank_dict[storage_rank] = (storage_info["file_name"], sharded_tensor.dtype) - return storage_rank_dict - - # Get storage rank information for the parameter - storage_rank_dict = _get_storage_rank_dict_of_param( - src_sharded_tensor_metas, param_file_mappings, param_name) - - # Map storage ranks to source ranks and adjust offsets - storage_to_src_rank_mapping: Dict[int, int] = {} - for search_rank in list(all_offset.keys()): # Iterate over copy of keys to allow modification - if search_rank not in storage_rank_dict: - # Get first source sharded tensor for this parameter - src_sharded_tensor = next(iter(src_sharded_tensor_metas[param_name])) - - find_storage_rank = False - # Find matching storage rank using reshard handler - for storage_rank in storage_rank_dict: - reshard_handler = ReshardHandler( - param_name=param_name, - full_shape=src_sharded_tensor.global_shape, - from_layout=src_sharded_tensor.layout, - to_layout=src_sharded_tensor.layout, # No actual layout change - to_rank_id=storage_rank - ) - - # Get source rank from reshard handler - src_rank = next(iter(reshard_handler.infer_all_tensor_offset().keys())) - - if src_rank == search_rank: - # Update offset mapping and record rank correspondence - all_offset[storage_rank] = all_offset.pop(search_rank) - storage_to_src_rank_mapping[storage_rank] = src_rank - find_storage_rank = True - break - - if not find_storage_rank: - raise RuntimeError("Failed to find matching storage rank for the parameter") - - # Load tensor slices for each rank - from_tensor_map: Dict[int, Parameter] = {} - for search_rank, param_slice in all_offset.items(): - # Get file information and load the specific tensor slice - param_file_name, param_dtype = storage_rank_dict[search_rank] - param_file_path = os.path.join(checkpoint_dir, param_file_name) - - # Load the specific slice from the safetensor file - loaded_weights = load_safetensor(param_file_path, param_name, param_slice, param_dtype) - - # Use original source rank if mapping exists - mapped_rank = storage_to_src_rank_mapping.get(search_rank, search_rank) - from_tensor_map[mapped_rank] = loaded_weights[param_name] - - return from_tensor_map + Returns: + params_info (dict): A nested dictionary containing all offset information for each parameter. + Structure: { + param_name: { + "reshard_handler": ReshardHandler instance, + "all_offset": Dict mapping ranks to tensor slice ranges (start/end tuples) + } + } + """ + rank_id = get_real_rank() + params_info = {} + for param_name, (src_sharded_tensor, dst_sharded_tensor) in params_with_sharded_tensor.items(): + # Check if this is a concat scenario (src is a list) + if isinstance(src_sharded_tensor, list): + src_sharded_tensor = src_sharded_tensor[0] + + reshard_handler = ReshardHandler( + param_name, + dst_sharded_tensor.global_shape, + src_sharded_tensor.layout, + dst_sharded_tensor.layout, + rank_id + ) + all_offset = reshard_handler.infer_all_tensor_offset() + params_info[param_name] = {"all_offset": all_offset, "reshard_handler": reshard_handler} + return params_info -def categorize_params( - dst_sharded_tensor_metas: Dict[str, ShardedTensor], - src_sharded_tensor_metas: Dict[str, List[ShardedTensor]], - param_file_mappings: Dict[str, List[Dict[str, Any]]] -) -> Tuple[List[str], Dict[str, Dict[str, List[Any]]], Dict[str, Dict[str, List[Any]]], Dict[str, List[Any]]]: - """ - Categorizes parameters based on comparison of source and destination sharding strategies. +def build_tensors_reformulation_mappings( + checkpoint_dir, + params_info, + src_sharded_tensor_metas, + param_file_mappings, + key_mapping: Dict[str, str] = None +): + """Builds tensor resharding mappings and splits parameters into reshard-required and no-reshard groups. - Analyzes parameters from destination and source sharded tensor metadata to classify them into: - - Special parameters: Missing from source metadata - - No-shard parameters: Matching sharding strategies and offsets - - Online-shard parameters: Different sharding strategies requiring resharding + This function orchestrates the process of: + 1. Resolving parameter storage information (file paths, rank groups) from checkpoint metadata + 2. Mapping search ranks to storage ranks and adjusting tensor offset information + 3. Organizing checkpoint file load information by file name + 4. Loading checkpoint files, slicing tensors according to computed offsets, and classifying parameters + into those that need resharding (params_info_need_reshard) and those that don't (state_dict_no_reshard) Args: - dst_sharded_tensor_metas: Metadata for destination sharded tensors - src_sharded_tensor_metas: Metadata for source sharded tensors - param_file_mappings: Mapping of parameters to their storage files + checkpoint_dir: Root directory containing checkpoint files + params_info: Nested dictionary with resharding metadata for each parameter. + Structure: { + param_name: { + "reshard_handler": ReshardHandler instance, + "all_offset": Dict mapping ranks to tensor slice ranges (start/end tuples) + } + } + src_sharded_tensor_metas: Dictionary mapping parameter names to a list of source ShardedTensor metadata + (key: parameter name, value: list of ShardedTensor instances for source) + param_file_mappings: Mapping from shard IDs to storage information lists. + Each storage info dict has "file_name", "storage_rank", "rank_group" keys. + key_mapping: Optional dictionary for parameter name remapping (used when original + name not found in src_sharded_tensor_metas). Defaults to None. Returns: - Tuple containing three collections: - - special_params: List of parameter names missing from source - - no_shard_params: Dict mapping filenames to params that don't need resharding - - online_shard_params: Dict of params that need resharding with their details + Tuple containing two dictionaries: + 1. params_info_need_reshard: Nested dictionary containing resharding metadata for parameters needing resharding. + Structure: { + param_name: { + "reshard_handler": ReshardHandler instance (handles tensor resharding logic), + "tensor_map": Dict mapping source ranks to sliced numpy tensor fragments + } + } + 2. state_dict_no_reshard: Parameters that don't need resharding (full tensor slices), + keyed by parameter name with Parameter objects as values. Raises: - ValueError: If a parameter exists in source metadata but has an empty list of ShardedTensor instances - RuntimeError: If global shapes of source and destination tensors for a parameter do not match - RuntimeError: If sharding strategies match but no corresponding parameter offset is found + ValueError: If parameter name not found in src_sharded_tensor_metas and no key_mapping provided, + or if slice dimension count doesn't match tensor dimension count in _smart_slice. + RuntimeError: If no matching storage rank found for a search rank in offset mapping. """ - # Initialize categorization collections - not_mapping_params: List[str] = [] - need_concat_params: Dict[str, Dict[str, List[Any]]] = {} - no_shard_params: Dict[str, Dict[str, List[Any]]] = {} - no_shard_params_list: List[str] = [] - online_shard_params: Dict[str, List[Any]] = {} + def _get_param_storage_info(param_name, sharded_tensor_metas, param_file_mappings, key_mapping): + """Retrieves storage information (file name, rank group) for a parameter by rank.""" + param_storage_info = {} + lookup_name = param_name + if param_name not in sharded_tensor_metas: + if not key_mapping: + raise ValueError(f"param name '{param_name}' not found in src sharded tensor metas.") + lookup_name = key_mapping.get(param_name, param_name) + if lookup_name not in sharded_tensor_metas: + raise ValueError(f"param name '{param_name}' (mapped to '{lookup_name}') not found in src sharded tensor metas.") + sharded_tensors = sharded_tensor_metas[lookup_name] + for sharded_tensor in sharded_tensors: + shard_id = get_sharded_tensor_shard_id( + sharded_tensor.org_key, sharded_tensor.global_offset) + storage_info_list = param_file_mappings[shard_id] + for storage_info in storage_info_list: + file_name = storage_info["file_name"] + storage_rank = storage_info["storage_rank"] + rank_group = storage_info["rank_group"] + param_storage_info[storage_rank] = {"file_name": file_name, "rank_group": rank_group} + return param_storage_info + + def _collect_files_load_info(params_info, sharded_tensor_metas, param_file_mappings, key_mapping): + """Collects and organizes checkpoint file load information for sharded tensor resharding.""" + files_load_info: Dict[str, Dict] = {} + + for param_name, reshard_info in params_info.items(): + reshard_handler = reshard_info["reshard_handler"] + all_offset = reshard_info.get("all_offset", reshard_handler.infer_all_tensor_offset()) + + param_storage_info = _get_param_storage_info( + param_name, sharded_tensor_metas, param_file_mappings, key_mapping) + + # Map storage ranks to source ranks and adjust offsets + storage_to_search_rank_mapping: Dict[int, int] = {} + for search_rank in list(all_offset.keys()): + if search_rank not in param_storage_info: + find_storage_rank = False + for storage_rank, storage_info in param_storage_info.items(): + rank_group = storage_info["rank_group"] + if search_rank in rank_group: + # Update offset mapping and record rank correspondence + all_offset[storage_rank] = all_offset.pop(search_rank) + storage_to_search_rank_mapping[storage_rank] = search_rank + find_storage_rank = True + if not find_storage_rank: + raise RuntimeError("Failed to find matching storage rank.\n" + f"param: {param_name}\nall_offset: {all_offset}\n" + f"param_storage_info: {param_storage_info}") + else: + storage_to_search_rank_mapping[search_rank] = search_rank + + # Organize load info by storage file name + for storage_rank, param_slice in all_offset.items(): + param_file_name = param_storage_info[storage_rank]["file_name"] + search_rank = storage_to_search_rank_mapping[storage_rank] + + if param_file_name not in files_load_info: + files_load_info[param_file_name] = { + "param_name_list": [param_name], + "param_slice_list": [param_slice], + "search_rank_list": [search_rank] + } + else: + files_load_info[param_file_name]["param_name_list"].append(param_name) + files_load_info[param_file_name]["param_slice_list"].append(param_slice) + files_load_info[param_file_name]["search_rank_list"].append(search_rank) + + return files_load_info + + files_load_info = _collect_files_load_info( + params_info=params_info, + sharded_tensor_metas=src_sharded_tensor_metas, + param_file_mappings=param_file_mappings, + key_mapping=key_mapping + ) - rank_id = get_real_rank() + state_dict_no_reshard: Dict[str, Parameter] = {} + params_info_need_reshard: Dict[str, Dict] = {} - # Analyze each parameter in destination metadata - for param_name in dst_sharded_tensor_metas: - # Handle parameters missing from source metadata - if param_name not in src_sharded_tensor_metas: - not_mapping_params.append(param_name) - continue + for file_name, load_info in files_load_info.items(): + param_name_list = load_info["param_name_list"] + param_slice_list = load_info["param_slice_list"] + search_rank_list = load_info["search_rank_list"] - # Get destination tensor strategy information - dst_sharded_tensor = dst_sharded_tensor_metas[param_name] - dst_global_shape, dst_axis_fragmentations, dst_global_offset = get_strategy_info_from_sharded_tensor( - dst_sharded_tensor) + file_path = os.path.join(checkpoint_dir, file_name) + state_dict_from_file = ms_load_checkpoint( + file_path, + format='safetensors', + choice_func=lambda x, lst=param_name_list: x in lst + ) - src_sharded_tensor_list = src_sharded_tensor_metas[param_name] - if not src_sharded_tensor_list: - raise ValueError( - f"Source metadata for parameter '{param_name}' contains an empty list of ShardedTensor instances. " - "Valid source metadata requires at least one ShardedTensor entry." - ) + for param_name, param_slice, search_rank in zip(param_name_list, param_slice_list, search_rank_list): + reshard_handler = params_info[param_name]["reshard_handler"] + all_offset = params_info[param_name].get("all_offset", reshard_handler.infer_all_tensor_offset()) + load_from_multi_rank = len(all_offset) > 1 + + # Get parameter from state_dict and slice it + parameter = state_dict_from_file.pop(param_name) + sliced_tensor, is_full_slice = smart_slice(parameter, param_slice, load_from_multi_rank) + + if is_full_slice and not load_from_multi_rank: + # No reshard needed, directly add to state_dict (sliced_tensor is the original parameter) + mapped_name = key_mapping.get(param_name, param_name) if key_mapping else param_name + if not isinstance(sliced_tensor, Parameter): + sliced_tensor = Parameter(sliced_tensor, name=mapped_name, requires_grad=False) + state_dict_no_reshard[mapped_name] = sliced_tensor + continue + + if param_name not in params_info_need_reshard: + params_info_need_reshard[param_name] = { + "reshard_handler": reshard_handler, + "tensor_map": {} + } - # Get parameters info which need to concat - if param_name != src_sharded_tensor_list[0].key: - concat_infos = [] - reshard_infos = [] - for src_sharded_tensor in src_sharded_tensor_list: - param_key = str((src_sharded_tensor.org_key, src_sharded_tensor.global_offset)) - concat_infos.append( - { - 'sub_name': src_sharded_tensor.org_key, - 'file_name': param_file_mappings[param_key][0]["file_name"], - 'param_dtype': src_sharded_tensor.dtype, - } - ) + # Initialize or update the tensor map for the parameter with the rank-specific slice + params_info_need_reshard[param_name]["tensor_map"][search_rank] = sliced_tensor - if dst_axis_fragmentations != src_sharded_tensor_list[0].axis_fragmentations: - # `reshard_infos` contains `full_shape, from_layout, to_layout, to_rank_id` - reshard_infos = [dst_global_shape, None, dst_sharded_tensor.layout, rank_id] - need_concat_params[param_name] = (concat_infos, reshard_infos) - continue + return params_info_need_reshard, state_dict_no_reshard - param_key: Optional[str] = None - strategy_is_same = False - # Compare with each source tensor strategy - for src_sharded_tensor in src_sharded_tensor_list: - src_global_shape, src_axis_fragmentations, src_global_offset = \ - get_strategy_info_from_sharded_tensor(src_sharded_tensor) +def apply_parallel_load_strategy( + params_info: Dict[str, Dict], + state_dict: Dict[str, Parameter], + key_mapping: Dict = None, + num_workers: int = 1 +): + """Applies parallel loading strategy to reshard and merge tensor slices into full parameters. + + This function distributes resharding workload across multiple worker threads to optimize + the process of merging sliced tensor fragments (from multiple ranks) into complete parameters. + Key steps include: + 1. Collecting parameter metadata (tensor slices, reshard handlers, size info) + 2. Balancing workload across workers by tensor size (to avoid load imbalance) + 3. Processing parameter groups in parallel to reconstruct full tensors + 4. Updating the state dict with reconstructed parameters - # Validate global shape compatibility - if src_global_shape != dst_global_shape: - raise RuntimeError("Global shapes of source and destination tensors do not match") + Args: + params_info: Nested dictionary containing resharding metadata for parameters needing resharding. + Structure: { + param_name: { + "reshard_handler": ReshardHandler instance (handles tensor resharding logic), + "tensor_map": Dict mapping source ranks to sliced numpy tensor fragments + } + } + state_dict: Target state dictionary to populate with reconstructed parameters. + Will be updated in-place with merged parameter tensors. + key_mapping: Optional dictionary for parameter name remapping (used when original + name not found in src_sharded_tensor_metas). Defaults to None. + num_workers: Maximum number of worker threads to use for parallel processing. + Actual workers used are min(num_workers, number of parameters) to avoid over-threading. + Defaults to 1 (single-threaded). + """ + total_param_info = [] + for param_name, param_info in params_info.items(): + reshard_handler = param_info["reshard_handler"] + tensor_map = param_info["tensor_map"] + + total_size = sum(t.size for t in tensor_map.values()) + total_param_info.append({ + "param_name": key_mapping.get(param_name, param_name) if key_mapping else param_name, + "from_tensor_map": tensor_map, + "reshard_handler": reshard_handler, + "size": total_size + }) - # Check if sharding strategies differ - if src_axis_fragmentations != dst_axis_fragmentations: - break # Strategies differ, no need to check further + num_params = len(total_param_info) + num_workers = min(num_workers, num_params) if num_params > 0 else 1 + logger.info(f"Process workers number: {num_workers}") - strategy_is_same = True + def balance_load(params: List[dict], num_groups: int) -> List[List[dict]]: + """Balances parameter load across worker groups to minimize load imbalance. - # Check if offsets match for direct mapping - if src_global_offset == dst_global_offset: - param_key = str((src_sharded_tensor.org_key, src_global_offset)) - break # Found matching parameter + Uses a greedy load balancing algorithm: + 1. Sorts parameters by total size (descending) to prioritize large parameters + 2. Greedily assigns each parameter to the worker group with the smallest current total size + This ensures even distribution of computational load across workers. - # Validate strategy consistency - if strategy_is_same and param_key is None: - raise RuntimeError("Matching strategy found but no corresponding parameter offset") + Args: + params: List of parameter metadata dicts (each with "size" key) + num_groups: Number of worker groups to split parameters into - src_sharded_tensor = src_sharded_tensor_list[0] + Returns: + List of worker groups, where each group is a list of parameter metadata dicts. + Groups are balanced by total tensor size to avoid uneven workload distribution. + """ + # Sort parameters from largest to smallest to optimize load balancing + sorted_params = sorted(params, key=lambda x: x["size"], reverse=True) - # Categorize based on strategy comparison - if strategy_is_same: - # Parameters that don't need resharding - file_name = param_file_mappings[param_key][0]["file_name"] + # Initialize worker groups with empty params and zero total size + groups = [{"total_size": 0, "params": []} for _ in range(num_groups)] - # Initialize entry if new file - if file_name not in no_shard_params: - no_shard_params[file_name] = { - "param_name_list": [src_sharded_tensor.org_key], - "param_dtype_list": [src_sharded_tensor.dtype], - } - else: - # Add to existing file entry - no_shard_params[file_name]["param_name_list"].append(src_sharded_tensor.org_key) - no_shard_params[file_name]["param_dtype_list"].append(src_sharded_tensor.dtype) + # Assign each parameter to the least loaded group + for param in sorted_params: + min_group = min(groups, key=lambda g: g["total_size"]) + min_group["total_size"] += param["size"] + min_group["params"].append(param) - no_shard_params_list.append(src_sharded_tensor.org_key) - else: - # Parameters that need online resharding - online_shard_params[src_sharded_tensor.org_key] = [ - dst_global_shape, src_sharded_tensor.layout, dst_sharded_tensor.layout, rank_id - ] - # Parameters to be processed for categorized logging - logger.debug(f"Params not mapping: {not_mapping_params}") - logger.debug(f"Params needing transformation: {need_concat_params}") - logger.debug(f"Params no need reshard: {no_shard_params_list}") - logger.debug(f"Params need reshard: {list(online_shard_params.keys())}") + # Extract only the parameter lists (discard size tracking) + return [group["params"] for group in groups] - return not_mapping_params, need_concat_params, no_shard_params, online_shard_params + param_groups = balance_load(total_param_info, num_workers) + def process_param_group(param_group: List[dict]) -> List[Tuple[str, Parameter]]: + """Processes a group of parameters to reconstruct full tensors from slices. -def get_metadata_of_checkpoint(checkpoint_dir: str) -> tuple[dict, dict]: + For each parameter in the group: + 1. Uses ReshardHandler to merge tensor slices into a full tensor + 2. Wraps the merged tensor into a Parameter object + 3. Collects (param_name, Parameter) pairs for state dict update + + Args: + param_group: List of parameter metadata dicts (from balance_load output) + + Returns: + List of tuples containing (parameter name, reconstructed Parameter object) + """ + results = [] + for param in param_group: + target_weight = param["reshard_handler"].get_real_tensor(param["from_tensor_map"]) + target_weight = Parameter(target_weight, name=param["param_name"], requires_grad=False) + results.append((param["param_name"], target_weight)) + return results + + with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor: + futures = [executor.submit(process_param_group, group) for group in param_groups] + + for future in concurrent.futures.as_completed(futures): + for param_name, target_weight in future.result(): + state_dict[param_name] = target_weight + + +def build_concat_tensors_reformulation_mappings( + checkpoint_dir: str, + params_info: Dict[str, Dict], + need_concat_params: Dict[str, Tuple[List[ShardedTensor], ShardedTensor]], + param_file_mappings: Dict[str, List[Dict[str, Any]]], + network: Cell, + key_mapping: Dict[str, str] +) -> Tuple[Dict[str, Dict], Dict[str, Parameter]]: """ - Retrieves metadata from checkpoint directory, either from an existing metadata file - or by parsing checkpoint files. + Processes parameters that need concatenation (HuggingFace weights), executes loading, concat, and slice operations. - First checks for a pre-existing 'metadata.json' file in the checkpoint directory. If found, - it loads metadata from this file using load_metadata(). If not found, it generates metadata - by parsing the checkpoint files directly using load_metadata_from_checkpoint(). + Workflow: + 1. Organize file loading info: group parameters by file for efficient batch loading using ms_load_checkpoint + 2. Concat weights: for each parameter, build concat_dict and call network.convert_hf_weight to concatenate + 3. Slice with all_offset: use _smart_slice to get target slices and classify into reshard-needed or no-reshard groups Args: - checkpoint_dir: Path to the directory containing the checkpoint files. - network: The target core network (Cell) which has method `convert_name` to convert Hugging Face weight. + checkpoint_dir: Path to checkpoint directory + params_info: Output from build_tensors_reformulation_all_offsets, containing reshard handlers and offsets + need_concat_params: {param_name: (src_sharded_tensor_list, dst_sharded_tensor)} + param_file_mappings: Mapping from shard IDs to storage information lists + network: Network instance (must have convert_hf_weight method) + key_mapping: Parameter name mapping dictionary Returns: - A tuple containing two dictionaries: - - sharded_tensor_metas: Metadata about sharded tensors - - param_file_mappings: Mapping of parameters to their storage files + Tuple[params_info_need_reshard, state_dict_no_reshard] + - params_info_need_reshard: {param_name: {"reshard_handler": ..., "tensor_map": ...}} + - state_dict_no_reshard: {param_name: Parameter} """ - logger.info("..........Load Metadata of Checkpoint..........") + if need_concat_params and not hasattr(network, 'convert_hf_weight'): + raise NotImplementedError( + "The `convert_hf_weight` method of network is not implemented." + ) - # Construct path to metadata file - metadata_path = os.path.join(checkpoint_dir, "metadata.json") + # ========== Step 1: Organize file loading info ========== + files_to_load: Dict[str, List[str]] = {} # {file_name: [param_org_key, ...]} - # Load from existing metadata file if available - if os.path.exists(metadata_path): - sharded_tensor_metas, param_file_mappings = load_metadata(metadata_path) - # Otherwise generate metadata from checkpoint files - else: - sharded_tensor_metas, param_file_mappings = generate_default_metadata_from_checkpoint(checkpoint_dir) + for param_name, (src_sharded_tensor_list, _) in need_concat_params.items(): + for src_sharded_tensor in src_sharded_tensor_list: + shard_id = get_sharded_tensor_shard_id( + src_sharded_tensor.org_key, + src_sharded_tensor.global_offset + ) + file_name = param_file_mappings[shard_id][0]["file_name"] + + if file_name not in files_to_load: + files_to_load[file_name] = [] + # Use org_key as the parameter name to load + if src_sharded_tensor.org_key not in files_to_load[file_name]: + files_to_load[file_name].append(src_sharded_tensor.org_key) + + # Use ms_load_checkpoint to lazy-load all needed parameters + state_dict_from_files: Dict[str, Parameter] = {} + for file_name, param_names in files_to_load.items(): + file_path = os.path.join(checkpoint_dir, file_name) + loaded = ms_load_checkpoint( + file_path, + format='safetensors', + choice_func=lambda x, names=param_names: x in names + ) + state_dict_from_files.update(loaded) + + # ========== Step 2: Concat weights for each parameter ========== + state_dict_concated: Dict[str, Tensor] = {} + + for param_name, (src_sharded_tensor_list, _) in need_concat_params.items(): + # Build concat_dict + concat_dict = {} + for src_sharded_tensor in src_sharded_tensor_list: + org_key = src_sharded_tensor.org_key + mapped_key = key_mapping.get(org_key, org_key) if key_mapping else org_key + if org_key in state_dict_from_files: + concat_dict[mapped_key] = state_dict_from_files[org_key] + + # Call network.convert_hf_weight to concat + concated_weight = network.convert_hf_weight(concat_dict) + state_dict_concated[param_name] = concated_weight[param_name] + + # ========== Step 3: Slice with all_offset ========== + params_info_need_reshard: Dict[str, Dict] = {} + state_dict_no_reshard: Dict[str, Parameter] = {} - # if the content or format of metadata_path and checkpoint_dir are invalid, the return value of - # sharded_tensor_metas and param_file_mappings may be empty or None, - # and it may cause an error in subsequent loading process. - return sharded_tensor_metas, param_file_mappings + for param_name, concated_tensor in state_dict_concated.items(): + reshard_handler = params_info[param_name]["reshard_handler"] + all_offset = params_info[param_name]["all_offset"] + load_from_multi_rank = len(all_offset) > 1 + # Get first rank's offset for slicing + first_rank = next(iter(all_offset)) + param_slice = all_offset[first_rank] + + sliced_tensor, is_full_slice = smart_slice(concated_tensor, param_slice, load_from_multi_rank) + + if is_full_slice and not load_from_multi_rank: + # No reshard needed, directly add to state_dict (sliced_tensor is the original tensor) + if not isinstance(sliced_tensor, Parameter): + sliced_tensor = Parameter(sliced_tensor, name=param_name, requires_grad=False) + state_dict_no_reshard[param_name] = sliced_tensor + else: + # Needs reshard, build tensor_map for all ranks + params_info_need_reshard[param_name] = { + "reshard_handler": reshard_handler, + "tensor_map": {0: sliced_tensor} + } -def params_key_mapping( - sharded_tensor_metas: Dict[str, List[ShardedTensor]], - network: Cell -) -> tuple[dict, dict]: + return params_info_need_reshard, state_dict_no_reshard + + +def categorize_params( + dst_sharded_tensor_metas: Dict[str, ShardedTensor], + src_sharded_tensor_metas: Dict[str, List[ShardedTensor]] +) -> Tuple[List[str], Dict[str, Tuple[List[ShardedTensor], ShardedTensor]], Dict[str, Tuple[ShardedTensor, ShardedTensor]]]: """ - Mapping Hugging Face checkpoint keys to MindSpore Transformers. + Categorizes parameters based on comparison of source and destination sharding strategies and key matching. + + Analyzes parameters from destination and source sharded tensor metadata to classify them into three categories: + - Special parameters (not_mapping_params): Missing from source metadata or with mismatched keys + - Concat-required parameters (need_concat_params): Params that need concatenation (and optional resharding) due to + mismatched keys or sharding strategies between source and destination + - Mapped parameters (params_to_load): Params with matching keys and consistent sharding attributes Args: - sharded_tensor_metas: Metadata about sharded tensors. - network: The network (Cell) which has method `convert_name` to convert Hugging Face weight. + dst_sharded_tensor_metas: Dictionary mapping parameter names to their destination ShardedTensor metadata + (key: parameter name, value: ShardedTensor instance for destination) + src_sharded_tensor_metas: Dictionary mapping parameter names to a list of source ShardedTensor metadata + (key: parameter name, value: list of ShardedTensor instances for source) + param_file_mappings: Mapping from shard IDs to storage information lists. + Each storage info dict has "file_name", "storage_rank", "rank_group" keys. Returns: - A dictionary after mapping about sharded tensor metas. - """ - # The key of `mapped_sharded_tensor_metas` is in the network, - # such as { qkv: [ShardedTensor, ShardedTensor, ShardedTensor], ... } - mapped_sharded_tensor_metas = {} - # The key of `key_mapping` is {'weight_key': 'mapping_key'}, - # and the `mapping_key` may not have the same name as the parameter in the network, - # it could be an intermediate form, - # such as { 'q_proj': 'linear_q', 'k_proj': 'linear_k', 'v_proj': 'linear_v', ... } - key_mapping = {} + Tuple containing three collections: + - not_mapping_params: List of parameter names missing from source metadata (special parameters) + - need_concat_params: Dictionary mapping parameter names to (src_sharded_tensor_list, dst_sharded_tensor) + where src_sharded_tensor_list is a list of source ShardedTensor instances that need to be concatenated + - params_to_load: Dictionary mapping original parameter keys to sharded tensor details: + Each entry contains (src_sharded_tensor, dst_sharded_tensor) - for param_name in sharded_tensor_metas: - param_name_converted = network.convert_name(param_name) - sharded_tensor_list = sharded_tensor_metas.get(param_name) + Raises: + ValueError: If a parameter exists in source metadata but has an empty list of ShardedTensor instances + """ + logger.info("..........Categorize Params..........") + # Initialize categorization collections + not_mapping_params: List[str] = [] + need_concat_params: Dict[str, Tuple[List[ShardedTensor], ShardedTensor]] = {} + params_to_load: Dict[str, Tuple[ShardedTensor, ShardedTensor]] = {} - for sharded_tensor in sharded_tensor_list: - sharded_tensor.key = param_name_converted - sharded_tensor.org_key = param_name + # Analyze each parameter in destination metadata + for param_name, dst_sharded_tensor in dst_sharded_tensor_metas.items(): + # Handle parameters missing from source metadata + if param_name not in src_sharded_tensor_metas: + not_mapping_params.append(param_name) + continue - key_mapping[param_name] = param_name_converted - param_name_converted_concat = network.convert_concat_name(param_name_converted) - mapped_sharded_tensor_metas.setdefault(param_name_converted_concat, []).extend(sharded_tensor_list) + src_sharded_tensor_list = src_sharded_tensor_metas[param_name] + if not src_sharded_tensor_list: + raise ValueError( + f"Source metadata for parameter '{param_name}' contains an empty list of ShardedTensor instances. " + "Valid source metadata requires at least one ShardedTensor entry." + ) - return mapped_sharded_tensor_metas, key_mapping + # Check if parameter needs concat: param_name differs from src tensor's key + if param_name != src_sharded_tensor_list[0].key: + # Needs concat: store complete ShardedTensor list and dst ShardedTensor + need_concat_params[param_name] = (src_sharded_tensor_list, dst_sharded_tensor) + else: + # Direct mapping + src_sharded_tensor = src_sharded_tensor_list[0] + params_to_load[src_sharded_tensor.org_key] = (src_sharded_tensor, dst_sharded_tensor) + # Parameters to be processed for categorized logging + logger.debug(f"Params can't load: {not_mapping_params}") + logger.debug(f"Params needing concat: {list(need_concat_params.keys())}") + logger.debug(f"Params to load: {list(params_to_load.keys())}") -# pylint: disable=W0212 -def get_core_network(network): - """Get the core network that has `convert_name` method.""" - if hasattr(network, '_backbone'): - return get_core_network(network._backbone) - if hasattr(network, 'network'): - return get_core_network(network.network) - return network + return not_mapping_params, need_concat_params, params_to_load -def load_checkpoint( +def load_checkpoint_v1( checkpoint: str, network: Cell, optimizer: Optional[Optimizer] = None, global_step: Optional[int] = None, - balanced_load: bool = False + balanced_load: bool = False, + load_worker_number: int = 1, ) -> None: """ Loads a checkpoint into a network and optional optimizer. @@ -900,6 +1012,7 @@ def load_checkpoint( network: The target network (Cell) to load parameters into (cannot be None) optimizer: Optional optimizer (Cell) to load optimizer states into global_step: Optional initial global step value if not found in checkpoint + load_worker_number: Max number of workers to process params Raises: ValueError: If the input `network` is None @@ -914,6 +1027,7 @@ def load_checkpoint( logger.info("..........Start Load Checkpoint..........") # Retrieve metadata from checkpoint files + logger.info("..........Get Metadata of Checkpoint..........") src_sharded_tensor_metas, param_file_mappings = get_metadata_of_checkpoint(checkpoint_dir) # Get the core network and check the convert method is illegal @@ -934,6 +1048,7 @@ def load_checkpoint( return param_name in list(network.parameters_dict().keys()) param_redundancy = None + logger.info("..........Get Metadata of Network..........") if balanced_load: rank_id_to_sharded_tensors = apply_balance_shard_strategy(network, filter_func) dst_sharded_tensor_metas = get_cur_sharded_tensor_after_balanced(rank_id_to_sharded_tensors) @@ -943,43 +1058,54 @@ def load_checkpoint( if get_real_group_size() > 1 else get_sharded_tensor_from_cell(network, optimizer) # Categorize parameters based on sharding strategies - _, need_concat_params, no_shard_params, online_shard_params = categorize_params( - dst_sharded_tensor_metas, src_sharded_tensor_metas, param_file_mappings + _, need_concat_params, params_to_load = categorize_params( + dst_sharded_tensor_metas, src_sharded_tensor_metas ) # Process Weight + logger.info("..........Building State Dict..........") state_dict: Dict[str, Parameter] = {} - # Concat parameters - concat_params(checkpoint_dir, network, key_mapping, need_concat_params, state_dict) - - # Load parameters that don't require resharding - for file_name, param_info in no_shard_params.items(): - param_name_list = param_info["param_name_list"] - param_dtype_list = param_info["param_dtype_list"] - no_reshard_state_dict = load_safetensor( - os.path.join(checkpoint_dir, file_name), param_name_list, dtype=param_dtype_list + # ==================== Path 1: Direct mapping parameters ==================== + start_time = time() + params_info_direct = build_tensors_reformulation_all_offsets(params_to_load) + build_all_offsets_time = time() - start_time + start_time = time() + params_info_need_reshard_1, state_dict_no_reshard_1 = build_tensors_reformulation_mappings( + checkpoint_dir, params_info_direct, src_sharded_tensor_metas, + param_file_mappings, key_mapping, + ) + build_all_tensor_map_time = time() - start_time + state_dict.update(state_dict_no_reshard_1) + + # ==================== Path 2: Concat parameters (HuggingFace weights) ==================== + if need_concat_params: + # Build reshard info for concat parameters (reuse build_tensors_reformulation_all_offsets) + params_info_concat = build_tensors_reformulation_all_offsets(need_concat_params) + # Load, concat, and build tensor_map + params_info_need_reshard_2, state_dict_no_reshard_2 = build_concat_tensors_reformulation_mappings( + checkpoint_dir, params_info_concat, need_concat_params, + param_file_mappings, network, key_mapping ) + state_dict.update(state_dict_no_reshard_2) + else: + params_info_need_reshard_2 = {} - state_dict.update({ - key_mapping[param_name]: value - for param_name, value in no_reshard_state_dict.items() - }) + # ==================== Unified parallel load strategy ==================== + all_params_need_reshard = {**params_info_need_reshard_1, **params_info_need_reshard_2} - # Load and reshard parameters that require online resharding - for param_name, (full_shape, from_layout, to_layout, to_rank_id) in online_shard_params.items(): - reshard_handler = ReshardHandler(param_name, full_shape, from_layout, to_layout, to_rank_id) - all_offset = reshard_handler.infer_all_tensor_offset() - from_tensor_map = load_tensor_by_offset( - all_offset, param_name, checkpoint_dir, src_sharded_tensor_metas, param_file_mappings, key_mapping - ) - target_weight = reshard_handler.get_real_tensor(from_tensor_map) - param_name = key_mapping[param_name] - state_dict[param_name] = Parameter(target_weight, name=param_name, requires_grad=False) + logger.info("Get and write params into state dict.") + start_time = time() + apply_parallel_load_strategy( + all_params_need_reshard, state_dict, key_mapping, num_workers=load_worker_number + ) + apply_parallel_load_strategy_time = time() - start_time + logger.info(f"Load {len(state_dict)} parameters") # Handle global_step for optimizer if needed if optimizer and "global_step" not in state_dict: # Initialize global_step with default or from common.json + logger.info(".....Get Global Step for Optimizer.....") if not global_step: common_file = os.path.join(checkpoint_dir, 'common.json') global_step = 0 if not os.path.exists(common_file) else CommonInfo.load_common(common_file).global_step @@ -987,7 +1113,9 @@ def load_checkpoint( state_dict["global_step"] = Parameter( Tensor([global_step], mstype.int32), name="global_step", requires_grad=False ) + # ms_save_checkpoint(state_dict, f"qwen3_0.6b_old_rank_{get_real_rank()}.safetensor", format="safetensors") + logger.info("..........Loading State Dict into Network..........") # Load state dictionary into network and optimizer load_parameters( network, @@ -997,44 +1125,10 @@ def load_checkpoint( param_redundancy=param_redundancy ) - -def concat_params(checkpoint_dir: str, network: Cell, key_mapping: dict, need_concat_params, state_dict: dict): - """Concat the need_concat_params dict in checkpoint.""" - if need_concat_params and not hasattr(network, 'convert_hf_weight'): - raise NotImplementedError("The `convert_hf_weight` method of network is not implemented.") - - for param_name, concat_info in need_concat_params.items(): - sharded_tensor_list, reshard_info = concat_info - org_weight_dict = {} - # Get all the params need to concat into `org_weight_dict`. - for sharded_tensor in sharded_tensor_list: - org_weight_dict.update( - load_safetensor( - checkpoint_path=os.path.join(checkpoint_dir, sharded_tensor['file_name']), - param_name=sharded_tensor['sub_name'], - dtype=sharded_tensor['param_dtype'] - ) - ) - # Mapping the weight key to MCore key into `concat_dict`. - concat_dict = { - key_mapping[k]: v - for k, v in org_weight_dict.items() - } - # Concat the weight. - concated_weight = network.convert_hf_weight(concat_dict) - - if reshard_info: - # Get the offset of the Tensor to reshard. - full_shape, from_layout, to_layout, to_rank_id = reshard_info - reshard_handler = ReshardHandler(param_name, full_shape, from_layout, to_layout, to_rank_id) - all_offset = reshard_handler.infer_all_tensor_offset() - # Get the slice to reshard the Tensor. - slices = tuple(slice(start, end) for start, end in all_offset[0]) - target_weight = concated_weight[param_name][slices] - # Update to `state_dict` to load into the network. - state_dict[param_name] = Parameter(target_weight, name=param_name, requires_grad=False) - else: - state_dict[param_name] = Parameter(concated_weight[param_name], name=param_name, requires_grad=False) + logger.info("..........Loading Checkpoint Finished..........") + logger.info(f"build_all_offsets_time: {round(build_all_offsets_time, 6)}s") + logger.info(f"build_all_tensor_map_time: {round(build_all_tensor_map_time, 6)}s") + logger.info(f"apply_parallel_load_strategy_time: {round(apply_parallel_load_strategy_time, 6)}s") def check_the_param_for_load_ckpt(checkpoint: str, network: Cell): @@ -1123,7 +1217,7 @@ def load_parameters( # Load parameters into network logger.debug(f"Network state_dict keys: {list(state_dict.keys())}") - param_not_load, ckpt_not_load = load_param_into_net(network, state_dict, strict_load=True) + param_not_load, ckpt_not_load = load_param_into_net(network, state_dict) if balanced_load: param_loaded = {param_name for param_name in state_dict if param_name not in ckpt_not_load} single_parameter_broadcast(network, param_redundancy, param_not_load, param_loaded) @@ -1201,3 +1295,172 @@ def get_checkpoint_path(checkpoint: str) -> str: logger.info(f"Get checkpoint: {checkpoint}") return checkpoint + + +def load_checkpoint( + checkpoint: str, + network: Cell, + optimizer: Optional[Optimizer] = None, + global_step: Optional[int] = None, + balanced_load: bool = False, + load_worker_number: int = 1, +) -> None: + """ + 加载 MindFormers 格式权重(自训权重)。 + + 使用 ReshardLoader 处理分布式 Reshard。 + 自训权重场景无需 template,直接使用源参数名。 + """ + # Validate mandatory network parameter + check_the_param_for_load_ckpt(checkpoint, network) + + # Determine checkpoint directory path + checkpoint_dir = get_checkpoint_path(checkpoint) + + logger.info("..........Start Load Checkpoint..........") + + # Retrieve metadata from checkpoint files + logger.info("..........Get Metadata of Checkpoint..........") + src_sharded_tensor_metas, param_file_mappings = get_metadata_of_checkpoint(checkpoint_dir) + + # Define parameter filtering function + def filter_func(param_name: str) -> bool: + if optimizer: + return "accu_grads" not in param_name + return param_name in list(network.parameters_dict().keys()) + + param_redundancy = None + logger.info("..........Get Metadata of Network..........") + if balanced_load: + rank_id_to_sharded_tensors = apply_balance_shard_strategy(network, filter_func) + dst_sharded_tensor_metas = get_cur_sharded_tensor_after_balanced(rank_id_to_sharded_tensors) + param_redundancy = get_param_redundancy_after_balanced(rank_id_to_sharded_tensors) + else: + dst_sharded_tensor_metas = get_cur_sharded_tensor(network, filter_func) \ + if get_real_group_size() > 1 else get_sharded_tensor_from_cell(network, optimizer) + + # 使用 ReshardLoader 加载 + # 自训权重不需要 template,参数名直接与 dst_sharded_tensor_metas 一致 + reshard_loader = ReshardLoader( + checkpoint_dir=checkpoint_dir, + dst_sharded_tensor_metas=dst_sharded_tensor_metas, # {param_name: ShardedTensor} + src_sharded_tensor_metas=src_sharded_tensor_metas, # {param_name: [ShardedTensor, ...]} + param_file_mappings=param_file_mappings, # {(param_name, global_offset): [...]} + num_workers=load_worker_number, + template=None # 自训权重无需 template + ) + + # 第一层 Reshard,得到 {param_name: Parameter} + state_dict = reshard_loader.load() + + # Handle global_step for optimizer if needed + if optimizer and "global_step" not in state_dict: + # Initialize global_step with default or from common.json + logger.info(".....Get Global Step for Optimizer.....") + if not global_step: + common_file = os.path.join(checkpoint_dir, 'common.json') + global_step = 0 if not os.path.exists(common_file) else CommonInfo.load_common(common_file).global_step + + state_dict["global_step"] = Parameter( + Tensor([global_step], mstype.int32), name="global_step", requires_grad=False + ) + + logger.info("..........Loading State Dict into Network..........") + # Load state dictionary into network and optimizer + load_parameters( + network, + state_dict, + optimizer, + balanced_load=balanced_load, + param_redundancy=param_redundancy + ) + + logger.info("..........Loading Checkpoint Finished..........") + + +def load_hf_checkpoint( + pretrained_model_dir: str, + network: Cell, + balanced_load: bool = False, + load_worker_number: int = 1, +) -> None: + """ + 加载 HuggingFace 格式权重。 + + 使用两层处理: + - 第一层 Reshard:使用 ReshardLoader 处理分布式切片 + - 第二层 Convert:使用 Template 处理 QKV 拼接、Stack 等转换 + + 设计说明: + - ReshardLoader 使用 template.get_mf_name() 完成 HF→MF 参数名映射 + - Template.convert() 完成最终的权重转换(如 QKV 拼接) + """ + logger.info("..........Start Load Checkpoint..........") + + # Step 1: 获取模板(直接从网络实例获取) + core_network = get_core_network(network) + template = getattr(core_network, 'template', None) + if template is None: + raise ValueError( + f"Network '{type(core_network).__name__}' does not have a template. " + f"Please use @register_template decorator on the model's __init__ method." + ) + + # Step 2: 获取目标元数据 + def filter_func(param_name: str) -> bool: + return param_name in list(network.parameters_dict().keys()) + + param_redundancy = None + logger.info("..........Get Metadata of Network..........") + if balanced_load: + rank_id_to_sharded_tensors = apply_balance_shard_strategy(network, filter_func) + dst_sharded_tensor_metas = get_cur_sharded_tensor_after_balanced(rank_id_to_sharded_tensors) + param_redundancy = get_param_redundancy_after_balanced(rank_id_to_sharded_tensors) + else: + dst_sharded_tensor_metas = get_cur_sharded_tensor(network, filter_func) \ + if get_real_group_size() > 1 else get_sharded_tensor_from_cell(network) + + # Step 4: 获取源元数据(从 HF 权重构建) + src_sharded_tensor_metas, param_file_mappings = get_metadata_of_checkpoint(pretrained_model_dir) + + # Step 5: 第一层 - Reshard + # 传入 template 用于参数名映射 + # ReshardLoader 内部会: + # 1. 预先构建双向映射: + # - src_to_dst_mapping: {src_name: dst_name},如 {q_proj: qkv, k_proj: qkv, v_proj: qkv} + # - dst_to_src_mapping: {dst_name: [src_names]},如 {qkv: [q_proj, k_proj, v_proj]} + # 2. 遍历 dst_metas,对于每个目标参数,通过 dst_to_src_mapping 获取所有关联的源参数 + # 3. 对每个源参数(如 q、k、v)分别计算 offset 并进行切片 + reshard_loader = ReshardLoader( + checkpoint_dir=pretrained_model_dir, + dst_sharded_tensor_metas=dst_sharded_tensor_metas, # {mf_param_name: ShardedTensor} + src_sharded_tensor_metas=src_sharded_tensor_metas, # {hf_param_name: [ShardedTensor]} + param_file_mappings=param_file_mappings, # {(hf_param_name, global_offset): [...]} + num_workers=load_worker_number, + template=template # 直接使用 network.template + ) + + # 获取 Reshard 输出 {hf_param_name: tensor} + # 返回的 key 是 HF 原始参数名(如 q_proj.weight, k_proj.weight, v_proj.weight) + # 每个源参数都已经完成了切片(只加载当前卡需要的部分) + reshard_output = reshard_loader.load() + + # Step 6: 第二层 - Convert + # 使用 template.get_mf_state_dict() 将 {hf_param_name: weight} 转换为 {mf_param_name: Parameter} + # 内部遍历参数,调用 add_hf_weight() 进行转换: + # - 单源权重(如 embed_tokens):直接重命名 + # - 多源权重(如 QKV):暂存 q, k, v,等齐后执行拼接 + # 因为 Reshard 阶段已完成切片,转换时无需再进行切分 + state_dict = template.get_mf_state_dict(reshard_output) + + # Step 7: 加载到网络 + logger.info("..........Loading State Dict into Network..........") + # Load state dictionary into network and optimizer + load_parameters( + network, + state_dict, + balanced_load=balanced_load, + param_redundancy=param_redundancy + ) + + logger.info("..........Loading Checkpoint Finished..........") diff --git a/mindformers/checkpoint/converter/__init__.py b/mindformers/checkpoint/converter/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mindformers/checkpoint/converter/convert_op.py b/mindformers/checkpoint/converter/convert_op.py new file mode 100644 index 0000000000000000000000000000000000000000..31925ea05314a43f7fc555f7bcafed6d03d41bf9 --- /dev/null +++ b/mindformers/checkpoint/converter/convert_op.py @@ -0,0 +1,656 @@ +import re +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Dict, List, Tuple, Optional, Union, Any +import numpy as np + +from mindformers.parallel_core.transformer_config import TransformerConfig +from mindspore import Tensor, Parameter + + +@dataclass +class ConverOp(ABC): + """ + 权重转换操作基类。 + + 参考 ROLL 库设计,支持双向转换: + - HF → MF:加载 HuggingFace 权重时使用 + - MF → HF:导出为 HuggingFace 格式时使用 + + Attributes: + hf_names: HuggingFace 权重名列表 + mf_names: MindFormers 权重名列表 + mf_config: MindFormers 模型配置(用于获取 num_heads 等参数) + """ + hf_names: Union[str, List[str]] + mf_names: Union[str, List[str]] + mf_config: TransformerConfig = None + + def __post_init__(self): + if isinstance(self.hf_names, str): + self.hf_names = [self.hf_names] + if isinstance(self.mf_names, str): + self.mf_names = [self.mf_names] + + def __call__( + self, + name_to_weight: Dict[str, np.ndarray], + mf_to_hf: bool = False + ) -> Optional[Dict[str, np.ndarray]]: + """ + 执行转换。 + + Args: + name_to_weight: 输入权重字典 + mf_to_hf: 转换方向 + - False: HF → MF(默认) + - True: MF → HF + + Returns: + 转换后的权重字典,权重不齐时返回 None + """ + required_names = self.mf_names if mf_to_hf else self.hf_names + if len(required_names) > len(name_to_weight): + return None + + if mf_to_hf: + return self.mf_to_hf(name_to_weight) + else: + return self.hf_to_mf(name_to_weight) + + @staticmethod + def _name_to_pattern(name: str): + return name.replace(".", "\.").replace("{}", "(.*)") + + def is_required_name(self, name, mf_name: bool): + required_names = self.mf_names if mf_name else self.hf_names + if name in required_names: + return True + for pattern in required_names: + re_pattern = self._name_to_pattern(pattern) + if re.match(re_pattern, name): + return True + return False + + def _to_names_and_weights( + self, + from_names: List[str], + to_names: List[str], + name_to_weight: Dict[str, np.ndarray] + ) -> Tuple[List[str], List[np.ndarray]]: + """从输入字典中提取权重,并计算目标名称""" + weights = [] + match = None + for from_name in from_names: + if from_name in name_to_weight: + weight = name_to_weight[from_name] + elif "{}" in from_name: + re_pattern = self._name_to_pattern(from_name) + for name in name_to_weight: + match = re.findall(re_pattern, name) + if match: + weight = name_to_weight[name] + break + if not match: + raise ValueError(f"Cannot find match {from_name} in {name_to_weight.keys()}") + else: + raise ValueError(f"Cannot find {from_name} in {name_to_weight.keys()}") + weights.append(weight) + + if match: + to_names = [to_name.format(*match) for to_name in to_names] + + return to_names, weights + + def hf_to_mf(self, name_to_weight: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: + """HF → MF 转换""" + names, weights = self._to_names_and_weights( + self.hf_names, self.mf_names, name_to_weight + ) + mf_weights = self._hf_to_mf(weights) + if not isinstance(mf_weights, list): + mf_weights = [mf_weights] + assert len(names) == len(mf_weights), f"names: {names}, weights: {mf_weights}" + return {names[i]: mf_weights[i] for i in range(len(names))} + + def mf_to_hf(self, name_to_weight: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: + """MF → HF 转换""" + names, weights = self._to_names_and_weights( + self.mf_names, self.hf_names, name_to_weight + ) + hf_weights = self._mf_to_hf(weights) + if not isinstance(hf_weights, list): + hf_weights = [hf_weights] + assert len(names) == len(hf_weights), f"names: {names}, weights: {hf_weights}" + return {names[i]: hf_weights[i] for i in range(len(names))} + + @abstractmethod + def _hf_to_mf(self, weights: List[np.ndarray]) -> List[np.ndarray]: + """将 HuggingFace 权重转换为 MindFormers 权重(子类实现)""" + raise NotImplementedError() + + @abstractmethod + def _mf_to_hf(self, weights: List[np.ndarray]) -> List[np.ndarray]: + """将 MindFormers 权重转换为 HuggingFace 权重(子类实现)""" + raise NotImplementedError() + + +@dataclass +class RenameConverOp(ConverOp): + """ + 重命名操作(1:1 映射)。 + + 双向转换时,仅修改参数名,权重值保持不变。 + """ + def __post_init__(self): + super().__post_init__() + assert len(self.hf_names) == 1, f"RenameConverOp only support one name {self.hf_names}" + assert len(self.mf_names) == 1, f"RenameConverOp only support one name {self.mf_names}" + + def _hf_to_mf(self, weights: List[np.ndarray]) -> List[np.ndarray]: + return weights + + def _mf_to_hf(self, weights: List[np.ndarray]) -> List[np.ndarray]: + return weights + + +@dataclass +class ConcatConverOp(ConverOp): + """ + 拼接操作(N:1 映射)。 + + HF → MF: np.concatenate() 拼接多个权重 + MF → HF: np.split() 拆分为多个权重 + """ + dim: int = 0 + split_sizes: List[int] = None # 可选:指定每个 HF 权重的大小 + use_interleaved_weight_layout_mlp: bool = True + + def __post_init__(self): + super().__post_init__() + assert (len(self.hf_names) == 1) != (len(self.mf_names) == 1), ( + f"ConcatConverOp only supports one name as target {self.hf_names} {self.mf_names}" + ) + + def set_model_config(self, config): + """从模型配置设置参数""" + self.use_interleaved_weight_layout_mlp = config.use_interleaved_weight_layout_mlp + + def _hf_to_mf(self, weights: List[np.ndarray]) -> List[np.ndarray]: + if self.use_interleaved_weight_layout_mlp: + # 步骤1:沿自定义维度 self.dim 堆叠(堆叠后ndim = input_ndim + 1) + stacked = np.stack(weights, axis=self.dim) # 新增1个维度用于权重索引 + + # 步骤2:动态构造转置维度,交换“堆叠维度”和“原拼接维度”,实现交织 + axes = list(range(stacked.ndim)) # 堆叠后的维度索引列表 + axes[self.dim], axes[self.dim+1] = axes[self.dim+1], axes[self.dim] # 交换相邻维度 + transposed = stacked.transpose(axes) # 转置后仍保持 ndim = input_ndim + 1 + + # 步骤3:构造新形状,合并两个相关维度且保留原ndim(不扁平化) + # 思路:替换原拼接维度为 (权重数量 * 单个权重拼接维度大小),删除堆叠新增的维度 + new_shape = list(transposed.shape) + new_shape[self.dim] = transposed.shape[self.dim] * transposed.shape[self.dim+1] + new_shape.pop(self.dim+1) + + # 步骤4:重塑形状,得到最终交织结果(ndim = input_ndim) + interleaved_concat = transposed.reshape(new_shape) + + return [interleaved_concat] + else: + self.split_sizes = [w.shape[self.dim] for w in weights] + return [np.concatenate(weights, axis=self.dim)] + + def _mf_to_hf(self, weights: List[np.ndarray]) -> List[np.ndarray]: + """将拼接的 MF 权重拆分回 HF 权重""" + concat_weight = weights[0] + + if self.split_sizes is not None: + # 按记录的大小拆分 + indices = np.cumsum(self.split_sizes[:-1]) + return np.split(concat_weight, indices, axis=self.dim) + else: + # 均匀拆分 + num_splits = len(self.hf_names) + return np.split(concat_weight, num_splits, axis=self.dim) + + +@dataclass +class StackConverOp(ConverOp): + """ + 堆叠操作(N:1 映射)。 + + HF → MF: np.stack() 堆叠多个权重 + MF → HF: np.split() 拆分为多个权重 + """ + dim: int = 0 + + def __post_init__(self): + super().__post_init__() + assert (len(self.hf_names) == 1) != (len(self.mf_names) == 1), ( + f"StackConverOp only supports one name as target {self.hf_names} {self.mf_names}" + ) + + def _hf_to_mf(self, weights: List[np.ndarray]) -> List[np.ndarray]: + return [np.stack(weights, axis=self.dim)] + + def _mf_to_hf(self, weights: List[np.ndarray]) -> List[np.ndarray]: + """将堆叠的 MF 权重拆分回 HF 权重""" + stacked_weight = weights[0] + num_splits = len(self.hf_names) + # 使用 np.moveaxis 将 stack 维度移到第 0 维,然后按第 0 维拆分 + splits = [np.squeeze(s, axis=self.dim) + for s in np.split(stacked_weight, num_splits, axis=self.dim)] + return splits + + +@dataclass +class QKVConverOp(ConverOp): + """ + QKV 融合操作(3:1 映射)。 + + HF → MF: 将 Q、K、V 按 GQA 格式交错拼接为 QKV + MF → HF: 将 QKV 拆分为独立的 Q、K、V + + GQA 格式:[ng, (nh/ng + 2) * kv_channels, hidden_size] + 其中 ng = num_query_groups, nh = num_attention_heads + """ + num_attention_heads: int = None + num_query_groups: int = None + kv_channels: int = None + hidden_size: int = None + tensor_model_parallel_size: int = None + use_contiguous_weight_layout_attention: bool = False + + def __post_init__(self): + super().__post_init__() + assert len(self.hf_names) == 3, f"QKVConverOp only support three hf_names {self.hf_names}" + assert len(self.mf_names) == 1, f"QKVConverOp only support one mca_name {self.mf_names}" + + def set_model_config(self, config): + """从模型配置设置参数""" + self.num_attention_heads = config.num_attention_heads + self.num_query_groups = config.num_query_groups + self.kv_channels = config.kv_channels + self.hidden_size = config.hidden_size + self.tensor_model_parallel_size = config.tensor_model_parallel_size + self.use_contiguous_weight_layout_attention = config.use_contiguous_weight_layout_attention + + def _hf_to_mf(self, weights: List[np.ndarray]) -> List[np.ndarray]: + """将 Q、K、V 权重转换为 QKV 融合权重""" + q_weight, k_weight, v_weight = weights + nh = self.num_attention_heads // self.tensor_model_parallel_size + ng = self.num_query_groups // self.tensor_model_parallel_size + dim = self.kv_channels + assert nh % ng == 0 + + # 重排并拼接(GQA 交错格式) + if not self.use_contiguous_weight_layout_attention: + mf_qkv_weight = np.concatenate([ + q_weight.reshape((ng, dim * nh // ng, -1)), + k_weight.reshape((ng, dim, -1)), + v_weight.reshape((ng, dim, -1)), + ], axis=1).reshape((-1, self.hidden_size)) + else: + mf_qkv_weight = np.concatenate([q_weight, k_weight, v_weight], axis=1).reshape((-1, self.hidden_size)) + + return [mf_qkv_weight] + + def _mf_to_hf(self, weights: List[np.ndarray]) -> List[np.ndarray]: + """ + 将 QKV 融合权重拆分为独立的 Q、K、V 权重。 + + 参考 ROLL 的 QKVConverOp._mf_to_hf() 实现。 + """ + if self.hidden_size is None: + self.hidden_size = self.mf_config.hidden_size if self.mf_config else weights[0].shape[-1] + + qkv_weight = weights[0] + ng = self.num_query_groups + nh = self.num_attention_heads + dim = self.kv_channels + + # 从 GQA 格式拆分 + # 输入 shape: [ng * (nh/ng + 2) * dim, hidden_size] + # 重塑为: [ng, (nh/ng + 2) * dim, hidden_size] + qkv_weight = qkv_weight.reshape((ng, dim * (nh // ng + 2), -1)) + + # 按 [q_dim, k_dim, v_dim] 拆分 + q_dim = dim * nh // ng + k_dim = dim + v_dim = dim + + qkv_splits = np.split(qkv_weight, [q_dim, q_dim + k_dim], axis=1) + + # 重塑为原始 HF 格式 + q_weight = qkv_splits[0].reshape((-1, self.hidden_size)) + k_weight = qkv_splits[1].reshape((-1, self.hidden_size)) + v_weight = qkv_splits[2].reshape((-1, self.hidden_size)) + + return [q_weight, k_weight, v_weight] + + +@dataclass +class QKVBiasConverOp(ConverOp): + """ + QKV Bias 融合操作(3:1 映射)。 + + 与 QKVConverOp 类似,但处理 1D 的 bias 向量。 + """ + num_attention_heads: int = None + num_query_groups: int = None + kv_channels: int = None + tensor_model_parallel_size: int = None + use_contiguous_weight_layout_attention: bool = False + + def set_model_config(self, config): + """从模型配置设置参数""" + self.num_attention_heads = config.num_attention_heads + self.num_query_groups = config.num_query_groups + self.kv_channels = config.kv_channels + self.tensor_model_parallel_size = config.tensor_model_parallel_size + self.use_contiguous_weight_layout_attention = config.use_contiguous_weight_layout_attention + + def _hf_to_mf(self, weights: List[np.ndarray]) -> List[np.ndarray]: + """将 Q、K、V bias 转换为 QKV 融合 bias""" + q_bias, k_bias, v_bias = weights + nh = self.num_attention_heads // self.tensor_model_parallel_size + ng = self.num_query_groups // self.tensor_model_parallel_size + dim = self.kv_channels + assert nh % ng == 0 + + mf_qkv_bias = np.concatenate([ + q_bias.reshape((ng, dim * nh // ng)), + k_bias.reshape((ng, dim)), + v_bias.reshape((ng, dim)), + ], axis=1).reshape(-1) + + return [mf_qkv_bias] + + def _mf_to_hf(self, weights: List[np.ndarray]) -> List[np.ndarray]: + """将 QKV 融合 bias 拆分为独立的 Q、K、V bias""" + qkv_bias = weights[0] + ng = self.num_query_groups + nh = self.num_attention_heads + dim = self.kv_channels + + qkv_bias = qkv_bias.reshape((ng, dim * (nh // ng + 2))) + + q_dim = dim * nh // ng + k_dim = dim + + qkv_splits = np.split(qkv_bias, [q_dim, q_dim + k_dim], axis=1) + + q_bias = qkv_splits[0].reshape(-1) + k_bias = qkv_splits[1].reshape(-1) + v_bias = qkv_splits[2].reshape(-1) + + return [q_bias, k_bias, v_bias] + + +@dataclass +class MoeExpertFc1ConverOp(ConverOp): + """ + MOE 专家 FC1 权重转换操作(动态 N:1 映射)。 + + 使用模板字符串动态匹配所有专家的 gate_proj 和 up_proj 权重。 + 支持完整路径格式,第一个 {} 表示层序号,第二个 {} 表示专家序号。 + + HF 格式: + - hf_names = [ + "model.layers.{}.mlp.experts.{}.gate_proj.weight", + "model.layers.{}.mlp.experts.{}.up_proj.weight" + ] + + MF 格式: + - mf_names = ["decoder.layers.{}.mlp.experts.weight1"] + - shape = (num_experts * hidden_size, 2 * ffn_hidden_size) + """ + num_experts: int = None + ffn_hidden_size: int = None + hidden_size: int = None + + def set_model_config(self, config): + """从模型配置设置参数""" + self.num_experts = config.num_moe_experts + self.ffn_hidden_size = config.moe_ffn_hidden_size + self.hidden_size = config.hidden_size + + def _collect_expert_weights( + self, + name_to_weight: Dict[str, np.ndarray] + ) -> Tuple[str, List[np.ndarray], str]: + """ + 收集所有专家的权重,按专家索引排序。 + + Returns: + (mf_name, weights, layer_id): MF 名称、权重列表和层 ID + weights 顺序: [gate_0, up_0, gate_1, up_1, ..., gate_n, up_n] + """ + # 构建正则模式(将 {} 替换为捕获组) + gate_pattern = self._name_to_pattern(self.hf_names[0]) + up_pattern = self._name_to_pattern(self.hf_names[1]) + + # 收集所有专家权重 + expert_weights = {} # {expert_id: {'gate': weight, 'up': weight}} + layer_id = None + + for name, weight in name_to_weight.items(): + gate_match = re.match(gate_pattern, name) + if gate_match: + groups = gate_match.groups() + layer_id = groups[0] # 第一个捕获组是层序号 + expert_id = int(groups[1]) # 第二个捕获组是专家序号 + if expert_id not in expert_weights: + expert_weights[expert_id] = {} + expert_weights[expert_id]['gate'] = weight + continue + + up_match = re.match(up_pattern, name) + if up_match: + groups = up_match.groups() + layer_id = groups[0] + expert_id = int(groups[1]) + if expert_id not in expert_weights: + expert_weights[expert_id] = {} + expert_weights[expert_id]['up'] = weight + + # 按专家索引排序,交错排列 gate 和 up + sorted_ids = sorted(expert_weights.keys()) + weights = [] + for expert_id in sorted_ids: + weights.append(expert_weights[expert_id]['gate']) + weights.append(expert_weights[expert_id]['up']) + + # 生成 MF 名称(替换层序号占位符) + mf_name = self.mf_names[0].replace("{}", layer_id, 1) + + return mf_name, weights, layer_id + + def hf_to_mf(self, name_to_weight: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: + """HF → MF 转换(重写以支持动态专家收集)""" + mf_name, weights, _ = self._collect_expert_weights(name_to_weight) + mf_weights = self._hf_to_mf(weights) + return {mf_name: mf_weights[0]} + + def _hf_to_mf(self, weights: List[np.ndarray]) -> List[np.ndarray]: + """ + 将多个专家的 gate_proj + up_proj 转换为 experts.weight1。 + + 输入顺序:[gate_0, up_0, gate_1, up_1, ..., gate_n, up_n] + """ + num_experts = len(weights) // 2 + + expert_weights = [] + for i in range(num_experts): + gate_weight = weights[i * 2] # (ffn_hidden_size, hidden_size) + up_weight = weights[i * 2 + 1] # (ffn_hidden_size, hidden_size) + + # 拼接 gate 和 up: (2 * ffn_hidden_size, hidden_size) + combined = np.concatenate([gate_weight, up_weight], axis=0) + # 转置为 (hidden_size, 2 * ffn_hidden_size) + combined = combined.T + expert_weights.append(combined) + + # 堆叠所有专家: (num_experts, hidden_size, 2 * ffn_hidden_size) + stacked = np.stack(expert_weights, axis=0) + # reshape 为 (num_experts * hidden_size, 2 * ffn_hidden_size) + weight1 = stacked.reshape(-1, stacked.shape[-1]) + + return [weight1] + + def _mf_to_hf(self, weights: List[np.ndarray]) -> List[np.ndarray]: + """将 experts.weight1 拆分为多个专家的 gate_proj 和 up_proj""" + weight1 = weights[0] # (num_experts * hidden_size, 2 * ffn_hidden_size) + + # reshape 为 (num_experts, hidden_size, 2 * ffn_hidden_size) + weight1 = weight1.reshape(self.num_experts, self.hidden_size, -1) + # 转置为 (num_experts, 2 * ffn_hidden_size, hidden_size) + weight1 = weight1.transpose(0, 2, 1) + + result = [] + for i in range(self.num_experts): + expert_weight = weight1[i] # (2 * ffn_hidden_size, hidden_size) + gate_weight, up_weight = np.split(expert_weight, [self.ffn_hidden_size], axis=0) + result.extend([gate_weight, up_weight]) + + return result + + def mf_to_hf(self, name_to_weight: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: + """MF → HF 转换(重写以支持动态专家展开)""" + mf_name = list(name_to_weight.keys())[0] + weight = name_to_weight[mf_name] + hf_weights = self._mf_to_hf([weight]) + + # 从 MF 名称中提取层序号 + mf_pattern = self._name_to_pattern(self.mf_names[0]) + match = re.match(mf_pattern, mf_name) + layer_id = match.group(1) if match else "0" + + # 生成所有专家的 hf_names + result = {} + for i in range(self.num_experts): + # 替换两个占位符:层序号和专家序号 + gate_name = self.hf_names[0].replace("{}", layer_id, 1).replace("{}", str(i), 1) + up_name = self.hf_names[1].replace("{}", layer_id, 1).replace("{}", str(i), 1) + result[gate_name] = hf_weights[i * 2] + result[up_name] = hf_weights[i * 2 + 1] + + return result + + +@dataclass +class MoeExpertFc2ConverOp(ConverOp): + """ + MOE 专家 FC2 权重转换操作(动态 N:1 映射)。 + + 使用模板字符串动态匹配所有专家的 down_proj 权重。 + 支持完整路径格式,第一个 {} 表示层序号,第二个 {} 表示专家序号。 + + HF 格式: + - hf_names = ["model.layers.{}.mlp.experts.{}.down_proj.weight"] + + MF 格式: + - mf_names = ["decoder.layers.{}.mlp.experts.weight2"] + - shape = (num_experts * ffn_hidden_size, hidden_size) + """ + num_experts: int = None + ffn_hidden_size: int = None + hidden_size: int = None + + def set_model_config(self, config): + """从模型配置设置参数""" + self.num_experts = config.num_moe_experts + self.ffn_hidden_size = config.moe_ffn_hidden_size + self.hidden_size = config.hidden_size + + def _collect_expert_weights( + self, + name_to_weight: Dict[str, np.ndarray] + ) -> Tuple[str, List[np.ndarray], str]: + """ + 收集所有专家的权重,按专家索引排序。 + + Returns: + (mf_name, weights, layer_id): MF 名称、权重列表和层 ID + weights 顺序: [down_0, down_1, ..., down_n] + """ + # 构建正则模式 + down_pattern = self._name_to_pattern(self.hf_names[0]) + + # 收集所有专家权重 + expert_weights = {} # {expert_id: weight} + layer_id = None + + for name, weight in name_to_weight.items(): + down_match = re.match(down_pattern, name) + if down_match: + groups = down_match.groups() + layer_id = groups[0] # 第一个捕获组是层序号 + expert_id = int(groups[1]) # 第二个捕获组是专家序号 + expert_weights[expert_id] = weight + + # 按专家索引排序 + sorted_ids = sorted(expert_weights.keys()) + weights = [expert_weights[expert_id] for expert_id in sorted_ids] + + # 生成 MF 名称(替换层序号占位符) + mf_name = self.mf_names[0].replace("{}", layer_id, 1) + + return mf_name, weights, layer_id + + def hf_to_mf(self, name_to_weight: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: + """HF → MF 转换(重写以支持动态专家收集)""" + mf_name, weights, _ = self._collect_expert_weights(name_to_weight) + mf_weights = self._hf_to_mf(weights) + return {mf_name: mf_weights[0]} + + def _hf_to_mf(self, weights: List[np.ndarray]) -> List[np.ndarray]: + """ + 将多个专家的 down_proj 转换为 experts.weight2。 + + 输入顺序:[down_0, down_1, ..., down_n] + """ + expert_weights = [] + for down_weight in weights: + # down_weight: (hidden_size, ffn_hidden_size) + # 转置为 (ffn_hidden_size, hidden_size) + expert_weights.append(down_weight.T) + + # 堆叠所有专家: (num_experts, ffn_hidden_size, hidden_size) + stacked = np.stack(expert_weights, axis=0) + # reshape 为 (num_experts * ffn_hidden_size, hidden_size) + weight2 = stacked.reshape(-1, stacked.shape[-1]) + + return [weight2] + + def _mf_to_hf(self, weights: List[np.ndarray]) -> List[np.ndarray]: + """将 experts.weight2 拆分为多个专家的 down_proj""" + weight2 = weights[0] # (num_experts * ffn_hidden_size, hidden_size) + + # reshape 为 (num_experts, ffn_hidden_size, hidden_size) + weight2 = weight2.reshape(self.num_experts, self.ffn_hidden_size, self.hidden_size) + # 转置为 (num_experts, hidden_size, ffn_hidden_size) + weight2 = weight2.transpose(0, 2, 1) + + result = [weight2[i] for i in range(self.num_experts)] + return result + + def mf_to_hf(self, name_to_weight: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: + """MF → HF 转换(重写以支持动态专家展开)""" + mf_name = list(name_to_weight.keys())[0] + weight = name_to_weight[mf_name] + hf_weights = self._mf_to_hf([weight]) + + # 从 MF 名称中提取层序号 + mf_pattern = self._name_to_pattern(self.mf_names[0]) + match = re.match(mf_pattern, mf_name) + layer_id = match.group(1) if match else "0" + + # 生成所有专家的 hf_names + result = {} + for i in range(self.num_experts): + # 替换两个占位符:层序号和专家序号 + down_name = self.hf_names[0].replace("{}", layer_id, 1).replace("{}", str(i), 1) + result[down_name] = hf_weights[i] + + return result diff --git a/mindformers/checkpoint/converter/convert_utils.py b/mindformers/checkpoint/converter/convert_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7dd5f482e3d2f4318af4965f99c1b4c184fb1adf --- /dev/null +++ b/mindformers/checkpoint/converter/convert_utils.py @@ -0,0 +1,55 @@ +from typing import Union + +MF_LAYER_PREFIX = "decoder.layers." +MF_MOE_PREFIX = ".mlp.experts.local_experts." +MF_MTP_PREFIX = "mtp.layers." +MF_MTP_MOE_PREFIX = ".transformer_layer.mlp.experts.local_experts." + + +def get_layer_index(weight_name: str, prefix: str): + if not weight_name.startswith(prefix): + return None + return int(weight_name.replace(prefix, "").split(".")[0]) + + +def convert_to_mf_prefix(weight_prefix: str, prefix: str, moe_prefix: str = None): + weight_prefix = weight_prefix.replace(prefix, MF_LAYER_PREFIX, 1) + if moe_prefix is not None: + weight_prefix = weight_prefix.replace(moe_prefix, MF_MOE_PREFIX, 1) + return weight_prefix + + +def convert_to_hf_prefix(weight_prefix: str, prefix: str, moe_prefix: str = None): + weight_prefix = weight_prefix.replace(MF_LAYER_PREFIX, prefix, 1) + if moe_prefix is not None: + weight_prefix = weight_prefix.replace(MF_MOE_PREFIX, moe_prefix, 1) + return weight_prefix + + +def get_weight_prefix(weight_name: str, prefix: str, moe_prefix: str = None): + if not weight_name.startswith(prefix): + return "" + layer_index = get_layer_index(weight_name, prefix) + layer_prefix = prefix + str(layer_index) + if moe_prefix is None: + return layer_prefix + return layer_prefix + get_weight_prefix(weight_name[len(layer_prefix) :], prefix=moe_prefix) + + +def get_mf_weight_prefix(weight_name: str): + return get_weight_prefix(weight_name, MF_LAYER_PREFIX, MF_MOE_PREFIX) + + +def remove_weight_prefix(weight_name: str, prefix: str, moe_prefix: str = None): + weight_prefix = get_weight_prefix(weight_name, prefix, moe_prefix) + return weight_name.replace(weight_prefix, "", 1) + + +def remove_mf_weight_prefix(weight_name: str): + if weight_name.startswith(MF_MTP_PREFIX): + return remove_mf_mtp_weight_prefix(weight_name) + return remove_weight_prefix(weight_name, MF_LAYER_PREFIX, MF_MOE_PREFIX) + + +def remove_mf_mtp_weight_prefix(weight_name: str): + return remove_weight_prefix(weight_name, MF_MTP_PREFIX, MF_MTP_MOE_PREFIX).replace(".transformer_layer", "") diff --git a/mindformers/checkpoint/converter/template.py b/mindformers/checkpoint/converter/template.py new file mode 100644 index 0000000000000000000000000000000000000000..e895964329e0cd7aad9c8e2c27dc8fb8ca617444 --- /dev/null +++ b/mindformers/checkpoint/converter/template.py @@ -0,0 +1,316 @@ +import re +import copy +from dataclasses import dataclass, field +from functools import wraps +from typing import Dict, List, Optional, Any + +import numpy as np +from mindspore import Parameter + +from mindformers.tools.logger import logger +from mindformers.checkpoint.converter.convert_op import ConverOp + + +@dataclass +class WeightTemplate: + """ + 权重转换模板,支持 HF ↔ MF 双向转换。 + + 使用完整路径 + {} 占位符的方式定义转换规则,无需前缀处理。 + 例如: + hf_names = ["model.layers.{}.self_attn.q_proj.weight", ...] + mf_names = ["decoder.layers.{}.self_attention.linear_qkv.weight"] + """ + hf_invalid_keys: List[str] = field(default_factory=list) + weight_converters: List[ConverOp] = field(default_factory=list) + + # 由 __post_init__ 构建 + hf_name_to_converter: Dict[str, ConverOp] = field(default_factory=dict, init=False) + mf_name_to_converter: Dict[str, ConverOp] = field(default_factory=dict, init=False) + # 暂存未完成转换的权重 {converter_key: {name: weight}} + pending_weights: Dict[str, Dict[str, np.ndarray]] = field(default_factory=dict, init=False) + + def __post_init__(self): + """构建 hf_name_to_converter 和 mf_name_to_converter 映射字典""" + for converter in self.weight_converters: + # 构建 HF 名称到 ConverOp 的映射 + for hf_name in converter.hf_names: + self.hf_name_to_converter[hf_name] = converter + # 构建 MF 名称到 ConverOp 的映射 + for mf_name in converter.mf_names: + self.mf_name_to_converter[mf_name] = converter + self.release() + + def release(self): + """释放暂存的权重缓存""" + weights_not_converted = [ + (key, name, weight.size) + for key, name2weight in self.pending_weights.items() + for name, weight in name2weight.items() + ] + if len(weights_not_converted) > 0: + logger.warning(f"weights not converted {len(weights_not_converted)} {weights_not_converted}") + self.pending_weights.clear() + + def set_model_config(self, config): + """设置模型配置到需要的 ConverOp""" + for converter in self.weight_converters: + converter.mf_config = config + if hasattr(converter, 'set_model_config'): + converter.set_model_config(config) + + def get_mf_names(self, hf_name: str) -> List[str]: + """获取 HF 参数名对应的 MF 参数名列表""" + converter = self.get_converter_for_name(hf_name, self.hf_name_to_converter) + if converter is None: + return [hf_name] + + # 提取 {} 占位符的值 + match_values = self._extract_placeholder_values(hf_name, converter.hf_names[0]) + + # 替换 MF 名称中的占位符 + return [self._fill_placeholders(mf_name, match_values) for mf_name in converter.mf_names] + + def get_hf_names_for_mf(self, mf_name: str) -> List[str]: + """获取生成某个 MF 参数所需的所有 HF 参数名""" + converter = self.get_converter_for_name(mf_name, self.mf_name_to_converter) + if converter is None: + return [mf_name] + + # 提取 {} 占位符的值 + match_values = self._extract_placeholder_values(mf_name, converter.mf_names[0]) + + # 替换 HF 名称中的占位符 + return [self._fill_placeholders(hf_name, match_values) for hf_name in converter.hf_names] + + def get_mf_state_dict( + self, + hf_state_dict: Dict[str, np.ndarray] + ) -> Dict[str, Parameter]: + """将 Reshard 输出的 HF 参数字典转换为 MF 参数字典""" + mf_state_dict: Dict[str, Parameter] = {} + + for hf_name, weight in hf_state_dict.items(): + result = self.add_hf_weight(hf_name, weight) + if result is not None: + for mf_name, mf_weight in result.items(): + if not isinstance(mf_weight, Parameter): + mf_weight = Parameter(mf_weight, name=mf_name) + mf_state_dict[mf_name] = mf_weight + + self.release() + return mf_state_dict + + def add_hf_weight( + self, + hf_name: str, + weight: np.ndarray + ) -> Optional[Dict[str, np.ndarray]]: + """添加单个 HF 权重并尝试转换""" + # 检查是否在无效列表中 + if hf_name in self.hf_invalid_keys: + return None + + # 获取该权重对应的转换器 + converter = self.get_converter_for_name(hf_name, self.hf_name_to_converter) + if converter is None: + logger.warning(f"No converter found for {hf_name}") + return None + + # 生成唯一的 converter key(用于分组同一层的权重) + converter_key = self._get_converter_key(hf_name, converter) + + # 暂存权重 + if converter_key not in self.pending_weights: + self.pending_weights[converter_key] = {} + self.pending_weights[converter_key][hf_name] = weight + + # 收集该转换器需要的所有权重 + pending = self.pending_weights[converter_key] + name_to_weight = { + name: pending.pop(name) + for name in list(pending.keys()) + if converter.is_required_name(name, mf_name=False) + } + + # 执行转换 + result = converter(name_to_weight) + if result is None: + # 权重不齐,暂存并返回 None + self.pending_weights[converter_key].update(name_to_weight) + return None + + return result + + def get_converter_for_name( + self, + name: str, + pattern_to_converter: Dict[str, ConverOp] + ) -> Optional[ConverOp]: + """根据名称查找对应的转换器""" + # 精确匹配 + if name in pattern_to_converter: + return pattern_to_converter[name] + + # 模式匹配(按模式长度降序,优先匹配更具体的模式) + for pattern in sorted(pattern_to_converter, key=len, reverse=True): + re_pattern = self._pattern_to_regex(pattern) + if re.match(re_pattern, name): + return pattern_to_converter[pattern] + + return None + + def _pattern_to_regex(self, pattern: str) -> str: + """将 {} 模式转换为正则表达式""" + # 转义特殊字符,然后将 {} 替换为捕获组 + escaped = re.escape(pattern) + return escaped.replace(r"\{\}", r"(\d+)") + + def _extract_placeholder_values(self, name: str, pattern: str) -> List[str]: + """从名称中提取 {} 占位符的值""" + re_pattern = self._pattern_to_regex(pattern) + match = re.match(re_pattern, name) + if match: + return list(match.groups()) + return [] + + def _fill_placeholders(self, pattern: str, values: List[str]) -> str: + """用值填充模式中的 {} 占位符""" + result = pattern + for value in values: + result = result.replace("{}", value, 1) + return result + + def _get_converter_key(self, name: str, converter: ConverOp) -> str: + """ + 生成 converter 的唯一 key(用于分组同一层的权重)。 + + 对于 QKV 等多输入转换器,同一层的 Q/K/V 需要分到同一组。 + """ + # 使用第一个 hf_name 模式提取层信息 + pattern = converter.hf_names[0] + values = self._extract_placeholder_values(name, pattern) + + # 生成 key:converter_id + layer_info + converter_id = id(converter) + layer_info = "_".join(values) if values else "global" + return f"{converter_id}_{layer_info}" + + def get_hf_state_dict( + self, + mf_state_dict: Dict[str, np.ndarray] + ) -> Dict[str, np.ndarray]: + """将 MindFormers 参数字典转换为 HuggingFace 格式""" + hf_state_dict: Dict[str, np.ndarray] = {} + + for mf_name, weight in mf_state_dict.items(): + result = self.add_mf_weight(mf_name, weight) + if result is not None: + hf_state_dict.update(result) + + self.release() + return hf_state_dict + + def add_mf_weight( + self, + mf_name: str, + weight: np.ndarray + ) -> Optional[Dict[str, np.ndarray]]: + """添加单个 MF 权重并转换为 HF 格式""" + # 获取该权重对应的转换器 + converter = self.get_converter_for_name(mf_name, self.mf_name_to_converter) + if converter is None: + logger.warning(f"No converter found for {mf_name}") + return None + + # 生成唯一的 converter key + converter_key = self._get_converter_key_mf(mf_name, converter) + + # 暂存权重 + if converter_key not in self.pending_weights: + self.pending_weights[converter_key] = {} + self.pending_weights[converter_key][mf_name] = weight + + # 收集该转换器需要的所有权重 + pending = self.pending_weights[converter_key] + name_to_weight = { + name: pending.pop(name) + for name in list(pending.keys()) + if converter.is_required_name(name, mf_name=True) + } + + # 执行转换 + result = converter(name_to_weight, mf_to_hf=True) + if result is None: + # 权重不齐,暂存并返回 None + self.pending_weights[converter_key].update(name_to_weight) + return None + + return result + + def _get_converter_key_mf(self, name: str, converter: ConverOp) -> str: + """生成 MF 侧的 converter key""" + pattern = converter.mf_names[0] + values = self._extract_placeholder_values(name, pattern) + converter_id = id(converter) + layer_info = "_".join(values) if values else "global" + return f"{converter_id}_{layer_info}" + + +def register_template(init_func): + """ + 方法装饰器:装饰 __init__ 方法,自动创建并绑定 WeightTemplate 到网络实例。 + + 使用方式: + class TrainingQwen3ForCausalLM(Qwen3PreTrainedModel): + @register_template + def __init__(self, config): + super().__init__(config) + ... + + 装饰后,__init__ 执行完成后会自动: + 1. 基于类属性 weight_converters 等创建 WeightTemplate + 2. 调用 template.set_model_config(config) 设置模型配置 + 3. 将 template 实例绑定到 self.template 属性 + + 要求模型类或其父类定义以下类属性: + - weight_converters: List[ConverOp] # 必需,使用完整路径 + {} 占位符 + - hf_invalid_keys: List[str] # 可选,默认为空列表 + """ + @wraps(init_func) + def wrapper(self, config, *args, **kwargs): + # 调用原始 __init__ + init_func(self, config, *args, **kwargs) + + # 从继承链中查找 weight_converters + cls = type(self) + weight_converters = None + for base in cls.__mro__: + if hasattr(base, 'weight_converters') and base.weight_converters: + weight_converters = base.weight_converters + break + + if weight_converters is None: + raise ValueError(f"Class {cls.__name__} or its parents must define 'weight_converters'") + + # 查找 hf_invalid_keys(可选) + hf_invalid_keys = [] + for base in cls.__mro__: + if hasattr(base, 'hf_invalid_keys'): + hf_invalid_keys = base.hf_invalid_keys + break + + # 创建 template 实例(深拷贝 converters 避免共享状态) + template = WeightTemplate( + hf_invalid_keys=hf_invalid_keys, + weight_converters=copy.deepcopy(weight_converters), + ) + + # 设置模型配置 + transformer_config = self.convert_to_transformer_config(config) + template.set_model_config(transformer_config) + + # 绑定到网络实例 + self.template = template + + return wrapper diff --git a/mindformers/checkpoint/fully_parallel.py b/mindformers/checkpoint/fully_parallel.py index f4bc8f87121e2a795c84746e6e7efe3061ec49e3..7b8b50c95a74a1fc9115c04bab7d122f0afd3f4e 100644 --- a/mindformers/checkpoint/fully_parallel.py +++ b/mindformers/checkpoint/fully_parallel.py @@ -236,16 +236,20 @@ class BalancedSaveStrategy(): if self.rank_id == 0: param_file_mapping = [] cur_rank_id = 0 - rank_param_ids_mappings = self._get_rank_param_ids_mappings(shared_distribution) - for rank_id, params in rank_param_ids_mappings.items(): + for rank_id, params in shared_distribution.items(): if params: save_file_name = get_checkpoint_name( None, self.user_prefix, cur_rank_id, self.total_files_num, self.file_type ) - for param_id in params: - param_file_mapping.append( - (save_file_name + ".safetensors", rank_id, _reverse_sharded_tensor_shard_id(param_id))) + for shard_id in params: + rank_group = params[shard_id][1] + param_file_mapping.append(( + save_file_name + ".safetensors", + rank_id, + rank_group, + _reverse_sharded_tensor_shard_id(shard_id) + )) cur_rank_id += 1 sharded_tensor_metas = get_all_sharded_tensor(self.network, self.filter_func) @@ -255,12 +259,13 @@ class BalancedSaveStrategy(): origin_shard_metadata, origin_param_file_mapping = load_metadata( get_metadata_filename(self.checkpoint_path, iteration)) sharded_tensor_metas.update({"origin": origin_shard_metadata}) - for param_id, storage in origin_param_file_mapping.items(): + for shard_id, storage in origin_param_file_mapping.items(): for storage_item in storage: param_file_mapping.append(( storage_item["file_name"], storage_item["storage_rank"], - _reverse_sharded_tensor_shard_id(param_id) + storage_item["rank_group"], + _reverse_sharded_tensor_shard_id(shard_id) )) metadata_file_path = get_metadata_filename(self.checkpoint_path, iteration) @@ -441,4 +446,9 @@ def apply_balance_shard_strategy(network: Cell, filter_func: Callable[[str], boo else: rank_id_to_sharded_tensors[selected_rank_id] = {shard_id: (sharded_tensor, rank_group)} + rank_id_to_sharded_tensors = { + k: rank_id_to_sharded_tensors.get(k, None) + for k in sorted(rank_id_to_sharded_tensors) + } + return rank_id_to_sharded_tensors diff --git a/mindformers/checkpoint/layout.py b/mindformers/checkpoint/layout.py new file mode 100644 index 0000000000000000000000000000000000000000..0037b4f3c0523f02282f0ec2d4c68989cfb480c6 --- /dev/null +++ b/mindformers/checkpoint/layout.py @@ -0,0 +1,114 @@ +from dataclasses import dataclass +from typing import Optional, Dict, Tuple, Set, Any + +import mindspore as ms +from mindspore import Layout +from mindspore.nn import Cell +from mindspore.parallel.strategy import get_strategy_metadata, get_current_strategy_metadata + +from mindformers.tools.utils import get_real_rank +from mindformers.tools.logger import logger + +try: + from hyper_parallel.core.checkpoint.layout import get_current_layout, get_global_layout +except: + pass + + +@dataclass +class MFLayout(Layout): + """统一的布局信息表示""" + def __init__(self, device_matrix, alias_name, rank_list=None, tensor_map=None): + super().__init__(device_matrix, alias_name, rank_list) + self._tensor_map = tensor_map + + +class LayoutAdapter: + """布局适配器,统一静态图和动态图的布局信息。""" + + @staticmethod + def is_pynative_mode() -> bool: + """判断是否为动态图模式""" + return ms.get_context('mode') == ms.PYNATIVE_MODE + + @staticmethod + def get_all_layout(network: Cell) -> Dict[int, Dict[str, Tuple[Layout, str, Tuple[int, ...]]]]: + """获取所有 rank 的布局信息(统一接口)""" + if LayoutAdapter.is_pynative_mode(): + return LayoutAdapter._get_layout_from_pynative(network) + else: + return LayoutAdapter._get_layout_from_graph(network) + + @staticmethod + def get_current_layout(network: Cell) -> Dict[str, Tuple[Layout, str, Tuple[int, ...]]]: + """获取当前 rank 的布局信息(统一接口)""" + if LayoutAdapter.is_pynative_mode(): + # all_layout = LayoutAdapter._get_layout_from_pynative(network) + # cur_rank = get_real_rank() + # return all_layout.get(cur_rank, {}) + return LayoutAdapter._get_current_layout_from_pynative(network) + else: + return LayoutAdapter._get_current_layout_from_graph(network) + + @staticmethod + def _get_layout_from_pynative(network: Cell) -> Dict[int, Dict[str, Tuple]]: + """从动态图接口获取布局信息并转换""" + global_layout_dict = get_global_layout(network) + if not global_layout_dict: + return {} + + logger.info(f"global_layout_dict: {global_layout_dict}") + global_layout_dict_mf = {} + for rank_id, current_layout_dict in global_layout_dict.items(): + rank_id = int(rank_id) + global_layout_dict_mf[rank_id] = {} + for param_name, param_info in current_layout_dict.items(): + # logger.info(f"param_info: {param_info}") + # layout_dict, dtype, global_shape = param_info + + # 转换为 Layout + layout_info = MFLayout( + device_matrix=tuple(param_info['mesh_shape']), + tensor_map=tuple(param_info['tensor_map']), + rank_list=list(param_info['rank_list']), + alias_name=param_info.get('alias_name') + ) + global_layout_dict_mf[rank_id][param_name] = (layout_info, param_info['type'], param_info['full_shape']) + + logger.info(f"global_layout_dict_mf: {global_layout_dict_mf}") + return global_layout_dict_mf + + @staticmethod + def _get_current_layout_from_pynative(network: Cell) -> Dict[int, Dict[str, Tuple]]: + """从动态图接口获取布局信息并转换""" + current_layout_dict = get_current_layout(network) + if not current_layout_dict: + return {} + + assert len(current_layout_dict) == 1 + rank_id = int(current_layout_dict.keys()[0]) + current_layouts = current_layout_dict[rank_id] + current_layouts_mf = {} + for param_name, param_info in current_layouts.items(): + # layout_dict, dtype, global_shape = param_info + + # 转换为 Layout + layout_info = MFLayout( + device_matrix=tuple(param_info['mesh_shape']), + tensor_map=tuple(param_info['tensor_map']), + rank_list=list(param_info['rank_list']), + alias_name=param_info.get('alias_name') + ) + current_layouts_mf[param_name] = (layout_info, param_info['type'], param_info['full_shape']) + + return current_layouts_mf + + @staticmethod + def _get_layout_from_graph(network: Cell) -> Dict[int, Dict[str, Tuple]]: + """从静态图接口获取布局信息并转换""" + return get_strategy_metadata(network) + + @staticmethod + def _get_current_layout_from_graph(network: Cell) -> Dict[str, Tuple]: + """从静态图接口获取当前 rank 布局信息并转换""" + return get_current_strategy_metadata(network) diff --git a/mindformers/checkpoint/metadata.py b/mindformers/checkpoint/metadata.py index e2ccbe4b4d8c04eaf2f29fd16f9f81394a59cf57..ade6ca66fc5ee3a937b20fda8d3447b7c2a7ca74 100644 --- a/mindformers/checkpoint/metadata.py +++ b/mindformers/checkpoint/metadata.py @@ -17,8 +17,11 @@ import os import json import tempfile from glob import glob -from safetensors import safe_open +from typing import Dict, Tuple, List +from collections import defaultdict +from mindspore import load_checkpoint +from mindspore.nn import Cell from mindspore.communication.management import get_group_size from mindspore.common.dtype import all_types from mindspore.parallel import Layout @@ -27,10 +30,14 @@ from mindformers.tools.logger import logger from mindformers.tools.utils import set_safe_mode_for_file_or_dir from mindformers.checkpoint.sharded_tensor import build_sharded_tensor from mindformers.checkpoint.utils import ( + is_hf_checkpoint, + get_needed_hf_files, get_checkpoint_name, get_sharded_tensor_shard_id, FileType ) +from mindformers.checkpoint.sharded_tensor import ShardedTensor + tensor_to_ms_type = {str(dtype).lower(): dtype for dtype in all_types} @@ -149,11 +156,16 @@ def save_metadata(sharded_tensor_metas, param_file_mappings, meta_data_path): for param_file_mapping in param_file_mappings: file_name = param_file_mapping[0] storage_rank = param_file_mapping[1] - param_id = get_sharded_tensor_shard_id(param_file_mapping[2][0], param_file_mapping[2][1]) - if param_id not in storage_data: - storage_data[param_id] = [{"file_name": file_name, "storage_rank": storage_rank}] + rank_group = param_file_mapping[2] + shard_id = get_sharded_tensor_shard_id(param_file_mapping[3][0], param_file_mapping[3][1]) + if shard_id not in storage_data: + storage_data[shard_id] = [ + {"file_name": file_name, "storage_rank": storage_rank, "rank_group": rank_group} + ] else: - storage_data[param_id].append({"file_name": file_name, "storage_rank": storage_rank}) + storage_data[shard_id].append( + {"file_name": file_name, "storage_rank": storage_rank, "rank_group": rank_group} + ) metadata = {"state_dict_metadata": state_dict_metadata, "storage_data": storage_data} @@ -277,11 +289,16 @@ def generate_default_metadata_from_checkpoint(checkpoint_dir: str) -> tuple[dict raise NotADirectoryError( f"Checkpoint directory '{checkpoint_dir}' does not exist or is not a directory.") - logger.info("..........Load Metadata from Checkpoint Files..........") + logger.info("..........Generate Metadata from Checkpoint Files..........") # Find all safetensor files in the checkpoint directory - safetensor_pattern = os.path.join(checkpoint_dir, "*.safetensors") - safetensor_files = glob(safetensor_pattern) + if is_hf_checkpoint(checkpoint_dir): + safetensor_files = get_needed_hf_files(checkpoint_dir) + logger.info(f"Detected HuggingFace checkpoint, found {len(safetensor_files)} safetensor files") + else: + safetensor_pattern = os.path.join(checkpoint_dir, "*.safetensors") + safetensor_files = glob(safetensor_pattern) + logger.info(f"Found {len(safetensor_files)} safetensor files in directory") # Verify we found safetensor files if not safetensor_files: @@ -297,45 +314,34 @@ def generate_default_metadata_from_checkpoint(checkpoint_dir: str) -> tuple[dict # Process each safetensor file for safetensor_file in safetensor_files: file_basename = os.path.basename(safetensor_file) - logger.info(f"Extracting metadata from Safetensors file: {file_basename}") - - # Open the safetensor file and process each tensor - with safe_open(safetensor_file, framework="np", device="cpu") as f: - for param_name in f.keys(): - try: - # Load the tensor to extract its properties - tensor = f.get_tensor(param_name) # Load as numpy tensor - except Exception as e: - raise RuntimeError( - f"Failed to load tensor '{param_name}' from file '{file_basename}'" - ) from e - - # Extract tensor properties - tensor_shape = tensor.shape - ms_dtype = tensor_to_ms_type.get(str(tensor.dtype)) - global_offset = (0,) - axis_fragmentations = (1,) * len(tensor_shape) - - # Create sharded tensor metadata object - sharded_tensor = build_sharded_tensor( - param_name=param_name, - param_dtype=ms_dtype, - local_shape=tensor_shape, - global_shape=tensor_shape, - global_offset=global_offset, - axis_fragmentations=axis_fragmentations, - layout=None - ) - - # Check for duplicate parameters - if param_name in sharded_tensor_metas: - raise RuntimeError(f"Duplicate parameter_name found: {param_name}.") - - # Store metadata with fixed storage rank 0 - sharded_tensor_metas[param_name] = [sharded_tensor] - param_file_mappings[str((param_name, (0,)))] = [ - {"file_name": file_basename, "storage_rank": 0} - ] + logger.info(f"Extracting metadata from: {file_basename}") + + loaded_params = load_checkpoint( + safetensor_file, + format='safetensors' + ) + + for param_name, param in loaded_params.items(): + param_shape = tuple(param.shape) + sharded_tensor = build_sharded_tensor( + param_name=param_name, + param_dtype=param.dtype, + local_shape=param_shape, + global_shape=param_shape, + global_offset=(0,), # 多维度时每维都是 0 + axis_fragmentations=[1] * len(param_shape), + layout=None + ) + + # Check for duplicate parameters + if param_name in sharded_tensor_metas: + raise RuntimeError(f"Duplicate parameter_name found: {param_name}.") + + # Store metadata with fixed storage rank 0 + sharded_tensor_metas[param_name] = [sharded_tensor] + param_file_mappings[str((param_name, (0,)))] = [ + {"file_name": file_basename, "storage_rank": 0, "rank_group": [0]} + ] return sharded_tensor_metas, param_file_mappings @@ -345,6 +351,12 @@ def get_total_params_file_mapping_info(sharded_tensor_metas, user_prefix, model_ if sharded_tensor_metas is None: return None + shard_id_to_ranks = defaultdict(list) + for cur_npu_rank, cur_rank_sharded_tensors in sharded_tensor_metas.items(): + for sharded_tensor in cur_rank_sharded_tensors.values(): + shard_id = get_sharded_tensor_shard_id(sharded_tensor.key, sharded_tensor.global_offset) + shard_id_to_ranks[shard_id].append(cur_npu_rank) + npu_nums = get_group_size() param_file_mappings = [] for cur_npu_rank, cur_rank_sharded_tensors in sharded_tensor_metas.items(): @@ -356,7 +368,85 @@ def get_total_params_file_mapping_info(sharded_tensor_metas, user_prefix, model_ ckpt_name = get_checkpoint_name(None, user_prefix, cur_npu_rank, npu_nums, FileType.MODEL) param_file_mappings.append( - (ckpt_name + '.safetensors', cur_npu_rank, (sharded_tensor.key, sharded_tensor.global_offset)) + (ckpt_name + '.safetensors', + cur_npu_rank, + shard_id_to_ranks[get_sharded_tensor_shard_id(sharded_tensor.key, sharded_tensor.global_offset)], + (sharded_tensor.key, sharded_tensor.global_offset)) ) return param_file_mappings + + +def get_metadata_of_checkpoint(checkpoint_dir: str) -> Tuple[Dict, Dict]: + """ + Retrieves metadata from checkpoint directory, either from an existing metadata file + or by parsing checkpoint files. + + First checks for a pre-existing 'metadata.json' file in the checkpoint directory. If found, + it loads metadata from this file using load_metadata(). If not found, it generates metadata + by parsing the checkpoint files directly using load_metadata_from_checkpoint(). + + Args: + checkpoint_dir: Path to the directory containing the checkpoint files. + network: The target core network (Cell) which has method `convert_name` to convert Hugging Face weight. + + Returns: + A tuple containing two dictionaries: + - sharded_tensor_metas: Metadata about sharded tensors + - param_file_mappings: Mapping from shard IDs to storage information lists. + Each storage info dict has "file_name", "storage_rank", "rank_group" keys. + """ + logger.info("..........Load Metadata of Checkpoint..........") + + # Construct path to metadata file + metadata_path = os.path.join(checkpoint_dir, "metadata.json") + + # Load from existing metadata file if available + if os.path.exists(metadata_path): + sharded_tensor_metas, param_file_mappings = load_metadata(metadata_path) + # Otherwise generate metadata from checkpoint files + else: + sharded_tensor_metas, param_file_mappings = generate_default_metadata_from_checkpoint(checkpoint_dir) + + # if the content or format of metadata_path and checkpoint_dir are invalid, the return value of + # sharded_tensor_metas and param_file_mappings may be empty or None, + # and it may cause an error in subsequent loading process. + return sharded_tensor_metas, param_file_mappings + + +def params_key_mapping( + sharded_tensor_metas: Dict[str, List[ShardedTensor]], + network: Cell +) -> Tuple[Dict, Dict]: + """ + Mapping Hugging Face checkpoint keys to MindSpore Transformers. + + Args: + sharded_tensor_metas: Metadata about sharded tensors. + network: The network (Cell) which has method `convert_name` to convert Hugging Face weight. + + Returns: + A dictionary after mapping about sharded tensor metas. + """ + # The key of `mapped_sharded_tensor_metas` is in the network, + # such as { qkv: [ShardedTensor, ShardedTensor, ShardedTensor], ... } + mapped_sharded_tensor_metas = {} + # The key of `key_mapping` is {'weight_key': 'mapping_key'}, + # and the `mapping_key` may not have the same name as the parameter in the network, + # it could be an intermediate form, + # such as { 'q_proj': 'linear_q', 'k_proj': 'linear_k', 'v_proj': 'linear_v', ... } + key_mapping = {} + + for param_name in sharded_tensor_metas: + param_name_converted = network.convert_name(param_name) + sharded_tensor_list = sharded_tensor_metas.get(param_name) + + for sharded_tensor in sharded_tensor_list: + sharded_tensor.key = param_name_converted + sharded_tensor.org_key = param_name + + key_mapping[param_name] = param_name_converted + param_name_converted_concat = network.convert_concat_name(param_name_converted) + mapped_sharded_tensor_metas.setdefault(param_name_converted_concat, []).extend(sharded_tensor_list) + + return mapped_sharded_tensor_metas, key_mapping diff --git a/mindformers/checkpoint/reshard.py b/mindformers/checkpoint/reshard.py index f812531fca5a106e5739f347dc90ab81a52a83f6..5d22044826ed9e640d322f2e99ef472c6c175d46 100644 --- a/mindformers/checkpoint/reshard.py +++ b/mindformers/checkpoint/reshard.py @@ -218,8 +218,8 @@ class ReshardHandler: check_layout(from_layout, 'from_layout') check_layout(to_layout, 'to_layout') - if from_layout is None and to_layout is None: - raise ValueError("`from_layout` and `to_layout` cannot both be None.") + # if from_layout is None and to_layout is None: + # raise ValueError("`from_layout` and `to_layout` cannot both be None.") # Initialize basic attributes self.param_name = param_name @@ -283,10 +283,7 @@ class ReshardHandler: # Filter ranks with non-redundant data for rank_id in self.inner_from_rank_list: dev_id_list = rank_id_to_dev_id_list(self.from_dev_matrix, rank_id) - if any([ - dim not in from_dev_map and dev_id_list[dim] > 0 - for dim in range(dev_dim) - ]): + if any(dim not in from_dev_map and dev_id_list[dim] > 0 for dim in range(dev_dim)): continue inner_deredundancy_rank_list.append(rank_id) @@ -372,8 +369,11 @@ class ReshardHandler: # Create target tensor and assign slices to_slice_shape = [end - start for start, end in self.to_area] - dtype = next(iter(from_tensor_map.values())).dtype - real_tensor = Tensor(np.zeros(to_slice_shape), dtype) + current_slice = next(iter(from_tensor_map.values())) + if isinstance(current_slice, Tensor): + real_tensor = Tensor(np.zeros(to_slice_shape), current_slice.dtype) + else: + real_tensor = np.zeros(to_slice_shape, current_slice.dtype) for from_rank_id, from_slice in from_tensor_map.items(): from_area = self.global_union_area_map[from_rank_id] diff --git a/mindformers/checkpoint/reshard_loader.py b/mindformers/checkpoint/reshard_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..5b91aa52767859d5a794b135f9dfe31fc778becb --- /dev/null +++ b/mindformers/checkpoint/reshard_loader.py @@ -0,0 +1,598 @@ +import os +import copy +from time import time +from typing import Dict, List, Optional, Tuple, Any +from concurrent.futures import ThreadPoolExecutor + +import numpy as np +from mindspore import Parameter, Tensor +from mindspore import load_checkpoint as ms_load_checkpoint + +from mindformers.tools.logger import logger +from mindformers.tools.utils import get_real_rank +from mindformers.checkpoint.reshard import ReshardHandler +from mindformers.checkpoint.sharded_tensor import ShardedTensor +from mindformers.checkpoint.utils import get_sharded_tensor_shard_id + + +def smart_slice(tensor, slice_ranges, load_from_multi_rank=False): + """ + 智能切片函数。 + + Args: + tensor: 待切片的张量 + slice_ranges: 切片范围列表 [(start, end), ...] + load_from_multi_rank: 是否从多 rank 加载 + + Returns: + (sliced_tensor, is_full_slice) + """ + tensor_shape = tensor.shape + + if len(slice_ranges) != len(tensor_shape): + raise ValueError( + f"Slice dimension count ({len(slice_ranges)}) does not " + f"match tensor dimension count ({len(tensor_shape)})" + ) + + # 检查是否为完整切片 + is_full_slice = all( + start == 0 and end == dim_size + for (start, end), dim_size in zip(slice_ranges, tensor_shape) + ) + + # 完整切片且无需多卡合并时,直接返回原张量 + if is_full_slice and not load_from_multi_rank: + return tensor, is_full_slice + + # 执行切片 + slice_indices = tuple(slice(start, end) for start, end in slice_ranges) + if not load_from_multi_rank: + return tensor[slice_indices], True + + if isinstance(tensor, (Tensor, Parameter)): + sliced_tensor = copy.deepcopy(tensor.asnumpy()[slice_indices]) + else: + sliced_tensor = tensor[slice_indices] + + return sliced_tensor, is_full_slice + + +def balance_load(params: List[dict], num_groups: int) -> List[List[dict]]: + """ + 贪心负载均衡算法。 + + 按参数 size 降序排序,每次将最大参数分配给当前负载最小的 worker。 + """ + sorted_params = sorted(params, key=lambda x: x["size"], reverse=True) + groups = [{"total_size": 0, "params": []} for _ in range(num_groups)] + + for param in sorted_params: + min_group = min(groups, key=lambda g: g["total_size"]) + min_group["total_size"] += param["size"] + min_group["params"].append(param) + + return [group["params"] for group in groups] + + +class ReshardLoader: + """ + 抽象的 Reshard 加载器。 + + 提供即插即用的分布式权重加载能力,支持: + - 懒加载:使用 ms_load_checkpoint 延迟读取 + - 按需切片:只读取需要的切片 + - 多线程拼接:并行处理 reshard 操作 + + 设计说明: + - template 参数仅用于 HF 权重加载场景 + - 自训权重加载时,template 为 None,直接使用源参数名 + - HF 权重加载时,使用 template.get_mf_name() 完成参数名映射 + """ + + def __init__( + self, + checkpoint_dir: str, + dst_sharded_tensor_metas: Dict[str, ShardedTensor], + src_sharded_tensor_metas: Dict[str, List[ShardedTensor]], + param_file_mappings: Dict[Tuple[str, Tuple], List[Dict]], + num_workers: int = 1, + template: Optional["WeightTemplate"] = None + ): + """ + 初始化 ReshardLoader。 + + Args: + checkpoint_dir: 权重文件夹路径 + dst_sharded_tensor_metas: 当前任务待加载的所有参数的 ShardedTensor 字典 + 格式:{param_name: ShardedTensor} + 续训场景包括网络和优化器参数 + src_sharded_tensor_metas: 待加载权重里所有参数的 ShardedTensor 字典 + 格式:{param_name: [ShardedTensor, ...]} + - key 为权重文件中记录的原始参数名 + - value 为 ShardedTensor 列表(分布式权重可能有多个切片) + param_file_mappings: 待加载权重里所有切片的存储信息 + 格式:{(param_name, global_offset): [ + {'file_name': 'xxx.safetensors', 'storage_rank': rank_id, 'rank_group': [...]}, + ... + ]} + num_workers: Reshard 处理线程数(用于拼接阶段) + template: HF 权重转换模板实例(可选) + """ + self.checkpoint_dir = checkpoint_dir + self.dst_metas = dst_sharded_tensor_metas + self.src_metas = src_sharded_tensor_metas + self.param_file_mappings = param_file_mappings + self.num_workers = num_workers + self.template = template # HF 权重转换模板(可选) + self.rank_id = get_real_rank() + + # 预先构建双向映射字典(避免重复调用 template.get_mf_name) + self.src_to_dst_mapping, self.dst_to_src_mapping = self._build_bidirectional_mapping() + + # 计算所有参数的 offset + start_time = time() + self.params_info = self._compute_all_offsets() + self.build_all_offsets_time = time() - start_time + + def _build_bidirectional_mapping(self) -> Tuple[Dict[str, str], Dict[str, List[str]]]: + """ + 预先构建双向映射字典。 + + 对于 HF 权重,一次性遍历所有 src_name,调用 template.get_mf_name() + 构建映射关系,避免后续重复计算。 + + Returns: + Tuple[src_to_dst_mapping, dst_to_src_mapping] + - src_to_dst_mapping: {src_name: dst_name},源参数名到目标参数名的映射 + - dst_to_src_mapping: {dst_name: [src_name_1, src_name_2, ...]}, + 目标参数名到源参数名列表的映射 + + 示例(HF 权重 QKV 场景): + src_to_dst_mapping = { + "q_proj.weight": "linear_qkv.weight", + "k_proj.weight": "linear_qkv.weight", + "v_proj.weight": "linear_qkv.weight" + } + dst_to_src_mapping = { + "linear_qkv.weight": ["q_proj.weight", "k_proj.weight", "v_proj.weight"] + } + """ + src_to_dst: Dict[str, str] = {} + dst_to_src: Dict[str, List[str]] = {} + + if self.template is None: + # 自训权重场景:src_name 与 dst_name 一致 + for src_name in self.src_metas.keys(): + if src_name in self.dst_metas: + dst_name = src_name + src_to_dst[src_name] = dst_name + dst_to_src[dst_name] = [src_name] + else: + # HF 权重场景:使用 template 进行映射 + for src_name in self.src_metas.keys(): + dst_name = self.template.get_mf_name(src_name)[0] + if dst_name in self.dst_metas: + src_to_dst[src_name] = dst_name + + if dst_name not in dst_to_src: + dst_to_src[dst_name] = [] + dst_to_src[dst_name].append(src_name) + + return src_to_dst, dst_to_src + + def get_dst_name(self, src_name: str) -> str: + """ + 获取源参数名对应的目标参数名。 + + 优先使用预构建的映射,避免重复调用 template.get_mf_name() + """ + return self.src_to_dst_mapping.get(src_name, src_name) + + def get_src_names(self, dst_name: str) -> List[str]: + """ + 获取目标参数名对应的所有源参数名。 + + 用于 HF 权重场景,如获取 qkv 对应的 [q, k, v]。 + """ + return self.dst_to_src_mapping.get(dst_name, [dst_name]) + + def _compute_all_offsets(self) -> Dict[str, Dict]: + """ + Step 1: 计算当前卡待加载参数所关联的所有权重参数的 offset 信息。 + + 关键设计: + - 先遍历 dst_metas(当前卡需要加载的参数),确定需要哪些目标参数 + - 对于每个目标参数,通过 dst_to_src_mapping 找到所有相关的源参数 + - 为每个源参数计算 offset 信息(使用源参数的 shape,目标参数的 layout) + - 返回 {src_name: {...}} 格式,需要时通过 get_dst_name() 查表 + + 流程: + 1. 遍历 dst_metas,获取当前卡需要加载的目标参数 + 2. 对于每个目标参数,获取所有相关的源参数名列表 + 3. 对于每个源参数,创建 ReshardHandler 并计算 all_offset + + Returns: + params_info: { + src_name: { + "all_offset": {rank: slice_range, ...}, + "reshard_handler": ReshardHandler + } + } + + 其中 src_name 是权重文件中保存的原始参数名, + 可通过 self.get_dst_name(src_name) 获取对应的目标参数名。 + """ + params_info = {} + + # Step 1: 遍历当前卡需要加载的目标参数 + for dst_name, dst_tensor in self.dst_metas.items(): + # Step 2: 获取该目标参数对应的所有源参数名 + src_names = self.get_src_names(dst_name) + + if not src_names: + logger.warning(f"No source parameters found for dst_param: {dst_name}, skipping") + continue + + # Step 3: 对于每个源参数,计算 offset 信息 + for src_name in src_names: + # 跳过已处理的源参数(避免重复) + if src_name in params_info: + continue + + # 检查源参数是否在 src_metas 中 + if src_name not in self.src_metas: + logger.warning(f"Source parameter {src_name} not in src_metas, skipping") + continue + + src_tensor_list = self.src_metas[src_name] + src_tensor = src_tensor_list[0] if isinstance(src_tensor_list, list) else src_tensor_list + + # 创建 ReshardHandler + # 注意:使用源参数的 global_shape,目标参数的 layout + reshard_handler = ReshardHandler( + src_name, # 使用源参数名 + src_tensor.global_shape, # 使用源参数的 shape + src_tensor.layout, # 源参数的 layout + dst_tensor.layout, # 目标参数的 layout + self.rank_id + ) + + all_offset = reshard_handler.infer_all_tensor_offset() + params_info[src_name] = { + "all_offset": all_offset, + "reshard_handler": reshard_handler + } + + return params_info + + def _organize_file_load_info(self) -> Dict[str, List[Tuple[str, int, Tuple]]]: + """ + Step 2: 按文件组织加载信息。 + + 根据 param_file_mappings 确定每个参数切片存储在哪个文件中, + 将需要加载的参数按文件分组。 + + 核心逻辑: + 1. 遍历 params_info,获取每个源参数需要加载的 search_rank 列表(all_offset.keys()) + 2. 获取该 src_name 在 param_file_mappings 中的所有存储切片信息 + 3. 对于每个 search_rank,遍历所有存储信息,查找匹配的切片 + 4. 匹配条件:storage_rank == search_rank 或 search_rank in rank_group + 5. 找到匹配后,记录文件名和切片信息 + + 去冗余保存场景说明: + - 当权重使用去冗余保存(remove_redundancy=True)时,同一切片可能只由 + rank_group 中的一个 rank 保存(storage_rank) + - 此时 search_rank 和 storage_rank 可能不一致 + - 需要通过 rank_group 判断:只要 search_rank 存在于 rank_group 中, + 就可以从该文件加载 + + Returns: + {file_name: [(src_name, search_rank, param_slice), ...]} + + 数据结构说明: + - param_file_mappings 格式: + { + (param_name, global_offset): [ + { + "file_name": "model-0000001-0000008.safetensors", + "storage_rank": 0, + "rank_group": [0, 1, 2, 3] # 去冗余场景 + }, + ... + ] + } + """ + files_to_load: Dict[str, List[Tuple[str, int, Tuple]]] = {} + + for src_name, src_info in self.params_info.items(): + all_offset = src_info["all_offset"] + + # Step 1: 获取该 src_name 参数的所有切片的存储信息 + param_storage_infos = self._get_param_storage_infos(src_name) + + # Step 2: 对于每个需要加载的 search_rank,查找匹配的存储文件 + for search_rank, param_slice in all_offset.items(): + # 根据 search_rank 查找对应的存储文件名 + file_name = self._find_file_for_rank( + src_name, search_rank, param_storage_infos + ) + + if file_name not in files_to_load: + files_to_load[file_name] = [] + files_to_load[file_name].append( + (src_name, search_rank, param_slice) + ) + + return files_to_load + + def _get_param_storage_infos( + self, + param_name: str + ) -> Dict[int, Dict]: + """ + 获取指定参数的所有切片的存储信息。 + + Args: + param_name: 源参数名 + + Returns: + {storage_rank_id: {"file_name": "xxx.safetensors", "rank_group": [...]}, ...} + + 实现逻辑: + 1. 从 src_metas 获取该参数的 ShardedTensor 列表 + 2. 遍历每个 ShardedTensor,获取其 global_offset + 3. 调用 get_sharded_tensor_shard_id() 生成 mapping key + 4. 通过 key 从 param_file_mappings 查找对应的存储信息 + 5. 以 storage_rank 为 key 重新组织返回结果 + + 这样实现的优势: + - 直接利用 src_metas 中已有的切片信息,无需遍历整个 param_file_mappings + - 使用 get_sharded_tensor_shard_id 确保 key 格式与保存时一致 + - 以 storage_rank 为 key,便于后续根据 search_rank 快速查找 + """ + result: Dict[int, Dict] = {} + + # Step 1: 从 src_metas 获取该参数的 ShardedTensor 列表 + if param_name not in self.src_metas: + raise ValueError(f"Parameter '{param_name}' not found in src_metas") + + src_tensor_list = self.src_metas[param_name] + if not isinstance(src_tensor_list, list): + src_tensor_list = [src_tensor_list] + + # Step 2: 遍历每个 ShardedTensor,获取存储信息 + for sharded_tensor in src_tensor_list: + # Step 3: 生成 mapping key + mapping_key = get_sharded_tensor_shard_id(param_name, sharded_tensor.global_offset) + + # Step 4: 从 param_file_mappings 查找存储信息 + if mapping_key not in self.param_file_mappings: + raise ValueError( + f"Storage info not found for param '{param_name}' " + f"with key={mapping_key}. The source checkpoint may be incomplete." + ) + + # Step 5: 以 storage_rank 为 key 组织结果 + for storage_info in self.param_file_mappings[mapping_key]: + storage_rank = storage_info.get("storage_rank") + result[storage_rank] = { + "file_name": storage_info["file_name"], + "rank_group": storage_info.get("rank_group", []) + } + + return result + + def _find_file_for_rank( + self, + param_name: str, + search_rank: int, + param_storage_infos: Dict[int, Dict] + ) -> str: + """ + 根据 search_rank 查找对应的存储文件名。 + + Args: + param_name: 参数名(用于错误信息) + search_rank: 需要查找的 rank id + param_storage_infos: 参数的存储信息 + 格式:{storage_rank: {"file_name": ..., "rank_group": [...]}, ...} + + Returns: + str: 存储该切片的文件名 + + Raises: + ValueError: 如果找不到对应的存储文件,说明源权重不完整 + + 匹配规则: + 1. 优先直接匹配:search_rank == storage_rank + 2. 去冗余场景:search_rank in rank_group + """ + # 规则 1: 直接匹配 + if search_rank in param_storage_infos: + return param_storage_infos[search_rank]["file_name"] + + # 规则 2: 检查 rank_group(去冗余保存场景) + for storage_rank, info in param_storage_infos.items(): + rank_group = info.get("rank_group", []) + if search_rank in rank_group: + return info["file_name"] + + # 未找到,报错 + raise ValueError( + f"Cannot find storage file for parameter '{param_name}' " + f"at search_rank={search_rank}. " + f"Available storage_ranks: {list(param_storage_infos.keys())}. " + f"The source checkpoint may be incomplete or corrupted." + ) + + def _load_and_slice(self, files_to_load: Dict) -> Tuple[Dict, Dict]: + """ + Step 3: 懒加载并切片。 + + 使用 ms_load_checkpoint 懒加载权重文件,然后执行切片。 + + 关键设计: + - 对于自训权重:切片后直接得到目标参数 + - 对于 HF 权重:分别切片 q、k、v,后续由 Template.convert() 拼接 + + Returns: + (params_info_need_reshard, src_sliced_tensors) + - params_info_need_reshard: 需要 reshard 的参数信息 + - src_sliced_tensors: {src_name: sliced_tensor} 切片后的源参数 + """ + src_sliced_tensors: Dict[str, Any] = {} # 存储切片后的源参数 + params_info_need_reshard: Dict[str, Dict] = {} + + for file_name, param_infos in files_to_load.items(): + file_path = os.path.join(self.checkpoint_dir, file_name) + + # 收集需要加载的源参数名(从权重文件读取时使用源参数名) + src_names = list(set(info[0] for info in param_infos)) # src_name 在 tuple 的第 1 位 + + # 懒加载 + state_dict_from_file = ms_load_checkpoint( + file_path, + format='safetensors', + choice_func=lambda x: x in src_names + ) + + # 切片处理 + for src_name, search_rank, param_slice in param_infos: + if src_name not in state_dict_from_file: + continue + + parameter = state_dict_from_file[src_name] + src_info = self.params_info[src_name] + reshard_handler = src_info["reshard_handler"] + all_offset = src_info["all_offset"] + load_from_multi_rank = len(all_offset) > 1 + + sliced_tensor, is_full_slice = smart_slice( + parameter, param_slice, load_from_multi_rank + ) + + if is_full_slice and not load_from_multi_rank: + # 无需 reshard,直接保存切片结果 + src_sliced_tensors[src_name] = sliced_tensor + else: + # 需要 reshard,记录到 params_info_need_reshard + if src_name not in params_info_need_reshard: + params_info_need_reshard[src_name] = { + "reshard_handler": reshard_handler, + "tensor_map": {} + } + params_info_need_reshard[src_name]["tensor_map"][search_rank] = sliced_tensor + + return params_info_need_reshard, src_sliced_tensors + + def _parallel_reshard( + self, + params_info_need_reshard: Dict, + src_sliced_tensors: Dict + ) -> Dict[str, Any]: + """ + Step 4: 并行拼接需要 reshard 的参数。 + + 将各个 rank 的切片拼接成完整的源参数。 + + Returns: + 更新后的 src_sliced_tensors: {src_name: reshard 后的完整参数} + """ + if not params_info_need_reshard: + return src_sliced_tensors + + # 准备 worker 任务 + tasks = [] + for src_name, info in params_info_need_reshard.items(): + tensor_map = info["tensor_map"] + size = sum(np.prod(t.shape) for t in tensor_map.values()) + tasks.append({ + "src_name": src_name, + "reshard_handler": info["reshard_handler"], + "tensor_map": tensor_map, + "size": size + }) + + # 负载均衡分配 + worker_groups = balance_load(tasks, self.num_workers) + + def process_group(group): + """处理一组参数""" + results = {} + for task in group: + real_tensor = task["reshard_handler"].get_real_tensor(task["tensor_map"]) + real_tensor = Parameter(real_tensor, name=task["src_name"], requires_grad=False) + results[task["src_name"]] = real_tensor + return results + + # 多线程执行 + if self.num_workers > 1: + with ThreadPoolExecutor(max_workers=self.num_workers) as executor: + futures = [executor.submit(process_group, group) for group in worker_groups] + for future in futures: + src_sliced_tensors.update(future.result()) + else: + for group in worker_groups: + src_sliced_tensors.update(process_group(group)) + + return src_sliced_tensors + + def load(self) -> Dict[str, Parameter]: + """ + 执行 Reshard 加载,返回 Reshard 后的参数字典。 + + Returns: + {param_name: Parameter} 格式的参数字典 + + 返回值说明: + - 自训权重场景:param_name 与 dst_sharded_tensor_metas 中的 key 一致 + - HF 权重场景:param_name 是 HF 原始参数名(后续由 Template.convert() 转换) + + Note: + 对于 HF 权重,返回的是 Reshard 后但未执行 Convert(QKV 拼接等)的结果。 + 需要后续调用 template.convert() 完成最终转换,得到 MF 参数名的字典。 + """ + logger.info("ReshardLoader: Starting load...") + + # Step 2: 组织文件加载信息 + start_time = time() + files_to_load = self._organize_file_load_info() + + # Step 3: 懒加载并切片 + params_info_need_reshard, src_sliced_tensors = self._load_and_slice(files_to_load) + self.build_all_tensor_map_time = time() - start_time + + # Step 4: 并行拼接需要 reshard 的参数 + start_time = time() + src_sliced_tensors = self._parallel_reshard(params_info_need_reshard, src_sliced_tensors) + self.apply_parallel_load_strategy_time = time() - start_time + + # Step 5: 构建返回结果 + # 对于自训权重:直接按 MF 参数名组织 + # 对于 HF 权重:返回源参数名,由 Template.convert() 转换 + if self.template is None: + # 自训权重:src_name == mf_param_name + state_dict = {} + start_time = time() + for src_name, tensor in src_sliced_tensors.items(): + if not isinstance(tensor, Parameter): + logger.info(f"{src_name}, type: {type(tensor)}") + tensor = Parameter(tensor, name=src_name, requires_grad=False) + state_dict[src_name] = tensor + self.convert_parameter_time = time() - start_time + else: + # HF 权重:返回 {src_name: tensor},由调用者执行 template.convert() + state_dict = src_sliced_tensors + + logger.info(f"ReshardLoader: Loaded {len(state_dict)} parameters") + logger.info(f"build_all_offsets_time: {round(self.build_all_offsets_time, 6)}s") + logger.info(f"build_all_tensor_map_time: {round(self.build_all_tensor_map_time, 6)}s") + logger.info(f"apply_parallel_load_strategy_time: {round(self.apply_parallel_load_strategy_time, 6)}s") + # logger.info(f"convert_parameter_time: {round(self.convert_parameter_time, 6)}s") + return state_dict + + def info(self, dict_info, name): + logger.info(f"============={name}=============") + for k, v in dict_info.items(): + logger.info(f"{k}: {v}") diff --git a/mindformers/checkpoint/sharded_tensor.py b/mindformers/checkpoint/sharded_tensor.py index d809d0dcc7df66af3a5872ac1aa22c18ca56bd37..ba16854a77e458891a9f8f064a4e6b41bb4a0569 100644 --- a/mindformers/checkpoint/sharded_tensor.py +++ b/mindformers/checkpoint/sharded_tensor.py @@ -24,6 +24,7 @@ from mindspore.parallel.strategy import get_current_strategy_metadata, get_strat from mindformers.tools.utils import get_real_rank, get_real_group_size from mindformers.tools.logger import logger +from mindformers.checkpoint.layout import LayoutAdapter ReplicaId = Union[int, Tuple[int, ...]] @@ -82,16 +83,20 @@ class ShardedTensor: def build_sharded_tensor( - param_name: str, param_dtype: ms.dtype, local_shape: Tuple[int, ...], global_shape: Tuple[int, ...], - axis_fragmentations: Tuple[int, ...], global_offset: Tuple[int, ...], replica_id: ReplicaId = 0, - allow_shape_mismatch: bool = False, allow_to_save: bool = True, layout: Optional[ms.Layout] = None + param_name: str, param_dtype = None, local_shape: Optional[Tuple[int]] = None, + global_shape: Optional[Tuple[int]] = None, axis_fragmentations: Optional[Tuple[int]] = None, + global_offset: Tuple[int] = (0,), replica_id: ReplicaId = 0, allow_shape_mismatch: bool = False, + allow_to_save: bool = True, layout: Optional[ms.Layout] = None ) -> ShardedTensor: """Creates and returns a ShardedTensor instance with the specified parameters.""" return ShardedTensor( - key=param_name, org_key=param_name, dtype=param_dtype, local_shape=tuple(local_shape), - global_shape=tuple(global_shape), global_offset=tuple(global_offset), - axis_fragmentations=tuple(axis_fragmentations), replica_id=replica_id, - allow_shape_mismatch=allow_shape_mismatch, allow_to_save=allow_to_save, layout=layout + key=param_name, org_key=param_name, dtype=param_dtype, + local_shape=tuple(local_shape) if local_shape else local_shape, + global_shape=tuple(global_shape) if global_shape else global_shape, + global_offset=tuple(global_offset) if global_offset else global_offset, + axis_fragmentations=tuple(axis_fragmentations) if axis_fragmentations else axis_fragmentations, + replica_id=replica_id, allow_shape_mismatch=allow_shape_mismatch, allow_to_save=allow_to_save, + layout=layout ) @@ -431,7 +436,9 @@ def get_all_sharded_tensor( associated with the network. """ logger.info(".........Get All Ranks' Strategy Metadata.........") - global_strategy_info = get_strategy_metadata(network) + + # global_strategy_info = get_strategy_metadata(network) + global_strategy_info = LayoutAdapter.get_all_layout(network) if not global_strategy_info: raise RuntimeError('`get_strategy_metadata` returns `None`, which indicates there is no strategy info. ' 'Please check whether this is a distributed job.') @@ -449,6 +456,9 @@ def get_all_sharded_tensor( ) sharded_tensor_metas[cur_npu_rank] = cur_rank_sharded_tensors + + sharded_tensor_metas = {k: sharded_tensor_metas.get(k, None) for k in sorted(sharded_tensor_metas)} + return sharded_tensor_metas @@ -470,10 +480,20 @@ def get_cur_sharded_tensor( instances. """ logger.info(".........Get Current Strategy Metadata.........") - strategy_info = get_current_strategy_metadata(network)[0] - # Get sharded tensors from strategy metadata + # strategy_info = get_current_strategy_metadata(network)[0] + strategy_info = LayoutAdapter.get_current_layout(network) + + if not strategy_info: + raise RuntimeError('`get_current_layout` returns `None`, which indicates there is no strategy info. ' + 'Please check whether this is a distributed job.') + + if not len(strategy_info) == 1: + raise RuntimeError('`get_current_layout` expected to return a dictionary with one element, ' + 'but got {len(strategy_info)} elements.') + + rank_id = int(list(strategy_info.keys())[0]) cur_rank_sharded_tensors = get_sharded_tensor_from_strategy_metadata( - param_infos=strategy_info, cur_npu_rank=get_real_rank(), filter_func=filter_func + param_infos=strategy_info[rank_id], cur_npu_rank=get_real_rank(), filter_func=filter_func ) return cur_rank_sharded_tensors diff --git a/mindformers/checkpoint/utils.py b/mindformers/checkpoint/utils.py index d0c6c0ae21b04365fade733b18ce8acbf6abcba3..e0be9d0c321888ce1d9b211e5bafefe09a12d37e 100644 --- a/mindformers/checkpoint/utils.py +++ b/mindformers/checkpoint/utils.py @@ -22,7 +22,7 @@ import shutil from enum import Enum from glob import glob from pathlib import Path -from typing import Optional +from typing import Optional, List import mindspore as ms from mindspore import context @@ -454,3 +454,104 @@ def compile_model(model, dataset, mode, sink_mode, epoch=1, sink_size=1, do_eval build_time_end = time.time() build_duration = build_time_end - build_time_start logger.info(f"Time spent compiling the model: {build_duration:.2f} seconds") + + +# pylint: disable=W0212 +def get_core_network(network): + """Get the core network that has `convert_name` method.""" + if hasattr(network, '_backbone'): + return get_core_network(network._backbone) + if hasattr(network, 'network'): + return get_core_network(network.network) + return network + + +def detect_model_type(pretrained_model_dir: str) -> str: + """ + 从 config.json 检测模型类型。 + + Args: + pretrained_model_dir: HuggingFace 权重目录 + + Returns: + 模型类型字符串,如 "qwen3", "llama" + """ + config_file = os.path.join(pretrained_model_dir, "config.json") + + if not os.path.exists(config_file): + raise ValueError(f"config.json not found in {pretrained_model_dir}") + + with open(config_file, 'r', encoding='utf-8') as f: + config = json.load(f) + + model_type = config.get("model_type") + if not model_type: + raise ValueError(f"model_type not found in {config_file}") + + return model_type + + +def is_hf_checkpoint(checkpoint_path: str) -> bool: + """ + 判断权重路径是否为 HuggingFace 格式。 + + 判断逻辑: + - 存在 model.safetensors 文件,或 + - 存在 model.safetensors.index.json 文件 + + Args: + checkpoint_path: 权重目录路径 + + Returns: + bool: 是否为 HF 格式权重 + """ + if not checkpoint_path: + raise ValueError("checkpoint_path cannot be empty") + + if not os.path.exists(checkpoint_path): + raise ValueError(f"checkpoint_path does not exist: {checkpoint_path}") + + if not os.path.isdir(checkpoint_path): + raise ValueError(f"checkpoint_path is not a directory: {checkpoint_path}") + + + # 检查是否存在 model.safetensors 或 model.safetensors.index.json + has_model_safetensors = os.path.exists(os.path.join(checkpoint_path, "model.safetensors")) + has_index = os.path.exists(os.path.join(checkpoint_path, "model.safetensors.index.json")) + + return has_model_safetensors or has_index + + +def get_needed_hf_files(checkpoint_dir: str) -> List[str]: + """ + 获取 HuggingFace 权重目录下需要加载的 safetensors 文件列表。 + + 处理流程: + 1. 检查是否存在 model.safetensors.index.json + - 存在:从 index.json 解析文件列表 + - 不存在:查找目录下所有 safetensors 文件 + 2. 返回完整文件路径列表 + + Args: + checkpoint_dir: HuggingFace 权重目录 + + Returns: + List[str]: safetensors 文件的完整路径列表 + """ + index_file = os.path.join(checkpoint_dir, "model.safetensors.index.json") + + if os.path.exists(index_file): + # 从 index.json 解析文件列表 + with open(index_file, 'r', encoding='utf-8') as f: + index_data = json.load(f) + + # weight_map 格式: {"param_name": "model-00001-of-00002.safetensors"} + weight_map = index_data.get("weight_map", {}) + file_names = set(weight_map.values()) + + return [os.path.join(checkpoint_dir, fn) for fn in file_names] + else: + # 直接查找 safetensors 文件 + # 可能是单文件格式 model.safetensors + pattern = os.path.join(checkpoint_dir, "*.safetensors") + return glob(pattern) diff --git a/mindformers/models/qwen3/modeling_qwen3_train.py b/mindformers/models/qwen3/modeling_qwen3_train.py index 1257ad50f277ea79884a13d6323e762be01ef0e3..475347f310d631d21dcdfc466c1b2aacc963ee1d 100644 --- a/mindformers/models/qwen3/modeling_qwen3_train.py +++ b/mindformers/models/qwen3/modeling_qwen3_train.py @@ -20,13 +20,14 @@ from mindformers.parallel_core.transformer_config import TransformerConfig from mindformers.parallel_core.training_graph.base_models.gpt.gpt_model import GPTModel from mindformers.parallel_core.training_graph.base_models.gpt.gpt_layer_specs import get_gpt_layer_local_spec from mindformers.parallel_core.utils.model_mixin import TrainModelMixin +from mindformers.checkpoint.converter.template import register_template from mindformers.models.qwen3.utils import Qwen3PreTrainedModel from .configuration_qwen3 import Qwen3Config class TrainingQwen3ForCausalLM(TrainModelMixin, Qwen3PreTrainedModel): """ - Provide qwen2 model infer through network. + Provide qwen3 model training through network. Args: config (Qwen3Config): The config of qwen3 model. @@ -36,6 +37,7 @@ class TrainingQwen3ForCausalLM(TrainModelMixin, Qwen3PreTrainedModel): """ + @register_template def __init__(self, config: Qwen3Config): super().__init__(config, auto_prefix=False) config: TransformerConfig = self.convert_to_transformer_config(self.config) @@ -54,6 +56,7 @@ class TrainingQwen3ForCausalLM(TrainModelMixin, Qwen3PreTrainedModel): share_embeddings_and_output_weights=self.config.tie_word_embeddings, post_process=self.config.post_process ) + # 装饰器在 __init__ 执行完成后自动设置 self.template def construct( self, diff --git a/mindformers/models/qwen3/utils.py b/mindformers/models/qwen3/utils.py index 0843eecacbd0ef83a99e46a4e1484e09c621d6a1..850f107809d2ef27b71a7d208910f2667e8ce3dd 100644 --- a/mindformers/models/qwen3/utils.py +++ b/mindformers/models/qwen3/utils.py @@ -16,6 +16,11 @@ from mindformers.models.qwen3.configuration_qwen3 import Qwen3Config from mindformers.models.modeling_utils import PreTrainedModel from mindformers.parallel_core.utils.model_mixin import ModelMixin +from mindformers.checkpoint.converter.convert_op import ( + RenameConverOp, + ConcatConverOp, + QKVConverOp, +) class Qwen3PreTrainedModel(PreTrainedModel, ModelMixin): @@ -27,6 +32,61 @@ class Qwen3PreTrainedModel(PreTrainedModel, ModelMixin): config_class = Qwen3Config base_model_prefix = "Qwen3" + # HF 权重转换配置(使用 {} 作为层序号占位符) + hf_invalid_keys = [] + + # 权重转换器定义(使用完整路径 + {} 占位符) + weight_converters = [ + # ========== Embedding and Output ========== + RenameConverOp(hf_names="model.embed_tokens.weight", mf_names="embedding.word_embeddings.weight"), + RenameConverOp(hf_names="lm_head.weight", mf_names="output_layer.weight"), + RenameConverOp(hf_names="model.norm.weight", mf_names="decoder.final_layernorm.weight"), + + # ========== Attention({} 表示层序号)========== + QKVConverOp( + hf_names=[ + "model.layers.{}.self_attn.q_proj.weight", + "model.layers.{}.self_attn.k_proj.weight", + "model.layers.{}.self_attn.v_proj.weight" + ], + mf_names=["decoder.layers.{}.self_attention.linear_qkv.weight"] + ), + RenameConverOp( + hf_names="model.layers.{}.self_attn.o_proj.weight", + mf_names="decoder.layers.{}.self_attention.linear_proj.weight" + ), + RenameConverOp( + hf_names="model.layers.{}.input_layernorm.weight", + mf_names="decoder.layers.{}.input_layernorm.weight" + ), + RenameConverOp( + hf_names="model.layers.{}.self_attn.k_norm.weight", + mf_names="decoder.layers.{}.self_attention.k_layernorm.weight" + ), + RenameConverOp( + hf_names="model.layers.{}.self_attn.q_norm.weight", + mf_names="decoder.layers.{}.self_attention.q_layernorm.weight" + ), + + # ========== FFN ========== + ConcatConverOp( + hf_names=[ + "model.layers.{}.mlp.gate_proj.weight", + "model.layers.{}.mlp.up_proj.weight" + ], + mf_names=["decoder.layers.{}.mlp.linear_fc1.weight"], + dim=0 + ), + RenameConverOp( + hf_names="model.layers.{}.mlp.down_proj.weight", + mf_names="decoder.layers.{}.mlp.linear_fc2.weight" + ), + RenameConverOp( + hf_names="model.layers.{}.post_attention_layernorm.weight", + mf_names="decoder.layers.{}.pre_mlp_layernorm.weight" + ), + ] + weight_mapping = [ ('model.embed_tokens.', 'embedding.word_embeddings.'), ('.self_attn.q_proj.', '.self_attention.linear_q.'), diff --git a/mindformers/models/qwen3_moe/modeling_qwen3_moe_train.py b/mindformers/models/qwen3_moe/modeling_qwen3_moe_train.py index 331cffff8c9716995622e5119d814987f43438ec..44977996c0d0dd0251bba0bc2df3db08c9ceee50 100644 --- a/mindformers/models/qwen3_moe/modeling_qwen3_moe_train.py +++ b/mindformers/models/qwen3_moe/modeling_qwen3_moe_train.py @@ -21,6 +21,7 @@ from mindspore import Tensor from mindformers.tools.register.register import MindFormerModuleType, MindFormerRegister from mindformers.tools.logger import logger from mindformers.parallel_core.transformer_config import TransformerConfig +from mindformers.checkpoint.converter.template import register_template from mindformers.models.qwen3_moe.utils import Qwen3MoePreTrainedModel from mindformers.parallel_core.training_graph.base_models.gpt.gpt_model import GPTModel @@ -31,7 +32,7 @@ from mindformers.parallel_core.utils.model_mixin import TrainModelMixin @MindFormerRegister.register(MindFormerModuleType.MODELS) class TrainingQwen3MoeForCausalLM(TrainModelMixin, Qwen3MoePreTrainedModel): """ - Provide qwen3_moe model infer through network. + Provide qwen3_moe model training through network. Args: config (Qwen3MoeConfig): The config of qwen3_moe model. @@ -41,6 +42,7 @@ class TrainingQwen3MoeForCausalLM(TrainModelMixin, Qwen3MoePreTrainedModel): """ + @register_template def __init__(self, config): super().__init__(config, auto_prefix=False) config: TransformerConfig = self.convert_to_transformer_config(self.config) @@ -61,6 +63,7 @@ class TrainingQwen3MoeForCausalLM(TrainModelMixin, Qwen3MoePreTrainedModel): share_embeddings_and_output_weights=self.config.tie_word_embeddings, post_process=self.config.post_process, ) + # 装饰器在 __init__ 执行完成后自动设置 self.template def construct( self, diff --git a/mindformers/models/qwen3_moe/utils.py b/mindformers/models/qwen3_moe/utils.py index 3e121a2cdbaa5e1f22679e51f893ea0d32c9ead0..36699cf963c1281288b3ba548bcda638412921f1 100644 --- a/mindformers/models/qwen3_moe/utils.py +++ b/mindformers/models/qwen3_moe/utils.py @@ -12,10 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -"""Qwen3 models' utils.""" +"""Qwen3Moe models' utils.""" from mindformers.models.qwen3_moe.configuration_qwen3_moe import Qwen3MoeConfig from mindformers.models.modeling_utils import PreTrainedModel from mindformers.parallel_core.utils.model_mixin import ModelMixin +from mindformers.checkpoint.converter.convert_op import ( + RenameConverOp, + QKVConverOp, + MoeExpertFc1ConverOp, + MoeExpertFc2ConverOp, +) class Qwen3MoePreTrainedModel(PreTrainedModel, ModelMixin): @@ -27,6 +33,67 @@ class Qwen3MoePreTrainedModel(PreTrainedModel, ModelMixin): config_class = Qwen3MoeConfig base_model_prefix = "Qwen3Moe" + # HF 权重转换配置(使用 {} 作为层序号和专家序号占位符) + hf_invalid_keys = [] + + # 权重转换器定义(使用完整路径 + {} 占位符) + # 第一个 {} 表示层序号,第二个 {} 表示专家序号(在 MOE 相关权重中) + weight_converters = [ + # ========== Embedding and Output ========== + RenameConverOp(hf_names="model.embed_tokens.weight", mf_names="embedding.word_embeddings.weight"), + RenameConverOp(hf_names="lm_head.weight", mf_names="output_layer.weight"), + RenameConverOp(hf_names="model.norm.weight", mf_names="decoder.final_layernorm.weight"), + + # ========== Attention({} 表示层序号)========== + QKVConverOp( + hf_names=[ + "model.layers.{}.self_attn.q_proj.weight", + "model.layers.{}.self_attn.k_proj.weight", + "model.layers.{}.self_attn.v_proj.weight" + ], + mf_names=["decoder.layers.{}.self_attention.linear_qkv.weight"] + ), + RenameConverOp( + hf_names="model.layers.{}.self_attn.o_proj.weight", + mf_names="decoder.layers.{}.self_attention.linear_proj.weight" + ), + RenameConverOp( + hf_names="model.layers.{}.input_layernorm.weight", + mf_names="decoder.layers.{}.input_layernorm.weight" + ), + RenameConverOp( + hf_names="model.layers.{}.self_attn.k_norm.weight", + mf_names="decoder.layers.{}.self_attention.k_layernorm.weight" + ), + RenameConverOp( + hf_names="model.layers.{}.self_attn.q_norm.weight", + mf_names="decoder.layers.{}.self_attention.q_layernorm.weight" + ), + + # ========== MOE Router ========== + RenameConverOp( + hf_names="model.layers.{}.mlp.gate.weight", + mf_names="decoder.layers.{}.mlp.router.weight" + ), + RenameConverOp( + hf_names="model.layers.{}.post_attention_layernorm.weight", + mf_names="decoder.layers.{}.pre_mlp_layernorm.weight" + ), + + # ========== MOE Expert Weights(第一个 {} 是层序号,第二个 {} 是专家序号)========== + MoeExpertFc1ConverOp( + hf_names=[ + "model.layers.{}.mlp.experts.{}.gate_proj.weight", + "model.layers.{}.mlp.experts.{}.up_proj.weight" + ], + mf_names=["decoder.layers.{}.mlp.experts.weight1"] + ), + MoeExpertFc2ConverOp( + hf_names=["model.layers.{}.mlp.experts.{}.down_proj.weight"], + mf_names=["decoder.layers.{}.mlp.experts.weight2"] + ), + ] + weight_mapping = [ ('model.embed_tokens.', 'embedding.word_embeddings.'), ('.self_attn.q_proj.', '.self_attention.linear_q.'), diff --git a/mindformers/parallel_core/utils/model_mixin.py b/mindformers/parallel_core/utils/model_mixin.py index 217fa6e4c29b5cd9dd7186745ee976706f51b4b8..c0828d1022d17a310742291bd3baf6c52bc28558 100644 --- a/mindformers/parallel_core/utils/model_mixin.py +++ b/mindformers/parallel_core/utils/model_mixin.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================ """ModelMixin for train models""" +from typing import Iterable, Tuple, Set from mindspore import Tensor import mindspore.common.dtype as mstype @@ -21,6 +22,10 @@ import numpy as np from mindformers.tools.logger import logger from mindformers.core.context.build_context import is_legacy_model from mindformers.parallel_core.transformer_config_utils import convert_to_transformer_config +from mindformers.tools.utils import get_real_group_size, get_real_rank +from mindformers.checkpoint.reshard import ReshardHandler +from mindformers.checkpoint.sharded_tensor import get_cur_sharded_tensor, get_sharded_tensor_from_cell +from mindformers.checkpoint.reshard_loader import smart_slice class ModelMixin: @@ -39,6 +44,119 @@ class ModelMixin: def __init__(self): self.transformer_config = None + def load_weights_with_reshard(self, state_dict: Iterable[Tuple[str, Tensor]]): + r""" + 加载完整权重并进行 Reshard 切分。 + + 该接口用于加载完整权重(未分片的权重),并根据网络的分布式布局自动切分权重。 + + Args: + state_dict: state_dict 的生成器,遍历可获取单层权重的 name 和 full weight。 + 其中 name 和网络的 name 保证保持一致,如 + decoder.layer.0.self_attention.linear_qkv.weight。 + + 处理流程: + 1. 遍历 state_dict,获取单层权重的 name 和 weight + 2. 获取权重名对应的网络参数:param = self.parameters_dict()[name] + 3. 获取该层参数的 dst_sharded_tensor(因为 weight 是完整权重,src_sharded_tensor 为 None) + 4. 对 weight 进行 Reshard 切分,获取目标切片 + 5. 加载权重切片到对应的网络层:param.set_data(sliced_weight) + + Raises: + KeyError: 如果权重名不在网络参数中 + ValueError: 如果权重形状不匹配或 Reshard 失败 + """ + logger.info("..........Start Loading Weights with Reshard..........") + + # 获取当前 rank 的所有参数的 sharded tensor metadata + # 使用 filter_func 过滤出网络参数(排除优化器参数等) + def filter_func(param_name: str) -> bool: + return param_name in self.parameters_dict() + + dst_sharded_tensor_metas = get_cur_sharded_tensor(self, filter_func=filter_func) \ + if get_real_group_size() > 1 else get_sharded_tensor_from_cell(self) + + loaded_params: Set[str] = set() + rank_id = get_real_rank() + + # 遍历 state_dict,加载每个权重 + for name, full_weight in state_dict: + if name not in self.parameters_dict(): + logger.warning(f"Parameter '{name}' not found in network, skipping") + continue + + # Step 1: 获取权重名对应的网络参数 + param = self.parameters_dict()[name] + + # Step 2: 获取该层参数的 dst_sharded_tensor + if name not in dst_sharded_tensor_metas: + logger.warning(f"Sharded tensor metadata not found for '{name}', skipping") + continue + + dst_sharded_tensor = dst_sharded_tensor_metas[name] + + # Step 3: 对 weight 进行 Reshard 切分 + # 创建 ReshardHandler:from_layout=None(完整权重),to_layout=dst_sharded_tensor.layout + try: + # 获取完整权重的 shape + full_weight_shape = full_weight.shape + + # 验证权重形状是否匹配 + if full_weight_shape != dst_sharded_tensor.global_shape: + raise ValueError( + f"Weight shape mismatch for '{name}': " + f"expected {dst_sharded_tensor.global_shape}, got {full_weight_shape}" + ) + + # 创建 ReshardHandler + # from_layout=None 表示完整权重(未分片) + # to_layout=dst_sharded_tensor.layout 表示目标分片布局 + reshard_handler = ReshardHandler( + param_name=name, + full_shape=full_weight_shape, + from_layout=None, # 完整权重,未分片 + to_layout=dst_sharded_tensor.layout, # 目标分片布局 + to_rank_id=rank_id + ) + + # 计算当前 rank 需要的切片范围 + # infer_all_tensor_offset() 会计算 self.to_area(全局切片范围) + all_offset = reshard_handler.infer_all_tensor_offset() + assert 0 in all_offset and len(all_offset) == 1 + slice_ranges = all_offset[0] + + # 使用 smart_slice 切分权重 + sliced_weight, is_full_slice = smart_slice( + full_weight, + slice_ranges, + load_from_multi_rank=False # 完整权重,不需要多 rank 合并 + ) + + if is_full_slice: + sliced_weight = full_weight + else: + sliced_weight = Tensor(sliced_weight, dtype=param.data.dtype) + + # Step 4: 加载权重切片到对应的网络层 + param.set_data(sliced_weight) + loaded_params.add(name) + + logger.debug(f"Loaded parameter '{name}' with shape {sliced_weight.shape}") + + except Exception as e: + logger.error(f"Failed to load parameter '{name}': {str(e)}") + raise + + # 记录未加载的参数 + network_not_load = set(self.parameters_dict().keys()) - loaded_params + if network_not_load: + logger.warning(f'These parameters are not loaded in the network: {network_not_load}') + else: + logger.info(f'Successfully loaded {len(loaded_params)} parameters') + + logger.info("..........Loading Weights with Reshard Finished..........") + + def convert_concat_name(self, weight_name): r""" convert HuggingFace weight name to MindFormers weight name. diff --git a/mindformers/trainer/base_trainer.py b/mindformers/trainer/base_trainer.py index 68948aa44a6f3beb9299d29cc2add800e11c0d45..9c28566d71a525a72074294f16f68f7b731b0803 100644 --- a/mindformers/trainer/base_trainer.py +++ b/mindformers/trainer/base_trainer.py @@ -67,8 +67,8 @@ from mindformers.core.callback.callback import ( ) from mindformers.modules.seq_pipe import SequenceSplit from mindformers.utils.load_checkpoint_utils import get_load_path_after_hf_convert -from mindformers.checkpoint.checkpoint import load_checkpoint, CommonInfo -from mindformers.checkpoint.utils import compile_model +from mindformers.checkpoint.checkpoint import load_checkpoint, load_hf_checkpoint, CommonInfo +from mindformers.checkpoint.utils import compile_model, is_hf_checkpoint from mindformers.dataset.dataloader.hf_dataloader import _resume_hf_iterable_dataset from ..core.config_args import ConfigArguments from .training_args import TrainingArguments @@ -1465,6 +1465,11 @@ class BaseTrainer: epoch=config.runner_config.epochs, sink_size=config.runner_config.sink_size) if config.resume_training: + if is_hf_checkpoint(config.load_checkpoint): + raise ValueError( + "Resume training is not supported for HuggingFace checkpoints." + ) + logger.info(".............Start resume training from checkpoint..................") global_step = common_info.global_step if common_info.global_batch_size != self.global_batch_size: @@ -1479,14 +1484,24 @@ class BaseTrainer: network=network, optimizer=optimizer, global_step=global_step, - balanced_load=config.balanced_load + balanced_load=config.balanced_load, + load_worker_number=config.load_worker_number or 1 ) else: - load_checkpoint( - checkpoint=config.load_checkpoint, - network=network, - balanced_load=config.balanced_load - ) + if is_hf_checkpoint(config.load_checkpoint): + load_hf_checkpoint( + pretrained_model_dir=config.load_checkpoint, + network=network, + balanced_load=config.balanced_load, + load_worker_number=config.load_worker_number or 1 + ) + else: + load_checkpoint( + checkpoint=config.load_checkpoint, + network=network, + balanced_load=config.balanced_load, + load_worker_number=config.load_worker_number or 1 + ) elif (config.load_checkpoint or config.only_save_strategy) and not check_is_reboot_node(): if config.resume_training: logger.info(".............Start resume training from checkpoint..................")