mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-28 03:20:58 +00:00
bf16 for dlpack (#10016)
This commit is contained in:
parent
cd0af7ad44
commit
a7c2d1cb09
1 changed files with 11 additions and 0 deletions
|
|
@ -67,6 +67,10 @@ DLDataType GetDlpackDataType(const OrtValue& ort_value) {
|
|||
dtype.code = DLDataTypeCode::kDLUInt;
|
||||
dtype.bits = sizeof(uint64_t);
|
||||
break;
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16:
|
||||
dtype.code = DLDataTypeCode::kDLBfloat;
|
||||
dtype.bits = sizeof(BFloat16);
|
||||
break;
|
||||
default:
|
||||
ORT_THROW("Unexpected data type of ", tensor.GetElementType());
|
||||
}
|
||||
|
|
@ -158,6 +162,13 @@ MLDataType GetOrtValueDataType(const DLDataType& dtype, bool is_bool_tensor) {
|
|||
default:
|
||||
ORT_THROW("Unsupported kFloat bits " + std::to_string(dtype.bits));
|
||||
}
|
||||
case DLDataTypeCode::kDLBfloat:
|
||||
switch (dtype.bits) {
|
||||
case 16:
|
||||
return DataTypeImpl::GetType<BFloat16>();
|
||||
default:
|
||||
ORT_THROW("Unsupported kBFloat bits " + std::to_string(dtype.bits));
|
||||
}
|
||||
default:
|
||||
ORT_THROW("Unsupported code " + std::to_string(dtype.code));
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue