From 1182b593376af928b6906892b49bd5486c07b67d Mon Sep 17 00:00:00 2001 From: avocadovo Date: Thu, 2 Jan 2025 17:44:50 +0800 Subject: [PATCH 1/4] =?UTF-8?q?=E3=80=90Tinker=E3=80=91=E6=A8=A1=E5=9D=97?= =?UTF-8?q?=E8=BF=81=E7=A7=BB=20model=20Part2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- profiler/tinker/model/block_adapter_1_1.py | 333 ++++++++++++++++++++ profiler/tinker/model/block_adapter_1_2.py | 340 +++++++++++++++++++++ 2 files changed, 673 insertions(+) create mode 100644 profiler/tinker/model/block_adapter_1_1.py create mode 100644 profiler/tinker/model/block_adapter_1_2.py diff --git a/profiler/tinker/model/block_adapter_1_1.py b/profiler/tinker/model/block_adapter_1_1.py new file mode 100644 index 000000000..28a3ce6fa --- /dev/null +++ b/profiler/tinker/model/block_adapter_1_1.py @@ -0,0 +1,333 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from contextlib import nullcontext +from dataclasses import dataclass + +import torch + +from megatron import core +from megatron.core import tensor_parallel +from megatron.legacy.model.gpt_model import post_language_model_processing +from tinker.megatron_patch.microbatches import get_num_microbatches +from tinker.model.block_adapter import BlockAdapter + + +@dataclass +class TransformerForwardInput: + hidden_states: any = None + attention_mask: any = None + encoder_output: any = None + enc_dec_attn_mask: any = None + retriever_input: any = None + retriever_output: any = None + retriever_attn_mask: any = None + inference_params: any = None + rotary_pos_emb: any = None + + +class TransformerBlockAdapter(BlockAdapter): + """parallel_transformer_forward""" + + def __init__(self): + pass + + @staticmethod + def origin_forward(self, input_data: TransformerForwardInput): + hidden_states = input_data.hidden_states + attention_mask = input_data.attention_mask + encoder_output = input_data.encoder_output + enc_dec_attn_mask = input_data.enc_dec_attn_mask + retriever_input = input_data.retriever_input + retriever_output = input_data.retriever_output + retriever_attn_mask = input_data.retriever_attn_mask + inference_params = input_data.inference_params + rotary_pos_emb = input_data.rotary_pos_emb + + #Checks. + if inference_params: + if self.recompute_granularity is not None: + raise ValueError('inference does not work with activation checkpointing') + + if not self.pre_process: + # See set_input_tensor() + hidden_states = self.input_tensor + + # Viewless tensor + # - We only need to create a viewless tensor in the case of micro batch + # size (mbs) == 1, since in this case, 'hidden_states.transpose()' + # above creates a view tensor, and '.contiguous()' is a pass-through. + # For mbs >= 2, '.contiguous()' creates a new tensor, eliminating + # the need to make it viewless. + # + # However, we don't explicitly check mbs == 1 here because + # make_viewless_tensor() has negligible overhead when its input + # is already viewless. + # + # - For the 'else' case above, calling make_viewless_tensor() here is + # likely redundant, since p2p_communication.py (likely originator) + # already creates viewless tensors. That said, make_viewless_tensor() + # is called here to be future-proof and corner-case-proof. + if self.input_embeds_norm and self.pre_process: + hidden_states = hidden_states * (self.hidden_size ** 0.5) + + hidden_states = core.utils.make_viewless_tensor( + hidden_states, + requires_grad=True, + keep_graph=True, + ) + + # RNG context. + if self.sequence_parallel: + rng_context = tensor_parallel.get_cuda_rng_tracker().fork() + else: + rng_context = nullcontext() + + # Forward layers + with rng_context: + # The fp8_autocast context manager is a no-op when enabled=True + # The if...else serves to short circuit name resolution for fp8_autocast + with transformer_engine.pytorch.fp8_autocast( + enabled=self.use_fp8, + fp8_recipe=self.fp8_recipe, + fp8_group=self.fp8_group + ) if self.use_fp8 else nullcontext(): + # Determine if the current iteration is first microbatch + if self.num_microbatches_in_previous_step != get_num_microbatches(): + self.microbatch_count = 0 # Reset count on new batch size rampup interval + self.num_microbatches_in_previous_step = get_num_microbatches() + is_first_microbatch = self.microbatch_count % get_num_microbatches() == 0 + + # Forward pass. + if self.recompute_granularity == 'full': + hidden_states = self._checkpointed_forward(hidden_states, + attention_mask, + encoder_output, + enc_dec_attn_mask, + rotary_pos_emb, + is_first_microbatch) + else: + forward_kwargs = { + 'encoder_output': encoder_output, + 'enc_dec_attn_mask': enc_dec_attn_mask, + 'inference_params': inference_params, + } + + if self.transformer_impl == 'transformer_engine': + forward_kwargs['is_first_microbatch'] = is_first_microbatch + forward_kwargs['checkpoint_core_attention'] = self.checkpoint_core_attention + if self.transformer_engine_v_0_10: + forward_kwargs['rotary_pos_emb'] = rotary_pos_emb + else: + forward_kwargs['rotary_pos_emb'] = rotary_pos_emb + forward_kwargs['retriever_input'] = retriever_input + forward_kwargs['retriever_output'] = retriever_output + forward_kwargs['retriever_attn_mask'] = retriever_attn_mask + + for index in range(self.num_layers): + layer = self._get_layer(index) + + hidden_states = layer( + hidden_states, + attention_mask, + **forward_kwargs) + + # First Retro decoder layer returns both hidden_states + # and retriever_output. Make retriever_output available + # to subsequence Retro layers. + if isinstance(hidden_states, tuple): + if len(hidden_states) != 2: + raise ValueError("hidden_states should be a tuple of length 2") + hidden_states, retriever_output = hidden_states + forward_kwargs["retriever_output"] = retriever_output + + # Skip counter update for eval and activation checkpointing + if torch.is_grad_enabled() and self.training: + self.microbatch_count += 1 + return hidden_states + + def get_output_name(self): + return ['hidden_states'] + + def copy_method_module(self, other): + other._get_layer = self.module._get_layer + other.layers = self.module.layers + other._checkpointed_forward = self.module._checkpointed_forward # 其实没必要,因为profile时不会用到,但可起到标注作用 + if other.transformer_impl == 'transformer_engine': + global transformer_engine + # noinsepction PyUnresolvedReferences + import transformer_engine + + def weight_module(self): + return self.module + + +@dataclass +class EmbeddingForwardInput: + enc_input_ids: any = None + enc_position_ids: any = None + enc_attn_mask: any = None + dec_input_ids: any = None + dec_position_ids: any = None + dec_attn_mask: any = None + retriever_input_ids: any = None + retriever_position_ids: any = None + retriever_attn_mask: any = None + enc_dec_attn_mask: any = None + tokentype_ids: any = None + inference_params: any = None + pooling_sequence_index: any = 0 + enc_hidden_states: any = None + output_enc_hidden: any = False + + +class EmbeddingAdapter(BlockAdapter): + """megatron.legacy.model.language_model.TransformerLanguageModel.forward""" + + def __init__(self): + pass + + @staticmethod + def origin_forward(self, input_data: EmbeddingForwardInput): + enc_input_ids = input_data.enc_input_ids + enc_position_ids = input_data.enc_position_ids + enc_attn_mask = input_data.enc_attn_mask + dec_input_ids = input_data.dec_input_ids + dec_position_ids = input_data.dec_position_ids + dec_attn_mask = input_data.dec_attn_mask + retriever_input_ids = input_data.retriever_input_ids + retriever_position_ids = input_data.retriever_position_ids + retriever_attn_mask = input_data.retriever_attn_mask + enc_dec_attn_mask = input_data.enc_dec_attn_mask + tokentype_ids = input_data.tokentype_ids + inference_params = input_data.inference_params + pooling_sequence_index = input_data.pooling_sequence_index + enc_hidden_states = input_data.enc_hidden_states + output_enc_hidden = input_data.output_enc_hidden + + # Encoder embedding. + if self.pre_process: + encoder_input = self.embedding(enc_input_ids, enc_position_ids, + tokentype_ids=tokentype_ids) + else: + encoder_input = None + + # Retriever embedding. + if self.add_retriever and self.pre_process: + retriever_input = self.embedding(retriever_input_ids, + retriever_position_ids, + tokentype_ids=tokentype_ids) + else: + retriever_input = None + + # Rotary positional embeddings + rotary_pos_emb = None + if self.use_rotary_position_embeddings: + if inference_params is not None: + rotary_pos_emb = \ + self.rotary_pos_emb(inference_params.max_sequence_length) + else: + rotary_pos_emb = self.rotary_pos_emb(self.seq_length) + + return encoder_input, rotary_pos_emb + + def get_output_name(self): + return ['encoder_input', 'rotary_pos_emb'] + + def copy_method_module(self, other): + other.embedding = self.module.embedding + if other.use_rotary_position_embeddings: + other.rotary_pos_emb = self.module.rotary_pos_emb + + def weight_module(self): + return self.module.embedding + + +class FinalNormAdapter(BlockAdapter): + """ParallelTransformer forward 中 self.final_norm 相关部分""" + + def __init__(self): + pass + + @staticmethod + def origin_forward(self, hidden_states): + hidden_states = self.final_norm(hidden_states) + + return hidden_states + + def get_output_name(self): + return ['hidden_states'] + + def copy_method_module(self, other): + other.final_norm = self.module.final_norm + + def weight_module(self): + return self.module.final_norm + + +@dataclass +class LossForwardInput: + input_ids: any = None + position_ids: any = None + attention_mask: any = None + lm_output: any = None + retriever_input_ids: any = None + retriever_position_ids: any = None + retriever_attn_mask: any = None + labels: any = None + tokentype_ids: any = None + inference_params: any = None + + +class LossAdapter(BlockAdapter): + """modellink.model.GPTModel""" + + def __init__(self): + pass + + @staticmethod + def origin_forward(self, input_data: LossForwardInput): + input_ids = input_data.input_ids + position_ids = input_data.position_ids + attention_mask = input_data.attention_mask + lm_output = input_data.lm_output + retriever_input_ids = input_data.retriever_input_ids + retriever_position_ids = input_data.retriever_position_ids + retriever_attn_mask = input_data.retriever_attn_mask + labels = input_data.labels + tokentype_ids = input_data.tokentype_ids + inference_params = input_data.inference_params + + output = post_language_model_processing( + lm_output, labels, + # 此处不进行判断 而直接输入该形状匹配的内容 + self.logit_weights, + self.parallel_output, + self.fp16_lm_cross_entropy) + + return output + + def get_output_name(self): + return ["output"] + + def copy_method_module(self, other): + # git 适配 not self.untie_embeddings_and_output_weights 情况 + if other.untie_embeddings_and_output_weights: + other.logit_weights = self.module.language_model.output_layer.weight + else: + other.logit_weights = self.module.language_model.embedding.word_embeddings.weight + + def weight_module(self): + return self.module.language_model.output_layer diff --git a/profiler/tinker/model/block_adapter_1_2.py b/profiler/tinker/model/block_adapter_1_2.py new file mode 100644 index 000000000..38782daa3 --- /dev/null +++ b/profiler/tinker/model/block_adapter_1_2.py @@ -0,0 +1,340 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from contextlib import nullcontext +from dataclasses import dataclass + + +import torch +from torch import Tensor + +from megatron.core import parallel_state, tensor_parallel, InferenceParams +from megatron.core.packed_seq_params import PackedSeqParams +from meagtron.core.utils import make_viewless_tensor +from megatron.training import get_args +from tinker.model import block_adapter_1_1 +from tinker.model.block_adapter import BlockAdapter + +EmbeddingAdapter = block_adapter_1_1.EmbeddingAdapter +TransformerBlockAdapter = block_adapter_1_1.TransformerBlockAdapter +FinalNormAdapter = block_adapter_1_1.FinalNormAdapter +LossAdapter = block_adapter_1_1.LossAdapter + + +class DummyDecoder: + def __init__(self): + self.input_tensor = None + + +@dataclass +class McoreForwardInput: + input_ids: Tensor = None + position_ids: Tensor = None + attention_mask: Tensor = None + decoder_input: Tensor = None + labels: Tensor = None + inference_params: InferenceParams = None + packed_seq_params: PackedSeqParams = None + extra_block_kwargs: dict = None + tokentype_ids=None + + +class McoreEmbeddingAdapter(BlockAdapter): + """modellink.core.gpt_model_forward""" + + def __init__(self): + pass + + @staticmethod + def origin_forward(self, input_data: McoreForwardInput): + input_ids = input_data.input_ids + position_ids = input_data.position_ids + attention_mask = input_data.attention_mask + decoder_input = input_data.decoder_input + labels = input_data.labels + inference_params = input_data.inference_params + packed_seq_params = input_data.packed_seq_params + extra_block_kwargs = input_data.extra_block_kwargs + tokentype_ids = input_data.tokentype_ids + """modellink.core.gpt_model_forward""" + # If decoder_input is provides (not None), then input_ids and position_ids are ignored. + # Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input. + args = get_args() + # Decoder embedding. + if decoder_input is not None: + pass + elif self.pre_process: + decoder_input = self.embedding(input_ids=input_ids, position_ids=position_ids) + if args.scale_emb is not None: + decoder_input = decoder_input * args.scale_emb + else: + # intermediate stage of pipeline + # decoder will get hidden_states from encoder.input_tensor + decoder_input = None + + # Rotary positional embeddings (embedding is None for PP intermediate devices) + rotary_pos_emb = None + if self.position_embedding_type == 'rope': + rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len( + inference_params, self.decoder, decoder_input, self.config + ) + rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len) + + return decoder_input, rotary_pos_emb + + def get_output_name(self): + return ['decoder_input', 'rotary_pos_emb'] + + def copy_method_module(self, other): + # 此decoder用于避免前向报错,但其本身值并不被使用 + other.decoder = DummyDecoder() + other.embedding = self.module.embedding + other.rotary_pos_emb = self.module.rotary_pos_emb + + def weight_module(self): + return self.module.embedding, self.module.rotary_pos_emb + + +@dataclass +class McoreTransformerForwardInput: + hidden_states: Tensor = None + attention_mask: Tensor = None + context: Tensor = None + context_mask: Tensor = None + rotary_pos_emb: Tensor = None + inference_params: InferenceParams = None + packed_seq_params: PackedSeqParams = None + + +class McoreTransformerBlockAdapter(BlockAdapter): + """megatron.core.transformer.transformer_block.TransformerBlock""" + + def __init__(self): + pass + + @staticmethod + def origin_forward(self, input_data: McoreTransformerForwardInput): + """modellink.core.transformer_block_forward""" + hidden_states = input_data.hidden_states + attention_mask = input_data.attention_mask + context = input_data.context + context_mask = input_data.context_mask + rotary_pos_emb = input_data.rotary_pos_emb + inference_params = input_data.inference_params + packed_seq_params = input_data.packed_seq_params + + # hidden_states (float): [s, b, h] + # attention_mask (bool): [1, 1, s, s] + + if not self.pre_process: + # See set_input_tensor() + hidden_states = self.input_tensor + + # Viewless tensor. + # - We only need to create a viewless tensor in the case of micro batch + # size (mbs) == 1, since in this case, 'hidden_states.transpose()' + # above creates a view tensor, and '.contiguous()' is a pass-through. + # For mbs >= 2, '.contiguous()' creates a new tensor, eliminating + # the need to make it viewless. + # + # However, we don't explicitly check mbs == 1 here because + # make_viewless_tensor() has negligible overhead when its input + # is already viewless. + # + # - For the 'else' case above, calling make_viewless_tensor() here is + # likely redundant, since p2p_communication.py (likely originator) + # already creates viewless tensors. That said, make_viewless_tensor() + # is called here to be future-proof and corner-case-proof. + if self.input_embeds_norm and self.pre_process: + normalizer = torch.tensor(self.hidden_size ** 0.5, dtype=hidden_states.dtype) + hidden_states = hidden_states * normalizer + + hidden_states = make_viewless_tensor( + inp=hidden_states, requires_grad=True, keep_graph=True, + ) + + if self.config.sequence_parallel: + rng_context = tensor_parallel.get_cuda_rng_tracker().fork() + else: + rng_context = nullcontext() + + if self.config.fp8: + import transformer_engine # To keep out TE dependency when not training in fp8 + + if self.config.fp8 == "e4m3": + fp8_format = transformer_engine.common.recipe.Format.E4M3 + elif self.config.fp8 == "hybrid": + fp8_format = transformer_engine.common.recipe.Format.HYBRID + else: + raise ValueError("E4M3 and HYBRID are the only supported FP8 formats.") + + fp8_recipe = transformer_engine.common.recipe.DelayedScaling( + margin=self.config.fp8_margin, + interval=self.config.fp8_interval, + fp8_format=fp8_format, + amax_compute_algo=self.config.fp8_amax_compute_algo, + amax_history_len=self.config.fp8_amax_history_len, + override_linear_precision=(False, False, not self.config.fp8_wgrad), + ) + fp8_group = None + if parallel_state.model_parallel_is_initialized(): + fp8_group = parallel_state.get_amax_reduction_group(with_context_parallel=True) + fp8_context = transformer_engine.pytorch.fp8_autocast( + enabled=True, fp8_recipe=fp8_recipe, fp8_group=fp8_group + ) + else: + fp8_context = nullcontext() + + with rng_context and fp8_context: + # Forward pass. + if self.config.recompute_granularity == 'full' and self.training: + hidden_states = self._checkpointed_forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + context=context, + context_mask=context_mask, + rotary_pos_emb=rotary_pos_emb, + packed_seq_params=packed_seq_params, + ) + else: + for layer in self.layers: + with self.offload_context: + hidden_states, context = layer( + hidden_states=hidden_states, + attention_mask=attention_mask, + context=context, + context_mask=context_mask, + rotary_pos_emb=rotary_pos_emb, + inference_params=inference_params, + packed_seq_params=packed_seq_params, + ) + + if ( + torch.is_grad_enabled() + and self.config.cpu_offloading + and self.group_prefetch_offload_commit_async is not None + ): + hidden_states = self.group_prefetch_offload_commit_async(hidden_states) + + return hidden_states + + def get_output_name(self): + return ['hidden_states'] + + def copy_method_module(self, other): + other.layers = self.module.layers + + def weight_module(self): + return self.module.layers + + +class McoreFinalNormAdapter(BlockAdapter): + + def __init__(self): + pass + + @staticmethod + def origin_forward(self, hidden_states: Tensor = None): + """modellink.core.transformer_block_forward""" + hidden_states = self.final_layernorm(hidden_states) + + return hidden_states + + def get_output_name(self): + return ['hidden_states'] + + def copy_method_module(self, other): + other.final_layernorm = self.module.final_layernorm + + def weight_module(self): + return self.module.final_layernorm + + +@dataclass +class McoreLossForwardInput: + input_ids: Tensor = None + label: any = None + hidden_states: any = None + position_ids: Tensor = None + attention_mask: Tensor = None + decoder_input: Tensor = None + labels: Tensor = None + inference_params: InferenceParams = None + packed_seq_params: PackedSeqParams = None + extra_block_kwargs: dict = None + tokentype_ids: any = None + + +class McoreLossAdapter(BlockAdapter): + """modellink.core.gpt_model_forward""" + + def __init__(self): + pass + + @staticmethod + def origin_forward(self, input_data: McoreLossForwardInput): + """modellink.core.gpt_model_forward""" + input_ids = input_data.input_ids + label = input_data.label + hidden_states = input_data.hidden_states + position_ids = input_data.position_ids + attention_mask = input_data.attention_mask + decoder_input = input_data.decoder_input + labels = input_data.labels + inference_params = input_data.inference_params + packed_seq_params = input_data.packed_seq_params + extra_block_kwargs = input_data.extra_block_kwargs + tokentype_ids = input_data.tokentype_ids + + args = get_args() + # logits and loss + output_weight = None + if self.share_embeddings_and_output_weights: + output_weight = self.shared_embedding_or_output_weight() + + if args.dim_model_base is not None: + hidden_states = hidden_states / (args.hidden_size / args.dim_model_base) + logits, _ = self.output_layer(hidden_states, weight=output_weight) + # new add to scale logits + if args.output_multiplier_scale: + logits = logits * args.output_multiplier_scale + + if args.output_logit_softcapping: + logits = logits / args.output_logit_softcapping + logits = torch.tanh(logits) + logits = logits * args.output_logit_softcapping + + if labels is None: + # [s b h] => [b s h] + return logits.transpose(0, 1).contiguous() + + if args.is_instruction_dataset: + labels = labels[:, 1:].contiguous() + logits = logits[:-1, :, :].contiguous() + + loss = self.compute_language_model_loss(labels, logits) + + return loss + + def get_output_name(self): + return ['loss'] + + def copy_method_module(self, other): + other.output_layer = self.module.output_layer + other.compute_language_model_loss = self.module.compute_language_model_loss + other.shared_embedding_or_output_weight = self.module.shared_embedding_or_output_weight + + def weight_module(self): + return self.module.output_layer -- Gitee From 4eb47e6b0d840b613f7c358d9f67882c877794ae Mon Sep 17 00:00:00 2001 From: avocadovo Date: Thu, 2 Jan 2025 19:28:37 +0800 Subject: [PATCH 2/4] =?UTF-8?q?=E3=80=90Tinker=E3=80=91model=20=E5=8F=AF?= =?UTF-8?q?=E8=AF=BB=E6=80=A7=E5=86=85=E5=AE=B9=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- profiler/tinker/model/block_adapter_1_1.py | 3 +-- profiler/tinker/model/block_adapter_1_2.py | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/profiler/tinker/model/block_adapter_1_1.py b/profiler/tinker/model/block_adapter_1_1.py index 28a3ce6fa..7fccb257a 100644 --- a/profiler/tinker/model/block_adapter_1_1.py +++ b/profiler/tinker/model/block_adapter_1_1.py @@ -236,8 +236,7 @@ class EmbeddingAdapter(BlockAdapter): rotary_pos_emb = None if self.use_rotary_position_embeddings: if inference_params is not None: - rotary_pos_emb = \ - self.rotary_pos_emb(inference_params.max_sequence_length) + rotary_pos_emb = self.rotary_pos_emb(inference_params.max_sequence_length) else: rotary_pos_emb = self.rotary_pos_emb(self.seq_length) diff --git a/profiler/tinker/model/block_adapter_1_2.py b/profiler/tinker/model/block_adapter_1_2.py index 38782daa3..50ceba808 100644 --- a/profiler/tinker/model/block_adapter_1_2.py +++ b/profiler/tinker/model/block_adapter_1_2.py @@ -88,8 +88,7 @@ class McoreEmbeddingAdapter(BlockAdapter): rotary_pos_emb = None if self.position_embedding_type == 'rope': rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len( - inference_params, self.decoder, decoder_input, self.config - ) + inference_params, self.decoder, decoder_input, self.config) rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len) return decoder_input, rotary_pos_emb -- Gitee From 4dbc7a1537650865c69c59a9630f3ad0db468de8 Mon Sep 17 00:00:00 2001 From: avocadovo Date: Fri, 3 Jan 2025 17:04:38 +0800 Subject: [PATCH 3/4] =?UTF-8?q?=E3=80=90Tinker=E3=80=91=E6=A8=A1=E5=9D=97?= =?UTF-8?q?=E5=8F=98=E9=87=8F=E5=90=8D=E7=BB=9F=E4=B8=80=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- profiler/tinker/model/block_adapter_1_1.py | 262 ++++++++++----------- profiler/tinker/model/block_adapter_1_2.py | 241 +++++++++---------- 2 files changed, 230 insertions(+), 273 deletions(-) diff --git a/profiler/tinker/model/block_adapter_1_1.py b/profiler/tinker/model/block_adapter_1_1.py index 7fccb257a..c6a5709dd 100644 --- a/profiler/tinker/model/block_adapter_1_1.py +++ b/profiler/tinker/model/block_adapter_1_1.py @@ -27,15 +27,15 @@ from tinker.model.block_adapter import BlockAdapter @dataclass class TransformerForwardInput: - hidden_states: any = None - attention_mask: any = None - encoder_output: any = None - enc_dec_attn_mask: any = None - retriever_input: any = None - retriever_output: any = None - retriever_attn_mask: any = None - inference_params: any = None - rotary_pos_emb: any = None + model_hidden_states: any = None + model_attention_mask: any = None + model_encoder_output: any = None + model_enc_model_dec_attn_mask: any = None + model_retriever_input: any = None + model_retriever_output: any = None + model_retriever_attn_mask: any = None + model_inference_params: any = None + model_rotary_pos_emb: any = None class TransformerBlockAdapter(BlockAdapter): @@ -46,45 +46,29 @@ class TransformerBlockAdapter(BlockAdapter): @staticmethod def origin_forward(self, input_data: TransformerForwardInput): - hidden_states = input_data.hidden_states - attention_mask = input_data.attention_mask - encoder_output = input_data.encoder_output - enc_dec_attn_mask = input_data.enc_dec_attn_mask - retriever_input = input_data.retriever_input - retriever_output = input_data.retriever_output - retriever_attn_mask = input_data.retriever_attn_mask - inference_params = input_data.inference_params - rotary_pos_emb = input_data.rotary_pos_emb + model_hidden_states = input_data.model_hidden_states + model_attention_mask = input_data.model_attention_mask + model_encoder_output = input_data.model_encoder_output + model_enc_model_dec_attn_mask = input_data.model_enc_model_dec_attn_mask + model_retriever_input = input_data.model_retriever_input + model_retriever_output = input_data.model_retriever_output + model_retriever_attn_mask = input_data.model_retriever_attn_mask + model_inference_params = input_data.model_inference_params + model_rotary_pos_emb = input_data.model_rotary_pos_emb #Checks. - if inference_params: + if model_inference_params: if self.recompute_granularity is not None: raise ValueError('inference does not work with activation checkpointing') if not self.pre_process: - # See set_input_tensor() - hidden_states = self.input_tensor - - # Viewless tensor - # - We only need to create a viewless tensor in the case of micro batch - # size (mbs) == 1, since in this case, 'hidden_states.transpose()' - # above creates a view tensor, and '.contiguous()' is a pass-through. - # For mbs >= 2, '.contiguous()' creates a new tensor, eliminating - # the need to make it viewless. - # - # However, we don't explicitly check mbs == 1 here because - # make_viewless_tensor() has negligible overhead when its input - # is already viewless. - # - # - For the 'else' case above, calling make_viewless_tensor() here is - # likely redundant, since p2p_communication.py (likely originator) - # already creates viewless tensors. That said, make_viewless_tensor() - # is called here to be future-proof and corner-case-proof. + model_hidden_states = self.input_tensor + if self.input_embeds_norm and self.pre_process: - hidden_states = hidden_states * (self.hidden_size ** 0.5) + model_hidden_states = model_hidden_states * (self.hidden_size ** 0.5) - hidden_states = core.utils.make_viewless_tensor( - hidden_states, + model_hidden_states = core.utils.make_viewless_tensor( + model_hidden_states, requires_grad=True, keep_graph=True, ) @@ -97,8 +81,6 @@ class TransformerBlockAdapter(BlockAdapter): # Forward layers with rng_context: - # The fp8_autocast context manager is a no-op when enabled=True - # The if...else serves to short circuit name resolution for fp8_autocast with transformer_engine.pytorch.fp8_autocast( enabled=self.use_fp8, fp8_recipe=self.fp8_recipe, @@ -112,54 +94,50 @@ class TransformerBlockAdapter(BlockAdapter): # Forward pass. if self.recompute_granularity == 'full': - hidden_states = self._checkpointed_forward(hidden_states, - attention_mask, - encoder_output, - enc_dec_attn_mask, - rotary_pos_emb, + model_hidden_states = self._checkpointed_forward(model_hidden_states, + model_attention_mask, + model_encoder_output, + model_enc_model_dec_attn_mask, + model_rotary_pos_emb, is_first_microbatch) else: - forward_kwargs = { - 'encoder_output': encoder_output, - 'enc_dec_attn_mask': enc_dec_attn_mask, - 'inference_params': inference_params, + forward_pass_kwargs = { + 'model_encoder_output': model_encoder_output, + 'model_enc_model_dec_attn_mask': model_enc_model_dec_attn_mask, + 'model_inference_params': model_inference_params, } if self.transformer_impl == 'transformer_engine': - forward_kwargs['is_first_microbatch'] = is_first_microbatch - forward_kwargs['checkpoint_core_attention'] = self.checkpoint_core_attention + forward_pass_kwargs['is_first_microbatch'] = is_first_microbatch + forward_pass_kwargs['checkpoint_core_attention'] = self.checkpoint_core_attention if self.transformer_engine_v_0_10: - forward_kwargs['rotary_pos_emb'] = rotary_pos_emb + forward_pass_kwargs['model_rotary_pos_emb'] = model_rotary_pos_emb else: - forward_kwargs['rotary_pos_emb'] = rotary_pos_emb - forward_kwargs['retriever_input'] = retriever_input - forward_kwargs['retriever_output'] = retriever_output - forward_kwargs['retriever_attn_mask'] = retriever_attn_mask + forward_pass_kwargs['model_rotary_pos_emb'] = model_rotary_pos_emb + forward_pass_kwargs['model_retriever_input'] = model_retriever_input + forward_pass_kwargs['model_retriever_output'] = model_retriever_output + forward_pass_kwargs['model_retriever_attn_mask'] = model_retriever_attn_mask for index in range(self.num_layers): layer = self._get_layer(index) - hidden_states = layer( - hidden_states, - attention_mask, - **forward_kwargs) + model_hidden_states = layer( + model_hidden_states, + model_attention_mask, + **forward_pass_kwargs) - # First Retro decoder layer returns both hidden_states - # and retriever_output. Make retriever_output available - # to subsequence Retro layers. - if isinstance(hidden_states, tuple): - if len(hidden_states) != 2: - raise ValueError("hidden_states should be a tuple of length 2") - hidden_states, retriever_output = hidden_states - forward_kwargs["retriever_output"] = retriever_output - - # Skip counter update for eval and activation checkpointing + if isinstance(model_hidden_states, tuple): + if len(model_hidden_states) != 2: + raise ValueError("model_hidden_states should be a tuple of length 2") + model_hidden_states, model_retriever_output = model_hidden_states + forward_pass_kwargs["model_retriever_output"] = model_retriever_output + if torch.is_grad_enabled() and self.training: self.microbatch_count += 1 - return hidden_states + return model_hidden_states def get_output_name(self): - return ['hidden_states'] + return ['model_hidden_states'] def copy_method_module(self, other): other._get_layer = self.module._get_layer @@ -176,21 +154,21 @@ class TransformerBlockAdapter(BlockAdapter): @dataclass class EmbeddingForwardInput: - enc_input_ids: any = None - enc_position_ids: any = None - enc_attn_mask: any = None - dec_input_ids: any = None - dec_position_ids: any = None - dec_attn_mask: any = None - retriever_input_ids: any = None - retriever_position_ids: any = None - retriever_attn_mask: any = None - enc_dec_attn_mask: any = None - tokentype_ids: any = None - inference_params: any = None - pooling_sequence_index: any = 0 - enc_hidden_states: any = None - output_enc_hidden: any = False + model_enc_input_ids: any = None + model_enc_model_position_ids: any = None + model_enc_attn_mask: any = None + model_dec_input_ids: any = None + model_dec_model_position_ids: any = None + model_dec_attn_mask: any = None + model_retriever_input_ids: any = None + model_retriever_model_position_ids: any = None + model_retriever_attn_mask: any = None + model_enc_model_dec_attn_mask: any = None + model_tokentype_ids: any = None + model_inference_params: any = None + model_pooling_sequence_index: any = 0 + model_enc_model_hidden_states: any = None + model_output_enc_hidden: any = False class EmbeddingAdapter(BlockAdapter): @@ -201,54 +179,54 @@ class EmbeddingAdapter(BlockAdapter): @staticmethod def origin_forward(self, input_data: EmbeddingForwardInput): - enc_input_ids = input_data.enc_input_ids - enc_position_ids = input_data.enc_position_ids - enc_attn_mask = input_data.enc_attn_mask - dec_input_ids = input_data.dec_input_ids - dec_position_ids = input_data.dec_position_ids - dec_attn_mask = input_data.dec_attn_mask - retriever_input_ids = input_data.retriever_input_ids - retriever_position_ids = input_data.retriever_position_ids - retriever_attn_mask = input_data.retriever_attn_mask - enc_dec_attn_mask = input_data.enc_dec_attn_mask - tokentype_ids = input_data.tokentype_ids - inference_params = input_data.inference_params - pooling_sequence_index = input_data.pooling_sequence_index - enc_hidden_states = input_data.enc_hidden_states - output_enc_hidden = input_data.output_enc_hidden + model_enc_input_ids = input_data.model_enc_input_ids + model_enc_model_position_ids = input_data.model_enc_model_position_ids + model_enc_attn_mask = input_data.model_enc_attn_mask + model_dec_input_ids = input_data.model_dec_input_ids + model_dec_model_position_ids = input_data.model_dec_model_position_ids + model_dec_attn_mask = input_data.model_dec_attn_mask + model_retriever_input_ids = input_data.model_retriever_input_ids + model_retriever_model_position_ids = input_data.model_retriever_model_position_ids + model_retriever_attn_mask = input_data.model_retriever_attn_mask + model_enc_model_dec_attn_mask = input_data.model_enc_model_dec_attn_mask + model_tokentype_ids = input_data.model_tokentype_ids + model_inference_params = input_data.model_inference_params + model_pooling_sequence_index = input_data.model_pooling_sequence_index + model_enc_model_hidden_states = input_data.model_enc_model_hidden_states + model_output_enc_hidden = input_data.model_output_enc_hidden # Encoder embedding. if self.pre_process: - encoder_input = self.embedding(enc_input_ids, enc_position_ids, - tokentype_ids=tokentype_ids) + encoder_input = self.embedding(model_enc_input_ids, model_enc_model_position_ids, + model_tokentype_ids=model_tokentype_ids) else: encoder_input = None # Retriever embedding. if self.add_retriever and self.pre_process: - retriever_input = self.embedding(retriever_input_ids, - retriever_position_ids, - tokentype_ids=tokentype_ids) + model_retriever_input = self.embedding(model_retriever_input_ids, + model_retriever_model_position_ids, + model_tokentype_ids=model_tokentype_ids) else: - retriever_input = None + model_retriever_input = None # Rotary positional embeddings - rotary_pos_emb = None + model_rotary_pos_emb = None if self.use_rotary_position_embeddings: - if inference_params is not None: - rotary_pos_emb = self.rotary_pos_emb(inference_params.max_sequence_length) + if model_inference_params is not None: + model_rotary_pos_emb = self.model_rotary_pos_emb(model_inference_params.max_sequence_length) else: - rotary_pos_emb = self.rotary_pos_emb(self.seq_length) + model_rotary_pos_emb = self.model_rotary_pos_emb(self.seq_length) - return encoder_input, rotary_pos_emb + return encoder_input, model_rotary_pos_emb def get_output_name(self): - return ['encoder_input', 'rotary_pos_emb'] + return ['encoder_input', 'model_rotary_pos_emb'] def copy_method_module(self, other): other.embedding = self.module.embedding if other.use_rotary_position_embeddings: - other.rotary_pos_emb = self.module.rotary_pos_emb + other.model_rotary_pos_emb = self.module.model_rotary_pos_emb def weight_module(self): return self.module.embedding @@ -261,13 +239,13 @@ class FinalNormAdapter(BlockAdapter): pass @staticmethod - def origin_forward(self, hidden_states): - hidden_states = self.final_norm(hidden_states) + def origin_forward(self, model_hidden_states): + model_hidden_states = self.final_norm(model_hidden_states) - return hidden_states + return model_hidden_states def get_output_name(self): - return ['hidden_states'] + return ['model_hidden_states'] def copy_method_module(self, other): other.final_norm = self.module.final_norm @@ -278,16 +256,16 @@ class FinalNormAdapter(BlockAdapter): @dataclass class LossForwardInput: - input_ids: any = None - position_ids: any = None - attention_mask: any = None - lm_output: any = None - retriever_input_ids: any = None - retriever_position_ids: any = None - retriever_attn_mask: any = None - labels: any = None - tokentype_ids: any = None - inference_params: any = None + model_input_ids: any = None + model_position_ids: any = None + model_attention_mask: any = None + model_lm_output: any = None + model_retriever_input_ids: any = None + model_retriever_model_position_ids: any = None + model_retriever_attn_mask: any = None + model_labels: any = None + model_tokentype_ids: any = None + model_inference_params: any = None class LossAdapter(BlockAdapter): @@ -298,19 +276,19 @@ class LossAdapter(BlockAdapter): @staticmethod def origin_forward(self, input_data: LossForwardInput): - input_ids = input_data.input_ids - position_ids = input_data.position_ids - attention_mask = input_data.attention_mask - lm_output = input_data.lm_output - retriever_input_ids = input_data.retriever_input_ids - retriever_position_ids = input_data.retriever_position_ids - retriever_attn_mask = input_data.retriever_attn_mask - labels = input_data.labels - tokentype_ids = input_data.tokentype_ids - inference_params = input_data.inference_params + model_input_ids = input_data.model_input_ids + model_position_ids = input_data.model_position_ids + model_attention_mask = input_data.model_attention_mask + model_lm_output = input_data.model_lm_output + model_retriever_input_ids = input_data.model_retriever_input_ids + model_retriever_model_position_ids = input_data.model_retriever_model_position_ids + model_retriever_attn_mask = input_data.model_retriever_attn_mask + model_labels = input_data.model_labels + model_tokentype_ids = input_data.model_tokentype_ids + model_inference_params = input_data.model_inference_params output = post_language_model_processing( - lm_output, labels, + model_lm_output, model_labels, # 此处不进行判断 而直接输入该形状匹配的内容 self.logit_weights, self.parallel_output, diff --git a/profiler/tinker/model/block_adapter_1_2.py b/profiler/tinker/model/block_adapter_1_2.py index 50ceba808..f0957f957 100644 --- a/profiler/tinker/model/block_adapter_1_2.py +++ b/profiler/tinker/model/block_adapter_1_2.py @@ -21,7 +21,7 @@ import torch from torch import Tensor from megatron.core import parallel_state, tensor_parallel, InferenceParams -from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.model_packed_seq_params import PackedSeqParams from meagtron.core.utils import make_viewless_tensor from megatron.training import get_args from tinker.model import block_adapter_1_1 @@ -40,15 +40,15 @@ class DummyDecoder: @dataclass class McoreForwardInput: - input_ids: Tensor = None - position_ids: Tensor = None - attention_mask: Tensor = None - decoder_input: Tensor = None - labels: Tensor = None - inference_params: InferenceParams = None - packed_seq_params: PackedSeqParams = None - extra_block_kwargs: dict = None - tokentype_ids=None + model_input_ids: Tensor = None + model_position_ids: Tensor = None + model_attention_mask: Tensor = None + model_decoder_input: Tensor = None + model_labels: Tensor = None + model_inference_params: InferenceParams = None + model_packed_seq_params: PackedSeqParams = None + model_extra_block_kwargs: dict = None + model_tokentype_ids=None class McoreEmbeddingAdapter(BlockAdapter): @@ -59,62 +59,59 @@ class McoreEmbeddingAdapter(BlockAdapter): @staticmethod def origin_forward(self, input_data: McoreForwardInput): - input_ids = input_data.input_ids - position_ids = input_data.position_ids - attention_mask = input_data.attention_mask - decoder_input = input_data.decoder_input - labels = input_data.labels - inference_params = input_data.inference_params - packed_seq_params = input_data.packed_seq_params - extra_block_kwargs = input_data.extra_block_kwargs - tokentype_ids = input_data.tokentype_ids + model_input_ids = input_data.model_input_ids + model_position_ids = input_data.model_position_ids + model_attention_mask = input_data.model_attention_mask + model_decoder_input = input_data.model_decoder_input + model_labels = input_data.model_labels + model_inference_params = input_data.model_inference_params + model_packed_seq_params = input_data.model_packed_seq_params + model_extra_block_kwargs = input_data.model_extra_block_kwargs + model_tokentype_ids = input_data.model_tokentype_ids """modellink.core.gpt_model_forward""" - # If decoder_input is provides (not None), then input_ids and position_ids are ignored. - # Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input. + args = get_args() # Decoder embedding. - if decoder_input is not None: + if model_decoder_input is not None: pass elif self.pre_process: - decoder_input = self.embedding(input_ids=input_ids, position_ids=position_ids) + model_decoder_input = self.embedding(model_input_ids=model_input_ids, model_position_ids=model_position_ids) if args.scale_emb is not None: - decoder_input = decoder_input * args.scale_emb + model_decoder_input = model_decoder_input * args.scale_emb else: - # intermediate stage of pipeline - # decoder will get hidden_states from encoder.input_tensor - decoder_input = None + model_decoder_input = None # Rotary positional embeddings (embedding is None for PP intermediate devices) - rotary_pos_emb = None + model_rotary_pos_emb = None if self.position_embedding_type == 'rope': - rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len( - inference_params, self.decoder, decoder_input, self.config) - rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len) + rotary_seq_len = self.model_rotary_pos_emb.get_rotary_seq_len( + model_inference_params, self.decoder, model_decoder_input, self.config) + model_rotary_pos_emb = self.model_rotary_pos_emb(rotary_seq_len) - return decoder_input, rotary_pos_emb + return model_decoder_input, model_rotary_pos_emb def get_output_name(self): - return ['decoder_input', 'rotary_pos_emb'] + return ['model_decoder_input', 'model_rotary_pos_emb'] def copy_method_module(self, other): # 此decoder用于避免前向报错,但其本身值并不被使用 other.decoder = DummyDecoder() other.embedding = self.module.embedding - other.rotary_pos_emb = self.module.rotary_pos_emb + other.model_rotary_pos_emb = self.module.model_rotary_pos_emb def weight_module(self): - return self.module.embedding, self.module.rotary_pos_emb + return self.module.embedding, self.module.model_rotary_pos_emb @dataclass class McoreTransformerForwardInput: - hidden_states: Tensor = None - attention_mask: Tensor = None - context: Tensor = None - context_mask: Tensor = None - rotary_pos_emb: Tensor = None - inference_params: InferenceParams = None - packed_seq_params: PackedSeqParams = None + model_hidden_states: Tensor = None + model_attention_mask: Tensor = None + model_context: Tensor = None + model_context_mask: Tensor = None + model_rotary_pos_emb: Tensor = None + model_inference_params: InferenceParams = None + model_packed_seq_params: PackedSeqParams = None class McoreTransformerBlockAdapter(BlockAdapter): @@ -126,48 +123,30 @@ class McoreTransformerBlockAdapter(BlockAdapter): @staticmethod def origin_forward(self, input_data: McoreTransformerForwardInput): """modellink.core.transformer_block_forward""" - hidden_states = input_data.hidden_states - attention_mask = input_data.attention_mask - context = input_data.context - context_mask = input_data.context_mask - rotary_pos_emb = input_data.rotary_pos_emb - inference_params = input_data.inference_params - packed_seq_params = input_data.packed_seq_params - - # hidden_states (float): [s, b, h] - # attention_mask (bool): [1, 1, s, s] + model_hidden_states = input_data.model_hidden_states + model_attention_mask = input_data.model_attention_mask + model_context = input_data.model_context + model_context_mask = input_data.model_context_mask + model_rotary_pos_emb = input_data.model_rotary_pos_emb + model_inference_params = input_data.model_inference_params + model_packed_seq_params = input_data.model_packed_seq_params if not self.pre_process: # See set_input_tensor() - hidden_states = self.input_tensor - - # Viewless tensor. - # - We only need to create a viewless tensor in the case of micro batch - # size (mbs) == 1, since in this case, 'hidden_states.transpose()' - # above creates a view tensor, and '.contiguous()' is a pass-through. - # For mbs >= 2, '.contiguous()' creates a new tensor, eliminating - # the need to make it viewless. - # - # However, we don't explicitly check mbs == 1 here because - # make_viewless_tensor() has negligible overhead when its input - # is already viewless. - # - # - For the 'else' case above, calling make_viewless_tensor() here is - # likely redundant, since p2p_communication.py (likely originator) - # already creates viewless tensors. That said, make_viewless_tensor() - # is called here to be future-proof and corner-case-proof. + model_hidden_states = self.input_tensor + if self.input_embeds_norm and self.pre_process: - normalizer = torch.tensor(self.hidden_size ** 0.5, dtype=hidden_states.dtype) - hidden_states = hidden_states * normalizer + normalizer = torch.tensor(self.hidden_size ** 0.5, dtype=model_hidden_states.dtype) + model_hidden_states = model_hidden_states * normalizer - hidden_states = make_viewless_tensor( - inp=hidden_states, requires_grad=True, keep_graph=True, + model_hidden_states = make_viewless_tensor( + inp=model_hidden_states, requires_grad=True, keep_graph=True, ) if self.config.sequence_parallel: - rng_context = tensor_parallel.get_cuda_rng_tracker().fork() + forward_rng_context = tensor_parallel.get_cuda_rng_tracker().fork() else: - rng_context = nullcontext() + forward_rng_context = nullcontext() if self.config.fp8: import transformer_engine # To keep out TE dependency when not training in fp8 @@ -179,7 +158,7 @@ class McoreTransformerBlockAdapter(BlockAdapter): else: raise ValueError("E4M3 and HYBRID are the only supported FP8 formats.") - fp8_recipe = transformer_engine.common.recipe.DelayedScaling( + forward_fp8_recipe = transformer_engine.common.recipe.DelayedScaling( margin=self.config.fp8_margin, interval=self.config.fp8_interval, fp8_format=fp8_format, @@ -187,37 +166,37 @@ class McoreTransformerBlockAdapter(BlockAdapter): amax_history_len=self.config.fp8_amax_history_len, override_linear_precision=(False, False, not self.config.fp8_wgrad), ) - fp8_group = None + forward_fp8_group = None if parallel_state.model_parallel_is_initialized(): - fp8_group = parallel_state.get_amax_reduction_group(with_context_parallel=True) - fp8_context = transformer_engine.pytorch.fp8_autocast( - enabled=True, fp8_recipe=fp8_recipe, fp8_group=fp8_group + forward_fp8_group = parallel_state.get_amax_reduction_group(with_context_parallel=True) + forward_fp8_context = transformer_engine.pytorch.fp8_autocast( + enabled=True, forward_fp8_recipe=forward_fp8_recipe, forward_fp8_group=forward_fp8_group ) else: - fp8_context = nullcontext() + forward_fp8_context = nullcontext() - with rng_context and fp8_context: + with forward_rng_context and forward_fp8_context: # Forward pass. if self.config.recompute_granularity == 'full' and self.training: - hidden_states = self._checkpointed_forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - context=context, - context_mask=context_mask, - rotary_pos_emb=rotary_pos_emb, - packed_seq_params=packed_seq_params, + model_hidden_states = self._checkpointed_forward( + model_hidden_states=model_hidden_states, + model_attention_mask=model_attention_mask, + model_context=model_context, + model_context_mask=model_context_mask, + model_rotary_pos_emb=model_rotary_pos_emb, + model_packed_seq_params=model_packed_seq_params, ) else: for layer in self.layers: with self.offload_context: - hidden_states, context = layer( - hidden_states=hidden_states, - attention_mask=attention_mask, - context=context, - context_mask=context_mask, - rotary_pos_emb=rotary_pos_emb, - inference_params=inference_params, - packed_seq_params=packed_seq_params, + model_hidden_states, model_context = layer( + model_hidden_states=model_hidden_states, + model_attention_mask=model_attention_mask, + model_context=model_context, + model_context_mask=model_context_mask, + model_rotary_pos_emb=model_rotary_pos_emb, + model_inference_params=model_inference_params, + model_packed_seq_params=model_packed_seq_params, ) if ( @@ -225,12 +204,12 @@ class McoreTransformerBlockAdapter(BlockAdapter): and self.config.cpu_offloading and self.group_prefetch_offload_commit_async is not None ): - hidden_states = self.group_prefetch_offload_commit_async(hidden_states) + model_hidden_states = self.group_prefetch_offload_commit_async(model_hidden_states) - return hidden_states + return model_hidden_states def get_output_name(self): - return ['hidden_states'] + return ['model_hidden_states'] def copy_method_module(self, other): other.layers = self.module.layers @@ -245,14 +224,14 @@ class McoreFinalNormAdapter(BlockAdapter): pass @staticmethod - def origin_forward(self, hidden_states: Tensor = None): + def origin_forward(self, model_hidden_states: Tensor = None): """modellink.core.transformer_block_forward""" - hidden_states = self.final_layernorm(hidden_states) + model_hidden_states = self.final_layernorm(model_hidden_states) - return hidden_states + return model_hidden_states def get_output_name(self): - return ['hidden_states'] + return ['model_hidden_states'] def copy_method_module(self, other): other.final_layernorm = self.module.final_layernorm @@ -263,17 +242,17 @@ class McoreFinalNormAdapter(BlockAdapter): @dataclass class McoreLossForwardInput: - input_ids: Tensor = None - label: any = None - hidden_states: any = None - position_ids: Tensor = None - attention_mask: Tensor = None - decoder_input: Tensor = None - labels: Tensor = None - inference_params: InferenceParams = None - packed_seq_params: PackedSeqParams = None - extra_block_kwargs: dict = None - tokentype_ids: any = None + model_input_ids: Tensor = None + model_label: any = None + model_hidden_states: any = None + model_position_ids: Tensor = None + model_attention_mask: Tensor = None + model_decoder_input: Tensor = None + model_labels: Tensor = None + model_inference_params: InferenceParams = None + model_packed_seq_params: PackedSeqParams = None + model_extra_block_kwargs: dict = None + model_tokentype_ids: any = None class McoreLossAdapter(BlockAdapter): @@ -285,17 +264,17 @@ class McoreLossAdapter(BlockAdapter): @staticmethod def origin_forward(self, input_data: McoreLossForwardInput): """modellink.core.gpt_model_forward""" - input_ids = input_data.input_ids - label = input_data.label - hidden_states = input_data.hidden_states - position_ids = input_data.position_ids - attention_mask = input_data.attention_mask - decoder_input = input_data.decoder_input - labels = input_data.labels - inference_params = input_data.inference_params - packed_seq_params = input_data.packed_seq_params - extra_block_kwargs = input_data.extra_block_kwargs - tokentype_ids = input_data.tokentype_ids + model_input_ids = input_data.model_input_ids + model_label = input_data.model_label + model_hidden_states = input_data.model_hidden_states + model_position_ids = input_data.model_position_ids + model_attention_mask = input_data.model_attention_mask + model_decoder_input = input_data.model_decoder_input + model_labels = input_data.model_labels + model_inference_params = input_data.model_inference_params + model_packed_seq_params = input_data.model_packed_seq_params + model_extra_block_kwargs = input_data.model_extra_block_kwargs + model_tokentype_ids = input_data.model_tokentype_ids args = get_args() # logits and loss @@ -304,8 +283,8 @@ class McoreLossAdapter(BlockAdapter): output_weight = self.shared_embedding_or_output_weight() if args.dim_model_base is not None: - hidden_states = hidden_states / (args.hidden_size / args.dim_model_base) - logits, _ = self.output_layer(hidden_states, weight=output_weight) + model_hidden_states = model_hidden_states / (args.hidden_size / args.dim_model_base) + logits, _ = self.output_layer(model_hidden_states, weight=output_weight) # new add to scale logits if args.output_multiplier_scale: logits = logits * args.output_multiplier_scale @@ -315,15 +294,15 @@ class McoreLossAdapter(BlockAdapter): logits = torch.tanh(logits) logits = logits * args.output_logit_softcapping - if labels is None: + if model_labels is None: # [s b h] => [b s h] return logits.transpose(0, 1).contiguous() if args.is_instruction_dataset: - labels = labels[:, 1:].contiguous() + model_labels = model_labels[:, 1:].contiguous() logits = logits[:-1, :, :].contiguous() - loss = self.compute_language_model_loss(labels, logits) + loss = self.compute_language_model_loss(model_labels, logits) return loss -- Gitee From a5dadb6850b2c9bc8cad39893d7c89fe492b6187 Mon Sep 17 00:00:00 2001 From: avocadovo Date: Fri, 3 Jan 2025 17:37:18 +0800 Subject: [PATCH 4/4] =?UTF-8?q?=E3=80=90Tinker=E3=80=91=E6=A8=A1=E5=9D=97?= =?UTF-8?q?=E9=87=8D=E6=9E=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- profiler/tinker/model/block_adapter_1_1.py | 151 ++++++++++++--------- profiler/tinker/model/block_adapter_1_2.py | 92 +++++++------ 2 files changed, 133 insertions(+), 110 deletions(-) diff --git a/profiler/tinker/model/block_adapter_1_1.py b/profiler/tinker/model/block_adapter_1_1.py index c6a5709dd..a2ced039c 100644 --- a/profiler/tinker/model/block_adapter_1_1.py +++ b/profiler/tinker/model/block_adapter_1_1.py @@ -56,85 +56,102 @@ class TransformerBlockAdapter(BlockAdapter): model_inference_params = input_data.model_inference_params model_rotary_pos_emb = input_data.model_rotary_pos_emb - #Checks. - if model_inference_params: + # Checks. + if input_data.model_inference_params: if self.recompute_granularity is not None: raise ValueError('inference does not work with activation checkpointing') + model_hidden_states = self.preprocess_model_hidden_states(model_hidden_states) + model_hidden_states = self.make_viewless_tensor(model_hidden_states) + + forward_rng_context = self.get_forward_rng_context() + forward_fp8_context = self.get_forward_fp8_context() + + with forward_rng_context, forward_fp8_context: + if self.num_microbatches_previous != get_num_microbatches(): + self.microbatch_count = 0 + self.num_microbatches_previous = get_num_microbatches() + is_first_microbatch = self.microbatch_count % get_num_microbatches() == 0 + + # Forward pass. + if self.recompute_granularity == 'full': + model_hidden_states = self.checkpointed_forward(model_hidden_states, + model_attention_mask, + model_encoder_output, + model_enc_model_dec_attn_mask, + model_rotary_pos_emb, + is_first_microbatch) + else: + forward_pass_kwargs = self.prepare_forward_pass_kwargs(model_attention_mask, + model_encoder_output, + model_enc_model_dec_attn_mask, + model_inference_params, + model_retriever_input, + model_retriever_output, + model_retriever_attn_mask, + model_rotary_pos_emb, + is_first_microbatch) + + for index in range(self.num_layers): + layer = self._get_layer(index) + model_hidden_states = layer(model_hidden_states, model_attention_mask, **forward_pass_kwargs) + + if isinstance(model_hidden_states, tuple): + if len(model_hidden_states) != 2: + raise ValueError("model_hidden_states should be a tuple of length 2") + model_hidden_states, model_retriever_output = model_hidden_states + forward_pass_kwargs['model_retriever_output'] = model_retriever_output + + self.update_microbatch_count() + return model_hidden_states + + def preprocess_model_hidden_states(self, model_hidden_states, model_attention_mask): if not self.pre_process: model_hidden_states = self.input_tensor - if self.input_embeds_norm and self.pre_process: model_hidden_states = model_hidden_states * (self.hidden_size ** 0.5) + return model_hidden_states - model_hidden_states = core.utils.make_viewless_tensor( - model_hidden_states, - requires_grad=True, - keep_graph=True, - ) - - # RNG context. + def make_viewless_tensor(self, model_hidden_states): + return core.utils.make_viewless_tensor(model_hidden_states, requires_grad=True, keep_graph=True) + + def get_forward_rng_context(self): if self.sequence_parallel: - rng_context = tensor_parallel.get_cuda_rng_tracker().fork() + return tensor_parallel.get_cuda_rng_tracker().fork() + else: + return nullcontext() + + def get_forward_fp8_context(self): + if self.use_fp8: + return transformer_engine.pytorch.fp8_autocast(enabled=True, + fp8_recipe=self.fp8_recipe, + fp8_group=self.fp8_group) else: - rng_context = nullcontext() - - # Forward layers - with rng_context: - with transformer_engine.pytorch.fp8_autocast( - enabled=self.use_fp8, - fp8_recipe=self.fp8_recipe, - fp8_group=self.fp8_group - ) if self.use_fp8 else nullcontext(): - # Determine if the current iteration is first microbatch - if self.num_microbatches_in_previous_step != get_num_microbatches(): - self.microbatch_count = 0 # Reset count on new batch size rampup interval - self.num_microbatches_in_previous_step = get_num_microbatches() - is_first_microbatch = self.microbatch_count % get_num_microbatches() == 0 - - # Forward pass. - if self.recompute_granularity == 'full': - model_hidden_states = self._checkpointed_forward(model_hidden_states, - model_attention_mask, - model_encoder_output, - model_enc_model_dec_attn_mask, - model_rotary_pos_emb, - is_first_microbatch) - else: - forward_pass_kwargs = { - 'model_encoder_output': model_encoder_output, - 'model_enc_model_dec_attn_mask': model_enc_model_dec_attn_mask, - 'model_inference_params': model_inference_params, + return nullcontext() + + def prepare_forward_pass_kwargs(self, **kwargs): + forward_pass_kwargs = { + 'model_encoder_output': kwargs['model_encoder_output'], + 'model_enc_model_dec_attn_mask': kwargs['model_enc_model_dec_attn_mask'], + 'model_inference_params': kwargs['model_inference_params'], } - if self.transformer_impl == 'transformer_engine': - forward_pass_kwargs['is_first_microbatch'] = is_first_microbatch - forward_pass_kwargs['checkpoint_core_attention'] = self.checkpoint_core_attention - if self.transformer_engine_v_0_10: - forward_pass_kwargs['model_rotary_pos_emb'] = model_rotary_pos_emb - else: - forward_pass_kwargs['model_rotary_pos_emb'] = model_rotary_pos_emb - forward_pass_kwargs['model_retriever_input'] = model_retriever_input - forward_pass_kwargs['model_retriever_output'] = model_retriever_output - forward_pass_kwargs['model_retriever_attn_mask'] = model_retriever_attn_mask - - for index in range(self.num_layers): - layer = self._get_layer(index) - - model_hidden_states = layer( - model_hidden_states, - model_attention_mask, - **forward_pass_kwargs) - - if isinstance(model_hidden_states, tuple): - if len(model_hidden_states) != 2: - raise ValueError("model_hidden_states should be a tuple of length 2") - model_hidden_states, model_retriever_output = model_hidden_states - forward_pass_kwargs["model_retriever_output"] = model_retriever_output - - if torch.is_grad_enabled() and self.training: - self.microbatch_count += 1 - return model_hidden_states + if self.transformer_impl == 'transformer_engine': + forward_pass_kwargs['is_first_microbatch'] = kwargs['is_first_microbatch'] + forward_pass_kwargs['checkpoint_core_attention'] = self.checkpoint_core_attention + if self.tansformer_engine_v_0_10: + forward_pass_kwargs['model_rotary_pos_emb'] = kwargs['model_rotary_pos_emb'] + else: + forward_pass_kwargs['model_rotary_pos_emb'] = kwargs['model_rotary_pos_emb'] + forward_pass_kwargs['model_retriever_input'] = kwargs['model_retriever_input'] + forward_pass_kwargs['model_retriever_output'] = kwargs['model_retriever_output'] + forward_pass_kwargs['model_retriever_attn_mask'] = kwargs['model_retriever_attn_mask'] + + return forward_pass_kwargs + + def update_microbatch_count(self): + if torch.is_grad_enabled() and self.training: + self.microbatch_count += 1 def get_output_name(self): return ['model_hidden_states'] diff --git a/profiler/tinker/model/block_adapter_1_2.py b/profiler/tinker/model/block_adapter_1_2.py index f0957f957..4f8c43c33 100644 --- a/profiler/tinker/model/block_adapter_1_2.py +++ b/profiler/tinker/model/block_adapter_1_2.py @@ -131,49 +131,10 @@ class McoreTransformerBlockAdapter(BlockAdapter): model_inference_params = input_data.model_inference_params model_packed_seq_params = input_data.model_packed_seq_params - if not self.pre_process: - # See set_input_tensor() - model_hidden_states = self.input_tensor - - if self.input_embeds_norm and self.pre_process: - normalizer = torch.tensor(self.hidden_size ** 0.5, dtype=model_hidden_states.dtype) - model_hidden_states = model_hidden_states * normalizer - - model_hidden_states = make_viewless_tensor( - inp=model_hidden_states, requires_grad=True, keep_graph=True, - ) - - if self.config.sequence_parallel: - forward_rng_context = tensor_parallel.get_cuda_rng_tracker().fork() - else: - forward_rng_context = nullcontext() - - if self.config.fp8: - import transformer_engine # To keep out TE dependency when not training in fp8 - - if self.config.fp8 == "e4m3": - fp8_format = transformer_engine.common.recipe.Format.E4M3 - elif self.config.fp8 == "hybrid": - fp8_format = transformer_engine.common.recipe.Format.HYBRID - else: - raise ValueError("E4M3 and HYBRID are the only supported FP8 formats.") - - forward_fp8_recipe = transformer_engine.common.recipe.DelayedScaling( - margin=self.config.fp8_margin, - interval=self.config.fp8_interval, - fp8_format=fp8_format, - amax_compute_algo=self.config.fp8_amax_compute_algo, - amax_history_len=self.config.fp8_amax_history_len, - override_linear_precision=(False, False, not self.config.fp8_wgrad), - ) - forward_fp8_group = None - if parallel_state.model_parallel_is_initialized(): - forward_fp8_group = parallel_state.get_amax_reduction_group(with_context_parallel=True) - forward_fp8_context = transformer_engine.pytorch.fp8_autocast( - enabled=True, forward_fp8_recipe=forward_fp8_recipe, forward_fp8_group=forward_fp8_group - ) - else: - forward_fp8_context = nullcontext() + model_hidden_states = self.preprocess_model_hidden_states(model_hidden_states) + model_hidden_states = self.make_viewless_tensor(model_hidden_states) + forward_rng_context = self.get_forward_rng_context() + forward_fp8_context = self.get_forward_fp8_context() with forward_rng_context and forward_fp8_context: # Forward pass. @@ -207,6 +168,51 @@ class McoreTransformerBlockAdapter(BlockAdapter): model_hidden_states = self.group_prefetch_offload_commit_async(model_hidden_states) return model_hidden_states + + def preprocess_hidden_states(self, model_hidden_states, model_attention_mask): + if not self.pre_process: + model_hidden_states = self.input_tensor + + if self.input_embeds_norm and self.pre_process: + normalizer = torch.tensor(self.hidden_size ** 0.5, dtype=model_hidden_states.dtype) + model_hidden_states = model_hidden_states * normalizer + return model_hidden_states + + def make_viewless_tensor(self, inp): + return make_viewless_tensor(inp=inp, requires_grad=True, keep_graph=True) + + def get_rng_context(self): + if self.config.sequence_parallel: + return tensor_parallel.get_cuda_rng_tracker().fork() + else: + return nullcontext() + + def get_fp8_context(self): + if self.config.fp8: + import transformer_engine + + model_fp8_format = { + "e4m3": transformer_engine.common.recipe.Format.E4M3, + "hybrid": transformer_engine.common.recipe.Format.HYBRID, + }[self.config.fp8] + + forward_fp8_recipe = transformer_engine.common.recipe.DelayedScaling( + margin=self.config.fp8_margin, + interval=self.config.fp8_interval, + model_fp8_format=model_fp8_format, + amax_compute_algo=self.config.fp8_amax_compute_algo, + amax_history_len=self.config.fp8_amax_history_len, + override_linear_precision=(False, False, not self.config.fp8_wgrad), + ) + forward_fp8_group = None + if parallel_state.model_parallel_is_initialized(): + forward_fp8_group = parallel_state.get_amax_reduction_group(with_context_parallel=True) + + return transformer_engine.pytorch.fp8_autocast( + enabled=True, forward_fp8_recipe=forward_fp8_recipe, forward_fp8_group=forward_fp8_group + ) + else: + return nullcontext() def get_output_name(self): return ['model_hidden_states'] -- Gitee