mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Add proper parse_tensor_constants support (#140558)
Fixes #140422 Pull Request resolved: https://github.com/pytorch/pytorch/pull/140558 Approved by: https://github.com/davidberard98
This commit is contained in:
parent
9d93c27025
commit
70060b0927
2 changed files with 70 additions and 3 deletions
|
|
@ -8237,6 +8237,42 @@ dedent """
|
|||
with self.assertRaises(RuntimeError):
|
||||
parse_ir(g, parse_tensor_constants=False)
|
||||
|
||||
def test_parse_scalar_tensor_constants(self):
|
||||
for dtype_str, dtype, value in [
|
||||
("Float", torch.float32, 1234.5),
|
||||
("Double", torch.float64, 1234.5),
|
||||
("BFloat16", torch.bfloat16, 123.5),
|
||||
("Int", torch.int32, 12345),
|
||||
("Long", torch.int64, 12345),
|
||||
("Short", torch.int16, 12345),
|
||||
]:
|
||||
g_str = f"""
|
||||
graph():
|
||||
%1 : {dtype_str}(requires_grad=0, device=cpu) = prim::Constant[value={{{value}}}]()
|
||||
return (%1)
|
||||
"""
|
||||
|
||||
jit_graph = parse_ir(g_str, parse_tensor_constants=True)
|
||||
|
||||
node = next(
|
||||
n
|
||||
for n in jit_graph.nodes()
|
||||
if isinstance(n.output().type(), torch.TensorType)
|
||||
)
|
||||
assert isinstance(node.output().type(), torch.TensorType)
|
||||
t = node.t("value")
|
||||
assert isinstance(t, torch.Tensor)
|
||||
self.assertEqual(t.dtype, dtype)
|
||||
self.assertEqual(t.item(), value)
|
||||
|
||||
with self.assertRaises(RuntimeError):
|
||||
g_str = """
|
||||
graph():
|
||||
%1 : Long(requires_grad=0, device=cpu) = prim::Constant[value={invalid}]()
|
||||
return (%1)
|
||||
"""
|
||||
jit_graph = parse_ir(g_str, parse_tensor_constants=True)
|
||||
|
||||
def test_parse_nested_names(self):
|
||||
g_str = """
|
||||
graph(%x.1 : Tensor):
|
||||
|
|
|
|||
|
|
@ -271,15 +271,39 @@ ParsedLiteral IRParser::parseScalarLiteral(Node* n) {
|
|||
if (L.cur().kind == '-') {
|
||||
L.next();
|
||||
}
|
||||
auto text = L.expect(TK_NUMBER);
|
||||
if (!parse_tensor_constants_) {
|
||||
auto text = L.expect(TK_NUMBER);
|
||||
throw(
|
||||
ErrorReport(token.range)
|
||||
<< "Single-element tensor constant encountered but "
|
||||
<< "`parse_tensor_constants` is set to false " << token.text());
|
||||
}
|
||||
L.expect('}');
|
||||
if (L.cur().kind != TK_NUMBER) {
|
||||
auto text = L.expect(TK_NUMBER);
|
||||
throw(
|
||||
ErrorReport(token.range)
|
||||
<< "Expected single-element tensor constant to contain a number"
|
||||
<< token.text());
|
||||
}
|
||||
auto number = parseScalarLiteral(n);
|
||||
switch (number.k) {
|
||||
case AttributeKind::i:
|
||||
n->ival_(attr::value, c10::Scalar(number.i));
|
||||
break;
|
||||
case AttributeKind::f:
|
||||
n->ival_(attr::value, c10::Scalar(number.f));
|
||||
break;
|
||||
case AttributeKind::c:
|
||||
n->ival_(attr::value, c10::Scalar(number.c));
|
||||
break;
|
||||
default:
|
||||
throw(
|
||||
ErrorReport(token.range)
|
||||
<< "Expected single-element tensor constant to contain a number"
|
||||
<< token.text());
|
||||
}
|
||||
deferred_tensor_value_initializations_.push_back(n);
|
||||
L.expect('}');
|
||||
r.k = AttributeKind::t;
|
||||
return r;
|
||||
}
|
||||
|
|
@ -647,7 +671,14 @@ void IRParser::parse() {
|
|||
auto dtype = tt->scalarType();
|
||||
TORCH_INTERNAL_ASSERT(dtype);
|
||||
auto options = at::TensorOptions(*device).dtype(dtype);
|
||||
auto t = n->t_(attr::value, at::empty_strided(*sizes, *strides, options));
|
||||
|
||||
auto e = at::empty_strided(*sizes, *strides, options);
|
||||
if (n->hasAttribute(attr::value)) {
|
||||
auto value = n->ival(attr::value);
|
||||
e.fill_(value.toScalar());
|
||||
}
|
||||
|
||||
auto t = n->t_(attr::value, e);
|
||||
(void)t;
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue