2023-06-20 19:02:37 +00:00
|
|
|
# Owner(s): ["module: dynamo"]
|
[dynamo] handle setting .data on a tensor (#113080)
**Dynamo**
We don't want setattr in the graph. Setting data has interesting implications on both aliasing and on the autograd engine.
The safe recipe is:
1) Disable grad
2) Call set_()
3) Manually lower the version counter on the object to hide it from the autograd engine
This is effectively the same exact thing as setting .data, and it composes properly with aot_autograd and inductor.
**aot_autograd**
For aot_autograd, there's another snag.
Specifically, when we invoke aot_autograd, we call `fake_mode.from_tensor()`, relying on memo to get the right tensor out. For .data mutations, this doesn't work, because the memoized fake_tensor is in the state it will be in at the end of the trace, not at the beginning. This means that the .data call is already applied, and the tensor shape (as in the case of these tests) mismatches. aot_autograd produces an invalid graph, with illegal calls like `torch.ops.aten.view.default(primals_2, [0])` where primals is actually sized `([6])` on input.
The new plan here is to:
1) Record tensor fakification policy in dynamo
2) provide a fresh fake mode to all backends
3) Invoke from_tensor with the stored policy to get fresh new fake tensors in aot_autograd
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113080
Approved by: https://github.com/bdhirsh
2023-12-01 21:12:05 +00:00
|
|
|
# flake8: noqa
|
2023-06-20 19:02:37 +00:00
|
|
|
import torch
|
|
|
|
|
import torch._dynamo
|
|
|
|
|
import torch._dynamo.test_case
|
|
|
|
|
import torch._dynamo.testing
|
[dynamo] handle setting .data on a tensor (#113080)
**Dynamo**
We don't want setattr in the graph. Setting data has interesting implications on both aliasing and on the autograd engine.
The safe recipe is:
1) Disable grad
2) Call set_()
3) Manually lower the version counter on the object to hide it from the autograd engine
This is effectively the same exact thing as setting .data, and it composes properly with aot_autograd and inductor.
**aot_autograd**
For aot_autograd, there's another snag.
Specifically, when we invoke aot_autograd, we call `fake_mode.from_tensor()`, relying on memo to get the right tensor out. For .data mutations, this doesn't work, because the memoized fake_tensor is in the state it will be in at the end of the trace, not at the beginning. This means that the .data call is already applied, and the tensor shape (as in the case of these tests) mismatches. aot_autograd produces an invalid graph, with illegal calls like `torch.ops.aten.view.default(primals_2, [0])` where primals is actually sized `([6])` on input.
The new plan here is to:
1) Record tensor fakification policy in dynamo
2) provide a fresh fake mode to all backends
3) Invoke from_tensor with the stored policy to get fresh new fake tensors in aot_autograd
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113080
Approved by: https://github.com/bdhirsh
2023-12-01 21:12:05 +00:00
|
|
|
from torch._dynamo.testing import (
|
|
|
|
|
CompileCounter,
|
|
|
|
|
CompileCounterWithBackend,
|
|
|
|
|
EagerAndRecordGraphs,
|
|
|
|
|
normalize_gm,
|
|
|
|
|
)
|
2023-06-20 19:02:37 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestInputAttrTracking(torch._dynamo.test_case.TestCase):
|
|
|
|
|
def test_tensor_property_on_tensor(self):
|
|
|
|
|
def fn(x):
|
|
|
|
|
return x * x.y
|
|
|
|
|
|
|
|
|
|
x_ = torch.randn([2, 2])
|
|
|
|
|
y_ = torch.randn([2, 2])
|
|
|
|
|
x_.y = y_
|
|
|
|
|
|
|
|
|
|
eager_result = fn(x_)
|
|
|
|
|
|
|
|
|
|
graph = None
|
|
|
|
|
|
|
|
|
|
def grab_graph_backend(gm, inps):
|
|
|
|
|
nonlocal graph
|
|
|
|
|
graph = gm
|
|
|
|
|
return gm
|
|
|
|
|
|
|
|
|
|
fn = torch._dynamo.optimize(grab_graph_backend, nopython=True)(fn)
|
|
|
|
|
compile_result = fn(x_)
|
|
|
|
|
self.assertEqual(eager_result, compile_result)
|
|
|
|
|
|
|
|
|
|
placeholder_cnt = 0
|
|
|
|
|
for node in graph.graph.nodes:
|
|
|
|
|
if node.op == "placeholder":
|
|
|
|
|
placeholder_cnt += 1
|
|
|
|
|
|
|
|
|
|
# We want to be very sure that this lifts y to inputs!
|
|
|
|
|
self.assertEqual(placeholder_cnt, 2)
|
|
|
|
|
|
|
|
|
|
def test_tensor_property_assigned_on_tensor(self):
|
|
|
|
|
def fn(x, y):
|
|
|
|
|
x.y = y
|
|
|
|
|
return x * x.y
|
|
|
|
|
|
|
|
|
|
x_ = torch.randn([2, 2])
|
|
|
|
|
y_ = torch.randn([2, 2])
|
|
|
|
|
|
|
|
|
|
eager_result = fn(x_, y_)
|
|
|
|
|
|
|
|
|
|
graph = None
|
|
|
|
|
|
|
|
|
|
def grab_graph_backend(gm, inps):
|
|
|
|
|
nonlocal graph
|
|
|
|
|
graph = gm
|
|
|
|
|
return gm
|
|
|
|
|
|
|
|
|
|
fn = torch._dynamo.optimize(grab_graph_backend, nopython=True)(fn)
|
|
|
|
|
compile_result = fn(x_, y_)
|
|
|
|
|
self.assertEqual(eager_result, compile_result)
|
|
|
|
|
|
|
|
|
|
placeholder_cnt = 0
|
|
|
|
|
for node in graph.graph.nodes:
|
|
|
|
|
if node.op == "placeholder":
|
|
|
|
|
placeholder_cnt += 1
|
|
|
|
|
|
|
|
|
|
# y is already an input
|
|
|
|
|
self.assertEqual(placeholder_cnt, 2)
|
|
|
|
|
|
|
|
|
|
def test_const_property_on_tensor(self):
|
|
|
|
|
def fn(x):
|
|
|
|
|
return x * x.y
|
|
|
|
|
|
|
|
|
|
x_ = torch.randn([2, 2])
|
|
|
|
|
y_ = 4
|
|
|
|
|
x_.y = y_
|
|
|
|
|
|
|
|
|
|
eager_result = fn(x_)
|
|
|
|
|
|
|
|
|
|
graph = None
|
|
|
|
|
|
|
|
|
|
def grab_graph_backend(gm, inps):
|
|
|
|
|
nonlocal graph
|
|
|
|
|
graph = gm
|
|
|
|
|
return gm
|
|
|
|
|
|
|
|
|
|
fn = torch._dynamo.optimize(grab_graph_backend, nopython=True)(fn)
|
|
|
|
|
compile_result = fn(x_)
|
|
|
|
|
self.assertEqual(eager_result, compile_result)
|
|
|
|
|
|
|
|
|
|
placeholder_cnt = 0
|
|
|
|
|
for node in graph.graph.nodes:
|
|
|
|
|
if node.op == "placeholder":
|
|
|
|
|
placeholder_cnt += 1
|
|
|
|
|
|
|
|
|
|
# We want to be very sure that this does not lifts y to inputs, as its a const
|
|
|
|
|
self.assertEqual(placeholder_cnt, 1)
|
|
|
|
|
|
|
|
|
|
def test_const_property_assigned_on_tensor(self):
|
|
|
|
|
def fn(x, y):
|
|
|
|
|
x.y = y
|
|
|
|
|
return x * x.y
|
|
|
|
|
|
|
|
|
|
x_ = torch.randn([2, 2])
|
|
|
|
|
y_ = 4
|
|
|
|
|
|
|
|
|
|
eager_result = fn(x_, y_)
|
|
|
|
|
|
|
|
|
|
fn = torch._dynamo.optimize("eager", nopython=True)(fn)
|
|
|
|
|
compile_result = fn(x_, y_)
|
|
|
|
|
self.assertEqual(eager_result, compile_result)
|
|
|
|
|
|
|
|
|
|
def test_guards_correctly_property_assigned_on_tensor_type_change(self):
|
|
|
|
|
def fn(x, y):
|
|
|
|
|
x.y = y
|
|
|
|
|
return x * x.y
|
|
|
|
|
|
|
|
|
|
x_ = torch.randn([2, 2])
|
|
|
|
|
|
|
|
|
|
fn = torch._dynamo.optimize("eager", nopython=True)(fn)
|
|
|
|
|
compile_result_const = fn(x_, 4)
|
|
|
|
|
self.assertEqual(compile_result_const, x_ * 4)
|
|
|
|
|
|
|
|
|
|
y = torch.randn([2, 2])
|
|
|
|
|
compile_result_tensor = fn(x_, y)
|
|
|
|
|
self.assertEqual(compile_result_tensor, x_ * y)
|
|
|
|
|
|
|
|
|
|
def test_guards_correctly_property_assigned_on_tensor_type_change_inductor(self):
|
|
|
|
|
def fn(x, y):
|
|
|
|
|
x.y = y
|
|
|
|
|
return x * x.y
|
|
|
|
|
|
|
|
|
|
x_ = torch.randn([2, 2])
|
|
|
|
|
|
|
|
|
|
fn = torch._dynamo.optimize("inductor", nopython=True)(fn)
|
|
|
|
|
compile_result_const = fn(x_, 4)
|
|
|
|
|
self.assertEqual(compile_result_const, x_ * 4)
|
|
|
|
|
|
|
|
|
|
y = torch.randn([2, 2])
|
|
|
|
|
compile_result_tensor = fn(x_, y)
|
|
|
|
|
self.assertEqual(compile_result_tensor, x_ * y)
|
|
|
|
|
|
|
|
|
|
def test_complex_attr_access_without_graph_breaks(self):
|
|
|
|
|
def fn(x, y, z):
|
|
|
|
|
for t in x:
|
|
|
|
|
t.y = y
|
|
|
|
|
t.z = y * z
|
|
|
|
|
|
|
|
|
|
new_y = 1
|
|
|
|
|
new_z = 1
|
|
|
|
|
for t in x:
|
|
|
|
|
new_y = t.y * new_y
|
|
|
|
|
new_z = t.z * new_z
|
|
|
|
|
|
|
|
|
|
return new_y, new_z
|
|
|
|
|
|
|
|
|
|
x_0 = torch.randn([2, 2])
|
|
|
|
|
x_1 = torch.randn([2, 2])
|
|
|
|
|
x_2 = torch.randn([2, 2])
|
|
|
|
|
x = [x_0, x_1, x_2]
|
|
|
|
|
|
|
|
|
|
y = torch.randn([2, 2])
|
|
|
|
|
z = 5
|
|
|
|
|
|
|
|
|
|
eager_result = fn(x, y, z)
|
|
|
|
|
|
|
|
|
|
counter = CompileCounter()
|
|
|
|
|
fn = torch._dynamo.optimize(counter, nopython=True)(fn)
|
|
|
|
|
|
|
|
|
|
compile_result = fn(x, y, z)
|
|
|
|
|
self.assertEqual(compile_result, eager_result)
|
|
|
|
|
self.assertEqual(counter.frame_count, 1)
|
|
|
|
|
self.assertEqual(counter.op_count, 9)
|
|
|
|
|
# Graph for reference
|
|
|
|
|
# ------------- ------ ----------------------- ------------------------------------ --------
|
|
|
|
|
# placeholder l_y_ L_y_ () {}
|
|
|
|
|
# call_function mul <built-in function mul> (l_y_, 5) {}
|
|
|
|
|
# call_function mul_1 <built-in function mul> (l_y_, 5) {}
|
|
|
|
|
# call_function mul_2 <built-in function mul> (l_y_, 5) {}
|
|
|
|
|
# call_function mul_3 <built-in function mul> (l_y_, 1) {}
|
|
|
|
|
# call_function mul_4 <built-in function mul> (mul, 1) {}
|
|
|
|
|
# call_function mul_5 <built-in function mul> (l_y_, mul_3) {}
|
|
|
|
|
# call_function mul_6 <built-in function mul> (mul_1, mul_4) {}
|
|
|
|
|
# call_function mul_7 <built-in function mul> (l_y_, mul_5) {}
|
|
|
|
|
# call_function mul_8 <built-in function mul> (mul_2, mul_6) {}
|
|
|
|
|
# output output output ((mul_7, mul_8, mul, mul_1, mul_2),) {}
|
|
|
|
|
|
|
|
|
|
def test_complex_attr_access_with_graph_breaks(self):
|
|
|
|
|
def fn(x, y, z):
|
|
|
|
|
for t in x:
|
|
|
|
|
t.y = y
|
|
|
|
|
t.z = y * z
|
|
|
|
|
|
|
|
|
|
print("Break!")
|
|
|
|
|
|
|
|
|
|
new_y = 1
|
|
|
|
|
new_z = 1
|
|
|
|
|
for t in x:
|
|
|
|
|
new_y = t.y * new_y
|
|
|
|
|
new_z = t.z * new_z
|
|
|
|
|
|
|
|
|
|
return new_y, new_z
|
|
|
|
|
|
|
|
|
|
x_0 = torch.randn([2, 2])
|
|
|
|
|
x_1 = torch.randn([2, 2])
|
|
|
|
|
x_2 = torch.randn([2, 2])
|
|
|
|
|
x = [x_0, x_1, x_2]
|
|
|
|
|
|
|
|
|
|
y = torch.randn([2, 2])
|
|
|
|
|
z = 5
|
|
|
|
|
|
|
|
|
|
eager_result = fn(x, y, z)
|
|
|
|
|
|
|
|
|
|
counter = CompileCounter()
|
|
|
|
|
fn = torch._dynamo.optimize(counter, nopython=False)(fn)
|
|
|
|
|
|
|
|
|
|
compile_result = fn(x, y, z)
|
|
|
|
|
self.assertEqual(compile_result, eager_result)
|
|
|
|
|
self.assertEqual(counter.frame_count, 2)
|
|
|
|
|
self.assertEqual(counter.op_count, 9)
|
|
|
|
|
# Graph for reference
|
|
|
|
|
# ------------- ------ ----------------------- ---------------------- --------
|
|
|
|
|
# placeholder l_y_ L_y_ () {}
|
|
|
|
|
# call_function mul <built-in function mul> (l_y_, 5) {}
|
|
|
|
|
# call_function mul_1 <built-in function mul> (l_y_, 5) {}
|
|
|
|
|
# call_function mul_2 <built-in function mul> (l_y_, 5) {}
|
|
|
|
|
# output output output ((mul, mul_1, mul_2),) {}
|
|
|
|
|
# [GRAPH BREAK!]
|
|
|
|
|
# ------------- ------- ----------------------- ----------------- --------
|
|
|
|
|
# placeholder l_x_0_y L_x_0_y () {}
|
|
|
|
|
# placeholder l_x_0_z L_x_0_z () {}
|
|
|
|
|
# placeholder l_x_1_y L_x_1_y () {}
|
|
|
|
|
# placeholder l_x_1_z L_x_1_z () {}
|
|
|
|
|
# placeholder l_x_2_y L_x_2_y () {}
|
|
|
|
|
# placeholder l_x_2_z L_x_2_z () {}
|
|
|
|
|
# call_function mul <built-in function mul> (l_x_0_y, 1) {}
|
|
|
|
|
# call_function mul_1 <built-in function mul> (l_x_0_z, 1) {}
|
|
|
|
|
# call_function mul_2 <built-in function mul> (l_x_1_y, mul) {}
|
|
|
|
|
# call_function mul_3 <built-in function mul> (l_x_1_z, mul_1) {}
|
|
|
|
|
# call_function mul_4 <built-in function mul> (l_x_2_y, mul_2) {}
|
|
|
|
|
# call_function mul_5 <built-in function mul> (l_x_2_z, mul_3) {}
|
|
|
|
|
# output output output ((mul_4, mul_5),) {}
|
|
|
|
|
|
|
|
|
|
def test_complex_attr_access_with_inline_reconstruct(self):
|
|
|
|
|
def inline_test_fn(x, y, z):
|
|
|
|
|
print("f")
|
|
|
|
|
return x.a + y.a + z.a
|
|
|
|
|
|
|
|
|
|
def fn(x, y, z):
|
|
|
|
|
x.a = 1
|
|
|
|
|
y.a = 2
|
|
|
|
|
z.a = 3
|
|
|
|
|
|
|
|
|
|
mult = inline_test_fn(x, y, z)
|
|
|
|
|
y = y * mult
|
|
|
|
|
x = x * mult
|
|
|
|
|
return x, y
|
|
|
|
|
|
|
|
|
|
x = torch.randn([2, 2])
|
|
|
|
|
y = torch.randn([2, 2])
|
|
|
|
|
z = torch.randn([2, 2])
|
|
|
|
|
|
|
|
|
|
eager_result = fn(x, y, z)
|
|
|
|
|
|
|
|
|
|
counter = CompileCounter()
|
|
|
|
|
|
|
|
|
|
fn = torch._dynamo.optimize(counter, nopython=False)(fn)
|
|
|
|
|
|
|
|
|
|
compile_result = fn(x, y, z)
|
|
|
|
|
self.assertEqual(compile_result, eager_result)
|
|
|
|
|
self.assertEqual(counter.frame_count, 1)
|
|
|
|
|
self.assertEqual(counter.op_count, 2)
|
|
|
|
|
# Graph for reference
|
|
|
|
|
# __compiled_fn_2 <eval_with_key>.0 opcode name target args kwargs
|
|
|
|
|
# ------------- ------ ----------------------- --------------- --------
|
|
|
|
|
# placeholder l_x_ L_x_ () {}
|
|
|
|
|
# placeholder l_y_ L_y_ () {}
|
|
|
|
|
# call_function mul <built-in function mul> (l_y_, 6) {}
|
|
|
|
|
# call_function mul_1 <built-in function mul> (l_x_, 6) {}
|
|
|
|
|
# output output output ((mul_1, mul),) {}
|
|
|
|
|
|
|
|
|
|
def test_set_data_on_input_tensor(self):
|
|
|
|
|
def fn(x, y):
|
|
|
|
|
x.data = y.data
|
|
|
|
|
if x.size() == y.size():
|
|
|
|
|
return x * y
|
|
|
|
|
else:
|
|
|
|
|
return y * y
|
|
|
|
|
|
|
|
|
|
x = torch.randn([5, 5])
|
|
|
|
|
y = torch.randn([2, 2])
|
|
|
|
|
|
|
|
|
|
eager_result = fn(x, y)
|
|
|
|
|
|
[dynamo] handle setting .data on a tensor (#113080)
**Dynamo**
We don't want setattr in the graph. Setting data has interesting implications on both aliasing and on the autograd engine.
The safe recipe is:
1) Disable grad
2) Call set_()
3) Manually lower the version counter on the object to hide it from the autograd engine
This is effectively the same exact thing as setting .data, and it composes properly with aot_autograd and inductor.
**aot_autograd**
For aot_autograd, there's another snag.
Specifically, when we invoke aot_autograd, we call `fake_mode.from_tensor()`, relying on memo to get the right tensor out. For .data mutations, this doesn't work, because the memoized fake_tensor is in the state it will be in at the end of the trace, not at the beginning. This means that the .data call is already applied, and the tensor shape (as in the case of these tests) mismatches. aot_autograd produces an invalid graph, with illegal calls like `torch.ops.aten.view.default(primals_2, [0])` where primals is actually sized `([6])` on input.
The new plan here is to:
1) Record tensor fakification policy in dynamo
2) provide a fresh fake mode to all backends
3) Invoke from_tensor with the stored policy to get fresh new fake tensors in aot_autograd
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113080
Approved by: https://github.com/bdhirsh
2023-12-01 21:12:05 +00:00
|
|
|
eager_and_record = EagerAndRecordGraphs()
|
|
|
|
|
|
|
|
|
|
counter = CompileCounterWithBackend(eager_and_record)
|
2023-06-20 19:02:37 +00:00
|
|
|
|
|
|
|
|
fn = torch._dynamo.optimize(counter, nopython=True)(fn)
|
|
|
|
|
|
|
|
|
|
compile_result = fn(x, y)
|
[dynamo] handle setting .data on a tensor (#113080)
**Dynamo**
We don't want setattr in the graph. Setting data has interesting implications on both aliasing and on the autograd engine.
The safe recipe is:
1) Disable grad
2) Call set_()
3) Manually lower the version counter on the object to hide it from the autograd engine
This is effectively the same exact thing as setting .data, and it composes properly with aot_autograd and inductor.
**aot_autograd**
For aot_autograd, there's another snag.
Specifically, when we invoke aot_autograd, we call `fake_mode.from_tensor()`, relying on memo to get the right tensor out. For .data mutations, this doesn't work, because the memoized fake_tensor is in the state it will be in at the end of the trace, not at the beginning. This means that the .data call is already applied, and the tensor shape (as in the case of these tests) mismatches. aot_autograd produces an invalid graph, with illegal calls like `torch.ops.aten.view.default(primals_2, [0])` where primals is actually sized `([6])` on input.
The new plan here is to:
1) Record tensor fakification policy in dynamo
2) provide a fresh fake mode to all backends
3) Invoke from_tensor with the stored policy to get fresh new fake tensors in aot_autograd
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113080
Approved by: https://github.com/bdhirsh
2023-12-01 21:12:05 +00:00
|
|
|
|
|
|
|
|
graph = eager_and_record.graphs[0]
|
|
|
|
|
actual = normalize_gm(graph.print_readable(False))
|
|
|
|
|
|
2023-06-20 19:02:37 +00:00
|
|
|
self.assertEqual(compile_result, eager_result)
|
|
|
|
|
self.assertEqual(counter.frame_count, 1)
|
[dynamo] handle setting .data on a tensor (#113080)
**Dynamo**
We don't want setattr in the graph. Setting data has interesting implications on both aliasing and on the autograd engine.
The safe recipe is:
1) Disable grad
2) Call set_()
3) Manually lower the version counter on the object to hide it from the autograd engine
This is effectively the same exact thing as setting .data, and it composes properly with aot_autograd and inductor.
**aot_autograd**
For aot_autograd, there's another snag.
Specifically, when we invoke aot_autograd, we call `fake_mode.from_tensor()`, relying on memo to get the right tensor out. For .data mutations, this doesn't work, because the memoized fake_tensor is in the state it will be in at the end of the trace, not at the beginning. This means that the .data call is already applied, and the tensor shape (as in the case of these tests) mismatches. aot_autograd produces an invalid graph, with illegal calls like `torch.ops.aten.view.default(primals_2, [0])` where primals is actually sized `([6])` on input.
The new plan here is to:
1) Record tensor fakification policy in dynamo
2) provide a fresh fake mode to all backends
3) Invoke from_tensor with the stored policy to get fresh new fake tensors in aot_autograd
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113080
Approved by: https://github.com/bdhirsh
2023-12-01 21:12:05 +00:00
|
|
|
self.assertEqual(counter.op_count, 6)
|
|
|
|
|
self.assertExpectedInline(
|
|
|
|
|
actual,
|
|
|
|
|
"""\
|
|
|
|
|
class GraphModule(torch.nn.Module):
|
|
|
|
|
def forward(self, L_y_ : torch.Tensor, L_x_ : torch.Tensor):
|
|
|
|
|
l_y_ = L_y_
|
|
|
|
|
l_x_ = L_x_
|
|
|
|
|
|
|
|
|
|
detach = l_y_.detach()
|
|
|
|
|
|
|
|
|
|
_set_grad_enabled = torch._C._set_grad_enabled(False)
|
|
|
|
|
|
|
|
|
|
set_ = torch_Tensor_set_(l_x_, detach); detach = None
|
|
|
|
|
|
|
|
|
|
_set_grad_enabled_1 = torch._C._set_grad_enabled(True)
|
|
|
|
|
|
|
|
|
|
_lower_version_count_by_1 = torch__dynamo_variables_builtin__lower_version_count_by_1(set_); set_ = None
|
|
|
|
|
|
|
|
|
|
mul = l_x_ * l_y_; l_x_ = l_y_ = None
|
|
|
|
|
return (mul,)
|
|
|
|
|
""",
|
|
|
|
|
)
|
2023-06-20 19:02:37 +00:00
|
|
|
|
|
|
|
|
# Note - this does not actually get captured in the graph yet.
|
|
|
|
|
# The plan of record is to introduce a set_data op, entirely subsume the operation into a call_function
|
2023-10-08 20:52:35 +00:00
|
|
|
# in the fx graph, and let aot_autograd handle it.
|
2023-06-20 19:02:37 +00:00
|
|
|
def test_set_data_on_scoped_tensor(self):
|
|
|
|
|
def fn(x):
|
|
|
|
|
z = torch.zeros([4, 4])
|
|
|
|
|
z.data = x.data
|
|
|
|
|
if x.size() == z.size():
|
|
|
|
|
return z * x
|
|
|
|
|
else:
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
x = torch.randn([5, 5])
|
|
|
|
|
|
|
|
|
|
eager_result = fn(x)
|
|
|
|
|
|
|
|
|
|
counter = CompileCounter()
|
|
|
|
|
|
|
|
|
|
fn = torch._dynamo.optimize(counter, nopython=False)(fn)
|
|
|
|
|
|
|
|
|
|
compile_result = fn(x)
|
|
|
|
|
self.assertEqual(compile_result, eager_result)
|
|
|
|
|
self.assertEqual(counter.frame_count, 2)
|
|
|
|
|
self.assertEqual(counter.op_count, 3)
|
|
|
|
|
|
|
|
|
|
def test_set_data_on_user_defined_class_input_tensor(self):
|
|
|
|
|
class MyUserDefinedClass:
|
|
|
|
|
def __init__(self, x, y):
|
|
|
|
|
self.x = x
|
|
|
|
|
self.y = y
|
|
|
|
|
|
|
|
|
|
def do_some_setattr_stuff(self):
|
|
|
|
|
self.z = x * y
|
|
|
|
|
self.a = x + x
|
|
|
|
|
return self.z * self.a
|
|
|
|
|
|
|
|
|
|
x = torch.randn([5, 5])
|
|
|
|
|
y = torch.randn([5, 5])
|
|
|
|
|
mudc_1 = MyUserDefinedClass(x, y)
|
|
|
|
|
|
|
|
|
|
eager_result = mudc_1.do_some_setattr_stuff()
|
|
|
|
|
|
|
|
|
|
counter = CompileCounter()
|
|
|
|
|
|
|
|
|
|
mudc_2 = MyUserDefinedClass(x, y)
|
|
|
|
|
do_some_setattr_stuff = torch._dynamo.optimize(counter, nopython=True)(
|
|
|
|
|
mudc_2.do_some_setattr_stuff
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
compile_result = do_some_setattr_stuff()
|
|
|
|
|
self.assertEqual(compile_result, eager_result)
|
|
|
|
|
self.assertEqual(counter.frame_count, 1)
|
|
|
|
|
self.assertEqual(counter.op_count, 3)
|
|
|
|
|
# Graph for reference
|
|
|
|
|
# __compiled_fn_0 <eval_with_key>.0 opcode name target args kwargs
|
|
|
|
|
# ------------- ------ ----------------------- -------------------- --------
|
|
|
|
|
# placeholder l_x_ L_x_ () {}
|
|
|
|
|
# placeholder l_y_ L_y_ () {}
|
|
|
|
|
# call_function mul <built-in function mul> (l_x_, l_y_) {}
|
|
|
|
|
# call_function add <built-in function add> (l_x_, l_x_) {}
|
|
|
|
|
# call_function mul_1 <built-in function mul> (mul, add) {}
|
|
|
|
|
# output output output ((mul_1, mul, add),) {}
|
2023-11-08 22:17:13 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
from torch._dynamo.test_case import run_tests
|
|
|
|
|
|
|
|
|
|
run_tests()
|