diff --git a/mindformers/checkpoint/checkpoint.py b/mindformers/checkpoint/checkpoint.py index 4bdfb6aace0fde630559db6a155afa34c2abadc4..81140fbedbf31e716557e4396a28c5315937aaba 100644 --- a/mindformers/checkpoint/checkpoint.py +++ b/mindformers/checkpoint/checkpoint.py @@ -15,16 +15,16 @@ """load/save checkpoint apis.""" import os import json +import copy import tempfile from time import time from typing import Callable, Union, Dict, Optional, Tuple, Any, List from dataclasses import dataclass +import concurrent.futures import threading from multiprocessing import active_children -from safetensors import safe_open -import mindspore as ms from mindspore import Tensor, Parameter, load_param_into_net from mindspore.common import dtype as mstype from mindspore.nn import Cell @@ -32,6 +32,7 @@ from mindspore.nn.optim.optimizer import Optimizer from mindspore.communication.management import get_rank, get_group_size from mindspore.communication import comm_func from mindspore import save_checkpoint as ms_save_checkpoint +from mindspore import load_checkpoint as ms_load_checkpoint from mindformers.tools.logger import logger from mindformers.checkpoint.reshard import ReshardHandler @@ -48,21 +49,22 @@ from mindformers.checkpoint.utils import ( get_checkpoint_name, get_checkpoint_tracker_filename, get_latest_iteration_from_tracker, + get_sharded_tensor_shard_id, get_common_filename, check_checkpoints_dir_max_num, get_metadata_filename, verify_ckpt_valid, + get_core_network, FileType ) from mindformers.checkpoint.fully_parallel import BalancedSaveStrategy, apply_balance_shard_strategy from mindformers.checkpoint.metadata import ( save_metadata, - load_metadata, - generate_default_metadata_from_checkpoint, get_total_params_file_mapping_info, + get_metadata_of_checkpoint, + params_key_mapping ) from mindformers.checkpoint.sharded_tensor import ( - get_strategy_info_from_sharded_tensor, ShardedTensor, get_sharded_tensor_from_cell, get_cur_sharded_tensor, @@ -441,444 +443,560 @@ def save_metadata_json(sharded_tensor_metas, model_keys, user_prefix, metadata_f logger.info("No need to save metadata.json for single card.") -def load_safetensor( - checkpoint_path: str, - param_name: Optional[Union[str, List[str]]] = None, - index_tuple: Optional[Union[Tuple[Tuple[int, int], ...], List[Tuple[Tuple[int, int], ...]]]] = None, - dtype: Optional[Union[Any, List[Any]]] = mstype.float32 -) -> Dict[str, Parameter]: +def smart_slice(tensor, slice_ranges, load_from_multi_rank=False): """ - Loads tensors from a Safetensors file into MindSpore Parameters. - - This function reads a Safetensors file and converts specified tensors into MindSpore - Parameter objects with the specified data type. It can load either specific tensors - (with optional slicing) or all tensors from the file. + Slices a tensor based on specified slice ranges and determines if it's a full slice. Args: - checkpoint_path: Path to the Safetensors file to load - param_name: Optional name or list of names of specific tensors to load. - If None, loads all tensors from the file. - index_tuple: Optional slicing indices for tensors. Can be a single tuple of (start, end) - index pairs or a list of such tuples. Must match the dimension of the - corresponding tensor if provided. - dtype: Target data type(s) for the loaded tensors. Can be a single type or a list of types. - Defaults to mstype.float32. + tensor: The tensor to slice (can be Parameter, Tensor, or have .shape attribute) + slice_ranges: List of (start, end) tuples specifying slice ranges for each dimension + load_from_multi_rank: If True, forces slicing even for full slices (for multi-rank loading) Returns: - Dictionary mapping tensor names to MindSpore Parameter objects + Tuple[sliced_tensor, is_full_slice]: + - sliced_tensor: The original tensor if full slice and not load_from_multi_rank, + otherwise the sliced numpy array + - is_full_slice: True if the slice covers the entire tensor, False otherwise Raises: - ValueError: If the file doesn't exist, if index_tuple dimension doesn't match - tensor dimension, or if parameter lists have mismatched lengths - KeyError: If a specified parameter name doesn't exist in the file + ValueError: If slice dimension count doesn't match tensor dimension count """ - # Validate file existence - if not os.path.exists(checkpoint_path): - raise FileNotFoundError(f"Safetensors file not found at: {checkpoint_path}") - - # Warn about unused index_tuple when no parameter name is specified - if param_name is None and index_tuple is not None: - logger.warning("index_tuple is ignored when param_name is None (loading all parameters)") - - def _convert_to_list(param: Optional[Union[object, List[object]]]) -> Optional[List[object]]: - """Convert parameter to list if it's not already a list""" - if param is None: - return None - return [param] if not isinstance(param, list) else param - - def _align_list_length(param_list: List[object], target_length: int) -> List[object]: - """Align list length to match target length, supporting broadcasting of single-element lists""" - if len(param_list) == target_length: - return param_list - if len(param_list) == 1: - return param_list * target_length - raise ValueError(f"List length {len(param_list)} cannot be aligned to target length {target_length}") - - # Unify parameters to list format for consistent processing - param_name_list: Optional[List[str]] = _convert_to_list(param_name) - index_tuple_list: Optional[List[Tuple[Tuple[int, int], ...]]] = _convert_to_list(index_tuple) - dtype_list: Optional[List[Any]] = _convert_to_list(dtype) - - # Validate parameter list length consistency - if param_name_list is not None: - # Validate index list length - if index_tuple_list is not None and len(index_tuple_list) != len(param_name_list): - raise ValueError( - f"Length of index_tuple ({len(index_tuple_list)}) must match " - f"length of param_name ({len(param_name_list)})" - ) + # Get tensor shape - handle both Parameter and Tensor types + tensor_shape = tensor.shape - # Validate data type list length - if dtype_list is not None: - dtype_list = _align_list_length(dtype_list, len(param_name_list)) - - weights: Dict[str, Parameter] = {} - - # Load data from Safetensors file - with safe_open(checkpoint_path, framework="np", device="cpu") as f: - if param_name_list: - # Load specified parameters - for idx, param_name_ in enumerate(param_name_list): - # Get data type for current parameter - cur_dtype = dtype_list[idx] if dtype_list and dtype_list[idx] else mstype.float32 - - # Load tensor from file - try: - tensor_np = f.get_tensor(param_name_) - except KeyError as e: - raise KeyError(f"Parameter '{param_name_}' not found in Safetensors file") from e - - # Apply slicing if specified - if index_tuple_list is not None: - index_tuple = index_tuple_list[idx] - if len(index_tuple) != tensor_np.ndim: - raise ValueError( - f"Index tuple dimension ({len(index_tuple)}) does not match " - f"parameter '{param_name_}' dimension ({tensor_np.ndim})" - ) - # Create slice objects and apply - slices = tuple(slice(start, end) for start, end in index_tuple) - tensor_np = tensor_np[slices] - - # Convert to MindSpore Parameter - weights[param_name_] = Parameter( - ms.from_numpy(tensor_np).astype(cur_dtype), name=param_name_, requires_grad=False - ) - else: - # Load all parameters - cur_dtype = dtype if not isinstance(dtype, list) else dtype[0] - for key in f.keys(): - tensor_np = f.get_tensor(key) - weights[key] = Parameter( - ms.from_numpy(tensor_np).astype(cur_dtype), name=key, requires_grad=False - ) + if len(slice_ranges) != len(tensor_shape): + raise ValueError( + f"Slice dimension count ({len(slice_ranges)}) does not " + f"match tensor dimension count ({len(tensor_shape)})" + ) - return weights + # Check if this is a full slice + is_full_slice = True + for i, (start, end) in enumerate(slice_ranges): + dim_size = tensor_shape[i] + if start != 0 or end != dim_size: + is_full_slice = False + break + + # If full slice and not loading from multiple ranks, return original tensor + if is_full_slice and not load_from_multi_rank: + return tensor, is_full_slice + + # Perform the slice + slice_indices = tuple(slice(start, end) for start, end in slice_ranges) + if isinstance(tensor, (Tensor, Parameter)): + # MindSpore Tensor/Parameter + sliced_tensor = copy.deepcopy(tensor.asnumpy()[slice_indices]) + else: + # Numpy array or other array-like + sliced_tensor = copy.deepcopy(tensor[slice_indices]) + return sliced_tensor, is_full_slice -def load_tensor_by_offset( - all_offset: Dict[int, Tuple[Tuple[int, int], ...]], - param_name: str, - checkpoint_dir: str, - src_sharded_tensor_metas: Dict[str, List[ShardedTensor]], - param_file_mappings: Dict[str, List[Dict[str, Any]]], - key_mapping: Dict[str, str], -) -> Dict[int, Parameter]: + +def build_tensors_reformulation_all_offsets(params_with_sharded_tensor): + """ + Builds parameter information with all tensor offsets for sharded tensor resharding. + + Args: + params_with_sharded_tensor (dict): + A dictionary mapping parameter names to tuples of sharded tensor objects. + Each tuple has the format (src_sharded_tensor_or_list, dst_sharded_tensor): + - src_sharded_tensor_or_list: Source sharded tensor object(s). Can be a single + ShardedTensor or a list of ShardedTensors (for concat scenarios) + - dst_sharded_tensor: Destination sharded tensor object with target layout info + + Returns: + params_info (dict): A nested dictionary containing all offset information for each parameter. + Structure: { + param_name: { + "reshard_handler": ReshardHandler instance, + "all_offset": Dict mapping ranks to tensor slice ranges (start/end tuples) + } + } """ - Loads specific tensor slices from checkpoint files based on offset information. + rank_id = get_real_rank() + params_info = {} + for param_name, (src_sharded_tensor, dst_sharded_tensor) in params_with_sharded_tensor.items(): + # Check if this is a concat scenario (src is a list) + if isinstance(src_sharded_tensor, list): + src_sharded_tensor = src_sharded_tensor[0] + + reshard_handler = ReshardHandler( + param_name, + dst_sharded_tensor.global_shape, + src_sharded_tensor.layout, + dst_sharded_tensor.layout, + rank_id + ) + all_offset = reshard_handler.infer_all_tensor_offset() + params_info[param_name] = {"all_offset": all_offset, "reshard_handler": reshard_handler} + return params_info - Retrieves the appropriate segments of a tensor from checkpoint files according to - the provided offset information. Handles storage rank mapping and potential resharding - to ensure the correct tensor slices are loaded for each rank. + +def build_tensors_reformulation_mappings( + checkpoint_dir, + params_info, + src_sharded_tensor_metas, + param_file_mappings, + key_mapping: Dict[str, str] = None +): + """Builds tensor resharding mappings and splits parameters into reshard-required and no-reshard groups. + + This function orchestrates the process of: + 1. Resolving parameter storage information (file paths, rank groups) from checkpoint metadata + 2. Mapping search ranks to storage ranks and adjusting tensor offset information + 3. Organizing checkpoint file load information by file name + 4. Loading checkpoint files, slicing tensors according to computed offsets, and classifying parameters + into those that need resharding (params_info_need_reshard) and those that don't (state_dict_no_reshard) Args: - all_offset: Dictionary mapping ranks to their respective tensor slice offsets - param_name: Name of the parameter/tensor to load - checkpoint_dir: Directory containing the checkpoint files - src_sharded_tensor_metas: Metadata for source sharded tensors - param_file_mappings: Mapping of parameters to their storage files - key_mapping: Mapping of `original key` in checkpoint to `param key` in network. + checkpoint_dir: Root directory containing checkpoint files + params_info: Nested dictionary with resharding metadata for each parameter. + Structure: { + param_name: { + "reshard_handler": ReshardHandler instance, + "all_offset": Dict mapping ranks to tensor slice ranges (start/end tuples) + } + } + src_sharded_tensor_metas: Dictionary mapping parameter names to a list of source ShardedTensor metadata + (key: parameter name, value: list of ShardedTensor instances for source) + param_file_mappings: Mapping from shard IDs to storage information lists. + Each storage info dict has "file_name", "storage_rank", "rank_group" keys. + key_mapping: Optional dictionary for parameter name remapping (used when original + name not found in src_sharded_tensor_metas). Defaults to None. Returns: - Dictionary mapping ranks to their corresponding loaded Parameter objects + Tuple containing two dictionaries: + 1. params_info_need_reshard: Nested dictionary containing resharding metadata for parameters needing resharding. + Structure: { + param_name: { + "reshard_handler": ReshardHandler instance (handles tensor resharding logic), + "tensor_map": Dict mapping source ranks to sliced numpy tensor fragments + } + } + 2. state_dict_no_reshard: Parameters that don't need resharding (full tensor slices), + keyed by parameter name with Parameter objects as values. + + Raises: + ValueError: If parameter name not found in src_sharded_tensor_metas and no key_mapping provided, + or if slice dimension count doesn't match tensor dimension count in _smart_slice. + RuntimeError: If no matching storage rank found for a search rank in offset mapping. """ - def _get_storage_info_of_sharded_tensor( - sharded_tensor: ShardedTensor, - param_file_mappings: Dict[str, List[Dict[str, Any]]] - ) -> List[Dict[str, Any]]: - """Retrieves storage information for a specific sharded tensor.""" - param_key = str((sharded_tensor.org_key, sharded_tensor.global_offset)) - return param_file_mappings[param_key] - - def _get_storage_rank_dict_of_param( - sharded_tensor_metas: Dict[str, List[ShardedTensor]], - param_file_mappings: Dict[str, List[Dict[str, Any]]], - param_name: str - ) -> Dict[int, Tuple[str, Any]]: - """Creates a dictionary mapping storage ranks to their file and dtype information.""" - storage_rank_dict: Dict[int, Tuple[str, Any]] = {} + def _get_param_storage_info(param_name, sharded_tensor_metas, param_file_mappings, key_mapping): + """Retrieves storage information (file name, rank group) for a parameter by rank.""" + param_storage_info = {} + lookup_name = param_name if param_name not in sharded_tensor_metas: - param_name = key_mapping[param_name] - - for sharded_tensor in sharded_tensor_metas[param_name]: - storage_info_list = _get_storage_info_of_sharded_tensor(sharded_tensor, param_file_mappings) + if not key_mapping: + raise ValueError(f"param name '{param_name}' not found in src sharded tensor metas.") + lookup_name = key_mapping.get(param_name, param_name) + if lookup_name not in sharded_tensor_metas: + raise ValueError( + f"param name '{param_name}' (mapped to '{lookup_name}') not found in src sharded tensor metas." + ) + sharded_tensors = sharded_tensor_metas[lookup_name] + for sharded_tensor in sharded_tensors: + shard_id = get_sharded_tensor_shard_id( + sharded_tensor.org_key, sharded_tensor.global_offset) + storage_info_list = param_file_mappings[shard_id] for storage_info in storage_info_list: + file_name = storage_info["file_name"] storage_rank = storage_info["storage_rank"] - storage_rank_dict[storage_rank] = (storage_info["file_name"], sharded_tensor.dtype) - return storage_rank_dict - - # Get storage rank information for the parameter - storage_rank_dict = _get_storage_rank_dict_of_param( - src_sharded_tensor_metas, param_file_mappings, param_name) - - # Map storage ranks to source ranks and adjust offsets - storage_to_src_rank_mapping: Dict[int, int] = {} - for search_rank in list(all_offset.keys()): # Iterate over copy of keys to allow modification - if search_rank not in storage_rank_dict: - # Get first source sharded tensor for this parameter - src_sharded_tensor = next(iter(src_sharded_tensor_metas[param_name])) - - find_storage_rank = False - # Find matching storage rank using reshard handler - for storage_rank in storage_rank_dict: - reshard_handler = ReshardHandler( - param_name=param_name, - full_shape=src_sharded_tensor.global_shape, - from_layout=src_sharded_tensor.layout, - to_layout=src_sharded_tensor.layout, # No actual layout change - to_rank_id=storage_rank - ) + rank_group = storage_info["rank_group"] + param_storage_info[storage_rank] = {"file_name": file_name, "rank_group": rank_group} + return param_storage_info + + def _collect_files_load_info(params_info, sharded_tensor_metas, param_file_mappings, key_mapping): + """Collects and organizes checkpoint file load information for sharded tensor resharding.""" + files_load_info: Dict[str, Dict] = {} + + for param_name, reshard_info in params_info.items(): + reshard_handler = reshard_info["reshard_handler"] + all_offset = reshard_info.get("all_offset", reshard_handler.infer_all_tensor_offset()) + + param_storage_info = _get_param_storage_info( + param_name, sharded_tensor_metas, param_file_mappings, key_mapping) + + # Map storage ranks to source ranks and adjust offsets + storage_to_search_rank_mapping: Dict[int, int] = {} + for search_rank in list(all_offset.keys()): + if search_rank not in param_storage_info: + find_storage_rank = False + for storage_rank, storage_info in param_storage_info.items(): + rank_group = storage_info["rank_group"] + if search_rank in rank_group: + # Update offset mapping and record rank correspondence + all_offset[storage_rank] = all_offset.pop(search_rank) + storage_to_search_rank_mapping[storage_rank] = search_rank + find_storage_rank = True + if not find_storage_rank: + raise RuntimeError("Failed to find matching storage rank.\n" + f"param: {param_name}\nall_offset: {all_offset}\n" + f"param_storage_info: {param_storage_info}") + else: + storage_to_search_rank_mapping[search_rank] = search_rank + + # Organize load info by storage file name + for storage_rank, param_slice in all_offset.items(): + param_file_name = param_storage_info[storage_rank]["file_name"] + search_rank = storage_to_search_rank_mapping[storage_rank] + + if param_file_name not in files_load_info: + files_load_info[param_file_name] = { + "param_name_list": [param_name], + "param_slice_list": [param_slice], + "search_rank_list": [search_rank] + } + else: + files_load_info[param_file_name]["param_name_list"].append(param_name) + files_load_info[param_file_name]["param_slice_list"].append(param_slice) + files_load_info[param_file_name]["search_rank_list"].append(search_rank) + + return files_load_info + + files_load_info = _collect_files_load_info( + params_info=params_info, + sharded_tensor_metas=src_sharded_tensor_metas, + param_file_mappings=param_file_mappings, + key_mapping=key_mapping + ) - # Get source rank from reshard handler - src_rank = next(iter(reshard_handler.infer_all_tensor_offset().keys())) + state_dict_no_reshard: Dict[str, Parameter] = {} + params_info_need_reshard: Dict[str, Dict] = {} - if src_rank == search_rank: - # Update offset mapping and record rank correspondence - all_offset[storage_rank] = all_offset.pop(search_rank) - storage_to_src_rank_mapping[storage_rank] = src_rank - find_storage_rank = True - break + for file_name, load_info in files_load_info.items(): + param_name_list = load_info["param_name_list"] + param_slice_list = load_info["param_slice_list"] + search_rank_list = load_info["search_rank_list"] + + file_path = os.path.join(checkpoint_dir, file_name) + state_dict_from_file = ms_load_checkpoint( + file_path, + format='safetensors', + choice_func=lambda x, lst=param_name_list: x in lst + ) - if not find_storage_rank: - raise RuntimeError("Failed to find matching storage rank for the parameter") + for param_name, param_slice, search_rank in zip(param_name_list, param_slice_list, search_rank_list): + reshard_handler = params_info[param_name]["reshard_handler"] + all_offset = params_info[param_name].get("all_offset", reshard_handler.infer_all_tensor_offset()) + load_from_multi_rank = len(all_offset) > 1 + + # Get parameter from state_dict and slice it + parameter = state_dict_from_file.pop(param_name) + sliced_tensor, is_full_slice = smart_slice(parameter, param_slice, load_from_multi_rank) + + if is_full_slice and not load_from_multi_rank: + # No reshard needed, directly add to state_dict (sliced_tensor is the original parameter) + mapped_name = key_mapping.get(param_name, param_name) if key_mapping else param_name + if not isinstance(sliced_tensor, Parameter): + sliced_tensor = Parameter(sliced_tensor, name=mapped_name, requires_grad=False) + state_dict_no_reshard[mapped_name] = sliced_tensor + continue + + if param_name not in params_info_need_reshard: + params_info_need_reshard[param_name] = { + "reshard_handler": reshard_handler, + "tensor_map": {} + } - # Load tensor slices for each rank - from_tensor_map: Dict[int, Parameter] = {} - for search_rank, param_slice in all_offset.items(): - # Get file information and load the specific tensor slice - param_file_name, param_dtype = storage_rank_dict[search_rank] - param_file_path = os.path.join(checkpoint_dir, param_file_name) + # Initialize or update the tensor map for the parameter with the rank-specific slice + params_info_need_reshard[param_name]["tensor_map"][search_rank] = sliced_tensor - # Load the specific slice from the safetensor file - loaded_weights = load_safetensor(param_file_path, param_name, param_slice, param_dtype) + return params_info_need_reshard, state_dict_no_reshard - # Use original source rank if mapping exists - mapped_rank = storage_to_src_rank_mapping.get(search_rank, search_rank) - from_tensor_map[mapped_rank] = loaded_weights[param_name] - return from_tensor_map +def apply_parallel_load_strategy( + params_info: Dict[str, Dict], + state_dict: Dict[str, Parameter], + key_mapping: Dict = None, + num_workers: int = 1 +): + """Applies parallel loading strategy to reshard and merge tensor slices into full parameters. + This function distributes resharding workload across multiple worker threads to optimize + the process of merging sliced tensor fragments (from multiple ranks) into complete parameters. + Key steps include: + 1. Collecting parameter metadata (tensor slices, reshard handlers, size info) + 2. Balancing workload across workers by tensor size (to avoid load imbalance) + 3. Processing parameter groups in parallel to reconstruct full tensors + 4. Updating the state dict with reconstructed parameters -def categorize_params( - dst_sharded_tensor_metas: Dict[str, ShardedTensor], - src_sharded_tensor_metas: Dict[str, List[ShardedTensor]], - param_file_mappings: Dict[str, List[Dict[str, Any]]] -) -> Tuple[List[str], Dict[str, Dict[str, List[Any]]], Dict[str, Dict[str, List[Any]]], Dict[str, List[Any]]]: + Args: + params_info: Nested dictionary containing resharding metadata for parameters needing resharding. + Structure: { + param_name: { + "reshard_handler": ReshardHandler instance (handles tensor resharding logic), + "tensor_map": Dict mapping source ranks to sliced numpy tensor fragments + } + } + state_dict: Target state dictionary to populate with reconstructed parameters. + Will be updated in-place with merged parameter tensors. + key_mapping: Optional dictionary for parameter name remapping (used when original + name not found in src_sharded_tensor_metas). Defaults to None. + num_workers: Maximum number of worker threads to use for parallel processing. + Actual workers used are min(num_workers, number of parameters) to avoid over-threading. + Defaults to 1 (single-threaded). """ - Categorizes parameters based on comparison of source and destination sharding strategies. + total_param_info = [] + for param_name, param_info in params_info.items(): + reshard_handler = param_info["reshard_handler"] + tensor_map = param_info["tensor_map"] + + total_size = sum(t.size for t in tensor_map.values()) + total_param_info.append({ + "param_name": key_mapping.get(param_name, param_name) if key_mapping else param_name, + "from_tensor_map": tensor_map, + "reshard_handler": reshard_handler, + "size": total_size + }) - Analyzes parameters from destination and source sharded tensor metadata to classify them into: - - Special parameters: Missing from source metadata - - No-shard parameters: Matching sharding strategies and offsets - - Online-shard parameters: Different sharding strategies requiring resharding + num_params = len(total_param_info) + num_workers = min(num_workers, num_params) if num_params > 0 else 1 + logger.info(f"Process workers number: {num_workers}") - Args: - dst_sharded_tensor_metas: Metadata for destination sharded tensors - src_sharded_tensor_metas: Metadata for source sharded tensors - param_file_mappings: Mapping of parameters to their storage files + def balance_load(params: List[dict], num_groups: int) -> List[List[dict]]: + """Balances parameter load across worker groups to minimize load imbalance. - Returns: - Tuple containing three collections: - - special_params: List of parameter names missing from source - - no_shard_params: Dict mapping filenames to params that don't need resharding - - online_shard_params: Dict of params that need resharding with their details + Uses a greedy load balancing algorithm: + 1. Sorts parameters by total size (descending) to prioritize large parameters + 2. Greedily assigns each parameter to the worker group with the smallest current total size + This ensures even distribution of computational load across workers. - Raises: - ValueError: If a parameter exists in source metadata but has an empty list of ShardedTensor instances - RuntimeError: If global shapes of source and destination tensors for a parameter do not match - RuntimeError: If sharding strategies match but no corresponding parameter offset is found - """ - # Initialize categorization collections - not_mapping_params: List[str] = [] - need_concat_params: Dict[str, Dict[str, List[Any]]] = {} - no_shard_params: Dict[str, Dict[str, List[Any]]] = {} - no_shard_params_list: List[str] = [] - online_shard_params: Dict[str, List[Any]] = {} + Args: + params: List of parameter metadata dicts (each with "size" key) + num_groups: Number of worker groups to split parameters into - rank_id = get_real_rank() + Returns: + List of worker groups, where each group is a list of parameter metadata dicts. + Groups are balanced by total tensor size to avoid uneven workload distribution. + """ + # Sort parameters from largest to smallest to optimize load balancing + sorted_params = sorted(params, key=lambda x: x["size"], reverse=True) - # Analyze each parameter in destination metadata - for param_name in dst_sharded_tensor_metas: - # Handle parameters missing from source metadata - if param_name not in src_sharded_tensor_metas: - not_mapping_params.append(param_name) - continue + # Initialize worker groups with empty params and zero total size + groups = [{"total_size": 0, "params": []} for _ in range(num_groups)] - # Get destination tensor strategy information - dst_sharded_tensor = dst_sharded_tensor_metas[param_name] - dst_global_shape, dst_axis_fragmentations, dst_global_offset = get_strategy_info_from_sharded_tensor( - dst_sharded_tensor) + # Assign each parameter to the least loaded group + for param in sorted_params: + min_group = min(groups, key=lambda g: g["total_size"]) + min_group["total_size"] += param["size"] + min_group["params"].append(param) - src_sharded_tensor_list = src_sharded_tensor_metas[param_name] - if not src_sharded_tensor_list: - raise ValueError( - f"Source metadata for parameter '{param_name}' contains an empty list of ShardedTensor instances. " - "Valid source metadata requires at least one ShardedTensor entry." - ) + # Extract only the parameter lists (discard size tracking) + return [group["params"] for group in groups] - # Get parameters info which need to concat - if param_name != src_sharded_tensor_list[0].key: - concat_infos = [] - reshard_infos = [] - for src_sharded_tensor in src_sharded_tensor_list: - param_key = str((src_sharded_tensor.org_key, src_sharded_tensor.global_offset)) - concat_infos.append( - { - 'sub_name': src_sharded_tensor.org_key, - 'file_name': param_file_mappings[param_key][0]["file_name"], - 'param_dtype': src_sharded_tensor.dtype, - } - ) + param_groups = balance_load(total_param_info, num_workers) - if dst_axis_fragmentations != src_sharded_tensor_list[0].axis_fragmentations: - # `reshard_infos` contains `full_shape, from_layout, to_layout, to_rank_id` - reshard_infos = [dst_global_shape, None, dst_sharded_tensor.layout, rank_id] - need_concat_params[param_name] = (concat_infos, reshard_infos) - continue + def process_param_group(param_group: List[dict]) -> List[Tuple[str, Parameter]]: + """Processes a group of parameters to reconstruct full tensors from slices. - param_key: Optional[str] = None - strategy_is_same = False + For each parameter in the group: + 1. Uses ReshardHandler to merge tensor slices into a full tensor + 2. Wraps the merged tensor into a Parameter object + 3. Collects (param_name, Parameter) pairs for state dict update - # Compare with each source tensor strategy - for src_sharded_tensor in src_sharded_tensor_list: - src_global_shape, src_axis_fragmentations, src_global_offset = \ - get_strategy_info_from_sharded_tensor(src_sharded_tensor) + Args: + param_group: List of parameter metadata dicts (from balance_load output) - # Validate global shape compatibility - if src_global_shape != dst_global_shape: - raise RuntimeError("Global shapes of source and destination tensors do not match") + Returns: + List of tuples containing (parameter name, reconstructed Parameter object) + """ + results = [] + for param in param_group: + target_weight = param["reshard_handler"].get_real_tensor(param["from_tensor_map"]) + target_weight = Parameter(target_weight, name=param["param_name"], requires_grad=False) + results.append((param["param_name"], target_weight)) + return results - # Check if sharding strategies differ - if src_axis_fragmentations != dst_axis_fragmentations: - break # Strategies differ, no need to check further + with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor: + futures = [executor.submit(process_param_group, group) for group in param_groups] - strategy_is_same = True + for future in concurrent.futures.as_completed(futures): + for param_name, target_weight in future.result(): + state_dict[param_name] = target_weight - # Check if offsets match for direct mapping - if src_global_offset == dst_global_offset: - param_key = str((src_sharded_tensor.org_key, src_global_offset)) - break # Found matching parameter - # Validate strategy consistency - if strategy_is_same and param_key is None: - raise RuntimeError("Matching strategy found but no corresponding parameter offset") +def build_concat_tensors_reformulation_mappings( + checkpoint_dir: str, + params_info: Dict[str, Dict], + need_concat_params: Dict[str, Tuple[List[ShardedTensor], ShardedTensor]], + param_file_mappings: Dict[str, List[Dict[str, Any]]], + network: Cell, + key_mapping: Dict[str, str] +) -> Tuple[Dict[str, Dict], Dict[str, Parameter]]: + """ + Processes parameters that need concatenation (HuggingFace weights), executes loading, concat, and slice operations. - src_sharded_tensor = src_sharded_tensor_list[0] + Workflow: + 1. Organize file loading info: group parameters by file for efficient batch loading using ms_load_checkpoint; + 2. Concat weights: for each parameter, build concat_dict and call network.convert_hf_weight to concatenate; + 3. Slice with all_offset: use _smart_slice to get target slices and classify + into reshard-needed or no-reshard groups. - # Categorize based on strategy comparison - if strategy_is_same: - # Parameters that don't need resharding - file_name = param_file_mappings[param_key][0]["file_name"] + Args: + checkpoint_dir: Path to checkpoint directory + params_info: Output from build_tensors_reformulation_all_offsets, containing reshard handlers and offsets + need_concat_params: {param_name: (src_sharded_tensor_list, dst_sharded_tensor)} + param_file_mappings: Mapping from shard IDs to storage information lists + network: Network instance (must have convert_hf_weight method) + key_mapping: Parameter name mapping dictionary - # Initialize entry if new file - if file_name not in no_shard_params: - no_shard_params[file_name] = { - "param_name_list": [src_sharded_tensor.org_key], - "param_dtype_list": [src_sharded_tensor.dtype], - } - else: - # Add to existing file entry - no_shard_params[file_name]["param_name_list"].append(src_sharded_tensor.org_key) - no_shard_params[file_name]["param_dtype_list"].append(src_sharded_tensor.dtype) + Returns: + Tuple[params_info_need_reshard, state_dict_no_reshard] + - params_info_need_reshard: {param_name: {"reshard_handler": ..., "tensor_map": ...}} + - state_dict_no_reshard: {param_name: Parameter} + """ + if need_concat_params and not hasattr(network, 'convert_hf_weight'): + raise NotImplementedError( + "The `convert_hf_weight` method of network is not implemented." + ) - no_shard_params_list.append(src_sharded_tensor.org_key) - else: - # Parameters that need online resharding - online_shard_params[src_sharded_tensor.org_key] = [ - dst_global_shape, src_sharded_tensor.layout, dst_sharded_tensor.layout, rank_id - ] - # Parameters to be processed for categorized logging - logger.debug(f"Params not mapping: {not_mapping_params}") - logger.debug(f"Params needing transformation: {need_concat_params}") - logger.debug(f"Params no need reshard: {no_shard_params_list}") - logger.debug(f"Params need reshard: {list(online_shard_params.keys())}") + # ========== Step 1: Organize file loading info ========== + files_to_load: Dict[str, List[str]] = {} # {file_name: [param_org_key, ...]} - return not_mapping_params, need_concat_params, no_shard_params, online_shard_params + for param_name, (src_sharded_tensor_list, _) in need_concat_params.items(): + for src_sharded_tensor in src_sharded_tensor_list: + shard_id = get_sharded_tensor_shard_id( + src_sharded_tensor.org_key, + src_sharded_tensor.global_offset + ) + file_name = param_file_mappings[shard_id][0]["file_name"] + + if file_name not in files_to_load: + files_to_load[file_name] = [] + # Use org_key as the parameter name to load + if src_sharded_tensor.org_key not in files_to_load[file_name]: + files_to_load[file_name].append(src_sharded_tensor.org_key) + + # Use ms_load_checkpoint to lazy-load all needed parameters + state_dict_from_files: Dict[str, Parameter] = {} + for file_name, param_names in files_to_load.items(): + file_path = os.path.join(checkpoint_dir, file_name) + loaded = ms_load_checkpoint( + file_path, + format='safetensors', + choice_func=lambda x, names=param_names: x in names + ) + state_dict_from_files.update(loaded) + # ========== Step 2: Concat weights for each parameter ========== + state_dict_concated: Dict[str, Tensor] = {} -def get_metadata_of_checkpoint(checkpoint_dir: str) -> tuple[dict, dict]: - """ - Retrieves metadata from checkpoint directory, either from an existing metadata file - or by parsing checkpoint files. + for param_name, (src_sharded_tensor_list, _) in need_concat_params.items(): + # Build concat_dict + concat_dict = {} + for src_sharded_tensor in src_sharded_tensor_list: + org_key = src_sharded_tensor.org_key + mapped_key = key_mapping.get(org_key, org_key) if key_mapping else org_key + if org_key in state_dict_from_files: + concat_dict[mapped_key] = state_dict_from_files[org_key] - First checks for a pre-existing 'metadata.json' file in the checkpoint directory. If found, - it loads metadata from this file using load_metadata(). If not found, it generates metadata - by parsing the checkpoint files directly using load_metadata_from_checkpoint(). + # Call network.convert_hf_weight to concat + concated_weight = network.convert_hf_weight(concat_dict) + state_dict_concated[param_name] = concated_weight[param_name] - Args: - checkpoint_dir: Path to the directory containing the checkpoint files. - network: The target core network (Cell) which has method `convert_name` to convert Hugging Face weight. + # ========== Step 3: Slice with all_offset ========== + params_info_need_reshard: Dict[str, Dict] = {} + state_dict_no_reshard: Dict[str, Parameter] = {} - Returns: - A tuple containing two dictionaries: - - sharded_tensor_metas: Metadata about sharded tensors - - param_file_mappings: Mapping of parameters to their storage files - """ - logger.info("..........Load Metadata of Checkpoint..........") + for param_name, concated_tensor in state_dict_concated.items(): + reshard_handler = params_info[param_name]["reshard_handler"] + all_offset = params_info[param_name]["all_offset"] + load_from_multi_rank = len(all_offset) > 1 - # Construct path to metadata file - metadata_path = os.path.join(checkpoint_dir, "metadata.json") + # Get first rank's offset for slicing + first_rank = next(iter(all_offset)) + param_slice = all_offset[first_rank] - # Load from existing metadata file if available - if os.path.exists(metadata_path): - sharded_tensor_metas, param_file_mappings = load_metadata(metadata_path) - # Otherwise generate metadata from checkpoint files - else: - sharded_tensor_metas, param_file_mappings = generate_default_metadata_from_checkpoint(checkpoint_dir) + sliced_tensor, is_full_slice = smart_slice(concated_tensor, param_slice, load_from_multi_rank) + + if is_full_slice and not load_from_multi_rank: + # No reshard needed, directly add to state_dict (sliced_tensor is the original tensor) + if not isinstance(sliced_tensor, Parameter): + sliced_tensor = Parameter(sliced_tensor, name=param_name, requires_grad=False) + state_dict_no_reshard[param_name] = sliced_tensor + else: + # Needs reshard, build tensor_map for all ranks + params_info_need_reshard[param_name] = { + "reshard_handler": reshard_handler, + "tensor_map": {0: sliced_tensor} + } - # if the content or format of metadata_path and checkpoint_dir are invalid, the return value of - # sharded_tensor_metas and param_file_mappings may be empty or None, - # and it may cause an error in subsequent loading process. - return sharded_tensor_metas, param_file_mappings + return params_info_need_reshard, state_dict_no_reshard -def params_key_mapping( - sharded_tensor_metas: Dict[str, List[ShardedTensor]], - network: Cell -) -> tuple[dict, dict]: +def categorize_params( + dst_sharded_tensor_metas: Dict[str, ShardedTensor], + src_sharded_tensor_metas: Dict[str, List[ShardedTensor]] +) -> Tuple[ + List[str], Dict[str, Tuple[List[ShardedTensor], ShardedTensor]], + Dict[str, Tuple[ShardedTensor, ShardedTensor]] +]: """ - Mapping Hugging Face checkpoint keys to MindSpore Transformers. + Categorizes parameters based on comparison of source and destination sharding strategies and key matching. + + Analyzes parameters from destination and source sharded tensor metadata to classify them into three categories: + - Special parameters (not_mapping_params): Missing from source metadata or with mismatched keys + - Concat-required parameters (need_concat_params): Params that need concatenation (and optional resharding) due to + mismatched keys or sharding strategies between source and destination + - Mapped parameters (params_to_load): Params with matching keys and consistent sharding attributes Args: - sharded_tensor_metas: Metadata about sharded tensors. - network: The network (Cell) which has method `convert_name` to convert Hugging Face weight. + dst_sharded_tensor_metas: Dictionary mapping parameter names to their destination ShardedTensor metadata + (key: parameter name, value: ShardedTensor instance for destination) + src_sharded_tensor_metas: Dictionary mapping parameter names to a list of source ShardedTensor metadata + (key: parameter name, value: list of ShardedTensor instances for source) + param_file_mappings: Mapping from shard IDs to storage information lists. + Each storage info dict has "file_name", "storage_rank", "rank_group" keys. Returns: - A dictionary after mapping about sharded tensor metas. - """ - # The key of `mapped_sharded_tensor_metas` is in the network, - # such as { qkv: [ShardedTensor, ShardedTensor, ShardedTensor], ... } - mapped_sharded_tensor_metas = {} - # The key of `key_mapping` is {'weight_key': 'mapping_key'}, - # and the `mapping_key` may not have the same name as the parameter in the network, - # it could be an intermediate form, - # such as { 'q_proj': 'linear_q', 'k_proj': 'linear_k', 'v_proj': 'linear_v', ... } - key_mapping = {} + Tuple containing three collections: + - not_mapping_params: List of parameter names missing from source metadata (special parameters) + - need_concat_params: Dictionary mapping parameter names to (src_sharded_tensor_list, dst_sharded_tensor) + where src_sharded_tensor_list is a list of source ShardedTensor instances that need to be concatenated + - params_to_load: Dictionary mapping original parameter keys to sharded tensor details: + Each entry contains (src_sharded_tensor, dst_sharded_tensor) - for param_name in sharded_tensor_metas: - param_name_converted = network.convert_name(param_name) - sharded_tensor_list = sharded_tensor_metas.get(param_name) + Raises: + ValueError: If a parameter exists in source metadata but has an empty list of ShardedTensor instances + """ + logger.info("..........Categorize Params..........") + # Initialize categorization collections + not_mapping_params: List[str] = [] + need_concat_params: Dict[str, Tuple[List[ShardedTensor], ShardedTensor]] = {} + params_to_load: Dict[str, Tuple[ShardedTensor, ShardedTensor]] = {} - for sharded_tensor in sharded_tensor_list: - sharded_tensor.key = param_name_converted - sharded_tensor.org_key = param_name + # Analyze each parameter in destination metadata + for param_name, dst_sharded_tensor in dst_sharded_tensor_metas.items(): + # Handle parameters missing from source metadata + if param_name not in src_sharded_tensor_metas: + not_mapping_params.append(param_name) + continue - key_mapping[param_name] = param_name_converted - param_name_converted_concat = network.convert_concat_name(param_name_converted) - mapped_sharded_tensor_metas.setdefault(param_name_converted_concat, []).extend(sharded_tensor_list) + src_sharded_tensor_list = src_sharded_tensor_metas[param_name] + if not src_sharded_tensor_list: + raise ValueError( + f"Source metadata for parameter '{param_name}' contains an empty list of ShardedTensor instances. " + "Valid source metadata requires at least one ShardedTensor entry." + ) - return mapped_sharded_tensor_metas, key_mapping + # Check if parameter needs concat: param_name differs from src tensor's key + if param_name != src_sharded_tensor_list[0].key: + # Needs concat: store complete ShardedTensor list and dst ShardedTensor + need_concat_params[param_name] = (src_sharded_tensor_list, dst_sharded_tensor) + else: + # Direct mapping + src_sharded_tensor = src_sharded_tensor_list[0] + params_to_load[src_sharded_tensor.org_key] = (src_sharded_tensor, dst_sharded_tensor) + # Parameters to be processed for categorized logging + logger.debug(f"Params can't load: {not_mapping_params}") + logger.debug(f"Params needing concat: {list(need_concat_params.keys())}") + logger.debug(f"Params to load: {list(params_to_load.keys())}") -# pylint: disable=W0212 -def get_core_network(network): - """Get the core network that has `convert_name` method.""" - if hasattr(network, '_backbone'): - return get_core_network(network._backbone) - if hasattr(network, 'network'): - return get_core_network(network.network) - return network + return not_mapping_params, need_concat_params, params_to_load def load_checkpoint( @@ -886,7 +1004,8 @@ def load_checkpoint( network: Cell, optimizer: Optional[Optimizer] = None, global_step: Optional[int] = None, - balanced_load: bool = False + balanced_load: bool = False, + load_worker_number: int = 1, ) -> None: """ Loads a checkpoint into a network and optional optimizer. @@ -900,6 +1019,7 @@ def load_checkpoint( network: The target network (Cell) to load parameters into (cannot be None) optimizer: Optional optimizer (Cell) to load optimizer states into global_step: Optional initial global step value if not found in checkpoint + load_worker_number: Max number of workers to process params Raises: ValueError: If the input `network` is None @@ -914,6 +1034,7 @@ def load_checkpoint( logger.info("..........Start Load Checkpoint..........") # Retrieve metadata from checkpoint files + logger.info("..........Get Metadata of Checkpoint..........") src_sharded_tensor_metas, param_file_mappings = get_metadata_of_checkpoint(checkpoint_dir) # Get the core network and check the convert method is illegal @@ -934,6 +1055,7 @@ def load_checkpoint( return param_name in list(network.parameters_dict().keys()) param_redundancy = None + logger.info("..........Get Metadata of Network..........") if balanced_load: rank_id_to_sharded_tensors = apply_balance_shard_strategy(network, filter_func) dst_sharded_tensor_metas = get_cur_sharded_tensor_after_balanced(rank_id_to_sharded_tensors) @@ -943,43 +1065,47 @@ def load_checkpoint( if get_real_group_size() > 1 else get_sharded_tensor_from_cell(network, optimizer) # Categorize parameters based on sharding strategies - _, need_concat_params, no_shard_params, online_shard_params = categorize_params( - dst_sharded_tensor_metas, src_sharded_tensor_metas, param_file_mappings + _, need_concat_params, params_to_load = categorize_params( + dst_sharded_tensor_metas, src_sharded_tensor_metas ) # Process Weight + logger.info("..........Building State Dict..........") state_dict: Dict[str, Parameter] = {} - # Concat parameters - concat_params(checkpoint_dir, network, key_mapping, need_concat_params, state_dict) - - # Load parameters that don't require resharding - for file_name, param_info in no_shard_params.items(): - param_name_list = param_info["param_name_list"] - param_dtype_list = param_info["param_dtype_list"] - no_reshard_state_dict = load_safetensor( - os.path.join(checkpoint_dir, file_name), param_name_list, dtype=param_dtype_list + # ==================== Path 1: Direct mapping parameters ==================== + params_info_direct = build_tensors_reformulation_all_offsets(params_to_load) + params_info_need_reshard_1, state_dict_no_reshard_1 = build_tensors_reformulation_mappings( + checkpoint_dir, params_info_direct, src_sharded_tensor_metas, + param_file_mappings, key_mapping, + ) + state_dict.update(state_dict_no_reshard_1) + + # ==================== Path 2: Concat parameters (HuggingFace weights) ==================== + if need_concat_params: + # Build reshard info for concat parameters (reuse build_tensors_reformulation_all_offsets) + params_info_concat = build_tensors_reformulation_all_offsets(need_concat_params) + # Load, concat, and build tensor_map + params_info_need_reshard_2, state_dict_no_reshard_2 = build_concat_tensors_reformulation_mappings( + checkpoint_dir, params_info_concat, need_concat_params, + param_file_mappings, network, key_mapping ) + state_dict.update(state_dict_no_reshard_2) + else: + params_info_need_reshard_2 = {} - state_dict.update({ - key_mapping[param_name]: value - for param_name, value in no_reshard_state_dict.items() - }) + # ==================== Unified parallel load strategy ==================== + all_params_need_reshard = {**params_info_need_reshard_1, **params_info_need_reshard_2} - # Load and reshard parameters that require online resharding - for param_name, (full_shape, from_layout, to_layout, to_rank_id) in online_shard_params.items(): - reshard_handler = ReshardHandler(param_name, full_shape, from_layout, to_layout, to_rank_id) - all_offset = reshard_handler.infer_all_tensor_offset() - from_tensor_map = load_tensor_by_offset( - all_offset, param_name, checkpoint_dir, src_sharded_tensor_metas, param_file_mappings, key_mapping - ) - target_weight = reshard_handler.get_real_tensor(from_tensor_map) - param_name = key_mapping[param_name] - state_dict[param_name] = Parameter(target_weight, name=param_name, requires_grad=False) + logger.info("Get and write params into state dict.") + apply_parallel_load_strategy( + all_params_need_reshard, state_dict, key_mapping, num_workers=load_worker_number + ) # Handle global_step for optimizer if needed if optimizer and "global_step" not in state_dict: # Initialize global_step with default or from common.json + logger.info(".....Get Global Step for Optimizer.....") if not global_step: common_file = os.path.join(checkpoint_dir, 'common.json') global_step = 0 if not os.path.exists(common_file) else CommonInfo.load_common(common_file).global_step @@ -988,6 +1114,7 @@ def load_checkpoint( Tensor([global_step], mstype.int32), name="global_step", requires_grad=False ) + logger.info("..........Loading State Dict into Network..........") # Load state dictionary into network and optimizer load_parameters( network, @@ -997,44 +1124,7 @@ def load_checkpoint( param_redundancy=param_redundancy ) - -def concat_params(checkpoint_dir: str, network: Cell, key_mapping: dict, need_concat_params, state_dict: dict): - """Concat the need_concat_params dict in checkpoint.""" - if need_concat_params and not hasattr(network, 'convert_hf_weight'): - raise NotImplementedError("The `convert_hf_weight` method of network is not implemented.") - - for param_name, concat_info in need_concat_params.items(): - sharded_tensor_list, reshard_info = concat_info - org_weight_dict = {} - # Get all the params need to concat into `org_weight_dict`. - for sharded_tensor in sharded_tensor_list: - org_weight_dict.update( - load_safetensor( - checkpoint_path=os.path.join(checkpoint_dir, sharded_tensor['file_name']), - param_name=sharded_tensor['sub_name'], - dtype=sharded_tensor['param_dtype'] - ) - ) - # Mapping the weight key to MCore key into `concat_dict`. - concat_dict = { - key_mapping[k]: v - for k, v in org_weight_dict.items() - } - # Concat the weight. - concated_weight = network.convert_hf_weight(concat_dict) - - if reshard_info: - # Get the offset of the Tensor to reshard. - full_shape, from_layout, to_layout, to_rank_id = reshard_info - reshard_handler = ReshardHandler(param_name, full_shape, from_layout, to_layout, to_rank_id) - all_offset = reshard_handler.infer_all_tensor_offset() - # Get the slice to reshard the Tensor. - slices = tuple(slice(start, end) for start, end in all_offset[0]) - target_weight = concated_weight[param_name][slices] - # Update to `state_dict` to load into the network. - state_dict[param_name] = Parameter(target_weight, name=param_name, requires_grad=False) - else: - state_dict[param_name] = Parameter(concated_weight[param_name], name=param_name, requires_grad=False) + logger.info("..........Loading Checkpoint Finished..........") def check_the_param_for_load_ckpt(checkpoint: str, network: Cell): @@ -1092,13 +1182,14 @@ def load_parameters( RuntimeError: If parameter loading fails due to mismatched keys or invalid parameter types (propagated from `load_param_into_net`). """ + def split_state_dict(network, state_dict, optimizer, state_dict_opt): """split state dict""" network_param_names = set(network.parameters_dict().keys()) optimizer_param_names = set(optimizer.parameters_dict().keys()) if optimizer else set() for param_name in list(state_dict.keys()): if param_name not in network_param_names and param_name in optimizer_param_names and \ - param_name not in state_dict_opt: + param_name not in state_dict_opt: state_dict_opt[param_name] = state_dict.pop(param_name) return network_param_names, optimizer_param_names, state_dict, state_dict_opt @@ -1123,7 +1214,7 @@ def load_parameters( # Load parameters into network logger.debug(f"Network state_dict keys: {list(state_dict.keys())}") - param_not_load, ckpt_not_load = load_param_into_net(network, state_dict, strict_load=True) + param_not_load, ckpt_not_load = load_param_into_net(network, state_dict) if balanced_load: param_loaded = {param_name for param_name in state_dict if param_name not in ckpt_not_load} single_parameter_broadcast(network, param_redundancy, param_not_load, param_loaded) diff --git a/mindformers/checkpoint/fully_parallel.py b/mindformers/checkpoint/fully_parallel.py index f4bc8f87121e2a795c84746e6e7efe3061ec49e3..7b8b50c95a74a1fc9115c04bab7d122f0afd3f4e 100644 --- a/mindformers/checkpoint/fully_parallel.py +++ b/mindformers/checkpoint/fully_parallel.py @@ -236,16 +236,20 @@ class BalancedSaveStrategy(): if self.rank_id == 0: param_file_mapping = [] cur_rank_id = 0 - rank_param_ids_mappings = self._get_rank_param_ids_mappings(shared_distribution) - for rank_id, params in rank_param_ids_mappings.items(): + for rank_id, params in shared_distribution.items(): if params: save_file_name = get_checkpoint_name( None, self.user_prefix, cur_rank_id, self.total_files_num, self.file_type ) - for param_id in params: - param_file_mapping.append( - (save_file_name + ".safetensors", rank_id, _reverse_sharded_tensor_shard_id(param_id))) + for shard_id in params: + rank_group = params[shard_id][1] + param_file_mapping.append(( + save_file_name + ".safetensors", + rank_id, + rank_group, + _reverse_sharded_tensor_shard_id(shard_id) + )) cur_rank_id += 1 sharded_tensor_metas = get_all_sharded_tensor(self.network, self.filter_func) @@ -255,12 +259,13 @@ class BalancedSaveStrategy(): origin_shard_metadata, origin_param_file_mapping = load_metadata( get_metadata_filename(self.checkpoint_path, iteration)) sharded_tensor_metas.update({"origin": origin_shard_metadata}) - for param_id, storage in origin_param_file_mapping.items(): + for shard_id, storage in origin_param_file_mapping.items(): for storage_item in storage: param_file_mapping.append(( storage_item["file_name"], storage_item["storage_rank"], - _reverse_sharded_tensor_shard_id(param_id) + storage_item["rank_group"], + _reverse_sharded_tensor_shard_id(shard_id) )) metadata_file_path = get_metadata_filename(self.checkpoint_path, iteration) @@ -441,4 +446,9 @@ def apply_balance_shard_strategy(network: Cell, filter_func: Callable[[str], boo else: rank_id_to_sharded_tensors[selected_rank_id] = {shard_id: (sharded_tensor, rank_group)} + rank_id_to_sharded_tensors = { + k: rank_id_to_sharded_tensors.get(k, None) + for k in sorted(rank_id_to_sharded_tensors) + } + return rank_id_to_sharded_tensors diff --git a/mindformers/checkpoint/metadata.py b/mindformers/checkpoint/metadata.py index e2ccbe4b4d8c04eaf2f29fd16f9f81394a59cf57..c8b6bc0e74f3a69ee998150784e966796a3427d5 100644 --- a/mindformers/checkpoint/metadata.py +++ b/mindformers/checkpoint/metadata.py @@ -17,8 +17,11 @@ import os import json import tempfile from glob import glob +from typing import Dict, Tuple, List +from collections import defaultdict from safetensors import safe_open +from mindspore.nn import Cell from mindspore.communication.management import get_group_size from mindspore.common.dtype import all_types from mindspore.parallel import Layout @@ -31,6 +34,8 @@ from mindformers.checkpoint.utils import ( get_sharded_tensor_shard_id, FileType ) +from mindformers.checkpoint.sharded_tensor import ShardedTensor + tensor_to_ms_type = {str(dtype).lower(): dtype for dtype in all_types} @@ -149,11 +154,16 @@ def save_metadata(sharded_tensor_metas, param_file_mappings, meta_data_path): for param_file_mapping in param_file_mappings: file_name = param_file_mapping[0] storage_rank = param_file_mapping[1] - param_id = get_sharded_tensor_shard_id(param_file_mapping[2][0], param_file_mapping[2][1]) - if param_id not in storage_data: - storage_data[param_id] = [{"file_name": file_name, "storage_rank": storage_rank}] + rank_group = param_file_mapping[2] + shard_id = get_sharded_tensor_shard_id(param_file_mapping[3][0], param_file_mapping[3][1]) + if shard_id not in storage_data: + storage_data[shard_id] = [ + {"file_name": file_name, "storage_rank": storage_rank, "rank_group": rank_group} + ] else: - storage_data[param_id].append({"file_name": file_name, "storage_rank": storage_rank}) + storage_data[shard_id].append( + {"file_name": file_name, "storage_rank": storage_rank, "rank_group": rank_group} + ) metadata = {"state_dict_metadata": state_dict_metadata, "storage_data": storage_data} @@ -302,28 +312,14 @@ def generate_default_metadata_from_checkpoint(checkpoint_dir: str) -> tuple[dict # Open the safetensor file and process each tensor with safe_open(safetensor_file, framework="np", device="cpu") as f: for param_name in f.keys(): - try: - # Load the tensor to extract its properties - tensor = f.get_tensor(param_name) # Load as numpy tensor - except Exception as e: - raise RuntimeError( - f"Failed to load tensor '{param_name}' from file '{file_basename}'" - ) from e - - # Extract tensor properties - tensor_shape = tensor.shape - ms_dtype = tensor_to_ms_type.get(str(tensor.dtype)) - global_offset = (0,) - axis_fragmentations = (1,) * len(tensor_shape) - # Create sharded tensor metadata object sharded_tensor = build_sharded_tensor( param_name=param_name, - param_dtype=ms_dtype, - local_shape=tensor_shape, - global_shape=tensor_shape, - global_offset=global_offset, - axis_fragmentations=axis_fragmentations, + param_dtype=None, + local_shape=None, + global_shape=None, + global_offset=(0,), + axis_fragmentations=None, layout=None ) @@ -334,7 +330,7 @@ def generate_default_metadata_from_checkpoint(checkpoint_dir: str) -> tuple[dict # Store metadata with fixed storage rank 0 sharded_tensor_metas[param_name] = [sharded_tensor] param_file_mappings[str((param_name, (0,)))] = [ - {"file_name": file_basename, "storage_rank": 0} + {"file_name": file_basename, "storage_rank": 0, "rank_group": [0]} ] return sharded_tensor_metas, param_file_mappings @@ -345,6 +341,12 @@ def get_total_params_file_mapping_info(sharded_tensor_metas, user_prefix, model_ if sharded_tensor_metas is None: return None + shard_id_to_ranks = defaultdict(list) + for cur_npu_rank, cur_rank_sharded_tensors in sharded_tensor_metas.items(): + for sharded_tensor in cur_rank_sharded_tensors.values(): + shard_id = get_sharded_tensor_shard_id(sharded_tensor.key, sharded_tensor.global_offset) + shard_id_to_ranks[shard_id].append(cur_npu_rank) + npu_nums = get_group_size() param_file_mappings = [] for cur_npu_rank, cur_rank_sharded_tensors in sharded_tensor_metas.items(): @@ -356,7 +358,85 @@ def get_total_params_file_mapping_info(sharded_tensor_metas, user_prefix, model_ ckpt_name = get_checkpoint_name(None, user_prefix, cur_npu_rank, npu_nums, FileType.MODEL) param_file_mappings.append( - (ckpt_name + '.safetensors', cur_npu_rank, (sharded_tensor.key, sharded_tensor.global_offset)) + (ckpt_name + '.safetensors', + cur_npu_rank, + shard_id_to_ranks[get_sharded_tensor_shard_id(sharded_tensor.key, sharded_tensor.global_offset)], + (sharded_tensor.key, sharded_tensor.global_offset)) ) return param_file_mappings + + +def get_metadata_of_checkpoint(checkpoint_dir: str) -> Tuple[Dict, Dict]: + """ + Retrieves metadata from checkpoint directory, either from an existing metadata file + or by parsing checkpoint files. + + First checks for a pre-existing 'metadata.json' file in the checkpoint directory. If found, + it loads metadata from this file using load_metadata(). If not found, it generates metadata + by parsing the checkpoint files directly using load_metadata_from_checkpoint(). + + Args: + checkpoint_dir: Path to the directory containing the checkpoint files. + network: The target core network (Cell) which has method `convert_name` to convert Hugging Face weight. + + Returns: + A tuple containing two dictionaries: + - sharded_tensor_metas: Metadata about sharded tensors + - param_file_mappings: Mapping from shard IDs to storage information lists. + Each storage info dict has "file_name", "storage_rank", "rank_group" keys. + """ + logger.info("..........Load Metadata of Checkpoint..........") + + # Construct path to metadata file + metadata_path = os.path.join(checkpoint_dir, "metadata.json") + + # Load from existing metadata file if available + if os.path.exists(metadata_path): + sharded_tensor_metas, param_file_mappings = load_metadata(metadata_path) + # Otherwise generate metadata from checkpoint files + else: + sharded_tensor_metas, param_file_mappings = generate_default_metadata_from_checkpoint(checkpoint_dir) + + # if the content or format of metadata_path and checkpoint_dir are invalid, the return value of + # sharded_tensor_metas and param_file_mappings may be empty or None, + # and it may cause an error in subsequent loading process. + return sharded_tensor_metas, param_file_mappings + + +def params_key_mapping( + sharded_tensor_metas: Dict[str, List[ShardedTensor]], + network: Cell +) -> Tuple[Dict, Dict]: + """ + Mapping Hugging Face checkpoint keys to MindSpore Transformers. + + Args: + sharded_tensor_metas: Metadata about sharded tensors. + network: The network (Cell) which has method `convert_name` to convert Hugging Face weight. + + Returns: + A dictionary after mapping about sharded tensor metas. + """ + # The key of `mapped_sharded_tensor_metas` is in the network, + # such as { qkv: [ShardedTensor, ShardedTensor, ShardedTensor], ... } + mapped_sharded_tensor_metas = {} + # The key of `key_mapping` is {'weight_key': 'mapping_key'}, + # and the `mapping_key` may not have the same name as the parameter in the network, + # it could be an intermediate form, + # such as { 'q_proj': 'linear_q', 'k_proj': 'linear_k', 'v_proj': 'linear_v', ... } + key_mapping = {} + + for param_name in sharded_tensor_metas: + param_name_converted = network.convert_name(param_name) + sharded_tensor_list = sharded_tensor_metas.get(param_name) + + for sharded_tensor in sharded_tensor_list: + sharded_tensor.key = param_name_converted + sharded_tensor.org_key = param_name + + key_mapping[param_name] = param_name_converted + param_name_converted_concat = network.convert_concat_name(param_name_converted) + mapped_sharded_tensor_metas.setdefault(param_name_converted_concat, []).extend(sharded_tensor_list) + + return mapped_sharded_tensor_metas, key_mapping diff --git a/mindformers/checkpoint/reshard.py b/mindformers/checkpoint/reshard.py index f812531fca5a106e5739f347dc90ab81a52a83f6..f23ab7efc558e181defe91808546d1102845a96b 100644 --- a/mindformers/checkpoint/reshard.py +++ b/mindformers/checkpoint/reshard.py @@ -218,9 +218,6 @@ class ReshardHandler: check_layout(from_layout, 'from_layout') check_layout(to_layout, 'to_layout') - if from_layout is None and to_layout is None: - raise ValueError("`from_layout` and `to_layout` cannot both be None.") - # Initialize basic attributes self.param_name = param_name self.full_shape = full_shape @@ -283,10 +280,7 @@ class ReshardHandler: # Filter ranks with non-redundant data for rank_id in self.inner_from_rank_list: dev_id_list = rank_id_to_dev_id_list(self.from_dev_matrix, rank_id) - if any([ - dim not in from_dev_map and dev_id_list[dim] > 0 - for dim in range(dev_dim) - ]): + if any(dim not in from_dev_map and dev_id_list[dim] > 0 for dim in range(dev_dim)): continue inner_deredundancy_rank_list.append(rank_id) @@ -372,8 +366,11 @@ class ReshardHandler: # Create target tensor and assign slices to_slice_shape = [end - start for start, end in self.to_area] - dtype = next(iter(from_tensor_map.values())).dtype - real_tensor = Tensor(np.zeros(to_slice_shape), dtype) + current_slice = next(iter(from_tensor_map.values())) + if isinstance(current_slice, Tensor): + real_tensor = Tensor(np.zeros(to_slice_shape), current_slice.dtype) + else: + real_tensor = np.zeros(to_slice_shape, current_slice.dtype) for from_rank_id, from_slice in from_tensor_map.items(): from_area = self.global_union_area_map[from_rank_id] diff --git a/mindformers/checkpoint/sharded_tensor.py b/mindformers/checkpoint/sharded_tensor.py index d809d0dcc7df66af3a5872ac1aa22c18ca56bd37..cabffc6d9e261136899d488516e5bdb7aab24b66 100644 --- a/mindformers/checkpoint/sharded_tensor.py +++ b/mindformers/checkpoint/sharded_tensor.py @@ -82,16 +82,20 @@ class ShardedTensor: def build_sharded_tensor( - param_name: str, param_dtype: ms.dtype, local_shape: Tuple[int, ...], global_shape: Tuple[int, ...], - axis_fragmentations: Tuple[int, ...], global_offset: Tuple[int, ...], replica_id: ReplicaId = 0, - allow_shape_mismatch: bool = False, allow_to_save: bool = True, layout: Optional[ms.Layout] = None + param_name: str, param_dtype: ms.dtype = None, local_shape: Optional[Tuple[int]] = None, + global_shape: Optional[Tuple[int]] = None, axis_fragmentations: Optional[Tuple[int]] = None, + global_offset: Tuple[int] = (0,), replica_id: ReplicaId = 0, allow_shape_mismatch: bool = False, + allow_to_save: bool = True, layout: Optional[ms.Layout] = None ) -> ShardedTensor: """Creates and returns a ShardedTensor instance with the specified parameters.""" return ShardedTensor( - key=param_name, org_key=param_name, dtype=param_dtype, local_shape=tuple(local_shape), - global_shape=tuple(global_shape), global_offset=tuple(global_offset), - axis_fragmentations=tuple(axis_fragmentations), replica_id=replica_id, - allow_shape_mismatch=allow_shape_mismatch, allow_to_save=allow_to_save, layout=layout + key=param_name, org_key=param_name, dtype=param_dtype, + local_shape=tuple(local_shape) if local_shape else local_shape, + global_shape=tuple(global_shape) if global_shape else global_shape, + global_offset=tuple(global_offset) if global_offset else global_offset, + axis_fragmentations=tuple(axis_fragmentations) if axis_fragmentations else axis_fragmentations, + replica_id=replica_id, allow_shape_mismatch=allow_shape_mismatch, allow_to_save=allow_to_save, + layout=layout ) @@ -449,6 +453,9 @@ def get_all_sharded_tensor( ) sharded_tensor_metas[cur_npu_rank] = cur_rank_sharded_tensors + + sharded_tensor_metas = {k: sharded_tensor_metas.get(k, None) for k in sorted(sharded_tensor_metas)} + return sharded_tensor_metas diff --git a/mindformers/checkpoint/utils.py b/mindformers/checkpoint/utils.py index d0c6c0ae21b04365fade733b18ce8acbf6abcba3..7161d329cedc1c39c58f5e2a9ceaabe855c1bbb8 100644 --- a/mindformers/checkpoint/utils.py +++ b/mindformers/checkpoint/utils.py @@ -454,3 +454,13 @@ def compile_model(model, dataset, mode, sink_mode, epoch=1, sink_size=1, do_eval build_time_end = time.time() build_duration = build_time_end - build_time_start logger.info(f"Time spent compiling the model: {build_duration:.2f} seconds") + + +# pylint: disable=W0212 +def get_core_network(network): + """Get the core network that has `convert_name` method.""" + if hasattr(network, '_backbone'): + return get_core_network(network._backbone) + if hasattr(network, 'network'): + return get_core_network(network.network) + return network diff --git a/mindformers/core/config_args.py b/mindformers/core/config_args.py index ea3f0f46ad3a0ada258a6f2c56e85701a98d7122..dbd3cc54438bf0e368c8d1d8be0072a0ece6383d 100644 --- a/mindformers/core/config_args.py +++ b/mindformers/core/config_args.py @@ -490,6 +490,7 @@ class MFContextConfig(BaseArgsConfig): 'only_save_strategy', 'use_legacy_format', 'balanced_load', + 'load_worker_number', 'run_mode', 'use_legacy', 'exclude_cann_cpu', diff --git a/mindformers/tools/register/template.py b/mindformers/tools/register/template.py index f34ad613f84166e9b0907a34f3d178e21ba72e06..993425c9144068bf5354b16a547943e5c4899ce6 100644 --- a/mindformers/tools/register/template.py +++ b/mindformers/tools/register/template.py @@ -253,6 +253,7 @@ class GeneralConfig(Config): use_legacy_format = True pretrained_model_dir = "" balanced_load = False + load_worker_number = 1 # eval while training do_eval = False diff --git a/mindformers/trainer/base_trainer.py b/mindformers/trainer/base_trainer.py index 68948aa44a6f3beb9299d29cc2add800e11c0d45..c027e03ee26e3301e74ff36277511af06aa800f7 100644 --- a/mindformers/trainer/base_trainer.py +++ b/mindformers/trainer/base_trainer.py @@ -1479,13 +1479,15 @@ class BaseTrainer: network=network, optimizer=optimizer, global_step=global_step, - balanced_load=config.balanced_load + balanced_load=config.balanced_load, + load_worker_number=config.load_worker_number if config.load_worker_number else 1 ) else: load_checkpoint( checkpoint=config.load_checkpoint, network=network, - balanced_load=config.balanced_load + balanced_load=config.balanced_load, + load_worker_number=config.load_worker_number if config.load_worker_number else 1 ) elif (config.load_checkpoint or config.only_save_strategy) and not check_is_reboot_node(): if config.resume_training: diff --git a/tests/st/test_ut/test_checkpoint/test_checkpoint.py b/tests/st/test_ut/test_checkpoint/test_checkpoint.py index 12dfcb5b7f7aff46ffca9dc427884d4e073abd5d..9497a81ba0850dd20a827b7a2286624a92b24c79 100644 --- a/tests/st/test_ut/test_checkpoint/test_checkpoint.py +++ b/tests/st/test_ut/test_checkpoint/test_checkpoint.py @@ -16,26 +16,29 @@ # pylint: disable=W0621 import os import json -from unittest.mock import patch +from unittest.mock import patch, MagicMock import pytest import numpy as np from mindspore import Tensor, Parameter, nn from mindspore.common import dtype as mstype +from mindformers.checkpoint.reshard import ReshardHandler from mindformers.checkpoint.checkpoint import ( AsyncSaveManager, save_checkpoint, save_metadata_json, - load_safetensor, categorize_params, - get_metadata_of_checkpoint, - params_key_mapping, load_checkpoint, - concat_params, check_the_param_for_load_ckpt, load_parameters, - get_checkpoint_path + get_checkpoint_path, + build_tensors_reformulation_all_offsets, + apply_parallel_load_strategy +) +from mindformers.checkpoint.metadata import ( + get_metadata_of_checkpoint, + params_key_mapping ) from mindformers.checkpoint.sharded_tensor import ShardedTensor @@ -333,45 +336,6 @@ def test_params_key_mapping(simple_network): assert isinstance(key_mapping, dict) -@pytest.mark.level0 -@pytest.mark.platform_x86_cpu -@pytest.mark.env_onecard -def test_concat_params(tmp_path, simple_network): - """ - Feature: Test concat_params function. - Description: Test the functionality of concat_params with mocked load_safetensor. - Expectation: The function should successfully concatenate parameters and add them to the state_dict. - """ - # Create a simple state_dict - state_dict = {} - key_mapping = {"test_param": "test_param"} - - # Create test data with sharded tensor list - sharded_tensor_list = [ - { - 'sub_name': 'test_param', - 'file_name': 'test.safetensors', - 'param_dtype': mstype.float32, - } - ] - - need_concat_params = { - "test_param": (sharded_tensor_list, []) - } - - # Mock the load_safetensor function to avoid actual file loading - # pylint: disable=W0613 - def mock_load_safetensor(checkpoint_path, param_name, index_tuple=None, dtype=None, **kwargs): - """Mock load_safetensor function.""" - return {param_name: Parameter(Tensor(np.ones((10, 10)), dtype=dtype), name=param_name)} - - with patch('mindformers.checkpoint.checkpoint.load_safetensor', side_effect=mock_load_safetensor): - concat_params(tmp_path, simple_network, key_mapping, need_concat_params, state_dict) - # Since we're mocking load_safetensor, the state_dict should contain the mocked parameter - assert "test_param" in state_dict - assert isinstance(state_dict["test_param"], Parameter) - - @pytest.mark.level0 @pytest.mark.platform_x86_cpu @pytest.mark.env_onecard @@ -406,19 +370,15 @@ def test_categorize_params(): dst_sharded_tensor_metas = {"test_param": dst_sharded_tensor} src_sharded_tensor_metas = {"test_param": [src_sharded_tensor]} - param_file_mappings = { - "('test_param', (0, 0))": [{"file_name": "test.safetensors", "storage_rank": 0}] - } # Test categorize_params with valid inputs - not_mapping_params, need_concat_params, no_shard_params, online_shard_params = categorize_params( - dst_sharded_tensor_metas, src_sharded_tensor_metas, param_file_mappings + not_mapping_params, need_concat_params, mapping_params = categorize_params( + dst_sharded_tensor_metas, src_sharded_tensor_metas ) assert isinstance(not_mapping_params, list) assert isinstance(need_concat_params, dict) - assert isinstance(no_shard_params, dict) - assert isinstance(online_shard_params, dict) + assert isinstance(mapping_params, dict) @pytest.mark.level0 @@ -458,32 +418,6 @@ def test_save_and_load_checkpoint(tmp_path, simple_network, optimizer): pass -@pytest.mark.level0 -@pytest.mark.platform_x86_cpu -@pytest.mark.env_onecard -def test_load_safetensor(tmp_path): - """ - Feature: Test load_safetensor function. - Description: Test the error handling of load_safetensor with non-existent file and invalid content. - Expectation: The function should raise appropriate exceptions for invalid inputs. - """ - # Test with non-existent file - non_existent_file = os.path.join(tmp_path, "non_existent.safetensors") - with pytest.raises(FileNotFoundError): - load_safetensor(non_existent_file) - - # Test with invalid parameter name - # Create a simple safetensors file for testing - # Note: This requires actual safetensors file creation, which is complex - # We'll test the error handling instead - dummy_file = os.path.join(tmp_path, "dummy.safetensors") - with open(dummy_file, "w", encoding='utf-8') as f: - f.write("dummy content") - - with pytest.raises(Exception): - load_safetensor(dummy_file, param_name="invalid_param") - - @pytest.mark.level0 @pytest.mark.platform_x86_cpu @pytest.mark.env_onecard @@ -570,3 +504,186 @@ def test_load_parameters_with_invalid_inputs(): # Test with invalid optimizer with pytest.raises(Exception): load_parameters(net, {}, optimizer="invalid_optimizer") + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_build_tensors_reformulation_all_offsets(): + """ + Feature: Test build_tensors_reformulation_all_offsets function. + Description: Test the functionality of build_tensors_reformulation_all_offsets with sharded tensor pairs. + Expectation: The function should return a dictionary with reshard_handler and all_offset for each parameter. + """ + # Create source and destination sharded tensors + src_sharded_tensor = ShardedTensor( + key="test_param", + org_key="test_param", + dtype=mstype.float32, + local_shape=(10, 10), + global_shape=(10, 10), + global_offset=(0,), + axis_fragmentations=(1, 1), + layout=None + ) + + dst_sharded_tensor = ShardedTensor( + key="test_param", + org_key="test_param", + dtype=mstype.float32, + local_shape=(10, 10), + global_shape=(10, 10), + global_offset=(0, 0), + axis_fragmentations=(1, 1), + layout=None + ) + + params_with_sharded_tensor = { + "test_param": (src_sharded_tensor, dst_sharded_tensor) + } + + # Mock get_real_rank to return 0 + with patch('mindformers.checkpoint.checkpoint.get_real_rank', return_value=0): + result = build_tensors_reformulation_all_offsets(params_with_sharded_tensor) + + assert isinstance(result, dict) + assert "test_param" in result + assert "reshard_handler" in result["test_param"] + assert "all_offset" in result["test_param"] + assert isinstance(result["test_param"]["reshard_handler"], ReshardHandler) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_apply_parallel_load_strategy(): + """ + Feature: Test apply_parallel_load_strategy function. + Description: Test the parallel loading strategy for resharding tensor slices. + Expectation: The function should correctly merge tensor slices into the state_dict. + """ + # Create a mock reshard_handler + mock_handler = MagicMock(spec=ReshardHandler) + mock_handler.get_real_tensor.return_value = np.ones((10, 10), dtype=np.float32) + + # Create params_info with tensor_map + params_info = { + "test_param": { + "reshard_handler": mock_handler, + "tensor_map": { + 0: np.ones((5, 10), dtype=np.float32), + 1: np.ones((5, 10), dtype=np.float32) + } + } + } + + state_dict = {} + key_mapping = {"test_param": "mapped_param"} + + apply_parallel_load_strategy(params_info, state_dict, key_mapping, num_workers=1) + + assert "mapped_param" in state_dict + assert isinstance(state_dict["mapped_param"], Parameter) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_apply_parallel_load_strategy_without_key_mapping(): + """ + Feature: Test apply_parallel_load_strategy without key_mapping. + Description: Test the parallel loading strategy when key_mapping is None. + Expectation: The function should use original parameter names. + """ + # Create a mock reshard_handler + mock_handler = MagicMock(spec=ReshardHandler) + mock_handler.get_real_tensor.return_value = np.ones((10, 10), dtype=np.float32) + + params_info = { + "original_param": { + "reshard_handler": mock_handler, + "tensor_map": { + 0: np.ones((10, 10), dtype=np.float32) + } + } + } + + state_dict = {} + + apply_parallel_load_strategy(params_info, state_dict, key_mapping=None, num_workers=2) + + assert "original_param" in state_dict + assert isinstance(state_dict["original_param"], Parameter) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_apply_parallel_load_strategy_empty_params(): + """ + Feature: Test apply_parallel_load_strategy with empty params_info. + Description: Test the function behavior when params_info is empty. + Expectation: The function should handle empty input gracefully. + """ + state_dict = {} + apply_parallel_load_strategy({}, state_dict, key_mapping=None, num_workers=1) + assert len(state_dict) == 0 + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_categorize_params_with_empty_source(): + """ + Feature: Test categorize_params with empty source sharded tensor list. + Description: Test error handling when source metadata has empty ShardedTensor list. + Expectation: The function should raise ValueError for empty source list. + """ + dst_sharded_tensor = ShardedTensor( + key="test_param", + org_key="test_param", + dtype=mstype.float32, + local_shape=(10, 10), + global_shape=(10, 10), + global_offset=(0, 0), + axis_fragmentations=(1, 1), + layout=None + ) + + dst_sharded_tensor_metas = {"test_param": dst_sharded_tensor} + src_sharded_tensor_metas = {"test_param": []} # Empty list + + with pytest.raises(ValueError): + categorize_params(dst_sharded_tensor_metas, src_sharded_tensor_metas) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_categorize_params_not_mapping(): + """ + Feature: Test categorize_params with parameters not in source. + Description: Test that parameters missing from source are added to not_mapping_params. + Expectation: The function should correctly identify unmapped parameters. + """ + dst_sharded_tensor = ShardedTensor( + key="missing_param", + org_key="missing_param", + dtype=mstype.float32, + local_shape=(10, 10), + global_shape=(10, 10), + global_offset=(0, 0), + axis_fragmentations=(1, 1), + layout=None + ) + + dst_sharded_tensor_metas = {"missing_param": dst_sharded_tensor} + src_sharded_tensor_metas = {} # Parameter not in source + + not_mapping_params, need_concat_params, params_to_load = categorize_params( + dst_sharded_tensor_metas, src_sharded_tensor_metas + ) + + assert "missing_param" in not_mapping_params + assert len(need_concat_params) == 0 + assert len(params_to_load) == 0 diff --git a/tests/st/test_ut/test_checkpoint/test_fully_parallel.py b/tests/st/test_ut/test_checkpoint/test_fully_parallel.py index 4a881f91fbc7ea4ad07b8e0d978103b19cf34343..f2caee81aa53ad28c6359f7862690956e552822d 100644 --- a/tests/st/test_ut/test_checkpoint/test_fully_parallel.py +++ b/tests/st/test_ut/test_checkpoint/test_fully_parallel.py @@ -418,8 +418,20 @@ def test_balanced_save_strategy_save_with_existing_metadata( with patch("mindformers.checkpoint.fully_parallel.os.path.exists", return_value=True): with patch("mindformers.checkpoint.fully_parallel.load_metadata") as mock_load: - mock_load.return_value = ({"shard1": MagicMock()}, - {"param1": [{"file_name": "test.safetensors", "storage_rank": 0}]}) + mock_load.return_value = ( + { + "shard1": MagicMock() + }, + { + "param1": [ + { + "file_name": "test.safetensors", + "storage_rank": 0, + "rank_group": [0] + } + ] + } + ) strategy.save(0) # Check that save_checkpoint was called diff --git a/tests/st/test_ut/test_megatron_format_checkpoint/test_metadata.py b/tests/st/test_ut/test_megatron_format_checkpoint/test_metadata.py index 2a209aeacb06ecd8ad80c773b155850c4b6b1314..24ece5b0cc0bf28b55c420edc0102c535b0eabda 100644 --- a/tests/st/test_ut/test_megatron_format_checkpoint/test_metadata.py +++ b/tests/st/test_ut/test_megatron_format_checkpoint/test_metadata.py @@ -67,6 +67,8 @@ def save_metadata_without_npu(global_strategy_info, model_keys, user_prefix, met filter_func=(lambda x: x in list(model_keys)) if not save_optimizer else None ) + sharded_tensor_metas[cur_npu_rank] = cur_rank_sharded_tensors + # Get mappings of parameter file of current rank. for _, sharded_tensor in cur_rank_sharded_tensors.items(): if save_optimizer and sharded_tensor.key not in list(model_keys): @@ -77,6 +79,7 @@ def save_metadata_without_npu(global_strategy_info, model_keys, user_prefix, met ( ckpt_name + '.safetensors', cur_npu_rank, + [cur_npu_rank], # rank_group - new required field (sharded_tensor.key, sharded_tensor.global_offset) ) ) @@ -143,6 +146,7 @@ def test_save_and_load_metadata_case(): adam_mapping_0 = has_optimizer_param_file_mappings["('adam_m.decoder.layers.0.input_layernorm.weight', (0,))"][0] assert adam_mapping_0["storage_rank"] == 0 assert adam_mapping_0["file_name"] == "my_test_net-opt-0000000-0000002.safetensors" + assert "rank_group" in adam_mapping_0 # 4. Test load 'metadata.json' without optimizer info. no_optimizer_sharded_tensors, no_optimizer_param_file_mappings = load_metadata( diff --git a/tests/st/test_ut/test_utils/test_tensorboard/test_tensorboard.py b/tests/st/test_ut/test_utils/test_tensorboard/test_tensorboard.py index 41395f1d6364cba1e45d1c977640081017c6cb0f..2c62c8b7b1fceb1e65fa42d261c3b4e5f2765eef 100644 --- a/tests/st/test_ut/test_utils/test_tensorboard/test_tensorboard.py +++ b/tests/st/test_ut/test_utils/test_tensorboard/test_tensorboard.py @@ -49,7 +49,7 @@ _CHECK_TEXT_MAPPING = { 'optimizer', 'parallel_config', 'parallel', 'recompute_config', 'remove_redundancy', 'runner_config', 'runner_wrapper', 'monitor_config', 'tensorboard', 'train_dataset_task', 'train_dataset', 'trainer', 'swap_config', 'use_legacy', 'pretrained_model_dir', 'print_separate_loss', 'use_legacy_format', - 'balanced_load' + 'balanced_load', 'load_worker_number' } def generator_train():