From 113edbda64ca32e5d445b90cad9ecbfa2ea2bba7 Mon Sep 17 00:00:00 2001 From: ashari4 <70242157+ashari4@users.noreply.github.com> Date: Sat, 2 Oct 2021 07:15:06 -0700 Subject: [PATCH] Add bf16 specialization for IsDataType (#9254) * Add bf16 specialization * Fixed indent --- onnxruntime/core/providers/shared_library/provider_interfaces.h | 1 + .../core/providers/shared_library/provider_wrappedtypes.h | 2 ++ onnxruntime/core/session/provider_bridge_ort.cc | 1 + 3 files changed, 4 insertions(+) diff --git a/onnxruntime/core/providers/shared_library/provider_interfaces.h b/onnxruntime/core/providers/shared_library/provider_interfaces.h index 175f0e909e..7a0704a6b7 100644 --- a/onnxruntime/core/providers/shared_library/provider_interfaces.h +++ b/onnxruntime/core/providers/shared_library/provider_interfaces.h @@ -727,6 +727,7 @@ struct ProviderHost { virtual bool Tensor__IsDataType_float(const Tensor* p) noexcept = 0; virtual bool Tensor__IsDataType_double(const Tensor* p) noexcept = 0; virtual bool Tensor__IsDataType_MLFloat16(const Tensor* p) noexcept = 0; + virtual bool Tensor__IsDataType_BFloat16(const Tensor* p) noexcept = 0; virtual bool Tensor__IsDataTypeString(const Tensor* p) noexcept = 0; virtual const TensorShape& Tensor__Shape(const Tensor* p) = 0; diff --git a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h index c8f991033c..3801a98e14 100644 --- a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h +++ b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h @@ -881,6 +881,8 @@ template <> inline bool Tensor::IsDataType() const { return g_host->Tensor__IsDataType_double(this); } template <> inline bool Tensor::IsDataType() const { return g_host->Tensor__IsDataType_MLFloat16(this); } +template <> +inline bool Tensor::IsDataType() const { return g_host->Tensor__IsDataType_BFloat16(this); } template <> inline bool* Tensor::MutableData() { return g_host->Tensor__MutableData_bool(this); } diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 0bd5d8e13a..8d7bb1406b 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -819,6 +819,7 @@ struct ProviderHostImpl : ProviderHost { bool Tensor__IsDataType_float(const Tensor* p) noexcept override { return p->IsDataType(); } bool Tensor__IsDataType_double(const Tensor* p) noexcept override { return p->IsDataType(); } bool Tensor__IsDataType_MLFloat16(const Tensor* p) noexcept override { return p->IsDataType(); } + bool Tensor__IsDataType_BFloat16(const Tensor* p) noexcept override { return p->IsDataType(); } bool Tensor__IsDataTypeString(const Tensor* p) noexcept override { return p->IsDataTypeString(); } const TensorShape& Tensor__Shape(const Tensor* p) override { return p->Shape(); }