diff --git a/.gitignore b/.gitignore index c268b75bde3..71e9d56255e 100644 --- a/.gitignore +++ b/.gitignore @@ -51,7 +51,7 @@ test/custom_operator/model.pt test/jit_hooks/*.pt test/data/legacy_modules.t7 test/data/*.pt -test/backward_compatibility/nightly_schemas.txt +test/forward_backward_compatibility/nightly_schemas.txt dropout_model.pt test/generated_type_hints_smoketest.py test/htmlcov diff --git a/.jenkins/pytorch/test.sh b/.jenkins/pytorch/test.sh index 216d9a299b8..b2a64e3af8c 100755 --- a/.jenkins/pytorch/test.sh +++ b/.jenkins/pytorch/test.sh @@ -436,9 +436,9 @@ test_xla() { # Do NOT run this test before any other tests, like test_python_shard, etc. # Because this function uninstalls the torch built from branch, and install # nightly version. -test_backward_compatibility() { +test_forward_backward_compatibility() { set -x - pushd test/backward_compatibility + pushd test/forward_backward_compatibility python -m venv venv # shellcheck disable=SC1091 . venv/bin/activate @@ -448,7 +448,7 @@ test_backward_compatibility() { deactivate rm -r venv pip show torch - python check_backward_compatibility.py --existing-schemas nightly_schemas.txt + python check_forward_backward_compatibility.py --existing-schemas nightly_schemas.txt popd set +x assert_git_not_dirty @@ -529,7 +529,7 @@ if ! [[ "${BUILD_ENVIRONMENT}" == *libtorch* || "${BUILD_ENVIRONMENT}" == *-baze fi if [[ "${BUILD_ENVIRONMENT}" == *backward* ]]; then - test_backward_compatibility + test_forward_backward_compatibility # Do NOT add tests after bc check tests, see its comment. elif [[ "${TEST_CONFIG}" == *xla* ]]; then install_torchvision diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 6984ee6c0e2..0bf04c35fe2 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1316,7 +1316,7 @@ This choice depends on several factors; here is the decision tree as of - pytorch_linux_xenial_py3_6_gcc5_4_build - pytorch_cpp_doc_build - pytorch_doc_test - - pytorch_linux_backward_compatibility_check_test + - pytorch_linux_forward_backward_compatibility_check_test - pytorch_linux_xenial_py3_6_gcc5_4_jit_legacy_test - pytorch_linux_xenial_py3_6_gcc5_4_test - pytorch_python_doc_build diff --git a/aten/src/ATen/core/function_schema.h b/aten/src/ATen/core/function_schema.h index 353e782a716..078f8aa9a6d 100644 --- a/aten/src/ATen/core/function_schema.h +++ b/aten/src/ATen/core/function_schema.h @@ -141,6 +141,15 @@ struct Argument { const Argument& old, std::ostream* why_not=nullptr) const; + // this function checks whether this Argument is forward compatible with + // the old one. we consider the following cases are forward compatible: + // 1) two arguments are equal + // 2) this arg's type should be subtype of old + // 3) this arg must provide the same default value if old arg has one, + bool isForwardCompatibleWith( + const Argument& old, + std::ostream* why_not = nullptr) const; + private: std::string name_; TypePtr type_; @@ -238,6 +247,28 @@ struct FunctionSchema { const FunctionSchema& old, std::ostream* why_not = nullptr) const; + // Checks whether this schema is forward compatible with the old one. + // The following conditions must be true: + // [Function structure] The new schema's name, overload-name, varargs, and + // return arity are the same. + // [Output Narrowing] The new schema's output type must be the same class + // or inherit from the old schema's output type. + // [Arg Compatibility] Every argument in the old schema has a corresponding + // argument in the new schema that: + // * is at the same position. + // * has the same name. + // * is either positional, or kwarg and the old argument was kwarg. + // * has the same type, or the old argument's type inherits from the + // new argument's type. + // [Default Values] Every new argument must have a default value. + // Each default value type should NOT be a container type. + // [Positioning] All defaults arguments MUST go after either old + // default arguments or the end of positional arguments + // and right BEFORE all out arguments + bool isForwardCompatibleWith( + const FunctionSchema& old, + std::ostringstream& why_not) const; + private: OperatorName name_; std::vector arguments_; diff --git a/aten/src/ATen/core/function_schema_inl.h b/aten/src/ATen/core/function_schema_inl.h index b7aab0730c7..fef580e0780 100644 --- a/aten/src/ATen/core/function_schema_inl.h +++ b/aten/src/ATen/core/function_schema_inl.h @@ -1,4 +1,5 @@ #pragma once +#include // note: windows build doesn't find symbols in operator files unless // this is a header file @@ -86,6 +87,34 @@ inline bool Argument::isBackwardCompatibleWith( return true; } +inline bool Argument::isForwardCompatibleWith( + const Argument& old, + std::ostream* why_not) const { + const Argument* lhs = this; + const Argument* rhs = &old; + if (!(lhs->name() == rhs->name() + && lhs->N() == rhs->N() + && (lhs->alias_info() == rhs->alias_info() + || (lhs->alias_info() != nullptr && rhs->alias_info() != nullptr + && *lhs->alias_info() == *rhs->alias_info())))) { + return false; + } + if (lhs->kwarg_only() && !rhs->kwarg_only()) { + return false; + } + if (!lhs->type()->isSubtypeOfExt(rhs->type(), why_not)) { + return false; + } + if (rhs->default_value().has_value() && + lhs->default_value() != rhs->default_value()) { + return false; + } + if (lhs->default_value().has_value() && !rhs->default_value().has_value()) { + return false; + } + return true; +} + inline std::string FunctionSchema::formatTypeMismatchMsg( const Argument& expected, const std::string& actual_type, @@ -145,7 +174,7 @@ inline bool FunctionSchema::isBackwardCompatibleWith( } } - // // Validate that all new arguments provided has a default value + // Validate that all new arguments provided has a default value for (const auto i : c10::irange(old_out_start_idx, new_out_start_idx)) { if (!arguments().at(i).default_value()) { if (why_not) { @@ -171,6 +200,86 @@ inline bool FunctionSchema::isBackwardCompatibleWith( return true; } +inline bool FunctionSchema::isForwardCompatibleWith( + const FunctionSchema& old, + std::ostringstream& why_not) const { + if (!(name() == old.name() && + overload_name() == old.overload_name() + // we are conservative on is_vararg and is_varret, + // since they are only used by internal operators + && is_vararg() == old.is_vararg() && is_varret() == old.is_varret() && + returns().size() == old.returns().size())) { + return false; + } + + // we want to test both out and default args seperately + size_t old_out_start_idx = findFirstOutArg(old.arguments()); + size_t new_out_start_idx = findFirstOutArg(arguments()); + + if (old.arguments().size() - old_out_start_idx != + arguments().size() - new_out_start_idx) { + if (why_not) { + why_not << "Function schema should have the " + << "same number of out arguments"; + } + return false; + } + + // make sure among the default args, they are forward compatible + for (size_t i = 0; i < std::min(old_out_start_idx, new_out_start_idx); i++) { + if (!arguments().at(i).isForwardCompatibleWith(old.arguments().at(i))) { + if (why_not) { + why_not + << "'" << arguments().at(i).name() << "'" + << " is not forward compatible with the older version of the schema"; + } + return false; + } + } + + // Validate that all new arguments provided has a default value + for (size_t i = old_out_start_idx; i < new_out_start_idx; ++i) { + if (!arguments().at(i).default_value()) { + if (why_not) { + why_not + << "Function schema is not forward compatible since the new argument '" + << arguments().at(i).name() << "' of type " + << arguments().at(i).type()->str() + << " did not provide a default value."; + } + return false; + } + + auto default_val = arguments().at(i).default_value().value(); + if (default_val.isList() || default_val.isGenericDict()) { + if (why_not) { + why_not + << "Function schema is not forward compatible since the new argument '" + << arguments().at(i).name() << "' of type " + << arguments().at(i).type()->str() << " has a container type " + << "as its default value."; + } + return false; + } + } + + // now compare the out args + for (size_t i = old_out_start_idx; i < old.arguments().size(); i++) { + if (!arguments() + .at(i - old_out_start_idx + new_out_start_idx) + .isForwardCompatibleWith(old.arguments().at(i))) { + if (why_not) { + why_not << "Out argument '" + << "'" << arguments().at(i).name() + << " is not FC with the older version of the schema"; + } + return false; + } + } + + return true; +} + inline void FunctionSchema::checkArg( const IValue& value, const Argument& argument, diff --git a/test/backward_compatibility/check_backward_compatibility.py b/test/forward_backward_compatibility/check_forward_backward_compatibility.py similarity index 82% rename from test/backward_compatibility/check_backward_compatibility.py rename to test/forward_backward_compatibility/check_forward_backward_compatibility.py index 6578ee9d4bf..c0d2a4186e7 100644 --- a/test/backward_compatibility/check_backward_compatibility.py +++ b/test/forward_backward_compatibility/check_forward_backward_compatibility.py @@ -2,6 +2,7 @@ import argparse import datetime import re import sys +import warnings from collections import defaultdict import torch @@ -149,14 +150,16 @@ def dont_parse(schema_line): return True return False - -def check_bc(existing_schemas): +def load_schemas_to_dict(): new_schemas = torch._C._jit_get_all_schemas() new_schemas += torch._C._jit_get_custom_class_schemas() new_schema_dict = defaultdict(list) for s in new_schemas: new_schema_dict[s.name].append(s) + return new_schema_dict +def check_bc(existing_schemas): + new_schema_dict = load_schemas_to_dict() is_bc = True broken_ops = [] for existing_schema in existing_schemas: @@ -192,6 +195,51 @@ def check_bc(existing_schemas): ) return is_bc +def check_fc(existing_schemas): + new_schema_dict = load_schemas_to_dict() + is_fc = True + broken_ops = [] + for existing_schema in existing_schemas: + if allow_listed(existing_schema): + print("schema: ", str(existing_schema), " found on allowlist, skipping") + continue + print("processing existing schema: ", str(existing_schema)) + matching_new_schemas = new_schema_dict.get(existing_schema.name, []) + found = False + possible_failure_reasons = [] + for matching_new_schema in matching_new_schemas: + is_compatible, reason = matching_new_schema.check_forward_compatible_with(existing_schema) + if is_compatible: + found = True + break + if reason != "": + possible_failure_reasons.append(reason) + if not found: + print( + "Can NOT find forward compatible schemas after changes " + "for schema {} from the following candidates:\n[\n{}\n]".format( + str(existing_schema), + "\n\t".join(str(s) for s in matching_new_schemas), + ) + ) + print( + "Refer to following reasons for failure " + "to find FC schema:\n[\n{}\n]".format( + "\n\t".join(str(r) for r in possible_failure_reasons) + ) + ) + broken_ops.append(str(existing_schema)) + is_fc = False + if is_fc: + print("Found forward compatible schemas for all existing schemas") + else: + warnings.warn( + "The PR is introducing a potentially forward incompatible changes to the " + "operator library. Please contact PyTorch team to confirm " + "whether this change is wanted or not. \n\nBroken ops: " + "[\n\t{}\n]".format("\n\t".join(broken_ops)) + ) + if __name__ == "__main__": parser = argparse.ArgumentParser(description="Process some integers.") @@ -216,5 +264,9 @@ if __name__ == "__main__": s = parse_schema(line.strip()) slist.append(s) + # TODO in case there is FC breaking changes, + # we just warn for now until there is a policy. + check_fc(slist) + if not check_bc(slist): sys.exit(1) diff --git a/test/backward_compatibility/dump_all_function_schemas.py b/test/forward_backward_compatibility/dump_all_function_schemas.py similarity index 100% rename from test/backward_compatibility/dump_all_function_schemas.py rename to test/forward_backward_compatibility/dump_all_function_schemas.py diff --git a/test/test_function_schema.py b/test/test_function_schema.py index 048ecf6464f..b64219d5062 100644 --- a/test/test_function_schema.py +++ b/test/test_function_schema.py @@ -111,7 +111,70 @@ class TestFunctionSchema(TestCase): def test_string_optional_parameter_default_value(self): schema_a = parse_schema("example::op(str? order=\"NCHW\") -> (Tensor)") schema_b = parse_schema(str(schema_a)) - self.assertEquals(schema_a, schema_b) + self.assertEqual(schema_a, schema_b) + + def test_forward_compatible_arguments_without_out(self): + old_schema = parse_schema('any(Tensor self, int a, int b=1) -> Tensor') + # deleting default arg is FC compatible + new_schema = parse_schema('any(Tensor self, int a) -> Tensor') + is_fc, _ = new_schema.check_forward_compatible_with(old_schema) + self.assertTrue(is_fc) + # adding default arg is FC compatible + new_schema = parse_schema('any(Tensor self, int a, int b=1, int c=1) -> Tensor') + is_fc, _ = new_schema.check_forward_compatible_with(old_schema) + self.assertTrue(is_fc) + # adding default arg with container type is NOT FC compatible + new_schema = parse_schema('any(Tensor self, int a, int b=1, int[2] c=1) -> Tensor') + is_fc, reason = new_schema.check_forward_compatible_with(old_schema) + self.assertFalse(is_fc) + self.assertEqual(reason, "Function schema is not forward compatible since the new argument" + " \'c\' of type int[] has a container type as its default value.") + # updating the default value of a default arg is NOT FC compatible + new_schema = parse_schema('any(Tensor self, int a, int b=4) -> Tensor') + is_fc, reason = new_schema.check_forward_compatible_with(old_schema) + self.assertFalse(is_fc) + self.assertEqual(reason, "\'b\' is not forward compatible with the older version of the schema") + # updating the arg name of a default arg is NOT FC compatible + new_schema = parse_schema('any(Tensor self, int a, int c=1) -> Tensor') + is_fc, reason = new_schema.check_forward_compatible_with(old_schema) + self.assertFalse(is_fc) + self.assertEqual(reason, "\'c\' is not forward compatible with the older version of the schema") + # not adding default arg in the end is NOT FC compatible + new_schema = parse_schema('any(Tensor self, int a, int c=1, int b=1) -> Tensor') + is_fc, reason = new_schema.check_forward_compatible_with(old_schema) + self.assertFalse(is_fc) + self.assertEqual(reason, "\'c\' is not forward compatible with the older version of the schema") + # making default arg into positional arg is NOT FC compatible + new_schema = parse_schema('any(Tensor self, int a, int b) -> Tensor') + is_fc, reason = new_schema.check_forward_compatible_with(old_schema) + self.assertFalse(is_fc) + self.assertEqual(reason, "\'b\' is not forward compatible with the older version of the schema") + # making positional arg into default arg is NOT FC compatible + new_schema = parse_schema('any(Tensor self, int a=1, int b=1) -> Tensor') + is_fc, reason = new_schema.check_forward_compatible_with(old_schema) + self.assertFalse(is_fc) + self.assertEqual(reason, "\'a\' is not forward compatible with the older version of the schema") + + def test_forward_compatible_arguments_real_use_case(self): + # this change introduced forward incompatibility in the past + old_slice_schema = parse_schema('slice(Tensor(a) self, int dim=0, int start=0, int end=0, int step=1) -> Tensor(a)') + new_slice_schema = parse_schema('slice(Tensor(a) self, int dim=0, int? start=None, int? end=None, int step=1) -> Tensor(a)') + is_fc, reason = new_slice_schema.check_forward_compatible_with(old_slice_schema) + self.assertFalse(is_fc) + self.assertEqual(reason, "\'start\' is not forward compatible with the older version of the schema") + + def test_forward_compatible_arguments_with_out(self): + old_schema = parse_schema('any(Tensor self, *, int a, int b=1, Tensor(a!) out) -> Tensor(a!)') + new_schema = parse_schema('any(Tensor self, *, int a, Tensor(a!) out) -> Tensor(a!)') + is_fc, _ = new_schema.check_forward_compatible_with(old_schema) + self.assertTrue(is_fc) + new_schema = parse_schema('any(Tensor self, *, int a, int b=1, int c=1, Tensor(a!) out) -> Tensor(a!)') + is_fc, _ = new_schema.check_forward_compatible_with(old_schema) + self.assertTrue(is_fc) + new_schema = parse_schema('any(Tensor self, *, int a, Tensor(d!) d, int b=1, Tensor(a!) out) -> Tensor(a!)') + is_fc, reason = new_schema.check_forward_compatible_with(old_schema) + self.assertFalse(is_fc) + self.assertEqual(reason, "Function schema should have the same number of out arguments") def test_schema_error(self): with self.assertRaisesRegex(RuntimeError, r"schemas with vararg \(...\) can't have default value args"): diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp index 9a1601b224e..238e9fff114 100644 --- a/torch/csrc/jit/python/init.cpp +++ b/torch/csrc/jit/python/init.cpp @@ -1390,6 +1390,13 @@ void initJITBindings(PyObject* module) { [](const FunctionSchema& self, const FunctionSchema& old_schema) { return self.isBackwardCompatibleWith(old_schema); }) + .def( + "check_forward_compatible_with", + [](const FunctionSchema& self, const FunctionSchema& old_schema) { + std::ostringstream out; + auto result = self.isForwardCompatibleWith(old_schema, out); + return std::make_pair(result, out.str()); + }) .def( "__eq__", [](const FunctionSchema& self, const FunctionSchema& other) {