mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-03 23:49:44 +00:00
add bf16 support for few ops (#20385)
### Description Add bf16 support for below ops: ConstantOfShape Exp Erf convolution PythonOp ### Motivation and Context phimm model works on bf16, ORT need support bf16 on previous ops to work with phimm on bf16
This commit is contained in:
parent
464f199b95
commit
227c4419fc
15 changed files with 174 additions and 10 deletions
|
|
@ -71,7 +71,7 @@ Do not modify directly.*
|
|||
|ConcatFromSequence|*in* input_sequence:**S**<br> *out* concat_result:**T**|11+|**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))|
|
||||
|ConstantOfShape|*in* input:**T1**<br> *out* output:**T2**|21+|**T1** = tensor(int64)<br/> **T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|||20|**T1** = tensor(int64)<br/> **T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|||[9, 19]|**T1** = tensor(int64)<br/> **T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|||[9, 19]|**T1** = tensor(int64)<br/> **T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|Conv|*in* X:**T**<br> *in* W:**T**<br> *in* B:**T**<br> *out* Y:**T**|11+|**T** = tensor(float)|
|
||||
|||[1, 10]|**T** = tensor(float)|
|
||||
|ConvInteger|*in* x:**T1**<br> *in* w:**T2**<br> *in* x_zero_point:**T1**<br> *in* w_zero_point:**T2**<br> *out* y:**T3**|10+|**T1** = tensor(uint8)<br/> **T2** = tensor(uint8)<br/> **T3** = tensor(int32)|
|
||||
|
|
@ -601,9 +601,9 @@ Do not modify directly.*
|
|||
|Equal|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T1**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)<br/> **T1** = tensor(bool)|
|
||||
|||[11, 12]|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)|
|
||||
|||[7, 10]|**T** = tensor(bool), tensor(int32), tensor(int64)|
|
||||
|Erf|*in* input:**T**<br> *out* output:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)|
|
||||
|Erf|*in* input:**T**<br> *out* output:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
|
||||
|||[9, 12]|**T** = tensor(double), tensor(float), tensor(float16)|
|
||||
|Exp|*in* input:**T**<br> *out* output:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)|
|
||||
|Exp|*in* input:**T**<br> *out* output:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
|
||||
|||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16)|
|
||||
|Expand|*in* input:**T**<br> *in* shape:**tensor(int64)**<br> *out* output:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|||[8, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|
|
|
|||
|
|
@ -15,13 +15,16 @@
|
|||
|
||||
namespace onnxruntime {
|
||||
|
||||
// Add bf16 support for ConstantOfShape operator for phimm model.
|
||||
// Although ONNX don't have bf16 support in opset-9 for ConstantOfShape we add support here:
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Changelog.md#constantofshape-9
|
||||
using ConstantOfShapeDefaultOutputTypes =
|
||||
TypeList<
|
||||
MLFloat16,
|
||||
float, double,
|
||||
int8_t, int16_t, int32_t, int64_t,
|
||||
uint8_t, uint16_t, uint32_t, uint64_t,
|
||||
bool>;
|
||||
bool, BFloat16>;
|
||||
|
||||
using ConstantOfShapeDefaultOutputTypesOpset20 =
|
||||
TypeList<
|
||||
|
|
@ -158,6 +161,7 @@ void ConstantOfShapeBase<EnabledOutputTypeList>::SetValueFromTensorProto(const O
|
|||
CASE_FETCH_VALUE_DATA(uint16_t)
|
||||
CASE_FETCH_VALUE_DATA(uint32_t)
|
||||
CASE_FETCH_VALUE_DATA(uint64_t)
|
||||
CASE_FETCH_VALUE_DATA(BFloat16)
|
||||
default:
|
||||
ORT_THROW("Unsupported value attribute datatype: ", tensor_type);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -231,6 +231,9 @@ __device__ __inline__ double _Erf(double a) { return erf(a); }
|
|||
template <>
|
||||
__device__ __inline__ half _Erf(half a) { return half(erff((float)a)); }
|
||||
|
||||
template <>
|
||||
__device__ __inline__ BFloat16 _Erf(BFloat16 a) { return BFloat16(erff((float)a)); }
|
||||
|
||||
template <typename T>
|
||||
__device__ __host__ __inline__ T _Round(T a);
|
||||
|
||||
|
|
|
|||
|
|
@ -1031,9 +1031,13 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Exp);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, Exp);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, Exp);
|
||||
// Add bf16 support for Exp in opset 13+ for phimm model
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, BFloat16, Exp);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Erf);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, Erf);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, Erf);
|
||||
// Add bf16 support for Erf in opset 13+ for phimm model
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, BFloat16, Erf);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Expand);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Sum);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Max);
|
||||
|
|
@ -1947,9 +1951,11 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Exp)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, Exp)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, Exp)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, BFloat16, Exp)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Erf)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, Erf)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, Erf)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, BFloat16, Erf)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Expand)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Sum)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Max)>,
|
||||
|
|
|
|||
|
|
@ -244,8 +244,8 @@ UNARY_OP_HFD(Ceil, 13)
|
|||
UNARY_OP_HFD(Reciprocal, 13)
|
||||
UNARY_OP_HFDX(Sqrt, 13)
|
||||
UNARY_OP_HFD(Log, 13)
|
||||
UNARY_OP_HFD(Exp, 13)
|
||||
UNARY_OP_HFD(Erf, 13)
|
||||
UNARY_OP_HFDX(Exp, 13)
|
||||
UNARY_OP_HFDX(Erf, 13)
|
||||
UNARY_OP_BWUZCSILHFD(Sign, 13)
|
||||
|
||||
UNARY_LOGICALOP_NOT_TYPED(1, bool)
|
||||
|
|
|
|||
|
|
@ -87,7 +87,7 @@ SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Reciprocal)
|
|||
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFDX(Sqrt)
|
||||
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFDX(Log)
|
||||
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFDX(Exp)
|
||||
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Erf)
|
||||
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFDX(Erf)
|
||||
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Round)
|
||||
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Sin)
|
||||
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Cos)
|
||||
|
|
|
|||
|
|
@ -138,6 +138,9 @@ __device__ __inline__ double _Erf(double a) { return erf(a); }
|
|||
template <>
|
||||
__device__ __inline__ half _Erf(half a) { return half(erff((float)a)); }
|
||||
|
||||
template <>
|
||||
__device__ __inline__ BFloat16 _Erf(BFloat16 a) { return BFloat16(erff((float)a)); }
|
||||
|
||||
template <typename T>
|
||||
__device__ __inline__ T _Round(T a);
|
||||
|
||||
|
|
|
|||
|
|
@ -1021,9 +1021,13 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, Exp);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, Exp);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, Exp);
|
||||
// Add bf16 support for Exp in opset 13+ for phimm model
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, Exp);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, Erf);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, Erf);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, Erf);
|
||||
// Add bf16 support for Erf in opset 13+ for phimm model
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, Erf);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Expand);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Sum);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Max);
|
||||
|
|
@ -1973,9 +1977,11 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, Exp)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, Exp)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, Exp)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, Exp)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, Erf)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, Erf)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, Erf)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, Erf)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Expand)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Sum)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Max)>,
|
||||
|
|
|
|||
|
|
@ -478,6 +478,8 @@ Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_d
|
|||
template <>
|
||||
Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ MLFloat16* p_data, size_t expected_size) { return g_host->UnpackTensor(tensor, raw_data, raw_data_len, p_data, expected_size); }
|
||||
template <>
|
||||
Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ BFloat16* p_data, size_t expected_size) { return g_host->UnpackTensor(tensor, raw_data, raw_data_len, p_data, expected_size); }
|
||||
template <>
|
||||
Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ int8_t* p_data, size_t expected_size) { return g_host->UnpackTensor(tensor, raw_data, raw_data_len, p_data, expected_size); }
|
||||
template <>
|
||||
Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ uint8_t* p_data, size_t expected_size) { return g_host->UnpackTensor(tensor, raw_data, raw_data_len, p_data, expected_size); }
|
||||
|
|
|
|||
|
|
@ -198,6 +198,7 @@ struct ProviderHost {
|
|||
virtual Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ float* p_data, size_t expected_size) = 0;
|
||||
virtual Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ double* p_data, size_t expected_size) = 0;
|
||||
virtual Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ MLFloat16* p_data, size_t expected_size) = 0;
|
||||
virtual Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ BFloat16* p_data, size_t expected_size) = 0;
|
||||
virtual Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ int8_t* p_data, size_t expected_size) = 0;
|
||||
virtual Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ uint8_t* p_data, size_t expected_size) = 0;
|
||||
virtual Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ int16_t* p_data, size_t expected_size) = 0;
|
||||
|
|
|
|||
|
|
@ -271,6 +271,7 @@ struct ProviderHostImpl : ProviderHost {
|
|||
Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ float* p_data, size_t expected_size) override { return utils::UnpackTensor(tensor, raw_data, raw_data_len, p_data, expected_size); }
|
||||
Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ double* p_data, size_t expected_size) override { return utils::UnpackTensor(tensor, raw_data, raw_data_len, p_data, expected_size); }
|
||||
Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ MLFloat16* p_data, size_t expected_size) override { return utils::UnpackTensor(tensor, raw_data, raw_data_len, p_data, expected_size); }
|
||||
Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ BFloat16* p_data, size_t expected_size) override { return utils::UnpackTensor(tensor, raw_data, raw_data_len, p_data, expected_size); }
|
||||
Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ int8_t* p_data, size_t expected_size) override { return utils::UnpackTensor(tensor, raw_data, raw_data_len, p_data, expected_size); }
|
||||
Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ uint8_t* p_data, size_t expected_size) override { return utils::UnpackTensor(tensor, raw_data, raw_data_len, p_data, expected_size); }
|
||||
Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ int16_t* p_data, size_t expected_size) override { return utils::UnpackTensor(tensor, raw_data, raw_data_len, p_data, expected_size); }
|
||||
|
|
|
|||
|
|
@ -85,6 +85,10 @@ inline void SetValue(TensorProto& t_proto, MLFloat16 value) {
|
|||
t_proto.mutable_int32_data()->Add(value.val);
|
||||
}
|
||||
|
||||
inline void SetValue(TensorProto& t_proto, BFloat16 value) {
|
||||
t_proto.mutable_int32_data()->Add(value.val);
|
||||
}
|
||||
|
||||
// This works for int64_t
|
||||
template <class T>
|
||||
inline void SetValue(TensorProto& t_proto, T value,
|
||||
|
|
@ -100,7 +104,7 @@ inline void SetValue(TensorProto& t_proto, T value,
|
|||
t_proto.mutable_uint64_data()->Add(value);
|
||||
}
|
||||
|
||||
// For everything else except float, double and MLFloat16
|
||||
// For everything else except float, double, MLFloat16 and BFloat16
|
||||
template <class T>
|
||||
inline void SetValue(TensorProto& t_proto, T value,
|
||||
typename std::enable_if<!std::is_same<T, int64_t>::value &&
|
||||
|
|
@ -153,6 +157,7 @@ TEST(ConstantOfShape, TypeTests) {
|
|||
RunTypedTest(TensorProto::UINT16, uint16_t(6U), opset);
|
||||
RunTypedTest(TensorProto::UINT32, uint32_t(32U), opset);
|
||||
RunTypedTest(TensorProto::UINT64, uint64_t(64U), opset);
|
||||
RunTypedTest(TensorProto::BFLOAT16, BFloat16::FromBits(static_cast<uint16_t>(7)), opset);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -3949,7 +3949,7 @@ Return true if all elements are true and false otherwise.
|
|||
static_cast<int64_t>(1))
|
||||
.TypeConstraint(
|
||||
"T",
|
||||
OpSchema::all_tensor_types(),
|
||||
OpSchema::all_tensor_types_ir4(),
|
||||
"Allow inputs and outputs to be any kind of tensor.")
|
||||
.TypeConstraint(
|
||||
"TInt64",
|
||||
|
|
@ -4116,7 +4116,7 @@ Return true if all elements are true and false otherwise.
|
|||
static_cast<int64_t>(1))
|
||||
.TypeConstraint(
|
||||
"T",
|
||||
OpSchema::all_tensor_types(),
|
||||
OpSchema::all_tensor_types_ir4(),
|
||||
"Allow inputs and outputs to be any kind of tensor.")
|
||||
.TypeConstraint(
|
||||
"TInt64",
|
||||
|
|
|
|||
|
|
@ -847,6 +847,103 @@ def layer_norm(g, input, normalized_shape, weight, bias, eps, cudnn_enable):
|
|||
return res
|
||||
|
||||
|
||||
# Adapted from torch.onnx.symbolic_opset9._convolution -
|
||||
# https://github.com/pytorch/pytorch/blob/cf06189a2d2785ac493bcd0d55e520af5a0e3b97/torch/onnx/symbolic_opset9.py#L2334
|
||||
# We override aten::_convolution here to support bf16 for phimm model from GenAI team.
|
||||
# For bf16 inputs, we will convert input to float32, do convolution then convert output back to bf16.
|
||||
# TODO: This might have negative impact on performance.
|
||||
@register_symbolic("_convolution")
|
||||
@parse_args("v", "v", "v", "is", "is", "is", "i", "is", "i", "i", "i", "i", "i")
|
||||
def convolution(
|
||||
g,
|
||||
input,
|
||||
weight,
|
||||
bias,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
transposed,
|
||||
output_padding,
|
||||
groups,
|
||||
benchmark,
|
||||
deterministic,
|
||||
cudnn_enabled,
|
||||
allow_tf32=None,
|
||||
):
|
||||
from torch.onnx.symbolic_opset9 import _convolution
|
||||
|
||||
input_casted = (
|
||||
g.op("Cast", input, to_i=torch.onnx.TensorProtoDataType.FLOAT)
|
||||
if input.type().scalarType() == "BFloat16"
|
||||
else input
|
||||
)
|
||||
weight_casted = (
|
||||
g.op("Cast", weight, to_i=torch.onnx.TensorProtoDataType.FLOAT)
|
||||
if weight.type().scalarType() == "BFloat16"
|
||||
else weight
|
||||
)
|
||||
|
||||
n = _convolution(
|
||||
g,
|
||||
input_casted,
|
||||
weight_casted,
|
||||
bias,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
transposed,
|
||||
output_padding,
|
||||
groups,
|
||||
benchmark,
|
||||
deterministic,
|
||||
cudnn_enabled,
|
||||
allow_tf32,
|
||||
)
|
||||
|
||||
n_casted = (
|
||||
g.op("Cast", n, to_i=torch.onnx.TensorProtoDataType.BFLOAT16) if input.type().scalarType() == "BFloat16" else n
|
||||
)
|
||||
return n_casted
|
||||
|
||||
|
||||
# Adapted from torch.onnx.symbolic_opset9._convolution_mode -
|
||||
# https://github.com/pytorch/pytorch/blob/cf06189a2d2785ac493bcd0d55e520af5a0e3b97/torch/onnx/symbolic_opset9.py#L2406
|
||||
# We override aten::_convolution_mode here to support bf16 for phimm model from GenAI team.
|
||||
# For bf16 inputs, we will convert input to float32, do convolution then convert output back to bf16.
|
||||
# TODO: This might have negative impact on performance.
|
||||
@register_symbolic("_convolution_mode")
|
||||
@parse_args("v", "v", "v", "is", "s", "is", "i")
|
||||
def convolution_mode(
|
||||
g,
|
||||
input,
|
||||
weight,
|
||||
bias,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
groups,
|
||||
):
|
||||
from torch.onnx.symbolic_opset9 import _convolution_mode
|
||||
|
||||
input_casted = (
|
||||
g.op("Cast", input, to_i=torch.onnx.TensorProtoDataType.FLOAT)
|
||||
if input.type().scalarType() == "BFloat16"
|
||||
else input
|
||||
)
|
||||
weight_casted = (
|
||||
g.op("Cast", weight, to_i=torch.onnx.TensorProtoDataType.FLOAT)
|
||||
if weight.type().scalarType() == "BFloat16"
|
||||
else weight
|
||||
)
|
||||
|
||||
n = _convolution_mode(g, input_casted, weight_casted, bias, stride, padding, dilation, groups)
|
||||
|
||||
n_casted = (
|
||||
g.op("Cast", n, to_i=torch.onnx.TensorProtoDataType.BFLOAT16) if input.type().scalarType() == "BFloat16" else n
|
||||
)
|
||||
return n_casted
|
||||
|
||||
|
||||
# Adapted from torch.onnx.symbolic_opset13.softmax -
|
||||
# https://github.com/pytorch/pytorch/blob/cf06189a2d2785ac493bcd0d55e520af5a0e3b97/torch/onnx/symbolic_opset13.py#L27
|
||||
# We don't need overloads symbolic_opset9 because training support opsets >= 13.
|
||||
|
|
|
|||
|
|
@ -6610,6 +6610,42 @@ def test_overridden_softmax_export(softmax_compute_type):
|
|||
assert to_value == pytorch_type_to_onnx_dtype(softmax_compute_type), "Cast to attribute is not as expected"
|
||||
|
||||
|
||||
def test_aten_conv_bf16():
|
||||
class NeuralNetConv(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(
|
||||
in_channels=3,
|
||||
out_channels=1024,
|
||||
kernel_size=14,
|
||||
stride=14,
|
||||
bias=False,
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
return self.conv(input)
|
||||
|
||||
device = "cuda"
|
||||
pt_model = NeuralNetConv().to(device)
|
||||
ort_model = ORTModule(copy.deepcopy(pt_model))
|
||||
|
||||
def run_step(model, input):
|
||||
prediction = model(input)
|
||||
prediction.sum().backward()
|
||||
return prediction
|
||||
|
||||
# reset manual seed to reset the generator
|
||||
torch.manual_seed(2333)
|
||||
pt_input = torch.randn([2, 3, 336, 336], dtype=torch.bfloat16, device=device, requires_grad=True)
|
||||
ort_input = copy.deepcopy(pt_input)
|
||||
pt_prediction = run_step(pt_model, pt_input)
|
||||
ort_prediction = run_step(ort_model, ort_input)
|
||||
|
||||
_test_helpers.assert_values_are_close(ort_prediction, pt_prediction)
|
||||
_test_helpers.assert_values_are_close(ort_input.grad, pt_input.grad)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("memory_optimization_level", [None, 0, 1, 2])
|
||||
@pytest.mark.parametrize("allow_gradient_checkpoint_export", [None, 0, 1])
|
||||
@pytest.mark.parametrize("fx", ["torch", "deepspeed"])
|
||||
|
|
|
|||
Loading…
Reference in a new issue