From bf9176b81bc501420d7d4be3b60530c91af7e4c3 Mon Sep 17 00:00:00 2001 From: daizhou Date: Mon, 9 Jun 2025 18:30:35 +0800 Subject: [PATCH] pta support ms ckpt --- .../mindspore/training/checkpointing.py | 13 + .../mindspore/training/serialization.py | 619 ++++++++++++++++++ mindspeed_llm/tasks/megatron_adaptor.py | 4 +- 3 files changed, 635 insertions(+), 1 deletion(-) create mode 100644 mindspeed_llm/mindspore/training/checkpointing.py create mode 100644 mindspeed_llm/mindspore/training/serialization.py diff --git a/mindspeed_llm/mindspore/training/checkpointing.py b/mindspeed_llm/mindspore/training/checkpointing.py new file mode 100644 index 000000000..45194bd78 --- /dev/null +++ b/mindspeed_llm/mindspore/training/checkpointing.py @@ -0,0 +1,13 @@ +from functools import wraps + +def load_wrapper(fn): + @wraps(fn) + def wrapper(*args, **kwargs): + try: + res = fn(*args, **kwargs) + except: + from mindspeed_llm.mindspore.training.serialization import load + # breakpoint() + res = load(*args, **kwargs) + return res + return wrapper diff --git a/mindspeed_llm/mindspore/training/serialization.py b/mindspeed_llm/mindspore/training/serialization.py new file mode 100644 index 000000000..423684719 --- /dev/null +++ b/mindspeed_llm/mindspore/training/serialization.py @@ -0,0 +1,619 @@ +import io +import sys +import pickle +import pathlib +import zipfile +import warnings +import operator +from enum import Enum +from functools import reduce, wraps +from typing import Dict, Union, Optional, cast + +import torch +import numpy as np +import mindspore +from ml_dtypes import bfloat16 + +def _is_path(name_or_buffer): + return isinstance(name_or_buffer, (str, pathlib.Path)) + +class _opener: + def __init__(self, file_like): + self.file_like = file_like + + def __enter__(self): + return self.file_like + + def __exit__(self, *args): + pass + +class _open_file(_opener): + def __init__(self, name, mode): + super().__init__(open(name, mode)) + + def __exit__(self, *args): + self.file_like.close() + +class _open_buffer_writer(_opener): + def __exit__(self, *args): + self.file_like.flush() + +class _open_buffer_reader(_opener): + def __init__(self, buffer): + super().__init__(buffer) + _check_seekable(buffer) + +def _check_seekable(f) -> bool: + + def raise_err_msg(patterns, e): + for p in patterns: + if p in str(e): + msg = (str(e) + ". You can only torch.load from a file that is seekable." + + " Please pre-load the data into a buffer like io.BytesIO and" + + " try to load from it instead.") + raise type(e)(msg) + raise e + + try: + f.seek(f.tell()) + return True + except (io.UnsupportedOperation, AttributeError) as e: + raise_err_msg(["seek", "tell"], e) + return False + + +def _open_file_like(name_or_buffer, mode): + if _is_path(name_or_buffer): + return _open_file(name_or_buffer, mode) + else: + if 'w' in mode: + return _open_buffer_writer(name_or_buffer) + elif 'r' in mode: + return _open_buffer_reader(name_or_buffer) + else: + raise RuntimeError(f"Expected 'r' or 'w' in mode but got {mode}") + +def _is_zipfile(f) -> bool: + """ + Args: + f (file object): The file object to be checked for being a valid zip file. + It should be opened in binary mode and point to the beginning of the file. + + Returns: + bool: Returns True if the input file is a valid zip file, otherwise False. + + Raises: + No specific exceptions are raised by this function. + """ + # This is a stricter implementation than zipfile.is_zipfile(). + # zipfile.is_zipfile() is True if the magic number appears anywhere in the + # binary. Since we expect the files here to be generated by torch.save or + # torch.jit.save, it's safe to only check the start bytes and avoid + # collisions and assume the zip has only 1 file. + # See bugs.python.org/issue28494. + + # Read the first 4 bytes of the file + read_bytes = [] + start = f.tell() + + byte = f.read(1) + while byte != b"": + read_bytes.append(byte) + if len(read_bytes) == 4: + break + byte = f.read(1) + f.seek(start) + + local_header_magic_number = [b'P', b'K', b'\x03', b'\x04'] + return read_bytes == local_header_magic_number + + +class PyTorchFileReader: + """ + Class to allow PackageImporter to operate on unzipped packages. Methods + copy the behavior of the internal PyTorchFileReader class (which is used for + accessing packages in all other cases). + + N.B.: ScriptObjects are not depickleable or accessible via this DirectoryReader + class due to ScriptObjects requiring an actual PyTorchFileReader instance. + """ + def __init__(self, file): + """ + Initializes a new instance of PyTorchFileReader. + + Args: + self (PyTorchFileReader): The instance of the PyTorchFileReader class. + file (str): The path to the zip file to be read. + + Returns: + None. This method initializes the PyTorchFileReader instance with the provided file. + + Raises: + IOError: If the file specified by the 'file' parameter does not exist or cannot be opened. + zipfile.BadZipFile: If the file specified by the 'file' parameter is not a valid zip file. + IndexError: If the zip file does not contain any files. + """ + + self.file = zipfile.ZipFile(file) + if hasattr(file, 'offset'): + file.seek(0) + bytes = file.read(file.len) + bytes = io.BytesIO(bytes) + self.file = zipfile.ZipFile(bytes) + + self.directory = self.file.namelist()[0].split('/')[0] + + def open_record(self, name): + """ + Opens a record file from the PyTorchFileReader directory. + + Args: + self (PyTorchFileReader): The instance of the PyTorchFileReader class. + name (str): The name of the record file to open. + + Returns: + None: If the specified record file does not exist in the PyTorchFileReader directory. + + Raises: + None. + + This method checks if the specified record file exists in the PyTorchFileReader directory. If it does, the file is opened and returned. If the file does not exist, None is returned. + """ + filename = f"{self.directory}/{name}" + if filename in self.file.namelist(): + return self.file.open(filename) + return None + + def read_record(self, name): + """ + Reads a record from a PyTorch file. + + Args: + self (PyTorchFileReader): An instance of the PyTorchFileReader class. + name (str): The name of the record to read from the PyTorch file. + + Returns: + None: If the record with the specified name does not exist in the PyTorch file. + + Raises: + FileNotFoundError: If the PyTorch file does not exist in the specified directory. + IOError: If there is an error in reading the PyTorch file. + + """ + filename = f"{self.directory}/{name}" + if filename in self.file.namelist(): + return self.file.read(filename) + return None + + def has_record(self, name): + """ + This method checks if a record with the specified name exists in the PyTorchFileReader's directory. + + Args: + self (PyTorchFileReader): An instance of the PyTorchFileReader class. + name (str): The name of the record to be checked in the directory. + + Returns: + None: This method returns None. + + Raises: + None + """ + filename = f"{self.directory}/{name}" + return filename in self.file.namelist() + + def get_all_records( + self, + ): + """ + Retrieves a list of all records from the PyTorchFileReader object. + + Args: + self: The PyTorchFileReader object itself. + + Returns: + None. This method does not return any value. + + Raises: + None. + + This method iterates through the files in the PyTorchFileReader object's directory and retrieves the names of all records. The records are then returned as a list of file names. + + Note: + - The PyTorchFileReader object must be initialized with a valid directory. + - The list of file names returned only includes the names of the files, without the directory path. + """ + files = [name.replace(self.directory + '/' , '')for name in self.file.namelist()] + return files + + def get_record_offset(self, name): + """ + Returns the header offset of a specified record in a PyTorch file. + + Args: + self (PyTorchFileReader): An instance of the PyTorchFileReader class. + name (str): The name of the record for which the header offset is to be retrieved. + + Returns: + None: If the specified record does not exist in the PyTorch file. + + Raises: + None. + + This method takes in the self parameter, which is an instance of the PyTorchFileReader class. It also takes a name parameter, which represents the name of the record for which the header offset is to +be retrieved. The method checks if the specified record exists in the PyTorch file by creating the filename using the directory attribute of the PyTorchFileReader instance and the provided name. If the +filename exists in the file's namelist, the method returns the header offset of the file info associated with the filename. Otherwise, it returns None, indicating that the specified record does not exist in +the file. + """ + filename = f"{self.directory}/{name}" + if filename in self.file.namelist(): + return self.file.getinfo(filename).header_offset + return None + + +class _open_zipfile_reader(_opener): + + """ + The _open_zipfile_reader class represents a reader for opening and reading zip files. + It inherits from the _opener class and provides functionality for reading zip files. + + Attributes: + name_or_buffer: The name or buffer of the file to be opened. + + Methods: + __init__: Initializes the _open_zipfile_reader instance, using the specified name_or_buffer to open a PyTorchFileReader. + """ + def __init__(self, name_or_buffer) -> None: + """ + Initializes the _open_zipfile_reader class. + + Args: + self (object): The instance of the _open_zipfile_reader class. + name_or_buffer (str or file-like object): The name of the file or a buffer object for reading the zipfile. + It can be a string representing the name of the file or a file-like object for reading the zipfile data. + + Returns: + None: This method does not return any value. + + Raises: + - TypeError: If the name_or_buffer parameter is not a string or file-like object. + - ValueError: If the name_or_buffer parameter is empty or invalid. + - IOError: If there is an error reading the zipfile from the provided name_or_buffer. + """ + super().__init__(PyTorchFileReader(name_or_buffer)) + +def _is_torchscript_zip(zip_file): + """ + Checks if the given zip file contains a specific record. + + Args: + zip_file (object): The zip file to be checked for the presence of a specific record. + + Returns: + None: This function does not return any value. + + Raises: + None + """ + return 'constants.pkl' in zip_file.get_all_records() + + +class LoadEndianness(Enum): + + """ + Represents an enumeration for specifying the byte order (endianness) of a data load. + + This class inherits from the built-in Enum class in Python and provides a set of pre-defined constants for different byte orders. The byte order determines the arrangement of bytes in a multi-byte data +type, such as integers and floating-point numbers, when it is stored or transmitted. + + Attributes: + BIG_ENDIAN: Represents the big-endian byte order where the most significant byte is stored first. + LITTLE_ENDIAN: Represents the little-endian byte order where the least significant byte is stored first. + NATIVE: Represents the native byte order of the underlying platform. + NETWORK: Represents the byte order used in network byte order, which is big-endian. + + The LoadEndianness class allows you to easily specify the desired byte order when loading data, ensuring compatibility with the expected byte order. It provides a convenient and readable way to work with +different byte orders without the need for manual byte swapping or conversion. + + Usage: + The LoadEndianness class can be used to specify the byte order when loading data from a file, network, or any other data source. Simply import the class and use the desired constant to set the byte +order. + + Example: + >>> load_endianness = LoadEndianness.BIG_ENDIAN + >>> data = load_data(source_file, byte_order=load_endianness) + >>> print(data) + + Note: + It is important to ensure that the byte order specified matches the actual byte order of the data being loaded. Using the wrong byte order can lead to incorrect interpretation of the data and produce +unexpected results. + + """ + NATIVE = 1 + LITTLE = 2 + BIG = 3 + +_default_load_endian: Optional[LoadEndianness] = None + +def get_default_load_endianness() -> Optional[LoadEndianness]: + ''' + Get fallback byte order for loading files + + If byteorder mark is not present in saved checkpoint, + this byte order is used as fallback. + By default, it's "native" byte order. + + Returns: + default_load_endian: Optional[LoadEndianness] + ''' + return _default_load_endian + +def _maybe_decode_ascii(bytes_str: Union[bytes, str]) -> str: + """ + This function decodes a bytes string to ASCII if it is a bytes type, otherwise returns the input string. + + Args: + bytes_str (Union[bytes, str]): A bytes or string input to be decoded if it is a bytes type. If it is already a string, it will be returned as is. + + Returns: + str: The decoded ASCII string if the input is of bytes type, otherwise the original string. + + Raises: + None + """ + # When using encoding='bytes' in Py3, some **internal** keys stored as + # strings in Py2 are loaded as bytes. This function decodes them with + # ascii encoding, one that Py3 uses by default. + # + # NOTE: This should only be used on internal keys (e.g., `typename` and + # `location` in `persistent_load` below! + if isinstance(bytes_str, bytes): + return bytes_str.decode('ascii') + return bytes_str + +dtype_map = { + "HalfStorage": np.float16, + "FloatStorage": np.float32, + 'BFloat16Storage': bfloat16, + 'LongStorage': np.int64, + 'ByteStorage': np.uint8, + 'BoolStorage': np.bool_ +} + + +def load(f, map_location=None, pickle_module=pickle, *, weights_only=False, mmap=None, **pickle_load_args): + """ + Load a file using pickle, optionally with memory mapping. + + Args: + f (file-like object or str): The file to load from. If a string is provided, it should be the filename. + pickle_module (module): The module to use for pickling. Defaults to the standard 'pickle' module. + + Returns: + None: This function does not return any value. + + Raises: + ValueError: Raised if 'f' is not a string filename when using mmap argument, or if torchscript is detected in a zipfile. + RuntimeError: Raised if mmap argument is used without files saved with `torch.save(_use_new_zipfile_serialization=True)`. + """ + if pickle_module is None: + pickle_module = pickle + + if 'encoding' not in pickle_load_args: + pickle_load_args['encoding'] = 'utf-8' + + with _open_file_like(f, 'rb') as opened_file: + if _is_zipfile(opened_file): + # The zipfile reader is going to advance the current file position. + # If we want to actually tail call to torch.jit.load, we need to + # reset back to the original position. + overall_storage = None + with _open_zipfile_reader(opened_file, ) as opened_zipfile: + if _is_torchscript_zip(opened_zipfile): + raise ValueError('do not support torchscript now') + return _load(opened_zipfile, + map_location, + pickle_module, + overall_storage=overall_storage, + **pickle_load_args) + + +def _load(zip_file, map_location, pickle_module, pickle_file='data.pkl', overall_storage=None, **pickle_load_args): + """ + Loads data from a zip file using pickle serialization. + + Args: + zip_file (zipfile.ZipFile): The zip file containing the data. + pickle_module (module): The pickle module to use for deserialization. + overall_storage (numpy.memmap, optional): The overall storage for loading the data. + pickle_file (str, optional): The name of the pickle file within the zip file. Default is 'data.pkl'. + **pickle_load_args: Additional keyword arguments to pass to the pickle module's load function. + + Returns: + None + + Raises: + ValueError: If an unknown endianness type is encountered. + ValueError: If an invalid load endianness type is encountered. + UserWarning: If the default load endianness is changed on big endian machines. + + """ + loaded_storages = {} + # check if byteswapping is needed + byteordername = 'byteorder' + byteorderdata = None + if zip_file.has_record(byteordername): + byteorderdata = zip_file.read_record(byteordername) + if byteorderdata not in [b'little', b'big']: + raise ValueError('Unknown endianness type: ' + byteorderdata.decode()) + elif get_default_load_endianness() == LoadEndianness.LITTLE or \ + get_default_load_endianness() is None: + byteorderdata = b'little' + elif get_default_load_endianness() == LoadEndianness.BIG: + byteorderdata = b'big' + elif get_default_load_endianness() == LoadEndianness.NATIVE: + pass + else: + raise ValueError('Invalid load endianness type') + + if not zip_file.has_record(byteordername) and \ + get_default_load_endianness() is None and \ + sys.byteorder == 'big': + # Default behaviour was changed + # See https://github.com/pytorch/pytorch/issues/101688 + warnings.warn("The default load endianness for checkpoints without a byteorder mark " + "on big endian machines was changed from 'native' to 'little' endian, " + "to avoid this behavior please use " + "torch.serialization.set_default_load_endianness to set " + "the desired default load endianness", + UserWarning) + + def persistent_load(saved_id): + assert isinstance(saved_id, tuple) + typename = _maybe_decode_ascii(saved_id[0]) + data = saved_id[1:] + + assert typename == 'storage', \ + f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'" + storage_type, key, location, numel = data + + # msadapter brach + name = f'data/{key}' + if name in loaded_storages: + return loaded_storages[name] + + if overall_storage is not None: + array = np.memmap(overall_storage, dtype=dtype_map[storage_type], + offset=zip_file.open_record(name)._fileobj.tell(), shape=(numel,)) + else: + array = np.frombuffer(zip_file.read_record(name), dtype_map[storage_type]) + loaded_storages[name] = array + return array + + load_module_mapping: Dict[str, str] = { + # See https://github.com/pytorch/pytorch/pull/51633 + 'torch.tensor': 'torch._tensor' + } + + # Need to subclass Unpickler instead of directly monkey-patching the find_class method + # because it's marked readonly in pickle. + # The type: ignore is because mypy can't statically determine the type of this class. + class UnpicklerWrapper(pickle_module.Unpickler): # type: ignore[name-defined] + # from https://stackoverflow.com/questions/13398462/unpickling-python-objects-with-a-changed-module-path/13405732 + # Lets us override the imports that pickle uses when unpickling an object. + # This is useful for maintaining BC if we change a module path that tensor instantiation relies on. + def find_class(self, mod_name, name): + if mod_name == 'torch._utils': + return eval(name) + if mod_name == 'torch': + return str(name) + if mod_name == 'torch._tensor': + return eval(name) + mod_name = load_module_mapping.get(mod_name, mod_name) + return super().find_class(mod_name, name) + + # Load the data (which may in turn use `persistent_load` to load tensors) + data_file = zip_file.open_record(pickle_file) + + unpickler = UnpicklerWrapper(data_file, **pickle_load_args) + unpickler.persistent_load = persistent_load + result = unpickler.load() + result = transform_ms_dtype_to_pt_dtype(result) + + return result + +DTYPE_MAP = { + mindspore.float32: torch.float32, + mindspore.bfloat16: torch.bfloat16 +} + +def transform_ms_dtype_to_pt_dtype(state): + if isinstance(state, dict): + new_state_dict = {} + for k, v in state.items(): + new_key = k + v = transform_ms_dtype_to_pt_dtype(v) + if isinstance(k, tuple) and len(k) == 2: + new_key = [] + for ms_dtype in k: + pt_dtype = DTYPE_MAP.get(ms_dtype) + if pt_dtype is None: + raise ValueError(f"convert error, unsupported dtype {ms_dtype}") + new_key.append(pt_dtype) + new_key = tuple(new_key) + new_state_dict[new_key] = v + return new_state_dict + elif isinstance(state, list): + new_state_list = [] + for member in state: + new_state_list.append(transform_ms_dtype_to_pt_dtype(member)) + return new_state_list + else: + return state + + +def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata=None): + '''Rebuilds a tensor based on the provided parameters. + + Args: + storage (ndarray): The storage array from which the tensor is created. + storage_offset (int): The offset in the storage array from where the tensor data starts. + size (tuple): The size of the tensor. + stride (tuple or None): The stride of the tensor, or None if not applicable. + requires_grad (bool): Indicates if the tensor requires gradient computation. + backward_hooks (list): A list of backward hooks for the tensor. + metadata (Any, optional): Additional metadata associated with the tensor. + + Returns: + None: This function does not have a return value. + + Raises: + None: This function does not raise any exceptions. + ''' + + if size == (): + num_elemets = 1 + else: + num_elemets = reduce(operator.mul, size) + array = storage[storage_offset: storage_offset + num_elemets] + + if stride is not None and len(stride) > 1 and stride[0] == 1: + # stride = tuple((s * 4 for s in stride)) + # # stride = tuple((s * 4 if s != 1 else s for s in stride)) + # array = np.lib.stride_tricks.as_strided(array, size, stride) + order = "F" + array = array.reshape(size, order=order) + else: + order = "C" + array = array.reshape(size, order=order) + if array.dtype == bfloat16: + param = torch.frombuffer(array.tobytes(), dtype=torch.bfloat16).reshape(array.shape) + else: + param = torch.from_numpy(array) + return param + + +def _rebuild_from_type_v2(func, new_type, args, state): + ret = func(*args) + return ret + + +if __name__ == "__main__": + state_dict = load(".pt") + + def recursive_print(state, prefix=None): + if isinstance(state, dict): + for k, v in state.items(): + prefix.append(str(k)) + recursive_print(v, prefix) + elif isinstance(state, list): + for i, member in enumerate(state): + prefix.append(str(i)) + recursive_print(member, prefix) + elif isinstance(state, torch.Tensor): + state_name = ".".join(prefix) + print(f"{state_name} {state.dtype} {state.size()} {state.sum()}", flush=True) + else: + state_name = ".".join(prefix) + print(f"{state_name} {type(state)} {state}", flush=True) + recursive_print(state_dict, []) diff --git a/mindspeed_llm/tasks/megatron_adaptor.py b/mindspeed_llm/tasks/megatron_adaptor.py index 5a3c51448..ba60aaedd 100644 --- a/mindspeed_llm/tasks/megatron_adaptor.py +++ b/mindspeed_llm/tasks/megatron_adaptor.py @@ -663,6 +663,8 @@ class CoreAdaptation(MegatronAdaptationABC): gpt_dataset_getitem_wrapper) MegatronAdaptation.register('megatron.core.datasets.gpt_dataset._get_ltor_masks_and_position_ids', _get_ltor_masks_and_position_ids) + from mindspeed_llm.mindspore.training.checkpointing import load_wrapper + MegatronAdaptation.register('torch.load', load_wrapper) def patch_utils(self): @@ -956,4 +958,4 @@ class LegacyAdaptation(MegatronAdaptationABC): MegatronAdaptation.register('torch.distributed.get_world_size', get_world_size_wrapper) MegatronAdaptation.register('megatron.training.initialize.initialize_megatron', initialize_megatron, force_patch=True) -MegatronAdaptation.execute() \ No newline at end of file +MegatronAdaptation.execute() -- Gitee