support intersection by polyfill (#130672)

Fixes https://github.com/pytorch/pytorch/issues/130557

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130672
Approved by: https://github.com/anijain2305
This commit is contained in:
awayzjj 2024-07-14 10:44:26 +00:00 committed by PyTorch MergeBot
parent 4d7bf72d93
commit dcaa111dc8
3 changed files with 87 additions and 0 deletions

View file

@ -1252,6 +1252,51 @@ class FunctionTests(torch._dynamo.test_case.TestCase):
test = make_test(fn)
test(self)
@make_test
def test_set_intersection(a, b):
set1 = {"apple", "banana", "cherry"}
set2 = {"google", "microsoft", "apple"}
intersection_set = set1.intersection(set2)
if "apple" in intersection_set:
x = a + b
else:
x = a - b
if "banana" in intersection_set:
y = a + b
else:
y = a - b
return x, y
@make_test
def test_set_union(a, b):
set1 = {"apple", "banana", "cherry"}
set2 = {"google", "microsoft", "apple"}
union_set = set1.union(set2)
if "apple" in union_set:
x = a + b
else:
x = a - b
if "banana" in union_set:
y = a + b
else:
y = a - b
return x, y
@make_test
def test_set_difference(a, b):
set1 = {"apple", "banana", "cherry"}
set2 = {"google", "microsoft", "apple"}
difference_set = set1.difference(set2)
if "apple" in difference_set:
x = a + b
else:
x = a - b
if "banana" in difference_set:
y = a + b
else:
y = a - b
return x, y
@make_test
def test_tuple_iadd(a, b):
output = (a, b)

View file

@ -63,6 +63,30 @@ def set_isdisjoint(set1, set2):
return True
def set_intersection(set1, set2):
intersection_set = set()
for x in set1:
if x in set2:
intersection_set.add(x)
return intersection_set
def set_union(set1, set2):
union_set = set1.copy()
for x in set2:
if x not in union_set:
union_set.add(x)
return union_set
def set_difference(set1, set2):
difference_set = set()
for x in set1:
if x not in set2:
difference_set.add(x)
return difference_set
def dropwhile(predicate, iterable):
# dropwhile(lambda x: x<5, [1,4,6,4,1]) -> 6 4 1
iterable = iter(iterable)

View file

@ -444,6 +444,24 @@ class SetVariable(ConstDictVariable):
return variables.UserFunctionVariable(
polyfill.set_isdisjoint
).call_function(tx, [self, args[0]], {})
elif name == "intersection":
assert not kwargs
assert len(args) == 1
return variables.UserFunctionVariable(
polyfill.set_intersection
).call_function(tx, [self, args[0]], {})
elif name == "union":
assert not kwargs
assert len(args) == 1
return variables.UserFunctionVariable(polyfill.set_union).call_function(
tx, [self, args[0]], {}
)
elif name == "difference":
assert not kwargs
assert len(args) == 1
return variables.UserFunctionVariable(
polyfill.set_difference
).call_function(tx, [self, args[0]], {})
elif (
name == "update"
and len(args) == 1