diff --git a/aten/src/ATen/native/SegmentReduce.cpp b/aten/src/ATen/native/SegmentReduce.cpp index 1139515b119..3e562b7cf85 100644 --- a/aten/src/ATen/native/SegmentReduce.cpp +++ b/aten/src/ATen/native/SegmentReduce.cpp @@ -8,8 +8,10 @@ namespace at { namespace native { -DEFINE_DISPATCH(_segment_reduce_stub); -DEFINE_DISPATCH(_segment_reduce_backward_stub); +DEFINE_DISPATCH(_segment_reduce_lengths_stub); +DEFINE_DISPATCH(_segment_reduce_offsets_stub); +DEFINE_DISPATCH(_segment_reduce_lengths_backward_stub); +DEFINE_DISPATCH(_segment_reduce_offsets_backward_stub); namespace { @@ -29,8 +31,8 @@ SegmentReductionType get_reduction_enum(const c10::string_view& reduce) { } } -template -void _segment_reduce_cpu_kernel1( +template +void _segment_reduce_lengths_cpu_kernel1( SegmentReductionType reduction, const Tensor& data, const T* lengths_data, @@ -46,14 +48,30 @@ void _segment_reduce_cpu_kernel1( outer_offset *= output.size(d); for (int64_t d = axis + 1; d < output.dim(); d++) inner_offset *= output.size(d); + int64_t lengths_size_axis = is_offsets_like ? segment_count + 1 : segment_count; + auto data_stride_axis = data.stride(axis); + auto data_size_axis = data.size(axis); + auto output_stride_axis = output.stride(axis); + auto output_size_axis = output.size(axis); AT_DISPATCH_FLOATING_TYPES_AND2( kBFloat16, kHalf, data.scalar_type(), "_segment_reduce_cpu", [&]() { auto* output_data = output.data_ptr(); const auto* values_data = data.data_ptr(); for (const auto outer_idx : c10::irange(outer_offset)) { - int64_t lengths_cum_sum = 0; + int64_t segment_start, segment_length; + int64_t segment_end = is_offsets_like ? + lengths_data[outer_idx * lengths_stride_axis * lengths_size_axis] : + 0; for (const auto dim_idx : c10::irange(segment_count)) { - int64_t segment_length = lengths_data[outer_idx * lengths_stride_axis * segment_count + dim_idx]; + segment_start = segment_end; + auto lengths_idx = outer_idx * lengths_stride_axis * lengths_size_axis + dim_idx; + if (is_offsets_like) { + segment_end = lengths_data[lengths_idx + 1]; + segment_length = segment_end - segment_start; + } else { + segment_length = lengths_data[lengths_idx]; + segment_end += segment_length; + } for (const auto inner_idx : c10::irange(inner_offset)) { // ===== step1: initialize starting value scalar_t initial_value; @@ -72,9 +90,9 @@ void _segment_reduce_cpu_kernel1( } // ===== step2: apply reduction - for (const auto j : c10::irange(segment_length)) { - int64_t data_index = outer_idx * data.stride(axis) * data.size(axis) - + (lengths_cum_sum + j) * data.stride(axis) + inner_idx; + for (const auto j : c10::irange(segment_start, segment_end)) { + int64_t data_index = outer_idx * data_stride_axis * data_size_axis + + j * data_stride_axis + inner_idx; const auto val = values_data[data_index]; if (reduction == SegmentReductionType::MAX) { initial_value = at::_isnan(val) @@ -104,17 +122,16 @@ void _segment_reduce_cpu_kernel1( segment_length > 0 && !at::_isnan(initial_value)) { initial_value = initial_value / segment_length; } - int64_t output_index = outer_idx * output.stride(axis) * output.size(axis) - + dim_idx * output.stride(axis) + inner_idx; + int64_t output_index = outer_idx * output_stride_axis * output_size_axis + + dim_idx * output_stride_axis + inner_idx; output_data[output_index] = initial_value; } - lengths_cum_sum += segment_length; } } }); } -Tensor _segment_reduce_cpu_kernel( +Tensor _segment_reduce_lengths_cpu_kernel( SegmentReductionType reduction, const Tensor& data, const Tensor& lengths, @@ -131,17 +148,43 @@ Tensor _segment_reduce_cpu_kernel( output_shape[axis] = segment_count; auto output = at::empty(output_shape, data.options()); - AT_DISPATCH_INDEX_TYPES(lengths.scalar_type(), "_segment_reduce_cpu_kernel1", [&]() { + AT_DISPATCH_INDEX_TYPES(lengths.scalar_type(), "_segment_reduce_lengths_cpu_kernel1", [&]() { const auto* lengths_data = lengths.data_ptr(); - _segment_reduce_cpu_kernel1( + _segment_reduce_lengths_cpu_kernel1( reduction, data, lengths_data, axis, initial, output, segment_count, lengths_stride_axis); }); return output; } -template -void _segment_reduce_cpu_backward_kernel1( +Tensor _segment_reduce_offsets_cpu_kernel( + SegmentReductionType reduction, + const Tensor& data, + const Tensor& offsets, + int64_t axis, + const c10::optional& initial) { + // data and lengths should be contiguous from the call to .contiguous in segment_reduce_kernel + TORCH_CHECK(data.is_contiguous(), "Expected data to be contiguous."); + TORCH_CHECK(offsets.is_contiguous(), "Expected offsets to be contiguous."); + // reduction axis should always be the last dimension of lengths + axis = offsets.dim() - 1; + int64_t segment_count = offsets.size(axis) - 1; + int64_t offsets_stride_axis = offsets.stride(axis); + auto output_shape = data.sizes().vec(); + output_shape[axis] = segment_count; + auto output = at::empty(output_shape, data.options()); + + AT_DISPATCH_INDEX_TYPES(offsets.scalar_type(), "_segment_reduce_offsets_cpu_kernel1", [&]() { + const auto* offsets_data = offsets.data_ptr(); + _segment_reduce_lengths_cpu_kernel1( + reduction, data, offsets_data, axis, initial, output, segment_count, offsets_stride_axis); + }); + + return output; +} + +template +void _segment_reduce_cpu_lengths_backward_kernel1( const Tensor& grad_contig, const Tensor& output_contig, const Tensor& data_contig, @@ -159,7 +202,12 @@ void _segment_reduce_cpu_backward_kernel1( outer_offset *= output_contig.size(d); for (int64_t d = axis + 1; d < output_contig.dim(); d++) inner_offset *= output_contig.size(d); - // TODO: Swtich to TensorIterator for better maintainablility and + int64_t lengths_size_axis = is_offsets_like ? segment_count + 1 : segment_count; + auto data_stride_axis = data_contig.stride(axis); + auto data_size_axis = data_contig.size(axis); + auto output_stride_axis = output_contig.stride(axis); + auto output_size_axis = output_contig.size(axis); + // TODO: Switch to TensorIterator for better maintainablility and // readability AT_DISPATCH_FLOATING_TYPES_AND2( kBFloat16, @@ -182,21 +230,34 @@ void _segment_reduce_cpu_backward_kernel1( } for (const auto outer_idx : c10::irange(outer_offset)) { - int64_t lengths_cum_sum = 0; + // int64_t lengths_cum_sum = 0; + int64_t segment_start, segment_length; + int64_t segment_end = is_offsets_like ? + lengths_data[outer_idx * lengths_stride_axis * lengths_size_axis] : + 0; for (const auto dim_idx : c10::irange(segment_count)) { - int64_t segment_length = lengths_data[outer_idx * lengths_stride_axis * segment_count + dim_idx]; + // int64_t segment_length = lengths_data[outer_idx * lengths_stride_axis * segment_count + dim_idx]; + segment_start = segment_end; + auto lengths_idx = outer_idx * lengths_stride_axis * lengths_size_axis + dim_idx; + if (is_offsets_like) { + segment_end = lengths_data[lengths_idx + 1]; + segment_length = segment_end - segment_start; + } else { + segment_length = lengths_data[lengths_idx]; + segment_end += segment_length; + } if (segment_length == 0) { continue; } for (const auto inner_idx : c10::irange(inner_offset)) { - int64_t output_index = outer_idx * output_contig.stride(axis) * output_contig.size(axis) - + dim_idx * output_contig.stride(axis) + inner_idx; + int64_t output_index = outer_idx * output_stride_axis * output_size_axis + + dim_idx * output_stride_axis + inner_idx; if (reduction == SegmentReductionType::MAX || reduction == SegmentReductionType::MIN) { int64_t counter = 0; - for (const auto j : c10::irange(segment_length)) { - int64_t data_index = outer_idx * data_contig.stride(axis) * data_contig.size(axis) - + (lengths_cum_sum + j) * data_contig.stride(axis) + inner_idx; + for (const auto j : c10::irange(segment_start, segment_end)) { + int64_t data_index = outer_idx * data_stride_axis * data_size_axis + + j * data_stride_axis + inner_idx; if (at::_isnan(values_data[data_index]) || values_data[data_index] == output_data[output_index]) { grad_input_data[data_index] = grad_data[output_index]; @@ -208,9 +269,9 @@ void _segment_reduce_cpu_backward_kernel1( if (counter < 2) { continue; } - for (const auto j : c10::irange(segment_length)) { - int64_t data_index = outer_idx * data_contig.stride(axis) * data_contig.size(axis) - + (lengths_cum_sum + j) * data_contig.stride(axis) + inner_idx; + for (const auto j : c10::irange(segment_start, segment_end)) { + int64_t data_index = outer_idx * data_stride_axis * data_size_axis + + j * data_stride_axis + inner_idx; if (grad_input_data[data_index] > 0) { grad_input_data[data_index] = grad_input_data[data_index] / counter; @@ -218,32 +279,32 @@ void _segment_reduce_cpu_backward_kernel1( } } else if (reduction == SegmentReductionType::MEAN) { auto grad_val = grad_data[output_index] / segment_length; - for (const auto j : c10::irange(segment_length)) { - int64_t data_index = outer_idx * data_contig.stride(axis) * data_contig.size(axis) - + (lengths_cum_sum + j) * data_contig.stride(axis) + inner_idx; + for (const auto j : c10::irange(segment_start, segment_end)) { + int64_t data_index = outer_idx * data_stride_axis * data_size_axis + + j * data_stride_axis + inner_idx; grad_input_data[data_index] = grad_val; } } else if (reduction == SegmentReductionType::SUM) { const auto& grad_val = grad_data[output_index]; - for (const auto j : c10::irange(segment_length)) { - int64_t data_index = outer_idx * data_contig.stride(axis) * data_contig.size(axis) - + (lengths_cum_sum + j) * data_contig.stride(axis) + inner_idx; + for (const auto j : c10::irange(segment_start, segment_end)) { + int64_t data_index = outer_idx * data_stride_axis * data_size_axis + + j * data_stride_axis + inner_idx; grad_input_data[data_index] = grad_val; } } else if (reduction == SegmentReductionType::PROD) { const auto& grad_val = grad_data[output_index] * output_data[output_index]; - for (const auto j : c10::irange(segment_length)) { - int64_t data_index = outer_idx * data_contig.stride(axis) * data_contig.size(axis) - + (lengths_cum_sum + j) * data_contig.stride(axis) + inner_idx; + for (const auto j : c10::irange(segment_start, segment_end)) { + int64_t data_index = outer_idx * data_stride_axis * data_size_axis + + j * data_stride_axis + inner_idx; if (at::_isnan(values_data[data_index]) || values_data[data_index] == 0) { // explicitly compute exclusive prod scalar_t exclusive_prod = initial_prod_value; int64_t idx; - for (const auto k : c10::irange(segment_length)) { + for (const auto k : c10::irange(segment_start, segment_end)) { if (k != j) { - idx = outer_idx * data_contig.stride(axis) * data_contig.size(axis) - + (lengths_cum_sum + k) * data_contig.stride(axis) + inner_idx; + idx = outer_idx * data_stride_axis * data_size_axis + + k * data_stride_axis + inner_idx; exclusive_prod *= values_data[idx]; } } @@ -254,13 +315,12 @@ void _segment_reduce_cpu_backward_kernel1( } } } - lengths_cum_sum += segment_length; } } }); } -Tensor _segment_reduce_cpu_backward_kernel( +Tensor _segment_reduce_cpu_lengths_backward_kernel( const Tensor& grad_contig, const Tensor& output_contig, const Tensor& data_contig, @@ -274,9 +334,9 @@ Tensor _segment_reduce_cpu_backward_kernel( auto grad_input = at::zeros({data_contig.sizes()}, grad_contig.options()); AT_DISPATCH_INDEX_TYPES( - lengths_contig.scalar_type(), "_segment_reduce_cpu_backward_kernel1", [&] { + lengths_contig.scalar_type(), "_segment_reduce_cpu_lengths_backward_kernel1", [&] { const auto* lengths_data = lengths_contig.data_ptr(); - _segment_reduce_cpu_backward_kernel1( + _segment_reduce_cpu_lengths_backward_kernel1( grad_contig, output_contig, data_contig, @@ -292,6 +352,39 @@ Tensor _segment_reduce_cpu_backward_kernel( return grad_input; } + +Tensor _segment_reduce_cpu_offsets_backward_kernel( + const Tensor& grad_contig, + const Tensor& output_contig, + const Tensor& data_contig, + SegmentReductionType reduction, + const Tensor& offsets_contig, + int64_t axis, + const c10::optional& initial) { + axis = offsets_contig.dim() - 1; + int64_t segment_count = offsets_contig.size(axis) - 1; + int64_t offsets_stride_axis = offsets_contig.stride(axis); + auto grad_input = at::zeros({data_contig.sizes()}, grad_contig.options()); + + AT_DISPATCH_INDEX_TYPES( + offsets_contig.scalar_type(), "_segment_reduce_cpu_offsets_backward_kernel1", [&] { + const auto* offsets_data = offsets_contig.data_ptr(); + _segment_reduce_cpu_lengths_backward_kernel1( + grad_contig, + output_contig, + data_contig, + reduction, + offsets_data, + axis, + initial, + grad_input, + segment_count, + offsets_stride_axis); + }); + + return grad_input; +} + } // namespace Tensor segment_reduce_kernel( @@ -299,49 +392,94 @@ Tensor segment_reduce_kernel( c10::string_view reduce, const c10::optional& lengths, const c10::optional& indices, + const c10::optional& offsets, int64_t axis, bool unsafe, const c10::optional& initial) { axis = maybe_wrap_dim(axis, data.ndimension()); TORCH_CHECK(data.numel() > 0); - // length related checks + // check that one of lengths or offsets is defined + auto lengths_has_value = lengths.has_value(); + auto offsets_has_value = offsets.has_value(); TORCH_CHECK( - lengths.has_value() && !indices.has_value(), - "Currently only lengths based reduction is supported!") - const auto& lengths_value = lengths.value(); - TORCH_CHECK(data.get_device() == lengths_value.get_device()); - TORCH_CHECK(data.dim() >= lengths_value.dim()); - TORCH_CHECK(axis == lengths_value.dim() - 1, "Expected axis to be equal to lengths.ndim() - 1 but got ", axis, "."); - - if (!unsafe) { - auto min_length = lengths_value.min().item(); - TORCH_CHECK((min_length >= 0), "lengths contains negative value!"); - TORCH_CHECK(all(lengths_value.sum({-1}) == data.size(axis)).item(), - "Expected all rows of lengths to sum to data.size(lengths.dim()-1) when unsafe=False"); - } + !indices.has_value(), + "segment_reduce(): indices based reduction is not supported yet."); + TORCH_CHECK( + lengths_has_value || offsets_has_value, + "segment_reduce(): Either lengths or offsets must be defined.") auto reduction = get_reduction_enum(reduce); const auto data_contig = data.contiguous(); - const auto lengths_contig = lengths_value.contiguous(); - return _segment_reduce_stub( + if (offsets_has_value) { + const auto& offsets_value = offsets.value(); + + // offsets related checks + TORCH_CHECK(data.get_device() == offsets_value.get_device()); + TORCH_CHECK(data.dim() >= offsets_value.dim()); + TORCH_CHECK(axis == offsets_value.dim() - 1, + "segment_reduce(): Expected axis to be the last dimension of offsets but got ", axis, "."); + + // TODO: add checks when !unsafe + + const auto offsets_contig = offsets_value.contiguous(); + + return _segment_reduce_offsets_stub( + data_contig.device().type(), + reduction, + data_contig, + offsets_contig, + axis, + initial); + + } else { + const auto& lengths_value = lengths.value(); + + // length related checks + TORCH_CHECK(data.get_device() == lengths_value.get_device()); + TORCH_CHECK(data.dim() >= lengths_value.dim()); + TORCH_CHECK(axis == lengths_value.dim() - 1, + "segment_reduce(): Expected axis to be the last dimension of lengths but got ", axis, "."); + + if (!unsafe) { + auto min_length = lengths_value.min().item(); + TORCH_CHECK((min_length >= 0), "lengths contains negative value!"); + TORCH_CHECK(all(lengths_value.sum({-1}) == data.size(axis)).item(), + "segment_reduce(): Expected all rows of lengths along axis ", + "to sum to data.size(lengths.dim()-1) when !unsafe."); + } + + const auto lengths_contig = lengths_value.contiguous(); + + return _segment_reduce_lengths_stub( data_contig.device().type(), reduction, data_contig, lengths_contig, axis, initial); + } } REGISTER_ARCH_DISPATCH( - _segment_reduce_stub, + _segment_reduce_lengths_stub, DEFAULT, - &_segment_reduce_cpu_kernel); -REGISTER_AVX2_DISPATCH(_segment_reduce_stub, &_segment_reduce_cpu_kernel); -REGISTER_AVX512_DISPATCH(_segment_reduce_stub, &_segment_reduce_cpu_kernel); -REGISTER_VSX_DISPATCH(_segment_reduce_stub, &_segment_reduce_cpu_kernel); -REGISTER_ZVECTOR_DISPATCH(_segment_reduce_stub, &_segment_reduce_cpu_kernel); + &_segment_reduce_lengths_cpu_kernel); +REGISTER_AVX2_DISPATCH(_segment_reduce_lengths_stub, &_segment_reduce_lengths_cpu_kernel); +REGISTER_AVX512_DISPATCH(_segment_reduce_lengths_stub, &_segment_reduce_lengths_cpu_kernel); +REGISTER_VSX_DISPATCH(_segment_reduce_lengths_stub, &_segment_reduce_lengths_cpu_kernel); +REGISTER_ZVECTOR_DISPATCH(_segment_reduce_lengths_stub, &_segment_reduce_lengths_cpu_kernel); + +// offsets dispatches +REGISTER_ARCH_DISPATCH( + _segment_reduce_offsets_stub, + DEFAULT, + &_segment_reduce_offsets_cpu_kernel); +REGISTER_AVX2_DISPATCH(_segment_reduce_offsets_stub, &_segment_reduce_offsets_cpu_kernel); +REGISTER_AVX512_DISPATCH(_segment_reduce_offsets_stub, &_segment_reduce_offsets_cpu_kernel); +REGISTER_VSX_DISPATCH(_segment_reduce_offsets_stub, &_segment_reduce_offsets_cpu_kernel); +REGISTER_ZVECTOR_DISPATCH(_segment_reduce_offsets_stub, &_segment_reduce_offsets_cpu_kernel); // Currently some computation is being duplicated across forward and backward. // TODO: Cache indices in forward pass to re-use in backward @@ -351,21 +489,40 @@ Tensor _segment_reduce_backward_kernel( const Tensor& data, c10::string_view reduce, const c10::optional& lengths, + const c10::optional& offsets, int64_t axis, const c10::optional& initial) { axis = maybe_wrap_dim(axis, data.ndimension()); + // check that one of lengths or offsets is defined + // codegen for derivatives.yaml passes an undefined Tensor for None rather than a c10::optional + // so checking .has_value() doesn't work unlike in the forward pass + auto lengths_has_value = lengths.has_value() && lengths.value().defined(); + auto offsets_has_value = offsets.has_value() && offsets.value().defined(); TORCH_CHECK( - lengths.has_value(), - "Currently only lengths based reduction is supported!") - const auto& lengths_value = lengths.value(); + lengths_has_value || offsets_has_value, + "segment_reduce(): Either lengths or offsets must be defined."); const auto grad_contig = grad.contiguous(); const auto output_contig = output.contiguous(); const auto data_contig = data.contiguous(); - const auto lengths_contig = lengths_value.contiguous(); - auto reduction = get_reduction_enum(reduce); - return _segment_reduce_backward_stub( + + if (offsets_has_value) { + const auto& offsets_value = offsets.value(); + const auto offsets_contig = offsets_value.contiguous(); + return _segment_reduce_offsets_backward_stub( + grad_contig.device().type(), + grad_contig, + output_contig, + data_contig, + reduction, + offsets_contig, + axis, + initial); + } else { + const auto& lengths_value = lengths.value(); + const auto lengths_contig = lengths_value.contiguous(); + return _segment_reduce_lengths_backward_stub( grad_contig.device().type(), grad_contig, output_contig, @@ -374,24 +531,42 @@ Tensor _segment_reduce_backward_kernel( lengths_contig, axis, initial); + } } REGISTER_ARCH_DISPATCH( - _segment_reduce_backward_stub, + _segment_reduce_lengths_backward_stub, DEFAULT, - &_segment_reduce_cpu_backward_kernel); + &_segment_reduce_cpu_lengths_backward_kernel); REGISTER_AVX512_DISPATCH( - _segment_reduce_backward_stub, - &_segment_reduce_cpu_backward_kernel); + _segment_reduce_lengths_backward_stub, + &_segment_reduce_cpu_lengths_backward_kernel); REGISTER_AVX2_DISPATCH( - _segment_reduce_backward_stub, - &_segment_reduce_cpu_backward_kernel); + _segment_reduce_lengths_backward_stub, + &_segment_reduce_cpu_lengths_backward_kernel); REGISTER_VSX_DISPATCH( - _segment_reduce_backward_stub, - &_segment_reduce_cpu_backward_kernel); + _segment_reduce_lengths_backward_stub, + &_segment_reduce_cpu_lengths_backward_kernel); REGISTER_ZVECTOR_DISPATCH( - _segment_reduce_backward_stub, - &_segment_reduce_cpu_backward_kernel); + _segment_reduce_lengths_backward_stub, + &_segment_reduce_cpu_lengths_backward_kernel); + +REGISTER_ARCH_DISPATCH( + _segment_reduce_offsets_backward_stub, + DEFAULT, + &_segment_reduce_cpu_offsets_backward_kernel); +REGISTER_AVX512_DISPATCH( + _segment_reduce_offsets_backward_stub, + &_segment_reduce_cpu_offsets_backward_kernel); +REGISTER_AVX2_DISPATCH( + _segment_reduce_offsets_backward_stub, + &_segment_reduce_cpu_offsets_backward_kernel); +REGISTER_VSX_DISPATCH( + _segment_reduce_offsets_backward_stub, + &_segment_reduce_cpu_offsets_backward_kernel); +REGISTER_ZVECTOR_DISPATCH( + _segment_reduce_offsets_backward_stub, + &_segment_reduce_cpu_offsets_backward_kernel); } // namespace native } // namespace at diff --git a/aten/src/ATen/native/SegmentReduce.h b/aten/src/ATen/native/SegmentReduce.h index a7cb5f8881a..7fb1512fd4c 100644 --- a/aten/src/ATen/native/SegmentReduce.h +++ b/aten/src/ATen/native/SegmentReduce.h @@ -11,15 +11,23 @@ namespace native { enum SegmentReductionType { MAX, MEAN, MIN, SUM, PROD}; -using segment_reduce_fn = Tensor (*)( +using segment_reduce_lengths_fn = Tensor (*)( SegmentReductionType, const Tensor&, const Tensor&, int64_t, const c10::optional&); -DECLARE_DISPATCH(segment_reduce_fn, _segment_reduce_stub); +DECLARE_DISPATCH(segment_reduce_lengths_fn, _segment_reduce_lengths_stub); -using segment_reduce_backward_fn = Tensor (*)( +using segment_reduce_offsets_fn = Tensor (*)( + SegmentReductionType, + const Tensor&, + const Tensor&, + int64_t, + const c10::optional&); +DECLARE_DISPATCH(segment_reduce_offsets_fn, _segment_reduce_offsets_stub); + +using segment_reduce_lengths_backward_fn = Tensor (*)( const Tensor&, const Tensor&, const Tensor&, @@ -27,7 +35,17 @@ using segment_reduce_backward_fn = Tensor (*)( const Tensor&, int64_t, const c10::optional&); -DECLARE_DISPATCH(segment_reduce_backward_fn, _segment_reduce_backward_stub); +DECLARE_DISPATCH(segment_reduce_lengths_backward_fn, _segment_reduce_lengths_backward_stub); + +using segment_reduce_offsets_backward_fn = Tensor (*)( + const Tensor&, + const Tensor&, + const Tensor&, + SegmentReductionType, + const Tensor&, + int64_t, + const c10::optional&); +DECLARE_DISPATCH(segment_reduce_offsets_backward_fn, _segment_reduce_offsets_backward_stub); } // namespace native } // namespace at diff --git a/aten/src/ATen/native/cuda/SegmentReduce.cu b/aten/src/ATen/native/cuda/SegmentReduce.cu index ab8571df922..bfaa5cacd66 100644 --- a/aten/src/ATen/native/cuda/SegmentReduce.cu +++ b/aten/src/ATen/native/cuda/SegmentReduce.cu @@ -70,7 +70,7 @@ Tensor _get_complete_sum(const Tensor& lengths) { offsets[0].zero_(); AT_DISPATCH_INDEX_TYPES( - lengths.scalar_type(), "_segment_reduce_cuda_backward_kernel1", ([&] { + lengths.scalar_type(), "_segment_reduce_cuda_lengths_offsets_backward_kernel1", ([&] { auto* lengths_data_ptr = lengths.data_ptr(); auto* offsets_data_ptr = offsets.data_ptr(); at::cuda::cub::inclusive_sum( @@ -278,23 +278,33 @@ __global__ void segment_reduce_backward_kernel( } } // namespace -Tensor _segment_reduce_cuda_backward_kernel( +Tensor _segment_reduce_lengths_offsets_backward_cuda_kernel( const Tensor& grad_contig, const Tensor& output_contig, const Tensor& data_contig, SegmentReductionType reduction, - const Tensor& lengths_contig, + const Tensor& lengths_or_offsets_contig, int64_t axis, - const c10::optional& initial) { - axis = lengths_contig.dim() - 1; - int64_t segment_count = lengths_contig.size(axis); - int64_t lengths_stride_axis = lengths_contig.stride(axis); + const c10::optional& initial, + bool is_offsets_like) { + axis = lengths_or_offsets_contig.dim() - 1; + int64_t segment_count = is_offsets_like ? + lengths_or_offsets_contig.size(axis) - 1 : + lengths_or_offsets_contig.size(axis); + int64_t lengths_stride_axis = lengths_or_offsets_contig.stride(axis); auto grad_input = at::zeros({data_contig.sizes()}, grad_contig.options()); - auto zeros_shape = lengths_contig.sizes().vec(); - zeros_shape[axis] = 1; - auto offsets = at::cat({at::zeros(zeros_shape, lengths_contig.options()), lengths_contig}, axis); - offsets.cumsum_(axis); + auto offsets = lengths_or_offsets_contig; + auto lengths = lengths_or_offsets_contig; + if (is_offsets_like) { + lengths = lengths.diff(); + } else { + // _get_complete_sum only supports 1D + auto zeros_shape = offsets.sizes().vec(); + zeros_shape[axis] = 1; + offsets = at::cat({at::zeros(zeros_shape, offsets.options()), offsets}, axis); + offsets.cumsum_(axis); + } // outer_offset is the size of the outer dimensions of output (before axis) // inner_offset is the size of the inner dimensions of output (after axis) @@ -318,8 +328,8 @@ Tensor _segment_reduce_cuda_backward_kernel( auto offsets_stride_axis = offsets.stride(axis); AT_DISPATCH_INDEX_TYPES( - lengths_contig.scalar_type(), "_segment_reduce_cuda_backward_kernel1", ([&] { - const auto* lengths_data = lengths_contig.data_ptr(); + lengths_or_offsets_contig.scalar_type(), "_segment_reduce_cuda_lengths_offsets_backward_kernel1", ([&] { + const auto* lengths_data = lengths.data_ptr(); auto* offsets_data = offsets.data_ptr(); // TODO: Switch to TensorIterator for better maintainablility and @@ -371,27 +381,59 @@ Tensor _segment_reduce_cuda_backward_kernel( return grad_input; } -Tensor _segment_reduce_cuda_kernel( - SegmentReductionType reduction, - const Tensor& data, - const Tensor& lengths, - int64_t axis, - const c10::optional& initial) { - // data and lengths should be contiguous from the call to .contiguous in segment_reduce_kernel - TORCH_CHECK(data.is_contiguous(), "Expected data to be contiguous."); - TORCH_CHECK(lengths.is_contiguous(), "Expected lengths to be contiguous."); - axis = lengths.dim() - 1; - int64_t segment_count = lengths.size(axis); - int64_t lengths_stride_axis = lengths.stride(axis); +Tensor _segment_reduce_lengths_backward_cuda_kernel( + const Tensor& grad_contig, + const Tensor& output_contig, + const Tensor& data_contig, + SegmentReductionType reduction, + const Tensor& lengths_contig, + int64_t axis, + const c10::optional& initial) { + return _segment_reduce_lengths_offsets_backward_cuda_kernel( + grad_contig, output_contig, data_contig, reduction, lengths_contig, axis, initial, /*is_offsets_like=*/false); +} + +Tensor _segment_reduce_offsets_backward_cuda_kernel( + const Tensor& grad_contig, + const Tensor& output_contig, + const Tensor& data_contig, + SegmentReductionType reduction, + const Tensor& offsets_contig, + int64_t axis, + const c10::optional& initial) { + return _segment_reduce_lengths_offsets_backward_cuda_kernel( + grad_contig, output_contig, data_contig, reduction, offsets_contig, axis, initial, /*is_offsets_like=*/true); +} + +Tensor _segment_reduce_lengths_offsets_cuda_kernel( + SegmentReductionType reduction, + const Tensor& data, + const Tensor& lengths_or_offsets, + int64_t axis, + const c10::optional& initial, + bool is_offsets_like) { + // data and lengths_or_offsets should be contiguous from the call to .contiguous in segment_reduce_kernel + TORCH_CHECK(data.is_contiguous()); + TORCH_CHECK(lengths_or_offsets.is_contiguous()); + axis = lengths_or_offsets.dim() - 1; + int64_t segment_count = is_offsets_like ? lengths_or_offsets.size(axis) - 1 : lengths_or_offsets.size(axis); + int64_t lengths_stride_axis = lengths_or_offsets.stride(axis); auto output_shape = data.sizes().vec(); output_shape[axis] = segment_count; auto output = at::empty(output_shape, data.options()); - // _get_complete_sum only supports 1D? - auto zeros_shape = lengths.sizes().vec(); - zeros_shape[axis] = 1; - auto offsets = at::cat({at::zeros(zeros_shape, lengths.options()), lengths}, axis); - offsets.cumsum_(axis); + + auto offsets = lengths_or_offsets; + auto lengths = lengths_or_offsets; + if (is_offsets_like) { + lengths = lengths.diff(); + } else { + // _get_complete_sum only supports 1D + auto zeros_shape = offsets.sizes().vec(); + zeros_shape[axis] = 1; + offsets = at::cat({at::zeros(zeros_shape, offsets.options()), offsets}, axis); + offsets.cumsum_(axis); + } // outer_offset is the size of the outer dimensions of output (before axis) // inner_offset is the size of the inner dimensions of output (after axis) @@ -416,7 +458,7 @@ Tensor _segment_reduce_cuda_kernel( auto offsets_stride_axis = offsets.stride(axis); AT_DISPATCH_INDEX_TYPES( - lengths.scalar_type(), "_segment_reduce_cuda_kernel1", ([&] { + lengths_or_offsets.scalar_type(), "_segment_reduce_cuda_kernel1", ([&] { auto* offsets_data_ptr = offsets.data_ptr(); auto* lengths_data_ptr = lengths.data_ptr(); AT_DISPATCH_FLOATING_TYPES_AND2( @@ -549,10 +591,34 @@ Tensor _segment_reduce_cuda_kernel( return output; } -REGISTER_DISPATCH(_segment_reduce_stub, &_segment_reduce_cuda_kernel); +Tensor _segment_reduce_lengths_cuda_kernel( + SegmentReductionType reduction, + const Tensor& data, + const Tensor& lengths, + int64_t axis, + const c10::optional& initial) { + return _segment_reduce_lengths_offsets_cuda_kernel( + reduction, data, lengths, axis, initial, /*is_offsets_like=*/false); +} + +Tensor _segment_reduce_offsets_cuda_kernel( + SegmentReductionType reduction, + const Tensor& data, + const Tensor& offsets, + int64_t axis, + const c10::optional& initial) { + return _segment_reduce_lengths_offsets_cuda_kernel( + reduction, data, offsets, axis, initial, /*is_offsets_like=*/true); +} + +REGISTER_DISPATCH(_segment_reduce_lengths_stub, &_segment_reduce_lengths_cuda_kernel); +REGISTER_DISPATCH(_segment_reduce_offsets_stub, &_segment_reduce_offsets_cuda_kernel); REGISTER_DISPATCH( - _segment_reduce_backward_stub, - &_segment_reduce_cuda_backward_kernel); + _segment_reduce_lengths_backward_stub, + &_segment_reduce_lengths_backward_cuda_kernel); +REGISTER_DISPATCH( + _segment_reduce_offsets_backward_stub, + &_segment_reduce_offsets_backward_cuda_kernel); } // namespace native } // namespace at diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 4d750973e4b..87974033742 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -11924,12 +11924,12 @@ dispatch: CompositeExplicitAutograd: _test_warn_in_autograd -- func: segment_reduce(Tensor data, str reduce, *, Tensor? lengths=None, Tensor? indices=None, int axis=0, bool unsafe=False, Scalar? initial=None) -> Tensor +- func: segment_reduce(Tensor data, str reduce, *, Tensor? lengths=None, Tensor? indices=None, Tensor? offsets=None, int axis=0, bool unsafe=False, Scalar? initial=None) -> Tensor variants: function dispatch: CPU, CUDA: segment_reduce_kernel -- func: _segment_reduce_backward(Tensor grad, Tensor output, Tensor data, str reduce, *, Tensor? lengths=None, int axis=0, Scalar? initial=None) -> Tensor +- func: _segment_reduce_backward(Tensor grad, Tensor output, Tensor data, str reduce, *, Tensor? lengths=None, Tensor? offsets=None, int axis=0, Scalar? initial=None) -> Tensor variants: function dispatch: CPU, CUDA: _segment_reduce_backward_kernel diff --git a/test/forward_backward_compatibility/check_forward_backward_compatibility.py b/test/forward_backward_compatibility/check_forward_backward_compatibility.py index 4a073301804..86f91976c3c 100644 --- a/test/forward_backward_compatibility/check_forward_backward_compatibility.py +++ b/test/forward_backward_compatibility/check_forward_backward_compatibility.py @@ -87,8 +87,8 @@ ALLOW_LIST = [ ("prim::infer_squeeze_size", datetime.date(9999, 1, 1)), ("aten::_weight_norm_cuda_interface", datetime.date(9999, 1, 1)), ("aten::_weight_norm_cuda_interface_backward", datetime.date(9999, 1, 1)), - ("aten::segment_reduce", datetime.date(9999, 1, 1)), - ("aten::_segment_reduce_backward", datetime.date(9999, 1, 1)), + ("aten::segment_reduce", datetime.date(2022, 6, 30)), + ("aten::_segment_reduce_backward", datetime.date(2022, 6, 30)), ("aten::empty.SymInt", datetime.date(9999, 1, 1)), # TODO: FIXME: prims shouldn't be checked ("prims::.*", datetime.date(9999, 1, 1)), diff --git a/test/test_segment_reductions.py b/test/test_segment_reductions.py index 20f871ae4bd..b91a56eeb15 100644 --- a/test/test_segment_reductions.py +++ b/test/test_segment_reductions.py @@ -1,6 +1,7 @@ # Owner(s): ["module: scatter & gather ops"] from itertools import product +from functools import partial import numpy as np import torch @@ -52,6 +53,11 @@ class TestSegmentReductions(TestCase): lengths_dtype=torch.int, ): lengths = torch.tensor(lengths_arr, device=device, dtype=lengths_dtype) + # generate offsets from lengths + zeros_shape = list(lengths.shape) + zeros_shape[-1] = 1 + offsets = torch.cat((lengths.new_zeros(zeros_shape), lengths), -1).cumsum_(-1) + data = torch.tensor( data_arr, device=device, @@ -60,52 +66,56 @@ class TestSegmentReductions(TestCase): ) expected_result = torch.tensor(expected_arr, device=device, dtype=dtype) expected_grad = torch.tensor(expected_grad_arr, device=device, dtype=dtype) - actual_result = torch.segment_reduce( - data=data, - reduce=reduction, - lengths=lengths, - axis=axis, - unsafe=unsafe, - initial=initial_value, - ) - self.assertEqual( - expected_result, actual_result, rtol=1e-02, atol=1e-05, equal_nan=True - ) - - if not check_backward: - return - - # Test backward - actual_result.sum().backward() - self.assertEqual( - expected_grad, data.grad, rtol=1e-02, atol=1e-05, equal_nan=True - ) - - # gradcheck does not work well with bfloat16 or fp16 cpu types - # also there is small numerical difference with fp32 - if dtype not in [torch.half, torch.bfloat16, torch.float]: - # gradcheck does not like "nan" input, setting to random 10 - d_non_nan = np.nan_to_num(data_arr, nan=10) - data = torch.tensor( - # [10 if v == float("nan") else v for v in data], - d_non_nan, - device=device, - dtype=dtype, - requires_grad=True, + for mode in ['lengths', 'offsets']: + segment_reduce_kwargs = dict( + axis=axis, + unsafe=unsafe, + initial=initial_value) + if (mode == 'lengths'): + segment_reduce_kwargs['lengths'] = lengths + else: + segment_reduce_kwargs['offsets'] = offsets + actual_result = torch.segment_reduce( + data=data, + reduce=reduction, + **segment_reduce_kwargs ) - self.assertTrue( - gradcheck( - lambda x: torch.segment_reduce( - data=x, - reduce=reduction, - lengths=lengths, - axis=axis, - unsafe=unsafe, - initial=initial_value, - ), - (data,), + self.assertEqual( + expected_result, actual_result, rtol=1e-02, atol=1e-05, equal_nan=True + ) + + if not check_backward: + return + + # Test backward + actual_result.sum().backward() + self.assertEqual( + expected_grad, data.grad, rtol=1e-02, atol=1e-05, equal_nan=True + ) + data = data.clone().detach().requires_grad_(True) + + # gradcheck does not work well with bfloat16 or fp16 cpu types + # also there is small numerical difference with fp32 + if dtype not in [torch.half, torch.bfloat16, torch.float]: + # gradcheck does not like "nan" input, setting to random 10 + d_non_nan = np.nan_to_num(data_arr, nan=10) + new_data = torch.tensor( + # [10 if v == float("nan") else v for v in data], + d_non_nan, + device=device, + dtype=dtype, + requires_grad=True, + ) + self.assertTrue( + gradcheck( + lambda x: torch.segment_reduce( + data=x, + reduce=reduction, + **segment_reduce_kwargs + ), + (new_data,), + ) ) - ) @dtypes( *product( @@ -384,8 +394,18 @@ class TestSegmentReductions(TestCase): ) self.assertEqual(actual_result, expected) + # test offsets + actual_result = torch.segment_reduce( + data=data, + reduce=reduce, + offsets=indptr, + axis=dim, + unsafe=True, + ) + self.assertEqual(actual_result, expected) + if val_dtype == torch.float64: - def fn(x): + def fn(x, mode='lengths'): initial = 1 # supply initial values to prevent gradcheck from failing for 0 length segments # where nan/inf are reduction identities that produce nans when calculating the numerical jacobian @@ -393,8 +413,16 @@ class TestSegmentReductions(TestCase): initial = 1000 elif reduce == 'max': initial = -1000 - return torch.segment_reduce(x, reduce, lengths=lengths, axis=dim, unsafe=True, initial=initial) - self.assertTrue(gradcheck(fn, (data.clone().detach().requires_grad_(True)))) + segment_reduce_args = {x, reduce} + segment_reduce_kwargs = dict(axis=dim, unsafe=True, initial=initial) + if mode == 'lengths': + segment_reduce_kwargs[mode] = lengths + elif mode == 'offsets': + segment_reduce_kwargs[mode] = indptr + return torch.segment_reduce(*segment_reduce_args, **segment_reduce_kwargs) + self.assertTrue(gradcheck(partial(fn, mode='lengths'), (data.clone().detach().requires_grad_(True)))) + self.assertTrue(gradcheck(partial(fn, mode='offsets'), (data.clone().detach().requires_grad_(True)))) + @dtypes( *product( diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 1ec70448720..124dbceb21f 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -2731,8 +2731,8 @@ - name: nonzero(Tensor self) -> Tensor output_differentiability: [False] -- name: segment_reduce(Tensor data, str reduce, *, Tensor? lengths=None, Tensor? indices=None, int axis=0, bool unsafe=False, Scalar? initial=None) -> Tensor - data: _segment_reduce_backward(grad, result, data, reduce, lengths, axis, initial) +- name: segment_reduce(Tensor data, str reduce, *, Tensor? lengths=None, Tensor? indices=None, Tensor? offsets=None, int axis=0, bool unsafe=False, Scalar? initial=None) -> Tensor + data: _segment_reduce_backward(grad, result, data, reduce, lengths, offsets, axis, initial) - name: _pin_memory(Tensor self, Device? device=None) -> Tensor self: grad diff --git a/torch/overrides.py b/torch/overrides.py index 933817019d6..dea16a87530 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -948,7 +948,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.scatter_add: lambda input, dim, index, src: -1, torch.scatter_reduce: lambda input, dim, index, src, reduce, include_self=True: -1, torch.searchsorted: lambda sorted_sequence, input, out_int32=False, right=False, out=None: -1, - torch.segment_reduce: lambda data, reduce="max", lengths=None, indices=None, axis=0, unsafe=False: -1, + torch.segment_reduce: lambda data, reduce="max", lengths=None, indices=None, offsets=None, axis=0, unsafe=False: -1, torch.select: lambda input, dim, index: -1, torch.select_scatter: lambda input, src, dim, index: -1, torch.slice_scatter: lambda input, src, dim=0, start=None, end=None, step=1: -1, diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 8311acf84c9..e0814e57199 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -8449,9 +8449,19 @@ def sample_inputs_segment_reduce(op_info, device, dtype, requires_grad, *, mode= for args, reduce, initial in product(test_cases, reductions, [1, 2]): inp_shape, dim, lengths, unsafe = args lengths_t = torch.tensor(lengths, dtype=torch.long, device=device) + sample_input_kwargs = {'axis': dim, 'unsafe': unsafe, 'initial': initial} + if mode == 'lengths': + sample_input_kwargs['lengths'] = lengths_t + elif mode == 'offsets': + zeros_shape = list(lengths_t.shape) + zeros_shape[dim] = 1 + offsets_t = torch.cat((lengths_t.new_zeros(zeros_shape), lengths_t), dim).cumsum_(dim) + sample_input_kwargs['offsets'] = offsets_t + else: + raise RuntimeError(f"mode most be one of 'offsets' or 'lengths' got '{mode}'.") yield SampleInput(_tensor(inp_shape), args=(reduce,), - kwargs={'lengths': lengths_t, 'axis': dim, 'unsafe': unsafe, 'initial': initial}) + kwargs=sample_input_kwargs) def sample_inputs_ravel(op_info, device, dtype, requires_grad, **kwargs): @@ -19586,6 +19596,25 @@ op_db: List[OpInfo] = [ ), ), ), + OpInfo( + 'segment_reduce', + variant_test_name='offsets', + dtypes=floating_types_and(torch.float16, torch.bfloat16), + supports_out=False, + # RuntimeError: derivative for aten::_segment_reduce_backward is not implemented + supports_gradgrad=False, + sample_inputs_func=partial(sample_inputs_segment_reduce, mode='offsets'), + skips=( + # FIXME: CUDA driver API confirmed a leak in + # __main__.TestJitCUDA.test_variant_consistency_jit_segment_reduce_cuda_float32 + DecorateInfo( + unittest.skip("Skipped!"), + "TestJit", + "test_variant_consistency_jit", + device_type="cuda", + ), + ), + ), UnaryUfuncInfo( 'special.bessel_j0', decorators=(