refactor: extract shared util function ComputeBroadcastOutputShape (#21940)

### Description

This is used in multiple places.
This commit is contained in:
Yulong Wang 2024-09-03 18:21:36 -07:00 committed by GitHub
parent 628c0a8f0e
commit decb3852a0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 46 additions and 77 deletions

View file

@ -10,6 +10,7 @@
// ORT system.
#include "core/providers/cuda/tensor/expand.h"
#include "core/providers/common.h"
// std C++.
#include <iostream>
@ -51,7 +52,7 @@ Status DistributedExpand<T>::ComputeInternal(OpKernelContext* context) const {
TensorShapeVector original_output_dims{p_shape, p_shape + shape_tensor->Shape().Size()};
TensorShape original_output_shape(original_output_dims);
ORT_ENFORCE(
onnxruntime::cuda::ComputeOutputShape(
onnxruntime::ComputeBroadcastOutputShape(
Node().Name(),
original_input_shape,
original_output_dims, original_output_shape)

View file

@ -224,34 +224,5 @@ void GenerateHashValue(const std::string string, HashValue& hash_value) {
hash_value = hash[0] | (uint64_t(hash[1]) << 32);
}
Status ComputeOutputShape(const std::string& node_name, const TensorShape& lhs_shape,
const TensorShape& rhs_shape, TensorShape& out_shape) {
size_t lhs_rank = lhs_shape.NumDimensions();
size_t rhs_rank = rhs_shape.NumDimensions();
size_t out_rank = std::max(lhs_rank, rhs_rank);
std::vector<int64_t> output_dims(out_rank, 0);
for (size_t i = 0; i < out_rank; ++i) {
int64_t lhs_dim = 1;
if (i < lhs_rank)
lhs_dim = lhs_shape[lhs_rank - 1 - i];
int64_t rhs_dim = 1;
if (i < rhs_rank)
rhs_dim = rhs_shape[rhs_rank - 1 - i];
int64_t max = std::max(lhs_dim, rhs_dim);
int64_t min = std::min(lhs_dim, rhs_dim);
int64_t out_dim = (min == 0 ? min : max); // special case a dim value of 0.
if (lhs_dim != out_dim && lhs_dim != 1)
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_name, ": left operand cannot broadcast on dim ", lhs_rank - 1 - i,
" LeftShape: ", lhs_shape.ToString(), ", RightShape: ", rhs_shape.ToString());
if (rhs_dim != out_dim && rhs_dim != 1)
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_name, ": right operand cannot broadcast on dim ", rhs_rank - 1 - i,
" LeftShape: ", lhs_shape.ToString(), ", RightShape: ", rhs_shape.ToString());
output_dims[out_rank - 1 - i] = out_dim;
}
out_shape = TensorShape(output_dims);
return Status::OK();
}
} // namespace cann
} // namespace onnxruntime

View file

@ -124,8 +124,6 @@ Status aclrtblasGemmEx(aclTransType transA,
bool FileExist(const std::string& file_name);
void GenerateHashValue(const std::string string, HashValue& hash_value);
Status ComputeOutputShape(const std::string& node_name, const TensorShape& lhs_shape,
const TensorShape& rhs_shape, TensorShape& out_shape);
std::unique_ptr<Model> CreateModel(const GraphViewer& graph_viewer, const logging::Logger& logger);

View file

@ -2,6 +2,8 @@
// Copyright (c) Huawei. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/shared_library/provider_api.h"
#include "core/providers/common.h"
#include "core/providers/cann/math/binary_elementwise_ops.h"
#include <vector>
#include <algorithm>
@ -20,7 +22,7 @@ Status BinaryElementwise::Prepare(OpKernelContext* ctx, CannPreparation& prepare
const Tensor* B = ctx->Input<Tensor>(1);
TensorShape output_shape;
ORT_RETURN_IF_ERROR(ComputeOutputShape(Node().Name(), A->Shape(), B->Shape(), output_shape));
ORT_RETURN_IF_ERROR(ComputeBroadcastOutputShape(Node().Name(), A->Shape(), B->Shape(), output_shape));
Tensor* C = ctx->Output(0, output_shape);
void* A_data = const_cast<void*>(A->DataRaw());

View file

@ -180,4 +180,38 @@ T Product(const Container<T>& c) {
return accumulate(c.cbegin(), c.cend(), static_cast<T>(1), std::multiplies<T>());
}
/// <summary>
/// Compute the output shape for broadcasting the given input shapes of lhs and rhs.
/// </summary>
inline Status ComputeBroadcastOutputShape(const std::string& node_name,
const TensorShape& lhs_shape,
const TensorShape& rhs_shape,
TensorShape& out_shape) {
size_t lhs_rank = lhs_shape.NumDimensions();
size_t rhs_rank = rhs_shape.NumDimensions();
size_t out_rank = std::max(lhs_rank, rhs_rank);
std::vector<int64_t> output_dims(out_rank, 0);
for (size_t i = 0; i < out_rank; ++i) {
int64_t lhs_dim = 1;
if (i < lhs_rank)
lhs_dim = lhs_shape[lhs_rank - 1 - i];
int64_t rhs_dim = 1;
if (i < rhs_rank)
rhs_dim = rhs_shape[rhs_rank - 1 - i];
int64_t max = std::max(lhs_dim, rhs_dim);
int64_t min = std::min(lhs_dim, rhs_dim);
int64_t out_dim = (min == 0 ? min : max); // special case a dim value of 0.
if (lhs_dim != out_dim && lhs_dim != 1)
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_name, ": left operand cannot broadcast on dim ", lhs_rank - 1 - i,
" LeftShape: ", lhs_shape.ToString(), ", RightShape: ", rhs_shape.ToString());
if (rhs_dim != out_dim && rhs_dim != 1)
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_name, ": right operand cannot broadcast on dim ", rhs_rank - 1 - i,
" LeftShape: ", lhs_shape.ToString(), ", RightShape: ", rhs_shape.ToString());
output_dims[out_rank - 1 - i] = out_dim;
}
out_shape = TensorShape(output_dims);
return Status::OK();
}
} // namespace onnxruntime

View file

@ -1,6 +1,8 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/shared_library/provider_api.h"
#include "core/providers/common.h"
#include "core/providers/cuda/math/binary_elementwise_ops.h"
#include "core/providers/cuda/math/binary_elementwise_ops_impl.h"
#include "core/providers/cuda/math/unary_elementwise_ops_impl.h"
@ -21,34 +23,6 @@ Status BinaryElementwise<ShouldNotBroadcast>::Prepare(OpKernelContext* context,
return Status::OK();
}
Status ComputeOutputShape(const std::string& node_name, const TensorShape& lhs_shape, const TensorShape& rhs_shape, TensorShape& out_shape) {
size_t lhs_rank = lhs_shape.NumDimensions();
size_t rhs_rank = rhs_shape.NumDimensions();
size_t out_rank = std::max(lhs_rank, rhs_rank);
std::vector<int64_t> output_dims(out_rank, 0);
for (size_t i = 0; i < out_rank; ++i) {
int64_t lhs_dim = 1;
if (i < lhs_rank)
lhs_dim = lhs_shape[lhs_rank - 1 - i];
int64_t rhs_dim = 1;
if (i < rhs_rank)
rhs_dim = rhs_shape[rhs_rank - 1 - i];
int64_t max = std::max(lhs_dim, rhs_dim);
int64_t min = std::min(lhs_dim, rhs_dim);
int64_t out_dim = (min == 0 ? min : max); // special case a dim value of 0.
if (lhs_dim != out_dim && lhs_dim != 1)
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_name, ": left operand cannot broadcast on dim ", lhs_rank - 1 - i,
" LeftShape: ", lhs_shape.ToString(), ", RightShape: ", rhs_shape.ToString());
if (rhs_dim != out_dim && rhs_dim != 1)
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_name, ": right operand cannot broadcast on dim ", rhs_rank - 1 - i,
" LeftShape: ", lhs_shape.ToString(), ", RightShape: ", rhs_shape.ToString());
output_dims[out_rank - 1 - i] = out_dim;
}
out_shape = TensorShape(output_dims);
return Status::OK();
}
Status BinaryElementwiseBroadcastPrepare(
const Tensor* lhs_tensor,
const Tensor* rhs_tensor,
@ -77,7 +51,7 @@ Status BinaryElementwise<ShouldBroadcast>::Prepare(OpKernelContext* context, Bin
const auto& rhs_shape = rhs_tensor->Shape();
TensorShape output_shape;
ORT_RETURN_IF_ERROR(ComputeOutputShape(Node().Name(), lhs_shape, rhs_shape, output_shape));
ORT_RETURN_IF_ERROR(ComputeBroadcastOutputShape(Node().Name(), lhs_shape, rhs_shape, output_shape));
auto output_tensor = context->Output(0, output_shape);
ORT_RETURN_IF_ERROR(BinaryElementwiseBroadcastPrepare(lhs_tensor, rhs_tensor, output_tensor, p));

View file

@ -108,12 +108,6 @@ struct BinaryElementwisePreparation {
}
};
Status ComputeOutputShape(
const std::string& node_name,
const TensorShape& lhs_shape,
const TensorShape& rhs_shape,
TensorShape& out_shape);
Status BinaryElementwiseBroadcastPrepare(
const Tensor* lhs_tensor,
const Tensor* rhs_tensor,

View file

@ -2,6 +2,7 @@
// Licensed under the MIT License.
#include "core/providers/shared_library/provider_api.h"
#include "core/providers/common.h"
#include "core/providers/cuda/math/variadic_elementwise_ops.h"
#include <cassert>
@ -209,7 +210,7 @@ Status VariadicElementwiseOp<VariadicElementwiseOpTag, SupportedElementTypes...>
TensorShape output_shape;
TensorShape previous_output_shape = first_input_tensor.Shape();
for (int index = 1; index < input_count; index++) {
ORT_RETURN_IF_ERROR(ComputeOutputShape(
ORT_RETURN_IF_ERROR(ComputeBroadcastOutputShape(
node_name, previous_output_shape, input_tensors[index].get().Shape(), output_shape));
previous_output_shape = output_shape;
}

View file

@ -95,7 +95,7 @@ Status Expand::ComputeInternal(OpKernelContext* ctx) const {
TensorShapeVector output_dims{p_shape, p_shape + input_shape_tensor.Shape().Size()};
TensorShape output_shape(output_dims);
ORT_RETURN_IF_ERROR(ComputeOutputShape(Node().Name(), input_data_tensor.Shape(), output_dims, output_shape));
ORT_RETURN_IF_ERROR(ComputeBroadcastOutputShape(Node().Name(), input_data_tensor.Shape(), output_dims, output_shape));
auto& output_tensor = *ctx->Output(0, output_shape);
if (0 == output_shape.Size()) {
return Status::OK();
@ -202,7 +202,7 @@ std::unique_ptr<Tensor> FuncExpand(
TensorShape output_shape(output_dims);
ORT_ENFORCE(
ComputeOutputShape(
ComputeBroadcastOutputShape(
cuda_kernel->Node().Name(),
input_data_tensor->Shape(),
output_dims, output_shape)

View file

@ -14,12 +14,6 @@ class Expand final : public CudaKernel {
Status ComputeInternal(OpKernelContext* context) const override;
};
Status ComputeOutputShape(
const std::string& node_name,
const TensorShape& lhs_shape,
const TensorShape& rhs_shape,
TensorShape& out_shape);
Status FuncExpand(
const CudaKernel* cuda_kernel,
OpKernelContext* ctx,