[JIT] Add mutability checks in FunctionSchema and create SchemaInfo subclass (#80734)

- Added overloads to is_mutable method in FunctionSchema to tell whether an argument at index is mutable or an argument with name is mutable.
- Created SchemaInfo subclass of FunctionSchema with constructors from FunctionSchema and from const char* signature.
- Tested is_mutable method overloads in new test_schema_info.cpp file.

**Note that this pr is used to set up SchemaInfo. Implementation for SchemaInfo will be addressed in later commits**

Differential Revision: [D37651384](https://our.internmc.facebook.com/intern/diff/D37651384)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80734
Approved by: https://github.com/davidberard98
This commit is contained in:
goldenxuett 2022-07-08 11:31:31 -07:00 committed by PyTorch MergeBot
parent 528ee0fa75
commit b4e342928b
6 changed files with 69 additions and 0 deletions

View file

@ -359,7 +359,20 @@ struct FunctionSchema {
return aliasInfo && aliasInfo->isWrite();
});
}
bool is_mutable(size_t index) const {
TORCH_INTERNAL_ASSERT(
index < arguments().size(),
"Invalid index for schema.");
const AliasInfo* aliasInfo = arguments()[index].alias_info();
return aliasInfo && aliasInfo->isWrite();
}
bool is_mutable(c10::string_view name) const {
c10::optional<int> index = argumentIndexWithName(name);
TORCH_INTERNAL_ASSERT(
index != c10::nullopt, "Schema has no argument named ", name);
return is_mutable(*index);
}
c10::optional<int> argumentIndexWithName(c10::string_view name) const {
for (const auto i : c10::irange(arguments().size())) {
if(name == arguments()[i].name())

View file

@ -365,6 +365,7 @@ core_sources_full_mobile_no_backend_interface = [
"torch/csrc/jit/testing/file_check.cpp",
"torch/csrc/jit/testing/hooks_for_testing.cpp",
"torch/csrc/utils/cpp_stacktraces.cpp",
"torch/csrc/utils/schema_info.cpp",
"torch/csrc/utils/tensor_flatten.cpp",
"torch/csrc/utils/variadic.cpp",
]

View file

@ -80,6 +80,7 @@ set(JIT_TEST_SRCS
${JIT_TEST_ROOT}/test_peephole_optimize.cpp
${JIT_TEST_ROOT}/test_qualified_name.cpp
${JIT_TEST_ROOT}/test_save_load.cpp
${JIT_TEST_ROOT}/test_schema_info.cpp
${JIT_TEST_ROOT}/test_schema_matching.cpp
${JIT_TEST_ROOT}/test_stack_opt.cpp
${JIT_TEST_ROOT}/test_subgraph_matcher.cpp

View file

@ -0,0 +1,24 @@
#include <gtest/gtest.h>
#include <torch/csrc/utils/schema_info.h>
namespace torch {
namespace utils {
TEST(FunctionSchemaIsMutableTest, Basic) {
c10::FunctionSchema schema = torch::jit::parseSchema(
"aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))");
ASSERT_TRUE(schema.is_mutable(0));
ASSERT_TRUE(schema.is_mutable("self"));
ASSERT_FALSE(schema.is_mutable(1));
ASSERT_FALSE(schema.is_mutable("other"));
ASSERT_FALSE(schema.is_mutable(2));
ASSERT_FALSE(schema.is_mutable("alpha"));
}
TEST(FunctionSchemaIsMutableTest, InvalidArgument) {
c10::FunctionSchema schema = torch::jit::parseSchema(
"aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))");
ASSERT_THROW(schema.is_mutable(4), c10::Error);
ASSERT_THROW(schema.is_mutable("named_argument"), c10::Error);
}
} // namespace utils
} // namespace torch

View file

@ -0,0 +1,5 @@
#include <torch/csrc/utils/schema_info.h>
namespace torch {
namespace utils {} // namespace utils
} // namespace torch

View file

@ -0,0 +1,25 @@
#pragma once
#include <torch/csrc/jit/frontend/function_schema_parser.h>
namespace torch {
namespace utils {
/**
* class SchemaInfo
*
* Subclass of FunctionSchema that publicizes argument value specific operator
* behavior (mutation, aliasing, special cases, etc...)
*/
struct TORCH_API SchemaInfo {
public:
explicit SchemaInfo(c10::FunctionSchema schema) : schema_(schema) {}
explicit SchemaInfo(const char* signature)
: schema_(torch::jit::parseSchema(signature)) {}
private:
c10::FunctionSchema schema_;
};
} // namespace utils
} // namespace torch