diff --git a/torch_npu/contrib/__init__.py b/torch_npu/contrib/__init__.py index 5bc4472adf03bf732c5d0fd9a9880a2ad90cfe74..6a531eb9136a48c9227ffcb52c0e2fef72db8582 100644 --- a/torch_npu/contrib/__init__.py +++ b/torch_npu/contrib/__init__.py @@ -5,7 +5,7 @@ from .function import npu_iou, npu_ptiou, npu_giou, npu_diou, npu_ciou, npu_mult from .module import ChannelShuffle, Prefetcher, LabelSmoothingCrossEntropy, ROIAlign, DCNv2, \ ModulatedDeformConv, Mish, BiLSTM, PSROIPool, SiLU, Swish, NpuFairseqDropout, NpuCachedDropout, \ MultiheadAttention, FusedColorJitter, NpuDropPath, Focus, FastBatchNorm1d, FastBatchNorm2d, \ - FastBatchNorm3d, FastSyncBatchNorm, LinearA8W8Quant, LinearWeightQuant + FastBatchNorm3d, FastSyncBatchNorm, LinearA8W8Quant, LinearWeightQuant, QuantConv2d __all__ = [ # from function @@ -47,4 +47,5 @@ __all__ = [ "LinearA8W8Quant", "FusedColorJitter", "LinearWeightQuant", + "QuantConv2d", ] diff --git a/torch_npu/contrib/module/__init__.py b/torch_npu/contrib/module/__init__.py index f9dd427f441bd19bd6488fa802a4b74fcffb0e66..c91e18cbe30c33ac8b0ea8c606a6b0731371c083 100644 --- a/torch_npu/contrib/module/__init__.py +++ b/torch_npu/contrib/module/__init__.py @@ -17,6 +17,7 @@ from .batchnorm_with_int32_count import FastBatchNorm1d, \ from .linear_a8w8_quant import LinearA8W8Quant from .linear_weight_quant import LinearWeightQuant from .npu_modules import DropoutWithByteMask +from .quant_conv2d import QuantConv2d __all__ = [ "ChannelShuffle", @@ -40,4 +41,5 @@ __all__ = [ "LinearA8W8Quant", "LinearWeightQuant", "DropoutWithByteMask", + "QuantConv2d", ] diff --git a/torch_npu/contrib/module/quant_conv2d.py b/torch_npu/contrib/module/quant_conv2d.py index 1aa59bce434fea66bb4749a319e77ccbe673610b..70eba9e371e968492ec006bbf6bf19238c3785e4 100644 --- a/torch_npu/contrib/module/quant_conv2d.py +++ b/torch_npu/contrib/module/quant_conv2d.py @@ -59,21 +59,19 @@ class QuantConv2d(nn.Module): offset (Tensor): Requant calculation parameter of shape (out_channels) Examples:: - >>> quant_conv2d_input = torch.randint(-1, 1, (1, 1, 4, 4), dtype=torch.int8) - >>> weight = torch.randint(-1, 1, (1, 1, 3, 3), dtype=torch.int8) - >>> scale = torch.randint(-1, 1, (1,), dtype=torch.int64) - >>> bias = torch.randint(-1, 1, (1,), dtype=torch.int32) - >>> model = QuantConv2d(in_channels, out_channels, k_size, output_dtype) + >>> quant_conv2d_input = torch.randint(-1, 1, (1, 1, 4, 4), dtype=torch.int8).npu() + >>> weight = torch.randint(-1, 1, (1, 1, 3, 3), dtype=torch.int8).npu() + >>> scale = torch.randint(-1, 1, (1,), dtype=torch.int64).npu() + >>> bias = torch.randint(-1, 1, (1,), dtype=torch.int32).npu() + >>> model = QuantConv2d(in_channels=1, out_channels=1, kernel_size=(3, 3), output_dtype=torch.float16).npu() >>> model = model.npu() >>> model.weight.data = weight >>> model.scale.data = scale >>> model.bias.data = bias - >>> config = CompilerConfig() - >>> npu_backend = tng.get_npu_backend(compiler_config=config) - >>> static_graph_model = torch.compile(model, backend=npu_backend, dynamic=False) + >>> static_graph_model = torch.compile(model, backend="npu", dynamic=False) >>> output = static_graph_model(quant_conv2d_input) >>> print(output.size()) - torch.Size(1, 1, 2, 2) + torch.Size([1, 1, 2, 2]) """ in_channels: int out_channels: int