mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-30 03:37:44 +00:00
Fix ORT Eager Mode to work with Pytorch 1.12 (#12323)
This commit is contained in:
parent
e2eeffeafb
commit
9cf6912bba
9 changed files with 30 additions and 27 deletions
|
|
@ -575,6 +575,10 @@ if (MSVC)
|
|||
if (onnxruntime_DEV_MODE)
|
||||
string(APPEND CMAKE_CXX_FLAGS " /wd26812")
|
||||
string(APPEND CMAKE_C_FLAGS " /wd26812")
|
||||
# warning C4805: '|': unsafe mix of type 'uintptr_t' and type 'bool' in operation (from c10/core/TensorImpl.h)
|
||||
if (onnxruntime_ENABLE_EAGER_MODE)
|
||||
string(APPEND CMAKE_CXX_FLAGS " /wd4805")
|
||||
endif()
|
||||
endif()
|
||||
endif()
|
||||
string(APPEND CMAKE_CXX_FLAGS " /experimental:external /external:W0 /external:templates- /external:I ${CMAKE_CURRENT_SOURCE_DIR} /external:I ${CMAKE_CURRENT_BINARY_DIR}")
|
||||
|
|
|
|||
|
|
@ -355,6 +355,10 @@ class StreamType(ConcreteType):
|
|||
pass
|
||||
|
||||
|
||||
class SymIntType(ConcreteType):
|
||||
pass
|
||||
|
||||
|
||||
# region Decls
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -153,7 +153,6 @@ hand_implemented = {
|
|||
"aten::gelu": Gelu("self"),
|
||||
"aten::max": ReduceMax("self", keepdims=0),
|
||||
"aten::min": ReduceMin("self", keepdims=0),
|
||||
"aten::_cat": Concat("tensors", "dim"),
|
||||
"aten::fill_.Scalar": SignatureOnly(),
|
||||
"aten::ne.Scalar_out": Cast(Not(Equal("self", "other")), to="GetONNXTensorProtoDataType(out.scalar_type())"),
|
||||
"aten::ne.Tensor_out": Cast(Not(Equal("self", "other")), to="GetONNXTensorProtoDataType(out.scalar_type())"),
|
||||
|
|
@ -187,6 +186,7 @@ aten_output_type["aten::nonzero"] = "at::ScalarType::Long"
|
|||
# This is done to make sure it is backward and future compatible
|
||||
if version.parse(torch.__version__) < version.parse(TORCH_API_CHANGE_VERSION):
|
||||
hand_implemented["aten::gelu_backward"] = GeluGrad("grad", "self")
|
||||
hand_implemented["aten::_cat"] = Concat("tensors", "dim")
|
||||
else:
|
||||
hand_implemented["aten::gelu_backward"] = GeluGrad("grad_output", "self")
|
||||
|
||||
|
|
|
|||
|
|
@ -291,6 +291,7 @@ class TorchParser(ParserBase):
|
|||
"Storage": StorageType,
|
||||
"ConstQuantizerPtr": ConstQuantizerPtrType,
|
||||
"Stream": StreamType,
|
||||
"SymInt": SymIntType,
|
||||
}
|
||||
identifier = self._expect_token(TokenKind.IDENTIFIER)
|
||||
base_type_parser = base_type_parsers.get(identifier.value)
|
||||
|
|
|
|||
|
|
@ -53,31 +53,26 @@ void ORTTensorImpl::shallow_copy_from(
|
|||
allow_tensor_metadata_change());
|
||||
}
|
||||
|
||||
at::IntArrayRef ORTTensorImpl::sizes() const {
|
||||
at::IntArrayRef ORTTensorImpl::sizes_custom() const {
|
||||
const_cast<ORTTensorImpl*>(this)->cacheSizeMetadata();
|
||||
return c10::TensorImpl::sizes();
|
||||
return c10::TensorImpl::sizes_default();
|
||||
}
|
||||
|
||||
int64_t ORTTensorImpl::dim() const {
|
||||
int64_t ORTTensorImpl::dim_custom() const {
|
||||
const_cast<ORTTensorImpl*>(this)->cacheSizeMetadata();
|
||||
return c10::TensorImpl::dim();
|
||||
return c10::TensorImpl::dim_default();
|
||||
}
|
||||
|
||||
int64_t ORTTensorImpl::numel() const {
|
||||
int64_t ORTTensorImpl::numel_custom() const {
|
||||
const_cast<ORTTensorImpl*>(this)->cacheSizeMetadata();
|
||||
return c10::TensorImpl::numel();
|
||||
return c10::TensorImpl::numel_default();
|
||||
}
|
||||
|
||||
bool ORTTensorImpl::is_contiguous(at::MemoryFormat memory_format) const {
|
||||
bool ORTTensorImpl::is_contiguous_custom(at::MemoryFormat memory_format) const {
|
||||
auto& tensor = tensor_.Get<onnxruntime::Tensor>();
|
||||
return tensor.IsContiguous();
|
||||
}
|
||||
|
||||
int64_t ORTTensorImpl::size(int64_t d) const {
|
||||
const_cast<ORTTensorImpl*>(this)->cacheSizeMetadata();
|
||||
return c10::TensorImpl::size(d);
|
||||
}
|
||||
|
||||
void ORTTensorImpl::cacheSizeMetadata() {
|
||||
// TODO: wrap with change generation guard
|
||||
auto& tensor = tensor_.Get<onnxruntime::Tensor>();
|
||||
|
|
@ -102,10 +97,10 @@ bool ORTTensorImpl::has_storage() const {
|
|||
return false;
|
||||
}
|
||||
|
||||
at::IntArrayRef ORTTensorImpl::strides() const {
|
||||
at::IntArrayRef ORTTensorImpl::strides_custom() const {
|
||||
const_cast<ORTTensorImpl*>(this)->cacheSizeMetadata();
|
||||
return sizes_and_strides_.strides_arrayref();
|
||||
return sizes_and_strides_.strides_arrayref();
|
||||
}
|
||||
|
||||
} // namespace eager
|
||||
} // namespace torch_ort
|
||||
} // namespace torch_ort
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ class ORTTensorImpl final : public c10::TensorImpl {
|
|||
c10::DispatchKeySet{c10::DispatchKey::ORT},
|
||||
options.dtype(),
|
||||
options.device()) {
|
||||
set_sizes_strides_policy(SizesStridesPolicy::CustomSizes);
|
||||
set_tensor(tensor);
|
||||
}
|
||||
|
||||
|
|
@ -38,21 +39,19 @@ class ORTTensorImpl final : public c10::TensorImpl {
|
|||
|
||||
void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) override;
|
||||
|
||||
at::IntArrayRef sizes() const override;
|
||||
at::IntArrayRef sizes_custom() const override;
|
||||
|
||||
int64_t dim() const override;
|
||||
int64_t dim_custom() const override;
|
||||
|
||||
int64_t numel() const override;
|
||||
int64_t numel_custom() const override;
|
||||
|
||||
bool is_contiguous(at::MemoryFormat memory_format) const override;
|
||||
|
||||
int64_t size(int64_t d) const override;
|
||||
bool is_contiguous_custom(at::MemoryFormat memory_format) const override;
|
||||
|
||||
const at::Storage& storage() const override;
|
||||
|
||||
bool has_storage() const override;
|
||||
|
||||
at::IntArrayRef strides() const override;
|
||||
at::IntArrayRef strides_custom() const override;
|
||||
|
||||
private:
|
||||
void cacheSizeMetadata();
|
||||
|
|
@ -60,4 +59,4 @@ class ORTTensorImpl final : public c10::TensorImpl {
|
|||
};
|
||||
|
||||
} // namespace eager
|
||||
} // namespace torch_ort
|
||||
} // namespace torch_ort
|
||||
|
|
|
|||
|
|
@ -1,2 +1,2 @@
|
|||
torch==1.11.0
|
||||
torch==1.12
|
||||
setuptools>=41.4.0
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
--pre
|
||||
-f https://download.pytorch.org/whl/cpu/torch_stable.html
|
||||
torch==1.11.0
|
||||
torch==1.12.0
|
||||
setuptools>=41.4.0
|
||||
cerberus
|
||||
h5py
|
||||
|
|
|
|||
|
|
@ -2,5 +2,5 @@ setuptools
|
|||
wheel
|
||||
numpy
|
||||
typing_extensions
|
||||
torch==1.11.0
|
||||
torch==1.12
|
||||
parameterized
|
||||
|
|
|
|||
Loading…
Reference in a new issue