diff --git a/setup.py b/setup.py index 44f9a6a843fe174b6e818f2d3e97794cc5492b01..3f2f95419985cb6b9d775e13ed8ed87fb21174c6 100644 --- a/setup.py +++ b/setup.py @@ -27,11 +27,12 @@ import site import distutils.ccompiler import distutils.command.clean from setuptools.command.build_ext import build_ext +from setuptools.command.install import install from setuptools import setup, find_packages, distutils, Extension BASE_DIR = os.path.dirname(os.path.abspath(__file__)) -VERSION = '1.8.1' +VERSION = '1.8.1-rc1' def _get_build_mode(): @@ -40,6 +41,15 @@ def _get_build_mode(): return sys.argv[i] +def get_package_dir(): + if '--user' in sys.argv: + package_dir = site.getusersitepackages() + else: + py_version = f'{sys.version_info.major}.{sys.version_info.minor}' + package_dir = f'{sys.prefix}/lib/python{py_version}/site-packages' + return package_dir + + def generate_bindings_code(base_dir): generate_code_cmd = ["sh", os.path.join(base_dir, 'scripts', 'generate_code.sh')] if subprocess.call(generate_code_cmd) != 0: @@ -95,12 +105,7 @@ def CppExtension(name, sources, *args, **kwargs): r''' Creates a :class:`setuptools.Extension` for C++. ''' - if '--user' in sys.argv: - package_dir = site.getusersitepackages() - else: - py_version = f'{sys.version_info.major}.{sys.version_info.minor}' - package_dir = f'{sys.prefix}/lib/python{py_version}/site-packages' - + package_dir = get_package_dir() temp_include_dirs = kwargs.get('include_dirs', []) temp_include_dirs.append(os.path.join(package_dir, 'torch/include')) temp_include_dirs.append(os.path.join(package_dir, 'torch/include/torch/csrc/api/include')) @@ -156,6 +161,23 @@ class Build(build_ext, object): self.compiler.compiler_so.remove('-g') return super(Build, self).build_extensions() + + +class PostInstallCommand(install): + + def run(self): + install.run(self) + if os.getenv('IMPORT_TORCHNPU', default='0').upper() in ['ON', '1', 'YES', 'TRUE', 'Y']: + self.execute(PostInstallCommand.obfuscate, (), msg="Do obfuscate") + + @staticmethod + def obfuscate(): + package_dir = get_package_dir() + assert os.path.exists(os.path.join(package_dir, "torch")), "Cannot find torch." + with open(os.path.join(package_dir, "torch/__init__.py"), "r+") as f: + if "import torch_npu" in [line.strip() for line in f.readlines()]: + return + f.write("try:\n import torch_npu\nexcept:\n pass\n") build_mode = _get_build_mode() @@ -231,4 +253,5 @@ setup( cmdclass={ 'build_ext': Build, 'clean': Clean, + 'install': PostInstallCommand })