mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
support intersection by polyfill (#130672)
Fixes https://github.com/pytorch/pytorch/issues/130557 Pull Request resolved: https://github.com/pytorch/pytorch/pull/130672 Approved by: https://github.com/anijain2305
This commit is contained in:
parent
4d7bf72d93
commit
dcaa111dc8
3 changed files with 87 additions and 0 deletions
|
|
@ -1252,6 +1252,51 @@ class FunctionTests(torch._dynamo.test_case.TestCase):
|
|||
test = make_test(fn)
|
||||
test(self)
|
||||
|
||||
@make_test
|
||||
def test_set_intersection(a, b):
|
||||
set1 = {"apple", "banana", "cherry"}
|
||||
set2 = {"google", "microsoft", "apple"}
|
||||
intersection_set = set1.intersection(set2)
|
||||
if "apple" in intersection_set:
|
||||
x = a + b
|
||||
else:
|
||||
x = a - b
|
||||
if "banana" in intersection_set:
|
||||
y = a + b
|
||||
else:
|
||||
y = a - b
|
||||
return x, y
|
||||
|
||||
@make_test
|
||||
def test_set_union(a, b):
|
||||
set1 = {"apple", "banana", "cherry"}
|
||||
set2 = {"google", "microsoft", "apple"}
|
||||
union_set = set1.union(set2)
|
||||
if "apple" in union_set:
|
||||
x = a + b
|
||||
else:
|
||||
x = a - b
|
||||
if "banana" in union_set:
|
||||
y = a + b
|
||||
else:
|
||||
y = a - b
|
||||
return x, y
|
||||
|
||||
@make_test
|
||||
def test_set_difference(a, b):
|
||||
set1 = {"apple", "banana", "cherry"}
|
||||
set2 = {"google", "microsoft", "apple"}
|
||||
difference_set = set1.difference(set2)
|
||||
if "apple" in difference_set:
|
||||
x = a + b
|
||||
else:
|
||||
x = a - b
|
||||
if "banana" in difference_set:
|
||||
y = a + b
|
||||
else:
|
||||
y = a - b
|
||||
return x, y
|
||||
|
||||
@make_test
|
||||
def test_tuple_iadd(a, b):
|
||||
output = (a, b)
|
||||
|
|
|
|||
|
|
@ -63,6 +63,30 @@ def set_isdisjoint(set1, set2):
|
|||
return True
|
||||
|
||||
|
||||
def set_intersection(set1, set2):
|
||||
intersection_set = set()
|
||||
for x in set1:
|
||||
if x in set2:
|
||||
intersection_set.add(x)
|
||||
return intersection_set
|
||||
|
||||
|
||||
def set_union(set1, set2):
|
||||
union_set = set1.copy()
|
||||
for x in set2:
|
||||
if x not in union_set:
|
||||
union_set.add(x)
|
||||
return union_set
|
||||
|
||||
|
||||
def set_difference(set1, set2):
|
||||
difference_set = set()
|
||||
for x in set1:
|
||||
if x not in set2:
|
||||
difference_set.add(x)
|
||||
return difference_set
|
||||
|
||||
|
||||
def dropwhile(predicate, iterable):
|
||||
# dropwhile(lambda x: x<5, [1,4,6,4,1]) -> 6 4 1
|
||||
iterable = iter(iterable)
|
||||
|
|
|
|||
|
|
@ -444,6 +444,24 @@ class SetVariable(ConstDictVariable):
|
|||
return variables.UserFunctionVariable(
|
||||
polyfill.set_isdisjoint
|
||||
).call_function(tx, [self, args[0]], {})
|
||||
elif name == "intersection":
|
||||
assert not kwargs
|
||||
assert len(args) == 1
|
||||
return variables.UserFunctionVariable(
|
||||
polyfill.set_intersection
|
||||
).call_function(tx, [self, args[0]], {})
|
||||
elif name == "union":
|
||||
assert not kwargs
|
||||
assert len(args) == 1
|
||||
return variables.UserFunctionVariable(polyfill.set_union).call_function(
|
||||
tx, [self, args[0]], {}
|
||||
)
|
||||
elif name == "difference":
|
||||
assert not kwargs
|
||||
assert len(args) == 1
|
||||
return variables.UserFunctionVariable(
|
||||
polyfill.set_difference
|
||||
).call_function(tx, [self, args[0]], {})
|
||||
elif (
|
||||
name == "update"
|
||||
and len(args) == 1
|
||||
|
|
|
|||
Loading…
Reference in a new issue