diff --git a/aten/src/ATen/core/function_schema.h b/aten/src/ATen/core/function_schema.h index ab021c69cf3..68e177a225d 100644 --- a/aten/src/ATen/core/function_schema.h +++ b/aten/src/ATen/core/function_schema.h @@ -406,7 +406,7 @@ inline std::ostream& operator<<(std::ostream& out, const Argument& arg) { if (arg.default_value()) { out << "="; - if (type->kind() == c10::TypeKind::StringType) { + if (type->kind() == c10::TypeKind::StringType || (unopt_type->kind() == c10::TypeKind::StringType && !arg.default_value().value().isNone())) { printQuotedString(out, arg.default_value().value().toStringRef()); } else { out << arg.default_value().value(); diff --git a/test/test_function_schema.py b/test/test_function_schema.py index 5a152737347..2f81c0bae76 100644 --- a/test/test_function_schema.py +++ b/test/test_function_schema.py @@ -86,5 +86,10 @@ class TestFunctionSchema(TestCase): new_schema = parse_schema('any(Tensor self, *, Tensor b, int[] c) -> Tensor') self.assertFalse(new_schema.is_backward_compatible_with(old_schema)) + def test_string_optional_parameter_default_value(self): + schema_a = parse_schema("example::op(str? order=\"NCHW\") -> (Tensor)") + schema_b = parse_schema(str(schema_a)) + self.assertEquals(schema_a, schema_b) + if __name__ == '__main__': run_tests()