From a3c95374c33463b85db84e388488baf52bd2d9df Mon Sep 17 00:00:00 2001 From: Hariharan Seshadri Date: Tue, 18 Aug 2020 02:09:30 -0700 Subject: [PATCH] Support asymmetric paddings in CUDA Conv kernel (#4627) --- .../core/providers/cpu/nn/conv_attributes.h | 154 ++++++++++++++++- .../core/providers/cpu/tensor/slice.cc | 105 +++++------- onnxruntime/core/providers/cpu/tensor/slice.h | 76 +++++---- .../providers/cuda/cuda_execution_provider.cc | 8 +- onnxruntime/core/providers/cuda/nn/conv.cc | 92 +++++++++- onnxruntime/core/providers/cuda/nn/conv.h | 19 ++- .../core/providers/cuda/tensor/slice.cc | 159 ++++++++++++------ .../core/providers/cuda/tensor/slice.h | 14 +- .../test/providers/cpu/nn/conv_op_test.cc | 28 ++- .../training_ops/cpu/tensor/slice_grad.cc | 16 +- .../training_ops/cuda/tensor/slice_grad.cc | 2 +- .../training_ops/cuda/tensor/slice_grad.h | 2 +- 12 files changed, 485 insertions(+), 190 deletions(-) diff --git a/onnxruntime/core/providers/cpu/nn/conv_attributes.h b/onnxruntime/core/providers/cpu/nn/conv_attributes.h index 24d19f8cd2..19b6b940f5 100644 --- a/onnxruntime/core/providers/cpu/nn/conv_attributes.h +++ b/onnxruntime/core/providers/cpu/nn/conv_attributes.h @@ -112,13 +112,26 @@ struct ConvAttributes { std::vector& output_shape, bool force_symmetric_auto_padding = false) const { size_t rank = input_shape.NumDimensions(); + + // Make sure all "metadata" containers have the right number of elements + if (rank > strides_p.size()) + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Not enough elements in strides. Expected: ", rank, " Got: ", strides_p.size()); + + if (rank > kernel_shape.size()) + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Not enough elements in kernel shape. Expected: ", rank, " Got: ", kernel_shape.size()); + + if (rank > dilations_p.size()) + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Not enough elements in dilations. Expected: ", rank, " Got: ", dilations_p.size()); + + if ((2 * rank) > pads_p.size()) + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Not enough elements in pads. Expected: ", (2 * rank), " Got: ", pads_p.size()); + for (size_t dim = 0; dim < rank; ++dim) { - if (dim >= strides_p.size() || dim >= kernel_shape.size() || - dim >= dilations_p.size() || dim >= pads_p.size() || - rank + dim >= pads_p.size()) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Out of bound access to array"); - } - int64_t dim_size = 0; + int64_t output_dim_size = 0; ORT_RETURN_IF_ERROR(ComputePadAndOutputShape(input_shape[dim], strides_p[dim], kernel_shape[dim], @@ -126,12 +139,135 @@ struct ConvAttributes { auto_pad, pads_p.at(dim), pads_p.at(input_shape.NumDimensions() + dim), - dim_size, + output_dim_size, force_symmetric_auto_padding)); - if (dim_size <= 0) { + if (output_dim_size <= 0) { return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Invalid input shape: " + input_shape.ToString()); } - output_shape.push_back(dim_size); + output_shape.push_back(output_dim_size); + } + return Status::OK(); + } + + // Use this method when pads are to be made symmetrical (if they are asymmetric) + // and to collect metadata regarding the portion of the output (with "adjusted" pads) + // to be sliced off to make the output correspond to the "actual" asymmetric paddings + Status InferOutputShapeWithAdjustedPads(const TensorShape& input_shape, + const std::vector& kernel_shape, + const std::vector& strides_p, + const std::vector& dilations_p, + std::vector& pads_p, + std::vector& output_shape, + std::vector& output_shape_with_revised_pads, + bool& post_slicing_needed, + std::vector& slice_starts, + std::vector& slice_ends, + std::vector& slice_axes) const { + size_t rank = input_shape.NumDimensions(); + // Make sure all "metadata" containers have the right number of elements + if (rank > strides_p.size()) + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Not enough elements in strides. Expected: ", rank, " Got: ", strides_p.size()); + + if (rank > kernel_shape.size()) + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Not enough elements in kernel shape. Expected: ", rank, " Got: ", kernel_shape.size()); + + if (rank > dilations_p.size()) + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Not enough elements in dilations. Expected: ", rank, " Got: ", dilations_p.size()); + + if ((2 * rank) > pads_p.size()) + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Not enough elements in pads. Expected: ", (2 * rank), " Got: ", pads_p.size()); + + for (size_t dim = 0; dim < rank; ++dim) { + int64_t output_dim_size = 0; + ORT_RETURN_IF_ERROR(ComputePadAndOutputShape(input_shape[dim], + strides_p[dim], + kernel_shape[dim], + dilations_p[dim], + auto_pad, + pads_p[dim], + pads_p[rank + dim], + output_dim_size)); + if (output_dim_size <= 0) { + return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Invalid input shape: " + input_shape.ToString()); + } + + // This is the "actual" output shape of the Conv op (i.e.) with given pads as is + output_shape.push_back(output_dim_size); + + // Deal with asymmetric pads if any and adjust them to be symmetric + // Along the way - note down how many values need to be sliced out from the start and end + // of each spatial dimension while we over-pad to get symmetric pads + + if (pads_p[dim] == pads_p[rank + dim]) { + // symmetric padding - No operation as such + // Make note of the dim size of the output (to be used if there are other symmetrically padded dims) + output_shape_with_revised_pads.push_back(output_dim_size); + } else { + // asymmetric padding + + int64_t& pad_head = pads_p[dim]; + int64_t& pad_tail = pads_p[rank + dim]; + int64_t stride = strides_p[dim]; + + bool head_overpadded = false; + + if (pad_head < pad_tail) { + int64_t excess_output_head = 0; + + // pad_head is under-padded, so "adjust" it by adding more padding + while (pad_head < pad_tail) { + // keep over-padding in multiples of 'strides' so that + // the filter slides over correctly + + pad_head += stride; + excess_output_head += 1; // each multiple of stride contributes to 1 excess output value + } + + post_slicing_needed = true; + slice_axes.push_back(dim + 2); + slice_starts.push_back(excess_output_head); + slice_ends.push_back(excess_output_head + output_dim_size); // we may modify this below + output_shape_with_revised_pads.push_back(excess_output_head + output_dim_size); // we may modify this below + head_overpadded = true; + } + + // we may enter this section even if the head was initially under-padded, + // because we had to over-pad by multiples of 'stride', now `pad_head` might be > `pad_tail` + if (pad_tail < pad_head) { + pad_tail = pad_head; + auto revised_dim_size = ComputeOutputShape(input_shape[dim], strides_p[dim], + kernel_shape[dim], dilations_p[dim], + pad_head, pad_tail); + + if (head_overpadded) { + // Head has already been over-padded + // Additional tail pads need not result in additional output + // Ensure that the size has changed - otherwise no operation needed. + if (revised_dim_size != + output_shape_with_revised_pads[output_shape_with_revised_pads.size() - 1]) { + output_shape_with_revised_pads[output_shape_with_revised_pads.size() - 1] = revised_dim_size; + } + } else { + // Additional tail pads need not result in additional output + // Ensure that the size has changed - otherwise no operation needed. + if (revised_dim_size != output_dim_size) { + // Head has not been over-padded. Only tail pads need to be modified. + post_slicing_needed = true; + + slice_axes.push_back(dim + 2); + slice_starts.push_back(0); + slice_ends.push_back(output_dim_size - revised_dim_size); + } + + // make note of the shape of this spatial dimension + output_shape_with_revised_pads.push_back(revised_dim_size); + } + } + } } return Status::OK(); } diff --git a/onnxruntime/core/providers/cpu/tensor/slice.cc b/onnxruntime/core/providers/cpu/tensor/slice.cc index 582b387868..f60bcaf96f 100644 --- a/onnxruntime/core/providers/cpu/tensor/slice.cc +++ b/onnxruntime/core/providers/cpu/tensor/slice.cc @@ -89,22 +89,18 @@ static void FlattenOutputDims(const std::vector& input_dimensions, Status SliceBase::PrepareForCompute(const std::vector& raw_starts, const std::vector& raw_ends, const std::vector& raw_axes, - const std::vector& input_dimensions, - std::vector& starts, - std::vector& steps, - std::vector& output_dims, - std::vector*& flattened_output_dims) const { + SliceOp::PrepareForComputeMetadata& compute_metadata) { // Initialize axes to the provided axes attribute or to the default sequence std::vector axes(raw_axes); if (axes.empty()) { //axes are omitted, they are set to[0, ..., ndim - 1] - axes.resize(starts.size()); + axes.resize(compute_metadata.starts_.size()); std::iota(axes.begin(), axes.end(), 0); } // Iterate through the provided axes and override the start/end ranges std::unordered_set unique_axes; - const auto& dimension_count = input_dimensions.size(); + const auto& dimension_count = compute_metadata.input_dimensions_.size(); for (size_t axis_index = 0, axes_count = axes.size(); axis_index < axes_count; ++axis_index) { auto axis = HandleNegativeAxis(axes[axis_index], dimension_count); // handle negative and enforce axis is valid if (axis >= static_cast(dimension_count) || axis < 0) @@ -116,23 +112,24 @@ Status SliceBase::PrepareForCompute(const std::vector& raw_starts, // process start auto start = raw_starts[axis_index]; if (start < 0) - start += input_dimensions[axis]; - starts[axis] = clamp(start, int64_t{0}, input_dimensions[axis]); + start += compute_metadata.input_dimensions_[axis]; + compute_metadata.starts_[axis] = clamp(start, int64_t{0}, compute_metadata.input_dimensions_[axis]); // process end auto end = raw_ends[axis_index]; if (end < 0) - end += input_dimensions[axis]; + end += compute_metadata.input_dimensions_[axis]; // find output dim value for this axis - auto temp = clamp(end, int64_t{0}, input_dimensions[axis]) - starts[axis]; + auto temp = clamp(end, int64_t{0}, compute_metadata.input_dimensions_[axis]) - compute_metadata.starts_[axis]; if (temp < 0) - output_dims[axis] = 0; + compute_metadata.output_dims_[axis] = 0; else - output_dims[axis] = temp; + compute_metadata.output_dims_[axis] = temp; } - FlattenOutputDims(input_dimensions, output_dims, starts, steps, flattened_output_dims); + FlattenOutputDims(compute_metadata.input_dimensions_, compute_metadata.output_dims_, compute_metadata.starts_, + compute_metadata.steps_, compute_metadata.p_flattened_output_dims_); return Status::OK(); } @@ -142,23 +139,19 @@ Status SliceBase::PrepareForCompute(const std::vector& raw_starts, const std::vector& raw_ends, const std::vector& raw_axes, const std::vector& raw_steps, - const std::vector& input_dimensions, - std::vector& starts, - std::vector& steps, - std::vector& output_dims, - std::vector*& flattened_output_dims) const { + SliceOp::PrepareForComputeMetadata& compute_metadata) { // Initialize axes to the provided axes attribute or to the default sequence std::vector axes(raw_axes); if (axes.empty()) { // axes are omitted, they are set to[0, ..., ndim - 1] - axes.resize(starts.size()); + axes.resize(compute_metadata.starts_.size()); std::iota(axes.begin(), axes.end(), 0); } // Iterate through the provided axes and override the start/end/steps ranges std::unordered_set unique_axes; - const auto& dimension_count = input_dimensions.size(); + const auto& dimension_count = compute_metadata.input_dimensions_.size(); for (size_t axis_index = 0, axes_count = axes.size(); axis_index < axes_count; ++axis_index) { auto axis = axes[axis_index] < 0 ? axes[axis_index] + static_cast(dimension_count) : axes[axis_index]; if (axis >= static_cast(dimension_count) || axis < 0) @@ -171,16 +164,16 @@ Status SliceBase::PrepareForCompute(const std::vector& raw_starts, auto step = axis_index < raw_steps.size() ? raw_steps[axis_index] : 1; if (step == 0) return Status(ONNXRUNTIME, INVALID_ARGUMENT, "'step' value cannot be 0"); - steps[axis] = step; + compute_metadata.steps_[axis] = step; // process start auto start = raw_starts[axis_index]; if (start < 0) - start += input_dimensions[axis]; + start += compute_metadata.input_dimensions_[axis]; if (step < 0) - starts[axis] = clamp(start, int64_t{0}, input_dimensions[axis] - 1); + compute_metadata.starts_[axis] = clamp(start, int64_t{0}, compute_metadata.input_dimensions_[axis] - 1); else - starts[axis] = clamp(start, int64_t{0}, input_dimensions[axis]); + compute_metadata.starts_[axis] = clamp(start, int64_t{0}, compute_metadata.input_dimensions_[axis]); // process end auto end = raw_ends[axis_index]; @@ -189,27 +182,28 @@ Status SliceBase::PrepareForCompute(const std::vector& raw_starts, // it represent slicing to the end of the dimension if (end == std::numeric_limits::max() || end == std::numeric_limits::max()) { - end = step < 0 ? -1 : input_dimensions[axis]; + end = step < 0 ? -1 : compute_metadata.input_dimensions_[axis]; } else { if (end < 0) - end += input_dimensions[axis]; + end += compute_metadata.input_dimensions_[axis]; if (step < 0) - end = clamp(end, int64_t{-1}, input_dimensions[axis]); + end = clamp(end, int64_t{-1}, compute_metadata.input_dimensions_[axis]); else - end = clamp(end, int64_t{0}, input_dimensions[axis]); + end = clamp(end, int64_t{0}, compute_metadata.input_dimensions_[axis]); } // find output dim value for this axis - auto temp = static_cast(ceil(1.0 * (end - starts[axis]) / step)); + auto temp = static_cast(ceil(1.0 * (end - compute_metadata.starts_[axis]) / step)); if (temp < 0) - output_dims[axis] = 0; + compute_metadata.output_dims_[axis] = 0; else - output_dims[axis] = temp; + compute_metadata.output_dims_[axis] = temp; } - FlattenOutputDims(input_dimensions, output_dims, starts, steps, flattened_output_dims); + FlattenOutputDims(compute_metadata.input_dimensions_, compute_metadata.output_dims_, compute_metadata.starts_, + compute_metadata.steps_, compute_metadata.p_flattened_output_dims_); return Status::OK(); } @@ -222,7 +216,7 @@ void SliceBase::FillVectorsFromInput(const Tensor& start_tensor, std::vector& input_starts, std::vector& input_ends, std::vector& input_axes, - std::vector& input_steps) const { + std::vector& input_steps) { ORT_ENFORCE(start_tensor.Shape().NumDimensions() == 1, "Starts must be a 1-D array"); ORT_ENFORCE(ends_tensor.Shape().NumDimensions() == 1, "Ends must be a 1-D array"); ORT_ENFORCE(start_tensor.Shape() == ends_tensor.Shape(), "Starts and ends shape mismatch"); @@ -267,11 +261,8 @@ void SliceBase::FillVectorsFromInput(const Tensor& start_tensor, template static Status SliceImpl(OpKernelContext* ctx, const Tensor& input_tensor, - std::vector& output_dims, - std::vector* flattened_output_dims, - const std::vector& starts, - const std::vector& steps) { - TensorShape output_shape(output_dims); + SliceOp::PrepareForComputeMetadata& compute_metadata) { + TensorShape output_shape(compute_metadata.output_dims_); auto& output_tensor = *ctx->Output(0, output_shape); // output tensor's size is 0, nothing to fill - return @@ -296,18 +287,18 @@ static Status SliceImpl(OpKernelContext* ctx, ORT_ENFORCE(output == output_end); }; - if (flattened_output_dims) { + if (compute_metadata.p_flattened_output_dims_) { // if we have flattened output dims we need to also flatten the input dims. // as we're combining the innermost dims and keeping all values we can just copy the size of the last dim std::vector flattened_input_dims(input_tensor.Shape().GetDims()); - flattened_input_dims.resize(flattened_output_dims->size()); - flattened_input_dims.back() = flattened_output_dims->back(); + flattened_input_dims.resize(compute_metadata.p_flattened_output_dims_->size()); + flattened_input_dims.back() = compute_metadata.p_flattened_output_dims_->back(); TensorShape input_shape(std::move(flattened_input_dims)); - auto input_iterator = SliceIterator(input_tensor, input_shape, starts, *flattened_output_dims, steps); + auto input_iterator = SliceIterator(input_tensor, input_shape, compute_metadata.starts_, *compute_metadata.p_flattened_output_dims_, compute_metadata.steps_); create_output(input_iterator); } else { - auto input_iterator = SliceIterator(input_tensor, starts, output_dims, steps); + auto input_iterator = SliceIterator(input_tensor, compute_metadata.starts_, compute_metadata.output_dims_, compute_metadata.steps_); create_output(input_iterator); } @@ -320,13 +311,7 @@ Status SliceBase::Compute(OpKernelContext* ctx) const { const auto& input_tensor = *input_tensor_ptr; const auto& input_dimensions = input_tensor.Shape().GetDims(); if (input_dimensions.empty()) return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Cannot slice scalars"); - - // Initialize the starts & ends to the actual tensor shape - std::vector starts(input_dimensions.size(), 0); - std::vector steps(input_dimensions.size(), 1); - std::vector output_dims(input_dimensions); - std::vector flattened_output_dims; - std::vector* p_flattened_output_dims = &flattened_output_dims; + SliceOp::PrepareForComputeMetadata compute_metadata(input_dimensions); // Slice V10 & DynamicSlice if (dynamic_) { @@ -337,36 +322,32 @@ Status SliceBase::Compute(OpKernelContext* ctx) const { FillVectorsFromInput(*ctx->Input(1), *ctx->Input(2), ctx->Input(3), ctx->Input(4), input_starts, input_ends, input_axes, input_steps); - ORT_RETURN_IF_ERROR(PrepareForCompute(input_starts, input_ends, input_axes, input_steps, - input_dimensions, starts, steps, output_dims, - p_flattened_output_dims)); + ORT_RETURN_IF_ERROR(PrepareForCompute(input_starts, input_ends, input_axes, input_steps, compute_metadata)); } // Slice V1-9 else { - ORT_RETURN_IF_ERROR(PrepareForCompute(attr_starts_, attr_ends_, attr_axes_, - input_dimensions, starts, steps, output_dims, - p_flattened_output_dims)); + ORT_RETURN_IF_ERROR(PrepareForCompute(attr_starts_, attr_ends_, attr_axes_, compute_metadata)); } Status status = Status::OK(); if (input_tensor.IsDataTypeString()) { - status = SliceImpl(ctx, input_tensor, output_dims, p_flattened_output_dims, starts, steps); + status = SliceImpl(ctx, input_tensor, compute_metadata); } else { const auto element_size = input_tensor.DataType()->Size(); switch (element_size) { case sizeof(uint32_t): - status = SliceImpl(ctx, input_tensor, output_dims, p_flattened_output_dims, starts, steps); + status = SliceImpl(ctx, input_tensor, compute_metadata); break; case sizeof(uint64_t): - status = SliceImpl(ctx, input_tensor, output_dims, p_flattened_output_dims, starts, steps); + status = SliceImpl(ctx, input_tensor, compute_metadata); break; case sizeof(uint16_t): - status = SliceImpl(ctx, input_tensor, output_dims, p_flattened_output_dims, starts, steps); + status = SliceImpl(ctx, input_tensor, compute_metadata); break; case sizeof(uint8_t): - status = SliceImpl(ctx, input_tensor, output_dims, p_flattened_output_dims, starts, steps); + status = SliceImpl(ctx, input_tensor, compute_metadata); break; default: ORT_THROW("Unsupported input data type of ", input_tensor.DataType()); diff --git a/onnxruntime/core/providers/cpu/tensor/slice.h b/onnxruntime/core/providers/cpu/tensor/slice.h index 09a9a1f2f0..7febbf84d2 100644 --- a/onnxruntime/core/providers/cpu/tensor/slice.h +++ b/onnxruntime/core/providers/cpu/tensor/slice.h @@ -7,7 +7,52 @@ namespace onnxruntime { +namespace SliceOp { +struct PrepareForComputeMetadata { + PrepareForComputeMetadata() = delete; + PrepareForComputeMetadata(const std::vector& input_dimensions) + : input_dimensions_(input_dimensions) { + size_t dimension_count = input_dimensions.size(); + starts_.resize(dimension_count, 0); + steps_.resize(dimension_count, 1); + output_dims_ = input_dimensions; + } + + const std::vector& input_dimensions_; + std::vector starts_; + std::vector steps_; + std::vector output_dims_; + std::vector flattened_output_dims_; + std::vector* p_flattened_output_dims_ = &flattened_output_dims_; +}; +} // namespace SliceOp + class SliceBase { + // static methods that can be used from other ops if needed + public: + // compute output_dims without steps (Slice V1-9 & DynamicSlice) + static Status PrepareForCompute(const std::vector& raw_starts, + const std::vector& raw_ends, + const std::vector& raw_axes, + SliceOp::PrepareForComputeMetadata& compute_metadata); + + // compute output_dims with steps (Slice V10) + static Status PrepareForCompute(const std::vector& raw_starts, + const std::vector& raw_ends, + const std::vector& raw_axes, + const std::vector& raw_steps, + SliceOp::PrepareForComputeMetadata& compute_metadata); + + // Slice V10 & DynamicSlice + static void FillVectorsFromInput(const Tensor& start_tensor, + const Tensor& ends_tensor, + const Tensor* axes_tensor, + const Tensor* steps_tensor, + std::vector& input_starts, + std::vector& input_ends, + std::vector& input_axes, + std::vector& input_steps); + protected: SliceBase(const OpKernelInfo& info, bool dynamic = false) : dynamic_(dynamic) { @@ -22,37 +67,6 @@ class SliceBase { } } - // compute output_dims without steps (Slice V1-9 & DynamicSlice) - Status PrepareForCompute(const std::vector& raw_starts, - const std::vector& raw_ends, - const std::vector& raw_axes, - const std::vector& input_dimensions, - std::vector& starts, - std::vector& steps, - std::vector& output_dims, - std::vector*& flattened_output_dims) const; - - // compute output_dims with steps (Slice V10) - Status PrepareForCompute(const std::vector& raw_starts, - const std::vector& raw_ends, - const std::vector& raw_axes, - const std::vector& raw_steps, - const std::vector& input_dimensions, - std::vector& starts, - std::vector& steps, - std::vector& output_dims, - std::vector*& flattened_output_dims) const; - - // Slice V10 & DynamicSlice - void FillVectorsFromInput(const Tensor& start_tensor, - const Tensor& ends_tensor, - const Tensor* axes_tensor, - const Tensor* steps_tensor, - std::vector& input_starts, - std::vector& input_ends, - std::vector& input_axes, - std::vector& input_steps) const; - Status Compute(OpKernelContext* context) const; protected: diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 7c6e578b9b..65bd65d8c4 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -1372,7 +1372,7 @@ static bool RNNNeedFallbackToCPU(const onnxruntime::Node& node, return false; } -static bool ConvNeedFallbackToCPU(const onnxruntime::Node& node) { +static bool ConvTransposeNeedFallbackToCPU(const onnxruntime::Node& node) { const auto& node_attributes = node.GetAttributes(); // Check attributes for (auto& attr : node_attributes) { @@ -1380,6 +1380,8 @@ static bool ConvNeedFallbackToCPU(const onnxruntime::Node& node) { auto attr_value = attr.second; //cudnn only supports symmetric padding + // TODO: Check if we can adopt a similar approach to deal with asymmetric pads in 'ConvTranspose' + // as we did for 'Conv' to circumvent the cudnn limitation if ("pads" == attr_name && ::ONNX_NAMESPACE::AttributeProto_AttributeType::AttributeProto_AttributeType_INTS == attr_value.type()) { auto& pads = attr_value.ints(); int pads_size = pads.size(); @@ -1468,8 +1470,8 @@ CUDAExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, std::vector activations_supported{"sigmoid", "tanh", "sigmoid", "tanh"}; not_supported = RNNNeedFallbackToCPU(node, activations_supported, node.OpType()); force_inside = !not_supported; - } else if ("Conv" == node.OpType()) { - not_supported = ConvNeedFallbackToCPU(node); + } else if ("ConvTranspose" == node.OpType()) { + not_supported = ConvTransposeNeedFallbackToCPU(node); force_inside = !not_supported; } else if ("Cast" == node.OpType()) { not_supported = CastNeedFallbackToCPU(node); diff --git a/onnxruntime/core/providers/cuda/nn/conv.cc b/onnxruntime/core/providers/cuda/nn/conv.cc index c29b39a259..b65f3289b4 100644 --- a/onnxruntime/core/providers/cuda/nn/conv.cc +++ b/onnxruntime/core/providers/cuda/nn/conv.cc @@ -5,6 +5,7 @@ #include "core/providers/cuda/cuda_common.h" #include "core/providers/cuda/nn/conv.h" #include "core/providers/cuda/shared_inc/fpgeneric.h" +#include "core/providers/cuda/tensor/slice.h" namespace onnxruntime { namespace cuda { @@ -33,6 +34,24 @@ REGISTER_KERNEL_TYPED(float) REGISTER_KERNEL_TYPED(double) REGISTER_KERNEL_TYPED(MLFloat16) +static Status SliceOutUnwantedOutputSection(const void* input_data, + const std::vector& input_dims, + void* output_data, + const std::vector& output_dims, + std::vector starts, + const std::vector& ends, + const std::vector& axes, + size_t element_size) { + SliceOp::PrepareForComputeMetadata compute_metadata(input_dims); + + SliceBase::PrepareForCompute(starts, ends, axes, compute_metadata); + + // As a sanity check, ensure that the slice operator's output shape matches with the expected output shape + ORT_ENFORCE(compute_metadata.output_dims_ == output_dims); + + return SliceCuda::Impl(input_data, input_dims, output_data, compute_metadata, element_size); +} + template Status Conv::ComputeInternal(OpKernelContext* context) const { typedef typename ToCudaType::MappedType CudaT; @@ -52,6 +71,18 @@ Status Conv::ComputeInternal(OpKernelContext* context) const { CudaT* y_data = nullptr; + AllocatorPtr allocator; + ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); + + size_t element_size = X->DataType()->Size(); + + Tensor* Y = nullptr; + + // We may have to write the CuDNN Conv results to a temporary bufferwhen we deal with + // asymmetric padding as we have to take the results written to this temporary buffer and slice out + // extraneous portions of the result + IAllocatorUniquePtr memory_for_cudnn_conv_results; + { std::lock_guard lock(s_.mutex); // TODO: add a global cache if need to handle cases for multiple frames running simultaneuously with different batch_size @@ -88,15 +119,46 @@ Status Conv::ComputeInternal(OpKernelContext* context) const { } std::vector y_dims; + y_dims.reserve(2 + rank); // rank indicates number of feature dimensions - so add 2 to account for 'N' and 'C' y_dims.insert(y_dims.begin(), {N, M}); - ORT_RETURN_IF_ERROR(conv_attrs_.InferOutputShape(x_shape.Slice(2), kernel_shape, - strides, dilations, pads, y_dims, true)); + + std::vector y_dims_with_adjusted_pads; + y_dims_with_adjusted_pads.reserve(2 + rank); // rank indicates number of feature dimensions - so add 2 to account for 'N' and 'C' + y_dims_with_adjusted_pads.insert(y_dims_with_adjusted_pads.begin(), {N, M}); + + bool post_slicing_required = false; + std::vector slice_starts; + slice_starts.reserve(rank); + + std::vector slice_ends; + slice_ends.reserve(rank); + + std::vector slice_axes; + slice_axes.reserve(rank); + + ORT_RETURN_IF_ERROR(conv_attrs_.InferOutputShapeWithAdjustedPads(x_shape.Slice(2), kernel_shape, + strides, dilations, pads, y_dims, y_dims_with_adjusted_pads, + post_slicing_required, slice_starts, slice_ends, slice_axes)); + ORT_ENFORCE(y_dims.size() == y_dims_with_adjusted_pads.size()); s_.y_dims = y_dims; - Tensor* Y = context->Output(0, TensorShape(s_.y_dims)); - y_data = reinterpret_cast(Y->template MutableData()); + s_.y_dims_with_adjusted_pads = y_dims_with_adjusted_pads; + s_.post_slicing_required = post_slicing_required; + s_.slice_starts = slice_starts; + s_.slice_ends = slice_ends; + s_.slice_axes = slice_axes; + + Y = context->Output(0, TensorShape(s_.y_dims)); + if (!post_slicing_required) { + // No post slicing needed. Fill the output tensor's buffer directly. + y_data = reinterpret_cast(Y->template MutableData()); + } else { + // Post slicing needed. Create and fill in the Conv results in an intermediate buffer. + memory_for_cudnn_conv_results = GetScratchBuffer(TensorShape(y_dims_with_adjusted_pads).Size() * element_size); + y_data = reinterpret_cast(memory_for_cudnn_conv_results.get()); + } std::vector x_dims_cudnn = x_dims; - std::vector y_dims_cudnn = y_dims; + std::vector y_dims_cudnn = !post_slicing_required ? y_dims : y_dims_with_adjusted_pads; if (rank < 2) { // cudnn only takes 4D or 5D input, so pad dimensions if needed x_dims_cudnn.push_back(1); @@ -173,12 +235,18 @@ Status Conv::ComputeInternal(OpKernelContext* context) const { } if (!y_data) { - Tensor* Y = context->Output(0, TensorShape(s_.y_dims)); + Y = context->Output(0, TensorShape(s_.y_dims)); // special case when there is a dim value of 0 in the shape. if (Y->Shape().Size() == 0) return Status::OK(); - y_data = reinterpret_cast(Y->template MutableData()); + if (!s_.post_slicing_required) { + y_data = reinterpret_cast(Y->template MutableData()); + } else { + // Post slicing needed. Create and fill in the Conv results in an intermediate buffer. + memory_for_cudnn_conv_results = GetScratchBuffer(TensorShape(s_.y_dims_with_adjusted_pads).Size() * element_size); + y_data = reinterpret_cast(memory_for_cudnn_conv_results.get()); + } } const auto alpha = Consts::One; @@ -203,7 +271,15 @@ Status Conv::ComputeInternal(OpKernelContext* context) const { if (has_bias) { const Tensor* B = context->Input(2); auto b_data = reinterpret_cast(B->template Data()); - CUDNN_RETURN_IF_ERROR(cudnnAddTensor(CudnnHandle(), &alpha, s_.b_tensor, b_data, &alpha, s_.y_tensor, y_data)); + CUDNN_RETURN_IF_ERROR(cudnnAddTensor(CudnnHandle(), &alpha, s_.b_tensor, b_data, &alpha, s_.y_tensor, + y_data)); + } + + // To deal with asymmetric padding, we may have over-padded on one or both sides of the spatial dimensions + // This may have lead to extra results that are unnecessary and hence we slice that off here + if (s_.post_slicing_required) { + SliceOutUnwantedOutputSection(y_data, s_.y_dims_with_adjusted_pads, Y->MutableDataRaw(), + s_.y_dims, s_.slice_starts, s_.slice_ends, s_.slice_axes, element_size); } } diff --git a/onnxruntime/core/providers/cuda/nn/conv.h b/onnxruntime/core/providers/cuda/nn/conv.h index ca3266c811..a081285678 100644 --- a/onnxruntime/core/providers/cuda/nn/conv.h +++ b/onnxruntime/core/providers/cuda/nn/conv.h @@ -88,7 +88,7 @@ class lru_unordered_map { lru_list_.clear(); } -private: + private: using list_type = std::list; using iterator_type = typename list_type::iterator; struct value_type { @@ -117,6 +117,7 @@ struct CudnnConvState { // these would be recomputed if x/w dims change std::vector y_dims; + std::vector y_dims_with_adjusted_pads; size_t workspace_bytes; decltype(AlgoPerfType().algo) algo; CudnnTensor x_tensor; @@ -126,12 +127,18 @@ struct CudnnConvState { CudnnConvolutionDescriptor conv_desc; struct PerfResultParams { - decltype(AlgoPerfType().algo) algo; - decltype(AlgoPerfType().memory) memory; + decltype(AlgoPerfType().algo) algo; + decltype(AlgoPerfType().memory) memory; decltype(AlgoPerfType().mathType) mathType; }; - lru_unordered_map, PerfResultParams, vector_hash> cached_benchmark_results { MAX_CACHED_ALGO_PERF_RESULTS }; + lru_unordered_map, PerfResultParams, vector_hash> cached_benchmark_results{MAX_CACHED_ALGO_PERF_RESULTS}; + + // Some properties needed to support asymmetric padded Conv nodes + bool post_slicing_required; + std::vector slice_starts; + std::vector slice_ends; + std::vector slice_axes; // note that conv objects are shared between execution frames, and a lock is needed to avoid multi-thread racing OrtMutex mutex; @@ -147,10 +154,6 @@ class Conv : public CudaKernel { Conv(const OpKernelInfo& info) : CudaKernel(info), conv_attrs_(info) { auto pads_size = conv_attrs_.pads.size(); ORT_ENFORCE(pads_size % 2 == 0); - auto rank = pads_size / 2; - for (size_t i = 0; i < rank; i++) { - ORT_ENFORCE(conv_attrs_.pads[i] == conv_attrs_.pads[i + rank], "cudnn only supports symmetric padding"); - } } Status ComputeInternal(OpKernelContext* context) const override; diff --git a/onnxruntime/core/providers/cuda/tensor/slice.cc b/onnxruntime/core/providers/cuda/tensor/slice.cc index a433db34fe..ae5cf42f5c 100644 --- a/onnxruntime/core/providers/cuda/tensor/slice.cc +++ b/onnxruntime/core/providers/cuda/tensor/slice.cc @@ -63,46 +63,43 @@ REGISTER_V11_TYPED_SLICE(int32_t) REGISTER_V11_TYPED_SLICE(int64_t) REGISTER_V11_TYPED_SLICE(float) -template -Status Slice::ComputeInternal(OpKernelContext* ctx) const { - const Tensor* input_tensor = GetSlicedOrUnslicedTensor(ctx); - - ORT_ENFORCE(nullptr != input_tensor); - - auto& input_dimensions = input_tensor->Shape().GetDims(); - - // Initialize the starts & ends to the actual tensor shape - size_t dimension_count = input_dimensions.size(); - std::vector starts(dimension_count, 0); - std::vector steps(dimension_count, 1); - std::vector output_dims(input_dimensions); - std::vector flattened_output_dims; - std::vector* p_flattened_output_dims = &flattened_output_dims; - - if (dynamic) { - std::vector input_starts, input_ends, input_axes, input_steps; - FillInputVectors(ctx, input_starts, input_ends, input_axes, input_steps); - ORT_RETURN_IF_ERROR(PrepareForCompute(input_starts, input_ends, input_axes, - input_steps, input_dimensions, starts, steps, output_dims, - p_flattened_output_dims)); - - } else { - ORT_RETURN_IF_ERROR(PrepareForCompute(StartsAttribute(), EndsAttribute(), AxesAttribute(), - input_dimensions, starts, steps, output_dims, - p_flattened_output_dims)); +static Status SliceImpCore(const void* input_data, void* output_data, + size_t element_size, size_t dimension_count, + const TArray& starts_buffer, const TArray& steps_buffer, + const TArray& input_strides, const TArray& output_strides, + const TensorShape& output_shape) { + if (output_shape.Size() == 0) { + return Status::OK(); } + return SliceImpl(element_size, + gsl::narrow_cast(dimension_count), + starts_buffer, + steps_buffer, + input_strides, + output_strides, + input_data, + output_data, + output_shape.Size()); +} + +namespace SliceCuda { + +static Status ComputeSliceStrides(const TensorShape& input_shape, + TArray& input_strides, + TArray& output_strides, + SliceOp::PrepareForComputeMetadata& compute_metadata) { + const auto& input_dimensions = input_shape.GetDims(); + size_t dimension_count = input_dimensions.size(); // if we are able to flatten the output dims we updated 'starts' and 'steps' to match the smaller number of dims. // update dimension_count to match. - if (p_flattened_output_dims != nullptr) { - dimension_count = flattened_output_dims.size(); + if (compute_metadata.p_flattened_output_dims_) { + dimension_count = compute_metadata.p_flattened_output_dims_->size(); } - TArray starts_buffer(starts); - TArray steps_buffer(steps); - TArray input_strides(gsl::narrow_cast(dimension_count)); + input_strides.SetSize(gsl::narrow_cast(dimension_count)); const gsl::span input_strides_span = gsl::make_span(input_strides.Data(), input_strides.Size()); - if (p_flattened_output_dims != nullptr) { + if (compute_metadata.p_flattened_output_dims_ != nullptr) { // we were able to flatten the innermost dimensions as they're being copied in full to the output. // do the same flattening to the innermost input dimensions in order to calculate pitches that match // the flattened output dimensions. @@ -119,22 +116,82 @@ Status Slice::ComputeInternal(OpKernelContext* ctx) const { ORT_ENFORCE(TensorPitches::Calculate(input_strides_span, input_dimensions)); } - TensorPitches original_output_strides(p_flattened_output_dims != nullptr ? flattened_output_dims : output_dims); - TArray output_strides(gsl::narrow_cast(original_output_strides.size())); + TensorPitches original_output_strides( + compute_metadata.p_flattened_output_dims_ != nullptr ? compute_metadata.flattened_output_dims_ : compute_metadata.output_dims_); + output_strides.SetSize(gsl::narrow_cast(original_output_strides.size())); for (int32_t i = 0; i < static_cast(original_output_strides.size()); ++i) { output_strides[i] = fast_divmod(gsl::narrow_cast(original_output_strides[i])); } - size_t element_size = input_tensor->DataType()->Size(); + return Status::OK(); +} - ORT_RETURN_IF_ERROR(CallSliceImp(element_size, +Status Impl(const void* input_data, + const TensorShape& input_shape, + void* output_data, + SliceOp::PrepareForComputeMetadata& compute_metadata, + size_t element_size) { + const auto& input_dimensions = input_shape.GetDims(); + size_t dimension_count = input_dimensions.size(); + + TArray starts_buffer(compute_metadata.starts_); + TArray steps_buffer(compute_metadata.steps_); + TArray input_strides; + TArray output_strides; + + ORT_RETURN_IF_ERROR(ComputeSliceStrides(input_shape, input_strides, output_strides, compute_metadata)); + + TensorShape output_shape(compute_metadata.output_dims_); + + ORT_RETURN_IF_ERROR(SliceImpCore(input_data, + output_data, + element_size, gsl::narrow_cast(dimension_count), starts_buffer, steps_buffer, input_strides, output_strides, - ctx, - TensorShape(output_dims))); + output_shape)); + + return Status::OK(); +} +} // namespace SliceCuda + +template +Status Slice::ComputeInternal(OpKernelContext* ctx) const { + const Tensor* input_tensor = GetSlicedOrUnslicedTensor(ctx); + ORT_ENFORCE(nullptr != input_tensor); + const auto& input_shape = input_tensor->Shape(); + const auto& input_dimensions = input_shape.GetDims(); + if (input_dimensions.empty()) return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Cannot slice scalars"); + + SliceOp::PrepareForComputeMetadata compute_metadata(input_dimensions); + + if (dynamic) { + std::vector input_starts, input_ends, input_axes, input_steps; + FillInputVectors(ctx, input_starts, input_ends, input_axes, input_steps); + ORT_RETURN_IF_ERROR(PrepareForCompute(input_starts, input_ends, input_axes, input_steps, compute_metadata)); + + } else { + ORT_RETURN_IF_ERROR(PrepareForCompute(StartsAttribute(), EndsAttribute(), AxesAttribute(), compute_metadata)); + } + + TensorShape output_shape(compute_metadata.output_dims_); + + TArray starts_buffer(compute_metadata.starts_); + TArray steps_buffer(compute_metadata.steps_); + TArray input_strides; + TArray output_strides; + + ORT_RETURN_IF_ERROR(SliceCuda::ComputeSliceStrides(input_shape, input_strides, output_strides, compute_metadata)); + + // It may seem that we may use `SliceImpCore()` directly, but we need to go through `CallSliceImp()` because + // `ComputeInternal()` is shared between the inferencing and training kernels and the training kernel overrides + // `CallSliceImp()` + ORT_RETURN_IF_ERROR(CallSliceImp(input_tensor->DataType()->Size(), input_dimensions.size(), starts_buffer, + steps_buffer, input_strides, + output_strides, ctx, + output_shape)); return Status::OK(); } @@ -156,21 +213,19 @@ template Status Slice::CallSliceImp(size_t element_size, size_t dimension_count, const TArray& starts_buffer, const TArray& steps_buffer, const TArray& input_strides, const TArray& output_strides, OpKernelContext* ctx, - TensorShape output_shape) const { + const TensorShape& output_shape) const { + const auto* input_tensor = ctx->Input(0); auto* output_tensor = ctx->Output(0, output_shape); - if (output_shape.Size() == 0) { - return Status::OK(); - } - return SliceImpl(element_size, - gsl::narrow_cast(dimension_count), - starts_buffer, - steps_buffer, - input_strides, - output_strides, - ctx->Input(0)->DataRaw(), - output_tensor->MutableDataRaw(), - output_shape.Size()); + return SliceImpCore(input_tensor->DataRaw(), + output_tensor->MutableDataRaw(), + element_size, + gsl::narrow_cast(dimension_count), + starts_buffer, + steps_buffer, + input_strides, + output_strides, + output_shape); } } // namespace cuda diff --git a/onnxruntime/core/providers/cuda/tensor/slice.h b/onnxruntime/core/providers/cuda/tensor/slice.h index 55ca9fbc61..4b0124a6cf 100644 --- a/onnxruntime/core/providers/cuda/tensor/slice.h +++ b/onnxruntime/core/providers/cuda/tensor/slice.h @@ -9,9 +9,19 @@ namespace onnxruntime { namespace cuda { +namespace SliceCuda { + +Status Impl(const void* input_data, + const TensorShape& input_shape, + void* output_data, + SliceOp::PrepareForComputeMetadata& prepare_metadata, + size_t element_size); + +} // namespace SliceCuda + template class Slice : public CudaKernel, public SliceBase { -public: + public: Slice(const OpKernelInfo& info) : CudaKernel(info), SliceBase(info, dynamic) {} Status ComputeInternal(OpKernelContext* ctx) const override; @@ -25,7 +35,7 @@ public: virtual Status CallSliceImp(size_t element_size, size_t dimension_count, const TArray& starts_buffer, const TArray& steps_buffer, const TArray& input_strides, const TArray& output_strides, OpKernelContext* ctx, - TensorShape output_shape) const; + const TensorShape& output_shape) const; }; } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/nn/conv_op_test.cc b/onnxruntime/test/providers/cpu/nn/conv_op_test.cc index 8b9f023cbd..3c866feaba 100644 --- a/onnxruntime/test/providers/cpu/nn/conv_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/conv_op_test.cc @@ -200,7 +200,6 @@ TEST(ConvTest, Conv2D_1) { auto expected_vals = {-0.012311071157455444f, 0.02822777070105076f, -0.028432954102754593f, -0.037657227367162704f, -0.04396762326359749f, 0.10081233829259872f, -0.10154513269662857f, -0.13448859751224518f}; - attrs.excluded_providers.insert(kCudaExecutionProvider); // asymmetric padding is not supported by cudnn TestConvOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); // NNAPI EP requires weight to be an initializer @@ -372,8 +371,6 @@ TEST(ConvTest, Conv2D_Bias_2) { auto expected_vals = {-0.3419531583786011f, -0.6116723418235779f, -0.39677709341049194f, -0.7316848039627075f, -0.5647197365760803f, 0.02788025140762329f, -0.30450713634490967f, -0.6786775588989258f}; - attrs.excluded_providers.insert(kCudaExecutionProvider); // asymmetric padding is not supported by cudnn - TestConvOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape); TestConvOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape, true); @@ -655,5 +652,30 @@ TEST(ConvTest, ConvDimWithZero) { TestConvOp(attrs, {X, W}, {X_shape, W_shape}, {}, out_shape, false, OpTester::ExpectResult::kExpectSuccess, "", 10); } +TEST(ConvTest, Conv1D_asymmetric_padding) { + ConvOpAndTestAttributes attrs = { + "", // auto_pad + vector{1}, // dilations + 1, // group + vector{3}, // kernel_shape + vector{1, 0}, // pads + vector{1}, // strides + {} // excluded EPs + }; + + vector X = {1.f, 2.f, 3.f}; + vector X_shape = {1, 1, 3}; + vector W = {1.f, 1.f, 1.f}; + vector W_shape = {1, 1, 3}; + vector B = {0.f}; + vector B_shape = {1}; + vector Y_shape = {1, 1, 2}; + auto expected_vals = {3.f, 6.f}; + + TestConvOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape); + + TestConvOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape, true); +} + } // namespace test } // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cpu/tensor/slice_grad.cc b/orttraining/orttraining/training_ops/cpu/tensor/slice_grad.cc index c6cbb517b6..d14de751e7 100644 --- a/orttraining/orttraining/training_ops/cpu/tensor/slice_grad.cc +++ b/orttraining/orttraining/training_ops/cpu/tensor/slice_grad.cc @@ -28,11 +28,6 @@ Status SliceGrad::Compute(OpKernelContext* context) const { Tensor& output = *context->Output(0, data_shape); memset(output.MutableDataRaw(), 0, output.SizeInBytes()); // Initialize the starts & ends to the actual tensor shape - std::vector starts(data_shape.GetDims().size(), 0); - std::vector steps(data_shape.GetDims().size(), 1); - std::vector output_dims(data_shape.GetDims()); - std::vector flattened_output_dims; - std::vector* p_flattened_output_dims = &flattened_output_dims; std::vector input_starts; std::vector input_ends; std::vector input_axes; @@ -40,17 +35,18 @@ Status SliceGrad::Compute(OpKernelContext* context) const { FillVectorsFromInput(*context->Input(2), *context->Input(3), context->Input(4), context->Input(5), input_starts, input_ends, input_axes, input_steps); - ORT_RETURN_IF_ERROR(PrepareForCompute(input_starts, input_ends, input_axes, input_steps, - data_shape.GetDims(), starts, steps, output_dims, - p_flattened_output_dims)); + SliceOp::PrepareForComputeMetadata compute_metadata(data_shape.GetDims()); + ORT_RETURN_IF_ERROR(PrepareForCompute(input_starts, input_ends, input_axes, input_steps, compute_metadata)); MLDataType T_type = grad.DataType(); if (T_type == DataTypeImpl::GetType()) { - return ComputeImpl(context, output, output_dims, p_flattened_output_dims, starts, steps); + return ComputeImpl(context, output, compute_metadata.output_dims_, compute_metadata.p_flattened_output_dims_, + compute_metadata.starts_, compute_metadata.steps_); } if (T_type == DataTypeImpl::GetType()) { - return ComputeImpl(context, output, output_dims, p_flattened_output_dims, starts, steps); + return ComputeImpl(context, output, compute_metadata.output_dims_, compute_metadata.p_flattened_output_dims_, + compute_metadata.starts_, compute_metadata.steps_); } return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "Type for T or Tind not supported yet in SliceGrad."); diff --git a/orttraining/orttraining/training_ops/cuda/tensor/slice_grad.cc b/orttraining/orttraining/training_ops/cuda/tensor/slice_grad.cc index 2d0e39c125..6b17d2cb37 100644 --- a/orttraining/orttraining/training_ops/cuda/tensor/slice_grad.cc +++ b/orttraining/orttraining/training_ops/cuda/tensor/slice_grad.cc @@ -50,7 +50,7 @@ void SliceGrad::FillInputVectors(OpKernelContext* ctx, std::vector& inp Status SliceGrad::CallSliceImp(size_t element_size, size_t dimension_count, const TArray& starts_buffer, const TArray& steps_buffer, const TArray& input_strides, const TArray& output_strides, OpKernelContext* ctx, - TensorShape output_shape) const { + const TensorShape& output_shape) const { Tensor* gradient_out_tensor = GetOutputGradientTensor(ctx); CUDA_RETURN_IF_ERROR(cudaMemset(gradient_out_tensor->MutableDataRaw(), 0, gradient_out_tensor->SizeInBytes())); return SliceImplGrad(element_size, diff --git a/orttraining/orttraining/training_ops/cuda/tensor/slice_grad.h b/orttraining/orttraining/training_ops/cuda/tensor/slice_grad.h index d43b7bbb54..368db4203b 100644 --- a/orttraining/orttraining/training_ops/cuda/tensor/slice_grad.h +++ b/orttraining/orttraining/training_ops/cuda/tensor/slice_grad.h @@ -17,7 +17,7 @@ class SliceGrad final : public Slice { Status CallSliceImp(size_t element_size, size_t dimension_count, const TArray& starts_buffer, const TArray& steps_buffer, const TArray& input_strides, - const TArray& output_strides, OpKernelContext* ctx, TensorShape output_shape) + const TArray& output_strides, OpKernelContext* ctx, const TensorShape& output_shape) const override; }; } // namespace cuda