Add proper parse_tensor_constants support (#140558)

Fixes #140422

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140558
Approved by: https://github.com/davidberard98
This commit is contained in:
Antonio Kim 2024-11-13 23:06:24 +00:00 committed by PyTorch MergeBot
parent 9d93c27025
commit 70060b0927
2 changed files with 70 additions and 3 deletions

View file

@ -8237,6 +8237,42 @@ dedent """
with self.assertRaises(RuntimeError):
parse_ir(g, parse_tensor_constants=False)
def test_parse_scalar_tensor_constants(self):
for dtype_str, dtype, value in [
("Float", torch.float32, 1234.5),
("Double", torch.float64, 1234.5),
("BFloat16", torch.bfloat16, 123.5),
("Int", torch.int32, 12345),
("Long", torch.int64, 12345),
("Short", torch.int16, 12345),
]:
g_str = f"""
graph():
%1 : {dtype_str}(requires_grad=0, device=cpu) = prim::Constant[value={{{value}}}]()
return (%1)
"""
jit_graph = parse_ir(g_str, parse_tensor_constants=True)
node = next(
n
for n in jit_graph.nodes()
if isinstance(n.output().type(), torch.TensorType)
)
assert isinstance(node.output().type(), torch.TensorType)
t = node.t("value")
assert isinstance(t, torch.Tensor)
self.assertEqual(t.dtype, dtype)
self.assertEqual(t.item(), value)
with self.assertRaises(RuntimeError):
g_str = """
graph():
%1 : Long(requires_grad=0, device=cpu) = prim::Constant[value={invalid}]()
return (%1)
"""
jit_graph = parse_ir(g_str, parse_tensor_constants=True)
def test_parse_nested_names(self):
g_str = """
graph(%x.1 : Tensor):

View file

@ -271,15 +271,39 @@ ParsedLiteral IRParser::parseScalarLiteral(Node* n) {
if (L.cur().kind == '-') {
L.next();
}
auto text = L.expect(TK_NUMBER);
if (!parse_tensor_constants_) {
auto text = L.expect(TK_NUMBER);
throw(
ErrorReport(token.range)
<< "Single-element tensor constant encountered but "
<< "`parse_tensor_constants` is set to false " << token.text());
}
L.expect('}');
if (L.cur().kind != TK_NUMBER) {
auto text = L.expect(TK_NUMBER);
throw(
ErrorReport(token.range)
<< "Expected single-element tensor constant to contain a number"
<< token.text());
}
auto number = parseScalarLiteral(n);
switch (number.k) {
case AttributeKind::i:
n->ival_(attr::value, c10::Scalar(number.i));
break;
case AttributeKind::f:
n->ival_(attr::value, c10::Scalar(number.f));
break;
case AttributeKind::c:
n->ival_(attr::value, c10::Scalar(number.c));
break;
default:
throw(
ErrorReport(token.range)
<< "Expected single-element tensor constant to contain a number"
<< token.text());
}
deferred_tensor_value_initializations_.push_back(n);
L.expect('}');
r.k = AttributeKind::t;
return r;
}
@ -647,7 +671,14 @@ void IRParser::parse() {
auto dtype = tt->scalarType();
TORCH_INTERNAL_ASSERT(dtype);
auto options = at::TensorOptions(*device).dtype(dtype);
auto t = n->t_(attr::value, at::empty_strided(*sizes, *strides, options));
auto e = at::empty_strided(*sizes, *strides, options);
if (n->hasAttribute(attr::value)) {
auto value = n->ival(attr::value);
e.fill_(value.toScalar());
}
auto t = n->t_(attr::value, e);
(void)t;
}