diff --git a/onnxruntime/core/dlpack/dlpack_converter.cc b/onnxruntime/core/dlpack/dlpack_converter.cc index fe9f2385ef..49f62b8e2d 100644 --- a/onnxruntime/core/dlpack/dlpack_converter.cc +++ b/onnxruntime/core/dlpack/dlpack_converter.cc @@ -67,6 +67,10 @@ DLDataType GetDlpackDataType(const OrtValue& ort_value) { dtype.code = DLDataTypeCode::kDLUInt; dtype.bits = sizeof(uint64_t); break; + case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16: + dtype.code = DLDataTypeCode::kDLBfloat; + dtype.bits = sizeof(BFloat16); + break; default: ORT_THROW("Unexpected data type of ", tensor.GetElementType()); } @@ -158,6 +162,13 @@ MLDataType GetOrtValueDataType(const DLDataType& dtype, bool is_bool_tensor) { default: ORT_THROW("Unsupported kFloat bits " + std::to_string(dtype.bits)); } + case DLDataTypeCode::kDLBfloat: + switch (dtype.bits) { + case 16: + return DataTypeImpl::GetType(); + default: + ORT_THROW("Unsupported kBFloat bits " + std::to_string(dtype.bits)); + } default: ORT_THROW("Unsupported code " + std::to_string(dtype.code)); }