From 60cdb79204abbf786149faecaf9660259e6b5907 Mon Sep 17 00:00:00 2001 From: Ashwini Khade Date: Fri, 1 Feb 2019 00:21:08 -0800 Subject: [PATCH] Enable tests for EyeLike and enable datatypes present in tests (#424) * Enable tests for EyeLike and enable datatypes present in tests * fix failure --- .../core/providers/cpu/tensor/eye_like.cc | 26 ++++++++++++------- onnxruntime/test/onnx/main.cc | 5 +--- 2 files changed, 18 insertions(+), 13 deletions(-) diff --git a/onnxruntime/core/providers/cpu/tensor/eye_like.cc b/onnxruntime/core/providers/cpu/tensor/eye_like.cc index 725f9a0999..40bdd7d41f 100644 --- a/onnxruntime/core/providers/cpu/tensor/eye_like.cc +++ b/onnxruntime/core/providers/cpu/tensor/eye_like.cc @@ -15,15 +15,19 @@ ONNX_CPU_OPERATOR_KERNEL( KernelDefBuilder().TypeConstraint("T1", std::vector{ DataTypeImpl::GetTensorType(), - DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType() }) - .TypeConstraint("T2", - std::vector{ - DataTypeImpl::GetTensorType(), - DataTypeImpl::GetTensorType(), - DataTypeImpl::GetTensorType(), - }), + .TypeConstraint("T2", + std::vector{ + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType() + }), EyeLike); Status EyeLike::Compute(OpKernelContext* context) const { @@ -34,10 +38,14 @@ Status EyeLike::Compute(OpKernelContext* context) const { switch (output_tensor_dtype) { case onnx::TensorProto_DataType_FLOAT: return ComputeImpl(context, T1); - case onnx::TensorProto_DataType_INT64: - return ComputeImpl(context, T1); + case onnx::TensorProto_DataType_DOUBLE: + return ComputeImpl(context, T1); + case onnx::TensorProto_DataType_INT32: + return ComputeImpl(context, T1); case onnx::TensorProto_DataType_UINT64: return ComputeImpl(context, T1); + case onnx::TensorProto_DataType_INT64: + return ComputeImpl(context, T1); default: ORT_THROW("Unsupported 'dtype' value: ", output_tensor_dtype); } diff --git a/onnxruntime/test/onnx/main.cc b/onnxruntime/test/onnx/main.cc index 03f032d6c5..35f5bb07c9 100644 --- a/onnxruntime/test/onnx/main.cc +++ b/onnxruntime/test/onnx/main.cc @@ -274,10 +274,7 @@ int real_main(int argc, char* argv[]) { {"PoissonNLLLLoss_no_reduce", "disable reason"}, {"Softsign", "disable reason"}, {"convtranspose_1d", "disable reason"}, - {"convtranspose_3d", "disable reason"}, - {"eyelike_populate_off_main_diagonal", "disable reason"}, - {"eyelike_with_dtype", "disable reason"}, - {"eyelike_without_dtype", "disable reason"}, + {"convtranspose_3d", "disable reason"}, {"flatten_axis0", "disable reason"}, {"flatten_axis1", "disable reason"}, {"flatten_axis2", "disable reason"},