1 Star 0 Fork 0

xqy2006/LLM-Pruner

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
test_speedup.py 4.31 KB
一键复制 编辑 原始数据 按行查看 历史
horseee 提交于 2023-07-11 16:47 +08:00 . Add input constructor for testing MACs on GPU
import os
import sys
import argparse
import torch
import numpy as np
import transformers
from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer
from transformers.activations import SiLUActivation
from ptflops import get_model_complexity_info
from ptflops.pytorch_ops import bn_flops_counter_hook, pool_flops_counter_hook
from LLMPruner.models.hf_llama.modeling_llama import LlamaForCausalLM, LlamaRMSNorm, LlamaAttention, LlamaMLP
from LLMPruner.peft import PeftModel
if torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
torch_version = int(torch.__version__.split('.')[1])
def LlamaAttention_counter_hook(module, input, output):
# (1) Ignore past-key values
# (2) Assume there is no attention mask
# Input will be empty in some pytorch version. use output here since input.shape == output.shape
flops = 0
q_len = output[0].shape[1]
linear_dim = output[0].shape[-1]
num_heads = module.num_heads
head_dim = module.head_dim
rotary_flops = 2 * (q_len * num_heads * head_dim) * 2
attention_flops = num_heads * (q_len * q_len * head_dim + q_len * q_len + q_len * q_len * head_dim) #QK^T + softmax + AttentionV
linear_flops = 4 * (q_len * linear_dim * linear_dim) # 4 for q, k, v, o.
flops += rotary_flops + attention_flops + linear_flops
module.__flops__ += int(flops)
def rmsnorm_flops_counter_hook(module, input, output):
input = input[0]
batch_flops = np.prod(input.shape)
batch_flops *= 2
module.__flops__ += int(batch_flops)
def main(args):
if args.model_type == 'pretrain':
tokenizer = LlamaTokenizer.from_pretrained(args.base_model)
model = LlamaForCausalLM.from_pretrained(
args.base_model,
low_cpu_mem_usage=True if torch_version >=9 else False
)
elif args.model_type == 'pruneLLM':
pruned_dict = torch.load(args.ckpt, map_location='cpu')
tokenizer, model = pruned_dict['tokenizer'], pruned_dict['model']
else:
raise NotImplementedError
def input_constructor(x):
return {'input_ids': torch.ones(x).long().to(device)}
if device == "cuda":
model.half()
model = model.cuda()
with torch.cuda.device(0):
macs, params = get_model_complexity_info(model, (1, 64,), as_strings=True,
input_constructor = input_constructor,
print_per_layer_stat=True, verbose=True,
custom_modules_hooks={
LlamaAttention: LlamaAttention_counter_hook,
LlamaRMSNorm: rmsnorm_flops_counter_hook,
SiLUActivation: pool_flops_counter_hook,
},)
else:
model.float()
macs, params = get_model_complexity_info(model, (1, 64,), as_strings=True,
input_constructor = input_constructor,
print_per_layer_stat=True, verbose=True,
custom_modules_hooks={
LlamaAttention: LlamaAttention_counter_hook,
LlamaRMSNorm: rmsnorm_flops_counter_hook,
SiLUActivation: pool_flops_counter_hook,
},)
print('{:<30} {:<8}'.format('Computational complexity: ', macs))
print('{:<30} {:<8}'.format('Number of parameters: ', params))
print("GPU Memory Requirement: {} MiB\n".format(torch.cuda.memory_allocated()/1024/1024))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Tuning Pruned LLaMA (huggingface version)')
parser.add_argument('--base_model', type=str, default="decapoda-research/llama-7b-hf", help='base model name')
parser.add_argument('--model_type', type=str, required=True, help = 'choose from ')
parser.add_argument('--ckpt', type=str, default=None)
parser.add_argument('--lora_ckpt', type=str, default=None)
args = parser.parse_args()
main(args)
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/kilszfdjs/LLM-Pruner.git
git@gitee.com:kilszfdjs/LLM-Pruner.git
kilszfdjs
LLM-Pruner
LLM-Pruner
main

搜索帮助