diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp index 6f3942b19d..a7f732ab5c 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp @@ -275,7 +275,6 @@ const static SupportedTensorDataTypes supportedTypeListLogicalComparison9[2] = / const static SupportedTensorDataTypes supportedTypeListSigned[1] = { SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::Int16 | SupportedTensorDataTypes::Int8 }; const static SupportedTensorDataTypes supportedTypeListRange[1] = {SupportedTensorDataTypes::Int16|SupportedTensorDataTypes::Int32|SupportedTensorDataTypes::Float32}; const static SupportedTensorDataTypes supportedTypeListInteger[3] = {SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8, SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8, SupportedTensorDataTypes::Int32 }; -const static SupportedTensorDataTypes supportedTypeListPadWithoutFloat16[1] = { SupportedTensorDataTypes::Int8to32 | SupportedTensorDataTypes::Float32 }; const static SupportedTensorDataTypes supportedTypeListQLinearMatMul[3] = { SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8, SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8, @@ -358,7 +357,7 @@ const static OperatorRegistrationInformation operatorRegistrationInformationTabl {REG_INFO_VER( 10, Slice, typeNameListSlice10, supportedTypeListSlice10, DmGraphSupport::Supported, {1, 2, 3, 4}, std::nullopt, QuerySlice)}, // Adds negative axes. {REG_INFO_VER( 11, Slice, typeNameListSlice10, supportedTypeListSlice10, DmGraphSupport::Supported, {1, 2, 3, 4}, std::nullopt, QuerySlice)}, {REG_INFO_VER( 7, Pad, typeNameListDefault, supportedTypeListNumericDefault, DmGraphSupport::Supported)}, - {REG_INFO_VER( 11, Pad, typeNameListDefault, supportedTypeListPadWithoutFloat16, DmGraphSupport::Supported, {1, 2} /*pads, value*/)}, // https://microsoft.visualstudio.com/OS/_workitems/edit/26007728 + {REG_INFO_VER( 11, Pad, typeNameListDefault, supportedTypeListNumericDefault, DmGraphSupport::Supported, {1, 2} /*pads, value*/)}, // https://microsoft.visualstudio.com/OS/_workitems/edit/26007728 {REG_INFO( 7, SpaceToDepth, typeNameListDefault, supportedTypeListScalars8to32, DmGraphSupport::Supported)}, {REG_INFO( 7, DepthToSpace, typeNameListDefault, supportedTypeListScalars8to32, DmGraphSupport::Supported)}, {REG_INFO( 11, DepthToSpace, typeNameListDefault, supportedTypeListScalars8to32, DmGraphSupport::Supported)}, diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp index 07de6d1988..303a334841 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp @@ -113,6 +113,36 @@ namespace OperatorHelper } } + float CastFloat16ToFloat32(uint16_t input) + { + // Promote float16m10e5s1 to float32m23e8s1. + // Note this works on machines of both ascending and descending byte + // endianness, so long as float32 and uint32 endianness match. + // It does not work for a few abberant architectures which store + // float32 and uint32 with opposite endianness. + + const uint32_t float16unsignedValueMask = 0x7FFF; + const uint32_t float16signMask = 0x8000; + const uint32_t float16exponentMask = 0x7C00; + const uint32_t float32exponentMask = 0x7F800000; + + uint32_t float16unsignedValue = input & float16unsignedValueMask; + uint32_t float16sign = input & float16signMask; + uint32_t float16exponent = input & float16exponentMask; + + // Shift mantissa bits left (23 - 10 = 13). + // Adjust exponent bias (127 - 15 = 112, 112 << 23 == 0x38000000). + // Move sign bit to float32 MSB (32 - 16 = 16). + uint32_t float32unsignedValue = (float16unsignedValue << 13) + 0x38000000; + uint32_t float32sign = float16sign << 16; + uint32_t result = (float16exponent == 0) ? (float32unsignedValue & ~float32exponentMask) : // Denormal + (float16exponent == float16exponentMask) ? (float32unsignedValue | float32exponentMask) : // Infinity + float32unsignedValue; // Any other normal value + result |= float32sign; + + return reinterpret_cast(result); + } + int64_t CastToInt64(MLOperatorTensorDataType tensorDataType, const void* p) { switch (tensorDataType) @@ -150,7 +180,7 @@ namespace OperatorHelper case MLOperatorTensorDataType::Int64: return static_cast(*reinterpret_cast(p)); case MLOperatorTensorDataType::String: ML_INVALID_ARGUMENT("MLOperatorTensorDataType::String type is unsupported for reading as an integer."); case MLOperatorTensorDataType::Bool: return static_cast(*reinterpret_cast(p)); - case MLOperatorTensorDataType::Float16: ML_INVALID_ARGUMENT("MLOperatorTensorDataType::Float16 type is unsupported for reading as an integer."); + case MLOperatorTensorDataType::Float16: return static_cast(CastFloat16ToFloat32(*reinterpret_cast(p))); case MLOperatorTensorDataType::Double: return static_cast(*reinterpret_cast(p)); case MLOperatorTensorDataType::UInt32: return static_cast(*reinterpret_cast(p)); case MLOperatorTensorDataType::UInt64: return static_cast(*reinterpret_cast(p));