diff --git a/test/npu/test_mult_stream.py b/test/npu/test_mult_stream.py new file mode 100644 index 0000000000000000000000000000000000000000..e1ed7398fd8ba9166e6e70f3c70ff02537c25542 --- /dev/null +++ b/test/npu/test_mult_stream.py @@ -0,0 +1,87 @@ +import torch +import torch_npu +from torch_npu.testing.testcase import TestCase, run_tests + + +class TestMultiStream(TestCase): + def test_multi_stream(self): + + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, query, key, value, x_mask): + scale = 0.08838 + s = torch.npu.default_stream() + s_cpu = torch.npu.Stream() + e_mask1 = torch.npu.Event(False, False, False) + e_mask2 = torch.npu.Event(False, False, False) + e_mask3 = torch.npu.Event(False, False, False) + e_mask4 = torch.npu.Event(False, False, False) + e_fa1 = torch.npu.Event(False, False, False) + e_fa2 = torch.npu.Event(False, False, False) + e_fa3 = torch.npu.Event(False, False, False) + with torch.npu.stream(s_cpu): + atten_mask1 = torch_npu._npu_dropout_gen_mask(x_mask, [1, 8, 2048, 2048], p=0.5, seed=1, offset=0, + parallel=False).view(2048, 2048) + s_cpu.record_event(e_mask1) + qk = torch.matmul(query, key.transpose(2, 3)).mul(scale) + kv = torch.matmul(key, value.transpose(2, 3)).mul(scale) + qv = torch.matmul(query, value.transpose(2, 3)).mul(scale) + add_res = qk + kv + qv + s.wait_event(e_mask1) + res = torch_npu.npu_fusion_attention(query, key, value, head_num=8, input_layout="BNSD", scale=scale, + sparse_mode=2, atten_mask=atten_mask1, pre_tockens=2048, next_tockens=2048) + s.record_event(e_fa1) + with torch.npu.stream(s_cpu): + s_cpu.wait_event(e_fa1) + atten_mask2 = torch_npu._npu_dropout_gen_mask(x_mask, [1, 8, 2048, 2048], p=0.5, seed=1, offset=0, + parallel=False).view(2048, 2048) + s_cpu.record_event(e_mask2) + qk = torch.matmul(query, key.transpose(2, 3)).mul(scale) + kv = torch.matmul(key, value.transpose(2, 3)).mul(scale) + qv = torch.matmul(query, value.transpose(2, 3)).mul(scale) + add_res = qk + kv + qv + s.wait_event(e_mask2) + res = torch_npu.npu_fusion_attention(query, key, value, head_num=8, input_layout="BNSD", scale=scale, + sparse_mode=2, atten_mask=atten_mask2, pre_tockens=2048, next_tockens=2048) + s.record_event(e_fa2) + with torch.npu.stream(s_cpu): + s_cpu.wait_event(e_fa2) + atten_mask3 = torch_npu._npu_dropout_gen_mask(x_mask, [1, 8, 2048, 2048], p=0.5, seed=1, offset=0, + parallel=False).view(2048, 2048) + s_cpu.record_event(e_mask3) + qk = torch.matmul(query, key.transpose(2, 3)).mul(scale) + kv = torch.matmul(key, value.transpose(2, 3)).mul(scale) + qv = torch.matmul(query, value.transpose(2, 3)).mul(scale) + add_res = qk + kv + qv + s.wait_event(e_mask3) + res = torch_npu.npu_fusion_attention(query, key, value, head_num=8, input_layout="BNSD", scale=scale, + sparse_mode=2, atten_mask=atten_mask3, pre_tockens=2048, next_tockens=2048) + s.record_event(e_fa3) + with torch.npu.stream(s_cpu): + s_cpu.wait_event(e_fa3) + atten_mask4 = torch_npu._npu_dropout_gen_mask(x_mask, [1, 8, 2048, 2048], p=0.5, seed=1, offset=0, + parallel=False).view(2048, 2048) + s_cpu.record_event(e_mask4) + qk = torch.matmul(query, key.transpose(2, 3)).mul(scale) + kv = torch.matmul(key, value.transpose(2, 3)).mul(scale) + qv = torch.matmul(query, value.transpose(2, 3)).mul(scale) + add_res = qk + kv + qv + s.wait_event(e_mask4) + res = torch_npu.npu_fusion_attention(query, key, value, head_num=8, input_layout="BNSD", scale=scale, + sparse_mode=2, atten_mask=atten_mask4, pre_tockens=2048, next_tockens=2048) + return res, add_res + + + model = Model().npu() + query = torch.randn(1, 8, 512, 512, dtype=torch.float16).npu() + key = torch.randn(1, 8, 512, 512, dtype=torch.float16).npu() + value = torch.randn(1, 8, 512, 512, dtype=torch.float16).npu() + x_mask = torch.randn(1, 8, 512, 512, dtype=torch.float16).npu() + model(query, key, value, x_mask) + + +if __name__ == '__main__': + run_tests() + \ No newline at end of file