From 131c65d23d040079d50a6dc3e85f72f0de401346 Mon Sep 17 00:00:00 2001 From: ytaous <4484531+ytaous@users.noreply.github.com> Date: Thu, 26 Mar 2020 18:43:30 -0700 Subject: [PATCH] Addressing PR comments (#3334) * PR comments * PR comments * PR comments * error out bad shape Co-authored-by: Ethan Tao --- onnxruntime/core/framework/session_state.cc | 13 ++++---- .../core/providers/cpu/tensor/slice.cc | 32 +++++++++---------- onnxruntime/core/providers/cpu/tensor/slice.h | 4 +-- .../providers/cuda/cuda_provider_factory.cc | 5 ++- .../core/providers/cuda/tensor/slice.cc | 2 +- onnxruntime/core/session/inference_session.cc | 3 -- onnxruntime/core/session/inference_session.h | 4 +-- .../core/session/training_session.cc | 2 +- .../training_ops/cpu/tensor/slice_grad.cc | 2 +- .../training_ops/cuda/tensor/slice_grad.cc | 2 +- 10 files changed, 35 insertions(+), 34 deletions(-) diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc index eb78ff89f0..2c7560df19 100644 --- a/onnxruntime/core/framework/session_state.cc +++ b/onnxruntime/core/framework/session_state.cc @@ -188,8 +188,7 @@ static int64_t CalculateMemoryPatternsKey(const std::vector& feeds, std::unordered_map& out) { - for (size_t i = 0; i < graph.GetInputs().size(); ++i) { - auto* input = graph.GetInputs()[i]; + for (const auto* input : graph.GetInputs()) { auto* shape = input->Shape(); auto it = feeds.find(input->Name()); if (it == feeds.end()) @@ -198,7 +197,7 @@ Status ResolveDimParams(const GraphViewer& graph, const std::mapName() + "'s shape is not present or its shape doesn't match feed's shape." "Unable to resolve the value for dynamic shape"); - for (int k = 0; k < shape->dim_size(); ++k) { + for (int k = 0, end = shape->dim_size(); k < end; ++k) { if (shape->dim()[k].has_dim_param()) { out.insert({shape->dim()[k].dim_param(), it->second.GetDims()[k]}); } @@ -212,7 +211,7 @@ Status SessionState::GeneratePatternGroupCache(const std::vector& feed_mlvalue_idxs, MemoryPatternGroup* output) const { std::map feeds; - for (size_t i = 0; i < feed_mlvalue_idxs.size(); ++i) { + for (size_t i = 0, end = feed_mlvalue_idxs.size(); i < end; ++i) { std::string name; ORT_RETURN_IF_ERROR(this->ort_value_name_idx_map_.GetName(feed_mlvalue_idxs[i], name)); feeds.insert({name, input_shape[i]}); @@ -228,7 +227,7 @@ Status SessionState::GeneratePatternGroupCache(const std::vectorGetNode(node_plan.node_index); int output_start = node_index + static_cast(node->InputDefs().size()) + static_cast(node->ImplicitInputDefs().size()); //allocate output - for (int i = 0; i < static_cast(node->OutputDefs().size()); ++i) { + for (int i = 0, end = static_cast(node->OutputDefs().size()); i < end; ++i) { const auto ml_value_idx = node_index_info.GetMLValueIndex(output_start + i); if (ml_value_idx == NodeIndexInfo::kInvalidEntry) continue; @@ -251,8 +250,10 @@ Status SessionState::GeneratePatternGroupCache(const std::vectorsecond; - } else { + } else if (dim.has_dim_value()) { len *= dim.dim_value(); + } else { + return Status(ONNXRUNTIME, FAIL, "Unknown shape found in memory pattern compute"); } } if (!IAllocator::CalcMemSizeForArrayWithAlignment<64>(len, ml_data_type->Size(), &size)) { diff --git a/onnxruntime/core/providers/cpu/tensor/slice.cc b/onnxruntime/core/providers/cpu/tensor/slice.cc index a4350396e2..640df73199 100644 --- a/onnxruntime/core/providers/cpu/tensor/slice.cc +++ b/onnxruntime/core/providers/cpu/tensor/slice.cc @@ -260,21 +260,21 @@ Status SliceBase::PrepareForCompute(const std::vector& raw_starts, } // Slice V10 & DynamicSlice -void SliceBase::FillVectorsFromInput(const Tensor* start_tensor, - const Tensor* ends_tensor, +void SliceBase::FillVectorsFromInput(const Tensor& start_tensor, + const Tensor& ends_tensor, const Tensor* axes_tensor, const Tensor* steps_tensor, std::vector& input_starts, std::vector& input_ends, std::vector& input_axes, std::vector& input_steps) const { - ORT_ENFORCE(nullptr != start_tensor && start_tensor->Shape().NumDimensions() == 1, "Starts must be a 1-D array"); - ORT_ENFORCE(nullptr != ends_tensor && ends_tensor->Shape().NumDimensions() == 1, "Ends must be a 1-D array"); - ORT_ENFORCE(start_tensor->Shape() == ends_tensor->Shape(), "Starts and ends shape mismatch"); - ORT_ENFORCE(nullptr == axes_tensor || start_tensor->Shape() == axes_tensor->Shape(), "Starts and axes shape mismatch"); - ORT_ENFORCE(nullptr == steps_tensor || start_tensor->Shape() == steps_tensor->Shape(), "Starts and steps shape mismatch"); + ORT_ENFORCE(start_tensor.Shape().NumDimensions() == 1, "Starts must be a 1-D array"); + ORT_ENFORCE(ends_tensor.Shape().NumDimensions() == 1, "Ends must be a 1-D array"); + ORT_ENFORCE(start_tensor.Shape() == ends_tensor.Shape(), "Starts and ends shape mismatch"); + ORT_ENFORCE(nullptr == axes_tensor || start_tensor.Shape() == axes_tensor->Shape(), "Starts and axes shape mismatch"); + ORT_ENFORCE(nullptr == steps_tensor || start_tensor.Shape() == steps_tensor->Shape(), "Starts and steps shape mismatch"); - const auto& size = start_tensor->Shape().Size(); + const auto& size = start_tensor.Shape().Size(); input_starts.resize(size); input_ends.resize(size); if (nullptr != axes_tensor) @@ -283,9 +283,9 @@ void SliceBase::FillVectorsFromInput(const Tensor* start_tensor, if (nullptr != steps_tensor) input_steps.resize(size); - if (start_tensor->IsDataType()) { - std::copy(start_tensor->Data(), start_tensor->Data() + size, input_starts.begin()); - std::copy(ends_tensor->Data(), ends_tensor->Data() + size, input_ends.begin()); + if (start_tensor.IsDataType()) { + std::copy(start_tensor.Data(), start_tensor.Data() + size, input_starts.begin()); + std::copy(ends_tensor.Data(), ends_tensor.Data() + size, input_ends.begin()); if (nullptr != axes_tensor) std::copy(axes_tensor->Data(), axes_tensor->Data() + size, input_axes.begin()); // Slice V10 @@ -293,9 +293,9 @@ void SliceBase::FillVectorsFromInput(const Tensor* start_tensor, std::copy(steps_tensor->Data(), steps_tensor->Data() + size, input_steps.begin()); } - else if (start_tensor->IsDataType()) { - std::copy(start_tensor->Data(), start_tensor->Data() + size, input_starts.begin()); - std::copy(ends_tensor->Data(), ends_tensor->Data() + size, input_ends.begin()); + else if (start_tensor.IsDataType()) { + std::copy(start_tensor.Data(), start_tensor.Data() + size, input_starts.begin()); + std::copy(ends_tensor.Data(), ends_tensor.Data() + size, input_ends.begin()); if (nullptr != axes_tensor) std::copy(axes_tensor->Data(), axes_tensor->Data() + size, input_axes.begin()); // Slice V10 @@ -305,7 +305,7 @@ void SliceBase::FillVectorsFromInput(const Tensor* start_tensor, // should not reach this as no kernel is registered for this condition to be triggered - just an additional safety check else { - ORT_THROW("Data type for starts and ends inputs' need to be int32_t or int64_t, but instead got ", start_tensor->DataType()); + ORT_THROW("Data type for starts and ends inputs' need to be int32_t or int64_t, but instead got ", start_tensor.DataType()); } } @@ -379,7 +379,7 @@ Status Slice::Compute(OpKernelContext* ctx) const { std::vector input_ends; std::vector input_axes; std::vector input_steps; - FillVectorsFromInput(ctx->Input(1), ctx->Input(2), ctx->Input(3), + FillVectorsFromInput(*ctx->Input(1), *ctx->Input(2), ctx->Input(3), ctx->Input(4), input_starts, input_ends, input_axes, input_steps); ORT_RETURN_IF_ERROR(PrepareForCompute(input_starts, input_ends, input_axes, input_steps, diff --git a/onnxruntime/core/providers/cpu/tensor/slice.h b/onnxruntime/core/providers/cpu/tensor/slice.h index 7aabd5e636..9abd928e37 100644 --- a/onnxruntime/core/providers/cpu/tensor/slice.h +++ b/onnxruntime/core/providers/cpu/tensor/slice.h @@ -43,8 +43,8 @@ class SliceBase { std::vector*& flattened_output_dims) const; // Slice V10 & DynamicSlice - void FillVectorsFromInput(const Tensor* start_tensor, - const Tensor* ends_tensor, + void FillVectorsFromInput(const Tensor& start_tensor, + const Tensor& ends_tensor, const Tensor* axes_tensor, const Tensor* steps_tensor, std::vector& input_starts, diff --git a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc index 052c9de9f1..f5026387f9 100644 --- a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc +++ b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc @@ -15,7 +15,10 @@ namespace onnxruntime { struct CUDAProviderFactory : IExecutionProviderFactory { CUDAProviderFactory(OrtDevice::DeviceId device_id, size_t cuda_mem_limit = std::numeric_limits::max(), - ArenaExtendStrategy arena_extend_strategy = ArenaExtendStrategy::kNextPowerOfTwo) : device_id_(device_id), cuda_mem_limit_(cuda_mem_limit), arena_extend_strategy_(arena_extend_strategy) {} + ArenaExtendStrategy arena_extend_strategy = ArenaExtendStrategy::kNextPowerOfTwo) + : device_id_(device_id), + cuda_mem_limit_(cuda_mem_limit), + arena_extend_strategy_(arena_extend_strategy) {} ~CUDAProviderFactory() override {} std::unique_ptr CreateProvider() override; diff --git a/onnxruntime/core/providers/cuda/tensor/slice.cc b/onnxruntime/core/providers/cuda/tensor/slice.cc index 328dfec66b..640a804aab 100644 --- a/onnxruntime/core/providers/cuda/tensor/slice.cc +++ b/onnxruntime/core/providers/cuda/tensor/slice.cc @@ -156,7 +156,7 @@ template void Slice::FillInputVectors(OpKernelContext* ctx, std::vector& input_starts, std::vector& input_ends, std::vector& input_axes, std::vector& input_steps) const { - FillVectorsFromInput(ctx->Input(1), ctx->Input(2), ctx->Input(3), + FillVectorsFromInput(*ctx->Input(1), *ctx->Input(2), ctx->Input(3), ctx->Input(4), input_starts, input_ends, input_axes, input_steps); } diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 799e201c13..a560284dbc 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -690,9 +690,6 @@ common::Status InferenceSession::CreateSubgraphSessionState(Graph& graph, Sessio // Pass fused function manager to subgraph subgraph_session_state->GetMutableFuncMgr().SetFusedFuncs(session_state.GetFuncMgr()); - // Pass fused function manager to subgraph - subgraph_session_state->GetMutableFuncMgr().SetFusedFuncs(session_state.GetFuncMgr()); - // recurse ORT_RETURN_IF_ERROR_SESSIONID_(CreateSubgraphSessionState(*subgraph, *subgraph_session_state)); diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index 55c58364cf..bd495f9657 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -373,8 +373,6 @@ class InferenceSession { // The file path of where the model was loaded. e.g. /tmp/test_squeezenet/model.onnx std::basic_string model_location_; - SessionOptions session_options_; - private: ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(InferenceSession); @@ -430,6 +428,8 @@ class InferenceSession { template void StartProfiling(const std::basic_string& file_prefix); + SessionOptions session_options_; + onnxruntime::GraphTransformerManager graph_transformation_mgr_; // List of transformers to run. When this list is not empty only the transformers in this list diff --git a/orttraining/orttraining/core/session/training_session.cc b/orttraining/orttraining/core/session/training_session.cc index 748dfc3c44..3e27722e5a 100644 --- a/orttraining/orttraining/core/session/training_session.cc +++ b/orttraining/orttraining/core/session/training_session.cc @@ -371,7 +371,7 @@ void TrainingSession::AddPredefinedTransformers(GraphTransformerManager& transfo const std::vector& custom_list) { auto add_transformers = [&](TransformerLevel level) { // Generate and register transformers for level - auto transformers_to_register = transformer_utils::GenerateTransformers(level, session_options_.free_dimension_overrides, custom_list); + auto transformers_to_register = transformer_utils::GenerateTransformers(level, GetSessionOptions().free_dimension_overrides, custom_list); for (auto& entry : transformers_to_register) { transformer_manager.Register(std::move(entry), level); } diff --git a/orttraining/orttraining/training_ops/cpu/tensor/slice_grad.cc b/orttraining/orttraining/training_ops/cpu/tensor/slice_grad.cc index c03009163f..c6cbb517b6 100644 --- a/orttraining/orttraining/training_ops/cpu/tensor/slice_grad.cc +++ b/orttraining/orttraining/training_ops/cpu/tensor/slice_grad.cc @@ -37,7 +37,7 @@ Status SliceGrad::Compute(OpKernelContext* context) const { std::vector input_ends; std::vector input_axes; std::vector input_steps; - FillVectorsFromInput(context->Input(2), context->Input(3), context->Input(4), + FillVectorsFromInput(*context->Input(2), *context->Input(3), context->Input(4), context->Input(5), input_starts, input_ends, input_axes, input_steps); ORT_RETURN_IF_ERROR(PrepareForCompute(input_starts, input_ends, input_axes, input_steps, diff --git a/orttraining/orttraining/training_ops/cuda/tensor/slice_grad.cc b/orttraining/orttraining/training_ops/cuda/tensor/slice_grad.cc index 9ad4e6132e..2d0e39c125 100644 --- a/orttraining/orttraining/training_ops/cuda/tensor/slice_grad.cc +++ b/orttraining/orttraining/training_ops/cuda/tensor/slice_grad.cc @@ -43,7 +43,7 @@ const Tensor* SliceGrad::GetSlicedOrUnslicedTensor(OpKernelContext* ctx) const { void SliceGrad::FillInputVectors(OpKernelContext* ctx, std::vector& input_starts, std::vector& input_ends, std::vector& input_axes, std::vector& input_steps) const { - FillVectorsFromInput(ctx->Input(2), ctx->Input(3), ctx->Input(4), + FillVectorsFromInput(*ctx->Input(2), *ctx->Input(3), ctx->Input(4), ctx->Input(5), input_starts, input_ends, input_axes, input_steps); }