diff --git a/onnxruntime/core/providers/shared_library/provider_interfaces.h b/onnxruntime/core/providers/shared_library/provider_interfaces.h index 175f0e909e..7a0704a6b7 100644 --- a/onnxruntime/core/providers/shared_library/provider_interfaces.h +++ b/onnxruntime/core/providers/shared_library/provider_interfaces.h @@ -727,6 +727,7 @@ struct ProviderHost { virtual bool Tensor__IsDataType_float(const Tensor* p) noexcept = 0; virtual bool Tensor__IsDataType_double(const Tensor* p) noexcept = 0; virtual bool Tensor__IsDataType_MLFloat16(const Tensor* p) noexcept = 0; + virtual bool Tensor__IsDataType_BFloat16(const Tensor* p) noexcept = 0; virtual bool Tensor__IsDataTypeString(const Tensor* p) noexcept = 0; virtual const TensorShape& Tensor__Shape(const Tensor* p) = 0; diff --git a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h index c8f991033c..3801a98e14 100644 --- a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h +++ b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h @@ -881,6 +881,8 @@ template <> inline bool Tensor::IsDataType() const { return g_host->Tensor__IsDataType_double(this); } template <> inline bool Tensor::IsDataType() const { return g_host->Tensor__IsDataType_MLFloat16(this); } +template <> +inline bool Tensor::IsDataType() const { return g_host->Tensor__IsDataType_BFloat16(this); } template <> inline bool* Tensor::MutableData() { return g_host->Tensor__MutableData_bool(this); } diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 0bd5d8e13a..8d7bb1406b 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -819,6 +819,7 @@ struct ProviderHostImpl : ProviderHost { bool Tensor__IsDataType_float(const Tensor* p) noexcept override { return p->IsDataType(); } bool Tensor__IsDataType_double(const Tensor* p) noexcept override { return p->IsDataType(); } bool Tensor__IsDataType_MLFloat16(const Tensor* p) noexcept override { return p->IsDataType(); } + bool Tensor__IsDataType_BFloat16(const Tensor* p) noexcept override { return p->IsDataType(); } bool Tensor__IsDataTypeString(const Tensor* p) noexcept override { return p->IsDataTypeString(); } const TensorShape& Tensor__Shape(const Tensor* p) override { return p->Shape(); }