From 4bf3b63de50406d8aad46943e4f301a115ae5df1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=A9=AC=E5=AE=81=E6=82=A6?= Date: Mon, 11 Mar 2024 11:39:49 +0000 Subject: [PATCH 1/4] add debug/weight_convert/bloom.py. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 马宁悦 --- debug/weight_convert/bloom.py | 262 ++++++++++++++++++++++++++++++++++ 1 file changed, 262 insertions(+) create mode 100644 debug/weight_convert/bloom.py diff --git a/debug/weight_convert/bloom.py b/debug/weight_convert/bloom.py new file mode 100644 index 0000000000..6fd4f9524b --- /dev/null +++ b/debug/weight_convert/bloom.py @@ -0,0 +1,262 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the License); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import json +import os +import sys +import logging +from collections import namedtuple +import time +import functools +import shutil +import torch + + +class ConvertBloom: + + model_config = { + "7B": [30, 4096, 32], # num_layers, hiddern_size, num_attention_heads + "176B": [70, 14336, 112] + } + + + def __init__(self, args): + self.huggingface_model = {} + self.layer_weight_idx = {} + self.tp_size = args.tensor_parallel_size + self.pp_size = args.pipeline_parallel_size + self.input_model_dir = args.input_model_dir + self.output_model_dir = args.output_model_dir + self.model_type = args.model_type + self.pp_layers = self.get_partition_layers(args.partition_layers) + self.init_huggingface_model() + + def get_partition_layers(self, partition_layers): + if self.model_type == "7B" and self.pp_size == 1: + return [30] + elif self.model_type == "7B" and self.pp_size == 2: + return [15, 15] + else: + return list(map(int, partition_layers.split(','))) + + def init_huggingface_model(self): + model_index = {} + params = ["input_layernorm.bias", "input_layernorm.weight", "mlp.dense_4h_to_h.bias", "mlp.dense_4h_to_h.weight", \ + "mlp.dense_h_to_4h.bias", "mlp.dense_h_to_4h.weight", "post_attention_layernorm.bias", "post_attention_layernorm.weight", \ + "self_attention.dense.bias", "self_attention.dense.weight", "self_attention.query_key_value.bias", "self_attention.query_key_value.weight"] + + for pp_rank in range(self.pp_size): + for offset in range(self.pp_layers[pp_rank]): + layer_id = sum(self.pp_layers[:pp_rank]) + offset + dest_model_filepath = "pytorch_model_{:05d}-of-00072.bin".format(layer_id + 2) + for param in params: + self.layer_weight_idx[f"h.{layer_id}.{param}"] = dest_model_filepath + + self.huggingface_model[dest_model_filepath] = {} + self.layer_weight_idx["ln_f.bias"] = "pytorch_model_00072-of-00072.bin" + self.layer_weight_idx["ln_f.weight"] = "pytorch_model_00072-of-00072.bin" + self.layer_weight_idx["word_embeddings.weight"] = "pytorch_model_00001-of-00072.bin" + self.layer_weight_idx["word_embeddings_layernorm.bias"] = "pytorch_model_00001-of-00072.bin" + self.layer_weight_idx["word_embeddings_layernorm.weight"] = "pytorch_model_00001-of-00072.bin" + model_index["weight_map"] = self.layer_weight_idx + model_index["metadata"] = {"total_size": 0} + self.huggingface_model["pytorch_model_00072-of-00072.bin"] = {} + self.huggingface_model["pytorch_model_00001-of-00072.bin"] = {} + + if not os.path.exists(self.output_model_dir): + os.makedirs(self.output_model_dir) + + with os.fdopen(os.open(os.path.join(args.output_model_dir, "pytorch_model.bin.index.json"), \ + os.O_WRONLY | os.O_CREAT, 0o640), 'w') as f: + f.write(json.dumps(model_index, indent=4)) + + config_files = ["config.json", "special_tokens_map.json", "tokenizer_config.json", "tokenizer.json"] + for _file in config_files: + srcfile = os.path.join(self.input_model_dir, _file) + if os.path.exists(srcfile): + shutil.copy2(srcfile, self.output_model_dir) + else: + print(f"warning: {srcfile} does not exist!") + + def set_huggingface_weight_by_name(self, layer_weight, w): + ''' + 设置huggingface权重信息,通过layer_weight_idx找到对应保存权重的二进制文件中 + ''' + self.huggingface_model[self.layer_weight_idx[layer_weight]][layer_weight] = w + + def check_has_layer_model(self): + one_layer_model_path_sample = os.path.join(self.input_model_dir, "layer_01-model_00-model_states.pt") + if os.path.exists(one_layer_model_path_sample): + return True + return False + + def convert_from_layer_model(self, pp_size, tp_size, num_layers): + weights_dicts = {"word_embeddings": None, "self_attention_qkv_weight": {}, "self_attention_qkv_bias": {}, \ + "self_attention_dense_weight": {}, "mlp_dense_h_to_4h_weight": {}, "mlp_dense_h_to_4h_bias": {}, "mlp_dense_4h_to_h_weight": {}} + + for pp_rank in range(pp_size): + for tp_rank in range(tp_size): + if pp_rank == 0: + model_path = os.path.join(self.input_model_dir, "layer_01-model_{:02d}-model_states.pt".format(tp_rank)) + ascendspeed_model = torch.load(model_path, map_location="cpu") + + self.set_huggingface_weight_by_name("word_embeddings_layernorm.weight", ascendspeed_model["word_embeddings.norm.weight"]) + self.set_huggingface_weight_by_name("word_embeddings_layernorm.bias", ascendspeed_model["word_embeddings.norm.bias"]) + word_embeddings_read = ascendspeed_model["word_embeddings.weight"] + weights_dicts["word_embeddings"] = row_concat(weights_dicts["word_embeddings"], word_embeddings_read, tp_size, tp_rank) + + if pp_rank == pp_size - 1: + as_layer_id = num_layers + 4 + model_path = os.path.join(self.input_model_dir, "layer_{:02d}-model_{:02d}-model_states.pt".format(as_layer_id, tp_rank)) + ascendspeed_model = torch.load(model_path, map_location="cpu") + self.set_huggingface_weight_by_name("ln_f.weight", ascendspeed_model["weight"]) + self.set_huggingface_weight_by_name("ln_f.bias", ascendspeed_model["bias"]) + + for i in range(self.pp_layers[pp_rank]): + layer_id = sum(self.pp_layers[:pp_rank]) + i + as_layer_id = layer_id + 3 + model_path = os.path.join(self.input_model_dir, "layer_{:02d}-model_{:02d}-model_states.pt".format(as_layer_id, tp_rank)) + ascendspeed_model = torch.load(model_path, map_location="cpu") + + self.set_huggingface_weight_by_name(f"h.{layer_id}.input_layernorm.weight", ascendspeed_model["input_layernorm.weight"]) + self.set_huggingface_weight_by_name(f"h.{layer_id}.input_layernorm.bias", ascendspeed_model["input_layernorm.bias"]) + + self.set_huggingface_weight_by_name(f"h.{layer_id}.post_attention_layernorm.weight", ascendspeed_model["post_attention_layernorm.weight"]) + self.set_huggingface_weight_by_name(f"h.{layer_id}.post_attention_layernorm.bias", ascendspeed_model["post_attention_layernorm.bias"]) + self.set_huggingface_weight_by_name(f"h.{layer_id}.self_attention.dense.bias", ascendspeed_model["self_attention.dense.bias"]) + self.set_huggingface_weight_by_name(f"h.{layer_id}.mlp.dense_4h_to_h.bias", ascendspeed_model["mlp.dense_4h_to_h.bias"]) + + self_attention_qkv_weight = ascendspeed_model["self_attention.query_key_value.weight"] + self_attention_qkv_bias = ascendspeed_model["self_attention.query_key_value.bias"] + self_attention_dense_weight = ascendspeed_model["self_attention.dense.weight"] + mlp_dense_h_to_4h_weight = ascendspeed_model["mlp.dense_h_to_4h.weight"] + mlp_dense_h_to_4h_bias = ascendspeed_model["mlp.dense_h_to_4h.bias"] + mlp_dense_4h_to_h_weight = ascendspeed_model["mlp.dense_4h_to_h.weight"] + + if layer_id not in weights_dicts["self_attention_qkv_weight"]: + weights_dicts["self_attention_qkv_weight"][layer_id] = None + weights_dicts["self_attention_qkv_bias"][layer_id] = None + weights_dicts["self_attention_dense_weight"][layer_id] = None + weights_dicts["mlp_dense_h_to_4h_weight"][layer_id] = None + weights_dicts["mlp_dense_h_to_4h_bias"][layer_id] = None + weights_dicts["mlp_dense_4h_to_h_weight"][layer_id] = None + + weights_dicts["self_attention_qkv_weight"][layer_id] = row_concat(weights_dicts["self_attention_qkv_weight"][layer_id], self_attention_qkv_weight, tp_size, tp_rank) + weights_dicts["self_attention_qkv_bias"][layer_id] = row_concat(weights_dicts["self_attention_qkv_bias"][layer_id], self_attention_qkv_bias, tp_size, tp_rank) + weights_dicts["self_attention_dense_weight"][layer_id] = column_concat(weights_dicts["self_attention_dense_weight"][layer_id], self_attention_dense_weight, tp_size, tp_rank) + weights_dicts["mlp_dense_h_to_4h_weight"][layer_id] = row_concat(weights_dicts["mlp_dense_h_to_4h_weight"][layer_id], mlp_dense_h_to_4h_weight, tp_size, tp_rank) + weights_dicts["mlp_dense_h_to_4h_bias"][layer_id] = row_concat(weights_dicts["mlp_dense_h_to_4h_bias"][layer_id], mlp_dense_h_to_4h_bias, tp_size, tp_rank) + weights_dicts["mlp_dense_4h_to_h_weight"][layer_id] = column_concat(weights_dicts["mlp_dense_4h_to_h_weight"][layer_id], mlp_dense_4h_to_h_weight, tp_size, tp_rank) + + + self.set_huggingface_weight_by_name("word_embeddings.weight", weights_dicts["word_embeddings"]) + for layer_id in weights_dicts["self_attention_qkv_weight"]: + self.set_huggingface_weight_by_name(f"h.{layer_id}.self_attention.query_key_value.weight", weights_dicts["self_attention_qkv_weight"][layer_id]) + self.set_huggingface_weight_by_name(f"h.{layer_id}.self_attention.query_key_value.bias", weights_dicts["self_attention_qkv_bias"][layer_id]) + self.set_huggingface_weight_by_name(f"h.{layer_id}.self_attention.dense.weight", weights_dicts["self_attention_dense_weight"][layer_id]) + self.set_huggingface_weight_by_name(f"h.{layer_id}.mlp.dense_h_to_4h.weight", weights_dicts["mlp_dense_h_to_4h_weight"][layer_id]) + self.set_huggingface_weight_by_name(f"h.{layer_id}.mlp.dense_h_to_4h.bias", weights_dicts["mlp_dense_h_to_4h_bias"][layer_id]) + self.set_huggingface_weight_by_name(f"h.{layer_id}.mlp.dense_4h_to_h.weight", weights_dicts["mlp_dense_4h_to_h_weight"][layer_id]) + + return True + + def convert_from_mprank_model(self, pp_size, tp_size, num_layers): + weights_dicts = {"word_embeddings": None, "self_attention_qkv_weight": {}, "self_attention_qkv_bias": {}, \ + "self_attention_dense_weight": {}, "mlp_dense_h_to_4h_weight": {}, "mlp_dense_h_to_4h_bias": {}, "mlp_dense_4h_to_h_weight": {}} + + for pp_rank in range(pp_size): + for tp_rank in range(tp_size): + model_path = os.path.join(self.input_model_dir, f"{'mp_rank_{:02d}'.format(pp_rank * tp_size + tp_rank)}_model_states.pt") + if not os.path.exists(model_path): + print(f"Error! {model_path} does not exist") + return False + as_pt_model = torch.load(model_path, map_location="cpu") + rank_model = as_pt_model["module"]["module"] + + if pp_rank == 0: + + self.set_huggingface_weight_by_name("word_embeddings_layernorm.weight", rank_model["tied_modules.embed.word_embeddings.norm.weight"]) + self.set_huggingface_weight_by_name("word_embeddings_layernorm.bias", rank_model["tied_modules.embed.word_embeddings.norm.bias"]) + word_embeddings_read = rank_model["tied_modules.embed.word_embeddings.weight"] + weights_dicts["word_embeddings"] = row_concat(weights_dicts["word_embeddings"], word_embeddings_read, tp_size, tp_rank) + + if pp_rank == pp_size - 1: + as_layer_id = num_layers + 4 + self.set_huggingface_weight_by_name("ln_f.weight", rank_model[f"{as_layer_id}.weight"]) + self.set_huggingface_weight_by_name("ln_f.bias", rank_model[f"{as_layer_id}.bias"]) + + for i in range(self.pp_layers[pp_rank]): + layer_id = sum(self.pp_layers[:pp_rank]) + i + as_layer_id = layer_id + 3 + + self.set_huggingface_weight_by_name(f"h.{layer_id}.input_layernorm.weight", rank_model[f"{as_layer_id}.input_layernorm.weight"]) + self.set_huggingface_weight_by_name(f"h.{layer_id}.input_layernorm.bias", rank_model[f"{as_layer_id}.input_layernorm.bias"]) + + self.set_huggingface_weight_by_name(f"h.{layer_id}.post_attention_layernorm.weight", rank_model[f"{as_layer_id}.post_attention_layernorm.weight"]) + self.set_huggingface_weight_by_name(f"h.{layer_id}.post_attention_layernorm.bias", rank_model[f"{as_layer_id}.post_attention_layernorm.bias"]) + self.set_huggingface_weight_by_name(f"h.{layer_id}.self_attention.dense.bias", rank_model[f"{as_layer_id}.self_attention.dense.bias"]) + self.set_huggingface_weight_by_name(f"h.{layer_id}.mlp.dense_4h_to_h.bias", rank_model[f"{as_layer_id}.mlp.dense_4h_to_h.bias"]) + + self_attention_qkv_weight = rank_model[f"{as_layer_id}.self_attention.query_key_value.weight"] + self_attention_qkv_bias = rank_model[f"{as_layer_id}.self_attention.query_key_value.bias"] + self_attention_dense_weight = rank_model[f"{as_layer_id}.self_attention.dense.weight"] + mlp_dense_h_to_4h_weight = rank_model[f"{as_layer_id}.mlp.dense_h_to_4h.weight"] + mlp_dense_h_to_4h_bias = rank_model[f"{as_layer_id}.mlp.dense_h_to_4h.bias"] + mlp_dense_4h_to_h_weight = rank_model[f"{as_layer_id}.mlp.dense_4h_to_h.weight"] + + if layer_id not in weights_dicts["self_attention_qkv_weight"]: + weights_dicts["self_attention_qkv_weight"][layer_id] = None + weights_dicts["self_attention_qkv_bias"][layer_id] = None + weights_dicts["self_attention_dense_weight"][layer_id] = None + weights_dicts["mlp_dense_h_to_4h_weight"][layer_id] = None + weights_dicts["mlp_dense_h_to_4h_bias"][layer_id] = None + weights_dicts["mlp_dense_4h_to_h_weight"][layer_id] = None + + weights_dicts["self_attention_qkv_weight"][layer_id] = row_concat(weights_dicts["self_attention_qkv_weight"][layer_id], self_attention_qkv_weight, tp_size, tp_rank) + weights_dicts["self_attention_qkv_bias"][layer_id] = row_concat(weights_dicts["self_attention_qkv_bias"][layer_id], self_attention_qkv_bias, tp_size, tp_rank) + weights_dicts["self_attention_dense_weight"][layer_id] = column_concat(weights_dicts["self_attention_dense_weight"][layer_id], self_attention_dense_weight, tp_size, tp_rank) + weights_dicts["mlp_dense_h_to_4h_weight"][layer_id] = row_concat(weights_dicts["mlp_dense_h_to_4h_weight"][layer_id], mlp_dense_h_to_4h_weight, tp_size, tp_rank) + weights_dicts["mlp_dense_h_to_4h_bias"][layer_id] = row_concat(weights_dicts["mlp_dense_h_to_4h_bias"][layer_id], mlp_dense_h_to_4h_bias, tp_size, tp_rank) + weights_dicts["mlp_dense_4h_to_h_weight"][layer_id] = column_concat(weights_dicts["mlp_dense_4h_to_h_weight"][layer_id], mlp_dense_4h_to_h_weight, tp_size, tp_rank) + + self.set_huggingface_weight_by_name("word_embeddings.weight", weights_dicts["word_embeddings"]) + for layer_id in weights_dicts["self_attention_qkv_weight"]: + self.set_huggingface_weight_by_name(f"h.{layer_id}.self_attention.query_key_value.weight", weights_dicts["self_attention_qkv_weight"][layer_id]) + self.set_huggingface_weight_by_name(f"h.{layer_id}.self_attention.query_key_value.bias", weights_dicts["self_attention_qkv_bias"][layer_id]) + self.set_huggingface_weight_by_name(f"h.{layer_id}.self_attention.dense.weight", weights_dicts["self_attention_dense_weight"][layer_id]) + self.set_huggingface_weight_by_name(f"h.{layer_id}.mlp.dense_h_to_4h.weight", weights_dicts["mlp_dense_h_to_4h_weight"][layer_id]) + self.set_huggingface_weight_by_name(f"h.{layer_id}.mlp.dense_h_to_4h.bias", weights_dicts["mlp_dense_h_to_4h_bias"][layer_id]) + self.set_huggingface_weight_by_name(f"h.{layer_id}.mlp.dense_4h_to_h.weight", weights_dicts["mlp_dense_4h_to_h_weight"][layer_id]) + + return True + + def generate_huggingface_weight(self): + try: + num_layer, _, _ = AscendspeedToHuggingfaceConvert.model_config[self.model_type] + except KeyError: + print(f"Error! {self.model_type} is not supported!") + return False + if self.check_has_layer_model(): + self.convert_from_layer_model(self.pp_size, self.tp_size, num_layer) + else: + self.convert_from_mprank_model(self.pp_size, self.tp_size, num_layer) + os.makedirs(self.output_model_dir, exist_ok=True) + for file_name in self.huggingface_model: + dest_path = os.path.join(self.output_model_dir, file_name) + print(f"Saving huggingface model to : {dest_path}") + torch.save(self.huggingface_model[file_name], dest_path) + +def convert(args): + coverter = ConvertBloom(args) + coverter.generate_huggingface_weight() -- Gitee From a4281a3e15e79b3e685b772c42dccf91f6e82fcb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=A9=AC=E5=AE=81=E6=82=A6?= Date: Mon, 11 Mar 2024 11:40:22 +0000 Subject: [PATCH 2/4] add debug/weight_convert/convert_ckpt.py. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 马宁悦 --- debug/weight_convert/convert_ckpt.py | 55 ++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) create mode 100644 debug/weight_convert/convert_ckpt.py diff --git a/debug/weight_convert/convert_ckpt.py b/debug/weight_convert/convert_ckpt.py new file mode 100644 index 0000000000..46a978dc81 --- /dev/null +++ b/debug/weight_convert/convert_ckpt.py @@ -0,0 +1,55 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the License); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import sys +import argparse +import importlib +import torch.multiprocessing as mp + +def load_model(model_name): + module_name = f"{model_name}" + try: + converter = importlib.import_module(module_name) + except ModuleNotFoundError: + module_name = model_name + try: + converter = importlib.import_module(module_name) + except ModuleNotFoundError: + sys.exit(f"Unable to load {model_name}. Exiting.") + return converter + + + +def main(): + + parser = argparse.ArgumentParser(description="convert as 2 hf") + + parser.add_argument('-m','--model', type=str, required=True, + choices=['llama', 'bloom', 'gptneox'], + help='Type of the model') + parser.add_argument('-i','--input-model-dir', type=str, required=True, + help='Directory to load model checkpoint from') + parser.add_argument('-o','--output-model-dir', type=str, required=True, + help='Directory to save model checkpoint to') + parser.add_argument('-t','--tensor-parallel-size', type=int, required=True) + parser.add_argument('-p','--pipeline-parallel-size', type=int, required=True) + parser.add_argument('--model-type', type=str, required=True) + parser.add_argument("--partition-layers", type=str, help="the partition method of model when pipeline is used") + args = parser.parse_args() + + converter = load_model(args.model) + converter.convert(args) + +if __name__ == '__main__': + main() -- Gitee From 5dc2a49bf8d524309af6c45362467a3f78b7322c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=A9=AC=E5=AE=81=E6=82=A6?= Date: Mon, 11 Mar 2024 11:41:45 +0000 Subject: [PATCH 3/4] add debug/weight_convert/llama.py. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 马宁悦 --- debug/weight_convert/llama.py | 252 ++++++++++++++++++++++++++++++++++ 1 file changed, 252 insertions(+) create mode 100644 debug/weight_convert/llama.py diff --git a/debug/weight_convert/llama.py b/debug/weight_convert/llama.py new file mode 100644 index 0000000000..cae3ecbe60 --- /dev/null +++ b/debug/weight_convert/llama.py @@ -0,0 +1,252 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the License); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +class ConvertLlama: + + model_config = { + "7B": [40, 4096, 32], # num_layers, hiddern_size, num_attention_heads + "13B": [40, 4096, 32], + "70B": [70, 14336, 112] + } + + + def __init__(self, args): + self.huggingface_model = {} + self.layer_weight_idx = {} + self.tp_size = args.tensor_parallel_size + self.pp_size = args.pipeline_parallel_size + self.output_huggingface_model_dir = args.output_huggingface_model_dir + self.make_vocab_size_divisible_by = args.make_vocab_size_divisible_by + self.ascendspeed_model_dir = args.ascendspeed_model_dir + self.pp_layers = self.get_partition_layers(args.partition_layers) + self.init_huggingface_model() + + def get_partition_layers(self, partition_layers): + if self.model_type == "7B" and self.pp_size == 1: + return [30] + elif self.model_type == "7B" and self.pp_size == 2: + return [15, 15] + else: + return list(map(int, partition_layers.split(','))) + + def init_huggingface_model(self): + model_index = {} + params = ["input_layernorm.bias", "input_layernorm.weight", "mlp.dense_4h_to_h.bias", "mlp.dense_4h_to_h.weight", \ + "mlp.dense_h_to_4h.bias", "mlp.dense_h_to_4h.weight", "post_attention_layernorm.bias", "post_attention_layernorm.weight", \ + "self_attention.dense.bias", "self_attention.dense.weight", "self_attention.query_key_value.bias", "self_attention.query_key_value.weight"] + + for pp_rank in range(self.pp_size): + for offset in range(self.pp_layers[pp_rank]): + layer_id = sum(self.pp_layers[:pp_rank]) + offset + dest_model_filepath = "pytorch_model_{:05d}-of-00072.bin".format(layer_id + 2) + for param in params: + self.layer_weight_idx[f"h.{layer_id}.{param}"] = dest_model_filepath + + self.huggingface_model[dest_model_filepath] = {} + self.layer_weight_idx["ln_f.bias"] = "pytorch_model_00072-of-00072.bin" + self.layer_weight_idx["ln_f.weight"] = "pytorch_model_00072-of-00072.bin" + self.layer_weight_idx["word_embeddings.weight"] = "pytorch_model_00001-of-00072.bin" + self.layer_weight_idx["word_embeddings_layernorm.bias"] = "pytorch_model_00001-of-00072.bin" + self.layer_weight_idx["word_embeddings_layernorm.weight"] = "pytorch_model_00001-of-00072.bin" + model_index["weight_map"] = self.layer_weight_idx + model_index["metadata"] = {"total_size": 0} + self.huggingface_model["pytorch_model_00072-of-00072.bin"] = {} + self.huggingface_model["pytorch_model_00001-of-00072.bin"] = {} + + if not os.path.exists(self.output_huggingface_model_dir): + os.makedirs(self.output_huggingface_model_dir) + + with os.fdopen(os.open(os.path.join(args.output_huggingface_model_dir, "pytorch_model.bin.index.json"), \ + os.O_WRONLY | os.O_CREAT, 0o640), 'w') as f: + f.write(json.dumps(model_index, indent=4)) + + config_files = ["config.json", "special_tokens_map.json", "tokenizer_config.json", "tokenizer.json"] + for _file in config_files: + srcfile = os.path.join(self.ascendspeed_model_dir, _file) + if os.path.exists(srcfile): + shutil.copy2(srcfile, self.output_huggingface_model_dir) + else: + print(f"warning: {srcfile} does not exist!") + + def set_huggingface_weight_by_name(self, layer_weight, w): + ''' + 设置huggingface权重信息,通过layer_weight_idx找到对应保存权重的二进制文件中 + ''' + self.huggingface_model[self.layer_weight_idx[layer_weight]][layer_weight] = w + + def check_has_layer_model(self): + one_layer_model_path_sample = os.path.join(self.ascendspeed_model_dir, "layer_01-model_00-model_states.pt") + if os.path.exists(one_layer_model_path_sample): + return True + return False + + def convert_from_layer_model(self, pp_size, tp_size, num_layers): + weights_dicts = {"word_embeddings": None, "self_attention_qkv_weight": {}, "self_attention_qkv_bias": {}, \ + "self_attention_dense_weight": {}, "mlp_dense_h_to_4h_weight": {}, "mlp_dense_h_to_4h_bias": {}, "mlp_dense_4h_to_h_weight": {}} + + for pp_rank in range(pp_size): + for tp_rank in range(tp_size): + if pp_rank == 0: + model_path = os.path.join(self.ascendspeed_model_dir, "layer_01-model_{:02d}-model_states.pt".format(tp_rank)) + ascendspeed_model = torch.load(model_path, map_location="cpu") + + self.set_huggingface_weight_by_name("word_embeddings_layernorm.weight", ascendspeed_model["word_embeddings.norm.weight"]) + self.set_huggingface_weight_by_name("word_embeddings_layernorm.bias", ascendspeed_model["word_embeddings.norm.bias"]) + word_embeddings_read = ascendspeed_model["word_embeddings.weight"] + weights_dicts["word_embeddings"] = row_concat(weights_dicts["word_embeddings"], word_embeddings_read, tp_size, tp_rank) + + if pp_rank == pp_size - 1: + as_layer_id = num_layers + 4 + model_path = os.path.join(self.ascendspeed_model_dir, "layer_{:02d}-model_{:02d}-model_states.pt".format(as_layer_id, tp_rank)) + ascendspeed_model = torch.load(model_path, map_location="cpu") + self.set_huggingface_weight_by_name("ln_f.weight", ascendspeed_model["weight"]) + self.set_huggingface_weight_by_name("ln_f.bias", ascendspeed_model["bias"]) + + for i in range(self.pp_layers[pp_rank]): + layer_id = sum(self.pp_layers[:pp_rank]) + i + as_layer_id = layer_id + 3 + model_path = os.path.join(self.ascendspeed_model_dir, "layer_{:02d}-model_{:02d}-model_states.pt".format(as_layer_id, tp_rank)) + ascendspeed_model = torch.load(model_path, map_location="cpu") + + self.set_huggingface_weight_by_name(f"h.{layer_id}.input_layernorm.weight", ascendspeed_model["input_layernorm.weight"]) + self.set_huggingface_weight_by_name(f"h.{layer_id}.input_layernorm.bias", ascendspeed_model["input_layernorm.bias"]) + + self.set_huggingface_weight_by_name(f"h.{layer_id}.post_attention_layernorm.weight", ascendspeed_model["post_attention_layernorm.weight"]) + self.set_huggingface_weight_by_name(f"h.{layer_id}.post_attention_layernorm.bias", ascendspeed_model["post_attention_layernorm.bias"]) + self.set_huggingface_weight_by_name(f"h.{layer_id}.self_attention.dense.bias", ascendspeed_model["self_attention.dense.bias"]) + self.set_huggingface_weight_by_name(f"h.{layer_id}.mlp.dense_4h_to_h.bias", ascendspeed_model["mlp.dense_4h_to_h.bias"]) + + self_attention_qkv_weight = ascendspeed_model["self_attention.query_key_value.weight"] + self_attention_qkv_bias = ascendspeed_model["self_attention.query_key_value.bias"] + self_attention_dense_weight = ascendspeed_model["self_attention.dense.weight"] + mlp_dense_h_to_4h_weight = ascendspeed_model["mlp.dense_h_to_4h.weight"] + mlp_dense_h_to_4h_bias = ascendspeed_model["mlp.dense_h_to_4h.bias"] + mlp_dense_4h_to_h_weight = ascendspeed_model["mlp.dense_4h_to_h.weight"] + + if layer_id not in weights_dicts["self_attention_qkv_weight"]: + weights_dicts["self_attention_qkv_weight"][layer_id] = None + weights_dicts["self_attention_qkv_bias"][layer_id] = None + weights_dicts["self_attention_dense_weight"][layer_id] = None + weights_dicts["mlp_dense_h_to_4h_weight"][layer_id] = None + weights_dicts["mlp_dense_h_to_4h_bias"][layer_id] = None + weights_dicts["mlp_dense_4h_to_h_weight"][layer_id] = None + + weights_dicts["self_attention_qkv_weight"][layer_id] = row_concat(weights_dicts["self_attention_qkv_weight"][layer_id], self_attention_qkv_weight, tp_size, tp_rank) + weights_dicts["self_attention_qkv_bias"][layer_id] = row_concat(weights_dicts["self_attention_qkv_bias"][layer_id], self_attention_qkv_bias, tp_size, tp_rank) + weights_dicts["self_attention_dense_weight"][layer_id] = column_concat(weights_dicts["self_attention_dense_weight"][layer_id], self_attention_dense_weight, tp_size, tp_rank) + weights_dicts["mlp_dense_h_to_4h_weight"][layer_id] = row_concat(weights_dicts["mlp_dense_h_to_4h_weight"][layer_id], mlp_dense_h_to_4h_weight, tp_size, tp_rank) + weights_dicts["mlp_dense_h_to_4h_bias"][layer_id] = row_concat(weights_dicts["mlp_dense_h_to_4h_bias"][layer_id], mlp_dense_h_to_4h_bias, tp_size, tp_rank) + weights_dicts["mlp_dense_4h_to_h_weight"][layer_id] = column_concat(weights_dicts["mlp_dense_4h_to_h_weight"][layer_id], mlp_dense_4h_to_h_weight, tp_size, tp_rank) + + + self.set_huggingface_weight_by_name("word_embeddings.weight", weights_dicts["word_embeddings"]) + for layer_id in weights_dicts["self_attention_qkv_weight"]: + self.set_huggingface_weight_by_name(f"h.{layer_id}.self_attention.query_key_value.weight", weights_dicts["self_attention_qkv_weight"][layer_id]) + self.set_huggingface_weight_by_name(f"h.{layer_id}.self_attention.query_key_value.bias", weights_dicts["self_attention_qkv_bias"][layer_id]) + self.set_huggingface_weight_by_name(f"h.{layer_id}.self_attention.dense.weight", weights_dicts["self_attention_dense_weight"][layer_id]) + self.set_huggingface_weight_by_name(f"h.{layer_id}.mlp.dense_h_to_4h.weight", weights_dicts["mlp_dense_h_to_4h_weight"][layer_id]) + self.set_huggingface_weight_by_name(f"h.{layer_id}.mlp.dense_h_to_4h.bias", weights_dicts["mlp_dense_h_to_4h_bias"][layer_id]) + self.set_huggingface_weight_by_name(f"h.{layer_id}.mlp.dense_4h_to_h.weight", weights_dicts["mlp_dense_4h_to_h_weight"][layer_id]) + + return True + + def convert_from_mprank_model(self, pp_size, tp_size, num_layers): + weights_dicts = {"word_embeddings": None, "self_attention_qkv_weight": {}, "self_attention_qkv_bias": {}, \ + "self_attention_dense_weight": {}, "mlp_dense_h_to_4h_weight": {}, "mlp_dense_h_to_4h_bias": {}, "mlp_dense_4h_to_h_weight": {}} + + for pp_rank in range(pp_size): + for tp_rank in range(tp_size): + model_path = os.path.join(self.ascendspeed_model_dir, f"{'mp_rank_{:02d}'.format(pp_rank * tp_size + tp_rank)}_model_states.pt") + if not os.path.exists(model_path): + print(f"Error! {model_path} does not exist") + return False + as_pt_model = torch.load(model_path, map_location="cpu") + rank_model = as_pt_model["module"]["module"] + + if pp_rank == 0: + + self.set_huggingface_weight_by_name("word_embeddings_layernorm.weight", rank_model["tied_modules.embed.word_embeddings.norm.weight"]) + self.set_huggingface_weight_by_name("word_embeddings_layernorm.bias", rank_model["tied_modules.embed.word_embeddings.norm.bias"]) + word_embeddings_read = rank_model["tied_modules.embed.word_embeddings.weight"] + weights_dicts["word_embeddings"] = row_concat(weights_dicts["word_embeddings"], word_embeddings_read, tp_size, tp_rank) + + if pp_rank == pp_size - 1: + as_layer_id = num_layers + 4 + self.set_huggingface_weight_by_name("ln_f.weight", rank_model[f"{as_layer_id}.weight"]) + self.set_huggingface_weight_by_name("ln_f.bias", rank_model[f"{as_layer_id}.bias"]) + + for i in range(self.pp_layers[pp_rank]): + layer_id = sum(self.pp_layers[:pp_rank]) + i + as_layer_id = layer_id + 3 + + self.set_huggingface_weight_by_name(f"h.{layer_id}.input_layernorm.weight", rank_model[f"{as_layer_id}.input_layernorm.weight"]) + self.set_huggingface_weight_by_name(f"h.{layer_id}.input_layernorm.bias", rank_model[f"{as_layer_id}.input_layernorm.bias"]) + + self.set_huggingface_weight_by_name(f"h.{layer_id}.post_attention_layernorm.weight", rank_model[f"{as_layer_id}.post_attention_layernorm.weight"]) + self.set_huggingface_weight_by_name(f"h.{layer_id}.post_attention_layernorm.bias", rank_model[f"{as_layer_id}.post_attention_layernorm.bias"]) + self.set_huggingface_weight_by_name(f"h.{layer_id}.self_attention.dense.bias", rank_model[f"{as_layer_id}.self_attention.dense.bias"]) + self.set_huggingface_weight_by_name(f"h.{layer_id}.mlp.dense_4h_to_h.bias", rank_model[f"{as_layer_id}.mlp.dense_4h_to_h.bias"]) + + self_attention_qkv_weight = rank_model[f"{as_layer_id}.self_attention.query_key_value.weight"] + self_attention_qkv_bias = rank_model[f"{as_layer_id}.self_attention.query_key_value.bias"] + self_attention_dense_weight = rank_model[f"{as_layer_id}.self_attention.dense.weight"] + mlp_dense_h_to_4h_weight = rank_model[f"{as_layer_id}.mlp.dense_h_to_4h.weight"] + mlp_dense_h_to_4h_bias = rank_model[f"{as_layer_id}.mlp.dense_h_to_4h.bias"] + mlp_dense_4h_to_h_weight = rank_model[f"{as_layer_id}.mlp.dense_4h_to_h.weight"] + + if layer_id not in weights_dicts["self_attention_qkv_weight"]: + weights_dicts["self_attention_qkv_weight"][layer_id] = None + weights_dicts["self_attention_qkv_bias"][layer_id] = None + weights_dicts["self_attention_dense_weight"][layer_id] = None + weights_dicts["mlp_dense_h_to_4h_weight"][layer_id] = None + weights_dicts["mlp_dense_h_to_4h_bias"][layer_id] = None + weights_dicts["mlp_dense_4h_to_h_weight"][layer_id] = None + + weights_dicts["self_attention_qkv_weight"][layer_id] = row_concat(weights_dicts["self_attention_qkv_weight"][layer_id], self_attention_qkv_weight, tp_size, tp_rank) + weights_dicts["self_attention_qkv_bias"][layer_id] = row_concat(weights_dicts["self_attention_qkv_bias"][layer_id], self_attention_qkv_bias, tp_size, tp_rank) + weights_dicts["self_attention_dense_weight"][layer_id] = column_concat(weights_dicts["self_attention_dense_weight"][layer_id], self_attention_dense_weight, tp_size, tp_rank) + weights_dicts["mlp_dense_h_to_4h_weight"][layer_id] = row_concat(weights_dicts["mlp_dense_h_to_4h_weight"][layer_id], mlp_dense_h_to_4h_weight, tp_size, tp_rank) + weights_dicts["mlp_dense_h_to_4h_bias"][layer_id] = row_concat(weights_dicts["mlp_dense_h_to_4h_bias"][layer_id], mlp_dense_h_to_4h_bias, tp_size, tp_rank) + weights_dicts["mlp_dense_4h_to_h_weight"][layer_id] = column_concat(weights_dicts["mlp_dense_4h_to_h_weight"][layer_id], mlp_dense_4h_to_h_weight, tp_size, tp_rank) + + self.set_huggingface_weight_by_name("word_embeddings.weight", weights_dicts["word_embeddings"]) + for layer_id in weights_dicts["self_attention_qkv_weight"]: + self.set_huggingface_weight_by_name(f"h.{layer_id}.self_attention.query_key_value.weight", weights_dicts["self_attention_qkv_weight"][layer_id]) + self.set_huggingface_weight_by_name(f"h.{layer_id}.self_attention.query_key_value.bias", weights_dicts["self_attention_qkv_bias"][layer_id]) + self.set_huggingface_weight_by_name(f"h.{layer_id}.self_attention.dense.weight", weights_dicts["self_attention_dense_weight"][layer_id]) + self.set_huggingface_weight_by_name(f"h.{layer_id}.mlp.dense_h_to_4h.weight", weights_dicts["mlp_dense_h_to_4h_weight"][layer_id]) + self.set_huggingface_weight_by_name(f"h.{layer_id}.mlp.dense_h_to_4h.bias", weights_dicts["mlp_dense_h_to_4h_bias"][layer_id]) + self.set_huggingface_weight_by_name(f"h.{layer_id}.mlp.dense_4h_to_h.weight", weights_dicts["mlp_dense_4h_to_h_weight"][layer_id]) + + return True + + def generate_huggingface_weight(self): + try: + num_layer, _, _ = AscendspeedToHuggingfaceConvert.model_config[self.model_type] + except KeyError: + print(f"Error! {self.model_type} is not supported!") + return False + if self.check_has_layer_model(): + self.convert_from_layer_model(self.pp_size, self.tp_size, num_layer) + else: + self.convert_from_mprank_model(self.pp_size, self.tp_size, num_layer) + os.makedirs(self.output_huggingface_model_dir, exist_ok=True) + for file_name in self.huggingface_model: + dest_path = os.path.join(self.output_huggingface_model_dir, file_name) + print(f"Saving huggingface model to : {dest_path}") + torch.save(self.huggingface_model[file_name], dest_path) + +def convert_llama(args): + coverter = AscendspeedToHuggingfaceConvert(args) + coverter.generate_huggingface_weight() -- Gitee From 1964521b888c2b2238af0d6ae1060476d1285a03 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=A9=AC=E5=AE=81=E6=82=A6?= Date: Tue, 12 Mar 2024 03:16:59 +0000 Subject: [PATCH 4/4] update README.md. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 马宁悦 --- README.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/README.md b/README.md index cb203544c7..8998eaa476 100644 --- a/README.md +++ b/README.md @@ -21,6 +21,9 @@ Ascend Training Tools,昇腾训练工具链。针对训练&大模型场景, 脚本迁移工具提供后端命令行用于将GPU上训练的PyTorch脚本迁移至NPU上,得到新的训练脚本用于训练。 +4. [脚本迁移工具](https://gitee.com/ascend/att/wikis/%E5%B7%A5%E5%85%B7%E4%BB%8B%E7%BB%8D/%E5%88%86%E6%9E%90%E8%BF%81%E7%A7%BB%E5%B7%A5%E5%85%B7/%E8%84%9A%E6%9C%AC%E8%BF%81%E7%A7%BB%E5%B7%A5%E5%85%B7%E4%BD%BF%E7%94%A8%E6%8C%87%E5%AF%BC) + + 脚本迁移工具提供后端命令行用于将GPU上训练的PyTorch脚本迁移至NPU上,得到新的训练脚本用于推理。 ### [精度工具](https://gitee.com/ascend/att/tree/master/debug/accuracy_tools) -- Gitee