diff --git a/torch_npu/__init__.py b/torch_npu/__init__.py index 288515bc469aaa8d269ab4de1da83aec616ff7e6..7ef01dbd4a1c99a17ec9e899ccbe3ea050580dce 100644 --- a/torch_npu/__init__.py +++ b/torch_npu/__init__.py @@ -36,6 +36,7 @@ for name in dir(torch_npu._C._VariableFunctions): continue globals()[name] = getattr(torch_npu._C._VariableFunctions, name) __all__.append(name) + setattr(torch, name, getattr(torch_npu._C._VariableFunctions, name)) all_monkey_patches = [ ["npu", torch_npu.npu],