[dynamo] support dict.copy() / OrderedDict.copy() / defaultdict.copy() (#115012)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115012
Approved by: https://github.com/jansel
ghstack dependencies: #115010, #115011
This commit is contained in:
Xuehai Pan 2023-12-03 17:40:25 +08:00 committed by PyTorch MergeBot
parent 917a52d2a2
commit 3fbfa8cd0a
7 changed files with 22 additions and 7 deletions

View file

@ -118,7 +118,7 @@ hf_T5,pass,0
hf_T5_generate,pass,20
hf_T5_generate,pass,18

1 name accuracy graph_breaks
118
119
120
121
122
123
124

View file

@ -118,7 +118,7 @@ hf_T5,pass,0
hf_T5_generate,fail_to_run,10
hf_T5_generate,fail_to_run,9

1 name accuracy graph_breaks
118
119
120
121
122
123
124

View file

@ -118,7 +118,7 @@ hf_T5,pass,0
hf_T5_generate,fail_to_run,10
hf_T5_generate,fail_to_run,9

1 name accuracy graph_breaks
118
119
120
121
122
123
124

View file

@ -118,7 +118,7 @@ hf_T5,pass,0
hf_T5_generate,pass,20
hf_T5_generate,pass,18

1 name accuracy graph_breaks
118
119
120
121
122
123
124

View file

@ -118,7 +118,7 @@ hf_T5,pass,0
hf_T5_generate,pass,20
hf_T5_generate,pass,18

1 name accuracy graph_breaks
118
119
120
121
122
123
124

View file

@ -817,6 +817,20 @@ class FunctionTests(torch._dynamo.test_case.TestCase):
d3 = collections.OrderedDict.fromkeys(tuple(lst), value=y)
return d1["a"] * d2["b"] + d2["a"] + d1["b"] + d3["a"] + d3["b"] + 1
@make_test
def test_dict_copy(x):
my_list = [("a", x), ("b", x + 1), ("c", x + 2)]
d1 = dict(my_list)
d1["a"] = x + 10
d2 = d1.copy()
d2["a"] = x - 5
d2["b"] = x + 3
d3 = collections.OrderedDict(my_list)
d3["c"] = x + 20
d4 = d3.copy()
d4["c"] = x - 10
return d1["a"] * d2["a"] + d2["b"] + d3["c"] * d4["c"] + 1
@make_test
def test_dict_update(x, y, z):
d = {"a": x, "b": y}

View file

@ -89,7 +89,6 @@ class ConstDictVariable(VariableTracker):
if name == "__getitem__":
assert len(args) == 1
return self.getitem_const(args[0])
elif name == "items":
assert not (args or kwargs)
return TupleVariable(
@ -118,10 +117,12 @@ class ConstDictVariable(VariableTracker):
],
mutable_local=MutableLocal(),
)
elif name == "values":
assert not (args or kwargs)
return TupleVariable(list(val.values()))
elif name == "copy":
assert not (args or kwargs)
return self.modifed(self.items.copy(), mutable_local=MutableLocal())
elif name == "__len__":
assert not (args or kwargs)
return ConstantVariable.create(len(self.items))