mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Tweak schema_check to handle annotated builtin types (#145154)
As of python 3.9 annotated lists can be written as `list[T]` and `List[T]` has been deprecated. However schema_check was converting `list[T]` to simply be `list`. This change teaches it to handle `list[T]` the same as `List[T]`. A couple small drive-by changes I noticed as well: - Path concatenation should use `os.path.join`, not `+` - Spelling in error message Pull Request resolved: https://github.com/pytorch/pytorch/pull/145154 Approved by: https://github.com/bobrenjc93
This commit is contained in:
parent
9e0437a04a
commit
cd8d0fa20c
3 changed files with 26 additions and 26 deletions
|
|
@ -80,9 +80,9 @@ if __name__ == "__main__":
|
||||||
print(yaml_content)
|
print(yaml_content)
|
||||||
print("\nWill write the above schema to" + args.prefix + commit.yaml_path)
|
print("\nWill write the above schema to" + args.prefix + commit.yaml_path)
|
||||||
else:
|
else:
|
||||||
with open(args.prefix + commit.yaml_path, "w") as f:
|
with open(os.path.join(args.prefix, commit.yaml_path), "w") as f:
|
||||||
f.write(yaml_content)
|
f.write(yaml_content)
|
||||||
with open(args.prefix + commit.cpp_header_path, "w") as f:
|
with open(os.path.join(args.prefix, commit.cpp_header_path), "w") as f:
|
||||||
f.write(cpp_header)
|
f.write(cpp_header)
|
||||||
with open(args.prefix + commit.thrift_schema_path, "w") as f:
|
with open(os.path.join(args.prefix, commit.thrift_schema_path), "w") as f:
|
||||||
f.write(thrift_schema)
|
f.write(thrift_schema)
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,7 @@ class TestSchema(TestCase):
|
||||||
msg = """
|
msg = """
|
||||||
Detected an invalidated change to export schema. Please run the following script to update the schema:
|
Detected an invalidated change to export schema. Please run the following script to update the schema:
|
||||||
Example(s):
|
Example(s):
|
||||||
python scripts/export/update_schema.py --prefix <path_to_torch_development_diretory>
|
python scripts/export/update_schema.py --prefix <path_to_torch_development_directory>
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if IS_FBCODE:
|
if IS_FBCODE:
|
||||||
|
|
@ -32,7 +32,7 @@ Example(s):
|
||||||
msg = """
|
msg = """
|
||||||
Detected an unexpected change to schema.thrift. Please update schema.py instead and run the following script:
|
Detected an unexpected change to schema.thrift. Please update schema.py instead and run the following script:
|
||||||
Example(s):
|
Example(s):
|
||||||
python scripts/export/update_schema.py --prefix <path_to_torch_development_diretory>
|
python scripts/export/update_schema.py --prefix <path_to_torch_development_directory>
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if IS_FBCODE:
|
if IS_FBCODE:
|
||||||
|
|
|
||||||
|
|
@ -20,6 +20,21 @@ def _check(x, msg):
|
||||||
raise SchemaUpdateError(msg)
|
raise SchemaUpdateError(msg)
|
||||||
|
|
||||||
|
|
||||||
|
_CPP_TYPE_MAP = {
|
||||||
|
str: "std::string",
|
||||||
|
int: "int64_t",
|
||||||
|
float: "double",
|
||||||
|
bool: "bool",
|
||||||
|
}
|
||||||
|
|
||||||
|
_THRIFT_TYPE_MAP = {
|
||||||
|
str: "string",
|
||||||
|
int: "i64",
|
||||||
|
float: "double",
|
||||||
|
bool: "bool",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def _staged_schema():
|
def _staged_schema():
|
||||||
yaml_ret: Dict[str, Any] = {}
|
yaml_ret: Dict[str, Any] = {}
|
||||||
defs = {}
|
defs = {}
|
||||||
|
|
@ -32,27 +47,10 @@ def _staged_schema():
|
||||||
|
|
||||||
def _handle_aggregate(ty) -> tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
|
def _handle_aggregate(ty) -> tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
|
||||||
def dump_type(t, level: int) -> tuple[str, str, str]:
|
def dump_type(t, level: int) -> tuple[str, str, str]:
|
||||||
CPP_TYPE_MAP = {
|
if getattr(t, "__name__", None) in cpp_enum_defs:
|
||||||
str: "std::string",
|
return t.__name__, "int64_t", t.__name__
|
||||||
int: "int64_t",
|
elif t in _CPP_TYPE_MAP:
|
||||||
float: "double",
|
return (t.__name__, _CPP_TYPE_MAP[t], _THRIFT_TYPE_MAP[t])
|
||||||
bool: "bool",
|
|
||||||
}
|
|
||||||
THRIFT_TYPE_MAP = {
|
|
||||||
str: "string",
|
|
||||||
int: "i64",
|
|
||||||
float: "double",
|
|
||||||
bool: "bool",
|
|
||||||
}
|
|
||||||
if isinstance(t, type):
|
|
||||||
if t.__name__ in cpp_enum_defs:
|
|
||||||
return t.__name__, "int64_t", t.__name__
|
|
||||||
else:
|
|
||||||
return (
|
|
||||||
t.__name__,
|
|
||||||
CPP_TYPE_MAP.get(t, t.__name__),
|
|
||||||
THRIFT_TYPE_MAP.get(t, t.__name__),
|
|
||||||
)
|
|
||||||
elif isinstance(t, str):
|
elif isinstance(t, str):
|
||||||
assert t in defs
|
assert t in defs
|
||||||
assert t not in cpp_enum_defs
|
assert t not in cpp_enum_defs
|
||||||
|
|
@ -102,6 +100,8 @@ def _staged_schema():
|
||||||
(f"{cpp_head}<{', '.join(cpp_arg_types)}>"),
|
(f"{cpp_head}<{', '.join(cpp_arg_types)}>"),
|
||||||
f"{thrift_head}{', '.join(thrift_arg_types)}{thrift_tail}",
|
f"{thrift_head}{', '.join(thrift_arg_types)}{thrift_tail}",
|
||||||
)
|
)
|
||||||
|
elif isinstance(t, type):
|
||||||
|
return (t.__name__, t.__name__, t.__name__)
|
||||||
else:
|
else:
|
||||||
raise AssertionError(f"Type {t} is not supported in export schema.")
|
raise AssertionError(f"Type {t} is not supported in export schema.")
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue