mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-29 23:06:41 +00:00
Enable tests for EyeLike and enable datatypes present in tests (#424)
* Enable tests for EyeLike and enable datatypes present in tests * fix failure
This commit is contained in:
parent
9f0298261d
commit
60cdb79204
2 changed files with 18 additions and 13 deletions
|
|
@ -15,15 +15,19 @@ ONNX_CPU_OPERATOR_KERNEL(
|
|||
KernelDefBuilder().TypeConstraint("T1",
|
||||
std::vector<MLDataType>{
|
||||
DataTypeImpl::GetTensorType<float>(),
|
||||
DataTypeImpl::GetTensorType<int64_t>(),
|
||||
DataTypeImpl::GetTensorType<double>(),
|
||||
DataTypeImpl::GetTensorType<uint64_t>(),
|
||||
DataTypeImpl::GetTensorType<int64_t>(),
|
||||
DataTypeImpl::GetTensorType<int32_t>()
|
||||
})
|
||||
.TypeConstraint("T2",
|
||||
std::vector<MLDataType>{
|
||||
DataTypeImpl::GetTensorType<float>(),
|
||||
DataTypeImpl::GetTensorType<uint64_t>(),
|
||||
DataTypeImpl::GetTensorType<int64_t>(),
|
||||
}),
|
||||
.TypeConstraint("T2",
|
||||
std::vector<MLDataType>{
|
||||
DataTypeImpl::GetTensorType<float>(),
|
||||
DataTypeImpl::GetTensorType<double>(),
|
||||
DataTypeImpl::GetTensorType<uint64_t>(),
|
||||
DataTypeImpl::GetTensorType<int64_t>(),
|
||||
DataTypeImpl::GetTensorType<int32_t>()
|
||||
}),
|
||||
EyeLike);
|
||||
|
||||
Status EyeLike::Compute(OpKernelContext* context) const {
|
||||
|
|
@ -34,10 +38,14 @@ Status EyeLike::Compute(OpKernelContext* context) const {
|
|||
switch (output_tensor_dtype) {
|
||||
case onnx::TensorProto_DataType_FLOAT:
|
||||
return ComputeImpl<float>(context, T1);
|
||||
case onnx::TensorProto_DataType_INT64:
|
||||
return ComputeImpl<int64_t>(context, T1);
|
||||
case onnx::TensorProto_DataType_DOUBLE:
|
||||
return ComputeImpl<double>(context, T1);
|
||||
case onnx::TensorProto_DataType_INT32:
|
||||
return ComputeImpl<int32_t>(context, T1);
|
||||
case onnx::TensorProto_DataType_UINT64:
|
||||
return ComputeImpl<uint64_t>(context, T1);
|
||||
case onnx::TensorProto_DataType_INT64:
|
||||
return ComputeImpl<int64_t>(context, T1);
|
||||
default:
|
||||
ORT_THROW("Unsupported 'dtype' value: ", output_tensor_dtype);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -274,10 +274,7 @@ int real_main(int argc, char* argv[]) {
|
|||
{"PoissonNLLLLoss_no_reduce", "disable reason"},
|
||||
{"Softsign", "disable reason"},
|
||||
{"convtranspose_1d", "disable reason"},
|
||||
{"convtranspose_3d", "disable reason"},
|
||||
{"eyelike_populate_off_main_diagonal", "disable reason"},
|
||||
{"eyelike_with_dtype", "disable reason"},
|
||||
{"eyelike_without_dtype", "disable reason"},
|
||||
{"convtranspose_3d", "disable reason"},
|
||||
{"flatten_axis0", "disable reason"},
|
||||
{"flatten_axis1", "disable reason"},
|
||||
{"flatten_axis2", "disable reason"},
|
||||
|
|
|
|||
Loading…
Reference in a new issue