From f6a5551cea3c5a4829af6bf3fcc263a12ee46dcf Mon Sep 17 00:00:00 2001 From: liujunzhu Date: Thu, 3 Apr 2025 20:34:24 +0800 Subject: [PATCH] add no_init_weights --- tools/rules/line_rules.py | 117 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 117 insertions(+) diff --git a/tools/rules/line_rules.py b/tools/rules/line_rules.py index ead60783..310f5eb8 100644 --- a/tools/rules/line_rules.py +++ b/tools/rules/line_rules.py @@ -258,6 +258,123 @@ LINE_RULES = { - logits.register_hook(self.finalnorm_ckpt.recompute) if args.output_multiplier_scale: logits = logits * args.output_multiplier_scale"""], + "mindspeed_llm/tasks/posttrain/base/base_trainer.py":[ + """ import sys ++import os ++from contextlib import contextmanager + import time""",""" _TRAIN_START_TIME = time.time() + ++TORCH_INIT_FUNCTIONS = { ++ "uniform_": torch.nn.init.uniform_, ++ "normal_": torch.nn.init.normal_, ++ "trunc_normal_": torch.nn.init.trunc_normal_, ++ "constant_": torch.nn.init.constant_, ++ "xavier_uniform_": torch.nn.init.xavier_uniform_, ++ "xavier_normal_": torch.nn.init.xavier_normal_, ++ "kaiming_uniform_": torch.nn.init.kaiming_uniform_, ++ "kaiming_normal_": torch.nn.init.kaiming_normal_, ++ "uniform": torch.nn.init.uniform, ++ "normal": torch.nn.init.normal, ++ "xavier_uniform": torch.nn.init.xavier_uniform, ++ "xavier_normal": torch.nn.init.xavier_normal, ++ "kaiming_uniform": torch.nn.init.kaiming_uniform, ++ "kaiming_normal": torch.nn.init.kaiming_normal, ++} ++ ++@contextmanager ++def no_init_weights(enable=True): ++ if enable: ++ def _skip_init(*args, **kwargs): ++ pass ++ ++ # # Save the original initialization functions ++ for name, init_func in TORCH_INIT_FUNCTIONS.items(): ++ setattr(torch.nn.init, name, _skip_init) ++ try: ++ yield ++ finally: ++ if enable: ++ # # Restore the original initialization functions ++ for name, init_func in TORCH_INIT_FUNCTIONS.items(): ++ setattr(torch.nn.init, name, init_func) ++ + + class BaseTrainer(ABC):""",""" config = core_transformer_config_from_args(args) + +- if args.use_mcore_models: +- if args.spec is not None: +- transformer_layer_spec = import_module(args.spec) +- else: +- if use_te: +- transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec(args.num_experts, +- args.moe_grouped_gemm) ++ enable_no_init_weights = os.getenv('ENABLE_NO_INIT_WEIGHTS') == '1' ++ with no_init_weights(enable_no_init_weights): ++ if args.use_mcore_models: ++ if args.spec is not None: ++ transformer_layer_spec = import_module(args.spec) + else: +- transformer_layer_spec = get_gpt_layer_local_spec(args.num_experts, args.moe_grouped_gemm) +- +- model = GPTModel( +- config=config, +- transformer_layer_spec=transformer_layer_spec, +- vocab_size=args.padded_vocab_size, +- max_sequence_length=args.max_position_embeddings, +- pre_process=pre_process, +- post_process=post_process, +- fp16_lm_cross_entropy=args.fp16_lm_cross_entropy, +- parallel_output=True, +- share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights, +- position_embedding_type=args.position_embedding_type, +- rotary_percent=args.rotary_percent, +- seq_len_interpolation_factor=args.rotary_seq_len_interpolation_factor +- ) +- else: +- if not args.context_parallel_size == 1: +- raise ValueError("Context parallelism is only supported with Megatron Core!") +- +- model = megatron.legacy.model.GPTModel( +- config, +- num_tokentypes=0, +- parallel_output=True, +- pre_process=pre_process, +- post_process=post_process +- ) ++ if use_te: ++ transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec(args.num_experts, ++ args.moe_grouped_gemm) ++ else: ++ transformer_layer_spec = get_gpt_layer_local_spec(args.num_experts, args.moe_grouped_gemm) ++ ++ model = GPTModel( ++ config=config, ++ transformer_layer_spec=transformer_layer_spec, ++ vocab_size=args.padded_vocab_size, ++ max_sequence_length=args.max_position_embeddings, ++ pre_process=pre_process, ++ post_process=post_process, ++ fp16_lm_cross_entropy=args.fp16_lm_cross_entropy, ++ parallel_output=True, ++ share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights, ++ position_embedding_type=args.position_embedding_type, ++ rotary_percent=args.rotary_percent, ++ seq_len_interpolation_factor=args.rotary_seq_len_interpolation_factor ++ ) ++ else: ++ if not args.context_parallel_size == 1: ++ raise ValueError("Context parallelism is only supported with Megatron Core!") ++ ++ model = megatron.legacy.model.GPTModel( ++ config, ++ num_tokentypes=0, ++ parallel_output=True, ++ pre_process=pre_process, ++ post_process=post_process ++ ) + + return model""" + ], "mindspeed_llm/tasks/posttrain/orm/orm_model.py":[ """ # we sometimes want to run our RM head in FP32, this allows it - autocast_context = torch.autocast(device_type=hidden_states.device.type, dtype=self.dtype) -- Gitee