mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-31 23:27:43 +00:00
type cast for ratio is not necessary for dropout (#3682)
Co-authored-by: Weixing Zhang <wezhan@microsoft.com>
This commit is contained in:
parent
f4a04c04e1
commit
c929963d74
1 changed files with 2 additions and 3 deletions
|
|
@ -37,7 +37,6 @@ REGISTER_KERNEL_TYPED(Dropout, kOnnxDomain, 12, double, double, 1)
|
|||
template <typename T1, typename T2>
|
||||
Status Dropout<T1, T2>::ComputeInternal(OpKernelContext* context) const {
|
||||
typedef typename ToCudaType<T1>::MappedType CudaT;
|
||||
typedef typename ToCudaType<T2>::MappedType CudaT2;
|
||||
|
||||
//Get X_data
|
||||
const Tensor* X = context->Input<Tensor>(0);
|
||||
|
|
@ -68,7 +67,7 @@ Status Dropout<T1, T2>::ComputeInternal(OpKernelContext* context) const {
|
|||
"T2 must be float16 or float or double");
|
||||
|
||||
if (ratio) {
|
||||
ratio_data = static_cast<float>(*reinterpret_cast<const CudaT2*>(ratio->template Data<T2>()));
|
||||
ratio_data = static_cast<float>(*(ratio->template Data<T2>()));
|
||||
} else {
|
||||
ratio_data = default_ratio_;
|
||||
}
|
||||
|
|
@ -112,7 +111,7 @@ Status DropoutGrad<T1, T2>::ComputeInternal(OpKernelContext* context) const {
|
|||
"T2 must be float16 or float or double");
|
||||
|
||||
if (ratio) {
|
||||
ratio_data = static_cast<float>(*reinterpret_cast<const T2*>(ratio->template Data<T2>()));
|
||||
ratio_data = static_cast<float>(*(ratio->template Data<T2>()));
|
||||
} else {
|
||||
ratio_data = default_ratio_;
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue