diff --git a/test/test_jit.py b/test/test_jit.py index 9e3bef1e004..251ea0916f4 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -14184,6 +14184,43 @@ dedent """ FileCheck().check_not("prim::PythonOp").run(cu.test.graph) + def test_parse_generator(self): + def _test_parse_generator(seed): + jit_graph = parse_ir( + f""" + graph(): + %0 : float = prim::Constant[value=-0.31622776601683789]() + %1 : float = prim::Constant[value=0.31622776601683789]() + %2 : Generator = prim::Constant[value=torch.Generator(device="cpu", seed={seed})]() + %3 : NoneType = prim::Constant() + %4 : int[] = prim::Constant[value=[]]() + %5 : int = prim::Constant[value=6]() + %6 : Device = prim::Constant[value="cpu"]() + %7 : Tensor = aten::empty(%4, %5, %3, %6, %3, %3) + %8 : Float() = aten::uniform(%7, %0, %1, %2) + return (%8) + """, + ) + + node = next( + n + for n in jit_graph.nodes() + if isinstance(n.output().type(), torch._C._GeneratorType) + ) + assert isinstance(node.output().type(), torch._C._GeneratorType) + g = node.ival("value") + assert isinstance(g, torch.Generator) + self.assertEqual(g.initial_seed(), seed) + + _test_parse_generator(2024) + _test_parse_generator(2**63 - 1) + + with self.assertRaisesRegex(RuntimeError, "Seed must be a non-negative integer"): + _test_parse_generator(-2024) + + with self.assertRaisesRegex(RuntimeError, "Number is too big"): + _test_parse_generator(2**63) + def test_early_return_rewrite(self): def test_foo(x: bool): if x: diff --git a/torch/csrc/jit/ir/irparser.cpp b/torch/csrc/jit/ir/irparser.cpp index e04d7cdc93a..6f02a8849e0 100644 --- a/torch/csrc/jit/ir/irparser.cpp +++ b/torch/csrc/jit/ir/irparser.cpp @@ -427,6 +427,69 @@ void IRParser::parseAttr(Node* n) { } L.expect(')'); deferred_empty_container_initializations_.push_back(n); + } else if (L.cur().text() == "torch") { + L.next(); + L.expect('.'); + auto function = L.cur().text(); + if (function == "Generator") { + L.next(); + L.expect('('); + std::optional seed; + std::string device = "cpu"; + while (!L.nextIf(')')) { + auto arg = L.expect(TK_IDENT).text(); + L.expect('='); + if (arg == "device") { + ParsedLiteral r = parseScalarLiteral(n); + if (r.k != AttributeKind::s) { + throw( + ErrorReport(L.cur().range) + << "Expected string literal for device argument"); + } + if (r.s != "cpu") { + throw( + ErrorReport(L.cur().range) + << "Only cpu device is supported for Generator at this time."); + } + device = r.s; + } else if (arg == "seed") { + ParsedLiteral r = parseScalarLiteral(n); + if (r.k != AttributeKind::i) { + throw( + ErrorReport(L.cur().range) + << "Expected int literal for seed argument"); + } + if (r.i < 0) { + throw( + ErrorReport(L.cur().range) + << "Seed must be a non-negative integer"); + } + seed = r.i; + } else { + throw( + ErrorReport(L.cur().range) + << "Generator only supports the following arguments:\n" + << "- device\n" + << "- seed\n" + << "Got: " << arg); + } + L.nextIf(','); + } + if (device == "cpu") { + if (seed.has_value()) { + n->ival_( + Symbol::attr(attrname), at::detail::createCPUGenerator(*seed)); + } else { + n->ival_(Symbol::attr(attrname), at::detail::createCPUGenerator()); + } + } + } else { + throw( + ErrorReport(L.cur().range) + << "Expected one of the following torch functions:\n" + << "- Generator\n" + << "Got: " << function); + } } else { // scalar ParsedLiteral r = parseScalarLiteral(n);