From 90260b56955cfc147d13b920ed6e476497e7ffaa Mon Sep 17 00:00:00 2001 From: zou-jieyu Date: Fri, 23 Jan 2026 16:47:25 +0800 Subject: [PATCH 1/2] Add distributed ops TransposeExtView --- .../core/shard/ops/parallel_transpose.py | 61 ++++++ .../core/shard/ops/yaml/transpose_ops.yaml | 5 + .../test_ops_transpose_ext_view_shell.py | 60 ++++++ .../transpose_ext_view_shard_in_python.py | 147 ++++++++++++++ .../test_parallel_transpose_ext_view.py | 187 ++++++++++++++++++ 5 files changed, 460 insertions(+) create mode 100644 tests/mindspore/st/shard/test_ops_transpose_ext_view_shell.py create mode 100644 tests/mindspore/st/shard/transpose_ext_view_shard_in_python.py create mode 100644 tests/mindspore/ut/parallel_ops_infer/test_parallel_transpose_ext_view.py diff --git a/hyper_parallel/core/shard/ops/parallel_transpose.py b/hyper_parallel/core/shard/ops/parallel_transpose.py index 31e4033..c8ccd59 100644 --- a/hyper_parallel/core/shard/ops/parallel_transpose.py +++ b/hyper_parallel/core/shard/ops/parallel_transpose.py @@ -65,3 +65,64 @@ class TransposeDistributedOp(DistributedOp): out_layout = output_layout(*out_tensor_map) return out_layout + + +class TransposeExtViewDistributedOp(DistributedOp): + """Distributed implementation for TransposeExtView operator.""" + + def infer_layout(self, layouts, extra_args): + """ + Infer output layout for TransposeExtView operator. + + Rules: + 1. Output layout is input layout with two tensor-map dimensions swapped. + + Args: + layouts (tuple): Layouts of input tensor. + extra_args (tuple | list): (dim0, dim1) + + Returns: + Layout: Layout for output tensor + """ + if not layouts or layouts[0] is None: + raise ValueError("Input layout is required for TransposeExtView.") + + if not isinstance(extra_args, (tuple, list)) or len(extra_args) < 2: + raise ValueError(f"TransposeExtView expects (dim0, dim1) in extra_args, but got {extra_args}") + + layout = layouts[0] + dim0 = extra_args[0] + dim1 = extra_args[1] + + if not isinstance(dim0, int) or not isinstance(dim1, int): + raise TypeError(f"dim0 and dim1 must be int, but got {type(dim0)} and {type(dim1)}") + + in_tensor_map = layout.alias_tensor_map + if in_tensor_map is None: + raise ValueError("Input layout.alias_tensor_map is None for TransposeExtView.") + + ndim = len(in_tensor_map) + + dim0 = self._normalize_dim(dim0, ndim) + dim1 = self._normalize_dim(dim1, ndim) + + if dim0 == dim1: + return layout + + out_tensor_map = list(in_tensor_map) + out_tensor_map[dim0], out_tensor_map[dim1] = out_tensor_map[dim1], out_tensor_map[dim0] + out_tensor_map = type(in_tensor_map)(out_tensor_map) + + output_layout = Layout( + mesh_shape=layout.mesh_shape, + alias_name=layout.alias_name, + rank_list=layout.rank_list + ) + return output_layout(*out_tensor_map) + + @staticmethod + def _normalize_dim(dim: int, ndim: int) -> int: + """Normalize dim into [0, ndim-1] with MindSpore-style range checks.""" + if dim < -ndim or dim >= ndim: + raise ValueError(f"dim {dim} out of range [-{ndim}, {ndim - 1}]") + return dim + ndim if dim < 0 else dim diff --git a/hyper_parallel/core/shard/ops/yaml/transpose_ops.yaml b/hyper_parallel/core/shard/ops/yaml/transpose_ops.yaml index e341231..28067e3 100644 --- a/hyper_parallel/core/shard/ops/yaml/transpose_ops.yaml +++ b/hyper_parallel/core/shard/ops/yaml/transpose_ops.yaml @@ -1,4 +1,9 @@ Transpose: dist_op_name: _transpose_dist_op distributed_op_class: TransposeDistributedOp + distributed_op_file: parallel_transpose + +TransposeExtView: + dist_op_name: _transpose_ext_view_dist_op + distributed_op_class: TransposeExtViewDistributedOp distributed_op_file: parallel_transpose \ No newline at end of file diff --git a/tests/mindspore/st/shard/test_ops_transpose_ext_view_shell.py b/tests/mindspore/st/shard/test_ops_transpose_ext_view_shell.py new file mode 100644 index 0000000..8021d7e --- /dev/null +++ b/tests/mindspore/st/shard/test_ops_transpose_ext_view_shell.py @@ -0,0 +1,60 @@ +# Copyright 2026 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""parallel_transpose_ext_view_shell test""" + +from tests.common.mark_utils import arg_mark +from tests.mindspore.st.utils import msrun_case + + +@arg_mark(plat_marks=["platform_ascend910b"], level_mark="level0", card_mark="allcards", essential_mark="essential") +def test_transpose_ext_view_basic_3d_1(): + ''' + Feature: TransposeExtView operator. + Description: Test TransposeExtView swaps two dims on a 3D tensor in python shard. + Expectation: Run success. + ''' + glog_v = 2 + file_name = "transpose_ext_view_shard_in_python.py" + case_name = "test_transpose_ext_view_basic_3d_1" + master_port = 11294 + msrun_case(glog_v, file_name, case_name, master_port) + + +@arg_mark(plat_marks=["platform_ascend910b"], level_mark="level0", card_mark="allcards", essential_mark="essential") +def test_transpose_ext_view_negative_dims_2(): + ''' + Feature: TransposeExtView operator. + Description: Test TransposeExtView supports negative dims in python shard. + Expectation: Run success. + ''' + glog_v = 2 + file_name = "transpose_ext_view_shard_in_python.py" + case_name = "test_transpose_ext_view_negative_dims_2" + master_port = 11295 + msrun_case(glog_v, file_name, case_name, master_port) + + +@arg_mark(plat_marks=["platform_ascend910b"], level_mark="level0", card_mark="allcards", essential_mark="essential") +def test_transpose_ext_view_same_dims_noop_3(): + ''' + Feature: TransposeExtView operator. + Description: Test TransposeExtView is a no-op when dim0 == dim1 in python shard. + Expectation: Run success. + ''' + glog_v = 2 + file_name = "transpose_ext_view_shard_in_python.py" + case_name = "test_transpose_ext_view_same_dims_noop_3" + master_port = 11296 + msrun_case(glog_v, file_name, case_name, master_port) diff --git a/tests/mindspore/st/shard/transpose_ext_view_shard_in_python.py b/tests/mindspore/st/shard/transpose_ext_view_shard_in_python.py new file mode 100644 index 0000000..67a4aeb --- /dev/null +++ b/tests/mindspore/st/shard/transpose_ext_view_shard_in_python.py @@ -0,0 +1,147 @@ +# Copyright 2026 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +test transpose ext view shard in python +""" + +import numpy as np + +import mindspore as ms +import mindspore.communication.management as D +from mindspore import nn, Tensor +from hyper_parallel import Layout, shard +from tests.mindspore.st.shard.utils import global_to_local, local_to_global + + +def setup_module(): + ms.set_device("Ascend") + D.init() + + +base_mesh_shape = (2, 2, 2) +base_alias_name = ("dp", "cp", "mp") + + +class TransposeExtViewNet(nn.Cell): + """TransposeExtView composed of transpose and ReLUs""" + + def __init__(self, relu_strategy=None): + super().__init__() + self.transpose = ms.mint.transpose + self.relu = ms.nn.ReLU() + if relu_strategy is not None: + stra = {"forward": {"input": relu_strategy}} + shard(self.relu, stra) + + def construct(self, x, dim0, dim1): + out = self.transpose(x, dim0, dim1) + out = self.relu(out) + out = out + 1 + return out + + +def _standalone_and_parallel_run(x, x_layout, dim0, dim1, relu_in_layout): + """Run standalone and parallel graph and return outputs.""" + # Standalone + standalone_net = TransposeExtViewNet() + standalone_output = standalone_net(x, dim0, dim1) + + # Parallel + x_local = global_to_local(x, x_layout) + parallel_net = TransposeExtViewNet(relu_strategy=(relu_in_layout,)) + parallel_output = parallel_net(x_local, dim0, dim1) + + parallel_output = local_to_global(parallel_output) + return standalone_output, parallel_output + + +def test_transpose_ext_view_basic_3d_1(): + ''' + Feature: TransposeExtView in python shard. + Description: Test TransposeExtView swaps two dims on a 3D tensor. + Expectation: Output matches standalone. + ''' + ms.set_seed(1) + np.random.seed(1) + + d0, d1, d2 = 16, 64, 32 + x = Tensor(np.random.randn(d0, d1, d2).astype(np.float32)) + + layout = Layout(base_mesh_shape, base_alias_name) + + # Shard input to ensure dtensor path is exercised + x_layout = layout("dp", "cp", "mp") + + # After transpose(0, 2): shape becomes (d2, d1, d0). + # Keep ReLU sharding simple: shard last dim by mp. + relu_in_layout = layout("None", "None", "mp") + + standalone_output, parallel_output = _standalone_and_parallel_run( + x, x_layout, dim0=0, dim1=2, relu_in_layout=relu_in_layout + ) + + assert np.allclose(standalone_output.asnumpy(), parallel_output.asnumpy(), 1e-3, 1e-3) + + +def test_transpose_ext_view_negative_dims_2(): + ''' + Feature: TransposeExtView in python shard. + Description: Test TransposeExtView supports negative dims (MindSpore semantics). + Expectation: Output matches standalone. + ''' + ms.set_seed(2) + np.random.seed(2) + + d0, d1, d2, d3 = 8, 16, 32, 64 + x = Tensor(np.random.randn(d0, d1, d2, d3).astype(np.float32)) + + layout = Layout(base_mesh_shape, base_alias_name) + + # Shard input: shard dim2 by mp; others replicated + x_layout = layout("None", "None", "mp", "None") + + # swap(-1, -3) => swap dim3 and dim1 + # output shape becomes (d0, d3, d2, d1) + relu_in_layout = layout("None", "mp", "None", "None") + + standalone_output, parallel_output = _standalone_and_parallel_run( + x, x_layout, dim0=-1, dim1=-3, relu_in_layout=relu_in_layout + ) + + assert np.allclose(standalone_output.asnumpy(), parallel_output.asnumpy(), 1e-3, 1e-3) + + +def test_transpose_ext_view_same_dims_noop_3(): + ''' + Feature: TransposeExtView in python shard. + Description: Test TransposeExtView is a no-op when dim0 == dim1. + Expectation: Output matches standalone. + ''' + ms.set_seed(3) + np.random.seed(3) + + d0, d1, d2 = 4, 32, 128 + x = Tensor(np.random.randn(d0, d1, d2).astype(np.float32)) + + layout = Layout(base_mesh_shape, base_alias_name) + + x_layout = layout("dp", "cp", "mp") + relu_in_layout = layout("dp", "cp", "mp") + + standalone_output, parallel_output = _standalone_and_parallel_run( + x, x_layout, dim0=1, dim1=1, relu_in_layout=relu_in_layout + ) + + assert np.allclose(standalone_output.asnumpy(), parallel_output.asnumpy(), 1e-3, 1e-3) diff --git a/tests/mindspore/ut/parallel_ops_infer/test_parallel_transpose_ext_view.py b/tests/mindspore/ut/parallel_ops_infer/test_parallel_transpose_ext_view.py new file mode 100644 index 0000000..2c75289 --- /dev/null +++ b/tests/mindspore/ut/parallel_ops_infer/test_parallel_transpose_ext_view.py @@ -0,0 +1,187 @@ +# Copyright 2026 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""parallel_transpose_ext_view test""" + +import pytest + +from hyper_parallel import Layout +from hyper_parallel.core.shard.ops.parallel_transpose import TransposeExtViewDistributedOp + + +def run_scenario(scenario_name, x_layout, expected_map, extra_args): + """Infer layout of TransposeExtView operator and validate tensor_map.""" + print(f"\n{'=' * 80}") + print(f"Test TransposeExtView, Scenario: {scenario_name}") + print('=' * 80) + + op = TransposeExtViewDistributedOp("TransposeExtView") + output_layout = op.infer_layout((x_layout,), extra_args) + assert output_layout.to_dict()["tensor_map"] == expected_map, \ + f"TransposeExtView failed in scenario '{scenario_name}'. " \ + f"Expected {expected_map}, got {output_layout.to_dict()['tensor_map']}" + + +base_mesh_shape = (2, 2, 2) +base_alias_name = ("dp", "cp", "mp") +base_rank_list = list(range(8)) + + +def test_transpose_ext_view_basic_swap_3d_1(): + """ + Feature: Basic swap. + Description: swap dim0=0 and dim1=2 on 3D tensor map. + Expectation: tensor_map dims swapped. + """ + x_layout = Layout(base_mesh_shape, base_alias_name, base_rank_list) + x_layout = x_layout("dp", "cp", "mp") + + run_scenario( + "1. Basic swap (0 <-> 2)", + x_layout, + expected_map=(0,1,2), + extra_args=(0, 2) + ) + + +def test_transpose_ext_view_negative_dims_2(): + """ + Feature: Negative dims. + Description: swap dim0=-1 and dim1=-3 on 3D tensor map. + Expectation: normalized dims swapped. + """ + x_layout = Layout(base_mesh_shape, base_alias_name, base_rank_list) + x_layout = x_layout("dp", "cp", "mp") + + # ndim=3: -1 -> 2, -3 -> 0 => swap(2,0) + run_scenario( + "2. Negative dims (-1 <-> -3)", + x_layout, + expected_map=(0,1,2), + extra_args=(-1, -3) + ) + + +def test_transpose_ext_view_noop_same_dims_3(): + """ + Feature: No-op. + Description: dim0 == dim1. + Expectation: output tensor_map unchanged. + """ + x_layout = Layout(base_mesh_shape, base_alias_name, base_rank_list) + x_layout = x_layout("dp", "cp", "mp") + + run_scenario( + "3. No-op swap (1 <-> 1)", + x_layout, + expected_map=(2, 1, 0), + extra_args=(1, 1) + ) + + +def test_transpose_ext_view_tuple_alias_dim_4(): + """ + Feature: Tuple alias dim. + Description: swap a normal dim with a tuple-alias dim. + Expectation: tuple moved to the swapped position. + """ + x_layout = Layout(base_mesh_shape, base_alias_name, base_rank_list) + x_layout = x_layout("None", ("dp", "cp"), "mp") + + # tensor_map should be (-1, (2,1), 0) with base_alias_name ("dp","cp","mp") + # swap dim0=1, dim1=2 => (-1, 0, (2,1)) + run_scenario( + "4. Tuple alias swap (1 <-> 2)", + x_layout, + expected_map=(-1, 0, (2, 1)), + extra_args=(1, 2) + ) + + +def test_transpose_ext_view_dim_out_of_range_5(): + """ + Feature: Error handling. + Description: dim0 or dim1 out of range [-ndim, ndim-1]. + Expectation: raise ValueError. + """ + x_layout = Layout(base_mesh_shape, base_alias_name, base_rank_list) + x_layout = x_layout("dp", "cp", "mp") + + with pytest.raises(ValueError): + run_scenario( + "5. dim0 out of range", + x_layout, + expected_map=(2, 1, 0), + extra_args=(3, 0) + ) + + with pytest.raises(ValueError): + run_scenario( + "6. dim1 out of range (negative)", + x_layout, + expected_map=(2, 1, 0), + extra_args=(-4, 0) + ) + + +def test_transpose_ext_view_dim_type_error_6(): + """ + Feature: Error handling. + Description: dim0 or dim1 is not int. + Expectation: raise TypeError. + """ + x_layout = Layout(base_mesh_shape, base_alias_name, base_rank_list) + x_layout = x_layout("dp", "cp", "mp") + + with pytest.raises(TypeError): + run_scenario( + "7. dim0 type error", + x_layout, + expected_map=(2, 1, 0), + extra_args=("0", 1) + ) + + with pytest.raises(TypeError): + run_scenario( + "8. dim1 type error", + x_layout, + expected_map=(2, 1, 0), + extra_args=(0, None) + ) + + +def test_transpose_ext_view_extra_args_invalid_7(): + """ + Feature: Error handling. + Description: extra_args is not (dim0, dim1). + Expectation: raise ValueError. + """ + x_layout = Layout(base_mesh_shape, base_alias_name, base_rank_list) + x_layout = x_layout("dp", "cp", "mp") + + with pytest.raises(ValueError): + run_scenario( + "9. extra_args missing dim1", + x_layout, + expected_map=(2, 1, 0), + extra_args=(0,) + ) + + with pytest.raises(ValueError): + run_scenario( + "10. extra_args not tuple/list", + x_layout, + expected_map=(2, 1, 0), + extra_args=None + ) -- Gitee From 92d8f4c42be4a6adbf0cc3a49d3a6324dab50880 Mon Sep 17 00:00:00 2001 From: zoujieyu Date: Tue, 10 Feb 2026 16:36:46 +0800 Subject: [PATCH 2/2] tests/torch/shard/ops/test_parallel_flash_attention_score.py --- hyper_parallel/core/shard/_op_dispatch.py | 171 ++- .../ops/parallel_npu_flash_attention_score.py | 1090 ++++++++++++++++ .../ops/yaml/torch_flash_attention_score.yaml | 4 + ...test_parallel_npu_flash_attention_score.py | 585 +++++++++ .../ops/parallel_op_flash_attention_score.py | 1126 +++++++++++++++++ .../test_parallel_flash_attention_score.py | 475 +++++++ 6 files changed, 3409 insertions(+), 42 deletions(-) create mode 100644 hyper_parallel/core/shard/ops/parallel_npu_flash_attention_score.py create mode 100644 hyper_parallel/core/shard/ops/yaml/torch_flash_attention_score.yaml create mode 100644 tests/mindspore/ut/parallel_ops_infer/test_parallel_npu_flash_attention_score.py create mode 100644 tests/torch/shard/ops/parallel_op_flash_attention_score.py create mode 100644 tests/torch/shard/ops/test_parallel_flash_attention_score.py diff --git a/hyper_parallel/core/shard/_op_dispatch.py b/hyper_parallel/core/shard/_op_dispatch.py index 78953cd..9f59827 100644 --- a/hyper_parallel/core/shard/_op_dispatch.py +++ b/hyper_parallel/core/shard/_op_dispatch.py @@ -185,20 +185,16 @@ class OpDispatcher: else: raise - def _with_layout_infer(self, func: callable, *args, **kwargs) -> Tensor: - """_with_layout_infer""" - func_name = platform.get_op_name(func) - packed_call = None - # Ops in unpack_ops use packed fallback args (e.g. ScatterUpdate: (prim_obj, op_name: str, (input_x, indices, updates))). - if(func_name in self.unpack_ops and len(args) == 3 and - isinstance(args[1], str) and isinstance(args[2],(tuple,list))): - packed_call = (args[0], args[1]) - args = tuple(args[2]) - - cache_key = LayoutCacheKey([]) + def _process_args_and_kwargs( + self, args, kwargs, cache_key: "LayoutCacheKey" + ) -> tuple[list, list, list, dict]: + """_process_args_and_kwargs""" + # input_layouts contain prarmeters which have layout, extra_args contain other parameters input_layouts = [] extra_args = [] + # input_args are position prarmeters, input_kwargs are keyword parameters input_args = [] + input_kwargs = kwargs.copy() # Normal ops pass real inputs directly (e.g. SumExt: args = (dtensor, axis: list, keep_dims: bool, dtype: None)). for arg in args: @@ -216,7 +212,7 @@ class OpDispatcher: input_layouts.append(None) input_args.append(arg) else: - layout = arg.layout + layout = arg._layout layout_id = layout.compact_str cache_key.layout_ids.append(str(layout_id)) input_layouts.append(layout) @@ -225,6 +221,41 @@ class OpDispatcher: else: input_args.append(arg) + for k, val in kwargs.items(): + if val is None: + input_layouts.append(None) + continue + if not hasattr(val, "_layout"): + id_str = "scalar" + if not isinstance(val, Tensor): + id_str = str(val) + cache_key.layout_ids.append(id_str) + extra_args.append(val) + input_layouts.append(None) + else: + layout = val._layout + layout_id = layout.compact_str + cache_key.layout_ids.append(str(layout_id)) + input_layouts.append(layout) + if isinstance(val, DTensor): + input_kwargs[k] = val.to_local() + + return input_layouts, extra_args, input_args, input_kwargs + + def _with_layout_infer(self, func: callable, *args, **kwargs) -> Tensor: + """_with_layout_infer""" + func_name = platform.get_op_name(func) + packed_call = None + # Ops in unpack_ops use packed fallback args (e.g. ScatterUpdate: (prim_obj, op_name: str, (input_x, indices, updates))). + if(func_name in self.unpack_ops and len(args) == 3 and + isinstance(args[1], str) and isinstance(args[2],(tuple,list))): + packed_call = (args[0], args[1]) + args = tuple(args[2]) + + cache_key = LayoutCacheKey([]) + input_layouts, extra_args, input_args, input_kwargs = self._process_args_and_kwargs( + args, kwargs, cache_key + ) cache_manager = LayoutCacheManager.get_instance() layout_cache = cache_manager.get_layout_cache() if func_name not in layout_cache: @@ -238,23 +269,25 @@ class OpDispatcher: else: all_args = (input_layouts, extra_args) output_layout = distribute_op.infer_layout(*all_args) - op_impl = distribute_op.get_expand_impl(func, output_layout, input_layouts, extra_args) + op_impl = getattr( + distribute_op, "get_expand_impl", lambda *args, **kwargs: None + )(func, output_layout, input_layouts, extra_args) op_layout_cache[cache_key] = (output_layout, op_impl) if op_impl is None: op_impl = func if packed_call is not None: - py_output = op_impl(packed_call[0], packed_call[1], tuple(input_args), **kwargs) + py_output = op_impl(packed_call[0], packed_call[1], tuple(input_args), **input_kwargs) else: - py_output = op_impl(*input_args, **kwargs) + py_output = op_impl(*input_args, **input_kwargs) if isinstance(py_output, (tuple, list)): output = () if isinstance(output_layout, (tuple, list)): if len(py_output) == len(output_layout): for i, output_item in enumerate(py_output): - output += (DTensor.from_local(output_item, output_layout[i]),) + output += (DTensor.from_local(output_item, output_layout[i].mesh, output_layout[i].placements),) else: raise RuntimeError(f"Output tuple size ({len(py_output)}) " f"does not match layout tuple size ({len(output_layout)})") @@ -262,7 +295,7 @@ class OpDispatcher: raise RuntimeError("Output is a tuple but layout is not") return output - return DTensor.from_local(py_output, output_layout) + return DTensor.from_local(py_output, output_layout.mesh, output_layout.placements) def _with_layout_infer_with_tuple_expand(self, func: callable, *args, **kwargs) -> Tensor: """_with_layout_infer_with_tuple_expand""" @@ -294,7 +327,7 @@ class OpDispatcher: extra_args.append(arg) input_layouts.append(None) else: - layout = arg.layout + layout = arg._layout layout_id = layout.compact_str cache_key.layout_ids.append(str(layout_id)) input_layouts.append(layout) @@ -326,7 +359,7 @@ class OpDispatcher: if isinstance(output_layout, (tuple, list)): if len(py_output) == len(output_layout): for i, output_item in enumerate(py_output): - output += (DTensor.from_local(output_item, output_layout[i]),) + output += (DTensor.from_local(output_item, output_layout[i].mesh, output_layout[i].placements),) else: raise RuntimeError(f"Output tuple size ({len(py_output)}) " f"does not match layout tuple size ({len(output_layout)})") @@ -334,7 +367,7 @@ class OpDispatcher: raise RuntimeError("Output is a tuple but layout is not") return output - return DTensor.from_local(py_output, output_layout) + return DTensor.from_local(py_output, output_layout.mesh, output_layout.placements) def _with_layout_infer_reshape(self, func: callable, *args) -> Tensor: """_with_layout_infer_reshape""" @@ -343,7 +376,7 @@ class OpDispatcher: cache_key = LayoutCacheKey([]) input_layouts = [] - layout = input_tensor.layout + layout = input_tensor._layout input_layouts.append(layout) layout_id = layout.compact_str cache_key.layout_ids.append(str(layout_id)) @@ -381,25 +414,17 @@ class OpDispatcher: py_output = op_impl(input_tensor.to_local(), local_shape) - return DTensor.from_local(py_output, infer_output_tuple[0]) + return DTensor.from_local(py_output, infer_output_tuple[0].mesh, infer_output_tuple[0].placements) - def _with_layout_infer_with_shape(self, func: callable, *args, **kwargs) -> Tensor: - """_with_layout_infer_with_shape""" - func_name = platform.get_op_name(func) - packed_call = None - # Packed fallback args for some ops (e.g. Mod: (prim_obj, "Mod", (x, y))). - if (func_name in self.unpack_ops and len(args) == 3 and - isinstance(args[1], str) and isinstance(args[2], (tuple, list))): - packed_call = (args[0], args[1]) - args = tuple(args[2]) - - cache_key = LayoutCacheKey([]) + def _process_args_and_kwargs_with_shape( + self, args, kwargs, cache_key: "LayoutCacheKey" + ) -> tuple[list, list, list, list, dict]: + """_process_args_and_kwargs_with_shape""" input_layouts = [] extra_args = [] input_shapes = [] input_args = [] - - # Normal ops pass real inputs directly (e.g. GreaterEqual: (dtensor_x, dtensor_y)). + input_kwargs = kwargs.copy() for arg in args: if arg is None: input_layouts.append(None) @@ -416,7 +441,7 @@ class OpDispatcher: input_layouts.append(None) input_args.append(arg) else: - layout = arg.layout + layout = arg._layout layout_id = layout.compact_str cache_key.layout_ids.append(str(layout_id)) input_layouts.append(layout) @@ -432,6 +457,48 @@ class OpDispatcher: input_shapes.append(input_shape) cache_key.layout_ids.append(str(input_shape)) + for k, val in kwargs.items(): + if val is None: + input_layouts.append(None) + continue + if not hasattr(val, "_layout"): + id_str = "scalar" + if not isinstance(val, Tensor): + id_str = str(val) + cache_key.layout_ids.append(id_str) + extra_args.append(val) + input_layouts.append(None) + else: + layout = val._layout + layout_id = layout.compact_str + cache_key.layout_ids.append(str(layout_id)) + input_layouts.append(layout) + if isinstance(val, DTensor): + input_kwargs[k] = val.to_local() + + if not hasattr(val, "shape"): + input_shapes.append(None) + else: + input_shape = val.shape + input_shapes.append(input_shape) + cache_key.layout_ids.append(str(input_shape)) + + return input_layouts, input_shapes, extra_args, input_args, input_kwargs + + def _with_layout_infer_with_shape(self, func: callable, *args, **kwargs) -> Tensor: + """_with_layout_infer_with_shape""" + func_name = platform.get_op_name(func) + packed_call = None + # Packed fallback args for some ops (e.g. Mod: (prim_obj, "Mod", (x, y))). + if (func_name in self.unpack_ops and len(args) == 3 and + isinstance(args[1], str) and isinstance(args[2], (tuple, list))): + packed_call = (args[0], args[1]) + args = tuple(args[2]) + + cache_key = LayoutCacheKey([]) + input_layouts, input_shapes, extra_args, input_args, input_kwargs = \ + self._process_args_and_kwargs_with_shape(args, kwargs, cache_key) + cache_manager = LayoutCacheManager.get_instance() layout_cache = cache_manager.get_layout_cache() if func_name not in layout_cache: @@ -453,9 +520,9 @@ class OpDispatcher: op_impl = func if packed_call is not None: - py_output = op_impl(packed_call[0], packed_call[1], tuple(input_args), **kwargs) + py_output = op_impl(packed_call[0], packed_call[1], tuple(input_args), **input_kwargs) else: - py_output = op_impl(*input_args, **kwargs) + py_output = op_impl(*input_args, **input_kwargs) # 设置输出布局 if isinstance(py_output, (tuple, list)): @@ -463,7 +530,7 @@ class OpDispatcher: if isinstance(output_layout, (tuple, list)): if len(py_output) == len(output_layout): for i, output_item in enumerate(py_output): - output += (DTensor.from_local(output_item, output_layout[i]),) + output += (DTensor.from_local(output_item, output_layout[i].mesh, output_layout[i].placements),) else: raise RuntimeError(f"Output tuple size ({len(py_output)}) " f"does not match layout tuple size ({len(output_layout)})") @@ -471,7 +538,7 @@ class OpDispatcher: raise RuntimeError("Output is a tuple but layout is not") return output - return DTensor.from_local(py_output, output_layout) + return DTensor.from_local(py_output, output_layout.mesh, output_layout.placements) def _with_layout_infer_slice(self, func: callable, *args) -> Tensor: """_with_layout_infer_slice""" @@ -483,7 +550,7 @@ class OpDispatcher: cache_key = LayoutCacheKey([]) input_layouts = [] - layout = input_tensor.layout + layout = input_tensor._layout global_shape = input_tensor.shape input_layouts.append(layout) layout_id = layout.compact_str @@ -523,7 +590,25 @@ class OpDispatcher: py_output = op_impl(input_tensor.to_local(), new_begin, new_end) - return DTensor.from_local(py_output, infer_output_tuple[0]) + return DTensor.from_local(py_output, infer_output_tuple[0].mesh, infer_output_tuple[0].placements) + + def _merge_default(self, config: dict): + """Apply __default__ values to all ops in this YAML file.""" + if "__default__" not in config: + return config + + default_cfg = config["__default__"] + merged = {} + + for op_name, op_cfg in config.items(): + if op_name == "__default__": + continue + + new_cfg = default_cfg.copy() + new_cfg.update(op_cfg) + merged[op_name] = new_cfg + + return merged def safe_load_yaml_from_dir(self): """ @@ -537,6 +622,8 @@ class OpDispatcher: for yaml_file_path in glob.glob(os.path.join(yaml_path, '*.yaml')): with open(yaml_file_path, 'r', encoding="utf-8") as f: yaml_data = yaml.safe_load(f) + + yaml_data = self._merge_default(yaml_data) for name, data in yaml_data.items(): if name in yaml_dict: raise ValueError(f"Duplicate yaml object with name '{name}'.") diff --git a/hyper_parallel/core/shard/ops/parallel_npu_flash_attention_score.py b/hyper_parallel/core/shard/ops/parallel_npu_flash_attention_score.py new file mode 100644 index 0000000..c313d03 --- /dev/null +++ b/hyper_parallel/core/shard/ops/parallel_npu_flash_attention_score.py @@ -0,0 +1,1090 @@ +# Copyright 2026 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""FlashAttentionScore Distributed Operator""" + +import copy +import warnings +import torch +from typing import List, Tuple, Optional, Any + +from hyper_parallel.core.layout import Layout +from hyper_parallel.core.placement_types import Shard, Replicate +from hyper_parallel.core.shard.ops.parallel_ops_register import register_distributed_op +from hyper_parallel.platform import get_platform + +platform = get_platform() +Tensor = platform.Tensor + +SPARSE_DEFAULT_MASK = 0 +SPARSE_ALL_MASK = 1 +SPARSE_LEFT_UP_CAUSAL = 2 +SPARSE_RIGHT_DOWN_CAUSAL = 3 +SPARSE_BAND = 4 + +LEFT_UP_TO_LEFT_UP = 0 +LEFT_UP_TO_RIGHT_DOWN = 1 +RIGHT_DOWN_TO_RIGHT_DOWN = 2 + +SPARSE_MODE_UPDATE_MAP = { + SPARSE_DEFAULT_MASK: LEFT_UP_TO_LEFT_UP, + SPARSE_ALL_MASK: LEFT_UP_TO_LEFT_UP, + SPARSE_LEFT_UP_CAUSAL: LEFT_UP_TO_RIGHT_DOWN, + SPARSE_RIGHT_DOWN_CAUSAL: RIGHT_DOWN_TO_RIGHT_DOWN, + SPARSE_BAND: RIGHT_DOWN_TO_RIGHT_DOWN, +} + + +class FlashAttentionScoreDistributedOp: + """Distributed operator for torch_npu.npu_fusion_attention.""" + + def __init__(self, op_name: str): + self.op_name = op_name + register_distributed_op(op_name, self) + + self._layout_dims = { + "BSH": {"batch": 0, "seq": 1, "hidden": 2}, + "BNSD": {"batch": 0, "head": 1, "seq": 2, "dim": 3}, + "SBH": {"seq": 0, "batch": 1, "hidden": 2}, + "BSND": {"batch": 0, "seq": 1, "head": 2, "dim": 3}, + "TND": {"total": 0, "head": 1, "dim": 2}, + } + + def _tensor_map_to_placements(self, base_layout: Layout, tensor_map: tuple) -> tuple: + """Convert tensor_map to placements.""" + mesh_ndim = len(base_layout.mesh_shape) + placements = [] + + for mesh_dim_idx in range(mesh_ndim): + is_sharded = False + + for tensor_dim_idx, tensor_dim_map in enumerate(tensor_map): + if tensor_dim_map == -1: + continue + + if isinstance(tensor_dim_map, tuple): + if mesh_dim_idx in tensor_dim_map: + placements.append(Shard(tensor_dim_idx)) + is_sharded = True + break + elif tensor_dim_map == mesh_dim_idx: + placements.append(Shard(tensor_dim_idx)) + is_sharded = True + break + + if not is_sharded: + placements.append(Replicate()) + + return tuple(placements) + + def _is_dynamic_shape(self, tensor: Tensor, dim: int) -> bool: + """Check if tensor has dynamic shape at given dimension.""" + try: + shape_val = tensor.shape[dim] + if isinstance(shape_val, int): + return shape_val == -1 + return not isinstance(shape_val, int) + except (IndexError, AttributeError): + return False + + def _get_dynamic_shape_info( + self, + query: Tensor, + key: Tensor, + input_layout: str + ) -> dict: + """Get dynamic shape information for query and key tensors.""" + dims = self._layout_dims.get(input_layout, {}) + + seq_dim_idx = None + if 'seq' in dims: + seq_dim_idx = dims['seq'] + elif 'total' in dims: + seq_dim_idx = dims['total'] + + if seq_dim_idx is None: + return {'is_dynamic': False} + + q_is_dynamic = self._is_dynamic_shape(query, seq_dim_idx) + kv_is_dynamic = self._is_dynamic_shape(key, seq_dim_idx) + + return { + 'is_dynamic': q_is_dynamic or kv_is_dynamic, + 'q_seq_dim': seq_dim_idx, + 'kv_seq_dim': seq_dim_idx, + 'q_batch_dim': dims.get('batch', dims.get('total')), + } + + def _is_attn_mask_compressed(self, sparse_mode: int) -> bool: + """Check if attention mask is compressed for given sparse mode.""" + return sparse_mode in ( + SPARSE_LEFT_UP_CAUSAL, + SPARSE_RIGHT_DOWN_CAUSAL, + SPARSE_BAND, + ) + + def _compute_sparse_params( + self, + sparse_mode: int, + pre_tockens: int, + next_tockens: int, + split_id: int, + split_num: int, + local_q_len: int, + global_q_len: int, + global_kv_len: int, + ) -> Tuple[int, int, int]: + """Calculate adjusted sparse parameters for static shape.""" + if sparse_mode not in SPARSE_MODE_UPDATE_MAP: + return sparse_mode, pre_tockens, next_tockens + + if sparse_mode == SPARSE_ALL_MASK: + return sparse_mode, pre_tockens, next_tockens + + if sparse_mode in (SPARSE_DEFAULT_MASK, SPARSE_BAND): + new_pre_tockens = pre_tockens + new_next_tockens = next_tockens + else: + new_pre_tockens = global_kv_len + new_next_tockens = 0 + + new_sparse_mode = SPARSE_BAND if sparse_mode != SPARSE_DEFAULT_MASK else sparse_mode + update_mode = SPARSE_MODE_UPDATE_MAP[sparse_mode] + + if update_mode == LEFT_UP_TO_LEFT_UP: + new_pre_tockens += -split_id * local_q_len + new_next_tockens += split_id * local_q_len + elif update_mode == LEFT_UP_TO_RIGHT_DOWN: + offset = global_kv_len - (split_id + 1) * local_q_len + new_pre_tockens += offset + new_next_tockens += -offset + elif update_mode == RIGHT_DOWN_TO_RIGHT_DOWN: + offset = (split_num - split_id - 1) * local_q_len + new_pre_tockens += offset + new_next_tockens += -offset + + return new_sparse_mode, new_pre_tockens, new_next_tockens + + def _compute_sparse_params_dynamic( + self, + query: Tensor, + key: Tensor, + sparse_mode: int, + pre_tockens: int, + next_tockens: int, + split_id: int, + split_num: int, + seq_dim_idx: int, + ) -> Tuple: + """Calculate adjusted sparse parameters for dynamic shape.""" + if sparse_mode not in SPARSE_MODE_UPDATE_MAP: + return sparse_mode, pre_tockens, next_tockens + + if sparse_mode == SPARSE_ALL_MASK: + return sparse_mode, pre_tockens, next_tockens + + query_seq_length = query.shape[seq_dim_idx] + key_seq_length = key.shape[seq_dim_idx] + + local_q_len_symbolic = query_seq_length // split_num + + if sparse_mode in (SPARSE_DEFAULT_MASK, SPARSE_BAND): + new_pre_tockens = pre_tockens + new_next_tockens = next_tockens + else: + new_pre_tockens = key_seq_length + new_next_tockens = 0 + + new_sparse_mode = SPARSE_BAND if sparse_mode != SPARSE_DEFAULT_MASK else sparse_mode + update_mode = SPARSE_MODE_UPDATE_MAP[sparse_mode] + + if update_mode == LEFT_UP_TO_LEFT_UP: + offset = -split_id * local_q_len_symbolic + new_pre_tockens = new_pre_tockens + offset + new_next_tockens = new_next_tockens - offset + elif update_mode == LEFT_UP_TO_RIGHT_DOWN: + offset = key_seq_length - (split_id + 1) * local_q_len_symbolic + new_pre_tockens = new_pre_tockens + offset + new_next_tockens = new_next_tockens - offset + elif update_mode == RIGHT_DOWN_TO_RIGHT_DOWN: + offset = (split_num - split_id - 1) * local_q_len_symbolic + new_pre_tockens = new_pre_tockens + offset + new_next_tockens = new_next_tockens - offset + + return new_sparse_mode, new_pre_tockens, new_next_tockens + + def _adjust_actual_seq_len_for_tnd_cp( + self, + query: Tensor, + key: Tensor, + actual_seq_qlen: List[int], + actual_seq_kvlen: List[int], + split_id: int, + kv_is_sharded: bool, + ) -> Tuple[List[int], List[int]]: + """Adjust actual_seq_qlen and actual_seq_kvlen for TND layout with context parallel.""" + slice_tq = query.shape[0] + slice_tk = key.shape[0] + + is_dynamic = self._is_dynamic_shape(query, 0) or self._is_dynamic_shape(key, 0) + + if is_dynamic: + return self._adjust_actual_seq_len_dynamic( + slice_tq, slice_tk, + actual_seq_qlen, actual_seq_kvlen, + split_id, kv_is_sharded, + query.device, + ) + else: + return self._adjust_actual_seq_len_static( + slice_tq, slice_tk, + actual_seq_qlen, actual_seq_kvlen, + split_id, kv_is_sharded, + query.device, + ) + + def _adjust_actual_seq_len_static( + self, + slice_tq: int, + slice_tk: int, + actual_seq_qlen: List[int], + actual_seq_kvlen: List[int], + split_id: int, + kv_is_sharded: bool, + device: torch.device = None, + ) -> Tuple[List[int], List[int]]: + """Adjust actual_seq_len for static shape.""" + if device is None: + device = torch.device("cpu") + + offset_q = slice_tq * split_id + + actual_seq_qlen_tensor = torch.tensor(actual_seq_qlen, dtype=torch.int64, device=device) + actual_seq_kvlen_tensor = torch.tensor(actual_seq_kvlen, dtype=torch.int64, device=device) + + qlen_offset = actual_seq_qlen_tensor - offset_q + new_actual_seq_qlen = torch.clamp(qlen_offset, min=0, max=slice_tq) + + if kv_is_sharded: + offset_kv = slice_tk * split_id + kvlen_offset = actual_seq_kvlen_tensor - offset_kv + new_actual_seq_kvlen = torch.clamp(kvlen_offset, min=0, max=slice_tk) + else: + relu_result = torch.relu(qlen_offset.float()).long() + kvlen_offset = relu_result - new_actual_seq_qlen + new_actual_seq_kvlen = actual_seq_kvlen_tensor - kvlen_offset + + if len(new_actual_seq_kvlen) > 0: + last_idx = len(new_actual_seq_kvlen) - 1 + if actual_seq_kvlen_tensor[last_idx].item() == slice_tk: + new_actual_seq_kvlen[last_idx] = slice_tk + + return new_actual_seq_qlen.tolist(), new_actual_seq_kvlen.tolist() + + def _adjust_actual_seq_len_dynamic( + self, + slice_tq, + slice_tk, + actual_seq_qlen: List[int], + actual_seq_kvlen: List[int], + split_id: int, + kv_is_sharded: bool, + device: torch.device = None, + ) -> Tuple[List[int], List[int]]: + """Adjust actual_seq_len for dynamic shape.""" + if device is None: + device = torch.device("cpu") + + offset_q = slice_tq * split_id + + actual_seq_qlen_tensor = torch.tensor(actual_seq_qlen, dtype=torch.int64, device=device) + actual_seq_kvlen_tensor = torch.tensor(actual_seq_kvlen, dtype=torch.int64, device=device) + + qlen_offset = actual_seq_qlen_tensor - offset_q + new_actual_seq_qlen = torch.clamp(qlen_offset, min=0, max=slice_tq) + + if kv_is_sharded: + offset_kv = slice_tk * split_id + kvlen_offset = actual_seq_kvlen_tensor - offset_kv + new_actual_seq_kvlen = torch.clamp(kvlen_offset, min=0, max=slice_tk) + else: + relu_result = torch.relu(qlen_offset.float()).long() + kvlen_offset = relu_result - new_actual_seq_qlen + new_actual_seq_kvlen = actual_seq_kvlen_tensor - kvlen_offset + + if len(new_actual_seq_kvlen) > 0: + last_idx = len(new_actual_seq_kvlen) - 1 + mask = (actual_seq_kvlen_tensor[last_idx] == slice_tk) + new_actual_seq_kvlen[last_idx] = torch.where( + mask, + slice_tk, + new_actual_seq_kvlen[last_idx] + ) + + return new_actual_seq_qlen.tolist(), new_actual_seq_kvlen.tolist() + + def _validate_atten_mask( + self, + atten_mask: Optional[Tensor], + sparse_mode: int, + input_layout: str, + is_varlen: bool = False + ) -> None: + """Validate attention mask shape and configuration for given sparse mode.""" + if atten_mask is None: + if sparse_mode == SPARSE_ALL_MASK: + raise ValueError( + "sparse_mode=1 (allMask) requires atten_mask to be provided" + ) + return + + mask_shape = atten_mask.shape + + if len(mask_shape) not in (2, 4): + raise ValueError( + f"atten_mask only supports 2D or 4D format, but got {len(mask_shape)}D" + ) + + if is_varlen: + if len(mask_shape) != 2: + raise ValueError( + f"Varlen scenario only supports 2D atten_mask (maxSq, maxSkv), " + f"but got {len(mask_shape)}D" + ) + + if self._is_attn_mask_compressed(sparse_mode): + expected_shape = (2048, 2048) + if mask_shape[-2:] != expected_shape: + warnings.warn( + f"sparse_mode={sparse_mode} uses compressed mask, " + f"expected shape {expected_shape} but got {mask_shape[-2:]}" + ) + + def _validate_pse_configuration( + self, + pse: Optional[Tensor], + sparse_mode: int + ) -> None: + """Validate PSE (positional encoding) configuration.""" + if pse is None: + return + + pse_shape = pse.shape + + if len(pse_shape) not in (3, 4): + raise ValueError( + f"PSE only supports 3D or 4D format, but got {len(pse_shape)}D" + ) + + if len(pse_shape) == 4 and pse_shape[2] == 1024: + warnings.warn("Detected Alibi positional encoding compression scenario") + + def infer_layout( + self, input_layouts: List[Optional[Layout]], extra_args: List[Any] + ) -> Tuple[Layout, ...]: + """Infer output layouts.""" + query_layout = input_layouts[0] + if query_layout is None: + raise ValueError("Query layout cannot be None") + + attention_out_layout = copy.deepcopy(query_layout) + if attention_out_layout.placements is None and attention_out_layout.tensor_map is not None: + attention_out_placements = self._tensor_map_to_placements( + attention_out_layout, attention_out_layout.tensor_map + ) + attention_out_layout.set_placements(attention_out_placements) + + input_layout_str = None + softmax_layout_param = "" + + if len(extra_args) >= 2: + input_layout_str = extra_args[1] + if not isinstance(input_layout_str, str): + input_layout_str = None + + if len(extra_args) >= 20: + softmax_layout_param = extra_args[19] if extra_args[19] else "" + + if input_layout_str and input_layout_str in self._layout_dims: + softmax_layout = self._infer_softmax_layout_by_input_layout( + query_layout, input_layout_str, softmax_layout_param + ) + else: + if input_layout_str is None: + raise ValueError( + "Missing required parameter 'input_layout' in extra explicit input_layout specification." + ) + + if input_layout_str not in self._layout_dims: + raise ValueError( + f"Unsupported input_layout: '{input_layout_str}'.\n" + f"Supported layouts: {list(self._layout_dims.keys())}" + ) + + softmax_layout = self._infer_softmax_layout_conservatively(query_layout) + + softmax_max_layout = softmax_layout + softmax_sum_layout = copy.deepcopy(softmax_layout) + softmax_out_layout = self._create_replicated_scalar_layout(query_layout) + if softmax_out_layout.placements is None and softmax_out_layout.tensor_map is not None: + softmax_out_placements = self._tensor_map_to_placements( + softmax_out_layout, softmax_out_layout.tensor_map + ) + softmax_out_layout.set_placements(softmax_out_placements) + + return ( + attention_out_layout, + softmax_max_layout, + softmax_sum_layout, + softmax_out_layout, + ) + + def _infer_softmax_layout_conservatively(self, query_layout: Layout) -> Layout: + """Conservative fallback for softmax layout inference.""" + softmax_layout = Layout.from_device_mesh(query_layout.mesh) + query_tm = query_layout.tensor_map + + if query_tm is None or len(query_tm) == 0: + softmax_tensor_map = (-1, -1, -1, -1) + else: + softmax_tm = [ + query_tm[0] if len(query_tm) > 0 else -1, + -1, + -1, + -1, + ] + softmax_tensor_map = tuple(softmax_tm) + + warnings.warn( + f"Using conservative softmax layout inference due to missing/invalid input_layout.\n" + f"Query tensor_map: {query_tm}\n" + f"Inferred softmax tensor_map: {softmax_tensor_map}\n" + f"This may not be optimal. Please provide explicit input_layout parameter." + ) + + softmax_layout.set_tensor_map(softmax_tensor_map) + softmax_placements = self._tensor_map_to_placements(softmax_layout, softmax_tensor_map) + softmax_layout.set_placements(softmax_placements) + + return softmax_layout + + def _infer_softmax_layout_by_input_layout( + self, + query_layout: Layout, + input_layout_str: str, + softmax_layout_param: str = "" + ) -> Layout: + """Infer softmax layout based on input_layout and softmax_layout parameter.""" + query_split_info = self._get_split_info(query_layout, input_layout_str) + + softmax_tensor_map = self._build_softmax_tensor_map( + query_layout, input_layout_str, query_split_info, softmax_layout_param + ) + + softmax_layout = Layout.from_device_mesh(query_layout.mesh) + softmax_layout.set_tensor_map(softmax_tensor_map) + softmax_placements = self._tensor_map_to_placements(softmax_layout, softmax_tensor_map) + softmax_layout.set_placements(softmax_placements) + + return softmax_layout + + def _build_softmax_tensor_map( + self, + query_layout: Layout, + input_layout_str: str, + query_split_info: dict, + softmax_layout_param: str = "" + ) -> tuple: + """Build softmax tensor_map.""" + dims = self._layout_dims.get(input_layout_str, {}) + query_tm = query_layout.tensor_map + + if query_tm is None: + return (-1, -1, -1, -1) + + softmax_tm = [-1, -1, -1, -1] + + if input_layout_str == "TND": + softmax_tm[0] = query_tm[0] if len(query_tm) > 0 else -1 + softmax_tm[1] = query_tm[1] if len(query_tm) > 1 else -1 + else: + if "batch" in dims: + batch_idx = dims["batch"] + if batch_idx < len(query_tm): + softmax_tm[0] = query_tm[batch_idx] + + if "head" in dims: + head_idx = dims["head"] + if head_idx < len(query_tm): + softmax_tm[1] = query_tm[head_idx] + elif "hidden" in dims: + hidden_idx = dims["hidden"] + if hidden_idx < len(query_tm): + softmax_tm[1] = query_tm[hidden_idx] + + if "seq" in dims: + seq_idx = dims["seq"] + if seq_idx < len(query_tm): + softmax_tm[2] = query_tm[seq_idx] + + softmax_tm[3] = -1 + + return tuple(softmax_tm) + + def _create_default_softmax_layout(self, query_layout: Layout) -> Layout: + """Create default softmax layout.""" + softmax_layout = Layout.from_device_mesh(query_layout.mesh) + softmax_layout.set_tensor_map((-1, -1, -1, -1)) + return softmax_layout + + def _validate_sharding_consistency( + self, + query_layout: Layout, + key_layout: Optional[Layout], + input_layout: str + ): + """Validate Q/K/V sharding consistency.""" + if key_layout is None or not hasattr(key_layout, 'tensor_map'): + return + + dims = self._layout_dims.get(input_layout, {}) + q_tm = query_layout.tensor_map + k_tm = key_layout.tensor_map + + if q_tm is None or k_tm is None: + return + + self._check_batch_consistency(dims, q_tm, k_tm, input_layout) + self._check_hidden_consistency(dims, q_tm, k_tm, input_layout) + self._check_dim_consistency(dims, q_tm, k_tm, input_layout) + + def _check_batch_consistency(self, dims, q_tm, k_tm, input_layout): + """Check batch dimension sharding consistency.""" + if "batch" not in dims: + return + + batch_idx = dims["batch"] + if batch_idx >= len(q_tm) or batch_idx >= len(k_tm): + return + + q_batch_shard = self._normalize_dim_map(q_tm[batch_idx]) + k_batch_shard = self._normalize_dim_map(k_tm[batch_idx]) + + if q_batch_shard != k_batch_shard: + raise ValueError( + f"Query and Key/Value must have identical batch sharding strategy.\n" + f"Input layout: {input_layout}\n" + f"Query batch sharding (dim {batch_idx}): {q_batch_shard}\n" + f"Key/Value batch sharding (dim {batch_idx}): {k_batch_shard}\n" + f"Query tensor_map: {q_tm}\n" + f"Key tensor_map: {k_tm}" + ) + + def _check_hidden_consistency(self, dims, q_tm, k_tm, input_layout): + """Check hidden dimension sharding consistency.""" + if "hidden" not in dims: + return + + hidden_idx = dims["hidden"] + if hidden_idx >= len(q_tm) or hidden_idx >= len(k_tm): + return + + q_hidden_shard = self._normalize_dim_map(q_tm[hidden_idx]) + k_hidden_shard = self._normalize_dim_map(k_tm[hidden_idx]) + + if q_hidden_shard != k_hidden_shard: + raise ValueError( + f"Query and Key/Value must have identical hidden sharding strategy.\n" + f"Input layout: {input_layout}\n" + f"Query hidden sharding (dim {hidden_idx}): {q_hidden_shard}\n" + f"Key/Value hidden sharding (dim {hidden_idx}): {k_hidden_shard}\n" + f"Query tensor_map: {q_tm}\n" + f"Key tensor_map: {k_tm}\n" + f"Note: This checks sharding strategy, not tensor size.\n" + f"GQA (different head counts) is supported when sharding strategies match." + ) + + def _check_dim_consistency(self, dims, q_tm, k_tm, input_layout): + """Check dim dimension sharding consistency.""" + if "dim" not in dims: + return + + dim_idx = dims["dim"] + if dim_idx >= len(q_tm) or dim_idx >= len(k_tm): + return + + q_dim_shard = self._normalize_dim_map(q_tm[dim_idx]) + k_dim_shard = self._normalize_dim_map(k_tm[dim_idx]) + + if q_dim_shard != k_dim_shard: + raise ValueError( + f"Query and Key/Value must have identical dim sharding strategy.\n" + f"Input layout: {input_layout}\n" + f"Query dim sharding (dim {dim_idx}): {q_dim_shard}\n" + f"Key/Value dim sharding (dim {dim_idx}): {k_dim_shard}\n" + f"Query tensor_map: {q_tm}\n" + f"Key tensor_map: {k_tm}" + ) + + def _check_seq_sharding_compatibility( + self, + query_layout: Layout, + key_layout: Optional[Layout], + input_layout: str, + seq_dim_idx: int, + seq_split_num: int, + kv_seq_split_num: int + ): + """Check sequence dimension sharding compatibility.""" + if key_layout is None: + return + + q_tm = query_layout.tensor_map + k_tm = key_layout.tensor_map + + if q_tm is None or k_tm is None: + return + + if seq_dim_idx >= len(q_tm) or seq_dim_idx >= len(k_tm): + return + + if input_layout != "TND" and kv_seq_split_num > 1: + raise NotImplementedError( + f"KV sequence sharding is not supported for layout '{input_layout}' " + f"without Ring Attention.\n" + f"Query sequence split num: {seq_split_num}\n" + f"Key/Value sequence split num: {kv_seq_split_num}\n" + f"Supported scenarios:\n" + f" - Query sequence sharding + KV not sharded (Ulysses-style)\n" + f" - Query and KV both not sharded\n" + f"Unsupported scenario:\n" + f" - KV sequence sharding (requires Ring Attention)" + ) + + q_seq_shard = self._normalize_dim_map(q_tm[seq_dim_idx]) + k_seq_shard = self._normalize_dim_map(k_tm[seq_dim_idx]) + + if q_seq_shard != k_seq_shard: + if input_layout == "TND": + pass + elif kv_seq_split_num > 1: + raise NotImplementedError( + f"Ring Attention (KV sequence sharding with different strategy) " + f"is not supported.\n" + f"Input layout: {input_layout}\n" + f"Query sequence split num: {seq_split_num}\n" + f"Key/Value sequence split num: {kv_seq_split_num}\n" + f"Query seq sharding strategy: {q_seq_shard}\n" + f"Key/Value seq sharding strategy: {k_seq_shard}\n" + f"Supported scenarios:\n" + f" - Query sequence sharding + KV not sharded (Ulysses-style)\n" + f" - Query and KV both sharded with SAME strategy\n" + f" - Query and KV both not sharded\n" + f"Unsupported scenario:\n" + f" - KV sequence sharding with DIFFERENT strategy (Ring Attention)" + ) + + def get_expand_impl( + self, + func: callable, + output_layouts: Tuple[Layout, ...], + input_layouts: List[Optional[Layout]], + extra_args: List[Any], + ) -> Optional[callable]: + """Create expanded implementation.""" + query_layout = input_layouts[0] + if query_layout is None: + return None + + if len(input_layouts) >= 3: + key_layout = input_layouts[1] + value_layout = input_layouts[2] + + if (key_layout is not None and value_layout is not None and + hasattr(key_layout, 'tensor_map') and hasattr(value_layout, 'tensor_map')): + if key_layout.tensor_map != value_layout.tensor_map: + raise ValueError( + f"Key and Value must have identical sharding strategies.\n" + f"Key tensor_map: {key_layout.tensor_map}\n" + f"Value tensor_map: {value_layout.tensor_map}" + ) + + def expanded_impl( + query, + key, + value, + head_num, + input_layout, + pse=None, + padding_mask=None, + atten_mask=None, + scale=1.0, + keep_prob=1.0, + pre_tockens=2147483647, + next_tockens=2147483647, + inner_precise=0, + prefix=None, + actual_seq_qlen=None, + actual_seq_kvlen=None, + sparse_mode=0, + gen_mask_parallel=True, + sync=False, + softmax_layout="", + sink=None, + ): + key_layout = input_layouts[1] + self._validate_sharding_consistency(query_layout, key_layout, input_layout) + + is_varlen = input_layout == "TND" and actual_seq_qlen is not None + self._validate_atten_mask(atten_mask, sparse_mode, input_layout, is_varlen) + self._validate_pse_configuration(pse, sparse_mode) + + split_info = self._get_split_info(query_layout, input_layout) + head_split_num = split_info["head"] + seq_split_num = split_info["seq"] + + if head_split_num == 1 and seq_split_num == 1: + result = func( + query, key, value, head_num, input_layout, + pse, padding_mask, atten_mask, scale, keep_prob, + pre_tockens, next_tockens, inner_precise, + prefix, actual_seq_qlen, actual_seq_kvlen, + sparse_mode, gen_mask_parallel, sync, + softmax_layout, sink, + ) + return result[:4] if isinstance(result, (tuple, list)) and len(result) >= 4 else result + + if head_split_num <= 0: + raise ValueError(f"Invalid head_split_num={head_split_num}") + if head_num % head_split_num != 0: + raise ValueError( + f"head_num({head_num}) not divisible by head_split_num({head_split_num})" + ) + adjusted_head_num = head_num // head_split_num + + adjusted_sparse_mode = sparse_mode + adjusted_pre_tockens = pre_tockens + adjusted_next_tockens = next_tockens + adjusted_actual_seq_qlen = actual_seq_qlen + adjusted_actual_seq_kvlen = actual_seq_kvlen + + if seq_split_num > 1: + dynamic_info = self._get_dynamic_shape_info(query, key, input_layout) + is_dynamic = dynamic_info.get('is_dynamic', False) + + split_id = self._get_split_id(query_layout, input_layout) + dims = self._layout_dims.get(input_layout, {}) + seq_dim_idx = self._get_seq_dim_idx(dims) + + if seq_dim_idx is None: + raise ValueError( + f"Cannot infer seq/total dim for input_layout={input_layout}" + ) + + kv_seq_split_num = 1 + if key_layout is not None: + kv_split_info = self._get_split_info(key_layout, input_layout) + kv_seq_split_num = kv_split_info["seq"] + + self._check_seq_sharding_compatibility( + query_layout, key_layout, input_layout, + seq_dim_idx, seq_split_num, kv_seq_split_num + ) + + if is_dynamic: + (adjusted_sparse_mode, + adjusted_pre_tockens, + adjusted_next_tockens) = self._compute_sparse_params_dynamic( + query, key, + sparse_mode, pre_tockens, next_tockens, + split_id, seq_split_num, seq_dim_idx, + ) + else: + local_q_len = query.shape[seq_dim_idx] + global_q_len = local_q_len * seq_split_num + local_kv_len = ( + key.shape[seq_dim_idx] + if hasattr(key, "shape") and len(key.shape) > seq_dim_idx + else local_q_len + ) + global_kv_len = local_kv_len * kv_seq_split_num + + (adjusted_sparse_mode, + adjusted_pre_tockens, + adjusted_next_tockens) = self._compute_sparse_params( + sparse_mode, pre_tockens, next_tockens, + split_id, seq_split_num, local_q_len, global_q_len, global_kv_len, + ) + + if input_layout == "TND": + batch_split_num, s1_split_num = self._calculate_tnd_split_params( + query_layout, key_layout, input_layout + ) + + if s1_split_num > 1: + if is_dynamic: + if sparse_mode != SPARSE_RIGHT_DOWN_CAUSAL: + raise ValueError( + f"TND layout with context parallelism " + f"(s1_split_num={s1_split_num} > 1) requires " + f"sparse_mode={SPARSE_RIGHT_DOWN_CAUSAL}, " + f"but got {sparse_mode}" + ) + else: + query_global_t = query.shape[0] * seq_split_num + key_global_t = key.shape[0] * kv_seq_split_num + self._validate_tnd_cp_requirements( + query_layout, key_layout, input_layout, + sparse_mode, + batch_split_num, s1_split_num, + (query_global_t, *query.shape[1:]), + (key_global_t, *key.shape[1:]) + ) + else: + if not is_dynamic and query.shape[0] != key.shape[0]: + raise ValueError( + f"TND layout with DP-only (s1_split_num=1) requires " + f"Query and Key to have the same local T-dimension, " + f"but got:\n" + f"Query local T-dim: {query.shape[0]}\n" + f"Key local T-dim: {key.shape[0]}\n" + f"batch_split_num: {batch_split_num}, " + f"s1_split_num: {s1_split_num}" + ) + + adjusted_sparse_mode = sparse_mode + adjusted_pre_tockens = pre_tockens + adjusted_next_tockens = next_tockens + + if actual_seq_qlen is None or actual_seq_kvlen is None: + raise ValueError( + "When using TND layout with sequence parallelism, " + "actual_seq_qlen and actual_seq_kvlen must be provided." + ) + + kv_is_sharded = (kv_seq_split_num > 1) + (adjusted_actual_seq_qlen, + adjusted_actual_seq_kvlen) = self._adjust_actual_seq_len_for_tnd_cp( + query, key, actual_seq_qlen, actual_seq_kvlen, + split_id, kv_is_sharded + ) + + result = func( + query, key, value, + adjusted_head_num, + input_layout, + pse, padding_mask, atten_mask, + scale, keep_prob, + adjusted_pre_tockens, + adjusted_next_tockens, + inner_precise, prefix, + adjusted_actual_seq_qlen, + adjusted_actual_seq_kvlen, + adjusted_sparse_mode, + gen_mask_parallel, sync, + softmax_layout, sink, + ) + + return result[:4] if isinstance(result, (tuple, list)) and len(result) >= 4 else result + + return expanded_impl + + def _get_seq_dim_idx(self, dims: dict) -> Optional[int]: + """Get the sequence dimension index.""" + if "seq" in dims: + return dims["seq"] + if "total" in dims: + return dims["total"] + return None + + def _normalize_dim_map(self, dim_map): + """Normalize dim_map.""" + if dim_map is None: + return "None" + return dim_map + + def _get_split_info(self, layout: Layout, input_layout_str: str): + """Extract split information from layout.""" + dims = self._layout_dims.get(input_layout_str, {}) + result = {"batch": 1, "head": 1, "seq": 1} + + if getattr(layout, "alias_tensor_map", None) is None: + return result + + if "batch" in dims: + result["batch"] = self._get_dim_split_num(layout, dims["batch"]) + + if "head" in dims: + result["head"] = self._get_dim_split_num(layout, dims["head"]) + elif "hidden" in dims: + result["head"] = self._get_dim_split_num(layout, dims["hidden"]) + + if "seq" in dims: + result["seq"] = self._get_dim_split_num(layout, dims["seq"]) + elif "total" in dims: + result["seq"] = self._get_dim_split_num(layout, dims["total"]) + + return result + + def _calculate_tnd_split_params( + self, + query_layout: Layout, + key_layout: Layout, + input_layout: str + ) -> Tuple[int, int]: + """Calculate batch_split_num and s1_split_num for TND layout.""" + if input_layout != "TND": + return 1, 1 + + query_split_info = self._get_split_info(query_layout, input_layout) + key_split_info = self._get_split_info(key_layout, input_layout) + + query_seq_split = query_split_info["seq"] + key_seq_split = key_split_info["seq"] + + batch_split_num = key_seq_split + + if batch_split_num == 0: + s1_split_num = query_seq_split + else: + s1_split_num = query_seq_split // batch_split_num + + return batch_split_num, s1_split_num + + def _get_dim_split_num(self, layout: Layout, dim_idx: int) -> int: + """Get split number along a tensor dimension.""" + if getattr(layout, "alias_tensor_map", None) is None: + return 1 + + if dim_idx >= len(layout.alias_tensor_map): + return 1 + + dim_map = self._normalize_dim_map(layout.alias_tensor_map[dim_idx]) + + if dim_map == "None": + return 1 + + if isinstance(dim_map, str): + return layout.mesh.get_device_num_along_axis(dim_map) + + if isinstance(dim_map, tuple): + total = 1 + for axis_name in dim_map: + axis_name = self._normalize_dim_map(axis_name) + if axis_name != "None": + total *= layout.mesh.get_device_num_along_axis(axis_name) + return total + + return 1 + + def _get_split_id(self, layout: Layout, input_layout_str: str) -> int: + """Get split ID along the sequence dimension.""" + dims = self._layout_dims.get(input_layout_str, {}) + seq_dim_idx = self._get_seq_dim_idx(dims) + + if seq_dim_idx is None or getattr(layout, "alias_tensor_map", None) is None: + return 0 + + if seq_dim_idx >= len(layout.alias_tensor_map): + return 0 + + dim_map = self._normalize_dim_map(layout.alias_tensor_map[seq_dim_idx]) + + if dim_map == "None": + return 0 + + if isinstance(dim_map, str): + rank = platform.get_rank() + rank_list = layout.mesh.get_rank_list_along_axis(dim_map) + if rank in rank_list: + return rank_list.index(rank) + return 0 + + if isinstance(dim_map, tuple): + non_none_axes = [ + ax for ax in dim_map if self._normalize_dim_map(ax) != "None" + ] + if len(non_none_axes) == 0: + return 0 + if len(non_none_axes) > 1: + warnings.warn( + f"Seq dim is sharded by multiple axes {non_none_axes}. " + f"Using the last axis for split_id calculation." + ) + axis_name = non_none_axes[-1] + rank = platform.get_rank() + rank_list = layout.mesh.get_rank_list_along_axis(axis_name) + if rank in rank_list: + return rank_list.index(rank) + + return 0 + + def _validate_tnd_cp_requirements( + self, + query_layout: Layout, + key_layout: Layout, + input_layout: str, + sparse_mode: int, + batch_split_num: int, + s1_split_num: int, + query_global_shape: Tuple[int, ...], + key_global_shape: Tuple[int, ...] + ): + """Validate requirements for TND+CP mode.""" + if input_layout != "TND" or s1_split_num <= 1: + return + + if sparse_mode != SPARSE_RIGHT_DOWN_CAUSAL: + raise ValueError( + f"TND layout with context parallelism (s1_split_num={s1_split_num} > 1) " + f"requires sparse_mode={SPARSE_RIGHT_DOWN_CAUSAL} (rightDownCausal), " + f"but got sparse_mode={sparse_mode}.\n" + f"This is required by MindSpore for correct attention mask partitioning." + ) + + if query_global_shape[0] != key_global_shape[0]: + raise ValueError( + f"TND layout with context parallelism requires Query and Key " + f"to have the same global T-dimension, but got:\n" + f"Query global T-dim: {query_global_shape[0]}\n" + f"Key global T-dim: {key_global_shape[0]}\n" + f"Note: This checks the global shape before sharding, not local shape." + ) + + query_split_info = self._get_split_info(query_layout, input_layout) + key_split_info = self._get_split_info(key_layout, input_layout) + + query_seq_split = query_split_info["seq"] + key_seq_split = key_split_info["seq"] + + if query_seq_split != key_seq_split * s1_split_num: + raise ValueError( + f"TND layout requires Query split number to be an integer multiple " + f"of Key split number.\n" + f"Query T-dimension split: {query_seq_split}\n" + f"Key T-dimension split: {key_seq_split}\n" + f"s1_split_num: {s1_split_num}\n" + f"Expected: {query_seq_split} == {key_seq_split} x {s1_split_num} " + f"= {key_seq_split * s1_split_num}\n" + f"This ensures proper alignment for context parallel computation." + ) + + def _create_replicated_scalar_layout(self, query_layout: Layout) -> Layout: + """Create a fully replicated layout for scalar tensor (softmax_out placeholder).""" + layout = Layout.from_device_mesh(query_layout.mesh) + mesh_ndim = len(query_layout.mesh_shape) + replicated_placements = tuple(Replicate() for _ in range(mesh_ndim)) + layout.set_placements(replicated_placements) + layout.set_tensor_map(()) + return layout diff --git a/hyper_parallel/core/shard/ops/yaml/torch_flash_attention_score.yaml b/hyper_parallel/core/shard/ops/yaml/torch_flash_attention_score.yaml new file mode 100644 index 0000000..345a5ba --- /dev/null +++ b/hyper_parallel/core/shard/ops/yaml/torch_flash_attention_score.yaml @@ -0,0 +1,4 @@ +npu_fusion_attention: + dist_op_name: _torch_npu_fusion_attention_dist_op + distributed_op_class: FlashAttentionScoreDistributedOp + distributed_op_file: parallel_npu_flash_attention_score \ No newline at end of file diff --git a/tests/mindspore/ut/parallel_ops_infer/test_parallel_npu_flash_attention_score.py b/tests/mindspore/ut/parallel_ops_infer/test_parallel_npu_flash_attention_score.py new file mode 100644 index 0000000..af3e273 --- /dev/null +++ b/tests/mindspore/ut/parallel_ops_infer/test_parallel_npu_flash_attention_score.py @@ -0,0 +1,585 @@ +# Copyright 2026 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""parallel_npu_fusion_attention unit test""" + +from hyper_parallel import Layout +from hyper_parallel.core.shard.ops.parallel_npu_flash_attention_score import FlashAttentionScoreDistributedOp + + +def run_scenario(scenario_name, q_layout, k_layout, v_layout, expected_out_map, extra_args): + """Infer layout and verify attention output tensor_map.""" + op = FlashAttentionScoreDistributedOp("npu_fusion_attention") + output_layouts = op.infer_layout((q_layout, k_layout, v_layout), extra_args) + attention_out_layout = output_layouts[0] + assert attention_out_layout.to_dict()["tensor_map"] == expected_out_map, \ + f"Scenario '{scenario_name}' failed. " \ + f"Expected {expected_out_map}, got {attention_out_layout.to_dict()['tensor_map']}" + + +base_mesh_shape = (2, 2, 2) +base_alias_name = ("dp", "sp", "mp") +base_rank_list = list(range(8)) + + +def test_flash_attention_no_parallel_1(): + """ + Feature: Layout inference with no parallelism. + Description: All dimensions replicated on a 3D mesh. + Expectation: Output tensor_map is all -1. + """ + x_layout = Layout(base_mesh_shape, base_alias_name, base_rank_list) + q_layout = x_layout("None", "None", "None") + k_layout = x_layout("None", "None", "None") + v_layout = x_layout("None", "None", "None") + + run_scenario( + "No Parallelism", + q_layout, k_layout, v_layout, + expected_out_map=(-1, -1, -1), + extra_args=[16, "BSH"] + ) + + +def test_flash_attention_data_parallel_2(): + """ + Feature: Layout inference with data parallelism. + Description: BSH batch dimension sharded on dp axis. + Expectation: Output batch dimension remains sharded on dp. + """ + x_layout = Layout(base_mesh_shape, base_alias_name, base_rank_list) + q_layout = x_layout("dp", "None", "None") + k_layout = x_layout("dp", "None", "None") + v_layout = x_layout("dp", "None", "None") + + run_scenario( + "Data Parallel (DP)", + q_layout, k_layout, v_layout, + expected_out_map=(2, -1, -1), + extra_args=[16, "BSH"] + ) + + +def test_flash_attention_head_parallel_3(): + """ + Feature: Layout inference with head parallelism. + Description: BSH hidden dimension sharded on mp axis. + Expectation: Output hidden dimension remains sharded on mp. + """ + x_layout = Layout(base_mesh_shape, base_alias_name, base_rank_list) + q_layout = x_layout("None", "None", "mp") + k_layout = x_layout("None", "None", "mp") + v_layout = x_layout("None", "None", "mp") + + run_scenario( + "Head Parallel (MP)", + q_layout, k_layout, v_layout, + expected_out_map=(-1, -1, 0), + extra_args=[16, "BSH"] + ) + + +def test_flash_attention_sequence_parallel_4(): + """ + Feature: Layout inference with sequence parallelism. + Description: BSH query sequence sharded on sp axis, KV not sharded. + Expectation: Output sequence dimension sharded on sp. + """ + x_layout = Layout(base_mesh_shape, base_alias_name, base_rank_list) + q_layout = x_layout("None", "sp", "None") + k_layout = x_layout("None", "None", "None") + v_layout = x_layout("None", "None", "None") + + run_scenario( + "Sequence Parallel (SP)", + q_layout, k_layout, v_layout, + expected_out_map=(-1, 1, -1), + extra_args=[16, "BSH"] + ) + + +def test_flash_attention_hybrid_dp_mp_5(): + """ + Feature: Layout inference with hybrid DP + MP. + Description: BSH batch on dp, hidden on mp. + Expectation: Both dimensions remain sharded. + """ + x_layout = Layout(base_mesh_shape, base_alias_name, base_rank_list) + q_layout = x_layout("dp", "None", "mp") + k_layout = x_layout("dp", "None", "mp") + v_layout = x_layout("dp", "None", "mp") + + run_scenario( + "Hybrid DP + MP", + q_layout, k_layout, v_layout, + expected_out_map=(2, -1, 0), + extra_args=[16, "BSH"] + ) + + +def test_flash_attention_hybrid_dp_sp_mp_6(): + """ + Feature: Layout inference with full hybrid DP + SP + MP. + Description: BSH with all three parallel strategies, KV only dp+mp. + Expectation: Output preserves dp+sp+mp sharding from query. + """ + x_layout = Layout(base_mesh_shape, base_alias_name, base_rank_list) + q_layout = x_layout("dp", "sp", "mp") + k_layout = x_layout("dp", "None", "mp") + v_layout = x_layout("dp", "None", "mp") + + run_scenario( + "Hybrid DP + SP + MP", + q_layout, k_layout, v_layout, + expected_out_map=(2, 1, 0), + extra_args=[16, "BSH"] + ) + + +def test_flash_attention_kv_different_layout_7(): + """ + Feature: Layout inference with different KV sharding. + Description: Key has sp sharding but Value does not. Layout inference does not + validate KV consistency; that is deferred to get_expand_impl at runtime. + Expectation: Layout inference succeeds based on query layout. + """ + x_layout = Layout(base_mesh_shape, base_alias_name, base_rank_list) + q_layout = x_layout("dp", "sp", "mp") + k_layout = x_layout("dp", "sp", "mp") + v_layout = x_layout("dp", "None", "mp") + + run_scenario( + "KV Different Layout", + q_layout, k_layout, v_layout, + expected_out_map=(2, 1, 0), + extra_args=[16, "BSH"] + ) + + +def test_flash_attention_bnsd_layout_8(): + """ + Feature: Layout inference for BNSD input layout. + Description: 4D tensor with dp on batch, mp on head. + Expectation: Output preserves 4D sharding structure. + """ + x_layout = Layout((4, 2), ("dp", "mp"), list(range(8))) + q_layout = x_layout("dp", "mp", "None", "None") + k_layout = x_layout("dp", "mp", "None", "None") + v_layout = x_layout("dp", "mp", "None", "None") + + run_scenario( + "BNSD Layout", + q_layout, k_layout, v_layout, + expected_out_map=(1, 0, -1, -1), + extra_args=[16, "BNSD"] + ) + + +def test_flash_attention_sbh_layout_9(): + """ + Feature: Layout inference for SBH input layout. + Description: Sequence-first 3D layout with head parallelism on mp. + Expectation: Output preserves SBH dimension order with mp sharding. + """ + x_layout = Layout((8,), ("mp",), list(range(8))) + q_layout = x_layout("None", "None", "mp") + k_layout = x_layout("None", "None", "mp") + v_layout = x_layout("None", "None", "mp") + + run_scenario( + "SBH Layout", + q_layout, k_layout, v_layout, + expected_out_map=(-1, -1, 0), + extra_args=[16, "SBH"] + ) + + +def test_flash_attention_bsnd_layout_10(): + """ + Feature: Layout inference for BSND input layout. + Description: 4D tensor with dp on batch, mp on head (dim 2). + Expectation: Output preserves 4D sharding structure. + """ + x_layout = Layout((4, 2), ("dp", "mp"), list(range(8))) + q_layout = x_layout("dp", "None", "mp", "None") + k_layout = x_layout("dp", "None", "mp", "None") + v_layout = x_layout("dp", "None", "mp", "None") + + run_scenario( + "BSND Layout", + q_layout, k_layout, v_layout, + expected_out_map=(1, -1, 0, -1), + extra_args=[16, "BSND"] + ) + + +def test_flash_attention_tuple_alias_11(): + """ + Feature: Layout inference with tuple alias. + Description: Batch dimension uses compound (dp, interleaved) alias. + Expectation: Output preserves tuple alias structure. + """ + mesh_shape = (2, 2, 2) + alias_name = ("dp", "sp", "interleaved") + rank_list = list(range(8)) + + x_layout = Layout(mesh_shape, alias_name, rank_list) + q_layout = x_layout(("dp", "interleaved"), "None", "sp") + k_layout = x_layout(("dp", "interleaved"), "None", "sp") + v_layout = x_layout(("dp", "interleaved"), "None", "sp") + + run_scenario( + "Tuple Alias (Interleaved Parallel)", + q_layout, k_layout, v_layout, + expected_out_map=((2, 0), -1, 1), + extra_args=[16, "BSH"] + ) + + +def test_flash_attention_partial_reduction_12(): + """ + Feature: Layout inference with partial reduction. + Description: Input has partial sum on dp axis. + Expectation: Output preserves partial reduction status. + """ + x_layout = Layout((4, 2), ("dp", "mp"), list(range(8))) + q_layout = x_layout("None", "None", "mp") + q_layout.set_partial_by_dev_axis("dp", "sum") + + k_layout = x_layout("None", "None", "mp") + k_layout.set_partial_by_dev_axis("dp", "sum") + + v_layout = x_layout("None", "None", "mp") + v_layout.set_partial_by_dev_axis("dp", "sum") + + op = FlashAttentionScoreDistributedOp("npu_fusion_attention") + output_layouts = op.infer_layout((q_layout, k_layout, v_layout), [16, "BSH"]) + + attention_out_layout = output_layouts[0] + assert attention_out_layout.partial == q_layout.partial, \ + "Partial reduction status not preserved" + + +def test_flash_attention_non_contiguous_ranks_13(): + """ + Feature: Layout inference with non-contiguous rank list. + Description: Device ranks are not sequential. + Expectation: Output preserves the original rank list. + """ + mesh_shape = (2, 2) + alias_name = ("dp", "mp") + rank_list = [3, 1, 0, 2] + + x_layout = Layout(mesh_shape, alias_name, rank_list) + q_layout = x_layout("dp", "None", "mp") + k_layout = x_layout("dp", "None", "mp") + v_layout = x_layout("dp", "None", "mp") + + op = FlashAttentionScoreDistributedOp("npu_fusion_attention") + output_layouts = op.infer_layout((q_layout, k_layout, v_layout), [16, "BSH"]) + + attention_out_layout = output_layouts[0] + assert attention_out_layout.rank_list == tuple(rank_list), \ + "Rank list not preserved in output" + + +def test_flash_attention_sparse_mode_0_14(): + """ + Feature: Layout inference with sparse_mode=0 in extra_args. + Description: defaultMask mode passed via extra_args. + Expectation: Layout inference succeeds normally. + """ + x_layout = Layout(base_mesh_shape, base_alias_name, base_rank_list) + q_layout = x_layout("dp", "None", "None") + k_layout = x_layout("dp", "None", "None") + v_layout = x_layout("dp", "None", "None") + + run_scenario( + "Sparse Mode 0", + q_layout, k_layout, v_layout, + expected_out_map=(2, -1, -1), + extra_args=[16, "BSH", 0] + ) + + +def test_flash_attention_sparse_mode_2_15(): + """ + Feature: Layout inference with sparse_mode=2 and sequence parallelism. + Description: leftUpCausal mode with query seq sharded on sp. + Expectation: Layout inference succeeds with seq sharding preserved. + """ + x_layout = Layout((8,), ("sp",), list(range(8))) + q_layout = x_layout("None", "sp", "None") + k_layout = x_layout("None", "None", "None") + v_layout = x_layout("None", "None", "None") + + run_scenario( + "Sparse Mode 2 (Causal + SP)", + q_layout, k_layout, v_layout, + expected_out_map=(-1, 0, -1), + extra_args=[16, "BSH", 2] + ) + + +def test_flash_attention_multiple_inputs_same_layout_16(): + """ + Feature: Layout inference returns 4 output layouts. + Description: Q/K/V all have the same dp+mp sharding. + Expectation: Returns exactly 4 layouts; attention_out matches query. + """ + x_layout = Layout(base_mesh_shape, base_alias_name, base_rank_list) + q_layout = x_layout("dp", "None", "mp") + k_layout = x_layout("dp", "None", "mp") + v_layout = x_layout("dp", "None", "mp") + + op = FlashAttentionScoreDistributedOp("npu_fusion_attention") + output_layouts = op.infer_layout((q_layout, k_layout, v_layout), [16, "BSH"]) + + assert len(output_layouts) == 4, "Should return 4 output layouts" + assert output_layouts[0].to_dict()["tensor_map"] == (2, -1, 0) + + +def test_flash_attention_output_layouts_bsh_17(): + """ + Feature: All 4 output layouts for BSH with dp+mp. + Description: Verify attention_out, softmax_max, softmax_sum, and softmax_out + layouts for BSH input with dp on batch and mp on hidden. + Expectation: attention_out matches query; softmax_max/sum are 4D (B,N,S,8); + softmax_out is a fully replicated scalar placeholder with empty tensor_map. + """ + x_layout = Layout((4, 2), ("dp", "mp"), list(range(8))) + q_layout = x_layout("dp", "None", "mp") + k_layout = x_layout("dp", "None", "mp") + v_layout = x_layout("dp", "None", "mp") + + op = FlashAttentionScoreDistributedOp("npu_fusion_attention") + output_layouts = op.infer_layout((q_layout, k_layout, v_layout), [16, "BSH"]) + + attention_out, softmax_max, softmax_sum, softmax_out = output_layouts + + assert attention_out.to_dict()["tensor_map"] == q_layout.to_dict()["tensor_map"] + assert attention_out.to_dict()["tensor_map"] == (1, -1, 0) + + assert len(softmax_max.to_dict()["tensor_map"]) == 4 + assert len(softmax_sum.to_dict()["tensor_map"]) == 4 + assert softmax_max.to_dict()["tensor_map"] == (1, 0, -1, -1) + assert softmax_sum.to_dict()["tensor_map"] == (1, 0, -1, -1) + + assert softmax_out.to_dict()["tensor_map"] == () + + +def test_flash_attention_multi_dimensional_mesh_18(): + """ + Feature: Layout inference with 3D device mesh. + Description: BSH with dp+sp+mp on a (2,2,2) mesh, all inputs fully sharded. + Expectation: Output preserves all three sharding axes. + """ + mesh_shape = (2, 2, 2) + alias_name = ("dp", "sp", "mp") + rank_list = list(range(8)) + + x_layout = Layout(mesh_shape, alias_name, rank_list) + q_layout = x_layout("dp", "sp", "mp") + k_layout = x_layout("dp", "sp", "mp") + v_layout = x_layout("dp", "sp", "mp") + + run_scenario( + "Multi-dimensional Mesh", + q_layout, k_layout, v_layout, + expected_out_map=(2, 1, 0), + extra_args=[16, "BSH"] + ) + + +def test_flash_attention_single_device_19(): + """ + Feature: Layout inference with single device mesh. + Description: Mesh size is 1, all dimensions replicated. + Expectation: Output tensor_map is all -1. + """ + mesh_shape = (1,) + alias_name = ("dp",) + rank_list = [0] + + x_layout = Layout(mesh_shape, alias_name, rank_list) + q_layout = x_layout("None", "None", "None") + k_layout = x_layout("None", "None", "None") + v_layout = x_layout("None", "None", "None") + + run_scenario( + "Single Device", + q_layout, k_layout, v_layout, + expected_out_map=(-1, -1, -1), + extra_args=[16, "BSH"] + ) + + +def test_flash_attention_large_world_size_20(): + """ + Feature: Layout inference with large world size. + Description: 32 devices on a 4D mesh (2,2,2,4) with dp+sp+mp+pp axes. + Expectation: Output scales correctly with 4D mesh. + """ + mesh_shape = (2, 2, 2, 4) + alias_name = ("dp", "sp", "mp", "pp") + rank_list = list(range(32)) + + x_layout = Layout(mesh_shape, alias_name, rank_list) + q_layout = x_layout("dp", "sp", "mp", "None") + k_layout = x_layout("dp", "None", "mp", "None") + v_layout = x_layout("dp", "None", "mp", "None") + + run_scenario( + "Large World Size (32 devices)", + q_layout, k_layout, v_layout, + expected_out_map=(3, 2, 1, -1), + extra_args=[32, "BSH"] + ) + + +def test_flash_attention_output_layouts_bnsd_21(): + """ + Feature: Softmax output layouts for BNSD input. + Description: BNSD with dp on batch, mp on head. Verify softmax_max/sum are 4D + and map BNSD dimensions correctly. + Expectation: softmax_max/sum tensor_map is (dp, mp, -1, -1). + """ + x_layout = Layout((4, 2), ("dp", "mp"), list(range(8))) + q_layout = x_layout("dp", "mp", "None", "None") + k_layout = x_layout("dp", "mp", "None", "None") + v_layout = x_layout("dp", "mp", "None", "None") + + op = FlashAttentionScoreDistributedOp("npu_fusion_attention") + output_layouts = op.infer_layout((q_layout, k_layout, v_layout), [16, "BNSD"]) + + attention_out, softmax_max, softmax_sum, _ = output_layouts + + assert attention_out.to_dict()["tensor_map"] == (1, 0, -1, -1) + assert softmax_max.to_dict()["tensor_map"] == (1, 0, -1, -1) + assert softmax_sum.to_dict()["tensor_map"] == (1, 0, -1, -1) + + +def test_flash_attention_output_layouts_sbh_22(): + """ + Feature: Softmax output layouts for SBH input. + Description: SB (dim 2). Verify softmax + correctly reorders SBH dimensions to (B, N, S, 8). + Expectation: softmax_max/sum tensor_map is (dp, mp, -1, -1). + """ + x_layout = Layout((4, 2), ("dp", "mp"), list(range(8))) + q_layout = x_layout("None", "dp", "mp") + k_layout = x_layout("None", "dp", "mp") + v_layout = x_layout("None", "dp", "mp") + + op = FlashAttentionScoreDistributedOp("npu_fusion_attention") + output_layouts = op.infer_layout((q_layout, k_layout, v_layout), [16, "SBH"]) + + attention_out, softmax_max, softmax_sum, _ = output_layouts + + assert attention_out.to_dict()["tensor_map"] == (-1, 1, 0) + assert softmax_max.to_dict()["tensor_map"] == (1, 0, -1, -1) + assert softmax_sum.to_dict()["tensor_map"] == (1, 0, -1, -1) + + +def test_flash_attention_output_layouts_tnd_23(): + """ + Feature: Softmax output layouts for TND input. + Description: TND with dp on token dim (dim 0), mp on head dim (dim 1). TND branch + maps softmax[0]=query_tm[0] and softmax[1]=query_tm[1], rest are -1. + Expectation: softmax_max/sum tensor_map is (dp, mp, -1, -1). + """ + x_layout = Layout((4, 2), ("dp", "mp"), list(range(8))) + q_layout = x_layout("dp", "mp", "None") + k_layout = x_layout("dp", "mp", "None") + v_layout = x_layout("dp", "mp", "None") + + op = FlashAttentionScoreDistributedOp("npu_fusion_attention") + output_layouts = op.infer_layout((q_layout, k_layout, v_layout), [16, "TND"]) + + attention_out, softmax_max, softmax_sum, _ = output_layouts + + assert attention_out.to_dict()["tensor_map"] == (1, 0, -1) + assert softmax_max.to_dict()["tensor_map"] == (1, 0, -1, -1) + assert softmax_sum.to_dict()["tensor_map"] == (1, 0, -1, -1) + + +def test_flash_attention_output_layouts_mixed_parallel_24(): + """ + Feature: Softmax output layouts with full DP + SP + MP parallelism. + Description: BSH with dp on batch, sp on seq, mp on hidden. Softmax 4D maps + B->dp, N->mp (via hidden), S->sp. + Expectation: softmax_max/sum tensor_map is (dp, mp, sp, -1). + """ + mesh_shape = (2, 2, 2) + alias_name = ("dp", "sp", "mp") + rank_list = list(range(8)) + + x_layout = Layout(mesh_shape, alias_name, rank_list) + q_layout = x_layout("dp", "sp", "mp") + k_layout = x_layout("dp", "None", "mp") + v_layout = x_layout("dp", "None", "mp") + + op = FlashAttentionScoreDistributedOp("npu_fusion_attention") + output_layouts = op.infer_layout((q_layout, k_layout, v_layout), [16, "BSH"]) + + attention_out, softmax_max, softmax_sum, _ = output_layouts + + assert attention_out.to_dict()["tensor_map"] == (2, 1, 0) + assert softmax_max.to_dict()["tensor_map"] == (2, 0, 1, -1) + assert softmax_sum.to_dict()["tensor_map"] == (2, 0, 1, -1) + + +def test_flash_attention_softmax_no_sharding_25(): + """ + Feature: Softmax output layouts with no sharding. + Description: All dimensions replicated on a single-axis mesh. + Expectation: softmax_max/sum tensor_map is all -1. + """ + x_layout = Layout((8,), ("dp",), list(range(8))) + q_layout = x_layout("None", "None", "None") + k_layout = x_layout("None", "None", "None") + v_layout = x_layout("None", "None", "None") + + op = FlashAttentionScoreDistributedOp("npu_fusion_attention") + output_layouts = op.infer_layout((q_layout, k_layout, v_layout), [16, "BSH"]) + + _, softmax_max, softmax_sum, _ = output_layouts + + assert softmax_max.to_dict()["tensor_map"] == (-1, -1, -1, -1) + assert softmax_sum.to_dict()["tensor_map"] == (-1, -1, -1, -1) + + +def test_flash_attention_softmax_tuple_alias_26(): + """ + Feature: Softmax output layouts with tuple alias. + Description: BSH batch dimension uses compound (dp, interleaved) alias, hidden + uses sp. Softmax 4D maps B->(dp,interleaved), N->sp. + Expectation: softmax_max/sum tensor_map is ((dp,interleaved), sp, -1, -1). + """ + mesh_shape = (2, 2, 2) + alias_name = ("dp", "sp", "interleaved") + rank_list = list(range(8)) + + x_layout = Layout(mesh_shape, alias_name, rank_list) + q_layout = x_layout(("dp", "interleaved"), "None", "sp") + k_layout = x_layout(("dp", "interleaved"), "None", "sp") + v_layout = x_layout(("dp", "interleaved"), "None", "sp") + + op = FlashAttentionScoreDistributedOp("npu_fusion_attention") + output_layouts = op.infer_layout((q_layout, k_layout, v_layout), [16, "BSH"]) + + _, softmax_max, softmax_sum, _ = output_layouts + + assert softmax_max.to_dict()["tensor_map"] == ((2, 0), 1, -1, -1) + assert softmax_sum.to_dict()["tensor_map"] == ((2, 0), 1, -1, -1) diff --git a/tests/torch/shard/ops/parallel_op_flash_attention_score.py b/tests/torch/shard/ops/parallel_op_flash_attention_score.py new file mode 100644 index 0000000..cadc6a9 --- /dev/null +++ b/tests/torch/shard/ops/parallel_op_flash_attention_score.py @@ -0,0 +1,1126 @@ +# Copyright 2026 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Test torch DTensor with npu_fusion_attention distributed operator.""" + +import numpy as np +import pytest +import torch +import torch_npu +from hyper_parallel import DTensor, init_device_mesh +from hyper_parallel.core.placement_types import Shard, Replicate +from tests.torch.utils import init_dist + + +np.random.seed(42) + +BATCH_SIZE = 8 +SEQ_LEN = 512 +HEAD_NUM = 16 +HEAD_DIM = 64 +HIDDEN_SIZE = HEAD_NUM * HEAD_DIM +SCALE = 1.0 / (HEAD_DIM ** 0.5) + +global_query_np = np.random.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE).astype(np.float16) +global_key_np = np.random.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE).astype(np.float16) +global_value_np = np.random.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE).astype(np.float16) + + +def create_attention_mask(sparse_mode): + """Create attention mask for the given sparse mode.""" + if sparse_mode == 0: + return None + if sparse_mode == 1: + return torch.zeros(SEQ_LEN, SEQ_LEN, dtype=torch.bool).npu() + if sparse_mode in [2, 3, 4]: + return torch.triu(torch.ones(2048, 2048), diagonal=1).bool().npu() + return None + + +def run_standalone_bsh(sparse_mode=0, scale=SCALE, keep_prob=1.0, + pre_tockens=2147483647, next_tockens=2147483647): + """Run standalone BSH attention as ground truth.""" + q = torch.from_numpy(global_query_np).npu() + k = torch.from_numpy(global_key_np).npu() + v = torch.from_numpy(global_value_np).npu() + mask = create_attention_mask(sparse_mode) + result = torch_npu.npu_fusion_attention( + q, k, v, head_num=HEAD_NUM, input_layout='BSH', + atten_mask=mask, scale=scale, keep_prob=keep_prob, + pre_tockens=pre_tockens, next_tockens=next_tockens, + sparse_mode=sparse_mode, + ) + return result[0] + + +def bsh_tensors(): + """Create BSH format tensors on NPU.""" + q = torch.from_numpy(global_query_np).npu() + k = torch.from_numpy(global_key_np).npu() + v = torch.from_numpy(global_value_np).npu() + return q, k, v + + +def bnsd_tensors(): + """Create BNSD format tensors on NPU.""" + def to_bnsd(data): + x = data.reshape(BATCH_SIZE, SEQ_LEN, HEAD_NUM, HEAD_DIM) + return np.ascontiguousarray(np.transpose(x, (0, 2, 1, 3))) + q = torch.from_numpy(to_bnsd(global_query_np)).npu() + k = torch.from_numpy(to_bnsd(global_key_np)).npu() + v = torch.from_numpy(to_bnsd(global_value_np)).npu() + return q, k, v + + +def sbh_tensors(): + """Create SBH format tensors on NPU.""" + def to_sbh(data): + return np.ascontiguousarray(np.transpose(data, (1, 0, 2))) + q = torch.from_numpy(to_sbh(global_query_np)).npu() + k = torch.from_numpy(to_sbh(global_key_np)).npu() + v = torch.from_numpy(to_sbh(global_value_np)).npu() + return q, k, v + + +def bsnd_tensors(): + """Create BSND format tensors on NPU.""" + def to_bsnd(data): + return np.ascontiguousarray( + data.reshape(BATCH_SIZE, SEQ_LEN, HEAD_NUM, HEAD_DIM) + ) + q = torch.from_numpy(to_bsnd(global_query_np)).npu() + k = torch.from_numpy(to_bsnd(global_key_np)).npu() + v = torch.from_numpy(to_bsnd(global_value_np)).npu() + return q, k, v + + +def tnd_tensors(num_samples=BATCH_SIZE, tokens_per_sample=SEQ_LEN): + """Create TND format tensors and cumulative sequence lengths.""" + total = num_samples * tokens_per_sample + def to_tnd(data): + return np.ascontiguousarray( + data.reshape(BATCH_SIZE, SEQ_LEN, HEAD_NUM, HEAD_DIM) + .reshape(total, HEAD_NUM, HEAD_DIM) + ) + q = torch.from_numpy(to_tnd(global_query_np)).npu() + k = torch.from_numpy(to_tnd(global_key_np)).npu() + v = torch.from_numpy(to_tnd(global_value_np)).npu() + actual_seq_qlen = [(i + 1) * tokens_per_sample for i in range(num_samples)] + actual_seq_kvlen = [(i + 1) * tokens_per_sample for i in range(num_samples)] + return q, k, v, actual_seq_qlen, actual_seq_kvlen + + +def assert_close(actual, expected, atol=1e-2, rtol=1e-2, msg="Output mismatch"): + """Assert two tensors are numerically close.""" + ok = np.allclose( + actual.cpu().float().numpy(), + expected.cpu().float().numpy(), + atol=atol, rtol=rtol, + ) + assert ok, msg + + +def test_bsh_replicate(): + """ + Feature: npu_fusion_attention BSH layout with fully replicated tensors + Description: + - All Q/K/V tensors are Replicate on an 8-device mesh. + - No actual parallelism; the distributed op should pass through directly. + - Compare output against standalone single-device execution. + Expectation: Output matches standalone result within tolerance. + """ + init_dist() + expected = run_standalone_bsh() + + mesh = init_device_mesh(mesh_shape=(8,), alias_name=("dp",)) + q, k, v = bsh_tensors() + dq = DTensor.from_local(q, mesh, (Replicate(),)) + dk = DTensor.from_local(k, mesh, (Replicate(),)) + dv = DTensor.from_local(v, mesh, (Replicate(),)) + + result = torch_npu.npu_fusion_attention( + dq, dk, dv, head_num=HEAD_NUM, input_layout='BSH', + scale=SCALE, sparse_mode=0, + ) + assert_close(result[0].to_local(), expected) + + +def test_bsh_dp(): + """ + Feature: npu_fusion_attention BSH layout with data parallelism + Description: + - Q/K/V are Shard(0) on batch dimension across 8 devices. + - Each device computes attention on its local batch slice independently. + - Gather output and compare against standalone full-batch result. + Expectation: Gathered output matches standalone result within tolerance. + """ + init_dist() + expected = run_standalone_bsh() + + mesh = init_device_mesh(mesh_shape=(8,), alias_name=("dp",)) + q, k, v = bsh_tensors() + placements = (Shard(0),) + dq = DTensor.distribute_tensor(q, mesh, placements) + dk = DTensor.distribute_tensor(k, mesh, placements) + dv = DTensor.distribute_tensor(v, mesh, placements) + + result = torch_npu.npu_fusion_attention( + dq, dk, dv, head_num=HEAD_NUM, input_layout='BSH', + scale=SCALE, sparse_mode=0, + ) + gathered = result[0].redistribute(mesh, (Replicate(),)).to_local() + assert_close(gathered, expected) + + +def test_bsh_hp(): + """ + Feature: npu_fusion_attention BSH layout with head parallelism + Description: + - Q/K/V are Shard(2) on hidden dimension across 8 devices. + - head_num is divided by the split factor; each device handles a subset of heads. + - Gather output and compare against standalone result. + Expectation: Gathered output matches standalone result within tolerance. + """ + init_dist() + expected = run_standalone_bsh() + + mesh = init_device_mesh(mesh_shape=(8,), alias_name=("mp",)) + q, k, v = bsh_tensors() + placements = (Shard(2),) + dq = DTensor.distribute_tensor(q, mesh, placements) + dk = DTensor.distribute_tensor(k, mesh, placements) + dv = DTensor.distribute_tensor(v, mesh, placements) + + result = torch_npu.npu_fusion_attention( + dq, dk, dv, head_num=HEAD_NUM, input_layout='BSH', + scale=SCALE, sparse_mode=0, + ) + gathered = result[0].redistribute(mesh, (Replicate(),)).to_local() + assert_close(gathered, expected) + + +def test_bsh_sp(): + """ + Feature: npu_fusion_attention BSH layout with head parallelism + Description: + - Q/K/V are Shard(2) on hidden dimension across 8 devices. + - head_num is divided by the split factor; each device handles a subset of heads. + - Gather output and compare against standalone result. + Expectation: Gathered output matches standalone result within tolerance. + """ + init_dist() + expected = run_standalone_bsh() + + mesh = init_device_mesh(mesh_shape=(8,), alias_name=("sp",)) + q, k, v = bsh_tensors() + dq = DTensor.distribute_tensor(q, mesh, (Shard(1),)) + dk = DTensor.from_local(k, mesh, (Replicate(),)) + dv = DTensor.from_local(v, mesh, (Replicate(),)) + + result = torch_npu.npu_fusion_attention( + dq, dk, dv, head_num=HEAD_NUM, input_layout='BSH', + scale=SCALE, sparse_mode=0, + ) + gathered = result[0].redistribute(mesh, (Replicate(),)).to_local() + assert_close(gathered, expected) + + +def test_bsh_dp_hp_2d(): + """ + Feature: npu_fusion_attention BSH layout with DP + HP on 2D mesh + Description: + - 2D mesh (4, 2) with dp on batch dim and mp on hidden dim. + - Q/K/V are Shard(0) on dp axis and Shard(2) on mp axis. + - Combines data parallelism and head parallelism simultaneously. + - Gather output and compare against standalone result. + Expectation: Gathered output matches standalone result within tolerance. + """ + init_dist() + expected = run_standalone_bsh() + + mesh = init_device_mesh(mesh_shape=(4, 2), alias_name=("dp", "mp")) + q, k, v = bsh_tensors() + placements = (Shard(0), Shard(2)) + dq = DTensor.distribute_tensor(q, mesh, placements) + dk = DTensor.distribute_tensor(k, mesh, placements) + dv = DTensor.distribute_tensor(v, mesh, placements) + + result = torch_npu.npu_fusion_attention( + dq, dk, dv, head_num=HEAD_NUM, input_layout='BSH', + scale=SCALE, sparse_mode=0, + ) + gathered = result[0].redistribute(mesh, (Replicate(), Replicate())).to_local() + assert_close(gathered, expected) + + +def test_bsh_sp_hp_2d(): + """ + Feature: npu_fusion_attention BSH layout with SP + HP on 2D mesh + Description: + - 2D mesh (4, 2) with sp on seq dim and mp on hidden dim. + - Q is Shard(1)+Shard(2); K/V are Replicate+Shard(2). + - Combines sequence parallelism and head parallelism simultaneously. + - Gather output and compare against standalone result. + Expectation: Gathered output matches standalone result within tolerance. + """ + init_dist() + expected = run_standalone_bsh() + + mesh = init_device_mesh(mesh_shape=(4, 2), alias_name=("sp", "mp")) + q, k, v = bsh_tensors() + dq = DTensor.distribute_tensor(q, mesh, (Shard(1), Shard(2))) + dk = DTensor.distribute_tensor(k, mesh, (Replicate(), Shard(2))) + dv = DTensor.distribute_tensor(v, mesh, (Replicate(), Shard(2))) + + result = torch_npu.npu_fusion_attention( + dq, dk, dv, head_num=HEAD_NUM, input_layout='BSH', + scale=SCALE, sparse_mode=0, + ) + gathered = result[0].redistribute(mesh, (Replicate(), Replicate())).to_local() + assert_close(gathered, expected) + + +def test_bsh_dp_sp_hp_3d(): + """ + Feature: npu_fusion_attention BSH layout with DP + SP + HP on 3D mesh + Description: + - 3D mesh (2, 2, 2) with dp, sp, and mp axes. + - Q is Shard(0)+Shard(1)+Shard(2); K/V are Shard(0)+Replicate+Shard(2). + - Exercises all three parallelism strategies simultaneously. + - Gather output and compare against standalone result. + Expectation: Gathered output matches standalone result within tolerance. + """ + init_dist() + expected = run_standalone_bsh() + + mesh = init_device_mesh(mesh_shape=(2, 2, 2), alias_name=("dp", "sp", "mp")) + q, k, v = bsh_tensors() + dq = DTensor.distribute_tensor(q, mesh, (Shard(0), Shard(1), Shard(2))) + dk = DTensor.distribute_tensor(k, mesh, (Shard(0), Replicate(), Shard(2))) + dv = DTensor.distribute_tensor(v, mesh, (Shard(0), Replicate(), Shard(2))) + + result = torch_npu.npu_fusion_attention( + dq, dk, dv, head_num=HEAD_NUM, input_layout='BSH', + scale=SCALE, sparse_mode=0, + ) + full_replicate = (Replicate(), Replicate(), Replicate()) + gathered = result[0].redistribute(mesh, full_replicate).to_local() + assert_close(gathered, expected) + + +def test_bnsd_dp_hp(): + """ + Feature: npu_fusion_attention BSH layout with DP + SP + HP on 3D mesh + Description: + - 3D mesh (2, 2, 2) with dp, sp, and mp axes. + - Q is Shard(0)+Shard(1)+Shard(2); K/V are Shard(0)+Replicate+Shard(2). + - Exercises all three parallelism strategies simultaneously. + - Gather output and compare against standalone result. + Expectation: Gathered output matches standalone result within tolerance. + """ + init_dist() + + mesh = init_device_mesh(mesh_shape=(4, 2), alias_name=("dp", "mp")) + q, k, v = bnsd_tensors() + placements = (Shard(0), Shard(1)) + dq = DTensor.distribute_tensor(q, mesh, placements) + dk = DTensor.distribute_tensor(k, mesh, placements) + dv = DTensor.distribute_tensor(v, mesh, placements) + + result = torch_npu.npu_fusion_attention( + dq, dk, dv, head_num=HEAD_NUM, input_layout='BNSD', + scale=SCALE, sparse_mode=0, + ) + out = result[0] + assert out.to_local().shape == (BATCH_SIZE // 4, HEAD_NUM // 2, SEQ_LEN, HEAD_DIM) + + +def test_bnsd_sp(): + """ + Feature: npu_fusion_attention BNSD layout with DP + SP on 2D mesh + Description: + - BNSD format with 2D mesh (4, 2): dp on batch dim, sp on seq dim (dim 2). + - Q is Shard(0)+Shard(2); K/V are Shard(0)+Replicate. + - Verify output shape matches expected local dimensions after sharding. + Expectation: Local output shape is (B/4, N, S/2, D). + """ + init_dist() + + mesh = init_device_mesh(mesh_shape=(4, 2), alias_name=("dp", "sp")) + q, k, v = bnsd_tensors() + dq = DTensor.distribute_tensor(q, mesh, (Shard(0), Shard(2))) + dk = DTensor.distribute_tensor(k, mesh, (Shard(0), Replicate())) + dv = DTensor.distribute_tensor(v, mesh, (Shard(0), Replicate())) + + result = torch_npu.npu_fusion_attention( + dq, dk, dv, head_num=HEAD_NUM, input_layout='BNSD', + scale=SCALE, sparse_mode=0, + ) + out = result[0] + assert out.to_local().shape == (BATCH_SIZE // 4, HEAD_NUM, SEQ_LEN // 2, HEAD_DIM) + + +def test_sbh_dp(): + """ + Feature: npu_fusion_attention SBH layout with data parallelism + Description: + - SBH format with batch dim at index 1, sharded across 8 devices. + - Verify output shape matches expected local dimensions after batch sharding. + Expectation: Local output shape is (S, B/8, H). + """ + init_dist() + + mesh = init_device_mesh(mesh_shape=(8,), alias_name=("dp",)) + q, k, v = sbh_tensors() + placements = (Shard(1),) + dq = DTensor.distribute_tensor(q, mesh, placements) + dk = DTensor.distribute_tensor(k, mesh, placements) + dv = DTensor.distribute_tensor(v, mesh, placements) + + result = torch_npu.npu_fusion_attention( + dq, dk, dv, head_num=HEAD_NUM, input_layout='SBH', + scale=SCALE, sparse_mode=0, + ) + out = result[0] + assert out.to_local().shape == (SEQ_LEN, BATCH_SIZE // 8, HIDDEN_SIZE) + + +def test_bsnd_dp_hp(): + """ + Feature: npu_fusion_attention BSND layout with DP + HP on 2D mesh + Description: + - BSND format with 2D mesh (4, 2): dp on batch dim, mp on head dim (dim 2). + - Q/K/V are Shard(0)+Shard(2). + - Verify output shape matches expected local dimensions after sharding. + Expectation: Local output shape is (B/4, S, N/2, D). + """ + init_dist() + + mesh = init_device_mesh(mesh_shape=(4, 2), alias_name=("dp", "mp")) + q, k, v = bsnd_tensors() + placements = (Shard(0), Shard(2)) + dq = DTensor.distribute_tensor(q, mesh, placements) + dk = DTensor.distribute_tensor(k, mesh, placements) + dv = DTensor.distribute_tensor(v, mesh, placements) + + result = torch_npu.npu_fusion_attention( + dq, dk, dv, head_num=HEAD_NUM, input_layout='BSND', + scale=SCALE, sparse_mode=0, + ) + out = result[0] + assert out.to_local().shape == (BATCH_SIZE // 4, SEQ_LEN, HEAD_NUM // 2, HEAD_DIM) + + +def test_tnd_dp(): + """ + Feature: npu_fusion_attention TND layout with data parallelism + Description: + - TND format with T dimension sharded across 8 devices (equal-length samples). + - actual_seq_qlen/actual_seq_kvlen provided as cumulative sums. + - Distributed op adjusts actual_seq_len via clamp to match local T slice. + - Verify output shape matches expected local T dimension. + Expectation: Local output shape is (T/8, N, D). + """ + init_dist() + + q, k, v, actual_seq_qlen, actual_seq_kvlen = tnd_tensors() + + mesh = init_device_mesh(mesh_shape=(8,), alias_name=("dp",)) + placements = (Shard(0),) + dq = DTensor.distribute_tensor(q, mesh, placements) + dk = DTensor.distribute_tensor(k, mesh, placements) + dv = DTensor.distribute_tensor(v, mesh, placements) + + result = torch_npu.npu_fusion_attention( + dq, dk, dv, head_num=HEAD_NUM, input_layout='TND', + scale=SCALE, sparse_mode=0, + actual_seq_qlen=actual_seq_qlen, + actual_seq_kvlen=actual_seq_kvlen, + ) + total_tokens = BATCH_SIZE * SEQ_LEN + out = result[0] + assert out.to_local().shape == (total_tokens // 8, HEAD_NUM, HEAD_DIM) + + +def test_tnd_hp(): + """ + Feature: npu_fusion_attention TND layout with head parallelism + Description: + - TND format with N (head) dimension sharded across 8 devices. + - head_num is divided by the split factor. + - Verify output shape matches expected local head count. + Expectation: Local output shape is (T, N/8, D). + """ + init_dist() + + q, k, v, actual_seq_qlen, actual_seq_kvlen = tnd_tensors() + + mesh = init_device_mesh(mesh_shape=(8,), alias_name=("mp",)) + placements = (Shard(1),) + dq = DTensor.distribute_tensor(q, mesh, placements) + dk = DTensor.distribute_tensor(k, mesh, placements) + dv = DTensor.distribute_tensor(v, mesh, placements) + + result = torch_npu.npu_fusion_attention( + dq, dk, dv, head_num=HEAD_NUM, input_layout='TND', + scale=SCALE, sparse_mode=0, + actual_seq_qlen=actual_seq_qlen, + actual_seq_kvlen=actual_seq_kvlen, + ) + total_tokens = BATCH_SIZE * SEQ_LEN + out = result[0] + assert out.to_local().shape == (total_tokens, HEAD_NUM // 8, HEAD_DIM) + + +def test_tnd_dp_hp(): + """ + Feature: npu_fusion_attention TND layout with DP + HP on 2D mesh + Description: + - TND format with 2D mesh (4, 2): dp on T dim, mp on N dim. + - Q/K/V are Shard(0)+Shard(1). + - Verify output shape matches expected local dimensions after both splits. + Expectation: Local output shape is (T/4, N/2, D). + """ + init_dist() + + q, k, v, actual_seq_qlen, actual_seq_kvlen = tnd_tensors() + + mesh = init_device_mesh(mesh_shape=(4, 2), alias_name=("dp", "mp")) + dq = DTensor.distribute_tensor(q, mesh, (Shard(0), Shard(1))) + dk = DTensor.distribute_tensor(k, mesh, (Shard(0), Shard(1))) + dv = DTensor.distribute_tensor(v, mesh, (Shard(0), Shard(1))) + + result = torch_npu.npu_fusion_attention( + dq, dk, dv, head_num=HEAD_NUM, input_layout='TND', + scale=SCALE, sparse_mode=0, + actual_seq_qlen=actual_seq_qlen, + actual_seq_kvlen=actual_seq_kvlen, + ) + total_tokens = BATCH_SIZE * SEQ_LEN + out = result[0] + assert out.to_local().shape == (total_tokens // 4, HEAD_NUM // 2, HEAD_DIM) + + +def test_sp_sparse_mode_0(): + """ + Feature: npu_fusion_attention BSH SP with sparse_mode=0 (defaultMask, no mask) + Description: + - Q is Shard(1) on seq dim; K/V are Replicate. + - sparse_mode=0 with default pre/next tokens (no mask provided). + - Distributed op applies LEFT_UP_TO_LEFT_UP offset to pre/next tokens. + - Gather output and compare against standalone result. + Expectation: Gathered output matches standalone result within tolerance. + """ + init_dist() + expected = run_standalone_bsh(sparse_mode=0) + + mesh = init_device_mesh(mesh_shape=(8,), alias_name=("sp",)) + q, k, v = bsh_tensors() + dq = DTensor.distribute_tensor(q, mesh, (Shard(1),)) + dk = DTensor.from_local(k, mesh, (Replicate(),)) + dv = DTensor.from_local(v, mesh, (Replicate(),)) + + result = torch_npu.npu_fusion_attention( + dq, dk, dv, head_num=HEAD_NUM, input_layout='BSH', + scale=SCALE, sparse_mode=0, + ) + gathered = result[0].redistribute(mesh, (Replicate(),)).to_local() + assert_close(gathered, expected) + + +def test_sp_sparse_mode_2(): + """ + Feature: npu_fusion_attention BSH SP with sparse_mode=2 (leftUpCausal) + Description: + - Q is Shard(1) on seq dim; K/V are Replicate. + - sparse_mode=2 with compressed 2048x2048 causal mask. + - Distributed op converts to sparse_mode=4 (band) with LEFT_UP_TO_RIGHT_DOWN offset. + - Gather output and compare against standalone result. + Expectation: Gathered output matches standalone result within tolerance. + """ + init_dist() + expected = run_standalone_bsh(sparse_mode=2) + + mesh = init_device_mesh(mesh_shape=(8,), alias_name=("sp",)) + q, k, v = bsh_tensors() + mask = create_attention_mask(2) + dq = DTensor.distribute_tensor(q, mesh, (Shard(1),)) + dk = DTensor.from_local(k, mesh, (Replicate(),)) + dv = DTensor.from_local(v, mesh, (Replicate(),)) + + result = torch_npu.npu_fusion_attention( + dq, dk, dv, head_num=HEAD_NUM, input_layout='BSH', + atten_mask=mask, scale=SCALE, sparse_mode=2, + ) + gathered = result[0].redistribute(mesh, (Replicate(),)).to_local() + assert_close(gathered, expected) + + +def test_sp_sparse_mode_3(): + """ + Feature: npu_fusion_attention BSH SP with sparse_mode=3 (rightDownCausal) + Description: + - Q is Shard(1) on seq dim; K/V are Replicate. + - sparse_mode=3 with compressed 2048x2048 causal mask. + - Distributed op converts to sparse_mode=4 (band) with RIGHT_DOWN_TO_RIGHT_DOWN offset. + - Gather output and compare against standalone result. + Expectation: Gathered output matches standalone result within tolerance. + """ + init_dist() + expected = run_standalone_bsh(sparse_mode=3) + + mesh = init_device_mesh(mesh_shape=(8,), alias_name=("sp",)) + q, k, v = bsh_tensors() + mask = create_attention_mask(3) + dq = DTensor.distribute_tensor(q, mesh, (Shard(1),)) + dk = DTensor.from_local(k, mesh, (Replicate(),)) + dv = DTensor.from_local(v, mesh, (Replicate(),)) + + result = torch_npu.npu_fusion_attention( + dq, dk, dv, head_num=HEAD_NUM, input_layout='BSH', + atten_mask=mask, scale=SCALE, sparse_mode=3, + ) + gathered = result[0].redistribute(mesh, (Replicate(),)).to_local() + assert_close(gathered, expected) + + +def test_sp_sparse_mode_4(): + """ + Feature: npu_fusion_attention BSH SP with sparse_mode=4 (band) and custom pre/next tokens + Description: + - Q is Shard(1) on seq dim; K/V are Replicate. + - sparse_mode=4 with pre_tockens=256, next_tockens=256, compressed 2048x2048 mask. + - Distributed op applies RIGHT_DOWN_TO_RIGHT_DOWN offset to pre/next tokens. + - Gather output and compare against standalone result. + Expectation: Gathered output matches standalone result within tolerance. + """ + init_dist() + pre, nxt = 256, 256 + expected = run_standalone_bsh(sparse_mode=4, pre_tockens=pre, next_tockens=nxt) + + mesh = init_device_mesh(mesh_shape=(8,), alias_name=("sp",)) + q, k, v = bsh_tensors() + mask = create_attention_mask(4) + dq = DTensor.distribute_tensor(q, mesh, (Shard(1),)) + dk = DTensor.from_local(k, mesh, (Replicate(),)) + dv = DTensor.from_local(v, mesh, (Replicate(),)) + + result = torch_npu.npu_fusion_attention( + dq, dk, dv, head_num=HEAD_NUM, input_layout='BSH', + atten_mask=mask, scale=SCALE, sparse_mode=4, + pre_tockens=pre, next_tockens=nxt, + ) + gathered = result[0].redistribute(mesh, (Replicate(),)).to_local() + assert_close(gathered, expected) + + +def test_dp_sparse_mode_1(): + """ + Feature: npu_fusion_attention BSH DP with sparse_mode=1 (allMask) + Description: + - Q/K/V are Shard(0) on batch dim (data parallelism). + - sparse_mode=1 with full [Sq, Skv] attention mask. + - DP does not enter sparse params adjustment; mode=1 is passed through directly. + - Gather output and compare against standalone result. + Expectation: Gathered output matches standalone result within tolerance. + """ + init_dist() + expected = run_standalone_bsh(sparse_mode=1) + + mesh = init_device_mesh(mesh_shape=(8,), alias_name=("dp",)) + q, k, v = bsh_tensors() + mask = create_attention_mask(1) + placements = (Shard(0),) + dq = DTensor.distribute_tensor(q, mesh, placements) + dk = DTensor.distribute_tensor(k, mesh, placements) + dv = DTensor.distribute_tensor(v, mesh, placements) + + result = torch_npu.npu_fusion_attention( + dq, dk, dv, head_num=HEAD_NUM, input_layout='BSH', + atten_mask=mask, scale=SCALE, sparse_mode=1, + ) + gathered = result[0].redistribute(mesh, (Replicate(),)).to_local() + assert_close(gathered, expected) + + +def test_dp_sparse_mode_4(): + """ + Feature: npu_fusion_attention BSH DP with sparse_mode=4 (band) and custom pre/next tokens + Description: + - Q/K/V are Shard(0) on batch dim (data parallelism). + - sparse_mode=4 with pre_tockens=256, + - Gather output and compare against standalone result. + Expectation: Gathered output matches standalone result within tolerance. + """ + init_dist() + pre, nxt = 256, 256 + expected = run_standalone_bsh(sparse_mode=4, pre_tockens=pre, next_tockens=nxt) + + mesh = init_device_mesh(mesh_shape=(8,), alias_name=("dp",)) + q, k, v = bsh_tensors() + mask = create_attention_mask(4) + placements = (Shard(0),) + dq = DTensor.distribute_tensor(q, mesh, placements) + dk = DTensor.distribute_tensor(k, mesh, placements) + dv = DTensor.distribute_tensor(v, mesh, placements) + + result = torch_npu.npu_fusion_attention( + dq, dk, dv, head_num=HEAD_NUM, input_layout='BSH', + atten_mask=mask, scale=SCALE, sparse_mode=4, + pre_tockens=pre, next_tockens=nxt, + ) + gathered = result[0].redistribute(mesh, (Replicate(),)).to_local() + assert_close(gathered, expected) + + +def test_error_kv_strategy_mismatch(): + """ + Feature: npu_fusion_attention rejects mismatched Key and Value sharding + Description: + - K is Shard(0)+Shard(2) but V is Shard(0)+Replicate. + - Key and Value must have identical tensor_map for correctness. + Expectation: Raise ValueError with message "Key and Value must have identical". + """ + init_dist() + + mesh = init_device_mesh(mesh_shape=(4, 2), alias_name=("dp", "mp")) + q, k, v = bsh_tensors() + dq = DTensor.distribute_tensor(q, mesh, (Shard(0), Shard(2))) + dk = DTensor.distribute_tensor(k, mesh, (Shard(0), Shard(2))) + dv = DTensor.distribute_tensor(v, mesh, (Shard(0), Replicate())) + + with pytest.raises(ValueError, match="Key and Value must have identical"): + torch_npu.npu_fusion_attention( + dq, dk, dv, head_num=HEAD_NUM, input_layout='BSH', + scale=SCALE, sparse_mode=0, + ) + + +def test_error_head_num_not_divisible(): + """ + Feature: npu_fusion_attention rejects indivisible head_num under head parallelism + Description: + - Use head_num=17 (prime) with 8-way head sharding. + - 17 is not divisible by 8, so adjusted_head_num cannot be computed. + Expectation: Raise ValueError with message "not divisible". + """ + init_dist() + + odd_heads = 17 + odd_hidden = odd_heads * HEAD_DIM + data = np.random.randn(BATCH_SIZE, SEQ_LEN, odd_hidden).astype(np.float16) + q = torch.from_numpy(data).npu() + k = torch.from_numpy(data).npu() + v = torch.from_numpy(data).npu() + + mesh = init_device_mesh(mesh_shape=(8,), alias_name=("mp",)) + placements = (Shard(2),) + dq = DTensor.distribute_tensor(q, mesh, placements) + dk = DTensor.distribute_tensor(k, mesh, placements) + dv = DTensor.distribute_tensor(v, mesh, placements) + + with pytest.raises(ValueError, match="not divisible"): + torch_npu.npu_fusion_attention( + dq, dk, dv, head_num=odd_heads, input_layout='BSH', + scale=SCALE, sparse_mode=0, + ) + + +def test_error_kv_seq_sharding_blocked(): + """ + Feature: npu_fusion_attention rejects KV sequence sharding for non-TND layouts + Description: + - BSH layout with Q/K/V all Shard(1) on seq dim. + - KV sequence sharding requires Ring Attention which is not supported. + Expectation: Raise NotImplementedError with message "KV sequence sharding is not supported". + """ + init_dist() + + mesh = init_device_mesh(mesh_shape=(8,), alias_name=("sp",)) + q, k, v = bsh_tensors() + placements = (Shard(1),) + dq = DTensor.distribute_tensor(q, mesh, placements) + dk = DTensor.distribute_tensor(k, mesh, placements) + dv = DTensor.distribute_tensor(v, mesh, placements) + + with pytest.raises(NotImplementedError, match="KV sequence sharding is not supported"): + torch_npu.npu_fusion_attention( + dq, dk, dv, head_num=HEAD_NUM, input_layout='BSH', + scale=SCALE, sparse_mode=0, + ) + + +def test_error_sparse_mode_1_no_mask(): + """ + Feature: npu_fusion_attention rejects sparse_mode=1 without atten_mask + Description: + - sparse_mode=1 (allMask) requires an explicit atten_mask tensor. + - Call without providing atten_mask. + Expectation: Raise ValueError with message "sparse_mode=1.*requires atten_mask". + """ + init_dist() + + mesh = init_device_mesh(mesh_shape=(8,), alias_name=("dp",)) + q, k, v = bsh_tensors() + placements = (Shard(0),) + dq = DTensor.distribute_tensor(q, mesh, placements) + dk = DTensor.distribute_tensor(k, mesh, placements) + dv = DTensor.distribute_tensor(v, mesh, placements) + + with pytest.raises(ValueError, match="sparse_mode=1.*requires atten_mask"): + torch_npu.npu_fusion_attention( + dq, dk, dv, head_num=HEAD_NUM, input_layout='BSH', + scale=SCALE, sparse_mode=1, + ) + + +def test_error_tnd_sp_without_actual_seq_len(): + """ + Feature: npu_fusion_attention rejects TND SP without actual_seq_qlen/actual_seq_kvlen + Description: + - TND layout with sequence parallelism (Q and K both Shard(0) on T dim). + - actual_seq_qlen and actual_seq_kvlen are not provided. + Expectation: Raise ValueError with message "actual_seq_qlen and actual_seq_kvlen must be provided". + """ + init_dist() + + q, k, v, _, _ = tnd_tensors() + total_tokens = BATCH_SIZE * SEQ_LEN + + mesh = init_device_mesh(mesh_shape=(4, 2), alias_name=("dp", "sp")) + dq = DTensor.distribute_tensor(q, mesh, (Shard(0), Shard(1))) + dk = DTensor.distribute_tensor(k, mesh, (Shard(0), Shard(1))) + dv = DTensor.distribute_tensor(v, mesh, (Shard(0), Shard(1))) + + with pytest.raises(ValueError, match="actual_seq_qlen and actual_seq_kvlen must be provided"): + torch_npu.npu_fusion_attention( + dq, dk, dv, head_num=HEAD_NUM, input_layout='TND', + scale=SCALE, sparse_mode=3, + ) + + +def test_error_bnsd_kv_seq_sharding_blocked(): + """ + Feature: npu_fusion_attention rejects KV sequence sharding for BNSD layout + Description: + - BNSD layout with Q/K/V all Shard(2) on seq dim. + - KV sequence sharding requires Ring Attention which is not supported, + same restriction as BSH. + Expectation: Raise NotImplementedError with message "KV sequence sharding is not supported". + """ + init_dist() + + mesh = init_device_mesh(mesh_shape=(4, 2), alias_name=("dp", "sp")) + q, k, v = bnsd_tensors() + placements = (Shard(0), Shard(2)) + dq = DTensor.distribute_tensor(q, mesh, placements) + dk = DTensor.distribute_tensor(k, mesh, placements) + dv = DTensor.distribute_tensor(v, mesh, placements) + + with pytest.raises(NotImplementedError, match="KV sequence sharding is not supported"): + torch_npu.npu_fusion_attention( + dq, dk, dv, head_num=HEAD_NUM, input_layout='BNSD', + scale=SCALE, sparse_mode=0, + ) + + +def test_bsh_custom_scale(): + """ + Feature: npu_fusion_attention BSH DP with custom scale value + Description: + - Use a non-default scale=0.125 instead of 1/sqrt(HEAD_DIM). + - Q/K/V are Shard(0) on batch dim. + - Gather output and compare against standalone result with same scale. + Expectation: Gathered output matches standalone result within tolerance. + """ + init_dist() + custom_scale = 0.125 + expected = run_standalone_bsh(scale=custom_scale) + + mesh = init_device_mesh(mesh_shape=(8,), alias_name=("dp",)) + q, k, v = bsh_tensors() + placements = (Shard(0),) + dq = DTensor.distribute_tensor(q, mesh, placements) + dk = DTensor.distribute_tensor(k, mesh, placements) + dv = DTensor.distribute_tensor(v, mesh, placements) + + result = torch_npu.npu_fusion_attention( + dq, dk, dv, head_num=HEAD_NUM, input_layout='BSH', + scale=custom_scale, sparse_mode=0, + ) + gathered = result[0].redistribute(mesh, (Replicate(),)).to_local() + assert_close(gathered, expected) + + +def test_bsh_dropout(): + """ + Feature: npu_fusion_attention BSH DP with dropout (keep_prob < 1) + Description: + - Q/K/V are Shard(0) on batch dim with keep_prob=0.9. + - Dropout introduces randomness, so only output shape is verified. + Expectation: Local output shape is (B/8, S, H). No correctness check due to randomness. + """ + init_dist() + + mesh = init_device_mesh(mesh_shape=(8,), alias_name=("dp",)) + q, k, v = bsh_tensors() + placements = (Shard(0),) + dq = DTensor.distribute_tensor(q, mesh, placements) + dk = DTensor.distribute_tensor(k, mesh, placements) + dv = DTensor.distribute_tensor(v, mesh, placements) + + result = torch_npu.npu_fusion_attention( + dq, dk, dv, head_num=HEAD_NUM, input_layout='BSH', + scale=SCALE, sparse_mode=0, keep_prob=0.9, + ) + out = result[0] + assert out.to_local().shape == (BATCH_SIZE // 8, SEQ_LEN, HIDDEN_SIZE) + + +def test_bsh_long_sequence_sp(): + """ + Feature: npu_fusion_attention BSH SP with longer sequence length + Description: + - Use SEQ_LEN=2048 (4x default) with 8-way sequence parallelism. + - Q is Shard(1); K/V are Replicate. + - Verify output shape matches expected local seq dimension. + Expectation: Local output shape is (B, S/8, H) where S=2048. + """ + init_dist() + + long_seq = 2048 + q_np = np.random.randn(BATCH_SIZE, long_seq, HIDDEN_SIZE).astype(np.float16) + k_np = np.random.randn(BATCH_SIZE, long_seq, HIDDEN_SIZE).astype(np.float16) + v_np = np.random.randn(BATCH_SIZE, long_seq, HIDDEN_SIZE).astype(np.float16) + q = torch.from_numpy(q_np).npu() + k = torch.from_numpy(k_np).npu() + v = torch.from_numpy(v_np).npu() + + mesh = init_device_mesh(mesh_shape=(8,), alias_name=("sp",)) + dq = DTensor.distribute_tensor(q, mesh, (Shard(1),)) + dk = DTensor.from_local(k, mesh, (Replicate(),)) + dv = DTensor.from_local(v, mesh, (Replicate(),)) + + result = torch_npu.npu_fusion_attention( + dq, dk, dv, head_num=HEAD_NUM, input_layout='BSH', + scale=SCALE, sparse_mode=0, + ) + out = result[0] + assert out.to_local().shape == (BATCH_SIZE, long_seq // 8, HIDDEN_SIZE) + + +def test_bsh_large_batch_dp(): + """ + Feature: npu_fusion_attention BSH DP with large batch size + Description: + - Use BATCH_SIZE=64 (8x default) with 8-way data parallelism. + - Verify output shape matches expected local batch dimension. + Expectation: Local output shape is (64/8, S, H). + """ + init_dist() + + large_batch = 64 + q_np = np.random.randn(large_batch, SEQ_LEN, HIDDEN_SIZE).astype(np.float16) + q = torch.from_numpy(q_np).npu() + k = torch.from_numpy(q_np).npu() + v = torch.from_numpy(q_np).npu() + + mesh = init_device_mesh(mesh_shape=(8,), alias_name=("dp",)) + placements = (Shard(0),) + dq = DTensor.distribute_tensor(q, mesh, placements) + dk = DTensor.distribute_tensor(k, mesh, placements) + dv = DTensor.distribute_tensor(v, mesh, placements) + + result = torch_npu.npu_fusion_attention( + dq, dk, dv, head_num=HEAD_NUM, input_layout='BSH', + scale=SCALE, sparse_mode=0, + ) + out = result[0] + assert out.to_local().shape == (large_batch // 8, SEQ_LEN, HIDDEN_SIZE) + + +def test_bsh_redistribute_then_attention(): + """ + Feature: npu_fusion_attention BSH with redistribute before attention + Description: + - Start with Replicate tensors, redistribute to Shard(0) on batch dim. + - Run attention on the redistributed DTensors. + - Gather output and compare against standalone result. + Expectation: Gathered output matches standalone result within tolerance. + """ + init_dist() + expected = run_standalone_bsh() + + mesh = init_device_mesh(mesh_shape=(8,), alias_name=("dp",)) + q, k, v = bsh_tensors() + dq = DTensor.from_local(q, mesh, (Replicate(),)) + dk = DTensor.from_local(k, mesh, (Replicate(),)) + dv = DTensor.from_local(v, mesh, (Replicate(),)) + + dq = dq.redistribute(mesh, (Shard(0),)) + dk = dk.redistribute(mesh, (Shard(0),)) + dv = dv.redistribute(mesh, (Shard(0),)) + + result = torch_npu.npu_fusion_attention( + dq, dk, dv, head_num=HEAD_NUM, input_layout='BSH', + scale=SCALE, sparse_mode=0, + ) + gathered = result[0].redistribute(mesh, (Replicate(),)).to_local() + assert_close(gathered, expected) + + +def test_sp_sparse_mode_2_with_2way_split(): + """ + Feature: npu_fusion_attention BSH SP sparse_mode=2 with 2-way sequence split + Description: + - 2D mesh (2, 4) with sp=2 and dp=4. + - Q is Shard(1)+Shard(0); K/V are Replicate+Shard(0). + - Verifies LEFT_UP_TO_RIGHT_DOWN offset calculation with non-8-way split. + - Gather output and compare against standalone result. + Expectation: Gathered output matches standalone result within tolerance. + """ + init_dist() + expected = run_standalone_bsh(sparse_mode=2) + + mesh = init_device_mesh(mesh_shape=(2, 4), alias_name=("sp", "dp")) + q, k, v = bsh_tensors() + dq = DTensor.distribute_tensor(q, mesh, (Shard(1), Shard(0))) + dk = DTensor.distribute_tensor(k, mesh, (Replicate(), Shard(0))) + dv = DTensor.distribute_tensor(v, mesh, (Replicate(), Shard(0))) + mask = create_attention_mask(2) + + result = torch_npu.npu_fusion_attention( + dq, dk, dv, head_num=HEAD_NUM, input_layout='BSH', + atten_mask=mask, scale=SCALE, sparse_mode=2, + ) + gathered = result[0].redistribute(mesh, (Replicate(), Replicate())).to_local() + assert_close(gathered, expected) + + +def test_sp_sparse_mode_3_with_2way_split(): + """ + Feature: npu_fusion_attention BSH SP sparse_mode=3 with 2-way sequence split + Description: + - 2D mesh (2, 4) with sp=2 and dp=4. + - Q is Shard(1)+Shard(0); K/V are Replicate+Shard(0). + - Verifies RIGHT_DOWN_TO_RIGHT_DOWN offset calculation with non-8-way split. + - Gather output and compare against standalone result. + Expectation: Gathered output matches standalone result within tolerance. + """ + init_dist() + expected = run_standalone_bsh(sparse_mode=3) + + mesh = init_device_mesh(mesh_shape=(2, 4), alias_name=("sp", "dp")) + q, k, v = bsh_tensors() + dq = DTensor.distribute_tensor(q, mesh, (Shard(1), Shard(0))) + dk = DTensor.distribute_tensor(k, mesh, (Replicate(), Shard(0))) + dv = DTensor.distribute_tensor(v, mesh, (Replicate(), Shard(0))) + mask = create_attention_mask(3) + + result = torch_npu.npu_fusion_attention( + dq, dk, dv, head_num=HEAD_NUM, input_layout='BSH', + atten_mask=mask, scale=SCALE, sparse_mode=3, + ) + gathered = result[0].redistribute(mesh, (Replicate(), Replicate())).to_local() + assert_close(gathered, expected) + + +def test_bnsd_sp_correctness(): + """ + Feature: npu_fusion_attention BNSD SP correctness verification + Description: + - BNSD format with Q Shard(2) on seq dim; K/V Replicate. + - Run standalone BNSD attention as ground truth. + - Gather distributed output and compare against standalone result. + Expectation: Gathered output matches standalone result within tolerance. + """ + init_dist() + + q_bnsd, k_bnsd, v_bnsd = bnsd_tensors() + mask = create_attention_mask(0) + standalone = torch_npu.npu_fusion_attention( + q_bnsd, k_bnsd, v_bnsd, head_num=HEAD_NUM, input_layout='BNSD', + atten_mask=mask, scale=SCALE, sparse_mode=0, + )[0] + + mesh = init_device_mesh(mesh_shape=(8,), alias_name=("sp",)) + dq = DTensor.distribute_tensor(q_bnsd, mesh, (Shard(2),)) + dk = DTensor.from_local(k_bnsd, mesh, (Replicate(),)) + dv = DTensor.from_local(v_bnsd, mesh, (Replicate(),)) + + result = torch_npu.npu_fusion_attention( + dq, dk, dv, head_num=HEAD_NUM, input_layout='BNSD', + scale=SCALE, sparse_mode=0, + ) + gathered = result[0].redistribute(mesh, (Replicate(),)).to_local() + assert_close(gathered, standalone) + + +def test_tnd_dp_correctness(): + """ + Feature: npu_fusion_attention TND DP correctness with actual_seq_len adjustment + Description: + - TND format with 8-way DP on T dimension (equal-length samples). + - Verifies that pure DP correctly adjusts actual_seq_qlen/actual_seq_kvlen + via clamp to match local T slice, while skipping sparse params adjustment. + - Run standalone TND attention as ground truth. + - Gather distributed output and compare against standalone result. + Expectation: Gathered output matches standalone result within tolerance. + """ + init_dist() + + # 8 batches with varying lengths, total = 4096 + batch_lens = [480, 520, 500, 540, 460, 550, 490, 456] + total = sum(batch_lens) + q, k, v, actual_seq_qlen, actual_seq_kvlen = tnd_tensors() + + mesh = init_device_mesh(mesh_shape=(8,), alias_name=("dp",)) + placements = (Shard(0),) + dq = DTensor.distribute_tensor(q, mesh, placements) + dk = DTensor.distribute_tensor(k, mesh, placements) + dv = DTensor.distribute_tensor(v, mesh, placements) + + # Run standalone for comparison + q_full = torch.from_numpy( + global_query_np.reshape(BATCH_SIZE, SEQ_LEN, HEAD_NUM, HEAD_DIM) + .reshape(BATCH_SIZE * SEQ_LEN, HEAD_NUM, HEAD_DIM) + ).npu() + k_full = torch.from_numpy( + global_key_np.reshape(BATCH_SIZE, SEQ_LEN, HEAD_NUM, HEAD_DIM) + .reshape(BATCH_SIZE * SEQ_LEN, HEAD_NUM, HEAD_DIM) + ).npu() + v_full = torch.from_numpy( + global_value_np.reshape(BATCH_SIZE, SEQ_LEN, HEAD_NUM, HEAD_DIM) + .reshape(BATCH_SIZE * SEQ_LEN, HEAD_NUM, HEAD_DIM) + ).npu() + expected = torch_npu.npu_fusion_attention( + q_full, k_full, v_full, head_num=HEAD_NUM, input_layout='TND', + scale=SCALE, sparse_mode=0, + actual_seq_qlen=actual_seq_qlen, + actual_seq_kvlen=actual_seq_kvlen, + )[0] + + result = torch_npu.npu_fusion_attention( + dq, dk, dv, head_num=HEAD_NUM, input_layout='TND', + scale=SCALE, sparse_mode=0, + actual_seq_qlen=actual_seq_qlen, + actual_seq_kvlen=actual_seq_kvlen, + ) + + total_tokens = BATCH_SIZE * SEQ_LEN + out = result[0] + assert out.to_local().shape == (total_tokens // 8, HEAD_NUM, HEAD_DIM) + + gathered = result[0].redistribute(mesh, (Replicate(),)).to_local() + assert_close(gathered, expected) diff --git a/tests/torch/shard/ops/test_parallel_flash_attention_score.py b/tests/torch/shard/ops/test_parallel_flash_attention_score.py new file mode 100644 index 0000000..75419ef --- /dev/null +++ b/tests/torch/shard/ops/test_parallel_flash_attention_score.py @@ -0,0 +1,475 @@ +# Copyright 2026 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Test npu_fusion_attention distributed operator""" + +from tests.torch.utils import torchrun_case +from tests.common.mark_utils import arg_mark + + +@arg_mark(plat_marks=["platform_ascend910b"], level_mark="level0", card_mark="onecard", essential_mark="essential") +def test_bsh_replicate(): + """ + Feature: test npu_fusion_attention with BSH layout, no parallelism. + Description: All tensors replicated, verify basic DTensor pass-through. + Expectation: Output matches standalone execution. + """ + master_port = 11010 + file_name = "parallel_op_flash_attention_score.py" + case_name = "test_bsh_replicate" + torchrun_case(file_name, case_name, master_port) + + +@arg_mark(plat_marks=["platform_ascend910b"], level_mark="level0", card_mark="allcards", essential_mark="essential") +def test_bsh_dp(): + """ + Feature: test npu_fusion_attention with BSH layout, data parallelism. + Description: Shard on batch dimension, gather and compare with standalone. + Expectation: Output matches standalone execution. + """ + master_port = 11011 + file_name = "parallel_op_flash_attention_score.py" + case_name = "test_bsh_dp" + torchrun_case(file_name, case_name, master_port) + + +@arg_mark(plat_marks=["platform_ascend910b"], level_mark="level0", card_mark="allcards", essential_mark="essential") +def test_bsh_hp(): + """ + Feature: test npu_fusion_attention with BSH layout, head parallelism. + Description: Shard on hidden dimension, head_num adjusted correctly. + Expectation: Output matches standalone execution. + """ + master_port = 11012 + file_name = "parallel_op_flash_attention_score.py" + case_name = "test_bsh_hp" + torchrun_case(file_name, case_name, master_port) + + +@arg_mark(plat_marks=["platform_ascend910b"], level_mark="level0", card_mark="allcards", essential_mark="essential") +def test_bsh_sp(): + """ + Feature: test npu_fusion_attention with BSH layout, sequence parallelism. + Description: Q shard on seq dim, KV replicated, sparse params adjusted. + Expectation: Output matches standalone execution. + """ + master_port = 11013 + file_name = "parallel_op_flash_attention_score.py" + case_name = "test_bsh_sp" + torchrun_case(file_name, case_name, master_port) + + +@arg_mark(plat_marks=["platform_ascend910b"], level_mark="level0", card_mark="allcards", essential_mark="essential") +def test_bsh_dp_hp_2d(): + """ + Feature: test npu_fusion_attention with BSH layout, DP + HP on 2D mesh. + Description: Shard batch and hidden on (4, 2) mesh. + Expectation: Output matches standalone execution. + """ + master_port = 11014 + file_name = "parallel_op_flash_attention_score.py" + case_name = "test_bsh_dp_hp_2d" + torchrun_case(file_name, case_name, master_port) + + +@arg_mark(plat_marks=["platform_ascend910b"], level_mark="level0", card_mark="allcards", essential_mark="essential") +def test_bsh_sp_hp_2d(): + """ + Feature: test npu_fusion_attention with BSH layout, SP + HP on 2D mesh. + Description: Q shard on seq + hidden, KV replicate seq + shard hidden. + Expectation: Output matches standalone execution. + """ + master_port = 11015 + file_name = "parallel_op_flash_attention_score.py" + case_name = "test_bsh_sp_hp_2d" + torchrun_case(file_name, case_name, master_port) + + +@arg_mark(plat_marks=["platform_ascend910b"], level_mark="level1", card_mark="allcards", essential_mark="unessential") +def test_bsh_dp_sp_hp_3d(): + """ + Feature: test npu_fusion_attention with BSH layout, DP + SP + HP on 3D mesh. + Description: Full parallelism on (2, 2, 2) mesh. + Expectation: Output matches standalone execution. + """ + master_port = 11016 + file_name = "parallel_op_flash_attention_score.py" + case_name = "test_bsh_dp_sp_hp_3d" + torchrun_case(file_name, case_name, master_port) + + +@arg_mark(plat_marks=["platform_ascend910b"], level_mark="level1", card_mark="allcards", essential_mark="unessential") +def test_bnsd_dp_hp(): + """ + Feature: test npu_fusion_attention with BNSD layout, DP + HP. + Description: Shard batch and head on (4, 2) mesh. + Expectation: Correct local output shape. + """ + master_port = 11017 + file_name = "parallel_op_flash_attention_score.py" + case_name = "test_bnsd_dp_hp" + torchrun_case(file_name, case_name, master_port) + + +@arg_mark(plat_marks=["platform_ascend910b"], level_mark="level1", card_mark="allcards", essential_mark="unessential") +def test_bnsd_sp(): + """ + Feature: test npu_fusion_attention with BNSD layout, DP + SP. + Description: Q shard on seq dim 2, KV replicate on seq. + Expectation: Correct local output shape. + """ + master_port = 11018 + file_name = "parallel_op_flash_attention_score.py" + case_name = "test_bnsd_sp" + torchrun_case(file_name, case_name, master_port) + + +@arg_mark(plat_marks=["platform_ascend910b"], level_mark="level1", card_mark="allcards", essential_mark="unessential") +def test_sbh_dp(): + """ + Feature: test npu_fusion_attention with SBH layout, data parallelism. + Description: Shard on batch dim (dim 1 in SBH). + Expectation: Correct local output shape. + """ + master_port = 11019 + file_name = "parallel_op_flash_attention_score.py" + case_name = "test_sbh_dp" + torchrun_case(file_name, case_name, master_port) + + +@arg_mark(plat_marks=["platform_ascend910b"], level_mark="level1", card_mark="allcards", essential_mark="unessential") +def test_bsnd_dp_hp(): + """ + Feature: test npu_fusion_attention with BSND layout, DP + HP. + Description: Shard batch and head on (4, 2) mesh. + Expectation: Correct local output shape. + """ + master_port = 11020 + file_name = "parallel_op_flash_attention_score.py" + case_name = "test_bsnd_dp_hp" + torchrun_case(file_name, case_name, master_port) + + +@arg_mark(plat_marks=["platform_ascend910b"], level_mark="level1", card_mark="allcards", essential_mark="unessential") +def test_tnd_dp(): + """ + Feature: test npu_fusion_attention with TND layout, data parallelism. + Description: Shard T-dim with cumulative actual_seq_qlen/kvlen. + Expectation: Correct local output shape. + """ + master_port = 11021 + file_name = "parallel_op_flash_attention_score.py" + case_name = "test_tnd_dp" + torchrun_case(file_name, case_name, master_port) + + +@arg_mark(plat_marks=["platform_ascend910b"], level_mark="level1", card_mark="allcards", essential_mark="unessential") +def test_tnd_hp(): + """ + Feature: test npu_fusion_attention with TND layout, head parallelism. + Description: Shard N-dim, head_num adjusted. + Expectation: Correct local output shape. + """ + master_port = 11022 + file_name = "parallel_op_flash_attention_score.py" + case_name = "test_tnd_hp" + torchrun_case(file_name, case_name, master_port) + + +@arg_mark(plat_marks=["platform_ascend910b"], level_mark="level1", card_mark="allcards", essential_mark="unessential") +def test_tnd_dp_hp(): + """ + Feature: test npu_fusion_attention with TND layout, DP + HP on 2D mesh. + Description: Shard T-dim and N-dim on (4, 2) mesh. + Expectation: Correct local output shape. + """ + master_port = 11023 + file_name = "parallel_op_flash_attention_score.py" + case_name = "test_tnd_dp_hp" + torchrun_case(file_name, case_name, master_port) + + +@arg_mark(plat_marks=["platform_ascend910b"], level_mark="level0", card_mark="allcards", essential_mark="essential") +def test_sp_sparse_mode_0(): + """ + Feature: test npu_fusion_attention SP with sparse_mode=0. + Description: Sequence parallel with default mask, no pre/next token adjustment. + Expectation: Output matches standalone execution. + """ + master_port = 11024 + file_name = "parallel_op_flash_attention_score.py" + case_name = "test_sp_sparse_mode_0" + torchrun_case(file_name, case_name, master_port) + + +@arg_mark(plat_marks=["platform_ascend910b"], level_mark="level0", card_mark="allcards", essential_mark="essential") +def test_sp_sparse_mode_2(): + """ + Feature: test npu_fusion_attention SP with sparse_mode=2. + Description: Sequence parallel with left_up_causal mask, offset via LEFT_UP_TO_RIGHT_DOWN. + Expectation: Output matches standalone execution. + """ + master_port = 11025 + file_name = "parallel_op_flash_attention_score.py" + case_name = "test_sp_sparse_mode_2" + torchrun_case(file_name, case_name, master_port) + + +@arg_mark(plat_marks=["platform_ascend910b"], level_mark="level0", card_mark="allcards", essential_mark="essential") +def test_sp_sparse_mode_3(): + """ + Feature: test npu_fusion_attention SP with sparse_mode=3. + Description: Sequence parallel with right_down_causal mask, offset via RIGHT_DOWN_TO_RIGHT_DOWN. + Expectation: Output matches standalone execution. + """ + master_port = 11026 + file_name = "parallel_op_flash_attention_score.py" + case_name = "test_sp_sparse_mode_3" + torchrun_case(file_name, case_name, master_port) + + +@arg_mark(plat_marks=["platform_ascend910b"], level_mark="level0", card_mark="allcards", essential_mark="essential") +def test_sp_sparse_mode_4(): + """ + Feature: test npu_fusion_attention SP with sparse_mode=4. + Description: Sequence parallel with band mask and custom pre/next_tockens. + Expectation: Output matches standalone execution. + """ + master_port = 11027 + file_name = "parallel_op_flash_attention_score.py" + case_name = "test_sp_sparse_mode_4" + torchrun_case(file_name, case_name, master_port) + + +@arg_mark(plat_marks=["platform_ascend910b"], level_mark="level0", card_mark="allcards", essential_mark="essential") +def test_dp_sparse_mode_1(): + """ + Feature: test npu_fusion_attention DP with sparse_mode=1. + Description: Data parallel with all_mask, requires atten_mask provided. + Expectation: Output matches standalone execution. + """ + master_port = 11028 + file_name = "parallel_op_flash_attention_score.py" + case_name = "test_dp_sparse_mode_1" + torchrun_case(file_name, case_name, master_port) + + +@arg_mark(plat_marks=["platform_ascend910b"], level_mark="level0", card_mark="allcards", essential_mark="essential") +def test_dp_sparse_mode_4(): + """ + Feature: test npu_fusion_attention DP with sparse_mode=4. + Description: Data parallel with band mask and custom pre/next_tockens. + Expectation: Output matches standalone execution. + """ + master_port = 11029 + file_name = "parallel_op_flash_attention_score.py" + case_name = "test_dp_sparse_mode_4" + torchrun_case(file_name, case_name, master_port) + + +@arg_mark(plat_marks=["platform_ascend910b"], level_mark="level0", card_mark="allcards", essential_mark="essential") +def test_error_kv_strategy_mismatch(): + """ + Feature: test npu_fusion_attention rejects mismatched KV strategies. + Description: Key and Value have different sharding strategies. + Expectation: Raise ValueError. + """ + master_port = 11030 + file_name = "parallel_op_flash_attention_score.py" + case_name = "test_error_kv_strategy_mismatch" + torchrun_case(file_name, case_name, master_port) + + +@arg_mark(plat_marks=["platform_ascend910b"], level_mark="level0", card_mark="allcards", essential_mark="essential") +def test_error_head_num_not_divisible(): + """ + Feature: test npu_fusion_attention rejects non-divisible head_num. + Description: head_num=17 is not divisible by head_split_num=8. + Expectation: Raise ValueError. + """ + master_port = 11031 + file_name = "parallel_op_flash_attention_score.py" + case_name = "test_error_head_num_not_divisible" + torchrun_case(file_name, case_name, master_port) + + +@arg_mark(plat_marks=["platform_ascend910b"], level_mark="level0", card_mark="allcards", essential_mark="essential") +def test_error_kv_seq_sharding_blocked(): + """ + Feature: test npu_fusion_attention blocks KV sequence sharding for BSH. + Description: Both Q and KV sharded on seq dim without Ring Attention. + Expectation: Raise NotImplementedError. + """ + master_port = 11032 + file_name = "parallel_op_flash_attention_score.py" + case_name = "test_error_kv_seq_sharding_blocked" + torchrun_case(file_name, case_name, master_port) + + +@arg_mark(plat_marks=["platform_ascend910b"], level_mark="level0", card_mark="allcards", essential_mark="essential") +def test_error_sparse_mode_1_no_mask(): + """ + Feature: test npu_fusion_attention rejects sparse_mode=1 without mask. + Description: sparse_mode=1 (allMask) requires atten_mask to be provided. + Expectation: Raise ValueError. + """ + master_port = 11033 + file_name = "parallel_op_flash_attention_score.py" + case_name = "test_error_sparse_mode_1_no_mask" + torchrun_case(file_name, case_name, master_port) + + +@arg_mark(plat_marks=["platform_ascend910b"], level_mark="level0", card_mark="allcards", essential_mark="essential") +def test_error_tnd_sp_without_actual_seq_len(): + """ + Feature: test npu_fusion_attention rejects TND SP without actual_seq_len. + Description: TND with sequence parallel requires actual_seq_qlen and actual_seq_kvlen. + Expectation: Raise ValueError. + """ + master_port = 11034 + file_name = "parallel_op_flash_attention_score.py" + case_name = "test_error_tnd_sp_without_actual_seq_len" + torchrun_case(file_name, case_name, master_port) + + +@arg_mark(plat_marks=["platform_ascend910b"], level_mark="level0", card_mark="allcards", essential_mark="essential") +def test_error_bnsd_kv_seq_sharding_blocked(): + """ + Feature: test npu_fusion_attention blocks KV sequence sharding for BNSD. + Description: Both Q and KV sharded on seq dim in BNSD layout. + Expectation: Raise NotImplementedError. + """ + master_port = 11035 + file_name = "parallel_op_flash_attention_score.py" + case_name = "test_error_bnsd_kv_seq_sharding_blocked" + torchrun_case(file_name, case_name, master_port) + + +@arg_mark(plat_marks=["platform_ascend910b"], level_mark="level1", card_mark="allcards", essential_mark="unessential") +def test_bsh_custom_scale(): + """ + Feature: test npu_fusion_attention with custom scale value. + Description: BSH DP with scale=0.125 instead of default. + Expectation: Output matches standalone execution with same scale. + """ + master_port = 11036 + file_name = "parallel_op_flash_attention_score.py" + case_name = "test_bsh_custom_scale" + torchrun_case(file_name, case_name, master_port) + + +@arg_mark(plat_marks=["platform_ascend910b"], level_mark="level1", card_mark="allcards", essential_mark="unessential") +def test_bsh_dropout(): + """ + Feature: test npu_fusion_attention with dropout. + Description: BSH DP with keep_prob=0.9, smoke test due to randomness. + Expectation: Correct output shape without crash. + """ + master_port = 11037 + file_name = "parallel_op_flash_attention_score.py" + case_name = "test_bsh_dropout" + torchrun_case(file_name, case_name, master_port) + + +@arg_mark(plat_marks=["platform_ascend910b"], level_mark="level1", card_mark="allcards", essential_mark="unessential") +def test_bsh_long_sequence_sp(): + """ + Feature: test npu_fusion_attention SP with long sequence. + Description: BSH SP with seq_len=2048. + Expectation: Correct output shape. + """ + master_port = 11038 + file_name = "parallel_op_flash_attention_score.py" + case_name = "test_bsh_long_sequence_sp" + torchrun_case(file_name, case_name, master_port) + + +@arg_mark(plat_marks=["platform_ascend910b"], level_mark="level1", card_mark="allcards", essential_mark="unessential") +def test_bsh_large_batch_dp(): + """ + Feature: test npu_fusion_attention DP with large batch. + Description: BSH DP with batch_size=64. + Expectation: Correct output shape. + """ + master_port = 11039 + file_name = "parallel_op_flash_attention_score.py" + case_name = "test_bsh_large_batch_dp" + torchrun_case(file_name, case_name, master_port) + + +@arg_mark(plat_marks=["platform_ascend910b"], level_mark="level1", card_mark="allcards", essential_mark="unessential") +def test_bsh_redistribute_then_attention(): + """ + Feature: test npu_fusion_attention after redistribute. + Description: Redistribute from Replicate to Shard before running attention. + Expectation: Output matches standalone execution. + """ + master_port = 11040 + file_name = "parallel_op_flash_attention_score.py" + case_name = "test_bsh_redistribute_then_attention" + torchrun_case(file_name, case_name, master_port) + + +@arg_mark(plat_marks=["platform_ascend910b"], level_mark="level1", card_mark="allcards", essential_mark="unessential") +def test_sp_sparse_mode_2_with_2way_split(): + """ + Feature: test npu_fusion_attention SP sparse_mode=2 with 2-way split. + Description: Verify LEFT_UP_TO_RIGHT_DOWN offset on (2, 4) mesh. + Expectation: Output matches standalone execution. + """ + master_port = 11041 + file_name = "parallel_op_flash_attention_score.py" + case_name = "test_sp_sparse_mode_2_with_2way_split" + torchrun_case(file_name, case_name, master_port) + + +@arg_mark(plat_marks=["platform_ascend910b"], level_mark="level1", card_mark="allcards", essential_mark="unessential") +def test_sp_sparse_mode_3_with_2way_split(): + """ + Feature: test npu_fusion_attention SP sparse_mode=3 with 2-way split. + Description: Verify RIGHT_DOWN_TO_RIGHT_DOWN offset on (2, 4) mesh. + Expectation: Output matches standalone execution. + """ + master_port = 11042 + file_name = "parallel_op_flash_attention_score.py" + case_name = "test_sp_sparse_mode_3_with_2way_split" + torchrun_case(file_name, case_name, master_port) + + +@arg_mark(plat_marks=["platform_ascend910b"], level_mark="level1", card_mark="allcards", essential_mark="unessential") +def test_bnsd_sp_correctness(): + """ + Feature: test npu_fusion_attention BNSD SP correctness. + Description: BNSD layout with Q seq shard, KV replicate, compare with standalone. + Expectation: Output matches standalone execution. + """ + master_port = 11043 + file_name = "parallel_op_flash_attention_score.py" + case_name = "test_bnsd_sp_correctness" + torchrun_case(file_name, case_name, master_port) + + +@arg_mark(plat_marks=["platform_ascend910b"], level_mark="level1", card_mark="allcards", essential_mark="unessential") +def test_tnd_dp_correctness(): + """ + Feature: test npu_fusion_attention TND DP correctness with actual_seq_len adjustment. + Description: TND layout with 8-way DP on T dimension, verifies pure DP correctly + adjusts actual_seq_qlen/actual_seq_kvlen via clamp while skipping sparse params + adjustment. Gather output and compare with standalone. + Expectation: Output matches standalone execution. + """ + master_port = 11044 + file_name = "parallel_op_flash_attention_score.py" + case_name = "test_tnd_dp_correctness" + torchrun_case(file_name, case_name, master_port) -- Gitee