mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-29 23:06:41 +00:00
Introduce output type/shape validation (#17301)
### Description Validate outputs type and shapes. Make sure sparse initializers are taken into account. ### Motivation and Context ORT currently does not validate output types or shapes. Further, neither inputs or outputs take into account sparse initializers that are converted from dense. It is currently possible to pre-allocate a wrong type/shape buffer for output. Cc: @Craigacp
This commit is contained in:
parent
8818a99c93
commit
dbcc60bed5
5 changed files with 187 additions and 147 deletions
|
|
@ -1829,83 +1829,102 @@ const DataTransferManager& InferenceSession::GetDataTransferManager() const {
|
|||
return data_transfer_mgr_;
|
||||
}
|
||||
|
||||
common::Status InferenceSession::CheckShapes(const std::string& input_name, const TensorShape& input_shape,
|
||||
const TensorShape& expected_shape) const {
|
||||
auto input_shape_sz = input_shape.NumDimensions();
|
||||
auto expected_shape_sz = expected_shape.NumDimensions();
|
||||
if (input_shape_sz != expected_shape_sz) {
|
||||
std::ostringstream ostr;
|
||||
ostr << "Invalid rank for input: " << input_name << " Got: " << input_shape_sz << " Expected: " << expected_shape_sz
|
||||
<< " Please fix either the inputs or the model.";
|
||||
return Status(ONNXRUNTIME, INVALID_ARGUMENT, ostr.str());
|
||||
common::Status InferenceSession::CheckShapes(const std::string& input_output_name, const TensorShape& input_output_shape,
|
||||
const TensorShape& expected_shape, const char* input_output_moniker) const {
|
||||
const auto shape_size = input_output_shape.NumDimensions();
|
||||
const auto expected_shape_size = expected_shape.NumDimensions();
|
||||
if (shape_size != expected_shape_size) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid rank for ", input_output_moniker, ": ",
|
||||
input_output_name, " Got: ", shape_size, " Expected: ", expected_shape_size,
|
||||
" Please fix either the inputs/outputs or the model.");
|
||||
}
|
||||
|
||||
std::vector<size_t> invalid_dim_indices;
|
||||
for (size_t i = 0; i < input_shape_sz; ++i) {
|
||||
InlinedVector<size_t> invalid_dim_indices;
|
||||
for (size_t i = 0; i < shape_size; ++i) {
|
||||
if (expected_shape[i] < 0) {
|
||||
continue; // this represents a symbolic shape dimension
|
||||
}
|
||||
if (input_shape[i] != expected_shape[i]) {
|
||||
if (input_output_shape[i] != expected_shape[i]) {
|
||||
invalid_dim_indices.push_back(i);
|
||||
}
|
||||
}
|
||||
|
||||
if (!invalid_dim_indices.empty()) {
|
||||
std::ostringstream ostr;
|
||||
ostr << "Got invalid dimensions for input: " << input_name << " for the following indices\n";
|
||||
ostr << "Got invalid dimensions for " << input_output_moniker << ": " << input_output_name << " for the following indices\n";
|
||||
for (size_t i = 0, end = invalid_dim_indices.size(); i < end; ++i) {
|
||||
size_t idx = invalid_dim_indices[i];
|
||||
ostr << " index: " << idx << " Got: " << input_shape[idx] << " Expected: " << expected_shape[idx] << "\n";
|
||||
ostr << " index: " << idx << " Got: " << input_output_shape[idx] << " Expected: " << expected_shape[idx] << "\n";
|
||||
}
|
||||
ostr << " Please fix either the inputs or the model.";
|
||||
ostr << " Please fix either the inputs/outputs or the model.";
|
||||
return Status(ONNXRUNTIME, INVALID_ARGUMENT, ostr.str());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
static common::Status CheckTypes(MLDataType actual, MLDataType expected, const std::string& base_type) {
|
||||
static common::Status CheckTypes(MLDataType actual, MLDataType expected, const std::string& base_type,
|
||||
const char* input_output_moniker) {
|
||||
if (actual == expected) {
|
||||
return Status::OK();
|
||||
}
|
||||
std::ostringstream ostr;
|
||||
ostr << "Unexpected input data type. Actual: (";
|
||||
ostr << base_type;
|
||||
ostr << "(";
|
||||
ostr << DataTypeImpl::ToString(actual);
|
||||
ostr << ")) , expected: (";
|
||||
ostr << base_type;
|
||||
ostr << "(";
|
||||
ostr << DataTypeImpl::ToString(expected);
|
||||
ostr << "))";
|
||||
|
||||
return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, ostr.str());
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unexpected ", input_output_moniker, " data type. Actual: (",
|
||||
base_type, "(",
|
||||
DataTypeImpl::ToString(actual), ")) , expected: (", base_type, "(",
|
||||
DataTypeImpl::ToString(expected), "))");
|
||||
}
|
||||
|
||||
common::Status InferenceSession::ValidateInputs(gsl::span<const std::string> feed_names,
|
||||
gsl::span<const OrtValue> feeds) const {
|
||||
if (feed_names.size() != feeds.size()) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Size mismatch: feed_names has ", feed_names.size(),
|
||||
"elements, but feeds has ", feeds.size(), " elements.");
|
||||
common::Status InferenceSession::ValidateInputsOutputs(gsl::span<const std::string> names,
|
||||
gsl::span<const OrtValue> feeds_fetches,
|
||||
const InputOutputDefMetaMap& input_output_meta_map,
|
||||
ArgType arg_type) const {
|
||||
ORT_ENFORCE(arg_type == ArgType::kInput || arg_type == ArgType::kOutput, "Valid values kInput, kOutput");
|
||||
|
||||
const bool is_inputs = arg_type == ArgType::kInput;
|
||||
|
||||
const char* const input_output_moniker = is_inputs ? "input" : "output";
|
||||
const char* const feed_fetches_moniker = is_inputs ? "feed" : "fetch";
|
||||
|
||||
#if !defined(DISABLE_SPARSE_TENSORS)
|
||||
auto is_sparse_initializer = [this](const std::string& name) -> bool {
|
||||
int idx = -1;
|
||||
if (session_state_->GetOrtValueNameIdxMap().GetIdx(name, idx).IsOK()) {
|
||||
return session_state_->IsSparseInitializer(idx);
|
||||
}
|
||||
return false;
|
||||
};
|
||||
#endif
|
||||
|
||||
if (names.size() != feeds_fetches.size()) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, feed_fetches_moniker, " names has ", names.size(),
|
||||
" elements, but ", feed_fetches_moniker, " has ", feeds_fetches.size(), " elements.");
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < feeds.size(); ++i) {
|
||||
const auto& feed_name = feed_names[i];
|
||||
for (size_t i = 0; i < feeds_fetches.size(); ++i) {
|
||||
const auto& name = names[i];
|
||||
|
||||
auto iter = input_def_map_.find(feed_name);
|
||||
if (input_def_map_.end() == iter) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid Feed Input Name:", feed_name);
|
||||
auto iter = input_output_meta_map.find(name);
|
||||
if (input_output_meta_map.end() == iter) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid ", input_output_moniker, " name: ", name);
|
||||
}
|
||||
|
||||
const auto& input_output_ml_value = feeds_fetches[i];
|
||||
|
||||
// For outputs the user may supply an unallocated placeholder.
|
||||
if (!is_inputs && !input_output_ml_value.IsAllocated()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto expected_type = iter->second.ml_data_type;
|
||||
auto& input_ml_value = feeds[i];
|
||||
if (input_ml_value.IsTensor()) {
|
||||
|
||||
if (input_output_ml_value.IsTensor()) {
|
||||
if (!expected_type->IsTensorType()
|
||||
#if !defined(DISABLE_OPTIONAL_TYPE)
|
||||
&& !utils::IsOptionalTensor(expected_type)
|
||||
#endif
|
||||
) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input with name: ", feed_name,
|
||||
" is not expected to be of type tensor.");
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, input_output_moniker, " with name: '", name,
|
||||
"' expected to be of type: ", static_cast<int>(expected_type->type_), " but received a tensor");
|
||||
}
|
||||
|
||||
// check for type
|
||||
|
|
@ -1919,44 +1938,56 @@ common::Status InferenceSession::ValidateInputs(gsl::span<const std::string> fee
|
|||
auto expected_element_type = expected_type->AsTensorType()->GetElementType();
|
||||
#endif
|
||||
|
||||
auto input_element_type = input_ml_value.Get<Tensor>().DataType();
|
||||
ORT_RETURN_IF_ERROR_SESSIONID_(CheckTypes(input_element_type, expected_element_type, "tensor"));
|
||||
const auto& input_output_tensor = input_output_ml_value.Get<Tensor>();
|
||||
ORT_RETURN_IF_ERROR_SESSIONID_(CheckTypes(input_output_tensor.DataType(),
|
||||
expected_element_type, "tensor", input_output_moniker));
|
||||
|
||||
// check for shape
|
||||
const auto& expected_shape = iter->second.tensor_shape;
|
||||
if (expected_shape.NumDimensions() > 0) {
|
||||
const auto& input_shape = input_ml_value.Get<Tensor>().Shape();
|
||||
ORT_RETURN_IF_ERROR_SESSIONID_(CheckShapes(feed_name, input_shape, expected_shape));
|
||||
if (iter->second.tensor_shape.has_value()) {
|
||||
ORT_RETURN_IF_ERROR_SESSIONID_(CheckShapes(name, input_output_tensor.Shape(),
|
||||
*iter->second.tensor_shape, input_output_moniker));
|
||||
}
|
||||
} else if (input_ml_value.IsSparseTensor()) {
|
||||
} else if (input_output_ml_value.IsSparseTensor()) {
|
||||
#if !defined(DISABLE_SPARSE_TENSORS)
|
||||
if (!expected_type->IsSparseTensorType()) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input with name: ", feed_name,
|
||||
" is not expected to be of type sparse tensor.");
|
||||
}
|
||||
auto expected_element_type = expected_type->AsSparseTensorType()->GetElementType();
|
||||
const SparseTensor& sparse_tensor = input_ml_value.Get<SparseTensor>();
|
||||
auto input_element_type = sparse_tensor.DataType();
|
||||
ORT_RETURN_IF_ERROR_SESSIONID_(CheckTypes(input_element_type, expected_element_type, "sparse_tensor"));
|
||||
// Check shape
|
||||
const auto& expected_shape = iter->second.tensor_shape;
|
||||
if (expected_shape.NumDimensions() > 0) {
|
||||
const auto& input_shape = sparse_tensor.DenseShape();
|
||||
ORT_RETURN_IF_ERROR_SESSIONID_(CheckShapes(feed_name, input_shape, expected_shape));
|
||||
|
||||
const SparseTensor& sparse_tensor = input_output_ml_value.Get<SparseTensor>();
|
||||
if (expected_type->IsSparseTensorType()) {
|
||||
auto expected_element_type = expected_type->AsSparseTensorType()->GetElementType();
|
||||
ORT_RETURN_IF_ERROR_SESSIONID_(CheckTypes(sparse_tensor.DataType(), expected_element_type,
|
||||
"sparse_tensor", input_output_moniker));
|
||||
// Check shape
|
||||
if (iter->second.tensor_shape.has_value()) {
|
||||
ORT_RETURN_IF_ERROR_SESSIONID_(CheckShapes(name, sparse_tensor.DenseShape(),
|
||||
*iter->second.tensor_shape, input_output_moniker));
|
||||
}
|
||||
} else if (is_sparse_initializer(name) &&
|
||||
expected_type->IsTensorType()) {
|
||||
// If this metadata came from a sparse initializer converted to dense, then still validate it.
|
||||
auto expected_element_type = expected_type->AsTensorType()->GetElementType();
|
||||
ORT_RETURN_IF_ERROR_SESSIONID_(CheckTypes(sparse_tensor.DataType(), expected_element_type,
|
||||
"sparse_tensor", input_output_moniker));
|
||||
// Check shape
|
||||
if (iter->second.tensor_shape.has_value()) {
|
||||
ORT_RETURN_IF_ERROR_SESSIONID_(CheckShapes(name, sparse_tensor.DenseShape(),
|
||||
*iter->second.tensor_shape, input_output_moniker));
|
||||
}
|
||||
} else {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, input_output_moniker, " with name: '", name,
|
||||
"' expected to be of type: ", static_cast<int>(expected_type->type_), " but received a sparse tensor");
|
||||
}
|
||||
#else
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input with name ", feed_name,
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, input_output_moniker, " with name ", name,
|
||||
" is a sparse tensor, which is not supported in this build.");
|
||||
#endif
|
||||
|
||||
} else if (input_ml_value.IsTensorSequence()) {
|
||||
} else if (input_output_ml_value.IsTensorSequence()) {
|
||||
if (!expected_type->IsTensorSequenceType()
|
||||
#if !defined(DISABLE_OPTIONAL_TYPE)
|
||||
&& !utils::IsOptionalSeqTensor(expected_type)
|
||||
#endif
|
||||
) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input with name: ", feed_name,
|
||||
" is not expected to be of type tensor sequence.");
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, input_output_moniker, " with name: '", name,
|
||||
"' expected to be of type: ", static_cast<int>(expected_type->type_), " but received a tensor sequence");
|
||||
}
|
||||
|
||||
#if !defined(DISABLE_OPTIONAL_TYPE)
|
||||
|
|
@ -1969,43 +2000,40 @@ common::Status InferenceSession::ValidateInputs(gsl::span<const std::string> fee
|
|||
auto expected_element_type = expected_type->AsSequenceTensorType()->GetElementType();
|
||||
#endif
|
||||
|
||||
auto input_element_type = input_ml_value.Get<TensorSeq>().DataType();
|
||||
ORT_RETURN_IF_ERROR_SESSIONID_(CheckTypes(input_element_type, expected_element_type, "seq"));
|
||||
auto input_output_element_type = input_output_ml_value.Get<TensorSeq>().DataType();
|
||||
ORT_RETURN_IF_ERROR_SESSIONID_(CheckTypes(input_output_element_type, expected_element_type, "seq", input_output_moniker));
|
||||
} else {
|
||||
auto input_type = input_ml_value.Type();
|
||||
ORT_RETURN_IF_ERROR_SESSIONID_(CheckTypes(input_type, expected_type, ""));
|
||||
auto input_output_type = input_output_ml_value.Type();
|
||||
ORT_RETURN_IF_ERROR_SESSIONID_(CheckTypes(input_output_type, expected_type, "", input_output_moniker));
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
common::Status InferenceSession::ValidateInputs(gsl::span<const std::string> feed_names,
|
||||
gsl::span<const OrtValue> feeds) const {
|
||||
return ValidateInputsOutputs(feed_names, feeds, input_def_map_, ArgType::kInput);
|
||||
}
|
||||
|
||||
common::Status InferenceSession::ValidateOutputs(gsl::span<const std::string> output_names,
|
||||
const std::vector<OrtValue>* p_fetches) const {
|
||||
if (p_fetches == nullptr) {
|
||||
return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Output vector pointer is NULL");
|
||||
}
|
||||
|
||||
if (output_names.empty()) {
|
||||
return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "At least one output should be requested.");
|
||||
}
|
||||
|
||||
if (!p_fetches->empty() && (output_names.size() != p_fetches->size())) {
|
||||
std::ostringstream ostr;
|
||||
ostr << "Output vector incorrectly sized: output_names.size(): " << output_names.size()
|
||||
<< "p_fetches->size(): " << p_fetches->size();
|
||||
return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, ostr.str());
|
||||
}
|
||||
const auto fetches = (p_fetches == nullptr) ? EmptySpan<const OrtValue>() : gsl::make_span(*p_fetches);
|
||||
|
||||
for (const auto& name : output_names) {
|
||||
if (model_output_names_.find(name) == model_output_names_.end()) {
|
||||
return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Invalid Output Name:" + name);
|
||||
if (fetches.empty()) {
|
||||
for (const auto& name : output_names) {
|
||||
if (output_def_map_.count(name) == 0) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid output name:", name);
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// TODO add more validation here like checking shape of the allocated buffers
|
||||
|
||||
return common::Status::OK();
|
||||
return ValidateInputsOutputs(output_names, fetches, output_def_map_, ArgType::kOutput);
|
||||
}
|
||||
|
||||
#ifdef ENABLE_TRAINING
|
||||
|
|
@ -2483,7 +2511,7 @@ std::pair<common::Status, const OutputDefList*> InferenceSession::GetModelOutput
|
|||
}
|
||||
}
|
||||
|
||||
return std::make_pair(common::Status::OK(), &output_def_list_);
|
||||
return std::make_pair(common::Status::OK(), &model_->MainGraph().GetOutputs());
|
||||
}
|
||||
|
||||
common::Status InferenceSession::NewIOBinding(std::unique_ptr<IOBinding>* io_binding) {
|
||||
|
|
@ -2697,43 +2725,40 @@ common::Status InferenceSession::SaveModelMetadata(const onnxruntime::Model& mod
|
|||
model_metadata_.custom_metadata_map = model.MetaData();
|
||||
model_metadata_.graph_name = graph.Name();
|
||||
|
||||
required_inputs_.clear();
|
||||
for (auto input : graph.GetInputs()) {
|
||||
required_inputs_.insert(input->Name());
|
||||
}
|
||||
|
||||
auto add_inputs = [this](const InputDefList& inputs) {
|
||||
input_def_map_.clear();
|
||||
input_def_map_.reserve(inputs.size());
|
||||
for (auto elem : inputs) {
|
||||
auto add_inputs_outputs = [](const InputDefList& inputs_outputs, InputOutputDefMetaMap& map) {
|
||||
map.reserve(inputs_outputs.size());
|
||||
for (auto elem : inputs_outputs) {
|
||||
auto elem_type = utils::GetMLDataType(*elem);
|
||||
auto elem_shape_proto = elem->Shape();
|
||||
input_def_map_.insert(
|
||||
{elem->Name(),
|
||||
InputDefMetaData(
|
||||
elem, elem_type,
|
||||
elem_shape_proto ? utils::GetTensorShapeFromTensorShapeProto(*elem_shape_proto) : TensorShape())});
|
||||
const auto* elem_shape_proto = elem->Shape();
|
||||
if (elem_shape_proto != nullptr) {
|
||||
map.emplace(elem->Name(), InputOutputDefMetaData(
|
||||
elem, elem_type,
|
||||
utils::GetTensorShapeFromTensorShapeProto(*elem_shape_proto)));
|
||||
} else {
|
||||
map.emplace(elem->Name(), InputOutputDefMetaData(elem, elem_type));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if (graph.CanOverrideInitializer()) {
|
||||
// for IR 4 or higher it is optional to have a matching graph input for an initializer, and if one exists the
|
||||
// initializer is explicitly overridable.
|
||||
add_inputs(graph.GetInputsIncludingInitializers());
|
||||
} else {
|
||||
// for IR < 4 we don't allow overriding initializers so that they can be treated as constant. exclude them from
|
||||
// the list of valid inputs by just using the GetInputs() list.
|
||||
add_inputs(graph.GetInputs());
|
||||
{
|
||||
InputOutputDefMetaMap input_defs;
|
||||
if (graph.CanOverrideInitializer()) {
|
||||
// for IR 4 or higher it is optional to have a matching graph input for an initializer, and if one exists the
|
||||
// initializer is explicitly overridable.
|
||||
add_inputs_outputs(graph.GetInputsIncludingInitializers(), input_defs);
|
||||
} else {
|
||||
// for IR < 4 we don't allow overriding initializers so that they can be treated as constant. exclude them from
|
||||
// the list of valid inputs by just using the GetInputs() list.
|
||||
add_inputs_outputs(graph.GetInputs(), input_defs);
|
||||
}
|
||||
input_def_map_.swap(input_defs);
|
||||
}
|
||||
|
||||
// save outputs
|
||||
const auto& outputs = graph.GetOutputs();
|
||||
output_def_list_ = outputs; // A direct copy of outputs
|
||||
|
||||
model_output_names_.clear();
|
||||
model_output_names_.reserve(outputs.size());
|
||||
for (const auto& elem : outputs) {
|
||||
model_output_names_.insert(elem->Name());
|
||||
{
|
||||
InputOutputDefMetaMap output_defs;
|
||||
add_inputs_outputs(outputs, output_defs);
|
||||
output_def_map_.swap(output_defs);
|
||||
}
|
||||
|
||||
VLOGS(*session_logger_, 1) << "Done saving model metadata";
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
|
||||
|
|
@ -103,6 +104,22 @@ struct ModelMetadata {
|
|||
*/
|
||||
|
||||
class InferenceSession {
|
||||
struct InputOutputDefMetaData {
|
||||
InputOutputDefMetaData(const NodeArg* node_arg0, MLDataType ml_data_type0, TensorShape&& tensor_shape0)
|
||||
: node_arg(node_arg0), ml_data_type(ml_data_type0), tensor_shape(std::move(tensor_shape0)) {
|
||||
}
|
||||
|
||||
InputOutputDefMetaData(const NodeArg* node_arg0, MLDataType ml_data_type0)
|
||||
: node_arg(node_arg0), ml_data_type(ml_data_type0) {
|
||||
}
|
||||
|
||||
gsl::not_null<const NodeArg*> node_arg;
|
||||
MLDataType ml_data_type;
|
||||
std::optional<TensorShape> tensor_shape; // not applicable if the input is non-tensor type
|
||||
};
|
||||
|
||||
using InputOutputDefMetaMap = InlinedHashMap<std::string_view, InputOutputDefMetaData>;
|
||||
|
||||
public:
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
|
||||
|
|
@ -570,9 +587,6 @@ class InferenceSession {
|
|||
// if they need.
|
||||
std::shared_ptr<onnxruntime::Model> model_;
|
||||
|
||||
// names of model outputs used for quick validation.
|
||||
std::unordered_set<std::string> model_output_names_;
|
||||
|
||||
// The file path of where the model was loaded. e.g. /tmp/test_squeezenet/model.onnx
|
||||
PathString model_location_;
|
||||
|
||||
|
|
@ -628,7 +642,7 @@ class InferenceSession {
|
|||
void InitLogger(logging::LoggingManager* logging_manager);
|
||||
|
||||
[[nodiscard]] common::Status CheckShapes(const std::string& input_name, const TensorShape& input_shape,
|
||||
const TensorShape& expected_shape) const;
|
||||
const TensorShape& expected_shape, const char* input_output_moniker) const;
|
||||
|
||||
[[nodiscard]] common::Status ValidateInputs(gsl::span<const std::string> feed_names,
|
||||
gsl::span<const OrtValue> feeds) const;
|
||||
|
|
@ -636,6 +650,11 @@ class InferenceSession {
|
|||
[[nodiscard]] common::Status ValidateOutputs(gsl::span<const std::string> output_names,
|
||||
const std::vector<OrtValue>* p_fetches) const;
|
||||
|
||||
[[nodiscard]] common::Status ValidateInputsOutputs(gsl::span<const std::string> feed_fetches_names,
|
||||
gsl::span<const OrtValue> feeds_fetches,
|
||||
const InputOutputDefMetaMap& input_output_meta_map,
|
||||
ArgType arg_type) const;
|
||||
|
||||
[[nodiscard]] common::Status WaitForNotification(Notification* p_executor_done, int64_t timeout_in_ms);
|
||||
|
||||
template <typename T>
|
||||
|
|
@ -737,19 +756,9 @@ class InferenceSession {
|
|||
#endif
|
||||
|
||||
ModelMetadata model_metadata_;
|
||||
std::unordered_set<std::string> required_inputs_;
|
||||
|
||||
struct InputDefMetaData {
|
||||
InputDefMetaData(const NodeArg* node_arg0, MLDataType ml_data_type0, TensorShape&& tensor_shape0)
|
||||
: node_arg(node_arg0), ml_data_type(ml_data_type0), tensor_shape(std::move(tensor_shape0)) {
|
||||
}
|
||||
const NodeArg* node_arg;
|
||||
MLDataType ml_data_type;
|
||||
TensorShape tensor_shape; // not applicable if the input is non-tensor type
|
||||
};
|
||||
|
||||
std::unordered_map<std::string, InputDefMetaData> input_def_map_;
|
||||
OutputDefList output_def_list_;
|
||||
InputOutputDefMetaMap input_def_map_;
|
||||
InputOutputDefMetaMap output_def_map_;
|
||||
|
||||
// Data transfer manager.
|
||||
DataTransferManager data_transfer_mgr_;
|
||||
|
|
|
|||
|
|
@ -496,14 +496,16 @@ TEST(ExecutionFrameTestInit, InitializerAsOutput) {
|
|||
|
||||
#if !defined(DISABLE_SPARSE_TENSORS)
|
||||
TEST(ExecutionFrameTestInit, SparseInitializerAsOutput) {
|
||||
const std::vector<int64_t> dense_shape{3, 3};
|
||||
std::vector<float> dense_data = {
|
||||
0, 0, 1.764052391052246f,
|
||||
0.40015721321105957f, 0, 0.978738009929657f,
|
||||
0, 0, 0};
|
||||
constexpr std::array<int64_t, 2> dense_shape{3, 3};
|
||||
|
||||
const std::vector<float> expected_values = {1.764052391052246f, 0.40015721321105957f, 0.978738009929657f};
|
||||
const std::vector<int64_t> expected_linear_indices = {2, 3, 5};
|
||||
// Tensor data in a dense form, useful for debugging and reference.
|
||||
// constexpr std::array<float, 9> dense_data = {
|
||||
// 0, 0, 1.764052391052246f,
|
||||
// 0.40015721321105957f, 0, 0.978738009929657f,
|
||||
// 0, 0, 0};
|
||||
|
||||
constexpr std::array<float, 3> expected_values = {1.764052391052246f, 0.40015721321105957f, 0.978738009929657f};
|
||||
constexpr std::array<int64_t, 3> expected_linear_indices = {2, 3, 5};
|
||||
|
||||
// sparse_initializer_as_output.onnx
|
||||
SessionOptions so;
|
||||
|
|
@ -515,14 +517,18 @@ TEST(ExecutionFrameTestInit, SparseInitializerAsOutput) {
|
|||
ASSERT_STATUS_OK(session.Initialize());
|
||||
|
||||
auto allocator = test::AllocatorManager::Instance().GetAllocator(CPU);
|
||||
auto p_tensor = std::make_unique<SparseTensor>();
|
||||
|
||||
std::vector<OrtValue> results;
|
||||
results.resize(1);
|
||||
auto ml_type = DataTypeImpl::GetType<SparseTensor>();
|
||||
results[0].Init(p_tensor.release(), ml_type, ml_type->GetDeleteFunc());
|
||||
|
||||
// Initialize the output value as a SparseTensor with pre-allocated memory
|
||||
// this is done here to test output types.
|
||||
auto element_type = DataTypeImpl::GetSparseTensorType<float>()->AsSparseTensorType()->GetElementType();
|
||||
SparseTensor::InitOrtValue(element_type, TensorShape(dense_shape), allocator, results[0]);
|
||||
|
||||
RunOptions ro;
|
||||
ASSERT_STATUS_OK(session.Run(ro, EmptySpan<std::string>(), EmptySpan<OrtValue>(), AsSpan<std::string>({"values"}), &results, nullptr));
|
||||
ASSERT_STATUS_OK(session.Run(ro, EmptySpan<std::string>(), EmptySpan<OrtValue>(),
|
||||
AsSpan<std::string>({"values"}), &results, nullptr));
|
||||
|
||||
ASSERT_TRUE(results[0].IsAllocated());
|
||||
ASSERT_TRUE(results[0].IsSparseTensor());
|
||||
|
|
|
|||
|
|
@ -1218,13 +1218,13 @@ TEST(InferenceSessionTests, TestOptionalInputs) {
|
|||
// required, optional and invalid input
|
||||
status = RunOptionalInputTest(true, true, true, version, sess_env);
|
||||
ASSERT_FALSE(status.IsOK());
|
||||
EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("Invalid Feed Input Name"));
|
||||
EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("Invalid input name"));
|
||||
|
||||
// missing required
|
||||
status = RunOptionalInputTest(false, true, false, version, sess_env);
|
||||
ASSERT_FALSE(status.IsOK());
|
||||
if (version == 3) {
|
||||
EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("Invalid Feed Input Name"));
|
||||
EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("Invalid input name"));
|
||||
} else {
|
||||
EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("Missing Input:"));
|
||||
}
|
||||
|
|
|
|||
|
|
@ -159,8 +159,8 @@ static void TestInference(Ort::Env& env, const std::basic_string<ORTCHAR_T>& mod
|
|||
expected_values_y,
|
||||
nullptr);
|
||||
// with preallocated output tensor
|
||||
Ort::Value value_y = Ort::Value::CreateTensor<InT>(default_allocator.get(),
|
||||
expected_dims_y.data(), expected_dims_y.size());
|
||||
Ort::Value value_y = Ort::Value::CreateTensor<OutT>(default_allocator.get(),
|
||||
expected_dims_y.data(), expected_dims_y.size());
|
||||
|
||||
// test it twice
|
||||
for (int i = 0; i != 2; ++i)
|
||||
|
|
|
|||
Loading…
Reference in a new issue