diff --git a/nlp/llm/deepseek_moe_7b/ColossalAI/README.md b/nlp/llm/deepseek_moe_7b/ColossalAI/README.md new file mode 100644 index 0000000000000000000000000000000000000000..23cef8093fef5e00247bb00e58b27796b3aa6d02 --- /dev/null +++ b/nlp/llm/deepseek_moe_7b/ColossalAI/README.md @@ -0,0 +1,28 @@ +# Colossal-AI LLaMA-7B + +## Model description +DeepSeekMoE 16B is a Mixture-of-Experts (MoE) language model with 16.4B parameters. It employs an innovative MoE architecture, which involves two principal strategies: fine-grained expert segmentation and shared experts isolation. +DeepSeekMoE 7B is a variant of the 16B model. + +## Step 1: Install + +Firstly, you should ensure that ColossalAI is installed in the environment. Generally, ColossalAI is installed by default. + +## Step 2: Prepare model and config + +Get "deepseek-moe-16b-base" models and config file from huggingface or other place, and mv it to "/home/model_zoos/nlp/deepseek-moe-16b-base". +One recommended link: "https://huggingface.co/deepseek-ai/deepseek-moe-16b-base/tree/main". + +## Step 3: Training +```bash +$ bash deepseek_moe_7b_pretrain.sh +``` + +## Results +| Model | Training speed | +|--------------------|--------------------| +| deepseek-moe-7b | 6.85 samples/sec | + +## Reference + +- [ColossalAI (tag:v0.4.4)](https://github.com/hpcaitech/ColossalAI/tree/v0.4.4/examples/language/deepseek) diff --git a/nlp/llm/deepseek_moe_7b/ColossalAI/benchmark.py b/nlp/llm/deepseek_moe_7b/ColossalAI/benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..fe4f88a374d1f2d84616a48204bda829ea40f1c2 --- /dev/null +++ b/nlp/llm/deepseek_moe_7b/ColossalAI/benchmark.py @@ -0,0 +1,276 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. + +# modified from mixtral benchmark +import argparse +import resource +import time +import warnings +from contextlib import nullcontext + +import torch +import torch.distributed as dist +from data_utils import RandomDataset +from model_utils import format_numel_str, get_model_numel +from performance_evaluator import PerformanceEvaluator, get_profile_context +from tqdm import tqdm +from transformers import AutoConfig, AutoModelForCausalLM + +import colossalai +from colossalai.accelerator import get_accelerator +from colossalai.booster import Booster +from colossalai.booster.plugin import MoeHybridParallelPlugin +from colossalai.cluster import DistCoordinator +from colossalai.lazy import LazyInitContext +from colossalai.nn.optimizer import HybridAdam +from colossalai.shardformer import PipelineGradientCheckpointConfig + +warnings.filterwarnings("ignore") +# ============================== +# Constants +# ============================== + +# We have lots of llamas for your choice! +# deepseek-ai/deepseek-moe-16b-base model_path +MODEL_CONFIGS = { + "100m": lambda model_path: AutoConfig.from_pretrained( + model_path, + max_position_embeddings=4096, + num_hidden_layers=1, + num_attention_heads=32, + intermediate_size=512, + moe_intermediate_size=128, + hidden_size=512, + n_routed_experts=8, + n_shared_experts=4, + num_experts_per_tok=2, + first_k_dense_replace=0, + attn_implementation="flash_attention_2", + trust_remote_code=True, + ), + "7b": lambda model_path: AutoConfig.from_pretrained( + model_path, + max_position_embeddings=4096, + num_hidden_layers=13, + attn_implementation="flash_attention_2", + trust_remote_code=True, + ), + "14b": lambda model_path: AutoConfig.from_pretrained( + model_path, + max_position_embeddings=4096, + num_hidden_layers=26, + attn_implementation="flash_attention_2", + trust_remote_code=True, + ), +} + + +def main(): + # ============================== + # Parse Arguments + # ============================== + parser = argparse.ArgumentParser() + parser.add_argument("-c", "--config", type=str, default="100m", help="Model configuration") + parser.add_argument( + "-p", + "--plugin", + choices=["3d"], + default="3d", + help="Choose which plugin to use", + ) + parser.add_argument("-b", "--batch_size", type=int, default=1, help="Batch size") + parser.add_argument("-s", "--num_steps", type=int, default=5, help="Number of steps to run") + parser.add_argument("-i", "--ignore_steps", type=int, default=2, help="Number of steps to ignore") + parser.add_argument("-g", "--grad_checkpoint", action="store_true", help="Use gradient checkpointing") + parser.add_argument("-l", "--max_length", type=int, default=4096, help="Max sequence length") + parser.add_argument( + "-w", "--warmup_ratio", type=float, default=0.8, help="warm up ratio of non-model data. Only for gemini-auto" + ) + parser.add_argument("-m", "--memory_limit", type=int, help="Gemini memory limit in mb") + parser.add_argument("-x", "--xformers", action="store_true", help="Use xformers") + parser.add_argument("--shard_param_frac", type=float, default=1.0, help="Shard param fraction. Only for gemini") + parser.add_argument("--offload_optim_frac", type=float, default=0.0, help="Offload optim fraction. Only for gemini") + parser.add_argument("--offload_param_frac", type=float, default=0.0, help="Offload param fraction. Only for gemini") + parser.add_argument("--tp", type=int, default=1, help="Tensor parallel size") + parser.add_argument("--ep", type=int, default=1, help="Expert parallel size") + parser.add_argument("--sp", type=int, default=1, help="Sequence parallel size") + parser.add_argument("--extra_dp", type=int, default=1, help="Extra data parallel size, used for Gemini") + parser.add_argument("--pp", type=int, default=1, help="Pipeline parallel size") + parser.add_argument("--mbs", type=int, default=1, help="Micro batch size of pipeline parallel") + parser.add_argument("--zero", type=int, default=1, help="Zero Stage when hybrid plugin is enabled") + parser.add_argument("--custom-ckpt", action="store_true", help="Customize checkpoint", default=False) + + parser.add_argument("--pp_style", default="1f1b", choices=["1f1b", "interleaved"]) + parser.add_argument("--n_chunks", default=1, help="number of model chunks", type=eval) + parser.add_argument("--profile", action="store_true", help="Profile the code") + parser.add_argument( + "--nsys", + action="store_true", + help="Use nsys for profiling. \ + You should put something like this before colossalai launch: \ + nsys profile -w true -t cuda,cudnn,cublas -s cpu --capture-range=cudaProfilerApi --capture-range-end=stop --cudabacktrace=true -x true --python-backtrace=cuda -o prof_out", + ) + parser.add_argument("--disable-async-reduce", action="store_true", help="Disable the asynchronous reduce operation") + parser.add_argument("--prefetch_num", type=int, default=0, help="chunk prefetch max number") + parser.add_argument("--no_cache", action="store_true") + parser.add_argument("--use_fp8_comm", action="store_true", default=False, help="for using fp8 during communication") + parser.add_argument("--use_fp8", action="store_true", default=False, help="for using fp8 linear") + parser.add_argument("--overlap_allgather", action="store_true") + parser.add_argument( + "--sp_mode", + default="all_to_all", + choices=["all_to_all"], + help="Sequence parallelism mode", + ) + parser.add_argument("--debug", action="store_true", help="Enable debug mode") + parser.add_argument("--model_path", type=str, default=None, help="the path of model and config") + args = parser.parse_args() + + colossalai.launch_from_torch() + coordinator = DistCoordinator() + coordinator.print_on_master(f"args:{args}") + # ckpt config for LLaMA3-70B on 64 H100 GPUs + hybrid_kwargs = ( + { + "gradient_checkpoint_config": PipelineGradientCheckpointConfig( + num_ckpt_layers_per_stage=[19, 19, 19, 13], + ), + "num_layers_per_stage": [19, 20, 20, 21], + "pp_style": "interleaved", + } + if args.custom_ckpt + else {} + ) + + # ============================== + # Initialize Booster + # ============================== + if args.plugin == "3d": + plugin = MoeHybridParallelPlugin( + ep_size=args.ep, + tp_size=args.tp, + pp_size=args.pp, + pp_style=args.pp_style, + num_model_chunks=args.n_chunks, + zero_stage=args.zero, + sp_size=args.sp, + sequence_parallelism_mode=args.sp_mode, + enable_sequence_parallelism=args.sp > 1, + enable_fused_normalization=torch.cuda.is_available(), + enable_flash_attention=args.xformers, + microbatch_size=args.mbs, + precision="bf16", + enable_metadata_cache=not args.no_cache, + overlap_allgather=args.overlap_allgather, + use_fp8=args.use_fp8, + fp8_communication=args.use_fp8_comm, + **hybrid_kwargs, + ) + else: + raise ValueError(f"Unknown plugin {args.plugin}") + + booster = Booster(plugin=plugin) + + # ============================== + # Initialize Dataset and Dataloader + # ============================== + dp_size = getattr(plugin, "dp_size", coordinator.world_size) + + config = MODEL_CONFIGS[args.config](args.model_path) + + torch.cuda.manual_seed(42) + + dataset = RandomDataset( + num_samples=args.batch_size * args.num_steps * dp_size, max_length=args.max_length, vocab_size=config.vocab_size + ) + dataloader = plugin.prepare_dataloader(dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, seed=42) + + # ============================== + # Initialize Model and Optimizer + # ============================== + init_ctx = ( + LazyInitContext(default_device=get_accelerator().get_current_device()) + if isinstance(plugin, MoeHybridParallelPlugin) + else nullcontext() + ) + + with init_ctx: + model = AutoModelForCausalLM.from_config(config, trust_remote_code=True).to(torch.bfloat16) + + if args.grad_checkpoint: + model.gradient_checkpointing_enable() + + model_numel = get_model_numel(model) + coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}") + performance_evaluator = PerformanceEvaluator( + model_numel, + model.config.num_hidden_layers, + model.config.hidden_size, + model.config.vocab_size, + args.grad_checkpoint, + args.ignore_steps, + dp_world_size=dp_size, + ) + + optimizer = HybridAdam(model.parameters()) + torch.set_default_dtype(torch.bfloat16) + model, optimizer, _, dataloader, _ = booster.boost(model, optimizer, dataloader=dataloader) + + torch.set_default_dtype(torch.float) + coordinator.print_on_master( + f"Booster init max CUDA memory: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB" + ) + coordinator.print_on_master( + f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB" + ) + + with get_profile_context( + args.profile, + args.ignore_steps, + 1, # avoid creating massive log files + save_dir=f"profile/{time.strftime('%H:%M', time.localtime())}-{args.plugin}-llama-{args.config}", + nsys=args.nsys, + ) as prof: # , distributed_debug_mode(10, enable=True): + if isinstance(plugin, MoeHybridParallelPlugin) and args.pp > 1: + data_iter = iter(dataloader) + for step in tqdm(range(len(dataloader)), desc="Step", disable=not coordinator.is_master()): + performance_evaluator.on_step_start(step) + outputs = booster.execute_pipeline( + data_iter, + model, + criterion=lambda outputs, inputs: outputs[0], + optimizer=optimizer, + return_loss=True, + ) + loss = outputs["loss"] + if dist.get_rank() == dist.get_world_size() - 1: + print(f"Step {step} loss: {loss}") + optimizer.step() + optimizer.zero_grad() + + performance_evaluator.on_step_end(input_ids=torch.empty(args.batch_size, args.max_length)) + prof.step() + # print(f"rank {dist.get_rank()} step {step} passed") + else: + for step, batch in enumerate(tqdm(dataloader, desc="Step", disable=not coordinator.is_master())): + performance_evaluator.on_step_start(step) + outputs = model(**batch) + loss = outputs[0] + del outputs # free memory + + if dist.get_rank() == dist.get_world_size() - 1: + print(f"Step {step} loss: {loss}") + + booster.backward(loss, optimizer) + optimizer.step() + optimizer.zero_grad() + + performance_evaluator.on_step_end(**batch) + prof.step() + + performance_evaluator.on_fit_end() + coordinator.print_on_master(f"Max CUDA memory usage: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB") + + +if __name__ == "__main__": + main() diff --git a/nlp/llm/deepseek_moe_7b/ColossalAI/data_utils.py b/nlp/llm/deepseek_moe_7b/ColossalAI/data_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6b9e8ef28eb7f18468ca6949743032b7c239a4b0 --- /dev/null +++ b/nlp/llm/deepseek_moe_7b/ColossalAI/data_utils.py @@ -0,0 +1,124 @@ +import json +import random +from typing import Iterator, Optional + +import numpy as np +import torch +from torch.distributed import ProcessGroup +from torch.distributed.distributed_c10d import _get_default_group +from torch.utils.data import DataLoader, Dataset, DistributedSampler + +from colossalai.accelerator import get_accelerator + + +class StatefulDistributedSampler(DistributedSampler): + def __init__( + self, + dataset: Dataset, + num_replicas: Optional[int] = None, + rank: Optional[int] = None, + shuffle: bool = True, + seed: int = 0, + drop_last: bool = False, + ) -> None: + super().__init__(dataset, num_replicas, rank, shuffle, seed, drop_last) + self.start_index: int = 0 + + def __iter__(self) -> Iterator: + iterator = super().__iter__() + indices = list(iterator) + indices = indices[self.start_index :] + return iter(indices) + + def __len__(self) -> int: + return self.num_samples - self.start_index + + def set_start_index(self, start_index: int) -> None: + self.start_index = start_index + + +def prepare_dataloader( + dataset, + batch_size, + shuffle=False, + seed=1024, + drop_last=False, + pin_memory=False, + num_workers=0, + process_group: Optional[ProcessGroup] = None, + **kwargs, +): + r""" + Prepare a dataloader for distributed training. The dataloader will be wrapped by + `torch.utils.data.DataLoader` and `StatefulDistributedSampler`. + + + Args: + dataset (`torch.utils.data.Dataset`): The dataset to be loaded. + shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False. + seed (int, optional): Random worker seed for sampling, defaults to 1024. + add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True. + drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size + is not divisible by the batch size. If False and the size of dataset is not divisible by + the batch size, then the last batch will be smaller, defaults to False. + pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False. + num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0. + kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in + `DataLoader `_. + + Returns: + :class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing. + """ + _kwargs = kwargs.copy() + process_group = process_group or _get_default_group() + sampler = StatefulDistributedSampler( + dataset, num_replicas=process_group.size(), rank=process_group.rank(), shuffle=shuffle + ) + + # Deterministic dataloader + def seed_worker(worker_id): + worker_seed = seed + np.random.seed(worker_seed) + torch.manual_seed(worker_seed) + random.seed(worker_seed) + + return DataLoader( + dataset, + batch_size=batch_size, + sampler=sampler, + worker_init_fn=seed_worker, + drop_last=drop_last, + pin_memory=pin_memory, + num_workers=num_workers, + **_kwargs, + ) + + +def load_json(file_path: str): + with open(file_path, "r") as f: + return json.load(f) + + +def save_json(data, file_path: str): + with open(file_path, "w") as f: + json.dump(data, f, indent=4) + + +class RandomDataset(Dataset): + def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 32000): + self.num_samples = num_samples + self.max_length = max_length + self.input_ids = torch.randint( + 0, vocab_size, (num_samples, max_length), device=get_accelerator().get_current_device() + ) + self.attention_mask = torch.ones_like(self.input_ids) + + def __len__(self): + return self.num_samples + + def __getitem__(self, idx): + return { + "input_ids": self.input_ids[idx], + "attention_mask": self.attention_mask[idx], + "labels": self.input_ids[idx], + } diff --git a/nlp/llm/deepseek_moe_7b/ColossalAI/deepseek_moe_7b_pretrain.sh b/nlp/llm/deepseek_moe_7b/ColossalAI/deepseek_moe_7b_pretrain.sh new file mode 100644 index 0000000000000000000000000000000000000000..b4ea5120af8e9e1c0375c137dcbf942feff5d787 --- /dev/null +++ b/nlp/llm/deepseek_moe_7b/ColossalAI/deepseek_moe_7b_pretrain.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +colossalai run --nproc_per_node 16 benchmark.py -c 7b -g -b 16 --tp 1 --pp 4 --num_steps 50 --model_path /home/model_zoos/nlp/deepseek-moe-16b-base diff --git a/nlp/llm/deepseek_moe_7b/ColossalAI/model_utils.py b/nlp/llm/deepseek_moe_7b/ColossalAI/model_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..63569bc61143b9abbba424ea312359c8ce85bbca --- /dev/null +++ b/nlp/llm/deepseek_moe_7b/ColossalAI/model_utils.py @@ -0,0 +1,32 @@ +from contextlib import contextmanager + +import torch +import torch.nn as nn + + +@contextmanager +def low_precision_init(target_dtype: torch.dtype = torch.float16): + dtype = torch.get_default_dtype() + try: + torch.set_default_dtype(target_dtype) + yield + finally: + torch.set_default_dtype(dtype) + + +def get_model_numel(model: nn.Module) -> int: + return sum(p.numel() for p in model.parameters()) + + +def format_numel_str(numel: int) -> str: + B = 1024**3 + M = 1024**2 + K = 1024 + if numel >= B: + return f"{numel / B:.2f} B" + elif numel >= M: + return f"{numel / M:.2f} M" + elif numel >= K: + return f"{numel / K:.2f} K" + else: + return f"{numel}" diff --git a/nlp/llm/deepseek_moe_7b/ColossalAI/performance_evaluator.py b/nlp/llm/deepseek_moe_7b/ColossalAI/performance_evaluator.py new file mode 100644 index 0000000000000000000000000000000000000000..65c7e49a2f03b7b7ae1c8d79e0efad24b836c1e9 --- /dev/null +++ b/nlp/llm/deepseek_moe_7b/ColossalAI/performance_evaluator.py @@ -0,0 +1,172 @@ +from time import time +from typing import Optional + +import torch +import torch.distributed as dist +from torch import Tensor +from torch.profiler import ProfilerActivity, profile, schedule, tensorboard_trace_handler + +from colossalai.cluster import DistCoordinator + + +def divide(x: float, y: float) -> float: + if y == 0: + return float("inf") + elif y == float("inf"): + return float("nan") + return x / y + + +@torch.no_grad() +def all_reduce_mean(x: float, world_size: int) -> float: + if world_size == 1: + return x + + # Use CPU tensor to avoid OOM/weird NCCl error + gloo_group = dist.new_group(backend="gloo") + tensor = torch.tensor([x], device="cpu") + dist.all_reduce(tensor, group=gloo_group) + tensor = tensor / world_size + return tensor.item() + + +def get_profile_context(enable_flag, warmup_steps, active_steps, save_dir, nsys=False): + class DummyProfiler: + def __init__(self): + self.step_number = 0 + + def step(self): + self.step_number += 1 + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + pass + + class NsysProfiler: + def __init__(self, warmup_steps, active_steps): + self.step_number = 0 + self.warmup_steps = warmup_steps + self.active_steps = active_steps + + def step(self): + if self.step_number == self.warmup_steps: + torch.cuda.cudart().cudaProfilerStart() + elif self.step_number == self.warmup_steps + self.active_steps: + torch.cuda.cudart().cudaProfilerStop() + self.step_number += 1 + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + pass + + if enable_flag: + if nsys: + return NsysProfiler(warmup_steps, active_steps) + + return profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + schedule=schedule(wait=0, warmup=warmup_steps, active=active_steps), + on_trace_ready=tensorboard_trace_handler(save_dir), + record_shapes=True, + profile_memory=True, + with_stack=True, + ) + else: + return DummyProfiler() + + +class Timer: + def __init__(self) -> None: + self.start_time: Optional[float] = None + self.duration: float = 0.0 + + def start(self) -> None: + self.start_time = time() + + def end(self) -> None: + assert self.start_time is not None + self.duration += time() - self.start_time + self.start_time = None + + def reset(self) -> None: + self.duration = 0.0 + + +class PerformanceEvaluator: + """ + Callback for valuate the performance of the model. + Args: + actor_num_params: The number of parameters of the actor model. + critic_num_params: The number of parameters of the critic model. + initial_model_num_params: The number of parameters of the initial model. + reward_model_num_params: The number of parameters of the reward model. + enable_grad_checkpoint: Whether to enable gradient checkpointing. + ignore_episodes: The number of episodes to ignore when calculating the performance. + """ + + def __init__( + self, + model_numel: int, + num_layers: int, + hidden_size: int, + vocab_size: int, + enable_grad_checkpoint: bool = False, + ignore_steps: int = 0, + dp_world_size: Optional[int] = None, + ) -> None: + self.model_numel = model_numel + self.enable_grad_checkpoint = enable_grad_checkpoint + self.ignore_steps = ignore_steps + self.num_layers = num_layers + self.hidden_size = hidden_size + self.vocab_size = vocab_size + + self.coordinator = DistCoordinator() + self.dp_world_size = dp_world_size or self.coordinator.world_size + self.disable: bool = False + self.timer = Timer() + self.num_samples: int = 0 + self.flop_megatron = 0 + self.flop: int = 0 + + def on_step_start(self, step: int) -> None: + self.disable = self.ignore_steps > 0 and step < self.ignore_steps + if self.disable: + return + # get_accelerator().synchronize() + self.timer.start() + + def on_step_end(self, input_ids: Tensor, **kwargs) -> None: + if self.disable: + return + # get_accelerator().synchronize() + self.timer.end() + + batch_size, seq_len = input_ids.shape + + self.num_samples += batch_size + checkpoint_activations_factor = 3 + int(self.enable_grad_checkpoint) + self.flop_megatron += ( + 24 * checkpoint_activations_factor * batch_size * seq_len * self.num_layers * (self.hidden_size**2) + ) * ( + 1.0 + (seq_len / (6.0 * self.hidden_size)) + (self.vocab_size / (16.0 * self.num_layers * self.hidden_size)) + ) + self.flop += batch_size * seq_len * self.model_numel * 2 * (3 + int(self.enable_grad_checkpoint)) + + def on_fit_end(self) -> None: + avg_duration = all_reduce_mean(self.timer.duration, self.coordinator.world_size) + avg_throughput = self.num_samples * self.dp_world_size / (avg_duration + 1e-12) + mp_world_size = self.coordinator.world_size // self.dp_world_size + avg_tflops_per_gpu_megatron = self.flop_megatron / 1e12 / (avg_duration + 1e-12) / mp_world_size + avg_tflops_per_gpu = self.flop / 1e12 / (avg_duration + 1e-12) / mp_world_size + self.coordinator.print_on_master( + f"num_samples: {self.num_samples}, dp_world_size: {self.dp_world_size}, flop_megatron: {self.flop_megatron}, flop: {self.flop}, avg_duration: {avg_duration}, " + f"avg_throughput: {avg_throughput}" + ) + self.coordinator.print_on_master( + f"Throughput: {avg_throughput:.2f} samples/sec, TFLOPS per GPU by Megatron: {avg_tflops_per_gpu_megatron:.2f}, TFLOPS per GPU: {avg_tflops_per_gpu:.2f}" + )