Handle incorrect perm data in Transpose op gracefully (#739)

This commit is contained in:
Hariharan Seshadri 2019-03-29 10:42:34 -07:00 committed by Changming Sun
parent 667fa39551
commit e6a2bdfacd
3 changed files with 18 additions and 3 deletions

View file

@ -267,7 +267,9 @@ Status Transpose::Compute(OpKernelContext* ctx) const {
std::vector<int64_t> output_dims(rank);
const std::vector<int64_t>* p_perm;
std::vector<int64_t> default_perm(rank);
ComputeOutputShape(X, output_dims, default_perm, p_perm);
const auto& status = ComputeOutputShape(X, output_dims, default_perm, p_perm);
if (!status.IsOK())
return status;
TensorShape output_shape{output_dims};
Tensor& Y = *ctx->Output(0, output_shape);

View file

@ -6,6 +6,7 @@
#include "gsl/gsl_util"
#include "core/common/common.h"
#include "core/framework/op_kernel.h"
#include <sstream>
namespace onnxruntime {
@ -36,7 +37,7 @@ class TransposeBase {
}
}
void ComputeOutputShape(const Tensor& X, std::vector<int64_t>& output_dims,
Status ComputeOutputShape(const Tensor& X, std::vector<int64_t>& output_dims,
std::vector<int64_t>& default_perm, const std::vector<int64_t>*& p_perm) const {
size_t rank = X.Shape().NumDimensions();
const auto& input_dims = X.Shape().GetDims();
@ -57,8 +58,18 @@ class TransposeBase {
output_dims.resize(rank);
for (int i = 0; i < rank; i++) {
size_t inpdim = (*p_perm)[i];
if (inpdim >= rank) {
std::ostringstream ss;
ss << "[ ";
for (const auto& p : *p_perm)
ss << p << " ";
ss << "]";
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"perm: ", ss.str(), " does not align with rank of input data: ", std::to_string(rank));
}
output_dims[i] = input_dims[inpdim];
}
return Status::OK();
}
bool perm_specified_ = false;

View file

@ -31,7 +31,9 @@ Status Transpose<T>::ComputeInternal(OpKernelContext* ctx) const {
std::vector<int64_t> output_dims(rank);
std::vector<int64_t> default_perm(rank);
const std::vector<int64_t>* p_perm = nullptr;
ComputeOutputShape(X, output_dims, default_perm, p_perm);
const auto& status = ComputeOutputShape(X, output_dims, default_perm, p_perm);
if (!status.IsOK())
return status;
TensorShape output_shape{output_dims};
Tensor* Y = ctx->Output(0, output_shape);