Fix ORT Eager Mode to work with Pytorch 1.12 (#12323)

This commit is contained in:
msftlincoln 2022-07-27 16:24:46 -04:00 committed by GitHub
parent e2eeffeafb
commit 9cf6912bba
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 30 additions and 27 deletions

View file

@ -575,6 +575,10 @@ if (MSVC)
if (onnxruntime_DEV_MODE)
string(APPEND CMAKE_CXX_FLAGS " /wd26812")
string(APPEND CMAKE_C_FLAGS " /wd26812")
# warning C4805: '|': unsafe mix of type 'uintptr_t' and type 'bool' in operation (from c10/core/TensorImpl.h)
if (onnxruntime_ENABLE_EAGER_MODE)
string(APPEND CMAKE_CXX_FLAGS " /wd4805")
endif()
endif()
endif()
string(APPEND CMAKE_CXX_FLAGS " /experimental:external /external:W0 /external:templates- /external:I ${CMAKE_CURRENT_SOURCE_DIR} /external:I ${CMAKE_CURRENT_BINARY_DIR}")

View file

@ -355,6 +355,10 @@ class StreamType(ConcreteType):
pass
class SymIntType(ConcreteType):
pass
# region Decls

View file

@ -153,7 +153,6 @@ hand_implemented = {
"aten::gelu": Gelu("self"),
"aten::max": ReduceMax("self", keepdims=0),
"aten::min": ReduceMin("self", keepdims=0),
"aten::_cat": Concat("tensors", "dim"),
"aten::fill_.Scalar": SignatureOnly(),
"aten::ne.Scalar_out": Cast(Not(Equal("self", "other")), to="GetONNXTensorProtoDataType(out.scalar_type())"),
"aten::ne.Tensor_out": Cast(Not(Equal("self", "other")), to="GetONNXTensorProtoDataType(out.scalar_type())"),
@ -187,6 +186,7 @@ aten_output_type["aten::nonzero"] = "at::ScalarType::Long"
# This is done to make sure it is backward and future compatible
if version.parse(torch.__version__) < version.parse(TORCH_API_CHANGE_VERSION):
hand_implemented["aten::gelu_backward"] = GeluGrad("grad", "self")
hand_implemented["aten::_cat"] = Concat("tensors", "dim")
else:
hand_implemented["aten::gelu_backward"] = GeluGrad("grad_output", "self")

View file

@ -291,6 +291,7 @@ class TorchParser(ParserBase):
"Storage": StorageType,
"ConstQuantizerPtr": ConstQuantizerPtrType,
"Stream": StreamType,
"SymInt": SymIntType,
}
identifier = self._expect_token(TokenKind.IDENTIFIER)
base_type_parser = base_type_parsers.get(identifier.value)

View file

@ -53,31 +53,26 @@ void ORTTensorImpl::shallow_copy_from(
allow_tensor_metadata_change());
}
at::IntArrayRef ORTTensorImpl::sizes() const {
at::IntArrayRef ORTTensorImpl::sizes_custom() const {
const_cast<ORTTensorImpl*>(this)->cacheSizeMetadata();
return c10::TensorImpl::sizes();
return c10::TensorImpl::sizes_default();
}
int64_t ORTTensorImpl::dim() const {
int64_t ORTTensorImpl::dim_custom() const {
const_cast<ORTTensorImpl*>(this)->cacheSizeMetadata();
return c10::TensorImpl::dim();
return c10::TensorImpl::dim_default();
}
int64_t ORTTensorImpl::numel() const {
int64_t ORTTensorImpl::numel_custom() const {
const_cast<ORTTensorImpl*>(this)->cacheSizeMetadata();
return c10::TensorImpl::numel();
return c10::TensorImpl::numel_default();
}
bool ORTTensorImpl::is_contiguous(at::MemoryFormat memory_format) const {
bool ORTTensorImpl::is_contiguous_custom(at::MemoryFormat memory_format) const {
auto& tensor = tensor_.Get<onnxruntime::Tensor>();
return tensor.IsContiguous();
}
int64_t ORTTensorImpl::size(int64_t d) const {
const_cast<ORTTensorImpl*>(this)->cacheSizeMetadata();
return c10::TensorImpl::size(d);
}
void ORTTensorImpl::cacheSizeMetadata() {
// TODO: wrap with change generation guard
auto& tensor = tensor_.Get<onnxruntime::Tensor>();
@ -102,10 +97,10 @@ bool ORTTensorImpl::has_storage() const {
return false;
}
at::IntArrayRef ORTTensorImpl::strides() const {
at::IntArrayRef ORTTensorImpl::strides_custom() const {
const_cast<ORTTensorImpl*>(this)->cacheSizeMetadata();
return sizes_and_strides_.strides_arrayref();
return sizes_and_strides_.strides_arrayref();
}
} // namespace eager
} // namespace torch_ort
} // namespace torch_ort

View file

@ -17,6 +17,7 @@ class ORTTensorImpl final : public c10::TensorImpl {
c10::DispatchKeySet{c10::DispatchKey::ORT},
options.dtype(),
options.device()) {
set_sizes_strides_policy(SizesStridesPolicy::CustomSizes);
set_tensor(tensor);
}
@ -38,21 +39,19 @@ class ORTTensorImpl final : public c10::TensorImpl {
void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) override;
at::IntArrayRef sizes() const override;
at::IntArrayRef sizes_custom() const override;
int64_t dim() const override;
int64_t dim_custom() const override;
int64_t numel() const override;
int64_t numel_custom() const override;
bool is_contiguous(at::MemoryFormat memory_format) const override;
int64_t size(int64_t d) const override;
bool is_contiguous_custom(at::MemoryFormat memory_format) const override;
const at::Storage& storage() const override;
bool has_storage() const override;
at::IntArrayRef strides() const override;
at::IntArrayRef strides_custom() const override;
private:
void cacheSizeMetadata();
@ -60,4 +59,4 @@ class ORTTensorImpl final : public c10::TensorImpl {
};
} // namespace eager
} // namespace torch_ort
} // namespace torch_ort

View file

@ -1,2 +1,2 @@
torch==1.11.0
torch==1.12
setuptools>=41.4.0

View file

@ -1,6 +1,6 @@
--pre
-f https://download.pytorch.org/whl/cpu/torch_stable.html
torch==1.11.0
torch==1.12.0
setuptools>=41.4.0
cerberus
h5py

View file

@ -2,5 +2,5 @@ setuptools
wheel
numpy
typing_extensions
torch==1.11.0
torch==1.12
parameterized