mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-03 23:49:44 +00:00
Handle incorrect perm data in Transpose op gracefully (#739)
This commit is contained in:
parent
667fa39551
commit
e6a2bdfacd
3 changed files with 18 additions and 3 deletions
|
|
@ -267,7 +267,9 @@ Status Transpose::Compute(OpKernelContext* ctx) const {
|
|||
std::vector<int64_t> output_dims(rank);
|
||||
const std::vector<int64_t>* p_perm;
|
||||
std::vector<int64_t> default_perm(rank);
|
||||
ComputeOutputShape(X, output_dims, default_perm, p_perm);
|
||||
const auto& status = ComputeOutputShape(X, output_dims, default_perm, p_perm);
|
||||
if (!status.IsOK())
|
||||
return status;
|
||||
|
||||
TensorShape output_shape{output_dims};
|
||||
Tensor& Y = *ctx->Output(0, output_shape);
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@
|
|||
#include "gsl/gsl_util"
|
||||
#include "core/common/common.h"
|
||||
#include "core/framework/op_kernel.h"
|
||||
#include <sstream>
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
|
|
@ -36,7 +37,7 @@ class TransposeBase {
|
|||
}
|
||||
}
|
||||
|
||||
void ComputeOutputShape(const Tensor& X, std::vector<int64_t>& output_dims,
|
||||
Status ComputeOutputShape(const Tensor& X, std::vector<int64_t>& output_dims,
|
||||
std::vector<int64_t>& default_perm, const std::vector<int64_t>*& p_perm) const {
|
||||
size_t rank = X.Shape().NumDimensions();
|
||||
const auto& input_dims = X.Shape().GetDims();
|
||||
|
|
@ -57,8 +58,18 @@ class TransposeBase {
|
|||
output_dims.resize(rank);
|
||||
for (int i = 0; i < rank; i++) {
|
||||
size_t inpdim = (*p_perm)[i];
|
||||
if (inpdim >= rank) {
|
||||
std::ostringstream ss;
|
||||
ss << "[ ";
|
||||
for (const auto& p : *p_perm)
|
||||
ss << p << " ";
|
||||
ss << "]";
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"perm: ", ss.str(), " does not align with rank of input data: ", std::to_string(rank));
|
||||
}
|
||||
output_dims[i] = input_dims[inpdim];
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool perm_specified_ = false;
|
||||
|
|
|
|||
|
|
@ -31,7 +31,9 @@ Status Transpose<T>::ComputeInternal(OpKernelContext* ctx) const {
|
|||
std::vector<int64_t> output_dims(rank);
|
||||
std::vector<int64_t> default_perm(rank);
|
||||
const std::vector<int64_t>* p_perm = nullptr;
|
||||
ComputeOutputShape(X, output_dims, default_perm, p_perm);
|
||||
const auto& status = ComputeOutputShape(X, output_dims, default_perm, p_perm);
|
||||
if (!status.IsOK())
|
||||
return status;
|
||||
|
||||
TensorShape output_shape{output_dims};
|
||||
Tensor* Y = ctx->Output(0, output_shape);
|
||||
|
|
|
|||
Loading…
Reference in a new issue