mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[dynamo] support dict.copy() / OrderedDict.copy() / defaultdict.copy() (#115012)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115012 Approved by: https://github.com/jansel ghstack dependencies: #115010, #115011
This commit is contained in:
parent
917a52d2a2
commit
3fbfa8cd0a
7 changed files with 22 additions and 7 deletions
|
|
@ -118,7 +118,7 @@ hf_T5,pass,0
|
|||
|
||||
|
||||
|
||||
hf_T5_generate,pass,20
|
||||
hf_T5_generate,pass,18
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
|
|
|
@ -118,7 +118,7 @@ hf_T5,pass,0
|
|||
|
||||
|
||||
|
||||
hf_T5_generate,fail_to_run,10
|
||||
hf_T5_generate,fail_to_run,9
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
|
|
|
@ -118,7 +118,7 @@ hf_T5,pass,0
|
|||
|
||||
|
||||
|
||||
hf_T5_generate,fail_to_run,10
|
||||
hf_T5_generate,fail_to_run,9
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
|
|
|
@ -118,7 +118,7 @@ hf_T5,pass,0
|
|||
|
||||
|
||||
|
||||
hf_T5_generate,pass,20
|
||||
hf_T5_generate,pass,18
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
|
|
|
@ -118,7 +118,7 @@ hf_T5,pass,0
|
|||
|
||||
|
||||
|
||||
hf_T5_generate,pass,20
|
||||
hf_T5_generate,pass,18
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
|
|
|
@ -817,6 +817,20 @@ class FunctionTests(torch._dynamo.test_case.TestCase):
|
|||
d3 = collections.OrderedDict.fromkeys(tuple(lst), value=y)
|
||||
return d1["a"] * d2["b"] + d2["a"] + d1["b"] + d3["a"] + d3["b"] + 1
|
||||
|
||||
@make_test
|
||||
def test_dict_copy(x):
|
||||
my_list = [("a", x), ("b", x + 1), ("c", x + 2)]
|
||||
d1 = dict(my_list)
|
||||
d1["a"] = x + 10
|
||||
d2 = d1.copy()
|
||||
d2["a"] = x - 5
|
||||
d2["b"] = x + 3
|
||||
d3 = collections.OrderedDict(my_list)
|
||||
d3["c"] = x + 20
|
||||
d4 = d3.copy()
|
||||
d4["c"] = x - 10
|
||||
return d1["a"] * d2["a"] + d2["b"] + d3["c"] * d4["c"] + 1
|
||||
|
||||
@make_test
|
||||
def test_dict_update(x, y, z):
|
||||
d = {"a": x, "b": y}
|
||||
|
|
|
|||
|
|
@ -89,7 +89,6 @@ class ConstDictVariable(VariableTracker):
|
|||
if name == "__getitem__":
|
||||
assert len(args) == 1
|
||||
return self.getitem_const(args[0])
|
||||
|
||||
elif name == "items":
|
||||
assert not (args or kwargs)
|
||||
return TupleVariable(
|
||||
|
|
@ -118,10 +117,12 @@ class ConstDictVariable(VariableTracker):
|
|||
],
|
||||
mutable_local=MutableLocal(),
|
||||
)
|
||||
|
||||
elif name == "values":
|
||||
assert not (args or kwargs)
|
||||
return TupleVariable(list(val.values()))
|
||||
elif name == "copy":
|
||||
assert not (args or kwargs)
|
||||
return self.modifed(self.items.copy(), mutable_local=MutableLocal())
|
||||
elif name == "__len__":
|
||||
assert not (args or kwargs)
|
||||
return ConstantVariable.create(len(self.items))
|
||||
|
|
|
|||
Loading…
Reference in a new issue