diff --git a/profiler/tinker/model/__init__.py b/profiler/tinker/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/profiler/tinker/model/block_adapter.py b/profiler/tinker/model/block_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..3e3b4d767188ae0a3bfd304c28d64ba16893ebaa --- /dev/null +++ b/profiler/tinker/model/block_adapter.py @@ -0,0 +1,29 @@ +from abc import ABC, abstractmethod + + +class BlockAdapter(ABC): + def __init__(self, name, module): + self.name = name + self.module = module + self.output_name = self.get_output_name() + + @abstractmethod + def origin_forward(self, *args, **kwargs): + """原生前向逻辑,其中若截断了原有前向,则需要手动确定后续逻辑要用的张量,作为返回值返回""" + pass + + @abstractmethod + def get_output_name(self): + """指定`origin_forward`返回张量的字符串名称,供匹配到下个子图的输入张量""" + return None + + @abstractmethod + def weight_module(self): + """返回该block中会占用显存的module""" + pass + + @abstractmethod + def copy_method_module(self, other): + """将`self.module`中被`origin_forward`使用的方法和module复制到other中""" + pass + \ No newline at end of file diff --git a/profiler/tinker/model/block_adapter_1_0.py b/profiler/tinker/model/block_adapter_1_0.py new file mode 100644 index 0000000000000000000000000000000000000000..b5b7dc4f3f66c80f273f71ea53a68dc6e37e4a3d --- /dev/null +++ b/profiler/tinker/model/block_adapter_1_0.py @@ -0,0 +1,327 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from contextlib import nullcontext +from dataclasses import dataclass + +import torch + +from megatron import core +from megatron.core import tensor_parallel +from megatron.model.gpt.model import post_language_model_processing +from megatron.training import get_args +from tinker.megatron_patch.microbatches import get_num_microbatches +from tinker.model.block_adapter import BlockAdapter + + +@dataclass +class TransformerForwardInput: + model_hidden_states: any = None + model_attention_mask: any = None + model_encoder_output: any = None + model_enc_model_dec_attn_mask: any = None + model_retriever_input: any = None + model_retriever_output: any = None + model_retriever_attn_mask: any = None + model_inference_params: any = None + model_rotary_pos_emb: any = None + + +class TransformerBlockAdapter(BlockAdapter): + """megatron.model.transformer.ParallelTransformer""" + + def __init__(self): + pass + + @staticmethod + def origin_forward(self, input_data: TransformerForwardInput): + model_hidden_states = input_data.model_hidden_states + model_attention_mask = input_data.model_attention_mask + model_encoder_output = input_data.model_encoder_output + model_enc_model_dec_attn_mask = input_data.model_enc_model_dec_attn_mask + model_retriever_input = input_data.model_retriever_input + model_retriever_output = input_data.model_retriever_output + model_retriever_attn_mask = input_data.model_retriever_attn_mask + model_inference_params = input_data.model_inference_params + model_rotary_pos_emb = input_data.model_rotary_pos_emb + + self.seq_length = get_args().seq_length + + # Checks. + if input_data.model_inference_params: + if self.recompute_granularity is not None: + raise ValueError('inference does not work with activation checkpointing') + + model_hidden_states = self.preprocess_model_hidden_states(model_hidden_states) + model_hidden_states = self.make_viewless_tensor(model_hidden_states) + + forward_rng_context = self.get_forward_rng_context() + forward_fp8_context = self.get_forward_fp8_context() + + with forward_rng_context, forward_fp8_context: + if self.num_microbatches_previous != get_num_microbatches(): + self.microbatch_count = 0 + self.num_microbatches_previous = get_num_microbatches() + is_first_microbatch = self.microbatch_count % get_num_microbatches() == 0 + + # Forward pass. + if self.recompute_granularity == 'full': + model_hidden_states = self.checkpointed_forward(model_hidden_states, + model_attention_mask, + model_encoder_output, + model_enc_model_dec_attn_mask, + model_rotary_pos_emb, + is_first_microbatch) + else: + forward_pass_kwargs = self.prepare_forward_pass_kwargs(model_attention_mask, + model_encoder_output, + model_enc_model_dec_attn_mask, + model_inference_params, + model_retriever_input, + model_retriever_output, + model_retriever_attn_mask, + model_rotary_pos_emb, + is_first_microbatch) + + for index in range(self.num_layers): + layer = self._get_layer(index) + model_hidden_states = layer(model_hidden_states, model_attention_mask, **forward_pass_kwargs) + + if isinstance(model_hidden_states, tuple): + if len(model_hidden_states) != 2: + raise ValueError("model_hidden_states should be a tuple of length 2") + model_hidden_states, model_retriever_output = model_hidden_states + forward_pass_kwargs['model_retriever_output'] = model_retriever_output + + self.update_microbatch_count() + return model_hidden_states + + def preprocess_model_hidden_states(self, model_hidden_states): + if not self.pre_process: + model_hidden_states = self.input_tensor + return model_hidden_states + + def make_viewless_tensor(self, model_hidden_states): + return core.utils.make_viewless_tensor(model_hidden_states, requires_grad=True, keep_graph=True) + + def get_forward_rng_context(self): + if self.sequence_parallel: + return tensor_parallel.get_cuda_rng_tracker().fork() + else: + return nullcontext() + + def get_forward_fp8_context(self): + if self.use_fp8: + return transformer_engine.pytorch.fp8_autocast(enabled=True, + fp8_recipe=self.fp8_recipe, + fp8_group=self.fp8_group) + else: + return nullcontext() + + def prepare_forward_pass_kwargs(self, **kwargs): + forward_pass_kwargs = { + 'model_encoder_output': kwargs['model_encoder_output'], + 'model_enc_model_dec_attn_mask': kwargs['model_enc_model_dec_attn_mask'], + 'model_inference_params': kwargs['model_inference_params'], + } + + if self.transformer_impl == 'transformer_engine': + forward_pass_kwargs['is_first_microbatch'] = kwargs['is_first_microbatch'] + forward_pass_kwargs['checkpoint_core_attention'] = self.checkpoint_core_attention + if self.tansformer_engine_v_0_10: + forward_pass_kwargs['model_rotary_pos_emb'] = kwargs['model_rotary_pos_emb'] + else: + forward_pass_kwargs['model_rotary_pos_emb'] = kwargs['model_rotary_pos_emb'] + forward_pass_kwargs['model_retriever_input'] = kwargs['model_retriever_input'] + forward_pass_kwargs['model_retriever_output'] = kwargs['model_retriever_output'] + forward_pass_kwargs['model_retriever_attn_mask'] = kwargs['model_retriever_attn_mask'] + + return forward_pass_kwargs + + def update_microbatch_count(self): + if torch.is_grad_enabled() and self.training: + self.microbatch_count += 1 + + def get_output_name(self): + return ['model_hidden_states'] + + def copy_method_module(self,other): + other._get_layer = self.module._get_layer + other.layers = self.module.layers + other._checkpointed_forward = self.module._checkpointed_forward # 其实没必要,因为profile时不会用到,但可起到标注作用 + if other.transformer_impl == 'transformer_engine': + global transformer_engine + + def weight_module(self): + return self.module + + +@dataclass +class EmbeddingForwardInput: + model_enc_input_ids: any = None + model_enc_position_ids: any = None + model_enc_attn_mask: any = None + model_dec_input_ids: any = None + model_dec_position_ids: any = None + model_dec_attn_mask: any = None + model_retriever_input_ids: any = None + model_retriever_position_ids: any = None + model_retriever_attn_mask: any = None + model_enc_model_dec_attn_mask: any = None + model_tokentype_ids: any = None + model_inference_params: any = None + model_pooling_sequence_index: any = 0 + model_enc_model_hidden_states: any = None + model_output_enc_hidden: any = False + + +class EmbeddingAdapter(BlockAdapter): + """megatron.model.language_model.TransformerLanguageModel.forward""" + + def __init__(self): + pass + + @staticmethod + def origin_forward(self, input_data: EmbeddingForwardInput): + model_enc_input_ids = input_data.model_enc_input_ids + model_enc_position_ids = input_data.model_enc_position_ids + model_enc_attn_mask = input_data.model_enc_attn_mask + model_dec_input_ids = input_data.model_dec_input_ids + model_dec_position_ids = input_data.model_dec_position_ids + model_dec_attn_mask = input_data.model_dec_attn_mask + model_retriever_input_ids = input_data.model_retriever_input_ids + model_retriever_position_ids = input_data.model_retriever_position_ids + model_retriever_attn_mask = input_data.model_retriever_attn_mask + model_enc_model_dec_attn_mask = input_data.model_enc_model_dec_attn_mask + model_tokentype_ids = input_data.model_tokentype_ids + model_inference_params = input_data.model_inference_params + model_pooling_sequence_index = input_data.model_pooling_sequence_index + model_enc_model_hidden_states = input_data.model_enc_model_hidden_states + model_output_enc_hidden = input_data.model_output_enc_hidden + + # Encoder embedding. + if self.pre_process: + encoder_input = self.embedding(model_enc_input_ids, model_enc_position_ids, + model_tokentype_ids=model_tokentype_ids) + else: + encoder_input = None + + # Retriever embedding. + if self.add_retriever and self.pre_process: + model_retriever_input = self.embedding(model_retriever_input_ids, + model_retriever_position_ids, + model_tokentype_ids=model_tokentype_ids) + else: + model_retriever_input = None + + # Rotary positional embeddings. + model_rotary_pos_emb = None + if self.use_rotary_position_embeddings: + if model_inference_params is not None: + model_rotary_pos_emb = self.model_rotary_pos_emb(model_inference_params.max_sequence_length) + else: + model_rotary_pos_emb = self.model_rotary_pos_emb(self.seq_length) + + return encoder_input, model_rotary_pos_emb + + def get_output_name(self): + return ['encoder_input', 'model_rotary_pos_emb'] + + def copy_method_module(self, other): + other.embedding = self.module.embedding + if other.use_rotary_position_embeddings: + other.model_rotary_pos_emb = self.module.model_rotary_pos_emb + + def weight_module(self): + return self.module.embedding + + +class FinalNormAdapter(BlockAdapter): + """ParallelTransformer forward 中 self.final_norm 相关部分""" + + def __init__(self): + pass + + @staticmethod + def origin_forward(self, model_hidden_states): + model_hidden_states = self.final_norm(model_hidden_states) + + return model_hidden_states + + def get_output_name(self): + return ['model_hidden_states'] + + def copy_method_module(self, other): + other.final_norm = self.module.final_norm + + def weight_module(self): + return self.module.final_norm + + +@dataclass +class LossForwardInput: + model_input_ids: any = None + model_position_ids: any = None + model_attention_mask: any = None + model_lm_output: any = None + model_retriever_input_ids: any = None + model_retriever_position_ids: any = None + model_retriever_attn_mask: any = None + model_labels: any = None + model_tokentype_ids: any = None + model_inference_params: any = None + + +class LossAdapter(BlockAdapter): + """modellink.model.GPTModel""" + + def __init__(self): + pass + + @staticmethod + def origin_forward(self, input_data: LossForwardInput): + model_input_ids = input_data.model_input_ids + model_position_ids = input_data.model_position_ids + model_attention_mask = input_data.model_attention_mask + model_lm_output = input_data.model_lm_output + model_retriever_input_ids = input_data.model_retriever_input_ids + model_retriever_position_ids = input_data.model_retriever_position_ids + model_retriever_attn_mask = input_data.model_retriever_attn_mask + model_labels = input_data.model_labels + model_tokentype_ids = input_data.model_tokentype_ids + model_inference_params = input_data.model_inference_params + + output = post_language_model_processing( + model_lm_output, model_labels, + # 此处不进行判断 而直接输入该形状匹配的内容 + self.logit_weights, + self.parallel_output, + self.fp16_lm_cross_entropy) + + return output + + def get_output_name(self): + return ["output"] + + def copy_method_module(self, other): + # git 适配 not self.untie_embeddings_and_output_weights 情况 + if other.untie_embeddings_and_output_weights: + other.logit_weights = self.module.language_model.output_layer.weight + else: + other.logit_weights = self.module.language_model.embedding.word_embeddings.weight + + def weight_module(self): + return self.module.language_model.output_layer + \ No newline at end of file diff --git a/profiler/tinker/model/block_infos.py b/profiler/tinker/model/block_infos.py new file mode 100644 index 0000000000000000000000000000000000000000..041a15e46895d40892437fac25ab855439e219ad --- /dev/null +++ b/profiler/tinker/model/block_infos.py @@ -0,0 +1,48 @@ +import importlib +from dataclasses import dataclass +from typing import List, Optional, Type + +import torch + +from tinker.framework_adapter.modellink_adapter import get_block_adapter, ModelLinkAdapter +from tinker.model.block_adapter import BlockAdapter + +block_adapter = importlib.import_module(f'tinker.model.{get_block_adapter()}') + + +@dataclass +class BlockInfo: + name: str + module: torch.nn.Module + block_adapter: Optional[Type[BlockAdapter]] = None + + +def get_model_block_infos(adapter: ModelLinkAdapter) -> List[BlockInfo]: + """获取需要的profile的block列表 block粒度观测时即头处理 TransformerBlock 两个尾处理""" + block_infos = [] # type: List[BlockInfo] + args = adapter.get_args() + model = adapter.get_model() + if args.use_mcore_models: + # mcore GPTModel + block_infos.append(BlockInfo("mcore-embedding", model, block_adapter.McoreEmbeddingAdapter)) + block_infos.append( + BlockInfo("mcore-transformer-block", model.decoder, block_adapter.McoreTransformerBlockAdapter)) + block_infos.append(BlockInfo("mcore-final-norm", model.decoder, block_adapter.McoreFinalNormAdapter)) + block_infos.append(BlockInfo("mcore-post-process", model, block_adapter.McoreLossAdapter)) + + else: + # legacy GPTModel + encoder = model.language_model.encoder + + # model.language_model.pre_process + block_infos.append(BlockInfo("embedding", model.language_model, block_adapter.EmbeddingAdapter)) + + block_infos.append(BlockInfo("transformer-block", encoder, block_adapter.TransformerBlockAdapter)) + + # encoder.post_norm and encoder.post_process + block_infos.append(BlockInfo("final-norm", encoder, block_adapter.FinalNormAdapter)) + + # model.post_process + block_infos.append(BlockInfo("post-process", model, block_adapter.LossAdapter)) + + return block_infos diff --git a/profiler/tinker/model/observation_block.py b/profiler/tinker/model/observation_block.py new file mode 100644 index 0000000000000000000000000000000000000000..c948c797257c397a767fd5d2489f2332e59de9b2 --- /dev/null +++ b/profiler/tinker/model/observation_block.py @@ -0,0 +1,46 @@ +import torch + +from tinker.model.block_adapter import BlockAdapter +from tinker.model.block_infos import BlockInfo + + +class ObservationBlock(torch.nn.Module): + def __init__(self, block_adapter: BlockAdapter): + super().__init__() + self.block_adapter = block_adapter + self.name = self.block_adapter.name + # for profiling + self.weight_size = 0 + # 提取block_adapter中属性,以及特定方法和module到self中 + self.extract_init_attr(self.block_adapter.module) + self.block_adapter.copy_method_module(self) + weight_module = self.block_adapter.weight_module() + if not isinstance(weight_module, tuple): + weight_module = (weight_module,) + self.set_weight_size(*weight_module) + + def extract_init_attr(self, module): + """为了复用前向代码 将所有非module属性拿过来""" + for attr in module.__dict__: + if not attr.startswith('_') and not isinstance(attr, torch.nn.Module): + setattr(self, attr, getattr(module, attr)) + + def set_weight_size(self, *modules): + """根据入参Module 自动计算权重参数含量""" + for module in modules: + self.weight_size += sum(p.numel() for p in module.parameters() if p.requires_grad) + + def forward(self, input_tensors): + outputs = self.block_adapter.origin_forward(self=self, **input_tensors) + if not isinstance(outputs, tuple): + outputs = (outputs,) + return {k: v for k, v in zip(self.block_adapter.output_name, outputs)} + + +def gen_block(block_info: BlockInfo): + """基于block_info,给出模型实例""" + if block_info.block_adapter is not None: + return ObservationBlock(block_info.block_adapter(block_info.name, block_info.module)) + else: + return RuntimeError(f"operator {block_info.name} is not supported.") +