From cccd61e3bc6f15120932a032a48bbdcc569938ea Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Wed, 23 Jun 2021 14:53:06 +1000 Subject: [PATCH] Add int64 as a required type to ConstantOfShape as it's used by the pytorch converter for Pad. (#8128) It's also used pointlessly for torch.tensor.repeat (although that usage should always be able to be constant folded). --- .../core/providers/cpu/generator/constant_of_shape.cc | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/cpu/generator/constant_of_shape.cc b/onnxruntime/core/providers/cpu/generator/constant_of_shape.cc index f7e7033471..bba62b290e 100644 --- a/onnxruntime/core/providers/cpu/generator/constant_of_shape.cc +++ b/onnxruntime/core/providers/cpu/generator/constant_of_shape.cc @@ -10,7 +10,13 @@ namespace op_kernel_type_control { ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPE_LIST_ALL_OPSETS( kCpuExecutionProvider, kOnnxDomain, ConstantOfShape, Output, 0, ConstantOfShapeDefaultOutputTypes); -} + +// pytorch converter uses ConstantOfShape with int64 to create Pad input +// https://github.com/pytorch/pytorch/blob/044b519a80459f6787f6723c1c091a18b153d184/torch/onnx/symbolic_opset11.py#L449 +ORT_SPECIFY_OP_KERNEL_ARG_REQUIRED_TYPES_ALL_OPSETS( + kCpuExecutionProvider, kOnnxDomain, ConstantOfShape, Output, 0, + int64_t); +} // namespace op_kernel_type_control namespace {