[pytorch] Fix printing of optional string arguments in schemas (#55196)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/55196

This commit fixes printing of default values for optional string type arguments in schemas. At the moment, these default values are not printed as quoted strings. If a schema with an optional string type parameter with a default value that is not `None` is printed and then parsed, the lack of quotes causes a parsing error.
ghstack-source-id: 125655241

Test Plan: This commit adds a unit test to `test_function_schema.py` to test this case.

Differential Revision: D27525450

fbshipit-source-id: 23a93169e7599e7b385e59b7cfafb17fd76318b7
This commit is contained in:
Meghan Lele 2021-04-05 15:26:48 -07:00 committed by Facebook GitHub Bot
parent 2ee02b30b1
commit ef262575dd
2 changed files with 6 additions and 1 deletions

View file

@ -406,7 +406,7 @@ inline std::ostream& operator<<(std::ostream& out, const Argument& arg) {
if (arg.default_value()) {
out << "=";
if (type->kind() == c10::TypeKind::StringType) {
if (type->kind() == c10::TypeKind::StringType || (unopt_type->kind() == c10::TypeKind::StringType && !arg.default_value().value().isNone())) {
printQuotedString(out, arg.default_value().value().toStringRef());
} else {
out << arg.default_value().value();

View file

@ -86,5 +86,10 @@ class TestFunctionSchema(TestCase):
new_schema = parse_schema('any(Tensor self, *, Tensor b, int[] c) -> Tensor')
self.assertFalse(new_schema.is_backward_compatible_with(old_schema))
def test_string_optional_parameter_default_value(self):
schema_a = parse_schema("example::op(str? order=\"NCHW\") -> (Tensor)")
schema_b = parse_schema(str(schema_a))
self.assertEquals(schema_a, schema_b)
if __name__ == '__main__':
run_tests()