int64 support for 'where' op (#1666)

This commit is contained in:
Negin Raoof 2019-08-22 18:14:18 -05:00 committed by Changming Sun
parent c9a4fe2b7b
commit addf32fa2a
2 changed files with 3 additions and 1 deletions

View file

@ -260,6 +260,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, string, Where);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, float, Where);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int32_t, Where);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int64_t, Where);
// Opset 10
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, StringNormalizer);
@ -546,6 +547,7 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, string, Where)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, float, Where)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int32_t, Where)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int64_t, Where)>,
// Opset 10
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, StringNormalizer)>,

View file

@ -29,7 +29,7 @@ namespace onnxruntime {
//WHERE_TYPED_KERNEL(int8_t)
//WHERE_TYPED_KERNEL(int16_t)
WHERE_TYPED_KERNEL(int32_t)
//WHERE_TYPED_KERNEL(int64_t)
WHERE_TYPED_KERNEL(int64_t)
//WHERE_TYPED_KERNEL(MLFloat16)
//WHERE_TYPED_KERNEL(BFloat16)
WHERE_TYPED_KERNEL(float)