From a7c2d1cb092db83e427a0e9a27282ae34d8a7029 Mon Sep 17 00:00:00 2001 From: Vincent Wang Date: Tue, 14 Dec 2021 13:34:14 +0800 Subject: [PATCH] bf16 for dlpack (#10016) --- onnxruntime/core/dlpack/dlpack_converter.cc | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/onnxruntime/core/dlpack/dlpack_converter.cc b/onnxruntime/core/dlpack/dlpack_converter.cc index fe9f2385ef..49f62b8e2d 100644 --- a/onnxruntime/core/dlpack/dlpack_converter.cc +++ b/onnxruntime/core/dlpack/dlpack_converter.cc @@ -67,6 +67,10 @@ DLDataType GetDlpackDataType(const OrtValue& ort_value) { dtype.code = DLDataTypeCode::kDLUInt; dtype.bits = sizeof(uint64_t); break; + case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16: + dtype.code = DLDataTypeCode::kDLBfloat; + dtype.bits = sizeof(BFloat16); + break; default: ORT_THROW("Unexpected data type of ", tensor.GetElementType()); } @@ -158,6 +162,13 @@ MLDataType GetOrtValueDataType(const DLDataType& dtype, bool is_bool_tensor) { default: ORT_THROW("Unsupported kFloat bits " + std::to_string(dtype.bits)); } + case DLDataTypeCode::kDLBfloat: + switch (dtype.bits) { + case 16: + return DataTypeImpl::GetType(); + default: + ORT_THROW("Unsupported kBFloat bits " + std::to_string(dtype.bits)); + } default: ORT_THROW("Unsupported code " + std::to_string(dtype.code)); }