Fix BatchNorm CUDA kernel definition

This commit is contained in:
Hariharan Seshadri 2020-04-17 20:46:32 -07:00 committed by Changming Sun
parent c365822808
commit 1599562016
3 changed files with 28 additions and 35 deletions

View file

@ -920,28 +920,28 @@ This version of the operator has been available since version 1 of the 'com.micr
#### Inputs
<dl>
<dt><tt>x</tt> : T2</dt>
<dt><tt>x</tt> : T1</dt>
<dd>N-D quantized Input tensor to be de-quantized.</dd>
<dt><tt>x_scale</tt> : T1</dt>
<dt><tt>x_scale</tt> : T2</dt>
<dd>Scale for input 'x'. It could be a scalar or a 1-D tensor, which means a per-tensor or per-axis quantization.If it's a 1-D tensor, its number of elements should be equal to the dimension value of 'axis' dimension of input 'x'.</dd>
<dt><tt>x_zero_point</tt> : T2</dt>
<dt><tt>x_zero_point</tt> : T1</dt>
<dd>Zero point for input 'x'. It could be a scalar or a 1-D tensor, which means a per-tensor or per-axis quantization.If it's a 1-D tensor, its number of elements should be equal to the dimension value of 'axis' dimension of input 'x'.</dd>
</dl>
#### Outputs
<dl>
<dt><tt>y</tt> : T1</dt>
<dt><tt>y</tt> : T2</dt>
<dd>N-D full precision output tensor. It has same shape as input 'x'.</dd>
</dl>
#### Type Constraints
<dl>
<dt><tt>T1</tt> : tensor(float)</dt>
<dt><tt>T1</tt> : tensor(int8), tensor(uint8)</dt>
<dd>Constrain 'x' and 'x_zero_point' to 8-bit integer tensors.</dd>
<dt><tt>T2</tt> : tensor(float16), tensor(float)</dt>
<dd>Constrain 'y', 'x_scale' to float tensors.</dd>
<dt><tt>T2</tt> : tensor(int8), tensor(uint8)</dt>
<dd>Constrain 'x_zero_point' and 'x' to 8-bit integer tensors.</dd>
</dl>
@ -1639,9 +1639,10 @@ This version of the operator has been available since version 1 of the 'com.micr
### <a name="com.microsoft.QuantizeLinear"></a><a name="com.microsoft.quantizelinear">**com.microsoft.QuantizeLinear**</a>
The linear quantization operator. It consumes a full precision data, a scale, a zero point and computes the quantized data.
The quantization formula is y = (x / y_scale) + y_zero_point. For (x / y_scale), it computes the nearest integer value to arg (in floating-point format),
rounding halfway cases away from zero. Scale and zero point must have same shape. They must be either scalar (per tensor) or 1-D tensor (per 'axis').
The linear quantization operator. It consumes a full precision data, a scale, a zero point to compute the low precision / quantized tensor.
The quantization formula is y = saturate ((x / y_scale) + y_zero_point).For saturation, it saturates to [0, 255] if it's uint8, or [-128, 127] if it's int8.
For (x / y_scale), it's rounding to nearest ties to even. Refer to https://en.wikipedia.org/wiki/Rounding for details.
Scale and zero point must have same shape. They must be either scalar (per tensor) or 1-D tensor (per 'axis').
#### Version
@ -1675,7 +1676,7 @@ This version of the operator has been available since version 1 of the 'com.micr
#### Type Constraints
<dl>
<dt><tt>T1</tt> : tensor(float)</dt>
<dt><tt>T1</tt> : tensor(float16), tensor(float)</dt>
<dd>Constrain 'x', 'y_scale' to float tensors.</dd>
<dt><tt>T2</tt> : tensor(int8), tensor(uint8)</dt>
<dd>Constrain 'y_zero_point' and 'y' to 8-bit integer tensors.</dd>
@ -2001,7 +2002,9 @@ This version of the operator has been available since version 1 of the 'com.micr
### <sub>experimental</sub> <a name="com.microsoft.Attention"></a><a name="com.microsoft.attention">**com.microsoft.Attention**</a>
Multi-Head Self Attention
Multi-Head Self Attention that can be either unidirectional (like GPT2) or bidirectional (like BERT).
The mask_index input is optional. Unidirectional and mask_index input are mutually exclusive. When unidirectional is 1, the
mask_index shall not be provided.
#### Version
@ -2011,6 +2014,8 @@ No versioning maintained for experimental ops.
<dl>
<dt><tt>num_heads</tt> : int (required)</dt>
<dd>Number of attention heads</dd>
<dt><tt>unidirectional</tt> : int</dt>
<dd>Whether every token can only attend to previous tokens. Default value is 0.</dd>
</dl>
#### Inputs (3 - 4)

View file

@ -361,8 +361,8 @@
|ConvTransposeWithDynamicPads|(*in* X:**T**, *in* W:**T**, *in* Pads:**tensor(int64)**, *in* B:**T**, *out* Y:**T**)|1+|**T** = tensor(float)|
|CropAndResize|(*in* X:**T1**, *in* rois:**T1**, *in* batch_indices:**T2**, *in* crop_size:**T2**, *out* Y:**T1**)|1+|**T** = tensor(float)|
| | ||**T2** = tensor(int32)|
|DequantizeLinear|(*in* x:**T2**, *in* x_scale:**T1**, *in* x_zero_point:**T2**, *out* y:**T1**)|1+|**T1** = tensor(float)|
| | ||**T2** = tensor(int8), tensor(uint8)|
|DequantizeLinear|(*in* x:**T1**, *in* x_scale:**T2**, *in* x_zero_point:**T1**, *out* y:**T2**)|1+|**T1** = tensor(int8), tensor(uint8)|
| | ||**T2** = tensor(float)|
|EmbedLayerNormalization|(*in* input_ids:**T1**, *in* segment_ids:**T1**, *in* word_embedding:**T**, *in* position_embedding:**T**, *in* segment_embedding:**T**, *in* gamma:**T**, *in* beta:**T**, *in* mask:**T1**, *out* output:**T**, *out* mask_index:**T1**)|1+|**T** = tensor(float)|
|ExpandDims|(*in* X:**T**, *in* axis:**tensor(int32)**, *out* Y:**T**)|1+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
| | ||**axis** = tensor(int32)|
@ -380,7 +380,7 @@
| | ||**T2** = tensor(int32), tensor(uint32)|
|Pad|(*in* data:**T**, *in* pads:**tensor(int64)**, *in* value:**T**, *out* output:**T**)|1+|**T** = tensor(float)|
|QuantizeLinear|(*in* x:**T1**, *in* y_scale:**T1**, *in* y_zero_point:**T2**, *out* y:**T2**)|1+|**T1** = tensor(float)|
| | ||**T2** = tensor(uint8)|
| | ||**T2** = tensor(int8), tensor(uint8)|
|Range|(*in* start:**T**, *in* limit:**T**, *in* delta:**T**, *out* Y:**T**)|1+|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64)|
|SampleOp|(*in* X:**T**, *out* Y:**T**)|1+|**T** = tensor(float)|
|SkipLayerNormalization|(*in* input:**T**, *in* skip:**T**, *in* gamma:**T**, *in* beta:**T**, *in* bias:**T**, *out* output:**T**, *out* mean:**U**, *out* inv_std_var:**U**)|1+|**T** = tensor(double), tensor(float)|
@ -422,16 +422,8 @@
| | ||**T** = tensor(double), tensor(float), tensor(float16)|
| | |[7, 9]|**I** = tensor(int64)|
| | ||**T** = tensor(double), tensor(float), tensor(float16)|
|BatchNormalization|(*in* X:**T**, *in* scale:**T**, *in* B:**T**, *in* mean:**T**, *in* var:**T**, *in* training_mode:**T1**, *out* Y:**T**, *out* output_mean:**T**, *out* output_var:**T**, *out* saved_mean:**T**, *out* saved_var:**T**) or (*in* X:**T**, *in* scale:**T**, *in* B:**T**, *in* mean:**T**, *in* var:**T**, *out* Y:**T**, *out* mean:**T**, *out* var:**T**, *out* saved_mean:**T**, *out* saved_var:**T**)|9+|**B** = tensor(double), tensor(float), tensor(float16)|
| | ||**X** = tensor(double), tensor(float), tensor(float16)|
| | ||**mean** = tensor(double), tensor(float), tensor(float16)|
| | ||**scale** = tensor(double), tensor(float), tensor(float16)|
| | ||**var** = tensor(double), tensor(float), tensor(float16)|
| | |[7, 8]|**B** = tensor(double), tensor(float), tensor(float16)|
| | ||**X** = tensor(double), tensor(float), tensor(float16)|
| | ||**mean** = tensor(double), tensor(float), tensor(float16)|
| | ||**scale** = tensor(double), tensor(float), tensor(float16)|
| | ||**var** = tensor(double), tensor(float), tensor(float16)|
|BatchNormalization|(*in* X:**T**, *in* scale:**T**, *in* B:**T**, *in* mean:**T**, *in* var:**T**, *in* training_mode:**T1**, *out* Y:**T**, *out* output_mean:**T**, *out* output_var:**T**, *out* saved_mean:**T**, *out* saved_var:**T**) or (*in* X:**T**, *in* scale:**T**, *in* B:**T**, *in* mean:**T**, *in* var:**T**, *out* Y:**T**, *out* mean:**T**, *out* var:**T**, *out* saved_mean:**T**, *out* saved_var:**T**)|9+|**T** = tensor(double), tensor(float), tensor(float16)|
| | |[7, 8]|**T** = tensor(double), tensor(float), tensor(float16)|
|Cast|(*in* input:**T1**, *out* output:**T2**)|9+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
| | ||**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
| | |[6, 8]|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
@ -642,10 +634,14 @@
|Attention|(*in* input:**T**, *in* weight:**T**, *in* bias:**T**, *in* mask_index:**M**, *out* output:**T**)|1+|**T** = tensor(float), tensor(float16)|
|BiasGelu|(*in* A:**T**, *in* B:**T**, *out* C:**T**)|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|ConvTransposeWithDynamicPads|(*in* X:**T**, *in* W:**T**, *in* Pads:**tensor(int64)**, *in* B:**T**, *out* Y:**T**)|1+|**T** = tensor(float)|
|DequantizeLinear|(*in* x:**T1**, *in* x_scale:**T2**, *in* x_zero_point:**T1**, *out* y:**T2**)|1+|**T1** = tensor(int8), tensor(uint8)|
| | ||**T2** = tensor(float16)|
|EmbedLayerNormalization|(*in* input_ids:**T1**, *in* segment_ids:**T1**, *in* word_embedding:**T**, *in* position_embedding:**T**, *in* segment_embedding:**T**, *in* gamma:**T**, *in* beta:**T**, *in* mask:**T1**, *out* output:**T**, *out* mask_index:**T1**)|1+|**T** = tensor(float), tensor(float16)|
|FastGelu|(*in* X:**T**, *in* bias:**T**, *out* Y:**T**)|1+|**T** = tensor(float), tensor(float16)|
|Gelu|(*in* X:**T**, *out* Y:**T**)|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|Irfft|(*in* X:**T**, *out* Y:**T**)|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|QuantizeLinear|(*in* x:**T1**, *in* y_scale:**T1**, *in* y_zero_point:**T2**, *out* y:**T2**)|1+|**T1** = tensor(float16)|
| | ||**T2** = tensor(int8), tensor(uint8)|
|Rfft|(*in* X:**T**, *out* Y:**T**)|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|SkipLayerNormalization|(*in* input:**T**, *in* skip:**T**, *in* gamma:**T**, *in* beta:**T**, *in* bias:**T**, *out* output:**T**, *out* mean:**U**, *out* inv_std_var:**U**)|1+|**T** = tensor(float), tensor(float16)|
| |

View file

@ -19,11 +19,7 @@ namespace cuda {
T, \
kCudaExecutionProvider, \
KernelDefBuilder() \
.TypeConstraint("X", DataTypeImpl::GetTensorType<T>()) \
.TypeConstraint("scale", DataTypeImpl::GetTensorType<T>()) \
.TypeConstraint("B", DataTypeImpl::GetTensorType<T>()) \
.TypeConstraint("mean", DataTypeImpl::GetTensorType<T>()) \
.TypeConstraint("var", DataTypeImpl::GetTensorType<T>()), \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
BatchNorm<T>); \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
BatchNormalization, \
@ -32,11 +28,7 @@ namespace cuda {
T, \
kCudaExecutionProvider, \
KernelDefBuilder() \
.TypeConstraint("X", DataTypeImpl::GetTensorType<T>()) \
.TypeConstraint("scale", DataTypeImpl::GetTensorType<T>()) \
.TypeConstraint("B", DataTypeImpl::GetTensorType<T>()) \
.TypeConstraint("mean", DataTypeImpl::GetTensorType<T>()) \
.TypeConstraint("var", DataTypeImpl::GetTensorType<T>()), \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
BatchNorm<T>);
template <typename T>