mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[JIT] Add mutability checks in FunctionSchema and create SchemaInfo subclass (#80734)
- Added overloads to is_mutable method in FunctionSchema to tell whether an argument at index is mutable or an argument with name is mutable. - Created SchemaInfo subclass of FunctionSchema with constructors from FunctionSchema and from const char* signature. - Tested is_mutable method overloads in new test_schema_info.cpp file. **Note that this pr is used to set up SchemaInfo. Implementation for SchemaInfo will be addressed in later commits** Differential Revision: [D37651384](https://our.internmc.facebook.com/intern/diff/D37651384) Pull Request resolved: https://github.com/pytorch/pytorch/pull/80734 Approved by: https://github.com/davidberard98
This commit is contained in:
parent
528ee0fa75
commit
b4e342928b
6 changed files with 69 additions and 0 deletions
|
|
@ -359,7 +359,20 @@ struct FunctionSchema {
|
|||
return aliasInfo && aliasInfo->isWrite();
|
||||
});
|
||||
}
|
||||
bool is_mutable(size_t index) const {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
index < arguments().size(),
|
||||
"Invalid index for schema.");
|
||||
const AliasInfo* aliasInfo = arguments()[index].alias_info();
|
||||
return aliasInfo && aliasInfo->isWrite();
|
||||
}
|
||||
bool is_mutable(c10::string_view name) const {
|
||||
c10::optional<int> index = argumentIndexWithName(name);
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
index != c10::nullopt, "Schema has no argument named ", name);
|
||||
|
||||
return is_mutable(*index);
|
||||
}
|
||||
c10::optional<int> argumentIndexWithName(c10::string_view name) const {
|
||||
for (const auto i : c10::irange(arguments().size())) {
|
||||
if(name == arguments()[i].name())
|
||||
|
|
|
|||
|
|
@ -365,6 +365,7 @@ core_sources_full_mobile_no_backend_interface = [
|
|||
"torch/csrc/jit/testing/file_check.cpp",
|
||||
"torch/csrc/jit/testing/hooks_for_testing.cpp",
|
||||
"torch/csrc/utils/cpp_stacktraces.cpp",
|
||||
"torch/csrc/utils/schema_info.cpp",
|
||||
"torch/csrc/utils/tensor_flatten.cpp",
|
||||
"torch/csrc/utils/variadic.cpp",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -80,6 +80,7 @@ set(JIT_TEST_SRCS
|
|||
${JIT_TEST_ROOT}/test_peephole_optimize.cpp
|
||||
${JIT_TEST_ROOT}/test_qualified_name.cpp
|
||||
${JIT_TEST_ROOT}/test_save_load.cpp
|
||||
${JIT_TEST_ROOT}/test_schema_info.cpp
|
||||
${JIT_TEST_ROOT}/test_schema_matching.cpp
|
||||
${JIT_TEST_ROOT}/test_stack_opt.cpp
|
||||
${JIT_TEST_ROOT}/test_subgraph_matcher.cpp
|
||||
|
|
|
|||
24
test/cpp/jit/test_schema_info.cpp
Normal file
24
test/cpp/jit/test_schema_info.cpp
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
#include <gtest/gtest.h>
|
||||
#include <torch/csrc/utils/schema_info.h>
|
||||
|
||||
namespace torch {
|
||||
namespace utils {
|
||||
TEST(FunctionSchemaIsMutableTest, Basic) {
|
||||
c10::FunctionSchema schema = torch::jit::parseSchema(
|
||||
"aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))");
|
||||
ASSERT_TRUE(schema.is_mutable(0));
|
||||
ASSERT_TRUE(schema.is_mutable("self"));
|
||||
ASSERT_FALSE(schema.is_mutable(1));
|
||||
ASSERT_FALSE(schema.is_mutable("other"));
|
||||
ASSERT_FALSE(schema.is_mutable(2));
|
||||
ASSERT_FALSE(schema.is_mutable("alpha"));
|
||||
}
|
||||
|
||||
TEST(FunctionSchemaIsMutableTest, InvalidArgument) {
|
||||
c10::FunctionSchema schema = torch::jit::parseSchema(
|
||||
"aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))");
|
||||
ASSERT_THROW(schema.is_mutable(4), c10::Error);
|
||||
ASSERT_THROW(schema.is_mutable("named_argument"), c10::Error);
|
||||
}
|
||||
} // namespace utils
|
||||
} // namespace torch
|
||||
5
torch/csrc/utils/schema_info.cpp
Normal file
5
torch/csrc/utils/schema_info.cpp
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
#include <torch/csrc/utils/schema_info.h>
|
||||
|
||||
namespace torch {
|
||||
namespace utils {} // namespace utils
|
||||
} // namespace torch
|
||||
25
torch/csrc/utils/schema_info.h
Normal file
25
torch/csrc/utils/schema_info.h
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
#pragma once
|
||||
|
||||
#include <torch/csrc/jit/frontend/function_schema_parser.h>
|
||||
|
||||
namespace torch {
|
||||
namespace utils {
|
||||
|
||||
/**
|
||||
* class SchemaInfo
|
||||
*
|
||||
* Subclass of FunctionSchema that publicizes argument value specific operator
|
||||
* behavior (mutation, aliasing, special cases, etc...)
|
||||
*/
|
||||
|
||||
struct TORCH_API SchemaInfo {
|
||||
public:
|
||||
explicit SchemaInfo(c10::FunctionSchema schema) : schema_(schema) {}
|
||||
explicit SchemaInfo(const char* signature)
|
||||
: schema_(torch::jit::parseSchema(signature)) {}
|
||||
|
||||
private:
|
||||
c10::FunctionSchema schema_;
|
||||
};
|
||||
} // namespace utils
|
||||
} // namespace torch
|
||||
Loading…
Reference in a new issue