mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-08 00:23:03 +00:00
refactor: extract shared util function ComputeBroadcastOutputShape (#21940)
### Description This is used in multiple places.
This commit is contained in:
parent
628c0a8f0e
commit
decb3852a0
10 changed files with 46 additions and 77 deletions
|
|
@ -10,6 +10,7 @@
|
|||
|
||||
// ORT system.
|
||||
#include "core/providers/cuda/tensor/expand.h"
|
||||
#include "core/providers/common.h"
|
||||
|
||||
// std C++.
|
||||
#include <iostream>
|
||||
|
|
@ -51,7 +52,7 @@ Status DistributedExpand<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
TensorShapeVector original_output_dims{p_shape, p_shape + shape_tensor->Shape().Size()};
|
||||
TensorShape original_output_shape(original_output_dims);
|
||||
ORT_ENFORCE(
|
||||
onnxruntime::cuda::ComputeOutputShape(
|
||||
onnxruntime::ComputeBroadcastOutputShape(
|
||||
Node().Name(),
|
||||
original_input_shape,
|
||||
original_output_dims, original_output_shape)
|
||||
|
|
|
|||
|
|
@ -224,34 +224,5 @@ void GenerateHashValue(const std::string string, HashValue& hash_value) {
|
|||
hash_value = hash[0] | (uint64_t(hash[1]) << 32);
|
||||
}
|
||||
|
||||
Status ComputeOutputShape(const std::string& node_name, const TensorShape& lhs_shape,
|
||||
const TensorShape& rhs_shape, TensorShape& out_shape) {
|
||||
size_t lhs_rank = lhs_shape.NumDimensions();
|
||||
size_t rhs_rank = rhs_shape.NumDimensions();
|
||||
size_t out_rank = std::max(lhs_rank, rhs_rank);
|
||||
|
||||
std::vector<int64_t> output_dims(out_rank, 0);
|
||||
for (size_t i = 0; i < out_rank; ++i) {
|
||||
int64_t lhs_dim = 1;
|
||||
if (i < lhs_rank)
|
||||
lhs_dim = lhs_shape[lhs_rank - 1 - i];
|
||||
int64_t rhs_dim = 1;
|
||||
if (i < rhs_rank)
|
||||
rhs_dim = rhs_shape[rhs_rank - 1 - i];
|
||||
int64_t max = std::max(lhs_dim, rhs_dim);
|
||||
int64_t min = std::min(lhs_dim, rhs_dim);
|
||||
int64_t out_dim = (min == 0 ? min : max); // special case a dim value of 0.
|
||||
if (lhs_dim != out_dim && lhs_dim != 1)
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_name, ": left operand cannot broadcast on dim ", lhs_rank - 1 - i,
|
||||
" LeftShape: ", lhs_shape.ToString(), ", RightShape: ", rhs_shape.ToString());
|
||||
if (rhs_dim != out_dim && rhs_dim != 1)
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_name, ": right operand cannot broadcast on dim ", rhs_rank - 1 - i,
|
||||
" LeftShape: ", lhs_shape.ToString(), ", RightShape: ", rhs_shape.ToString());
|
||||
output_dims[out_rank - 1 - i] = out_dim;
|
||||
}
|
||||
out_shape = TensorShape(output_dims);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace cann
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -124,8 +124,6 @@ Status aclrtblasGemmEx(aclTransType transA,
|
|||
|
||||
bool FileExist(const std::string& file_name);
|
||||
void GenerateHashValue(const std::string string, HashValue& hash_value);
|
||||
Status ComputeOutputShape(const std::string& node_name, const TensorShape& lhs_shape,
|
||||
const TensorShape& rhs_shape, TensorShape& out_shape);
|
||||
|
||||
std::unique_ptr<Model> CreateModel(const GraphViewer& graph_viewer, const logging::Logger& logger);
|
||||
|
||||
|
|
|
|||
|
|
@ -2,6 +2,8 @@
|
|||
// Copyright (c) Huawei. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/providers/shared_library/provider_api.h"
|
||||
#include "core/providers/common.h"
|
||||
#include "core/providers/cann/math/binary_elementwise_ops.h"
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
|
|
@ -20,7 +22,7 @@ Status BinaryElementwise::Prepare(OpKernelContext* ctx, CannPreparation& prepare
|
|||
const Tensor* B = ctx->Input<Tensor>(1);
|
||||
|
||||
TensorShape output_shape;
|
||||
ORT_RETURN_IF_ERROR(ComputeOutputShape(Node().Name(), A->Shape(), B->Shape(), output_shape));
|
||||
ORT_RETURN_IF_ERROR(ComputeBroadcastOutputShape(Node().Name(), A->Shape(), B->Shape(), output_shape));
|
||||
Tensor* C = ctx->Output(0, output_shape);
|
||||
|
||||
void* A_data = const_cast<void*>(A->DataRaw());
|
||||
|
|
|
|||
|
|
@ -180,4 +180,38 @@ T Product(const Container<T>& c) {
|
|||
return accumulate(c.cbegin(), c.cend(), static_cast<T>(1), std::multiplies<T>());
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Compute the output shape for broadcasting the given input shapes of lhs and rhs.
|
||||
/// </summary>
|
||||
inline Status ComputeBroadcastOutputShape(const std::string& node_name,
|
||||
const TensorShape& lhs_shape,
|
||||
const TensorShape& rhs_shape,
|
||||
TensorShape& out_shape) {
|
||||
size_t lhs_rank = lhs_shape.NumDimensions();
|
||||
size_t rhs_rank = rhs_shape.NumDimensions();
|
||||
size_t out_rank = std::max(lhs_rank, rhs_rank);
|
||||
|
||||
std::vector<int64_t> output_dims(out_rank, 0);
|
||||
for (size_t i = 0; i < out_rank; ++i) {
|
||||
int64_t lhs_dim = 1;
|
||||
if (i < lhs_rank)
|
||||
lhs_dim = lhs_shape[lhs_rank - 1 - i];
|
||||
int64_t rhs_dim = 1;
|
||||
if (i < rhs_rank)
|
||||
rhs_dim = rhs_shape[rhs_rank - 1 - i];
|
||||
int64_t max = std::max(lhs_dim, rhs_dim);
|
||||
int64_t min = std::min(lhs_dim, rhs_dim);
|
||||
int64_t out_dim = (min == 0 ? min : max); // special case a dim value of 0.
|
||||
if (lhs_dim != out_dim && lhs_dim != 1)
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_name, ": left operand cannot broadcast on dim ", lhs_rank - 1 - i,
|
||||
" LeftShape: ", lhs_shape.ToString(), ", RightShape: ", rhs_shape.ToString());
|
||||
if (rhs_dim != out_dim && rhs_dim != 1)
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_name, ": right operand cannot broadcast on dim ", rhs_rank - 1 - i,
|
||||
" LeftShape: ", lhs_shape.ToString(), ", RightShape: ", rhs_shape.ToString());
|
||||
output_dims[out_rank - 1 - i] = out_dim;
|
||||
}
|
||||
out_shape = TensorShape(output_dims);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/providers/shared_library/provider_api.h"
|
||||
#include "core/providers/common.h"
|
||||
#include "core/providers/cuda/math/binary_elementwise_ops.h"
|
||||
#include "core/providers/cuda/math/binary_elementwise_ops_impl.h"
|
||||
#include "core/providers/cuda/math/unary_elementwise_ops_impl.h"
|
||||
|
|
@ -21,34 +23,6 @@ Status BinaryElementwise<ShouldNotBroadcast>::Prepare(OpKernelContext* context,
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ComputeOutputShape(const std::string& node_name, const TensorShape& lhs_shape, const TensorShape& rhs_shape, TensorShape& out_shape) {
|
||||
size_t lhs_rank = lhs_shape.NumDimensions();
|
||||
size_t rhs_rank = rhs_shape.NumDimensions();
|
||||
size_t out_rank = std::max(lhs_rank, rhs_rank);
|
||||
|
||||
std::vector<int64_t> output_dims(out_rank, 0);
|
||||
for (size_t i = 0; i < out_rank; ++i) {
|
||||
int64_t lhs_dim = 1;
|
||||
if (i < lhs_rank)
|
||||
lhs_dim = lhs_shape[lhs_rank - 1 - i];
|
||||
int64_t rhs_dim = 1;
|
||||
if (i < rhs_rank)
|
||||
rhs_dim = rhs_shape[rhs_rank - 1 - i];
|
||||
int64_t max = std::max(lhs_dim, rhs_dim);
|
||||
int64_t min = std::min(lhs_dim, rhs_dim);
|
||||
int64_t out_dim = (min == 0 ? min : max); // special case a dim value of 0.
|
||||
if (lhs_dim != out_dim && lhs_dim != 1)
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_name, ": left operand cannot broadcast on dim ", lhs_rank - 1 - i,
|
||||
" LeftShape: ", lhs_shape.ToString(), ", RightShape: ", rhs_shape.ToString());
|
||||
if (rhs_dim != out_dim && rhs_dim != 1)
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_name, ": right operand cannot broadcast on dim ", rhs_rank - 1 - i,
|
||||
" LeftShape: ", lhs_shape.ToString(), ", RightShape: ", rhs_shape.ToString());
|
||||
output_dims[out_rank - 1 - i] = out_dim;
|
||||
}
|
||||
out_shape = TensorShape(output_dims);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status BinaryElementwiseBroadcastPrepare(
|
||||
const Tensor* lhs_tensor,
|
||||
const Tensor* rhs_tensor,
|
||||
|
|
@ -77,7 +51,7 @@ Status BinaryElementwise<ShouldBroadcast>::Prepare(OpKernelContext* context, Bin
|
|||
const auto& rhs_shape = rhs_tensor->Shape();
|
||||
|
||||
TensorShape output_shape;
|
||||
ORT_RETURN_IF_ERROR(ComputeOutputShape(Node().Name(), lhs_shape, rhs_shape, output_shape));
|
||||
ORT_RETURN_IF_ERROR(ComputeBroadcastOutputShape(Node().Name(), lhs_shape, rhs_shape, output_shape));
|
||||
auto output_tensor = context->Output(0, output_shape);
|
||||
|
||||
ORT_RETURN_IF_ERROR(BinaryElementwiseBroadcastPrepare(lhs_tensor, rhs_tensor, output_tensor, p));
|
||||
|
|
|
|||
|
|
@ -108,12 +108,6 @@ struct BinaryElementwisePreparation {
|
|||
}
|
||||
};
|
||||
|
||||
Status ComputeOutputShape(
|
||||
const std::string& node_name,
|
||||
const TensorShape& lhs_shape,
|
||||
const TensorShape& rhs_shape,
|
||||
TensorShape& out_shape);
|
||||
|
||||
Status BinaryElementwiseBroadcastPrepare(
|
||||
const Tensor* lhs_tensor,
|
||||
const Tensor* rhs_tensor,
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/providers/shared_library/provider_api.h"
|
||||
#include "core/providers/common.h"
|
||||
#include "core/providers/cuda/math/variadic_elementwise_ops.h"
|
||||
|
||||
#include <cassert>
|
||||
|
|
@ -209,7 +210,7 @@ Status VariadicElementwiseOp<VariadicElementwiseOpTag, SupportedElementTypes...>
|
|||
TensorShape output_shape;
|
||||
TensorShape previous_output_shape = first_input_tensor.Shape();
|
||||
for (int index = 1; index < input_count; index++) {
|
||||
ORT_RETURN_IF_ERROR(ComputeOutputShape(
|
||||
ORT_RETURN_IF_ERROR(ComputeBroadcastOutputShape(
|
||||
node_name, previous_output_shape, input_tensors[index].get().Shape(), output_shape));
|
||||
previous_output_shape = output_shape;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -95,7 +95,7 @@ Status Expand::ComputeInternal(OpKernelContext* ctx) const {
|
|||
TensorShapeVector output_dims{p_shape, p_shape + input_shape_tensor.Shape().Size()};
|
||||
TensorShape output_shape(output_dims);
|
||||
|
||||
ORT_RETURN_IF_ERROR(ComputeOutputShape(Node().Name(), input_data_tensor.Shape(), output_dims, output_shape));
|
||||
ORT_RETURN_IF_ERROR(ComputeBroadcastOutputShape(Node().Name(), input_data_tensor.Shape(), output_dims, output_shape));
|
||||
auto& output_tensor = *ctx->Output(0, output_shape);
|
||||
if (0 == output_shape.Size()) {
|
||||
return Status::OK();
|
||||
|
|
@ -202,7 +202,7 @@ std::unique_ptr<Tensor> FuncExpand(
|
|||
TensorShape output_shape(output_dims);
|
||||
|
||||
ORT_ENFORCE(
|
||||
ComputeOutputShape(
|
||||
ComputeBroadcastOutputShape(
|
||||
cuda_kernel->Node().Name(),
|
||||
input_data_tensor->Shape(),
|
||||
output_dims, output_shape)
|
||||
|
|
|
|||
|
|
@ -14,12 +14,6 @@ class Expand final : public CudaKernel {
|
|||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
};
|
||||
|
||||
Status ComputeOutputShape(
|
||||
const std::string& node_name,
|
||||
const TensorShape& lhs_shape,
|
||||
const TensorShape& rhs_shape,
|
||||
TensorShape& out_shape);
|
||||
|
||||
Status FuncExpand(
|
||||
const CudaKernel* cuda_kernel,
|
||||
OpKernelContext* ctx,
|
||||
|
|
|
|||
Loading…
Reference in a new issue