From b4e342928b85aa53a62fd3194c9aa761e7352c8a Mon Sep 17 00:00:00 2001 From: goldenxuett Date: Fri, 8 Jul 2022 11:31:31 -0700 Subject: [PATCH] [JIT] Add mutability checks in FunctionSchema and create SchemaInfo subclass (#80734) - Added overloads to is_mutable method in FunctionSchema to tell whether an argument at index is mutable or an argument with name is mutable. - Created SchemaInfo subclass of FunctionSchema with constructors from FunctionSchema and from const char* signature. - Tested is_mutable method overloads in new test_schema_info.cpp file. **Note that this pr is used to set up SchemaInfo. Implementation for SchemaInfo will be addressed in later commits** Differential Revision: [D37651384](https://our.internmc.facebook.com/intern/diff/D37651384) Pull Request resolved: https://github.com/pytorch/pytorch/pull/80734 Approved by: https://github.com/davidberard98 --- aten/src/ATen/core/function_schema.h | 13 +++++++++++++ build_variables.bzl | 1 + test/cpp/jit/CMakeLists.txt | 1 + test/cpp/jit/test_schema_info.cpp | 24 ++++++++++++++++++++++++ torch/csrc/utils/schema_info.cpp | 5 +++++ torch/csrc/utils/schema_info.h | 25 +++++++++++++++++++++++++ 6 files changed, 69 insertions(+) create mode 100644 test/cpp/jit/test_schema_info.cpp create mode 100644 torch/csrc/utils/schema_info.cpp create mode 100644 torch/csrc/utils/schema_info.h diff --git a/aten/src/ATen/core/function_schema.h b/aten/src/ATen/core/function_schema.h index 6680d2543e2..0b8200e08c3 100644 --- a/aten/src/ATen/core/function_schema.h +++ b/aten/src/ATen/core/function_schema.h @@ -359,7 +359,20 @@ struct FunctionSchema { return aliasInfo && aliasInfo->isWrite(); }); } + bool is_mutable(size_t index) const { + TORCH_INTERNAL_ASSERT( + index < arguments().size(), + "Invalid index for schema."); + const AliasInfo* aliasInfo = arguments()[index].alias_info(); + return aliasInfo && aliasInfo->isWrite(); + } + bool is_mutable(c10::string_view name) const { + c10::optional index = argumentIndexWithName(name); + TORCH_INTERNAL_ASSERT( + index != c10::nullopt, "Schema has no argument named ", name); + return is_mutable(*index); + } c10::optional argumentIndexWithName(c10::string_view name) const { for (const auto i : c10::irange(arguments().size())) { if(name == arguments()[i].name()) diff --git a/build_variables.bzl b/build_variables.bzl index f38af258be9..68353440865 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -365,6 +365,7 @@ core_sources_full_mobile_no_backend_interface = [ "torch/csrc/jit/testing/file_check.cpp", "torch/csrc/jit/testing/hooks_for_testing.cpp", "torch/csrc/utils/cpp_stacktraces.cpp", + "torch/csrc/utils/schema_info.cpp", "torch/csrc/utils/tensor_flatten.cpp", "torch/csrc/utils/variadic.cpp", ] diff --git a/test/cpp/jit/CMakeLists.txt b/test/cpp/jit/CMakeLists.txt index 9bd349b6195..03f0647b507 100644 --- a/test/cpp/jit/CMakeLists.txt +++ b/test/cpp/jit/CMakeLists.txt @@ -80,6 +80,7 @@ set(JIT_TEST_SRCS ${JIT_TEST_ROOT}/test_peephole_optimize.cpp ${JIT_TEST_ROOT}/test_qualified_name.cpp ${JIT_TEST_ROOT}/test_save_load.cpp + ${JIT_TEST_ROOT}/test_schema_info.cpp ${JIT_TEST_ROOT}/test_schema_matching.cpp ${JIT_TEST_ROOT}/test_stack_opt.cpp ${JIT_TEST_ROOT}/test_subgraph_matcher.cpp diff --git a/test/cpp/jit/test_schema_info.cpp b/test/cpp/jit/test_schema_info.cpp new file mode 100644 index 00000000000..11ac4e4c66e --- /dev/null +++ b/test/cpp/jit/test_schema_info.cpp @@ -0,0 +1,24 @@ +#include +#include + +namespace torch { +namespace utils { +TEST(FunctionSchemaIsMutableTest, Basic) { + c10::FunctionSchema schema = torch::jit::parseSchema( + "aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))"); + ASSERT_TRUE(schema.is_mutable(0)); + ASSERT_TRUE(schema.is_mutable("self")); + ASSERT_FALSE(schema.is_mutable(1)); + ASSERT_FALSE(schema.is_mutable("other")); + ASSERT_FALSE(schema.is_mutable(2)); + ASSERT_FALSE(schema.is_mutable("alpha")); +} + +TEST(FunctionSchemaIsMutableTest, InvalidArgument) { + c10::FunctionSchema schema = torch::jit::parseSchema( + "aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))"); + ASSERT_THROW(schema.is_mutable(4), c10::Error); + ASSERT_THROW(schema.is_mutable("named_argument"), c10::Error); +} +} // namespace utils +} // namespace torch diff --git a/torch/csrc/utils/schema_info.cpp b/torch/csrc/utils/schema_info.cpp new file mode 100644 index 00000000000..302bae6cbd4 --- /dev/null +++ b/torch/csrc/utils/schema_info.cpp @@ -0,0 +1,5 @@ +#include + +namespace torch { +namespace utils {} // namespace utils +} // namespace torch diff --git a/torch/csrc/utils/schema_info.h b/torch/csrc/utils/schema_info.h new file mode 100644 index 00000000000..75efd8062d6 --- /dev/null +++ b/torch/csrc/utils/schema_info.h @@ -0,0 +1,25 @@ +#pragma once + +#include + +namespace torch { +namespace utils { + +/** + * class SchemaInfo + * + * Subclass of FunctionSchema that publicizes argument value specific operator + * behavior (mutation, aliasing, special cases, etc...) + */ + +struct TORCH_API SchemaInfo { + public: + explicit SchemaInfo(c10::FunctionSchema schema) : schema_(schema) {} + explicit SchemaInfo(const char* signature) + : schema_(torch::jit::parseSchema(signature)) {} + + private: + c10::FunctionSchema schema_; +}; +} // namespace utils +} // namespace torch