mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-16 21:00:14 +00:00
Return error message from status instead of swallowing it. (#1221)
* Return error message from status instead of swallowing it. * Return OrtValue* from OpKernelContext::GetOrCreateOutputMLValue * Add unstaged change.
This commit is contained in:
parent
18b7d2b18a
commit
4d765dc6d0
7 changed files with 34 additions and 20 deletions
|
|
@ -93,8 +93,7 @@ class OpKernelContext {
|
|||
if (index < 0 || index >= OutputCount())
|
||||
return nullptr;
|
||||
|
||||
OrtValue* p_ml_value = nullptr;
|
||||
ORT_ENFORCE(GetOrCreateOutputMLValue(index, p_ml_value).IsOK());
|
||||
OrtValue* p_ml_value = GetOrCreateOutputMLValue(index);
|
||||
return p_ml_value ? p_ml_value->GetMutable<T>() : nullptr;
|
||||
}
|
||||
|
||||
|
|
@ -174,7 +173,7 @@ class OpKernelContext {
|
|||
private:
|
||||
ORT_DISALLOW_COPY_AND_ASSIGNMENT(OpKernelContext);
|
||||
|
||||
Status GetOrCreateOutputMLValue(int index, OrtValue*& value);
|
||||
OrtValue* GetOrCreateOutputMLValue(int index);
|
||||
|
||||
int GetInputArgIndex(int index) const;
|
||||
int GetImplicitInputArgIndex(int index) const;
|
||||
|
|
|
|||
|
|
@ -102,10 +102,12 @@ Fence_t OpKernelContext::OutputFence(int index) const {
|
|||
return p_ml_value ? p_ml_value->Fence() : nullptr;
|
||||
}
|
||||
|
||||
Status OpKernelContext::GetOrCreateOutputMLValue(int index, OrtValue*& p_value) {
|
||||
OrtValue* OpKernelContext::GetOrCreateOutputMLValue(int index) {
|
||||
auto output_arg_index = GetOutputArgIndex(index);
|
||||
ORT_ENFORCE(execution_frame_->GetOrCreateNodeOutputMLValue(output_arg_index, nullptr, p_value).IsOK());
|
||||
return Status::OK();
|
||||
OrtValue* value = nullptr;
|
||||
auto status = execution_frame_->GetOrCreateNodeOutputMLValue(output_arg_index, nullptr, value);
|
||||
ORT_ENFORCE(status.IsOK(), status.ErrorMessage());
|
||||
return value;
|
||||
}
|
||||
|
||||
int OpKernelContext::GetInputArgIndex(int index) const {
|
||||
|
|
|
|||
|
|
@ -181,7 +181,8 @@ void ParallelExecutor::RunNodeAsyncInternal(size_t p_node_index,
|
|||
// Execute the kernel.
|
||||
auto status = p_op_kernel->Compute(&op_kernel_context);
|
||||
if (!status.IsOK()) {
|
||||
ORT_THROW("Compute failed for node: ", graph_viewer->GetNode(node_index)->Name());
|
||||
ORT_THROW("Compute failed for node: ", graph_viewer->GetNode(node_index)->Name(),
|
||||
". Error:", status.ErrorMessage());
|
||||
}
|
||||
if (f_profiler_enabled) {
|
||||
session_state.Profiler().EndTimeAndRecordEvent(profiling::NODE_EVENT,
|
||||
|
|
|
|||
|
|
@ -96,7 +96,14 @@ inline int CompareCString<wchar_t>(const wchar_t* s1, const wchar_t* s2) {
|
|||
return wcscmp(s1, s2);
|
||||
}
|
||||
|
||||
enum class OrtFileType { TYPE_BLK, TYPE_CHR, TYPE_DIR, TYPE_FIFO, TYPE_LNK, TYPE_REG, TYPE_SOCK, TYPE_UNKNOWN };
|
||||
enum class OrtFileType { TYPE_BLK,
|
||||
TYPE_CHR,
|
||||
TYPE_DIR,
|
||||
TYPE_FIFO,
|
||||
TYPE_LNK,
|
||||
TYPE_REG,
|
||||
TYPE_SOCK,
|
||||
TYPE_UNKNOWN };
|
||||
|
||||
template <typename PATH_CHAR_TYPE>
|
||||
PATH_CHAR_TYPE GetPathSep();
|
||||
|
|
@ -236,7 +243,7 @@ void LoopDir(const std::string& dir_name, T func) {
|
|||
auto e = errno;
|
||||
char buf[1024];
|
||||
char* msg;
|
||||
#if defined(__GLIBC__) && defined(_GNU_SOURCE) && !defined (__ANDROID__)
|
||||
#if defined(__GLIBC__) && defined(_GNU_SOURCE) && !defined(__ANDROID__)
|
||||
msg = strerror_r(e, buf, sizeof(buf));
|
||||
#else
|
||||
if (strerror_r(e, buf, sizeof(buf)) != 0) {
|
||||
|
|
@ -266,7 +273,8 @@ void LoopDir(const std::string& dir_name, T func) {
|
|||
template <typename T>
|
||||
inline T ReplaceFilename(const T& input, const T& new_value) {
|
||||
T ret;
|
||||
ORT_ENFORCE(GetDirNameFromFilePath(input, ret).IsOK());
|
||||
auto status = GetDirNameFromFilePath(input, ret);
|
||||
ORT_ENFORCE(status.IsOK(), status.ErrorMessage());
|
||||
return ConcatPathComponent(ret, new_value);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -149,7 +149,8 @@ FunctionImpl::FunctionImpl(const onnxruntime::Graph& graph,
|
|||
}
|
||||
|
||||
//TODO: if we reuse the nodes in parent graph, maybe we don't need to resolve it.
|
||||
ORT_ENFORCE(sub_graph.Resolve().IsOK());
|
||||
auto status = sub_graph.Resolve();
|
||||
ORT_ENFORCE(status.IsOK(), status.ErrorMessage());
|
||||
}
|
||||
|
||||
FunctionImpl::FunctionImpl(const onnxruntime::Graph& graph,
|
||||
|
|
|
|||
|
|
@ -157,14 +157,16 @@ void TransformerMemcpyImpl::ProcessDefs(onnxruntime::Node& node, const KernelReg
|
|||
const KernelCreateInfo* kci = nullptr;
|
||||
kernel_registries.SearchKernelRegistry(node, &kci);
|
||||
|
||||
ORT_ENFORCE(onnxruntime::Node::ForEachWithIndex(node.InputDefs(), [this, &kci](const onnxruntime::NodeArg& arg,
|
||||
size_t index) {
|
||||
if (kci && kci->kernel_def->IsInputOnCpu(index))
|
||||
non_provider_input_defs_.insert(&arg);
|
||||
else
|
||||
provider_input_defs_.insert(&arg);
|
||||
return Status::OK();
|
||||
}).IsOK());
|
||||
auto status = onnxruntime::Node::ForEachWithIndex(node.InputDefs(),
|
||||
[this, &kci](const onnxruntime::NodeArg& arg, size_t index) {
|
||||
if (kci && kci->kernel_def->IsInputOnCpu(index))
|
||||
non_provider_input_defs_.insert(&arg);
|
||||
else
|
||||
provider_input_defs_.insert(&arg);
|
||||
return Status::OK();
|
||||
});
|
||||
|
||||
ORT_ENFORCE(status.IsOK(), status.ErrorMessage());
|
||||
|
||||
// we don't need to handle implicit input here as provider_ is never kCpuExecutionProvider, all control flow
|
||||
// nodes are CPU based, and only control flow nodes have implicit inputs.
|
||||
|
|
|
|||
|
|
@ -540,7 +540,8 @@ struct TBroadcastOutput {
|
|||
template <typename T>
|
||||
struct TensorAllocator {
|
||||
TensorAllocator(OpKernelContext& context) {
|
||||
ORT_ENFORCE(context.GetTempSpaceAllocator(&allocator_).IsOK());
|
||||
auto status = context.GetTempSpaceAllocator(&allocator_);
|
||||
ORT_ENFORCE(status.IsOK(), status.ErrorMessage());
|
||||
}
|
||||
|
||||
std::unique_ptr<Tensor> Allocate(const TensorShape& shape) {
|
||||
|
|
|
|||
Loading…
Reference in a new issue