diff --git a/codegen/gen_backend_stubs.py b/codegen/gen_backend_stubs.py index 248704d49282c0d2a54410cc092c5abeb61d3846..6c682edaafb724c35a38d624d7a4802a5bc97d40 100644 --- a/codegen/gen_backend_stubs.py +++ b/codegen/gen_backend_stubs.py @@ -548,7 +548,7 @@ def gen_functionalization(fm: FileManager, grouped_native_functions, key_fn=key_func, env_callable=functionalization_env_callable, - num_shards=2, + num_shards=1, sharded_keys={ "func_definitions", "func_registrations", diff --git a/codegen/templates/RegisterFunctionalization.cpp b/codegen/templates/RegisterFunctionalization.cpp index 0f57e00fdcf60ceff5a9e37dc10638f18b25b69e..45100d89d9a6d508bff8c154faeba81975428d8a 100644 --- a/codegen/templates/RegisterFunctionalization.cpp +++ b/codegen/templates/RegisterFunctionalization.cpp @@ -92,6 +92,51 @@ inline c10::List> to_meta(const c10::List npu_scatter_pa_kv_cache_functional(c10::DispatchKeySet dispatchKeySet, + const at::Tensor& key, const at::Tensor& value, const at::Tensor& key_cache, const at::Tensor& value_cache, + const at::Tensor& slot_mapping, const c10::optional& compress_lens, + const c10::optional& compress_seq_offsets, const c10::optional& seq_lens, + at::TensorList out) +{ + at::Tensor key_cache_ref = out[0]; + at::Tensor value_cache_ref = out[1]; + at::functionalization::impl::sync(key); + at::functionalization::impl::sync(value); + at::functionalization::impl::sync(slot_mapping); + at::functionalization::impl::sync(compress_lens); + at::functionalization::impl::sync(compress_seq_offsets); + at::functionalization::impl::sync(seq_lens); + at::functionalization::impl::sync(key_cache_ref); + at::functionalization::impl::sync(value_cache_ref); + + auto key_unwarp = at::functionalization::impl::from_functional_tensor(key); + auto value_unwarp = at::functionalization::impl::from_functional_tensor(value); + auto slot_mapping_unwarp = at::functionalization::impl::from_functional_tensor(slot_mapping); + auto compress_lens_unwarp = at::functionalization::impl::from_functional_tensor(compress_lens); + auto compress_seq_offsets_unwarp = at::functionalization::impl::from_functional_tensor(compress_seq_offsets); + auto seq_lens_unwarp = at::functionalization::impl::from_functional_tensor(seq_lens); + auto key_cache_ref_unwarp = at::functionalization::impl::from_functional_tensor(key_cache_ref); + auto value_cache_ref_unwarp = at::functionalization::impl::from_functional_tensor(value_cache_ref); + + at::Tensor tmp_input1; + at::Tensor tmp_input2; + { + at::AutoDispatchSkipFunctionalize guard; + auto tmp_result = at_npu::native::custom_ops::npu_scatter_pa_kv_cache(key_unwarp, value_unwarp, + key_cache_ref_unwarp, value_cache_ref_unwarp, slot_mapping_unwarp, compress_lens_unwarp, + compress_seq_offsets_unwarp, seq_lens_unwarp); + tmp_input1 = std::get<0>(tmp_result); + tmp_input2 = std::get<1>(tmp_result); + } + + at::functionalization::impl::replace_(key_cache_ref, tmp_input1); + at::functionalization::impl::replace_(value_cache_ref, tmp_input2); + at::functionalization::impl::commit_update(key_cache_ref); + at::functionalization::impl::commit_update(value_cache_ref); + at::functionalization::impl::sync(key_cache_ref); + at::functionalization::impl::sync(value_cache_ref); + return ::std::tuple(key_cache_ref, value_cache_ref); +} ${func_definitions} @@ -101,6 +146,7 @@ namespace { TORCH_LIBRARY_IMPL(npu, Functionalize, m) { ${func_registrations}; + m.impl("npu_scatter_pa_kv_cache.out", TORCH_FN(functionalization::npu_scatter_pa_kv_cache_functional)); } } // namespace