type cast for ratio is not necessary for dropout (#3682)

Co-authored-by: Weixing Zhang <wezhan@microsoft.com>
This commit is contained in:
Weixing Zhang 2020-04-24 00:49:37 -07:00 committed by GitHub
parent f4a04c04e1
commit c929963d74
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -37,7 +37,6 @@ REGISTER_KERNEL_TYPED(Dropout, kOnnxDomain, 12, double, double, 1)
template <typename T1, typename T2>
Status Dropout<T1, T2>::ComputeInternal(OpKernelContext* context) const {
typedef typename ToCudaType<T1>::MappedType CudaT;
typedef typename ToCudaType<T2>::MappedType CudaT2;
//Get X_data
const Tensor* X = context->Input<Tensor>(0);
@ -68,7 +67,7 @@ Status Dropout<T1, T2>::ComputeInternal(OpKernelContext* context) const {
"T2 must be float16 or float or double");
if (ratio) {
ratio_data = static_cast<float>(*reinterpret_cast<const CudaT2*>(ratio->template Data<T2>()));
ratio_data = static_cast<float>(*(ratio->template Data<T2>()));
} else {
ratio_data = default_ratio_;
}
@ -112,7 +111,7 @@ Status DropoutGrad<T1, T2>::ComputeInternal(OpKernelContext* context) const {
"T2 must be float16 or float or double");
if (ratio) {
ratio_data = static_cast<float>(*reinterpret_cast<const T2*>(ratio->template Data<T2>()));
ratio_data = static_cast<float>(*(ratio->template Data<T2>()));
} else {
ratio_data = default_ratio_;
}