Addressing PR comments (#3334)

* PR comments

* PR comments

* PR comments

* error out bad shape

Co-authored-by: Ethan Tao <ettao@microsoft.com>
This commit is contained in:
ytaous 2020-03-26 18:43:30 -07:00 committed by GitHub
parent 0a6ec0df56
commit 131c65d23d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 35 additions and 34 deletions

View file

@ -188,8 +188,7 @@ static int64_t CalculateMemoryPatternsKey(const std::vector<std::reference_wrapp
namespace {
Status ResolveDimParams(const GraphViewer& graph, const std::map<std::string, TensorShape>& feeds, std::unordered_map<std::string, int64_t>& out) {
for (size_t i = 0; i < graph.GetInputs().size(); ++i) {
auto* input = graph.GetInputs()[i];
for (const auto* input : graph.GetInputs()) {
auto* shape = input->Shape();
auto it = feeds.find(input->Name());
if (it == feeds.end())
@ -198,7 +197,7 @@ Status ResolveDimParams(const GraphViewer& graph, const std::map<std::string, Te
return Status(ONNXRUNTIME, FAIL, "Graph input " + input->Name() +
"'s shape is not present or its shape doesn't match feed's shape."
"Unable to resolve the value for dynamic shape");
for (int k = 0; k < shape->dim_size(); ++k) {
for (int k = 0, end = shape->dim_size(); k < end; ++k) {
if (shape->dim()[k].has_dim_param()) {
out.insert({shape->dim()[k].dim_param(), it->second.GetDims()[k]});
}
@ -212,7 +211,7 @@ Status SessionState::GeneratePatternGroupCache(const std::vector<std::reference_
const std::vector<int>& feed_mlvalue_idxs,
MemoryPatternGroup* output) const {
std::map<std::string, TensorShape> feeds;
for (size_t i = 0; i < feed_mlvalue_idxs.size(); ++i) {
for (size_t i = 0, end = feed_mlvalue_idxs.size(); i < end; ++i) {
std::string name;
ORT_RETURN_IF_ERROR(this->ort_value_name_idx_map_.GetName(feed_mlvalue_idxs[i], name));
feeds.insert({name, input_shape[i]});
@ -228,7 +227,7 @@ Status SessionState::GeneratePatternGroupCache(const std::vector<std::reference_
auto* node = graph_viewer_->GetNode(node_plan.node_index);
int output_start = node_index + static_cast<int>(node->InputDefs().size()) + static_cast<int>(node->ImplicitInputDefs().size());
//allocate output
for (int i = 0; i < static_cast<int>(node->OutputDefs().size()); ++i) {
for (int i = 0, end = static_cast<int>(node->OutputDefs().size()); i < end; ++i) {
const auto ml_value_idx = node_index_info.GetMLValueIndex(output_start + i);
if (ml_value_idx == NodeIndexInfo::kInvalidEntry)
continue;
@ -251,8 +250,10 @@ Status SessionState::GeneratePatternGroupCache(const std::vector<std::reference_
return Status(ONNXRUNTIME, FAIL, "Unknown shape found in memory pattern compute");
}
len *= it->second;
} else {
} else if (dim.has_dim_value()) {
len *= dim.dim_value();
} else {
return Status(ONNXRUNTIME, FAIL, "Unknown shape found in memory pattern compute");
}
}
if (!IAllocator::CalcMemSizeForArrayWithAlignment<64>(len, ml_data_type->Size(), &size)) {

View file

@ -260,21 +260,21 @@ Status SliceBase::PrepareForCompute(const std::vector<int64_t>& raw_starts,
}
// Slice V10 & DynamicSlice
void SliceBase::FillVectorsFromInput(const Tensor* start_tensor,
const Tensor* ends_tensor,
void SliceBase::FillVectorsFromInput(const Tensor& start_tensor,
const Tensor& ends_tensor,
const Tensor* axes_tensor,
const Tensor* steps_tensor,
std::vector<int64_t>& input_starts,
std::vector<int64_t>& input_ends,
std::vector<int64_t>& input_axes,
std::vector<int64_t>& input_steps) const {
ORT_ENFORCE(nullptr != start_tensor && start_tensor->Shape().NumDimensions() == 1, "Starts must be a 1-D array");
ORT_ENFORCE(nullptr != ends_tensor && ends_tensor->Shape().NumDimensions() == 1, "Ends must be a 1-D array");
ORT_ENFORCE(start_tensor->Shape() == ends_tensor->Shape(), "Starts and ends shape mismatch");
ORT_ENFORCE(nullptr == axes_tensor || start_tensor->Shape() == axes_tensor->Shape(), "Starts and axes shape mismatch");
ORT_ENFORCE(nullptr == steps_tensor || start_tensor->Shape() == steps_tensor->Shape(), "Starts and steps shape mismatch");
ORT_ENFORCE(start_tensor.Shape().NumDimensions() == 1, "Starts must be a 1-D array");
ORT_ENFORCE(ends_tensor.Shape().NumDimensions() == 1, "Ends must be a 1-D array");
ORT_ENFORCE(start_tensor.Shape() == ends_tensor.Shape(), "Starts and ends shape mismatch");
ORT_ENFORCE(nullptr == axes_tensor || start_tensor.Shape() == axes_tensor->Shape(), "Starts and axes shape mismatch");
ORT_ENFORCE(nullptr == steps_tensor || start_tensor.Shape() == steps_tensor->Shape(), "Starts and steps shape mismatch");
const auto& size = start_tensor->Shape().Size();
const auto& size = start_tensor.Shape().Size();
input_starts.resize(size);
input_ends.resize(size);
if (nullptr != axes_tensor)
@ -283,9 +283,9 @@ void SliceBase::FillVectorsFromInput(const Tensor* start_tensor,
if (nullptr != steps_tensor)
input_steps.resize(size);
if (start_tensor->IsDataType<int32_t>()) {
std::copy(start_tensor->Data<int32_t>(), start_tensor->Data<int32_t>() + size, input_starts.begin());
std::copy(ends_tensor->Data<int32_t>(), ends_tensor->Data<int32_t>() + size, input_ends.begin());
if (start_tensor.IsDataType<int32_t>()) {
std::copy(start_tensor.Data<int32_t>(), start_tensor.Data<int32_t>() + size, input_starts.begin());
std::copy(ends_tensor.Data<int32_t>(), ends_tensor.Data<int32_t>() + size, input_ends.begin());
if (nullptr != axes_tensor)
std::copy(axes_tensor->Data<int32_t>(), axes_tensor->Data<int32_t>() + size, input_axes.begin());
// Slice V10
@ -293,9 +293,9 @@ void SliceBase::FillVectorsFromInput(const Tensor* start_tensor,
std::copy(steps_tensor->Data<int32_t>(), steps_tensor->Data<int32_t>() + size, input_steps.begin());
}
else if (start_tensor->IsDataType<int64_t>()) {
std::copy(start_tensor->Data<int64_t>(), start_tensor->Data<int64_t>() + size, input_starts.begin());
std::copy(ends_tensor->Data<int64_t>(), ends_tensor->Data<int64_t>() + size, input_ends.begin());
else if (start_tensor.IsDataType<int64_t>()) {
std::copy(start_tensor.Data<int64_t>(), start_tensor.Data<int64_t>() + size, input_starts.begin());
std::copy(ends_tensor.Data<int64_t>(), ends_tensor.Data<int64_t>() + size, input_ends.begin());
if (nullptr != axes_tensor)
std::copy(axes_tensor->Data<int64_t>(), axes_tensor->Data<int64_t>() + size, input_axes.begin());
// Slice V10
@ -305,7 +305,7 @@ void SliceBase::FillVectorsFromInput(const Tensor* start_tensor,
// should not reach this as no kernel is registered for this condition to be triggered - just an additional safety check
else {
ORT_THROW("Data type for starts and ends inputs' need to be int32_t or int64_t, but instead got ", start_tensor->DataType());
ORT_THROW("Data type for starts and ends inputs' need to be int32_t or int64_t, but instead got ", start_tensor.DataType());
}
}
@ -379,7 +379,7 @@ Status Slice<T, dynamic>::Compute(OpKernelContext* ctx) const {
std::vector<int64_t> input_ends;
std::vector<int64_t> input_axes;
std::vector<int64_t> input_steps;
FillVectorsFromInput(ctx->Input<Tensor>(1), ctx->Input<Tensor>(2), ctx->Input<Tensor>(3),
FillVectorsFromInput(*ctx->Input<Tensor>(1), *ctx->Input<Tensor>(2), ctx->Input<Tensor>(3),
ctx->Input<Tensor>(4), input_starts, input_ends, input_axes, input_steps);
ORT_RETURN_IF_ERROR(PrepareForCompute(input_starts, input_ends, input_axes, input_steps,

View file

@ -43,8 +43,8 @@ class SliceBase {
std::vector<int64_t>*& flattened_output_dims) const;
// Slice V10 & DynamicSlice
void FillVectorsFromInput(const Tensor* start_tensor,
const Tensor* ends_tensor,
void FillVectorsFromInput(const Tensor& start_tensor,
const Tensor& ends_tensor,
const Tensor* axes_tensor,
const Tensor* steps_tensor,
std::vector<int64_t>& input_starts,

View file

@ -15,7 +15,10 @@ namespace onnxruntime {
struct CUDAProviderFactory : IExecutionProviderFactory {
CUDAProviderFactory(OrtDevice::DeviceId device_id,
size_t cuda_mem_limit = std::numeric_limits<size_t>::max(),
ArenaExtendStrategy arena_extend_strategy = ArenaExtendStrategy::kNextPowerOfTwo) : device_id_(device_id), cuda_mem_limit_(cuda_mem_limit), arena_extend_strategy_(arena_extend_strategy) {}
ArenaExtendStrategy arena_extend_strategy = ArenaExtendStrategy::kNextPowerOfTwo)
: device_id_(device_id),
cuda_mem_limit_(cuda_mem_limit),
arena_extend_strategy_(arena_extend_strategy) {}
~CUDAProviderFactory() override {}
std::unique_ptr<IExecutionProvider> CreateProvider() override;

View file

@ -156,7 +156,7 @@ template <bool dynamic>
void Slice<dynamic>::FillInputVectors(OpKernelContext* ctx, std::vector<int64_t>& input_starts,
std::vector<int64_t>& input_ends, std::vector<int64_t>& input_axes,
std::vector<int64_t>& input_steps) const {
FillVectorsFromInput(ctx->Input<Tensor>(1), ctx->Input<Tensor>(2), ctx->Input<Tensor>(3),
FillVectorsFromInput(*ctx->Input<Tensor>(1), *ctx->Input<Tensor>(2), ctx->Input<Tensor>(3),
ctx->Input<Tensor>(4), input_starts, input_ends, input_axes, input_steps);
}

View file

@ -690,9 +690,6 @@ common::Status InferenceSession::CreateSubgraphSessionState(Graph& graph, Sessio
// Pass fused function manager to subgraph
subgraph_session_state->GetMutableFuncMgr().SetFusedFuncs(session_state.GetFuncMgr());
// Pass fused function manager to subgraph
subgraph_session_state->GetMutableFuncMgr().SetFusedFuncs(session_state.GetFuncMgr());
// recurse
ORT_RETURN_IF_ERROR_SESSIONID_(CreateSubgraphSessionState(*subgraph, *subgraph_session_state));

View file

@ -373,8 +373,6 @@ class InferenceSession {
// The file path of where the model was loaded. e.g. /tmp/test_squeezenet/model.onnx
std::basic_string<ORTCHAR_T> model_location_;
SessionOptions session_options_;
private:
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(InferenceSession);
@ -430,6 +428,8 @@ class InferenceSession {
template <typename T>
void StartProfiling(const std::basic_string<T>& file_prefix);
SessionOptions session_options_;
onnxruntime::GraphTransformerManager graph_transformation_mgr_;
// List of transformers to run. When this list is not empty only the transformers in this list

View file

@ -371,7 +371,7 @@ void TrainingSession::AddPredefinedTransformers(GraphTransformerManager& transfo
const std::vector<std::string>& custom_list) {
auto add_transformers = [&](TransformerLevel level) {
// Generate and register transformers for level
auto transformers_to_register = transformer_utils::GenerateTransformers(level, session_options_.free_dimension_overrides, custom_list);
auto transformers_to_register = transformer_utils::GenerateTransformers(level, GetSessionOptions().free_dimension_overrides, custom_list);
for (auto& entry : transformers_to_register) {
transformer_manager.Register(std::move(entry), level);
}

View file

@ -37,7 +37,7 @@ Status SliceGrad::Compute(OpKernelContext* context) const {
std::vector<int64_t> input_ends;
std::vector<int64_t> input_axes;
std::vector<int64_t> input_steps;
FillVectorsFromInput(context->Input<Tensor>(2), context->Input<Tensor>(3), context->Input<Tensor>(4),
FillVectorsFromInput(*context->Input<Tensor>(2), *context->Input<Tensor>(3), context->Input<Tensor>(4),
context->Input<Tensor>(5), input_starts, input_ends, input_axes, input_steps);
ORT_RETURN_IF_ERROR(PrepareForCompute(input_starts, input_ends, input_axes, input_steps,

View file

@ -43,7 +43,7 @@ const Tensor* SliceGrad::GetSlicedOrUnslicedTensor(OpKernelContext* ctx) const {
void SliceGrad::FillInputVectors(OpKernelContext* ctx, std::vector<int64_t>& input_starts,
std::vector<int64_t>& input_ends, std::vector<int64_t>& input_axes,
std::vector<int64_t>& input_steps) const {
FillVectorsFromInput(ctx->Input<Tensor>(2), ctx->Input<Tensor>(3), ctx->Input<Tensor>(4),
FillVectorsFromInput(*ctx->Input<Tensor>(2), *ctx->Input<Tensor>(3), ctx->Input<Tensor>(4),
ctx->Input<Tensor>(5), input_starts, input_ends, input_axes, input_steps);
}