mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-19 21:32:23 +00:00
support ScatterND(18) and ScatterElement(18) (#14224)
This commit is contained in:
parent
d2c3d8eb38
commit
5d6a049141
6 changed files with 218 additions and 15 deletions
|
|
@ -304,10 +304,12 @@ Do not modify directly.*
|
|||
|||[9, 10]|**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|||8|**I** = tensor(int64)<br/> **V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|Scatter|*in* data:**T**<br> *in* indices:**Tind**<br> *in* updates:**T**<br> *out* output:**T**|[9, 10]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **Tind** = tensor(int32), tensor(int64)|
|
||||
|ScatterElements|*in* data:**T**<br> *in* indices:**Tind**<br> *in* updates:**T**<br> *out* output:**T**|16+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **Tind** = tensor(int32), tensor(int64)|
|
||||
|ScatterElements|*in* data:**T**<br> *in* indices:**Tind**<br> *in* updates:**T**<br> *out* output:**T**|18+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **Tind** = tensor(int32), tensor(int64)|
|
||||
|||[16, 17]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **Tind** = tensor(int32), tensor(int64)|
|
||||
|||[13, 15]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **Tind** = tensor(int32), tensor(int64)|
|
||||
|||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **Tind** = tensor(int32), tensor(int64)|
|
||||
|ScatterND|*in* data:**T**<br> *in* indices:**tensor(int64)**<br> *in* updates:**T**<br> *out* output:**T**|16+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|ScatterND|*in* data:**T**<br> *in* indices:**tensor(int64)**<br> *in* updates:**T**<br> *out* output:**T**|18+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|||[16, 17]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|||[13, 15]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|Selu|*in* X:**T**<br> *out* Y:**T**|6+|**T** = tensor(float)|
|
||||
|
|
|
|||
|
|
@ -765,8 +765,8 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, If
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, float, RoiAlign);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, double, RoiAlign);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, float, GridSample);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, ScatterElements);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, ScatterND);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, 17, ScatterElements);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, 17, ScatterND);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, string, Where);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, float, Where);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, double, Where);
|
||||
|
|
@ -830,6 +830,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, float, ReduceSumSquare);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, double, ReduceSumSquare);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int32_t, ReduceSumSquare);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, ScatterND);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, ScatterElements);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, Split);
|
||||
#if !defined(DISABLE_OPTIONAL_TYPE)
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, OptionalHasElement);
|
||||
|
|
@ -2018,8 +2020,8 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
|
|||
RoiAlign)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, float,
|
||||
GridSample)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, ScatterElements)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, ScatterND)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, 17, ScatterElements)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, 17, ScatterND)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, string, Where)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, float, Where)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, double, Where)>,
|
||||
|
|
@ -2128,11 +2130,13 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
|
|||
ReduceSumSquare)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, double,
|
||||
ReduceSumSquare)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, Split)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, ScatterND)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, ScatterElements)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, Split)>,
|
||||
#if !defined(DISABLE_OPTIONAL_TYPE)
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, OptionalHasElement)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, OptionalGetElement)>,
|
||||
#endif
|
||||
#endif
|
||||
};
|
||||
|
||||
for (auto& function_table_entry : function_table) {
|
||||
|
|
|
|||
|
|
@ -86,9 +86,20 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
|
|||
.TypeConstraint("Tind", BuildKernelDefConstraints<int32_t, int64_t>()),
|
||||
Scatter<EnabledScatterElementsDataTypes>);
|
||||
|
||||
ONNX_CPU_OPERATOR_KERNEL(
|
||||
ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
|
||||
ScatterElements,
|
||||
16,
|
||||
17,
|
||||
KernelDefBuilder()
|
||||
.MayInplace(0, 0)
|
||||
.TypeConstraint("T",
|
||||
BuildKernelDefConstraintsFromTypeList<EnabledScatterElementsDataTypes>())
|
||||
.TypeConstraint("Tind", BuildKernelDefConstraints<int32_t, int64_t>()),
|
||||
Scatter<EnabledScatterElementsDataTypes>);
|
||||
|
||||
ONNX_CPU_OPERATOR_KERNEL(
|
||||
ScatterElements,
|
||||
18,
|
||||
KernelDefBuilder()
|
||||
.MayInplace(0, 0)
|
||||
.TypeConstraint("T",
|
||||
|
|
@ -166,6 +177,76 @@ struct Func_Mul<BFloat16> {
|
|||
}
|
||||
};
|
||||
|
||||
template <class T>
|
||||
struct Func_Min {
|
||||
void operator()(T* a, const T* b) const {
|
||||
(*a) = (*a) < (*b) ? (*a) : (*b);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Func_Min<bool> {
|
||||
void operator()(bool*, const bool*) const {
|
||||
ORT_NOT_IMPLEMENTED("CPU execution provider: bool data type is not supported with ScatterElements opset 18 when reduction is 'min'.");
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Func_Min<std::string> {
|
||||
void operator()(std::string*, const std::string*) const {
|
||||
ORT_NOT_IMPLEMENTED("CPU execution provider: string data type is not supported with ScatterElements opset 18 when reduction is 'min'.");
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Func_Min<MLFloat16> {
|
||||
void operator()(MLFloat16*, const MLFloat16*) const {
|
||||
ORT_NOT_IMPLEMENTED("CPU execution provider: MLFloat16 data type is not supported with ScatterElements opset 18 when reduction is 'min'.");
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Func_Min<BFloat16> {
|
||||
void operator()(BFloat16*, const BFloat16*) const {
|
||||
ORT_NOT_IMPLEMENTED("CPU execution provider: BFloat16 data type is not supported with ScatterElements opset 18 when reduction is 'min'.");
|
||||
}
|
||||
};
|
||||
|
||||
template <class T>
|
||||
struct Func_Max {
|
||||
void operator()(T* a, const T* b) const {
|
||||
(*a) = (*a) > (*b) ? (*a) : (*b);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Func_Max<bool> {
|
||||
void operator()(bool*, const bool*) const {
|
||||
ORT_NOT_IMPLEMENTED("CPU execution provider: bool data type is not supported with ScatterElements opset 18 when reduction is 'max'.");
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Func_Max<std::string> {
|
||||
void operator()(std::string*, const std::string*) const {
|
||||
ORT_NOT_IMPLEMENTED("CPU execution provider: string data type is not supported with ScatterElements opset 18 when reduction is 'max'.");
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Func_Max<MLFloat16> {
|
||||
void operator()(MLFloat16*, const MLFloat16*) const {
|
||||
ORT_NOT_IMPLEMENTED("CPU execution provider: MLFloat16 data type is not supported with ScatterElements opset 18 when reduction is 'max'.");
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Func_Max<BFloat16> {
|
||||
void operator()(BFloat16*, const BFloat16*) const {
|
||||
ORT_NOT_IMPLEMENTED("CPU execution provider: BFloat16 data type is not supported with ScatterElements opset 18 when reduction is 'max'.");
|
||||
}
|
||||
};
|
||||
|
||||
template <class TIndex>
|
||||
Status GetIndices(
|
||||
const Tensor& data_input, const Tensor& indices_input, int64_t axis,
|
||||
|
|
@ -317,7 +398,13 @@ struct ScatterDataDispatchTarget {
|
|||
else if(reduction == "mul")
|
||||
return ScatterData<TData>(
|
||||
Func_Mul<TData>(), data_input, indices_data, updates_input, axis, data_output);
|
||||
else // if (reduction == "none")
|
||||
else if (reduction == "min")
|
||||
return ScatterData<TData>(
|
||||
Func_Min<TData>(), data_input, indices_data, updates_input, axis, data_output);
|
||||
else if (reduction == "max")
|
||||
return ScatterData<TData>(
|
||||
Func_Max<TData>(), data_input, indices_data, updates_input, axis, data_output);
|
||||
else // if (reduction == "none")
|
||||
return ScatterData<TData>(
|
||||
Func_Assignment<TData>(), data_input, indices_data, updates_input, axis, data_output);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -38,9 +38,18 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
|
|||
BuildKernelDefConstraintsFromTypeList<EnabledScatterNDDataTypes>()),
|
||||
ScatterND);
|
||||
|
||||
ONNX_CPU_OPERATOR_KERNEL(
|
||||
ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
|
||||
ScatterND,
|
||||
16,
|
||||
17,
|
||||
KernelDefBuilder()
|
||||
.TypeConstraint("T",
|
||||
BuildKernelDefConstraintsFromTypeList<EnabledScatterNDDataTypes>()),
|
||||
ScatterND);
|
||||
|
||||
ONNX_CPU_OPERATOR_KERNEL(
|
||||
ScatterND,
|
||||
18,
|
||||
KernelDefBuilder()
|
||||
.TypeConstraint("T",
|
||||
BuildKernelDefConstraintsFromTypeList<EnabledScatterNDDataTypes>()),
|
||||
|
|
@ -263,6 +272,84 @@ struct Func_Mul_ND<BFloat16> {
|
|||
}
|
||||
};
|
||||
|
||||
template <class T>
|
||||
struct Func_Min_ND {
|
||||
void operator()(T* a, const T* b, uint64_t element_to_copy) const {
|
||||
while (element_to_copy-- > 0) {
|
||||
(*a) = (*a) < (*b) ? (*a) : (*b);
|
||||
a++;
|
||||
b++;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Func_Min_ND<bool> {
|
||||
void operator()(bool*, const bool*, uint64_t) const {
|
||||
ORT_NOT_IMPLEMENTED("CPU execution provider: bool data type is not supported with ScatterND opset 18 when reduction is 'min'.");
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Func_Min_ND<std::string> {
|
||||
void operator()(std::string*, const std::string*, uint64_t) const {
|
||||
ORT_NOT_IMPLEMENTED("CPU execution provider: string data type is not supported with ScatterND opset 18 when reduction is 'min'.");
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Func_Min_ND<MLFloat16> {
|
||||
void operator()(MLFloat16*, const MLFloat16*, uint64_t) const {
|
||||
ORT_NOT_IMPLEMENTED("CPU execution provider: MLFloat16 data type is not supported with ScatterND opset 18 when reduction is 'min'.");
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Func_Min_ND<BFloat16> {
|
||||
void operator()(BFloat16*, const BFloat16*, uint64_t) const {
|
||||
ORT_NOT_IMPLEMENTED("CPU execution provider: BFloat16 data type is not supported with ScatterND opset 18 when reduction is 'min'.");
|
||||
}
|
||||
};
|
||||
|
||||
template <class T>
|
||||
struct Func_Max_ND {
|
||||
void operator()(T* a, const T* b, uint64_t element_to_copy) const {
|
||||
while (element_to_copy-- > 0) {
|
||||
(*a) = (*a) > (*b) ? (*a) : (*b);
|
||||
a++;
|
||||
b++;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Func_Max_ND<bool> {
|
||||
void operator()(bool*, const bool*, uint64_t) const {
|
||||
ORT_NOT_IMPLEMENTED("CPU execution provider: bool data type is not supported with ScatterND opset 18 when reduction is 'max'.");
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Func_Max_ND<std::string> {
|
||||
void operator()(std::string*, const std::string*, uint64_t) const {
|
||||
ORT_NOT_IMPLEMENTED("CPU execution provider: string data type is not supported with ScatterND opset 18 when reduction is 'max'.");
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Func_Max_ND<MLFloat16> {
|
||||
void operator()(MLFloat16*, const MLFloat16*, uint64_t) const {
|
||||
ORT_NOT_IMPLEMENTED("CPU execution provider: MLFloat16 data type is not supported with ScatterND opset 18 when reduction is 'max'.");
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Func_Max_ND<BFloat16> {
|
||||
void operator()(BFloat16*, const BFloat16*, uint64_t) const {
|
||||
ORT_NOT_IMPLEMENTED("CPU execution provider: BFloat16 data type is not supported with ScatterND opset 18 when reduction is 'max'.");
|
||||
}
|
||||
};
|
||||
|
||||
template <typename TData>
|
||||
struct ScatterNDDispatchTarget {
|
||||
Status operator()(OpKernelContext* context, concurrency::ThreadPool* tp, ScatterND::Reduction reduction) const {
|
||||
|
|
@ -285,6 +372,20 @@ struct ScatterNDDispatchTarget {
|
|||
prepare.input_base + i * prepare.element_to_copy,
|
||||
prepare.element_to_copy);
|
||||
} break;
|
||||
case ScatterND::Reduction::Min: {
|
||||
auto func = Func_Min_ND<TData>();
|
||||
func(
|
||||
prepare.output_base + prepare.element_offsets[onnxruntime::narrow<size_t>(i)],
|
||||
prepare.input_base + i * prepare.element_to_copy,
|
||||
prepare.element_to_copy);
|
||||
} break;
|
||||
case ScatterND::Reduction::Max: {
|
||||
auto func = Func_Max_ND<TData>();
|
||||
func(
|
||||
prepare.output_base + prepare.element_offsets[onnxruntime::narrow<size_t>(i)],
|
||||
prepare.input_base + i * prepare.element_to_copy,
|
||||
prepare.element_to_copy);
|
||||
} break;
|
||||
default:
|
||||
case ScatterND::Reduction::None: {
|
||||
auto func = Func_Copy_ND<TData>();
|
||||
|
|
|
|||
|
|
@ -18,7 +18,9 @@ class ScatterND final : public OpKernel {
|
|||
enum class Reduction : int {
|
||||
None = 0,
|
||||
Add,
|
||||
Mul
|
||||
Mul,
|
||||
Min,
|
||||
Max,
|
||||
};
|
||||
|
||||
explicit ScatterND(const OpKernelInfo& info) : OpKernel(info) {
|
||||
|
|
@ -30,6 +32,10 @@ class ScatterND final : public OpKernel {
|
|||
reduction_ = Reduction::Add;
|
||||
else if (reduction == "mul")
|
||||
reduction_ = Reduction::Mul;
|
||||
else if (reduction == "min")
|
||||
reduction_ = Reduction::Min;
|
||||
else if (reduction == "max")
|
||||
reduction_ = Reduction::Max;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -104,8 +104,6 @@
|
|||
"^test_identity_opt",
|
||||
// Following tests are for opset 16 ops and are not yet implemented in ORT
|
||||
"^test_roialign_aligned_*",
|
||||
"^test_scatternd_*",
|
||||
"^test_scatter_elements_with_duplicate_indices",
|
||||
//GPU failures
|
||||
"^test_batchnorm_epsilon_training_mode_cuda",
|
||||
"^test_batchnorm_example_training_mode_cuda",
|
||||
|
|
@ -127,7 +125,6 @@
|
|||
"^test_constant_pad_cpu",
|
||||
"^test_edge_pad_cpu",
|
||||
"^test_reflect_pad_cpu",
|
||||
"^test_scatter_elements_*",
|
||||
"^test_softplus_example_expanded_cpu",
|
||||
"^test_softplus_expanded_cpu",
|
||||
"^test_split_*",
|
||||
|
|
@ -312,6 +309,9 @@
|
|||
"^test_reduce_sum_do_not_keepdims_example",
|
||||
"^test_reduce_sum_do_not_keepdims_random",
|
||||
"^test_scatter_elements_with_negative_indices",
|
||||
"^test_scatter_elements_with_duplicate_indices*",
|
||||
"^test_scatternd_add*",
|
||||
"^test_scatternd_multiply*",
|
||||
"^test_squeeze",
|
||||
"^test_squeeze_negative_axes"
|
||||
],
|
||||
|
|
@ -332,6 +332,9 @@
|
|||
"^test_reduce_sum_default_axes_keepdims*",
|
||||
"^test_reduce_sum_empty_axes_input_noop*",
|
||||
"^test_scatter_elements_with_negative_indices_cpu",
|
||||
"^test_scatter_elements_with_duplicate_indices*",
|
||||
"^test_scatternd_add*",
|
||||
"^test_scatternd_multiply*",
|
||||
"^test_squeeze_*", // Does not support axes as input
|
||||
"^test_pow_types_float", // Runs disabled pow tests from the "current_failing_tests" list at the top
|
||||
"^test_loop11",
|
||||
|
|
|
|||
Loading…
Reference in a new issue