Ai
1 Star 0 Fork 143

王紫东/MindSpeed-Core-MS

forked from Ascend/MindSpeed-Core-MS
暂停
 
加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
pretrain_gpt.py 8.40 KB
一键复制 编辑 原始数据 按行查看 历史
wangshuangling 提交于 2024-12-16 15:35 +08:00 . adapt to latest training procedure
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
"""Pretrain GPT."""
import os
from functools import partial
from typing import Union
import mindspore as ms
import mindspore.communication.comm_func as comm_func
from mindspore import mint
import mindspeed_ms
from mindspeed_ms.training import get_args, print_rank_0, get_tokenizer
from mindspeed_ms.core import mpu
from mindspeed_ms.core.enums import ModelType
from mindspeed_ms.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder
from mindspeed_ms.core.datasets.utils import get_blend_from_list
from mindspeed_ms.core.datasets.gpt_dataset import GPTDatasetConfig, MockGPTDataset, GPTDataset
from mindspeed_ms.core.models.gpt import GPTModel
from mindspeed_ms.training import pretrain
from mindspeed_ms.core.transformer.spec_utils import import_module
from mindspeed_ms.training.utils import (
get_batch_on_this_cp_rank,
get_batch_on_this_tp_rank,
)
from mindspeed_ms.training.arguments import core_transformer_config_from_args
from mindspeed_ms.training.yaml_arguments import core_transformer_config_from_yaml
from mindspeed_ms.core.models.gpt.gpt_layer_specs import get_gpt_layer_local_spec
from mindspeed_ms.core.tensor_parallel.mappings import ReduceFromContextParallelRegion
def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, mindspeed_ms.legacy.model.GPTModel]:
"""Builds the model.
If you set the use_legacy_models to True, it will return the legacy GPT model and if not the mcore GPT model.
Args:
pre_process (bool, optional): Set to true if you need to compute embedings. Defaults to True.
post_process (bool, optional): Set to true if you need to want to compute output logits/loss. Defaults to True.
Returns:
Union[GPTModel, mindspeed_ms.legacy.model.GPTModel]: The returned model
"""
args = get_args()
use_te = args.transformer_impl == "transformer_engine"
print_rank_0('building GPT model ...')
# Experimental loading arguments from yaml
if args.yaml_cfg is not None:
config = core_transformer_config_from_yaml(args, "language_model")
else:
config = core_transformer_config_from_args(args)
if args.use_mcore_models:
if args.spec is not None:
transformer_layer_spec = import_module(args.spec)
else:
# pylint: disable=R1720
if use_te:
raise NotImplementedError("'transformer_engine' is not supported for now.")
else:
transformer_layer_spec = get_gpt_layer_local_spec(args.num_experts,
args.moe_grouped_gemm,
args.qk_layernorm)
model = GPTModel(
config=config,
transformer_layer_spec=transformer_layer_spec,
vocab_size=args.padded_vocab_size,
max_sequence_length=args.max_position_embeddings,
pre_process=pre_process,
post_process=post_process,
fp16_lm_cross_entropy=args.fp16_lm_cross_entropy,
parallel_output=True,
share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights,
position_embedding_type=args.position_embedding_type,
rotary_percent=args.rotary_percent,
rotary_base=args.rotary_base
)
else:
model = mindspeed_ms.legacy.model.GPTModel(
config,
num_tokentypes=0,
parallel_output=True,
pre_process=pre_process,
post_process=post_process,
)
return model
def get_batch(data_iterator):
"""Generate a batch."""
# TODO: this is pretty hacky, find a better way
if (not mpu.is_pipeline_first_stage()) and (not mpu.is_pipeline_last_stage()):
return None, None, None, None, None
# get batches based on the TP rank you are on
batch = get_batch_on_this_tp_rank(data_iterator)
# slice batch along sequence dimension for context parallelism
batch = get_batch_on_this_cp_rank(batch)
return batch.values()
def loss_func(loss_mask: ms.Tensor, output_tensor: ms.Tensor):
"""Loss function.
Args:
loss_mask (ms.Tensor): Used to mask out some portions of the loss
output_tensor (ms.Tensor): The tensor with the losses
Returns:
the loss scalar for this micro-batch
the number of non-padded tokens in this microbatch
a dict containing reporting metrics on the loss and number of tokens across
the data parallel ranks
"""
args = get_args()
losses = output_tensor.float()
loss_mask = loss_mask.view(-1).float()
total_tokens = loss_mask.sum()
loss = mint.cat([mint.sum(losses.view(-1) * loss_mask).view(1), total_tokens.view(1)])
if args.context_parallel_size > 1:
loss = ReduceFromContextParallelRegion()(loss)
# Check individual rank losses are not NaN prior to DP all-reduce.
if args.check_for_nan_in_loss_and_grad:
global_rank = ms.communication.get_rank()
assert not loss[0].isnan(), (
f'Rank {global_rank}: found NaN in local forward loss calculation. '
f'Device: {ms.hal.get_device_name()}, node: {os.uname()[1]}'
)
# Reduce loss for logging.
reporting_loss = loss.copy()
reporting_loss = comm_func.all_reduce(reporting_loss, group=mpu.get_data_parallel_group())[0]
local_num_tokens = loss[1].copy().to(ms.int32)
return (
loss[0] * args.context_parallel_size,
local_num_tokens,
{'lm loss': (reporting_loss[0], reporting_loss[1])},
)
def forward_step(data_iterator, model: GPTModel):
"""Forward training step.
Args:
data_iterator : Input data iterator
model (GPTModel): The GPT Model
"""
# Get the batch.
tokens, labels, loss_mask, attention_mask, position_ids = get_batch(data_iterator)
input_tensor = (tokens, labels, attention_mask, position_ids)
def core_forward_func(*args):
""" core forward func """
tokens, labels, attention_mask, position_ids = args
output_tensor = model(tokens, position_ids, attention_mask, labels=labels)
return output_tensor
return input_tensor, core_forward_func, partial(loss_func, loss_mask)
def is_dataset_built_on_rank():
return (mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage()) and mpu.get_tensor_model_parallel_rank() == 0
def core_gpt_dataset_config_from_args(args):
tokenizer = get_tokenizer()
return GPTDatasetConfig(
random_seed=args.seed,
sequence_length=args.seq_length,
blend=get_blend_from_list(args.data_path),
blend_per_split=[
get_blend_from_list(args.train_data_path),
get_blend_from_list(args.valid_data_path),
get_blend_from_list(args.test_data_path)
],
split=args.split,
num_dataset_builder_threads=args.num_dataset_builder_threads,
path_to_cache=args.data_cache_path,
mmap_bin_files=args.mmap_bin_files,
tokenizer=tokenizer,
reset_position_ids=args.reset_position_ids,
reset_attention_mask=args.reset_attention_mask,
eod_mask_loss=args.eod_mask_loss,
create_attention_mask=args.create_attention_mask_in_dataloader,
s3_cache_path=args.s3_cache_path
)
def train_valid_test_datasets_provider(train_val_test_num_samples):
"""Build the train test and validation datasets.
Args:
train_val_test_num_samples : A list containing the number of samples in train test and validation.
"""
args = get_args()
config = core_gpt_dataset_config_from_args(args)
if args.mock_data:
dataset_type = MockGPTDataset
else:
dataset_type = GPTDataset
print_rank_0("> building train, validation, and test datasets for GPT ...")
train_ds, valid_ds, test_ds = BlendedMegatronDatasetBuilder(
dataset_type,
train_val_test_num_samples,
is_dataset_built_on_rank,
config
).build()
print_rank_0("> finished creating GPT datasets ...")
return train_ds, valid_ds, test_ds
if __name__ == "__main__":
# Temporary for transition to core datasets
train_valid_test_datasets_provider.is_distributed = True
pretrain(
train_valid_test_datasets_provider,
model_provider,
ModelType.encoder_or_decoder,
forward_step,
args_defaults={'tokenizer_type': 'GPT2BPETokenizer'},
)
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/wang_zi_dong/MindSpeed-Core-MS.git
git@gitee.com:wang_zi_dong/MindSpeed-Core-MS.git
wang_zi_dong
MindSpeed-Core-MS
MindSpeed-Core-MS
dev

搜索帮助