Return error message from status instead of swallowing it. (#1221)

* Return error message from status instead of swallowing it.

* Return OrtValue* from OpKernelContext::GetOrCreateOutputMLValue

* Add unstaged change.
This commit is contained in:
Scott McKay 2019-06-22 06:26:42 +10:00 committed by GitHub
parent 18b7d2b18a
commit 4d765dc6d0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 34 additions and 20 deletions

View file

@ -93,8 +93,7 @@ class OpKernelContext {
if (index < 0 || index >= OutputCount())
return nullptr;
OrtValue* p_ml_value = nullptr;
ORT_ENFORCE(GetOrCreateOutputMLValue(index, p_ml_value).IsOK());
OrtValue* p_ml_value = GetOrCreateOutputMLValue(index);
return p_ml_value ? p_ml_value->GetMutable<T>() : nullptr;
}
@ -174,7 +173,7 @@ class OpKernelContext {
private:
ORT_DISALLOW_COPY_AND_ASSIGNMENT(OpKernelContext);
Status GetOrCreateOutputMLValue(int index, OrtValue*& value);
OrtValue* GetOrCreateOutputMLValue(int index);
int GetInputArgIndex(int index) const;
int GetImplicitInputArgIndex(int index) const;

View file

@ -102,10 +102,12 @@ Fence_t OpKernelContext::OutputFence(int index) const {
return p_ml_value ? p_ml_value->Fence() : nullptr;
}
Status OpKernelContext::GetOrCreateOutputMLValue(int index, OrtValue*& p_value) {
OrtValue* OpKernelContext::GetOrCreateOutputMLValue(int index) {
auto output_arg_index = GetOutputArgIndex(index);
ORT_ENFORCE(execution_frame_->GetOrCreateNodeOutputMLValue(output_arg_index, nullptr, p_value).IsOK());
return Status::OK();
OrtValue* value = nullptr;
auto status = execution_frame_->GetOrCreateNodeOutputMLValue(output_arg_index, nullptr, value);
ORT_ENFORCE(status.IsOK(), status.ErrorMessage());
return value;
}
int OpKernelContext::GetInputArgIndex(int index) const {

View file

@ -181,7 +181,8 @@ void ParallelExecutor::RunNodeAsyncInternal(size_t p_node_index,
// Execute the kernel.
auto status = p_op_kernel->Compute(&op_kernel_context);
if (!status.IsOK()) {
ORT_THROW("Compute failed for node: ", graph_viewer->GetNode(node_index)->Name());
ORT_THROW("Compute failed for node: ", graph_viewer->GetNode(node_index)->Name(),
". Error:", status.ErrorMessage());
}
if (f_profiler_enabled) {
session_state.Profiler().EndTimeAndRecordEvent(profiling::NODE_EVENT,

View file

@ -96,7 +96,14 @@ inline int CompareCString<wchar_t>(const wchar_t* s1, const wchar_t* s2) {
return wcscmp(s1, s2);
}
enum class OrtFileType { TYPE_BLK, TYPE_CHR, TYPE_DIR, TYPE_FIFO, TYPE_LNK, TYPE_REG, TYPE_SOCK, TYPE_UNKNOWN };
enum class OrtFileType { TYPE_BLK,
TYPE_CHR,
TYPE_DIR,
TYPE_FIFO,
TYPE_LNK,
TYPE_REG,
TYPE_SOCK,
TYPE_UNKNOWN };
template <typename PATH_CHAR_TYPE>
PATH_CHAR_TYPE GetPathSep();
@ -236,7 +243,7 @@ void LoopDir(const std::string& dir_name, T func) {
auto e = errno;
char buf[1024];
char* msg;
#if defined(__GLIBC__) && defined(_GNU_SOURCE) && !defined (__ANDROID__)
#if defined(__GLIBC__) && defined(_GNU_SOURCE) && !defined(__ANDROID__)
msg = strerror_r(e, buf, sizeof(buf));
#else
if (strerror_r(e, buf, sizeof(buf)) != 0) {
@ -266,7 +273,8 @@ void LoopDir(const std::string& dir_name, T func) {
template <typename T>
inline T ReplaceFilename(const T& input, const T& new_value) {
T ret;
ORT_ENFORCE(GetDirNameFromFilePath(input, ret).IsOK());
auto status = GetDirNameFromFilePath(input, ret);
ORT_ENFORCE(status.IsOK(), status.ErrorMessage());
return ConcatPathComponent(ret, new_value);
}

View file

@ -149,7 +149,8 @@ FunctionImpl::FunctionImpl(const onnxruntime::Graph& graph,
}
//TODO: if we reuse the nodes in parent graph, maybe we don't need to resolve it.
ORT_ENFORCE(sub_graph.Resolve().IsOK());
auto status = sub_graph.Resolve();
ORT_ENFORCE(status.IsOK(), status.ErrorMessage());
}
FunctionImpl::FunctionImpl(const onnxruntime::Graph& graph,

View file

@ -157,14 +157,16 @@ void TransformerMemcpyImpl::ProcessDefs(onnxruntime::Node& node, const KernelReg
const KernelCreateInfo* kci = nullptr;
kernel_registries.SearchKernelRegistry(node, &kci);
ORT_ENFORCE(onnxruntime::Node::ForEachWithIndex(node.InputDefs(), [this, &kci](const onnxruntime::NodeArg& arg,
size_t index) {
if (kci && kci->kernel_def->IsInputOnCpu(index))
non_provider_input_defs_.insert(&arg);
else
provider_input_defs_.insert(&arg);
return Status::OK();
}).IsOK());
auto status = onnxruntime::Node::ForEachWithIndex(node.InputDefs(),
[this, &kci](const onnxruntime::NodeArg& arg, size_t index) {
if (kci && kci->kernel_def->IsInputOnCpu(index))
non_provider_input_defs_.insert(&arg);
else
provider_input_defs_.insert(&arg);
return Status::OK();
});
ORT_ENFORCE(status.IsOK(), status.ErrorMessage());
// we don't need to handle implicit input here as provider_ is never kCpuExecutionProvider, all control flow
// nodes are CPU based, and only control flow nodes have implicit inputs.

View file

@ -540,7 +540,8 @@ struct TBroadcastOutput {
template <typename T>
struct TensorAllocator {
TensorAllocator(OpKernelContext& context) {
ORT_ENFORCE(context.GetTempSpaceAllocator(&allocator_).IsOK());
auto status = context.GetTempSpaceAllocator(&allocator_);
ORT_ENFORCE(status.IsOK(), status.ErrorMessage());
}
std::unique_ptr<Tensor> Allocate(const TensorShape& shape) {