bf16 for dlpack (#10016)

This commit is contained in:
Vincent Wang 2021-12-14 13:34:14 +08:00 committed by GitHub
parent cd0af7ad44
commit a7c2d1cb09
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -67,6 +67,10 @@ DLDataType GetDlpackDataType(const OrtValue& ort_value) {
dtype.code = DLDataTypeCode::kDLUInt;
dtype.bits = sizeof(uint64_t);
break;
case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16:
dtype.code = DLDataTypeCode::kDLBfloat;
dtype.bits = sizeof(BFloat16);
break;
default:
ORT_THROW("Unexpected data type of ", tensor.GetElementType());
}
@ -158,6 +162,13 @@ MLDataType GetOrtValueDataType(const DLDataType& dtype, bool is_bool_tensor) {
default:
ORT_THROW("Unsupported kFloat bits " + std::to_string(dtype.bits));
}
case DLDataTypeCode::kDLBfloat:
switch (dtype.bits) {
case 16:
return DataTypeImpl::GetType<BFloat16>();
default:
ORT_THROW("Unsupported kBFloat bits " + std::to_string(dtype.bits));
}
default:
ORT_THROW("Unsupported code " + std::to_string(dtype.code));
}