diff --git a/profiler/tinker/model/__init__.py b/profiler/tinker/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/profiler/tinker/model/block_adapter.py b/profiler/tinker/model/block_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..18dadeef1bf8e1d99c107ceb9970af37838d0805 --- /dev/null +++ b/profiler/tinker/model/block_adapter.py @@ -0,0 +1,29 @@ +from abc import ABC, abstractmethod + + +class BlockAdapter(ABC): + def __init__(self, name, module): + self.name = name + self.module = module + self.output_name = self.get_output_name() + + @abstractmethod + def origin_forward(self, *args, **kwargs): + """原生前向逻辑,其中若截断了原有前向,则需要手动确定后续逻辑要用的张量,作为返回值返回""" + pass + + @abstractmethod + def get_output_name(self): + """指定`origin_forward`返回张量的字符串名称,供匹配到下个子图的输入张量""" + pass + + @abstractmethod + def weight_module(self): + """返回该block中会占用显存的module""" + pass + + @abstractmethod + def copy_method_module(self, other): + """将`self.module`中被`origin_forward`使用的方法和module复制到other中""" + pass + \ No newline at end of file diff --git a/profiler/tinker/model/block_adapter_1_0.py b/profiler/tinker/model/block_adapter_1_0.py new file mode 100644 index 0000000000000000000000000000000000000000..6e5131c3084f259da22f4505aefb2d68fb37b631 --- /dev/null +++ b/profiler/tinker/model/block_adapter_1_0.py @@ -0,0 +1,332 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from contextlib import nullcontext +from dataclasses import dataclass + +import torch + +from megatron import core +from megatron.core import tensor_parallel +from megatron.model.gpt.model import post_language_model_processing +from megatron.training import get_args +from tinker.megatron_patch.microbatches import get_num_microbatches +from tinker.model.block_adapter import BlockAdapter + + +@dataclass +class TransformerForwardInput: + hidden_states: any = None + attention_mask: any = None + encoder_output: any = None + enc_dec_attn_mask: any = None + retriever_input: any = None + retriever_output: any = None + retriever_attn_mask: any = None + inference_params: any = None + rotary_pos_emb: any = None + + +class TransformerBlockAdapter(BlockAdapter): + """megatron.model.transformer.ParallelTransformer""" + + def __init__(self): + pass + + @staticmethod + def origin_forward(self, input_data: TransformerForwardInput): + hidden_states = input_data.hidden_states + attention_mask = input_data.attention_mask + encoder_output = input_data.encoder_output + enc_dec_attn_mask = input_data.enc_dec_attn_mask + retriever_input = input_data.retriever_input + retriever_output = input_data.retriever_output + retriever_attn_mask = input_data.retriever_attn_mask + inference_params = input_data.inference_params + rotary_pos_emb = input_data.rotary_pos_emb + + self.seq_length = get_args().seq_length + + # Checks. + if input_data.inference_params: + if self.recompute_granularity is not None: + raise ValueError('inference does not work with activation checkpointing') + + if not self.pre_process: + # See set_input_tensor() + hidden_states = self.input_tensor + + # Viewless tensor. + # - We only need to create a viewless tensor in the case of mirco batch + # size (mbs) == 1, since in this case, 'hidden_states.transpose()' + # above creates a view tensor, and '.contiguous()' is a pass-through. + # For mbs >= 2, '.contiguous()' creates a new tensor, eliminating + # the need to make it viewless. + # + # However, we don't explicitly check mbs == 1 here because + # make_viewless_tensor() has negligible overhead when its input + # is already viewless. + # + # - For the 'else' case above,calling make_viewless_tensor() here is + # likely redundant, since p2p_communication.py (likely originator) + # already creates viewless tensors. That said, make_viewless_tensor() + # is called here to be future-proof and corner-case-proof. + hidden_states = core.utils.make_viewless_tensor( + hidden_states, + requires_grad=True, + keep_graph=True, + ) + + # RNG context. + if self.sequence_parallel: + rng_context = tensor_parallel.get_cuda_rng_tracker().fork() + else: + rng_context = nullcontext() + + # Forward layers. + with rng_context: + # The fp8_autocast context manager is a no-op when enable=True + # The if...else serves to short circuit name resolution for fp8_autocast + with transformer_engine.pytorch.fp8_autocast( + enabled=self.use_fp8, + fp8_recipe=self.fp8_recipe, + fp8_group=self.fp8_group + ) if self.use_fp8 else nullcontext(): + # Determine if the current iteration is first microbatch + if self.num_microbatches_in_previous_step != get_num_microbatches(): + self.microbatch_count = 0 # Reset count on new batch size rampup interval + self.num_microbatches_in_previous_step = get_num_microbatches() + is_first_microbatch = self.microbatch_count % get_num_microbatches() == 0 + + # Forward pass. + if self.recompute_granularity == 'full': + hidden_states = self.checkpointed_forward(hidden_states, + attention_mask, + encoder_output, + enc_dec_attn_mask, + rotary_pos_emb, + is_first_microbatch) + else: + forward_kwargs = { + 'encoder_output': encoder_output, + 'enc_dec_attn_mask': enc_dec_attn_mask, + 'inference_params': inference_params, + } + + if self.transformer_impl == 'transformer_engine': + forward_kwargs['is_first_microbatch'] = is_first_microbatch + forward_kwargs['checkpoint_core_attention'] = self.checkpoint_core_attention + if self.tansformer_engine_v_0_10: + forward_kwargs['rotary_pos_emb'] = rotary_pos_emb + else: + forward_kwargs['rotary_pos_emb'] = rotary_pos_emb + forward_kwargs['retriever_input'] = retriever_input + forward_kwargs['retriever_output'] = retriever_output + forward_kwargs['retriever_attn_mask'] = retriever_attn_mask + + for index in range(self.num_layers): + layer = self._get_layer(index) + + hidden_states = layer( + hidden_states, + attention_mask, + **forward_kwargs) + + # First Retro decoder layer returns both hidden_states + # add retriever_output. Make retriever_output available + # to subsequence Retro layers. + if isinstance(hidden_states, tuple): + if len(hidden_states) != 2: + raise ValueError("hidden_states should be a tuple of length 2") + hidden_states, retriever_output = hidden_states + forward_kwargs['retriever_output'] = retriever_output + + # Skip counter update for eval and activating checkpointing + if torch.is_grad_enabled() and self.training: + self.microbatch_count += 1 + return hidden_states + + def get_output_name(self): + return ['hidden_states'] + + def copy_method_module(self,other): + other._get_layer = self.module._get_layer + other.layers = self.module.layers + other._checkpointed_forward = self.module._checkpointed_forward # 其实没必要,因为profile时不会用到,但可起到标注作用 + if other.transformer_impl == 'transformer_engine': + global transformer_engine + + def weight_module(self): + return self.module + + +@dataclass +class EmbeddingForwardInput: + enc_input_ids: any = None + enc_position_ids: any = None + enc_attn_mask: any = None + dec_input_ids: any = None + dec_position_ids: any = None + dec_attn_mask: any = None + retriever_input_ids: any = None + retriever_position_ids: any = None + retriever_attn_mask: any = None + enc_dec_attn_mask: any = None + tokentype_ids: any = None + inference_params: any = None + pooling_sequence_index: any = 0 + enc_hidden_states: any = None + output_enc_hidden: any = False + + +class EmbeddingAdapter(BlockAdapter): + """megatron.model.language_model.TransformerLanguageModel.forward""" + + def __init__(self): + pass + + @staticmethod + def origin_forward(self, input_data: EmbeddingForwardInput): + enc_input_ids = input_data.enc_input_ids + enc_position_ids = input_data.enc_position_ids + enc_attn_mask = input_data.enc_attn_mask + dec_input_ids = input_data.dec_input_ids + dec_position_ids = input_data.dec_position_ids + dec_attn_mask = input_data.dec_attn_mask + retriever_input_ids = input_data.retriever_input_ids + retriever_position_ids = input_data.retriever_position_ids + retriever_attn_mask = input_data.retriever_attn_mask + enc_dec_attn_mask = input_data.enc_dec_attn_mask + tokentype_ids = input_data.tokentype_ids + inference_params = input_data.inference_params + pooling_sequence_index = input_data.pooling_sequence_index + enc_hidden_states = input_data.enc_hidden_states + output_enc_hidden = input_data.output_enc_hidden + + # Encoder embedding. + if self.pre_process: + encoder_input = self.embedding(enc_input_ids, enc_position_ids, + tokentype_ids=tokentype_ids) + else: + encoder_input = None + + # Retriever embedding. + if self.add_retriever and self.pre_process: + retriever_input = self.embedding(retriever_input_ids, + retriever_position_ids, + tokentype_ids=tokentype_ids) + else: + retriever_input = None + + # Rotary positional embeddings. + rotary_pos_emb = None + if self.use_rotary_position_embeddings: + if inference_params is not None: + rotary_pos_emb = \ + self.rotary_pos_emb(inference_params.max_sequence_length) + else: + rotary_pos_emb = self.rotary_pos_emb(self.seq_length) + + return encoder_input, rotary_pos_emb + + def get_output_name(self): + return ['encoder_input', 'rotary_pos_emb'] + + def copy_method_module(self, other): + other.embedding = self.module.embedding + if other.use_rotary_position_embeddings: + other.rotary_pos_emb = self.module.rotary_pos_emb + + def weight_module(self): + return self.module.embedding + + +class FinalNormAdapter(BlockAdapter): + """ParallelTransformer forward 中 self.final_norm 相关部分""" + + def __init__(self): + pass + + @staticmethod + def origin_forward(self, hidden_states): + hidden_states = self.final_norm(hidden_states) + + return hidden_states + + def get_output_name(self): + return ['hidden_states'] + + def copy_method_module(self, other): + other.final_norm = self.module.final_norm + + def weight_module(self): + return self.module.final_norm + + +@dataclass +class LossForwardInput: + input_ids: any = None + position_ids: any = None + attention_mask: any = None + lm_output: any = None + retriever_input_ids: any = None + retriever_position_ids: any = None + retriever_attn_mask: any = None + labels: any = None + tokentype_ids: any = None + inference_params: any = None + + +class LossAdapter(BlockAdapter): + """modellink.model.GPTModel""" + + def __init__(self): + pass + + @staticmethod + def origin_forward(self, input_data: LossForwardInput): + input_ids = input_data.input_ids + position_ids = input_data.position_ids + attention_mask = input_data.attention_mask + lm_output = input_data.lm_output + retriever_input_ids = input_data.retriever_input_ids + retriever_position_ids = input_data.retriever_position_ids + retriever_attn_mask = input_data.retriever_attn_mask + labels = input_data.labels + tokentype_ids = input_data.tokentype_ids + inference_params = input_data.inference_params + + output = post_language_model_processing( + lm_output, labels, + # 此处不进行判断 而直接输入该形状匹配的内容 + self.logit_weights, + self.parallel_output, + self.fp16_lm_cross_entropy) + + return output + + def get_output_name(self): + return ["output"] + + def copy_method_module(self, other): + # git 适配 not self.untie_embeddings_and_output_weights 情况 + if other.untie_embeddings_and_output_weights: + other.logit_weights = self.module.language_model.output_layer.weight + else: + other.logit_weights = self.module.language_model.embedding.word_embeddings.weight + + def weight_module(self): + return self.module.language_model.output_layer + \ No newline at end of file diff --git a/profiler/tinker/model/block_adapter_1_1.py b/profiler/tinker/model/block_adapter_1_1.py new file mode 100644 index 0000000000000000000000000000000000000000..28a3ce6fa0d56f6403cf1e077ee1746e236b6970 --- /dev/null +++ b/profiler/tinker/model/block_adapter_1_1.py @@ -0,0 +1,333 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from contextlib import nullcontext +from dataclasses import dataclass + +import torch + +from megatron import core +from megatron.core import tensor_parallel +from megatron.legacy.model.gpt_model import post_language_model_processing +from tinker.megatron_patch.microbatches import get_num_microbatches +from tinker.model.block_adapter import BlockAdapter + + +@dataclass +class TransformerForwardInput: + hidden_states: any = None + attention_mask: any = None + encoder_output: any = None + enc_dec_attn_mask: any = None + retriever_input: any = None + retriever_output: any = None + retriever_attn_mask: any = None + inference_params: any = None + rotary_pos_emb: any = None + + +class TransformerBlockAdapter(BlockAdapter): + """parallel_transformer_forward""" + + def __init__(self): + pass + + @staticmethod + def origin_forward(self, input_data: TransformerForwardInput): + hidden_states = input_data.hidden_states + attention_mask = input_data.attention_mask + encoder_output = input_data.encoder_output + enc_dec_attn_mask = input_data.enc_dec_attn_mask + retriever_input = input_data.retriever_input + retriever_output = input_data.retriever_output + retriever_attn_mask = input_data.retriever_attn_mask + inference_params = input_data.inference_params + rotary_pos_emb = input_data.rotary_pos_emb + + #Checks. + if inference_params: + if self.recompute_granularity is not None: + raise ValueError('inference does not work with activation checkpointing') + + if not self.pre_process: + # See set_input_tensor() + hidden_states = self.input_tensor + + # Viewless tensor + # - We only need to create a viewless tensor in the case of micro batch + # size (mbs) == 1, since in this case, 'hidden_states.transpose()' + # above creates a view tensor, and '.contiguous()' is a pass-through. + # For mbs >= 2, '.contiguous()' creates a new tensor, eliminating + # the need to make it viewless. + # + # However, we don't explicitly check mbs == 1 here because + # make_viewless_tensor() has negligible overhead when its input + # is already viewless. + # + # - For the 'else' case above, calling make_viewless_tensor() here is + # likely redundant, since p2p_communication.py (likely originator) + # already creates viewless tensors. That said, make_viewless_tensor() + # is called here to be future-proof and corner-case-proof. + if self.input_embeds_norm and self.pre_process: + hidden_states = hidden_states * (self.hidden_size ** 0.5) + + hidden_states = core.utils.make_viewless_tensor( + hidden_states, + requires_grad=True, + keep_graph=True, + ) + + # RNG context. + if self.sequence_parallel: + rng_context = tensor_parallel.get_cuda_rng_tracker().fork() + else: + rng_context = nullcontext() + + # Forward layers + with rng_context: + # The fp8_autocast context manager is a no-op when enabled=True + # The if...else serves to short circuit name resolution for fp8_autocast + with transformer_engine.pytorch.fp8_autocast( + enabled=self.use_fp8, + fp8_recipe=self.fp8_recipe, + fp8_group=self.fp8_group + ) if self.use_fp8 else nullcontext(): + # Determine if the current iteration is first microbatch + if self.num_microbatches_in_previous_step != get_num_microbatches(): + self.microbatch_count = 0 # Reset count on new batch size rampup interval + self.num_microbatches_in_previous_step = get_num_microbatches() + is_first_microbatch = self.microbatch_count % get_num_microbatches() == 0 + + # Forward pass. + if self.recompute_granularity == 'full': + hidden_states = self._checkpointed_forward(hidden_states, + attention_mask, + encoder_output, + enc_dec_attn_mask, + rotary_pos_emb, + is_first_microbatch) + else: + forward_kwargs = { + 'encoder_output': encoder_output, + 'enc_dec_attn_mask': enc_dec_attn_mask, + 'inference_params': inference_params, + } + + if self.transformer_impl == 'transformer_engine': + forward_kwargs['is_first_microbatch'] = is_first_microbatch + forward_kwargs['checkpoint_core_attention'] = self.checkpoint_core_attention + if self.transformer_engine_v_0_10: + forward_kwargs['rotary_pos_emb'] = rotary_pos_emb + else: + forward_kwargs['rotary_pos_emb'] = rotary_pos_emb + forward_kwargs['retriever_input'] = retriever_input + forward_kwargs['retriever_output'] = retriever_output + forward_kwargs['retriever_attn_mask'] = retriever_attn_mask + + for index in range(self.num_layers): + layer = self._get_layer(index) + + hidden_states = layer( + hidden_states, + attention_mask, + **forward_kwargs) + + # First Retro decoder layer returns both hidden_states + # and retriever_output. Make retriever_output available + # to subsequence Retro layers. + if isinstance(hidden_states, tuple): + if len(hidden_states) != 2: + raise ValueError("hidden_states should be a tuple of length 2") + hidden_states, retriever_output = hidden_states + forward_kwargs["retriever_output"] = retriever_output + + # Skip counter update for eval and activation checkpointing + if torch.is_grad_enabled() and self.training: + self.microbatch_count += 1 + return hidden_states + + def get_output_name(self): + return ['hidden_states'] + + def copy_method_module(self, other): + other._get_layer = self.module._get_layer + other.layers = self.module.layers + other._checkpointed_forward = self.module._checkpointed_forward # 其实没必要,因为profile时不会用到,但可起到标注作用 + if other.transformer_impl == 'transformer_engine': + global transformer_engine + # noinsepction PyUnresolvedReferences + import transformer_engine + + def weight_module(self): + return self.module + + +@dataclass +class EmbeddingForwardInput: + enc_input_ids: any = None + enc_position_ids: any = None + enc_attn_mask: any = None + dec_input_ids: any = None + dec_position_ids: any = None + dec_attn_mask: any = None + retriever_input_ids: any = None + retriever_position_ids: any = None + retriever_attn_mask: any = None + enc_dec_attn_mask: any = None + tokentype_ids: any = None + inference_params: any = None + pooling_sequence_index: any = 0 + enc_hidden_states: any = None + output_enc_hidden: any = False + + +class EmbeddingAdapter(BlockAdapter): + """megatron.legacy.model.language_model.TransformerLanguageModel.forward""" + + def __init__(self): + pass + + @staticmethod + def origin_forward(self, input_data: EmbeddingForwardInput): + enc_input_ids = input_data.enc_input_ids + enc_position_ids = input_data.enc_position_ids + enc_attn_mask = input_data.enc_attn_mask + dec_input_ids = input_data.dec_input_ids + dec_position_ids = input_data.dec_position_ids + dec_attn_mask = input_data.dec_attn_mask + retriever_input_ids = input_data.retriever_input_ids + retriever_position_ids = input_data.retriever_position_ids + retriever_attn_mask = input_data.retriever_attn_mask + enc_dec_attn_mask = input_data.enc_dec_attn_mask + tokentype_ids = input_data.tokentype_ids + inference_params = input_data.inference_params + pooling_sequence_index = input_data.pooling_sequence_index + enc_hidden_states = input_data.enc_hidden_states + output_enc_hidden = input_data.output_enc_hidden + + # Encoder embedding. + if self.pre_process: + encoder_input = self.embedding(enc_input_ids, enc_position_ids, + tokentype_ids=tokentype_ids) + else: + encoder_input = None + + # Retriever embedding. + if self.add_retriever and self.pre_process: + retriever_input = self.embedding(retriever_input_ids, + retriever_position_ids, + tokentype_ids=tokentype_ids) + else: + retriever_input = None + + # Rotary positional embeddings + rotary_pos_emb = None + if self.use_rotary_position_embeddings: + if inference_params is not None: + rotary_pos_emb = \ + self.rotary_pos_emb(inference_params.max_sequence_length) + else: + rotary_pos_emb = self.rotary_pos_emb(self.seq_length) + + return encoder_input, rotary_pos_emb + + def get_output_name(self): + return ['encoder_input', 'rotary_pos_emb'] + + def copy_method_module(self, other): + other.embedding = self.module.embedding + if other.use_rotary_position_embeddings: + other.rotary_pos_emb = self.module.rotary_pos_emb + + def weight_module(self): + return self.module.embedding + + +class FinalNormAdapter(BlockAdapter): + """ParallelTransformer forward 中 self.final_norm 相关部分""" + + def __init__(self): + pass + + @staticmethod + def origin_forward(self, hidden_states): + hidden_states = self.final_norm(hidden_states) + + return hidden_states + + def get_output_name(self): + return ['hidden_states'] + + def copy_method_module(self, other): + other.final_norm = self.module.final_norm + + def weight_module(self): + return self.module.final_norm + + +@dataclass +class LossForwardInput: + input_ids: any = None + position_ids: any = None + attention_mask: any = None + lm_output: any = None + retriever_input_ids: any = None + retriever_position_ids: any = None + retriever_attn_mask: any = None + labels: any = None + tokentype_ids: any = None + inference_params: any = None + + +class LossAdapter(BlockAdapter): + """modellink.model.GPTModel""" + + def __init__(self): + pass + + @staticmethod + def origin_forward(self, input_data: LossForwardInput): + input_ids = input_data.input_ids + position_ids = input_data.position_ids + attention_mask = input_data.attention_mask + lm_output = input_data.lm_output + retriever_input_ids = input_data.retriever_input_ids + retriever_position_ids = input_data.retriever_position_ids + retriever_attn_mask = input_data.retriever_attn_mask + labels = input_data.labels + tokentype_ids = input_data.tokentype_ids + inference_params = input_data.inference_params + + output = post_language_model_processing( + lm_output, labels, + # 此处不进行判断 而直接输入该形状匹配的内容 + self.logit_weights, + self.parallel_output, + self.fp16_lm_cross_entropy) + + return output + + def get_output_name(self): + return ["output"] + + def copy_method_module(self, other): + # git 适配 not self.untie_embeddings_and_output_weights 情况 + if other.untie_embeddings_and_output_weights: + other.logit_weights = self.module.language_model.output_layer.weight + else: + other.logit_weights = self.module.language_model.embedding.word_embeddings.weight + + def weight_module(self): + return self.module.language_model.output_layer diff --git a/profiler/tinker/model/block_adapter_1_2.py b/profiler/tinker/model/block_adapter_1_2.py new file mode 100644 index 0000000000000000000000000000000000000000..38782daa341b66ae80b00d5ee45a416ddf4cfd21 --- /dev/null +++ b/profiler/tinker/model/block_adapter_1_2.py @@ -0,0 +1,340 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from contextlib import nullcontext +from dataclasses import dataclass + + +import torch +from torch import Tensor + +from megatron.core import parallel_state, tensor_parallel, InferenceParams +from megatron.core.packed_seq_params import PackedSeqParams +from meagtron.core.utils import make_viewless_tensor +from megatron.training import get_args +from tinker.model import block_adapter_1_1 +from tinker.model.block_adapter import BlockAdapter + +EmbeddingAdapter = block_adapter_1_1.EmbeddingAdapter +TransformerBlockAdapter = block_adapter_1_1.TransformerBlockAdapter +FinalNormAdapter = block_adapter_1_1.FinalNormAdapter +LossAdapter = block_adapter_1_1.LossAdapter + + +class DummyDecoder: + def __init__(self): + self.input_tensor = None + + +@dataclass +class McoreForwardInput: + input_ids: Tensor = None + position_ids: Tensor = None + attention_mask: Tensor = None + decoder_input: Tensor = None + labels: Tensor = None + inference_params: InferenceParams = None + packed_seq_params: PackedSeqParams = None + extra_block_kwargs: dict = None + tokentype_ids=None + + +class McoreEmbeddingAdapter(BlockAdapter): + """modellink.core.gpt_model_forward""" + + def __init__(self): + pass + + @staticmethod + def origin_forward(self, input_data: McoreForwardInput): + input_ids = input_data.input_ids + position_ids = input_data.position_ids + attention_mask = input_data.attention_mask + decoder_input = input_data.decoder_input + labels = input_data.labels + inference_params = input_data.inference_params + packed_seq_params = input_data.packed_seq_params + extra_block_kwargs = input_data.extra_block_kwargs + tokentype_ids = input_data.tokentype_ids + """modellink.core.gpt_model_forward""" + # If decoder_input is provides (not None), then input_ids and position_ids are ignored. + # Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input. + args = get_args() + # Decoder embedding. + if decoder_input is not None: + pass + elif self.pre_process: + decoder_input = self.embedding(input_ids=input_ids, position_ids=position_ids) + if args.scale_emb is not None: + decoder_input = decoder_input * args.scale_emb + else: + # intermediate stage of pipeline + # decoder will get hidden_states from encoder.input_tensor + decoder_input = None + + # Rotary positional embeddings (embedding is None for PP intermediate devices) + rotary_pos_emb = None + if self.position_embedding_type == 'rope': + rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len( + inference_params, self.decoder, decoder_input, self.config + ) + rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len) + + return decoder_input, rotary_pos_emb + + def get_output_name(self): + return ['decoder_input', 'rotary_pos_emb'] + + def copy_method_module(self, other): + # 此decoder用于避免前向报错,但其本身值并不被使用 + other.decoder = DummyDecoder() + other.embedding = self.module.embedding + other.rotary_pos_emb = self.module.rotary_pos_emb + + def weight_module(self): + return self.module.embedding, self.module.rotary_pos_emb + + +@dataclass +class McoreTransformerForwardInput: + hidden_states: Tensor = None + attention_mask: Tensor = None + context: Tensor = None + context_mask: Tensor = None + rotary_pos_emb: Tensor = None + inference_params: InferenceParams = None + packed_seq_params: PackedSeqParams = None + + +class McoreTransformerBlockAdapter(BlockAdapter): + """megatron.core.transformer.transformer_block.TransformerBlock""" + + def __init__(self): + pass + + @staticmethod + def origin_forward(self, input_data: McoreTransformerForwardInput): + """modellink.core.transformer_block_forward""" + hidden_states = input_data.hidden_states + attention_mask = input_data.attention_mask + context = input_data.context + context_mask = input_data.context_mask + rotary_pos_emb = input_data.rotary_pos_emb + inference_params = input_data.inference_params + packed_seq_params = input_data.packed_seq_params + + # hidden_states (float): [s, b, h] + # attention_mask (bool): [1, 1, s, s] + + if not self.pre_process: + # See set_input_tensor() + hidden_states = self.input_tensor + + # Viewless tensor. + # - We only need to create a viewless tensor in the case of micro batch + # size (mbs) == 1, since in this case, 'hidden_states.transpose()' + # above creates a view tensor, and '.contiguous()' is a pass-through. + # For mbs >= 2, '.contiguous()' creates a new tensor, eliminating + # the need to make it viewless. + # + # However, we don't explicitly check mbs == 1 here because + # make_viewless_tensor() has negligible overhead when its input + # is already viewless. + # + # - For the 'else' case above, calling make_viewless_tensor() here is + # likely redundant, since p2p_communication.py (likely originator) + # already creates viewless tensors. That said, make_viewless_tensor() + # is called here to be future-proof and corner-case-proof. + if self.input_embeds_norm and self.pre_process: + normalizer = torch.tensor(self.hidden_size ** 0.5, dtype=hidden_states.dtype) + hidden_states = hidden_states * normalizer + + hidden_states = make_viewless_tensor( + inp=hidden_states, requires_grad=True, keep_graph=True, + ) + + if self.config.sequence_parallel: + rng_context = tensor_parallel.get_cuda_rng_tracker().fork() + else: + rng_context = nullcontext() + + if self.config.fp8: + import transformer_engine # To keep out TE dependency when not training in fp8 + + if self.config.fp8 == "e4m3": + fp8_format = transformer_engine.common.recipe.Format.E4M3 + elif self.config.fp8 == "hybrid": + fp8_format = transformer_engine.common.recipe.Format.HYBRID + else: + raise ValueError("E4M3 and HYBRID are the only supported FP8 formats.") + + fp8_recipe = transformer_engine.common.recipe.DelayedScaling( + margin=self.config.fp8_margin, + interval=self.config.fp8_interval, + fp8_format=fp8_format, + amax_compute_algo=self.config.fp8_amax_compute_algo, + amax_history_len=self.config.fp8_amax_history_len, + override_linear_precision=(False, False, not self.config.fp8_wgrad), + ) + fp8_group = None + if parallel_state.model_parallel_is_initialized(): + fp8_group = parallel_state.get_amax_reduction_group(with_context_parallel=True) + fp8_context = transformer_engine.pytorch.fp8_autocast( + enabled=True, fp8_recipe=fp8_recipe, fp8_group=fp8_group + ) + else: + fp8_context = nullcontext() + + with rng_context and fp8_context: + # Forward pass. + if self.config.recompute_granularity == 'full' and self.training: + hidden_states = self._checkpointed_forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + context=context, + context_mask=context_mask, + rotary_pos_emb=rotary_pos_emb, + packed_seq_params=packed_seq_params, + ) + else: + for layer in self.layers: + with self.offload_context: + hidden_states, context = layer( + hidden_states=hidden_states, + attention_mask=attention_mask, + context=context, + context_mask=context_mask, + rotary_pos_emb=rotary_pos_emb, + inference_params=inference_params, + packed_seq_params=packed_seq_params, + ) + + if ( + torch.is_grad_enabled() + and self.config.cpu_offloading + and self.group_prefetch_offload_commit_async is not None + ): + hidden_states = self.group_prefetch_offload_commit_async(hidden_states) + + return hidden_states + + def get_output_name(self): + return ['hidden_states'] + + def copy_method_module(self, other): + other.layers = self.module.layers + + def weight_module(self): + return self.module.layers + + +class McoreFinalNormAdapter(BlockAdapter): + + def __init__(self): + pass + + @staticmethod + def origin_forward(self, hidden_states: Tensor = None): + """modellink.core.transformer_block_forward""" + hidden_states = self.final_layernorm(hidden_states) + + return hidden_states + + def get_output_name(self): + return ['hidden_states'] + + def copy_method_module(self, other): + other.final_layernorm = self.module.final_layernorm + + def weight_module(self): + return self.module.final_layernorm + + +@dataclass +class McoreLossForwardInput: + input_ids: Tensor = None + label: any = None + hidden_states: any = None + position_ids: Tensor = None + attention_mask: Tensor = None + decoder_input: Tensor = None + labels: Tensor = None + inference_params: InferenceParams = None + packed_seq_params: PackedSeqParams = None + extra_block_kwargs: dict = None + tokentype_ids: any = None + + +class McoreLossAdapter(BlockAdapter): + """modellink.core.gpt_model_forward""" + + def __init__(self): + pass + + @staticmethod + def origin_forward(self, input_data: McoreLossForwardInput): + """modellink.core.gpt_model_forward""" + input_ids = input_data.input_ids + label = input_data.label + hidden_states = input_data.hidden_states + position_ids = input_data.position_ids + attention_mask = input_data.attention_mask + decoder_input = input_data.decoder_input + labels = input_data.labels + inference_params = input_data.inference_params + packed_seq_params = input_data.packed_seq_params + extra_block_kwargs = input_data.extra_block_kwargs + tokentype_ids = input_data.tokentype_ids + + args = get_args() + # logits and loss + output_weight = None + if self.share_embeddings_and_output_weights: + output_weight = self.shared_embedding_or_output_weight() + + if args.dim_model_base is not None: + hidden_states = hidden_states / (args.hidden_size / args.dim_model_base) + logits, _ = self.output_layer(hidden_states, weight=output_weight) + # new add to scale logits + if args.output_multiplier_scale: + logits = logits * args.output_multiplier_scale + + if args.output_logit_softcapping: + logits = logits / args.output_logit_softcapping + logits = torch.tanh(logits) + logits = logits * args.output_logit_softcapping + + if labels is None: + # [s b h] => [b s h] + return logits.transpose(0, 1).contiguous() + + if args.is_instruction_dataset: + labels = labels[:, 1:].contiguous() + logits = logits[:-1, :, :].contiguous() + + loss = self.compute_language_model_loss(labels, logits) + + return loss + + def get_output_name(self): + return ['loss'] + + def copy_method_module(self, other): + other.output_layer = self.module.output_layer + other.compute_language_model_loss = self.module.compute_language_model_loss + other.shared_embedding_or_output_weight = self.module.shared_embedding_or_output_weight + + def weight_module(self): + return self.module.output_layer diff --git a/profiler/tinker/model/block_infos.py b/profiler/tinker/model/block_infos.py new file mode 100644 index 0000000000000000000000000000000000000000..041a15e46895d40892437fac25ab855439e219ad --- /dev/null +++ b/profiler/tinker/model/block_infos.py @@ -0,0 +1,48 @@ +import importlib +from dataclasses import dataclass +from typing import List, Optional, Type + +import torch + +from tinker.framework_adapter.modellink_adapter import get_block_adapter, ModelLinkAdapter +from tinker.model.block_adapter import BlockAdapter + +block_adapter = importlib.import_module(f'tinker.model.{get_block_adapter()}') + + +@dataclass +class BlockInfo: + name: str + module: torch.nn.Module + block_adapter: Optional[Type[BlockAdapter]] = None + + +def get_model_block_infos(adapter: ModelLinkAdapter) -> List[BlockInfo]: + """获取需要的profile的block列表 block粒度观测时即头处理 TransformerBlock 两个尾处理""" + block_infos = [] # type: List[BlockInfo] + args = adapter.get_args() + model = adapter.get_model() + if args.use_mcore_models: + # mcore GPTModel + block_infos.append(BlockInfo("mcore-embedding", model, block_adapter.McoreEmbeddingAdapter)) + block_infos.append( + BlockInfo("mcore-transformer-block", model.decoder, block_adapter.McoreTransformerBlockAdapter)) + block_infos.append(BlockInfo("mcore-final-norm", model.decoder, block_adapter.McoreFinalNormAdapter)) + block_infos.append(BlockInfo("mcore-post-process", model, block_adapter.McoreLossAdapter)) + + else: + # legacy GPTModel + encoder = model.language_model.encoder + + # model.language_model.pre_process + block_infos.append(BlockInfo("embedding", model.language_model, block_adapter.EmbeddingAdapter)) + + block_infos.append(BlockInfo("transformer-block", encoder, block_adapter.TransformerBlockAdapter)) + + # encoder.post_norm and encoder.post_process + block_infos.append(BlockInfo("final-norm", encoder, block_adapter.FinalNormAdapter)) + + # model.post_process + block_infos.append(BlockInfo("post-process", model, block_adapter.LossAdapter)) + + return block_infos diff --git a/profiler/tinker/model/observation_block.py b/profiler/tinker/model/observation_block.py new file mode 100644 index 0000000000000000000000000000000000000000..c948c797257c397a767fd5d2489f2332e59de9b2 --- /dev/null +++ b/profiler/tinker/model/observation_block.py @@ -0,0 +1,46 @@ +import torch + +from tinker.model.block_adapter import BlockAdapter +from tinker.model.block_infos import BlockInfo + + +class ObservationBlock(torch.nn.Module): + def __init__(self, block_adapter: BlockAdapter): + super().__init__() + self.block_adapter = block_adapter + self.name = self.block_adapter.name + # for profiling + self.weight_size = 0 + # 提取block_adapter中属性,以及特定方法和module到self中 + self.extract_init_attr(self.block_adapter.module) + self.block_adapter.copy_method_module(self) + weight_module = self.block_adapter.weight_module() + if not isinstance(weight_module, tuple): + weight_module = (weight_module,) + self.set_weight_size(*weight_module) + + def extract_init_attr(self, module): + """为了复用前向代码 将所有非module属性拿过来""" + for attr in module.__dict__: + if not attr.startswith('_') and not isinstance(attr, torch.nn.Module): + setattr(self, attr, getattr(module, attr)) + + def set_weight_size(self, *modules): + """根据入参Module 自动计算权重参数含量""" + for module in modules: + self.weight_size += sum(p.numel() for p in module.parameters() if p.requires_grad) + + def forward(self, input_tensors): + outputs = self.block_adapter.origin_forward(self=self, **input_tensors) + if not isinstance(outputs, tuple): + outputs = (outputs,) + return {k: v for k, v in zip(self.block_adapter.output_name, outputs)} + + +def gen_block(block_info: BlockInfo): + """基于block_info,给出模型实例""" + if block_info.block_adapter is not None: + return ObservationBlock(block_info.block_adapter(block_info.name, block_info.module)) + else: + return RuntimeError(f"operator {block_info.name} is not supported.") +