From a1bb6705362a5eaf57bc7c3cfd20dd3c054df2a6 Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Tue, 25 Jul 2023 15:31:32 -0400 Subject: [PATCH] [java] Fp16 fix for android/react native (#16832) ### Description This PR splits out the FP16 conversions into a separate package we can override in the android build with a version which works on old versions of Android. I'm not sure the android build system changes are correct as I haven't got an android build environment configured on my workstation. @YUNQIUGUO if the CI build fails we should follow up offline to get my environment configured so I can iterate on it. ### Motivation and Context Fixes the CI failure after #16703. --- java/build-android.gradle | 3 + java/build.gradle | 4 + .../onnxruntime/platform/Fp16Conversions.java | 237 +++++++++++++++ .../ai/onnxruntime/platform/package-info.java | 13 + .../java/ai/onnxruntime/OnnxSparseTensor.java | 5 +- .../main/java/ai/onnxruntime/OnnxTensor.java | 9 +- .../src/main/java/ai/onnxruntime/OrtUtil.java | 268 ---------------- .../onnxruntime/platform/Fp16Conversions.java | 287 ++++++++++++++++++ .../ai/onnxruntime/platform/package-info.java | 14 + .../java/ai/onnxruntime/OnnxTensorTest.java | 21 +- 10 files changed, 577 insertions(+), 284 deletions(-) create mode 100644 java/src/main/android/ai/onnxruntime/platform/Fp16Conversions.java create mode 100644 java/src/main/android/ai/onnxruntime/platform/package-info.java create mode 100644 java/src/main/jvm/ai/onnxruntime/platform/Fp16Conversions.java create mode 100644 java/src/main/jvm/ai/onnxruntime/platform/package-info.java diff --git a/java/build-android.gradle b/java/build-android.gradle index 8ab56e8f20..6f0dfa0a4f 100644 --- a/java/build-android.gradle +++ b/java/build-android.gradle @@ -98,6 +98,9 @@ android { sourceSets { main { jniLibs.srcDirs = [jniLibsDir] + java { + srcDirs = ['src/main/java', 'src/main/android'] + } } } diff --git a/java/build.gradle b/java/build.gradle index 78a643aeed..c0a75f8165 100644 --- a/java/build.gradle +++ b/java/build.gradle @@ -102,6 +102,10 @@ compileTestJava { } } +sourceSets.main.java { + srcDirs = ['src/main/java', 'src/main/jvm'] +} + sourceSets.test { // add test resource files resources.srcDirs += [ diff --git a/java/src/main/android/ai/onnxruntime/platform/Fp16Conversions.java b/java/src/main/android/ai/onnxruntime/platform/Fp16Conversions.java new file mode 100644 index 0000000000..dd7dd07fc1 --- /dev/null +++ b/java/src/main/android/ai/onnxruntime/platform/Fp16Conversions.java @@ -0,0 +1,237 @@ +/* + * Copyright (c) 2023, Oracle and/or its affiliates. All rights reserved. + * Licensed under the MIT License. + */ +package ai.onnxruntime.platform; + +import java.lang.invoke.MethodHandle; +import java.lang.invoke.MethodHandles; +import java.lang.invoke.MethodType; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.FloatBuffer; +import java.nio.ShortBuffer; +import java.util.logging.Level; +import java.util.logging.Logger; + +/** * Conversions between fp16, bfloat16 and fp32. */ +public final class Fp16Conversions { + private static final Logger logger = Logger.getLogger(Fp16Conversions.class.getName()); + + /** + * Rounds a buffer of floats into a buffer containing fp16 values (stored as shorts in Java). + * + *

Respects the position and limit of the input buffer. + * + * @param buf The buffer of floats. + * @return A buffer of fp16 values stored as shorts. + */ + public static ShortBuffer convertFloatBufferToFp16Buffer(FloatBuffer buf) { + int pos = buf.position(); + int remaining = buf.remaining(); + ShortBuffer output = + ByteBuffer.allocateDirect(remaining * 2).order(ByteOrder.nativeOrder()).asShortBuffer(); + for (int i = 0; i < remaining; i++) { + output.put(i, floatToFp16(buf.get(i + pos))); + } + return output; + } + + /** + * Casts a buffer of fp16 values stored as shorts into a buffer of floats. + * + *

Respects the position and limit of the input buffer. + * + * @param buf The buffer of fp16 values stored as shorts. + * @return A buffer of float values. + */ + public static FloatBuffer convertFp16BufferToFloatBuffer(ShortBuffer buf) { + int pos = buf.position(); + int remaining = buf.remaining(); + FloatBuffer output = + ByteBuffer.allocateDirect(remaining * 4).order(ByteOrder.nativeOrder()).asFloatBuffer(); + for (int i = 0; i < remaining; i++) { + output.put(i, fp16ToFloat(buf.get(i + pos))); + } + return output; + } + + /** + * Rounds a buffer of floats into a buffer containing bf16 values (stored as shorts in Java). + * + *

Respects the position and limit of the input buffer. + * + * @param buf The buffer of floats. + * @return A buffer of bf16 values stored as shorts. + */ + public static ShortBuffer convertFloatBufferToBf16Buffer(FloatBuffer buf) { + int pos = buf.position(); + int remaining = buf.remaining(); + ShortBuffer output = + ByteBuffer.allocateDirect(remaining * 2).order(ByteOrder.nativeOrder()).asShortBuffer(); + for (int i = 0; i < remaining; i++) { + output.put(i, floatToBf16(buf.get(i + pos))); + } + return output; + } + + /** + * Casts a buffer of bf16 values stored as shorts into a buffer of floats. + * + *

Respects the position and limit of the input buffer. + * + * @param buf The buffer of bf16 values stored as shorts. + * @return A buffer of float values. + */ + public static FloatBuffer convertBf16BufferToFloatBuffer(ShortBuffer buf) { + int pos = buf.position(); + int remaining = buf.remaining(); + FloatBuffer output = + ByteBuffer.allocateDirect(remaining * 4).order(ByteOrder.nativeOrder()).asFloatBuffer(); + for (int i = 0; i < remaining; i++) { + output.put(i, bf16ToFloat(buf.get(i + pos))); + } + return output; + } + + /** + * Converts a fp16 value stored in a short into a float value. + * + *

On Android this is an alias for {@link #mlasFp16ToFloat(short)}. + * + * @param input The fp16 value. + * @return The float value. + */ + public static float fp16ToFloat(short input) { + return mlasFp16ToFloat(input); + } + + /** + * Converts a float value into a fp16 value stored in a short. + * + *

On Android this is an alias for {@link #mlasFloatToFp16(float)}. + * + * @param input The float value. + * @return The fp16 value. + */ + public static short floatToFp16(float input) { + return mlasFloatToFp16(input); + } + + /** + * Upcasts a fp16 value to a float. Mirrors the conversion in MLAS. + * + * @param input A uint16_t representing an IEEE half precision float. + * @return A float. + */ + public static float mlasFp16ToFloat(short input) { + // Port of MLAS_Half2Float from onnxruntime/core/mlas/inc/mlas_float16.h + final int MAGIC = 113 << 23; + // exponent mask after shift + final int SHIFTED_EXP = 0x7c00 << 13; + + // exponent/mantissa bits + int bits = (input & 0x7fff) << 13; + // just the exponent + final int exp = SHIFTED_EXP & bits; + // exponent adjust + bits += (127 - 15) << 23; + + // handle exponent special cases + if (exp == SHIFTED_EXP) { + // Inf/NaN? + // extra exp adjust + bits += (128 - 16) << 23; + } else if (exp == 0) { + // Zero/Denormal? + // extra exp adjust + bits += (1 << 23); + // renormalize + float tmp = Float.intBitsToFloat(bits) - Float.intBitsToFloat(MAGIC); + bits = Float.floatToIntBits(tmp); + } + + // sign bit + bits |= (input & 0x8000) << 16; + + return Float.intBitsToFloat(bits); + } + + /** + * Rounds a float value to fp16. Mirrors the conversion in MLAS. + * + * @param input A float value. + * @return The value rounded to an IEEE half precision value. + */ + public static short mlasFloatToFp16(float input) { + // Port of MLAS_Float2Half from onnxruntime/core/mlas/inc/mlas_float16.h + int bits = Float.floatToIntBits(input); + final int F32_INFINITY = Float.floatToIntBits(Float.POSITIVE_INFINITY); + final int F16_MAX = (127 + 16) << 23; + final int DENORM_MAGIC = ((127 - 15) + (23 - 10) + 1) << 23; + final int SIGN_MASK = 0x80000000; + final int ROUNDING_CONST = ((15 - 127) << 23) + 0xfff; + + int sign = bits & SIGN_MASK; + // mask out sign bit + bits ^= sign; + + short output; + if (bits >= F16_MAX) { + // Inf or NaN (all exponent bits set) + output = (bits > F32_INFINITY) ? (short) 0x7e00 : (short) 0x7c00; + } else { + if (bits < (113 << 23)) { + // Subnormal or zero + // use a magic value to align our 10 mantissa bits at the bottom of + // the float. as long as FP addition is round-to-nearest-even this + // just works. + float tmp = Float.intBitsToFloat(bits) + Float.intBitsToFloat(DENORM_MAGIC); + + // and one integer subtract of the bias later, we have our final float! + output = (short) (Float.floatToIntBits(tmp) - DENORM_MAGIC); + } else { + int mant_odd = (bits >> 13) & 1; // resulting mantissa is odd + + // update exponent, rounding bias part 1 + bits += ROUNDING_CONST; + // rounding bias part 2 + bits += mant_odd; + // take the bits! + output = (short) (bits >> 13); + } + } + + // Add the sign back in + output = (short) (output | ((short) (sign >> 16))); + + return output; + } + + /** + * Converts a bf16 value stored in a short into a float value. + * + * @param input A uint16_t representing a bfloat16 value. + * @return A float. + */ + public static float bf16ToFloat(short input) { + int bits = input << 16; + return Float.intBitsToFloat(bits); + } + + /** + * Converts a float into bf16. May not produce correct values for subnormal floats. + * + *

Rounds to nearest even. + * + * @param input The float input. + * @return A bfloat16 value which is closest to the float. + */ + public static short floatToBf16(float input) { + int bits = Float.floatToIntBits(input); + int lsb = (bits >> 16) & 1; + int roundingBias = 0x7fff + lsb; + bits += roundingBias; + return (short) (bits >> 16); + } +} diff --git a/java/src/main/android/ai/onnxruntime/platform/package-info.java b/java/src/main/android/ai/onnxruntime/platform/package-info.java new file mode 100644 index 0000000000..0899c4f283 --- /dev/null +++ b/java/src/main/android/ai/onnxruntime/platform/package-info.java @@ -0,0 +1,13 @@ +/* + * Copyright (c) 2023, Oracle and/or its affiliates. All rights reserved. + * Licensed under the MIT License. + */ + +/** + * A package of platform specific code, used to swap out Java implementations which don't run on Android. + * + *

Classes in this package should always have the same public methods. + * + *

This is the Android version of the package. + */ +package ai.onnxruntime.platform; \ No newline at end of file diff --git a/java/src/main/java/ai/onnxruntime/OnnxSparseTensor.java b/java/src/main/java/ai/onnxruntime/OnnxSparseTensor.java index 061738a1ba..0ab44cfb50 100644 --- a/java/src/main/java/ai/onnxruntime/OnnxSparseTensor.java +++ b/java/src/main/java/ai/onnxruntime/OnnxSparseTensor.java @@ -4,6 +4,7 @@ */ package ai.onnxruntime; +import ai.onnxruntime.platform.Fp16Conversions; import java.nio.Buffer; import java.nio.ByteBuffer; import java.nio.ByteOrder; @@ -323,12 +324,12 @@ public final class OnnxSparseTensor extends OnnxTensorLike { case FLOAT16: { ShortBuffer shortBuffer = buffer.asShortBuffer(); - return OrtUtil.convertFp16BufferToFloatBuffer(shortBuffer); + return Fp16Conversions.convertFp16BufferToFloatBuffer(shortBuffer); } case BFLOAT16: { ShortBuffer shortBuffer = buffer.asShortBuffer(); - return OrtUtil.convertBf16BufferToFloatBuffer(shortBuffer); + return Fp16Conversions.convertBf16BufferToFloatBuffer(shortBuffer); } case DOUBLE: { diff --git a/java/src/main/java/ai/onnxruntime/OnnxTensor.java b/java/src/main/java/ai/onnxruntime/OnnxTensor.java index 0dec29b59a..09d2cefbb8 100644 --- a/java/src/main/java/ai/onnxruntime/OnnxTensor.java +++ b/java/src/main/java/ai/onnxruntime/OnnxTensor.java @@ -4,6 +4,7 @@ */ package ai.onnxruntime; +import ai.onnxruntime.platform.Fp16Conversions; import java.nio.Buffer; import java.nio.ByteBuffer; import java.nio.ByteOrder; @@ -72,10 +73,10 @@ public class OnnxTensor extends OnnxTensorLike { case STRING: return getString(OnnxRuntime.ortApiHandle, nativeHandle); case FLOAT16: - return OrtUtil.fp16ToFloat( + return Fp16Conversions.fp16ToFloat( getShort(OnnxRuntime.ortApiHandle, nativeHandle, info.onnxType.value)); case BFLOAT16: - return OrtUtil.bf16ToFloat( + return Fp16Conversions.bf16ToFloat( getShort(OnnxRuntime.ortApiHandle, nativeHandle, info.onnxType.value)); case UNKNOWN: default: @@ -149,12 +150,12 @@ public class OnnxTensor extends OnnxTensorLike { // if it's fp16 we need to copy it out by hand. ByteBuffer buf = getBuffer(); ShortBuffer buffer = buf.asShortBuffer(); - return OrtUtil.convertFp16BufferToFloatBuffer(buffer); + return Fp16Conversions.convertFp16BufferToFloatBuffer(buffer); } else if (info.type == OnnxJavaType.BFLOAT16) { // if it's bf16 we need to copy it out by hand. ByteBuffer buf = getBuffer(); ShortBuffer buffer = buf.asShortBuffer(); - return OrtUtil.convertBf16BufferToFloatBuffer(buffer); + return Fp16Conversions.convertBf16BufferToFloatBuffer(buffer); } else { return null; } diff --git a/java/src/main/java/ai/onnxruntime/OrtUtil.java b/java/src/main/java/ai/onnxruntime/OrtUtil.java index d8ade62f62..2dbb63637a 100644 --- a/java/src/main/java/ai/onnxruntime/OrtUtil.java +++ b/java/src/main/java/ai/onnxruntime/OrtUtil.java @@ -5,9 +5,6 @@ */ package ai.onnxruntime; -import java.lang.invoke.MethodHandle; -import java.lang.invoke.MethodHandles; -import java.lang.invoke.MethodType; import java.lang.reflect.Array; import java.nio.Buffer; import java.nio.ByteBuffer; @@ -19,46 +16,12 @@ import java.nio.LongBuffer; import java.nio.ShortBuffer; import java.util.ArrayList; import java.util.Arrays; -import java.util.logging.Level; import java.util.logging.Logger; /** Util code for interacting with Java arrays. */ public final class OrtUtil { private static final Logger logger = Logger.getLogger(OrtUtil.class.getName()); - private static final MethodHandle fp16ToFp32; - private static final MethodHandle fp32ToFp16; - - static { - MethodHandle tmp16 = null; - MethodHandle tmp32 = null; - MethodHandles.Lookup lookup = MethodHandles.lookup(); - try { - // Attempt to lookup the Java 20 fp16 conversion methods which can use SIMD intrinsics. - tmp16 = - lookup.findStatic( - Float.class, "float16ToFloat", MethodType.methodType(float.class, short.class)); - tmp32 = - lookup.findStatic( - Float.class, "floatToFloat16", MethodType.methodType(short.class, float.class)); - } catch (IllegalAccessException | NoSuchMethodException e) { - // Must be on Java 19 or earlier, create handles for our methods. - try { - tmp16 = - lookup.findStatic( - OrtUtil.class, "mlasFp16ToFloat", MethodType.methodType(float.class, short.class)); - tmp32 = - lookup.findStatic( - OrtUtil.class, "mlasFloatToFp16", MethodType.methodType(short.class, float.class)); - } catch (IllegalAccessException | NoSuchMethodException ex) { - // Should not happen - logger.log(Level.SEVERE, "Failed to find fp16 conversion methods on OnnxTensor", e); - } - } - fp16ToFp32 = tmp16; - fp32ToFp16 = tmp32; - } - /** Private constructor for static util class. */ private OrtUtil() {} @@ -595,237 +558,6 @@ public final class OrtUtil { return new BufferTuple(tmp, bufferPos, bufferSize, data.remaining(), tmp != data); } - /** - * Rounds a buffer of floats into a buffer containing fp16 values (stored as shorts in Java). - * - *

Respects the position and limit of the input buffer. - * - * @param buf The buffer of floats. - * @return A buffer of fp16 values stored as shorts. - */ - public static ShortBuffer convertFloatBufferToFp16Buffer(FloatBuffer buf) { - int pos = buf.position(); - int remaining = buf.remaining(); - ShortBuffer output = - ByteBuffer.allocateDirect(remaining * 2).order(ByteOrder.nativeOrder()).asShortBuffer(); - for (int i = 0; i < remaining; i++) { - output.put(i, floatToFp16(buf.get(i + pos))); - } - return output; - } - - /** - * Casts a buffer of fp16 values stored as shorts into a buffer of floats. - * - *

Respects the position and limit of the input buffer. - * - * @param buf The buffer of fp16 values stored as shorts. - * @return A buffer of float values. - */ - public static FloatBuffer convertFp16BufferToFloatBuffer(ShortBuffer buf) { - int pos = buf.position(); - int remaining = buf.remaining(); - FloatBuffer output = - ByteBuffer.allocateDirect(remaining * 4).order(ByteOrder.nativeOrder()).asFloatBuffer(); - for (int i = 0; i < remaining; i++) { - output.put(i, fp16ToFloat(buf.get(i + pos))); - } - return output; - } - - /** - * Rounds a buffer of floats into a buffer containing bf16 values (stored as shorts in Java). - * - *

Respects the position and limit of the input buffer. - * - * @param buf The buffer of floats. - * @return A buffer of bf16 values stored as shorts. - */ - public static ShortBuffer convertFloatBufferToBf16Buffer(FloatBuffer buf) { - int pos = buf.position(); - int remaining = buf.remaining(); - ShortBuffer output = - ByteBuffer.allocateDirect(remaining * 2).order(ByteOrder.nativeOrder()).asShortBuffer(); - for (int i = 0; i < remaining; i++) { - output.put(i, floatToBf16(buf.get(i + pos))); - } - return output; - } - - /** - * Casts a buffer of bf16 values stored as shorts into a buffer of floats. - * - *

Respects the position and limit of the input buffer. - * - * @param buf The buffer of bf16 values stored as shorts. - * @return A buffer of float values. - */ - public static FloatBuffer convertBf16BufferToFloatBuffer(ShortBuffer buf) { - int pos = buf.position(); - int remaining = buf.remaining(); - FloatBuffer output = - ByteBuffer.allocateDirect(remaining * 4).order(ByteOrder.nativeOrder()).asFloatBuffer(); - for (int i = 0; i < remaining; i++) { - output.put(i, bf16ToFloat(buf.get(i + pos))); - } - return output; - } - - /** - * Converts a fp16 value stored in a short into a float value. - * - *

Note on Java 20 or newer this uses {@code Float.float16ToFloat} which may use CPU specific - * instructions for the conversion, otherwise it uses the conversion operation from ORT's native - * implementation. - * - * @param input The fp16 value. - * @return The float value. - */ - public static float fp16ToFloat(short input) { - try { - float ret = (float) fp16ToFp32.invokeExact(input); - return ret; - } catch (Throwable e) { - throw new AssertionError("Should not reach here", e); - } - } - - /** - * Converts a float value into a fp16 value stored in a short. - * - *

Note on Java 20 or newer this uses {@code Float.floatToFloat16} which may use CPU specific - * instructions for the conversion, otherwise it uses the conversion operation from ORT's native - * implementation. - * - * @param input The float value. - * @return The fp16 value. - */ - public static short floatToFp16(float input) { - try { - short ret = (short) fp32ToFp16.invokeExact(input); - return ret; - } catch (Throwable e) { - throw new AssertionError("Should not reach here", e); - } - } - - /** - * Upcasts a fp16 value to a float. Mirrors the conversion in MLAS. - * - * @param input A uint16_t representing an IEEE half precision float. - * @return A float. - */ - static float mlasFp16ToFloat(short input) { - // Port of MLAS_Half2Float from onnxruntime/core/mlas/inc/mlas_float16.h - final int MAGIC = 113 << 23; - // exponent mask after shift - final int SHIFTED_EXP = 0x7c00 << 13; - - // exponent/mantissa bits - int bits = (input & 0x7fff) << 13; - // just the exponent - final int exp = SHIFTED_EXP & bits; - // exponent adjust - bits += (127 - 15) << 23; - - // handle exponent special cases - if (exp == SHIFTED_EXP) { - // Inf/NaN? - // extra exp adjust - bits += (128 - 16) << 23; - } else if (exp == 0) { - // Zero/Denormal? - // extra exp adjust - bits += (1 << 23); - // renormalize - float tmp = Float.intBitsToFloat(bits) - Float.intBitsToFloat(MAGIC); - bits = Float.floatToIntBits(tmp); - } - - // sign bit - bits |= (input & 0x8000) << 16; - - return Float.intBitsToFloat(bits); - } - - /** - * Rounds a float value to fp16. Mirrors the conversion in MLAS. - * - * @param input A float value. - * @return The value rounded to an IEEE half precision value. - */ - static short mlasFloatToFp16(float input) { - // Port of MLAS_Float2Half from onnxruntime/core/mlas/inc/mlas_float16.h - int bits = Float.floatToIntBits(input); - final int F32_INFINITY = Float.floatToIntBits(Float.POSITIVE_INFINITY); - final int F16_MAX = (127 + 16) << 23; - final int DENORM_MAGIC = ((127 - 15) + (23 - 10) + 1) << 23; - final int SIGN_MASK = 0x80000000; - final int ROUNDING_CONST = ((15 - 127) << 23) + 0xfff; - - int sign = bits & SIGN_MASK; - // mask out sign bit - bits ^= sign; - - short output; - if (bits >= F16_MAX) { - // Inf or NaN (all exponent bits set) - output = (bits > F32_INFINITY) ? (short) 0x7e00 : (short) 0x7c00; - } else { - if (bits < (113 << 23)) { - // Subnormal or zero - // use a magic value to align our 10 mantissa bits at the bottom of - // the float. as long as FP addition is round-to-nearest-even this - // just works. - float tmp = Float.intBitsToFloat(bits) + Float.intBitsToFloat(DENORM_MAGIC); - - // and one integer subtract of the bias later, we have our final float! - output = (short) (Float.floatToIntBits(tmp) - DENORM_MAGIC); - } else { - int mant_odd = (bits >> 13) & 1; // resulting mantissa is odd - - // update exponent, rounding bias part 1 - bits += ROUNDING_CONST; - // rounding bias part 2 - bits += mant_odd; - // take the bits! - output = (short) (bits >> 13); - } - } - - // Add the sign back in - output = (short) (output | ((short) (sign >> 16))); - - return output; - } - - /** - * Converts a bf16 value stored in a short into a float value. - * - * @param input A uint16_t representing a bfloat16 value. - * @return A float. - */ - public static float bf16ToFloat(short input) { - int bits = input << 16; - return Float.intBitsToFloat(bits); - } - - /** - * Converts a float into bf16. May not produce correct values for subnormal floats. - * - *

Rounds to nearest even. - * - * @param input The float input. - * @return A bfloat16 value which is closest to the float. - */ - public static short floatToBf16(float input) { - int bits = Float.floatToIntBits(input); - int lsb = (bits >> 16) & 1; - int roundingBias = 0x7fff + lsb; - bits += roundingBias; - return (short) (bits >> 16); - } - static final class BufferTuple { final Buffer data; final int pos; diff --git a/java/src/main/jvm/ai/onnxruntime/platform/Fp16Conversions.java b/java/src/main/jvm/ai/onnxruntime/platform/Fp16Conversions.java new file mode 100644 index 0000000000..fce872688a --- /dev/null +++ b/java/src/main/jvm/ai/onnxruntime/platform/Fp16Conversions.java @@ -0,0 +1,287 @@ +/* + * Copyright (c) 2023, Oracle and/or its affiliates. All rights reserved. + * Licensed under the MIT License. + */ +package ai.onnxruntime.platform; + +import java.lang.invoke.MethodHandle; +import java.lang.invoke.MethodHandles; +import java.lang.invoke.MethodType; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.FloatBuffer; +import java.nio.ShortBuffer; +import java.util.logging.Level; +import java.util.logging.Logger; + +/** Conversions between fp16, bfloat16 and fp32. */ +public final class Fp16Conversions { + private static final Logger logger = Logger.getLogger(Fp16Conversions.class.getName()); + private static final MethodHandle fp16ToFp32; + private static final MethodHandle fp32ToFp16; + + static { + MethodHandle tmp16 = null; + MethodHandle tmp32 = null; + MethodHandles.Lookup lookup = MethodHandles.lookup(); + try { + // Attempt to lookup the Java 20 fp16 conversion methods which can use SIMD intrinsics. + tmp16 = + lookup.findStatic( + Float.class, "float16ToFloat", MethodType.methodType(float.class, short.class)); + tmp32 = + lookup.findStatic( + Float.class, "floatToFloat16", MethodType.methodType(short.class, float.class)); + } catch (IllegalAccessException | NoSuchMethodException e) { + // Must be on Java 19 or earlier, create handles for our methods. + try { + tmp16 = + lookup.findStatic( + Fp16Conversions.class, + "mlasFp16ToFloat", + MethodType.methodType(float.class, short.class)); + tmp32 = + lookup.findStatic( + Fp16Conversions.class, + "mlasFloatToFp16", + MethodType.methodType(short.class, float.class)); + } catch (IllegalAccessException | NoSuchMethodException ex) { + // Should not happen + logger.log(Level.SEVERE, "Failed to find fp16 conversion methods on OnnxTensor", e); + } + } + fp16ToFp32 = tmp16; + fp32ToFp16 = tmp32; + } + + /** + * Rounds a buffer of floats into a buffer containing fp16 values (stored as shorts in Java). + * + *

Respects the position and limit of the input buffer. + * + * @param buf The buffer of floats. + * @return A buffer of fp16 values stored as shorts. + */ + public static ShortBuffer convertFloatBufferToFp16Buffer(FloatBuffer buf) { + int pos = buf.position(); + int remaining = buf.remaining(); + ShortBuffer output = + ByteBuffer.allocateDirect(remaining * 2).order(ByteOrder.nativeOrder()).asShortBuffer(); + for (int i = 0; i < remaining; i++) { + output.put(i, floatToFp16(buf.get(i + pos))); + } + return output; + } + + /** + * Casts a buffer of fp16 values stored as shorts into a buffer of floats. + * + *

Respects the position and limit of the input buffer. + * + * @param buf The buffer of fp16 values stored as shorts. + * @return A buffer of float values. + */ + public static FloatBuffer convertFp16BufferToFloatBuffer(ShortBuffer buf) { + int pos = buf.position(); + int remaining = buf.remaining(); + FloatBuffer output = + ByteBuffer.allocateDirect(remaining * 4).order(ByteOrder.nativeOrder()).asFloatBuffer(); + for (int i = 0; i < remaining; i++) { + output.put(i, fp16ToFloat(buf.get(i + pos))); + } + return output; + } + + /** + * Rounds a buffer of floats into a buffer containing bf16 values (stored as shorts in Java). + * + *

Respects the position and limit of the input buffer. + * + * @param buf The buffer of floats. + * @return A buffer of bf16 values stored as shorts. + */ + public static ShortBuffer convertFloatBufferToBf16Buffer(FloatBuffer buf) { + int pos = buf.position(); + int remaining = buf.remaining(); + ShortBuffer output = + ByteBuffer.allocateDirect(remaining * 2).order(ByteOrder.nativeOrder()).asShortBuffer(); + for (int i = 0; i < remaining; i++) { + output.put(i, floatToBf16(buf.get(i + pos))); + } + return output; + } + + /** + * Casts a buffer of bf16 values stored as shorts into a buffer of floats. + * + *

Respects the position and limit of the input buffer. + * + * @param buf The buffer of bf16 values stored as shorts. + * @return A buffer of float values. + */ + public static FloatBuffer convertBf16BufferToFloatBuffer(ShortBuffer buf) { + int pos = buf.position(); + int remaining = buf.remaining(); + FloatBuffer output = + ByteBuffer.allocateDirect(remaining * 4).order(ByteOrder.nativeOrder()).asFloatBuffer(); + for (int i = 0; i < remaining; i++) { + output.put(i, bf16ToFloat(buf.get(i + pos))); + } + return output; + } + + /** + * Converts a fp16 value stored in a short into a float value. + * + *

Note on Java 20 or newer this uses {@code Float.float16ToFloat} which may use CPU specific + * instructions for the conversion, otherwise it uses the conversion operation from ORT's native + * implementation. + * + * @param input The fp16 value. + * @return The float value. + */ + public static float fp16ToFloat(short input) { + try { + float ret = (float) fp16ToFp32.invokeExact(input); + return ret; + } catch (Throwable e) { + throw new AssertionError("Should not reach here", e); + } + } + + /** + * Converts a float value into a fp16 value stored in a short. + * + *

Note on Java 20 or newer this uses {@code Float.floatToFloat16} which may use CPU specific + * instructions for the conversion, otherwise it uses the conversion operation from ORT's native + * implementation. + * + * @param input The float value. + * @return The fp16 value. + */ + public static short floatToFp16(float input) { + try { + short ret = (short) fp32ToFp16.invokeExact(input); + return ret; + } catch (Throwable e) { + throw new AssertionError("Should not reach here", e); + } + } + + /** + * Upcasts a fp16 value to a float. Mirrors the conversion in MLAS. + * + * @param input A uint16_t representing an IEEE half precision float. + * @return A float. + */ + public static float mlasFp16ToFloat(short input) { + // Port of MLAS_Half2Float from onnxruntime/core/mlas/inc/mlas_float16.h + final int MAGIC = 113 << 23; + // exponent mask after shift + final int SHIFTED_EXP = 0x7c00 << 13; + + // exponent/mantissa bits + int bits = (input & 0x7fff) << 13; + // just the exponent + final int exp = SHIFTED_EXP & bits; + // exponent adjust + bits += (127 - 15) << 23; + + // handle exponent special cases + if (exp == SHIFTED_EXP) { + // Inf/NaN? + // extra exp adjust + bits += (128 - 16) << 23; + } else if (exp == 0) { + // Zero/Denormal? + // extra exp adjust + bits += (1 << 23); + // renormalize + float tmp = Float.intBitsToFloat(bits) - Float.intBitsToFloat(MAGIC); + bits = Float.floatToIntBits(tmp); + } + + // sign bit + bits |= (input & 0x8000) << 16; + + return Float.intBitsToFloat(bits); + } + + /** + * Rounds a float value to fp16. Mirrors the conversion in MLAS. + * + * @param input A float value. + * @return The value rounded to an IEEE half precision value. + */ + public static short mlasFloatToFp16(float input) { + // Port of MLAS_Float2Half from onnxruntime/core/mlas/inc/mlas_float16.h + int bits = Float.floatToIntBits(input); + final int F32_INFINITY = Float.floatToIntBits(Float.POSITIVE_INFINITY); + final int F16_MAX = (127 + 16) << 23; + final int DENORM_MAGIC = ((127 - 15) + (23 - 10) + 1) << 23; + final int SIGN_MASK = 0x80000000; + final int ROUNDING_CONST = ((15 - 127) << 23) + 0xfff; + + int sign = bits & SIGN_MASK; + // mask out sign bit + bits ^= sign; + + short output; + if (bits >= F16_MAX) { + // Inf or NaN (all exponent bits set) + output = (bits > F32_INFINITY) ? (short) 0x7e00 : (short) 0x7c00; + } else { + if (bits < (113 << 23)) { + // Subnormal or zero + // use a magic value to align our 10 mantissa bits at the bottom of + // the float. as long as FP addition is round-to-nearest-even this + // just works. + float tmp = Float.intBitsToFloat(bits) + Float.intBitsToFloat(DENORM_MAGIC); + + // and one integer subtract of the bias later, we have our final float! + output = (short) (Float.floatToIntBits(tmp) - DENORM_MAGIC); + } else { + int mant_odd = (bits >> 13) & 1; // resulting mantissa is odd + + // update exponent, rounding bias part 1 + bits += ROUNDING_CONST; + // rounding bias part 2 + bits += mant_odd; + // take the bits! + output = (short) (bits >> 13); + } + } + + // Add the sign back in + output = (short) (output | ((short) (sign >> 16))); + + return output; + } + + /** + * Converts a bf16 value stored in a short into a float value. + * + * @param input A uint16_t representing a bfloat16 value. + * @return A float. + */ + public static float bf16ToFloat(short input) { + int bits = input << 16; + return Float.intBitsToFloat(bits); + } + + /** + * Converts a float into bf16. May not produce correct values for subnormal floats. + * + *

Rounds to nearest even. + * + * @param input The float input. + * @return A bfloat16 value which is closest to the float. + */ + public static short floatToBf16(float input) { + int bits = Float.floatToIntBits(input); + int lsb = (bits >> 16) & 1; + int roundingBias = 0x7fff + lsb; + bits += roundingBias; + return (short) (bits >> 16); + } +} diff --git a/java/src/main/jvm/ai/onnxruntime/platform/package-info.java b/java/src/main/jvm/ai/onnxruntime/platform/package-info.java new file mode 100644 index 0000000000..ef52c50064 --- /dev/null +++ b/java/src/main/jvm/ai/onnxruntime/platform/package-info.java @@ -0,0 +1,14 @@ +/* + * Copyright (c) 2023, Oracle and/or its affiliates. All rights reserved. + * Licensed under the MIT License. + */ + +/** + * A package of platform specific code, used to swap out Java implementations which don't run on + * Android. + * + *

Classes in this package should always have the same public methods. + * + *

This is the Java version of the package. + */ +package ai.onnxruntime.platform; diff --git a/java/src/test/java/ai/onnxruntime/OnnxTensorTest.java b/java/src/test/java/ai/onnxruntime/OnnxTensorTest.java index 28ac840c0f..fcb4590717 100644 --- a/java/src/test/java/ai/onnxruntime/OnnxTensorTest.java +++ b/java/src/test/java/ai/onnxruntime/OnnxTensorTest.java @@ -4,6 +4,7 @@ */ package ai.onnxruntime; +import ai.onnxruntime.platform.Fp16Conversions; import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.nio.FloatBuffer; @@ -163,7 +164,7 @@ public class OnnxTensorTest { for (int i = 0; i < input.length; i++) { for (int j = 0; j < input[0].length; j++) { short bits = (short) rng.nextInt(); - input[i][j] = OrtUtil.bf16ToFloat(bits); + input[i][j] = Fp16Conversions.bf16ToFloat(bits); shortBuf.put(bits); } } @@ -196,7 +197,7 @@ public class OnnxTensorTest { for (int i = 0; i < input.length; i++) { for (int j = 0; j < input[0].length; j++) { short bits = (short) rng.nextInt(); - input[i][j] = OrtUtil.fp16ToFloat(bits); + input[i][j] = Fp16Conversions.fp16ToFloat(bits); shortBuf.put(bits); } } @@ -232,7 +233,7 @@ public class OnnxTensorTest { int bits = rng.nextInt(); input[i][j] = Float.intBitsToFloat(bits); floatBuf.put(input[i][j]); - shortBuf.put(OrtUtil.floatToFp16(input[i][j])); + shortBuf.put(Fp16Conversions.floatToFp16(input[i][j])); } } floatBuf.rewind(); @@ -247,7 +248,7 @@ public class OnnxTensorTest { // Check outbound Java side cast to fp32 works FloatBuffer castOutput = output.getFloatBuffer(); float[] expectedFloatArr = new float[10 * 5]; - OrtUtil.convertFp16BufferToFloatBuffer(shortBuf).get(expectedFloatArr); + Fp16Conversions.convertFp16BufferToFloatBuffer(shortBuf).get(expectedFloatArr); float[] actualFloatArr = new float[10 * 5]; castOutput.get(actualFloatArr); Assertions.assertArrayEquals(expectedFloatArr, actualFloatArr); @@ -279,7 +280,7 @@ public class OnnxTensorTest { int bits = rng.nextInt(); input[i][j] = Float.intBitsToFloat(bits); floatBuf.put(input[i][j]); - shortBuf.put(OrtUtil.floatToBf16(input[i][j])); + shortBuf.put(Fp16Conversions.floatToBf16(input[i][j])); } } floatBuf.rewind(); @@ -294,7 +295,7 @@ public class OnnxTensorTest { // Check outbound Java side cast to fp32 works FloatBuffer castOutput = output.getFloatBuffer(); float[] expectedFloatArr = new float[10 * 5]; - OrtUtil.convertBf16BufferToFloatBuffer(shortBuf).get(expectedFloatArr); + Fp16Conversions.convertBf16BufferToFloatBuffer(shortBuf).get(expectedFloatArr); float[] actualFloatArr = new float[10 * 5]; castOutput.get(actualFloatArr); Assertions.assertArrayEquals(expectedFloatArr, actualFloatArr); @@ -314,8 +315,8 @@ public class OnnxTensorTest { for (int i = 0; i < 0xffff; i++) { // Round trip every value short curVal = (short) (0xffff & i); - float upcast = OrtUtil.mlasFp16ToFloat(curVal); - short output = OrtUtil.mlasFloatToFp16(upcast); + float upcast = Fp16Conversions.mlasFp16ToFloat(curVal); + short output = Fp16Conversions.mlasFloatToFp16(upcast); if (!Float.isNaN(upcast)) { // We coerce NaNs to the same value. Assertions.assertEquals( @@ -331,8 +332,8 @@ public class OnnxTensorTest { for (int i = 0; i < 0xffff; i++) { // Round trip every value short curVal = (short) (0xffff & i); - float upcast = OrtUtil.bf16ToFloat(curVal); - short output = OrtUtil.floatToBf16(upcast); + float upcast = Fp16Conversions.bf16ToFloat(curVal); + short output = Fp16Conversions.floatToBf16(upcast); if (!Float.isNaN(upcast)) { // We coerce NaNs to the same value. Assertions.assertEquals(