diff --git a/python/level1_single_api/9_amct/amct_pytorch/fp4_weight_quantization/README_CN.md b/python/level1_single_api/9_amct/amct_pytorch/fp4_weight_quantization/README_CN.md new file mode 100644 index 0000000000000000000000000000000000000000..93ea0a9ce1c77f4214d90da5539b97f16fc57a9b --- /dev/null +++ b/python/level1_single_api/9_amct/amct_pytorch/fp4_weight_quantization/README_CN.md @@ -0,0 +1,50 @@ +# FP4伪量化 + +## 1 FP4伪量化 + +### 1.1 安装依赖 + +本sample依赖包可参考[requirements.txt](requirements.txt) + +### 1.2 模型和数据集准备 + +本sample以Llama2-7b模型,pileval和wikitext2数据集为示例,请用户自行下载。 + +### 1.3 简易量化配置 +./src/quantization.cfg文件为用户自定义的简易量化配置,具体表示信息如下: + +| 字段 |类型| 说明 | 默认值 | 取值范围 | +|:--| :-: | :-- | :-: | :-: | +|skip_layers|str|跳过量化的层 |/|/| +|weight_only_config.weight_compress_only|bool|是否为仅权重量化|False|True/False| +|weight_only_config.wts_type|enum|量化后权重类型|INT8|INT8/MXFP4_E2M1/HIFLOAT8/FLOAT8_E4M3FN| +|weight_only_config.awq_quantize.grids_num|uint32|awq搜索格点数量|20|/|/| + +## 2 FLOAT4_E2M1量化示例 +> 当前quantization.cfg文件中weight_only_config.wts_type设置的值为FLOAT4_E2M1 + + +### 2.1 使用接口方式调用 + +请在当前目录执行如下命令运行示例程序 + +验证fakequant模型脚本: + +`CUDA_VISIBLE_DEVICES=0,1,2,3,4,5 python3 src/run_llama7b_quantization.py --calibration_data=/pile_val_backup/ --verify_data=/data/Datasets/wikitext/wikitext-2-raw-v1/wikitext-2-raw/wikiscript.py --model=/data/Models/pytorch/Llama2/Llama2_7b_hf` + + +若出现如下信息,则说明量化成功: + +```none +Test time taken: 9.0 min 38.24865388870239 s +Score: 5.657759 +``` + +推理成功后,在当前目录会生成量化日志文件./amct_log/amct_pytorch.log和./output文件夹,该文件夹内包含以下内容: + +- config.json:量化配置文件,描述了如何对模型中的每一层进行量化。 +- record.txt:量化因子记录文件。 +- awq_result.pt:存储了awq算法的的scale和clip +- quant_factor.pt:存储量化缩放因子 + +> 如果outputs目录下已经存在量化配置文件或量化因子记录文件,再次运行示例程序时,如果新生成的文件与已有文件同名,则会覆盖已有的量化配置文件或量化因子记录文件。 diff --git a/python/level1_single_api/9_amct/amct_pytorch/fp4_weight_quantization/requirements.txt b/python/level1_single_api/9_amct/amct_pytorch/fp4_weight_quantization/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..55441d06294298b274659bf1ae738f73e985c93f --- /dev/null +++ b/python/level1_single_api/9_amct/amct_pytorch/fp4_weight_quantization/requirements.txt @@ -0,0 +1,7 @@ +torch==2.1.0 +transformers==4.40.0 +accelerate==0.30.1 +datasets==2.19.1 +sentencepiece==0.2.0 +numpy==1.23.5 +protobuf==3.20.2 \ No newline at end of file diff --git a/python/level1_single_api/9_amct/amct_pytorch/fp4_weight_quantization/src/quantization.cfg b/python/level1_single_api/9_amct/amct_pytorch/fp4_weight_quantization/src/quantization.cfg new file mode 100644 index 0000000000000000000000000000000000000000..a43152ad3f761e4e03c3e8028955ac8daee102b8 --- /dev/null +++ b/python/level1_single_api/9_amct/amct_pytorch/fp4_weight_quantization/src/quantization.cfg @@ -0,0 +1,8 @@ +skip_layers: "lm_head" +weight_only_config: { + weight_compress_only: True + wts_type: FLOAT4_E2M1 + awq_quantize:{ + grids_num: 20 + } +} \ No newline at end of file diff --git a/python/level1_single_api/9_amct/amct_pytorch/fp4_weight_quantization/src/run_llama7b_quantization.py b/python/level1_single_api/9_amct/amct_pytorch/fp4_weight_quantization/src/run_llama7b_quantization.py new file mode 100644 index 0000000000000000000000000000000000000000..4aac4fad999ffa617278ba0439cabb409534da90 --- /dev/null +++ b/python/level1_single_api/9_amct/amct_pytorch/fp4_weight_quantization/src/run_llama7b_quantization.py @@ -0,0 +1,162 @@ +""" +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import argparse +import os +import copy +import time +import tqdm +import torch +import torch.nn as nn +from transformers import AutoTokenizer, AutoConfig +from accelerate import infer_auto_device_map, dispatch_model +from accelerate.utils.modeling import get_balanced_memory + +from utils import get_loaders, get_llama2, get_calib_dataset +import amct_pytorch as amct + + +def build_model_and_enc(model, model_path, gpu_num): + config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) + if "mpt" in config.__class__.__name__.lower(): + enc = AutoTokenizer.from_pretrained( + config.tokenizer_name, trust_remote_code=True + ) + else: + enc = AutoTokenizer.from_pretrained( + model_path, use_fast=False, trust_remote_code=True + ) + + # Move the model to GPU (as much as possible) for LM evaluation + # max_memory = ['0:16GiB', '1:16GiB','2:16GiB', 'cpu:30GiB'], '0' means the first GPU that you specify. + # I don't recommend use 16GiB, we need to reserve some space for other tensors during calculation + # please see the recommand memeory allocation in the Word file + # Adjust the max_size accroding to the real situation + # a clever way: + + max_memory = [] + for i in range(gpu_num): + max_memory.append(f'{i}:12GiB') + max_memory.append('cpu:80GiB') + print('Max_memory allocation: \n', max_memory) + + max_memory = [v.split(":") for v in (max_memory or [])] + max_memory = {(int(k) if k.isdigit() else k): v for k, v in max_memory} + kwargs = { + "max_memory": get_balanced_memory( + model, max_memory if len(max_memory) > 0 else None + ) + } + model.tie_weights() + device_map = infer_auto_device_map( + model, + no_split_module_classes=[ + "LlamaDecoderLayer", + ], + **kwargs, + ) + model = dispatch_model(model, device_map=device_map, + offload_dir=os.path.join(model_path, 'offload_dir')) + + return model, enc + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--calibration_data', type=str, default='/pile_val_backup') + parser.add_argument('--verify_data', type=str, default='/data/Datasets/wikitext/wikitext-2-raw-v1/wikitext-2-raw/wikiscript.py') + parser.add_argument('--model', type=str, default='/data/Models/pytorch/Llama2/Llama2_7b_hf') + + args = parser.parse_args() + model, model_path = get_llama2(args.model) + model = model.eval() + copied_model = copy.deepcopy(model) + gpu_num = torch.cuda.device_count() + model, enc = build_model_and_enc(model, model_path, gpu_num) + + proto_path = './src/quantization.cfg' + config_file = './output/config.json' + record_file = './output/record.txt' + + test_start_time = time.time() + # Phase1: generate quant config json + amct.create_post_quant_config(config_file, + model, + config_defination=proto_path) + + # Phase2: do weights calibration and generate calibration model + samples = get_calib_dataset( + data_path=args.calibration_data, tokenizer=enc, n_samples=512, block_size=518 + ) + samples = torch.cat(samples, dim=0)[:1,:] + + post_quant_model = amct.create_post_quant_model(config_file, + record_file, + model) + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + with torch.no_grad(): + post_quant_model(samples.to(next(post_quant_model.parameters()).device)) + if torch.cuda.is_available(): + torch.cuda.empty_cache() + test_end_time = time.time() + total_time = test_end_time - test_start_time + print('Calibration time taken: ', total_time // 60, 'min ', total_time%60, 's') + # save memory, del unuse model + del post_quant_model + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + model, enc = build_model_and_enc(copied_model, model_path, gpu_num) + + # Phase3: save fakequant model + testenc = get_loaders(data_path=args.verify_data, + enc=enc, + seqlen=model.seqlen) + + testenc = testenc.input_ids.to(model.device) + + quant_model = amct.save_post_quant_model(record_file, model, mode='fakequant') + + nsamples = testenc.numel() // model.seqlen + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + # Phase4: Test ppl result + nlls = [] + test_start_time = time.time() + for i in tqdm.tqdm(range(nsamples), desc="evaluating..."): + batch = testenc[:, (i * model.seqlen) : ((i + 1) * model.seqlen)].to( + quant_model.device + ) + with torch.no_grad(): + lm_logits = quant_model(batch).logits + shift_logits = lm_logits[:, :-1, :].contiguous().float().cpu() + shift_labels = testenc[:, (i * model.seqlen) : ((i + 1) * model.seqlen)][:, 1:].cpu() + loss_fct = nn.CrossEntropyLoss() + loss = loss_fct( + shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) + ) + neg_log_likelihood = loss.float() * model.seqlen + nlls.append(neg_log_likelihood) + test_end_time = time.time() + + ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen)) + + total_time = test_end_time - test_start_time + print('Test time taken: ', total_time // 60, 'min ', total_time%60, 's' ) + print('Score: ', ppl.item()) \ No newline at end of file diff --git a/python/level1_single_api/9_amct/amct_pytorch/fp4_weight_quantization/src/utils.py b/python/level1_single_api/9_amct/amct_pytorch/fp4_weight_quantization/src/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..474a5b6185b21c2ecdd76d4fa69a1c3d0e0d7a20 --- /dev/null +++ b/python/level1_single_api/9_amct/amct_pytorch/fp4_weight_quantization/src/utils.py @@ -0,0 +1,69 @@ +""" +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import torch +import torch.nn as nn +from datasets import load_dataset,load_from_disk + +def get_llama2(model_path, seqlen=2048): + def skip(*args, **kwargs): + pass + + torch.nn.init.kaiming_uniform_ = skip + torch.nn.init.uniform_ = skip + torch.nn.init.normal_ = skip + from transformers import LlamaForCausalLM + + model = LlamaForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, offload_folder="offload/") + + model.seqlen = seqlen + return model, model_path + + +def get_loaders(data_path: str, enc, seqlen): + + print('Loading dataset: Wikitext2') + testenc = load_dataset(data_path, 'wikitext-2-raw-v1', split='test', trust_remote_code=True) + testenc = enc("\n\n".join(testenc["text"]), return_tensors="pt") + + return testenc + + +def get_calib_dataset(data_path, tokenizer=None, n_samples=512, block_size=512): + dataset = load_from_disk(data_path) + dataset = dataset.shuffle(seed=42) + samples = [] + n_run = 0 + for data in dataset: + line = data["text"] + line = line.strip() + line_encoded = tokenizer.encode(line) + if len(line_encoded) > 512: + continue + sample = torch.tensor([line_encoded]) + if sample.numel() == 0: + continue + samples.append(sample) + n_run += 1 + if n_run == n_samples: + break + # now concatenate all samples and split according to block size + cat_samples = torch.cat(samples, dim=1) + n_split = cat_samples.shape[1] // block_size + print(f" * Split into {n_split} blocks") + return [ + cat_samples[:, i * block_size : (i + 1) * block_size] for i in range(n_split) + ]