From 87e3fb82b1ca9f66dde47cbce49c76ff967d3ed3 Mon Sep 17 00:00:00 2001 From: wangyuqing Date: Wed, 28 May 2025 07:49:36 +0000 Subject: [PATCH] update python/level1_single_api/9_amct/amct_pytorch/lut4_quantization/src/lut4_quant.cfg. fix: fix code review Signed-off-by: wangyuqing --- .../lut4_quantization/README_CN.md | 24 +++++++++-------- .../lut4_quantization/src/lut4_quant.cfg | 2 +- .../src/run_llama7b_calibration.py | 9 +++++-- .../src/save_llama7b_quant_model.py | 9 +++++-- .../lut4_quantization/src/utils.py | 26 +++++-------------- 5 files changed, 34 insertions(+), 36 deletions(-) diff --git a/python/level1_single_api/9_amct/amct_pytorch/lut4_quantization/README_CN.md b/python/level1_single_api/9_amct/amct_pytorch/lut4_quantization/README_CN.md index f9e8a13c1..f6cf738d3 100644 --- a/python/level1_single_api/9_amct/amct_pytorch/lut4_quantization/README_CN.md +++ b/python/level1_single_api/9_amct/amct_pytorch/lut4_quantization/README_CN.md @@ -1,6 +1,6 @@ -# MXFP4量化 +# LUT4bit量化 -## 1 MXFP4量化前提 +## 1 LUT4bit量化前提 ### 1.1 安装依赖 @@ -15,12 +15,12 @@ | 字段 |类型| 说明 | 默认值 | 取值范围 | 注意事项 | |:--| :-: | :-- | :-: | :-: | :-: | -|batch_num|uint32|量化使用的batch数量 |1|/|MXFP量化中配置不生效,校准使用batch数与推理使用输入数据有关,是校准脚本中的batch_num| +|batch_num|uint32|量化使用的batch数量 |1|/|校准使用batch数与推理使用输入数据有关,是校准脚本中的batch_num| |skip_layers|str|跳过量化的层 |/|/|跳过量化层支持模糊匹配,当配置字符串为层名字串,或与层名一致时,跳过该层量化,不生成量化配置。字符串必须包含数字或字母| -|weight_only_config.weight_compress_only|bool|是否为仅权重量化|False|True/False|MXFP4量化目前仅支持权重量化,需要设置为True| -|weight_only_config.wts_type|enum|量化后权重类型|INT8|INT8/MXFP4_E2M1|/| -|weight_only_config.weight_granularity|enum|权重量化粒度|PER_TENSOR|PER_TENSOR/PER_CHANNEL/PER_GROUP|MXFP4_E2M1仅支持PER_GROUP模式| -|weight_only_config.round_mode|enum|舍入模式|/|HYBRID/ROUND/RINT|MXFP4_E2M1仅支持RINT模式| +|weight_only_config.weight_compress_only|bool|是否为仅权重量化|False|True/False|LUT4bit量化目前仅支持权重量化,需要设置为True| +|weight_only_config.wts_type|enum|量化后权重类型|INT8|本sample支持INT4|/| +|weight_only_config.weight_granularity|enum|权重量化粒度|PER_TENSOR|PER_TENSOR/PER_CHANNEL/PER_GROUP|LUT4bit仅支持PER_GROUP模式| +|weight_only_config.round_mode|enum|舍入模式|/|HYBRID/ROUND/RINT|LUT4bit仅支持RINT模式| |weight_only_config.lut_quantize.lut_alog|enum|lut量化算法模式|CLUSTER|CLUSTER/ATCTAN| ## 2 LUT4量化示例 @@ -29,19 +29,21 @@ **step 1.** 请在当前目录执行如下两条命令运行示例程序,用户需根据实际情况修改示例程序中的模型和数据集路径: -`CUDA_VISIBLE_DEVICES=0,1,2,3,4,5 python3 src/run_llama7b_calibration.py` -`CUDA_VISIBLE_DEVICES=0,1,2,3,4,5 python3 src/save_llama7b_quant_model.py` +校准: +`CUDA_VISIBLE_DEVICES=0,1,2,3,4,5 python3 src/run_llama7b_calibration.py --calibration_data=/pile_val_backup/ --model=/data/Models/pytorch/Llama2/Llama2_7b_hf` +保存并推理量化模型: +`CUDA_VISIBLE_DEVICES=0,1,2,3,4,5 python3 src/save_llama7b_quant_model.py --verify_data=/data/Datasets/wikitext/wikitext-2-raw-v1/wikitext-2-raw/wikiscript.py --model=/data/Models/pytorch/Llama2/Llama2_7b_hf` 若出现如下信息,则说明校准成功: ```none -Calibration time taken: 1.0 min 59.24865388870239 s +Calibration time taken: 56.0 min 17.225504398345947 s ``` 出现如下信息,说明量化成功 ```none -Test time taken: 1.0 min 59.24865388870239 s +Test time taken: 56.0 min 17 s Score: 5.670858383178711 ``` diff --git a/python/level1_single_api/9_amct/amct_pytorch/lut4_quantization/src/lut4_quant.cfg b/python/level1_single_api/9_amct/amct_pytorch/lut4_quantization/src/lut4_quant.cfg index 5b8cd01c6..6f532c21c 100644 --- a/python/level1_single_api/9_amct/amct_pytorch/lut4_quantization/src/lut4_quant.cfg +++ b/python/level1_single_api/9_amct/amct_pytorch/lut4_quantization/src/lut4_quant.cfg @@ -1,4 +1,4 @@ -batch_num: 4 +batch_num: 1 skip_layers: "lm_head" weight_only_config: { weight_compress_only: True diff --git a/python/level1_single_api/9_amct/amct_pytorch/lut4_quantization/src/run_llama7b_calibration.py b/python/level1_single_api/9_amct/amct_pytorch/lut4_quantization/src/run_llama7b_calibration.py index 1072dc945..bab76bfe5 100644 --- a/python/level1_single_api/9_amct/amct_pytorch/lut4_quantization/src/run_llama7b_calibration.py +++ b/python/level1_single_api/9_amct/amct_pytorch/lut4_quantization/src/run_llama7b_calibration.py @@ -28,7 +28,12 @@ from amct_pytorch.post_quant_calibration import LLMHelper if __name__ == '__main__': - model, model_path = get_llama2('7b', seqlen=512) + parser = argparse.ArgumentParser() + parser.add_argument('--calibration_data', type=str, default='/pile_val_backup') + parser.add_argument('--model', type=str, default='/data/Models/pytorch/Llama2/Llama2_7b_hf') + + args = parser.parse_args() + model, model_path = get_llama2(args.model) model = model.eval() copied_model = copy.deepcopy(model) gpu_num = torch.cuda.device_count() @@ -46,7 +51,7 @@ if __name__ == '__main__': # Phase2: do weights calibration and generate calibration model samples = get_calib_dataset( - data="pileval", tokenizer=enc, n_samples=512, block_size=256 + data_path=args.calibration_data, tokenizer=enc, n_samples=512, block_size=256 ) samples = torch.cat(samples, dim=0)[:1,:] model.config.use_cache = False diff --git a/python/level1_single_api/9_amct/amct_pytorch/lut4_quantization/src/save_llama7b_quant_model.py b/python/level1_single_api/9_amct/amct_pytorch/lut4_quantization/src/save_llama7b_quant_model.py index d89bdaa99..fafcbdd48 100644 --- a/python/level1_single_api/9_amct/amct_pytorch/lut4_quantization/src/save_llama7b_quant_model.py +++ b/python/level1_single_api/9_amct/amct_pytorch/lut4_quantization/src/save_llama7b_quant_model.py @@ -27,7 +27,12 @@ import amct_pytorch as amct if __name__ == '__main__': - model, model_path = get_llama2('7b', seqlen=512) + parser = argparse.ArgumentParser() + parser.add_argument('--verify_data', type=str, default='/data/Datasets/wikitext/wikitext-2-raw-v1/wikitext-2-raw/wikiscript.py') + parser.add_argument('--model', type=str, default='/data/Models/pytorch/Llama2/Llama2_7b_hf') + + args = parser.parse_args() + model, model_path = get_llama2(args.model) model = model.eval() copied_model = copy.deepcopy(model) gpu_num = torch.cuda.device_count() @@ -40,7 +45,7 @@ if __name__ == '__main__': model, enc = build_model_and_enc(copied_model, model_path, gpu_num) # Phase1: save fakequant model - testenc = get_loaders(dataset_name='wikitext2', + testenc = get_loaders(dataset_name=args.verify_data, enc=enc, seqlen=model.seqlen) diff --git a/python/level1_single_api/9_amct/amct_pytorch/lut4_quantization/src/utils.py b/python/level1_single_api/9_amct/amct_pytorch/lut4_quantization/src/utils.py index 5c235afbe..1158edbc6 100644 --- a/python/level1_single_api/9_amct/amct_pytorch/lut4_quantization/src/utils.py +++ b/python/level1_single_api/9_amct/amct_pytorch/lut4_quantization/src/utils.py @@ -68,19 +68,9 @@ def build_model_and_enc(model, model_path, gpu_num): return model, enc def get_llama2(model, seqlen=2048): - '''If model is specified from ['7b', '13b', '70b'], then we load official pretrained model; - If you want to load checkpoints other than the official ones, please specifiy the model path, - otherwise please choose from ['7b', '13b', '70b'] for better clarity - ''' - def skip(*args, **kwargs): pass - if model in ['7b', '13b', '70b']: - model_path = f'/data/Models/pytorch/Llama2/Llama2_{model}_hf' - print(f'Getting official pretrained Llama2-{model}') - else: - model_path = model torch.nn.init.kaiming_uniform_ = skip torch.nn.init.uniform_ = skip torch.nn.init.normal_ = skip @@ -92,20 +82,16 @@ def get_llama2(model, seqlen=2048): return model, model_path -def get_loaders(dataset_name: str, enc, seqlen): - if dataset_name == 'wikitext2': - print('Loading dataset: Wikitext2') - testenc = load_dataset('/data/Datasets/wikitext/wikitext-2-raw-v1/wikitext-2-raw/wikiscript.py', 'wikitext-2-raw-v1', split='test', trust_remote_code=True) - testenc = enc("\n\n".join(testenc["text"]), return_tensors="pt") +def get_loaders(data_path: str, enc, seqlen): + print('Loading dataset: Wikitext2') + testenc = load_dataset(data_path, 'wikitext-2-raw-v1', split='test', trust_remote_code=True) + testenc = enc("\n\n".join(testenc["text"]), return_tensors="pt") return testenc -def get_calib_dataset(data="pileval", tokenizer=None, n_samples=512, block_size=512): - if data == "pileval": - dataset = load_from_disk('/pile_val_backup') - else: - raise NotImplementedError +def get_calib_dataset(data_path="pileval", tokenizer=None, n_samples=512, block_size=512): + dataset = load_from_disk(data_path) dataset = dataset.shuffle(seed=42) samples = [] n_run = 0 -- Gitee