mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[pytorch] Fix printing of optional string arguments in schemas (#55196)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/55196 This commit fixes printing of default values for optional string type arguments in schemas. At the moment, these default values are not printed as quoted strings. If a schema with an optional string type parameter with a default value that is not `None` is printed and then parsed, the lack of quotes causes a parsing error. ghstack-source-id: 125655241 Test Plan: This commit adds a unit test to `test_function_schema.py` to test this case. Differential Revision: D27525450 fbshipit-source-id: 23a93169e7599e7b385e59b7cfafb17fd76318b7
This commit is contained in:
parent
2ee02b30b1
commit
ef262575dd
2 changed files with 6 additions and 1 deletions
|
|
@ -406,7 +406,7 @@ inline std::ostream& operator<<(std::ostream& out, const Argument& arg) {
|
|||
|
||||
if (arg.default_value()) {
|
||||
out << "=";
|
||||
if (type->kind() == c10::TypeKind::StringType) {
|
||||
if (type->kind() == c10::TypeKind::StringType || (unopt_type->kind() == c10::TypeKind::StringType && !arg.default_value().value().isNone())) {
|
||||
printQuotedString(out, arg.default_value().value().toStringRef());
|
||||
} else {
|
||||
out << arg.default_value().value();
|
||||
|
|
|
|||
|
|
@ -86,5 +86,10 @@ class TestFunctionSchema(TestCase):
|
|||
new_schema = parse_schema('any(Tensor self, *, Tensor b, int[] c) -> Tensor')
|
||||
self.assertFalse(new_schema.is_backward_compatible_with(old_schema))
|
||||
|
||||
def test_string_optional_parameter_default_value(self):
|
||||
schema_a = parse_schema("example::op(str? order=\"NCHW\") -> (Tensor)")
|
||||
schema_b = parse_schema(str(schema_a))
|
||||
self.assertEquals(schema_a, schema_b)
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
|
|
|||
Loading…
Reference in a new issue