Add bf16 specialization for IsDataType (#9254)

* Add bf16 specialization

* Fixed indent
This commit is contained in:
ashari4 2021-10-02 07:15:06 -07:00 committed by GitHub
parent 8f6fd014e4
commit 113edbda64
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 4 additions and 0 deletions

View file

@ -727,6 +727,7 @@ struct ProviderHost {
virtual bool Tensor__IsDataType_float(const Tensor* p) noexcept = 0;
virtual bool Tensor__IsDataType_double(const Tensor* p) noexcept = 0;
virtual bool Tensor__IsDataType_MLFloat16(const Tensor* p) noexcept = 0;
virtual bool Tensor__IsDataType_BFloat16(const Tensor* p) noexcept = 0;
virtual bool Tensor__IsDataTypeString(const Tensor* p) noexcept = 0;
virtual const TensorShape& Tensor__Shape(const Tensor* p) = 0;

View file

@ -881,6 +881,8 @@ template <>
inline bool Tensor::IsDataType<double>() const { return g_host->Tensor__IsDataType_double(this); }
template <>
inline bool Tensor::IsDataType<MLFloat16>() const { return g_host->Tensor__IsDataType_MLFloat16(this); }
template <>
inline bool Tensor::IsDataType<BFloat16>() const { return g_host->Tensor__IsDataType_BFloat16(this); }
template <>
inline bool* Tensor::MutableData<bool>() { return g_host->Tensor__MutableData_bool(this); }

View file

@ -819,6 +819,7 @@ struct ProviderHostImpl : ProviderHost {
bool Tensor__IsDataType_float(const Tensor* p) noexcept override { return p->IsDataType<float>(); }
bool Tensor__IsDataType_double(const Tensor* p) noexcept override { return p->IsDataType<double>(); }
bool Tensor__IsDataType_MLFloat16(const Tensor* p) noexcept override { return p->IsDataType<MLFloat16>(); }
bool Tensor__IsDataType_BFloat16(const Tensor* p) noexcept override { return p->IsDataType<BFloat16>(); }
bool Tensor__IsDataTypeString(const Tensor* p) noexcept override { return p->IsDataTypeString(); }
const TensorShape& Tensor__Shape(const Tensor* p) override { return p->Shape(); }