diff --git a/toolbox/ColossalAI/v0.4.4/patches/.gitignore b/toolbox/ColossalAI/v0.4.4/patches/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..c5f09140433fd513e8f6558a203bbd127eb15b95 --- /dev/null +++ b/toolbox/ColossalAI/v0.4.4/patches/.gitignore @@ -0,0 +1,171 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +build_pip/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ +docs/.build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# IDE +.idea/ +.vscode/ + +# macos +*.DS_Store +#data/ + +docs/.build + +# pytorch checkpoint +*.pt + +# ignore version.py generated by setup.py +colossalai/version.py + +# ignore any kernel build files +.o +.so + +# ignore python interface defition file +.pyi + +# ignore coverage test file +coverage.lcov +coverage.xml + +# ignore testmon and coverage files +.coverage +.testmondata* + +# log, test files - ColossalChat +applications/ColossalChat/logs +applications/ColossalChat/tests/logs + +examples/language/mixtral/profile + +# ignore nsys report files +*.nsys-rep \ No newline at end of file diff --git a/toolbox/ColossalAI/v0.4.4/patches/applications/Colossal-LLaMA/colossal_llama/dataset/loader.py b/toolbox/ColossalAI/v0.4.4/patches/applications/Colossal-LLaMA/colossal_llama/dataset/loader.py new file mode 100644 index 0000000000000000000000000000000000000000..eed42eb49a8e2fd1fcd2efaed26458b8bbe89859 --- /dev/null +++ b/toolbox/ColossalAI/v0.4.4/patches/applications/Colossal-LLaMA/colossal_llama/dataset/loader.py @@ -0,0 +1,181 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import os +from dataclasses import dataclass +from typing import Dict, Iterator, List, Optional, Sequence, Union + +import torch +import torch.nn.functional as F +from datasets import Dataset as HFDataset +from datasets import dataset_dict, load_from_disk +from torch.utils.data import ConcatDataset, Dataset, DistributedSampler +from transformers.tokenization_utils import PreTrainedTokenizer + +DatasetType = Union[Dataset, ConcatDataset, dataset_dict.Dataset] +PathType = Union[str, os.PathLike] + + +def load_tokenized_dataset( + dataset_parrent_path: Union[PathType, List[PathType]], mode: str = "train" +) -> Optional[DatasetType]: + """ + Load pre-tokenized dataset. + Each instance of dataset is a dictionary with + `{'input_ids': List[int], 'labels': List[int], sequence: str}` format. + """ + mode_map = {"train": "train", "dev": "validation", "test": "test"} + assert mode in tuple(mode_map), f"Unsupported mode {mode}, it must be in {tuple(mode_map)}" + if dataset_parrent_path: + dataset_paths=[] + for dirname in os.listdir(dataset_parrent_path): + dataset_paths.append(os.path.join(dataset_parrent_path, dirname)) + + # if isinstance(dataset_paths, (str, os.PathLike)): + # dataset_paths = [dataset_paths] + + datasets = [] # `List[datasets.dataset_dict.Dataset]` + for ds_path in dataset_paths: + ds_path = os.path.abspath(ds_path) + assert os.path.exists(ds_path), f"Not existed file path {ds_path}" + ds_dict = load_from_disk(dataset_path=ds_path, keep_in_memory=False) + if isinstance(ds_dict, HFDataset): + datasets.append(ds_dict) + else: + if mode_map[mode] in ds_dict: + datasets.append(ds_dict[mode_map[mode]]) + if len(datasets) == 0: + return None + if len(datasets) == 1: + return datasets.pop() + return ConcatDataset(datasets=datasets) + + +@dataclass +class DataCollatorForSupervisedDataset(object): + """ + Collate instances for supervised dataset. + Each instance is a tokenized dictionary with fields + `input_ids`(List[int]), `labels`(List[int]) and `sequence`(str). + """ + + tokenizer: PreTrainedTokenizer + max_length: int = 4096 + ignore_index: int = -100 + padding: str = "max_length" + + def __call__(self, instances: Sequence[Dict[str, List[int]]]) -> Dict[str, torch.Tensor]: + """ + + Args: + instances (`Sequence[Dict[str, List[int]]]`): + Mini-batch samples, each sample is stored in an individual dictionary. + + Returns: + (`Dict[str, torch.Tensor]`): Contains the following `torch.Tensor`: + `input_ids`: `torch.Tensor` of shape (bsz, max_len); + `attention_mask`: `torch.BoolTensor` of shape (bsz, max_len); + `labels`: `torch.Tensor` of shape (bsz, max_len), which contains `IGNORE_INDEX`. + """ + assert isinstance(self.tokenizer.pad_token_id, int) and self.tokenizer.pad_token_id >= 0, ( + f"`{self.tokenizer.__class__.__name__}.pad_token_id` must be a valid non-negative integer index value, " + f"but now `{self.tokenizer.pad_token_id}`" + ) + + # `List[torch.Tensor]` + batch_input_ids = [ + ( + torch.LongTensor(instance["input_ids"][: self.max_length]) + if len(instance["input_ids"]) > self.max_length + else torch.LongTensor(instance["input_ids"][:-1]) + ) + for instance in instances + ] + batch_labels = [ + ( + torch.LongTensor(instance["labels"][1: self.max_length+1]) + if len(instance["labels"]) > self.max_length + else torch.LongTensor(instance["labels"])[1:] + ) + for instance in instances + ] + + if self.tokenizer.padding_side == "right": + input_ids = torch.nn.utils.rnn.pad_sequence( + sequences=batch_input_ids, + batch_first=True, + padding_value=self.tokenizer.pad_token_id, + ) # (bsz, max_len) + labels = torch.nn.utils.rnn.pad_sequence( + sequences=batch_labels, + batch_first=True, + padding_value=self.ignore_index, + ) # (bsz, max_len) + if self.padding == "max_length": + # pad to max + to_pad = self.max_length - input_ids.size(1) + input_ids = F.pad(input_ids, (0, to_pad), value=self.tokenizer.pad_token_id) + labels = F.pad(labels, (0, to_pad), value=self.ignore_index) + elif self.tokenizer.padding_side == "left": + reversed_input_ids = [seq.flip(dims=(0,)) for seq in batch_input_ids] + reversed_input_ids = torch.nn.utils.rnn.pad_sequence( + sequences=reversed_input_ids, + batch_first=True, + padding_value=self.tokenizer.pad_token_id, + ) # (bsz, max_len) + input_ids = torch.flip(reversed_input_ids, dims=(1,)) # (bsz, max_len) + reversed_labels = [seq.flip(dims=(0,)) for seq in batch_labels] + reversed_labels = torch.nn.utils.rnn.pad_sequence( + sequences=reversed_labels, + batch_first=True, + padding_value=self.ignore_index, + ) # (bsz, max_len) + labels = torch.flip(reversed_labels, dims=(1,)) # (bsz, max_len) + else: + raise RuntimeError( + f"`{self.tokenizer.__class__.__name__}.padding_side` can only be `left` or `right`, " + f"but now `{self.tokenizer.padding_side}`" + ) + + attention_mask = input_ids.ne(self.tokenizer.pad_token_id) # `torch.BoolTensor`, (bsz, max_len) + + return dict(input_ids=input_ids, attention_mask=attention_mask, labels=labels) + + +class StatefulDistributedSampler(DistributedSampler): + """ + Stateful distributed sampler for multi-stage training. + """ + + def __init__( + self, + dataset: DatasetType, + num_replicas: Optional[int] = None, + rank: Optional[int] = None, + shuffle: bool = True, + seed: int = 0, + drop_last: bool = False, + ) -> None: + super().__init__( + dataset=dataset, + num_replicas=num_replicas, + rank=rank, + shuffle=shuffle, + seed=seed, + drop_last=drop_last, + ) + self.start_index = 0 + + def __iter__(self) -> Iterator: + iterator = super().__iter__() + indices = list(iterator) + indices = indices[self.start_index :] + return iter(indices) + + def __len__(self) -> int: + return self.num_samples - self.start_index + + def set_start_index(self, start_index: int) -> None: + self.start_index = start_index diff --git a/toolbox/ColossalAI/v0.4.4/patches/applications/Colossal-LLaMA/colossal_llama/dataset/spliced_and_tokenized_dataset.py b/toolbox/ColossalAI/v0.4.4/patches/applications/Colossal-LLaMA/colossal_llama/dataset/spliced_and_tokenized_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..0b79d7fdc39fea1f02f9733621236bbb4d1db7a5 --- /dev/null +++ b/toolbox/ColossalAI/v0.4.4/patches/applications/Colossal-LLaMA/colossal_llama/dataset/spliced_and_tokenized_dataset.py @@ -0,0 +1,331 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Splicing multiple pre-tokenized sequence data points +""" + +import bisect +import random +import warnings +from copy import deepcopy +from typing import Any, Callable, Dict, Iterable, List, Tuple, Union + +from datasets import dataset_dict +from torch.utils.data import ConcatDataset, Dataset, IterableDataset +from transformers import AutoTokenizer +from transformers.models.llama.tokenization_llama import LlamaTokenizer +from transformers.tokenization_utils import PreTrainedTokenizer + +from colossalai.logging import get_dist_logger + +from .conversation import Conversation, default_conversation + +logger = get_dist_logger() + +IGNORE_INDEX = -100 + +DSType = Union[Dataset, ConcatDataset, dataset_dict.Dataset] + +def supervised_tokenize_pretrain_webtext( + data_point: Dict[str, str], tokenizer: LlamaTokenizer, ignore_index: int = None, max_length: int = 4096 +) -> Dict[str, Union[int, str, List[int]]]: + + """ + A tokenization function to tokenize an original pretraining data point as following: + {"id": 0, "text": "Beijing, the capital of the People's Republic of China, ...", "length": 124,"ended": False} + """ + assert tokenizer.add_bos_token is False and tokenizer.add_eos_token is False, ( + "Initially set `tokenizer.add_bos_token` and `tokenizer.add_eos_token` to False, " + "add and manually later" + ) + + text = data_point["text"] + sequence_text = tokenizer.bos_token + text + tokenizer.eos_token + sequence_input_ids = tokenizer(sequence_text)["input_ids"] + sequence_labels = deepcopy(sequence_input_ids) + if len(sequence_input_ids) > max_length: + sequence_input_ids = sequence_input_ids[:max_length] + sequence_labels = sequence_labels[:max_length] + + return dict( + input_ids=sequence_input_ids, + labels=sequence_labels, + seq_length=len(sequence_input_ids), + ) + + + +def supervised_tokenize_pretrain( + data_point: Dict[str, str], tokenizer: LlamaTokenizer, ignore_index: int = None, max_length: int = 4096 +) -> Dict[str, Union[int, str, List[int]]]: + """ + A tokenization function to tokenize an original pretraining data point as following: + {"source": "", "target": "Beijing, the capital of the People's Republic of China, ...", "category": "geography"} + """ + assert tokenizer.add_bos_token is False and tokenizer.add_eos_token is False, ( + "Initially set `tokenizer.add_bos_token` and `tokenizer.add_eos_token` to False, " + "add and manually later" + ) + if ignore_index is None: + ignore_index = IGNORE_INDEX + + source_text = data_point["source"] # `str` + target_text = data_point["target"] # `str` + is_null_source = len(source_text) == 0 + + source_text = tokenizer.bos_token + source_text + target_text += tokenizer.eos_token + sequence_text = source_text + target_text + + tokenized = tokenizer([source_text, sequence_text])["input_ids"] + sequence_input_ids = tokenized[1] + sequence_labels = deepcopy(sequence_input_ids) + + source_length = len(tokenized[0]) + if not is_null_source: + sequence_labels[:source_length] = [ignore_index for _ in range(source_length)] + + # sequence truncation. + if len(sequence_input_ids) > max_length: + sequence_input_ids = sequence_input_ids[:max_length] + sequence_labels = sequence_labels[:max_length] + + return dict( + input_ids=sequence_input_ids, + labels=sequence_labels, + seq_length=len(sequence_input_ids), + seq_category=data_point["category"], + ) + + +def supervised_tokenize_sft( + data_point: Dict[str, str], + tokenizer: AutoTokenizer, + conversation_template: Conversation = default_conversation, + ignore_index: int = None, + max_length: int = 4096, +) -> Dict[str, Union[int, str, List[int]]]: + """ + A tokenization function to tokenize an original supervised data point as following: + {"messages": [{"from": "human", "content": "xxx"}, {"from": "assistant", "content": "xxx"}]} + """ + assert tokenizer.add_bos_token is False and tokenizer.add_eos_token is False, ( + "Initially set `tokenizer.add_bos_token` and `tokenizer.add_eos_token` to False, " + "add and manually later" + ) + + assert ( + tokenizer.bos_token == conversation_template.seps[0] and tokenizer.eos_token == conversation_template.seps[1] + ), "`bos_token` and `eos_token` should be the same with `conversation_template.seps`." + + if ignore_index is None: + ignore_index = IGNORE_INDEX + + messages = data_point["messages"] + template = deepcopy(conversation_template) + template.messages = [] + + for mess in messages: + from_str = mess["from"] + if from_str.lower() == "human": + from_str = template.roles[0] + elif from_str.lower() == "assistant": + from_str = template.roles[1] + else: + raise ValueError(f"Unsupported role {from_str.lower()}") + + template.append_message(from_str, mess["content"]) + + if len(template.messages) % 2 != 0: + template.messages = template.messages[0:-1] + + # `target_turn_index` is the number of turns which exceeds `max_length - 1` for the first time. + turns = [i for i in range(1, len(messages) // 2 + 1)] + target_turn_index = bisect.bisect_right( + turns, + max_length - 1, + key=lambda x: len(tokenizer([template.get_prompt(2 * x)], add_special_tokens=False)["input_ids"][0]), + ) + + # The tokenized length for first turn already exceeds `max_length - 1`. + if target_turn_index - 1 < 0: + return dict( + input_ids=None, + labels=None, + inputs_decode=None, + labels_decode=None, + seq_length=None, + seq_category=None, + ) + + target_turn = turns[target_turn_index - 1] + prompt = template.get_prompt(2 * target_turn) + tokenized = tokenizer([prompt], add_special_tokens=False)["input_ids"][0] + + template.messages = template.messages[0 : 2 * target_turn] + + starts = [] + ends = [] + gpt_bos = False if template.messages[0][0] == template.roles[0] else True + gpt_eos = False if template.messages[0][0] == template.roles[0] else True + + for i, token_id in enumerate(tokenized): + if token_id == tokenizer.bos_token_id: + if gpt_bos: + starts.append(i) + gpt_bos = not gpt_bos + elif token_id == tokenizer.eos_token_id: + if gpt_eos: + ends.append(i) + gpt_eos = not gpt_eos + + if len(starts) != target_turn or len(ends) != target_turn: + logger.info( + "Please check whether the tokenizer add additional `bos_token` and `eos_token`.\n\nOr the original message contains `bos_token` or `eos_token`." + ) + return dict( + input_ids=None, + labels=None, + inputs_decode=None, + labels_decode=None, + seq_length=None, + seq_category=None, + ) + + tokenized = [tokenizer.bos_token_id] + tokenized + labels = [ignore_index] * len(tokenized) + for start, end in zip(starts, ends): + labels[start + 1 : end + 2] = tokenized[start + 1 : end + 2] + + labels_decode = deepcopy(labels) + for i, z in enumerate(labels_decode): + if z == ignore_index: + labels_decode[i] = tokenizer.unk_token_id + + # `inputs_decode` and `labels_decode` can be used to check whether the tokenization method is true. + return dict( + input_ids=tokenized, + labels=labels, + inputs_decode=tokenizer.decode(tokenized), + labels_decode=tokenizer.decode(labels_decode), + seq_length=len(tokenized), + seq_category=data_point["category"] if "category" in data_point else "None", + ) + + +class ClosedToConstantLengthSplicedDataset(IterableDataset): + """ + Define an iterable dataset that returns a (close to) constant length data point spliced from multiple + original independent (pre-tokenized) data points. + """ + + def __init__( + self, + dataset: DSType, + tokenizer: PreTrainedTokenizer, + max_length: int = 4096, + num_packed_sequences: int = 8, + fetch_sequence_func: Callable[[Any], Tuple[List[int], List[int]]] = None, + input_ids_field: str = "input_ids", + labels_field: str = "labels", + infinite: bool = False, + shuffle: bool = True, + error_strict: bool = False, + ) -> None: + self.tokenizer = tokenizer + self.dataset = dataset + self.max_length = max_length + self.infinite = infinite + self.max_buffer_size = max_length * num_packed_sequences # e.g., 4096 * 16 + self.shuffle = shuffle + + # Callable[[Dict[str, Any]], Tuple[List[int], List[int]]], + # A function that fetch sequence input_ids and labels from the original data point + if fetch_sequence_func is None: + self.fetch_sequence_func = lambda data_point: (data_point[input_ids_field], data_point[labels_field]) + else: + self.fetch_sequence_func = fetch_sequence_func + self.input_ids_field = input_ids_field + self.labels_field = labels_field + + self.error_strict = error_strict + self.current_size = 0 # `int`, current packed data size. + + def __len__(self) -> int: + return len(self.dataset) + + def __iter__(self) -> Iterable[Dict[str, List[int]]]: + iterator = iter(self.dataset) + more_data_points = True + while more_data_points is True: + buffer, buffer_len = [], 0 + while True: + # ending condition. + if buffer_len >= self.max_buffer_size: + break + try: + # `Tuple[List[int], List[int]]` + seq_input_ids, seq_labels = self.fetch_sequence_func(next(iterator)) + buffer.append({self.input_ids_field: seq_input_ids, self.labels_field: seq_labels}) + buffer_len += len(buffer[-1][self.input_ids_field]) + except StopIteration: + if self.infinite is True: + iterator = iter(self.dataset) + warnings.warn("The dataset reached end and the iterator is reset to the start.") + else: + more_data_points = False + break + examples = [] # `List[Dict[str, List[int]]]`, save buffered spliced data points. + spliced_input_ids, spliced_labels = [], [] # `List[int]`, `List[int]` + for i, data_point in enumerate(buffer): + # TODO(2023-09-18) check errors for each unspliced tokenized data point + seq_input_ids = data_point[self.input_ids_field] + seq_labels = data_point[self.labels_field] + # Handle special case: + # If the length of an original data point (i.e., input_ids length of a data point before splicing) + # exceeds `max_length`, truncate it. + if len(seq_input_ids) > self.max_length: + truncated_seq_input_ids = seq_input_ids[: self.max_length] + truncated_label_ids = seq_labels[: self.max_length] + if set(truncated_label_ids) == {IGNORE_INDEX}: + if self.error_strict is True: + raise ValueError( + f"Find an out-of-bounds length({len(seq_input_ids)}) data point " + f"with all label values as {IGNORE_INDEX}." + ) + else: + warnings.warn(f"Filter an error truncated data point (labels all {IGNORE_INDEX})") + continue # Skip the current error data point. + spliced_data_point = { + self.input_ids_field: truncated_seq_input_ids, + self.labels_field: truncated_label_ids, + } + examples.append(spliced_data_point) + warnings.warn("Find a data point to be truncated.") + continue + + # Pre action judgment. + if len(spliced_input_ids) + len(seq_input_ids) > self.max_length: + spliced_input_ids.extend(seq_input_ids) + spliced_labels.extend(seq_labels) + spliced_data_point = { + self.input_ids_field: spliced_input_ids[:self.max_length], + self.labels_field: spliced_labels[:self.max_length], + } # `Dict[str, List[int]]` + # Update. + spliced_input_ids, spliced_labels = [], [] + examples.append(spliced_data_point) + else: + spliced_input_ids.extend(seq_input_ids) + spliced_labels.extend(seq_labels) + # For residual spliced data point at the end of the data set + if self.infinite is False and more_data_points is False and len(spliced_input_ids) > 0: + examples.append({self.input_ids_field: spliced_input_ids, self.labels_field: spliced_labels}) + if self.shuffle: + random.shuffle(examples) + for spliced_data_point in examples: + # TODO(2023-09-18): check errors for each spliced tokenized data point. + self.current_size += 1 + yield spliced_data_point diff --git a/toolbox/ColossalAI/v0.4.4/patches/applications/Colossal-LLaMA/dataset/convert_data.py b/toolbox/ColossalAI/v0.4.4/patches/applications/Colossal-LLaMA/dataset/convert_data.py new file mode 100644 index 0000000000000000000000000000000000000000..9d2895aa26af599b8afaa67588a2fd3c743e5b0d --- /dev/null +++ b/toolbox/ColossalAI/v0.4.4/patches/applications/Colossal-LLaMA/dataset/convert_data.py @@ -0,0 +1,38 @@ +#!/usr/bin/env python3 +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json + +with open('dataset/school_math/school_math_0.25M.jsonl', 'r', encoding='utf-8') as file: + lines=file.readlines() + +res_datas=[] +for line in lines: + data=json.loads(line.strip()) + human_content=data["conversation"][0]["human"] + assistant_content=data["conversation"][0]["assistant"] + + Res_data={"messages": [{"from": "human", "content": human_content}, {"from": "assistant", "content": assistant_content}]} + + res_datas.append(Res_data) + # print(Res_data) + if len(res_datas) > 20000: + break + +with open('dataset/school_math/convert/school_math_0.25M_convert.jsonl', 'w', encoding='utf-8') as file: + for res_data in res_datas: + file.write(json.dumps(res_data, ensure_ascii=False)+'\n') + + diff --git a/toolbox/ColossalAI/v0.4.4/patches/applications/Colossal-LLaMA/dataset/prepare_pretrain_dataset.py b/toolbox/ColossalAI/v0.4.4/patches/applications/Colossal-LLaMA/dataset/prepare_pretrain_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..ef2d3e8eacd0a8fe5d29a42de4c35acf7cf8c29a --- /dev/null +++ b/toolbox/ColossalAI/v0.4.4/patches/applications/Colossal-LLaMA/dataset/prepare_pretrain_dataset.py @@ -0,0 +1,153 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Prepare dataset for continual pre-training +""" + +import argparse +import json +import math +import os +import time +from multiprocessing import cpu_count + +import sys +sys.path.append(os.path.join(os.path.dirname(__file__), "../")) +from colossal_llama.dataset.spliced_and_tokenized_dataset import ( + ClosedToConstantLengthSplicedDataset, + supervised_tokenize_pretrain, + supervised_tokenize_pretrain_webtext +) +from datasets import dataset_dict, load_dataset +from transformers import AutoTokenizer + +from colossalai.logging import get_dist_logger + +logger = get_dist_logger() + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--data_input_dirs", + type=str, + required=True, + default=None, + help="Comma(i.e., ',') separated list of all data directories containing `.jsonl` data files.", + ) + parser.add_argument( + "--tokenizer_dir", type=str, required=True, default=None, help="A directory containing the tokenizer" + ) + parser.add_argument("--data_output_dirs", type=str, default="data_output_dirs", help="Data output directory") + parser.add_argument("--max_length", type=int, default=8192, help="Max length of each spliced tokenized sequence") + parser.add_argument("--num_spliced_dataset_bins", type=int, default=10, help="Number of spliced dataset bins") + parser.add_argument("--dataset_type", type=str, default="webtext", help="dataset type") + args = parser.parse_args() + + if args.num_spliced_dataset_bins >= 100000: + raise ValueError("Too many spliced divisions, must be smaller than 100000") + + args.data_cache_dir = os.path.join(args.data_output_dirs, "cache") + args.data_jsonl_output_dir = os.path.join(args.data_output_dirs, "jsonl") + args.data_arrow_output_dir = os.path.join(args.data_output_dirs, "arrow") + + if not os.path.exists(args.data_cache_dir): + os.makedirs(args.data_cache_dir) + if not os.path.exists(args.data_jsonl_output_dir): + os.makedirs(args.data_jsonl_output_dir) + if not os.path.exists(args.data_arrow_output_dir): + os.makedirs(args.data_arrow_output_dir) + + # Prepare to all input datasets + input_data_paths = [] + input_data_dirs = args.data_input_dirs.split(",") + for ds_dir in input_data_dirs: + ds_dir = os.path.abspath(ds_dir) + assert os.path.exists(ds_dir), f"Not find data dir {ds_dir}" + ds_files = [name for name in os.listdir(ds_dir) if name.endswith(".jsonl")] + ds_paths = [os.path.join(ds_dir, name) for name in ds_files] + input_data_paths.extend(ds_paths) + + # Prepare to data splitting. + train_splits = [] + split_interval = math.ceil(100 / args.num_spliced_dataset_bins) + for i in range(0, 100, split_interval): + start = i + end = i + split_interval + if end > 100: + end = 100 + train_splits.append(f"train[{start}%:{end}%]") + + # Prepare to the tokenizer. + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_dir) + tokenizer.add_bos_token = False + tokenizer.add_eos_token = False + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.unk_token + + list_dataset = load_dataset( + path="json", + data_files=input_data_paths, + cache_dir=os.path.join(args.data_cache_dir, "raw"), + keep_in_memory=False, + split=train_splits, + num_proc=cpu_count(), + ) + for index, dataset in enumerate(list_dataset): + assert isinstance(dataset, dataset_dict.Dataset) + logger.info(f"Start to process part-{index}/{len(list_dataset)} of all original datasets.") + dataset = dataset.map( + function=supervised_tokenize_pretrain_webtext if args.dataset_type =="webtext" else supervised_tokenize_pretrain, + fn_kwargs={"tokenizer": tokenizer, "max_length": args.max_length}, + keep_in_memory=False, + num_proc=min(len(dataset), cpu_count()), + ) + if args.dataset_type =="webtext": + dataset = dataset.remove_columns(column_names=["id", "text", "length", "ended"]) + dataset = dataset.sort(column_names=("seq_length"), reverse=False, keep_in_memory=False) + dataset = dataset.remove_columns(column_names=["seq_length"]) + else: + dataset = dataset.remove_columns(column_names=["source", "target", "category"]) + dataset = dataset.sort(column_names=("seq_category", "seq_length"), reverse=False, keep_in_memory=False) + dataset = dataset.remove_columns(column_names=["seq_category", "seq_length"]) + spliced_dataset = ClosedToConstantLengthSplicedDataset( + dataset=dataset, tokenizer=tokenizer, max_length=args.max_length, error_strict=False + ) + # Save each jsonl spliced dataset. + output_index = "0" * (5 - len(str(index))) + str(index) + output_name = f"part-{output_index}" + output_jsonl_path = os.path.join(args.data_jsonl_output_dir, output_name + ".jsonl") + st = time.time() + with open(file=output_jsonl_path, mode="w", encoding="utf-8") as fp_writer: + spliced_count = 0 + for spliced_data_point in spliced_dataset: + if spliced_count % 500 == 0: + logger.info(f"processing {spliced_count} spliced data points for {fp_writer.name}") + spliced_count += 1 + fp_writer.write(json.dumps(spliced_data_point, ensure_ascii=False) + "\n") + logger.info( + f"Current file {fp_writer.name}; " + f"Data size: {len(spliced_dataset)}; " + f"Spliced data size: {spliced_dataset.current_size}; " + f"Splicing compression rate: {round(spliced_dataset.current_size / len(spliced_dataset), 6)}; " + f"Time cost: {round((time.time() - st) / 60, 6)} minutes." + ) + + # Save each arrow spliced dataset + output_arrow_path = os.path.join(args.data_arrow_output_dir, output_name) + logger.info(f"Start to save {output_arrow_path}") + spliced_dataset = load_dataset( + path="json", + data_files=[output_jsonl_path], + cache_dir=os.path.join(args.data_cache_dir, "spliced_and_tokenized"), + keep_in_memory=False, + num_proc=cpu_count(), + split="train", + ) + spliced_dataset.save_to_disk(dataset_path=output_arrow_path, num_proc=min(len(spliced_dataset), cpu_count())) + + +if __name__ == "__main__": + main() diff --git a/toolbox/ColossalAI/v0.4.4/patches/applications/Colossal-LLaMA/dataset/prepare_sft_dataset.py b/toolbox/ColossalAI/v0.4.4/patches/applications/Colossal-LLaMA/dataset/prepare_sft_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..6f626e53dd75f8b5bbf0b12e334df7598b11abd1 --- /dev/null +++ b/toolbox/ColossalAI/v0.4.4/patches/applications/Colossal-LLaMA/dataset/prepare_sft_dataset.py @@ -0,0 +1,152 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Prepare sft dataset for fine-tuning +""" + +import argparse +import json +import math +import os +from multiprocessing import cpu_count +import sys +sys.path.append(os.path.join(os.path.dirname(__file__), "../")) +from colossal_llama.dataset.conversation import LLaMA2_Conv, LLaMA3_Conv +from colossal_llama.dataset.spliced_and_tokenized_dataset import supervised_tokenize_sft +from datasets import dataset_dict, load_dataset +from transformers import AddedToken, AutoTokenizer + +from colossalai.logging import get_dist_logger + +logger = get_dist_logger() + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--data_input_dirs", + type=str, + required=True, + default=None, + help="Comma(i.e., ',') separated list of all data directories containing `.jsonl` data files.", + ) + parser.add_argument( + "--tokenizer_dir", type=str, required=True, default=None, help="A directory containing the tokenizer" + ) + parser.add_argument("--data_output_dirs", type=str, default="data_output_dirs", help="Data output directory") + parser.add_argument("--max_length", type=int, default=8192, help="Max length of each spliced tokenized sequence") + parser.add_argument("--num_spliced_dataset_bins", type=int, default=10, help="Number of spliced dataset bins") + parser.add_argument("--llama_version", type=int, default=3, help="LLaMA version") + args = parser.parse_args() + + if args.num_spliced_dataset_bins >= 100000: + raise ValueError("Too many spliced divisions, must be smaller than 100000") + + args.data_cache_dir = os.path.join(args.data_output_dirs, "cache") + args.data_jsonl_output_dir = os.path.join(args.data_output_dirs, "jsonl") + args.data_arrow_output_dir = os.path.join(args.data_output_dirs, "arrow") + + if not os.path.exists(args.data_cache_dir): + os.makedirs(args.data_cache_dir) + if not os.path.exists(args.data_jsonl_output_dir): + os.makedirs(args.data_jsonl_output_dir) + if not os.path.exists(args.data_arrow_output_dir): + os.makedirs(args.data_arrow_output_dir) + + # Prepare to all input datasets + input_data_paths = [] + input_data_dirs = args.data_input_dirs.split(",") + for ds_dir in input_data_dirs: + ds_dir = os.path.abspath(ds_dir) + assert os.path.exists(ds_dir), f"Not find data dir {ds_dir}" + ds_files = [name for name in os.listdir(ds_dir) if name.endswith(".jsonl")] + ds_paths = [os.path.join(ds_dir, name) for name in ds_files] + input_data_paths.extend(ds_paths) + + # Prepare to data splitting. + train_splits = [] + split_interval = math.ceil(100 / args.num_spliced_dataset_bins) + for i in range(0, 100, split_interval): + start = i + end = i + split_interval + if end > 100: + end = 100 + train_splits.append(f"train[{start}%:{end}%]") + + # Prepare to the tokenizer. + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_dir) + + default_conversation = LLaMA3_Conv + + # Fix split issue: https://github.com/huggingface/transformers/issues/23833 + if args.llama_version == 2: + tokenizer.add_tokens(AddedToken("", normalized=False, special=True), special_tokens=True) + default_conversation = LLaMA2_Conv + + tokenizer.add_bos_token = False + tokenizer.add_eos_token = False + if tokenizer.pad_token is None: + if tokenizer.unk_token is not None: + tokenizer.pad_token = tokenizer.unk_token + else: + tokenizer.pad_token = tokenizer.eos_token + tokenizer.unk_token = tokenizer.eos_token + + list_dataset = load_dataset( + path="json", + data_files=input_data_paths, + cache_dir=os.path.join(args.data_cache_dir, "raw"), + keep_in_memory=False, + split=train_splits, + num_proc=cpu_count(), + ) + for index, dataset in enumerate(list_dataset): + assert isinstance(dataset, dataset_dict.Dataset) + logger.info(f"Start to process part-{index}/{len(list_dataset)} of all original datasets.") + dataset = dataset.map( + function=supervised_tokenize_sft, + fn_kwargs={ + "tokenizer": tokenizer, + "conversation_template": default_conversation, + "max_length": args.max_length, + }, + keep_in_memory=False, + num_proc=min(len(dataset), cpu_count()), + ) + + dataset = dataset.filter(lambda data: data["labels"] is not None) + dataset = dataset.sort(column_names=("seq_category", "seq_length"), reverse=False, keep_in_memory=False) + + # We don't concatenate data samples here. + spliced_dataset = dataset + # Save each jsonl spliced dataset. + output_index = "0" * (5 - len(str(index))) + str(index) + output_name = f"part-{output_index}" + output_jsonl_path = os.path.join(args.data_jsonl_output_dir, output_name + ".jsonl") + # st = time.time() + with open(file=output_jsonl_path, mode="w", encoding="utf-8") as fp_writer: + spliced_count = 0 + for spliced_data_point in spliced_dataset: + if spliced_count % 500 == 0: + logger.info(f"processing {spliced_count} spliced data points for {fp_writer.name}") + spliced_count += 1 + fp_writer.write(json.dumps(spliced_data_point, ensure_ascii=False) + "\n") + + # Save each arrow spliced dataset + output_arrow_path = os.path.join(args.data_arrow_output_dir, output_name) + logger.info(f"Start to save {output_arrow_path}") + spliced_dataset = load_dataset( + path="json", + data_files=[output_jsonl_path], + cache_dir=os.path.join(args.data_cache_dir, "spliced_and_tokenized"), + keep_in_memory=False, + num_proc=cpu_count(), + split="train", + ) + spliced_dataset.save_to_disk(dataset_path=output_arrow_path, num_proc=min(len(spliced_dataset), cpu_count())) + + +if __name__ == "__main__": + main() diff --git a/toolbox/ColossalAI/v0.4.4/patches/applications/Colossal-LLaMA/performance_evaluator.py b/toolbox/ColossalAI/v0.4.4/patches/applications/Colossal-LLaMA/performance_evaluator.py new file mode 100644 index 0000000000000000000000000000000000000000..053af1d2cbdd4afd8ab20c9ffc384e22be5e35d8 --- /dev/null +++ b/toolbox/ColossalAI/v0.4.4/patches/applications/Colossal-LLaMA/performance_evaluator.py @@ -0,0 +1,153 @@ +#!/usr/bin/env python3 +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from time import time +from typing import Optional + +import torch +import torch.distributed as dist +from torch import Tensor + +from colossalai.accelerator import get_accelerator +from colossalai.cluster import DistCoordinator +from colossal_llama.utils import utils + +def divide(x: float, y: float) -> float: + if y == 0: + return float("inf") + elif y == float("inf"): + return float("nan") + return x / y + + +@torch.no_grad() +def all_reduce_mean(x: float, world_size: int) -> float: + if world_size == 1: + return x + tensor = torch.tensor([x], device=get_accelerator().get_current_device()) + dist.all_reduce(tensor) + tensor = tensor / world_size + return tensor.item() + + +class Timer: + def __init__(self) -> None: + self.start_time: Optional[float] = None + self.duration: float = 0.0 + + def start(self) -> None: + self.start_time = time() + + def end(self) -> None: + assert self.start_time is not None + self.duration = time() - self.start_time + self.start_time = None + + def reset(self) -> None: + self.duration = 0.0 + + +class PerformanceEvaluator: + """ + Callback for valuate the performance of the model. + Args: + actor_num_params: The number of parameters of the actor model. + critic_num_params: The number of parameters of the critic model. + initial_model_num_params: The number of parameters of the initial model. + reward_model_num_params: The number of parameters of the reward model. + enable_grad_checkpoint: Whether to enable gradient checkpointing. + ignore_episodes: The number of episodes to ignore when calculating the performance. + """ + + def __init__( + self, + model_numel: int, + num_layers: int, + hidden_size: int, + vocab_size: int, + enable_grad_checkpoint: bool = False, + ignore_steps: int = 0, + dp_world_size: Optional[int] = None, + ) -> None: + self.model_numel = model_numel + self.enable_grad_checkpoint = enable_grad_checkpoint + self.ignore_steps = ignore_steps + self.num_layers = num_layers + self.hidden_size = hidden_size + self.vocab_size = vocab_size + + self.coordinator = DistCoordinator() + self.dp_world_size = dp_world_size or self.coordinator.world_size + self.disable: bool = False + self.timer = Timer() + self.num_samples: int = 0 + self.flop_megatron = 0 + self.flop: int = 0 + self.tokens_per_second_per_devices = [] + self.avg_tflops_per_gpus = [] + + def on_step_start(self, step: int) -> None: + self.disable = self.ignore_steps > 0 and step < self.ignore_steps + self.step = step + # if self.disable: + # return + get_accelerator().synchronize() + self.timer.start() + + def on_step_end(self, loss, inputs_size, plugin, **kwargs) -> None: + # if self.disable: + # return + get_accelerator().synchronize() + self.timer.end() + + batch_size, seq_len = inputs_size + + self.num_samples = batch_size + checkpoint_activations_factor = 3 + int(self.enable_grad_checkpoint) + self.flop_megatron = ( + 24 * checkpoint_activations_factor * batch_size * seq_len * self.num_layers * (self.hidden_size**2) + ) * ( + 1.0 + (seq_len / (6.0 * self.hidden_size)) + (self.vocab_size / (16.0 * self.num_layers * self.hidden_size)) + ) + self.flop = batch_size * seq_len * self.model_numel * 2 * (3 + int(self.enable_grad_checkpoint)) + + # def on_fit_end(self) -> None: + avg_duration = all_reduce_mean(self.timer.duration, self.coordinator.world_size) + avg_throughput = self.num_samples * self.dp_world_size / (avg_duration + 1e-12) + tokens_per_second_per_device = avg_throughput * seq_len * 2 / self.coordinator.world_size ## BI-V150 one device has two gpus + mp_world_size = self.coordinator.world_size // self.dp_world_size + avg_tflops_per_gpu_megatron = self.flop_megatron / 1e12 / (avg_duration + 1e-12) / mp_world_size + avg_tflops_per_gpu = self.flop / 1e12 / (avg_duration + 1e-12) / mp_world_size + + global_loss = None + if plugin.stage_manager.is_last_stage(): + global_loss = utils.all_reduce_mean(loss, plugin) + + + self.coordinator.print_on_last_process( + f"num_samples: {self.num_samples}, dp_world_size: {self.dp_world_size}, flop_megatron: {self.flop_megatron}, flop: {self.flop}, avg_duration: {avg_duration}, " + ) + self.coordinator.print_on_last_process( + f"loss:{global_loss}, Throughput: {avg_throughput:.2f} samples/sec , tokens_per_second_per_device: {tokens_per_second_per_device} , TFLOPS per GPU by Megatron: {avg_tflops_per_gpu_megatron:.2f} , TFLOPS per GPU: {avg_tflops_per_gpu:.2f}" + ) + + if self.step >= self.ignore_steps and self.step < self.ignore_steps + 5: + if self.step == self.ignore_steps + 4: + self.coordinator.print_on_last_process("\n ---------------------------------------------" + + f"\n average values of [{self.ignore_steps} - {self.ignore_steps + 5}) steps, tokens_per_second_per_device: {sum(self.tokens_per_second_per_devices)/len(self.tokens_per_second_per_devices):.2f} , TFLOPS per GPU: {sum(self.avg_tflops_per_gpus)/len(self.avg_tflops_per_gpus):.2f} " + + "\n ---------------------------------------------") + else: + self.tokens_per_second_per_devices.append(tokens_per_second_per_device) + self.avg_tflops_per_gpus.append(avg_tflops_per_gpu) \ No newline at end of file diff --git a/toolbox/ColossalAI/v0.4.4/patches/applications/Colossal-LLaMA/prepare_pretrain_dataset.sh b/toolbox/ColossalAI/v0.4.4/patches/applications/Colossal-LLaMA/prepare_pretrain_dataset.sh new file mode 100644 index 0000000000000000000000000000000000000000..e199754b69fe610f2887309c9f3991650b27fad2 --- /dev/null +++ b/toolbox/ColossalAI/v0.4.4/patches/applications/Colossal-LLaMA/prepare_pretrain_dataset.sh @@ -0,0 +1,53 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#!/bin/bash +# 本脚本可以带一个参数或者0个参数,指示llama版本,可为 "llama2" 或者 "llama3",如果无入参,则默认为llama2 + +set -euox pipefail +CUR_DIR=$(cd "$(dirname "$0")";pwd) +cd ${CUR_DIR} + +if [[ ! -f $CUR_DIR/small-117M.train.jsonl ]]; then + wget http://10.150.9.95/swapp/datasets/nlp/gpt-2-output-dataset/small-117M.train.jsonl +fi + +DATA_INPUT_DIRS=$CUR_DIR + +LLAMA_VER=${1:-"llama3"} +echo "LLaMA version:" $LLAMA_VER + +if [ $LLAMA_VER == "llama2" ]; then + # 代码中lable与input的错位需要,loss计算length为4096的sequence。 + MAX_LENGTH=4097 + TOKENIZER_DIR=/home/model_zoos/Llama-2-7b-hf + DATA_OUTPUT_DIRS=dataset/llama2_data + +elif [ $LLAMA_VER == "llama3" ]; then + # 代码中lable与input的错位需要,loss计算length为8192的sequence。 + MAX_LENGTH=8193 + TOKENIZER_DIR=/home/model_zoos/Meta-Llama-3-8B + DATA_OUTPUT_DIRS=dataset/llama3_data + +else + echo "Error LLAMA_VER, please input correct LLaMA version" + exit 1 +fi + +python3 dataset/prepare_pretrain_dataset.py \ + --data_input_dirs $DATA_INPUT_DIRS \ + --data_output_dirs $DATA_OUTPUT_DIRS \ + --dataset_type webtext \ + --tokenizer_dir $TOKENIZER_DIR \ + --max_length $MAX_LENGTH \ diff --git a/toolbox/ColossalAI/v0.4.4/patches/applications/Colossal-LLaMA/prepare_sft_dataset.sh b/toolbox/ColossalAI/v0.4.4/patches/applications/Colossal-LLaMA/prepare_sft_dataset.sh new file mode 100644 index 0000000000000000000000000000000000000000..363a89fccb87d2959b6ed12090e8aa80d462ee09 --- /dev/null +++ b/toolbox/ColossalAI/v0.4.4/patches/applications/Colossal-LLaMA/prepare_sft_dataset.sh @@ -0,0 +1,61 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#!/bin/bash +# 本脚本可以带一个参数或者0个参数,指示llama版本,可为 "llama2" 或者 "llama3",如果无入参,则默认为llama2 + +set -euox pipefail +CUR_DIR=$(cd "$(dirname "$0")";pwd) +cd ${CUR_DIR} + +DATA_INPUT_DIRS=$CUR_DIR"/dataset/school_math/convert/" +mkdir -p $DATA_INPUT_DIRS + +if [[ ! -f $DATA_INPUT_DIRS"school_math_0.25M_convert.jsonl" ]]; then + if [[ ! -f $DATA_INPUT_DIRS"../school_math_0.25M.jsonl" ]]; then + wget http://sw.iluvatar.ai/download/apps/llm-modelzoo/dataset/school_math_0.25M.jsonl + mv school_math_0.25M.jsonl $DATA_INPUT_DIRS"../" + fi + + python3 dataset/convert_data.py +fi + + +LLAMA_VER=${1:-"llama3"} +echo "LLaMA version:" $LLAMA_VER + +if [ $LLAMA_VER == "llama2" ]; then + # 代码中lable与input的错位需要,loss计算length为4096的sequence。 + MAX_LENGTH=4097 + TOKENIZER_DIR=/home/model_zoos/Llama-2-7b-hf + DATA_OUTPUT_DIRS=dataset/school_math/convert/llama2_data_sft + llama_ver=2 + +elif [ $LLAMA_VER == "llama3" ]; then + # 代码中lable与input的错位需要,loss计算length为8192的sequence。 + MAX_LENGTH=8193 + TOKENIZER_DIR=/home/model_zoos/Meta-Llama-3-8B + DATA_OUTPUT_DIRS=dataset/school_math/convert/llama3_data_sft + llama_ver=3 +else + echo "Error LLAMA_VER, please input correct LLaMA version" + exit 1 +fi + +python3 dataset/prepare_sft_dataset.py \ + --data_input_dirs $DATA_INPUT_DIRS \ + --data_output_dirs $DATA_OUTPUT_DIRS \ + --tokenizer_dir $TOKENIZER_DIR \ + --max_length $MAX_LENGTH \ + --llama_version $llama_ver diff --git a/toolbox/ColossalAI/v0.4.4/patches/applications/Colossal-LLaMA/requirements.txt b/toolbox/ColossalAI/v0.4.4/patches/applications/Colossal-LLaMA/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..a691b49c829c14cb9eef75582d5e6f1cabb514f3 --- /dev/null +++ b/toolbox/ColossalAI/v0.4.4/patches/applications/Colossal-LLaMA/requirements.txt @@ -0,0 +1,15 @@ +# torch==2.1.2 +huggingface-hub +packaging==24.0 +colossalai>=0.4.0 +autoflake==2.2.1 +black==23.9.1 +# transformers>=4.39.3 +tensorboard==2.14.0 +six==1.16.0 +datasets +ninja==1.11.1 +# flash-attn +tqdm +sentencepiece==0.1.99 +protobuf<=3.20.0 diff --git a/toolbox/ColossalAI/v0.4.4/patches/applications/Colossal-LLaMA/run_llama2_7b_pretrain_3d.sh b/toolbox/ColossalAI/v0.4.4/patches/applications/Colossal-LLaMA/run_llama2_7b_pretrain_3d.sh new file mode 100644 index 0000000000000000000000000000000000000000..df8da8819b4207249141c8b77e7787cdb96f4dcb --- /dev/null +++ b/toolbox/ColossalAI/v0.4.4/patches/applications/Colossal-LLaMA/run_llama2_7b_pretrain_3d.sh @@ -0,0 +1,80 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#!/bin/bash +# set_n_least_used_CUDA_VISIBLE_DEVICES() { +# local n=${1:-"9999"} +# echo "GPU Memory Usage:" +# local FIRST_N_GPU_IDS=$(ixsmi --query-gpu=memory.used --format=csv | +# tail -n +2 | +# nl -v 0 | +# tee /dev/tty | +# sort -g -k 2 | +# awk '{print $1}' | +# head -n $n) +# export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g') +# echo "Now CUDA_VISIBLE_DEVICES is set to:" +# echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" +# } + +# set_n_least_used_CUDA_VISIBLE_DEVICES 8 + +PARENT_SAVE_DIR="checkpoint" +PARENT_TENSORBOARD_DIR="tensorboard" +PARENT_CONFIG_FILE="config" + +TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S) +LOG_DIR="logs/${TIMESTAMP}" +SAVE_DIR="${LOG_DIR}/${PARENT_SAVE_DIR}" +TENSORBOARD_DIR="${LOG_DIR}/${PARENT_TENSORBOARD_DIR}" +CONFIG_FILE="${LOG_DIR}/${PARENT_CONFIG_FILE}.json" + +DATASET_PATH=./dataset/llama2_data/arrow/ +TOKENIZER_DIR=/home/model_zoos/Llama-2-7b-hf +GLOBAL_BATCH_SIZE_PER_DP=8 +MICRO_BATCH_SIZE=1 + + +mkdir -p $LOG_DIR +colossalai run --nproc_per_node 16 train.py \ + --dataset $DATASET_PATH \ + --tokenizer_dir $TOKENIZER_DIR \ + --max_length 4096 \ + --plugin "3d" \ + --zero_stage 1 \ + --pp 4 \ + --custom_ckpt \ + --custom_recompute_layers_per_stage 0 0 0 0 \ + --ignore_steps 2 \ + --use_ixformer_mlp \ + --use_colo_llamaflashatten \ + --use_ixformer_fusedrmsnormres \ + --save_interval 0 \ + --save_dir $SAVE_DIR \ + --tensorboard_dir $TENSORBOARD_DIR \ + --config_file $CONFIG_FILE \ + --num_epochs 1 \ + --batch_size $GLOBAL_BATCH_SIZE_PER_DP \ + --microbatch_size $MICRO_BATCH_SIZE \ + --lr 1e-4 \ + --mixed_precision "bf16" \ + --grad_clip 1.0 \ + --weight_decay 0.01 \ + --warmup_steps 100 \ + --use_grad_checkpoint \ + --use_flash_attn \ + --pad_token "unk" |& tee ${LOG_DIR}/output.log + + + diff --git a/toolbox/ColossalAI/v0.4.4/patches/applications/Colossal-LLaMA/run_llama3_8b_pretrain_3d.sh b/toolbox/ColossalAI/v0.4.4/patches/applications/Colossal-LLaMA/run_llama3_8b_pretrain_3d.sh new file mode 100644 index 0000000000000000000000000000000000000000..04df4b815177d49e198692b44d6e1af81eb6c727 --- /dev/null +++ b/toolbox/ColossalAI/v0.4.4/patches/applications/Colossal-LLaMA/run_llama3_8b_pretrain_3d.sh @@ -0,0 +1,78 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#!/bin/bash +# set_n_least_used_CUDA_VISIBLE_DEVICES() { +# local n=${1:-"9999"} +# echo "GPU Memory Usage:" +# local FIRST_N_GPU_IDS=$(ixsmi --query-gpu=memory.used --format=csv | +# tail -n +2 | +# nl -v 0 | +# tee /dev/tty | +# sort -g -k 2 | +# awk '{print $1}' | +# head -n $n) +# export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g') +# echo "Now CUDA_VISIBLE_DEVICES is set to:" +# echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" +# } + +# set_n_least_used_CUDA_VISIBLE_DEVICES 8 + +PARENT_SAVE_DIR="checkpoint" +PARENT_TENSORBOARD_DIR="tensorboard" +PARENT_CONFIG_FILE="config" + +TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S) +LOG_DIR="logs/${TIMESTAMP}" +SAVE_DIR="${LOG_DIR}/${PARENT_SAVE_DIR}" +TENSORBOARD_DIR="${LOG_DIR}/${PARENT_TENSORBOARD_DIR}" +CONFIG_FILE="${LOG_DIR}/${PARENT_CONFIG_FILE}.json" + +DATASET_PATH=./dataset/llama3_data/arrow/ +TOKENIZER_DIR=/home/model_zoos/Meta-Llama-3-8B +GLOBAL_BATCH_SIZE_PER_DP=8 +MICRO_BATCH_SIZE=1 + + +mkdir -p $LOG_DIR +colossalai run --nproc_per_node 16 train.py \ + --config "llama3_8b" \ + --dataset $DATASET_PATH \ + --tokenizer_dir $TOKENIZER_DIR \ + --max_length 8192 \ + --plugin "3d" \ + --zero_stage 1 \ + --pp 4 \ + --custom_ckpt \ + --custom_recompute_layers_per_stage 8 7 6 7 \ + --ignore_steps 2 \ + --use_ixformer_mlp \ + --use_ixformer_fusedrmsnormres \ + --save_interval 0 \ + --save_dir $SAVE_DIR \ + --tensorboard_dir $TENSORBOARD_DIR \ + --config_file $CONFIG_FILE \ + --num_epochs 1 \ + --batch_size $GLOBAL_BATCH_SIZE_PER_DP \ + --microbatch_size $MICRO_BATCH_SIZE \ + --lr 1e-4 \ + --mixed_precision "bf16" \ + --grad_clip 1.0 \ + --weight_decay 0.01 \ + --warmup_steps 100 \ + --use_grad_checkpoint \ + --use_flash_attn \ + --pad_token "eos" |& tee ${LOG_DIR}/output.log + diff --git a/toolbox/ColossalAI/v0.4.4/patches/applications/Colossal-LLaMA/run_llama3_8b_sft_3d.sh b/toolbox/ColossalAI/v0.4.4/patches/applications/Colossal-LLaMA/run_llama3_8b_sft_3d.sh new file mode 100644 index 0000000000000000000000000000000000000000..56b668adaac3e65259692bfb3ec14f25c57fdb2e --- /dev/null +++ b/toolbox/ColossalAI/v0.4.4/patches/applications/Colossal-LLaMA/run_llama3_8b_sft_3d.sh @@ -0,0 +1,79 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#!/bin/bash +# set_n_least_used_CUDA_VISIBLE_DEVICES() { +# local n=${1:-"9999"} +# echo "GPU Memory Usage:" +# local FIRST_N_GPU_IDS=$(ixsmi --query-gpu=memory.used --format=csv | +# tail -n +2 | +# nl -v 0 | +# tee /dev/tty | +# sort -g -k 2 | +# awk '{print $1}' | +# head -n $n) +# export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g') +# echo "Now CUDA_VISIBLE_DEVICES is set to:" +# echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" +# } + +# set_n_least_used_CUDA_VISIBLE_DEVICES 8 + +PARENT_SAVE_DIR="checkpoint" +PARENT_TENSORBOARD_DIR="tensorboard" +PARENT_CONFIG_FILE="config" + +TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S) +LOG_DIR="logs/${TIMESTAMP}" +SAVE_DIR="${LOG_DIR}/${PARENT_SAVE_DIR}" +TENSORBOARD_DIR="${LOG_DIR}/${PARENT_TENSORBOARD_DIR}" +CONFIG_FILE="${LOG_DIR}/${PARENT_CONFIG_FILE}.json" + +DATASET_PATH=./dataset/school_math/convert/llama3_data_sft/arrow/ +TOKENIZER_DIR=/home/model_zoos/Meta-Llama-3-8B +GLOBAL_BATCH_SIZE_PER_DP=8 +MICRO_BATCH_SIZE=1 + + +mkdir -p $LOG_DIR +colossalai run --nproc_per_node 16 train.py \ + --config "llama3_8b" \ + --dataset $DATASET_PATH \ + --tokenizer_dir $TOKENIZER_DIR \ + --max_length 8192 \ + --plugin "3d" \ + --zero_stage 1 \ + --pp 4 \ + --custom_ckpt \ + --custom_recompute_layers_per_stage 7 6 5 6 \ + --use_ixformer_mlp \ + --use_ixformer_fusedrmsnormres \ + --ignore_steps 2 \ + --save_interval 0 \ + --save_dir $SAVE_DIR \ + --tensorboard_dir $TENSORBOARD_DIR \ + --config_file $CONFIG_FILE \ + --num_epochs 1 \ + --batch_size $GLOBAL_BATCH_SIZE_PER_DP \ + --microbatch_size $MICRO_BATCH_SIZE \ + --lr 1e-4 \ + --mixed_precision "bf16" \ + --grad_clip 1.0 \ + --weight_decay 0.01 \ + --warmup_steps 100 \ + --use_grad_checkpoint \ + --use_flash_attn \ + --use_neft \ + --pad_token "eos" |& tee ${LOG_DIR}/output.log + diff --git a/toolbox/ColossalAI/v0.4.4/patches/applications/Colossal-LLaMA/train.py b/toolbox/ColossalAI/v0.4.4/patches/applications/Colossal-LLaMA/train.py new file mode 100644 index 0000000000000000000000000000000000000000..79c11d2917b39d9ef883969f91c751f01ab86eb8 --- /dev/null +++ b/toolbox/ColossalAI/v0.4.4/patches/applications/Colossal-LLaMA/train.py @@ -0,0 +1,601 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Continual Pre-training/Supervised fine-tuning of Colossal-LLaMA-2 developed by Colossal-AI Team +""" + +import argparse +import json +import os +import resource +from contextlib import nullcontext + +import torch +from colossal_llama.dataset.dummy_dataset import RandomDataset +from colossal_llama.dataset.loader import ( + DataCollatorForSupervisedDataset, + StatefulDistributedSampler, + load_tokenized_dataset, +) +from colossal_llama.utils.ckpt_io import load_checkpoint, save_checkpoint +from colossal_llama.utils.froze import freeze_non_embeds_parameters +from colossal_llama.utils.neftune_patch import activate_neftune, deactivate_neftune +from colossal_llama.utils.utils import all_reduce_mean, format_numel_str, get_model_numel +from torch.utils.tensorboard import SummaryWriter +from tqdm import tqdm +from transformers import AutoModelForCausalLM, AutoTokenizer + +import colossalai +from colossalai.accelerator import get_accelerator +from colossalai.booster import Booster +from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin +from colossalai.cluster import DistCoordinator +from colossalai.lazy import LazyInitContext +from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR +from colossalai.nn.optimizer import HybridAdam +from colossalai.utils import get_current_device +from transformers.models.llama.configuration_llama import LlamaConfig +from transformers import AutoConfig, AutoModelForCausalLM +from colossalai.shardformer import PipelineGradientCheckpointConfig +from performance_evaluator import PerformanceEvaluator + +MODEL_CONFIGS = { + "7b": LlamaConfig(max_position_embeddings=4096), + "13b": LlamaConfig( + hidden_size=5120, + intermediate_size=13824, + num_hidden_layers=40, + num_attention_heads=40, + max_position_embeddings=4096, + ), + "70b": LlamaConfig( + hidden_size=8192, + intermediate_size=28672, + num_hidden_layers=80, + num_attention_heads=64, + max_position_embeddings=4096, + num_key_value_heads=8, + ), + "llama3_8b": LlamaConfig(max_position_embeddings=8192, + vocab_size=128256, + num_key_value_heads=8, + intermediate_size=14336, + rope_theta=500000), +} + +def train(args) -> None: + # ============================== + # Initialize Distributed Training + # ============================== + colossalai.launch_from_torch() + accelerator = get_accelerator() + coordinator = DistCoordinator() + + # ============================== + # Initialize Tensorboard and Save Config + # ============================== + if coordinator.is_master(): + os.makedirs(args.tensorboard_dir, exist_ok=True) + writer = SummaryWriter(args.tensorboard_dir) + + with open(args.config_file, "w") as f: + json.dump(args.__dict__, f, indent=4) + print(f"args:{args}") + # ============================== + # Initialize Booster + # ============================== + hybrid_kwargs = { + "gradient_checkpoint_config": PipelineGradientCheckpointConfig(num_ckpt_layers_per_stage=args.custom_recompute_layers_per_stage) if args.custom_ckpt else None, + "use_ixformer_mlp": args.use_ixformer_mlp, + "use_colo_llamaflashatten": args.use_colo_llamaflashatten, + "use_ixformer_fusedrmsnormres": args.use_ixformer_fusedrmsnormres, + } + if args.plugin == "ddp": + plugin = TorchDDPPlugin(find_unused_parameters=True if args.use_grad_checkpoint is False else False) + elif args.plugin == "gemini": + plugin = GeminiPlugin( + precision=args.mixed_precision, + initial_scale=2**16, + max_norm=args.grad_clip, + enable_gradient_accumulation=(args.accumulation_steps > 1), + enable_fused_normalization=torch.cuda.is_available(), + enable_flash_attention=args.use_flash_attn, + ) + elif args.plugin == "gemini_auto": + plugin = GeminiPlugin( + precision=args.mixed_precision, + placement_policy="auto", + initial_scale=2**16, + max_norm=args.grad_clip, + enable_gradient_accumulation=(args.accumulation_steps > 1), + enable_fused_normalization=torch.cuda.is_available(), + enable_flash_attention=args.use_flash_attn, + ) + elif args.plugin == "zero2": + plugin = LowLevelZeroPlugin( + stage=2, + precision=args.mixed_precision, + initial_scale=2**16, + max_norm=args.grad_clip, + ) + elif args.plugin == "zero2_cpu": + plugin = LowLevelZeroPlugin( + stage=2, + precision=args.mixed_precision, + initial_scale=2**16, + cpu_offload=True, + max_norm=args.grad_clip, + ) + elif args.plugin == "3d": + plugin = HybridParallelPlugin( + tp_size=args.tp, + pp_size=args.pp, + sp_size=args.sp, + sequence_parallelism_mode=args.sp_mode, + zero_stage=args.zero_stage, + enable_flash_attention=args.use_flash_attn, + enable_fused_normalization=torch.cuda.is_available(), + enable_sequence_parallelism=args.enable_sequence_parallelism, + cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False, + parallel_output=False, + max_norm=args.grad_clip, + precision=args.mixed_precision, + microbatch_size=args.microbatch_size, + **hybrid_kwargs, + ) + else: + raise ValueError(f"Unknown plugin {args.plugin}") + + booster = Booster(plugin=plugin) + + # ====================================================== + # Initialize Tokenizer, Dataset, Collator and Dataloader + # ====================================================== + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_dir) + if args.pad_token == "eos": + tokenizer.pad_token = tokenizer.eos_token + elif args.pad_token == "unk": + tokenizer.pad_token = tokenizer.unk_token + tokenizer.add_bos_token = False + tokenizer.add_eos_token = False + + coordinator.print_on_master( + f"Training Info:\nConfig file: {args.config_file} \nTensorboard logs: {args.tensorboard_dir} \nModel checkpoint: {args.save_dir}" + ) + + if args.benchmark: + coordinator.print_on_master(f"Run benchmark with {args.num_samples} random samples.") + dataset = RandomDataset( + num_samples=args.num_samples, max_length=args.max_length, vocab_size=tokenizer.vocab_size + ) + dataloader = plugin.prepare_dataloader( + dataset, + batch_size=args.batch_size, + shuffle=True, + drop_last=True, + seed=42, + distributed_sampler_cls=StatefulDistributedSampler, + ) + else: + coordinator.print_on_master(f"Load dataset: {args.dataset}") + dataset = load_tokenized_dataset(dataset_parrent_path=args.dataset, mode="train") + data_collator = DataCollatorForSupervisedDataset( + tokenizer=tokenizer, max_length=args.max_length, padding=args.padding_mode + ) + dataloader = plugin.prepare_dataloader( + dataset=dataset, + batch_size=args.batch_size, + num_workers=2, + shuffle=True, + drop_last=True, + collate_fn=data_collator, + distributed_sampler_cls=StatefulDistributedSampler, + ) + + coordinator.print_on_master( + f"Max device memory after data loader: {accelerator.max_memory_allocated() / 1024 ** 2:.2f} MB" + ) + + # ====================================================== + # Initialize Model, Objective, Optimizer and LR Scheduler + # ====================================================== + if args.config in MODEL_CONFIGS: + config = MODEL_CONFIGS[args.config] + else: + config = AutoConfig.from_pretrained(args.config, trust_remote_code=True) + + init_ctx = ( + LazyInitContext(default_device=get_current_device()) + if isinstance(plugin, (GeminiPlugin, HybridParallelPlugin)) + else nullcontext() + ) + with init_ctx: + if args.pretrained: + model = AutoModelForCausalLM.from_pretrained( + args.pretrained, + torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16, + trust_remote_code=True, + ) + else: + init_kwargs={} + if args.use_flash_attn or args.use_colo_llamaflashatten: + init_kwargs["attn_implementation"] = "flash_attention_2" + init_kwargs["torch_dtype"]=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16 + + model = AutoModelForCausalLM.from_config(config, + trust_remote_code=True, + **init_kwargs) + + # Freeze part of parameters. + if args.freeze_non_embeds_params: + freeze_non_embeds_parameters(model=model) + # this is essential, otherwise the grad checkpoint will not work. + model.train() + + if args.use_grad_checkpoint: + model.gradient_checkpointing_enable() + coordinator.print_on_master(msg="Gradient checkpointing enabled successfully") + + model_numel = get_model_numel(model) + coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}") + + optimizer = HybridAdam( + model_params=( + filter(lambda p: p.requires_grad, model.parameters()) + if args.freeze_non_embeds_params + else model.parameters() + ), + lr=args.lr, + betas=(0.9, 0.95), + weight_decay=args.weight_decay, + adamw_mode=True, + ) + + if args.warmup_steps is None: + args.warmup_steps = int(args.num_epochs * 0.025 * (len(dataloader) // args.accumulation_steps)) + coordinator.print_on_master(f"Warmup steps is set to {args.warmup_steps}") + + lr_scheduler = CosineAnnealingWarmupLR( + optimizer=optimizer, + total_steps=args.num_epochs * (len(dataloader) // args.accumulation_steps), + warmup_steps=args.warmup_steps, + eta_min=0.1 * args.lr, + ) + + # Flash attention will be disabled because it does NOT support fp32. + default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16 + torch.set_default_dtype(default_dtype) + model, optimizer, _, dataloader, lr_scheduler = booster.boost( + model=model, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + dataloader=dataloader, + ) + + torch.set_default_dtype(torch.float) + + coordinator.print_on_master( + f"Booster init max device memory: {accelerator.max_memory_allocated() / 1024 ** 2:.2f} MB" + ) + coordinator.print_on_master( + f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB" + ) + + start_epoch = 0 + start_step = 0 + sampler_start_idx = 0 + if args.load_checkpoint is not None: + if "modeling" in args.load_checkpoint: + coordinator.print_on_master(f"Continued pretrain from checkpoint {args.load_checkpoint}") + booster.load_model(model, args.load_checkpoint) + else: + coordinator.print_on_master(f"Load model checkpoint from {args.load_checkpoint}") + start_epoch, start_step, sampler_start_idx = load_checkpoint( + load_dir=args.load_checkpoint, + booster=booster, + model=model, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + ) + coordinator.print_on_master( + f"Loaded checkpoint {args.load_checkpoint} at epoch {start_epoch} step {start_step}" + ) + coordinator.print_on_master(f"Loaded sample at index {sampler_start_idx}") + + coordinator.print_on_master( + f"Checkpoint loaded max device memory: {accelerator.max_memory_allocated() / 1024 ** 2:.2f} MB" + ) + coordinator.print_on_master( + f"Checkpoint loaded device memory: {accelerator.memory_allocated() / 1024 ** 2:.2f} MB" + ) + coordinator.print_on_master( + f"Checkpoint loaded max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB" + ) + + if args.use_neft: + coordinator.print_on_master("Activate NEFTune.") + model, handle = activate_neftune(model) + + dp_size = getattr(plugin, "dp_size", coordinator.world_size) + performance_evaluator = PerformanceEvaluator( + model_numel, + model.module.config.num_hidden_layers, + model.module.config.hidden_size, + model.module.config.vocab_size, + args.use_grad_checkpoint, + args.ignore_steps, + dp_world_size=dp_size, + ) + + num_steps_per_epoch = len(dataloader) // args.accumulation_steps + # If resume training, set the sampler start index to the correct value + assert isinstance(dataloader.sampler, StatefulDistributedSampler) + dataloader.sampler.set_start_index(start_index=sampler_start_idx) + + for epoch in range(start_epoch, args.num_epochs): + dataloader.sampler.set_epoch(epoch=epoch) + if isinstance(plugin, HybridParallelPlugin) and plugin.pp_size > 1: + data_iter = iter(dataloader) + step_bar = tqdm( + range(len(dataloader)), + desc="Step", + disable=not (coordinator._local_rank == coordinator._world_size - 1), + ) + with torch.autograd.profiler.profile(enabled=False) as prof: + for step in step_bar: + # if step > 7: + # break + performance_evaluator.on_step_start(step) + outputs = booster.execute_pipeline( + data_iter, + model, + criterion=lambda outputs, inputs: outputs[0], + optimizer=optimizer, + return_loss=True, + ) + loss = outputs["loss"] + optimizer.step() + optimizer.zero_grad() + performance_evaluator.on_step_end(loss, inputs_size = (args.batch_size, args.max_length), plugin=booster.plugin) + + # Save modeling. + save_model_condition = args.save_interval > 0 and (step + 1) % args.save_interval == 0 + + if not args.skip_save_each_epoch: + save_model_condition = save_model_condition or (step + 1) == len(dataloader) + + if save_model_condition and not args.benchmark: + coordinator.print_on_master("\nStart saving model checkpoint with running states") + + if args.use_neft: + coordinator.print_on_master("Deactivate NEFTune before saving model.") + deactivate_neftune(model, handle) + + accelerator.empty_cache() + save_checkpoint( + save_dir=args.save_dir, + booster=booster, + model=model, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + epoch=epoch, + step=step + 1, + batch_size=args.batch_size, + coordinator=coordinator, + ) + coordinator.print_on_master( + f"Saved checkpoint at epoch {epoch} step {step + 1} at folder {args.save_dir}" + ) + + if args.use_neft: + coordinator.print_on_master("Activate NEFTune.") + model, handle = activate_neftune(model) + if prof: + prof.export_chrome_trace(f'torch_profile/rank{torch.distributed.get_rank()}_llama2_colo.json') + else: + pbar = tqdm( + desc=f"Epoch {epoch}", + disable=not coordinator.is_master(), + total=num_steps_per_epoch, + initial=start_step // args.accumulation_steps, + ) + total_loss = torch.tensor(0.0, device=get_current_device()) + for step, batch in enumerate(dataloader, start=start_step): + batch = {k: v.to(get_current_device()) for k, v in batch.items() if isinstance(v, torch.Tensor)} + + batch_output = model(**batch) + + loss = batch_output.loss / args.accumulation_steps + total_loss.add_(loss.data) + + booster.backward(loss=loss, optimizer=optimizer) + + if (step + 1) % args.accumulation_steps == 0: + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + all_reduce_mean(tensor=total_loss) + pbar.set_postfix({"Loss": f"{total_loss.item():.4f}"}) + if coordinator.is_master(): + global_step = (epoch * num_steps_per_epoch) + (step + 1) // args.accumulation_steps + writer.add_scalar(tag="Loss", scalar_value=total_loss.item(), global_step=global_step) + writer.add_scalar( + tag="Learning Rate", + scalar_value=lr_scheduler.get_last_lr()[0], + global_step=global_step, + ) + total_loss.fill_(0.0) + pbar.update() + + # Save modeling. + save_model_condition = ( + args.save_interval > 0 and (step + 1) % (args.save_interval * args.accumulation_steps) == 0 + ) + + if not args.skip_save_each_epoch: + save_model_condition = save_model_condition or (step + 1) == len(dataloader) + + if save_model_condition and not args.benchmark: + coordinator.print_on_master("\nStart saving model checkpoint with running states") + + if args.use_neft: + coordinator.print_on_master("Deactivate NEFTune before saving model.") + deactivate_neftune(model, handle) + + accelerator.empty_cache() + save_checkpoint( + save_dir=args.save_dir, + booster=booster, + model=model, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + epoch=epoch, + step=step + 1, + batch_size=args.batch_size, + coordinator=coordinator, + ) + coordinator.print_on_master( + f"Saved checkpoint at epoch {epoch} step {step + 1} at folder {args.save_dir}" + ) + + if args.use_neft: + coordinator.print_on_master("Activate NEFTune.") + model, handle = activate_neftune(model) + + # Delete cache. + # del batch, batch_labels, batch_output, loss + accelerator.empty_cache() + + # the continue epochs are not resumed, so we need to reset the sampler start index and start step + dataloader.sampler.set_start_index(start_index=0) + start_step = 0 + + if args.use_neft: + coordinator.print_on_master("Deactivate NEFTune.") + deactivate_neftune(model, handle) + + # Final save. + if not args.benchmark: + coordinator.print_on_master("Start saving final model checkpoint") + booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True) + coordinator.print_on_master(f"Saved final model checkpoint at epoch {epoch} at folder {args.save_dir}") + + coordinator.print_on_master(f"Max device memory usage: {accelerator.max_memory_allocated()/1024**2:.2f} MB") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Basic training information. + parser.add_argument( + "--pretrained", + type=str, + default=None, + help="Address of the pre-trained model", + ) + parser.add_argument("--load_checkpoint", type=str, default=None, help="Load checkpoint for continuous training.") + parser.add_argument("--dataset", type=str, default='') + parser.add_argument( + "--plugin", + type=str, + default="gemini", + choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "3d", "ddp"], + help="Choose which plugin to use", + ) + parser.add_argument("--save_interval", type=int, default=1000, help="Save interval") + parser.add_argument("--save_dir", type=str, default="checkpoint_dir", help="Checkpoint directory") + parser.add_argument("--tensorboard_dir", type=str, default="logs_dir", help="Tensorboard directory") + parser.add_argument("--config_file", type=str, default="config_file", help="Config file") + # Training parameters + parser.add_argument("--num_epochs", type=int, default=1, help="Number of training epochs") + parser.add_argument("--accumulation_steps", type=int, default=1, help="Number of accumulation steps") + parser.add_argument("--batch_size", type=int, default=2, help="Global Batch size of each process") + parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate") + parser.add_argument("--max_length", type=int, default=8192, help="Model max length") + parser.add_argument( + "--mixed_precision", + type=str, + default="fp16", + choices=["fp16", "bf16"], + help="Mixed precision", + ) + parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value") + parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay") + parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps") + parser.add_argument( + "--use_grad_checkpoint", + action="store_true", + default=False, + help="Use gradient checkpointing", + ) + parser.add_argument( + "--use_flash_attn", + action="store_true", + default=False, + help="Use flash-attention", + ) + parser.add_argument( + "--use_neft", + action="store_true", + default=False, + help="Use NEFTune", + ) + parser.add_argument( + "--freeze_non_embeds_params", + action="store_true", + default=False, + help="Freeze non embeddings parameters", + ) + parser.add_argument("--pad_token", choices=["eos", "unk"], default="eos") + parser.add_argument("--padding_mode", choices=["max_length", "longest"], default="max_length") + parser.add_argument( + "--skip_save_each_epoch", + action="store_true", + default=False, + help="Skip saving the model checkpoint after each epoch is completed.", + ) + + # Additional arguments for 3d plugin. + parser.add_argument("--tp", type=int, default=1, help="TP size, used for 3d plugin.") + parser.add_argument("--pp", type=int, default=1, help="PP size, used for 3d plugin.") + parser.add_argument("--sp", type=int, default=1, help="SP size, used for 3d plugin.") + parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage, used for 3d plugin.", choices=[0, 1, 2]) + parser.add_argument( + "--sp_mode", + type=str, + default="split_gather", + choices=["split_gather", "ring", "all_to_all"], + help="SP mode, used for 3d plugin.", + ) + parser.add_argument( + "--enable_sequence_parallelism", + default=False, + action="store_true", + help="Whether to enable SP, used for 3d plugin.", + ) + parser.add_argument( + "--zero_cpu_offload", default=False, action="store_true", help="Whether to use offloading, used for 3d plugin." + ) + parser.add_argument( + "--microbatch_size", type=int, default=1, help="Batch size for each process in PP, used for 3d plugin." + ) + + # Additional arguments for benchmark. + parser.add_argument("--num_samples", type=int, default=500, help="Number of samples for benchmarking.") + parser.add_argument( + "--benchmark", action="store_true", default=False, help="Benchmark performance using random dataset." + ) + parser.add_argument("--tokenizer_dir", type=str, default="", help="the path to llamatokenizer") + parser.add_argument("--config", type=str, default="7b", help="Model configuration") + parser.add_argument("--custom_ckpt", action="store_true", help="Customize checkpoint", default=False) + parser.add_argument('--custom_recompute_layers_per_stage', nargs='*', type=int, default=None, + help='custom recompute num layers in each PP stage, it should be equal to PP size ') + parser.add_argument("--ignore_steps", type=int, default=2, help="Number of steps to ignore") + parser.add_argument("--use_ixformer_mlp", action="store_true", help="use_ixformer_mlp", default=False) + parser.add_argument("--use_colo_llamaflashatten", action="store_true", help="use_colo_attention", default=False) + parser.add_argument("--use_ixformer_fusedrmsnormres", action="store_true", help="fused res and accumulating weight grad in rmsnormalization", default=False) + + args = parser.parse_args() + train(args) diff --git a/toolbox/ColossalAI/v0.4.4/patches/applications/ColossalMoE/tests/test_EPMixtralSparseMoeBlock.py b/toolbox/ColossalAI/v0.4.4/patches/applications/ColossalMoE/tests/test_EPMixtralSparseMoeBlock.py new file mode 100644 index 0000000000000000000000000000000000000000..d1be461837e33774a4d59422f5f4598c09663a96 --- /dev/null +++ b/toolbox/ColossalAI/v0.4.4/patches/applications/ColossalMoE/tests/test_EPMixtralSparseMoeBlock.py @@ -0,0 +1,101 @@ +#!/usr/bin/env python3 +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +torchrun --standalone --nproc_per_node 8 test_EPMixtralSparseMoeBlock.py +""" + +from copy import deepcopy + +import torch +import torch.distributed as dist +from torch.testing import assert_close + +from transformers.models.mixtral import MixtralConfig +from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock + +import colossalai +from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin +from colossalai.shardformer.modeling.mixtral import ( + EPMixtralSparseMoeBlock, + EPOptimizeMixtralSparseMoeBlock, +) + + +def build_model(hidden_size, n_experts, top_k): + + torch.manual_seed(0) + + config = MixtralConfig( + hidden_size=hidden_size, + num_local_experts=n_experts, + num_experts_per_tok=top_k, + num_hidden_layers=1, + max_position_embeddings=32768, + ) + + model = MixtralSparseMoeBlock(config).cuda() + + plugin = MoeHybridParallelPlugin( + precision="bf16", + tp_size=1, + pp_size=1, + zero_stage=1, + ep_size=dist.get_world_size(), + ) + + ep_model = deepcopy(model) + ep_model = EPMixtralSparseMoeBlock.from_native_module( + ep_model, + ep_group=plugin.ep_group, + tp_group=plugin.tp_group, + moe_dp_group=plugin.moe_dp_group, + ) + + return model, ep_model + + +if __name__ == "__main__": + + hidden_size = 16 + n_experts, top_k = 8, 2 + seq_len = 3 + + rtol, atol = 1e-2, 1e-4 + + colossalai.launch_from_torch(seed=0) + model, ep_model = build_model(hidden_size, n_experts, top_k) + + # Test Forward + x = torch.rand(1, seq_len, hidden_size, requires_grad=True).cuda() + orig_output, orig_logits = model(x) + ep_output, ep_logits = ep_model(x) + + assert_close(orig_logits, ep_logits, rtol=rtol, atol=atol) + assert_close(orig_output, ep_output, rtol=rtol, atol=atol) + + # Test Backward + orig_loss = orig_output.mean() + orig_loss.backward() + ep_loss = ep_output.mean() + ep_loss.backward() + assert_close(orig_loss, ep_loss, rtol=rtol, atol=atol) + name_to_p = {n: p for n, p in model.named_parameters()} + for n, ep_p in ep_model.named_parameters(): + p = name_to_p[n] + if ep_p.grad is not None: + assert_close(p.grad, ep_p.grad, rtol=rtol, atol=atol) + + dist.destroy_process_group() diff --git a/toolbox/ColossalAI/v0.4.4/patches/applications/ColossalMoE/tests/test_ep_optimize_perf.py b/toolbox/ColossalAI/v0.4.4/patches/applications/ColossalMoE/tests/test_ep_optimize_perf.py new file mode 100644 index 0000000000000000000000000000000000000000..08715875aa94b6eff452c0ed061cdbc72a733930 --- /dev/null +++ b/toolbox/ColossalAI/v0.4.4/patches/applications/ColossalMoE/tests/test_ep_optimize_perf.py @@ -0,0 +1,138 @@ +#!/usr/bin/env python3 +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +torchrun --standalone --nproc_per_node 4 test_ep_optimize_perf.py +""" + +from time import time +from copy import deepcopy + +import torch +import torch.distributed as dist +from torch.testing import assert_close + +from transformers.models.mixtral import MixtralConfig +from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock + +import colossalai +from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin +from colossalai.shardformer.modeling.mixtral import ( + EPMixtralSparseMoeBlock, + EPOptimizeMixtralSparseMoeBlock, +) + + +def build_model(): + + torch.manual_seed(0) + + config = MixtralConfig(max_position_embeddings=32768) + model = MixtralSparseMoeBlock(config).cuda() + + plugin = MoeHybridParallelPlugin( + precision="bf16", + tp_size=1, + pp_size=1, + zero_stage=1, + ep_size=dist.get_world_size(), + ) + + ep_model = deepcopy(model) + ep_model = EPMixtralSparseMoeBlock.from_native_module( + ep_model, + ep_group=plugin.ep_group, + tp_group=plugin.tp_group, + moe_dp_group=plugin.moe_dp_group, + ) + + ep_optimize_model = deepcopy(model) + ep_optimize_model = EPOptimizeMixtralSparseMoeBlock.from_native_module( + ep_optimize_model, + ep_group=plugin.ep_group, + tp_group=plugin.tp_group, + moe_dp_group=plugin.moe_dp_group, + ) + + return model, ep_model, ep_optimize_model + + +def rum_model(input, model, model_prefix=None): + start = time() + output, logits = model(input) + loss = output.mean() + torch.cuda.synchronize() + fwd_end = time() + fwd_time = fwd_end - start + + loss.backward() + torch.cuda.synchronize() + bwd_end = time() + bwd_time = bwd_end - fwd_end + if dist.get_rank() == 0: + print( + f"{model_prefix} model forward time={fwd_time:.3f}s, backward time={bwd_time:.3f}s" + ) + + return output, fwd_time, bwd_time + + +def warm_up(input, model): + output, logits = model(input) + loss = output.mean() + loss.backward() + + +if __name__ == "__main__": + + n_warm_up = 10 + seq_len, hidden_size = 1024, 4096 + + colossalai.launch_from_torch(seed=0) + model, ep_model, ep_opt_model = build_model() + + input = torch.rand(1, seq_len, hidden_size, requires_grad=True).cuda() + for _ in range(n_warm_up): + warm_up(input, model) + warm_up(input, ep_model) + warm_up(input, ep_opt_model) + torch.cuda.synchronize() + + orig_fwd_out, _, _ = rum_model(input, model, model_prefix="Original") + ep_fwd_out, fwd_time, bwd_time = rum_model(input, ep_model, model_prefix="EP") + ep_opt_fwd_out, opt_fwd_time, opt_bwd_time = rum_model( + input, ep_opt_model, model_prefix="EP_Optimize" + ) + + if dist.get_rank() == 0: + print(f"ep forward improve {100*fwd_time/opt_fwd_time:.0f}%") + print(f"ep backward improve {100*bwd_time/opt_bwd_time:.0f}%") + print( + f"ep forward+backward improve {100*(fwd_time+bwd_time)/(opt_fwd_time+opt_bwd_time):.0f}%" + ) + + rtol, atol = 1e-2, 1e-4 + + # Test Forward Accuracy + assert_close(orig_fwd_out, ep_opt_fwd_out, rtol=rtol, atol=atol) + + # Test Backward Accuracy + name_to_p = {n: p for n, p in model.named_parameters()} + for n, ep_p in ep_opt_model.named_parameters(): + p = name_to_p[n] + if ep_p.grad is not None: + assert_close(p.grad, ep_p.grad, rtol=rtol, atol=atol) + + dist.destroy_process_group() \ No newline at end of file diff --git a/toolbox/ColossalAI/v0.4.4/patches/build_colossalai.sh b/toolbox/ColossalAI/v0.4.4/patches/build_colossalai.sh new file mode 100644 index 0000000000000000000000000000000000000000..dec423600bd3dcc1a9f9bfbcbd88d195646ad867 --- /dev/null +++ b/toolbox/ColossalAI/v0.4.4/patches/build_colossalai.sh @@ -0,0 +1,43 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#!/bin/bash + +COREX_VERSION=${COREX_VERSION:-latest} +MAX_JOBS=${MAX_JOBS:-$(nproc --all)} +PYTHON_PATH=$(which python3) +${PYTHON_PATH} -c "import torch;print(torch.__version__)" || { + echo "ERROR: building vision requries torch has been installed." + exit 1 +} +PY_VERSION=`${PYTHON_PATH} -V 2>&1|awk '{print $2}'|awk -F '.' '{print $2}'` +OS_ID=$(awk -F= '/^ID=/{print $2}' /etc/os-release | tr -d '"') + +pip3 install -r requirements/requirements.txt + +# ${PYTHON_PATH} -m pip install -r requirements_dev.txt || exit + +if [[ "${COREX_VERSION}" == "latest" || -z "${COREX_VERSION}" ]]; then + COREX_VERSION=`date --utc +%Y%m%d%H%M%S` +fi +export COLOSSALAI_LOCAL_VERSION_IDENTIFIER="corex.${COREX_VERSION}" + +export MAX_JOBS=${MAX_JOBS} + +${PYTHON_PATH} setup.py build 2>&1 | tee compile.log; [[ ${PIPESTATUS[0]} == 0 ]] || exit + +${PYTHON_PATH} setup.py bdist_wheel -d build_pip || exit + +# Return 0 status if all finished +exit 0 \ No newline at end of file diff --git a/toolbox/ColossalAI/v0.4.4/patches/clean_colossalai.sh b/toolbox/ColossalAI/v0.4.4/patches/clean_colossalai.sh new file mode 100644 index 0000000000000000000000000000000000000000..782addc14226f6efb0167579ca38acb9faf4f1cc --- /dev/null +++ b/toolbox/ColossalAI/v0.4.4/patches/clean_colossalai.sh @@ -0,0 +1,25 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#!/bin/bash + +PYTHON_PATH=$(which python3) + +rm -rf build +${PYTHON_PATH} setup.py clean || true +rm -rf build_pip +rm -rf ipex.egg-info +rm -rf colossalai/git_version_info_installed.py +# Return 0 status if all finished +exit 0 \ No newline at end of file diff --git a/toolbox/ColossalAI/v0.4.4/patches/colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py b/toolbox/ColossalAI/v0.4.4/patches/colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py new file mode 100644 index 0000000000000000000000000000000000000000..65563aae16639658658e54d669eac1475ba6cfeb --- /dev/null +++ b/toolbox/ColossalAI/v0.4.4/patches/colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py @@ -0,0 +1,80 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from abc import ABC, abstractmethod +from typing import Dict + +import torch +from torch import Tensor + +from colossalai.accelerator import get_accelerator +from colossalai.logging import get_dist_logger + +__all__ = ["BaseGradScaler"] + + +class BaseGradScaler(ABC): + """A base class for the gradient scaler. + + Args: + initial_scale (float): the initial loss scale + verbose (bool): whether to log messages + """ + + def __init__(self, initial_scale: float, verbose: bool): + assert initial_scale > 0 + self._scale = torch.tensor([initial_scale], device=get_accelerator().get_current_device(), dtype=torch.float) + self._verbose = verbose + + if self._verbose: + self._logger = get_dist_logger() + + @property + def scale(self) -> Tensor: + """Returns the loss scale.""" + + return self._scale + + @property + def inv_scale(self) -> Tensor: + """Returns the inverse of the loss scale.""" + + return self._scale.reciprocal().float() + + def state_dict(self) -> Dict: + """Returns the states of the gradient scaler as a dict object.""" + + state_dict = dict() + state_dict["scale"] = self.scale + return state_dict + + def load_state_dict(self, state_dict: Dict) -> None: + """Load the states of the gradient scaler from a dict object. + + Args: + state_dict (dict): the states of the gradient scaler + """ + + self._scale = state_dict["scale"] + + @abstractmethod + def update(self, overflow: bool) -> None: + """Update the loss scale. + + Args: + overflow (bool): whether overflow occurs + """ + + def log(self, message, *args, **kwargs): + """Log messages. + + Args: + message (str): the message to log + *args: positional arguments for :class:`colossalai.logging.DistributedLogger` + **kwargs: key-word arguments for :class:`colossalai.logging.DistributedLogger` + """ + + if self._verbose: + self._logger.info(message, *args, **kwargs) diff --git a/toolbox/ColossalAI/v0.4.4/patches/colossalai/amp/naive_amp/mixed_precision_optimizer.py b/toolbox/ColossalAI/v0.4.4/patches/colossalai/amp/naive_amp/mixed_precision_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..78a624d191ca8f1a09d9cf337691580a91f99e91 --- /dev/null +++ b/toolbox/ColossalAI/v0.4.4/patches/colossalai/amp/naive_amp/mixed_precision_optimizer.py @@ -0,0 +1,217 @@ +#!/usr/bin/env python3 +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +from typing import Dict, List, Tuple + +import torch +from torch import Tensor, inf +from torch.nn import Module, Parameter +from torch.optim import Optimizer + +from colossalai.interface import OptimizerWrapper + +from .mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMixin + + +class NaiveFP16MixedPrecisionMixin(FP16MixedPrecisionMixin): + def __init__( + self, + working_params: List[Parameter], + initial_scale: float = 2**16, + min_scale: float = 1, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + max_scale: float = 2**32, + ) -> None: + super().__init__( + initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, hysteresis, max_scale + ) + self.params = working_params + + def check_local_overflow(self) -> bool: + for p in self.params: + if p.grad is not None and not torch.isfinite(p.grad).all(): + return True + return False + + +class MixedPrecisionOptimizer(OptimizerWrapper): + def __init__( + self, + optim: Optimizer, + precision: str = "fp16", + initial_scale: float = 2**16, + min_scale: float = 1, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + max_scale: float = 2**32, + max_norm: float = 0.0, + ): + super().__init__(optim) + if precision == "fp16": + working_params = [] + for group in self.optim.param_groups: + for p in group["params"]: + working_params.append(p) + self.mixed_precision = NaiveFP16MixedPrecisionMixin( + working_params, + initial_scale=initial_scale, + min_scale=min_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + hysteresis=hysteresis, + max_scale=max_scale, + ) + elif precision == "bf16": + self.mixed_precision = BF16MixedPrecisionMixin() + else: + raise ValueError(f"Unsupported precision: {precision}") + self.max_norm = max_norm + self.working_to_master_map: Dict[Parameter, Tensor] = {} + self.master_to_working_map: Dict[Tensor, Parameter] = {} + + # create master weights + for group in self.optim.param_groups: + master_params = [] + for p in group["params"]: + if p.requires_grad: + master_p = p + if p.dtype != torch.float: + master_p = p.detach().float() + self.working_to_master_map[p] = master_p + self.master_to_working_map[master_p] = p + master_params.append(master_p) + group["params"] = master_params + + def backward(self, loss: Tensor, *args, **kwargs): + loss = self.mixed_precision.pre_backward(loss) + loss.backward(*args, **kwargs) + + def backward_by_grad(self, tensor: Tensor, grad: Tensor): + grad = self.mixed_precision.pre_backward_by_grad(tensor, grad) + tensor.backward(grad) + + def zero_grad(self, *args, **kwargs): + for p in self.working_to_master_map.keys(): + p.grad = None + self.mixed_precision.pre_zero_grad() + return super().zero_grad(*args, **kwargs) + + def _unscale_and_clip_grads(self, total_norm: float) -> None: + """ + Unscale and clip gradients before performing the optimization step. + + Args: + total_norm (float): The computed total gradient norm. + + Returns: + None + """ + div_scale = 1.0 + + # If mixed-precision training is used, get the gradient division scale from the mixed-precision handler. + if self.mixed_precision is not None: + div_scale = self.mixed_precision.get_grad_div_scale() + + if self.max_norm > 0.0: + # Calculate the scaling factor for gradient clipping + # The gradient norm is scaled by 'div_scale' and then clipped to 'max_norm' + clip = ((total_norm / div_scale) + 1e-6) / self.max_norm + + # If the clip factor exceeds 1, adjust 'div_scale' accordingly to ensure clipping + if clip > 1: + div_scale = clip * div_scale + + # Apply the scaling factor to gradients + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + p.grad.data.mul_(1.0 / div_scale) + + def _compute_grad_norm(self, param_gradient_pairs: List[Tuple[Tensor]], norm_type: int = 2) -> int: + r""" + Compute and return the gradient norm for gradient clipping. + + Args: + param_gradient_pairs (List[Tuple[Tensor]]): List of (parameter, gradient) pairs; gradients are used for norm calculation. + norm_type (int, optional): Type of the norm used (e.g., 2 for L2 norm). Defaults to 2. + + Returns: + float: The total norm of the given gradients. + """ + + if len(param_gradient_pairs) == 0: + return 0.0 + + # gradients used for norm calculation. + gradients = [grad for param, grad in param_gradient_pairs] + + if norm_type == inf: + total_norm = max(grad.data.abs().max() for grad in gradients) + + else: + total_norm_exponentiated = 0.0 + for grad in gradients: + total_norm_exponentiated += grad.data.float().norm(norm_type) ** norm_type + total_norm = total_norm_exponentiated ** (1.0 / norm_type) + + return total_norm + + def step(self, *args, **kwargs): + if self.mixed_precision.should_skip_step(): + self.zero_grad() + return + # prepare grads + for group in self.optim.param_groups: + for p in group["params"]: + working_param = self.master_to_working_map[p] + if p is working_param: + continue + if working_param.grad is not None: + p.grad = working_param.grad.data.float() + working_param.grad = None + + # gradient unscale and clip. + if self.max_norm <= 0: + # no need to compute gradient norm. + total_norm = 0.0 + else: + # compute the total norm. + param_gradient_pairs = [ + (self.master_to_working_map[p], p.grad) + for group in self.param_groups + for p in group["params"] + if p.grad is not None + ] + total_norm = self._compute_grad_norm(param_gradient_pairs) + self._unscale_and_clip_grads(total_norm) + + self.optim.step(*args, **kwargs) + # update working params + for group in self.optim.param_groups: + for p in group["params"]: + working_param = self.master_to_working_map[p] + if p is working_param: + continue + working_param.data.copy_(p.data) + + def update_master_params(self, model: Module): + # Update master params from working params + with torch.no_grad(): + for p in model.parameters(): + if (p is None) or (p not in self.working_to_master_map): + continue + master_param = self.working_to_master_map[p] + master_param.data.copy_(p.data) + + def get_working_to_master_map(self) -> Dict[int, torch.Tensor]: + return {id(working_p): master_p for working_p, master_p in self.working_to_master_map.items()} + + def get_master_to_working_map(self) -> Dict[int, torch.Tensor]: + return {id(master_p): working_p for master_p, working_p in self.master_to_working_map.items()} diff --git a/toolbox/ColossalAI/v0.4.4/patches/colossalai/booster/plugin/hybrid_parallel_plugin.py b/toolbox/ColossalAI/v0.4.4/patches/colossalai/booster/plugin/hybrid_parallel_plugin.py new file mode 100644 index 0000000000000000000000000000000000000000..d098ed9d602ae0501a79bb25f12b0436431ebe0f --- /dev/null +++ b/toolbox/ColossalAI/v0.4.4/patches/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -0,0 +1,1498 @@ +#!/usr/bin/env python3 +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +import ctypes +import random +from collections import defaultdict +from contextlib import contextmanager, nullcontext +from copy import deepcopy +from functools import partial +from types import MethodType +from typing import Any, Callable, Dict, Iterator, List, Optional, OrderedDict, Tuple, Union + +import numpy as np +import torch +import torch.distributed as dist +from torch import Tensor, inf +from torch.distributed import ProcessGroup, get_world_size +from torch.nn import Module, SyncBatchNorm +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler as LRScheduler +from torch.utils._pytree import tree_map +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler + +from colossalai.accelerator import get_accelerator +from colossalai.amp.naive_amp.mixed_precision_optimizer import MixedPrecisionOptimizer +from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO +from colossalai.cluster import ProcessGroupMesh +from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper +from colossalai.interface.optimizer import DistributedOptim +from colossalai.logging import get_dist_logger +from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed +from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.quantization import BnbQuantizationConfig, quantize_model +from colossalai.quantization.fp8_hook import FP8Hook +from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer +from colossalai.shardformer.layer.utils import SeqParallelUtils, is_share_sp_tp +from colossalai.shardformer.policies.base_policy import Policy +from colossalai.tensor.colo_parameter import ColoParameter +from colossalai.tensor.d_tensor.api import is_distributed_tensor +from colossalai.tensor.param_op_hook import ColoParamOpHookManager +from colossalai.zero.low_level import LowLevelZeroOptimizer +from colossalai.zero.low_level.zero_hook import ZeroOpHook, wait_all_gather_handle + +from .pp_plugin_base import PipelinePluginBase + +SUPPORT_SP_MODE = ["split_gather", "ring", "all_to_all", "ring_attn"] + +PRECISION_TORCH_TYPE = {"fp16": torch.float16, "fp32": torch.float32, "bf16": torch.bfloat16} + + +def _convert_floating_point(x, dtype: torch.dtype = torch.float16): + if isinstance(x, torch.Tensor) and torch.is_floating_point(x): + return x.to(dtype) + return x + + +class HybridParallelModule(ModelWrapper, AMPModelMixin): + def __init__( + self, + module: Module, + precision: str, + shard_config: ShardConfig, + dp_group: ProcessGroup, + tp_group: ProcessGroup, + sp_group: ProcessGroup, + use_ddp: bool, + ddp_config: dict, + custom_policy: Policy, + overlap_allgather: bool = False, + use_fp8: bool = False, + ) -> None: + self.stage_manager = shard_config.pipeline_stage_manager + self.shard_config = shard_config + self.dp_group = dp_group + self.tp_group = tp_group + self.sp_group = sp_group + self.use_ddp = use_ddp + self.require_grad_sync = True + self.overlap_allgather = overlap_allgather + self.use_fp8 = use_fp8 + + shardformer = ShardFormer(shard_config) + if custom_policy is not None: + assert isinstance(custom_policy, object) + module, self.shared_params = shardformer.optimize(module, policy=custom_policy) + + # setting process groups for shared parameters + self.shared_param_process_groups = [] + for shared_param in self.shared_params: + if len(shared_param) > 0: + self.shared_param_process_groups.append( + self.stage_manager.init_process_group_by_stages(list(shared_param.keys())) + ) + + # setting mixed_precision + self.mixed_precision = None + if precision == "fp16": + self.mixed_precision = torch.float16 + elif precision == "bf16": + self.mixed_precision = torch.bfloat16 + if self.mixed_precision is not None: + module = module.to(self.mixed_precision) + module = module.to(get_accelerator().get_current_device()) + + # setting input type cast when using mixed precision + self.convert_fn = None + if self.mixed_precision is not None: + self.convert_fn = partial(_convert_floating_point, dtype=self.mixed_precision) + + # setting ddp configs + if use_ddp: + # convert model to sync bn + module = SyncBatchNorm.convert_sync_batchnorm(module, dp_group) + # wrap the model with PyTorch DDP + module = DDP(module, process_group=dp_group, **ddp_config) + + super().__init__(module) + self.op_hooks = [] + if use_fp8: + self.op_hooks.append(FP8Hook()) + if overlap_allgather: + self.op_hooks.append(ZeroOpHook()) + if use_fp8 or overlap_allgather: + for p in module.parameters(): + if p.requires_grad and type(p) is not ColoParameter: + p.__class__ = ColoParameter + p.__init__(p, requires_grad=True) + + def sync_shared_params(self): + for shared_param, group in zip(self.shared_params, self.shared_param_process_groups): + if self.stage_manager.stage in shared_param: + param = shared_param[self.stage_manager.stage] + dist.all_reduce(param.grad, group=group) + dist.barrier() + + @contextmanager + def no_sync(self): + r""" + A context manager to disable automatic gradient synchronization (all-reduce) and allow manual synchronization + when 'no_sync' is active. Alternatively, synchronization will occur in the first forward-backward pass + when exiting the context. + """ + + # Store the current value of 'require_grad_sync' to restore it later. + old_require_grad_sync = self.require_grad_sync + # Disable automatic gradient synchronization. + self.require_grad_sync = False + try: + if self.use_ddp: + # If using data parallel processing (use_ddp), disable synchronization too. + with self.module.no_sync(): + yield + else: + yield + finally: + # Restore the original value of 'require_grad_sync'. + self.require_grad_sync = old_require_grad_sync + + def sync_dp_grads(self): + r""" + Synchronize gradients across data parallelism (DP) if the DP group size is greater than 1. + This function performs an all-reduce operation to combine gradients from different devices in the DP group. + + Args: + None + + Returns: + None + """ + + # Check if the DP group size is 1, meaning no synchronization is needed. + if self.dp_group.size() == 1: + return + + # Iterate through the model's parameters and perform gradient synchronization. + for p in self.module.parameters(): + if p.grad is not None: + # Perform all-reduce to combine gradients from different devices. + dist.all_reduce(p.grad, group=self.dp_group) + # Normalize the gradient by dividing it by the DP group size. + p.grad.div_(self.dp_group.size()) + + def sync_sp_grads(self, grads: Optional[List[torch.Tensor]] = None): + r""" + Synchronize gradients that are partially derived within sequence parallelism + if sequence parallelism is enabled. Gradients can be provided explicitly or extracted + from the module. + + Args: + grads (Optional[List[torch.Tensor]]): A list of gradient tensors to synchronize. If not + provided, gradients will be extracted from the model. + + Returns: + None + """ + + if self.shard_config.enable_sequence_parallelism: + if self.shard_config.sequence_parallelism_mode in ["all_to_all", "ring_attn"]: + return + + if self.shard_config.sequence_parallelism_mode in ["split_gather", "ring"]: + # If sequence parallelism is enabled and mode is split_gather or ring, gradients are synchronized + # across the tensor parallelism group. + group = self.tp_group + else: + raise ValueError(f"Unknown sequence parallelism mode: {self.shard_config.sequence_parallelism_mode}") + + if grads is not None: + # Synchronize provided gradient tensors across the tensor parallelism group. + SeqParallelUtils.allreduce_partial_data_grad(process_group=group, grads=grads) + else: + # Synchronize gradients from the model across the tensor parallelism group. + SeqParallelUtils.allreduce_partial_data_grad(process_group=group, model=self.module) + + def forward(self, *args, **kwargs): + if self.convert_fn is not None: + args = tree_map(self.convert_fn, args) + kwargs = tree_map(self.convert_fn, kwargs) + with self._hook_context(): + return super().forward(*args, **kwargs) + + def unwrap(self): + module = super().unwrap() + if isinstance(module, DDP): + module = module.module + return module + + def _force_wait_all_gather(self): + for p in self.module.parameters(): + wait_all_gather_handle(p) + + def _hook_context(self): + return ColoParamOpHookManager.use_hooks(*self.op_hooks) if len(self.op_hooks) > 0 else nullcontext() + + +def get_param_info(optim: Optimizer): + # Get a backup of necessary information of parameters for future use, which includes: + # 1. A complete param_group, with params in the form of param_id + # 2. A mapping from param address (obtained using id(param)) to integer param_id + # 3. A mapping from integer param_id to param address. + # 4. A mapping from param_address (obtained using id(param)) to the original shape of parameter before sharding. + # When Zero is used, the params here are fp16/bf16 model params rather than fp32 master params in optimizer. + + if optim is None: + return {} + param_info = {"param_groups": [], "param2id": {}, "id2param": {}, "param2shape": {}} + start_index = 0 + for group in optim.param_groups: + packed_group = {k: v for k, v in group.items() if k != "params"} + packed_group["params"] = [] + + for param_id, param in enumerate(group["params"], start_index): + original_shape = param.shape if isinstance(param, torch.Tensor) else None + packed_group["params"].append(param_id) + param_info["param2id"][id(param)] = param_id + param_info["id2param"][param_id] = id(param) + param_info["param2shape"][id(param)] = original_shape + + param_info["param_groups"].append(packed_group) + start_index += len(group["params"]) + + return param_info + + +def reinitialize_optimizer(optim: Optimizer, model: Module): + model_params = set(model.parameters()) + new_param_groups = [] + for group in optim.param_groups: + params = [p for p in group["params"] if p in model_params] + new_param_groups.append({**group, "params": params}) + optim.__setstate__({"param_groups": new_param_groups}) + + +class HybridParallelNaiveOptimizer(OptimizerWrapper): + def __init__( + self, + optim: Optimizer, + model: HybridParallelModule, + use_pipeline: bool, + param_info: OrderedDict, + max_norm: float = 0, + tp_process_group: Optional[ProcessGroup] = None, # if using tp + pp_process_group: Optional[ProcessGroup] = None, # if using pp + ): + self.param_info = param_info + if use_pipeline: + reinitialize_optimizer(optim, model) + self.model = model + self.stage_manager = model.stage_manager + self.shared_params = model.shared_params + self.max_norm = max_norm + self.tp_pg = tp_process_group + self.pp_pg = pp_process_group + self.tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1 + self.pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1 + super().__init__(optim) + + def backward(self, loss: Tensor, *args, **kwargs): + r""" + Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients. + + This method performs backward pass for gradient computation. If sequence parallelism is enabled + and gradient synchronization is required, it will synchronize gradients that are partially derived + within sequence parallelism across tp parallelism groups. + + Args: + loss (Tensor): The loss tensor to compute gradients with respect to. + *args: Additional positional arguments to be passed to the superclass backward method. + **kwargs: Additional keyword arguments to be passed to the superclass backward method. + + Returns: + None + """ + + # Call the superclass backward method to compute gradients. + with self.model._hook_context(): + super().backward(loss, *args, **kwargs) + + if self.model.require_grad_sync: + # If gradient synchronization is required, sync sequence parallelism gradients. + self.model.sync_sp_grads() + else: + # If gradient synchronization is is not required, return. + return + + def backward_by_grad(self, tensor: Tensor, grad: Tensor): + """ + Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients. + + This method performs a backward pass for gradient computation using a precomputed gradient tensor. + If sequence parallelism is enabled and gradient synchronization is required, it will synchronize + gradients that are partially derived within sequence parallelism across tp parallelism groups. + + Args: + tensor (Tensor): The input tensor for which gradients are computed. + grad (Tensor): The precomputed gradient tensor to compute gradients with respect to the input tensor. + + Returns: + None + """ + + # Call the superclass backward method to compute gradients. + super().backward_by_grad(tensor, grad) + + if self.model.require_grad_sync: + # If gradient synchronization is required, sync sequence parallelism gradients. + self.model.sync_sp_grads() + else: + # If gradient synchronization is is not required, return. + return + + def step(self, *args, **kwargs): + r""" + Perform an optimization step. + + Args: + *args: Variable-length positional arguments to be passed to the optimizer's step function. + **kwargs: Keyword arguments to be passed to the optimizer's step function. + """ + + if self.max_norm > 0: + # Compute the total gradient norm. + param_gradient_pairs = [ + (p, p.grad) for group in self.optim.param_groups for p in group["params"] if p.grad is not None + ] + total_norm = self._compute_grad_norm(param_gradient_pairs) + + # Clip the gradients to prevent exploding gradients. + self._clip_grad_norm(total_norm) + + # Perform the optimization step using the underlying optimizer. + self.optim.step(*args, **kwargs) + + def _compute_grad_norm(self, param_gradient_pairs: List[Tuple[Tensor]], norm_type: int = 2) -> int: + r""" + Compute and return the gradient norm for gradient clipping. + + Args: + param_gradient_pairs (List[Tuple[Tensor]]): List of (parameter, gradient) pairs; gradients are used for norm calculation. + norm_type (int, optional): Type of the norm used (e.g., 2 for L2 norm). Defaults to 2. + + Returns: + float: The total norm of the given gradients. + """ + + if len(param_gradient_pairs) == 0: + return 0.0 + + norm_type = float(norm_type) + + # gradients used for norm calculation. + gradients = [grad for param, grad in param_gradient_pairs] + + if norm_type == inf: + total_norm = max(grad.data.abs().max() for grad in gradients) + total_norm_cuda = torch.tensor( + [float(total_norm)], device=get_accelerator().get_current_device(), dtype=torch.float32 + ) + if self.tp_size > 1: + dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg) + if self.pp_size > 1: + dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.pp_pg) + total_norm = total_norm_cuda.item() + else: + # gradients used for norm calculation. + gradients = [grad for param, grad in param_gradient_pairs] + # grad_to_param_mapping is used to check which gradients are not distributed across devices of the 'tp_group'. + grad_to_param_mapping = {id(grad): param for param, grad in param_gradient_pairs} + + total_norm_exponentiated = 0.0 + for grad in gradients: + grad_norm_exponentiated = grad.data.float().norm(norm_type) ** norm_type + + # If 'tp_size' is greater than 1 and the parameter for the gradient is not a distributed tensor, + # it indicates that the parameter is not distributed across devices of the 'tp_group'. + # Consequently, there is no need to perform an 'all_reduce' operation for 'grad_norm'. + # However, we still perform the 'all_reduce' operation for the sake of good coding practices. + # To ensure mathematical equivalence, we divide the 'grad_norm' by 'tp_size.' + if self.tp_size > 1: + param_for_grad = grad_to_param_mapping[id(grad)] + if not is_distributed_tensor(param_for_grad): + grad_norm_exponentiated /= self.tp_size + + # If 'pp_size' is greater than 1 and the gradient belongs to shared parameters, + # it means that this parameter is used in two different pipeline stages. + # To avoid redundant norm calculations, we divide the exponent of this norm by + # the number of shared stages. + if self.pp_size > 1: + for shared_param in self.shared_params: + if self.stage_manager.stage in shared_param: + stage_shared_param = shared_param[self.stage_manager.stage] + if grad is stage_shared_param.grad: + grad_norm_exponentiated /= len(shared_param) + + total_norm_exponentiated += grad_norm_exponentiated + + total_norm_exponentiated_cuda = torch.tensor( + [float(total_norm_exponentiated)], device=get_accelerator().get_current_device(), dtype=torch.float32 + ) + if self.tp_size > 1: + # compute norm in tp process group + dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.tp_pg) + if self.pp_size > 1: + # compute norm in pp process group + dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.pp_pg) + + # compute the total_norm + total_norm = total_norm_exponentiated_cuda.item() ** (1.0 / norm_type) + + return total_norm + + def _clip_grad_norm(self, total_norm: float) -> None: + r""" + Clips the gradients of the model's parameters to prevent exploding gradients. + + Args: + total_norm (float): The computed total gradient norm. + + Returns: + None + """ + clip_coef = torch.tensor(self.max_norm / (total_norm + 1e-6)) + clip_coef_clamped = torch.clamp(clip_coef, max=1.0) + + for group in self.optim.param_groups: + for p in group["params"]: + if p.grad is None: + continue + p.grad.data.mul_(clip_coef_clamped) + + def update_master_params(self, model: Module): + pass + + def get_working_to_master_map(self): + return None + + def get_master_to_working_map(self): + return None + + +class HybridParallelAMPOptimizer(MixedPrecisionOptimizer): + def __init__( + self, + optim: Optimizer, + model: HybridParallelModule, + use_pipeline: bool, + param_info: OrderedDict, + precision: str = "fp16", + initial_scale: float = 2**16, + min_scale: float = 1, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + max_scale: float = 2**32, + max_norm: float = 0, + tp_process_group: Optional[ProcessGroup] = None, # if using tp + pp_process_group: Optional[ProcessGroup] = None, # if using pp + ): + self.model = model + self.param_info = param_info + self.stage_manager = model.stage_manager + self.shared_params = model.shared_params + self.tp_pg = tp_process_group + self.pp_pg = pp_process_group + self.tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1 + self.pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1 + if use_pipeline: + reinitialize_optimizer(optim, model) + super().__init__( + optim, + precision=precision, + initial_scale=initial_scale, + min_scale=min_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + hysteresis=hysteresis, + max_scale=max_scale, + max_norm=max_norm, + ) + + def backward(self, loss: Tensor, *args, **kwargs): + r""" + Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients. + + This method performs backward pass for gradient computation. If sequence parallelism is enabled + and gradient synchronization is required, it will synchronize gradients that are partially derived + within sequence parallelism across tp parallelism groups. + + Args: + loss (Tensor): The loss tensor to compute gradients with respect to. + *args: Additional positional arguments to be passed to the superclass backward method. + **kwargs: Additional keyword arguments to be passed to the superclass backward method. + + Returns: + None + """ + # Call the superclass backward method to compute gradients. + with self.model._hook_context(): + super().backward(loss, *args, **kwargs) + + if self.model.require_grad_sync: + # If gradient synchronization is required, sync sequence parallelism gradients. + self.model.sync_sp_grads() + else: + # If gradient synchronization is is not required, return. + return + + def backward_by_grad(self, tensor: Tensor, grad: Tensor): + """ + Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients. + + This method performs a backward pass for gradient computation using a precomputed gradient tensor. + If sequence parallelism is enabled and gradient synchronization is required, it will synchronize + gradients that are partially derived within sequence parallelism across tp parallelism groups. + + Args: + tensor (Tensor): The input tensor for which gradients are computed. + grad (Tensor): The precomputed gradient tensor to compute gradients with respect to the input tensor. + + Returns: + None + """ + # Call the superclass backward method to compute gradients. + super().backward_by_grad(tensor, grad) + + if self.model.require_grad_sync: + # If gradient synchronization is required, sync sequence parallelism gradients. + self.model.sync_sp_grads() + else: + # If gradient synchronization is is not required, return. + return + + def _compute_grad_norm(self, param_gradient_pairs: List[Tuple[Tensor]], norm_type: int = 2) -> int: + r""" + Compute and return the gradient norm for gradient clipping. + + Args: + param_gradient_pairs (List[Tuple[Tensor]]): List of (parameter, gradient) pairs; gradients are used for norm calculation. + norm_type (int, optional): Type of the norm used (e.g., 2 for L2 norm). Defaults to 2. + + Returns: + float: The total norm of the given gradients. + """ + if len(param_gradient_pairs) == 0: + return 0.0 + + norm_type = float(norm_type) + + if norm_type == inf: + # The parent class calculates the norm of 'dp' gradients, + # so we need to calculate the norm of 'tp' and 'pp' gradients. + total_norm = super()._compute_grad_norm(param_gradient_pairs, norm_type) + + total_norm_cuda = torch.tensor( + [float(total_norm)], device=get_accelerator().get_current_device(), dtype=torch.float32 + ) + + if self.tp_size > 1: + dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg) + if self.pp_size > 1: + dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.pp_pg) + + total_norm = total_norm_cuda.item() + + else: + # gradients used for norm calculation. + gradients = [grad for param, grad in param_gradient_pairs] + # grad_to_param_mapping is used to check which gradients are not distributed in tensor parallelism. + grad_to_param_mapping = {id(grad): param for param, grad in param_gradient_pairs} + + total_norm_exponentiated = 0.0 + for grad in gradients: + grad_norm_exponentiated = grad.data.float().norm(norm_type) ** norm_type + + # If 'tp_size' is greater than 1 and the parameter for the gradient is not a distributed tensor, + # it indicates that the parameter is not distributed across devices of the 'tp_group'. + # Consequently, there is no need to perform an 'all_reduce' operation for 'grad_norm'. + # However, we still perform the 'all_reduce' operation for the sake of good coding practices. + # To ensure mathematical equivalence, we divide the 'grad_norm' by 'tp_size.' + if self.tp_size > 1: + param_for_grad = grad_to_param_mapping[id(grad)] + if not is_distributed_tensor(param_for_grad): + grad_norm_exponentiated /= self.tp_size + + # If 'pp_size' is greater than 1 and the gradient belongs to shared parameters, + # it means that this parameter is used in two different pipeline stages. + # To avoid redundant norm calculations, we divide the exponent of this norm by + # the number of shared stages. + if self.pp_size > 1: + for shared_param in self.shared_params: + if self.stage_manager.stage in shared_param: + stage_working_shared_param = shared_param[self.stage_manager.stage] + stage_master_shared_param = self.working_to_master_map[stage_working_shared_param] + if grad is stage_master_shared_param.grad: + grad_norm_exponentiated /= len(shared_param) + + total_norm_exponentiated += grad_norm_exponentiated + + total_norm_exponentiated_cuda = torch.tensor( + [float(total_norm_exponentiated)], device=get_accelerator().get_current_device(), dtype=torch.float32 + ) + if self.tp_size > 1: + # compute norm in tp process group + dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.tp_pg) + if self.pp_size > 1: + # compute norm in pp process group + dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.pp_pg) + + # compute the total_norm + total_norm = total_norm_exponentiated_cuda.item() ** (1.0 / norm_type) + + return total_norm + + +class HybridParallelZeroOptimizer(LowLevelZeroOptimizer): + def __init__( + self, + optimizer: Optimizer, + model: HybridParallelModule, + use_pipeline: bool, + param_info: OrderedDict, + pg_to_param_list: Dict[ProcessGroup, List[torch.nn.Parameter]] = None, + initial_scale: int = 2**16, # grad scaler config + min_scale: int = 1, + growth_factor: float = 2.0, + backoff_factor: float = 0.5, + growth_interval: int = 2000, + hysteresis: int = 2, + max_scale: int = 2**24, + clip_grad_norm: float = 0.0, # grad clipping + verbose: bool = False, + reduce_bucket_size: int = 1024 * 1024, # communication + communication_dtype: Optional[torch.dtype] = None, + overlap_communication: bool = True, + partition_grad: bool = False, # stage 2 flag + cpu_offload: bool = False, # cpu offload + dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm + tp_process_group: Optional[ProcessGroup] = None, # if using tp + pp_process_group: Optional[ProcessGroup] = None, # if using pp + forced_dtype: Optional[torch.dtype] = None, + overlap_allgather: bool = False, + fp8_communication: bool = False, + ): + self.model = model + self.param_info = param_info + self.stage_manager = model.stage_manager + self.shared_params = model.shared_params + self.tp_pg = tp_process_group + self.pp_pg = pp_process_group + if use_pipeline: + reinitialize_optimizer(optimizer, model) + super().__init__( + optimizer=optimizer, + initial_scale=initial_scale, + min_scale=min_scale, + pg_to_param_list=pg_to_param_list, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + hysteresis=hysteresis, + max_scale=max_scale, + clip_grad_norm=clip_grad_norm, + verbose=verbose, + reduce_bucket_size=reduce_bucket_size, + communication_dtype=communication_dtype, + overlap_communication=overlap_communication, + partition_grad=partition_grad, + cpu_offload=cpu_offload, + dp_process_group=dp_process_group, + forced_dtype=forced_dtype, + overlap_allgather=overlap_allgather, + fp8_communication=fp8_communication, + backward_context=model._hook_context, + ) + + def sync_dp_grads(self): + r""" + Synchronize gradients in the data parallelism dimension. + + This method wraps the existing `_sync_grad` method in order to explicitly synchronize gradients + in the data parallelism dimension. It is necessary due to the introduction of new parallel dimensions, + namely tp (tensor parallelism) and pp (pipeline parallelism). This ensures better code organization + and readability. + + Args: + None + + Returns: + None + """ + # Call the superclass `_sync_grad` method to synchronize gradients. + super()._sync_grad() + + def _sync_sp_grads(self): + r""" + Synchronize gradients that are partially derived within sequence parallelism. + + This method is responsible for synchronizing partially derived gradients across tp parallelism groups. + It identifies gradients that ara partially derived or not and synchronizes them. + If synchronization is required and gradients are found to be synchronized, + it performs the synchronization. + + Args: + None + + Returns: + None + """ + + def _get_all_working_grads() -> List[Tensor]: + """Retrieve all working gradients from different parameter groups.""" + all_working_grads = [] + for group_id in range(self.num_param_groups): + working_grads = self.get_working_grads_by_group_id(group_id) + all_working_grads.extend(working_grads) + return all_working_grads + + def _get_grads_to_sync(all_working_grads) -> Union[List[Tensor], None]: + """Identify gradients to be synchronized in the sequence parallelism.""" + grads_to_sync = [] + for grad in all_working_grads: + param_id_for_grad = self.get_param_id_for_grad(grad) + param_for_grad = ctypes.cast(param_id_for_grad, ctypes.py_object).value + if SeqParallelUtils.is_sp_partial_derived_param(param_for_grad): + grads_to_sync.append(grad) + + if len(grads_to_sync) > 0: + return grads_to_sync + else: + return None + + # Get all working gradients and gradients to be synchronized. + all_working_grads = _get_all_working_grads() + grads_to_sync = _get_grads_to_sync(all_working_grads) + if self.require_grad_sync and grads_to_sync is not None: + # Synchronize sequence parallelism gradients if required. + SeqParallelUtils.allreduce_partial_data_grad(process_group=self.tp_pg, grads=grads_to_sync) + else: + return + + def backward(self, loss, retain_graph=False): + """ + Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients. + + This method performs the backward pass for gradient computation based on a given loss tensor. + If sequence parallelism is enabled and gradient synchronization is required, it will synchronize + gradients that are partially derived within sequence parallelism across TP parallelism groups. + + Args: + loss: The loss tensor to compute gradients with respect to. + retain_graph (bool): Whether to retain the computation graph. + + Returns: + None + """ + # Call the superclass backward method to compute gradients. + super().backward(loss, retain_graph) + + if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism: + # If gradient synchronization is required, sync sequence parallelism gradients. + self._sync_sp_grads() + else: + # If gradient synchronization is is not required, return. + return + + def backward_by_grad(self, tensor, grad): + """ + Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients. + + This method performs a backward pass for gradient computation based on a precomputed gradient tensor. + If sequence parallelism is enabled and gradient synchronization is required, it will synchronize + gradients that are partially derived within sequence parallelism across TP parallelism groups. + + Args: + tensor: The input tensor for which gradients are computed. + grad: The precomputed gradient tensor to compute gradients with respect to the input tensor. + + Returns: + None + """ + # Call the superclass backward_by_grad method to compute gradients. + super().backward_by_grad(tensor, grad) + + if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism: + # If gradient synchronization is required, sync sequence parallelism gradients. + self._sync_sp_grads() + else: + # If gradient synchronization is is not required, return. + return + + def _compute_grad_norm(self, dp_pg, gradients: List[Tensor], norm_type: int = 2) -> float: + r""" + Compute and return the gradient norm for gradient clipping. + + Args: + gradients (List[Tensor]): A list of tensors containing gradients. + norm_type (int, optional): Type of the p-norm to be computed. Defaults to 2. + + Returns: + float: The computed gradient norm. + """ + + # Check if the list of gradients is empty + if len(gradients) == 0: + return 0.0 + + dp_size = get_world_size(dp_pg) if dp_pg is not None else 1 + tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1 + pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1 + norm_type = float(norm_type) + + if norm_type == inf: + # The parent class calculates the norm of 'dp' gradients, + # so we only need to calculate the norm 'tp' of 'pp' gradients. + total_norm = super()._compute_grad_norm(gradients, norm_type) + + total_norm_cuda = torch.tensor( + [float(total_norm)], device=get_accelerator().get_current_device(), dtype=torch.float32 + ) + + if tp_size > 1: + dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg) + if pp_size > 1: + dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.pp_pg) + + total_norm = total_norm_cuda.item() + else: + total_norm_exponentiated = 0.0 + for grad in gradients: + grad_norm_exponentiated = grad.data.float().norm(norm_type) ** norm_type + + # If 'tp_size' is greater than 1 and the parameter for the gradient is not a distributed tensor, + # it indicates that the parameter is not distributed across devices of the 'tp_group'. + # Consequently, there is no need to perform an 'all_reduce' operation for 'grad_norm'. + # However, we still perform the 'all_reduce' operation for the sake of good coding practices. + # To ensure mathematical equivalence, we divide the 'grad_norm' by 'tp_size.' + if tp_size > 1: + param_id_for_grad = self.get_param_id_for_grad(grad) + param_for_grad = ctypes.cast(param_id_for_grad, ctypes.py_object).value + + if not is_distributed_tensor(param_for_grad): + grad_norm_exponentiated /= tp_size + + # If 'pp_size' is greater than 1 and the gradient belongs to shared parameters, + # it means that this parameter is used in two different pipeline stages. + # To avoid redundant norm calculations, we divide the exponent of this norm by + # the number of shared stages. + if pp_size > 1: + for shared_param in self.shared_params: + if self.stage_manager.stage in shared_param: + stage_shared_param = shared_param[self.stage_manager.stage] + working_grad = self.get_working_grad_by_param_id(id(stage_shared_param)) + if grad is working_grad: + grad_norm_exponentiated /= len(shared_param) + + total_norm_exponentiated += grad_norm_exponentiated + + total_norm_exponentiated_cuda = torch.tensor( + [float(total_norm_exponentiated)], device=get_accelerator().get_current_device(), dtype=torch.float32 + ) + if dp_size > 1: + # compute norm in dp process group + dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=dp_pg) + if tp_size > 1: + # compute norm in tp process group + dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.tp_pg) + if pp_size > 1: + # compute norm in pp process group + dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.pp_pg) + + # Compute the 'total_norm' from 'total_norm_exponentiated' + total_norm = total_norm_exponentiated_cuda.item() ** (1.0 / norm_type) + + return total_norm + + +class HybridParallelPlugin(PipelinePluginBase): + """ + Plugin for Hybrid Parallel Training. + Tensor parallel, pipeline parallel and data parallel(DDP/ZeRO) can be picked and combined in this plugin. + The size of tp and pp should be passed in by user, then the size of dp is automatically calculated from dp_size = world_size / (tp_size * pp_size). + + ```python + from colossalai.booster import Booster + from colossalai.booster.plugin import HybridParallelPlugin + + model, train_dataset, optimizer, criterion = ... + plugin = HybridParallelPlugin(tp_size=2, pp_size=2) + + train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8) + booster = Booster(plugin=plugin) + model, optimizer, criterion, train_dataloader, _ = booster.boost(model, optimizer, criterion, train_dataloader) + ``` + + Args: + tp_size (int): The size of tensor parallelism. Tensor parallelism will not be used when tp_size is set to 1. + pp_size (int): The number of pipeline stages in pipeline parallelism. Pipeline parallelism will not be used when pp_size is set to 1. + sp_size (int): The size of sequence parallelism. + precision (str, optional): Specifies the precision of parameters during training. + Auto-mixied precision will be used when this argument is set to 'fp16' or 'bf16', otherwise model is trained with 'fp32'. + Defaults to 'fp16'. + zero_stage (int, optional): The stage of ZeRO for data parallelism. Can only be choosed from [0, 1, 2]. + When set to 0, ZeRO will not be used. Defaults to 0. + enable_all_optimization (bool, optional): Whether to switch on all the optimizations supported by Shardformer. + Currently all the optimization methods include fused normalization, flash attention and JIT. + Defaults to False. + enable_fused_normalization (bool, optional): Whether to switch on fused normalization in Shardformer. Defaults to False. + enable_flash_attention (bool, optional): Whether to switch on flash attention in Shardformer. Defaults to False. + enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False. + enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False. + sequence_parallelism_mode (str): The Sequence parallelism mode. Can only be choosed from ["split_gather", "ring", "all_to_all"]. Defaults to "split_gather". + enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False. + parallel_output (bool): Whether to keep the output parallel when enabling tensor parallelism. Default to True. + num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None. + microbatch_size (int, optional): Microbatch size when using pipeline parallelism. + Either ``num_microbatches`` or ``microbatch_size`` should be provided if using pipeline. + If ``num_microbatches`` is provided, this will be ignored. Defaults to None. + initial_scale (float, optional): The initial loss scale of AMP. Defaults to 2**16. + min_scale (float, optional): The minimum loss scale of AMP. Defaults to 1. + growth_factor (float, optional): The multiplication factor for increasing loss scale when using AMP. Defaults to 2. + backoff_factor (float, optional): The multiplication factor for decreasing loss scale when using AMP. Defaults to 0.5. + growth_interval (int, optional): The number of steps to increase loss scale when no overflow occurs when using AMP. Defaults to 1000. + hysteresis (int, optional): The number of overflows before decreasing loss scale when using AMP. Defaults to 2. + max_scale (float, optional): The maximum loss scale of AMP. Defaults to 2**32. + max_norm (float, optional): Maximum norm for gradient clipping. Defaults to 0. + broadcast_buffers (bool, optional): Whether to broadcast buffers in the beginning of training when using DDP. Defaults to True. + ddp_bucket_cap_mb (int, optional): The bucket size in MB when using DDP. Defaults to 25. + find_unused_parameters (bool, optional): Whether to find unused parameters when using DDP. Defaults to False. + check_reduction (bool, optional): Whether to check reduction when using DDP. Defaults to False. + gradient_as_bucket_view (bool, optional): Whether to use gradient as bucket view when using DDP. Defaults to False. + static_graph (bool, optional): Whether to use static graph when using DDP. Defaults to False. + zero_bucket_size_in_m (int, optional): Gradient reduce bucket size in million elements when using ZeRO. Defaults to 12. + cpu_offload (bool, optional): Whether to open cpu_offload when using ZeRO. Defaults to False. + communication_dtype (torch.dtype, optional): Communication dtype when using ZeRO. If not specified, the dtype of param will be used. Defaults to None. + overlap_communication (bool, optional): Whether to overlap communication and computation when using ZeRO. Defaults to True. + custom_policy (Policy, optional): Custom policy for Shardformer. Defaults to None. + pp_style (str, optional): The style for pipeline parallelism. Defaults to '1f1b'. + num_model_chunks (int, optional): The number of model chunks for interleaved pipeline parallelism. Defaults to 1. + gradient_checkpoint_config (GradientCheckpointConfig, optional): Configuration for gradient checkpointing. Defaults to None. + enable_metadata_cache (bool, optional): Whether to enable metadata cache for pipeline parallelism. Defaults to True. + make_vocab_size_divisible_by (int, optional): it's used when padding the vocabulary size, to make it choose an faster kenel. Default to 64. + fp8_communication (bool, optional): Whether to enable fp8 communication. Defaults to False. + use_fp8 (bool, optional): Whether to enable fp8 mixed precision training. Defaults to False. + overlap_p2p (bool, optional): Whether to overlap the p2p communication in pipeline parallelism + inner_ring_size (int, optional): The inner ring size of 2D Ring Attention when sp mode is "ring_attn". + It's advisable to not tune this (especially in single-node settings) and let it be heuristically set based on topology by default. + + """ + + def __init__( + self, + tp_size: int, + pp_size: int, + sp_size: int = None, + precision: str = "fp16", + zero_stage: int = 0, + enable_all_optimization: bool = False, + enable_fused_normalization: bool = False, + enable_flash_attention: bool = False, + enable_jit_fused: bool = False, + enable_sequence_parallelism: bool = False, + sequence_parallelism_mode: str = None, + enable_sequence_overlap: bool = False, + parallel_output: bool = True, + num_microbatches: Optional[int] = None, + microbatch_size: Optional[int] = None, + initial_scale: float = 2**16, + min_scale: float = 1, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + max_scale: float = 2**32, + max_norm: float = 0, + broadcast_buffers: bool = True, + ddp_bucket_cap_mb: int = 25, + find_unused_parameters: bool = False, + check_reduction: bool = False, + gradient_as_bucket_view: bool = False, + static_graph: bool = False, + zero_bucket_size_in_m: int = 12, + cpu_offload: bool = False, + communication_dtype: Optional[torch.dtype] = None, + overlap_communication: bool = True, + custom_policy: Policy = None, + pp_style: str = "1f1b", + num_model_chunks: int = 1, + num_layers_per_stage: Optional[List[int]] = None, + gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None, + enable_metadata_cache: bool = True, + make_vocab_size_divisible_by: int = 64, + dp_outside: bool = True, + overlap_p2p: bool = True, + overlap_allgather: bool = False, + fp8_communication: bool = False, + use_fp8: bool = False, + inner_ring_size: int = None, + **hybrid_kwargs, + ) -> None: + super().__init__() + self.logger = get_dist_logger() + + assert ( + dist.get_world_size() % (tp_size * pp_size) == 0 + ), f"World size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}" + + if enable_sequence_parallelism: + self.sequence_parallelism_mode = ( + sequence_parallelism_mode if sequence_parallelism_mode is not None else "all_to_all" + ) + assert ( + self.sequence_parallelism_mode in SUPPORT_SP_MODE + ), f"Sequence parallelism mode {self.sequence_parallelism_mode} is not in the supported list {SUPPORT_SP_MODE}" + if self.sequence_parallelism_mode in ["split_gather", "ring"]: + assert ( + tp_size > 1 + ), f"Sequence parallelism mode {self.sequence_parallelism_mode} must be enabled when using tensor parallelism" + if sp_size != 1: + self.logger.warning( + f"The sp_size will be the same as tp_size in sequence parallelism mode {self.sequence_parallelism_mode}, will ignore the given sequence parallelism size.", + ranks=[0], + ) + self.sp_size = 1 + self.dp_size = dist.get_world_size() // (tp_size * pp_size) + elif self.sequence_parallelism_mode in ["all_to_all", "ring_attn"]: + self.sp_size = 1 if sp_size is None else sp_size + self.dp_size = dist.get_world_size() // (self.sp_size * pp_size * tp_size) + if self.sequence_parallelism_mode == "ring_attn": + enable_flash_attention = True + else: + self.dp_size = dist.get_world_size() // (tp_size * pp_size) + assert ( + sp_size == 1 or sp_size is None + ), f"You should not set sp_size when sequence parallelism is not enabled." + self.sp_size = 1 + + self.tp_size = tp_size + self.pp_size = pp_size + self.precision = precision + self.zero_stage = zero_stage + self.cpu_offload = cpu_offload + self.enable_all_optimization = enable_all_optimization + self.enable_fused_normalization = enable_fused_normalization + self.enable_flash_attention = enable_flash_attention + self.enable_jit_fused = enable_jit_fused + self.enable_sequence_parallelism = enable_sequence_parallelism + self.use_fp8 = use_fp8 + if dp_outside: + self.dp_axis, self.pp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3 + self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size) + if sequence_parallelism_mode == "ring_attn": + # Swap tp and sp since 2D Ring has better inter-node latency + self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.sp_size, self.tp_size) + self.sp_axis = 2 + self.tp_axis = 3 + else: + self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size) + else: + self.pp_axis, self.dp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3 + if sequence_parallelism_mode == "ring_attn": + self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.sp_size, self.tp_size) + self.sp_axis = 2 + self.tp_axis = 3 + else: + self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size) + + self.stage_manager = None + self.schedule = None + self.custom_policy = custom_policy + assert zero_stage in (0, 1, 2) + if self.pp_size > 1: + assert pp_style in ["1f1b", "interleaved"], "Unsupported pipeline parallelism style" + assert pp_style == "interleaved" or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b" + assert ( + num_microbatches is not None or microbatch_size is not None + ), "num_microbatches or microbatch_size must be specified when using pipeline parallelism" + assert ( + self.zero_stage <= 1 + ), "To avoid prohibitive gradient synchronization costs, zero stage must be 0 or 1 when using pipeline parallelism" + self.stage_manager = PipelineStageManager( + self.pg_mesh, + pipeline_axis=self.pp_axis, + enable_interleave=pp_style == "interleaved", + num_model_chunks=num_model_chunks, + num_layers_per_stage=num_layers_per_stage, + ) + + if pp_style == "interleaved": + assert num_model_chunks > 1, "number of model chunks must be > 1 when using interleaved" + self.schedule = InterleavedSchedule( + stage_manager=self.stage_manager, + num_model_chunks=num_model_chunks, + num_microbatch=num_microbatches, + microbatch_size=microbatch_size, + enable_metadata_cache=enable_metadata_cache, + overlap_p2p=overlap_p2p, + fp8_communication=fp8_communication, + ) + elif pp_style == "1f1b": + self.schedule = OneForwardOneBackwardSchedule( + stage_manager=self.stage_manager, + num_microbatches=num_microbatches, + microbatch_size=microbatch_size, + enable_metadata_cache=enable_metadata_cache, + fp8_communication=fp8_communication, + ) + else: + raise NotImplementedError() + if sequence_parallelism_mode == "ring_attn": + if not parallel_output: + self.logger.warning( + "parallel_output must be True for Zigzag Ring Attention, as we've not supported Zigzag all-gather yet.", + ranks=[0], + ) + parallel_output = True + + self.tp_group = self.pg_mesh.get_group_along_axis(self.tp_axis) + self.dp_group = self.pg_mesh.get_group_along_axis(self.dp_axis) + self.pp_group = self.pg_mesh.get_group_along_axis(self.pp_axis) + if self.enable_sequence_parallelism and self.sequence_parallelism_mode in ["split_gather", "ring"]: + self.sp_group = self.pg_mesh.get_group_along_axis(self.tp_axis) + else: + self.sp_group = self.pg_mesh.get_group_along_axis(self.sp_axis) + + self.shard_config = ShardConfig( + tensor_parallel_process_group=self.tp_group, + sequence_parallel_process_group=self.sp_group, + pipeline_stage_manager=self.stage_manager, + enable_tensor_parallelism=self.tp_size > 1, + enable_all_optimization=self.enable_all_optimization, + enable_fused_normalization=self.enable_fused_normalization, + enable_flash_attention=self.enable_flash_attention, + enable_jit_fused=self.enable_jit_fused, + enable_sequence_parallelism=enable_sequence_parallelism, + sequence_parallelism_mode=sequence_parallelism_mode, + enable_sequence_overlap=enable_sequence_overlap, + parallel_output=parallel_output, + make_vocab_size_divisible_by=make_vocab_size_divisible_by, + gradient_checkpoint_config=gradient_checkpoint_config, + fp8_communication=fp8_communication, + inner_ring_size=inner_ring_size, + **hybrid_kwargs, + ) + self.amp_config = dict( + initial_scale=initial_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + hysteresis=hysteresis, + min_scale=min_scale, + max_scale=max_scale, + ) + + self.ddp_config = dict( + broadcast_buffers=broadcast_buffers, + bucket_cap_mb=ddp_bucket_cap_mb, + find_unused_parameters=find_unused_parameters, + check_reduction=check_reduction, + gradient_as_bucket_view=gradient_as_bucket_view, + static_graph=static_graph, + ) + + self.zero_config = dict( + reduce_bucket_size=zero_bucket_size_in_m * 1024 * 1024, + communication_dtype=communication_dtype, + overlap_communication=overlap_communication, + cpu_offload=cpu_offload, + partition_grad=(self.zero_stage == 2), + forced_dtype=PRECISION_TORCH_TYPE[precision], + overlap_allgather=overlap_allgather, + fp8_communication=fp8_communication, + ) + + self.max_norm = max_norm + + def __del__(self): + """Destroy the process groups in ProcessGroupMesh""" + self.pg_mesh.destroy_mesh_process_groups() + + @property + def enable_pipeline_parallelism(self) -> bool: + return self.pp_size > 1 + + def supported_devices(self) -> List[str]: + return ["cuda", "npu"] + + def supported_precisions(self) -> List[str]: + return ["fp16", "bf16", "fp32"] + + def control_device(self) -> bool: + return True + + def control_precision(self) -> bool: + return True + + def support_no_sync(self) -> bool: + return True + + def support_lora(self) -> bool: + return True + + def control_checkpoint_io(self) -> bool: + return True + + def configure( + self, + model: Module, + optimizer: Optional[Optimizer] = None, + criterion: Optional[Callable] = None, + dataloader: Optional[DataLoader] = None, + lr_scheduler: Optional[LRScheduler] = None, + ) -> Tuple[Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]: + param_info = get_param_info(optimizer) + + # TODO: Support Galore + ZeRO + zero_stage = self.zero_stage + zero_config = deepcopy(self.zero_config) + + # Replace with distributed implementation if exists + optimizer = cast_to_distributed(optimizer) + + if isinstance(optimizer, DistGaloreAwamW) and zero_stage > 0 and self.dp_size > 0: + self.logger.warning( + "Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.", + ranks=[0], + ) + zero_config["partition_grad"] = False + zero_stage = 0 + + if not isinstance(model, ModelWrapper): + # Shouldn't use pp (frequent grad accumulation) with torch ddp + use_ddp = (self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0) or ( + self.dp_size == 1 and self.pp_size == 1 + ) + # sync gradients across DP * SP ranks + # Apply Hybrid ZeRO across DP * SP ranks + if self.enable_sequence_parallelism and not is_share_sp_tp(self.sequence_parallelism_mode): + dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis]) + self.dp_size = get_world_size(dp_group) + else: + dp_group = self.dp_group + model = HybridParallelModule( + model, + precision=self.precision, + shard_config=self.shard_config, + dp_group=dp_group, + tp_group=self.tp_group, + sp_group=self.sp_group, + use_ddp=use_ddp, + ddp_config=self.ddp_config, + custom_policy=self.custom_policy, + overlap_allgather=(self.zero_stage > 0 and self.zero_config["overlap_allgather"]), + use_fp8=self.use_fp8, + ) + if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): + if zero_stage == 0: + is_zero = False + if self.precision in ["fp16", "bf16"]: + optimizer = HybridParallelAMPOptimizer( + optimizer, + model, + use_pipeline=self.enable_pipeline_parallelism, + param_info=param_info, + precision=self.precision, + max_norm=self.max_norm, + pp_process_group=self.pp_group, + tp_process_group=self.tp_group, + **self.amp_config, + ) + else: + optimizer = HybridParallelNaiveOptimizer( + optimizer, + model, + use_pipeline=self.enable_pipeline_parallelism, + param_info=param_info, + max_norm=self.max_norm, + pp_process_group=self.pp_group, + tp_process_group=self.tp_group, + ) + else: + is_zero = self.dp_size > 1 + if self.dp_size == 1: + self.logger.warning( + "Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. " + "If you do not intend to use cpu_offload, please consider set zero_stage=0.", + ranks=[0], + ) + + assert self.precision != "fp32", "Please set precision to 'fp16' or 'bf16' when using ZeRO." + optimizer = HybridParallelZeroOptimizer( + optimizer, + model, + use_pipeline=self.enable_pipeline_parallelism, + param_info=param_info, + dp_process_group=dp_group, + tp_process_group=self.tp_group, + pp_process_group=self.pp_group, + verbose=True, + clip_grad_norm=self.max_norm, + **zero_config, + **self.amp_config, + ) + # inject update_master_params + model.update_master_params = MethodType(optimizer.update_master_params, model) + + # Setup optimizers that require global states + optim = optimizer.optim + if isinstance(optim, DistributedOptim): + shard_to_param = optimizer.get_master_to_working_map() if is_zero else {} + padding_map = optimizer.get_param_padding_map() if is_zero else defaultdict(int) + optim.setup_distributed(self.tp_group, self.dp_group, shard_to_param, padding_map, is_zero) + + return model, optimizer, criterion, dataloader, lr_scheduler + + def execute_pipeline( + self, + data_iter: Iterator, + model: HybridParallelModule, + criterion: Callable[[Any, Any], torch.Tensor], + optimizer: Optional[ + Union[HybridParallelNaiveOptimizer, HybridParallelAMPOptimizer, HybridParallelZeroOptimizer] + ] = None, + return_loss: bool = True, + return_outputs: bool = False, + ) -> dict: + assert self.enable_pipeline_parallelism, "pipeline parallelism is not enabled" + + if return_outputs: + self.logger.warning("return_outputs may lead to significant extra memory consumption.", ranks=[0]) + + # Create a context for gradient synchronization based on the optimizer type. + # If it's a HybridParallelZeroOptimizer, use optimizer.no_sync(); otherwise, use model.no_sync(). + # This is to avoid redundant gradient reduction in pipeline parallelism (multiple microbatch values should be reduced once), + # so we disable it, performing manual reduction instead. + ctx = optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync() + + with ctx, model._hook_context(): + outputs = self.schedule.forward_backward_step( + model, data_iter, criterion, optimizer, return_loss, return_outputs + ) + + # run with gradients accumulation + if model.require_grad_sync == False or ( + isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False + ): + return outputs + + # Synchronize the grads of shared parameters of the model. + model.sync_shared_params() + # Synchronize sequence parallelism gradients of the model. + model.sync_sp_grads() + + # Check if the optimizer is a HybridParallelZeroOptimizer and synchronize data parallelism gradients if so. + # Otherwise, synchronize data parallelism gradients of the model. + # This is because these are two different forms of data parallelism. + if isinstance(optimizer, HybridParallelZeroOptimizer): + optimizer.sync_dp_grads() + else: + model.sync_dp_grads() + + return outputs + + def prepare_dataloader( + self, + dataset, + batch_size, + shuffle=False, + seed=1024, + drop_last=False, + pin_memory=False, + num_workers=0, + distributed_sampler_cls=None, + **kwargs, + ): + r""" + Prepare a dataloader for distributed training. The dataloader will be wrapped by + `torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`. + + + Args: + dataset (`torch.utils.data.Dataset`): The dataset to be loaded. + shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False. + seed (int, optional): Random worker seed for sampling, defaults to 1024. + add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True. + drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size + is not divisible by the batch size. If False and the size of dataset is not divisible by + the batch size, then the last batch will be smaller, defaults to False. + pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False. + num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0. + kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in + `DataLoader `_. + + Returns:` + :class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing. + """ + _kwargs = kwargs.copy() + distributed_sampler_cls = distributed_sampler_cls or DistributedSampler + sampler = distributed_sampler_cls( + dataset, + num_replicas=self.dp_group.size(), + rank=dist.get_group_rank(self.dp_group, global_rank=dist.get_rank()), + shuffle=shuffle, + ) + + # Deterministic dataloader + def seed_worker(worker_id): + worker_seed = seed + np.random.seed(worker_seed) + torch.manual_seed(worker_seed) + random.seed(worker_seed) + + return DataLoader( + dataset, + batch_size=batch_size, + sampler=sampler, + worker_init_fn=seed_worker, + drop_last=drop_last, + pin_memory=pin_memory, + num_workers=num_workers, + **_kwargs, + ) + + def get_checkpoint_io(self) -> CheckpointIO: + return HybridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage) + + def no_sync(self, model: Module, optimizer: OptimizerWrapper) -> Iterator[None]: + assert ( + self.zero_stage != 2 + ), "ZERO2 is not compatible with no_sync function, please run gradient accumulation with gradient synchronization allowed." + return optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync() + + def enable_lora( + self, + model: Module, + pretrained_dir: Optional[str] = None, + lora_config: Optional[Dict] = None, + bnb_quantization_config: Optional[BnbQuantizationConfig] = None, + ) -> Module: + from peft import PeftModel, get_peft_model + + assert not isinstance(model, HybridParallelModule), "Lora should be enabled before boosting the model." + assert self.pp_size == 1 and self.tp_size == 1 + self.lora_enabled = True + self.logger.warning("You have enabled LoRa training. Please check the hyperparameters such as lr", ranks=[0]) + + if bnb_quantization_config is not None: + model = quantize_model(model, bnb_quantization_config) + + if pretrained_dir is None: + peft_model = get_peft_model(model, lora_config) + else: + peft_model = PeftModel.from_pretrained(model, pretrained_dir, is_trainable=True) + return peft_model diff --git a/toolbox/ColossalAI/v0.4.4/patches/colossalai/cluster/dist_coordinator.py b/toolbox/ColossalAI/v0.4.4/patches/colossalai/cluster/dist_coordinator.py new file mode 100644 index 0000000000000000000000000000000000000000..1fa89c8a1e52a560a4c8568e5e342c05feb9a640 --- /dev/null +++ b/toolbox/ColossalAI/v0.4.4/patches/colossalai/cluster/dist_coordinator.py @@ -0,0 +1,214 @@ +#!/usr/bin/env python3 +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +import functools +import os +from contextlib import contextmanager + +import torch.distributed as dist +from torch.distributed import ProcessGroup + +from colossalai.context.singleton_meta import SingletonMeta + + +class DistCoordinator(metaclass=SingletonMeta): + """ + This class is used to coordinate distributed training. It is a singleton class, which means that there is only one instance of this + class in the whole program. + + There are some terms that are used in this class: + - rank: the rank of the current process + - world size: the total number of processes + - local rank: the rank of the current process on the current node + - master: the process with rank 0 + - node master: the process with local rank 0 on the current node + + + ```python + from colossalai.cluster.dist_coordinator import DistCoordinator + coordinator = DistCoordinator() + + if coordinator.is_master(): + do_something() + + coordinator.print_on_master('hello world') + ``` + + Attributes: + rank (int): the rank of the current process + world_size (int): the total number of processes + local_rank (int): the rank of the current process on the current node + """ + + def __init__(self): + assert ( + dist.is_initialized() + ), "Distributed is not initialized. Please call `torch.distributed.init_process_group` or `colossalai.launch` first." + self._rank = dist.get_rank() + self._world_size = dist.get_world_size() + # this is often passed by launchers such as torchrun + self._local_rank = int(os.environ.get("LOCAL_RANK", -1)) + + @property + def rank(self) -> int: + return self._rank + + @property + def world_size(self) -> int: + return self._world_size + + @property + def local_rank(self) -> int: + return self._local_rank + + def _assert_local_rank_set(self): + """ + Assert that the local rank is set. This is often passed by launchers such as torchrun. + """ + assert ( + self.local_rank >= 0 + ), "The environment variable LOCAL_RANK is not set, thus the coordinator is not aware of the local rank of the current process." + + def is_master(self, process_group: ProcessGroup = None) -> bool: + """ + Check if the current process is the master process (rank is 0). It can accept a sub process group to check the rank 0 with respect to the process. + + Args: + process_group (ProcessGroup, optional): process group to use for the rank 0 check. Defaults to None, which refers to the default process group. + + Returns: + bool: True if the current process is the master process, False otherwise + """ + rank = dist.get_rank(group=process_group) + return rank == 0 + + def is_node_master(self) -> bool: + """ + Check if the current process is the master process on the current node (local rank is 0). + + Returns: + bool: True if the current process is the master process on the current node, False otherwise + """ + self._assert_local_rank_set() + return self.local_rank == 0 + + def is_last_process(self, process_group: ProcessGroup = None) -> bool: + """ + Check if the current process is the last process (rank is world size - 1). It can accept a sub process group to check the last rank with respect to the process. + + Args: + process_group (ProcessGroup, optional): process group to use for the last rank check. Defaults to None, which refers to the default process group. + + Returns: + bool: True if the current process is the last process, False otherwise + """ + rank = dist.get_rank(group=process_group) + world_size = dist.get_world_size(group=process_group) + return rank == world_size - 1 + + def print_on_master(self, msg: str, process_group: ProcessGroup = None): + """ + Print message only from rank 0. + + Args: + msg (str): message to print + process_group (ProcessGroup, optional): process group to use for the rank 0 check. Defaults to None, which refers to the default process group. + """ + rank = dist.get_rank(group=process_group) + if rank == 0: + print(msg) + + def print_on_node_master(self, msg: str): + """ + Print message only from local rank 0. Local rank 0 refers to the 0th process running the current node. + + Args: + msg (str): message to print + """ + self._assert_local_rank_set() + if self.local_rank == 0: + print(msg) + + def print_on_last_process(self, msg: str): + """ + Print message only from local rank 0. Local rank 0 refers to the 0th process running the current node. + + Args: + msg (str): message to print + """ + if self.is_last_process(): + print(msg) + + @contextmanager + def priority_execution(self, executor_rank: int = 0, process_group: ProcessGroup = None): + """ + This context manager is used to allow one process to execute while blocking all + other processes in the same process group. This is often useful when downloading is required + as we only want to download in one process to prevent file corruption. + + + ```python + from colossalai.cluster import DistCoordinator + dist_coordinator = DistCoordinator() + with dist_coordinator.priority_execution(): + dataset = CIFAR10(root='./data', download=True) + ``` + + Args: + executor_rank (int): the process rank to execute without blocking, all other processes will be blocked + process_group (ProcessGroup, optional): process group to use for the executor rank check. Defaults to None, which refers to the default process group. + """ + rank = dist.get_rank(group=process_group) + should_block = rank != executor_rank + + if should_block: + self.block_all(process_group) + + yield + + if not should_block: + self.block_all(process_group) + + def destroy(self, process_group: ProcessGroup = None): + """ + Destroy the distributed process group. + + Args: + process_group (ProcessGroup, optional): process group to destroy. Defaults to None, which refers to the default process group. + """ + dist.destroy_process_group(process_group) + + def block_all(self, process_group: ProcessGroup = None): + """ + Block all processes in the process group. + + Args: + process_group (ProcessGroup, optional): process group to block. Defaults to None, which refers to the default process group. + """ + dist.barrier(group=process_group) + + def on_master_only(self, process_group: ProcessGroup = None): + """ + A function wrapper that only executes the wrapped function on the master process (rank 0). + + ```python + from colossalai.cluster import DistCoordinator + dist_coordinator = DistCoordinator() + + @dist_coordinator.on_master_only() + def print_on_master(msg): + print(msg) + ``` + """ + is_master = self.is_master(process_group) + + # define an inner function + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + if is_master: + return func(*args, **kwargs) + + return wrapper + + return decorator diff --git a/toolbox/ColossalAI/v0.4.4/patches/colossalai/legacy/amp/torch_amp/_grad_scaler.py b/toolbox/ColossalAI/v0.4.4/patches/colossalai/legacy/amp/torch_amp/_grad_scaler.py new file mode 100644 index 0000000000000000000000000000000000000000..4bc590377a4d8e976b8faf2ace04c31cefcd3927 --- /dev/null +++ b/toolbox/ColossalAI/v0.4.4/patches/colossalai/legacy/amp/torch_amp/_grad_scaler.py @@ -0,0 +1,595 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +#!/usr/bin/env python +# -*- encoding: utf-8 -*- +# modified from https://github.com/pytorch/pytorch/blob/master/torch/cuda/amp/grad_scaler.py +# to support tensor parallel + +import warnings +from collections import abc, defaultdict +from enum import Enum +from typing import Any, Dict, List, Optional, Tuple + +import torch +import torch.distributed as dist +from packaging import version +from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors + +from colossalai.legacy.context import ParallelMode +from colossalai.legacy.core import global_context as gpc + + +class _MultiDeviceReplicator(object): + """ + Lazily serves copies of a tensor to requested devices. Copies are cached per-device. + """ + + def __init__(self, master_tensor: torch.Tensor) -> None: + assert master_tensor.is_cuda or master_tensor.device.type == "xla" + self.master = master_tensor + self._per_device_tensors: Dict[torch.device, torch.Tensor] = {} + + def get(self, device) -> torch.Tensor: + retval = self._per_device_tensors.get(device, None) + if retval is None: + retval = self.master.to(device=device, non_blocking=True, copy=True) + self._per_device_tensors[device] = retval + return retval + + +# Defines default_factory for GradScaler's _per_optimizer_states defaultdict, +# as well as associated "enum" values. Prefers defining these at top level because +# - Lambdas can't be pickled, so we don't want to supply a lambda as the factory. +# - Defining READY, UNSCALED, STEPPED and _refresh_per_optimizer_state within GradScaler +# causes a circular reference, which we'd rather avoid. +class OptState(Enum): + READY = 0 + UNSCALED = 1 + STEPPED = 2 + + +def _refresh_per_optimizer_state(): + return {"stage": OptState.READY, "found_inf_per_device": {}} + + +class GradScaler(object): + _scale: Optional[torch.Tensor] + _grows_tracker: Optional[torch.Tensor] + _per_optimizer_states: Dict[int, Dict[str, Any]] + """ + An instance ``scaler`` of :class:`GradScaler` helps perform the steps of gradient scaling + conveniently. + + * ``scaler.scale(loss)`` multiplies a given loss by ``scaler``'s current scale factor. + * ``scaler.step(optimizer)`` safely unscales gradients and calls ``optimizer.step()``. + * ``scaler.update()`` updates ``scaler``'s scale factor. + + Example: + + # Creates a GradScaler once at the beginning of training. + scaler = GradScaler() + + for epoch in epochs: + for input, target in data: + optimizer.zero_grad() + output = model(input) + loss = loss_fn(output, target) + + # Scales loss. Calls backward() on scaled loss to create scaled gradients. + scaler.scale(loss).backward() + + # scaler.step() first unscales gradients of the optimizer's params. + # If gradients don't contain infs/NaNs, optimizer.step() is then called, + # otherwise, optimizer.step() is skipped. + scaler.step(optimizer) + + # Updates the scale for next iteration. + scaler.update() + + See the :ref:`Automatic Mixed Precision examples` for usage + (along with autocasting) in more complex cases like gradient clipping, gradient accumulation, gradient penalty, + and multiple losses/optimizers. + + ``scaler`` dynamically estimates the scale factor each iteration. To minimize gradient underflow, + a large scale factor should be used. However, ``float16`` values can "overflow" (become inf or NaN) if + the scale factor is too large. Therefore, the optimal scale factor is the largest factor that can be used + without incurring inf or NaN gradient values. + ``scaler`` approximates the optimal scale factor over time by checking the gradients for infs and NaNs during every + ``scaler.step(optimizer)`` (or optional separate ``scaler.unscale_(optimizer)``, see :meth:`unscale_`). + + * If infs/NaNs are found, ``scaler.step(optimizer)`` skips the underlying ``optimizer.step()`` (so the params + themselves remain uncorrupted) and ``update()`` multiplies the scale by ``backoff_factor``. + + * If no infs/NaNs are found, ``scaler.step(optimizer)`` runs the underlying ``optimizer.step()`` as usual. + If ``growth_interval`` unskipped iterations occur consecutively, ``update()`` multiplies the scale by + ``growth_factor``. + + The scale factor often causes infs/NaNs to appear in gradients for the first few iterations as its + value calibrates. ``scaler.step`` will skip the underlying ``optimizer.step()`` for these + iterations. After that, step skipping should occur rarely (once every few hundred or thousand iterations). + + Args: + init_scale (float, optional, default=2.**16): Initial scale factor. + growth_factor (float, optional, default=2.0): Factor by which the scale is multiplied during + :meth:`update` if no inf/NaN gradients occur for ``growth_interval`` consecutive iterations. + backoff_factor (float, optional, default=0.5): Factor by which the scale is multiplied during + :meth:`update` if inf/NaN gradients occur in an iteration. + growth_interval (int, optional, default=2000): Number of consecutive iterations without inf/NaN gradients + that must occur for the scale to be multiplied by ``growth_factor``. + enabled (bool, optional, default=True): If ``False``, disables gradient scaling. :meth:`step` simply + invokes the underlying ``optimizer.step()``, and other methods become no-ops. + """ + + def __init__(self, init_scale=2.0**16, growth_factor=2.0, backoff_factor=0.5, growth_interval=2000, enabled=True): + if enabled and not torch.cuda.is_available(): + warnings.warn("torch.cuda.amp.GradScaler is enabled, but CUDA is not available. Disabling.") + self._enabled = False + else: + self._enabled = enabled + + # check version + torch_version = version.parse(torch.__version__) + assert torch_version.major == 1 + if torch_version.minor > 8: + self._higher_than_torch18 = True + else: + self._higher_than_torch18 = False + + if self._enabled: + assert growth_factor > 1.0, "The growth factor must be > 1.0." + assert backoff_factor < 1.0, "The backoff factor must be < 1.0." + + self._init_scale = init_scale + # self._scale will be lazily initialized during the first call to scale() + self._scale = None + self._growth_factor = growth_factor + self._backoff_factor = backoff_factor + self._growth_interval = growth_interval + self._init_growth_tracker = 0 + # self._growth_tracker will be lazily initialized during the first call to scale() + self._growth_tracker = None + self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state) + + def _check_scale_growth_tracker(self, funcname) -> Tuple[torch.Tensor, torch.Tensor]: + fix = "This may indicate your script did not use scaler.scale(loss or outputs) earlier in the iteration." + assert self._scale is not None, "Attempted {} but _scale is None. ".format(funcname) + fix + assert self._growth_tracker is not None, "Attempted {} but _growth_tracker is None. ".format(funcname) + fix + return (self._scale, self._growth_tracker) + + def _lazy_init_scale_growth_tracker(self, dev): + assert self._growth_tracker is None, "_growth_tracker initialized before _scale" + self._scale = torch.full((1,), self._init_scale, dtype=torch.float32, device=dev) + self._growth_tracker = torch.full((1,), self._init_growth_tracker, dtype=torch.int32, device=dev) + + def scale(self, outputs): + """ + Multiplies ('scales') a tensor or list of tensors by the scale factor. + + Returns scaled outputs. If this instance of :class:`GradScaler` is not enabled, outputs are returned + unmodified. + + Args: + outputs (Tensor or iterable of Tensors): Outputs to scale. + """ + if not self._enabled: + return outputs + + # Short-circuit for the common case. + if isinstance(outputs, torch.Tensor): + assert outputs.is_cuda or outputs.device.type == "xla" + if self._scale is None: + self._lazy_init_scale_growth_tracker(outputs.device) + assert self._scale is not None + return outputs * self._scale.to(device=outputs.device, non_blocking=True) + + # Invoke the more complex machinery only if we're treating multiple outputs. + # holds a reference that can be overwritten by apply_scale + stash: List[_MultiDeviceReplicator] = [] + + def apply_scale(val): + if isinstance(val, torch.Tensor): + assert val.is_cuda or val.device.type == "xla" + if len(stash) == 0: + if self._scale is None: + self._lazy_init_scale_growth_tracker(val.device) + assert self._scale is not None + stash.append(_MultiDeviceReplicator(self._scale)) + return val * stash[0].get(val.device) + elif isinstance(val, abc.Iterable): + iterable = map(apply_scale, val) + if isinstance(val, list) or isinstance(val, tuple): + return type(val)(iterable) + else: + return iterable + else: + raise ValueError("outputs must be a Tensor or an iterable of Tensors") + + return apply_scale(outputs) + + def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16): + per_device_inv_scale = _MultiDeviceReplicator(inv_scale) + per_device_found_inf = _MultiDeviceReplicator(found_inf) + + # To set up _amp_foreach_non_finite_check_and_unscale_, split grads by device and dtype. + # There could be hundreds of grads, so we'd like to iterate through them just once. + # However, we don't know their devices or dtypes in advance. + + # https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict + # Google says mypy struggles with defaultdicts type annotations. + per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) # type: ignore[var-annotated] + with torch.no_grad(): + for group in optimizer.param_groups: + for param in group["params"]: + if param.grad is None: + continue + if (not allow_fp16) and param.grad.dtype == torch.float16: + raise ValueError("Attempting to unscale FP16 gradients.") + if param.grad.is_sparse: + # is_coalesced() == False means the sparse grad has values with duplicate indices. + # coalesce() deduplicates indices and adds all values that have the same index. + # For scaled fp16 values, there's a good chance coalescing will cause overflow, + # so we should check the coalesced _values(). + if param.grad.dtype is torch.float16: + param.grad = param.grad.coalesce() + to_unscale = param.grad._values() + else: + to_unscale = param.grad + + # TODO: is there a way to split by device and dtype without appending in the inner loop? + per_device_and_dtype_grads[to_unscale.device][to_unscale.dtype].append(to_unscale) + + for device, per_dtype_grads in per_device_and_dtype_grads.items(): + for grads in per_dtype_grads.values(): + torch._amp_foreach_non_finite_check_and_unscale_( + grads, per_device_found_inf.get(device), per_device_inv_scale.get(device) + ) + # For tensor parallel parameters it should be all-reduced over tensor parallel process group + if gpc.is_initialized(ParallelMode.MODEL) and gpc.get_world_size(ParallelMode.MODEL) > 1: + vals = [val for val in per_device_found_inf._per_device_tensors.values()] + coalesced = _flatten_dense_tensors(vals) + dist.all_reduce(coalesced, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.MODEL)) + for buf, synced in zip(vals, _unflatten_dense_tensors(coalesced, vals)): + buf.copy_(synced) + return per_device_found_inf._per_device_tensors + + def unscale_(self, optimizer): + """ + Divides ("unscales") the optimizer's gradient tensors by the scale factor. + + :meth:`unscale_` is optional, serving cases where you need to + :ref:`modify or inspect gradients` + between the backward pass(es) and :meth:`step`. + If :meth:`unscale_` is not called explicitly, gradients will be unscaled automatically during :meth:`step`. + + Simple example, using :meth:`unscale_` to enable clipping of unscaled gradients:: + + ... + scaler.scale(loss).backward() + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) + scaler.step(optimizer) + scaler.update() + + Args: + optimizer (torch.optim.Optimizer): Optimizer that owns the gradients to be unscaled. + + .. note:: + :meth:`unscale_` does not incur a CPU-GPU sync. + + .. warning:: + :meth:`unscale_` should only be called once per optimizer per :meth:`step` call, + and only after all gradients for that optimizer's assigned parameters have been accumulated. + Calling :meth:`unscale_` twice for a given optimizer between each :meth:`step` triggers a RuntimeError. + + .. warning:: + :meth:`unscale_` may unscale sparse gradients out of place, replacing the ``.grad`` attribute. + """ + if not self._enabled: + return + + self._check_scale_growth_tracker("unscale_") + + optimizer_state = self._per_optimizer_states[id(optimizer)] + + if optimizer_state["stage"] is OptState.UNSCALED: + raise RuntimeError("unscale_() has already been called on this optimizer since the last update().") + elif optimizer_state["stage"] is OptState.STEPPED: + raise RuntimeError("unscale_() is being called after step().") + + # FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64. + assert self._scale is not None + inv_scale = self._scale.reciprocal().float() + found_inf = torch.full((1,), 0.0, dtype=torch.float32, device=self._scale.device) + + optimizer_state["found_inf_per_device"] = self._unscale_grads_(optimizer, inv_scale, found_inf, False) + optimizer_state["stage"] = OptState.UNSCALED + + def _maybe_opt_step(self, optimizer, optimizer_state, *args, **kwargs): + retval = None + if not sum(v.item() for v in optimizer_state["found_inf_per_device"].values()): + retval = optimizer.step(*args, **kwargs) + return retval + + def step(self, optimizer, *args, **kwargs): + """ + :meth:`step` carries out the following two operations: + + 1. Internally invokes ``unscale_(optimizer)`` (unless :meth:`unscale_` was explicitly called for ``optimizer`` + earlier in the iteration). As part of the :meth:`unscale_`, gradients are checked for infs/NaNs. + 2. If no inf/NaN gradients are found, invokes ``optimizer.step()`` using the unscaled + gradients. Otherwise, ``optimizer.step()`` is skipped to avoid corrupting the params. + + ``*args`` and ``**kwargs`` are forwarded to ``optimizer.step()``. + + Returns the return value of ``optimizer.step(*args, **kwargs)``. + + Args: + optimizer (torch.optim.Optimizer): Optimizer that applies the gradients. + args: Any arguments. + kwargs: Any keyword arguments. + + .. warning:: + Closure use is not currently supported. + """ + if not self._enabled: + return optimizer.step(*args, **kwargs) + + if "closure" in kwargs: + raise RuntimeError("Closure use is not currently supported if GradScaler is enabled.") + + self._check_scale_growth_tracker("step") + + optimizer_state = self._per_optimizer_states[id(optimizer)] + + if optimizer_state["stage"] is OptState.STEPPED: + raise RuntimeError("step() has already been called since the last update().") + + retval = None + + if hasattr(optimizer, "_step_supports_amp_scaling") and optimizer._step_supports_amp_scaling: + # This optimizer has customized scale-handling logic, so we can call optimizer.step() directly. + # The contract with custom optimizers is that their step() should accept an additional, + # optional grad_scaler kwarg. We append self to the kwargs so the custom optimizer has full information: + # it can query its own state, invoke unscale_ on itself, etc + retval = optimizer.step(*args, **dict(kwargs, grad_scaler=self)) + optimizer_state["stage"] = OptState.STEPPED + return retval + + if optimizer_state["stage"] is OptState.READY: + self.unscale_(optimizer) + + assert len(optimizer_state["found_inf_per_device"]) > 0, "No inf checks were recorded for this optimizer." + + retval = self._maybe_opt_step(optimizer, optimizer_state, *args, **kwargs) + + optimizer_state["stage"] = OptState.STEPPED + + return retval + + def update(self, new_scale=None): + """ + Updates the scale factor. + + If any optimizer steps were skipped the scale is multiplied by ``backoff_factor`` + to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively, + the scale is multiplied by ``growth_factor`` to increase it. + + Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not + used directly, it's used to fill GradScaler's internal scale tensor. So if + ``new_scale`` was a tensor, later in-place changes to that tensor will not further + affect the scale GradScaler uses internally.) + + Args: + new_scale (float or :class:`torch.cuda.FloatTensor`, optional, default=None): New scale factor. + + .. warning:: + :meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has + been invoked for all optimizers used this iteration. + """ + if not self._enabled: + return + + _scale, _growth_tracker = self._check_scale_growth_tracker("update") + + if new_scale is not None: + # Accept a new user-defined scale. + if isinstance(new_scale, float): + self._scale.fill_(new_scale) # type: ignore[union-attr] + else: + reason = "new_scale should be a float or a 1-element torch.cuda.FloatTensor with requires_grad=False." + # type: ignore[attr-defined] + assert isinstance(new_scale, torch.cuda.FloatTensor), reason + assert new_scale.numel() == 1, reason + assert new_scale.requires_grad is False, reason + self._scale.copy_(new_scale) # type: ignore[union-attr] + else: + # Consume shared inf/nan data collected from optimizers to update the scale. + # If all found_inf tensors are on the same device as self._scale, this operation is asynchronous. + found_infs = [ + found_inf.to(device=_scale.device, non_blocking=True) + for state in self._per_optimizer_states.values() + for found_inf in state["found_inf_per_device"].values() + ] + + assert len(found_infs) > 0, "No inf checks were recorded prior to update." + + found_inf_combined = found_infs[0] + if len(found_infs) > 1: + for i in range(1, len(found_infs)): + found_inf_combined += found_infs[i] + + if self._higher_than_torch18: + torch._amp_update_scale_( + _scale, + _growth_tracker, + found_inf_combined, + self._growth_factor, + self._backoff_factor, + self._growth_interval, + ) + else: + self._scale = torch._amp_update_scale( + _growth_tracker, + _scale, + found_inf_combined, + self._growth_factor, + self._backoff_factor, + self._growth_interval, + ) + + # To prepare for next iteration, clear the data collected from optimizers this iteration. + self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state) + + def _get_scale_async(self): + return self._scale + + def get_scale(self): + """ + Returns a Python float containing the current scale, or 1.0 if scaling is disabled. + + .. warning:: + :meth:`get_scale` incurs a CPU-GPU sync. + """ + if self._enabled: + return self._init_scale if self._scale is None else self._get_scale_async().item() + else: + return 1.0 + + def get_growth_factor(self): + r""" + Returns a Python float containing the scale growth factor. + """ + return self._growth_factor + + def set_growth_factor(self, new_factor): + r""" + Args: + new_scale (float): Value to use as the new scale growth factor. + """ + self._growth_factor = new_factor + + def get_backoff_factor(self): + r""" + Returns a Python float containing the scale backoff factor. + """ + return self._backoff_factor + + def set_backoff_factor(self, new_factor): + r""" + Args: + new_scale (float): Value to use as the new scale backoff factor. + """ + self._backoff_factor = new_factor + + def get_growth_interval(self): + r""" + Returns a Python int containing the growth interval. + """ + return self._growth_interval + + def set_growth_interval(self, new_interval): + r""" + Args: + new_interval (int): Value to use as the new growth interval. + """ + self._growth_interval = new_interval + + def _get_growth_tracker(self): + if self._enabled: + return self._init_growth_tracker if self._growth_tracker is None else self._growth_tracker.item() + else: + return 0 + + def is_enabled(self): + r""" + Returns a bool indicating whether this instance is enabled. + """ + return self._enabled + + def state_dict(self): + r""" + Returns the state of the scaler as a :class:`dict`. It contains five entries: + + * ``"scale"`` - a Python float containing the current scale + * ``"growth_factor"`` - a Python float containing the current growth factor + * ``"backoff_factor"`` - a Python float containing the current backoff factor + * ``"growth_interval"`` - a Python int containing the current growth interval + * ``"_growth_tracker"`` - a Python int containing the number of recent consecutive unskipped steps. + + If this instance is not enabled, returns an empty dict. + + .. note:: + If you wish to checkpoint the scaler's state after a particular iteration, :meth:`state_dict` + should be called after :meth:`update`. + """ + return ( + { + "scale": self.get_scale(), + "growth_factor": self._growth_factor, + "backoff_factor": self._backoff_factor, + "growth_interval": self._growth_interval, + "_growth_tracker": self._get_growth_tracker(), + } + if self._enabled + else {} + ) + + def load_state_dict(self, state_dict): + r""" + Loads the scaler state. If this instance is disabled, :meth:`load_state_dict` is a no-op. + + Args: + state_dict(dict): scaler state. Should be an object returned from a call to :meth:`state_dict`. + """ + if not self._enabled: + return + + if len(state_dict) == 0: + raise RuntimeError( + "The source state dict is empty, possibly because it was saved " + "from a disabled instance of GradScaler." + ) + + self._init_scale = state_dict["scale"] + if self._scale is not None: + self._scale.fill_(state_dict["scale"]) + self._growth_factor = state_dict["growth_factor"] + self._backoff_factor = state_dict["backoff_factor"] + self._growth_interval = state_dict["growth_interval"] + self._init_growth_tracker = state_dict["_growth_tracker"] + if self._growth_tracker is not None: + self._growth_tracker.fill_(state_dict["_growth_tracker"]) + + def __getstate__(self): + state = self.__dict__.copy() + if self._enabled: + assert len(self._per_optimizer_states) == 0, ( + "A GradScaler instance may only be pickled at the beginning " + "of an iteration, or at the end after scaler.update()." + ) + # Pickling _scale and _growth_tracker Tensors directly triggers + # "warnings.warn("pickle support for Storage will be removed in 1.5..." + # so instead, we set the unpickled instance up to reinitialize them lazily. + state["_init_scale"] = self.get_scale() + state["_init_growth_tracker"] = self._get_growth_tracker() + state["_scale"] = None + state["_growth_tracker"] = None + return state + + def __setstate__(self, state): + self.__dict__.update(state) + + def _check_inf_per_device(self, optimizer): + _scale, _ = self._check_scale_growth_tracker("_check_inf_per_device") + + dummy_inv_scale = torch.full((1,), 1.0, dtype=torch.float32, device=_scale.device) + found_inf = torch.full((1,), 0.0, dtype=torch.float32, device=_scale.device) + + self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] = self._unscale_grads_( + optimizer, dummy_inv_scale, found_inf, True + ) + + return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] + + def _found_inf_per_device(self, optimizer): + return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] diff --git a/toolbox/ColossalAI/v0.4.4/patches/colossalai/shardformer/layer/__init__.py b/toolbox/ColossalAI/v0.4.4/patches/colossalai/shardformer/layer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2bc78865bb9ddf232ff5e0023943ba209465cfd9 --- /dev/null +++ b/toolbox/ColossalAI/v0.4.4/patches/colossalai/shardformer/layer/__init__.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python3 +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +from ._operation import all_to_all_comm +from .attn import AttnMaskType, ColoAttention, RingAttention, get_pad_info +from .dropout import DropoutForParallelInput, DropoutForReplicatedInput +from .embedding import Embedding1D, PaddingEmbedding, VocabParallelEmbedding1D +from .linear import Linear1D_Col, Linear1D_Row, PaddingLMHead, VocabParallelLMHead1D,LinearWithFusedGradientAccu +from .loss import cross_entropy_1d, dist_cross_entropy +from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm +from .parallel_module import ParallelModule +from .qkv_fused_linear import FusedLinear1D_Col, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row +from .mlp import IXFLlamaMLP +from .flash_attention import Colo_LlamaFlashAtten +from .normalization import Colo_FusedRMSNorm + +__all__ = [ + "Embedding1D", + "VocabParallelEmbedding1D", + "Linear1D_Col", + "Linear1D_Row", + "GPT2FusedLinearConv1D_Col", + "GPT2FusedLinearConv1D_Row", + "DropoutForParallelInput", + "DropoutForReplicatedInput", + "cross_entropy_1d", + "dist_cross_entropy", + "BaseLayerNorm", + "LayerNorm", + "RMSNorm", + "FusedLayerNorm", + "FusedRMSNorm", + "FusedLinear1D_Col", + "ParallelModule", + "PaddingEmbedding", + "PaddingLMHead", + "VocabParallelLMHead1D", + "AttnMaskType", + "ColoAttention", + "RingAttention", + "get_pad_info", + "all_to_all_comm", + "LinearWithFusedGradientAccu", + "IXFLlamaMLP", + "Colo_LlamaFlashAtten", + "Colo_FusedRMSNorm", +] diff --git a/toolbox/ColossalAI/v0.4.4/patches/colossalai/shardformer/layer/_operation.py b/toolbox/ColossalAI/v0.4.4/patches/colossalai/shardformer/layer/_operation.py new file mode 100644 index 0000000000000000000000000000000000000000..0d835caa1258e7cb8124b15155764f54a654be56 --- /dev/null +++ b/toolbox/ColossalAI/v0.4.4/patches/colossalai/shardformer/layer/_operation.py @@ -0,0 +1,1171 @@ +#!/usr/bin/env python3 +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +import torch +import torch.distributed as dist +import torch.nn.functional as F + +from .utils import is_share_sp_tp + +try: + import fused_mix_prec_layer_norm_cuda +except: + fused_mix_prec_layer_norm_cuda = None + +try: + import fused_weight_gradient_mlp_cuda + + _grad_accum_fusion_available = True +except ImportError: + _grad_accum_fusion_available = False + +from colossalai.quantization.fp8 import ( + all_gather_fp8, + all_reduce_fp8, + all_to_all_fp8, + all_to_all_single_fp8, + reduce_scatter_fp8, +) + + +class FusedLayerNormAffineFunction1D(torch.autograd.Function): + r"""Layernorm + + Args: + input: input matrix. + weight: weight matrix. + bias: bias matrix. + normalized_shape: input shape from an expected input of size. + :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] \times \ldots \times \text{normalized_shape}[-1]]` + If a single integer is used, it is treated as a singleton list, and this module will + normalize over the last dimension which is expected to be of that specific size. + eps: a value added to the denominator for numerical stability + """ + + @staticmethod + def forward(ctx, input, weight, bias, normalized_shape, eps): + ctx.normalized_shape = normalized_shape + ctx.eps = eps + input_ = input.contiguous() + weight_ = weight.contiguous() + bias_ = bias.contiguous() + output, mean, invvar = fused_mix_prec_layer_norm_cuda.forward_affine( + input_, ctx.normalized_shape, weight_, bias_, ctx.eps + ) + ctx.save_for_backward(input_, weight_, bias_, mean, invvar) + return output + + @staticmethod + def backward(ctx, grad_output): + input_, weight_, bias_, mean, invvar = ctx.saved_tensors + grad_input = grad_weight = grad_bias = None + grad_input, grad_weight, grad_bias = fused_mix_prec_layer_norm_cuda.backward_affine( + grad_output.contiguous(), mean, invvar, input_, ctx.normalized_shape, weight_, bias_, ctx.eps + ) + + return grad_input, grad_weight, grad_bias, None, None + + +class MatmulWithAsyncCommunication(torch.autograd.Function): + """ + Linear layer execution with asynchronous communication in backprop. + """ + + @staticmethod + def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False): + ctx.save_for_backward(input_, weight, bias) + ctx.use_bias = bias is not None + ctx.process_group = process_group + ctx.async_grad_allreduce = async_grad_allreduce + ctx.fp8_communication = fp8_communication + + output = torch.matmul(input_, weight) + + if bias is not None: + output = output + bias + + return output + + @staticmethod + def backward(ctx, grad_output): + input, weight, bias = ctx.saved_tensors + use_bias = ctx.use_bias + fp8_communication = ctx.fp8_communication + + # In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. + weight = weight.view(weight.shape) + if bias is not None: + bias = bias.view(bias.shape) + + total_input = input + grad_input = grad_output.matmul(weight.T) + grad_output = grad_output.contiguous() + # Convert the tensor shapes to 2D for execution compatibility + if len(grad_output.shape) > 2: + grad_output = grad_output.view(-1, grad_output.shape[-1]) + total_input = total_input.view(-1, total_input.shape[-1]) + + if ctx.async_grad_allreduce and fp8_communication: + _reduce(grad_input, group=ctx.process_group, fp8_communication=fp8_communication, fp8_format="e5m2") + elif ctx.async_grad_allreduce: + # Asynchronous all-reduce + handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True) + # Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have + # all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py + + grad_weight = total_input.t().matmul(grad_output) + grad_bias = grad_output.sum(dim=0) if use_bias else None + + if ctx.async_grad_allreduce and not fp8_communication: + handle.wait() + + return grad_input, grad_weight, grad_bias, None, None, None, None + + +class LinearWithAsyncCommunication(torch.autograd.Function): + """ + Linear layer execution with asynchronous communication in backprop. + """ + + @staticmethod + def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False): + ctx.save_for_backward(input_, weight, bias) + ctx.use_bias = bias is not None + ctx.process_group = process_group + ctx.async_grad_allreduce = async_grad_allreduce + ctx.fp8_communication = fp8_communication + if bias is not None: + output = F.linear(input_, weight, bias) + else: + output = F.linear(input_, weight) + + return output + + @staticmethod + def backward(ctx, grad_output): + input, weight, bias = ctx.saved_tensors + use_bias = ctx.use_bias + fp8_communication = ctx.fp8_communication + + # In order to be hooked into Gemini's '__torch_function__', adding a view operation to bias. + if use_bias: + bias.view(bias.shape) + + total_input = input.contiguous() + grad_input = grad_output.matmul(weight) + grad_output = grad_output.contiguous() + # Convert the tensor shapes to 2D for execution compatibility + if len(grad_output.shape) > 2: + grad_output = grad_output.view(-1, grad_output.shape[-1]) + total_input = total_input.view(-1, total_input.shape[-1]) + + if ctx.async_grad_allreduce: + # Asynchronous all-reduce + if fp8_communication: + all_reduce_fp8(grad_input, group=ctx.process_group) + else: + handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True) + # Relay on CUDA_DEVICE_MAX_CONNECTIONS=1 to have + # all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py + + if _grad_accum_fusion_available and weight.grad is not None: + grad = weight.grad + if grad.dtype == torch.float32: + fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad) + grad_weight = None + elif grad.dtype == torch.float16: + fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, grad) + grad_weight = None + else: + grad_weight = grad_output.t().matmul(total_input) + else: + grad_weight = grad_output.t().matmul(total_input) + + grad_bias = grad_output.sum(dim=0) if use_bias else None + + if ctx.async_grad_allreduce and not fp8_communication: + handle.wait() + + return grad_input, grad_weight, grad_bias, None, None, None, None +class LinearWithFusedGradAccu(torch.autograd.Function): + """ + Linear layer execution with FusedGradient. + """ + + @staticmethod + def forward(ctx, input_, weight, bias = None): + ctx.save_for_backward(input_, weight, bias) + ctx.use_bias = bias is not None + + if bias is not None: + output = F.linear(input_, weight, bias) + else: + output = F.linear(input_, weight) + + return output + + @staticmethod + def backward(ctx, grad_output): + input, weight, bias = ctx.saved_tensors + use_bias = ctx.use_bias + + # In order to be hooked into Gemini's '__torch_function__', adding a view operation to bias. + if use_bias: + bias.view(bias.shape) + + total_input = input + grad_input = grad_output.matmul(weight) + grad_output = grad_output.contiguous() + # Convert the tensor shapes to 2D for execution compatibility + if len(grad_output.shape) > 2: + grad_output = grad_output.view(-1, grad_output.shape[-1]) + total_input = total_input.view(-1, total_input.shape[-1]) + + if weight.grad is None: + weight.grad = torch.zeros_like(weight) + if _grad_accum_fusion_available and weight.grad is not None: + grad = weight.grad + if grad.dtype == torch.float32: + fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad) + grad_weight = None + # wgrad_gemm_accum_fp16 接口同时支持fp16和bf16 + elif grad.dtype == torch.float16 or grad.dtype == torch.bfloat16: + fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, grad) + grad_weight = None + else: + grad_weight = grad_output.t().matmul(total_input) + else: + grad_weight = grad_output.t().matmul(total_input) + + grad_bias = grad_output.sum(dim=0) if use_bias else None + + return grad_input, grad_weight, grad_bias + + +def _ring_as_gather(func, input_to_gather=None, input_local=None, process_group=None, gather_dim=1, keep_item=False): + # currently only support one single tensor as output + group_size = dist.get_world_size(process_group) + cur_rank = dist.get_rank(process_group) + + # output_tensors = [torch.empty((input_shape[0], input_shape[1], weight_shape[0])) for _ in range(group_size)] + + # initialization of ring communication + recv_rank = cur_rank + 1 if cur_rank + 1 < group_size else 0 + send_rank = cur_rank - 1 if cur_rank > 0 else group_size - 1 + rank_map = list(dist.get_process_group_ranks(process_group)) + recv_rank = rank_map[recv_rank] + send_rank = rank_map[send_rank] + recv_tensors = {} + send_tensors = {} + for k, v in input_to_gather.items(): + recv_tensors[k] = torch.empty_like(v) + send_tensors[k] = v.clone() + + def communicate_step(): + comm_ops = [] + for k in recv_tensors: + comm_ops.append(dist.P2POp(dist.irecv, recv_tensors[k], recv_rank, group=process_group)) + comm_ops.append(dist.P2POp(dist.isend, send_tensors[k], send_rank, group=process_group)) + return dist.batch_isend_irecv(comm_ops) + + def switch_step(): + for k in recv_tensors: + send_tensors[k], recv_tensors[k] = recv_tensors[k], send_tensors[k] + + output_tensors = [] + + handles = communicate_step() + # first round: special case, retrive from local tensor + output_tensors.append(func(**input_to_gather, **input_local)) + for i in range(group_size - 2): + for handle in handles: + handle.wait() + + switch_step() + + handles = communicate_step() + + # actual computation + output_tensors.append(func(**send_tensors, **input_local)) + + # final round: special case, no need to send/recv again + for handle in handles: + handle.wait() + output_tensors.append(func(**recv_tensors, **input_local)) + + return torch.cat(output_tensors[group_size - cur_rank :] + output_tensors[: group_size - cur_rank], dim=gather_dim) + + +class _GatherForwardReduceScatterBackward(torch.autograd.Function): + """Gather input from sequence parallel in forward and reduce-scatter gradient in backward + + Args: + input_ (`torch.Tensor`): The input tensor from sequence parallel region. + process_group (`torch.distributed.ProcessGroup`): The process group used for collective communication. + overlap (`bool`): Whther to overlap the all_gather op and gradient calculate in backward. + + """ + + @staticmethod + def forward(ctx, input_, process_group, dim, fp8_communication=False): + ctx.process_group = process_group + ctx.dim = dim + ctx.fp8_communication = fp8_communication + + return _gather(input_, dim, process_group, fp8_communication, fp8_format="e4m3") + + @staticmethod + def backward(ctx, grad_output): + dim = ctx.dim + process_group = ctx.process_group + fp8_communication = ctx.fp8_communication + # do reduce-scatter + new_shape = list(grad_output.shape) + assert ( + new_shape[dim] % dist.get_world_size(process_group) == 0 + ), f"The dimension to split ({new_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). " + new_shape[dim] = new_shape[dim] // dist.get_world_size(process_group) + grad_list = [ + item.contiguous() for item in torch.chunk(grad_output, dist.get_world_size(process_group), dim=dim) + ] + output = torch.empty(new_shape, dtype=grad_output.dtype, device=grad_output.device) + + if fp8_communication: + reduce_scatter_fp8(output, grad_list, group=process_group, fp8_format="e5m2") + else: + dist.reduce_scatter(output, grad_list, group=process_group) + + return output, None, None, None + + +class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function): + """Gather input from sequence parallel in forward and reduce-scatter gradient in backward + + Args: + input_ (`torch.Tensor`): The input tensor from sequence parallel region. + process_group (`torch.distributed.ProcessGroup`): The process group used for collective communication. + overlap (`bool`): Whether to overlap the all_gather op and gradient calculate in backward. + + """ + + @staticmethod + def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap=True, ring=False): + ctx.save_for_backward(input_, weight, bias) + ctx.use_bias = bias is not None + ctx.process_group = process_group + ctx.async_grad_reduce_scatter = async_grad_reduce_scatter + ctx.dim = dim + ctx.overlap = overlap + + if ring is True: + input_to_gather = {"input": input_} + input_local = {"weight": weight} + + output = _ring_as_gather( + F.linear, + input_to_gather=input_to_gather, + input_local=input_local, + process_group=process_group, + ) + + if bias is not None: + output += bias + else: + input_parallel = _gather(input_, dim, process_group) + if bias is not None: + output = F.linear(input_parallel, weight, bias) + else: + output = F.linear(input_parallel, weight) + + return output + + @staticmethod + def backward(ctx, grad_output): + input_, weight, bias = ctx.saved_tensors + use_bias = ctx.use_bias + dim = ctx.dim + process_group = ctx.process_group + overlap = ctx.overlap + + # In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm + if use_bias: + bias = bias.view(bias.shape) + + if not overlap: + input_parallel = _gather(input_, dim, process_group) + + total_input = input_parallel + grad_input = grad_output.matmul(weight) + grad_output = grad_output.contiguous() + # Convert the tensor shapes to 2D for execution compatibility + if len(grad_output.shape) > 2: + grad_output = grad_output.view(-1, grad_output.shape[-1]) + total_input = total_input.view(-1, total_input.shape[-1]) + + if ctx.async_grad_reduce_scatter: + # Asynchronous reduce-scatter + input_list = [ + item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim) + ] + output = torch.empty( + input_.shape, dtype=input_parallel.dtype, device=input_parallel.device + ).contiguous() + handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True) + # Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have + # all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py + + if _grad_accum_fusion_available and weight.grad is not None: + grad = weight.grad + if grad.dtype == torch.float32: + fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad) + grad_weight = None + elif grad.dtype == torch.float16: + fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, grad) + grad_weight = None + else: + grad_weight = grad_output.t().matmul(total_input) + else: + grad_weight = grad_output.t().matmul(total_input) + + grad_bias = grad_output.sum(dim=0) if use_bias else None + + if ctx.async_grad_reduce_scatter: + handle.wait() + + else: + input_ = input_.contiguous() + world_size = dist.get_world_size(process_group) + tensor_list = [torch.empty_like(input_) for _ in range(world_size)] + + # do all gather in is async way + gather_handle = dist.all_gather(tensor_list, input_, group=process_group, async_op=True) + # calculate gradient and prepare data asynchronously with all-gather + # calculate + grad_input = grad_output.matmul(weight) + grad_output = grad_output.contiguous() + # Convert the tensor shapes to 2D for execution compatibility + if len(grad_output.shape) > 2: + grad_output = grad_output.view(-1, grad_output.shape[-1]) + grad_bias = grad_output.sum(dim=0) if use_bias else None + # prepare data + input_list = [ + item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim) + ] + output = torch.empty(input_.shape, dtype=input_.dtype, device=input_.device).contiguous() + # wait until all-gather finished + gather_handle.wait() + + # do reduce-scatter in async way + reducescatter_handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True) + input_parallel = torch.cat(tensor_list, dim=dim).contiguous() + # calculate gradient + if len(input_parallel.shape) > 2: + input_parallel = input_parallel.view(-1, input_parallel.shape[-1]) + + if _grad_accum_fusion_available and weight.grad is not None: + grad = weight.grad + if grad.dtype == torch.float32: + fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(input_parallel, grad_output, grad) + grad_weight = None + elif grad.dtype == torch.float16: + fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(input_parallel, grad_output, grad) + grad_weight = None + else: + grad_weight = grad_output.t().matmul(input_parallel) + else: + grad_weight = grad_output.t().matmul(input_parallel) + # grad_weight = grad_output.t().matmul(input_parallel) + # wait until reduce-scatter finished + reducescatter_handle.wait() + + return output, grad_weight, grad_bias, None, None, None, None, None + + +def _ring_as_reducescatter( + func, input_to_reducescatter=None, input_local=None, process_group=None, reducescatter_dim=1 +): + # currently only support one single tensor as output + group_size = dist.get_world_size(process_group) + cur_rank = dist.get_rank(process_group) + + # initialization of ring communication + recv_rank = cur_rank - 1 if cur_rank > 0 else group_size - 1 + send_rank = cur_rank + 1 if cur_rank + 1 < group_size else 0 + rank_map = list(dist.get_process_group_ranks(process_group)) + recv_rank = rank_map[recv_rank] + send_rank = rank_map[send_rank] + input_tensors = [] + for _ in range(group_size): + input_tensors.append({}) + for k, v in input_to_reducescatter.items(): + input_shape = v.shape + assert input_shape[reducescatter_dim] % group_size == 0 + _input_tensors = list(torch.split(v, input_shape[reducescatter_dim] // group_size, dim=reducescatter_dim)) + for i in range(group_size): + input_tensors[i][k] = _input_tensors[i] + input_tensors = input_tensors[cur_rank:] + input_tensors[:cur_rank] + input_tensors.reverse() + + output_tensor = func(**input_tensors[0], **input_local) + recv_tensor = torch.empty_like(output_tensor) + send_tensor = output_tensor.clone() + + def communicate_step(): + recv_op = dist.P2POp(dist.irecv, recv_tensor, recv_rank, group=process_group) + send_op = dist.P2POp(dist.isend, send_tensor, send_rank, group=process_group) + return dist.batch_isend_irecv([recv_op, send_op]) + + handles = communicate_step() + # first round: special case, retrive from local tensor + for i in range(group_size - 2): + # actual computation + output_tensor = func(**input_tensors[i + 1], **input_local) + + for handle in handles: + handle.wait() + output_tensor += recv_tensor + + tmp_tensor = send_tensor + send_tensor = output_tensor + output_tensor = tmp_tensor + + handles = communicate_step() + + # final round: special case, no need to send/recv again + output_tensor = func(**input_tensors[-1], **input_local) + for handle in handles: + handle.wait() + output_tensor += recv_tensor + return output_tensor + + +class _LinearWithReduceScatterForwardGatherBackward(torch.autograd.Function): + """Reduce-scatter input from sequence parallel in forward and gather gradient in backward with ring + + Args: + input_ (`torch.Tensor`): The input tensor from sequence parallel region. + process_group (`torch.distributed.ProcessGroup`): The process group used for collective communication. + overlap (`bool`): Whther to overlap the all_gather op and gradient calculate in backward. + + """ + + @staticmethod + def forward(ctx, input_, weight, bias, process_group, dim, ring): + ctx.save_for_backward(input_, weight, bias) + ctx.use_bias = bias is not None + ctx.process_group = process_group + ctx.dim = dim + + if ring is True: + input_to_reducescatter = {"input": input_} + input_local = {"weight": weight} + + if bias is not None: + input_to_reducescatter["bias"] = bias + + output = _ring_as_reducescatter( + F.linear, + input_to_reducescatter=input_to_reducescatter, + input_local=input_local, + process_group=process_group, + ) + else: + if bias is not None: + partial_output = F.linear(input_, weight, bias) + else: + partial_output = F.linear(input_, weight) + + output_shape = list(partial_output.shape) + assert ( + output_shape[dim] % dist.get_world_size(process_group) == 0 + ), f"The dimension to split ({output_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). " + output_shape[dim] = output_shape[dim] // dist.get_world_size(process_group) + + output_list = [ + item.contiguous() for item in torch.chunk(partial_output, dist.get_world_size(process_group), dim=dim) + ] + output = torch.empty(output_shape, dtype=partial_output.dtype, device=partial_output.device).contiguous() + dist.reduce_scatter(output, output_list, group=process_group) + + return output + + @staticmethod + def backward(ctx, grad_output): + input_, weight, bias = ctx.saved_tensors + use_bias = ctx.use_bias + dim = ctx.dim + process_group = ctx.process_group + + # In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm + if use_bias: + bias = bias.view(bias.shape) + + grad_output = _gather(grad_output, dim, process_group) + + # TODO Need to fully optimize + total_input = input_ + grad_input = grad_output.matmul(weight) + grad_output = grad_output.contiguous() + # Convert the tensor shapes to 2D for execution compatibility + if len(grad_output.shape) > 2: + grad_output = grad_output.view(-1, grad_output.shape[-1]) + total_input = total_input.view(-1, total_input.shape[-1]) + grad_weight = grad_output.t().matmul(total_input) + grad_bias = grad_output.sum(dim=0) if use_bias else None + + return grad_input, grad_weight, grad_bias, None, None, None + + +class _ReduceScatterForwardGatherBackward(torch.autograd.Function): + """Reduce-scatter input from sequence parallel in forward and gather gradient in backward + + Args: + input_ (`torch.Tensor`): The input tensor from sequence parallel region. + process_group (`torch.distributed.ProcessGroup`): The process group used for collective communication. + + """ + + @staticmethod + def forward(ctx, input_, process_group, dim, fp8_communication=False): + ctx.dim = dim + ctx.process_group = process_group + ctx.fp8_communication = fp8_communication + + # do reduce-scatter + new_shape = list(input_.shape) + assert ( + new_shape[dim] % dist.get_world_size(process_group) == 0 + ), f"The dimension to split ({new_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). " + new_shape[dim] = new_shape[dim] // dist.get_world_size(process_group) + input_list = [item.contiguous() for item in torch.chunk(input_, dist.get_world_size(process_group), dim=dim)] + output = torch.empty(new_shape, dtype=input_.dtype, device=input_.device) + if fp8_communication: + reduce_scatter_fp8(output, input_list, group=process_group, fp8_format="e4m3") + else: + dist.reduce_scatter(output, input_list, group=process_group) + + return output + + @staticmethod + def backward(ctx, grad_output): + dim = ctx.dim + process_group = ctx.process_group + fp8_communication = ctx.fp8_communication + + return _gather(grad_output, dim, process_group, fp8_communication, fp8_format="e5m2"), None, None, None + + +class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function): + """ + This class is designed for matmul operation with gather forward and reduce-scatter backward. + + Args: + input_ (`torch.Tensor`): input matrix. + dim (int): the dimension to perform split and gather + process_group (`torch.distributed.ProcessGroup`): the process group used for collective communication + + """ + + @staticmethod + def forward( + ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring, fp8_communication + ): + ctx.save_for_backward(input_, weight, bias) + ctx.use_bias = bias is not None + ctx.process_group = process_group + ctx.async_grad_reduce_scatter = async_grad_reduce_scatter + ctx.dim = dim + ctx.overlap = overlap + ctx.fp8_communication = fp8_communication + + if ring is True: + input_to_gather = {} + input_local = {} + input_to_gather["input"] = input_ + input_local["other"] = weight + + output = _ring_as_gather( + torch.matmul, + input_to_gather=input_to_gather, + input_local=input_local, + process_group=process_group, + gather_dim=dim, + ) + + else: + input_parallel = _gather(input_, dim, process_group, fp8_communication, fp8_format="e4m3") + + output = torch.matmul(input_parallel, weight) + + if bias is not None: + output = output + bias + return output + + @staticmethod + def backward(ctx, grad_output): + input_, weight, bias = ctx.saved_tensors + use_bias = ctx.use_bias + dim = ctx.dim + process_group = ctx.process_group + overlap = ctx.overlap + fp8_communication = ctx.fp8_communication + + # In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm + weight = weight.view(weight.shape) + if use_bias: + bias = bias.view(bias.shape) + + if not overlap: + input_parallel = _gather(input_, dim, process_group, fp8_communication, fp8_format="e5m2") + + total_input = input_parallel + grad_input = grad_output.matmul(weight.T) + grad_output = grad_output.contiguous() + # Convert the tensor shapes to 2D for execution compatibility + if len(grad_output.shape) > 2: + grad_output = grad_output.view(-1, grad_output.shape[-1]) + total_input = total_input.view(-1, total_input.shape[-1]) + + if ctx.async_grad_reduce_scatter: + # Asynchronous reduce-scatter + input_list = [ + item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim) + ] + output = torch.empty( + input_.shape, dtype=input_parallel.dtype, device=input_parallel.device + ).contiguous() + handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True) + # Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have + # all-reduce scheduled first and have GPU resources allocated + + grad_weight = total_input.t().matmul(grad_output) + grad_bias = grad_output.sum(dim=0) if use_bias else None + + if ctx.async_grad_reduce_scatter: + handle.wait() + + else: + world_size = dist.get_world_size(process_group) + tensor_list = [torch.empty_like(input_) for _ in range(world_size)] + + # do all gather in is async way + gather_handle = dist.all_gather(tensor_list, input_, group=process_group, async_op=True) + # calculate gradient and prepare data asynchronously with all-gather + # calculate + grad_input = grad_output.matmul(weight.T) + grad_output = grad_output.contiguous() + # Convert the tensor shapes to 2D for execution compatibility + if len(grad_output.shape) > 2: + grad_output = grad_output.view(-1, grad_output.shape[-1]) + grad_bias = grad_output.sum(dim=0) if use_bias else None + # prepare data + input_list = [ + item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim) + ] + output = torch.empty(input_.shape, dtype=input_.dtype, device=input_.device).contiguous() + # wait until all-gather finished + gather_handle.wait() + + # do reduce-scatter in async way + reducescatter_handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True) + input_parallel = torch.cat(tensor_list, dim=dim).contiguous() + # calculate gradient + if len(input_parallel.shape) > 2: + input_parallel = input_parallel.view(-1, input_parallel.shape[-1]) + grad_weight = input_parallel.t().matmul(grad_output) + # wait until reduce-scatter finished + reducescatter_handle.wait() + + return output, grad_weight, grad_bias, None, None, None, None, None, None + + +class _SplitForwardGatherBackward(torch.autograd.Function): + """ + Split the input and keep only the corresponding chuck to the rank. + + Args: + input_ (`torch.Tensor`): input matrix. + dim (int): the dimension to perform split and gather + process_group (`torch.distributed.ProcessGroup`): the process group used for collective communication + + """ + + @staticmethod + def forward(ctx, input_, dim, process_group, grad_scale=None, fp8_communication=False): + ctx.process_group = process_group + ctx.dim = dim + ctx.grad_scale = grad_scale + ctx.fp8_communication = fp8_communication + return _split(input_, dim, process_group) + + @staticmethod + def backward(ctx, grad_output): + if ctx.grad_scale is not None: + grad_output = grad_output * ctx.grad_scale + + return ( + _gather(grad_output, ctx.dim, ctx.process_group, ctx.fp8_communication, fp8_format="e5m2"), + None, + None, + None, + None, + ) + + +class _ReduceForward(torch.autograd.Function): + """ + All-reduce the input from the model parallel region. + + Args: + input_: input matrix. + process_group: communication group. + + """ + + @staticmethod + def forward(ctx, input_, process_group, grad_scale=None, fp8_communication=False): + ctx.grad_scale = grad_scale + return _reduce(input_, process_group, fp8_communication, fp8_format="e4m3") + + @staticmethod + def backward(ctx, grad_output): + if ctx.grad_scale is not None: + grad_output = grad_output * ctx.grad_scale + return grad_output, None, None, None + + +class _ReduceBackward(torch.autograd.Function): + """ + All-reduce the input from the model parallel region. + + Args: + input_: input matrix. + parallel_mode: parallel mode. + """ + + @staticmethod + def forward(ctx, input_, process_group, fp8_communication=False): + ctx.process_group = process_group + ctx.fp8_communication = fp8_communication + return input_ + + @staticmethod + def backward(ctx, grad_output): + fp8_communication = ctx.fp8_communication + return _reduce(grad_output, ctx.process_group, fp8_communication, fp8_format="e5m2"), None, None + + +class _GatherForwardSplitBackward(torch.autograd.Function): + """Gather the input from model parallel region and concatenate. + + Args: + input_: input matrix. + parallel_mode: parallel mode. + dim: dimension + """ + + @staticmethod + def forward(ctx, input_, dim, process_group, grad_scale=None, fp8_communication=False): + ctx.process_group = process_group + ctx.dim = dim + ctx.grad_scale = grad_scale + + return _gather(input_, dim, process_group, fp8_communication=fp8_communication, fp8_format="e4m3") + + @staticmethod + def backward(ctx, grad_output): + if ctx.grad_scale is not None: + grad_output = grad_output * ctx.grad_scale + return _split(grad_output, ctx.dim, ctx.process_group), None, None, None, None + + +class _AllToAll(torch.autograd.Function): + """All-to-all communication. + + Args: + input_: input matrix + process_group: communication group + scatter_dim: scatter dimension + gather_dim: gather dimension + """ + + @staticmethod + def forward(ctx, input_, process_group, scatter_dim, gather_dim, fp8_communication=False): + ctx.process_group = process_group + ctx.scatter_dim = scatter_dim + ctx.gather_dim = gather_dim + ctx.fp8_communication = fp8_communication + world_size = dist.get_world_size(process_group) + bsz, _, _ = input_.shape + + # using all_to_all_single when batch size is 1 + if bsz == 1: + return _all_to_all_single( + input_, + world_size, + process_group, + scatter_dim, + gather_dim, + fp8_communication=fp8_communication, + fp8_format="e4m3", + ) + else: + return _all_to_all( + input_, + world_size, + process_group, + scatter_dim, + gather_dim, + fp8_communication=fp8_communication, + fp8_format="e4m3", + ) + + @staticmethod + def backward(ctx, grad_output): + process_group = ctx.process_group + scatter_dim = ctx.gather_dim + gather_dim = ctx.scatter_dim + fp8_communication = ctx.fp8_communication + world_size = dist.get_world_size(process_group) + bsz, _, _ = grad_output.shape + + if bsz == 1: + return_grad = _all_to_all_single( + grad_output, + world_size, + process_group, + scatter_dim, + gather_dim, + fp8_communication=fp8_communication, + fp8_format="e5m2", + ) + else: + return_grad = _all_to_all( + grad_output, + world_size, + process_group, + scatter_dim, + gather_dim, + fp8_communication=fp8_communication, + fp8_format="e5m2", + ) + + return (return_grad, None, None, None, None) + + +class HookParameter(torch.autograd.Function): + """In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm""" + + @staticmethod + def forward(ctx, input, weight, bias): + ctx.save_for_backward(weight, bias) + output = input + return output + + @staticmethod + def backward(ctx, grad_output): + weight, bias = ctx.saved_tensors + if weight is not None: + weight = weight.view(weight.shape) + if bias is not None: + bias = bias.view(bias.shape) + return grad_output, None, None + + +def hook_parameter_in_backward(input, weight=None, bias=None): + return HookParameter.apply(input, weight, bias) + + +def _reduce(input_, process_group, fp8_communication=False, fp8_format="e5m2"): + # skip if only one rank involved + if dist.get_world_size(process_group) == 1: + return input_ + else: + if fp8_communication: + all_reduce_fp8(input_, group=process_group, fp8_format=fp8_format) + else: + dist.all_reduce(input_, group=process_group) + return input_ + + +def _split(input_, dim=-1, process_group=None): + # skip if only one rank involved + world_size = dist.get_world_size(process_group) + if world_size == 1: + return input_ + + # Split along last dimension. + dim_size = input_.size(dim) + assert dim_size % world_size == 0, ( + f"The dimension to split ({dim_size}) is not a multiple of world size ({world_size}), " + f"cannot split tensor evenly" + ) + + tensor_list = torch.split(input_, dim_size // world_size, dim=dim) + rank = dist.get_rank(process_group) + output = tensor_list[rank].clone().contiguous() + + return output + + +def _gather(input_, dim=-1, process_group=None, fp8_communication=False, fp8_format="e5m2"): + # skip if only one rank involved + world_size = dist.get_world_size(process_group) + if world_size == 1: + return input_ + + input_ = input_.contiguous() + tensor_list = [torch.empty_like(input_) for _ in range(world_size)] + if fp8_communication: + all_gather_fp8(tensor_list, input_, fp8_format=fp8_format, group=process_group) + else: + dist.all_gather(tensor_list, input_, group=process_group) + + output = torch.cat(tensor_list, dim=dim).contiguous() + + return output + + +def _reduce_scatter(input_, dim=1, process_group=None): + """Do reduce-scatter operation. + + Args: + input_ (`torch.Tensor`): The input tensor from sequence parallel region. + dim (int): The dimension to perform reduce-scatter. + process_group (`torch.distributed.ProcessGroup`): The process group used for collective communication. + """ + world_size = dist.get_world_size(process_group) + if world_size == 1: + return input_ + + # reduce-scatter + new_shape = list(input_.shape) + assert ( + new_shape[dim] % dist.get_world_size(process_group) == 0 + ), f"The dimension to split ({new_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). " + new_shape[dim] = new_shape[dim] // world_size + output = torch.empty(new_shape, dtype=input_.dtype, device=input_.device) + dist.reduce_scatter(output, input_, group=process_group) + + return output + + +def _all_to_all(input_, world_size, group, scatter_dim, gather_dim, fp8_communication=False, fp8_format="e5m2"): + input_list = [t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim)] + output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)] + if fp8_communication: + all_to_all_fp8(output_list, input_list, group=group, fp8_format=fp8_format) + else: + dist.all_to_all(output_list, input_list, group=group) + return torch.cat(output_list, dim=gather_dim).contiguous() + + +def _all_to_all_single( + input_, seq_world_size, group, scatter_dim, gather_dim, fp8_communication=False, fp8_format="e5m2" +): + inp_shape = list(input_.shape) + inp_shape[scatter_dim] = inp_shape[scatter_dim] // seq_world_size + if scatter_dim < 2: + input_t = input_.reshape([seq_world_size, inp_shape[scatter_dim]] + inp_shape[scatter_dim + 1 :]).contiguous() + else: + input_t = ( + input_.reshape([-1, seq_world_size, inp_shape[scatter_dim]] + inp_shape[scatter_dim + 1 :]) + .transpose(0, 1) + .contiguous() + ) + + output = torch.empty_like(input_t) + if fp8_communication: + all_to_all_single_fp8(output, input_t, group=group, fp8_format=fp8_format) + else: + + dist.all_to_all_single(output, input_t, group=group) + + if scatter_dim < 2: + output = output.transpose(0, 1).contiguous() + + return output.reshape( + inp_shape[:gather_dim] + + [ + inp_shape[gather_dim] * seq_world_size, + ] + + inp_shape[gather_dim + 1 :] + ).contiguous() + + +def matmul_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False): + return MatmulWithAsyncCommunication.apply( + input_, weight, bias, process_group, async_grad_allreduce, fp8_communication + ) + + +def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False): + return LinearWithAsyncCommunication.apply( + input_, weight, bias, process_group, async_grad_allreduce, fp8_communication + ) + + +def linear_gather_forward_reducescatter_backward( + input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring=False +): + return _LinearWithGatherForwardReduceScatterBackward.apply( + input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring + ) + + +def gather_forward_reducescatter_backward(input_, process_group, dim, fp8_communication=False): + return _GatherForwardReduceScatterBackward.apply(input_, process_group, dim, fp8_communication) + + +def reducescatter_forward_gather_backward(input_, process_group, dim, fp8_communication=False): + return _ReduceScatterForwardGatherBackward.apply(input_, process_group, dim, fp8_communication) + + +def linear_reducescatter_forward_gather_backward(input_, weight, bias=None, process_group=None, dim=1, ring=False): + return _LinearWithReduceScatterForwardGatherBackward.apply(input_, weight, bias, process_group, dim, ring) + + +def matmul_gather_forward_reducescatter_backward( + input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring=False, fp8_communication=False +): + return _MatmulWithGatherForwardReduceScatterBackward.apply( + input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring, fp8_communication + ) + + +def gather_forward_split_backward(input_, dim, process_group, grad_scale=None, fp8_communication=False): + return _GatherForwardSplitBackward.apply(input_, dim, process_group, grad_scale, fp8_communication) + + +def split_forward_gather_backward(input_, dim, process_group, grad_scale=None, fp8_communication=False): + return _SplitForwardGatherBackward.apply(input_, dim, process_group, grad_scale, fp8_communication) + + +def reduce_forward(input_, process_group, grad_scale=None, fp8_communication=False): + return _ReduceForward.apply(input_, process_group, grad_scale, fp8_communication) + + +def reduce_backward(input_, process_group, fp8_communication=False): + return _ReduceBackward.apply(input_, process_group, fp8_communication) + + +def all_to_all_comm(input_, process_group=None, scatter_dim=2, gather_dim=1, fp8_communication=False): + return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim, fp8_communication) + + +def gather_sp_output(hidden_states, shard_config, sp_dim=1): + """ + Gather the output of the last layer for cross entropy computation + """ + sp_group = shard_config.sequence_parallel_process_group + sp_mode = shard_config.sequence_parallelism_mode + fp8_comm = shard_config.fp8_communication + if dist.get_world_size(sp_group) == 1: + return hidden_states + + # Rescale grad (HybridParallelPlugin applies ZeRO grad averaging on the DP * SP group) + scale = None if is_share_sp_tp(sp_mode) else dist.get_world_size(sp_group) + hidden_states = gather_forward_split_backward( + hidden_states, sp_dim, sp_group, grad_scale=scale, fp8_communication=fp8_comm + ) + return hidden_states diff --git a/toolbox/ColossalAI/v0.4.4/patches/colossalai/shardformer/layer/flash_attention.py b/toolbox/ColossalAI/v0.4.4/patches/colossalai/shardformer/layer/flash_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..c62e94208e875aa1ecfb955bf128c89458b9d5d7 --- /dev/null +++ b/toolbox/ColossalAI/v0.4.4/patches/colossalai/shardformer/layer/flash_attention.py @@ -0,0 +1,482 @@ +#!/usr/bin/env python3 +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +import warnings +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers.cache_utils import Cache +from transformers.utils import logging + +from colossalai.lazy import LazyInitContext + +logger = logging.get_logger(__name__) + +from transformers.utils import ( + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, +) +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_unpadded_func + from flash_attn.flash_attn_interface import flash_attn_func +from transformers.models.llama.modeling_llama import LlamaAttention +from .rotary_pos_embedding import RotaryEmbedding + +# from apex.transformer.functional import fused_apply_rotary_pos_emb +from ixformer.train import fused_apply_split_rotary_pos_emb + +try: + from einops import rearrange +except ImportError: + rearrange = None + +from transformers.models.llama.configuration_llama import LlamaConfig +from colossalai.shardformer.layer import LinearWithFusedGradientAccu + +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class FlashSelfAttentionCore(torch.nn.Module): + """Implement the scaled dot product attention with softmax. + Arguments + --------- + softmax_scale: The temperature to use for the softmax attention. + (default: 1/sqrt(d_keys) where d_keys is computed at + runtime) + attention_dropout: The dropout rate to apply to the attention + (default: 0.0) + """ + def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0, + device=None, dtype=None): + super().__init__() + assert flash_attn_unpadded_func is not None, ('Please install FlashAttention first, ' + 'e.g., with pip install flash-attn') + assert rearrange is not None, 'Please install einops first, e.g., with pip install einops' + self.causal = causal + self.softmax_scale = softmax_scale + self.dropout_p = attention_dropout + + def forward(self, q, k, v): + """Implements the multihead softmax attention. + Arguments + --------- + q, k, v: The tensor containing the query, key, and value. (B, S, H, D) + """ + + assert all((i.dtype in [torch.float16, torch.bfloat16] for i in (q,k,v))) + assert all((i.is_cuda for i in (q,k,v))) + + batch_size, seqlen_q = q.shape[0], q.shape[1] + seqlen_k = k.shape[1] + + # q, k, v = [rearrange(x, 'b s ... -> (b s) ...') for x in [q, k, v]] + # # if os.getenv('ENABLE_FLASH_ATTENTION_WITH_IXDNN', '0') != '0': + # # cu_seqlens_q = torch.empty((batch_size), dtype=torch.int32, device=q.device) + # # else: + # cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, + # device=q.device) + + # if self.training: + # # during training q,k,v always have same seqlen + # assert seqlen_k == seqlen_q + + # is_causal = self.causal + # cu_seqlens_k = cu_seqlens_q + # dropout_p = self.dropout_p + # else: + # # turn off FA causal mask after first inference autoregressive iteration + # # only on first autoregressive step q,k,v have same seqlen + # is_causal = seqlen_q == seqlen_k + # cu_seqlens_k = torch.arange(0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, + # device=q.device) + # dropout_p = 0 + + # output = flash_attn_unpadded_func( + # q, k, v, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen_k, + # dropout_p, + # softmax_scale=self.softmax_scale, causal=is_causal + # ) + # output = rearrange(output, '(b s) ... -> b s ...', b=batch_size) + + + self.attn_impl_mode = 1 + self.use_alibi = False + self.alibi_mode = 1 + output = flash_attn_func( + q, + k, + v, + dropout_p=self.dropout_p, + softmax_scale=self.softmax_scale, + causal= self.causal, + use_alibi=self.use_alibi, + alibi_mode=self.alibi_mode, + imp_mode=self.attn_impl_mode, + ) + #output [b,s,h,d] + + return output + + +class LlamaAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + + # self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) + # self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + # self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + # self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias) + self.query_key_value = LinearWithFusedGradientAccu(self.hidden_size, self.head_dim * (self.num_heads + self.num_key_value_heads * 2), bias=config.attention_bias) + self.o_proj = LinearWithFusedGradientAccu(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) + + # # partial rotary embeddings, which is better than full rotary + # # Wang and Komatsuzaki et al + # # https://github.com/kingoflolz/mesh-transformer-jax/ + rotary_dim = config.hidden_size // config.num_attention_heads + rotary_pos_emb = RotaryEmbedding( + rotary_dim, + rotary_percent = 1, + seq_len_interpolation_factor = 1, + rotary_base=config.rope_theta + ) + self.rotary_pos_emb = rotary_pos_emb(config.max_position_embeddings) + self.rotary_pos_emb = ((self.rotary_pos_emb,) * 2) + + # self._init_rope() + + # def _init_rope(self): + # if self.config.rope_scaling is None: + # self.rotary_emb = LlamaRotaryEmbedding( + # self.head_dim, + # max_position_embeddings=self.max_position_embeddings, + # base=self.rope_theta, + # ) + # else: + # scaling_type = self.config.rope_scaling["type"] + # scaling_factor = self.config.rope_scaling["factor"] + # if scaling_type == "linear": + # self.rotary_emb = LlamaLinearScalingRotaryEmbedding( + # self.head_dim, + # max_position_embeddings=self.max_position_embeddings, + # scaling_factor=scaling_factor, + # base=self.rope_theta, + # ) + # elif scaling_type == "dynamic": + # self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding( + # self.head_dim, + # max_position_embeddings=self.max_position_embeddings, + # scaling_factor=scaling_factor, + # base=self.rope_theta, + # ) + # else: + # raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + return None + +class Colo_LlamaFlashAttention2(LlamaAttention): + """ + 基于 transformers (v4.39.3) LlamaFlashAttention2 改进, 优化点: + a. 融合 self.q_proj、self.k_proj、self.v_proj 为 self.query_key_value ; + b. self.query_key_value 和 self.o_proj 改为 LinearWithFusedGradientAccu 类型 ; + c. 融合 split_mixed_q_k_v 和 rope ; + d. self_attention core 的实现改为 core_attention_flash,内部使用定长的 flash_attn_func ; + + + Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + self.core_attention_flash = FlashSelfAttentionCore( + causal=True, attention_dropout=self.attention_dropout + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + hidden_states = hidden_states.transpose(0,1).contiguous() + mixed_x_layer = self.query_key_value(hidden_states) + # [sq, b, ng * (np/ng + 2) * hn] --> [sq, b, ng, (np/ng + 2), hn] + new_tensor_shape = (mixed_x_layer.size()[0], mixed_x_layer.size()[1], + self.num_heads, + (1 + 2), + self.head_dim, + ) + mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) + + rotary_pos_emb = self.rotary_pos_emb[0] + query_states, key_states, value_states = fused_apply_split_rotary_pos_emb( + mixed_x_layer, + rotary_pos_emb, + ) + query_states, key_states, value_states = [rearrange(x, "s b h d -> b s h d").contiguous() for x in (query_states, key_states, value_states)] + + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + # attn_output = self._flash_attention_forward( + # query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate + # ) + attn_output = self.core_attention_flash(query_states, key_states, value_states) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + def _flash_attention_forward( + self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`float`): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal + ) + + return attn_output + + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +class Colo_LlamaFlashAtten(Colo_LlamaFlashAttention2): + def __init__(self) -> None: + raise NotImplementedError( + "Colo_LlamaFlashAtten is not implemented as a physical class. " + "It is meant to be used only with the from_native_module interface to Convert a native LlamaFlashAttention2(from transformers) module to Colo_LlamaFlashAttention2 module provided above." + ) + + @staticmethod + def from_native_module(module: nn.Module, *args, **kwargs) -> nn.Module: + + LazyInitContext.materialize(module) + + config = getattr(module, "config") + layer_idx = getattr(module, "layer_idx") + + flash_atten = Colo_LlamaFlashAttention2(config=config, layer_idx=layer_idx) + + return flash_atten \ No newline at end of file diff --git a/toolbox/ColossalAI/v0.4.4/patches/colossalai/shardformer/layer/linear.py b/toolbox/ColossalAI/v0.4.4/patches/colossalai/shardformer/layer/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..f18d83b1330881f65adb29a82d2b0d40975d9605 --- /dev/null +++ b/toolbox/ColossalAI/v0.4.4/patches/colossalai/shardformer/layer/linear.py @@ -0,0 +1,677 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import math +from typing import Callable, List, Optional, Tuple, Union + +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from torch.distributed import ProcessGroup +from torch.nn.parameter import Parameter + +from colossalai.lazy import LazyInitContext +from colossalai.nn import init as init +from colossalai.nn.layer.utils import divide +from colossalai.tensor.d_tensor.api import ( + is_distributed_tensor, + shard_colwise, + shard_rowwise, + sharded_tensor_to_existing_param, +) + +from ._operation import ( + gather_forward_reducescatter_backward, + gather_forward_split_backward, + linear_gather_forward_reducescatter_backward, + linear_reducescatter_forward_gather_backward, + linear_with_async_comm, + reduce_forward, + reducescatter_forward_gather_backward, + split_forward_gather_backward, +) +from .parallel_module import PaddingParallelModule, ParallelModule +from .utils import create_randomizer_with_offset +from colossalai.shardformer.layer._operation import LinearWithFusedGradAccu + +__all__ = ["Linear1D_Col", "Linear1D_Row"] + + +class Linear1D_Col(ParallelModule): + r"""Linear layer with column parallelism. + + The linear layer is defined as :math:`Y = XA + b`. A is parallelized along + its second dimension as :math:`A = [A_1, ..., A_p]`. + + Args: + in_features (int): size of each input sample. + out_features (int): size of each output sample. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (`torch.dtype`): The dtype of parameters, defaults to None. + device (`torch.device`): The device of parameters, defaults to None. + process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None. + gather_output (bool, optional): If true, call all-gather on output and make Y available + to all GPUs, otherwise, every GPU will have its output + which is :math:`Y_i = XA_i`, defaults to False + seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False. + overlap (`bool`): If set to ``True``, it will overlap input all-gather with gradient computation during backward, defaults to False. + skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer, + which is preserved for kernel fusion, defaults to False + weight_initializer (`typing.Callable`): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (`typing.Callable`): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + device: torch.device = None, + process_group: ProcessGroup = None, + gather_output: bool = False, + seq_parallel_mode: str = None, + seq_parallel_dim: int = 1, + overlap: torch.cuda.Stream = None, + skip_bias_add: bool = False, + weight: Optional[Parameter] = None, + bias_: Optional[Parameter] = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + fp8_communication: bool = False, + **kwargs, + ): + super().__init__(weight=weight, bias_=bias_, **kwargs) + + # Keep input parameters + self.in_features = in_features + self.out_features = out_features + self.gather_output = gather_output + self.seq_parallel_mode = seq_parallel_mode + self.seq_parallel_dim = seq_parallel_dim + self.overlap = overlap + self.skip_bias_add = skip_bias_add + self.device = device + self.process_group = process_group + self.fp8_communication = fp8_communication + + if skip_bias_add and not bias: + raise ValueError("cannot skip bias addition if bias is None") + + # offset the seed with randomizer index and rank + seed = torch.random.initial_seed() + self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) + + # sanity check + if weight is not None: + assert not bias or bias_ is not None, "bias_ must be provided if bias is True when weight is not None" + else: + assert bias_ is None, "bias_ must be None if weight is None" + + # Parameters. + if weight is None: + factory_kwargs = {"device": device, "dtype": dtype} + self.weight = Parameter(torch.empty(self.out_features, self.in_features, **factory_kwargs)) + else: + weight.data = weight.data.to(device=device, dtype=dtype) + self.weight = weight + + if not is_distributed_tensor(self.weight): + sharded_weight = shard_rowwise(self.weight.data, self.process_group) + sharded_tensor_to_existing_param(sharded_weight, self.weight) + + if bias: + if bias_ is None: + self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs)) + else: + bias_.data = bias_.data.to(device=device, dtype=dtype) + self.bias = bias_ + if not is_distributed_tensor(self.bias): + sharded_bias = shard_colwise(self.bias.data, self.process_group) + sharded_tensor_to_existing_param(sharded_bias, self.bias) + else: + self.bias = None + + if weight is None: + # init weights + self.reset_parameters(weight_initializer, bias_initializer) + + @staticmethod + def from_native_module( + module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], **kwargs + ) -> ParallelModule: + r""" + Convert a native PyTorch linear layer to a parallelized linear layer. + """ + LazyInitContext.materialize(module) + # get the attributes + in_features = module.in_features + out_features = module.out_features + bias = module.bias is not None + device = module.weight.device + # ensure only one process group is passed + if isinstance(process_group, (list, tuple)): + assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}." + process_group = process_group[0] + + tp_size = dist.get_world_size(process_group) + if out_features < tp_size: + return module + + if out_features % tp_size != 0: + raise ValueError( + f"The size of out_features:{out_features} is not integer multiples of tensor parallel size: {tp_size}!" + ) + + linear_1d = Linear1D_Col( + in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + process_group=process_group, + weight=module.weight, + bias_=module.bias, + **kwargs, + ) + + return linear_1d + + def reset_parameters(self, weight_initializer, bias_initializer) -> None: + with self.randomizer.fork_rng(enable_cpu=True): + fan_in, fan_out = self.in_features, self.out_features + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + + def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: + assert ( + input_.shape[-1] == self.weight.shape[-1] + ), "Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.".format( + input_.shape, self.weight.shape, self.weight.shape[-1] + ) + + # Set up backprop all-reduce. + input_parallel = input_ + + # Matrix multiply. + bias = self.bias if not self.skip_bias_add else None + + if self.seq_parallel_mode == "split_gather": + input_parallel = gather_forward_reducescatter_backward( + input_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication + ) + output_parallel = linear_with_async_comm( + input_parallel, self.weight, bias, self.process_group, False, fp8_communication=self.fp8_communication + ) + elif self.seq_parallel_mode == "ring": + output_parallel = linear_gather_forward_reducescatter_backward( + input_parallel, self.weight, bias, self.process_group, True, self.seq_parallel_dim, self.overlap, True + ) + else: + output_parallel = linear_with_async_comm( + input_parallel, self.weight, bias, self.process_group, True, fp8_communication=self.fp8_communication + ) + + if self.gather_output: + # All-gather across the partitions. + output = gather_forward_split_backward( + output_parallel, dim=-1, process_group=self.process_group, fp8_communication=self.fp8_communication + ) + else: + output = output_parallel + + if self.skip_bias_add: + return output, self.bias + else: + return output + + +class Linear1D_Row(ParallelModule): + r"""Linear layer with row parallelism + + Args: + in_features (int): size of each input sample. + out_features (int): size of each output sample. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (`torch.dtype`): The dtype of parameters, defaults to None. + parallel_input (bool): If set to ``True``, it's assumed that the input is split, defaults to False. + process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None. + seq_parallel_mode (`str`): The type of sp mode, it will use sequence parallel when `seq_parallel_mode` is not None. Defaults to None. + seq_parallel_dim (`int`): Which dim will sequence parallelism split and gather the sequence. + skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer, + which is preserved for kernel fusion, defaults to False + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + device: torch.device = None, + process_group: ProcessGroup = None, + seq_parallel_mode: str = None, + seq_parallel_dim: int = 1, + parallel_input: bool = True, + skip_bias_add: bool = False, + weight: Optional[Parameter] = None, + bias_: Optional[Parameter] = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + stream_chunk_num: int = 1, + fp8_communication: bool = False, + ): + super().__init__() + + self.stream_chunk_num = stream_chunk_num + + # Keep input parameters + self.in_features = in_features + self.out_features = out_features + self.parallel_input = parallel_input + self.skip_bias_add = skip_bias_add + self.process_group = process_group + self.seq_parallel_mode = seq_parallel_mode + self.seq_parallel_dim = seq_parallel_dim + self.num_partitions = dist.get_world_size(self.process_group) + self.fp8_communication = fp8_communication + + if skip_bias_add and not bias: + raise ValueError("cannot skip bias addition if bias is None") + + # offset the seed with randomizer index and rank + seed = torch.random.initial_seed() + self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) + + # sanity check + if weight is not None: + assert not bias or bias_ is not None, "bias_ must be provided if bias is True when weight is not None" + else: + assert bias_ is None, "bias_ must be None if weight is None" + + # Parameters. + if weight is None: + # Initialize weight. + factory_kwargs = {"device": device, "dtype": dtype} + self.weight = Parameter(torch.empty(self.out_features, self.in_features, **factory_kwargs)) + else: + weight.data = weight.data.to(device=device, dtype=dtype) + self.weight = weight + if not is_distributed_tensor(self.weight): + sharded_weight = shard_colwise(self.weight.data, self.process_group) + sharded_tensor_to_existing_param(sharded_weight, self.weight) + + if self.stream_chunk_num > 1: + # TODO() work for inference only + self.chunk_weight() + + if bias: + if bias_ is None: + self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs)) + else: + bias_.data = bias_.data.to(device=device, dtype=dtype) + self.bias = bias_ + else: + self.bias = None + + if weight is None: + with self.randomizer.fork_rng(enable_cpu=True): + self.reset_parameters(weight_initializer, bias_initializer) + + @staticmethod + def from_native_module( + module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], **kwargs + ) -> ParallelModule: + r""" + Convert a native PyTorch linear layer to a parallelized linear layer. + """ + LazyInitContext.materialize(module) + # get the attributes + in_features = module.in_features + out_features = module.out_features + bias = module.bias is not None + device = module.weight.device + + # ensure only one process group is passed + if isinstance(process_group, (list, tuple)): + assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}." + process_group = process_group[0] + + tp_size = dist.get_world_size(process_group) + if in_features < tp_size: + return module + + if in_features % tp_size != 0: + raise ValueError( + f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!" + ) + + linear_1d = Linear1D_Row( + in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + process_group=process_group, + weight=module.weight, + bias_=module.bias, + **kwargs, + ) + + return linear_1d + + def chunk_weight(self): + self.weight_list = torch.chunk(self.weight, self.stream_chunk_num, dim=0) + + @torch.no_grad() + def reset_parameters(self, weight_initializer, bias_initializer) -> None: + fan_in, fan_out = self.in_features, self.out_features + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + if self.process_group is None: + src_rank = 0 + else: + src_rank = dist.distributed_c10d._get_global_rank(self.process_group, 0) + + origin_device = self.bias.device + bias = self.bias.cuda() + dist.broadcast(bias, src=src_rank, group=self.process_group) + bias = bias.to(origin_device) + self.bias.copy_(bias) + + def forward(self, input_: Tensor) -> Tensor: + # Set up backprop all-reduce. + if self.parallel_input: + assert ( + input_.shape[-1] == self.weight.shape[-1] + ), "Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.".format( + input_.shape, self.weight.shape, self.weight.shape[-1] + ) + input_ = input_ + else: + assert ( + divide(input_.shape[-1], self.num_partitions) == self.weight.shape[-1] + ), "Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.".format( + input_.shape, self.weight.shape, self.weight.shape[-1] * self.num_partitions + ) + input_ = split_forward_gather_backward( + input_, dim=-1, process_group=self.process_group, fp8_communication=self.fp8_communication + ) + + if self.stream_chunk_num > 1: + if self.training: + raise RuntimeError("use stream_chunk_num=1 in Linear1D_Row for training!") + with torch.no_grad(): + output_parallel_list = [None for i in range(self.stream_chunk_num)] + handle_list = [] + for i in range(self.stream_chunk_num): + output_parallel_list[i] = F.linear(input_, self.weight_list[i]) + handle = torch.distributed.all_reduce( + output_parallel_list[i], group=self.process_group, async_op=True + ) + handle_list.append(handle) + for handle in handle_list: + handle.wait() + output = torch.cat(output_parallel_list, dim=-1) + else: + if self.seq_parallel_mode is None: + output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False) + output = reduce_forward(output_parallel, self.process_group, fp8_communication=self.fp8_communication) + elif self.seq_parallel_mode == "split_gather": + output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False) + output = reducescatter_forward_gather_backward( + output_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication + ) + elif self.seq_parallel_mode == "ring": + output = linear_reducescatter_forward_gather_backward( + input_, + self.weight, + process_group=self.process_group, + dim=self.seq_parallel_dim, + ring=True, + ) + else: + output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False) + output = reduce_forward(output_parallel, self.process_group) + + if not self.skip_bias_add: + if self.bias is not None: + output = output + self.bias + return output + else: + return output, self.bias + + +class PaddingLMHead(PaddingParallelModule): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + device: torch.device = None, + weight: Optional[Parameter] = None, + bias_: Optional[Parameter] = None, + make_vocab_size_divisible_by: int = 64, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + ): + # Keep input parameters + self.in_features = in_features + self.out_features = out_features + + if out_features % make_vocab_size_divisible_by != 0: + self.out_features = ( + out_features + make_vocab_size_divisible_by - (out_features % make_vocab_size_divisible_by) + ) + if weight is None: + factory_kwargs = {"device": device, "dtype": dtype} + weight = Parameter(torch.empty(out_features, self.in_features, **factory_kwargs)) + else: + weight.data = weight.data.to(device=device, dtype=dtype) + + if bias: + if bias_ is None: + self.bias = Parameter(torch.empty(out_features, **factory_kwargs)) + else: + bias_.data = bias_.data.to(device=device, dtype=dtype) + else: + bias_ = None + + # resize embeddings + super().__init__(self.out_features, out_features, weight, bias_) + + if weight is None: + self.reset_parameters(weight_initializer, bias_initializer) + + def reset_parameters(self, weight_initializer, bias_initializer) -> None: + fan_in, fan_out = self.in_features, self.out_features + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + + @staticmethod + def from_native_module( + module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], **kwargs + ) -> PaddingParallelModule: + r""" + Convert a native PyTorch linear layer to a parallelized linear layer. + """ + LazyInitContext.materialize(module) + # get the attributes + in_features = module.in_features + out_features = module.out_features + bias = module.bias is not None + device = module.weight.device + # ensure only one process group is passed + + lm_head_linear = PaddingLMHead( + in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + weight=module.weight, + bias_=module.bias, + **kwargs, + ) + + return lm_head_linear + + def forward(self, input: Tensor) -> Tensor: + # output = F.linear(input, self.weight, self.bias) + output = LinearWithFusedGradAccu.apply(input, self.weight, self.bias) + output = output[..., : self.old_num_embeddings] + return output + + +class VocabParallelLMHead1D(Linear1D_Col, PaddingParallelModule): + r"""Linear layer with column parallelism. + + The linear layer is defined as :math:`Y = XA + b`. A is parallelized along + its second dimension as :math:`A = [A_1, ..., A_p]`. + + Args: + in_features (int): size of each input sample. + out_features (int): size of each output sample. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (`torch.dtype`): The dtype of parameters, defaults to None. + device (`torch.device`): The device of parameters, defaults to None. + process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None. + gather_output (bool, optional): If true, call all-gather on output and make Y available + to all GPUs, otherwise, every GPU will have its output + which is :math:`Y_i = XA_i`, defaults to False + seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False. + overlap (`bool`): If set to ``True``, it will overlap input all-gather with gradient computation during backward, defaults to False. + skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer, + which is preserved for kernel fusion, defaults to False + weight_initializer (`typing.Callable`): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (`typing.Callable`): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + device: torch.device = None, + process_group: ProcessGroup = None, + weight: Optional[Parameter] = None, + bias_: Optional[Parameter] = None, + make_vocab_size_divisible_by: int = 64, + fp8_communication: bool = False, + **kwargs, + ): + # create weight and bias + if weight is None: + factory_kwargs = {"device": device, "dtype": dtype} + weight = Parameter(torch.empty(out_features, self.in_features, **factory_kwargs)) + if bias: + if bias_ is None: + bias_ = Parameter(torch.empty(out_features, **factory_kwargs)) + else: + bias_ = None + + # calculate new vocab size + self.tensor_parallel_size = dist.get_world_size(group=process_group) + new_out_features = out_features + multiple = make_vocab_size_divisible_by * self.tensor_parallel_size + if out_features % multiple != 0: + new_out_features = out_features + multiple - (out_features % multiple) + + super().__init__( + in_features=in_features, + out_features=new_out_features, + bias=bias, + device=device, + process_group=process_group, + weight=weight, + bias_=bias_, + **kwargs, + new_num_embeddings=new_out_features, + old_num_embeddings=out_features, + fp8_communication=fp8_communication, + ) + # get the length of valid embeddings + tp_rank = dist.get_rank(process_group) + partition_size = self.new_num_embeddings // dist.get_world_size(process_group) + if self.old_num_embeddings >= (tp_rank + 1) * partition_size: + self.num_valid_embeddings_local = partition_size + elif self.old_num_embeddings >= tp_rank * partition_size: + self.num_valid_embeddings_local = self.old_num_embeddings - tp_rank * partition_size + else: + self.num_valid_embeddings_local = 0 + + @staticmethod + def from_native_module( + module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], **kwargs + ) -> PaddingParallelModule: + r""" + Convert a native PyTorch linear layer to a parallelized linear layer. + """ + LazyInitContext.materialize(module) + # get the attributes + in_features = module.in_features + out_features = module.out_features + bias = module.bias is not None + device = module.weight.device + + lm_head_linear = VocabParallelLMHead1D( + in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + process_group=process_group, + weight=module.weight, + bias_=module.bias, + **kwargs, + ) + + return lm_head_linear + + def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: + # get forward output + if self.skip_bias_add: + output, bias = super().forward(input_) + else: + output = super().forward(input_) + + # delete the padding of output + if self.gather_output: + output = output[..., : self.old_num_embeddings] + else: + output = output[..., : self.num_valid_embeddings_local] + + # return + if self.skip_bias_add: + return output, bias + return output + +class LinearWithFusedGradientAccu(torch.nn.Linear): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, input): + return LinearWithFusedGradAccu.apply(input, self.weight, self.bias) \ No newline at end of file diff --git a/toolbox/ColossalAI/v0.4.4/patches/colossalai/shardformer/layer/loss.py b/toolbox/ColossalAI/v0.4.4/patches/colossalai/shardformer/layer/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..5214f1fa60025fe677d47852ad2d9f88057abd8f --- /dev/null +++ b/toolbox/ColossalAI/v0.4.4/patches/colossalai/shardformer/layer/loss.py @@ -0,0 +1,254 @@ +#!/usr/bin/env python3 +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +import torch +import torch.distributed as dist +from torch.autograd import Function +from torch.distributed import ProcessGroup +from torch.nn import CrossEntropyLoss + +from colossalai.shardformer.layer._operation import reduce_forward +from colossalai.shardformer.shard import ShardConfig + +from .utils import is_share_sp_tp + +__all__ = ["DistCrossEntropy", "cross_entropy_1d", "dist_cross_entropy"] + +_IGNORE_IDX = -100 + + +class DistCrossEntropy(Function): + r""" + Overwrite the forward and backward function to calculate the cross entropy loss before gather + + Args: + Function (:class:`torch.autograd.Function`): default + """ + + @staticmethod + def forward( + ctx, + vocab_logits: torch.Tensor, + target: torch.Tensor, + ignore_index: int, + process_group: ProcessGroup, + vocab_size: int, + dtype=torch.float32, + mode="mean", + ): + r""" + Calculate the cross entropy loss before gather, the origin loss function is as follows: + loss = -log(exp(x[class])/sum(exp(x[i])) + and can be rewriten as: + loss = log(sum(exp(x[i])) - x[class] + + To avoid the `nan` of log(sum(exp(x[i]))), we minus the max of x[i] + + Args: + vocab_logits (:class:`torch.Tensor`): The logits of the vocabulary, shape is + [batch_size, seq_len, vocab_size] + target (:class:`torch.Tensor`): The labels of the vocabulary, shape is + [batch_size, seq_len] + + Returns: + :class:`torch.Tensor`: The cross entropy loss + """ + assert mode in ["mean", "sum"] + # get the max + logits_max = torch.max(vocab_logits, dim=-1)[0] + handle = dist.all_reduce(logits_max, op=dist.ReduceOp.MAX, group=process_group, async_op=True) + + # mask the target in the local device + rank = dist.get_rank(group=process_group) + world_size = dist.get_world_size(group=process_group) + if vocab_size == None: + partition_vocab_size = vocab_logits.size()[-1] + global_vocab_size = partition_vocab_size * world_size + else: + global_vocab_size = vocab_size + partition_vocab_size = global_vocab_size // world_size + + # [down, up) => false, other device and -100 => true + delta = (global_vocab_size + world_size - 1) // world_size + down_threshold = rank * delta + up_threshold = down_threshold + delta + if up_threshold > global_vocab_size: + up_threshold = global_vocab_size + mask = (target < down_threshold) | (target >= up_threshold) + masked_target = target.clone() - down_threshold + masked_target[mask] = 0 + masked_target_1d = masked_target.view(-1).contiguous() + + # minus the max to avoid the result of sum of exp is too large and the log is nan + handle.wait() + vocab_logits = vocab_logits - logits_max.unsqueeze(dim=-1) + # reshape the logits and target + # reshape the vocab_logits to [bath_size * seq_len, vocab_size] + # reshape the labels to [bath_size * seq_len] + self_vocab_size = vocab_logits.size()[-1] + logits_2d = vocab_logits.view(-1, self_vocab_size) + + # extract the x[class] and set the x[other device] to zero + idx = torch.arange(start=0, end=logits_2d.shape[0], device=logits_2d.device) + pred_logits_1d = logits_2d[idx, masked_target_1d].contiguous() + pred_logits = pred_logits_1d.view_as(target) + pred_logits[mask] = 0.0 + + # all-reduce to get full x[i, y] + handle = dist.all_reduce(pred_logits, op=dist.ReduceOp.SUM, group=process_group, async_op=True) + exp_logits = vocab_logits + torch.exp(vocab_logits, out=exp_logits) + sum_exp_logits = torch.sum(exp_logits, dim=-1, dtype=torch.float32) + dist.all_reduce(sum_exp_logits, op=dist.ReduceOp.SUM, group=process_group) + + # calculate the loss + # loss = log(sum(exp(x[i]))) - x[class] + handle.wait() + loss = torch.where(target == ignore_index, 0.0, torch.log(sum_exp_logits) - pred_logits) + if mode == "mean": + num_non_zero = torch.sum(loss != 0.0) + ctx.inv_num_non_zero = 1.0 / num_non_zero + loss = torch.sum(loss).div_(num_non_zero) + else: + loss = torch.sum(loss) + + # calculate the softmax + exp_logits = exp_logits.div(sum_exp_logits.unsqueeze(dim=-1)).to(dtype) + exp_logits[target == ignore_index] = 0.0 + ctx.save_for_backward(exp_logits, mask, masked_target_1d) + ctx.dtype = dtype + ctx.mode = mode + + return loss + + @staticmethod + def backward(ctx, grad_output): + # retrieve the saved tensors + if ctx.mode == "mean": + grad_output = grad_output * ctx.inv_num_non_zero + exp_logits, mask, masked_target_1d = ctx.saved_tensors + + # use exp logits as the input grad + grad_logits = exp_logits + partion_vocab_size = grad_logits.shape[-1] + grad_logits_2d = grad_logits.view(-1, partion_vocab_size) + + update = 1.0 - mask.view(-1).float().to(ctx.dtype) + grad_logits_2d[torch.arange(0, grad_logits_2d.shape[0]), masked_target_1d] -= update + + grad_logits.mul_(grad_output.unsqueeze(dim=-1)) + return grad_logits, None, None, None, None, None, None + + +def cross_entropy_1d( + vocab_logits: torch.Tensor, + labels: torch.Tensor, + ignore_index: int = _IGNORE_IDX, + process_group: ProcessGroup = None, + vocab_size: int = None, + dtype: torch.dtype = None, + mode: str = "mean", +) -> torch.Tensor: + return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group, vocab_size, dtype, mode) + + +def dist_cross_entropy( + labels: torch.Tensor, # [B, S] or [B, S, Vocab_size] + logits: torch.Tensor, # [B, S, Vocab_size] + shard_config: ShardConfig, + vocab_size: int, + dtype: torch.dtype, + seq_dim: int = 1, +) -> torch.Tensor: + """ + Helper to compute cross entropy loss for most shardformer models supporting PP, TP and SP. + """ + # Split labels if not gather output + sp_group = shard_config.sequence_parallel_process_group + sp_rank = dist.get_rank(sp_group) + sp_size = shard_config.sequence_parallel_size + sp_mode = shard_config.sequence_parallelism_mode + parallel_output = shard_config.parallel_output + is_tp = shard_config.enable_tensor_parallelism + is_packed = labels.dim() == 2 + if is_packed: + bs, seq_len = labels.shape + else: + # padded sequence + seq_len = labels.shape[-1] + logits = logits.reshape(-1, *logits.shape[2:]) + seq_dim = 0 + + # Shift labels to predict the next token, and remove the tail logit predicting + is_sp = sp_size > 1 and (not is_share_sp_tp(sp_mode)) + split_labels_here = seq_len // sp_size == logits.size(seq_dim) # ring attn splits labels before forward + + if sp_mode == "ring_attn": + # For Zigzag Ring Attention, labels should've been split and + # shifted by RingAttention.prepare_varlen_batch() + if sp_rank == 0: + # logits = logits[..., :-1, :] + logits = logits + # logits = torch.cat([logits, torch.full_like(logits[:, :1, :], _IGNORE_IDX)], dim=seq_dim) + elif is_sp: + # Shift only once: either before splitting or in the last rank without splitting + if split_labels_here or (sp_rank == sp_size - 1): + # labels = labels[..., 1:] + labels = labels + if split_labels_here: + labels = labels.split(seq_len // sp_size, dim=-1)[sp_rank] + + if sp_rank == sp_size - 1: + # logits = logits[..., :-1, :] + logits = logits + # Pad logits and labels to the same shape across all ranks for TP all_reduce + if is_tp and parallel_output: + # If is packed sequence (label dim is 1), then each seq already has the end label token padded. + # torch.cat is faster than F.pad... + pad_shape = (logits.shape[0], 1, *logits.shape[2:]) if is_packed else (1, *logits.shape[1:]) + padding = torch.full(pad_shape, _IGNORE_IDX, dtype=logits.dtype, device=logits.device) + logits = torch.cat([logits, padding], dim=seq_dim) + pad_shape = (labels.shape[0], 1) if is_packed else (1,) + padding = torch.full(pad_shape, _IGNORE_IDX, dtype=labels.dtype, device=labels.device) + labels = torch.cat([labels, padding], dim=seq_dim) + else: + # labels = labels[..., 1:] + # logits = logits[..., :-1, :] + labels = labels # 在 datacollator 内处理,此处无需切片处理 + logits = logits + + labels = labels.contiguous() + logits = logits.contiguous() + num_nonzero = (labels != _IGNORE_IDX).sum() + assert labels.shape == logits.shape[:-1], f"label shape {labels.shape} does not match logit shape {logits.shape}" + + # Flatten the tokens + loss_fct = CrossEntropyLoss(ignore_index=_IGNORE_IDX, reduction="sum") + labels = labels.view(-1) + + if is_tp and parallel_output: + # Cross entropy with all-reduce for TP + new_vocab_size = logits.shape[-1] + logits = logits.view(-1, new_vocab_size) + loss = cross_entropy_1d( + logits, + labels, + process_group=shard_config.tensor_parallel_process_group, + vocab_size=vocab_size, + dtype=dtype, + mode="sum", + ) + else: + # NOTE if use TP and not parallel_output, the output is gathered in VocabParallelLMHead1D + logits = logits.view(-1, logits.size(-1)) + loss = loss_fct(logits, labels) + + # Reduce loss instead of gathering logits over seq dim for savings + if split_labels_here or sp_mode == "ring_attn": + # Get the global non-zero count + loss = torch.stack((loss, num_nonzero)) + # Rescale to offset the grad / (DP * SP) in HybridParallelPlugin + loss = reduce_forward(loss, sp_group, grad_scale=sp_size) + loss, num_nonzero = loss[0], loss[1].detach() + loss = (loss / num_nonzero).squeeze() + return loss diff --git a/toolbox/ColossalAI/v0.4.4/patches/colossalai/shardformer/layer/mlp.py b/toolbox/ColossalAI/v0.4.4/patches/colossalai/shardformer/layer/mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..be69f738c428789f30669aa8c7644cb7a00debf9 --- /dev/null +++ b/toolbox/ColossalAI/v0.4.4/patches/colossalai/shardformer/layer/mlp.py @@ -0,0 +1,90 @@ +#!/usr/bin/env python3 +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +import warnings +from typing import List, Optional, Tuple, Union + +import torch +import torch.distributed +import torch.nn as nn + +from transformers.models.llama.configuration_llama import LlamaConfig +from transformers.models.llama.modeling_llama import LlamaMLP +from transformers import Cache +from transformers.utils import logging + +from colossalai.lazy import LazyInitContext +from colossalai.shardformer.layer.linear import LinearWithFusedGradientAccu + +try: + from apex.corex.activations import SwiGLUFunction + swiglu_available=True +except: + swiglu_available=False + + +logger = logging.get_logger(__name__) + +class SwiGLU(torch.nn.Module): + def forward(self, input): + return SwiGLUFunction.apply(input) + + +class BaseLlamaMLP(LlamaMLP): + """ + 这个层主要的优化点是:将linear1(act(cat(linear2(x), linear3(x))))的结构变成 linear1(act(linear23(x))) + """ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # self.gate_up = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=False) + self.gate_up = LinearWithFusedGradientAccu(self.hidden_size, self.intermediate_size * 2, bias=False) + self.down = LinearWithFusedGradientAccu(self.intermediate_size, self.hidden_size, bias=False) + + del self.gate_proj, self.up_proj, self.down_proj + self.swiglu_available = swiglu_available + if swiglu_available: + self.act_fn = SwiGLU() + + def forward(self, x): + if not self.swiglu_available: + gate_proj, up_proj = self.gate_up(x).split((self.intermediate_size, self.intermediate_size), dim=-1) + down_proj = self.down(self.act_fn(gate_proj) * up_proj) + else: + res = self.gate_up(x) + down_proj = self.down(self.act_fn(res)) + return down_proj + + +class IXFLlamaMLP(BaseLlamaMLP): + def __init__(self) -> None: + raise NotImplementedError( + "IXFLlamaMLP is not implemented as a physical class. " + "It is meant to be used only with the from_native_module interface to Convert a native LlamaAttention module to IXFLlamaMLP module provided above." + ) + + @staticmethod + def from_native_module(module: nn.Module, *args, **kwargs) -> nn.Module: + + LazyInitContext.materialize(module) + + config = getattr(module, "config") + + mlp = BaseLlamaMLP(config=config) + + mlp.gate_up.weight.data = torch.concat((module.gate_proj.weight.data, module.up_proj.weight.data), dim=0) + mlp.down.weight.data = module.down_proj.weight.data + + return mlp \ No newline at end of file diff --git a/toolbox/ColossalAI/v0.4.4/patches/colossalai/shardformer/layer/normalization.py b/toolbox/ColossalAI/v0.4.4/patches/colossalai/shardformer/layer/normalization.py new file mode 100644 index 0000000000000000000000000000000000000000..6b13ecc0271ea17e553fd27337be1245d7696329 --- /dev/null +++ b/toolbox/ColossalAI/v0.4.4/patches/colossalai/shardformer/layer/normalization.py @@ -0,0 +1,335 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +#!/usr/bin/env python +# -*- encoding: utf-8 -*- +import warnings +from abc import ABC, abstractmethod + +import torch.nn as nn + +from colossalai.lazy import LazyInitContext + +from ._operation import hook_parameter_in_backward +from .utils import SeqParallelUtils + +__all__ = ["FusedLayerNorm", "FusedRMSNorm", "LayerNorm", "RMSNorm", "BaseLayerNorm"] + +try: + from apex.contrib.layer_norm.layer_norm import FastLayerNorm + + EnableFastLayerNorm = True +except ImportError: + EnableFastLayerNorm = False + +try: + from apex.normalization import FusedLayerNorm as ApexFusedLayerNorm + + class FusedLayerNormWithHook(ApexFusedLayerNorm): + def __init__(self, normalized_shape, eps=0.00001, elementwise_affine=True): + super().__init__(normalized_shape, eps, elementwise_affine) + + def forward(self, input): + output = super().forward(input) + output = hook_parameter_in_backward(output, self.weight, self.bias) + return output + +except ImportError: + warnings.warn("Please install apex from source (https://github.com/NVIDIA/apex) to use the fused RMSNorm kernel") + +try: + from ixformer.train import FusedRMSNorm + class FusedRMSNormWithHook(FusedRMSNorm): + def __init__(self, normalized_shape, eps=0.00001, elementwise_affine=True, gradient_accumulation_fusion = True): + super().__init__(normalized_shape, eps, elementwise_affine, gradient_accumulation_fusion = gradient_accumulation_fusion) + + def forward(self, input): + output = super().forward(input) + output = hook_parameter_in_backward(output, self.weight) + return output + +except ImportError: + warnings.warn("Please install ixformer to use the fused rmsnorm kernel") + +FAST_LAYERNORM_SUPPORTED_SIZE = [ + 1024, + 1536, + 2048, + 2304, + 3072, + 3840, + 4096, + 5120, + 6144, + 8192, + 10240, + 12288, + 12800, + 15360, + 16384, + 18432, + 20480, + 24576, + 25600, + 30720, + 32768, + 40960, + 49152, + 65536, +] + +if EnableFastLayerNorm: + + class FastLayerNormWithHook(FastLayerNorm): + def __init__(self, hidden_size, eps=0.00001): + super().__init__(hidden_size, eps) + + def forward(self, input): + output = super().forward(input) + output = hook_parameter_in_backward(output, self.weight, self.bias) + return output + + +class BaseLayerNorm(ABC): + @abstractmethod + def from_native_module(module: nn.Module, sp_partial_derived: bool = False): + """ + Convert a native PyTorch layer normalization module to a specific layer normalization module, + and optionally mark parameters for gradient aggregation. + + Args: + module (nn.Module): The native PyTorch layer normalization module to be converted. + sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism. + + Returns: + nn.Module: The specific layer normalization module. + + Raises: + AssertionError: If the provided module is not an instance of the supported layer normalization type. + """ + + +class RMSNorm(BaseLayerNorm): + r""" + This is a wrapper around the RMSNorm. It is meant to be used only with the from_native_module interface. + """ + + def __init__(self) -> None: + raise NotImplementedError( + "FusedLayerNorm is not implemented as a physical class. " + "It is meant to be used only with the from_native_module interface to convert a native RMSNorm module to colossalai layer norm module." + ) + + @staticmethod + def from_native_module(module: nn.Module, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module: + """ + Convert a native RMSNorm module to colossalai layer norm module, + and optionally mark parameters for gradient aggregation. + + Args: + module (nn.Module): The native RMSNorm module to be converted. + sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism. + + Returns: + nn.Module: The RMSNorm module. + """ + + LazyInitContext.materialize(module) + + if sp_partial_derived: + # Since gradients are computed using only a subset of the data, + # aggregation of these gradients is necessary during backpropagation. + # Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation. + SeqParallelUtils.marked_as_sp_partial_derived_param(module.weight) + + return module + + +class LayerNorm(BaseLayerNorm): + r""" + This is a wrapper around native LayerNorm. It is meant to be used only with the from_native_module interface. + """ + + def __init__(self) -> None: + raise NotImplementedError( + "LayerNorm is not implemented as a physical class. " + "It is meant to be used only with the from_native_module interface to convert a native LayerNorm module to colossalai layer norm module." + ) + + @staticmethod + def from_native_module(module: nn.Module, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module: + r""" + Convert a native LayerNorm module to colossalai layer norm module, + and optionally marking parameters for gradient aggregation. + + Args: + module (nn.Module): The native LayerNorm module to be converted. + sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism. + + Returns: + nn.Module: The colossalai LayerNorm module. + + """ + + LazyInitContext.materialize(module) + + if sp_partial_derived: + # Since gradients are computed using only a subset of the data, + # aggregation of these gradients is necessary during backpropagation. + # Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation. + SeqParallelUtils.marked_as_sp_partial_derived_param(module.weight) + if module.bias is not None: + SeqParallelUtils.marked_as_sp_partial_derived_param(module.bias) + + return module + + +class FusedLayerNorm(BaseLayerNorm): + r""" + This is a wrapper around the apex fused layernorm implementation. It is meant to be used only with the from_native_module interface. + """ + + def __init__(self) -> None: + raise NotImplementedError( + "FusedLayerNorm is not implemented as a physical class. " + "It is meant to be used only with the from_native_module interface convert a native LayerNorm module to FusedLayerNorm module provided by apex." + ) + + @staticmethod + def from_native_module(module: nn.LayerNorm, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module: + r""" + Convert a native LayerNorm module to FusedLayerNorm module provided by apex, + and optionally marking parameters for gradient aggregation. + + Args: + module (nn.Module): The native LayerNorm module to be converted. + sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism. + + Returns: + nn.Module: Union[FastLayerNorm, FusedLayerNorm]. + + """ + + LazyInitContext.materialize(module) + # get the attributes of the module + normalized_shape = getattr(module, "normalized_shape", module.weight.shape[0]) + eps = module.variance_epsilon if hasattr(module, "variance_epsilon") else module.eps + elementwise_affine = getattr(module, "elementwise_affine", True) + dtype = module.weight.dtype + device = module.weight.device + + # pick the suitable layernorm implementation + use_fast_ln = normalized_shape in FAST_LAYERNORM_SUPPORTED_SIZE + + if use_fast_ln: + if EnableFastLayerNorm: + ApexFusedLayerNorm = FastLayerNormWithHook + else: + # fall back to the normal fused layernorm is not built + ApexFusedLayerNorm = FusedLayerNormWithHook + else: + try: + ApexFusedLayerNorm = FusedLayerNormWithHook + except NameError: + warnings.warn( + "Please install Apex from source to use fused kernels, or set self.enable_fused_normalization = False. Using native layernorm instead." + ) + return module + + layernorm = ( + ApexFusedLayerNorm(normalized_shape, eps=eps, elementwise_affine=elementwise_affine).to(dtype).to(device) + ) + layernorm.weight = module.weight + if module.bias is not None: + layernorm.bias = module.bias + + if sp_partial_derived: + # Since gradients are computed using only a subset of the data, + # aggregation of these gradients is necessary during backpropagation. + # Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation. + SeqParallelUtils.marked_as_sp_partial_derived_param(layernorm.weight) + SeqParallelUtils.marked_as_sp_partial_derived_param(layernorm.bias) + + return layernorm + + +class FusedRMSNorm(BaseLayerNorm): + """ + This is a wrapper around the apex fused rms norm implementation. It is meant to be used only with the from_native_module interface. + """ + + def __init__(self) -> None: + raise NotImplementedError( + "FusedRMSNorm is not implemented as a physical class. " + "It is meant to be used only with the from_native_module interface to Convert a native RMSNorm module to FusedRMSNorm module provided by apex." + ) + + @staticmethod + def from_native_module(module: nn.Module, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module: + r""" + Convert a native RMSNorm module module to FusedRMSNorm module provided by apex, + and optionally marking parameters for gradient aggregation. + + Args: + module (nn.LayerNorm): The native PyTorch LayerNorm module to be converted. + sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism. + + Returns: + nn.Module: FusedRMSNorm module. + """ + + LazyInitContext.materialize(module) + + # try to get normalized_shape, eps, elementwise_affine from the module + normalized_shape = getattr(module, "normalized_shape", module.weight.shape[0]) + eps = module.variance_epsilon if hasattr(module, "variance_epsilon") else module.eps + elementwise_affine = getattr(module, "elementwise_affine", True) + + try: + rmsnorm = FusedRMSNormWithHook( + normalized_shape=normalized_shape, + eps=eps, + elementwise_affine=elementwise_affine, + gradient_accumulation_fusion = True + ) + except ImportError: + warnings.warn( + "Module replacement failed.\ + Please install apex from source (https://github.com/NVIDIA/apex) to use the fused RMS normalization kernel" + ) + return module + + rmsnorm.weight = module.weight + + if sp_partial_derived: + # Since gradients are computed using only a subset of the data, + # aggregation of these gradients is necessary during backpropagation. + # Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation. + SeqParallelUtils.marked_as_sp_partial_derived_param(rmsnorm.weight) + + return rmsnorm + + +from ixformer.train import FusedRMSNormRes +class Colo_FusedRMSNorm(nn.Module): + def __init__(self) -> None: + raise NotImplementedError( + "Colo_FusedRMSNorm is not implemented as a physical class. " + "It is meant to be used only with the from_native_module interface to Convert a native LlamaRMSNorm(from transformers) module to RMSNormResidualOrNot module provided above." + ) + + @staticmethod + def from_native_module(module: nn.Module, *args, **kwargs) -> nn.Module: + + LazyInitContext.materialize(module) + + normalized_shape = getattr(module, "normalized_shape", module.weight.size()) + eps = module.variance_epsilon if hasattr(module, "variance_epsilon") else module.eps + + FusedRMSNorm = FusedRMSNormRes(normalized_shape, + eps=eps, + gradient_accumulation_fusion = True) + FusedRMSNorm.weight = module.weight + # FusedRMSNorm.weight.grad = torch.zeros_like(FusedRMSNorm.weight) + + # print(f"rank:{torch.distributed.get_rank()}, weight:{FusedRMSNorm.weight.size()}, grad:{FusedRMSNorm.weight.grad.size()}") + return FusedRMSNorm \ No newline at end of file diff --git a/toolbox/ColossalAI/v0.4.4/patches/colossalai/shardformer/layer/rotary_pos_embedding.py b/toolbox/ColossalAI/v0.4.4/patches/colossalai/shardformer/layer/rotary_pos_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..a1087f930c72d6a8d679cf89b564c3a7f772a494 --- /dev/null +++ b/toolbox/ColossalAI/v0.4.4/patches/colossalai/shardformer/layer/rotary_pos_embedding.py @@ -0,0 +1,182 @@ +#!/usr/bin/env python3 +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + +from __future__ import annotations + +from typing import TYPE_CHECKING + +# if TYPE_CHECKING: +# from megatron_ds.core.transformer.transformer_config import TransformerConfig +# from megatron_ds.core.transformer.transformer_block import TransformerBlock + +import torch +from torch import Tensor, nn + +# from megatron_ds.core import parallel_state + +__all__ = ['RotaryEmbedding', 'apply_rotary_pos_emb'] + + +# def get_pos_emb_on_this_cp_rank(pos_emb, seq_dim): +# cp_size = parallel_state.get_context_parallel_world_size() +# cp_rank = parallel_state.get_context_parallel_rank() +# cp_idx = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)], device=pos_emb.device) +# pos_emb = pos_emb.view( +# *pos_emb.shape[:seq_dim], 2 * cp_size, -1, *pos_emb.shape[(seq_dim + 1) :] +# ) +# pos_emb = pos_emb.index_select(seq_dim, cp_idx) +# pos_emb = pos_emb.view(*pos_emb.shape[:seq_dim], -1, *pos_emb.shape[(seq_dim + 2) :]) +# return pos_emb + + +class RotaryEmbedding(nn.Module): + """Rotary Embedding for language model. + + Args: + kv_channels (int): Projection weights dimension in multi-head attention. Obtained from transformer config + rotary_percent (float): Percent of rotary dimension to use for rotary position embeddings. + seq_len_interpolation_factor (float, optional): scale of linearly interpolating RoPE for longer sequences. The value must be a float larger than 1.0. Defaults to None + rotary_base (int, optional): Base period for rotary position embeddings. Defaults to 10000. + """ + + def __init__( + self, + kv_channels: int, + rotary_percent: float, + seq_len_interpolation_factor: float = None, + rotary_base: int = 10000, + ) -> None: + super().__init__() + + dim = kv_channels + if rotary_percent < 1.0: + dim = int(dim * rotary_percent) + + self.seq_len_interpolation_factor = seq_len_interpolation_factor + self.inv_freq = 1.0 / ( + rotary_base + ** ( + torch.arange(0, dim, 2, dtype=torch.float32, device=torch.cuda.current_device()) + / dim + ) + ) + + def forward(self, max_seq_len: int, offset: int = 0) -> Tensor: + """Forward pass of RoPE embedding. + + Args: + max_seq_len (int): Maximum size of sequence + offset (int, optional): _description_. Defaults to 0. + + Returns: + Tensor: Embeddings after applying RoPE. + """ + seq = ( + torch.arange(max_seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + + offset + ) + + if self.seq_len_interpolation_factor is not None: + seq *= 1 / self.seq_len_interpolation_factor + + freqs = torch.outer(seq, self.inv_freq) + # first part even vector components, second part odd vector components, + # 2 * dim in dimension size + emb = torch.cat((freqs, freqs), dim=-1) + # emb [seq_length, .., dim] + emb = emb[:, None, None, :] + # if parallel_state.get_context_parallel_world_size() > 1: + # # slice rotary_pos_emb along sequence dimension and select the parition of the current CP rank + # emb = get_pos_emb_on_this_cp_rank(emb, 0) + return emb + + # def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): + # state_dict.pop(f'{prefix}inv_freq', None) + # return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) + + # def get_rotary_seq_len( + # self, + # inference_params, + # transformer: TransformerBlock, + # transformer_input: Tensor, + # transformer_config: TransformerConfig, + # ) -> float: + # """Function to get the rotary sequence length. + + # Args: + # inference_params : Used during Inference time + # transformer (TransformerBlock): The transformer block (decoder/encoder) used by the model + # transformer_input (Tensor): _description_ + # transformer_config (TransformerConfig): Transformer config used by the model + + # Returns: + # float: The rotary sequence length + # """ + # if inference_params is not None: + # rotary_seq_len = inference_params.max_sequence_length + # else: + # if transformer.input_tensor is not None: + # rotary_seq_len = transformer.input_tensor.size(0) + # else: + # rotary_seq_len = transformer_input.size(0) + + # if transformer_config.sequence_parallel: + # rotary_seq_len *= transformer_config.tensor_model_parallel_size + + # rotary_seq_len *= transformer_config.context_parallel_size + + # return rotary_seq_len + + +def _rotate_half(x: Tensor) -> Tensor: + """Change sign so the last dimension becomes [-odd, +even] + + Args: + x (Tensor): Input tensor + + Returns: + Tensor: Tensor rotated half + """ + + x1, x2 = torch.chunk(x, 2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(t: Tensor, freqs: Tensor) -> Tensor: + """Apply rotary positional embedding to input tensor T. + + check https://kexue.fm/archives/8265 for detailed formulas + + Args: + t (Tensor): Input tensor T is of shape [seq_length, ... , dim] + freqs (Tensor): Rotary Positional embedding tensor freq is of shape [seq_length, ..., dim] + + Returns: + Tensor: The input tensor after applying RoPE + """ + rot_dim = freqs.shape[-1] + + # ideally t_pass is empty so rotary pos embedding is applied to all tensor t + t, t_pass = t[..., :rot_dim], t[..., rot_dim:] + + # first part is cosine component + # second part is sine component, need to change signs with _rotate_half method + cos_ = torch.cos(freqs).to(t.dtype) + sin_ = torch.sin(freqs).to(t.dtype) + + t = (t * cos_) + (_rotate_half(t) * sin_) + return torch.cat((t, t_pass), dim=-1) diff --git a/toolbox/ColossalAI/v0.4.4/patches/colossalai/shardformer/modeling/llama.py b/toolbox/ColossalAI/v0.4.4/patches/colossalai/shardformer/modeling/llama.py new file mode 100644 index 0000000000000000000000000000000000000000..c5c3380f85f7cf9258d5039c52ed5f2c997ab215 --- /dev/null +++ b/toolbox/ColossalAI/v0.4.4/patches/colossalai/shardformer/modeling/llama.py @@ -0,0 +1,700 @@ +#!/usr/bin/env python3 +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +import math +import warnings +from typing import Dict, List, Optional, Tuple, Union + +import torch +import torch.distributed +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from transformers.cache_utils import Cache, DynamicCache +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, +) +from transformers.models.llama.modeling_llama import ( + LlamaForCausalLM, + LlamaForSequenceClassification, + LlamaModel, + StaticCache, + apply_rotary_pos_emb, + repeat_kv, +) +from transformers.utils import logging + +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer.layer._operation import all_to_all_comm, gather_sp_output, split_forward_gather_backward +from colossalai.shardformer.layer.utils import is_share_sp_tp, split_batch_zigzag +from colossalai.shardformer.shard import ShardConfig + +from ..layer import ColoAttention, RingAttention, dist_cross_entropy + +_SUPPORTED_SP_MODE = ["all_to_all", "split_gather", "ring", "ring_attn"] + + +class LlamaPipelineForwards: + """ + This class serves as a micro library for forward function substitution of Llama models + under pipeline setting. + """ + + @staticmethod + def llama_model_forward( + self: LlamaModel, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + force_sp_gather: bool = True, # Set to false only when computing cross entropy + ): + logger = logging.get_logger(__name__) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with pipeline parallelism. Setting `use_cache=False`..." + ) + use_cache = False + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + disable_pp = stage_manager is None + # retrieve input_ids and inputs_embeds + if disable_pp or stage_manager.is_first_stage(): + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape[:2] + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape[:2] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + hidden_states = inputs_embeds + device = hidden_states.device + else: + input_shape = hidden_states.shape[:-1] + batch_size, seq_length = input_shape + device = hidden_states.device + + # Support SP + PP + sp_mode = shard_config.sequence_parallelism_mode + sp_group = shard_config.sequence_parallel_process_group + sp_size = shard_config.sequence_parallel_size + # Generating full positions ids for modes that gather sequence before attn + if stage_manager and (sp_mode != "ring_attn" and not stage_manager.is_first_stage()): + seq_length *= sp_size + + past_seen_tokens = 0 + if use_cache: # kept for BC (cache positions) + if not isinstance(past_key_values, StaticCache): + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_seen_tokens = past_key_values.get_seq_length() + if cache_position is None: + if isinstance(past_key_values, StaticCache): + raise ValueError("cache_position is a required argument when using StaticCache.") + cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_length, device=device) + + seq_length_with_past = seq_length + past_seen_tokens + + if output_attentions: + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") + output_attentions = False + if output_hidden_states: + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") + output_hidden_states = False + if use_cache: + logger.warning_once("use_cache=True is not supported for pipeline models at the moment.") + use_cache = False + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + no_split_input = disable_pp or not stage_manager.is_first_stage() + if no_split_input and sp_mode == "ring_attn": + _, attn_kwargs, _ = RingAttention.prepare_varlen_batch(attention_mask, sp_group) + elif shard_config.use_colo_llamaflashatten: # use flash_attention_2 + attn_kwargs: torch.Tensor = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif shard_config.enable_flash_attention: + mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past) + attn_kwargs: dict = ColoAttention.prepare_attn_kwargs( + mask_shape, + hidden_states.dtype, + hidden_states.device, + q_padding_mask=attention_mask, + is_causal=True, + invert=(sp_mode != "ring_attn"), + ) + else: + attn_kwargs: torch.Tensor = self._update_causal_mask(attention_mask, hidden_states, cache_position) + + # Support SP + PP. Later stages have already received the split input. + split_input = disable_pp or stage_manager.is_first_stage() + if split_input: + # Ring Attention zigzag batch processing + if sp_mode == "ring_attn": + assert shard_config.enable_flash_attention, "Ring Attention inherently requires Flash Attention." + if not attention_mask.bool().all(): + hidden_states, attn_kwargs, position_ids = RingAttention.prepare_varlen_batch( + attention_mask, sp_group, hidden_states, position_ids + ) + else: + hidden_states, position_ids = split_batch_zigzag([hidden_states, position_ids], sp_group) + + elif is_share_sp_tp(sp_mode): + hidden_states = split_forward_gather_backward( + hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication + ) + elif sp_mode == "all_to_all": + hidden_states = split_forward_gather_backward( + hidden_states, 1, sp_group, 1 / sp_size, fp8_communication=shard_config.fp8_communication + ) + + if self.gradient_checkpointing and self.training and use_cache: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + start_idx, end_idx = (0, len(self.layers)) if disable_pp else (stage_index[0], stage_index[1]) + + num_ckpt_layers = 0 + if self.gradient_checkpointing and self.training: + num_ckpt_layers = end_idx - start_idx + # TODO: We can replace `gradient_checkpointing_enable` fn and initialize a gradient_checkpointing (List[bool]) for each layer + if shard_config.gradient_checkpoint_config is not None: + num_ckpt_layers = shard_config.gradient_checkpoint_config.get_num_ckpt_layers( + stage=stage_manager.stage, + num_stages=stage_manager.num_stages, + num_layers=end_idx - start_idx, + model_chunk_id=(stage_manager.model_chunk_id if stage_manager.is_interleave else 0), + num_model_chunks=stage_manager.num_model_chunks, + ) + assert num_ckpt_layers <= end_idx - start_idx + + for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx): + if output_hidden_states: + all_hidden_states += (hidden_states,) + if idx - start_idx < num_ckpt_layers: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attn_kwargs, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attn_kwargs, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if disable_pp or stage_manager.is_last_stage(): + hidden_states = self.norm(hidden_states) + if (not shard_config.parallel_output) or force_sp_gather or is_share_sp_tp(sp_mode): # noqa + hidden_states = gather_sp_output(hidden_states, shard_config) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + next_cache = next_decoder_cache if use_cache else None + if disable_pp or stage_manager.is_last_stage(): + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_cache, + all_hidden_states, + all_self_attns, + ] + if v is not None + ) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + # always return dict for intermediate stage + return {"hidden_states": hidden_states} + + @staticmethod + def llama_for_causal_lm_forward( + self: LlamaForCausalLM, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + ): + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, LlamaForCausalLM + + >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + logger = logging.get_logger(__name__) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. + if output_attentions: + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") + output_attentions = False + if output_hidden_states: + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") + output_hidden_states = False + + if shard_config.sequence_parallelism_mode == "ring_attn" and shard_config.parallel_output: + # Split labels in a zigzag fashion too + sp_group = shard_config.sequence_parallel_process_group + if attention_mask.bool().all(): + labels = split_batch_zigzag(labels, sp_group, seq_dim=1, is_label=True) + else: + # [B, max_seqlen // sp_size] + labels, _, _ = RingAttention.prepare_varlen_batch(attention_mask, sp_group, labels, is_label=True) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = LlamaPipelineForwards.llama_model_forward( + self.model, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + shard_config=shard_config, + force_sp_gather=False, + ) + past_key_values = None + + disable_pp = stage_manager is None + if disable_pp or stage_manager.is_last_stage(): + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + loss = None + if labels is not None: + loss = dist_cross_entropy(labels, logits, shard_config, self.lm_head.out_features, self.model.dtype) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + else: + hidden_states = outputs.get("hidden_states") + return {"hidden_states": hidden_states} + + @staticmethod + def llama_for_sequence_classification_forward( + self: LlamaForSequenceClassification, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + ): + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + logger = logging.get_logger(__name__) + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. + if output_attentions: + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") + output_attentions = False + if output_hidden_states: + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") + output_hidden_states = False + + transformer_outputs = LlamaPipelineForwards.llama_model_forward( + self.model, + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + shard_config=shard_config, + ) + + if input_ids is not None: + batch_size = input_ids.shape[0] + elif inputs_embeds is not None: + batch_size = inputs_embeds.shape[0] + else: + batch_size = hidden_states.shape[0] + + if stage_manager.is_last_stage(): + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + else: + hidden_states = transformer_outputs.get("hidden_states") + return {"hidden_states": hidden_states} + + +def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None): + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[Union[torch.Tensor, Dict]] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: + if sp_mode is not None: + assert sp_mode in _SUPPORTED_SP_MODE, f"SP mode {sp_mode} is not supported by {type(self)} yet" + assert (sp_size is not None) and ( + sp_group is not None + ), "Must specify sp_size and sp_group for sequence parallel" + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + + bsz, q_len, _ = hidden_states.size() + # sp: modify sp_len when sequence parallel mode is ring + if is_share_sp_tp(sp_mode): + q_len *= sp_size + + if self.config.pretraining_tp > 1: + key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp + query_slices = self.q_proj.weight.split( + (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 + ) + key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) + value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + + query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] + query_states = torch.cat(query_states, dim=-1) + + key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] + key_states = torch.cat(key_states, dim=-1) + + value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] + value_states = torch.cat(value_states, dim=-1) + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # sp: all-to-all comminucation when introducing sequence parallel + if sp_mode == "all_to_all": + query_states = all_to_all_comm(query_states, sp_group, fp8_communication=shard_config.fp8_communication) + key_states = all_to_all_comm(key_states, sp_group, fp8_communication=shard_config.fp8_communication) + value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication) + bsz, q_len, _ = query_states.size() + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + if sp_mode == "ring_attn": + attn_output = RingAttention.attention( + query_states, + key_states, + value_states, + sp_group, + **attention_mask, + inner_ring_size=shard_config.inner_ring_size, + ) + + elif shard_config.enable_flash_attention: + assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict." + attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask) + else: + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + # sp: all-to-all comminucation when introducing sequence parallel + if sp_mode == "all_to_all": + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim) + attn_output = all_to_all_comm( + attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication + ) + else: + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + if self.config.pretraining_tp > 1: + attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) + o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) + attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) + else: + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + return attn_output, attn_weights, past_key_value + + return forward + +def get_llama_decoder_layer_forward(): + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + 基于transformers 4.39.3 LlamaDecoderLayer.forward 进行改进,改进点: + a. 由于 input_layernorm 和 post_attention_layernorm 使用了ixformer 的融合算子,ln的输入输出形式有所改动。 + + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + + # residual = hidden_states + + hidden_states, residual = self.input_layernorm(hidden_states, None) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + # hidden_states = residual + hidden_states + + # Fully Connected + # residual = hidden_states + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + return forward + diff --git a/toolbox/ColossalAI/v0.4.4/patches/colossalai/shardformer/modeling/mixtral.py b/toolbox/ColossalAI/v0.4.4/patches/colossalai/shardformer/modeling/mixtral.py new file mode 100644 index 0000000000000000000000000000000000000000..657ad2552b1da87ddf6dab41e82471dc05cb2bca --- /dev/null +++ b/toolbox/ColossalAI/v0.4.4/patches/colossalai/shardformer/modeling/mixtral.py @@ -0,0 +1,1103 @@ +#!/usr/bin/env python3 +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +import inspect +import warnings +from typing import List, Optional, Tuple, Union + +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch.distributed import ProcessGroup +from torch.nn import CrossEntropyLoss +from transformers.cache_utils import Cache, DynamicCache +from transformers.modeling_attn_mask_utils import ( + _prepare_4d_causal_attention_mask, + _prepare_4d_causal_attention_mask_for_sdpa, +) +from transformers.models.mixtral.modeling_mixtral import ( + MixtralSparseMoeBlock, + MoeCausalLMOutputWithPast, + MoeModelOutputWithPast, + apply_rotary_pos_emb, + load_balancing_loss_func, + repeat_kv, +) +from transformers.utils import is_flash_attn_2_available, logging + +from colossalai.lazy import LazyInitContext +from colossalai.moe._operation import ( + DPGradScalerIn, + DPGradScalerOut, + EPGradScalerIn, + EPGradScalerOut, + all_to_all_uneven, +) +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.quantization.fp8 import all_reduce_fp8 +from colossalai.shardformer.layer._operation import ( + all_to_all_comm, + gather_forward_split_backward, + split_forward_gather_backward, +) +from colossalai.shardformer.layer.linear import Linear1D_Col, Linear1D_Row, ParallelModule +from colossalai.shardformer.shard import ShardConfig +from colossalai.shardformer.shard.utils import set_tensors_to_none +from colossalai.tensor.moe_tensor.api import set_moe_tensor_ep_group + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func + + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + + _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters) + + +class EPMixtralSparseMoeBlock(ParallelModule): + def __init__(self, *args, **kwargs): + raise RuntimeError(f"Please use `from_native_module` to create an instance of {self.__class__.__name__}") + + def setup_process_groups( + self, + tp_group: ProcessGroup, + moe_dp_group: ProcessGroup, + ep_group: ProcessGroup, + fp8_communication: bool = False, + ): + assert tp_group is not None + assert moe_dp_group is not None + assert ep_group is not None + + # setup ep group + self.ep_size = dist.get_world_size(ep_group) + self.ep_rank = dist.get_rank(ep_group) + self.ep_group = ep_group + self.fp8_communication = fp8_communication + + if self.num_experts % self.ep_size != 0: + raise ValueError("The number of experts must be divisible by the number of expert parallel groups.") + + self.num_experts_per_ep = self.num_experts // self.ep_size + self.expert_start_idx = self.ep_rank * self.num_experts_per_ep + held_experts = self.experts[self.expert_start_idx : self.expert_start_idx + self.num_experts_per_ep] + + set_tensors_to_none(self.experts, exclude=set(held_experts)) + + # setup moe_dp group + self.moe_dp_group = moe_dp_group + self.moe_dp_size = moe_dp_group.size() + + # setup global tp group + self.tp_group = tp_group + if self.tp_group.size() > 1: + for expert in held_experts: + expert.w1 = Linear1D_Col.from_native_module( + expert.w1, self.tp_group, fp8_communication=self.fp8_communication + ) + expert.w3 = Linear1D_Col.from_native_module( + expert.w3, self.tp_group, fp8_communication=self.fp8_communication + ) + expert.w2 = Linear1D_Row.from_native_module( + expert.w2, self.tp_group, fp8_communication=self.fp8_communication + ) + + for p in self.experts.parameters(): + set_moe_tensor_ep_group(p, ep_group) + + @staticmethod + def from_native_module( + module: MixtralSparseMoeBlock, + tp_group: ProcessGroup, + moe_dp_group: ProcessGroup, + ep_group: ProcessGroup, + *args, + **kwargs, + ) -> "EPMixtralSparseMoeBlock": + # TODO: better init + LazyInitContext.materialize(module) + module.__class__ = EPMixtralSparseMoeBlock + fp8_communication = kwargs.get("fp8_communication", False) + module.setup_process_groups(tp_group, moe_dp_group, ep_group, fp8_communication) + return module + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + # router_logits: (batch * sequence_length, n_experts) + router_logits = self.gate(hidden_states) + + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + # we cast back to the input dtype + routing_weights = routing_weights.to(hidden_states.dtype) + + selected_experts = selected_experts.t().reshape(-1) + selected_experts_idx = selected_experts.argsort() + dispatch_states = hidden_states.repeat(self.top_k, 1)[selected_experts_idx] + input_split_sizes = selected_experts.bincount(minlength=self.num_experts) + + output_split_sizes = torch.zeros_like(input_split_sizes) + + dist.all_to_all_single(output_split_sizes, input_split_sizes, group=self.ep_group) + + with torch.no_grad(): + activate_experts = output_split_sizes[: self.num_experts_per_ep].clone() + for i in range(1, self.ep_size): + activate_experts += output_split_sizes[i * self.num_experts_per_ep : (i + 1) * self.num_experts_per_ep] + activate_experts = (activate_experts > 0).float() + + if self.fp8_communication: + all_reduce_fp8(activate_experts, group=self.moe_dp_group) + else: + dist.all_reduce(activate_experts, group=self.moe_dp_group) + + input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist() + output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist() + + output_states, _ = all_to_all_uneven( + dispatch_states, + input_split_list, + output_split_list, + self.ep_group, + fp8_communication=self.fp8_communication, + ) + # compute expert output + output_states = EPGradScalerIn.apply(output_states, self.ep_size) + if output_states.size(0) > 0: + if self.num_experts_per_ep == 1: + # no need to split + expert = self.experts[self.expert_start_idx] + output_states = DPGradScalerIn.apply(output_states, self.moe_dp_size, activate_experts[0]) + output_states = expert.act_fn(expert.w1(output_states)) * expert.w3(output_states) + output_states = expert.w2(output_states) + output_states = DPGradScalerOut.apply(output_states, self.moe_dp_size, activate_experts[0]) + else: + output_states_splits = output_states.split(output_split_sizes.tolist()) + output_states_list = [] + for i, split_states in enumerate(output_states_splits): + if split_states.size(0) == 0: + continue + expert = self.experts[self.expert_start_idx + i % self.num_experts_per_ep] + split_states = DPGradScalerIn.apply( + split_states, self.moe_dp_size, activate_experts[i % self.num_experts_per_ep] + ) + split_states = expert.act_fn(expert.w1(split_states)) * expert.w3(split_states) + split_states = expert.w2(split_states) + split_states = DPGradScalerOut.apply( + split_states, self.moe_dp_size, activate_experts[i % self.num_experts_per_ep] + ) + output_states_list.append(split_states) + output_states = torch.cat(output_states_list) + + output_states = EPGradScalerOut.apply(output_states, self.ep_size) + dispatch_states, _ = all_to_all_uneven( + output_states, output_split_list, input_split_list, self.ep_group, fp8_communication=self.fp8_communication + ) + + recover_experts_idx = torch.empty_like(selected_experts_idx) + recover_experts_idx[selected_experts_idx] = torch.arange( + selected_experts_idx.size(0), device=selected_experts_idx.device + ) + dispatch_states = dispatch_states[recover_experts_idx] + k_hidden_states = dispatch_states.chunk(self.top_k) + output_states = k_hidden_states[0] * routing_weights[:, 0, None] + for i in range(1, self.top_k): + output_states += k_hidden_states[i] * routing_weights[:, i, None] + output_states = output_states.reshape(batch_size, sequence_length, hidden_dim) + return output_states, router_logits + + +class EPOptimizeMixtralSparseMoeBlock(ParallelModule): + def __init__(self, *args, **kwargs): + raise RuntimeError( + f"Please use `from_native_module` to create an instance of {self.__class__.__name__}" + ) + + def setup_process_groups( + self, + tp_group: ProcessGroup, + moe_dp_group: ProcessGroup, + ep_group: ProcessGroup, + fp8_communication: bool = False, + ): + assert tp_group is not None + assert moe_dp_group is not None + assert ep_group is not None + + # setup ep group + self.ep_size = dist.get_world_size(ep_group) + self.ep_rank = dist.get_rank(ep_group) + self.ep_group = ep_group + self.fp8_communication = fp8_communication + + if self.num_experts % self.ep_size != 0: + raise ValueError( + "The number of experts must be divisible by the number of expert parallel groups." + ) + + self.num_experts_per_ep = self.num_experts // self.ep_size + self.expert_start_idx = self.ep_rank * self.num_experts_per_ep + held_experts = self.experts[ + self.expert_start_idx : self.expert_start_idx + self.num_experts_per_ep + ] + + set_tensors_to_none(self.experts, exclude=set(held_experts)) + + # setup moe_dp group + self.moe_dp_group = moe_dp_group + self.moe_dp_size = moe_dp_group.size() + + # setup global tp group + self.tp_group = tp_group + if self.tp_group.size() > 1: + for expert in held_experts: + expert.w1 = Linear1D_Col.from_native_module( + expert.w1, self.tp_group, fp8_communication=self.fp8_communication + ) + expert.w3 = Linear1D_Col.from_native_module( + expert.w3, self.tp_group, fp8_communication=self.fp8_communication + ) + expert.w2 = Linear1D_Row.from_native_module( + expert.w2, self.tp_group, fp8_communication=self.fp8_communication + ) + + for p in self.experts.parameters(): + set_moe_tensor_ep_group(p, ep_group) + + @staticmethod + def from_native_module( + module: MixtralSparseMoeBlock, + tp_group: ProcessGroup, + moe_dp_group: ProcessGroup, + ep_group: ProcessGroup, + *args, + **kwargs, + ) -> "EPOptimizeMixtralSparseMoeBlock": + # TODO: better init + LazyInitContext.materialize(module) + module.__class__ = EPOptimizeMixtralSparseMoeBlock + fp8_communication = kwargs.get("fp8_communication", False) + module.setup_process_groups(tp_group, moe_dp_group, ep_group, fp8_communication) + return module + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + # router_logits: (batch * sequence_length, n_experts) + router_logits = self.gate(hidden_states) + + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk( + routing_weights, self.top_k, dim=-1 + ) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + # we cast back to the input dtype + routing_weights = routing_weights.to(hidden_states.dtype) + + selected_experts = selected_experts.t().reshape(-1) + selected_experts_idx = selected_experts.argsort() + dispatch_states = hidden_states.repeat(self.top_k, 1)[selected_experts_idx] + input_split_sizes = selected_experts.bincount(minlength=self.num_experts) + + output_split_sizes = torch.zeros_like(input_split_sizes) + + dist.all_to_all_single( + output_split_sizes, input_split_sizes, group=self.ep_group + ) + + with torch.no_grad(): + activate_experts = output_split_sizes[: self.num_experts_per_ep].clone() + for i in range(1, self.ep_size): + activate_experts += output_split_sizes[ + i * self.num_experts_per_ep : (i + 1) * self.num_experts_per_ep + ] + activate_experts = (activate_experts > 0).float() + + if self.fp8_communication: + all_reduce_fp8(activate_experts, group=self.moe_dp_group) + else: + dist.all_reduce(activate_experts, group=self.moe_dp_group) + + input_split_list = ( + input_split_sizes.view(self.ep_size, self.num_experts_per_ep) + .sum(dim=-1) + .tolist() + ) + output_split_list = ( + output_split_sizes.view(self.ep_size, self.num_experts_per_ep) + .sum(dim=-1) + .tolist() + ) + + output_states, _ = all_to_all_uneven( + dispatch_states, + input_split_list, + output_split_list, + self.ep_group, + fp8_communication=self.fp8_communication, + ) + + # compute expert output + output_states = EPGradScalerIn.apply(output_states, self.ep_size) + if output_states.size(0) > 0: + if self.num_experts_per_ep == 1: + # no need to split + expert = self.experts[self.expert_start_idx] + output_states = DPGradScalerIn.apply( + output_states, self.moe_dp_size, activate_experts[0] + ) + output_states = expert.act_fn(expert.w1(output_states)) * expert.w3( + output_states + ) + output_states = expert.w2(output_states) + output_states = DPGradScalerOut.apply( + output_states, self.moe_dp_size, activate_experts[0] + ) + else: + output_states_splits = output_states.split(output_split_sizes.tolist()) + tmp_output_states_list = [] + for i in range(self.num_experts_per_ep): + split_states = output_states_splits[i] + if split_states.size(0) == 0: + continue + expert = self.experts[ + self.expert_start_idx + i % self.num_experts_per_ep + ] + split_states = DPGradScalerIn.apply( + split_states, + self.moe_dp_size, + activate_experts[i % self.num_experts_per_ep], + ) + split_states = expert.act_fn(expert.w1(split_states)) * expert.w3( + split_states + ) + split_states = expert.w2(split_states) + split_states = DPGradScalerOut.apply( + split_states, + self.moe_dp_size, + activate_experts[i % self.num_experts_per_ep], + ) + tmp_output_states_list.append(split_states) + # 重复 len(output_split_list) 次 + output_states_list_1 = [] + for i in range(len(output_split_list)): + output_states_list_1.extend(tmp_output_states_list) + + output_states = torch.cat(output_states_list_1) + + output_states = EPGradScalerOut.apply(output_states, self.ep_size) + dispatch_states, _ = all_to_all_uneven( + output_states, + output_split_list, + input_split_list, + self.ep_group, + fp8_communication=self.fp8_communication, + ) + + recover_experts_idx = torch.empty_like(selected_experts_idx) + recover_experts_idx[selected_experts_idx] = torch.arange( + selected_experts_idx.size(0), device=selected_experts_idx.device + ) + dispatch_states = dispatch_states[recover_experts_idx] + k_hidden_states = dispatch_states.chunk(self.top_k) + output_states = k_hidden_states[0] * routing_weights[:, 0, None] + + for i in range(1, self.top_k): + output_states += k_hidden_states[i] * routing_weights[:, i, None] + + output_states = output_states.reshape(batch_size, sequence_length, hidden_dim) + return output_states, router_logits + + +class MixtralPipelineForwards: + """ + This class serves as a micro library for forward function substitution of Mixtral models + under pipeline setting. + """ + + @staticmethod + def mixtral_model_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + past_router_logits: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + ): + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, MixtralForCausalLM + + >>> model = MixtralForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + logger = logging.get_logger(__name__) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if stage_manager.is_first_stage(): + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + device = input_ids.device if input_ids is not None else inputs_embeds.device + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + hidden_states = inputs_embeds + else: + input_shape = hidden_states.shape[:-1] + batch_size, seq_length = input_shape + device = hidden_states.device + + seq_length_with_past = seq_length + past_key_values_length = 0 + + # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. + if output_attentions: + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") + output_attentions = False + if output_hidden_states: + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") + output_hidden_states = False + if use_cache: + logger.warning_once("use_cache=True is not supported for pipeline models at the moment.") + use_cache = False + + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + if position_ids is None: + position_ids = torch.arange( + past_key_values_length, + seq_length + past_key_values_length, + dtype=torch.long, + device=device, + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + # embed positions, for the first stage, hidden_states is the input embeddings, + # for the other stages, hidden_states is the output of the previous stage + if is_flash_attn_2_available(): + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, + (batch_size, seq_length), + hidden_states, + past_key_values_length, + sliding_window=self.config.sliding_window, + ) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_router_logits = () if output_router_logits else None + next_decoder_cache = None + + start_idx, end_idx = stage_index[0], stage_index[1] + for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + position_ids, + None, + output_attentions, + output_router_logits, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask, + position_ids, + past_key_value, + output_attentions, + output_router_logits, + use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = (layer_outputs[2 if output_attentions else 1],) + if output_attentions: + all_self_attns += (layer_outputs[1],) + if output_router_logits: + all_router_logits += (layer_outputs[-1],) + + if stage_manager.is_last_stage(): + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + next_cache = next_decoder_cache if use_cache else None + + if output_router_logits and past_router_logits is not None: + all_router_logits = past_router_logits + all_router_logits + if stage_manager.is_last_stage(): + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits] + if v is not None + ) + return MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + router_logits=all_router_logits, + ) + else: + if output_router_logits: + return { + "hidden_states": hidden_states, + "past_router_logits": all_router_logits, + } + else: + return { + "hidden_states": hidden_states, + } + + @staticmethod + def mixtral_for_causal_lm_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + past_router_logits: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + ): + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, MixtralForCausalLM + + >>> model = MixtralForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + logger = logging.get_logger(__name__) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. + if output_attentions: + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") + output_attentions = False + if output_hidden_states: + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") + output_hidden_states = False + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = MixtralPipelineForwards.mixtral_model_forward( + self.model, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + past_router_logits=past_router_logits, + ) + past_key_values = None + + if stage_manager.is_last_stage(): + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + aux_loss = None + if output_router_logits: + aux_loss = load_balancing_loss_func(outputs[-1], self.num_experts, self.num_experts_per_tok) + if labels is not None: + loss += self.router_aux_loss_coef * aux_loss + + if not return_dict: + output = (logits,) + outputs[1:] + if output_router_logits: + output = (aux_loss,) + output + return (loss,) + output if loss is not None else output + + return MoeCausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=logits, + past_key_values=None, + hidden_states=outputs[0], + attentions=None, + router_logits=outputs[-1], + ) + else: + out = {} + hidden_states = outputs.get("hidden_states") + out["hidden_states"] = hidden_states + if output_router_logits: + out["past_router_logits"] = outputs["past_router_logits"] + return out + + +def get_mixtral_flash_attention_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None): + logger = logging.get_logger(__name__) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: + if sp_mode is not None: + assert sp_mode in ["all_to_all", "split_gather", "ring"], "Invalid sp_mode" + assert (sp_size is not None) and ( + sp_group is not None + ), "Must specify sp_size and sp_group for sequence parallel" + + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + + # overwrite attention_mask with padding_mask + attention_mask = kwargs.pop("padding_mask") + bsz, q_len, _ = hidden_states.size() + + # sp: modify sp_len when sequence parallel mode is ring + if sp_mode in ["split_gather", "ring"]: + q_len *= sp_size + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # sp: all-to-all comminucation when introducing sequence parallel + if sp_mode == "all_to_all": + query_states = all_to_all_comm(query_states, sp_group, fp8_communication=shard_config.fp8_communication) + key_states = all_to_all_comm(key_states, sp_group, fp8_communication=shard_config.fp8_communication) + value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication) + bsz, q_len, _ = query_states.size() + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + + # Because the input can be padded, the absolute sequence length depends on the max position id. + rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1 + cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + use_sliding_windows = ( + _flash_supports_window_size + and getattr(self.config, "sliding_window", None) is not None + and kv_seq_len > self.config.sliding_window + ) + if not _flash_supports_window_size: + logger.warning_once( + "The current flash attention version does not support sliding window attention, for a more memory efficient implementation" + " make sure to upgrade flash-attn library." + ) + if past_key_value is not None: + # Activate slicing cache only if the config has a value `sliding_windows` attribute + cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 + if ( + getattr(self.config, "sliding_window", None) is not None + and kv_seq_len > self.config.sliding_window + and cache_has_contents + ): + slicing_tokens = 1 - self.config.sliding_window + + past_key = past_key_value[self.layer_idx][0] + past_value = past_key_value[self.layer_idx][1] + + past_key = past_key[:, :, slicing_tokens:, :].contiguous() + past_value = past_value[:, :, slicing_tokens:, :].contiguous() + + if past_key.shape[-2] != self.config.sliding_window - 1: + raise ValueError( + f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got" + f" {past_key.shape}" + ) + + if attention_mask is not None: + attention_mask = attention_mask[:, slicing_tokens:] + attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) + + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + dropout_rate = 0.0 if not self.training else self.attention_dropout + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + # Reashape to the expected shape for Flash Attention + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + attn_output = self._flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate, + use_sliding_windows=use_sliding_windows, + ) + + # sp: all-to-all comminucation when introducing sequence parallel + if sp_mode == "all_to_all": + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim).contiguous() # (1, 8, 128) + attn_output = all_to_all_comm( + attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication + ) # (1, 4, 256) + else: + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + return attn_output, attn_weights, past_key_value + + return forward + + +def get_mixtral_flash_attention_model_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None): + logger = logging.get_logger(__name__) + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, MoeModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + past_key_values_length = 0 + + if (self.gradient_checkpointing or sp_mode in ["ring", "all_to_all"]) and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + if use_cache: + use_legacy_cache = not isinstance(past_key_values, Cache) + if use_legacy_cache: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_key_values_length = past_key_values.get_usable_length(seq_length) + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache: + is_padding_right = attention_mask[:, -1].sum().item() != batch_size + if is_padding_right: + raise ValueError( + "You are attempting to perform batched generation with padding_side='right'" + " this may lead to unexpected behaviour for Flash Attention version of Mixtral. Make sure to " + " call `tokenizer.padding_side = 'left'` before tokenizing the input. " + ) + if self._attn_implementation == "flash_attention_2": + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif self._attn_implementation == "sdpa" and not output_attentions: + # output_attentions=True can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + ) + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + sliding_window=self.config.sliding_window, + ) + + if sp_mode in ["ring", "split_gather"]: + inputs_embeds = split_forward_gather_backward( + inputs_embeds, 1, sp_group, fp8_communication=shard_config.fp8_communication + ) + elif sp_mode == "all_to_all": + inputs_embeds = split_forward_gather_backward( + inputs_embeds, 1, sp_group, 1 / sp_size, fp8_communication=shard_config.fp8_communication + ) + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_router_logits = () if output_router_logits else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask, + position_ids, + past_key_values, + output_attentions, + output_router_logits, + use_cache, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + output_router_logits=output_router_logits, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if output_router_logits: + all_router_logits += (layer_outputs[-1],) + + hidden_states = self.norm(hidden_states) + + if sp_mode == "ring" or sp_mode == "split_gather": + hidden_states = gather_forward_split_backward( + hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication + ) + elif sp_mode == "all_to_all": + hidden_states = gather_forward_split_backward( + hidden_states, 1, sp_group, grad_scale=sp_size, fp8_communication=shard_config.fp8_communication + ) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = None + if use_cache: + next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache + + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits] + if v is not None + ) + return MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + router_logits=all_router_logits, + ) + + return forward diff --git a/toolbox/ColossalAI/v0.4.4/patches/colossalai/shardformer/policies/llama.py b/toolbox/ColossalAI/v0.4.4/patches/colossalai/shardformer/policies/llama.py new file mode 100644 index 0000000000000000000000000000000000000000..475c9f01a57d6e67927ae07f0b47c574be02f2d0 --- /dev/null +++ b/toolbox/ColossalAI/v0.4.4/patches/colossalai/shardformer/policies/llama.py @@ -0,0 +1,459 @@ +#!/usr/bin/env python3 +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +from functools import partial +from typing import Callable, Dict, List, Union + +import torch.nn as nn +from torch import Tensor +from torch.nn import Module + +from colossalai.shardformer.layer import ( + FusedRMSNorm, + Linear1D_Col, + Linear1D_Row, + PaddingEmbedding, + PaddingLMHead, + RMSNorm, + VocabParallelEmbedding1D, + VocabParallelLMHead1D, + IXFLlamaMLP, + Colo_LlamaFlashAtten, + Colo_FusedRMSNorm, +) + +from ..modeling.llama import LlamaPipelineForwards, get_llama_flash_attention_forward, get_llama_decoder_layer_forward +from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription + +__all__ = ["LlamaPolicy", "LlamaForCausalLMPolicy", "LlamaForSequenceClassificationPolicy"] + + +class LlamaPolicy(Policy): + def config_sanity_check(self): + pass + + def preprocess(self): + self.tie_weight = self.tie_weight_check() + self.origin_attn_implement = self.model.config._attn_implementation + return self.model + + def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: + from transformers.models.llama.modeling_llama import ( + LlamaAttention, + LlamaDecoderLayer, + LlamaFlashAttention2, + LlamaModel, + LlamaSdpaAttention, + ) + + ATTN_IMPLEMENTATION = { + "eager": LlamaAttention, + "flash_attention_2": LlamaFlashAttention2, + "sdpa": LlamaSdpaAttention, + } + policy = {} + + attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement] + embedding_cls = None + if self.shard_config.enable_tensor_parallelism: + embedding_cls = VocabParallelEmbedding1D + else: + if self.tie_weight: + embedding_cls = PaddingEmbedding + + if self.shard_config.enable_fused_normalization: + norm_cls = FusedRMSNorm + else: + norm_cls = RMSNorm + + sp_mode = self.shard_config.sequence_parallelism_mode or None + sp_size = self.shard_config.sequence_parallel_size or None + sp_group = self.shard_config.sequence_parallel_process_group or None + sp_partial_derived = sp_mode in ["split_gather", "ring"] + if sp_mode == "ring_attn" and not self.is_causal: + raise ValueError("Ring attention is only meant for causal language modeling.") + + tp_size = self.shard_config.tensor_parallel_size + # Modified by SP and TP + num_q_heads = self.model.config.num_attention_heads + num_kv_heads = getattr(self.model.config, "num_key_value_heads", None) + + if sp_mode == "all_to_all": + num_q_heads //= sp_size + decoder_attribute_replacement = {"num_heads": num_q_heads} + if num_kv_heads: + num_kv_heads //= sp_size + decoder_attribute_replacement["num_key_value_heads"] = num_kv_heads + + policy[attn_cls] = ModulePolicyDescription( + attribute_replacement=decoder_attribute_replacement, + ) + if self.shard_config.enable_flash_attention or self.shard_config.enable_sequence_parallelism: + if not self.shard_config.use_colo_llamaflashatten: # 若使用Colo_LlamaFlashAtten,则使用其forward,不必做此替换 + self.append_or_create_method_replacement( + description={ + "forward": get_llama_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group), + }, + policy=policy, + target_key=attn_cls, + ) + + if self.pipeline_stage_manager is None: + self.append_or_create_method_replacement( + description={ + "forward": partial( + LlamaPipelineForwards.llama_model_forward, + shard_config=self.shard_config, + ), + }, + policy=policy, + target_key=LlamaModel, + ) + + if self.shard_config.enable_tensor_parallelism: + assert ( + num_q_heads % tp_size == 0 + ), f"The number of attention heads must be divisible by tensor parallel size." + if hasattr(self.model.config, "num_key_value_heads"): + assert ( + num_kv_heads >= tp_size and num_kv_heads % tp_size == 0 + ), f"The number of key_value heads must be divisible by, and must not be less than tensor parallel size." + num_q_heads //= tp_size + decoder_attribute_replacement = { + "self_attn.hidden_size": self.model.config.hidden_size // tp_size, + "self_attn.num_heads": num_q_heads, + } + if getattr(self.model.config, "num_key_value_heads", False): + num_kv_heads //= tp_size + decoder_attribute_replacement["self_attn.num_key_value_heads"] = num_kv_heads + + policy[LlamaDecoderLayer] = ModulePolicyDescription( + attribute_replacement=decoder_attribute_replacement, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attn.q_proj", + target_module=Linear1D_Col, + kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), + ), + SubModuleReplacementDescription( + suffix="self_attn.k_proj", + target_module=Linear1D_Col, + kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), + ), + SubModuleReplacementDescription( + suffix="self_attn.v_proj", + target_module=Linear1D_Col, + kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), + ), + SubModuleReplacementDescription( + suffix="self_attn.o_proj", + target_module=Linear1D_Row, + kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), + ), + SubModuleReplacementDescription( + suffix="mlp.gate_proj", + target_module=Linear1D_Col, + kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), + ), + SubModuleReplacementDescription( + suffix="mlp.up_proj", + target_module=Linear1D_Col, + kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), + ), + SubModuleReplacementDescription( + suffix="mlp.down_proj", + target_module=Linear1D_Row, + kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), + ), + ], + ) + + if embedding_cls is not None: + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="embed_tokens", + target_module=embedding_cls, + kwargs=( + { + "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + "fp8_communication": self.shard_config.fp8_communication, + } + if self.shard_config.enable_tensor_parallelism + else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by} + ), + ), + policy=policy, + target_key=LlamaModel, + ) + + # Colo_LlamaFlashAtten,IXFLlamaMLP 暂不支持 tp + if self.shard_config.use_colo_llamaflashatten: + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="self_attn", + target_module=Colo_LlamaFlashAtten, + kwargs=None, + ), + policy=policy, + target_key=LlamaDecoderLayer, + ) + + if self.shard_config.use_ixformer_mlp: + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="mlp", + target_module=IXFLlamaMLP, + kwargs=None, + ), + policy=policy, + target_key=LlamaDecoderLayer, + ) + + # optimization configuration + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="input_layernorm", + target_module=Colo_FusedRMSNorm if self.shard_config.use_ixformer_fusedrmsnormres else norm_cls, + kwargs={"sp_partial_derived": sp_partial_derived}, + ), + SubModuleReplacementDescription( + suffix="post_attention_layernorm", + target_module=Colo_FusedRMSNorm if self.shard_config.use_ixformer_fusedrmsnormres else norm_cls, + kwargs={"sp_partial_derived": sp_partial_derived}, + ), + ], + policy=policy, + target_key=LlamaDecoderLayer, + ) + + if self.shard_config.use_ixformer_fusedrmsnormres: + self.append_or_create_method_replacement( + description={ + "forward": get_llama_decoder_layer_forward(), + }, + policy=policy, + target_key=LlamaDecoderLayer, + ) + + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="norm", + target_module=norm_cls, + kwargs={"sp_partial_derived": sp_partial_derived}, + ), + policy=policy, + target_key=LlamaModel, + ) + + return policy + + def postprocess(self): + return self.model + + def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: + """If under pipeline parallel setting, replacing the original forward method of huggingface + to customized forward method, and add this changing to policy.""" + if self.pipeline_stage_manager is None: + return + + stage_manager = self.pipeline_stage_manager + if self.model.__class__.__name__ == "LlamaModel": + module = self.model + else: + module = self.model.model + + if stage_manager.is_interleave: + layers_per_stage = stage_manager.distribute_layers(len(module.layers)) + stage_manager.stage_indices = stage_manager.get_stage_index(layers_per_stage) + method_replacement = { + "forward": partial(new_forward, stage_manager=stage_manager, shard_config=self.shard_config) + } + + else: + layers_per_stage = stage_manager.distribute_layers(len(module.layers)) + stage_index = stage_manager.get_stage_index(layers_per_stage) + method_replacement = { + "forward": partial( + new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config + ) + } + + self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) + + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + assert self.pipeline_stage_manager is not None + + if self.model.__class__.__name__ == "LlamaModel": + module = self.model + else: + module = self.model.model + stage_manager = self.pipeline_stage_manager + + held_layers = [] + if stage_manager.is_interleave: + assert stage_manager.num_model_chunks is not None + layers_per_stage = stage_manager.distribute_layers(len(module.layers)) + stage_indices = stage_manager.get_stage_index(layers_per_stage) + if stage_manager.is_first_stage(ignore_chunk=True): + held_layers.append(module.embed_tokens) + for start_idx, end_idx in stage_indices: + held_layers.extend(module.layers[start_idx:end_idx]) + if stage_manager.is_last_stage(ignore_chunk=True): + held_layers.append(module.norm) + + else: + layers_per_stage = stage_manager.distribute_layers(len(module.layers)) + if stage_manager.is_first_stage(): + held_layers.append(module.embed_tokens) + start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) + held_layers.extend(module.layers[start_idx:end_idx]) + if stage_manager.is_last_stage(): + held_layers.append(module.norm) + + return held_layers + + +class LlamaModelPolicy(LlamaPolicy): + def module_policy(self): + policy = super().module_policy() + from transformers.models.llama.modeling_llama import LlamaModel + + if self.pipeline_stage_manager: + # set None as default + self.set_pipeline_forward( + model_cls=LlamaModel, new_forward=LlamaPipelineForwards.llama_model_forward, policy=policy + ) + return policy + + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + held_layers = super().get_held_layers() + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """No shared params in llama model""" + return [] + + +class LlamaForCausalLMPolicy(LlamaPolicy): + def module_policy(self): + from transformers import LlamaForCausalLM + + self.is_causal = True + policy = super().module_policy() + + if self.shard_config.enable_tensor_parallelism: + # add a new item for causal lm + new_item = { + LlamaForCausalLM: ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="lm_head", + target_module=VocabParallelLMHead1D, + kwargs={ + "gather_output": not self.shard_config.parallel_output, + "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + "fp8_communication": self.shard_config.fp8_communication, + }, + ) + ], + ) + } + else: + new_item = { + LlamaForCausalLM: ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="lm_head", + target_module=PaddingLMHead, + kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, + ) + ], + ) + } + policy.update(new_item) + + if self.pipeline_stage_manager: + # set None as default + self.set_pipeline_forward( + model_cls=LlamaForCausalLM, new_forward=LlamaPipelineForwards.llama_for_causal_lm_forward, policy=policy + ) + elif self.shard_config.enable_tensor_parallelism or self.shard_config.enable_sequence_parallelism: + # Compute loss distributedly along the sequence dimension + new_item[LlamaForCausalLM].method_replacement = { + # "forward": get_lm_forward_with_dist_cross_entropy(self.shard_config) + "forward": partial(LlamaPipelineForwards.llama_for_causal_lm_forward, shard_config=self.shard_config) + } + return policy + + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + stage_manager = self.pipeline_stage_manager + held_layers = super().get_held_layers() + if stage_manager.is_last_stage(ignore_chunk=True): + held_layers.append(self.model.lm_head) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + llama_model = self.model.model + if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1: + if ( + id(llama_model.embed_tokens.weight) == id(self.model.lm_head.weight) + and self.pipeline_stage_manager.num_stages > 1 + ): + # tie weights + return [ + { + 0: llama_model.embed_tokens.weight, + self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight, + } + ] + return [] + + +class LlamaForSequenceClassificationPolicy(LlamaPolicy): + def module_policy(self): + from transformers import LlamaForSequenceClassification + + policy = super().module_policy() + + if self.shard_config.enable_tensor_parallelism: + # add a new item for sequence classification + new_item = { + LlamaForSequenceClassification: ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="score", + target_module=Linear1D_Col, + kwargs=dict( + gather_output=True, + fp8_communication=self.shard_config.fp8_communication, + ), + ) + ] + ) + } + policy.update(new_item) + # to be confirmed + if self.pipeline_stage_manager: + # set None as default + self.set_pipeline_forward( + model_cls=LlamaForSequenceClassification, + new_forward=LlamaPipelineForwards.llama_for_sequence_classification_forward, + policy=policy, + ) + return policy + + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + stage_manager = self.pipeline_stage_manager + held_layers = super().get_held_layers() + if stage_manager.is_last_stage(ignore_chunk=True): + held_layers.append(self.model.score) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """No shared params in llama for sequence classification model""" + return [] diff --git a/toolbox/ColossalAI/v0.4.4/patches/colossalai/shardformer/shard/shard_config.py b/toolbox/ColossalAI/v0.4.4/patches/colossalai/shardformer/shard/shard_config.py new file mode 100644 index 0000000000000000000000000000000000000000..9cf1f90ff0b80e3983608415d7d013272afcce62 --- /dev/null +++ b/toolbox/ColossalAI/v0.4.4/patches/colossalai/shardformer/shard/shard_config.py @@ -0,0 +1,143 @@ +#!/usr/bin/env python3 +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +import warnings +from dataclasses import dataclass, field +from typing import Any, Dict, Optional + +import torch.distributed as dist +from torch.distributed import ProcessGroup + +from colossalai.pipeline.stage_manager import PipelineStageManager + +from .grad_ckpt_config import GradientCheckpointConfig + +__all__ = ["ShardConfig"] +SUPPORT_SP_MODE = ["split_gather", "ring", "all_to_all", "ring_attn"] + + +@dataclass +class ShardConfig: + r""" + The config for sharding the huggingface model + + Args: + tensor_parallel_process_group (Optional[ProcessGroup]): The process group of tensor parallelism, it's necessary when using tensor parallel. Defaults to None, which is the global process group. + pipeline_stage_manager (Optional[PipelineStageManager]): If using pipeline parallelism, it's necessary to specify a pipeline stage manager for inter-process communication in pipeline parallelism. Defaults to None, which means not using pipeline parallelism. + enable_tensor_parallelism (bool): Whether to use tensor parallelism. Defaults to True. + enable_fused_normalization (bool): Whether to use fused layernorm. Defaults to False. + enable_flash_attention (bool, optional): Whether to switch on flash attention. Defaults to False. + enable_jit_fused (bool, optional): Whether to switch on JIT fused operators. Defaults to False. + enable_sequence_parallelism (bool): Whether to turn on sequence parallelism, which partitions non-tensor-parallel regions along the sequence dimension. Defaults to False. + enable_sequence_overlap (bool): Whether to turn on sequence overlap, which overlap the computation and communication in sequence parallelism. It can only be used when enable_sequence_parallelism is True. Defaults to False. + gradient_checkpoint_config (Optional[GradientCheckpointConfig]): The gradient checkpoint config. Defaults to None. + enable_all_optimization (bool): Whether to turn on all optimization tools including 'fused normalization', 'flash attention', 'JIT fused operators', 'sequence parallelism' and 'sequence overlap'. Defaults to False. + fp8_communication (bool, optional): Whether to enable fp8 communication in model parallelism. Defaults to False. + parallel_output (bool): For TP: whether to use parallelize cross entropy computation along the feature dim. + For SP: set to True to NOT gather the output along the seq dim. + """ + + tensor_parallel_process_group: Optional[ProcessGroup] = None + sequence_parallel_process_group: Optional[ProcessGroup] = None + pipeline_stage_manager: Optional[PipelineStageManager] = None + enable_tensor_parallelism: bool = True + enable_all_optimization: bool = False + enable_fused_normalization: bool = False + enable_flash_attention: bool = False + enable_jit_fused: bool = False + enable_sequence_parallelism: bool = False + sequence_parallelism_mode: str = None + enable_sequence_overlap: bool = False + parallel_output: bool = True + make_vocab_size_divisible_by: int = 64 + gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None + extra_kwargs: Dict[str, Any] = field(default_factory=dict) + + # For ring attention + inner_ring_size: Optional[int] = None + # for moe related + moe_dp_group: Optional[ProcessGroup] = None + ep_group: Optional[ProcessGroup] = None + fp8_communication: bool = False + # pipeline_parallel_size: int + # data_parallel_size: int + # tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d'] + use_ixformer_mlp: bool = False + use_colo_llamaflashatten: bool = False + use_ixformer_fusedrmsnormres: bool = False + + @property + def tensor_parallel_size(self): + return self._tensor_parallel_size + + @property + def sequence_parallel_size(self): + return self._sequence_parallel_size + + def __post_init__(self): + # turn on all optimization if all_optimization is set to True + if self.enable_all_optimization: + self._turn_on_all_optimization() + + if self.enable_sequence_parallelism: + self.sequence_parallelism_mode = ( + "split_gather" if self.sequence_parallelism_mode is None else self.sequence_parallelism_mode + ) + assert ( + self.sequence_parallelism_mode in SUPPORT_SP_MODE + ), f"Sequence parallelism mode {self.sequence_parallelism_mode} is not in the supported list {SUPPORT_SP_MODE}" + if self.sequence_parallelism_mode in ["split_gather", "ring"]: + assert ( + self.enable_tensor_parallelism + ), f"sequence parallelism mode {self.sequence_parallelism_mode} can only be used when enable_tensor_parallelism is True" + elif self.sequence_parallelism_mode in ["all_to_all"]: + # assert ( + # not self.enable_tensor_parallelism + # ), f"sequence parallelism mode {self.sequence_parallelism_mode} can only be used when enable_tensor_parallelism is False" + if self.enable_sequence_overlap: + self.enable_sequence_overlap = False + warnings.warn( + f"The enable_sequence_overlap flag will be ignored in sequence parallelism mode {self.sequence_parallelism_mode}" + ) + else: + if self.sequence_parallelism_mode: + self.sequence_parallelism_mode = None + warnings.warn( + f"The sequence_parallelism_mode will be ignored when enable_sequence_parallelism is False" + ) + assert ( + not self.enable_sequence_overlap + ), f"enable_sequence_overlap can only be set to True when enable_sequence_parallelism is True" + + # get the tensor parallel size + if not self.enable_tensor_parallelism: + self._tensor_parallel_size = 1 + else: + self._tensor_parallel_size = dist.get_world_size(self.tensor_parallel_process_group) + + # get the sequence parallel size + if not self.enable_sequence_parallelism: + self._sequence_parallel_size = 1 + else: + self._sequence_parallel_size = dist.get_world_size(self.sequence_parallel_process_group) + + def _turn_on_all_optimization(self): + """ + Turn on all optimization. + """ + # you can add all the optimization flag here + try: + from apex.normalization import FusedLayerNorm as ApexFusedLayerNorm # noqa + + apex_avail = True + except ImportError: + apex_avail = False + warnings.warn("You set enable_all_optimization=True, but apex is not installed.") + + self.enable_fused_normalization = apex_avail + self.enable_flash_attention = True + self.enable_jit_fused = True + # This can cause non-in-place param sharding when used without ZeRO. + # It may also slow down training when seq len is small. Plz enable manually. + # self.enable_sequence_parallelism = True + # self.enable_sequence_overlap = True diff --git a/toolbox/ColossalAI/v0.4.4/patches/colossalai/zero/low_level/low_level_optim.py b/toolbox/ColossalAI/v0.4.4/patches/colossalai/zero/low_level/low_level_optim.py new file mode 100644 index 0000000000000000000000000000000000000000..9e7d177fcbf5e0e596648cc85c120ecb48bd7bac --- /dev/null +++ b/toolbox/ColossalAI/v0.4.4/patches/colossalai/zero/low_level/low_level_optim.py @@ -0,0 +1,939 @@ +#!/usr/bin/env python3 +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# this code is inspired by the DeepSpeed library and implemented with our own design from scratch +import copy +from contextlib import contextmanager, nullcontext +from functools import partial +from typing import Dict, Iterator, List, Optional, Tuple +from weakref import proxy + +import torch +import torch.distributed as dist +import torch.nn as nn +from torch import Tensor, inf +from torch.distributed import ProcessGroup +from torch.optim import Optimizer + +from colossalai.accelerator import get_accelerator +from colossalai.amp.naive_amp.mixed_precision_mixin import ( + BF16MixedPrecisionMixin, + FP16MixedPrecisionMixin, + MixedPrecisionMixin, +) +from colossalai.interface import OptimizerWrapper +from colossalai.logging import get_dist_logger +from colossalai.quantization.fp8 import all_gather_fp8, all_reduce_fp8, reduce_scatter_fp8 +from colossalai.tensor.moe_tensor.api import is_moe_tensor + +from ._utils import calculate_global_norm_from_list, has_inf_or_nan, release_param_grad, sync_tensor +from .bookkeeping import BucketStore, GradientStore, TensorBucket +from .zero_hook import set_all_gather_handle, wait_all_gather_handle + + +class LowLevelZeroFP16MixedPrecisionMixin(FP16MixedPrecisionMixin): + def __init__( + self, + num_working_param_groups: int, + pg_to_grad_store: Dict[ProcessGroup, GradientStore], + initial_scale: float = 2**16, + min_scale: float = 1, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + max_scale: float = 2**32, + ) -> None: + super().__init__( + initial_scale, + min_scale, + growth_factor, + backoff_factor, + growth_interval, + hysteresis, + max_scale, + ) + self.num_working_param_groups = num_working_param_groups + self.pg_to_grad_store = pg_to_grad_store + + def check_local_overflow(self) -> bool: + for store in self.pg_to_grad_store.values(): + for group_id in range(self.num_working_param_groups): + for avg_grad in store.get_working_grads_by_group_id(group_id): + if avg_grad is not None and has_inf_or_nan(avg_grad): + return True + return False + + +class LowLevelZeroOptimizer(OptimizerWrapper): + """Optimizer used for ZeRO-1 and ZeRO-2.""" + + def __init__( + self, + optimizer: Optimizer, + pg_to_param_list: Optional[Dict[ProcessGroup, List[nn.Parameter]]] = None, + initial_scale: int = 2**16, # grad scaler config + min_scale: int = 1, + growth_factor: float = 2.0, + backoff_factor: float = 0.5, + growth_interval: int = 2000, + hysteresis: int = 2, + max_scale: int = 2**24, + clip_grad_norm: float = 0.0, # grad clipping + verbose: bool = False, + reduce_bucket_size: int = 1024 * 1024, # communication + communication_dtype: Optional[torch.dtype] = None, + overlap_communication: bool = False, + partition_grad: bool = False, # stage 2 flag + cpu_offload: bool = False, # cpu offload + dp_process_group: Optional[ProcessGroup] = None, + forced_dtype: Optional[torch.dtype] = None, + master_weights: bool = True, # master weights + overlap_allgather: bool = False, + fp8_communication: bool = False, + backward_context=None, + ): + super(LowLevelZeroOptimizer, self).__init__(optim=optimizer) + + self._dtype = self.optim.param_groups[0]["params"][0].dtype + self._logger = get_dist_logger() + self._verbose = verbose + + if (dp_process_group is not None) and (pg_to_param_list is not None): + raise ValueError("dp_process_group and pg_to_param_list should not be provided at the same time.") + + if pg_to_param_list is None: + unique_dp_group = dist.group.WORLD if dp_process_group is None else dp_process_group + pg_to_param_list = {unique_dp_group: []} + for group in self.optim.param_groups: + pg_to_param_list[unique_dp_group].extend(group["params"]) + + self.pg_to_param_list = pg_to_param_list + param_to_pg = {} + for grp, param_list in pg_to_param_list.items(): + for p in param_list: + assert isinstance(p, nn.Parameter), f"got {type(p)}" + param_to_pg[p] = grp + self.param_to_pg = param_to_pg + + # stage 2 + self._partition_grads = partition_grad + + self._cpu_offload = cpu_offload + + # grad accumulation + self.require_grad_sync = True + + # working and master params for mixed precision training + self._working_param_groups = dict() + self._master_param_groups_of_current_rank = dict() + + # communication params + self._overlap_communication = overlap_communication + self._overlap_allgather = overlap_allgather + self._reduce_bucket_size = reduce_bucket_size + self._communication_dtype = communication_dtype + self._fp8_communication = fp8_communication + self._backward_context = backward_context + + # gradient clipping + self._clip_grad_norm = clip_grad_norm + + # master weights copy + self._master_weights = master_weights + + if forced_dtype: + for group in self.optim.param_groups: + group_params = group["params"] + for param in group_params: + param.data = param.data.to(forced_dtype) + self._dtype = forced_dtype + + # check argument conflict + self._sanity_checks() + + # ParameterStore will manage the tensor buffers used for zero + # it will not manage the tensors used by mixed precision training + + # record the padding size of each param + self._padding_map = dict() + # padded working param is all-gather buffer and it shares the same memory with working param + self._working_param_to_padded_working_param = dict() + + # mapping working param and master param + self.master_to_working_param = dict() + self.working_to_master_param = dict() + + # NOTE need to gurantee the order of process group is the same accross all ranks + # process_group <---> xxx_store + # process_group <---> [param1 param2 ...] + # each process group have its own stores + # param belonging to one process_group will use corresponding store + self.pg_to_grad_store = { + pg: GradientStore(pg, partition_grad=self._partition_grads) for pg in self.pg_to_param_list + } + # param id to grad store, have to use id(param) as key since it is used in stores + self.pid_to_grad_store = {id(param): self.pg_to_grad_store[param_to_pg[param]] for param in param_to_pg} + self.pg_to_bucket_store = {pg: BucketStore(pg, reduce_bucket_size) for pg in self.pg_to_param_list} + # param id to bucket store, have to use id(param) as key since it is used in stores + self.pid_to_bucket_store = {id(param): self.pg_to_bucket_store[param_to_pg[param]] for param in param_to_pg} + + # iterate over the param group in the optimizer + # partition these param groups for data parallel training + # and add buffers to parameter store for future access + for group_id, param_group in enumerate(self.optim.param_groups): + group_params = list() + for param in param_group["params"]: + if param.requires_grad: + group_params.append(param) + + # add the working params to working_param_groups for bookkeeping + self._working_param_groups[group_id] = group_params + + master_param_current_rank = self._create_master_param_current_rank(group_params) + self._master_param_groups_of_current_rank[group_id] = master_param_current_rank + + # need to replace the params in the `params` field in the optimizer + # so that when the optimizer calls step(), it only updates the tensors + # managed by this data parallel rank + param_group["params"] = master_param_current_rank + + # reduction hook is only used if overlapping communication + # or stage 2 is used + # if it is stage 1 without overlapping, no hook will be attached + self.grad_handles = [] + if self._overlap_communication or self._partition_grads: + self._attach_reduction_hook() + + # initialize mixed precision mixin + self.mixed_precision_mixin: Optional[MixedPrecisionMixin] = None + if self._dtype is torch.float16: + self.mixed_precision_mixin = LowLevelZeroFP16MixedPrecisionMixin( + self.num_param_groups, + self.pg_to_grad_store, + initial_scale=initial_scale, + min_scale=min_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + hysteresis=hysteresis, + max_scale=max_scale, + ) + elif self._dtype is torch.bfloat16: + self.mixed_precision_mixin = BF16MixedPrecisionMixin() + + def __del__(self): + for hook in self.grad_handles: + hook.remove() + + @property + def dtype(self): + return self._dtype + + @property + def num_param_groups(self): + return len(self._working_param_groups) + + def _sanity_checks(self): + assert get_accelerator().name in ["cuda", "npu"], "device is required" + for param_group in self.optim.param_groups: + group_params = param_group["params"] + for param in group_params: + if not hasattr(param, "skip_zero_check") or param.skip_zero_check is False: + assert ( + param.dtype == self._dtype + ), f"Parameters are expected to have the same dtype `{self._dtype}`, but got `{param.dtype}`" + + def _create_master_param_current_rank(self, param_list): + # split each param evenly by world size + params_current_rank = [] + device = "cpu" if self._cpu_offload else get_accelerator().get_current_device() + + for param in param_list: + padding_size = ( + self.pid_to_bucket_store[id(param)].world_size + - param.numel() % self.pid_to_bucket_store[id(param)].world_size + ) % self.pid_to_bucket_store[id(param)].world_size + self.record_param_padding_size(param, padding_size) + + with torch.no_grad(): + if padding_size > 0: + padding_param = torch.nn.functional.pad(param.data.view(-1), [0, padding_size]) + # # reset working params' ptr when no master weights + # if self._master_weights == False: + param.data = padding_param[: param.numel()].view(param.shape) + else: + padding_param = param.data.view(-1) + self._working_param_to_padded_working_param[param] = padding_param + + splited_params = padding_param.split( + padding_param.numel() // self.pid_to_bucket_store[id(param)].world_size + ) + splited_params = splited_params[self.pid_to_bucket_store[id(param)].local_rank] + + # use fp32 when master_weights is True + if self._master_weights is True: + splited_param_current_rank = splited_params.detach().clone().float().to(device) + else: + splited_param_current_rank = splited_params + + params_current_rank.append(splited_param_current_rank) + self.link_master_and_working_param(splited_param_current_rank, param) + + return params_current_rank + + ########################### + # Backward Reduction Hook # + ########################### + + def _attach_reduction_hook(self): + # we iterate over the working params + # on each param, we register a hook to its AccumulateGrad object + self_weakref = proxy(self) + + def _grad_handler(param, group_id): + # if run with no_sync context, would not sync grad when backward + if self_weakref.require_grad_sync: + self_weakref._add_to_bucket(param, group_id) + + for group_id in range(self.num_param_groups): + param_group = self._working_param_groups[group_id] + for param in param_group: + if param.requires_grad: + self.grad_handles.append( + param.register_post_accumulate_grad_hook(partial(_grad_handler, group_id=group_id)) + ) + + ####################### + # Reduction Functions # + ####################### + + def _run_reduction(self): + for bucket_store in self.pg_to_bucket_store.values(): + if bucket_store.num_elements_in_bucket() <= 0: + continue + + bucket_store.build_grad_in_bucket() + + flat_grads = bucket_store.get_flatten_grad() + flat_grads /= bucket_store.world_size + + # ready to add other tensors to bucket + bucket_store.reset_num_elements_in_bucket() + + if self._overlap_communication: + stream = bucket_store.comm_stream + # in case of the memory being reused in the default stream + flat_grads.record_stream(stream) + # waiting for ops in the default stream finishing + stream.wait_stream(get_accelerator().current_stream()) + else: + stream = get_accelerator().current_stream() + + with get_accelerator().stream(stream): + group_id = bucket_store.current_group_id + + grad_dtype = flat_grads.dtype + if self._communication_dtype is not None: + flat_grads = flat_grads.to(self._communication_dtype) + + if not self._partition_grads: + if self._fp8_communication: + all_reduce_fp8(flat_grads, group=bucket_store.torch_pg) + else: + dist.all_reduce(flat_grads, group=bucket_store.torch_pg) + if flat_grads.dtype != grad_dtype: + flat_grads = flat_grads.to(grad_dtype) + + flat_grads_per_rank = flat_grads.split(flat_grads.numel() // bucket_store.world_size) + grad_in_bucket = bucket_store.get_grad() + self._update_unpartitoned_grad(bucket_store, grad_in_bucket.values(), flat_grads_per_rank, group_id) + else: + flat_grads_list = list(flat_grads.split(len(flat_grads) // bucket_store.world_size)) + received_grad = torch.zeros_like(flat_grads_list[0]) + if self._fp8_communication: + reduce_scatter_fp8( + received_grad, + flat_grads_list, + group=bucket_store.torch_pg, + ) + else: + dist.reduce_scatter(received_grad, flat_grads_list, group=bucket_store.torch_pg) + + if received_grad.dtype != grad_dtype: + received_grad = received_grad.to(grad_dtype) + + grad_in_bucket_current_rank = bucket_store.get_grad()[bucket_store.local_rank] + self._update_partitoned_grad(bucket_store, grad_in_bucket_current_rank, received_grad, group_id, 1) + + bucket_store.reset() + + def _update_unpartitoned_grad( + self, bucket_store: BucketStore, origin_grad_list: List, flat_grad_list: List, group_id: int + ) -> None: + for rank, grad_list in enumerate(origin_grad_list): + sync_tensor(flat_grad_list[rank], grad_list) + for grad in grad_list: + param_id = bucket_store.get_param_id_of_grad(grad) + self._add_grad(grad, bucket_store.world_size, group_id, param_id, rank) + + def _update_partitoned_grad( + self, + bucket_store: BucketStore, + origin_grad_list: List, + flat_grad: torch.Tensor, + group_id: int, + partition_num: int, + ) -> None: + sync_tensor(flat_grad, origin_grad_list) + for grad in origin_grad_list: + param_id = bucket_store.get_param_id_of_grad(grad) + self._add_grad(grad, partition_num, group_id, param_id) + + def _add_grad( + self, + grad: torch.Tensor, + partition_num: int, + group_id: int, + param_id: int, + rank: int = 0, + ) -> None: + if ( + len(self.pid_to_grad_store[param_id].get_partitioned_gradients_by_param_id(group_id, param_id)) + < partition_num + ): + self.pid_to_grad_store[param_id].append_gradients_by_param_id(grad, group_id, param_id) + else: + self.pid_to_grad_store[param_id].add_gradients_by_param_id(grad, rank, group_id, param_id) + + def _add_to_bucket(self, param, group_id): + param_size = param.numel() + + # check if the bucket is full + # if full, will reduce the grads already in the bucket + # or got a grad of param from another group + # after reduction, the bucket will be empty + if ( + self.pid_to_bucket_store[id(param)].num_elements_in_bucket() + param_size > self._reduce_bucket_size + or group_id != self.pid_to_bucket_store[id(param)].current_group_id + ): + self._run_reduction() + + padding_size = self.get_param_padding_size(param) + self.pid_to_bucket_store[id(param)].add_param_grad(group_id, param, padding_size) + + ################################ + # torch.optim.Optimizer methods + ################################ + + def backward(self, loss, retain_graph=False): + assert not ( + self._partition_grads and not self.require_grad_sync + ), "ZeRO2(partition_grads) and no_sync are not compatible" + + if self.mixed_precision_mixin is not None: + loss = self.mixed_precision_mixin.pre_backward(loss) + + ctx = nullcontext() if self._backward_context is None else self._backward_context() + with ctx: + loss.backward(retain_graph=retain_graph) + + if not self.require_grad_sync: + return + + self._reduce_grad(self._partition_grads) + + # clear reduced grads + if self._overlap_communication: + get_accelerator().synchronize() + + def backward_by_grad(self, tensor, grad): + assert not ( + self._partition_grads and not self.require_grad_sync + ), "ZeRO2(partition_grads) and gradient accumulation(no_sync) are not compatible" + + if self.mixed_precision_mixin is not None: + grad = self.mixed_precision_mixin.pre_backward_by_grad(tensor, grad) + torch.autograd.backward(tensor, grad) + + if not self.require_grad_sync: + return + self._reduce_grad(self._partition_grads) + + # clear reduced grads + if self._overlap_communication: + get_accelerator().synchronize() + + def zero_bucket_stores(self): + for bucket_store in self.pg_to_bucket_store.values(): + bucket_store.reset_all() + + def zero_grad_stores(self): + for grad_store in self.pg_to_grad_store.values(): + grad_store.reset_all_gradients() + + def zero_grad(self, set_to_none=True): + """ + Set parameter gradients to zero. If set_to_none = True, gradient + will be set to None to save memory. + + :param set_to_none: Whether set the gradient to None. Default value is True. + :type set_to_none: bool + """ + if self.mixed_precision_mixin is not None: + self.mixed_precision_mixin.pre_zero_grad() + for _, param_group in self._working_param_groups.items(): + for param in param_group: + if set_to_none: + param.grad = None + else: + if param.grad is not None: + param.grad.detach() + param.grad.zero_() + self.zero_grad_stores() + self.zero_bucket_stores() + + #################### + # Update Parameter # + #################### + + def step(self, closure=None): + assert closure is None, "closure is not supported by step()" + if not self.require_grad_sync: + return + + if self.mixed_precision_mixin is not None and self.mixed_precision_mixin.should_skip_step(): + if self._verbose: + self._logger.info(f"Found overflow. Skip step") + self.zero_grad() + return + + # record all grads for unscale and clip + grad_partition_groups = [] + norm_groups = [] + + # sometimes not all params are 'really' working + # for instance, when layer drop, the dropped layer has no grad + # and should not be updated + real_working_params = dict() + real_master_params = dict() + + for group_id in range(self.num_param_groups): + master_params = self._master_param_groups_of_current_rank[group_id] + working_params = self._working_param_groups[group_id] + real_working_params[group_id] = [] + real_master_params[group_id] = [] + working_grads = [] + for working_param, master_param in zip(working_params, master_params): + # if a working param requires grad and has no grad + # it is not 'really' working, e.g. the droped layer + # else the splited grad should be attached to the splited param + grad_store = self.pid_to_grad_store[id(working_param)] + grads = grad_store.get_partitioned_gradients_by_param_id(group_id, id(working_param)) + grad_index = 0 if self._partition_grads else grad_store.local_rank + if len(grads) > 0: + real_working_params[group_id].append(working_param) + grad = grads[grad_index] + # no need to copy fp32 grad if master_weights is False + if self._master_weights: + grad = grad.to(master_param.dtype).to(master_param.device) + master_param.grad = grad + grad_partition_groups.append(grad) + real_master_params[group_id].append(master_param) + + # compute norm + norm_group = 0 + for grad_store in self.pg_to_grad_store.values(): + working_grads = grad_store.get_working_grads_by_group_id(group_id) + norm_group += self._compute_grad_norm(dp_pg=grad_store.torch_pg, gradients=working_grads) + + norm_groups.append(norm_group) + + # update the params in the optimizer + self.optim.param_groups[group_id]["params"] = real_master_params[group_id] + + # unscale and clip grads + global_norm = calculate_global_norm_from_list(norm_list=norm_groups) + self._unscale_and_clip_grads(grad_partition_groups, global_norm) + + # update the parameters + self.optim.step() + + # release the grad + grad_partition_groups = [] + for group_id in range(self.num_param_groups): + release_param_grad(self._master_param_groups_of_current_rank[group_id]) + + self.pg_to_tensor_bucket = { + pg: TensorBucket(self.pg_to_bucket_store[pg].reduce_bucket_size) for pg in self.pg_to_param_list + } + + # update working partition updated by the current rank + device = get_accelerator().get_current_device() + for group_id in range(self.num_param_groups): + master_working_param = self.optim.param_groups[group_id]["params"] + for idx, master_param in enumerate(master_working_param): + working_param = real_working_params[group_id][idx] + param_to_gather = master_param.to(device).to(self._dtype) + pg = self.param_to_pg[working_param] + padded_working_param = self._working_param_to_padded_working_param[working_param] + if self._overlap_allgather: + handle = dist.all_gather_into_tensor(padded_working_param, param_to_gather, pg, async_op=True) + set_all_gather_handle(working_param, handle) + else: + if param_to_gather.numel() > self.pg_to_tensor_bucket[pg].max_size: + if self._fp8_communication: + all_gather_fp8( + list(padded_working_param.chunk(dist.get_world_size(pg))), + param_to_gather, + pg, + fp8_format="e4m3", + ) + else: + dist.all_gather_into_tensor(padded_working_param, param_to_gather, pg) + continue + try: + self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param) + except RuntimeError: + self.pg_to_tensor_bucket[pg].all_gather(pg, fp8_communication=self._fp8_communication) + self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param) + self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id] + if not self._overlap_allgather: + for pg, tensor_bucket in self.pg_to_tensor_bucket.items(): + if not tensor_bucket.is_empty(): + tensor_bucket.all_gather(pg, fp8_communication=self._fp8_communication) + + def _compute_grad_norm(self, dp_pg: ProcessGroup, gradients: List[Tensor], norm_type: int = 2) -> float: + r""" + Compute and return the gradient norm for gradient clipping. + + Args: + gradients (List[Tensor]): The gradients to compute norm + norm_type (int, optional): type of the used p-norm, Can be ``'inf'`` for infinity norm. Defaults to 2. + + Returns: + float: The total norm of given gradients + """ + + if len(gradients) == 0: + return 0.0 + + norm_type = float(norm_type) + if norm_type == inf: + total_norm = max(grad.data.abs().max() for grad in gradients) + total_norm_cuda = torch.tensor( + [float(total_norm)], + device=get_accelerator().get_current_device(), + dtype=torch.float, + ) + dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=dp_pg) + total_norm = total_norm_cuda.item() + + else: + total_norm_exponentiated = 0.0 + for grad in gradients: + grad_norm_exponentiated = grad.data.float().norm(norm_type) ** norm_type + total_norm_exponentiated += grad_norm_exponentiated + + # Sum across all model parallel GPUs. + total_norm_exponentiated_cuda = torch.tensor( + [float(total_norm_exponentiated)], + device=get_accelerator().get_current_device(), + dtype=torch.float, + ) + torch.distributed.all_reduce( + total_norm_exponentiated_cuda, + op=torch.distributed.ReduceOp.SUM, + group=dp_pg, + ) + total_norm = total_norm_exponentiated_cuda.item() ** (1.0 / norm_type) + + return total_norm + + ############################# + # Mixed Precision Utilities # + ############################# + + def _unscale_and_clip_grads(self, grad_groups_flat, total_norm): + # compute combined scale factor for this group + div_scale = 1.0 + if self.mixed_precision_mixin is not None: + div_scale = self.mixed_precision_mixin.get_grad_div_scale() + + if self._clip_grad_norm > 0.0: + # norm is in fact norm*scale + clip = ((total_norm / div_scale) + 1e-6) / self._clip_grad_norm + if clip > 1: + div_scale = clip * div_scale + + for grad in grad_groups_flat: + grad.data.mul_(1.0 / div_scale) + + ############################ + # Gradient Synchronization # + ############################ + + # this method is used to sync gradient manually + def _sync_grad(self): + for group_id in range(self.num_param_groups): + param_group = self._working_param_groups[group_id] + for param in param_group: + if is_moe_tensor(param) and param.requires_grad and param.grad is None: + # TODO better of of doing this + # assign zero grad to unrouted expert to avoid hang during grad reduction + param.grad = torch.zeros_like(param) + + if param.requires_grad and param.grad is not None: + self._add_to_bucket(param, group_id) + + self._run_reduction() + + def _reduce_grad(self, partition_grad): + # if not overlapping communication (no reduction hook is attached) when zero1 + # we need to manually reduce these gradients + if not partition_grad and not self._overlap_communication: + self._sync_grad() + else: + self._run_reduction() + + # this context comes from pytorch DDP + @contextmanager + def no_sync(self): + old_require_grad_sync = self.require_grad_sync + self.require_grad_sync = False + try: + yield + finally: + self.require_grad_sync = old_require_grad_sync + + ############## + # State Dict # + ############## + + def _pack_state(self, state: Dict) -> Dict: + # comes from pytorch optimizer.state_dict() + param_mappings = {} + start_index = 0 + + def pack_group(group): + nonlocal start_index + packed = {k: v for k, v in group.items() if k != "params"} + param_mappings.update( + {id(p): i for i, p in enumerate(group["params"], start_index) if id(p) not in param_mappings} + ) + packed["params"] = [param_mappings[id(p)] for p in group["params"]] + start_index += len(packed["params"]) + return packed + + param_groups = [pack_group(g) for g in self.optim.param_groups] + # Remap state to use order indices as keys + packed_state = {(param_mappings[id(k)] if isinstance(k, torch.Tensor) else k): v for k, v in state.items()} + + return {"state": packed_state, "param_groups": param_groups} + + def state_dict(self) -> Dict: + """Return a state_dict same with DDP + + Returns: + Dict: the pytorch form state_dict + """ + zero_state = dict() + device = get_accelerator().get_current_device() + for param, state in self.optim.state.items(): + zero_state[param] = copy.deepcopy(state) + for k, v in state.items(): + if isinstance(v, torch.Tensor) and k != "step": + working_param = self.master_to_working_param[id(param)] + pg = self.param_to_pg[working_param] + gather_tensor = [torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(pg.size())] + dist.all_gather(gather_tensor, v.to(device), group=pg) + param_state = ( + torch.stack(gather_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu() + ) + zero_state[param][k] = param_state + + states_dict = self._pack_state(zero_state) + + return states_dict + + def load_state_dict(self, state_dict: Dict): + """Load state dict, requires the state_dict be the pytorch form + + Args: + state_dict (dict): A pytorch form state_dict + """ + zero_state_dict = copy.deepcopy(state_dict) + idx2master = {} + cnt = 0 + for param_group in self.optim.param_groups: + for param in param_group["params"]: + idx2master[cnt] = param + cnt += 1 + for param_idx, state in zero_state_dict["state"].items(): + pg = self.param_to_pg[self.master_to_working_param[id(idx2master[param_idx])]] + for k, v in state.items(): + if isinstance(v, torch.Tensor) and k != "step": + padding_size = (pg.size() - v.numel() % pg.size()) % pg.size() + with torch.no_grad(): + v = v.flatten() + if padding_size > 0: + v = torch.nn.functional.pad(v, [0, padding_size]) + v_list = v.split(v.numel() // pg.size()) + zero_state_dict["state"][param_idx][k] = v_list[pg.rank()].detach().clone() + + self.optim.load_state_dict(zero_state_dict) + + def state_dict_shard(self, max_shard_size: int = 1024) -> Iterator[Tuple[Dict, int]]: + """Returns dictionaries containing a whole state of the module one by one. The max size of dictionary shard is specified by ``max_shard_size``. + Only include the 'state' in state_dict. + + Args: + max_shard_size (int, optional): max size of state shard (in MB). Defaults to 1024. + + Yields: + Iterator[OrderedDict]: A generator of state dict shard + """ + ret_block = dict() + ret_block_size = 0 + + device = get_accelerator().get_current_device() + local_states = self.optim.state_dict()["state"] + + idx2master = {} + cnt = 0 + for param_group in self.optim.param_groups: + for param in param_group["params"]: + idx2master[cnt] = param + cnt += 1 + for param_idx, states in local_states.items(): + current_block_size = 0 + current_block = copy.deepcopy(states) + + master_param = idx2master[param_idx] + working_param = self.master_to_working_param[id(master_param)] + pg = self.param_to_pg[working_param] + + for k, v in states.items(): + if isinstance(v, torch.Tensor) and k != "step": + state_tensor = [torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(pg.size())] + dist.all_gather(state_tensor, v.to(device), group=pg) + state_tensor = ( + torch.stack(state_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu() + ) + current_block_size += state_tensor.numel() + current_block[k] = state_tensor + + if ret_block_size + current_block_size > max_shard_size and len(ret_block) > 0: + yield ret_block, ret_block_size + ret_block = dict() + ret_block_size = 0 + + ret_block[param_idx] = current_block + ret_block_size += current_block_size + + yield ret_block, ret_block_size + + def update_master_params(self, model: nn.Module) -> None: + """Update master params from working params + + Args: + model (nn.Module): The model to update master params + """ + for p in model.parameters(): + p_id = id(p) + if p_id in self.working_to_master_param: + pg = self.param_to_pg[p] + master_param = self.working_to_master_param[p_id] + padding_size = self.get_param_padding_size(p) + working_param = p.data.view(-1) + if padding_size > 0: + working_param = torch.nn.functional.pad(working_param, [0, padding_size]) + master_param.copy_(working_param.chunk(pg.size())[pg.rank()]) + + def get_working_to_master_map(self) -> Dict[int, torch.Tensor]: + return self.working_to_master_param + + def get_master_to_working_map(self) -> Dict[int, torch.Tensor]: + return self.master_to_working_param + + def get_param_padding_map(self) -> Dict[int, torch.Tensor]: + return self._padding_map + + def record_param_padding_size(self, param: Tensor, padding_size: int): + """Record the padding size of a param + + Args: + param (Tensor): The parameter + padding_size (int): The padding size of the parameter + """ + + self._padding_map[id(param)] = padding_size + + def get_param_padding_size(self, param: Tensor) -> int: + """Return the padding size of the parameter + + Args: + param (Tensor): The parameter + + Returns: + int: the padding size of the parameter + """ + + return self._padding_map[id(param)] + + def link_master_and_working_param(self, master_param: Tensor, working_param: Tensor): + """Mapping master parameter and working parameter + + Args: + master_param (Tensor): The parameter copy in optimizer + working_param (Tensor): The parameter of the model + """ + + self.master_to_working_param[id(master_param)] = working_param + self.working_to_master_param[id(working_param)] = master_param + + def get_padding_map(self) -> Dict[int, Tensor]: + """Return the padding map + + Returns: + Dict[int, Tensor]: The padding map + """ + + return self._padding_map + + def get_param_grad(self, working_param: nn.Parameter) -> Tensor: + grad_store = self.pid_to_grad_store[id(working_param)] + grad = grad_store.get_working_grad_by_param_id(id(working_param)) + if grad is None: + return None + grad_flat = torch.empty((grad_store.world_size, *grad.shape), dtype=grad.dtype, device=grad.device) + dist.all_gather_into_tensor(grad_flat, grad, group=grad_store.torch_pg) + return grad_flat.view(-1)[: working_param.numel()].view_as(working_param) + + def get_working_grads_by_group_id(self, group_id: int) -> List[Tensor]: + working_grads = [] + for grad_store in self.pg_to_grad_store.values(): + working_grads.extend(grad_store.get_working_grads_by_group_id(group_id)) + return working_grads + + def get_param_id_for_grad(self, grad: Tensor) -> int: + param_id = None + for grad_store in self.pg_to_grad_store.values(): + id_maybe_none = grad_store.get_param_id_for_grad(grad) + if id_maybe_none is not None: + if param_id is not None: + raise ValueError("The grad mapping is not unique") + param_id = id_maybe_none + return param_id + + def get_working_grad_by_param_id(self, param_id: int) -> Tensor: + grad_store = self.pid_to_grad_store[param_id] + return grad_store.get_working_grad_by_param_id(param_id) + + def get_partitioned_gradients_by_param_id(self, group_id: int, param_id: int) -> List: + grad_store = self.pid_to_grad_store[param_id] + return grad_store.get_partitioned_gradients_by_param_id(group_id, param_id) + + def _force_wait_all_gather(self): + for param in self._working_param_to_padded_working_param.keys(): + wait_all_gather_handle(param) diff --git a/toolbox/ColossalAI/v0.4.4/patches/examples/language/mixtral/benchmark.py b/toolbox/ColossalAI/v0.4.4/patches/examples/language/mixtral/benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..9fb85c4c4ec0e2e8124bff2c8bd6069ed5bcbb5a --- /dev/null +++ b/toolbox/ColossalAI/v0.4.4/patches/examples/language/mixtral/benchmark.py @@ -0,0 +1,264 @@ +#!/usr/bin/env python3 +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# modified from llama benchmark +import argparse +import resource +import time +import warnings +from contextlib import nullcontext + +import torch +import torch.distributed as dist +from data_utils import RandomDataset +from model_utils import format_numel_str, get_model_numel +from performance_evaluator import PerformanceEvaluator, get_profile_context +from tqdm import tqdm +from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM + +import colossalai +from colossalai.accelerator import get_accelerator +from colossalai.booster import Booster +from colossalai.booster.plugin import MoeHybridParallelPlugin +from colossalai.cluster import DistCoordinator +from colossalai.lazy import LazyInitContext +from colossalai.nn.optimizer import HybridAdam +from colossalai.shardformer import PipelineGradientCheckpointConfig + +warnings.filterwarnings("ignore") +# ============================== +# Constants +# ============================== + +# We have lots of llamas for your choice! +MODEL_CONFIGS = { + "100m": MixtralConfig( + max_position_embeddings=4096, + num_hidden_layers=4, + num_attention_heads=32, + intermediate_size=768, + hidden_size=768, + attn_implementation="flash_attention_2", + ), + "7b": MixtralConfig( + max_position_embeddings=4096, + num_hidden_layers=1, + attn_implementation="flash_attention_2", + ), + "14b": MixtralConfig( + max_position_embeddings=4096, + num_hidden_layers=10, + attn_implementation="flash_attention_2", + ), +} + + +def main(): + # ============================== + # Parse Arguments + # ============================== + parser = argparse.ArgumentParser() + parser.add_argument("-c", "--config", type=str, default="100m", help="Model configuration") + parser.add_argument( + "-p", + "--plugin", + choices=["3d"], + default="3d", + help="Choose which plugin to use", + ) + parser.add_argument("-b", "--batch_size", type=int, default=1, help="Batch size") + parser.add_argument("-s", "--num_steps", type=int, default=5, help="Number of steps to run") + parser.add_argument("-i", "--ignore_steps", type=int, default=2, help="Number of steps to ignore") + parser.add_argument("-g", "--grad_checkpoint", action="store_true", help="Use gradient checkpointing") + parser.add_argument("-l", "--max_length", type=int, default=4096, help="Max sequence length") + parser.add_argument( + "-w", "--warmup_ratio", type=float, default=0.8, help="warm up ratio of non-model data. Only for gemini-auto" + ) + parser.add_argument("-m", "--memory_limit", type=int, help="Gemini memory limit in mb") + parser.add_argument("-x", "--xformers", action="store_true", help="Use xformers") + parser.add_argument("--shard_param_frac", type=float, default=1.0, help="Shard param fraction. Only for gemini") + parser.add_argument("--offload_optim_frac", type=float, default=0.0, help="Offload optim fraction. Only for gemini") + parser.add_argument("--offload_param_frac", type=float, default=0.0, help="Offload param fraction. Only for gemini") + parser.add_argument("--tp", type=int, default=1, help="Tensor parallel size") + parser.add_argument("--ep", type=int, default=1, help="Expert parallel size") + parser.add_argument("--sp", type=int, default=1, help="Sequence parallel size") + parser.add_argument("--extra_dp", type=int, default=1, help="Extra data parallel size, used for Gemini") + parser.add_argument("--pp", type=int, default=1, help="Pipeline parallel size") + parser.add_argument("--mbs", type=int, default=1, help="Micro batch size of pipeline parallel") + parser.add_argument("--zero", type=int, default=1, help="Zero Stage when hybrid plugin is enabled") + parser.add_argument("--custom-ckpt", action="store_true", help="Customize checkpoint", default=False) + + parser.add_argument("--pp_style", default="1f1b", choices=["1f1b", "interleaved"]) + parser.add_argument("--n_chunks", default=1, help="number of model chunks", type=eval) + parser.add_argument("--profile", action="store_true", help="Profile the code") + parser.add_argument( + "--nsys", + action="store_true", + help="Use nsys for profiling. \ + You should put something like this before colossalai launch: \ + nsys profile -w true -t cuda,cudnn,cublas -s cpu --capture-range=cudaProfilerApi --capture-range-end=stop --cudabacktrace=true -x true --python-backtrace=cuda -o prof_out", + ) + parser.add_argument("--disable-async-reduce", action="store_true", help="Disable the asynchronous reduce operation") + parser.add_argument("--prefetch_num", type=int, default=0, help="chunk prefetch max number") + parser.add_argument("--no_cache", action="store_true") + parser.add_argument("--use_fp8_comm", action="store_true", default=False, help="for using fp8 during communication") + parser.add_argument("--use_fp8", action="store_true", default=False, help="for using fp8 linear") + parser.add_argument("--overlap_allgather", action="store_true") + parser.add_argument( + "--sp_mode", + default="all_to_all", + choices=["all_to_all"], + help="Sequence parallelism mode", + ) + parser.add_argument("--debug", action="store_true", help="Enable debug mode") + args = parser.parse_args() + + colossalai.launch_from_torch() + coordinator = DistCoordinator() + + # ckpt config for LLaMA3-70B on 64 H100 GPUs + hybrid_kwargs = ( + { + "gradient_checkpoint_config": PipelineGradientCheckpointConfig( + num_ckpt_layers_per_stage=[19, 19, 19, 13], + ), + "num_layers_per_stage": [19, 20, 20, 21], + "pp_style": "interleaved", + } + if args.custom_ckpt + else {} + ) + + # ============================== + # Initialize Booster + # ============================== + if args.plugin == "3d": + plugin = MoeHybridParallelPlugin( + ep_size=args.ep, + tp_size=args.tp, + pp_size=args.pp, + pp_style=args.pp_style, + num_model_chunks=args.n_chunks, + zero_stage=args.zero, + sp_size=args.sp, + sequence_parallelism_mode=args.sp_mode, + enable_sequence_parallelism=args.sp > 1, + enable_fused_normalization=torch.cuda.is_available(), + enable_flash_attention=args.xformers, + microbatch_size=args.mbs, + precision="bf16", + enable_metadata_cache=not args.no_cache, + overlap_allgather=args.overlap_allgather, + use_fp8=args.use_fp8, + fp8_communication=args.use_fp8_comm, + **hybrid_kwargs, + ) + else: + raise ValueError(f"Unknown plugin {args.plugin}") + + booster = Booster(plugin=plugin) + + # ============================== + # Initialize Dataset and Dataloader + # ============================== + dp_size = getattr(plugin, "dp_size", coordinator.world_size) + + if args.config in MODEL_CONFIGS: + config = MODEL_CONFIGS[args.config] + else: + config = MixtralConfig.from_pretrained(args.config, trust_remote_code=True) + config.max_position_embeddings = args.max_length + torch.cuda.manual_seed(42) + + dataset = RandomDataset( + num_samples=args.batch_size * args.num_steps * dp_size, max_length=args.max_length, vocab_size=config.vocab_size + ) + dataloader = plugin.prepare_dataloader(dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, seed=42) + + # ============================== + # Initialize Model and Optimizer + # ============================== + init_ctx = ( + LazyInitContext(default_device=get_accelerator().get_current_device()) + if isinstance(plugin, MoeHybridParallelPlugin) + else nullcontext() + ) + + with init_ctx: + model = MixtralForCausalLM(config=config).to(torch.bfloat16) + + if args.grad_checkpoint: + model.gradient_checkpointing_enable() + + model_numel = get_model_numel(model) + coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}") + performance_evaluator = PerformanceEvaluator( + model_numel, + model.config.num_hidden_layers, + model.config.hidden_size, + model.config.vocab_size, + args.max_length, + args.grad_checkpoint, + args.ignore_steps, + dp_world_size=dp_size, + ) + + optimizer = HybridAdam(model.parameters()) + torch.set_default_dtype(torch.bfloat16) + model, optimizer, _, dataloader, _ = booster.boost(model, optimizer, dataloader=dataloader) + + torch.set_default_dtype(torch.float) + coordinator.print_on_master( + f"Booster init max CUDA memory: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB" + ) + coordinator.print_on_master( + f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB" + ) + + with get_profile_context( + args.profile, + args.ignore_steps, + 1, # avoid creating massive log files + save_dir=f"profile/{time.strftime('%H:%M', time.localtime())}-{args.plugin}-mixtral-{args.config}", + nsys=args.nsys, + ) as prof: + if isinstance(plugin, MoeHybridParallelPlugin) and args.pp > 1: + data_iter = iter(dataloader) + for step in tqdm(range(len(dataloader)), desc="Step", disable=not coordinator.is_master()): + performance_evaluator.on_step_start(step) + outputs = booster.execute_pipeline( + data_iter, + model, + criterion=lambda outputs, inputs: outputs[0], + optimizer=optimizer, + return_loss=True, + ) + loss = outputs["loss"] + if dist.get_rank() == dist.get_world_size() - 1: + print(f"Step {step} loss: {loss}") + optimizer.step() + optimizer.zero_grad() + + performance_evaluator.on_step_end(input_ids=torch.empty(args.batch_size, args.max_length)) + prof.step() + else: + for step, batch in enumerate(tqdm(dataloader, desc="Step", disable=not coordinator.is_master())): + performance_evaluator.on_step_start(step) + outputs = model(**batch) + loss = outputs[0] + del outputs # free memory + + if dist.get_rank() == dist.get_world_size() - 1: + print(f"Step {step} loss: {loss}") + booster.backward(loss, optimizer) + optimizer.step() + optimizer.zero_grad() + + performance_evaluator.on_step_end(**batch) + prof.step() + performance_evaluator.on_fit_end() + coordinator.print_on_master(f"Max CUDA memory usage: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB") + + +if __name__ == "__main__": + main() diff --git a/toolbox/ColossalAI/v0.4.4/patches/examples/language/mixtral/run_benchmark.sh b/toolbox/ColossalAI/v0.4.4/patches/examples/language/mixtral/run_benchmark.sh new file mode 100644 index 0000000000000000000000000000000000000000..ec7301fdb00d2b8fe8b4f5f45dc434174b6b8c70 --- /dev/null +++ b/toolbox/ColossalAI/v0.4.4/patches/examples/language/mixtral/run_benchmark.sh @@ -0,0 +1,57 @@ +#!/bin/bash +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +NUM_GPU=1 + +CONFIG="7b" +BATCH_SIZE=8 +MICRO_BATCH_SIZE=1 +MAX_LENGTH=4096 + +TP=1 +SP=1 +EP=1 +PP=1 + +pp_style="1f1b" + +#################### Single-Node ################# +nsys profile -o nsys_mixtral_layer -t cuda,cudnn,cublas \ + --capture-range cudaProfilerApi --capture-range-end stop --force-overwrite true \ + torchrun --standalone --nproc_per_node $NUM_GPU benchmark.py \ + -c $CONFIG \ + -b $BATCH_SIZE \ + -l $MAX_LENGTH \ + --mbs $MICRO_BATCH_SIZE \ + --tp $TP \ + --sp $SP \ + --ep $EP \ + --pp $PP \ + --pp_style $pp_style \ + --profile \ + --nsys + +# torchrun --standalone --nproc_per_node $NUM_GPU benchmark.py \ +# -c $CONFIG \ +# -b $BATCH_SIZE \ +# -l $MAX_LENGTH \ +# --mbs $MICRO_BATCH_SIZE \ +# --tp $TP \ +# --sp $SP \ +# --ep $EP \ +# --pp $PP \ +# --pp_style $pp_style + # > benchmark_mixtral_tp${TP}sp${SP}pp${PP}ep${EP}.log 2>&1 diff --git a/toolbox/ColossalAI/v0.4.4/patches/examples/language/performance_evaluator.py b/toolbox/ColossalAI/v0.4.4/patches/examples/language/performance_evaluator.py new file mode 100644 index 0000000000000000000000000000000000000000..ed459d0bfacfd8bbf5aa26323a7b41f11213c0f8 --- /dev/null +++ b/toolbox/ColossalAI/v0.4.4/patches/examples/language/performance_evaluator.py @@ -0,0 +1,181 @@ +#!/usr/bin/env python3 +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +from time import time +from typing import Optional + +import torch +import torch.distributed as dist +from torch import Tensor +from torch.profiler import ProfilerActivity, profile, schedule, tensorboard_trace_handler + +from colossalai.cluster import DistCoordinator + + +def divide(x: float, y: float) -> float: + if y == 0: + return float("inf") + elif y == float("inf"): + return float("nan") + return x / y + + +@torch.no_grad() +def all_reduce_mean(x: float, world_size: int) -> float: + if world_size == 1: + return x + + # Use CPU tensor to avoid OOM/weird NCCl error + gloo_group = dist.new_group(backend="gloo") + tensor = torch.tensor([x], device="cpu") + dist.all_reduce(tensor, group=gloo_group) + tensor = tensor / world_size + return tensor.item() + + +def get_profile_context(enable_flag, warmup_steps, active_steps, save_dir, nsys=False): + class DummyProfiler: + def __init__(self): + self.step_number = 0 + + def step(self): + self.step_number += 1 + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + pass + + class NsysProfiler: + def __init__(self, warmup_steps, active_steps): + self.step_number = 0 + self.warmup_steps = warmup_steps + self.active_steps = active_steps + + def step(self): + if self.step_number == self.warmup_steps: + torch.cuda.cudart().cudaProfilerStart() + elif self.step_number == self.warmup_steps + self.active_steps: + torch.cuda.cudart().cudaProfilerStop() + self.step_number += 1 + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + pass + + if enable_flag: + if nsys: + return NsysProfiler(warmup_steps, active_steps) + + return profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + schedule=schedule(wait=0, warmup=warmup_steps, active=active_steps), + on_trace_ready=tensorboard_trace_handler(save_dir), + record_shapes=True, + profile_memory=True, + with_stack=True, + ) + else: + return DummyProfiler() + + +class Timer: + def __init__(self) -> None: + self.start_time: Optional[float] = None + self.duration: float = 0.0 + + def start(self) -> None: + self.start_time = time() + + def end(self) -> None: + assert self.start_time is not None + self.duration += time() - self.start_time + self.start_time = None + + def reset(self) -> None: + self.duration = 0.0 + + +class PerformanceEvaluator: + """ + Callback for valuate the performance of the model. + Args: + actor_num_params: The number of parameters of the actor model. + critic_num_params: The number of parameters of the critic model. + initial_model_num_params: The number of parameters of the initial model. + reward_model_num_params: The number of parameters of the reward model. + enable_grad_checkpoint: Whether to enable gradient checkpointing. + ignore_episodes: The number of episodes to ignore when calculating the performance. + """ + + def __init__( + self, + model_numel: int, + num_layers: int, + hidden_size: int, + vocab_size: int, + seq_len: int, + enable_grad_checkpoint: bool = False, + ignore_steps: int = 0, + dp_world_size: Optional[int] = None, + ) -> None: + self.model_numel = model_numel + self.enable_grad_checkpoint = enable_grad_checkpoint + self.ignore_steps = ignore_steps + self.num_layers = num_layers + self.hidden_size = hidden_size + self.vocab_size = vocab_size + self.seq_len = seq_len + + self.coordinator = DistCoordinator() + self.dp_world_size = dp_world_size or self.coordinator.world_size + self.disable: bool = False + self.timer = Timer() + self.num_samples: int = 0 + self.flop_megatron = 0 + self.flop: int = 0 + + def on_step_start(self, step: int) -> None: + self.disable = self.ignore_steps > 0 and step < self.ignore_steps + if self.disable: + return + # get_accelerator().synchronize() + self.timer.start() + + def on_step_end(self, input_ids: Tensor, **kwargs) -> None: + if self.disable: + return + # get_accelerator().synchronize() + self.timer.end() + + batch_size, seq_len = input_ids.shape + + self.num_samples += batch_size + checkpoint_activations_factor = 3 + int(self.enable_grad_checkpoint) + self.flop_megatron += ( + 24 * checkpoint_activations_factor * batch_size * seq_len * self.num_layers * (self.hidden_size**2) + ) * ( + 1.0 + (seq_len / (6.0 * self.hidden_size)) + (self.vocab_size / (16.0 * self.num_layers * self.hidden_size)) + ) + self.flop += batch_size * seq_len * self.model_numel * 2 * (3 + int(self.enable_grad_checkpoint)) + + def on_fit_end(self) -> None: + avg_duration = all_reduce_mean(self.timer.duration, self.coordinator.world_size) + avg_throughput = self.num_samples * self.dp_world_size / (avg_duration + 1e-12) + + # BI-V150 per GPU device has two cards + avg_tokens_per_gpu = avg_throughput * self.seq_len / self.coordinator.world_size / 2 + + mp_world_size = self.coordinator.world_size // self.dp_world_size + avg_tflops_per_gpu_megatron = self.flop_megatron / 1e12 / (avg_duration + 1e-12) / mp_world_size + avg_tflops_per_gpu = self.flop / 1e12 / (avg_duration + 1e-12) / mp_world_size + self.coordinator.print_on_master( + f"num_samples: {self.num_samples}, dp_world_size: {self.dp_world_size}, flop_megatron: {self.flop_megatron}, flop: {self.flop}, avg_duration: {avg_duration:.2f}, " + f"avg_throughput: {avg_throughput:.2f}" + ) + self.coordinator.print_on_master( + f"Tokens per GPU per Second:{avg_tokens_per_gpu:.2f}, Throughput: {avg_throughput:.2f} samples/sec, TFLOPS per GPU by Megatron: {avg_tflops_per_gpu_megatron:.2f}, TFLOPS per GPU: {avg_tflops_per_gpu:.2f}" + ) diff --git a/toolbox/ColossalAI/v0.4.4/patches/extensions/csrc/kernel/cuda/layer_norm_kernel.cu b/toolbox/ColossalAI/v0.4.4/patches/extensions/csrc/kernel/cuda/layer_norm_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..85695ce74f37cc4f6c2903f41a9860afc7dc271c --- /dev/null +++ b/toolbox/ColossalAI/v0.4.4/patches/extensions/csrc/kernel/cuda/layer_norm_kernel.cu @@ -0,0 +1,700 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +/* Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. */ +/* All Rights Reserved. */ +/*This code from NVIDIA apex: + * https://github.com/NVIDIA/apex + * with minor changes. */ + +#include +#include + +#include "ATen/ATen.h" +#include "ATen/AccumulateType.h" +#include "ATen/cuda/CUDAContext.h" +#include "ATen/cuda/DeviceUtils.cuh" +#include "common/micros.h" + +template +__device__ void cuWelfordOnlineSum(const U curr, U& mu, U& sigma2, U& count) { + count = count + U(1); + U delta = curr - mu; + U lmean = mu + delta / count; + mu = lmean; + U delta2 = curr - lmean; + sigma2 = sigma2 + delta * delta2; +} + +template +__device__ void cuChanOnlineSum(const U muB, const U sigma2B, const U countB, + U& mu, U& sigma2, U& count) { + U delta = muB - mu; + U nA = count; + U nB = countB; + count = count + countB; + U nX = count; + if (nX > U(0)) { + nA = nA / nX; + nB = nB / nX; + mu = nA * mu + nB * muB; + sigma2 = sigma2 + sigma2B + delta * delta * nA * nB * nX; + } else { + mu = U(0); + sigma2 = U(0); + } +} + +template +__device__ void cuWelfordMuSigma2(const T* __restrict__ vals, const int n1, + const int n2, const int i1, U& mu, U& sigma2, + U* buf) { + // Assumptions: + // 1) blockDim.x == warpSize + // 2) Tensor is contiguous + // 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available. + // + // compute variance and mean over n2 + U count = U(0); + mu = U(0); + sigma2 = U(0); + if (i1 < n1) { + // one warp normalizes one n1 index, + // synchronization is implicit + // initialize with standard Welford algorithm + const int numx = blockDim.x * blockDim.y; + const int thrx = threadIdx.x + threadIdx.y * blockDim.x; + const T* lvals = vals + i1 * n2; + int l = 4 * thrx; + for (; l + 3 < n2; l += 4 * numx) { + for (int k = 0; k < 4; ++k) { + U curr = static_cast(lvals[l + k]); + cuWelfordOnlineSum(curr, mu, sigma2, count); + } + } + for (; l < n2; ++l) { + U curr = static_cast(lvals[l]); + cuWelfordOnlineSum(curr, mu, sigma2, count); + } + // intra-warp reductions + for (int l = 0; l <= 4; ++l) { + int srcLaneB = (threadIdx.x + (1 << l)) & 31; + U muB = WARP_SHFL(mu, srcLaneB); + U countB = WARP_SHFL(count, srcLaneB); + U sigma2B = WARP_SHFL(sigma2, srcLaneB); + cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count); + } + // threadIdx.x == 0 has correct values for each warp + // inter-warp reductions + if (blockDim.y > 1) { + U* ubuf = (U*)buf; + U* ibuf = (U*)(ubuf + blockDim.y); + for (int offset = blockDim.y / 2; offset > 0; offset /= 2) { + // upper half of warps write to shared + if (threadIdx.x == 0 && threadIdx.y >= offset && + threadIdx.y < 2 * offset) { + const int wrt_y = threadIdx.y - offset; + ubuf[2 * wrt_y] = mu; + ubuf[2 * wrt_y + 1] = sigma2; + ibuf[wrt_y] = count; + } + __syncthreads(); + // lower half merges + if (threadIdx.x == 0 && threadIdx.y < offset) { + U muB = ubuf[2 * threadIdx.y]; + U sigma2B = ubuf[2 * threadIdx.y + 1]; + U countB = ibuf[threadIdx.y]; + cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count); + } + __syncthreads(); + } + // threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values + if (threadIdx.x == 0 && threadIdx.y == 0) { + ubuf[0] = mu; + ubuf[1] = sigma2; + } + __syncthreads(); + mu = ubuf[0]; + sigma2 = ubuf[1] / U(n2); + // don't care about final value of count, we know count == n2 + } else { + mu = WARP_SHFL(mu, 0); + sigma2 = WARP_SHFL(sigma2 / U(n2), 0); + } + } +} + +template <> +__device__ void cuWelfordMuSigma2(const at::Half* __restrict__ vals, + const int n1, const int n2, const int i1, + float& mu, float& sigma2, float* buf) { + // Assumptions: + // 1) blockDim.x == warpSize + // 2) Tensor is contiguous + // 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available. + // + // compute variance and mean over n2 + float count = 0.0f; + mu = float(0); + sigma2 = float(0); + if (i1 < n1) { + // one warp normalizes one n1 index, + // synchronization is implicit + // initialize with standard Welford algorithm + const int numx = blockDim.x * blockDim.y; + const int thrx = threadIdx.x + threadIdx.y * blockDim.x; + const at::Half* lvals = vals + i1 * n2; + int l = 8 * thrx; + if ((((size_t)lvals) & 3) != 0) { + // 16 bit alignment + // first thread consumes first point + if (thrx == 0) { + float curr = static_cast(lvals[0]); + cuWelfordOnlineSum(curr, mu, sigma2, count); + } + ++l; + } + // at this point, lvals[l] are 32 bit aligned for all threads. + for (; l + 7 < n2; l += 8 * numx) { + for (int k = 0; k < 8; k += 2) { + float2 curr = __half22float2(*((__half2*)(lvals + l + k))); + cuWelfordOnlineSum(curr.x, mu, sigma2, count); + cuWelfordOnlineSum(curr.y, mu, sigma2, count); + } + } + for (; l < n2; ++l) { + float curr = static_cast(lvals[l]); + cuWelfordOnlineSum(curr, mu, sigma2, count); + } + // intra-warp reductions + for (int l = 0; l <= 4; ++l) { + int srcLaneB = (threadIdx.x + (1 << l)) & 31; + float muB = WARP_SHFL(mu, srcLaneB); + float countB = WARP_SHFL(count, srcLaneB); + float sigma2B = WARP_SHFL(sigma2, srcLaneB); + cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count); + } + // threadIdx.x == 0 has correct values for each warp + // inter-warp reductions + if (blockDim.y > 1) { + float* ubuf = (float*)buf; + float* ibuf = (float*)(ubuf + blockDim.y); + for (int offset = blockDim.y / 2; offset > 0; offset /= 2) { + // upper half of warps write to shared + if (threadIdx.x == 0 && threadIdx.y >= offset && + threadIdx.y < 2 * offset) { + const int wrt_y = threadIdx.y - offset; + ubuf[2 * wrt_y] = mu; + ubuf[2 * wrt_y + 1] = sigma2; + ibuf[wrt_y] = count; + } + __syncthreads(); + // lower half merges + if (threadIdx.x == 0 && threadIdx.y < offset) { + float muB = ubuf[2 * threadIdx.y]; + float sigma2B = ubuf[2 * threadIdx.y + 1]; + float countB = ibuf[threadIdx.y]; + cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count); + } + __syncthreads(); + } + // threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values + if (threadIdx.x == 0 && threadIdx.y == 0) { + ubuf[0] = mu; + ubuf[1] = sigma2; + } + __syncthreads(); + mu = ubuf[0]; + sigma2 = ubuf[1] / float(n2); + // don't care about final value of count, we know count == n2 + } else { + mu = WARP_SHFL(mu, 0); + sigma2 = WARP_SHFL(sigma2 / float(n2), 0); + } + } +} + +template +U rsqrt(U v) { + return U(1) / sqrt(v); +} +template <> +float rsqrt(float v) { + return rsqrtf(v); +} +template <> +double rsqrt(double v) { + return rsqrt(v); +} + +namespace { +// This is the un-specialized struct. Note that we prevent instantiation of +// this struct by putting an undefined symbol in the function body so it won't +// compile. +// template +// struct SharedMemory +// { +// // Ensure that we won't compile any un-specialized types +// __device__ T *getPointer() +// { +// extern __device__ void error(void); +// error(); +// return NULL; +// } +// }; +// https://github.com/NVIDIA/apex/issues/246 +template +struct SharedMemory; + +template <> +struct SharedMemory { + __device__ float* getPointer() { + extern __shared__ float s_float[]; + return s_float; + } +}; + +} // namespace + +template +__global__ void cuApplyLayerNorm(V* __restrict__ output_vals, + U* __restrict__ mean, U* __restrict__ invvar, + const T* __restrict__ vals, const int n1, + const int n2, const U epsilon, + const V* __restrict__ gamma, + const V* __restrict__ beta) { + // Assumptions: + // 1) blockDim.x == warpSize + // 2) Tensors are contiguous + // + for (auto i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) { + SharedMemory shared; + U* buf = shared.getPointer(); + U mu, sigma2; + cuWelfordMuSigma2(vals, n1, n2, i1, mu, sigma2, buf); + const T* lvals = vals + i1 * n2; + V* ovals = output_vals + i1 * n2; + U c_invvar = rsqrt(sigma2 + epsilon); + const int numx = blockDim.x * blockDim.y; + const int thrx = threadIdx.x + threadIdx.y * blockDim.x; + if (gamma != NULL && beta != NULL) { + for (int i = thrx; i < n2; i += numx) { + U curr = static_cast(lvals[i]); + ovals[i] = gamma[i] * static_cast(c_invvar * (curr - mu)) + beta[i]; + } + } else { + for (int i = thrx; i < n2; i += numx) { + U curr = static_cast(lvals[i]); + ovals[i] = static_cast(c_invvar * (curr - mu)); + } + } + if (threadIdx.x == 0 && threadIdx.y == 0) { + mean[i1] = mu; + invvar[i1] = c_invvar; + } + } +} + +template +__device__ void cuLoadWriteStridedInputs( + const int i1_block, const int thr_load_row_off, const int thr_load_col_off, + const int i2_off, const int row_stride, U* warp_buf1, U* warp_buf2, + const T* input, const V* dout, const int i1_end, const int n2, + const U* __restrict__ mean, const U* __restrict__ invvar) { + int i1 = i1_block + thr_load_row_off; + if (i1 < i1_end) { + U curr_mean = mean[i1]; + U curr_invvar = invvar[i1]; + for (int k = 0; k < blockDim.y; ++k) { + int i2 = i2_off + k; + int load_idx = i1 * n2 + i2; + int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k; + if (i2 < n2) { + U curr_input = static_cast(input[load_idx]); + U curr_dout = static_cast(dout[load_idx]); + warp_buf1[write_idx] = curr_dout; + warp_buf2[write_idx] = + curr_dout * (curr_input - curr_mean) * curr_invvar; + } else { + warp_buf1[write_idx] = U(0); + warp_buf2[write_idx] = U(0); + } + } + } else { + for (int k = 0; k < blockDim.y; ++k) { + int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k; + warp_buf1[write_idx] = U(0); + warp_buf2[write_idx] = U(0); + } + } +} + +template +__device__ void cuLoadAddStridedInputs( + const int i1_block, const int thr_load_row_off, const int thr_load_col_off, + const int i2_off, const int row_stride, U* warp_buf1, U* warp_buf2, + const T* input, const V* dout, const int i1_end, const int n2, + const U* __restrict__ mean, const U* __restrict__ invvar) { + int i1 = i1_block + thr_load_row_off; + if (i1 < i1_end) { + U curr_mean = mean[i1]; + U curr_invvar = invvar[i1]; + for (int k = 0; k < blockDim.y; ++k) { + int i2 = i2_off + k; + int load_idx = i1 * n2 + i2; + int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k; + if (i2 < n2) { + U curr_input = static_cast(input[load_idx]); + U curr_dout = static_cast(dout[load_idx]); + warp_buf1[write_idx] += curr_dout; + warp_buf2[write_idx] += + curr_dout * (curr_input - curr_mean) * curr_invvar; + } + } + } +} + +template +__global__ void cuComputePartGradGammaBeta( + const V* __restrict__ dout, const T* __restrict__ input, const int n1, + const int n2, const U* __restrict__ mean, const U* __restrict__ invvar, + U epsilon, U* part_grad_gamma, U* part_grad_beta) { + const int numsegs_n1 = + (n1 + blockDim.y * blockDim.y - 1) / (blockDim.y * blockDim.y); + const int segs_per_block = (numsegs_n1 + gridDim.y - 1) / gridDim.y; + const int i1_beg = blockIdx.y * segs_per_block * blockDim.y * blockDim.y; + const int i1_beg_plus_one = + (blockIdx.y + 1) * segs_per_block * blockDim.y * blockDim.y; + const int i1_end = i1_beg_plus_one < n1 ? i1_beg_plus_one : n1; + const int row_stride = blockDim.x + 1; + const int thr_load_col_off = (threadIdx.x * blockDim.y) & (blockDim.x - 1); + const int thr_load_row_off = + (threadIdx.x * blockDim.y) / blockDim.x + threadIdx.y * blockDim.y; + const int i2_off = blockIdx.x * blockDim.x + thr_load_col_off; + SharedMemory shared; + U* buf = shared.getPointer(); // buf has at least blockDim.x * blockDim.y * + // blockDim.y + (blockDim.y - + // 1)*(blockDim.x/blockDim.y) elements + U* warp_buf1 = (U*)buf; + U* warp_buf2 = warp_buf1 + blockDim.y * blockDim.y * row_stride; + // compute partial sums from strided inputs + // do this to increase number of loads in flight + cuLoadWriteStridedInputs(i1_beg, thr_load_row_off, thr_load_col_off, i2_off, + row_stride, warp_buf1, warp_buf2, input, dout, + i1_end, n2, mean, invvar); + for (int i1_block = i1_beg + blockDim.y * blockDim.y; i1_block < i1_end; + i1_block += blockDim.y * blockDim.y) { + cuLoadAddStridedInputs(i1_block, thr_load_row_off, thr_load_col_off, i2_off, + row_stride, warp_buf1, warp_buf2, input, dout, + i1_end, n2, mean, invvar); + } + __syncthreads(); + // inter-warp reductions + // sum within each warp + U acc1 = U(0); + U acc2 = U(0); + for (int k = 0; k < blockDim.y; ++k) { + int row1 = threadIdx.y + k * blockDim.y; + int idx1 = row1 * row_stride + threadIdx.x; + acc1 += warp_buf1[idx1]; + acc2 += warp_buf2[idx1]; + } + warp_buf1[threadIdx.y * row_stride + threadIdx.x] = acc1; + warp_buf2[threadIdx.y * row_stride + threadIdx.x] = acc2; + __syncthreads(); + // sum all warps + for (int offset = blockDim.y / 2; offset > 1; offset /= 2) { + if (threadIdx.y < offset) { + int row1 = threadIdx.y; + int row2 = threadIdx.y + offset; + int idx1 = row1 * row_stride + threadIdx.x; + int idx2 = row2 * row_stride + threadIdx.x; + warp_buf1[idx1] += warp_buf1[idx2]; + warp_buf2[idx1] += warp_buf2[idx2]; + } + __syncthreads(); + } + int i2 = blockIdx.x * blockDim.x + threadIdx.x; + if (threadIdx.y == 0 && i2 < n2) { + int row1 = threadIdx.y; + int row2 = threadIdx.y + 1; + int idx1 = row1 * row_stride + threadIdx.x; + int idx2 = row2 * row_stride + threadIdx.x; + part_grad_beta[blockIdx.y * n2 + i2] = warp_buf1[idx1] + warp_buf1[idx2]; + part_grad_gamma[blockIdx.y * n2 + i2] = warp_buf2[idx1] + warp_buf2[idx2]; + } +} + +template +__global__ void cuComputeGradGammaBeta(const U* part_grad_gamma, + const U* part_grad_beta, + const int part_size, const int n1, + const int n2, V* grad_gamma, + V* grad_beta) { + // sum partial gradients for gamma and beta + SharedMemory shared; + U* buf = shared.getPointer(); + int i2 = blockIdx.x * blockDim.x + threadIdx.x; + if (i2 < n2) { + // each warp does sequential reductions until reduced part_size is num_warps + int num_warp_reductions = part_size / blockDim.y; + U sum_gamma = U(0); + U sum_beta = U(0); + const U* part_grad_gamma_ptr = + part_grad_gamma + threadIdx.y * num_warp_reductions * n2 + i2; + const U* part_grad_beta_ptr = + part_grad_beta + threadIdx.y * num_warp_reductions * n2 + i2; + for (int warp_offset = 0; warp_offset < num_warp_reductions; + ++warp_offset) { + sum_gamma += part_grad_gamma_ptr[warp_offset * n2]; + sum_beta += part_grad_beta_ptr[warp_offset * n2]; + } + // inter-warp reductions + const int nbsize3 = blockDim.x * blockDim.y / 2; + for (int offset = blockDim.y / 2; offset >= 1; offset /= 2) { + // top half write to shared memory + if (threadIdx.y >= offset && threadIdx.y < 2 * offset) { + const int write_idx = (threadIdx.y - offset) * blockDim.x + threadIdx.x; + buf[write_idx] = sum_gamma; + buf[write_idx + nbsize3] = sum_beta; + } + __syncthreads(); + // bottom half sums + if (threadIdx.y < offset) { + const int read_idx = threadIdx.y * blockDim.x + threadIdx.x; + sum_gamma += buf[read_idx]; + sum_beta += buf[read_idx + nbsize3]; + } + __syncthreads(); + } + // write out fully summed gradients + if (threadIdx.y == 0) { + grad_gamma[i2] = sum_gamma; + grad_beta[i2] = sum_beta; + } + } +} + +template +__global__ void cuComputeGradInput(const V* __restrict__ dout, + const T* __restrict__ input, const int n1, + const int n2, const U* __restrict__ mean, + const U* __restrict__ invvar, U epsilon, + const V* gamma, T* grad_input) { + for (auto i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) { + U sum_loss1 = U(0); + U sum_loss2 = U(0); + const U c_mean = mean[i1]; + const U c_invvar = invvar[i1]; + const T* k_input = input + i1 * n2; + const V* k_dout = dout + i1 * n2; + const int numx = blockDim.x * blockDim.y; + const int thrx = threadIdx.x + threadIdx.y * blockDim.x; + if (gamma != NULL) { + int l = 4 * thrx; + for (; l + 3 < n2; l += 4 * numx) { + for (int k = 0; k < 4; ++k) { + const U c_h = static_cast(k_input[l + k]); + const U c_loss = static_cast(k_dout[l + k]); + sum_loss1 += c_loss * gamma[l + k]; + sum_loss2 += c_loss * gamma[l + k] * (c_h - c_mean) * c_invvar; + } + } + for (; l < n2; ++l) { + const U c_h = static_cast(k_input[l]); + const U c_loss = static_cast(k_dout[l]); + sum_loss1 += c_loss * gamma[l]; + sum_loss2 += c_loss * gamma[l] * (c_h - c_mean) * c_invvar; + } + } else { + int l = 4 * thrx; + for (; l + 3 < n2; l += 4 * numx) { + for (int k = 0; k < 4; ++k) { + const U c_h = static_cast(k_input[l + k]); + const U c_loss = static_cast(k_dout[l + k]); + sum_loss1 += c_loss; + sum_loss2 += c_loss * (c_h - c_mean) * c_invvar; + } + } + for (; l < n2; ++l) { + const U c_h = static_cast(k_input[l]); + const U c_loss = static_cast(k_dout[l]); + sum_loss1 += c_loss; + sum_loss2 += c_loss * (c_h - c_mean) * c_invvar; + } + } + // intra-warp reductions + for (int mask = blockDim.x / 2; mask > 0; mask /= 2) { + sum_loss1 += WARP_SHFL_XOR(sum_loss1, mask); + sum_loss2 += WARP_SHFL_XOR(sum_loss2, mask); + } + // inter-warp reductions + if (blockDim.y > 1) { + SharedMemory shared; + U* buf = shared.getPointer(); + for (int offset = blockDim.y / 2; offset > 0; offset /= 2) { + // upper half of warps write to shared + if (threadIdx.y >= offset && threadIdx.y < 2 * offset) { + const int wrt_i = (threadIdx.y - offset) * blockDim.x + threadIdx.x; + buf[2 * wrt_i] = sum_loss1; + buf[2 * wrt_i + 1] = sum_loss2; + } + __syncthreads(); + // lower half merges + if (threadIdx.y < offset) { + const int read_i = threadIdx.y * blockDim.x + threadIdx.x; + sum_loss1 += buf[2 * read_i]; + sum_loss2 += buf[2 * read_i + 1]; + } + __syncthreads(); + } + if (threadIdx.y == 0) { + buf[2 * threadIdx.x] = sum_loss1; + buf[2 * threadIdx.x + 1] = sum_loss2; + } + __syncthreads(); + if (threadIdx.y != 0) { + sum_loss1 = buf[2 * threadIdx.x]; + sum_loss2 = buf[2 * threadIdx.x + 1]; + } + } + // all threads now have the two sums over l + U fH = (U)n2; + U term1 = (U(1) / fH) * c_invvar; + T* k_grad_input = grad_input + i1 * n2; + if (gamma != NULL) { + for (int l = thrx; l < n2; l += numx) { + const U c_h = static_cast(k_input[l]); + const U c_loss = static_cast(k_dout[l]); + U f_grad_input = fH * c_loss * gamma[l]; + f_grad_input -= sum_loss1; + f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2; + f_grad_input *= term1; + k_grad_input[l] = static_cast(f_grad_input); + } + } else { + for (int l = thrx; l < n2; l += numx) { + const U c_h = static_cast(k_input[l]); + const U c_loss = static_cast(k_dout[l]); + U f_grad_input = fH * c_loss; + f_grad_input -= sum_loss1; + f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2; + f_grad_input *= term1; + k_grad_input[l] = static_cast(f_grad_input); + } + } + } +} + +template +void HostApplyLayerNorm(V* output, U* mean, U* invvar, const T* input, int n1, + int n2, float epsilon, const V* gamma, const V* beta) { + auto stream = at::cuda::getCurrentCUDAStream().stream(); + const dim3 threads(32, 4, 1); + const uint64_t maxGridY = + at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; + const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1); + int nshared = + threads.y > 1 ? threads.y * sizeof(U) + (threads.y / 2) * sizeof(U) : 0; + cuApplyLayerNorm<<>>( + output, mean, invvar, input, n1, n2, U(epsilon), gamma, beta); +} + +void cuda_layer_norm(at::Tensor* output, at::Tensor* mean, at::Tensor* invvar, + at::Tensor* input, int n1, int n2, +#ifdef VERSION_GE_1_1 + at::IntArrayRef normalized_shape, +#else + at::IntList normalized_shape, +#endif + at::Tensor* gamma, at::Tensor* beta, float epsilon) { + using namespace at; + DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES( + input->scalar_type(), output->scalar_type(), "cuda_layer_norm_kernel", + HostApplyLayerNorm(output->data_ptr(), + mean->data_ptr(), invvar->data_ptr(), + input->data_ptr(), n1, n2, epsilon, + gamma != NULL ? gamma->data_ptr() : NULL, + beta != NULL ? beta->data_ptr() : NULL);) +} + +template +void HostLayerNormGradient(const V* dout, const U* mean, const U* invvar, + at::Tensor* input, int n1, int n2, const V* gamma, + const V* beta, float epsilon, T* grad_input, + V* grad_gamma, V* grad_beta) { + auto stream = at::cuda::getCurrentCUDAStream().stream(); + + if (gamma != NULL && beta != NULL) { + // compute grad_gamma(j) and grad_beta(j) + const int part_size = 16; + const dim3 threads2(32, 4, 1); + const dim3 blocks2((n2 + threads2.x - 1) / threads2.x, part_size, 1); + const int nshared2_a = + 2 * sizeof(U) * threads2.y * threads2.y * (threads2.x + 1); + const int nshared2_b = threads2.x * threads2.y * sizeof(U); + const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b; + at::Tensor part_grad_gamma = at::empty( + {part_size, n2}, input->options().dtype(at::ScalarType::Float)); + at::Tensor part_grad_beta = at::empty_like(part_grad_gamma); + cuComputePartGradGammaBeta<<>>( + dout, input->data_ptr(), n1, n2, mean, invvar, U(epsilon), + part_grad_gamma.data_ptr(), part_grad_beta.data_ptr()); + + const dim3 threads3(32, 8, 1); + const dim3 blocks3((n2 + threads2.x - 1) / threads2.x, 1, 1); + const int nshared3 = threads3.x * threads3.y * sizeof(U); + cuComputeGradGammaBeta<<>>( + part_grad_gamma.data_ptr(), part_grad_beta.data_ptr(), part_size, + n1, n2, grad_gamma, grad_beta); + } + + // compute grad_input + const uint64_t maxGridY = + at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; + const dim3 blocks1(1, std::min((uint64_t)n1, maxGridY), 1); + const dim3 threads1(32, 4, 1); + int nshared = threads1.y > 1 ? threads1.y * threads1.x * sizeof(U) : 0; + cuComputeGradInput<<>>( + dout, input->data_ptr(), n1, n2, mean, invvar, U(epsilon), gamma, + grad_input); +} + +void cuda_layer_norm_gradient(at::Tensor* dout, at::Tensor* mean, + at::Tensor* invvar, at::Tensor* input, int n1, + int n2, +#ifdef VERSION_GE_1_1 + at::IntArrayRef normalized_shape, +#else + at::IntList normalized_shape, +#endif + at::Tensor* gamma, at::Tensor* beta, + float epsilon, at::Tensor* grad_input, + at::Tensor* grad_gamma, at::Tensor* grad_beta) { + using namespace at; + DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES( + input->scalar_type(), gamma->scalar_type(), + "cuda_layer_norm_gradient_kernel", + HostLayerNormGradient( + dout->data_ptr(), mean->data_ptr(), + invvar->data_ptr(), input, n1, n2, + // TMJ pass NULL argument for gamma, beta, grad_gamma and grad_beta + // if gamma Tensor is NULL on input. + gamma != NULL ? gamma->data_ptr() : NULL, + gamma != NULL ? beta->data_ptr() : NULL, epsilon, + grad_input->data_ptr(), + gamma != NULL ? grad_gamma->data_ptr() : NULL, + gamma != NULL ? grad_beta->data_ptr() : NULL);) +} diff --git a/toolbox/ColossalAI/v0.4.4/patches/extensions/pybind/cpu_adam/cpu_adam_x86.py b/toolbox/ColossalAI/v0.4.4/patches/extensions/pybind/cpu_adam/cpu_adam_x86.py new file mode 100644 index 0000000000000000000000000000000000000000..19a021f2c74f910f96b6b0c29804ceaf601e3148 --- /dev/null +++ b/toolbox/ColossalAI/v0.4.4/patches/extensions/pybind/cpu_adam/cpu_adam_x86.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python3 +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +import platform + +from ...cuda_extension import _CudaExtension +from ...utils import append_nvcc_threads + + +class CpuAdamX86Extension(_CudaExtension): + def __init__(self): + super().__init__(name="cpu_adam_x86") + + def is_available(self) -> bool: + return platform.machine() == "x86_64" and super().is_available() + + def assert_compatible(self) -> None: + arch = platform.machine() + assert ( + arch == "x86_64" + ), f"[extension] The {self.name} kernel requires the CPU architecture to be x86_64 but got {arch}" + super().assert_compatible() + + # necessary 4 functions + def sources_files(self): + ret = [ + self.csrc_abs_path("kernel/x86/cpu_adam.cpp"), + ] + return ret + + def cxx_flags(self): + extra_cxx_flags = [ + "-std=c++14", + "-std=c++17", + "-lcudart", + "-lcublas", + "-g", + "-Wno-reorder", + "-fopenmp", + "-march=native", + ] + return ["-O3"] + self.version_dependent_macros + extra_cxx_flags + + def nvcc_flags(self): + extra_cuda_flags = [ + "-std=c++14", + "-std=c++17", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_HALF2_OPERATORS__", + "-DTHRUST_IGNORE_CUB_VERSION_CHECK", + ] + ret = ["-O3"] + self.version_dependent_macros + extra_cuda_flags + super().nvcc_flags() + return append_nvcc_threads(ret) diff --git a/toolbox/ColossalAI/v0.4.4/patches/extensions/pybind/layernorm/layer_norm.cpp b/toolbox/ColossalAI/v0.4.4/patches/extensions/pybind/layernorm/layer_norm.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0fc54ea8183d6c212d47356c144d277ff9ee54ff --- /dev/null +++ b/toolbox/ColossalAI/v0.4.4/patches/extensions/pybind/layernorm/layer_norm.cpp @@ -0,0 +1,158 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +/* Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. */ +/* All Rights Reserved. */ +/*This code from NVIDIA apex: + * https://github.com/NVIDIA/apex + * with minor changes. */ + +#include + +#include +#include + +#include "common/micros.h" + +namespace { + +void compute_n1_n2(at::Tensor input, at::IntArrayRef normalized_shape, int &n1, + int &n2) { + int idiff = input.ndimension() - normalized_shape.size(); + n2 = 1; + for (int i = 0; i < (int)normalized_shape.size(); ++i) { + assert(input.sizes()[i + idiff] == normalized_shape[i]); + n2 *= normalized_shape[i]; + } + n1 = 1; + for (int i = 0; i < idiff; ++i) { + n1 *= input.sizes()[i]; + } +} + +void check_args(at::IntArrayRef normalized_shape, at::Tensor gamma, + at::Tensor beta) { + TORCH_CHECK(!gamma.defined() || gamma.sizes().equals(normalized_shape)); + TORCH_CHECK(!beta.defined() || beta.sizes().equals(normalized_shape)); +} + +void check_args(at::Tensor input, at::IntArrayRef normalized_shape, int &n1, + int &n2) { + int64_t normalized_ndim = normalized_shape.size(); + + if (normalized_ndim < 1) { + std::stringstream ss; + ss << "Expected normalized_shape to be at least 1-dimensional, i.e., " + << "containing at least one element, but got normalized_shape=" + << normalized_shape; + throw std::runtime_error(ss.str()); + } + + auto input_shape = input.sizes(); + auto input_ndim = input.dim(); + + if (input_ndim < normalized_ndim || + !input_shape.slice(input_ndim - normalized_ndim) + .equals(normalized_shape)) { + std::stringstream ss; + ss << "Given normalized_shape=" << normalized_shape + << ", expected input with shape [*"; + for (auto size : normalized_shape) { + ss << ", " << size; + } + ss << "], but got input of size" << input_shape; + throw std::runtime_error(ss.str()); + } + + compute_n1_n2(input, normalized_shape, n1, n2); +} + +void check_args(at::Tensor input, at::IntArrayRef normalized_shape, + at::Tensor gamma, at::Tensor beta, int &n1, int &n2) { + check_args(input, normalized_shape, n1, n2); + check_args(normalized_shape, gamma, beta); +} +} // namespace + +void cuda_layer_norm(at::Tensor *output, at::Tensor *mean, at::Tensor *invvar, + at::Tensor *input, int n1, int n2, + at::IntArrayRef normalized_shape, at::Tensor *gamma, + at::Tensor *beta, float epsilon); + +#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) \ + TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) \ + CHECK_CUDA(x); \ + CHECK_CONTIGUOUS(x) + +std::vector layer_norm_affine(at::Tensor input, + at::IntArrayRef normalized_shape, + at::Tensor gamma, at::Tensor beta, + float epsilon) { + CHECK_INPUT(input); + CHECK_INPUT(gamma); + CHECK_INPUT(beta); + int n1, n2; + check_args(input, normalized_shape, gamma, beta, n1, n2); + + at::Tensor output = + at::empty_like(input, gamma.options().dtype(gamma.scalar_type())); + at::Tensor mean = + at::empty({n1}, input.options().dtype(at::ScalarType::Float)); + at::Tensor invvar = at::empty_like(mean); + + cuda_layer_norm(&output, &mean, &invvar, &input, n1, n2, normalized_shape, + &gamma, &beta, epsilon); + + return {output, mean, invvar}; +} + +void cuda_layer_norm_gradient(at::Tensor *dout, at::Tensor *mean, + at::Tensor *invvar, at::Tensor *input, int n1, + int n2, at::IntArrayRef normalized_shape, + at::Tensor *gamma, at::Tensor *beta, + float epsilon, at::Tensor *grad_input, + at::Tensor *grad_gamma, at::Tensor *grad_beta); + +std::vector layer_norm_gradient_affine( + at::Tensor dout, at::Tensor mean, at::Tensor invvar, at::Tensor input, + at::IntArrayRef normalized_shape, at::Tensor gamma, at::Tensor beta, + float epsilon) { + CHECK_INPUT(dout); + CHECK_INPUT(mean); + CHECK_INPUT(invvar); + CHECK_INPUT(input); + CHECK_INPUT(gamma); + CHECK_INPUT(beta); + int n1, n2; + check_args(input, normalized_shape, gamma, beta, n1, n2); + + at::Tensor grad_input = at::empty_like(input); + at::Tensor grad_gamma = at::empty_like(gamma); + at::Tensor grad_beta = at::empty_like(beta); + + cuda_layer_norm_gradient(&dout, &mean, &invvar, &input, n1, n2, + normalized_shape, &gamma, &beta, epsilon, + &grad_input, &grad_gamma, &grad_beta); + + return {grad_input, grad_gamma, grad_beta}; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward_affine", &layer_norm_affine, "LayerNorm forward (CUDA)"); + m.def("backward_affine", &layer_norm_gradient_affine, + "LayerNorm backward (CUDA)"); +} diff --git a/toolbox/ColossalAI/v0.4.4/patches/extensions/pybind/layernorm/layernorm_cuda.py b/toolbox/ColossalAI/v0.4.4/patches/extensions/pybind/layernorm/layernorm_cuda.py new file mode 100644 index 0000000000000000000000000000000000000000..5f09a4e843da43473ac32ddce5d8426c7fb37fa0 --- /dev/null +++ b/toolbox/ColossalAI/v0.4.4/patches/extensions/pybind/layernorm/layernorm_cuda.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python3 +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +from ...cuda_extension import _CudaExtension +from ...utils import append_nvcc_threads, get_cuda_cc_flag + + +class LayerNormCudaExtension(_CudaExtension): + def __init__(self): + super().__init__(name="layernorm_cuda") + + def sources_files(self): + ret = [self.csrc_abs_path(fname) for fname in ["kernel/cuda/layer_norm_kernel.cu"]] + [ + self.pybind_abs_path("layernorm/layer_norm.cpp") + ] + return ret + + def include_dirs(self): + ret = [self.get_cuda_home_include()] + [self.csrc_abs_path("")] + return ret + + def cxx_flags(self): + return ["-O3"] + self.version_dependent_macros + + def nvcc_flags(self): + ret = ["-O3"] + self.version_dependent_macros + super().nvcc_flags() + return append_nvcc_threads(ret) diff --git a/toolbox/ColossalAI/v0.4.4/patches/extensions/pybind/moe/moe_cuda.py b/toolbox/ColossalAI/v0.4.4/patches/extensions/pybind/moe/moe_cuda.py new file mode 100644 index 0000000000000000000000000000000000000000..6faccb7f5d5a8121ecc22f8804f4d63d6f7c3268 --- /dev/null +++ b/toolbox/ColossalAI/v0.4.4/patches/extensions/pybind/moe/moe_cuda.py @@ -0,0 +1,28 @@ +#!/usr/bin/env python3 +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +from ...cuda_extension import _CudaExtension +from ...utils import append_nvcc_threads, get_cuda_cc_flag + + +class MoeCudaExtension(_CudaExtension): + def __init__(self): + super().__init__(name="moe_cuda") + + def sources_files(self): + ret = [self.csrc_abs_path(fname) for fname in ["kernel/cuda/moe_kernel.cu"]] + [ + self.pybind_abs_path("moe/moe.cpp") + ] + return ret + + def cxx_flags(self): + return ["-O3"] + self.version_dependent_macros + + def nvcc_flags(self): + extra_cuda_flags = [ + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + ] + extra_cuda_flags.extend(get_cuda_cc_flag()) + ret = ["-O3"] + extra_cuda_flags + super().nvcc_flags() + return append_nvcc_threads(ret) diff --git a/toolbox/ColossalAI/v0.4.4/patches/extensions/pybind/optimizer/fused_optimizer_cuda.py b/toolbox/ColossalAI/v0.4.4/patches/extensions/pybind/optimizer/fused_optimizer_cuda.py new file mode 100644 index 0000000000000000000000000000000000000000..8112f37ab70941396a4f9ad1f18ee16a61a2115f --- /dev/null +++ b/toolbox/ColossalAI/v0.4.4/patches/extensions/pybind/optimizer/fused_optimizer_cuda.py @@ -0,0 +1,32 @@ +#!/usr/bin/env python3 +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +from ...cuda_extension import _CudaExtension +from ...utils import get_cuda_cc_flag + + +class FusedOptimizerCudaExtension(_CudaExtension): + def __init__(self): + super().__init__(name="fused_optim_cuda") + + def sources_files(self): + ret = [ + self.csrc_abs_path(fname) + for fname in [ + "kernel/cuda/multi_tensor_sgd_kernel.cu", + "kernel/cuda/multi_tensor_scale_kernel.cu", + "kernel/cuda/multi_tensor_adam_kernel.cu", + "kernel/cuda/multi_tensor_l2norm_kernel.cu", + "kernel/cuda/multi_tensor_lamb_kernel.cu", + ] + ] + [self.pybind_abs_path("optimizer/optimizer.cpp")] + return ret + + def cxx_flags(self): + version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"] + return ["-O3"] + version_dependent_macros + + def nvcc_flags(self): + extra_cuda_flags = ["-lineinfo"] + extra_cuda_flags.extend(get_cuda_cc_flag()) + return ["-O3"] + extra_cuda_flags + super().nvcc_flags() diff --git a/toolbox/ColossalAI/v0.4.4/patches/extensions/pybind/softmax/scaled_masked_softmax_cuda.py b/toolbox/ColossalAI/v0.4.4/patches/extensions/pybind/softmax/scaled_masked_softmax_cuda.py new file mode 100644 index 0000000000000000000000000000000000000000..d96a954cd54edb522b8d5afbc8e541c700ab4047 --- /dev/null +++ b/toolbox/ColossalAI/v0.4.4/patches/extensions/pybind/softmax/scaled_masked_softmax_cuda.py @@ -0,0 +1,31 @@ +#!/usr/bin/env python3 +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +from ...cuda_extension import _CudaExtension +from ...utils import append_nvcc_threads + + +class ScaledMaskedSoftmaxCudaExtension(_CudaExtension): + def __init__(self): + super().__init__(name="scaled_masked_softmax_cuda") + + def sources_files(self): + ret = [self.csrc_abs_path(fname) for fname in ["kernel/cuda/scaled_masked_softmax_kernel.cu"]] + [ + self.pybind_abs_path("softmax/scaled_masked_softmax.cpp") + ] + return ret + + def cxx_flags(self): + return ["-O3"] + self.version_dependent_macros + + def nvcc_flags(self): + extra_cuda_flags = [ + "-std=c++14", + "-std=c++17", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_HALF2_OPERATORS__", + "-DTHRUST_IGNORE_CUB_VERSION_CHECK", + ] + ret = ["-O3"] + self.version_dependent_macros + extra_cuda_flags + super().nvcc_flags() + return append_nvcc_threads(ret) diff --git a/toolbox/ColossalAI/v0.4.4/patches/extensions/pybind/softmax/scaled_upper_triangle_masked_softmax_cuda.py b/toolbox/ColossalAI/v0.4.4/patches/extensions/pybind/softmax/scaled_upper_triangle_masked_softmax_cuda.py new file mode 100644 index 0000000000000000000000000000000000000000..dab094ef587415d020e5c97674bd5b9ffc0a9b56 --- /dev/null +++ b/toolbox/ColossalAI/v0.4.4/patches/extensions/pybind/softmax/scaled_upper_triangle_masked_softmax_cuda.py @@ -0,0 +1,31 @@ +#!/usr/bin/env python3 +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +from ...cuda_extension import _CudaExtension +from ...utils import append_nvcc_threads, get_cuda_cc_flag + + +class ScaledUpperTriangleMaskedSoftmaxCudaExtension(_CudaExtension): + def __init__(self): + super().__init__(name="scaled_upper_triangle_masked_softmax_cuda") + + def sources_files(self): + ret = [ + self.csrc_abs_path(fname) + for fname in [ + "kernel/cuda/scaled_upper_triang_masked_softmax_kernel.cu", + ] + ] + [self.pybind_abs_path("softmax/scaled_upper_triang_masked_softmax.cpp")] + return ret + + def cxx_flags(self): + return ["-O3"] + self.version_dependent_macros + + def nvcc_flags(self): + extra_cuda_flags = [ + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + ] + extra_cuda_flags.extend(get_cuda_cc_flag()) + ret = ["-O3"] + extra_cuda_flags + super().nvcc_flags() + return append_nvcc_threads(ret) diff --git a/toolbox/ColossalAI/v0.4.4/patches/install_colossalai.sh b/toolbox/ColossalAI/v0.4.4/patches/install_colossalai.sh new file mode 100644 index 0000000000000000000000000000000000000000..07b4d8f1660b94e836a9a29122c31cd373ebb8e3 --- /dev/null +++ b/toolbox/ColossalAI/v0.4.4/patches/install_colossalai.sh @@ -0,0 +1,42 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#!/bin/bash + +TARGET_DIR=${TARGET_DIR:-} + +PYTHON_PATH=$(which python3) +PYTHON_DIST_PATH=${TARGET_DIR}/lib/python3/dist-packages + +PKG_DIR="build_pip" +PKG_NAME="colossalai" + +if [[ ! -d ${PKG_DIR} ]]; then + echo "ERROR: Package directory ${PKG_DIR} doesn't exist" + exit 1 +fi + +latest_pkg="$(ls -t ${PKG_DIR} | grep ${PKG_NAME} | head -1)" +if [[ "${latest_pkg}" == "" ]]; then + echo "ERROR: Cannot find latest ${PKG_NAME} package" + exit 1 +else + echo "INFO: Found latest package ${latest_pkg} in directory ${PKG_DIR}" +fi + +${PYTHON_PATH} -m pip uninstall ${PKG_NAME} -y +${PYTHON_PATH} -m pip install ${PKG_DIR}/${latest_pkg} || exit + +# Return 0 status if all finished +exit 0 \ No newline at end of file diff --git a/toolbox/ColossalAI/v0.4.4/patches/requirements/requirements-test.txt b/toolbox/ColossalAI/v0.4.4/patches/requirements/requirements-test.txt new file mode 100644 index 0000000000000000000000000000000000000000..e2bcf8bf26491e5b2804b66c4bd1d239ef19c728 --- /dev/null +++ b/toolbox/ColossalAI/v0.4.4/patches/requirements/requirements-test.txt @@ -0,0 +1,20 @@ +pytest +# coverage==7.2.3 +# git+https://github.com/hpcaitech/pytest-testmon +# torchvision +# timm +# titans +# torchaudio>=0.13.1 +# torchx-nightly==2022.6.29 # torchrec 0.2.0 requires torchx-nightly. This package is updated every day. We fix the version to a specific date to avoid breaking changes. +# torchrec==0.2.0 +contexttimer +einops +# triton +# requests==2.27.1 # downgrade to avoid huggingface error https://github.com/huggingface/transformers/issues/17611 +SentencePiece +ninja +# flash_attn +datasets +pydantic +ray +peft>=0.7.1 diff --git a/toolbox/ColossalAI/v0.4.4/patches/requirements/requirements.txt b/toolbox/ColossalAI/v0.4.4/patches/requirements/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..5487e21b8c4933773ac69cc99225ff8e0ad66aae --- /dev/null +++ b/toolbox/ColossalAI/v0.4.4/patches/requirements/requirements.txt @@ -0,0 +1,26 @@ +numpy +tqdm +psutil +packaging +pre-commit +rich +click +fabric +contexttimer +ninja +# torch>=2.2.0,<=2.4.0 +safetensors +einops +pydantic +ray +sentencepiece +google +protobuf +transformers>=4.39.3 # BI适配的版本太落后,直接使用官方的版本 +peft>=0.7.1 +# bitsandbytes>=0.39.0 +rpyc==6.0.0 +fastapi +uvicorn==0.29.0 +galore_torch +# diffusers==0.29.0 diff --git a/toolbox/ColossalAI/v0.4.4/patches/setup.py b/toolbox/ColossalAI/v0.4.4/patches/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..70f29efd9d2accbc765e87db53fa034e334ce328 --- /dev/null +++ b/toolbox/ColossalAI/v0.4.4/patches/setup.py @@ -0,0 +1,156 @@ +#!/usr/bin/env python3 +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +import os +import sys +from typing import List + +from setuptools import find_packages, setup + +try: + import torch # noqa + from torch.utils.cpp_extension import BuildExtension + + TORCH_AVAILABLE = True +except ImportError: + TORCH_AVAILABLE = False + +THIS_DIR = os.path.dirname(os.path.abspath(__file__)) +BUILD_EXT = int(os.environ.get("BUILD_EXT", "0")) == 1 + +# we do not support windows currently +if sys.platform == "win32": + raise RuntimeError("Windows is not supported yet. Please try again within the Windows Subsystem for Linux (WSL).") + + +def fetch_requirements(path) -> List[str]: + """ + This function reads the requirements file. + + Args: + path (str): the path to the requirements file. + + Returns: + The lines in the requirements file. + """ + with open(path, "r") as fd: + return [r.strip() for r in fd.readlines()] + + +def fetch_readme() -> str: + """ + This function reads the README.md file in the current directory. + + Returns: + The lines in the README file. + """ + with open("README.md", encoding="utf-8") as f: + return f.read() + + +def get_version() -> str: + """ + This function reads the version.txt and generates the colossalai/version.py file. + + Returns: + The library version stored in version.txt. + """ + + setup_file_path = os.path.abspath(__file__) + project_path = os.path.dirname(setup_file_path) + version_txt_path = os.path.join(project_path, "version.txt") + version_py_path = os.path.join(project_path, "colossalai/version.py") + + with open(version_txt_path) as f: + version = f.read().strip() + + # write version into version.py + with open(version_py_path, "w") as f: + f.write(f"__version__ = '{version}'\n") + return version + + +if BUILD_EXT: + if not TORCH_AVAILABLE: + raise ModuleNotFoundError( + "[extension] PyTorch is not found while BUILD_EXT=1. You need to install PyTorch first in order to build CUDA extensions" + ) + + from extensions import ALL_EXTENSIONS + + op_names = [] + ext_modules = [] + + for ext_cls in ALL_EXTENSIONS: + ext = ext_cls() + if ext.support_aot and ext.is_available(): + ext.assert_compatible() + op_names.append(ext.name) + ext_modules.append(ext.build_aot()) + + # show log + if len(ext_modules) == 0: + raise RuntimeError("[extension] Could not find any kernel compatible with the current environment.") + else: + op_name_list = ", ".join(op_names) + print(f"[extension] Building extensions{op_name_list}") +else: + ext_modules = [] + +version = get_version() +if "COLOSSALAI_LOCAL_VERSION_IDENTIFIER" in os.environ: + version += "+" + str(os.environ['COLOSSALAI_LOCAL_VERSION_IDENTIFIER']) + +package_name = "colossalai" + +setup( + name=package_name, + version=version, + packages=find_packages( + exclude=( + "extensions", + "benchmark", + "docker", + "tests", + "docs", + "examples", + "tests", + "scripts", + "requirements", + "*.egg-info", + ), + ), + description="An integrated large-scale model training system with efficient parallelization techniques", + long_description=fetch_readme(), + long_description_content_type="text/markdown", + license="Apache Software License 2.0", + url="https://www.colossalai.org", + project_urls={ + "Forum": "https://github.com/hpcaitech/ColossalAI/discussions", + "Bug Tracker": "https://github.com/hpcaitech/ColossalAI/issues", + "Examples": "https://github.com/hpcaitech/ColossalAI-Examples", + "Documentation": "http://colossalai.readthedocs.io", + "Github": "https://github.com/hpcaitech/ColossalAI", + }, + ext_modules=ext_modules, + cmdclass={"build_ext": BuildExtension} if ext_modules else {}, + install_requires=fetch_requirements("requirements/requirements.txt"), + entry_points=""" + [console_scripts] + colossalai=colossalai.cli:cli + """, + python_requires=">=3.6", + classifiers=[ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: Apache Software License", + "Environment :: GPU :: NVIDIA CUDA", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: System :: Distributed Computing", + ], + package_data={ + "colossalai": [ + "kernel/extensions/csrc/**/*", + "kernel/extensions/pybind/**/*", + ] + }, +) diff --git a/toolbox/DeepSpeed/v0.15.3/patches/.gitignore b/toolbox/DeepSpeed/v0.15.3/patches/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..ed91e9e83b3f68a5c7206f599ee4b7ed666a09b5 --- /dev/null +++ b/toolbox/DeepSpeed/v0.15.3/patches/.gitignore @@ -0,0 +1,67 @@ +## Ignore Python compiled files +*.pyc + +## Ignore IDE-specific files and directories +# JetBrains IDE settings +.idea/ +# Visual Studio Code settings +.vscode/ +# Theia IDE settings +.theia/ + +## Ignore temporary and backup files +# General backup files +*~ +# Vim swap files +*.swp + +## Ignore log files +*.log + +## Ignore a specific generated file +deepspeed/git_version_info_installed.py + +## Ignore Python bytecode cache +__pycache__ + +## Build + installation data +# Build artifacts +build/ + +# Distribution files + +build_pip/ + +dist/ +# Compiled shared objects +*.so +# Deepspeed package info +deepspeed.egg-info/ +# Build information +build.txt + +## Website generated files +# Jekyll generated site +docs/_site/ +# Generated documentation +docs/build +docs/code-docs/source/_build +docs/code-docs/_build +docs/code-docs/build +# SASS cache +.sass-cache/ +# Jekyll cache +.jekyll-cache/ +.jekyll-metadata + +## Testing data +# Saved checkpoints for testing +tests/unit/saved_checkpoint/ +tests/exit_code* + +# HIP files created during AMD compilation +*_hip.cpp +*_hip.h +*.hip +*.cuh +*hip_layers.h diff --git a/toolbox/DeepSpeed/v0.15.3/patches/build_deepspeed.sh b/toolbox/DeepSpeed/v0.15.3/patches/build_deepspeed.sh new file mode 100644 index 0000000000000000000000000000000000000000..ad5045c320936bb4d85215b6ee8fc61c7bde60c2 --- /dev/null +++ b/toolbox/DeepSpeed/v0.15.3/patches/build_deepspeed.sh @@ -0,0 +1,86 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#!/bin/bash + +COREX_VERSION=${COREX_VERSION:-latest} +MAX_JOBS=${MAX_JOBS:-$(nproc --all)} +PYTHON_PATH=$(which python3) +PLATFORM_ID=$(uname -i) +${PYTHON_PATH} -c "import torch;print(torch.__version__)" || { + echo "ERROR: building vision requries torch has been installed." + exit 1 +} +PY_VERSION=`${PYTHON_PATH} -V 2>&1|awk '{print $2}'|awk -F '.' '{print $2}'` +OS_ID=$(awk -F= '/^ID=/{print $2}' /etc/os-release | tr -d '"') +if [[ "${OS_ID}" == "ubuntu" ]]; then + sudo apt-get install libaio-dev -y || exit +elif [[ "${OS_ID}" == "centos" ]]; then + sudo yum install libaio libaio-devel -y || exit +else + echo "Warning: unable to identify OS ..." +fi + +pip3 install -r requirements/requirements-bi.txt + +# ${PYTHON_PATH} -m pip install -r requirements_dev.txt || exit + +if [[ "${COREX_VERSION}" == "latest" ]]; then + COREX_VERSION=`date --utc +%Y%m%d%H%M%S` +fi +export DEEPSPEED_LOCAL_VERSION_IDENTIFIER="corex.${COREX_VERSION}" + +export MAX_JOBS=${MAX_JOBS} + +# export DS_BUILD_OPS=0 +arch=$(uname -m) + +if [ "$arch" == "aarch64" ]; then + echo "This is an ARM architecture" + export DS_BUILD_CPU_ADAM=0 +elif [ "$arch" == "x86_64" ]; then + echo "This is an x86 architecture" + export DS_BUILD_CPU_ADAM=1 +else + echo "Unknown architecture: $arch" +fi +export DS_BUILD_CPU_LION=1 +export DS_BUILD_FUSED_LION=1 +export DS_BUILD_FUSED_ADAM=1 +export DS_BUILD_FUSED_LAMB=1 +export DS_BUILD_SPARSE_ATTN=0 +export DS_BUILD_TRANSFORMER=1 +export DS_BUILD_QUANTIZER=1 +export DS_BUILD_CPU_ADAGRAD=1 +export DS_BUILD_RANDOM_LTD=1 +export DS_BUILD_SPATIAL_INFERENCE=1 +export DS_BUILD_TRANSFORMER_INFERENCE=1 +export DS_BUILD_STOCHASTIC_TRANSFORMER=1 +export DS_BUILD_UTILS=1 +export DS_ACCELERATOR=cuda +export DS_BUILD_AIO=1 + +export DS_BUILD_EVOFORMER_ATTN=0 +export DS_BUILD_SWIGLU=1 +export DS_BUILD_FUSED_ROPE=1 +export DS_BUILD_FUSED_LAYERNORM=1 +export DS_BUILD_GDS=1 + + +${PYTHON_PATH} setup.py build 2>&1 | tee compile.log; [[ ${PIPESTATUS[0]} == 0 ]] || exit + +${PYTHON_PATH} setup.py bdist_wheel -d build_pip || exit + +# Return 0 status if all finished +exit 0 diff --git a/toolbox/DeepSpeed/v0.15.3/patches/clean_deepspeed.sh b/toolbox/DeepSpeed/v0.15.3/patches/clean_deepspeed.sh new file mode 100644 index 0000000000000000000000000000000000000000..eba3a959846a4807539042530eba77f1af471f9d --- /dev/null +++ b/toolbox/DeepSpeed/v0.15.3/patches/clean_deepspeed.sh @@ -0,0 +1,25 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#!/bin/bash + +PYTHON_PATH=$(which python3) + +rm -rf build +${PYTHON_PATH} setup.py clean || true +rm -rf build_pip +rm -rf ipex.egg-info +rm -rf deepspeed/git_version_info_installed.py +# Return 0 status if all finished +exit 0 \ No newline at end of file diff --git a/toolbox/DeepSpeed/v0.15.3/patches/csrc/adam/multi_tensor_adam.cu b/toolbox/DeepSpeed/v0.15.3/patches/csrc/adam/multi_tensor_adam.cu new file mode 100644 index 0000000000000000000000000000000000000000..a675d89e24ca79d89f527395cfce1fca13665b08 --- /dev/null +++ b/toolbox/DeepSpeed/v0.15.3/patches/csrc/adam/multi_tensor_adam.cu @@ -0,0 +1,224 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +/* Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. */ +/* All Rights Reserved. */ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +/* +Copyright NVIDIA/apex +This file is adapted from fused adam in NVIDIA/apex, commit a109f85 +*/ + +#include +#include +#include +#include +// Another possibility: +// #include + +#include + +#include "multi_tensor_apply.cuh" +#include "type_shim.h" + +#define BLOCK_SIZE 512 +#ifdef __ILUVATAR__ +#define ILP 1 +#else +#define ILP 4 +#endif + +typedef enum : int { + ADAM_MODE_0 = 0, // L2 regularization mode + ADAM_MODE_1 = 1 // Decoupled weight decay mode(AdamW) +} adamMode_t; + +using MATH_T = float; + +template +struct AdamFunctor { + __device__ __forceinline__ void operator()(int chunk_size, + volatile int* noop_gmem, + TensorListMetadata<4>& tl, + const float beta1, + const float beta2, + const float beta1_correction, + const float beta2_correction, + const float epsilon, + const float lr, + adamMode_t mode, + const float decay) + { + // I'd like this kernel to propagate infs/nans. + // if(*noop_gmem == 1) + // return; + + index_t tensor_loc = tl.block_to_tensor[blockIdx.x]; + + // potentially use to pass in list of scalar + // int tensor_num = tl.start_tensor_this_launch + tensor_loc; + + index_t chunk_idx = tl.block_to_chunk[blockIdx.x]; + index_t n = tl.sizes[tensor_loc]; + + T* g = (T*)tl.addresses[0][tensor_loc]; + g += chunk_idx * chunk_size; + + T* p = (T*)tl.addresses[1][tensor_loc]; + p += chunk_idx * chunk_size; + + T* m = (T*)tl.addresses[2][tensor_loc]; + m += chunk_idx * chunk_size; + + T* v = (T*)tl.addresses[3][tensor_loc]; + v += chunk_idx * chunk_size; + + n -= chunk_idx * chunk_size; + + // see note in multi_tensor_scale_kernel.cu + for (index_t i_start = 0; i_start < n && i_start < chunk_size; + i_start += blockDim.x * ILP) { + MATH_T r_g[ILP]; + MATH_T r_p[ILP]; + MATH_T r_m[ILP]; + MATH_T r_v[ILP]; +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) { + r_g[ii] = g[i]; + r_p[ii] = p[i]; + r_m[ii] = m[i]; + r_v[ii] = v[i]; + } else { + r_g[ii] = MATH_T(0); + r_p[ii] = MATH_T(0); + r_m[ii] = MATH_T(0); + r_v[ii] = MATH_T(0); + } + } +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + if (mode == ADAM_MODE_0) { // L2 + r_g[ii] = r_g[ii] + (decay * r_p[ii]); + r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii]; + r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii]; + MATH_T next_m_unbiased = r_m[ii] / beta1_correction; + MATH_T next_v_unbiased = r_v[ii] / beta2_correction; + MATH_T denom = sqrtf(next_v_unbiased) + epsilon; + MATH_T update = next_m_unbiased / denom; + r_p[ii] = r_p[ii] - (lr * update); + } else { // weight decay + r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii]; + r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii]; + MATH_T next_m_unbiased = r_m[ii] / beta1_correction; + MATH_T next_v_unbiased = r_v[ii] / beta2_correction; + MATH_T denom = sqrtf(next_v_unbiased) + epsilon; + MATH_T update = (next_m_unbiased / denom) + (decay * r_p[ii]); + r_p[ii] = r_p[ii] - (lr * update); + } + } +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) { + p[i] = r_p[ii]; + m[i] = r_m[ii]; + v[i] = r_v[ii]; + } + } + } + } +}; + +void multi_tensor_adam_cuda(int chunk_size, + at::Tensor noop_flag, + std::vector> tensor_lists, + const float lr, + const float beta1, + const float beta2, + const float epsilon, + const int step, + const int mode, + const int bias_correction, + const float weight_decay) +{ + using namespace at; + + // Handle bias correction mode + float bias_correction1 = 1.0f, bias_correction2 = 1.0f; + if (bias_correction == 1) { + bias_correction1 = 1 - std::pow(beta1, step); + bias_correction2 = 1 - std::pow(beta2, step); + } + + size_t max_size = 0; + bool requires_64bit_indexing = false; + for (auto it = tensor_lists.begin(); it != tensor_lists.end(); it++) { + for (auto it2 = it->begin(); it2 != it->end(); it2++) { + if (it2->numel() > max_size) { + max_size = it2->numel(); + if (max_size >= INT_MAX) { + requires_64bit_indexing = true; + break; + } + } + } + if (requires_64bit_indexing) { break; } + } + + // Assume single type across p,g,m1,m2 now + if (requires_64bit_indexing) { + DISPATCH_DOUBLE_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), + 0, + "adam", + multi_tensor_apply<4>((int64_t)BLOCK_SIZE, + (int64_t)chunk_size, + noop_flag, + tensor_lists, + AdamFunctor(), + beta1, + beta2, + bias_correction1, + bias_correction2, + epsilon, + lr, + (adamMode_t)mode, + weight_decay);) + } else { + DISPATCH_DOUBLE_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), + 0, + "adam", + multi_tensor_apply<4>(BLOCK_SIZE, + chunk_size, + noop_flag, + tensor_lists, + AdamFunctor(), + beta1, + beta2, + bias_correction1, + bias_correction2, + epsilon, + lr, + (adamMode_t)mode, + weight_decay);) + } + + AT_CUDA_CHECK(cudaGetLastError()); +} diff --git a/toolbox/DeepSpeed/v0.15.3/patches/csrc/aio/common/deepspeed_aio_common.cpp b/toolbox/DeepSpeed/v0.15.3/patches/csrc/aio/common/deepspeed_aio_common.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9c348e7f7393f42f5f373193109084a9df2fd2d3 --- /dev/null +++ b/toolbox/DeepSpeed/v0.15.3/patches/csrc/aio/common/deepspeed_aio_common.cpp @@ -0,0 +1,359 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +/* Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. */ +/* All Rights Reserved. */ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +/* +Functionality for swapping optimizer tensors to/from (NVMe) storage devices. +*/ + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "deepspeed_aio_common.h" + +using namespace std; +using namespace std::chrono; + +#define DEBUG_DS_AIO_PERF 0 +#define DEBUG_DS_AIO_SUBMIT_PERF 0 + +static const std::string c_library_name = "deepspeed_aio"; + +static void _report_aio_statistics(const char* tag, + const std::vector>& latencies) + __attribute__((unused)); + +static void _report_aio_statistics(const char* tag, + const std::vector>& latencies) +{ + std::vector lat_usec; + for (auto& lat : latencies) { lat_usec.push_back(lat.count() * 1e6); } + const auto min_lat = *(std::min_element(lat_usec.begin(), lat_usec.end())); + const auto max_lat = *(std::max_element(lat_usec.begin(), lat_usec.end())); + const auto avg_lat = std::accumulate(lat_usec.begin(), lat_usec.end(), 0) / lat_usec.size(); + + std::cout << c_library_name << ": latency statistics(usec) " << tag + << " min/max/avg = " << min_lat << " " << max_lat << " " << avg_lat << std::endl; +} + +static void _get_aio_latencies(std::vector>& raw_latencies, + struct deepspeed_aio_latency_t& summary_latencies) +{ + std::vector lat_usec; + for (auto& lat : raw_latencies) { lat_usec.push_back(lat.count() * 1e6); } + summary_latencies._min_usec = *(std::min_element(lat_usec.begin(), lat_usec.end())); + summary_latencies._max_usec = *(std::max_element(lat_usec.begin(), lat_usec.end())); + summary_latencies._avg_usec = + std::accumulate(lat_usec.begin(), lat_usec.end(), 0) / lat_usec.size(); +} + +static void _do_io_submit_singles(const int64_t n_iocbs, + const int64_t iocb_index, + std::unique_ptr& aio_ctxt, + std::vector>& submit_times) +{ + for (auto i = 0; i < n_iocbs; ++i) { + const auto st = std::chrono::high_resolution_clock::now(); + const auto submit_ret = io_submit(aio_ctxt->_io_ctxt, 1, aio_ctxt->_iocbs.data() + i); + submit_times.push_back(std::chrono::high_resolution_clock::now() - st); +#if DEBUG_DS_AIO_SUBMIT_PERF + printf("submit(usec) %f io_index=%lld buf=%p len=%lu off=%llu \n", + submit_times.back().count() * 1e6, + iocb_index, + aio_ctxt->_iocbs[i]->u.c.buf, + aio_ctxt->_iocbs[i]->u.c.nbytes, + aio_ctxt->_iocbs[i]->u.c.offset); +#endif + assert(submit_ret > 0); + } +} + +static void _do_io_submit_block(const int64_t n_iocbs, + const int64_t iocb_index, + std::unique_ptr& aio_ctxt, + std::vector>& submit_times) +{ + const auto st = std::chrono::high_resolution_clock::now(); + const auto submit_ret = io_submit(aio_ctxt->_io_ctxt, n_iocbs, aio_ctxt->_iocbs.data()); + submit_times.push_back(std::chrono::high_resolution_clock::now() - st); +#if DEBUG_DS_AIO_SUBMIT_PERF + printf("submit(usec) %f io_index=%lld nr=%lld buf=%p len=%lu off=%llu \n", + submit_times.back().count() * 1e6, + iocb_index, + n_iocbs, + aio_ctxt->_iocbs[0]->u.c.buf, + aio_ctxt->_iocbs[0]->u.c.nbytes, + aio_ctxt->_iocbs[0]->u.c.offset); +#endif + assert(submit_ret > 0); +} + +static int _do_io_complete(const int64_t min_completes, + const int64_t max_completes, + std::unique_ptr& aio_ctxt, + std::vector>& reap_times) +{ + const auto start_time = std::chrono::high_resolution_clock::now(); + // long long int n_completes = io_pgetevents(aio_ctxt->_io_ctxt, + // min_completes, + // max_completes, + // aio_ctxt->_io_events.data(), + // nullptr, + // nullptr); + const auto n_completes = io_getevents( + aio_ctxt->_io_ctxt, min_completes, max_completes, aio_ctxt->_io_events.data(), nullptr); + reap_times.push_back(std::chrono::high_resolution_clock::now() - start_time); + assert(n_completes >= min_completes); + return n_completes; +} + +void do_aio_operation_sequential(const bool read_op, + std::unique_ptr& aio_ctxt, + std::unique_ptr& xfer_ctxt, + deepspeed_aio_config_t* config, + deepspeed_aio_perf_t* perf) +{ + struct io_prep_context prep_ctxt(read_op, xfer_ctxt, aio_ctxt->_block_size, &aio_ctxt->_iocbs); + + const auto num_io_blocks = static_cast( + ceil(static_cast(xfer_ctxt->_num_bytes) / aio_ctxt->_block_size)); +#if DEBUG_DS_AIO_PERF + const auto io_op_name = std::string(read_op ? "read" : "write"); + std::cout << c_library_name << ": start " << io_op_name << " " << xfer_ctxt->_num_bytes + << " bytes with " << num_io_blocks << " io blocks" << std::endl; +#endif + + std::vector> submit_times; + std::vector> reap_times; + const auto max_queue_bytes = + static_cast(aio_ctxt->_queue_depth * aio_ctxt->_block_size); + + auto start = std::chrono::high_resolution_clock::now(); + for (int64_t iocb_index = 0; iocb_index < num_io_blocks; iocb_index += aio_ctxt->_queue_depth) { + const auto start_offset = iocb_index * aio_ctxt->_block_size; + const auto start_buffer = (char*)xfer_ctxt->_mem_buffer + start_offset; + const auto n_iocbs = + min(static_cast(aio_ctxt->_queue_depth), (num_io_blocks - iocb_index)); + const auto num_bytes = min(max_queue_bytes, (xfer_ctxt->_num_bytes - start_offset)); + prep_ctxt.prep_iocbs(n_iocbs, num_bytes, start_buffer, start_offset); + + if (config->_single_submit) { + _do_io_submit_singles(n_iocbs, iocb_index, aio_ctxt, submit_times); + } else { + _do_io_submit_block(n_iocbs, iocb_index, aio_ctxt, submit_times); + } + + _do_io_complete(n_iocbs, n_iocbs, aio_ctxt, reap_times); + } + const std::chrono::duration elapsed = std::chrono::high_resolution_clock::now() - start; + + if (perf) { + _get_aio_latencies(submit_times, perf->_submit); + _get_aio_latencies(reap_times, perf->_complete); + perf->_e2e_usec = elapsed.count() * 1e6; + perf->_e2e_rate_GB = (xfer_ctxt->_num_bytes / elapsed.count() / 1e9); + } + +#if DEBUG_DS_AIO_PERF + _report_aio_statistics("submit", submit_times); + _report_aio_statistics("complete", reap_times); +#endif + +#if DEBUG_DS_AIO_PERF + std::cout << c_library_name << ": runtime(usec) " << elapsed.count() * 1e6 + << " rate(GB/sec) = " << (xfer_ctxt->_num_bytes / elapsed.count() / 1e9) << std::endl; +#endif + +#if DEBUG_DS_AIO_PERF + std::cout << c_library_name << ": finish " << io_op_name << " " << xfer_ctxt->_num_bytes + << " bytes " << std::endl; +#endif +} + +void do_aio_operation_overlap(const bool read_op, + std::unique_ptr& aio_ctxt, + std::unique_ptr& xfer_ctxt, + deepspeed_aio_config_t* config, + deepspeed_aio_perf_t* perf) +{ + struct io_prep_generator io_gen(read_op, xfer_ctxt, aio_ctxt->_block_size); + +#if DEBUG_DS_AIO_PERF + const auto io_op_name = std::string(read_op ? "read" : "write"); + std::cout << c_library_name << ": start " << io_op_name << " " << xfer_ctxt->_num_bytes + << " bytes with " << io_gen._num_io_blocks << " io blocks" << std::endl; +#endif + + std::vector> submit_times; + std::vector> reap_times; + + auto request_iocbs = aio_ctxt->_queue_depth; + auto n_pending_iocbs = 0; + const auto min_completes = 1; + auto start = std::chrono::high_resolution_clock::now(); + while (true) { + const auto n_iocbs = io_gen.prep_iocbs(request_iocbs - n_pending_iocbs, &aio_ctxt->_iocbs); + if (n_iocbs > 0) { + if (config->_single_submit) { + _do_io_submit_singles( + n_iocbs, (io_gen._next_iocb_index - n_iocbs), aio_ctxt, submit_times); + } else { + _do_io_submit_block( + n_iocbs, (io_gen._next_iocb_index - n_iocbs), aio_ctxt, submit_times); + } + } + + n_pending_iocbs += n_iocbs; + assert(n_pending_iocbs <= aio_ctxt->_queue_depth); + + if (n_pending_iocbs == 0) { break; } + + const auto n_complete = + _do_io_complete(min_completes, n_pending_iocbs, aio_ctxt, reap_times); + n_pending_iocbs -= n_complete; + } + + const std::chrono::duration elapsed = std::chrono::high_resolution_clock::now() - start; + + if (perf) { + _get_aio_latencies(submit_times, perf->_submit); + _get_aio_latencies(reap_times, perf->_complete); + perf->_e2e_usec = elapsed.count() * 1e6; + perf->_e2e_rate_GB = (xfer_ctxt->_num_bytes / elapsed.count() / 1e9); + } + +#if DEBUG_DS_AIO_PERF + _report_aio_statistics("submit", submit_times); + _report_aio_statistics("complete", reap_times); +#endif + +#if DEBUG_DS_AIO_PERF + std::cout << c_library_name << ": runtime(usec) " << elapsed.count() * 1e6 + << " rate(GB/sec) = " << (xfer_ctxt->_num_bytes / elapsed.count() / 1e9) << std::endl; +#endif + +#if DEBUG_DS_AIO_PERF + std::cout << c_library_name << ": finish " << io_op_name << " " << xfer_ctxt->_num_bytes + << " bytes " << std::endl; +#endif +} + +void report_file_error(const char* filename, const std::string file_op, const int error_code) +{ + std::string err_msg = file_op + std::string(" failed on ") + std::string(filename) + + " error = " + std::to_string(error_code); + std::cerr << c_library_name << ": " << err_msg << std::endl; +} + +int open_file(const char* filename, const bool read_op) +{ + const int flags = read_op ? (O_RDONLY | O_DIRECT) : (O_WRONLY | O_CREAT | O_DIRECT); +#if defined(__ENABLE_CANN__) + int* flags_ptr = (int*)&flags; + *flags_ptr = read_op ? (O_RDONLY) : (O_WRONLY | O_CREAT); +#endif + const int mode = 0600; + const auto fd = open(filename, flags, mode); + if (fd == -1) { + const auto error_code = errno; + const auto error_msg = read_op ? " open for read " : " open for write "; + report_file_error(filename, error_msg, error_code); + return -1; + } + return fd; +} + +int regular_read(const char* filename, std::vector& buffer) +{ + int64_t num_bytes; + const auto f_size = get_file_size(filename, num_bytes); + assert(f_size != -1); + buffer.resize(num_bytes); + const auto fd = open(filename, O_RDONLY, 0600); + assert(fd != -1); + int64_t read_bytes = 0; + auto r = 0; + do { + const auto buffer_ptr = buffer.data() + read_bytes; + const auto bytes_to_read = num_bytes - read_bytes; + r = read(fd, buffer_ptr, bytes_to_read); + read_bytes += r; + } while (r > 0); + + if (read_bytes != num_bytes) { + std::cerr << "read error " << " read_bytes (read) = " << read_bytes + << " num_bytes (fstat) = " << num_bytes << std::endl; + } + assert(read_bytes == num_bytes); + close(fd); + return 0; +} + +static bool _validate_buffer(const char* filename, void* aio_buffer, const int64_t num_bytes) +{ + std::vector regular_buffer; + const auto reg_ret = regular_read(filename, regular_buffer); + assert(0 == reg_ret); + std::cout << "regular read of " << filename << " returned " << regular_buffer.size() << " bytes" + << std::endl; + + if (static_cast(regular_buffer.size()) != num_bytes) { return false; } + + return (0 == memcmp(aio_buffer, regular_buffer.data(), regular_buffer.size())); +} + +bool validate_aio_operation(const bool read_op, + const char* filename, + void* aio_buffer, + const int64_t num_bytes) +{ + const auto msg_suffix = std::string("deepspeed_aio_") + + std::string(read_op ? "read()" : "write()") + + std::string("using read()"); + + if (false == _validate_buffer(filename, aio_buffer, num_bytes)) { + std::cout << "Fail: correctness of " << msg_suffix << std::endl; + return false; + } + + std::cout << "Pass: correctness of " << msg_suffix << std::endl; + return true; +} diff --git a/toolbox/DeepSpeed/v0.15.3/patches/csrc/aio/common/deepspeed_aio_types.h b/toolbox/DeepSpeed/v0.15.3/patches/csrc/aio/common/deepspeed_aio_types.h new file mode 100644 index 0000000000000000000000000000000000000000..adebcc63f42b4de8b9d871ff90fc738e555c40db --- /dev/null +++ b/toolbox/DeepSpeed/v0.15.3/patches/csrc/aio/common/deepspeed_aio_types.h @@ -0,0 +1,76 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +/* Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. */ +/* All Rights Reserved. */ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +/* +Functionality for swapping optimizer tensors to/from (NVMe) storage devices. +*/ + +#include +#include + +#include +#include + +using namespace std; + +struct deepspeed_aio_latency_t { + float _min_usec; + float _max_usec; + float _avg_usec; + + void dump(const std::string tag); + void accumulate(const deepspeed_aio_latency_t&); + void scale(const float value); +}; + +struct deepspeed_aio_perf_t { + deepspeed_aio_latency_t _submit; + deepspeed_aio_latency_t _complete; + float _e2e_usec; + float _e2e_rate_GB; +}; + +struct deepspeed_aio_config_t { + const int _block_size; + const int _queue_depth; + const bool _single_submit; + const bool _overlap_events; + const bool _lock_memory; + + deepspeed_aio_config_t(); + deepspeed_aio_config_t(const int block_size, + const int queue_depth, + const bool single_submit, + const bool overlap_events, + const bool lock_memory); +}; + +struct aio_context { + io_context_t _io_ctxt; + std::vector _io_events; + std::vector _iocbs; + int _block_size; + int _queue_depth; + + aio_context(const int block_size, const int queue_depth); + ~aio_context(); +}; diff --git a/toolbox/DeepSpeed/v0.15.3/patches/csrc/aio/common/deepspeed_aio_utils.cpp b/toolbox/DeepSpeed/v0.15.3/patches/csrc/aio/common/deepspeed_aio_utils.cpp new file mode 100644 index 0000000000000000000000000000000000000000..58455db23b9743b46c84be7a8b47b3b202dfea2f --- /dev/null +++ b/toolbox/DeepSpeed/v0.15.3/patches/csrc/aio/common/deepspeed_aio_utils.cpp @@ -0,0 +1,143 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +/* Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. */ +/* All Rights Reserved. */ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +/* +Functionality for swapping optimizer tensors to/from (NVMe) storage devices. +*/ + +#include +#include + +#include "deepspeed_aio_utils.h" + +using namespace std; + +const int c_block_size = 128 * 1024; +const int c_io_queue_depth = 8; + +io_xfer_ctxt::io_xfer_ctxt(const int fd, + const int64_t file_offset, + const int64_t num_bytes, + const void* buffer) + : _fd(fd), _base_offset(file_offset), _mem_buffer(buffer), _num_bytes(num_bytes) +{ +} + +io_prep_context::io_prep_context(const bool read_op, + const std::unique_ptr& xfer_ctxt, + const size_t block_size, + const std::vector* iocbs) + : _read_op(read_op), _xfer_ctxt(xfer_ctxt), _block_size(block_size), _iocbs(iocbs) +{ +} + +void io_prep_context::prep_iocbs(const int n_iocbs, + const size_t num_bytes, + const void* start_buffer, + const int64_t start_offset) +{ + assert(static_cast(n_iocbs) <= _iocbs->size()); + for (auto i = 0; i < n_iocbs; ++i) { + const auto shift = i * _block_size; + const auto xfer_buffer = (char*)start_buffer + _xfer_ctxt->_base_offset + shift; + const auto xfer_offset = _xfer_ctxt->_base_offset + start_offset + shift; + auto byte_count = _block_size; + if ((shift + _block_size) > num_bytes) { byte_count = num_bytes - shift; } + + if (_read_op) { + io_prep_pread(_iocbs->at(i), _xfer_ctxt->_fd, xfer_buffer, byte_count, xfer_offset); + } else { + io_prep_pwrite(_iocbs->at(i), _xfer_ctxt->_fd, xfer_buffer, byte_count, xfer_offset); + } + } +} + +io_prep_generator::io_prep_generator(const bool read_op, + const std::unique_ptr& xfer_ctxt, + const size_t block_size) + : _read_op(read_op), + _xfer_ctxt(xfer_ctxt), + _block_size(block_size), + _remaining_bytes(xfer_ctxt->_num_bytes), + _next_iocb_index(0) +{ + _num_io_blocks = + static_cast(ceil(static_cast(xfer_ctxt->_num_bytes) / block_size)); + _remaining_io_blocks = _num_io_blocks; +} + +int io_prep_generator::prep_iocbs(const int n_iocbs, std::vector* iocbs) +{ + if ((_remaining_bytes) == 0 || (_remaining_io_blocks == 0)) { + assert(static_cast(_remaining_bytes) == _remaining_io_blocks); + return 0; + } + + assert(static_cast(n_iocbs) <= iocbs->size()); + + auto actual_n_iocbs = min(static_cast(n_iocbs), _remaining_io_blocks); + for (auto i = 0; i < actual_n_iocbs; ++i, ++_next_iocb_index) { + const auto xfer_offset = _xfer_ctxt->_base_offset + (_next_iocb_index * _block_size); + const auto xfer_buffer = (char*)_xfer_ctxt->_mem_buffer + xfer_offset; + const auto num_bytes = min(static_cast(_block_size), _remaining_bytes); + + if (_read_op) { + io_prep_pread(iocbs->at(i), _xfer_ctxt->_fd, xfer_buffer, num_bytes, xfer_offset); + } else { + io_prep_pwrite(iocbs->at(i), _xfer_ctxt->_fd, xfer_buffer, num_bytes, xfer_offset); + } + _remaining_bytes -= num_bytes; + } + _remaining_io_blocks -= actual_n_iocbs; + + return actual_n_iocbs; +} + +int get_file_size(const char* filename, int64_t& size) +{ + struct stat st; + if (stat(filename, &st) == -1) { return -1; } + size = st.st_size; + return 0; +} + +void* ds_page_aligned_alloc(const int64_t size, const bool lock) +{ + void* ptr; + int retval; + + retval = posix_memalign(&ptr, (size_t)sysconf(_SC_PAGESIZE), size); + if (retval) { return nullptr; } + + if (lock == false) { return ptr; } + + auto mlock_ret = mlock(ptr, size); + if (mlock_ret != 0) { + auto mlock_error = errno; + std::cerr << "mlock failed to allocate " << size << " bytes with error no " << mlock_error + << " msg " << strerror(mlock_error) << std::endl; + free(ptr); + return nullptr; + } + + return ptr; +} diff --git a/toolbox/DeepSpeed/v0.15.3/patches/csrc/aio/py_lib/deepspeed_py_aio.cpp b/toolbox/DeepSpeed/v0.15.3/patches/csrc/aio/py_lib/deepspeed_py_aio.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f04fd9200c9bf6305f000078b0f1810401ca1080 --- /dev/null +++ b/toolbox/DeepSpeed/v0.15.3/patches/csrc/aio/py_lib/deepspeed_py_aio.cpp @@ -0,0 +1,137 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +/* Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. */ +/* All Rights Reserved. */ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +/* +Functionality for swapping optimizer tensors to/from (NVMe) storage devices. +*/ + +#include +#include +#include + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "deepspeed_py_aio.h" + +using namespace std; +using namespace std::chrono; + +#define DEBUG_DS_AIO_READ 0 +#define DEBUG_DS_AIO_WRITE 0 + +static const std::string c_library_name = "deepspeed_aio"; + +int deepspeed_py_aio_write(const torch::Tensor& buffer, + const char* filename, + const int block_size, + const int queue_depth, + const bool single_submit, + const bool overlap_events, + const bool validate) +{ + const auto start_time = std::chrono::high_resolution_clock::now(); + deepspeed_aio_config_t config(block_size, queue_depth, single_submit, overlap_events, false); + + const auto fd = open_file(filename, false); + if (fd == -1) { return -1; } + + auto write_buffer = (char*)buffer.data_ptr(); + const auto num_write_bytes = static_cast(buffer.nbytes()); + std::unique_ptr xfer_ctxt(new io_xfer_ctxt(fd, 0, num_write_bytes, write_buffer)); + std::unique_ptr aio_ctxt(new aio_context(config._block_size, config._queue_depth)); + + if (config._overlap_events) { + do_aio_operation_overlap(false, aio_ctxt, xfer_ctxt, &config, nullptr); + } else { + do_aio_operation_sequential(false, aio_ctxt, xfer_ctxt, &config, nullptr); + } + const std::chrono::duration aio_time = + std::chrono::high_resolution_clock::now() - start_time; + + close(fd); + + if (validate) { validate_aio_operation(false, filename, write_buffer, num_write_bytes); } + + const std::chrono::duration fn_time = + std::chrono::high_resolution_clock::now() - start_time; + std::cout << "Elapsed time(usec): " << "aio = " << aio_time.count() * 1e6 + << " call = " << fn_time.count() * 1e6 << std::endl; + return 0; +} + +int deepspeed_py_aio_read(torch::Tensor& buffer, + const char* filename, + const int block_size, + const int queue_depth, + const bool single_submit, + const bool overlap_events, + const bool validate) +{ + const auto start_time = std::chrono::high_resolution_clock::now(); + int64_t num_file_bytes; + if (-1 == get_file_size(filename, num_file_bytes)) { + const auto error_code = errno; + report_file_error(filename, " fstat for read", error_code); + return -1; + } + + deepspeed_aio_config_t config(block_size, queue_depth, single_submit, overlap_events, false); + const auto fd = open_file(filename, true); + if (fd == -1) { return -1; } + + auto read_buffer = (char*)buffer.data_ptr(); + assert(static_cast(buffer.nbytes()) == num_file_bytes); + + std::unique_ptr xfer_ctxt(new io_xfer_ctxt(fd, 0, num_file_bytes, read_buffer)); + std::unique_ptr aio_ctxt(new aio_context(config._block_size, config._queue_depth)); + + if (config._overlap_events) { + do_aio_operation_overlap(true, aio_ctxt, xfer_ctxt, &config, nullptr); + } else { + do_aio_operation_sequential(true, aio_ctxt, xfer_ctxt, &config, nullptr); + } + const std::chrono::duration aio_time = + std::chrono::high_resolution_clock::now() - start_time; + + close(fd); + + if (validate) { validate_aio_operation(true, filename, read_buffer, num_file_bytes); } + + const std::chrono::duration fn_time = + std::chrono::high_resolution_clock::now() - start_time; + std::cout << "Elapsed time(usec): " << "aio = " << aio_time.count() * 1e6 + << " call = " << fn_time.count() * 1e6 << std::endl; + return 0; +} diff --git a/toolbox/DeepSpeed/v0.15.3/patches/csrc/aio/py_lib/deepspeed_py_io_handle.cpp b/toolbox/DeepSpeed/v0.15.3/patches/csrc/aio/py_lib/deepspeed_py_io_handle.cpp new file mode 100644 index 0000000000000000000000000000000000000000..985edb4ab596130dfaac5eead73717030f9d6141 --- /dev/null +++ b/toolbox/DeepSpeed/v0.15.3/patches/csrc/aio/py_lib/deepspeed_py_io_handle.cpp @@ -0,0 +1,322 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +/* Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. */ +/* All Rights Reserved. */ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +/* +Functionality for swapping optimizer tensors to/from (NVMe) storage devices. +*/ + +#include "deepspeed_py_io_handle.h" +#include + +using namespace std; + +static void _start_aio_thread(std::shared_ptr ctxt) { ctxt->run(); } + +deepspeed_io_handle_t::deepspeed_io_handle_t(const int block_size, + const int queue_depth, + const bool single_submit, + const bool overlap_events, + const int intra_op_parallelism) + : _aio_ctxt(new aio_context(block_size, queue_depth)), + _single_submit(single_submit), + _overlap_events(overlap_events), + _intra_op_parallelism(intra_op_parallelism), + _aio_config(block_size, queue_depth, single_submit, overlap_events, false), + _num_pending_ops(0), + _pinned_tensor_mgr(new deepspeed_pin_tensor_t()) +{ + for (auto i = 0; i < intra_op_parallelism; ++i) { + _thread_contexts.push_back(std::make_shared(i, _aio_config)); + } + + for (auto& ctxt : _thread_contexts) { + _threads.push_back(std::thread(_start_aio_thread, ctxt)); + } +} + +deepspeed_io_handle_t::~deepspeed_io_handle_t() +{ + _stop_threads(); + for (auto& thr : _threads) { thr.join(); } +} + +const int deepspeed_io_handle_t::get_block_size() const +{ + return _aio_ctxt ? _aio_ctxt->_block_size : -1; +} + +const int deepspeed_io_handle_t::get_queue_depth() const +{ + return _aio_ctxt ? _aio_ctxt->_queue_depth : -1; +} + +const bool deepspeed_io_handle_t::get_single_submit() const { return _single_submit; } + +const bool deepspeed_io_handle_t::get_overlap_events() const { return _overlap_events; } + +const int deepspeed_io_handle_t::get_intra_op_parallelism() const { return _intra_op_parallelism; } + +int deepspeed_io_handle_t::read(torch::Tensor& buffer, const char* filename, const bool validate) +{ + const auto start_time = std::chrono::high_resolution_clock::now(); + + assert(_aio_ctxt); + + int64_t num_file_bytes; + if (-1 == get_file_size(filename, num_file_bytes)) { + const auto error_code = errno; + report_file_error(filename, " fstat for read", error_code); + return -1; + } + assert(static_cast(buffer.nbytes()) == num_file_bytes); + + const auto fd = open_file(filename, true); + if (fd == -1) { return -1; } + + auto read_buffer = (char*)buffer.data_ptr(); + std::unique_ptr xfer_ctxt(new io_xfer_ctxt(fd, 0, num_file_bytes, read_buffer)); + + if (_aio_config._overlap_events) { + do_aio_operation_overlap(true, _aio_ctxt, xfer_ctxt, &_aio_config, nullptr); + } else { + do_aio_operation_sequential(true, _aio_ctxt, xfer_ctxt, &_aio_config, nullptr); + } + + close(fd); + const std::chrono::duration aio_time = + std::chrono::high_resolution_clock::now() - start_time; + + if (validate) { validate_aio_operation(true, filename, read_buffer, num_file_bytes); } + const std::chrono::duration fn_time = + std::chrono::high_resolution_clock::now() - start_time; + std::cout << "Elapsed time(usec): " << "aio = " << aio_time.count() * 1e6 + << " call = " << fn_time.count() * 1e6 << std::endl; + return 0; +} + +int deepspeed_io_handle_t::write(const torch::Tensor& buffer, + const char* filename, + const bool validate) +{ + assert(_aio_ctxt); + + const auto start_time = std::chrono::high_resolution_clock::now(); + + const auto fd = open_file(filename, false); + if (fd == -1) { return -1; } + + auto write_buffer = (char*)buffer.data_ptr(); + const auto num_write_bytes = static_cast(buffer.nbytes()); + std::unique_ptr xfer_ctxt(new io_xfer_ctxt(fd, 0, num_write_bytes, write_buffer)); + + if (_aio_config._overlap_events) { + do_aio_operation_overlap(false, _aio_ctxt, xfer_ctxt, &_aio_config, nullptr); + } else { + do_aio_operation_sequential(false, _aio_ctxt, xfer_ctxt, &_aio_config, nullptr); + } + const std::chrono::duration aio_time = + std::chrono::high_resolution_clock::now() - start_time; + + close(fd); + + if (validate) { validate_aio_operation(false, filename, write_buffer, num_write_bytes); } + + const std::chrono::duration fn_time = + std::chrono::high_resolution_clock::now() - start_time; + std::cout << "Elapsed time(usec): " << "aio = " << aio_time.count() * 1e6 + << " call = " << fn_time.count() * 1e6 << std::endl; + return 0; +} + +void deepspeed_io_handle_t::_schedule_aio_work(std::shared_ptr scheduled_op) +{ + for (auto& ctxt : _thread_contexts) { + { + std::lock_guard lock(ctxt->_work_sync._mutex); + ctxt->_work_queue.push(scheduled_op); + } + ctxt->_work_sync._cond_var.notify_one(); + } + _num_pending_ops++; +} + +std::shared_ptr deepspeed_io_handle_t::_wait_for_aio_work() +{ + std::shared_ptr completed_op = nullptr; + for (auto& ctxt : _thread_contexts) { + std::unique_lock lock(ctxt->_complete_sync._mutex); + ctxt->_complete_sync._cond_var.wait(lock, + [ctxt] { return !ctxt->_complete_queue.empty(); }); + completed_op = ctxt->_complete_queue.front(); + ctxt->_complete_queue.pop(); + } + return completed_op; +} + +void deepspeed_io_handle_t::_stop_threads() +{ + assert(0 == _num_pending_ops); + for (auto& ctxt : _thread_contexts) { + { + std::lock_guard lock(ctxt->_work_sync._mutex); + ctxt->_time_to_exit = true; + } + ctxt->_work_sync._cond_var.notify_one(); + } +} + +int deepspeed_io_handle_t::wait() +{ + assert(_num_pending_ops > 0); + auto num_completed_ops = 0; + + while (_num_pending_ops > 0) { + auto completed_op = _wait_for_aio_work(); + + if (completed_op->_validate) { completed_op->validate(); } + + completed_op->finish(); + + close(completed_op->_fd); + + --_num_pending_ops; + ++num_completed_ops; + } + + return num_completed_ops; +} + +bool deepspeed_io_handle_t::_is_valid_parallel_aio_op(const bool read_op, const int64_t num_bytes) +{ + const auto op_string = read_op ? "Read" : "Write"; + if (num_bytes % get_intra_op_parallelism()) { + std::cout << "deepspeed_aio failure: parallel " << op_string << " num_bytes = " << num_bytes + << " not divisible by thread count = " << get_intra_op_parallelism() << std::endl; + return false; + } + + return true; +} + +std::shared_ptr deepspeed_io_handle_t::_create_io_op_desc( + const bool read_op, + const torch::Tensor& buffer, + const int fd, + const char* filename, + const int64_t file_num_bytes, + const bool validate) +{ + return std::make_shared(read_op, + buffer, + _pinned_tensor_mgr, + fd, + filename, + file_num_bytes, + _intra_op_parallelism, + validate); +} + +int deepspeed_io_handle_t::pread(const torch::Tensor& buffer, + const char* filename, + const bool validate, + const bool async) +{ + int64_t num_file_bytes; + if (-1 == get_file_size(filename, num_file_bytes)) { + const auto error_code = errno; + report_file_error(filename, " fstat for read", error_code); + return -1; + } + const auto buffer_bytes = static_cast(buffer.nbytes()); + if (buffer_bytes != num_file_bytes) { + std::cout << filename << ": buffer nbytes != file bytes " << buffer_bytes + << " != " << num_file_bytes << std::endl; + } + assert(buffer_bytes == num_file_bytes); + assert((num_file_bytes % _intra_op_parallelism) == 0); + + if (!_is_valid_parallel_aio_op(true, num_file_bytes)) { return -1; } + + const auto fd = open_file(filename, true); + if (fd == -1) { return -1; } + + auto scheduled_op = _create_io_op_desc(true, buffer, fd, filename, num_file_bytes, validate); + + _schedule_aio_work(scheduled_op); + + if (async) { return 0; } + + return wait(); +} + +int deepspeed_io_handle_t::pwrite(const torch::Tensor& buffer, + const char* filename, + const bool validate, + const bool async) +{ + const auto num_write_bytes = static_cast(buffer.nbytes()); + assert((num_write_bytes % _intra_op_parallelism) == 0); + + if (!_is_valid_parallel_aio_op(false, num_write_bytes)) { return -1; } + + const auto fd = open_file(filename, false); + if (fd == -1) { return -1; } + + auto scheduled_op = _create_io_op_desc(false, buffer, fd, filename, num_write_bytes, validate); + + _schedule_aio_work(scheduled_op); + + if (async) { return 0; } + + return wait(); +} + +int deepspeed_io_handle_t::sync_pread(torch::Tensor& buffer, const char* filename) +{ + return pread(buffer, filename, false, false); +} + +int deepspeed_io_handle_t::sync_pwrite(const torch::Tensor& buffer, const char* filename) +{ + return pwrite(buffer, filename, false, false); +} + +int deepspeed_io_handle_t::async_pread(torch::Tensor& buffer, const char* filename) +{ + return pread(buffer, filename, false, true); +} + +int deepspeed_io_handle_t::async_pwrite(const torch::Tensor& buffer, const char* filename) +{ + return pwrite(buffer, filename, false, true); +} + +at::Tensor deepspeed_io_handle_t::new_cpu_locked_tensor(const int64_t num_elem, + const torch::Tensor& example_tensor) +{ + return _pinned_tensor_mgr->alloc(num_elem, example_tensor.scalar_type()); +} + +bool deepspeed_io_handle_t::free_cpu_locked_tensor(torch::Tensor& locked_tensor) +{ + return _pinned_tensor_mgr->free(locked_tensor); +} diff --git a/toolbox/DeepSpeed/v0.15.3/patches/csrc/gds/py_lib/deepspeed_gds_utils.h b/toolbox/DeepSpeed/v0.15.3/patches/csrc/gds/py_lib/deepspeed_gds_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..b2993d266d2706a90e2e6c476f8716f7547ee5a6 --- /dev/null +++ b/toolbox/DeepSpeed/v0.15.3/patches/csrc/gds/py_lib/deepspeed_gds_utils.h @@ -0,0 +1,112 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +/* Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. */ +/* All Rights Reserved. */ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include + +// CUDA/cuFile includes +#include +#include +#ifdef __ILUVATAR__ +#include "nvcufile.h" +#else +#include "cufile.h" +#endif + +// Macro for checking cuda errors following a cuda launch or api call +#define cudaCheckError() \ + { \ + cudaError_t e = cudaGetLastError(); \ + if (e != cudaSuccess) { \ + printf("Cuda failure %s:%d: '%s'\n", __FILE__, __LINE__, cudaGetErrorString(e)); \ + exit(EXIT_FAILURE); \ + } \ + } + +#define check_cudadrivercall(fn) \ + do { \ + CUresult res = fn; \ + if (res != CUDA_SUCCESS) { \ + const char* str = nullptr; \ + cuGetErrorName(res, &str); \ + std::cerr << "cuda driver api call failed " << #fn << " res : " << res << ", " \ + << __LINE__ << ":" << str << std::endl; \ + std::cerr << "EXITING program!!!" << std::endl; \ + exit(1); \ + } \ + } while (0) + +#define check_cudaruntimecall(fn) \ + do { \ + cudaError_t res = fn; \ + if (res != cudaSuccess) { \ + const char* str = cudaGetErrorName(res); \ + std::cerr << "cuda runtime api call failed " << #fn << __LINE__ << ":" << str \ + << std::endl; \ + std::cerr << "EXITING program!!!" << std::endl; \ + exit(1); \ + } \ + } while (0) + +#define check_cuFileCall(fn, api_msg) \ + do { \ + CUfileError_t status = fn; \ + if (status.err != CU_FILE_SUCCESS) { \ + std::cout << api_msg << " failed with error " << CUFILE_ERRSTR(status.err) \ + << std::endl; \ + exit(EXIT_FAILURE); \ + } \ + } while (0) + +// +// cuda driver error description +// +static inline const char* GetCuErrorString(CUresult curesult) +{ + const char* descp; + if (cuGetErrorName(curesult, &descp) != CUDA_SUCCESS) descp = "unknown cuda error"; + return descp; +} + +// +// cuFile APIs return both cuFile specific error codes as well as POSIX error codes +// for ease, the below template can be used for getting the error description depending +// on its type. + +// POSIX +template ::value, std::nullptr_t>::type = nullptr> +std::string cuFileGetErrorString(T status) +{ + status = std::abs(status); + return IS_CUFILE_ERR(status) ? std::string(CUFILE_ERRSTR(status)) + : std::string(std::strerror(status)); +} + +// CUfileError_t +template ::value, std::nullptr_t>::type = nullptr> +std::string cuFileGetErrorString(T status) +{ + std::string errStr = cuFileGetErrorString(static_cast(status.err)); + if (IS_CUDA_ERR(status)) errStr.append(".").append(GetCuErrorString(status.cu_err)); + return errStr; +} diff --git a/toolbox/DeepSpeed/v0.15.3/patches/csrc/includes/StopWatch.h b/toolbox/DeepSpeed/v0.15.3/patches/csrc/includes/StopWatch.h new file mode 100644 index 0000000000000000000000000000000000000000..29473bc0f66d4f27ccd5bde43ad44db9c9ad4ca3 --- /dev/null +++ b/toolbox/DeepSpeed/v0.15.3/patches/csrc/includes/StopWatch.h @@ -0,0 +1,120 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +/* Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. */ +/* All Rights Reserved. */ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once +#ifdef _WIN32 +#include +#else +#include +#endif + +#ifdef _WIN32 + +class Stopwatch { +private: + float m_total_time; + LARGE_INTEGER m_start_time; + +public: + Stopwatch() { m_total_time = 0.0; } + + ~Stopwatch() {} + + void Reset() { m_total_time = 0.0; } + + void Start() { QueryPerformanceCounter(&m_start_time); } + + void Restart() + { + m_total_time = 0.0; + QueryPerformanceCounter(&m_start_time); + } + + void Stop() + { + LARGE_INTEGER frequency; + LARGE_INTEGER stop_time; + QueryPerformanceFrequency(&frequency); + QueryPerformanceCounter(&stop_time); + m_total_time += + ((float)(stop_time.QuadPart - m_start_time.QuadPart) / (float)frequency.QuadPart); + } + + float GetTimeInSeconds() { return m_total_time; } +}; + +#else + +class Stopwatch { +private: + float m_total_time; + struct timespec m_start_time; + bool m_is_started; + +public: + Stopwatch() + { + m_total_time = 0.0; + m_is_started = false; + } + + ~Stopwatch() {} + + void Reset() { m_total_time = 0.0; } + + void Start() + { + clock_gettime(CLOCK_MONOTONIC, &m_start_time); + m_is_started = true; + } + + void Restart() + { + m_total_time = 0.0; + clock_gettime(CLOCK_MONOTONIC, &m_start_time); + m_is_started = true; + } + + void Stop() + { + if (m_is_started) { + m_is_started = false; + + struct timespec end_time; + clock_gettime(CLOCK_MONOTONIC, &end_time); + + m_total_time += (float)(end_time.tv_sec - m_start_time.tv_sec) + + (float)(end_time.tv_nsec - m_start_time.tv_nsec) / 1e9; + } + } + + float GetTimeInSeconds() + { + if (m_is_started) { + Stop(); + Start(); + } + return m_total_time; + } +}; + +#endif diff --git a/toolbox/DeepSpeed/v0.15.3/patches/csrc/includes/conversion_utils.h b/toolbox/DeepSpeed/v0.15.3/patches/csrc/includes/conversion_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..42a00a7ee21e5450ebc977dc07fcdbbe928234f4 --- /dev/null +++ b/toolbox/DeepSpeed/v0.15.3/patches/csrc/includes/conversion_utils.h @@ -0,0 +1,657 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +/* Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. */ +/* All Rights Reserved. */ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#include "ds_kernel_utils.h" + +#include + +#ifdef BF16_AVAILABLE +#include +#endif + +namespace conversion { + +// Basic primitive for constructing conversions +template +DS_D_INLINE TO to(FROM val) +{ + return to(val); +} + +// Specializations + +/********************* Identity Conversions *********************/ +/* +Identity conversions are useful in templated functions where we might have +a fixed destination type. For example, I might have a kernel that accepts +__half, __nv_bfloat16, and float but always want to do the core computation +at floating point: + +T mem_value = input[idx]; +float compute_value = conversion::to(mem_value); + +In practice, we should be able to elide the second template parameter: +float compute_val = conversion::to(mem_value); + +In this case, we need an implementation to handle the T = float case + +NOTE: The type inferencing system appears to be unable to handle inferring the first +template parameter, even in the trivial case. +*/ + +// Floating point types +template <> +DS_D_INLINE double to(double val) +{ + return val; +} +template <> +DS_D_INLINE float to(float val) +{ + return val; +} +template <> +DS_D_INLINE __half to(__half val) +{ + return val; +} +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE __nv_bfloat16 to(__nv_bfloat16 val) +{ + return val; +} +#endif + +// Integer types +template <> +DS_D_INLINE int8_t to(int8_t val) +{ + return val; +} +template <> +DS_D_INLINE uint8_t to(uint8_t val) +{ + return val; +} +template <> +DS_D_INLINE int16_t to(int16_t val) +{ + return val; +} +template <> +DS_D_INLINE uint16_t to(uint16_t val) +{ + return val; +} +template <> +DS_D_INLINE int32_t to(int32_t val) +{ + return val; +} +template <> +DS_D_INLINE uint32_t to(uint32_t val) +{ + return val; +} +template <> +DS_D_INLINE int64_t to(int64_t val) +{ + return val; +} +template <> +DS_D_INLINE uint64_t to(uint64_t val) +{ + return val; +} + +// TODO: evaluate if we want bools + +/********************* To Double Conversions *********************/ + +// * to double variants + +// Would normally like to not use C cast, but this is an important enough conversion +// to keep +template <> +DS_D_INLINE double to(float val) +{ +#ifdef PTX_AVAILABLE + double ret_val; + asm("ctv.rn.f64.f32 %0, %1;\n" : "=d"(ret_val) : "f"(val)); + return ret_val; +#else + return double(val); +#endif +} +// Note: there is a CVT instruction for __half -> double, but there's no inline interface +// for passing a single half value +template <> +DS_D_INLINE double to(__half val) +{ + return to(__half2float(val)); +} +template <> +DS_D_INLINE double to(int64_t val) +{ + return __ll2double_rn(val); +} +template <> +DS_D_INLINE double to(int32_t val) +{ + return __int2double_rn(val); +} +template <> +DS_D_INLINE double to(int16_t val) +{ + return __int2double_rn(val); +} +template <> +DS_D_INLINE double to(int8_t val) +{ + return __int2double_rn(val); +} +template <> +DS_D_INLINE double to(uint64_t val) +{ + return __ull2double_rn(val); +} +template <> +DS_D_INLINE double to(uint32_t val) +{ + return __uint2double_rn(val); +} +template <> +DS_D_INLINE double to(uint16_t val) +{ + return __uint2double_rn(val); +} +template <> +DS_D_INLINE double to(uint8_t val) +{ + return __uint2double_rn(val); +} + +// Same applies here +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE double to(__nv_bfloat16 val) +{ + return to(__bfloat162float(val)); +} +#endif + +/********************* To Float Conversions *********************/ + +template <> +DS_D_INLINE float to(double val) +{ + return __double2float_rn(val); +} +template <> +DS_D_INLINE float to(__half val) +{ + return __half2float(val); +} +template <> +DS_D_INLINE float to(int64_t val) +{ + return __ll2float_rn(val); +} +template <> +DS_D_INLINE float to(int32_t val) +{ + return __int2float_rn(val); +} +template <> +DS_D_INLINE float to(int16_t val) +{ + return __int2float_rn(val); +} +template <> +DS_D_INLINE float to(int8_t val) +{ + return __int2float_rn(val); +} +template <> +DS_D_INLINE float to(uint64_t val) +{ + return __ull2float_rn(val); +} +template <> +DS_D_INLINE float to(uint32_t val) +{ + return __uint2float_rn(val); +} +template <> +DS_D_INLINE float to(uint16_t val) +{ + return __uint2float_rn(val); +} +template <> +DS_D_INLINE float to(uint8_t val) +{ + return __uint2float_rn(val); +} + +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE float to(__nv_bfloat16 val) +{ + return __bfloat162float(val); +} +#endif + +/********************* To Float2 Conversions *********************/ +template <> +DS_D_INLINE float2 to(__half2 val) +{ + return __half22float2(val); +} + +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE float2 to(__nv_bfloat162 val) +{ + return __bfloat1622float2(val); +} +#endif + +/********************* To Half Conversions *********************/ +// template <> +// DS_D_INLINE __half to(double val) +// { +// #ifdef __HIP_PLATFORM_HCC__ +// float val_f = __double2float_rn(val); +// return __float2half(val_f); +// #else +// return __double2half(val); +// #endif +// } +template <> +DS_D_INLINE __half to(float val) +{ + return __float2half(val); +} +template <> +DS_D_INLINE __half to(int64_t val) +{ + return __ll2half_rn(val); +} +template <> +DS_D_INLINE __half to(int32_t val) +{ + return __int2half_rn(val); +} +template <> +DS_D_INLINE __half to(int16_t val) +{ + return __short2half_rn(val); +} +template <> +DS_D_INLINE __half to(int8_t val) +{ + return __int2half_rn(val); +} +template <> +DS_D_INLINE __half to(uint64_t val) +{ + return __ull2half_rn(val); +} +template <> +DS_D_INLINE __half to(uint32_t val) +{ + return __uint2half_rn(val); +} +template <> +DS_D_INLINE __half to(uint16_t val) +{ + return __ushort2half_rn(val); +} +template <> +DS_D_INLINE __half to(uint8_t val) +{ + return __uint2half_rn(val); +} + +#ifdef BF16_AVAILABLE +// No direct conversion +template <> +DS_D_INLINE __half to(__nv_bfloat16 val) +{ + return to<__half>(to(val)); +} +#endif + +/********************* To Half2 Conversions *********************/ +template <> +DS_D_INLINE __half2 to(float2 val) +{ + return __float22half2_rn(val); +} +template <> +DS_D_INLINE __half2 to(float val) +{ + return __float2half2_rn(val); +} + +#ifdef BF16_AVAILABLE +// No direct conversion +template <> +DS_D_INLINE __half2 to(__nv_bfloat162 val) +{ + return to<__half2>(to(val)); +} +#endif + +/********************* To BF16 Conversions *********************/ +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE __nv_bfloat16 to(double val) +{ + return __double2bfloat16(val); +} +template <> +DS_D_INLINE __nv_bfloat16 to(float val) +{ + return __float2bfloat16(val); +} +template <> +DS_D_INLINE __nv_bfloat16 to(int64_t val) +{ + return __ll2bfloat16_rn(val); +} +template <> +DS_D_INLINE __nv_bfloat16 to(int32_t val) +{ + return __int2bfloat16_rn(val); +} +template <> +DS_D_INLINE __nv_bfloat16 to(int16_t val) +{ + return __short2bfloat16_rn(val); +} +template <> +DS_D_INLINE __nv_bfloat16 to(int8_t val) +{ + return __int2bfloat16_rn(val); +} +template <> +DS_D_INLINE __nv_bfloat16 to(uint64_t val) +{ + return __ull2bfloat16_rn(val); +} +template <> +DS_D_INLINE __nv_bfloat16 to(uint32_t val) +{ + return __uint2bfloat16_rn(val); +} +template <> +DS_D_INLINE __nv_bfloat16 to(uint16_t val) +{ + return __ushort2bfloat16_rn(val); +} +template <> +DS_D_INLINE __nv_bfloat16 to(uint8_t val) +{ + return __uint2bfloat16_rn(val); +} +#endif + +/********************* To BF162 Conversions *********************/ +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE __nv_bfloat162 to(float2 val) +{ + return __float22bfloat162_rn(val); +} +template <> +DS_D_INLINE __nv_bfloat162 to(float val) +{ + return __float2bfloat162_rn(val); +} +template <> +DS_D_INLINE __nv_bfloat162 to(__half2 val) +{ + return to<__nv_bfloat162>(to(val)); +} +#endif + +/********************* To INT64_T Conversions *********************/ +template <> +DS_D_INLINE int64_t to(double val) +{ + return __double2ll_rn(val); +} +template <> +DS_D_INLINE int64_t to(float val) +{ + return __float2ll_rn(val); +} +template <> +DS_D_INLINE int64_t to(__half val) +{ + return __half2ll_rn(val); +} +// No direct support for integer casts at the C++ level and I don't feel they're so important +// to demand an PTX at this time + +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE int64_t to(__nv_bfloat16 val) +{ + return __bfloat162ll_rn(val); +} +#endif + +/********************* To INT32_T Conversions *********************/ +template <> +DS_D_INLINE int32_t to(double val) +{ + return __double2int_rn(val); +} +template <> +DS_D_INLINE int32_t to(float val) +{ + return __float2int_rn(val); +} +template <> +DS_D_INLINE int32_t to(__half val) +{ + return __half2int_rn(val); +} +// No direct support for integer casts at the C++ level and I don't feel they're so important +// to demand an PTX at this time + +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE int32_t to(__nv_bfloat16 val) +{ + return __bfloat162int_rn(val); +} +#endif + +/********************* To INT16_T Conversions *********************/ +template <> +DS_D_INLINE int16_t to(double val) +{ + return __double2int_rn(val); +} +template <> +DS_D_INLINE int16_t to(float val) +{ + return __float2int_rn(val); +} +template <> +DS_D_INLINE int16_t to(__half val) +{ + return __half2int_rn(val); +} +// No direct support for integer casts at the C++ level and I don't feel they're so important +// to demand an PTX at this time + +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE int16_t to(__nv_bfloat16 val) +{ + return __bfloat162int_rn(val); +} +#endif + +/********************* To INT8_T Conversions *********************/ +template <> +DS_D_INLINE int8_t to(double val) +{ + return __double2int_rn(val); +} +template <> +DS_D_INLINE int8_t to(float val) +{ + return __float2int_rn(val); +} +template <> +DS_D_INLINE int8_t to(__half val) +{ + return __half2int_rn(val); +} +// No direct support for integer casts at the C++ level and I don't feel they're so important +// to demand an PTX at this time + +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE int8_t to(__nv_bfloat16 val) +{ + return __bfloat162int_rn(val); +} +#endif + +/********************* To UINT64_T Conversions *********************/ +template <> +DS_D_INLINE uint64_t to(double val) +{ + return __double2ull_rn(val); +} +template <> +DS_D_INLINE uint64_t to(float val) +{ + return __float2ull_rn(val); +} +template <> +DS_D_INLINE uint64_t to(__half val) +{ + return __half2ull_rn(val); +} +// No direct support for integer casts at the C++ level and I don't feel they're so important +// to demand an PTX at this time + +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE uint64_t to(__nv_bfloat16 val) +{ + return __bfloat162ull_rn(val); +} +#endif + +/********************* To UINT32_T Conversions *********************/ +template <> +DS_D_INLINE uint32_t to(double val) +{ + return __double2uint_rn(val); +} +template <> +DS_D_INLINE uint32_t to(float val) +{ + return __float2uint_rn(val); +} +template <> +DS_D_INLINE uint32_t to(__half val) +{ + return __half2uint_rn(val); +} +// No direct support for integer casts at the C++ level and I don't feel they're so important +// to demand an PTX at this time + +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE uint32_t to(__nv_bfloat16 val) +{ + return __bfloat162uint_rn(val); +} +#endif + +/********************* To UINT16_T Conversions *********************/ +template <> +DS_D_INLINE uint16_t to(double val) +{ + return __double2uint_rn(val); +} +template <> +DS_D_INLINE uint16_t to(float val) +{ + return __float2uint_rn(val); +} +template <> +DS_D_INLINE uint16_t to(__half val) +{ + return __half2uint_rn(val); +} +// No direct support for integer casts at the C++ level and I don't feel they're so important +// to demand an PTX at this time + +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE uint16_t to(__nv_bfloat16 val) +{ + return __bfloat162uint_rn(val); +} +#endif + +/********************* To UINT8_T Conversions *********************/ +template <> +DS_D_INLINE uint8_t to(double val) +{ + return __double2uint_rn(val); +} +template <> +DS_D_INLINE uint8_t to(float val) +{ + return __float2uint_rn(val); +} +template <> +DS_D_INLINE uint8_t to(__half val) +{ + return __half2uint_rn(val); +} +// No direct support for integer casts at the C++ level and I don't feel they're so important +// to demand an PTX at this time + +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE uint8_t to(__nv_bfloat16 val) +{ + return __bfloat162uint_rn(val); +} +#endif + +} // namespace conversion diff --git a/toolbox/DeepSpeed/v0.15.3/patches/csrc/includes/cooperative_groups.h b/toolbox/DeepSpeed/v0.15.3/patches/csrc/includes/cooperative_groups.h new file mode 100644 index 0000000000000000000000000000000000000000..796c161b3cb37d5d4aea541ca8e3d1a856c65322 --- /dev/null +++ b/toolbox/DeepSpeed/v0.15.3/patches/csrc/includes/cooperative_groups.h @@ -0,0 +1,1111 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +/* +Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may +not use this file except in compliance with the License. You may obtain +a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +/* + * Copyright 1993-2016 NVIDIA Corporation. All rights reserved. + * + * NOTICE TO LICENSEE: + * + * This source code and/or documentation ("Licensed Deliverables") are + * subject to NVIDIA intellectual property rights under U.S. and + * international Copyright laws. + * + * These Licensed Deliverables contained herein is PROPRIETARY and + * CONFIDENTIAL to NVIDIA and is being provided under the terms and + * conditions of a form of NVIDIA software license agreement by and + * between NVIDIA and Licensee ("License Agreement") or electronically + * accepted by Licensee. Notwithstanding any terms or conditions to + * the contrary in the License Agreement, reproduction or disclosure + * of the Licensed Deliverables to any third party without the express + * written consent of NVIDIA is prohibited. + * + * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE + * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE + * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS + * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND. + * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED + * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY, + * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE. + * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE + * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY + * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY + * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, + * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS + * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE + * OF THESE LICENSED DELIVERABLES. + * + * U.S. Government End Users. These Licensed Deliverables are a + * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT + * 1995), consisting of "commercial computer software" and "commercial + * computer software documentation" as such terms are used in 48 + * C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government + * only as a commercial end item. Consistent with 48 C.F.R.12.212 and + * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all + * U.S. Government End Users acquire the Licensed Deliverables with + * only those rights set forth herein. + * + * Any use of the Licensed Deliverables in individual and commercial + * software must include, in the user documentation and internal + * comments to the code, the above Disclaimer and U.S. Government End + * Users Notice. + */ + +#ifndef _COOPERATIVE_GROUPS_H_ +# define _COOPERATIVE_GROUPS_H_ + +#if defined(__cplusplus) && defined(__CUDACC__) + +# include "cooperative_groups_helpers.h" + +_CG_BEGIN_NAMESPACE + +/** + * class thread_group; + * + * Generic thread group type, into which all groups are convertible. + * It acts as a container for all storage necessary for the derived groups, + * and will dispatch the API calls to the correct derived group. This means + * that all derived groups must implement the same interface as thread_group. + */ +class thread_group +{ + friend _CG_QUALIFIER thread_group this_thread(); + friend _CG_QUALIFIER thread_group tiled_partition(const thread_group& parent, unsigned int tilesz); + friend class thread_block; + + protected: + union __align__(8) { + unsigned int type : 8; + struct { + unsigned int type : 8; + unsigned int size : 24; + unsigned int metaGroupSize : 16; + unsigned int metaGroupRank : 16; + +#ifdef __ILUVATAR__ + uint64_t mask; +#else + unsigned int mask; +#endif + } coalesced; + struct { + void* ptr[2]; + } buffer; + } _data; + + _CG_QUALIFIER thread_group operator=(const thread_group& src); + _CG_QUALIFIER thread_group(__internal::groupType type) { + _data.type = type; + } + +#if __cplusplus >= 201103L + static_assert(sizeof(_data) == 16, "Failed size check"); +#endif + +public: + _CG_QUALIFIER unsigned int size() const; + _CG_QUALIFIER unsigned int thread_rank() const; + _CG_QUALIFIER void sync() const; +}; + +/** + * thread_group this_thread() + * + * Constructs a generic thread_group containing only the calling thread + */ +_CG_QUALIFIER thread_group this_thread() +{ + thread_group g = thread_group(__internal::Coalesced); +#ifdef __ILUVATAR__ + g._data.coalesced.mask = __internal::lanemask_eq(); +#else + g._data.coalesced.mask = __internal::lanemask32_eq(); +#endif + g._data.coalesced.size = 1; + return (g); +} + +#if defined(_CG_HAS_MULTI_GRID_GROUP) + +/** + * class multi_grid_group; + * + * Threads within this this group are guaranteed to be co-resident on the + * same system, on multiple devices within the same launched kernels. + * To use this group, the kernel must have been launched with + * cuLaunchCooperativeKernelMultiDevice (or the CUDA Runtime equivalent), + * and the device must support it (queryable device attribute). + * + * Constructed via this_multi_grid(); + */ +class multi_grid_group +{ + friend _CG_QUALIFIER multi_grid_group this_multi_grid(); + + struct __align__(8) { + unsigned long long handle; + unsigned int size; + unsigned int rank; + } _data; + +#if __cplusplus >= 201103L + static_assert(sizeof(_data) == 16, "Failed size check"); +#endif + +public: + _CG_QUALIFIER multi_grid_group() { + _data.handle = __internal::multi_grid::get_intrinsic_handle(); + _data.size = __internal::multi_grid::size(_data.handle); + _data.rank = __internal::multi_grid::thread_rank(_data.handle); + } + + _CG_QUALIFIER bool is_valid() const { + return (_data.handle != 0); + } + + _CG_QUALIFIER void sync() const { + _CG_ASSERT(is_valid()); + __internal::multi_grid::sync(_data.handle); + } + + _CG_QUALIFIER unsigned int size() const { + _CG_ASSERT(is_valid()); + return (_data.size); + } + + _CG_QUALIFIER unsigned int thread_rank() const { + _CG_ASSERT(is_valid()); + return (_data.rank); + } + + _CG_QUALIFIER unsigned int grid_rank() const { + _CG_ASSERT(is_valid()); + return (__internal::multi_grid::grid_rank(_data.handle)); + } + + _CG_QUALIFIER unsigned int num_grids() const { + _CG_ASSERT(is_valid()); + return (__internal::multi_grid::num_grids(_data.handle)); + } +}; + +/** + * multi_grid_group this_multi_grid() + * + * Constructs a multi_grid_group + */ +_CG_QUALIFIER multi_grid_group this_multi_grid() +{ + return (multi_grid_group()); +} + +#endif + +#if defined(_CG_HAS_GRID_GROUP) + +/** + * class grid_group; + * + * Threads within this this group are guaranteed to be co-resident on the + * same device within the same launched kernel. To use this group, the kernel + * must have been launched with cuLaunchCooperativeKernel (or the CUDA Runtime equivalent), + * and the device must support it (queryable device attribute). + * + * Constructed via this_grid(); + */ +class grid_group +{ + friend _CG_QUALIFIER grid_group this_grid(); + + struct __align__(8) { + unsigned long long handle; + unsigned int size; + unsigned int rank; + } _data; + +#if __cplusplus >= 201103L + static_assert(sizeof(_data) == 16, "Failed size check"); +#endif + + public: + _CG_QUALIFIER grid_group() { + _data.handle = (__internal::grid::get_intrinsic_handle()); + _data.size = __internal::grid::size(_data.handle); + _data.rank = __internal::grid::thread_rank(_data.handle); + } + + _CG_QUALIFIER bool is_valid() const { + return (_data.handle != 0); + } + + _CG_QUALIFIER void sync() const { + _CG_ASSERT(is_valid()); + __internal::grid::sync(_data.handle); + } + + _CG_QUALIFIER unsigned int size() const { + _CG_ASSERT(is_valid()); + return (_data.size); + } + + _CG_QUALIFIER unsigned int thread_rank() const { + _CG_ASSERT(is_valid()); + return (_data.rank); + } + + _CG_QUALIFIER dim3 group_dim() const { + _CG_ASSERT(is_valid()); + return (__internal::grid::grid_dim()); + } + +}; + +/** + * grid_group this_grid() + * + * Constructs a grid_group + */ +_CG_QUALIFIER grid_group this_grid() +{ + return (grid_group()); +} + +#endif + +/** + * class thread_block + * + * Every GPU kernel is executed by a grid of thread blocks, and threads within + * each block are guaranteed to reside on the same streaming multiprocessor. + * A thread_block represents a thread block whose dimensions are not known until runtime. + * + * Constructed via this_thread_block(); + */ +class thread_block : public thread_group +{ + friend _CG_QUALIFIER thread_block this_thread_block(); + friend _CG_QUALIFIER thread_group tiled_partition(const thread_group& parent, unsigned int tilesz); + friend _CG_QUALIFIER thread_group tiled_partition(const thread_block& parent, unsigned int tilesz); + + _CG_QUALIFIER thread_block() : thread_group(__internal::ThreadBlock) { + } + + // Internal Use + _CG_QUALIFIER thread_group _get_tiled_threads(unsigned int tilesz) const { + const bool pow2_tilesz = ((tilesz & (tilesz - 1)) == 0); + +#ifdef __ILUVATAR__ + // Invalid, immediately fail + if (tilesz == 0 || (tilesz > warpSize) || !pow2_tilesz) { + __internal::abort(); + return (thread_block()); + } + + uint64_t mask; + unsigned int base_offset = thread_rank() & (~(tilesz - 1)); + unsigned int masklength = min(size() - base_offset, tilesz); + mask = (uint64_t) (-1) >> (warpSize - masklength); + mask <<= (__internal::laneid() & ~(tilesz - 1)); + thread_group tile = thread_group(__internal::CoalescedTile); + tile._data.coalesced.mask = mask; + tile._data.coalesced.size = __popcll(mask); + tile._data.coalesced.metaGroupSize= (__internal::cta::size() + tilesz -1) / tilesz; + tile._data.coalesced.metaGroupSize= __internal::cta::thread_rank() / tilesz; +#else + // Invalid, immediately fail + if (tilesz == 0 || (tilesz > 32) || !pow2_tilesz) { + __internal::abort(); + return (thread_block()); + } + + unsigned int mask; + unsigned int base_offset = thread_rank() & (~(tilesz - 1)); + unsigned int masklength = min(size() - base_offset, tilesz); + + mask = (unsigned int)(-1) >> (32 - masklength); + mask <<= (__internal::laneid() & ~(tilesz - 1)); + + thread_group tile = thread_group(__internal::CoalescedTile); + tile._data.coalesced.mask = mask; + tile._data.coalesced.size = __popc(mask); +#endif + return (tile); + } + + public: + _CG_QUALIFIER void sync() const { + __internal::cta::sync(); + } + + _CG_QUALIFIER unsigned int size() const { + return (__internal::cta::size()); + } + + _CG_QUALIFIER unsigned int thread_rank() const { + return (__internal::cta::thread_rank()); + } + + // Additional functionality exposed by the group + _CG_QUALIFIER dim3 group_index() const { + return (__internal::cta::group_index()); + } + + _CG_QUALIFIER dim3 thread_index() const { + return (__internal::cta::thread_index()); + } + + _CG_QUALIFIER dim3 group_dim() const { + return (__internal::cta::block_dim()); + } + +}; + +/** + * thread_block this_thread_block() + * + * Constructs a thread_block group + */ +_CG_QUALIFIER thread_block this_thread_block() +{ + return (thread_block()); +} + +/** + * class coalesced_group + * + * A group representing the current set of converged threads in a warp. + * The size of the group is not guaranteed and it may return a group of + * only one thread (itself). + * + * This group exposes warp-synchronous builtins. + * Constructed via coalesced_threads(); + */ +class coalesced_group : public thread_group +{ + friend _CG_QUALIFIER coalesced_group coalesced_threads(); + friend _CG_QUALIFIER thread_group tiled_partition(const thread_group& parent, unsigned int tilesz); + friend _CG_QUALIFIER coalesced_group tiled_partition(const coalesced_group& parent, unsigned int tilesz); + + _CG_QUALIFIER unsigned int _packLanes(unsigned laneMask) const { + unsigned int member_pack = 0; + unsigned int member_rank = 0; +#ifdef __ILUVATAR__ + for (int bit_idx = 0; bit_idx < warpSize; bit_idx++) { + uint64_t lane_bit = _data.coalesced.mask & (1 << bit_idx); + if (lane_bit) { + if (laneMask & lane_bit) + member_pack |= 1 << member_rank; + member_rank++; + } + } +#else + for (int bit_idx = 0; bit_idx < 32; bit_idx++) { + unsigned int lane_bit = _data.coalesced.mask & (1 << bit_idx); + if (lane_bit) { + if (laneMask & lane_bit) + member_pack |= 1 << member_rank; + member_rank++; + } + } +#endif + return (member_pack); + } + + // Internal Use + _CG_QUALIFIER coalesced_group _get_tiled_threads(unsigned int tilesz) const { + const bool pow2_tilesz = ((tilesz & (tilesz - 1)) == 0); + + // Invalid, immediately fail + if (tilesz == 0 || (tilesz > 32) || !pow2_tilesz) { + __internal::abort(); + return (coalesced_group(0)); + } + if (size() <= tilesz) { + return (*this); + } + + if ((_data.type == __internal::CoalescedTile) && pow2_tilesz) { + unsigned int base_offset = (thread_rank() & (~(tilesz - 1))); + unsigned int masklength = min(size() - base_offset, tilesz); + unsigned int mask = (unsigned int)(-1) >> (32 - masklength); + + mask <<= (__internal::laneid() & ~(tilesz - 1)); + coalesced_group coalesced_tile = coalesced_group(mask); + coalesced_tile._data.type = __internal::CoalescedTile; + coalesced_tile._data.coalesced.metaGroupSize = size() / tilesz; + coalesced_tile._data.coalesced.metaGroupRank = thread_rank() / tilesz; + return (coalesced_tile); + } + else if ((_data.type == __internal::Coalesced) && pow2_tilesz) { + unsigned int mask = 0; + unsigned int member_rank = 0; + int seen_lanes = (thread_rank() / tilesz) * tilesz; + for (unsigned int bit_idx = 0; bit_idx < 32; bit_idx++) { + unsigned int lane_bit = _data.coalesced.mask & (1 << bit_idx); + if (lane_bit) { + if (seen_lanes <= 0 && member_rank < tilesz) { + mask |= lane_bit; + member_rank++; + } + seen_lanes--; + } + } + coalesced_group coalesced_tile = coalesced_group(mask); + coalesced_tile._data.coalesced.metaGroupSize = (size() + tilesz -1) / tilesz; + coalesced_tile._data.coalesced.metaGroupRank = thread_rank() / tilesz; + + return coalesced_tile; + } + else { + // None in _CG_VERSION 1000 + __internal::abort(); + } + + return (coalesced_group(0)); + } + + protected: + // Construct a group from scratch (coalesced_threads) + _CG_QUALIFIER coalesced_group(unsigned int mask) : thread_group(__internal::Coalesced) { + _data.coalesced.mask = mask; + _data.coalesced.size = __popc(mask); + _data.coalesced.metaGroupRank = 0; + _data.coalesced.metaGroupSize = 1; + } + + public: + _CG_QUALIFIER unsigned int size() const { + return (_data.coalesced.size); + } +#ifdef __ILUVATAR__ + _CG_QUALIFIER unsigned int thread_rank() const { + return (__popcll(_data.coalesced.mask & __internal::lanemask_lt())); + } +#else + _CG_QUALIFIER unsigned int thread_rank() const { + return (__popc(_data.coalesced.mask & __internal::lanemask32_lt())); + } +#endif + // Rank of this group in the upper level of the hierarchy + _CG_QUALIFIER unsigned int meta_group_rank() const { + return _data.coalesced.metaGroupRank; + } + + // Total num partitions created out of all CTAs when the group was created + _CG_QUALIFIER unsigned int meta_group_size() const { + return _data.coalesced.metaGroupSize; + } + // + _CG_QUALIFIER void sync() const { + __syncwarp(_data.coalesced.mask); + } + +#define COALESCED_SHFL_FUNCTION(type) \ + _CG_QUALIFIER type shfl(type var, unsigned int src_rank) const { \ + unsigned int lane = (src_rank == 0) ? __ffs(_data.coalesced.mask) - 1 : \ + (size() == 32) ? src_rank : __fns(_data.coalesced.mask, 0, (src_rank + 1)); \ + return (__shfl_sync(_data.coalesced.mask, var, lane, 32)); \ + } + +#define COALESCED_SHFL_UP_FUNCTION(type) \ + _CG_QUALIFIER type shfl_up(type var, int delta) const { \ + if (size() == 32) { \ + return (__shfl_up_sync(0xFFFFFFFF, var, delta, 32)); \ + } \ + unsigned lane = __fns(_data.coalesced.mask, __internal::laneid(), -(delta + 1)); \ + if (lane >= 32) lane = __internal::laneid(); \ + return (__shfl_sync(_data.coalesced.mask, var, lane, 32)); \ + } + +#define COALESCED_SHFL_DOWN_FUNCTION(type) \ + _CG_QUALIFIER type shfl_down(type var, int delta) const { \ + if (size() == 32) { \ + return (__shfl_down_sync(0xFFFFFFFF, var, delta, 32)); \ + } \ + unsigned int lane = __fns(_data.coalesced.mask, __internal::laneid(), delta + 1); \ + if (lane >= 32) lane = __internal::laneid(); \ + return (__shfl_sync(_data.coalesced.mask, var, lane, 32)); \ + } + + COALESCED_SHFL_FUNCTION(int); + COALESCED_SHFL_FUNCTION(unsigned int); + COALESCED_SHFL_FUNCTION(long); + COALESCED_SHFL_FUNCTION(unsigned long); + COALESCED_SHFL_FUNCTION(long long); + COALESCED_SHFL_FUNCTION(unsigned long long); + COALESCED_SHFL_FUNCTION(float); + COALESCED_SHFL_FUNCTION(double); + + COALESCED_SHFL_UP_FUNCTION(int); + COALESCED_SHFL_UP_FUNCTION(unsigned int); + COALESCED_SHFL_UP_FUNCTION(long); + COALESCED_SHFL_UP_FUNCTION(unsigned long); + COALESCED_SHFL_UP_FUNCTION(long long); + COALESCED_SHFL_UP_FUNCTION(unsigned long long); + COALESCED_SHFL_UP_FUNCTION(float); + COALESCED_SHFL_UP_FUNCTION(double); + + COALESCED_SHFL_DOWN_FUNCTION(int); + COALESCED_SHFL_DOWN_FUNCTION(unsigned int); + COALESCED_SHFL_DOWN_FUNCTION(long); + COALESCED_SHFL_DOWN_FUNCTION(unsigned long); + COALESCED_SHFL_DOWN_FUNCTION(long long); + COALESCED_SHFL_DOWN_FUNCTION(unsigned long long); + COALESCED_SHFL_DOWN_FUNCTION(float); + COALESCED_SHFL_DOWN_FUNCTION(double); + +# ifdef _CG_HAS_FP16_COLLECTIVE + COALESCED_SHFL_FUNCTION(__half); + COALESCED_SHFL_UP_FUNCTION(__half); + COALESCED_SHFL_DOWN_FUNCTION(__half); + + COALESCED_SHFL_FUNCTION(__half2); + COALESCED_SHFL_UP_FUNCTION(__half2); + COALESCED_SHFL_DOWN_FUNCTION(__half2); +# endif + +#undef COALESCED_SHFL_FUNCTION +#undef COALESCED_SHFL_UP_FUNCTION +#undef COALESCED_SHFL_DOWN_FUNCTION + + _CG_QUALIFIER int any(int predicate) const { + return (__ballot_sync((unsigned int)_data.coalesced.mask, predicate) != 0); + } + _CG_QUALIFIER int all(int predicate) const { + return (__ballot_sync((unsigned int)_data.coalesced.mask, predicate) == _data.coalesced.mask); + } + _CG_QUALIFIER unsigned int ballot(int predicate) const { + if (size() == 32) { + return (__ballot_sync(0xFFFFFFFF, predicate)); + } + unsigned int lane_ballot = __ballot_sync((unsigned int)_data.coalesced.mask, predicate); + return (_packLanes(lane_ballot)); + } + +#ifdef _CG_HAS_MATCH_COLLECTIVE + +# define COALESCED_MATCH_ANY_FUNCTION(type) \ + _CG_QUALIFIER unsigned int match_any(type val) const { \ + if (size() == 32) { \ + return (__match_any_sync(0xFFFFFFFF, val)); \ + } \ + unsigned int lane_match = __match_any_sync(_data.coalesced.mask, val); \ + return (_packLanes(lane_match)); \ + } +# define COALESCED_MATCH_ALL_FUNCTION(type) \ + _CG_QUALIFIER unsigned int match_all(type val, int &pred) const { \ + if (size() == 32) { \ + return (__match_all_sync(0xFFFFFFFF, val, &pred)); \ + } \ + unsigned int lane_match = __match_all_sync(_data.coalesced.mask, val, &pred); \ + return (_packLanes(lane_match)); \ + } + + COALESCED_MATCH_ANY_FUNCTION(int); + COALESCED_MATCH_ANY_FUNCTION(unsigned int); + COALESCED_MATCH_ANY_FUNCTION(long); + COALESCED_MATCH_ANY_FUNCTION(unsigned long); + COALESCED_MATCH_ANY_FUNCTION(long long); + COALESCED_MATCH_ANY_FUNCTION(unsigned long long); + COALESCED_MATCH_ANY_FUNCTION(float); + COALESCED_MATCH_ANY_FUNCTION(double); + + COALESCED_MATCH_ALL_FUNCTION(int); + COALESCED_MATCH_ALL_FUNCTION(unsigned int); + COALESCED_MATCH_ALL_FUNCTION(long); + COALESCED_MATCH_ALL_FUNCTION(unsigned long); + COALESCED_MATCH_ALL_FUNCTION(long long); + COALESCED_MATCH_ALL_FUNCTION(unsigned long long); + COALESCED_MATCH_ALL_FUNCTION(float); + COALESCED_MATCH_ALL_FUNCTION(double); + +# undef COALESCED_MATCH_ANY_FUNCTION +# undef COALESCED_MATCH_ALL_FUNCTION + +#endif /* !_CG_HAS_MATCH_COLLECTIVE */ + +}; + +_CG_QUALIFIER coalesced_group coalesced_threads() +{ + return (coalesced_group(__activemask())); +} + +template +class __thread_block_tile_base : public thread_group +{ + static const unsigned int numThreads = Size; + + _CG_QUALIFIER unsigned int build_mask() const { + unsigned int mask; + + if (numThreads == 32) { + mask = 0xFFFFFFFF; + } + else { + mask = (unsigned int)(-1) >> (32 - numThreads); + mask <<= (__internal::laneid() & (~(numThreads - 1))); + } + return (mask); + } + + protected: + _CG_QUALIFIER __thread_block_tile_base() : thread_group(__internal::CoalescedTile) { + _data.coalesced.mask = build_mask(); + _data.coalesced.size = numThreads; + _data.coalesced.metaGroupRank = 0; + _data.coalesced.metaGroupSize = 1; + } + + public: + _CG_QUALIFIER void sync() const { + __syncwarp(build_mask()); + } + _CG_QUALIFIER unsigned int thread_rank() const { + return (__internal::laneid() & (numThreads - 1)); + } + _CG_QUALIFIER unsigned int size() const { + return (numThreads); + } + + _CG_QUALIFIER unsigned int meta_group_rank() const { + return _data.coalesced.metaGroupRank; + } + + // Total num partitions created out of all CTAs when the group was created + _CG_QUALIFIER unsigned int meta_group_size() const { + return _data.coalesced.metaGroupSize; + } + + + // PTX supported collectives + _CG_QUALIFIER int shfl(int var, int srcRank) const { + return (__shfl_sync(build_mask(), var, srcRank, numThreads)); + } + _CG_QUALIFIER int shfl_down(int var, unsigned int delta) const { + return (__shfl_down_sync(build_mask(), var, delta, numThreads)); + } + _CG_QUALIFIER int shfl_up(int var, unsigned int delta) const { + return (__shfl_up_sync(build_mask(), var, delta, numThreads)); + } + _CG_QUALIFIER int shfl_xor(int var, unsigned int laneMask) const { + return (__shfl_xor_sync(build_mask(), var, laneMask, numThreads)); + } + _CG_QUALIFIER unsigned int shfl(unsigned int var, int srcRank) const { + return (__shfl_sync(build_mask(), var, srcRank, numThreads)); + } + _CG_QUALIFIER unsigned int shfl_down(unsigned int var, unsigned int delta) const { + return (__shfl_down_sync(build_mask(), var, delta, numThreads)); + } + _CG_QUALIFIER unsigned int shfl_up(unsigned int var, unsigned int delta) const { + return (__shfl_up_sync(build_mask(), var, delta, numThreads)); + } + _CG_QUALIFIER unsigned int shfl_xor(unsigned int var, unsigned int laneMask) const { + return (__shfl_xor_sync(build_mask(), var, laneMask, numThreads)); + } + _CG_QUALIFIER long shfl(long var, int srcRank) const { + return (__shfl_sync(build_mask(), var, srcRank, numThreads)); + } + _CG_QUALIFIER long shfl_down(long var, unsigned int delta) const { + return (__shfl_down_sync(build_mask(), var, delta, numThreads)); + } + _CG_QUALIFIER long shfl_up(long var, unsigned int delta) const { + return (__shfl_up_sync(build_mask(), var, delta, numThreads)); + } + _CG_QUALIFIER long shfl_xor(long var, unsigned int laneMask) const { + return (__shfl_xor_sync(build_mask(), var, laneMask, numThreads)); + } + _CG_QUALIFIER unsigned long shfl(unsigned long var, int srcRank) const { + return (__shfl_sync(build_mask(), var, srcRank, numThreads)); + } + _CG_QUALIFIER unsigned long shfl_down(unsigned long var, unsigned int delta) const { + return (__shfl_down_sync(build_mask(), var, delta, numThreads)); + } + _CG_QUALIFIER unsigned long shfl_up(unsigned long var, unsigned int delta) const { + return (__shfl_up_sync(build_mask(), var, delta, numThreads)); + } + _CG_QUALIFIER unsigned long shfl_xor(unsigned long var, unsigned int laneMask) const { + return (__shfl_xor_sync(build_mask(), var, laneMask, numThreads)); + } + _CG_QUALIFIER long long shfl(long long var, int srcRank) const { + return (__shfl_sync(build_mask(), var, srcRank, numThreads)); + } + _CG_QUALIFIER long long shfl_down(long long var, unsigned int delta) const { + return (__shfl_down_sync(build_mask(), var, delta, numThreads)); + } + _CG_QUALIFIER long long shfl_up(long long var, unsigned int delta) const { + return (__shfl_up_sync(build_mask(), var, delta, numThreads)); + } + _CG_QUALIFIER long long shfl_xor(long long var, unsigned int laneMask) const { + return (__shfl_xor_sync(build_mask(), var, laneMask, numThreads)); + } + _CG_QUALIFIER unsigned long long shfl(unsigned long long var, int srcRank) const { + return (__shfl_sync(build_mask(), var, srcRank, numThreads)); + } + _CG_QUALIFIER unsigned long long shfl_down(unsigned long long var, unsigned int delta) const { + return (__shfl_down_sync(build_mask(), var, delta, numThreads)); + } + _CG_QUALIFIER unsigned long long shfl_up(unsigned long long var, unsigned int delta) const { + return (__shfl_up_sync(build_mask(), var, delta, numThreads)); + } + _CG_QUALIFIER unsigned long long shfl_xor(unsigned long long var, unsigned int laneMask) const { + return (__shfl_xor_sync(build_mask(), var, laneMask, numThreads)); + } + _CG_QUALIFIER float shfl(float var, int srcRank) const { + return (__shfl_sync(build_mask(), var, srcRank, numThreads)); + } + _CG_QUALIFIER float shfl_down(float var, unsigned int delta) const { + return (__shfl_down_sync(build_mask(), var, delta, numThreads)); + } + _CG_QUALIFIER float shfl_up(float var, unsigned int delta) const { + return (__shfl_up_sync(build_mask(), var, delta, numThreads)); + } + _CG_QUALIFIER float shfl_xor(float var, unsigned int laneMask) const { + return (__shfl_xor_sync(build_mask(), var, laneMask, numThreads)); + } + _CG_QUALIFIER double shfl(double var, int srcRank) const { + return (__shfl_sync(build_mask(), var, srcRank, numThreads)); + } + _CG_QUALIFIER double shfl_down(double var, unsigned int delta) const { + return (__shfl_down_sync(build_mask(), var, delta, numThreads)); + } + _CG_QUALIFIER double shfl_up(double var, unsigned int delta) const { + return (__shfl_up_sync(build_mask(), var, delta, numThreads)); + } + _CG_QUALIFIER double shfl_xor(double var, unsigned int laneMask) const { + return (__shfl_xor_sync(build_mask(), var, laneMask, numThreads)); + } + _CG_QUALIFIER int any(int predicate) const { + unsigned int lane_ballot = build_mask() & __ballot_sync(build_mask(), predicate); + return (lane_ballot != 0); + } + _CG_QUALIFIER int all(int predicate) const { + unsigned int lane_ballot = build_mask() & __ballot_sync(build_mask(), predicate); + return (lane_ballot == build_mask()); + } + _CG_QUALIFIER unsigned int ballot(int predicate) const { + unsigned int lane_ballot = build_mask() & __ballot_sync(build_mask(), predicate); + return (lane_ballot >> (__internal::laneid() & (~(numThreads - 1)))); + } + +#ifdef _CG_HAS_FP16_COLLECTIVE + _CG_QUALIFIER __half shfl(__half var, int srcRank) const { + return (__shfl_sync(build_mask(), var, srcRank, numThreads)); + } + _CG_QUALIFIER __half shfl_down(__half var, unsigned int delta) const { + return (__shfl_down_sync(build_mask(), var, delta, numThreads)); + } + _CG_QUALIFIER __half shfl_up(__half var, unsigned int delta) const { + return (__shfl_up_sync(build_mask(), var, delta, numThreads)); + } + _CG_QUALIFIER __half shfl_xor(__half var, unsigned int laneMask) const { + return (__shfl_xor_sync(build_mask(), var, laneMask, numThreads)); + } + _CG_QUALIFIER __half2 shfl(__half2 var, int srcRank) const { + return (__shfl_sync(build_mask(), var, srcRank, numThreads)); + } + _CG_QUALIFIER __half2 shfl_down(__half2 var, unsigned int delta) const { + return (__shfl_down_sync(build_mask(), var, delta, numThreads)); + } + _CG_QUALIFIER __half2 shfl_up(__half2 var, unsigned int delta) const { + return (__shfl_up_sync(build_mask(), var, delta, numThreads)); + } + _CG_QUALIFIER __half2 shfl_xor(__half2 var, unsigned int laneMask) const { + return (__shfl_xor_sync(build_mask(), var, laneMask, numThreads)); + } +#endif + +#ifdef _CG_HAS_MATCH_COLLECTIVE + _CG_QUALIFIER unsigned int match_any(int val) const { + unsigned int lane_match = build_mask() & __match_any_sync(build_mask(), val); + return (lane_match >> (__internal::laneid() & (~(numThreads - 1)))); + } + _CG_QUALIFIER unsigned int match_any(unsigned int val) const { + unsigned int lane_match = build_mask() & __match_any_sync(build_mask(), val); + return (lane_match >> (__internal::laneid() & (~(numThreads - 1)))); + } + _CG_QUALIFIER unsigned int match_any(long val) const { + unsigned int lane_match = build_mask() & __match_any_sync(build_mask(), val); + return (lane_match >> (__internal::laneid() & (~(numThreads - 1)))); + } + _CG_QUALIFIER unsigned int match_any(unsigned long val) const { + unsigned int lane_match = build_mask() & __match_any_sync(build_mask(), val); + return (lane_match >> (__internal::laneid() & (~(numThreads - 1)))); + } + _CG_QUALIFIER unsigned int match_any(long long val) const { + unsigned int lane_match = build_mask() & __match_any_sync(build_mask(), val); + return (lane_match >> (__internal::laneid() & (~(numThreads - 1)))); + } + _CG_QUALIFIER unsigned int match_any(unsigned long long val) const { + unsigned int lane_match = build_mask() & __match_any_sync(build_mask(), val); + return (lane_match >> (__internal::laneid() & (~(numThreads - 1)))); + } + _CG_QUALIFIER unsigned int match_any(float val) const { + unsigned int lane_match = build_mask() & __match_any_sync(build_mask(), val); + return (lane_match >> (__internal::laneid() & (~(numThreads - 1)))); + } + _CG_QUALIFIER unsigned int match_any(double val) const { + unsigned int lane_match = build_mask() & __match_any_sync(build_mask(), val); + return (lane_match >> (__internal::laneid() & (~(numThreads - 1)))); + } + + _CG_QUALIFIER unsigned int match_all(int val, int &pred) const { + unsigned int lane_match = build_mask() & __match_all_sync(build_mask(), val, &pred); + return (lane_match >> (__internal::laneid() & (~(numThreads - 1)))); + } + _CG_QUALIFIER unsigned int match_all(unsigned int val, int &pred) const { + unsigned int lane_match = build_mask() & __match_all_sync(build_mask(), val, &pred); + return (lane_match >> (__internal::laneid() & (~(numThreads - 1)))); + } + _CG_QUALIFIER unsigned int match_all(long val, int &pred) const { + unsigned int lane_match = build_mask() & __match_all_sync(build_mask(), val, &pred); + return (lane_match >> (__internal::laneid() & (~(numThreads - 1)))); + } + _CG_QUALIFIER unsigned int match_all(unsigned long val, int &pred) const { + unsigned int lane_match = build_mask() & __match_all_sync(build_mask(), val, &pred); + return (lane_match >> (__internal::laneid() & (~(numThreads - 1)))); + } + _CG_QUALIFIER unsigned int match_all(long long val, int &pred) const { + unsigned int lane_match = build_mask() & __match_all_sync(build_mask(), val, &pred); + return (lane_match >> (__internal::laneid() & (~(numThreads - 1)))); + } + _CG_QUALIFIER unsigned int match_all(unsigned long long val, int &pred) const { + unsigned int lane_match = build_mask() & __match_all_sync(build_mask(), val, &pred); + return (lane_match >> (__internal::laneid() & (~(numThreads - 1)))); + } + _CG_QUALIFIER unsigned int match_all(float val, int &pred) const { + unsigned int lane_match = build_mask() & __match_all_sync(build_mask(), val, &pred); + return (lane_match >> (__internal::laneid() & (~(numThreads - 1)))); + } + _CG_QUALIFIER unsigned int match_all(double val, int &pred) const { + unsigned int lane_match = build_mask() & __match_all_sync(build_mask(), val, &pred); + return (lane_match >> (__internal::laneid() & (~(numThreads - 1)))); + } +#endif + +}; + +/** + * class thread_block_tile + * + * Statically-sized group type, representing one tile of a thread block. + * The only specializations currently supported are those with native + * hardware support (1/2/4/8/16/32) + * + * This group exposes warp-synchronous builtins. + * Constructed via tiled_partition(class thread_block); + */ +template +class thread_block_tile; +template <> class thread_block_tile<64> : public __thread_block_tile_base<64> { }; +template <> class thread_block_tile<32> : public __thread_block_tile_base<32> { }; +template <> class thread_block_tile<16> : public __thread_block_tile_base<16> { }; +template <> class thread_block_tile<8> : public __thread_block_tile_base<8> { }; +template <> class thread_block_tile<4> : public __thread_block_tile_base<4> { }; +template <> class thread_block_tile<2> : public __thread_block_tile_base<2> { }; +template <> class thread_block_tile<1> : public __thread_block_tile_base<1> { }; + +/** + * Outer level API calls + * void sync(GroupT) - see .sync() + * void thread_rank(GroupT) - see .thread_rank() + * void group_size(GroupT) - see .size() + */ +template _CG_QUALIFIER void sync(GroupT const &g) +{ + g.sync(); +} + +template _CG_QUALIFIER unsigned int thread_rank(GroupT const& g) +{ + return (g.thread_rank()); +} + +template _CG_QUALIFIER unsigned int group_size(GroupT const &g) +{ + return (g.size()); +} + +/** + * .sync() + * + * Executes a barrier across the group + * + * Implements both a compiler fence and an architectural fence to prevent, + * memory reordering around the barrier. + */ +_CG_QUALIFIER void thread_group::sync() const +{ + if (_data.type == __internal::Coalesced || _data.type == __internal::CoalescedTile) { + static_cast(this)->sync(); + } + else { + static_cast(this)->sync(); + } +} + +/** + * .size() + * + * Returns the total number of threads in the group. + */ +_CG_QUALIFIER unsigned int thread_group::size() const +{ + if (_data.type == __internal::Coalesced || _data.type == __internal::CoalescedTile) { + return (static_cast(this)->size()); + } + else { + return (static_cast(this)->size()); + } +} + +/** + * .thread_rank() + * + * Returns the linearized rank of the calling thread along the interval [0, size()). + */ +_CG_QUALIFIER unsigned int thread_group::thread_rank() const +{ + if (_data.type == __internal::Coalesced || _data.type == __internal::CoalescedTile) { + return (static_cast(this)->thread_rank()); + } + else { + return (static_cast(this)->thread_rank()); + } +} + +/** + * tiled_partition + * + * The tiled_partition(parent, tilesz) method is a collective operation that + * partitions the parent group into a one-dimensional, row-major, tiling of subgroups. + * + * A total of ((size(parent)+tilesz-1)/tilesz) subgroups will + * be created where threads having identical k = (thread_rank(parent)/tilesz) + * will be members of the same subgroup. + * + * The implementation may cause the calling thread to wait until all the members + * of the parent group have invoked the operation before resuming execution. + * + * Functionality is limited to power-of-two sized subgorup instances of at most + * 32 threads. Only thread_block, thread_block_tile<>, and their subgroups can be + * tiled_partition() in _CG_VERSION 1000. + */ +_CG_QUALIFIER thread_group tiled_partition(const thread_group& parent, unsigned int tilesz) +{ + if (parent._data.type == __internal::Coalesced || parent._data.type == __internal::CoalescedTile) { + return (static_cast(parent)._get_tiled_threads(tilesz)); + } + else { + return (static_cast(parent)._get_tiled_threads(tilesz)); + } +} +// Thread block type overload: returns a basic thread_group for now (may be specialized later) +_CG_QUALIFIER thread_group tiled_partition(const thread_block& parent, unsigned int tilesz) +{ + return (parent._get_tiled_threads(tilesz)); +} +// Coalesced group type overload: retains its ability to stay coalesced +_CG_QUALIFIER coalesced_group tiled_partition(const coalesced_group& parent, unsigned int tilesz) +{ + return (parent._get_tiled_threads(tilesz)); +} + +namespace __internal { + + // For specializing on different tiled_partition template arguments + template + struct tiled_partition_impl; + + template + struct tiled_partition_impl : public thread_block_tile { + _CG_QUALIFIER tiled_partition_impl(thread_block const &) : thread_block_tile() {} + }; + template + struct tiled_partition_impl > : public thread_block_tile { + _CG_QUALIFIER tiled_partition_impl(thread_block_tile<64> const&) : thread_block_tile() {} + }; + template + struct tiled_partition_impl > : public thread_block_tile { + _CG_QUALIFIER tiled_partition_impl(thread_block_tile<32> const&) : thread_block_tile() {} + }; + template + struct tiled_partition_impl > : public thread_block_tile { + _CG_QUALIFIER tiled_partition_impl(thread_block_tile<16> const&) : thread_block_tile() {} + }; + template + struct tiled_partition_impl > : public thread_block_tile { + _CG_QUALIFIER tiled_partition_impl(thread_block_tile<8> const&) : thread_block_tile() {} + }; + template + struct tiled_partition_impl > : public thread_block_tile { + _CG_QUALIFIER tiled_partition_impl(thread_block_tile<4> const&) : thread_block_tile() {} + }; + template + struct tiled_partition_impl > : public thread_block_tile { + _CG_QUALIFIER tiled_partition_impl(thread_block_tile<2> const&) : thread_block_tile() {} + }; + template <> + struct tiled_partition_impl<1, thread_block_tile<1> > : public thread_block_tile<1> { + _CG_QUALIFIER tiled_partition_impl(thread_block_tile<1> const&) : thread_block_tile<1>() {} + }; + +}; + +/** + * tiled_partition + * + * The tiled_partition(parent) method is a collective operation that + * partitions the parent group into a one-dimensional, row-major, tiling of subgroups. + * + * A total of ((size(parent)/tilesz) subgroups will be created, + * therefore the parent group size must be evenly divisible by the tilesz. + * The allow parent groups are thread_block or thread_block_tile. + * + * The implementation may cause the calling thread to wait until all the members + * of the parent group have invoked the operation before resuming execution. + * + * Functionality is limited to native hardware sizes, 1/2/4/8/16/32. + * The size(parent) must be greater than the template Size parameter + * otherwise the results are undefined. + */ +template +_CG_QUALIFIER thread_block_tile tiled_partition(const ParentT& g) +{ + return (__internal::tiled_partition_impl(g)); +} + +_CG_END_NAMESPACE + +# endif /* ! (__cplusplus, __CUDACC__) */ + +#endif /* !_COOPERATIVE_GROUPS_H_ */ \ No newline at end of file diff --git a/toolbox/DeepSpeed/v0.15.3/patches/csrc/includes/cooperative_groups_helpers.h b/toolbox/DeepSpeed/v0.15.3/patches/csrc/includes/cooperative_groups_helpers.h new file mode 100644 index 0000000000000000000000000000000000000000..e9ecdef3c1337d71bdd6dbf3f249bb6d1bf93b7c --- /dev/null +++ b/toolbox/DeepSpeed/v0.15.3/patches/csrc/includes/cooperative_groups_helpers.h @@ -0,0 +1,354 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +/* +Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may +not use this file except in compliance with the License. You may obtain +a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +/* Copyright 1993-2016 NVIDIA Corporation. All rights reserved. + * + * NOTICE TO LICENSEE: + * + * The source code and/or documentation ("Licensed Deliverables") are + * subject to NVIDIA intellectual property rights under U.S. and + * international Copyright laws. + * + * The Licensed Deliverables contained herein are PROPRIETARY and + * CONFIDENTIAL to NVIDIA and are being provided under the terms and + * conditions of a form of NVIDIA software license agreement by and + * between NVIDIA and Licensee ("License Agreement") or electronically + * accepted by Licensee. Notwithstanding any terms or conditions to + * the contrary in the License Agreement, reproduction or disclosure + * of the Licensed Deliverables to any third party without the express + * written consent of NVIDIA is prohibited. + * + * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE + * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE + * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. THEY ARE + * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND. + * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED + * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY, + * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE. + * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE + * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY + * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY + * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, + * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS + * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE + * OF THESE LICENSED DELIVERABLES. + * + * U.S. Government End Users. These Licensed Deliverables are a + * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT + * 1995), consisting of "commercial computer software" and "commercial + * computer software documentation" as such terms are used in 48 + * C.F.R. 12.212 (SEPT 1995) and are provided to the U.S. Government + * only as a commercial end item. Consistent with 48 C.F.R.12.212 and + * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all + * U.S. Government End Users acquire the Licensed Deliverables with + * only those rights set forth herein. + * + * Any use of the Licensed Deliverables in individual and commercial + * software must include, in the user documentation and internal + * comments to the code, the above Disclaimer and U.S. Government End + * Users Notice. + */ + +#ifndef _COOPERATIVE_GROUPS_HELPERS_H_ +# define _COOPERATIVE_GROUPS_HELPERS_H_ + +/* +** Define: _CG_VERSION +*/ +# define _CG_VERSION 1000 + +/* +** Define: _CG_ABI_VERSION +*/ +# ifndef _CG_ABI_VERSION +# define _CG_ABI_VERSION 1 +# endif + +/* +** Define: _CG_ABI_EXPERIMENTAL +** Desc: If enabled, sets all features enabled (ABI-breaking or experimental) +*/ +# if defined(_CG_ABI_EXPERIMENTAL) +# endif + +# define _CG_CONCAT_INNER(x, y) x ## y +# define _CG_CONCAT_OUTER(x, y) _CG_CONCAT_INNER(x, y) +# define _CG_NAMESPACE _CG_CONCAT_OUTER(__v, _CG_ABI_VERSION) + +# define _CG_BEGIN_NAMESPACE \ + namespace cooperative_groups { namespace _CG_NAMESPACE { +# define _CG_END_NAMESPACE \ + }; using namespace _CG_NAMESPACE; }; + +# if !defined(_CG_STATIC_QUALIFIER) +# define _CG_STATIC_QUALIFIER static __forceinline__ __device__ +# endif +# if !defined(_CG_QUALIFIER) +# define _CG_QUALIFIER __forceinline__ __device__ +# endif + +#ifndef __ILUVATAR__ +# if (__CUDA_ARCH__ >= 600) || !defined(__CUDA_ARCH__) +# define _CG_HAS_GRID_GROUP +# endif +# if (__CUDA_ARCH__ >= 600) || !defined(__CUDA_ARCH__) +# define _CG_HAS_MULTI_GRID_GROUP +# endif +# if (__CUDA_ARCH__ >= 700) || !defined(__CUDA_ARCH__) +# define _CG_HAS_MATCH_COLLECTIVE +# endif +#endif + +// Has __half and __half2 +// Only usable if you include the cuda_fp16.h extension, and +// _before_ including cooperative_groups.h +# ifdef __CUDA_FP16_TYPES_EXIST__ +# define _CG_HAS_FP16_COLLECTIVE +# endif + +/* +** Define: CG_DEBUG +** What: Enables various runtime safety checks +*/ +#if defined(__CUDACC_DEBUG__) && !defined(_CG_DEBUG) +# define _CG_DEBUG 1 +#endif + +#if defined(_CG_DEBUG) && (_CG_DEBUG == 1) && !defined(NDEBUG) +# include +# define _CG_ASSERT(x) assert((x)); +# define _CG_ABORT() assert(0); +#else +# define _CG_ASSERT(x) +# define _CG_ABORT() __trap(); +#endif + +_CG_BEGIN_NAMESPACE + +namespace __internal { + + enum groupType { + CoalescedTile, + Coalesced, + ThreadBlock, + Grid, + MultiGrid, + }; + +#if defined(_CG_HAS_GRID_GROUP) + + namespace grid { + + _CG_STATIC_QUALIFIER unsigned long long get_intrinsic_handle() + { + return (cudaCGGetIntrinsicHandle(cudaCGScopeGrid)); + } + + _CG_STATIC_QUALIFIER void sync(const unsigned long long handle) + { + cudaCGSynchronizeGrid(handle, 0); + } + + _CG_STATIC_QUALIFIER unsigned int size(const unsigned long long handle) + { + return (blockDim.z * gridDim.z) * + (blockDim.y * gridDim.y) * + (blockDim.x * gridDim.x); + } + + _CG_STATIC_QUALIFIER unsigned int thread_rank(const unsigned long long handle) + { + unsigned int blkIdx = ((blockIdx.z * gridDim.y * gridDim.x) + + (blockIdx.y * gridDim.x) + + blockIdx.x); + return (blkIdx * (blockDim.x * blockDim.y * blockDim.z) + + ((threadIdx.z * blockDim.y * blockDim.x) + + (threadIdx.y * blockDim.x) + + threadIdx.x)); + } + + _CG_STATIC_QUALIFIER dim3 grid_dim() + { + return (dim3(gridDim.x, gridDim.y, gridDim.z)); + } + }; + +#endif + +#if defined(_CG_HAS_MULTI_GRID_GROUP) + + namespace multi_grid { + + _CG_STATIC_QUALIFIER unsigned long long get_intrinsic_handle() + { + return (cudaCGGetIntrinsicHandle(cudaCGScopeMultiGrid)); + } + + _CG_STATIC_QUALIFIER void sync(const unsigned long long handle) + { + cudaError_t err = cudaCGSynchronize(handle, 0); + } + + _CG_STATIC_QUALIFIER unsigned int size(const unsigned long long handle) + { + unsigned int numThreads = 0; + cudaCGGetSize(&numThreads, NULL, handle); + return numThreads; + } + + _CG_STATIC_QUALIFIER unsigned int thread_rank(const unsigned long long handle) + { + unsigned int threadRank = 0; + cudaCGGetRank(&threadRank, NULL, handle); + return threadRank; + } + + _CG_STATIC_QUALIFIER unsigned int grid_rank(const unsigned long long handle) + { + unsigned int gridRank = 0; + cudaCGGetRank(NULL, &gridRank, handle); + return gridRank; + } + + _CG_STATIC_QUALIFIER unsigned int num_grids(const unsigned long long handle) + { + unsigned int numGrids = 0; + cudaCGGetSize(NULL, &numGrids, handle); + return numGrids; + } + + }; + +#endif + + namespace cta { +#ifdef __ILUVATAR__ + _CG_STATIC_QUALIFIER void sync() + { + __syncthreads(); + } +#else + _CG_STATIC_QUALIFIER void sync() + { + __barrier_sync(0); + } +#endif + _CG_STATIC_QUALIFIER unsigned int size() + { + return (blockDim.x * blockDim.y * blockDim.z); + } + + _CG_STATIC_QUALIFIER unsigned int thread_rank() + { + return ((threadIdx.z * blockDim.y * blockDim.x) + + (threadIdx.y * blockDim.x) + + threadIdx.x); + } + + _CG_STATIC_QUALIFIER dim3 group_index() + { + return (dim3(blockIdx.x, blockIdx.y, blockIdx.z)); + } + + _CG_STATIC_QUALIFIER dim3 thread_index() + { + return (dim3(threadIdx.x, threadIdx.y, threadIdx.z)); + } + + _CG_STATIC_QUALIFIER dim3 block_dim() + { + return (dim3(blockDim.x, blockDim.y, blockDim.z)); + } + + }; + +#ifdef __ILUVATAR__ + _CG_STATIC_QUALIFIER unsigned int laneid() + { + return __ivcorex_lane_id(); + } + + _CG_STATIC_QUALIFIER unsigned int warpsz() + { + return warpSize; + } + + _CG_STATIC_QUALIFIER unsigned long long int lanemask_eq() + { + unsigned long long int m = 1ull << laneid(); + return m; + } + + _CG_STATIC_QUALIFIER unsigned long long int lanemask_lt() + { + unsigned long long int m = (1ull << laneid()) - 1ull; + return m; + } +#else + _CG_STATIC_QUALIFIER unsigned int laneid() + { + unsigned int laneid; + asm volatile("mov.u32 %0, %%laneid;" : "=r"(laneid)); + return laneid; + } + + _CG_STATIC_QUALIFIER unsigned int warpsz() + { + unsigned int warpSize; + asm volatile("mov.u32 %0, WARP_SZ;" : "=r"(warpSize)); + return warpSize; + } + + _CG_STATIC_QUALIFIER unsigned int lanemask32_eq() + { + unsigned int lanemask32_eq; + asm volatile("mov.u32 %0, %%lanemask_eq;" : "=r"(lanemask32_eq)); + return (lanemask32_eq); + } + + _CG_STATIC_QUALIFIER unsigned int lanemask32_lt() + { + unsigned int lanemask32_lt; + asm volatile("mov.u32 %0, %%lanemask_lt;" : "=r"(lanemask32_lt)); + return (lanemask32_lt); + } +#endif + _CG_STATIC_QUALIFIER void abort() + { + _CG_ABORT(); + } + +}; // !Namespace internal + +_CG_END_NAMESPACE + +#endif /* !_COOPERATIVE_GROUPS_HELPERS_H_ */ \ No newline at end of file diff --git a/toolbox/DeepSpeed/v0.15.3/patches/csrc/includes/custom_cuda_layers.h b/toolbox/DeepSpeed/v0.15.3/patches/csrc/includes/custom_cuda_layers.h new file mode 100644 index 0000000000000000000000000000000000000000..7383509248fee3f901511a728d6a4d0cdf32b518 --- /dev/null +++ b/toolbox/DeepSpeed/v0.15.3/patches/csrc/includes/custom_cuda_layers.h @@ -0,0 +1,341 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +/* Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. */ +/* All Rights Reserved. */ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#include "ds_kernel_utils.h" + +#include +#include +#include +#include +#include + +#include "context.h" +#include "cublas_wrappers.h" + +#define CUDA_CHECK(callstr) \ + { \ + cudaError_t error_code = callstr; \ + if (error_code != cudaSuccess) { \ + std::cerr << "CUDA error " << error_code << " at " << __FILE__ << ":" << __LINE__; \ + assert(0); \ + } \ + } + +#define MAX_THREADS 65536 +#define THREADS 256 + +#define MAX_THREAD_STRIDE 32 +#define TILE_DIM 32 + +// Maximum sequence-length support based on the number of threads (2048) allowed in each block and +// this MAX is 8K For higher sequence length we need to use higher Max, like for 64K : 32 +#define MAX_THREAD_ITERATIONS 8 // Maximum 8K +#define MAX_WARP_NUM 32 + +#define MAX_REGISTERS 256 + +#define MAX_REG 256 + +#define WARP_SIZE_BITS 5 + +// Fused bias add with gelu activation +template +void launch_bias_gelu(const T* input, + const T* bias, + T* output, + int intermediate_size, + int batch_size, + cudaStream_t stream); + +template +void launch_gelu(const T* input, + T* output, + int intermediate_size, + int batch_size, + cudaStream_t stream); + +template +void launch_d_gelu(T* d_output, + const T* input, + const T* bias, + int intermediate_size, + int batch_size, + cudaStream_t stream); + +// Custom fused bias add with layer normalization +template +void launch_bias_residual_layer_norm(T* vals, + const T* residual, + const T* gamma, + const T* beta, + float epsilon, + int batch_size, + int hidden_dim, + cudaStream_t stream, + bool preLayerNorm, + bool training, + T* vars, + T* means); + +template +void launch_bias_residual_layer_norm(T* vals, + const T* residual, + const T* gamma, + const T* beta, + float epsilon, + int batch_size, + int hidden_dim, + cudaStream_t stream, + bool preLayerNorm, + bool training, + T* vars); + +template +void launch_layerNorm_backward_fused_add(const T* out_grad1, + const T* out_grad2, + const T* X_data, + const T* vars, + const T* means, + const T* gamma, + T* gamma_grad, + T* betta_grad, + T* inp_grad, + int batch_size, + int hidden_dim, + cudaStream_t stream[2]); +template +void launch_layerNorm_backward_fused_add(const T* out_grad1, + const T* out_grad2, + const T* vals_hat, + const T* vars, + const T* gamma, + T* gamma_grad, + T* betta_grad, + T* inp_grad, + int batch_size, + int hidden_dim, + cudaStream_t stream[2], + bool invertible = false, + const T* betta = nullptr); + +template +void launch_layerNorm_backward(const T* out_grad, + const T* X_data, + const T* vars, + const T* means, + const T* gamma, + T* gamma_grad, + T* betta_grad, + T* inp_grad, + int batch_size, + int hidden_dim, + cudaStream_t stream[2]); + +template +void launch_layerNorm_backward(const T* out_grad, + const T* vals_hat, + const T* vars, + const T* gamma, + T* gamma_grad, + T* betta_grad, + T* inp_grad, + int batch_size, + int hidden_dim, + cudaStream_t stream[2], + bool invertible = false, + const T* betta = nullptr); + +template +void launch_layerNorm_backward_nreversible(const T* out_grad, + const T* vals, + const T* out_grad_trans, + const T* vals_trans, + const T* means, + const T* vars, + const T* gamma, + T* gamma_grad, + T* betta_grad, + T* inp_grad, + int batch_size, + int hidden_dim, + cudaStream_t stream[2]); + +template +void Transpose(const T* inp_mat, T* out_mat, int rows, int cols, cudaStream_t stream); + +template +void launch_attn_softmax_backward(T* out_grad, + const T* soft_inp, + int batch_size, + int heads, + int seq_length, + cudaStream_t stream); + +template +void launch_attn_softmax_backward_v2(T* out_grad, + const T* soft_inp, + int batch_size, + int heads, + int seq_length, + cudaStream_t stream); + +// Custom softmax with scaling and attention mask addition +template +void launch_attn_softmax(T* vals, + const T* attn_mask, + int batch_size, + int heads, + int sequence_length, + cudaStream_t stream); + +template +void launch_transform_0213(T* output, + const T* vals, + int batch_size, + int seq_length, + int hidden_dim, + int heads, + cudaStream_t stream); + +// Custom bias add +template +void launch_bias_add_transform_0213(T* outputs, + const T* vals, + const T* bias, + int batch_size, + int seq_length, + int hidden_dim, + int heads, + cudaStream_t stream, + int trans_count); + +// 4D transform [0, 1, 2, 3] -> [0, 2, 1, 3] +template +void launch_transform4d_0213(T* out, + const T* in, + int batch_size, + int heads, + int seq_length, + int hidden_dim, + cudaStream_t stream, + int trans_count); + +template +void launch_dropout(T* vals, + const T* bias, + uint8_t* mask, + int batch, + int dim, + float ratio, + cudaStream_t stream); + +template +void launch_dropout(T* vals_out, + const T* vals, + uint8_t* mask, + int total_count, + int dim, + float ratio, + cudaStream_t stream, + bool bwd = false); + +template +void launch_dropout(T* out, + const T* vals, + const T* residual, + const T* bias, + uint8_t* mask, + int batch, + int dim, + float ratio, + cudaStream_t stream); + +template +void launch_dropout_grad(T* vals, uint8_t* mask, int total_count, float ratio, cudaStream_t stream); + +template +void launch_dropout_grad(T* vals_out, + const T* vals, + uint8_t* mask, + int total_count, + float ratio, + cudaStream_t stream); + +template +void launch_fuse_transpose_bias_kernel(const T* inp, + T* out, + int rows, + int cols, + cudaStream_t stream); + +void launch_token_sort(int32_t* indices, + int layers, + int batch_size, + int reserved_size, + int original_tokens, + cudaStream_t stream); + +template +void launch_gather_tokens(T* retained_tokens, + T* activations, + int32_t* gather_indices, + int32_t batch_size, + int32_t sampled_tokens, + int32_t channels, + int32_t read_batch_stride, + int32_t read_seq_stride, + int32_t write_batch_stride, + int32_t write_seq_stride, + cudaStream_t stream); + +template +void launch_scatter_tokens(T* all_activations, + T* layer_activations, + int32_t* gather_indices, + int32_t batch_size, + int32_t sampled_tokens, + int32_t channels, + int32_t read_batch_stride, + int32_t read_seq_stride, + int32_t write_batch_stride, + int32_t write_seq_stride, + cudaStream_t stream); + +template +void launch_slice_gpt_mask(T* output_mask, + const T* input_mask, + int batch_size, + int truncated_seq_len, + int orig_seq_len, + cudaStream_t stream); + +template +void launch_slice_bert_mask(T* output_mask, + const T* input_mask, + const int32_t* retained_indices, + int32_t layers, + int32_t batch_size, + int32_t truncated_seq_len, + int32_t orig_seq_len, + cudaStream_t stream); diff --git a/toolbox/DeepSpeed/v0.15.3/patches/csrc/includes/ds_kernel_utils.h b/toolbox/DeepSpeed/v0.15.3/patches/csrc/includes/ds_kernel_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..86dcea21bdac755326c1669efa8f42a313f1ff86 --- /dev/null +++ b/toolbox/DeepSpeed/v0.15.3/patches/csrc/includes/ds_kernel_utils.h @@ -0,0 +1,75 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +/* Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. */ +/* All Rights Reserved. */ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +/* +Centralized header file for preprocessor macros and constants +used throughout the codebase. +*/ + +#pragma once + +#include +#include + +#ifdef BF16_AVAILABLE +#include +#endif + +#define DS_HD_INLINE __host__ __device__ __forceinline__ +#define DS_D_INLINE __device__ __forceinline__ + +#ifdef __HIP_PLATFORM_AMD__ + +// constexpr variant of warpSize for templating +constexpr int hw_warp_size = ROCM_WAVEFRONT_SIZE; +#define HALF_PRECISION_AVAILABLE = 1 +#include +#include + +#else // !__HIP_PLATFORM_AMD__ + +// constexpr variant of warpSize for templating +constexpr int hw_warp_size = 64; +#define HALF_PRECISION_AVAILABLE = 1 +#if __CUDA_ARCH__ >= 530 +#define HALF_PRECISION_AVAILABLE = 1 +#define PTX_AVAILABLE +#endif // __CUDA_ARCH__ >= 530 + +#if __CUDA_ARCH__ >= 800 +#define ASYNC_COPY_AVAILABLE +#endif // __CUDA_ARCH__ >= 800 + +#include +#include + +#endif //__HIP_PLATFORM_AMD__ + +inline int next_pow2(const int val) +{ + int rounded_val = val - 1; + rounded_val |= rounded_val >> 1; + rounded_val |= rounded_val >> 2; + rounded_val |= rounded_val >> 4; + rounded_val |= rounded_val >> 8; + return rounded_val + 1; +} diff --git a/toolbox/DeepSpeed/v0.15.3/patches/csrc/includes/fused_rotary_positional_embedding.h b/toolbox/DeepSpeed/v0.15.3/patches/csrc/includes/fused_rotary_positional_embedding.h new file mode 100644 index 0000000000000000000000000000000000000000..f669e14876584dc00f3a0f78a6eb49c70e434339 --- /dev/null +++ b/toolbox/DeepSpeed/v0.15.3/patches/csrc/includes/fused_rotary_positional_embedding.h @@ -0,0 +1,181 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +/* +Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may +not use this file except in compliance with the License. You may obtain +a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +/* coding=utf-8 + * Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include + +namespace { + +template +__global__ void fused_rope_forward(int h, int d, int d2, int stride_s, + int stride_b, int stride_h, int stride_d, + int o_stride_s, int o_stride_b, + int o_stride_h, int o_stride_d, + const scalar_t* src, const scalar_t* cos, + const scalar_t* sin, scalar_t* dst) { + int s_id = blockIdx.x, b_id = blockIdx.y; + int offset_block = s_id * stride_s + b_id * stride_b; + int offset_block_dst = s_id * o_stride_s + b_id * o_stride_b; +#pragma unroll + for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) { + scalar_t v_cos = cos[s_id * d2 + d_id]; + scalar_t v_sin = sin[s_id * d2 + d_id]; +#pragma unroll + for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) { + int offset_src = offset_block + h_id * stride_h + d_id * stride_d; + int offset_dst = offset_block_dst + h_id * o_stride_h + d_id * o_stride_d; + scalar_t v_src = src[offset_src]; + scalar_t v_src_rotate = (d_id + d2 / 2 < d2) + ? -src[offset_src + (d2 / 2) * stride_d] + : src[offset_src + (d2 / 2 - d2) * stride_d]; + dst[offset_dst] = v_src * v_cos + v_src_rotate * v_sin; + } + } + + // copy the rest + if (d > d2) { +#pragma unroll + for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) { + int offset_head = offset_block + h_id * stride_h; + int offset_head_dst = offset_block_dst + h_id * o_stride_h; +#pragma unroll + for (int d_id = d2 + threadIdx.x; d_id < d; d_id += blockDim.x) { + dst[offset_head_dst + d_id * o_stride_d] = + src[offset_head + d_id * stride_d]; + } + } + } +} + +template +__global__ void fused_rope_backward(int h, int d, int d2, int stride_s, + int stride_b, int stride_h, int stride_d, + int o_stride_s, int o_stride_b, + int o_stride_h, int o_stride_d, + const scalar_t* src, const scalar_t* cos, + const scalar_t* sin, scalar_t* dst) { + int s_id = blockIdx.x, b_id = blockIdx.y; + int offset_block = s_id * stride_s + b_id * stride_b; + int offset_block_dst = s_id * o_stride_s + b_id * o_stride_b; +#pragma unroll + for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) { + scalar_t v_cos = cos[s_id * d2 + d_id]; + scalar_t v_sin = (d_id + d2 / 2 < d2) + ? sin[s_id * d2 + d_id + d2 / 2] + : -sin[s_id * d2 + d_id + d2 / 2 - d2]; +#pragma unroll + for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) { + int offset_src = offset_block + h_id * stride_h + d_id * stride_d; + int offset_dst = offset_block_dst + h_id * o_stride_h + d_id * o_stride_d; + scalar_t v_src = src[offset_src]; + scalar_t v_src_rotate = (d_id + d2 / 2 < d2) + ? src[offset_src + (d2 / 2) * stride_d] + : src[offset_src + (d2 / 2 - d2) * stride_d]; + dst[offset_dst] = v_src * v_cos + v_src_rotate * v_sin; + } + } + + // handle the tail + if (d > d2) { +#pragma unroll + for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) { + int offset_head = offset_block + h_id * stride_h; + int offset_head_dst = offset_block_dst + h_id * o_stride_h; +#pragma unroll + for (int d_id = d2 + threadIdx.x; d_id < d; d_id += blockDim.x) { + dst[offset_head_dst + d_id * o_stride_d] = src[offset_head + d_id * stride_d]; + } + } + } +} + +} // end of anonymous namespace + +template +void dispatch_fused_rope_forward(int s, int b, int h, int d, int d2, + int stride_s, int stride_b, int stride_h, + int stride_d, int o_stride_s, int o_stride_b, + int o_stride_h, int o_stride_d, + const scalar_t* input, const scalar_t* cos, + const scalar_t* sin, scalar_t* output) { + auto stream = at::cuda::getCurrentCUDAStream(); + + int warps_per_block = h < 16 ? 4 : 8; + dim3 blocks(s, b); + dim3 threads(C10_WARP_SIZE, warps_per_block); + + fused_rope_forward<<>>( + h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b, + o_stride_h, o_stride_d, input, cos, sin, output); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +template +void dispatch_fused_rope_backward(int s, int b, int h, int d, int d2, + int stride_s, int stride_b, int stride_h, + int stride_d, int o_stride_s, int o_stride_b, + int o_stride_h, int o_stride_d, + const scalar_t* output_grads, + const scalar_t* cos, const scalar_t* sin, + scalar_t* input_grads) { + auto stream = at::cuda::getCurrentCUDAStream(); + + int warps_per_block = h < 16 ? 4 : 8; + dim3 blocks(s, b); + dim3 threads(C10_WARP_SIZE, warps_per_block); + + fused_rope_backward<<>>( + h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b, + o_stride_h, o_stride_d, output_grads, cos, sin, input_grads); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} \ No newline at end of file diff --git a/toolbox/DeepSpeed/v0.15.3/patches/csrc/includes/reduction_utils.h b/toolbox/DeepSpeed/v0.15.3/patches/csrc/includes/reduction_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..a9ace2fd1325d6f0ca35015d9f78c2be20991e92 --- /dev/null +++ b/toolbox/DeepSpeed/v0.15.3/patches/csrc/includes/reduction_utils.h @@ -0,0 +1,847 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +/* Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. */ +/* All Rights Reserved. */ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#include "conversion_utils.h" +#include "ds_kernel_utils.h" +#include "memory_access_utils.h" + +namespace cg = cooperative_groups; + +namespace reduce { + +enum class ROpType { + // Addition + Add, + + // Maximum reduction + Max, + + // Minimum reduction + Min, +}; + +constexpr int max_threads = 2048; +constexpr int max_warps = max_threads / hw_warp_size; + +/* +High level API. The API takes in a set of operations and variables +and performs that reduction operation on that variable. The reductions +of each of the arguments are completely independent of each other ( +i.e., the val1-op1 combination has no impact on val2-op2). + +Example usage: +``` cpp +float max_val; +float min_val; +reduce::block(tb, warp, max_val, min_val); +``` + +TODO(cmikeh2): In theory, we might be able to do this sequentially with +device functions and rely on the assembler correctly behaving. My initial +instinct is this won't work, but if it does it would reduce implementation +cost significantly. + +TODO(cmikeh2): We need to support sub-block reductions. The warp intrinsic +currently supports this (more incidentally than anything else). It is not +uncommon in something like softmax or a fused attention kernel to map multiple +reductions to a thread block, but each reduction itself is only scoped +to part of the threads (i.e block size = 512, 128 threads per reduction). +*/ +template +DS_D_INLINE void block(cg::thread_block& tb, cg::thread_block_tile& warp, float& val); + +template +DS_D_INLINE void block(cg::thread_block& tb, + cg::thread_block_tile& warp, + float& val1, + float& val2); + +template +DS_D_INLINE void block(cg::thread_block& tb, + cg::thread_block_tile& warp, + float& val1, + float& val2, + float& val3); + +template +DS_D_INLINE void block(cg::thread_block& tb, + cg::thread_block_tile& warp, + float& val1, + float& val2, + float& val3, + float& val4); + +/* +The partitioned block is a special case of the above where in the warps of a threadblock are +partitioned into separate independent reductions. For example, I might have an 8 warp thread block +in which each pair of warps is processing an independent piece of data. I would then reduce that +data with the something like the following: +``` cpp +float max_val; +reduce::partitioned_block(tb, warp, max_val); +``` +After which, each pair of warps would have coherent data with each other. Note, this API will not +provide correct results if the number of warps per partition is not a power of 2. +*/ +template +DS_D_INLINE void partitioned_block(cg::thread_block& tb, + cg::thread_block_tile& warp, + float& val); + +template +DS_D_INLINE void partitioned_block(cg::thread_block& tb, + cg::thread_block_tile& warp, + float& val1, + float& val2); + +template +DS_D_INLINE void partitioned_block(cg::thread_block& tb, + cg::thread_block_tile& warp, + float& val1, + float& val2, + float& val3); + +template +DS_D_INLINE void partitioned_block(cg::thread_block& tb, + cg::thread_block_tile& warp, + float& val1, + float& val2, + float& val3, + float& val4); + +/* +Single element reduction primitives. Used inside serial collection +loops. + +Example usage: +using rop = reduce::OpType; +float min = init(); +for (int i = 0; i < 4; i++) { + min = reduce::element(min, data[i]); +} +*/ + +template +DS_D_INLINE T element(const T lhs, const T rhs); + +template +DS_D_INLINE T init(); + +/********************** Internal reduction APIs **********************/ + +/* +Single element "reductions". TODO(cmikeh2): this sort of "op" concept +should be refactored into its own implementation at some point. This interface +may be easily expanded for new types/operations, but the typical reductions +we need are covered with min/max/add on float. + +NOTE: there is no mean reduction because that relies on knowledge of how +many values were already reduced into each scalar. Implementing this on top +of reduce should be straightforward (can just wrap the sum reduction) and +would be a good extension of the header. +*/ + +DS_D_INLINE int _warp_rank() +{ + const int thread_rank = + threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * blockDim.x * blockDim.y; + return thread_rank / hw_warp_size; +} + +/* Float element reduce implementations */ +template <> +DS_D_INLINE float element(const float lhs, const float rhs) +{ + return lhs + rhs; +} + +template <> +DS_D_INLINE double element(const double lhs, const double rhs) +{ + return lhs + rhs; +} + +template <> +DS_D_INLINE float element(const float lhs, const float rhs) +{ + return fmaxf(lhs, rhs); +} + +template <> +DS_D_INLINE float element(const float lhs, const float rhs) +{ + return fminf(lhs, rhs); +} + +/* __half element reduce implementation */ +template <> +DS_D_INLINE __half element(const __half lhs, const __half rhs) +{ + return lhs + rhs; +} + +template <> +DS_D_INLINE __half element(const __half lhs, const __half rhs) +{ +#if __CUDA_ARCH__ >= 800 + // Intrinsic limited to Ampere + newer + return __hmax(lhs, rhs); +#else + return (lhs > rhs) ? lhs : rhs; +#endif +} + +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE __nv_bfloat16 element(const __nv_bfloat16 lhs, const __nv_bfloat16 rhs) +{ +#if __CUDA_ARCH__ >= 800 + // Intrinsic limited to Ampere + newer + return __hmax(lhs, rhs); +#else + return (lhs > rhs) ? lhs : rhs; +#endif +} +#endif + +template <> +DS_D_INLINE __half element(const __half lhs, const __half rhs) +{ +#if __CUDA_ARCH__ >= 800 + // Intrinsic limited to Ampere + newer + return __hmin(lhs, rhs); +#else + return (lhs < rhs) ? lhs : rhs; +#endif +} + +/* __half2 element reduce implementation */ +template <> +DS_D_INLINE __half2 element(const __half2 lhs, const __half2 rhs) +{ + return lhs + rhs; +} + +template <> +DS_D_INLINE __half2 element(const __half2 lhs, const __half2 rhs) +{ +#if __CUDA_ARCH__ >= 800 + return __hmax2(lhs, rhs); +#else + __half2 ret_val; + ret_val.x = (lhs.x > rhs.x) ? lhs.x : rhs.x; + ret_val.y = (lhs.y > rhs.y) ? lhs.y : rhs.y; + return ret_val; +#endif +} + +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE __nv_bfloat162 element(const __nv_bfloat162 lhs, const __nv_bfloat162 rhs) +{ +#if __CUDA_ARCH__ >= 800 + return __hmax2(lhs, rhs); +#else + __nv_bfloat162 ret_val; + ret_val.x = (lhs.x > rhs.x) ? lhs.x : rhs.x; + ret_val.y = (lhs.y > rhs.y) ? lhs.y : rhs.y; + return ret_val; +#endif +} +#endif + +template <> +DS_D_INLINE __half2 element(const __half2 lhs, const __half2 rhs) +{ +#if __CUDA_ARCH__ >= 800 + return __hmin2(lhs, rhs); +#else + __half2 ret_val; + ret_val.x = (lhs.x < rhs.x) ? lhs.x : rhs.x; + ret_val.y = (lhs.y < rhs.y) ? lhs.y : rhs.y; + return ret_val; +#endif +} + +template <> +DS_D_INLINE int32_t element(const int32_t lhs, const int32_t rhs) +{ + return lhs + rhs; +} + +template <> +DS_D_INLINE int32_t element(const int32_t lhs, const int32_t rhs) +{ + return (lhs > rhs) ? lhs : rhs; +} + +template <> +DS_D_INLINE int32_t element(const int32_t lhs, const int32_t rhs) +{ + return (lhs < rhs) ? lhs : rhs; +} + +template <> +DS_D_INLINE uint32_t element(const uint32_t lhs, const uint32_t rhs) +{ + return lhs + rhs; +} + +template <> +DS_D_INLINE uint32_t element(const uint32_t lhs, const uint32_t rhs) +{ + return (lhs > rhs) ? lhs : rhs; +} + +template <> +DS_D_INLINE uint32_t element(const uint32_t lhs, const uint32_t rhs) +{ + return (lhs < rhs) ? lhs : rhs; +} + +template <> +DS_D_INLINE int64_t element(const int64_t lhs, const int64_t rhs) +{ + return lhs + rhs; +} + +template <> +DS_D_INLINE int64_t element(const int64_t lhs, const int64_t rhs) +{ + return (lhs > rhs) ? lhs : rhs; +} + +template <> +DS_D_INLINE int64_t element(const int64_t lhs, const int64_t rhs) +{ + return (lhs < rhs) ? lhs : rhs; +} + +/* +Reduction initialization primitives +*/ +template <> +DS_D_INLINE float init() +{ + return 0.0f; +} +template <> +DS_D_INLINE double init() +{ + return (double)0.0f; +} + +template <> +DS_D_INLINE float init() +{ + // Positive infinity + return INFINITY; +} + +template <> +DS_D_INLINE float init() +{ + // Negative infinity + return -INFINITY; +} + +template <> +DS_D_INLINE __half init() +{ + constexpr __half_raw zero = {0x0000}; + return __half(zero); +} + +template <> +DS_D_INLINE __half init() +{ + constexpr __half_raw inf = {0x7C00}; + return __half(inf); +} + +template <> +DS_D_INLINE __half init() +{ + constexpr __half_raw neg_inf = {0xFC00}; + return __half(neg_inf); +} + +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE __nv_bfloat16 init() +{ + constexpr __nv_bfloat16_raw neg_inf = {0xFF80}; + return __nv_bfloat16(neg_inf); +} +#endif + +template <> +DS_D_INLINE __half2 init() +{ +#ifdef __HIP_PLATFORM_AMD__ + return __half2{_Float16_2{0x0000, 0x0000}}; +#else + constexpr __half2_raw zero = {0x0000, 0x0000}; + return __half2(zero); +#endif +} + +template <> +DS_D_INLINE __half2 init() +{ +#ifdef __HIP_PLATFORM_AMD__ + return __half2{_Float16_2{0x7C00, 0x7C00}}; +#else + constexpr __half2_raw inf = {0x7C00, 0x7C00}; + return __half2(inf); +#endif +} + +template <> +DS_D_INLINE __half2 init() +{ +#ifdef __HIP_PLATFORM_AMD__ + return __half2{_Float16_2{0xFC00, 0xFC00}}; +#else + constexpr __half2_raw neg_inf = {0xFC00, 0xFC00}; + return __half2(neg_inf); +#endif +} + +template <> +DS_D_INLINE int32_t init() +{ + return 0; +} + +template <> +DS_D_INLINE int32_t init() +{ + return 0x7FFFFFFF; +} + +template <> +DS_D_INLINE int32_t init() +{ + return 0x80000000; +} + +template <> +DS_D_INLINE uint32_t init() +{ + return 0; +} + +template <> +DS_D_INLINE uint32_t init() +{ + return 0xFFFFFFFF; +} + +template <> +DS_D_INLINE uint32_t init() +{ + return 0; +} + +template <> +DS_D_INLINE int64_t init() +{ + return 0; +} + +template <> +DS_D_INLINE int64_t init() +{ + return 0x7FFFFFFFFFFFFFFF; +} + +template <> +DS_D_INLINE int64_t init() +{ + return 0x8000000000000000; +} + +template <> +DS_D_INLINE uint64_t init() +{ + return 0; +} + +template <> +DS_D_INLINE uint64_t init() +{ + return 0xFFFFFFFFFFFFFFFF; +} + +template <> +DS_D_INLINE uint64_t init() +{ + return 0; +} + +template +DS_D_INLINE void init(T* data) +{ + data[0] = init(); +} + +template +DS_D_INLINE void init(T* data) +{ + data[0] = init(); + data[1] = init(); +} + +template +DS_D_INLINE void init(T* data) +{ + data[0] = init(); + data[1] = init(); + data[2] = init(); +} + +template +DS_D_INLINE void init(T* data) +{ + data[0] = init(); + data[1] = init(); + data[2] = init(); + data[3] = init(); +} + +/* +Warp reduction primitives + +`reduction_width` is an unsafe template parameter, that is that +when using `reduction_width` < hw_warp_size the warp is partitioned +into `hw_warp_size` / `reduction_width` groups of partial sums. + +If someone can figure out how to use variadic templates in a reasonable way +here (fold is C++17 only and I don't think helps and recursion feels like +huge overkill that harms readability) that would be wonderful. +*/ + +template +DS_D_INLINE void _warp(cg::thread_block_tile& warp, T* data) +{ +#pragma unroll + for (int i = 1; i < reduce_width; i *= 2) { + data[0] = element(data[0], warp.shfl_xor(data[0], i)); + } +} + +template +DS_D_INLINE void _warp(cg::thread_block_tile& warp, T* data) +{ +#pragma unroll + for (int i = 1; i < reduce_width; i *= 2) { + data[0] = element(data[0], warp.shfl_xor(data[0], i)); + data[1] = element(data[1], warp.shfl_xor(data[1], i)); + } +} + +template +DS_D_INLINE void _warp(cg::thread_block_tile& warp, T* data) +{ +#pragma unroll + for (int i = 1; i < reduce_width; i *= 2) { + data[0] = element(data[0], warp.shfl_xor(data[0], i)); + data[1] = element(data[1], warp.shfl_xor(data[1], i)); + data[2] = element(data[2], warp.shfl_xor(data[2], i)); + } +} + +template +DS_D_INLINE void _warp(cg::thread_block_tile& warp, T* data) +{ +#pragma unroll + for (int i = 1; i < reduce_width; i *= 2) { + data[0] = element(data[0], warp.shfl_xor(data[0], i)); + data[1] = element(data[1], warp.shfl_xor(data[1], i)); + data[2] = element(data[2], warp.shfl_xor(data[2], i)); + data[3] = element(data[3], warp.shfl_xor(data[3], i)); + } +} + +/* +Implementation for primary block reduction that serves both `block` and +`partitioned_block`. + +Total warps refers to the reduction width of the reduction, not +the number of warps in the block (which may exceed that +if the block is partitioned or if we do a conservative bound at +compile time). +*/ +template +DS_D_INLINE void _block(cg::thread_block& tb, + cg::thread_block_tile& warp_arg, + T* data) +{ + constexpr int elems = sizeof...(Ops); + constexpr int bytes = sizeof(T); + // Unused when `partition_size == 1` or total_warps == 1 + __shared__ T reduce_buffer[max_warps * elems]; + +#ifdef __HIP_PLATFORM_AMD__ + const int total_threads = blockDim.x * blockDim.y * blockDim.z; + const int running_warps = total_threads / hw_warp_size; +#else + const int total_threads = blockDim.x * blockDim.y * blockDim.z; + int running_warps = total_threads / hw_warp_size; + if (running_warps == 0){ + running_warps = 1; + } +#endif + + // Always perform warp-scope reduction + _warp(warp_arg, data); + + // If max_warps == 1 let's skip the runtime check + if (total_warps != 1) { + if (warp_arg.thread_rank() == 0) { +#pragma unroll + for (int i = 0; i < elems; i++) { + mem_access::store_shared(reduce_buffer + elems * _warp_rank() + i, data + i); + } + } + + // Synchronization inside block-uniform conditional is safe + tb.sync(); + + if (_warp_rank() == 0) { + if (warp_arg.thread_rank() < running_warps) { +#pragma unroll + for (int i = 0; i < elems; i++) { + mem_access::load_shared( + data + i, reduce_buffer + elems * warp_arg.thread_rank() + i); + } + } else { + init(data); + } + + _warp(warp_arg, data); + +#pragma unroll + for (int i = 0; i < elems; i++) { + mem_access::store_shared(reduce_buffer + elems * warp_arg.thread_rank() + i, + data + i); + } + } + + // Synchronization inside block-uniform conditional is safe + tb.sync(); + +#pragma unroll + for (int i = 0; i < elems; i++) { + mem_access::load_shared(data + i, reduce_buffer + _warp_rank() * elems + i); + } + } +} + +/* +Main API implementations. For the most part, they just convert the individual +variables into arrays, which makes working with them easier with a single +implementation. In theory, we could use the `_block` implementation as another +option, but the nature of using a pointer is a little less safe and this allows +us to obfuscate the details of the partitioned implementation. +*/ +template +DS_D_INLINE void block(cg::thread_block& tb, cg::thread_block_tile& warp, float& val) +{ + _block(tb, warp, &val); +} + +template +DS_D_INLINE void block(cg::thread_block& tb, + cg::thread_block_tile& warp, + float& val1, + float& val2) +{ + float data[2] = {val1, val2}; + _block(tb, warp, data); + val1 = data[0]; + val2 = data[1]; +} + +template +DS_D_INLINE void block(cg::thread_block& tb, + cg::thread_block_tile& warp, + float& val1, + float& val2, + float& val3) +{ + float data[3] = {val1, val2, val3}; + _block(tb, warp, data); + val1 = data[0]; + val2 = data[1]; + val3 = data[2]; +} + +template +DS_D_INLINE void block(cg::thread_block& tb, + cg::thread_block_tile& warp, + float& val1, + float& val2, + float& val3, + float& val4) +{ + float data[4] = {val1, val2, val3, val4}; + _block(tb, warp, data); + val1 = data[0]; + val2 = data[1]; + val3 = data[2]; + val4 = data[3]; +} + +/* +Note: for the partitioned blocks, the implementation does not support non-power of 2 blocks in order +to shorten block scale reduction length. +*/ +template +DS_D_INLINE void partitioned_block(cg::thread_block& tb, + cg::thread_block_tile& warp, + float& val) +{ + if (num_threads <= hw_warp_size) { + _warp(warp, &val); + } else { + constexpr int num_warps = num_threads / hw_warp_size; + _block(tb, warp, &val); + } +} + +template +DS_D_INLINE void partitioned_block(cg::thread_block& tb, + cg::thread_block_tile& warp, + float& val1, + float& val2) +{ + float data[2] = {val1, val2}; + + if (num_threads <= hw_warp_size) { + _warp(warp, data); + } else { + constexpr int num_warps = num_threads / hw_warp_size; + _block(tb, warp, data); + } + + val1 = data[0]; + val2 = data[1]; +} + +template +DS_D_INLINE void partitioned_block(cg::thread_block& tb, + cg::thread_block_tile& warp, + float& val1, + float& val2, + float& val3) +{ + float data[3] = {val1, val2, val3}; + + if (num_threads <= hw_warp_size) { + _warp(warp, data); + } else { + constexpr int num_warps = num_threads / hw_warp_size; + _block(tb, warp, data); + } + + val1 = data[0]; + val2 = data[1]; + val3 = data[2]; +} + +template +DS_D_INLINE void partitioned_block(cg::thread_block& tb, + cg::thread_block_tile& warp, + float& val1, + float& val2, + float& val3, + float& val4) +{ + float data[4] = {val1, val2, val3, val4}; + + if (num_threads <= hw_warp_size) { + _warp(warp, data); + } else { + constexpr int num_warps = num_threads / hw_warp_size; + _block(tb, warp, data); + } + + val1 = data[0]; + val2 = data[1]; + val3 = data[2]; + val4 = data[3]; +} + +/* +Arg-reduce is a specialization of the above. We only support this with a single reduction +parameter. This only works for max/min reductions. +*/ + +__align__(8) struct IdxReduceResult { + /* + NOTE: ORDERING MATTERS HERE! The idx is the least significant set of bits + and the val is the most significant. Changing the order of this declaration + will break the code. + */ + int idx; + float val; +}; + +template +DS_D_INLINE IdxReduceResult +idx_reduce(cg::thread_block& tb, cg::thread_block_tile& warp, float val, int idx) +{ + IdxReduceResult res = {idx, val}; + + // Clear out the nan. This shouldn't be an issue for our initial applications + if (isnan(val)) res.val = init(); + + // Can do float compares as integers. By packing the index into the lower bits + // we can just do a single int64 rather than a branch, compare, and select. + // One side benefit of this is that it is by nature a stable algorithm and + // will always bias ties to the higher index. + int64_t* res_as_int = reinterpret_cast(&res); + + // The way floating point compare works is normally to perform a sign comparison + // and if they match, then do a comparison of the rest of the bits as unsigned + // integers. Since we are bundling these, that means for negative values we need + // to reverse the sort order, which we can do with an XOR. + if (val < 0) { *res_as_int ^= 0x7fffffff00000000; } + + _block(tb, warp, res_as_int); + + // Sign bit is preserved, so we can check if we need to invert the mantissa back + if (res.val < 0) { *res_as_int ^= 0x7fffffff00000000; } + + return res; +} + +} // namespace reduce diff --git a/toolbox/DeepSpeed/v0.15.3/patches/csrc/includes/static_switch.h b/toolbox/DeepSpeed/v0.15.3/patches/csrc/includes/static_switch.h new file mode 100644 index 0000000000000000000000000000000000000000..f7d74484a457a1682917693842adf3ef1033d36b --- /dev/null +++ b/toolbox/DeepSpeed/v0.15.3/patches/csrc/includes/static_switch.h @@ -0,0 +1,57 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +/* +Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may +not use this file except in compliance with the License. You may obtain +a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// From +// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h + +#pragma once + +/// @param COND - a boolean expression to switch by +/// @param CONST_NAME - a name given for the constexpr bool variable. +/// @param ... - code to execute for true and false +/// +/// Usage: +/// ``` +/// BOOL_SWITCH(flag, BoolConst, [&] { +/// some_function(...); +/// }); +/// ``` +#define BOOL_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + if (COND) { \ + constexpr static bool CONST_NAME = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + } \ + }() \ No newline at end of file diff --git a/toolbox/DeepSpeed/v0.15.3/patches/csrc/includes/swiglu.h b/toolbox/DeepSpeed/v0.15.3/patches/csrc/includes/swiglu.h new file mode 100644 index 0000000000000000000000000000000000000000..447da22f1485b358146ad99c0f5a79c3b700bb23 --- /dev/null +++ b/toolbox/DeepSpeed/v0.15.3/patches/csrc/includes/swiglu.h @@ -0,0 +1,41 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +/* +Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may +not use this file except in compliance with the License. You may obtain +a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#pragma once + +#include +#include + +torch::Tensor launch_swiglu_kernel(torch::Tensor& input); +torch::Tensor launch_swiglu_kernel_bwd(torch::Tensor& input, torch::Tensor& grad); + + diff --git a/toolbox/DeepSpeed/v0.15.3/patches/csrc/includes/type_shim_rope.h b/toolbox/DeepSpeed/v0.15.3/patches/csrc/includes/type_shim_rope.h new file mode 100644 index 0000000000000000000000000000000000000000..f6eb6f4ef62e9cf1dad54e80721986ad7ac2442e --- /dev/null +++ b/toolbox/DeepSpeed/v0.15.3/patches/csrc/includes/type_shim_rope.h @@ -0,0 +1,454 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +/* +Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may +not use this file except in compliance with the License. You may obtain +a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include +#include "compat.h" + +// Forward/backward compatiblity hack around +// https://github.com/pytorch/pytorch/commit/3aeb78079bcd68282fe9117088e138b77318e288 +// pending more future-proof guidance from upstream. +// struct TypeShim +// { +// const at::Type& payload; +// TypeShim(const at::Type& type) : payload(type) {} +// // Enable trivial conversion to a const at::Type& for pre-3aeb78 +// operator const at::Type&(){ return payload; }; +// // Enable dispatch switch statements to take *this directly for post-3aeb78 +// //operator at::ScalarType(){ return payload.; }; +// }; + +#define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \ + switch(TYPE) \ + { \ + case at::ScalarType::Float: \ + { \ + using scalar_t_##LEVEL = float; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Half: \ + { \ + using scalar_t_##LEVEL = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + } + + +#define DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, LEVEL, NAME, ...) \ + switch(TYPE) \ + { \ + case at::ScalarType::Float: \ + { \ + using scalar_t_##LEVEL = float; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Half: \ + { \ + using scalar_t_##LEVEL = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::BFloat16: \ + { \ + using scalar_t_##LEVEL = at::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + } + + +#define DISPATCH_FLOAT_HALF_AND_BYTE(TYPE, LEVEL, NAME, ...) \ + switch(TYPE) \ + { \ + case at::ScalarType::Float: \ + { \ + using scalar_t_##LEVEL = float; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Half: \ + { \ + using scalar_t_##LEVEL = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Byte: \ + { \ + using scalar_t_##LEVEL = uint8_t; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + } + + +#define DISPATCH_DOUBLE_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \ + switch(TYPE) \ + { \ + case at::ScalarType::Double: \ + { \ + using scalar_t_##LEVEL = double; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Float: \ + { \ + using scalar_t_##LEVEL = float; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Half: \ + { \ + using scalar_t_##LEVEL = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + } + + +#define DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(TYPE, LEVEL, NAME, ...) \ + switch(TYPE) \ + { \ + case at::ScalarType::Double: \ + { \ + using scalar_t_##LEVEL = double; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Float: \ + { \ + using scalar_t_##LEVEL = float; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Half: \ + { \ + using scalar_t_##LEVEL = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::BFloat16: \ + { \ + using scalar_t_##LEVEL = at::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + } + + + #define DISPATCH_DOUBLE_AND_FLOAT(TYPE, LEVEL, NAME, ...) \ + switch(TYPE) \ + { \ + case at::ScalarType::Double: \ + { \ + using scalar_t_##LEVEL = double; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Float: \ + { \ + using scalar_t_##LEVEL = float; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + } + + + #define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \ + switch(TYPE) \ + { \ + case at::ScalarType::Half: \ + { \ + using scalar_t = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::BFloat16: \ + { \ + using scalar_t = at::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + } + + + #define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \ + switch(TYPEIN) \ + { \ + case at::ScalarType::Float: \ + { \ + using scalar_t_in = float; \ + switch(TYPEOUT) \ + { \ + case at::ScalarType::Float: \ + { \ + using scalar_t_out = float; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Half: \ + { \ + using scalar_t_out = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::BFloat16: \ + { \ + using scalar_t_out = at::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \ + } \ + break; \ + } \ + case at::ScalarType::Half: \ + { \ + using scalar_t_in = at::Half; \ + using scalar_t_out = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::BFloat16: \ + { \ + using scalar_t_in = at::BFloat16; \ + using scalar_t_out = at::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \ + } + + + #define DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \ + switch(TYPEIN) \ + { \ + case at::ScalarType::Double: \ + { \ + using scalar_t_in = double; \ + switch(TYPEOUT) \ + { \ + case at::ScalarType::Double: \ + { \ + using scalar_t_out = double; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Float: \ + { \ + using scalar_t_out = float; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Half: \ + { \ + using scalar_t_out = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::BFloat16: \ + { \ + using scalar_t_out = at::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \ + } \ + break; \ + } \ + case at::ScalarType::Float: \ + { \ + using scalar_t_in = float; \ + switch(TYPEOUT) \ + { \ + case at::ScalarType::Float: \ + { \ + using scalar_t_out = float; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Half: \ + { \ + using scalar_t_out = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::BFloat16: \ + { \ + using scalar_t_out = at::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \ + } \ + break; \ + } \ + case at::ScalarType::Half: \ + { \ + using scalar_t_in = at::Half; \ + using scalar_t_out = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::BFloat16: \ + { \ + using scalar_t_in = at::BFloat16; \ + using scalar_t_out = at::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \ + } + + +template +__device__ __forceinline__ T reduce_block_into_lanes + (T *x, + T val, + int lanes=1, + bool share_result=false) // lanes is intended to be <= warpSize. +{ + int tid = threadIdx.x + threadIdx.y*blockDim.x; + int blockSize = blockDim.x*blockDim.y; // blockSize is intended to be a multiple of warpSize. + auto double_warp_size = warpSize * 2; + + if(blockSize >= double_warp_size) + { + x[tid] = val; + __syncthreads(); + } + + #pragma unroll + for(int i = (blockSize >> 1); i >= double_warp_size; i >>= 1) + { + if(tid < i) + x[tid] = x[tid] + x[tid+i]; + __syncthreads(); + } + + T final; + + if(tid < warpSize) + { + if(blockSize >= double_warp_size) + final = x[tid] + x[tid + warpSize]; + else + final = val; + // __SYNCWARP(); + + #pragma unroll + for(int i = warpSize / 2; i >= lanes; i >>= 1) + final = final + __shfl_down_sync(0xffffffff, final, i); + } + + if(share_result) + { + if(tid < lanes) + x[tid] = final; // EpilogueOp + // Make sure the smem result is visible to all warps. + // __syncthreads(); + } + __syncthreads(); + + return final; +} + +template +__device__ __forceinline__ T reduce_block_into_lanes_max_op + (T *x, + T val, + int lanes=1, + bool share_result=false) // lanes is intended to be <= warpSize. +{ + int tid = threadIdx.x + threadIdx.y*blockDim.x; + int blockSize = blockDim.x*blockDim.y; // blockSize is intended to be a multiple of warpSize. + auto double_warp_size = warpSize * 2; + + if(blockSize >= double_warp_size) + { + x[tid] = val; + __syncthreads(); + } + + #pragma unroll + for(int i = (blockSize >> 1); i >= double_warp_size; i >>= 1) + { + if(tid < i) + x[tid] = fmaxf(fabsf(x[tid]), fabsf(x[tid+i])); + __syncthreads(); + } + + T final; + + if(tid < warpSize) + { + if(blockSize >= double_warp_size) + final = fmaxf(fabsf(x[tid]), fabsf(x[tid + warpSize])); + else + final = val; + // __SYNCWARP(); + + #pragma unroll + for(int i = warpSize / 2; i >= lanes; i >>= 1) + final = fmaxf(fabsf(final), fabsf(__shfl_down_sync(0xffffffff, final, i))); + } + + if(share_result) + { + if(tid < lanes) + x[tid] = final; // EpilogueOp + // Make sure the smem result is visible to all warps. + __syncthreads(); + } + + return final; +} \ No newline at end of file diff --git a/toolbox/DeepSpeed/v0.15.3/patches/csrc/lamb/fused_lamb_cuda_kernel.cu b/toolbox/DeepSpeed/v0.15.3/patches/csrc/lamb/fused_lamb_cuda_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..b58eb63bd236e8ae310a2d66480f2c9d62cfbd51 --- /dev/null +++ b/toolbox/DeepSpeed/v0.15.3/patches/csrc/lamb/fused_lamb_cuda_kernel.cu @@ -0,0 +1,535 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +/* Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. */ +/* All Rights Reserved. */ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include +#include +#include +#include +#include "ATen/ATen.h" +#include "ATen/TensorUtils.h" +#include "ATen/cuda/CUDAContext.h" +#include "ATen/cuda/detail/IndexUtils.cuh" +//#include +#include "ATen/AccumulateType.h" + +#include + +#ifdef __ILUVATAR__ +#include +#else +#include +#if defined(__HIP_PLATFORM_AMD__) && HIP_VERSION > 305 +#include +#else +#include +#endif +#endif +#include +#include + +namespace cg = cooperative_groups; + +// Utility class used to avoid linker errors with extern +// unsized shared memory arrays with templated type +namespace { +// This is the un-specialized struct. Note that we prevent instantiation of this +// struct by putting an undefined symbol in the function body so it won't compile. +template +struct SharedMemory { + // Ensure that we won't compile any un-specialized types + __device__ inline operator T*() + { +#ifndef _WIN32 + extern __device__ void error(void); + error(); +#endif + return NULL; + } +}; + +template <> +struct SharedMemory { + __device__ inline operator float*() + { + extern __shared__ float s_float[]; + return s_float; + } +}; + +template <> +struct SharedMemory { + __device__ inline operator double*() + { + extern __shared__ double s_double[]; + return s_double; + } +}; +} // namespace + +#include "type_shim.h" + +typedef enum { + ADAM_MODE_0 = 0, // eps under square root + ADAM_MODE_1 = 1 // eps outside square root +} adamMode_t; + +// s_a and s_b are in shared memory +// g_a and g_b are in shared memory +template +__device__ void reduce_block_in_shared_memory(T* s_a, T* s_b, T* g_a, T* g_b) +{ + // Handle to thread block group + cg::thread_block cta = cg::this_thread_block(); + + // perform block reduction in shared memory, + unsigned int tid = cta.thread_rank(); + + T a_sum = s_a[tid]; + T b_sum = s_b[tid]; + +#if defined(__HIP_PLATFORM_AMD__) && HIP_VERSION > 305 + cta.sync(); +#else + cg::sync(cta); +#endif + + // do reduction in shared mem + if ((blockSize >= 512) && (tid < 256)) { + s_a[tid] = a_sum = a_sum + s_a[tid + 256]; + s_b[tid] = b_sum = b_sum + s_b[tid + 256]; + } + +#if defined(__HIP_PLATFORM_AMD__) && HIP_VERSION > 305 + cta.sync(); +#else + cg::sync(cta); +#endif + + if ((blockSize >= 256) && (tid < 128)) { + s_a[tid] = a_sum = a_sum + s_a[tid + 128]; + s_b[tid] = b_sum = b_sum + s_b[tid + 128]; + } + +#if defined(__HIP_PLATFORM_AMD__) && HIP_VERSION > 305 + cta.sync(); +#else + cg::sync(cta); +#endif + + if ((blockSize >= 128) && (tid < 64)) { + s_a[tid] = a_sum = a_sum + s_a[tid + 64]; + s_b[tid] = b_sum = b_sum + s_b[tid + 64]; + } + +#if defined(__HIP_PLATFORM_AMD__) && HIP_VERSION > 305 + cta.sync(); +#else + cg::sync(cta); +#endif + +#if 0 //(__CUDA_ARCH__ >= 300) + if (tid < 32) { + cg::coalesced_group active = cg::coalesced_threads(); + + // Fetch final intermediate sum from 2nd warp + if (blockSize >= 64) { + a_sum = a_sum + s_a[tid + 32]; + b_sum = b_sum + s_b[tid + 32]; + } + + // Reduce final warp using shuffle + for (int offset = warpSize / 2; offset > 0; offset /= 2) { + a_sum += active.shfl_down(a_sum, offset); + b_sum += active.shfl_down(b_sum, offset); + } + } +#else + if ((blockSize >= 64) && (tid < 32)) { + s_a[tid] = a_sum = a_sum + s_a[tid + 32]; + s_b[tid] = b_sum = b_sum + s_b[tid + 32]; + } + +#if defined(__HIP_PLATFORM_AMD__) && HIP_VERSION > 305 + cta.sync(); +#else + cg::sync(cta); +#endif + + if ((blockSize >= 32) && (tid < 16)) { + s_a[tid] = a_sum = a_sum + s_a[tid + 16]; + s_b[tid] = b_sum = b_sum + s_b[tid + 16]; + } + +#if defined(__HIP_PLATFORM_AMD__) && HIP_VERSION > 305 + cta.sync(); +#else + cg::sync(cta); +#endif + + if ((blockSize >= 16) && (tid < 8)) { + s_a[tid] = a_sum = a_sum + s_a[tid + 8]; + s_b[tid] = b_sum = b_sum + s_b[tid + 8]; + } + +#if defined(__HIP_PLATFORM_AMD__) && HIP_VERSION > 305 + cta.sync(); +#else + cg::sync(cta); +#endif + + if ((blockSize >= 8) && (tid < 4)) { + s_a[tid] = a_sum = a_sum + s_a[tid + 4]; + s_b[tid] = b_sum = b_sum + s_b[tid + 4]; + } + +#if defined(__HIP_PLATFORM_AMD__) && HIP_VERSION > 305 + cta.sync(); +#else + cg::sync(cta); +#endif + + if ((blockSize >= 4) && (tid < 2)) { + s_a[tid] = a_sum = a_sum + s_a[tid + 2]; + s_b[tid] = b_sum = b_sum + s_b[tid + 2]; + } + +#if defined(__HIP_PLATFORM_AMD__) && HIP_VERSION > 305 + cta.sync(); +#else + cg::sync(cta); +#endif + + if ((blockSize >= 2) && (tid < 1)) { + s_a[tid] = a_sum = a_sum + s_a[tid + 1]; + s_b[tid] = b_sum = b_sum + s_b[tid + 1]; + } + +#if defined(__HIP_PLATFORM_AMD__) && HIP_VERSION > 305 + cta.sync(); +#else + cg::sync(cta); +#endif + +#endif + + // write result for this block to global mem + if (tid == 0) { + g_a[blockIdx.x] = (T)a_sum; + g_b[blockIdx.x] = (T)b_sum; + } +} + +template +__device__ void reduce_two_vectors_in_register(T a, T b, T* g_a, T* g_b) +{ + const int threadIdInBlock = cg::this_thread_block().thread_rank(); + + T* s_a = SharedMemory(); + T* s_b = SharedMemory() + cg::this_thread_block().size(); + + s_a[threadIdInBlock] = a; + s_b[threadIdInBlock] = b; + + reduce_block_in_shared_memory(s_a, s_b, g_a, g_b); +} + +template +__global__ void lamb_cuda_kernel_part1( + T* __restrict__ p, + GRAD_T* __restrict__ p_copy, // For mixed precision training, pass NULL if not needed + T* __restrict__ m, + T* __restrict__ v, + const GRAD_T* __restrict__ g, + const float b1, + const float b2, + const float eps, + const float grad_scale, + const float step_size, + const size_t tsize, + adamMode_t mode, + const float decay, + T* __restrict__ w_l2_i, + T* __restrict__ u_l2_i) +{ + // Assuming 2D grids and 2D blocks + const int blockId = gridDim.x * blockIdx.y + blockIdx.x; + const int threadsPerBlock = blockDim.x * blockDim.y; + const int threadIdInBlock = cg::this_thread_block().thread_rank(); + const int i = (blockId * threadsPerBlock + threadIdInBlock); + const int totThreads = gridDim.x * gridDim.y * threadsPerBlock; + + T reg_w = 0; + T reg_u = 0; + + for (int j = i; j < tsize; j += totThreads) { + T scaled_grad = g[j] / grad_scale; + T pj = p[j]; + m[j] = b1 * m[j] + (1 - b1) * scaled_grad; + v[j] = b2 * v[j] + (1 - b2) * scaled_grad * scaled_grad; + float denom; + if (mode == ADAM_MODE_0) + denom = sqrtf(v[j] + eps); + else // Mode 1 + denom = sqrtf(v[j]) + eps; + T update = (m[j] / denom) + (decay * p[j]); + + reg_u += update * update; + reg_w += pj * pj; + } + + reduce_two_vectors_in_register(reg_w, reg_u, w_l2_i, u_l2_i); +} + +template +__global__ void lamb_cuda_kernel_part2(const size_t tsize, T* __restrict__ g_a, T* __restrict__ g_b) +{ + T* s_a = SharedMemory(); + T* s_b = SharedMemory() + cg::this_thread_block().size(); + + const int threadIdInBlock = cg::this_thread_block().thread_rank(); + + s_a[threadIdInBlock] = g_a[threadIdInBlock]; + s_b[threadIdInBlock] = g_b[threadIdInBlock]; + + if (threadIdInBlock >= tsize) { + s_a[threadIdInBlock] = 0.0; + s_b[threadIdInBlock] = 0.0; + } + + reduce_block_in_shared_memory(s_a, s_b, g_a, g_b); +} + +template +__global__ void lamb_cuda_kernel_part3( + T* __restrict__ p, + GRAD_T* __restrict__ p_copy, // For mixed precision training, pass NULL if not needed + T* __restrict__ m, + T* __restrict__ v, + const GRAD_T* __restrict__ g, + const float b1, + const float b2, + const float max_coeff, + const float min_coeff, + const float eps, + const float grad_scale, + const float step_size, + const size_t tsize, + adamMode_t mode, + const float decay, + T* __restrict__ w_l2_i, + T* __restrict__ u_l2_i, + T* __restrict__ lamb_coeff_val) +{ + // Assuming 2D grids and 2D blocks + const int blockId = gridDim.x * blockIdx.y + blockIdx.x; + const int threadsPerBlock = blockDim.x * blockDim.y; + const int threadIdInBlock = cg::this_thread_block().thread_rank(); + const int i = (blockId * threadsPerBlock + threadIdInBlock); + const int totThreads = gridDim.x * gridDim.y * threadsPerBlock; + + T reg_w = sqrtf(w_l2_i[0]); + T reg_u = sqrtf(u_l2_i[0]); + + float lamb_coeff = 1.0; + + if (reg_w != 0 && reg_u != 0) { + lamb_coeff = reg_w / reg_u; + if (lamb_coeff > max_coeff) { lamb_coeff = max_coeff; } + if (lamb_coeff < min_coeff) { lamb_coeff = min_coeff; } + } + + if (blockId == 0 && threadIdInBlock == 0) { + lamb_coeff_val[0] = lamb_coeff; + // printf("Cuda Lamb Coeff is %.6f \n",lamb_coeff); + } + + for (int j = i; j < tsize; j += totThreads) { + T pj = (float)p[j]; + T mj = m[j]; + T vj = v[j]; + float denom; + if (mode == ADAM_MODE_0) + denom = sqrtf(vj + eps); + else // Mode 1 + denom = sqrtf(vj) + eps; + T update = (mj / denom) + (decay * pj); + + pj = pj - (step_size * lamb_coeff * update); + p[j] = pj; + if (p_copy != NULL) p_copy[j] = (GRAD_T)pj; + } +} + +void fused_lamb_cuda(at::Tensor& p, + at::Tensor& p_copy, + at::Tensor& m, + at::Tensor& v, + at::Tensor& g, + float lr, + float beta1, + float beta2, + float max_coeff, + float min_coeff, + float eps, + float grad_scale, + int step, + int mode, + int bias_correction, + float decay, + at::Tensor& w_l2_i, + at::Tensor& u_l2_i, + at::Tensor& lamb_coeff) +{ + // using namespace at; + + // Get tensor size + int tsize = p.numel(); + // Determine #threads and #blocks + const int threadsPerBlock = 512; + int num_blocks = (tsize + threadsPerBlock - 1) / threadsPerBlock; + if (num_blocks > 512) num_blocks = 512; + + int smemsize = 0; + smemsize = 2 * threadsPerBlock * sizeof(float); + + const dim3 blocks(num_blocks); + const dim3 threads(threadsPerBlock); + + AT_ASSERTM(at::cuda::detail::canUse32BitIndexMath(p), + "parameter tensor is too large to be indexed with int32"); + // Constants + float step_size = 0; + if (bias_correction == 1) { + const float bias_correction1 = 1 - std::pow(beta1, step); + const float bias_correction2 = 1 - std::pow(beta2, step); + step_size = lr * std::sqrt(bias_correction2) / bias_correction1; + } else { + step_size = lr; + } + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + if (g.type().scalarType() == at::ScalarType::Half) { + // all other values should be fp32 for half gradients + AT_ASSERTM(p.type().scalarType() == at::ScalarType::Float, + "expected parameter to be of float type"); + // dispatch is done on the gradient type + using namespace at; // prevents "toString is undefined" errors + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + g.scalar_type(), "lamb_cuda_kernel", ([&] { + using accscalar_t = at::acc_type; + + lamb_cuda_kernel_part1 + <<>>( + p.data(), + p_copy.numel() ? p_copy.data() : NULL, + m.data(), + v.data(), + g.data(), + beta1, + beta2, + eps, + grad_scale, + step_size, + tsize, + (adamMode_t)mode, + decay, + w_l2_i.data(), + u_l2_i.data()); + + lamb_cuda_kernel_part2 + <<<1, threadsPerBlock, smemsize, stream>>>( + num_blocks, w_l2_i.data(), u_l2_i.data()); + + lamb_cuda_kernel_part3 + <<>>( + p.data(), + p_copy.numel() ? p_copy.data() : NULL, + m.data(), + v.data(), + g.data(), + beta1, + beta2, + max_coeff, + min_coeff, + eps, + grad_scale, + step_size, + tsize, + (adamMode_t)mode, + decay, + w_l2_i.data(), + u_l2_i.data(), + lamb_coeff.data()); + })); + } else { + using namespace at; + AT_DISPATCH_FLOATING_TYPES( + g.scalar_type(), "lamb_cuda_kernel", ([&] { + lamb_cuda_kernel_part1 + <<>>( + p.data(), + NULL, // don't output p_copy for fp32, it's wasted write + m.data(), + v.data(), + g.data(), + beta1, + beta2, + eps, + grad_scale, + step_size, + tsize, + (adamMode_t)mode, + decay, + w_l2_i.data(), + u_l2_i.data()); + + lamb_cuda_kernel_part2 + <<<1, threadsPerBlock, smemsize, stream>>>( + num_blocks, w_l2_i.data(), u_l2_i.data()); + + lamb_cuda_kernel_part3 + <<>>( + p.data(), + NULL, // don't output p_copy for fp32, it's wasted write + m.data(), + v.data(), + g.data(), + beta1, + beta2, + max_coeff, + min_coeff, + eps, + grad_scale, + step_size, + tsize, + (adamMode_t)mode, + decay, + w_l2_i.data(), + u_l2_i.data(), + lamb_coeff.data()); + })); + } + C10_CUDA_CHECK(cudaGetLastError()); +} + +// template __device__ void reduce_two_vectors_in_register(float a, float b, float* g_a, +// float* g_b, cg::grid_group &cgg); \ No newline at end of file diff --git a/toolbox/DeepSpeed/v0.15.3/patches/csrc/layernorm/layer_norm_cuda.cpp b/toolbox/DeepSpeed/v0.15.3/patches/csrc/layernorm/layer_norm_cuda.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5b5b091cce1692abd41076382df167c4270f9b93 --- /dev/null +++ b/toolbox/DeepSpeed/v0.15.3/patches/csrc/layernorm/layer_norm_cuda.cpp @@ -0,0 +1,578 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +/* +Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may +not use this file except in compliance with the License. You may obtain +a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include +#include +#include +#include "compat.h" + +namespace { +void compute_n1_n2( + at::Tensor input, + #ifdef VERSION_GE_1_1 + at::IntArrayRef normalized_shape, + #else + at::IntList normalized_shape, + #endif + int& n1, + int& n2) +{ + int idiff = input.ndimension() - normalized_shape.size(); + n2 = 1; + for (int i = 0; i < (int)normalized_shape.size(); ++i) { + assert( input.sizes()[i+idiff] == normalized_shape[i] ); + n2 *= normalized_shape[i]; + } + n1 = 1; + for (int i = 0; i < idiff; ++i) { + n1 *= input.sizes()[i]; + } +} + +void check_args( + #ifdef VERSION_GE_1_1 + at::IntArrayRef normalized_shape, + #else + at::IntList normalized_shape, + #endif + at::Tensor gamma, + at::Tensor beta + ) +{ + TORCH_CHECK(!gamma.defined() || gamma.sizes().equals(normalized_shape)); + TORCH_CHECK(!beta.defined() || beta.sizes().equals(normalized_shape)); +} + +void check_args( + #ifdef VERSION_GE_1_1 + at::IntArrayRef normalized_shape, + #else + at::IntList normalized_shape, + #endif + at::Tensor gamma + ) +{ + TORCH_CHECK(!gamma.defined() || gamma.sizes().equals(normalized_shape)); +} + + +void check_args( + at::Tensor input, + #ifdef VERSION_GE_1_1 + at::IntArrayRef normalized_shape, + #else + at::IntList normalized_shape, + #endif + int& n1, + int& n2 + ) +{ + int64_t normalized_ndim = normalized_shape.size(); + + if (normalized_ndim < 1) { + std::stringstream ss; + ss << "Expected normalized_shape to be at least 1-dimensional, i.e., " + << "containing at least one element, but got normalized_shape=" + << normalized_shape; + throw std::runtime_error(ss.str()); + } + + auto input_shape = input.sizes(); + auto input_ndim = input.dim(); + + if (input_ndim < normalized_ndim || + !input_shape.slice(input_ndim - normalized_ndim).equals(normalized_shape)) { + std::stringstream ss; + ss << "Given normalized_shape=" << normalized_shape + << ", expected input with shape [*"; + for (auto size : normalized_shape) { + ss << ", " << size; + } + ss << "], but got input of size" << input_shape; + throw std::runtime_error(ss.str()); + } + + compute_n1_n2(input,normalized_shape,n1,n2); +} + +void check_args( + at::Tensor input, + #ifdef VERSION_GE_1_1 + at::IntArrayRef normalized_shape, + #else + at::IntList normalized_shape, + #endif + at::Tensor gamma, + at::Tensor beta, + int& n1, + int& n2 + ) +{ + check_args(input,normalized_shape,n1,n2); + check_args(normalized_shape,gamma,beta); +} + +void check_args( + at::Tensor input, + #ifdef VERSION_GE_1_1 + at::IntArrayRef normalized_shape, + #else + at::IntList normalized_shape, + #endif + at::Tensor gamma, + int& n1, + int& n2 + ) +{ + check_args(input,normalized_shape,n1,n2); + check_args(normalized_shape,gamma); +} +} + +void cuda_layer_norm( + at::Tensor* output, + at::Tensor* mean, + at::Tensor* invvar, + at::Tensor* input, + int n1, + int n2, + #ifdef VERSION_GE_1_1 + at::IntArrayRef normalized_shape, + #else + at::IntList normalized_shape, + #endif + at::Tensor* gamma, + at::Tensor* beta, + float epsilon); + +#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + +std::vector layer_norm( + at::Tensor input, + #ifdef VERSION_GE_1_1 + at::IntArrayRef normalized_shape, + #else + at::IntList normalized_shape, + #endif + float epsilon) { + CHECK_INPUT(input); + int n1,n2; + check_args(input,normalized_shape,n1,n2); + at::Tensor output = at::empty_like(input); + at::Tensor mean = at::empty({n1}, input.options().dtype(input.scalar_type()==at::ScalarType::Half || input.scalar_type()==at::ScalarType::BFloat16 ? at::ScalarType::Float : input.scalar_type())); + at::Tensor invvar = at::empty_like(mean); + cuda_layer_norm(&output,&mean,&invvar,&input,n1,n2, + normalized_shape,NULL,NULL,epsilon); + return {output, mean, invvar}; +} + +std::vector layer_norm_affine( + at::Tensor input, + #ifdef VERSION_GE_1_1 + at::IntArrayRef normalized_shape, + #else + at::IntList normalized_shape, + #endif + at::Tensor gamma, + at::Tensor beta, + float epsilon) { + CHECK_INPUT(input); + CHECK_INPUT(gamma); + CHECK_INPUT(beta); + int n1,n2; + check_args(input,normalized_shape,gamma,beta,n1,n2); + at::Tensor output = at::empty_like(input); + const auto stats_dtype = (input.scalar_type() == at::ScalarType::Half || input.scalar_type() == at::ScalarType::BFloat16) ? at::ScalarType::Float : input.scalar_type(); + at::Tensor mean = at::empty({n1}, input.options().dtype(stats_dtype)); + at::Tensor invvar = at::empty_like(mean); + cuda_layer_norm(&output,&mean,&invvar,&input,n1,n2, + normalized_shape,&gamma,&beta,epsilon); + return {output, mean, invvar}; +} + +std::vector layer_norm_affine_mixed_dtypes( + at::Tensor input, + #ifdef VERSION_GE_1_1 + at::IntArrayRef normalized_shape, + #else + at::IntList normalized_shape, + #endif + at::Tensor gamma, + at::Tensor beta, + double epsilon) { + CHECK_INPUT(input); + int n1, n2; + check_args(input, normalized_shape, n1, n2); + at::Tensor output = at::empty_like(input, gamma.options().dtype(gamma.scalar_type())); + at::Tensor mean = at::empty({n1}, input.options().dtype(input.scalar_type() == at::ScalarType::Half || input.scalar_type() == at::ScalarType::BFloat16 ? at::ScalarType::Float : input.scalar_type())); + at::Tensor invvar = at::empty_like(mean); + cuda_layer_norm(&output, &mean, &invvar, &input, n1, n2, + normalized_shape, &gamma, &beta, epsilon); + return {output, mean, invvar}; +} + +void cuda_layer_norm_gradient( + at::Tensor* dout, + at::Tensor* mean, + at::Tensor* invvar, + at::Tensor* input, + int n1, + int n2, + #ifdef VERSION_GE_1_1 + at::IntArrayRef normalized_shape, + #else + at::IntList normalized_shape, + #endif + at::Tensor* gamma, + at::Tensor* beta, + float epsilon, + at::Tensor* grad_input, + at::Tensor* grad_gamma, + at::Tensor* grad_beta, + bool memory_efficient + ); + +at::Tensor layer_norm_gradient( + at::Tensor dout, + c10::optional mean, + at::Tensor invvar, + at::Tensor input, + #ifdef VERSION_GE_1_1 + at::IntArrayRef normalized_shape, + #else + at::IntList normalized_shape, + #endif + float epsilon, + bool memory_efficient) { + CHECK_INPUT(dout); + CHECK_INPUT(invvar); + CHECK_INPUT(input); + int n1,n2; + check_args(input,normalized_shape,n1,n2); + at::Tensor grad_input = at::empty_like(input); + if (mean.has_value()) { + cuda_layer_norm_gradient(&dout,&mean.value(),&invvar,&input,n1,n2, + normalized_shape,NULL,NULL,epsilon, + &grad_input,NULL,NULL,memory_efficient); + } else { + cuda_layer_norm_gradient(&dout,NULL,&invvar,&input,n1,n2, + normalized_shape,NULL,NULL,epsilon, + &grad_input,NULL,NULL,memory_efficient); + } + return grad_input; +} + +std::vector layer_norm_gradient_affine( + at::Tensor dout, + c10::optional mean, + at::Tensor invvar, + at::Tensor input, + #ifdef VERSION_GE_1_1 + at::IntArrayRef normalized_shape, + #else + at::IntList normalized_shape, + #endif + at::Tensor gamma, + at::Tensor beta, + float epsilon, + bool memory_efficient) { + CHECK_INPUT(dout); + CHECK_INPUT(invvar); + CHECK_INPUT(input); + CHECK_INPUT(gamma); + CHECK_INPUT(beta); + int n1,n2; + check_args(input,normalized_shape,gamma,beta,n1,n2); + at::Tensor grad_input = at::empty_like(input); + at::Tensor grad_gamma = at::empty_like(gamma); + at::Tensor grad_beta = at::empty_like(beta); + if (mean.has_value()) { + cuda_layer_norm_gradient(&dout,&mean.value(),&invvar,&input,n1,n2, + normalized_shape,&gamma,&beta,epsilon, + &grad_input,&grad_gamma,&grad_beta,memory_efficient); + } else { + cuda_layer_norm_gradient(&dout,NULL,&invvar,&input,n1,n2, + normalized_shape,&gamma,&beta,epsilon, + &grad_input,&grad_gamma,&grad_beta,memory_efficient); + } + return {grad_input, grad_gamma, grad_beta}; +} + +void cuda_rms_norm( + at::Tensor* output, + at::Tensor* invvar, + at::Tensor* input, + int n1, + int n2, + #ifdef VERSION_GE_1_1 + at::IntArrayRef normalized_shape, + #else + at::IntList normalized_shape, + #endif + at::Tensor* gamma, + double epsilon); + +#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + +std::vector rms_norm( + at::Tensor input, + #ifdef VERSION_GE_1_1 + at::IntArrayRef normalized_shape, + #else + at::IntList normalized_shape, + #endif + double epsilon) { + CHECK_INPUT(input); + int n1,n2; + check_args(input,normalized_shape,n1,n2); + at::Tensor output = at::empty_like(input); + at::Tensor invvar = at::empty({n1}, input.options().dtype(input.scalar_type()==at::ScalarType::Half || input.scalar_type()==at::ScalarType::BFloat16 ? at::ScalarType::Float : input.scalar_type())); + cuda_rms_norm(&output,&invvar,&input,n1,n2, + normalized_shape,NULL,epsilon); + return {output, invvar}; +} + +std::vector rms_norm_affine( + at::Tensor input, + #ifdef VERSION_GE_1_1 + at::IntArrayRef normalized_shape, + #else + at::IntList normalized_shape, + #endif + at::Tensor gamma, + double epsilon) { + CHECK_INPUT(input); + CHECK_INPUT(gamma); + int n1,n2; + check_args(input,normalized_shape,gamma,n1,n2); + at::Tensor output = at::empty_like(input); + const auto stats_dtype = (input.scalar_type() == at::ScalarType::Half || input.scalar_type() == at::ScalarType::BFloat16) ? at::ScalarType::Float : input.scalar_type(); + at::Tensor invvar = at::empty({n1}, input.options().dtype(stats_dtype)); + cuda_rms_norm(&output,&invvar,&input,n1,n2, + normalized_shape,&gamma,epsilon); + return {output, invvar}; +} + +std::vector rms_norm_affine_mixed_dtypes( + at::Tensor input, + #ifdef VERSION_GE_1_1 + at::IntArrayRef normalized_shape, + #else + at::IntList normalized_shape, + #endif + at::Tensor gamma, + double epsilon) { + CHECK_INPUT(input); + int n1, n2; + check_args(input, normalized_shape, n1, n2); + at::Tensor output = at::empty_like(input, gamma.options().dtype(gamma.scalar_type())); + at::Tensor invvar = at::empty({n1}, input.options().dtype(input.scalar_type() == at::ScalarType::Half || input.scalar_type() == at::ScalarType::BFloat16 ? at::ScalarType::Float : input.scalar_type())); + + cuda_rms_norm(&output,&invvar, &input, n1, n2, + normalized_shape, &gamma,epsilon); + return {output,invvar}; +} + +void cuda_rms_norm_residual( + at::Tensor* output, + at::Tensor* sum, + at::Tensor* invvar, + at::Tensor* input, + at::Tensor* residual, + int n1, + int n2, + #ifdef VERSION_GE_1_1 + at::IntArrayRef normalized_shape, + #else + at::IntList normalized_shape, + #endif + at::Tensor* gamma, + double epsilon); + +std::vector rms_norm_pre_norm_residual_forward( + at::Tensor input, + at::Tensor residual, + #ifdef VERSION_GE_1_1 + at::IntArrayRef normalized_shape, + #else + at::IntList normalized_shape, + #endif + at::Tensor gamma, + double epsilon) { + CHECK_INPUT(input); + CHECK_INPUT(residual); + int n1, n2; + check_args(input, normalized_shape, n1, n2); + at::Tensor output = at::empty_like(input, gamma.options().dtype(gamma.scalar_type())); + at::Tensor sum = at::empty_like(input, gamma.options().dtype(gamma.scalar_type())); + at::Tensor invvar = at::empty({n1}, input.options().dtype(input.scalar_type() == at::ScalarType::Half || input.scalar_type() == at::ScalarType::BFloat16 ? at::ScalarType::Float : input.scalar_type())); + + cuda_rms_norm_residual(&output,&sum,&invvar, &input, &residual, n1, n2, + normalized_shape, &gamma,epsilon); + return {output,invvar,sum}; +} + +void cuda_rms_norm_gradient( + at::Tensor* dout, + at::Tensor* invvar, + at::Tensor* input, + int n1, + int n2, + #ifdef VERSION_GE_1_1 + at::IntArrayRef normalized_shape, + #else + at::IntList normalized_shape, + #endif + at::Tensor* gamma, + double epsilon, + at::Tensor* grad_input, + at::Tensor* grad_gamma, + bool memory_efficient); + +at::Tensor rms_norm_gradient( + at::Tensor dout, + at::Tensor invvar, + at::Tensor input, + #ifdef VERSION_GE_1_1 + at::IntArrayRef normalized_shape, + #else + at::IntList normalized_shape, + #endif + double epsilon, + bool memory_efficient) { + CHECK_INPUT(dout); + CHECK_INPUT(invvar); + CHECK_INPUT(input); + int n1,n2; + check_args(input,normalized_shape,n1,n2); + at::Tensor grad_input = at::empty_like(input); + cuda_rms_norm_gradient(&dout,&invvar,&input,n1,n2, + normalized_shape,NULL,epsilon, + &grad_input,NULL,memory_efficient); + return grad_input; +} + +std::vector rms_norm_gradient_affine( + at::Tensor dout, + at::Tensor invvar, + at::Tensor input, + #ifdef VERSION_GE_1_1 + at::IntArrayRef normalized_shape, + #else + at::IntList normalized_shape, + #endif + at::Tensor gamma, + double epsilon, + bool memory_efficient) { + CHECK_INPUT(dout); + CHECK_INPUT(invvar); + CHECK_INPUT(input); + CHECK_INPUT(gamma); + int n1,n2; + check_args(input,normalized_shape,gamma,n1,n2); + at::Tensor grad_input = at::empty_like(input); + at::Tensor grad_gamma = at::empty_like(gamma); + cuda_rms_norm_gradient(&dout,&invvar,&input,n1,n2, + normalized_shape,&gamma,epsilon, + &grad_input,&grad_gamma,memory_efficient); + return {grad_input, grad_gamma}; +} + +void cuda_rms_norm_residual_gradient( + at::Tensor* dout, + at::Tensor* dres, + at::Tensor* invvar, + at::Tensor* input, + int n1, + int n2, + #ifdef VERSION_GE_1_1 + at::IntArrayRef normalized_shape, + #else + at::IntList normalized_shape, + #endif + at::Tensor* gamma, + double epsilon, + at::Tensor* grad_input, + at::Tensor* grad_gamma, + bool memory_efficient); + + +std::vector rms_norm_pre_norm_residual_backward( + at::Tensor dout, + at::Tensor dres, + at::Tensor invvar, + at::Tensor input, + #ifdef VERSION_GE_1_1 + at::IntArrayRef normalized_shape, + #else + at::IntList normalized_shape, + #endif + at::Tensor gamma, + double epsilon, + bool memory_efficient) { + CHECK_INPUT(dout); + CHECK_INPUT(dres); + CHECK_INPUT(invvar); + CHECK_INPUT(input); + CHECK_INPUT(gamma); + int n1,n2; + check_args(input,normalized_shape,gamma,n1,n2); + at::Tensor grad_input = at::empty_like(input); + at::Tensor grad_gamma = at::empty_like(gamma); + cuda_rms_norm_residual_gradient(&dout,&dres,&invvar,&input,n1,n2, + normalized_shape,&gamma,epsilon, + &grad_input,&grad_gamma,memory_efficient); + return {grad_input, grad_gamma}; +} + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward_affine", &layer_norm_affine, "LayerNorm forward (CUDA)"); + m.def("forward", &layer_norm, "LayerNorm forward (CUDA)"); + m.def("backward_affine", &layer_norm_gradient_affine, "LayerNorm backward (CUDA)"); + m.def("backward", &layer_norm_gradient, "LayerNorm backward (CUDA)"); + + m.def("forward_affine_mixed_dtypes", &layer_norm_affine_mixed_dtypes, "LayerNorm forward with mixed dtypes (CUDA) compatible with Megatron's implementation"); + + m.def("rms_forward_affine", &rms_norm_affine, "RMSNorm forward (CUDA)"); + m.def("rms_forward", &rms_norm, "RMSNorm forward (CUDA)"); + m.def("rms_backward_affine", &rms_norm_gradient_affine, "RMSNorm backward (CUDA)"); + m.def("rms_backward", &rms_norm_gradient, "RMSNorm backward (CUDA)"); + + m.def("rms_forward_affine_mixed_dtypes", &rms_norm_affine_mixed_dtypes, "RMSNorm forward with mixed dtypes (CUDA) compatible with Megatron's implementation"); + m.def("rms_pre_norm_residual_forward", &rms_norm_pre_norm_residual_forward, "RMSNorm forward with pre norm residual"); + m.def("rms_pre_norm_residual_backward", &rms_norm_pre_norm_residual_backward, "RMSNorm backward with pre norm residual"); +} \ No newline at end of file diff --git a/toolbox/DeepSpeed/v0.15.3/patches/csrc/layernorm/layer_norm_cuda_kernel.cu b/toolbox/DeepSpeed/v0.15.3/patches/csrc/layernorm/layer_norm_cuda_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..d4193eda8a797cf645d231cddc83f1ff9e7168c3 --- /dev/null +++ b/toolbox/DeepSpeed/v0.15.3/patches/csrc/layernorm/layer_norm_cuda_kernel.cu @@ -0,0 +1,3484 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +/* +Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may +not use this file except in compliance with the License. You may obtain +a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include "ATen/ATen.h" +#include "ATen/AccumulateType.h" +#include "ATen/cuda/CUDAContext.h" +#include "ATen/cuda/DeviceUtils.cuh" + +#include +#include + +#include "type_shim_rope.h" + +#define BOOL_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + if (COND) { \ + constexpr static bool CONST_NAME = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + } \ + }() + + +template __device__ +void cuWelfordOnlineSum( + const U curr, + U& mu, + U& sigma2, + U& count) +{ + count = count + U(1); + U delta = curr - mu; + U lmean = mu + delta / count; + mu = lmean; + U delta2 = curr - lmean; + sigma2 = sigma2 + delta * delta2; +} + +template __device__ +void cuChanOnlineSum( + const U muB, + const U sigma2B, + const U countB, + U& mu, + U& sigma2, + U& count) +{ + U delta = muB - mu; + U nA = count; + U nB = countB; + count = count + countB; + U nX = count; + if (nX > U(0)) { + nA = nA / nX; + nB = nB / nX; + mu = nA*mu + nB*muB; + sigma2 = sigma2 + sigma2B + delta * delta * nA * nB * nX; + } else { + mu = U(0); + sigma2 = U(0); + } +} + +template __device__ +void cuRMSOnlineSum( + const U curr, + U& sigma2) +{ + sigma2 = sigma2 + curr * curr; +} + +template __device__ +void cuChanRMSOnlineSum( + const U sigma2B, + U& sigma2) +{ + sigma2 = sigma2 + sigma2B; +} + + +template __device__ +void cuWelfordMuSigma2( + const T* __restrict__ vals, + const int n1, + const int n2, + const int i1, + U& mu, + U& sigma2, + U* buf, + bool rms_only) +{ + // Assumptions: + // 1) blockDim.x == warpSize + // 2) Tensor is contiguous + // 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available. + // + // compute variance and mean over n2 + U count = U(0); + mu= U(0); + sigma2 = U(0); + if (i1 < n1) { + // one warp normalizes one n1 index, + // synchronization is implicit + // initialize with standard Welford algorithm + const int numx = blockDim.x * blockDim.y; + const int thrx = threadIdx.x + threadIdx.y * blockDim.x; + const T* lvals = vals + i1*n2; + int l = 4*thrx; + for (; l+3 < n2; l+=4*numx) { + for (int k = 0; k < 4; ++k) { + U curr = static_cast(lvals[l+k]); + if (!rms_only) { + cuWelfordOnlineSum(curr,mu,sigma2,count); + } else { + cuRMSOnlineSum(curr, sigma2); + } + } + } + for (; l < n2; ++l) { + U curr = static_cast(lvals[l]); + if (!rms_only) { + cuWelfordOnlineSum(curr,mu,sigma2,count); + } else { + cuRMSOnlineSum(curr, sigma2); + } + } + // intra-warp reductions + for (int l = 0; l <= 4; ++l) { + int srcLaneB = (threadIdx.x+(1<(muB,sigma2B,countB,mu,sigma2,count); + } else { + cuChanRMSOnlineSum(sigma2B, sigma2); + } + } + // threadIdx.x == 0 has correct values for each warp + // inter-warp reductions + if (blockDim.y > 1) { + U* ubuf = (U*)buf; + U* ibuf = (U*)(ubuf + blockDim.y); + for (int offset = blockDim.y/2; offset > 0; offset /= 2) { + // upper half of warps write to shared + if (threadIdx.x == 0 && threadIdx.y >= offset && threadIdx.y < 2*offset) { + const int wrt_y = threadIdx.y - offset; + if (!rms_only) { + ubuf[2*wrt_y] = mu; + ibuf[wrt_y] = count; + } + ubuf[2*wrt_y+1] = sigma2; + } + __syncthreads(); + // lower half merges + if (threadIdx.x == 0 && threadIdx.y < offset) { + U sigma2B = ubuf[2*threadIdx.y+1]; + if (!rms_only) { + U muB = ubuf[2*threadIdx.y]; + U countB = ibuf[threadIdx.y]; + cuChanOnlineSum(muB,sigma2B,countB,mu,sigma2,count); + } else { + cuChanRMSOnlineSum(sigma2B,sigma2); + } + } + __syncthreads(); + } + // threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values + if (threadIdx.x == 0 && threadIdx.y == 0) { + if (!rms_only) { + ubuf[0] = mu; + } + ubuf[1] = sigma2; + } + __syncthreads(); + if (!rms_only) { + mu = ubuf[0]; + } + sigma2 = ubuf[1]/U(n2); + // don't care about final value of count, we know count == n2 + } else { + if (!rms_only) { +#ifndef __ILUVATAR__ + mu = WARP_SHFL(mu, 0); +#else + mu = WARP_SHFL(mu, 0, 32); +#endif + } +#ifndef __ILUVATAR__ + sigma2 = WARP_SHFL(sigma2/U(n2), 0); +#else + sigma2 = WARP_SHFL(sigma2/U(n2), 0, 32); +#endif + } + } +} + +template<> __device__ +void cuWelfordMuSigma2( + const at::Half* __restrict__ vals, + const int n1, + const int n2, + const int i1, + float& mu, + float& sigma2, + float* buf, + bool rms_only) +{ + // Assumptions: + // 1) blockDim.x == warpSize + // 2) Tensor is contiguous + // 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available. + // + // compute variance and mean over n2 + float count = 0.0f; + mu= float(0); + sigma2 = float(0); + if (i1 < n1) { + // one warp normalizes one n1 index, + // synchronization is implicit + // initialize with standard Welford algorithm + const int numx = blockDim.x * blockDim.y; + const int thrx = threadIdx.x + threadIdx.y * blockDim.x; + const at::Half* lvals = vals + i1*n2; + int l = 8*thrx; + if ((((size_t)lvals)&3) != 0) { + // 16 bit alignment + // first thread consumes first point + if (thrx == 0) { + float curr = static_cast(lvals[0]); + if (!rms_only) { + cuWelfordOnlineSum(curr,mu,sigma2,count); + } else { + cuRMSOnlineSum(curr, sigma2); + } + + } + ++l; + } + // at this point, lvals[l] are 32 bit aligned for all threads. + for (; l+7 < n2; l+=8*numx) { + for (int k = 0; k < 8; k+=2) { + float2 curr = __half22float2(*((__half2*)(lvals+l+k))); + if (!rms_only) { +#ifndef __ILUVATAR__ + cuWelfordOnlineSum(curr.x,mu,sigma2,count); + cuWelfordOnlineSum(curr.y,mu,sigma2,count); +#else + cuWelfordOnlineSum(curr.x,mu,sigma2,count); + cuWelfordOnlineSum(curr.y,mu,sigma2,count); +#endif + } else { + cuRMSOnlineSum(curr.x, sigma2); + cuRMSOnlineSum(curr.y, sigma2); + } + } + } + for (; l < n2; ++l) { + float curr = static_cast(lvals[l]); + if (!rms_only) { + cuWelfordOnlineSum(curr,mu,sigma2,count); + } else { + cuRMSOnlineSum(curr, sigma2); + } + } + // intra-warp reductions + for (int l = 0; l <= 4; ++l) { + int srcLaneB = (threadIdx.x+(1< 1) { + float* ubuf = (float*)buf; + float* ibuf = (float*)(ubuf + blockDim.y); + for (int offset = blockDim.y/2; offset > 0; offset /= 2) { + // upper half of warps write to shared + if (threadIdx.x == 0 && threadIdx.y >= offset && threadIdx.y < 2*offset) { + const int wrt_y = threadIdx.y - offset; + ubuf[2*wrt_y+1] = sigma2; + if (!rms_only) { + ubuf[2*wrt_y] = mu; + ibuf[wrt_y] = count; + } + } + __syncthreads(); + // lower half merges + if (threadIdx.x == 0 && threadIdx.y < offset) { + float sigma2B = ubuf[2*threadIdx.y+1]; + if (!rms_only) { + float muB = ubuf[2*threadIdx.y]; + float countB = ibuf[threadIdx.y]; + cuChanOnlineSum(muB,sigma2B,countB,mu,sigma2,count); + } else { + cuChanRMSOnlineSum(sigma2B, sigma2); + } + } + __syncthreads(); + } + // threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values + if (threadIdx.x == 0 && threadIdx.y == 0) { + if (!rms_only) { + ubuf[0] = mu; + } + ubuf[1] = sigma2; + } + __syncthreads(); + if (!rms_only) { + mu = ubuf[0]; + } + sigma2 = ubuf[1]/float(n2); + // don't care about final value of count, we know count == n2 + } else { + if (!rms_only) { +#ifndef __ILUVATAR__ + mu = WARP_SHFL(mu, 0); +#else + mu = WARP_SHFL(mu, 0, 32); +#endif + } +#ifndef __ILUVATAR__ + sigma2 = WARP_SHFL(sigma2/float(n2), 0); +#else + sigma2 = WARP_SHFL(sigma2/float(n2), 0, 32); +#endif + } + } +} + +template __device__ +void cuWelfordMuSigma2_opt( + const T* __restrict__ vals, + const int n1, + const int n2, + const int i1, + U& mu, + U& sigma2, + U* buf, + bool rms_only) +{ + U count = U(0); + mu= U(0); + sigma2 = U(0); + const int numx = blockDim.x * blockDim.y; + const int tid = threadIdx.x + threadIdx.y * blockDim.x; + const T* lvals = vals + i1*n2; + + #pragma unroll + for (int l = tid;l < n2;l+=numx){ + U curr = static_cast(lvals[l]); + if (!rms_only) { + cuWelfordOnlineSum(curr, mu, sigma2, count); + } else { + cuRMSOnlineSum(curr, sigma2); + } + } + + U muB; + U sigma2B; + U countB; + #pragma unroll + for (int offset=32;offset>0;offset/=2) { + if (rms_only) { + sigma2B = __shfl_xor_sync(0xffffffff, sigma2, offset, 64); + sigma2 += sigma2B; + } else { + muB = __shfl_xor_sync(0xffffffff, mu, offset, 64); + sigma2B = __shfl_xor_sync(0xffffffff, sigma2, offset, 64); + countB = __shfl_xor_sync(0xffffffff, count, offset, 64); + cuChanOnlineSum(muB,sigma2B,countB,mu,sigma2,count); + } + } + + if (blockDim.y > 1) { + if (rms_only) { + if (tid % 64 == 0) { + buf[tid/64] = sigma2; + } + __syncthreads(); + sigma2 = buf[0]; + for (int i=1;i0;offset/=2) { + if (tid < offset) { + cuChanOnlineSum(buf[tid*3+offset*3],buf[tid*3+offset*3+1],buf[tid*3+offset*3+2],buf[tid*3],buf[tid*3+1],buf[tid*3+2]); + } + __syncthreads(); + } + mu = buf[0]; + sigma2 = buf[1]/U(n2); + } + } else { + sigma2 /= U(n2); + } +} + + +template<> __device__ +void cuWelfordMuSigma2_opt( + const at::Half* __restrict__ vals, + const int n1, + const int n2, + const int i1, + float& mu, + float& sigma2, + float* buf, + bool rms_only) +{ + typedef unsigned v4u32 __attribute__((ext_vector_type(4))); + mu = float(0); + sigma2 = float(0); + float count = float(0); + const int numx = blockDim.x * blockDim.y; + const int tid = threadIdx.x + threadIdx.y * blockDim.x; + + const float* lvals = reinterpret_cast(vals + i1*n2); + + float curr = float(0); + at::Half* curr1 = reinterpret_cast(&curr); + + v4u32 aBase; + aBase.x = (unsigned)(unsigned long long)lvals; + aBase.y = (unsigned)((unsigned long long)lvals >> 32); + aBase.zw = -1u; + + #pragma unroll + for (int l = 0;l < n2/(numx*2);l++){ + curr = __ivcorex_ml_mem_load_f32(aBase, 4 * (tid+l*numx), 0, 0); + if (rms_only) { + cuRMSOnlineSum(static_cast(curr1[0]), sigma2); + cuRMSOnlineSum(static_cast(curr1[1]), sigma2); + } else { + cuWelfordOnlineSum(static_cast(curr1[0]), mu, sigma2, count); + cuWelfordOnlineSum(static_cast(curr1[1]), mu, sigma2, count); + } + } + + float muB; + float sigma2B; + float countB; + #pragma unroll + for (int offset=32;offset>0;offset/=2) { + if (rms_only) { + sigma2B = __shfl_xor_sync(0xffffffff, sigma2, offset, 64); + sigma2 += sigma2B; + } else { + muB = __shfl_xor_sync(0xffffffff, mu, offset, 64); + sigma2B = __shfl_xor_sync(0xffffffff, sigma2, offset, 64); + countB = __shfl_xor_sync(0xffffffff, count, offset, 64); + cuChanOnlineSum(muB,sigma2B,countB,mu,sigma2,count); + } + } + + if (blockDim.y > 1) { + if (rms_only) { + if (tid % 64 == 0) { + buf[tid/64] = sigma2; + } + __syncthreads(); + sigma2 = buf[0]; + for (int i=1;i0;offset/=2) { + if (tid < offset) { + cuChanOnlineSum(buf[tid*3+offset*3],buf[tid*3+offset*3+1],buf[tid*3+offset*3+2],buf[tid*3],buf[tid*3+1],buf[tid*3+2]); + } + __syncthreads(); + } + mu = buf[0]; + sigma2 = buf[1]/float(n2); + } + } else { + sigma2 /= float(n2); + } +} + +template U rsqrt(U v) { + return U(1) / sqrt(v); +} +template<> float rsqrt(float v) { + return rsqrtf(v); +} +template<> double rsqrt(double v) { + return rsqrt(v); +} + +namespace { +// This is the un-specialized struct. Note that we prevent instantiation of this +// struct by putting an undefined symbol in the function body so it won't compile. +// template +// struct SharedMemory +// { +// // Ensure that we won't compile any un-specialized types +// __device__ T *getPointer() +// { +// extern __device__ void error(void); +// error(); +// return NULL; +// } +// }; +// https://github.com/NVIDIA/apex/issues/246 +template +struct SharedMemory; + +template <> +struct SharedMemory +{ + __device__ float *getPointer() + { + extern __shared__ float s_float[]; + return s_float; + } +}; + +template <> +struct SharedMemory +{ + __device__ double *getPointer() + { + extern __shared__ double s_double[]; + return s_double; + } +}; +} + +template __device__ +void cuApplyLayerNorm_( + V* __restrict__ output_vals, + U* __restrict__ mean, + U* __restrict__ invvar, + const T* __restrict__ vals, + const int n1, + const int n2, + const U epsilon, + const V* __restrict__ gamma, + const V* __restrict__ beta, + bool rms_only + ) +{ + // Assumptions: + // 1) blockDim.x == warpSize + // 2) Tensors are contiguous + // + for (auto i1=blockIdx.y; i1 < n1; i1 += gridDim.y) { + SharedMemory shared; + U* buf = shared.getPointer(); + U mu,sigma2; + // cuWelfordMuSigma2(vals,n1,n2,i1,mu,sigma2,buf,rms_only); + cuWelfordMuSigma2_opt(vals,n1,n2,i1,mu,sigma2,buf,rms_only); + + const T* lvals = vals + i1*n2; + V* ovals = output_vals + i1*n2; + U c_invvar = rsqrt(sigma2 + epsilon); + const int numx = blockDim.x * blockDim.y; + const int thrx = threadIdx.x + threadIdx.y * blockDim.x; + if (gamma != NULL && (beta != NULL || rms_only)) { + for (int i = thrx; i < n2; i+=numx) { + U curr = static_cast(lvals[i]); + if (!rms_only) { + ovals[i] = gamma[i] * static_cast(c_invvar * (curr - mu)) + beta[i]; + } else { + ovals[i] = gamma[i] * static_cast(c_invvar * curr); + } + + } + } else { + for (int i = thrx; i < n2; i+=numx) { + U curr = static_cast(lvals[i]); + if (!rms_only) { + ovals[i] = static_cast(c_invvar * (curr - mu)); + } else { + ovals[i] = static_cast(c_invvar * curr); + } + } + } + if (threadIdx.x == 0 && threadIdx.y == 0) { + if (!rms_only) { + mean[i1] = mu; + } + invvar[i1] = c_invvar; + } + __syncthreads(); + } +} + +template __device__ +void cuWelfordMuSigma2_opt2( + const T* __restrict__ vals, + const T* __restrict__ residual, + const int n1, + const int n2, + const int i1, + U& mu, + U& sigma2, + U* buf, + bool rms_only) +{ + U count = U(0); + mu= U(0); + sigma2 = U(0); + const int numx = blockDim.x * blockDim.y; + const int tid = threadIdx.x + threadIdx.y * blockDim.x; + const T* lvals = vals + i1*n2; + const T* lresidual = residual + i1*n2; + + #pragma unroll + for (int l = tid;l < n2;l+=numx){ + U curr = static_cast(lvals[l]+lresidual[l]); + if (!rms_only) { + cuWelfordOnlineSum(curr, mu, sigma2, count); + } else { + cuRMSOnlineSum(curr, sigma2); + } + } + + U muB; + U sigma2B; + U countB; + #pragma unroll + for (int offset=32;offset>0;offset/=2) { + if (rms_only) { + sigma2B = __shfl_xor_sync(0xffffffff, sigma2, offset, 64); + sigma2 += sigma2B; + } else { + muB = __shfl_xor_sync(0xffffffff, mu, offset, 64); + sigma2B = __shfl_xor_sync(0xffffffff, sigma2, offset, 64); + countB = __shfl_xor_sync(0xffffffff, count, offset, 64); + cuChanOnlineSum(muB,sigma2B,countB,mu,sigma2,count); + } + } + + if (blockDim.y > 1) { + if (rms_only) { + if (tid % 64 == 0) { + buf[tid/64] = sigma2; + } + __syncthreads(); + sigma2 = buf[0]; + for (int i=1;i0;offset/=2) { + if (tid < offset) { + cuChanOnlineSum(buf[tid*3+offset*3],buf[tid*3+offset*3+1],buf[tid*3+offset*3+2],buf[tid*3],buf[tid*3+1],buf[tid*3+2]); + } + __syncthreads(); + } + mu = buf[0]; + sigma2 = buf[1]/U(n2); + } + } else { + sigma2 /= U(n2); + } +} + + +template<> __device__ +void cuWelfordMuSigma2_opt2( + const at::Half* __restrict__ vals, + const at::Half* __restrict__ residual, + const int n1, + const int n2, + const int i1, + float& mu, + float& sigma2, + float* buf, + bool rms_only) +{ + typedef unsigned v4u32 __attribute__((ext_vector_type(4))); + mu = float(0); + sigma2 = float(0); + float count = float(0); + const int numx = blockDim.x * blockDim.y; + const int tid = threadIdx.x + threadIdx.y * blockDim.x; + + const float* lvals = reinterpret_cast(vals + i1*n2); + const float* lresidual = reinterpret_cast(residual + i1*n2); + + float curr1 = float(0); + at::Half* curr1_ = reinterpret_cast(&curr1); + float curr2 = float(0); + at::Half* curr2_ = reinterpret_cast(&curr2); + + v4u32 aBase; + aBase.x = (unsigned)(unsigned long long)lvals; + aBase.y = (unsigned)((unsigned long long)lvals >> 32); + aBase.zw = -1u; + + v4u32 bBase; + bBase.x = (unsigned)(unsigned long long)lresidual; + bBase.y = (unsigned)((unsigned long long)lresidual >> 32); + bBase.zw = -1u; + + #pragma unroll + for (int l = 0;l < n2/(numx*2);l++){ + curr1 = __ivcorex_ml_mem_load_f32(aBase, 4 * (tid+l*numx), 0, 0); + curr2 = __ivcorex_ml_mem_load_f32(bBase, 4 * (tid+l*numx), 0, 0); + curr1_[0] += curr2_[0]; + curr1_[1] += curr2_[1]; + + if (rms_only) { + cuRMSOnlineSum(static_cast(curr1_[0]), sigma2); + cuRMSOnlineSum(static_cast(curr1_[1]), sigma2); + } else { + cuWelfordOnlineSum(static_cast(curr1_[0]), mu, sigma2, count); + cuWelfordOnlineSum(static_cast(curr1_[1]), mu, sigma2, count); + } + } + + float muB; + float sigma2B; + float countB; + #pragma unroll + for (int offset=32;offset>0;offset/=2) { + if (rms_only) { + sigma2B = __shfl_xor_sync(0xffffffff, sigma2, offset, 64); + sigma2 += sigma2B; + } else { + muB = __shfl_xor_sync(0xffffffff, mu, offset, 64); + sigma2B = __shfl_xor_sync(0xffffffff, sigma2, offset, 64); + countB = __shfl_xor_sync(0xffffffff, count, offset, 64); + cuChanOnlineSum(muB,sigma2B,countB,mu,sigma2,count); + } + } + + if (blockDim.y > 1) { + if (rms_only) { + if (tid % 64 == 0) { + buf[tid/64] = sigma2; + } + __syncthreads(); + sigma2 = buf[0]; + for (int i=1;i0;offset/=2) { + if (tid < offset) { + cuChanOnlineSum(buf[tid*3+offset*3],buf[tid*3+offset*3+1],buf[tid*3+offset*3+2],buf[tid*3],buf[tid*3+1],buf[tid*3+2]); + } + __syncthreads(); + } + mu = buf[0]; + sigma2 = buf[1]/float(n2); + } + } else { + sigma2 /= float(n2); + } +} + +template<> __device__ +void cuWelfordMuSigma2_opt2( + const at::BFloat16* __restrict__ vals, + const at::BFloat16* __restrict__ residual, + const int n1, + const int n2, + const int i1, + float& mu, + float& sigma2, + float* buf, + bool rms_only) +{ + typedef unsigned v4u32 __attribute__((ext_vector_type(4))); + mu = float(0); + sigma2 = float(0); + float count = float(0); + const int numx = blockDim.x * blockDim.y; + const int tid = threadIdx.x + threadIdx.y * blockDim.x; + + const float* lvals = reinterpret_cast(vals + i1*n2); + const float* lresidual = reinterpret_cast(residual + i1*n2); + + float curr1 = float(0); + at::BFloat16* curr1_ = reinterpret_cast(&curr1); + float curr2 = float(0); + at::BFloat16* curr2_ = reinterpret_cast(&curr2); + + v4u32 aBase; + aBase.x = (unsigned)(unsigned long long)lvals; + aBase.y = (unsigned)((unsigned long long)lvals >> 32); + aBase.zw = -1u; + + v4u32 bBase; + bBase.x = (unsigned)(unsigned long long)lresidual; + bBase.y = (unsigned)((unsigned long long)lresidual >> 32); + bBase.zw = -1u; + + #pragma unroll + for (int l = 0;l < n2/(numx*2);l++){ + curr1 = __ivcorex_ml_mem_load_f32(aBase, 4 * (tid+l*numx), 0, 0); + curr2 = __ivcorex_ml_mem_load_f32(bBase, 4 * (tid+l*numx), 0, 0); + curr1_[0] += curr2_[0]; + curr1_[1] += curr2_[1]; + + if (rms_only) { + cuRMSOnlineSum(static_cast(curr1_[0]), sigma2); + cuRMSOnlineSum(static_cast(curr1_[1]), sigma2); + } else { + cuWelfordOnlineSum(static_cast(curr1_[0]), mu, sigma2, count); + cuWelfordOnlineSum(static_cast(curr1_[1]), mu, sigma2, count); + } + } + + float muB; + float sigma2B; + float countB; + #pragma unroll + for (int offset=32;offset>0;offset/=2) { + if (rms_only) { + sigma2B = __shfl_xor_sync(0xffffffff, sigma2, offset, 64); + sigma2 += sigma2B; + } else { + muB = __shfl_xor_sync(0xffffffff, mu, offset, 64); + sigma2B = __shfl_xor_sync(0xffffffff, sigma2, offset, 64); + countB = __shfl_xor_sync(0xffffffff, count, offset, 64); + cuChanOnlineSum(muB,sigma2B,countB,mu,sigma2,count); + } + } + + if (blockDim.y > 1) { + if (rms_only) { + if (tid % 64 == 0) { + buf[tid/64] = sigma2; + } + __syncthreads(); + sigma2 = buf[0]; + for (int i=1;i0;offset/=2) { + if (tid < offset) { + cuChanOnlineSum(buf[tid*3+offset*3],buf[tid*3+offset*3+1],buf[tid*3+offset*3+2],buf[tid*3],buf[tid*3+1],buf[tid*3+2]); + } + __syncthreads(); + } + mu = buf[0]; + sigma2 = buf[1]/float(n2); + } + } else { + sigma2 /= float(n2); + } +} + +template __device__ +void cuApplyLayerNormRes_( + V* __restrict__ output_vals, + V* __restrict__ output_sum, + U* __restrict__ mean, + U* __restrict__ invvar, + const T* __restrict__ vals, + const T* __restrict__ residual, + const int n1, + const int n2, + const U epsilon, + const V* __restrict__ gamma, + const V* __restrict__ beta, + bool rms_only + ) +{ + // Assumptions: + // 1) blockDim.x == warpSize + // 2) Tensors are contiguous + // + for (auto i1=blockIdx.y; i1 < n1; i1 += gridDim.y) { + SharedMemory shared; + U* buf = shared.getPointer(); + U mu,sigma2; + cuWelfordMuSigma2_opt2(vals,residual,n1,n2,i1,mu,sigma2,buf,rms_only); + + __syncthreads(); + const T* lvals = vals + i1*n2; + const T* lresidual = residual + i1*n2; + V* ovals = output_vals + i1*n2; + V* osum = output_sum + i1*n2; + U c_invvar = rsqrt(sigma2 + epsilon); + const int numx = blockDim.x * blockDim.y; + const int thrx = threadIdx.x + threadIdx.y * blockDim.x; + if (gamma != NULL && (beta != NULL || rms_only)) { + for (int i = thrx; i < n2; i+=numx) { + U curr = static_cast(lvals[i]+lresidual[i]); + if (!rms_only) { + ovals[i] = gamma[i] * static_cast(c_invvar * (curr - mu)) + beta[i]; + } else { + ovals[i] = gamma[i] * static_cast(c_invvar * curr); + } + osum[i] = static_cast(curr); + } + } else { + for (int i = thrx; i < n2; i+=numx) { + U curr = static_cast(lvals[i]+lresidual[i]); + if (!rms_only) { + ovals[i] = static_cast(c_invvar * (curr - mu)); + } else { + ovals[i] = static_cast(c_invvar * curr); + } + osum[i] = static_cast(curr); + } + } + if (threadIdx.x == 0 && threadIdx.y == 0) { + if (!rms_only) { + mean[i1] = mu; + } + invvar[i1] = c_invvar; + } + __syncthreads(); + } +} + +template __global__ +void cuApplyLayerNorm( + V* __restrict__ output_vals, + U* __restrict__ mean, + U* __restrict__ invvar, + const T* __restrict__ vals, + const int n1, + const int n2, + const U epsilon, + const V* __restrict__ gamma, + const V* __restrict__ beta + ) +{ + cuApplyLayerNorm_(output_vals, mean, invvar, vals, n1, n2, epsilon, gamma, beta, false); +} + +template __global__ +void cuApplyRMSNorm( + V* __restrict__ output_vals, + U* __restrict__ invvar, + const T* __restrict__ vals, + const int n1, + const int n2, + const U epsilon, + const V* __restrict__ gamma) +{ + cuApplyLayerNorm_(output_vals, NULL, invvar, vals, n1, n2, epsilon, gamma, NULL, true); +} + +template __global__ +void cuApplyRMSNormRes( + V* __restrict__ output_vals, + V* __restrict__ output_sum, + U* __restrict__ invvar, + const T* __restrict__ vals, + const T* __restrict__ residual, + const int n1, + const int n2, + const U epsilon, + const V* __restrict__ gamma) +{ + cuApplyLayerNormRes_(output_vals, output_sum, NULL, invvar, vals, residual, n1, n2, epsilon, gamma, NULL, true); +} + +template __device__ +void cuLoadWriteStridedInputs( + const int i1_block, + const int thr_load_row_off, + const int thr_load_col_off, + const int i2_off, + const int row_stride, + U* warp_buf1, + U* warp_buf2, + const T* input, + const V* dout, + const int i1_end, + const int n2, + const U* __restrict__ mean, + const U* __restrict__ invvar, + bool rms_only + ) +{ + int i1 = i1_block+thr_load_row_off; + if (i1 < i1_end) { + U curr_mean; + if (!rms_only) { + curr_mean = mean[i1]; + } + U curr_invvar = invvar[i1]; + for (int k = 0; k < blockDim.y; ++k) { + int i2 = i2_off + k; + int load_idx = i1*n2+i2; + int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k; + if (i2(input[load_idx]); + U curr_dout = static_cast(dout[load_idx]); + if (!rms_only) { + warp_buf1[write_idx] = curr_dout; + warp_buf2[write_idx] = curr_dout * (curr_input - curr_mean) * curr_invvar; + } else { + warp_buf2[write_idx] = curr_dout * (curr_input) * curr_invvar; + } + } else { + if (!rms_only) { + warp_buf1[write_idx] = U(0); + } + warp_buf2[write_idx] = U(0); + } + } + } else { + for (int k = 0; k < blockDim.y; ++k) { + int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k; + if (!rms_only) { + warp_buf1[write_idx] = U(0); + } + warp_buf2[write_idx] = U(0); + } + } +} + +template __device__ +void cuLoadAddStridedInputs( + const int i1_block, + const int thr_load_row_off, + const int thr_load_col_off, + const int i2_off, + const int row_stride, + U* warp_buf1, + U* warp_buf2, + const T* input, + const V* dout, + const int i1_end, + const int n2, + const U* __restrict__ mean, + const U* __restrict__ invvar, + bool rms_only + ) +{ + int i1 = i1_block+thr_load_row_off; + if (i1 < i1_end) { + U curr_mean; + if (!rms_only) { + curr_mean = mean[i1]; + } + U curr_invvar = invvar[i1]; + for (int k = 0; k < blockDim.y; ++k) { + int i2 = i2_off + k; + int load_idx = i1*n2+i2; + int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k; + if (i2(input[load_idx]); + U curr_dout = static_cast(dout[load_idx]); + if (!rms_only) { + warp_buf1[write_idx] += curr_dout; + warp_buf2[write_idx] += curr_dout * (curr_input - curr_mean) * curr_invvar; + } else { + warp_buf2[write_idx] += curr_dout * (curr_input) * curr_invvar; + } + } + } + } +} + + +template __global__ +void cuComputePartGradGammaBeta( + const V* __restrict__ dout, + const T* __restrict__ input, + const int n1, + const int n2, + const U* __restrict__ mean, + const U* __restrict__ invvar, + U epsilon, + U* part_grad_gamma, + U* part_grad_beta, + bool rms_only) +{ + const int numsegs_n1 = (n1+blockDim.y*blockDim.y-1) / (blockDim.y*blockDim.y); + const int segs_per_block = (numsegs_n1 + gridDim.y - 1) / gridDim.y; + const int i1_beg = blockIdx.y * segs_per_block * blockDim.y*blockDim.y; + const int i1_beg_plus_one = (blockIdx.y+1) * segs_per_block * blockDim.y*blockDim.y; + const int i1_end = i1_beg_plus_one < n1 ? i1_beg_plus_one : n1; + const int row_stride = blockDim.x+1; + const int thr_load_col_off = (threadIdx.x*blockDim.y)&(blockDim.x-1); + const int thr_load_row_off = (threadIdx.x*blockDim.y)/blockDim.x + threadIdx.y*blockDim.y; + const int i2_off = blockIdx.x * blockDim.x + thr_load_col_off; + SharedMemory shared; + U* buf = shared.getPointer(); // buf has at least blockDim.x * blockDim.y * blockDim.y + (blockDim.y - 1)*(blockDim.x/blockDim.y) elements + U* warp_buf1 = (U*)buf; + U* warp_buf2 = warp_buf1 + blockDim.y * blockDim.y * row_stride; + // compute partial sums from strided inputs + // do this to increase number of loads in flight + cuLoadWriteStridedInputs(i1_beg,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,n2,mean,invvar, rms_only); + for (int i1_block = i1_beg+blockDim.y*blockDim.y; i1_block < i1_end; i1_block+=blockDim.y*blockDim.y) { + cuLoadAddStridedInputs(i1_block,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,n2,mean,invvar, rms_only); + } + __syncthreads(); + // inter-warp reductions + // sum within each warp + U acc1 = U(0); + U acc2 = U(0); + for (int k = 0; k < blockDim.y; ++k) { + int row1 = threadIdx.y + k*blockDim.y; + int idx1 = row1*row_stride + threadIdx.x; + if (!rms_only) { + acc1 += warp_buf1[idx1]; + } + acc2 += warp_buf2[idx1]; + } + if (!rms_only) { + warp_buf1[threadIdx.y*row_stride+threadIdx.x] = acc1; + } + warp_buf2[threadIdx.y*row_stride+threadIdx.x] = acc2; + __syncthreads(); + // sum all warps + for (int offset = blockDim.y/2; offset > 1; offset /= 2) { + if (threadIdx.y < offset) { + int row1 = threadIdx.y; + int row2 = threadIdx.y + offset; + int idx1 = row1*row_stride + threadIdx.x; + int idx2 = row2*row_stride + threadIdx.x; + if (!rms_only) { + warp_buf1[idx1] += warp_buf1[idx2]; + } + warp_buf2[idx1] += warp_buf2[idx2]; + } + __syncthreads(); + } + int i2 = blockIdx.x * blockDim.x + threadIdx.x; + if (threadIdx.y == 0 && i2 < n2) { + int row1 = threadIdx.y; + int row2 = threadIdx.y + 1; + int idx1 = row1*row_stride + threadIdx.x; + int idx2 = row2*row_stride + threadIdx.x; + if (!rms_only) { + part_grad_beta[blockIdx.y*n2+i2] = warp_buf1[idx1] + warp_buf1[idx2]; + } + part_grad_gamma[blockIdx.y*n2+i2] = warp_buf2[idx1] + warp_buf2[idx2]; + } +} + +template __global__ +void cuComputeGradGammaBeta( + const U* part_grad_gamma, + const U* part_grad_beta, + const int part_size, + const int n1, + const int n2, + V* grad_gamma, + V* grad_beta, + bool rms_only) +{ + // sum partial gradients for gamma and beta + SharedMemory shared; + U* buf = shared.getPointer(); + int i2 = blockIdx.x * blockDim.x + threadIdx.x; + if (i2 < n2) { + // each warp does sequential reductions until reduced part_size is num_warps + int num_warp_reductions = part_size / blockDim.y; + U sum_gamma = U(0); + U sum_beta = U(0); + const U* part_grad_gamma_ptr = part_grad_gamma + threadIdx.y * num_warp_reductions * n2 + i2; + const U* part_grad_beta_ptr = part_grad_beta + threadIdx.y * num_warp_reductions * n2 + i2; + for (int warp_offset = 0; warp_offset < num_warp_reductions; ++warp_offset) { + sum_gamma += part_grad_gamma_ptr[warp_offset*n2]; + if (!rms_only) { + sum_beta += part_grad_beta_ptr[warp_offset*n2]; + } + } + // inter-warp reductions + const int nbsize3 = blockDim.x * blockDim.y / 2; + for (int offset = blockDim.y/2; offset >= 1; offset /= 2) { + // top half write to shared memory + if (threadIdx.y >= offset && threadIdx.y < 2*offset) { + const int write_idx = (threadIdx.y - offset) * blockDim.x + threadIdx.x; + buf[write_idx] = sum_gamma; + if (!rms_only) { + buf[write_idx+nbsize3] = sum_beta; + } + } + __syncthreads(); + // bottom half sums + if (threadIdx.y < offset) { + const int read_idx = threadIdx.y * blockDim.x + threadIdx.x; + sum_gamma += buf[read_idx]; + if (!rms_only) { + sum_beta += buf[read_idx+nbsize3]; + } + } + __syncthreads(); + } + // write out fully summed gradients + if (threadIdx.y == 0) { + grad_gamma[i2] = sum_gamma; + if (!rms_only) { + grad_beta[i2] = sum_beta; + } + } + } +} + + +template __global__ +void cuComputeGradInput( + const V* __restrict__ dout, + const T* __restrict__ input, + const int n1, + const int n2, + const U* __restrict__ mean, + const U* __restrict__ invvar, + U epsilon, + const V* gamma, + T* grad_input, + bool rms_only) +{ + for (auto i1=blockIdx.y; i1 < n1; i1 += gridDim.y) { + U sum_loss1 = U(0); + U sum_loss2 = U(0); + U c_mean; + if (!rms_only) { + c_mean = mean[i1]; + } + const U c_invvar = invvar[i1]; + const T* k_input = input + i1*n2; + const V* k_dout = dout + i1*n2; + const int numx = blockDim.x * blockDim.y; + const int thrx = threadIdx.x + threadIdx.y * blockDim.x; + if (gamma != NULL) { + int l = 4*thrx; + for (; l+3 < n2; l+=4*numx) { + for (int k = 0; k < 4; ++k) { + const U c_h = static_cast(k_input[l+k]); + const U c_loss = static_cast(k_dout[l+k]); + if (!rms_only) { + sum_loss1 += c_loss * gamma[l+k]; + sum_loss2 += c_loss * gamma[l+k] * (c_h - c_mean) * c_invvar; + } else { + sum_loss2 += c_loss * gamma[l+k] * (c_h) * c_invvar; + } + } + } + for (; l < n2; ++l) { + const U c_h = static_cast(k_input[l]); + const U c_loss = static_cast(k_dout[l]); + if (!rms_only) { + sum_loss1 += c_loss * gamma[l]; + sum_loss2 += c_loss * gamma[l] * (c_h - c_mean) * c_invvar; + } else { + sum_loss2 += c_loss * gamma[l] * (c_h) * c_invvar; + } + + } + } else { + int l = 4*thrx; + for (; l+3 < n2; l+=4*numx) { + for (int k = 0; k < 4; ++k) { + const U c_h = static_cast(k_input[l+k]); + const U c_loss = static_cast(k_dout[l+k]); + if (!rms_only) { + sum_loss1 += c_loss; + sum_loss2 += c_loss * (c_h - c_mean) * c_invvar; + } else { + sum_loss2 += c_loss * (c_h) * c_invvar; + } + } + } + for (; l < n2; ++l) { + const U c_h = static_cast(k_input[l]); + const U c_loss = static_cast(k_dout[l]); + if (!rms_only) { + sum_loss1 += c_loss; + sum_loss2 += c_loss * (c_h - c_mean) * c_invvar; + } else { + sum_loss2 += c_loss * (c_h) * c_invvar; + } + } + } + // intra-warp reductions + for (int mask = blockDim.x/2; mask > 0; mask /= 2) { + if (!rms_only) { +#ifndef __ILUVATAR__ + sum_loss1 += WARP_SHFL_XOR(sum_loss1, mask); +#else + sum_loss1 += WARP_SHFL_XOR(sum_loss1, mask, 32); +#endif + } +#ifndef __ILUVATAR__ + sum_loss2 += WARP_SHFL_XOR(sum_loss2, mask); +#else + sum_loss2 += WARP_SHFL_XOR(sum_loss2, mask, 32); +#endif + } + // inter-warp reductions + if (blockDim.y > 1) { + SharedMemory shared; + U* buf = shared.getPointer(); + for (int offset = blockDim.y/2; offset > 0; offset /= 2) { + // upper half of warps write to shared + if (threadIdx.y >= offset && threadIdx.y < 2*offset) { + const int wrt_i = (threadIdx.y - offset) * blockDim.x + threadIdx.x; + if (!rms_only) { + buf[2*wrt_i] = sum_loss1; + } + buf[2*wrt_i+1] = sum_loss2; + } + __syncthreads(); + // lower half merges + if (threadIdx.y < offset) { + const int read_i = threadIdx.y * blockDim.x + threadIdx.x; + if (!rms_only) { + sum_loss1 += buf[2*read_i]; + } + sum_loss2 += buf[2*read_i+1]; + } + __syncthreads(); + } + if (threadIdx.y == 0) { + if (!rms_only) { + buf[2*threadIdx.x] = sum_loss1; + } + buf[2*threadIdx.x+1] = sum_loss2; + } + __syncthreads(); + if (threadIdx.y !=0) { + if (!rms_only) { + sum_loss1 = buf[2*threadIdx.x]; + } + sum_loss2 = buf[2*threadIdx.x+1]; + } + } + // all threads now have the two sums over l + U fH = (U)n2; + U term1 = (U(1) / fH) * c_invvar; + T* k_grad_input = grad_input + i1*n2; + if (gamma != NULL) { + for (int l = thrx; l < n2; l+=numx) { + const U c_h = static_cast(k_input[l]); + const U c_loss = static_cast(k_dout[l]); + U f_grad_input = fH * c_loss * gamma[l]; + if (!rms_only) { + f_grad_input -= sum_loss1; + f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2; + } else { + f_grad_input -= (c_h) * c_invvar * sum_loss2; + } + f_grad_input *= term1; + k_grad_input[l] = static_cast(f_grad_input); + } + } else { + for (int l = thrx; l < n2; l+=numx) { + const U c_h = static_cast(k_input[l]); + const U c_loss = static_cast(k_dout[l]); + U f_grad_input = fH * c_loss; + if (!rms_only) { + f_grad_input -= sum_loss1; + f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2; + } else { + f_grad_input -= (c_h) * c_invvar * sum_loss2; + } + f_grad_input *= term1; + k_grad_input[l] = static_cast(f_grad_input); + } + } + // prevent race where buf is written again before reads are done + __syncthreads(); + } +} + +template __device__ +V clamp_by_magnitude(V curr_gamma, double eps) +{ + const V kMinGamma = V(eps); + if (curr_gamma >= 0) { + if (curr_gamma < kMinGamma) { + return kMinGamma; + } else { + return curr_gamma; + } + } else { + if (curr_gamma > -kMinGamma) { + return -kMinGamma; + } else { + return curr_gamma; + } + } +} + +template __global__ +void fusedGradInputWeights( + const V* __restrict__ dout, + const T* __restrict__ input_or_output, + const int n1, + const int n2, + const U* __restrict__ mean, + const U* __restrict__ invvar, + const V* gamma, + const V* beta, + const double eps, + T* grad_input, + U* part_grad_gamma, + U* part_grad_beta) +{ + const int numx = blockDim.x * blockDim.y; + const int thrx = threadIdx.y * blockDim.x + threadIdx.x; + U d_gamma[LDG] = {0}; + U d_beta[LDG] = {0}; + __shared__ U shm[80]; + + V gamma_data[LDG]; + #pragma unroll + for (int l=0;l(input_or_output_ptr[thrx+l*numx]); + const U c_loss = static_cast(dout_ptr[thrx+l*numx]); + U y_tmp; + if (!RMSONLY) { + if (!MemoryEfficient) { + y_tmp = (c_h - mean[row]) * c_invvar; + } else { + y_tmp = (c_h - static_cast(beta_data[l])) / static_cast(clamp_by_magnitude(gamma_data[l], eps)); + } + } else { + if (!MemoryEfficient) { + y_tmp = c_h * c_invvar; + } else { + y_tmp = c_h / static_cast(clamp_by_magnitude(gamma_data[l], eps)); + } + } + U dy_tmp = c_loss * gamma_data[l]; + if (!RMSONLY) { + sum_loss1 += dy_tmp; + } + sum_loss2 += dy_tmp * y_tmp; + + y[l] = y_tmp; + dy[l] = dy_tmp; + d_gamma[l] += c_loss * y_tmp; + if (!RMSONLY) { + d_beta[l] += c_loss; + } + } + + // intra warp reduction + U val1; + U val2; + #pragma unroll + for (int offset = 32; offset > 0; offset /= 2) { + val1 = __shfl_xor_sync(0xffffffff, sum_loss1, offset, 64); + val2 = __shfl_xor_sync(0xffffffff, sum_loss2, offset, 64); + sum_loss1 += val1; + sum_loss2 += val2; + } + + // intra block reduction + if (blockDim.y > 1) { + int offset = 1<<(32 - __clz(blockDim.y-1) - 1); + if (!RMSONLY) { + if (threadIdx.x == 0) { + shm[threadIdx.y*2] = sum_loss1; + shm[threadIdx.y*2+1] = sum_loss2; + } + __syncthreads(); + for (;offset>0;offset/=2) { + if (thrx < offset) { + shm[thrx*2] += shm[thrx*2+offset*2]; + shm[thrx*2+1] += shm[thrx*2+offset*2+1]; + } + __syncthreads(); + } + sum_loss1 = shm[0]; + sum_loss2 = shm[1]; + } else { + if (threadIdx.x == 0) { + shm[threadIdx.y] = sum_loss2; + } + __syncthreads(); + #pragma unroll + for (;offset>0;offset/=2) { + if (thrx < offset && thrx+offset < blockDim.y) { + shm[thrx] += shm[thrx+offset]; + } + __syncthreads(); + } + sum_loss2 = shm[0]; + } + } + + U fH = (U)n2; + U term1 = (U(1) / fH) * c_invvar; + T* k_grad_input = grad_input + row*n2; + for (int l=0;l(f_grad_input); + } + } + // #pragma unroll + for (int l=0;l __global__ +void fusedGradInputWeights( + const at::Half* __restrict__ dout, + const at::Half* __restrict__ input_or_output, + const int n1, + const int n2, + const float* __restrict__ mean, + const float* __restrict__ invvar, + const at::Half* gamma, + const at::Half* beta, + const double eps, + at::Half* grad_input, + float* part_grad_gamma, + float* part_grad_beta) +{ + typedef unsigned v4u32 __attribute__((ext_vector_type(4))); + const int numx = blockDim.x * blockDim.y; + const int thrx = threadIdx.y * blockDim.x + threadIdx.x; + float2 d_gamma[LDG] = {0}; + float2 d_beta[LDG] = {0}; + + __shared__ float shm[80]; + + float c_h = float(0); + at::Half* c_h1 = reinterpret_cast(&c_h); + float c_loss = float(0); + at::Half* c_loss1 = reinterpret_cast(&c_loss); + float c_gamma = float(0); + at::Half* c_gamma1 = reinterpret_cast(&c_gamma); + + v4u32 gBase; + gBase.x = (unsigned)(unsigned long long)gamma; + gBase.y = (unsigned)((unsigned long long)gamma >> 32); + gBase.zw = -1u; + + at::Half gamma_data[LDG*2]; + #pragma unroll + for (int l=0;l(&c_beta); + + v4u32 hBase; + hBase.x = (unsigned)(unsigned long long)beta; + hBase.y = (unsigned)((unsigned long long)beta >> 32); + hBase.zw = -1u; + + #pragma unroll + for (int l=0;l(input_or_output + row*n2); + const float* dout_ptr = reinterpret_cast(dout + row*n2); + + v4u32 aBase; + aBase.x = (unsigned)(unsigned long long)input_or_output_ptr; + aBase.y = (unsigned)((unsigned long long)input_or_output_ptr >> 32); + aBase.zw = -1u; + + v4u32 bBase; + bBase.x = (unsigned)(unsigned long long)dout_ptr; + bBase.y = (unsigned)((unsigned long long)dout_ptr >> 32); + bBase.zw = -1u; + + float2 y[LDG]; + float2 dy[LDG]; + + #pragma unroll + for (int l=0;l(c_h1[0]) - mean[row]) * c_invvar; + y_tmp1 = (static_cast(c_h1[1]) - mean[row]) * c_invvar; + } else { + y_tmp0 = (static_cast(c_h1[0]) - static_cast(beta_data[l*2])) / static_cast(clamp_by_magnitude(gamma_data[l*2], eps)); + y_tmp1 = (static_cast(c_h1[1]) - static_cast(beta_data[l*2+1])) / static_cast(clamp_by_magnitude(gamma_data[l*2+1], eps)); + } + } else { + if (!MemoryEfficient) { + y_tmp0 = static_cast(c_h1[0]) * c_invvar; + y_tmp1 = static_cast(c_h1[1]) * c_invvar; + } else { + y_tmp0 = static_cast(c_h1[0]) / static_cast(clamp_by_magnitude(gamma_data[l*2], eps)); + y_tmp1 = static_cast(c_h1[1]) / static_cast(clamp_by_magnitude(gamma_data[l*2+1], eps)); + } + } + float dy_tmp0 = static_cast(c_loss1[0]) * static_cast(gamma_data[l*2]); + float dy_tmp1 = static_cast(c_loss1[1]) * static_cast(gamma_data[l*2+1]); + sum_loss1 += dy_tmp0 + dy_tmp1; + sum_loss2 += y_tmp0 * dy_tmp0 + y_tmp1 * dy_tmp1; + y[l].x = y_tmp0; + y[l].y = y_tmp1; + dy[l].x = dy_tmp0; + dy[l].y = dy_tmp1; + d_gamma[l].x += static_cast(c_loss1[0]) * y_tmp0; + d_gamma[l].y += static_cast(c_loss1[1]) * y_tmp1; + if (!RMSONLY) { + d_beta[l].x += static_cast(c_loss1[0]); + d_beta[l].y += static_cast(c_loss1[1]); + } + } + + // intra warp reduction + float val1; + float val2; + #pragma unroll + for (int offset = 32; offset > 0; offset /= 2) { + val1 = __shfl_xor_sync(0xffffffff, sum_loss1, offset, 64); + val2 = __shfl_xor_sync(0xffffffff, sum_loss2, offset, 64); + sum_loss1 += val1; + sum_loss2 += val2; + } + + // intra block reduction + if (blockDim.y > 1) { + int offset = 1<<(32 - __clz(blockDim.y-1) - 1); + if (!RMSONLY) { + if (threadIdx.x == 0) { + shm[threadIdx.y*2] = sum_loss1; + shm[threadIdx.y*2+1] = sum_loss2; + } + __syncthreads(); + for (;offset>0;offset/=2) { + if (thrx < offset) { + shm[thrx*2] += shm[thrx*2+offset*2]; + shm[thrx*2+1] += shm[thrx*2+offset*2+1]; + } + __syncthreads(); + } + sum_loss1 = shm[0]; + sum_loss2 = shm[1]; + } else { + if (threadIdx.x == 0) { + shm[threadIdx.y] = sum_loss2; + } + __syncthreads(); + #pragma unroll + for (;offset>0;offset/=2) { + if (thrx < offset && thrx+offset < blockDim.y) { + shm[thrx] += shm[thrx+offset]; + } + __syncthreads(); + } + sum_loss2 = shm[0]; + } + } + + float fH = (float)n2; + float term1 = (float(1) / fH) * c_invvar; + float* k_grad_input = reinterpret_cast(grad_input + row*n2); + for (int l=0;l(&f_grad_input); + f_grad_input_[0] = static_cast(f_grad_input0); + f_grad_input_[1] = static_cast(f_grad_input1); + k_grad_input[thrx+l*numx] = f_grad_input; + } + } + + float2* part_grad_gamma_ptr = reinterpret_cast(part_grad_gamma + blockIdx.y*n2); + float2* part_grad_beta_ptr = reinterpret_cast(part_grad_beta + blockIdx.y*n2); + for (int l=0;l __global__ +void fusedGradInputWeights( + const at::BFloat16* __restrict__ dout, + const at::BFloat16* __restrict__ input_or_output, + const int n1, + const int n2, + const float* __restrict__ mean, + const float* __restrict__ invvar, + const at::BFloat16* gamma, + const at::BFloat16* beta, + const double eps, + at::BFloat16* grad_input, + float* part_grad_gamma, + float* part_grad_beta) +{ + typedef unsigned v4u32 __attribute__((ext_vector_type(4))); + const int numx = blockDim.x * blockDim.y; + const int thrx = threadIdx.y * blockDim.x + threadIdx.x; + float2 d_gamma[LDG] = {0}; + float2 d_beta[LDG] = {0}; + + __shared__ float shm[80]; + + float c_h = float(0); + at::BFloat16* c_h1 = reinterpret_cast(&c_h); + float c_loss = float(0); + at::BFloat16* c_loss1 = reinterpret_cast(&c_loss); + float c_gamma = float(0); + at::BFloat16* c_gamma1 = reinterpret_cast(&c_gamma); + + v4u32 gBase; + gBase.x = (unsigned)(unsigned long long)gamma; + gBase.y = (unsigned)((unsigned long long)gamma >> 32); + gBase.zw = -1u; + + at::BFloat16 gamma_data[LDG*2]; + #pragma unroll + for (int l=0;l(&c_beta); + + v4u32 hBase; + hBase.x = (unsigned)(unsigned long long)beta; + hBase.y = (unsigned)((unsigned long long)beta >> 32); + hBase.zw = -1u; + + #pragma unroll + for (int l=0;l(input_or_output + row*n2); + const float* dout_ptr = reinterpret_cast(dout + row*n2); + + v4u32 aBase; + aBase.x = (unsigned)(unsigned long long)input_or_output_ptr; + aBase.y = (unsigned)((unsigned long long)input_or_output_ptr >> 32); + aBase.zw = -1u; + + v4u32 bBase; + bBase.x = (unsigned)(unsigned long long)dout_ptr; + bBase.y = (unsigned)((unsigned long long)dout_ptr >> 32); + bBase.zw = -1u; + + float2 y[LDG]; + float2 dy[LDG]; + + #pragma unroll + for (int l=0;l(c_h1[0]) - mean[row]) * c_invvar; + y_tmp1 = (static_cast(c_h1[1]) - mean[row]) * c_invvar; + } else { + y_tmp0 = (static_cast(c_h1[0]) - static_cast(beta_data[l*2])) / static_cast(clamp_by_magnitude(gamma_data[l*2], eps)); + y_tmp1 = (static_cast(c_h1[1]) - static_cast(beta_data[l*2+1])) / static_cast(clamp_by_magnitude(gamma_data[l*2+1], eps)); + } + } else { + if (!MemoryEfficient) { + y_tmp0 = static_cast(c_h1[0]) * c_invvar; + y_tmp1 = static_cast(c_h1[1]) * c_invvar; + } else { + y_tmp0 = static_cast(c_h1[0]) / static_cast(clamp_by_magnitude(gamma_data[l*2], eps)); + y_tmp1 = static_cast(c_h1[1]) / static_cast(clamp_by_magnitude(gamma_data[l*2+1], eps)); + } + } + float dy_tmp0 = static_cast(c_loss1[0]) * static_cast(gamma_data[l*2]); + float dy_tmp1 = static_cast(c_loss1[1]) * static_cast(gamma_data[l*2+1]); + sum_loss1 += dy_tmp0 + dy_tmp1; + sum_loss2 += y_tmp0 * dy_tmp0 + y_tmp1 * dy_tmp1; + y[l].x = y_tmp0; + y[l].y = y_tmp1; + dy[l].x = dy_tmp0; + dy[l].y = dy_tmp1; + d_gamma[l].x += static_cast(c_loss1[0]) * y_tmp0; + d_gamma[l].y += static_cast(c_loss1[1]) * y_tmp1; + if (!RMSONLY) { + d_beta[l].x += static_cast(c_loss1[0]); + d_beta[l].y += static_cast(c_loss1[1]); + } + } + + // intra warp reduction + float val1; + float val2; + #pragma unroll + for (int offset = 32; offset > 0; offset /= 2) { + val1 = __shfl_xor_sync(0xffffffff, sum_loss1, offset, 64); + val2 = __shfl_xor_sync(0xffffffff, sum_loss2, offset, 64); + sum_loss1 += val1; + sum_loss2 += val2; + } + + // intra block reduction + if (blockDim.y > 1) { + int offset = 1<<(32 - __clz(blockDim.y-1) - 1); + if (!RMSONLY) { + if (threadIdx.x == 0) { + shm[threadIdx.y*2] = sum_loss1; + shm[threadIdx.y*2+1] = sum_loss2; + } + __syncthreads(); + for (;offset>0;offset/=2) { + if (thrx < offset) { + shm[thrx*2] += shm[thrx*2+offset*2]; + shm[thrx*2+1] += shm[thrx*2+offset*2+1]; + } + __syncthreads(); + } + sum_loss1 = shm[0]; + sum_loss2 = shm[1]; + } else { + if (threadIdx.x == 0) { + shm[threadIdx.y] = sum_loss2; + } + __syncthreads(); + #pragma unroll + for (;offset>0;offset/=2) { + if (thrx < offset && thrx+offset < blockDim.y) { + shm[thrx] += shm[thrx+offset]; + } + __syncthreads(); + } + sum_loss2 = shm[0]; + } + } + + float fH = (float)n2; + float term1 = (float(1) / fH) * c_invvar; + float* k_grad_input = reinterpret_cast(grad_input + row*n2); + for (int l=0;l(&f_grad_input); + f_grad_input_[0] = static_cast(f_grad_input0); + f_grad_input_[1] = static_cast(f_grad_input1); + k_grad_input[thrx+l*numx] = f_grad_input; + } + } + + float2* part_grad_gamma_ptr = reinterpret_cast(part_grad_gamma + blockIdx.y*n2); + float2* part_grad_beta_ptr = reinterpret_cast(part_grad_beta + blockIdx.y*n2); + for (int l=0;l __global__ +void fusedGradInputWeights_( + const V* __restrict__ dout, + const V* __restrict__ dres, + const T* __restrict__ input_or_output, + const int n1, + const int n2, + const U* __restrict__ mean, + const U* __restrict__ invvar, + const V* gamma, + const V* beta, + const double eps, + T* grad_input, + U* part_grad_gamma, + U* part_grad_beta) +{ + const int numx = blockDim.x * blockDim.y; + const int thrx = threadIdx.y * blockDim.x + threadIdx.x; + U d_gamma[LDG] = {0}; + U d_beta[LDG] = {0}; + __shared__ U shm[80]; + + V gamma_data[LDG]; + #pragma unroll + for (int l=0;l(input_or_output_ptr[thrx+l*numx]); + const U c_loss = static_cast(dout_ptr[thrx+l*numx]); + U y_tmp; + if (!RMSONLY) { + if (!MemoryEfficient) { + y_tmp = (c_h - mean[row]) * c_invvar; + } else { + y_tmp = (c_h - static_cast(beta_data[l])) / static_cast(clamp_by_magnitude(gamma_data[l], eps)); + } + } else { + if (!MemoryEfficient) { + y_tmp = c_h * c_invvar; + } else { + y_tmp = c_h / static_cast(clamp_by_magnitude(gamma_data[l], eps)); + } + } + U dy_tmp = c_loss * gamma_data[l]; + if (!RMSONLY) { + sum_loss1 += dy_tmp; + } + sum_loss2 += dy_tmp * y_tmp; + + y[l] = y_tmp; + dy[l] = dy_tmp; + d_gamma[l] += c_loss * y_tmp; + if (!RMSONLY) { + d_beta[l] += c_loss; + } + } + + // intra warp reduction + U val1; + U val2; + #pragma unroll + for (int offset = 32; offset > 0; offset /= 2) { + val1 = __shfl_xor_sync(0xffffffff, sum_loss1, offset, 64); + val2 = __shfl_xor_sync(0xffffffff, sum_loss2, offset, 64); + sum_loss1 += val1; + sum_loss2 += val2; + } + + // intra block reduction + if (blockDim.y > 1) { + int offset = 1<<(32 - __clz(blockDim.y-1) - 1); + if (!RMSONLY) { + if (threadIdx.x == 0) { + shm[threadIdx.y*2] = sum_loss1; + shm[threadIdx.y*2+1] = sum_loss2; + } + __syncthreads(); + for (;offset>0;offset/=2) { + if (thrx < offset) { + shm[thrx*2] += shm[thrx*2+offset*2]; + shm[thrx*2+1] += shm[thrx*2+offset*2+1]; + } + __syncthreads(); + } + sum_loss1 = shm[0]; + sum_loss2 = shm[1]; + } else { + if (threadIdx.x == 0) { + shm[threadIdx.y] = sum_loss2; + } + __syncthreads(); + #pragma unroll + for (;offset>0;offset/=2) { + if (thrx < offset && thrx+offset < blockDim.y) { + shm[thrx] += shm[thrx+offset]; + } + __syncthreads(); + } + sum_loss2 = shm[0]; + } + } + + U fH = (U)n2; + U term1 = (U(1) / fH) * c_invvar; + T* k_grad_input = grad_input + row*n2; + for (int l=0;l(dres_ptr[thrx+l*numx]); + k_grad_input[thrx+l*numx] = static_cast(f_grad_input); + } + } + // #pragma unroll + for (int l=0;l __global__ +void fusedGradInputWeights_( + const at::Half* __restrict__ dout, + const at::Half* __restrict__ dres, + const at::Half* __restrict__ input_or_output, + const int n1, + const int n2, + const float* __restrict__ mean, + const float* __restrict__ invvar, + const at::Half* gamma, + const at::Half* beta, + const double eps, + at::Half* grad_input, + float* part_grad_gamma, + float* part_grad_beta) +{ + typedef unsigned v4u32 __attribute__((ext_vector_type(4))); + const int numx = blockDim.x * blockDim.y; + const int thrx = threadIdx.y * blockDim.x + threadIdx.x; + float2 d_gamma[LDG] = {0}; + float2 d_beta[LDG] = {0}; + + __shared__ float shm[80]; + + float c_h = float(0); + at::Half* c_h1 = reinterpret_cast(&c_h); + float c_loss = float(0); + at::Half* c_loss1 = reinterpret_cast(&c_loss); + float c_dres = float(0); + at::Half* c_dres1 = reinterpret_cast(&c_dres); + float c_gamma = float(0); + at::Half* c_gamma1 = reinterpret_cast(&c_gamma); + + v4u32 gBase; + gBase.x = (unsigned)(unsigned long long)gamma; + gBase.y = (unsigned)((unsigned long long)gamma >> 32); + gBase.zw = -1u; + + at::Half gamma_data[LDG*2]; + #pragma unroll + for (int l=0;l(&c_beta); + + v4u32 hBase; + hBase.x = (unsigned)(unsigned long long)beta; + hBase.y = (unsigned)((unsigned long long)beta >> 32); + hBase.zw = -1u; + + #pragma unroll + for (int l=0;l(input_or_output + row*n2); + const float* dout_ptr = reinterpret_cast(dout + row*n2); + const float* dres_ptr = reinterpret_cast(dres + row*n2); + + v4u32 aBase; + aBase.x = (unsigned)(unsigned long long)input_or_output_ptr; + aBase.y = (unsigned)((unsigned long long)input_or_output_ptr >> 32); + aBase.zw = -1u; + + v4u32 bBase; + bBase.x = (unsigned)(unsigned long long)dout_ptr; + bBase.y = (unsigned)((unsigned long long)dout_ptr >> 32); + bBase.zw = -1u; + + v4u32 cBase; + cBase.x = (unsigned)(unsigned long long)dres_ptr; + cBase.y = (unsigned)((unsigned long long)dres_ptr >> 32); + cBase.zw = -1u; + + float2 y[LDG]; + float2 dy[LDG]; + + #pragma unroll + for (int l=0;l(c_h1[0]) - mean[row]) * c_invvar; + y_tmp1 = (static_cast(c_h1[1]) - mean[row]) * c_invvar; + } else { + y_tmp0 = (static_cast(c_h1[0]) - static_cast(beta_data[l*2])) / static_cast(clamp_by_magnitude(gamma_data[l*2], eps)); + y_tmp1 = (static_cast(c_h1[1]) - static_cast(beta_data[l*2+1])) / static_cast(clamp_by_magnitude(gamma_data[l*2+1], eps)); + } + } else { + if (!MemoryEfficient) { + y_tmp0 = static_cast(c_h1[0]) * c_invvar; + y_tmp1 = static_cast(c_h1[1]) * c_invvar; + } else { + y_tmp0 = static_cast(c_h1[0]) / static_cast(clamp_by_magnitude(gamma_data[l*2], eps)); + y_tmp1 = static_cast(c_h1[1]) / static_cast(clamp_by_magnitude(gamma_data[l*2+1], eps)); + } + } + float dy_tmp0 = static_cast(c_loss1[0]) * static_cast(gamma_data[l*2]); + float dy_tmp1 = static_cast(c_loss1[1]) * static_cast(gamma_data[l*2+1]); + sum_loss1 += dy_tmp0 + dy_tmp1; + sum_loss2 += y_tmp0 * dy_tmp0 + y_tmp1 * dy_tmp1; + y[l].x = y_tmp0; + y[l].y = y_tmp1; + dy[l].x = dy_tmp0; + dy[l].y = dy_tmp1; + d_gamma[l].x += static_cast(c_loss1[0]) * y_tmp0; + d_gamma[l].y += static_cast(c_loss1[1]) * y_tmp1; + if (!RMSONLY) { + d_beta[l].x += static_cast(c_loss1[0]); + d_beta[l].y += static_cast(c_loss1[1]); + } + } + + // intra warp reduction + float val1; + float val2; + #pragma unroll + for (int offset = 32; offset > 0; offset /= 2) { + val1 = __shfl_xor_sync(0xffffffff, sum_loss1, offset, 64); + val2 = __shfl_xor_sync(0xffffffff, sum_loss2, offset, 64); + sum_loss1 += val1; + sum_loss2 += val2; + } + + // intra block reduction + if (blockDim.y > 1) { + int offset = 1<<(32 - __clz(blockDim.y-1) - 1); + if (!RMSONLY) { + if (threadIdx.x == 0) { + shm[threadIdx.y*2] = sum_loss1; + shm[threadIdx.y*2+1] = sum_loss2; + } + __syncthreads(); + for (;offset>0;offset/=2) { + if (thrx < offset) { + shm[thrx*2] += shm[thrx*2+offset*2]; + shm[thrx*2+1] += shm[thrx*2+offset*2+1]; + } + __syncthreads(); + } + sum_loss1 = shm[0]; + sum_loss2 = shm[1]; + } else { + if (threadIdx.x == 0) { + shm[threadIdx.y] = sum_loss2; + } + __syncthreads(); + #pragma unroll + for (;offset>0;offset/=2) { + if (thrx < offset && thrx+offset < blockDim.y) { + shm[thrx] += shm[thrx+offset]; + } + __syncthreads(); + } + sum_loss2 = shm[0]; + } + } + + float fH = (float)n2; + float term1 = (float(1) / fH) * c_invvar; + float* k_grad_input = reinterpret_cast(grad_input + row*n2); + for (int l=0;l(&f_grad_input); + f_grad_input_[0] = static_cast(f_grad_input0+static_cast(c_dres1[0])); + f_grad_input_[1] = static_cast(f_grad_input1+static_cast(c_dres1[1])); + k_grad_input[thrx+l*numx] = f_grad_input; + } + } + + float2* part_grad_gamma_ptr = reinterpret_cast(part_grad_gamma + blockIdx.y*n2); + float2* part_grad_beta_ptr = reinterpret_cast(part_grad_beta + blockIdx.y*n2); + for (int l=0;l __global__ +void fusedGradInputWeights_( + const at::BFloat16* __restrict__ dout, + const at::BFloat16* __restrict__ dres, + const at::BFloat16* __restrict__ input_or_output, + const int n1, + const int n2, + const float* __restrict__ mean, + const float* __restrict__ invvar, + const at::BFloat16* gamma, + const at::BFloat16* beta, + const double eps, + at::BFloat16* grad_input, + float* part_grad_gamma, + float* part_grad_beta) +{ + typedef unsigned v4u32 __attribute__((ext_vector_type(4))); + const int numx = blockDim.x * blockDim.y; + const int thrx = threadIdx.y * blockDim.x + threadIdx.x; + float2 d_gamma[LDG] = {0}; + float2 d_beta[LDG] = {0}; + + __shared__ float shm[80]; + + float c_h = float(0); + at::BFloat16* c_h1 = reinterpret_cast(&c_h); + float c_loss = float(0); + at::BFloat16* c_loss1 = reinterpret_cast(&c_loss); + float c_dres = float(0); + at::BFloat16* c_dres1 = reinterpret_cast(&c_dres); + float c_gamma = float(0); + at::BFloat16* c_gamma1 = reinterpret_cast(&c_gamma); + + v4u32 gBase; + gBase.x = (unsigned)(unsigned long long)gamma; + gBase.y = (unsigned)((unsigned long long)gamma >> 32); + gBase.zw = -1u; + + at::BFloat16 gamma_data[LDG*2]; + #pragma unroll + for (int l=0;l(&c_beta); + + v4u32 hBase; + hBase.x = (unsigned)(unsigned long long)beta; + hBase.y = (unsigned)((unsigned long long)beta >> 32); + hBase.zw = -1u; + + #pragma unroll + for (int l=0;l(input_or_output + row*n2); + const float* dout_ptr = reinterpret_cast(dout + row*n2); + const float* dres_ptr = reinterpret_cast(dres + row*n2); + + v4u32 aBase; + aBase.x = (unsigned)(unsigned long long)input_or_output_ptr; + aBase.y = (unsigned)((unsigned long long)input_or_output_ptr >> 32); + aBase.zw = -1u; + + v4u32 bBase; + bBase.x = (unsigned)(unsigned long long)dout_ptr; + bBase.y = (unsigned)((unsigned long long)dout_ptr >> 32); + bBase.zw = -1u; + + v4u32 cBase; + cBase.x = (unsigned)(unsigned long long)dres_ptr; + cBase.y = (unsigned)((unsigned long long)dres_ptr >> 32); + cBase.zw = -1u; + + float2 y[LDG]; + float2 dy[LDG]; + + #pragma unroll + for (int l=0;l(c_h1[0]) - mean[row]) * c_invvar; + y_tmp1 = (static_cast(c_h1[1]) - mean[row]) * c_invvar; + } else { + y_tmp0 = (static_cast(c_h1[0]) - static_cast(beta_data[l*2])) / static_cast(clamp_by_magnitude(gamma_data[l*2], eps)); + y_tmp1 = (static_cast(c_h1[1]) - static_cast(beta_data[l*2+1])) / static_cast(clamp_by_magnitude(gamma_data[l*2+1], eps)); + } + } else { + if (!MemoryEfficient) { + y_tmp0 = static_cast(c_h1[0]) * c_invvar; + y_tmp1 = static_cast(c_h1[1]) * c_invvar; + } else { + y_tmp0 = static_cast(c_h1[0]) / static_cast(clamp_by_magnitude(gamma_data[l*2], eps)); + y_tmp1 = static_cast(c_h1[1]) / static_cast(clamp_by_magnitude(gamma_data[l*2+1], eps)); + } + } + float dy_tmp0 = static_cast(c_loss1[0]) * static_cast(gamma_data[l*2]); + float dy_tmp1 = static_cast(c_loss1[1]) * static_cast(gamma_data[l*2+1]); + sum_loss1 += dy_tmp0 + dy_tmp1; + sum_loss2 += y_tmp0 * dy_tmp0 + y_tmp1 * dy_tmp1; + y[l].x = y_tmp0; + y[l].y = y_tmp1; + dy[l].x = dy_tmp0; + dy[l].y = dy_tmp1; + d_gamma[l].x += static_cast(c_loss1[0]) * y_tmp0; + d_gamma[l].y += static_cast(c_loss1[1]) * y_tmp1; + if (!RMSONLY) { + d_beta[l].x += static_cast(c_loss1[0]); + d_beta[l].y += static_cast(c_loss1[1]); + } + } + + // intra warp reduction + float val1; + float val2; + #pragma unroll + for (int offset = 32; offset > 0; offset /= 2) { + val1 = __shfl_xor_sync(0xffffffff, sum_loss1, offset, 64); + val2 = __shfl_xor_sync(0xffffffff, sum_loss2, offset, 64); + sum_loss1 += val1; + sum_loss2 += val2; + } + + // intra block reduction + if (blockDim.y > 1) { + int offset = 1<<(32 - __clz(blockDim.y-1) - 1); + if (!RMSONLY) { + if (threadIdx.x == 0) { + shm[threadIdx.y*2] = sum_loss1; + shm[threadIdx.y*2+1] = sum_loss2; + } + __syncthreads(); + for (;offset>0;offset/=2) { + if (thrx < offset) { + shm[thrx*2] += shm[thrx*2+offset*2]; + shm[thrx*2+1] += shm[thrx*2+offset*2+1]; + } + __syncthreads(); + } + sum_loss1 = shm[0]; + sum_loss2 = shm[1]; + } else { + if (threadIdx.x == 0) { + shm[threadIdx.y] = sum_loss2; + } + __syncthreads(); + #pragma unroll + for (;offset>0;offset/=2) { + if (thrx < offset && thrx+offset < blockDim.y) { + shm[thrx] += shm[thrx+offset]; + } + __syncthreads(); + } + sum_loss2 = shm[0]; + } + } + + float fH = (float)n2; + float term1 = (float(1) / fH) * c_invvar; + float* k_grad_input = reinterpret_cast(grad_input + row*n2); + for (int l=0;l(&f_grad_input); + f_grad_input_[0] = static_cast(f_grad_input0+static_cast(c_dres1[0])); + f_grad_input_[1] = static_cast(f_grad_input1+static_cast(c_dres1[1])); + k_grad_input[thrx+l*numx] = f_grad_input; + } + } + + float2* part_grad_gamma_ptr = reinterpret_cast(part_grad_gamma + blockIdx.y*n2); + float2* part_grad_beta_ptr = reinterpret_cast(part_grad_beta + blockIdx.y*n2); + for (int l=0;l __global__ +void ComputeGradGammaBeta_opt( + const U* part_grad_gamma, + const U* part_grad_beta, + const int part_size, + const int n2, + V* grad_gamma, + V* grad_beta) +{ + U tmp_gamma = U(0); + U tmp_beta = U(0); + int i2 = blockDim.x * blockIdx.x + threadIdx.x; + #pragma unroll + for (int k=0;k(tmp_gamma); + if (!RMSONLY) { + grad_beta[i2] = static_cast(tmp_beta); + } +} + +template +void HostApplyLayerNorm( + V* output, + U* mean, + U* invvar, + const T* input, + int n1, + int n2, + float epsilon, + const V* gamma, + const V* beta + ) +{ + auto stream = at::cuda::getCurrentCUDAStream().stream(); + // const dim3 threads(32,4,1); + dim3 threads(64,1,1); + if (sizeof(T) == 1) { + threads.y = n2/1024/4 > 1 ? n2/1024/4 : 1; + } + if (sizeof(T) == 2) { + threads.y = n2/1024/2 > 1 ? n2/1024/2 : 1; + } + if (sizeof(T) >= 4) { + threads.y = n2/1024 > 1 ? n2/1024 : 1; + } + const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; +#ifdef __ILUVATAR__ + const dim3 blocks(1, std::min((uint64_t)n1, maxGridY/threads.y), 1); +#else + const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1); +#endif + int nshared = + threads.y > 1 ? + threads.y*sizeof(U)+(threads.y/2)*sizeof(U) : + 0; + cuApplyLayerNorm<<>>( + output, mean, invvar, input, n1, n2, U(epsilon), gamma, beta); +} + +template +void HostApplyRMSNorm( + V* output, + U* invvar, + const T* input, + int n1, + int n2, + double epsilon, + const V* gamma) +{ + auto stream = at::cuda::getCurrentCUDAStream().stream(); + // const dim3 threads(32,4,1); + dim3 threads(64,1,1); + if (sizeof(T) == 1) { + threads.y = n2/1024/4 > 1 ? n2/1024/4 : 1; + } + if (sizeof(T) == 2) { + threads.y = n2/1024/2 > 1 ? n2/1024/2 : 1; + } + if (sizeof(T) >= 4) { + threads.y = n2/1024 > 1 ? n2/1024 : 1; + } + const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; +#ifdef __ILUVATAR__ + const dim3 blocks(1, std::min((uint64_t)n1, maxGridY/threads.y), 1); +#else + const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1); +#endif + int nshared = + threads.y > 1 ? + threads.y*sizeof(U)+(threads.y/2)*sizeof(U) : + 0; + cuApplyRMSNorm<<>>( + output, invvar, input, n1, n2, U(epsilon), gamma); +} + +template +void HostApplyRMSNormRes( + V* output, + V* sum, + U* invvar, + const T* input, + const T* residual, + int n1, + int n2, + double epsilon, + const V* gamma) +{ + auto stream = at::cuda::getCurrentCUDAStream().stream(); + // const dim3 threads(32,4,1); + dim3 threads(64,1,1); + if (sizeof(T) == 1) { + threads.y = n2/1024/4 > 1 ? n2/1024/4 : 1; + } + if (sizeof(T) == 2) { + threads.y = n2/1024/2 > 1 ? n2/1024/2 : 1; + } + if (sizeof(T) >= 4) { + threads.y = n2/1024 > 1 ? n2/1024 : 1; + } + const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; +#ifdef __ILUVATAR__ + const dim3 blocks(1, std::min((uint64_t)n1, maxGridY/threads.y), 1); +#else + const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1); +#endif + int nshared = + threads.y > 1 ? + threads.y*sizeof(U)+(threads.y/2)*sizeof(U) : + 0; + cuApplyRMSNormRes<<>>( + output, sum, invvar, input, residual, n1, n2, U(epsilon), gamma); +} + + +void cuda_layer_norm( + at::Tensor* output, + at::Tensor* mean, + at::Tensor* invvar, + at::Tensor* input, + int n1, + int n2, + #ifdef VERSION_GE_1_1 + at::IntArrayRef normalized_shape, + #else + at::IntList normalized_shape, + #endif + at::Tensor* gamma, + at::Tensor* beta, + float epsilon) +{ + using namespace at; + DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES( + input->scalar_type(), output->scalar_type(), "layer_norm_cuda_kernel", + using accscalar_t = at::acc_type; + HostApplyLayerNorm( + output->DATA_PTR(), + mean->DATA_PTR(), + invvar->DATA_PTR(), + input->DATA_PTR(), + n1,n2, + epsilon, + gamma != NULL ? gamma->DATA_PTR() : NULL, + beta != NULL ? beta->DATA_PTR() : NULL); + ) +} + +void cuda_rms_norm( + at::Tensor* output, + at::Tensor* invvar, + at::Tensor* input, + int n1, + int n2, + #ifdef VERSION_GE_1_1 + at::IntArrayRef normalized_shape, + #else + at::IntList normalized_shape, + #endif + at::Tensor* gamma, + double epsilon) +{ + using namespace at; + DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES( + input->scalar_type(), output->scalar_type(), "rms_norm_cuda_kernel", + using accscalar_t = at::acc_type; + HostApplyRMSNorm( + output->DATA_PTR(), + invvar->DATA_PTR(), + input->DATA_PTR(), + n1,n2, + epsilon, + gamma != NULL ? gamma->DATA_PTR() : NULL); + ) +} + +void cuda_rms_norm_residual( + at::Tensor* output, + at::Tensor* sum, + at::Tensor* invvar, + at::Tensor* input, + at::Tensor* residual, + int n1, + int n2, + #ifdef VERSION_GE_1_1 + at::IntArrayRef normalized_shape, + #else + at::IntList normalized_shape, + #endif + at::Tensor* gamma, + double epsilon) +{ + using namespace at; + DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES( + input->scalar_type(), output->scalar_type(), "rms_norm_residual_cuda_kernel", + using accscalar_t = at::acc_type; + HostApplyRMSNormRes( + output->DATA_PTR(), + sum->DATA_PTR(), + invvar->DATA_PTR(), + input->DATA_PTR(), + residual->DATA_PTR(), + n1,n2, + epsilon, + gamma != NULL ? gamma->DATA_PTR() : NULL); + ) +} + + +template +void HostLayerNormGradient( + const V* dout, + const U* mean, + const U* invvar, + at::Tensor* input, + int n1, + int n2, + const V* gamma, + const V* beta, + float epsilon, + T* grad_input, + V* grad_gamma, + V* grad_beta + ) +{ + auto stream = at::cuda::getCurrentCUDAStream().stream(); + + if (gamma != NULL && beta != NULL) { + // compute grad_gamma(j) and grad_beta(j) + const int part_size = 16; + const dim3 threads2(32,4,1); + const dim3 blocks2((n2+threads2.x-1)/threads2.x,part_size,1); + const int nshared2_a = 2 * sizeof(U) * threads2.y * threads2.y * (threads2.x + 1); + const int nshared2_b = threads2.x * threads2.y * sizeof(U); + const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b; + // note (mkozuki): I can hard code part_grad_gamma's dtype as float given that + // the `cuda_layer_norm_gradient` doesn't support double. + const auto part_grad_dtype = + (input->scalar_type() == at::ScalarType::Half || input->scalar_type() == at::ScalarType::BFloat16) ? + at::ScalarType::Float : + input->scalar_type(); + at::Tensor part_grad_gamma = at::empty({part_size,n2}, input->options().dtype(part_grad_dtype)); + at::Tensor part_grad_beta = at::empty_like(part_grad_gamma); + cuComputePartGradGammaBeta<<>>( + dout, + input->DATA_PTR(), + n1,n2, + mean, + invvar, + U(epsilon), + part_grad_gamma.DATA_PTR(), + part_grad_beta.DATA_PTR(), + false); + + const dim3 threads3(32,8,1); + const dim3 blocks3((n2+threads2.x-1)/threads2.x,1,1); + const int nshared3 = threads3.x * threads3.y * sizeof(U); + cuComputeGradGammaBeta<<>>( + part_grad_gamma.DATA_PTR(), + part_grad_beta.DATA_PTR(), + part_size, + n1,n2, + grad_gamma, + grad_beta, + false); + } + + // compute grad_input + const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; +#ifdef __ILUVATAR__ + const dim3 blocks1(1, std::min((uint64_t)n1, maxGridY/4), 1); +#else + const dim3 blocks1(1, std::min((uint64_t)n1, maxGridY), 1); +#endif + const dim3 threads1(32,4,1); + int nshared = + threads1.y > 1 ? + threads1.y*threads1.x*sizeof(U) : + 0; + cuComputeGradInput<<>>( + dout, + input->DATA_PTR(), + n1,n2, + mean, + invvar, + U(epsilon), + gamma, + grad_input, + false); +} + +template +void HostRMSNormGradient( + const V* dout, + const U* invvar, + at::Tensor* input, + int n1, + int n2, + const V* gamma, + double epsilon, + T* grad_input, + V* grad_gamma) +{ + auto stream = at::cuda::getCurrentCUDAStream().stream(); + + if (gamma != NULL) { + const int part_size = 16; + const dim3 threads2(32,4,1); + const dim3 blocks2((n2+threads2.x-1)/threads2.x,part_size,1); + const int nshared2_a = 2 * sizeof(U) * threads2.y * threads2.y * (threads2.x + 1); + const int nshared2_b = threads2.x * threads2.y * sizeof(U); + const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b; + // note (mkozuki): I can hard code part_grad_gamma's dtype as float given that + // the `cuda_layer_norm_gradient` doesn't support double. + const auto part_grad_dtype = + (input->scalar_type() == at::ScalarType::Half || input->scalar_type() == at::ScalarType::BFloat16) ? + at::ScalarType::Float : + input->scalar_type(); + at::Tensor part_grad_gamma = at::empty({part_size,n2}, input->options().dtype(part_grad_dtype)); + cuComputePartGradGammaBeta<<>>( + dout, + input->DATA_PTR(), + n1,n2, + invvar, // unused + invvar, + U(epsilon), + part_grad_gamma.DATA_PTR(), + part_grad_gamma.DATA_PTR(), /* unused */ + true); + + const dim3 threads3(32,8,1); + const dim3 blocks3((n2+threads2.x-1)/threads2.x,1,1); + const int nshared3 = threads3.x * threads3.y * sizeof(U); + cuComputeGradGammaBeta<<>>( + part_grad_gamma.DATA_PTR(), + part_grad_gamma.DATA_PTR(), /* unused */ + part_size, + n1,n2, + grad_gamma, + grad_gamma, /* unused */ + true); + } + + // compute grad_input + const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; + #ifdef __ILUVATAR__ + const dim3 blocks1(1, std::min((uint64_t)n1, maxGridY/4), 1); +#else + const dim3 blocks1(1, std::min((uint64_t)n1, maxGridY), 1); +#endif + const dim3 threads1(32,4,1); + int nshared = + threads1.y > 1 ? + threads1.y*threads1.x*sizeof(U) : + 0; + cuComputeGradInput<<>>( + dout, + input->DATA_PTR(), + n1,n2, + invvar, /* unused */ + invvar, + U(epsilon), + gamma, + grad_input, + true); +} + +template +void HostLayerNormGradient_opt( + const V* dout, + const U* mean, + const U* invvar, + at::Tensor* input, // max supported hidden size 64*32*40=81920 + int n1, + int n2, + const V* gamma, + const V* beta, + float epsilon, + T* grad_input, + V* grad_gamma, + V* grad_beta, + bool memory_efficient + ) +{ + auto stream = at::cuda::getCurrentCUDAStream().stream(); + + if (gamma != NULL && beta != NULL) { + int div = 1; + if (sizeof(T) < 4) {div = 4/sizeof(T);} + int blocky = n2/64/div; + int LDG; + int gridy; + + if (blocky > 32) { + for (int i =2;i0;i--) { + if (n1 % i == 0) {gridy = i;break;} + } + } else { + gridy = 16 * 128 / blocky; + if (sizeof(T) == 2) {gridy -= 1;} + for (int i = gridy;i>0;i--) { + if (n1 % i == 0) {gridy = i;break;} + } + } + LDG = n2/64/blocky/div; + + const auto part_grad_dtype = + (input->scalar_type() == at::ScalarType::Half || input->scalar_type() == at::ScalarType::BFloat16) ? + at::ScalarType::Float : + input->scalar_type(); + at::Tensor part_grad_gamma = at::empty({gridy,n2}, input->options().dtype(part_grad_dtype)); + at::Tensor part_grad_beta = at::empty_like(part_grad_gamma); + + if (LDG == 1) { + const dim3 threads2 (64, blocky, 1); + const dim3 blocks2 (1, gridy, 1); + BOOL_SWITCH(memory_efficient, MemoryEfficient, [&]{fusedGradInputWeights<1, false, MemoryEfficient><<>>(dout, input->DATA_PTR(), n1, n2, mean, invvar, gamma, beta, double(epsilon), grad_input, part_grad_gamma.DATA_PTR(), part_grad_beta.DATA_PTR());}); + } + if (LDG == 2) { + const dim3 threads2 (64, blocky, 1); + const dim3 blocks2 (1, gridy, 1); + BOOL_SWITCH(memory_efficient, MemoryEfficient, [&]{fusedGradInputWeights<2, false, MemoryEfficient><<>>(dout, input->DATA_PTR(), n1, n2, mean, invvar, gamma, beta, double(epsilon), grad_input, part_grad_gamma.DATA_PTR(), part_grad_beta.DATA_PTR());}); + } + if (LDG == 4) { + const dim3 threads2 (64, blocky, 1); + const dim3 blocks2 (1, gridy, 1); + BOOL_SWITCH(memory_efficient, MemoryEfficient, [&]{fusedGradInputWeights<4, false, MemoryEfficient><<>>(dout, input->DATA_PTR(), n1, n2, mean, invvar, gamma, beta, double(epsilon), grad_input, part_grad_gamma.DATA_PTR(), part_grad_beta.DATA_PTR());}); + } + if (LDG == 8) { + const dim3 threads2 (64, blocky, 1); + const dim3 blocks2 (1, gridy, 1); + BOOL_SWITCH(memory_efficient, MemoryEfficient, [&]{fusedGradInputWeights<8, false, MemoryEfficient><<>>(dout, input->DATA_PTR(), n1, n2, mean, invvar, gamma, beta, double(epsilon), grad_input, part_grad_gamma.DATA_PTR(), part_grad_beta.DATA_PTR());}); + } + if (LDG == 16) { + const dim3 threads2 (64, blocky, 1); + const dim3 blocks2 (1, gridy, 1); + BOOL_SWITCH(memory_efficient, MemoryEfficient, [&]{fusedGradInputWeights<16, false, MemoryEfficient><<>>(dout, input->DATA_PTR(), n1, n2, mean, invvar, gamma, beta, double(epsilon), grad_input, part_grad_gamma.DATA_PTR(), part_grad_beta.DATA_PTR());}); + } + if (LDG == 32) { + const dim3 threads2 (64, blocky, 1); + const dim3 blocks2 (1, gridy, 1); + BOOL_SWITCH(memory_efficient, MemoryEfficient, [&]{fusedGradInputWeights<32, false, MemoryEfficient><<>>(dout, input->DATA_PTR(), n1, n2, mean, invvar, gamma, beta, double(epsilon), grad_input, part_grad_gamma.DATA_PTR(), part_grad_beta.DATA_PTR());}); + } + + + const dim3 threads3 (64, 1, 1); + const dim3 blocks3 (n2/64, 1, 1); + ComputeGradGammaBeta_opt<<>>( + part_grad_gamma.DATA_PTR(), part_grad_beta.DATA_PTR(), gridy, n2, grad_gamma, grad_beta + ); + } else { + // compute grad_input + const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; +#ifdef __ILUVATAR__ + const dim3 blocks1(1, std::min((uint64_t)n1, maxGridY/4), 1); +#else + const dim3 blocks1(1, std::min((uint64_t)n1, maxGridY), 1); +#endif + const dim3 threads1(32,4,1); + int nshared = + threads1.y > 1 ? + threads1.y*threads1.x*sizeof(U) : + 0; + cuComputeGradInput<<>>( + dout, + input->DATA_PTR(), + n1,n2, + mean, + invvar, + U(epsilon), + gamma, + grad_input, + false); + } +} + +template +void HostRMSNormGradient_opt( + const V* dout, + const U* invvar, + at::Tensor* input, + int n1, + int n2, + const V* gamma, + double epsilon, + T* grad_input, + V* grad_gamma, + bool memory_efficient) +{ + auto stream = at::cuda::getCurrentCUDAStream().stream(); + + if (gamma != NULL) { + int div = 1; + if (sizeof(T) < 4) {div = 4/sizeof(T);} + int blocky = n2/64/div; + int LDG; + int gridy; + + if (blocky > 32) { + for (int i =2;i0;i--) { + if (n1 % i == 0) {gridy = i;break;} + } + } else { + gridy = 16 * 128 / blocky; + if (sizeof(T) == 2) {gridy -= 1;} + for (int i = gridy;i>0;i--) { + if (n1 % i == 0) {gridy = i;break;} + } + } + LDG = n2/64/blocky/div; + + const auto part_grad_dtype = + (input->scalar_type() == at::ScalarType::Half || input->scalar_type() == at::ScalarType::BFloat16) ? + at::ScalarType::Float : + input->scalar_type(); + at::Tensor part_grad_gamma = at::empty({gridy,n2}, input->options().dtype(part_grad_dtype)); + + if (LDG == 1) { + const dim3 threads2 (64, blocky, 1); + const dim3 blocks2 (1, gridy, 1); + BOOL_SWITCH(memory_efficient, MemoryEfficient, [&]{fusedGradInputWeights<1, true, MemoryEfficient><<>>(dout, input->DATA_PTR(), n1, n2, invvar, invvar, gamma, gamma, epsilon, grad_input, part_grad_gamma.DATA_PTR(), part_grad_gamma.DATA_PTR());}); + } + if (LDG == 2) { + const dim3 threads2 (64, blocky, 1); + const dim3 blocks2 (1, gridy, 1); + BOOL_SWITCH(memory_efficient, MemoryEfficient, [&]{fusedGradInputWeights<2, true, MemoryEfficient><<>>(dout, input->DATA_PTR(), n1, n2, invvar, invvar, gamma, gamma, epsilon, grad_input, part_grad_gamma.DATA_PTR(), part_grad_gamma.DATA_PTR());}); + } + if (LDG == 4) { + const dim3 threads2 (64, blocky, 1); + const dim3 blocks2 (1, gridy, 1); + BOOL_SWITCH(memory_efficient, MemoryEfficient, [&]{fusedGradInputWeights<4, true, MemoryEfficient><<>>(dout, input->DATA_PTR(), n1, n2, invvar, invvar, gamma, gamma, epsilon, grad_input, part_grad_gamma.DATA_PTR(), part_grad_gamma.DATA_PTR());}); + } + if (LDG == 8) { + const dim3 threads2 (64, blocky, 1); + const dim3 blocks2 (1, gridy, 1); + BOOL_SWITCH(memory_efficient, MemoryEfficient, [&]{fusedGradInputWeights<8, true, MemoryEfficient><<>>(dout, input->DATA_PTR(), n1, n2, invvar, invvar, gamma, gamma, epsilon, grad_input, part_grad_gamma.DATA_PTR(), part_grad_gamma.DATA_PTR());}); + } + if (LDG == 16) { + const dim3 threads2 (64, blocky, 1); + const dim3 blocks2 (1, gridy, 1); + BOOL_SWITCH(memory_efficient, MemoryEfficient, [&]{fusedGradInputWeights<16, true, MemoryEfficient><<>>(dout, input->DATA_PTR(), n1, n2, invvar, invvar, gamma, gamma, epsilon, grad_input, part_grad_gamma.DATA_PTR(), part_grad_gamma.DATA_PTR());}); + } + if (LDG == 32) { + const dim3 threads2 (64, blocky, 1); + const dim3 blocks2 (1, gridy, 1); + BOOL_SWITCH(memory_efficient, MemoryEfficient, [&]{fusedGradInputWeights<32, true, MemoryEfficient><<>>(dout, input->DATA_PTR(), n1, n2, invvar, invvar, gamma, gamma, epsilon, grad_input, part_grad_gamma.DATA_PTR(), part_grad_gamma.DATA_PTR());}); + } + + + const dim3 threads3 (64, 1, 1); + const dim3 blocks3 (n2/64, 1, 1); + ComputeGradGammaBeta_opt<<>>( + part_grad_gamma.DATA_PTR(), part_grad_gamma.DATA_PTR(), gridy, n2, grad_gamma, grad_gamma + ); + } else { + + // compute grad_input + const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; +#ifdef __ILUVATAR__ + const dim3 blocks1(1, std::min((uint64_t)n1, maxGridY/4), 1); +#else + const dim3 blocks1(1, std::min((uint64_t)n1, maxGridY), 1); +#endif + const dim3 threads1(32,4,1); + int nshared = + threads1.y > 1 ? + threads1.y*threads1.x*sizeof(U) : + 0; + cuComputeGradInput<<>>( + dout, + input->DATA_PTR(), + n1,n2, + invvar, /* unused */ + invvar, + U(epsilon), + gamma, + grad_input, + true); + } +} + +template +void HostRMSNormGradient_opt2( + const V* dout, + const V* dres, + const U* invvar, + at::Tensor* input, + int n1, + int n2, + const V* gamma, + double epsilon, + T* grad_input, + V* grad_gamma, + bool memory_efficient) +{ + auto stream = at::cuda::getCurrentCUDAStream().stream(); + + if (gamma != NULL) { + int div = 1; + if (sizeof(T) < 4) {div = 4/sizeof(T);} + int blocky = n2/64/div; + int LDG; + int gridy; + + if (blocky > 32) { + for (int i =2;i0;i--) { + if (n1 % i == 0) {gridy = i;break;} + } + } else { + gridy = 16 * 128 / blocky; + if (sizeof(T) == 2) {gridy -= 1;} + for (int i = gridy;i>0;i--) { + if (n1 % i == 0) {gridy = i;break;} + } + } + LDG = n2/64/blocky/div; + + const auto part_grad_dtype = + (input->scalar_type() == at::ScalarType::Half || input->scalar_type() == at::ScalarType::BFloat16) ? + at::ScalarType::Float : + input->scalar_type(); + at::Tensor part_grad_gamma = at::empty({gridy,n2}, input->options().dtype(part_grad_dtype)); + + if (LDG == 1) { + const dim3 threads2 (64, blocky, 1); + const dim3 blocks2 (1, gridy, 1); + BOOL_SWITCH(memory_efficient, MemoryEfficient, [&]{fusedGradInputWeights_<1, true, MemoryEfficient><<>>(dout, dres, input->DATA_PTR(), n1, n2, invvar, invvar, gamma, gamma, epsilon, grad_input, part_grad_gamma.DATA_PTR(), part_grad_gamma.DATA_PTR());}); + } + if (LDG == 2) { + const dim3 threads2 (64, blocky, 1); + const dim3 blocks2 (1, gridy, 1); + BOOL_SWITCH(memory_efficient, MemoryEfficient, [&]{fusedGradInputWeights_<2, true, MemoryEfficient><<>>(dout, dres, input->DATA_PTR(), n1, n2, invvar, invvar, gamma, gamma, epsilon, grad_input, part_grad_gamma.DATA_PTR(), part_grad_gamma.DATA_PTR());}); + } + if (LDG == 4) { + const dim3 threads2 (64, blocky, 1); + const dim3 blocks2 (1, gridy, 1); + BOOL_SWITCH(memory_efficient, MemoryEfficient, [&]{fusedGradInputWeights_<4, true, MemoryEfficient><<>>(dout, dres, input->DATA_PTR(), n1, n2, invvar, invvar, gamma, gamma, epsilon, grad_input, part_grad_gamma.DATA_PTR(), part_grad_gamma.DATA_PTR());}); + } + if (LDG == 8) { + const dim3 threads2 (64, blocky, 1); + const dim3 blocks2 (1, gridy, 1); + BOOL_SWITCH(memory_efficient, MemoryEfficient, [&]{fusedGradInputWeights_<8, true, MemoryEfficient><<>>(dout, dres, input->DATA_PTR(), n1, n2, invvar, invvar, gamma, gamma, epsilon, grad_input, part_grad_gamma.DATA_PTR(), part_grad_gamma.DATA_PTR());}); + } + if (LDG == 16) { + const dim3 threads2 (64, blocky, 1); + const dim3 blocks2 (1, gridy, 1); + BOOL_SWITCH(memory_efficient, MemoryEfficient, [&]{fusedGradInputWeights_<16, true, MemoryEfficient><<>>(dout, dres, input->DATA_PTR(), n1, n2, invvar, invvar, gamma, gamma, epsilon, grad_input, part_grad_gamma.DATA_PTR(), part_grad_gamma.DATA_PTR());}); + } + if (LDG == 32) { + const dim3 threads2 (64, blocky, 1); + const dim3 blocks2 (1, gridy, 1); + BOOL_SWITCH(memory_efficient, MemoryEfficient, [&]{fusedGradInputWeights_<32, true, MemoryEfficient><<>>(dout, dres, input->DATA_PTR(), n1, n2, invvar, invvar, gamma, gamma, epsilon, grad_input, part_grad_gamma.DATA_PTR(), part_grad_gamma.DATA_PTR());}); + } + + + const dim3 threads3 (64, 1, 1); + const dim3 blocks3 (n2/64, 1, 1); + ComputeGradGammaBeta_opt<<>>( + part_grad_gamma.DATA_PTR(), part_grad_gamma.DATA_PTR(), gridy, n2, grad_gamma, grad_gamma + ); + } else { + + // compute grad_input + const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; +#ifdef __ILUVATAR__ + const dim3 blocks1(1, std::min((uint64_t)n1, maxGridY/4), 1); +#else + const dim3 blocks1(1, std::min((uint64_t)n1, maxGridY), 1); +#endif + const dim3 threads1(32,4,1); + int nshared = + threads1.y > 1 ? + threads1.y*threads1.x*sizeof(U) : + 0; + cuComputeGradInput<<>>( + dout, + input->DATA_PTR(), + n1,n2, + invvar, /* unused */ + invvar, + U(epsilon), + gamma, + grad_input, + true); + } +} + +void cuda_layer_norm_gradient( + at::Tensor* dout, + at::Tensor* mean, + at::Tensor* invvar, + at::Tensor* input, + int n1, + int n2, + #ifdef VERSION_GE_1_1 + at::IntArrayRef normalized_shape, + #else + at::IntList normalized_shape, + #endif + at::Tensor* gamma, + at::Tensor* beta, + float epsilon, + at::Tensor* grad_input, + at::Tensor* grad_gamma, + at::Tensor* grad_beta, + bool memory_efficient) +{ + using namespace at; + // we can do away with `accscalar_t` as there're only three dtypes: fp32, fp16, bf16 + DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES( + input->scalar_type(), gamma == NULL ? input->scalar_type() : gamma->scalar_type(), "cuComputeGradInput", + using accscalar_t = at::acc_type; + HostLayerNormGradient_opt( + dout->DATA_PTR(), + mean != NULL ? mean->DATA_PTR() : NULL, + invvar->DATA_PTR(), + input, + n1,n2, + // TMJ pass NULL argument for gamma, beta, grad_gamma and grad_beta + // if gamma Tensor is NULL on input. + gamma != NULL ? gamma->DATA_PTR() : NULL, + gamma != NULL ? beta->DATA_PTR() : NULL, + epsilon, + grad_input->DATA_PTR(), + gamma != NULL ? grad_gamma->DATA_PTR() : NULL, + gamma != NULL ? grad_beta->DATA_PTR() : NULL, + memory_efficient); + ) +} + +void cuda_rms_norm_gradient( + at::Tensor* dout, + at::Tensor* invvar, + at::Tensor* input, + int n1, + int n2, + #ifdef VERSION_GE_1_1 + at::IntArrayRef normalized_shape, + #else + at::IntList normalized_shape, + #endif + at::Tensor* gamma, + double epsilon, + at::Tensor* grad_input, + at::Tensor* grad_gamma, + bool memory_efficient) +{ + using namespace at; + // we can do away with `accscalar_t` as there're only three dtypes: fp32, fp16, bf16 + // DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES( + DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES( + input->scalar_type(), gamma == NULL ? input->scalar_type() : gamma->scalar_type(), "cuComputeGradInputRMS", + using accscalar_t = at::acc_type; + HostRMSNormGradient_opt( + dout->DATA_PTR(), + invvar->DATA_PTR(), + input, + n1,n2, + // TMJ pass NULL argument for gamma, beta, grad_gamma and grad_beta + // if gamma Tensor is NULL on input. + gamma != NULL ? gamma->DATA_PTR() : NULL, + epsilon, + grad_input->DATA_PTR(), + gamma != NULL ? grad_gamma->DATA_PTR() : NULL, + memory_efficient); + ) +} + +void cuda_rms_norm_residual_gradient( + at::Tensor* dout, + at::Tensor* dres, + at::Tensor* invvar, + at::Tensor* input, + int n1, + int n2, + #ifdef VERSION_GE_1_1 + at::IntArrayRef normalized_shape, + #else + at::IntList normalized_shape, + #endif + at::Tensor* gamma, + double epsilon, + at::Tensor* grad_input, + at::Tensor* grad_gamma, + bool memory_efficient) +{ + using namespace at; + // we can do away with `accscalar_t` as there're only three dtypes: fp32, fp16, bf16 + // DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES( + DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES( + input->scalar_type(), gamma == NULL ? input->scalar_type() : gamma->scalar_type(), "cuComputeGradInputRMS", + using accscalar_t = at::acc_type; + HostRMSNormGradient_opt2( + dout->DATA_PTR(), + dres->DATA_PTR(), + invvar->DATA_PTR(), + input, + n1,n2, + // TMJ pass NULL argument for gamma, beta, grad_gamma and grad_beta + // if gamma Tensor is NULL on input. + gamma != NULL ? gamma->DATA_PTR() : NULL, + epsilon, + grad_input->DATA_PTR(), + gamma != NULL ? grad_gamma->DATA_PTR() : NULL, + memory_efficient); + ) +} \ No newline at end of file diff --git a/toolbox/DeepSpeed/v0.15.3/patches/csrc/quantization/fake_quantizer.cu b/toolbox/DeepSpeed/v0.15.3/patches/csrc/quantization/fake_quantizer.cu new file mode 100644 index 0000000000000000000000000000000000000000..21383316d4551e593c2c10996de732b4211de35e --- /dev/null +++ b/toolbox/DeepSpeed/v0.15.3/patches/csrc/quantization/fake_quantizer.cu @@ -0,0 +1,1042 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +/* Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. */ +/* All Rights Reserved. */ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include +#include "custom_cuda_layers.h" +#include "memory_access_utils.h" + +namespace cg = cooperative_groups; + +__global__ void fake_quantize_kernel(__half* vals, int group_size, int num_bits) +{ + cg::thread_block b = cg::this_thread_block(); // tb + cg::thread_block_tile<32> g = + cg::tiled_partition<32>(b); // warp, 32 not optimal for AMD which should be 64. + + int gid = threadIdx.x >> 5; + int lane = threadIdx.x & 0x1f; + int warp_num = blockDim.x >> 5; + int id = threadIdx.x; + + constexpr int granularity = 16; + constexpr int vals_per_access = granularity / sizeof(__half); + + __half data[vals_per_access]; + + int group_id = blockIdx.x; + + int thread_index = id * vals_per_access; + int reg_count = 0; + int offset = group_id * group_size; + float max = -10000.0; + for (int thread_index = id * vals_per_access; thread_index < group_size; + thread_index += blockDim.x * vals_per_access) { + mem_access::load_global(data, vals + offset + thread_index); + +#pragma unroll + for (int i = 0; i < vals_per_access; i++) { + if (abs((float)data[i]) > max) max = abs((float)data[i]); + } + } + +#pragma unroll + for (int i = 1; i < WARP_SIZE; i <<= 1) { + auto temp = g.shfl_xor(max, i); + if (max < temp) max = temp; + } + __shared__ float partialMax[WARP_SIZE]; + + if (lane == 0) partialMax[gid] = max; + + b.sync(); + + if (lane < warp_num) max = partialMax[lane]; + +#pragma unroll + for (int i = 1; i < WARP_SIZE; i <<= 1) { + auto temp = g.shfl_down(max, i); + if (max < temp) max = temp; + } + + max = g.shfl(max, 0); + + float q_scale = (float)(1 << num_bits) / (2 * max + 1e-5); + float q_scale_inv = 1 / q_scale; + int q_range_max = (1 << (num_bits - 1)) - 1; + int q_range_min = -(1 << (num_bits - 1)); + + for (int thread_index = id * vals_per_access; thread_index < group_size; + thread_index += blockDim.x * vals_per_access) { + mem_access::load_global(data, vals + offset + thread_index); +#pragma unroll + for (int j = 0; j < vals_per_access; j++) { + float q_data; + q_data = __half2float(data[j]); + q_data = __float2int_rn(q_data * q_scale); + q_data = q_data > (q_range_max) ? (q_range_max) + : (q_data < (q_range_min) ? (q_range_min) : q_data); + data[j] = __float2half_rn(q_data * q_scale_inv); + } + mem_access::store_global(vals + offset + thread_index, data); + } + +} + +__global__ void fake_quantize_kernel(float* vals, int group_size, int num_bits) +{ + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile<32> g = cg::tiled_partition<32>(b); + + int gid = threadIdx.x >> 5; + int lane = threadIdx.x & 0x1f; + int warp_num = blockDim.x >> 5; + int id = threadIdx.x; + + constexpr int granularity = 16; + constexpr int vals_per_access = granularity / sizeof(float); + + float data[vals_per_access]; + + int bid = blockIdx.x; + + int thread_index = id * vals_per_access; + + int reg_count = 0; + + int offset = bid * group_size; + + float max = -10000.0; + + for (int thread_index = id * vals_per_access; thread_index < group_size; + thread_index += blockDim.x * vals_per_access) { + mem_access::load_global(data, vals + offset + thread_index); + +#pragma unroll + for (int i = 0; i < vals_per_access; i++) { + if (abs(data[i]) > max) max = abs(data[i]); + } + } + +#pragma unroll + for (int i = 1; i < WARP_SIZE; i <<= 1) { + auto temp = g.shfl_xor(max, i); + if (max < temp) max = temp; + } + __shared__ float partialMax[WARP_SIZE]; + + if (lane == 0) partialMax[gid] = max; + + b.sync(); + + if (lane < warp_num) max = partialMax[lane]; + + b.sync(); + +#pragma unroll + for (int i = 1; i < warp_num; i <<= 1) { + auto temp = g.shfl_down(max, i); + if (max < temp) max = temp; + } + + max = g.shfl(max, 0); + + float q_scale = (1 << num_bits) / (2 * max + 1e-5); + float q_scale_inv = 1 / q_scale; + + int q_range_max = (1 << (num_bits - 1)) - 1; + int q_range_min = -(1 << (num_bits - 1)); + + for (int thread_index = id * vals_per_access; thread_index < group_size; + thread_index += blockDim.x * vals_per_access) { + mem_access::load_global(data, vals + offset + thread_index); +#pragma unroll + for (int j = 0; j < vals_per_access; j++) { + float q_data; + q_data = __float2int_rn(data[j] * q_scale); + q_data = q_data > (q_range_max) ? (q_range_max) + : (q_data < (q_range_min) ? (q_range_min) : q_data); + data[j] = roundf(q_data * q_scale_inv); + } + mem_access::store_global(vals + offset + thread_index, data); + } +} + +template +void launch_fake_quantize_kernel(T* vals, + int total_count, + int group_num, + int num_bits, + cudaStream_t stream) +{ + dim3 grid_dim(group_num); + dim3 block_dim(1024); + + fake_quantize_kernel<<>>( + vals, total_count / group_num, num_bits); +} + +template void launch_fake_quantize_kernel(float* vals, + int total_count, + int group_num, + int num_bits, + cudaStream_t stream); +template void launch_fake_quantize_kernel(__half* vals, + int total_count, + int group_num, + int num_bits, + cudaStream_t stream); + +__global__ void sr_fake_quantize_kernel(__half* vals, + int token_size, + int token_num, + int num_bits, + std::pair seed) +{ +#if __CUDA_ARCH__ >= 700 || defined(__HIP_PLATFORM_AMD__) + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile<32> g = cg::tiled_partition<32>(b); + + int gid = threadIdx.x >> 5; + int lane = threadIdx.x & 0x1f; + int warp_num = blockDim.x >> 5; + + int idx = blockIdx.x * blockDim.x + threadIdx.x; + + float2* vals_cast = reinterpret_cast(vals); + + __half2 data_low[128]; + __half2 data_high[128]; + + int bid = blockIdx.x; + + curandStatePhilox4_32_10_t state; + curand_init(seed.first, idx, seed.second, &state); + unsigned int tid = threadIdx.x; + int reg_count = 0; + int offset = bid * token_size; + int group_index = bid * token_size + tid; + + int total_count = token_size * token_num; + if (group_index < total_count) { + // float min = 10000.0; + float max = -10000.0; + while (tid < token_size) { + float2 data = vals_cast[offset + tid]; + __half2* data_h = reinterpret_cast<__half2*>(&data); + data_low[reg_count] = data_h[0]; + data_high[reg_count] = data_h[1]; + + float2 data_f[2]; + data_f[0] = __half22float2(data_h[0]); + data_f[1] = __half22float2(data_h[1]); + + if (abs((float)data_f[0].x) > max) max = abs((float)data_f[0].x); + if (abs((float)data_f[0].y) > max) max = abs((float)data_f[0].y); + if (abs((float)data_f[1].x) > max) max = abs((float)data_f[1].x); + if (abs((float)data_f[1].y) > max) max = abs((float)data_f[1].y); + + tid += blockDim.x; + reg_count++; + } + +#pragma unroll + for (int i = 1; i < WARP_SIZE; i <<= 1) { + auto temp = g.shfl_xor(max, i); + if (max < temp) max = temp; + } + + __shared__ float partialMax[WARP_SIZE]; + + if (lane == 0) partialMax[gid] = max; + + b.sync(); + + if (lane < warp_num) max = partialMax[lane]; + +#pragma unroll + for (int i = 1; i < warp_num; i <<= 1) { + auto temp = g.shfl_down(max, i); + if (max < temp) max = temp; + } + + max = g.shfl(max, 0); + + float q_scale_val = (float)(1 << num_bits) / (max * 2 + 1e-5); + float high_q = (float)((1 << (num_bits - 1)) - 1); + float low_q = (float)(-((1 << (num_bits - 1)))); + + for (int i = 0; i < reg_count; i++) { + int token_index = i * blockDim.x + threadIdx.x; + if (token_index < token_size) { + float2 data_f[2]; + data_f[0] = __half22float2(data_low[i]); + data_f[1] = __half22float2(data_high[i]); + + float2 q_data_int[2]; + q_data_int[0].x = (float)((int)(data_f[0].x * q_scale_val)); + q_data_int[0].y = (float)((int)(data_f[0].y * q_scale_val)); + q_data_int[1].x = (float)((int)(data_f[1].x * q_scale_val)); + q_data_int[1].y = (float)((int)(data_f[1].y * q_scale_val)); + + // Stochastic rounding + float4 rand = curand_uniform4(&state); + + float q_error[4]; + q_error[0] = abs(data_f[0].x - (q_data_int[0].x / q_scale_val)) * q_scale_val; + q_error[1] = abs(data_f[0].y - (q_data_int[0].y / q_scale_val)) * q_scale_val; + q_error[2] = abs(data_f[1].x - (q_data_int[1].x / q_scale_val)) * q_scale_val; + q_error[3] = abs(data_f[1].y - (q_data_int[1].y / q_scale_val)) * q_scale_val; + + q_data_int[0].x = + (rand.x < q_error[0] && q_data_int[0].x > low_q && q_data_int[0].x < high_q) + ? (q_data_int[0].x + (data_f[0].x > 0 ? 1 : -1)) + : q_data_int[0].x; + q_data_int[0].y = + (rand.y < q_error[1] && q_data_int[0].y > low_q && q_data_int[0].y < high_q) + ? (q_data_int[0].y + (data_f[0].y > 0 ? 1 : -1)) + : q_data_int[0].y; + q_data_int[1].x = + (rand.w < q_error[2] && q_data_int[1].x > low_q && q_data_int[1].x < high_q) + ? (q_data_int[1].x + (data_f[1].x > 0 ? 1 : -1)) + : q_data_int[1].x; + q_data_int[1].y = + (rand.z < q_error[3] && q_data_int[1].y > low_q && q_data_int[1].y < high_q) + ? (q_data_int[1].y + (data_f[1].y > 0 ? 1 : -1)) + : q_data_int[1].y; + + data_f[0].x = q_data_int[0].x / q_scale_val; + data_f[0].y = q_data_int[0].y / q_scale_val; + data_f[1].x = q_data_int[1].x / q_scale_val; + data_f[1].y = q_data_int[1].y / q_scale_val; + + float2 result; + __half2* result_h = reinterpret_cast<__half2*>(&result); + result_h[0] = __float22half2_rn(data_f[0]); + result_h[1] = __float22half2_rn(data_f[1]); + + vals_cast[offset + token_index] = result; + } + } + } +#endif +} + +__global__ void sr_fake_quantize_kernel(float* vals, + int token_size, + int token_num, + int num_bits, + std::pair seed) +{ + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile<32> g = cg::tiled_partition<32>(b); + + int gid = threadIdx.x >> 5; + int lane = threadIdx.x & 0x1f; + int warp_num = blockDim.x >> 5; + int id = threadIdx.x; + + int idx = blockIdx.x * blockDim.x + id; + + float4* vals_cast = reinterpret_cast(vals); + + float4 data[128]; + + int bid = blockIdx.x; + int tid = threadIdx.x; + curandStatePhilox4_32_10_t state; + curand_init(seed.first, idx, seed.second, &state); + + int group_index = bid * token_size + threadIdx.x; + int reg_count = 0; + int total_count = token_size * token_num; + if (group_index < total_count) { + // float min = 10000.0; + float max = -10000.0; + + while (tid < token_size) { + data[reg_count] = vals_cast[group_index]; + + if (abs(data[reg_count].x) > max) max = abs(data[reg_count].x); + if (abs(data[reg_count].y) > max) max = abs(data[reg_count].y); + if (abs(data[reg_count].z) > max) max = abs(data[reg_count].z); + if (abs(data[reg_count].w) > max) max = abs(data[reg_count].w); + + group_index += blockDim.x; + tid += blockDim.x; + reg_count++; + } + +#pragma unroll + for (int i = 1; i < WARP_SIZE; i <<= 1) { + auto temp = g.shfl_xor(max, i); + if (max < temp) max = temp; + } + __shared__ float partialMax[WARP_SIZE]; + + if (lane == 0) partialMax[gid] = max; + + b.sync(); + + if (lane < warp_num) max = partialMax[lane]; + +#pragma unroll + for (int i = 1; i < warp_num; i <<= 1) { + auto temp = g.shfl_down(max, i); + if (max < temp) max = temp; + } + + max = g.shfl(max, 0); + + float q_scale_val = (float)(1 << num_bits) / (max * 2 + 1e-5); + float high_q = (float)((1 << (num_bits - 1)) - 1); + float low_q = (float)(-((1 << (num_bits - 1)))); + + int offset = (bid)*token_size; + for (int i = 0; i < reg_count; i++) { + group_index = i * blockDim.x + threadIdx.x; + if (group_index < token_size) { + float4 q_data = data[i]; + + float4 q_data_int; + q_data_int.x = (float)((int)(q_data.x * q_scale_val)); + q_data_int.y = (float)((int)(q_data.y * q_scale_val)); + q_data_int.w = (float)((int)(q_data.w * q_scale_val)); + q_data_int.z = (float)((int)(q_data.z * q_scale_val)); + + // Stochastic rounding + float4 rand = curand_uniform4(&state); + + float q_error[4]; + q_error[0] = abs(q_data.x - (q_data_int.x / q_scale_val)) * q_scale_val; + q_error[1] = abs(q_data.y - (q_data_int.y / q_scale_val)) * q_scale_val; + q_error[2] = abs(q_data.w - (q_data_int.w / q_scale_val)) * q_scale_val; + q_error[3] = abs(q_data.z - (q_data_int.z / q_scale_val)) * q_scale_val; + + q_data_int.x = + (rand.x < q_error[0] && q_data_int.x > low_q && q_data_int.x < high_q) + ? (q_data_int.x + (q_data.x > 0 ? 1 : -1)) + : q_data_int.x; + q_data_int.y = + (rand.y < q_error[1] && q_data_int.y > low_q && q_data_int.y < high_q) + ? (q_data_int.y + (q_data.y > 0 ? 1 : -1)) + : q_data_int.y; + q_data_int.w = + (rand.w < q_error[2] && q_data_int.w > low_q && q_data_int.w < high_q) + ? (q_data_int.w + (q_data.w > 0 ? 1 : -1)) + : q_data_int.w; + q_data_int.z = + (rand.z < q_error[3] && q_data_int.z > low_q && q_data_int.z < high_q) + ? (q_data_int.z + (q_data.z > 0 ? 1 : -1)) + : q_data_int.z; + + q_data_int.x /= q_scale_val; + q_data_int.y /= q_scale_val; + q_data_int.w /= q_scale_val; + q_data_int.z /= q_scale_val; + + vals_cast[group_index + offset] = q_data_int; + } + } + } +} + +template +void launch_sr_fake_quantize_kernel(T* vals, + int total_count, + int group_num, + int num_bits, + cudaStream_t stream) +{ + dim3 block_dim(1024); + dim3 grid_dim(group_num); + + uint64_t inc = total_count / grid_dim.x / block_dim.x; + std::pair seed = TrainingContext::Instance().IncrementOffset(inc); + + sr_fake_quantize_kernel<<>>( + vals, (total_count / group_num) / 4, group_num, num_bits, seed); +} +template void launch_sr_fake_quantize_kernel(float* vals, + int total_count, + int group_num, + int num_bits, + cudaStream_t stream); +template void launch_sr_fake_quantize_kernel(__half* vals, + int total_count, + int group_num, + int num_bits, + cudaStream_t stream); + +__global__ void fake_quantize_kernel_asym(__half* vals, int group_size, int num_bits) +{ +#if __CUDA_ARCH__ >= 700 || defined(__HIP_PLATFORM_AMD__) + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile<32> g = cg::tiled_partition<32>(b); + + int gid = threadIdx.x >> 5; + int lane = threadIdx.x & 0x1f; + int warp_num = blockDim.x >> 5; + int id = threadIdx.x; + + float2* vals_cast = reinterpret_cast(vals); + + float2 data[MAX_REG]; + + int group_id = blockIdx.x; + + { + int group_index = id; + int reg_count = 0; + int offset = group_id * group_size; + float max = -10000.0; + float min = 10000.0; + + while (group_index < group_size && reg_count < MAX_REG) { + data[reg_count] = vals_cast[offset + group_index]; + __half* data_h = reinterpret_cast<__half*>(&data[reg_count]); + + if (((float)data_h[0]) > max) max = (float)data_h[0]; + if (((float)data_h[1]) > max) max = (float)data_h[1]; + if (((float)data_h[2]) > max) max = (float)data_h[2]; + if (((float)data_h[3]) > max) max = (float)data_h[3]; + + if (((float)data_h[0]) < min) min = (float)data_h[0]; + if (((float)data_h[1]) < min) min = (float)data_h[1]; + if (((float)data_h[2]) < min) min = (float)data_h[2]; + if (((float)data_h[3]) < min) min = (float)data_h[3]; + + group_index += blockDim.x; + reg_count++; + } + +#pragma unroll + for (int i = 1; i < WARP_SIZE; i <<= 1) { + auto temp = g.shfl_xor(max, i); + if (max < temp) max = temp; + } + +#pragma unroll + for (int i = 1; i < WARP_SIZE; i <<= 1) { + auto temp = g.shfl_xor(min, i); + if (min > temp) min = temp; + } + + __shared__ float partialMax[WARP_SIZE]; + __shared__ float partialMin[WARP_SIZE]; + + if (lane == 0) partialMax[gid] = max; + if (lane == 0) partialMin[gid] = min; + + b.sync(); + + if (lane < warp_num) max = partialMax[lane]; + if (lane < warp_num) min = partialMin[lane]; + +#pragma unroll + for (int i = 1; i < warp_num; i <<= 1) { + auto temp = g.shfl_down(max, i); + if (max < temp) max = temp; + } +#pragma unroll + for (int i = 1; i < warp_num; i <<= 1) { + auto temp = g.shfl_down(min, i); + if (min > temp) min = temp; + } + + max = g.shfl(max, 0); + min = g.shfl(min, 0); + + float q_scale = ((max - min) + 1e-5) / (float)(1 << num_bits); + float q_scale_inv = 1 / q_scale; + + for (int i = 0; i < reg_count; i++) { + group_index = i * blockDim.x + id; + if (group_index < group_size) { + __half2* data_h = reinterpret_cast<__half2*>(&data[i]); + float2 q_data[2]; + q_data[0] = __half22float2(data_h[0]); + q_data[1] = __half22float2(data_h[1]); + + float2 q_data_int[2]; + + q_data_int[0].x = roundf((q_data[0].x - min) * q_scale_inv); + q_data_int[0].y = roundf((q_data[0].y - min) * q_scale_inv); + q_data_int[1].x = roundf((q_data[1].x - min) * q_scale_inv); + q_data_int[1].y = roundf((q_data[1].y - min) * q_scale_inv); + + q_data_int[0].x = q_data_int[0].x * q_scale + min; + q_data_int[0].y = q_data_int[0].y * q_scale + min; + q_data_int[1].x = q_data_int[1].x * q_scale + min; + q_data_int[1].y = q_data_int[1].y * q_scale + min; + + data_h[0] = __float22half2_rn(q_data_int[0]); + data_h[1] = __float22half2_rn(q_data_int[1]); + + vals_cast[offset + group_index] = data[i]; + } + } + } +#endif +} + +__global__ void fake_quantize_kernel_asym(float* vals, int group_size, int num_bits) +{ + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile<32> g = cg::tiled_partition<32>(b); + + int gid = threadIdx.x >> 5; + int lane = threadIdx.x & 0x1f; + int warp_num = blockDim.x >> 5; + int id = threadIdx.x; + + float4* vals_cast = reinterpret_cast(vals); + + float4 data[MAX_REG]; + + int bid = blockIdx.x; + + int group_index = bid * group_size + id; + int reg_count = 0; + + float max = -10000.0; + float min = 10000.0; + + while (id < group_size && reg_count < MAX_REG) { + float4 data_reg = vals_cast[group_index]; + data[reg_count] = data_reg; + + if (data_reg.x > max) max = data_reg.x; + if (data_reg.y > max) max = data_reg.y; + if (data_reg.w > max) max = data_reg.w; + if (data_reg.z > max) max = data_reg.z; + + if (data_reg.x < min) min = data_reg.x; + if (data_reg.y < min) min = data_reg.y; + if (data_reg.w < min) min = data_reg.w; + if (data_reg.z < min) min = data_reg.z; + + group_index += blockDim.x; + id += blockDim.x; + reg_count++; + } + id = threadIdx.x; + +#pragma unroll + for (int i = 1; i < WARP_SIZE; i <<= 1) { + auto temp = g.shfl_xor(max, i); + if (max < temp) max = temp; + } + +#pragma unroll + for (int i = 1; i < WARP_SIZE; i <<= 1) { + auto temp = g.shfl_xor(min, i); + if (min > temp) min = temp; + } + + __shared__ float partialMax[WARP_SIZE]; + __shared__ float partialMin[WARP_SIZE]; + + if (lane == 0) partialMax[gid] = max; + if (lane == 0) partialMin[gid] = min; + + b.sync(); + + if (lane < warp_num) max = partialMax[lane]; + if (lane < warp_num) min = partialMin[lane]; + +#pragma unroll + for (int i = 1; i < warp_num; i <<= 1) { + auto temp = g.shfl_down(max, i); + if (max < temp) max = temp; + } +#pragma unroll + for (int i = 1; i < warp_num; i <<= 1) { + auto temp = g.shfl_down(min, i); + if (min > temp) min = temp; + } + + max = g.shfl(max, 0); + min = g.shfl(min, 0); + + float q_scale = ((max - min) + 1e-5) / (float)(1 << num_bits); + float q_scale_inv = 1 / q_scale; + for (int i = 0; i < reg_count; i++) { + group_index = i * blockDim.x + id; + if (group_index < group_size) { + float4 q_data; + q_data = data[i]; + + float4 q_data_int; + q_data_int.x = roundf((q_data.x - min) * q_scale_inv); + q_data_int.y = roundf((q_data.y - min) * q_scale_inv); + q_data_int.w = roundf((q_data.w - min) * q_scale_inv); + q_data_int.z = roundf((q_data.z - min) * q_scale_inv); + + q_data.x = q_data_int.x * q_scale + min; + q_data.y = q_data_int.y * q_scale + min; + q_data.w = q_data_int.w * q_scale + min; + q_data.z = q_data_int.z * q_scale + min; + + vals_cast[group_index + bid * group_size] = q_data; + } + } +} + +template +void launch_fake_quantize_kernel_asym(T* vals, + int total_count, + int group_num, + int num_bits, + cudaStream_t stream) +{ + dim3 grid_dim(group_num); + dim3 block_dim(1024); + + fake_quantize_kernel_asym<<>>( + vals, (total_count / group_num) / 4, num_bits); +} + +template void launch_fake_quantize_kernel_asym(float* vals, + int total_count, + int group_num, + int num_bits, + cudaStream_t stream); +template void launch_fake_quantize_kernel_asym(__half* vals, + int total_count, + int group_num, + int num_bits, + cudaStream_t stream); + +__global__ void sr_fake_quantize_kernel_asym(__half* vals, + int token_size, + int token_num, + int num_bits, + std::pair seed) +{ +#if __CUDA_ARCH__ >= 700 || defined(__HIP_PLATFORM_AMD__) + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile<32> g = cg::tiled_partition<32>(b); + + int gid = threadIdx.x >> 5; + int lane = threadIdx.x & 0x1f; + int warp_num = blockDim.x >> 5; + + int idx = blockIdx.x * blockDim.x + threadIdx.x; + + float2* vals_cast = reinterpret_cast(vals); + + __half2 data_low[128]; + __half2 data_high[128]; + + int bid = blockIdx.x; + + curandStatePhilox4_32_10_t state; + curand_init(seed.first, idx, seed.second, &state); + unsigned int tid = threadIdx.x; + int reg_count = 0; + int offset = bid * token_size; + int group_index = bid * token_size + tid; + + int total_count = token_size * token_num; + if (group_index < total_count) { + float min = 10000.0; + float max = -10000.0; + while (tid < token_size) { + float2 data = vals_cast[offset + tid]; + __half2* data_h = reinterpret_cast<__half2*>(&data); + data_low[reg_count] = data_h[0]; + data_high[reg_count] = data_h[1]; + + float2 data_f[2]; + data_f[0] = __half22float2(data_h[0]); + data_f[1] = __half22float2(data_h[1]); + + if (((float)data_f[0].x) > max) max = (float)data_f[0].x; + if (((float)data_f[0].y) > max) max = (float)data_f[0].y; + if (((float)data_f[1].x) > max) max = (float)data_f[1].x; + if (((float)data_f[1].y) > max) max = (float)data_f[1].y; + + if (((float)data_f[0].x) < min) min = (float)data_f[0].x; + if (((float)data_f[0].y) < min) min = (float)data_f[0].y; + if (((float)data_f[1].x) < min) min = (float)data_f[1].x; + if (((float)data_f[1].y) < min) min = (float)data_f[1].y; + + tid += blockDim.x; + reg_count++; + } + +#pragma unroll + for (int i = 1; i < WARP_SIZE; i <<= 1) { + auto temp = g.shfl_xor(max, i); + if (max < temp) max = temp; + } + +#pragma unroll + for (int i = 1; i < WARP_SIZE; i <<= 1) { + auto temp = g.shfl_xor(min, i); + if (min > temp) min = temp; + } + + __shared__ float partialMax[WARP_SIZE]; + __shared__ float partialMin[WARP_SIZE]; + + if (lane == 0) partialMax[gid] = max; + if (lane == 0) partialMin[gid] = min; + + b.sync(); + + if (lane < warp_num) max = partialMax[lane]; + if (lane < warp_num) min = partialMin[lane]; + +#pragma unroll + for (int i = 1; i < warp_num; i <<= 1) { + auto temp = g.shfl_down(max, i); + if (max < temp) max = temp; + } +#pragma unroll + for (int i = 1; i < warp_num; i <<= 1) { + auto temp = g.shfl_down(min, i); + if (min > temp) min = temp; + } + + max = g.shfl(max, 0); + min = g.shfl(min, 0); + + float q_scale_val = ((max - min) + 1e-5) / (float)(1 << num_bits); + float q_scale_val_inv = 1 / q_scale_val; + float high_q = (float)((1 << num_bits) - 1); + + for (int i = 0; i < reg_count; i++) { + int token_index = i * blockDim.x + threadIdx.x; + if (token_index < token_size) { + float2 data_f[2]; + data_f[0] = __half22float2(data_low[i]); + data_f[1] = __half22float2(data_high[i]); + + float2 q_data_int[2]; + q_data_int[0].x = (float)((unsigned int)((data_f[0].x - min) * q_scale_val_inv)); + q_data_int[0].y = (float)((unsigned int)((data_f[0].y - min) * q_scale_val_inv)); + q_data_int[1].x = (float)((unsigned int)((data_f[1].x - min) * q_scale_val_inv)); + q_data_int[1].y = (float)((unsigned int)((data_f[1].y - min) * q_scale_val_inv)); + + // Stochastic rounding + float4 rand = curand_uniform4(&state); + + float q_error[4]; + q_error[0] = + abs(data_f[0].x - ((q_data_int[0].x * q_scale_val) + min)) * q_scale_val_inv; + q_error[1] = + abs(data_f[0].y - ((q_data_int[0].y * q_scale_val) + min)) * q_scale_val_inv; + q_error[2] = + abs(data_f[1].x - ((q_data_int[1].x * q_scale_val) + min)) * q_scale_val_inv; + q_error[3] = + abs(data_f[1].y - ((q_data_int[1].y * q_scale_val) + min)) * q_scale_val_inv; + + q_data_int[0].x = (rand.x < q_error[0] && q_data_int[0].x < high_q) + ? (q_data_int[0].x + 1) + : q_data_int[0].x; + q_data_int[0].y = (rand.y < q_error[1] && q_data_int[0].y < high_q) + ? (q_data_int[0].y + 1) + : q_data_int[0].y; + q_data_int[1].x = (rand.w < q_error[2] && q_data_int[1].x < high_q) + ? (q_data_int[1].x + 1) + : q_data_int[1].x; + q_data_int[1].y = (rand.z < q_error[3] && q_data_int[1].y < high_q) + ? (q_data_int[1].y + 1) + : q_data_int[1].y; + + data_f[0].x = q_data_int[0].x * q_scale_val + min; + data_f[0].y = q_data_int[0].y * q_scale_val + min; + data_f[1].x = q_data_int[1].x * q_scale_val + min; + data_f[1].y = q_data_int[1].y * q_scale_val + min; + + float2 result; + __half2* result_h = reinterpret_cast<__half2*>(&result); + result_h[0] = __float22half2_rn(data_f[0]); + result_h[1] = __float22half2_rn(data_f[1]); + + vals_cast[offset + token_index] = result; + } + } + } +#endif +} + +__global__ void sr_fake_quantize_kernel_asym(float* vals, + int token_size, + int token_num, + int num_bits, + std::pair seed) +{ + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile<32> g = cg::tiled_partition<32>(b); + + int gid = threadIdx.x >> 5; + int lane = threadIdx.x & 0x1f; + int warp_num = blockDim.x >> 5; + int id = threadIdx.x; + + int idx = blockIdx.x * blockDim.x + id; + + float4* vals_cast = reinterpret_cast(vals); + + float4 data[128]; + + int bid = blockIdx.x; + int tid = threadIdx.x; + curandStatePhilox4_32_10_t state; + curand_init(seed.first, idx, seed.second, &state); + + int group_index = bid * token_size + threadIdx.x; + int reg_count = 0; + int total_count = token_size * token_num; + if (group_index < total_count) { + float min = 10000.0; + float max = -10000.0; + + while (tid < token_size) { + float4 data_reg = vals_cast[group_index]; + data[reg_count] = data_reg; + if (data_reg.x > max) max = data_reg.x; + if (data_reg.y > max) max = data_reg.y; + if (data_reg.w > max) max = data_reg.w; + if (data_reg.z > max) max = data_reg.z; + + if (data_reg.x < min) min = data_reg.x; + if (data_reg.y < min) min = data_reg.y; + if (data_reg.w < min) min = data_reg.w; + if (data_reg.z < min) min = data_reg.z; + + group_index += blockDim.x; + tid += blockDim.x; + reg_count++; + } + +#pragma unroll + for (int i = 1; i < WARP_SIZE; i <<= 1) { + auto temp = g.shfl_xor(max, i); + if (max < temp) max = temp; + } + +#pragma unroll + for (int i = 1; i < WARP_SIZE; i <<= 1) { + auto temp = g.shfl_xor(min, i); + if (min > temp) min = temp; + } + + __shared__ float partialMax[WARP_SIZE]; + __shared__ float partialMin[WARP_SIZE]; + + if (lane == 0) partialMax[gid] = max; + if (lane == 0) partialMin[gid] = min; + + b.sync(); + + if (lane < warp_num) max = partialMax[lane]; + if (lane < warp_num) min = partialMin[lane]; + +#pragma unroll + for (int i = 1; i < warp_num; i <<= 1) { + auto temp = g.shfl_down(max, i); + if (max < temp) max = temp; + } +#pragma unroll + for (int i = 1; i < warp_num; i <<= 1) { + auto temp = g.shfl_down(min, i); + if (min > temp) min = temp; + } + + max = g.shfl(max, 0); + min = g.shfl(min, 0); + + float q_scale_val = ((max - min) + 1e-5) / (float)(1 << num_bits); + float high_q = (float)((1 << num_bits) - 1); + + int offset = (bid)*token_size; + for (int i = 0; i < reg_count; i++) { + group_index = i * blockDim.x + threadIdx.x; + if (group_index < token_size) { + float4 q_data = data[i]; + + float4 q_data_int; + q_data_int.x = (float)((int)((q_data.x - min) / q_scale_val)); + q_data_int.y = (float)((int)((q_data.y - min) / q_scale_val)); + q_data_int.w = (float)((int)((q_data.w - min) / q_scale_val)); + q_data_int.z = (float)((int)((q_data.z - min) / q_scale_val)); + + // Stochastic rounding + float4 rand = curand_uniform4(&state); + + float q_error[4]; + q_error[0] = abs(q_data.x - ((q_data_int.x * q_scale_val) + min)) / q_scale_val; + q_error[1] = abs(q_data.y - ((q_data_int.y * q_scale_val) + min)) / q_scale_val; + q_error[2] = abs(q_data.w - ((q_data_int.w * q_scale_val) + min)) / q_scale_val; + q_error[3] = abs(q_data.z - ((q_data_int.z * q_scale_val) + min)) / q_scale_val; + + q_data_int.x = (rand.x < q_error[0] && q_data_int.x < high_q) ? (q_data_int.x + 1) + : q_data_int.x; + q_data_int.y = (rand.y < q_error[1] && q_data_int.y < high_q) ? (q_data_int.y + 1) + : q_data_int.y; + q_data_int.w = (rand.w < q_error[2] && q_data_int.w < high_q) ? (q_data_int.w + 1) + : q_data_int.w; + q_data_int.z = (rand.z < q_error[3] && q_data_int.z < high_q) ? (q_data_int.z + 1) + : q_data_int.z; + + q_data_int.x = q_data_int.x * q_scale_val + min; + q_data_int.y = q_data_int.y * q_scale_val + min; + q_data_int.w = q_data_int.w * q_scale_val + min; + q_data_int.z = q_data_int.z * q_scale_val + min; + + vals_cast[group_index + offset] = q_data_int; + } + } + } +} +template +void launch_sr_fake_quantize_kernel_asym(T* vals, + int total_count, + int group_num, + int num_bits, + cudaStream_t stream) +{ + dim3 block_dim(1024); + dim3 grid_dim(group_num); + + uint64_t inc = total_count / grid_dim.x / block_dim.x; + std::pair seed = TrainingContext::Instance().IncrementOffset(inc); + + sr_fake_quantize_kernel<<>>( + vals, (total_count / group_num) / 4, group_num, num_bits, seed); +} +template void launch_sr_fake_quantize_kernel_asym(float* vals, + int total_count, + int group_num, + int num_bits, + cudaStream_t stream); +template void launch_sr_fake_quantize_kernel_asym(__half* vals, + int total_count, + int group_num, + int num_bits, + cudaStream_t stream); diff --git a/toolbox/DeepSpeed/v0.15.3/patches/csrc/rope/fused_rotary_positional_embedding.cpp b/toolbox/DeepSpeed/v0.15.3/patches/csrc/rope/fused_rotary_positional_embedding.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8cc16bc52dc9a63f514f0f9d44d06463c6e127fd --- /dev/null +++ b/toolbox/DeepSpeed/v0.15.3/patches/csrc/rope/fused_rotary_positional_embedding.cpp @@ -0,0 +1,117 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +/* +Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may +not use this file except in compliance with the License. You may obtain +a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +/* coding=utf-8 + * Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +namespace fused_rope { + +torch::Tensor fwd_cuda(const torch::Tensor &input, const torch::Tensor &cos, + const torch::Tensor &sin, const bool transpose_output); + +torch::Tensor bwd_cuda(const torch::Tensor &output_grads, + const torch::Tensor &cos, const torch::Tensor &sin, + const bool transpose_output); + +torch::Tensor fwd(const at::Tensor &input, const at::Tensor &cos, + const at::Tensor &sin, const bool transpose_output) { + TORCH_CHECK(input.dim() == 4, "expected 4D tensor"); + TORCH_CHECK(cos.dim() == 4, "expected 4D tensor"); + TORCH_CHECK(sin.dim() == 4, "expected 4D tensor"); + TORCH_CHECK(input.size(0) == cos.size(0), + "expected input and cos tensor have the same sequence length"); + TORCH_CHECK(input.size(0) == sin.size(0), + "expected input and sin tensor have the same sequence length"); + TORCH_CHECK(cos.size(1) == 1 && cos.size(2) == 1, + "expected the second and third dims of the cos tensor equal 1"); + TORCH_CHECK(sin.size(1) == 1 && sin.size(2) == 1, + "expected the second and third dims of the sin tensor equal 1"); + TORCH_CHECK(input.size(3) >= cos.size(3), + "expected the last dim of the input tensor is greater than the " + "cos tensor"); + TORCH_CHECK(input.size(3) >= sin.size(3), + "expected the last dim of the input tensor is greater than the " + "sin tensor"); + + return fwd_cuda(input, cos, sin, transpose_output); +} + +torch::Tensor bwd(const torch::Tensor &output_grads, const at::Tensor &cos, + const at::Tensor &sin, const bool transpose_output) { + TORCH_CHECK(output_grads.dim() == 4, "expected 4D tensor"); + TORCH_CHECK(cos.dim() == 4, "expected 4D tensor"); + TORCH_CHECK(sin.dim() == 4, "expected 4D tensor"); + TORCH_CHECK( + output_grads.size(0) == cos.size(0), + "expected output_grads and cos tensor have the same sequence length"); + TORCH_CHECK( + output_grads.size(0) == sin.size(0), + "expected output_grads and sin tensor have the same sequence length"); + TORCH_CHECK(cos.size(1) == 1 && cos.size(2) == 1, + "expected the second and third dims of the cos tensor equal 1"); + TORCH_CHECK(sin.size(1) == 1 && sin.size(2) == 1, + "expected the second and third dims of the sin tensor equal 1"); + TORCH_CHECK( + output_grads.size(3) >= cos.size(3), + "expected the last dim of the output_grads tensor is greater than the " + "cos tensor"); + TORCH_CHECK( + output_grads.size(3) >= sin.size(3), + "expected the last dim of the output_grads tensor is greater than the " + "sin tensor"); + + return bwd_cuda(output_grads, cos, sin, transpose_output); +} + +} // end namespace fused_rope + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &fused_rope::fwd, + "Fused Rotary Positional Embedding -- Forward."); + m.def("backward", &fused_rope::bwd, + "Fused Rotary Positional Embedding -- Backward."); +} \ No newline at end of file diff --git a/toolbox/DeepSpeed/v0.15.3/patches/csrc/rope/fused_rotary_positional_embedding_cuda.cu b/toolbox/DeepSpeed/v0.15.3/patches/csrc/rope/fused_rotary_positional_embedding_cuda.cu new file mode 100644 index 0000000000000000000000000000000000000000..ce1b7fb774726a46e0e0529dc1f01cc8feb90fed --- /dev/null +++ b/toolbox/DeepSpeed/v0.15.3/patches/csrc/rope/fused_rotary_positional_embedding_cuda.cu @@ -0,0 +1,142 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +/* +Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may +not use this file except in compliance with the License. You may obtain +a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +/* coding=utf-8 + * Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "fused_rotary_positional_embedding.h" +#include "type_shim_rope.h" + +namespace fused_rope { + +torch::Tensor fwd_cuda(const torch::Tensor &input, const torch::Tensor &cos, + const torch::Tensor &sin, const bool transpose_output) { + // input sizes: (s, b, h, d) + // s: sequence length + // b: batch size + // h: head num + // d: dim of each head + const int s = input.size(0); + const int b = input.size(1); + const int h = input.size(2); + const int d = input.size(3); + // input strides + const int stride_s = input.stride(0); + const int stride_b = input.stride(1); + const int stride_h = input.stride(2); + const int stride_d = input.stride(3); + // cos/sin's shape is always (s, 1, 1, d2), so the strides are same under + // different memory formats + const int d2 = cos.size(3); + + // output + auto act_options = input.options().requires_grad(false); + torch::Tensor output; + if (transpose_output) { + output = torch::empty({b, s, h, d}, act_options).transpose(0, 1); + } else { + output = torch::empty({s, b, h, d}, act_options); + } + // output strides + const int o_stride_s = output.stride(0); + const int o_stride_b = output.stride(1); + const int o_stride_h = output.stride(2); + const int o_stride_d = output.stride(3); + + DISPATCH_FLOAT_HALF_AND_BFLOAT( + input.scalar_type(), 0, "dispatch_fused_rope_forward", + dispatch_fused_rope_forward( + s, b, h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, + o_stride_b, o_stride_h, o_stride_d, input.data_ptr(), + cos.data_ptr(), sin.data_ptr(), + output.data_ptr());); + return output; +} + +torch::Tensor bwd_cuda(const torch::Tensor &output_grads, + const torch::Tensor &cos, const torch::Tensor &sin, + const bool transpose_output) { + // output_grads sizes: (s, b, h, d) + // s: sequence length + // b: batch size + // h: head num + // d: dim of each head + const int s = output_grads.size(0); + const int b = output_grads.size(1); + const int h = output_grads.size(2); + const int d = output_grads.size(3); + // output_grads strides + const int stride_s = output_grads.stride(0); + const int stride_b = output_grads.stride(1); + const int stride_h = output_grads.stride(2); + const int stride_d = output_grads.stride(3); + // cos/sin's shape is always (s, 1, 1, d2), so the strides are same under + // different memory formats + const int d2 = cos.size(3); + + auto act_options = output_grads.options().requires_grad(false); + torch::Tensor input_grads; + if (transpose_output) { + input_grads = torch::empty({b, s, h, d}, act_options).transpose(0, 1); + } else { + input_grads = torch::empty({s, b, h, d}, act_options); + } + const int o_stride_s = input_grads.stride(0); + const int o_stride_b = input_grads.stride(1); + const int o_stride_h = input_grads.stride(2); + const int o_stride_d = input_grads.stride(3); + + DISPATCH_FLOAT_HALF_AND_BFLOAT( + output_grads.scalar_type(), 0, "dispatch_fused_rope_backward", + dispatch_fused_rope_backward( + s, b, h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, + o_stride_b, o_stride_h, o_stride_d, + output_grads.data_ptr(), cos.data_ptr(), + sin.data_ptr(), input_grads.data_ptr());) + return input_grads; +} +} // end namespace fused_rope \ No newline at end of file diff --git a/toolbox/DeepSpeed/v0.15.3/patches/csrc/spatial/includes/spatial_cuda_layers.h b/toolbox/DeepSpeed/v0.15.3/patches/csrc/spatial/includes/spatial_cuda_layers.h new file mode 100644 index 0000000000000000000000000000000000000000..49bdf8a14e6a2f4b537c55bf31ee722442c9c05a --- /dev/null +++ b/toolbox/DeepSpeed/v0.15.3/patches/csrc/spatial/includes/spatial_cuda_layers.h @@ -0,0 +1,54 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +/* Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. */ +/* All Rights Reserved. */ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once +#define HALF_PRECISION_AVAILABLE = 1 +#if __CUDA_ARCH__ >= 530 +#define HALF_PRECISION_AVAILABLE = 1 +#endif + +#ifdef __HIP_PLATFORM_AMD__ +#include +#else +#include +#endif + +#include +#include + +/*********** Group Norm Kernels, Structs, and Helpers ************/ + +struct { + int64_t batch_size; + int64_t seq_len; + int64_t channels; +} typedef ChannelsLastProblem; + +void launch_opt_bias_add(__half* result, + const __half* activation, + const __half* bias, + const __half* other, + const __half* other_bias, + int batch_size, + int seq_len, + int channels, + cudaStream_t stream); diff --git a/toolbox/DeepSpeed/v0.15.3/patches/csrc/swiglu/swiglu.cpp b/toolbox/DeepSpeed/v0.15.3/patches/csrc/swiglu/swiglu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5f2b522720c988d4a2170a9c0ae97895c5681e12 --- /dev/null +++ b/toolbox/DeepSpeed/v0.15.3/patches/csrc/swiglu/swiglu.cpp @@ -0,0 +1,42 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +/* +Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may +not use this file except in compliance with the License. You may obtain +a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include +#include + +#include "swiglu.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("swiglu_fwd", &launch_swiglu_kernel, ""); + m.def("swiglu_bwd", &launch_swiglu_kernel_bwd, ""); + +} \ No newline at end of file diff --git a/toolbox/DeepSpeed/v0.15.3/patches/csrc/swiglu/swiglu_kernel.cu b/toolbox/DeepSpeed/v0.15.3/patches/csrc/swiglu/swiglu_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..8bc31cf48e4c36de51935b217db84c1839af5147 --- /dev/null +++ b/toolbox/DeepSpeed/v0.15.3/patches/csrc/swiglu/swiglu_kernel.cu @@ -0,0 +1,183 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +/* +Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may +not use this file except in compliance with the License. You may obtain +a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include "pack_type.cuh" +#include "swiglu.h" + +#include +#include +#include + +template +__global__ void SwiGLUPackedKernel(const T* x, T* y, int pack_len, int last_dim) { + int data_index = blockIdx.x * blockDim.x + threadIdx.x; + int idx_x = data_index % last_dim + (data_index / last_dim) * last_dim * 2; + int idx_y = idx_x + last_dim; + const Packed* ptr_x = reinterpret_cast*>(x); + Packed* ptr_z = reinterpret_cast*>(y); + Packed in_x, in_y, out; + if (data_index < pack_len) { + in_x = ptr_x[idx_x]; + in_y = ptr_x[idx_y]; +#pragma unroll + for (int i = 0; i < pack_size; ++i) { + float x_f = PackItemType2Float(in_x.elem[i]); + float y_f = PackItemType2Float(in_y.elem[i]); + float z_f = x_f * y_f / (1.0f + ::exp(-x_f)); + out.elem[i] = Float2PackItemType(z_f); + } + ptr_z[data_index] = out; + } +} + + +template +void LaunchSwiGLUPackedKernel(const T* x, T* y, const cudaStream_t &stream, int len, int last_dim) { + constexpr const int pack_size = std::max(static_cast(4 / sizeof(T)), 1); + bool is_packed = IsAlignedForPack(x, y); + if (is_packed && last_dim % pack_size == 0) { + int pack_len = len / pack_size; + unsigned int block_x = std::min(pack_len, 1024); + unsigned int grid_x = (pack_len + block_x - 1) / block_x; + SwiGLUPackedKernel + <<>>(x, y, pack_len, last_dim / pack_size); + } else { + int pack_len = len; + unsigned int block_x = std::min(pack_len, 1024); + unsigned int grid_x = (pack_len + block_x - 1) / block_x; + SwiGLUPackedKernel<1, T> + <<>>(x, y, pack_len, last_dim); + } +} + + +torch::Tensor launch_swiglu_kernel(torch::Tensor& input) { + TORCH_CHECK(input.size(-1) % 2 == 0, "last dim of input should be even, got ", input.size(-1)); + auto shape = input.sizes().vec(); + shape[shape.size() - 1] = input.size(-1) / 2; + torch::Tensor out = torch::empty(shape, at::TensorOptions(input.dtype()) + .device(input.device())); + + int last_dim = out.size(-1); + int len = out.numel(); + const void* x = input.data_ptr(); + void* y = out.data_ptr(); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + if (input.dtype() == at::ScalarType::Half) { + LaunchSwiGLUPackedKernel<__half>((const __half*)x, (__half*)y, stream, len, last_dim); + } else if (input.dtype() == at::ScalarType::Float) { + LaunchSwiGLUPackedKernel((const float*)x, (float*)y, stream, len, last_dim); + } else if (input.dtype() == at::ScalarType::BFloat16) { + LaunchSwiGLUPackedKernel<__nv_bfloat16>((const __nv_bfloat16*)x, (__nv_bfloat16*)y, stream, len, last_dim); + } else { + TORCH_CHECK(false, "input datatype should be half/float, got ", input.dtype()); + } + + return out; +} +template +__global__ void SwiGLUPackedBwdKernel(const T* x, const T* g, T* dx, int pack_len, int last_dim) { + int data_index = blockIdx.x * blockDim.x + threadIdx.x; + int idx_x = data_index % last_dim + (data_index / last_dim) * last_dim * 2; + int idx_y = idx_x + last_dim; + const Packed* ptr_x = reinterpret_cast*>(x); + const Packed* ptr_g = reinterpret_cast*>(g); + Packed* ptr_dx = reinterpret_cast*>(dx); + Packed in_x, in_y, in_g, d_in_x, d_in_y; + if (data_index < pack_len) { + in_x = ptr_x[idx_x]; + in_y = ptr_x[idx_y]; + in_g = ptr_g[data_index]; +#pragma unroll + for (int i = 0; i < pack_size; ++i) { + float x_f = PackItemType2Float(in_x.elem[i]); + float y_f = PackItemType2Float(in_y.elem[i]); + float g_f = PackItemType2Float(in_g.elem[i]); + float x_sigmoid = 1.0f / (1.0f + ::exp(-x_f)); + float d_tmp_x = x_sigmoid * (1 + x_f * (1.0f - x_sigmoid)) * g_f * y_f; + float d_tmp_y = x_f * x_sigmoid * g_f; + d_in_x.elem[i] = Float2PackItemType(d_tmp_x); + d_in_y.elem[i] = Float2PackItemType(d_tmp_y); + } + ptr_dx[idx_x] = d_in_x; + ptr_dx[idx_y] = d_in_y; + } +} + + +template +void LaunchSwiGLUPackedBwdKernel(const T* x, const T* g, T* dx, const cudaStream_t &stream, int len, int last_dim) { + constexpr const int pack_size = std::max(static_cast(4 / sizeof(T)), 1); + bool is_packed = IsAlignedForPack(x, g, dx); + if (is_packed && last_dim % pack_size == 0) { + int pack_len = len / pack_size; + unsigned int block_x = std::min(pack_len, 1024); + unsigned int grid_x = (pack_len + block_x - 1) / block_x; + SwiGLUPackedBwdKernel + <<>>(x, g, dx, pack_len, last_dim / pack_size); + } else { + int pack_len = len; + unsigned int block_x = std::min(pack_len, 1024); + unsigned int grid_x = (pack_len + block_x - 1) / block_x; + SwiGLUPackedBwdKernel<1, T> + <<>>(x, g, dx, pack_len, last_dim); + } +} + + +torch::Tensor launch_swiglu_kernel_bwd(torch::Tensor& input, torch::Tensor& grad) { + TORCH_CHECK(input.size(-1) % 2 == 0, "last dim of input should be even, got ", input.size(-1)); + int last_dim = grad.size(-1); + int len = grad.numel(); + + torch::Tensor out = torch::empty(input.sizes(), at::TensorOptions(input.dtype()) + .device(input.device())); + + const void* x = input.data_ptr(); + const void* g = grad.data_ptr(); + void* dx = out.data_ptr(); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + if (grad.dtype() == at::ScalarType::Half && input.dtype() == at::ScalarType::Half) { + LaunchSwiGLUPackedBwdKernel<__half>((const __half*)x, (const __half*)g, (__half*)dx, stream, len, last_dim); + } else if (grad.dtype() == at::ScalarType::Float && input.dtype() == at::ScalarType::Float) { + LaunchSwiGLUPackedBwdKernel((const float*)x, (const float*)g, (float*)dx, stream, len, last_dim); + } else if (grad.dtype() == at::ScalarType::BFloat16 && input.dtype() == at::ScalarType::BFloat16) { + LaunchSwiGLUPackedBwdKernel<__nv_bfloat16>((const __nv_bfloat16*)x, (const __nv_bfloat16*)g, (__nv_bfloat16*)dx, stream, len, last_dim); + } else { + TORCH_CHECK(false, "input and grad datatype should be half/float, got ", input.dtype(), grad.dtype()); + } + + return out; +} + diff --git a/toolbox/DeepSpeed/v0.15.3/patches/csrc/transformer/cublas_wrappers.cu b/toolbox/DeepSpeed/v0.15.3/patches/csrc/transformer/cublas_wrappers.cu new file mode 100644 index 0000000000000000000000000000000000000000..4585021c789a474a5ab3e3e79d771450d80e1edf --- /dev/null +++ b/toolbox/DeepSpeed/v0.15.3/patches/csrc/transformer/cublas_wrappers.cu @@ -0,0 +1,510 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +/* Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. */ +/* All Rights Reserved. */ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include "cublas_wrappers.h" + +// TODO HIP: Remove backward compatibility for torch<=2.0 in future +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) +int cublas_gemm_ex(rocblas_handle handle, + rocblas_operation transa, + rocblas_operation transb, + int m, + int n, + int k, + const float* alpha, + const float* beta, + const float* A, + const float* B, + float* C, + rocblas_gemm_algo algo) +#else +int cublas_gemm_ex(cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const float* alpha, + const float* beta, + const float* A, + const float* B, + float* C, + cublasGemmAlgo_t algo) +#endif +{ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) + rocblas_status status = rocblas_gemm_ex(handle, + transa, + transb, + m, + n, + k, + (const void*)alpha, + (const void*)A, + rocblas_datatype_f32_r, + (transa == rocblas_operation_none) ? m : k, + (const void*)B, + rocblas_datatype_f32_r, + (transb == rocblas_operation_none) ? k : n, + (const void*)beta, + C, + rocblas_datatype_f32_r, + m, + C, + rocblas_datatype_f32_r, + m, + rocblas_datatype_f32_r, + algo, + 0, + 0); +#else + cublasStatus_t status = cublasGemmEx(handle, + transa, + transb, + m, + n, + k, + (const void*)alpha, + (const void*)A, +#ifdef __HIP_PLATFORM_AMD__ + HIPBLAS_R_32F, +#else + CUDA_R_32F, +#endif + (transa == CUBLAS_OP_N) ? m : k, + (const void*)B, +#ifdef __HIP_PLATFORM_AMD__ + HIPBLAS_R_32F, +#else + CUDA_R_32F, +#endif + (transb == CUBLAS_OP_N) ? k : n, + (const void*)beta, + C, +#ifdef __HIP_PLATFORM_AMD__ + HIPBLAS_R_32F, +#else + CUDA_R_32F, +#endif + m, +#if defined(__HIP_PLATFORM_AMD__) && defined(HIPBLAS_V2) + HIPBLAS_COMPUTE_32F, +#elif defined(__HIP_PLATFORM_AMD__) + HIPBLAS_R_32F, +#else + CUDA_R_32F, +#endif + algo); +#endif + +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) + if (status != rocblas_status_success) { +#else + if (status != CUBLAS_STATUS_SUCCESS) { +#endif + fprintf(stderr, + "!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n", + m, + n, + k, + (int)status); + return EXIT_FAILURE; + } + return 0; +} + +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) +int cublas_gemm_ex(rocblas_handle handle, + rocblas_operation transa, + rocblas_operation transb, + int m, + int n, + int k, + const float* alpha, + const float* beta, + const __half* A, + const __half* B, + __half* C, + rocblas_gemm_algo algo) +#else +int cublas_gemm_ex(cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const float* alpha, + const float* beta, + const __half* A, + const __half* B, + __half* C, + cublasGemmAlgo_t algo) +#endif +{ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) + rocblas_status status = rocblas_gemm_ex(handle, + transa, + transb, + m, + n, + k, + (const void*)alpha, + (const void*)A, + rocblas_datatype_f16_r, + (transa == rocblas_operation_none) ? m : k, + (const void*)B, + rocblas_datatype_f16_r, + (transb == rocblas_operation_none) ? k : n, + (const void*)beta, + (void*)C, + rocblas_datatype_f16_r, + m, + (void*)C, + rocblas_datatype_f16_r, + m, + rocblas_datatype_f32_r, + algo, + 0, + 0); +#else + cublasStatus_t status = cublasGemmEx(handle, + transa, + transb, + m, + n, + k, + (const void*)alpha, + (const void*)A, +#ifdef __HIP_PLATFORM_AMD__ + HIPBLAS_R_16F, +#else + CUDA_R_16F, +#endif + (transa == CUBLAS_OP_N) ? m : k, + (const void*)B, +#ifdef __HIP_PLATFORM_AMD__ + HIPBLAS_R_16F, +#else + CUDA_R_16F, +#endif + (transb == CUBLAS_OP_N) ? k : n, + (const void*)beta, + (void*)C, +#ifdef __HIP_PLATFORM_AMD__ + HIPBLAS_R_16F, +#else + CUDA_R_16F, +#endif + m, +#if defined(__HIP_PLATFORM_AMD__) && defined(HIPBLAS_V2) + HIPBLAS_COMPUTE_32F, +#elif defined(__HIP_PLATFORM_AMD__) + HIPBLAS_R_32F, +#else + CUDA_R_32F, +#endif + algo); +#endif + +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) + if (status != rocblas_status_success) { +#else + if (status != CUBLAS_STATUS_SUCCESS) { +#endif + fprintf(stderr, + "!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n", + m, + n, + k, + (int)status); + return EXIT_FAILURE; + } + return 0; +} + +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) +int cublas_strided_batched_gemm(rocblas_handle handle, + int m, + int n, + int k, + const float* alpha, + const float* beta, + const float* A, + const float* B, + float* C, + rocblas_operation op_A, + rocblas_operation op_B, + int stride_A, + int stride_B, + int stride_C, + int batch, + rocblas_gemm_algo algo) +#else +int cublas_strided_batched_gemm(cublasHandle_t handle, + int m, + int n, + int k, + const float* alpha, + const float* beta, + const float* A, + const float* B, + float* C, + cublasOperation_t op_A, + cublasOperation_t op_B, + int stride_A, + int stride_B, + int stride_C, + int batch, + cublasGemmAlgo_t algo) +#endif +{ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) + rocblas_status status = + rocblas_gemm_strided_batched_ex(handle, + op_A, + op_B, + m, + n, + k, + alpha, + A, + rocblas_datatype_f32_r, + (op_A == rocblas_operation_none) ? m : k, + stride_A, + B, + rocblas_datatype_f32_r, + (op_B == rocblas_operation_none) ? k : n, + stride_B, + beta, + C, + rocblas_datatype_f32_r, + m, + stride_C, + C, + rocblas_datatype_f32_r, + m, + stride_C, + batch, + rocblas_datatype_f32_r, + algo, + 0, + 0); +#else + cublasStatus_t status = cublasGemmStridedBatchedEx(handle, + op_A, + op_B, + m, + n, + k, + alpha, + A, +#ifdef __HIP_PLATFORM_AMD__ + HIPBLAS_R_32F, +#else + CUDA_R_32F, +#endif + (op_A == CUBLAS_OP_N) ? m : k, + stride_A, + B, +#ifdef __HIP_PLATFORM_AMD__ + HIPBLAS_R_32F, +#else + CUDA_R_32F, +#endif + (op_B == CUBLAS_OP_N) ? k : n, + stride_B, + beta, + C, +#ifdef __HIP_PLATFORM_AMD__ + HIPBLAS_R_32F, +#else + CUDA_R_32F, +#endif + m, + stride_C, + batch, +#if defined(__HIP_PLATFORM_AMD__) && defined(HIPBLAS_V2) + HIPBLAS_COMPUTE_32F, +#elif defined(__HIP_PLATFORM_AMD__) + HIPBLAS_R_32F, +#else + CUDA_R_32F, +#endif + algo); +#endif + +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) + if (status != rocblas_status_success) { +#else + if (status != CUBLAS_STATUS_SUCCESS) { +#endif + fprintf(stderr, + "!!!! kernel execution error. (batch: %d, m: %d, n: %d, k: %d, error: %d) \n", + batch, + m, + n, + k, + (int)status); + return EXIT_FAILURE; + } + return 0; +} + +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) +int cublas_strided_batched_gemm(rocblas_handle handle, + int m, + int n, + int k, + const float* alpha, + const float* beta, + const __half* A, + const __half* B, + __half* C, + rocblas_operation op_A, + rocblas_operation op_B, + int stride_A, + int stride_B, + int stride_C, + int batch, + rocblas_gemm_algo algo) +#else +int cublas_strided_batched_gemm(cublasHandle_t handle, + int m, + int n, + int k, + const float* alpha, + const float* beta, + const __half* A, + const __half* B, + __half* C, + cublasOperation_t op_A, + cublasOperation_t op_B, + int stride_A, + int stride_B, + int stride_C, + int batch, + cublasGemmAlgo_t algo) +#endif +{ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) + rocblas_status status = + rocblas_gemm_strided_batched_ex(handle, + op_A, + op_B, + m, + n, + k, + alpha, + A, + rocblas_datatype_f16_r, + (op_A == rocblas_operation_none) ? m : k, + stride_A, + B, + rocblas_datatype_f16_r, + (op_B == rocblas_operation_none) ? k : n, + stride_B, + beta, + C, + rocblas_datatype_f16_r, + m, + stride_C, + C, + rocblas_datatype_f16_r, + m, + stride_C, + batch, + rocblas_datatype_f32_r, + algo, + 0, + 0); +#else + cublasStatus_t status = cublasGemmStridedBatchedEx(handle, + op_A, + op_B, + m, + n, + k, + alpha, + A, +#ifdef __HIP_PLATFORM_AMD__ + HIPBLAS_R_16F, +#else + CUDA_R_16F, +#endif + (op_A == CUBLAS_OP_N) ? m : k, + stride_A, + B, +#ifdef __HIP_PLATFORM_AMD__ + HIPBLAS_R_16F, +#else + CUDA_R_16F, +#endif + (op_B == CUBLAS_OP_N) ? k : n, + stride_B, + beta, + C, +#ifdef __HIP_PLATFORM_AMD__ + HIPBLAS_R_16F, +#else + CUDA_R_16F, +#endif + m, + stride_C, + batch, +#if defined(__HIP_PLATFORM_AMD__) && defined(HIPBLAS_V2) + HIPBLAS_COMPUTE_32F, +#elif defined(__HIP_PLATFORM_AMD__) + HIPBLAS_R_32F, +#else + CUDA_R_32F, +#endif + algo); +#endif + +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) + if (status != rocblas_status_success) { +#else + if (status != CUBLAS_STATUS_SUCCESS) { +#endif + fprintf(stderr, + "!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n", + m, + n, + k, + (int)status); + return EXIT_FAILURE; + } + + return 0; +} diff --git a/toolbox/DeepSpeed/v0.15.3/patches/csrc/transformer/dropout_kernels.cu b/toolbox/DeepSpeed/v0.15.3/patches/csrc/transformer/dropout_kernels.cu new file mode 100644 index 0000000000000000000000000000000000000000..e75f19e80a1b74df18e68153902a0640a0ee3723 --- /dev/null +++ b/toolbox/DeepSpeed/v0.15.3/patches/csrc/transformer/dropout_kernels.cu @@ -0,0 +1,890 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +/* Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. */ +/* All Rights Reserved. */ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include "custom_cuda_layers.h" + +const int unroll_factor = 4; + +__global__ void dropout_kernel(const int N, + const float ratio, + float* out, + const float* Xdata, + uint8_t* mask, + std::pair seed) +{ + const float scale = 1. / (1. - ratio); + int idx = blockIdx.x * blockDim.x + threadIdx.x; + + curandStatePhilox4_32_10_t state; + curand_init(seed.first, idx, seed.second, &state); + + CUDA_1D_KERNEL_LOOP(j, N / unroll_factor) + { + float4 rand = curand_uniform4(&state); + uint8_t m[unroll_factor]; + + m[0] = (uint8_t)(rand.x > ratio); + m[1] = (uint8_t)(rand.y > ratio); + m[2] = (uint8_t)(rand.z > ratio); + m[3] = (uint8_t)(rand.w > ratio); + + int i = j * unroll_factor; + + mask[i] = (uint8_t)m[0]; + mask[i + 1] = (uint8_t)m[1]; + mask[i + 2] = (uint8_t)m[2]; + mask[i + 3] = (uint8_t)m[3]; + + out[i] = Xdata[i] * scale * m[0]; + out[i + 1] = Xdata[i + 1] * scale * m[1]; + out[i + 2] = Xdata[i + 2] * scale * m[2]; + out[i + 3] = Xdata[i + 3] * scale * m[3]; + } + int high_index = + ((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x; + if (N > high_index) { + float4 rand = curand_uniform4(&state); + float* rand_data = &(rand.x); + int k = 0; + for (int i = high_index; i < N; i++) { + uint8_t m = (uint8_t)(rand_data[k++] > ratio); + out[i] = Xdata[i] * scale * m; + mask[i] = m; + } + } +} + +__global__ void dropout_kernel(const int N, + const float ratio, + __half* out, + const __half* Xdata, + uint8_t* mask, + std::pair seed) +{ + const float scale = 1. / (1. - ratio); + + int idx = blockIdx.x * blockDim.x + threadIdx.x; + + curandStatePhilox4_32_10_t state; + curand_init(seed.first, idx, seed.second, &state); + +#ifdef __STOCHASTIC_MODE__ + + const __half2 h_scale = __float2half2_rn(scale); + const float2* x_cast = reinterpret_cast(Xdata); + float2* out_cast = reinterpret_cast(out); + uint32_t* mask_cast = reinterpret_cast(mask); + + uint32_t m_32; + uint8_t* m = reinterpret_cast(&m_32); + + float2 result_f; + __half2* result_h = reinterpret_cast<__half2*>(&result_f); + __half2 mask_h[2]; + float2 mask_f[2]; + + CUDA_1D_KERNEL_LOOP(j, N / unroll_factor) + { + float2 x_f = x_cast[j]; + __half2* x_h = reinterpret_cast<__half2*>(&x_f); + + float4 rand = curand_uniform4(&state); + + m[0] = (uint8_t)(rand.x > ratio); + m[1] = (uint8_t)(rand.y > ratio); + m[2] = (uint8_t)(rand.z > ratio); + m[3] = (uint8_t)(rand.w > ratio); + + float* mask_f_data = &mask_f[0].x; +#pragma unroll + for (int i = 0; i < unroll_factor; i++) mask_f_data[i] = (float)(m[i]); + + mask_h[0] = __float22half2_rn(mask_f[0]); + mask_h[1] = __float22half2_rn(mask_f[1]); + + result_h[0] = x_h[0] * h_scale * mask_h[0]; + result_h[1] = x_h[1] * h_scale * mask_h[1]; + + out_cast[j] = result_f; + + mask_cast[j] = m_32; + } + +#else + + CUDA_1D_KERNEL_LOOP(j, N / unroll_factor) + { + int i = j * unroll_factor; + + const __half2* vals_half = reinterpret_cast(Xdata + i); + float2 vals_half_f[2]; + vals_half_f[0] = __half22float2(vals_half[0]); + vals_half_f[1] = __half22float2(vals_half[1]); + + uint8_t m[unroll_factor]; + float4 rand = curand_uniform4(&state); + m[0] = (uint8_t)(rand.x > ratio); + m[1] = (uint8_t)(rand.y > ratio); + m[2] = (uint8_t)(rand.z > ratio); + m[3] = (uint8_t)(rand.w > ratio); + + out[i] = __float2half(vals_half_f[0].x * scale * m[0]); + out[i + 1] = __float2half(vals_half_f[0].y * scale * m[1]); + out[i + 2] = __float2half(vals_half_f[1].x * scale * m[2]); + out[i + 3] = __float2half(vals_half_f[1].y * scale * m[3]); + + mask[i] = m[0]; + mask[i + 1] = m[1]; + mask[i + 2] = m[2]; + mask[i + 3] = m[3]; + } + +#endif + int high_index = + ((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x; + if (N > high_index) { + float4 rand = curand_uniform4(&state); + float* rand_data = &(rand.x); + int k = 0; + for (int i = high_index; i < N; i++) { + uint8_t m = (uint8_t)(rand_data[k++] > ratio); + out[i] = __float2half((float)Xdata[i] * scale * m); + mask[i] = m; + } + } +} + +__global__ void dropout_kernel_bwd(const int N, + const float ratio, + const float* Xdata, + float* out, + uint8_t* mask, + std::pair seed) +{ + const float scale = 1. / (1. - ratio); + CUDA_1D_KERNEL_LOOP(j, N / unroll_factor) + { + int i = j * unroll_factor; + + out[i] = mask[i] ? Xdata[i] * scale : 0.0; + out[i + 1] = mask[i + 1] ? Xdata[i + 1] * scale : 0.0; + out[i + 2] = mask[i + 2] ? Xdata[i + 2] * scale : 0.0; + out[i + 3] = mask[i + 3] ? Xdata[i + 3] * scale : 0.0; + } + int high_index = + ((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x; + if (N > high_index) { + for (int i = high_index; i < N; i++) { out[i] = mask[i] ? Xdata[i] * scale : 0.0; } + } +} + +__global__ void dropout_kernel_bwd(const int N, + const float ratio, + const __half* Xdata, + __half* out, + uint8_t* mask, + std::pair seed) +{ + const float scale = 1. / (1. - ratio); + +#ifdef __STOCHASTIC_MODE__ + + const __half2 h_scale = __float2half2_rn(scale); + + const float2* x_cast = reinterpret_cast(Xdata); + float2* out_cast = reinterpret_cast(out); + uint32_t* mask_cast = reinterpret_cast(mask); + + CUDA_1D_KERNEL_LOOP(j, N / unroll_factor) + { + float2 x_f = x_cast[j]; + __half2* x_h = reinterpret_cast<__half2*>(&x_f); + + uint32_t m_32 = mask_cast[j]; + uint8_t* m = (uint8_t*)&m_32; + + __half2 mask_h[2]; + float2 mask_f[2]; + + float* mask_f_data = &mask_f[0].x; +#pragma unroll + for (int i = 0; i < unroll_factor; i++) mask_f_data[i] = (float)(m[i]); + +#pragma unroll + for (int i = 0; i < 2; i++) mask_h[i] = __float22half2_rn(mask_f[i]); + + float2 result_f; + __half2* result_h = reinterpret_cast<__half2*>(&result_f); + + result_h[0] = x_h[0] * h_scale * mask_h[0]; + result_h[1] = x_h[1] * h_scale * mask_h[1]; + + out_cast[j] = result_f; + } + +#else + + const __half h_scale = __float2half(scale); + const __half h_zero = __float2half(0.0); + + CUDA_1D_KERNEL_LOOP(j, N / unroll_factor) + { + int i = j * unroll_factor; + + const __half2* vals_half = reinterpret_cast(Xdata + i); + + uint8_t* m = mask + i; + + float2 vals_half_f[2]; + + vals_half_f[0] = __half22float2(vals_half[0]); + vals_half_f[1] = __half22float2(vals_half[1]); + + out[i] = __float2half(vals_half_f[0].x * scale * m[0]); + out[i + 1] = __float2half(vals_half_f[0].y * scale * m[1]); + out[i + 2] = __float2half(vals_half_f[1].x * scale * m[2]); + out[i + 3] = __float2half(vals_half_f[1].y * scale * m[3]); + } + +#endif + int high_index = + ((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x; + if (N > high_index) { + for (int i = high_index; i < N; i++) { + out[i] = __float2half((float)Xdata[i] * scale * mask[i]); + } + } +} + +template +void launch_dropout(T* out, + const T* vals, + uint8_t* mask, + int total_count, + int dim, + float ratio, + cudaStream_t stream, + bool bwd) +{ + assert(unroll_factor == 4); + + dim3 grid_dim = DS_GET_BLOCKS(total_count / unroll_factor); + dim3 block_dim = DS_CUDA_NUM_THREADS; + + if (dim > 512) { + block_dim.x >>= 1; + grid_dim.x <<= 1; + } + uint64_t inc = total_count / grid_dim.x / block_dim.x; + std::pair seed = TrainingContext::Instance().IncrementOffset(inc); + if (bwd) + dropout_kernel_bwd<<>>( + total_count, ratio, vals, out, mask, seed); + else + dropout_kernel<<>>( + total_count, ratio, out, vals, mask, seed); +} + +template void launch_dropout(float* out, + const float* vals, + uint8_t* mask, + int total_count, + int dim, + float ratio, + cudaStream_t stream, + bool); +template void launch_dropout(__half* out, + const __half* vals, + uint8_t* mask, + int total_count, + int dim, + float ratio, + cudaStream_t stream, + bool); + +__global__ void dropout_grad_kernel(const int N, const float scale, float* Xdata, uint8_t* mask) +{ + CUDA_1D_KERNEL_LOOP(i, N) { Xdata[i] *= scale * mask[i]; } +} + +__global__ void dropout_grad_kernel(const int N, const float scale, __half* Xdata, uint8_t* mask) +{ + const __half2 h_scale = __float2half2_rn(scale); + float2* x_cast = reinterpret_cast(Xdata); + uint32_t* mask_cast = reinterpret_cast(mask); + + CUDA_1D_KERNEL_LOOP(j, N / unroll_factor) + { + float2 x_data = x_cast[j]; + uint32_t m_32 = mask_cast[j]; + uint8_t* m = (uint8_t*)&m_32; + + float2 result_f; + __half2* result_h = reinterpret_cast<__half2*>(&result_f); + +#ifdef __STOCHASTIC_MODE__ + + __half2* x_data_h = reinterpret_cast<__half2*>(&x_data); + __half2 mask_h[2]; + float2 mask_f[2]; + + float* mask_f_data = &mask_f[0].x; +#pragma unroll + for (int i = 0; i < unroll_factor; i++) *(mask_f_data++) = (float)(m[i]); + + mask_h[0] = __float22half2_rn(mask_f[0]); + mask_h[1] = __float22half2_rn(mask_f[1]); + + result_h[0] = x_data_h[0] * h_scale * mask_h[0]; + result_h[1] = x_data_h[1] * h_scale * mask_h[1]; + +#else + + __half* x_data_h = reinterpret_cast<__half*>(&x_data); + float2 result[2]; + + result[0].x = (float)x_data_h[0] * scale * m[0]; + result[0].y = (float)x_data_h[1] * scale * m[1]; + result[1].x = (float)x_data_h[2] * scale * m[2]; + result[1].y = (float)x_data_h[3] * scale * m[3]; + + result_h[0] = __float22half2_rn(result[0]); + result_h[1] = __float22half2_rn(result[1]); + +#endif + x_cast[j] = result_f; + } + int high_index = + ((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x; + if (N > high_index) { + for (int i = high_index; i < N; i++) { + Xdata[i] = __float2half((float)Xdata[i] * scale * mask[i]); + } + } +} + +template +void launch_dropout_grad(T* vals, uint8_t* mask, int total_count, float ratio, cudaStream_t stream) +{ + assert(unroll_factor == 4); + + const float scale = 1. / (1. - ratio); + dropout_grad_kernel<<>>(total_count, scale, vals, mask); +} + +template void launch_dropout_grad(float* vals, + uint8_t* mask, + int total_count, + float ratio, + cudaStream_t stream); +template void launch_dropout_grad(__half* vals, + uint8_t* mask, + int total_count, + float ratio, + cudaStream_t stream); + +__global__ void dropout_grad_kernel(const int N, + const float scale, + const float* Xdata, + float* out, + uint8_t* mask) +{ + CUDA_1D_KERNEL_LOOP(i, N) { out[i] = Xdata[i] * scale * mask[i]; } +} + +__global__ void dropout_grad_kernel(const int N, + const float scale, + const __half* Xdata, + __half* out, + uint8_t* mask) +{ + const float2* x_cast = reinterpret_cast(Xdata); + float2* out_cast = reinterpret_cast(out); + const uint32_t* mask_cast = reinterpret_cast(mask); + + float2 result_f; + __half2* result_h = reinterpret_cast<__half2*>(&result_f); + + CUDA_1D_KERNEL_LOOP(j, N / unroll_factor) + { + float2 x_data = x_cast[j]; + uint32_t m_32 = mask_cast[j]; + uint8_t* m = (uint8_t*)&m_32; + + __half* x_data_h = reinterpret_cast<__half*>(&x_data); + float2 result[2]; + + result[0].x = (float)x_data_h[0] * scale * m[0]; + result[0].y = (float)x_data_h[1] * scale * m[1]; + result[1].x = (float)x_data_h[2] * scale * m[2]; + result[1].y = (float)x_data_h[3] * scale * m[3]; + + result_h[0] = __float22half2_rn(result[0]); + result_h[1] = __float22half2_rn(result[1]); + + out_cast[j] = result_f; + } + int high_index = + ((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x; + if (N > high_index) { + for (int i = high_index; i < N; i++) { + out[i] = __float2half((float)Xdata[i] * scale * mask[i]); + } + } +} + +template +void launch_dropout_grad(T* vals_out, + const T* vals, + uint8_t* mask, + int total_count, + float ratio, + cudaStream_t stream) +{ + assert(unroll_factor == 4); + + const float scale = 1. / (1. - ratio); + dropout_grad_kernel<<>>(total_count, scale, vals, vals_out, mask); +} +template void launch_dropout_grad(float*, + const float* vals, + uint8_t* mask, + int total_count, + float ratio, + cudaStream_t stream); +template void launch_dropout_grad(__half*, + const __half* vals, + uint8_t* mask, + int total_count, + float ratio, + cudaStream_t stream); + +__global__ void dropout_kernel(const int N, + const int dim, + const float ratio, + const float* bias, + float* Xdata, + uint8_t* mask, + std::pair seed) +{ + const float scale = 1. / (1. - ratio); + int idx = blockIdx.x * blockDim.x + threadIdx.x; + int tid = threadIdx.x % (dim / unroll_factor); + + curandStatePhilox4_32_10_t state; + curand_init(seed.first, idx, seed.second, &state); + + float4* Xdata_cast = reinterpret_cast(Xdata); + uint32_t* mask_32 = reinterpret_cast(mask); + const float4* bias_cast = reinterpret_cast(bias); + + CUDA_1D_KERNEL_LOOP(j, N) + { + float4 rand = curand_uniform4(&state); + uint32_t m_32; + uint8_t* m = (uint8_t*)&m_32; + + m[0] = (uint8_t)(rand.x > ratio); + m[1] = (uint8_t)(rand.y > ratio); + m[2] = (uint8_t)(rand.z > ratio); + m[3] = (uint8_t)(rand.w > ratio); + + float4 x_data = Xdata_cast[j]; + float4 b_data = bias_cast[j % (dim / unroll_factor)]; + + x_data.x += b_data.x; + x_data.y += b_data.y; + x_data.z += b_data.z; + x_data.w += b_data.w; + + x_data.x = x_data.x * scale * m[0]; + x_data.y = x_data.y * scale * m[1]; + x_data.z = x_data.z * scale * m[2]; + x_data.w = x_data.w * scale * m[3]; + + mask_32[j] = m_32; + Xdata_cast[j] = x_data; + } + int high_index = + ((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x; + if (N > high_index) { + float4 rand = curand_uniform4(&state); + float* rand_data = &(rand.x); + int k = 0; + for (int i = high_index; i < N; i++) { + float x_data = Xdata[i] + bias[i % dim]; + uint8_t m = (uint8_t)(rand_data[k++] > ratio); + Xdata[i] = x_data * scale * m; + mask[i] = m; + } + } +} + +__global__ void dropout_kernel(const int N, + const int dim, + const float ratio, + const __half* bias, + __half* Xdata, + uint8_t* mask, + std::pair seed) +{ + const float scale = 1. / (1. - ratio); + int idx = blockIdx.x * blockDim.x + threadIdx.x; + int tid = threadIdx.x % (dim / unroll_factor); + + curandStatePhilox4_32_10_t state; + curand_init(seed.first, idx, seed.second, &state); + + float2* Xdata_cast = reinterpret_cast(Xdata); + uint32_t* mask_32 = reinterpret_cast(mask); + const float2* bias_cast = reinterpret_cast(bias); + + CUDA_1D_KERNEL_LOOP(j, N) + { + float4 rand = curand_uniform4(&state); + + float2 data_f; + __half2* data_h = reinterpret_cast<__half2*>(&data_f); + + float2 bias_f; + __half2* bias_h = reinterpret_cast<__half2*>(&bias_f); + + data_f = Xdata_cast[j]; + bias_f = bias_cast[j % (dim / unroll_factor)]; + + float2 data_h_0 = __half22float2(data_h[0]); + float2 data_h_1 = __half22float2(data_h[1]); + + float2 bias_h_0 = __half22float2(bias_h[0]); + float2 bias_h_1 = __half22float2(bias_h[1]); + + data_h_0.x += bias_h_0.x; + data_h_0.y += bias_h_0.y; + data_h_1.x += bias_h_1.x; + data_h_1.y += bias_h_1.y; + + uint32_t m_32; + uint8_t* m = (uint8_t*)&m_32; + + m[0] = (uint8_t)(rand.x > ratio); + m[1] = (uint8_t)(rand.y > ratio); + m[2] = (uint8_t)(rand.z > ratio); + m[3] = (uint8_t)(rand.w > ratio); + + data_h_0.x = __float2half(data_h_0.x * scale * m[0]); + data_h_0.y = __float2half(data_h_0.y * scale * m[1]); + data_h_1.x = __float2half(data_h_1.x * scale * m[2]); + data_h_1.y = __float2half(data_h_1.y * scale * m[3]); + + float2 result_f; + __half2* result_h = reinterpret_cast<__half2*>(&result_f); + + result_h[0] = __float22half2_rn(data_h_0); + result_h[1] = __float22half2_rn(data_h_1); + + Xdata_cast[j] = result_f; + mask_32[j] = m_32; + } + int high_index = + ((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x; + if (N > high_index) { + float4 rand = curand_uniform4(&state); + float* rand_data = &(rand.x); + int k = 0; + for (int i = high_index; i < N; i++) { + float x_data = (float)Xdata[i] + (float)bias[i % dim]; + uint8_t m = (uint8_t)(rand_data[k++] > ratio); + Xdata[i] = __float2half(x_data * scale * m); + mask[i] = m; + } + } +} + +template +void launch_dropout(T* out, + const T* bias, + uint8_t* mask, + int batch, + int dim, + float ratio, + cudaStream_t stream) +{ + assert(unroll_factor == 4); + + int total_count = batch * dim / unroll_factor; + + dim3 grid_dim = DS_GET_BLOCKS(total_count); + dim3 block_dim = DS_CUDA_NUM_THREADS; + + uint64_t inc = (batch * dim) / grid_dim.x / block_dim.x; + std::pair seed = TrainingContext::Instance().IncrementOffset(inc); + + dropout_kernel<<>>( + total_count, dim, ratio, bias, out, mask, seed); +} + +template void launch_dropout(float*, + const float* bias, + uint8_t* mask, + int batch, + int dim, + float ratio, + cudaStream_t stream); +template void launch_dropout(__half*, + const __half* bias, + uint8_t* mask, + int batch, + int dim, + float ratio, + cudaStream_t stream); + +__global__ void dropout_kernel(const int N, + const int dim, + const float ratio, + const float* input, + const float* residual, + const float* bias, + float* out, + uint8_t* mask, + std::pair seed) +{ + const float scale = 1. / (1. - ratio); + int idx = blockIdx.x * blockDim.x + threadIdx.x; + int tid = threadIdx.x % (dim / unroll_factor); + + curandStatePhilox4_32_10_t state; + curand_init(seed.first, idx, seed.second, &state); + + float4* out_cast = reinterpret_cast(out); + uint32_t* mask_32 = reinterpret_cast(mask); + + const float4* bias_cast = reinterpret_cast(bias); + const float4* residual_cast = reinterpret_cast(residual); + const float4* input_cast = reinterpret_cast(input); + + CUDA_1D_KERNEL_LOOP(j, N) + { + float4 rand = curand_uniform4(&state); + + uint32_t m_32; + uint8_t* m = (uint8_t*)&m_32; + + m[0] = (uint8_t)(rand.x > ratio); + m[1] = (uint8_t)(rand.y > ratio); + m[2] = (uint8_t)(rand.z > ratio); + m[3] = (uint8_t)(rand.w > ratio); + + float4 out_data; + float4 b_data = bias_cast[j % (dim / unroll_factor)]; + float4 res_data = residual_cast[j]; + float4 inp_data = input_cast[j]; + + out_data.x = (b_data.x + inp_data.x); + out_data.y = (b_data.y + inp_data.y); + out_data.z = (b_data.z + inp_data.z); + out_data.w = (b_data.w + inp_data.w); + + out_data.x = out_data.x * scale * m[0]; + out_data.y = out_data.y * scale * m[1]; + out_data.z = out_data.z * scale * m[2]; + out_data.w = out_data.w * scale * m[3]; + + out_data.x += res_data.x; + out_data.y += res_data.y; + out_data.z += res_data.z; + out_data.w += res_data.w; + + mask_32[j] = m_32; + out_cast[j] = out_data; + } + int high_index = + ((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x; + if (N > high_index) { + float4 rand = curand_uniform4(&state); + float* rand_data = &(rand.x); + int k = 0; + for (int i = high_index; i < N; i++) { + float x_data = input[i] + bias[i % dim]; + uint8_t m = (uint8_t)(rand_data[k++] > ratio); + x_data = x_data * scale * m; + x_data += residual[i]; + + out[i] = x_data; + mask[i] = m; + } + } +} + +__global__ void dropout_kernel(const int N, + const int dim, + const float ratio, + const __half* input, + const __half* residual, + const __half* bias, + __half* out, + uint8_t* mask, + std::pair seed) +{ + const float scale = 1. / (1. - ratio); + int idx = blockIdx.x * blockDim.x + threadIdx.x; + int tid = threadIdx.x % (dim / unroll_factor); + + curandStatePhilox4_32_10_t state; + curand_init(seed.first, idx, seed.second, &state); + + float2* out_cast = reinterpret_cast(out); + uint32_t* mask_32 = reinterpret_cast(mask); + + const float2* bias_cast = reinterpret_cast(bias); + const float2* residual_cast = reinterpret_cast(residual); + const float2* input_cast = reinterpret_cast(input); + + CUDA_1D_KERNEL_LOOP(j, N) + { + float4 rand = curand_uniform4(&state); + + float2 data_f; + __half2* data_h = reinterpret_cast<__half2*>(&data_f); + + float2 bias_f; + __half2* bias_h = reinterpret_cast<__half2*>(&bias_f); + + float2 residual_f; + __half2* residual_h = reinterpret_cast<__half2*>(&residual_f); + + float2 input_f; + __half2* input_h = reinterpret_cast<__half2*>(&input_f); + + bias_f = bias_cast[j % (dim / unroll_factor)]; + residual_f = residual_cast[j]; + input_f = input_cast[j]; + + float2 data_h_0 = __half22float2(data_h[0]); + float2 data_h_1 = __half22float2(data_h[1]); + + float2 bias_h_0 = __half22float2(bias_h[0]); + float2 bias_h_1 = __half22float2(bias_h[1]); + + float2 residual_h_0 = __half22float2(residual_h[0]); + float2 residual_h_1 = __half22float2(residual_h[1]); + + float2 input_h_0 = __half22float2(input_h[0]); + float2 input_h_1 = __half22float2(input_h[1]); + + data_h_0.x = (bias_h_0.x + input_h_0.x); + data_h_0.y = (bias_h_0.y + input_h_0.y); + data_h_1.x = (bias_h_1.x + input_h_1.x); + data_h_1.y = (bias_h_1.y + input_h_1.y); + + uint32_t m_32; + uint8_t* m = (uint8_t*)&m_32; + + m[0] = (uint8_t)(rand.x > ratio); + m[1] = (uint8_t)(rand.y > ratio); + m[2] = (uint8_t)(rand.z > ratio); + m[3] = (uint8_t)(rand.w > ratio); + + data_h_0.x = __float2half(data_h_0.x * scale * m[0]); + data_h_0.y = __float2half(data_h_0.y * scale * m[1]); + data_h_1.x = __float2half(data_h_1.x * scale * m[2]); + data_h_1.y = __float2half(data_h_1.y * scale * m[3]); + + data_h_0.x += residual_h_0.x; + data_h_0.y += residual_h_0.y; + data_h_1.x += residual_h_1.x; + data_h_1.y += residual_h_1.y; + + float2 result_f; + __half2* result_h = reinterpret_cast<__half2*>(&result_f); + + result_h[0] = __float22half2_rn(data_h_0); + result_h[1] = __float22half2_rn(data_h_1); + + out_cast[j] = result_f; + mask_32[j] = m_32; + } + int high_index = + ((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x; + if (N > high_index) { + float4 rand = curand_uniform4(&state); + float* rand_data = &(rand.x); + int k = 0; + for (int i = high_index; i < N; i++) { + float x_data = (float)input[i] + (float)bias[i % dim]; + uint8_t m = (uint8_t)(rand_data[k++] > ratio); + x_data = x_data * scale * m; + x_data += (float)residual[i]; + + out[i] = __float2half(x_data); + mask[i] = m; + } + } +} + +template +void launch_dropout(T* out, + const T* input, + const T* residual, + const T* bias, + uint8_t* mask, + int batch, + int dim, + float ratio, + cudaStream_t stream) +{ + assert(unroll_factor == 4); + + int total_count = batch * dim / unroll_factor; + dim3 grid_dim = DS_GET_BLOCKS(total_count); + dim3 block_dim = DS_CUDA_NUM_THREADS; + + uint64_t inc = (batch * dim) / grid_dim.x / block_dim.x; + std::pair seed = TrainingContext::Instance().IncrementOffset(inc); + + dropout_kernel<<>>( + total_count, dim, ratio, input, residual, bias, out, mask, seed); +} + +template void launch_dropout(float*, + const float*, + const float* residual, + const float* bias, + uint8_t* mask, + int batch, + int dim, + float ratio, + cudaStream_t stream); +template void launch_dropout(__half*, + const __half*, + const __half* residual, + const __half* bias, + uint8_t* mask, + int batch, + int dim, + float ratio, + cudaStream_t stream); diff --git a/toolbox/DeepSpeed/v0.15.3/patches/csrc/transformer/ds_transformer_cuda.cpp b/toolbox/DeepSpeed/v0.15.3/patches/csrc/transformer/ds_transformer_cuda.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c99d573204d0daba9f0557bbc22a3a4c7d19c134 --- /dev/null +++ b/toolbox/DeepSpeed/v0.15.3/patches/csrc/transformer/ds_transformer_cuda.cpp @@ -0,0 +1,1072 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +/* Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. */ +/* All Rights Reserved. */ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include + +#include +#include +#include +#include +#include +#include +#include "Timer.h" +#include "context.h" +#include "cublas_wrappers.h" +#include "custom_cuda_layers.h" +#include "ds_transformer_cuda.h" + +static std::unordered_map> s_transformer_layers; + +const int init_seq_length = 128; + +// C++ interface + +template +unsigned get_workspace_size(unsigned maxBatchSize, + unsigned seq_len, + unsigned hidden_size, + unsigned intermediate_size, + unsigned heads, + bool training, + bool gelu_checkpoint) +{ + unsigned workSpacesize = 4 * (size_t(maxBatchSize) * seq_len * hidden_size); + if (training) { + workSpacesize += 2 * (size_t(maxBatchSize) * seq_len * hidden_size); + workSpacesize += ((std::max)((size_t(maxBatchSize) * seq_len * intermediate_size), + 2 * (size_t(maxBatchSize) * heads * seq_len * seq_len))); + if (gelu_checkpoint) + workSpacesize += 2 * (size_t(maxBatchSize) * seq_len * intermediate_size); + } + return workSpacesize; // * sizeof(T); +} + +// NOTE: AT_ASSERT has become AT_CHECK on master after 0.4. +#define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) \ + CHECK_CUDA(x); \ + CHECK_CONTIGUOUS(x) + +template +BertTransformerLayer::BertTransformerLayer(unsigned layer_id, + unsigned batch_size, + unsigned hidden_size, + unsigned num_heads, + unsigned intermediate_size, + unsigned seq_length, + float attn_prob_dropout_ratio, + float hidden_output_dropout_ratio, + float layer_norm_eps, + bool pre_or_postLayerNorm, + const std::vector>& gemm_algos, + bool attn_dropout_checkpoint, + bool normalize_invertible, + bool gelu_checkpoint, + bool stochastic_mode) + : _layer_id(layer_id), + _batch_size(batch_size), + _hidden_size(hidden_size), + _heads(num_heads), + _intermediate_size(intermediate_size), + _seq_length(seq_length), + _training(true), + _pre_or_postLayerNorm(pre_or_postLayerNorm), + _attn_dropout_checkpoint(attn_dropout_checkpoint), + _normalize_invertible(normalize_invertible), + _gelu_checkpoint(gelu_checkpoint), + _stochastic_mode(stochastic_mode), + _stream(TrainingContext::Instance().GetCurrentStream()), + _cublasHandle(TrainingContext::Instance().GetCublasHandle()), + _qkv_linear(typename FeedForward::Config(batch_size * seq_length, + 3 * hidden_size, + hidden_size, + gemm_algos[0])), + _attn_out_linear(typename FeedForward::Config(batch_size * seq_length, + hidden_size, + hidden_size, + gemm_algos[0])), + _attn_layer_norm(typename Normalize_Layer::Config(batch_size, + seq_length, + hidden_size, + layer_norm_eps, + true, + !normalize_invertible)), + _layer_norm(typename Normalize_Layer::Config(batch_size, + seq_length, + hidden_size, + layer_norm_eps, + true, + !normalize_invertible)), + _ff1(typename FeedForward::Config(batch_size * seq_length, + _intermediate_size, + hidden_size, + gemm_algos[1])), + _ff2(typename FeedForward::Config(batch_size * seq_length, + hidden_size, + _intermediate_size, + gemm_algos[2])), + _softmax(typename Softmax::Config(batch_size, num_heads, seq_length)), + _gelu(typename Gelu::Config(_intermediate_size)), + _attn_prob_dropout(typename Dropout::Config(attn_prob_dropout_ratio, _seq_length)), + _attn_output_dropout(typename Dropout::Config(hidden_output_dropout_ratio, _hidden_size)), + _layer_output_dropout(typename Dropout::Config(hidden_output_dropout_ratio, _hidden_size)), + _attn_scores(typename StridedBatchGemm::Config(_batch_size * _heads, + _seq_length, + _seq_length, + _hidden_size / _heads, + (T(1.0) / T(sqrt(_hidden_size / _heads))), + T(0.0), + CUBLAS_OP_T, + CUBLAS_OP_N, + gemm_algos[3])), + _attn_context(typename StridedBatchGemm::Config(_batch_size * _heads, + _hidden_size / _heads, + _seq_length, + _seq_length, + T(1.0), + T(0.0), + CUBLAS_OP_N, + CUBLAS_OP_N, + gemm_algos[4])) +{ + assert(_hidden_size % _heads == 0); + + Initialize(); +} + +template +BertTransformerLayer::~BertTransformerLayer() +{ +} + +template +void BertTransformerLayer::Initialize() +{ +#ifndef __HIP_PLATFORM_AMD__ + if (std::is_same::value) cublasSetMathMode(_cublasHandle, CUBLAS_TENSOR_OP_MATH); +#endif +} + +template +void BertTransformerLayer::Forward(unsigned bsz, + const T* input_ptr, + const T* input_mask_ptr, + const T* attn_qkvw_ptr, + const T* attn_qkvb_ptr, + const T* attn_ow_ptr, + const T* attn_ob_ptr, + const T* attn_nw_ptr, + const T* attn_nb_ptr, + const T* inter_w_ptr, + const T* inter_b_ptr, + const T* output_w_ptr, + const T* output_b_ptr, + const T* norm_w_ptr, + const T* norm_b_ptr, + T* out_ptr, + T* inp_norm_ptr, + T* q_tf_ptr, + T* k_tf_ptr, + T* v_tf_ptr, + T* soft_out_ptr, + T* ctx_bufB_ptr, + T* attn_o_inp_ptr, + T* add_res_ptr, + T* ff1_inp_ptr, + T* gelu_inp_ptr, + T* ff2_inp_ptr) +{ + cublasSetStream(_cublasHandle, _stream); + + if (!_stochastic_mode) cudaStreamSynchronize(_stream); + + T* workspace = static_cast(TrainingContext::Instance().GetWorkSpace()); + size_t small_buf_size = bsz * _seq_length * _hidden_size; + T* buf_0 = workspace; + T* buf_1 = buf_0 + small_buf_size; + T* buf_2 = buf_1; + + if (_normalize_invertible) { + add_res_ptr = buf_1 + 3 * small_buf_size; + buf_2 = add_res_ptr; + } + if (_gelu_checkpoint) buf_2 += small_buf_size; + if (_attn_dropout_checkpoint) + ctx_bufB_ptr = + (_gelu_checkpoint ? (buf_2 + (_intermediate_size / _hidden_size) * small_buf_size) + : (buf_1 + 4 * small_buf_size)); + + int bsz_seq = bsz * _seq_length; + + if (_pre_or_postLayerNorm) { + if (_layer_norm.UseMean()) + _layer_norm.ForwardCheckpoint( + bsz_seq, inp_norm_ptr, input_ptr, norm_w_ptr, norm_b_ptr, _stream, true); + + else + _layer_norm.Forward( + bsz_seq, inp_norm_ptr, input_ptr, norm_w_ptr, norm_b_ptr, _stream, true); + } + + if (_pre_or_postLayerNorm) + _qkv_linear.Forward(bsz_seq, inp_norm_ptr, attn_qkvw_ptr, buf_0, _cublasHandle); + else + _qkv_linear.Forward(bsz_seq, input_ptr, attn_qkvw_ptr, buf_0, _cublasHandle); + + launch_bias_add_transform_0213( + q_tf_ptr, buf_0, attn_qkvb_ptr, bsz, _seq_length, _hidden_size, _heads, _stream, 3); + + int bsz_heads = bsz * _heads; + + // attention scores + _attn_scores.Forward(bsz_heads, soft_out_ptr, k_tf_ptr, q_tf_ptr, _cublasHandle); + + // Softmax + Mask + _softmax.Forward(bsz, soft_out_ptr, input_mask_ptr, _stream); + + // attn prob dropout. + _attn_prob_dropout.Forward(bsz_heads * _seq_length, ctx_bufB_ptr, soft_out_ptr, _stream); + + // attention context + _attn_context.Forward(bsz_heads, buf_1, v_tf_ptr, ctx_bufB_ptr, _cublasHandle); + + launch_transform4d_0213( + attn_o_inp_ptr, buf_1, bsz, _heads, _seq_length, _hidden_size, _stream, 1); + + if (_pre_or_postLayerNorm) + _attn_out_linear.Forward(bsz_seq, attn_o_inp_ptr, attn_ow_ptr, buf_1, _cublasHandle); + else + _attn_out_linear.Forward(bsz_seq, attn_o_inp_ptr, attn_ow_ptr, ff1_inp_ptr, _cublasHandle); + + // attn output dropout. + if (_pre_or_postLayerNorm) + _attn_output_dropout.ForwardWithBias( + bsz_seq, add_res_ptr, buf_1, input_ptr, attn_ob_ptr, _stream); + else + _attn_output_dropout.ForwardWithBias( + bsz_seq, add_res_ptr, ff1_inp_ptr, input_ptr, attn_ob_ptr, _stream); + + if (_pre_or_postLayerNorm) { + if (_attn_layer_norm.UseMean()) + _attn_layer_norm.ForwardCheckpoint( + bsz_seq, ff1_inp_ptr, add_res_ptr, attn_nw_ptr, attn_nb_ptr, _stream, true); + else + _attn_layer_norm.Forward( + bsz_seq, ff1_inp_ptr, add_res_ptr, attn_nw_ptr, attn_nb_ptr, _stream, true); + } else { + if (_attn_layer_norm.UseMean()) + _attn_layer_norm.ForwardCheckpoint( + bsz_seq, ff1_inp_ptr, add_res_ptr, attn_nw_ptr, attn_nb_ptr, _stream, true); + else + _attn_layer_norm.Forward( + bsz_seq, ff1_inp_ptr, add_res_ptr, attn_nw_ptr, attn_nb_ptr, _stream, true); + } + + _ff1.Forward(bsz_seq, + ff1_inp_ptr, + inter_w_ptr, + (_gelu_checkpoint ? ff2_inp_ptr : gelu_inp_ptr), + _cublasHandle); + + _gelu.ForwardWithBiasAdd(bsz_seq, + (_gelu_checkpoint ? ff2_inp_ptr : gelu_inp_ptr), + inter_b_ptr, + (_gelu_checkpoint ? buf_2 : ff2_inp_ptr), + _stream); + + _ff2.Forward( + bsz_seq, (_gelu_checkpoint ? buf_2 : ff2_inp_ptr), output_w_ptr, out_ptr, _cublasHandle); + + // layer output dropout. + if (_pre_or_postLayerNorm) + _layer_output_dropout.ForwardWithBias( + bsz_seq, out_ptr, out_ptr, add_res_ptr, output_b_ptr, _stream); + else + _layer_output_dropout.ForwardWithBias( + bsz_seq, inp_norm_ptr, out_ptr, ff1_inp_ptr, output_b_ptr, _stream); + + if (!_pre_or_postLayerNorm) { + if (_layer_norm.UseMean()) + _layer_norm.ForwardCheckpoint( + bsz_seq, out_ptr, inp_norm_ptr, norm_w_ptr, norm_b_ptr, _stream, true); + else + _layer_norm.Forward( + bsz_seq, out_ptr, inp_norm_ptr, norm_w_ptr, norm_b_ptr, _stream, true); + } +} + +template +void BertTransformerLayer::Backward(unsigned bsz, + const T* grad_output_ptr, + const T* input_ptr, + const T* output_ptr, + const T* inp_norm_ptr, + const T* q_tf_ptr, + const T* k_tf_ptr, + const T* v_tf_ptr, + const T* soft_out_ptr, + const T* ctx_bufB_ptr, + const T* attn_o_inp_ptr, + const T* add_res_ptr, + const T* ff1_inp_ptr, + const T* gelu_inp_ptr, + const T* ff2_inp_ptr, + const T* input_mask_ptr, + const T* attn_qkvw_ptr, + const T* attn_ow_ptr, + const T* attn_nw_ptr, + const T* attn_nb_ptr, + const T* inter_w_ptr, + const T* inter_b_ptr, + const T* output_w_ptr, + const T* norm_w_ptr, + const T* norm_b_ptr, + + T* grad_input_ptr, + T* grad_attn_qkvw_ptr, + T* grad_attn_qkvb_ptr, + T* grad_attn_ow_ptr, + T* grad_attn_ob_ptr, + T* grad_attn_nw_ptr, + T* grad_attn_nb_ptr, + T* grad_inter_w_ptr, + T* grad_inter_b_ptr, + T* grad_output_w_ptr, + T* grad_output_b_ptr, + T* grad_norm_w_ptr, + T* grad_norm_b_ptr) +{ + cublasSetStream(_cublasHandle, _stream); + + if (!_stochastic_mode) cudaStreamSynchronize(_stream); + + T* workspace = static_cast(TrainingContext::Instance().GetWorkSpace()); + size_t small_buf_size = bsz * _seq_length * _hidden_size; + T* buf_0 = workspace; + T* buf_1 = buf_0 + small_buf_size; + T* buf_2 = buf_1 + small_buf_size; + T* buf_3 = buf_2 + small_buf_size; + + T* ff2_buf = (_gelu_checkpoint ? buf_3 + (bsz * _seq_length * _intermediate_size) + : buf_3 + small_buf_size); + T* ctx_bufB_ptr_recomp = ff2_buf + (_seq_length * _seq_length * bsz * _heads); + + cudaStream_t streams[2] = {_stream, _stream}; + + int bsz_seq = bsz * _seq_length; + int bsz_heads = bsz * _heads; + + if (!_pre_or_postLayerNorm) { + if (_layer_norm.UseMean()) + _layer_norm.Backward(bsz_seq, + grad_output_ptr, + norm_w_ptr, + grad_norm_w_ptr, + grad_norm_b_ptr, + streams, + buf_1, + inp_norm_ptr); + + else + _layer_norm.Backward(bsz_seq, + grad_output_ptr, + norm_w_ptr, + norm_b_ptr, + grad_norm_w_ptr, + grad_norm_b_ptr, + streams, + buf_1, + output_ptr); + } + + if (_pre_or_postLayerNorm) + _layer_output_dropout.Backward(bsz_seq, buf_0, grad_output_ptr, _stream); + else + _layer_output_dropout.Backward(bsz_seq, buf_0, buf_1, _stream); + + const T* layer_dropout_buf = _layer_output_dropout.HasDropout() + ? buf_0 + : (_pre_or_postLayerNorm ? grad_output_ptr : buf_1); + + if (_gelu_checkpoint) + _gelu.ForwardWithBiasAdd(bsz_seq, ff2_inp_ptr, inter_b_ptr, buf_2, _stream); + _ff2.Backward(bsz_seq, + layer_dropout_buf, + (_gelu_checkpoint ? buf_2 : ff2_inp_ptr), + output_w_ptr, + grad_output_w_ptr, + grad_output_b_ptr, + _cublasHandle, + _stream, + ff2_buf); + + _gelu.Backward( + bsz_seq, ff2_buf, (_gelu_checkpoint ? ff2_inp_ptr : gelu_inp_ptr), inter_b_ptr, _stream); + + _ff1.Backward(bsz_seq, + ff2_buf, + ff1_inp_ptr, + inter_w_ptr, + grad_inter_w_ptr, + grad_inter_b_ptr, + _cublasHandle, + _stream, + buf_3); + + if (!_pre_or_postLayerNorm) + launch_fused_add2(buf_2, buf_3, buf_1, bsz, _seq_length, _hidden_size, _stream); + + if (_pre_or_postLayerNorm) { + if (_attn_layer_norm.UseMean()) + _attn_layer_norm.BackwardFusedAdd(bsz_seq, + buf_3, + grad_output_ptr, + attn_nw_ptr, + grad_attn_nw_ptr, + grad_attn_nb_ptr, + streams, + buf_0, + add_res_ptr); + + else + _attn_layer_norm.BackwardFusedAdd(bsz_seq, + buf_3, + grad_output_ptr, + attn_nw_ptr, + attn_nb_ptr, + grad_attn_nw_ptr, + grad_attn_nb_ptr, + streams, + buf_0, + ff1_inp_ptr); + } else { + if (_attn_layer_norm.UseMean()) + _attn_layer_norm.Backward(bsz_seq, + buf_2, + attn_nw_ptr, + grad_attn_nw_ptr, + grad_attn_nb_ptr, + streams, + buf_0, + add_res_ptr); + + else + _attn_layer_norm.Backward(bsz_seq, + buf_2, + attn_nw_ptr, + attn_nb_ptr, + grad_attn_nw_ptr, + grad_attn_nb_ptr, + streams, + buf_0, + ff1_inp_ptr); + } + + _attn_output_dropout.Backward(bsz_seq, buf_2, buf_0, _stream); + + T* attn_output_dropout_buf = _attn_output_dropout.HasDropout() ? buf_2 : buf_0; + + _attn_out_linear.Backward(bsz_seq, + attn_output_dropout_buf, + attn_o_inp_ptr, + attn_ow_ptr, + grad_attn_ow_ptr, + grad_attn_ob_ptr, + _cublasHandle, + _stream, + buf_1); + + launch_transform_0213(buf_2, buf_1, bsz, _seq_length, _hidden_size, _heads, _stream); + + if (_attn_prob_dropout.HasDropout()) { + if (_attn_dropout_checkpoint) + _attn_prob_dropout.Forward( + bsz_heads * _seq_length, ctx_bufB_ptr_recomp, soft_out_ptr, _stream, true); + + _attn_context.Backward(bsz_heads, + buf_2, + v_tf_ptr, + (_attn_dropout_checkpoint ? ctx_bufB_ptr_recomp : ctx_bufB_ptr), + _cublasHandle, + buf_3, + ff2_buf); + } else + _attn_context.Backward( + bsz_heads, buf_2, v_tf_ptr, soft_out_ptr, _cublasHandle, buf_3, ff2_buf); + + _attn_prob_dropout.Backward(bsz_heads * _seq_length, ff2_buf, _stream); + + _softmax.Backward(bsz, ff2_buf, soft_out_ptr, _stream); + + _attn_scores.Backward(bsz_heads, ff2_buf, k_tf_ptr, q_tf_ptr, _cublasHandle, buf_2, buf_1); + + launch_transform4d_0213(ff2_buf, buf_1, bsz, _heads, _seq_length, _hidden_size, _stream, 3); + + if (_pre_or_postLayerNorm) + _qkv_linear.Backward(bsz_seq, + ff2_buf, + inp_norm_ptr, + attn_qkvw_ptr, + grad_attn_qkvw_ptr, + grad_attn_qkvb_ptr, + _cublasHandle, + _stream, + buf_2); + else + _qkv_linear.Backward(bsz_seq, + ff2_buf, + input_ptr, + attn_qkvw_ptr, + grad_attn_qkvw_ptr, + grad_attn_qkvb_ptr, + _cublasHandle, + _stream, + buf_2); + + if (_pre_or_postLayerNorm) { + if (_layer_norm.UseMean()) + _layer_norm.BackwardFusedAdd(bsz_seq, + buf_2, + buf_0, + norm_w_ptr, + grad_norm_w_ptr, + grad_norm_b_ptr, + streams, + grad_input_ptr, + input_ptr); + + else + _layer_norm.BackwardFusedAdd(bsz_seq, + buf_2, + buf_0, + norm_w_ptr, + norm_b_ptr, + grad_norm_w_ptr, + grad_norm_b_ptr, + streams, + grad_input_ptr, + inp_norm_ptr); + } else + launch_fused_add2(grad_input_ptr, buf_2, buf_0, bsz, _seq_length, _hidden_size, _stream); +} + +template +void BertTransformerLayer::SetTrainingMode(bool training) +{ + // Dropout will be skipped when not in training model. + _attn_prob_dropout.SetTrainingMode(training); + _attn_output_dropout.SetTrainingMode(training); + _layer_output_dropout.SetTrainingMode(training); +} + +template +void BertTransformerLayer::SetIntermediateBuffers(uint8_t* attn_prob_dropout_mask_ptr, + uint8_t* attn_output_dropout_mask_ptr, + uint8_t* layer_output_dropout_mask_ptr, + T* attn_layer_norm_var, + T* attn_layer_norm_mean, + T* layer_norm_var, + T* layer_norm_mean) +{ + _attn_prob_dropout.SetMask(attn_prob_dropout_mask_ptr); + _attn_output_dropout.SetMask(attn_output_dropout_mask_ptr); + _layer_output_dropout.SetMask(layer_output_dropout_mask_ptr); + + _attn_layer_norm.SetVar(attn_layer_norm_var); + _attn_layer_norm.SetMean(attn_layer_norm_mean); + _layer_norm.SetVar(layer_norm_var); + _layer_norm.SetMean(layer_norm_mean); +} + +template +void BertTransformerLayer::SetSeqLength(unsigned seq_len) +{ + _seq_length = seq_len; + + _softmax.SetSeqLength(_seq_length); + _attn_prob_dropout.SetDimension(_seq_length); + _attn_scores.SetConfig(_seq_length, _seq_length, _hidden_size / _heads); + _attn_context.SetConfig(_hidden_size / _heads, _seq_length, _seq_length); +} + +template +int create_transformer_layer(unsigned layer_id, + unsigned batch_size, + unsigned hidden_dim, + unsigned num_heads, + unsigned intermediate_size, + float attn_dropout_ratio, + float hidden_dropout_ratio, + float layer_norm_eps, + int seed, + bool pre_or_postLayerNorm, + bool test_gemm, + bool attn_dropout_checkpoint, + bool normalize_invertible, + bool gelu_checkpoint, + bool stochastic_mode) +{ + TrainingContext::Instance().SetSeed(seed); + TrainingContext::Instance().TestGemmFP16( + test_gemm, batch_size, init_seq_length, num_heads, hidden_dim / num_heads); + + auto layer = + std::make_shared>(layer_id, + batch_size, + hidden_dim, + num_heads, + intermediate_size, + init_seq_length, + attn_dropout_ratio, + hidden_dropout_ratio, + layer_norm_eps, + pre_or_postLayerNorm, + TrainingContext::Instance().GetGemmAlgos(), + attn_dropout_checkpoint, + normalize_invertible, + gelu_checkpoint, + stochastic_mode); + + s_transformer_layers[layer_id] = layer; + + std::string dtype = (std::is_same::value) ? "half" : "float"; + + std::cout << "layer #" << layer_id << " is created with date type [" << dtype << "]." + << std::endl; + + return 0; +} + +template +std::vector ds_transformer_forward(unsigned layer_id, + const torch::Tensor& input, + const torch::Tensor& input_mask, + const torch::Tensor& attn_qkvw, + const torch::Tensor& attn_qkvb, + const torch::Tensor& attn_ow, + const torch::Tensor& attn_ob, + const torch::Tensor& attn_nw, + const torch::Tensor& attn_nb, + const torch::Tensor& inter_w, + const torch::Tensor& inter_b, + const torch::Tensor& output_w, + const torch::Tensor& output_b, + const torch::Tensor& norm_w, + const torch::Tensor& norm_b, + bool training_mode, + bool prelayernorm, + bool attn_dropout_checkpoint, + bool normalize_invertible, + bool gelu_checkpoint) +{ + CHECK_INPUT(input); + CHECK_INPUT(input_mask); + CHECK_INPUT(attn_qkvw); + CHECK_INPUT(attn_qkvb); + CHECK_INPUT(attn_ow); + CHECK_INPUT(attn_ob); + CHECK_INPUT(attn_nw); + CHECK_INPUT(attn_nb); + CHECK_INPUT(inter_w); + CHECK_INPUT(inter_b); + CHECK_INPUT(output_w); + CHECK_INPUT(output_b); + CHECK_INPUT(norm_w); + CHECK_INPUT(norm_b); + + unsigned bsz = input.size(0); + + const T* input_ptr = (const T*)input.data_ptr(); + const T* input_mask_ptr = (const T*)input_mask.data_ptr(); + const T* attn_qkvw_ptr = (const T*)attn_qkvw.data_ptr(); + const T* attn_qkvb_ptr = (const T*)attn_qkvb.data_ptr(); + const T* attn_ow_ptr = (const T*)attn_ow.data_ptr(); + const T* attn_ob_ptr = (const T*)attn_ob.data_ptr(); + const T* attn_nw_ptr = (const T*)attn_nw.data_ptr(); + const T* attn_nb_ptr = (const T*)attn_nb.data_ptr(); + const T* inter_w_ptr = (const T*)inter_w.data_ptr(); + const T* inter_b_ptr = (const T*)inter_b.data_ptr(); + const T* output_w_ptr = (const T*)output_w.data_ptr(); + const T* output_b_ptr = (const T*)output_b.data_ptr(); + const T* norm_w_ptr = (const T*)norm_w.data_ptr(); + const T* norm_b_ptr = (const T*)norm_b.data_ptr(); + + auto output = torch::empty_like(input); + T* out_ptr = (T*)output.data_ptr(); + + auto options = torch::TensorOptions() + .dtype(input.options().dtype()) + .layout(torch::kStrided) + .device(torch::kCUDA) + .requires_grad(true); + + auto uint8_options = torch::TensorOptions() + .dtype(torch::kInt8) + .layout(torch::kStrided) + .device(torch::kCUDA) + .requires_grad(false); + + std::shared_ptr> layer = + std::static_pointer_cast>(s_transformer_layers[layer_id]); + + unsigned seq_len = layer->GetSeqLength(); + if (input.size(1) != seq_len) { + seq_len = input.size(1); + layer->SetSeqLength(seq_len); + } + + auto workspace = torch::empty({get_workspace_size(bsz, + seq_len, + layer->GetHiddenSize(), + layer->GetIntermediateSize(), + layer->GetNumHeads(), + layer->IsTrainingMode(), + layer->GeluCheckpoint())}, + options); + TrainingContext::Instance().SetWorkSpace((T*)workspace.data_ptr()); + + auto inp_norm = ((prelayernorm || !normalize_invertible) ? torch::empty_like(input) : output); + auto add_res = (normalize_invertible ? inp_norm : torch::empty_like(input)); + auto attn_o_inp = torch::empty_like(input); + auto qkv_tf = torch::empty({(bsz * seq_len), output_w.size(0) * 3}, options); + + auto attn_prob_dropout_mask = + torch::empty({(bsz * layer->GetNumHeads() * seq_len), seq_len}, uint8_options); + auto attn_output_dropout_mask = + torch::empty({(bsz * seq_len), layer->GetHiddenSize()}, uint8_options); + auto layer_output_dropout_mask = + torch::empty({(bsz * seq_len), layer->GetHiddenSize()}, uint8_options); + + auto attn_layer_norm_var = torch::empty({(bsz * seq_len)}, options); + auto attn_layer_norm_mean = torch::empty({(bsz * seq_len)}, options); + auto layer_norm_var = torch::empty({(bsz * seq_len)}, options); + auto layer_norm_mean = torch::empty({(bsz * seq_len)}, options); + + T* inp_norm_ptr = (T*)inp_norm.data_ptr(); + T* add_res_ptr = (T*)add_res.data_ptr(); + T* q_tf_ptr = (T*)qkv_tf.data_ptr(); + T* k_tf_ptr = q_tf_ptr + (bsz * seq_len * output_w.size(0)); //(T*)k_tf.data_ptr(); + T* v_tf_ptr = k_tf_ptr + (bsz * seq_len * output_w.size(0)); //(T*)v_tf.data_ptr(); + T* attn_o_inp_ptr = (T*)attn_o_inp.data_ptr(); + + torch::Tensor ff2_inp = torch::empty({(bsz * seq_len), output_w.size(1)}, options); + torch::Tensor gelu_inp = + (gelu_checkpoint ? ff2_inp : torch::empty({(bsz * seq_len), output_w.size(1)}, options)); + auto ff1_inp = torch::empty_like(input); + T* ff2_inp_ptr = (T*)ff2_inp.data_ptr(); + T* gelu_inp_ptr = (T*)gelu_inp.data_ptr(); + T* ff1_inp_ptr = (T*)ff1_inp.data_ptr(); + + torch::Tensor soft_out = + torch::empty({(bsz * layer->GetNumHeads() * seq_len), seq_len}, options); + torch::Tensor ctx_bufB = + (attn_dropout_checkpoint + ? soft_out + : torch::empty({(bsz * layer->GetNumHeads() * seq_len), seq_len}, options)); + T* soft_out_ptr = (T*)soft_out.data_ptr(); + T* ctx_bufB_ptr = (T*)ctx_bufB.data_ptr(); + + layer->SetTrainingMode(training_mode); + layer->SetIntermediateBuffers((uint8_t*)attn_prob_dropout_mask.data_ptr(), + (uint8_t*)attn_output_dropout_mask.data_ptr(), + (uint8_t*)layer_output_dropout_mask.data_ptr(), + (T*)attn_layer_norm_var.data_ptr(), + (T*)attn_layer_norm_mean.data_ptr(), + (T*)layer_norm_var.data_ptr(), + (T*)layer_norm_mean.data_ptr()); + + layer->Forward(bsz, + input_ptr, + input_mask_ptr, + attn_qkvw_ptr, + attn_qkvb_ptr, + attn_ow_ptr, + attn_ob_ptr, + attn_nw_ptr, + attn_nb_ptr, + inter_w_ptr, + inter_b_ptr, + output_w_ptr, + output_b_ptr, + norm_w_ptr, + norm_b_ptr, + out_ptr, + inp_norm_ptr, + q_tf_ptr, + k_tf_ptr, + v_tf_ptr, + soft_out_ptr, + ctx_bufB_ptr, + attn_o_inp_ptr, + add_res_ptr, + ff1_inp_ptr, + gelu_inp_ptr, + ff2_inp_ptr); + + return {output, + inp_norm, + qkv_tf, + soft_out, + ctx_bufB, + attn_o_inp, + add_res, + ff1_inp, + gelu_inp, + ff2_inp, + attn_prob_dropout_mask, + attn_output_dropout_mask, + layer_output_dropout_mask, + attn_layer_norm_var, + attn_layer_norm_mean, + layer_norm_var, + layer_norm_mean}; +} + +template +std::vector ds_transformer_backward(unsigned layer_id, + const torch::Tensor& grad_output, + const torch::Tensor& output, + const torch::Tensor& inp_norm, + const torch::Tensor& qkv_tf, + const torch::Tensor& soft_out, + const torch::Tensor& ctx_bufB, + const torch::Tensor& attn_o_inp, + const torch::Tensor& add_res, + const torch::Tensor& ff1_inp, + const torch::Tensor& gelu_inp, + const torch::Tensor& ff2_inp, + const torch::Tensor& attn_prob_dropout_mask, + const torch::Tensor& attn_output_dropout_mask, + const torch::Tensor& layer_output_dropout_mask, + const torch::Tensor& attn_layer_norm_var, + const torch::Tensor& attn_layer_norm_mean, + const torch::Tensor& layer_norm_var, + const torch::Tensor& layer_norm_mean, + const torch::Tensor& input, + const torch::Tensor& input_mask, + const torch::Tensor& attn_qkvw, + const torch::Tensor& attn_qkvb, + const torch::Tensor& attn_ow, + const torch::Tensor& attn_ob, + const torch::Tensor& attn_nw, + const torch::Tensor& attn_nb, + const torch::Tensor& inter_w, + const torch::Tensor& inter_b, + const torch::Tensor& output_w, + const torch::Tensor& output_b, + const torch::Tensor& norm_w, + const torch::Tensor& norm_b) +{ + auto g_output = grad_output.contiguous(); + CHECK_INPUT(g_output); + CHECK_INPUT(output); + CHECK_INPUT(inp_norm); + CHECK_INPUT(qkv_tf); + CHECK_INPUT(add_res); + CHECK_INPUT(soft_out); + CHECK_INPUT(ctx_bufB); + CHECK_INPUT(attn_o_inp); + CHECK_INPUT(ff1_inp); + CHECK_INPUT(gelu_inp); + CHECK_INPUT(ff2_inp); + CHECK_INPUT(input); + CHECK_INPUT(input_mask); + CHECK_INPUT(attn_qkvw); + CHECK_INPUT(attn_qkvb); + CHECK_INPUT(attn_ow); + CHECK_INPUT(attn_ob); + CHECK_INPUT(attn_nw); + CHECK_INPUT(attn_nb); + CHECK_INPUT(inter_w); + CHECK_INPUT(inter_b); + CHECK_INPUT(output_w); + CHECK_INPUT(output_b); + CHECK_INPUT(norm_w); + CHECK_INPUT(norm_b); + + unsigned bsz = g_output.size(0); + + std::shared_ptr> layer = + std::static_pointer_cast>(s_transformer_layers[layer_id]); + + unsigned seq_len = layer->GetSeqLength(); + if (g_output.size(1) != seq_len) { + seq_len = g_output.size(1); + layer->SetSeqLength(seq_len); + } + auto options = torch::TensorOptions() + .dtype(g_output.options().dtype()) + .layout(torch::kStrided) + .device(torch::kCUDA) + .requires_grad(true); + auto workspace = torch::empty({get_workspace_size(bsz, + seq_len, + layer->GetHiddenSize(), + layer->GetIntermediateSize(), + layer->GetNumHeads(), + layer->IsTrainingMode(), + layer->GeluCheckpoint())}, + options); + TrainingContext::Instance().SetWorkSpace((T*)workspace.data_ptr()); + + auto grad_input = torch::empty_like(input); + auto grad_attn_qkvw = torch::empty_like(attn_qkvw); + auto grad_attn_qkvb = torch::empty_like(attn_qkvb); + auto grad_attn_ow = torch::empty_like(attn_ow); + auto grad_attn_ob = torch::empty_like(attn_ob); + auto grad_attn_nw = torch::empty_like(attn_nw); + auto grad_attn_nb = torch::empty_like(attn_nb); + auto grad_inter_w = torch::empty_like(inter_w); + auto grad_inter_b = torch::empty_like(inter_b); + auto grad_output_w = torch::empty_like(output_w); + auto grad_output_b = torch::empty_like(output_b); + auto grad_norm_w = torch::empty_like(norm_w); + auto grad_norm_b = torch::empty_like(norm_b); + + // inputs. + const T* grad_output_ptr = (const T*)g_output.data_ptr(); + const T* input_ptr = (const T*)input.data_ptr(); + const T* output_ptr = (const T*)output.data_ptr(); + const T* inp_norm_ptr = (const T*)inp_norm.data_ptr(); + const T* q_tf_ptr = (const T*)qkv_tf.data_ptr(); + const T* add_res_ptr = (const T*)add_res.data_ptr(); + const T* k_tf_ptr = + q_tf_ptr + (bsz * layer->GetSeqLength() * output_w.size(0)); //(const T*)k_tf.data_ptr(); + const T* v_tf_ptr = + k_tf_ptr + (bsz * layer->GetSeqLength() * output_w.size(0)); //(const T*)v_tf.data_ptr(); + const T* ff1_inp_ptr = (const T*)ff1_inp.data_ptr(); + const T* gelu_inp_ptr = (const T*)gelu_inp.data_ptr(); + const T* ff2_inp_ptr = (const T*)ff2_inp.data_ptr(); + const T* ctx_bufB_ptr = (const T*)ctx_bufB.data_ptr(); + const T* soft_out_ptr = (const T*)soft_out.data_ptr(); + const T* attn_o_inp_ptr = (const T*)attn_o_inp.data_ptr(); + const T* input_mask_ptr = (const T*)input_mask.data_ptr(); + const T* attn_qkvw_ptr = (const T*)attn_qkvw.data_ptr(); + const T* attn_ow_ptr = (const T*)attn_ow.data_ptr(); + const T* attn_nw_ptr = (const T*)attn_nw.data_ptr(); + const T* attn_nb_ptr = (const T*)attn_nb.data_ptr(); + const T* inter_w_ptr = (const T*)inter_w.data_ptr(); + const T* inter_b_ptr = (const T*)inter_b.data_ptr(); + const T* output_w_ptr = (const T*)output_w.data_ptr(); + const T* norm_w_ptr = (const T*)norm_w.data_ptr(); + const T* norm_b_ptr = (const T*)norm_b.data_ptr(); + + // outputs. + T* grad_input_ptr = (T*)grad_input.data_ptr(); + T* grad_attn_qkvw_ptr = (T*)grad_attn_qkvw.data_ptr(); + T* grad_attn_qkvb_ptr = (T*)grad_attn_qkvb.data_ptr(); + T* grad_attn_ow_ptr = (T*)grad_attn_ow.data_ptr(); + T* grad_attn_ob_ptr = (T*)grad_attn_ob.data_ptr(); + T* grad_attn_nw_ptr = (T*)grad_attn_nw.data_ptr(); + T* grad_attn_nb_ptr = (T*)grad_attn_nb.data_ptr(); + T* grad_inter_w_ptr = (T*)grad_inter_w.data_ptr(); + T* grad_inter_b_ptr = (T*)grad_inter_b.data_ptr(); + T* grad_output_w_ptr = (T*)grad_output_w.data_ptr(); + T* grad_output_b_ptr = (T*)grad_output_b.data_ptr(); + T* grad_norm_w_ptr = (T*)grad_norm_w.data_ptr(); + T* grad_norm_b_ptr = (T*)grad_norm_b.data_ptr(); + + layer->SetIntermediateBuffers((uint8_t*)attn_prob_dropout_mask.data_ptr(), + (uint8_t*)attn_output_dropout_mask.data_ptr(), + (uint8_t*)layer_output_dropout_mask.data_ptr(), + (T*)attn_layer_norm_var.data_ptr(), + (T*)attn_layer_norm_mean.data_ptr(), + (T*)layer_norm_var.data_ptr(), + (T*)layer_norm_mean.data_ptr()); + + layer->Backward(bsz, + grad_output_ptr, + input_ptr, + output_ptr, + inp_norm_ptr, + q_tf_ptr, + k_tf_ptr, + v_tf_ptr, + soft_out_ptr, + ctx_bufB_ptr, + attn_o_inp_ptr, + add_res_ptr, + ff1_inp_ptr, + gelu_inp_ptr, + ff2_inp_ptr, + input_mask_ptr, + attn_qkvw_ptr, + attn_ow_ptr, + attn_nw_ptr, + attn_nb_ptr, + inter_w_ptr, + inter_b_ptr, + output_w_ptr, + norm_w_ptr, + norm_b_ptr, + + grad_input_ptr, + grad_attn_qkvw_ptr, + grad_attn_qkvb_ptr, + grad_attn_ow_ptr, + grad_attn_ob_ptr, + grad_attn_nw_ptr, + grad_attn_nb_ptr, + grad_inter_w_ptr, + grad_inter_b_ptr, + grad_output_w_ptr, + grad_output_b_ptr, + grad_norm_w_ptr, + grad_norm_b_ptr); + + return {grad_input, + grad_attn_qkvw, + grad_attn_qkvb, + grad_attn_ow, + grad_attn_ob, + grad_attn_nw, + grad_attn_nb, + grad_inter_w, + grad_inter_b, + grad_output_w, + grad_output_b, + grad_norm_w, + grad_norm_b}; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("forward_fp32", + &ds_transformer_forward, + "DeepSpeed Transformer forward with fp32 (CUDA)"); + m.def("forward_fp16", + &ds_transformer_forward<__half>, + "DeepSpeed Transformer forward with fp16 (CUDA)"); + m.def("backward_fp32", + &ds_transformer_backward, + "DeepSpeed Transformer backward with fp32 (CUDA)"); + m.def("backward_fp16", + &ds_transformer_backward<__half>, + "DeepSpeed Transformer backward with fp16 (CUDA)"); + m.def("create_transformer_layer_fp32", + &create_transformer_layer, + "Create DeepSpeed Transformer Transformer Layer with fp32 (CUDA)"); + m.def("create_transformer_layer_fp16", + &create_transformer_layer<__half>, + "Create DeepSpeed Transformer Transformer Layer with fp16 (CUDA)"); +} diff --git a/toolbox/DeepSpeed/v0.15.3/patches/csrc/transformer/gelu_kernels.cu b/toolbox/DeepSpeed/v0.15.3/patches/csrc/transformer/gelu_kernels.cu new file mode 100644 index 0000000000000000000000000000000000000000..477e4a0c95d1297a732925c201d0c0c1a30e0b75 --- /dev/null +++ b/toolbox/DeepSpeed/v0.15.3/patches/csrc/transformer/gelu_kernels.cu @@ -0,0 +1,352 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +/* Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. */ +/* All Rights Reserved. */ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include "custom_cuda_layers.h" + +inline __device__ float gelu(const float x) +{ + const float sqrt_param = 0.79788456080286535587989211986876f; + const float mul_param = 0.044715; + return x * 0.5f * (1.0f + tanhf(sqrt_param * (x + mul_param * x * x * x))); +} + +inline __device__ float d_gelu(const float x) +{ + const float sqrt_param = 0.79788456080286535587989211986876f; + const float mul_param = 0.044715; + + float x2mul = x * x * mul_param; + float tan_h = tanhf(sqrt_param * (x + x * x2mul)); + float dg1 = 0.5f * (1.0f + tan_h); + float dg2 = x * 0.5f * sqrt_param * (1 - tan_h * tan_h); + float dg3 = dg2 * 3 * x2mul; + return (dg1 + dg2 + dg3); +} + +/* +Fused bias add with GELU + +Loads a vector of 4 elements each iteration, for stride +iterations. It was written with the intention to launch 256 thread +threadblocks, so to launch for bert-large, we would set ITERATIONS +to 4. This is currently done automatically as a heuristic, setting +the number of iterations as blocks of 1024. + +For FP16, the values are loaded from memory as __half, but converted +to FP32 for the arithmetic itself, to prevent numerous overflow on +the intermediate hyperbolic tangent, since there's no intrinsic +that computes it directly. +*/ + +__global__ void gelu_kernel(const float* input, float* vals, int row_stride, int iterations) +{ + int row = blockIdx.x; + int id = threadIdx.x; + int loop_stride = blockDim.x; + + const float4* input_cast = reinterpret_cast(input); + float4* vals_cast = reinterpret_cast(vals); + + for (int i = 0; i < iterations; i++) { + if (i * loop_stride + id < row_stride) { + float4 data = input_cast[row * row_stride + i * loop_stride + id]; + + data.x = gelu(data.x); + data.y = gelu(data.y); + data.z = gelu(data.z); + data.w = gelu(data.w); + + vals_cast[row * row_stride + i * loop_stride + id] = data; + } + } +} + +__global__ void gelu_kernel(const __half* input, __half* vals, int row_stride, int iterations) +{ +#ifdef HALF_PRECISION_AVAILABLE + int row = blockIdx.x; + int id = threadIdx.x; + int loop_stride = blockDim.x; + + const float2* input_cast = reinterpret_cast(input); + float2* vals_cast = reinterpret_cast(vals); + + for (int i = 0; i < iterations; i++) { + if (i * loop_stride + id < row_stride) { + float2 vals_vec = input_cast[row * row_stride + i * loop_stride + id]; + + __half2* vals_half = reinterpret_cast<__half2*>(&vals_vec); + + float2 low_data = __half22float2(vals_half[0]); + float2 high_data = __half22float2(vals_half[1]); + + low_data.x = gelu(low_data.x); + low_data.y = gelu(low_data.y); + high_data.x = gelu(high_data.x); + high_data.y = gelu(high_data.y); + + vals_half[0] = __float22half2_rn(low_data); + vals_half[1] = __float22half2_rn(high_data); + + vals_cast[row * row_stride + i * loop_stride + id] = vals_vec; + } + } +#endif +} + +__global__ void fused_bias_gelu(const float* input, + const float* bias, + float* vals, + int row_stride, + int iterations) +{ + int row = blockIdx.x; + int id = threadIdx.x; + int loop_stride = blockDim.x; + + const float4* input_cast = reinterpret_cast(input); + float4* vals_cast = reinterpret_cast(vals); + const float4* bias_cast = reinterpret_cast(bias); + + for (int i = 0; i < iterations; i++) { + if (i * loop_stride + id < row_stride) { + float4 data = input_cast[row * row_stride + i * loop_stride + id]; + float4 bias_data = bias_cast[i * loop_stride + id]; + + data.x += bias_data.x; + data.y += bias_data.y; + data.z += bias_data.z; + data.w += bias_data.w; + + data.x = gelu(data.x); + data.y = gelu(data.y); + data.z = gelu(data.z); + data.w = gelu(data.w); + + vals_cast[row * row_stride + i * loop_stride + id] = data; + } + } +} + +__global__ void fused_bias_gelu(const __half* input, + const __half* bias, + __half* vals, + int row_stride, + int iterations) +{ +#ifdef HALF_PRECISION_AVAILABLE + int row = blockIdx.x; + int id = threadIdx.x; + int loop_stride = blockDim.x; + + const float2* input_cast = reinterpret_cast(input); + float2* vals_cast = reinterpret_cast(vals); + const float2* bias_cast = reinterpret_cast(bias); + + for (int i = 0; i < iterations; i++) { + if (i * loop_stride + id < row_stride) { + float2 vals_vec = input_cast[row * row_stride + i * loop_stride + id]; + float2 bias_vec = bias_cast[i * loop_stride + id]; + + __half2* vals_half = reinterpret_cast<__half2*>(&vals_vec); + __half2* bias_half = reinterpret_cast<__half2*>(&bias_vec); + + float2 low_data = __half22float2(vals_half[0]); + float2 high_data = __half22float2(vals_half[1]); + + float2 low_bias = __half22float2(bias_half[0]); + float2 high_bias = __half22float2(bias_half[1]); + + low_data.x += low_bias.x; + low_data.y += low_bias.y; + high_data.x += high_bias.x; + high_data.y += high_bias.y; + + low_data.x = gelu(low_data.x); + low_data.y = gelu(low_data.y); + high_data.x = gelu(high_data.x); + high_data.y = gelu(high_data.y); + + vals_half[0] = __float22half2_rn(low_data); + vals_half[1] = __float22half2_rn(high_data); + + vals_cast[row * row_stride + i * loop_stride + id] = vals_vec; + } + } +#endif +} + +__global__ void d_gelu_func(float* d_output, + const float* gelu_input, + const float* bias, + int row_stride, + int iterations) +{ + int row = blockIdx.x; + int id = threadIdx.x; + int loop_stride = blockDim.x; + + float4* d_output_cast = reinterpret_cast(d_output); + const float4* gelu_input_cast = reinterpret_cast(gelu_input); + const float4* bias_cast = reinterpret_cast(bias); + + for (int i = 0; i < iterations; i++) { + if (i * loop_stride + id < row_stride) { + float4 output_data = d_output_cast[row * row_stride + i * loop_stride + id]; + float4 gelu_input_data = gelu_input_cast[row * row_stride + i * loop_stride + id]; + float4 bias_data = bias_cast[i * loop_stride + id]; + + gelu_input_data.x += bias_data.x; + gelu_input_data.y += bias_data.y; + gelu_input_data.z += bias_data.z; + gelu_input_data.w += bias_data.w; + + output_data.x *= d_gelu(gelu_input_data.x); + output_data.y *= d_gelu(gelu_input_data.y); + output_data.z *= d_gelu(gelu_input_data.z); + output_data.w *= d_gelu(gelu_input_data.w); + + d_output_cast[row * row_stride + i * loop_stride + id] = output_data; + } + } +} + +__global__ void d_gelu_func(__half* d_output, + const __half* gelu_input, + const __half* bias, + int row_stride, + int iterations) +{ +#ifdef HALF_PRECISION_AVAILABLE + int row = blockIdx.x; + int id = threadIdx.x; + int loop_stride = blockDim.x; + + float2* d_output_cast = reinterpret_cast(d_output); + const float2* gelu_input_cast = reinterpret_cast(gelu_input); + const float2* bias_cast = reinterpret_cast(bias); + +#pragma unroll + for (int i = 0; i < iterations; i++) { + if (i * loop_stride + id < row_stride) { + float2 output_data = d_output_cast[row * row_stride + i * loop_stride + id]; + float2 gelu_input_data = gelu_input_cast[row * row_stride + i * loop_stride + id]; + float2 bias_vec = bias_cast[i * loop_stride + id]; + + __half2* output_data_half = reinterpret_cast<__half2*>(&output_data); + __half2* gelu_input_data_half = reinterpret_cast<__half2*>(&gelu_input_data); + __half2* bias_half = reinterpret_cast<__half2*>(&bias_vec); + + float2 output_half_0 = __half22float2(output_data_half[0]); + float2 output_half_1 = __half22float2(output_data_half[1]); + + float2 gelu_input_half_0 = __half22float2(gelu_input_data_half[0]); + float2 gelu_input_half_1 = __half22float2(gelu_input_data_half[1]); + + float2 bias_half_0 = __half22float2(bias_half[0]); + float2 bias_half_1 = __half22float2(bias_half[1]); + + gelu_input_half_0.x += bias_half_0.x; + gelu_input_half_0.y += bias_half_0.y; + gelu_input_half_1.x += bias_half_1.x; + gelu_input_half_1.y += bias_half_1.y; + + output_half_0.x *= d_gelu(gelu_input_half_0.x); + output_half_0.y *= d_gelu(gelu_input_half_0.y); + output_half_1.x *= d_gelu(gelu_input_half_1.x); + output_half_1.y *= d_gelu(gelu_input_half_1.y); + + float2 result; + __half2* result_half2 = reinterpret_cast<__half2*>(&result); + + result_half2[0] = __float22half2_rn(output_half_0); + result_half2[1] = __float22half2_rn(output_half_1); + + d_output_cast[row * row_stride + i * loop_stride + id] = result; + } + } +#endif +} + +template +void launch_bias_gelu(const T* input, + const T* bias, + T* output, + int intermediate_size, + int batch_size, + cudaStream_t stream) +{ + int iterations = (intermediate_size + 1023) / 1024; + int threads = (intermediate_size - 1) / (iterations * 4) + 1; + dim3 block_dims(threads); + dim3 grid_dims(batch_size); + + fused_bias_gelu<<>>( + input, bias, output, intermediate_size / 4, iterations); +} + +template +void launch_gelu(const T* input, + T* output, + int intermediate_size, + int batch_size, + cudaStream_t stream) +{ + int iterations = (intermediate_size + 1023) / 1024; + int threads = (intermediate_size - 1) / (iterations * 4) + 1; + dim3 block_dims(threads); + dim3 grid_dims(batch_size); + + gelu_kernel<<>>( + input, output, intermediate_size / 4, iterations); +} + +template void launch_bias_gelu(const float*, const float*, float*, int, int, cudaStream_t); +template void launch_bias_gelu<__half>(const __half*, + const __half*, + __half*, + int, + int, + cudaStream_t); + +template void launch_gelu(const float*, float*, int, int, cudaStream_t); +template void launch_gelu<__half>(const __half*, __half*, int, int, cudaStream_t); + +template +void launch_d_gelu(T* d_output, + const T* input, + const T* bias, + int intermediate_size, + int batch_size, + cudaStream_t stream) +{ + int iterations = (intermediate_size + 1023) / 1024; + int threads = (intermediate_size - 1) / (iterations * 4) + 1; + dim3 block_dims(threads); + dim3 grid_dims(batch_size); + + d_gelu_func<<>>( + d_output, input, bias, intermediate_size / 4, iterations); +} + +template void launch_d_gelu(float*, const float*, const float*, int, int, cudaStream_t); +template void launch_d_gelu<__half>(__half*, const __half*, const __half*, int, int, cudaStream_t); diff --git a/toolbox/DeepSpeed/v0.15.3/patches/csrc/transformer/general_kernels.cu b/toolbox/DeepSpeed/v0.15.3/patches/csrc/transformer/general_kernels.cu new file mode 100644 index 0000000000000000000000000000000000000000..816a5478890d2991c59750fdbee50b55567145e5 --- /dev/null +++ b/toolbox/DeepSpeed/v0.15.3/patches/csrc/transformer/general_kernels.cu @@ -0,0 +1,433 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +/* Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. */ +/* All Rights Reserved. */ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include "general_kernels.h" + +namespace cg = cooperative_groups; + +template +__global__ void column_sum_reduce(const T* __restrict__ inp, + T* __restrict__ out, + int rows, + int width) +{ + __shared__ float tile[TILE_DIM][TILE_DIM + 1]; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + int idx = blockDim.x * blockIdx.x + threadIdx.x; + + int y_stride = width * TILE_DIM; + + float localSum = 0; + + // Loop across matrix height + if (idx < width) { + int offset = threadIdx.y * width + idx; + for (int r = threadIdx.y; r < rows; r += TILE_DIM) { + localSum += (float)inp[offset]; + offset += y_stride; + } + } + + tile[threadIdx.x][threadIdx.y] = localSum; + + __syncthreads(); + + // Sum the shared buffer. + float sum = tile[threadIdx.y][threadIdx.x]; + +#ifndef __STOCHASTIC_MODE__ + __syncthreads(); +#endif + + for (int i = 1; i < TILE_DIM; i <<= 1) sum += g.shfl_down(sum, i); + + if (threadIdx.x == 0) { + int pos = blockIdx.x * TILE_DIM + threadIdx.y; + if (pos < width) out[pos] = sum; + } +} + +template +void launch_fuse_transpose_bias_kernel(const T* inp, + T* out, + int rows, + int cols, + cudaStream_t stream); + +template <> +void launch_fuse_transpose_bias_kernel(const float* inp, + float* out, + int rows, + int cols, + cudaStream_t stream) +{ + // assert(rows % TILE_DIM == 0); + // assert(cols % TILE_DIM == 0); + + dim3 grid_dim((cols - 1) / TILE_DIM + 1); + dim3 block_dim(TILE_DIM, TILE_DIM); + + column_sum_reduce<<>>(inp, out, rows, cols); +} + +template <> +void launch_fuse_transpose_bias_kernel<__half>(const __half* inp, + __half* out, + int rows, + int cols, + cudaStream_t stream) +{ + // assert(rows % TILE_DIM == 0); + // assert(cols % TILE_DIM == 0); + + dim3 grid_dim((cols - 1) / TILE_DIM + 1); + dim3 block_dim(TILE_DIM, TILE_DIM); + + column_sum_reduce<__half><<>>(inp, out, rows, cols); +} + +__global__ void fused_add2_kernel(const int N, float* out, const float* inp1, const float* inp2) +{ + const float4* inp1_4 = reinterpret_cast(inp1); + const float4* inp2_4 = reinterpret_cast(inp2); + float4* out_4 = reinterpret_cast(out); + + CUDA_1D_KERNEL_LOOP(j, N) + { + float4 val; + float4 inp1_reg = inp1_4[j]; + float4 inp2_reg = inp2_4[j]; + + val.x = inp1_reg.x + inp2_reg.x; + val.y = inp1_reg.y + inp2_reg.y; + val.z = inp1_reg.z + inp2_reg.z; + val.w = inp1_reg.w + inp2_reg.w; + + out_4[j] = val; + } +} + +__global__ void fused_add2_kernel(const int N, __half* out, const __half* inp1, const __half* inp2) +{ + float2 inp1_4; + float2 inp2_4; + + __half2* inp1_h = reinterpret_cast<__half2*>(&inp1_4); + __half2* inp2_h = reinterpret_cast<__half2*>(&inp2_4); + + const float2* inp1_arr = reinterpret_cast(inp1); + const float2* inp2_arr = reinterpret_cast(inp2); + + CUDA_1D_KERNEL_LOOP(j, N) + { + inp1_4 = inp1_arr[j]; + inp2_4 = inp2_arr[j]; + + float2 inp1_h_f_0 = __half22float2(inp1_h[0]); + float2 inp1_h_f_1 = __half22float2(inp1_h[1]); + + float2 inp2_h_f_0 = __half22float2(inp2_h[0]); + float2 inp2_h_f_1 = __half22float2(inp2_h[1]); + + inp1_h_f_0.x += inp2_h_f_0.x; + inp1_h_f_0.y += inp2_h_f_0.y; + inp1_h_f_1.x += inp2_h_f_1.x; + inp1_h_f_1.y += inp2_h_f_1.y; + + float2 val_f; + __half2* val_h = reinterpret_cast<__half2*>(&val_f); + + val_h[0] = __float22half2_rn(inp1_h_f_0); + val_h[1] = __float22half2_rn(inp1_h_f_1); + + float2* out_4 = reinterpret_cast(out); + out_4[j] = val_f; + } +} + +template <> +void launch_fused_add2(float* out, + const float* inp1, + const float* inp2, + int batch_size, + int seq_length, + int hidden_dim, + cudaStream_t& stream) +{ + int total_count = batch_size * seq_length * hidden_dim / 4; + dim3 grid_dim = DS_GET_BLOCKS(total_count); //(batch_size * seq_length); + + dim3 block_dim = DS_CUDA_NUM_THREADS; //(hidden_dim / 4); + + fused_add2_kernel<<>>(total_count, out, inp1, inp2); +} + +template <> +void launch_fused_add2<__half>(__half* out, + const __half* inp1, + const __half* inp2, + int batch_size, + int seq_length, + int hidden_dim, + cudaStream_t& stream) +{ + int total_count = batch_size * seq_length * hidden_dim / 4; + dim3 grid_dim = DS_GET_BLOCKS(total_count); //(batch_size * seq_length); + + dim3 block_dim = DS_CUDA_NUM_THREADS; //(hidden_dim / 4); + + fused_add2_kernel<<>>(total_count, out, inp1, inp2); +} + +__global__ void fused_add3_kernel(float* out, + const float* inp1, + const float* inp2, + const float* inp3, + int size, + int row_stride) +{ + int row = blockIdx.x; + int id = threadIdx.x; + + const float4* inp1_4 = reinterpret_cast(inp1); + const float4* inp2_4 = reinterpret_cast(inp2); + const float4* inp3_4 = reinterpret_cast(inp3); + + float4* out_4 = reinterpret_cast(out); + + float4 val; + float4 inp1_reg = inp1_4[row * row_stride + id]; + float4 inp2_reg = inp2_4[row * row_stride + id]; + float4 inp3_reg = inp3_4[row * row_stride + id]; + + val.x = inp1_reg.x + inp2_reg.x + inp3_reg.x; + val.y = inp1_reg.y + inp2_reg.y + inp3_reg.y; + val.z = inp1_reg.z + inp2_reg.z + inp3_reg.z; + val.w = inp1_reg.w + inp2_reg.w + inp3_reg.w; + + out_4[row * row_stride + id] = val; +} + +__global__ void fused_add3_kernel(__half* out, + const __half* inp1, + const __half* inp2, + const __half* inp3, + int size, + int row_stride) +{ + int row = blockIdx.x; + int id = threadIdx.x; + const float2* inp1_arr = reinterpret_cast(inp1); + const float2* inp2_arr = reinterpret_cast(inp2); + const float2* inp3_arr = reinterpret_cast(inp3); + + float2 inp1_4 = inp1_arr[row * row_stride + id]; + float2 inp2_4 = inp2_arr[row * row_stride + id]; + float2 inp3_4 = inp3_arr[row * row_stride + id]; + + __half2* inp1_h = reinterpret_cast<__half2*>(&inp1_4); + __half2* inp2_h = reinterpret_cast<__half2*>(&inp2_4); + __half2* inp3_h = reinterpret_cast<__half2*>(&inp3_4); + + float2 inp1_h_f_0 = __half22float2(inp1_h[0]); + float2 inp1_h_f_1 = __half22float2(inp1_h[1]); + + float2 inp2_h_f_0 = __half22float2(inp2_h[0]); + float2 inp2_h_f_1 = __half22float2(inp2_h[1]); + + float2 inp3_h_f_0 = __half22float2(inp3_h[0]); + float2 inp3_h_f_1 = __half22float2(inp3_h[1]); + + inp1_h_f_0.x += (inp2_h_f_0.x + inp3_h_f_0.x); + inp1_h_f_0.y += (inp2_h_f_0.y + inp3_h_f_0.y); + inp1_h_f_1.x += (inp2_h_f_1.x + inp3_h_f_1.x); + inp1_h_f_1.y += (inp2_h_f_1.y + inp3_h_f_1.y); + + float2 val_f; + __half2* val_h = reinterpret_cast<__half2*>(&val_f); + + val_h[0] = __float22half2_rn(inp1_h_f_0); + val_h[1] = __float22half2_rn(inp1_h_f_1); + + float2* out_4 = reinterpret_cast(out); + out_4[row * row_stride + id] = val_f; +} + +template <> +void launch_fused_add3(float* out, + const float* inp1, + const float* inp2, + const float* inp3, + int batch_size, + int seq_length, + int hidden_size, + cudaStream_t& stream) +{ + dim3 grid_dim(batch_size * seq_length); + + dim3 block_dim(hidden_size / 4); + + fused_add3_kernel<<>>( + out, inp1, inp2, inp3, (batch_size * seq_length * hidden_size), hidden_size / 4); +} + +template <> +void launch_fused_add3<__half>(__half* out, + const __half* inp1, + const __half* inp2, + const __half* inp3, + int batch_size, + int seq_length, + int hidden_size, + cudaStream_t& stream) +{ + dim3 grid_dim(batch_size * seq_length); + + dim3 block_dim(hidden_size / 4); + + fused_add3_kernel<<>>( + out, inp1, inp2, inp3, (batch_size * seq_length * hidden_size), hidden_size / 4); +} + +__global__ void fused_add4_kernel(float* out, + const float* inp1, + const float* inp2, + const float* inp3, + const float* inp4, + int size, + int row_stride) +{ + int row = blockIdx.x; + int id = threadIdx.x; + + const float4* inp1_4 = reinterpret_cast(inp1); + const float4* inp2_4 = reinterpret_cast(inp2); + const float4* inp3_4 = reinterpret_cast(inp3); + const float4* inp4_4 = reinterpret_cast(inp4); + float4* out_4 = reinterpret_cast(out); + + float4 val; + float4 inp1_reg = inp1_4[row * row_stride + id]; + float4 inp2_reg = inp2_4[row * row_stride + id]; + float4 inp3_reg = inp3_4[row * row_stride + id]; + float4 inp4_reg = inp4_4[row * row_stride + id]; + + val.x = inp1_reg.x + inp2_reg.x + inp3_reg.x + inp4_reg.x; + val.y = inp1_reg.y + inp2_reg.y + inp3_reg.y + inp4_reg.y; + val.z = inp1_reg.z + inp2_reg.z + inp3_reg.z + inp4_reg.z; + val.w = inp1_reg.w + inp2_reg.w + inp3_reg.w + inp4_reg.w; + + out_4[row * row_stride + id] = val; +} + +__global__ void fused_add4_kernel(__half* out, + const __half* inp1, + const __half* inp2, + const __half* inp3, + const __half* inp4, + int size, + int row_stride) +{ + int row = blockIdx.x; + int id = threadIdx.x; + const float2* inp1_arr = reinterpret_cast(inp1); + const float2* inp2_arr = reinterpret_cast(inp2); + const float2* inp3_arr = reinterpret_cast(inp3); + const float2* inp4_arr = reinterpret_cast(inp4); + + float2 inp1_4 = inp1_arr[row * row_stride + id]; + float2 inp2_4 = inp2_arr[row * row_stride + id]; + float2 inp3_4 = inp3_arr[row * row_stride + id]; + float2 inp4_4 = inp4_arr[row * row_stride + id]; + + __half2* inp1_h = reinterpret_cast<__half2*>(&inp1_4); + __half2* inp2_h = reinterpret_cast<__half2*>(&inp2_4); + __half2* inp3_h = reinterpret_cast<__half2*>(&inp3_4); + __half2* inp4_h = reinterpret_cast<__half2*>(&inp4_4); + + float2 inp1_h_f_0 = __half22float2(inp1_h[0]); + float2 inp1_h_f_1 = __half22float2(inp1_h[1]); + + float2 inp2_h_f_0 = __half22float2(inp2_h[0]); + float2 inp2_h_f_1 = __half22float2(inp2_h[1]); + + float2 inp3_h_f_0 = __half22float2(inp3_h[0]); + float2 inp3_h_f_1 = __half22float2(inp3_h[1]); + + float2 inp4_h_f_0 = __half22float2(inp4_h[0]); + float2 inp4_h_f_1 = __half22float2(inp4_h[1]); + + inp1_h_f_0.x += (inp2_h_f_0.x + inp3_h_f_0.x + inp4_h_f_0.x); + inp1_h_f_0.y += (inp2_h_f_0.y + inp3_h_f_0.y + inp4_h_f_0.y); + inp1_h_f_1.x += (inp2_h_f_1.x + inp3_h_f_1.x + inp4_h_f_1.x); + inp1_h_f_1.y += (inp2_h_f_1.y + inp3_h_f_1.y + inp4_h_f_1.y); + + float2 val_f; + __half2* val_h = reinterpret_cast<__half2*>(&val_f); + + val_h[0] = __float22half2_rn(inp1_h_f_0); + val_h[1] = __float22half2_rn(inp1_h_f_1); + + float2* out_4 = reinterpret_cast(out); + out_4[row * row_stride + id] = val_f; +} + +template <> +void launch_fused_add4(float* out, + const float* inp1, + const float* inp2, + const float* inp3, + const float* inp4, + int batch_size, + int seq_length, + int hidden_size, + cudaStream_t& stream) +{ + dim3 grid_dim(batch_size * seq_length); + + dim3 block_dim(hidden_size / 4); + + fused_add4_kernel<<>>( + out, inp1, inp2, inp3, inp4, (batch_size * seq_length * hidden_size), hidden_size / 4); +} + +template <> +void launch_fused_add4<__half>(__half* out, + const __half* inp1, + const __half* inp2, + const __half* inp3, + const __half* inp4, + int batch_size, + int seq_length, + int hidden_size, + cudaStream_t& stream) +{ + dim3 grid_dim(batch_size * seq_length); + + dim3 block_dim(hidden_size / 4); + + fused_add4_kernel<<>>( + out, inp1, inp2, inp3, inp4, (batch_size * seq_length * hidden_size), hidden_size / 4); +} diff --git a/toolbox/DeepSpeed/v0.15.3/patches/csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu b/toolbox/DeepSpeed/v0.15.3/patches/csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu new file mode 100644 index 0000000000000000000000000000000000000000..b6acfe49f9423058599c9589dd371d2bb9e64a63 --- /dev/null +++ b/toolbox/DeepSpeed/v0.15.3/patches/csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu @@ -0,0 +1,208 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +/* Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. */ +/* All Rights Reserved. */ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include "conversion_utils.h" +#ifdef __HIP_PLATFORM_AMD__ +#include "hip/hip_cooperative_groups.h" +#else +#include "cooperative_groups.h" +#endif +#include "ds_kernel_utils.h" +#include "inference_cuda_layers.h" +#include "memory_access_utils.h" + +#ifndef __HIP_PLATFORM_AMD__ +#include +#endif + +namespace cg = cooperative_groups; + +namespace rot_half { +constexpr int threads = 256; +} // namespace rot_half + +template +__global__ void apply_rotary_pos_half(T* mixed_query, + T* key_layer, + unsigned rotary_dim, + unsigned seq_len, + unsigned seq_offset, + unsigned num_heads, + unsigned head_size, + unsigned total_count, + float rope_theta, + int max_out_tokens) +{ + constexpr int T_per_thread = granularity / sizeof(T); + constexpr int heads_per_block = rot_half::threads / threadsPerHead; + + cg::thread_block tb = cg::this_thread_block(); + cg::thread_block_tile head_group = cg::tiled_partition(tb); + + const int head_idx = blockIdx.x * heads_per_block + threadIdx.x / threadsPerHead; + const int cur_seq_idx = head_idx % seq_len; + const int offset = head_idx * head_size; + const int k_offset = (cur_seq_idx + (head_idx / seq_len) * max_out_tokens) * head_size; + + const int seq_idx = cur_seq_idx + seq_offset; + const int half_dim = rotary_dim >> 1; + const int half_dim_threads = half_dim / T_per_thread; + + if (head_idx < total_count) { + const int base_neuron_idx = head_group.thread_rank() * T_per_thread; + + T q[T_per_thread], k[T_per_thread]; + mem_access::load_global(q, mixed_query + offset + base_neuron_idx); + mem_access::load_global(k, key_layer + k_offset + base_neuron_idx); + +#pragma unroll + for (int i = 0; i < T_per_thread; i++) { + const int neuron_idx = base_neuron_idx + i; + if (neuron_idx < rotary_dim) { + float inv_freq = (float)((neuron_idx % half_dim) * 2) / (float)rotary_dim; + inv_freq = 1.0 / powf(rope_theta, inv_freq) * (float)seq_idx; + + float rotary_sign = (neuron_idx > (half_dim - 1) ? -1.0 : 1.0); + float q_rot = conversion::to(q[i]) * rotary_sign; + float k_rot = conversion::to(k[i]) * rotary_sign; + + const int target_lane = (neuron_idx < half_dim) + ? head_group.thread_rank() + half_dim_threads + : head_group.thread_rank() - half_dim_threads; + + const float q_rot_temp = head_group.shfl(q_rot, target_lane); + const float k_rot_temp = head_group.shfl(k_rot, target_lane); + + q[i] = conversion::to(conversion::to(q[i]) * cosf(inv_freq) + + q_rot_temp * sinf(inv_freq)); + k[i] = conversion::to(conversion::to(k[i]) * cosf(inv_freq) + + k_rot_temp * sinf(inv_freq)); + } + } + + mem_access::store_global(mixed_query + offset + base_neuron_idx, q); + mem_access::store_global(key_layer + k_offset + base_neuron_idx, k); + } +} + +#define LAUNCH_ROT_POS_EMB_HALF(HEAD_THREADS, ALIGNMENT) \ + apply_rotary_pos_half<<>>(mixed_query, \ + key_layer, \ + rotary_dim, \ + seq_len, \ + offset, \ + num_heads, \ + head_size, \ + total_count, \ + rope_theta, \ + max_out_tokens); + +#if defined(__HIP_PLATFORM_AMD__) and ROCM_WAVEFRONT_SIZE == 64 +#define LAUNCH_FOR_ALIGNMENT(ALIGNMENT) \ + if (threads_per_head == 64) { \ + LAUNCH_ROT_POS_EMB_HALF(64, ALIGNMENT); \ + } else { \ + assert(false); \ + } +#else +#define LAUNCH_FOR_ALIGNMENT(ALIGNMENT) \ + if (threads_per_head == 4) { \ + LAUNCH_ROT_POS_EMB_HALF(4, ALIGNMENT); \ + } else if (threads_per_head == 8) { \ + LAUNCH_ROT_POS_EMB_HALF(8, ALIGNMENT); \ + } else if (threads_per_head == 16) { \ + LAUNCH_ROT_POS_EMB_HALF(16, ALIGNMENT); \ + } else if (threads_per_head == 32) { \ + LAUNCH_ROT_POS_EMB_HALF(32, ALIGNMENT); \ + } else { \ + assert(false); \ + } +#endif + +template +void launch_apply_rotary_pos_emb(T* mixed_query, + T* key_layer, + unsigned head_size, + unsigned seq_len, + unsigned rotary_dim, + unsigned offset, + unsigned num_heads, + unsigned batch, + float rope_theta, + cudaStream_t stream, + int max_out_tokens) +{ + const int half_dim = rotary_dim >> 1; + + int alignment = sizeof(T); + if (half_dim % (16 / sizeof(T)) == 0) { + alignment = 16; + } else if (half_dim % (8 / sizeof(T)) == 0) { + alignment = 8; + } else if (half_dim % (4 / sizeof(T)) == 0) { + alignment = 4; + } else { + assert(false); + } + const int T_per_elem = alignment / sizeof(T); + + int total_count = batch * num_heads * seq_len; + + const int padded_head_size = next_pow2(head_size); + + assert(padded_head_size <= hw_warp_size * T_per_elem); + + const int threads_per_head = padded_head_size / T_per_elem; + const int heads_per_block = rot_half::threads / threads_per_head; + + dim3 block(rot_half::threads); + dim3 grid((total_count + heads_per_block - 1) / heads_per_block); + + if (alignment == 4) { + LAUNCH_FOR_ALIGNMENT(4); + } else if (alignment == 8) { + LAUNCH_FOR_ALIGNMENT(8); + } else if (alignment == 16) { + LAUNCH_FOR_ALIGNMENT(16); + } else { + assert(false); + } +} + +#define INSTANTIATE_LAUNCH_ROTARY_POS_EMB(T) \ + template void launch_apply_rotary_pos_emb(T*, \ + T*, \ + unsigned, \ + unsigned, \ + unsigned, \ + unsigned, \ + unsigned, \ + unsigned, \ + float, \ + cudaStream_t, \ + int); + +INSTANTIATE_LAUNCH_ROTARY_POS_EMB(float); +#ifdef BF16_AVAILABLE +INSTANTIATE_LAUNCH_ROTARY_POS_EMB(__nv_bfloat16); +#endif +INSTANTIATE_LAUNCH_ROTARY_POS_EMB(__half); diff --git a/toolbox/DeepSpeed/v0.15.3/patches/csrc/transformer/inference/csrc/dequantize.cu b/toolbox/DeepSpeed/v0.15.3/patches/csrc/transformer/inference/csrc/dequantize.cu new file mode 100644 index 0000000000000000000000000000000000000000..2d3945812ce37b64de1bddca6f3a988d071eb6cc --- /dev/null +++ b/toolbox/DeepSpeed/v0.15.3/patches/csrc/transformer/inference/csrc/dequantize.cu @@ -0,0 +1,170 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +/* Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. */ +/* All Rights Reserved. */ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include "conversion_utils.h" +#include "inference_cuda_layers.h" + +#define MAX_QUANTIZE_GROUPING 1024 + +#define loop_unroll 1 +#define loop_unroll_bits 1 + +template +__global__ void dequantize_kernel(T* output, + const int8_t* input, + const float* qscale, + int output_size, + int hidden_dim, + int groups, + int merge_count) +{ + unsigned merge_hidden = hidden_dim >> merge_count; + unsigned quantization_stride = (merge_hidden * output_size) / groups; + + unsigned bid = blockIdx.x; + unsigned tid = threadIdx.x; + + while (tid < output_size) { + unsigned w_index = bid / merge_hidden; + unsigned q_index = tid + bid * output_size; + + auto q = input[q_index]; + + unsigned merge_hidden_total = w_index * merge_hidden; + unsigned scale_index = + ((((bid - merge_hidden_total) + tid * merge_hidden) / quantization_stride) + << merge_count) + + w_index; + + float scale_data = qscale[scale_index]; + + output[q_index] = conversion::to(scale_data * (float)q); + tid += blockDim.x; + } +} + +template +void launch_dequantize(T* output, + const int8_t* input, + const float* qscale, + unsigned output_size, + unsigned hidden_dim, + unsigned groups, + unsigned merge_count, + cudaStream_t stream) +{ + unsigned threads = 1024; + dim3 block_dims(threads); + dim3 grid_dims(hidden_dim); + + dequantize_kernel<<>>( + output, input, qscale, output_size, hidden_dim, groups, merge_count); +} + +#define INSTANTIATE_DEQUANTIZE_MERGE(T) \ + template void launch_dequantize( \ + T*, const int8_t*, const float*, unsigned, unsigned, unsigned, unsigned, cudaStream_t); + +INSTANTIATE_DEQUANTIZE_MERGE(float); +#ifdef BF16_AVAILABLE +INSTANTIATE_DEQUANTIZE_MERGE(__nv_bfloat16); +#endif +INSTANTIATE_DEQUANTIZE_MERGE(__half); + +__global__ void dequantize_kernel(float* output, + const int8_t* input, + const float* qscale, + int hidden_dim, + unsigned merge_hidden, + int cnt) +{ +} + +template +__global__ void dequantize_kernel(T* output, + const int8_t* input, + const float* qscale, + unsigned hidden_dim, + unsigned merge_hidden, + int cnt) +{ + unsigned bid = blockIdx.x * gridDim.y + blockIdx.y; + unsigned tid = threadIdx.x; + + float local_scale = qscale[blockIdx.x]; + + const float* input_cast = reinterpret_cast(input); + float2* output_cast = reinterpret_cast(output); + + input_cast += bid * merge_hidden; + output_cast += bid * merge_hidden; + + for (int c = 0; c < cnt; c++) { + if (tid < merge_hidden) { + float q = input_cast[tid]; + int8_t* q_int8 = (int8_t*)&q; + + float2 q_f; + T* q_h = (T*)&q_f; + + q_h[0] = conversion::to(local_scale * (float)q_int8[0]); + q_h[1] = conversion::to(local_scale * (float)q_int8[1]); + q_h[2] = conversion::to(local_scale * (float)q_int8[2]); + q_h[3] = conversion::to(local_scale * (float)q_int8[3]); + output_cast[tid] = q_f; + tid += blockDim.x; + } + } +} + +template +void launch_dequantize(T* output, + const int8_t* input, + const float* qscale, + unsigned output_size, + unsigned hidden_dim, + unsigned groups, + cudaStream_t stream) +{ + unsigned threads = 1024; + hidden_dim /= 4; + unsigned thd_cnt = (hidden_dim - 1) / threads + 1; + + assert(output_size % groups == 0); + unsigned blocks = output_size / groups; + + dim3 block_dims(threads); + dim3 grid_dims(groups, blocks); + + dequantize_kernel<<>>( + output, input, qscale, hidden_dim, hidden_dim, thd_cnt); +} + +#define INSTANTIATE_DEQUANTIZE_NO_MERGE(T) \ + template void launch_dequantize( \ + T*, const int8_t*, const float*, unsigned, unsigned, unsigned, cudaStream_t); + +INSTANTIATE_DEQUANTIZE_NO_MERGE(float); +#ifdef BF16_AVAILABLE +INSTANTIATE_DEQUANTIZE_NO_MERGE(__nv_bfloat16); +#endif +INSTANTIATE_DEQUANTIZE_NO_MERGE(__half); diff --git a/toolbox/DeepSpeed/v0.15.3/patches/csrc/transformer/inference/csrc/gelu.cu b/toolbox/DeepSpeed/v0.15.3/patches/csrc/transformer/inference/csrc/gelu.cu new file mode 100644 index 0000000000000000000000000000000000000000..504a302d5e4f6f95ae98c1f2d1c81d20b099e6da --- /dev/null +++ b/toolbox/DeepSpeed/v0.15.3/patches/csrc/transformer/inference/csrc/gelu.cu @@ -0,0 +1,727 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +/* Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. */ +/* All Rights Reserved. */ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include "conversion_utils.h" +#include "inference_cuda_layers.h" +#include "memory_access_utils.h" + +namespace cg = cooperative_groups; +#define MAX_CAP 4 +#define MAX_SEQ 2048 + +// only used to avoid compilation error due to lack of definition. +#ifndef BF16_AVAILABLE +using __nv_bfloat162 = __half2; +#endif + +inline __device__ float gelu(const float x) +{ + constexpr float sqrt_param = 0.79788456080286535587989211986876f; + constexpr float mul_param = 0.044715; + return x * 0.5f * (1.0f + tanhf(sqrt_param * (x + mul_param * x * x * x))); +} + +/* +In-place gelu(biasAdd(x)) for channels last +*/ +template +__global__ void fused_bias_gelu(T* input, const T* bias, int total_count, int intermediate_size) +{ + // Input restriction: intermediate_size % vals_per_access == 0 + constexpr int granularity = 16; + constexpr int values_per_access = granularity / sizeof(T); + const int offset = (blockIdx.x * blockDim.x + threadIdx.x) * values_per_access; + + if (offset < total_count) { + T data[values_per_access]; + T data_bias[values_per_access]; + mem_access::load_global(data, input + offset); + mem_access::load_global( + data_bias, bias + (offset % intermediate_size), bias != nullptr); + +#pragma unroll + for (int i = 0; i < values_per_access; i++) { + float data_f = conversion::to(data[i]); + float bias_f = conversion::to(data_bias[i]); + data[i] = conversion::to(gelu(data_f + bias_f)); + } + + mem_access::store_global(input + offset, data); + } +} + +template +void launch_bias_gelu(T* input, + const T* bias, + int intermediate_size, + int batch_size, + cudaStream_t stream) +{ + constexpr int threads = 1024; + constexpr int granularity = 16; + + const int total_count = batch_size * intermediate_size; + const int elems_per_block = threads * (granularity / sizeof(T)); + dim3 block_dims(threads); + dim3 grid_dims((total_count + elems_per_block - 1) / elems_per_block); + + fused_bias_gelu<<>>( + input, bias, total_count, intermediate_size); +} + +#define INSTANTIATE_LAUNCH_BIAS_GELU(T) \ + template void launch_bias_gelu(T*, const T*, int, int, cudaStream_t); + +INSTANTIATE_LAUNCH_BIAS_GELU(float) +#ifdef BF16_AVAILABLE +INSTANTIATE_LAUNCH_BIAS_GELU(__nv_bfloat16) +#endif +INSTANTIATE_LAUNCH_BIAS_GELU(__half) + +/* +In-place channels-last bias add +*/ +template +__global__ void fused_bias_add(T* input, const T* bias, int total_count, int intermediate_size) +{ + // Input restriction: intermediate_size % vals_per_access == 0 + constexpr int granularity = 16; + constexpr int values_per_access = granularity / sizeof(T); + const int offset = (blockIdx.x * blockDim.x + threadIdx.x) * values_per_access; + + if (offset < total_count) { + T data[values_per_access]; + T data_bias[values_per_access]; + mem_access::load_global(data, input + offset); + mem_access::load_global( + data_bias, bias + (offset % intermediate_size), bias != nullptr); + +#pragma unroll + for (int i = 0; i < values_per_access; i++) { + float data_f = conversion::to(data[i]); + float bias_f = conversion::to(data_bias[i]); + data[i] = conversion::to(data_f + bias_f); + } + + mem_access::store_global(input + offset, data); + } +} + +template +void launch_bias_add(T* input, + const T* bias, + int intermediate_size, + int batch_size, + cudaStream_t stream) +{ + constexpr int threads = 1024; + constexpr int granularity = 16; + + const int total_count = batch_size * intermediate_size; + const int elems_per_block = threads * (granularity / sizeof(T)); + dim3 block_dims(threads); + dim3 grid_dims((total_count + elems_per_block - 1) / elems_per_block); + + fused_bias_add<<>>( + input, bias, total_count, intermediate_size); +} + +#define INSTANTIATE_LAUNCH_BIAS_ADD(T) \ + template void launch_bias_add(T*, const T*, int, int, cudaStream_t); + +INSTANTIATE_LAUNCH_BIAS_ADD(float) +#ifdef BF16_AVAILABLE +INSTANTIATE_LAUNCH_BIAS_ADD(__nv_bfloat16) +#endif +INSTANTIATE_LAUNCH_BIAS_ADD(__half) + +__global__ void fused_bias_residual(float* residual, + const float* hidden_state, + const float* attn, + const float* bias, + const float* attn_bias, + const int total_count, + const int intermediate_size, + const float mp_scale, + const bool preln) +{ + float4* res_fl4_ptr = reinterpret_cast(residual); + const float4* hs_fl4_ptr = reinterpret_cast(hidden_state); + const float4* attn_fl4_ptr = reinterpret_cast(attn); + const float4* bias_fl4_ptr = reinterpret_cast(bias); + const float4* attn_bias_fl4_ptr = reinterpret_cast(attn_bias); + const int offset = blockIdx.x * blockDim.x + threadIdx.x; + + if (offset < total_count) { + float4 res_fl4 = res_fl4_ptr[offset]; + const float4 hs_fl4 = hs_fl4_ptr[offset]; + const float4 attn_fl4 = attn_fl4_ptr[offset]; + const float4 bias_fl4 = bias_fl4_ptr[offset % intermediate_size]; + const float4 attn_bias_fl4 = attn_bias_fl4_ptr[offset % intermediate_size]; + if (preln) { + // residual = (residual + attention + bias + attention_bias) * + // mp_scale + hidden_state + res_fl4.x = + (res_fl4.x + attn_fl4.x + bias_fl4.x + attn_bias_fl4.x) * mp_scale + (hs_fl4.x); + res_fl4.y = + (res_fl4.y + attn_fl4.y + bias_fl4.y + attn_bias_fl4.y) * mp_scale + (hs_fl4.y); + res_fl4.z = + (res_fl4.z + attn_fl4.z + bias_fl4.z + attn_bias_fl4.z) * mp_scale + (hs_fl4.z); + res_fl4.w = + (res_fl4.w + attn_fl4.w + bias_fl4.w + attn_bias_fl4.w) * mp_scale + (hs_fl4.w); + } else { + // residual += hidden_state + bias + res_fl4.x = res_fl4.x + hs_fl4.x + bias_fl4.x; + res_fl4.y = res_fl4.y + hs_fl4.y + bias_fl4.y; + res_fl4.z = res_fl4.z + hs_fl4.z + bias_fl4.z; + res_fl4.w = res_fl4.w + hs_fl4.w + bias_fl4.w; + } + res_fl4_ptr[offset] = res_fl4; + } +} + +template +__global__ void fused_bias_residual(T* residual, + const T* hidden_state, + const T* attn, + const T* bias, + const T* attn_bias, + const int total_count, + const int intermediate_size, + const float mp_scale, + const bool preln) +{ + using T2 = + typename std::conditional::value, __half2, __nv_bfloat162>::type; + float2* res_fl2_ptr = reinterpret_cast(residual); + const float2* hs_fl2_ptr = reinterpret_cast(hidden_state); + const float2* attn_fl2_ptr = reinterpret_cast(attn); + const float2* bias_fl2_ptr = reinterpret_cast(bias); + const float2* attn_bias_fl2_ptr = reinterpret_cast(attn_bias); + const int offset = blockIdx.x * blockDim.x + threadIdx.x; + + if (offset < total_count) { + float2 res_fl2 = res_fl2_ptr[offset]; + const float2 hs_fl2 = hs_fl2_ptr[offset]; + const float2 attn_fl2 = attn_fl2_ptr[offset]; + const float2 bias_fl2 = bias_fl2_ptr[offset % intermediate_size]; + const float2 attn_bias_fl2 = attn_bias_fl2_ptr[offset % intermediate_size]; + + T2* res_half2 = reinterpret_cast(&res_fl2); + const T2* hs_half2 = reinterpret_cast(&hs_fl2); + const T2* attn_half2 = reinterpret_cast(&attn_fl2); + const T2* bias_half2 = reinterpret_cast(&bias_fl2); + const T2* attn_bias_half2 = reinterpret_cast(&attn_bias_fl2); + + float2 res_low = conversion::to(res_half2[0]); + float2 res_high = conversion::to(res_half2[1]); + + const float2 hs_low = conversion::to(hs_half2[0]); + const float2 hs_high = conversion::to(hs_half2[1]); + + const float2 attn_low = conversion::to(attn_half2[0]); + const float2 attn_high = conversion::to(attn_half2[1]); + + const float2 bias_low = conversion::to(bias_half2[0]); + const float2 bias_high = conversion::to(bias_half2[1]); + + const float2 attn_bias_low = conversion::to(attn_bias_half2[0]); + const float2 attn_bias_high = conversion::to(attn_bias_half2[1]); + + if (preln) { + // residual = (residual + attention + bias + attention_bias) * + // mp_scale + hidden_state + res_low.x = + (res_low.x + attn_low.x + bias_low.x + attn_bias_low.x) * mp_scale + hs_low.x; + res_low.y = + (res_low.y + attn_low.y + bias_low.y + attn_bias_low.y) * mp_scale + hs_low.y; + res_high.x = + (res_high.x + attn_high.x + bias_high.x + attn_bias_high.x) * mp_scale + hs_high.x; + res_high.y = + (res_high.y + attn_high.y + bias_high.y + attn_bias_high.y) * mp_scale + hs_high.y; + } else { + // residual += hidden_state + bias + res_low.x = (res_low.x + hs_low.x + bias_low.x); + res_low.y = (res_low.y + hs_low.y + bias_low.y); + res_high.x = (res_high.x + hs_high.x + bias_high.x); + res_high.y = (res_high.y + hs_high.y + bias_high.y); + } + res_half2[0] = conversion::to(res_low); + res_half2[1] = conversion::to(res_high); + + res_fl2_ptr[offset] = res_fl2; + } +} + +template +void launch_bias_residual(T* residual, + T* hidden_state, + T* attn, + T* bias, + T* attn_bias, + int batch, + int hidden_dim, + int mp_size, + bool preln, + cudaStream_t stream) +{ + int total_count = batch * hidden_dim / 4; + dim3 block_dims(1024); + dim3 grid_dims((total_count - 1) / 1024 + 1); // (batch_size); + + fused_bias_residual<<>>(residual, + hidden_state, + attn, + bias, + attn_bias, + total_count, + hidden_dim / 4, + 1.0 / mp_size, + preln); +} + +#define INSTANTIATE_LAUNCH_BIAS_RESIDUAL(T) \ + template void launch_bias_residual(T*, T*, T*, T*, T*, int, int, int, bool, cudaStream_t); + +INSTANTIATE_LAUNCH_BIAS_RESIDUAL(float); +#ifdef BF16_AVAILABLE +INSTANTIATE_LAUNCH_BIAS_RESIDUAL(__nv_bfloat16); +#endif +INSTANTIATE_LAUNCH_BIAS_RESIDUAL(__half); + +__global__ void gptj_residual_add(float* residual, + const float* hidden_state, + const float* attn, + const float* bias, + const float* attn_bias, + const int total_count, + const int intermediate_size, + const float mp_scale) +{ + float4* res_fl4_ptr = reinterpret_cast(residual); + const float4* hs_fl4_ptr = reinterpret_cast(hidden_state); + const float4* attn_fl4_ptr = reinterpret_cast(attn); + const float4* bias_fl4_ptr = reinterpret_cast(bias); + const float4* attn_bias_fl4_ptr = reinterpret_cast(attn_bias); + const int offset = blockIdx.x * blockDim.x + threadIdx.x; + + if (offset < total_count) { + float4 res_fl4 = res_fl4_ptr[offset]; + const float4 hs_fl4 = hs_fl4_ptr[offset]; + const float4 attn_fl4 = attn_fl4_ptr[offset]; + const float4 bias_fl4 = bias_fl4_ptr[offset % intermediate_size]; + + if (attn_bias) { + float4 attn_bias_fl4 = attn_bias_fl4_ptr[offset % intermediate_size]; + // residual += attention_bias + res_fl4.x += attn_bias_fl4.x; + res_fl4.y += attn_bias_fl4.y; + res_fl4.z += attn_bias_fl4.z; + res_fl4.w += attn_bias_fl4.w; + } + // residual = hidden_state + attention + (residual + bias) * mp_scale + res_fl4.x = hs_fl4.x + attn_fl4.x + (res_fl4.x + bias_fl4.x) * mp_scale; + res_fl4.y = hs_fl4.y + attn_fl4.y + (res_fl4.y + bias_fl4.y) * mp_scale; + res_fl4.z = hs_fl4.z + attn_fl4.z + (res_fl4.z + bias_fl4.z) * mp_scale; + res_fl4.w = hs_fl4.w + attn_fl4.w + (res_fl4.w + bias_fl4.w) * mp_scale; + + res_fl4_ptr[offset] = res_fl4; + } +} + +template +__global__ void gptj_residual_add(T* residual, + const T* hidden_state, + const T* attn, + const T* bias, + const T* attn_bias, + const int total_count, + const int intermediate_size, + const float mp_scale) +{ + using T2 = + typename std::conditional::value, __half2, __nv_bfloat162>::type; + float2* res_fl2_ptr = reinterpret_cast(residual); + const float2* hs_fl2_ptr = reinterpret_cast(hidden_state); + const float2* attn_fl2_ptr = reinterpret_cast(attn); + const float2* bias_fl2_ptr = reinterpret_cast(bias); + const float2* attn_bias_fl2_ptr = reinterpret_cast(attn_bias); + const int offset = blockIdx.x * blockDim.x + threadIdx.x; + + if (offset < total_count) { + float2 res_fl2 = res_fl2_ptr[offset]; + const float2 hs_fl2 = hs_fl2_ptr[offset]; + const float2 attn_fl2 = attn_fl2_ptr[offset]; + const float2 bias_fl2 = bias_fl2_ptr[offset % intermediate_size]; + + T2* res_half2 = reinterpret_cast(&res_fl2); + const T2* hs_half2 = reinterpret_cast(&hs_fl2); + const T2* attn_half2 = reinterpret_cast(&attn_fl2); + const T2* bias_half2 = reinterpret_cast(&bias_fl2); + + float2 res_low = conversion::to(res_half2[0]); + float2 res_high = conversion::to(res_half2[1]); + + const float2 hs_low = conversion::to(hs_half2[0]); + const float2 hs_high = conversion::to(hs_half2[1]); + + const float2 attn_low = conversion::to(attn_half2[0]); + const float2 attn_high = conversion::to(attn_half2[1]); + + const float2 bias_low = conversion::to(bias_half2[0]); + const float2 bias_high = conversion::to(bias_half2[1]); + + if (attn_bias) { + const float2 attn_bias_fl2 = attn_bias_fl2_ptr[offset % intermediate_size]; + const T2* attn_bias_half2 = reinterpret_cast(&attn_bias_fl2); + const float2 attn_bias_low = conversion::to(attn_bias_half2[0]); + const float2 attn_bias_high = conversion::to(attn_bias_half2[1]); + // residual += attention_bias + res_low.x += attn_bias_low.x; + res_low.y += attn_bias_low.y; + res_high.x += attn_bias_high.x; + res_high.y += attn_bias_high.y; + } + // residual = hidden_state + attention + (residual + bias) * mp_scale + res_low.x = attn_low.x + hs_low.x + (res_low.x + bias_low.x) * mp_scale; + res_low.y = attn_low.y + hs_low.y + (res_low.y + bias_low.y) * mp_scale; + res_high.x = attn_high.x + hs_high.x + (res_high.x + bias_high.x) * mp_scale; + res_high.y = attn_high.y + hs_high.y + (res_high.y + bias_high.y) * mp_scale; + + res_half2[0] = conversion::to(res_low); + res_half2[1] = conversion::to(res_high); + + res_fl2_ptr[offset] = res_fl2; + } +} + +template +void launch_gptj_residual_add(T* residual, + T* hidden_state, + T* attn, + T* bias, + T* attn_bias, + int hidden_dim, + int batch, + int mp_size, + cudaStream_t stream) +{ + int total_count = batch * hidden_dim / 4; + dim3 block_dims(1024); + dim3 grid_dims((total_count - 1) / 1024 + 1); // (batch_size); + + gptj_residual_add<<>>( + residual, hidden_state, attn, bias, attn_bias, total_count, hidden_dim / 4, 1.0 / mp_size); +} + +#define INSTANTIATE_GPT_RES_ADD(T) \ + template void launch_gptj_residual_add(T*, T*, T*, T*, T*, int, int, int, cudaStream_t); + +INSTANTIATE_GPT_RES_ADD(float); +INSTANTIATE_GPT_RES_ADD(__half); +#ifdef BF16_AVAILABLE +INSTANTIATE_GPT_RES_ADD(__nv_bfloat16); +#endif + +template +__global__ void moe_res_matmul(T* residual, T* coef, T* mlp_out, int seq_len, int hidden_dim) +{ + constexpr int granularity = 16; + constexpr int vals_per_access = granularity / sizeof(T); + + T* residual_seq = residual + blockIdx.x * hidden_dim; + T* mlp_out_seq = mlp_out + blockIdx.x * hidden_dim; + + for (unsigned tid = threadIdx.x * vals_per_access; tid < hidden_dim; + tid += blockDim.x * vals_per_access) { + T mlp[vals_per_access]; + T res[vals_per_access]; + T coef1[vals_per_access]; + T coef2[vals_per_access]; + + mem_access::load_global(mlp, mlp_out_seq + tid); + mem_access::load_global(res, residual_seq + tid); + mem_access::load_global(coef1, coef + tid); + mem_access::load_global(coef2, coef + tid + hidden_dim); + +#pragma unroll + for (int idx = 0; idx < vals_per_access; idx++) { + mlp[idx] = mlp[idx] * coef2[idx] + res[idx] * coef1[idx]; + } + + mem_access::store_global(mlp_out_seq + tid, mlp); + } +} + +template +void launch_moe_res_matmul(T* residual, + T* coef, + T* mlp_out, + int seq_len, + int hidden_dim, + cudaStream_t stream) +{ + dim3 grid_dim(seq_len); + dim3 block_dim(1024); + moe_res_matmul<<>>( + residual, coef, mlp_out, seq_len, hidden_dim); +} + +#define INSTANTIATE_LAUNCH_MOE_RES_MATMUL(T) \ + template void launch_moe_res_matmul(T*, T*, T*, int, int, cudaStream_t); + +INSTANTIATE_LAUNCH_MOE_RES_MATMUL(float); +#ifdef BF16_AVAILABLE +INSTANTIATE_LAUNCH_MOE_RES_MATMUL(__nv_bfloat16); +#endif +INSTANTIATE_LAUNCH_MOE_RES_MATMUL(__half); + +template +__global__ void pad_data_kernel(T* padded_output, T* output, int head_size, int padded_head_size) +{ + using T2 = + typename std::conditional::value, __half2, __nv_bfloat162>::type; + float4* padded_output_cast = reinterpret_cast(padded_output); + float4* output_cast = reinterpret_cast(output); + int bid = blockIdx.x * (blockDim.y) + threadIdx.y; + int idx = threadIdx.x; + padded_output_cast += (bid * padded_head_size); + output_cast += (bid * head_size); + float4 ZERO; + const T2 zero_h = conversion::to(0.f); + T2* ZERO_h = reinterpret_cast(&ZERO); +#pragma unroll + for (int i = 0; i < 4; i++) ZERO_h[i] = zero_h; + if (idx < head_size) + padded_output_cast[idx] = output_cast[idx]; + else + padded_output_cast[idx] = ZERO; +} + +__global__ void pad_data_kernel(float* padded_output, + float* output, + int head_size, + int padded_head_size) +{ +} + +template +void pad_data(T* padded_output, + T* output, + int bsz, + int head_size, + int padded_head_size, + cudaStream_t stream) +{ + dim3 grid_dim((bsz - 1) / 16 + 1); + dim3 block_dim(padded_head_size / 8, 16); + pad_data_kernel<<>>( + padded_output, output, head_size / 8, padded_head_size / 8); +} + +#define INSTANTIATE_PAD_DATA(T) template void pad_data(T*, T*, int, int, int, cudaStream_t stream); + +INSTANTIATE_PAD_DATA(float); +INSTANTIATE_PAD_DATA(__half); +#ifdef BF16_AVAILABLE +INSTANTIATE_PAD_DATA(__nv_bfloat16); +#endif + +template +__global__ void pad_head_seq_kernel(T* padded_output, + T* output, + int seq_len, + int padded_seq_len, + int head_size, + int padded_head_size) +{ + using T2 = + typename std::conditional::value, __half2, __nv_bfloat162>::type; + float4* padded_output_cast = reinterpret_cast(padded_output); + float4* output_cast = reinterpret_cast(output); + int bsz = blockIdx.x; + int bid = blockIdx.y * (blockDim.y) + threadIdx.y; + int idx = threadIdx.x; + padded_output_cast += (bsz * padded_seq_len + bid) * padded_head_size; + output_cast += (bsz * seq_len + bid) * head_size; + float4 ZERO; + const T2 zero_h = conversion::to(0.f); + T2* ZERO_h = reinterpret_cast(&ZERO); +#pragma unroll + for (int i = 0; i < 4; i++) ZERO_h[i] = zero_h; + + if (idx < head_size && bid < seq_len) + padded_output_cast[idx] = output_cast[idx]; + else + padded_output_cast[idx] = ZERO; +} + +__global__ void pad_head_seq_kernel(float* padded_output, + float* output, + int seq_len, + int padded_seq_len, + int head_size, + int padded_head_size) +{ +} + +template +void pad_head_seq(T* padded_output, + T* output, + int bsz, + int seq_len, + int padded_seq_len, + int head_size, + int padded_head_size, + cudaStream_t stream) +{ + dim3 grid_dim(bsz, padded_seq_len / 16); + dim3 block_dim(padded_head_size / 8, 16); + pad_head_seq_kernel<<>>( + padded_output, output, seq_len, padded_seq_len, head_size / 8, padded_head_size / 8); +} + +#define INSTANTIATE_PAD_HEAD_SEQ(T) \ + template void pad_head_seq(T*, T*, int, int, int, int, int, cudaStream_t); + +INSTANTIATE_PAD_HEAD_SEQ(__half); +#ifdef BF16_AVAILABLE +INSTANTIATE_PAD_HEAD_SEQ(__nv_bfloat16); +#endif +INSTANTIATE_PAD_HEAD_SEQ(float); + +// TODO(cmikeh2): evaluate different GeLU performance +__device__ __forceinline__ float old_gelu(float val) +{ + // 1 / sqrt(2) + constexpr float rsqrt_2 = 0.707106769084930419922; + return val * 0.5f * (1.0f + erff(val * rsqrt_2)); +} + +namespace fused_geglu { +constexpr int threads = 256; +constexpr int steps = 2; +constexpr int granularity = 16; +} // namespace fused_geglu + +__device__ __forceinline__ float silu(float val) { return val / (1.0f + expf(-val)); } + +template +__global__ void fused_gate_activation(T* output, + const T* activation, + const T* bias, + int base_channels, + int output_stride, + int total_elems) +{ + constexpr int T_per_access = fused_geglu::granularity / sizeof(T); + constexpr int T_per_step = T_per_access * fused_geglu::threads; + constexpr int T_per_block = T_per_step * fused_geglu::steps; + + const int id = blockIdx.x * T_per_block + threadIdx.x * T_per_access; + +#pragma unroll + for (int i = 0; i < fused_geglu::steps; i++) { + T activation_buffer_1[T_per_access]; + T activation_buffer_2[T_per_access]; + T bias_buffer_1[T_per_access]; + T bias_buffer_2[T_per_access]; + + const int iter_id = id + T_per_step * i; + if (iter_id < total_elems) { + const int channel_id = iter_id % base_channels; + const int seq_id = iter_id / base_channels; + const int seq_offset = seq_id * base_channels * 2; + + mem_access::load_global(activation_buffer_1, + activation + seq_offset + channel_id); + mem_access::load_global( + activation_buffer_2, activation + seq_offset + channel_id + base_channels); + mem_access::load_global( + bias_buffer_1, bias + channel_id, bias != nullptr); + mem_access::load_global( + bias_buffer_2, bias + channel_id + base_channels, bias != nullptr); + + // Since the GeLU is going to happen at float, might as well + // convert +#pragma unroll + for (int v = 0; v < T_per_access; v++) { + T hidden_state = activation_buffer_1[v] + bias_buffer_1[v]; + T pre_gate = activation_buffer_2[v] + bias_buffer_2[v]; + float pre_gate_f = conversion::to(pre_gate); + float gate_f = (useGelu) ? old_gelu(pre_gate_f) : silu(pre_gate_f); + T gate = conversion::to(gate_f); + activation_buffer_1[v] = hidden_state * gate; + } + + mem_access::store_global( + output + seq_id * output_stride + channel_id, activation_buffer_1); + } + } +} + +template +void launch_gated_activation(T* output, + const T* activation, + const T* bias, + int rows, + int output_stride, + int elems_per_row, + bool use_gelu, + cudaStream_t stream) +{ + /* + Fused bias GEGLU is a variant of the gated activation functions. + The input here is a matrix of [batch, seq_len, 2 * intermediate_dim] + where the second half of the channels act as GeLU gates for the first + half. + */ + + // Re-derive the above figures + constexpr int T_per_access = fused_geglu::granularity / sizeof(T); + constexpr int T_per_step = T_per_access * fused_geglu::threads; + constexpr int T_per_block = T_per_step * fused_geglu::steps; + + const int base_channels = elems_per_row / 2; + const int total_elems = base_channels * rows; + + dim3 block(fused_geglu::threads); + dim3 grid((total_elems + T_per_block - 1) / T_per_block); + + if (use_gelu) { + fused_gate_activation<<>>( + output, activation, bias, base_channels, output_stride, total_elems); + } else { + fused_gate_activation<<>>( + output, activation, bias, base_channels, output_stride, total_elems); + } +} + +#define INSTANTIATE_LAUNCH_GATED_ACTIVATION(T) \ + template void launch_gated_activation( \ + T*, const T*, const T*, int, int, int, bool, cudaStream_t); + +INSTANTIATE_LAUNCH_GATED_ACTIVATION(__half); +#ifdef BF16_AVAILABLE +INSTANTIATE_LAUNCH_GATED_ACTIVATION(__nv_bfloat16); +#endif +INSTANTIATE_LAUNCH_GATED_ACTIVATION(float); diff --git a/toolbox/DeepSpeed/v0.15.3/patches/csrc/transformer/inference/csrc/layer_norm.cu b/toolbox/DeepSpeed/v0.15.3/patches/csrc/transformer/inference/csrc/layer_norm.cu new file mode 100644 index 0000000000000000000000000000000000000000..67091597aaf74144c5473ee5367b3efc31b89bd9 --- /dev/null +++ b/toolbox/DeepSpeed/v0.15.3/patches/csrc/transformer/inference/csrc/layer_norm.cu @@ -0,0 +1,520 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +/* Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. */ +/* All Rights Reserved. */ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include "conversion_utils.h" +#include "ds_kernel_utils.h" +#include "inference_cuda_layers.h" +#include "memory_access_utils.h" +#include "reduction_utils.h" + +namespace cg = cooperative_groups; +using rop = reduce::ROpType; + +namespace ln { +constexpr int granularity = 16; +} // namespace ln + +/* +Primary layer norm implementation. Assumes elems_per_row % 8 +is equal to 0. + +Args: + output: buffer for output data + vals: buffer for input data + gamma: gain for normalization + beta: bias for normalization + epsilon: numeric stability + elems_per_row: number of elements each block will normalize +*/ +template +__global__ void fused_ln(T* output, + const T* vals, + const T* gamma, + const T* beta, + float epsilon, + int elems_per_row) +{ + constexpr int T_per_load = ln::granularity / sizeof(T); + + cg::thread_block tb = cg::this_thread_block(); + cg::thread_block_tile warp = cg::tiled_partition(tb); + + // X-dimension of the block + const int block_offset = (tb.group_index().x * (maxThreads / threadsPerGroup) * elems_per_row) + + (tb.thread_index().y * elems_per_row); + const int thread_offset = tb.thread_index().x * T_per_load; + const int base_offset = block_offset + thread_offset; + const int stride = blockDim.x * T_per_load; + + float sum = reduce::init(); + + const T* input_base = vals + base_offset; + + T local_buffer[unRoll * T_per_load]; + +#pragma unRoll + for (int i = 0; i < unRoll; i++) { + T* iteration_buffer = local_buffer + i * T_per_load; + + mem_access::load_global( + iteration_buffer, input_base + i * stride, thread_offset + i * stride < elems_per_row); + +#pragma unRoll + for (int j = 0; j < T_per_load; j++) { + float vals_up_cast = conversion::to(iteration_buffer[j]); + sum = reduce::element(sum, vals_up_cast); + } + } + + reduce::partitioned_block(tb, warp, sum); + const float mean = sum / elems_per_row; + + float mean_diff = reduce::init(); + +#pragma unRoll + for (int i = 0; i < unRoll; i++) { +#pragma unRoll + for (int j = 0; j < T_per_load; j++) { + // Using a 0 value here skews the variance, have to if-guard + if (thread_offset + i * stride < elems_per_row) { + float diff = (conversion::to(local_buffer[i * T_per_load + j]) - mean); + mean_diff = reduce::element(mean_diff, diff * diff); + } + } + } + + reduce::partitioned_block(tb, warp, mean_diff); + const float variance = mean_diff / elems_per_row; + const float denom = __frsqrt_rn(variance + epsilon); + + // const T mean_compute = conversion::to(mean); + // const T denom_compute = conversion::to(denom); + + T* block_output = output + block_offset; + +#pragma unRoll + for (int i = 0; i < unRoll; i++) { + T* iteration_buffer = local_buffer + i * T_per_load; + const int iter_idx = i * stride + thread_offset; + const bool do_loads = iter_idx < elems_per_row; + + T gamma_local[T_per_load], beta_local[T_per_load]; + + mem_access::load_global(gamma_local, gamma + iter_idx, do_loads); + mem_access::load_global(beta_local, beta + iter_idx, do_loads); + +#pragma unRoll + for (int j = 0; j < T_per_load; j++) { + float val = conversion::to(iteration_buffer[j]); + val = (val - mean) * denom; + val = + val * conversion::to(gamma_local[j]) + conversion::to(beta_local[j]); + iteration_buffer[j] = conversion::to(val); + } + + if (do_loads) { + mem_access::store_global(block_output + iter_idx, iteration_buffer); + } + } +} + +#define LAUNCH_FUSED_LN(unRollFactor, threadsPerGroup, maxThreads) \ + fused_ln \ + <<>>(output, vals, gamma, beta, epsilon, elems_per_row); + +template +void launch_fused_ln(T* output, + const T* vals, + const T* gamma, + const T* beta, + float epsilon, + int rows, + int elems_per_row, + cudaStream_t stream) +{ + // 8 for __half, 4 for float + constexpr int T_per_load = ln::granularity / sizeof(T); + + constexpr int maxThreads = 256; + + // For Flaoat, unRoll 4, for __half, unRoll 2 + constexpr int internal_unRoll = sizeof(T) == 4 ? 4 : 2; + + const bool is_subblock_schedule = (elems_per_row <= 128) ? true : false; + const int h_per_step = is_subblock_schedule ? T_per_load : T_per_load * internal_unRoll; + + // Scheduling concern: may be slightly faster for some inputs to assign multiple stages of + // warp-sized blocks rather than stepping up to 64/96 threads + const int one_step_threads = next_pow2((elems_per_row + h_per_step - 1) / h_per_step); + const int threadsPerGroup = (one_step_threads < maxThreads) ? one_step_threads : maxThreads; + + const int groups_per_block_max = + is_subblock_schedule ? (maxThreads + threadsPerGroup - 1) / threadsPerGroup : 1; + const int groups_per_block = (rows < groups_per_block_max) ? rows : groups_per_block_max; + const int groups_launch = (groups_per_block + rows - 1) / groups_per_block; + + dim3 block(threadsPerGroup, groups_per_block); + dim3 grid(groups_launch); + + const int elems_per_step = threadsPerGroup * h_per_step; + const int external_unRoll = (elems_per_row + elems_per_step - 1) / elems_per_step; + + if (is_subblock_schedule) { + // <=128 + if (threadsPerGroup == 1) { + LAUNCH_FUSED_LN(1, 1, maxThreads); + } else if (threadsPerGroup == 2) { + LAUNCH_FUSED_LN(1, 2, maxThreads); + } else if (threadsPerGroup == 4) { + LAUNCH_FUSED_LN(1, 4, maxThreads); + } else if (threadsPerGroup == 8) { + LAUNCH_FUSED_LN(1, 8, maxThreads); + } else if (threadsPerGroup == 16) { + LAUNCH_FUSED_LN(1, 16, maxThreads); + } + } else if (external_unRoll == 1) { + // 129 - 4096 elems + // (this can launch with 1-7 warps as well) + LAUNCH_FUSED_LN(1 * internal_unRoll, maxThreads, maxThreads); + } else if (external_unRoll == 2) { + // 4097 - 8192 elems + LAUNCH_FUSED_LN(2 * internal_unRoll, maxThreads, maxThreads); + } else if (external_unRoll == 3) { + // 8193 - 12288 elems + LAUNCH_FUSED_LN(3 * internal_unRoll, maxThreads, maxThreads); + } else if (external_unRoll == 4) { + // 12289 - 16384 elems + LAUNCH_FUSED_LN(4 * internal_unRoll, maxThreads, maxThreads); + } +} + +#define INSTANTIATE_FUSED_LN(T) \ + template void launch_fused_ln(T*, const T*, const T*, const T*, float, int, int, cudaStream_t); + +INSTANTIATE_FUSED_LN(__half); +#ifdef BF16_AVAILABLE +INSTANTIATE_FUSED_LN(__nv_bfloat16); +#endif +INSTANTIATE_FUSED_LN(float); + +/* +Fused resiual + bias + layer norm implementation. Assumes elems_per_row % 8 +is equal to 0. + +TODO(cmikeh2): Goal is to deprecate this implementation. The bias + residual +need to be fused into compute-bound producer operations. + +Args: + output: buffer for output data + res_output: output of residual addition + vals: buffer for input data + residual: residual data + bias: bias of of input data + gamma: gain for normalization + beta: bias for normalization + epsilon: numeric stability + elems_per_row: number of elements each block will normalize +Template arg: + StoreResidual: controls whether the residual calculation is stored + or not. When set to false, the input `res_output` is unused. +*/ +template +__global__ void fused_residual_ln(T* output, + T* res_output, + const T* vals, + const T* residual, + const T* bias, + const T* gamma, + const T* beta, + float epsilon, + int elems_per_row) +{ + constexpr int T_per_load = ln::granularity / sizeof(T); + + cg::thread_block tb = cg::this_thread_block(); + cg::thread_block_tile warp = cg::tiled_partition(tb); + + // X-dimension of the block + const int block_offset = (tb.group_index().x * (maxThreads / threadsPerGroup) * elems_per_row) + + (tb.thread_index().y * elems_per_row); + const int thread_offset = tb.thread_index().x * T_per_load; + const int base_offset = block_offset + thread_offset; + const int stride = tb.size() * T_per_load; + + float sum = reduce::init(); + + const T* input_base = vals + base_offset; + const T* residual_base = residual + base_offset; + const T* bias_base = bias + thread_offset; + + T local_buffer[unRoll * T_per_load]; + + // Unlike a vanilla layernorm, since we're fusing the two adds as well + // an inner unRoll seems to be less valuable. If anything, a double unRoll + // makes the most sense if we find we are having performance issues. +#pragma unRoll + for (int i = 0; i < unRoll; i++) { + T* iteration_buffer = local_buffer + i * T_per_load; + T residual_buffer[T_per_load]; + T bias_buffer[T_per_load]; + + mem_access::load_global( + iteration_buffer, input_base + i * stride, thread_offset + i * stride < elems_per_row); + mem_access::load_global(residual_buffer, + residual_base + i * stride, + thread_offset + i * stride < elems_per_row); + mem_access::load_global( + bias_buffer, bias_base + i * stride, thread_offset + i * stride < elems_per_row); + +#pragma unRoll + for (int j = 0; j < T_per_load; j++) { + float vals_up_cast = conversion::to(iteration_buffer[j]); + float res_up_cast = conversion::to(residual_buffer[j]); + float bias_up_cast = conversion::to(bias_buffer[j]); + vals_up_cast = vals_up_cast + bias_up_cast + res_up_cast; + sum = reduce::element(sum, vals_up_cast); + iteration_buffer[j] = conversion::to(vals_up_cast); + } + + if (preLnResidual && (thread_offset + i * stride < elems_per_row)) { + mem_access::store_global(res_output + base_offset + i * stride, + iteration_buffer); + } + } + + reduce::partitioned_block(tb, warp, sum); + const float mean = sum / elems_per_row; + + float mean_diff = reduce::init(); +#pragma unRoll + for (int i = 0; i < unRoll; i++) { +#pragma unRoll + for (int j = 0; j < T_per_load; j++) { + // Using a 0 value here skews the variance, have to if-guard + if (thread_offset + i * stride < elems_per_row) { + float diff = (conversion::to(local_buffer[i * T_per_load + j]) - mean); + mean_diff = reduce::element(mean_diff, diff * diff); + } + } + } + + reduce::partitioned_block(tb, warp, mean_diff); + const float variance = mean_diff / elems_per_row; + const float denom = __frsqrt_rn(variance + epsilon); + + T* block_output = output + block_offset; + +#pragma unRoll + for (int i = 0; i < unRoll; i++) { + T* iteration_buffer = local_buffer + i * T_per_load; + const int iter_idx = i * stride + thread_offset; + const bool do_loads = iter_idx < elems_per_row; + + T gamma_local[T_per_load], beta_local[T_per_load]; + + mem_access::load_global(gamma_local, gamma + iter_idx, do_loads); + mem_access::load_global(beta_local, beta + iter_idx, do_loads); + +#pragma unRoll + for (int j = 0; j < T_per_load; j++) { + // iteration_buffer[j] = (iteration_buffer[j] - mean_compute) * denom_compute; + // iteration_buffer[j] = iteration_buffer[j] * gamma_local[j] + beta_local[j]; + float val = conversion::to(iteration_buffer[j]); + val = (val - mean) * denom; + val = + val * conversion::to(gamma_local[j]) + conversion::to(beta_local[j]); + iteration_buffer[j] = conversion::to(val); + } + + if (do_loads) { + mem_access::store_global(block_output + iter_idx, iteration_buffer); + } + } +} + +// TODO(cmikeh2): There's a bunch of redundancy here that needs to be removed/simplified. +#define LAUNCH_FUSED_RES_LN(unRollFactor, threadsPerGroup, maxThreads) \ + fused_residual_ln \ + <<>>( \ + output, nullptr, vals, residual, bias, gamma, beta, epsilon, elems_per_row); + +template +void launch_fused_residual_ln(T* output, + const T* vals, + const T* residual, + const T* bias, + const T* gamma, + const T* beta, + float epsilon, + int rows, + int elems_per_row, + cudaStream_t stream) +{ + // 8 for __half, 4 for float + constexpr int T_per_load = ln::granularity / sizeof(T); + + constexpr int maxThreads = 256; + + // For Flaoat, unRoll 4, for __half, unRoll 2 + constexpr int internal_unRoll = sizeof(T) == 4 ? 4 : 2; + + const bool is_subblock_schedule = (elems_per_row <= 128) ? true : false; + const int h_per_step = is_subblock_schedule ? T_per_load : T_per_load * internal_unRoll; + + // Scheduling concern: may be slightly faster for some inputs to assign multiple stages of + // warp-sized blocks rather than stepping up to 64/96 threads + const int one_step_threads = next_pow2((elems_per_row + h_per_step - 1) / h_per_step); + const int threadsPerGroup = (one_step_threads < maxThreads) ? one_step_threads : maxThreads; + + const int groups_per_block_max = + is_subblock_schedule ? (maxThreads + threadsPerGroup - 1) / threadsPerGroup : 1; + const int groups_per_block = (rows < groups_per_block_max) ? rows : groups_per_block_max; + const int groups_launch = (groups_per_block + rows - 1) / groups_per_block; + + dim3 block(threadsPerGroup, groups_per_block); + dim3 grid(groups_launch); + + const int elems_per_step = threadsPerGroup * h_per_step; + const int external_unRoll = (elems_per_row + elems_per_step - 1) / elems_per_step; + + if (is_subblock_schedule) { + // <=128 + if (threadsPerGroup == 1) { + LAUNCH_FUSED_RES_LN(1, 1, maxThreads); + } else if (threadsPerGroup == 2) { + LAUNCH_FUSED_RES_LN(1, 2, maxThreads); + } else if (threadsPerGroup == 4) { + LAUNCH_FUSED_RES_LN(1, 4, maxThreads); + } else if (threadsPerGroup == 8) { + LAUNCH_FUSED_RES_LN(1, 8, maxThreads); + } else if (threadsPerGroup == 16) { + LAUNCH_FUSED_RES_LN(1, 16, maxThreads); + } + } else if (external_unRoll == 1) { + // 129 - 4096 elems + // (this can launch with 1-7 warps as well) + LAUNCH_FUSED_RES_LN(1 * internal_unRoll, maxThreads, maxThreads); + } else if (external_unRoll == 2) { + // 4097 - 8192 elems + LAUNCH_FUSED_RES_LN(2 * internal_unRoll, maxThreads, maxThreads); + } else if (external_unRoll == 3) { + // 8193 - 12288 elems + LAUNCH_FUSED_RES_LN(3 * internal_unRoll, maxThreads, maxThreads); + } else if (external_unRoll == 4) { + // 12289 - 16384 elems + LAUNCH_FUSED_RES_LN(4 * internal_unRoll, maxThreads, maxThreads); + } +} + +#define LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES(unRollFactor, threadsPerGroup, maxThreads) \ + fused_residual_ln \ + <<>>( \ + norm_output, res_output, vals, residual, bias, gamma, beta, epsilon, elems_per_row); + +template +void launch_fused_residual_ln_store_pre_ln_res(T* norm_output, + T* res_output, + const T* vals, + const T* residual, + const T* bias, + const T* gamma, + const T* beta, + float epsilon, + int rows, + int elems_per_row, + cudaStream_t stream) +{ + // 8 for __half, 4 for float + constexpr int T_per_load = ln::granularity / sizeof(T); + + constexpr int maxThreads = 256; + + // For Flaoat, unRoll 4, for __half, unRoll 2 + constexpr int internal_unRoll = sizeof(T) == 4 ? 4 : 2; + + const bool is_subblock_schedule = (elems_per_row <= 128) ? true : false; + const int h_per_step = is_subblock_schedule ? T_per_load : T_per_load * internal_unRoll; + + // Scheduling concern: may be slightly faster for some inputs to assign multiple stages of + // warp-sized blocks rather than stepping up to 64/96 threads + const int one_step_threads = next_pow2((elems_per_row + h_per_step - 1) / h_per_step); + const int threadsPerGroup = (one_step_threads < maxThreads) ? one_step_threads : maxThreads; + + const int groups_per_block_max = + is_subblock_schedule ? (maxThreads + threadsPerGroup - 1) / threadsPerGroup : 1; + const int groups_per_block = (rows < groups_per_block_max) ? rows : groups_per_block_max; + const int groups_launch = (groups_per_block + rows - 1) / groups_per_block; + + dim3 block(threadsPerGroup, groups_per_block); + dim3 grid(groups_launch); + + const int elems_per_step = threadsPerGroup * h_per_step; + const int external_unRoll = (elems_per_row + elems_per_step - 1) / elems_per_step; + + if (is_subblock_schedule) { + // <=128 + if (threadsPerGroup == 1) { + LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES(1, 1, maxThreads); + } else if (threadsPerGroup == 2) { + LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES(1, 2, maxThreads); + } else if (threadsPerGroup == 4) { + LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES(1, 4, maxThreads); + } else if (threadsPerGroup == 8) { + LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES(1, 8, maxThreads); + } else if (threadsPerGroup == 16) { + LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES(1, 16, maxThreads); + } + } else if (external_unRoll == 1) { + // 129 - 4096 elems + // (this can launch with 1-7 warps as well) + LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES(1 * internal_unRoll, maxThreads, maxThreads); + } else if (external_unRoll == 2) { + // 4097 - 8192 elems + LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES(2 * internal_unRoll, maxThreads, maxThreads); + } else if (external_unRoll == 3) { + // 8193 - 12288 elems + LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES(3 * internal_unRoll, maxThreads, maxThreads); + } else if (external_unRoll == 4) { + // 12289 - 16384 elems + LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES(4 * internal_unRoll, maxThreads, maxThreads); + } +} + +#define INSTANTIATE_RES_LN(T) \ + template void launch_fused_residual_ln( \ + T*, const T*, const T*, const T*, const T*, const T*, float, int, int, cudaStream_t); + +#define INSTANTIATE_PRE_LN_RES(T) \ + template void launch_fused_residual_ln_store_pre_ln_res( \ + T*, T*, const T*, const T*, const T*, const T*, const T*, float, int, int, cudaStream_t); + +INSTANTIATE_RES_LN(__half); +INSTANTIATE_RES_LN(float); +#ifdef BF16_AVAILABLE +INSTANTIATE_RES_LN(__nv_bfloat16); +#endif + +INSTANTIATE_PRE_LN_RES(__half); +INSTANTIATE_PRE_LN_RES(float); +#ifdef BF16_AVAILABLE +INSTANTIATE_PRE_LN_RES(__nv_bfloat16); +#endif diff --git a/toolbox/DeepSpeed/v0.15.3/patches/csrc/transformer/inference/csrc/pointwise_ops.cu b/toolbox/DeepSpeed/v0.15.3/patches/csrc/transformer/inference/csrc/pointwise_ops.cu new file mode 100644 index 0000000000000000000000000000000000000000..209005bb4a232f0998bf56497d275a617382dc65 --- /dev/null +++ b/toolbox/DeepSpeed/v0.15.3/patches/csrc/transformer/inference/csrc/pointwise_ops.cu @@ -0,0 +1,91 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +/* Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. */ +/* All Rights Reserved. */ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include +#include "conversion_utils.h" +#include "ds_kernel_utils.h" +#include "memory_access_utils.h" + +namespace pwise { +constexpr int granularity = 16; +constexpr int unroll = 4; +constexpr int threads = 256; +} // namespace pwise + +template +__global__ void vector_add_kernel(T* out, const T* a, const T* b, float gamma, int num_elems) +{ + constexpr int T_per_access = pwise::granularity / sizeof(T); + + const int block_offset = blockIdx.x * pwise::threads * pwise::unroll * T_per_access; + const int thread_offset = threadIdx.x * T_per_access; + const int total_offset = block_offset + thread_offset; + constexpr int stride = pwise::threads * T_per_access; + +#pragma unroll + for (int i = 0; i < pwise::unroll; i++) { + T temp_buf_a[T_per_access], temp_buf_b[T_per_access]; + + const int iter_idx = total_offset + i * stride; + + mem_access::load_global(temp_buf_a, a + iter_idx, iter_idx < num_elems); + mem_access::load_global(temp_buf_b, b + iter_idx, iter_idx < num_elems); + +#pragma unroll + for (int j = 0; j < T_per_access; j++) { + float up_cast_a = conversion::to(temp_buf_a[j]); + float up_cast_b = conversion::to(temp_buf_b[j]); + temp_buf_a[j] = conversion::to((gamma * up_cast_a) + up_cast_b); + } + + if (iter_idx < num_elems) { + mem_access::store_global(out + iter_idx, temp_buf_a); + } + } +} + +template +void launch_vector_add(T* out, + const T* a, + const T* b, + float gamma, + int num_elems, + cudaStream_t stream) +{ + constexpr int T_per_access = pwise::granularity / sizeof(T); + constexpr int T_per_block = pwise::threads * T_per_access * pwise::unroll; + + dim3 block(pwise::threads); + dim3 grid((num_elems + T_per_block - 1) / T_per_block); + + vector_add_kernel<<>>(out, a, b, gamma, num_elems); +} + +#define INSTANTIATE_VECTOR_ADD(T) \ + template void launch_vector_add( \ + T * out, const T* a, const T* b, float gamma, int num_elems, cudaStream_t stream); + +INSTANTIATE_VECTOR_ADD(float) +INSTANTIATE_VECTOR_ADD(__half) +#ifdef BF16_AVAILABLE +INSTANTIATE_VECTOR_ADD(__nv_bfloat16) +#endif diff --git a/toolbox/DeepSpeed/v0.15.3/patches/csrc/transformer/inference/csrc/pt_binding.cpp b/toolbox/DeepSpeed/v0.15.3/patches/csrc/transformer/inference/csrc/pt_binding.cpp new file mode 100644 index 0000000000000000000000000000000000000000..22c9bc0884db0c7dddc7a71e524db63c15cdfa73 --- /dev/null +++ b/toolbox/DeepSpeed/v0.15.3/patches/csrc/transformer/inference/csrc/pt_binding.cpp @@ -0,0 +1,2061 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +/* Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. */ +/* All Rights Reserved. */ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include +#include +#include +#include +#include "inference_context.h" +#include "inference_cublas_wrappers.h" +#include "inference_cuda_layers.h" + +std::array gemm_algos = std::array({99, 99, 99}); + +// NOTE: This activation function type enum should be always in sync +// with the python counterpart, otherwise the casting from python binding +// will be incorrect. +enum class ActivationFuncType { UNKNOWN = 0, GELU = 1, ReLU = 2, GATED_GELU = 3, GATED_SILU = 4 }; + +enum class NormType { UNKNOWN = 0, LayerNorm = 1, GroupNorm = 2, RMSNorm = 3 }; + +enum class TransformerType : uint8_t { UNKNOWN = 0, GPTType = 1, BERTType = 2 }; + +// NOTE: this is a temporary and dodgy solution to distinguish GPT and BERT style models +// based on the dimensions of the corresponding attention mask. +inline auto infer_transformer_type(at::Tensor& attn_mask) -> TransformerType +{ + auto attn_mask_num_dims = attn_mask.sizes().size(); + + if (attn_mask_num_dims > 2) { + return TransformerType::GPTType; + } else if (attn_mask_num_dims == 2) { + return TransformerType::BERTType; + } else { + return TransformerType::UNKNOWN; + } +} + +// infer stride of attention mask memory layout based on the model type. +inline auto get_attn_mask_stride(at::Tensor& attn_mask) -> int +{ + auto trnsfrmr_type = infer_transformer_type(attn_mask); + + if (trnsfrmr_type == TransformerType::GPTType) { + return attn_mask.size(2); + } else if (trnsfrmr_type == TransformerType::BERTType) { + // Bert style models have always a mask stride of 1. + return 1; + } else if (trnsfrmr_type == TransformerType::UNKNOWN) { + return 0; + } + + // this is just to make the compiler happy. + return 0; +} + +template +at::Tensor ds_softmax(at::Tensor& attn_scores, + at::Tensor& attn_mask, + at::Tensor& alibi, + bool triangular, + bool recompute, + bool local_attention, + int window_size, + bool async_op, + float layer_scale, + int head_offset, + int mp_size) +{ + auto attn_scores_c = attn_scores.contiguous(); + int bsz = attn_scores_c.size(0); + + int seq_len = attn_scores_c.size(1); + int len = attn_scores_c.sizes().size(); + if (len > 2) seq_len = attn_scores_c.size(2); + + int soft_len = attn_scores_c.size(2); + if (len > 3) soft_len = attn_scores_c.size(3); + + int heads = 1; + if (len > 1) heads = attn_scores_c.size(1); + + auto mask_stride = get_attn_mask_stride(attn_mask); + + launch_attn_softmax_v2((T*)attn_scores_c.data_ptr(), + (attn_mask.sizes().size() > 1 ? (T*)attn_mask.data_ptr() : nullptr), + (alibi.sizes().size() > 1 ? (T*)alibi.data_ptr() : nullptr), + layer_scale, + triangular, + recompute, + local_attention, + window_size, + bsz, + heads, + seq_len, + soft_len, + head_offset, + mask_stride, + mp_size, + InferenceContext::Instance().GetCurrentStream(async_op)); + + return attn_scores_c; +} + +template +void allocate_workspace(unsigned hidden_dim, + unsigned num_heads, + unsigned prompt_length, + unsigned batch_size, + unsigned num_layers, + unsigned mp_size = 1, + bool external_cache = false, + unsigned rank = 0, + unsigned max_out_tokens = 1024, + unsigned min_out_tokens = 1) +{ + InferenceContext::Instance().GenWorkSpace(num_layers, + num_heads, + batch_size, + prompt_length, + hidden_dim, + mp_size, + external_cache, + sizeof(T), + rank, + max_out_tokens, + min_out_tokens); +} + +template +at::Tensor einsum_sec_sm_ecm(at::Tensor& Q, at::Tensor& W) +{ + auto options = at::TensorOptions() + .dtype(Q.options().dtype()) + .layout(at::kStrided) + .device(at::kCUDA) + .requires_grad(false); + T* workspace = (T*)InferenceContext::Instance().GetWorkSpace(); + float alpha = 1; + float gemm_beta = 0.0; + + /* + // Reallocate memory if we received a new prompt + if (!workspace || input.size(1) != 1) { + allocate_workspace(W.size(1), InferenceContext::Instance().GetMaxTokenLength(), + Q.size(0), 1, head_size); workspace = (T*)InferenceContext::Instance().GetWorkSpace(); + } + */ + + auto O = at::from_blob(workspace, {Q.size(1), Q.size(2), W.size(1)}, options); + unsigned m = W.size(1); + unsigned n = Q.size(1) * Q.size(2); + unsigned k = Q.size(0); + cublas_gemm_ex(InferenceContext::Instance().GetCublasHandle(), + CUBLAS_OP_N, + CUBLAS_OP_T, + m, + n, + k, + &alpha, + &gemm_beta, + (T*)W.data_ptr(), + (T*)Q.data_ptr(), + (T*)O.data_ptr(), +// TODO HIP: Remove backward compatibility for torch<=2.0 in future +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) + rocblas_gemm_algo_standard); +#else + CUBLAS_GEMM_DEFAULT_TENSOR_OP); +#endif + return O; +} + +template +void attention_unfused(at::Tensor& prev_key_cont, + at::Tensor& query_cont, + at::Tensor& attn_mask, + at::Tensor& prev_value_cont, + at::Tensor& output, + int& bsz, + int& seq_len, + int& soft_len, + int& heads, + float& norm_factor, + bool triangular, + bool recompute, + bool local_attention, + int window_size) +{ + auto options = at::TensorOptions() + .dtype(query_cont.options().dtype()) + .layout(at::kStrided) + .device(at::kCUDA) + .requires_grad(false); + float alpha = norm_factor; + float gemm_beta = 0.0; + auto attn_score = at::empty({bsz, heads, seq_len, soft_len}, options); + int k = prev_value_cont.size(2) / heads; + + auto mask_stride = get_attn_mask_stride(attn_mask); + + cublasSetStream(InferenceContext::Instance().GetCublasHandle(), + InferenceContext::Instance().GetCurrentStream()); + cublas_strided_batched_gemm(InferenceContext::Instance().GetCublasHandle(), + soft_len, + seq_len, + k, + &alpha, + &gemm_beta, + (T*)prev_key_cont.data_ptr(), + (T*)query_cont.data_ptr(), + (T*)attn_score.data_ptr(), + CUBLAS_OP_N, + CUBLAS_OP_N, + soft_len * k, + seq_len * k, + seq_len * soft_len, + bsz * heads, +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) + rocblas_gemm_algo_standard); +#else + CUBLAS_GEMM_DEFAULT_TENSOR_OP); +#endif + launch_attn_softmax_v2((T*)attn_score.data_ptr(), + (T*)(attn_mask.sizes().size() > 1 ? attn_mask.data_ptr() : nullptr), + (T*)nullptr, + 1.0, + triangular, + recompute, + local_attention, + window_size, + bsz, + heads, + seq_len, + soft_len, + 0, + mask_stride, + 1, + InferenceContext::Instance().GetCurrentStream(false)); + alpha = 1.0; + cublas_strided_batched_gemm(InferenceContext::Instance().GetCublasHandle(), + k, + seq_len, + soft_len, + &alpha, + &gemm_beta, + (T*)prev_value_cont.data_ptr(), + (T*)attn_score.data_ptr(), + (T*)output.data_ptr(), + CUBLAS_OP_N, + CUBLAS_OP_N, + soft_len * k, + seq_len * soft_len, + seq_len * k, + bsz * heads, +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) + rocblas_gemm_algo_standard); +#else + CUBLAS_GEMM_DEFAULT_TENSOR_OP); +#endif +} + +template +std::vector ds_softmax_context1(at::Tensor& query, + at::Tensor& prev_key, + at::Tensor& new_key, + at::Tensor& attn_mask, + at::Tensor& prev_value, + at::Tensor& new_value, + int heads, + float norm_factor, + bool merging, + bool triangular, + bool local_attention, + int window_size, + bool no_masking) +{ + auto query_cont = query.contiguous(); + auto prev_key_cont = prev_key.contiguous(); + auto prev_value_cont = prev_value.contiguous(); + + int new_size = (new_value.sizes().size() > 1 ? new_value.size(1) : 0); + + // Attn_Score [ batch Head Sequence-length Softmax-length] + + int bsz = query_cont.size(0); + int seq_len = query_cont.size(1); + int soft_len = prev_value.size(1); + + auto options = at::TensorOptions() + .dtype(query_cont.options().dtype()) + .layout(at::kStrided) + .device(at::kCUDA) + .requires_grad(false); + + auto output = + at::empty({prev_value.size(0), heads, seq_len, prev_value.size(2) / heads}, options); + attention_unfused(prev_key_cont, + query_cont, + attn_mask, //(no_masking ? nullptr : (T*)attn_mask.data_ptr()), + prev_value_cont, + output, + bsz, + seq_len, + soft_len, + heads, + norm_factor, + (triangular && (new_size == 0)), + (new_size == 0), + local_attention, + window_size); + + return {output, prev_key, prev_value}; +} + +template +void ds_softmax_internal(T* attn_scores, + at::Tensor& attn_mask, + at::Tensor& alibi, + float& layer_scale, + bool triangular, + bool recompute, + bool local_attention, + int window_size, + int bsz, + int seq_len, + int soft_len, + int heads) +{ + auto mask_stride = get_attn_mask_stride(attn_mask); + + launch_attn_softmax_v2((T*)attn_scores, + (attn_mask.sizes().size() > 1 ? (T*)attn_mask.data_ptr() : nullptr), + (alibi.sizes().size() > 1 ? (T*)alibi.data_ptr() : nullptr), + layer_scale, + triangular, + recompute, + local_attention, + window_size, + bsz, + heads, + seq_len, + soft_len, + 0, + mask_stride, + 1, + at::cuda::getCurrentCUDAStream()); +} + +template +void attention_unfused(T* prev_key_cont, + T* query_cont, + at::Tensor& attn_mask, + T* prev_value_cont, + T* output, + unsigned& bsz, + int& k, + unsigned& seq_len, + unsigned& soft_len, + int& heads, + float& norm_factor, + bool triangular, + bool recompute, + bool local_attention, + int window_size, + at::Tensor& alibi, + int layer_id) +{ + float layer_scale = alibi.sizes().size() > 1 ? std::max(1, layer_id) : 1.0; + float alpha = norm_factor * norm_factor / layer_scale; + float gemm_beta = 0.0; + T* workspace = (T*)InferenceContext::Instance().GetAttentionUnfusedWorkspace(); + + cublasSetStream(InferenceContext::Instance().GetCublasHandle(), + InferenceContext::Instance().GetCurrentStream()); + cublas_strided_batched_gemm(InferenceContext::Instance().GetCublasHandle(), + soft_len, + seq_len, + k, + &alpha, + &gemm_beta, + (T*)prev_key_cont, + (T*)query_cont, + workspace, + CUBLAS_OP_T, + CUBLAS_OP_N, + InferenceContext::Instance().GetMaxTokenLength() * k, + seq_len * k, + seq_len * soft_len, + bsz * heads, +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) + rocblas_gemm_algo_standard); +#else + CUBLAS_GEMM_DEFAULT_TENSOR_OP); +#endif + ds_softmax_internal(workspace, + attn_mask, + alibi, + layer_scale, + triangular, + recompute, + local_attention, + window_size, + bsz, + seq_len, + soft_len, + heads); + alpha = 1.0; + cublas_strided_batched_gemm(InferenceContext::Instance().GetCublasHandle(), + k, + seq_len, + soft_len, + &alpha, + &gemm_beta, + (T*)prev_value_cont, + workspace, + (T*)output, + CUBLAS_OP_N, + CUBLAS_OP_N, + InferenceContext::Instance().GetMaxTokenLength() * k, + seq_len * soft_len, + seq_len * k, + bsz * heads, +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) + rocblas_gemm_algo_standard); +#else + CUBLAS_GEMM_DEFAULT_TENSOR_OP); +#endif +} + +void reset_cache() { InferenceContext::Instance().reset_tokens(); } + +template +std::vector ds_softmax_context(at::Tensor& query_key_value, + at::Tensor& attn_mask, + int rotary_dim, + bool rotate_half, + bool rotate_every_two, + int heads, + int num_kv, + float norm_factor, + bool triangular, + bool local_attention, + int window_size, + bool no_masking, + unsigned layer_id, + unsigned num_layers, + at::Tensor& alibi, + float rope_theta, + bool is_prompt, + std::optional token_idx, + std::optional position_ids) +{ + unsigned bsz = query_key_value.size(0); + unsigned seq_len = query_key_value.size(1); + int k = query_key_value.size(2) / (heads + 2 * (num_kv > 0 ? num_kv : heads)); + unsigned hidden_dim = heads * k; + + is_prompt = (seq_len > 1); + + if (is_prompt) InferenceContext::Instance().reset_tokens(seq_len); + unsigned soft_len = InferenceContext::Instance().current_tokens(); + + auto options = at::TensorOptions() + .dtype(query_key_value.options().dtype()) + .layout(at::kStrided) + .device(at::kCUDA) + .requires_grad(false); + + T* workspace = (T*)InferenceContext::Instance().GetWorkSpace(); + size_t buf_size = bsz * seq_len * hidden_dim; + auto output = torch::from_blob(workspace + 4 * buf_size, {bsz, seq_len, hidden_dim}, options); + + auto query_cont = workspace + 5 * buf_size; + size_t offset = + 10 * (hidden_dim * bsz * InferenceContext::Instance().GetMaxTokenLength()) + + layer_id * 2 * bsz * InferenceContext::Instance().GetMaxTokenLength() * hidden_dim; + unsigned all_tokens = soft_len; + auto kv_cache = workspace + offset + (hidden_dim / heads) * (is_prompt ? 0 : soft_len - 1); + size_t value_offset = bsz * InferenceContext::Instance().GetMaxTokenLength() * hidden_dim; + + T* temp_buf = (T*)output.data_ptr() + at::numel(output); + launch_bias_add_transform_0213((T*)query_cont, + kv_cache, + kv_cache + value_offset, + (T*)query_key_value.data_ptr(), + nullptr, + bsz, + seq_len, + (is_prompt ? 0 : soft_len - 1), + soft_len, + hidden_dim, + heads, + (num_kv > 0 ? num_kv : heads), + rotary_dim, + rotate_half, + rotate_every_two, + InferenceContext::Instance().GetCurrentStream(), + 3, + InferenceContext::Instance().GetMaxTokenLength(), + rope_theta); + if (rotary_dim > 0 && rotate_half) + launch_apply_rotary_pos_emb(query_cont, + kv_cache, + k, + seq_len, + rotary_dim, + (is_prompt ? 0 : soft_len - 1), + heads, + bsz, + rope_theta, + InferenceContext::Instance().GetCurrentStream(), + InferenceContext::Instance().GetMaxTokenLength()); + + attention_unfused(workspace + offset, + (T*)query_cont, + attn_mask, + workspace + offset + value_offset, + temp_buf, + bsz, + k, + seq_len, + all_tokens, + heads, + norm_factor, + (triangular && is_prompt), + is_prompt, + local_attention, + window_size, + alibi, + layer_id); + launch_transform4d_0213((T*)output.data_ptr(), + temp_buf, + bsz, + heads, + seq_len, + output.size(2), + InferenceContext::Instance().GetCurrentStream(false), + 1); + + if (layer_id == num_layers - 1) InferenceContext::Instance().advance_tokens(); + auto prev_key = torch::from_blob( + workspace + offset, + {bsz, heads, all_tokens, k}, + {hidden_dim * static_cast(InferenceContext::Instance().GetMaxTokenLength()), + k * static_cast(InferenceContext::Instance().GetMaxTokenLength()), + k, + 1}, + options); + + auto prev_value = torch::from_blob( + workspace + offset + value_offset, + {bsz, heads, all_tokens, k}, + {hidden_dim * static_cast(InferenceContext::Instance().GetMaxTokenLength()), + k * static_cast(InferenceContext::Instance().GetMaxTokenLength()), + k, + 1}, + options); + + return {output, prev_key, prev_value}; +} + +template +at::Tensor ds_bias_gelu(at::Tensor& input, at::Tensor& bias) +{ + auto input_cont = input.contiguous(); + + int bsz = input_cont.size(0) * input_cont.size(1); + int intermediate_size = input_cont.size(2); + + launch_bias_gelu((T*)input_cont.data_ptr(), + (T*)bias.data_ptr(), + intermediate_size, + bsz, + InferenceContext::Instance().GetCurrentStream()); + return input_cont; +} + +#define DISPATCH_GATED_ACT(T_TYPE, C_TYPE) \ + if (activation.options().dtype() == torch::T_TYPE) { \ + launch_gated_activation((C_TYPE*)output.data_ptr(), \ + (const C_TYPE*)activation.data_ptr(), \ + (const C_TYPE*)bias.data_ptr(), \ + rows, \ + out_channels, \ + channels, \ + activation_type == ActivationFuncType::GATED_GELU, \ + InferenceContext::Instance().GetCurrentStream()); \ + } + +at::Tensor ds_gated_activation(at::Tensor& activation, at::Tensor& bias, int actFun) +{ + /* + Used in FF of Stable diffusion + */ + + const ActivationFuncType activation_type = static_cast(actFun); + + assert(activation_type == ActivationFuncType::GATED_GELU || + activation_type == ActivationFuncType::GATED_SILU); + + const int batch_size = activation.size(0); + const int seq_len = activation.size(1); + const int channels = activation.size(2); + + const int rows = batch_size * seq_len; + // Dimensionality is cut in half + const int out_channels = channels / 2; + + auto output = at::empty({batch_size, seq_len, out_channels}, activation.options()); + + DISPATCH_GATED_ACT(kFloat, float); + DISPATCH_GATED_ACT(kHalf, __half); +#ifdef BF16_AVAILABLE + DISPATCH_GATED_ACT(kBFloat16, __nv_bfloat16); +#endif + + return output; +} + +template +at::Tensor ds_bias_relu(at::Tensor& input, at::Tensor& bias) +{ + auto input_cont = input.contiguous(); + + int bsz = input_cont.size(0) * input_cont.size(1); + int intermediate_size = input_cont.size(2); + + launch_bias_relu((T*)input_cont.data_ptr(), + (T*)bias.data_ptr(), + intermediate_size, + bsz, + InferenceContext::Instance().GetCurrentStream()); + return input_cont; +} + +template +at::Tensor ds_bias_add(at::Tensor& input, at::Tensor& bias) +{ + auto input_cont = input.contiguous(); + + int bsz = input_cont.size(0) * input_cont.size(1); + int hidden_size = input_cont.size(2); + + launch_bias_add((T*)input_cont.data_ptr(), + (T*)bias.data_ptr(), + hidden_size, + bsz, + InferenceContext::Instance().GetCurrentStream()); + return input_cont; +} + +template +at::Tensor ds_bias_residual(at::Tensor& input, at::Tensor& residual, at::Tensor& bias) +{ + auto input_cont = input.contiguous(); + auto residual_cont = residual.contiguous(); + + int bsz = input_cont.size(0) * input_cont.size(1); + // launch_bias_residual((T*)input_cont.data_ptr(), + // (T*)residual_cont.data_ptr(), + // (T*)bias.data_ptr(), + // bsz, + // input_cont.size(2), + // (bias.size(0) > 1), + // InferenceContext::Instance().GetCurrentStream()); + return input_cont; +} + +#define DISPATCH_LAYER_NORM(T_TYPE, C_TYPE) \ + if (input.options().dtype() == torch::T_TYPE) { \ + launch_fused_ln((C_TYPE*)output.data_ptr(), \ + (const C_TYPE*)input.data_ptr(), \ + (const C_TYPE*)gamma.data_ptr(), \ + (const C_TYPE*)beta.data_ptr(), \ + epsilon, \ + rows, \ + elems_per_row, \ + InferenceContext::Instance().GetCurrentStream()); \ + } + +at::Tensor ds_layer_norm(at::Tensor& input, at::Tensor& gamma, at::Tensor& beta, float epsilon) +{ + const int rows = input.size(0) * input.size(1); + const int elems_per_row = input.size(2); + auto output = at::empty_like(input); + + DISPATCH_LAYER_NORM(kFloat, float); + DISPATCH_LAYER_NORM(kHalf, __half); +#ifdef BF16_AVAILABLE + DISPATCH_LAYER_NORM(kBFloat16, __nv_bfloat16); +#endif + + return output; +} + +#define DISPATCH_RMS_NORM(T_TYPE, C_TYPE) \ + if (input.options().dtype() == torch::T_TYPE) { \ + launch_rms_norm((C_TYPE*)output.data_ptr(), \ + (C_TYPE*)nullptr, \ + (const C_TYPE*)input.data_ptr(), \ + (const C_TYPE*)nullptr, \ + (const C_TYPE*)gamma.data_ptr(), \ + epsilon, \ + rows, \ + elems_per_row, \ + InferenceContext::Instance().GetCurrentStream()); \ + } + +at::Tensor ds_rms_norm(at::Tensor& input, at::Tensor& gamma, float epsilon) +{ + // Get number of dims of tensor + int num_dims = input.dim(); + const int rows = (num_dims == 2) ? input.size(0) : input.size(0) * input.size(1); + const int elems_per_row = (num_dims == 2) ? input.size(1) : input.size(2); + + auto output = at::empty_like(input); + + DISPATCH_RMS_NORM(kFloat, float); + DISPATCH_RMS_NORM(kHalf, __half); +#ifdef BF16_AVAILABLE + DISPATCH_RMS_NORM(kBFloat16, __nv_bfloat16); +#endif + + return output; +} + +#define DISPATCH_PRE_RMS_NORM(T_TYPE, C_TYPE) \ + if (input.options().dtype() == torch::T_TYPE) { \ + launch_rms_norm((C_TYPE*)output.data_ptr(), \ + (C_TYPE*)res_out.data_ptr(), \ + (const C_TYPE*)input.data_ptr(), \ + (const C_TYPE*)residual.data_ptr(), \ + (const C_TYPE*)gamma.data_ptr(), \ + epsilon, \ + rows, \ + elems_per_row, \ + InferenceContext::Instance().GetCurrentStream()); \ + } + +std::vector ds_pre_rms_norm(at::Tensor& input, + at::Tensor& residual, + at::Tensor& gamma, + float epsilon) +{ + // Get number of dims of tensor + int num_dims = input.dim(); + const int rows = (num_dims == 2) ? input.size(0) : input.size(0) * input.size(1); + const int elems_per_row = (num_dims == 2) ? input.size(1) : input.size(2); + + auto output = at::empty_like(input); + auto res_out = at::empty_like(residual); + + DISPATCH_PRE_RMS_NORM(kFloat, float); + DISPATCH_PRE_RMS_NORM(kHalf, __half); +#ifdef BF16_AVAILABLE + DISPATCH_PRE_RMS_NORM(kBFloat16, __nv_bfloat16); +#endif + + return {output, res_out}; +} + +template +void ds_layer_norm_internal(T* workspace, + at::Tensor& input, + at::Tensor& gamma, + at::Tensor& beta, + float epsilon) +{ + int bsz = input.size(0) * input.size(1); + launch_fused_ln(workspace, + (const T*)input.data_ptr(), + (const T*)gamma.data_ptr(), + (const T*)beta.data_ptr(), + epsilon, + bsz, + input.size(2), + InferenceContext::Instance().GetCurrentStream()); +} + +#define DISPATCH_LAYER_NORM_RESIDUAL(T_TYPE, C_TYPE) \ + if (input.options().dtype() == torch::T_TYPE) { \ + launch_fused_residual_ln((C_TYPE*)output.data_ptr(), \ + (const C_TYPE*)input.data_ptr(), \ + (const C_TYPE*)residual.data_ptr(), \ + (const C_TYPE*)bias.data_ptr(), \ + (const C_TYPE*)gamma.data_ptr(), \ + (const C_TYPE*)beta.data_ptr(), \ + epsilon, \ + rows, \ + elems_per_row, \ + InferenceContext::Instance().GetCurrentStream()); \ + } + +/* Currently only used in unit testing */ +at::Tensor ds_layer_norm_residual(at::Tensor& input, + at::Tensor& bias, + at::Tensor& residual, + at::Tensor& gamma, + at::Tensor& beta, + float epsilon) +{ + const int rows = input.size(0) * input.size(1); + const int elems_per_row = input.size(2); + auto output = at::empty_like(input); + + DISPATCH_LAYER_NORM_RESIDUAL(kFloat, float); + DISPATCH_LAYER_NORM_RESIDUAL(kHalf, __half); +#ifdef BF16_AVAILABLE + DISPATCH_LAYER_NORM_RESIDUAL(kBFloat16, __nv_bfloat16); +#endif + + return output; +} + +#define DISPATCH_PRE_LAYER_NORM_RESIDUAL(T_TYPE, C_TYPE) \ + if (input.options().dtype() == torch::T_TYPE) { \ + launch_fused_residual_ln_store_pre_ln_res( \ + (C_TYPE*)norm_output.data_ptr(), \ + (C_TYPE*)res_output.data_ptr(), \ + (const C_TYPE*)input.data_ptr(), \ + (const C_TYPE*)residual.data_ptr(), \ + (const C_TYPE*)bias.data_ptr(), \ + (const C_TYPE*)gamma.data_ptr(), \ + (const C_TYPE*)beta.data_ptr(), \ + epsilon, \ + rows, \ + elems_per_row, \ + InferenceContext::Instance().GetCurrentStream()); \ + } + +/* Currently only used in unit testing */ +std::vector ds_layer_norm_residual_store_pre_ln_res(at::Tensor& input, + at::Tensor& bias, + at::Tensor& residual, + at::Tensor& gamma, + at::Tensor& beta, + float epsilon) +{ + const int rows = input.size(0) * input.size(1); + const int elems_per_row = input.size(2); + auto norm_output = at::empty_like(input); + auto res_output = at::empty_like(input); + + DISPATCH_PRE_LAYER_NORM_RESIDUAL(kFloat, float); + DISPATCH_PRE_LAYER_NORM_RESIDUAL(kHalf, __half); +#ifdef BF16_AVAILABLE + DISPATCH_PRE_LAYER_NORM_RESIDUAL(kBFloat16, __nv_bfloat16); +#endif + + return {norm_output, res_output}; +} + +template +void quantized_gemm(void* output, + T* input, + at::Tensor& weight, + at::Tensor& qscale, + int groups, + int bsz, + int hidden_size) +{ + // T* weight16 = (T*)InferenceContext::Instance().GetWorkSpace() + 12 * hidden_size * bsz; + + auto options = at::TensorOptions() + .dtype(at::kHalf) + .layout(at::kStrided) + .device(at::kCUDA) + .requires_grad(false); + auto tmp = torch::empty(weight.sizes(), options); + T* weight16 = (T*)tmp.data_ptr(); + launch_dequantize(weight16, + (int8_t*)weight.data_ptr(), + (float*)qscale.data_ptr(), + weight.size(0), + weight.size(1), + groups, + InferenceContext::Instance().GetCurrentStream()); + + float alpha = (T)1.0; + float gemm_beta = (T)0.0; + cublas_gemm_ex(InferenceContext::Instance().GetCublasHandle(), + CUBLAS_OP_T, + CUBLAS_OP_N, + weight.size(0), + bsz, + weight.size(1), + &alpha, + &gemm_beta, + weight16, + (T*)input, + (T*)output, +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) + rocblas_gemm_algo_standard); +#else + CUBLAS_GEMM_DEFAULT_TENSOR_OP); +#endif +} + +template +at::Tensor qkv_unfused_cublas(at::Tensor& output, + at::Tensor& input, + at::Tensor& weight, + at::Tensor& q_scale, + at::Tensor& bias, + at::Tensor& gamma, + at::Tensor& beta, + const float epsilon, + bool add_bias, + bool q_int8, + bool transposed_mode) +{ + int bsz = input.size(0) * input.size(1); + T* workspace = (T*)InferenceContext::Instance().GetWorkSpace(); + workspace += (3 * bsz * input.size(2)); + ds_layer_norm_internal(workspace, input, gamma, beta, epsilon); + + if (q_int8) { + quantized_gemm( + output.data_ptr(), workspace, weight, q_scale, q_scale.size(0), bsz, input.size(2)); + } else { + float alpha = (T)1.0; + float gemm_beta = (T)0.0; + + cublasSetStream(InferenceContext::Instance().GetCublasHandle(), + InferenceContext::Instance().GetCurrentStream()); + cublas_gemm_ex(InferenceContext::Instance().GetCublasHandle(), + (transposed_mode ? CUBLAS_OP_T : CUBLAS_OP_N), + CUBLAS_OP_N, + weight.size(transposed_mode ? 0 : 1), + bsz, + input.size(2), + &alpha, + &gemm_beta, + (T*)weight.data_ptr(), + workspace, + (T*)output.data_ptr(), +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) + rocblas_gemm_algo_standard); +#else + CUBLAS_GEMM_DEFAULT_TENSOR_OP); +#endif + } + if (add_bias) + launch_bias_add((T*)output.data_ptr(), + (T*)bias.data_ptr(), + (transposed_mode || q_int8) ? weight.size(0) : weight.size(1), + bsz, + InferenceContext::Instance().GetCurrentStream()); + return torch::from_blob(workspace, input.sizes(), input.options()); +} + +template +std::vector ds_rms_qkv(at::Tensor& input, + at::Tensor& weight, + at::Tensor& q_scale, + at::Tensor& gamma, + const float epsilon, + bool q_int8, + bool transposed_mode) +{ + const int bsz = input.size(0) * input.size(1); + T* workspace = (T*)InferenceContext::Instance().GetWorkSpace(); + T* rms_norm_ptr = workspace + (3 * bsz * input.size(2)); + int out_size = (transposed_mode || q_int8) ? weight.size(0) : weight.size(1); + + auto options = at::TensorOptions() + .dtype(input.options().dtype()) + .layout(at::kStrided) + .device(at::kCUDA) + .requires_grad(false); + auto rms_norm = at::from_blob(rms_norm_ptr, input.sizes(), options); + auto output = at::from_blob(workspace, {input.size(0), input.size(1), out_size}, options); + + launch_rms_norm((T*)rms_norm.data_ptr(), + (T*)nullptr, + (const T*)input.data_ptr(), + (const T*)nullptr, + (const T*)gamma.data_ptr(), + epsilon, + bsz, + input.size(2), + InferenceContext::Instance().GetCurrentStream()); + + if (q_int8) { + quantized_gemm((T*)output.data_ptr(), + (T*)rms_norm.data_ptr(), + weight, + q_scale, + q_scale.size(0), + bsz, + input.size(2)); + } else { + float alpha = (T)1.0; + float gemm_beta = (T)0.0; + + cublasSetStream(InferenceContext::Instance().GetCublasHandle(), + InferenceContext::Instance().GetCurrentStream()); + cublas_gemm_ex(InferenceContext::Instance().GetCublasHandle(), + (transposed_mode ? CUBLAS_OP_T : CUBLAS_OP_N), + CUBLAS_OP_N, + weight.size(transposed_mode ? 0 : 1), + bsz, + input.size(2), + &alpha, + &gemm_beta, + (T*)weight.data_ptr(), + (T*)rms_norm.data_ptr(), + (T*)output.data_ptr(), +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) + rocblas_gemm_algo_standard); +#else + CUBLAS_GEMM_DEFAULT_TENSOR_OP); +#endif + } + + return {output, rms_norm}; +} + +template +std::vector ds_qkv_gemm(at::Tensor& input, + at::Tensor& weight, + at::Tensor& q_scale, + at::Tensor& bias, + at::Tensor& gamma, + at::Tensor& beta, + const float epsilon, + bool add_bias, + bool q_int8, + bool transposed_mode) +{ + int bsz = input.size(0) * input.size(1); + T* workspace = (T*)InferenceContext::Instance().GetWorkSpace(); + int out_size = (transposed_mode || q_int8) ? weight.size(0) : weight.size(1); + + auto options = at::TensorOptions() + .dtype(input.options().dtype()) + .layout(at::kStrided) + .device(at::kCUDA) + .requires_grad(false); + + auto output = at::from_blob(workspace, {input.size(0), input.size(1), out_size}, options); + auto inp_norm = qkv_unfused_cublas(output, + input, + weight, + q_scale, + bias, + gamma, + beta, + epsilon, + add_bias, + q_int8, + transposed_mode); + + return {output, inp_norm}; +} + +template +void quantized_gemm(at::Tensor& output, + at::Tensor& input, + at::Tensor& weight, + at::Tensor& qscale, + int groups, + int merge_count) +{ + int bsz = input.size(0) * input.size(1); + auto options = at::TensorOptions() + .dtype(input.options().dtype()) + .layout(at::kStrided) + .device(at::kCUDA) + .requires_grad(false); + auto weight16 = at::empty({weight.size(0), weight.size(1)}, options); + + launch_dequantize((T*)weight16.data_ptr(), + (int8_t*)weight.data_ptr(), + (float*)qscale.data_ptr(), + weight.size(0), + weight.size(1), + groups, + merge_count, + InferenceContext::Instance().GetCurrentStream()); + + float alpha = (T)1.0; + float gemm_beta = (T)0.0; + cublas_gemm_ex(InferenceContext::Instance().GetCublasHandle(), + CUBLAS_OP_T, + CUBLAS_OP_N, + weight.size(0), + bsz, + input.size(2), + &alpha, + &gemm_beta, + (T*)weight16.data_ptr(), + (T*)input.data_ptr(), + (T*)output.data_ptr(), +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) + rocblas_gemm_algo_standard); +#else + CUBLAS_GEMM_DEFAULT_TENSOR_OP); +#endif +} + +template +at::Tensor ds_linear_layer(at::Tensor& input, + at::Tensor& weight, + at::Tensor& bias, + bool add_bias, + bool do_flash_attn, + int num_heads, + bool transposed_mode, + float rope_theta) +{ + auto input_cont = input.contiguous(); + auto options = at::TensorOptions() + .dtype(input_cont.options().dtype()) + .layout(at::kStrided) + .device(at::kCUDA) + .requires_grad(false); + + int head_size = input_cont.size(2) / num_heads; + int bsz = input.size(0) * input.size(1); + int out_size = transposed_mode ? weight.size(0) : weight.size(1); + T* workspace = (T*)InferenceContext::Instance().GetWorkSpace(); + auto output = at::from_blob(workspace, {input.size(0), input.size(1), out_size}, options); + + float alpha = (T)1.0; + float gemm_beta = (T)0.0; + cublasSetStream(InferenceContext::Instance().GetCublasHandle(), + InferenceContext::Instance().GetCurrentStream()); + + cublas_gemm_ex(InferenceContext::Instance().GetCublasHandle(), + (transposed_mode ? CUBLAS_OP_T : CUBLAS_OP_N), + CUBLAS_OP_N, + weight.size(transposed_mode ? 0 : 1), + bsz, + input_cont.size(2), + &alpha, + &gemm_beta, + (T*)weight.data_ptr(), + (T*)input_cont.data_ptr(), + (T*)output.data_ptr(), +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) + rocblas_gemm_algo_standard); +#else + CUBLAS_GEMM_DEFAULT_TENSOR_OP); +#endif + if (add_bias) + launch_bias_add((T*)output.data_ptr(), + (T*)bias.data_ptr(), + weight.size(transposed_mode ? 0 : 1), + bsz, + InferenceContext::Instance().GetCurrentStream()); + bool add_padding = (head_size % 32 != 0 && head_size < 64) || (head_size % 64 != 0); + if (do_flash_attn) { + if (add_padding) { + int padded_head_size = head_size < 32 ? 32 : (head_size < 64 ? 64 : 128); + auto padded_output = workspace + output.numel(); + auto final_output = + padded_output + (input.size(0) * input.size(1) * 3 * num_heads * padded_head_size); + pad_data(padded_output, + workspace, + 3 * bsz * num_heads, + head_size, + padded_head_size, + InferenceContext::Instance().GetCurrentStream()); + + launch_bias_add_transform_0213( + final_output, + final_output + (input.size(0) * input.size(1) * num_heads * padded_head_size), + final_output + (input.size(0) * input.size(1) * 2 * num_heads * padded_head_size), + padded_output, + nullptr, + input.size(0), + input.size(1), + 0, + input.size(1), + (num_heads * padded_head_size), + num_heads, + -1, + -1, + false, + false, + InferenceContext::Instance().GetCurrentStream(), + 3, + input.size(1), + rope_theta); + return at::from_blob(final_output, + {3, input.size(0), num_heads, input.size(1), padded_head_size}, + options); + // return at::from_blob(padded_output, {input.size(0) * input.size(1), 3, num_heads, + // padded_head_size}, options); + } else { + auto final_output = workspace + output.numel(); + launch_bias_add_transform_0213( + final_output, + final_output + (input.size(0) * input.size(1) * input_cont.size(2)), + final_output + (input.size(0) * input.size(1) * 2 * input_cont.size(2)), + workspace, + nullptr, + input.size(0), + input.size(1), + 0, + input.size(1), + input_cont.size(2), + num_heads, + -1, + -1, + false, + false, + InferenceContext::Instance().GetCurrentStream(), + 3, + input.size(1), + rope_theta); + return at::from_blob( + final_output, {3, input.size(0), num_heads, input.size(1), head_size}, options); + // return at::from_blob(workspace, {input.size(0) * input.size(1), 3, num_heads, + // head_size}, options); + } + + } else + return output; +} + +template +std::vector add_padding(at::Tensor& query, at::Tensor& key, at::Tensor& value) +{ + int head_size = query.size(3); + int padded_head_size = head_size < 32 ? 32 : (head_size < 64 ? 64 : 128); + T* workspace = (T*)InferenceContext::Instance().GetWorkSpace(); + T* key_pad_ptr = workspace + padded_head_size * query.size(0) * query.size(1) * query.size(2); + T* value_pad_ptr = key_pad_ptr + padded_head_size * query.size(0) * query.size(1) * 128; + pad_head_seq(workspace, + (T*)query.data_ptr(), + query.size(0) * query.size(1), + query.size(2), + query.size(2), + head_size, + padded_head_size, + InferenceContext::Instance().GetCurrentStream()); + pad_head_seq(key_pad_ptr, + (T*)key.data_ptr(), + query.size(0) * query.size(1), + key.size(2), + 128, + head_size, + padded_head_size, + InferenceContext::Instance().GetCurrentStream()); + pad_head_seq(value_pad_ptr, + (T*)value.data_ptr(), + query.size(0) * query.size(1), + key.size(2), + 128, + head_size, + padded_head_size, + InferenceContext::Instance().GetCurrentStream()); + return { + at::from_blob(workspace, + {query.size(0), query.size(1), query.size(2), padded_head_size}, + query.options()), + at::from_blob( + key_pad_ptr, {query.size(0), query.size(1), 128, padded_head_size}, query.options()), + at::from_blob( + value_pad_ptr, {query.size(0), query.size(1), 128, padded_head_size}, query.options())}; +} + +template +std::vector padd_add_transform(at::Tensor& query, + at::Tensor& key, + at::Tensor& value, + int heads, + bool add_padding) +{ + int head_size = query.size(2) / heads; + int key_value_length = add_padding ? 128 : key.size(1); + int padded_head_size = add_padding ? (head_size < 32 ? 32 : (head_size < 64 ? 64 : 128)) + : head_size; + T* workspace = (T*)InferenceContext::Instance().GetWorkSpace(); + T* key_pad_ptr = workspace + padded_head_size * query.size(0) * heads * query.size(1); + T* value_pad_ptr = key_pad_ptr + padded_head_size * query.size(0) * heads * key_value_length; + launch_pad_add_transform_0213(workspace, + (T*)query.data_ptr(), + query.size(0), + query.size(2), + query.size(1), + query.size(1), + heads, + padded_head_size, + InferenceContext::Instance().GetCurrentStream()); + launch_pad_add_transform_0213(key_pad_ptr, + (T*)key.data_ptr(), + key.size(0), + key.size(2), + key.size(1), + key_value_length, + heads, + padded_head_size, + InferenceContext::Instance().GetCurrentStream()); + launch_pad_add_transform_0213(value_pad_ptr, + (T*)value.data_ptr(), + value.size(0), + value.size(2), + value.size(1), + key_value_length, + heads, + padded_head_size, + InferenceContext::Instance().GetCurrentStream()); + return { + at::from_blob( + workspace, {query.size(0), heads, query.size(1), padded_head_size}, query.options()), + at::from_blob(key_pad_ptr, + {query.size(0), heads, key_value_length, padded_head_size}, + query.options()), + at::from_blob(value_pad_ptr, + {query.size(0), heads, key_value_length, padded_head_size}, + query.options())}; +} + +template +at::Tensor ds_vector_matmul(at::Tensor& input, + at::Tensor& weight, + bool async_op, + at::Tensor& q_scale, + bool q_int8, + bool transposed_mode) +{ + auto options = at::TensorOptions() + .dtype(input.options().dtype()) + .layout(at::kStrided) + .device(at::kCUDA) + .requires_grad(false); + int out_size = (q_int8 || transposed_mode) ? weight.size(0) : weight.size(1); + int bsz = input.size(0) * input.size(1); + + T* workspace = (T*)InferenceContext::Instance().GetWorkSpace(); + auto output = at::from_blob(workspace, {input.size(0), input.size(1), out_size}, options); + if (q_int8) { + quantized_gemm(output.data_ptr(), + (T*)input.data_ptr(), + weight, + q_scale, + q_scale.size(0), + bsz, + input.size(2)); + } else { + float alpha = (T)1.0; + float gemm_beta = (T)0.0; + cublasSetStream(InferenceContext::Instance().GetCublasHandle(), + InferenceContext::Instance().GetCurrentStream(async_op)); + cublas_gemm_ex(InferenceContext::Instance().GetCublasHandle(), + (transposed_mode ? CUBLAS_OP_T : CUBLAS_OP_N), + CUBLAS_OP_N, + weight.size(transposed_mode ? 0 : 1), + bsz, + input.size(2), + &alpha, + &gemm_beta, + (T*)weight.data_ptr(), + (T*)input.data_ptr(), + (T*)output.data_ptr(), +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) + rocblas_gemm_algo_standard); +#else + CUBLAS_GEMM_DEFAULT_TENSOR_OP); +#endif + } + return output; +} + +template +at::Tensor ds_vector_matmul_int8(at::Tensor& input, + at::Tensor& weight, + at::Tensor& q_scale, + int groups, + int merge_count) +{ + auto input_cont = input.contiguous(); + auto options = at::TensorOptions() + .dtype(input_cont.options().dtype()) + .layout(at::kStrided) + .device(at::kCUDA) + .requires_grad(false); + + auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options); + + quantized_gemm(output, input_cont, weight, q_scale, groups, merge_count); + return output; +} + +template +at::Tensor mlp_unfused_cublas(at::Tensor& output, + at::Tensor& input, + at::Tensor& residual, + at::Tensor& input_bias, + at::Tensor& weight, + at::Tensor& weight1, + at::Tensor& bias, + at::Tensor& gamma, + at::Tensor& beta, + const float epsilon, + bool preLayerNorm, + bool mlp_after_attn, + at::Tensor& q_scale, + at::Tensor& q_scale1, + bool q_int8, + ActivationFuncType act_func_type, + bool transposed_mode) +{ + int bsz = input.size(0) * input.size(1); + T* inp_norm = (T*)InferenceContext::Instance().GetWorkSpace() + torch::numel(input) + + torch::numel(output); + T* intermediate = inp_norm + torch::numel(input); + + if (mlp_after_attn) { + launch_fused_residual_ln((T*)inp_norm, + (const T*)input.data_ptr(), + (const T*)residual.data_ptr(), + (const T*)input_bias.data_ptr(), + (const T*)gamma.data_ptr(), + (const T*)beta.data_ptr(), + epsilon, + bsz, + input.size(2), + InferenceContext::Instance().GetCurrentStream()); + } else { + ds_layer_norm_internal(inp_norm, input, gamma, beta, epsilon); + } + if (q_int8) { + quantized_gemm( + intermediate, inp_norm, weight, q_scale, q_scale.size(0), bsz, input.size(2)); + } else { + float alpha = (T)1.0; + float gemm_beta = (T)0.0; + cublasSetStream(InferenceContext::Instance().GetCublasHandle(), + InferenceContext::Instance().GetCurrentStream()); + cublas_gemm_ex(InferenceContext::Instance().GetCublasHandle(), + (transposed_mode ? CUBLAS_OP_T : CUBLAS_OP_N), + CUBLAS_OP_N, + weight.size(transposed_mode ? 0 : 1), + bsz, + input.size(2), + &alpha, + &gemm_beta, + (T*)weight.data_ptr(), + inp_norm, + intermediate, +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) + rocblas_gemm_algo_standard); +#else + CUBLAS_GEMM_DEFAULT_TENSOR_OP); +#endif + } + if (act_func_type == ActivationFuncType::GELU) { + launch_bias_gelu(intermediate, + (T*)bias.data_ptr(), + (transposed_mode || q_int8) ? weight.size(0) : weight.size(1), + bsz, + InferenceContext::Instance().GetCurrentStream()); + } else if (act_func_type == ActivationFuncType::ReLU) { + launch_bias_relu(intermediate, + (T*)bias.data_ptr(), + (transposed_mode || q_int8) ? weight.size(0) : weight.size(1), + bsz, + InferenceContext::Instance().GetCurrentStream()); + } + + if (q_int8) { + quantized_gemm(output.data_ptr(), + intermediate, + weight1, + q_scale1, + q_scale1.size(0), + bsz, + input.size(2)); + } else { + float alpha = (T)1.0; + float gemm_beta = (T)0.0; + cublasSetStream(InferenceContext::Instance().GetCublasHandle(), + InferenceContext::Instance().GetCurrentStream()); + cublas_gemm_ex(InferenceContext::Instance().GetCublasHandle(), + (transposed_mode ? CUBLAS_OP_T : CUBLAS_OP_N), + CUBLAS_OP_N, + weight1.size(transposed_mode ? 0 : 1), + bsz, + weight1.size(transposed_mode ? 1 : 0), + &alpha, + &gemm_beta, + (T*)weight1.data_ptr(), + intermediate, + (T*)output.data_ptr(), +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) + rocblas_gemm_algo_standard); +#else + CUBLAS_GEMM_DEFAULT_TENSOR_OP); +#endif + } + + return torch::from_blob(inp_norm, input.sizes(), input.options()); +} + +template +std::vector ds_mlp_gemm(at::Tensor& input, + at::Tensor& residual, + at::Tensor& input_bias, + at::Tensor& weight_interm, + at::Tensor& weight_out, + at::Tensor& bias, + at::Tensor& gamma, + at::Tensor& beta, + const float epsilon, + bool preLayerNorm, + bool mlp_after_attn, + at::Tensor& q_scale, + at::Tensor& q_scale1, + bool q_int8, + int activation_type, + bool transposed_mode) +{ + auto options = at::TensorOptions() + .dtype(input.options().dtype()) + .layout(at::kStrided) + .device(at::kCUDA) + .requires_grad(false); + + int out_size = (q_int8 || transposed_mode) ? weight_out.size(0) : weight_out.size(1); + auto output = + at::from_blob((T*)InferenceContext::Instance().GetWorkSpace() + torch::numel(input), + {input.size(0), input.size(1), out_size}, + options); + int bsz = input.size(0) * input.size(1); + + auto act_func_type = static_cast(activation_type); + auto res_add = mlp_unfused_cublas(output, + mlp_after_attn ? input : residual, + residual, + input_bias, + weight_interm, + weight_out, + bias, + gamma, + beta, + epsilon, + preLayerNorm, + mlp_after_attn, + q_scale, + q_scale1, + q_int8, + act_func_type, + transposed_mode); + + return {output, res_add}; +} + +template +std::vector ds_rms_mlp_gemm(at::Tensor& input, + at::Tensor& residual, + at::Tensor& weight_interm, + at::Tensor& weight_out, + at::Tensor& gamma, + const float epsilon, + at::Tensor& q_scale, + at::Tensor& q_scale1, + bool q_int8, + int activation_type, + bool transposed_mode) +{ + const int bsz = input.size(0) * input.size(1); + const size_t input_neurons = input.size(2); + const size_t mlp_1_out_neurons = transposed_mode ? weight_interm.size(0) + : weight_interm.size(1); + const size_t mlp_2_in_neurons = transposed_mode ? weight_out.size(1) : weight_out.size(0); + + auto options = at::TensorOptions() + .dtype(input.options().dtype()) + .layout(at::kStrided) + .device(at::kCUDA) + .requires_grad(false); + + T* output_ptr = (T*)InferenceContext::Instance().GetWorkSpace() + torch::numel(input); + T* inp_norm_ptr = output_ptr + torch::numel(input); + T* intermediate_ptr = inp_norm_ptr + torch::numel(input); + + auto output = at::from_blob(output_ptr, input.sizes(), options); + auto inp_norm = at::from_blob(inp_norm_ptr, input.sizes(), options); + auto intermediate_gemm = + at::from_blob(intermediate_ptr, + {input.size(0), input.size(1), static_cast(mlp_1_out_neurons)}, + options); + + auto act_func_type = static_cast(activation_type); + + // RMS Norm, we'll update the residual in-place + launch_rms_norm((T*)inp_norm.data_ptr(), + (T*)residual.data_ptr(), + (const T*)input.data_ptr(), + (const T*)residual.data_ptr(), + (const T*)gamma.data_ptr(), + epsilon, + bsz, + input_neurons, + InferenceContext::Instance().GetCurrentStream()); + + if (q_int8) { + quantized_gemm(intermediate_ptr, + (T*)inp_norm.data_ptr(), + weight_interm, + q_scale, + q_scale.size(0), + bsz, + input_neurons); + } else { + float alpha = (T)1.0; + float gemm_beta = (T)0.0; + cublasSetStream(InferenceContext::Instance().GetCublasHandle(), + InferenceContext::Instance().GetCurrentStream()); + cublas_gemm_ex(InferenceContext::Instance().GetCublasHandle(), + (transposed_mode ? CUBLAS_OP_T : CUBLAS_OP_N), + CUBLAS_OP_N, + mlp_1_out_neurons, + bsz, + input_neurons, + &alpha, + &gemm_beta, + (T*)weight_interm.data_ptr(), + (T*)inp_norm.data_ptr(), + intermediate_ptr, +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) + rocblas_gemm_algo_standard); +#else + CUBLAS_GEMM_DEFAULT_TENSOR_OP); +#endif + } + + if (act_func_type == ActivationFuncType::GELU) { + launch_bias_gelu(intermediate_ptr, + (T*)nullptr, + mlp_1_out_neurons, + bsz, + InferenceContext::Instance().GetCurrentStream()); + } else if (act_func_type == ActivationFuncType::ReLU) { + launch_bias_relu(intermediate_ptr, + (T*)nullptr, + mlp_1_out_neurons, + bsz, + InferenceContext::Instance().GetCurrentStream()); + } else if (act_func_type == ActivationFuncType::GATED_GELU) { + launch_gated_activation(intermediate_ptr, + (const T*)intermediate_ptr, + (const T*)nullptr, + bsz, + mlp_1_out_neurons, + mlp_1_out_neurons, + true, + InferenceContext::Instance().GetCurrentStream()); + } else if (act_func_type == ActivationFuncType::GATED_SILU) { + launch_gated_activation(intermediate_ptr, + (const T*)intermediate_ptr, + (const T*)nullptr, + bsz, + mlp_1_out_neurons, + mlp_1_out_neurons, + false, + InferenceContext::Instance().GetCurrentStream()); + } + + if (q_int8) { + quantized_gemm(output.data_ptr(), + intermediate_ptr, + weight_out, + q_scale1, + q_scale1.size(0), + bsz, + input.size(2)); + } else { + float alpha = (T)1.0; + float gemm_beta = (T)0.0; + cublasSetStream(InferenceContext::Instance().GetCublasHandle(), + InferenceContext::Instance().GetCurrentStream()); + cublas_gemm_ex(InferenceContext::Instance().GetCublasHandle(), + (transposed_mode ? CUBLAS_OP_T : CUBLAS_OP_N), + CUBLAS_OP_N, + input_neurons, + bsz, + mlp_2_in_neurons, + &alpha, + &gemm_beta, + (T*)weight_out.data_ptr(), + intermediate_ptr, + (T*)output.data_ptr(), +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) + rocblas_gemm_algo_standard, +#else + CUBLAS_GEMM_DEFAULT_TENSOR_OP, +#endif + mlp_1_out_neurons); + } + + return {output, residual}; +} + +template +at::Tensor fused_gemm_gelu(at::Tensor& input, + at::Tensor& weight, + at::Tensor& weight_scale, + at::Tensor& bias, + at::Tensor& weight_out, + at::Tensor& weight_out_scale, + bool q_int8, + bool transposed_mode) +{ + auto options = at::TensorOptions() + .dtype(input.options().dtype()) + .layout(at::kStrided) + .device(at::kCUDA) + .requires_grad(false); + + int intm_dim = (transposed_mode || q_int8) ? weight.size(0) : weight.size(1); + + // auto output = at::from_blob((T*)InferenceContext::Instance().GetWorkSpace() + + // torch::numel(input), + // {input.size(0), input.size(1), out_size}, + // options); + // T* intermediate = (T*)input.data_ptr() + torch::numel(input); + auto intermediate = at::empty({input.size(0), input.size(1), intm_dim}, options); + + int bsz = input.size(0) * input.size(1); + + float alpha = (T)1.0; + float gemm_beta = (T)0.0; + if (q_int8) { + quantized_gemm(intermediate.data_ptr(), + (T*)input.data_ptr(), + weight, + weight_scale, + weight_scale.size(0), + bsz, + input.size(2)); + } else { + cublasSetStream(InferenceContext::Instance().GetCublasHandle(), + InferenceContext::Instance().GetCurrentStream()); + cublas_gemm_ex(InferenceContext::Instance().GetCublasHandle(), + (transposed_mode ? CUBLAS_OP_T : CUBLAS_OP_N), + CUBLAS_OP_N, + intm_dim, + bsz, + input.size(2), + &alpha, + &gemm_beta, + (T*)weight.data_ptr(), + (T*)input.data_ptr(), + (T*)intermediate.data_ptr(), +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) + rocblas_gemm_algo_standard); +#else + CUBLAS_GEMM_DEFAULT_TENSOR_OP); +#endif + } + launch_bias_gelu((T*)intermediate.data_ptr(), + (T*)bias.data_ptr(), + intm_dim, + bsz, + InferenceContext::Instance().GetCurrentStream()); + + int out_size = (transposed_mode || q_int8) ? weight_out.size(0) : weight_out.size(1); + auto output = at::empty({input.size(0), input.size(1), out_size}, options); + if (q_int8) { + quantized_gemm(output.data_ptr(), + (T*)intermediate.data_ptr(), + weight_out, + weight_out_scale, + weight_out_scale.size(0), + bsz, + input.size(2)); + } else { + cublas_gemm_ex(InferenceContext::Instance().GetCublasHandle(), + (transposed_mode ? CUBLAS_OP_T : CUBLAS_OP_N), + CUBLAS_OP_N, + out_size, + bsz, + intm_dim, + &alpha, + &gemm_beta, + (T*)weight_out.data_ptr(), + (T*)intermediate.data_ptr(), + (T*)output.data_ptr(), +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) + rocblas_gemm_algo_standard); +#else + CUBLAS_GEMM_DEFAULT_TENSOR_OP); +#endif + } + // cudaEventRecord(InferenceContext::Instance().GetCompEvent(2), + // InferenceContext::Instance().GetCurrentStream(true)); + return output; +} + +template +at::Tensor& residual_add_bias(at::Tensor& hidden_state, + at::Tensor& residual, + const at::Tensor& attention_output, + const at::Tensor& attention_bias, + const at::Tensor& final_bias, + const int mp_size, + const bool mlp_after_attn, + const bool add_bias, + const bool preln) +{ + int bsz = residual.size(0) * residual.size(1); + int hidden_size = residual.size(2); + if (mlp_after_attn) + launch_bias_residual(static_cast(residual.data_ptr()), + static_cast(hidden_state.data_ptr()), + static_cast(attention_output.data_ptr()), + static_cast(final_bias.data_ptr()), + static_cast(attention_bias.data_ptr()), + bsz, + hidden_size, + mp_size, + preln, + InferenceContext::Instance().GetCurrentStream()); + else + launch_gptj_residual_add( + static_cast(residual.data_ptr()), + static_cast(hidden_state.data_ptr()), + static_cast(attention_output.data_ptr()), + static_cast(final_bias.data_ptr()), + static_cast((add_bias ? attention_bias.data_ptr() : nullptr)), + hidden_size, + bsz, + mp_size, + InferenceContext::Instance().GetCurrentStream()); + return residual; +} + +#define DISPATCH_VECTOR_ADD(T_TYPE, C_TYPE) \ + if (a.scalar_type() == at::k##T_TYPE) { \ + launch_vector_add((C_TYPE*)(a.data_ptr()), \ + (const C_TYPE*)(a.data_ptr()), \ + (const C_TYPE*)(b.data_ptr()), \ + gamma, \ + total_elems, \ + InferenceContext::Instance().GetCurrentStream()); \ + } + +at::Tensor& _vector_add(at::Tensor& a, at::Tensor& b, float gamma) +{ + const int total_elems = a.numel(); + + DISPATCH_VECTOR_ADD(Float, float) + DISPATCH_VECTOR_ADD(Half, __half) +#ifdef BF16_AVAILABLE + DISPATCH_VECTOR_ADD(BFloat16, __nv_bfloat16) +#endif + + return a; +} + +std::vector apply_rotary_pos_emb(at::Tensor& mixed_query, + at::Tensor& key_layer, + unsigned rotary_dim, + unsigned offset, + unsigned num_heads, + bool rotate_half, + float rope_theta) +{ + auto query_cont = mixed_query.contiguous(); + auto key_cont = key_layer.contiguous(); + + unsigned bsz = mixed_query.size(0); + unsigned head_size = mixed_query.size(2) / num_heads; + unsigned seq_len = mixed_query.size(1); + + if (mixed_query.scalar_type() == at::kFloat) + launch_apply_rotary_pos_emb((float*)query_cont.data_ptr(), + (float*)key_cont.data_ptr(), + head_size, + seq_len, + rotary_dim, + offset, + num_heads, + bsz, + rope_theta, + InferenceContext::Instance().GetCurrentStream(), + InferenceContext::Instance().GetMaxTokenLength()); + else + launch_apply_rotary_pos_emb<__half>((__half*)query_cont.data_ptr(), + (__half*)key_cont.data_ptr(), + head_size, + seq_len, + rotary_dim, + offset, + num_heads, + bsz, + rope_theta, + InferenceContext::Instance().GetCurrentStream(), + InferenceContext::Instance().GetMaxTokenLength()); + return {query_cont, key_cont}; +} + +#define DISPATCH_MOE_RESIDUAL(T_TYPE, C_TYPE) \ + if (moe_res.scalar_type() == torch::T_TYPE) { \ + launch_moe_res_matmul((C_TYPE*)moe_res.data_ptr(), \ + (C_TYPE*)coef.data_ptr(), \ + (C_TYPE*)output.data_ptr(), \ + M, \ + N, \ + InferenceContext::Instance().GetCurrentStream()); \ + } + +at::Tensor moe_res_matmul(at::Tensor& moe_res, at::Tensor& coef, at::Tensor& output) +{ + int M = moe_res.size(0) * moe_res.size(1); + int N = moe_res.size(2); + InferenceContext::Instance().SynchComm(); + + DISPATCH_MOE_RESIDUAL(kFloat, float) + DISPATCH_MOE_RESIDUAL(kHalf, __half) +#ifdef BF16_AVAILABLE + DISPATCH_MOE_RESIDUAL(kBFloat16, __nv_bfloat16) +#endif + + return output; +} + +void ds_release_workspace() { InferenceContext::Instance().release_workspace(); } + +bool ds_retake_workspace() { return InferenceContext::Instance().retake_workspace(); } + +template +at::Tensor ds_dequantize(at::Tensor& weight, at::Tensor& qscale, int groups) +{ + auto options = at::TensorOptions() + .dtype(torch::kFloat16) + .layout(at::kStrided) + .device(at::kCUDA) + .requires_grad(false); + auto weight16 = at::empty({weight.size(0), weight.size(1)}, options); + + launch_dequantize((T*)weight16.data_ptr(), + (int8_t*)weight.data_ptr(), + (float*)qscale.data_ptr(), + weight.size(0), + weight.size(1), + groups, + InferenceContext::Instance().GetCurrentStream()); + + return weight16; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("softmax_context_int8", + &ds_softmax_context1<__half>, + "DeepSpeed attention with int8 (CUDA)"); + + // The following functions handle type dispatching internally + m.def("gated_activation", &ds_gated_activation, "DeepSpeed Bias GEGLU (CUDA)"); + m.def("layer_norm", &ds_layer_norm, "DeepSpeed layer norm (CUDA)"); + m.def( + "_layer_norm_residual", &ds_layer_norm_residual, "DeepSpeed layer norm + residual (CUDA)"); + m.def("layer_norm_residual_store_pre_ln_res", + &ds_layer_norm_residual_store_pre_ln_res, + "DeepSpeed layer norm + store pre Layernorm residual (CUDA)"); + m.def("rms_norm", &ds_rms_norm, "DeepSpeed rms norm (CUDA)"); + m.def("pre_rms_norm", &ds_pre_rms_norm, "DeepSpeed pre rms norm (CUDA)"); + m.def("_vector_add", &_vector_add, "DeepSpeed vector add (CUDA)"); + m.def("apply_rotary_pos_emb", &apply_rotary_pos_emb, "DeepSpeed mlp with fp16 (CUDA)"); + m.def("moe_res_matmul", &moe_res_matmul, "DeepSpeed moe residual matmul (CUDA)"); + m.def("reset_cache", &reset_cache, "Reset Cache for generation tasks"); + m.def("release_workspace", &ds_release_workspace, "DeepSpeed Release Workspace"); + m.def("retake_workspace", &ds_retake_workspace, "DeepSpeed Retake Workspace"); + + // The following functions are templated and need to be explicitly instantiated and bound + // to different python methods +#define DEF_OPS(_name, _dtype) \ + m.def("softmax_" #_name, &ds_softmax<_dtype>, "DeepSpeed SoftMax with " #_name " (CUDA)"); \ + m.def("softmax_context_" #_name, \ + &ds_softmax_context<_dtype>, \ + "DeepSpeed attention with " #_name " (CUDA)"); \ + m.def("bias_gelu_" #_name, &ds_bias_gelu<_dtype>, "DeepSpeed Gelu with " #_name " (CUDA)"); \ + m.def("bias_add_" #_name, &ds_bias_add<_dtype>, "DeepSpeed Bias Add with " #_name " (CUDA)"); \ + m.def("bias_relu_" #_name, &ds_bias_relu<_dtype>, "DeepSpeed ReLU with " #_name " (CUDA)"); \ + m.def("bias_residual_" #_name, \ + &ds_bias_residual<_dtype>, \ + "DeepSpeed residual-bias add with " #_name " (CUDA)"); \ + m.def("qkv_gemm_" #_name, &ds_qkv_gemm<_dtype>, "DeepSpeed qkv gemm with " #_name " (CUDA)"); \ + m.def("rms_qkv_gemm_" #_name, \ + &ds_rms_qkv<_dtype>, \ + "DeepSpeed rms qkv gemm with " #_name " (CUDA)"); \ + m.def("mlp_gemm_" #_name, &ds_mlp_gemm<_dtype>, "DeepSpeed mlp with " #_name " (CUDA)"); \ + m.def("rms_mlp_gemm_" #_name, \ + &ds_rms_mlp_gemm<_dtype>, \ + "DeepSpeed rms mlp gemm with " #_name " (CUDA)"); \ + m.def("vector_matmul_" #_name, \ + &ds_vector_matmul<_dtype>, \ + "DeepSpeed vector-MM with " #_name " (CUDA)"); \ + m.def("linear_layer_" #_name, \ + &ds_linear_layer<_dtype>, \ + "DeepSpeed linear_layer with " #_name " (CUDA)"); \ + m.def("fused_gemm_gelu_" #_name, \ + &fused_gemm_gelu<_dtype>, \ + "DeepSpeed mlp with " #_name " (CUDA)"); \ + m.def("residual_add_bias_" #_name, \ + &residual_add_bias<_dtype>, \ + "DeepSpeed residual add with " #_name " (CUDA)"); \ + m.def("einsum_sec_sm_ecm_" #_name, \ + &einsum_sec_sm_ecm<_dtype>, \ + "DeepSpeed vector-MM with " #_name " (CUDA)"); \ + m.def("add_padding_" #_name, \ + &add_padding<_dtype>, \ + "DeepSpeed residual add with " #_name " (CUDA)"); \ + m.def("pad_transform_" #_name, \ + &padd_add_transform<_dtype>, \ + "DeepSpeed residual add with " #_name " (CUDA)"); \ + m.def("allocate_workspace_" #_name, \ + &allocate_workspace<_dtype>, \ + "DeepSpeed memory allocation for GPT inference with " #_name " (CUDA)"); \ + m.def("dequantize_" #_name, \ + &ds_dequantize<_dtype>, \ + "DeepSpeed dequantize with " #_name " (CUDA)"); + + DEF_OPS(fp32, float); + DEF_OPS(fp16, __half); +#ifdef BF16_AVAILABLE + DEF_OPS(bf16, __nv_bfloat16); +#endif +} diff --git a/toolbox/DeepSpeed/v0.15.3/patches/csrc/transformer/inference/csrc/relu.cu b/toolbox/DeepSpeed/v0.15.3/patches/csrc/transformer/inference/csrc/relu.cu new file mode 100644 index 0000000000000000000000000000000000000000..66cd9439c95946e644756650a78d22c07afc3446 --- /dev/null +++ b/toolbox/DeepSpeed/v0.15.3/patches/csrc/transformer/inference/csrc/relu.cu @@ -0,0 +1,88 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +/* Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. */ +/* All Rights Reserved. */ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include "conversion_utils.h" +#include "inference_cuda_layers.h" +#include "memory_access_utils.h" + +namespace cg = cooperative_groups; +#define MAX_CAP 4 +#define MAX_SEQ 2048 + +inline __device__ float relu(const float x) { return x < 0 ? 0 : x; } + +/* +In-place relu(biasAdd(x)) for channels last +*/ +template +__global__ void fused_bias_relu(T* input, const T* bias, int total_count, int intermediate_size) +{ + // Input restriction: intermediate_size % vals_per_access == 0 + constexpr int granularity = 16; + constexpr int values_per_access = granularity / sizeof(T); + const int offset = (blockIdx.x * blockDim.x + threadIdx.x) * values_per_access; + + if (offset < total_count) { + T data[values_per_access]; + T data_bias[values_per_access]; + mem_access::load_global(data, input + offset); + mem_access::load_global( + data_bias, bias + (offset % intermediate_size), bias != nullptr); + +#pragma unroll + for (int i = 0; i < values_per_access; i++) { + float data_f = conversion::to(data[i]); + float bias_f = conversion::to(data_bias[i]); + data[i] = conversion::to(relu(data_f + bias_f)); + } + + mem_access::store_global(input + offset, data); + } +} + +template +void launch_bias_relu(T* input, + const T* bias, + int intermediate_size, + int batch_size, + cudaStream_t stream) +{ + constexpr int threads = 1024; + constexpr int granularity = 16; + + const int total_count = batch_size * intermediate_size; + const int elems_per_block = threads * (granularity / sizeof(T)); + dim3 block_dims(threads); + dim3 grid_dims((total_count + elems_per_block - 1) / elems_per_block); + + fused_bias_relu<<>>( + input, bias, total_count, intermediate_size); +} + +#define INSTANTIATE_LAUNCH_BIAS_RELU(T) \ + template void launch_bias_relu(T*, const T*, int, int, cudaStream_t); + +INSTANTIATE_LAUNCH_BIAS_RELU(float) +#ifdef BF16_AVAILABLE +INSTANTIATE_LAUNCH_BIAS_RELU(__nv_bfloat16) +#endif +INSTANTIATE_LAUNCH_BIAS_RELU(__half) diff --git a/toolbox/DeepSpeed/v0.15.3/patches/csrc/transformer/inference/csrc/rms_norm.cu b/toolbox/DeepSpeed/v0.15.3/patches/csrc/transformer/inference/csrc/rms_norm.cu new file mode 100644 index 0000000000000000000000000000000000000000..a8e6eb4a5ca64b15620d71c17f9398593df01b1c --- /dev/null +++ b/toolbox/DeepSpeed/v0.15.3/patches/csrc/transformer/inference/csrc/rms_norm.cu @@ -0,0 +1,280 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +/* Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. */ +/* All Rights Reserved. */ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include "conversion_utils.h" +#include "ds_kernel_utils.h" +#include "inference_cuda_layers.h" +#include "memory_access_utils.h" +#include "reduction_utils.h" + +namespace cg = cooperative_groups; +using rop = reduce::ROpType; + +namespace rms { +constexpr int granularity = 16; +} // namespace rms + +template +__global__ void rms_norm(T* output, const T* vals, const T* gamma, float epsilon, int elems_per_row) +{ + constexpr int T_per_load = rms::granularity / sizeof(T); + + cg::thread_block tb = cg::this_thread_block(); + cg::thread_block_tile warp = cg::tiled_partition(tb); + + // X-dimension of the block + const int block_offset = (tb.group_index().x * (maxThreads / threadsPerGroup) * elems_per_row) + + (tb.thread_index().y * elems_per_row); + const int thread_offset = tb.thread_index().x * T_per_load; + const int base_offset = block_offset + thread_offset; + const int stride = blockDim.x * T_per_load; + + float var_sum = reduce::init(); + + const T* input_base = vals + base_offset; + + T local_buffer[UNROLL * T_per_load]; + +#pragma unroll + for (int i = 0; i < UNROLL; i++) { + T* iteration_buffer = local_buffer + (i * T_per_load); + + mem_access::load_global(iteration_buffer, + input_base + (i * stride), + thread_offset + (i * stride) < elems_per_row); + +#pragma unroll + for (int j = 0; j < T_per_load; j++) { + float up_cast = conversion::to(iteration_buffer[j]); + float sq_val = up_cast * up_cast; + var_sum = reduce::element(var_sum, sq_val); + } + } + + reduce::partitioned_block(tb, warp, var_sum); + const float var = var_sum / elems_per_row; + const T denom = conversion::to(__frsqrt_rn(var + epsilon)); + + T* block_output = output + block_offset; + +#pragma unroll + for (int i = 0; i < UNROLL; i++) { + T* iteration_buffer = local_buffer + (i * T_per_load); + const int iter_idx = i * stride + thread_offset; + const bool do_loads = (iter_idx < elems_per_row); + + T gamma_local[T_per_load]; + + mem_access::load_global(gamma_local, gamma + iter_idx, do_loads); + +#pragma unroll + for (int j = 0; j < T_per_load; j++) { + iteration_buffer[j] *= denom; + iteration_buffer[j] *= gamma_local[j]; + } + + if (do_loads) { + mem_access::store_global(block_output + iter_idx, iteration_buffer); + } + } +} + +template +__global__ void pre_rms_norm(T* output, + T* res_out, + const T* vals, + const T* residual, + const T* gamma, + float epsilon, + int elems_per_row) +{ + constexpr int T_per_load = rms::granularity / sizeof(T); + + cg::thread_block tb = cg::this_thread_block(); + cg::thread_block_tile warp = cg::tiled_partition(tb); + + // X-dimension of the block + const int block_offset = (tb.group_index().x * (maxThreads / threadsPerGroup) * elems_per_row) + + (tb.thread_index().y * elems_per_row); + const int thread_offset = tb.thread_index().x * T_per_load; + const int base_offset = block_offset + thread_offset; + const int stride = blockDim.x * T_per_load; + + float var_sum = reduce::init(); + + const T* input_base = vals + base_offset; + const T* residual_base = residual + base_offset; + T* res_output = res_out + base_offset; + + T local_buffer[UNROLL * T_per_load]; + +#pragma unroll + for (int i = 0; i < UNROLL; i++) { + T* iteration_buffer = local_buffer + (i * T_per_load); + T residual_buffer[T_per_load]; + + const int iter_offset = i * stride + thread_offset; + const bool do_loads = (iter_offset < elems_per_row); + + mem_access::load_global( + iteration_buffer, input_base + (i * stride), do_loads); + mem_access::load_global( + residual_buffer, residual_base + (i * stride), do_loads); + +#pragma unroll + for (int j = 0; j < T_per_load; j++) { + iteration_buffer[j] += residual_buffer[j]; + float vals_up_cast = conversion::to(iteration_buffer[j]); + + var_sum = reduce::element(var_sum, vals_up_cast * vals_up_cast); + } + + if (do_loads) { + mem_access::store_global(res_output + i * stride, iteration_buffer); + } + } + + reduce::partitioned_block(tb, warp, var_sum); + const float var = var_sum / elems_per_row; + const T denom = conversion::to(__frsqrt_rn(var + epsilon)); + + T* block_output = output + block_offset; + +#pragma unroll + for (int i = 0; i < UNROLL; i++) { + T* iteration_buffer = local_buffer + (i * T_per_load); + const int iter_idx = i * stride + thread_offset; + const bool do_loads = (iter_idx < elems_per_row); + + T gamma_local[T_per_load]; + + mem_access::load_global(gamma_local, gamma + iter_idx, do_loads); + +#pragma unroll + for (int j = 0; j < T_per_load; j++) { + iteration_buffer[j] *= denom; + iteration_buffer[j] *= gamma_local[j]; + } + + if (do_loads) { + mem_access::store_global(block_output + iter_idx, iteration_buffer); + } + } +} + +#define LAUNCH_RMS_NORM(UNROLL, threadsPerGroup, maxThreads) \ + rms_norm \ + <<>>(norm_output, vals, gamma, epsilon, elems_per_row); + +#define LAUNCH_PRE_RMS_NORM(UNROLL, threadsPerGroup, maxThreads) \ + pre_rms_norm<<>>( \ + norm_output, res_output, vals, residual, gamma, epsilon, elems_per_row); + +#define LAUNCH_ALL_RMS_NORM(UNROLL, threadsPerGroup, maxThreads) \ + if (pre_norm) { \ + LAUNCH_PRE_RMS_NORM(UNROLL, threadsPerGroup, maxThreads) \ + } else { \ + LAUNCH_RMS_NORM(UNROLL, threadsPerGroup, maxThreads) \ + } + +template +void launch_rms_norm(T* norm_output, + T* res_output, + const T* vals, + const T* residual, + const T* gamma, + float epsilon, + int rows, + int elems_per_row, + cudaStream_t stream) +{ + // 8 for __half, 4 for float + constexpr int T_per_load = rms::granularity / sizeof(T); + constexpr int maxThreads = 256; + constexpr int internalUnroll = sizeof(T) == 4 ? 4 : 2; + + const bool is_subblock_schedule = (elems_per_row <= 128) ? true : false; + const int h_per_step = is_subblock_schedule ? T_per_load : T_per_load * internalUnroll; + + // Scheduling concern: may be slightly faster for some inputs to assign multiple stages of + // warp-sized blocks rather than stepping up to 64/96 threads + const int one_step_threads = next_pow2((elems_per_row + h_per_step - 1) / h_per_step); + const int threads_per_group = (one_step_threads < maxThreads) ? one_step_threads : maxThreads; + + const int groups_per_block_max = + is_subblock_schedule ? (maxThreads + threads_per_group - 1) / threads_per_group : 1; + const int groups_per_block = (rows < groups_per_block_max) ? rows : groups_per_block_max; + const int groups_launch = (groups_per_block + rows - 1) / groups_per_block; + + dim3 block(threads_per_group, groups_per_block); + dim3 grid(groups_launch); + + const int elems_per_step = threads_per_group * h_per_step; + const int external_unRoll = (elems_per_row + elems_per_step - 1) / elems_per_step; + + bool pre_norm = (residual == nullptr) ? false : true; + + if (is_subblock_schedule) { + // <=128 + if (threads_per_group == 1) { + LAUNCH_ALL_RMS_NORM(1, 1, maxThreads); + } else if (threads_per_group == 2) { + LAUNCH_ALL_RMS_NORM(1, 2, maxThreads); + } else if (threads_per_group == 4) { + LAUNCH_ALL_RMS_NORM(1, 4, maxThreads); + } else if (threads_per_group == 8) { + LAUNCH_ALL_RMS_NORM(1, 8, maxThreads); + } else if (threads_per_group == 16) { + LAUNCH_ALL_RMS_NORM(1, 16, maxThreads); + } + } else if (external_unRoll == 1) { + // 129 - 4096 elems + // (this can launch with 1-7 warps as well) + LAUNCH_ALL_RMS_NORM(1 * internalUnroll, maxThreads, maxThreads); + } else if (external_unRoll == 2) { + // 4097 - 8192 elems + LAUNCH_ALL_RMS_NORM(2 * internalUnroll, maxThreads, maxThreads); + } else if (external_unRoll == 3) { + // 8193 - 12288 elems + LAUNCH_ALL_RMS_NORM(3 * internalUnroll, maxThreads, maxThreads); + } else if (external_unRoll == 4) { + // 12289 - 16384 elems + LAUNCH_ALL_RMS_NORM(4 * internalUnroll, maxThreads, maxThreads); + } +} + +#define INSTANTIATE_LAUNCH_RMS_NORM(T) \ + template void launch_rms_norm(T * norm_output, \ + T * res_output, \ + const T* vals, \ + const T* residual, \ + const T* gamma, \ + float epsilon, \ + int rows, \ + int elems_per_row, \ + cudaStream_t stream); + +INSTANTIATE_LAUNCH_RMS_NORM(float) +INSTANTIATE_LAUNCH_RMS_NORM(__half) +#ifdef BF16_AVAILABLE +INSTANTIATE_LAUNCH_RMS_NORM(__nv_bfloat16) +#endif diff --git a/toolbox/DeepSpeed/v0.15.3/patches/csrc/transformer/inference/csrc/softmax.cu b/toolbox/DeepSpeed/v0.15.3/patches/csrc/transformer/inference/csrc/softmax.cu new file mode 100644 index 0000000000000000000000000000000000000000..1eefd616e97baa034f11e3d31b5572778be2b516 --- /dev/null +++ b/toolbox/DeepSpeed/v0.15.3/patches/csrc/transformer/inference/csrc/softmax.cu @@ -0,0 +1,579 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +/* Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. */ +/* All Rights Reserved. */ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include +#include "conversion_utils.h" +#include "inference_cuda_layers.h" + +#ifndef __HIP_PLATFORM_AMD__ +#include +#endif +#include +#include +#include + +#define MAX_REG_SIZE 8 + +#define minus_infinity -10000.0 + +void CheckCudaErrorAux(const char* file, unsigned line) +{ + cudaError_t err = cudaGetLastError(); + if (err == cudaSuccess) return; + std::cerr << cudaGetErrorString(err) << "(" << err << ") at " << file << ":" << line + << std::endl; + throw std::runtime_error("CUDA ERROR!!!\n"); +} + +#define CUDA_CHECK_ERROR() CheckCudaErrorAux(__FILE__, __LINE__) + +namespace cg = cooperative_groups; + +template +__global__ void attn_softmax_v2(T* vals, + T* mask, + T* alibi, + float layer_scale, + bool triangular, + bool recompute, + bool local_attention, + int window_size, + int total_count, + int heads, + int sequence_length, + int num_seq, + int head_offset, + int mask_stride, + int mp_size, + int reduceWidth) +{ + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + float2 low_data[MAX_REG_SIZE]; + float2 high_data[MAX_REG_SIZE]; + const T zero_h = conversion::to(0.f); + + int wid = threadIdx.x >> 5; + int lane = threadIdx.x & 0x1f; + int warp_num = blockDim.x >> 5; + + int reduce_blocks = reduceWidth >> 5; + int seq_lane = threadIdx.x % reduceWidth; + + __shared__ float partialSum[MAX_WARP_NUM]; + + int iter_offset = blockIdx.x * (warp_num / reduce_blocks) + (wid / reduce_blocks); + int batch_idx = iter_offset / (num_seq * heads); + int alibi_offset = batch_idx * heads * mp_size + head_offset; + int mask_offset = batch_idx * mask_stride + (iter_offset % mask_stride); + + if (iter_offset < total_count) { + vals += (iter_offset * sequence_length); + + alibi_offset = (alibi_offset + ((iter_offset / num_seq) % heads)) * sequence_length; + mask_offset = mask_offset * sequence_length; + int seq_id = iter_offset % num_seq; + + int real_seq_id = seq_id + (num_seq == sequence_length ? 0 : sequence_length); + int window_stride4 = (local_attention && (real_seq_id >> 2) > (window_size >> 2)) + ? (real_seq_id >> 2) - (window_size >> 2) + : 0; + int window_stride = + (local_attention && real_seq_id >= window_size) ? real_seq_id - window_size : -1; + + float max_val = minus_infinity; + // if (lane == 0) printf("%d, %d: %d \n", wid, blockIdx.x, mask_offset); + for (int i = 0; i < iterations; i++) { + int data_id = i * (reduceWidth << 2) + (seq_lane); + bool check = (data_id >> 2) >= window_stride4; + bool low_x_check = check && (data_id < sequence_length) && + (!triangular || (data_id <= seq_id)) && (data_id > window_stride); + bool low_y_check = check && ((data_id + reduceWidth) < sequence_length) && + (!triangular || ((data_id + reduceWidth) <= seq_id)) && + ((data_id + reduceWidth) > window_stride); + bool high_x_check = check && ((data_id + reduceWidth * 2) < sequence_length) && + (!triangular || ((data_id + reduceWidth * 2) <= seq_id)) && + ((data_id + reduceWidth * 2) > window_stride); + bool high_y_check = check && ((data_id + reduceWidth * 3) < sequence_length) && + (!triangular || ((data_id + reduceWidth * 3) <= seq_id)) && + ((data_id + reduceWidth * 3) > window_stride); + + if (mask && alibi) { + low_data[i].x = low_x_check + ? conversion::to(vals[data_id]) * layer_scale + + (conversion::to(alibi[data_id + alibi_offset])) + + (conversion::to(mask[data_id + mask_offset])) + : minus_infinity; + low_data[i].y = + low_y_check + ? conversion::to(vals[data_id + reduceWidth]) * layer_scale + + (conversion::to(alibi[data_id + alibi_offset + reduceWidth])) + + (conversion::to(mask[data_id + mask_offset + reduceWidth])) + : minus_infinity; + high_data[i].x = + high_x_check + ? conversion::to(vals[data_id + reduceWidth * 2]) * layer_scale + + (conversion::to( + alibi[data_id + alibi_offset + reduceWidth * 2])) + + (conversion::to(mask[data_id + mask_offset + reduceWidth * 2])) + : minus_infinity; + high_data[i].y = + high_y_check + ? conversion::to(vals[data_id + reduceWidth * 3]) * layer_scale + + (conversion::to( + alibi[data_id + alibi_offset + reduceWidth * 3])) + + (conversion::to(mask[data_id + mask_offset + reduceWidth * 3])) + : minus_infinity; + } else if (mask) { + low_data[i].x = low_x_check + ? conversion::to(vals[data_id]) * layer_scale + + (conversion::to(mask[data_id + mask_offset])) + : minus_infinity; + low_data[i].y = + low_y_check + ? conversion::to(vals[data_id + reduceWidth]) * layer_scale + + (conversion::to(mask[data_id + mask_offset + reduceWidth])) + : minus_infinity; + high_data[i].x = + high_x_check + ? conversion::to(vals[data_id + reduceWidth * 2]) * layer_scale + + (conversion::to(mask[data_id + mask_offset + reduceWidth * 2])) + : minus_infinity; + high_data[i].y = + high_y_check + ? conversion::to(vals[data_id + reduceWidth * 3]) * layer_scale + + (conversion::to(mask[data_id + mask_offset + reduceWidth * 3])) + : minus_infinity; + } else if (alibi) { + low_data[i].x = low_x_check + ? conversion::to(vals[data_id]) * layer_scale + + (conversion::to(alibi[data_id + alibi_offset])) + : minus_infinity; + low_data[i].y = + low_y_check + ? conversion::to(vals[data_id + reduceWidth]) * layer_scale + + (conversion::to(alibi[data_id + alibi_offset + reduceWidth])) + : minus_infinity; + high_data[i].x = + high_x_check + ? conversion::to(vals[data_id + reduceWidth * 2]) * layer_scale + + (conversion::to( + alibi[data_id + alibi_offset + reduceWidth * 2])) + : minus_infinity; + high_data[i].y = + high_y_check + ? conversion::to(vals[data_id + reduceWidth * 3]) * layer_scale + + (conversion::to( + alibi[data_id + alibi_offset + reduceWidth * 3])) + : minus_infinity; + } else { + low_data[i].x = low_x_check ? conversion::to(vals[data_id]) * layer_scale + : minus_infinity; + low_data[i].y = + low_y_check ? conversion::to(vals[data_id + reduceWidth]) * layer_scale + : minus_infinity; + high_data[i].x = + high_x_check + ? conversion::to(vals[data_id + reduceWidth * 2]) * layer_scale + : minus_infinity; + high_data[i].y = + high_y_check + ? conversion::to(vals[data_id + reduceWidth * 3]) * layer_scale + : minus_infinity; + } + + // if(lane == 0) printf("%f , %d, %d \n", low_data[i].x, data_id, seq_id); + max_val = (low_data[i].x > max_val ? low_data[i].x : max_val); + max_val = (low_data[i].y > max_val ? low_data[i].y : max_val); + max_val = (high_data[i].x > max_val ? high_data[i].x : max_val); + max_val = (high_data[i].y > max_val ? high_data[i].y : max_val); + } + + for (int i = 1; i < WARP_SIZE; i *= 2) { + auto temp = g.shfl_xor(max_val, i); + max_val = (temp > max_val ? temp : max_val); + } + + if (reduceWidth > WARP_SIZE) { + if (lane == 0) partialSum[wid] = max_val; + b.sync(); + + if (lane < warp_num) max_val = partialSum[lane]; + + b.sync(); + + for (int i = 1; i < reduce_blocks; i *= 2) { + auto temp = g.shfl_xor(max_val, i); + max_val = (temp > max_val ? temp : max_val); + } + + max_val = g.shfl(max_val, threadIdx.x / WARP_SIZE); + } + float sum = 0; + for (int i = 0; i < iterations; i++) { + low_data[i].x = __expf(low_data[i].x - max_val); + low_data[i].y = __expf(low_data[i].y - max_val); + high_data[i].x = __expf(high_data[i].x - max_val); + high_data[i].y = __expf(high_data[i].y - max_val); + + sum += (low_data[i].x + low_data[i].y + high_data[i].x + high_data[i].y); + } + + for (int i = 1; i < WARP_SIZE; i *= 2) sum += g.shfl_xor(sum, i); + + if (reduceWidth > WARP_SIZE) { + if (lane == 0) partialSum[wid] = sum; + b.sync(); + + if (lane < warp_num) sum = partialSum[lane]; + + b.sync(); + + for (int i = 1; i < reduce_blocks; i *= 2) { sum += g.shfl_xor(sum, i); } + + sum = g.shfl(sum, threadIdx.x / WARP_SIZE); + } + sum += 1e-6; + for (int i = 0; i < iterations; i++) { + int data_id = i * (reduceWidth << 2) + (seq_lane); + if (data_id < sequence_length) { + vals[data_id] = conversion::to(low_data[i].x / sum); + if ((data_id + reduceWidth) < sequence_length) + vals[data_id + reduceWidth] = conversion::to(low_data[i].y / sum); + if ((data_id + reduceWidth * 2) < sequence_length) + vals[data_id + reduceWidth * 2] = conversion::to(high_data[i].x / sum); + if ((data_id + reduceWidth * 3) < sequence_length) + vals[data_id + reduceWidth * 3] = conversion::to(high_data[i].y / sum); + } + } + } +} + +template +__global__ void attn_softmax_v2(float* vals, + float* attn_mask, + float* alibi, + float layer_scale, + bool triangular, + bool recompute, + bool local_attention, + int window_size, + int total_count, + int heads, + int sequence_length, + int num_seq, + int head_offset, + int mask_stride, + int mp_size, + int reduceWidth) +{ + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + float4 data[MAX_REG_SIZE]; + + int wid = threadIdx.x >> 5; + int lane = threadIdx.x & 0x1f; + int warp_num = blockDim.x >> 5; + + int reduce_blocks = reduceWidth >> 5; + int seq_lane = threadIdx.x % reduceWidth; + + __shared__ float partialSum[MAX_WARP_NUM]; + + int iter_offset = blockIdx.x * (warp_num / reduce_blocks) + (wid / reduce_blocks); + if (iter_offset < total_count) { + vals += (iter_offset * sequence_length); + + int batch_idx = iter_offset / (num_seq * heads); + int mask_offset = batch_idx * mask_stride + (iter_offset % mask_stride); + mask_offset = mask_offset * sequence_length; + int seq_id = iter_offset % num_seq; + + int real_seq_id = seq_id + (num_seq == sequence_length ? 0 : sequence_length); + int window_stride4 = (local_attention && (real_seq_id >> 2) > (window_size >> 2)) + ? (real_seq_id >> 2) - (window_size >> 2) + : 0; + int window_stride = + (local_attention && real_seq_id >= window_size) ? real_seq_id - window_size : -1; + + float max_val = minus_infinity; + + for (int i = 0; i < iterations; i++) { + int data_id = i * (reduceWidth << 2) + (seq_lane); + bool check = (data_id >> 2) >= window_stride4; + bool x_check = check && (data_id < sequence_length) && + (!triangular || (data_id <= seq_id)) && (data_id > window_stride); + bool y_check = check && ((data_id + reduceWidth) < sequence_length) && + (!triangular || ((data_id + reduceWidth) <= seq_id)) && + ((data_id + reduceWidth) > window_stride); + bool z_check = check && ((data_id + reduceWidth * 2) < sequence_length) && + (!triangular || ((data_id + reduceWidth * 2) <= seq_id)) && + ((data_id + reduceWidth * 2) > window_stride); + bool w_check = check && ((data_id + reduceWidth * 3) < sequence_length) && + (!triangular || ((data_id + reduceWidth * 3) <= seq_id)) && + ((data_id + reduceWidth * 3) > window_stride); + + if (attn_mask) { + data[i].x = x_check ? vals[data_id] + attn_mask[data_id + mask_offset] + : minus_infinity; + data[i].y = y_check ? vals[data_id + reduceWidth] + + attn_mask[data_id + mask_offset + reduceWidth] + : minus_infinity; + data[i].z = z_check ? vals[data_id + reduceWidth * 2] + + attn_mask[data_id + mask_offset + reduceWidth * 2] + : minus_infinity; + data[i].w = w_check ? vals[data_id + reduceWidth * 3] + + attn_mask[data_id + mask_offset + reduceWidth * 3] + : minus_infinity; + } else { + data[i].x = x_check ? vals[data_id] : minus_infinity; + data[i].y = y_check ? vals[data_id + reduceWidth] : minus_infinity; + data[i].z = z_check ? vals[data_id + reduceWidth * 2] : minus_infinity; + data[i].w = w_check ? vals[data_id + reduceWidth * 3] : minus_infinity; + } + + max_val = (data[i].x > max_val ? data[i].x : max_val); + max_val = (data[i].y > max_val ? data[i].y : max_val); + max_val = (data[i].z > max_val ? data[i].z : max_val); + max_val = (data[i].w > max_val ? data[i].w : max_val); + } + + for (int i = 1; i < WARP_SIZE; i *= 2) { + auto temp = g.shfl_xor(max_val, i); + max_val = (temp > max_val ? temp : max_val); + } + + if (reduceWidth > WARP_SIZE) { + if (lane == 0) partialSum[wid] = max_val; + b.sync(); + + if (lane < warp_num) max_val = partialSum[lane]; + + b.sync(); + + for (int i = 1; i < reduce_blocks; i *= 2) { + auto temp = g.shfl_xor(max_val, i); + max_val = (temp > max_val ? temp : max_val); + } + + max_val = g.shfl(max_val, threadIdx.x / WARP_SIZE); + } + + float sum = 0; + for (int i = 0; i < iterations; i++) { + data[i].x = __expf(data[i].x - max_val); + data[i].y = __expf(data[i].y - max_val); + data[i].z = __expf(data[i].z - max_val); + data[i].w = __expf(data[i].w - max_val); + + sum += (data[i].x + data[i].y + data[i].z + data[i].w); + } + + for (int i = 1; i < WARP_SIZE; i *= 2) sum += g.shfl_xor(sum, i); + + if (reduceWidth > WARP_SIZE) { + if (lane == 0) partialSum[wid] = sum; + b.sync(); + + if (lane < warp_num) sum = partialSum[lane]; + + b.sync(); + + for (int i = 1; i < reduce_blocks; i *= 2) { sum += g.shfl_xor(sum, i); } + + sum = g.shfl(sum, threadIdx.x / WARP_SIZE); + } + sum += 1e-6; + + for (int i = 0; i < iterations; i++) { + int data_id = i * (reduceWidth << 2) + (seq_lane); + if (data_id < sequence_length) { + vals[data_id] = data[i].x / sum; + if ((data_id + reduceWidth) < sequence_length) + vals[data_id + reduceWidth] = data[i].y / sum; + if ((data_id + reduceWidth * 2) < sequence_length) + vals[data_id + reduceWidth * 2] = data[i].z / sum; + if ((data_id + reduceWidth * 3) < sequence_length) + vals[data_id + reduceWidth * 3] = data[i].w / sum; + } + } + } +} + +#define LAUNCH_ATTN_SOFTMAX_V2(iterations) \ + attn_softmax_v2<<>>(vals, \ + mask, \ + alibi, \ + layer_scale, \ + triangular, \ + recompute, \ + local_attention, \ + window_size, \ + total_count, \ + heads, \ + sequence_length, \ + num_seq, \ + head_offset, \ + mask_stride, \ + mp_size, \ + reduce_width); + +template +void launch_attn_softmax_v2(T* vals, + T* mask, + T* alibi, + float layer_scale, + bool triangular, + bool recompute, + bool local_attention, + int window_size, + int batch_size, + int heads, + int num_seq, + int sequence_length, + int head_offset, + int mask_stride, + int mp_size, + cudaStream_t stream) +{ + const int total_count = batch_size * heads * num_seq; + + // Scheduling Overview + // 4 element unroll with power of 2 `reduce_width` threads to a ceiling of `attn_threads` + // Each block should be partitioned into as many `reduce_width` blocks + // as can be fit. + constexpr int attn_threads = 256; + constexpr int min_reduce_width = hw_warp_size; + constexpr int internal_unroll = 4; + + // Handle internal unroll then round to next power of 2. Bump up to minimum granularity. + const int thread_steps_rounded = + next_pow2((sequence_length + internal_unroll - 1) / internal_unroll); + const int thread_steps_schedule = + (thread_steps_rounded < min_reduce_width) ? min_reduce_width : thread_steps_rounded; + // Bound reduce width to the number of threads + const int reduce_width = (thread_steps_schedule < attn_threads) ? thread_steps_schedule + : attn_threads; + // Scale for the excess + const int iterations = thread_steps_schedule / reduce_width; + // Should be safe since reduce_width is capped to attn_threads + const int partitions = attn_threads / reduce_width; + + // Launch params + dim3 grid((total_count + partitions - 1) / partitions); + dim3 block(attn_threads); + + if (sequence_length <= 32768) { + if (iterations == 1) { + LAUNCH_ATTN_SOFTMAX_V2(1); + } else if (iterations == 2) { + LAUNCH_ATTN_SOFTMAX_V2(2); + } else if (iterations == 4) { + LAUNCH_ATTN_SOFTMAX_V2(4); + } else if (iterations == 8) { + LAUNCH_ATTN_SOFTMAX_V2(8); + } else if (iterations == 16) { + LAUNCH_ATTN_SOFTMAX_V2(16); + } else if (iterations == 32) { + LAUNCH_ATTN_SOFTMAX_V2(32); + } else if (iterations == 64) { + LAUNCH_ATTN_SOFTMAX_V2(64); + } + } else + throw std::runtime_error("Unsupport Seq_Length!"); +} + +#define INSTANTIATE_LAUNCH_ATTN_SOFTMAX_V2(T) \ + template void launch_attn_softmax_v2(T* vals, \ + T* mask, \ + T* alibi, \ + float layer_scale, \ + bool triangular, \ + bool recompute, \ + bool local_attention, \ + int window_size, \ + int batch_size, \ + int heads, \ + int num_seq, \ + int sequence_length, \ + int head_offset, \ + int mask_stride, \ + int mp_size, \ + cudaStream_t stream); + +INSTANTIATE_LAUNCH_ATTN_SOFTMAX_V2(float); +#ifdef BF16_AVAILABLE +INSTANTIATE_LAUNCH_ATTN_SOFTMAX_V2(__nv_bfloat16); +#endif +INSTANTIATE_LAUNCH_ATTN_SOFTMAX_V2(__half); + +#define DEF_ATTN_SOFTMAX_V2_HALF(_iter) \ + template __global__ void attn_softmax_v2<__half, _iter>(__half * vals, \ + __half * mask, \ + __half * alibi, \ + float layer_scale, \ + bool triangular, \ + bool recompute, \ + bool local_attention, \ + int window_size, \ + int total_count, \ + int heads, \ + int sequence_length, \ + int num_seq, \ + int head_offset, \ + int mask_stride, \ + int mp_size, \ + int reduceWidth) + +#define DEF_ATTN_SOFTMAX_V2_BF16(_iter) \ + template __global__ void attn_softmax_v2<__nv_bfloat16, _iter>(__nv_bfloat16 * vals, \ + __nv_bfloat16 * mask, \ + __nv_bfloat16 * alibi, \ + float layer_scale, \ + bool triangular, \ + bool recompute, \ + bool local_attention, \ + int window_size, \ + int total_count, \ + int heads, \ + int sequence_length, \ + int num_seq, \ + int head_offset, \ + int mask_stride, \ + int mp_size, \ + int reduceWidth) + +#define FOREACH_ITERATIONS(cb) \ + cb(1); \ + cb(2); \ + cb(4); \ + cb(8); \ + cb(16); \ + cb(32); \ + cb(64) + +FOREACH_ITERATIONS(DEF_ATTN_SOFTMAX_V2_HALF); +#ifdef BF16_AVAILABLE +FOREACH_ITERATIONS(DEF_ATTN_SOFTMAX_V2_BF16); +#endif diff --git a/toolbox/DeepSpeed/v0.15.3/patches/csrc/transformer/inference/csrc/transform.cu b/toolbox/DeepSpeed/v0.15.3/patches/csrc/transformer/inference/csrc/transform.cu new file mode 100644 index 0000000000000000000000000000000000000000..7d1a53b7ebf7bf26c50e6b8c35ba416ed235dfd8 --- /dev/null +++ b/toolbox/DeepSpeed/v0.15.3/patches/csrc/transformer/inference/csrc/transform.cu @@ -0,0 +1,744 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +/* Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. */ +/* All Rights Reserved. */ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#ifndef __HIP_PLATFORM_AMD__ +#include +#endif +#include "conversion_utils.h" +#include "inference_cuda_layers.h" +namespace cg = cooperative_groups; + +// only used to avoid compilation error due to lack of definition. +#ifndef BF16_AVAILABLE +using __nv_bfloat162 = __half2; +#endif + +// Bias add + +__global__ void bias_add_transform_0213(float* output, + float* k_cache, + float* v_cache, + const float* vals, + const float* bias, + int hidden_dim, + int seq_length, + unsigned seq_offset, + int heads, + int head_stride, + int num_kv, + int rotary_dim, + bool rotate_half, + bool rotate_every_two, + int head_ext, + int max_out_tokens, + float rope_theta) +{ + int d0_stride = hidden_dim * seq_length; + int d1_stride = hidden_dim; + int d2_stride = hidden_dim / heads; + + int d0 = blockIdx.x; // Batch + int d1 = blockIdx.y; // Sequence ID (0-127) + int cnt = blockIdx.z / head_ext; // Hidden count + int d2 = threadIdx.y + (blockIdx.z % head_ext) * (heads / head_ext); // Head (0-11) + int d3 = threadIdx.x; // Values (groups of 4) + + int d2_out_stride = d2_stride * (cnt == 0 ? seq_length : max_out_tokens); + int d0_out_stride = hidden_dim * (cnt == 0 ? seq_length : max_out_tokens); + + const float4* vals_vec = reinterpret_cast(vals); + float4* output_vec = + reinterpret_cast(cnt == 0 ? output : (cnt == 1 ? k_cache : v_cache)); + + vals_vec += (d0 * (d1_stride + num_kv * 2 * d2_stride) * seq_length); + vals_vec += d1 * (d1_stride + num_kv * 2 * d2_stride); + vals_vec += (cnt == 0 ? 0 : d1_stride) + (cnt == 0 ? 0 : (cnt - 1) * num_kv * d2_stride); + vals_vec += ((cnt == 0 ? d2 : (d2 / head_stride)) * d2_stride); + + output_vec += (d1 * d2_stride); + output_vec += (d0 * d0_out_stride); + output_vec += (d2 * d2_out_stride); + + unsigned seq_id = d1 + seq_offset; + float4 inputs = vals_vec[d3]; + int lane = d3 & 0x1f; + if (cnt < 2 && rotary_dim > 0 && d3 < rotary_dim) { + float4 q = vals_vec[d3]; + float2* q_f = reinterpret_cast(&q); + if (rotate_every_two) { +#pragma unroll + for (int o = 0; o < 2; o++) { + float inv_freq = (float)(((d3 << 1) + o) * 2) / (float)(rotary_dim << 2); + inv_freq = 1.0 / powf(rope_theta, inv_freq) * (float)seq_id; + q_f[o].x = (-1.0 * q_f[o].y * sinf(inv_freq) + q_f[o].x * cosf(inv_freq)); + q_f[o].y = (q_f[o].x * sinf(inv_freq) + q_f[o].y * cosf(inv_freq)); + } + } + output_vec[d3] = q; + } else + output_vec[d3] = inputs; +} + +#define ATTN_H 3 +#define MAX_SEQ_LINE 10 + +template +__global__ void bias_add_transform_0213(T* output, // q + T* k_cache, + T* v_cache, + const T* vals, // qkv + const T* bias, + int hidden_dim, + int seq_length, + unsigned seq_offset, + int all_tokens, + int heads, + int head_stride, + int num_kv, + int rotary_dim, + bool rotate_half, + bool rotate_every_two, + int head_ext, + int max_out_tokens, + float rope_theta) +{ + using T2 = + typename std::conditional::value, __half2, __nv_bfloat162>::type; + unsigned half_dim = (rotary_dim << 3) >> 1; + int d0_stride = hidden_dim * seq_length; + int d1_stride = hidden_dim; + int d2_stride = hidden_dim / heads; + + int d0 = blockIdx.x; // Batch + int d1 = blockIdx.y; // Sequence ID (0-127) + int cnt = blockIdx.z / head_ext; // Hidden count + int d2 = threadIdx.y + (blockIdx.z % head_ext) * (heads / head_ext); // Head (0-11) + int d3 = threadIdx.x; // Values (groups of 4) + + int d2_out_stride = d2_stride * (cnt == 0 ? seq_length : max_out_tokens); + int d0_out_stride = hidden_dim * (cnt == 0 ? seq_length : max_out_tokens); + + float4 vals_arr; + float4 output_arr; + + T2* vals_half = reinterpret_cast(&vals_arr); + T2* output_half = reinterpret_cast(&output_arr); + + const float4* vals_vec = reinterpret_cast(vals); + float4* output_vec = + reinterpret_cast(cnt == 0 ? output : (cnt == 1 ? k_cache : v_cache)); + + vals_vec += (d0 * (d1_stride + num_kv * 2 * d2_stride) * seq_length); + vals_vec += (d1 * (d1_stride + num_kv * 2 * d2_stride)); + vals_vec += (cnt == 0 ? 0 : d1_stride) + (cnt == 0 ? 0 : (cnt - 1) * num_kv * d2_stride); + vals_vec += ((cnt == 0 ? d2 : (d2 / head_stride)) * d2_stride); + + output_vec += (d1 * d2_stride); + output_vec += (d0 * d0_out_stride); + output_vec += (d2 * d2_out_stride); + + unsigned seq_id = d1 + seq_offset; + + int lane = d3 & 0x1f; + if (cnt < 2 && rotary_dim > 0 && d3 < rotary_dim) { + float4 q = vals_vec[d3]; + T2* q_h = reinterpret_cast(&q); + if (rotate_every_two) { +#pragma unroll + for (int o = 0; o < 4; o++) { + float inv_freq = (float)(((d3 << 2) + o) * 2) / (float)(rotary_dim << 3); + inv_freq = 1.0 / powf(rope_theta, inv_freq) * (float)seq_id; + float q_data[2]; + q_data[0] = conversion::to(q_h[o].x); + q_data[1] = conversion::to(q_h[o].y); + q_h[o].x = conversion::to(-1.0 * q_data[1] * sinf(inv_freq) + + q_data[0] * cosf(inv_freq)); + q_h[o].y = + conversion::to(q_data[0] * sinf(inv_freq) + q_data[1] * cosf(inv_freq)); + } + } + output_vec[d3] = q; + } else + output_vec[d3] = vals_vec[d3]; +} + +// [B S C*H] - > C * [B A S N] +template <> +void launch_bias_add_transform_0213(float* output, + float* k_cache, + float* v_cache, + const float* vals, + const float* bias, + int batch_size, + int seq_length, + unsigned seq_offset, + int all_tokens, + int hidden_dim, + int heads, + int num_kv, + int rotary_dim, + bool rotate_half, + bool rotate_every_two, + cudaStream_t stream, + int trans_count, + int max_out_tokens, + float rope_theta) +{ + hidden_dim >>= 2; + int head_ext = (hidden_dim - 1) / MAX_THREADS + 1; + + dim3 block_dim(hidden_dim / heads, (heads / head_ext)); + dim3 grid_dim(batch_size, seq_length, (trans_count * head_ext)); + + bias_add_transform_0213<<>>(output, + k_cache, + v_cache, + vals, + bias, + hidden_dim, + seq_length, + seq_offset, + heads, + num_kv > 0 ? (heads / num_kv) : 1, + num_kv > 0 ? num_kv : heads, + rotary_dim >> 2, + rotate_half, + rotate_every_two, + head_ext, + max_out_tokens, + rope_theta); +} + +template +void launch_bias_add_transform_0213(T* output, + T* k_cache, + T* v_cache, + const T* vals, + const T* bias, + int batch_size, + int seq_length, + unsigned seq_offset, + int all_tokens, + int hidden_dim, + int heads, + int num_kv, + int rotary_dim, + bool rotate_half, + bool rotate_every_two, + cudaStream_t stream, + int trans_count, + int max_out_tokens, + float rope_theta) +{ + hidden_dim >>= 3; + int head_ext = 1; // (hidden_dim - 1) / MAX_THREADS + 1; + dim3 block_dim(hidden_dim / heads, (heads / head_ext)); + dim3 grid_dim(batch_size, seq_length, (trans_count * head_ext)); + bias_add_transform_0213<<>>(output, + k_cache, + v_cache, + vals, + bias, + hidden_dim, + seq_length, + seq_offset, + all_tokens, + heads, + num_kv > 0 ? (heads / num_kv) : 1, + num_kv > 0 ? num_kv : heads, + rotary_dim >> 3, + rotate_half, + rotate_every_two, + head_ext, + max_out_tokens, + rope_theta); +} + +#define INSTANTIATE_LAUNCH_BIAS_ADD_TRANSFORM_0213(T) \ + template void launch_bias_add_transform_0213(T*, \ + T*, \ + T*, \ + const T*, \ + const T*, \ + int, \ + int, \ + unsigned, \ + int, \ + int, \ + int, \ + int, \ + int, \ + bool, \ + bool, \ + cudaStream_t, \ + int, \ + int, \ + float) + +#ifdef BF16_AVAILABLE +INSTANTIATE_LAUNCH_BIAS_ADD_TRANSFORM_0213(__nv_bfloat16); +#endif +INSTANTIATE_LAUNCH_BIAS_ADD_TRANSFORM_0213(__half); + +// Bias add + +__global__ void pad_add_transform_0213(float* output, + const float* vals, + int hidden_dim, + int seq_length, + int padded_seq_len, + int heads, + int padded_head_size) +{ +} + +template +__global__ void pad_add_transform_0213(T* output, + const T* vals, + int hidden_dim, + int seq_length, + int padded_seq_len, + int heads, + int padded_head_size) +{ + using T2 = + typename std::conditional::value, __half2, __nv_bfloat162>::type; + float4 ZERO; + const T2 zero_h = conversion::to(0.f); + T2* ZERO_h = reinterpret_cast(&ZERO); +#pragma unroll + for (int i = 0; i < 4; i++) ZERO_h[i] = zero_h; + + int d0_stride = hidden_dim * seq_length; + int d1_stride = hidden_dim; + int d2_stride = hidden_dim / heads; + + int d0 = blockIdx.x; // Batch + int d1 = blockIdx.y * blockDim.z + threadIdx.z; // Sequence ID (0-127) + int d2 = threadIdx.y; // Head (0-11) + int d3 = threadIdx.x; // Values (groups of 4) + + int d2_out_stride = padded_head_size * padded_seq_len; + int d0_out_stride = heads * d2_out_stride; + + const float4* vals_vec = reinterpret_cast(vals); + float4* output_vec = reinterpret_cast(output); + + vals_vec += (d0 * d0_stride); + vals_vec += (d1 * d1_stride); + vals_vec += (d2 * d2_stride); + + output_vec += (d1 * padded_head_size); + output_vec += (d0 * d0_out_stride); + output_vec += (d2 * d2_out_stride); + + if (d3 < d2_stride && d1 < seq_length) + output_vec[d3] = vals_vec[d3]; + else + output_vec[d3] = ZERO; +} + +// [B S C*H] - > C * [B A S N] +template <> +void launch_pad_add_transform_0213(float* output, + const float* vals, + int batch_size, + int hidden_dim, + int seq_length, + int padded_seq_len, + int heads, + int padded_head_size, + cudaStream_t stream) +{ +} + +template +void launch_pad_add_transform_0213(T* output, + const T* vals, + int batch_size, + int hidden_dim, + int seq_length, + int padded_seq_len, + int heads, + int padded_head_size, + cudaStream_t stream) +{ + hidden_dim >>= 3; + dim3 block_dim((padded_head_size >> 3), heads, 2); + dim3 grid_dim(batch_size, padded_seq_len / 2); + pad_add_transform_0213<<>>( + output, vals, hidden_dim, seq_length, padded_seq_len, heads, padded_head_size >> 3); +} + +#define INSTANTIATE_LAUNCH_PAD_ADD_TRANSFORM_0213_SIMPLE(T) \ + template void launch_pad_add_transform_0213( \ + T*, const T*, int, int, int, int, int, int, cudaStream_t); + +INSTANTIATE_LAUNCH_PAD_ADD_TRANSFORM_0213_SIMPLE(__half); +#ifdef BF16_AVAILABLE +INSTANTIATE_LAUNCH_PAD_ADD_TRANSFORM_0213_SIMPLE(__nv_bfloat16); +#endif + +// Bias add +template +__global__ void bias_add_transform_0213(T* output, + const T* vals, + const T* bias, + int hidden_dim, + int seq_length, + int heads, + int head_ext); + +template <> +__global__ void bias_add_transform_0213(float* output, + const float* vals, + const float* bias, + int hidden_dim, + int seq_length, + int heads, + int head_ext) +{ + int d0_stride = hidden_dim * seq_length; + int d1_stride = hidden_dim; + int d2_stride = hidden_dim / heads; + + int d0_out_stride = d0_stride; + int d1_out_stride = d2_stride; + int d2_out_stride = d2_stride * seq_length; + + int d0 = blockIdx.x; // Batch + int d1 = blockIdx.y; // Sequence ID (0-127) + int cnt = blockIdx.z / head_ext; // Hidden count + int d2 = threadIdx.y + (blockIdx.z % head_ext) * (heads / head_ext); // Head (0-11) + int d3 = threadIdx.x; // Values (groups of 4) + + const float4* vals_vec = reinterpret_cast(vals); + const float4* bias_vec = reinterpret_cast(bias); + float4* output_vec = reinterpret_cast(output); + + float4 inputs = vals_vec[d0 * d0_stride * (gridDim.z / head_ext) + cnt * d1_stride + + d1 * d1_stride * (gridDim.z / head_ext) + d2 * d2_stride + d3]; + float4 biases = bias_vec[cnt * d1_stride + d2 * d2_stride + d3]; + + float4 outputs; + outputs.x = inputs.x + biases.x; + outputs.y = inputs.y + biases.y; + outputs.z = inputs.z + biases.z; + outputs.w = inputs.w + biases.w; + + output_vec[cnt * d0_out_stride * gridDim.x + d0 * d0_out_stride + d1 * d1_out_stride + + d2 * d2_out_stride + d3] = outputs; +} + +template +__global__ void bias_add_transform_0213(T* output, + const T* vals, + const T* bias, + int hidden_dim, + int seq_length, + int heads, + int head_ext) +{ + using T2 = + typename std::conditional::value, __half2, __nv_bfloat162>::type; + int d0_stride = hidden_dim * seq_length; + int d1_stride = hidden_dim; + int d2_stride = hidden_dim / heads; + + int d2_out_stride = d2_stride * seq_length; + + int d0 = blockIdx.x; // Batch + int d1 = blockIdx.y; // Sequence ID (0-127) + int cnt = blockIdx.z / head_ext; // Hidden count + int d2 = threadIdx.y + (blockIdx.z % head_ext) * (heads / head_ext); // Head (0-11) + int d3 = threadIdx.x; // Values (groups of 4) + + float4 vals_arr; + float4 bias_arr; + float4 output_arr; + T2* vals_half = reinterpret_cast(&vals_arr); + T2* bias_half = reinterpret_cast(&bias_arr); + T2* output_half = reinterpret_cast(&output_arr); + + const float4* vals_vec = reinterpret_cast(vals); + const float4* bias_vec = reinterpret_cast(bias); + float4* output_vec = reinterpret_cast(output); + + vals_vec += (d0 * d0_stride * (gridDim.z / head_ext)); + vals_vec += (d1 * d1_stride * (gridDim.z / head_ext)); + vals_vec += (cnt * d1_stride); + vals_vec += (d2 * d2_stride); + + bias_vec += (cnt * d1_stride); + bias_vec += (d2 * d2_stride); + + output_vec += (cnt * d0_stride * gridDim.x); + output_vec += (d1 * d2_stride); + output_vec += (d0 * d0_stride); + output_vec += (d2 * d2_out_stride); + + bias_arr = bias_vec[d3]; + vals_arr = vals_vec[d3]; + + output_half[0] = vals_half[0] + bias_half[0]; + output_half[1] = vals_half[1] + bias_half[1]; + output_half[2] = vals_half[2] + bias_half[2]; + output_half[3] = vals_half[3] + bias_half[3]; + output_vec[d3] = output_arr; +} + +template +__global__ void bias_add_transform_0213_v2(T* output, + const T* vals, + const T* bias, + int hidden_dim, + int seq_length, + int heads) +{ + using T2 = + typename std::conditional::value, __half2, __nv_bfloat162>::type; + __shared__ float4 in_data[3072]; + + int d0_stride = hidden_dim * seq_length; + int d1_stride = hidden_dim; + int d2_stride = hidden_dim / heads; + int iteration_stride = d1_stride * blockDim.z; // Hidden * 3 / 8 + int batch_stride = d0_stride * blockDim.z; // Hidden * S * 3 / 8 + + int d0_out_stride = d0_stride; + int d1_out_stride = d2_stride; + int d2_out_stride = d2_stride * seq_length; + + int d0 = blockIdx.x; // Batch + int d1 = blockIdx.y; // Sequence ID (0-127) + int cnt = threadIdx.z; // blockIdx.z; // Hidden count + int d2 = threadIdx.y; // Head (0-11) + int d3 = threadIdx.x; // Values (groups of 4) + + float4 vals_arr[1]; + float4 bias_arr[1]; + float4 output_arr[1]; + T2* vals_half = reinterpret_cast(vals_arr); + T2* bias_half = reinterpret_cast(bias_arr); + T2* output_half = reinterpret_cast(output_arr); + + const float4* vals_vec = reinterpret_cast(vals); + const float4* bias_vec = reinterpret_cast(bias); + float4* output_vec = reinterpret_cast(output); + + int iter_index = cnt * d1_stride + d2 * d2_stride + d3; + int input_offset = d0 * batch_stride + d1 * (iteration_stride << 1); + bias_arr[0] = bias_vec[iter_index]; + +#pragma unroll + for (int iter = 0; iter < 2; iter++) { + int iter_id = iter * iteration_stride + iter_index; + vals_arr[0] = vals_vec[input_offset + iter_id]; + + output_half[0] = vals_half[0] + bias_half[0]; + output_half[1] = vals_half[1] + bias_half[1]; + output_half[2] = vals_half[2] + bias_half[2]; + output_half[3] = vals_half[3] + bias_half[3]; + + in_data[iter_id] = output_arr[0]; + } + __syncthreads(); + + iteration_stride = blockDim.z * (blockDim.y >> 1); + int matrix_stride = (d0_out_stride * gridDim.x); + int head_count = (d2 >> 1) + cnt * (blockDim.y >> 1); + + int out_index = d0 * d0_out_stride + d1 * (d1_out_stride << 1) + d3 + (d2 % 2) * d2_stride; + +#pragma unroll + for (int iter = 0; iter < 2; iter++) { + int iter_row = (iter * iteration_stride) + head_count; + int iter_offset = + (iter_row % blockDim.y) * d2_out_stride + (iter_row / blockDim.y) * matrix_stride; + output_vec[out_index + iter_offset] = + in_data[iter_row * d2_stride + d3 + (d2 % 2) * (d1_stride * blockDim.z)]; + } +} + +template +__global__ void transform4d_0213(T* out, + const T* in, + int heads, + int seq_length, + int hidden_dim, + int head_ext); + +template <> +__global__ void transform4d_0213(float* out, + const float* in, + int heads, + int seq_length, + int hidden_dim, + int head_ext) +{ + int d0_stride = hidden_dim * seq_length; + int d1_stride = d0_stride / heads; + int d2_stride = hidden_dim / heads; + + int d0_out_stride = d0_stride; + int d1_out_stride = d2_stride; + int d2_out_stride = hidden_dim; + + int d0 = blockIdx.x; // Batch + int d1 = blockIdx.y / ((seq_length - 1) / blockDim.y + 1); // Head + int d2 = (threadIdx.y + blockDim.y * blockIdx.y) % seq_length; + int cnt = blockIdx.z; + int d3 = threadIdx.x; // Values (groups of 8) + + if (d2 < seq_length) { + const float4* in_vec = reinterpret_cast(in); + float4* out_vec = reinterpret_cast(out); + + float4 vals_vec = in_vec[cnt * d0_stride * gridDim.x + d0 * d0_stride + d1 * d1_stride + + d2 * d2_stride + d3]; + out_vec[d0 * d0_out_stride * gridDim.z + cnt * d2_out_stride + d1 * d1_out_stride + + d2 * d2_out_stride * gridDim.z + d3] = vals_vec; + } +} + +template +__global__ void transform4d_0213(T* out, + const T* in, + int heads, + int seq_length, + int hidden_dim, + int head_ext) +{ + int d0_stride = hidden_dim * (seq_length / head_ext); + int d1_stride = hidden_dim; + int d2_stride = hidden_dim / heads; + + int d0 = blockIdx.x; // Batch + int d1 = threadIdx.y + (blockIdx.z % head_ext) * (heads / head_ext); // Head + int d2 = blockIdx.z / head_ext; // Sequence + int cnt = blockIdx.y; // Hidden count + int d3 = threadIdx.x; // Values (groups of 8) + + const float4* in_vec = reinterpret_cast(in); + float4* out_vec = reinterpret_cast(out); + + in_vec += (cnt * d0_stride * gridDim.x); + in_vec += (d0 * d0_stride); + in_vec += (d2 * d2_stride); + in_vec += (d1 * d2_stride * seq_length); + + out_vec += (cnt * d1_stride); + out_vec += (d1 * d2_stride); + out_vec += (d0 * d0_stride * gridDim.y); + out_vec += (d2 * d1_stride * gridDim.y); + + out_vec[d3] = in_vec[d3]; +} + +template +__global__ void transform4d_0213_v2(T* out, const T* in, int heads, int seq_length, int hidden_dim) +{ + __shared__ float4 in_data[3072]; + + int d0_stride = hidden_dim * seq_length; + int d1_stride = hidden_dim; + int d2_stride = hidden_dim / heads; + + int d0 = blockIdx.x; // Batch + int d1 = threadIdx.y; // Head + int d2 = blockIdx.y; // Sequence + int cnt = threadIdx.z; // Hidden count + int d3 = threadIdx.x; // Values (groups of 8) + + const float4* in_vec = reinterpret_cast(in); + float4* out_vec = reinterpret_cast(out); + + int input_offset = d0 * d0_stride + d2 * (d2_stride << 1) + d3 + (d1 % 2) * d2_stride; + int head_count = (d1 >> 1) + cnt * (blockDim.y >> 1); + int iteration_stride = blockDim.z * (blockDim.y >> 1); + int matrix_stride = (d0_stride * gridDim.x); + +#pragma unroll + for (int iter = 0; iter < 2; iter++) { + int iter_row = iter * iteration_stride + head_count; + int iter_offset = (iter_row % blockDim.y) * d2_stride; + + in_data[d3 + iter_offset + (iter_row / blockDim.y + (d1 % 2) * blockDim.z) * d1_stride] = + in_vec[input_offset + iter_offset * seq_length + + (iter_row / blockDim.y) * matrix_stride]; + } + __syncthreads(); + + iteration_stride = d1_stride * blockDim.z; + int iter_index = cnt * d1_stride + d1 * d2_stride + d3; + int output_offset = d0 * d0_stride * blockDim.z + d2 * (iteration_stride << 1); + +#pragma unroll + for (int iter = 0; iter < 2; iter++) { + int iter_id = iter * iteration_stride + iter_index; + out_vec[output_offset + iter_id] = in_data[iter_id]; + } +} + +// 3 * [B A S N] - > [B S C*H] +template <> +void launch_transform4d_0213(float* out, + const float* in, + int batch_size, + int heads, + int seq_length, + int hidden_dim, + cudaStream_t stream, + int trans_count) +{ + hidden_dim >>= 2; + dim3 grid_dims(batch_size, heads * ((seq_length - 1) / 8 + 1), trans_count); + dim3 block_dims(hidden_dim / heads, 8); + transform4d_0213 + <<>>(out, in, heads, seq_length, hidden_dim, 1); +} + +template +void launch_transform4d_0213(T* out, + const T* in, + int batch_size, + int heads, + int seq_length, + int hidden_dim, + cudaStream_t stream, + int trans_count) +{ + hidden_dim >>= 3; + int head_ext = (hidden_dim - 1) / MAX_THREADS + 1; + dim3 grid_dims(batch_size, trans_count, (seq_length * head_ext)); + dim3 block_dims(hidden_dim / heads, (heads / head_ext)); + transform4d_0213<<>>( + out, in, heads, seq_length, hidden_dim, head_ext); +} + +#define INSTANTIATE_2B_LAUNCH_TRANSFORM4D(T) \ + template void launch_transform4d_0213(T*, const T*, int, int, int, int, cudaStream_t, int); + +INSTANTIATE_2B_LAUNCH_TRANSFORM4D(__half) +#ifdef BF16_AVAILABLE +INSTANTIATE_2B_LAUNCH_TRANSFORM4D(__nv_bfloat16) +#endif diff --git a/toolbox/DeepSpeed/v0.15.3/patches/csrc/transformer/inference/includes/inference_context.h b/toolbox/DeepSpeed/v0.15.3/patches/csrc/transformer/inference/includes/inference_context.h new file mode 100644 index 0000000000000000000000000000000000000000..443862018c998d9f32cd7f04ccb269884d5cd85f --- /dev/null +++ b/toolbox/DeepSpeed/v0.15.3/patches/csrc/transformer/inference/includes/inference_context.h @@ -0,0 +1,309 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +/* Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. */ +/* All Rights Reserved. */ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#include +#include +#include +#include +#include +#include "cublas_v2.h" +#include "cuda.h" + +#define MEGABYTE (1024 * 1024) +#define GIGABYTE (1024 * 1024 * 1024) + +// TODO: refactor out +#define WARP_SIZE 32 + +#define CUDA_CHECK(callstr) \ + { \ + cudaError_t error_code = callstr; \ + if (error_code != cudaSuccess) { \ + std::cerr << "CUDA error " << error_code << " at " << __FILE__ << ":" << __LINE__; \ + assert(0); \ + } \ + } + +#define CUDA_1D_KERNEL_LOOP(i, n) \ + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); i += blockDim.x * gridDim.x) + +#define CUDA_2D_KERNEL_LOOP(i, n, j, m) \ + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); i += blockDim.x * gridDim.x) \ + for (size_t j = blockIdx.y * blockDim.y + threadIdx.y; j < (m); j += blockDim.y * gridDim.y) + +#define DS_CUDA_NUM_THREADS 512 +#define DS_MAXIMUM_NUM_BLOCKS 262144 + +inline int DS_GET_BLOCKS(const int N) +{ + return std::max( + std::min((N + DS_CUDA_NUM_THREADS - 1) / DS_CUDA_NUM_THREADS, DS_MAXIMUM_NUM_BLOCKS), + // Use at least 1 block, since CUDA does not allow empty block + 1); +} + +class InferenceContext { +public: + InferenceContext() + : _workspace(nullptr), + _seed(42), + _curr_offset(0), + _stream(0), + _free_memory_size(0), + _num_tokens(1), + _attention_unfused_workspace_offset(0), + _workSpaceSize(0) + { + _workSpaceSize = 0; + _workspace = 0; + + cublasStatus_t stat = cublasCreate(&_cublasHandle); + if (stat != CUBLAS_STATUS_SUCCESS) { + // It would be nice to use cublasGetStatusName and + // cublasGetStatusString, but they were only added in CUDA 11.4.2. + auto message = std::string("Failed to create cublas handle: cublasStatus_t was ") + + std::to_string(stat); + std::cerr << message << std::endl; + throw std::runtime_error(message); + } +#ifndef __HIP_PLATFORM_AMD__ + cublasSetMathMode(_cublasHandle, CUBLAS_TENSOR_OP_MATH); +#endif + cudaEventCreate(&_comp1_event); + cudaEventCreate(&_comp2_event); + cudaEventCreate(&_comp_event); + cudaEventCreate(&_comm_event); + } + + virtual ~InferenceContext() + { + cublasDestroy(_cublasHandle); + cudaFree(_workspace); + cudaEventDestroy(_comp1_event); + cudaEventDestroy(_comp2_event); + cudaEventDestroy(_comp_event); + cudaEventDestroy(_comm_event); + } + + static InferenceContext& Instance() + { + static InferenceContext _ctx; + return _ctx; + } + + void GenWorkSpace(const unsigned& num_layers, + const unsigned& num_heads, + const size_t& batch_size, + const size_t& prompt_len, + const size_t& hidden_dim, + const unsigned& mp_size, + const bool& external_cache, + const size_t& elem_size, + const unsigned& rank, + unsigned max_out_tokens, + unsigned min_out_tokens) + { + size_t total_size; + if (!_free_memory_size) { cudaMemGetInfo(&_free_memory_size, &total_size); } + + // Flash attention requires padded heads and we'll conservatively allocate + // for that here. Flash attention is only enabled for head size <= 128 right now + const int head_size = hidden_dim / num_heads; + const int padded_head_size = head_size <= 32 ? 32 : (head_size <= 64 ? 64 : 128); + const int effective_head_size = (head_size > 128) ? head_size : padded_head_size; + + size_t activation_size = 10 * (num_heads * effective_head_size) * batch_size; + // Other sequence length dimension is added when the final workSpaceSize is calculated + size_t temp_size = batch_size * (num_heads / mp_size) * max_out_tokens; + size_t cache_size = + num_layers * batch_size * ((num_heads * effective_head_size) / mp_size) * 2; + size_t minimal_requirements = + temp_size + (_free_memory_size > GIGABYTE ? 500 : 100) * MEGABYTE; + if (_free_memory_size < minimal_requirements) { + printf("Requested:\t%lu\nFree:\t%lu\nTotal:\t%lu\n", + minimal_requirements, + _free_memory_size, + total_size); + throw std::runtime_error("Workspace can't be allocated, no enough memory."); + } + + _max_seq_len = ((_free_memory_size - minimal_requirements) / elem_size) / + (activation_size + temp_size + cache_size); + _max_seq_len = std::min((size_t)max_out_tokens, _max_seq_len); + size_t workSpaceSize = ((external_cache ? (activation_size + temp_size) + : (activation_size + temp_size + cache_size))) * + _max_seq_len * elem_size; + temp_size *= _max_seq_len * elem_size; + + if (_max_seq_len < min_out_tokens) { + printf( + "Allocatable workspace available (%ld tokens) is less than minimum requested " + "workspace (%d tokens)\n", + _max_seq_len, + min_out_tokens); + throw std::runtime_error("Workspace can't be allocated, not enough memory"); + } + + if (!_workspace) { + assert(_workspace == nullptr); + cudaMalloc(&_workspace, workSpaceSize); + } else if (_workSpaceSize < workSpaceSize) { + cudaFree(_workspace); + cudaMalloc(&_workspace, workSpaceSize); + } + if (rank == 0 && (!_workspace || _workSpaceSize < workSpaceSize)) + printf( + "------------------------------------------------------\n" + "Free memory : %f (GigaBytes) \n" + "Total memory: %f (GigaBytes) \n" + "Requested memory: %f (GigaBytes) \n" + "Setting maximum total tokens (input + output) to %lu \n" + "WorkSpace: %p \n" + "------------------------------------------------------\n", + (float)_free_memory_size / GIGABYTE, + (float)total_size / GIGABYTE, + (float)workSpaceSize / GIGABYTE, + _max_seq_len, + _workspace); + + if (!_workspace) { + printf("Requested:\t%lu\nFree:\t%lu\nTotal:\t%lu\n", + workSpaceSize, + _free_memory_size, + total_size); + throw std::runtime_error("Workspace is null."); + } + _workSpaceSize = workSpaceSize; + _attention_unfused_workspace_offset = workSpaceSize - temp_size; + } + inline size_t GetMaxTokenLength() const { return _max_seq_len; } + + cudaEvent_t GetCompEvent(int id) { return id == 1 ? _comp1_event : _comp2_event; } + + size_t get_workspace_size() const { return _workSpaceSize; } + void* GetWorkSpace() { return _workspace; } + void* GetAttentionUnfusedWorkspace() + { + return (char*)_workspace + _attention_unfused_workspace_offset; + } + + inline unsigned new_token(unsigned layer_id) + { + if (layer_id == 0) _token_length++; + return _token_length; + } + + inline void reset_tokens(unsigned initial_tokens = 1) + { + _num_tokens = initial_tokens; + } //_token_length = 0; } + + inline unsigned current_tokens() const { return _num_tokens; } + + inline void advance_tokens() { _num_tokens++; } + + cudaStream_t GetCommStream(bool async_op = false) + { + if (!_comm_stream) + _comm_stream = async_op ? at::cuda::getStreamFromPool(true) + : at::cuda::getCurrentCUDAStream(); + return _comm_stream; + } + cudaStream_t GetCurrentStream(bool other_stream = false) + { + // get current pytorch stream. + if (other_stream) { + if (!_stream) _stream = at::cuda::getStreamFromPool(true); + return _stream; + } + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + return stream; + } + + void release_workspace() + { + cudaFree(_workspace); + _workspace = nullptr; + } + bool retake_workspace() + { + if (_workspace != nullptr || _workSpaceSize == 0) return true; + cudaMalloc(&_workspace, _workSpaceSize); + return _workspace != nullptr; + } + cublasHandle_t GetCublasHandle() { return _cublasHandle; } + + std::pair IncrementOffset(uint64_t offset_inc) + { + uint64_t offset = _curr_offset; + _curr_offset += offset_inc; + return std::pair(_seed, offset); + } + + void SetSeed(uint64_t new_seed) { _seed = new_seed; } + + const std::vector>& GetGemmAlgos() const { return _gemm_algos; } + + inline void SynchComp() + { + cudaEventRecord(_comp_event, _comp_stream); + cudaStreamWaitEvent(_comm_stream, _comp_event, 0); + } + inline void SynchComm() + { + cudaEventRecord(_comm_event, _comm_stream); + cudaStreamWaitEvent(_comp_stream, _comm_event, 0); + } + +private: + cublasHandle_t _cublasHandle; + + cudaEvent_t _comp_event; + cudaEvent_t _comm_event; + + void* _workspace; + // offset from _workspace for attention unfused memory + size_t _attention_unfused_workspace_offset; + uint64_t _seed; + uint64_t _curr_offset; + + size_t _workSpaceSize; + size_t _free_memory_size; + + size_t _max_seq_len; + + cudaEvent_t _comp1_event; + cudaEvent_t _comp2_event; + + cudaStream_t _stream; + + unsigned _token_length; + unsigned _num_tokens; + std::vector> _gemm_algos; + + cudaStream_t _comp_stream; + cudaStream_t _comm_stream; + + std::unordered_map _world_sizes; +}; diff --git a/toolbox/DeepSpeed/v0.15.3/patches/csrc/transformer/inference/includes/inference_cublas_wrappers.h b/toolbox/DeepSpeed/v0.15.3/patches/csrc/transformer/inference/includes/inference_cublas_wrappers.h new file mode 100644 index 0000000000000000000000000000000000000000..93500b0935eb7adefe1dc21773ae2fd930476436 --- /dev/null +++ b/toolbox/DeepSpeed/v0.15.3/patches/csrc/transformer/inference/includes/inference_cublas_wrappers.h @@ -0,0 +1,521 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +/* Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. */ +/* All Rights Reserved. */ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#include +#include +#include +#ifdef BF16_AVAILABLE +#include +#endif +#include +#include +#ifndef __HIP_PLATFORM_AMD__ +#include +#endif +#include + +// TODO HIP: Remove backward compatibility for torch<=2.0 in future +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) +int cublas_gemm_ex(rocblas_handle handle, + rocblas_operation transa, + rocblas_operation transb, + int m, + int n, + int k, + const float* alpha, + const float* beta, + const float* A, + const float* B, + float* C, + rocblas_gemm_algo algo, + int b_stride = -1) +#else +int cublas_gemm_ex(cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const float* alpha, + const float* beta, + const float* A, + const float* B, + float* C, + cublasGemmAlgo_t algo, + int b_stride = -1) +#endif +{ + const int ldb = (b_stride == -1) ? ((transb == CUBLAS_OP_N) ? k : n) : b_stride; +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) + rocblas_status status = rocblas_gemm_ex(handle, + transa, + transb, + m, + n, + k, + (const void*)alpha, + (const void*)A, + rocblas_datatype_f32_r, + (transa == rocblas_operation_none) ? m : k, + (const void*)B, + rocblas_datatype_f32_r, + ldb, + (const void*)beta, + C, + rocblas_datatype_f32_r, + m, + C, + rocblas_datatype_f32_r, + m, + rocblas_datatype_f32_r, + algo, + 0, + 0); +#else + cublasStatus_t status = cublasGemmEx(handle, + transa, + transb, + m, + n, + k, + (const void*)alpha, + (const void*)A, +#ifdef __HIP_PLATFORM_AMD__ + HIPBLAS_R_32F, +#else + CUDA_R_32F, +#endif + (transa == CUBLAS_OP_N) ? m : k, + (const void*)B, +#ifdef __HIP_PLATFORM_AMD__ + HIPBLAS_R_32F, +#else + CUDA_R_32F, +#endif + ldb, + (const void*)beta, + C, +#ifdef __HIP_PLATFORM_AMD__ + HIPBLAS_R_32F, +#else + CUDA_R_32F, +#endif + m, +#if defined(__HIP_PLATFORM_AMD__) && defined(HIPBLAS_V2) + HIPBLAS_COMPUTE_32F, +#elif defined(__HIP_PLATFORM_AMD__) + HIPBLAS_R_32F, +#else + CUDA_R_32F, +#endif + algo); +#endif + +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) + if (status != rocblas_status_success) { +#else + if (status != CUBLAS_STATUS_SUCCESS) { +#endif + fprintf(stderr, + "!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n", + m, + n, + k, + (int)status); + return EXIT_FAILURE; + } + return 0; +} + +template +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) +int cublas_gemm_ex(rocblas_handle handle, + rocblas_operation transa, + rocblas_operation transb, + int m, + int n, + int k, + const float* alpha, + const float* beta, + const T* A, + const T* B, + T* C, + rocblas_gemm_algo algo, + int b_stride = -1) +#else +int cublas_gemm_ex(cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const float* alpha, + const float* beta, + const T* A, + const T* B, + T* C, + cublasGemmAlgo_t algo, + int b_stride = -1) +#endif +{ + const int ldb = (b_stride == -1) ? ((transb == CUBLAS_OP_N) ? k : n) : b_stride; +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) + constexpr auto rocblas_dtype_16 = std::is_same::value ? rocblas_datatype_f16_r + : rocblas_datatype_bf16_r; + rocblas_status status = rocblas_gemm_ex(handle, + transa, + transb, + m, + n, + k, + (const void*)alpha, + (const void*)A, + rocblas_dtype_16, + (transa == rocblas_operation_none) ? m : k, + (const void*)B, + rocblas_dtype_16, + ldb, + (const void*)beta, + (void*)C, + rocblas_dtype_16, + m, + (void*)C, + rocblas_dtype_16, + m, + rocblas_datatype_f32_r, + algo, + 0, + 0); +#else +#ifdef __HIP_PLATFORM_AMD__ + constexpr auto cublas_dtype_16 = std::is_same::value ? HIPBLAS_R_16F : HIPBLAS_R_16B; +#else + constexpr auto cublas_dtype_16 = std::is_same::value ? CUDA_R_16F : CUDA_R_16BF; +#endif + cublasStatus_t status = cublasGemmEx(handle, + transa, + transb, + m, + n, + k, + (const void*)alpha, + (const void*)A, + cublas_dtype_16, + (transa == CUBLAS_OP_N) ? m : k, + (const void*)B, + cublas_dtype_16, + ldb, + (const void*)beta, + (void*)C, + cublas_dtype_16, + m, +#if defined(__HIP_PLATFORM_AMD__) && defined(HIPBLAS_V2) + HIPBLAS_COMPUTE_32F, +#elif defined(__HIP_PLATFORM_AMD__) + HIPBLAS_R_32F, +#else + CUDA_R_32F, +#endif + algo); +#endif + +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) + if (status != rocblas_status_success) { +#else + if (status != CUBLAS_STATUS_SUCCESS) { +#endif + fprintf(stderr, + "!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n", + m, + n, + k, + (int)status); + return EXIT_FAILURE; + } + return 0; +} + +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) +int cublas_strided_batched_gemm(rocblas_handle handle, + int m, + int n, + int k, + const float* alpha, + const float* beta, + const float* A, + const float* B, + float* C, + rocblas_operation op_A, + rocblas_operation op_B, + int stride_A, + int stride_B, + int stride_C, + int batch, + rocblas_gemm_algo algo) +#else +int cublas_strided_batched_gemm(cublasHandle_t handle, + int m, + int n, + int k, + const float* alpha, + const float* beta, + const float* A, + const float* B, + float* C, + cublasOperation_t op_A, + cublasOperation_t op_B, + int stride_A, + int stride_B, + int stride_C, + int batch, + cublasGemmAlgo_t algo) +#endif +{ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) + rocblas_status status = + rocblas_gemm_strided_batched_ex(handle, + op_A, + op_B, + m, + n, + k, + alpha, + A, + rocblas_datatype_f32_r, + (op_A == rocblas_operation_none) ? m : k, + stride_A, + B, + rocblas_datatype_f32_r, + (op_B == rocblas_operation_none) ? k : n, + stride_B, + beta, + C, + rocblas_datatype_f32_r, + m, + stride_C, + C, + rocblas_datatype_f32_r, + m, + stride_C, + batch, + rocblas_datatype_f32_r, + algo, + 0, + 0); +#else + cublasStatus_t status = cublasGemmStridedBatchedEx(handle, + op_A, + op_B, + m, + n, + k, + alpha, + A, +#ifdef __HIP_PLATFORM_AMD__ + HIPBLAS_R_32F, +#else + CUDA_R_32F, +#endif + (op_A == CUBLAS_OP_N) ? m : k, + stride_A, + B, +#ifdef __HIP_PLATFORM_AMD__ + HIPBLAS_R_32F, +#else + CUDA_R_32F, +#endif + (op_B == CUBLAS_OP_N) ? k : n, + stride_B, + beta, + C, +#ifdef __HIP_PLATFORM_AMD__ + HIPBLAS_R_32F, +#else + CUDA_R_32F, +#endif + m, + stride_C, + batch, +#if defined(__HIP_PLATFORM_AMD__) && defined(HIPBLAS_V2) + HIPBLAS_COMPUTE_32F, +#elif defined(__HIP_PLATFORM_AMD__) + HIPBLAS_R_32F, +#else + CUDA_R_32F, +#endif + algo); +#endif + +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) + if (status != rocblas_status_success) { +#else + if (status != CUBLAS_STATUS_SUCCESS) { +#endif + fprintf(stderr, + "!!!! kernel execution error. (batch: %d, m: %d, n: %d, k: %d, error: %d) \n", + batch, + m, + n, + k, + (int)status); + return EXIT_FAILURE; + } + return 0; +} + +template +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) +int cublas_strided_batched_gemm(rocblas_handle handle, + int m, + int n, + int k, + const float* alpha, + const float* beta, + const T* A, + const T* B, + T* C, + rocblas_operation op_A, + rocblas_operation op_B, + int stride_A, + int stride_B, + int stride_C, + int batch, + rocblas_gemm_algo algo) +#else +int cublas_strided_batched_gemm(cublasHandle_t handle, + int m, + int n, + int k, + const float* alpha, + const float* beta, + const T* A, + const T* B, + T* C, + cublasOperation_t op_A, + cublasOperation_t op_B, + int stride_A, + int stride_B, + int stride_C, + int batch, + cublasGemmAlgo_t algo) +#endif +{ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) + constexpr auto rocblas_dtype_16 = std::is_same::value ? rocblas_datatype_f16_r + : rocblas_datatype_bf16_r; + rocblas_status status = + rocblas_gemm_strided_batched_ex(handle, + op_A, + op_B, + m, + n, + k, + alpha, + A, + rocblas_dtype_16, + (op_A == rocblas_operation_none) ? m : k, + stride_A, + B, + rocblas_dtype_16, + (op_B == rocblas_operation_none) ? k : n, + stride_B, + beta, + C, + rocblas_dtype_16, + m, + stride_C, + C, + rocblas_dtype_16, + m, + stride_C, + batch, + rocblas_datatype_f32_r, + algo, + 0, + 0); +#else +#ifdef __HIP_PLATFORM_AMD__ + constexpr auto cublas_dtype_16 = std::is_same::value ? HIPBLAS_R_16F : HIPBLAS_R_16B; +#else + constexpr auto cublas_dtype_16 = std::is_same::value ? CUDA_R_16F : CUDA_R_16BF; +#endif + cublasStatus_t status = cublasGemmStridedBatchedEx(handle, + op_A, + op_B, + m, + n, + k, + alpha, + A, + cublas_dtype_16, + (op_A == CUBLAS_OP_N) ? m : k, + stride_A, + B, + cublas_dtype_16, + (op_B == CUBLAS_OP_N) ? k : n, + stride_B, + beta, + C, + cublas_dtype_16, + m, + stride_C, + batch, +#if defined(__HIP_PLATFORM_AMD__) && defined(HIPBLAS_V2) + HIPBLAS_COMPUTE_32F, +#elif defined(__HIP_PLATFORM_AMD__) + HIPBLAS_R_32F, +#else + CUDA_R_32F, +#endif + algo); +#endif + +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) + if (status != rocblas_status_success) { +#else + if (status != CUBLAS_STATUS_SUCCESS) { +#endif + fprintf(stderr, + "!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n", + m, + n, + k, + (int)status); + return EXIT_FAILURE; + } + + return 0; +} diff --git a/toolbox/DeepSpeed/v0.15.3/patches/csrc/transformer/inference/includes/inference_cuda_layers.h b/toolbox/DeepSpeed/v0.15.3/patches/csrc/transformer/inference/includes/inference_cuda_layers.h new file mode 100644 index 0000000000000000000000000000000000000000..fa949b8822982b3eb453ae21d28156466a78db65 --- /dev/null +++ b/toolbox/DeepSpeed/v0.15.3/patches/csrc/transformer/inference/includes/inference_cuda_layers.h @@ -0,0 +1,265 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +/* Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. */ +/* All Rights Reserved. */ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#include "ds_kernel_utils.h" + +#include +#ifdef BF16_AVAILABLE +#include +#endif +#include +#include +#include +#include +#include + +#define MAX_WARP_NUM 32 +#define WARP_SIZE 32 + +#define MAX_THREADS 65536 +#define SMs 80 + +#define MAX_REGISTERS 256 + +template +void launch_attn_softmax_v2(T* vals, + T* mask, + T* alibi, + float layer_scale, + bool triangular, + bool recompute, + bool local_attention, + int window_size, + int batch_size, + int heads, + int num_seq, + int sequence_length, + int offset, + int mask_stride, + int mp_size, + cudaStream_t stream); + +// Fused bias add with gelu activation +template +void launch_bias_gelu(T* input, + const T* bias, + int intermediate_size, + int batch_size, + cudaStream_t stream); + +template +void launch_gated_activation(T* output, + const T* activation, + const T* bias, + int rows, + int output_stride, + int elems_per_row, + bool use_gelu, + cudaStream_t stream); + +// Fused bias add with relu activation +template +void launch_bias_relu(T* input, + const T* bias, + int intermediate_size, + int batch_size, + cudaStream_t stream); + +template +void launch_bias_add(T* input, const T* bias, int hidden_size, int batch_size, cudaStream_t stream); + +template +void launch_bias_residual(T* input, + T* output, + T* attn, + T* bias, + T* attn_bias, + int batch, + int hidden_dim, + int mp_size, + bool preln, + cudaStream_t stream); + +template +void launch_fused_ln(T* output, + const T* vals, + const T* gamma, + const T* beta, + float epsilon, + int rows, + int elems_per_row, + cudaStream_t stream); + +template +void launch_fused_residual_ln(T* output, + const T* vals, + const T* residual, + const T* bias, + const T* gamma, + const T* beta, + float epsilon, + int rows, + int elems_per_row, + cudaStream_t stream); + +template +void launch_fused_residual_ln_store_pre_ln_res(T* norm_output, + T* res_output, + const T* vals, + const T* residual, + const T* bias, + const T* gamma, + const T* beta, + float epsilon, + int rows, + int elems_per_row, + cudaStream_t stream); + +template +void launch_rms_norm(T* norm_output, + T* res_output, + const T* vals, + const T* residual, + const T* gamma, + float epsilon, + int rows, + int elems_per_row, + cudaStream_t stream); + +template +void launch_dequantize(T* output, + const int8_t* input, + const float* qscale, + unsigned output_size, + unsigned hidden_dim, + unsigned groups, + unsigned merge_count, + cudaStream_t stream); + +template +void launch_dequantize(T* output, + const int8_t* input, + const float* qscale, + unsigned output_size, + unsigned hidden_dim, + unsigned groups, + cudaStream_t stream); +template +void launch_gptj_residual_add(T* input, + T* output, + T* attn, + T* bias, + T* attn_bias, + int batch, + int head_size, + int mp_size, + cudaStream_t stream); + +template +void launch_apply_rotary_pos_emb(T* mixed_query, + T* key_layer, + unsigned head_size, + unsigned seq_len, + unsigned rotary_dim, + unsigned offset, + unsigned num_heads, + unsigned batch, + float rope_theta, + cudaStream_t stream, + int max_out_tokens); + +template +void launch_moe_res_matmul(T* residual, + T* coef, + T* mlp_out, + int seq_len, + int hidden_dim, + cudaStream_t stream); + +// 4D transform [0, 1, 2, 3] -> [0, 2, 1, 3] +template +void launch_transform4d_0213(T* out, + const T* in, + int batch_size, + int heads, + int seq_length, + int hidden_dim, + cudaStream_t stream, + int trans_count); +template +void launch_bias_add_transform_0213(T* outputs, + T* vals, + T* vals1, + const T* vals2, + const T* bias, + int batch_size, + int seq_length, + unsigned seq_offset, + int seq_length1, + int hidden_dim, + int heads, + int num_kv, + int rotary_dim, + bool rotate_half, + bool rotate_every_two, + cudaStream_t stream, + int trans_count, + int max_out_tokens, + float rope_theta); +template +void pad_data(T* padded_output, + T* output, + int bsz, + int head_size, + int padded_head_size, + cudaStream_t stream); + +template +void pad_head_seq(T* padded_output, + T* output, + int bsz, + int seq_len, + int padded_seq_len, + int head_size, + int padded_head_size, + cudaStream_t stream); + +template +void launch_pad_add_transform_0213(T* output, + const T* vals, + int batch_size, + int hidden_dim, + int seq_length, + int padded_seq_len, + int heads, + int padded_head_size, + cudaStream_t stream); + +template +void launch_vector_add(T* out, + const T* a, + const T* b, + float gamma, + int num_elems, + cudaStream_t stream); diff --git a/toolbox/DeepSpeed/v0.15.3/patches/csrc/transformer/normalize_kernels.cu b/toolbox/DeepSpeed/v0.15.3/patches/csrc/transformer/normalize_kernels.cu new file mode 100644 index 0000000000000000000000000000000000000000..2c0dd87cfdc8675c12d3cd7ba1994384bde8a541 --- /dev/null +++ b/toolbox/DeepSpeed/v0.15.3/patches/csrc/transformer/normalize_kernels.cu @@ -0,0 +1,2151 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +/* Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. */ +/* All Rights Reserved. */ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include "custom_cuda_layers.h" + +namespace cg = cooperative_groups; + +/* +Fused bias add, residual (elementwise) add, and normalization layer. + +For FP16, this kernel does not promote to FP32 in order to utilize the 2x throughput for +__half2 instructions, and avoid the conversion overhead (1/8 of __hal2 arithmetic). + +For specific launch constraints, see the launch functions. +*/ + +#define NORM_REG (MAX_REGISTERS / 4) + +__global__ void fused_bias_residual_layer_norm(float* vals, + const float* residual, + const float* gamma, + const float* beta, + float epsilon, + bool preLayerNorm, + bool training, + float* vars, + float* means, + int row_stride) +{ + int iteration_stride = blockDim.x; + int iterations = row_stride / iteration_stride; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + int row = blockIdx.x; + int id = threadIdx.x; + int gid = id / WARP_SIZE; + + float vals_arr[NORM_REG]; + __shared__ float shr[MAX_WARP_NUM]; + + residual += (row * row_stride); + vals += (row * row_stride); + + float sum = 0.f; + int high_index = iterations * iteration_stride + id; +#pragma unroll + for (int i = 0; i < iterations; i++) { + vals_arr[i] = residual[i * iteration_stride + id]; + sum += vals_arr[i]; + } + if (high_index < row_stride) { + vals_arr[iterations] = residual[high_index]; + sum += vals_arr[iterations]; + iterations++; + } + + for (int i = 1; i < 32; i *= 2) { sum += g.shfl_down(sum, i); } + + if (g.thread_rank() == 0) shr[gid] = sum; + + b.sync(); + + if (g.thread_rank() < (iteration_stride >> WARP_SIZE_BITS)) sum = shr[g.thread_rank()]; + +#if !defined(__STOCHASTIC_MODE__) || __CUDA_ARCH__ < 700 + b.sync(); +#endif + + for (int i = 1; i < (iteration_stride >> WARP_SIZE_BITS); i *= 2) { + sum += g.shfl_down(sum, i); + } + + sum = g.shfl(sum, 0); + float mean = sum / row_stride; + if (training) + if (threadIdx.x == 0) means[row] = mean; + float variance = 0.f; + for (int i = 0; i < iterations; i++) { + vals_arr[i] -= mean; + variance += vals_arr[i] * vals_arr[i]; + } + + for (int i = 1; i < 32; i *= 2) { variance += g.shfl_down(variance, i); } + + if (g.thread_rank() == 0) shr[gid] = variance; + + b.sync(); + + if (g.thread_rank() < (iteration_stride >> WARP_SIZE_BITS)) variance = shr[g.thread_rank()]; + +#ifndef __STOCHASTIC_MODE__ + b.sync(); +#endif + + for (int i = 1; i < (iteration_stride >> WARP_SIZE_BITS); i *= 2) { + variance += g.shfl_down(variance, i); + } + variance = g.shfl(variance, 0); + variance /= row_stride; + variance += epsilon; + if (training) + if (threadIdx.x == 0) vars[row] = variance; + + iterations = row_stride / iteration_stride; + for (int i = 0; i < iterations; i++) { + vals_arr[i] = vals_arr[i] * rsqrtf(variance); + vals_arr[i] = + vals_arr[i] * gamma[i * iteration_stride + id] + beta[i * iteration_stride + id]; + vals[i * iteration_stride + id] = vals_arr[i]; + } + if ((high_index) < row_stride) { + vals_arr[iterations] = vals_arr[iterations] * rsqrtf(variance); + vals_arr[iterations] = vals_arr[iterations] * gamma[high_index] + beta[high_index]; + vals[high_index] = vals_arr[iterations]; + } +} + +__global__ void fused_bias_residual_layer_norm(__half* vals, + const __half* residual, + const __half* gamma, + const __half* beta, + float epsilon, + bool preLayerNorm, + bool training, + __half* vars, + __half* means, + int row_stride) +{ +#ifdef HALF_PRECISION_AVAILABLE + int iteration_stride = blockDim.x; + int iterations = row_stride / iteration_stride; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile<32> g = cg::tiled_partition<32>(b); + + int row = blockIdx.x; + int id = threadIdx.x; + int gid = id >> WARP_SIZE_BITS; + + float2 vals_f[NORM_REG]; + __shared__ float shr[MAX_WARP_NUM]; + + __half2* vals_cast = reinterpret_cast<__half2*>(vals); + const __half2* residual_cast = reinterpret_cast(residual); + + residual_cast += (row * row_stride); + vals_cast += (row * row_stride); + + float sum = 0.f; + int high_index = iterations * iteration_stride + id; +#pragma unroll + for (int i = 0; i < iterations; i++) { + vals_f[i] = __half22float2(residual_cast[i * iteration_stride + id]); + sum += vals_f[i].x; + sum += vals_f[i].y; + } + if ((high_index) < row_stride) { + vals_f[iterations] = __half22float2(residual_cast[high_index]); + sum += vals_f[iterations].x; + sum += vals_f[iterations].y; + iterations++; + } + + for (int i = 1; i < 32; i *= 2) { sum += g.shfl_down(sum, i); } + + if (g.thread_rank() == 0) shr[gid] = sum; + + b.sync(); + + if (g.thread_rank() < (iteration_stride >> WARP_SIZE_BITS)) sum = shr[g.thread_rank()]; + +#ifndef __STOCHASTIC_MODE__ + b.sync(); +#endif + + for (int i = 1; i < (iteration_stride >> WARP_SIZE_BITS); i *= 2) { + sum += g.shfl_down(sum, i); + } + sum = g.shfl(sum, 0); + float mean = sum / (row_stride * 2); + + float variance = 0.f; + for (int i = 0; i < iterations; i++) { + vals_f[i].x -= mean; + vals_f[i].y -= mean; + variance += vals_f[i].x * vals_f[i].x; + variance += vals_f[i].y * vals_f[i].y; + } + + for (int i = 1; i < 32; i *= 2) { variance += g.shfl_down(variance, i); } + + if (g.thread_rank() == 0) shr[gid] = variance; + + b.sync(); + + if (g.thread_rank() < (iteration_stride >> WARP_SIZE_BITS)) variance = shr[g.thread_rank()]; + +#ifndef __STOCHASTIC_MODE__ + b.sync(); +#endif + + for (int i = 1; i < (iteration_stride >> WARP_SIZE_BITS); i *= 2) { + variance += g.shfl_down(variance, i); + } + variance = g.shfl(variance, 0); + variance /= (row_stride * 2); + variance += epsilon; + + __half2 variance_h = __float2half2_rn(variance); + const __half2* gamma_cast = reinterpret_cast(gamma); + const __half2* beta_cast = reinterpret_cast(beta); + + if (training && threadIdx.x == 0) { + vars[row] = __float2half(variance); + means[row] = __float2half(mean); + } + iterations = row_stride / iteration_stride; + for (int i = 0; i < iterations; i++) { + __half2 vals_arr = __float22half2_rn(vals_f[i]); + vals_arr = vals_arr * h2rsqrt(variance_h); + vals_arr = + vals_arr * gamma_cast[i * iteration_stride + id] + beta_cast[i * iteration_stride + id]; + vals_cast[i * iteration_stride + id] = vals_arr; + } + if ((high_index) < row_stride) { + __half2 vals_arr = __float22half2_rn(vals_f[iterations]); + vals_arr = vals_arr * h2rsqrt(variance_h); + vals_arr = vals_arr * gamma_cast[high_index] + beta_cast[high_index]; + vals_cast[high_index] = vals_arr; + } +#endif +} + +template +void launch_bias_residual_layer_norm(T* vals, + const T* residual, + const T* gamma, + const T* beta, + float epsilon, + int batch_size, + int hidden_dim, + cudaStream_t stream, + bool preLayerNorm, + bool training, + T* vars, + T* means); + +template <> +void launch_bias_residual_layer_norm(float* vals, + const float* residual, + const float* gamma, + const float* beta, + float epsilon, + int batch_size, + int hidden_dim, + cudaStream_t stream, + bool preLayerNorm, + bool training, + float* vars, + float* means) +{ + int threads = THREADS; + + dim3 grid_dim(batch_size); + + if (hidden_dim > 16384 && hidden_dim <= 32768) + threads <<= 1; + else if (hidden_dim > 32768 && hidden_dim <= 65536) + threads <<= 2; + else if (hidden_dim > 65536) + throw std::runtime_error("Unsupport hidden_dim."); + + dim3 block_dim(threads); + + fused_bias_residual_layer_norm<<>>( + vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, hidden_dim); +} + +template <> +void launch_bias_residual_layer_norm<__half>(__half* vals, + const __half* residual, + const __half* gamma, + const __half* beta, + float epsilon, + int batch_size, + int hidden_dim, + cudaStream_t stream, + bool preLayerNorm, + bool training, + __half* vars, + __half* means) +{ + int threads = 128; + + dim3 grid_dim(batch_size); + + if (hidden_dim > 8192 && hidden_dim <= 16384) + threads <<= 1; + else if (hidden_dim > 16384 && hidden_dim <= 32768) + threads <<= 2; + else if (hidden_dim > 32768 && hidden_dim <= 65536) + threads <<= 3; + else if (hidden_dim > 65536) + throw std::runtime_error("Unsupport hidden_dim."); + + dim3 block_dim(threads); + + fused_bias_residual_layer_norm<<>>( + vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, hidden_dim / 2); +} + +__global__ void fused_bias_residual_layer_norm(float* vals, + const float* residual, + const float* gamma, + const float* beta, + float epsilon, + bool preLayerNorm, + bool training, + float* vars, + int row_stride) +{ + int iteration_stride = blockDim.x; + int iterations = row_stride / iteration_stride; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile<32> g = cg::tiled_partition<32>(b); + + int row = blockIdx.x; + int id = threadIdx.x; + int gid = id / 32; + + float vals_arr[NORM_REG]; + __shared__ float shr[MAX_WARP_NUM]; + + residual += (row * row_stride); + vals += (row * row_stride); + + float sum = 0.f; + int high_index = iterations * iteration_stride + id; +#pragma unroll + for (int i = 0; i < iterations; i++) { + vals_arr[i] = residual[i * iteration_stride + id]; + sum += vals_arr[i]; + } + if ((high_index) < row_stride) { + vals_arr[iterations] = residual[high_index]; + sum += vals_arr[iterations]; + iterations++; + } + + for (int i = 1; i < 32; i *= 2) { sum += g.shfl_down(sum, i); } + + if (g.thread_rank() == 0) shr[gid] = sum; + + b.sync(); + + if (g.thread_rank() < (iteration_stride >> WARP_SIZE_BITS)) sum = shr[g.thread_rank()]; + +#if !defined(__STOCHASTIC_MODE__) || __CUDA_ARCH__ < 700 + b.sync(); +#endif + + for (int i = 1; i < (iteration_stride >> WARP_SIZE_BITS); i *= 2) { + sum += g.shfl_down(sum, i); + } + + sum = g.shfl(sum, 0); + float mean = sum / row_stride; + float variance = 0.f; + for (int i = 0; i < iterations; i++) { + vals_arr[i] -= mean; + variance += vals_arr[i] * vals_arr[i]; + } + + for (int i = 1; i < 32; i *= 2) { variance += g.shfl_down(variance, i); } + + if (g.thread_rank() == 0) shr[gid] = variance; + + b.sync(); + + if (g.thread_rank() < (iteration_stride >> WARP_SIZE_BITS)) variance = shr[g.thread_rank()]; + +#ifndef __STOCHASTIC_MODE__ + b.sync(); +#endif + + for (int i = 1; i < (iteration_stride >> WARP_SIZE_BITS); i *= 2) { + variance += g.shfl_down(variance, i); + } + variance = g.shfl(variance, 0); + variance /= row_stride; + variance += epsilon; + if (training) + if (threadIdx.x == 0) vars[row] = variance; + + iterations = row_stride / iteration_stride; + for (int i = 0; i < iterations; i++) { + vals_arr[i] = vals_arr[i] * rsqrtf(variance); + vals_arr[i] = + vals_arr[i] * gamma[i * iteration_stride + id] + beta[i * iteration_stride + id]; + vals[i * iteration_stride + id] = vals_arr[i]; + } + if ((high_index) < row_stride) { + vals_arr[iterations] = vals_arr[iterations] * rsqrtf(variance); + vals_arr[iterations] = vals_arr[iterations] * gamma[high_index] + beta[high_index]; + vals[high_index] = vals_arr[iterations]; + } +} + +__global__ void fused_bias_residual_layer_norm(__half* vals, + const __half* residual, + const __half* gamma, + const __half* beta, + float epsilon, + bool preLayerNorm, + bool training, + __half* vars, + int row_stride) +{ +#ifdef HALF_PRECISION_AVAILABLE + + int iteration_stride = blockDim.x; + int iterations = row_stride / iteration_stride; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile<32> g = cg::tiled_partition<32>(b); + + int row = blockIdx.x; + int id = threadIdx.x; + int gid = id >> WARP_SIZE_BITS; + + float2 vals_f[NORM_REG]; + __shared__ float shr[MAX_WARP_NUM]; + + __half2* vals_cast = reinterpret_cast<__half2*>(vals); + const __half2* residual_cast = reinterpret_cast(residual); + + residual_cast += (row * row_stride); + vals_cast += (row * row_stride); + + float sum = 0.f; + int high_index = iterations * iteration_stride + id; +#pragma unroll + for (int i = 0; i < iterations; i++) { + vals_f[i] = __half22float2(residual_cast[i * iteration_stride + id]); + sum += vals_f[i].x; + sum += vals_f[i].y; + } + if ((high_index) < row_stride) { + vals_f[iterations] = __half22float2(residual_cast[high_index]); + sum += vals_f[iterations].x; + sum += vals_f[iterations].y; + iterations++; + } + + for (int i = 1; i < 32; i *= 2) { sum += g.shfl_down(sum, i); } + + if (g.thread_rank() == 0) shr[gid] = sum; + + b.sync(); + + if (g.thread_rank() < (iteration_stride >> WARP_SIZE_BITS)) sum = shr[g.thread_rank()]; + +#ifndef __STOCHASTIC_MODE__ + b.sync(); +#endif + + for (int i = 1; i < (iteration_stride >> WARP_SIZE_BITS); i *= 2) { + sum += g.shfl_down(sum, i); + } + sum = g.shfl(sum, 0); + float mean = sum / (row_stride * 2); + + float variance = 0.f; + for (int i = 0; i < iterations; i++) { + vals_f[i].x -= mean; + vals_f[i].y -= mean; + variance += vals_f[i].x * vals_f[i].x; + variance += vals_f[i].y * vals_f[i].y; + } + + for (int i = 1; i < 32; i *= 2) { variance += g.shfl_down(variance, i); } + + if (g.thread_rank() == 0) shr[gid] = variance; + + b.sync(); + + if (g.thread_rank() < (iteration_stride >> WARP_SIZE_BITS)) variance = shr[g.thread_rank()]; + +#ifndef __STOCHASTIC_MODE__ + b.sync(); +#endif + + for (int i = 1; i < (iteration_stride >> WARP_SIZE_BITS); i *= 2) { + variance += g.shfl_down(variance, i); + } + variance = g.shfl(variance, 0); + variance /= (row_stride * 2); + variance += epsilon; + + __half2 variance_h = __float2half2_rn(variance); + const __half2* gamma_cast = reinterpret_cast(gamma); + const __half2* beta_cast = reinterpret_cast(beta); + + if (training && threadIdx.x == 0) vars[row] = __float2half(variance); + + iterations = row_stride / iteration_stride; + for (int i = 0; i < iterations; i++) { + __half2 vals_arr = __float22half2_rn(vals_f[i]); + vals_arr = vals_arr * h2rsqrt(variance_h); + vals_arr = + vals_arr * gamma_cast[i * iteration_stride + id] + beta_cast[i * iteration_stride + id]; + vals_cast[i * iteration_stride + id] = vals_arr; + } + if ((high_index) < row_stride) { + __half2 vals_arr = __float22half2_rn(vals_f[iterations]); + vals_arr = vals_arr * h2rsqrt(variance_h); + vals_arr = vals_arr * gamma_cast[high_index] + beta_cast[high_index]; + vals_cast[high_index] = vals_arr; + } +#endif +} + +template +void launch_bias_residual_layer_norm(T* vals, + const T* residual, + const T* gamma, + const T* beta, + float epsilon, + int batch_size, + int hidden_dim, + cudaStream_t stream, + bool preLayerNorm, + bool training, + T* vars); + +/* +To tune this launch the following restrictions must be met: + +For float: +row_stride == hidden_size +threads * iterations == row_stride +threads is in [32, 64, 128, 256, 512, 1024] + +For half: +row_stride == hidden_size / 2 +threads * iterations == row_stride +threads is in [32, 64, 128, 256, 512, 1024] + +*/ + +template <> +void launch_bias_residual_layer_norm(float* vals, + const float* residual, + const float* gamma, + const float* beta, + float epsilon, + int batch_size, + int hidden_dim, + cudaStream_t stream, + bool preLayerNorm, + bool training, + float* vars) +{ + int threads = THREADS; + + dim3 grid_dim(batch_size); + + // There are some limitations to call below functions, now just enumerate the situations. + + if (hidden_dim > 16384 && hidden_dim <= 32768) + threads <<= 1; + else if (hidden_dim > 32768 && hidden_dim <= 65536) + threads <<= 2; + else if (hidden_dim > 65536) + throw std::runtime_error("Unsupport hidden_dim."); + + dim3 block_dim(threads); + + fused_bias_residual_layer_norm<<>>( + vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, hidden_dim); +} + +template <> +void launch_bias_residual_layer_norm<__half>(__half* vals, + const __half* residual, + const __half* gamma, + const __half* beta, + float epsilon, + int batch_size, + int hidden_dim, + cudaStream_t stream, + bool preLayerNorm, + bool training, + __half* vars) +{ + int threads = 128; + + dim3 grid_dim(batch_size); + + // There are some limitations to call below functions, now just enumerate the situations. + + if (hidden_dim > 8192 && hidden_dim <= 16384) + threads <<= 1; + else if (hidden_dim > 16384 && hidden_dim <= 32768) + threads <<= 2; + else if (hidden_dim > 32768 && hidden_dim <= 65536) + threads <<= 3; + else if (hidden_dim > 65536) + throw std::runtime_error("Unsupport hidden_dim."); + + dim3 block_dim(threads); + fused_bias_residual_layer_norm<<>>( + vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, hidden_dim / 2); +} + +/* Normalize Gamma & Betta gradients + * Compute gradients using either X_hat or + * normalize input (invertible). + * Combine transpose with gradients computation. + */ + +template +__global__ void LayerNormBackward1(const T* __restrict__ out_grad, + const T* __restrict__ vals_hat, + const T* __restrict__ gamma, + const T* __restrict__ betta, + T* __restrict__ gamma_grad, + T* __restrict__ betta_grad, + int rows, + int width, + bool invertible) +{ + __shared__ float betta_buffer[TILE_DIM][TILE_DIM + 1]; + __shared__ float gamma_buffer[TILE_DIM][TILE_DIM + 1]; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + int idx = blockDim.x * blockIdx.x + threadIdx.x; + int offset = threadIdx.y * width + idx; + int y_stride = width * TILE_DIM; + + float betta_reg = (invertible ? (float)betta[idx] : 0.0f); + float gamma_reg = (float)gamma[idx]; + + // Loop across matrix height + float betta_tmp = 0; + float gamma_tmp = 0; + for (int r = threadIdx.y; r < rows; r += TILE_DIM) { + float grad = (float)out_grad[offset]; + float val = (invertible ? ((float)vals_hat[offset] - betta_reg) / gamma_reg + : (float)vals_hat[offset]); + betta_tmp += grad; + gamma_tmp += (val * grad); + + offset += y_stride; + } + + betta_buffer[threadIdx.x][threadIdx.y] = betta_tmp; + gamma_buffer[threadIdx.x][threadIdx.y] = gamma_tmp; + + __syncthreads(); + + // Sum the shared buffer. + float s1 = betta_buffer[threadIdx.y][threadIdx.x]; + float s2 = gamma_buffer[threadIdx.y][threadIdx.x]; + +#ifndef __STOCHASTIC_MODE__ + __syncthreads(); +#endif + + for (int i = 1; i < TILE_DIM; i <<= 1) { + s1 += g.shfl_down(s1, i); + s2 += g.shfl_down(s2, i); + } + + if (threadIdx.x == 0) { + int pos = blockIdx.x * TILE_DIM + threadIdx.y; + betta_grad[pos] = s1; + gamma_grad[pos] = s2; + } +} + +/* Normalize Gamma & Betta gradients + * Compute gradients using the input to + * the normalize. + * Combine transpose with gradients computation. + */ + +template +__global__ void LayerNormBackward1(const T* __restrict__ out_grad, + const T* __restrict__ X_data, + const T* __restrict__ vars, + const T* __restrict__ means, + T* __restrict__ gamma_grad, + T* __restrict__ betta_grad, + int rows, + int width) +{ + __shared__ float betta_buffer[TILE_DIM][TILE_DIM + 1]; + __shared__ float gamma_buffer[TILE_DIM][TILE_DIM + 1]; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + int idx = blockDim.x * blockIdx.x + threadIdx.x; + int offset = threadIdx.y * width + idx; + int y_stride = width * TILE_DIM; + + int pos = blockIdx.x * TILE_DIM + threadIdx.y; + // Loop across matrix height + + float betta_tmp = 0; + float gamma_tmp = 0; + for (int r = threadIdx.y; r < rows; r += TILE_DIM) { + float grad = (float)out_grad[offset]; + float val = (float)X_data[offset]; + val = (val - (float)means[r]) * rsqrtf((float)vars[r]); + betta_tmp += grad; + gamma_tmp += (val * grad); + + offset += y_stride; + } + + betta_buffer[threadIdx.x][threadIdx.y] = betta_tmp; + gamma_buffer[threadIdx.x][threadIdx.y] = gamma_tmp; + + __syncthreads(); + + // Sum the shared buffer. + float s1 = betta_buffer[threadIdx.y][threadIdx.x]; + float s2 = gamma_buffer[threadIdx.y][threadIdx.x]; + +#ifndef __STOCHASTIC_MODE__ + __syncthreads(); +#endif + + for (int i = 1; i < TILE_DIM; i <<= 1) { + s1 += g.shfl_down(s1, i); + s2 += g.shfl_down(s2, i); + } + + if (threadIdx.x == 0) { + betta_grad[pos] = s1; + gamma_grad[pos] = s2; + } +} +/* + +/* Backward Normalize (Input-Gradient) + * Using the means and variances from the input + * This type of backward is invertible! + * We do the backward using the X_hat (X - u) / sqrt(variance) or the output of Normalization. + */ + +__global__ void LayerNormBackward2(const float* out_grad, + const float* vals_hat, + const float* gamma, + const float* betta, + const float* vars, + float* inp_grad, + bool invertible, + int row_stride) +{ + int iteration_stride = blockDim.x; + int iterations = row_stride / iteration_stride; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + int row = blockIdx.x; + int id = threadIdx.x; + int wid = id / WARP_SIZE; + int warp_num = iteration_stride >> WARP_SIZE_BITS; + __shared__ float partialSum[MAX_WARP_NUM]; + + out_grad += (row * row_stride); + vals_hat += (row * row_stride); + inp_grad += (row * row_stride); + + float vals_arr[NORM_REG]; + float vals_hat_arr[NORM_REG]; + int high_index = iterations * iteration_stride + id; +#pragma unroll + for (int i = 0; i < iterations; i++) { + float gamma_reg = gamma[i * iteration_stride + id]; + vals_arr[i] = out_grad[i * iteration_stride + id]; + vals_arr[i] *= gamma_reg; + vals_hat_arr[i] = + (invertible ? (vals_hat[i * iteration_stride + id] - betta[i * iteration_stride + id]) / + gamma_reg + : vals_hat[i * iteration_stride + id]); + } + if ((high_index) < row_stride) { + float gamma_reg = gamma[high_index]; + vals_arr[iterations] = out_grad[high_index]; + vals_arr[iterations] *= gamma_reg; + vals_hat_arr[iterations] = + (invertible ? (vals_hat[high_index] - betta[high_index]) / gamma_reg + : vals_hat[high_index]); + iterations++; + } + + float var_reg = vars[row]; + + float sum = 0; + for (int i = 0; i < iterations; i++) { + sum += vals_hat_arr[i] * vals_arr[i] * + sqrtf(var_reg); // dval_hat = gamma * (x - u) * out_grad + vals_arr[i] *= rsqrtf(var_reg); // dvar_inv = gamma * out_grad / sqrt(var) + } + + for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); } + + if (g.thread_rank() == 0) partialSum[wid] = sum; + + __syncthreads(); + + if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()]; + +#ifndef __STOCHASTIC_MODE__ + __syncthreads(); +#endif + + for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i); + + sum = g.shfl(sum, 0); + sum /= row_stride; + + for (int i = 0; i < iterations; i++) { vals_arr[i] += ((-sum * vals_hat_arr[i]) / var_reg); } + + sum = 0; + for (int i = 0; i < iterations; i++) { sum += vals_arr[i]; } + + for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); } + + if (g.thread_rank() == 0) partialSum[wid] = sum; + + __syncthreads(); + + if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()]; + +#ifndef __STOCHASTIC_MODE__ + __syncthreads(); +#endif + + for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i); + sum = g.shfl(sum, 0); + sum /= row_stride; + + iterations = row_stride / iteration_stride; + for (int i = 0; i < iterations; i++) inp_grad[i * iteration_stride + id] = (vals_arr[i] - sum); + if ((high_index) < row_stride) inp_grad[high_index] = (vals_arr[iterations] - sum); +} + +__global__ void LayerNormBackward2(const __half* out_grad, + const __half* vals_hat, + const __half* gamma, + const __half* betta, + const __half* vars, + __half* inp_grad, + bool invertible, + int row_stride) +{ +#ifdef HALF_PRECISION_AVAILABLE + int iteration_stride = blockDim.x; + int iterations = row_stride / iteration_stride; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + int row = blockIdx.x; + int id = threadIdx.x; + int wid = id / WARP_SIZE; + int warp_num = iteration_stride >> WARP_SIZE_BITS; + __shared__ float partialSum[MAX_WARP_NUM]; + + __half2 vals_arr[NORM_REG]; + float2 vals_arr_f[NORM_REG]; + __half2 vals_hat_arr[NORM_REG]; + + __half2* inp_grad_h = reinterpret_cast<__half2*>(inp_grad); + const __half2* out_grad_h = reinterpret_cast(out_grad); + const __half2* vals_hat_h = reinterpret_cast(vals_hat); + + inp_grad_h += (row * row_stride); + out_grad_h += (row * row_stride); + vals_hat_h += (row * row_stride); + + const __half2* gamma_h = reinterpret_cast(gamma); + const __half2* betta_h = (invertible ? reinterpret_cast(betta) : nullptr); + int high_index = iterations * iteration_stride + id; +#pragma unroll + for (int i = 0; i < iterations; i++) { + __half2 gamma_reg = gamma_h[i * iteration_stride + id]; + vals_arr[i] = out_grad_h[i * iteration_stride + id]; + vals_arr[i] *= gamma_reg; + vals_hat_arr[i] = + (invertible + ? (vals_hat_h[i * iteration_stride + id] - betta_h[i * iteration_stride + id]) / + gamma_reg + : vals_hat_h[i * iteration_stride + id]); + } + if ((high_index) < row_stride) { + __half2 gamma_reg = gamma_h[high_index]; + vals_arr[iterations] = out_grad_h[high_index]; + vals_arr[iterations] *= gamma_reg; + vals_hat_arr[iterations] = + (invertible ? (vals_hat_h[high_index] - betta_h[high_index]) / gamma_reg + : vals_hat_h[high_index]); + iterations++; + } + __half var_h = vars[row]; + __half2 var_reg = __halves2half2(var_h, var_h); + + float sum = 0.f; + for (int i = 0; i < iterations; i++) { + __half2 result_h = (vals_hat_arr[i] * vals_arr[i] * h2sqrt(var_reg)); + float2 result_f = __half22float2(result_h); + sum += result_f.x; + sum += result_f.y; + vals_arr[i] *= h2rsqrt(var_reg); + } + + for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); } + + if (g.thread_rank() == 0) partialSum[wid] = sum; + + __syncthreads(); + + if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()]; + +#ifndef __STOCHASTIC_MODE__ + __syncthreads(); +#endif + + for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i); + + sum = g.shfl(sum, 0); + sum /= (2 * row_stride); + __half2 sum_h = __float2half2_rn(sum); + + for (int i = 0; i < iterations; i++) { + __half2 temp = ((-sum_h * vals_hat_arr[i]) / (var_reg)); + vals_arr_f[i] = __half22float2(vals_arr[i]); + float2 temp_f = __half22float2(temp); + vals_arr_f[i].x += temp_f.x; + vals_arr_f[i].y += temp_f.y; + } + sum = 0.f; + + for (int i = 0; i < iterations; i++) { + sum += (vals_arr_f[i].x); + sum += (vals_arr_f[i].y); + } + + for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); } + + if (g.thread_rank() == 0) partialSum[wid] = sum; + + __syncthreads(); + + if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()]; + +#ifndef __STOCHASTIC_MODE__ + __syncthreads(); +#endif + + for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i); + + sum = g.shfl(sum, 0); + sum /= (2 * row_stride); + + iterations = row_stride / iteration_stride; + for (int i = 0; i < iterations; i++) { + vals_arr_f[i].x -= sum; + vals_arr_f[i].y -= sum; + __half2 temp = __float22half2_rn(vals_arr_f[i]); + + inp_grad_h[i * iteration_stride + id] = temp; + } + if ((high_index) < row_stride) { + vals_arr_f[iterations].x -= sum; + vals_arr_f[iterations].y -= sum; + __half2 temp = __float22half2_rn(vals_arr_f[iterations]); + + inp_grad_h[high_index] = temp; + } +#endif +} + +template <> +void launch_layerNorm_backward(const float* out_grad, + const float* vals_hat, + const float* vars, + const float* gamma, + float* gamma_grad, + float* betta_grad, + float* inp_grad, + int batch, + int hidden_dim, + cudaStream_t stream[2], + bool invertible, + const float* betta) +{ + int threads = THREADS; + + dim3 grid_dim(hidden_dim / TILE_DIM); + dim3 block_dim(TILE_DIM, TILE_DIM); + + LayerNormBackward1<<>>( + out_grad, vals_hat, gamma, betta, gamma_grad, betta_grad, batch, hidden_dim, invertible); + + dim3 grid_dim2(batch); + + if (hidden_dim > 16384 && hidden_dim <= 32768) + threads <<= 1; + else if (hidden_dim > 32768 && hidden_dim <= 65536) + threads <<= 2; + else if (hidden_dim > 65536) + throw std::runtime_error("Unsupport hidden_dim."); + + dim3 block_dim2(threads); + + LayerNormBackward2<<>>( + out_grad, vals_hat, gamma, betta, vars, inp_grad, invertible, hidden_dim); +} + +template <> +void launch_layerNorm_backward<__half>(const __half* out_grad, + const __half* vals_hat, + const __half* vars, + const __half* gamma, + __half* gamma_grad, + __half* betta_grad, + __half* inp_grad, + int batch, + int hidden_dim, + cudaStream_t stream[2], + bool invertible, + const __half* betta) +{ + int threads = THREADS; + + dim3 grid_dim(hidden_dim / TILE_DIM); + dim3 block_dim(TILE_DIM, TILE_DIM); + + // LayerNormBackward1<__half><<>>( + // out_grad, vals_hat, gamma, betta, gamma_grad, betta_grad, batch, hidden_dim, invertible); + + dim3 grid_dim2(batch); + + if (hidden_dim > 8192 && hidden_dim <= 16384) + threads <<= 1; + else if (hidden_dim > 16384 && hidden_dim <= 32768) + threads <<= 2; + else if (hidden_dim > 32768 && hidden_dim <= 65536) + threads <<= 3; + else if (hidden_dim > 65536) + throw std::runtime_error("Unsupport hidden_dim."); + + dim3 block_dim2(threads / 2); + + LayerNormBackward2<<>>( + out_grad, vals_hat, gamma, betta, vars, inp_grad, invertible, hidden_dim / 2); +} + +/* Backward Normalize (Input-Gradient) + * Using the means and variances from the input + * This type of backward is not invertible! + * We do the backward using the input (X) + */ + +__global__ void LayerNormBackward2(const float* out_grad, + const float* X_vals, + const float* gamma, + const float* vars, + const float* means, + float* inp_grad, + int row_stride) +{ + int iteration_stride = blockDim.x; + int iterations = row_stride / iteration_stride; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + int row = blockIdx.x; + int id = threadIdx.x; + int wid = id >> WARP_SIZE_BITS; + int warp_num = iteration_stride >> WARP_SIZE_BITS; + __shared__ float partialSum[MAX_WARP_NUM]; + + out_grad += (row * row_stride); + X_vals += (row * row_stride); + inp_grad += (row * row_stride); + + float vals_arr[NORM_REG]; + int high_index = iterations * iteration_stride + id; +#pragma unroll + for (int i = 0; i < iterations; i++) { + float gamma_reg = gamma[i * iteration_stride + id]; + vals_arr[i] = out_grad[i * iteration_stride + id]; + vals_arr[i] *= gamma_reg; + } + if ((high_index) < row_stride) { + float gamma_reg = gamma[high_index]; + vals_arr[iterations] = out_grad[high_index]; + vals_arr[iterations] *= gamma_reg; + iterations++; + } + + float var_reg = vars[row]; + float mean_reg = means[row]; + + float sum = 0; + float xu[NORM_REG]; + for (int i = 0; i < iterations; i++) { + xu[i] = (X_vals[i * iteration_stride + id] - mean_reg); + sum += vals_arr[i] * xu[i]; + vals_arr[i] *= rsqrtf(var_reg); + } + + for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); } + + if (g.thread_rank() == 0) partialSum[wid] = sum; + + __syncthreads(); + + if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()]; + +#ifndef __STOCHASTIC_MODE__ + __syncthreads(); +#endif + + for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i); + + sum = g.shfl(sum, 0); + sum /= row_stride; + + for (int i = 0; i < iterations; i++) { + vals_arr[i] += (-sum * xu[i] * rsqrtf(var_reg) / (var_reg)); + } + + sum = 0; + for (int i = 0; i < iterations; i++) { sum += vals_arr[i]; } + + for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); } + + if (g.thread_rank() == 0) partialSum[wid] = sum; + + __syncthreads(); + + if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()]; + +#ifndef __STOCHASTIC_MODE__ + __syncthreads(); +#endif + + for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i); + sum = g.shfl(sum, 0); + sum /= row_stride; + + iterations = row_stride / iteration_stride; + for (int i = 0; i < iterations; i++) inp_grad[i * iteration_stride + id] = (vals_arr[i] - sum); + if ((high_index) < row_stride) inp_grad[high_index] = (vals_arr[iterations] - sum); +} + +__global__ void LayerNormBackward2(const __half* out_grad, + const __half* X_vals, + const __half* gamma, + const __half* vars, + const __half* means, + __half* inp_grad, + int row_stride) +{ +#ifdef HALF_PRECISION_AVAILABLE + int iteration_stride = blockDim.x; + int iterations = row_stride / iteration_stride; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + int row = blockIdx.x; + int id = threadIdx.x; + int wid = id >> WARP_SIZE_BITS; + int warp_num = iteration_stride >> WARP_SIZE_BITS; + + __shared__ float partialSum[MAX_WARP_NUM]; + + __half2 vals_arr[NORM_REG]; + float2 vals_arr_f[NORM_REG]; + __half2 xu[NORM_REG]; + + __half2* inp_grad_h = reinterpret_cast<__half2*>(inp_grad); + const __half2* out_grad_h = reinterpret_cast(out_grad); + const __half2* vals_hat_h = reinterpret_cast(X_vals); + + inp_grad_h += (row * row_stride); + out_grad_h += (row * row_stride); + vals_hat_h += (row * row_stride); + + const __half2* gamma_h = reinterpret_cast(gamma); + int high_index = iterations * iteration_stride + id; + + __half mean_h = means[row]; + __half2 mean_reg = __halves2half2(mean_h, mean_h); +#pragma unroll + for (int i = 0; i < iterations; i++) { + __half2 gamma_reg = gamma_h[i * iteration_stride + id]; + vals_arr[i] = out_grad_h[i * iteration_stride + id]; + vals_arr[i] *= gamma_reg; // out_grad * gamma + xu[i] = (vals_hat_h[i * iteration_stride + id] - mean_reg); + } + if ((high_index) < row_stride) { + __half2 gamma_reg = gamma_h[high_index]; + vals_arr[iterations] = out_grad_h[high_index]; + vals_arr[iterations] *= gamma_reg; // out_grad * gamma + xu[iterations] = (vals_hat_h[high_index] - mean_reg); + iterations++; + } + __half var_h = vars[row]; + __half2 var_reg = __halves2half2(var_h, var_h); + + float sum = 0.f; + for (int i = 0; i < iterations; i++) { + __half2 result_h = (xu[i] * vals_arr[i]); + float2 result_f = __half22float2(result_h); + sum += result_f.x; + sum += result_f.y; + vals_arr[i] *= h2rsqrt(var_reg); + } + + for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); } + + if (g.thread_rank() == 0) partialSum[wid] = sum; + + __syncthreads(); + + if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()]; + +#ifndef __STOCHASTIC_MODE__ + __syncthreads(); +#endif + + for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i); + + sum = g.shfl(sum, 0); + sum /= (2 * row_stride); + __half2 sum_h = __float2half2_rn(sum); + + for (int i = 0; i < iterations; i++) { + __half2 xu_grad = ((-sum_h * xu[i] * h2rsqrt(var_reg)) / (var_reg)); + vals_arr_f[i] = __half22float2(vals_arr[i]); + float2 xu_grad_f = __half22float2(xu_grad); + vals_arr_f[i].x += xu_grad_f.x; + vals_arr_f[i].y += xu_grad_f.y; + } + + sum = 0.f; + for (int i = 0; i < iterations; i++) { + sum += (vals_arr_f[i].x); + sum += (vals_arr_f[i].y); + } + + for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); } + + if (g.thread_rank() == 0) partialSum[wid] = sum; + + __syncthreads(); + + if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()]; + +#ifndef __STOCHASTIC_MODE__ + __syncthreads(); +#endif + + for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i); + + sum = g.shfl(sum, 0); + sum /= (2 * row_stride); + + iterations = row_stride / iteration_stride; + for (int i = 0; i < iterations; i++) { + vals_arr_f[i].x -= sum; + vals_arr_f[i].y -= sum; + __half2 temp = __float22half2_rn(vals_arr_f[i]); + inp_grad_h[i * iteration_stride + id] = temp; + } + if ((high_index) < row_stride) { + vals_arr_f[iterations].x -= sum; + vals_arr_f[iterations].y -= sum; + __half2 temp = __float22half2_rn(vals_arr_f[iterations]); + inp_grad_h[high_index] = temp; + } +#endif +} + +template <> +void launch_layerNorm_backward(const float* out_grad, + const float* X_data, + const float* vars, + const float* means, + const float* gamma, + float* gamma_grad, + float* betta_grad, + float* inp_grad, + int batch, + int hidden_dim, + cudaStream_t stream[2]) +{ + int threads = THREADS; + + dim3 grid_dim(hidden_dim / TILE_DIM); + dim3 block_dim(TILE_DIM, TILE_DIM); + + LayerNormBackward1<<>>( + out_grad, X_data, vars, means, gamma_grad, betta_grad, batch, hidden_dim); + + dim3 grid_dim2(batch); + + if (hidden_dim > 16384 && hidden_dim <= 32768) + threads <<= 1; + else if (hidden_dim > 32768 && hidden_dim <= 65536) + threads <<= 2; + else if (hidden_dim > 65536) + throw std::runtime_error("Unsupport hidden_dim."); + + dim3 block_dim2(threads); + LayerNormBackward2<<>>( + out_grad, X_data, gamma, vars, means, inp_grad, hidden_dim); +} + +template <> +void launch_layerNorm_backward<__half>(const __half* out_grad, + const __half* X_data, + const __half* vars, + const __half* means, + const __half* gamma, + __half* gamma_grad, + __half* betta_grad, + __half* inp_grad, + int batch, + int hidden_dim, + cudaStream_t stream[2]) +{ + int threads = THREADS; + + dim3 grid_dim(hidden_dim / TILE_DIM); + dim3 block_dim(TILE_DIM, TILE_DIM); + + LayerNormBackward1<__half><<>>( + out_grad, X_data, vars, means, gamma_grad, betta_grad, batch, hidden_dim); + + dim3 grid_dim2(batch); + + if (hidden_dim > 8192 && hidden_dim <= 16384) + threads <<= 1; + else if (hidden_dim > 16384 && hidden_dim <= 32768) + threads <<= 2; + else if (hidden_dim > 32768 && hidden_dim <= 65536) + threads <<= 3; + else if (hidden_dim > 65536) + throw std::runtime_error("Unsupport hidden_dim."); + + dim3 block_dim2(threads / 2); + LayerNormBackward2<<>>( + out_grad, X_data, gamma, vars, means, inp_grad, hidden_dim / 2); +} + +template +__global__ void LayerNormBackward1_fused_add(const T* __restrict__ out_grad1, + const T* __restrict__ out_grad2, + const T* __restrict__ vals_hat, + const T* __restrict__ gamma, + const T* __restrict__ betta, + T* __restrict__ gamma_grad, + T* __restrict__ betta_grad, + int rows, + int width, + bool invertible) +{ + __shared__ float betta_buffer[TILE_DIM][TILE_DIM + 1]; + __shared__ float gamma_buffer[TILE_DIM][TILE_DIM + 1]; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + int idx = blockDim.x * blockIdx.x + threadIdx.x; + int offset = threadIdx.y * width + idx; + int y_stride = width * TILE_DIM; + + float betta_reg = (invertible ? (float)betta[idx] : 0.0f); + float gamma_reg = (float)gamma[idx]; + + // Loop across matrix height + float betta_tmp = 0; + float gamma_tmp = 0; + for (int r = threadIdx.y; r < rows; r += TILE_DIM) { + float grad = (float)out_grad1[offset] + (float)out_grad2[offset]; + float val = (invertible ? ((float)vals_hat[offset] - betta_reg) / gamma_reg + : (float)vals_hat[offset]); + betta_tmp += grad; + gamma_tmp += (val * grad); + + offset += y_stride; + } + + betta_buffer[threadIdx.x][threadIdx.y] = betta_tmp; + gamma_buffer[threadIdx.x][threadIdx.y] = gamma_tmp; + + __syncthreads(); + + // Sum the shared buffer. + float s1 = betta_buffer[threadIdx.y][threadIdx.x]; + float s2 = gamma_buffer[threadIdx.y][threadIdx.x]; + +#ifndef __STOCHASTIC_MODE__ + __syncthreads(); +#endif + + for (int i = 1; i < TILE_DIM; i <<= 1) { + s1 += g.shfl_down(s1, i); + s2 += g.shfl_down(s2, i); + } + + if (threadIdx.x == 0) { + int pos = blockIdx.x * TILE_DIM + threadIdx.y; + betta_grad[pos] = s1; + gamma_grad[pos] = s2; + } +} + +template +__global__ void LayerNormBackward1_fused_add(const T* __restrict__ out_grad1, + const T* __restrict__ out_grad2, + const T* __restrict__ X_data, + const T* __restrict__ vars, + const T* __restrict__ means, + T* __restrict__ gamma_grad, + T* __restrict__ betta_grad, + int rows, + int width) +{ + __shared__ float betta_buffer[TILE_DIM][TILE_DIM + 1]; + __shared__ float gamma_buffer[TILE_DIM][TILE_DIM + 1]; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + int idx = blockDim.x * blockIdx.x + threadIdx.x; + int offset = threadIdx.y * width + idx; + int y_stride = width * TILE_DIM; + + int pos = blockIdx.x * TILE_DIM + threadIdx.y; + // Loop across matrix height + + float betta_tmp = 0; + float gamma_tmp = 0; + for (int r = threadIdx.y; r < rows; r += TILE_DIM) { + float grad = (float)out_grad1[offset] + (float)out_grad2[offset]; + float val = (float)X_data[offset]; + val = (val - (float)means[r]) * rsqrtf((float)vars[r]); + betta_tmp += grad; + gamma_tmp += (val * grad); + + offset += y_stride; + } + + betta_buffer[threadIdx.x][threadIdx.y] = betta_tmp; + gamma_buffer[threadIdx.x][threadIdx.y] = gamma_tmp; + + __syncthreads(); + + // Sum the shared buffer. + float s1 = betta_buffer[threadIdx.y][threadIdx.x]; + float s2 = gamma_buffer[threadIdx.y][threadIdx.x]; + +#ifndef __STOCHASTIC_MODE__ + __syncthreads(); +#endif + + for (int i = 1; i < TILE_DIM; i <<= 1) { + s1 += g.shfl_down(s1, i); + s2 += g.shfl_down(s2, i); + } + + if (threadIdx.x == 0) { + betta_grad[pos] = s1; + gamma_grad[pos] = s2; + } +} + +__global__ void LayerNormBackward2_fused_add(const float* out_grad1, + const float* out_grad2, + const float* vals_hat, + const float* gamma, + const float* betta, + const float* vars, + float* inp_grad, + bool invertible, + int row_stride) +{ + int iteration_stride = blockDim.x; + int iterations = row_stride / iteration_stride; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + int row = blockIdx.x; + int id = threadIdx.x; + int wid = id / WARP_SIZE; + int warp_num = iteration_stride >> WARP_SIZE_BITS; + __shared__ float partialSum[MAX_WARP_NUM]; + + out_grad1 += (row * row_stride); + out_grad2 += (row * row_stride); + vals_hat += (row * row_stride); + inp_grad += (row * row_stride); + + float vals_arr[NORM_REG]; + float vals_hat_arr[NORM_REG]; + int high_index = iterations * iteration_stride + id; +#pragma unroll + for (int i = 0; i < iterations; i++) { + float gamma_reg = gamma[i * iteration_stride + id]; + vals_arr[i] = out_grad1[i * iteration_stride + id]; + vals_arr[i] *= gamma_reg; + vals_hat_arr[i] = + (invertible ? (vals_hat[i * iteration_stride + id] - betta[i * iteration_stride + id]) / + gamma_reg + : vals_hat[i * iteration_stride + id]); + } + if ((high_index) < row_stride) { + float gamma_reg = gamma[high_index]; + vals_arr[iterations] = out_grad1[high_index]; + vals_arr[iterations] *= gamma_reg; + vals_hat_arr[iterations] = + (invertible ? (vals_hat[high_index] - betta[high_index]) / gamma_reg + : vals_hat[high_index]); + iterations++; + } + + float var_reg = vars[row]; + + float sum = 0; + for (int i = 0; i < iterations; i++) { + sum += vals_hat_arr[i] * vals_arr[i] * sqrtf(var_reg); + vals_arr[i] *= rsqrtf(var_reg); + } + + for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); } + + if (g.thread_rank() == 0) partialSum[wid] = sum; + + __syncthreads(); + + if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()]; + +#ifndef __STOCHASTIC_MODE__ + __syncthreads(); +#endif + + for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i); + + sum = g.shfl(sum, 0); + sum /= row_stride; + + for (int i = 0; i < iterations; i++) { vals_arr[i] += ((-sum * vals_hat_arr[i]) / var_reg); } + + sum = 0; + for (int i = 0; i < iterations; i++) { sum += vals_arr[i]; } + + for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); } + + if (g.thread_rank() == 0) partialSum[wid] = sum; + + __syncthreads(); + + if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()]; + +#ifndef __STOCHASTIC_MODE__ + __syncthreads(); +#endif + + for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i); + sum = g.shfl(sum, 0); + sum /= row_stride; + + iterations = row_stride / iteration_stride; + for (int i = 0; i < iterations; i++) + inp_grad[i * iteration_stride + id] = + (vals_arr[i] - sum) + out_grad2[i * iteration_stride + id]; + if ((high_index) < row_stride) + inp_grad[high_index] = (vals_arr[iterations] - sum) + out_grad2[high_index]; +} + +__global__ void LayerNormBackward2_fused_add(const __half* out_grad1, + const __half* out_grad2, + const __half* vals_hat, + const __half* gamma, + const __half* betta, + const __half* vars, + __half* inp_grad, + bool invertible, + int row_stride) +{ +#ifdef HALF_PRECISION_AVAILABLE + int iteration_stride = blockDim.x; + int iterations = row_stride / iteration_stride; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + int row = blockIdx.x; + int id = threadIdx.x; + int wid = id / WARP_SIZE; + int warp_num = iteration_stride >> WARP_SIZE_BITS; + __shared__ float partialSum[MAX_WARP_NUM]; + + __half2 vals_arr[NORM_REG]; + float2 vals_arr_f[NORM_REG]; + __half2 vals_hat_arr[NORM_REG]; + + // float2 result[iterations]; + + __half2* inp_grad_h = reinterpret_cast<__half2*>(inp_grad); + const __half2* out_grad_h1 = reinterpret_cast(out_grad1); + const __half2* out_grad_h2 = reinterpret_cast(out_grad2); + const __half2* vals_hat_h = reinterpret_cast(vals_hat); + + inp_grad_h += (row * row_stride); + out_grad_h1 += (row * row_stride); + out_grad_h2 += (row * row_stride); + vals_hat_h += (row * row_stride); + + const __half2* gamma_h = reinterpret_cast(gamma); + const __half2* betta_h = (invertible ? reinterpret_cast(betta) : nullptr); + int high_index = iterations * iteration_stride + id; +#pragma unroll + for (int i = 0; i < iterations; i++) { + __half2 gamma_reg = gamma_h[i * iteration_stride + id]; + vals_arr[i] = out_grad_h1[i * iteration_stride + id]; + vals_arr[i] *= gamma_reg; // out_grad * gamma + vals_hat_arr[i] = + (invertible + ? (vals_hat_h[i * iteration_stride + id] - betta_h[i * iteration_stride + id]) / + gamma_reg + : vals_hat_h[i * iteration_stride + id]); + } + if ((high_index) < row_stride) { + __half2 gamma_reg = gamma_h[high_index]; + vals_arr[iterations] = out_grad_h1[high_index]; + vals_arr[iterations] *= gamma_reg; // out_grad * gamma + vals_hat_arr[iterations] = + (invertible ? (vals_hat_h[high_index] - betta_h[high_index]) / gamma_reg + : vals_hat_h[high_index]); + iterations++; + } + __half var_h = vars[row]; + __half2 var_reg = __halves2half2(var_h, var_h); + + float sum = 0.f; + for (int i = 0; i < iterations; i++) { + __half2 result_h = (vals_hat_arr[i] * vals_arr[i] * h2sqrt(var_reg)); + float2 result_f = __half22float2(result_h); + sum += result_f.x; + sum += result_f.y; + vals_arr[i] *= h2rsqrt(var_reg); + } + + for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); } + + if (g.thread_rank() == 0) partialSum[wid] = sum; + + __syncthreads(); + + if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()]; + +#ifndef __STOCHASTIC_MODE__ + __syncthreads(); +#endif + + for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i); + + sum = g.shfl(sum, 0); + sum /= (2 * row_stride); + __half2 sum_h = __float2half2_rn(sum); + + for (int i = 0; i < iterations; i++) { + __half2 temp = ((-sum_h * vals_hat_arr[i]) / (var_reg)); + vals_arr_f[i] = __half22float2(vals_arr[i]); + float2 temp_f = __half22float2(temp); + vals_arr_f[i].x += temp_f.x; + vals_arr_f[i].y += temp_f.y; + } + sum = 0.f; + for (int i = 0; i < iterations; i++) { + sum += (vals_arr_f[i].x); + sum += (vals_arr_f[i].y); + } + + for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); } + + if (g.thread_rank() == 0) partialSum[wid] = sum; + + __syncthreads(); + + if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()]; + +#ifndef __STOCHASTIC_MODE__ + __syncthreads(); +#endif + + for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i); + + sum = g.shfl(sum, 0); + sum /= (2 * row_stride); + + iterations = row_stride / iteration_stride; + for (int i = 0; i < iterations; i++) { + vals_arr_f[i].x -= sum; + vals_arr_f[i].y -= sum; + __half2 temp = __float22half2_rn(vals_arr_f[i]); + + inp_grad_h[i * iteration_stride + id] = temp + out_grad_h2[i * iteration_stride + id]; + } + if ((high_index) < row_stride) { + vals_arr_f[iterations].x -= sum; + vals_arr_f[iterations].y -= sum; + __half2 temp = __float22half2_rn(vals_arr_f[iterations]); + + inp_grad_h[high_index] = temp + out_grad_h2[high_index]; + } +#endif +} + +template <> +void launch_layerNorm_backward_fused_add(const float* out_grad1, + const float* out_grad2, + const float* vals_hat, + const float* vars, + const float* gamma, + float* gamma_grad, + float* betta_grad, + float* inp_grad, + int batch, + int hidden_dim, + cudaStream_t stream[2], + bool invertible, + const float* betta) +{ + int threads = THREADS; + + dim3 grid_dim(hidden_dim / TILE_DIM); + dim3 block_dim(TILE_DIM, TILE_DIM); + LayerNormBackward1<<>>( + out_grad1, vals_hat, gamma, betta, gamma_grad, betta_grad, batch, hidden_dim, invertible); + + dim3 grid_dim2(batch); + + if (hidden_dim > 16384 && hidden_dim <= 32768) + threads <<= 1; + else if (hidden_dim > 32768 && hidden_dim <= 65536) + threads <<= 2; + else if (hidden_dim > 65536) + throw std::runtime_error("Unsupport hidden_dim."); + + dim3 block_dim2(threads); + LayerNormBackward2_fused_add<<>>( + out_grad1, out_grad2, vals_hat, gamma, betta, vars, inp_grad, invertible, hidden_dim); +} + +template <> +void launch_layerNorm_backward_fused_add<__half>(const __half* out_grad1, + const __half* out_grad2, + const __half* vals_hat, + const __half* vars, + const __half* gamma, + __half* gamma_grad, + __half* betta_grad, + __half* inp_grad, + int batch, + int hidden_dim, + cudaStream_t stream[2], + bool invertible, + const __half* betta) +{ + int threads = THREADS; + + dim3 grid_dim(hidden_dim / TILE_DIM); + dim3 block_dim(TILE_DIM, TILE_DIM); + + LayerNormBackward1<__half><<>>( + out_grad1, vals_hat, gamma, betta, gamma_grad, betta_grad, batch, hidden_dim, invertible); + + dim3 grid_dim2(batch); + + if (hidden_dim > 8192 && hidden_dim <= 16384) + threads <<= 1; + else if (hidden_dim > 16384 && hidden_dim <= 32768) + threads <<= 2; + else if (hidden_dim > 32768 && hidden_dim <= 65536) + threads <<= 3; + else if (hidden_dim > 65536) + throw std::runtime_error("Unsupport hidden_dim."); + + dim3 block_dim2(threads / 2); + LayerNormBackward2_fused_add<<>>( + out_grad1, out_grad2, vals_hat, gamma, betta, vars, inp_grad, invertible, hidden_dim / 2); +} + +/* Backward Normalize (Input-Gradient) + * Using the means and variances from the input + * This type of backward is not invertible! + * We do the backward using the input (X) + */ + +__global__ void LayerNormBackward2_fused_add(const float* out_grad1, + const float* out_grad2, + const float* X_vals, + const float* gamma, + const float* vars, + const float* means, + float* inp_grad, + int row_stride) +{ + int iteration_stride = blockDim.x; + int iterations = row_stride / iteration_stride; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + int row = blockIdx.x; + int id = threadIdx.x; + int wid = id / WARP_SIZE; + int warp_num = iteration_stride >> WARP_SIZE_BITS; + __shared__ float partialSum[MAX_WARP_NUM]; + + float vals_arr[NORM_REG]; + float vals_hat_arr[NORM_REG]; + + out_grad1 += (row * row_stride); + out_grad2 += (row * row_stride); + X_vals += (row * row_stride); + inp_grad += (row * row_stride); + int high_index = iterations * iteration_stride + id; +#pragma unroll + for (int i = 0; i < iterations; i++) { + float gamma_reg = gamma[i * iteration_stride + id]; + vals_arr[i] = out_grad1[i * iteration_stride + id]; + vals_arr[i] *= gamma_reg; + vals_hat_arr[i] = X_vals[i * iteration_stride + id]; + } + if ((high_index) < row_stride) { + float gamma_reg = gamma[high_index]; + vals_arr[iterations] = out_grad1[high_index]; + vals_arr[iterations] *= gamma_reg; + vals_hat_arr[iterations] = X_vals[high_index]; + iterations++; + } + + float var_reg = vars[row]; + float mean_reg = means[row]; + + float sum = 0; + float xu[NORM_REG]; + for (int i = 0; i < iterations; i++) { + xu[i] = (vals_hat_arr[i] - mean_reg); + sum += vals_arr[i] * xu[i]; + vals_arr[i] *= rsqrtf(var_reg); + } + + for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); } + + if (g.thread_rank() == 0) partialSum[wid] = sum; + + __syncthreads(); + + if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()]; + +#ifndef __STOCHASTIC_MODE__ + __syncthreads(); +#endif + + for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i); + + sum = g.shfl(sum, 0); + sum /= row_stride; + + for (int i = 0; i < iterations; i++) { + vals_arr[i] += (-sum * xu[i] * rsqrtf(var_reg) / (var_reg)); + } + + sum = 0; + for (int i = 0; i < iterations; i++) { sum += vals_arr[i]; } + + for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); } + + if (g.thread_rank() == 0) partialSum[wid] = sum; + + __syncthreads(); + + if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()]; + +#ifndef __STOCHASTIC_MODE__ + __syncthreads(); +#endif + + for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i); + sum = g.shfl(sum, 0); + sum /= row_stride; + + iterations = row_stride / iteration_stride; + for (int i = 0; i < iterations; i++) + inp_grad[i * iteration_stride + id] = + (vals_arr[i] - sum) + out_grad2[i * iteration_stride + id]; + if ((high_index) < row_stride) + inp_grad[high_index] = (vals_arr[iterations] - sum) + out_grad2[high_index]; +} + +__global__ void LayerNormBackward2_fused_add(const __half* out_grad1, + const __half* out_grad2, + const __half* X_vals, + const __half* gamma, + const __half* vars, + const __half* means, + __half* inp_grad, + int row_stride) +{ +#ifdef HALF_PRECISION_AVAILABLE + int iteration_stride = blockDim.x; + int iterations = row_stride / iteration_stride; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + int row = blockIdx.x; + int id = threadIdx.x; + int wid = id / WARP_SIZE; + int warp_num = iteration_stride >> WARP_SIZE_BITS; + + __shared__ float partialSum[MAX_WARP_NUM]; + + __half2 vals_arr[NORM_REG]; + float2 vals_arr_f[NORM_REG]; + __half2 vals_hat_arr[NORM_REG]; + + __half2* inp_grad_h = reinterpret_cast<__half2*>(inp_grad); + const __half2* out_grad_h1 = reinterpret_cast(out_grad1); + const __half2* out_grad_h2 = reinterpret_cast(out_grad2); + const __half2* vals_hat_h = reinterpret_cast(X_vals); + + out_grad_h1 += (row * row_stride); + out_grad_h2 += (row * row_stride); + inp_grad_h += (row * row_stride); + vals_hat_h += (row * row_stride); + + const __half2* gamma_h = reinterpret_cast(gamma); + int high_index = iterations * iteration_stride + id; +#pragma unroll + for (int i = 0; i < iterations; i++) { + __half2 gamma_reg = gamma_h[i * iteration_stride + id]; + vals_arr[i] = out_grad_h1[i * iteration_stride + id]; + vals_arr[i] *= gamma_reg; // out_grad * gamma + vals_hat_arr[i] = vals_hat_h[i * iteration_stride + id]; + } + if ((high_index) < row_stride) { + __half2 gamma_reg = gamma_h[high_index]; + vals_arr[iterations] = out_grad_h1[high_index]; + vals_arr[iterations] *= gamma_reg; // out_grad * gamma + vals_hat_arr[iterations] = vals_hat_h[high_index]; + iterations++; + } + + __half mean_h = means[row]; + __half var_h = vars[row]; + __half2 var_reg = __halves2half2(var_h, var_h); + __half2 mean_reg = __halves2half2(mean_h, mean_h); + __half2 xu[NORM_REG]; + + float sum = 0.f; + for (int i = 0; i < iterations; i++) { + xu[i] = (vals_hat_arr[i] - mean_reg); + __half2 result_h = (xu[i] * vals_arr[i]); + float2 result_f = __half22float2(result_h); + sum += result_f.x; + sum += result_f.y; + vals_arr[i] *= h2rsqrt(var_reg); + } + + for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); } + + if (g.thread_rank() == 0) partialSum[wid] = sum; + + __syncthreads(); + + if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()]; + +#ifndef __STOCHASTIC_MODE__ + __syncthreads(); +#endif + + for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i); + + sum = g.shfl(sum, 0); + sum /= (2 * row_stride); + __half2 sum_h = __float2half2_rn(sum); + + for (int i = 0; i < iterations; i++) { + __half2 xu_grad = ((-sum_h * xu[i] * h2rsqrt(var_reg)) / (var_reg)); + vals_arr_f[i] = __half22float2(vals_arr[i]); + float2 xu_grad_f = __half22float2(xu_grad); + vals_arr_f[i].x += xu_grad_f.x; + vals_arr_f[i].y += xu_grad_f.y; + } + + sum = 0.f; + for (int i = 0; i < iterations; i++) { + sum += (vals_arr_f[i].x); + sum += (vals_arr_f[i].y); + } + + for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); } + + if (g.thread_rank() == 0) partialSum[wid] = sum; + + __syncthreads(); + + if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()]; + +#ifndef __STOCHASTIC_MODE__ + __syncthreads(); +#endif + + for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i); + + sum = g.shfl(sum, 0); + sum /= (2 * row_stride); + + iterations = row_stride / iteration_stride; + for (int i = 0; i < iterations; i++) { + vals_arr_f[i].x -= sum; + vals_arr_f[i].y -= sum; + __half2 temp = __float22half2_rn(vals_arr_f[i]); + inp_grad_h[i * iteration_stride + id] = temp + out_grad_h2[i * iteration_stride + id]; + } + if ((high_index) < row_stride) { + vals_arr_f[iterations].x -= sum; + vals_arr_f[iterations].y -= sum; + __half2 temp = __float22half2_rn(vals_arr_f[iterations]); + inp_grad_h[high_index] = temp + out_grad_h2[high_index]; + } +#endif +} + +template <> +void launch_layerNorm_backward_fused_add(const float* out_grad1, + const float* out_grad2, + const float* X_data, + const float* vars, + const float* means, + const float* gamma, + float* gamma_grad, + float* betta_grad, + float* inp_grad, + int batch, + int hidden_dim, + cudaStream_t stream[2]) +{ + int threads = THREADS; + + dim3 grid_dim(hidden_dim / TILE_DIM); + dim3 block_dim(TILE_DIM, TILE_DIM); + + LayerNormBackward1<<>>( + out_grad1, X_data, vars, means, gamma_grad, betta_grad, batch, hidden_dim); + + dim3 grid_dim2(batch); + + if (hidden_dim > 16384 && hidden_dim <= 32768) + threads <<= 1; + else if (hidden_dim > 32768 && hidden_dim <= 65536) + threads <<= 2; + else if (hidden_dim > 65536) + throw std::runtime_error("Unsupport hidden_dim."); + + dim3 block_dim2(threads); + LayerNormBackward2_fused_add<<>>( + out_grad1, out_grad2, X_data, gamma, vars, means, inp_grad, hidden_dim); +} + +template <> +void launch_layerNorm_backward_fused_add<__half>(const __half* out_grad1, + const __half* out_grad2, + const __half* X_data, + const __half* vars, + const __half* means, + const __half* gamma, + __half* gamma_grad, + __half* betta_grad, + __half* inp_grad, + int batch, + int hidden_dim, + cudaStream_t stream[2]) +{ + int threads = THREADS; + + dim3 grid_dim(hidden_dim / TILE_DIM); + dim3 block_dim(TILE_DIM, TILE_DIM); + + LayerNormBackward1<__half><<>>( + out_grad1, X_data, vars, means, gamma_grad, betta_grad, batch, hidden_dim); + + dim3 grid_dim2(batch); + + if (hidden_dim > 8192 && hidden_dim <= 16384) + threads <<= 1; + else if (hidden_dim > 16384 && hidden_dim <= 32768) + threads <<= 2; + else if (hidden_dim > 32768 && hidden_dim <= 65536) + threads <<= 3; + else if (hidden_dim > 65536) + throw std::runtime_error("Unsupport hidden_dim."); + + dim3 block_dim2(threads / 2); + LayerNormBackward2_fused_add<<>>( + out_grad1, out_grad2, X_data, gamma, vars, means, inp_grad, hidden_dim / 2); +} diff --git a/toolbox/DeepSpeed/v0.15.3/patches/csrc/transformer/softmax_kernels.cu b/toolbox/DeepSpeed/v0.15.3/patches/csrc/transformer/softmax_kernels.cu new file mode 100644 index 0000000000000000000000000000000000000000..e4e1968f163b3cb5f8ef5f6d4c7068ace8f67650 --- /dev/null +++ b/toolbox/DeepSpeed/v0.15.3/patches/csrc/transformer/softmax_kernels.cu @@ -0,0 +1,718 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +/* Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. */ +/* All Rights Reserved. */ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include +#include "custom_cuda_layers.h" +#include "general_kernels.h" + +namespace cg = cooperative_groups; + +dim3 get_attn_softmax_grid(int batch_size, int heads, int sequence_length, int threads) +{ + int seq_length4 = sequence_length / 4; + int block_compute_size = + (seq_length4 < threads ? (int)pow(2.0, floor(log2((float)(threads / seq_length4)))) : 1); + // Note that the Y and Z dimensions are limited to 65535, while X is basically unlimited: + // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications + // The batch size is typically relatively small, while the sequence length could potentially be + // arbitrarily large. We therefore place the batch size second to avoid hitting the Y limit. + unsigned x = heads * sequence_length / block_compute_size; + unsigned y = batch_size; + return {x, y}; +} + +// Fused attention + softmax +template +__global__ void attn_softmax(float* vals, + const float* attn_mask, + int heads, + int seq_length, + int iterations) +{ + __shared__ float partialSum[MAX_WARP_NUM]; + + int warp_num = blockDim.x >> WARP_SIZE_BITS; + + int iteration_stride = blockDim.x; + int block_width = blockStride * seq_length; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + int batch = blockIdx.y; + int row = blockIdx.x; + int max_threads_in_sequence = std::max(seq_length, tbSeq); + int seq_lane = threadIdx.x % max_threads_in_sequence; + + int data_offset = batch * (gridDim.x * block_width) + row * block_width + + (threadIdx.x / max_threads_in_sequence) * seq_length; + int mask_offset = batch * seq_length; + + int wid = threadIdx.x >> WARP_SIZE_BITS; + int lane = threadIdx.x & 0x1f; + + float4* val_cast = reinterpret_cast(vals); + const float4* attn_mask_cast = reinterpret_cast(attn_mask); + + float4 data[MAX_THREAD_ITERATIONS]; + + float max_val = minus_infinity; + + for (int i = 0; i < iterations; i++) { + int data_id = i * iteration_stride + seq_lane; + if (data_id < seq_length) { + float4 mask = attn_mask_cast[mask_offset + data_id]; + data[i] = val_cast[data_offset + data_id]; + + data[i].x += mask.x; + data[i].y += mask.y; + data[i].z += mask.z; + data[i].w += mask.w; + + max_val = (data[i].x > max_val ? data[i].x : max_val); + max_val = (data[i].y > max_val ? data[i].y : max_val); + max_val = (data[i].z > max_val ? data[i].z : max_val); + max_val = (data[i].w > max_val ? data[i].w : max_val); + } else { + data[i].x = minus_infinity; + data[i].y = minus_infinity; + data[i].z = minus_infinity; + data[i].w = minus_infinity; + } + } + + for (int i = 1; i < tbSize; i *= 2) { + auto temp = g.shfl_xor(max_val, i); + max_val = (temp > max_val ? temp : max_val); + } + + if (seq_length > tbSize) { + if (lane == 0) partialSum[wid] = max_val; + b.sync(); + + if (lane < warp_num) max_val = partialSum[lane]; + +#ifndef __STOCHASTIC_MODE__ + b.sync(); +#endif + + int iters = warp_num; + if (seq_length < iteration_stride) + iters = warp_num / (iteration_stride / max_threads_in_sequence); + + for (int i = 1; i < iters; i *= 2) { + auto temp = g.shfl_xor(max_val, i); + max_val = (temp > max_val ? temp : max_val); + } + + max_val = g.shfl(max_val, threadIdx.x / tbSize); + } + + float sum = 0; + for (int i = 0; i < iterations; i++) { + data[i].x = __expf(data[i].x - max_val); + data[i].y = __expf(data[i].y - max_val); + data[i].z = __expf(data[i].z - max_val); + data[i].w = __expf(data[i].w - max_val); + + sum += (data[i].x + data[i].y + data[i].z + data[i].w); + } + + for (int i = 1; i < tbSize; i *= 2) { sum += g.shfl_xor(sum, i); } + + if (seq_length > tbSize) { + if (lane == 0) partialSum[wid] = sum; + b.sync(); + + if (lane < warp_num) sum = partialSum[lane]; + +#ifndef __STOCHASTIC_MODE__ + b.sync(); +#endif + + int iters = warp_num; + if (seq_length < iteration_stride) + iters = warp_num / (iteration_stride / max_threads_in_sequence); + + for (int i = 1; i < iters; i *= 2) { sum += g.shfl_xor(sum, i); } + + sum = g.shfl(sum, threadIdx.x / tbSize); + } + + sum += 1e-6; + + for (int i = 0; i < iterations; i++) { + data[i].x /= sum; + data[i].y /= sum; + data[i].z /= sum; + data[i].w /= sum; + + int data_id = i * iteration_stride + seq_lane; + if (data_id < seq_length) val_cast[data_offset + data_id] = data[i]; + } +} + +template +__global__ void attn_softmax(__half* vals, + const __half* attn_mask, + int heads, + int seq_length, + int iterations) +{ +#ifdef HALF_PRECISION_AVAILABLE + __shared__ float partialSum[MAX_WARP_NUM]; + + int warp_num = blockDim.x >> WARP_SIZE_BITS; + + int iteration_stride = blockDim.x; + int block_width = blockStride * seq_length; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + int batch = blockIdx.y; + int row = blockIdx.x; + int max_threads_in_sequence = std::max(seq_length, tbSeq); + int seq_lane = threadIdx.x % max_threads_in_sequence; + + int data_offset = batch * (gridDim.x * block_width) + row * block_width + + (threadIdx.x / max_threads_in_sequence) * seq_length; + int mask_offset = batch * seq_length; + + int wid = threadIdx.x >> WARP_SIZE_BITS; + int lane = threadIdx.x & 0x1f; + + float2* val_cast = reinterpret_cast(vals); + const float2* attn_mask_cast = reinterpret_cast(attn_mask); + + val_cast += data_offset; + attn_mask_cast += mask_offset; + + float2 low_data[MAX_THREAD_ITERATIONS]; + float2 high_data[MAX_THREAD_ITERATIONS]; + + float max_val = minus_infinity; + + for (int i = 0; i < iterations; i++) { + int data_id = i * iteration_stride + seq_lane; + if (data_id < seq_length) { + float2 data = val_cast[data_id]; + float2 mask = attn_mask_cast[data_id]; + + __half2* data_arr = reinterpret_cast<__half2*>(&data); + __half2* mask_arr = reinterpret_cast<__half2*>(&mask); + + low_data[i] = __half22float2(data_arr[0]); + high_data[i] = __half22float2(data_arr[1]); + float2 low_mask = __half22float2(mask_arr[0]); + float2 high_mask = __half22float2(mask_arr[1]); + + low_data[i].x += low_mask.x; + low_data[i].y += low_mask.y; + high_data[i].x += high_mask.x; + high_data[i].y += high_mask.y; + + max_val = (low_data[i].x > max_val ? low_data[i].x : max_val); + max_val = (low_data[i].y > max_val ? low_data[i].y : max_val); + max_val = (high_data[i].x > max_val ? high_data[i].x : max_val); + max_val = (high_data[i].y > max_val ? high_data[i].y : max_val); + } + } + + for (int i = 1; i < tbSize; i *= 2) { + auto temp = g.shfl_xor(max_val, i); + max_val = (temp > max_val ? temp : max_val); + } + + if (seq_length > tbSize) { + if (lane == 0) partialSum[wid] = max_val; + b.sync(); + + if (lane < warp_num) max_val = partialSum[lane]; + +#ifndef __STOCHASTIC_MODE__ + b.sync(); +#endif + + int iters = warp_num; + if (seq_length < iteration_stride) + iters = warp_num / (iteration_stride / max_threads_in_sequence); + + for (int i = 1; i < iters; i *= 2) { + auto temp = g.shfl_xor(max_val, i); + max_val = (temp > max_val ? temp : max_val); + } + + max_val = g.shfl(max_val, threadIdx.x / tbSize); + } + + float sum = 0; + for (int i = 0; i < iterations; i++) { + int data_id = i * iteration_stride + seq_lane; + if (data_id < seq_length) { + low_data[i].x = __expf(low_data[i].x - max_val); + low_data[i].y = __expf(low_data[i].y - max_val); + high_data[i].x = __expf(high_data[i].x - max_val); + high_data[i].y = __expf(high_data[i].y - max_val); + + sum += (low_data[i].x + low_data[i].y + high_data[i].x + high_data[i].y); + } + } + + for (int i = 1; i < tbSize; i *= 2) { sum += g.shfl_xor(sum, i); } + + if (seq_length > tbSize) { + if (lane == 0) partialSum[wid] = sum; + b.sync(); + + if (lane < warp_num) sum = partialSum[lane]; + +#ifndef __STOCHASTIC_MODE__ + b.sync(); +#endif + + int iters = warp_num; + if (seq_length < iteration_stride) + iters = warp_num / (iteration_stride / max_threads_in_sequence); + + for (int i = 1; i < iters; i *= 2) { sum += g.shfl_xor(sum, i); } + + sum = g.shfl(sum, threadIdx.x / tbSize); + } + + sum += 1e-6; + + for (int i = 0; i < iterations; i++) { + int data_id = i * iteration_stride + seq_lane; + if (data_id < seq_length) { + float2 result_f; + __half2* result_h = reinterpret_cast<__half2*>(&result_f); + + low_data[i].x /= sum; + low_data[i].y /= sum; + high_data[i].x /= sum; + high_data[i].y /= sum; + + result_h[0] = __float22half2_rn(low_data[i]); + result_h[1] = __float22half2_rn(high_data[i]); + + val_cast[data_id] = result_f; + } + } + +#endif +} + +template +void launch_attn_softmax(T*, const T*, int, int, int, cudaStream_t); + +template <> +void launch_attn_softmax(float* vals, + const float* attn_mask, + int batch_size, + int heads, + int sequence_length, + cudaStream_t stream) +{ + const int threads = 128; + int seq_length4 = sequence_length / 4; + + dim3 grid_dim = get_attn_softmax_grid(batch_size, heads, sequence_length, threads); + + int subblock_max_workload = MAX_THREAD_ITERATIONS * 4 * threads; + + dim3 block_dim(seq_length4 > threads ? ((sequence_length + subblock_max_workload - 1) / + subblock_max_workload * threads) + : threads); + int iterations = + (sequence_length < subblock_max_workload ? (seq_length4 + threads - 1) / threads + : MAX_THREAD_ITERATIONS); + + if (sequence_length <= 8) + attn_softmax<2, (threads / 2), 2> + <<>>(vals, attn_mask, heads, seq_length4, iterations); + else if (sequence_length <= 16) + attn_softmax<4, (threads / 4), 4> + <<>>(vals, attn_mask, heads, seq_length4, iterations); + else if (sequence_length <= 32) + attn_softmax<8, (threads / 8), 8> + <<>>(vals, attn_mask, heads, seq_length4, iterations); + else if (sequence_length <= 64) + attn_softmax<16, (threads / 16), 16> + <<>>(vals, attn_mask, heads, seq_length4, iterations); + else if (sequence_length <= 128) + attn_softmax<32, (threads / 32), 32> + <<>>(vals, attn_mask, heads, seq_length4, iterations); + else if (sequence_length <= 256) + attn_softmax<32, (threads / 64), 64> + <<>>(vals, attn_mask, heads, seq_length4, iterations); + else { + const int threads = 256; + dim3 grid_dim = get_attn_softmax_grid(batch_size, heads, sequence_length, threads); + + int subblock_max_workload = MAX_THREAD_ITERATIONS * 4 * threads; + + dim3 block_dim(seq_length4 > threads ? ((sequence_length + subblock_max_workload - 1) / + subblock_max_workload * threads) + : threads); + iterations = + (sequence_length < subblock_max_workload ? (seq_length4 + threads - 1) / threads + : MAX_THREAD_ITERATIONS); + if (sequence_length <= 512) + attn_softmax<32, (threads / 128), 128><<>>( + vals, attn_mask, heads, seq_length4, iterations); + else if (sequence_length < (MAX_THREADS * MAX_THREAD_ITERATIONS * 4)) + attn_softmax<32, 1, 128><<>>( + vals, attn_mask, heads, seq_length4, iterations); + else + throw std::runtime_error( + "Unsupport Seq_Length! Check the restriction of the max_threads and " + "max_thread_iterations!"); + } +} + +template <> +void launch_attn_softmax<__half>(__half* vals, + const __half* attn_mask, + int batch_size, + int heads, + int sequence_length, + cudaStream_t stream) +{ + const int threads = 128; + int seq_length4 = sequence_length / 4; + + dim3 grid_dim = get_attn_softmax_grid(batch_size, heads, sequence_length, threads); + + int subblock_max_workload = MAX_THREAD_ITERATIONS * 4 * threads; + + dim3 block_dim(seq_length4 > threads ? ((sequence_length + subblock_max_workload - 1) / + subblock_max_workload * threads) + : threads); + + int iterations = + (sequence_length < subblock_max_workload ? (seq_length4 + threads - 1) / threads + : MAX_THREAD_ITERATIONS); + + if (sequence_length <= 8) + attn_softmax<2, (threads / 2), 2> + <<>>(vals, attn_mask, heads, seq_length4, iterations); + else if (sequence_length <= 16) + attn_softmax<4, (threads / 4), 4> + <<>>(vals, attn_mask, heads, seq_length4, iterations); + else if (sequence_length <= 32) + attn_softmax<8, (threads / 8), 8> + <<>>(vals, attn_mask, heads, seq_length4, iterations); + else if (sequence_length <= 64) + attn_softmax<16, (threads / 16), 16> + <<>>(vals, attn_mask, heads, seq_length4, iterations); + else if (sequence_length <= 128) + attn_softmax<32, (threads / 32), 32> + <<>>(vals, attn_mask, heads, seq_length4, iterations); + else if (sequence_length <= 256) + attn_softmax<32, (threads / 64), 64> + <<>>(vals, attn_mask, heads, seq_length4, iterations); + else { + const int threads = 256; + dim3 grid_dim = get_attn_softmax_grid(batch_size, heads, sequence_length, threads); + + int subblock_max_workload = MAX_THREAD_ITERATIONS * 4 * threads; + + dim3 block_dim(seq_length4 > threads ? ((sequence_length + subblock_max_workload - 1) / + subblock_max_workload * threads) + : threads); + iterations = + (sequence_length < subblock_max_workload ? (seq_length4 + threads - 1) / threads + : MAX_THREAD_ITERATIONS); + if (sequence_length <= 512) + attn_softmax<32, (threads / 128), 128><<>>( + vals, attn_mask, heads, seq_length4, iterations); + else if (sequence_length < (MAX_THREADS * MAX_THREAD_ITERATIONS * 4)) + attn_softmax<32, 1, 128><<>>( + vals, attn_mask, heads, seq_length4, iterations); + else + throw std::runtime_error( + "Unsupport Seq_Length! Check the restriction of the max_threads and " + "max_thread_iterations!"); + } +} + +template +__global__ void softmax_backward_kernel(T* out_grad, const T* soft_inp, int seq_length) +{ + __shared__ float partialSum[MAX_WARP_NUM]; + + int warp_num = blockDim.x >> WARP_SIZE_BITS; // warp-count = num_threads / WARP_SIZE (32) + + int iteration_stride = blockDim.x; + int block_width = blockStride * seq_length; + + int iterations = (seq_length < (MAX_THREAD_ITERATIONS * iteration_stride) + ? (seq_length + iteration_stride - 1) / iteration_stride + : MAX_THREAD_ITERATIONS); + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + int row = blockIdx.x; + int id = threadIdx.x; + + int wid = id >> WARP_SIZE_BITS; + int lane = id & 0x1f; + + T val_reg[MAX_THREAD_ITERATIONS]; + T soft_reg[MAX_THREAD_ITERATIONS]; + float grad_reg = 0.0f; + +#pragma unroll + for (int i = 0; i < iterations; i++) { + int data_id = i * iteration_stride + id; + if (data_id < block_width) { + val_reg[i] = out_grad[row * block_width + data_id]; + soft_reg[i] = soft_inp[row * block_width + data_id]; + + grad_reg += ((float)val_reg[i] * + (float)soft_reg[i]); // if done in half, the multiplication, we may lose + // 2% of accuracy in computation!! + } + } + for (int i = 1; i < tbSize; i *= 2) grad_reg += g.shfl_xor(grad_reg, i); + + if (seq_length > tbSize) { + if (lane == 0) partialSum[wid] = grad_reg; + b.sync(); + + if (lane < warp_num) grad_reg = partialSum[lane]; + + int iters = warp_num; + if (seq_length < iteration_stride) iters = warp_num / (iteration_stride / seq_length); + + for (int i = 1; i < iters; i *= 2) grad_reg += g.shfl_xor(grad_reg, i); + + grad_reg = g.shfl(grad_reg, id / tbSize); + } + + for (int i = 0; i < iterations; i++) { + int data_id = i * iteration_stride + id; + if (data_id < block_width) { + float temp = (float)soft_reg[i] * ((float)val_reg[i] - grad_reg); + out_grad[row * block_width + data_id] = (T)temp; + } + } +} + +template +__global__ void softmax_backward_kernel_v2(T* grad /* input & output*/, + const T* output, + int softmax_length) +{ + int batch_idx = blockIdx.x * blockDim.y + threadIdx.y; + int offset = batch_idx * softmax_length + threadIdx.x; + + grad += offset; + output += offset; + + T grad_reg[ITERATIONS]; + T output_reg[ITERATIONS]; + float sum = 0.0; + +#pragma unroll + for (int i = 0; i < ITERATIONS; ++i) { + int curr_idx = threadIdx.x + i * WARP_SIZE; + if (curr_idx < softmax_length) { + grad_reg[i] = grad[i * WARP_SIZE]; + output_reg[i] = output[i * WARP_SIZE]; + sum += (float)grad_reg[i] * (float)output_reg[i]; + } + } + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_xor(sum, i); + +#pragma unroll + for (int i = 0; i < ITERATIONS; ++i) { + int curr_idx = threadIdx.x + i * WARP_SIZE; + if (curr_idx < softmax_length) + grad[i * WARP_SIZE] = (float)output_reg[i] * ((float)grad_reg[i] - sum); + } +} + +__global__ void softmax_backward_kernel_arbitrary_length(__half* grad /* input & output*/, + const __half* output, + int softmax_length) +{ + int batch_idx = blockIdx.x * blockDim.y + threadIdx.y; + int offset = batch_idx * softmax_length + threadIdx.x; + + const float4* output_cast = reinterpret_cast(output); + float4* grad_cast = reinterpret_cast(grad); + + grad_cast += offset; + output_cast += offset; + + float sum = 0.0; + int curr_idx = threadIdx.x; + while (curr_idx < softmax_length) { + float4 out_reg = output_cast[curr_idx]; + float4 grad_reg = grad_cast[curr_idx]; + __half2* out_h = reinterpret_cast<__half2*>(&out_reg); + __half2* grad_h = reinterpret_cast<__half2*>(&grad_reg); +#pragma unroll + for (int m = 0; m < 4; m++) grad_h[m] *= out_h[m]; + sum += ((float)grad_h[0].x + (float)grad_h[0].y + (float)grad_h[1].x + (float)grad_h[1].y) + + ((float)grad_h[2].x + (float)grad_h[2].y + (float)grad_h[3].x + (float)grad_h[3].y); + curr_idx += WARP_SIZE; + } + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + +#pragma unroll + for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_xor(sum, i); + + curr_idx = threadIdx.x; + while (curr_idx < softmax_length) { + float4 out_reg = output_cast[curr_idx]; + float4 grad_reg = grad_cast[curr_idx]; + __half* grad_h = reinterpret_cast<__half*>(&grad_reg); + __half* out_h = reinterpret_cast<__half*>(&out_reg); + +#pragma unroll + for (int m = 0; m < 8; m++) grad_h[m] = (float)out_h[m] * ((float)grad_h[m] - sum); + + grad_cast[curr_idx] = grad_reg; + curr_idx += WARP_SIZE; + } +} + +__global__ void softmax_backward_kernel_arbitrary_length(float* grad /* input & output*/, + const float* output, + int softmax_length) +{ + int batch_idx = blockIdx.x * blockDim.y + threadIdx.y; + int offset = batch_idx * softmax_length + threadIdx.x; + + const float4* output_cast = reinterpret_cast(output); + float4* grad_cast = reinterpret_cast(grad); + + grad_cast += offset; + output_cast += offset; + + float sum = 0.0; + int curr_idx = threadIdx.x; + while (curr_idx < softmax_length) { + float4 out_reg = output_cast[curr_idx]; + float4 grad_reg = grad_cast[curr_idx]; + + grad_reg.x *= out_reg.x; + grad_reg.y *= out_reg.y; + grad_reg.z *= out_reg.z; + grad_reg.w *= out_reg.w; + sum += (grad_reg.x + grad_reg.y + grad_reg.z + grad_reg.w); + + curr_idx += WARP_SIZE; + } + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + +#pragma unroll + for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_xor(sum, i); + + curr_idx = threadIdx.x; + while (curr_idx < softmax_length) { + float4 out_reg = output_cast[curr_idx]; + float4 grad_reg = grad_cast[curr_idx]; + grad_reg.x = out_reg.x * (grad_reg.x - sum); + grad_reg.y = out_reg.y * (grad_reg.y - sum); + grad_reg.z = out_reg.z * (grad_reg.z - sum); + grad_reg.w = out_reg.w * (grad_reg.w - sum); + + grad_cast[curr_idx] = grad_reg; + curr_idx += WARP_SIZE; + } +} + +template +void launch_attn_softmax_backward_v2(T* out_grad, + const T* soft_inp, + int batch_size, + int heads, + int seq_length, + cudaStream_t stream) +{ + const int warps_per_block = 4; + dim3 grid_dim(batch_size * heads * seq_length / warps_per_block); + dim3 block_dim(WARP_SIZE, warps_per_block); + + if (seq_length <= 32) + softmax_backward_kernel_v2 + <<>>(out_grad, soft_inp, seq_length); + else if (seq_length <= 64) + softmax_backward_kernel_v2 + <<>>(out_grad, soft_inp, seq_length); + else if (seq_length <= 128) + softmax_backward_kernel_v2 + <<>>(out_grad, soft_inp, seq_length); + else if (seq_length <= 256) + softmax_backward_kernel_v2 + <<>>(out_grad, soft_inp, seq_length); + else if (seq_length <= 384) + softmax_backward_kernel_v2 + <<>>(out_grad, soft_inp, seq_length); + else if (seq_length <= 512) + softmax_backward_kernel_v2 + <<>>(out_grad, soft_inp, seq_length); + else if (seq_length <= 768) + softmax_backward_kernel_v2 + <<>>(out_grad, soft_inp, seq_length); + else if (seq_length <= 1024) + softmax_backward_kernel_v2 + <<>>(out_grad, soft_inp, seq_length); + else if (seq_length <= 2048) + softmax_backward_kernel_v2 + <<>>(out_grad, soft_inp, seq_length); + else if (seq_length <= 4096) + softmax_backward_kernel_v2 + <<>>(out_grad, soft_inp, seq_length); + else if (seq_length <= 8192) + softmax_backward_kernel_v2 + <<>>(out_grad, soft_inp, seq_length); + else + softmax_backward_kernel_arbitrary_length<<>>( + out_grad, soft_inp, seq_length / (4 << ((sizeof(T) & 2) >> 1))); +} + +template void launch_attn_softmax_backward_v2<__half>(__half* out_grad, + const __half* soft_inp, + int batch_size, + int heads, + int seq_length, + cudaStream_t stream); +template void launch_attn_softmax_backward_v2(float* out_grad, + const float* soft_inp, + int batch_size, + int heads, + int seq_length, + cudaStream_t stream); diff --git a/toolbox/DeepSpeed/v0.15.3/patches/csrc/transformer/transform_kernels.cu b/toolbox/DeepSpeed/v0.15.3/patches/csrc/transformer/transform_kernels.cu new file mode 100644 index 0000000000000000000000000000000000000000..1dee0b6fa7474676fab8220d811fc4bd34b52d75 --- /dev/null +++ b/toolbox/DeepSpeed/v0.15.3/patches/csrc/transformer/transform_kernels.cu @@ -0,0 +1,597 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +/* Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. */ +/* All Rights Reserved. */ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include "custom_cuda_layers.h" + +#define rows_trans 16 +#define cols_trans 16 + +template +__global__ void Transpose_Kernel(const T* inp, T* out, int row_width, int col_width) +{ + __shared__ T data_block[rows_trans * (cols_trans + 1)]; + + int r = threadIdx.x / cols_trans; + int c = threadIdx.x % cols_trans; + + int m = row_width / cols_trans; + + int i = blockIdx.x / m * rows_trans + r; + int j = blockIdx.x % m * cols_trans + c; + + int row_stride = rows_trans / ((rows_trans * cols_trans + THREADS - 1) / THREADS); + + for (int k = 0; k < rows_trans; k += row_stride) + data_block[(k + r) * cols_trans + c] = inp[(i + k) * row_width + j]; + + __syncthreads(); + + i = blockIdx.x % m * rows_trans + r; + j = blockIdx.x / m * cols_trans + c; + + for (int k = 0; k < rows_trans; k += row_stride) + out[(i + k) * col_width + j] = data_block[c * cols_trans + r + k]; +} + +template <> +void Transpose<__half>(const __half* inp_mat, + __half* out_mat, + int rows, + int cols, + cudaStream_t stream) +{ + int threads = THREADS; + + Transpose_Kernel<__half><<<(rows * cols + threads - 1) / threads, threads, 0, stream>>>( + inp_mat, out_mat, cols, rows); +} + +template <> +void Transpose(const float* inp_mat, float* out_mat, int rows, int cols, cudaStream_t stream) +{ + int threads = THREADS; + + Transpose_Kernel<<<(rows * cols + threads - 1) / threads, threads, 0, stream>>>( + inp_mat, out_mat, cols, rows); +} + +template +__global__ void transform_0213(T* output, + const T* vals, + int hidden_dim, + int seq_length, + int heads, + int head_ext); + +template <> +__global__ void transform_0213(float* output, + const float* vals, + int hidden_dim, + int seq_length, + int heads, + int head_ext) +{ + int d0_stride = hidden_dim * seq_length; + int d1_stride = hidden_dim; + int d2_stride = hidden_dim / heads; + + int d0_out_stride = d0_stride; + int d1_out_stride = d2_stride; + int d2_out_stride = d2_stride * seq_length; + + int d0 = blockIdx.x; // Batch + int d1 = blockIdx.y / head_ext; // Sequence ID (0-127) + int d2 = threadIdx.y + (blockIdx.y % head_ext) * (heads / head_ext); // Head (0-11) + int d3 = threadIdx.x; // Values (groups of 4) + + const float4* vals_vec = reinterpret_cast(vals); + float4* output_vec = reinterpret_cast(output); + + float4 inputs = vals_vec[d0 * d0_stride + d1 * d1_stride + d2 * d2_stride + d3]; + output_vec[d0 * d0_out_stride + d1 * d1_out_stride + d2 * d2_out_stride + d3] = inputs; +} + +template <> +__global__ void transform_0213<__half>(__half* output, + const __half* vals, + int hidden_dim, + int seq_length, + int heads, + int head_ext) +{ +#ifdef HALF_PRECISION_AVAILABLE + + int d0_stride = hidden_dim * seq_length; + int d1_stride = hidden_dim; + int d2_stride = hidden_dim / heads; + + int d0_out_stride = d0_stride; + int d1_out_stride = d2_stride; + int d2_out_stride = d2_stride * seq_length; + + int d0 = blockIdx.x; // Batch + int d1 = blockIdx.y / head_ext; // Sequence ID (0-127) + int d2 = threadIdx.y + (blockIdx.y % head_ext) * (heads / head_ext); // Head (0-11) + int d3 = threadIdx.x; // Values (groups of 4) + + float4 vals_arr[1]; + + const float4* vals_vec = reinterpret_cast(vals); + float4* output_vec = reinterpret_cast(output); + + vals_arr[0] = vals_vec[d0 * d0_stride + d1 * d1_stride + d2 * d2_stride + d3]; + output_vec[d0 * d0_out_stride + d1 * d1_out_stride + d2 * d2_out_stride + d3] = vals_arr[0]; +#endif +} + +template <> +void launch_transform_0213(float* output, + const float* vals, + int batch_size, + int seq_length, + int hidden_dim, + int heads, + cudaStream_t stream) +{ + hidden_dim >>= 2; + int head_ext = (hidden_dim - 1) / MAX_THREADS + 1; + dim3 block_dim(hidden_dim / heads, (heads / head_ext)); + dim3 grid_dim(batch_size, (seq_length * head_ext)); + + transform_0213 + <<>>(output, vals, hidden_dim, seq_length, heads, head_ext); +} + +template <> +void launch_transform_0213<__half>(__half* output, + const __half* vals, + int batch_size, + int seq_length, + int hidden_dim, + int heads, + cudaStream_t stream) +{ + hidden_dim >>= 3; + int head_ext = (hidden_dim - 1) / MAX_THREADS + 1; + dim3 block_dim(hidden_dim / heads, (heads / head_ext)); + dim3 grid_dim(batch_size, (seq_length * head_ext)); + transform_0213<__half> + <<>>(output, vals, hidden_dim, seq_length, heads, head_ext); +} + +// Bias add +template +__global__ void bias_add_transform_0213(T* output, + const T* vals, + const T* bias, + int hidden_dim, + int seq_length, + int heads, + int head_ext); + +template <> +__global__ void bias_add_transform_0213(float* output, + const float* vals, + const float* bias, + int hidden_dim, + int seq_length, + int heads, + int head_ext) +{ + int d0_stride = hidden_dim * seq_length; + int d1_stride = hidden_dim; + int d2_stride = hidden_dim / heads; + + int d0_out_stride = d0_stride; + int d1_out_stride = d2_stride; + int d2_out_stride = d2_stride * seq_length; + + int d0 = blockIdx.x; // Batch + int d1 = blockIdx.y; // Sequence ID (0-127) + int cnt = blockIdx.z / head_ext; // Hidden count + int d2 = threadIdx.y + (blockIdx.z % head_ext) * (heads / head_ext); // Head (0-11) + int d3 = threadIdx.x; // Values (groups of 4) + + const float4* vals_vec = reinterpret_cast(vals); + const float4* bias_vec = reinterpret_cast(bias); + float4* output_vec = reinterpret_cast(output); + + float4 inputs = vals_vec[d0 * d0_stride * (gridDim.z / head_ext) + cnt * d1_stride + + d1 * d1_stride * (gridDim.z / head_ext) + d2 * d2_stride + d3]; + float4 biases = bias_vec[cnt * d1_stride + d2 * d2_stride + d3]; + + float4 outputs; + outputs.x = inputs.x + biases.x; + outputs.y = inputs.y + biases.y; + outputs.z = inputs.z + biases.z; + outputs.w = inputs.w + biases.w; + + output_vec[cnt * d0_out_stride * gridDim.x + d0 * d0_out_stride + d1 * d1_out_stride + + d2 * d2_out_stride + d3] = outputs; +} + +#define ATTN_H 3 +#define MAX_SEQ_LINE 10 + +template <> +__global__ void bias_add_transform_0213<__half>(__half* output, + const __half* vals, + const __half* bias, + int hidden_dim, + int seq_length, + int heads, + int head_ext) +{ +#ifdef HALF_PRECISION_AVAILABLE + + int d0_stride = hidden_dim * seq_length; + int d1_stride = hidden_dim; + int d2_stride = hidden_dim / heads; + + int d2_out_stride = d2_stride * seq_length; + + int d0 = blockIdx.x; // Batch + int d1 = blockIdx.y; // Sequence ID (0-127) + int cnt = blockIdx.z / head_ext; // Hidden count + int d2 = threadIdx.y + (blockIdx.z % head_ext) * (heads / head_ext); // Head (0-11) + int d3 = threadIdx.x; // Values (groups of 4) + + float4 vals_arr; + float4 bias_arr; + float4 output_arr; + __half2* vals_half = reinterpret_cast<__half2*>(&vals_arr); + __half2* bias_half = reinterpret_cast<__half2*>(&bias_arr); + __half2* output_half = reinterpret_cast<__half2*>(&output_arr); + + const float4* vals_vec = reinterpret_cast(vals); + const float4* bias_vec = reinterpret_cast(bias); + float4* output_vec = reinterpret_cast(output); + + vals_vec += (d0 * d0_stride * (gridDim.z / head_ext)); + vals_vec += (d1 * d1_stride * (gridDim.z / head_ext)); + vals_vec += (cnt * d1_stride); + vals_vec += (d2 * d2_stride); + + bias_vec += (cnt * d1_stride); + bias_vec += (d2 * d2_stride); + + output_vec += (cnt * d0_stride * gridDim.x); + output_vec += (d1 * d2_stride); + output_vec += (d0 * d0_stride); + output_vec += (d2 * d2_out_stride); + + bias_arr = bias_vec[d3]; + vals_arr = vals_vec[d3]; + +#if defined(__ACC_HALF__) + output_half[0] = vals_half[0] + bias_half[0]; + output_half[1] = vals_half[1] + bias_half[1]; + output_half[2] = vals_half[2] + bias_half[2]; + output_half[3] = vals_half[3] + bias_half[3]; +#else + float2 bias_arr_f[4]; + float2 vals_arr_f[4]; +#pragma unroll + for (int l = 0; l < 4; l++) { + bias_arr_f[l] = __half22float2(bias_half[l]); + vals_arr_f[l] = __half22float2(vals_half[l]); + vals_arr_f[l].x += bias_arr_f[l].x; + vals_arr_f[l].y += bias_arr_f[l].y; + output_half[l] = __float22half2_rn(vals_arr_f[l]); + } +#endif + output_vec[d3] = output_arr; + +#endif +} + +__global__ void bias_add_transform_0213_v2(__half* output, + const __half* vals, + const __half* bias, + int hidden_dim, + int seq_length, + int heads) +{ +#ifdef HALF_PRECISION_AVAILABLE + __shared__ float4 in_data[3072]; + + int d0_stride = hidden_dim * seq_length; + int d1_stride = hidden_dim; + int d2_stride = hidden_dim / heads; + int iteration_stride = d1_stride * blockDim.z; // Hidden * 3 / 8 + int batch_stride = d0_stride * blockDim.z; // Hidden * S * 3 / 8 + + int d0_out_stride = d0_stride; + int d1_out_stride = d2_stride; + int d2_out_stride = d2_stride * seq_length; + + int d0 = blockIdx.x; // Batch + int d1 = blockIdx.y; // Sequence ID (0-127) + int cnt = threadIdx.z; // blockIdx.z; // Hidden count + int d2 = threadIdx.y; // Head (0-11) + int d3 = threadIdx.x; // Values (groups of 4) + + float4 vals_arr[1]; + float4 bias_arr[1]; + float4 output_arr[1]; + __half2* vals_half = reinterpret_cast<__half2*>(vals_arr); + __half2* bias_half = reinterpret_cast<__half2*>(bias_arr); + __half2* output_half = reinterpret_cast<__half2*>(output_arr); + + const float4* vals_vec = reinterpret_cast(vals); + const float4* bias_vec = reinterpret_cast(bias); + float4* output_vec = reinterpret_cast(output); + + int iter_index = cnt * d1_stride + d2 * d2_stride + d3; + int input_offset = d0 * batch_stride + d1 * (iteration_stride << 1); + bias_arr[0] = bias_vec[iter_index]; + +#pragma unroll + for (int iter = 0; iter < 2; iter++) { + int iter_id = iter * iteration_stride + iter_index; + vals_arr[0] = vals_vec[input_offset + iter_id]; + + output_half[0] = vals_half[0] + bias_half[0]; + output_half[1] = vals_half[1] + bias_half[1]; + output_half[2] = vals_half[2] + bias_half[2]; + output_half[3] = vals_half[3] + bias_half[3]; + + in_data[iter_id] = output_arr[0]; + } + __syncthreads(); + + iteration_stride = blockDim.z * (blockDim.y >> 1); + int matrix_stride = (d0_out_stride * gridDim.x); + int head_count = (d2 >> 1) + cnt * (blockDim.y >> 1); + + int out_index = d0 * d0_out_stride + d1 * (d1_out_stride << 1) + d3 + (d2 % 2) * d2_stride; + +#pragma unroll + for (int iter = 0; iter < 2; iter++) { + int iter_row = (iter * iteration_stride) + head_count; + int iter_offset = + (iter_row % blockDim.y) * d2_out_stride + (iter_row / blockDim.y) * matrix_stride; + output_vec[out_index + iter_offset] = + in_data[iter_row * d2_stride + d3 + (d2 % 2) * (d1_stride * blockDim.z)]; + } +#endif +} + +// [B S C*H] - > C * [B A S N] +template <> +void launch_bias_add_transform_0213(float* output, + const float* vals, + const float* bias, + int batch_size, + int seq_length, + int hidden_dim, + int heads, + cudaStream_t stream, + int trans_count) +{ + hidden_dim >>= 2; + int head_ext = (hidden_dim - 1) / MAX_THREADS + 1; + + dim3 block_dim(hidden_dim / heads, (heads / head_ext)); + dim3 grid_dim(batch_size, seq_length, (trans_count * head_ext)); + + bias_add_transform_0213<<>>( + output, vals, bias, hidden_dim, seq_length, heads, head_ext); +} + +template <> +void launch_bias_add_transform_0213<__half>(__half* output, + const __half* vals, + const __half* bias, + int batch_size, + int seq_length, + int hidden_dim, + int heads, + cudaStream_t stream, + int trans_count) +{ + hidden_dim >>= 3; + if (hidden_dim > 128 || hidden_dim < 16) { + int head_ext = (hidden_dim - 1) / MAX_THREADS + 1; + dim3 block_dim(hidden_dim / heads, (heads / head_ext)); + dim3 grid_dim(batch_size, seq_length, (trans_count * head_ext)); + bias_add_transform_0213<__half><<>>( + output, vals, bias, hidden_dim, seq_length, heads, head_ext); + } else { + dim3 block_dim(hidden_dim / heads, heads, trans_count); + dim3 grid_dim(batch_size, seq_length / 2); + bias_add_transform_0213_v2<<>>( + output, vals, bias, hidden_dim, seq_length, heads); + } +} + +template +__global__ void transform4d_0213(T* out, + const T* in, + int heads, + int seq_length, + int hidden_dim, + int head_ext); + +template <> +__global__ void transform4d_0213(float* out, + const float* in, + int heads, + int seq_length, + int hidden_dim, + int head_ext) +{ + int d0_stride = hidden_dim * seq_length; + int d1_stride = d0_stride / heads; + int d2_stride = hidden_dim / heads; + + int d0_out_stride = d0_stride; + int d1_out_stride = d2_stride; + int d2_out_stride = hidden_dim; + + int d0 = blockIdx.x; // Batch + int d1 = blockIdx.y / ((seq_length - 1) / blockDim.y + 1); // Head + int d2 = (threadIdx.y + blockDim.y * blockIdx.y) % seq_length; + int cnt = blockIdx.z; + int d3 = threadIdx.x; // Values (groups of 8) + + if (d2 < seq_length) { + const float4* in_vec = reinterpret_cast(in); + float4* out_vec = reinterpret_cast(out); + + float4 vals_vec = in_vec[cnt * d0_stride * gridDim.x + d0 * d0_stride + d1 * d1_stride + + d2 * d2_stride + d3]; + out_vec[d0 * d0_out_stride * gridDim.z + cnt * d2_out_stride + d1 * d1_out_stride + + d2 * d2_out_stride * gridDim.z + d3] = vals_vec; + } +} + +template <> +__global__ void transform4d_0213<__half>(__half* out, + const __half* in, + int heads, + int seq_length, + int hidden_dim, + int head_ext) +{ +#ifdef HALF_PRECISION_AVAILABLE + + int d0_stride = hidden_dim * (seq_length / head_ext); + int d1_stride = hidden_dim; + int d2_stride = hidden_dim / heads; + + int d0 = blockIdx.x; // Batch + int d1 = threadIdx.y + (blockIdx.z % head_ext) * (heads / head_ext); // Head + int d2 = blockIdx.z / head_ext; // Sequence + int cnt = blockIdx.y; // Hidden count + int d3 = threadIdx.x; // Values (groups of 8) + + const float4* in_vec = reinterpret_cast(in); + float4* out_vec = reinterpret_cast(out); + + in_vec += (cnt * d0_stride * gridDim.x); + in_vec += (d0 * d0_stride); + in_vec += (d2 * d2_stride); + in_vec += (d1 * d2_stride * seq_length); + + out_vec += (cnt * d1_stride); + out_vec += (d1 * d2_stride); + out_vec += (d0 * d0_stride * gridDim.y); + out_vec += (d2 * d1_stride * gridDim.y); + + out_vec[d3] = in_vec[d3]; + +#endif +} + +__global__ void transform4d_0213_v2(__half* out, + const __half* in, + int heads, + int seq_length, + int hidden_dim) +{ +#ifdef HALF_PRECISION_AVAILABLE + __shared__ float4 in_data[3072]; + + int d0_stride = hidden_dim * seq_length; + int d1_stride = hidden_dim; + int d2_stride = hidden_dim / heads; + + int d0 = blockIdx.x; // Batch + int d1 = threadIdx.y; // Head + int d2 = blockIdx.y; // Sequence + int cnt = threadIdx.z; // Hidden count + int d3 = threadIdx.x; // Values (groups of 8) + + const float4* in_vec = reinterpret_cast(in); + float4* out_vec = reinterpret_cast(out); + + int input_offset = d0 * d0_stride + d2 * (d2_stride << 1) + d3 + (d1 % 2) * d2_stride; + int head_count = (d1 >> 1) + cnt * (blockDim.y >> 1); + int iteration_stride = blockDim.z * (blockDim.y >> 1); + int matrix_stride = (d0_stride * gridDim.x); + +#pragma unroll + for (int iter = 0; iter < 2; iter++) { + int iter_row = iter * iteration_stride + head_count; + int iter_offset = (iter_row % blockDim.y) * d2_stride; + + in_data[d3 + iter_offset + (iter_row / blockDim.y + (d1 % 2) * blockDim.z) * d1_stride] = + in_vec[input_offset + iter_offset * seq_length + + (iter_row / blockDim.y) * matrix_stride]; + } + __syncthreads(); + + iteration_stride = d1_stride * blockDim.z; + int iter_index = cnt * d1_stride + d1 * d2_stride + d3; + int output_offset = d0 * d0_stride * blockDim.z + d2 * (iteration_stride << 1); + +#pragma unroll + for (int iter = 0; iter < 2; iter++) { + int iter_id = iter * iteration_stride + iter_index; + out_vec[output_offset + iter_id] = in_data[iter_id]; + } +#endif +} + +// 3 * [B A S N] - > [B S C*H] +template <> +void launch_transform4d_0213(float* out, + const float* in, + int batch_size, + int heads, + int seq_length, + int hidden_dim, + cudaStream_t stream, + int trans_count) +{ + hidden_dim >>= 2; + dim3 grid_dims(batch_size, heads * ((seq_length - 1) / 8 + 1), trans_count); + dim3 block_dims(hidden_dim / heads, 8); + transform4d_0213 + <<>>(out, in, heads, seq_length, hidden_dim, 1); +} + +template <> +void launch_transform4d_0213<__half>(__half* out, + const __half* in, + int batch_size, + int heads, + int seq_length, + int hidden_dim, + cudaStream_t stream, + int trans_count) +{ + hidden_dim >>= 3; + if (hidden_dim > 128 || hidden_dim < 16) { + int head_ext = (hidden_dim - 1) / MAX_THREADS + 1; + dim3 grid_dims(batch_size, trans_count, (seq_length * head_ext)); + dim3 block_dims(hidden_dim / heads, (heads / head_ext)); + transform4d_0213<__half><<>>( + out, in, heads, seq_length, hidden_dim, head_ext); + } else { + dim3 grid_dims(batch_size, seq_length / 2); + dim3 block_dims(hidden_dim / heads, heads, trans_count); + transform4d_0213_v2<<>>( + out, in, heads, seq_length, hidden_dim); + } +} diff --git a/toolbox/DeepSpeed/v0.15.3/patches/deepspeed/__init__.py b/toolbox/DeepSpeed/v0.15.3/patches/deepspeed/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e1d742fd9356f982c9ac1ee0ae819c0bfa41afb3 --- /dev/null +++ b/toolbox/DeepSpeed/v0.15.3/patches/deepspeed/__init__.py @@ -0,0 +1,372 @@ +#!/usr/bin/env python3 +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import sys +import types +import json +from typing import Optional, Union +import torch +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler +from packaging import version as pkg_version + +# Skip Triton import for AMD due to pytorch-triton-rocm module breaking device API in DeepSpeed +if not (hasattr(torch.version, 'hip') and torch.version.hip is not None): + try: + import triton # noqa: F401 # type: ignore + HAS_TRITON = True + except ImportError: + HAS_TRITON = False +else: + HAS_TRITON = False + +from . import ops +from . import module_inject + +from .accelerator import get_accelerator +from .constants import TORCH_DISTRIBUTED_DEFAULT_PORT +from .runtime.engine import DeepSpeedEngine, DeepSpeedOptimizerCallable, DeepSpeedSchedulerCallable +from .runtime.engine import ADAM_OPTIMIZER, LAMB_OPTIMIZER +from .runtime.hybrid_engine import DeepSpeedHybridEngine +from .runtime.pipe.engine import PipelineEngine +from .inference.engine import InferenceEngine +from .inference.config import DeepSpeedInferenceConfig +from .runtime.lr_schedules import add_tuning_arguments +from .runtime.config import DeepSpeedConfig, DeepSpeedConfigError +from .runtime.activation_checkpointing import checkpointing +from .ops.transformer import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig +from .module_inject import replace_transformer_layer, revert_transformer_layer +# from .ops.rope import fused_apply_rotary_pos_emb,fused_apply_rotary_pos_emb_cached +# from .ops.swiglu import swiglu +# from .ops.layernorm import FusedLayerNorm,FusedRMSNorm,MixedFusedLayerNorm,MixedFusedRMSNorm,FusedRMSNormResidualFunction + +from .utils import log_dist, OnDevice, logger +from .comm.comm import init_distributed + +from .runtime import zero +from .runtime.compiler import is_compile_supported + +from .pipe import PipelineModule + +from .git_version_info import version, git_hash, git_branch + + +def _parse_version(version_str): + '''Parse a version string and extract the major, minor, and patch versions.''' + ver = pkg_version.parse(version_str) + return ver.major, ver.minor, ver.micro + + +# Export version information +__version__ = version +__version_major__, __version_minor__, __version_patch__ = _parse_version(__version__) +__git_hash__ = git_hash +__git_branch__ = git_branch + +# Set to torch's distributed package or deepspeed.comm based inside DeepSpeedEngine init +dist = None + + +def initialize(args=None, + model: torch.nn.Module = None, + optimizer: Optional[Union[Optimizer, DeepSpeedOptimizerCallable]] = None, + model_parameters: Optional[torch.nn.Module] = None, + training_data: Optional[torch.utils.data.Dataset] = None, + lr_scheduler: Optional[Union[_LRScheduler, DeepSpeedSchedulerCallable]] = None, + distributed_port: int = TORCH_DISTRIBUTED_DEFAULT_PORT, + mpu=None, + dist_init_required: Optional[bool] = None, + collate_fn=None, + config=None, + mesh_param=None, + config_params=None): + """Initialize the DeepSpeed Engine. + + Arguments: + args: an object containing local_rank and deepspeed_config fields. + This is optional if `config` is passed. + + model: Required: nn.module class before apply any wrappers + + optimizer: Optional: a user defined Optimizer or Callable that returns an Optimizer object. + This overrides any optimizer definition in the DeepSpeed json config. + + model_parameters: Optional: An iterable of torch.Tensors or dicts. + Specifies what Tensors should be optimized. + + training_data: Optional: Dataset of type torch.utils.data.Dataset + + lr_scheduler: Optional: Learning Rate Scheduler Object or a Callable that takes an Optimizer and returns a Scheduler object. + The scheduler object should define a get_lr(), step(), state_dict(), and load_state_dict() methods + + distributed_port: Optional: Master node (rank 0)'s free port that needs to be used for communication during distributed training + + mpu: Optional: A model parallelism unit object that implements + get_{model,data}_parallel_{rank,group,world_size}() + + dist_init_required: Optional: None will auto-initialize torch distributed if needed, + otherwise the user can force it to be initialized or not via boolean. + + collate_fn: Optional: Merges a list of samples to form a + mini-batch of Tensor(s). Used when using batched loading from a + map-style dataset. + + config: Optional: Instead of requiring args.deepspeed_config you can pass your deepspeed config + as an argument instead, as a path or a dictionary. + + config_params: Optional: Same as `config`, kept for backwards compatibility. + + Returns: + A tuple of ``engine``, ``optimizer``, ``training_dataloader``, ``lr_scheduler`` + + * ``engine``: DeepSpeed runtime engine which wraps the client model for distributed training. + + * ``optimizer``: Wrapped optimizer if a user defined ``optimizer`` is supplied, or if + optimizer is specified in json config else ``None``. + + * ``training_dataloader``: DeepSpeed dataloader if ``training_data`` was supplied, + otherwise ``None``. + + * ``lr_scheduler``: Wrapped lr scheduler if user ``lr_scheduler`` is passed, or + if ``lr_scheduler`` specified in JSON configuration. Otherwise ``None``. + """ + log_dist("DeepSpeed info: version={}, git-hash={}, git-branch={}".format(__version__, __git_hash__, + __git_branch__), + ranks=[0]) + + # Disable zero.Init context if it's currently enabled + zero.partition_parameters.shutdown_init_context() + + assert model is not None, "deepspeed.initialize requires a model" + + global dist + from deepspeed import comm as dist + dist_backend = get_accelerator().communication_backend_name() + dist.init_distributed(dist_backend=dist_backend, + distributed_port=distributed_port, + dist_init_required=dist_init_required) + + ##TODO: combine reuse mpu as mesh device and vice versa + # Set config using config_params for backwards compat + if config is None and config_params is not None: + config = config_params + + mesh_device = None + if mesh_param: + logger.info(f"mesh_param to Initialize mesh device: {mesh_param}") + mesh_device = dist.initialize_mesh_device(mesh_param, ("data_parallel", "sequence_parallel")) + #if config file has sequence parallelize and data parallelize, then use them to initialize mesh device + elif config is not None: + if "sequence_parallel_size" in config and "data_parallel_size" in config: + logger.info(f"config to Initialize mesh device: {config}") + mesh_device = dist.initialize_mesh_device((config["data_parallel_size"], config["sequence_parallel_size"]), \ + ("data_parallel", "sequence_parallel")) + + # Check for deepscale_config for backwards compat + if hasattr(args, "deepscale_config") and args.deepscale_config is not None: + logger.warning("************ --deepscale_config is deprecated, please use --deepspeed_config ************") + if hasattr(args, "deepspeed_config"): + assert (args.deepspeed_config is + None), "Not sure how to proceed, we were given both a deepscale_config and deepspeed_config" + args.deepspeed_config = args.deepscale_config + args.deepscale_config = None + + # Check that we have only one config passed + if hasattr(args, "deepspeed_config") and args.deepspeed_config is not None: + assert config is None, "Not sure how to proceed, we were given deepspeed configs in the deepspeed arguments and deepspeed.initialize() function call" + config = args.deepspeed_config + assert config is not None, "DeepSpeed requires --deepspeed_config to specify configuration file" + if not isinstance(model, PipelineModule): + config_class = DeepSpeedConfig(config, mpu, mesh_device=mesh_device) + if config_class.hybrid_engine.enabled: + engine = DeepSpeedHybridEngine(args=args, + model=model, + optimizer=optimizer, + model_parameters=model_parameters, + training_data=training_data, + lr_scheduler=lr_scheduler, + mpu=mpu, + dist_init_required=dist_init_required, + collate_fn=collate_fn, + config=config, + config_class=config_class) + else: + engine = DeepSpeedEngine(args=args, + model=model, + optimizer=optimizer, + model_parameters=model_parameters, + training_data=training_data, + lr_scheduler=lr_scheduler, + mpu=mpu, + dist_init_required=dist_init_required, + collate_fn=collate_fn, + config=config, + mesh_device=mesh_device, + config_class=config_class) + else: + assert mpu is None, "mpu must be None with pipeline parallelism" + mpu = model.mpu() + config_class = DeepSpeedConfig(config, mpu) + engine = PipelineEngine(args=args, + model=model, + optimizer=optimizer, + model_parameters=model_parameters, + training_data=training_data, + lr_scheduler=lr_scheduler, + mpu=mpu, + dist_init_required=dist_init_required, + collate_fn=collate_fn, + config=config, + config_class=config_class) + + # Restore zero.Init context if necessary + zero.partition_parameters.restore_init_context() + + return_items = [ + engine, + engine.optimizer, + engine.training_dataloader, + engine.lr_scheduler, + ] + return tuple(return_items) + + +def _add_core_arguments(parser): + r"""Helper (internal) function to update an argument parser with an argument group of the core DeepSpeed arguments. + The core set of DeepSpeed arguments include the following: + 1) --deepspeed: boolean flag to enable DeepSpeed + 2) --deepspeed_config : path of a json configuration file to configure DeepSpeed runtime. + + This is a helper function to the public add_config_arguments() + + Arguments: + parser: argument parser + Return: + parser: Updated Parser + """ + group = parser.add_argument_group('DeepSpeed', 'DeepSpeed configurations') + + group.add_argument('--deepspeed', + default=False, + action='store_true', + help='Enable DeepSpeed (helper flag for user code, no impact on DeepSpeed backend)') + + group.add_argument('--deepspeed_config', default=None, type=str, help='DeepSpeed json configuration file.') + + group.add_argument('--deepscale', + default=False, + action='store_true', + help='Deprecated enable DeepSpeed (helper flag for user code, no impact on DeepSpeed backend)') + + group.add_argument('--deepscale_config', + default=None, + type=str, + help='Deprecated DeepSpeed json configuration file.') + + return parser + + +def add_config_arguments(parser): + r"""Update the argument parser to enabling parsing of DeepSpeed command line arguments. + The set of DeepSpeed arguments include the following: + 1) --deepspeed: boolean flag to enable DeepSpeed + 2) --deepspeed_config : path of a json configuration file to configure DeepSpeed runtime. + + Arguments: + parser: argument parser + Return: + parser: Updated Parser + """ + parser = _add_core_arguments(parser) + + return parser + + +def default_inference_config(): + """ + Return a default DeepSpeed inference configuration dictionary. + """ + return DeepSpeedInferenceConfig().dict() + + +def init_inference(model, config=None, **kwargs): + """Initialize the DeepSpeed InferenceEngine. + + Description: all four cases are valid and supported in DS init_inference() API. + + # Case 1: user provides no config and no kwargs. Default config will be used. + + .. code-block:: python + + generator.model = deepspeed.init_inference(generator.model) + string = generator("DeepSpeed is") + print(string) + + # Case 2: user provides a config and no kwargs. User supplied config will be used. + + .. code-block:: python + + generator.model = deepspeed.init_inference(generator.model, config=config) + string = generator("DeepSpeed is") + print(string) + + # Case 3: user provides no config and uses keyword arguments (kwargs) only. + + .. code-block:: python + + generator.model = deepspeed.init_inference(generator.model, + tensor_parallel={"tp_size": world_size}, + dtype=torch.half, + replace_with_kernel_inject=True) + string = generator("DeepSpeed is") + print(string) + + # Case 4: user provides config and keyword arguments (kwargs). Both config and kwargs are merged and kwargs take precedence. + + .. code-block:: python + + generator.model = deepspeed.init_inference(generator.model, config={"dtype": torch.half}, replace_with_kernel_inject=True) + string = generator("DeepSpeed is") + print(string) + + Arguments: + model: Required: original nn.module object without any wrappers + + config: Optional: instead of arguments, you can pass in a DS inference config dict or path to JSON file + + Returns: + A deepspeed.InferenceEngine wrapped model. + """ + log_dist("DeepSpeed info: version={}, git-hash={}, git-branch={}".format(__version__, __git_hash__, + __git_branch__), + ranks=[0]) + + # Load config_dict from config first + if config is None: + config = {} + if isinstance(config, str): + with open(config, "r") as f: + config_dict = json.load(f) + elif isinstance(config, dict): + config_dict = config + else: + raise ValueError(f"'config' argument expected string or dictionary, got {type(config)}") + + # Update with values from kwargs, ensuring no conflicting overlap between config and kwargs + overlap_keys = set(config_dict.keys()).intersection(kwargs.keys()) + # If there is overlap, error out if values are different + for key in overlap_keys: + if config_dict[key] != kwargs[key]: + raise ValueError(f"Conflicting argument '{key}' in 'config':{config_dict[key]} and kwargs:{kwargs[key]}") + config_dict.update(kwargs) + + ds_inference_config = DeepSpeedInferenceConfig(**config_dict) + + engine = InferenceEngine(model, config=ds_inference_config) + + return engine diff --git a/toolbox/DeepSpeed/v0.15.3/patches/deepspeed/elasticity/elastic_agent.py b/toolbox/DeepSpeed/v0.15.3/patches/deepspeed/elasticity/elastic_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..0839285c906e4b224f8192b6c16c2fbc5e6e7bc0 --- /dev/null +++ b/toolbox/DeepSpeed/v0.15.3/patches/deepspeed/elasticity/elastic_agent.py @@ -0,0 +1,197 @@ +#!/usr/bin/env python3 +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from torch.distributed.elastic.agent.server.local_elastic_agent import LocalElasticAgent +from typing import Any, Dict, Optional, Tuple +from datetime import datetime +import torch +if torch.__version__ >= "2.4": + from torch.distributed.elastic.utils.distributed import get_free_port as _get_socket_with_port +else: + from torch.distributed.elastic.agent.server.api import _get_socket_with_port +from torch.distributed.elastic.metrics import put_metric +from torch.distributed.elastic.agent.server.api import ( + RunResult, + WorkerGroup, + WorkerSpec, + WorkerState, +) +from torch.distributed import Store +import time +import os +from torch.distributed.elastic.multiprocessing import start_processes +from torch.distributed.elastic.utils import macros +import shutil +import copy +from contextlib import closing +import subprocess + +from torch.distributed.elastic.utils.logging import get_logger + +log = get_logger(__name__) + + +class DSElasticAgent(LocalElasticAgent): + + def __init__( + self, + spec: WorkerSpec, + env: Dict, + start_method="spawn", + exit_barrier_timeout: float = 300, + log_dir: Optional[str] = None, + ): + super().__init__(spec, start_method, exit_barrier_timeout, log_dir) + self.ds_env = env + + @staticmethod + def _set_master_addr_port(store: Store, + master_addr: Optional[str], + master_port: Optional[int], + local_addr: Optional[str] = None): + if master_port is None: + sock = _get_socket_with_port() + with closing(sock): + master_port = sock.getsockname()[1] + + if master_addr is None: + # master_addr = _get_fq_hostname() + import shlex + safe_cmd = shlex.split("hostname -I") + result = subprocess.check_output(safe_cmd) + master_addr = result.decode('utf-8').split()[0] + + store.set("MASTER_ADDR", master_addr.encode(encoding="UTF-8")) + store.set("MASTER_PORT", str(master_port).encode(encoding="UTF-8")) + + def _start_workers(self, worker_group: WorkerGroup) -> Dict[int, Any]: + spec = worker_group.spec + store = worker_group.store + assert store is not None + master_addr, master_port = super()._get_master_addr_port(store) + restart_count = spec.max_restarts - self._remaining_restarts + + use_agent_store = spec.rdzv_handler.get_backend() == "static" + + args: Dict[int, Tuple] = {} + envs: Dict[int, Dict[str, str]] = {} + for worker in worker_group.workers: + local_rank = worker.local_rank + + worker_env_ds = copy.deepcopy(self.ds_env) + worker_env_elastic = { + "LOCAL_RANK": str(local_rank), + "RANK": str(worker.global_rank), + "GROUP_RANK": str(worker_group.group_rank), + "ROLE_RANK": str(worker.role_rank), + "ROLE_NAME": spec.role, + "LOCAL_WORLD_SIZE": str(spec.local_world_size), + "WORLD_SIZE": str(worker.world_size), + "GROUP_WORLD_SIZE": str(worker_group.group_world_size), + "ROLE_WORLD_SIZE": str(worker.role_world_size), + "MASTER_ADDR": master_addr, + "MASTER_PORT": str(master_port), + "TORCHELASTIC_RESTART_COUNT": str(restart_count), + "TORCHELASTIC_MAX_RESTARTS": str(spec.max_restarts), + "TORCHELASTIC_RUN_ID": spec.rdzv_handler.get_run_id(), + "TORCHELASTIC_USE_AGENT_STORE": str(use_agent_store), + "NCCL_ASYNC_ERROR_HANDLING": os.getenv("NCCL_ASYNC_ERROR_HANDLING", str(1)), + } + worker_env_ds.update(worker_env_elastic) + if "OMP_NUM_THREADS" in os.environ: + worker_env_ds["OMP_NUM_THREADS"] = os.environ["OMP_NUM_THREADS"] + + envs[local_rank] = worker_env_ds + worker_args = list(spec.args) + worker_args = macros.substitute(worker_args, str(local_rank)) + args[local_rank] = tuple(worker_args) + + # scaling events do not count towards restarts (gets same attempt #) + # remove existing log dir if this restart is due to a scaling event + attempt_log_dir = os.path.join(self._log_dir, f"attempt_{restart_count}") + shutil.rmtree(attempt_log_dir, ignore_errors=True) + os.makedirs(attempt_log_dir) + + assert spec.entrypoint is not None + self._pcontext = start_processes( + name=spec.role, + entrypoint=spec.entrypoint, + args=args, + envs=envs, + log_dir=attempt_log_dir, + start_method=self._start_method, + redirects=spec.redirects, + tee=spec.tee, + ) + + return self._pcontext.pids() + + def _invoke_run(self, role: str = "default") -> RunResult: + # NOTE: currently only works for a single role + + spec = self._worker_group.spec + role = spec.role + + log.info(f"[{role}] starting workers for entrypoint: {spec.get_entrypoint_name()}") + + self._initialize_workers(self._worker_group) + monitor_interval = spec.monitor_interval + rdzv_handler = spec.rdzv_handler + + participants = rdzv_handler._state_holder.state.participants + + while True: + assert self._worker_group.state != WorkerState.INIT + time.sleep(monitor_interval) + run_result = self._monitor_workers(self._worker_group) + state = run_result.state + self._worker_group.state = state + + expire_time = datetime.utcnow() - (rdzv_handler._settings.keep_alive_interval * + rdzv_handler._settings.keep_alive_max_attempt) + _dead_nodes = [ + node for node, last_heartbeat in rdzv_handler._state_holder.state.last_heartbeats.items() + if last_heartbeat < expire_time + ] + + put_metric(f"workers.{role}.remaining_restarts", self._remaining_restarts) + put_metric(f"workers.{role}.{state.name.lower()}", 1) + + if state == WorkerState.SUCCEEDED: + log.info(f"[{role}] worker group successfully finished." + f" Waiting {self._exit_barrier_timeout} seconds for other agents to finish.") + self._exit_barrier() + return run_result + elif state in {WorkerState.UNHEALTHY, WorkerState.FAILED + } or len(participants) > len(rdzv_handler._state_holder.state.participants): + if self._remaining_restarts > 0: + log.info(f"[{role}] Worker group {state.name}. " + f"{self._remaining_restarts}/{spec.max_restarts} attempts left;" + f" will restart worker group") + self._remaining_restarts -= 1 + # rdzv_handler._state_holder.state.restart = False + self._restart_workers(self._worker_group) + participants = rdzv_handler._state_holder.state.participants + + else: + self._stop_workers(self._worker_group) + self._worker_group.state = WorkerState.FAILED + self._exit_barrier() + return run_result + elif state == WorkerState.HEALTHY: + # membership changes do not count as retries + num_nodes_waiting = rdzv_handler.num_nodes_waiting() + group_rank = self._worker_group.group_rank + if num_nodes_waiting > 0: + log.info(f"[{role}] Detected {num_nodes_waiting} " + f"new nodes from group_rank={group_rank}; " + f"will restart worker group") + self._restart_workers(self._worker_group) + participants = rdzv_handler._state_holder.state.participants + else: + raise Exception(f"[{role}] Worker group in {state.name} state") diff --git a/toolbox/DeepSpeed/v0.15.3/patches/deepspeed/ops/__init__.py b/toolbox/DeepSpeed/v0.15.3/patches/deepspeed/ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c5859978a7c09364a3ddec7f59b64edb681e1943 --- /dev/null +++ b/toolbox/DeepSpeed/v0.15.3/patches/deepspeed/ops/__init__.py @@ -0,0 +1,25 @@ +#!/usr/bin/env python3 +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from . import adam +from . import adagrad +from . import lamb +from . import lion + +from . import swiglu +from . import layernorm +from . import rope +#from ..git_version_info_installed import installed_ops as __installed_ops__ +#if __installed_ops__['sparse_attn']: + +from . import sparse_attention +from . import transformer +from . import fp_quantizer +from .transformer import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig + +from ..git_version_info import compatible_ops as __compatible_ops__ diff --git a/toolbox/DeepSpeed/v0.15.3/patches/deepspeed/ops/lamb/fused_lamb.py b/toolbox/DeepSpeed/v0.15.3/patches/deepspeed/ops/lamb/fused_lamb.py new file mode 100644 index 0000000000000000000000000000000000000000..b6a817fbb0ae835b590234326575edb1f60d580d --- /dev/null +++ b/toolbox/DeepSpeed/v0.15.3/patches/deepspeed/ops/lamb/fused_lamb.py @@ -0,0 +1,177 @@ +#!/usr/bin/env python3 +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +""" +Copyright NVIDIA/apex +This file is adapted from NVIDIA/apex/optimizer/fused_adam and implements the LAMB optimizer +""" +import types +import torch +from deepspeed.ops.op_builder import FusedLambBuilder + + +class FusedLamb(torch.optim.Optimizer): + """Implements the LAMB algorithm. Currently GPU-only. + + LAMB was proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes. + https://arxiv.org/abs/1904.00962 + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups. + lr (float, optional): learning rate. (default: 1e-3) + bias_correction (bool, optional): bias correction (default: True) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square. (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability. (default: 1e-8) + eps_inside_sqrt (boolean, optional): in the 'update parameters' step, + adds eps to the bias-corrected second moment estimate before + evaluating square root instead of adding it to the square root of + second moment estimate as in the original paper. (default: False) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + max_grad_norm (float, optional): value used to clip global grad norm + (default: 0.0) + max_coeff(float, optional): maximum value of the lamb coefficient (default: 10.0) + min_coeff(float, optional): minimum value of the lamb coefficient (default: 0.01) + amsgrad (boolean, optional): NOT SUPPORTED in FusedLamb! + """ + + def __init__(self, + params, + lr=1e-3, + bias_correction=True, + betas=(0.9, 0.999), + eps=1e-8, + eps_inside_sqrt=False, + weight_decay=0., + max_grad_norm=0., + max_coeff=10.0, + min_coeff=0.01, + amsgrad=False): + self.fused_lamb_cuda = FusedLambBuilder().load() + + if amsgrad: + raise RuntimeError('FusedLamb does not support the AMSGrad variant.') + defaults = dict(lr=lr, + bias_correction=bias_correction, + betas=betas, + eps=eps, + weight_decay=weight_decay, + max_grad_norm=max_grad_norm, + max_coeff=max_coeff, + min_coeff=min_coeff) + super(FusedLamb, self).__init__(params, defaults) + self.eps_mode = 0 if eps_inside_sqrt else 1 + self.lamb_coeffs = [] + + def step(self, closure=None, grads=None, output_params=None, scale=1., grad_norms=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + grads (list of tensors, optional): weight gradient to use for the + optimizer update. If gradients have type torch.half, parameters + are expected to be in type torch.float. (default: None) + output params (list of tensors, optional): A reduced precision copy + of the updated weights written out in addition to the regular + updated weights. Have to be of same type as gradients. (default: None) + scale (float, optional): factor to divide gradient tensor values + by before applying to weights. (default: 1) + """ + loss = None + if closure is not None: + loss = closure() + + if grads is None: + grads_group = [None] * len(self.param_groups) + # backward compatibility + # assuming a list/generator of parameter means single group + elif isinstance(grads, types.GeneratorType): + grads_group = [grads] + elif type(grads[0]) != list: + grads_group = [grads] + else: + grads_group = grads + + if output_params is None: + output_params_group = [None] * len(self.param_groups) + elif isinstance(output_params, types.GeneratorType): + output_params_group = [output_params] + elif type(output_params[0]) != list: + output_params_group = [output_params] + else: + output_params_group = output_params + + if grad_norms is None: + grad_norms = [None] * len(self.param_groups) + + #remove the previous coeffs + del self.lamb_coeffs[:] + + for group, grads_this_group, output_params_this_group, grad_norm_group in zip( + self.param_groups, grads_group, output_params_group, grad_norms): + if grads_this_group is None: + grads_this_group = [None] * len(group['params']) + if output_params_this_group is None: + output_params_this_group = [None] * len(group['params']) + + if grad_norm_group is None: + grad_norm_group = [None] * len(group['params']) + elif not isinstance(grad_norm_group, list): + grad_norm_group = [grad_norm_group] + + bias_correction = 1 if group['bias_correction'] else 0 + + for p, grad, output_param, grad_norm in zip(group['params'], grads_this_group, output_params_this_group, + grad_norm_group): + + # compute combined scale factor for this group + combined_scale = scale + if group['max_grad_norm'] > 0: + # norm is in fact norm*scale + clip = ((grad_norm / scale) + 1e-6) / group['max_grad_norm'] + if clip > 1: + combined_scale = clip * scale + + #note: p.grad should not ever be set for correct operation of mixed precision optimizer that sometimes sends None gradients + if p.grad is None and grad is None: + continue + if grad is None: + grad = p.grad.data + if grad.is_sparse: + raise RuntimeError('FusedLamb does not support sparse gradients') + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p.data) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p.data) + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + beta1, beta2 = group['betas'] + max_coeff = group['max_coeff'] + min_coeff = group['min_coeff'] + + state['step'] += 1 + + out_p = torch.tensor([], dtype=torch.float) if output_param is None else output_param + lamb_coeff = self.fused_lamb_cuda.lamb(p.data, out_p, exp_avg, exp_avg_sq, grad, group['lr'], beta1, + beta2, max_coeff, min_coeff, group['eps'], combined_scale, + state['step'], self.eps_mode, bias_correction, + group['weight_decay']) + self.lamb_coeffs.append(lamb_coeff) + return loss + + def get_lamb_coeffs(self): + lamb_coeffs = [lamb_coeff.item() for lamb_coeff in self.lamb_coeffs] + return lamb_coeffs diff --git a/toolbox/DeepSpeed/v0.15.3/patches/deepspeed/ops/layernorm/__init__.py b/toolbox/DeepSpeed/v0.15.3/patches/deepspeed/ops/layernorm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..12604d2761abb4df7cb1bb8265dcfba41e6becf9 --- /dev/null +++ b/toolbox/DeepSpeed/v0.15.3/patches/deepspeed/ops/layernorm/__init__.py @@ -0,0 +1,16 @@ +#!/usr/bin/env python3 +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .fused_layernorm import FusedLayerNorm,FusedRMSNorm,MixedFusedLayerNorm,MixedFusedRMSNorm,FusedRMSNormResidualFunction \ No newline at end of file diff --git a/toolbox/DeepSpeed/v0.15.3/patches/deepspeed/ops/layernorm/fused_layernorm.py b/toolbox/DeepSpeed/v0.15.3/patches/deepspeed/ops/layernorm/fused_layernorm.py new file mode 100644 index 0000000000000000000000000000000000000000..3b43a48e30be99e0c7ddc981930d3c33050a0bf3 --- /dev/null +++ b/toolbox/DeepSpeed/v0.15.3/patches/deepspeed/ops/layernorm/fused_layernorm.py @@ -0,0 +1,489 @@ +#!/usr/bin/env python3 +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import importlib +import numbers +from typing import Optional, Sequence +import torch +from torch.nn.parameter import Parameter +from torch.nn import init +from torch.nn import functional as F +from deepspeed.ops.op_builder import FusedLayernormBuilder + +# global fused_layer_norm_cuda +# fused_layer_norm_cuda = None + +def _get_autocast_dtypes() -> Sequence[torch.dtype]: + if torch.cuda.is_bf16_supported(): + return [torch.half, torch.bfloat16] + return [torch.half] + + +def _get_current_dtype(dtype: Optional[torch.dtype] = None) -> torch.dtype: + if not torch.is_autocast_enabled(): + return torch.float or dtype + else: + return torch.get_autocast_gpu_dtype() + + +def _cast_if_autocast_enabled(*args): + if not torch.is_autocast_enabled(): + return args + else: + return torch.cuda.amp.autocast_mode._cast(args, torch.get_autocast_gpu_dtype()) + +# Reference implementation from Huggingface +def manual_rms_norm(input, normalized_shape, weight, eps): + # layer norm should always be calculated in float32 + dims = tuple(i for i in range(-1, -len(normalized_shape)-1, -1)) + variance = input.to(torch.float32).pow(2).mean(dims, keepdim=True) + input = input * torch.rsqrt(variance + eps) + + if weight is None: + return input + + # convert into half-precision if necessary + if weight.dtype in [torch.float16, torch.bfloat16]: + input = input.to(weight.dtype) + + return weight * input + + +class FusedLayerNormAffineFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, input, weight, bias, normalized_shape, eps, memory_efficient=True): + ctx.normalized_shape = normalized_shape + ctx.eps = eps + ctx.memory_efficient = memory_efficient + input_ = input.contiguous() + weight_ = weight.contiguous() + bias_ = bias.contiguous() + output, mean, invvar = fused_layer_norm_cuda.forward_affine( + input_, ctx.normalized_shape, weight_, bias_, ctx.eps + ) + if memory_efficient: + ctx.save_for_backward(output, weight_, bias_, None, invvar) + else: + ctx.save_for_backward(input_, weight_, bias_, mean, invvar) + return output + + @staticmethod + def backward(ctx, grad_output): + input_, weight_, bias_, mean, invvar = ctx.saved_tensors + grad_input = grad_weight = grad_bias = None + grad_input, grad_weight, grad_bias = fused_layer_norm_cuda.backward_affine( + grad_output.contiguous(), mean, invvar, input_, ctx.normalized_shape, weight_, bias_, ctx.eps, ctx.memory_efficient + ) + return grad_input, grad_weight, grad_bias, None, None, None + + +class FusedRMSNormAffineFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, input, weight, normalized_shape, eps, memory_efficient=True): + ctx.normalized_shape = normalized_shape + ctx.eps = eps + ctx.memory_efficient = memory_efficient + input_ = input.contiguous() + weight_ = weight.contiguous() + output, invvar = fused_layer_norm_cuda.rms_forward_affine( + input_, ctx.normalized_shape, weight_, ctx.eps) + if memory_efficient: + ctx.save_for_backward(output, weight_, invvar) + else: + ctx.save_for_backward(input_, weight_, invvar) + return output + + @staticmethod + def backward(ctx, grad_output): + input_, weight_, invvar = ctx.saved_tensors + grad_input = grad_weight = None + grad_input, grad_weight = fused_layer_norm_cuda.rms_backward_affine( + grad_output.contiguous(), invvar, input_, ctx.normalized_shape, weight_, ctx.eps, ctx.memory_efficient + ) + return grad_input, grad_weight, None, None, None + + +class FusedLayerNormAffineMixedDtypesFunction(FusedLayerNormAffineFunction): + + @staticmethod + def forward(ctx, input, weight, bias, normalized_shape, eps, memory_efficient=True): + ctx.normalized_shape = normalized_shape + ctx.eps = eps + ctx.memory_efficient = memory_efficient + input_ = input.contiguous() + weight_ = weight.contiguous() + bias_ = bias.contiguous() + output, mean, invvar = fused_layer_norm_cuda.forward_affine_mixed_dtypes( + input_, ctx.normalized_shape, weight_, bias_, ctx.eps + ) + if memory_efficient: + ctx.save_for_backward(output, weight_, bias_, None, invvar) + else: + ctx.save_for_backward(input_, weight_, bias_, mean, invvar) + return output + + +class FusedRMSNormAffineMixedDtypesFunction(FusedRMSNormAffineFunction): + + @staticmethod + def forward(ctx, input, weight, normalized_shape, eps, memory_efficient=True): + ctx.normalized_shape = normalized_shape + ctx.eps = eps + ctx.memory_efficient = memory_efficient + input_ = input.contiguous() + weight_ = weight.contiguous() + output, invvar = fused_layer_norm_cuda.rms_forward_affine_mixed_dtypes( + input_, ctx.normalized_shape, weight_, ctx.eps + ) + if memory_efficient: + ctx.save_for_backward(output, weight_, invvar) + else: + ctx.save_for_backward(input_, weight_, invvar) + return output + + +class FusedLayerNormFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, input, normalized_shape, eps): + ctx.normalized_shape = normalized_shape + ctx.eps = eps + input_ = input.contiguous() + output, mean, invvar = fused_layer_norm_cuda.forward(input_, ctx.normalized_shape, ctx.eps) + ctx.save_for_backward(input_, mean, invvar) + return output + + @staticmethod + def backward(ctx, grad_output): + input_, mean, invvar = ctx.saved_tensors + grad_input = None + grad_input = fused_layer_norm_cuda.backward( + grad_output.contiguous(), mean, invvar, input_, ctx.normalized_shape, ctx.eps, False + ) + return grad_input, None, None + + +class FusedRMSNormFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, input, normalized_shape, eps): + ctx.normalized_shape = normalized_shape + ctx.eps = eps + input_ = input.contiguous() + output, invvar = fused_layer_norm_cuda.rms_forward(input_, ctx.normalized_shape, ctx.eps) + ctx.save_for_backward(input_, invvar) + return output + + @staticmethod + def backward(ctx, grad_output): + input_, invvar = ctx.saved_tensors + grad_input = None + grad_input = fused_layer_norm_cuda.rms_backward( + grad_output.contiguous(), invvar, input_, ctx.normalized_shape, ctx.eps, False + ) + return grad_input, None, None + +class FusedRMSNormResidualFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, input, weight, residual, normalized_shape, eps, memory_efficient=True): + ctx.normalized_shape = normalized_shape + ctx.eps = eps + ctx.memory_efficient = memory_efficient + input_ = input.contiguous() + weight_ = weight.contiguous() + residual_ = residual.contiguous() + output, invvar, sum_res = fused_layer_norm_cuda.rms_pre_norm_residual_forward(input_, residual_, ctx.normalized_shape, weight_, ctx.eps) + if memory_efficient: + ctx.save_for_backward(weight_, output, invvar) + else: + ctx.save_for_backward(weight_, sum_res, invvar) + return output, sum_res + + @staticmethod + def backward(ctx, grad_output, grad_res): + weight_, input_or_output, invvar = ctx.saved_tensors + grad_input = grad_weight = None + grad_input, grad_weight = fused_layer_norm_cuda.rms_pre_norm_residual_backward( + grad_output.contiguous(), grad_res.contiguous(), invvar, input_or_output, ctx.normalized_shape, weight_, ctx.eps, ctx.memory_efficient + ) + return grad_input, grad_weight, grad_input, None, None, None + + +def fused_layer_norm_affine(input, weight, bias, normalized_shape, eps=1e-6): + args = _cast_if_autocast_enabled(input, weight, bias, normalized_shape, eps) + with torch.cuda.amp.autocast(enabled=False): + return FusedLayerNormAffineFunction.apply(*args) + + +def fused_layer_norm(input, normalized_shape, eps=1e-6): + args = _cast_if_autocast_enabled(input, normalized_shape, eps) + with torch.cuda.amp.autocast(enabled=False): + return FusedLayerNormFunction.apply(*args) + + +def mixed_dtype_fused_layer_norm_affine(input, weight, bias, normalized_shape, eps=1e-6): + args = _cast_if_autocast_enabled(input, weight, bias, normalized_shape, eps) + with torch.cuda.amp.autocast(enabled=False): + return FusedLayerNormAffineMixedDtypesFunction.apply(*args) + + +def fused_rms_norm_affine(input, weight, normalized_shape, eps=1e-6): + args = _cast_if_autocast_enabled(input, weight, normalized_shape, eps) + with torch.cuda.amp.autocast(enabled=False): + return FusedRMSNormAffineFunction.apply(*args) + + +def fused_rms_norm(input, normalized_shape, eps=1e-6): + args = _cast_if_autocast_enabled(input, normalized_shape, eps) + with torch.cuda.amp.autocast(enabled=False): + return FusedRMSNormFunction.apply(*args) + + +def mixed_dtype_fused_rms_norm_affine(input, weight, normalized_shape, eps=1e-6): + args = _cast_if_autocast_enabled(input, weight, normalized_shape, eps) + with torch.cuda.amp.autocast(enabled=False): + return FusedRMSNormAffineMixedDtypesFunction.apply(*args) + + +class FusedLayerNorm(torch.nn.Module): + r"""Applies Layer Normalization over a mini-batch of inputs as described in + the paper `Layer Normalization`_ . + + Currently only runs on cuda() tensors. + + .. math:: + y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta + + The mean and standard-deviation are calculated separately over the last + certain number dimensions which have to be of the shape specified by + :attr:`normalized_shape`. + :math:`\gamma` and :math:`\beta` are learnable affine transform parameters of + :attr:`normalized_shape` if :attr:`elementwise_affine` is ``True``. + + .. note:: + Unlike Batch Normalization and Instance Normalization, which applies + scalar scale and bias for each entire channel/plane with the + :attr:`affine` option, Layer Normalization applies per-element scale and + bias with :attr:`elementwise_affine`. + + This layer uses statistics computed from input data in both training and + evaluation modes. + + Args: + normalized_shape (int or list or torch.Size): input shape from an expected input + of size + + .. math:: + [* \times \text{normalized}\_\text{shape}[0] \times \text{normalized}\_\text{shape}[1] + \times \ldots \times \text{normalized}\_\text{shape}[-1]] + + If a single integer is used, it is treated as a singleton list, and this module will + normalize over the last dimension which is expected to be of that specific size. + eps: a value added to the denominator for numerical stability. Default: 1e-5 + elementwise_affine: a boolean value that when set to ``True``, this module + has learnable per-element affine parameters initialized to ones (for weights) + and zeros (for biases). Default: ``True``. + + Shape: + - Input: :math:`(N, *)` + - Output: :math:`(N, *)` (same shape as input) + + Examples:: + + >>> input = torch.randn(20, 5, 10, 10) + >>> # With Learnable Parameters + >>> m = apex.normalization.FusedLayerNorm(input.size()[1:]) + >>> # Without Learnable Parameters + >>> m = apex.normalization.FusedLayerNorm(input.size()[1:], elementwise_affine=False) + >>> # Normalize over last two dimensions + >>> m = apex.normalization.FusedLayerNorm([10, 10]) + >>> # Normalize over last dimension of size 10 + >>> m = apex.normalization.FusedLayerNorm(10) + >>> # Activating the module + >>> output = m(input) + + .. _`Layer Normalization`: https://arxiv.org/abs/1607.06450 + """ + + def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True): + super().__init__() + global fused_layer_norm_cuda + fused_layer_norm_cuda = FusedLayernormBuilder().load() + + if isinstance(normalized_shape, numbers.Integral): + normalized_shape = (normalized_shape,) + self.normalized_shape = torch.Size(normalized_shape) + self.eps = eps + self.elementwise_affine = elementwise_affine + if self.elementwise_affine: + self.weight = Parameter(torch.empty(*normalized_shape)) + self.bias = Parameter(torch.empty(*normalized_shape)) + else: + self.register_parameter("weight", None) + self.register_parameter("bias", None) + self.reset_parameters() + + def reset_parameters(self): + if self.elementwise_affine: + init.ones_(self.weight) + init.zeros_(self.bias) + + def forward(self, input): + if torch.jit.is_tracing() or torch.jit.is_scripting() or not input.is_cuda: + return F.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps) + if self.elementwise_affine: + return fused_layer_norm_affine(input, self.weight, self.bias, self.normalized_shape, self.eps) + else: + return fused_layer_norm(input, self.normalized_shape, self.eps) + + def extra_repr(self): + return "{normalized_shape}, eps={eps}, " "elementwise_affine={elementwise_affine}".format(**self.__dict__) + + +class FusedRMSNorm(torch.nn.Module): + r"""Applies RMS Normalization over a mini-batch of inputs + + Currently only runs on cuda() tensors. + + .. math:: + y = \frac{x}{\mathrm{RMS}[x]} * \gamma + + The root-mean-square is calculated separately over the last + certain number dimensions which have to be of the shape specified by + :attr:`normalized_shape`. + :math:`\gamma` is a learnable affine transform parameter of + :attr:`normalized_shape` if :attr:`elementwise_affine` is ``True``. + `epsilon` is added to the mean-square, then the root of the sum is taken. + + .. note:: + Unlike Batch Normalization and Instance Normalization, which applies + scalar scale and bias for each entire channel/plane with the + :attr:`affine` option, RMS Normalization applies per-element scale + with :attr:`elementwise_affine`. + + This layer uses statistics computed from input data in both training and + evaluation modes. + + Args: + normalized_shape (int or list or torch.Size): input shape from an expected input + of size + + .. math:: + [* \times \text{normalized}\_\text{shape}[0] \times \text{normalized}\_\text{shape}[1] + \times \ldots \times \text{normalized}\_\text{shape}[-1]] + + If a single integer is used, it is treated as a singleton list, and this module will + normalize over the last dimension which is expected to be of that specific size. + eps: a value added to the denominator for numerical stability. Default: 1e-5 + elementwise_affine: a boolean value that when set to ``True``, this module + has learnable per-element affine parameters initialized to ones (for weights) + and zeros (for biases). Default: ``True``. + + Shape: + - Input: :math:`(N, *)` + - Output: :math:`(N, *)` (same shape as input) + + Examples:: + + >>> input = torch.randn(20, 5, 10, 10) + >>> # With Learnable Parameters + >>> m = apex.normalization.FusedRMSNorm(input.size()[1:]) + >>> # Without Learnable Parameters + >>> m = apex.normalization.FusedRMSNorm(input.size()[1:], elementwise_affine=False) + >>> # Normalize over last two dimensions + >>> m = apex.normalization.FusedRMSNorm([10, 10]) + >>> # Normalize over last dimension of size 10 + >>> m = apex.normalization.FusedRMSNorm(10) + >>> # Activating the module + >>> output = m(input) + + .. _`Root Mean Square Layer Normalization`: https://arxiv.org/pdf/1910.07467.pdf + """ + + def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True): + super().__init__() + + fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda") + + if isinstance(normalized_shape, numbers.Integral): + normalized_shape = (normalized_shape,) + self.normalized_shape = torch.Size(normalized_shape) + self.eps = eps + self.elementwise_affine = elementwise_affine + if self.elementwise_affine: + self.weight = Parameter(torch.empty(*normalized_shape)) + else: + self.register_parameter("weight", None) + self.reset_parameters() + + def reset_parameters(self): + if self.elementwise_affine: + init.ones_(self.weight) + + def forward(self, input): + if torch.jit.is_tracing() or torch.jit.is_scripting() or not input.is_cuda: + return manual_rms_norm(input, self.normalized_shape, self.weight, self.eps) + + if self.elementwise_affine: + return fused_rms_norm_affine(input, self.weight, self.normalized_shape, self.eps) + else: + return fused_rms_norm(input, self.normalized_shape, self.eps) + + def extra_repr(self): + return "{normalized_shape}, eps={eps}, " "elementwise_affine={elementwise_affine}".format(**self.__dict__) + + +# NOTE (mkozuki): Why "mixed"? +# MixedFusedLayerNorm differs from FusedLayerNorm in that this layer norm uses parameter's dtype +# as output tensor's dtype while FusedLayerNorm uses input tensor's dtype for output tensor's dtype. +# See: `layer_norm_affine` and `layer_norm_affine_mixed_dtypes` in "csrc/layer_norm_cuda.cpp" +class MixedFusedLayerNorm(FusedLayerNorm): + + def __init__(self, normalized_shape, eps=1e-5, **kwargs): + if "elementwise_affine" in kwargs: + import warnings + warnings.warn("MixedFusedLayerNorm does not support `elementwise_affine` argument") + elementwise_affine = kwargs.pop("elementwise_affine") + if not elementwise_affine: + raise RuntimeError("MixedFusedLayerNorm does not support `elementwise_affine = False`") + + super().__init__(normalized_shape=normalized_shape, eps=eps, elementwise_affine=True) + + def forward(self, input: torch.Tensor): + # NOTE (mkozuki): CPU path is here mainly for unittest sake. + if torch.jit.is_tracing() or torch.jit.is_scripting() or not input.is_cuda: + return F.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps) + return mixed_dtype_fused_layer_norm_affine(input, self.weight, self.bias, self.normalized_shape, self.eps) + + +# MixedFusedLayerNorm differs from FusedLayerNorm in that this layer norm uses parameter's dtype +# as output tensor's dtype while FusedLayerNorm uses input tensor's dtype for output tensor's dtype. +# See: `layer_norm_affine` and `layer_norm_affine_mixed_dtypes` in "csrc/layer_norm_cuda.cpp" +class MixedFusedRMSNorm(FusedRMSNorm): + + def __init__(self, normalized_shape, eps=1e-5, **kwargs): + if "elementwise_affine" in kwargs: + import warnings + warnings.warn("MixedFusedRMSNorm does not support `elementwise_affine` argument") + elementwise_affine = kwargs.pop("elementwise_affine") + if not elementwise_affine: + raise RuntimeError("MixedFusedRMSNorm does not support `elementwise_affine = False`") + + super().__init__(normalized_shape=normalized_shape, eps=eps, elementwise_affine=True) + + def forward(self, input: torch.Tensor): + # NOTE (mkozuki): CPU path is here mainly for unittest sake. + # TODO Manual RMS Norm Implementation Here + if torch.jit.is_tracing() or torch.jit.is_scripting() or not input.is_cuda: + return manual_rms_norm(input, self.normalized_shape, self.weight, self.eps) + return mixed_dtype_fused_rms_norm_affine(input, self.weight, self.normalized_shape, self.eps) \ No newline at end of file diff --git a/toolbox/DeepSpeed/v0.15.3/patches/deepspeed/ops/rope/__init__.py b/toolbox/DeepSpeed/v0.15.3/patches/deepspeed/ops/rope/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..36cf431fbc37421cc9af91b32ba8f54e902f50c6 --- /dev/null +++ b/toolbox/DeepSpeed/v0.15.3/patches/deepspeed/ops/rope/__init__.py @@ -0,0 +1,16 @@ +#!/usr/bin/env python3 +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .fused_rope import fused_apply_rotary_pos_emb,fused_apply_rotary_pos_emb_cached \ No newline at end of file diff --git a/toolbox/DeepSpeed/v0.15.3/patches/deepspeed/ops/rope/fused_rope.py b/toolbox/DeepSpeed/v0.15.3/patches/deepspeed/ops/rope/fused_rope.py new file mode 100644 index 0000000000000000000000000000000000000000..d2e0a5f18aea51731a7fd848e2f97496a71128d5 --- /dev/null +++ b/toolbox/DeepSpeed/v0.15.3/patches/deepspeed/ops/rope/fused_rope.py @@ -0,0 +1,107 @@ +#!/usr/bin/env python3 +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# coding=utf-8 +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Tuple, Union +import torch +from deepspeed.ops.op_builder import FusedRopeBuilder +class FusedRoPEFunc(torch.autograd.Function): + @staticmethod + def forward( + ctx, + t: torch.Tensor, + cos_: torch.Tensor, + sin_: torch.Tensor, + transpose_output_memory: bool = False, + ) -> torch.Tensor: + fused_rope_cuda = FusedRopeBuilder().load() + output = fused_rope_cuda.forward( + t, cos_, sin_, transpose_output_memory + ) + ctx.save_for_backward(cos_, sin_) + ctx.transpose_output_memory = transpose_output_memory + + return output + + @staticmethod + def backward( + ctx, grad_output: torch.Tensor + ) -> Tuple[Union[torch.Tensor, None], ...]: + fused_rope_cuda = FusedRopeBuilder().load() + cos_, sin_ = ctx.saved_tensors + grad_input = fused_rope_cuda.backward( + grad_output, cos_, sin_, ctx.transpose_output_memory + ) + + return grad_input, None, None, None + + +def fused_apply_rotary_pos_emb( + t: torch.Tensor, + freqs: torch.Tensor, + transpose_output_memory: bool = False, +) -> torch.Tensor: + """Apply rotary positional embedding to input tensor T. + + Args: + t (Tensor): Input tensor T is of shape [seq_length, ... , dim] + freqs (Tensor): Rotary Positional embedding tensor freq is of shape [seq_length, ..., dim] + transpose_output_memory (bool): Default to False. Whether to transpose the 's' and 'b' + dimension of the output's underlying memory format. This is very helpful when you want to + get a contiguous tensor after calling `output.transpose(0, 1)`. + + Returns: + Tensor: The input tensor after applying RoPE + """ + cos_ = torch.cos(freqs).to(t.dtype) + sin_ = torch.sin(freqs).to(t.dtype) + return FusedRoPEFunc.apply(t, cos_, sin_, transpose_output_memory) + + +def fused_apply_rotary_pos_emb_cached( + t: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + transpose_output_memory: bool = False, +) -> torch.Tensor: + """Apply rotary positional embedding to input tensor T. + + Args: + t (Tensor): Input tensor T is of shape [seq_length, ... , dim] + cos (Tensor): Cached cosine of the rotary positional embedding tensor is of shape [seq_length, ..., dim] + sin (Tensor): Cached sine of the rotary positional embedding tensor is of shape [seq_length, ..., dim] + transpose_output_memory (bool): Default to False. Whether to transpose the 's' and 'b' + dimension of the output's underlying memory format. This is very helpful when you want to + get a contiguous tensor after calling `output.transpose(0, 1)`. + + Returns: + Tensor: The input tensor after applying RoPE + """ + cos_ = cos.to(t.dtype) + sin_ = sin.to(t.dtype) + return FusedRoPEFunc.apply(t, cos_, sin_, transpose_output_memory) \ No newline at end of file diff --git a/toolbox/DeepSpeed/v0.15.3/patches/deepspeed/ops/swiglu/__init__.py b/toolbox/DeepSpeed/v0.15.3/patches/deepspeed/ops/swiglu/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..14eeb60a03868c95276ea291521f2611ba804d0f --- /dev/null +++ b/toolbox/DeepSpeed/v0.15.3/patches/deepspeed/ops/swiglu/__init__.py @@ -0,0 +1,16 @@ +#!/usr/bin/env python3 +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .swiglu import swiglu \ No newline at end of file diff --git a/toolbox/DeepSpeed/v0.15.3/patches/deepspeed/ops/swiglu/swiglu.py b/toolbox/DeepSpeed/v0.15.3/patches/deepspeed/ops/swiglu/swiglu.py new file mode 100644 index 0000000000000000000000000000000000000000..203828f1b4e5e8dd398f0d42f0b1802a379de1d0 --- /dev/null +++ b/toolbox/DeepSpeed/v0.15.3/patches/deepspeed/ops/swiglu/swiglu.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +from deepspeed.ops.op_builder import SwigluBuilder +global swiglu_cuda +class SwiGLUFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return swiglu_fwd(x) + + @staticmethod + def backward(ctx, g): + x, = ctx.saved_tensors + return swiglu_bwd(x, g) + +swiglu = SwiGLUFunction.apply +def swiglu_fwd(input: torch.Tensor): + swiglu_cuda = SwigluBuilder().load() + assert input.is_contiguous() + assert input.dtype == torch.half or input.dtype == torch.float or input.dtype == torch.bfloat16 + assert input.size(-1) % 2 == 0 + return swiglu_cuda.swiglu_fwd(input) + +def swiglu_bwd(input: torch.Tensor, grad: torch.Tensor): + swiglu_cuda = SwigluBuilder().load() + assert input.is_contiguous() and grad.is_contiguous() + assert (input.dtype == torch.half and grad.dtype == torch.half) or (input.dtype == torch.float and grad.dtype == torch.float) or (input.dtype == torch.bfloat16 and grad.dtype == torch.bfloat16) + assert input.size(-1) % 2 == 0 and input.size(-1) // grad.size(-1) == 2 + return swiglu_cuda.swiglu_bwd(input, grad) + diff --git a/toolbox/DeepSpeed/v0.15.3/patches/deepspeed/ops/transformer/__init__.py b/toolbox/DeepSpeed/v0.15.3/patches/deepspeed/ops/transformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..dbc4836d3245bb561bcd4af88c37fb7766581266 --- /dev/null +++ b/toolbox/DeepSpeed/v0.15.3/patches/deepspeed/ops/transformer/__init__.py @@ -0,0 +1,11 @@ +#!/usr/bin/env python3 +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +from .transformer import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig +from .inference.config import DeepSpeedInferenceConfig +from ...model_implementations.transformers.ds_transformer import DeepSpeedTransformerInference +from .inference.moe_inference import DeepSpeedMoEInferenceConfig, DeepSpeedMoEInference diff --git a/toolbox/DeepSpeed/v0.15.3/patches/deepspeed/runtime/data_pipeline/data_sampling/data_sampler.py b/toolbox/DeepSpeed/v0.15.3/patches/deepspeed/runtime/data_pipeline/data_sampling/data_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..2fc139100fa379291657d528e93ddd201f067e40 --- /dev/null +++ b/toolbox/DeepSpeed/v0.15.3/patches/deepspeed/runtime/data_pipeline/data_sampling/data_sampler.py @@ -0,0 +1,472 @@ +#!/usr/bin/env python3 +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +""" +coding=utf-8 + Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +Part of this code was adopted from https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/data/data_samplers.py +""" + +import torch +import os +import numpy as np + +import deepspeed.comm as dist +from deepspeed.utils import logger +from deepspeed.accelerator import get_accelerator +from ..constants import * +from ..curriculum_scheduler import CurriculumScheduler +from .indexed_dataset import MMapIndexedDataset +from .utils import create_mmap_dataset_builder, close_mmap_dataset_builder, find_fit_int_dtype + +from torch.utils.data import Dataset, Sampler +from typing import Iterator, List, Optional + +class DeepSpeedDataSampler(object): + + def __init__(self, + data_efficiency_config, + one_epoch_total_samples, + micro_batch_size, + data_parallel_rank, + data_parallel_size, + data_parallel_group, + gradient_accumulation_steps, + global_rank, + drop_last=True): + # Keep a copy of input params for later use. + self.data_efficiency_config = data_efficiency_config + self.one_epoch_total_samples = one_epoch_total_samples + self.index_dtype = find_fit_int_dtype(0, one_epoch_total_samples) + self.total_samples = one_epoch_total_samples * self.data_efficiency_config[DATA_SAMPLING][ + DATA_SAMPLING_NUM_EPOCHS] + self.micro_batch_size = micro_batch_size + self.data_parallel_rank = data_parallel_rank + self.data_parallel_group = data_parallel_group + self.micro_batch_times_data_parallel_size = \ + self.micro_batch_size * data_parallel_size + self.gradient_accumulation_steps = gradient_accumulation_steps + self.global_batch_size = self.micro_batch_times_data_parallel_size * \ + self.gradient_accumulation_steps + self.global_rank = global_rank + self.drop_last = drop_last + self.np_rng = np.random.default_rng(self.data_efficiency_config[DATA_EFFICIENCY_SEED]) + self.state = {} + self.batch = [] + self.consumed_samples = 0 + if self.data_efficiency_config[DATA_SAMPLING][CURRICULUM_LEARNING][CURRICULUM_LEARNING_ENABLED]: + self.curriculum_step = 0 + self.current_difficulties = {} + self.data_cluster_paths = [] + self.data_cluster_current_position = [] + self.curriculum_schedulers = {} + self.curriculum_index_to_sample = {} + self.curriculum_index_to_metric = {} + self.difficulty_type = {} + self.clustering_type = {} + self.data_1epoch_size = None + if self.global_rank == 0: + self.data_clusters = [] + self.data_cluster_sizes = [] + cluster_path = self.data_efficiency_config[DATA_SAMPLING][CURRICULUM_LEARNING][ + CURRICULUM_LEARNING_CLUSTER_PATH] + if not os.path.exists(cluster_path): + os.makedirs(cluster_path) + for metric in self.data_efficiency_config[DATA_SAMPLING][CURRICULUM_LEARNING][CURRICULUM_LEARNING_METRICS]: + self.curriculum_schedulers[metric] = CurriculumScheduler( + data_efficiency_config[DATA_SAMPLING][CURRICULUM_LEARNING][CURRICULUM_LEARNING_METRICS][metric]) + self.difficulty_type[metric] = data_efficiency_config[DATA_SAMPLING][CURRICULUM_LEARNING][ + CURRICULUM_LEARNING_METRICS][metric][CURRICULUM_LEARNING_DIFFICULTY_TYPE] + self.clustering_type[metric] = data_efficiency_config[DATA_SAMPLING][CURRICULUM_LEARNING][ + CURRICULUM_LEARNING_METRICS][metric][CURRICULUM_LEARNING_CLUSTERING_TYPE] + if self.global_rank == 0: + if self.clustering_type[metric] != CURRICULUM_LEARNING_SINGLE_CLUSTER: + self.curriculum_index_to_sample[metric] = MMapIndexedDataset( + data_efficiency_config[DATA_SAMPLING][CURRICULUM_LEARNING][CURRICULUM_LEARNING_METRICS] + [metric][CURRICULUM_LEARNING_SAMPLE_PATH], + skip_warmup=True) + if self.difficulty_type[metric] == CURRICULUM_LEARNING_VALUE_BASED: + self.curriculum_index_to_metric[metric] = MMapIndexedDataset( + data_efficiency_config[DATA_SAMPLING][CURRICULUM_LEARNING][CURRICULUM_LEARNING_METRICS] + [metric][CURRICULUM_LEARNING_METRIC_PATH], + skip_warmup=True) + + # Sanity checks. + assert self.total_samples > 0, \ + 'no sample to consume: {}'.format(self.total_samples) + assert self.micro_batch_size > 0 + assert data_parallel_size > 0 + assert self.data_parallel_rank < data_parallel_size, \ + 'data_parallel_rank should be smaller than data size: {}, ' \ + '{}'.format(self.data_parallel_rank, data_parallel_size) + + def __len__(self): + return self.total_samples + + def set_custom_curriculum_learning_schedule(self, schedule_func_dict): + for metric in self.curriculum_schedulers: + if metric in schedule_func_dict: + self.curriculum_schedulers[metric].set_custom_get_difficulty(schedule_func_dict[metric]) + + def get_start_end_idx(self, batch_len=None): + """ + given the length of a minibatch (defaults to micro-batch size * data_parallel_size), + return the start and end indices of the current data parallel rank + """ + batch_len = batch_len or self.micro_batch_times_data_parallel_size + start_idx_fn = lambda r: round(r * batch_len / self.data_parallel_group.size()) + start_idx = start_idx_fn(self.data_parallel_rank) + end_idx = start_idx_fn(self.data_parallel_rank + 1) + return start_idx, end_idx + + def get_sample_based_on_metric_value(self, metric, value_start, value_end): + new_samples = None + for row in range(len(self.curriculum_index_to_sample[metric])): + if self.curriculum_index_to_metric[metric][row] <= value_end and self.curriculum_index_to_metric[metric][ + row] > value_start: + row_samples = np.copy(self.curriculum_index_to_sample[metric][row]) + new_samples = row_samples if new_samples is None else np.concatenate( + (new_samples, row_samples), axis=None) + return new_samples + + def get_sample_based_on_metric_percentile(self, metric, percentile_start, percentile_end): + new_samples = None + if self.data_1epoch_size is None: + self.data_1epoch_size = sum(len(x) for x in self.curriculum_index_to_sample[metric]) + max_percentile = self.data_efficiency_config[DATA_SAMPLING][CURRICULUM_LEARNING][CURRICULUM_LEARNING_METRICS][ + metric][CURRICULUM_LEARNING_MAX_DIFFICULTY] + sample_per_percentile = self.data_1epoch_size // max_percentile + start_count = sample_per_percentile * percentile_start + end_count = sample_per_percentile * percentile_end + if percentile_end == max_percentile: + end_count = self.data_1epoch_size + current_count = 0 + for row in range(len(self.curriculum_index_to_sample[metric])): + row_size = len(self.curriculum_index_to_sample[metric][row]) + if current_count + row_size > start_count: + row_start = max(0, start_count - current_count) + if current_count + row_size <= end_count: + row_end = row_size + else: + row_end = end_count - current_count + row_samples = np.copy(self.curriculum_index_to_sample[metric][row][row_start:row_end]) + new_samples = row_samples if new_samples is None else np.concatenate( + (new_samples, row_samples), axis=None) + current_count += row_size + if current_count >= end_count: + break + return new_samples + + def get_new_cluster(self, previous_difficulties): + cluster_fname = CURRICULUM_LEARNING_CLUSTER_PREFIX + for metric in self.curriculum_schedulers: + cluster_fname = f"{cluster_fname}_{metric}{self.current_difficulties[metric]}" + cluster_path = self.data_efficiency_config[DATA_SAMPLING][CURRICULUM_LEARNING][ + CURRICULUM_LEARNING_CLUSTER_PATH] + cluster_path = f"{cluster_path}/{cluster_fname}" + if self.global_rank == 0: + new_cluster = None + need_clustering = 0 + for metric in self.clustering_type: + if self.clustering_type[metric] != CURRICULUM_LEARNING_SINGLE_CLUSTER: + need_clustering += 1 + if need_clustering > 1: + for metric in self.curriculum_schedulers: + if self.clustering_type[metric] == CURRICULUM_LEARNING_SINGLE_CLUSTER: + metric_cluster = np.arange(start=0, + stop=self.one_epoch_total_samples, + step=1, + dtype=self.index_dtype) + else: + if self.difficulty_type[metric] == CURRICULUM_LEARNING_VALUE_BASED: + metric_cluster = self.get_sample_based_on_metric_value(metric, float('-inf'), + self.current_difficulties[metric]) + elif self.difficulty_type[metric] == CURRICULUM_LEARNING_PERCENTILE_BASED: + metric_cluster = self.get_sample_based_on_metric_percentile( + metric, 0, self.current_difficulties[metric]) + new_cluster = metric_cluster if new_cluster is None else \ + np.intersect1d(new_cluster, metric_cluster, assume_unique=True) + for cluster in self.data_clusters: + new_cluster = np.setdiff1d(new_cluster, cluster[0], assume_unique=True) + else: + if len(self.data_clusters) == 0: + new_cluster = np.arange(start=0, stop=self.one_epoch_total_samples, step=1, dtype=self.index_dtype) + for metric in self.curriculum_schedulers: + if self.clustering_type[metric] != CURRICULUM_LEARNING_SINGLE_CLUSTER: + if self.difficulty_type[metric] == CURRICULUM_LEARNING_VALUE_BASED: + new_cluster = self.get_sample_based_on_metric_value(metric, previous_difficulties[metric], + self.current_difficulties[metric]) + elif self.difficulty_type[metric] == CURRICULUM_LEARNING_PERCENTILE_BASED: + new_cluster = self.get_sample_based_on_metric_percentile( + metric, previous_difficulties[metric], self.current_difficulties[metric]) + if new_cluster is not None and len(new_cluster) > 0: + logger.info( + f"new data cluster (previous_difficulties {previous_difficulties}, current_difficulties {self.current_difficulties}) with size {len(new_cluster)} generated." + ) + self.np_rng.shuffle(new_cluster) + cluster_builder = create_mmap_dataset_builder(cluster_path, self.index_dtype) + cluster_builder.add_item_numpy(new_cluster) + close_mmap_dataset_builder(cluster_builder, cluster_path) + self.data_clusters.append(MMapIndexedDataset(cluster_path, skip_warmup=True)) + self.data_cluster_sizes.append(len(self.data_clusters[-1][0])) + else: + logger.info( + f"new data cluster (previous_difficulties {previous_difficulties}, current_difficulties {self.current_difficulties}) has no matched data thus skipped." + ) + dist.barrier(group=self.data_parallel_group) + if os.path.isfile(f"{cluster_path}.bin"): + self.data_cluster_paths.append(cluster_fname) + self.data_cluster_current_position.append(0) + + def sample_from_clusters(self): + num_clusters = len(self.data_clusters) + weight_sum = sum(self.data_cluster_sizes) + weights = [x / weight_sum for x in self.data_cluster_sizes] + samples = self.np_rng.choice(num_clusters, self.global_batch_size, replace=True, p=weights) + samples = np.bincount(samples, minlength=num_clusters) + return samples + + def reshuffle_clusters(self, cidx): + cluster_fname = self.data_cluster_paths[cidx] + cluster_path = self.data_efficiency_config[DATA_SAMPLING][CURRICULUM_LEARNING][ + CURRICULUM_LEARNING_CLUSTER_PATH] + cluster_path = f"{cluster_path}/{cluster_fname}" + cluster = np.copy(self.data_clusters[cidx][0]) + self.np_rng.shuffle(cluster) + cluster_builder = create_mmap_dataset_builder(cluster_path, self.index_dtype) + cluster_builder.add_item_numpy(cluster) + close_mmap_dataset_builder(cluster_builder, cluster_path) + self.data_clusters[cidx] = MMapIndexedDataset(cluster_path, skip_warmup=True) + + def get_sample_from_cluster(self, cidx, num_samples): + start_idx = self.data_cluster_current_position[cidx] + samples = list(np.copy(self.data_clusters[cidx][0][start_idx:(start_idx + num_samples)])) + self.data_cluster_current_position[cidx] += num_samples + if len(samples) < num_samples: + num_samples_remained = num_samples - len(samples) + logger.info(f"reshuffling cluster {cidx}.") + self.reshuffle_clusters(cidx) + samples += list(np.copy(self.data_clusters[cidx][0][:num_samples_remained])) + self.data_cluster_current_position[cidx] = num_samples_remained + return samples + + def get_next_global_batch(self): + if self.data_efficiency_config[DATA_SAMPLING][CURRICULUM_LEARNING][CURRICULUM_LEARNING_ENABLED]: + self.curriculum_step += 1 + new_cluster = False + previous_difficulties = {} + for metric in self.curriculum_schedulers: + next_difficulty = self.curriculum_schedulers[metric].update_difficulty(self.curriculum_step) + if metric not in self.current_difficulties or \ + next_difficulty != self.current_difficulties[metric]: + new_cluster = True + if metric in self.current_difficulties: + previous_difficulties[metric] = self.current_difficulties[metric] + else: + if self.difficulty_type[metric] == CURRICULUM_LEARNING_VALUE_BASED: + previous_difficulties[metric] = float('-inf') + elif self.difficulty_type[metric] == CURRICULUM_LEARNING_PERCENTILE_BASED: + previous_difficulties[metric] = 0 + self.current_difficulties[metric] = next_difficulty + if new_cluster: + self.get_new_cluster(previous_difficulties) + if self.global_rank == 0: + samples_per_cluster = self.sample_from_clusters() + batch = [] + for cidx in range(len(samples_per_cluster)): + batch += self.get_sample_from_cluster(cidx, samples_per_cluster[cidx]) + self.np_rng.shuffle(batch) + + # broadcast tensor must have same shape across participants. So we fill batch with -1s when not full + assert len(batch) <= self.global_batch_size + batch += [-1] * (self.global_batch_size - len(batch)) + batch = torch.tensor(batch, device=get_accelerator().current_device_name(), dtype=torch.long).view(-1) + else: + batch = torch.empty(self.global_batch_size, + device=get_accelerator().current_device_name(), + dtype=torch.long) + dist.broadcast(batch, 0, group=self.data_parallel_group) + batch = batch[batch != -1] # remove trailing -1s used to fill incomplete batch tensor + self.batch = batch.tolist() + + def __iter__(self): + while self.consumed_samples <= self.total_samples: + if len(self.batch) == 0: + self.get_next_global_batch() + current_batch = self.batch[:self.micro_batch_times_data_parallel_size] + self.batch = self.batch[self.micro_batch_times_data_parallel_size:] + if len(current_batch) == self.micro_batch_times_data_parallel_size or \ + (len(current_batch) > 0 and not self.drop_last): + start_idx, end_idx = self.get_start_end_idx(len(current_batch)) + yield current_batch[start_idx:end_idx] + self.consumed_samples += len(current_batch) + current_batch = [] + + def state_dict(self): + return { + CURRICULUM_LEARNING_BATCH: self.batch, + CURRICULUM_LEARNING_CONSUMED_SAMPLES: self.consumed_samples, + CURRICULUM_LEARNING_STEP: self.curriculum_step, + CURRICULUM_LEARNING_CURRENT_DIFFICULTIES: self.current_difficulties, + CURRICULUM_LEARNING_DATA_CLUSTER_PATHS: self.data_cluster_paths, + CURRICULUM_LEARNING_DATA_CLUSTER_CURRENT_POSITION: self.data_cluster_current_position, + CURRICULUM_LEARNING_NP_RNG_STATE: np.random.get_state() + } + + def load_state_dict(self, state_dict): + self.batch = state_dict[CURRICULUM_LEARNING_BATCH] + self.consumed_samples = state_dict[CURRICULUM_LEARNING_CONSUMED_SAMPLES] + self.curriculum_step = state_dict[CURRICULUM_LEARNING_STEP] + self.current_difficulties = state_dict[CURRICULUM_LEARNING_CURRENT_DIFFICULTIES] + self.data_cluster_paths = state_dict[CURRICULUM_LEARNING_DATA_CLUSTER_PATHS] + self.data_cluster_current_position = state_dict[CURRICULUM_LEARNING_DATA_CLUSTER_CURRENT_POSITION] + np.random.set_state(state_dict[CURRICULUM_LEARNING_NP_RNG_STATE]) + cluster_root_path = self.data_efficiency_config[DATA_SAMPLING][CURRICULUM_LEARNING][ + CURRICULUM_LEARNING_CLUSTER_PATH] + # Backward compatibility: previously data_cluster_paths were stored as + # absolute paths. Now we changed it to just the file name so that even + # if user moved the cluster files, the checkpoint loading still works + # as long as user set the correct new CURRICULUM_LEARNING_CLUSTER_PATH + # in deepspeed json config. + for idx in range(len(self.data_cluster_paths)): + if '/' in self.data_cluster_paths[idx]: + self.data_cluster_paths[idx] = self.data_cluster_paths[idx].split('/')[-1] + if self.global_rank == 0: + for cluster_fname in self.data_cluster_paths: + cluster_path = f"{cluster_root_path}/{cluster_fname}" + self.data_clusters.append(MMapIndexedDataset(cluster_path, skip_warmup=True)) + self.data_cluster_sizes.append(len(self.data_clusters[-1][0])) + + +class DsRandomSampler(Sampler): + data_source: Optional[Dataset] + replacement: bool + + def __init__(self, data_source: Optional[Dataset], replacement: bool = False, + num_samples: Optional[int] = None, generator=None, + lengths: Optional[List[int]] = None, batch_size: Optional[int] = None) -> None: + self.data_source = data_source + self.replacement = replacement + self._num_samples = num_samples + self.generator = generator + + if lengths is None: + model_input_name = "input_ids" + lengths = [len(feature[model_input_name]) for feature in self.data_source] + elif isinstance(lengths, torch.Tensor): + logger.info( + "If lengths is a torch.Tensor, LengthGroupedSampler will be slow. Converting lengths to List[int]..." + ) + lengths = lengths.tolist() + + self.lengths = lengths + self.batch_size = batch_size + self.trs_random = os.getenv("RANDOM_SAMPLER", "LargerLengthsPre").upper() + + if not isinstance(self.replacement, bool): + raise TypeError(f"replacement should be a boolean value, but got replacement={self.replacement}") + + if not isinstance(self.num_samples, int) or self.num_samples <= 0: + raise ValueError(f"num_samples should be a positive integer value, but got num_samples={self.num_samples}") + + @property + def num_samples(self) -> int: + # dataset size might change at runtime + if self._num_samples is None: + return len(self.data_source) + return self._num_samples + + def __iter__(self) -> Iterator[int]: + n = len(self.data_source) + if self.generator is None: + seed = int(torch.empty((), dtype=torch.int64).random_().item()) + generator = torch.Generator() + generator.manual_seed(seed) + else: + generator = self.generator + + if self.replacement: + for _ in range(self.num_samples // 32): + yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist() + yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist() + else: + for _ in range(self.num_samples // n): + yield from self.group_larger_lengths_indices(torch.randperm(n, generator=generator).tolist()) + yield from self.group_larger_lengths_indices(torch.randperm(n, generator=generator).tolist()[:self.num_samples % n]) + + def __len__(self) -> int: + return self.num_samples + + def group_larger_lengths_indices(self, ori_indices: List[int] = []) -> List[int]: + if len(ori_indices) == 0: + return ori_indices + + limit_shape = int(os.getenv("LIMIT_SHAPE", 500)) + normal_indices = [] + larger_indices = [] + for i in range(len(ori_indices)): + if self.lengths[ori_indices[i]] > limit_shape: + larger_indices.append(ori_indices[i]) + else: + normal_indices.append(ori_indices[i]) + + import random + scaling_factor = 2 + max_multiple_bs = 8 + if self.trs_random == "LARGERLENGTHSPRE": + # ========== add first random ========== + larger_indices_length = len(larger_indices) + if larger_indices_length % self.batch_size != 0: + multiple_bs = larger_indices_length // self.batch_size + 1 + multiple_bs = max(min(multiple_bs * scaling_factor, max_multiple_bs), multiple_bs) + diff_normal = multiple_bs * self.batch_size - larger_indices_length + larger_indices.extend(normal_indices[0:diff_normal]) + normal_indices = normal_indices[diff_normal:] + random.shuffle(larger_indices) + larger_indices.extend(normal_indices) + else: + random.shuffle(larger_indices) + larger_indices.extend(normal_indices) + return larger_indices + elif self.trs_random == "LARGERLENGTHSPOST": + # ========== add last random ========== + larger_indices_length = len(larger_indices) + normal_indices_length = len(normal_indices) + if larger_indices_length % self.batch_size != 0: + multiple_bs = larger_indices_length // self.batch_size + 1 + multiple_bs = max(min(multiple_bs * scaling_factor, max_multiple_bs), multiple_bs) + diff_normal = multiple_bs * self.batch_size - larger_indices_length + larger_indices.extend(normal_indices[normal_indices_length - diff_normal:]) + normal_indices = normal_indices[0:normal_indices_length - diff_normal] + random.shuffle(larger_indices) + normal_indices.extend(larger_indices) + else: + random.shuffle(larger_indices) + normal_indices.extend(larger_indices) + return normal_indices + elif self.trs_random == "LARGERLENGTHSPOSTANDSORT": + # ========== larger lengths add last and larger-> smaller sort: test largest limit module id + sort_lengths, sort_indices = torch.sort(torch.tensor(self.lengths), descending=True) + larger_mask = sort_lengths > limit_shape + larger_indices = sort_indices[larger_mask].tolist() + norm_indices = sort_indices[~larger_mask].tolist() + norm_indices.extend(larger_indices) + return norm_indices + else: + raise RuntimeError(f"{self.trs_random} is not support") \ No newline at end of file diff --git a/toolbox/DeepSpeed/v0.15.3/patches/deepspeed/runtime/pipe/engine.py b/toolbox/DeepSpeed/v0.15.3/patches/deepspeed/runtime/pipe/engine.py new file mode 100644 index 0000000000000000000000000000000000000000..8a49c969a64c910b6b22b05b219e9459f4f79c2e --- /dev/null +++ b/toolbox/DeepSpeed/v0.15.3/patches/deepspeed/runtime/pipe/engine.py @@ -0,0 +1,1426 @@ +#!/usr/bin/env python3 +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from types import MethodType +from collections import OrderedDict +from functools import reduce +from operator import mul + +import torch +from deepspeed import comm as dist + +from deepspeed.utils import logger +from deepspeed.utils.timer import ThroughputTimer +from deepspeed.accelerator import get_accelerator +from deepspeed.runtime.bf16_optimizer import BF16_Optimizer + +from ..engine import DeepSpeedEngine, MEMORY_OPT_ALLREDUCE_SIZE +from deepspeed.utils.timer import FORWARD_MICRO_TIMER, FORWARD_GLOBAL_TIMER, BACKWARD_MICRO_TIMER, \ + BACKWARD_GLOBAL_TIMER, BACKWARD_INNER_MICRO_TIMER, BACKWARD_INNER_GLOBAL_TIMER, \ + BACKWARD_REDUCE_MICRO_TIMER, BACKWARD_REDUCE_GLOBAL_TIMER, \ + STEP_MICRO_TIMER, STEP_GLOBAL_TIMER + +from ..utils import PartitionedTensor +from ..dataloader import RepeatingLoader +from ..zero.config import ZeroStageEnum +from ..activation_checkpointing import checkpointing as ds_checkpointing + +from .module import PipelineModule, PipelineError +from . import p2p +from . import schedule + +TARGET_ID = -2 +LOG_STAGE = -2 +DATA_PARALLEL_ID = -2 + +BATCH_INPUT_TIMER = 'batch_input' +TRAIN_BATCH_TIMER = 'train_batch' +PIPE_SEND_OUTPUT_TIMER = 'pipe_send_output' +PIPE_SEND_GRAD_TIMER = 'pipe_send_grad' +PIPE_RECV_INPUT_TIMER = 'pipe_recv_input' +PIPE_RECV_GRAD_TIMER = 'pipe_recv_grad' + +# The buffer size to store the meta data for each tensor. +TENSOR_META_SIZE = 256 + + +def is_even(number): + return number % 2 == 0 + + +mem_alloced = 0 +mem_cached = 0 + + +def _tensor_bytes(tensor): + return tensor.numel() * tensor.element_size() + + +class PipelineEngine(DeepSpeedEngine): + """ A training engine hybrid pipeline, data, and model parallel training. + + This engine is created by ``deepspeed.initialize()`` when a :class:`PipelineModule` + is provided. + """ + ID_TO_DTYPE = [ + torch.float32, torch.float64, torch.complex64, torch.complex128, torch.float16, torch.bfloat16, torch.uint8, + torch.int8, torch.int16, torch.int32, torch.int64, torch.bool + ] + DTYPE_TO_ID = {dtype: id_ for id_, dtype in enumerate(ID_TO_DTYPE)} + + def __init__(self, has_bool_tensors=False, *super_args, **super_kwargs): + super().__init__(*super_args, **super_kwargs) + assert isinstance(self.module, PipelineModule), "model must base PipelineModule" + + assert self.zero_optimization_stage( + ) < ZeroStageEnum.gradients, "ZeRO-2 and ZeRO-3 are incompatible with pipeline parallelism" + + # We schedule the all-reduces, so disable it in super().backward() + self.enable_backward_allreduce = False + self.has_bool_tensors = has_bool_tensors + self.eval_return_logits = False + self.outputs = None + # BF16 Optimizer is hardcoded for fp32 gradient accumulation + self.using_bf16_optimizer = type(self.optimizer) == BF16_Optimizer + + # used to disable the pipeline all-reduce when used with 1-bit Adam/1-bit LAMB + self.pipeline_enable_backward_allreduce = True + + if self.elasticity_enabled(): + if not self.is_elastic_model_parallel_supported(): + assert not self.elasticity_enabled(), "Elasticity is not currently supported" \ + " with pipeline parallelism." + + # pipeline step for logging + self.log_batch_step_id = -1 + + self.micro_batch_size = self.train_micro_batch_size_per_gpu() + self.micro_batches = self.gradient_accumulation_steps() + + # Set Grid and Communication Groups + self.grid = self.module._grid + if self.grid.get_global_rank() == 0: + logger.info(f'CONFIG: micro_batches={self.micro_batches} ' + f'micro_batch_size={self.micro_batch_size}') + + self.global_rank = self.grid.get_global_rank() + + assert self.dp_world_size == self.grid.data_parallel_size + assert self.train_batch_size() == \ + self.micro_batch_size * self.micro_batches * self.grid.data_parallel_size + + # Set Stage Inf + self.num_stages = self.grid.pipe_parallel_size + self.stage_id = self.grid.get_stage_id() + self.prev_stage = self.stage_id - 1 + self.next_stage = self.stage_id + 1 + + self.data_iterator = None + self.batch_fn = None + + self._force_grad_boundary = False + + self.batch_timer = ThroughputTimer(self._config.timers_config, + batch_size=self.train_batch_size(), + logging_fn=self.tput_log, + monitor_memory=False, + steps_per_output=self.steps_per_print()) + + # PipelineEngine needs to handle data loading specially due to only the first + # and last stages loading inputs/labels. We construct a sampler that uses + if self.training_data: + self._build_data_iter(self.training_data) + + self.is_pipe_parallel = self.grid.pipe_parallel_size > 1 + self.is_data_parallel = self.grid.data_parallel_size > 1 + self.is_model_parallel = self.grid.model_parallel_size > 1 + + # Partition input/output buffers + # XXX temporarily disable while I revert some partition hacks. + assert isinstance(self._config.pipeline['pipe_partitioned'], bool) + assert isinstance(self._config.pipeline['grad_partitioned'], bool) + self.is_pipe_partitioned = self.is_model_parallel and self._config.pipeline['pipe_partitioned'] + self.is_grad_partitioned = self.is_model_parallel and self._config.pipeline['grad_partitioned'] + logger.info(f'is_pipe_partitioned= {self.is_pipe_partitioned} ' + f'is_grad_partitioned= {self.is_grad_partitioned}') + + model_parameters = filter(lambda p: p.requires_grad, self.module.parameters()) + num_params = sum([p.numel() for p in model_parameters]) + unique_params = num_params + # Subtract tied parameters if we don't own them + if self.module.tied_comms: + tied_params = 0 + for key, d in self.module.tied_comms.items(): + if self.global_rank != min(d['ranks']): + tied_params += sum(p.numel() for p in d['module'].parameters()) + unique_params -= tied_params + params_tensor = torch.LongTensor(data=[num_params, unique_params]).to(self.device) + dist.all_reduce(params_tensor, group=self.grid.get_model_parallel_group()) + params_tensor = params_tensor.tolist() + total_params = params_tensor[0] + unique_params = params_tensor[1] + if self.grid.data_parallel_id == 0: + logger.info(f'RANK={self.global_rank} ' + f'STAGE={self.stage_id} ' + f'LAYERS={self.module._local_stop - self.module._local_start} ' + f'[{self.module._local_start}, {self.module._local_stop}) ' + f'STAGE_PARAMS={num_params} ({num_params/1e6:0.3f}M) ' + f'TOTAL_PARAMS={total_params} ({total_params/1e6:0.3f}M) ' + f'UNIQUE_PARAMS={unique_params} ({unique_params/1e6:0.3f}M)') + + #initialize peer-2-peer communication and allreduce groups + if self.is_pipe_parallel: + p2p.init_process_groups(self.grid) + + # Pipeline buffers + self.num_pipe_buffers = 0 + self.pipe_buffers = { + 'inputs': [], # batch input and received activations + 'labels': [], # labels from batch input + 'outputs': [], # activations + 'output_tensors': [], # tensor object to preserve backward graph + } + self.pipe_recv_buf = None + self.grad_layer = None + self._grad_layer_buf = [] + + self.meta_buffer = None + + self.first_output_send = True + self.first_gradient_send = True + self.pipe_partition_input_meta_cache = None + self.pipe_partition_output_meta_cache = None + self.pipe_partition_grad_meta_cache = None + self.grad_partition_grad_layer_meta_cache = None + + #stores the loss for the current micro batch being processed + self.loss = torch.tensor(0.0).to(self.device) + + #stores the loss for the entire batch + self.total_loss = None + self.total_additional_losses = None + self.agg_loss = torch.tensor(0.0, requires_grad=False).to(self.device) + self.dp_group_loss = torch.tensor(0.0, requires_grad=False).to(self.device) + + # stores aggregated-DP train final loss and aggregated-DP additional losses, if any + # additional losses are stored as dict: {loss-name: agg-loss} + self.agg_train_loss = None + self.agg_additional_losses = None + + if self._config.pipeline['activation_checkpoint_interval'] > 0: + self.module.activation_checkpoint_interval = self._config.pipeline['activation_checkpoint_interval'] + # set use_reentrant default to True. + if self._config.pipeline.get('use_reentrant') is None: + self._config.pipeline['use_reentrant'] = True + if self._config.pipeline['use_reentrant'] is False: + # set activation_checkpoint_func to non_reentrant_checkpoint func. + self.module.activation_checkpoint_func = ds_checkpointing.non_reentrant_checkpoint + if self.grid.get_global_rank() == 0: + logger.info(f'CONFIG: activation_checkpoint_func=non_reentrant_checkpoint') + if self.module.activation_checkpoint_interval > 0: + self.module._precompute_checkpointable_values() + + self.module.checkpoint_parallel_write_pipeline = self._config.checkpoint_parallel_write_pipeline + + if self.is_last_stage(): + self.loss_model = self.module.loss_fn + + self.has_attention_mask = self.module.__class__.__name__ == 'GPT2ModelPipe' + # Initialize pipeline communicators. Just send a 0. + if is_even(self.stage_id): + if not self.is_last_stage(): + p2p.send(self.loss, self.next_stage) + if not self.is_first_stage(): + p2p.recv(self.loss, self.prev_stage) + else: + if not self.is_first_stage(): + p2p.recv(self.loss, self.prev_stage) + if not self.is_last_stage(): + p2p.send(self.loss, self.next_stage) + + # XXX look into timer reporting timing + # Initialize some timers because of early weirdness. + if self.wall_clock_breakdown(): + self.timers(FORWARD_MICRO_TIMER).start() + self.timers(FORWARD_MICRO_TIMER).stop() + self.timers(BACKWARD_MICRO_TIMER).start() + self.timers(BACKWARD_MICRO_TIMER).stop() + self.timers(BACKWARD_INNER_MICRO_TIMER).start() + self.timers(BACKWARD_INNER_MICRO_TIMER).stop() + self.timers(BACKWARD_REDUCE_MICRO_TIMER).start() + self.timers(BACKWARD_REDUCE_MICRO_TIMER).stop() + self.timers(BACKWARD_REDUCE_GLOBAL_TIMER).start() + self.timers(BACKWARD_REDUCE_GLOBAL_TIMER).stop() + self.timers(STEP_MICRO_TIMER).start() + self.timers(STEP_MICRO_TIMER).stop() + + self.dynamic_shape = self.module.dynamic_shape + + def set_has_attention_mask(self, value): + assert isinstance(value, bool) + self.has_attention_mask = value + + def _build_data_iter(self, dataset): + sampler = torch.utils.data.distributed.DistributedSampler(dataset, + num_replicas=self.dp_world_size, + rank=self.mpu.get_data_parallel_rank(), + shuffle=False) + # Build a loader and make it repeating. + pipe_dataloader = self.deepspeed_io(dataset, data_sampler=sampler) + pipe_dataloader = RepeatingLoader(pipe_dataloader) + self.set_dataloader(pipe_dataloader) + + def _exec_reduce_tied_grads(self): + # We need to run this first to write to self.averaged_gradients; + # since this class turns `enable_backward_allreduce` off, + # `self.overlapping_partition_gradients_reduce_epilogue()` defined in the DeepSpeedEngine + # never actually runs. I suspect this is because of efficiency problems; get_flat_partition in + # stage2.py might do something expensive; someone will have to look into that later. But + # in the meantime, this fixes ZeRO2 + Pipelining enough to run a demo. Further profiling + # needed to decide if it actually breaks everything. + # (see https://github.com/EleutherAI/gpt-neox/issues/62#issuecomment-761471944) + if self.zero_optimization_partition_gradients(): + self.optimizer.overlapping_partition_gradients_reduce_epilogue() + + weight_group_list = self.module.get_tied_weights_and_groups() + for weight, group in weight_group_list: + grad = weight._hp_grad if self.using_bf16_optimizer else weight.grad + dist.all_reduce(grad, group=group) + + def _exec_reduce_grads(self): + self._force_grad_boundary = True + if self.pipeline_enable_backward_allreduce: + if self.using_bf16_optimizer: + # PP+BF16 work for ZeRO Stage 1 + self._bf16_reduce_grads() + else: + self.allreduce_gradients(bucket_size=MEMORY_OPT_ALLREDUCE_SIZE) + self._force_grad_boundary = False + + def _bf16_reduce_grads(self): + self.buffered_allreduce_fallback(grads=None, elements_per_buffer=MEMORY_OPT_ALLREDUCE_SIZE) + + def _reserve_pipe_buffers(self, num_buffers): + """Ensure that each pipeline buffer has at least ``num_buffers`` slots. + + This method only reserves slots and does not allocate tensors. + + Args: + num_buffers (int): The number of buffers to reserve. + """ + if self.num_pipe_buffers >= num_buffers: + return + + num_added = num_buffers - self.num_pipe_buffers + for key in self.pipe_buffers: + self.pipe_buffers[key].extend([None] * num_added) + self.num_pipe_buffers = num_buffers + + def reset_activation_shape(self): + """Reset the buffers when the shape of activation and gradient change. + For example, for curriculum learning that changes the seqlen of each + sample, we need to call this whenever the seqlen is going to change. + """ + self.first_output_send = True + self.pipe_recv_buf = None + self.grad_layer = None + self._grad_layer_buf = [] + self.meta_buffer = None + + self.pipe_partition_input_meta_cache = None + self.pipe_partition_output_meta_cache = None + self.pipe_partition_grad_meta_cache = None + self.grad_partition_grad_layer_meta_cache = None + + def train_batch(self, data_iter=None): + """Progress the pipeline to train the next batch of data. The engine will ingest + ``self.train_batch_size()`` total samples collectively across all workers. + + + An iterator that over training data should be provided as an argument + unless ``deepspeed.initialize()`` was provided a training set. In that event, + the training data will automatically be read. + + + .. warning:: + A total of ``self.gradient_accumulation_steps()`` entries will be pulled + from ``data_iter`` by each pipeline. There must be sufficient + data left in ``data_iter`` or else a ``StopIteration`` will halt training. + + DeepSpeed provides a convenience class :class:`deepspeed.utils.RepeatingLoader` + that wraps data loaders to automatically restart upon a ``StopIteration``. + + Args: + data_iter (Iterator, optional): Iterator of training data. + + Returns: + The arithmetic mean of the losses computed this batch. + """ + if not torch._C.is_grad_enabled(): + raise RuntimeError(f'train_batch() requires gradients enabled. Use eval_batch() instead.') + + # Curriculum learning could change activation shape + if self.curriculum_enabled_legacy(): + new_difficulty = self.curriculum_scheduler_legacy.update_difficulty( \ + self.global_steps + 1) + if self.global_steps == 0 or self.curriculum_scheduler_legacy.first_step: + self.reset_activation_shape() + self.curriculum_scheduler_legacy.first_step = False + elif new_difficulty != self.curriculum_scheduler_legacy.get_difficulty( \ + self.global_steps): + self.reset_activation_shape() + + if data_iter is not None: + self.set_dataiterator(data_iter) + + self.module.train() + self.total_loss = None + self.total_additional_losses = None + self._compute_loss = True + + # Do the work + self.timers(TRAIN_BATCH_TIMER).start() + sched = schedule.TrainSchedule(micro_batches=self.micro_batches, + stages=self.num_stages, + stage_id=self.stage_id) + self._exec_schedule(sched) + + with torch.no_grad(): + self.agg_train_loss = self._aggregate_total_loss() + + self.timers(TRAIN_BATCH_TIMER).stop() + + if self.global_steps % self.steps_per_print() == 0: + if self.global_rank == 0: + elapsed = self.timers(TRAIN_BATCH_TIMER).elapsed(reset=True) / 1000.0 + iter_time = elapsed / self.steps_per_print() + tput = self.train_batch_size() / iter_time + log_str = f'steps: {self.global_steps} loss: {self.agg_train_loss:0.4f} ' + if self.agg_additional_losses is not None: + for loss_name, loss_value in self.agg_additional_losses.items(): + log_str += f'{loss_name}: {loss_value.item():0.4f} ' + log_str += f'iter time (s): {iter_time:0.3f} samples/sec: {tput:0.3f}' + print(log_str) + else: + self.timers(TRAIN_BATCH_TIMER).elapsed(reset=True) + + # Monitoring + if self.global_rank == 0 and self.monitor.enabled: + self.summary_events = [(f'Train/Samples/train_loss', self.agg_train_loss.mean().item(), + self.global_samples)] + self.monitor.write_events(self.summary_events) + + if self.wall_clock_breakdown() and self.global_steps % self.steps_per_print() == 0: + self.timers.log([ + PIPE_SEND_OUTPUT_TIMER, + PIPE_SEND_GRAD_TIMER, + PIPE_RECV_INPUT_TIMER, + PIPE_RECV_GRAD_TIMER, + ]) + + # TODO: should return precisely what loss returned and allow others to be queried? + return self.agg_train_loss + + def eval_batch(self, + data_iter, + return_logits=False, + compute_loss=True, + reduce_output='avg', + bcast_loss=True, + num_micro_batches=None): + """Evaluate the pipeline on a batch of data from ``data_iter``. The + engine will evaluate ``self.train_batch_size()`` total samples + collectively across all workers. + + This method is equivalent to: + + .. code-block:: python + + module.eval() + with torch.no_grad(): + output = module(batch) + + .. warning:: + A total of ``self.gradient_accumulation_steps()`` entries will be pulled + from ``data_iter`` by each pipeline. There must be sufficient + data left in ``data_iter`` or else a ``StopIteration`` will halt training. + + DeepSpeed provides a convenience class :class:`deepspeed.utils.RepeatingLoader` + that wraps data loaders to automatically restart upon a ``StopIteration``. + + Args: + data_iter (Iterator): Iterator of data to evaluate. + + Returns: + The arithmetic mean of the losses computed this batch. + """ + self.eval_return_logits = return_logits + self.module.eval() + + # Curriculum learning could change activation shape + if self.curriculum_enabled_legacy(): + new_difficulty = self.curriculum_scheduler_legacy.update_difficulty( \ + self.global_steps + 1) + if self.global_steps == 0 or self.curriculum_scheduler_legacy.first_step: + self.reset_activation_shape() + self.curriculum_scheduler_legacy.first_step = False + elif new_difficulty != self.curriculum_scheduler_legacy.get_difficulty( \ + self.global_steps): + self.reset_activation_shape() + + eval_output = None + + self._compute_loss = compute_loss + + # Use the provided data iterator + train_iterator = self.data_iterator + self.set_dataiterator(data_iter) + + # set the number micro batches in case the user chose value than training + micro_batches = self.micro_batches if num_micro_batches is None else num_micro_batches + + # Do the work + sched = schedule.InferenceSchedule(micro_batches=micro_batches, stages=self.num_stages, stage_id=self.stage_id) + + # prevent dead-lock with multiple evals sequence + dist.barrier() + + with torch.no_grad(): + self._exec_schedule(sched) + + if self.is_last_stage(): + eval_output = self._reduce_outputs(self.fwd_outputs, reduce=reduce_output, micro_batches=micro_batches) + + if compute_loss and (bcast_loss or self.monitor.enabled): + eval_output = self._bcast_pipe_scalar(eval_output) + + if self.global_rank == 0 and self.monitor.enabled: + self.summary_events = [(f'Train/Samples/eval_loss', eval_output.mean().item(), self.global_samples)] + self.monitor.write_events(self.summary_events) + + # Restore the training iterator + self.set_dataiterator(train_iterator) + + # Reset any buffers that may have been populated during the forward passes. + #ds_checkpointing.reset() + self.eval_return_logits = False + if return_logits: + outputs = self.outputs + self.outputs = None + return eval_output, outputs + return eval_output + + def set_train_batch_size(self, train_batch_size): + """Adjust the global batch size by increasing or decreasing the number of + micro-batches (i.e., gradient accumulation steps). The size of each micro-batch + (i.e., ``train_micro_batch_size_per_gpu``) is not changed. + Args: + train_batch_size (int): The new global batch size for training. + Raises: + ValueError: if ``train_batch_size`` is not divisible by the + configured micro-batch size and data parallelism. + """ + super().set_train_batch_size(train_batch_size) + self.micro_batches = self.gradient_accumulation_steps() + + def is_first_stage(self): + """True if this process is in the first stage in the pipeline.""" + return self.stage_id == 0 + + def is_last_stage(self): + """True if this process is in the last stage in the pipeline.""" + return self.stage_id == self.num_stages - 1 + + def _reduce_outputs(self, outputs, reduce='avg', reduce_dp=True, micro_batches=None): + if reduce is None: + return outputs + + if reduce.lower() == 'avg': + # first sum over all microbatches + if torch.is_tensor(outputs[0]): + reduced = sum(outputs) + else: + assert isinstance(outputs, (list, tuple)) + reduced = [torch.zeros_like(o) for o in outputs[0]] + for idx, out in outputs: + reduced[idx] += out + + # Average over the microbatches + reduced = self._scale_loss_by_gas(reduced, eval_micro_batches=micro_batches) + + # Average over DP groups + if reduce_dp and self.is_data_parallel: + if torch.is_tensor(reduced): + dist.all_reduce(reduced, group=self.mpu.get_data_parallel_group()) + reduced /= self.dp_world_size + else: + for idx in range(len(reduced)): + dist.all_reduce(reduced[idx], group=self.mpu.get_data_parallel_group()) + reduced[idx] /= self.dp_world_size + + return reduced + else: + raise NotImplementedError(f'reduction type {reduce} not supported.') + + def _bcast_pipe_scalar(self, data, src_rank=None, dtype=torch.float32): + # Default to last stage (e.g., for broadcasting loss) + if src_rank is None: + src_rank = self.grid.stage_to_global(self.num_stages - 1) + assert src_rank in self.grid.pp_group + + if self.global_rank == src_rank: + result = data.clone().detach().type(dtype).to(self.device) + else: + result = torch.Tensor([0.]).type(dtype).to(self.device) + + dist.broadcast(tensor=result, src=src_rank, group=self.mpu.get_pipe_parallel_group()) + + return result + + def _aggregate_total_loss(self): + # Scale loss, average among DP ranks, and bcast loss to the rest of my DP group + if self.is_last_stage(): + # Scale loss and additional losses, if any + loss = self._scale_loss_by_gas(self.total_loss) + self.agg_additional_losses = self.total_additional_losses + if self.agg_additional_losses is not None: + self.agg_additional_losses = OrderedDict({ + loss_name: self._scale_loss_by_gas(_loss.clone().detach()) + for loss_name, _loss in self.agg_additional_losses.items() + }) + + self.dp_group_loss = loss.clone().detach() + agg_loss = self.dp_group_loss.clone().detach() + #print(f'RANK={self.global_rank} bcast SENDER src={self.global_rank} group={self.grid.pp_group}', flush=True) + + # Average loss across all data-parallel groups + if self.is_data_parallel: + if self.agg_additional_losses is None: + dist.all_reduce(agg_loss, group=self.mpu.get_data_parallel_group()) + agg_loss /= self.dp_world_size + else: + # use a single reduce op for agg_loss and additional losses, if any + assert '__train_loss__' not in self.agg_additional_losses.keys() + tensors = OrderedDict({'__train_loss__': agg_loss}) + tensors.update(self.agg_additional_losses.items()) + flat_tensor = torch.cat([t.clone().reshape(-1).detach() for t in tensors.values()]) + dist.all_reduce(flat_tensor, group=self.mpu.get_data_parallel_group()) + flat_tensor /= self.dp_world_size + offset = 0 + reduced_tensor = {} + for name, t in tensors.items(): + n_elem = t.numel() + reduced_tensor[name] = flat_tensor[offset:offset + n_elem].clone().detach().reshape(t.shape) + offset += n_elem + agg_loss = reduced_tensor['__train_loss__'] + self.agg_additional_losses = OrderedDict( + {name: reduced_tensor[name] + for name in self.agg_additional_losses.keys()}) + + assert self.global_rank in self.grid.pp_group + losses = [self.dp_group_loss, agg_loss] + if self.agg_additional_losses is not None: + losses += list(self.agg_additional_losses.values()) + losses = torch.stack(losses).float() + if self.is_pipe_parallel: + dist.broadcast(tensor=losses, src=self.global_rank, group=self.mpu.get_pipe_parallel_group()) + else: + # Get loss from last stage + src_rank = self.grid.stage_to_global(self.num_stages - 1) + assert src_rank in self.grid.pp_group + # losses to reduce are: dp_group_loss, agg_loss, model additional losses + # therefore: 2 + n_additional_losses + additional_losses = self.module.get_additional_losses() + n_additional_losses = 0 if additional_losses is None else len(additional_losses) + losses = torch.Tensor([0.] * (2 + n_additional_losses)).to(self.device) + dist.broadcast(tensor=losses, src=src_rank, group=self.grid.get_pipe_parallel_group()) + self.dp_group_loss = losses[0].clone().detach() + agg_loss = losses[1].clone().detach() + if additional_losses is not None: + self.agg_additional_losses = OrderedDict( + {name: losses[2 + i].clone().detach() + for i, name in enumerate(additional_losses.keys())}) + return agg_loss + + def set_dataloader(self, loader): + """""" + if self.is_first_stage() or self.is_last_stage(): + self.training_dataloader = loader + self.data_iterator = iter(self.training_dataloader) + + def set_dataiterator(self, iterator): + """ Store an iterator to sample for training data. """ + if self.is_first_stage() or self.is_last_stage(): + self.training_dataloader = None + self.data_iterator = iterator + + def set_batch_fn(self, fn): + """Execute a post-processing function on input data. + + Args: + fn (function): The function to run. + """ + self.batch_fn = fn + + def is_gradient_accumulation_boundary(self): + """True if the engine is executing a gradient reduction or optimizer step instruction. + + This is overridden from :class:`DeepSpeedEngine` to force reductions + and steps when the pipeline engine is instructed to do so. + + Returns: + bool: whether reductions and optimizer steps should occur. + """ + return self._force_grad_boundary + + def log_for_device(self, *msg): + if LOG_STAGE == self.stage_id or LOG_STAGE == -1: + if DATA_PARALLEL_ID == self.grid.data_parallel_id or DATA_PARALLEL_ID == -1: + print( + f'RANK={dist.get_rank()} ' + f'PIPE-ID={self.stage_id} ' + f'DATA-ID={self.grid.data_parallel_id} ' + f'MBATCH-ID={self.microbatch_id} ' + f'STEP-ID={self.log_batch_step_id} ' + '::', + *msg, + flush=True) + + def tput_log(self, *msg): + if self.global_rank == 0 and self.global_steps % self.steps_per_print() == 0: + print(*msg) + + def _next_batch(self): + # If using 3D parallelism, only some first-stage ranks may do IO + batch = None + if self.data_iterator is not None: + batch = next(self.data_iterator) + + # Any post-processing, like broadcasting across a slice-parallel group. + if self.batch_fn: + batch = self.batch_fn(batch) + + return batch + + def _exec_forward_pass(self, buffer_id): + self.tput_timer.start() + self.mem_status('BEFORE FWD', reset_max=True) + + if isinstance(self.pipe_buffers['inputs'][buffer_id], tuple): + inputs = tuple(t.clone() for t in self.pipe_buffers['inputs'][buffer_id]) + else: + inputs = self.pipe_buffers['inputs'][buffer_id].clone() + + # collect the partitioned input from the previous stage + if self.is_pipe_partitioned and not self.is_first_stage(): + if self.pipe_partition_input_meta_cache is None: + self.pipe_partition_input_meta_cache = inputs[0].to('cpu') + part_input = PartitionedTensor.from_meta(meta=self.pipe_partition_input_meta_cache, + local_part=inputs[1], + group=self.grid.get_slice_parallel_group()) + + inputs = (part_input.full(), *inputs[2:]) + inputs[0].requires_grad = True + # skip mask + #inputs[1].requires_grad = True + part_input = None + inputs = inputs[0] if len(inputs) == 1 else inputs + self.pipe_buffers['inputs'][buffer_id] = inputs + + # inputs has no gradient because it is from a cloned tensor + outputs = super().forward(inputs) + + # Reset activation checkpointing buffers. + # Need to call this between evaluation iterations + if not self.module.training: + ds_checkpointing.reset() + + # Partition the outputs if we are not the last stage + if self.is_pipe_partitioned and not self.is_last_stage(): + if isinstance(outputs, tuple): + first_output = outputs[0] + # TODO: Improve pipe partitioning to pass multiple tensors that require grads + assert all([torch.is_tensor(elt) and elt.requires_grad is False for elt in outputs[1:]]) + outputs_tail = outputs[1:] + elif torch.is_tensor(outputs): + first_output = outputs + outputs_tail = [] + else: + raise ValueError("expecting a tensor or a tuple of tensors") + part = PartitionedTensor(tensor=first_output, group=self.grid.get_slice_parallel_group()) + # Clear the large output data, but save the computation graph + first_output.data = torch.zeros(1, device=first_output.data.device) + self.pipe_buffers['output_tensors'][buffer_id] = first_output + # Inject the partitioned tensor into the output before sending + outputs = (part.to_meta(), part.data(), *outputs_tail) + part = None + + self.pipe_buffers['outputs'][buffer_id] = outputs + + # Optionally compute loss on the last device + if self.is_last_stage(): + if self._compute_loss and self.module.loss_fn is not None: + labels = self.pipe_buffers['labels'][buffer_id] + self.loss = self.module.loss_fn(outputs, labels) + else: + # Some models just return loss from forward() + self.loss = outputs + if self.eval_return_logits: + self.outputs = outputs + + if isinstance(self.loss, torch.Tensor): + self.fwd_outputs.append(self.loss.detach()) + else: + self.fwd_outputs.append([l.detach() for l in self.loss]) + + def add_to_total_loss(_total_loss, _loss): + if isinstance(_loss, torch.Tensor): + if _total_loss is None: + _total_loss = torch.zeros_like(_loss) + _total_loss += _loss.detach() + else: + if _total_loss is None: + _total_loss = [torch.zeros_like(_l) for _l in _loss] + for _idx, _l in enumerate(_loss): + _total_loss[_idx] += _l.detach() + return _total_loss + + self.total_loss = add_to_total_loss(self.total_loss, self.loss) + + # aggregate additional losses across gradient accumulation steps + additional_losses = self.module.get_additional_losses() + if additional_losses is not None: + if self.total_additional_losses is None: + self.total_additional_losses = OrderedDict() + for name, loss in additional_losses.items(): + total = self.total_additional_losses[name] if name in self.total_additional_losses else None + self.total_additional_losses[name] = add_to_total_loss(total, loss) + + def _exec_backward_pass(self, buffer_id): + assert self.optimizer is not None, "must provide optimizer during " \ + "init in order to use backward" + + self.mem_status('BEFORE BWD', reset_max=True) + + # The last stage just runs backward on the loss using DeepSpeed's typical + # mechanisms. + if self.is_last_stage(): + super().backward(self.loss) + self.mem_status('AFTER BWD') + return + + outputs = self.pipe_buffers['outputs'][buffer_id] + + if self.wall_clock_breakdown(): + self.timers(BACKWARD_MICRO_TIMER).start() + self.timers(BACKWARD_GLOBAL_TIMER).start() + self.timers(BACKWARD_INNER_MICRO_TIMER).start() + self.timers(BACKWARD_INNER_GLOBAL_TIMER).start() + + # Reconstruct if we previously partitioned the output. We must be + # careful to also restore the computational graph of the tensors we partitioned. + if self.is_pipe_partitioned: + if self.is_grad_partitioned: + if self.pipe_partition_output_meta_cache is None: + self.pipe_partition_output_meta_cache = outputs[0].to('cpu') + part_output = PartitionedTensor.from_meta(meta=self.pipe_partition_output_meta_cache, + local_part=outputs[1], + group=self.grid.get_slice_parallel_group()) + self.pipe_buffers['output_tensors'][buffer_id].data = part_output.full() + outputs = (self.pipe_buffers['output_tensors'][buffer_id], *outputs[2:]) + else: + # Already restored from partition + self.pipe_buffers['output_tensors'][buffer_id].data = outputs[0] + outputs = (self.pipe_buffers['output_tensors'][buffer_id], *outputs[1:]) + + grad_tensors = self.grad_layer + if self.is_grad_partitioned: + #print(f'RANK={self.global_rank} BEFORE-BWD restoring grad={self.grad_layer[0].size()} {self.grad_layer[1].size()}') + if self.grad_partition_grad_layer_meta_cache is None: + self.grad_partition_grad_layer_meta_cache = self.grad_layer[0].to('cpu') + part_grad = PartitionedTensor.from_meta(meta=self.grad_partition_grad_layer_meta_cache, + local_part=self.grad_layer[1], + group=self.grid.get_slice_parallel_group()) + grad_tensors = (part_grad.full(), *grad_tensors[2:]) + part_grad = None + #print(f'RANK={self.global_rank} BEFORE-BWD restored grad={self.grad_layer[0].size()} {self.grad_layer[1].size()}') + + if self.using_bf16_optimizer and not self.is_last_stage(): + # manually call because we don't call optimizer.backward() + self.optimizer.clear_lp_grads() + + # This handles either a single tensor or tuple of tensors. + if isinstance(outputs, tuple): + out_tensors = [t for t in outputs if t.is_floating_point()] + assert len(out_tensors) == len(grad_tensors) + torch.autograd.backward(tensors=out_tensors, grad_tensors=grad_tensors) + else: + torch.autograd.backward(tensors=(outputs, ), grad_tensors=(grad_tensors, )) + + if self.using_bf16_optimizer and not self.is_last_stage(): + # manually call because we don't call optimizer.backward() + if not self._config.bfloat16_immediate_grad_update: + self.optimizer.update_hp_grads(clear_lp_grads=False) + + # Free up the memory from the output of forward() + self.pipe_buffers['output_tensors'][buffer_id] = None + self.pipe_buffers['outputs'][buffer_id] = None + grad_tensors = None + + if self.wall_clock_breakdown(): + self.timers(BACKWARD_INNER_MICRO_TIMER).stop() + self.timers(BACKWARD_INNER_GLOBAL_TIMER).stop() + self.timers(BACKWARD_MICRO_TIMER).stop() + self.timers(BACKWARD_GLOBAL_TIMER).stop() + + self.mem_status('AFTER BWD') + + def _exec_load_micro_batch(self, buffer_id): + if self.wall_clock_breakdown(): + self.timers(BATCH_INPUT_TIMER).start() + + batch = self._next_batch() + + if self.is_first_stage(): + loaded = None + if torch.is_tensor(batch[0]): + loaded = batch[0].clone().to(self.device).detach() + if self._config.pipeline['activation_checkpoint_interval'] > 0 and self._config.pipeline[ + 'use_reentrant']: + loaded.requires_grad = loaded.is_floating_point() + else: + assert isinstance(batch[0], (tuple, list)) + # Assume list or tuple + loaded = [] + for x in batch[0]: + assert torch.is_tensor(x) + mine = x.clone().detach().to(self.device) + if self._config.pipeline['activation_checkpoint_interval'] > 0 and self._config.pipeline[ + 'use_reentrant']: + mine.requires_grad = mine.is_floating_point() + loaded.append(mine) + loaded = tuple(loaded) + + self.pipe_buffers['inputs'][buffer_id] = loaded + + if self.is_last_stage(): + loaded = batch[1] + if torch.is_tensor(batch[1]): + loaded = batch[1].to(self.device) + # XXX: torch 1.6.0 DataLoader will auto convert tuple to list + elif isinstance(batch[1], (tuple, list)): + loaded = [] + for x in batch[1]: + assert torch.is_tensor(x) + x = x.to(self.device).detach() + loaded.append(x) + loaded = tuple(loaded) + + self.pipe_buffers['labels'][buffer_id] = loaded + + if self.wall_clock_breakdown(): + self.timers(BATCH_INPUT_TIMER).stop() + + def _send_tensor_meta(self, buffer, recv_stage): + """ Communicate metadata about upcoming p2p transfers. + + Metadata is communicated in this order: + * type (0: tensor, 1: list) + * num_tensors if type=list + foreach tensor in buffer: + * ndims + * shape + """ + meta_buffer = torch.empty(TENSOR_META_SIZE, dtype=torch.int32, device=self.device) + if isinstance(buffer, torch.Tensor): + meta_buf_list = [ + 0, # type of data (0: tensor, 1: list (unused), 2: tuple) + self.DTYPE_TO_ID[buffer.dtype], # dtype + len(buffer.size()) # ndims + ] + meta_buf_list.extend(buffer.size()) + assert len( + meta_buf_list + ) <= TENSOR_META_SIZE, f"Buffer for metadata is too small. Current buffer size: {TENSOR_META_SIZE} but required {len(meta_buf_list)}" + meta_buffer[:len(meta_buf_list)].copy_(torch.tensor(meta_buf_list, dtype=torch.int32)) + p2p.send(meta_buffer, recv_stage) + + elif isinstance(buffer, tuple): + meta_buf_list = [ + 2, # type of data (0: tensor, 1: list (unused), 2: tuple) + len(buffer) # num_tensors + ] + + for tensor in buffer: + assert isinstance(tensor, torch.Tensor) + meta_buf_list.append(self.DTYPE_TO_ID[tensor.dtype]) + meta_buf_list.append(len(tensor.size())) + meta_buf_list.extend(tensor.size()) + + assert len( + meta_buf_list + ) <= TENSOR_META_SIZE, f"Buffer for metadata is too small. Current buffer size: {TENSOR_META_SIZE} but required {len(meta_buf_list)}" + meta_buffer[:len(meta_buf_list)].copy_(torch.tensor(meta_buf_list, dtype=torch.int32)) + p2p.send(meta_buffer, recv_stage) + + else: + raise NotImplementedError(f'Could not send meta type {type(buffer)}') + + # Useful for performance debugging. + ''' + if self.grid.data_parallel_id == 0: + print(f'STAGE={self.stage_id} pipe-send-volume: {send_bytes/1024**2:0.2f}MB') + ''' + + def _recv_tensor_meta(self, send_stage): + """Receive metadata about upcoming p2p transfers and return allocated buffers. + + Returns: + Allocated buffer for receiving from send_stage. + """ + buffer = torch.empty(TENSOR_META_SIZE, dtype=torch.int32, device=self.device) + p2p.recv(buffer, send_stage) + + recv_type = buffer[0].item() + + # A single tensor will be sent. + if recv_type == 0: + recv_dtype = self.ID_TO_DTYPE[buffer[1].item()] + recv_ndims = buffer[2].item() + recv_shape = buffer[3:3 + recv_ndims].tolist() + return self._allocate_or_extend_buffers(0, recv_shape, recv_dtype) + + # List or tuple of tensors (recv_type == 1 (list) is currently unused) + elif recv_type == 1 or recv_type == 2: + num_tensors = buffer[1].item() + + buffers = [] + offset = 2 + for idx in range(num_tensors): + recv_dtype = self.ID_TO_DTYPE[buffer[offset].item()] + recv_ndims = buffer[offset + 1].item() + recv_shape = buffer[offset + 2:offset + 2 + recv_ndims].tolist() + offset += 2 + recv_ndims + + buffers.append(self._allocate_or_extend_buffers(idx, recv_shape, recv_dtype)) + + # Convert to tuples if requested. + if recv_type == 2: + buffers = tuple(buffers) + return buffers + + else: + raise NotImplementedError(f'Could not receive type {type(recv_type)}') + + def _exec_send_activations(self, buffer_id): + if self.wall_clock_breakdown(): + self.timers(PIPE_SEND_OUTPUT_TIMER).start() + + outputs = self.pipe_buffers['outputs'][buffer_id] + + # NCCL does not like to send torch.BoolTensor types, so cast the mask to half(). + # We could do char, but with half() we can eventually flatten with other fp16 + # messages (TODO) + if self.has_attention_mask or self.has_bool_tensors: + outputs = list(outputs) + outputs[-1] = outputs[-1].half() + outputs = tuple(outputs) + + if self.dynamic_shape or self.first_output_send: + self.first_output_send = False + self._send_tensor_meta(outputs, self.next_stage) + + if isinstance(outputs, torch.Tensor): + p2p.send(outputs, self.next_stage) + elif isinstance(outputs, tuple): + for idx, buffer in enumerate(outputs): + p2p.send(buffer, self.next_stage) + else: + raise NotImplementedError('Could not send output of type ' + f'{type(outputs)}') + + # Restore the boolean tensor + if self.has_attention_mask or self.has_bool_tensors: + outputs = list(outputs) + outputs[-1] = outputs[-1].bool() + outputs = tuple(outputs) + + if self.wall_clock_breakdown(): + self.timers(PIPE_SEND_OUTPUT_TIMER).stop() + + def _exec_send_grads(self, buffer_id): + if self.wall_clock_breakdown(): + self.timers(PIPE_SEND_GRAD_TIMER).start() + + inputs = self.pipe_buffers['inputs'][buffer_id] + + # Partition the gradient + if self.is_grad_partitioned: + if isinstance(inputs, tuple): + first_input = inputs[0] + assert all([torch.is_tensor(elt) for elt in inputs[1:]]) + inputs_grad_tail = [elt.grad for elt in inputs[1:]] + elif torch.is_tensor(inputs): + first_input = inputs + inputs_grad_tail = [] + else: + raise ValueError("expecting a tensor or a tuple of tensors") + assert torch.is_tensor(first_input) + part = PartitionedTensor(tensor=first_input.grad, group=self.grid.get_slice_parallel_group()) + + inputs = (part.to_meta(), part.data(), *inputs_grad_tail) + + # XXX Terrible hack + # Drop the attention mask from the input buffer here. It does not have + # a grad that needs to be communicated. We free the buffer immediately + # after, so no need to restore it. The receiver also has a hack that skips + # the recv. This is because NCCL does not let us send torch.BoolTensor :-(. + if self.has_attention_mask or self.has_bool_tensors: + inputs = list(inputs) + inputs.pop() + inputs = tuple(inputs) + + if isinstance(inputs, torch.Tensor): + assert inputs.grad is not None + p2p.send(inputs.grad, self.prev_stage) + else: + # XXX terrible hacky branch + if self.is_grad_partitioned: + # First two sends are partitioned gradient + p2p.send(inputs[0], self.prev_stage) + p2p.send(inputs[1], self.prev_stage) + else: + for idx, buffer in enumerate(inputs): + # Skip tensors that will not produce a grad + if not buffer.is_floating_point(): + assert buffer.grad is None + continue + assert buffer.grad is not None + p2p.send(buffer.grad, self.prev_stage) + + # We can free up the input buffer now + self.pipe_buffers['inputs'][buffer_id] = None + + if self.wall_clock_breakdown(): + self.timers(PIPE_SEND_GRAD_TIMER).stop() + + def _exec_recv_activations(self, buffer_id): + if self.wall_clock_breakdown(): + self.timers(PIPE_RECV_INPUT_TIMER).start() + + recvd = None + + # Allocate the buffer if necessary + if self.dynamic_shape or self.pipe_recv_buf is None: + self.pipe_recv_buf = self._recv_tensor_meta(self.prev_stage) + + if isinstance(self.pipe_recv_buf, torch.Tensor): + p2p.recv(self.pipe_recv_buf, self.prev_stage) + recvd = self.pipe_recv_buf.clone().detach() + recvd.requires_grad = recvd.is_floating_point() + else: + assert isinstance(self.pipe_recv_buf, tuple) + recvd = [None] * len(self.pipe_recv_buf) + for idx, buffer in enumerate(self.pipe_recv_buf): + assert torch.is_tensor(buffer) + # XXX hardcode meta type + if self.is_pipe_partitioned and idx == 0 and buffer.dtype != torch.long: + if self.meta_buffer is None: + self.meta_buffer = torch.zeros(buffer.size(), dtype=torch.long, device=self.device) + buffer = self.meta_buffer + + p2p.recv(buffer, self.prev_stage) + recvd[idx] = buffer.clone().detach() + + # NCCL does not like to send torch.BoolTensor types, so un-cast the + # attention mask + if self.has_attention_mask or self.has_bool_tensors: + recvd[-1] = recvd[-1].bool() + + recvd = tuple(recvd) + + for buffer in recvd: + buffer.requires_grad = buffer.is_floating_point() + + self.pipe_buffers['inputs'][buffer_id] = recvd + + if self.wall_clock_breakdown(): + self.timers(PIPE_RECV_INPUT_TIMER).stop() + + def _exec_recv_grads(self, buffer_id): + if self.wall_clock_breakdown(): + self.timers(PIPE_RECV_GRAD_TIMER).start() + + outputs = self.pipe_buffers['outputs'][buffer_id] + # XXX these shapes are hardcoded for Megatron + # Restore partitioned output if it was partitioned and we are sending full gradients + if self.is_pipe_partitioned and not self.is_grad_partitioned: + if self.pipe_partition_grad_meta_cache is None: + self.pipe_partition_grad_meta_cache = outputs[0].to('cpu') + part_output = PartitionedTensor.from_meta(meta=self.pipe_partition_grad_meta_cache, + local_part=outputs[1], + group=self.grid.get_slice_parallel_group()) + outputs[0].data = part_output.full() + outputs = (outputs[0], *outputs[2:]) + # save for backward + self.pipe_buffers['outputs'][buffer_id] = outputs + + # Allocate gradient if necessary + if self.dynamic_shape or self.grad_layer is None: + if isinstance(outputs, torch.Tensor): + self.grad_layer = self._allocate_or_extend_buffers(0, list(outputs.size()), outputs.dtype) + else: + # XXX This is a HACK + # When we exchange activations/gradients, the two pipe stages + # need to issue the send/recv with the same buffer sizes or + # else there is a deadlock. The is_floating_point() filter is + # used to avoid sending gradients for tensors that do not + # produce gradients. When TP>1, we partition the first + # activations/gradients across TP ranks to save communication + # volume and memory. That partitioned tensor is represented as + # two tensors: a 1/TPth chunk of the original data and also a + # small LongTensor storing the metadata used to reconstruct on + # the other side. When combined, the floating point filter also + # filtered out the metadata tensor. This quick (hacky) fix just + # branches on is_grad_partitioned so we don't filter out the + # metadata tensor. + if self.is_grad_partitioned: + sizes_and_dtypes = [(list(t.size()), t.dtype) + for t in outputs[:2]] + [(list(t.size()), t.dtype) + for t in outputs[2:] if t.is_floating_point()] + else: + sizes_and_dtypes = [(list(t.size()), t.dtype) for t in outputs if t.is_floating_point()] + + self.grad_layer = [ + self._allocate_or_extend_buffers(i, size, dtype) + for i, (size, dtype) in enumerate(sizes_and_dtypes) + ] + + if isinstance(self.grad_layer, torch.Tensor): + p2p.recv(self.grad_layer, self.next_stage) + else: + assert isinstance(outputs, tuple) + for idx, buffer in enumerate(self.grad_layer): + # XXX GPT-2 hack + if self.is_grad_partitioned and idx == 0 and buffer.dtype != torch.long: + buffer.data = torch.zeros(buffer.size(), dtype=torch.long, device=self.device) + p2p.recv(buffer, self.next_stage) + + if self.wall_clock_breakdown(): + self.timers(PIPE_RECV_GRAD_TIMER).stop() + + def _exec_optimizer_step(self, lr_kwargs=None): + if self.wall_clock_breakdown(): + self.timers(STEP_MICRO_TIMER).start() + self.timers(STEP_GLOBAL_TIMER).start() + self.mem_status('BEFORE STEP', reset_max=True) + + self._force_grad_boundary = True + self._take_model_step(lr_kwargs) + self._force_grad_boundary = False + + self.mem_status('AFTER STEP') + + if self.global_rank == 0 and self.monitor.enabled: + self.summary_events = [(f'Train/Samples/lr', self.get_lr()[0], self.global_samples)] + if self.fp16_enabled() and hasattr(self.optimizer, 'cur_scale'): + self.summary_events.append( + (f'Train/Samples/loss_scale', self.optimizer.cur_scale, self.global_samples)) + self.monitor.write_events(self.summary_events) + + if self.wall_clock_breakdown(): + self.timers(STEP_MICRO_TIMER).stop() + self.timers(STEP_GLOBAL_TIMER).stop() + if self.global_steps % self.steps_per_print() == 0: + self.timers.log([ + BATCH_INPUT_TIMER, + FORWARD_MICRO_TIMER, + BACKWARD_MICRO_TIMER, + BACKWARD_INNER_MICRO_TIMER, + BACKWARD_REDUCE_MICRO_TIMER, + STEP_MICRO_TIMER, + ]) + if self.global_steps % self.steps_per_print() == 0: + self.timers.log([ + FORWARD_GLOBAL_TIMER, + BACKWARD_GLOBAL_TIMER, + BACKWARD_INNER_GLOBAL_TIMER, + BACKWARD_REDUCE_GLOBAL_TIMER, + STEP_GLOBAL_TIMER, + ]) + + def _allocate_zeros(self, shape, **kwargs): + """ Allocate a tensor of zeros on the engine's device. + + Arguments: + shape: the shape of the tensor to allocate + kwargs: passed to torch.zeros() + + Returns: + A tensor from torch.zeros() allocated on self.device. + """ + if "dtype" not in kwargs: + if self.fp16_enabled(): + kwargs["dtype"] = torch.half + if self.bfloat16_enabled(): + kwargs["dtype"] = torch.bfloat16 + + return torch.zeros(shape, device=self.device, **kwargs) + + def _allocate_buffer(self, shape, num_buffers=-1, **kwargs): + buffers = [] + if num_buffers == -1: + num_buffers = self.num_pipe_buffers + for count in range(num_buffers): + buffers.append(self._allocate_zeros(shape, **kwargs)) + return buffers + + def _allocate_or_extend_buffers(self, idx, shape, dtype): + numel = reduce(mul, shape) if len(shape) > 0 else 1 + if len(self._grad_layer_buf) <= idx or self._grad_layer_buf[idx].numel() < numel: + new_buf = self._allocate_buffer(shape, dtype=dtype, num_buffers=1)[0] + if len(self._grad_layer_buf) <= idx: + self._grad_layer_buf.append(new_buf) + else: + self._grad_layer_buf[idx] = new_buf + return self._grad_layer_buf[idx] + else: + return self._grad_layer_buf[idx].flatten()[:numel].view(shape) + + def forward(self, *args, **kwargs): + """Disabled for pipeline parallel training. See ``train_batch()``. """ + raise PipelineError("Only train_batch() is accessible in pipeline mode.") + + def backward(self, *args, **kwargs): + """Disabled for pipeline parallel training. See ``train_batch()``. """ + raise PipelineError("Only train_batch() is accessible in pipeline mode.") + + def step(self, *args, **kwargs): + """Disabled for pipeline parallel training. See ``train_batch()``. """ + raise PipelineError("Only train_batch() is accessible in pipeline mode.") + + def mem_status(self, msg, print_rank=-1, reset_max=False): + return + global mem_alloced, mem_cached + if not self.global_steps == 0 or not self.global_steps == 9: + #return + pass + if self.mpu.get_data_parallel_rank() != 0: + return + + if self.global_rank != 0: + return + + rank = self.global_rank + if print_rank != -1 and rank != print_rank: + return + + get_accelerator().synchronize() + + if reset_max: + get_accelerator().reset_max_memory_cached() + get_accelerator().reset_max_memory_allocated() + + new_alloced = get_accelerator().memory_allocated() + new_cached = get_accelerator().memory_cached() + + delta_alloced = new_alloced - mem_alloced + delta_cached = new_cached - mem_cached + + mem_cached = new_cached + mem_alloced = new_alloced + + max_alloced = get_accelerator().max_memory_allocated() + max_cached = get_accelerator().max_memory_cached() + + # convert to GB for printing + new_alloced /= 1024**3 + new_cached /= 1024**3 + delta_alloced /= 1024**3 + delta_cached /= 1024**3 + max_alloced /= 1024**3 + max_cached /= 1024**3 + + print( + f'RANK={rank} STAGE={self.stage_id} STEP={self.global_steps} MEMSTATS', msg, + f'current alloc={new_alloced:0.4f}GB (delta={delta_alloced:0.4f}GB max={max_alloced:0.4f}GB) ' + f'current cache={new_cached:0.4f}GB (delta={delta_cached:0.4f}GB max={max_cached:0.4f}GB)') + + def module_state_dict(self, exclude_frozen_parameters=False): + """Override hack to save a pipe model and return the directory path of the save. + + This method should only be called by DeepSpeed's ``save_checkpoint()``. The + recommended way of saving a ``PipelineModule`` outside of ``save_checkpoint()`` + is ``save_state_dict()``. + + Returns: + None + """ + assert isinstance(self.module, PipelineModule) + assert self._curr_ckpt_path is not None, \ + "PipelineEngine expects module_state_dict() to be called from save_checkpoint()" + + self.module.save_state_dict(self._curr_ckpt_path, + checkpoint_engine=self.checkpoint_engine, + exclude_frozen_params=exclude_frozen_parameters) + return None + + def load_module_state_dict(self, checkpoint, strict=True, custom_load_fn=None, fetch_z3_params=False): + """Override hack to instead use a directory path. + + This is important because pipeline models checkpoint by layer instead of rank. + + If ``state_dict`` is not ``None`` or a ``str``, we revert to ``super()`` expecting a ``dict``. + + Args: + state_dict (str, None): unused + strict (bool, optional): Strict state loading. Defaults to True. + """ + assert custom_load_fn is None, "custom_load_fn not supported w. pipeline parallelism" + state_dict = checkpoint if self.has_moe_layers else checkpoint['module'] + if (state_dict is not None) and (not isinstance(state_dict, str)): + super().load_module_state_dict(state_dict, strict) + return + + self.module.load_state_dir(load_dir=self._curr_ckpt_path, + strict=strict, + checkpoint_engine=self.checkpoint_engine) + + # A map of PipeInstruction types to methods. Each method will be executed with the + # kwargs provided to the PipeInstruction from the scheduler. + _INSTRUCTION_MAP = { + schedule.OptimizerStep: _exec_optimizer_step, + schedule.ReduceGrads: _exec_reduce_grads, + schedule.ReduceTiedGrads: _exec_reduce_tied_grads, + schedule.LoadMicroBatch: _exec_load_micro_batch, + schedule.ForwardPass: _exec_forward_pass, + schedule.BackwardPass: _exec_backward_pass, + schedule.SendActivation: _exec_send_activations, + schedule.RecvActivation: _exec_recv_activations, + schedule.SendGrad: _exec_send_grads, + schedule.RecvGrad: _exec_recv_grads, + } + + def _exec_schedule(self, pipe_schedule): + # Reserve and reset buffers. + self._reserve_pipe_buffers(pipe_schedule.num_pipe_buffers()) + self.fwd_outputs = [] + + # For each step in the schedule + for step_cmds in pipe_schedule: + # For each instruction in the step + for cmd in step_cmds: + if type(cmd) not in self._INSTRUCTION_MAP: + raise RuntimeError(f'{self.__class__.__name__} does not understand instruction {repr(cmd)}') + + # Equivalent to: self._exec_forward_pass(buffer_id=0) + self._exec_instr = MethodType(self._INSTRUCTION_MAP[type(cmd)], self) + self._exec_instr(**cmd.kwargs) + + def get_additional_losses(self): + return self.agg_additional_losses diff --git a/toolbox/DeepSpeed/v0.15.3/patches/deepspeed/runtime/pipe/module.py b/toolbox/DeepSpeed/v0.15.3/patches/deepspeed/runtime/pipe/module.py new file mode 100644 index 0000000000000000000000000000000000000000..b19601683971c4e5f5369e8592e6c890e71e479e --- /dev/null +++ b/toolbox/DeepSpeed/v0.15.3/patches/deepspeed/runtime/pipe/module.py @@ -0,0 +1,716 @@ +#!/usr/bin/env python3 +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import os +import glob + +import re as regex + +from functools import partial + +import torch +import torch.nn as nn +from deepspeed import comm as dist + +from deepspeed.utils import logger +from .. import utils as ds_utils +from ..activation_checkpointing import checkpointing +from .topology import PipeDataParallelTopology, PipelineParallelGrid +from deepspeed.runtime.state_dict_factory import SDLoaderFactory +from deepspeed.accelerator import get_accelerator +from deepspeed.checkpoint.utils import clone_tensors_for_torch_save + + +class PipelineError(Exception): + """Errors related to the use of deepspeed.PipelineModule """ + + +class LayerSpec: + """Building block for specifying pipeline-parallel modules. + + LayerSpec stores the type information and parameters for each stage in a + PipelineModule. For example: + + .. code-block:: python + + nn.Sequence( + torch.nn.Linear(self.in_dim, self.hidden_dim, bias=False), + torch.nn.Linear(self.hidden_hidden, self.out_dim) + ) + + becomes + + .. code-block:: python + + layer_specs = [ + LayerSpec(torch.nn.Linear, self.in_dim, self.hidden_dim, bias=False), + LayerSpec(torch.nn.Linear, self.hidden_hidden, self.out_dim)] + ] + """ + + def __init__(self, typename, *module_args, **module_kwargs): + self.typename = typename + self.module_args = module_args + self.module_kwargs = module_kwargs + + if not issubclass(typename, nn.Module): + raise RuntimeError('LayerSpec only supports torch.nn.Module types.') + + if dist.is_initialized(): + self.global_rank = dist.get_rank() + else: + self.global_rank = -1 + + def __repr__(self): + return ds_utils.call_to_str(self.typename.__name__, self.module_args, self.module_kwargs) + + def build(self, log=False): + """Build the stored specification.""" + if log: + logger.info(f'RANK={self.global_rank} building {repr(self)}') + + return self.typename(*self.module_args, **self.module_kwargs) + + +class TiedLayerSpec(LayerSpec): + + def __init__(self, key, typename, *module_args, forward_fn=None, tied_weight_attr=['weight'], **module_kwargs): + super().__init__(typename, *module_args, **module_kwargs) + self.key = key + self.forward_fn = forward_fn + self.tied_weight_attr = [tied_weight_attr] if type(tied_weight_attr) == str else tied_weight_attr + + +class PipelineModule(nn.Module): + """Modules to be parallelized with pipeline parallelism. + + The key constraint that enables pipeline parallelism is the + representation of the forward pass as a sequence of layers + and the enforcement of a simple interface between them. The + forward pass is implicitly defined by the module ``layers``. The key + assumption is that the output of each layer can be directly fed as + input to the next, like a ``torch.nn.Sequence``. The forward pass is + implicitly: + + .. code-block:: python + + def forward(self, inputs): + x = inputs + for layer in self.layers: + x = layer(x) + return x + + .. note:: + Pipeline parallelism is not compatible with ZeRO-2 and ZeRO-3. + + Args: + layers (Iterable): A sequence of layers defining pipeline structure. Can be a ``torch.nn.Sequential`` module. + num_stages (int, optional): The degree of pipeline parallelism. If not specified, ``topology`` must be provided. + topology (``deepspeed.runtime.pipe.ProcessTopology``, optional): Defines the axes of parallelism axes for training. Must be provided if ``num_stages`` is ``None``. + loss_fn (callable, optional): Loss is computed ``loss = loss_fn(outputs, label)`` + seed_layers(bool, optional): Use a different seed for each layer. Defaults to False. + seed_fn(type, optional): The custom seed generating function. Defaults to random seed generator. + base_seed (int, optional): The starting seed. Defaults to 1234. + partition_method (str, optional): The method upon which the layers are partitioned. Defaults to 'parameters'. + custom_partition (list, optional): custom model layers in PP stage. + activation_checkpoint_interval (int, optional): The granularity activation checkpointing in terms of number of layers. 0 disables activation checkpointing. + activation_checkpoint_func (callable, optional): The function to use for activation checkpointing. Defaults to ``deepspeed.checkpointing.checkpoint``. + checkpointable_layers(list, optional): Checkpointable layers may not be checkpointed. Defaults to None which does not additional filtering. + dynamic_shape: Allows dynamic shapes of inputs. This might have a performance impact. + custom_recompute_layers_per_stage (list, optional): custom recompute layers in PP stage (for megatron-deepspeed). + """ + + def __init__(self, + layers, + num_stages=None, + topology=None, + loss_fn=None, + seed_layers=False, + seed_fn=None, + base_seed=1234, + partition_method='parameters', + custom_partition=None, + activation_checkpoint_interval=0, + activation_checkpoint_func=checkpointing.checkpoint, + checkpointable_layers=None, + dynamic_shape=False, + custom_recompute_layers_per_stage=None): + + super().__init__() + + if num_stages is None and topology is None: + raise RuntimeError('must provide num_stages or topology') + + self.micro_offset = 0 + + self.loss_fn = loss_fn + + self.checkpointable_layers = checkpointable_layers + if checkpointable_layers is not None: + assert isinstance(checkpointable_layers, list), "param `checkpointable_layers` must be type of list." + + self.seed_layers = seed_layers + self.seed_fn = seed_fn + self.base_seed = base_seed + if dist.get_rank() == 0: + try: + seed_str = self.seed_fn.__name__ + except AttributeError: + seed_str = None + print(f'SEED_LAYERS={self.seed_layers} BASE_SEED={self.base_seed} SEED_FN={seed_str}') + + # Setup world info + self.world_group = dist.new_group(ranks=range(dist.get_world_size())) + self.global_rank = dist.get_rank(group=self.world_group) + self.world_size = dist.get_world_size(group=self.world_group) + self.local_rank = int(os.environ.get("LOCAL_RANK", None)) + assert self.local_rank is not None + + if topology: + self._topo = topology + self.num_stages = self._topo.get_dim('pipe') + else: + self.num_stages = num_stages + if topology is None: + if self.world_size % self.num_stages != 0: + raise RuntimeError( + f'num_stages ({self.num_stages}) must divide distributed world size ({self.world_size})') + dp = self.world_size // num_stages + topology = PipeDataParallelTopology(num_pp=num_stages, num_dp=dp) + self._topo = topology + + # Construct communicators for pipeline topology + self._grid = PipelineParallelGrid(process_group=self.world_group, topology=self._topo) + + self.stage_id = self._topo.get_coord(self.global_rank).pipe + + # Initialize partition information + self._layer_specs = list(layers) + self._num_layers = len(self._layer_specs) + self._local_start = 0 + self._local_stop = None + self._partition_layers(method=partition_method, custom_partition=custom_partition) + + self.forward_funcs = [] + self.fwd_map = {} + self.tied_modules = nn.ModuleDict() + self.tied_weight_attrs = {} + + # Offset the random seed by the stage ID. + #newseed = get_accelerator().initial_seed() + self._grid.get_stage_id() + #ds_utils.set_random_seed(newseed) + + self.activation_checkpoint_interval = activation_checkpoint_interval + + self.activation_checkpoint_func = activation_checkpoint_func + + #storage for precomputed checkpointeble results + self.is_checkpointable_results = [] + self.is_checkpointable_results_interval = None + + # if configuration use_reentrant = False, self.activation_checkpoint_func will be set to ``checkpointing.non_reentrant_checkpoint`` + + #with torch.random.fork_rng(devices=[get_accelerator().current_device_name()]): + self._build() + self.to(get_accelerator().device_name(self.local_rank)) + + self.tied_comms = self._index_tied_modules() + self._synchronize_tied_weights() + + self.dynamic_shape = dynamic_shape + self.custom_activation_checkpoint = None + if custom_recompute_layers_per_stage is not None: + self.custom_activation_checkpoint = custom_recompute_layers_per_stage + + def _precompute_checkpointable_values(self): + if self.activation_checkpoint_interval > 0 and self.is_checkpointable_results_interval != self.activation_checkpoint_interval: + num_layers = len(self.forward_funcs) + self.interval_was_zero = False + for start_idx in range(0, num_layers, self.activation_checkpoint_interval): + end_idx = min(start_idx + self.activation_checkpoint_interval, num_layers) + funcs = self.forward_funcs[start_idx:end_idx] + self.is_checkpointable_results.append(self._is_checkpointable(funcs)) + self.is_checkpointable_results_interval = self.activation_checkpoint_interval + + def _build(self): + specs = self._layer_specs + + for local_idx, layer in enumerate(specs[self._local_start:self._local_stop]): + layer_idx = local_idx + self._local_start + if self.seed_layers: + if self.seed_fn: + self.seed_fn(self.base_seed + layer_idx) + else: + ds_utils.set_random_seed(self.base_seed + layer_idx) + + # Recursively build PipelineModule objects + if isinstance(layer, PipelineModule): + raise NotImplementedError('RECURSIVE BUILD NOT YET IMPLEMENTED') + + # LayerSpec objects contain an nn.Module that should be allocated now. + elif isinstance(layer, nn.Module): + name = str(layer_idx) + self.forward_funcs.append(layer) + self.fwd_map.update({name: len(self.forward_funcs) - 1}) + self.add_module(name, layer) + + # TiedLayerSpec objects contain an nn.Module that should be allocated now. + elif isinstance(layer, TiedLayerSpec): + # Build and register the module if we haven't seen it before. + if layer.key not in self.tied_modules: + self.tied_modules[layer.key] = layer.build() + self.tied_weight_attrs[layer.key] = layer.tied_weight_attr + + if layer.forward_fn is None: + # Just use forward() + self.forward_funcs.append(self.tied_modules[layer.key]) + else: + # User specified fn with args (module, input) + self.forward_funcs.append(partial(layer.forward_fn, self.tied_modules[layer.key])) + + # LayerSpec objects contain an nn.Module that should be allocated now. + elif isinstance(layer, LayerSpec): + module = layer.build() + name = str(layer_idx) + self.forward_funcs.append(module) + self.fwd_map.update({name: len(self.forward_funcs) - 1}) + self.add_module(name, module) + + # Last option: layer may be a functional (e.g., lambda). We do nothing in + # that case and just use it in forward() + else: + self.forward_funcs.append(layer) + + # All pipeline parameters should be considered as model parallel in the context + # of our FP16 optimizer + for p in self.parameters(): + p.ds_pipe_replicated = False + + def _get_frozen_parameter_names(self, layer): + """ Get names of frozen parameters in the layer. + + Returns: + A list of frozen parameter names + """ + if isinstance(layer, LayerSpec): + l = layer.build() + return [n for n, p in l.named_parameters() if not p.requires_grad] + elif isinstance(layer, nn.Module): + return [n for n, p in layer.named_parameters() if not p.requires_grad] + + return [] + + def _count_layer_params(self): + """Count the trainable parameters in individual layers. + + This routine will only build one layer at a time. + + Returns: + A list of the number of parameters in each layer. + """ + param_counts = [0] * len(self._layer_specs) + for idx, layer in enumerate(self._layer_specs): + if isinstance(layer, LayerSpec): + l = layer.build() + params = filter(lambda p: p.requires_grad, l.parameters()) + param_counts[idx] = sum(p.numel() for p in params) + elif isinstance(layer, nn.Module): + params = filter(lambda p: p.requires_grad, layer.parameters()) + param_counts[idx] = sum(p.numel() for p in params) + return param_counts + + def _find_layer_type(self, layername): + idxs = [] + typeregex = regex.compile(layername, regex.IGNORECASE) + for idx, layer in enumerate(self._layer_specs): + name = None + if isinstance(layer, LayerSpec): + name = layer.typename.__name__ + elif isinstance(layer, nn.Module): + name = layer.__class__.__name__ + else: + try: + name = layer.__name__ + except AttributeError: + continue + if typeregex.search(name): + idxs.append(idx) + + if len(idxs) == 0: + raise RuntimeError(f"Partitioning '{layername}' found no valid layers to partition.") + return idxs + + def forward(self, forward_input): + # We need to offset the seed by the microbatch ID. Save it in a local var to + # ensure it is preserved in the closure. Otherwise checkpointed forward funcs + # will see a different offset. + self.micro_offset += 1 + + def exec_range_func(start, end): + ''' Helper function to be used with checkpoint() + Adapted from torch.utils.checkpoint:checkpoint_sequential() + ''' + local_micro_offset = self.micro_offset + 1 + + def exec_func(*inputs): + # Single tensor inputs need to be unwrapped + if len(inputs) == 1: + inputs = inputs[0] + for idx, layer in enumerate(self.forward_funcs[start:end]): + self.curr_layer = idx + self._local_start + if self.seed_layers: + new_seed = (self.base_seed * local_micro_offset) + self.curr_layer + if self.seed_fn: + self.seed_fn(new_seed) + else: + ds_utils.set_random_seed(new_seed) + + inputs = layer(inputs) + return inputs + + return exec_func + + num_layers = len(self.forward_funcs) + if self.custom_activation_checkpoint is not None: + recompute_layers = self.custom_activation_checkpoint[self._grid.get_pipe_parallel_rank()] + ## 获取pp stage 第一层transformer的序号 + for i in range(num_layers-1): + f = self.forward_funcs[i:i+1] + if self._is_checkpointable(f): + break + if recompute_layers != 0: + activation_checkpoint_interval = (num_layers - i) // recompute_layers + ## 当重计算层数大于transformer 层数,则全部重计算 + if activation_checkpoint_interval <= 0: + activation_checkpoint_interval = 1 + else: + activation_checkpoint_interval = 0 + else: + activation_checkpoint_interval = self.activation_checkpoint_interval + checkpoint_num_layers = 0 + if activation_checkpoint_interval == 0: + func = exec_range_func(0, len(self.forward_funcs)) + x = func(forward_input) + else: + num_layers = len(self.forward_funcs) + x = forward_input + + for start_idx, is_checkpointable_result in \ + zip(range(0, num_layers, self.activation_checkpoint_interval), self.is_checkpointable_results): + if self.custom_activation_checkpoint is not None: + if checkpoint_num_layers < recompute_layers - 1: + end_idx = min(start_idx + activation_checkpoint_interval, num_layers) + else: + end_idx = num_layers + else: + end_idx = min(start_idx + activation_checkpoint_interval, num_layers) + + funcs = self.forward_funcs[start_idx:end_idx] + # Since we either pass tensors or tuples of tensors without unpacking, we + # need to be careful not to double-wrap tensors with tuple. + if not isinstance(x, tuple): + x = (x, ) + + if is_checkpointable_result: + checkpoint_num_layers += 1 + x = self.activation_checkpoint_func(exec_range_func(start_idx, end_idx), *x) + else: + x = exec_range_func(start_idx, end_idx)(*x) + ## 计算完最后一组时,退出 + if end_idx == num_layers: + break + return x + + def _partition_layers(self, method='uniform', custom_partition=None): + num_stages = self._topo.get_dim('pipe') + stage_id = self._topo.get_coord(self.global_rank).pipe + + if self.global_rank == 0: + logger.info(f'Partitioning pipeline stages with method {method}') + + method = method.lower() + + # Each stage gets a simple uniform number of layers. + if method == 'uniform': + num_layers = len(self._layer_specs) + self.parts = ds_utils.partition_uniform(num_items=num_layers, num_parts=num_stages) + elif method == 'parameters': + param_counts = self._count_layer_params() + self.parts = ds_utils.partition_balanced(weights=param_counts, num_parts=num_stages) + elif method.startswith('type:'): + layertype = method.split(':')[1] + binary_weights = [0] * len(self._layer_specs) + for idx in self._find_layer_type(layertype): + binary_weights[idx] = 1 + self.parts = ds_utils.partition_balanced(weights=binary_weights, num_parts=num_stages) + ## custom partition of layers to pp stage + elif method == 'custom': + num_layers = len(self._layer_specs) + if not custom_partition: + raise ValueError(f"argument of custom-partition should not be None, it should be number layers of every PP_stage. Or choose other partition-methon like: uniform") + if isinstance(custom_partition, list): + custom_part = custom_partition + if len(custom_part) != num_stages: + raise ValueError(f"lenth of custom-partition should be equal to PP_stages. lenth of custom={len(custom_part)} num_stages={num_stages}") + if sum(custom_part) != num_layers: + raise ValueError(f"sum of custom-partition layers should be equal to total model layers. sum(custom_part)={sum(custom_part)} num_layers={num_layers}") + self.parts = [0]*(num_stages + 1) + for i in range(num_stages): + self.parts[i+1] = self.parts[i] + custom_part[i] + elif method == 'profile': + raise NotImplementedError(f'Partitioning method {method} not implemented.') + else: + raise NotImplementedError(f'Partitioning method {method} not implemented.') + + # Print some information on the partitioning. + if self.global_rank == 0: + for stage in range(num_stages): + start = self.parts[stage] + stop = self.parts[stage + 1] + print(f'stage={stage} layers={stop - start}') + for idx, layer in enumerate(self._layer_specs[start:stop]): + name = str(layer) + if isinstance(layer, LayerSpec): + name = layer.typename.__name__ + if isinstance(layer, nn.Module): + name = layer.__class__.__name__ + else: + try: + name = layer.__name__ + except AttributeError: + pass + print(f' {idx+start:2d}: {name}') + if self.loss_fn: + try: + print(f' loss: {self.loss_fn.__name__}') + except AttributeError: + print(f' loss: {self.loss_fn.__class__.__name__}') + + self._set_bounds(start=self.parts[stage_id], stop=self.parts[stage_id + 1]) + + def allreduce_tied_weight_gradients(self): + '''All reduce the gradients of the tied weights between tied stages''' + for key, comm in self.tied_comms.items(): + for attr_name in comm['weight_attr']: + weight = getattr(self.tied_modules[key], attr_name) + dist.all_reduce(weight.grad, group=comm['group']) + + def get_tied_weights_and_groups(self): + weight_group_list = [] + for key, comm in self.tied_comms.items(): + for attr_name in comm['weight_attr']: + weight = getattr(self.tied_modules[key], attr_name) + weight_group_list.append((weight, comm['group'])) + return weight_group_list + + def _synchronize_tied_weights(self): + for key, comm in self.tied_comms.items(): + for attr_name in comm['weight_attr']: + dist.broadcast( + getattr(comm['module'], attr_name), + src=min(comm['ranks']), + group=comm['group'], + ) + + def _index_tied_modules(self): + ''' Build communication structures for tied modules. ''' + tied_comms = {} + if self._topo.get_dim('pipe') == 1: + return tied_comms + + specs = self._layer_specs + tie_keys = set(s.key for s in specs if isinstance(s, TiedLayerSpec)) + for key in tie_keys: + # Find the layers that the tied module appears in + tied_layers = [] + for idx, layer in enumerate(specs): + if isinstance(layer, TiedLayerSpec) and layer.key == key: + tied_layers.append(idx) + # Find all stages with this tied module + # TODO: Would be nice to remove the nested data/model parallelism loops and + # TODO: instead generalize in some way, since we really just care about the + # TODO: stage that owns the tied layer. Then loop over each (dp, mp, ...) + # TODO: fiber to generate process groups. + tied_stages = set(self.stage_owner(idx) for idx in tied_layers) + for dp in range(self._grid.data_parallel_size): + for mp in range(self._grid.get_slice_parallel_world_size()): + tied_ranks = [] + for s in sorted(tied_stages): + if self._grid.get_slice_parallel_world_size() > 1: + tied_ranks.append(self._grid.stage_to_global(stage_id=s, data=dp, model=mp)) + else: + tied_ranks.append(self._grid.stage_to_global(stage_id=s, data=dp)) + group = dist.new_group(ranks=tied_ranks) + + # Record this tied module if we own a local copy of it. + if self.global_rank in tied_ranks: + assert key in self.tied_modules + if key in self.tied_modules: + tied_comms[key] = { + 'ranks': tied_ranks, + 'group': group, + 'weight_attr': self.tied_weight_attrs[key], + 'module': self.tied_modules[key], + } + # Only count the tied module once in the eyes of the FP16 optimizer + if self.global_rank != tied_ranks[0]: + for p in self.tied_modules[key].parameters(): + p.ds_pipe_replicated = True + ''' + if len(tied_comms) > 0: + print(f'RANK={self.global_rank} tied_comms={tied_comms}') + ''' + + return tied_comms + + def partitions(self): + return self.parts + + def stage_owner(self, layer_idx): + assert 0 <= layer_idx < self._num_layers + for stage in range(self._topo.get_dim('pipe')): + if self.parts[stage] <= layer_idx < self.parts[stage + 1]: + return stage + raise RuntimeError(f'Layer {layer_idx} not owned? parts={self.parts}') + + def _set_bounds(self, start=None, stop=None): + """Manually define the range of layers that will be built on this process. + + These boundaries are treated as list slices and so start is inclusive and stop is + exclusive. The default of None for both results in all layers being built + locally. + """ + self._local_start = start + self._local_stop = stop + + def set_checkpoint_interval(self, interval): + assert interval >= 0 + self.checkpoint_interval = interval + + def topology(self): + """ ProcessTopology object to query process mappings. """ + return self._topo + + def mpu(self): + return self._grid + + def num_pipeline_stages(self): + return self._topo.get_dim('pipe') + + def ckpt_prefix(self, checkpoints_path, tag): + """Build a prefix for all checkpoint files written by this module. """ + # All checkpoint files start with this + rank_name = 'module' + + # Data parallelism is omitted from the naming convention because we are agnostic + # to this in the checkpoint. + omit_dims = frozenset(['data']) + axes = [a for a in self._grid._topo.get_axis_names() if a not in omit_dims] + for dim in axes: + rank = getattr(self._grid._topo.get_coord(rank=self.global_rank), dim) + rank_name += f'-{dim}_{rank:02d}' + + ckpt_name = os.path.join(checkpoints_path, str(tag), rank_name) + return ckpt_name + + def ckpt_layer_path(self, ckpt_dir, local_layer_idx): + """Customize a prefix for a specific pipeline module layer. """ + idx = local_layer_idx + self._local_start + layer_ckpt_path = os.path.join(ckpt_dir, f'layer_{idx:02d}') + rank_repr = self._grid._topo.get_rank_repr(rank=self.global_rank) + if rank_repr != '': + layer_ckpt_path += f'-{rank_repr}' + layer_ckpt_path += '-model_states.pt' + return layer_ckpt_path + + def ckpt_layer_path_list(self, ckpt_dir, local_layer_idx): + """Get all ckpt file list for a specific pipeline module layer. """ + idx = local_layer_idx + self._local_start + layer_ckpt_path = os.path.join(ckpt_dir, f'layer_{idx:02d}-') + layer_ckpt_path += "*model_states.pt" + ckpt_files = glob.glob(layer_ckpt_path) + ckpt_files.sort() + return ckpt_files + + def save_state_dict(self, save_dir, checkpoint_engine, exclude_frozen_params=False): + # Processes having the same model parallel rank on different data parallel instances + # have identical layer weights. We can distribute the task of saving the layer weights + # among the data parallel ranks. For example, if a pipeline stage has 9 layers and + # if there are 2 data parallel instances, rank 0 will save the first 5 layers and + # rank 1 will save the last 4. + dp_rank = self._grid.data_parallel_id + dp_size = self._grid.data_parallel_size + num_layers = len(self.forward_funcs) + if self.checkpoint_parallel_write_pipeline: + # spread layers evenly across data parallel ranks + offsets = ds_utils.partition_uniform(num_layers, dp_size) + start, end = offsets[dp_rank], offsets[dp_rank + 1] + else: + # data parallel rank 0 writes all layers + if dp_rank != 0: + return + start, end = 0, num_layers + layer_list = self.forward_funcs[start:end] + + checkpoint_engine.makedirs(save_dir, exist_ok=True) + for idx, layer in enumerate(layer_list): + model_ckpt_path = self.ckpt_layer_path(save_dir, start + idx) + if not hasattr(layer, 'state_dict'): + continue + + orig_state_dict = layer.state_dict() + if exclude_frozen_params: + for n in self._get_frozen_parameter_names(layer): + del orig_state_dict[n] + final_state_dict = clone_tensors_for_torch_save(orig_state_dict) + checkpoint_engine.save(final_state_dict, model_ckpt_path) + + def load_state_dir(self, load_dir, checkpoint_engine, strict=True): + for idx, layer in enumerate(self.forward_funcs): + # Functions, etc. will not have state_dicts + if not hasattr(layer, 'load_state_dict'): + continue + + # get all checkpoint files for the layer. + model_ckpt_list = self.ckpt_layer_path_list(load_dir, idx) + mp_rank = self._grid.get_slice_parallel_rank() + mp_world_size = self._grid.get_slice_parallel_world_size() + + sd_loader = SDLoaderFactory.get_sd_loader(model_ckpt_list, + version=2.0, + checkpoint_engine=checkpoint_engine) + load_path, checkpoint, _ = sd_loader.load(mp_world_size, mp_rank, module_key=None, is_pipe_parallel=True) + + layer.load_state_dict(checkpoint, strict=strict) + + # if self._grid.data_parallel_id == 0: + # logger.info( + # f'RANK={self.global_rank} Loaded layer={idx+self._local_start} file={load_path}' + # ) + + self._synchronize_tied_weights() + + def _is_checkpointable(self, funcs): + + if self.activation_checkpoint_func is not checkpointing.non_reentrant_checkpoint: + # This hook excludes the embedding layer + # because only non_reentrant_checkpoint can accept inputs with requires_grad=False + # otherwise, the backward of the embedding layer won't receive gradients. + if self.__class__.__name__ in ('GPTModelPipe', 'GPT2ModelPipe'): + return all('ParallelTransformerLayerPipe' in f.__class__.__name__ for f in funcs) + if self.checkpointable_layers is not None: + return all(f.__class__.__name__ in self.checkpointable_layers for f in funcs) + params = [f.parameters() for f in funcs if isinstance(f, torch.nn.Module)] + return any(len(list(p)) > 0 for p in params) + + def get_additional_losses(self): + """ Returns model specific additional losses for reporting + + Return a dictionary of {"loss name": loss_value} or None if no additional losses. + """ + return None diff --git a/toolbox/DeepSpeed/v0.15.3/patches/deepspeed/runtime/zero/parameter_offload.py b/toolbox/DeepSpeed/v0.15.3/patches/deepspeed/runtime/zero/parameter_offload.py new file mode 100644 index 0000000000000000000000000000000000000000..cc4cda92bbf6b369954e822b7804ebc77c842b1f --- /dev/null +++ b/toolbox/DeepSpeed/v0.15.3/patches/deepspeed/runtime/zero/parameter_offload.py @@ -0,0 +1,498 @@ +#!/usr/bin/env python3 +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import sys +import torch +from collections import OrderedDict +from deepspeed.utils import z3_leaf_module +from deepspeed.runtime.utils import see_memory_usage +from deepspeed.runtime.zero.utils import apply_to_tensors_only, is_zero_param +from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum +from deepspeed.runtime.zero.partition_parameters import _init_external_params +from deepspeed.runtime.zero.partition_parameters import * +from deepspeed.runtime.zero.partitioned_param_coordinator import PartitionedParameterCoordinator, InflightParamRegistry, iter_params +from deepspeed.accelerator import get_accelerator + +FWD_MODULE_STACK = list() + + +#for each tensor in outputs run the forward_function and register backward_function as hook +def _apply_forward_and_backward_to_tensors_only(module, forward_function, backward_function, outputs): + if type(outputs) is tuple: + touched_outputs = [] + for output in outputs: + touched_output = _apply_forward_and_backward_to_tensors_only(module, forward_function, backward_function, + output) + touched_outputs.append(touched_output) + return tuple(touched_outputs) + elif type(outputs) is torch.Tensor: + forward_function(outputs) + if outputs.requires_grad: + outputs.register_hook(backward_function) + return outputs + else: + return outputs + + +class ZeROOrderedDict(OrderedDict): + + def __init__(self, parent_module=None, *args, **kwargs): + """A replacement for ``collections.OrderedDict`` to detect external ZeRO params. + + Args: + parent_module (``collections.OrderedDict``): the collection to replace + """ + + super().__init__(*args, **kwargs) + self._parent_module = parent_module + self._in_forward = False + + def __getitem__(self, key): + param = super().__getitem__(key) + + # Params can be registered as None (e.g., bias) + if param is None: + return param + + if hasattr(param, "ds_status") and param.ds_status == ZeroParamStatus.NOT_AVAILABLE: + if self._parent_module._parameters._in_forward: + register_external_parameter(FWD_MODULE_STACK[-1], param) + param.all_gather() + print_rank_0(f'Registering external parameter from getter {key} ds_id = {param.ds_id}', force=False) + + return param + + +def _inject_parameters(module, cls): + for module in module.modules(): + if cls == ZeROOrderedDict: + new_param = cls(parent_module=module) + else: + new_param = cls() + + for key, param in module._parameters.items(): + new_param[key] = param + module._parameters = new_param + + +class DeepSpeedZeRoOffload(object): + + def __init__( + self, + module, + timers, + ds_config, + overlap_comm=True, + prefetch_bucket_size=50000000, + max_reuse_distance=1000000000, + max_live_parameters=1000000000, + param_persistence_threshold=100000, + model_persistence_threshold=sys.maxsize, + dp_process_group=None, + offload_param_config=None, + mpu=None, + zero_param_parallel_group=None, + zero_quantized_weights=False, + zero_quantized_nontrainable_weights=False, + ): + + see_memory_usage("DeepSpeedZeRoOffload initialize [begin]", force=True) + + print_rank_0(f"initialized {__class__.__name__} with args: {locals()}", force=False) + + self.module = module + self.timers = timers + self.dtype = list(module.parameters())[0].dtype + self.dp_process_group = dp_process_group + self.offload_device = None + self.offload_param_pin_memory = False + self.zero_param_parallel_group = zero_param_parallel_group + self.zero_quantized_weights = zero_quantized_weights + self.zero_quantized_nontrainable_weights = zero_quantized_nontrainable_weights + + if offload_param_config is not None and offload_param_config.device != OffloadDeviceEnum.none: + self.offload_device = offload_param_config.device + self.offload_param_pin_memory = offload_param_config.pin_memory + + self._convert_to_zero_parameters(ds_config, module, mpu) + + for m in module.modules(): + _init_external_params(m) + + _inject_parameters(module, ZeROOrderedDict) + + self.param_numel_persistence_threshold = int(param_persistence_threshold) + self.model_persistence_threshold = int(model_persistence_threshold) + self.persistent_parameters = self.mark_persistent_parameters(self.param_numel_persistence_threshold, + self.model_persistence_threshold) + + self.param_coordinators = {} + self._prefetch_bucket_sz = int(prefetch_bucket_size) + self._max_reuse_distance_in_numel = int(max_reuse_distance) + self._max_available_parameters_in_numel = int(max_live_parameters) + self.__allgather_stream = None if get_accelerator().is_synchronized_device() else get_accelerator().Stream( + ) if overlap_comm else get_accelerator().default_stream() + + if not hasattr(module, "ds_inflight_param_registry"): + module.ds_inflight_param_registry = dict() + # we need two registries, one for training and one for eval. They will be used when creating PartitionedParameterCoordinator + module.ds_inflight_param_registry[True] = InflightParamRegistry() + module.ds_inflight_param_registry[False] = InflightParamRegistry() + self.__inflight_param_registry = module.ds_inflight_param_registry + + self.forward_hooks = [] + self.backward_hooks = [] + self.setup_zero_stage3_hooks() + print_rank_0( + f'Created module hooks: forward = {len(self.forward_hooks)}, backward = {len(self.backward_hooks)}', + force=False) + + see_memory_usage("DeepSpeedZeRoOffload initialize [end]", force=True) + + self.use_ds_opt = True if os.getenv("OPT_MODULES", None) else False + print_rank_0(f"Stage_3 param_numel_persistence_threshold {self.param_numel_persistence_threshold}", force=self.use_ds_opt) + print_rank_0(f"Stage_3 max_reuse_distance: {self._max_reuse_distance_in_numel}", force=self.use_ds_opt) + print_rank_0(f"Stage_3 max_live_parameters: {self._max_available_parameters_in_numel}", force=self.use_ds_opt) + + @instrument_w_nvtx + def partition_all_parameters(self): + """Partitioning Parameters that were not partitioned usually if parameters + of modules whose input parameters do not require grad computation do not + trigger post call and will therefore will remain unpartitioned""" + self.get_param_coordinator(training=self.module.training).release_and_reset_all(self.module) + for param in iter_params(self.module, recurse=True): + if not self.use_ds_opt: + if param.ds_status != ZeroParamStatus.NOT_AVAILABLE: + raise RuntimeError(f"{param.ds_summary()} expected to be released") + else: + if param.ds_status == ZeroParamStatus.HOLD_COMPLETE: + continue + elif param.ds_status != ZeroParamStatus.NOT_AVAILABLE: + raise RuntimeError(f"{param.ds_summary()} expected to be released") + + def get_param_coordinator(self, training): + if not training in self.param_coordinators: + self.param_coordinators[training] = PartitionedParameterCoordinator( + prefetch_bucket_sz=self._prefetch_bucket_sz, + max_reuse_distance_in_numel=self._max_reuse_distance_in_numel, + max_available_parameters_in_numel=self._max_available_parameters_in_numel, + allgather_stream=self.__allgather_stream, + inflight_param_registry=self.__inflight_param_registry[training], + prefetch_nvme=self.offload_device == OffloadDeviceEnum.nvme, + timers=self.timers, + zero_quantized_weights=self.zero_quantized_weights, + zero_quantized_nontrainable_weights=self.zero_quantized_nontrainable_weights, + ) + + return self.param_coordinators[training] + + def empty_partition_cache(self): + self.partition_all_parameters() + + def _convert_to_zero_parameters(self, ds_config, module, mpu): + non_zero_params = [p for p in module.parameters() if not is_zero_param(p)] + if non_zero_params: + zero_params = [p for p in module.parameters() if is_zero_param(p)] + if zero_params: + zero_params[0].convert_to_zero_parameters(param_list=non_zero_params) + else: + group = None + if mpu: + group = mpu.get_data_parallel_group() + + Init(module=module, + data_parallel_group=group, + dtype=self.dtype, + config_dict_or_path=ds_config, + remote_device=self.offload_device, + pin_memory=self.offload_param_pin_memory, + mpu=mpu, + zero_param_parallel_group=self.zero_param_parallel_group, + zero_quantized_weights=self.zero_quantized_weights, + zero_quantized_nontrainable_weights=self.zero_quantized_nontrainable_weights) + + def destroy(self): + self._remove_module_hooks() + + def _remove_module_hooks(self): + num_forward_hooks = len(self.forward_hooks) + num_backward_hooks = len(self.backward_hooks) + + for hook in self.forward_hooks: + hook.remove() + + for hook in self.backward_hooks: + hook.remove() + + print_rank_0(f'Deleted module hooks: forward = {num_forward_hooks}, backward = {num_backward_hooks}', + force=False) + + def setup_zero_stage3_hooks(self): + self.hierarchy = 0 + + #reset step if in inference mode + @instrument_w_nvtx + def _end_of_forward_hook(module, *args): + + if not torch._C.is_grad_enabled(): + self.get_param_coordinator(training=False).reset_step() + + #likely one of them should be enough but just to be safe + self._register_hooks_recursively(self.module) + self.module.register_forward_hook(_end_of_forward_hook) + + # Add top module to stack trace + global FWD_MODULE_STACK + FWD_MODULE_STACK.append(self.module) + + def mark_persistent_parameters(self, param_threshold, model_threshold): + persistent_params = [] + total_persistent_parameters = 0 + params_count = 0 + for name, param in self.module.named_parameters(recurse=True): + if param.ds_numel + total_persistent_parameters > model_threshold: + continue + + if param.ds_numel <= param_threshold: + params_count += 1 + param.ds_persist = True + persistent_params.append(param) + total_persistent_parameters += param.ds_numel + + print_rank_0( + f"Parameter Offload: Total persistent parameters: {total_persistent_parameters} in {params_count} params", + force=True) + + return persistent_params + + def _register_hooks_recursively(self, module, count=[0]): + my_count = count[0] + module.id = my_count + + #print(f"{module.__class__} : {module.id}") + + if z3_leaf_module(module): + for param in module.parameters(): + param.ds_z3_leaf_module = module + else: + for child in module.children(): + count[0] = count[0] + 1 + self._register_hooks_recursively(child, count=count) + + @instrument_w_nvtx + def _pre_forward_module_hook(module, *args): + self.pre_sub_module_forward_function(module, *args) + + @instrument_w_nvtx + def _post_forward_module_hook(module, input, output): + + global FWD_MODULE_STACK + FWD_MODULE_STACK.pop() + if output is None: + output = [] + elif not isinstance(output, (list, tuple)): + if torch.is_tensor(output): + output = [output] + else: + #print(f'got UNKNOWN type {type(output)}') + outputs = [] + output = output if isinstance(output, dict) else vars(output) + for name, val in output.items(): + if not name.startswith('__') and torch.is_tensor(val): + outputs.append(val) + output = outputs + + for item in filter(lambda item: is_zero_param(item) or hasattr(item, 'ds_param_alias'), output): + key = id(item) if hasattr(item, 'ds_id') else id(item.ds_param_alias) + actual_external_param = item if hasattr(item, 'ds_id') else item.ds_param_alias + + if not any(key in m._external_params for m in FWD_MODULE_STACK): + actual_external_param.is_external_param = True + module_to_register = FWD_MODULE_STACK[-1] + register_external_parameter(module_to_register, actual_external_param) + print_rank_0( + f'Registering dangling parameter for module {module_to_register.__class__.__name__}, ds_id = {actual_external_param.ds_id}.', + force=False) + + # It's possible that the parameter was already external to the completed module. If so, remove it the + # registration as it will be covered by the outer module instead. + if key in module._external_params: + print_rank_0( + f' Unregistering nested dangling parameter from module {module.__class__.__name__}, ds_id = {actual_external_param.ds_id}', + force=False) + unregister_external_parameter(module, actual_external_param) + + actual_external_param.all_gather() + + self.post_sub_module_forward_function(module) + + def _bwd_hook_unexpected_inputs_msg(value): + return f"A module has unknown inputs or outputs type ({type(value)}) and the tensors embedded in it cannot be detected. " \ + "The ZeRO-3 hooks designed to trigger before or after backward pass of the module relies on knowing the input and " \ + "output tensors and therefore may not get triggered properly." + + def _pre_backward_module_hook(module, inputs, output): + + if not hasattr(module, "pre_bwd_fn"): + + @instrument_w_nvtx + def _run_before_backward_function(sub_module): + # some models (e.g. Albert) may run multiple forwards on the same layer in a loop + # before doing backwards, so each backward will need a pre-fetch - using reference + # counting to support this scenario + #print(f"COUNTER before: {sub_module.applied_pre_backward_ref_cnt}") + if sub_module.applied_pre_backward_ref_cnt > 0: + self.pre_sub_module_backward_function(sub_module) + sub_module.applied_pre_backward_ref_cnt -= 1 + #print(f"COUNTER after: {sub_module.applied_pre_backward_ref_cnt}") + + class PreBackwardFunctionForModule(torch.autograd.Function): + + @staticmethod + def forward(ctx, outputs): + # Capture `module` and _run_before_backward_function + ctx.module = module + ctx.pre_backward_function = _run_before_backward_function + if not hasattr(ctx.module, "applied_pre_backward_ref_cnt"): + ctx.module.applied_pre_backward_ref_cnt = 0 + ctx.module.applied_pre_backward_ref_cnt += 1 + outputs = outputs.detach() + return outputs + + @staticmethod + def backward(ctx, *args): + ctx.pre_backward_function(ctx.module) + return args + + module.pre_bwd_fn = PreBackwardFunctionForModule + + return apply_to_tensors_only(module.pre_bwd_fn.apply, + output, + warning_msg_fn=_bwd_hook_unexpected_inputs_msg) + + #This is an alternate to doing _post_backward_module_hook + #it uses tensor.register_hook instead of using torch.autograd.Function + def _alternate_post_backward_module_hook(module, inputs): + module.ds_grads_remaining = 0 + + #print(f"Before Forward {module.__class__.__name__}") + + def _run_after_backward_hook(*unused): + module.ds_grads_remaining = module.ds_grads_remaining - 1 + if module.ds_grads_remaining == 0: + #print(f"After backward {module.__class__.__name__}") + self.post_sub_module_backward_function(module) + + def _run_before_forward_function(input): + if input.requires_grad: + module.ds_grads_remaining += 1 + + return _apply_forward_and_backward_to_tensors_only(module, _run_before_forward_function, + _run_after_backward_hook, inputs) + + def _post_backward_module_hook(module, inputs): + module.ds_grads_remaining = 0 + + if not hasattr(module, "post_bwd_fn"): + + @instrument_w_nvtx + def _run_after_backward_function(sub_module): + if sub_module.ds_grads_remaining == 0: + self.post_sub_module_backward_function(sub_module) + + class PostBackwardFunctionModule(torch.autograd.Function): + + @staticmethod + def forward(ctx, output): + ctx.module = module + if output.requires_grad: + #TODO SOME TIMES post backward does not seem to be triggered debug in detail + #Should only cause increase in memory not correctness issue + #if output.grad_fn.__class__.__name__ == 'ViewBackward': + # ctx.view=True + # print(f"Warning view tensor for input to module : {module.__class__.__name__}. Backward hooks may not trigger properly") + #assert len(module.parameters(recurse=False)), "The input tensor to the module is a view, and autograd Function or register_hook is not triggered with view tensors." + #if module.ds_grads_remaining == 0: + # print(f"Before Forward: {ctx.module.__class__.__name__}") + module.ds_grads_remaining += 1 + ctx.post_backward_function = _run_after_backward_function + output = output.detach() + return output + + @staticmethod + def backward(ctx, *args): + ctx.module.ds_grads_remaining = ctx.module.ds_grads_remaining - 1 + if ctx.module.ds_grads_remaining == 0: + ctx.post_backward_function(ctx.module) + return args + + module.post_bwd_fn = PostBackwardFunctionModule + + return apply_to_tensors_only(module.post_bwd_fn.apply, + inputs, + warning_msg_fn=_bwd_hook_unexpected_inputs_msg) + + # Pre forward hook + self.forward_hooks.append(module.register_forward_pre_hook(_pre_forward_module_hook)) + + # Post forward hook + self.forward_hooks.append(module.register_forward_hook(_post_forward_module_hook)) + + # Pre backward hook + self.backward_hooks.append(module.register_forward_hook(_pre_backward_module_hook)) + + # post backward hook + self.backward_hooks.append(module.register_forward_pre_hook(_post_backward_module_hook)) + + @torch.no_grad() + def pre_sub_module_forward_function(self, sub_module, *args): + see_memory_usage(f"Before sub module function {sub_module.__class__.__name__}", force=False) + + global FWD_MODULE_STACK + FWD_MODULE_STACK.append(sub_module) + + param_coordinator = self.get_param_coordinator(training=sub_module.training) + param_coordinator.trace_prologue(sub_module) + if param_coordinator.is_record_trace(): + param_coordinator.record_module(sub_module) + param_coordinator.fetch_sub_module(sub_module, forward=True, args=args) + + see_memory_usage(f"Before sub module function {sub_module.__class__.__name__} after fetch", force=False) + + @torch.no_grad() + def post_sub_module_forward_function(self, sub_module): + see_memory_usage(f"After sub module function {sub_module.__class__.__name__} {sub_module.id} before release", + force=False) + + param_coordinator = self.get_param_coordinator(training=sub_module.training) + param_coordinator.release_sub_module(sub_module) + + see_memory_usage(f"After sub module function {sub_module.__class__.__name__} {sub_module.id} after release", + force=False) + + @torch.no_grad() + def pre_sub_module_backward_function(self, sub_module): + assert sub_module.training, "backward pass is invalid for module in evaluation mode" + param_coordinator = self.get_param_coordinator(training=True) + param_coordinator.trace_prologue(sub_module) + if param_coordinator.is_record_trace(): + param_coordinator.record_module(sub_module) + param_coordinator.fetch_sub_module(sub_module, forward=False) + + @torch.no_grad() + def post_sub_module_backward_function(self, sub_module): + assert sub_module.training, "backward pass is invalid for module in evaluation mode" + see_memory_usage( + f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} before release", + force=False) + + self.get_param_coordinator(training=True).release_sub_module(sub_module, forward=False) + + see_memory_usage( + f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} after release", + force=False) diff --git a/toolbox/DeepSpeed/v0.15.3/patches/deepspeed/runtime/zero/partition_parameters.py b/toolbox/DeepSpeed/v0.15.3/patches/deepspeed/runtime/zero/partition_parameters.py new file mode 100644 index 0000000000000000000000000000000000000000..584f0bbf4597b9f08a607ed9593cb1aa38dcf070 --- /dev/null +++ b/toolbox/DeepSpeed/v0.15.3/patches/deepspeed/runtime/zero/partition_parameters.py @@ -0,0 +1,2253 @@ +#!/usr/bin/env python3 +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import math +import os +import types +from typing import Callable, Iterable +from enum import Enum +import functools +import itertools +from typing import List +from collections import defaultdict +import logging +import torch +from torch import Tensor +from deepspeed import comm as dist +from torch.nn import Module +from torch.nn import Parameter + +from .linear import zero3_linear_wrap + +from deepspeed.utils import groups +import deepspeed +from ..utils import see_memory_usage, get_only_unique_item +from deepspeed.runtime.zero.config import DeepSpeedZeroConfig +from deepspeed.runtime.zero.utils import assert_ints_same_as_other_ranks, is_zero_param +from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum +from deepspeed.runtime.config_utils import get_config_default +from deepspeed.utils import instrument_w_nvtx, logger +from deepspeed.comm.comm import init_distributed +from deepspeed.utils.debug import (debug_param2name_id_shape, debug_param2name_id_shape_device, debug_module2name, + debug_param2name_id, debug_param2name_id_shape_status) +from deepspeed.accelerator import get_accelerator +from ..swap_tensor.partitioned_param_swapper import AsyncPartitionedParameterSwapper, PartitionedParamStatus +from deepspeed.inference.quantization.utils import _quantize_param, WEIGHT_QUANTIZATION_LAYERS, wrap_quantized_functional, wrap_load_from_state_dict + +partitioned_param_data_shape = [0] +zero_init_context = 0 +top_level_context = None + + +class NoGatherHandle: + + def __init__(self, param: Parameter) -> None: + if param.ds_status != ZeroParamStatus.INFLIGHT: + raise RuntimeError(f"expected param {param.ds_summary()} to be available") + + if hasattr(param.ds_tensor, "ds_quant_scale"): + param.data = Init.quantizer_module.dequantize(param.ds_tensor.data, param.ds_tensor.ds_quant_scale).to( + device=get_accelerator().current_device_name(), non_blocking=True).view(param.ds_shape) + else: + param.data = param.ds_tensor.data.to(device=get_accelerator().current_device_name(), + non_blocking=True).view(param.ds_shape) + self.__param = param + + def wait(self) -> None: + if not get_accelerator().resolves_data_dependency(): + get_accelerator().current_stream().synchronize() + self.__param.ds_status = ZeroParamStatus.AVAILABLE + + +class NoGatherCoalescedHandle: + + def __init__(self, params: List[Parameter]) -> None: + self.__params = params + self.__complete = False + + for param in self.__params: + if param.ds_status != ZeroParamStatus.INFLIGHT: + raise RuntimeError(f"expected param {param.ds_summary()} to not be available") + if hasattr(param.ds_tensor, "ds_quant_scale"): + param.data = Init.quantizer_module.dequantize(param.ds_tensor.data, param.ds_tensor.ds_quant_scale).to( + device=get_accelerator().current_device_name(), non_blocking=True).view(param.ds_shape) + else: + param.data = param.ds_tensor.data.to(device=get_accelerator().current_device_name(), + non_blocking=True).view(param.ds_shape) + + @instrument_w_nvtx + def wait(self) -> None: + if self.__complete: + return + + if not get_accelerator().resolves_data_dependency(): + get_accelerator().current_stream().synchronize() + for param in self.__params: + assert param.ds_status == ZeroParamStatus.INFLIGHT, f"expected param {param.ds_summary()} to be inflight" + param.ds_status = ZeroParamStatus.AVAILABLE + + self.__complete = True + + +def _dist_allgather_fn(input_tensor: Tensor, output_tensor: Tensor, group=None): + return instrument_w_nvtx(dist.allgather_fn)(output_tensor, input_tensor, group=group, async_op=True) + + +def print_rank_0(message, debug=False, force=False): + rank = dist.get_rank() + if rank == 0 and (debug or force): + print(message) + # other variations + # - print for all ranks w/o interleaving + # printflock(f"[{rank}] {message}") + # - print to log file per rank + # log_rank_file(rank, message) + + +def debug_rank0(msg: str) -> None: + if dist.get_rank() == 0: + logger.debug(msg) + + +def _init_external_params(module): + if not hasattr(module, '_external_params'): + module._external_params = {} + + def external_parameters(self): + return self._external_params.items() + + def all_parameters(self): + return itertools.chain(self.named_parameters(self, recurse=False), external_parameters(self)) + + module.ds_external_parameters = types.MethodType(external_parameters, module) + module.all_parameters = types.MethodType(all_parameters, module) + + +def register_external_parameter(module, parameter): + """Instruct DeepSpeed to coordinate ``parameter``'s collection and partitioning in + the forward and backward passes of ``module``. + + This is used when a parameter is accessed outside of its owning module's + ``forward()``. DeepSpeed must know to collect it from its partitioned + state and when to release the memory. + + .. note:: + This is only applicable to training with ZeRO stage 3. + + Args: + module (``torch.nn.Module``): The module that requires ``parameter`` in its forward pass. + parameter (``torch.nn.Parameter``): The parameter to register. + + Raises: + RuntimeError: If ``parameter`` is not of type ``torch.nn.Parameter``. + + + Examples + ======== + + #. Register a weight that is used in another module's forward pass (line 6). + Parameter ``layer1.weight`` is used by ``layer2`` (line 11). + + .. code-block:: python + :linenos: + :emphasize-lines: 6,11 + + class ModuleZ3(torch.nn.Module): + def __init__(self, *args): + super().__init__(self, *args) + self.layer1 = SomeLayer() + self.layer2 = OtherLayer() + deepspeed.zero.register_external_parameter(self, self.layer1.weight) + + def forward(self, input): + x = self.layer1(input) + # self.layer1.weight is required by self.layer2.forward + y = self.layer2(x, self.layer1.weight) + return y + """ + if not isinstance(parameter, torch.nn.Parameter): + raise RuntimeError('Parameter is not a torch.nn.Parameter') + + if not hasattr(module, '_external_params'): + _init_external_params(module) + + key = id(parameter) + module._external_params[key] = parameter + + +def unregister_external_parameter(module, parameter): + """Reverses the effects of :meth:`register_external_parameter`. + + Args: + module (``torch.nn.Module``): The module to affect. + parameter (``torch.nn.Parameter``): The parameter to unregister. + + Raises: + RuntimeError: If ``parameter`` is not of type ``torch.nn.Parameter``. + RuntimeError: If ``parameter`` is not a registered external parameter of ``module``. + """ + if not isinstance(parameter, torch.nn.Parameter): + raise RuntimeError('Parameter is not a torch.nn.Parameter') + + if not hasattr(module, '_external_params') or id(parameter) not in module._external_params: + raise RuntimeError('Parameter is not a registered external parameter of module.') + + key = id(parameter) + del module._external_params[key] + + +class ZeroParamType(Enum): + + # same as regular pytorch parameters + NORMAL = 1 + + # parameters are partitioned across data parallel process + PARTITIONED = 2 + + # the parameter is held with a unique process rank + # and is not available on all other process + REMOTE = 3 + + +class ZeroParamStatus(Enum): + # parameters are fully present and ready for use on all processes + AVAILABLE = 1 + + # parameters are either partitioned or remote in some or all process + NOT_AVAILABLE = 2 + + # parameters are being gathered. + INFLIGHT = 3 + + # parameters are fully cached for zero3 opt + HOLD_COMPLETE = 4 + + +_orig_torch_tensor = torch.tensor +_orig_torch_empty = torch.empty +_orig_torch_zeros = torch.zeros +_orig_torch_ones = torch.ones +_orig_torch_full = torch.full +_orig_torch_arange = torch.arange +_orig_torch_eye = torch.eye +_orig_torch_randn = torch.randn + + +def zero_wrapper_for_fp_tensor_constructor(fn: Callable, target_fp_dtype: torch.dtype) -> Callable: + + def wrapped_fn(*args, **kwargs) -> Tensor: + if kwargs.get("device", None) is None: + kwargs['device'] = torch.device(get_accelerator().device_name(os.environ["LOCAL_RANK"])) + tensor: Tensor = fn(*args, **kwargs) + if tensor.is_floating_point(): + tensor.data = tensor.data.to(target_fp_dtype) + + return tensor + + return wrapped_fn + + +def get_new_tensor_fn_for_dtype(dtype: torch.dtype) -> Callable: + + def new_tensor(cls, *args, **kwargs) -> Tensor: + device = torch.device(get_accelerator().device_name(os.environ["LOCAL_RANK"])) + if not args: + args = (0, ) + tensor = _orig_torch_empty(0, device=device).new_empty(*args, **kwargs) + if tensor.is_floating_point(): + tensor = tensor.to(dtype) + + return tensor + + return new_tensor + + +# https://stackoverflow.com/a/63851681/9201239 +def get_all_subclasses(cls, include_root=True): + subclass_list = [] + + def recurse(cl): + for subclass in cl.__subclasses__(): + subclass_list.append(subclass) + recurse(subclass) + + recurse(cls) + + ret = set(subclass_list) + if include_root: + ret.add(cls) + return ret + + +@instrument_w_nvtx +def free_param(param: Parameter) -> None: + """Free underlying storage of a parameter.""" + assert not param.ds_active_sub_modules, param.ds_summary() + if get_accelerator().on_accelerator(param.data): + # need to make sure that we don't free the parameter while it is still + # being used for computation + if not get_accelerator().is_synchronized_device(): + param.data.record_stream(get_accelerator().current_stream()) + # param.data doesn't store anything meaningful in partitioned state + param.data = torch.empty(0, dtype=param.dtype, device=param.device) + param.ds_status = ZeroParamStatus.NOT_AVAILABLE + + +reuse_buffers = False +temp_contiguous_tensor = None +empty_buffers = {} + + +# Inserts _post_init_method at the end of init method +# for all sub classes of torch.nn.Module +class InsertPostInitMethodToModuleSubClasses(object): + num_module_parameters = 0 + num_module_elements = 0 + + def __init__(self, enabled=True, mem_efficient_linear=True, ds_config=None, dtype=None): + self.mem_efficient_linear = mem_efficient_linear + self.enabled = enabled + self._set_dtype(ds_config, dtype) + assert self.dtype in [ + torch.half, torch.bfloat16, torch.float + ], f"Invalid data type {self.dtype}, allowed values are [torch.half, torch.bfloat16, torch.float]" + self.wrapped_cls = set() + self.skip_init_depth = 0 + + self.quantized_initialization = None + if ds_config is not None and ds_config.weight_quantization_config and ds_config.weight_quantization_config.quantized_initialization: + self.quantized_initialization = ds_config.weight_quantization_config.quantized_initialization + + def __enter__(self): + if not self.enabled: + return + + global zero_init_context + if zero_init_context == 0: + self.patch_init_and_builtins() + global top_level_context + top_level_context = self + + zero_init_context += 1 + + def __exit__(self, exc_type, exc_value, traceback): + if not self.enabled: + return + + global zero_init_context + zero_init_context -= 1 + + # Exiting the top level context + if zero_init_context == 0: + self.unpatch_init_and_builtins() + global top_level_context + top_level_context = None + + if dist.get_rank() == 0: + billion_elems = InsertPostInitMethodToModuleSubClasses.num_module_elements / 1e9 + num_params = InsertPostInitMethodToModuleSubClasses.num_module_parameters + logger.info( + f"finished initializing model - num_params = {num_params}, num_elems = {billion_elems:.2f}B") + + # Now that we cleaned up the metaclass injection, raise the exception. + if exc_type is not None: + return False + + # To be implemented by inheriting classes + def _post_init_method(self, module): + pass + + def _set_dtype(self, ds_config, dtype): + if ds_config is not None and dtype is None: + if ds_config.bfloat16_enabled and ds_config.fp16_enabled: + raise RuntimeError("bfloat16 and fp16 cannot be enabled at once") + + if ds_config.bfloat16_enabled: + self.dtype = torch.bfloat16 + elif ds_config.fp16_enabled: + self.dtype = torch.half + else: + self.dtype = torch.float + else: + self.dtype = dtype or torch.float16 if get_accelerator().is_fp16_supported( + ) else torch.bfloat16 if get_accelerator().is_bf16_supported else torch.float32 + + def patch_init_and_builtins(self): + + def apply_with_gather(orig_module_apply_fn: Callable) -> Callable: + """many models make use of child modules like Linear or Embedding which + perform their own weight initialization in their __init__ methods, + but will then have more weight initialization in a parent module's __init__ + method that modifies weights of child modules, which is typically done + using the Module.apply method. + + since the Init context manager partitions child modules immediately after + they are initialized, without modifying apply we would entirely skip + any initialization done by parent modules. + + to get around this issue, we wrap the function passed to Module.apply + so that the applied function is applied to child modules correctly. + """ + + def get_wrapped_fn_to_apply(fn_to_apply: Callable) -> Callable: + if hasattr(fn_to_apply, "wrapped"): + return fn_to_apply + + @functools.wraps(fn_to_apply) + def wrapped_fn_to_apply(module_to_apply_fn_to: Module) -> None: + """gathers parameters before calling apply function. afterwards + parameters are broadcasted to ensure consistency across all ranks + then re-partitioned. + + takes the following steps: + 1. allgathers parameters for the current module being worked on + 2. calls the original function + 3. broadcasts root rank's parameters to the other ranks + 4. re-partitions the parameters + """ + + # TODO Delay error checking for dangling partitioned parameters to post module init + # raise RuntimeError(f"not all parameters for {module_to_apply_fn_to.__class__.__name__}, " + # f"were zero params, is it possible that the parameters were " + # f"overwritten after they were initialized? " + # f"params: {[p for p in module_to_apply_fn_to.parameters(recurse=False)]} ") + + params_to_apply_fn_to: Iterable[Parameter] = list( + sorted([p for p in module_to_apply_fn_to.parameters(recurse=False) if is_zero_param(p)], + key=lambda p: p.ds_id)) + + for param in params_to_apply_fn_to: + param.all_gather() + + fn_to_apply(module_to_apply_fn_to) + + for param in params_to_apply_fn_to: + dist.broadcast(param.data, 0, group=param.ds_process_group) + + for param in params_to_apply_fn_to: + param.partition(has_been_updated=True) + + wrapped_fn_to_apply.wrapped = True + + return wrapped_fn_to_apply + + @functools.wraps(orig_module_apply_fn) + def wrapped_apply(module: Module, fn_to_apply: Callable) -> None: + orig_module_apply_fn(module, get_wrapped_fn_to_apply(fn_to_apply)) + + return wrapped_apply + + def hook_for_skip_init(module): + # this function is intended for handling the logic of torch.nn.utils.skip_init + # skip_init:module_cls(*args, **kwargs).to_empty(device=final_device), where kwargs['device']='meta' + # the function call occurs between module_cls(*args, **kwargs) and to_empty(device=final_device). + def partition_after_empty_init(f): + + @functools.wraps(f) + def wrapper(module, *args, **kwargs): + _module = f(module, *args, **kwargs) + # here is the post-hook for module.apply(empty_like...) + # after module.apply(empty_like...), the module has completed its empty init on real device + # since skip_init won't involve any computations or weight adjustments, we can directly utilize post_init + self._post_init_method(_module) + return _module + + return wrapper + + def post_wrapper_to_empty(f): + # append some wrapper restoration after to_empty() call + @functools.wraps(f) + def wrapper(*args, **kwargs): + res = f(*args, **kwargs) + # restore _apply hook + for subclass in get_all_subclasses(torch.nn.modules.module.Module): + _disable_class_apply(subclass) + # self restore + module.to_empty = f + return res + + return wrapper + + def _enable_class_apply(cls): + if '_apply' in cls.__dict__: + cls._old_apply_of_skip_init_hook = cls._apply + cls._apply = partition_after_empty_init(cls._apply) + + def _disable_class_apply(cls): + if hasattr(cls, '_old_apply_of_skip_init_hook'): + cls._apply = cls._old_apply_of_skip_init_hook + + # add hooks for to_empty: apply_(empty_like) + for subclass in get_all_subclasses(torch.nn.modules.module.Module): + _enable_class_apply(subclass) + + # add a restore hook when exiting skip_init + module.to_empty = post_wrapper_to_empty(module.to_empty) + + def partition_after(f): + + @functools.wraps(f) + def wrapper(module, *args, **kwargs): + + # important logic: We want to run post_init only after child's __init__ is + # completed, and do nothing after __init__ of any of its parents and grandparents in + # the inheritance ancestry. This way the partitioning will need to happen only once + # when the whole object is ready to be partitioned and not before. This is because + # often the child module will need to tweak the weights - for example running a + # custom weights init function. So if a parent created the weights param, the child + # won't need to gather it in order to tweak it + + print_rank_0(f'Before initializing {module.__class__.__name__}', force=False) + + is_child_module = False + if not hasattr(module, "_ds_child_entered"): + # child's __init__ was called, since parents all see the same object they can now skip post_init + is_child_module = True + setattr(module, "_ds_child_entered", True) + + init_on_meta = 'device' in kwargs and kwargs['device'] == 'meta' + if init_on_meta: + self.skip_init_depth += 1 + + f(module, *args, **kwargs) + if init_on_meta and self.skip_init_depth == 1: + # check and handle the logic of empty_init + hook_for_skip_init(module) + if is_child_module: + # child's __init__ is done, now we can run a single post_init on the child object + delattr(module, "_ds_child_entered") + + print_rank_0(f'Running post_init for {module.__class__.__name__}', force=False) + if self.skip_init_depth == 0: + self._post_init_method(module) + + print_rank_0(f'After initializing followed by post init for {module.__class__.__name__}', force=False) + if init_on_meta: + self.skip_init_depth -= 1 + + return wrapper + + def _enable_class(cls): + if '__init__' in cls.__dict__: + cls._old_init = cls.__init__ + cls.__init__ = partition_after(cls.__init__) + + def _init_subclass(cls, **kwargs): + if '__init__' in cls.__dict__: + cls._old_init = cls.__init__ + cls.__init__ = partition_after(cls.__init__) + + # Replace .__init__() for all existing subclasses of torch.nn.Module recursively + for subclass in get_all_subclasses(torch.nn.modules.module.Module): + _enable_class(subclass) + + # holding onto some methods so we can put them back the way they were in __exit__ + torch.nn.modules.module.Module._old_init_subclass = torch.nn.modules.module.Module.__init_subclass__ + torch.nn.modules.module.Module._old_apply = torch.nn.modules.module.Module.apply + torch.Tensor.__old_new__ = torch.Tensor.__new__ + + # Replace .__init__() for future subclasses of torch.nn.Module + torch.nn.modules.module.Module.__init_subclass__ = classmethod(_init_subclass) + if Init.override_module_apply: + torch.nn.modules.module.Module.apply = apply_with_gather(torch.nn.modules.module.Module._old_apply) + + self._add_tensor_creation_wrappers() + + if self.mem_efficient_linear: + print_rank_0( + "nn.functional.linear has been overridden with a more memory efficient version. This will persist unless manually reset.", + force=False) + self.linear_bk = torch.nn.functional.linear + torch.nn.functional.linear = zero3_linear_wrap + + if self.quantized_initialization: + print_rank_0("nn.functional.linear has been overridden with quantized linear version.", force=False) + torch.nn.functional.linear = wrap_quantized_functional(torch.nn.functional.linear) + torch.nn.functional.embedding = wrap_quantized_functional(torch.nn.functional.embedding) + for cls in WEIGHT_QUANTIZATION_LAYERS: + cls._load_from_state_dict = wrap_load_from_state_dict(cls._load_from_state_dict) + + logger.info("Enable Zero3 engine with INT4 quantization.") + + self.patched = True + + def unpatch_init_and_builtins(self): + if self.patched: + + def _disable_class(cls): + if hasattr(cls, '_old_init'): + cls.__init__ = cls._old_init + + for subclass in get_all_subclasses(torch.nn.modules.module.Module): + _disable_class(subclass) + + # putting methods back the way we found them + torch.nn.modules.module.Module.__init_subclass__ = torch.nn.modules.module.Module._old_init_subclass + if Init.override_module_apply: + torch.nn.modules.module.Module.apply = torch.nn.modules.module.Module._old_apply + + self._remove_tensor_creation_wrappers() + + self.patched = False + + def _add_tensor_creation_wrappers(self): + torch.Tensor.__new__ = get_new_tensor_fn_for_dtype(self.dtype) + torch.tensor = zero_wrapper_for_fp_tensor_constructor(_orig_torch_tensor, self.dtype) + torch.empty = zero_wrapper_for_fp_tensor_constructor(_orig_torch_empty, self.dtype) + torch.zeros = zero_wrapper_for_fp_tensor_constructor(_orig_torch_zeros, self.dtype) + torch.ones = zero_wrapper_for_fp_tensor_constructor(_orig_torch_ones, self.dtype) + torch.full = zero_wrapper_for_fp_tensor_constructor(_orig_torch_full, self.dtype) + torch.arange = zero_wrapper_for_fp_tensor_constructor(_orig_torch_arange, self.dtype) + torch.eye = zero_wrapper_for_fp_tensor_constructor(_orig_torch_eye, self.dtype) + torch.randn = zero_wrapper_for_fp_tensor_constructor(_orig_torch_randn, self.dtype) + + def _remove_tensor_creation_wrappers(self): + torch.Tensor.__new__ = torch.Tensor.__old_new__ + torch.tensor = _orig_torch_tensor + torch.empty = _orig_torch_empty + torch.zeros = _orig_torch_zeros + torch.ones = _orig_torch_ones + torch.full = _orig_torch_full + torch.arange = _orig_torch_arange + torch.eye = _orig_torch_eye + torch.randn = _orig_torch_randn + + +def shutdown_init_context(): + """ + This function is used to initialize deepspeed engine inside the context of Init. + We need to remove the wrappers but keep the context. + """ + if top_level_context: + top_level_context.unpatch_init_and_builtins() + + +def restore_init_context(): + """ + This function is used to restore the wrappers after deepspeed engine is initialized. + """ + if top_level_context: + top_level_context.patch_init_and_builtins() + + +class AllGatherHandle: + + def __init__(self, handle, param: Parameter, quantization=None) -> None: + if param.ds_status != ZeroParamStatus.INFLIGHT: + raise RuntimeError(f"expected param {param.ds_summary()} to be available") + + self.__handle = handle + self.__param = param + self.__quantization = quantization + + def wait(self) -> None: + instrument_w_nvtx(self.__handle.wait)() + if self.__quantization: + instrument_w_nvtx(self.__quantization.quant_handle.wait)() + self.__param.data = self.__quantization.backend.dequantize( + self.__quantization.quantized_param, self.__quantization.scale_buffer).to(self.__param.device) + self.__param.ds_status = ZeroParamStatus.AVAILABLE + + +class AllGatherCoalescedHandle: + + def __init__( + self, + allgather_handle, + params: List[Parameter], + partitions: List[Tensor], + world_size: int, + use_secondary_tensor=False, + quantization=None, + ) -> None: + self.allgather_handle = allgather_handle + self.params = params + self.partitions = partitions + self.world_size = world_size + self.use_secondary_tensor = use_secondary_tensor + self.complete = False + self.quantization = quantization + + for param in self.params: + if param.ds_status != ZeroParamStatus.INFLIGHT: + raise RuntimeError(f"expected param {param.ds_summary()} to not be available") + + @instrument_w_nvtx + def wait(self) -> None: + if self.complete: + return + + instrument_w_nvtx(self.allgather_handle.wait)() + + if self.quantization: + instrument_w_nvtx(self.quantization.quant_handle.wait)() + flat_tensor = self.quantization.backend.dequantize( + self.quantization.quantized_param, self.quantization.scale_buffer).to(self.params[0].device) + + self.partitions: List[Parameter] = [] + for i in range(self.world_size): + self.partitions.append( + flat_tensor.narrow(0, self.quantization.partition_sz * i, self.quantization.partition_sz)) + + # split the single tensor out into individual tensors + param_offset = 0 + for param in self.params: + assert param.ds_status == ZeroParamStatus.INFLIGHT, f"expected param {param.ds_summary()} to be inflight" + partitions: List[Tensor] = [] + ds_tensor_numel = param.ds_tensor.ds_numel + if self.use_secondary_tensor: + ds_tensor_numel *= param.ds_secondary_tensor_num_of_groups + for rank in range(self.world_size): + param_start = rank * ds_tensor_numel + if param_start < param.ds_numel: + part_to_copy = self.partitions[rank].narrow(0, param_offset, + min(param.ds_numel - param_start, ds_tensor_numel)) + partitions.append(part_to_copy) + param.data = instrument_w_nvtx(torch.cat)(partitions).view(param.ds_shape) + param.ds_status = ZeroParamStatus.AVAILABLE + + for part_to_copy in partitions: + if not get_accelerator().is_synchronized_device(): + part_to_copy.record_stream(get_accelerator().current_stream()) + + param_offset += ds_tensor_numel + + self.complete = True + + +class MultipleAllGatherHandles: + + def __init__(self, handles: List[AllGatherCoalescedHandle]): + self.handles = handles + + def wait(self) -> None: + for handle in self.handles: + handle.wait() + + +class AllReduceCoalescedHandle: + + def __init__(self, handle, params: List[Parameter]) -> None: + self.handle = handle + self.params = params + self.complete = False + + for param in self.params: + if param.ds_status != ZeroParamStatus.INFLIGHT: + raise RuntimeError(f"expected param {param.ds_summary()} to not be available") + + @instrument_w_nvtx + def wait(self) -> None: + if self.complete: + return + + instrument_w_nvtx(self.handle.wait)() + + for param in self.params: + assert param.ds_status == ZeroParamStatus.INFLIGHT, f"expected param {param.ds_summary()} to be inflight" + param.ds_status = ZeroParamStatus.AVAILABLE + + self.complete = True + + +class QuantizationInfo: + # a placeholder object to store all quant related vars used in handles + def __init__(self) -> None: + self.quantized_param = None + self.backend = None + self.quant_handle = None + self.scale_buffer = None + + +class CUDAQuantizer: + async_flag = True + target_group_size = 8000 # the optimal size is 4k, so we set the target to be below 8k + group_size_cache = dict() + quantizer_cuda_module = None + + def __init__(self) -> None: + if CUDAQuantizer.quantizer_cuda_module is None: + CUDAQuantizer.quantizer_cuda_module = deepspeed.ops.op_builder.QuantizerBuilder().load() + + def quantize(self, param, groups=None): + if groups is None: + try: + groups = self.group_size_cache[param.numel()] + except KeyError: + groups = math.ceil(param.numel() / self.target_group_size) + while groups < param.numel(): + if param.numel() % (8 * groups) == 0: + break + groups += 1 + while True: + if param.numel() % (8 * groups * 2) == 0 and param.numel( + ) / groups > self.target_group_size: #hard limit of 16k group_size + groups *= 2 + else: + break + assert ( + param.numel() % (8 * groups) == 0 + ), f"Qantized weight requires the number of weights be a multiple of 8. Yet {param.numel()} cannot be divided by 8*{groups}" + assert (param.numel() / groups < 16000), f"{param.numel()} / {groups} is larger than 16k" + assert param.numel( + ) > groups, f"Adaptive grouping algorithm cannot find a group size for input tensor of size {param.numel()}" + self.group_size_cache[param.numel()] = groups + return self.quantizer_cuda_module.quantize(param.to(get_accelerator().device_name()), groups, 8, + self.quantizer_cuda_module.Symmetric) + + def dequantize(self, quantized_param, scale): + return self.quantizer_cuda_module.dequantize(quantized_param, scale, scale.numel(), 8, + self.quantizer_cuda_module.Symmetric) + + +def _no_gather_coalesced(params: Iterable[Parameter]) -> AllGatherCoalescedHandle: + for param in params: + if param.ds_status != ZeroParamStatus.NOT_AVAILABLE: + raise RuntimeError(f"expect param.ds_status == ZeroParamStatus.NOT_AVAILABLE, got{param.ds_summary()}") + param.ds_status = ZeroParamStatus.INFLIGHT + + params = sorted(params, key=lambda p: p.ds_id) + if len(params) == 1: + param, = params + return NoGatherHandle(param) + return NoGatherCoalescedHandle(params) + + +# Replaces all parameters in module with Scattered Parameters +class Init(InsertPostInitMethodToModuleSubClasses): + param_id = 0 + param_persistence_threshold = get_config_default(DeepSpeedZeroConfig, "param_persistence_threshold") + model_persistence_threshold = get_config_default(DeepSpeedZeroConfig, "model_persistence_threshold") + num_persisted_parameters = 0 + num_persisted_elements = 0 + apply_param_persistence = False + override_module_apply = get_config_default(DeepSpeedZeroConfig, "override_module_apply") + + def __init__(self, + module=None, + data_parallel_group=None, + mem_efficient_linear=True, + remote_device=None, + pin_memory=False, + config_dict_or_path=None, + config=None, + enabled=True, + dtype=None, + mpu=None, + zero_param_parallel_group=None, + zero_quantized_weights=False, + zero_quantized_nontrainable_weights=False, + sequence_data_parallel_group=None, + param_swapper=None): + """A context to enable massive model construction for training with + ZeRO-3. Models are automatically partitioned (or, sharded) across the + system and converted to half precision. + + Args: + module (``torch.nn.Module``, optional): If provided, partition the model as + if it was constructed in the context. + data_parallel_group (``deepspeed.comm`` process group, optional): + The group of processes to partition among. Defaults to all processes. + Synonymous with sequence data parallel group for param partitioning + across both sequence and data parallel groups. + mem_efficient_linear (bool, optional): Replace + torch.nn.functional.linear with an implementation that allows + DeepSpeed to partition parameters. Defaults to ``True``. + remote_device (string, optional): The initial device to store model + weights e.g., ``cpu``, ``nvme``. Passing ``"cpu"`` will create the model in CPU + memory. The model may still be moved to GPU based on the + offload settings for training. Defaults to param offload device if a config is + defined, otherwise GPU. + pin_memory (bool, optional): Potentially increase performance by + using pinned memory for model weights. ``remote_device`` must be + ``"cpu"``. Defaults to pin_memory value in config, otherwise ``False``. + config_dict_or_path (dict or ``json file``, optional): If provided, provides configuration + for swapping fp16 params to NVMe. + config (dict or ``json file``, optional): Deprecated, use config_dict_or_path instead. + enabled (bool, optional): If ``False``, this context has no + effect. Defaults to ``True``. + dtype (``dtype``, optional): Can be used to change the data type of the parameters. + Supported options are ``torch.half`` and ``torch.float``. Defaults to ``None`` + mpu (``object``, optional): A model parallelism unit object that implements get_{model,data}_parallel_{rank,group,world_size}. + zero_param_parallel_group(``object``, optional): Parallel (comm) group for dual partitioning of ZeRO params. + zero_quantized_weights (bool, optional): If ``True``, turn on quantized weights in all gather weights. Default is ``False`` + zero_quantized_nontrainable_weights (bool, optional): If ``True``, nontrainable weights will be stored in quantized format. Default is ``False`` + param_swapper (``deepspeed.runtime.swap_tensor.partitioned_param_swapper.AsyncPartitionedParameterSwapper``, optional): [Experimental] Use existing parameter swapper. Defaults to ``None``. + This argument will be removed in the near future. + + This context accelerates model initialization and enables models that + are too large to allocate in their entirety in CPU memory. It has the + following effects: + + #. allocates tensors to either GPU or CPU memory or NVMe + #. converts floating point tensors to half precision + #. immediately partitions tensors among the group of data-parallel devices + #. (*optional*) replaces ``torch.nn.functional.linear`` with a more + memory-efficient implementation + + These modifications allow for models that exceed the size of local CPU/GPU + memory/NVMe, but fit within the total NVMe capacity (*i.e.*, aggregate CPU + or GPU memory or NVMe) across all nodes. Consider initializing a model with one + trillion parameters, whose weights occupy two terabytes (TB) in half + precision. The initial CPU allocation in full precision requires 4TB of + memory *per process*, and so a system with 8 GPUs per node would need 32TB of + CPU memory due to data-parallel redundancies. Instead, by immediately + partitioning tensors we remove the redundancies. The result is that + regardless of the number of GPUs, we still only require the original 4TB. This + allows for a linear increase in model size with the aggregate system memory. + For example, if a node has 1TB of memory and 8 GPUs, we could fit a trillion + parameter model with 4 nodes and 32 GPUs. + + Important: If the fp16 weights of the model can't fit onto a single GPU memory + this feature must be used. + + .. note:: + Initializes ``deepspeed.comm`` if it has not already been done so. + See :meth:`deepspeed.init_distributed` for more information. + + .. note:: + Only applicable to training with ZeRO-3. + + Examples + -------- + + #. Allocate a model and partition it among all processes: + + .. code-block:: python + + with deepspeed.zero.Init(): + model = MyLargeModel() + + + #. Allocate a model in pinned CPU memory and partition it among a subgroup of processes: + + .. code-block:: python + + with deepspeed.zero.Init(data_parallel_group=mpu.get_data_parallel_group(), + remote_device="cpu", + pin_memory=True): + model = MyLargeModel() + + + #. Partition an already-allocated model in CPU memory: + + .. code-block:: python + + model = deepspeed.zero.Init(module=model) + """ + if config is not None: + config_dict_or_path = config + logger.warning( + f'zero.Init: the `config` argument is deprecated. Please use `config_dict_or_path` instead.') + _ds_config = deepspeed.runtime.config.DeepSpeedConfig(config_dict_or_path, + mpu) if config_dict_or_path is not None else None + if _ds_config is not None: + mem_efficient_linear = _ds_config.zero_config.memory_efficient_linear + + super().__init__(enabled=enabled, mem_efficient_linear=mem_efficient_linear, ds_config=_ds_config, dtype=dtype) + if not dist.is_initialized(): + init_distributed() + assert dist.is_initialized(), "Parameters cannot be scattered without initializing deepspeed.comm" + + if data_parallel_group is None: + self.ds_process_group = dist.get_world_group() + else: + self.ds_process_group = data_parallel_group + + if sequence_data_parallel_group is not None: + logger.warning( + f"sequence_data_parallel_group' is deprecated and will be removed. Use 'data_parallel_group' instead.") + if data_parallel_group is not None: + raise ValueError( + "Both 'data_parallel_group' and 'sequence_data_parallel_group' were specified. Please provide only one of these arguments." + ) + self.ds_process_group = sequence_data_parallel_group + + self.rank = dist.get_rank(group=self.ds_process_group) + self.dp_world_size = dist.get_world_size(group=self.ds_process_group) + + self.zero_param_process_group = zero_param_parallel_group + if _ds_config is not None and _ds_config.zero_config.zero_hpz_partition_size > 1 and self.zero_param_process_group is None: + groups._create_zero_param_parallel_group(_ds_config.zero_config.zero_hpz_partition_size) + self.zero_param_process_group = groups._get_zero_param_intra_parallel_group() + + self.num_ranks_in_param_group = self.dp_world_size + self.rank_in_group = self.rank + self.num_param_groups = 1 + + if self.zero_param_process_group is not None: + self.num_ranks_in_param_group = groups._get_zero_param_intra_parallel_group_world_size() + self.num_param_groups = int(self.dp_world_size / self.num_ranks_in_param_group) + self.rank_in_group = groups._get_zero_param_intra_parallel_rank_in_mygroup() + print_rank_0(f"hpZeRO group size: {self.num_ranks_in_param_group}", force=True) + + logger.debug( + "hpZeRO partition parameter my rank in world {} my rank in group {} ranks in my param partition group: {} " + .format(self.rank, self.rank_in_group, groups._get_zero_param_intra_parallel_group_ranks())) + + # Local device is the device where the parameters are consumed, must be default device. + # It is the device where parameters are fully instantiated using allgather + self.local_device = torch.device(get_accelerator().device_name(os.environ["LOCAL_RANK"])) + get_accelerator().set_device(self.local_device) + + self.quantized_weights = zero_quantized_weights + if _ds_config is not None and _ds_config.zero_config.zero_quantized_weights and not self.quantized_weights: + self.quantized_weights = _ds_config.zero_config.zero_quantized_weights + self.quantized_nontrainable_weights = zero_quantized_nontrainable_weights + if _ds_config is not None and _ds_config.zero_config.zero_quantized_nontrainable_weights and not self.quantized_nontrainable_weights: + self.quantized_nontrainable_weights = _ds_config.zero_config.zero_quantized_nontrainable_weights + + self.module = module + if (self.quantized_weights or self.quantized_nontrainable_weights): + self.quantizer_module = CUDAQuantizer() + print_rank_0(f'Using quantizer for weights: {self.quantizer_module.__class__.__name__}', force=True) + + if _ds_config is not None: + Init.override_module_apply = _ds_config.zero_config.override_module_apply + + if _ds_config.zero_config.offload_param is not None: + remote_device = _ds_config.zero_config.offload_param.device + pin_memory = _ds_config.zero_config.offload_param.pin_memory + + self._validate_remote_device(remote_device, _ds_config) + + # Remote device is the device where parameter partitions are stored + # It can be same as local_device or it could be CPU or NVMe. + self.remote_device = self.local_device if remote_device in [None, OffloadDeviceEnum.none] else remote_device + self.pin_memory = pin_memory if (self.remote_device in [OffloadDeviceEnum.cpu, OffloadDeviceEnum.nvme + ]) else False + + # Enable fp16 param swapping to NVMe + if self.remote_device == OffloadDeviceEnum.nvme: + self.param_swapper = param_swapper or AsyncPartitionedParameterSwapper(_ds_config, self.dtype) + else: + self.param_swapper = None + + # If we are provided an already-allocated module to prepare. + if module is not None: + assert isinstance(module, torch.nn.Module) + self._convert_to_zero_parameters(module.parameters(recurse=True)) + + self.use_all_gather_into_tensor = dist.has_all_gather_into_tensor() + if not self.use_all_gather_into_tensor: + logger.info(f"all_gather_into_tensor API is not available in torch {torch.__version__}") + + self.use_all_reduce_for_fetch_params = get_config_default(DeepSpeedZeroConfig, + "use_all_reduce_for_fetch_params") + if _ds_config is not None: + self.use_all_reduce_for_fetch_params = _ds_config.zero_config.use_all_reduce_for_fetch_params + + def _update_persist_config(self, ds_config): + Init.apply_param_persistence = True + Init.param_persistence_threshold = ds_config.zero_config.param_persistence_threshold + Init.model_persistence_threshold = ds_config.zero_config.model_persistence_threshold // self.num_partitions + + def _zero_init_param(self, param): + self._convert_to_deepspeed_param(param) + if dist.get_world_group() == self.get_dp_process_group(): + dist.broadcast(param.data, 0, self.get_dp_process_group()) + else: + dist.broadcast(param.data, dist.get_global_rank(self.get_dp_process_group(), 0), + self.get_dp_process_group()) + param.partition() + + def _convert_to_zero_parameters(self, param_list): + for param in param_list: + if is_zero_param(param): + continue + + param.data = param.data.to(self.local_device) + self._zero_init_param(param) + + def _validate_remote_device(self, remote_device, ds_config): + if ds_config is not None: + if remote_device in [None, OffloadDeviceEnum.cpu]: + if ds_config.zero_config.offload_param is not None: + offload_param_device = ds_config.zero_config.offload_param.device + assert offload_param_device != OffloadDeviceEnum.nvme, \ + f"'device' in DeepSpeed Config cannot be {offload_param_device} if remote device is {remote_device}." + + if remote_device == OffloadDeviceEnum.nvme: + assert ds_config.zero_config.offload_param is not None, \ + f'"offload_param" must be defined in DeepSpeed Config if remote device is {OffloadDeviceEnum.nvme}.' + + assert ds_config.zero_config.offload_param.nvme_path is not None, \ + f'"nvme_path" in DeepSpeed Config cannot be None if remote device is {OffloadDeviceEnum.nvme}' + + def _post_init_method(self, module): + #see_memory_usage(f"Before converting params in {module.__class__.__name__}", force=False) + print_rank_0(f'Converting Params in {module.__class__.__name__}', force=False) + see_memory_usage(f"Before converting and partitioning params in {module.__class__.__name__}", force=False) + + for name, param in module.named_parameters(recurse=False): + print_rank_0(f'Analyzing param {name} in {module.__class__.__name__}', force=False) + InsertPostInitMethodToModuleSubClasses.num_module_parameters += 1 + InsertPostInitMethodToModuleSubClasses.num_module_elements += param.numel() + if not is_zero_param(param): + if not get_accelerator().on_accelerator(param): + param.data = param.data.to(self.local_device) + + if name == 'weight' and self.quantized_initialization and type(module) in WEIGHT_QUANTIZATION_LAYERS: + _quantize_param(param, self.quantized_initialization) + + self._zero_init_param(param) + print_rank_0( + f"Partitioning param {debug_param2name_id_shape(param)} module={debug_module2name(module)}") + + see_memory_usage( + f"Param count {InsertPostInitMethodToModuleSubClasses.num_module_elements}. After converting and partitioning params in {module.__class__.__name__}", + force=False) + + def _convert_to_deepspeed_param(self, param): + + # Partitioned, Normal, Remote + param.ds_param_type = ZeroParamType.PARTITIONED + + # Replicated vs Partitioned vs Inflight + param.ds_status = ZeroParamStatus.AVAILABLE + + # Stores the shape of the original tensor + param.ds_shape = param.shape + + # Stores the number of elements in the original parameter without padding + param.ds_numel = param.numel() + + # Stores the partitioned copy of the tensor + param.ds_tensor = None + + # Keeps track of how many active sub-modules need this param at any given point in time + param.ds_active_sub_modules = set() + + # If this flag is true, then the parameters are replicated throughput training + # And only partitioned before the step + if Init.apply_param_persistence and param.ds_numel <= Init.param_persistence_threshold and Init.num_persisted_elements + param.ds_numel <= Init.model_persistence_threshold: + param.ds_persist = True + Init.num_persisted_parameters += 1 + Init.num_persisted_elements += param.ds_numel + else: + param.ds_persist = False + + param.is_external_param = False + + # The group that the parameter is scattered across. + param.ds_process_group = self.ds_process_group + + # Stores the secondary partitioned copy of the tensor + param.ds_secondary_tensor = None + + #Process group for secondary partition all (group) gather + param.ds_zero_param_process_group = self.zero_param_process_group + param.ds_secondary_tensor_group_size = self.num_ranks_in_param_group + param.ds_secondary_tensor_num_of_groups = self.num_param_groups + + # This is set to the Async Param swapper if remote device is nvme + # else this is set to None + param.nvme_swapper = self.param_swapper + + # DeepSpeed Param ID + param.ds_id = Init.param_id + Init.param_id += 1 + + def all_gather(param_list=None, async_op=False, hierarchy=0): + cls = param + if param_list is None: + param_list = [cls] + return self._all_gather(param_list, async_op=async_op, hierarchy=hierarchy) + + def _all_gather_dtype(dtype, params, world_size, rank_in_group, ds_process_group): + partition_sz = sum(p.ds_tensor.ds_numel for p in params) + + use_secondary_tensor = params[0].ds_secondary_tensor is not None + + if use_secondary_tensor: + partition_sz = sum(p.ds_tensor.ds_numel * p.ds_secondary_tensor_num_of_groups for p in params) + + flat_tensor = torch.empty(partition_sz * world_size, + dtype=dtype, + device=get_accelerator().current_device_name(), + requires_grad=False) + + partitions: List[Parameter] = [] + for i in range(world_size): + partitions.append(flat_tensor.narrow(0, partition_sz * i, partition_sz)) + + if use_secondary_tensor: + instrument_w_nvtx( + torch.cat)([p.ds_secondary_tensor.to(get_accelerator().current_device_name()) for p in params], + out=partitions[rank_in_group]) + else: + instrument_w_nvtx(torch.cat)([p.ds_tensor.to(get_accelerator().current_device_name()) for p in params], + out=partitions[rank_in_group]) + handle = _dist_allgather_fn(partitions[rank_in_group], flat_tensor, ds_process_group) + #Fix get_partition_dp_group(params[0])) + + return AllGatherCoalescedHandle( + allgather_handle=handle, + params=params, + partitions=partitions, + world_size=world_size, + use_secondary_tensor=use_secondary_tensor, + ) + + @instrument_w_nvtx + def all_gather_coalesced(params: Iterable[Parameter], + safe_mode: bool = False, + quantize: bool = False) -> AllGatherCoalescedHandle: + + # fetches from nvme if the partition is not available and in nvme + self._ensure_availability_of_partitioned_params(params) + + if self.num_partitions == 1: + return _no_gather_coalesced(params) + + for param in params: + if param.ds_status != ZeroParamStatus.NOT_AVAILABLE: + raise RuntimeError(param.ds_summary()) + param.ds_status = ZeroParamStatus.INFLIGHT + + #use appropriate all gather process group + ds_process_group = self.ds_process_group + rank_in_group = self.rank + world_size = self.dp_world_size + use_secondary_tensor = params[0].ds_secondary_tensor is not None + if self.zero_param_process_group and use_secondary_tensor: + ds_process_group = self.zero_param_process_group #intragroup + rank_in_group = self.rank_in_group + world_size = self.num_ranks_in_param_group + + #pprint(dir(ds_process_group)) + # ensure that each rank has params in same order. the allgather + # is done by flattening the parameter list into a single tensor that + # can be allgathered in a single call - this means that if each rank + # gives a list of the same parameters in a different order we will + # silently get incorrect parameter values, and have very difficult + # to debug correctness issues. + params = sorted(params, key=lambda p: p.ds_id) + + if logger.isEnabledFor(logging.DEBUG): + debug_rank0(f"-allgather_coalesced: {[p.ds_id for p in params]}") + + if safe_mode: + # ensure that same list (with same ordering) of parameters are + # being allgathered across all ranks, otherwise could mix + # data between tensors. + assert_ints_same_as_other_ranks([p.ds_id for p in params]) + # ensure that tensors from each rank agree on the same ds_numel + # otherwise could mix data between tensors. + assert_ints_same_as_other_ranks([p.ds_tensor.ds_numel for p in params]) + + if len(params) == 1: + # have an opportunity to avoid some intermediate memory allocations + param = params[0] + buffer_size = math.ceil(param.ds_numel / world_size) * world_size + if use_secondary_tensor: + buffer_size = param.ds_secondary_tensor.shape[0] * world_size #make sure out is appropriately sized + + param_ds_tensor = param.ds_secondary_tensor if use_secondary_tensor else param.ds_tensor + param_buffer = torch.empty( + buffer_size, + dtype=param_ds_tensor.dtype if not quantize else torch.int8, + device=get_accelerator().current_device_name(), + requires_grad=False, + ) + if not quantize: + handles = _dist_allgather_fn( + param_ds_tensor.to(get_accelerator().current_device_name()), + param_buffer, + ds_process_group, + ) + param.data = param_buffer.narrow(0, 0, param.ds_numel).view(param.ds_shape).to(param.device) + return AllGatherHandle(handles, param) + else: + if hasattr(param_ds_tensor, "ds_quant_scale"): + scales = param_ds_tensor.ds_quant_scale + quantized_param = param_ds_tensor.data + else: + quantized_param, scales = self.quantizer_module.quantize(param_ds_tensor) + handle = _dist_allgather_fn(quantized_param.to(get_accelerator().current_device_name()), + param_buffer, ds_process_group) + + quant_scale_buffer = torch.empty( + scales.numel() * world_size, + dtype=scales.dtype, + device=get_accelerator().current_device_name(), + requires_grad=False, + ) + quant_handle = _dist_allgather_fn(scales.to(get_accelerator().current_device_name()), + quant_scale_buffer, ds_process_group) + quant_info = QuantizationInfo() + quant_info.quantized_param = param_buffer.narrow(0, 0, param.ds_numel).view(param.ds_shape).to( + param.device) + quant_info.backend = self.quantizer_module + quant_info.quant_handle = quant_handle + quant_info.scale_buffer = quant_scale_buffer + return AllGatherHandle(handle, param, quantization=quant_info) + + else: + if self.use_all_reduce_for_fetch_params and not quantize and not use_secondary_tensor: + # Use all_reduce instead of all_gather to fetch the module params + flat_buffer_size = sum(p.ds_numel_aligned for p in params) + flat_tensor = torch.zeros(flat_buffer_size, + dtype=get_only_unique_item(p.ds_tensor.dtype for p in params), + device=get_accelerator().current_device_name(), + requires_grad=False) + start_param = 0 + for param in params: + param.data = flat_tensor.narrow(0, start_param, param.ds_numel).view(param.ds_shape) + start = start_param + param.ds_tensor.ds_numel * self.get_partition_rank() + flat_tensor.narrow(0, start, param.ds_tensor.ds_numel).copy_(param.ds_tensor) + + start_param += param.ds_numel + + handle = dist.all_reduce(flat_tensor, group=ds_process_group, async_op=True) + + return AllReduceCoalescedHandle(handle=handle, params=params) + else: + if not quantize: + dtype_params = defaultdict(list) + for p in params: + dtype_params[p.ds_tensor.dtype].append(p) + handles = [] + for dtype, params in dtype_params.items(): + handles.append( + _all_gather_dtype(dtype, params, world_size, rank_in_group, ds_process_group)) + + return MultipleAllGatherHandles(handles) + + else: + partition_sz = sum(p.ds_tensor.ds_numel for p in params) + + if use_secondary_tensor: + partition_sz = sum(p.ds_tensor.ds_numel * p.ds_secondary_tensor_num_of_groups + for p in params) + + flat_tensor = torch.empty(partition_sz * world_size, + dtype=torch.int8, + device=get_accelerator().current_device_name(), + requires_grad=False) + + if use_secondary_tensor: + if hasattr(params[0].ds_secondary_tensor, "ds_quant_scale"): + quantized_param = instrument_w_nvtx(torch.cat)([ + p.ds_secondary_tensor.data.to(get_accelerator().current_device_name()) + for p in params + ]) + scales = instrument_w_nvtx(torch.cat)([ + p.ds_secondary_tensor.ds_quant_scale.to(get_accelerator().current_device_name()) + for p in params + ]) + else: + quantized_param, scales = self.quantizer_module.quantize( + instrument_w_nvtx(torch.cat)([ + p.ds_secondary_tensor.to(get_accelerator().current_device_name()) + for p in params + ])) + else: + if hasattr(params[0].ds_tensor, "ds_quant_scale"): + quantized_param = instrument_w_nvtx(torch.cat)( + [p.ds_tensor.data.to(get_accelerator().current_device_name()) for p in params]) + scales = instrument_w_nvtx(torch.cat)([ + p.ds_tensor.ds_quant_scale.to(get_accelerator().current_device_name()) + for p in params + ]) + else: + quantized_param, scales = self.quantizer_module.quantize( + instrument_w_nvtx(torch.cat)( + [p.ds_tensor.to(get_accelerator().current_device_name()) for p in params])) + quant_scale_buffer = torch.empty( + scales.numel() * world_size, + dtype=torch.float32, + device=get_accelerator().current_device_name(), + requires_grad=False, + ) + handle = _dist_allgather_fn(quantized_param, flat_tensor, ds_process_group) + quant_handle = _dist_allgather_fn(scales, quant_scale_buffer, ds_process_group) + quant_info = QuantizationInfo() + quant_info.quantized_param = flat_tensor + quant_info.backend = self.quantizer_module + quant_info.quant_handle = quant_handle + quant_info.scale_buffer = quant_scale_buffer + quant_info.partition_sz = partition_sz + quant_info.world_size = world_size + return AllGatherCoalescedHandle( + allgather_handle=handle, + params=params, + partitions=None, + world_size=world_size, + use_secondary_tensor=use_secondary_tensor, + quantization=quant_info, + ) + + def partition(param_list=None, hierarchy=0, has_been_updated=False): + cls = param + print_rank_0(f"{'--'*hierarchy}----Partitioning param {debug_param2name_id_shape_device(cls)}", + force=False) + if param_list is None: + param_list = [cls] + self._partition(param_list, has_been_updated=has_been_updated) + + def reduce_gradients_at_owner(param_list=None, hierarchy=0): + cls = param + if param_list is None: + param_list = [cls] + print_rank_0( + f"{'--'*hierarchy}----Reducing Gradients for param with ids {[param.ds_id for param in param_list]} to owner" + ) + self._reduce_scatter_gradients(param_list) + + def partition_gradients(param_list=None, partition_buffers=None, hierarchy=0, accumulate=False): + cls = param + print_rank_0( + f"{'--'*hierarchy}----Partitioning param gradient with id {debug_param2name_id_shape_device(cls)}") + if param_list is None: + param_list = [cls] + if isinstance(partition_buffers, torch.Tensor): + partition_buffers = [partition_buffers] + + self._partition_gradients(param_list, partition_buffers=partition_buffers, accumulate=accumulate) + + def aligned_size(): + return self._aligned_size(param) + + def padding_size(): + return self._padding_size(param) + + def partition_numel(): + return self._partition_numel(param) + + def item_override(): + param.all_gather() + return param._orig_item() + + def ds_summary(slf: torch.Tensor, use_debug_name: bool = False) -> dict: + return { + "id": debug_param2name_id(slf) if use_debug_name else slf.ds_id, + "status": slf.ds_status.name, + "numel": slf.numel(), + "ds_numel": slf.ds_numel, + "shape": tuple(slf.shape), + "ds_shape": tuple(slf.ds_shape), + "requires_grad": slf.requires_grad, + "grad_shape": tuple(slf.grad.shape) if slf.grad is not None else None, + "persist": slf.ds_persist, + "active_sub_modules": slf.ds_active_sub_modules, + "ds_tensor.shape": slf.ds_tensor.shape if slf.ds_tensor is not None else None + } + + def convert_to_zero_parameters(param_list): + self._convert_to_zero_parameters(param_list) + + def allgather_before(func: Callable) -> Callable: + + def wrapped(*args, **kwargs): + param.all_gather() + return func(*args, **kwargs) + + return wrapped + + # Collectives for gathering and partitioning parameters + param.all_gather = all_gather + param.all_gather_coalesced = all_gather_coalesced + param.partition = partition + + # Collective for averaging gradients + param.reduce_gradients_at_owner = reduce_gradients_at_owner + param.partition_gradients = partition_gradients + + # Partitioning size utilities + param.aligned_size = aligned_size + param.padding_size = padding_size + param.partition_numel = partition_numel + param.ds_summary = types.MethodType(ds_summary, param) + + param.item = allgather_before(param.item) + + param.convert_to_zero_parameters = convert_to_zero_parameters + + def _aligned_size(self, param): + return param.ds_numel + self._padding_size(param) + + def _padding_size(self, param): + remainder = param.ds_numel % self.num_partitions + return (self.num_partitions - remainder) if remainder else 0 + + def _partition_numel(self, param): + return param.ds_tensor.ds_numel + + def _ensure_availability_of_partitioned_params(self, params): + swap_in_list = [] + swap_in_flight = [] + for param in params: + if param.ds_tensor.status == PartitionedParamStatus.NOT_AVAILABLE: + assert param.ds_tensor.final_location == OffloadDeviceEnum.nvme and param.ds_status == ZeroParamStatus.NOT_AVAILABLE + swap_in_list.append(param) + if param.ds_tensor.status == PartitionedParamStatus.INFLIGHT: + assert param.ds_tensor.final_location == OffloadDeviceEnum.nvme and param.ds_status == ZeroParamStatus.NOT_AVAILABLE + swap_in_flight.append(param) + if len(swap_in_list) > 0: + swap_in_list[0].nvme_swapper.swap_in(swap_in_list, async_op=False) + elif len(swap_in_flight) > 0: + swap_in_flight[0].nvme_swapper.synchronize_reads() + + @instrument_w_nvtx + def _all_gather(self, param_list, async_op=False, hierarchy=None): + + # fetches from nvme if the partition is not available and in nvme + self._ensure_availability_of_partitioned_params(param_list) + + handles = [] + all_gather_list = [] + for param in param_list: + if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: + if async_op: + handle = self._allgather_param(param, async_op=async_op, hierarchy=hierarchy) + param.ds_status = ZeroParamStatus.INFLIGHT # if async_op else ZeroParamStatus.AVAILABLE + handles.append(handle) + else: + all_gather_list.append(param) + # note: param_list may contain params that are already in flight / aviailable. So we need to use all_gather_list + if not async_op: + if len(all_gather_list) == 1: + ret_value = self._allgather_params(all_gather_list, hierarchy=hierarchy) + else: + all_gather_quantize_list = [] + all_gather_nonquantize_list = [] + for param in all_gather_list: + if hasattr(param.ds_tensor, + "ds_quant_scale") or (hasattr(param, "ds_secondary_tensor") + and hasattr(param.ds_secondary_tensor, "ds_quant_scale")): + all_gather_quantize_list.append(param) + else: + all_gather_nonquantize_list.append(param) + # _allgather_params_coalesced always return None + self._allgather_params_coalesced(all_gather_nonquantize_list, hierarchy, quantize=False) + self._allgather_params_coalesced(all_gather_quantize_list, hierarchy, quantize=True) + for param in all_gather_list: + param.ds_status = ZeroParamStatus.AVAILABLE + return None + + return handles + + def _partition(self, param_list, force=False, has_been_updated=False): + for param in param_list: + print_rank_0(f"Before Partitioning Param {param.ds_id}", force=False) + if self.zero_param_process_group is not None: + self._partition_param_sec(param) + self._partition_param(param, has_been_updated=has_been_updated) + + param.ds_status = ZeroParamStatus.NOT_AVAILABLE + # if param.ds_tensor is not None: + # assert id(param.data) == id(param.ds_tensor.data), \ + # "After the parameters are initially partitioned, make sure we are not recreating the partition." + #print_rank_0(f"After Partitioning Param {param.ds_id} {param.ds_tensor.size()} {param.ds_tensor}",force=False) + @instrument_w_nvtx + def _partition_param(self, param, buffer=None, has_been_updated=False): + assert param.ds_status is not ZeroParamStatus.INFLIGHT, f" {param} Cannot partition a param in flight" + global reuse_buffers + print_rank_0(f"Param id {param.ds_id} status is {param.ds_status}", force=False) + if param.ds_status is ZeroParamStatus.AVAILABLE: + print_rank_0(f"Partitioning param id {param.ds_id} reuse buffers {reuse_buffers}", force=False) + # if reuse_buffers and False: + # numel = buffer.numel() + # buffer = param.data.view(-1) + # print_rank_0( + # "Returning buffer for param {param.ds_id} with numel {param.ds_numel} to empty buffers", + # force=False) + # if numel in empty_buffers: + # empty_buffers[numel].append(buffer) + + # if deepspeed.comm.get_rank(): + # print(f"Releasing {param.data.numel()}") + + if param.ds_tensor is not None and not has_been_updated: ##param already partitioned + + #print_rank_0(f"Param {param.ds_id} pri {param.ds_tensor.size()} loc? {param.ds_tensor.final_location}", force=True) + #param.data = param.ds_tensor.data + + see_memory_usage(f'Before partitioning param {param.ds_id} {param.shape}', force=False) + # param.data does not store anything meaningful in partitioned state + free_param(param) + see_memory_usage(f'After partitioning param {param.ds_id} {param.shape}', force=False) + + if param.ds_tensor.final_location == OffloadDeviceEnum.nvme: + print_rank_0(f"Param {param.ds_id} partition released since it exists in nvme", force=False) + param.nvme_swapper.remove_partition_and_release_buffers([param]) + print_rank_0( + f"after swap Param {param.ds_id} {param.ds_tensor.shape} partition released since it exists in nvme", + force=False) + + return + + tensor_size = self._aligned_size(param) + partition_size = tensor_size // self.num_partitions + if param.ds_tensor is None: + final_location = None + if self.remote_device == OffloadDeviceEnum.nvme and self.param_swapper.swappable_tensor( + numel=partition_size): + final_location = OffloadDeviceEnum.nvme + buffer = self.param_swapper.get_buffer(param, partition_size) + partitioned_tensor = torch.empty(0, dtype=param.dtype, device=buffer.device) + partitioned_tensor.data = buffer.data + print_rank_0(f"ID {param.ds_id} Initializing partition for the first time for nvme offload.") + + else: + if param.ds_persist: + device = self.local_device + elif self.remote_device == OffloadDeviceEnum.nvme: + device = OffloadDeviceEnum.cpu + else: + device = self.remote_device + + partitioned_tensor = torch.empty(partition_size, dtype=param.dtype, device=device) + # quantize the tensor if it's not trainable + if not param.requires_grad and self.quantized_nontrainable_weights: + partitioned_tensor, partitioned_tensor.ds_quant_scale = self.quantizer_module.quantize( + partitioned_tensor) + + if device == OffloadDeviceEnum.cpu and self.pin_memory: + partitioned_tensor = get_accelerator().pin_memory(partitioned_tensor) + + partitioned_tensor.requires_grad = False + param.ds_tensor = partitioned_tensor + param.ds_tensor.ds_numel = partition_size + param.ds_tensor.status = PartitionedParamStatus.AVAILABLE + param.ds_tensor.final_location = final_location + param.ds_numel_aligned = tensor_size + + start = partition_size * self.get_partition_rank() + end = start + partition_size + + one_dim_param = param.contiguous().view(-1) + + if start < param.ds_numel and end <= param.ds_numel: + src_tensor = one_dim_param.narrow(0, start, partition_size) + + with torch.no_grad(): + # make sure param.ds_tensor requires_grad always be false, + # otherwise, torch tracer will complain. + param.ds_tensor.copy_(src_tensor) + + #partitioned_tensor = src_tensor.clone().detach().to(self.remote_device) + + else: + # partitioned_tensor = torch.zeros(partition_size, + # dtype=param.dtype, + # device=self.remote_device ) + + if start < param.ds_numel: + elems_to_copy = param.ds_numel - start + with torch.no_grad(): + # make sure param.ds_tensor requires_grad always be false, + # otherwise, torch tracer will complain. + param.ds_tensor.narrow(0, 0, + elems_to_copy).copy_(one_dim_param.narrow(0, start, elems_to_copy)) + + #print(f"Remote device {self.remote_device}") + + #param.ds_tensor = partitioned_tensor + + #param.data = param.ds_tensor.data + + # param.data does not store anything meaningful in partitioned state + + see_memory_usage(f'Before partitioning param {param.ds_id} {param.shape}', force=False) + free_param(param) + see_memory_usage(f'After partitioning param {param.ds_id} {param.shape}', force=False) + + if param.ds_tensor.final_location == OffloadDeviceEnum.nvme: + self.param_swapper.swap_out_and_release([param]) + print_rank_0(f"ID {param.ds_id} Offloaded to nvme offload and buffers released.") + see_memory_usage(f"ID {param.ds_id} Offloaded to nvme offload and buffers released.", force=False) + + print_rank_0(f"ID {param.ds_id} partitioned type {param.dtype} dev {param.device} shape {param.shape}") + + @instrument_w_nvtx + def _partition_param_sec(self, param, buffer=None, has_been_updated=False): + assert param.ds_status is not ZeroParamStatus.INFLIGHT, f" {param} Cannot partition a param in flight" + global reuse_buffers + ##support for NVME secondary param offload + #print_rank_0(f"SEC Param id {param.ds_id} status is {param.ds_status}", force=True) + if param.ds_status is ZeroParamStatus.AVAILABLE: + if param.ds_secondary_tensor is not None and not has_been_updated: ##param already partitioned + return + #check padding + tensor_size = self._aligned_size(param) + partition_size = tensor_size // self.dp_world_size + + secondary_partition_size = int(tensor_size // self.num_ranks_in_param_group) + if param.ds_secondary_tensor is None: + final_location = None + secondary_partitioned_tensor = torch.empty(secondary_partition_size, + dtype=param.dtype, + device=self.remote_device) + + if self.pin_memory: + secondary_partitioned_tensor = secondary_partitioned_tensor.pin_memory() + # quantize the tensor if it's not trainable + if not param.requires_grad and self.quantized_nontrainable_weights: + secondary_partitioned_tensor, secondary_partitioned_tensor.ds_quant_scale = self.quantizer_module.quantize( + secondary_partitioned_tensor) + secondary_partitioned_tensor.requires_grad = False + param.ds_secondary_tensor = secondary_partitioned_tensor + param.ds_secondary_tensor.ds_numel = secondary_partition_size + param.ds_secondary_tensor.status = PartitionedParamStatus.AVAILABLE + param.ds_secondary_tensor.final_location = final_location + + #use rank in group for secondary tensor + secondary_start = secondary_partition_size * self.rank_in_group + + secondary_end = secondary_start + secondary_partition_size + + one_dim_param = param.contiguous().view(-1) + + # ds_numel is unpadded, so the last chunk of the secondary tensor might not be secondary_partition_size + sec_numel = max(0, min(param.ds_numel - secondary_start, secondary_partition_size)) + + # copy from full tensor to secondary tensor + param.ds_secondary_tensor.narrow(0, 0, + sec_numel).copy_(one_dim_param.narrow(0, secondary_start, sec_numel)) + + # TODO: This is a temporary fix to avoid the issue that 2nd tensor all-gather happens before 2nd tensor partition is done + if not get_accelerator().resolves_data_dependency(): + get_accelerator().current_stream().synchronize() + + print_rank_0(f"{param.ds_id} partitioned type {param.dtype} dev {param.device} shape {param.shape}", + force=False) + + def _param_status(self, param): + if param.ds_tensor is not None: + print_rank_0( + f"Param id {param.ds_id}, param status: {param.ds_status}, param numel {param.ds_numel}, partitioned numel {param.ds_tensor.numel()}, data numel {param.data.numel()}" + ) + else: + print_rank_0( + f"Param id {param.ds_id}, param status: {param.ds_status}, param numel {param.ds_numel}, partitioned ds_tensor {param.ds_tensor}, data numel {param.data.numel()}" + ) + + def _allgather_param(self, param, async_op=False, hierarchy=0): + + partition_size = param.ds_tensor.ds_numel + + tensor_size = partition_size * self.num_partitions + aligned_param_size = self._aligned_size(param) + assert tensor_size == aligned_param_size, f'param id {param.ds_id} aligned size {aligned_param_size} does not match tensor size {tensor_size}' + + print_rank_0( + f"{'--'* hierarchy}---- Before allocating allgather param {debug_param2name_id_shape_status(param)} partition size={partition_size}" + ) + + see_memory_usage( + f'Before allocate allgather param {debug_param2name_id_shape_status(param)} partition_size={partition_size} ', + force=False) + flat_tensor = torch.zeros(aligned_param_size, dtype=param.dtype, device=param.device).view(-1) + see_memory_usage( + f'After allocate allgather param {debug_param2name_id_shape_status(param)} {aligned_param_size} {partition_size} ', + force=False) + + if not get_accelerator().resolves_data_dependency(): + get_accelerator().synchronize() + + print_rank_0( + f"{'--'* hierarchy}----allgather param with {debug_param2name_id_shape_status(param)} partition size={partition_size}" + ) + # if not flat_tensor.numel() > 100000: + # replicated_tensor = flat_tensor.narrow(0, + # 0, + # param.ds_numel).view(param.ds_shape) + # param.data = replicated_tensor.data + # return None + if self.use_all_gather_into_tensor: + handle = dist.all_gather_into_tensor(flat_tensor, + param.ds_tensor.to(get_accelerator().device_name()), + group=self.get_partition_dp_group(param), + async_op=async_op) + else: + partitions = [] + for i in range(self.num_partitions): + partitions.append(flat_tensor.narrow(0, partition_size * i, partition_size)) + + if i == dist.get_rank(group=self.get_partition_dp_group(param)): + partitions[i].data.copy_(param.ds_tensor.data, non_blocking=True) + + handle = dist.all_gather(partitions, + partitions[self.get_partition_rank()], + group=self.get_partition_dp_group(param), + async_op=async_op) + + replicated_tensor = flat_tensor.narrow(0, 0, param.ds_numel).view(param.ds_shape) + param.data = replicated_tensor.data + return handle + + def _allgather_params_coalesced(self, param_list, hierarchy=0, quantize=False): + """ blocking call + avoid explicit memory copy in _allgather_params + """ + if len(param_list) == 0: + return + + if self.num_partitions == 1: + handle = _no_gather_coalesced(param_list) + handle.wait() + return None + + # collect local tensors and partition sizes + partition_sizes = [] + local_tensors = [] + if quantize: + quantize_scale_sizes = [] + quantize_scale_tensors = [] + for param in param_list: + partition_sizes.append(param.ds_tensor.ds_numel) + local_tensors.append(param.ds_tensor.to(get_accelerator().device_name())) + if quantize: + quantize_scale_sizes.append(param.ds_tensor.ds_quant_scale.numel()) + quantize_scale_tensors.append(param.ds_tensor.ds_quant_scale.to(get_accelerator().device_name())) + # allocate memory for allgather params + allgather_params = [] + if quantize: + allgather_quantize_scale = [] + for psize in partition_sizes: + tensor_size = psize * self.num_partitions + flat_tensor = torch.empty(tensor_size, dtype=param_list[0].ds_tensor.dtype, + device=self.local_device).view(-1) + flat_tensor.requires_grad = False + allgather_params.append(flat_tensor) + if quantize: + for psize in quantize_scale_sizes: + tensor_size = psize * self.num_partitions + flat_tensor = torch.empty(tensor_size, + dtype=param_list[0].ds_tensor.ds_quant_scale.dtype, + device=self.local_device).view(-1) + flat_tensor.requires_grad = False + allgather_quantize_scale.append(flat_tensor) + + # launch + launch_handles = [] + launch_quantize_handles = [] + for param_idx, param in enumerate(param_list): + input_tensor = local_tensors[param_idx].view(-1) + + if self.use_all_gather_into_tensor: + # try the _all_gather_base from Pytorch master + h = dist.all_gather_into_tensor(allgather_params[param_idx], + input_tensor, + group=self.get_partition_dp_group(param), + async_op=True) + if quantize: + quantize_handle = dist.all_gather_into_tensor(allgather_quantize_scale[param_idx], + quantize_scale_tensors[param_idx], + group=self.get_partition_dp_group(param), + async_op=True) + launch_quantize_handles.append(quantize_handle) + else: + output_list = [] + for i in range(self.num_partitions): + psize = partition_sizes[param_idx] + partition = allgather_params[param_idx].narrow(0, i * psize, psize) + output_list.append(partition) + if not get_accelerator().on_accelerator(partition): + logger.warning( + f'param {param_idx}, partition {i} is not on CUDA, partition shape {partition.size()}') + + # back to old all_gather function + h = dist.all_gather(output_list, input_tensor, group=self.get_partition_dp_group(param), async_op=True) + if quantize: + output_scale_list = [] + for i in range(self.num_partitions): + psize = quantize_scale_sizes[param_idx] + partition = allgather_quantize_scale[param_idx].narrow(0, i * psize, psize) + output_scale_list.append(partition) + quant_handle = dist.all_gather(output_scale_list, + quantize_scale_tensors[param_idx], + group=self.get_partition_dp_group(param), + async_op=True) + launch_quantize_handles.append(quant_handle) + launch_handles.append(h) + + # Wait ensures the operation is enqueued, but not necessarily complete. + launch_handles[-1].wait() + if quantize: + for quant_handle in launch_quantize_handles: + quant_handle.wait() + + # assign to param.data (not copy) + for i, param in enumerate(param_list): + gathered_tensor = allgather_params[i] + if quantize: + gathered_tensor = self.quantizer_module.dequantize(gathered_tensor, allgather_quantize_scale[i]) + param.data = gathered_tensor.narrow(0, 0, param.ds_numel).view(param.ds_shape).data + + # guarantee the communication to be completed + if not get_accelerator().resolves_data_dependency(): + get_accelerator().synchronize() + + return None + + def _allgather_params(self, param_list, hierarchy=0): + if len(param_list) == 0: + return + + partition_size = sum([param.ds_tensor.ds_numel for param in param_list]) + + tensor_size = partition_size * self.num_partitions + flat_tensor = torch.empty(tensor_size, dtype=param_list[0].ds_tensor.dtype, device=self.local_device) + flat_tensor.requires_grad = False + partitions = [] + for i in range(self.num_partitions): + start = partition_size * i + + partitions.append(flat_tensor.narrow(0, start, partition_size)) + + if i == self.get_partition_rank(): + offset = 0 + for param in param_list: + param_numel = param.ds_tensor.ds_numel + + partitions[i].narrow(0, offset, param_numel).copy_(param.ds_tensor.data) + + offset += param_numel + + if hasattr(param_list[0], 'ds_quant_scale'): + scale_size = sum([param.ds_tensor.ds_quant_scale.numel() for param in param_list]) + scale_tensor_size = scale_size * self.world_size + flat_scale_tensor = torch.empty(scale_tensor_size, + dtype=param_list[0].ds_tensor.ds_quant_scale.dtype, + device=self.local_device) + flat_scale_tensor.requires_grad = False + scale_partitions = [] + for i in range(self.world_size): + start = scale_tensor_size * i + scale_partitions.append(flat_scale_tensor.narrow(0, start, scale_tensor_size)) + if i == self.rank: + offset = 0 + for param in param_list: + param_scale_numel = param.ds_tensor.ds_quant_scale.ds_numel + + scale_partitions[i].narrow(0, offset, + param_scale_numel).copy_(param.ds_tensor.ds_quant_scale.data) + + offset += param_scale_numel + + dist.all_gather_into_tensor(flat_tensor, + partitions[self.get_partition_rank()], + group=self.get_partition_dp_group(param), + async_op=False) + if hasattr(param_list[0], 'ds_quant_scale'): + dist.all_gather(flat_scale_tensor, + param_list[0].ds_quant_scale, + group=self.get_partition_dp_group(param), + async_op=False) + param_offset = 0 + + for param in param_list: + param_partition_size = param.ds_tensor.ds_numel + param_size = param.ds_numel + replicated_tensor = torch.empty(param.ds_shape, dtype=param.ds_tensor.dtype, device=self.local_device) + + for i in range(self.num_partitions): + + start = i * partition_size + + param_start = i * param_partition_size + + if param_start < param_size: + numel_to_copy = min(param_size - param_start, param_partition_size) + + part_to_copy = partitions[i].narrow(0, param_offset, numel_to_copy) + + replicated_tensor.view(-1).narrow(0, param_start, numel_to_copy).copy_(part_to_copy) + #param_offset += param.data.numel() + param_offset += param.ds_tensor.ds_numel + if hasattr(param_list[0], 'ds_quant_scale'): + replicated_tensor = self.quantizer_module.dequantize(replicated_tensor, flat_scale_tensor) + param.data = replicated_tensor.data + + return None + + def _reduce_scatter_gradients(self, param_list): + #print_rank_0([param.grad for param in param_list]) + #assert any([param.grad is None for param in param_list]), "None gradients cannot be reduce scattered" + + handles_and_reduced_partitions = [] + for param in param_list: + assert param.grad.numel( + ) == param.ds_numel, f"{param.grad.numel()} != {param.ds_numel} Cannot reduce scatter gradients whose size is not same as the params" + + handles_and_reduced_partitions.append(self._reduce_scatter_gradient(param)) + + for param, (handle, reduced_partition) in zip(param_list, handles_and_reduced_partitions): + if handle is not None: + handle.wait() + + # some ranks may have partitions that are padded to go beyond the grad size. + # For these ranks the output of reduce scatter is a separate buffer and needs + # to be copied in + partition_size = param.ds_tensor.ds_numel + start = self.get_partition_rank() * partition_size + end = start + partition_size + #print_rank_0("REduce scatter was executed for param {param.ds_id}") + if start < param.ds_numel < end: + elements = param.ds_numel - start + param.grad.view(-1).narrow(0, start, elements).copy_(reduced_partition.narrow(0, 0, elements)) + + def _reduce_scatter_gradient(self, param): + + partition_size = param.ds_tensor.ds_numel + #output = torch.empty(partition_size, dtype=param.dtype, device=param.device) + + total_size = partition_size * self.num_partitions + input_list = [] + + for i in range(self.num_partitions): + + start = i * partition_size + end = start + partition_size + + #print("before reduce scatter gradients") + if start < param.ds_numel and end <= param.ds_numel: + input = param.grad.view(-1).narrow(0, start, partition_size) + else: + input = torch.zeros(partition_size, dtype=param.dtype, device=param.device) + + if start < param.ds_numel: + elements = param.ds_numel - start + input.narrow(0, 0, elements).copy_(param.grad.view(-1).narrow(0, start, elements)) + #print("after reduce scatter gradients") + input_list.append(input) + + rank = dist.get_rank(group=self.get_partition_dp_group(param)) + handle = dist.reduce_scatter(input_list[rank], + input_list, + group=self.get_partition_dp_group(param), + async_op=True) + + return handle, input_list[rank] + + def _partition_gradients(self, param_list, partition_buffers=None, accumulate=False): + if partition_buffers is None: + partition_buffers = [None] * len(param_list) + + for param, partition_buffer in zip(param_list, partition_buffers): + self._partition_gradient(param, partition_buffer=partition_buffer, accumulate=accumulate) + + def _partition_gradient(self, param, partition_buffer=None, accumulate=False): + + #import pdb;pdb.set_trace() + # param.grad=None + # param.grad.test() + print_rank_0( + f"Partitioning param {param.ds_id} gradient of size {param.grad.numel()} type {param.grad.dtype} part_size {param.ds_tensor.ds_numel}" + ) + see_memory_usage("Before partitioning gradients", force=False) + partition_size = param.ds_tensor.ds_numel + + if partition_buffer is None: + assert not accumulate, "No buffer to accumulate to" + partition_buffer = torch.zeros(partition_size, dtype=param.dtype, device=param.device) + else: + assert partition_buffer.numel( + ) >= partition_size, f"The partition buffer size {partition_buffer.numel()} should match the size of param.ds_tensor {partition_size}" + + rank = dist.get_rank(group=self.get_partition_dp_group(param)) + start = partition_size * rank + end = start + partition_size + + dest_tensor_full_buffer = partition_buffer.view(-1).narrow(0, 0, partition_size) + + #print("before partition gradients") + if start < param.ds_numel: + elements = min(param.ds_numel - start, partition_size) + + dest_tensor = dest_tensor_full_buffer.narrow(0, 0, elements) + src_tensor = param.grad.view(-1).narrow(0, start, elements) + + # just copy the grad partition to the buffer + if not accumulate: + dest_tensor.copy_(src_tensor) + + # if source and destination are on same device, + # add to the provided buffer + elif src_tensor.device == dest_tensor.device: + dest_tensor.add_(src_tensor) + + # if source and destination are on different device, copy first to src + # then add and move back to the destination. This seems to run faster + # when src is gpu and dest is cpu + # adding directly to cpu is very slow + else: + acc_tensor = torch.empty(src_tensor.numel(), dtype=param.dtype, device=param.device) + + acc_tensor.copy_(dest_tensor) + acc_tensor.add_(src_tensor) + dest_tensor.copy_(acc_tensor) + + # partition_buffer.view(-1).narrow( + # 0, + # 0, + # elements).copy_(param.grad.view(-1).narrow(0, + # start, + # elements)) + + #print("after partition gradients") + param.grad.data = dest_tensor_full_buffer.data + see_memory_usage("After partitioning gradients", force=False) + + def get_partition_dp_group(self, param): + return param.ds_process_group + + def get_partition_rank(self): + """subclass can overload to specify different relative rank in + parameter partition group""" + return self.rank + + @property + def num_partitions(self): + return self.dp_world_size + + def get_dp_process_group(self): + """ Return the communication group with all data-parallel ranks """ + return self.ds_process_group + + +class GatheredParameters: + + def __init__(self, params, modifier_rank=None, fwd_module=None, enabled=True): + """A context that collects parameters that were partitioned via a + :class:`deepspeed.zero.Init` context. The parameters are partitioned + again upon exit. + + Args: + params (``torch.nn.Parameter``): A single parameter, or an iterable of parameters (list, tuple, generator) of parameters to collect. + It's assumed that all parameters are zero params. + modifier_rank (int, optional): If specified, this rank's parameter will be + broadcasted on exit from the context. This argument is required if ``params`` are + modified, so that all processes have a consistent view of the data. Defaults + to ``None``. + fwd_module (``torch.nn.Module``, optional): If specified, ``params`` will be + registered as external parameters of ``fwd_module``. See :meth:`deepspeed.zero.register_external_parameter`. + enabled (bool, optional): If ``False``, this context is a no-op. Defaults to ``True``. + + Important: Make sure to use ``modifier_rank`` that is not ``None`` (e.g., ``modifier_rank=0``) + if you need the GPU memory allocated by gather to be released upon exit from the context manager. + + Important: if ``params`` isn't an iterable of parameters or a single parameter it'll be silently ignored! + + Examples + ======== + + #. Allocate a partitioned module, initialize its weight on rank 0, and update all + processes. + + .. code-block:: python + + with deepspeed.zero.Init(): + linear = torch.nn.Linear(1000,1000) + + with deepspeed.zero.GatheredParameters(linear.weight, + modifier_rank=0): + if deepspeed.comm.get_rank() == 0: + linear.weight.zero_() + + with deepspeed.zero.GatheredParameters(linear.weight, + modifier_rank=0): + if deepspeed.comm.get_rank() == 0: + linear.weight.zero_() + + #. Collect a partitioned weight to pass to another module during + training. The parameter will be registered as an external parameter + and made available during the backward pass. + + .. code-block:: python + :emphasize-lines: 6 + + def forward(self, input): + x = self.layer1(input) + + # self.layer1.weight is required by self.layer2.forward + with deepspeed.zero.GatheredParameters(self.layer1.weight, + fwd_module=self): + y = self.layer2(x, self.layer1.weight) + return y + + + #. Pretrained model loading + + .. code-block:: python + + with deepspeed.zero.Init(): + model = MyModel() + + state_dict = torch.load(model_path, map_location="cpu") + + def load(module: nn.Module, prefix=""): + # because zero3 puts placeholders in model params, this context + # manager gathers (unpartitions) the params of the current layer, then loads from + # the state dict and then re-partitions them again + with deepspeed.zero.GatheredParameters(list(module.parameters(recurse=False)), modifier_rank=0): + if deepspeed.comm.get_rank() == 0: + module._load_from_state_dict(state_dict, prefix) + + for name, child in module._modules.items(): + if child is not None: + load(child, prefix + name + ".") + + load(model, prefix="") + + If this approach is not used, then the full model will first be copied to each GPU. For models + bigger than the memory of a single GPU, this method is required. + """ + + self.enabled = enabled + if not enabled: + return + + if isinstance(params, Iterable) and not isinstance(params, torch.Tensor): + # deal with generators like model.parameters() + # must convert to list to be able to iterate more than once if we get a generator + params = list(params) + else: + # single param + params = [params] + # enable if at least one is zero-param, otherwise a noop + if not any(is_zero_param(p) for p in params): + self.enabled = False + return + + self.params = [p for p in params if hasattr(p, "ds_id")] + self.params = sorted( + set(self.params), key=lambda x: x.ds_id + ) # remove the duplicates to prevent racing condition, we must also make sure the order is the same on all ranks otherwise we'll get deadlocks + self.src_rank = None + if modifier_rank is not None: + if self.params[0].ds_process_group == dist.get_world_group(): + self.src_rank = modifier_rank + else: + # A group was specified; convert DP rank to global rank + self.src_rank = dist.get_global_rank(self.params[0].ds_process_group, modifier_rank) + self.fwd_module = fwd_module + if self.fwd_module is not None: + # is a no-op if already registered + for p in self.params: + register_external_parameter(self.fwd_module, p) + + def __enter__(self): + if not self.enabled: + return + self.params[0].all_gather(param_list=self.params) + + def __exit__(self, *exc): + if not self.enabled: + return + if self.src_rank is None: + self.params[0].partition(param_list=self.params, has_been_updated=False) + return + + handles = [dist.broadcast(p.data, self.src_rank, group=p.ds_process_group, async_op=True) for p in self.params] + for h in handles: + h.wait() + self.params[0].partition(param_list=self.params, has_been_updated=True) diff --git a/toolbox/DeepSpeed/v0.15.3/patches/deepspeed/runtime/zero/partitioned_param_coordinator.py b/toolbox/DeepSpeed/v0.15.3/patches/deepspeed/runtime/zero/partitioned_param_coordinator.py new file mode 100644 index 0000000000000000000000000000000000000000..262446978ea76377e7053d6d725e08010d5b48c8 --- /dev/null +++ b/toolbox/DeepSpeed/v0.15.3/patches/deepspeed/runtime/zero/partitioned_param_coordinator.py @@ -0,0 +1,626 @@ +#!/usr/bin/env python3 +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from dataclasses import dataclass +import collections +from collections import UserDict +from typing import Deque, Set + +from deepspeed import comm as dist +from deepspeed.utils import z3_leaf_module +from deepspeed.utils.logging import logger +from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum +from deepspeed.runtime.zero.partition_parameters import * +from deepspeed.runtime.zero.partitioned_param_profiler import PartitionedParameterProfiler +from deepspeed.runtime.swap_tensor.partitioned_param_swapper import PartitionedParamStatus +from deepspeed.utils.debug import debug_module2name_id, debug_param2name_id +from deepspeed.accelerator import get_accelerator +import deepspeed.runtime.compiler as compiler + +import logging + +ENABLE_PROFILER = False + + +def debug_rank0(message: str) -> None: + if dist.get_rank() == 0: + logger.debug(message) + + +@instrument_w_nvtx +def get_all_parameters(sub_module, recurse=False): + return itertools.chain(sub_module.named_parameters(recurse=recurse), sub_module.ds_external_parameters()) + + +@compiler.disable +def iter_params(module: Module, recurse=False) -> Iterable[Parameter]: + return map(lambda pair: pair[1], get_all_parameters(module, recurse)) + + +class ZeRoTraceMode(Enum): + # Record trace of the network during a single forward+backward (for training) or forward (for inference) + RECORD = 1 + # Use recorded network trace to optimize current forward+backward or forward + COMPLETE = 2 + # Recorded trace does not match current forward+backward or forward pass. + INVALID = 3 + + +class InflightParamRegistry(UserDict): + """registry for parameters in flight""" + + def __setitem__(self, param: Parameter, handle: AllGatherCoalescedHandle) -> None: + if param in self.data: + raise RuntimeError(f"{param.ds_summary()} already in registry") + if param.ds_status != ZeroParamStatus.INFLIGHT: + raise RuntimeError(f"attempted to add non-inflight parameter to registry {param.ds_summary()}") + self.data[param] = handle + + +class PartitionedParameterCoordinator: + FORWARD_FETCH_SUBMIT = 'forward_fetch_submit' + FORWARD_FETCH_WAIT = 'forward_fetch_wait' + FORWARD_PREFETCH_SUBMIT = 'forward_prefetch_submit' + BACKWARD_FETCH_SUBMIT = 'backward_fetch_submit' + BACKWARD_FETCH_WAIT = 'backward_fetch_wait' + BACKWARD_PREFETCH_SUBMIT = 'backward_prefetch_wait' + FORWARD_ALL_GATHER = 'forward_all_gather' + BACKWARD_ALL_GATHER = 'backward_all_gather' + """Handles partitioning and gathering of parameters.""" + + @dataclass + class __ParamInTrace: + param: Parameter + step_id_last_used_at: int + + def __init__( + self, + prefetch_bucket_sz: int, + max_reuse_distance_in_numel: int, + max_available_parameters_in_numel: int, + allgather_stream: get_accelerator().Stream, + inflight_param_registry: InflightParamRegistry, + prefetch_nvme: bool = False, + timers=None, + zero_quantized_weights=False, + zero_quantized_nontrainable_weights=False, + ) -> None: + # mapping of param -> handle for each param that is currently in flight + self.__inflight_param_registry = inflight_param_registry + # keeps track of the number of submodules invoked so far. + self.__step_id: int = 0 + # network tracing mode + self.__trace_mode: ZeRoTraceMode = ZeRoTraceMode.RECORD + # sequence of submodules/parameters in forward pass + backward pass + self.__submodule_order: Iterable[Module] = [] + self.__param_order: Iterable[__class__.__ParamInTrace] = [] + self.__most_recent_step_id_param_fetched_for = collections.defaultdict(lambda: int(-1e10)) + self.__step_id_module_fetched_for = collections.defaultdict(lambda: collections.deque()) + # number of available params, and max number of available params + self.__n_available_params: int = 0 + self.__max_n_available_params: int = max_available_parameters_in_numel + # max distance between two use of the module beyond which module is released + self.__max_reuse_dist_in_numel: int = max_reuse_distance_in_numel + # queue for parameters to fetch. parameters will be popped off the left + # side of the dequeue as they are fetched + self.__param_queue: Deque[__class__.__ParamInTrace] = None + self.__prefetch_bucket_sz: int = prefetch_bucket_sz + self.__prefetch_nvme: bool = prefetch_nvme + self.hierarchy: int = 0 + self.zero_quantized_weights = zero_quantized_weights + self.zero_quantized_nontrainable_weights = zero_quantized_nontrainable_weights + + # stream that will be used for allgather operations + self.__allgather_stream: get_accelerator().Stream = allgather_stream + + # limit the number of fetch events that can be queued at once + # otherwise, what happens is memory is allocated by the host thread at the + # time of the call, but not used until later by the asynchronous cuda stream. + # allowing an infinite number of these to queue up causes a lot of memory + # pressure that then becomes detrimental to performance. + # this is a much less elegant way of fixing this vs something like using + # cudaMallocAsync/cudaFreeAsync. Choosing to not expose this to the user now + # because ideally in the future its replaced by an async allocation + # mechanism which doesn't require any configuration by the user. + self.__ongoing_fetch_events: Deque[get_accelerator().Event] = collections.deque() + # TODO. make this configurable via JSON + self.__max_ongoing_fetch_events: int = 2 + self.__profiler = PartitionedParameterProfiler(timers if ENABLE_PROFILER else None) + + self.opt_modules = os.getenv("OPT_MODULES", None) + if self.opt_modules: + self.limit_module_id = os.getenv("LIMIT_MODULE_ID", "0,100").split(',') + import math + self.use_norm_pad = 32 + self.limit_shape = math.ceil(int(os.getenv("LIMIT_SHAPE", 500)) / self.use_norm_pad) * self.use_norm_pad + + self.opt_modules = self.opt_modules.upper().split(',') + self.record_opt_modules = list() + self.device = torch.cuda.current_device() + self.force_release_false = torch.tensor(0, device=self.device, dtype=torch.int32) + self.force_release_true = torch.tensor(1, device=self.device, dtype=torch.int32) + self.stop_current_step_opt = False + self.cur_shape = -1 + + """Tracing and Tracking + TODO. consider performing trace before initializing PartitionedParameterCoordinator + and passing trace results into constructor. This way all the code in here can + just assume that the trace is complete and the results can be entirely + immutable. + + Bookkeeping operations used to track where we are in the forward/backward pass + """ + + def _clear_trace_structures(self) -> None: + self.__submodule_order = [] + self.__param_order = [] + self.__most_recent_step_id_param_fetched_for = collections.defaultdict(lambda: int(-1e10)) + self.__param_queue = None + + def is_complete_trace(self) -> bool: + return self.__trace_mode == ZeRoTraceMode.COMPLETE + + def is_invalid_trace(self) -> bool: + return self.__trace_mode == ZeRoTraceMode.INVALID + + def is_record_trace(self) -> bool: + return self.__trace_mode == ZeRoTraceMode.RECORD + + def _clean_inflight_param_registry(self) -> None: + for param, handle in self.__inflight_param_registry.items(): + handle.wait() + self.__release_param(param) + self.__inflight_param_registry.clear() + + def _invalidate_trace(self) -> None: + if self.is_invalid_trace(): + raise RuntimeError("attempted to invalidate already invalid trace") + self.__trace_mode = ZeRoTraceMode.INVALID + self._clear_trace_structures() + self._clean_inflight_param_registry() + + def trace_prologue(self, sub_module: Module) -> None: + if self.is_complete_trace(): + # sub_module must match expectation else invalidate trace cache + if len(self.__submodule_order) <= self.__step_id: + print_rank_0( + f"Invalidate trace cache @ step {self.__step_id} and module {sub_module.id}: " + f"cache has only {len(self.__submodule_order)} modules", + force=True) + self._invalidate_trace() + return + + if sub_module != self.__submodule_order[self.__step_id]: + expected_module_id = self.__submodule_order[self.__step_id].id + print_rank_0( + f"Invalidate trace cache @ step {self.__step_id}: " + f"expected module {expected_module_id}, but got module {sub_module.id}", + force=True) + self._invalidate_trace() + + @compiler.disable + def record_module(self, sub_module: Module) -> None: + """adds sub module to trace""" + if not self.is_record_trace(): + raise RuntimeError(f"attempted to record trace when status = {self.__trace_mode}") + + self.__submodule_order.append(sub_module) + self.__step_id_module_fetched_for[sub_module.id].append(self.__step_id) + + def record_parameters(self, sub_module: Module) -> None: + """adds sub module to trace""" + if not self.is_record_trace(): + raise RuntimeError(f"attempted to record trace when status = {self.__trace_mode}") + + step_id = self.__step_id_module_fetched_for[sub_module.id].popleft() + for param in sorted(set(iter_params(sub_module, recurse=z3_leaf_module(sub_module))), key=lambda p: p.ds_id): + self.__param_order.append(__class__.__ParamInTrace(param=param, step_id_last_used_at=step_id)) + + def construct_parameter_trace_from_module_trace(self): + """use module trace to construct parameter trace""" + self.__param_order = [] + for sub_module in self.__submodule_order: + self.record_parameters(sub_module) + + def reset_step(self) -> None: + """indicate that we have completed one fwd+bwd for the model""" + self._clean_inflight_param_registry() + if self.opt_modules: + self.stop_current_step_opt = False + self.force_release_false = torch.tensor(0, device=self.device, dtype=torch.int32) + self.force_release_true = torch.tensor(1, device=self.device, dtype=torch.int32) + + if self.__inflight_param_registry: + raise RuntimeError(f"still have inflight params " + f"{[p.ds_summary() for p in self.__inflight_param_registry.keys()]}") + + if not self.is_complete_trace(): # not self.trace_complete: + # Make sure that recorded submodule orders are identical across ranks + assert_ints_same_as_other_ranks([m.id for m in self.__submodule_order]) + + if self.is_record_trace(): + # Successfully recorded a trace + self.construct_parameter_trace_from_module_trace() + # Make sure that recorded parameter orders are identical across ranks + assert_ints_same_as_other_ranks([p.param.ds_id for p in self.__param_order]) + assert_ints_same_as_other_ranks([p.step_id_last_used_at for p in self.__param_order]) + + self.__submodule_order = tuple(self.__submodule_order) # freeze + self.__param_order = tuple(self.__param_order) # freeze + self.__trace_mode = ZeRoTraceMode.COMPLETE + print_rank_0( + f"completed record trace of {len(self.__submodule_order)} sub modules: {[m.id for m in self.__submodule_order]}", + force=False) + else: + # Enable trace recording for next forward/backward pass + self.__trace_mode = ZeRoTraceMode.RECORD + + else: + if self.__profiler is not None: + self.__profiler.log_events() + + self.__param_queue = collections.deque(self.__param_order) # reset fetch queue + self.__most_recent_step_id_param_fetched_for = collections.defaultdict(lambda: int(-1e10)) + self.__step_id_module_fetched_for = collections.defaultdict(lambda: collections.deque()) + self.__step_id = 0 + self.__n_available_params = 0 + self.__profiler.reset_events() + + def _dump_params(self, tag, sub_module, params, step_id=None): + if step_id is None: + step_id = self.__step_id + param_names = [debug_param2name_id(p) for p in params] + print_rank_0(f'{tag} step = {step_id} mod = {debug_module2name_id(sub_module)} p_names = {param_names}', + force=False) + + def _dump_param_ids(self, tag, mod_id, p_ids, step_id=None): + if step_id is None: + step_id = self.__step_id + print_rank_0(f'{tag} mod = {mod_id}, step = {step_id}, p_ids = {p_ids}', force=False) + + """Fetch and Release + Fetching, prefetching, and releasing parameters + """ + + @compiler.disable + @instrument_w_nvtx + @torch.no_grad() + def fetch_sub_module(self, current_submodule: Module, forward: bool, args=None) -> None: + """This method does the following (in order): + 1. kick off fetch for parameters in immediately required sub module + 2. kick off fetch for next few parameters we will need later (prefetch) + 3. block on parameters in immediately required sub module + """ + if logger.isEnabledFor(logging.DEBUG): + debug_rank0( + f"{self.__step_id}: M{current_submodule.id}({type(current_submodule).__name__}) P{[p.ds_id for p in iter_params(current_submodule, recurse=z3_leaf_module(current_submodule))]} " + + str({ + "avail": f"{self.__n_available_params:.1e}", + "queue_sz": f"{len(self.__param_queue or [])}", + "inflight": [p.ds_id for p in self.__inflight_param_registry], + })) + + if self.opt_modules and forward and current_submodule._get_name().upper() == 'EMBEDDING': + self.cur_shape = args[0][0].shape[-1] if args is not None else 0 + force_tensor = self.force_release_true if self.cur_shape > self.limit_shape else self.force_release_false + dist.all_reduce(force_tensor, group=dist.get_world_group()) + force_release = True if (force_tensor > 0) else False + + if force_release: + self.stop_current_step_opt = True + if len(self.record_opt_modules) > 0: + # 需要把opt参数都释放掉 + for opt_module in self.record_opt_modules: + if hasattr(opt_module, "HOLD_COMPLETE"): + opt_module.__delattr__("HOLD_COMPLETE") + params = set(p for p in iter_params(opt_module, recurse=z3_leaf_module(opt_module))) + for p in params: + if p.ds_status == ZeroParamStatus.HOLD_COMPLETE: + p.ds_status = ZeroParamStatus.AVAILABLE + p.ds_active_sub_modules.discard(opt_module.id) + self.__release_param(p) + self.record_opt_modules.clear() + torch.cuda.empty_cache() + + params_to_fetch = frozenset(iter_params(current_submodule, recurse=z3_leaf_module(current_submodule))) + fetch_numel = sum( + [p.partition_numel() for p in params_to_fetch if p.ds_status == ZeroParamStatus.NOT_AVAILABLE]) + + # opt参数已经常驻 + already_complete_cached = False + if self.opt_modules and not self.stop_current_step_opt and (current_submodule._get_name().upper() in self.opt_modules) and hasattr(current_submodule, "HOLD_COMPLETE"): + for p in params_to_fetch: + if p.ds_status == ZeroParamStatus.HOLD_COMPLETE: + already_complete_cached = True + break + + run_fetch = not already_complete_cached if self.opt_modules else True + if run_fetch: + if fetch_numel > 0: + event_name = __class__.FORWARD_FETCH_SUBMIT if forward else __class__.BACKWARD_FETCH_SUBMIT + self._dump_param_ids(event_name, current_submodule.id, + [p.ds_id for p in params_to_fetch if p.ds_status == ZeroParamStatus.NOT_AVAILABLE]) + self.__profiler.start_event(event_name) + # kick off all gather for params in the immediately required submodule + #for param in params_to_fetch: + if logger.isEnabledFor(logging.DEBUG): + for param in params_to_fetch: + debug_rank0(f"-fetch: {param.ds_summary()}") + self.__all_gather_params(params_to_fetch, forward) + self.__profiler.stop_event(event_name, fetch_numel) + + wait_numel = 0 + wait_event_name = __class__.FORWARD_FETCH_WAIT if forward else __class__.BACKWARD_FETCH_WAIT + self.__profiler.start_event(wait_event_name) + # wait for parameters in the immediately needed submodule to become available + for param in params_to_fetch: + param.ds_active_sub_modules.add(current_submodule.id) + if logger.isEnabledFor(logging.DEBUG): + debug_rank0(f"-wait: {param.ds_summary()}") + if param in self.__inflight_param_registry: + wait_numel += param.partition_numel() + with get_accelerator().stream(self.__allgather_stream): + while self.__ongoing_fetch_events and self.__ongoing_fetch_events[0].query(): + self.__ongoing_fetch_events.popleft() + if len(self.__ongoing_fetch_events) > self.__max_ongoing_fetch_events: + self.__ongoing_fetch_events.popleft().synchronize() + + self.__inflight_param_registry.pop(param).wait() + + if not get_accelerator().handles_memory_backpressure(): + event = get_accelerator().Event() + event.record() + self.__ongoing_fetch_events.append(event) + + assert param.ds_status == ZeroParamStatus.AVAILABLE, param.ds_summary() + if not get_accelerator().resolves_data_dependency(): + get_accelerator().current_stream().wait_stream(self.__allgather_stream) + self.__profiler.stop_event(wait_event_name, wait_numel) + + # kick off parameter prefetches for upcoming modules + # don't prefetch if we dont have a completed model trace + if self.is_complete_trace(): + # go through the parameters we need for the current module and pop them + # off the fetch queue so that they aren't prefetched later. + # if params have already been popped off the fetch queue by earlier + # prefetches we won't look for them here + discarded_from_prefetch_queue = set() + params_not_already_fetched = set( + filter(lambda p: self.__most_recent_step_id_param_fetched_for[p] < self.__step_id, params_to_fetch)) + while self.__param_queue and len(discarded_from_prefetch_queue) < len(params_not_already_fetched): + param_in_trace = self.__param_queue.popleft() + self.__most_recent_step_id_param_fetched_for[ + param_in_trace.param] = param_in_trace.step_id_last_used_at + discarded_from_prefetch_queue.add(param_in_trace.param) + + if discarded_from_prefetch_queue != params_not_already_fetched: + raise RuntimeError( + f"tracing error at step {self.__step_id}: \n" + f"module id: {current_submodule.id}, training: {current_submodule.training}\n" + f"expected the next {len(params_not_already_fetched)} parameters in the " + f"parameter fetch queue to be {tuple(p.ds_summary(use_debug_name=True) for p in params_not_already_fetched)} \n" + f"but got \n {tuple(p.ds_summary(use_debug_name=True) for p in discarded_from_prefetch_queue)}.") + + def _is_currently_on_nvme(param): + if param.nvme_swapper is None: + return False + + return param.ds_tensor.final_location == OffloadDeviceEnum.nvme \ + and param.ds_tensor.status == PartitionedParamStatus.NOT_AVAILABLE + + # kick off all gather for params in the next few submodules (prefetch) + if self.__prefetch_bucket_sz > 0: + max_params_to_prefetch = min(self.__max_n_available_params - self.__n_available_params, + self.__prefetch_bucket_sz) + params_to_prefetch = set() + numel_prefetching = 0 + while self.__param_queue and numel_prefetching < max_params_to_prefetch: + param_in_trace: __class__.__ParamInTrace = self.__param_queue.popleft() + + if _is_currently_on_nvme(param_in_trace.param): + # nvme prefetch is handled elsewhere. Need to break here to preserve fetch order + self.__param_queue.appendleft(param_in_trace) + break + + do_prefetch = param_in_trace.param.ds_status == ZeroParamStatus.NOT_AVAILABLE + if param_in_trace.param in params_to_prefetch: + # Avoid duplicates + do_prefetch = False + + self.__most_recent_step_id_param_fetched_for[param_in_trace.param] = \ + max(self.__most_recent_step_id_param_fetched_for[param_in_trace.param], + param_in_trace.step_id_last_used_at) + + if do_prefetch: + params_to_prefetch.add(param_in_trace.param) + numel_prefetching += param_in_trace.param.ds_numel + + if numel_prefetching > 0: + event_name = __class__.FORWARD_PREFETCH_SUBMIT if forward else __class__.BACKWARD_PREFETCH_SUBMIT + self.__profiler.start_event(event_name) + if logger.isEnabledFor(logging.DEBUG): + for param in params_to_prefetch: + debug_rank0(f"-prefetch: {param.ds_summary()}") + self.__all_gather_params(params_to_prefetch, forward) + self.__profiler.stop_event(event_name, numel_prefetching) + + if self.__prefetch_nvme: + self.__prefetch_nvme_param_partitions() + + self.__step_id += 1 + + @instrument_w_nvtx + @torch.no_grad() + def release_sub_module(self, submodule: Module, forward=True) -> None: + """release the parameters of a sub module, assuming they meet conditions to + be released.""" + if self.opt_modules: + is_opt_submodule_name = (submodule._get_name().upper() in self.opt_modules) + # opt参数常驻 + limit_flag = (self.__step_id > int(self.limit_module_id[0]) and self.__step_id < int(self.limit_module_id[1])) + if not self.stop_current_step_opt and is_opt_submodule_name and (limit_flag or hasattr(submodule, "HOLD_COMPLETE")): + if not hasattr(submodule, "HOLD_COMPLETE"): + submodule.__setattr__("HOLD_COMPLETE", True) + self.record_opt_modules.append(submodule) + params = set(p for p in iter_params(submodule, recurse=z3_leaf_module(submodule))) + for p in params: + if p.ds_status != ZeroParamStatus.HOLD_COMPLETE: + assert p.ds_status == ZeroParamStatus.AVAILABLE, p.ds_summary() + p.ds_status = ZeroParamStatus.HOLD_COMPLETE + return + + params_to_release = (self.__params_to_release(submodule, self.__step_id) if self.is_complete_trace() else set( + p.ds_id for p in iter_params(submodule, recurse=z3_leaf_module(submodule)))) + for param in iter_params(submodule, recurse=z3_leaf_module(submodule)): + param.ds_active_sub_modules.discard(submodule.id) + if param.ds_id in params_to_release and not param.is_external_param: + self.__release_param(param) + + @instrument_w_nvtx + @torch.no_grad() + def release_and_reset_all(self, module: Module) -> None: + """release all module parameters""" + for param in iter_params(module, recurse=True): + if self.opt_modules and param.ds_status == ZeroParamStatus.HOLD_COMPLETE: + param.ds_status = ZeroParamStatus.AVAILABLE + if param in self.__inflight_param_registry: + self.__inflight_param_registry.pop(param).wait() + + # TODO. make this throw if if there are still active submodules. currently + # there's a hook execution issue + param.ds_active_sub_modules.clear() + self.__release_param(param) + + for param in iter_params(module, recurse=True): + if self.opt_modules and param.ds_status == ZeroParamStatus.HOLD_COMPLETE: + param.ds_status = ZeroParamStatus.AVAILABLE + if param.ds_status != ZeroParamStatus.NOT_AVAILABLE: + raise RuntimeError(f"{param.ds_summary()} expected to be released") + + @instrument_w_nvtx + def __all_gather_params(self, params: Set[Parameter], forward: bool) -> None: + quantized_params = [] + nonquantized_params = [] + for param in params: + if hasattr(param.ds_tensor, 'ds_quant_scale'): + quantized_params.append(param) + else: + nonquantized_params.append(param) + if quantized_params: + self.__all_gather_params_(quantized_params, forward, quantize=True) + if nonquantized_params: + self.__all_gather_params_(nonquantized_params, forward, quantize=self.zero_quantized_weights) + + def __all_gather_params_(self, params: Set[Parameter], forward: bool, quantize: bool = False) -> None: + """for each partitioned parameter, kick off an async allgather and store + the work handle for the in flight parameters.""" + partitioned_params = [] + all_gather_numel = 0 # numel = num of elements + for param in params: + if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: + partitioned_params.append(param) + all_gather_numel += param.ds_numel + + if partitioned_params: + self.__n_available_params += all_gather_numel + # here we need to handle a special case where some of the parameters have a valid hpz secondary tensor (e.g. they are not trainable so their secondary tensor never expire) but others do not. + partitioned_params_with_secondary_tensors = [ + p for p in partitioned_params if p.ds_secondary_tensor is not None + ] + partitioned_params_without_secondary_tensors = [ + p for p in partitioned_params if p.ds_secondary_tensor is None + ] + for param_group in [ + partitioned_params_with_secondary_tensors, partitioned_params_without_secondary_tensors + ]: + if not param_group: + continue + with get_accelerator().stream(self.__allgather_stream): + event_name = __class__.FORWARD_ALL_GATHER if forward else __class__.BACKWARD_ALL_GATHER + self.__profiler.start_event(event_name) + handle = param_group[0].all_gather_coalesced(param_group, quantize=quantize) + self.__profiler.stop_event(event_name, all_gather_numel) + for param in param_group: + assert param.ds_status == ZeroParamStatus.INFLIGHT, param.ds_summary() + self.__inflight_param_registry[param] = handle + + # Release swap buffers for persisted params on nvme since they will never be partitioned or evicted from GPU + swap_persisted_params = [ + p for p in partitioned_params if p.ds_persist and p.ds_tensor.final_location == OffloadDeviceEnum.nvme + ] + if swap_persisted_params: + swap_persisted_params[0].nvme_swapper.remove_partition_and_release_buffers(swap_persisted_params) + + @compiler.disable + @instrument_w_nvtx + def __release_param(self, param: Parameter) -> None: + if param.ds_status == ZeroParamStatus.AVAILABLE and not param.ds_active_sub_modules: + if logger.isEnabledFor(logging.DEBUG): + debug_rank0(f"-release: {param.ds_summary()}") + param.partition() + self.__n_available_params -= param.ds_numel + + @instrument_w_nvtx + @functools.lru_cache(maxsize=None) + def __params_to_release(self, submodule_to_release: Module, step_id: int) -> Set[int]: + if not self.is_complete_trace(): + raise RuntimeError("expected trace to be complete") + + params_to_release = set( + p.ds_id for p in iter_params(submodule_to_release, recurse=z3_leaf_module(submodule_to_release)) + if not p.ds_persist) + + # Problem: When prefetcher scans the param trace, it skips AVAILABLE params. + # This creates issues if those params are released before the skipped uses: + # 1) It hurts performance as the skipped uses are never prefetched. + # 2) For nvme params, we run out of swap buffers because the prefetch order + # diverges from the trace. + # Solution: Don't release params whose reuse was skipped by prefetch. This is + # possible because we detect such skips during prefetch and mark those params. + for param in iter_params(submodule_to_release, recurse=z3_leaf_module(submodule_to_release)): + if self.__most_recent_step_id_param_fetched_for[param] > step_id: + params_to_release.discard(param.ds_id) + + # examine all modules within `max_reuse_dist_in_numel` of the current step, + # if we see any of the candidate parameters to be released reoccur while + # doing this, remove them from the set of parameters to release. + params_traversed = 0 + for module in self.__submodule_order[step_id:]: + if params_traversed >= self.__max_reuse_dist_in_numel: + break + for param in iter_params(module, recurse=z3_leaf_module(submodule_to_release)): + params_to_release.discard(param.ds_id) + params_traversed += param.ds_numel + + return params_to_release + + @instrument_w_nvtx + def __prefetch_nvme_param_partitions(self) -> None: + """swap in parameter partitions from nvme for those parameters that will be used + after the ones that are already being prefetched into full parameters + """ + if not self.is_complete_trace(): + return + + numel_in_flight = sum(param.ds_numel for param in self.__inflight_param_registry) + + numel_considered = 0 + swap_in_params = [] + for param_in_trace in self.__param_queue: + param = param_in_trace.param + if param.nvme_swapper is None: + continue + if (numel_considered > 2 * numel_in_flight + or len(swap_in_params) >= param.nvme_swapper.available_swap_in_buffers()): + break + if param.ds_tensor.status == PartitionedParamStatus.NOT_AVAILABLE: + swap_in_params.append(param) + numel_considered += param.ds_numel + + if swap_in_params: + swap_in_params[0].nvme_swapper.swap_in(swap_in_params, async_op=True) diff --git a/toolbox/DeepSpeed/v0.15.3/patches/deepspeed/runtime/zero/stage3.py b/toolbox/DeepSpeed/v0.15.3/patches/deepspeed/runtime/zero/stage3.py new file mode 100644 index 0000000000000000000000000000000000000000..a801d6f9856b00873c14f998df85ca7445143991 --- /dev/null +++ b/toolbox/DeepSpeed/v0.15.3/patches/deepspeed/runtime/zero/stage3.py @@ -0,0 +1,3126 @@ +#!/usr/bin/env python3 +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import sys +import gc +import collections +import itertools +from typing import Deque, Dict, Set, Tuple, Container +from contextlib import contextmanager + +from deepspeed import comm as dist +from deepspeed.utils import groups, z3_leaf_parameter + +from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors +from deepspeed.runtime.base_optimizer import ZeROOptimizer +from deepspeed.utils import logger +from deepspeed.runtime.fp16.loss_scaler import CreateLossScaler +from deepspeed.runtime.comm.coalesced_collectives import reduce_scatter_coalesced, all_to_all_quant_reduce +from deepspeed.runtime.utils import inf, is_model_parallel_parameter, get_only_unique_item +from deepspeed.runtime.zero.partition_parameters import * +from deepspeed.runtime.zero.config import ZeroStageEnum +from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum, OffloadStateTypeEnum +from deepspeed.runtime.zero.parameter_offload import DeepSpeedZeRoOffload +from deepspeed.runtime.zero.utils import apply_to_tensors_only, get_mapping_to_flat_buffer +from deepspeed.runtime.zero.offload_states import offload_adam_states, reload_adam_states +from deepspeed.ops.adam import DeepSpeedCPUAdam +from deepspeed.runtime.swap_tensor.partitioned_param_swapper import PartitionedParamStatus +from deepspeed.runtime.swap_tensor.optimizer_utils import OptimizerSwapper +from deepspeed.runtime.swap_tensor.partitioned_optimizer_swapper import PartitionedOptimizerSwapper +from deepspeed.runtime.swap_tensor.pipelined_optimizer_swapper import PipelinedOptimizerSwapper +from deepspeed.checkpoint.constants import OPTIMIZER_STATE_DICT, FP32_FLAT_GROUPS, PARTITION_COUNT, ZERO_STAGE, LOSS_SCALER +from deepspeed.accelerator import get_accelerator + +# Toggle this to true to enable correctness test +# with gradient partitioning and without +pg_correctness_test = False + +OPTIMIZER_SWAP_IN_STATE_TIMER = 'optimizer_swap_in_state' +INIT_OPTIMIZER_TIMER = 'init_optimizer_state' +OPTIMIZER_SWAP_OUT_STATE_TIMER = 'optimizer_swap_out_state' +OPTIMIZER_STEP_TIMER = 'optimizer_step' + + +def print_rank_0(message, debug=False, force=False): + rank = dist.get_rank() + if rank == 0 and (debug or force): + logger.info(message) + # other variations + # - print for all ranks w/o interleaving + # printflock(f"[{rank}] {message}") + # - print to log file per rank + # log_rank_file(rank, message) + + +def input(msg): + return + + +def isclose(a, b, rtol=1e-09, atol=0.0): + return abs(a - b) <= max(rtol * max(abs(a), abs(b)), atol) + + +def lcm(x, y): + from fractions import gcd # or can import gcd from `math` in Python 3 + return x * y // gcd(x, y) + + +def move_to_cpu(tensor_list): + for tensor in tensor_list: + tensor.data = tensor.data.cpu() + + +@contextmanager +def unwrap_model_for_generation(model): + """ + For ZeRO-3 models, we gather the weights once to speed up generation. + """ + with GatheredParameters(model.parameters()): + # Removes the optimizer hooks from a DeepSpeed ZeRO-3 model. + + # Remove hooks + if model.optimizer is not None and hasattr(model.optimizer, "parameter_offload"): + optimizer_offload = model.optimizer.parameter_offload + elif model.optimizer is not None: + optimizer_offload = model.optimizer + + for hook in optimizer_offload.forward_hooks: + hook.remove() + for hook in optimizer_offload.backward_hooks: + hook.remove() + + optimizer_offload.forward_hooks = [] + optimizer_offload.backward_hooks = [] + + yield model + + # Adds the optimizer hooks from a DeepSpeed ZeRO-3 model. + if model.optimizer is not None and hasattr(model.optimizer, "parameter_offload"): + optimizer_offload = model.optimizer.parameter_offload + elif model.optimizer is not None: + optimizer_offload = model.optimizer + optimizer_offload._register_hooks_recursively(optimizer_offload.module) + return + + +INITIAL_MICRO_STEP_ID = -1 + + +class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer): + """ + DeepSpeedZeroOptimizer designed to reduce the memory footprint + required for training large deep learning models. + + For more details please see ZeRO: Memory Optimization Towards Training A Trillion Parameter Models + https://arxiv.org/abs/1910.02054 + + For usage examples, refer to TODO: DeepSpeed Tutorial + + """ + + def __init__( + self, + module, + init_optimizer, + timers, + ds_config, + static_loss_scale=1.0, + dynamic_loss_scale=False, + dynamic_loss_args=None, + verbose=True, + contiguous_gradients=True, + reduce_bucket_size=500000000, + prefetch_bucket_size=50000000, + max_reuse_distance=1000000000, + max_live_parameters=1000000000, + param_persistence_threshold=100000, + model_persistence_threshold=sys.maxsize, + dp_process_group=None, + reduce_scatter=True, + overlap_comm=False, + offload_optimizer_config=None, + offload_param_config=None, + sub_group_size=1000000000000, + offload_ratio=0.0, + mpu=None, + clip_grad=0.0, + gradient_accumulation_dtype=torch.float32, + communication_data_type=torch.float16, + postscale_gradients=True, + gradient_predivide_factor=1.0, + gradient_accumulation_steps=1, + elastic_checkpoint=False, + aio_config=None, + all2all_process_group=None, + zero_hpz_partition_size=1, + zero_quantized_weights=False, + zero_quantized_nontrainable_weights=False, + ): + see_memory_usage("Stage 3 initialize beginning", force=True) + + print_rank_0(f"initialized {__class__.__name__} with args: {locals()}", force=False) + + if dist.get_rank() == 0: + logger.info(f"Reduce bucket size {reduce_bucket_size}") + logger.info(f"Prefetch bucket size {prefetch_bucket_size}") + # The fused optimizer does all the work. We need this layer for two reason: + # 1. maintain same user API from apex.fp16_utils + # 2. keep common stuff here in case we need to add ne552w fused optimizer later + + # differences from apex.fp16_utils: + # - assume all model params in fp16 + # - assume all params requires grad + # - flat by groups, not keeping state. TODO: remove state explicitly? + # - master grad and unflat master weight never exist. TODO: a way to save out unflat master? + if not get_accelerator().is_available(): + raise SystemError("Cannot use fp16 without accelerator.") + + self.optimizer = init_optimizer + + # Use torch (un)flatten ops + self.flatten = _flatten_dense_tensors + self.unflatten = _unflatten_dense_tensors + self.dtype = self.optimizer.param_groups[0]['params'][0].dtype + self.gradient_accumulation_dtype = gradient_accumulation_dtype + self._global_grad_norm = 0. + + self.custom_loss_scaler = False + self.external_loss_scale = None + + self.optimizer_swapper = None + self.swap_optimizer = False + + self.offload_optimizer = False + self.offload_optimizer_pin_memory = False + self.offload_optimizer_fast_init = False + self.offload_param = False + self.offload_param_pin_memory = False + self.params_in_nvme_and_cpu = False + self.max_params_in_cpu = 0 + self.partial_offload = offload_ratio + + #num of ranks in a ZeRO param partitioning group + self.zero_hpz_partition_size = zero_hpz_partition_size + + zero_param_parallel_group = groups._get_zero_param_intra_parallel_group() + print_rank_0( + f"ZeRO Stage 3 param partitioning group {self.zero_hpz_partition_size} {zero_param_parallel_group}", + force=False) + if self.zero_hpz_partition_size > 1 and zero_param_parallel_group is None: + self._set_zero_group_parallelism() + zero_param_parallel_group = groups._get_zero_param_intra_parallel_group() + + self.parameter_offload = self.initialize_ds_offload( + module=module, + timers=timers, + ds_config=ds_config, + overlap_comm=overlap_comm, + prefetch_bucket_size=prefetch_bucket_size, + max_reuse_distance=max_reuse_distance, + max_live_parameters=max_live_parameters, + param_persistence_threshold=param_persistence_threshold, + model_persistence_threshold=model_persistence_threshold, + dp_process_group=dp_process_group, + offload_param_config=offload_param_config, + mpu=mpu, + zero_param_parallel_group=zero_param_parallel_group, + zero_quantized_weights=zero_quantized_weights, + zero_quantized_nontrainable_weights=zero_quantized_nontrainable_weights) + + self.persistent_parameters = self.parameter_offload.persistent_parameters + self._configure_offloading(offload_optimizer_config, offload_param_config) + + # backup fused_adam optimizer init + if self.offload_optimizer and self.partial_offload != 1.0: + backup_gpu_tensor = torch.randn(1, device=get_accelerator().device_name()).to(self.dtype) + backup_gpu_param = torch.nn.Parameter(backup_gpu_tensor) + assert type(init_optimizer) == DeepSpeedCPUAdam, 'Hybrid Optimizer Only Supports DeepSpeedCPUAdam' + self.backup_optimizer = torch.optim.AdamW([backup_gpu_param], + lr=self.optimizer.param_groups[0]["lr"], + betas=self.optimizer.param_groups[0]["betas"], + eps=self.optimizer.param_groups[0]["eps"], + weight_decay=self.optimizer.param_groups[0]["weight_decay"], + amsgrad=self.optimizer.param_groups[0]["amsgrad"]) + # Multiple param_groups configs for back-up optimizer + if len(self.optimizer.param_groups) > 1: + for i in range(1, len(self.optimizer.param_groups)): + self.backup_optimizer.add_param_group(self.optimizer.param_groups[i]) + + self.module = module + self.elastic_checkpoint = elastic_checkpoint + + self.device = get_accelerator().current_device_name() if not self.offload_optimizer else OffloadDeviceEnum.cpu + + self.inf_or_nan_tracker: Tensor = torch.zeros(1, dtype=torch.bool, device=self.device, requires_grad=False) + + self.deepspeed_adam_offload = (self.offload_optimizer and type(init_optimizer) == DeepSpeedCPUAdam) + + ### streams used for overlapping computation with communication + self.reduce_and_partition_stream = None if get_accelerator().is_synchronized_device() else get_accelerator( + ).Stream() if overlap_comm else get_accelerator().default_stream() + + ############################################################################ + + self.n_caching_allocator_flushes = 0 + + #-------------Stage 3 Setup-------------------# + + self.timers = timers + + self.all2all_process_group = all2all_process_group + + self.reduce_scatter = reduce_scatter + + self.dp_process_group = self.parameter_offload.dp_process_group + self.sequence_parallel_size = groups._get_sequence_parallel_world_size() + + self.all2all_process_group = all2all_process_group + + self.zero_quantized_nontrainable_weights = zero_quantized_nontrainable_weights + + self.partition_count = dist.get_world_size(group=self.dp_process_group) + + if mpu is None: + self.model_parallel_group = None + self.model_parallel_rank = 0 + else: + self.model_parallel_group = mpu.get_model_parallel_group() + self.model_parallel_rank = mpu.get_model_parallel_rank() + + self.overflow = False + self.clip_grad = clip_grad + self.communication_data_type = communication_data_type + self.gradient_predivide_factor = gradient_predivide_factor + self.postscale_gradients = postscale_gradients + self.gradient_accumulation_steps = gradient_accumulation_steps + self.micro_step_id = 0 + self.reduce_bucket_size = int(reduce_bucket_size) + + if self.all2all_process_group is not None: + assert self.all2all_process_group is not None and self.reduce_scatter == True, "when enable all_to_all_reduce, reduce_scatter should also be enabled for data type checks." + + if self.reduce_scatter: + valid_reduce_scatter_dtypes = (torch.float16, torch.bfloat16, torch.float32) + assert self.communication_data_type in valid_reduce_scatter_dtypes, f"ZeRO-3 supports {valid_reduce_scatter_dtypes} communication_data_type with reduce scatter enabled. Got: '{self.communication_data_type}'" + assert self.gradient_predivide_factor == 1.0, "gradient_predivide_factor != 1.0 is not yet supported with ZeRO-3 with reduce scatter enabled" + assert self.postscale_gradients, "pre-scale gradients is not yet supported with ZeRO-3 with reduce scatter enabled" + + # Holds the mode parameter + # The param.data may not hold any meaningful data + # when param's status is NOT_AVAILABLE or IN_FLGHT + self.fp16_groups = [] + + # Hold partitioned parameters + self.fp16_partitioned_groups = [] + + # Holds a fused and flattened copy of the parameters + self.fp16_partitioned_groups_flat = [] + self.fp16_partitioned_groups_flat_numel = [] + self.fp16_partitioned_groups_flat_id = [] + + #defragmented pinned memory + self.param_groups_fp16_flat_cpu_memory = [] + + #a single 32-bit partition of the parallel partitioned parameters + #that this process will update + self.fp32_partitioned_groups_flat = [] + self.next_swappable_fp32_partitioned_groups = [] + + # number of elements per partition in each group + self.partition_size = [] + + self.all_reduce_print = False + + self.prefetch_elements = int(prefetch_bucket_size) + + self.contiguous_gradients = contiguous_gradients + + # padding on each partition for alignment purposes + self.groups_padding = [] + + self.sub_group_size = sub_group_size + + self.sub_group_to_group_id = {} + + # Trainable parameters + self.trainable_param_groups = self._get_trainable_parameter_groups() + + see_memory_usage("Before creating fp16 partitions", force=True) + self._create_fp16_partitions_with_defragmentation(self.trainable_param_groups) + num_fp16_subgroups = len(self.fp16_partitioned_groups_flat) + see_memory_usage(f"After creating fp16 partitions: {num_fp16_subgroups}", force=True) + + # Optimizer tensor swapping + if self.swap_optimizer: + self._configure_tensor_swapping(offload_optimizer_config, aio_config) + + self.is_gradient_accumulation_boundary: bool = True + + self.param_reduce_events: Deque[get_accelerator().Event] = collections.deque() + # TODO. make this configurable via JSON + self.max_param_reduce_events: int = 2 + + self.param_dict = {} + + # map between param_id and bool to specify if a param is in this partition + self.is_param_in_current_partition = {} + + self.extra_large_param_to_reduce = None + self.grads_in_ipg_bucket = [] + self.params_in_ipg_bucket = [] + + self.params_already_reduced = {} + self._release_ipg_buffers() + self.previous_reduced_grads = None + + # model parameter traversal-based param id that's stable across runs + for params_group in self.fp16_groups: + for param in params_group: + param_id = self.get_param_id(param) + self.param_dict[param_id] = param + self.params_already_reduced[param_id] = False + + #Largest partitioned param + largest_partitioned_param_numel = 0 + for fp16_partitioned_group in self.fp16_partitioned_groups: + if len(fp16_partitioned_group) > 0: + largest_partitioned_param_numel = max( + largest_partitioned_param_numel, + max([max(tensor.numel(), tensor.ds_numel) for tensor in fp16_partitioned_group])) + + print_rank_0(f'Largest partitioned param numel = {largest_partitioned_param_numel}', force=False) + + self._setup_for_real_optimizer() + self.grad_position = {} + self.set_grad_positions() + + if self.offload_optimizer: + self.norm_for_param_grads = {} + + # stores if a partition has been reduced in this step + self.is_partition_reduced = {} + + # stores if a grad in a partition has been computed or not + self.is_grad_computed = {} + + # will store the averaged gradients required by this partition + self.averaged_gradients = {} + + #creates backward hooks for gradient partitioning + ###Calls all gather param + self._grad_acc_hooks = [] + self._leaf_module_hooks = [] + self.create_reduce_and_remove_grad_hooks() + + #exit(0) + + # we may have a way of fusing dynamic scale. Do not support for now + self.loss_scaler = CreateLossScaler(dtype=self.dtype, + static_loss_scale=static_loss_scale, + dynamic_scaling=dynamic_loss_scale, + dynamic_loss_args=dynamic_loss_args) + self.dynamic_loss_scale = self.loss_scaler.dynamic + + self.debug_fp16_grads = [{} for _ in self.fp16_groups] + + self._link_all_hp_params() + + self.offloaded_states: Set(OffloadDeviceEnum) = set() + + if dist.get_rank(group=self.dp_process_group) == 0: + see_memory_usage(f"After initializing ZeRO optimizer", force=True) + + def destroy(self): + self.parameter_offload.destroy() + for hook in self._grad_acc_hooks: + hook.remove() + for hook in self._leaf_module_hooks: + hook.remove() + print_rank_0("Removed grad acc hooks", force=False) + del self.__ipg_bucket_flat_buffer + + def initialize_ds_offload( + self, + module, + timers, + ds_config, + overlap_comm, + prefetch_bucket_size, + max_reuse_distance, + max_live_parameters, + param_persistence_threshold, + model_persistence_threshold, + dp_process_group, + offload_param_config, + mpu, + zero_param_parallel_group, + zero_quantized_weights, + zero_quantized_nontrainable_weights, + ): + return DeepSpeedZeRoOffload(module=module, + timers=timers, + ds_config=ds_config, + overlap_comm=overlap_comm, + prefetch_bucket_size=prefetch_bucket_size, + max_reuse_distance=max_reuse_distance, + max_live_parameters=max_live_parameters, + param_persistence_threshold=param_persistence_threshold, + model_persistence_threshold=model_persistence_threshold, + dp_process_group=dp_process_group, + offload_param_config=offload_param_config, + mpu=mpu, + zero_param_parallel_group=zero_param_parallel_group, + zero_quantized_weights=zero_quantized_weights, + zero_quantized_nontrainable_weights=zero_quantized_nontrainable_weights) + + def _get_trainable_parameter_groups(self): + param_groups = [] + PARAMS_KEY = "params" + for param_group in self.optimizer.param_groups: + trainable_params = [p for p in param_group[PARAMS_KEY] if p.requires_grad] + if len(trainable_params) == 0: + continue + + trainable_param_group = {} + for key in param_group.keys(): + if key == PARAMS_KEY: + trainable_param_group[PARAMS_KEY] = trainable_params + else: + trainable_param_group[key] = param_group[key] + param_groups.append(trainable_param_group) + + return param_groups + + def _set_zero_group_parallelism(self): + groups._create_zero_param_parallel_group(self.zero_hpz_partition_size) + + def invalidate_secondary_tensor(self): + for fpg in self.fp16_groups: + for param in fpg: + if param.ds_secondary_tensor is not None: + param.ds_secondary_tensor = None + + def _setup_for_real_optimizer(self): + see_memory_usage("Before creating fp32 partitions", force=True) + self._create_fp32_partitions() + see_memory_usage("After creating fp32 partitions", force=True) + dist.barrier() + + # To support pipelined optimizer swapping + self._create_next_swappable_fp32_groups() + + see_memory_usage("Before initializing optimizer states", force=True) + + self.initialize_optimizer_states() + see_memory_usage("After initializing optimizer states", force=True) + dist.barrier() + + if dist.get_rank() == 0: + logger.info(f"optimizer state initialized") + + # IPG + if self.contiguous_gradients: + self.__ipg_bucket_flat_buffer: Tensor = torch.empty(self.reduce_bucket_size, + dtype=self.dtype, + device=get_accelerator().current_device_name()) + + self.grad_partitions_flat_buffer = None + self.__param_id_to_grad_partition: Dict[int, Tensor] = {} + + all_params = list(itertools.chain.from_iterable(self.fp16_groups)) + + self.grad_partitions_flat_buffer: Tensor = torch.zeros(sum(p.partition_numel() for p in all_params), + dtype=self.gradient_accumulation_dtype, + device=self.device) + if self.offload_optimizer_pin_memory: + self.grad_partitions_flat_buffer = get_accelerator().pin_memory(self.grad_partitions_flat_buffer) + + offset = 0 + for param in all_params: + self.__param_id_to_grad_partition[param.ds_id] = self.grad_partitions_flat_buffer.narrow( + 0, offset, param.partition_numel()) + offset += param.partition_numel() + + def _link_all_hp_params(self): + for p in self.module.parameters(): + p._z3_optimizer = self + + def set_lr(self, lr): + """Set the learning rate.""" + for param_group in self.optimizer.param_groups: + param_group["lr"] = lr + + def get_lr(self): + """Return the current learning rate.""" + return self.optimizer.param_groups[0]["lr"] + + # TODO. factor out to a utility outside of stage3 + @staticmethod + def defragment(tensors: List[Tensor]) -> Tensor: + """move provided tensors into a contiguous flat buffer, with some additional + measures taken to reduce memory fragmentation""" + assert len(set(t.dtype for t in tensors)) == 1 + assert len(set(t.device for t in tensors)) == 1 + + cpu_buffer = torch.empty(sum(p.numel() for p in tensors), + dtype=get_only_unique_item(t.dtype for t in tensors), + device="cpu") + tensor_infos: List[Tuple[Tensor, int, int]] = get_mapping_to_flat_buffer(tensors) + orig_device = get_only_unique_item(t.device for t in tensors) + + offset = 0 + for tensor, offset, tensor_numel in tensor_infos: + # move the tensor from device memory to host memory + cpu_buffer.narrow(0, offset, tensor_numel).copy_(tensor) + tensor.data = torch.empty(0, dtype=tensor.dtype, device=tensor.device) + + gc.collect() + get_accelerator().empty_cache() + + # copy tensors (now flattened and contiguous) back to GPU + device_buffer = cpu_buffer.to(orig_device) + + # restore device tensors + for tensor, offset, tensor_numel in tensor_infos: + tensor.data = device_buffer.narrow(0, offset, tensor_numel) + + return device_buffer + + def _get_param_coordinator(self, training): + return self.parameter_offload.get_param_coordinator(training) + + def _configure_offloading(self, offload_optimizer_config, offload_param_config): + ###################### offload optimizer setup ################################## + if offload_optimizer_config is not None and offload_optimizer_config.device != OffloadDeviceEnum.none: + self.offload_optimizer = True + self.offload_optimizer_pin_memory = offload_optimizer_config.pin_memory + self.swap_optimizer = offload_optimizer_config.device == OffloadDeviceEnum.nvme + self.offload_optimizer_fast_init = offload_optimizer_config.fast_init + + ###################### offload param setup ################################## + if offload_param_config is not None and offload_param_config.device != OffloadDeviceEnum.none: + self.offload_param = True + self.offload_param_pin_memory = offload_param_config.pin_memory + self.params_in_nvme_and_cpu = offload_param_config.device == OffloadDeviceEnum.nvme + self.max_params_in_cpu = offload_param_config.max_in_cpu + print_rank_0( + f"FP16 params swapping is {self.params_in_nvme_and_cpu}, Max params in CPU is {self.max_params_in_cpu}", + force=False) + + def _configure_tensor_swapping(self, offload_optimizer_config, aio_config): + nvme_swap_folder = os.path.join(offload_optimizer_config.nvme_path, 'zero_stage_3') + os.makedirs(nvme_swap_folder, exist_ok=True) + if dist.get_rank() == 0: + logger.info(f'Tensor Swapping: Adding optimizer tensors') + + swapper_type = PipelinedOptimizerSwapper if offload_optimizer_config.pipeline else PartitionedOptimizerSwapper + + self.optimizer_swapper = swapper_type(swap_config=offload_optimizer_config, + aio_config=aio_config, + base_folder=nvme_swap_folder, + optimizer=self.optimizer, + largest_numel=max(self.fp16_partitioned_groups_flat_numel), + device=self.device, + dtype=torch.float32, + timers=self.timers) + + @property + def elements_in_ipg_bucket(self): + return sum(p.ds_numel for p in self.params_in_ipg_bucket) + + def _move_to_flat_buffer(self, param_list, flat_buffer, avoid_copy=False): + '''If flat buffer is None then the parameters in the param_list are + not copied to the flat buffer. This is because they exceed the number of max_params_in_cpu + Some of these parameters may already be in CPU in unflattened buffers + or they maybe in GPU, or they maybe in NVME. If they are in NVME, then + they will be marked as NOT_AVAILABLE, and will be moved to CPU when they are + needed during training.''' + if flat_buffer is None: + # this dst buffer is on NVMe, so skip this + return + + start = 0 + for param in param_list: + src = param.ds_tensor + dest = flat_buffer.narrow(0, start, src.ds_numel) + start = start + src.ds_numel + '''if the parameter was initialized in nvme then bring it to the destination buffer directly''' + if src.status == PartitionedParamStatus.NOT_AVAILABLE: + print_rank_0( + f"Swapping in {param.ds_id} with partition size {param.partition_numel()} permanently to CPU") + param.nvme_swapper.swap_into_buffer(param, dest) + src.data = dest.data + src.status = PartitionedParamStatus.AVAILABLE + else: + assert src.status == PartitionedParamStatus.AVAILABLE, "Partitioned Param must be available here" + if not avoid_copy: + dest.data.copy_(src.data) + src.data = dest.data + + # Final location must be gpu/cpu in this case + param.ds_tensor.final_location = 'not-nvme' + + def _create_param_groups_fp16_flat_cpu_memory(self): + + aggregate_params_count = 0 + + for j, param_group in enumerate(self.trainable_param_groups): + params_in_group = sum([p.partition_numel() for p in param_group['params']]) + + flat_buffer_size = params_in_group + + if self.params_in_nvme_and_cpu and \ + aggregate_params_count + params_in_group > self.max_params_in_cpu: + + flat_buffer_size = max(0, self.max_params_in_cpu - aggregate_params_count) + + aggregate_params_count += params_in_group + + if flat_buffer_size > 0: + print_rank_0(f"group {j} flat buffer size {flat_buffer_size}", force=False) + self.param_groups_fp16_flat_cpu_memory.append(get_accelerator().pin_memory( + torch.empty(int(flat_buffer_size), dtype=self.dtype))) + else: + print_rank_0(f"No flat buffer size. Param group size was {params_in_group}", force=False) + + self.param_groups_fp16_flat_cpu_memory.append(torch.empty(1, dtype=self.dtype)) + + def _create_fp16_partitions_with_defragmentation(self, fp16_param_groups): + dist.barrier() + + param_groups: List[List[Parameter]] = tuple( + self._create_fp16_sub_groups(param_group["params"]) for param_group in fp16_param_groups) + + # bookkeeping related to param groups + for param_group_idx, param_group in enumerate(param_groups): + for sub_group in param_group: + sub_group_idx = len(self.fp16_groups) + + # record sub group and partitions + self.fp16_groups.append(sub_group) + self.fp16_partitioned_groups.append([param.ds_tensor for param in sub_group]) + + # record sub group -> group mapping + self.sub_group_to_group_id[sub_group_idx] = param_group_idx + + # record total elements of parameter partitions in sub group + self.fp16_partitioned_groups_flat_numel.append(sum(p.partition_numel() for p in sub_group)) + + # record ds_ids of parameter partitions in sub group + self.fp16_partitioned_groups_flat_id.append([p.ds_id for p in sub_group]) + + # record padding required to align group to world size (only applies to last rank) + rank_requires_padding = dist.get_rank( + self.dp_process_group) == dist.get_world_size(self.dp_process_group) - 1 + self.groups_padding.append([p.padding_size() if rank_requires_padding else 0 for p in sub_group]) + + # move parameters to flattened buffer + if not self.offload_param: # partitioned params remain in GPU during training + # move parameter partitions into a single contiguous flat buffer + parameter_partitions: List[Tensor] = [] + for sub_group in self.fp16_groups: + for param in sub_group: + parameter_partitions.append(param.ds_tensor) + + # We need to keep the reference to this buffer to make sure you can free it in `offload_states` + self.lp_param_buffer = __class__.defragment(parameter_partitions) + self._set_fp16_partitioned_groups_flat() + + else: # partitioned params offloaded to CPU when not in use + # create a flat CPU memory allocation for each param group + self._create_param_groups_fp16_flat_cpu_memory() + for param_group_idx, param_group in enumerate(param_groups): + flat_offset = 0 + for i, sub_group in enumerate(param_group): + total_elements = sum(p.partition_numel() for p in sub_group) + print_rank_0(f"Params in nvme and cpu {self.params_in_nvme_and_cpu}") + #Flat buffer may not be available for parameters that reside in NVME + if not self.params_in_nvme_and_cpu or flat_offset + total_elements <= self.param_groups_fp16_flat_cpu_memory[ + param_group_idx].numel(): + fp16_partitioned_group_flat = self.param_groups_fp16_flat_cpu_memory[param_group_idx].narrow( + 0, flat_offset, total_elements) + print_rank_0( + f"Creating a flat buffer for subgroup {i} requiring {total_elements} elements, and cumulative CPU elements {flat_offset + total_elements}", + force=False) + + elif self.params_in_nvme_and_cpu: + fp16_partitioned_group_flat = None + print_rank_0(f"No flat buffer for sub group {i} of {total_elements} elements", force=False) + else: + assert False, "Either params are in nvme, or they are in CPU memory. This code path should not be triggered. Please see you max_params_in_cpu and params_in_nvme configs" + + self.fp16_partitioned_groups_flat.append(fp16_partitioned_group_flat) + flat_offset += total_elements + + self._move_to_flat_buffer(sub_group, + fp16_partitioned_group_flat, + avoid_copy=not self.offload_param) + + # if necessary, create a pinned memory buffer to be used for swapping out + # params to NVME after optimizer step + should_create_fp16_flat_reuse_buffer = any(flattened_partition_group is None + for flattened_partition_group in self.fp16_partitioned_groups_flat) + if should_create_fp16_flat_reuse_buffer: + max_partition_numel, largest_partition_numel = 0, None + for sub_group in self.fp16_groups: + total_elements = sum(t.partition_numel() for t in sub_group) + if total_elements > max_partition_numel: + largest_partition_numel = [t.ds_numel for t in sub_group] + max_partition_numel = total_elements + + assert len(largest_partition_numel) > 0, f'Unexpected that largest partition is empty' + self.fp16_groups[0][0].nvme_swapper.reserve_partitioned_swap_space(largest_partition_numel) + + def _swap_in_sub_group_to_flat_buffer(self, flat_buffer, sub_group_id): + offset = 0 + elements_in_sub_group = sum([t.ds_numel for t in self.fp16_partitioned_groups[sub_group_id]]) + assert (flat_buffer.numel() == elements_in_sub_group) + for param, partitioned_param in zip(self.fp16_groups[sub_group_id], + self.fp16_partitioned_groups[sub_group_id]): + dest = flat_buffer.narrow(0, offset, partitioned_param.ds_numel) + if partitioned_param.status == PartitionedParamStatus.NOT_AVAILABLE: + print_rank_0( + f"Swapping in {param.ds_id} with elements {param.ds_numel} and partition {param.partition_numel()}" + ) + param.nvme_swapper.swap_in([param], async_op=False) + dest.data.copy_(partitioned_param.data) + param.nvme_swapper.remove_partition_and_release_buffers([param]) + print_rank_0(f"Swapping in {param.ds_id} done") + else: + dest.data.copy_(partitioned_param.data) + offset += partitioned_param.ds_numel + + def _create_next_swappable_fp32_groups(self): + reverse_order_indices = [i for i in range(len(self.fp32_partitioned_groups_flat))] + reverse_order_indices.reverse() + + next_group = None + for i in reverse_order_indices: + self.next_swappable_fp32_partitioned_groups.append(next_group) + if self._swappable_optimizer_subgroup(i): + next_group = self.fp32_partitioned_groups_flat[i] + + self.next_swappable_fp32_partitioned_groups.reverse() + + def _get_sub_group_partitions(self, sub_group_id): + sub_group_partitions = [] + for param, partitioned_param in zip(self.fp16_groups[sub_group_id], + self.fp16_partitioned_groups[sub_group_id]): + if partitioned_param.status == PartitionedParamStatus.NOT_AVAILABLE: + swap_path = param.nvme_swapper.get_path(param, True) + sub_group_partitions.append((partitioned_param, param.partition_numel(), swap_path)) + else: + sub_group_partitions.append((partitioned_param, partitioned_param.ds_numel, None)) + + return sub_group_partitions + + def _create_fp32_partitions(self): + cpu_memory_usage = 0 + cpu_memory_sub_groups = 0 + nvme_memory_usage = 0 + num_swappable_partitions = 0 + num_swap_from_nvme_partitions = 0 + num_swap_from_cpu_partitions = 0 + swap_from_nvme_memory_usage = 0 + swap_from_cpu_memory_usage = 0 + GIGA_BYTES = (1024**3) + + swappable_fp32_tensors = [] + swappable_fp16_src_tensors = [] + nvme_fp16_partitions_info = [] + nvme_fp16_num_elems = [] + nvme_fp32_dest_tensors = [] + fp32_element_size = torch.tensor([], dtype=torch.float32).element_size() + + # Assign portion of subgroup to cpu, the other to gpu. + if self.offload_optimizer: + self.subgroup_to_device = {} + sub_group_size = len(self.fp16_partitioned_groups_flat) + # print(f"Partial offload sub_group_size is {sub_group_size}, ratio is {self.partial_offload}\n") + for i in range(sub_group_size): + if i < int(self.partial_offload * sub_group_size): + self.subgroup_to_device[i] = 'cpu' + else: + self.subgroup_to_device[i] = get_accelerator()._name + + for i, tensor in enumerate(self.fp16_partitioned_groups_flat): + num_elements = self.fp16_partitioned_groups_flat_numel[i] + ds_id_begin = str(self.fp16_partitioned_groups_flat_id[i][0]) + ds_id_end = str(self.fp16_partitioned_groups_flat_id[i][-1]) + ds_id = ds_id_begin + '_' + ds_id_end + + # a partition of the fp32 master weights that will be updated by this process + if self._swappable_optimizer_subgroup(i): + self.fp32_partitioned_groups_flat.append(torch.Tensor()) + self.fp32_partitioned_groups_flat[i].ds_id = ds_id + nvme_memory_usage += (fp32_element_size * num_elements) + num_swappable_partitions += 1 + + if self.params_in_nvme_and_cpu and tensor is None: + num_swap_from_nvme_partitions += 1 + swap_from_nvme_memory_usage += (fp32_element_size * num_elements) + if self.offload_optimizer_fast_init: + sub_group_partitions = self._get_sub_group_partitions(i) + nvme_fp16_partitions_info.append(sub_group_partitions) + nvme_fp16_num_elems.append(num_elements) + nvme_fp32_dest_tensors.append(self.fp32_partitioned_groups_flat[i]) + else: + unpinned_fp32_buffer = torch.empty(num_elements, device=self.device, dtype=torch.float) + self._swap_in_sub_group_to_flat_buffer(unpinned_fp32_buffer, i) + self.optimizer_swapper.initialize_parameters(parameters=[self.fp32_partitioned_groups_flat[i]], + src_tensors=[unpinned_fp32_buffer]) + else: + num_swap_from_cpu_partitions += 1 + swap_from_cpu_memory_usage += (fp32_element_size * num_elements) + swappable_fp32_tensors.append(self.fp32_partitioned_groups_flat[i]) + swappable_fp16_src_tensors.append(self.fp16_partitioned_groups_flat[i]) + else: + cpu_memory_usage += (fp32_element_size * num_elements) + cpu_memory_sub_groups += 1 + + if self.params_in_nvme_and_cpu and tensor is None: + unpinned_fp32_buffer = torch.empty(num_elements, device=self.device, dtype=torch.float) + self._swap_in_sub_group_to_flat_buffer(unpinned_fp32_buffer, i) + self.fp32_partitioned_groups_flat.append(unpinned_fp32_buffer) + else: + if self.offload_optimizer: + self.fp32_partitioned_groups_flat.append(self.fp16_partitioned_groups_flat[i].to( + self.subgroup_to_device[i]).clone().float().detach()) + else: + self.fp32_partitioned_groups_flat.append(self.fp16_partitioned_groups_flat[i].to( + self.device).clone().float().detach()) + self.fp32_partitioned_groups_flat[i].ds_id = ds_id + + self.fp32_partitioned_groups_flat[i].requires_grad = True # keep this in case internal optimizer uses it + + if len(swappable_fp32_tensors) > 0: + self.optimizer_swapper.initialize_parameters(parameters=swappable_fp32_tensors, + src_tensors=swappable_fp16_src_tensors) + + if len(nvme_fp32_dest_tensors) > 0: + fp16_pinned_buffers = self.fp16_groups[0][0].nvme_swapper.reserve_available_buffers() + assert len(fp16_pinned_buffers) > 0 + self.optimizer_swapper.initialize_from_swapped_fp16_params(fp16_partitions_info=nvme_fp16_partitions_info, + fp16_num_elems=nvme_fp16_num_elems, + fp16_pinned_buffers=fp16_pinned_buffers, + fp32_parameters=nvme_fp32_dest_tensors) + self.fp16_groups[0][0].nvme_swapper.release_reserved_buffers() + + nvme_gigabytes = nvme_memory_usage / GIGA_BYTES + print_rank_0(f'Swappable FP32 Partitions: count={num_swappable_partitions} size={nvme_gigabytes:5.2f} GB', + force=False) + if self.params_in_nvme_and_cpu: + print_rank_0( + f'Swap from NVMe Partitions: count = {num_swap_from_nvme_partitions}, size = {swap_from_nvme_memory_usage/GIGA_BYTES:5.2f}GB', + force=False) + print_rank_0( + f'Swap from CPU Partitions: count = {num_swap_from_cpu_partitions}, size = {swap_from_cpu_memory_usage/GIGA_BYTES:5.2f}GB', + force=False) + + cpu_memory_gigabytes = cpu_memory_usage / GIGA_BYTES + print_rank_0(f'In-Memory FP32 Partitions: count={cpu_memory_sub_groups} size={cpu_memory_gigabytes:5.2f} GB', + force=False) + + # Clear for on-the-fly population before the optimizer step + for param_group in self.optimizer.param_groups: + param_group['params'] = [] + + def _create_fp16_sub_groups(self, params_group): + + params_group_numel = sum([param.partition_numel() for param in params_group]) + sub_group_size = self.sub_group_size + + if sub_group_size is None or sub_group_size >= params_group_numel: + return [params_group] + + sub_groups = [] + sub_group = [] + local_sub_group_size = 0 + for param in params_group: + + sub_group.append(param) + local_sub_group_size += param.partition_numel() + + if local_sub_group_size >= sub_group_size or id(param) == id(params_group[-1]): + + sub_groups.append(sub_group) + + sub_group = [] + local_sub_group_size = 0 + + return sub_groups + + def _release_ipg_buffers(self): + if self.contiguous_gradients: + self.ipg_buffer = None + + def _optimizer_step(self, sub_group_id): + param_group_id = self.sub_group_to_group_id[sub_group_id] + fp32_param = self.fp32_partitioned_groups_flat[sub_group_id] + if self.offload_optimizer: + cur_device = self.subgroup_to_device[sub_group_id] + if cur_device == 'cpu': + self.optimizer.param_groups[param_group_id]['params'] = [fp32_param] + cpu_loss = self.optimizer.step() + self.optimizer.param_groups[param_group_id]['params'] = [] + else: + self.backup_optimizer.param_groups[param_group_id]['params'] = [fp32_param] + gpu_loss = self.backup_optimizer.step() + self.backup_optimizer.param_groups[param_group_id]['params'] = [] + else: + self.optimizer.param_groups[param_group_id]['params'] = [fp32_param] + self.optimizer.step() + self.optimizer.param_groups[param_group_id]['params'] = [] + + def _swappable_optimizer_subgroup(self, sub_group_id): + if not self.swap_optimizer: + return False + + return self.optimizer_swapper.swappable_tensor(None, + numel=self.fp16_partitioned_groups_flat_numel[sub_group_id]) + + def _partitioned_params_swap_out(self, i): + offset = 0 + fp32_param = self.fp32_partitioned_groups_flat[i] + assert fp32_param is not None, \ + f'fp32 parameters of sub_group {i} is None' + + swap_fp16_params = [] + swap_fp32_params = [] + for param, partitioned_param in zip(self.fp16_groups[i], self.fp16_partitioned_groups[i]): + src = fp32_param.narrow(0, offset, partitioned_param.ds_numel) + if partitioned_param.status == PartitionedParamStatus.AVAILABLE: + partitioned_param.data.copy_(src.data) + else: + swap_fp32_params.append(src) + swap_fp16_params.append(param) + offset += partitioned_param.ds_numel + + if len(swap_fp16_params): + swap_fp16_params[0].nvme_swapper.swap_out_partitioned_params(dst_fp16_params=swap_fp16_params, + src_fp32_params=swap_fp32_params) + + def _set_fp16_partitioned_groups_flat(self): + # setup flat buffers per subgroup, these are each just sections of the + # contiguous flat buffer for all parameters that we created earlier + offset = 0 + for sub_group in self.fp16_groups: + sub_group_numel = sum(param.partition_numel() for param in sub_group) + self.fp16_partitioned_groups_flat.append(self.lp_param_buffer.narrow(0, offset, sub_group_numel)) + offset += sub_group_numel + + def initialize_optimizer_states(self): + num_subgroups = len(self.fp16_groups) + + largest_numel = max([sum([p.ds_numel for p in psg]) for psg in self.fp16_partitioned_groups]) + gradient_dtype = self.fp32_partitioned_groups_flat[0].dtype + gradient_buffer = torch.zeros(int(largest_numel), dtype=gradient_dtype, device=self.device) + + timer_names = set() + + # State initialization for the Adagrad optimizer occurs at construction as opposed to other optimizers + # which do lazy initialization of the state at the first call to step. + is_adagrad = isinstance(self.optimizer, torch.optim.Adagrad) + + if self.swap_optimizer: + self.optimizer_swapper.init_timers() + + timer_names.add(INIT_OPTIMIZER_TIMER) + self.timers(INIT_OPTIMIZER_TIMER).start() + + for i, group in enumerate(self.fp16_groups): + swappable_optimizer_subgroup = self._swappable_optimizer_subgroup(i) + swappable_param_subgroup = self.fp16_partitioned_groups_flat[i] is None + + num_elements = int(self.fp16_partitioned_groups_flat_numel[i]) + + see_memory_usage( + f'[Begin] Initialize optimizer states {i} / {num_subgroups} subgroups, num_elems: {num_elements}, swappable opt/param:{swappable_optimizer_subgroup}/{swappable_param_subgroup}', + force=False) + + if swappable_optimizer_subgroup: + self._optimizer_states_and_gradient_swap_in(i, timer_names) + + if self.offload_optimizer and not swappable_optimizer_subgroup: + subgroup_gradient_buffer = torch.zeros(num_elements, dtype=gradient_dtype, device=self.device) + if self.offload_optimizer_pin_memory: + subgroup_gradient_buffer = get_accelerator().pin_memory(subgroup_gradient_buffer) + + self.fp32_partitioned_groups_flat[i].grad = subgroup_gradient_buffer.to(self.subgroup_to_device[i]) + else: + self.fp32_partitioned_groups_flat[i].grad = gradient_buffer.narrow(0, 0, num_elements) + + if swappable_param_subgroup: + self._partitioned_params_swap_out(i) + + if swappable_optimizer_subgroup: + self._optimizer_states_and_gradient_swap_out(i, timer_names) + + see_memory_usage( + f'[End] Initialize optimizer states {i} / {num_subgroups} subgroups, num_elems: {num_elements}, swappable opt/param:{swappable_optimizer_subgroup}/{swappable_param_subgroup}', + force=False) + + # Initialize the optimizer states with the flattened fp32 partition. + if is_adagrad: + self.optimizer = torch.optim.Adagrad(self.fp32_partitioned_groups_flat, **self.optimizer.defaults) + + self.timers(INIT_OPTIMIZER_TIMER).stop() + self.timers.log(timer_names) + + if self.swap_optimizer: + self.optimizer_swapper.log_timers() + + if not self.offload_optimizer: + for group in self.fp32_partitioned_groups_flat: + group.grad = None + + # Reset steps + return + + ######################################################################### + #########################ZeRO Partition Gradients######################## + ######################################################################### + + def get_first_param_index(self, group_id, param_group, partition_id): + for index, param in enumerate(param_group): + param_id = self.get_param_id(param) + if partition_id in self.param_to_partition_ids[group_id][param_id]: + return index + return None + + def initialize_gradient_partitioning_data_structures(self): + + total_partitions = dist.get_world_size(group=self.dp_process_group) + + for i, param_group in enumerate(self.fp16_groups): + + self.param_to_partition_ids[i] = {} + self.is_partition_reduced[i] = {} + self.total_grads_in_partition[i] = {} + self.remaining_grads_in_partition[i] = {} + self.is_grad_computed[i] = {} + self.grad_partition_insertion_offset[i] = {} + self.grad_start_offset[i] = {} + self.first_param_index_in_partition[i] = {} + + for partition_id in range(total_partitions): + self.is_grad_computed[i][partition_id] = {} + self.grad_partition_insertion_offset[i][partition_id] = {} + self.grad_start_offset[i][partition_id] = {} + self.initialize_gradient_partition(i, param_group, partition_id) + self.is_partition_reduced[i][partition_id] = False + self.first_param_index_in_partition[i][partition_id] = self.get_first_param_index( + i, param_group, partition_id) + + @instrument_w_nvtx + def independent_gradient_partition_epilogue(self): + self.report_ipg_memory_usage(f"In ipg_epilogue before reduce_ipg_grads", 0) + self.__reduce_and_partition_ipg_grads() + self.report_ipg_memory_usage(f"In ipg_epilogue after reduce_ipg_grads", 0) + + if not get_accelerator().resolves_data_dependency(): + self.reduce_and_partition_stream.synchronize() + + for param_id in self.params_already_reduced.keys(): + self.params_already_reduced[param_id] = False + + #in case of cpu offload, averaged gradients are already in fp32_partitioned_groups_flat.grad + #TODO: use a similar code path for both cpu_offload and non-cpu offload + if not self.offload_optimizer: + for i, sub_group in enumerate(self.fp16_groups): + #TODO: This is redundant + self.averaged_gradients[i] = [ + self.__param_id_to_grad_partition[param.ds_id] + if param.requires_grad else torch.zeros_like(param.ds_tensor) for param in sub_group + ] + # this method gets called after every backward. need to increment + # here because if it gets incremented in backward() the micro step + # id will be off by one when we do the reduce and partition at the. + # start of this method. + # TODO. make this less error prone + self.micro_step_id += 1 + + def overlapping_partition_gradients_reduce_epilogue(self): + self.independent_gradient_partition_epilogue() + + def create_reduce_and_remove_grad_hooks(self): + print_rank_0(f'[Begin] Create gradient reduction hooks') + self.grad_accs = [] + self.leaf_parameters = defaultdict(list) + for i, param_group in enumerate(self.fp16_groups): + for param in param_group: + if param.requires_grad: + #print_rank_0(f" Before all gather {param.device}, {param.shape}") + print_rank_0(f"Before all gather {param.device}, {param.shape}", force=False) + + # The hook must be created in un-partitioned parameter + param.all_gather() + + #print(f"After all gather {param.device}, {param.shape}") + def wrapper(param): + param_tmp = param.expand_as(param) + grad_acc = param_tmp.grad_fn.next_functions[0][0] + + @instrument_w_nvtx + def reduce_partition_and_remove_grads(*notneeded): + self.reduce_ready_partitions_and_remove_grads(param) + + self._grad_acc_hooks.append(grad_acc.register_hook(reduce_partition_and_remove_grads)) + self.grad_accs.append(grad_acc) + + #print(f"param grad fn {param.expand_as(param).grad_fn}") + if z3_leaf_parameter(param): + self.leaf_parameters[param.ds_z3_leaf_module].append(param) + else: + wrapper(param) + + # Partition the parameter after creating the hook + param.partition() + + # We delay reduce-scatter for all gradients in the leaf modules until the backward pass of the leaf module is done + for leaf_module, leaf_parameters in self.leaf_parameters.items(): + + def wrapper_pre_hook(params): + + def forward_pre_hook(module, input): + """Pre-forward hook to set backward hook on input tensors to the leaf module""" + module._leaf_module_inputs_remaining = 0 + + @instrument_w_nvtx + def reduce_leaf_module_grads(grad): + module._leaf_module_inputs_remaining -= 1 + # Make sure everything is done in the leaf module + if module._leaf_module_inputs_remaining == 0: + for param in params: + if param.grad is None: + param.grad = torch.zeros_like(param) + self.reduce_ready_partitions_and_remove_grads(param) + + def set_module_bwd_hook(tensor): + if tensor.requires_grad: + module._leaf_module_inputs_remaining += 1 + tensor.register_hook(reduce_leaf_module_grads) + return tensor + + output = apply_to_tensors_only(set_module_bwd_hook, input) + + return output + + return forward_pre_hook + + def wrapper_post_hook(): + + def forward_post_hook(module, input, output): + """Pre-forward hook to set backward hook on input tensors to the leaf module""" + module._leaf_output_required_grad_num = 0 + + def increment_rg_count_bwd_hook(tensor): + if tensor.requires_grad: + module._leaf_output_required_grad_num += 1 + return tensor + + apply_to_tensors_only(increment_rg_count_bwd_hook, output) + + if module._leaf_module_inputs_remaining == 0 and module._leaf_output_required_grad_num > 0: + raise RuntimeError( + "A module cannot be set as a leaf module when it does not have any input tensors that require gradients and has output tensors that require gradients. This is because the gradient reduction hook will not be called in this case." + ) + + return forward_post_hook + + self._leaf_module_hooks.append(leaf_module.register_forward_pre_hook(wrapper_pre_hook(leaf_parameters))) + self._leaf_module_hooks.append(leaf_module.register_forward_hook(wrapper_post_hook())) + + print_rank_0(f'[End] Create gradient reduction hooks') + + def get_param_id(self, param): + return OptimizerSwapper.parameter_id(param) + + def report_ipg_memory_usage(self, tag, param_elems): + elem_count = self.elements_in_ipg_bucket + param_elems + percent_of_bucket_size = (100.0 * elem_count) // self.reduce_bucket_size + see_memory_usage( + f"{tag}: elems in_bucket {self.elements_in_ipg_bucket} param {param_elems} max_percent {percent_of_bucket_size}", + force=False) + + ###############Independent Partition Gradient ######################## + def reduce_independent_p_g_buckets_and_remove_grads(self, param): + #print_rank_0(f"Inside reduce ipg buckets. {debug_param2name_id_shape(param)}, ipg elements {self.elements_in_ipg_bucket}, reduce bucket size {self.reduce_bucket_size}", force=True) + + # Because the ipg bucket is initialized with a random place holder tensor, we must + # explicitly check that the bucket has any real data in it (self.elements_in_ipg_bucket > + # 0). Otherwise if the incoming param.ds_numel is large, this branch may get triggered on a + # garbage data and `self.average_tensor()` will crash because its params_to_reduce will be + # empty, while reduction_list will have that garbage data. + if self.elements_in_ipg_bucket + param.ds_numel > self.reduce_bucket_size and self.elements_in_ipg_bucket > 0: + self.report_ipg_memory_usage("In ipg_remove_grads before reduce_ipg_grads", param.ds_numel) + + self.__reduce_and_partition_ipg_grads() + + self.__add_grad_to_ipg_bucket(param) + + @instrument_w_nvtx + @torch.no_grad() + def __add_grad_to_ipg_bucket(self, param: Parameter) -> None: + if not get_accelerator().resolves_data_dependency(): + self.reduce_and_partition_stream.wait_stream(get_accelerator().default_stream()) + + if self.contiguous_gradients and self.elements_in_ipg_bucket + param.grad.numel() <= self.reduce_bucket_size: + # move the gradient to a contiguous buffer + with get_accelerator().stream(self.reduce_and_partition_stream): + # move the parameter's gradient to the contiguous flat buffer + new_grad_tensor = self.__ipg_bucket_flat_buffer.narrow(0, self.elements_in_ipg_bucket, + param.grad.numel()).view_as(param.grad) + new_grad_tensor.copy_(param.grad, non_blocking=True) + if not get_accelerator().is_synchronized_device(): + param.grad.record_stream(get_accelerator().current_stream()) + param.grad.data = new_grad_tensor + + self.params_in_ipg_bucket.append(param) + + @instrument_w_nvtx + @torch.no_grad() + def __reduce_and_partition_ipg_grads(self, safe_mode: bool = False) -> None: + if not self.params_in_ipg_bucket: + return + + for param in self.params_in_ipg_bucket: + if param.grad.numel() != param.ds_numel: + raise RuntimeError(f"{param.grad.numel()} != {param.ds_numel} Cannot reduce scatter " + f"gradients whose size is not same as the params") + + assert len(set(p.ds_id for p in self.params_in_ipg_bucket)) == len(self.params_in_ipg_bucket) + + while self.param_reduce_events and self.param_reduce_events[0].query(): + self.param_reduce_events.popleft() + if len(self.param_reduce_events) > self.max_param_reduce_events: + self.param_reduce_events.popleft().synchronize() + + with get_accelerator().stream(self.reduce_and_partition_stream): + if safe_mode: + assert_ints_same_as_other_ranks([p.ds_id for p in self.params_in_ipg_bucket]) + + if self.contiguous_gradients and self.elements_in_ipg_bucket <= self.reduce_bucket_size and not self.reduce_scatter: + grad_bucket = self.__ipg_bucket_flat_buffer.narrow(0, 0, self.elements_in_ipg_bucket) + grad_partitions = self.__avg_scatter_contiguous_grads(grad_bucket) + else: + self.params_in_ipg_bucket.sort(key=lambda p: p.ds_id) + grad_partitions = self.__avg_scatter_grads(self.params_in_ipg_bucket) + + self.partition_grads(self.params_in_ipg_bucket, grad_partitions) + + self.params_in_ipg_bucket.clear() + + if not get_accelerator().handles_memory_backpressure(): + event = get_accelerator().Event() + event.record() + self.param_reduce_events.append(event) + + @instrument_w_nvtx + def __avg_scatter_contiguous_grads(self, buffer_to_reduce: Tensor) -> List[Tensor]: + dtype = buffer_to_reduce.dtype + if self.communication_data_type != dtype: + buffer_to_reduce = buffer_to_reduce.to(self.communication_data_type) + if self.postscale_gradients and self.gradient_predivide_factor != 1.0: + buffer_to_reduce = buffer_to_reduce.div_(self.gradient_predivide_factor) + + world_sz = dist.get_world_size(self.dp_process_group) + rank = dist.get_rank(self.dp_process_group) + buffer_to_reduce.div_(world_sz / float(self.sequence_parallel_size)) + + dist.all_reduce(buffer_to_reduce, group=self.dp_process_group) + + if self.postscale_gradients and self.gradient_predivide_factor != world_sz: + buffer_to_reduce = buffer_to_reduce.mul(self.gradient_predivide_factor) + + if self.communication_data_type != self.dtype: + buffer_to_reduce = buffer_to_reduce.to(self.dtype) + + grad_partitions = [] + grad_offset_in_buffer = 0 + for param in self.params_in_ipg_bucket: + grad = param.grad + chunk_sz = math.ceil(grad.numel() / world_sz) + + start_offset = grad_offset_in_buffer + min(rank * chunk_sz, grad.numel()) + end_offset = grad_offset_in_buffer + min(rank * chunk_sz + chunk_sz, grad.numel()) + + partition = buffer_to_reduce[start_offset:end_offset] + if param.partition_numel() != partition.numel(): + padded_partition = torch.zeros(param.partition_numel(), device=grad.device, dtype=grad.dtype) + if partition.numel() > 0: + padded_partition[:partition.numel()] = partition + grad_partitions.append(padded_partition) + else: + grad_partitions.append(partition) + grad_offset_in_buffer += grad.numel() + + return grad_partitions + + @instrument_w_nvtx + def __avg_scatter_grads(self, params_to_reduce: List[Parameter]) -> List[Tensor]: + """average gradients and scatter partitions across ranks""" + + full_grads_for_rank = [p.grad for p in params_to_reduce] + if self.communication_data_type != self.dtype: + full_grads_for_rank = [g.to(self.communication_data_type) for g in full_grads_for_rank] + + if self.postscale_gradients and self.gradient_predivide_factor != 1.0: + full_grads_for_rank = [g.div(self.gradient_predivide_factor) for g in full_grads_for_rank] + + local_world_size = get_accelerator().device_count() + global_world_size = dist.get_world_size() + num_nodes = global_world_size // local_world_size + if self.all2all_process_group is not None and num_nodes > 1: + grad_partitions_for_rank = all_to_all_quant_reduce(full_grads_for_rank, self.all2all_process_group) + else: + grad_partitions_for_rank = reduce_scatter_coalesced(full_grads_for_rank, self.dp_process_group) + + if self.postscale_gradients and self.gradient_predivide_factor != 1.0 and self.gradient_predivide_factor != dist.get_world_size( + self.dp_process_group): + grad_partitions_for_rank = [g.mul(self.gradient_predivide_factor) for g in grad_partitions_for_rank] + + if self.communication_data_type != self.dtype: + grad_partitions_for_rank = [g.to(self.dtype) for g in grad_partitions_for_rank] + + return grad_partitions_for_rank + + def set_grad_positions(self): + for i, group in enumerate(self.fp16_groups): + current_offset = 0 + for param in group: + param_id = self.get_param_id(param) + num_elements = param.partition_numel() + + self.grad_position[param_id] = [int(i), int(current_offset), int(num_elements)] + #print(f"param id {param_id} i:{i}, ds_tensor {num_elements} numel {param.numel()}") + current_offset += num_elements + see_memory_usage(f"After Set Grad positions", force=False) + + def _constant_buffered_norm2(self, input, buffer_size=250000000): + norm = None + for part in input.view(-1).split(buffer_size): + if norm is None: + norm = part.data.float().norm(2)**2.0 + else: + norm += part.data.float().norm(2)**2.0 + return norm**0.5 + + def set_norm_for_param_grad_in_gpu(self, param): + param_id = self.get_param_id(param) + #self.norm_for_param_grads[param_id] = param.grad.data.double().norm(2) + #Using a more memory efficient version + self.norm_for_param_grads[param_id] = self._constant_buffered_norm2(param.grad) + + def async_inplace_copy_grad_to_fp32_buffer_from_gpu(self, param, fp32_grad_tensor): + with get_accelerator().stream(self.copy_grad_stream): + param_id = self.get_param_id(param) + src_tensor = param.grad.view(-1).float() + #print(f"src_tensor {src_tensor.size()} and fp32 grad {fp32_grad_tensor.size()}") + fp32_grad_tensor.copy_(src_tensor, non_blocking=True) + param.grad = None + + def complete_grad_norm_calculation_for_cpu_offload(self, params): + total_norm = 0.0 + norm_type = 2.0 + for p in params: + if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0): + param_id = self.get_param_id(p) + if param_id in self.norm_for_param_grads.keys(): + param_norm = self.norm_for_param_grads[param_id] + total_norm += param_norm**2 + + # Sum across all model parallel GPUs. + total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)]) + + dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.SUM, group=self.dp_process_group) + + self._model_parallel_all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.SUM) + + total_norm = total_norm_cuda[0]**(1. / norm_type) + + norm_is_inf = total_norm.isinf() + norm_is_nan = total_norm.isnan() + inf_or_nan = norm_is_nan.logical_or(norm_is_inf) + + err = torch.tensor(-1.0, device=inf_or_nan.device, dtype=torch.float) + total_norm = inf_or_nan * err + inf_or_nan.logical_not() * total_norm + + return total_norm.cpu() + + @instrument_w_nvtx + def partition_grads(self, params_to_release: List[Parameter], grad_partitions: List[Tensor]) -> None: + offload_fp32_gradients = {} + offload_fp32_offsets = {} + buffers = [] + for param, grad_partition in zip(params_to_release, grad_partitions): + + contains_real_data = param.partition_numel() * dist.get_rank(self.dp_process_group) < param.ds_numel + if not contains_real_data: + # this grad partition is empty - don't need to do anything + param.grad = None + continue + + # move or accumulate gradient partition to target buffer + grad_buffer = self.__param_id_to_grad_partition[param.ds_id].narrow(0, 0, grad_partition.numel()) + buffers.append(grad_buffer) + if self.micro_step_id == 0: # don't accumulate + grad_buffer.copy_(grad_partition, non_blocking=True) + # ensure grad buffer is a CUDA buffer to speed up the next few + # operations and so it can be used asynchronously + grad_buffer = grad_buffer.to(grad_partition.device, non_blocking=True) + elif get_accelerator().on_accelerator(grad_buffer): + grad_buffer.add_(grad_partition.to(self.gradient_accumulation_dtype).view(grad_buffer.shape)) + else: + # if dst is CPU, copy first to src device, do the addition + # there, then move back to dst. adding directly to cpu is very slow + cuda_grad_buffer = grad_buffer.to(grad_partition.device, non_blocking=True) + cuda_grad_buffer.add_(grad_partition.to(self.gradient_accumulation_dtype).view(cuda_grad_buffer.shape)) + grad_buffer.copy_(cuda_grad_buffer, non_blocking=True) + # ensure grad buffer is a CUDA buffer to speed up the next few + # operations and so it can be used asynchronously + grad_buffer = cuda_grad_buffer + + # offload the gradient partition if applicable + if self.offload_optimizer: + i, dest_offset, _ = self.grad_position[self.get_param_id(param)] + + if self.is_gradient_accumulation_boundary: + self.norm_for_param_grads[self.get_param_id(param)] = self._constant_buffered_norm2(grad_buffer) + + if self._swappable_optimizer_subgroup(i): + if not i in offload_fp32_gradients.keys(): + offload_fp32_gradients[i] = [] + offload_fp32_offsets[i] = [] + + offload_fp32_gradients[i].append(grad_buffer.float()) + offload_fp32_offsets[i].append(dest_offset) + else: + fp32_grad_tensor = self.fp32_partitioned_groups_flat[i].grad.narrow( + 0, dest_offset, grad_buffer.numel()) + fp32_grad_tensor.copy_(grad_buffer) + + # free the gradient + if not get_accelerator().is_synchronized_device(): + param.grad.record_stream(get_accelerator().current_stream()) + param.grad = None + + if self.offload_optimizer and self.swap_optimizer: + for i in offload_fp32_gradients.keys(): + self.optimizer_swapper.swap_out_gradients(parameter=self.fp32_partitioned_groups_flat[i], + gradient_offsets=offload_fp32_offsets[i], + gradient_tensors=offload_fp32_gradients[i]) + return buffers + + def reduce_ready_partitions_and_remove_grads(self, param): + #print_rank_0(f"Backward {debug_param2name_id_shape(param)}", force=True) + self.reduce_independent_p_g_buckets_and_remove_grads(param) + + def zero_reduced_gradients(self, partition_id, i): + + def are_all_related_partitions_reduced(params_id): + for partition_id in self.param_to_partition_ids[i][params_id]: + if not self.is_partition_reduced[i][partition_id]: + return False + return True + + for params_id in self.is_grad_computed[i][partition_id]: + if are_all_related_partitions_reduced(params_id): + self.param_dict[params_id].grad = None + + def quantize_nontrainable_params(self): + """ In ZeRO-3, when the zero_quantized_nontrainable_weights flag is set, we quantize the non-trainable weights and also store them in quantized format. However, this check for trainable/non-trainable is done when deepspeed initializes the partitioning. So, if the user changes the trainable/non-trainable status of a parameter after the partitioning is done (e.g. LoRA), the user needs to re-quantize the non-trainable weights by calling this function. + """ + if not self.zero_quantized_nontrainable_weights: + print_rank_0( + f"Warning: quantize_nontrainable_params() called with zero_quantized_nontrainable_weights disabled, return without doing anything", + force=True) + return + quantizer_module = CUDAQuantizer() + + def quantize_dstensor(tensor): + assert tensor.dtype == torch.float16, f"quantize_dstensor() expects tensor.dtype == torch.float16, got {tensor.dtype}" + partition_size = tensor.ds_numel + ds_status = tensor.status + final_location = tensor.final_location + tensor, tensor.ds_quant_scale = quantizer_module.quantize(tensor) + tensor.ds_numel = partition_size + tensor.status = ds_status + tensor.final_location = final_location + tensor.requires_grad = False + return tensor + + for param in self.module.parameters(): + if hasattr(param, "ds_tensor") and (param.ds_tensor.numel() <= 2048 or param.ds_numel <= 500000): + # skip small parameters + continue + if hasattr(param, + "ds_tensor") and not param.requires_grad and not hasattr(param.ds_tensor, "ds_quant_scale"): + param.ds_tensor = quantize_dstensor(param.ds_tensor) + if hasattr(param, "ds_secondary_tensor") and not param.requires_grad and not hasattr( + param.ds_secondary_tensor, "ds_quant_scale") and param.ds_secondary_tensor is not None: + param.ds_secondary_tensor = quantize_dstensor(param.ds_secondary_tensor) + get_accelerator().synchronize() + + def flatten_and_print(self, message, tensors, start=0, n=5): + flatten_tensor = self.flatten(tensors) + + def print_func(): + logger.info(flatten_tensor.contiguous().view(-1).narrow(0, start, n)) + + self.sequential_execution(print_func, message) + + def get_grads_to_reduce(self, i, partition_id): + + def get_reducible_portion(key): + grad = self.param_dict[key].grad + total_elements = grad.numel() + start = self.grad_start_offset[i][partition_id][key] + num_elements = min(total_elements - start, + self.partition_size[i] - self.grad_partition_insertion_offset[i][partition_id][key]) + if not pg_correctness_test: + if num_elements == total_elements: + return grad + else: + return grad.contiguous().view(-1).narrow(0, int(start), int(num_elements)) + else: + if num_elements == total_elements: + return grad.clone() + else: + return grad.clone().contiguous().view(-1).narrow(0, int(start), int(num_elements)) + + grads_to_reduce = [] + for key in self.is_grad_computed[i][partition_id]: + grad = get_reducible_portion(key) + grads_to_reduce.append(grad) + return grads_to_reduce + + def sequential_execution(self, function, message, group=None): + if group is None: + group = self.dp_process_group + if dist.get_rank(group=group) == 0: + logger.info(message) + for id in range(dist.get_world_size(group=group)): + if id == dist.get_rank(group=group): + function() + dist.barrier(group=group) + + def set_none_gradients_to_zero(self, i, partition_id): + for param_id in self.is_grad_computed[i][partition_id]: + param = self.param_dict[param_id] + if param.grad is None: + param.grad = torch.zeros_like(param) + + ######################Reduction Related Methods############################## + + def allreduce_bucket(self, bucket, rank=None, log=None): + rank = None + tensor = self.flatten(bucket) + + tensor_to_allreduce = tensor + + if pg_correctness_test: + communication_data_type = torch.float32 + else: + communication_data_type = self.communication_data_type + + if communication_data_type != tensor.dtype: + tensor_to_allreduce = tensor.to(communication_data_type) + + tensor_to_allreduce.div_(dist.get_world_size(group=self.dp_process_group) / float(self.sequence_parallel_size)) + + if rank is None: + # "All Reducing" + dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group) + else: + global_rank = dist.get_global_rank(self.dp_process_group, rank) + dist.reduce(tensor_to_allreduce, global_rank, group=self.dp_process_group) + + if communication_data_type != tensor.dtype and tensor is not tensor_to_allreduce: + if rank is None or rank == dist.get_rank(group=self.dp_process_group): + tensor.copy_(tensor_to_allreduce) + + return tensor + + # if rank is specified do a reduction instead of an allreduce + def allreduce_and_copy(self, small_bucket, rank=None, log=None): + with get_accelerator().stream(self.reduction_stream): + allreduced = self.allreduce_bucket(small_bucket, rank=rank, log=log) + if rank is None or rank == dist.get_rank(group=self.dp_process_group): + for buf, synced in zip(small_bucket, self.unflatten(allreduced, small_bucket)): + buf.copy_(synced) + + def allreduce_no_retain(self, bucket, numel_per_bucket=500000000, rank=None, log=None): + small_bucket = [] + numel = 0 + for tensor in bucket: + small_bucket.append(tensor) + numel = numel + tensor.numel() + if numel > numel_per_bucket: + self.allreduce_and_copy(small_bucket, rank=rank, log=None) + small_bucket = [] + if len(small_bucket) > 0: + self.allreduce_and_copy(small_bucket, rank=rank, log=log) + + ############################################################################# + ############################################################################# + ############################################################################# + + # views the tensor as multiple partitions and returns + # those partitions + def get_data_parallel_partitions(self, tensor): + partitions = [] + + dp = dist.get_world_size(group=self.dp_process_group) + dp_id = dist.get_rank(group=self.dp_process_group) + + total_num_elements = tensor.numel() + + base_size = total_num_elements // dp + remaining = total_num_elements % dp + + start = 0 + for id in range(dp): + partition_size = base_size + if id < remaining: + partition_size = partition_size + 1 + partitions.append(tensor.narrow(0, start, partition_size)) + start = start + partition_size + return partitions + + def get_partition_info(self, tensor_list, partition_size, partition_id): + params_in_partition = [] + params_not_in_partition = [] + + start_index = partition_size * partition_id + end_index = partition_size * (partition_id + 1) + + current_index = 0 + first_offset = 0 + + for tensor in tensor_list: + + tensor_size = tensor.numel() + + if start_index <= current_index < end_index: + params_in_partition.append(tensor) + + elif current_index < start_index < (current_index + tensor_size): + params_in_partition.append(tensor) + + assert (first_offset == 0 + ), "This can happen either zero or only once as this must be the first tensor in the partition" + first_offset = start_index - current_index + + else: + params_not_in_partition.append(tensor) + + current_index = current_index + tensor_size + + return params_in_partition, params_not_in_partition, first_offset + + @instrument_w_nvtx + def zero_grad(self, set_to_none=True): + """ + Zero FP16 parameter grads. + """ + self.micro_step_id = 0 + + # FP32 grad should never exist. + # For speed, set model fp16 grad to None by default + for group in self.fp16_groups: + for p in group: + if set_to_none: + if p.grad is not None and get_accelerator().on_accelerator(p.grad): + p.grad.record_stream(get_accelerator().current_stream()) + p.grad = None + else: + if p.grad is not None: + p.grad.detach_() + p.grad.zero_() + + def _model_parallel_all_reduce(self, tensor, op): + """ Perform all reduce within model parallel group, if any. + """ + if self.model_parallel_group is None: + pass + else: + dist.all_reduce(tensor=tensor, op=op, group=self.model_parallel_group) + + @instrument_w_nvtx + def get_grad_norm_direct(self, gradients, params, norm_type=2): + """Clips gradient norm of an iterable of parameters. + + This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and + added functionality to handle model parallel parameters. Note that + the gradients are modified in place. + + Arguments: + parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a + single Tensor that will have gradients normalized + max_norm (float or int): max norm of the gradients + norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for + infinity norm. + + Returns: + Total norm of the parameters (viewed as a single vector). + """ + norm_type = float(norm_type) + if norm_type == inf: + total_norm = max(g.data.abs().max() for g in gradients) + total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)]) + dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, group=self.dp_process_group) + + # Take max across all GPUs. + self._model_parallel_all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX) + total_norm = total_norm_cuda[0] + else: + # if dist.get_rank() == 0: + # logger.info(f"Total Norm beginning {total_norm}") + grad_norms = [] + for g, p in zip(gradients, params): + if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0): + grad_norms.append(g.to(get_accelerator().device_name(), non_blocking=True).float().norm(2)) + + # Sum across all model parallel GPUs. + if len(grad_norms) == 0: + # FIX https://github.com/microsoft/DeepSpeed/issues/3564 + total_norm_cuda = torch.tensor(0, + dtype=gradients[0].dtype).to(get_accelerator().device_name()).float() + else: + total_norm_cuda = torch.sum(torch.pow(torch.stack(grad_norms), 2)) + + dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.SUM, group=self.dp_process_group) + + self._model_parallel_all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.SUM) + + total_norm = total_norm_cuda**(1. / norm_type) + + norm_is_inf = total_norm.isinf() + norm_is_nan = total_norm.isnan() + inf_or_nan = norm_is_nan.logical_or(norm_is_inf) + + err = torch.tensor(-1.0, device=self.device, dtype=torch.float) + total_norm = inf_or_nan * err + inf_or_nan.logical_not() * total_norm + + return total_norm + + # creates a flat fused tensor from the tensor list starting at the first_offset + # in the first tensor of the list. If there are not enough elements in the tensor + # list then the flat tensor will be padded with zeros + def get_flat_partition(self, tensor_list, first_offset, partition_size, return_tensor_list=False): + flat_tensor_list = [] + current_size = 0 + for i, tensor in enumerate(tensor_list): + if tensor.grad is None: + tensor.grad = torch.zeros_like(tensor) + + tensor = tensor.grad + num_elements = tensor.numel() + tensor_offset = 0 + + # we need to offset to get to the right element + if i == 0 and first_offset > 0: + tensor_offset = first_offset + num_elements = num_elements - tensor_offset + + # we dont need all elements of the tensor + if num_elements > (partition_size - current_size): + num_elements = partition_size - current_size + + # we need a narrow view of the tensor based on the tensor offset and number of elements that + # we need from this tensor + if tensor_offset > 0 or num_elements < tensor.numel(): + flat_tensor_list.append(tensor.contiguous().view(-1).narrow(0, int(tensor_offset), int(num_elements))) + else: + flat_tensor_list.append(tensor) + + current_size = current_size + num_elements + + # this means its the last partition and does not align with the dp boundary. We need to pad before flattening + if current_size < partition_size: + flat_tensor_list.append( + torch.zeros(int(partition_size - current_size), + dtype=tensor_list[0].dtype, + device=tensor_list[0].device)) + + if return_tensor_list: + return flat_tensor_list + + return self.flatten(flat_tensor_list) + + def free_grad_in_param_list(self, param_list): + for p in param_list: + p.grad = None + + def reset_cpu_buffers(self): + self.norm_for_param_grads = {} + + def _pre_step(self): + self.micro_step_id = 0 + + print_rank_0(f"Inside Step function") + see_memory_usage(f"In step before checking overflow", force=False) + + print_rank_0("Finished Tracing at Beginning of Step") + self._get_param_coordinator(training=True).hierarchy = 0 + + print_rank_0("Finished Tracing at Beginning of Step") + + @instrument_w_nvtx + def _get_norm_groups(self): + norm_groups = [] + for i, group in enumerate(self.fp16_groups): + if self.offload_optimizer: + norm_groups.append(self.complete_grad_norm_calculation_for_cpu_offload(self.fp16_groups[i])) + else: + norm_groups.append(self.get_grad_norm_direct(self.averaged_gradients[i], self.fp16_groups[i])) + return norm_groups + + @instrument_w_nvtx + def _prepare_fp32_grad_for_sub_group(self, sub_group_id): + partition_id = dist.get_rank(group=self.dp_process_group) + + single_grad_partition = self.flatten(self.averaged_gradients[sub_group_id]).to( + self.fp32_partitioned_groups_flat[sub_group_id].dtype) + + assert single_grad_partition.numel() == self.fp32_partitioned_groups_flat[sub_group_id].numel(), \ + "averaged gradients have different number of elements that partition size {} {} {} {}".format( + single_grad_partition.numel(), self.fp32_partitioned_groups_flat[sub_group_id].numel(), sub_group_id, partition_id) + + self.fp32_partitioned_groups_flat[sub_group_id].grad = single_grad_partition + + # release all the gradient since we have already created a necessary copy in dp_grad_partition + self.zero_grad(set_to_none=True) + + if not get_accelerator().is_synchronized_device(): + for grad in filter(lambda g: get_accelerator().on_accelerator(g), self.averaged_gradients[sub_group_id]): + grad.record_stream(get_accelerator().current_stream()) + + self.averaged_gradients[sub_group_id] = None + + @instrument_w_nvtx + def _prepare_sub_group(self, sub_group_id, timer_names): + see_memory_usage(f'Before prepare optimizer sub group {sub_group_id}', force=False) + if self._swappable_optimizer_subgroup(sub_group_id): + self._optimizer_states_and_gradient_swap_in(sub_group_id, timer_names) + elif not self.offload_optimizer: + self._prepare_fp32_grad_for_sub_group(sub_group_id) + see_memory_usage(f'After prepare optimizer sub group {sub_group_id}', force=False) + + def _optimizer_states_and_gradient_swap_in(self, sub_group_id, timer_names): + param_length = self.fp16_partitioned_groups_flat_numel[sub_group_id] + fp32_param_id = self.get_param_id(self.fp32_partitioned_groups_flat[sub_group_id]) + assert self._swappable_optimizer_subgroup(sub_group_id), \ + f'Parameter {fp32_param_id} of numel={param_length} is not swappable' + + see_memory_usage(f'pre-step Before swapping in optimizer tensors {sub_group_id}', force=False) + timer_names.add(OPTIMIZER_SWAP_IN_STATE_TIMER) + self.timers(OPTIMIZER_SWAP_IN_STATE_TIMER).start() + + self.optimizer_swapper.swap_in_optimizer_state( + parameter=self.fp32_partitioned_groups_flat[sub_group_id], + async_parameter=self.next_swappable_fp32_partitioned_groups[sub_group_id]) + + self.timers(OPTIMIZER_SWAP_IN_STATE_TIMER).stop() + see_memory_usage(f'pre-step After swapping in optimizer tensors {sub_group_id}', force=False) + + @instrument_w_nvtx + def _release_sub_group(self, sub_group_id, timer_names): + see_memory_usage(f'Before release optimizer sub group {sub_group_id}', force=False) + # get rid of the fp32 gradients. Not needed anymore + if not self.offload_optimizer: + self.fp32_partitioned_groups_flat[sub_group_id].grad = None + + if self._swappable_optimizer_subgroup(sub_group_id): + self._optimizer_states_and_gradient_swap_out(sub_group_id, timer_names) + see_memory_usage(f'After release optimizer sub group {sub_group_id}', force=False) + + # create a flat tensor aligned at the alignment boundary + @instrument_w_nvtx + def flatten_dense_tensors_aligned(self, tensor_list, alignment): + num_elements = 0 + for tens in tensor_list: + num_elements = num_elements + tens.numel() + + remaining = num_elements % alignment + + if remaining: + elements_to_add = alignment - remaining + pad_tensor = torch.zeros(elements_to_add, device=tensor_list[0].device, dtype=tensor_list[0].dtype) + padded_tensor_list = tensor_list + [pad_tensor] + + num_elements = num_elements + elements_to_add + else: + padded_tensor_list = tensor_list + + return self.flatten(padded_tensor_list) + + def _optimizer_states_and_gradient_swap_out(self, sub_group_id, timer_names): + param_length = self.fp16_partitioned_groups_flat_numel[sub_group_id] + fp32_param_id = self.get_param_id(self.fp32_partitioned_groups_flat[sub_group_id]) + assert self._swappable_optimizer_subgroup(sub_group_id), \ + f'Parameter {fp32_param_id} of numel={param_length} is not swappable' + + see_memory_usage(f'post-step Before swapping out optimizer tensors {sub_group_id}', force=False) + timer_names.add(OPTIMIZER_SWAP_OUT_STATE_TIMER) + self.timers(OPTIMIZER_SWAP_OUT_STATE_TIMER).start() + + self.optimizer_swapper.swap_out_optimizer_state( + parameter=self.fp32_partitioned_groups_flat[sub_group_id], + async_swap=self.next_swappable_fp32_partitioned_groups[sub_group_id] is not None) + + self.timers(OPTIMIZER_SWAP_OUT_STATE_TIMER).stop() + see_memory_usage(f'post-step After swapping out optimizer tensors {sub_group_id}', force=False) + + # get rid of the fp32 gradients. Not needed anymore + self.fp32_partitioned_groups_flat[sub_group_id].grad = None + + def _unflatten_partitioned_parameters(self, sub_group_id): + updated_params = self.unflatten(self.fp16_partitioned_groups_flat[sub_group_id], + self.fp16_partitioned_groups[sub_group_id]) + + for partitioned_param, q in zip(self.fp16_partitioned_groups[sub_group_id], updated_params): + partitioned_param.data = q.data + + def _overflow_clean_up(self, prev_scale): + see_memory_usage('After overflow before clearing gradients', force=False) + self.zero_grad(set_to_none=True) + + if self.offload_optimizer: + self.reset_cpu_buffers() + else: + self.averaged_gradients = {} + + see_memory_usage('After overflow after clearing gradients', force=False) + + @instrument_w_nvtx + def _overflow_check_and_loss_scale_update(self): + + # First compute norm for all group so we know if there is overflow + if self.dtype == torch.float16: + self.check_overflow() + + #loss scaling related computation + prev_scale = self.loss_scale + self._update_scale(self.overflow) + + if self.overflow: + self._overflow_clean_up(prev_scale) + + return self.overflow + + @instrument_w_nvtx + def _post_step(self, timer_names): + if self.offload_optimizer: + self.reset_cpu_buffers() + + #Gathering persisting parameters + if len(self.persistent_parameters) > 0: + self.persistent_parameters[0].all_gather(self.persistent_parameters) + + if self.swap_optimizer: + self.optimizer_swapper.log_timers() + + self.invalidate_secondary_tensor() + + self.timers.log(timer_names) + + see_memory_usage('After zero_optimizer step', force=False) + print_rank_0(f"------------------Finishing Step-----------------------") + + @instrument_w_nvtx + def _reassign_or_swap_out_partitioned_parameters(self, sub_group_id): + if self.fp16_partitioned_groups_flat[sub_group_id] is not None: + self.fp16_partitioned_groups_flat[sub_group_id].data.copy_( + self.fp32_partitioned_groups_flat[sub_group_id].data) + + #unflatten fp16 parameter subgroup + self._unflatten_partitioned_parameters(sub_group_id) + else: + self._partitioned_params_swap_out(sub_group_id) + + def override_loss_scale(self, loss_scale): + if loss_scale != self.external_loss_scale: + logger.info(f'[deepspeed] setting loss scale from {self.external_loss_scale} -> {loss_scale}') + self.custom_loss_scaler = True + self.external_loss_scale = loss_scale + + @instrument_w_nvtx + def step(self, closure=None): + """ + Not supporting closure. + """ + self._pre_step() + self._partition_all_parameters() + + #checks for overflow, adjust the loss scale accordingly + if self._overflow_check_and_loss_scale_update(): + if self.swap_optimizer: + self.optimizer_swapper.log_timers() + return + + norm_groups = self._get_norm_groups() + scaled_global_grad_norm = torch.linalg.norm(torch.stack(norm_groups)) + + # Stash unscaled gradient norm + self._global_grad_norm = scaled_global_grad_norm / self.loss_scale + + timer_names = set() + + timer_names.add(OPTIMIZER_STEP_TIMER) + self.timers(OPTIMIZER_STEP_TIMER).start() + + #update parameters one sub group at a time + for sub_group_id, group in enumerate(self.fp16_groups): + + #prepare optimizer states, gradients and fp32 parameters for update + self._prepare_sub_group(sub_group_id, timer_names) + + #scale the fp32 gradients + self.unscale_and_clip_grads(sub_group_id, scaled_global_grad_norm) + + #apply the optimizer step on the sub group and copy fp32 parameters to fp16 + self._optimizer_step(sub_group_id) + + #put fp16 parameters in appropriate location + self._reassign_or_swap_out_partitioned_parameters(sub_group_id) + + #release memory or swap out optimizer states of fp32 parameters + self._release_sub_group(sub_group_id, timer_names) + + self.timers(OPTIMIZER_STEP_TIMER).stop() + + self._post_step(timer_names) + + # warn user about caching allocator flushes + memory_stats = get_accelerator().memory_stats() + alloc_retries = memory_stats.get("num_alloc_retries") + if alloc_retries is None: + alloc_retries = 0 + if alloc_retries > self.n_caching_allocator_flushes: + if dist.get_rank() == 0: + logger.warning( + "%d pytorch allocator cache flushes since last step. this happens " + "when there is high memory pressure and is detrimental to " + "performance. if this is happening frequently consider adjusting " + "settings to reduce memory consumption. If you are unable to " + "make the cache flushes go away consider adding " + "get_accelerator().empty_cache() calls in your training loop to ensure " + "that all ranks flush their caches at the same time", + alloc_retries - self.n_caching_allocator_flushes) + self.n_caching_allocator_flushes = alloc_retries + + def dump_pre_step_gradients(self, debug_fp32_grads): + # Dump gradient norms for debugging + for i, _ in enumerate(self.fp16_groups): + print(f'Pre-Step Dump Norms for Group {i} FP16P, FP16G, FP32G, FP32GUC') + for fp16_param, fp32_grad in zip(self.fp16_groups[i], debug_fp32_grads[i]): + param_id = self.get_param_id(fp16_param) + fp16_grad_norm = self.debug_fp16_grads[i][param_id] + + fp32_grad_norm = [float(t.data.float().norm(2)) for t in fp32_grad] + norm_list = [fp16_grad_norm, fp32_grad_norm] + print(f'Pre-Step Norms {i} {param_id} = {norm_list}') + + def dump_post_step_gradients(self): + # Dump gradient norms for debugging + for i, group in enumerate(self.fp16_groups): + print(f'Post-Step Dump Norms for Group {i} FP16P, FP16DS, FP16FLAT, FP32FLAT') + unflat_fp16 = self.unflatten(self.fp16_groups_flat[i], self.fp16_groups[i]) + unflat_fp32 = self.unflatten(self.fp32_partitioned_groups_flat[i], self.fp16_groups[i]) + for j, p in enumerate(self.fp16_groups[i]): + param_id = self.get_param_id(p) + param_norm = float(p.data.float().norm(2)) + ds_norm = float(p.ds_tensor.data.float().norm(2)) + + unflat_norm = [float(t.data.float().norm(2)) for t in [unflat_fp16[j], unflat_fp32[j]]] + norm_list = [param_norm, ds_norm] + unflat_norm + print(f'Post-Step Norms {i} {param_id} = {norm_list}') + + @instrument_w_nvtx + def unscale_and_clip_grads(self, sub_group_id, total_norm): + # compute combined scale factor for this group + combined_scale = self.loss_scale + if self.clip_grad > 0.: + # norm is in fact norm*scale + clip = ((total_norm / self.loss_scale) + 1e-6) / self.clip_grad + clip = torch.clamp(clip, min=1.0) + combined_scale = clip * self.loss_scale + + self.fp32_partitioned_groups_flat[sub_group_id].grad.mul_(1. / combined_scale) + + def _check_overflow(self, partition_gradients=True): + self.overflow = self.has_overflow(partition_gradients) + + # `params` is a list / generator of torch.Variable + def has_overflow_serial(self, params, is_grad_list=False): + for p in params: + if p.grad is not None and self._has_inf_or_nan(p.grad.data): + return True + + return False + + def has_overflow_partitioned_grads_serial(self): + for i in range(len(self.fp16_groups)): + for j, grad in enumerate(self.averaged_gradients[i]): + if grad is not None and self._has_inf_or_nan(grad.data, j): + return True + return False + + @instrument_w_nvtx + def has_overflow(self, partition_gradients=True): + if partition_gradients: + with get_accelerator().stream(self.reduce_and_partition_stream): + if hasattr(self.inf_or_nan_tracker, "logical_or_"): + self.inf_or_nan_tracker.logical_or_(torch.isinf(self.grad_partitions_flat_buffer).any()) + self.inf_or_nan_tracker.logical_or_(torch.isnan(self.grad_partitions_flat_buffer).any()) + else: + # logical_or_ not available in older versions of pytorch + self.inf_or_nan_tracker += torch.isinf(self.grad_partitions_flat_buffer).any() + self.inf_or_nan_tracker += torch.isnan(self.grad_partitions_flat_buffer).any() + self.inf_or_nan_tracker = self.inf_or_nan_tracker > 0 + + overflow_gpu = self.inf_or_nan_tracker.clone().to(get_accelerator().current_device_name()).to( + torch.uint8) + self.inf_or_nan_tracker.zero_() + + if not get_accelerator().resolves_data_dependency(): + get_accelerator().default_stream().wait_stream(self.reduce_and_partition_stream) + dist.all_reduce(overflow_gpu, op=dist.ReduceOp.MAX, group=self.dp_process_group) + + else: + params = [] + for group in self.fp16_groups: + for param in group: + params.append(param) + + overflow = self.has_overflow_serial(params, is_grad_list=partition_gradients) + overflow_gpu = get_accelerator().ByteTensor([overflow]) + + # Since each model parallel GPU carries only part of the model, + # make sure overflow flag is synced across all the model parallel GPUs + self._model_parallel_all_reduce(tensor=overflow_gpu, op=dist.ReduceOp.MAX) + + overflow = overflow_gpu[0].item() + return bool(overflow) + + # `x` is a torch.Tensor + @staticmethod + def _has_inf_or_nan(x, j=None): + try: + # if x is half, the .float() incurs an additional deep copy, but it's necessary if + # Pytorch's .sum() creates a one-element tensor of the same type as x + # (which is true for some recent version of pytorch). + cpu_sum = float(x.float().sum()) + # More efficient version that can be used if .sum() returns a Python scalar + # cpu_sum = float(x.sum()) + except RuntimeError as instance: + # We want to check if inst is actually an overflow exception. + # RuntimeError could come from a different error. + # If so, we still want the exception to propagate. + if "value cannot be converted" not in instance.args[0]: + raise + return True + else: + if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum: + return True + return False + + @instrument_w_nvtx + def backward(self, loss, retain_graph=False): + """ + :attr:`backward` performs the following steps: + + 1. fp32_loss = loss.float() + 2. scaled_loss = fp32_loss*loss_scale + 3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's fp16 leaves + """ + if self.swap_optimizer: + self.optimizer_swapper.pre_backward() + + see_memory_usage(f"Before backward", force=False) + + if self.custom_loss_scaler: + scaled_loss = self.external_loss_scale * loss + scaled_loss.backward() + else: + self.loss_scaler.backward(loss.float(), retain_graph=retain_graph) + + self._get_param_coordinator(training=True).reset_step() + + if self.swap_optimizer: + self.optimizer_swapper.post_backward() + + def get_fp32_grad_partitions(self) -> Dict[int, Dict[int, Tensor]]: + """get fp32 gradient partition dictionary + accessed as grad_dict[parameter_group_index][parameter_index] + """ + if not get_accelerator().resolves_data_dependency(): + self.reduce_and_partition_stream.synchronize() + grad_dict = collections.defaultdict(dict) + if self.offload_optimizer: + for group in self.fp16_groups: + for param_idx, param in enumerate(group): + group_idx, dest_offset, num_elements = self.grad_position[self.get_param_id(param)] + fp32_grad = self.fp32_partitioned_groups_flat[group_idx].grad.narrow(0, dest_offset, num_elements) + grad_dict[group_idx][param_idx] = fp32_grad + else: + for group_idx, group in self.averaged_gradients.items(): + for param_idx, gradient in enumerate(group): + grad_dict[group_idx][param_idx] = gradient.float() + + return grad_dict + + def _fp32_state_allgather(self, param, fp32_state_partition): + reduce_buffer = torch.empty(self.partition_count * fp32_state_partition.numel(), + dtype=torch.float32, + device=param.device) + my_rank = dist.get_rank(group=self.dp_process_group) + partition = reduce_buffer.narrow(0, fp32_state_partition.numel() * my_rank, fp32_state_partition.numel()) + partition.data.copy_(fp32_state_partition.data, non_blocking=False) + dist.all_gather_into_tensor(reduce_buffer, partition, group=self.dp_process_group) + return reduce_buffer.narrow(0, 0, param.ds_numel).view(param.ds_shape) + + def get_fp32_grad_for_param(self, param) -> Tensor: + if not param.requires_grad: + return None + + if not get_accelerator().resolves_data_dependency(): + self.reduce_and_partition_stream.synchronize() + + if self.offload_optimizer: + group_idx, dest_offset, num_elements = self.grad_position[self.get_param_id(param)] + fp32_grad = self.fp32_partitioned_groups_flat[group_idx].grad.narrow(0, dest_offset, num_elements) + else: + fp32_grad = self.__param_id_to_grad_partition[param.ds_id].float() + + return self._fp32_state_allgather(param, fp32_grad) + + def set_fp32_grad_for_param(self, value, param): + if not param.requires_grad: + return + + if not get_accelerator().resolves_data_dependency(): + self.reduce_and_partition_stream.synchronize() + + if self.offload_optimizer: + group_idx, dest_offset, num_elements = self.grad_position[self.get_param_id(param)] + fp32_grad = self.fp32_partitioned_groups_flat[group_idx].grad.narrow(0, dest_offset, num_elements) + else: + fp32_grad = self.__param_id_to_grad_partition[param.ds_id] + + my_rank = dist.get_rank(group=self.dp_process_group) + value_partition = value.flatten().narrow(0, fp32_grad.numel() * my_rank, fp32_grad.numel()) + + fp32_grad.data.copy_(value_partition.data) + + def _get_fp32_opt_state_partition(self, param, optim_state_key=None): + if not get_accelerator().resolves_data_dependency(): + self.reduce_and_partition_stream.synchronize() + + group_idx, dest_offset, num_elements = self.grad_position[self.get_param_id(param)] + + if self._swappable_optimizer_subgroup(group_idx): + self._optimizer_states_and_gradient_swap_in(group_idx) + + fp32_param = self.fp32_partitioned_groups_flat[group_idx] + if optim_state_key is None: + fp32_opt_state = fp32_param.narrow(0, dest_offset, num_elements) + else: + fp32_opt_state = self.optimizer.state[fp32_param][optim_state_key].narrow(0, dest_offset, num_elements) + + return fp32_opt_state, group_idx + + def get_full_hp_param(self, param, optim_state_key=None) -> Tensor: + if not param.requires_grad: + return None + + fp32_opt_state, group_idx = self._get_fp32_opt_state_partition(param, optim_state_key) + hp_param = self._fp32_state_allgather(param, fp32_opt_state) + + if self._swappable_optimizer_subgroup(group_idx): + self._optimizer_states_and_gradient_swap_out(group_idx) + + return hp_param + + def set_full_hp_param(self, value, param, optim_state_key=None): + if not param.requires_grad: + return + + assert value.numel( + ) == param.ds_numel, f" Number of elements do not match: {value.numel()} != {param.ds_numel}" + + fp32_opt_state_partition, group_idx = self._get_fp32_opt_state_partition(param, optim_state_key) + my_rank = dist.get_rank(group=self.dp_process_group) + value_partition = value.flatten().narrow(0, + fp32_opt_state_partition.numel() * my_rank, + fp32_opt_state_partition.numel()) + fp32_opt_state_partition.data.copy_(value_partition.data) + + if self._swappable_optimizer_subgroup(group_idx): + self._optimizer_states_and_gradient_swap_out(group_idx) + + ### Local API START ### + + def get_local_fp32_grad_for_param(self, param) -> Tensor: + if not param.requires_grad: + return None + + if not get_accelerator().resolves_data_dependency(): + self.reduce_and_partition_stream.synchronize() + + if self.offload_optimizer: + group_idx, dest_offset, num_elements = self.grad_position[self.get_param_id(param)] + fp32_grad = self.fp32_partitioned_groups_flat[group_idx].grad.narrow(0, dest_offset, num_elements) + else: + fp32_grad = self.__param_id_to_grad_partition[param.ds_id].float() + return fp32_grad + + def set_local_grad_for_param(self, value, param): + if not param.requires_grad: + return + + assert value.numel() == param.ds_tensor.numel( + ), f" Number of elements do not match: {value.numel()} != {param.ds_tensor.ds_numel}" + + if not get_accelerator().resolves_data_dependency(): + self.reduce_and_partition_stream.synchronize() + + if self.offload_optimizer: + group_idx, dest_offset, num_elements = self.grad_position[self.get_param_id(param)] + fp32_grad = self.fp32_partitioned_groups_flat[group_idx].grad.narrow(0, dest_offset, num_elements) + else: + fp32_grad = self.__param_id_to_grad_partition[param.ds_id] + + fp32_grad.data.copy_(value.flatten().data) + + def get_local_fp32_param(self, param, optim_state_key=None) -> Tensor: + if not param.requires_grad: + return None + fp32_opt_state, group_idx = self._get_fp32_opt_state_partition(param, optim_state_key) + return fp32_opt_state + + def set_local_hp_param(self, value, param, optim_state_key=None): + if not param.requires_grad: + return + + assert hasattr(param, "ds_tensor"), f" The parameter does not contain the partitioned copy of the tensor." + assert value.numel() == param.ds_tensor.numel( + ), f" Number of elements do not match: {value.numel()} != {param.ds_tensor.ds_numel}" + + fp32_opt_state_partition, group_idx = self._get_fp32_opt_state_partition(param, optim_state_key) + value_partition = value.flatten() + fp32_opt_state_partition.data.copy_(value_partition.data) + + if self._swappable_optimizer_subgroup(group_idx): + self._optimizer_states_and_gradient_swap_out(group_idx) + # logger.info(f"[set_local_hp_param][update the params' value successfully]") + + ### Local API END ### + + @instrument_w_nvtx + def _partition_all_parameters(self): + self.parameter_offload.partition_all_parameters() + + def check_overflow(self, partition_gradients=True): + self._check_overflow(partition_gradients) + + def _update_scale(self, has_overflow=False): + self.loss_scaler.update_scale(has_overflow) + + # Promote state so it can be retrieved or set via "fp16_optimizer_instance.state" + def _get_state(self): + return self.optimizer.state + + def _set_state(self, value): + self.optimizer.state = value + + state = property(_get_state, _set_state) + + # Promote param_groups so it can be retrieved or set via "fp16_optimizer_instance.param_groups" + # (for example, to adjust the learning rate) + def _get_param_groups(self): + return self.optimizer.param_groups + + def _set_param_groups(self, value): + self.optimizer.param_groups = value + self.trainable_param_groups = self._get_trainable_parameter_groups() + + param_groups = property(_get_param_groups, _set_param_groups) + + # Promote loss scale so it can be retrieved or set via "fp16_optimizer_instance.loss_scale" + def _get_loss_scale(self): + if self.custom_loss_scaler: + return self.external_loss_scale + else: + return self.loss_scaler.cur_scale + + def _set_loss_scale(self, value): + self.loss_scaler.cur_scale = value + + loss_scale = property(_get_loss_scale, _set_loss_scale) + cur_scale = property(_get_loss_scale, _set_loss_scale) + + def _get_lean_tensors(self, padded_flattened_tensor, group_tensors, paddings): + # Remove paddings from flattened tensor + individual_tensors = self.unflatten(padded_flattened_tensor, group_tensors) + lean_lengths = [t.numel() - pad for t, pad in zip(group_tensors, paddings)] + lean_tensors = [t[:len] for t, len in zip(individual_tensors, lean_lengths)] + #logger.info(f'rank {dist.get_rank()}: lean_tensors = {[t.numel() for t in lean_tensors]}') + return lean_tensors + + #TODO REVISIT this for stage 3 + def get_lean_optimizer_state(self): + # Return optimizer states after removing paddings. + # This method assumes that each param group contains a single flattened tensor. + optimizer_groups_state = [] + + for i, group in enumerate(self.optimizer.param_groups): + p = group['params'][0] + lean_state = {} + for key, value in self.optimizer.state[p].items(): + if torch.is_tensor(value): + padded_lens = [t.numel() for t in self.fp16_partitioned_groups[i]] + lean_state[key] = self._get_lean_tensors(value, self.fp16_partitioned_groups[i], + self.groups_padding[i]) + lean_flat_len = sum([t.numel() for t in lean_state[key]]) + else: + lean_state[key] = value + + optimizer_groups_state.append(lean_state) + + return optimizer_groups_state + + def get_groups_without_padding(self, groups_with_padding): + # Return group tensor after removing paddings added for alignment to DP world size. + groups_without_padding = [] + for i, group in enumerate(groups_with_padding): + lean_group = self._get_lean_tensors(group, self.fp16_partitioned_groups[i], self.groups_padding[i]) + groups_without_padding.append(lean_group) + + return groups_without_padding + + def _set_fp32_optimizer_param_groups(self): + for sub_group_id, _ in enumerate(self.fp16_groups): + param_group_id = self.sub_group_to_group_id[sub_group_id] + self.optimizer.param_groups[param_group_id]['params'].append( + self.fp32_partitioned_groups_flat[sub_group_id]) + + def _clear_fp32_optimizer_param_groups(self): + for param_group in self.optimizer.param_groups: + param_group['params'] = [] + + def _rigid_state_dict(self): + state_dict = {} + state_dict[ZERO_STAGE] = ZeroStageEnum.weights + state_dict[LOSS_SCALER] = self.loss_scaler + state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale + state_dict['overflow'] = self.overflow + state_dict[PARTITION_COUNT] = self.partition_count + + self._set_fp32_optimizer_param_groups() + state_dict[OPTIMIZER_STATE_DICT] = self.optimizer.state_dict() + state_dict[FP32_FLAT_GROUPS] = self.fp32_partitioned_groups_flat + self._clear_fp32_optimizer_param_groups() + + return state_dict + + def state_dict(self): + """ + Returns a dict containing the current state of this :class:`FP16_Optimizer` instance. + This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict + of the contained Pytorch optimizer. + Example:: + checkpoint = {} + checkpoint['model'] = model.state_dict() + checkpoint['optimizer'] = optimizer.state_dict() + torch.save(checkpoint, "saved.pth") + """ + if self.elastic_checkpoint: + raise NotImplementedError("ZeRO-3 does not yet support elastic checkpointing, please disable for now.") + + return self._rigid_state_dict() + + +# Restore base optimizer fp32 weights from checkpoint by: +# 1) Merging fp32 weights from checkpoints of all partitions +# 2) Extracting fp32 weights for current partition from merged weights +# 3) Using extracted weights to update base optimizer weights directly. + + def _restore_from_fp32_weights(self, all_state_dict): + + flat_local_partition = [] + for i in range(len(self.fp32_partitioned_groups_flat)): + merged_partitions = [sd['fp32_groups'][i] for sd in all_state_dict] + flat_local_partition.append(self._get_flattened_partition(merged_partitions)) + + for current, saved in zip(self.fp32_partitioned_groups_flat, flat_local_partition): + current.data.copy_(saved.data) + + # Restore base optimizer fp32 weights from ZeRO fp16 weights + def _restore_from_bit16_weights(self): + for fp16_partitions, fp32_partition in zip(self.fp16_partitioned_groups_flat, + self.fp32_partitioned_groups_flat): + fp32_partition.data.copy_(fp16_partitions.data) + + # Refresh the fp32 master params from the fp16 copies. + def refresh_fp32_params(self): + self._restore_from_bit16_weights() + + # Extract flattened partition for current rank from all partitions + def _get_flattened_partition(self, all_partition_states): + partition_id = dist.get_rank(group=self.dp_process_group) + alignment = dist.get_world_size(group=self.dp_process_group) + + param_partitions = [[] for _ in range(len(all_partition_states[0]))] + for i, partition in enumerate(all_partition_states): + for j, param in enumerate(partition): + param_partitions[j].append(param) + + local_state_partitions = [] + for param_index, param_slices in enumerate(param_partitions): + flattened_merged_tensor = self.flatten_dense_tensors_aligned(param_slices, alignment) + new_partitions = self.get_data_parallel_partitions(flattened_merged_tensor) + local_state_partitions.append(new_partitions[partition_id]) + + if torch.is_tensor(local_state_partitions[0]): + return self.flatten_dense_tensors_aligned(local_state_partitions, alignment) + + # Assume non-tensor states are not partitioned and equal across ranks, so return first one + return local_state_partitions[0] + + # Restore base optimizer state from checkpoint by + # 1) Merging optimizer state from checkpoints of all partitions + # 2) Extracting optimizer state for current partition from the merged state + # 3) Using the extracted value to directly update the base optimizer. + def _restore_base_optimizer_state(self, all_state_dict): + base_optimizer_group_states = [] + for i in range(len(self.optimizer.param_groups)): + partition_states = {} + all_partition_group_states = [sd['base_optimizer_state'][i] for sd in all_state_dict] + for key in all_partition_group_states[0].keys(): + all_partition_states = [all_states[key] for all_states in all_partition_group_states] + partition_states[key] = self._get_flattened_partition(all_partition_states) + base_optimizer_group_states.append(partition_states) + + for i, group in enumerate(self.optimizer.param_groups): + p = group['params'][0] + for key, saved in base_optimizer_group_states[i].items(): + if torch.is_tensor(self.optimizer.state[p][key]): + self.optimizer.state[p][key].data.copy_(saved.data) + else: + self.optimizer.state[p][key] = saved + + def _rigid_load_state_dict(self, state_dict, load_optimizer_states=True): + # I think it should actually be ok to reload the optimizer before the model. + self.loss_scaler = state_dict[LOSS_SCALER] + self.dynamic_loss_scale = state_dict['dynamic_loss_scale'] + self.overflow = state_dict['overflow'] + + if load_optimizer_states: + self._set_fp32_optimizer_param_groups() + self.optimizer.load_state_dict(state_dict[OPTIMIZER_STATE_DICT]) + self._clear_fp32_optimizer_param_groups() + + if self.swap_optimizer or self.params_in_nvme_and_cpu: + # Purge the swapped optimizer state, it was initialized to the freshly created model and not the checkpoint + for swap_info in self.optimizer_swapper.swap_params_info.values(): + swap_info.tensors = [swap_info.tensors[0]] + swap_info.has_state_tensors = False + + if self.swap_optimizer: + # Touch all parameters to synchronize all buffers + timer_names = set() + self._partition_all_parameters() + for sub_group_id, group in enumerate(self.fp16_groups): + self._prepare_sub_group(sub_group_id, timer_names) + self._reassign_or_swap_out_partitioned_parameters(sub_group_id) + self._release_sub_group(sub_group_id, timer_names) + self._post_step(timer_names) + + # restore fp32 partitions + for curr_param, saved_param in zip(self.fp32_partitioned_groups_flat, state_dict[FP32_FLAT_GROUPS]): + curr_param.data.copy_(saved_param.data) + + # restore fp16 partitions from fp32 + for sub_group_id in range(len(self.fp32_partitioned_groups_flat)): + fp32_param = self.fp32_partitioned_groups_flat[sub_group_id] + if sum(fp32_param.size()) > 0: + fp16_param = self.fp16_partitioned_groups_flat[sub_group_id] + fp16_param.data.copy_(fp32_param.data) + + # update fp16 unflattened params + for sub_group_id in range(len(self.fp16_partitioned_groups_flat)): + updated_params = self.unflatten(self.fp16_partitioned_groups_flat[sub_group_id], + self.fp16_partitioned_groups[sub_group_id]) + + for partitioned_param, q in zip(self.fp16_partitioned_groups[sub_group_id], updated_params): + partitioned_param.data = q.data + + # TODO: Support different/changing load/save DP degree. + def load_state_dict(self, + state_dict_list, + load_optimizer_states=True, + load_from_fp32_weights=False, + checkpoint_folder=None, + load_serial=None, + param_shapes=None): + r"""Loading a ZeRO checkpoint + Arguments: + state_dict_list: List of all saved ZeRO checkpoints, one for each saved partition. + Note that the number of saved partitions may differ from number of loading partitions to support + changing GPU count, specifically DP world size, between saving and loading checkpoints. + load_optimizer_states: Boolean indicating whether or not to load base optimizer states + load_from_fp32_weights: Boolean indicating whether to initialize fp32 master weights from fp32 + copies in checkpoints (no precision loss) or from model's fp16 copies (with precision loss). + """ + """ + Loads a state_dict created by an earlier call to state_dict(). + If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``, + whose parameters in turn came from ``model``, it is expected that the user + will call ``model.load_state_dict()`` before + ``fp16_optimizer_instance.load_state_dict()`` is called. + Example:: + model = torch.nn.Linear(D_in, D_out).to(get_accelerator().device_name()).half() + optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) + optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0) + ... + checkpoint = torch.load("saved.pth") + model.load_state_dict(checkpoint['model']) + optimizer.load_state_dict(checkpoint['optimizer']) + """ + + if self.elastic_checkpoint: + raise NotImplementedError("ZeRO-3 does not yet support elastic checkpointing, please disable for now.") + + if checkpoint_folder: + self._load_universal_checkpoint(checkpoint_folder, load_optimizer_states, load_from_fp32_weights, + param_shapes) + else: + self._rigid_load_state_dict(state_dict_list[dist.get_rank(group=self.dp_process_group)], + load_optimizer_states=load_optimizer_states) + + # when use loading checkpoint serial, after finish loading, we need to + # delete the temp state_dict_list variable to save memory, then trigger + # the next rank's loading + if load_serial is not None: + load_serial += 1 + rank = dist.get_rank(group=self.dp_process_group) + local_rank = dist.get_local_rank() + del state_dict_list[rank] + rank_end = dist.get_world_size() - 1 + if local_rank != rank_end: + dist.send(tensor=load_serial, dst=rank + 1) + + if len(self.persistent_parameters) > 0: + self.persistent_parameters[0].partition(self.persistent_parameters) + # self.persistent_parameters[0].all_gather(self.persistent_parameters) # this will be done in checkpoint_event_epilogue() so remove it to prevent double all_gather + + def _load_universal_checkpoint(self, checkpoint_folder, load_optimizer_states, load_from_fp32_weights, + param_shapes): + self.load_hp_checkpoint_state_from_checkpoint_dir_stage3(checkpoint_folder, param_shapes) + + def load_hp_checkpoint_state_from_checkpoint_dir_stage3(self, checkpoint_dir, param_shapes): + """ Load optimizer and model states from the checkpoint directory. """ + checkpoint_dir = os.path.join(checkpoint_dir, "zero") + optim_state_path = os.path.join(checkpoint_dir, "optimizer_state.pt") + assert os.path.isfile( + optim_state_path), f'{optim_state_path} containing optimizer global state is missing! Cannot proceed.' + + optim_sd = torch.load(optim_state_path) + self._load_global_state_stage3(optim_sd) + + key_list = ["fp32", "exp_avg", "exp_avg_sq"] + + for key in key_list: + key_tensor = torch.empty(0) + for layer in param_shapes[0].keys(): + key_layer_state_partition = self.load_hp_checkpoint_state(os.path.join(checkpoint_dir, layer), key) + key_tensor = torch.cat((key_tensor, key_layer_state_partition)) + if key == "fp32": + self.fp32_partitioned_groups_flat[0].data.copy_(key_tensor) + self.optimizer.param_groups[0]['params'].append(self.fp32_partitioned_groups_flat[0]) + else: + optim_sd[OPTIMIZER_STATE_DICT]['state'][0][key] = key_tensor + + if self.swap_optimizer or self.params_in_nvme_and_cpu: + # Purge the swapped optimizer state, it was initialized to the freshly created model and not the checkpoint + for swap_info in self.optimizer_swapper.swap_params_info.values(): + swap_info.tensors = [swap_info.tensors[0]] + swap_info.has_state_tensors = False + + if self.swap_optimizer: + # Touch all parameters to synchronize all buffers + timer_names = set() + self._partition_all_parameters() + for sub_group_id, group in enumerate(self.fp16_groups): + self._prepare_sub_group(sub_group_id, timer_names) + self._reassign_or_swap_out_partitioned_parameters(sub_group_id) + self._release_sub_group(sub_group_id, timer_names) + self._post_step(timer_names) + + self.optimizer.load_state_dict(optim_sd[OPTIMIZER_STATE_DICT]) + for param_group in self.optimizer.param_groups: + param_group['params'] = [] + + for sub_group_id in range(len(self.fp32_partitioned_groups_flat)): + fp32_param = self.fp32_partitioned_groups_flat[sub_group_id] + if sum(fp32_param.size()) > 0: + fp16_param = self.fp16_partitioned_groups_flat[sub_group_id] + fp16_param.data.copy_(fp32_param.data) + + for sub_group_id in range(len(self.fp16_partitioned_groups_flat)): + updated_params = self.unflatten(self.fp16_partitioned_groups_flat[sub_group_id], + self.fp16_partitioned_groups[sub_group_id]) + + for partitioned_param, q in zip(self.fp16_partitioned_groups[sub_group_id], updated_params): + partitioned_param.data = q.data + + def _load_global_state_stage3(self, sd): + self.loss_scaler = sd.get(LOSS_SCALER, self.loss_scaler) + self.dynamic_loss_scale = sd.get('dynamic_loss_scale', self.dynamic_loss_scale) + self.overflow = sd.get('overflow', self.overflow) + + def load_hp_checkpoint_state(self, folder, key): + local_rank = dist.get_local_rank() + + # Load tensors from files and reshape them to flat vectors + loaded_checkpoint_state = torch.load(os.path.join(folder, f"{key}.pt")).view(-1) + + # Partition the loaded data according to the local rank + world_size = dist.get_world_size(group=self.dp_process_group) + unpartitioned_numel = loaded_checkpoint_state.numel() + partitioned_numel = math.ceil(unpartitioned_numel / world_size) + + if world_size * partitioned_numel != unpartitioned_numel: + padding_size = world_size * partitioned_numel - unpartitioned_numel + padding_tensor = torch.zeros(padding_size, dtype=loaded_checkpoint_state.dtype) + loaded_checkpoint_state = torch.cat([loaded_checkpoint_state, padding_tensor]) + checkpoint_state_partition = loaded_checkpoint_state.narrow(0, local_rank * partitioned_numel, + partitioned_numel) + + return checkpoint_state_partition + + def reset_swap_buffers(self): + timer_names = set() + for sub_group_id, group in enumerate(self.fp16_groups): + self._prepare_sub_group(sub_group_id, timer_names) + self._reassign_or_swap_out_partitioned_parameters(sub_group_id) + self._release_sub_group(sub_group_id, timer_names) + + def checkpoint_event_prologue(self): + self._partition_all_parameters() + + def checkpoint_event_epilogue(self): + if len(self.persistent_parameters) > 0: + self.persistent_parameters[0].all_gather(self.persistent_parameters) + + def empty_partition_cache(self): + self.parameter_offload.empty_partition_cache() + + def offload_states(self, + include: Container[OffloadStateTypeEnum] = None, + device: OffloadDeviceEnum = OffloadDeviceEnum.cpu, + pin_memory: bool = True, + non_blocking: bool = False): + device = device.value + + self.empty_partition_cache() + + assert self.optimizer.__class__ == deepspeed.ops.adam.fused_adam.FusedAdam, f"Offloading is supported only for DeepSpeed FusedAdam." + + def needs_offload(target): + # return True + return target not in self.offloaded_states and (include == None or target in include) + + # HP param + if needs_offload(OffloadStateTypeEnum.hp_params): + if pin_memory: + if not hasattr(self, "hp_params_pin_buffers"): + self.hp_params_pin_buffers = [ + get_accelerator().pin_memory(torch.empty_like(t, device=device)) + for t in self.fp32_partitioned_groups_flat + ] + + for src_tensor, dest_buf in zip(self.fp32_partitioned_groups_flat, self.hp_params_pin_buffers): + dest_buf.copy_(src_tensor, non_blocking=non_blocking) + src_tensor.data = dest_buf + else: + for buf in self.fp32_partitioned_groups_flat: + buf.data = buf.data.to(device, non_blocking=non_blocking) + self.offloaded_states.add(OffloadStateTypeEnum.hp_params) + + # LP param + if needs_offload(OffloadStateTypeEnum.lp_params): + if pin_memory: + if not hasattr(self, "lp_param_contiguous_pin_buffer"): + self.lp_param_contiguous_pin_buffer = get_accelerator().pin_memory( + torch.empty_like(self.lp_param_buffer, device=device)) + self.lp_param_contiguous_pin_buffer.copy_(self.lp_param_buffer, non_blocking=non_blocking) + cpu_buffer = self.lp_param_contiguous_pin_buffer + else: + cpu_buffer = self.lp_param_buffer.to(device, non_blocking=non_blocking) + + self.lp_param_buffer.data = cpu_buffer + for tensor, offset, tensor_numel in get_mapping_to_flat_buffer( + [p.ds_tensor for p in self.module.parameters()]): + tensor.data = cpu_buffer.narrow(0, offset, tensor_numel) + + self.fp16_partitioned_groups_flat.clear() + self.offloaded_states.add(OffloadStateTypeEnum.lp_params) + + # LP grad + if needs_offload(OffloadStateTypeEnum.lp_grads): + if pin_memory: + if not hasattr(self, "lp_grad_partitions_flat_pin_buffers"): + self.lp_grad_partitions_flat_pin_buffers = get_accelerator().pin_memory( + torch.empty_like(self.grad_partitions_flat_buffer, device=device)) + self.lp_grad_partitions_flat_pin_buffers.copy_(self.grad_partitions_flat_buffer, + non_blocking=non_blocking) + self.grad_partitions_flat_buffer.data = self.lp_grad_partitions_flat_pin_buffers + else: + self.grad_partitions_flat_buffer.data = self.grad_partitions_flat_buffer.data.to(device) + self.averaged_gradients = {} + + self.__param_id_to_grad_partition = {} + + self.offloaded_states.add(OffloadStateTypeEnum.lp_grads) + + # contiguous bucket + if needs_offload(OffloadStateTypeEnum.contiguous_grad_buffer): + if hasattr(self, "_DeepSpeedZeroOptimizer_Stage3__ipg_bucket_flat_buffer"): + # Record properties like shape, strides, etc. as a meta tensor + self.grad_buffer_meta = self.__ipg_bucket_flat_buffer.to("meta") + self.__ipg_bucket_flat_buffer = None + self.offloaded_states.add(OffloadStateTypeEnum.contiguous_grad_buffer) + + # Adam + if needs_offload(OffloadStateTypeEnum.optim_states): + offload_adam_states(self.optimizer, device, pin_memory=pin_memory, non_blocking=non_blocking) + self.offloaded_states.add(OffloadStateTypeEnum.optim_states) + + gc.collect() + get_accelerator().empty_cache() + + def reload_states(self, non_blocking: bool = False): + + device = get_accelerator().current_device_name() + + # HP param + if OffloadStateTypeEnum.hp_params in self.offloaded_states: + if hasattr(self, "hp_params_pin_buffers"): + for src, dest in zip(self.hp_params_pin_buffers, self.fp32_partitioned_groups_flat): + dest.data = src.to(device, non_blocking=non_blocking) + else: + for buf in self.fp32_partitioned_groups_flat: + buf.data = buf.data.to(device, non_blocking=non_blocking) + self.offloaded_states.remove(OffloadStateTypeEnum.hp_params) + + # LP Param + if OffloadStateTypeEnum.lp_params in self.offloaded_states: + cpu_buffer = self.lp_param_contiguous_pin_buffer if hasattr( + self, "lp_param_contiguous_pin_buffer") else self.lp_param_buffer + self.lp_param_buffer.data = cpu_buffer.data.to(device, non_blocking=non_blocking) + self._set_fp16_partitioned_groups_flat() + + for tensor, offset, tensor_numel in get_mapping_to_flat_buffer( + [p.ds_tensor for p in self.module.parameters()]): + tensor.data = self.lp_param_buffer.narrow(0, offset, tensor_numel) + self.offloaded_states.remove(OffloadStateTypeEnum.lp_params) + + # LP grad + if OffloadStateTypeEnum.lp_grads in self.offloaded_states: + if hasattr(self, "lp_grad_partitions_flat_pin_buffers"): + self.grad_partitions_flat_buffer.data = self.lp_grad_partitions_flat_pin_buffers.to( + device, non_blocking=non_blocking) + else: + self.grad_partitions_flat_buffer.data = self.grad_partitions_flat_buffer.data.to( + device, non_blocking=non_blocking) + self.averaged_gradients = {} + + offset = 0 + all_params = list(itertools.chain.from_iterable(self.fp16_groups)) + for param in all_params: + self.__param_id_to_grad_partition[param.ds_id] = self.grad_partitions_flat_buffer.narrow( + 0, offset, param.partition_numel()) + offset += param.partition_numel() + + self.offloaded_states.remove(OffloadStateTypeEnum.lp_grads) + + # contiguous bucket + if OffloadStateTypeEnum.contiguous_grad_buffer in self.offloaded_states: + self.__ipg_bucket_flat_buffer = torch.empty_like(self.grad_buffer_meta, device=device) + # self.__ipg_bucket_flat_buffer.data = self.__ipg_bucket_flat_buffer.data.to(device) + self.offloaded_states.remove(OffloadStateTypeEnum.contiguous_grad_buffer) + + # Adam + if OffloadStateTypeEnum.optim_states in self.offloaded_states: + reload_adam_states(self.optimizer, device, non_blocking=non_blocking) + self.offloaded_states.remove(OffloadStateTypeEnum.optim_states) + + if non_blocking: + get_accelerator().synchronize() + + +def _handle_overflow(cpu_sum, x, i): + import math + rank = dist.get_rank() + if rank == 0: + t_i = -1 + for v_i, v in enumerate(x.data.contiguous().view(-1)): + if not math.isfinite(float(v)): + t_i = v_i + break + logger.info(f"rank {rank} detected overflow {cpu_sum} in tensor {i}:{t_i} shape {x.shape}") + + +def estimate_zero3_model_states_mem_needs(total_params, + largest_layer_params, + num_gpus_per_node=1, + num_nodes=1, + cpu_offload=True, + cpu_offload_params=True, + zero_init=True, + additional_buffer_factor=1.5): + + total_gpus = num_nodes * num_gpus_per_node + gpus_factor = 1 / num_nodes + largest_layer_memory = (4 * largest_layer_params) + + if cpu_offload: + if cpu_offload_params: + gpu_mem = largest_layer_memory + + if zero_init: + cpu_mem = total_params * 18 * gpus_factor * additional_buffer_factor + else: + cpu_mem = total_params * max(4 * num_gpus_per_node, 18 * gpus_factor) * additional_buffer_factor + + else: + gpu_mem = largest_layer_memory + int(2 * total_params / total_gpus) + + if zero_init: + cpu_mem = total_params * 16 * gpus_factor * additional_buffer_factor + else: + cpu_mem = total_params * max(4 * num_gpus_per_node, 16 * gpus_factor) * additional_buffer_factor + else: + gpu_mem = largest_layer_memory + int(18 * total_params / total_gpus) + if zero_init: + cpu_mem = largest_layer_params * 4 * num_gpus_per_node * additional_buffer_factor + else: + cpu_mem = total_params * 4 * num_gpus_per_node * additional_buffer_factor + + return int(cpu_mem), int(gpu_mem), largest_layer_memory + + +def model_to_params(model): + # shared params calculated only once + total_params = sum(dict((p.data_ptr(), p.numel()) for p in model.parameters()).values()) + + largest_layer_params = 0 + for m in model.modules(): + # assuming no shared params within a single layer + layer_params = sum(p.numel() for p in m.parameters(recurse=False)) + largest_layer_params = max(largest_layer_params, layer_params) + + return total_params, largest_layer_params + + +def estimate_zero3_model_states_mem_needs_all_live(model, + num_gpus_per_node=1, + num_nodes=1, + additional_buffer_factor=1.5): + """ + Print out estimates on memory usage requirements for ZeRO 3 params, optim states and gradients + for a given ``model`` and hardware setup. + + If you have an actual model object, use this function and everything will be derived + automatically. + + If it's a hypothetical model, use ``estimate_zero3_model_states_mem_needs_all_cold`` where you have to pass + the ``total_params`` and ``largest_layer_params`` explicitly. + + Args: + - ``model``: ``nn.Module`` object + - ``num_gpus_per_node``: how many gpus per node (defaults to 1) + - ``num_nodes``: how many nodes (defaults to 1), + - ``additional_buffer_factor``: estimation factor (defaults to 1.5): + + """ + + total_params, largest_layer_params = model_to_params(model) + + estimate_zero3_model_states_mem_needs_all_cold(total_params=total_params, + largest_layer_params=largest_layer_params, + num_gpus_per_node=num_gpus_per_node, + num_nodes=num_nodes, + additional_buffer_factor=additional_buffer_factor) + + +def estimate_zero3_model_states_mem_needs_all_cold(total_params, + largest_layer_params, + num_gpus_per_node=1, + num_nodes=1, + additional_buffer_factor=1.5): + """ + Print out estimates on memory usage requirements for ZeRO 3 params, optim states and gradients + for a given ``model`` and hardware setup. + + If it's a hypothetical model, use this function where you have to pass + the ``total_params`` and ``largest_layer_params`` explicitly. + + If you have an actual model object, use ``estimate_zero3_model_states_mem_needs_all_live`` and everything + will be derived automatically. + + Args: + - ``total_params``: total model params + - ``largest_layer_params``: largest layer's params + - ``num_gpus_per_node``: how many gpus per node (defaults to 1) + - ``num_nodes``: how many nodes (defaults to 1), + - ``additional_buffer_factor``: estimation factor (defaults to 1.5): + + """ + + def format_options(cpu_offload, cpu_offload_params, zero_init): + enabled = [] + padded_cpu_str = f'{OffloadDeviceEnum.cpu:4}' + param_device = padded_cpu_str if cpu_offload_params else "none" + enabled.append(f"offload_param={param_device}") + optimizer_device = padded_cpu_str if cpu_offload else "none" + enabled.append(f"offload_optimizer={optimizer_device}") + enabled.append(f"zero_init={1 if zero_init else 0}") + return ", ".join(enabled) + + nodes_str = "nodes" if num_nodes > 1 else "node" + gpus_str = "GPUs" if num_gpus_per_node > 1 else "GPU" + print( + "Estimated memory needed for params, optim states and gradients for a:\n" + f"HW: Setup with {num_nodes} {nodes_str}, {num_gpus_per_node} {gpus_str} per node.\n" + f"SW: Model with {int(total_params/1e6)}M total params, {int(largest_layer_params/1e6)}M largest layer params." + ) + print(" per CPU | per GPU | Options") + for cpu_offload in [True, False]: + for cpu_offload_params in [True, False]: + if not cpu_offload and cpu_offload_params: + continue + for zero_init in [True, False]: + cpu_mem, gpu_mem, largest_layer_memory = estimate_zero3_model_states_mem_needs( + total_params=total_params, + largest_layer_params=largest_layer_params, + num_gpus_per_node=num_gpus_per_node, + num_nodes=num_nodes, + cpu_offload=cpu_offload, + cpu_offload_params=cpu_offload_params, + zero_init=zero_init, + additional_buffer_factor=additional_buffer_factor) + + options_str = format_options(cpu_offload=cpu_offload, + cpu_offload_params=cpu_offload_params, + zero_init=zero_init) + print(f" {cpu_mem/2**30:7.2f}GB | {gpu_mem/2**30:6.2f}GB | {options_str}") diff --git a/toolbox/DeepSpeed/v0.15.3/patches/deepspeed/runtime/zero/stage_1_and_2.py b/toolbox/DeepSpeed/v0.15.3/patches/deepspeed/runtime/zero/stage_1_and_2.py new file mode 100644 index 0000000000000000000000000000000000000000..1ec6c3badf95eb84af7b16019828d8bc8e9aa800 --- /dev/null +++ b/toolbox/DeepSpeed/v0.15.3/patches/deepspeed/runtime/zero/stage_1_and_2.py @@ -0,0 +1,2534 @@ +#!/usr/bin/env python3 +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +from deepspeed import comm as dist +from packaging import version as pkg_version +from collections import OrderedDict +from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors + +from deepspeed.runtime.base_optimizer import ZeROOptimizer +from deepspeed.runtime.fp16.loss_scaler import CreateLossScaler +from deepspeed.runtime.utils import (empty_cache, see_memory_usage, inf, is_model_parallel_parameter, + align_dense_tensors, all_gather_dp_groups) +from deepspeed.runtime.zero.config import ZeroStageEnum +from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum +from deepspeed.ops.adam import DeepSpeedCPUAdam +from deepspeed.utils import logger +from deepspeed.utils.bwc import bwc_tensor_model_parallel_rank +from deepspeed.moe.utils import is_moe_param +from deepspeed.git_version_info import version + +from deepspeed.runtime.constants import PIPE_REPLICATED +from deepspeed.accelerator import get_accelerator + +from deepspeed.checkpoint.constants import (DS_VERSION, GROUP_PADDINGS, PARTITION_COUNT, LOSS_SCALER, + SINGLE_PARTITION_OF_FP32_GROUPS, BASE_OPTIMIZER_STATE, + BASE_OPTIMIZER_STATE_STEP, CLIP_GRAD, ZERO_STAGE, PARAM_SLICE_MAPPINGS) +from deepspeed.utils import link_hp_params, lazy_init_hp_params_optimizer_state +from deepspeed.checkpoint import enable_universal_checkpoint + +from deepspeed.utils import groups +# Toggle this to true to enable correctness test +# with gradient partitioning and without +pg_correctness_test = False + +OPTIMIZER_ALLGATHER_TIMER = 'optimizer_allgather' +OPTIMIZER_GRADIENTS_TIMER = 'optimizer_gradients' +OPTIMIZER_STEP_TIMER = 'optimizer_step' +OPTIMIZER_TIMERS = [OPTIMIZER_ALLGATHER_TIMER, OPTIMIZER_GRADIENTS_TIMER, OPTIMIZER_STEP_TIMER] +INITIAL_MICRO_STEP_ID = -1 + + +def input(msg): + return + + +def split_half_float_double(tensors): + device_type = get_accelerator().device_name() + dtypes = [ + "torch.{}.HalfTensor".format(device_type), "torch.{}.FloatTensor".format(device_type), "torch.{}.BFloat16Tensor".format(device_type) + ] + buckets = [] + for i, dtype in enumerate(dtypes): + bucket = [t for t in tensors if t.type() == dtype] + if bucket: + buckets.append(bucket) + return buckets + + +def isclose(a, b, rtol=1e-09, atol=0.0): + return abs(a - b) <= max(rtol * max(abs(a), abs(b)), atol) + + +def lcm(x, y): + from fractions import gcd # or can import gcd from `math` in Python 3 + return x * y // gcd(x, y) + + +def get_alignment_padding(tensor_list, alignment): + num_elements = sum([tensor.numel() for tensor in tensor_list]) + remainder = num_elements % alignment + return (alignment - remainder) if remainder else remainder + + +def print_rank_msg(msg): + print(f"rank {dist.get_rank()} - {msg}") + + +def _get_padded_tensor(src_tensor, size): + if src_tensor.numel() >= size: + return src_tensor + padded_tensor = torch.zeros(size, dtype=src_tensor.dtype, device=src_tensor.device) + slice_tensor = torch.narrow(padded_tensor, 0, 0, src_tensor.numel()) + slice_tensor.data.copy_(src_tensor.data) + return padded_tensor + + +def _pad_tensor_by_size(src_tensor, pad_size, dtype, device): + padded_tensor = torch.zeros(src_tensor.numel() + pad_size, dtype=dtype, device=device) + padded_tensor.data[:src_tensor.numel()].copy_(src_tensor.data) + return padded_tensor + + +class DeepSpeedZeroOptimizer(ZeROOptimizer): + """ + DeepSpeedZeroOptimizer designed to reduce the memory footprint + required for training large deep learning models. + + For more details please see ZeRO: Memory Optimization Towards Training A Trillion Parameter Models + https://arxiv.org/abs/1910.02054 + + For usage examples, refer to TODO: DeepSpeed Tutorial + + """ + + def __init__(self, + init_optimizer, + param_names, + timers, + static_loss_scale=1.0, + dynamic_loss_scale=False, + dynamic_loss_args=None, + verbose=True, + contiguous_gradients=True, + reduce_bucket_size=500000000, + use_multi_rank_bucket_allreduce=True, + allgather_bucket_size=5000000000, + dp_process_group=None, + expert_parallel_group=None, + expert_data_parallel_group=None, + reduce_scatter=True, + overlap_comm=False, + offload_optimizer_config=None, + mpu=None, + clip_grad=0.0, + gradient_accumulation_dtype=torch.float32, + communication_data_type=torch.float16, + postscale_gradients=True, + gradient_predivide_factor=1.0, + gradient_accumulation_steps=1, + ignore_unused_parameters=True, + partition_grads=True, + round_robin_gradients=False, + has_moe_layers=False, + fp16_master_weights_and_gradients=False, + elastic_checkpoint=False): + + if offload_optimizer_config is not None and offload_optimizer_config.device != OffloadDeviceEnum.none: + self.cpu_offload = True + self.cpu_offload_pin_memory = offload_optimizer_config.pin_memory + else: + self.cpu_offload = False + self.cpu_offload_pin_memory = False + + if dist.get_rank() == 0: + logger.info(f"Reduce bucket size {reduce_bucket_size}") + logger.info(f"Allgather bucket size {allgather_bucket_size}") + logger.info(f"CPU Offload: {self.cpu_offload}") + logger.info(f'Round robin gradient partitioning: {round_robin_gradients}') + # The fused optimizer does all the work. We need this layer for two reason: + # 1. maintain same user API from apex.fp16_utils + # 2. keep common stuff here in case we need to add ne552w fused optimizer later + + self.elastic_checkpoint = elastic_checkpoint + self.param_names = param_names + self.mpu = mpu + # differences from apex.fp16_utils: + # - assume all model params in fp16 + # - assume all params requires grad + # - flat by groups, not keeping state. TODO: remove state explicitly? + # - master grad and unflat master weight never exist. TODO: a way to save out unflat master? + if not get_accelerator().is_available(): + raise SystemError("Accelerator is not detected, cannot perform low precision training (e.g., fp16, bf16).") + self.optimizer = init_optimizer + + # Use torch (un)flatten ops + self.flatten = _flatten_dense_tensors + self.unflatten = _unflatten_dense_tensors + + # ZeRO stage 1 (False) or 2 (True) + self.partition_gradients = partition_grads + self.zero_stage_string = "ZeRO-2" if partition_grads else "ZeRO-1" + + self.timers = timers + + self.reduce_scatter = reduce_scatter + + self.overlap_comm = overlap_comm + + self.deepspeed_adam_offload = self.cpu_offload + + self.device = get_accelerator().current_device_name() if not self.cpu_offload else 'cpu' + + self.dp_process_group = dp_process_group + self.sequence_parallel_size = groups._get_sequence_parallel_world_size() + #expert parallel group + self.ep_process_group = expert_parallel_group + + #data parallel group for experts + self.expert_dp_process_group = expert_data_parallel_group + + #data parallel size for non-experts + dp_size = dist.get_world_size(group=self.dp_process_group) + + #For MoE models this maybe different for different param group + #It will be modified during MoE setup later in the init + self.real_dp_process_group = [dp_process_group for i in range(len(self.optimizer.param_groups))] + self.partition_count = [dp_size for i in range(len(self.optimizer.param_groups))] + + self.is_gradient_accumulation_boundary = True + + # CPU-Offload requires contiguous gradients + self.contiguous_gradients = contiguous_gradients or self.cpu_offload + + self.has_moe_layers = has_moe_layers + if self.has_moe_layers: + self._configure_moe_settings() + self._global_grad_norm = 0. + + if mpu is None: + self.model_parallel_group = None + self.model_parallel_world_size = 1 + self.model_parallel_rank = 0 + else: + self.model_parallel_group = mpu.get_model_parallel_group() + self.model_parallel_world_size = mpu.get_model_parallel_world_size() + self.model_parallel_rank = bwc_tensor_model_parallel_rank(mpu) + + self.overflow = False + self.clip_grad = clip_grad + self.communication_data_type = communication_data_type + self.gradient_predivide_factor = gradient_predivide_factor + self.postscale_gradients = postscale_gradients + self.gradient_accumulation_steps = gradient_accumulation_steps + self.micro_step_id = INITIAL_MICRO_STEP_ID + self.ignore_unused_parameters = ignore_unused_parameters + self.round_robin_gradients = round_robin_gradients + + self.extra_large_param_to_reduce = None + self.fp16_master_weights_and_gradients = fp16_master_weights_and_gradients + + if self.fp16_master_weights_and_gradients: + assert self.cpu_offload and type(self.optimizer) in [DeepSpeedCPUAdam], \ + f"fp16_master_and_gradients requires optimizer to support keeping fp16 master and gradients while keeping the optimizer states in fp32."\ + f"Currently only supported using ZeRO-Offload with DeepSpeedCPUAdam. But current setting is ZeRO-Offload:{self.cpu_offload} and optimizer type {type(self.optimizer)}." \ + f"Either disable fp16_master_weights_and_gradients or enable {self.zero_stage_string} Offload with DeepSpeedCPUAdam." + + if self.reduce_scatter and self.partition_gradients: + valid_reduce_scatter_dtypes = (torch.float16, torch.bfloat16, torch.float32) + assert self.communication_data_type in valid_reduce_scatter_dtypes, f"{self.zero_stage_string} supports {valid_reduce_scatter_dtypes} communication_data_type with reduce scatter enabled. Got: '{self.communication_data_type}'" + assert self.gradient_predivide_factor == 1.0, f"gradient_predivide_factor != 1.0 is not yet supported with {self.zero_stage_string} with reduce scatter enabled" + assert self.postscale_gradients, f"pre-scale gradients is not yet supported with {self.zero_stage_string} with reduce scatter enabled" + + # param flattened by groups + self.bit16_groups = [] + self.bit16_groups_flat = [] + + # param partitioned by data parallel degree + # this will contain a list of equal sized tensors + # each of which will be updated by a different process + self.parallel_partitioned_bit16_groups = [] + + # a single 32-bit partition of the parallel partitioned parameters + # that this process will update + self.single_partition_of_fp32_groups = [] + + # param partition info + + # These are the parameters in each group that will not be updated by this process directly + self.params_not_in_partition = [] + + # These are the parameters that will be updated by this process directly + self.params_in_partition = [] + + # Offset from the first parameter in the self.params_in_partition + # the parameter boundaries may not align with partition boundaries + # so we need to keep track of the offset + self.first_offset = [] + + # number of elements per partition in each group + self.partition_size = [] + + # align nccl all-gather send buffers to 4-byte boundary + self.nccl_start_alignment_factor = 2 # 4-byte alignment/sizeof(fp16) = 2 + + assert ( + allgather_bucket_size % self.nccl_start_alignment_factor == 0 + ), f"allgather_bucket_size must be a multiple of nccl_start_alignment_factor, {self.nccl_start_alignment_factor} " + + self.all_reduce_print = False + self.dtype = self.optimizer.param_groups[0]['params'][0].dtype + self.gradient_accumulation_dtype = gradient_accumulation_dtype + + if self.dtype != self.gradient_accumulation_dtype: + self.use_separate_grad_accum = True + else: + self.use_separate_grad_accum = False + if self.use_separate_grad_accum and not self.partition_gradients: + self.use_grad_accum_attribute = True + else: + self.use_grad_accum_attribute = False + + self.round_robin_bit16_groups = [] + self.round_robin_bit16_indices = [] + self.round_robin_bit16_meta = [] + + # Use different parallel to do all_to_all_reduce related things + # padding on each partition for alignment purposes + self.groups_padding = [] + # loop to deal with groups + for i, param_group in enumerate(self.optimizer.param_groups): + partition_id = dist.get_rank(group=self.real_dp_process_group[i]) + + # push this group to list before modify + # TODO: Explore simplification that avoids the extra book-keeping by pushing the reordered group + trainable_parameters = [] + for param in param_group['params']: + if param.requires_grad: + param.grad_accum = None + trainable_parameters.append(param) + self.bit16_groups.append(trainable_parameters) + + # not sure why apex was cloning the weights before flattening + # removing cloning here + + see_memory_usage(f"Before moving param group {i} to CPU") + # move all the parameters to cpu to free up GPU space for creating flat buffer + + # Create temp CPU param copies, free accelerator tensors + orig_group_numel = 0 + for param in self.bit16_groups[i]: + orig_group_numel += param.numel() + param.cpu_data = param.data.cpu() + param.data = torch.empty(1).to(param.device) + + empty_cache() + see_memory_usage(f"After moving param group {i} to CPU", force=False) + + # Reorder group parameters for load balancing of gradient partitioning during backward among ranks. + # This ensures that gradients are reduced in a fashion such that ownership round robins among the ranks. + # For example, rather than 3 gradients (g_n+2, g_n+1, g_n) that are reduced consecutively belonging + # to the same rank, instead they will belong to 3 ranks (r_m+2, r_m+1, r_m). + if self.round_robin_gradients: + round_robin_tensors, round_robin_indices = self._round_robin_reorder( + self.bit16_groups[i], dist.get_world_size(group=self.real_dp_process_group[i])) + else: + round_robin_tensors = self.bit16_groups[i] + round_robin_indices = list(range(len(self.bit16_groups[i]))) + + self.round_robin_bit16_groups.append(round_robin_tensors) + self.round_robin_bit16_indices.append(round_robin_indices) + + # Create meta tensors list, ordered according to round_robin_tensors + meta_tensors = [] + for param in round_robin_tensors: + meta_tensors.append(torch.zeros_like(param.cpu_data, device="meta")) + self.round_robin_bit16_meta.append(meta_tensors) + + # create flat buffer in CPU + flattened_buffer = self.flatten_dense_tensors_aligned( + self.round_robin_bit16_groups[i], + self.nccl_start_alignment_factor * dist.get_world_size(group=self.real_dp_process_group[i]), + use_cpu_data=True) + + # free temp CPU params + for param in self.bit16_groups[i]: + del param.cpu_data + + # Move CPU flat tensor to the accelerator memory. + self.bit16_groups_flat.append(flattened_buffer.to(get_accelerator().current_device_name())) + del flattened_buffer + + see_memory_usage(f"After flattening and moving param group {i} to GPU", force=False) + + # Record padding required for alignment + if partition_id == dist.get_world_size(group=self.real_dp_process_group[i]) - 1: + padding = self.bit16_groups_flat[i].numel() - orig_group_numel + else: + padding = 0 + self.groups_padding.append(padding) + + if dist.get_rank(group=self.real_dp_process_group[i]) == 0: + see_memory_usage(f"After Flattening and after emptying param group {i} cache", force=False) + + # set model bit16 weight to slices of flattened buffer + self._update_model_bit16_weights(i) + + # divide the flat weights into near equal partition equal to the data parallel degree + # each process will compute on a different part of the partition + data_parallel_partitions = self.get_data_parallel_partitions(self.bit16_groups_flat[i], i) + self.parallel_partitioned_bit16_groups.append(data_parallel_partitions) + + # verify that data partition start locations are 4-byte aligned + for partitioned_data in data_parallel_partitions: + assert (partitioned_data.data_ptr() % (2 * self.nccl_start_alignment_factor) == 0) + + # A partition of the fp32 master weights that will be updated by this process. + # Note that the params in single_partition_of_fp32_groups is cloned and detached + # from the origin params of the model. + if not fp16_master_weights_and_gradients: + weights_partition = self.parallel_partitioned_bit16_groups[i][partition_id].to( + self.device).clone().float().detach() + else: + weights_partition = self.parallel_partitioned_bit16_groups[i][partition_id].to( + self.device).clone().half().detach() + + if self.cpu_offload: + weights_partition = get_accelerator().pin_memory(weights_partition) + + self.single_partition_of_fp32_groups.append(weights_partition) + + # Set local optimizer to have flat params of its own partition. + # After this, the local optimizer will only contain its own partition of params. + # In that case, the local optimizer only saves the states(momentum, variance, etc.) related to its partition's params(zero stage1). + self.single_partition_of_fp32_groups[ + i].requires_grad = True # keep this in case internal optimizer uses it + param_group['params'] = [self.single_partition_of_fp32_groups[i]] + + partition_size = len(self.bit16_groups_flat[i]) / dist.get_world_size(group=self.real_dp_process_group[i]) + params_in_partition, params_not_in_partition, first_offset = self.get_partition_info( + self.round_robin_bit16_groups[i], partition_size, partition_id) + + self.partition_size.append(partition_size) + self.params_in_partition.append(params_in_partition) + self.params_not_in_partition.append(params_not_in_partition) + self.first_offset.append(first_offset) + + self.reduce_bucket_size = int(reduce_bucket_size) + self.use_multi_rank_bucket_allreduce = use_multi_rank_bucket_allreduce + self.allgather_bucket_size = int(allgather_bucket_size) + + self.reduction_stream = None if get_accelerator().is_synchronized_device() else get_accelerator().Stream() + #self.copy_grad_stream = get_accelerator().Stream() + self.callback_queued = False + + self.param_dict = {} + + # map between param_id and bool to specify if a param is in this partition + self.is_param_in_current_partition = {} + + self.grads_in_ipg_bucket = [] + self.params_in_ipg_bucket = [] + self.elements_in_ipg_bucket = 0 + self.params_already_reduced = [] + self._release_ipg_buffers() + self.previous_reduced_grads = None + self.ipg_bucket_has_moe_params = False + + # simplified param id + self.param_id = {} + + #interesting code: unique ids being assigned to individual parameters + largest_param_numel = 0 + count = 0 + for i, params_group in enumerate(self.bit16_groups): + for param in params_group: + unique_id = id(param) + self.param_id[unique_id] = count + self.param_dict[count] = param + self.params_already_reduced.append(False) + if param.numel() > largest_param_numel: + largest_param_numel = param.numel() + count = count + 1 + + for param_group in self.params_in_partition: + for param in param_group: + self.is_param_in_current_partition[self.get_param_id(param)] = True + + for param_group in self.params_not_in_partition: + for param in param_group: + self.is_param_in_current_partition[self.get_param_id(param)] = False + + if self.cpu_offload: + self.accumulated_grads_in_cpu = {} + self.norm_for_param_grads = {} + self.local_overflow = False + self.grad_position = {} + self.temp_grad_buffer_for_cpu_offload = torch.zeros(largest_param_numel, + device=self.device, + dtype=self.dtype) + if self.cpu_offload_pin_memory: + self.temp_grad_buffer_for_cpu_offload = get_accelerator().pin_memory( + self.temp_grad_buffer_for_cpu_offload) + self.temp_grad_buffer_for_gpu_offload = torch.zeros(largest_param_numel, + device=get_accelerator().current_device_name(), + dtype=self.dtype) + for i, params_group in enumerate(self.bit16_groups): + self.get_grad_position(i, self.params_in_partition[i], self.first_offset[i], self.partition_size[i]) + + # mapping from parameter to partition that it belongs to + self.param_to_partition_ids = {} + + # stores if a partition has been reduced in this step + self.is_partition_reduced = {} + + # number of grads in partition that still need to be computed + self.remaining_grads_in_partition = {} + + # total number of grads in partition + self.total_grads_in_partition = {} + + # stores if a grad in a partition has been computed or not + self.is_grad_computed = {} + + # stores the offset at which a parameter gradient needs to be inserted in a partition + self.grad_partition_insertion_offset = {} + + # the offset in the gradient at which it must be inserted at the beginning of the partition + self.grad_start_offset = {} + + # will store the averaged gradients required by this partition + self.averaged_gradients = {} + + # For cpu_offload, will store the averaged gradients required by this partition + self.offload_gradient_dict = {} + + # store index of first parameter in each partition + self.first_param_index_in_partition = {} + + # initializes all data structures for implementing gradient partitioning + self.initialize_gradient_partitioning_data_structures() + + # resets the data structure value for the next backward propagation + self.reset_partition_gradient_structures() + + # creates backward hooks for gradient partitioning + self._grad_acc_hooks = [] + if self.partition_gradients or self.overlap_comm: + self.create_reduce_and_remove_grad_hooks() + + self.custom_loss_scaler = False + self.external_loss_scale = None + + # we may have a way of fusing dynamic scale. Do not support for now + self.loss_scaler = CreateLossScaler(dtype=self.dtype, + static_loss_scale=static_loss_scale, + dynamic_scaling=dynamic_loss_scale, + dynamic_loss_args=dynamic_loss_args) + self.dynamic_loss_scale = self.loss_scaler.dynamic + + if self.dtype != torch.float16: + # Only fp16 should use dynamic loss scaling + assert self.loss_scaler.cur_scale == 1.0 + assert not self.dynamic_loss_scale + + see_memory_usage("Before initializing optimizer states", force=True) + self.initialize_optimizer_states() + see_memory_usage("After initializing optimizer states", force=True) + + if dist.get_rank() == 0: + logger.info(f"optimizer state initialized") + + if dist.get_rank(group=self.dp_process_group) == 0: + see_memory_usage(f"After initializing ZeRO optimizer", force=True) + + self._link_all_hp_params() + self._hp_optimizer_states_linked = False + + self._enable_universal_checkpoint() + self._param_slice_mappings = self._create_param_mapping() + + def destroy(self): + for i, _ in enumerate(self.optimizer.param_groups): + for p in self.bit16_groups[i]: + if getattr(p, '_hp_mapping', None): + p._hp_mapping = None + for hook in self._grad_acc_hooks: + hook.remove() + self.print_rank_0("Removed grad acc hooks") + + def _enable_universal_checkpoint(self): + for lp_param_group in self.bit16_groups: + enable_universal_checkpoint(param_list=lp_param_group) + + def _create_param_mapping(self): + param_mapping = [] + for i, _ in enumerate(self.optimizer.param_groups): + param_mapping_per_group = OrderedDict() + for lp in self.bit16_groups[i]: + if lp._hp_mapping is not None: + lp_name = self.param_names[lp] + param_mapping_per_group[lp_name] = lp._hp_mapping.get_hp_fragment_address() + param_mapping.append(param_mapping_per_group) + + return param_mapping + + def _link_all_hp_params(self): + if self.cpu_offload: + self._get_offload_gradient_dict() + + for i, _ in enumerate(self.optimizer.param_groups): + # Link bit16 and fp32 params in partition + partition_id = dist.get_rank(group=self.real_dp_process_group[i]) + partition_size = self.bit16_groups_flat[i].numel() // dist.get_world_size( + group=self.real_dp_process_group[i]) + flat_hp_partition = self.single_partition_of_fp32_groups[i] + link_hp_params(lp_param_list=self.bit16_groups[i], + flat_hp_partition=flat_hp_partition, + gradient_dict=self.averaged_gradients, + offload_gradient_dict=self.offload_gradient_dict, + use_offload=self.cpu_offload, + param_group_index=i, + partition_start=partition_id * partition_size, + partition_size=partition_size, + dp_group=self.real_dp_process_group[i]) + + def _lazy_init_hp_params_optimizer_state(self): + if not self._hp_optimizer_states_linked: + for i, _ in enumerate(self.optimizer.param_groups): + lazy_init_hp_params_optimizer_state(self.bit16_groups[i], self.single_partition_of_fp32_groups[i], + self.optimizer.state) + self._hp_optimizer_states_linked = True + + def is_moe_group(self, group): + return 'moe' in group and group['moe'] + + def _configure_moe_settings(self): + # if we're using ZeRO stage 2, ensure contiguous gradients are used + if self.partition_gradients: + assert self.contiguous_gradients, "Contiguous Gradients in ZeRO Stage 2 must be set to True for MoE. Other code paths are not tested with MoE" + # NOTE: To run ZeRO stage 1 with MoE, we need to set self.contiguous_gradients to True or ignore the assertion + if not self.partition_gradients and not self.contiguous_gradients: + logger.warn( + "ZeRO Stage 1 has not been thoroughly tested with MoE. This configuration is still experimental.") + assert self.reduce_scatter, "Reduce Scatter in ZeRO Stage 2 must be set to True for MoE. Other code paths are not tested with MoE" + + assert any( + [self.is_moe_group(group) for group in self.optimizer.param_groups] + ), "The model has moe layers, but None of the param groups are marked as MoE. Create a param group with 'moe' key set to True before creating optimizer" + self.is_moe_param_group = [] + for i, group in enumerate(self.optimizer.param_groups): + if self.is_moe_group(group): + assert all([is_moe_param(param) + for param in group['params']]), "All params in MoE group must be MoE params" + self.real_dp_process_group[i] = self.expert_dp_process_group[group['name']] + self.partition_count[i] = dist.get_world_size(group=self.expert_dp_process_group[group['name']]) + self.is_moe_param_group.append(True) + else: + self.is_moe_param_group.append(False) + + assert self.expert_dp_process_group is not None, "Expert data parallel group should be configured with MoE" + assert self.ep_process_group is not None, "Expert parallel group should be configured with MoE" + + def _update_model_bit16_weights(self, group_index): + updated_params = self.unflatten(self.bit16_groups_flat[group_index], self.round_robin_bit16_meta[group_index]) + for p, q in zip(self.round_robin_bit16_groups[group_index], updated_params): + p.data = q.data + + # set model fp16 weight to slices of reordered flattened buffer + for param_index, param in enumerate(self.bit16_groups[group_index]): + new_index = self.round_robin_bit16_indices[group_index][param_index] + param.data = self.round_robin_bit16_groups[group_index][new_index].data + + def _round_robin_reorder(self, tensor_list, num_partitions): + + # disable round robin if need to debug something + # return tensor_list, list(range(len(tensor_list))) + + partition_tensors = {} + + for i, tensor in enumerate(tensor_list): + j = i % num_partitions + if not j in partition_tensors: + partition_tensors[j] = [] + partition_tensors[j].append((i, tensor)) + + reordered_tensors = [] + reordered_indices = {} + + for partition_index in partition_tensors.keys(): + for i, (original_index, tensor) in enumerate(partition_tensors[partition_index]): + reordered_indices[original_index] = len(reordered_tensors) + reordered_tensors.append(tensor) + + return reordered_tensors, reordered_indices + + def _release_ipg_buffers(self): + if self.contiguous_gradients: + self.ipg_buffer = None + self.grads_in_partition = None + self.grads_in_partition_offset = 0 + + def initialize_optimizer_states(self): + + for i, group in enumerate(self.bit16_groups): + single_grad_partition = torch.zeros(int(self.partition_size[i]), + dtype=self.single_partition_of_fp32_groups[i].dtype, + device=self.device) + self.single_partition_of_fp32_groups[i].grad = get_accelerator().pin_memory( + single_grad_partition) if self.cpu_offload_pin_memory else single_grad_partition + + # Initialize the optimizer states with the flattened fp32 partition. + # State initialization for the Adagrad optimizer occurs at construction as opposed to other optimizers + # which do lazy initialization of the state at the first call to step. + if isinstance(self.optimizer, torch.optim.Adagrad): + self.optimizer = torch.optim.Adagrad(self.single_partition_of_fp32_groups, **self.optimizer.defaults) + + if not self.cpu_offload: + for group in self.single_partition_of_fp32_groups: + group.grad = None #class init + + return + + ######################################################################### + #################### ZeRO Stage 1 - reduce gradients #################### + ######################################################################### + def reduce_gradients(self, pipeline_parallel=False): + world_size = dist.get_world_size(self.dp_process_group) + my_rank = dist.get_rank(self.dp_process_group) + + # with PP we must create ipg buffer, since backward is handled outside zero + if pipeline_parallel and self.contiguous_gradients: + self.ipg_buffer = [] + buf_0 = torch.empty(int(self.reduce_bucket_size), + dtype=self.dtype, + device=get_accelerator().current_device_name()) + self.ipg_buffer.append(buf_0) + self.ipg_index = 0 + + if not self.overlap_comm: + for i, group in enumerate(self.bit16_groups): + for param in group: + grad_reduc = self.get_gradient_for_reduction(param) + if grad_reduc is not None: + self.reduce_ready_partitions_and_remove_grads(param, i) + # reduce any pending grads in either hook/non-hook case + self.overlapping_partition_gradients_reduce_epilogue() + + ######################################################################### + #########################ZeRO Partition Gradients######################## + ######################################################################### + + def get_first_param_index(self, group_id, param_group, partition_id): + for index, param in enumerate(param_group): + param_id = self.get_param_id(param) + if group_id in self.param_to_partition_ids and param_id in self.param_to_partition_ids[group_id]: + if partition_id in self.param_to_partition_ids[group_id][param_id]: + return index + return None + + def initialize_gradient_partitioning_data_structures(self): + + for i, param_group in enumerate(self.round_robin_bit16_groups): + total_partitions = dist.get_world_size(group=self.real_dp_process_group[i]) + + self.param_to_partition_ids[i] = {} + self.is_partition_reduced[i] = {} + self.total_grads_in_partition[i] = {} + self.remaining_grads_in_partition[i] = {} + self.is_grad_computed[i] = {} + self.grad_partition_insertion_offset[i] = {} + self.grad_start_offset[i] = {} + self.first_param_index_in_partition[i] = {} + + for partition_id in range(total_partitions): + self.is_grad_computed[i][partition_id] = {} + self.grad_partition_insertion_offset[i][partition_id] = {} + self.grad_start_offset[i][partition_id] = {} + self.total_grads_in_partition[i][partition_id] = 0 + self.initialize_gradient_partition(i, param_group, partition_id) + self.is_partition_reduced[i][partition_id] = False + self.first_param_index_in_partition[i][partition_id] = self.get_first_param_index( + i, param_group, partition_id) + + def independent_gradient_partition_epilogue(self): + self.report_ipg_memory_usage(f"In ipg_epilogue before reduce_ipg_grads", 0) + self.reduce_ipg_grads() + self.report_ipg_memory_usage(f"In ipg_epilogue after reduce_ipg_grads", 0) + + # if dist.get_rank() == 0: + # logger.info("Params already reduced %s", self.params_already_reduced) + for i in range(len(self.params_already_reduced)): + self.params_already_reduced[i] = False + + if self.overlap_comm: + if not get_accelerator().resolves_data_dependency(): + get_accelerator().synchronize() + # It is safe to clear previously reduced grads of other partitions + self._clear_previous_reduced_grads() + + if self.cpu_offload is False: + for i, _ in enumerate(self.bit16_groups): + + if not i in self.averaged_gradients or self.averaged_gradients[i] is None: + self.averaged_gradients[i] = self.get_flat_partition( + self.params_in_partition[i], + self.first_offset[i], + self.partition_size[i], + dtype=self.gradient_accumulation_dtype, + device=get_accelerator().current_device_name(), + return_tensor_list=True) + else: + avg_new = self.get_flat_partition(self.params_in_partition[i], + self.first_offset[i], + self.partition_size[i], + dtype=self.gradient_accumulation_dtype, + device=get_accelerator().current_device_name(), + return_tensor_list=True) + + for accumulated_grad, new_avg_grad in zip(self.averaged_gradients[i], avg_new): + accumulated_grad.add_(new_avg_grad) + + self._release_ipg_buffers() + + # No need to keep the gradients anymore. + # All gradients required by the step + # are in self.averaged_gradients + self.zero_grad(set_to_none=True) + see_memory_usage(f"End ipg_epilogue") + + # resets all partition to no reduced + # sets remaining grads to the total number of grads in each partition + # set is grad computed to false for all grads in partition + def reset_partition_gradient_structures(self): + for i, _ in enumerate(self.bit16_groups): + total_partitions = dist.get_world_size(group=self.real_dp_process_group[i]) + for partition_id in range(total_partitions): + self.is_partition_reduced[i][partition_id] = False + self.remaining_grads_in_partition[i][partition_id] = self.total_grads_in_partition[i][partition_id] + + for param_id in self.is_grad_computed[i][partition_id]: + self.is_grad_computed[i][partition_id][param_id] = False + + def initialize_gradient_partition(self, i, param_group, partition_id): + + def set_key_value_list(dictionary, key, value): + if key in dictionary: + dictionary[key].append(value) + else: + dictionary[key] = [value] + + def increment_value(dictionary, key): + if key in dictionary: + dictionary[key] += 1 + else: + dictionary[key] = 1 + + partition_size = self.partition_size[i] + + start_index = partition_size * partition_id + end_index = partition_size * (partition_id + 1) + + current_index = 0 + first_offset = 0 + + for param in param_group: + + param_size = param.numel() + param_id = self.get_param_id(param) + + if start_index <= current_index < end_index: + set_key_value_list(self.param_to_partition_ids[i], param_id, partition_id) + increment_value(self.total_grads_in_partition[i], partition_id) + + self.is_grad_computed[i][partition_id][param_id] = False + + self.grad_partition_insertion_offset[i][partition_id][param_id] = current_index - start_index + self.grad_start_offset[i][partition_id][param_id] = 0 + + elif current_index < start_index < (current_index + param_size): + assert (first_offset == 0 + ), "This can happen either zero or only once as this must be the first tensor in the partition" + first_offset = start_index - current_index + + set_key_value_list(self.param_to_partition_ids[i], param_id, partition_id) + increment_value(self.total_grads_in_partition[i], partition_id) + + self.is_grad_computed[i][partition_id][param_id] = False + + self.grad_partition_insertion_offset[i][partition_id][param_id] = 0 + self.grad_start_offset[i][partition_id][param_id] = first_offset + + current_index = current_index + param_size + + def overlapping_partition_gradients_reduce_epilogue(self): + self.independent_gradient_partition_epilogue() + + def fill_grad_accum_attribute(self): + for group in self.bit16_groups: + for param in group: + if param.grad is not None: + if param.grad_accum is None: + param.grad_accum = param.grad.to(self.gradient_accumulation_dtype) + else: + param.grad_accum.add_( + param.grad.to(self.gradient_accumulation_dtype).view(param.grad_accum.shape)) + param.grad = None + + def get_gradient_for_reduction(self, param): + if self.use_grad_accum_attribute: + return param.grad_accum.to(self.dtype) if param.grad_accum is not None else None + else: + return param.grad + + def get_param_gradient_attribute(self, param): + return param.grad_accum if self.use_grad_accum_attribute else param.grad + + # Clear the tensor the reduction gradient attribute is pointing to + def clear_grad_attribute(self, param): + if self.use_grad_accum_attribute: + param.grad_accum = None + else: + param.grad = None + + def create_reduce_and_remove_grad_hooks(self): + self.grad_accs = [] + for i, param_group in enumerate(self.bit16_groups): + for param in param_group: + if param.requires_grad: + + def wrapper(param, i): + param_tmp = param.expand_as(param) + grad_acc = param_tmp.grad_fn.next_functions[0][0] + + def reduce_partition_and_remove_grads(*notneeded): + self.reduce_ready_partitions_and_remove_grads(param, i) + + self._grad_acc_hooks.append(grad_acc.register_hook(reduce_partition_and_remove_grads)) + self.grad_accs.append(grad_acc) + + wrapper(param, i) + + def get_param_id(self, param): + unique_id = id(param) + return self.param_id[unique_id] + + def report_ipg_memory_usage(self, tag, param_elems): + elem_count = self.elements_in_ipg_bucket + param_elems + percent_of_bucket_size = (100.0 * elem_count) // self.reduce_bucket_size + see_memory_usage( + f"{tag}: elems in_bucket {self.elements_in_ipg_bucket} param {param_elems} max_percent {percent_of_bucket_size}" + ) + + # create a flat tensor aligned at the alignment boundary + def flatten_dense_tensors_aligned(self, tensor_list, alignment, use_cpu_data=False): + tensor_list = [param.cpu_data for param in tensor_list] if use_cpu_data else tensor_list + return self.flatten(align_dense_tensors(tensor_list, alignment)) + + ############### Independent Partition Gradient ######################## + def reduce_independent_p_g_buckets_and_remove_grads(self, param, i): + + grad_reduc = self.get_gradient_for_reduction(param) + if self.elements_in_ipg_bucket + param.numel() > self.reduce_bucket_size: + self.report_ipg_memory_usage("In ipg_remove_grads before reduce_ipg_grads", param.numel()) + self.reduce_ipg_grads() + if self.contiguous_gradients and self.overlap_comm: + # Swap ipg_index between 0 and 1 + self.ipg_index = 1 - self.ipg_index + self.report_ipg_memory_usage("In ipg_remove_grads after reduce_ipg_grads", param.numel()) + + param_id = self.get_param_id(param) + assert self.params_already_reduced[param_id] == False, \ + f"The parameter {param_id} has already been reduced. \ + Gradient computed twice for this partition. \ + Multiple gradient reduction is currently not supported" + + if self.contiguous_gradients: + if param.numel() > self.reduce_bucket_size: + self.extra_large_param_to_reduce = param + else: + # keeping the gradients contiguous to prevent memory fragmentation, and avoid flattening + new_grad_tensor = self.ipg_buffer[self.ipg_index].narrow(0, self.elements_in_ipg_bucket, param.numel()) + new_grad_tensor.copy_(grad_reduc.view(-1)) + grad_reduc.data = new_grad_tensor.data.view_as(grad_reduc) + + self.elements_in_ipg_bucket += param.numel() + + assert grad_reduc is not None, f"rank {dist.get_rank()} - Invalid to reduce Param {param_id} with None gradient" + + self.grads_in_ipg_bucket.append(grad_reduc) + self.params_in_ipg_bucket.append((i, param, param_id)) + + #make sure the average tensor function knows how to average the gradients + if is_moe_param(param): + self.ipg_bucket_has_moe_params = True + + self.report_ipg_memory_usage("End ipg_remove_grads", 0) + + def print_rank_0(self, message): + if dist.get_rank() == 0: + logger.info(message) + + def gradient_reduction_w_predivide(self, tensor): + if tensor.size().numel() == 0: + return tensor + + dp_world_size = dist.get_world_size(group=self.dp_process_group) + + tensor_to_allreduce = tensor + + if self.communication_data_type != tensor.dtype: + tensor_to_allreduce = tensor.to(self.communication_data_type) + + if self.postscale_gradients: + if self.gradient_predivide_factor != 1.0: + tensor_to_allreduce.mul_(1. / self.gradient_predivide_factor) + + dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group) + + if self.gradient_predivide_factor != dp_world_size: + tensor_to_allreduce.mul_(self.gradient_predivide_factor / + (dp_world_size / float(self.sequence_parallel_size))) + else: + tensor_to_allreduce.div_(dp_world_size / float(self.sequence_parallel_size)) + dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group) + + if self.communication_data_type != tensor.dtype and tensor is not tensor_to_allreduce: + tensor.copy_(tensor_to_allreduce) + + return tensor + + def allreduce_and_copy_with_multiple_ranks(self, + small_bucket, + log=None, + divide=True, + process_group=None, + bucket_ranks=None): + process_group = self.dp_process_group if process_group is None else process_group + allreduced = self.allreduce_bucket(small_bucket, log=log, divide=divide, process_group=process_group) + for buf, synced, bucket_rank in zip(small_bucket, self.unflatten(allreduced, small_bucket), bucket_ranks): + if dist.get_rank(group=process_group) == bucket_rank: + buf.copy_(synced) + + def allreduce_and_scatter(self, bucket, numel_per_bucket=500000000, log=None, divide=True, process_group=None): + small_bucket = [] + small_bucket_ranks = [] + numel = 0 + allreduce_sizes = [] + + for i, bucket_elem in enumerate(bucket): + rank, tensor = bucket_elem + small_bucket.append(tensor) + small_bucket_ranks.append(rank) + numel = numel + tensor.numel() + if numel > numel_per_bucket: + self.allreduce_and_copy_with_multiple_ranks(small_bucket, + log=None, + divide=divide, + process_group=process_group, + bucket_ranks=small_bucket_ranks) + small_bucket = [] + small_bucket_ranks = [] + numel = 0 + + if len(small_bucket) > 0: + self.allreduce_and_copy_with_multiple_ranks(small_bucket, + log=None, + divide=divide, + process_group=process_group, + bucket_ranks=small_bucket_ranks) + + def average_tensor(self, tensor): + if self.overlap_comm: + stream = self.reduction_stream + if not get_accelerator().resolves_data_dependency(): + stream.wait_stream(get_accelerator().current_stream()) + get_accelerator().current_stream().wait_stream(stream) + else: + stream = get_accelerator().current_stream() + + with get_accelerator().stream(stream): + if not self.reduce_scatter: + self.gradient_reduction_w_predivide(tensor) + return + + # Accumulate destination ranks and bucket offsets for each gradient slice. + # Note: potential future optimization, record access pattern of parameters + # in backward pass and partition gradients w.r.t. access pattern so that our + # bucket is guaranteed to be contiguous w.r.t. ranks + rank_and_offsets = [] + real_dp_process_group = [] + curr_size = 0 + prev_id, prev_process_group = -1, None + + process_group = self.dp_process_group + # count = 0 + for i, param, param_id in self.params_in_ipg_bucket: + + process_group = self.dp_process_group + grad_reduc = self.get_gradient_for_reduction(param) + #Averages gradients at parameter level if ipg has a moe param + #Otherwise averaging is done at the entire buffer level at the end of the loop + # MoE param have different groups + if self.ipg_bucket_has_moe_params: + process_group = self.expert_dp_process_group[param.group_name] if is_moe_param( + param) else self.dp_process_group + grad_reduc.data.div_(dist.get_world_size(group=process_group) / float(self.sequence_parallel_size)) + + partition_ids = self.param_to_partition_ids[i][param_id] + assert all([p_id < dist.get_world_size(group=process_group) for p_id in partition_ids + ]), f"world size {dist.get_world_size(group=process_group)} and p_ids: {partition_ids}" + partition_size = self.partition_size[i] + # Get all partition ids + their offsets + partition_ids_w_offsets = [] + for partition_id in partition_ids: + offset = self.grad_start_offset[i][partition_id][param_id] + partition_ids_w_offsets.append((partition_id, offset)) + partition_ids_w_offsets.sort(key=lambda t: t[1]) + + # Calculate rank and offsets for grad slices + for idx in range(len(partition_ids_w_offsets)): + partition_id, offset = partition_ids_w_offsets[idx] + + # if dist.get_rank() == 0 and count < 100: + # print(f"Rank {dist.get_rank()} rank offset id {idx} calculated dp size {dist.get_world_size(group=process_group)} real dp size {dist.get_world_size(self.real_dp_process_group[i])} and dst: {partition_id}") + # count += 1 + + # Calculate numel for grad slice depending on partition location + if idx == len(partition_ids_w_offsets) - 1: + # Last partition_id uses its own offset + numel = param.numel() - offset + else: + # Set numel to next partition's offset + numel = partition_ids_w_offsets[idx + 1][1] - offset + + # Merge bucket ranges if they belong to the same rank + if partition_id == prev_id and process_group == prev_process_group: + prev_pid, prev_size, prev_numel = rank_and_offsets[-1] + rank_and_offsets[-1] = (prev_pid, prev_size, prev_numel + numel) + else: + rank_and_offsets.append((partition_id, curr_size, numel)) + real_dp_process_group.append(process_group) + curr_size += numel + prev_id, prev_process_group = partition_id, process_group + + if not self.ipg_bucket_has_moe_params: + tensor.div_(dist.get_world_size(group=self.dp_process_group) / float(self.sequence_parallel_size)) + + buckets = {} + for i, (dst, bucket_offset, numel) in enumerate(rank_and_offsets): + grad_slice = tensor.narrow(0, int(bucket_offset), int(numel)) + bucket_key = real_dp_process_group[i] if self.use_multi_rank_bucket_allreduce else ( + dst, real_dp_process_group[i]) + if bucket_key not in buckets: + buckets[bucket_key] = [] + if self.use_multi_rank_bucket_allreduce: + buckets[bucket_key].append((dst, grad_slice)) + else: + buckets[bucket_key].append(grad_slice) + + for bucket_key in buckets: + if self.use_multi_rank_bucket_allreduce: + self.allreduce_and_scatter(buckets[bucket_key], + numel_per_bucket=self.reduce_bucket_size, + divide=False, + process_group=bucket_key) + else: + dst, process_group = bucket_key + self.allreduce_no_retain(buckets[bucket_key], + numel_per_bucket=self.reduce_bucket_size, + rank=dst, + divide=False, + process_group=process_group) + + ############################################################################## + ############################# CPU Offload Methods############################# + ############################################################################## + def get_grad_position(self, group_id, tensor_list, first_offset, partition_size): + current_offset = 0 + + for i, tensor in enumerate(tensor_list): + param_id = self.get_param_id(tensor) + param_start_offset = 0 + + num_elements = tensor.numel() + + # we need to offset to get to the right element + if i == 0 and first_offset > 0: + tensor_offset = first_offset + num_elements = num_elements - tensor_offset + param_start_offset = first_offset + + # we dont need all elements of the tensor + if num_elements > (partition_size - current_offset): + num_elements = partition_size - current_offset + + self.grad_position[param_id] = [ + int(group_id), int(param_start_offset), + int(current_offset), int(num_elements) + ] + current_offset += num_elements + + def update_overflow_tracker_for_param_grad(self, param): + grad_accum = self.get_param_gradient_attribute(param) + if grad_accum is not None and self._has_inf_or_nan(grad_accum.data): + self.local_overflow = True + + def _get_offload_gradient_dict(self): + for param_group_index, _ in enumerate(self.optimizer.param_groups): + self.offload_gradient_dict[param_group_index] = [] + for lp_param in self.params_in_partition[param_group_index]: + param_id = self.get_param_id(lp_param) + [_, _, dest_offset, num_elements] = self.grad_position[param_id] + dest_tensor = self.single_partition_of_fp32_groups[param_group_index].grad.view(-1).narrow( + 0, dest_offset, num_elements) + self.offload_gradient_dict[param_group_index].append(dest_tensor) + + def async_accumulate_grad_in_cpu_via_gpu(self, param): + param_id = self.get_param_id(param) + + [i, source_offset, dest_offset, num_elements] = self.grad_position[param_id] + + # copy to a preexisiting buffer to avoid memory allocation penalty + dest_buffer = self.temp_grad_buffer_for_gpu_offload.view(-1).narrow(0, 0, param.numel()) + + #buffer for storing gradients for this parameter in CPU + def buffer_to_accumulate_to_in_cpu(): + if not self.fp16_master_weights_and_gradients: + buffer = torch.zeros(param.numel(), dtype=param.dtype, device=self.device) + return get_accelerator().pin_memory(buffer) if self.cpu_offload_pin_memory else buffer + else: + return self.single_partition_of_fp32_groups[i].grad.view(-1).narrow(0, dest_offset, num_elements) + + #accumulate gradients into param.grad_accum or parts of it that belongs to this partition + def accumulate_gradients(): + grad_accum = self.get_param_gradient_attribute(param) + if not self.fp16_master_weights_and_gradients: + dest_buffer.copy_(self.accumulated_grads_in_cpu[param_id].view(-1), non_blocking=True) + grad_accum.data.view(-1).add_(dest_buffer) + else: + dest_buffer.narrow(0, source_offset, + num_elements).copy_(self.accumulated_grads_in_cpu[param_id].view(-1), + non_blocking=True) + grad_accum.data.view(-1).narrow(0, source_offset, + num_elements).add_(dest_buffer.narrow(0, source_offset, num_elements)) + + #move accumulated gradients back to CPU + def copy_gradients_to_cpu(): + grad_accum = self.get_param_gradient_attribute(param) + if not self.fp16_master_weights_and_gradients: + self.accumulated_grads_in_cpu[param_id].data.copy_(grad_accum.data.view(-1), non_blocking=True) + else: + self.accumulated_grads_in_cpu[param_id].data.copy_(grad_accum.data.view(-1).narrow( + 0, source_offset, num_elements), + non_blocking=True) + + if param_id not in self.accumulated_grads_in_cpu: + self.accumulated_grads_in_cpu[param_id] = buffer_to_accumulate_to_in_cpu() + + if self.micro_step_id > 0: + accumulate_gradients() + else: + copy_gradients_to_cpu() + + def set_norm_for_param_grad(self, param): + param_id = self.get_param_id(param) + grad_accum = self.get_param_gradient_attribute(param) + accumulated_grad = self.accumulated_grads_in_cpu[ + param_id] if self.gradient_accumulation_steps > 1 else grad_accum + + [i, source_offset, dest_offset, num_elements] = self.grad_position[param_id] + + start = source_offset + accumulated_grad = accumulated_grad.view(-1).narrow(0, start, num_elements) + + self.norm_for_param_grads[param_id] = accumulated_grad.data.float().norm(2) + + def set_norm_for_param_grad_in_gpu(self, param): + param_id = self.get_param_id(param) + grad_accum = self.get_param_gradient_attribute(param) + if grad_accum is None: + accumulated_grad = param.grad + else: + accumulated_grad = grad_accum + + [i, source_offset, dest_offset, num_elements] = self.grad_position[param_id] + + start = source_offset + accumulated_grad = accumulated_grad.view(-1).narrow(0, start, num_elements) + + self.norm_for_param_grads[param_id] = accumulated_grad.data.float().norm(2) + + def async_inplace_copy_grad_to_fp32_buffer_from_gpu(self, param): + param_id = self.get_param_id(param) + + [i, source_offset, dest_offset, num_elements] = self.grad_position[param_id] + + dest_tensor = self.single_partition_of_fp32_groups[i].grad.view(-1).narrow(0, dest_offset, num_elements) + + grad_accum = self.get_param_gradient_attribute(param) + if grad_accum is None: + src_tensor = grad_accum.view(-1).narrow(0, source_offset, num_elements) + else: + src_tensor = grad_accum.view(-1).narrow(0, source_offset, num_elements) + if not self.fp16_master_weights_and_gradients: + src_tensor = src_tensor.float() + + dest_tensor.copy_(src_tensor, non_blocking=True) + param.grad = None #offload only + + def complete_grad_norm_calculation_for_cpu_offload(self, params): + total_norm = 0.0 + norm_type = 2.0 + for p in params: + # Pipeline parallelism may replicate parameters. Avoid multi-counting. + if hasattr(p, PIPE_REPLICATED) and p.ds_pipe_replicated: + continue + + if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0): + param_id = self.get_param_id(p) + # as some model have trainable parameters but skipped in training, + # their backward hooks in self.create_reduce_and_remove_grad_hooks() will not run, + # so they have no norm_for_param_grads + if param_id in self.norm_for_param_grads: + param_norm = self.norm_for_param_grads[param_id] + total_norm += param_norm.item()**2 + else: + # As unused parameters in modules may not be expected sometimes, + # add an explicit error msg when it occurred and an option to + # avoid the error + assert self.ignore_unused_parameters, """ + This assert indicates that your module has parameters that + were not used in producing loss. + You can avoid this assert by + (1) enable ignore_unused_parameters option in zero_optimization config; + (2) making sure all trainable parameters and `forward` function + outputs participate in calculating loss. + """ + + # Sum across all model parallel GPUs. + total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)]) + dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.SUM, group=self.dp_process_group) + + self._model_parallel_all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.SUM) + + total_norm = total_norm_cuda[0].item()**(1. / norm_type) + + if total_norm == float('inf') or total_norm == -float('inf') or total_norm != total_norm: + total_norm = -1 + + return total_norm + + ############################################################################################ + def copy_grads_in_partition(self, param): + if self.cpu_offload: + + if self.gradient_accumulation_steps > 1: + self.async_accumulate_grad_in_cpu_via_gpu(param) + + if self.is_gradient_accumulation_boundary: + self.set_norm_for_param_grad_in_gpu(param) + + self.update_overflow_tracker_for_param_grad(param) + + self.async_inplace_copy_grad_to_fp32_buffer_from_gpu(param) + + return + #print(f"ID {self.get_param_id(param)} grad norm {param.grad.norm()}") + if self.grads_in_partition is None: + self.grads_in_partition_offset = 0 + total_size = 0 + for group in self.params_in_partition: + for param_in_partition in group: + total_size += param_in_partition.numel() + + see_memory_usage(f"before copying {total_size} gradients into partition") + self.grads_in_partition = torch.empty(int(total_size), + dtype=self.dtype, + device=get_accelerator().current_device_name()) + see_memory_usage(f"after copying {total_size} gradients into partition") + + grad_reduc = self.get_gradient_for_reduction(param) + # The allreduce buffer will be rewritten. Copy the gradients in partition to a new buffer + new_grad_tensor = self.grads_in_partition.view(-1).narrow(0, self.grads_in_partition_offset, param.numel()) + new_grad_tensor.copy_(grad_reduc.view(-1)) + grad_reduc.data = new_grad_tensor.data.view_as(grad_reduc) + #print(f"Grad norm after copy to contiguous_buffer {param.grad.data.norm()}") + self.grads_in_partition_offset += param.numel() + + def reduce_ipg_grads(self): + if self.contiguous_gradients: + if self.extra_large_param_to_reduce is not None: + assert len(self.params_in_ipg_bucket) == 1, "more than 1 param in ipg bucket, this shouldn't happen" + _, _, param_id = self.params_in_ipg_bucket[0] + assert self.get_param_id(self.extra_large_param_to_reduce + ) == param_id, "param in ipg bucket does not match extra-large param" + extra_large_grad_reduc = self.get_gradient_for_reduction(self.extra_large_param_to_reduce) + self.average_tensor(extra_large_grad_reduc.view(-1)) + self.extra_large_param_to_reduce = None + else: + self.average_tensor(self.ipg_buffer[self.ipg_index].narrow(0, 0, self.elements_in_ipg_bucket)) + else: + self.buffered_reduce_fallback(None, + self.grads_in_ipg_bucket, + elements_per_buffer=self.elements_in_ipg_bucket) + + if self.overlap_comm: + stream = self.reduction_stream + elif self.cpu_offload: + # TODO: copy_grad_stream is disabled because of race with reduce. This hurts perf and should be fixed. + # get_accelerator().synchronize() + # stream = self.copy_grad_stream + stream = get_accelerator().current_stream() + else: + stream = get_accelerator().current_stream() + + with get_accelerator().stream(stream): + for _, param, param_id in self.params_in_ipg_bucket: + + assert self.params_already_reduced[param_id] == False, \ + f"The parameter {param_id} has already been reduced. \ + Gradient computed twice for this partition. \ + Multiple gradient reduction is currently not supported" + + self.params_already_reduced[param_id] = True + if self.partition_gradients: + if not self.is_param_in_current_partition[param_id]: + if self.overlap_comm and self.contiguous_gradients is False: + # Clear grads of other partitions during the next reduction + # to avoid clearing them before the reduction is complete. + if self.previous_reduced_grads is None: + self.previous_reduced_grads = [] + self.previous_reduced_grads.append(param) + else: + self.clear_grad_attribute(param) + elif self.contiguous_gradients: + self.copy_grads_in_partition(param) + else: # zero stage 1 - partition only optimizer state + if self.contiguous_gradients and self.is_param_in_current_partition[param_id]: + self.copy_grads_in_partition(param) + + self.grads_in_ipg_bucket = [] + self.params_in_ipg_bucket = [] + self.ipg_bucket_has_moe_params = False + self.elements_in_ipg_bucket = 0 + ##################################################################### + + def reduce_ready_partitions_and_remove_grads(self, param, i): + if self.partition_gradients or self.is_gradient_accumulation_boundary: + self.reduce_independent_p_g_buckets_and_remove_grads(param, i) + + def zero_reduced_gradients(self, partition_id, i): + + def are_all_related_partitions_reduced(params_id): + for partition_id in self.param_to_partition_ids[i][params_id]: + if not self.is_partition_reduced[i][partition_id]: + return False + return True + + for params_id in self.is_grad_computed[i][partition_id]: + if are_all_related_partitions_reduced(params_id): + self.param_dict[params_id].grad = None # dead code + + def flatten_and_print(self, message, tensors, start=0, n=5): + flatten_tensor = self.flatten(tensors) + + def print_func(): + logger.info(flatten_tensor.contiguous().view(-1).narrow(0, start, n)) + + self.sequential_execution(print_func, message) + + def get_grads_to_reduce(self, i, partition_id): + + def get_reducible_portion(key): + grad = self.param_dict[key].grad + total_elements = grad.numel() + start = self.grad_start_offset[i][partition_id][key] + num_elements = min(total_elements - start, + self.partition_size[i] - self.grad_partition_insertion_offset[i][partition_id][key]) + if not pg_correctness_test: + if num_elements == total_elements: + return grad + else: + return grad.contiguous().view(-1).narrow(0, int(start), int(num_elements)) + else: + if num_elements == total_elements: + return grad.clone() + else: + return grad.clone().contiguous().view(-1).narrow(0, int(start), int(num_elements)) + + grads_to_reduce = [] + for key in self.is_grad_computed[i][partition_id]: + grad = get_reducible_portion(key) + grads_to_reduce.append(grad) + return grads_to_reduce + + def sequential_execution(self, function, message, group=None): + if group is None: + group = self.dp_process_group + if dist.get_rank(group=group) == 0: + logger.info(message) + for id in range(dist.get_world_size(group=group)): + if id == dist.get_rank(group=group): + function() + dist.barrier(group=group) + + def set_none_gradients_to_zero(self, i, partition_id): + for param_id in self.is_grad_computed[i][partition_id]: + param = self.param_dict[param_id] + if param.grad is None: + param.grad = torch.zeros_like(param) + + ######################Reduction Related Methods############################## + def allreduce_bucket(self, bucket, rank=None, log=None, divide=True, process_group=None): + tensor = self.flatten(bucket) + + process_group = self.dp_process_group if process_group is None else process_group + + tensor_to_allreduce = tensor + + if pg_correctness_test or self.sequence_parallel_size > 1: + communication_data_type = torch.float32 + else: + communication_data_type = self.communication_data_type + + if communication_data_type != tensor.dtype: + tensor_to_allreduce = tensor.to(communication_data_type) + + if divide: + tensor_to_allreduce.div_(dist.get_world_size(group=process_group) / float(self.sequence_parallel_size)) + + if rank is None: + # "All Reducing" + dist.all_reduce(tensor_to_allreduce, group=process_group) + else: + global_rank = dist.get_global_rank(process_group, rank) + dist.reduce(tensor_to_allreduce, global_rank, group=process_group) + + if communication_data_type != tensor.dtype and tensor is not tensor_to_allreduce: + if rank is None or rank == dist.get_rank(group=process_group): + tensor.copy_(tensor_to_allreduce) + + return tensor + + def _clear_previous_reduced_grads(self): + if self.previous_reduced_grads is not None: + for param in self.previous_reduced_grads: + self.clear_grad_attribute(param) + self.previous_reduced_grads = None + + # if rank is specified do a reduction instead of an allreduce + def allreduce_and_copy(self, small_bucket, rank=None, log=None, divide=True, process_group=None): + process_group = self.dp_process_group if process_group is None else process_group + if self.overlap_comm: + if not get_accelerator().resolves_data_dependency(): + get_accelerator().synchronize() + # It is safe to clear the previously reduced grads of other partitions + self._clear_previous_reduced_grads() + stream = self.reduction_stream + else: + stream = get_accelerator().current_stream() + + with get_accelerator().stream(stream): + allreduced = self.allreduce_bucket( + small_bucket, + rank=rank, + log=log, + divide=divide, + process_group=process_group, + ) + if rank is None or rank == dist.get_rank(group=self.dp_process_group): + for buf, synced in zip(small_bucket, self.unflatten(allreduced, small_bucket)): + buf.copy_(synced) + + def allreduce_no_retain( + self, + bucket, + numel_per_bucket=500000000, + rank=None, + log=None, + divide=True, + process_group=None, + ): + small_bucket = [] + numel = 0 + for tensor in bucket: + small_bucket.append(tensor) + numel = numel + tensor.numel() + if numel > numel_per_bucket: + self.allreduce_and_copy(small_bucket, rank=rank, log=None, divide=divide, process_group=process_group) + small_bucket = [] + numel = 0 + + if len(small_bucket) > 0: + self.allreduce_and_copy(small_bucket, rank=rank, log=log, divide=divide, process_group=process_group) + + # allows using reduction of gradients instead of using all_reduce + + def buffered_reduce_fallback(self, rank, grads, elements_per_buffer=500000000, log=None): + split_buckets = split_half_float_double(grads) + + for i, bucket in enumerate(split_buckets): + self.allreduce_no_retain(bucket, numel_per_bucket=elements_per_buffer, rank=rank, log=log) + + ############################################################################# + ############################################################################# + ############################################################################# + + # views the tensor as multiple partitions and returns + # those partitions + def get_data_parallel_partitions(self, tensor, group_id): + partitions = [] + + dp = dist.get_world_size(group=self.real_dp_process_group[group_id]) + # dp_id = dist.get_rank(group=self.real_dp_process_group[group_id]) + + total_num_elements = tensor.numel() + + base_size = total_num_elements // dp + remaining = total_num_elements % dp + + start = 0 + for id in range(dp): + partition_size = base_size + if id < remaining: + partition_size = partition_size + 1 + partitions.append(tensor.narrow(0, start, partition_size)) + start = start + partition_size + return partitions + + def get_partition_info(self, tensor_list, partition_size, partition_id): + params_in_partition = [] + params_not_in_partition = [] + + start_index = partition_size * partition_id + end_index = partition_size * (partition_id + 1) + + current_index = 0 + first_offset = 0 + + for tensor in tensor_list: + + tensor_size = tensor.numel() + + if start_index <= current_index < end_index: + params_in_partition.append(tensor) + + elif current_index < start_index < (current_index + tensor_size): + params_in_partition.append(tensor) + + assert (first_offset == 0 + ), "This can happen either zero or only once as this must be the first tensor in the partition" + first_offset = start_index - current_index + + else: + params_not_in_partition.append(tensor) + + current_index = current_index + tensor_size + + return params_in_partition, params_not_in_partition, first_offset + + def zero_grad(self, set_to_none=True): + """ + Zero FP16 parameter grads. + """ + # FP32 grad should never exist. + # For speed, set model fp16 grad to None by default + # zero all pointers to grad tensors + for group in self.bit16_groups: + for p in group: + if set_to_none: + p.grad = None # epilogue and in step + p.grad_accum = None + else: + if p.grad is not None: + p.grad.detach_() + p.grad.zero_() + + def _model_parallel_all_reduce(self, tensor, op): + """ Perform all reduce within model parallel group, if any. + """ + if self.model_parallel_group is None or self.model_parallel_world_size == 1: + pass + else: + dist.all_reduce(tensor=tensor, op=op, group=self.model_parallel_group) + + def get_grad_norm_direct(self, gradients, params, norm_type=2): + """Clips gradient norm of an iterable of parameters. + + This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and + added functionality to handle model parallel parameters. Note that + the gradients are modified in place. + + Arguments: + parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a + single Tensor that will have gradients normalized + max_norm (float or int): max norm of the gradients + norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for + infinity norm. + + Returns: + Total norm of the parameters (viewed as a single vector). + """ + norm_type = float(norm_type) + all_norms = [] + if norm_type == inf: + for g in gradients: + all_norms.append(g.data.abs().max().float()) + total_norm = torch.stack(all_norms).max() + dist.all_reduce(total_norm, op=dist.ReduceOp.MAX, group=self.dp_process_group) + + # Take max across all GPUs. + self._model_parallel_all_reduce(tensor=total_norm, op=dist.ReduceOp.MAX) + else: + # if dist.get_rank() == 0: + # logger.info(f"Total Norm beginning {total_norm}") + for g, p in zip(gradients, params): + # Pipeline parallelism may replicate parameters. Avoid multi-counting. + if hasattr(p, PIPE_REPLICATED) and p.ds_pipe_replicated: + continue + if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0): + + all_norms.append( + torch.norm(g.data.float().detach(), norm_type).to(get_accelerator().current_device_name())) + if len(all_norms) > 0: + total_norm = torch.stack(all_norms).square().sum().float() + else: + total_norm = torch.tensor(0.0, dtype=torch.float32).to(self.device) + # Sum across all model parallel Device. + dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=self.dp_process_group) + + + + self._model_parallel_all_reduce(tensor=total_norm, op=dist.ReduceOp.SUM) + + total_norm = total_norm.pow(1. / norm_type) + + norm_is_inf = total_norm.isinf() + norm_is_nan = total_norm.isnan() + inf_or_nan = norm_is_nan.logical_or(norm_is_inf) + + err = torch.tensor(-1.0, device=self.device, dtype=torch.float) + total_norm = inf_or_nan * err + inf_or_nan.logical_not() * total_norm + return total_norm + + # creates a flat fused tensor from the tensor list starting at the first_offset + # in the first tensor of the list. If there are not enough elements in the tensor + # list then the flat tensor will be padded with zeros + def get_flat_partition(self, tensor_list, first_offset, partition_size, dtype, device, return_tensor_list=False): + flat_tensor_list = [] + current_size = 0 + + for i, tensor in enumerate(tensor_list): + grad_accum = self.get_param_gradient_attribute(tensor) + if grad_accum is None: + grad_accum = torch.zeros_like(tensor, dtype=dtype) + + tensor = grad_accum + num_elements = tensor.numel() + tensor_offset = 0 + + # we need to offset to get to the right element + if i == 0 and first_offset > 0: + tensor_offset = first_offset + num_elements = num_elements - tensor_offset + + # we dont need all elements of the tensor + if num_elements > (partition_size - current_size): + num_elements = partition_size - current_size + + # we need a narrow view of the tensor based on the tensor offset and number of elements that + # we need from this tensor + if tensor_offset > 0 or num_elements < tensor.numel(): + flat_tensor_list.append(tensor.contiguous().view(-1).narrow(0, int(tensor_offset), int(num_elements))) + else: + flat_tensor_list.append(tensor) + + current_size = current_size + num_elements + + # this means its the last partition and does not align with the dp boundary. We need to pad before flattening + if current_size < partition_size: + flat_tensor_list.append(torch.zeros(int(partition_size - current_size), dtype=dtype, device=device)) + + if return_tensor_list: + return flat_tensor_list + + return self.flatten(flat_tensor_list) + + def free_grad_in_param_list(self, param_list): + for p in param_list: + p.grad = None # in step + p.grad_accum = None + + def reset_cpu_buffers(self): + self.norm_for_param_grads = {} + self.local_overflow = False + + def set_lr(self, lr): + """Set the learning rate.""" + for param_group in self.optimizer.param_groups: + param_group["lr"] = lr + + def get_lr(self): + """Return the current learning rate.""" + return self.optimizer.param_groups[0]["lr"] + + def override_loss_scale(self, loss_scale): + if loss_scale != self.external_loss_scale: + logger.info(f'[deepspeed] setting loss scale from {self.external_loss_scale} -> {loss_scale}') + self.custom_loss_scaler = True + self.external_loss_scale = loss_scale + + def scaled_global_norm(self, norm_type=2): + assert norm_type == 2, "only L2 norm supported" + norm_groups = [] + for i, group in enumerate(self.bit16_groups): + if self.cpu_offload: + # complete complete_grad_norm_calculation_for_cpu_offload return python float, moving back to + # torch.tensor as else statement returns tensor as well + norm = torch.tensor(self.complete_grad_norm_calculation_for_cpu_offload(self.params_in_partition[i]), + device=self.device) + norm_groups.append(norm) + else: + norm_groups.append(self.get_grad_norm_direct(self.averaged_gradients[i], self.params_in_partition[i])) + + if self.has_moe_layers: + self._average_expert_grad_norms(norm_groups) + + # calculating L2 norm + return torch.norm(torch.stack(norm_groups), p=norm_type) + + def get_bit16_param_group(self, group_no): + bit16_partitions = self.parallel_partitioned_bit16_groups[group_no] + partition_id = dist.get_rank(group=self.real_dp_process_group[group_no]) + return [bit16_partitions[dist.get_rank(group=self.real_dp_process_group[group_no])]] + + def _optimizer_step(self, group_no): + original_param_groups = self.optimizer.param_groups + self.optimizer.param_groups = [original_param_groups[group_no]] + # Disabling this as the C++ side copy & synchronize is not working correctly + #from deepspeed.ops.adam import DeepSpeedCPUAdam + #if type(self.optimizer) == DeepSpeedCPUAdam and self.dtype == torch.half: + # self.optimizer.step(fp16_param_groups=[self.get_bit16_param_group(group_no)]) + #else: + # self.optimizer.step() + self.optimizer.step() + self.optimizer.param_groups = original_param_groups + + # We need to link optimizer state after the first step() call + self._lazy_init_hp_params_optimizer_state() + + def step(self, closure=None): + """ + Not supporting closure. + """ + self.micro_step_id = INITIAL_MICRO_STEP_ID + + see_memory_usage(f"In step before checking overflow") + + # First compute norm for all group so we know if there is overflow + if self.dtype == torch.float16: + self.check_overflow() + + prev_scale = self.loss_scale + self._update_scale(self.overflow) + if self.overflow: + see_memory_usage('After overflow before clearing gradients') + self.zero_grad(set_to_none=True) + if self.cpu_offload: + self.reset_cpu_buffers() + else: + self.averaged_gradients = {} + + see_memory_usage('After overflow after clearing gradients') + + for timer in OPTIMIZER_TIMERS: + self.timers(timer).start() + self.timers(timer).stop() + return + + # Step 1:- Calculate gradient norm using bit-16 grads + see_memory_usage('Before norm calculation') + scaled_global_grad_norm = self.scaled_global_norm() + self._global_grad_norm = scaled_global_grad_norm / prev_scale + see_memory_usage('After norm before optimizer') + + # Step 2:- run optimizer and upscaling simultaneously + for i, group in enumerate(self.bit16_groups): + self.timers(OPTIMIZER_GRADIENTS_TIMER).start() + partition_id = dist.get_rank(group=self.real_dp_process_group[i]) + if self.cpu_offload: + single_grad_partition = self.single_partition_of_fp32_groups[i].grad + self.unscale_and_clip_grads([single_grad_partition], scaled_global_grad_norm) + + self.timers(OPTIMIZER_GRADIENTS_TIMER).stop() + self.timers(OPTIMIZER_STEP_TIMER).start() + self._optimizer_step(i) + + # Disabled, this is not currently working + #from deepspeed.ops.adam import DeepSpeedCPUAdam + #if not (type(self.optimizer) == DeepSpeedCPUAdam and self.dtype == torch.half): + # bit16_partitions = self.parallel_partitioned_bit16_groups[i] + # fp32_partition = self.single_partition_of_fp32_groups[i] + # bit16_partitions[partition_id].data.copy_(fp32_partition.data) + bit16_partitions = self.parallel_partitioned_bit16_groups[i] + fp32_partition = self.single_partition_of_fp32_groups[i] + bit16_partitions[partition_id].data.copy_( + fp32_partition.to(get_accelerator().current_device_name()).data) + + self.timers(OPTIMIZER_STEP_TIMER).stop() + else: + # free gradients for all the parameters that are not updated by this process(ZeRO stage2) + self.free_grad_in_param_list(self.params_not_in_partition[i]) + + # create a flat gradients for parameters updated by this process + # If we are last partition, ensure we have same size grads and partition size, if not pad with zero tensors + if partition_id == dist.get_world_size(group=self.real_dp_process_group[i]) - 1: + single_grad_partition = self.flatten_dense_tensors_aligned( + self.averaged_gradients[i], + int(self.partition_size[i])).to(self.single_partition_of_fp32_groups[i].dtype) + else: + single_grad_partition = self.flatten(self.averaged_gradients[i]).to( + self.single_partition_of_fp32_groups[i].dtype) + assert single_grad_partition.numel() == self.partition_size[i], \ + "averaged gradients have different number of elements that partition size {} {} {} {}".format( + single_grad_partition.numel(), self.partition_size[i], i, partition_id) + + self.single_partition_of_fp32_groups[i].grad = single_grad_partition + # release all the gradient since we have already created a necessary copy in dp_grad_partition(ZeRO stage2) + self.free_grad_in_param_list(self.params_in_partition[i]) + + self.averaged_gradients[i] = None + + self.unscale_and_clip_grads([single_grad_partition], scaled_global_grad_norm) + + self.timers(OPTIMIZER_GRADIENTS_TIMER).stop() + + # Step 3:- run the optimizer if no offloading + self.timers(OPTIMIZER_STEP_TIMER).start() + self._optimizer_step(i) + # Step 4:- get rid of the fp32 gradients. Not needed anymore + self.single_partition_of_fp32_groups[i].grad = None + del single_grad_partition + bit16_partitions = self.parallel_partitioned_bit16_groups[i] + fp32_partition = self.single_partition_of_fp32_groups[i] + bit16_partitions[partition_id].data.copy_(fp32_partition.data) + self.timers(OPTIMIZER_STEP_TIMER).stop() + + see_memory_usage('After optimizer before all-gather') + if self.cpu_offload: + self.reset_cpu_buffers() + + self.timers(OPTIMIZER_ALLGATHER_TIMER).start() + # Gather the updated weights from everyone. + # Then all partitions of the model parameters are updated and ready for next round forward. + all_gather_dp_groups(groups_flat=self.bit16_groups_flat, + partitioned_param_groups=self.parallel_partitioned_bit16_groups, + dp_process_group=self.real_dp_process_group, + start_alignment_factor=self.nccl_start_alignment_factor, + allgather_bucket_size=self.allgather_bucket_size) + self.timers(OPTIMIZER_ALLGATHER_TIMER).stop() + + # TODO: we probably don't need this? just to be safe + for i in range(len(self.bit16_groups)): + self._update_model_bit16_weights(i) + + self.timers.log(OPTIMIZER_TIMERS) + see_memory_usage('After zero_optimizer step') + + return + + @torch.no_grad() + def update_lp_params(self): + for i, (bit16_partitions, fp32_partition) in enumerate( + zip(self.parallel_partitioned_bit16_groups, self.single_partition_of_fp32_groups)): + partition_id = dist.get_rank(group=self.real_dp_process_group[i]) + bit16_partitions[partition_id].data.copy_(fp32_partition.data) + # print_rank_0(f'update_lp_params {i=} {partition_id=}', force=True) + # if i == 0: + # print_rank_0(f'{fp32_partition[:10]=}', force=True) + all_gather_dp_groups(groups_flat=self.bit16_groups_flat, + partitioned_param_groups=self.parallel_partitioned_bit16_groups, + dp_process_group=self.real_dp_process_group, + start_alignment_factor=self.nccl_start_alignment_factor, + allgather_bucket_size=self.allgather_bucket_size) + + def _average_expert_grad_norms(self, norm_groups): + for i, norm in enumerate(norm_groups): + if self.is_moe_param_group[i]: + scaled_norm_tensor = norm * 1.0 / dist.get_world_size(group=self.real_dp_process_group[i]) + if self.device == 'cpu': + scaled_norm_tensor = scaled_norm_tensor.to(get_accelerator().current_device_name()) + dist.all_reduce(scaled_norm_tensor, group=self.real_dp_process_group[i]) + norm_groups[i] = scaled_norm_tensor.to(self.device) + + def unscale_and_clip_grads(self, grad_groups_flat, total_norm): + # compute combined scale factor for this group + combined_scale = self.loss_scale + if self.clip_grad > 0.: + # norm is in fact norm*scale + clip = ((total_norm / self.loss_scale) + 1e-6) / self.clip_grad + clip = torch.clamp(clip, min=1.0) + combined_scale = clip * self.loss_scale + + for grad in grad_groups_flat: + if isinstance(grad, list): + sub_partitions = grad + for g in sub_partitions: + g.data.mul_(1. / combined_scale) + else: + grad.data.mul_(1. / combined_scale) + + def _check_overflow(self, partition_gradients=True): + self.overflow = self.has_overflow(partition_gradients) + + # `params` is a list / generator of torch.Variable + def has_overflow_serial(self, params): + invalid_grad_count = torch.zeros([1], dtype=torch.float, device=get_accelerator().current_device_name()) + for p in params: + if p.grad is not None: + invalid_grad_count += self._has_inf_or_nan(p.grad) + return invalid_grad_count.bool() + + def has_overflow_partitioned_grads_serial(self): + invalid_grad_count = torch.zeros([1], dtype=torch.float, device=get_accelerator().current_device_name()) + for i in range(len(self.bit16_groups)): + for j, grad in enumerate(self.averaged_gradients[i]): + if grad is not None: + invalid_grad_count += self._has_inf_or_nan(grad) + return invalid_grad_count.bool() + + def has_overflow(self, partition_gradients=True): + if partition_gradients: + overflow = self.local_overflow if self.cpu_offload else self.has_overflow_partitioned_grads_serial() + overflow_gpu = get_accelerator().ByteTensor([overflow]) if self.cpu_offload else overflow.byte().to( + get_accelerator().current_device_name()) + '''This will capture overflow across all data parallel and expert parallel process + Since expert parallel process are a subset of data parallel process''' + dist.all_reduce(overflow_gpu, op=dist.ReduceOp.MAX, group=self.dp_process_group) + + else: + params = [] + for group in self.bit16_groups: + for param in group: + params.append(param) + overflow_gpu = self.has_overflow_serial(params).byte().to(get_accelerator().current_device_name()) + + # Since each model parallel GPU carries only part of the model, + # make sure overflow flag is synced across all the model parallel GPUs + self._model_parallel_all_reduce(tensor=overflow_gpu, op=dist.ReduceOp.MAX) + + overflow = overflow_gpu[0].item() + return bool(overflow) + + # `x` is a torch.Tensor + @staticmethod + def _has_inf_or_nan(x, j=None): + float_x = x.float() + nan = float_x.isnan() + inf = float_x.isinf() + inf_or_nan = nan.logical_or(inf) + return inf_or_nan.float().max() + + def backward(self, loss, retain_graph=False): + """ + :attr:`backward` performs the following steps: + + 1. fp32_loss = loss.float() + 2. scaled_loss = fp32_loss*loss_scale + 3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's fp16 leaves + """ + self.micro_step_id += 1 + + if self.contiguous_gradients: + self.ipg_buffer = [] + buf_0 = torch.empty(int(self.reduce_bucket_size), + dtype=self.dtype, + device=get_accelerator().current_device_name()) + self.ipg_buffer.append(buf_0) + + # Use double buffers to avoid data access conflict when overlap_comm is enabled. + if self.overlap_comm: + buf_1 = torch.empty(int(self.reduce_bucket_size), + dtype=self.dtype, + device=get_accelerator().current_device_name()) + self.ipg_buffer.append(buf_1) + self.ipg_index = 0 + + if self.custom_loss_scaler: + scaled_loss = self.external_loss_scale * loss + scaled_loss.backward() + else: + self.loss_scaler.backward(loss.float(), retain_graph=retain_graph) + + # Only for Stage 1, Mode 2 + if self.use_grad_accum_attribute: + self.fill_grad_accum_attribute() + + def check_overflow(self, partition_gradients=True): + self._check_overflow(partition_gradients) + + def _update_scale(self, has_overflow=False): + self.loss_scaler.update_scale(has_overflow) + + # Promote state so it can be retrieved or set via "fp16_optimizer_instance.state" + def _get_state(self): + return self.optimizer.state + + def _set_state(self, value): + self.optimizer.state = value + + state = property(_get_state, _set_state) + + # Promote param_groups so it can be retrieved or set via "fp16_optimizer_instance.param_groups" + # (for example, to adjust the learning rate) + def _get_param_groups(self): + return self.optimizer.param_groups + + def _set_param_groups(self, value): + self.optimizer.param_groups = value + + param_groups = property(_get_param_groups, _set_param_groups) + + # Promote loss scale so it can be retrieved or set via "fp16_optimizer_instance.loss_scale" + def _get_loss_scale(self): + if self.custom_loss_scaler: + return self.external_loss_scale + else: + return self.loss_scaler.cur_scale + + def _set_loss_scale(self, value): + self.loss_scaler.cur_scale = value + + loss_scale = property(_get_loss_scale, _set_loss_scale) + cur_scale = property(_get_loss_scale, _set_loss_scale) + + # Return group tensor after removing paddings that are added for alignment to DP world size. + # This method works on the assumption that each group contains a single flattened tensor. + def _get_groups_without_padding(self, groups_with_padding): + groups_without_padding = [] + for i, group in enumerate(groups_with_padding): + lean_length = group.numel() - self.groups_padding[i] + groups_without_padding.append(group[:lean_length]) + + return groups_without_padding + + # Return optimizer state after removing paddings that are added for alignment. + def _get_state_without_padding(self, state_with_padding, padding): + lean_state = {} + for key, value in state_with_padding.items(): + if torch.is_tensor(value): + lean_length = value.numel() - padding + lean_state[key] = value[:lean_length] + else: + lean_state[key] = value + + return lean_state + + # Return base optimizer states. + # This method assumes that each param group contains a single flattened tensor. + def _get_base_optimizer_state(self): + optimizer_groups_state = [] + for i, group in enumerate(self.optimizer.param_groups): + p = group['params'][0] + lean_optimizer_state = self._get_state_without_padding(self.optimizer.state[p], self.groups_padding[i]) + optimizer_groups_state.append(lean_optimizer_state) + + return optimizer_groups_state + + def state_dict(self): + """ + Returns a dict containing the current state of this :class:`FP16_Optimizer` instance. + This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict + of the contained Pytorch optimizer. + Example:: + checkpoint = {} + checkpoint['model'] = model.state_dict() + checkpoint['optimizer'] = optimizer.state_dict() + torch.save(checkpoint, "saved.pth") + """ + state_dict = {} + state_dict[LOSS_SCALER] = self.loss_scaler + state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale + state_dict['overflow'] = self.overflow + state_dict[CLIP_GRAD] = self.clip_grad + + if self.elastic_checkpoint: + state_dict[BASE_OPTIMIZER_STATE] = self._get_base_optimizer_state() + + if "step" in self.optimizer.param_groups[0]: + # Assuming "step" is the only item that changes through training iterations + assert all(group["step"] == self.optimizer.param_groups[0]["step"] + for group in self.optimizer.param_groups), "All param groups must have the same step value" + state_dict[BASE_OPTIMIZER_STATE_STEP] = self.optimizer.param_groups[0]["step"] + else: + state_dict[BASE_OPTIMIZER_STATE] = self.optimizer.state_dict() + + # Remove paddings for DP alignment to enable loading for other alignment values + fp32_groups_without_padding = self._get_groups_without_padding(self.single_partition_of_fp32_groups) + state_dict[SINGLE_PARTITION_OF_FP32_GROUPS] = fp32_groups_without_padding + + state_dict[ + ZERO_STAGE] = ZeroStageEnum.gradients if self.partition_gradients else ZeroStageEnum.optimizer_states + state_dict[GROUP_PADDINGS] = self.groups_padding + state_dict[PARTITION_COUNT] = self.partition_count + + state_dict[DS_VERSION] = version + state_dict[PARAM_SLICE_MAPPINGS] = self._param_slice_mappings + + return state_dict + + # Restore base optimizer fp32 weights from elastic checkpoint by: + # 1) Merging fp32 weights from checkpoints of all partitions + # 2) Extracting fp32 weights for current partition from merged weights + # 3) Using extracted weights to update base optimizer weights directly. + def _restore_from_elastic_fp32_weights(self, all_state_dict): + merged_single_partition_of_fp32_groups = [] + + for i in range(len(self.single_partition_of_fp32_groups)): + partition_id = dist.get_rank(group=self.real_dp_process_group[i]) + merged_partitions = [sd[SINGLE_PARTITION_OF_FP32_GROUPS][i] for sd in all_state_dict] + if self.is_moe_group(self.optimizer.param_groups[i]): + ranks = self.get_ep_ranks(group_name=self.optimizer.param_groups[i]['name']) + merged_partitions = [merged_partitions[i] for i in ranks] + flat_merged_partitions = self.flatten_dense_tensors_aligned( + merged_partitions, + self.nccl_start_alignment_factor * dist.get_world_size(group=self.real_dp_process_group[i])) + dp_partitions = self.get_data_parallel_partitions(flat_merged_partitions, i) + merged_single_partition_of_fp32_groups.append(dp_partitions[partition_id]) + + for current, saved in zip(self.single_partition_of_fp32_groups, merged_single_partition_of_fp32_groups): + current.data.copy_(saved.data) + + # Restore base optimizer fp32 weights from ZeRO fp16 or bfloat16 weights + def _restore_from_bit16_weights(self): + for group_id, (bit16_partitions, fp32_partition) in enumerate( + zip(self.parallel_partitioned_bit16_groups, self.single_partition_of_fp32_groups)): + partition_id = dist.get_rank(group=self.real_dp_process_group[group_id]) + fp32_partition.data.copy_(bit16_partitions[partition_id].data) + + # Refresh the fp32 master params from the fp16 or bfloat16 copies. + def refresh_fp32_params(self): + self._restore_from_bit16_weights() + + # Extract optimizer state for current partition from merged states of all partitions + def _partition_base_optimizer_state(self, state_key, all_partition_states, group_id): + partition_id = dist.get_rank(group=self.real_dp_process_group[group_id]) + alignment = self.nccl_start_alignment_factor * dist.get_world_size(group=self.real_dp_process_group[group_id]) + if torch.is_tensor(all_partition_states[0]): + flat_merged_partitions = self.flatten_dense_tensors_aligned(all_partition_states, alignment) + dp_partitions = self.get_data_parallel_partitions(flat_merged_partitions, group_id) + return dp_partitions[partition_id] + else: + # Assume non-tensor states are not partitioned and equal across ranks, so return first one + return all_partition_states[0] + + def _restore_step_from_elastic_checkpoint(self, all_state_dict): + assert BASE_OPTIMIZER_STATE_STEP in all_state_dict[0] + assert all(sd[BASE_OPTIMIZER_STATE_STEP] == all_state_dict[0][BASE_OPTIMIZER_STATE_STEP] + for sd in all_state_dict), "State dicts of all partitions must have the same step value" + return all_state_dict[0][BASE_OPTIMIZER_STATE_STEP] + + def _restore_base_optimizer_state(self, base_optimizer_group_states, base_optimizer_state_step, group_paddings): + if type(base_optimizer_group_states) == dict: + base_optimizer_group_states = base_optimizer_group_states['state'] + + saved_keys = base_optimizer_group_states[0].keys() + + for i, group in enumerate(self.optimizer.param_groups): + p = group['params'][0] + padding = 0 if group_paddings is None else group_paddings[i] + for key in saved_keys: + saved = base_optimizer_group_states[i][key] + + if torch.is_tensor(saved): + if key in self.optimizer.state[p]: + dst_tensor = self.optimizer.state[p][key] + src_tensor = _get_padded_tensor(saved, dst_tensor.numel()) + self.optimizer.state[p][key].data.copy_(src_tensor.data) + else: + self.optimizer.state[p][key] = _pad_tensor_by_size( + saved, padding, torch.float32, + torch.device('cpu') if self.cpu_offload else self.device) + else: + self.optimizer.state[p][key] = saved + + for param_group in self.optimizer.param_groups: + param_group['step'] = base_optimizer_state_step + + def get_ep_ranks(self, rank=0, group_name=None): + from deepspeed.utils import groups + expert_parallel_size_ = groups._get_expert_parallel_world_size(group_name) + world_size = groups._get_data_parallel_world_size() + rank = groups._get_expert_parallel_rank(group_name) + ranks = range(rank, world_size, expert_parallel_size_) + return list(ranks) + + # Restore base optimizer state from elastic checkpoint by + # 1) Merging optimizer state from checkpoints of all partitions + # 2) Extracting optimizer state for current partition from the merged state + # 3) Using the extracted value to directly update the base optimizer. + def _restore_elastic_base_optimizer_state(self, all_state_dict): + base_optimizer_group_states = [] + for i in range(len(self.optimizer.param_groups)): + partition_states = {} + all_partition_group_states = [sd[BASE_OPTIMIZER_STATE][i] for sd in all_state_dict] + + if self.is_moe_group(self.optimizer.param_groups[i]): + ranks = self.get_ep_ranks(group_name=self.optimizer.param_groups[i]['name']) + all_partition_group_states = [all_partition_group_states[i] for i in ranks] + + for key in all_partition_group_states[0].keys(): + all_partition_states = [all_states[key] for all_states in all_partition_group_states] + partition_states[key] = self._partition_base_optimizer_state(key, all_partition_states, i) + base_optimizer_group_states.append(partition_states) + + self._restore_base_optimizer_state(base_optimizer_group_states, + self._restore_step_from_elastic_checkpoint(all_state_dict), None) + + def load_state_dict(self, + state_dict_list, + load_optimizer_states=True, + load_from_fp32_weights=False, + checkpoint_folder=None, + load_serial=None, + param_shapes=None): + if checkpoint_folder: + self._load_universal_checkpoint(checkpoint_folder, load_optimizer_states, load_from_fp32_weights) + else: + self._load_legacy_checkpoint(state_dict_list, load_optimizer_states, load_from_fp32_weights) + + def _load_universal_checkpoint(self, checkpoint_folder, load_optimizer_states, load_from_fp32_weights): + self.load_hp_checkpoint_state_from_checkpoint_dir("bit16_groups", checkpoint_folder) + + @property + def param_groups(self): + """Forward the wrapped optimizer's parameters.""" + return self.optimizer.param_groups + + def _load_global_state(self, sd): + self.loss_scaler = sd.get(LOSS_SCALER, self.loss_scaler) + self.dynamic_loss_scale = sd.get('dynamic_loss_scale', self.dynamic_loss_scale) + self.overflow = sd.get('overflow', self.overflow) + self.clip_grad = sd.get(CLIP_GRAD, self.clip_grad) + + ckpt_version = sd.get(DS_VERSION, False) + assert ckpt_version, f"Empty ds_version in checkpoint, not clear how to proceed" + ckpt_version = pkg_version.parse(ckpt_version) + + # zero stage 1 mode + if not self.partition_gradients: + required_version = pkg_version.parse("0.3.17") + error_str = f"ZeRO stage 1 changed in {required_version} and is not backwards compatible " \ + "with older stage 1 checkpoints. If you'd like to load an old ZeRO-1 checkpoint " \ + "please use an older version of DeepSpeed (<= 0.5.8) and set 'legacy_stage1': true in your zero config json." + assert required_version <= ckpt_version, f"Old version: {ckpt_version} {error_str}" + + def _load_legacy_checkpoint(self, state_dict_list, load_optimizer_states=True, load_from_fp32_weights=False): + r"""Loading ZeRO checkpoint + + Arguments: + state_dict_list: List of all saved ZeRO checkpoints, one for each saved partition. + Note that the number of saved partitions may differ from number of loading partitions to support + changing GPU count, specifically DP world size, between saving and loading checkpoints. + load_optimizer_states: Boolean indicating whether or not to load base optimizer states + load_from_fp32_weights: Boolean indicating whether to initialize fp32 master weights from fp32 + copies in checkpoints (no precision loss) or from model's fp16 copies (with precision loss). + """ + """ + Loads a state_dict created by an earlier call to state_dict(). + If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``, + whose parameters in turn came from ``model``, it is expected that the user + will call ``model.load_state_dict()`` before + ``fp16_optimizer_instance.load_state_dict()`` is called. + Example:: + model = torch.nn.Linear(D_in, D_out).to(get_accelerator().device_name()).half() + optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) + optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0) + ... + checkpoint = torch.load("saved.pth") + model.load_state_dict(checkpoint['model']) + optimizer.load_state_dict(checkpoint['optimizer']) + """ + + # I think it should actually be ok to reload the optimizer before the model. + dp_rank = dist.get_rank(group=self.dp_process_group) + current_rank_sd = state_dict_list[dp_rank] + self._load_global_state(current_rank_sd) + + ckpt_is_rigid = isinstance(current_rank_sd[BASE_OPTIMIZER_STATE], dict) + + # padding is always at the last rank/partition + # if DP=1024 and param-group elems=16 -> padding will be 1024-16 across all but one rank + # scenario-1 (shrink): saving w. 4 gpus -> loading w. 2 gpus + # scenario-2 (expand): saving w. 2 gpus -> loading w. 4 gpus + # if load_optimizer_states: + # if new_dp_size: + # self.strip_padding() + # self.add_padding_w_new_dp_size() + # self.optimizer.load_state_dict(current_rank_sd[BASE_OPTIMIZER_STATE]) + + if load_optimizer_states: + if ckpt_is_rigid: + # loading rigid ckpt into either rigid or elastic exec + self.optimizer.load_state_dict(current_rank_sd[BASE_OPTIMIZER_STATE]) + else: + if self.elastic_checkpoint: + # loading elastic into elastic exec + self._restore_elastic_base_optimizer_state(state_dict_list) + else: + # loading an elastic checkpoint into rigid exec + self._restore_base_optimizer_state(current_rank_sd[BASE_OPTIMIZER_STATE], + current_rank_sd[BASE_OPTIMIZER_STATE_STEP], + current_rank_sd[GROUP_PADDINGS]) + + # At this point, the optimizer's references to the model's fp32 parameters are up to date. + # The optimizer's hyperparameters and internal buffers are also up to date. + # However, the fp32 master copies of the model's fp16 params stored by the optimizer are still + # out of date. There are two options. + # 1: Refresh the master params from the model's fp16 params. + # This requires less storage but incurs precision loss. + # 2: Save and restore the fp32 master copies separately. + # We choose option 1 if changing DP degree and option 2 otherwise. + # + # Pytorch Optimizer.load_state_dict casts saved buffers (e.g. momentum) to the type and device + # of their associated parameters, because it's possible those buffers might not exist yet in + # the current optimizer instance. In our case, as long as the current FP16_Optimizer has been + # constructed in the same way as the one whose state_dict we are loading, the same master params + # are guaranteed to exist, so we can just copy_() from the saved master params. + + if load_from_fp32_weights: + # option 2 from above + if self.elastic_checkpoint and not ckpt_is_rigid: + self._restore_from_elastic_fp32_weights(state_dict_list) + else: + # For non-elastic checkpoint, simply copying from saved weights of current rank is sufficient. + for current, saved in zip(self.single_partition_of_fp32_groups, + current_rank_sd[SINGLE_PARTITION_OF_FP32_GROUPS]): + src_tensor = _get_padded_tensor(saved, current.numel()) + current.data.copy_(src_tensor.data) + else: + # option 1 from above + self._restore_from_bit16_weights() + + if load_optimizer_states: + self._link_all_hp_params() + + +def _handle_overflow(cpu_sum, x, i): + import math + rank = dist.get_rank() + if rank == 0: + t_i = -1 + for v_i, v in enumerate(x.data.contiguous().view(-1)): + if not math.isfinite(float(v)): + t_i = v_i + break + logger.info(f"rank {rank} detected overflow {cpu_sum} in tensor {i}:{t_i} shape {x.shape}") + + +def estimate_zero2_model_states_mem_needs(total_params, + num_gpus_per_node=1, + num_nodes=1, + cpu_offload=True, + additional_buffer_factor=1.5): + + total_gpus = num_nodes * num_gpus_per_node + + if cpu_offload: + gpu_mem = 2 * total_params + cpu_mem = total_params * max(4 * total_gpus, 16) * additional_buffer_factor + else: + # GPU's total_params multipliers: 2 = params_16bit, + # 18 = 2_grads_16bit + 4_grads_32bit + 4_params_32bit + 8_optimizer_states_32bit(momentum and variance) + gpu_mem = 2 * total_params + int(18 * total_params / total_gpus) + cpu_mem = total_params * 4 * num_gpus_per_node * additional_buffer_factor + + return int(cpu_mem), int(gpu_mem) + + +def model_to_params(model): + # shared params calculated only once + total_params = sum(dict((p.data_ptr(), p.numel()) for p in model.parameters()).values()) + return total_params + + +def estimate_zero2_model_states_mem_needs_all_live(model, + num_gpus_per_node=1, + num_nodes=1, + additional_buffer_factor=1.5): + """ + Print out estimates on memory usage requirements for ZeRO 2 params, optim states and gradients + for a given ``model`` and hardware setup. + + If you have an actual model object, use this function and everything will be derived + automatically. + + If it's a hypothetical model, use ``estimate_zero2_model_states_mem_needs_all_cold`` where you have to pass + the ``total_params`` explicitly. + + Args: + - ``model``: ``nn.Module`` object + - ``num_gpus_per_node``: how many gpus per node (defaults to 1) + - ``num_nodes``: how many nodes (defaults to 1), + - ``additional_buffer_factor``: estimation factor (defaults to 1.5): + + """ + + total_params = model_to_params(model) + + estimate_zero2_model_states_mem_needs_all_cold(total_params=total_params, + num_gpus_per_node=num_gpus_per_node, + num_nodes=num_nodes, + additional_buffer_factor=additional_buffer_factor) + + +def estimate_zero2_model_states_mem_needs_all_cold(total_params, + num_gpus_per_node=1, + num_nodes=1, + additional_buffer_factor=1.5): + """ + Print out estimates on memory usage requirements for ZeRO 2 params, optim states and gradients + for a given ``model`` and hardware setup. + + If it's a hypothetical model, use this function where you have to pass + the ``total_params`` and ``largest_layer_params`` explicitly. + + If you have an actual model object, use ``estimate_zero2_model_states_mem_needs_all_live`` and everything + will be derived automatically. + + Args: + - ``total_params``: total model params + - ``num_gpus_per_node``: how many gpus per node (defaults to 1) + - ``num_nodes``: how many nodes (defaults to 1), + - ``additional_buffer_factor``: estimation factor (defaults to 1.5): + + """ + + def format_options(cpu_offload): + enabled = [] + device = f'{OffloadDeviceEnum.cpu:4}' if cpu_offload else "none" + enabled.append(f"offload_optimizer={device}") + return ", ".join(enabled) + + nodes_str = "nodes" if num_nodes > 1 else "node" + gpus_str = "GPUs" if num_gpus_per_node > 1 else "GPU" + print("Estimated memory needed for params, optim states and gradients for a:\n" + f"HW: Setup with {num_nodes} {nodes_str}, {num_gpus_per_node} {gpus_str} per node.\n" + f"SW: Model with {int(total_params/1e6)}M total params.") + print(" per CPU | per GPU | Options") + for cpu_offload in [True, False]: + cpu_mem, gpu_mem = estimate_zero2_model_states_mem_needs(total_params=total_params, + num_gpus_per_node=num_gpus_per_node, + num_nodes=num_nodes, + cpu_offload=cpu_offload, + additional_buffer_factor=additional_buffer_factor) + + options_str = format_options(cpu_offload=cpu_offload) + print(f" {cpu_mem/2**30:7.2f}GB | {gpu_mem/2**30:6.2f}GB | {options_str}") diff --git a/toolbox/DeepSpeed/v0.15.3/patches/install_deepspeed.sh b/toolbox/DeepSpeed/v0.15.3/patches/install_deepspeed.sh new file mode 100644 index 0000000000000000000000000000000000000000..409cae21daa8eb4485fbb0d379f9ed67f746bd13 --- /dev/null +++ b/toolbox/DeepSpeed/v0.15.3/patches/install_deepspeed.sh @@ -0,0 +1,42 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#!/bin/bash + +TARGET_DIR=${TARGET_DIR:-} + +PYTHON_PATH=$(which python3) +PYTHON_DIST_PATH=${TARGET_DIR}/lib/python3/dist-packages + +PKG_DIR="build_pip" +PKG_NAME="deepspeed" + +if [[ ! -d ${PKG_DIR} ]]; then + echo "ERROR: Package directory ${PKG_DIR} doesn't exist" + exit 1 +fi + +latest_pkg="$(ls -t ${PKG_DIR} | grep ${PKG_NAME} | head -1)" +if [[ "${latest_pkg}" == "" ]]; then + echo "ERROR: Cannot find latest ${PKG_NAME} package" + exit 1 +else + echo "INFO: Found latest package ${latest_pkg} in directory ${PKG_DIR}" +fi + +${PYTHON_PATH} -m pip uninstall ${PKG_NAME} -y +${PYTHON_PATH} -m pip install ${PKG_DIR}/${latest_pkg} || exit + +# Return 0 status if all finished +exit 0 \ No newline at end of file diff --git a/toolbox/DeepSpeed/v0.15.3/patches/op_builder/builder.py b/toolbox/DeepSpeed/v0.15.3/patches/op_builder/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..9fb88f2f3e35c87b063b5be4bebfcca0c45e2e75 --- /dev/null +++ b/toolbox/DeepSpeed/v0.15.3/patches/op_builder/builder.py @@ -0,0 +1,856 @@ +#!/usr/bin/env python3 +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import os +import re +import sys +import time +import importlib +from pathlib import Path +import subprocess +import shlex +import shutil +import tempfile +import distutils.ccompiler +import distutils.log +import distutils.sysconfig +from distutils.errors import CompileError, LinkError +from abc import ABC, abstractmethod +from typing import List + +YELLOW = '\033[93m' +END = '\033[0m' +WARNING = f"{YELLOW} [WARNING] {END}" + +DEFAULT_TORCH_EXTENSION_PATH = "/tmp/torch_extensions" +DEFAULT_COMPUTE_CAPABILITIES = "6.0;6.1;7.0" + +try: + import torch +except ImportError: + print(f"{WARNING} unable to import torch, please install it if you want to pre-compile any deepspeed ops.") +else: + TORCH_MAJOR = int(torch.__version__.split('.')[0]) + TORCH_MINOR = int(torch.__version__.split('.')[1]) + + +class MissingCUDAException(Exception): + pass + + +class CUDAMismatchException(Exception): + pass + + +def installed_cuda_version(name=""): + import torch.utils.cpp_extension + cuda_home = torch.utils.cpp_extension.CUDA_HOME + if cuda_home is None: + raise MissingCUDAException("CUDA_HOME does not exist, unable to compile CUDA op(s)") + # Ensure there is not a cuda version mismatch between torch and nvcc compiler + output = subprocess.check_output([cuda_home + "/bin/nvcc", "-V"], universal_newlines=True) + output_split = output.split() + release_idx = output_split.index("release") + release = output_split[release_idx + 1].replace(',', '').split(".") + # Ignore patch versions, only look at major + minor + cuda_major, cuda_minor = release[:2] + return int(cuda_major), int(cuda_minor) + + +def get_default_compute_capabilities(): + compute_caps = DEFAULT_COMPUTE_CAPABILITIES + import torch.utils.cpp_extension + if torch.utils.cpp_extension.CUDA_HOME is not None and installed_cuda_version()[0] >= 11: + if installed_cuda_version()[0] == 11 and installed_cuda_version()[1] == 0: + # Special treatment of CUDA 11.0 because compute_86 is not supported. + compute_caps += ";8.0" + else: + compute_caps += ";8.0;8.6" + return compute_caps + + +# list compatible minor CUDA versions - so that for example pytorch built with cuda-11.0 can be used +# to build deepspeed and system-wide installed cuda 11.2 +cuda_minor_mismatch_ok = { + 10: ["10.0", "10.1", "10.2"], + 11: ["11.0", "11.1", "11.2", "11.3", "11.4", "11.5", "11.6", "11.7", "11.8"], + 12: ["12.0", "12.1", "12.2", "12.3", "12.4", "12.5", "12.6"], +} + + +def assert_no_cuda_mismatch(name=""): + cuda_major, cuda_minor = installed_cuda_version(name) + sys_cuda_version = f'{cuda_major}.{cuda_minor}' + torch_cuda_version = ".".join(torch.version.cuda.split('.')[:2]) + # This is a show-stopping error, should probably not proceed past this + if sys_cuda_version != torch_cuda_version: + if (cuda_major in cuda_minor_mismatch_ok and sys_cuda_version in cuda_minor_mismatch_ok[cuda_major] + and torch_cuda_version in cuda_minor_mismatch_ok[cuda_major]): + print(f"Installed CUDA version {sys_cuda_version} does not match the " + f"version torch was compiled with {torch.version.cuda} " + "but since the APIs are compatible, accepting this combination") + return True + elif os.getenv("DS_SKIP_CUDA_CHECK", "0") == "1": + print( + f"{WARNING} DeepSpeed Op Builder: Installed CUDA version {sys_cuda_version} does not match the " + f"version torch was compiled with {torch.version.cuda}." + "Detected `DS_SKIP_CUDA_CHECK=1`: Allowing this combination of CUDA, but it may result in unexpected behavior." + ) + return True + raise CUDAMismatchException( + f">- DeepSpeed Op Builder: Installed CUDA version {sys_cuda_version} does not match the " + f"version torch was compiled with {torch.version.cuda}, unable to compile " + "cuda/cpp extensions without a matching cuda version.") + return True + + +class OpBuilder(ABC): + _rocm_version = None + _rocm_gpu_arch = None + _rocm_wavefront_size = None + _is_rocm_pytorch = None + _is_sycl_enabled = None + _loaded_ops = {} + + def __init__(self, name): + self.name = name + self.jit_mode = False + self.build_for_cpu = False + self.enable_bf16 = False + self.error_log = None + + @abstractmethod + def absolute_name(self): + ''' + Returns absolute build path for cases where the op is pre-installed, e.g., deepspeed.ops.adam.cpu_adam + will be installed as something like: deepspeed/ops/adam/cpu_adam.so + ''' + pass + + @abstractmethod + def sources(self): + ''' + Returns list of source files for your op, relative to root of deepspeed package (i.e., DeepSpeed/deepspeed) + ''' + pass + + def hipify_extension(self): + pass + + def sycl_extension(self): + pass + + @staticmethod + def validate_torch_version(torch_info): + install_torch_version = torch_info['version'] + current_torch_version = ".".join(torch.__version__.split('.')[:2]) + if install_torch_version != current_torch_version: + raise RuntimeError("PyTorch version mismatch! DeepSpeed ops were compiled and installed " + "with a different version than what is being used at runtime. " + f"Please re-install DeepSpeed or switch torch versions. " + f"Install torch version={install_torch_version}, " + f"Runtime torch version={current_torch_version}") + + @staticmethod + def validate_torch_op_version(torch_info): + if not OpBuilder.is_rocm_pytorch(): + current_cuda_version = ".".join(torch.version.cuda.split('.')[:2]) + install_cuda_version = torch_info['cuda_version'] + if install_cuda_version != current_cuda_version: + raise RuntimeError("CUDA version mismatch! DeepSpeed ops were compiled and installed " + "with a different version than what is being used at runtime. " + f"Please re-install DeepSpeed or switch torch versions. " + f"Install CUDA version={install_cuda_version}, " + f"Runtime CUDA version={current_cuda_version}") + else: + current_hip_version = ".".join(torch.version.hip.split('.')[:2]) + install_hip_version = torch_info['hip_version'] + if install_hip_version != current_hip_version: + raise RuntimeError("HIP version mismatch! DeepSpeed ops were compiled and installed " + "with a different version than what is being used at runtime. " + f"Please re-install DeepSpeed or switch torch versions. " + f"Install HIP version={install_hip_version}, " + f"Runtime HIP version={current_hip_version}") + + @staticmethod + def is_rocm_pytorch(): + if OpBuilder._is_rocm_pytorch is not None: + return OpBuilder._is_rocm_pytorch + + _is_rocm_pytorch = False + try: + import torch + except ImportError: + pass + else: + if TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 5): + _is_rocm_pytorch = hasattr(torch.version, 'hip') and torch.version.hip is not None + if _is_rocm_pytorch: + from torch.utils.cpp_extension import ROCM_HOME + _is_rocm_pytorch = ROCM_HOME is not None + OpBuilder._is_rocm_pytorch = _is_rocm_pytorch + return OpBuilder._is_rocm_pytorch + + @staticmethod + def is_sycl_enabled(): + if OpBuilder._is_sycl_enabled is not None: + return OpBuilder._is_sycl_enabled + + _is_sycl_enabled = False + try: + result = subprocess.run(["c2s", "--version"], capture_output=True) + except: + pass + else: + _is_sycl_enabled = True + + OpBuilder._is_sycl_enabled = _is_sycl_enabled + return OpBuilder._is_sycl_enabled + + @staticmethod + def installed_rocm_version(): + if OpBuilder._rocm_version: + return OpBuilder._rocm_version + + ROCM_MAJOR = '0' + ROCM_MINOR = '0' + ROCM_VERSION_DEV_RAW = "" + if OpBuilder.is_rocm_pytorch(): + from torch.utils.cpp_extension import ROCM_HOME + rocm_ver_file = Path(ROCM_HOME).joinpath(".info/version") + if rocm_ver_file.is_file(): + with open(rocm_ver_file, 'r') as file: + ROCM_VERSION_DEV_RAW = file.read() + elif "rocm" in torch.__version__: + ROCM_VERSION_DEV_RAW = torch.__version__.split("rocm")[1] + if ROCM_VERSION_DEV_RAW != "": + ROCM_MAJOR = ROCM_VERSION_DEV_RAW.split('.')[0] + ROCM_MINOR = ROCM_VERSION_DEV_RAW.split('.')[1] + else: + # Look in /usr/include/rocm-version.h + rocm_ver_file = Path("/usr/include/rocm_version.h") + if rocm_ver_file.is_file(): + with open(rocm_ver_file, 'r') as file: + for ln in file.readlines(): + if "#define ROCM_VERSION_MAJOR" in ln: + ROCM_MAJOR = re.findall(r'\S+', ln)[2] + elif "#define ROCM_VERSION_MINOR" in ln: + ROCM_MINOR = re.findall(r'\S+', ln)[2] + if ROCM_MAJOR == '0': + assert False, "Could not detect ROCm version" + + OpBuilder._rocm_version = (int(ROCM_MAJOR), int(ROCM_MINOR)) + return OpBuilder._rocm_version + + @staticmethod + def get_rocm_gpu_arch(): + if OpBuilder._rocm_gpu_arch: + return OpBuilder._rocm_gpu_arch + rocm_info = Path("/opt/rocm/bin/rocminfo") + if (not rocm_info.is_file()): + rocm_info = Path("rocminfo") + rocm_gpu_arch_cmd = str(rocm_info) + " | grep -o -m 1 'gfx.*'" + try: + result = subprocess.check_output(rocm_gpu_arch_cmd, shell=True) + rocm_gpu_arch = result.decode('utf-8').strip() + except subprocess.CalledProcessError: + rocm_gpu_arch = "" + OpBuilder._rocm_gpu_arch = rocm_gpu_arch + return OpBuilder._rocm_gpu_arch + + @staticmethod + def get_rocm_wavefront_size(): + if OpBuilder._rocm_wavefront_size: + return OpBuilder._rocm_wavefront_size + + rocm_info = Path("/opt/rocm/bin/rocminfo") + if (not rocm_info.is_file()): + rocm_info = Path("rocminfo") + rocm_wavefront_size_cmd = str( + rocm_info) + " | grep -Eo -m1 'Wavefront Size:[[:space:]]+[0-9]+' | grep -Eo '[0-9]+'" + try: + result = subprocess.check_output(rocm_wavefront_size_cmd, shell=True) + rocm_wavefront_size = result.decode('utf-8').strip() + except subprocess.CalledProcessError: + rocm_wavefront_size = "32" + OpBuilder._rocm_wavefront_size = rocm_wavefront_size + return OpBuilder._rocm_wavefront_size + + def include_paths(self): + ''' + Returns list of include paths, relative to root of deepspeed package (i.e., DeepSpeed/deepspeed) + ''' + return [] + + def nvcc_args(self): + ''' + Returns optional list of compiler flags to forward to nvcc when building CUDA sources + ''' + return [] + + def cxx_args(self): + ''' + Returns optional list of compiler flags to forward to the build + ''' + return [] + + def is_compatible(self, verbose=False): + ''' + Check if all non-python dependencies are satisfied to build this op + ''' + return True + + def extra_ldflags(self): + return [] + + def has_function(self, funcname, libraries, library_dirs=None, verbose=False): + ''' + Test for existence of a function within a tuple of libraries. + + This is used as a smoke test to check whether a certain library is available. + As a test, this creates a simple C program that calls the specified function, + and then distutils is used to compile that program and link it with the specified libraries. + Returns True if both the compile and link are successful, False otherwise. + ''' + tempdir = None # we create a temporary directory to hold various files + filestderr = None # handle to open file to which we redirect stderr + oldstderr = None # file descriptor for stderr + try: + # Echo compile and link commands that are used. + if verbose: + distutils.log.set_verbosity(1) + + # Create a compiler object. + compiler = distutils.ccompiler.new_compiler(verbose=verbose) + + # Configure compiler and linker to build according to Python install. + distutils.sysconfig.customize_compiler(compiler) + + # Create a temporary directory to hold test files. + tempdir = tempfile.mkdtemp() + + # Define a simple C program that calls the function in question + prog = "void %s(void); int main(int argc, char** argv) { %s(); return 0; }" % (funcname, funcname) + + # Write the test program to a file. + filename = os.path.join(tempdir, 'test.c') + with open(filename, 'w') as f: + f.write(prog) + + # Redirect stderr file descriptor to a file to silence compile/link warnings. + if not verbose: + filestderr = open(os.path.join(tempdir, 'stderr.txt'), 'w') + oldstderr = os.dup(sys.stderr.fileno()) + os.dup2(filestderr.fileno(), sys.stderr.fileno()) + + # Workaround for behavior in distutils.ccompiler.CCompiler.object_filenames() + # Otherwise, a local directory will be used instead of tempdir + drive, driveless_filename = os.path.splitdrive(filename) + root_dir = driveless_filename[0] if os.path.isabs(driveless_filename) else '' + output_dir = os.path.join(drive, root_dir) + + # Attempt to compile the C program into an object file. + cflags = shlex.split(os.environ.get('CFLAGS', "")) + objs = compiler.compile([filename], output_dir=output_dir, extra_preargs=self.strip_empty_entries(cflags)) + + # Attempt to link the object file into an executable. + # Be sure to tack on any libraries that have been specified. + ldflags = shlex.split(os.environ.get('LDFLAGS', "")) + compiler.link_executable(objs, + os.path.join(tempdir, 'a.out'), + extra_preargs=self.strip_empty_entries(ldflags), + libraries=libraries, + library_dirs=library_dirs) + + # Compile and link succeeded + return True + + except CompileError: + return False + + except LinkError: + return False + + except: + return False + + finally: + # Restore stderr file descriptor and close the stderr redirect file. + if oldstderr is not None: + os.dup2(oldstderr, sys.stderr.fileno()) + if filestderr is not None: + filestderr.close() + + # Delete the temporary directory holding the test program and stderr files. + if tempdir is not None: + shutil.rmtree(tempdir) + + def strip_empty_entries(self, args): + ''' + Drop any empty strings from the list of compile and link flags + ''' + return [x for x in args if len(x) > 0] + + def cpu_arch(self): + try: + from cpuinfo import get_cpu_info + except ImportError as e: + cpu_info = self._backup_cpuinfo() + if cpu_info is None: + return "-march=native" + + try: + cpu_info = get_cpu_info() + except Exception as e: + self.warning(f"{self.name} attempted to use `py-cpuinfo` but failed (exception type: {type(e)}, {e}), " + "falling back to `lscpu` to get this information.") + cpu_info = self._backup_cpuinfo() + if cpu_info is None: + return "-march=native" + + if cpu_info['arch'].startswith('PPC_'): + # gcc does not provide -march on PowerPC, use -mcpu instead + return '-mcpu=native' + return '-march=native' + + def is_cuda_enable(self): + try: + assert_no_cuda_mismatch(self.name) + return '-D__ENABLE_CUDA__' + except MissingCUDAException: + print(f"{WARNING} {self.name} cuda is missing or is incompatible with installed torch, " + "only cpu ops can be compiled!") + return '-D__DISABLE_CUDA__' + return '-D__DISABLE_CUDA__' + + def _backup_cpuinfo(self): + # Construct cpu_info dict from lscpu that is similar to what py-cpuinfo provides + if not self.command_exists('lscpu'): + self.warning(f"{self.name} attempted to query 'lscpu' after failing to use py-cpuinfo " + "to detect the CPU architecture. 'lscpu' does not appear to exist on " + "your system, will fall back to use -march=native and non-vectorized execution.") + return None + result = subprocess.check_output(['lscpu']) + result = result.decode('utf-8').strip().lower() + + cpu_info = {} + cpu_info['arch'] = None + cpu_info['flags'] = "" + if 'genuineintel' in result or 'authenticamd' in result: + cpu_info['arch'] = 'X86_64' + if 'avx512' in result: + cpu_info['flags'] += 'avx512,' + elif 'avx512f' in result: + cpu_info['flags'] += 'avx512f,' + if 'avx2' in result: + cpu_info['flags'] += 'avx2' + elif 'ppc64le' in result: + cpu_info['arch'] = "PPC_" + + return cpu_info + + def simd_width(self): + try: + from cpuinfo import get_cpu_info + except ImportError as e: + cpu_info = self._backup_cpuinfo() + if cpu_info is None: + return '-D__SCALAR__' + + try: + cpu_info = get_cpu_info() + except Exception as e: + self.warning(f"{self.name} attempted to use `py-cpuinfo` but failed (exception type: {type(e)}, {e}), " + "falling back to `lscpu` to get this information.") + cpu_info = self._backup_cpuinfo() + if cpu_info is None: + return '-D__SCALAR__' + + if cpu_info['arch'] == 'X86_64': + if 'avx512' in cpu_info['flags'] or 'avx512f' in cpu_info['flags']: + return '-D__AVX512__' + elif 'avx2' in cpu_info['flags']: + return '-D__AVX256__' + return '-D__SCALAR__' + + def command_exists(self, cmd): + if '|' in cmd: + cmds = cmd.split("|") + else: + cmds = [cmd] + valid = False + for cmd in cmds: + safe_cmd = ["bash", "-c", f"type {cmd}"] + result = subprocess.Popen(safe_cmd, stdout=subprocess.PIPE) + valid = valid or result.wait() == 0 + + if not valid and len(cmds) > 1: + print(f"{WARNING} {self.name} requires one of the following commands '{cmds}', but it does not exist!") + elif not valid and len(cmds) == 1: + print(f"{WARNING} {self.name} requires the '{cmd}' command, but it does not exist!") + return valid + + def warning(self, msg): + self.error_log = f"{msg}" + print(f"{WARNING} {msg}") + + def deepspeed_src_path(self, code_path): + if os.path.isabs(code_path): + return code_path + else: + return os.path.join(Path(__file__).parent.parent.absolute(), code_path) + + def builder(self): + from torch.utils.cpp_extension import CppExtension + include_dirs = [os.path.abspath(x) for x in self.strip_empty_entries(self.include_paths())] + return CppExtension(name=self.absolute_name(), + sources=self.strip_empty_entries(self.sources()), + include_dirs=include_dirs, + extra_compile_args={'cxx': self.strip_empty_entries(self.cxx_args())}, + extra_link_args=self.strip_empty_entries(self.extra_ldflags())) + + def load(self, verbose=True): + if self.name in __class__._loaded_ops: + return __class__._loaded_ops[self.name] + + + from deepspeed.git_version_info import installed_ops, torch_info, accelerator_name + from deepspeed.accelerator import get_accelerator + if installed_ops.get(self.name, False): + + # Ensure the op we're about to load was compiled with the same + # torch/cuda versions we are currently using at runtime. + self.validate_torch_version(torch_info) + if torch.cuda.is_available() and isinstance(self, CUDAOpBuilder): + self.validate_torch_op_version(torch_info) + + op_module = importlib.import_module(self.absolute_name()) + __class__._loaded_ops[self.name] = op_module + return op_module + else: + return self.jit_load(verbose) + + def jit_load(self, verbose=True): + if not self.is_compatible(verbose): + raise RuntimeError( + f"Unable to JIT load the {self.name} op due to it not being compatible due to hardware/software issue. {self.error_log}" + ) + try: + import ninja # noqa: F401 # type: ignore + except ImportError: + raise RuntimeError(f"Unable to JIT load the {self.name} op due to ninja not being installed.") + + if isinstance(self, CUDAOpBuilder) and not self.is_rocm_pytorch(): + self.build_for_cpu = not torch.cuda.is_available() + + self.jit_mode = True + from torch.utils.cpp_extension import load + + start_build = time.time() + sources = [os.path.abspath(self.deepspeed_src_path(path)) for path in self.sources()] + extra_include_paths = [os.path.abspath(self.deepspeed_src_path(path)) for path in self.include_paths()] + + # Torch will try and apply whatever CCs are in the arch list at compile time, + # we have already set the intended targets ourselves we know that will be + # needed at runtime. This prevents CC collisions such as multiple __half + # implementations. Stash arch list to reset after build. + torch_arch_list = None + if "TORCH_CUDA_ARCH_LIST" in os.environ: + torch_arch_list = os.environ.get("TORCH_CUDA_ARCH_LIST") + os.environ["TORCH_CUDA_ARCH_LIST"] = "" + + nvcc_args = self.strip_empty_entries(self.nvcc_args()) + cxx_args = self.strip_empty_entries(self.cxx_args()) + + if isinstance(self, CUDAOpBuilder): + if not self.build_for_cpu and self.enable_bf16: + cxx_args.append("-DBF16_AVAILABLE") + nvcc_args.append("-DBF16_AVAILABLE") + nvcc_args.append("-U__CUDA_NO_BFLOAT16_OPERATORS__") + nvcc_args.append("-U__CUDA_NO_BFLOAT162_OPERATORS__") + nvcc_args.append("-U__CUDA_NO_BFLOAT16_CONVERSIONS__") + + if self.is_rocm_pytorch(): + cxx_args.append("-D__HIP_PLATFORM_AMD__=1") + os.environ["PYTORCH_ROCM_ARCH"] = self.get_rocm_gpu_arch() + cxx_args.append('-DROCM_WAVEFRONT_SIZE=%s' % self.get_rocm_wavefront_size()) + + op_module = load(name=self.name, + sources=self.strip_empty_entries(sources), + extra_include_paths=self.strip_empty_entries(extra_include_paths), + extra_cflags=cxx_args, + extra_cuda_cflags=nvcc_args, + extra_ldflags=self.strip_empty_entries(self.extra_ldflags()), + verbose=verbose) + + build_duration = time.time() - start_build + if verbose: + print(f"Time to load {self.name} op: {build_duration} seconds") + + # Reset arch list so we are not silently removing it for other possible use cases + if torch_arch_list: + os.environ["TORCH_CUDA_ARCH_LIST"] = torch_arch_list + + __class__._loaded_ops[self.name] = op_module + + return op_module + + +class CUDAOpBuilder(OpBuilder): + + def compute_capability_args(self, cross_compile_archs=None): + """ + Returns nvcc compute capability compile flags. + + 1. `TORCH_CUDA_ARCH_LIST` takes priority over `cross_compile_archs`. + 2. If neither is set default compute capabilities will be used + 3. Under `jit_mode` compute capabilities of all visible cards will be used plus PTX + + Format: + + - `TORCH_CUDA_ARCH_LIST` may use ; or whitespace separators. Examples: + + TORCH_CUDA_ARCH_LIST="6.1;7.5;8.6" pip install ... + TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0 7.5 8.0 8.6+PTX" pip install ... + + - `cross_compile_archs` uses ; separator. + + """ + ccs = [] + if self.jit_mode: + # Compile for underlying architectures since we know those at runtime + for i in range(torch.cuda.device_count()): + CC_MAJOR, CC_MINOR = torch.cuda.get_device_capability(i) + cc = f"{CC_MAJOR}.{CC_MINOR}" + if cc not in ccs: + ccs.append(cc) + ccs = sorted(ccs) + ccs[-1] += '+PTX' + else: + # Cross-compile mode, compile for various architectures + # env override takes priority + cross_compile_archs_env = os.environ.get('TORCH_CUDA_ARCH_LIST', None) + if cross_compile_archs_env is not None and cross_compile_archs_env != "": + if cross_compile_archs is not None: + print( + f"{WARNING} env var `TORCH_CUDA_ARCH_LIST={cross_compile_archs_env}` overrides `cross_compile_archs={cross_compile_archs}`" + ) + cross_compile_archs = cross_compile_archs_env.replace(' ', ';') + else: + if cross_compile_archs is None: + cross_compile_archs = get_default_compute_capabilities() + ccs = cross_compile_archs.split(';') + + ccs = self.filter_ccs(ccs) + if len(ccs) == 0: + raise RuntimeError( + f"Unable to load {self.name} op due to no compute capabilities remaining after filtering") + + args = [] + self.enable_bf16 = True + for cc in ccs: + num = cc[0] + cc[2] + # args.append(f'-gencode=arch=compute_{num},code=sm_{num}') + # if cc.endswith('+PTX'): + # args.append(f'-gencode=arch=compute_{num},code=compute_{num}') + + if int(cc[0]) <= 7: + self.enable_bf16 = False + + return args + + def filter_ccs(self, ccs: List[str]): + """ + Prune any compute capabilities that are not compatible with the builder. Should log + which CCs have been pruned. + """ + return ccs + + def version_dependent_macros(self): + # Fix from apex that might be relevant for us as well, related to https://github.com/NVIDIA/apex/issues/456 + version_ge_1_1 = [] + if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 0): + version_ge_1_1 = ['-DVERSION_GE_1_1'] + version_ge_1_3 = [] + if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 2): + version_ge_1_3 = ['-DVERSION_GE_1_3'] + version_ge_1_5 = [] + if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 4): + version_ge_1_5 = ['-DVERSION_GE_1_5'] + return version_ge_1_1 + version_ge_1_3 + version_ge_1_5 + + def is_compatible(self, verbose=False): + return super().is_compatible(verbose) + + def builder(self): + try: + if not self.is_rocm_pytorch(): + assert_no_cuda_mismatch(self.name) + self.build_for_cpu = False + except MissingCUDAException: + self.build_for_cpu = True + + if self.build_for_cpu: + from torch.utils.cpp_extension import CppExtension as ExtensionBuilder + else: + from torch.utils.cpp_extension import CUDAExtension as ExtensionBuilder + include_dirs = [os.path.abspath(x) for x in self.strip_empty_entries(self.include_paths())] + compile_args = {'cxx': self.strip_empty_entries(self.cxx_args())} if self.build_for_cpu else \ + {'cxx': self.strip_empty_entries(self.cxx_args()), \ + 'nvcc': self.strip_empty_entries(self.nvcc_args())} + + if not self.build_for_cpu and self.enable_bf16: + compile_args['cxx'].append("-DBF16_AVAILABLE") + compile_args['nvcc'].append("-DBF16_AVAILABLE") + + if self.is_rocm_pytorch(): + compile_args['cxx'].append("-D__HIP_PLATFORM_AMD__=1") + #cxx compiler args are required to compile cpp files + compile_args['cxx'].append('-DROCM_WAVEFRONT_SIZE=%s' % self.get_rocm_wavefront_size()) + #nvcc compiler args are required to compile hip files + compile_args['nvcc'].append('-DROCM_WAVEFRONT_SIZE=%s' % self.get_rocm_wavefront_size()) + if self.get_rocm_gpu_arch(): + os.environ["PYTORCH_ROCM_ARCH"] = self.get_rocm_gpu_arch() + + cuda_ext = ExtensionBuilder(name=self.absolute_name(), + sources=self.strip_empty_entries(self.sources()), + include_dirs=include_dirs, + libraries=self.strip_empty_entries(self.libraries_args()), + extra_compile_args=compile_args, + extra_link_args=self.strip_empty_entries(self.extra_ldflags())) + + if self.is_rocm_pytorch(): + # hip converts paths to absolute, this converts back to relative + sources = cuda_ext.sources + curr_file = Path(__file__).parent.parent # ds root + for i in range(len(sources)): + src = Path(sources[i]) + if src.is_absolute(): + sources[i] = str(src.relative_to(curr_file)) + else: + sources[i] = str(src) + cuda_ext.sources = sources + return cuda_ext + + def hipify_extension(self): + if self.is_rocm_pytorch(): + from torch.utils.hipify import hipify_python + hipify_python.hipify( + project_directory=os.getcwd(), + output_directory=os.getcwd(), + header_include_dirs=self.include_paths(), + includes=[os.path.join(os.getcwd(), '*')], + extra_files=[os.path.abspath(s) for s in self.sources()], + show_detailed=True, + is_pytorch_extension=True, + hipify_extra_files_only=True, + ) + + def cxx_args(self): + if sys.platform == "win32": + return ['-O2'] + else: + return ['-O3', '-std=c++17', '-g', '-Wno-reorder'] + + def nvcc_args(self): + if self.build_for_cpu: + return [] + args = ['-O3'] + if self.is_rocm_pytorch(): + ROCM_MAJOR, ROCM_MINOR = self.installed_rocm_version() + args += [ + '-std=c++17', '-U__HIP_NO_HALF_OPERATORS__', '-U__HIP_NO_HALF_CONVERSIONS__', + '-U__HIP_NO_HALF2_OPERATORS__', + '-DROCM_VERSION_MAJOR=%s' % ROCM_MAJOR, + '-DROCM_VERSION_MINOR=%s' % ROCM_MINOR + ] + else: + try: + nvcc_threads = int(os.getenv("DS_NVCC_THREADS", "")) + if nvcc_threads <= 0: + raise ValueError("") + except ValueError: + nvcc_threads = min(os.cpu_count(), 8) + + cuda_major, cuda_minor = installed_cuda_version() + if cuda_major > 10: + if cuda_major == 12 and cuda_minor >= 5: + std_lib = '-std=c++20' + else: + std_lib = '-std=c++17' + else: + std_lib = '-std=c++14' + args += [ + '-allow-unsupported-compiler' if sys.platform == "win32" else '', + '-std=c++17' if cuda_major >= 10 else '-std=c++14', '-U__CUDA_NO_HALF_OPERATORS__', + '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__' + + ] + if os.environ.get('DS_DEBUG_CUDA_BUILD', '0') == '1': + args.append('--ptxas-options=-v') + args += self.compute_capability_args() + return args + + def libraries_args(self): + if self.build_for_cpu: + return [] + + if sys.platform == "win32": + return ['cublas', 'curand'] + else: + return [] + + +class TorchCPUOpBuilder(CUDAOpBuilder): + + def get_cuda_lib64_path(self): + import torch + if not self.is_rocm_pytorch(): + CUDA_LIB64 = os.path.join(torch.utils.cpp_extension.CUDA_HOME, "lib64") + if not os.path.exists(CUDA_LIB64): + CUDA_LIB64 = os.path.join(torch.utils.cpp_extension.CUDA_HOME, "lib") + else: + CUDA_LIB64 = os.path.join(torch.utils.cpp_extension.ROCM_HOME, "lib") + return CUDA_LIB64 + + def extra_ldflags(self): + if self.build_for_cpu: + return ['-fopenmp'] + + if not self.is_rocm_pytorch(): + ld_flags = ['-lcurand'] + if not self.build_for_cpu: + ld_flags.append(f'-L{self.get_cuda_lib64_path()}') + return ld_flags + + return [] + + def cxx_args(self): + args = [] + if not self.build_for_cpu: + CUDA_LIB64 = self.get_cuda_lib64_path() + + args += super().cxx_args() + args += [ + f'-L{CUDA_LIB64}', + '-lcudart', + '-lcublas', + '-g', + ] + + CPU_ARCH = self.cpu_arch() + SIMD_WIDTH = self.simd_width() + CUDA_ENABLE = self.is_cuda_enable() + args += [ + CPU_ARCH, + '-fopenmp', + SIMD_WIDTH, + CUDA_ENABLE, + ] + + return args diff --git a/toolbox/DeepSpeed/v0.15.3/patches/op_builder/fused_adam.py b/toolbox/DeepSpeed/v0.15.3/patches/op_builder/fused_adam.py new file mode 100644 index 0000000000000000000000000000000000000000..e298c3b2564fbeb9f5ffad9fb9d2634f73992a34 --- /dev/null +++ b/toolbox/DeepSpeed/v0.15.3/patches/op_builder/fused_adam.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .builder import CUDAOpBuilder + +import sys + + +class FusedAdamBuilder(CUDAOpBuilder): + BUILD_VAR = "DS_BUILD_FUSED_ADAM" + NAME = "fused_adam" + + def __init__(self): + super().__init__(name=self.NAME) + + def absolute_name(self): + return f'deepspeed.ops.adam.{self.NAME}_op' + + def sources(self): + return ['csrc/adam/fused_adam_frontend.cpp', 'csrc/adam/multi_tensor_adam.cu'] + + def include_paths(self): + return ['csrc/includes', 'csrc/adam'] + + def cxx_args(self): + args = super().cxx_args() + return args + self.version_dependent_macros() + + def nvcc_args(self): + nvcc_flags = ['-O3'] + self.version_dependent_macros() + if not self.is_rocm_pytorch(): + nvcc_flags.extend( + ['-allow-unsupported-compiler' if sys.platform == "win32" else '', '-lineinfo'] + + self.compute_capability_args()) + return nvcc_flags diff --git a/toolbox/DeepSpeed/v0.15.3/patches/op_builder/fused_lamb.py b/toolbox/DeepSpeed/v0.15.3/patches/op_builder/fused_lamb.py new file mode 100644 index 0000000000000000000000000000000000000000..6605e0169d3af52c524147ea548034cec245bc7b --- /dev/null +++ b/toolbox/DeepSpeed/v0.15.3/patches/op_builder/fused_lamb.py @@ -0,0 +1,43 @@ +#!/usr/bin/env python3 +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .builder import CUDAOpBuilder + +import sys + + +class FusedLambBuilder(CUDAOpBuilder): + BUILD_VAR = 'DS_BUILD_FUSED_LAMB' + NAME = "fused_lamb" + + def __init__(self): + super().__init__(name=self.NAME) + + def absolute_name(self): + return f'deepspeed.ops.lamb.{self.NAME}_op' + + def sources(self): + return ['csrc/lamb/fused_lamb_cuda.cpp', 'csrc/lamb/fused_lamb_cuda_kernel.cu'] + + def include_paths(self): + return ['csrc/includes'] + + def cxx_args(self): + args = super().cxx_args() + return args + self.version_dependent_macros() + + def nvcc_args(self): + nvcc_flags = ['-O3'] + self.version_dependent_macros() + if self.is_rocm_pytorch(): + ROCM_MAJOR, ROCM_MINOR = self.installed_rocm_version() + nvcc_flags += ['-DROCM_VERSION_MAJOR=%s' % ROCM_MAJOR, '-DROCM_VERSION_MINOR=%s' % ROCM_MINOR] + else: + nvcc_flags.extend( + ['-allow-unsupported-compiler' if sys.platform == "win32" else '', '-lineinfo'] + + self.compute_capability_args()) + return nvcc_flags diff --git a/toolbox/DeepSpeed/v0.15.3/patches/op_builder/fused_layernorm.py b/toolbox/DeepSpeed/v0.15.3/patches/op_builder/fused_layernorm.py new file mode 100644 index 0000000000000000000000000000000000000000..8ada1a7f8cc7c301cebef7a4962d68d52ca0e239 --- /dev/null +++ b/toolbox/DeepSpeed/v0.15.3/patches/op_builder/fused_layernorm.py @@ -0,0 +1,65 @@ +#!/usr/bin/env python3 +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .builder import CUDAOpBuilder + +import sys +class FusedLayernormBuilder(CUDAOpBuilder): + BUILD_VAR = "DS_BUILD_FUSED_LAYERNORM" + NAME = "fused_layernorm" + + def __init__(self, name=None): + name = self.NAME if name is None else name + super().__init__(name=name) + + def absolute_name(self): + return f'deepspeed.ops.layernorm.{self.NAME}_op' + + def sources(self): + return [ + 'csrc/layernorm/layer_norm_cuda.cpp', 'csrc/layernorm/layer_norm_cuda_kernel.cu'] + + def extra_ldflags(self): + if not self.is_rocm_pytorch(): + return ['-lcurand'] + else: + return [] + + def include_paths(self): + includes = ['csrc/includes'] + return includes + def cxx_args(self): + args = ['-O3'] + return args + self.version_dependent_macros() + + def nvcc_args(self): + nvcc_flags = ['-O3'] + self.version_dependent_macros() + if self.is_rocm_pytorch(): + ROCM_MAJOR, ROCM_MINOR = self.installed_rocm_version() + nvcc_flags += ['-DROCM_VERSION_MAJOR=%s' % ROCM_MAJOR, '-DROCM_VERSION_MINOR=%s' % ROCM_MINOR] + else: + nvcc_flags.extend( + ['-allow-unsupported-compiler' if sys.platform == "win32" else '', '-lineinfo'] + + self.compute_capability_args()) + return nvcc_flags + + + + diff --git a/toolbox/DeepSpeed/v0.15.3/patches/op_builder/fused_lion.py b/toolbox/DeepSpeed/v0.15.3/patches/op_builder/fused_lion.py new file mode 100644 index 0000000000000000000000000000000000000000..a6136702b3db80ed514771825825f35a2be320b1 --- /dev/null +++ b/toolbox/DeepSpeed/v0.15.3/patches/op_builder/fused_lion.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .builder import CUDAOpBuilder + +import sys + + +class FusedLionBuilder(CUDAOpBuilder): + BUILD_VAR = "DS_BUILD_FUSED_LION" + NAME = "fused_lion" + + def __init__(self): + super().__init__(name=self.NAME) + + def absolute_name(self): + return f'deepspeed.ops.lion.{self.NAME}_op' + + def sources(self): + return ['csrc/lion/fused_lion_frontend.cpp', 'csrc/lion/multi_tensor_lion.cu'] + + def include_paths(self): + return ['csrc/includes', 'csrc/lion'] + + def cxx_args(self): + args = super().cxx_args() + return args + self.version_dependent_macros() + + def nvcc_args(self): + nvcc_flags = ['-O3'] + self.version_dependent_macros() + if not self.is_rocm_pytorch(): + nvcc_flags.extend( + ['-allow-unsupported-compiler' if sys.platform == "win32" else '', '-lineinfo'] + + self.compute_capability_args()) + return nvcc_flags diff --git a/toolbox/DeepSpeed/v0.15.3/patches/op_builder/fused_rope.py b/toolbox/DeepSpeed/v0.15.3/patches/op_builder/fused_rope.py new file mode 100644 index 0000000000000000000000000000000000000000..1f546b594d2107034590852ac9eec42d92fee4da --- /dev/null +++ b/toolbox/DeepSpeed/v0.15.3/patches/op_builder/fused_rope.py @@ -0,0 +1,65 @@ +#!/usr/bin/env python3 +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .builder import CUDAOpBuilder + +import sys +class FusedRopeBuilder(CUDAOpBuilder): + BUILD_VAR = "DS_BUILD_FUSED_ROPE" + NAME = "fused_rope" + + def __init__(self, name=None): + name = self.NAME if name is None else name + super().__init__(name=name) + + def absolute_name(self): + return f'deepspeed.ops.rope.{self.NAME}_op' + + def sources(self): + return [ + 'csrc/rope/fused_rotary_positional_embedding.cpp', 'csrc/rope/fused_rotary_positional_embedding_cuda.cu'] + + def extra_ldflags(self): + if not self.is_rocm_pytorch(): + return ['-lcurand'] + else: + return [] + + def include_paths(self): + includes = ['csrc/includes'] + return includes + def cxx_args(self): + args = ['-O3'] + return args + self.version_dependent_macros() + + def nvcc_args(self): + nvcc_flags = ['-O3','-U__CUDA_NO_HALF_OPERATORS__','-U__CUDA_NO_HALF_CONVERSIONS__'] + self.version_dependent_macros() + if self.is_rocm_pytorch(): + ROCM_MAJOR, ROCM_MINOR = self.installed_rocm_version() + nvcc_flags += ['-DROCM_VERSION_MAJOR=%s' % ROCM_MAJOR, '-DROCM_VERSION_MINOR=%s' % ROCM_MINOR] + else: + nvcc_flags.extend( + ['-allow-unsupported-compiler' if sys.platform == "win32" else '', '-lineinfo'] + + self.compute_capability_args()) + return nvcc_flags + + + + diff --git a/toolbox/DeepSpeed/v0.15.3/patches/op_builder/gds.py b/toolbox/DeepSpeed/v0.15.3/patches/op_builder/gds.py new file mode 100644 index 0000000000000000000000000000000000000000..3898c23c032f4ee85d41a71e878ac0b6518f3f54 --- /dev/null +++ b/toolbox/DeepSpeed/v0.15.3/patches/op_builder/gds.py @@ -0,0 +1,69 @@ +#!/usr/bin/env python3 +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import os +from .async_io import AsyncIOBuilder + + +class GDSBuilder(AsyncIOBuilder): + BUILD_VAR = "DS_BUILD_GDS" + NAME = "gds" + + def __init__(self): + super().__init__() + + def absolute_name(self): + return f'deepspeed.ops.gds.{self.NAME}_op' + + def lib_sources(self): + src_list = ['csrc/gds/py_lib/deepspeed_py_gds_handle.cpp', 'csrc/gds/py_lib/deepspeed_gds_op.cpp'] + return super().lib_sources() + src_list + + def sources(self): + return self.lib_sources() + ['csrc/gds/py_lib/py_ds_gds.cpp'] + + def cxx_args(self): + return super().cxx_args() + ['-lcufile'] + ['-D__ILUVATAR__'] + + def include_paths(self): + import torch + CUDA_INCLUDE = [os.path.join(torch.utils.cpp_extension.CUDA_HOME, "include")] + return ['csrc/aio/py_lib', 'csrc/aio/common'] + CUDA_INCLUDE + + def extra_ldflags(self): + return super().extra_ldflags() + ['-lcufile'] + + def is_compatible(self, verbose=False): + if self.is_rocm_pytorch(): + if verbose: + self.warning(f'{self.NAME} is not compatible with ROCM') + return False + + try: + import torch.utils.cpp_extension + except ImportError: + if verbose: + self.warning("Please install torch if trying to pre-compile GDS") + return False + + CUDA_HOME = torch.utils.cpp_extension.CUDA_HOME + if CUDA_HOME is None: + if verbose: + self.warning("Please install torch CUDA if trying to pre-compile GDS with CUDA") + return False + + CUDA_LIB64 = os.path.join(CUDA_HOME, "lib64") + gds_compatible = self.has_function(funcname="cuFileDriverOpen", + libraries=("cufile", ), + library_dirs=( + CUDA_HOME, + CUDA_LIB64, + ), + verbose=verbose) + + return gds_compatible and super().is_compatible(verbose) diff --git a/toolbox/DeepSpeed/v0.15.3/patches/op_builder/swiglu.py b/toolbox/DeepSpeed/v0.15.3/patches/op_builder/swiglu.py new file mode 100644 index 0000000000000000000000000000000000000000..2468477d057285e1f07b146cab3e49d6f489243d --- /dev/null +++ b/toolbox/DeepSpeed/v0.15.3/patches/op_builder/swiglu.py @@ -0,0 +1,65 @@ +#!/usr/bin/env python3 +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .builder import CUDAOpBuilder + +import sys +class SwigluBuilder(CUDAOpBuilder): + BUILD_VAR = "DS_BUILD_SWIGLU" + NAME = "swiglu" + + def __init__(self, name=None): + name = self.NAME if name is None else name + super().__init__(name=name) + + def absolute_name(self): + return f'deepspeed.ops.swiglu.{self.NAME}_op' + + def sources(self): + return [ + 'csrc/swiglu/swiglu.cpp', 'csrc/swiglu/swiglu_kernel.cu'] + + def extra_ldflags(self): + if not self.is_rocm_pytorch(): + return ['-lcurand'] + else: + return [] + + def include_paths(self): + includes = ['csrc/includes'] + return includes + def cxx_args(self): + args = ['-O3'] + return args + self.version_dependent_macros() + + def nvcc_args(self): + nvcc_flags = ['-O3','-U__CUDA_NO_HALF_OPERATORS__','-U__CUDA_NO_HALF_CONVERSIONS__'] + self.version_dependent_macros() + if self.is_rocm_pytorch(): + ROCM_MAJOR, ROCM_MINOR = self.installed_rocm_version() + nvcc_flags += ['-DROCM_VERSION_MAJOR=%s' % ROCM_MAJOR, '-DROCM_VERSION_MINOR=%s' % ROCM_MINOR] + else: + nvcc_flags.extend( + ['-allow-unsupported-compiler' if sys.platform == "win32" else '', '-lineinfo'] + + self.compute_capability_args()) + return nvcc_flags + + + + diff --git a/toolbox/DeepSpeed/v0.15.3/patches/requirements/requirements-bi.txt b/toolbox/DeepSpeed/v0.15.3/patches/requirements/requirements-bi.txt new file mode 100644 index 0000000000000000000000000000000000000000..d2a8e53fd84cdc291696a1ef1858d591e527614e --- /dev/null +++ b/toolbox/DeepSpeed/v0.15.3/patches/requirements/requirements-bi.txt @@ -0,0 +1,9 @@ +hjson +ninja +numpy +packaging>=20.0 +psutil +py-cpuinfo +pydantic +pynvml +tqdm \ No newline at end of file diff --git a/toolbox/DeepSpeed/v0.15.3/patches/setup.py b/toolbox/DeepSpeed/v0.15.3/patches/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..af090ec7470af8694ff9f7118b80e0b1de79453f --- /dev/null +++ b/toolbox/DeepSpeed/v0.15.3/patches/setup.py @@ -0,0 +1,338 @@ +#!/usr/bin/env python3 +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +""" +DeepSpeed library + +To build wheel on Windows: +1. Install pytorch, such as pytorch 1.12 + cuda 11.6. +2. Install visual cpp build tool. +3. Include cuda toolkit. +4. Launch cmd console with Administrator privilege for creating required symlink folders. + + +Create a new wheel via the following command: +build_win.bat + +The wheel will be located at: dist/*.whl +""" + +import pathlib +import os +import shutil +import sys +import subprocess +from setuptools import setup, find_packages +from setuptools.command import egg_info +import time +import typing +import shlex + +torch_available = True +try: + import torch +except ImportError: + torch_available = False + print('[WARNING] Unable to import torch, pre-compiling ops will be disabled. ' \ + 'Please visit https://pytorch.org/ to see how to properly install torch on your system.') + +from op_builder import get_default_compute_capabilities, OpBuilder +from op_builder.all_ops import ALL_OPS, accelerator_name +from op_builder.builder import installed_cuda_version + +from accelerator import get_accelerator + +# Fetch rocm state. +is_rocm_pytorch = OpBuilder.is_rocm_pytorch() +rocm_version = OpBuilder.installed_rocm_version() + +RED_START = '\033[31m' +RED_END = '\033[0m' +ERROR = f"{RED_START} [ERROR] {RED_END}" + + +def abort(msg): + print(f"{ERROR} {msg}") + assert False, msg + + +def fetch_requirements(path): + with open(path, 'r') as fd: + return [r.strip() for r in fd.readlines()] + + +def is_env_set(key): + """ + Checks if an environment variable is set and not "". + """ + return bool(os.environ.get(key, None)) + + +def get_env_if_set(key, default: typing.Any = ""): + """ + Returns an environment variable if it is set and not "", + otherwise returns a default value. In contrast, the fallback + parameter of os.environ.get() is skipped if the variable is set to "". + """ + return os.environ.get(key, None) or default + + +install_requires = fetch_requirements('requirements/requirements.txt') +extras_require = { + '1bit': [], # add cupy based on cuda/rocm version + '1bit_mpi': fetch_requirements('requirements/requirements-1bit-mpi.txt'), + 'readthedocs': fetch_requirements('requirements/requirements-readthedocs.txt'), + 'dev': fetch_requirements('requirements/requirements-dev.txt'), + 'autotuning': fetch_requirements('requirements/requirements-autotuning.txt'), + 'autotuning_ml': fetch_requirements('requirements/requirements-autotuning-ml.txt'), + 'sparse_attn': fetch_requirements('requirements/requirements-sparse_attn.txt'), + 'sparse': fetch_requirements('requirements/requirements-sparse_pruning.txt'), + 'inf': fetch_requirements('requirements/requirements-inf.txt'), + 'sd': fetch_requirements('requirements/requirements-sd.txt'), + 'triton': fetch_requirements('requirements/requirements-triton.txt'), +} + +# Only install pynvml on nvidia gpus. +if torch_available and get_accelerator().device_name() == 'cuda' and not is_rocm_pytorch: + install_requires.append('nvidia-ml-py') + +# Add specific cupy version to both onebit extension variants. +if torch_available and get_accelerator().device_name() == 'cuda': + cupy = None + if is_rocm_pytorch: + rocm_major, rocm_minor = rocm_version + # XXX cupy support for rocm 5 is not available yet. + if rocm_major <= 4: + cupy = f"cupy-rocm-{rocm_major}-{rocm_minor}" + else: + cuda_major_ver, cuda_minor_ver = installed_cuda_version() + if (cuda_major_ver < 11) or ((cuda_major_ver == 11) and (cuda_minor_ver < 3)): + cupy = f"cupy-cuda{cuda_major_ver}{cuda_minor_ver}" + else: + cupy = f"cupy-cuda{cuda_major_ver}x" + + if cupy: + extras_require['1bit'].append(cupy) + extras_require['1bit_mpi'].append(cupy) + +# Make an [all] extra that installs all needed dependencies. +all_extras = set() +for extra in extras_require.items(): + for req in extra[1]: + all_extras.add(req) +extras_require['all'] = list(all_extras) + +cmdclass = {} + +# For any pre-installed ops force disable ninja. +if torch_available: + use_ninja = is_env_set("DS_ENABLE_NINJA") + cmdclass['build_ext'] = get_accelerator().build_extension().with_options(use_ninja=use_ninja) + +if torch_available: + TORCH_MAJOR = torch.__version__.split('.')[0] + TORCH_MINOR = torch.__version__.split('.')[1] +else: + TORCH_MAJOR = "0" + TORCH_MINOR = "0" + +if torch_available and not get_accelerator().device_name() == 'cuda': + # Fix to allow docker builds, similar to https://github.com/NVIDIA/apex/issues/486. + print("[WARNING] Torch did not find cuda available, if cross-compiling or running with cpu only " + "you can ignore this message. Adding compute capability for Pascal, Volta, and Turing " + "(compute capabilities 6.0, 6.1, 6.2)") + if not is_env_set("TORCH_CUDA_ARCH_LIST"): + os.environ["TORCH_CUDA_ARCH_LIST"] = get_default_compute_capabilities() + +ext_modules = [] + +# Default to pre-install kernels to false so we rely on JIT on Linux, opposite on Windows. +BUILD_OP_PLATFORM = 1 if sys.platform == "win32" else 0 +BUILD_OP_DEFAULT = int(get_env_if_set('DS_BUILD_OPS', BUILD_OP_PLATFORM)) +print(f"DS_BUILD_OPS={BUILD_OP_DEFAULT}") + +if BUILD_OP_DEFAULT: + assert torch_available, "Unable to pre-compile ops without torch installed. Please install torch before attempting to pre-compile ops." + + +def command_exists(cmd): + if sys.platform == "win32": + safe_cmd = shlex.split(f'{cmd}') + result = subprocess.Popen(safe_cmd, stdout=subprocess.PIPE) + return result.wait() == 1 + else: + safe_cmd = shlex.split(f"bash -c type {cmd}") + result = subprocess.Popen(safe_cmd, stdout=subprocess.PIPE) + return result.wait() == 0 + + +def op_envvar(op_name): + assert hasattr(ALL_OPS[op_name], 'BUILD_VAR'), \ + f"{op_name} is missing BUILD_VAR field" + return ALL_OPS[op_name].BUILD_VAR + + +def op_enabled(op_name): + env_var = op_envvar(op_name) + return int(get_env_if_set(env_var, BUILD_OP_DEFAULT)) + + +install_ops = dict.fromkeys(ALL_OPS.keys(), False) +for op_name, builder in ALL_OPS.items(): + op_compatible = builder.is_compatible() + + # If op is requested but not available, throw an error. + if op_enabled(op_name) and not op_compatible: + env_var = op_envvar(op_name) + if not is_env_set(env_var): + builder.warning(f"Skip pre-compile of incompatible {op_name}; One can disable {op_name} with {env_var}=0") + continue + + # If op is compatible but install is not enabled (JIT mode). + if is_rocm_pytorch and op_compatible and not op_enabled(op_name): + builder.hipify_extension() + + # If op install enabled, add builder to extensions. + if op_enabled(op_name) and op_compatible: + assert torch_available, f"Unable to pre-compile {op_name}, please first install torch" + install_ops[op_name] = op_enabled(op_name) + ext_modules.append(builder.builder()) + +print(f'Install Ops={install_ops}') + +# Write out version/git info. +git_hash_cmd = shlex.split("bash -c git rev-parse --short HEAD") +git_branch_cmd = shlex.split("bash -c git rev-parse --abbrev-ref HEAD") +if command_exists('git') and not is_env_set('DEEPSPEED_LOCAL_VERSION_IDENTIFIER'): + try: + result = subprocess.check_output(git_hash_cmd) + git_hash = result.decode('utf-8').strip() + result = subprocess.check_output(git_branch_cmd) + git_branch = result.decode('utf-8').strip() + except subprocess.CalledProcessError: + git_hash = "unknown" + git_branch = "unknown" +else: + git_hash = "unknown" + git_branch = "unknown" + +if sys.platform == "win32": + shutil.rmtree('.\\deepspeed\\ops\\csrc', ignore_errors=True) + pathlib.Path('.\\deepspeed\\ops\\csrc').unlink(missing_ok=True) + shutil.copytree('.\\csrc', '.\\deepspeed\\ops\\csrc', dirs_exist_ok=True) + shutil.rmtree('.\\deepspeed\\ops\\op_builder', ignore_errors=True) + pathlib.Path('.\\deepspeed\\ops\\op_builder').unlink(missing_ok=True) + shutil.copytree('.\\op_builder', '.\\deepspeed\\ops\\op_builder', dirs_exist_ok=True) + shutil.rmtree('.\\deepspeed\\accelerator', ignore_errors=True) + pathlib.Path('.\\deepspeed\\accelerator').unlink(missing_ok=True) + shutil.copytree('.\\accelerator', '.\\deepspeed\\accelerator', dirs_exist_ok=True) + egg_info.manifest_maker.template = 'MANIFEST_win.in' + +# Parse the DeepSpeed version string from version.txt. +version_str = open('version.txt', 'r').read().strip() + +# Build specifiers like .devX can be added at install time. Otherwise, add the git hash. +# Example: DS_BUILD_STRING=".dev20201022" python setup.py sdist bdist_wheel. + +# Building wheel for distribution, update version file. +if is_env_set('DEEPSPEED_LOCAL_VERSION_IDENTIFIER'): + local_version_identifier = os.environ.get("DEEPSPEED_LOCAL_VERSION_IDENTIFIER") + # Build string env specified, probably building for distribution. + with open('build.txt', 'w') as fd: + fd.write(os.environ['DEEPSPEED_LOCAL_VERSION_IDENTIFIER']) + version_str += "+" + local_version_identifier +elif os.path.isfile('build.txt'): + # build.txt exists, probably installing from distribution. + with open('build.txt', 'r') as fd: + version_str += fd.read().strip() +else: + # None of the above, probably installing from source. + version_str += f'+{git_hash}' + +torch_version = ".".join([TORCH_MAJOR, TORCH_MINOR]) +bf16_support = False +# Set cuda_version to 0.0 if cpu-only. +cuda_version = "0.0" +nccl_version = "0.0" +# Set hip_version to 0.0 if cpu-only. +hip_version = "0.0" +if torch_available and torch.version.cuda is not None: + cuda_version = ".".join(torch.version.cuda.split('.')[:2]) + if sys.platform != "win32": + if isinstance(torch.cuda.nccl.version(), int): + # This will break if minor version > 9. + nccl_version = ".".join(str(torch.cuda.nccl.version())[:2]) + else: + nccl_version = ".".join(map(str, torch.cuda.nccl.version()[:2])) + if hasattr(torch.cuda, 'is_bf16_supported') and torch.cuda.is_available(): + bf16_support = torch.cuda.is_bf16_supported() + +if torch_available and hasattr(torch.version, 'hip') and torch.version.hip is not None: + hip_version = ".".join(torch.version.hip.split('.')[:2]) +torch_info = { + "version": torch_version, + "bf16_support": bf16_support, + "cuda_version": cuda_version, + "nccl_version": nccl_version, + "hip_version": hip_version +} + +print(f"version={version_str}, git_hash={git_hash}, git_branch={git_branch}") +with open('deepspeed/git_version_info_installed.py', 'w') as fd: + fd.write(f"version='{version_str}'\n") + fd.write(f"git_hash='{git_hash}'\n") + fd.write(f"git_branch='{git_branch}'\n") + fd.write(f"installed_ops={install_ops}\n") + fd.write(f"accelerator_name='{accelerator_name}'\n") + fd.write(f"torch_info={torch_info}\n") + +print(f'install_requires={install_requires}') +print(f'ext_modules={ext_modules}') + +# Parse README.md to make long_description for PyPI page. +thisdir = os.path.abspath(os.path.dirname(__file__)) +with open(os.path.join(thisdir, 'README.md'), encoding='utf-8') as fin: + readme_text = fin.read() + +if sys.platform == "win32": + scripts = ['bin/deepspeed.bat', 'bin/ds', 'bin/ds_report.bat', 'bin/ds_report'] +else: + scripts = [ + 'bin/deepspeed', 'bin/deepspeed.pt', 'bin/ds', 'bin/ds_ssh', 'bin/ds_report', 'bin/ds_bench', 'bin/dsr', + 'bin/ds_elastic', 'bin/ds_nvme_tune', 'bin/ds_io' + ] + +start_time = time.time() + +setup(name='deepspeed', + version=version_str, + description='DeepSpeed library', + long_description=readme_text, + long_description_content_type='text/markdown', + author='DeepSpeed Team', + author_email='deepspeed-info@microsoft.com', + url='http://deepspeed.ai', + project_urls={ + 'Documentation': 'https://deepspeed.readthedocs.io', + 'Source': 'https://github.com/microsoft/DeepSpeed', + }, + install_requires=install_requires, + extras_require=extras_require, + packages=find_packages(include=['deepspeed', 'deepspeed.*']), + include_package_data=True, + scripts=scripts, + classifiers=[ + 'Programming Language :: Python :: 3.6', 'Programming Language :: Python :: 3.7', + 'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: 3.9', + 'Programming Language :: Python :: 3.10' + ], + license='Apache Software License 2.0', + ext_modules=ext_modules, + cmdclass=cmdclass) + +end_time = time.time() +print(f'deepspeed build time = {end_time - start_time} secs') diff --git a/toolbox/Megatron-DeepSpeed/megatron_ds/core/datasets/blended_megatron_dataset_builder.py b/toolbox/Megatron-DeepSpeed/megatron_ds/core/datasets/blended_megatron_dataset_builder.py index 8e7e338c759c08a8658233fdd2db2eb6de204884..e36ba292248a633dee527719fcb90b0d1b77c6b4 100644 --- a/toolbox/Megatron-DeepSpeed/megatron_ds/core/datasets/blended_megatron_dataset_builder.py +++ b/toolbox/Megatron-DeepSpeed/megatron_ds/core/datasets/blended_megatron_dataset_builder.py @@ -1,6 +1,6 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. # All Rights Reserved. +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. import logging import math @@ -247,9 +247,14 @@ class BlendedMegatronDatasetBuilder(object): dataset = None # First, build on rank 0 + # WA: each node's first rank build the dataset cache (some node could not need to do this, but this can work on no shared storage, only given a litte overhead) if torch.distributed.get_rank() % get_accelerator().device_count() == 0: #if rank == 0and getattr(self.config, "is_built_on_rank")(): try: + # @todo: if data_parallel_group has been created, we can use this group to avoid overhead + #vote = get_accelerator().LongTensor([1]).fill_(mpu.get_data_parallel_rank(with_context_parallel=True)) + #torch.distributed.all_reduce(vote, group=mpu.get_data_parallel_group(), op=torch.distributed.ReduceOp.MIN) + #if vote.item() == mpu.get_data_parallel_rank(with_context_parallel=True): dataset = cls(*args) except OSError as err: log = ( diff --git a/toolbox/Megatron-DeepSpeed/megatron_ds/core/datasets/helpers.cpp b/toolbox/Megatron-DeepSpeed/megatron_ds/core/datasets/helpers.cpp index 69073bb85cfa5e83bcc16297bef2804efeb58859..c6c7c7cfb29e924ccd4255d4d010b0cce34bc17c 100644 --- a/toolbox/Megatron-DeepSpeed/megatron_ds/core/datasets/helpers.cpp +++ b/toolbox/Megatron-DeepSpeed/megatron_ds/core/datasets/helpers.cpp @@ -1,6 +1,7 @@ -/* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */ /* Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. */ /* All Rights Reserved. */ +/* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */ + /* Helper methods for fast index mapping builds */ #include diff --git a/toolbox/Megatron-DeepSpeed/tests/tests.py b/toolbox/Megatron-DeepSpeed/tests/tests.py index 7e81d21864baf77e73feff6cf733f1c8ff53e622..24b52718ea79ecdb3b9aa8abf01c8cdf31f31fcd 100644 --- a/toolbox/Megatron-DeepSpeed/tests/tests.py +++ b/toolbox/Megatron-DeepSpeed/tests/tests.py @@ -146,10 +146,10 @@ def run_py_case(args, py_file, test_args: List[str] = None, log_dir: str = None, test_args = [] if "test_utils.py" in py_file: - command = f"torchrun --nproc_per_node=1 -m pytest -s {py_file} {' '.join(test_args)} --junitxml={args.log_dir}/_{py_file.split('/')[-1][:-3]}.xml" + command = f"torchrun --nproc_per_node=1 -m pytest -s {py_file} {' '.join(test_args)} --junitxml={args.log_dir}/___{py_file.split('/')[-1][:-3]}.xml -o junit_suite_name={py_file.split('/')[-1][:-3]}" else: command = f"torchrun --nproc_per_node=8 --nnodes {args.nnodes} --node_rank {args.node_rank} \ - --master_addr {args.master_addr} --master_port {args.master_port} -m pytest -s {py_file} {' '.join(test_args)} --junitxml={args.log_dir}/_{py_file.split('/')[-1][:-3]}.xml" + --master_addr {args.master_addr} --master_port {args.master_port} -m pytest -s {py_file} {' '.join(test_args)} --junitxml={args.log_dir}/___{py_file.split('/')[-1][:-3]}.xml -o junit_suite_name={py_file.split('/')[-1][:-3]}" if log_dir is None: log_dir = DEFAULT_LOG_DIR @@ -210,7 +210,7 @@ def run_py_cases(args, files, log_dir = None, timeout_per_case = None, excludes: test_results = [] for i, file in enumerate(test_files): print(f"Progress: {i+1} / {len(test_files)}, Case: {file}") - + sys.stdout.flush() if not is_valid_test_case(file): print(f"Skip {file}") continue diff --git a/toolbox/firefly/setup.py b/toolbox/firefly/setup.py index 873d14b1095a334324684cfc7ddb06b104df30c2..109c03450e89878eea5a95f2312cad9a5e7b26d5 100644 --- a/toolbox/firefly/setup.py +++ b/toolbox/firefly/setup.py @@ -1,19 +1,17 @@ # Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. # All Rights Reserved. # -# Licensed under the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. You may obtain -# a copy of the License at +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. - - +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import setuptools import os @@ -23,11 +21,19 @@ def req_file(filename, folder="./"): content = f.readlines() return [x.strip() for x in content] + +def make_version(version): + COREX_VERSION = os.getenv("COREX_VERSION", "") + if COREX_VERSION: + return str(version)+"+corex."+COREX_VERSION + return str(version) + + install_requires = req_file("requirements.txt") setuptools.setup( name="firefly", - version="0.1.0", + version=make_version(0.1), description="Firefly is an open-source project for large-scale model training, supporting pre-training and fine-tuning of state-of-the-art large models.", packages=setuptools.find_packages(), python_requires='>=3.7, <4', diff --git a/toolbox/openpcdet/tools/test.py b/toolbox/openpcdet/tools/test.py index 51b7178c68544ad1aa0f9f1975bde752f5111e33..b38fc5d70ded9c7edd816961ad2c80e9851ef93c 100644 --- a/toolbox/openpcdet/tools/test.py +++ b/toolbox/openpcdet/tools/test.py @@ -29,7 +29,7 @@ def parse_config(): parser.add_argument('--pretrained_model', type=str, default=None, help='pretrained_model') parser.add_argument('--launcher', choices=['none', 'pytorch', 'slurm'], default='none') parser.add_argument('--tcp_port', type=int, default=18888, help='tcp port for distrbuted training') - parser.add_argument('--local_rank', type=int, default=0, help='local rank for distributed training') + parser.add_argument('--local-rank', type=int, default=0, help='local rank for distributed training') parser.add_argument('--set', dest='set_cfgs', default=None, nargs=argparse.REMAINDER, help='set extra config keys if needed')