mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-18 21:21:17 +00:00
Add bf16 specialization for IsDataType (#9254)
* Add bf16 specialization * Fixed indent
This commit is contained in:
parent
8f6fd014e4
commit
113edbda64
3 changed files with 4 additions and 0 deletions
|
|
@ -727,6 +727,7 @@ struct ProviderHost {
|
|||
virtual bool Tensor__IsDataType_float(const Tensor* p) noexcept = 0;
|
||||
virtual bool Tensor__IsDataType_double(const Tensor* p) noexcept = 0;
|
||||
virtual bool Tensor__IsDataType_MLFloat16(const Tensor* p) noexcept = 0;
|
||||
virtual bool Tensor__IsDataType_BFloat16(const Tensor* p) noexcept = 0;
|
||||
virtual bool Tensor__IsDataTypeString(const Tensor* p) noexcept = 0;
|
||||
|
||||
virtual const TensorShape& Tensor__Shape(const Tensor* p) = 0;
|
||||
|
|
|
|||
|
|
@ -881,6 +881,8 @@ template <>
|
|||
inline bool Tensor::IsDataType<double>() const { return g_host->Tensor__IsDataType_double(this); }
|
||||
template <>
|
||||
inline bool Tensor::IsDataType<MLFloat16>() const { return g_host->Tensor__IsDataType_MLFloat16(this); }
|
||||
template <>
|
||||
inline bool Tensor::IsDataType<BFloat16>() const { return g_host->Tensor__IsDataType_BFloat16(this); }
|
||||
|
||||
template <>
|
||||
inline bool* Tensor::MutableData<bool>() { return g_host->Tensor__MutableData_bool(this); }
|
||||
|
|
|
|||
|
|
@ -819,6 +819,7 @@ struct ProviderHostImpl : ProviderHost {
|
|||
bool Tensor__IsDataType_float(const Tensor* p) noexcept override { return p->IsDataType<float>(); }
|
||||
bool Tensor__IsDataType_double(const Tensor* p) noexcept override { return p->IsDataType<double>(); }
|
||||
bool Tensor__IsDataType_MLFloat16(const Tensor* p) noexcept override { return p->IsDataType<MLFloat16>(); }
|
||||
bool Tensor__IsDataType_BFloat16(const Tensor* p) noexcept override { return p->IsDataType<BFloat16>(); }
|
||||
bool Tensor__IsDataTypeString(const Tensor* p) noexcept override { return p->IsDataTypeString(); }
|
||||
|
||||
const TensorShape& Tensor__Shape(const Tensor* p) override { return p->Shape(); }
|
||||
|
|
|
|||
Loading…
Reference in a new issue