diff --git a/test/dynamo/test_functions.py b/test/dynamo/test_functions.py index be22e333e6a80a1051d0bab69fabc02325c9d4bf..c0fb2ec4a57c5c5d47d3180e8d6c9cc5c8111e61 100644 --- a/test/dynamo/test_functions.py +++ b/test/dynamo/test_functions.py @@ -2583,6 +2583,7 @@ def forward(self, x_1, output_1): self.assertEqual(fn(z), fn_opt(z)) + @torch._dynamo.config.patch(capture_func_transforms=True) def test_is_init_in_compile_vmapped_mutated_tensor_tensor(self): def fn(z): x = z.clone() @@ -2596,6 +2597,7 @@ def forward(self, x_1, output_1): self.assertEqual(fn(z), fn_opt(z)) + @torch._dynamo.config.patch(capture_func_transforms=True) def test_is_vmapped_mutated_tensor_tensor(self): def fn(x): y = torch.vmap(torch.Tensor.acos_)(x) @@ -2606,7 +2608,8 @@ def forward(self, x_1, output_1): z = torch.ones(4, 1) self.assertEqual(fn(z), fn_opt(z)) - + + @torch._dynamo.config.patch(capture_func_transforms=True) def test_is_init_in_compile_vmapped_mutated_tensor_tensor_multi_arg(self): def fn(y, z): a = y.clone() diff --git a/test/dynamo/test_higher_order_ops.py b/test/dynamo/test_higher_order_ops.py index 8696b535809da41ee49d046191a42b78c6743f81..5a2c447c55c776744bcd91d87520e4dca2571425 100644 --- a/test/dynamo/test_higher_order_ops.py +++ b/test/dynamo/test_higher_order_ops.py @@ -2078,7 +2078,8 @@ class GraphModule(torch.nn.Module): pprint.pformat(actual_stack), """{'add': ['map', 'map', 'add'], 'cos': ['map', 'cos'], 'sin': ['sin']}""", ) - + + @config.patch(capture_func_transforms=True) def test_grad_source_fn_stack(self): backend = EagerAndRecordGraphs() @@ -2100,7 +2101,8 @@ class GraphModule(torch.nn.Module): {'sin': ['grad_impl', 'grad_impl', 'sin'], 'sum_1': ['grad_impl', 'grad_impl', 'sum_1']}""", ) - + + @config.patch(capture_func_transforms=True) def test_vmap_source_fn_stack(self): backend = EagerAndRecordGraphs() @@ -2230,7 +2232,8 @@ class FuncTorchHigherOrderOpTests(torch._dynamo.test_case.TestCase): wrapped_gm = backend.graphs[graph_idx] return wrapped_gm - + + @config.patch(capture_func_transforms=True) def test_grad(self): counters.clear() @@ -2269,6 +2272,7 @@ class GraphModule(torch.nn.Module): """, ) + @config.patch(capture_func_transforms=True) def test_grad_freevar_tensor(self): counters.clear() y = torch.randn(3, 3) @@ -2284,6 +2288,7 @@ class GraphModule(torch.nn.Module): actual = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=True)(x) self.assertEqual(actual, expected) + @config.patch(capture_func_transforms=True) def test_grad_freevar_python_scalar(self): counters.clear() y = 3 @@ -2324,6 +2329,7 @@ class GraphModule(torch.nn.Module): """, ) + @config.patch(capture_func_transforms=True) def test_grad_capture_tensor(self): counters.clear() @@ -2368,6 +2374,7 @@ class GraphModule(torch.nn.Module): """, ) + @config.patch(capture_func_transforms=True) def test_grad_closure_scalar(self): counters.clear() @@ -2412,6 +2419,7 @@ class GraphModule(torch.nn.Module): """, ) + @config.patch(capture_func_transforms=True) def test_grad_has_aux(self): counters.clear() @@ -2456,6 +2464,7 @@ class GraphModule(torch.nn.Module): """, ) + @config.patch(capture_func_transforms=True) def test_grad_two_tensor_has_aux(self): counters.clear() @@ -2500,6 +2509,7 @@ class GraphModule(torch.nn.Module): """, ) + @config.patch(capture_func_transforms=True) def test_grad_two_tensor_all_grad_has_aux(self): counters.clear() @@ -2585,7 +2595,8 @@ class GraphModule(torch.nn.Module): return (sum_1, cos) """, ) - + + @config.patch(capture_func_transforms=True) def test_grad_over_grad(self): counters.clear() @@ -2631,6 +2642,7 @@ class GraphModule(torch.nn.Module): """, ) + @config.patch(capture_func_transforms=True) def test_grad_with_graph_break(self): counters.clear() @@ -2647,6 +2659,7 @@ class GraphModule(torch.nn.Module): self.assertEqual(len(counters["graph_break"]), 1) self.assertEqual(actual, expected) + @config.patch(capture_func_transforms=True) def test_grad_with_side_effect(self): counters.clear() @@ -2672,6 +2685,7 @@ class GraphModule(torch.nn.Module): ) self.assertEqual(actual, expected) + @config.patch(capture_func_transforms=True) def test_grad_pytree(self): counters.clear() @@ -2695,7 +2709,8 @@ class GraphModule(torch.nn.Module): {".*HigherOrderOperator with body that accepts non-Tensors as input": 2}, ) self.assertEqual(actual, expected) - + + @config.patch(capture_func_transforms=True) def test_grad_non_tensor_input(self): counters.clear() @@ -2735,7 +2750,8 @@ class GraphModule(torch.nn.Module): return add """, ) - + + @config.patch(capture_func_transforms=True) def test_grad_disable_capture(self): counters.clear() @@ -2762,7 +2778,8 @@ class GraphModule(torch.nn.Module): }, ) self.assertEqual(actual, expected) - + + @config.patch(capture_func_transforms=True) def test_grad_fn_with_kwargs(self): def fn(x, y): return (x + y).sum() @@ -2781,6 +2798,7 @@ class GraphModule(torch.nn.Module): ) self.assertEqual(actual, expected) + @config.patch(capture_func_transforms=True) def test_vmap(self): def fn(x): return torch.func.vmap(lambda x: x.sum(0) + x.sum(1))(x) @@ -2816,7 +2834,8 @@ class GraphModule(torch.nn.Module): return add """, ) - + + @config.patch(capture_func_transforms=True) def test_vmap_free_const(self): y = 3 @@ -2856,6 +2875,7 @@ class GraphModule(torch.nn.Module): """, ) + @config.patch(capture_func_transforms=True) def test_vmap_free_tensor(self): y = torch.randn(3, 3) @@ -2896,6 +2916,7 @@ class GraphModule(torch.nn.Module): """, ) + @config.patch(capture_func_transforms=True) def test_vmap_two_inputs(self): def fn(x, y): return torch.func.vmap( @@ -2938,6 +2959,7 @@ class GraphModule(torch.nn.Module): """, ) + @config.patch(capture_func_transforms=True) def test_vmap_two_inputs_tuple_in_dims(self): in_dims = (0, 1) @@ -2982,6 +3004,7 @@ class GraphModule(torch.nn.Module): """, ) + @config.patch(capture_func_transforms=True) def test_vmap_over_vmap_two_inputs(self): def fn(x, y): return torch.func.vmap(torch.func.vmap(lambda x, y: x + y, in_dims=1))(x, y) @@ -3028,7 +3051,8 @@ class GraphModule(torch.nn.Module): return add """, ) - + + @config.patch(capture_func_transforms=True) def test_vmap_over_vmap_captured(self): x = torch.ones(2, 3) y = torch.ones(5, 3) @@ -3074,7 +3098,8 @@ class GraphModule(torch.nn.Module): return mul """, ) - + + @config.patch(capture_func_transforms=True) def test_vmap_multiple_outputs(self): x = torch.ones(2, 4, 3) @@ -3112,7 +3137,8 @@ class GraphModule(torch.nn.Module): return (sum_1, sum_2) """, ) - + + @config.patch(capture_func_transforms=True) def test_vmap_multiple_outputs_diff_dims(self): x = torch.ones(2, 4, 3) @@ -3150,7 +3176,8 @@ class GraphModule(torch.nn.Module): return (sum_1, sum_2) """, ) - + + @config.patch(capture_func_transforms=True) def test_vmap_multiple_outputs_out_dims_tuple(self): x = torch.ones(2, 4, 3) out_dims = (1, 0) @@ -3189,7 +3216,8 @@ class GraphModule(torch.nn.Module): return (sum_1, sum_2) """, ) - + + @config.patch(capture_func_transforms=True) def test_vmap_kwargs(self): counters.clear() x = torch.ones(2, 3) @@ -3207,6 +3235,7 @@ class GraphModule(torch.nn.Module): ) self.assertEqual(actual, expected) + @config.patch(capture_func_transforms=True) def test_vmap_pytree_inputs(self): counters.clear() x = torch.ones(2, 3) @@ -3232,7 +3261,8 @@ class GraphModule(torch.nn.Module): }, ) self.assertEqual(actual, expected) - + + @config.patch(capture_func_transforms=True) def test_vmap_side_effects(self): counters.clear() x = torch.ones(2, 3) @@ -3258,7 +3288,8 @@ class GraphModule(torch.nn.Module): }, ) self.assertEqual(actual, expected) - + + @config.patch(capture_func_transforms=True) def test_vmap_disable_capture(self): counters.clear() @@ -3282,7 +3313,8 @@ class GraphModule(torch.nn.Module): }, ) self.assertEqual(actual, expected) - + + @config.patch(capture_func_transforms=True) def test_vmap_illegal_op_graph_break(self): counters.clear() @@ -3303,7 +3335,8 @@ class GraphModule(torch.nn.Module): {".*Illegal getattr invocation stride in strict mode": 2}, ) self.assertEqual(actual, expected) - + + @config.patch(capture_func_transforms=True) def test_vmap_multiple_invocation_in_dims(self): counters.clear() @@ -3319,7 +3352,8 @@ class GraphModule(torch.nn.Module): self.assertEqual(expected, actual) self.assertEqual(cnt.frame_count, 3) self.assertEqual(cnt.op_count, 9) - + + @config.patch(capture_func_transforms=True) def test_vmap_multiple_invocation_out_dims(self): counters.clear() @@ -3335,7 +3369,8 @@ class GraphModule(torch.nn.Module): self.assertEqual(expected, actual) self.assertEqual(cnt.frame_count, 3) self.assertEqual(cnt.op_count, 9) - + + @config.patch(capture_func_transforms=True) def test_vmap_new_tensor_in_body(self): def fn(x): return x + torch.ones(3) @@ -3350,7 +3385,8 @@ class GraphModule(torch.nn.Module): expected = wrapper_fn(x) actual = opt(x) self.assertEqual(expected, actual) - + + @config.patch(capture_func_transforms=True) def test_vmap_new_tensor_unused_in_body(self): def fn(x): return torch.tensor(0.5) @@ -3363,7 +3399,8 @@ class GraphModule(torch.nn.Module): expected = wrapper_fn(x) actual = opt(x) self.assertEqual(expected, actual) - + + @config.patch(capture_func_transforms=True) def test_vmap_new_tensor_implicit_via_op(self): def wrapper_fn(x): return torch.func.vmap(lambda t: torch.add(t, 0.5))(x) diff --git a/test/dynamo/test_interop.py b/test/dynamo/test_interop.py index 48cd8ba4bdaa6ad0021f03f02f98d907cfbd1371..34f6194225ffb4bd66502ecc168249a778bc1c02 100644 --- a/test/dynamo/test_interop.py +++ b/test/dynamo/test_interop.py @@ -31,6 +31,7 @@ class InteropTests(torch._dynamo.test_case.TestCase): trace_fn = torch.jit.trace(fn, [torch.zeros(10), torch.zeros(10)]) self._common(lambda a, b: trace_fn(a, b) + 1) + @torch._dynamo.config.patch(capture_func_transforms=True) def test_vmap_in_graph(self): from functools import wraps