From 149dd8070062df08fa1638d9fe9a928db313075a Mon Sep 17 00:00:00 2001 From: avocadovo Date: Thu, 2 Jan 2025 10:25:47 +0800 Subject: [PATCH 1/2] =?UTF-8?q?=E3=80=90Tinker=E3=80=91=E8=87=AA=E5=8A=A8?= =?UTF-8?q?=E5=B9=B6=E8=A1=8C=E4=BC=98=E5=8C=96=E5=B7=A5=E5=85=B7=20?= =?UTF-8?q?=E6=A8=A1=E5=9D=97=20model?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- profiler/tinker/model/__init__.py | 0 profiler/tinker/model/block_adapter.py | 29 +++ profiler/tinker/model/block_adapter_1_0.py | 239 ++++++++++++++++++ profiler/tinker/model/block_adapter_1_1.py | 242 ++++++++++++++++++ profiler/tinker/model/block_adapter_1_2.py | 270 +++++++++++++++++++++ profiler/tinker/model/block_infos.py | 48 ++++ profiler/tinker/model/observation_block.py | 46 ++++ 7 files changed, 874 insertions(+) create mode 100644 profiler/tinker/model/__init__.py create mode 100644 profiler/tinker/model/block_adapter.py create mode 100644 profiler/tinker/model/block_adapter_1_0.py create mode 100644 profiler/tinker/model/block_adapter_1_1.py create mode 100644 profiler/tinker/model/block_adapter_1_2.py create mode 100644 profiler/tinker/model/block_infos.py create mode 100644 profiler/tinker/model/observation_block.py diff --git a/profiler/tinker/model/__init__.py b/profiler/tinker/model/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/profiler/tinker/model/block_adapter.py b/profiler/tinker/model/block_adapter.py new file mode 100644 index 0000000000..18dadeef1b --- /dev/null +++ b/profiler/tinker/model/block_adapter.py @@ -0,0 +1,29 @@ +from abc import ABC, abstractmethod + + +class BlockAdapter(ABC): + def __init__(self, name, module): + self.name = name + self.module = module + self.output_name = self.get_output_name() + + @abstractmethod + def origin_forward(self, *args, **kwargs): + """原生前向逻辑,其中若截断了原有前向,则需要手动确定后续逻辑要用的张量,作为返回值返回""" + pass + + @abstractmethod + def get_output_name(self): + """指定`origin_forward`返回张量的字符串名称,供匹配到下个子图的输入张量""" + pass + + @abstractmethod + def weight_module(self): + """返回该block中会占用显存的module""" + pass + + @abstractmethod + def copy_method_module(self, other): + """将`self.module`中被`origin_forward`使用的方法和module复制到other中""" + pass + \ No newline at end of file diff --git a/profiler/tinker/model/block_adapter_1_0.py b/profiler/tinker/model/block_adapter_1_0.py new file mode 100644 index 0000000000..1b27313d88 --- /dev/null +++ b/profiler/tinker/model/block_adapter_1_0.py @@ -0,0 +1,239 @@ +from contextlib import nullcontext + +import torch + +from megatron import core +from megatron.core import tensor_parallel +from megatron.model.gpt.model import post_language_model_processing +from megatron.training import get_args +from tinker.megatron_patch.microbatches import get_num_microbatches +from tinker.model.block_adapter import BlockAdapter + + +class TransformerBlockAdapter(BlockAdapter): + """megatron.model.transformer.ParallelTransformer""" + + def get_output_name(self): + return ['hidden_states'] + + def copy_method_module(self,other): + other._get_layer = self.module._get_layer + other.layers = self.module.layers + other._checkpointed_forward = self.module._checkpointed_forward # 其实没必要,因为profile时不会用到,但可起到标注作用 + if other.transformer_impl == 'transformer_engine': + global transformer_engine + + def weight_module(self): + return self.module + + @staticmethod + def origin_forward(self,hidden_states=None,attention_mask=None, + encoder_output=None,enc_dec_attn_mask=None, + retriever_input=None, + retriever_output=None, + retriever_attn_mask=None, + inference_params=None, + rotary_pos_emb=None): + self.seq_length = get_args().seq_length + # hidden_states: [s, b, h] + + # Checks. + if inference_params: + assert self.recompute_granularity is None, \ + 'inference does not work with activation checkpointing' + + if not self.pre_process: + # See set_input_tensor() + hidden_states = self.input_tensor + + # Viewless tensor. + # - We only need to create a viewless tensor in the case of mirco batch + # size (mbs) == 1, since in this case, 'hidden_states.transpose()' + # above creates a view tensor, and '.contiguous()' is a pass-through. + # For mbs >= 2, '.contiguous()' creates a new tensor, eliminating + # the need to make it viewless. + # + # However, we don't explicitly check mbs == 1 here because + # make_viewless_tensor() has negligible overhead when its input + # is already viewless. + # + # - For the 'else' case above,calling make_viewless_tensor() here is + # likely redundant, since p2p_communication.py (likely originator) + # already creates viewless tensors. That said, make_viewless_tensor() + # is called here to be future-proof and corner-case-proof. + hidden_states = core.utils.make_viewless_tensor( + hidden_states, + requires_grad=True, + keep_graph=True, + ) + + # RNG context. + if self.sequence_parallel: + rng_context = tensor_parallel.get_cuda_rng_tracker().fork() + else: + rng_context = nullcontext() + + # Forward layers. + with rng_context: + # The fp8_autocast context manager is a no-op when enable=True + # The if...else serves to short circuit name resolution for fp8_autocast + with transformer_engine.pytorch.fp8_autocast( + enabled=self.use_fp8, + fp8_recipe=self.fp8_recipe, + fp8_group=self.fp8_group + ) if self.use_fp8 else nullcontext(): + # Determine if the current iteration is first microbatch + if self.num_microbatches_in_previous_step != get_num_microbatches(): + self.microbatch_count = 0 # Reset count on new batch size rampup interval + self.num_microbatches_in_previous_step = get_num_microbatches() + is_first_microbatch = self.microbatch_count % get_num_microbatches() == 0 + + # Forward pass. + if self.recompute_granularity == 'full': + hidden_states = self.checkpointed_forward(hidden_states, + attention_mask, + encoder_output, + enc_dec_attn_mask, + rotary_pos_emb, + is_first_microbatch) + else: + forward_kwargs = { + 'encoder_output': encoder_output, + 'enc_dec_attn_mask': enc_dec_attn_mask, + 'inference_params': inference_params, + } + + if self.transformer_impl == 'transformer_engine': + forward_kwargs['is_first_microbatch'] = is_first_microbatch + forward_kwargs['checkpoint_core_attention'] = self.checkpoint_core_attention + if self.tansformer_engine_v_0_10: + forward_kwargs['rotary_pos_emb'] = rotary_pos_emb + else: + forward_kwargs['rotary_pos_emb'] = rotary_pos_emb + forward_kwargs['retriever_input'] = retriever_input + forward_kwargs['retriever_output'] = retriever_output + forward_kwargs['retriever_attn_mask'] = retriever_attn_mask + + for index in range(self.num_layers): + layer = self._get_layer(index) + + hidden_states = layer( + hidden_states, + attention_mask, + **forward_kwargs) + + # First Retro decoder layer returns both hidden_states + # add retriever_output. Make retriever_output available + # to subsequence Retro layers. + if isinstance(hidden_states, tuple): + assert len(hidden_states) == 2 + hidden_states, retriever_output = hidden_states + forward_kwargs['retriever_output'] = retriever_output + + # Skip counter update for eval and activating checkpointing + if torch.is_grad_enabled() and self.training: + self.microbatch_count += 1 + return hidden_states + + +class EmbeddingAdapter(BlockAdapter): + """megatron.model.language_model.TransformerLanguageModel.forward""" + + def get_output_name(self): + return ['encoder_input', 'rotary_pos_emb'] + + def copy_method_module(self, other): + other.embedding = self.module.embedding + if other.use_rotary_position_embeddings: + other.rotary_pos_emb = self.module.rotary_pos_emb + + def weight_module(self): + return self.module.embedding + + @staticmethod + def origin_forward(self, enc_input_ids=None, enc_position_ids=None, enc_attn_mask=None, + dec_input_ids=None, dec_position_ids=None, dec_attn_mask=None, + retriever_input_ids=None, + retriever_position_ids=None, + retriever_attn_mask=None, + enc_dec_attn_mask=None, tokentype_ids=None, + inference_params=None, + pooling_sequence_index=0, + enc_hidden_states=None, output_enc_hidden=False): + # Encoder embedding. + if self.pre_process: + encoder_input = self.embedding(enc_input_ids, enc_position_ids, + tokentype_ids=tokentype_ids) + else: + encoder_input = None + + # Retriever embedding. + if self.add_retriever and self.pre_process: + retriever_input = self.embedding(retriever_input_ids, + retriever_position_ids, + tokentype_ids=tokentype_ids) + else: + retriever_input = None + + # Rotary positional embeddings. + rotary_pos_emb = None + if self.use_rotary_position_embeddings: + if inference_params is not None: + rotary_pos_emb = \ + self.rotary_pos_emb(inference_params.max_sequence_length) + else: + rotary_pos_emb = self.rotary_pos_emb(self.seq_length) + + return encoder_input, rotary_pos_emb + + +class FinalNormAdapter(BlockAdapter): + """ParallelTransformer forward 中 self.final_norm 相关部分""" + + def get_output_name(self): + return ['hidden_states'] + + def copy_method_module(self, other): + other.final_norm = self.module.final_norm + + def weight_module(self): + return self.module.final_norm + + @staticmethod + def origin_forward(self, hidden_states): + hidden_states = self.final_norm(hidden_states) + + return hidden_states + + +class LossAdapter(BlockAdapter): + """modellink.model.GPTModel""" + + def get_output_name(self): + return ["output"] + + def copy_method_module(self, other): + # git 适配 not self.untie_embeddings_and_output_weights 情况 + if other.untie_embeddings_and_output_weights: + other.logit_weights = self.module.language_model.output_layer.weight + else: + other.logit_weights = self.module.language_model.embedding.word_embeddings.weight + + def weight_module(self): + return self.module.language_model.output_layer + + @staticmethod + def origin_forward(self, input_ids=None, position_ids=None, attention_mask=None, lm_output=None, + retriever_input_ids=None, + retriever_position_ids=None, + retriever_attn_mask=None, + labels=None, tokentype_ids=None, inference_params=None): + output = post_language_model_processing( + lm_output, labels, + # 此处不进行判断 而直接输入该形状匹配的内容 + self.logit_weights, + self.parallel_output, + self.fp16_lm_cross_entropy) + + return output + \ No newline at end of file diff --git a/profiler/tinker/model/block_adapter_1_1.py b/profiler/tinker/model/block_adapter_1_1.py new file mode 100644 index 0000000000..17ef2bedd0 --- /dev/null +++ b/profiler/tinker/model/block_adapter_1_1.py @@ -0,0 +1,242 @@ +from contextlib import nullcontext + +import torch + +from megatron import core +from megatron.core import tensor_parallel +from megatron.legacy.model.gpt_model import post_language_model_processing +from tinker.megatron_patch.microbatches import get_num_microbatches +from tinker.model.block_adapter import BlockAdapter + + +class TransformerBlockAdapter(BlockAdapter): + """parallel_transformer_forward""" + + def get_output_name(self): + return ['hidden_states'] + + def copy_method_module(self, other): + other._get_layer = self.module._get_layer + other.layers = self.module.layers + other._checkpointed_forward = self.module._checkpointed_forward # 其实没必要,因为profile时不会用到,但可起到标注作用 + if other.transformer_impl == 'transformer_engine': + global transformer_engine + # noinsepction PyUnresolvedReferences + import transformer_engine + + def weight_module(self): + return self.module + + @staticmethod + def origin_forward(self, hidden_states=None, attention_mask=None, + encoder_output=None, enc_dec_attn_mask=None, + retriever_input=None, + retriever_output=None, + retriever_attn_mask=None, + inference_params=None, + rotary_pos_emb=None): + # hidden_states: [s, b, h] + + #Checks. + if inference_params: + assert self.recompute_granularity is None, \ + 'inference does not work with activation checkpointing' + + if not self.pre_process: + # See set_input_tensor() + hidden_states = self.input_tensor + + # Viewless tensor + # - We only need to create a viewless tensor in the case of micro batch + # size (mbs) == 1, since in this case, 'hidden_states.transpose()' + # above creates a view tensor, and '.contiguous()' is a pass-through. + # For mbs >= 2, '.contiguous()' creates a new tensor, eliminating + # the need to make it viewless. + # + # However, we don't explicitly check mbs == 1 here because + # make_viewless_tensor() has negligible overhead when its input + # is already viewless. + # + # - For the 'else' case above, calling make_viewless_tensor() here is + # likely redundant, since p2p_communication.py (likely originator) + # already creates viewless tensors. That said, make_viewless_tensor() + # is called here to be future-proof and corner-case-proof. + if self.input_embeds_norm and self.pre_process: + hidden_states = hidden_states * (self.hidden_size ** 0.5) + + hidden_states = core.utils.make_viewless_tensor( + hidden_states, + requires_grad=True, + keep_graph=True, + ) + + # RNG context. + if self.sequence_parallel: + rng_context = tensor_parallel.get_cuda_rng_tracker().fork() + else: + rng_context = nullcontext() + + # Forward layers + with rng_context: + # The fp8_autocast context manager is a no-op when enabled=True + # The if...else serves to short circuit name resolution for fp8_autocast + with transformer_engine.pytorch.fp8_autocast( + enabled=self.use_fp8, + fp8_recipe=self.fp8_recipe, + fp8_group=self.fp8_group + ) if self.use_fp8 else nullcontext(): + # Determine if the current iteration is first microbatch + if self.num_microbatches_in_previous_step != get_num_microbatches(): + self.microbatch_count = 0 # Reset count on new batch size rampup interval + self.num_microbatches_in_previous_step = get_num_microbatches() + is_first_microbatch = self.microbatch_count % get_num_microbatches() == 0 + + # Forward pass. + if self.recompute_granularity == 'full': + hidden_states = self._checkpointed_forward(hidden_states, + attention_mask, + encoder_output, + enc_dec_attn_mask, + rotary_pos_emb, + is_first_microbatch) + else: + forward_kwargs = { + 'encoder_output': encoder_output, + 'enc_dec_attn_mask': enc_dec_attn_mask, + 'inference_params': inference_params, + } + + if self.transformer_impl == 'transformer_engine': + forward_kwargs['is_first_microbatch'] = is_first_microbatch + forward_kwargs['checkpoint_core_attention'] = self.checkpoint_core_attention + if self.transformer_engine_v_0_10: + forward_kwargs['rotary_pos_emb'] = rotary_pos_emb + else: + forward_kwargs['rotary_pos_emb'] = rotary_pos_emb + forward_kwargs['retriever_input'] = retriever_input + forward_kwargs['retriever_output'] = retriever_output + forward_kwargs['retriever_attn_mask'] = retriever_attn_mask + + for index in range(self.num_layers): + layer = self._get_layer(index) + + hidden_states = layer( + hidden_states, + attention_mask, + **forward_kwargs) + + # First Retro decoder layer returns both hidden_states + # and retriever_output. Make retriever_output available + # to subsequence Retro layers. + if isinstance(hidden_states, tuple): + assert len(hidden_states) == 2 + hidden_states, retriever_output = hidden_states + forward_kwargs["retriever_output"] = retriever_output + + # Skip counter update for eval and activation checkpointing + if torch.is_grad_enabled() and self.training: + self.microbatch_count += 1 + return hidden_states + + +class EmbeddingAdapter(BlockAdapter): + """megatron.legacy.model.language_model.TransformerLanguageModel.forward""" + + def get_output_name(self): + return ['encoder_input', 'rotary_pos_emb'] + + def copy_method_module(self, other): + other.embedding = self.module.embedding + if other.use_rotary_position_embeddings: + other.rotary_pos_emb = self.module.rotary_pos_emb + + def weight_module(self): + return self.module.embedding + + @staticmethod + def origin_forward(self, enc_input_ids=None, enc_position_ids=None, enc_attn_mask=None, + dec_input_ids=None, dec_position_ids=None, dec_attn_mask=None, + retriever_input_ids=None, + retriever_position_ids=None, + retriever_attn_mask=None, + enc_dec_attn_mask=None, tokentype_ids=None, + inference_params=None, + pooling_sequence_index=0, + enc_hidden_states=None, output_enc_hidden=False): + # Encoder embedding. + if self.pre_process: + encoder_input = self.embedding(enc_input_ids, enc_position_ids, + tokentype_ids=tokentype_ids) + else: + encoder_input = None + + # Retriever embedding. + if self.add_retriever and self.pre_process: + retriever_input = self.embedding(retriever_input_ids, + retriever_position_ids, + tokentype_ids=tokentype_ids) + else: + retriever_input = None + + # Rotary positional embeddings + rotary_pos_emb = None + if self.use_rotary_position_embeddings: + if inference_params is not None: + rotary_pos_emb = \ + self.rotary_pos_emb(inference_params.max_sequence_length) + else: + rotary_pos_emb = self.rotary_pos_emb(self.seq_length) + + return encoder_input, rotary_pos_emb + + +class FinalNormAdapter(BlockAdapter): + """ParallelTransformer forward 中 self.final_norm 相关部分""" + + def get_output_name(self): + return ['hidden_states'] + + def copy_method_module(self, other): + other.final_norm = self.module.final_norm + + def weight_module(self): + return self.module.final_norm + + @staticmethod + def origin_forward(self, hidden_states): + hidden_states = self.final_norm(hidden_states) + + return hidden_states + + +class LossAdapter(BlockAdapter): + """modellink.model.GPTModel""" + + def get_output_name(self): + return ["output"] + + def copy_method_module(self, other): + # git 适配 not self.untie_embeddings_and_output_weights 情况 + if other.untie_embeddings_and_output_weights: + other.logit_weights = self.module.language_model.output_layer.weight + else: + other.logit_weights = self.module.language_model.embedding.word_embeddings.weight + + def weight_module(self): + return self.module.language_model.output_layer + + @staticmethod + def origin_forward(self, input_ids=None, position_ids=None, attention_mask=None, lm_output=None, + retriever_input_ids=None, + retriever_position_ids=None, + retriever_attn_mask=None, + labels=None, tokentype_ids=None, inference_params=None): + output = post_language_model_processing( + lm_output, labels, + # 此处不进行判断 而直接输入该形状匹配的内容 + self.logit_weights, + self.parallel_output, + self.fp16_lm_cross_entropy) + + return output + \ No newline at end of file diff --git a/profiler/tinker/model/block_adapter_1_2.py b/profiler/tinker/model/block_adapter_1_2.py new file mode 100644 index 0000000000..31e418798d --- /dev/null +++ b/profiler/tinker/model/block_adapter_1_2.py @@ -0,0 +1,270 @@ +from contextlib import nullcontext + + +import torch +from torch import Tensor + +from megatron.core import parallel_state, tensor_parallel, InferenceParams +from megatron.core.packed_seq_params import PackedSeqParams +from meagtron.core.utils import make_viewless_tensor +from megatron.training import get_args +from tinker.model import block_adapter_1_1 +from tinker.model.block_adapter import BlockAdapter + +EmbeddingAdapter = block_adapter_1_1.EmbeddingAdapter +TransformerBlockAdapter = block_adapter_1_1.TransformerBlockAdapter +FinalNormAdapter = block_adapter_1_1.FinalNormAdapter +LossAdapter = block_adapter_1_1.LossAdapter + + +class DummyDecoder: + def __init__(self): + self.input_tensor = None + + +class McoreEmbeddingAdapter(BlockAdapter): + """modellink.core.gpt_model_forward""" + + def get_output_name(self): + return ['decoder_input', 'rotary_pos_emb'] + + def copy_method_module(self, other): + # 此decoder用于避免前向报错,但其本身值并不被使用 + other.decoder = DummyDecoder() + other.embedding = self.module.embedding + other.rotary_pos_emb = self.module.rotary_pos_emb + + def weight_module(self): + return self.module.embedding, self.module.rotary_pos_emb + + @staticmethod + def origin_forward(self, input_ids: Tensor = None, + position_ids: Tensor = None, attention_mask: Tensor = None, + decoder_input: Tensor = None, + labels: Tensor = None, + inference_params: InferenceParams = None, + packed_seq_params: PackedSeqParams = None, + extra_block_kwargs: dict = None, + tokentype_ids=None): + """modellink.core.gpt_model_forward""" + # If decoder_input is provides (not None), then input_ids and position_ids are ignored. + # Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input. + args = get_args() + # Decoder embedding. + if decoder_input is not None: + pass + elif self.pre_process: + decoder_input = self.embedding(input_ids=input_ids, position_ids=position_ids) + if args.scale_emb is not None: + decoder_input = decoder_input * args.scale_emb + else: + # intermediate stage of pipeline + # decoder will get hidden_states from encoder.input_tensor + decoder_input = None + + # Rotary positional embeddings (embedding is None for PP intermediate devices) + rotary_pos_emb = None + if self.position_embedding_type == 'rope': + rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len( + inference_params, self.decoder, decoder_input, self.config + ) + rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len) + + return decoder_input, rotary_pos_emb + + +class McoreTransformerBlockAdapter(BlockAdapter): + """megatron.core.transformer.transformer_block.TransformerBlock""" + def get_output_name(self): + return ['hidden_states'] + + def copy_method_module(self, other): + other.layers = self.module.layers + # other.group_prefetch_offload_commit_async = self.module.group_prefetch_offload_commit_async + # other.input_embeds_norm = self.module.input_embeds_norm + # other.offload_context = self.module.offload_context + + + def weight_module(self): + return self.module.layers + + @staticmethod + def origin_forward( + self, + hidden_states: Tensor = None, + attention_mask: Tensor = None, + context: Tensor = None, + context_mask: Tensor = None, + rotary_pos_emb: Tensor = None, + inference_params: InferenceParams = None, + packed_seq_params: PackedSeqParams = None, + ): + """modellink.core.transformer_block_forward""" + # hidden_states (float): [s, b, h] + # attention_mask (bool): [1, 1, s, s] + + if not self.pre_process: + # See set_input_tensor() + hidden_states = self.input_tensor + + # Viewless tensor. + # - We only need to create a viewless tensor in the case of micro batch + # size (mbs) == 1, since in this case, 'hidden_states.transpose()' + # above creates a view tensor, and '.contiguous()' is a pass-through. + # For mbs >= 2, '.contiguous()' creates a new tensor, eliminating + # the need to make it viewless. + # + # However, we don't explicitly check mbs == 1 here because + # make_viewless_tensor() has negligible overhead when its input + # is already viewless. + # + # - For the 'else' case above, calling make_viewless_tensor() here is + # likely redundant, since p2p_communication.py (likely originator) + # already creates viewless tensors. That said, make_viewless_tensor() + # is called here to be future-proof and corner-case-proof. + if self.input_embeds_norm and self.pre_process: + normalizer = torch.tensor(self.hidden_size ** 0.5, dtype=hidden_states.dtype) + hidden_states = hidden_states * normalizer + + hidden_states = make_viewless_tensor( + inp=hidden_states, requires_grad=True, keep_graph=True, + ) + + if self.config.sequence_parallel: + rng_context = tensor_parallel.get_cuda_rng_tracker().fork() + else: + rng_context = nullcontext() + + if self.config.fp8: + import transformer_engine # To keep out TE dependency when not training in fp8 + + if self.config.fp8 == "e4m3": + fp8_format = transformer_engine.common.recipe.Format.E4M3 + elif self.config.fp8 == "hybrid": + fp8_format = transformer_engine.common.recipe.Format.HYBRID + else: + raise ValueError("E4M3 and HYBRID are the only supported FP8 formats.") + + fp8_recipe = transformer_engine.common.recipe.DelayedScaling( + margin=self.config.fp8_margin, + interval=self.config.fp8_interval, + fp8_format=fp8_format, + amax_compute_algo=self.config.fp8_amax_compute_algo, + amax_history_len=self.config.fp8_amax_history_len, + override_linear_precision=(False, False, not self.config.fp8_wgrad), + ) + fp8_group = None + if parallel_state.model_parallel_is_initialized(): + fp8_group = parallel_state.get_amax_reduction_group(with_context_parallel=True) + fp8_context = transformer_engine.pytorch.fp8_autocast( + enabled=True, fp8_recipe=fp8_recipe, fp8_group=fp8_group + ) + else: + fp8_context = nullcontext() + + with rng_context and fp8_context: + # Forward pass. + if self.config.recompute_granularity == 'full' and self.training: + hidden_states = self._checkpointed_forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + context=context, + context_mask=context_mask, + rotary_pos_emb=rotary_pos_emb, + packed_seq_params=packed_seq_params, + ) + else: + for layer in self.layers: + with self.offload_context: + hidden_states, context = layer( + hidden_states=hidden_states, + attention_mask=attention_mask, + context=context, + context_mask=context_mask, + rotary_pos_emb=rotary_pos_emb, + inference_params=inference_params, + packed_seq_params=packed_seq_params, + ) + + if ( + torch.is_grad_enabled() + and self.config.cpu_offloading + and self.group_prefetch_offload_commit_async is not None + ): + hidden_states = self.group_prefetch_offload_commit_async(hidden_states) + + return hidden_states + + +class McoreFinalNormAdapter(BlockAdapter): + def get_output_name(self): + return ['hidden_states'] + + def copy_method_module(self, other): + other.final_layernorm = self.module.final_layernorm + + def weight_module(self): + return self.module.final_layernorm + + @staticmethod + def origin_forward(self, hidden_states: Tensor = None): + """modellink.core.transformer_block_forward""" + hidden_states = self.final_layernorm(hidden_states) + + return hidden_states + + +class McoreLossAdapter(BlockAdapter): + """modellink.core.gpt_model_forward""" + + def get_output_name(self): + return ['loss'] + + def copy_method_module(self, other): + other.output_layer = self.module.output_layer + other.compute_language_model_loss = self.module.compute_language_model_loss + other.shared_embedding_or_output_weight = self.module.shared_embedding_or_output_weight + + def weight_module(self): + return self.module.output_layer + + @staticmethod + def origin_forward(self, input_ids: Tensor = None, label=None, hidden_states=None, + position_ids: Tensor = None, attention_mask: Tensor = None, + decoder_input: Tensor = None, + labels: Tensor = None, + inference_params: InferenceParams = None, + packed_seq_params: PackedSeqParams = None, + extra_block_kwargs: dict = None, + tokentype_ids=None): + """modellink.core.gpt_model_forward""" + args = get_args() + # logits and loss + output_weight = None + if self.share_embeddings_and_output_weights: + output_weight = self.shared_embedding_or_output_weight() + + if args.dim_model_base is not None: + hidden_states = hidden_states / (args.hidden_size / args.dim_model_base) + logits, _ = self.output_layer(hidden_states, weight=output_weight) + # new add to scale logits + if args.output_multiplier_scale: + logits = logits * args.output_multiplier_scale + + if args.output_logit_softcapping: + logits = logits / args.output_logit_softcapping + logits = torch.tanh(logits) + logits = logits * args.output_logit_softcapping + + if labels is None: + # [s b h] => [b s h] + return logits.transpose(0, 1).contiguous() + + if args.is_instruction_dataset: + labels = labels[:, 1:].contiguous() + logits = logits[:-1, :, :].contiguous() + + loss = self.compute_language_model_loss(labels, logits) + + return loss + \ No newline at end of file diff --git a/profiler/tinker/model/block_infos.py b/profiler/tinker/model/block_infos.py new file mode 100644 index 0000000000..041a15e468 --- /dev/null +++ b/profiler/tinker/model/block_infos.py @@ -0,0 +1,48 @@ +import importlib +from dataclasses import dataclass +from typing import List, Optional, Type + +import torch + +from tinker.framework_adapter.modellink_adapter import get_block_adapter, ModelLinkAdapter +from tinker.model.block_adapter import BlockAdapter + +block_adapter = importlib.import_module(f'tinker.model.{get_block_adapter()}') + + +@dataclass +class BlockInfo: + name: str + module: torch.nn.Module + block_adapter: Optional[Type[BlockAdapter]] = None + + +def get_model_block_infos(adapter: ModelLinkAdapter) -> List[BlockInfo]: + """获取需要的profile的block列表 block粒度观测时即头处理 TransformerBlock 两个尾处理""" + block_infos = [] # type: List[BlockInfo] + args = adapter.get_args() + model = adapter.get_model() + if args.use_mcore_models: + # mcore GPTModel + block_infos.append(BlockInfo("mcore-embedding", model, block_adapter.McoreEmbeddingAdapter)) + block_infos.append( + BlockInfo("mcore-transformer-block", model.decoder, block_adapter.McoreTransformerBlockAdapter)) + block_infos.append(BlockInfo("mcore-final-norm", model.decoder, block_adapter.McoreFinalNormAdapter)) + block_infos.append(BlockInfo("mcore-post-process", model, block_adapter.McoreLossAdapter)) + + else: + # legacy GPTModel + encoder = model.language_model.encoder + + # model.language_model.pre_process + block_infos.append(BlockInfo("embedding", model.language_model, block_adapter.EmbeddingAdapter)) + + block_infos.append(BlockInfo("transformer-block", encoder, block_adapter.TransformerBlockAdapter)) + + # encoder.post_norm and encoder.post_process + block_infos.append(BlockInfo("final-norm", encoder, block_adapter.FinalNormAdapter)) + + # model.post_process + block_infos.append(BlockInfo("post-process", model, block_adapter.LossAdapter)) + + return block_infos diff --git a/profiler/tinker/model/observation_block.py b/profiler/tinker/model/observation_block.py new file mode 100644 index 0000000000..c948c79725 --- /dev/null +++ b/profiler/tinker/model/observation_block.py @@ -0,0 +1,46 @@ +import torch + +from tinker.model.block_adapter import BlockAdapter +from tinker.model.block_infos import BlockInfo + + +class ObservationBlock(torch.nn.Module): + def __init__(self, block_adapter: BlockAdapter): + super().__init__() + self.block_adapter = block_adapter + self.name = self.block_adapter.name + # for profiling + self.weight_size = 0 + # 提取block_adapter中属性,以及特定方法和module到self中 + self.extract_init_attr(self.block_adapter.module) + self.block_adapter.copy_method_module(self) + weight_module = self.block_adapter.weight_module() + if not isinstance(weight_module, tuple): + weight_module = (weight_module,) + self.set_weight_size(*weight_module) + + def extract_init_attr(self, module): + """为了复用前向代码 将所有非module属性拿过来""" + for attr in module.__dict__: + if not attr.startswith('_') and not isinstance(attr, torch.nn.Module): + setattr(self, attr, getattr(module, attr)) + + def set_weight_size(self, *modules): + """根据入参Module 自动计算权重参数含量""" + for module in modules: + self.weight_size += sum(p.numel() for p in module.parameters() if p.requires_grad) + + def forward(self, input_tensors): + outputs = self.block_adapter.origin_forward(self=self, **input_tensors) + if not isinstance(outputs, tuple): + outputs = (outputs,) + return {k: v for k, v in zip(self.block_adapter.output_name, outputs)} + + +def gen_block(block_info: BlockInfo): + """基于block_info,给出模型实例""" + if block_info.block_adapter is not None: + return ObservationBlock(block_info.block_adapter(block_info.name, block_info.module)) + else: + return RuntimeError(f"operator {block_info.name} is not supported.") + -- Gitee From 38feb281e0598f66c15a81e8478180a3d125f7f6 Mon Sep 17 00:00:00 2001 From: avocadovo Date: Thu, 2 Jan 2025 17:11:21 +0800 Subject: [PATCH 2/2] =?UTF-8?q?=E3=80=90Tinker=E3=80=91=E4=BF=AE=E6=94=B9?= =?UTF-8?q?=20model=20codecheck?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- profiler/tinker/model/block_adapter_1_0.py | 221 +++++++++++++++------ profiler/tinker/model/block_adapter_1_1.py | 221 +++++++++++++++------ profiler/tinker/model/block_adapter_1_2.py | 194 ++++++++++++------ 3 files changed, 445 insertions(+), 191 deletions(-) diff --git a/profiler/tinker/model/block_adapter_1_0.py b/profiler/tinker/model/block_adapter_1_0.py index 1b27313d88..6e5131c308 100644 --- a/profiler/tinker/model/block_adapter_1_0.py +++ b/profiler/tinker/model/block_adapter_1_0.py @@ -1,4 +1,20 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from contextlib import nullcontext +from dataclasses import dataclass import torch @@ -10,37 +26,43 @@ from tinker.megatron_patch.microbatches import get_num_microbatches from tinker.model.block_adapter import BlockAdapter +@dataclass +class TransformerForwardInput: + hidden_states: any = None + attention_mask: any = None + encoder_output: any = None + enc_dec_attn_mask: any = None + retriever_input: any = None + retriever_output: any = None + retriever_attn_mask: any = None + inference_params: any = None + rotary_pos_emb: any = None + + class TransformerBlockAdapter(BlockAdapter): """megatron.model.transformer.ParallelTransformer""" - def get_output_name(self): - return ['hidden_states'] - - def copy_method_module(self,other): - other._get_layer = self.module._get_layer - other.layers = self.module.layers - other._checkpointed_forward = self.module._checkpointed_forward # 其实没必要,因为profile时不会用到,但可起到标注作用 - if other.transformer_impl == 'transformer_engine': - global transformer_engine - - def weight_module(self): - return self.module - + def __init__(self): + pass + @staticmethod - def origin_forward(self,hidden_states=None,attention_mask=None, - encoder_output=None,enc_dec_attn_mask=None, - retriever_input=None, - retriever_output=None, - retriever_attn_mask=None, - inference_params=None, - rotary_pos_emb=None): + def origin_forward(self, input_data: TransformerForwardInput): + hidden_states = input_data.hidden_states + attention_mask = input_data.attention_mask + encoder_output = input_data.encoder_output + enc_dec_attn_mask = input_data.enc_dec_attn_mask + retriever_input = input_data.retriever_input + retriever_output = input_data.retriever_output + retriever_attn_mask = input_data.retriever_attn_mask + inference_params = input_data.inference_params + rotary_pos_emb = input_data.rotary_pos_emb + self.seq_length = get_args().seq_length - # hidden_states: [s, b, h] # Checks. - if inference_params: - assert self.recompute_granularity is None, \ - 'inference does not work with activation checkpointing' + if input_data.inference_params: + if self.recompute_granularity is not None: + raise ValueError('inference does not work with activation checkpointing') if not self.pre_process: # See set_input_tensor() @@ -126,7 +148,8 @@ class TransformerBlockAdapter(BlockAdapter): # add retriever_output. Make retriever_output available # to subsequence Retro layers. if isinstance(hidden_states, tuple): - assert len(hidden_states) == 2 + if len(hidden_states) != 2: + raise ValueError("hidden_states should be a tuple of length 2") hidden_states, retriever_output = hidden_states forward_kwargs['retriever_output'] = retriever_output @@ -134,32 +157,64 @@ class TransformerBlockAdapter(BlockAdapter): if torch.is_grad_enabled() and self.training: self.microbatch_count += 1 return hidden_states + + def get_output_name(self): + return ['hidden_states'] + + def copy_method_module(self,other): + other._get_layer = self.module._get_layer + other.layers = self.module.layers + other._checkpointed_forward = self.module._checkpointed_forward # 其实没必要,因为profile时不会用到,但可起到标注作用 + if other.transformer_impl == 'transformer_engine': + global transformer_engine + + def weight_module(self): + return self.module + + +@dataclass +class EmbeddingForwardInput: + enc_input_ids: any = None + enc_position_ids: any = None + enc_attn_mask: any = None + dec_input_ids: any = None + dec_position_ids: any = None + dec_attn_mask: any = None + retriever_input_ids: any = None + retriever_position_ids: any = None + retriever_attn_mask: any = None + enc_dec_attn_mask: any = None + tokentype_ids: any = None + inference_params: any = None + pooling_sequence_index: any = 0 + enc_hidden_states: any = None + output_enc_hidden: any = False class EmbeddingAdapter(BlockAdapter): """megatron.model.language_model.TransformerLanguageModel.forward""" - def get_output_name(self): - return ['encoder_input', 'rotary_pos_emb'] - - def copy_method_module(self, other): - other.embedding = self.module.embedding - if other.use_rotary_position_embeddings: - other.rotary_pos_emb = self.module.rotary_pos_emb - - def weight_module(self): - return self.module.embedding + def __init__(self): + pass @staticmethod - def origin_forward(self, enc_input_ids=None, enc_position_ids=None, enc_attn_mask=None, - dec_input_ids=None, dec_position_ids=None, dec_attn_mask=None, - retriever_input_ids=None, - retriever_position_ids=None, - retriever_attn_mask=None, - enc_dec_attn_mask=None, tokentype_ids=None, - inference_params=None, - pooling_sequence_index=0, - enc_hidden_states=None, output_enc_hidden=False): + def origin_forward(self, input_data: EmbeddingForwardInput): + enc_input_ids = input_data.enc_input_ids + enc_position_ids = input_data.enc_position_ids + enc_attn_mask = input_data.enc_attn_mask + dec_input_ids = input_data.dec_input_ids + dec_position_ids = input_data.dec_position_ids + dec_attn_mask = input_data.dec_attn_mask + retriever_input_ids = input_data.retriever_input_ids + retriever_position_ids = input_data.retriever_position_ids + retriever_attn_mask = input_data.retriever_attn_mask + enc_dec_attn_mask = input_data.enc_dec_attn_mask + tokentype_ids = input_data.tokentype_ids + inference_params = input_data.inference_params + pooling_sequence_index = input_data.pooling_sequence_index + enc_hidden_states = input_data.enc_hidden_states + output_enc_hidden = input_data.output_enc_hidden + # Encoder embedding. if self.pre_process: encoder_input = self.embedding(enc_input_ids, enc_position_ids, @@ -185,11 +240,31 @@ class EmbeddingAdapter(BlockAdapter): rotary_pos_emb = self.rotary_pos_emb(self.seq_length) return encoder_input, rotary_pos_emb + + def get_output_name(self): + return ['encoder_input', 'rotary_pos_emb'] + + def copy_method_module(self, other): + other.embedding = self.module.embedding + if other.use_rotary_position_embeddings: + other.rotary_pos_emb = self.module.rotary_pos_emb + + def weight_module(self): + return self.module.embedding class FinalNormAdapter(BlockAdapter): """ParallelTransformer forward 中 self.final_norm 相关部分""" + def __init__(self): + pass + + @staticmethod + def origin_forward(self, hidden_states): + hidden_states = self.final_norm(hidden_states) + + return hidden_states + def get_output_name(self): return ['hidden_states'] @@ -199,16 +274,49 @@ class FinalNormAdapter(BlockAdapter): def weight_module(self): return self.module.final_norm - @staticmethod - def origin_forward(self, hidden_states): - hidden_states = self.final_norm(hidden_states) - return hidden_states +@dataclass +class LossForwardInput: + input_ids: any = None + position_ids: any = None + attention_mask: any = None + lm_output: any = None + retriever_input_ids: any = None + retriever_position_ids: any = None + retriever_attn_mask: any = None + labels: any = None + tokentype_ids: any = None + inference_params: any = None class LossAdapter(BlockAdapter): """modellink.model.GPTModel""" + def __init__(self): + pass + + @staticmethod + def origin_forward(self, input_data: LossForwardInput): + input_ids = input_data.input_ids + position_ids = input_data.position_ids + attention_mask = input_data.attention_mask + lm_output = input_data.lm_output + retriever_input_ids = input_data.retriever_input_ids + retriever_position_ids = input_data.retriever_position_ids + retriever_attn_mask = input_data.retriever_attn_mask + labels = input_data.labels + tokentype_ids = input_data.tokentype_ids + inference_params = input_data.inference_params + + output = post_language_model_processing( + lm_output, labels, + # 此处不进行判断 而直接输入该形状匹配的内容 + self.logit_weights, + self.parallel_output, + self.fp16_lm_cross_entropy) + + return output + def get_output_name(self): return ["output"] @@ -221,19 +329,4 @@ class LossAdapter(BlockAdapter): def weight_module(self): return self.module.language_model.output_layer - - @staticmethod - def origin_forward(self, input_ids=None, position_ids=None, attention_mask=None, lm_output=None, - retriever_input_ids=None, - retriever_position_ids=None, - retriever_attn_mask=None, - labels=None, tokentype_ids=None, inference_params=None): - output = post_language_model_processing( - lm_output, labels, - # 此处不进行判断 而直接输入该形状匹配的内容 - self.logit_weights, - self.parallel_output, - self.fp16_lm_cross_entropy) - - return output - \ No newline at end of file + \ No newline at end of file diff --git a/profiler/tinker/model/block_adapter_1_1.py b/profiler/tinker/model/block_adapter_1_1.py index 17ef2bedd0..28a3ce6fa0 100644 --- a/profiler/tinker/model/block_adapter_1_1.py +++ b/profiler/tinker/model/block_adapter_1_1.py @@ -1,4 +1,20 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from contextlib import nullcontext +from dataclasses import dataclass import torch @@ -9,38 +25,41 @@ from tinker.megatron_patch.microbatches import get_num_microbatches from tinker.model.block_adapter import BlockAdapter +@dataclass +class TransformerForwardInput: + hidden_states: any = None + attention_mask: any = None + encoder_output: any = None + enc_dec_attn_mask: any = None + retriever_input: any = None + retriever_output: any = None + retriever_attn_mask: any = None + inference_params: any = None + rotary_pos_emb: any = None + + class TransformerBlockAdapter(BlockAdapter): """parallel_transformer_forward""" - def get_output_name(self): - return ['hidden_states'] - - def copy_method_module(self, other): - other._get_layer = self.module._get_layer - other.layers = self.module.layers - other._checkpointed_forward = self.module._checkpointed_forward # 其实没必要,因为profile时不会用到,但可起到标注作用 - if other.transformer_impl == 'transformer_engine': - global transformer_engine - # noinsepction PyUnresolvedReferences - import transformer_engine + def __init__(self): + pass - def weight_module(self): - return self.module - @staticmethod - def origin_forward(self, hidden_states=None, attention_mask=None, - encoder_output=None, enc_dec_attn_mask=None, - retriever_input=None, - retriever_output=None, - retriever_attn_mask=None, - inference_params=None, - rotary_pos_emb=None): - # hidden_states: [s, b, h] + def origin_forward(self, input_data: TransformerForwardInput): + hidden_states = input_data.hidden_states + attention_mask = input_data.attention_mask + encoder_output = input_data.encoder_output + enc_dec_attn_mask = input_data.enc_dec_attn_mask + retriever_input = input_data.retriever_input + retriever_output = input_data.retriever_output + retriever_attn_mask = input_data.retriever_attn_mask + inference_params = input_data.inference_params + rotary_pos_emb = input_data.rotary_pos_emb #Checks. if inference_params: - assert self.recompute_granularity is None, \ - 'inference does not work with activation checkpointing' + if self.recompute_granularity is not None: + raise ValueError('inference does not work with activation checkpointing') if not self.pre_process: # See set_input_tensor() @@ -129,7 +148,8 @@ class TransformerBlockAdapter(BlockAdapter): # and retriever_output. Make retriever_output available # to subsequence Retro layers. if isinstance(hidden_states, tuple): - assert len(hidden_states) == 2 + if len(hidden_states) != 2: + raise ValueError("hidden_states should be a tuple of length 2") hidden_states, retriever_output = hidden_states forward_kwargs["retriever_output"] = retriever_output @@ -138,31 +158,65 @@ class TransformerBlockAdapter(BlockAdapter): self.microbatch_count += 1 return hidden_states + def get_output_name(self): + return ['hidden_states'] + + def copy_method_module(self, other): + other._get_layer = self.module._get_layer + other.layers = self.module.layers + other._checkpointed_forward = self.module._checkpointed_forward # 其实没必要,因为profile时不会用到,但可起到标注作用 + if other.transformer_impl == 'transformer_engine': + global transformer_engine + # noinsepction PyUnresolvedReferences + import transformer_engine + + def weight_module(self): + return self.module + + +@dataclass +class EmbeddingForwardInput: + enc_input_ids: any = None + enc_position_ids: any = None + enc_attn_mask: any = None + dec_input_ids: any = None + dec_position_ids: any = None + dec_attn_mask: any = None + retriever_input_ids: any = None + retriever_position_ids: any = None + retriever_attn_mask: any = None + enc_dec_attn_mask: any = None + tokentype_ids: any = None + inference_params: any = None + pooling_sequence_index: any = 0 + enc_hidden_states: any = None + output_enc_hidden: any = False + class EmbeddingAdapter(BlockAdapter): """megatron.legacy.model.language_model.TransformerLanguageModel.forward""" - def get_output_name(self): - return ['encoder_input', 'rotary_pos_emb'] - - def copy_method_module(self, other): - other.embedding = self.module.embedding - if other.use_rotary_position_embeddings: - other.rotary_pos_emb = self.module.rotary_pos_emb + def __init__(self): + pass - def weight_module(self): - return self.module.embedding - @staticmethod - def origin_forward(self, enc_input_ids=None, enc_position_ids=None, enc_attn_mask=None, - dec_input_ids=None, dec_position_ids=None, dec_attn_mask=None, - retriever_input_ids=None, - retriever_position_ids=None, - retriever_attn_mask=None, - enc_dec_attn_mask=None, tokentype_ids=None, - inference_params=None, - pooling_sequence_index=0, - enc_hidden_states=None, output_enc_hidden=False): + def origin_forward(self, input_data: EmbeddingForwardInput): + enc_input_ids = input_data.enc_input_ids + enc_position_ids = input_data.enc_position_ids + enc_attn_mask = input_data.enc_attn_mask + dec_input_ids = input_data.dec_input_ids + dec_position_ids = input_data.dec_position_ids + dec_attn_mask = input_data.dec_attn_mask + retriever_input_ids = input_data.retriever_input_ids + retriever_position_ids = input_data.retriever_position_ids + retriever_attn_mask = input_data.retriever_attn_mask + enc_dec_attn_mask = input_data.enc_dec_attn_mask + tokentype_ids = input_data.tokentype_ids + inference_params = input_data.inference_params + pooling_sequence_index = input_data.pooling_sequence_index + enc_hidden_states = input_data.enc_hidden_states + output_enc_hidden = input_data.output_enc_hidden + # Encoder embedding. if self.pre_process: encoder_input = self.embedding(enc_input_ids, enc_position_ids, @@ -189,10 +243,30 @@ class EmbeddingAdapter(BlockAdapter): return encoder_input, rotary_pos_emb + def get_output_name(self): + return ['encoder_input', 'rotary_pos_emb'] + + def copy_method_module(self, other): + other.embedding = self.module.embedding + if other.use_rotary_position_embeddings: + other.rotary_pos_emb = self.module.rotary_pos_emb + + def weight_module(self): + return self.module.embedding + class FinalNormAdapter(BlockAdapter): """ParallelTransformer forward 中 self.final_norm 相关部分""" + def __init__(self): + pass + + @staticmethod + def origin_forward(self, hidden_states): + hidden_states = self.final_norm(hidden_states) + + return hidden_states + def get_output_name(self): return ['hidden_states'] @@ -202,16 +276,49 @@ class FinalNormAdapter(BlockAdapter): def weight_module(self): return self.module.final_norm - @staticmethod - def origin_forward(self, hidden_states): - hidden_states = self.final_norm(hidden_states) - return hidden_states - +@dataclass +class LossForwardInput: + input_ids: any = None + position_ids: any = None + attention_mask: any = None + lm_output: any = None + retriever_input_ids: any = None + retriever_position_ids: any = None + retriever_attn_mask: any = None + labels: any = None + tokentype_ids: any = None + inference_params: any = None + class LossAdapter(BlockAdapter): """modellink.model.GPTModel""" + def __init__(self): + pass + + @staticmethod + def origin_forward(self, input_data: LossForwardInput): + input_ids = input_data.input_ids + position_ids = input_data.position_ids + attention_mask = input_data.attention_mask + lm_output = input_data.lm_output + retriever_input_ids = input_data.retriever_input_ids + retriever_position_ids = input_data.retriever_position_ids + retriever_attn_mask = input_data.retriever_attn_mask + labels = input_data.labels + tokentype_ids = input_data.tokentype_ids + inference_params = input_data.inference_params + + output = post_language_model_processing( + lm_output, labels, + # 此处不进行判断 而直接输入该形状匹配的内容 + self.logit_weights, + self.parallel_output, + self.fp16_lm_cross_entropy) + + return output + def get_output_name(self): return ["output"] @@ -224,19 +331,3 @@ class LossAdapter(BlockAdapter): def weight_module(self): return self.module.language_model.output_layer - - @staticmethod - def origin_forward(self, input_ids=None, position_ids=None, attention_mask=None, lm_output=None, - retriever_input_ids=None, - retriever_position_ids=None, - retriever_attn_mask=None, - labels=None, tokentype_ids=None, inference_params=None): - output = post_language_model_processing( - lm_output, labels, - # 此处不进行判断 而直接输入该形状匹配的内容 - self.logit_weights, - self.parallel_output, - self.fp16_lm_cross_entropy) - - return output - \ No newline at end of file diff --git a/profiler/tinker/model/block_adapter_1_2.py b/profiler/tinker/model/block_adapter_1_2.py index 31e418798d..38782daa34 100644 --- a/profiler/tinker/model/block_adapter_1_2.py +++ b/profiler/tinker/model/block_adapter_1_2.py @@ -1,4 +1,20 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from contextlib import nullcontext +from dataclasses import dataclass import torch @@ -22,30 +38,36 @@ class DummyDecoder: self.input_tensor = None +@dataclass +class McoreForwardInput: + input_ids: Tensor = None + position_ids: Tensor = None + attention_mask: Tensor = None + decoder_input: Tensor = None + labels: Tensor = None + inference_params: InferenceParams = None + packed_seq_params: PackedSeqParams = None + extra_block_kwargs: dict = None + tokentype_ids=None + + class McoreEmbeddingAdapter(BlockAdapter): """modellink.core.gpt_model_forward""" - def get_output_name(self): - return ['decoder_input', 'rotary_pos_emb'] - - def copy_method_module(self, other): - # 此decoder用于避免前向报错,但其本身值并不被使用 - other.decoder = DummyDecoder() - other.embedding = self.module.embedding - other.rotary_pos_emb = self.module.rotary_pos_emb - - def weight_module(self): - return self.module.embedding, self.module.rotary_pos_emb + def __init__(self): + pass @staticmethod - def origin_forward(self, input_ids: Tensor = None, - position_ids: Tensor = None, attention_mask: Tensor = None, - decoder_input: Tensor = None, - labels: Tensor = None, - inference_params: InferenceParams = None, - packed_seq_params: PackedSeqParams = None, - extra_block_kwargs: dict = None, - tokentype_ids=None): + def origin_forward(self, input_data: McoreForwardInput): + input_ids = input_data.input_ids + position_ids = input_data.position_ids + attention_mask = input_data.attention_mask + decoder_input = input_data.decoder_input + labels = input_data.labels + inference_params = input_data.inference_params + packed_seq_params = input_data.packed_seq_params + extra_block_kwargs = input_data.extra_block_kwargs + tokentype_ids = input_data.tokentype_ids """modellink.core.gpt_model_forward""" # If decoder_input is provides (not None), then input_ids and position_ids are ignored. # Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input. @@ -72,34 +94,47 @@ class McoreEmbeddingAdapter(BlockAdapter): return decoder_input, rotary_pos_emb - -class McoreTransformerBlockAdapter(BlockAdapter): - """megatron.core.transformer.transformer_block.TransformerBlock""" def get_output_name(self): - return ['hidden_states'] + return ['decoder_input', 'rotary_pos_emb'] def copy_method_module(self, other): - other.layers = self.module.layers - # other.group_prefetch_offload_commit_async = self.module.group_prefetch_offload_commit_async - # other.input_embeds_norm = self.module.input_embeds_norm - # other.offload_context = self.module.offload_context - + # 此decoder用于避免前向报错,但其本身值并不被使用 + other.decoder = DummyDecoder() + other.embedding = self.module.embedding + other.rotary_pos_emb = self.module.rotary_pos_emb def weight_module(self): - return self.module.layers + return self.module.embedding, self.module.rotary_pos_emb + + +@dataclass +class McoreTransformerForwardInput: + hidden_states: Tensor = None + attention_mask: Tensor = None + context: Tensor = None + context_mask: Tensor = None + rotary_pos_emb: Tensor = None + inference_params: InferenceParams = None + packed_seq_params: PackedSeqParams = None + + +class McoreTransformerBlockAdapter(BlockAdapter): + """megatron.core.transformer.transformer_block.TransformerBlock""" + + def __init__(self): + pass @staticmethod - def origin_forward( - self, - hidden_states: Tensor = None, - attention_mask: Tensor = None, - context: Tensor = None, - context_mask: Tensor = None, - rotary_pos_emb: Tensor = None, - inference_params: InferenceParams = None, - packed_seq_params: PackedSeqParams = None, - ): + def origin_forward(self, input_data: McoreTransformerForwardInput): """modellink.core.transformer_block_forward""" + hidden_states = input_data.hidden_states + attention_mask = input_data.attention_mask + context = input_data.context + context_mask = input_data.context_mask + rotary_pos_emb = input_data.rotary_pos_emb + inference_params = input_data.inference_params + packed_seq_params = input_data.packed_seq_params + # hidden_states (float): [s, b, h] # attention_mask (bool): [1, 1, s, s] @@ -194,17 +229,21 @@ class McoreTransformerBlockAdapter(BlockAdapter): hidden_states = self.group_prefetch_offload_commit_async(hidden_states) return hidden_states - -class McoreFinalNormAdapter(BlockAdapter): def get_output_name(self): return ['hidden_states'] def copy_method_module(self, other): - other.final_layernorm = self.module.final_layernorm + other.layers = self.module.layers def weight_module(self): - return self.module.final_layernorm + return self.module.layers + + +class McoreFinalNormAdapter(BlockAdapter): + + def __init__(self): + pass @staticmethod def origin_forward(self, hidden_states: Tensor = None): @@ -213,33 +252,54 @@ class McoreFinalNormAdapter(BlockAdapter): return hidden_states - -class McoreLossAdapter(BlockAdapter): - """modellink.core.gpt_model_forward""" - def get_output_name(self): - return ['loss'] + return ['hidden_states'] def copy_method_module(self, other): - other.output_layer = self.module.output_layer - other.compute_language_model_loss = self.module.compute_language_model_loss - other.shared_embedding_or_output_weight = self.module.shared_embedding_or_output_weight + other.final_layernorm = self.module.final_layernorm def weight_module(self): - return self.module.output_layer + return self.module.final_layernorm + +@dataclass +class McoreLossForwardInput: + input_ids: Tensor = None + label: any = None + hidden_states: any = None + position_ids: Tensor = None + attention_mask: Tensor = None + decoder_input: Tensor = None + labels: Tensor = None + inference_params: InferenceParams = None + packed_seq_params: PackedSeqParams = None + extra_block_kwargs: dict = None + tokentype_ids: any = None + + +class McoreLossAdapter(BlockAdapter): + """modellink.core.gpt_model_forward""" + + def __init__(self): + pass + @staticmethod - def origin_forward(self, input_ids: Tensor = None, label=None, hidden_states=None, - position_ids: Tensor = None, attention_mask: Tensor = None, - decoder_input: Tensor = None, - labels: Tensor = None, - inference_params: InferenceParams = None, - packed_seq_params: PackedSeqParams = None, - extra_block_kwargs: dict = None, - tokentype_ids=None): + def origin_forward(self, input_data: McoreLossForwardInput): """modellink.core.gpt_model_forward""" + input_ids = input_data.input_ids + label = input_data.label + hidden_states = input_data.hidden_states + position_ids = input_data.position_ids + attention_mask = input_data.attention_mask + decoder_input = input_data.decoder_input + labels = input_data.labels + inference_params = input_data.inference_params + packed_seq_params = input_data.packed_seq_params + extra_block_kwargs = input_data.extra_block_kwargs + tokentype_ids = input_data.tokentype_ids + args = get_args() - # logits and loss + # logits and loss output_weight = None if self.share_embeddings_and_output_weights: output_weight = self.shared_embedding_or_output_weight() @@ -267,4 +327,14 @@ class McoreLossAdapter(BlockAdapter): loss = self.compute_language_model_loss(labels, logits) return loss - \ No newline at end of file + + def get_output_name(self): + return ['loss'] + + def copy_method_module(self, other): + other.output_layer = self.module.output_layer + other.compute_language_model_loss = self.module.compute_language_model_loss + other.shared_embedding_or_output_weight = self.module.shared_embedding_or_output_weight + + def weight_module(self): + return self.module.output_layer -- Gitee