Support asymmetric paddings in CUDA Conv kernel (#4627)

This commit is contained in:
Hariharan Seshadri 2020-08-18 02:09:30 -07:00 committed by GitHub
parent c878ecbbe0
commit a3c95374c3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 485 additions and 190 deletions

View file

@ -112,13 +112,26 @@ struct ConvAttributes {
std::vector<int64_t>& output_shape,
bool force_symmetric_auto_padding = false) const {
size_t rank = input_shape.NumDimensions();
// Make sure all "metadata" containers have the right number of elements
if (rank > strides_p.size())
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Not enough elements in strides. Expected: ", rank, " Got: ", strides_p.size());
if (rank > kernel_shape.size())
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Not enough elements in kernel shape. Expected: ", rank, " Got: ", kernel_shape.size());
if (rank > dilations_p.size())
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Not enough elements in dilations. Expected: ", rank, " Got: ", dilations_p.size());
if ((2 * rank) > pads_p.size())
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Not enough elements in pads. Expected: ", (2 * rank), " Got: ", pads_p.size());
for (size_t dim = 0; dim < rank; ++dim) {
if (dim >= strides_p.size() || dim >= kernel_shape.size() ||
dim >= dilations_p.size() || dim >= pads_p.size() ||
rank + dim >= pads_p.size()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Out of bound access to array");
}
int64_t dim_size = 0;
int64_t output_dim_size = 0;
ORT_RETURN_IF_ERROR(ComputePadAndOutputShape(input_shape[dim],
strides_p[dim],
kernel_shape[dim],
@ -126,12 +139,135 @@ struct ConvAttributes {
auto_pad,
pads_p.at(dim),
pads_p.at(input_shape.NumDimensions() + dim),
dim_size,
output_dim_size,
force_symmetric_auto_padding));
if (dim_size <= 0) {
if (output_dim_size <= 0) {
return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Invalid input shape: " + input_shape.ToString());
}
output_shape.push_back(dim_size);
output_shape.push_back(output_dim_size);
}
return Status::OK();
}
// Use this method when pads are to be made symmetrical (if they are asymmetric)
// and to collect metadata regarding the portion of the output (with "adjusted" pads)
// to be sliced off to make the output correspond to the "actual" asymmetric paddings
Status InferOutputShapeWithAdjustedPads(const TensorShape& input_shape,
const std::vector<int64_t>& kernel_shape,
const std::vector<int64_t>& strides_p,
const std::vector<int64_t>& dilations_p,
std::vector<int64_t>& pads_p,
std::vector<int64_t>& output_shape,
std::vector<int64_t>& output_shape_with_revised_pads,
bool& post_slicing_needed,
std::vector<int64_t>& slice_starts,
std::vector<int64_t>& slice_ends,
std::vector<int64_t>& slice_axes) const {
size_t rank = input_shape.NumDimensions();
// Make sure all "metadata" containers have the right number of elements
if (rank > strides_p.size())
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Not enough elements in strides. Expected: ", rank, " Got: ", strides_p.size());
if (rank > kernel_shape.size())
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Not enough elements in kernel shape. Expected: ", rank, " Got: ", kernel_shape.size());
if (rank > dilations_p.size())
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Not enough elements in dilations. Expected: ", rank, " Got: ", dilations_p.size());
if ((2 * rank) > pads_p.size())
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Not enough elements in pads. Expected: ", (2 * rank), " Got: ", pads_p.size());
for (size_t dim = 0; dim < rank; ++dim) {
int64_t output_dim_size = 0;
ORT_RETURN_IF_ERROR(ComputePadAndOutputShape(input_shape[dim],
strides_p[dim],
kernel_shape[dim],
dilations_p[dim],
auto_pad,
pads_p[dim],
pads_p[rank + dim],
output_dim_size));
if (output_dim_size <= 0) {
return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Invalid input shape: " + input_shape.ToString());
}
// This is the "actual" output shape of the Conv op (i.e.) with given pads as is
output_shape.push_back(output_dim_size);
// Deal with asymmetric pads if any and adjust them to be symmetric
// Along the way - note down how many values need to be sliced out from the start and end
// of each spatial dimension while we over-pad to get symmetric pads
if (pads_p[dim] == pads_p[rank + dim]) {
// symmetric padding - No operation as such
// Make note of the dim size of the output (to be used if there are other symmetrically padded dims)
output_shape_with_revised_pads.push_back(output_dim_size);
} else {
// asymmetric padding
int64_t& pad_head = pads_p[dim];
int64_t& pad_tail = pads_p[rank + dim];
int64_t stride = strides_p[dim];
bool head_overpadded = false;
if (pad_head < pad_tail) {
int64_t excess_output_head = 0;
// pad_head is under-padded, so "adjust" it by adding more padding
while (pad_head < pad_tail) {
// keep over-padding in multiples of 'strides' so that
// the filter slides over correctly
pad_head += stride;
excess_output_head += 1; // each multiple of stride contributes to 1 excess output value
}
post_slicing_needed = true;
slice_axes.push_back(dim + 2);
slice_starts.push_back(excess_output_head);
slice_ends.push_back(excess_output_head + output_dim_size); // we may modify this below
output_shape_with_revised_pads.push_back(excess_output_head + output_dim_size); // we may modify this below
head_overpadded = true;
}
// we may enter this section even if the head was initially under-padded,
// because we had to over-pad by multiples of 'stride', now `pad_head` might be > `pad_tail`
if (pad_tail < pad_head) {
pad_tail = pad_head;
auto revised_dim_size = ComputeOutputShape(input_shape[dim], strides_p[dim],
kernel_shape[dim], dilations_p[dim],
pad_head, pad_tail);
if (head_overpadded) {
// Head has already been over-padded
// Additional tail pads need not result in additional output
// Ensure that the size has changed - otherwise no operation needed.
if (revised_dim_size !=
output_shape_with_revised_pads[output_shape_with_revised_pads.size() - 1]) {
output_shape_with_revised_pads[output_shape_with_revised_pads.size() - 1] = revised_dim_size;
}
} else {
// Additional tail pads need not result in additional output
// Ensure that the size has changed - otherwise no operation needed.
if (revised_dim_size != output_dim_size) {
// Head has not been over-padded. Only tail pads need to be modified.
post_slicing_needed = true;
slice_axes.push_back(dim + 2);
slice_starts.push_back(0);
slice_ends.push_back(output_dim_size - revised_dim_size);
}
// make note of the shape of this spatial dimension
output_shape_with_revised_pads.push_back(revised_dim_size);
}
}
}
}
return Status::OK();
}

View file

@ -89,22 +89,18 @@ static void FlattenOutputDims(const std::vector<int64_t>& input_dimensions,
Status SliceBase::PrepareForCompute(const std::vector<int64_t>& raw_starts,
const std::vector<int64_t>& raw_ends,
const std::vector<int64_t>& raw_axes,
const std::vector<int64_t>& input_dimensions,
std::vector<int64_t>& starts,
std::vector<int64_t>& steps,
std::vector<int64_t>& output_dims,
std::vector<int64_t>*& flattened_output_dims) const {
SliceOp::PrepareForComputeMetadata& compute_metadata) {
// Initialize axes to the provided axes attribute or to the default sequence
std::vector<int64_t> axes(raw_axes);
if (axes.empty()) {
//axes are omitted, they are set to[0, ..., ndim - 1]
axes.resize(starts.size());
axes.resize(compute_metadata.starts_.size());
std::iota(axes.begin(), axes.end(), 0);
}
// Iterate through the provided axes and override the start/end ranges
std::unordered_set<int64_t> unique_axes;
const auto& dimension_count = input_dimensions.size();
const auto& dimension_count = compute_metadata.input_dimensions_.size();
for (size_t axis_index = 0, axes_count = axes.size(); axis_index < axes_count; ++axis_index) {
auto axis = HandleNegativeAxis(axes[axis_index], dimension_count); // handle negative and enforce axis is valid
if (axis >= static_cast<int64_t>(dimension_count) || axis < 0)
@ -116,23 +112,24 @@ Status SliceBase::PrepareForCompute(const std::vector<int64_t>& raw_starts,
// process start
auto start = raw_starts[axis_index];
if (start < 0)
start += input_dimensions[axis];
starts[axis] = clamp(start, int64_t{0}, input_dimensions[axis]);
start += compute_metadata.input_dimensions_[axis];
compute_metadata.starts_[axis] = clamp(start, int64_t{0}, compute_metadata.input_dimensions_[axis]);
// process end
auto end = raw_ends[axis_index];
if (end < 0)
end += input_dimensions[axis];
end += compute_metadata.input_dimensions_[axis];
// find output dim value for this axis
auto temp = clamp(end, int64_t{0}, input_dimensions[axis]) - starts[axis];
auto temp = clamp(end, int64_t{0}, compute_metadata.input_dimensions_[axis]) - compute_metadata.starts_[axis];
if (temp < 0)
output_dims[axis] = 0;
compute_metadata.output_dims_[axis] = 0;
else
output_dims[axis] = temp;
compute_metadata.output_dims_[axis] = temp;
}
FlattenOutputDims(input_dimensions, output_dims, starts, steps, flattened_output_dims);
FlattenOutputDims(compute_metadata.input_dimensions_, compute_metadata.output_dims_, compute_metadata.starts_,
compute_metadata.steps_, compute_metadata.p_flattened_output_dims_);
return Status::OK();
}
@ -142,23 +139,19 @@ Status SliceBase::PrepareForCompute(const std::vector<int64_t>& raw_starts,
const std::vector<int64_t>& raw_ends,
const std::vector<int64_t>& raw_axes,
const std::vector<int64_t>& raw_steps,
const std::vector<int64_t>& input_dimensions,
std::vector<int64_t>& starts,
std::vector<int64_t>& steps,
std::vector<int64_t>& output_dims,
std::vector<int64_t>*& flattened_output_dims) const {
SliceOp::PrepareForComputeMetadata& compute_metadata) {
// Initialize axes to the provided axes attribute or to the default sequence
std::vector<int64_t> axes(raw_axes);
if (axes.empty()) {
// axes are omitted, they are set to[0, ..., ndim - 1]
axes.resize(starts.size());
axes.resize(compute_metadata.starts_.size());
std::iota(axes.begin(), axes.end(), 0);
}
// Iterate through the provided axes and override the start/end/steps ranges
std::unordered_set<int64_t> unique_axes;
const auto& dimension_count = input_dimensions.size();
const auto& dimension_count = compute_metadata.input_dimensions_.size();
for (size_t axis_index = 0, axes_count = axes.size(); axis_index < axes_count; ++axis_index) {
auto axis = axes[axis_index] < 0 ? axes[axis_index] + static_cast<int64_t>(dimension_count) : axes[axis_index];
if (axis >= static_cast<int64_t>(dimension_count) || axis < 0)
@ -171,16 +164,16 @@ Status SliceBase::PrepareForCompute(const std::vector<int64_t>& raw_starts,
auto step = axis_index < raw_steps.size() ? raw_steps[axis_index] : 1;
if (step == 0)
return Status(ONNXRUNTIME, INVALID_ARGUMENT, "'step' value cannot be 0");
steps[axis] = step;
compute_metadata.steps_[axis] = step;
// process start
auto start = raw_starts[axis_index];
if (start < 0)
start += input_dimensions[axis];
start += compute_metadata.input_dimensions_[axis];
if (step < 0)
starts[axis] = clamp(start, int64_t{0}, input_dimensions[axis] - 1);
compute_metadata.starts_[axis] = clamp(start, int64_t{0}, compute_metadata.input_dimensions_[axis] - 1);
else
starts[axis] = clamp(start, int64_t{0}, input_dimensions[axis]);
compute_metadata.starts_[axis] = clamp(start, int64_t{0}, compute_metadata.input_dimensions_[axis]);
// process end
auto end = raw_ends[axis_index];
@ -189,27 +182,28 @@ Status SliceBase::PrepareForCompute(const std::vector<int64_t>& raw_starts,
// it represent slicing to the end of the dimension
if (end == std::numeric_limits<int32_t>::max() ||
end == std::numeric_limits<int64_t>::max()) {
end = step < 0 ? -1 : input_dimensions[axis];
end = step < 0 ? -1 : compute_metadata.input_dimensions_[axis];
}
else {
if (end < 0)
end += input_dimensions[axis];
end += compute_metadata.input_dimensions_[axis];
if (step < 0)
end = clamp(end, int64_t{-1}, input_dimensions[axis]);
end = clamp(end, int64_t{-1}, compute_metadata.input_dimensions_[axis]);
else
end = clamp(end, int64_t{0}, input_dimensions[axis]);
end = clamp(end, int64_t{0}, compute_metadata.input_dimensions_[axis]);
}
// find output dim value for this axis
auto temp = static_cast<int64_t>(ceil(1.0 * (end - starts[axis]) / step));
auto temp = static_cast<int64_t>(ceil(1.0 * (end - compute_metadata.starts_[axis]) / step));
if (temp < 0)
output_dims[axis] = 0;
compute_metadata.output_dims_[axis] = 0;
else
output_dims[axis] = temp;
compute_metadata.output_dims_[axis] = temp;
}
FlattenOutputDims(input_dimensions, output_dims, starts, steps, flattened_output_dims);
FlattenOutputDims(compute_metadata.input_dimensions_, compute_metadata.output_dims_, compute_metadata.starts_,
compute_metadata.steps_, compute_metadata.p_flattened_output_dims_);
return Status::OK();
}
@ -222,7 +216,7 @@ void SliceBase::FillVectorsFromInput(const Tensor& start_tensor,
std::vector<int64_t>& input_starts,
std::vector<int64_t>& input_ends,
std::vector<int64_t>& input_axes,
std::vector<int64_t>& input_steps) const {
std::vector<int64_t>& input_steps) {
ORT_ENFORCE(start_tensor.Shape().NumDimensions() == 1, "Starts must be a 1-D array");
ORT_ENFORCE(ends_tensor.Shape().NumDimensions() == 1, "Ends must be a 1-D array");
ORT_ENFORCE(start_tensor.Shape() == ends_tensor.Shape(), "Starts and ends shape mismatch");
@ -267,11 +261,8 @@ void SliceBase::FillVectorsFromInput(const Tensor& start_tensor,
template <typename T>
static Status SliceImpl(OpKernelContext* ctx,
const Tensor& input_tensor,
std::vector<int64_t>& output_dims,
std::vector<int64_t>* flattened_output_dims,
const std::vector<int64_t>& starts,
const std::vector<int64_t>& steps) {
TensorShape output_shape(output_dims);
SliceOp::PrepareForComputeMetadata& compute_metadata) {
TensorShape output_shape(compute_metadata.output_dims_);
auto& output_tensor = *ctx->Output(0, output_shape);
// output tensor's size is 0, nothing to fill - return
@ -296,18 +287,18 @@ static Status SliceImpl(OpKernelContext* ctx,
ORT_ENFORCE(output == output_end);
};
if (flattened_output_dims) {
if (compute_metadata.p_flattened_output_dims_) {
// if we have flattened output dims we need to also flatten the input dims.
// as we're combining the innermost dims and keeping all values we can just copy the size of the last dim
std::vector<int64_t> flattened_input_dims(input_tensor.Shape().GetDims());
flattened_input_dims.resize(flattened_output_dims->size());
flattened_input_dims.back() = flattened_output_dims->back();
flattened_input_dims.resize(compute_metadata.p_flattened_output_dims_->size());
flattened_input_dims.back() = compute_metadata.p_flattened_output_dims_->back();
TensorShape input_shape(std::move(flattened_input_dims));
auto input_iterator = SliceIterator<T>(input_tensor, input_shape, starts, *flattened_output_dims, steps);
auto input_iterator = SliceIterator<T>(input_tensor, input_shape, compute_metadata.starts_, *compute_metadata.p_flattened_output_dims_, compute_metadata.steps_);
create_output(input_iterator);
} else {
auto input_iterator = SliceIterator<T>(input_tensor, starts, output_dims, steps);
auto input_iterator = SliceIterator<T>(input_tensor, compute_metadata.starts_, compute_metadata.output_dims_, compute_metadata.steps_);
create_output(input_iterator);
}
@ -320,13 +311,7 @@ Status SliceBase::Compute(OpKernelContext* ctx) const {
const auto& input_tensor = *input_tensor_ptr;
const auto& input_dimensions = input_tensor.Shape().GetDims();
if (input_dimensions.empty()) return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Cannot slice scalars");
// Initialize the starts & ends to the actual tensor shape
std::vector<int64_t> starts(input_dimensions.size(), 0);
std::vector<int64_t> steps(input_dimensions.size(), 1);
std::vector<int64_t> output_dims(input_dimensions);
std::vector<int64_t> flattened_output_dims;
std::vector<int64_t>* p_flattened_output_dims = &flattened_output_dims;
SliceOp::PrepareForComputeMetadata compute_metadata(input_dimensions);
// Slice V10 & DynamicSlice
if (dynamic_) {
@ -337,36 +322,32 @@ Status SliceBase::Compute(OpKernelContext* ctx) const {
FillVectorsFromInput(*ctx->Input<Tensor>(1), *ctx->Input<Tensor>(2), ctx->Input<Tensor>(3),
ctx->Input<Tensor>(4), input_starts, input_ends, input_axes, input_steps);
ORT_RETURN_IF_ERROR(PrepareForCompute(input_starts, input_ends, input_axes, input_steps,
input_dimensions, starts, steps, output_dims,
p_flattened_output_dims));
ORT_RETURN_IF_ERROR(PrepareForCompute(input_starts, input_ends, input_axes, input_steps, compute_metadata));
}
// Slice V1-9
else {
ORT_RETURN_IF_ERROR(PrepareForCompute(attr_starts_, attr_ends_, attr_axes_,
input_dimensions, starts, steps, output_dims,
p_flattened_output_dims));
ORT_RETURN_IF_ERROR(PrepareForCompute(attr_starts_, attr_ends_, attr_axes_, compute_metadata));
}
Status status = Status::OK();
if (input_tensor.IsDataTypeString()) {
status = SliceImpl<std::string>(ctx, input_tensor, output_dims, p_flattened_output_dims, starts, steps);
status = SliceImpl<std::string>(ctx, input_tensor, compute_metadata);
} else {
const auto element_size = input_tensor.DataType()->Size();
switch (element_size) {
case sizeof(uint32_t):
status = SliceImpl<uint32_t>(ctx, input_tensor, output_dims, p_flattened_output_dims, starts, steps);
status = SliceImpl<uint32_t>(ctx, input_tensor, compute_metadata);
break;
case sizeof(uint64_t):
status = SliceImpl<uint64_t>(ctx, input_tensor, output_dims, p_flattened_output_dims, starts, steps);
status = SliceImpl<uint64_t>(ctx, input_tensor, compute_metadata);
break;
case sizeof(uint16_t):
status = SliceImpl<uint16_t>(ctx, input_tensor, output_dims, p_flattened_output_dims, starts, steps);
status = SliceImpl<uint16_t>(ctx, input_tensor, compute_metadata);
break;
case sizeof(uint8_t):
status = SliceImpl<uint8_t>(ctx, input_tensor, output_dims, p_flattened_output_dims, starts, steps);
status = SliceImpl<uint8_t>(ctx, input_tensor, compute_metadata);
break;
default:
ORT_THROW("Unsupported input data type of ", input_tensor.DataType());

View file

@ -7,7 +7,52 @@
namespace onnxruntime {
namespace SliceOp {
struct PrepareForComputeMetadata {
PrepareForComputeMetadata() = delete;
PrepareForComputeMetadata(const std::vector<int64_t>& input_dimensions)
: input_dimensions_(input_dimensions) {
size_t dimension_count = input_dimensions.size();
starts_.resize(dimension_count, 0);
steps_.resize(dimension_count, 1);
output_dims_ = input_dimensions;
}
const std::vector<int64_t>& input_dimensions_;
std::vector<int64_t> starts_;
std::vector<int64_t> steps_;
std::vector<int64_t> output_dims_;
std::vector<int64_t> flattened_output_dims_;
std::vector<int64_t>* p_flattened_output_dims_ = &flattened_output_dims_;
};
} // namespace SliceOp
class SliceBase {
// static methods that can be used from other ops if needed
public:
// compute output_dims without steps (Slice V1-9 & DynamicSlice)
static Status PrepareForCompute(const std::vector<int64_t>& raw_starts,
const std::vector<int64_t>& raw_ends,
const std::vector<int64_t>& raw_axes,
SliceOp::PrepareForComputeMetadata& compute_metadata);
// compute output_dims with steps (Slice V10)
static Status PrepareForCompute(const std::vector<int64_t>& raw_starts,
const std::vector<int64_t>& raw_ends,
const std::vector<int64_t>& raw_axes,
const std::vector<int64_t>& raw_steps,
SliceOp::PrepareForComputeMetadata& compute_metadata);
// Slice V10 & DynamicSlice
static void FillVectorsFromInput(const Tensor& start_tensor,
const Tensor& ends_tensor,
const Tensor* axes_tensor,
const Tensor* steps_tensor,
std::vector<int64_t>& input_starts,
std::vector<int64_t>& input_ends,
std::vector<int64_t>& input_axes,
std::vector<int64_t>& input_steps);
protected:
SliceBase(const OpKernelInfo& info, bool dynamic = false)
: dynamic_(dynamic) {
@ -22,37 +67,6 @@ class SliceBase {
}
}
// compute output_dims without steps (Slice V1-9 & DynamicSlice)
Status PrepareForCompute(const std::vector<int64_t>& raw_starts,
const std::vector<int64_t>& raw_ends,
const std::vector<int64_t>& raw_axes,
const std::vector<int64_t>& input_dimensions,
std::vector<int64_t>& starts,
std::vector<int64_t>& steps,
std::vector<int64_t>& output_dims,
std::vector<int64_t>*& flattened_output_dims) const;
// compute output_dims with steps (Slice V10)
Status PrepareForCompute(const std::vector<int64_t>& raw_starts,
const std::vector<int64_t>& raw_ends,
const std::vector<int64_t>& raw_axes,
const std::vector<int64_t>& raw_steps,
const std::vector<int64_t>& input_dimensions,
std::vector<int64_t>& starts,
std::vector<int64_t>& steps,
std::vector<int64_t>& output_dims,
std::vector<int64_t>*& flattened_output_dims) const;
// Slice V10 & DynamicSlice
void FillVectorsFromInput(const Tensor& start_tensor,
const Tensor& ends_tensor,
const Tensor* axes_tensor,
const Tensor* steps_tensor,
std::vector<int64_t>& input_starts,
std::vector<int64_t>& input_ends,
std::vector<int64_t>& input_axes,
std::vector<int64_t>& input_steps) const;
Status Compute(OpKernelContext* context) const;
protected:

View file

@ -1372,7 +1372,7 @@ static bool RNNNeedFallbackToCPU(const onnxruntime::Node& node,
return false;
}
static bool ConvNeedFallbackToCPU(const onnxruntime::Node& node) {
static bool ConvTransposeNeedFallbackToCPU(const onnxruntime::Node& node) {
const auto& node_attributes = node.GetAttributes();
// Check attributes
for (auto& attr : node_attributes) {
@ -1380,6 +1380,8 @@ static bool ConvNeedFallbackToCPU(const onnxruntime::Node& node) {
auto attr_value = attr.second;
//cudnn only supports symmetric padding
// TODO: Check if we can adopt a similar approach to deal with asymmetric pads in 'ConvTranspose'
// as we did for 'Conv' to circumvent the cudnn limitation
if ("pads" == attr_name && ::ONNX_NAMESPACE::AttributeProto_AttributeType::AttributeProto_AttributeType_INTS == attr_value.type()) {
auto& pads = attr_value.ints();
int pads_size = pads.size();
@ -1468,8 +1470,8 @@ CUDAExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph,
std::vector<std::string> activations_supported{"sigmoid", "tanh", "sigmoid", "tanh"};
not_supported = RNNNeedFallbackToCPU(node, activations_supported, node.OpType());
force_inside = !not_supported;
} else if ("Conv" == node.OpType()) {
not_supported = ConvNeedFallbackToCPU(node);
} else if ("ConvTranspose" == node.OpType()) {
not_supported = ConvTransposeNeedFallbackToCPU(node);
force_inside = !not_supported;
} else if ("Cast" == node.OpType()) {
not_supported = CastNeedFallbackToCPU(node);

View file

@ -5,6 +5,7 @@
#include "core/providers/cuda/cuda_common.h"
#include "core/providers/cuda/nn/conv.h"
#include "core/providers/cuda/shared_inc/fpgeneric.h"
#include "core/providers/cuda/tensor/slice.h"
namespace onnxruntime {
namespace cuda {
@ -33,6 +34,24 @@ REGISTER_KERNEL_TYPED(float)
REGISTER_KERNEL_TYPED(double)
REGISTER_KERNEL_TYPED(MLFloat16)
static Status SliceOutUnwantedOutputSection(const void* input_data,
const std::vector<int64_t>& input_dims,
void* output_data,
const std::vector<int64_t>& output_dims,
std::vector<int64_t> starts,
const std::vector<int64_t>& ends,
const std::vector<int64_t>& axes,
size_t element_size) {
SliceOp::PrepareForComputeMetadata compute_metadata(input_dims);
SliceBase::PrepareForCompute(starts, ends, axes, compute_metadata);
// As a sanity check, ensure that the slice operator's output shape matches with the expected output shape
ORT_ENFORCE(compute_metadata.output_dims_ == output_dims);
return SliceCuda::Impl(input_data, input_dims, output_data, compute_metadata, element_size);
}
template <typename T>
Status Conv<T>::ComputeInternal(OpKernelContext* context) const {
typedef typename ToCudaType<T>::MappedType CudaT;
@ -52,6 +71,18 @@ Status Conv<T>::ComputeInternal(OpKernelContext* context) const {
CudaT* y_data = nullptr;
AllocatorPtr allocator;
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));
size_t element_size = X->DataType()->Size();
Tensor* Y = nullptr;
// We may have to write the CuDNN Conv results to a temporary bufferwhen we deal with
// asymmetric padding as we have to take the results written to this temporary buffer and slice out
// extraneous portions of the result
IAllocatorUniquePtr<void> memory_for_cudnn_conv_results;
{
std::lock_guard<OrtMutex> lock(s_.mutex);
// TODO: add a global cache if need to handle cases for multiple frames running simultaneuously with different batch_size
@ -88,15 +119,46 @@ Status Conv<T>::ComputeInternal(OpKernelContext* context) const {
}
std::vector<int64_t> y_dims;
y_dims.reserve(2 + rank); // rank indicates number of feature dimensions - so add 2 to account for 'N' and 'C'
y_dims.insert(y_dims.begin(), {N, M});
ORT_RETURN_IF_ERROR(conv_attrs_.InferOutputShape(x_shape.Slice(2), kernel_shape,
strides, dilations, pads, y_dims, true));
std::vector<int64_t> y_dims_with_adjusted_pads;
y_dims_with_adjusted_pads.reserve(2 + rank); // rank indicates number of feature dimensions - so add 2 to account for 'N' and 'C'
y_dims_with_adjusted_pads.insert(y_dims_with_adjusted_pads.begin(), {N, M});
bool post_slicing_required = false;
std::vector<int64_t> slice_starts;
slice_starts.reserve(rank);
std::vector<int64_t> slice_ends;
slice_ends.reserve(rank);
std::vector<int64_t> slice_axes;
slice_axes.reserve(rank);
ORT_RETURN_IF_ERROR(conv_attrs_.InferOutputShapeWithAdjustedPads(x_shape.Slice(2), kernel_shape,
strides, dilations, pads, y_dims, y_dims_with_adjusted_pads,
post_slicing_required, slice_starts, slice_ends, slice_axes));
ORT_ENFORCE(y_dims.size() == y_dims_with_adjusted_pads.size());
s_.y_dims = y_dims;
Tensor* Y = context->Output(0, TensorShape(s_.y_dims));
y_data = reinterpret_cast<CudaT*>(Y->template MutableData<T>());
s_.y_dims_with_adjusted_pads = y_dims_with_adjusted_pads;
s_.post_slicing_required = post_slicing_required;
s_.slice_starts = slice_starts;
s_.slice_ends = slice_ends;
s_.slice_axes = slice_axes;
Y = context->Output(0, TensorShape(s_.y_dims));
if (!post_slicing_required) {
// No post slicing needed. Fill the output tensor's buffer directly.
y_data = reinterpret_cast<CudaT*>(Y->template MutableData<T>());
} else {
// Post slicing needed. Create and fill in the Conv results in an intermediate buffer.
memory_for_cudnn_conv_results = GetScratchBuffer<void>(TensorShape(y_dims_with_adjusted_pads).Size() * element_size);
y_data = reinterpret_cast<CudaT*>(memory_for_cudnn_conv_results.get());
}
std::vector<int64_t> x_dims_cudnn = x_dims;
std::vector<int64_t> y_dims_cudnn = y_dims;
std::vector<int64_t> y_dims_cudnn = !post_slicing_required ? y_dims : y_dims_with_adjusted_pads;
if (rank < 2) {
// cudnn only takes 4D or 5D input, so pad dimensions if needed
x_dims_cudnn.push_back(1);
@ -173,12 +235,18 @@ Status Conv<T>::ComputeInternal(OpKernelContext* context) const {
}
if (!y_data) {
Tensor* Y = context->Output(0, TensorShape(s_.y_dims));
Y = context->Output(0, TensorShape(s_.y_dims));
// special case when there is a dim value of 0 in the shape.
if (Y->Shape().Size() == 0)
return Status::OK();
y_data = reinterpret_cast<CudaT*>(Y->template MutableData<T>());
if (!s_.post_slicing_required) {
y_data = reinterpret_cast<CudaT*>(Y->template MutableData<T>());
} else {
// Post slicing needed. Create and fill in the Conv results in an intermediate buffer.
memory_for_cudnn_conv_results = GetScratchBuffer<void>(TensorShape(s_.y_dims_with_adjusted_pads).Size() * element_size);
y_data = reinterpret_cast<CudaT*>(memory_for_cudnn_conv_results.get());
}
}
const auto alpha = Consts<CudaT>::One;
@ -203,7 +271,15 @@ Status Conv<T>::ComputeInternal(OpKernelContext* context) const {
if (has_bias) {
const Tensor* B = context->Input<Tensor>(2);
auto b_data = reinterpret_cast<const CudaT*>(B->template Data<T>());
CUDNN_RETURN_IF_ERROR(cudnnAddTensor(CudnnHandle(), &alpha, s_.b_tensor, b_data, &alpha, s_.y_tensor, y_data));
CUDNN_RETURN_IF_ERROR(cudnnAddTensor(CudnnHandle(), &alpha, s_.b_tensor, b_data, &alpha, s_.y_tensor,
y_data));
}
// To deal with asymmetric padding, we may have over-padded on one or both sides of the spatial dimensions
// This may have lead to extra results that are unnecessary and hence we slice that off here
if (s_.post_slicing_required) {
SliceOutUnwantedOutputSection(y_data, s_.y_dims_with_adjusted_pads, Y->MutableDataRaw(),
s_.y_dims, s_.slice_starts, s_.slice_ends, s_.slice_axes, element_size);
}
}

View file

@ -88,7 +88,7 @@ class lru_unordered_map {
lru_list_.clear();
}
private:
private:
using list_type = std::list<Key, ListAllocator>;
using iterator_type = typename list_type::iterator;
struct value_type {
@ -117,6 +117,7 @@ struct CudnnConvState {
// these would be recomputed if x/w dims change
std::vector<int64_t> y_dims;
std::vector<int64_t> y_dims_with_adjusted_pads;
size_t workspace_bytes;
decltype(AlgoPerfType().algo) algo;
CudnnTensor x_tensor;
@ -126,12 +127,18 @@ struct CudnnConvState {
CudnnConvolutionDescriptor conv_desc;
struct PerfResultParams {
decltype(AlgoPerfType().algo) algo;
decltype(AlgoPerfType().memory) memory;
decltype(AlgoPerfType().algo) algo;
decltype(AlgoPerfType().memory) memory;
decltype(AlgoPerfType().mathType) mathType;
};
lru_unordered_map<std::vector<int64_t>, PerfResultParams, vector_hash<int64_t>> cached_benchmark_results { MAX_CACHED_ALGO_PERF_RESULTS };
lru_unordered_map<std::vector<int64_t>, PerfResultParams, vector_hash<int64_t>> cached_benchmark_results{MAX_CACHED_ALGO_PERF_RESULTS};
// Some properties needed to support asymmetric padded Conv nodes
bool post_slicing_required;
std::vector<int64_t> slice_starts;
std::vector<int64_t> slice_ends;
std::vector<int64_t> slice_axes;
// note that conv objects are shared between execution frames, and a lock is needed to avoid multi-thread racing
OrtMutex mutex;
@ -147,10 +154,6 @@ class Conv : public CudaKernel {
Conv(const OpKernelInfo& info) : CudaKernel(info), conv_attrs_(info) {
auto pads_size = conv_attrs_.pads.size();
ORT_ENFORCE(pads_size % 2 == 0);
auto rank = pads_size / 2;
for (size_t i = 0; i < rank; i++) {
ORT_ENFORCE(conv_attrs_.pads[i] == conv_attrs_.pads[i + rank], "cudnn only supports symmetric padding");
}
}
Status ComputeInternal(OpKernelContext* context) const override;

View file

@ -63,46 +63,43 @@ REGISTER_V11_TYPED_SLICE(int32_t)
REGISTER_V11_TYPED_SLICE(int64_t)
REGISTER_V11_TYPED_SLICE(float)
template <bool dynamic>
Status Slice<dynamic>::ComputeInternal(OpKernelContext* ctx) const {
const Tensor* input_tensor = GetSlicedOrUnslicedTensor(ctx);
ORT_ENFORCE(nullptr != input_tensor);
auto& input_dimensions = input_tensor->Shape().GetDims();
// Initialize the starts & ends to the actual tensor shape
size_t dimension_count = input_dimensions.size();
std::vector<int64_t> starts(dimension_count, 0);
std::vector<int64_t> steps(dimension_count, 1);
std::vector<int64_t> output_dims(input_dimensions);
std::vector<int64_t> flattened_output_dims;
std::vector<int64_t>* p_flattened_output_dims = &flattened_output_dims;
if (dynamic) {
std::vector<int64_t> input_starts, input_ends, input_axes, input_steps;
FillInputVectors(ctx, input_starts, input_ends, input_axes, input_steps);
ORT_RETURN_IF_ERROR(PrepareForCompute(input_starts, input_ends, input_axes,
input_steps, input_dimensions, starts, steps, output_dims,
p_flattened_output_dims));
} else {
ORT_RETURN_IF_ERROR(PrepareForCompute(StartsAttribute(), EndsAttribute(), AxesAttribute(),
input_dimensions, starts, steps, output_dims,
p_flattened_output_dims));
static Status SliceImpCore(const void* input_data, void* output_data,
size_t element_size, size_t dimension_count,
const TArray<int64_t>& starts_buffer, const TArray<int64_t>& steps_buffer,
const TArray<int64_t>& input_strides, const TArray<fast_divmod>& output_strides,
const TensorShape& output_shape) {
if (output_shape.Size() == 0) {
return Status::OK();
}
return SliceImpl(element_size,
gsl::narrow_cast<int32_t>(dimension_count),
starts_buffer,
steps_buffer,
input_strides,
output_strides,
input_data,
output_data,
output_shape.Size());
}
namespace SliceCuda {
static Status ComputeSliceStrides(const TensorShape& input_shape,
TArray<int64_t>& input_strides,
TArray<fast_divmod>& output_strides,
SliceOp::PrepareForComputeMetadata& compute_metadata) {
const auto& input_dimensions = input_shape.GetDims();
size_t dimension_count = input_dimensions.size();
// if we are able to flatten the output dims we updated 'starts' and 'steps' to match the smaller number of dims.
// update dimension_count to match.
if (p_flattened_output_dims != nullptr) {
dimension_count = flattened_output_dims.size();
if (compute_metadata.p_flattened_output_dims_) {
dimension_count = compute_metadata.p_flattened_output_dims_->size();
}
TArray<int64_t> starts_buffer(starts);
TArray<int64_t> steps_buffer(steps);
TArray<int64_t> input_strides(gsl::narrow_cast<int32_t>(dimension_count));
input_strides.SetSize(gsl::narrow_cast<int32_t>(dimension_count));
const gsl::span<int64_t> input_strides_span = gsl::make_span(input_strides.Data(), input_strides.Size());
if (p_flattened_output_dims != nullptr) {
if (compute_metadata.p_flattened_output_dims_ != nullptr) {
// we were able to flatten the innermost dimensions as they're being copied in full to the output.
// do the same flattening to the innermost input dimensions in order to calculate pitches that match
// the flattened output dimensions.
@ -119,22 +116,82 @@ Status Slice<dynamic>::ComputeInternal(OpKernelContext* ctx) const {
ORT_ENFORCE(TensorPitches::Calculate(input_strides_span, input_dimensions));
}
TensorPitches original_output_strides(p_flattened_output_dims != nullptr ? flattened_output_dims : output_dims);
TArray<fast_divmod> output_strides(gsl::narrow_cast<int32_t>(original_output_strides.size()));
TensorPitches original_output_strides(
compute_metadata.p_flattened_output_dims_ != nullptr ? compute_metadata.flattened_output_dims_ : compute_metadata.output_dims_);
output_strides.SetSize(gsl::narrow_cast<int32_t>(original_output_strides.size()));
for (int32_t i = 0; i < static_cast<int32_t>(original_output_strides.size()); ++i) {
output_strides[i] = fast_divmod(gsl::narrow_cast<int>(original_output_strides[i]));
}
size_t element_size = input_tensor->DataType()->Size();
return Status::OK();
}
ORT_RETURN_IF_ERROR(CallSliceImp(element_size,
Status Impl(const void* input_data,
const TensorShape& input_shape,
void* output_data,
SliceOp::PrepareForComputeMetadata& compute_metadata,
size_t element_size) {
const auto& input_dimensions = input_shape.GetDims();
size_t dimension_count = input_dimensions.size();
TArray<int64_t> starts_buffer(compute_metadata.starts_);
TArray<int64_t> steps_buffer(compute_metadata.steps_);
TArray<int64_t> input_strides;
TArray<fast_divmod> output_strides;
ORT_RETURN_IF_ERROR(ComputeSliceStrides(input_shape, input_strides, output_strides, compute_metadata));
TensorShape output_shape(compute_metadata.output_dims_);
ORT_RETURN_IF_ERROR(SliceImpCore(input_data,
output_data,
element_size,
gsl::narrow_cast<int32_t>(dimension_count),
starts_buffer,
steps_buffer,
input_strides,
output_strides,
ctx,
TensorShape(output_dims)));
output_shape));
return Status::OK();
}
} // namespace SliceCuda
template <bool dynamic>
Status Slice<dynamic>::ComputeInternal(OpKernelContext* ctx) const {
const Tensor* input_tensor = GetSlicedOrUnslicedTensor(ctx);
ORT_ENFORCE(nullptr != input_tensor);
const auto& input_shape = input_tensor->Shape();
const auto& input_dimensions = input_shape.GetDims();
if (input_dimensions.empty()) return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Cannot slice scalars");
SliceOp::PrepareForComputeMetadata compute_metadata(input_dimensions);
if (dynamic) {
std::vector<int64_t> input_starts, input_ends, input_axes, input_steps;
FillInputVectors(ctx, input_starts, input_ends, input_axes, input_steps);
ORT_RETURN_IF_ERROR(PrepareForCompute(input_starts, input_ends, input_axes, input_steps, compute_metadata));
} else {
ORT_RETURN_IF_ERROR(PrepareForCompute(StartsAttribute(), EndsAttribute(), AxesAttribute(), compute_metadata));
}
TensorShape output_shape(compute_metadata.output_dims_);
TArray<int64_t> starts_buffer(compute_metadata.starts_);
TArray<int64_t> steps_buffer(compute_metadata.steps_);
TArray<int64_t> input_strides;
TArray<fast_divmod> output_strides;
ORT_RETURN_IF_ERROR(SliceCuda::ComputeSliceStrides(input_shape, input_strides, output_strides, compute_metadata));
// It may seem that we may use `SliceImpCore()` directly, but we need to go through `CallSliceImp()` because
// `ComputeInternal()` is shared between the inferencing and training kernels and the training kernel overrides
// `CallSliceImp()`
ORT_RETURN_IF_ERROR(CallSliceImp(input_tensor->DataType()->Size(), input_dimensions.size(), starts_buffer,
steps_buffer, input_strides,
output_strides, ctx,
output_shape));
return Status::OK();
}
@ -156,21 +213,19 @@ template <bool dynamic>
Status Slice<dynamic>::CallSliceImp(size_t element_size, size_t dimension_count, const TArray<int64_t>& starts_buffer,
const TArray<int64_t>& steps_buffer, const TArray<int64_t>& input_strides,
const TArray<fast_divmod>& output_strides, OpKernelContext* ctx,
TensorShape output_shape) const {
const TensorShape& output_shape) const {
const auto* input_tensor = ctx->Input<Tensor>(0);
auto* output_tensor = ctx->Output(0, output_shape);
if (output_shape.Size() == 0) {
return Status::OK();
}
return SliceImpl(element_size,
gsl::narrow_cast<int32_t>(dimension_count),
starts_buffer,
steps_buffer,
input_strides,
output_strides,
ctx->Input<Tensor>(0)->DataRaw(),
output_tensor->MutableDataRaw(),
output_shape.Size());
return SliceImpCore(input_tensor->DataRaw(),
output_tensor->MutableDataRaw(),
element_size,
gsl::narrow_cast<int32_t>(dimension_count),
starts_buffer,
steps_buffer,
input_strides,
output_strides,
output_shape);
}
} // namespace cuda

View file

@ -9,9 +9,19 @@
namespace onnxruntime {
namespace cuda {
namespace SliceCuda {
Status Impl(const void* input_data,
const TensorShape& input_shape,
void* output_data,
SliceOp::PrepareForComputeMetadata& prepare_metadata,
size_t element_size);
} // namespace SliceCuda
template <bool dynamic>
class Slice : public CudaKernel, public SliceBase {
public:
public:
Slice(const OpKernelInfo& info) : CudaKernel(info), SliceBase(info, dynamic) {}
Status ComputeInternal(OpKernelContext* ctx) const override;
@ -25,7 +35,7 @@ public:
virtual Status CallSliceImp(size_t element_size, size_t dimension_count, const TArray<int64_t>& starts_buffer,
const TArray<int64_t>& steps_buffer, const TArray<int64_t>& input_strides,
const TArray<fast_divmod>& output_strides, OpKernelContext* ctx,
TensorShape output_shape) const;
const TensorShape& output_shape) const;
};
} // namespace cuda
} // namespace onnxruntime

View file

@ -200,7 +200,6 @@ TEST(ConvTest, Conv2D_1) {
auto expected_vals = {-0.012311071157455444f, 0.02822777070105076f, -0.028432954102754593f, -0.037657227367162704f,
-0.04396762326359749f, 0.10081233829259872f, -0.10154513269662857f, -0.13448859751224518f};
attrs.excluded_providers.insert(kCudaExecutionProvider); // asymmetric padding is not supported by cudnn
TestConvOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape);
// NNAPI EP requires weight to be an initializer
@ -372,8 +371,6 @@ TEST(ConvTest, Conv2D_Bias_2) {
auto expected_vals = {-0.3419531583786011f, -0.6116723418235779f, -0.39677709341049194f, -0.7316848039627075f,
-0.5647197365760803f, 0.02788025140762329f, -0.30450713634490967f, -0.6786775588989258f};
attrs.excluded_providers.insert(kCudaExecutionProvider); // asymmetric padding is not supported by cudnn
TestConvOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape);
TestConvOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape, true);
@ -655,5 +652,30 @@ TEST(ConvTest, ConvDimWithZero) {
TestConvOp(attrs, {X, W}, {X_shape, W_shape}, {}, out_shape, false, OpTester::ExpectResult::kExpectSuccess, "", 10);
}
TEST(ConvTest, Conv1D_asymmetric_padding) {
ConvOpAndTestAttributes attrs = {
"", // auto_pad
vector<int64_t>{1}, // dilations
1, // group
vector<int64_t>{3}, // kernel_shape
vector<int64_t>{1, 0}, // pads
vector<int64_t>{1}, // strides
{} // excluded EPs
};
vector<float> X = {1.f, 2.f, 3.f};
vector<int64_t> X_shape = {1, 1, 3};
vector<float> W = {1.f, 1.f, 1.f};
vector<int64_t> W_shape = {1, 1, 3};
vector<float> B = {0.f};
vector<int64_t> B_shape = {1};
vector<int64_t> Y_shape = {1, 1, 2};
auto expected_vals = {3.f, 6.f};
TestConvOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape);
TestConvOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape, true);
}
} // namespace test
} // namespace onnxruntime

View file

@ -28,11 +28,6 @@ Status SliceGrad::Compute(OpKernelContext* context) const {
Tensor& output = *context->Output(0, data_shape);
memset(output.MutableDataRaw(), 0, output.SizeInBytes());
// Initialize the starts & ends to the actual tensor shape
std::vector<int64_t> starts(data_shape.GetDims().size(), 0);
std::vector<int64_t> steps(data_shape.GetDims().size(), 1);
std::vector<int64_t> output_dims(data_shape.GetDims());
std::vector<int64_t> flattened_output_dims;
std::vector<int64_t>* p_flattened_output_dims = &flattened_output_dims;
std::vector<int64_t> input_starts;
std::vector<int64_t> input_ends;
std::vector<int64_t> input_axes;
@ -40,17 +35,18 @@ Status SliceGrad::Compute(OpKernelContext* context) const {
FillVectorsFromInput(*context->Input<Tensor>(2), *context->Input<Tensor>(3), context->Input<Tensor>(4),
context->Input<Tensor>(5), input_starts, input_ends, input_axes, input_steps);
ORT_RETURN_IF_ERROR(PrepareForCompute(input_starts, input_ends, input_axes, input_steps,
data_shape.GetDims(), starts, steps, output_dims,
p_flattened_output_dims));
SliceOp::PrepareForComputeMetadata compute_metadata(data_shape.GetDims());
ORT_RETURN_IF_ERROR(PrepareForCompute(input_starts, input_ends, input_axes, input_steps, compute_metadata));
MLDataType T_type = grad.DataType();
if (T_type == DataTypeImpl::GetType<float>()) {
return ComputeImpl<float>(context, output, output_dims, p_flattened_output_dims, starts, steps);
return ComputeImpl<float>(context, output, compute_metadata.output_dims_, compute_metadata.p_flattened_output_dims_,
compute_metadata.starts_, compute_metadata.steps_);
}
if (T_type == DataTypeImpl::GetType<double>()) {
return ComputeImpl<double>(context, output, output_dims, p_flattened_output_dims, starts, steps);
return ComputeImpl<double>(context, output, compute_metadata.output_dims_, compute_metadata.p_flattened_output_dims_,
compute_metadata.starts_, compute_metadata.steps_);
}
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "Type for T or Tind not supported yet in SliceGrad.");

View file

@ -50,7 +50,7 @@ void SliceGrad::FillInputVectors(OpKernelContext* ctx, std::vector<int64_t>& inp
Status SliceGrad::CallSliceImp(size_t element_size, size_t dimension_count, const TArray<int64_t>& starts_buffer,
const TArray<int64_t>& steps_buffer, const TArray<int64_t>& input_strides,
const TArray<fast_divmod>& output_strides, OpKernelContext* ctx,
TensorShape output_shape) const {
const TensorShape& output_shape) const {
Tensor* gradient_out_tensor = GetOutputGradientTensor(ctx);
CUDA_RETURN_IF_ERROR(cudaMemset(gradient_out_tensor->MutableDataRaw(), 0, gradient_out_tensor->SizeInBytes()));
return SliceImplGrad(element_size,

View file

@ -17,7 +17,7 @@ class SliceGrad final : public Slice<true> {
Status CallSliceImp(size_t element_size, size_t dimension_count, const TArray<int64_t>& starts_buffer,
const TArray<int64_t>& steps_buffer, const TArray<int64_t>& input_strides,
const TArray<fast_divmod>& output_strides, OpKernelContext* ctx, TensorShape output_shape)
const TArray<fast_divmod>& output_strides, OpKernelContext* ctx, const TensorShape& output_shape)
const override;
};
} // namespace cuda