1 Star 0 Fork 0

科大讯飞/DeepEP

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
setup.py 2.31 KB
一键复制 编辑 原始数据 按行查看 历史
Chenggang Zhao 提交于 2025-02-24 21:56 +08:00 . Initial commit
import os
import subprocess
import setuptools
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
if __name__ == '__main__':
nvshmem_dir = os.getenv('NVSHMEM_DIR', None)
assert nvshmem_dir is not None and os.path.exists(nvshmem_dir), 'Failed to find NVSHMEM'
print(f'NVSHMEM directory: {nvshmem_dir}')
# TODO: currently, we only support Hopper architecture, we may add Ampere support later
os.environ['TORCH_CUDA_ARCH_LIST'] = '9.0'
cxx_flags = ['-O3', '-Wno-deprecated-declarations', '-Wno-unused-variable',
'-Wno-sign-compare', '-Wno-reorder', '-Wno-attributes']
nvcc_flags = ['-O3', '-Xcompiler', '-O3', '-rdc=true', '--ptxas-options=--register-usage-level=10']
include_dirs = ['csrc/', f'{nvshmem_dir}/include']
sources = ['csrc/deep_ep.cpp',
'csrc/kernels/runtime.cu', 'csrc/kernels/intranode.cu',
'csrc/kernels/internode.cu', 'csrc/kernels/internode_ll.cu']
library_dirs = [f'{nvshmem_dir}/lib']
# Disable aggressive PTX instructions
if int(os.getenv('DISABLE_AGGRESSIVE_PTX_INSTRS', '0')):
cxx_flags.append('-DDISABLE_AGGRESSIVE_PTX_INSTRS')
nvcc_flags.append('-DDISABLE_AGGRESSIVE_PTX_INSTRS')
# Disable DLTO (default by PyTorch)
nvcc_dlink = ['-dlink', f'-L{nvshmem_dir}/lib', '-lnvshmem']
extra_link_args = ['-l:libnvshmem.a', '-l:nvshmem_bootstrap_uid.so', f'-Wl,-rpath,{nvshmem_dir}/lib']
extra_compile_args = {
'cxx': cxx_flags,
'nvcc': nvcc_flags,
'nvcc_dlink': nvcc_dlink
}
# noinspection PyBroadException
try:
cmd = ['git', 'rev-parse', '--short', 'HEAD']
revision = '+' + subprocess.check_output(cmd).decode('ascii').rstrip()
except Exception as _:
revision = ''
setuptools.setup(
name='deep_ep',
version='1.0.0' + revision,
packages=setuptools.find_packages(
include=['deep_ep']
),
ext_modules=[
CUDAExtension(
name='deep_ep_cpp',
include_dirs=include_dirs,
library_dirs=library_dirs,
sources=sources,
extra_compile_args=extra_compile_args,
extra_link_args=extra_link_args
)
],
cmdclass={
'build_ext': BuildExtension
}
)
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/iflytek/DeepEP.git
git@gitee.com:iflytek/DeepEP.git
iflytek
DeepEP
DeepEP
h20

搜索帮助