From 9e87ab46bae378910e13299deff0bcdbc3415337 Mon Sep 17 00:00:00 2001 From: wangyuqing Date: Wed, 28 May 2025 14:55:52 +0800 Subject: [PATCH 1/3] add amct lut4 llama7b quantization sample --- .../lut4_quantization/README_CN.md | 54 ++++++++ .../lut4_quantization/requirements.txt | 7 + .../lut4_quantization/src/lut4_quant.cfg | 9 ++ .../src/run_llama7b_calibration.py | 64 +++++++++ .../src/save_llama7b_quant_model.py | 76 ++++++++++ .../lut4_quantization/src/utils.py | 131 ++++++++++++++++++ 6 files changed, 341 insertions(+) create mode 100644 python/level1_single_api/9_amct/amct_pytorch/lut4_quantization/README_CN.md create mode 100644 python/level1_single_api/9_amct/amct_pytorch/lut4_quantization/requirements.txt create mode 100644 python/level1_single_api/9_amct/amct_pytorch/lut4_quantization/src/lut4_quant.cfg create mode 100644 python/level1_single_api/9_amct/amct_pytorch/lut4_quantization/src/run_llama7b_calibration.py create mode 100644 python/level1_single_api/9_amct/amct_pytorch/lut4_quantization/src/save_llama7b_quant_model.py create mode 100644 python/level1_single_api/9_amct/amct_pytorch/lut4_quantization/src/utils.py diff --git a/python/level1_single_api/9_amct/amct_pytorch/lut4_quantization/README_CN.md b/python/level1_single_api/9_amct/amct_pytorch/lut4_quantization/README_CN.md new file mode 100644 index 000000000..f9e8a13c1 --- /dev/null +++ b/python/level1_single_api/9_amct/amct_pytorch/lut4_quantization/README_CN.md @@ -0,0 +1,54 @@ +# MXFP4量化 + +## 1 MXFP4量化前提 + +### 1.1 安装依赖 + +本sample依赖包可参考[requirements.txt](requirements.txt) + +### 1.2 模型和数据集准备 + +本sample以Llama2-7b模型,pileval和wikitext2数据集为示例,请用户自行下载,并适配utils.py文件中get_loader数据获取函数中的文件路径。当前sample中数据集保存目录需根据实际保存目录修改。 + +### 1.3 简易量化配置 +./src/lut4_quant.cfg文件为用户自定义的简易量化配置,具体表示信息如下: + +| 字段 |类型| 说明 | 默认值 | 取值范围 | 注意事项 | +|:--| :-: | :-- | :-: | :-: | :-: | +|batch_num|uint32|量化使用的batch数量 |1|/|MXFP量化中配置不生效,校准使用batch数与推理使用输入数据有关,是校准脚本中的batch_num| +|skip_layers|str|跳过量化的层 |/|/|跳过量化层支持模糊匹配,当配置字符串为层名字串,或与层名一致时,跳过该层量化,不生成量化配置。字符串必须包含数字或字母| +|weight_only_config.weight_compress_only|bool|是否为仅权重量化|False|True/False|MXFP4量化目前仅支持权重量化,需要设置为True| +|weight_only_config.wts_type|enum|量化后权重类型|INT8|INT8/MXFP4_E2M1|/| +|weight_only_config.weight_granularity|enum|权重量化粒度|PER_TENSOR|PER_TENSOR/PER_CHANNEL/PER_GROUP|MXFP4_E2M1仅支持PER_GROUP模式| +|weight_only_config.round_mode|enum|舍入模式|/|HYBRID/ROUND/RINT|MXFP4_E2M1仅支持RINT模式| +|weight_only_config.lut_quantize.lut_alog|enum|lut量化算法模式|CLUSTER|CLUSTER/ATCTAN| + +## 2 LUT4量化示例 + +### 2.1 使用接口方式调用 + +**step 1.** 请在当前目录执行如下两条命令运行示例程序,用户需根据实际情况修改示例程序中的模型和数据集路径: + +`CUDA_VISIBLE_DEVICES=0,1,2,3,4,5 python3 src/run_llama7b_calibration.py` +`CUDA_VISIBLE_DEVICES=0,1,2,3,4,5 python3 src/save_llama7b_quant_model.py` + +若出现如下信息,则说明校准成功: + +```none +Calibration time taken: 1.0 min 59.24865388870239 s +``` + +出现如下信息,说明量化成功 + +```none +Test time taken: 1.0 min 59.24865388870239 s +Score: 5.670858383178711 +``` + +推理成功后,在当前目录会生成量化日志文件./amct_log/amct_pytorch.log和./outputs文件夹,该文件夹内包含以下内容: + +- config.json:量化配置文件,描述了如何对模型中的每一层进行量化。 +- record.txt:量化因子记录文件。 +- lut_result.pt:lut算法参数文件。 + +> 如果outputs目录下已经存在量化配置文件或量化因子记录文件,再次运行示例程序时,如果新生成的文件与已有文件同名,则会覆盖已有的量化配置文件或量化因子记录文件。 diff --git a/python/level1_single_api/9_amct/amct_pytorch/lut4_quantization/requirements.txt b/python/level1_single_api/9_amct/amct_pytorch/lut4_quantization/requirements.txt new file mode 100644 index 000000000..55441d062 --- /dev/null +++ b/python/level1_single_api/9_amct/amct_pytorch/lut4_quantization/requirements.txt @@ -0,0 +1,7 @@ +torch==2.1.0 +transformers==4.40.0 +accelerate==0.30.1 +datasets==2.19.1 +sentencepiece==0.2.0 +numpy==1.23.5 +protobuf==3.20.2 \ No newline at end of file diff --git a/python/level1_single_api/9_amct/amct_pytorch/lut4_quantization/src/lut4_quant.cfg b/python/level1_single_api/9_amct/amct_pytorch/lut4_quantization/src/lut4_quant.cfg new file mode 100644 index 000000000..5b8cd01c6 --- /dev/null +++ b/python/level1_single_api/9_amct/amct_pytorch/lut4_quantization/src/lut4_quant.cfg @@ -0,0 +1,9 @@ +batch_num: 4 +skip_layers: "lm_head" +weight_only_config: { + weight_compress_only: True + wts_type: INT4 + lut_quantize : { + lut_algo: CLUSTER + } +} \ No newline at end of file diff --git a/python/level1_single_api/9_amct/amct_pytorch/lut4_quantization/src/run_llama7b_calibration.py b/python/level1_single_api/9_amct/amct_pytorch/lut4_quantization/src/run_llama7b_calibration.py new file mode 100644 index 000000000..1072dc945 --- /dev/null +++ b/python/level1_single_api/9_amct/amct_pytorch/lut4_quantization/src/run_llama7b_calibration.py @@ -0,0 +1,64 @@ +""" +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + + +import os +import copy +import time +import tqdm +import torch +import torch.nn as nn + +from utils import get_llama2, get_calib_dataset, build_model_and_enc +import amct_pytorch as amct +from amct_pytorch.post_quant_calibration import LLMHelper + + +if __name__ == '__main__': + model, model_path = get_llama2('7b', seqlen=512) + model = model.eval() + copied_model = copy.deepcopy(model) + gpu_num = torch.cuda.device_count() + model, enc = build_model_and_enc(model, model_path, gpu_num) + + proto_path = './src/lut4_quant.cfg' + config_file = './output/config.json' + record_file = './output/record.txt' + + test_start_time = time.time() + # Phase1: generate quant config json + amct.create_post_quant_config(config_file, + model, + config_defination=proto_path) + + # Phase2: do weights calibration and generate calibration model + samples = get_calib_dataset( + data="pileval", tokenizer=enc, n_samples=512, block_size=256 + ) + samples = torch.cat(samples, dim=0)[:1,:] + model.config.use_cache = False + with torch.no_grad(): + post_quant_model = amct.create_post_quant_model(config_file, + record_file, + model) + calibration_helper = LLMHelper(post_quant_model, samples, calibration_block='LlamaDecoderLayer', layer_filter=True) + if torch.cuda.is_available(): + torch.cuda.empty_cache() + post_quant_model.config.use_cache = False + amct.quant_calibration(calibration_helper) + test_end_time = time.time() + total_time = test_end_time - test_start_time + print('Calibration time taken: ', total_time // 60, 'min ', total_time%60, 's') diff --git a/python/level1_single_api/9_amct/amct_pytorch/lut4_quantization/src/save_llama7b_quant_model.py b/python/level1_single_api/9_amct/amct_pytorch/lut4_quantization/src/save_llama7b_quant_model.py new file mode 100644 index 000000000..d89bdaa99 --- /dev/null +++ b/python/level1_single_api/9_amct/amct_pytorch/lut4_quantization/src/save_llama7b_quant_model.py @@ -0,0 +1,76 @@ +""" +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + + +import os +import copy +import time +import tqdm +import torch +import torch.nn as nn + +from utils import get_loaders, get_llama2, get_calib_dataset, build_model_and_enc +import amct_pytorch as amct + + +if __name__ == '__main__': + model, model_path = get_llama2('7b', seqlen=512) + model = model.eval() + copied_model = copy.deepcopy(model) + gpu_num = torch.cuda.device_count() + + proto_path = './src/lut4_quant.cfg' + config_file = './output/config.json' + record_file = './output/record.txt' + + test_start_time = time.time() + model, enc = build_model_and_enc(copied_model, model_path, gpu_num) + + # Phase1: save fakequant model + testenc = get_loaders(dataset_name='wikitext2', + enc=enc, + seqlen=model.seqlen) + + testenc = testenc.input_ids.to(model.device) + fake_quant_model = amct.save_post_quant_model(record_file, model, mode='fakequant') + nsamples = testenc.numel() // model.seqlen + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + # Phase2: Test ppl result + nlls = [] + test_start_time = time.time() + for i in tqdm.tqdm(range(nsamples), desc="evaluating..."): + batch = testenc[:, (i * model.seqlen) : ((i + 1) * model.seqlen)].to( + model.device + ) + with torch.no_grad(): + lm_logits = fake_quant_model(batch).logits + shift_logits = lm_logits[:, :-1, :].contiguous().float() + shift_labels = testenc[:, (i * model.seqlen) : ((i + 1) * model.seqlen)][:, 1:] + loss_fct = nn.CrossEntropyLoss() + loss = loss_fct( + shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) + ) + neg_log_likelihood = loss.float() * model.seqlen + nlls.append(neg_log_likelihood) + test_end_time = time.time() + + ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen)) + total_time = test_end_time - test_start_time + print('Test time taken: ', total_time // 60, 'min ', total_time%60, 's' ) + print('Score: ', ppl.item()) \ No newline at end of file diff --git a/python/level1_single_api/9_amct/amct_pytorch/lut4_quantization/src/utils.py b/python/level1_single_api/9_amct/amct_pytorch/lut4_quantization/src/utils.py new file mode 100644 index 000000000..1ff8bd14e --- /dev/null +++ b/python/level1_single_api/9_amct/amct_pytorch/lut4_quantization/src/utils.py @@ -0,0 +1,131 @@ +""" +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import os +import torch +import torch.nn as nn +from datasets import load_dataset,load_from_disk + +from transformers import AutoTokenizer, AutoConfig +from accelerate import infer_auto_device_map, dispatch_model +from accelerate.utils.modeling import get_balanced_memory + +def build_model_and_enc(model, model_path, gpu_num): + config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) + if "mpt" in config.__class__.__name__.lower(): + enc = AutoTokenizer.from_pretrained( + config.tokenizer_name, trust_remote_code=True + ) + else: + enc = AutoTokenizer.from_pretrained( + model_path, use_fast=False, trust_remote_code=True + ) + + # Move the model to GPU (as much as possible) for LM evaluation + # max_memory = ['0:16GiB', '1:16GiB','2:16GiB', 'cpu:30GiB'], '0' means the first GPU that you specify. + # I don't recommend use 16GiB, we need to reserve some space for other tensors during calculation + # please see the recommand memeory allocation in the Word file + # Adjust the max_size accroding to the real situation + # a clever way: + + max_memory = [] + for i in range(gpu_num): + max_memory.append(f'{i}:12GiB') + max_memory.append('cpu:80GiB') + print('Max_memory allocation: \n', max_memory) + + max_memory = [v.split(":") for v in (max_memory or [])] + max_memory = {(int(k) if k.isdigit() else k): v for k, v in max_memory} + kwargs = { + "max_memory": get_balanced_memory( + model, max_memory if len(max_memory) > 0 else None + ) + } + model.tie_weights() + device_map = infer_auto_device_map( + model, + no_split_module_classes=[ + "LlamaDecoderLayer", + ], + **kwargs, + ) + model = dispatch_model(model, device_map=device_map, + offload_dir=os.path.join(model_path, 'offload_dir')) + + return model, enc + +def get_llama2(model, seqlen=2048): + '''If model is specified from ['7b', '13b', '70b'], then we load official pretrained model; + If you want to load checkpoints other than the official ones, please specifiy the model path, + otherwise please choose from ['7b', '13b', '70b'] for better clarity + ''' + + def skip(*args, **kwargs): + pass + + if model in ['7b', '13b', '70b']: + model_path = f'/data/Models/pytorch/Llama2/Llama2_{model}_hf' + print(f'Getting official pretrained Llama2-{model}') + else: + model_path = model + torch.nn.init.kaiming_uniform_ = skip + torch.nn.init.uniform_ = skip + torch.nn.init.normal_ = skip + from transformers import LlamaForCausalLM + + model = LlamaForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, offload_folder="offload/") + + model.seqlen = seqlen + return model, model_path + + +def get_loaders(dataset_name: str, enc, seqlen): + if dataset_name == 'wikitext2': + print('Loading dataset: Wikitext2') + testenc = load_dataset('/data/Datasets/wikitext/wikitext-2-raw-v1/wikitext-2-raw/wikiscript.py', 'wikitext-2-raw-v1', split='test', trust_remote_code=True) + testenc = enc("\n\n".join(testenc["text"]), return_tensors="pt") + + return testenc + + +def get_calib_dataset(data="pileval", tokenizer=None, n_samples=512, block_size=512): + if data == "pileval": + dataset = load_from_disk('/pile_val_backup') + else: + raise NotImplementedError + dataset = dataset.shuffle(seed=42) + samples = [] + n_run = 0 + for data in dataset: + line = data["text"] + line = line.strip() + line_encoded = tokenizer.encode(line) + if len(line_encoded) > 512: + continue + sample = torch.tensor([line_encoded]) + if sample.numel() == 0: + continue + samples.append(sample) + n_run += 1 + if n_run == n_samples: + break + # now concatenate all samples and split according to block size + cat_samples = torch.cat(samples, dim=1) + n_split = cat_samples.shape[1] // block_size + print(f" * Split into {n_split} blocks") + return [ + cat_samples[:, i * block_size : (i + 1) * block_size] for i in range(n_split) + ] -- Gitee From cd953aa83ba3c3894d67cd11cda6077ca18ff5b9 Mon Sep 17 00:00:00 2001 From: wangyuqing Date: Wed, 28 May 2025 07:47:28 +0000 Subject: [PATCH 2/3] update python/level1_single_api/9_amct/amct_pytorch/lut4_quantization/src/utils.py. fix float16 to float32 Signed-off-by: wangyuqing --- .../9_amct/amct_pytorch/lut4_quantization/src/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/level1_single_api/9_amct/amct_pytorch/lut4_quantization/src/utils.py b/python/level1_single_api/9_amct/amct_pytorch/lut4_quantization/src/utils.py index 1ff8bd14e..5c235afbe 100644 --- a/python/level1_single_api/9_amct/amct_pytorch/lut4_quantization/src/utils.py +++ b/python/level1_single_api/9_amct/amct_pytorch/lut4_quantization/src/utils.py @@ -86,7 +86,7 @@ def get_llama2(model, seqlen=2048): torch.nn.init.normal_ = skip from transformers import LlamaForCausalLM - model = LlamaForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, offload_folder="offload/") + model = LlamaForCausalLM.from_pretrained(model_path, torch_dtype=torch.float32, offload_folder="offload/") model.seqlen = seqlen return model, model_path -- Gitee From df9233bff88505ecca0b021f90b608cab92cb071 Mon Sep 17 00:00:00 2001 From: wangyuqing Date: Wed, 28 May 2025 07:49:36 +0000 Subject: [PATCH 3/3] update python/level1_single_api/9_amct/amct_pytorch/lut4_quantization/src/lut4_quant.cfg. fix: fix code review Signed-off-by: wangyuqing --- .../9_amct/amct_pytorch/lut4_quantization/src/lut4_quant.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/level1_single_api/9_amct/amct_pytorch/lut4_quantization/src/lut4_quant.cfg b/python/level1_single_api/9_amct/amct_pytorch/lut4_quantization/src/lut4_quant.cfg index 5b8cd01c6..6f532c21c 100644 --- a/python/level1_single_api/9_amct/amct_pytorch/lut4_quantization/src/lut4_quant.cfg +++ b/python/level1_single_api/9_amct/amct_pytorch/lut4_quantization/src/lut4_quant.cfg @@ -1,4 +1,4 @@ -batch_num: 4 +batch_num: 1 skip_layers: "lm_head" weight_only_config: { weight_compress_only: True -- Gitee