mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-26 22:35:43 +00:00
Addressing PR comments (#3334)
* PR comments * PR comments * PR comments * error out bad shape Co-authored-by: Ethan Tao <ettao@microsoft.com>
This commit is contained in:
parent
0a6ec0df56
commit
131c65d23d
10 changed files with 35 additions and 34 deletions
|
|
@ -188,8 +188,7 @@ static int64_t CalculateMemoryPatternsKey(const std::vector<std::reference_wrapp
|
|||
|
||||
namespace {
|
||||
Status ResolveDimParams(const GraphViewer& graph, const std::map<std::string, TensorShape>& feeds, std::unordered_map<std::string, int64_t>& out) {
|
||||
for (size_t i = 0; i < graph.GetInputs().size(); ++i) {
|
||||
auto* input = graph.GetInputs()[i];
|
||||
for (const auto* input : graph.GetInputs()) {
|
||||
auto* shape = input->Shape();
|
||||
auto it = feeds.find(input->Name());
|
||||
if (it == feeds.end())
|
||||
|
|
@ -198,7 +197,7 @@ Status ResolveDimParams(const GraphViewer& graph, const std::map<std::string, Te
|
|||
return Status(ONNXRUNTIME, FAIL, "Graph input " + input->Name() +
|
||||
"'s shape is not present or its shape doesn't match feed's shape."
|
||||
"Unable to resolve the value for dynamic shape");
|
||||
for (int k = 0; k < shape->dim_size(); ++k) {
|
||||
for (int k = 0, end = shape->dim_size(); k < end; ++k) {
|
||||
if (shape->dim()[k].has_dim_param()) {
|
||||
out.insert({shape->dim()[k].dim_param(), it->second.GetDims()[k]});
|
||||
}
|
||||
|
|
@ -212,7 +211,7 @@ Status SessionState::GeneratePatternGroupCache(const std::vector<std::reference_
|
|||
const std::vector<int>& feed_mlvalue_idxs,
|
||||
MemoryPatternGroup* output) const {
|
||||
std::map<std::string, TensorShape> feeds;
|
||||
for (size_t i = 0; i < feed_mlvalue_idxs.size(); ++i) {
|
||||
for (size_t i = 0, end = feed_mlvalue_idxs.size(); i < end; ++i) {
|
||||
std::string name;
|
||||
ORT_RETURN_IF_ERROR(this->ort_value_name_idx_map_.GetName(feed_mlvalue_idxs[i], name));
|
||||
feeds.insert({name, input_shape[i]});
|
||||
|
|
@ -228,7 +227,7 @@ Status SessionState::GeneratePatternGroupCache(const std::vector<std::reference_
|
|||
auto* node = graph_viewer_->GetNode(node_plan.node_index);
|
||||
int output_start = node_index + static_cast<int>(node->InputDefs().size()) + static_cast<int>(node->ImplicitInputDefs().size());
|
||||
//allocate output
|
||||
for (int i = 0; i < static_cast<int>(node->OutputDefs().size()); ++i) {
|
||||
for (int i = 0, end = static_cast<int>(node->OutputDefs().size()); i < end; ++i) {
|
||||
const auto ml_value_idx = node_index_info.GetMLValueIndex(output_start + i);
|
||||
if (ml_value_idx == NodeIndexInfo::kInvalidEntry)
|
||||
continue;
|
||||
|
|
@ -251,8 +250,10 @@ Status SessionState::GeneratePatternGroupCache(const std::vector<std::reference_
|
|||
return Status(ONNXRUNTIME, FAIL, "Unknown shape found in memory pattern compute");
|
||||
}
|
||||
len *= it->second;
|
||||
} else {
|
||||
} else if (dim.has_dim_value()) {
|
||||
len *= dim.dim_value();
|
||||
} else {
|
||||
return Status(ONNXRUNTIME, FAIL, "Unknown shape found in memory pattern compute");
|
||||
}
|
||||
}
|
||||
if (!IAllocator::CalcMemSizeForArrayWithAlignment<64>(len, ml_data_type->Size(), &size)) {
|
||||
|
|
|
|||
|
|
@ -260,21 +260,21 @@ Status SliceBase::PrepareForCompute(const std::vector<int64_t>& raw_starts,
|
|||
}
|
||||
|
||||
// Slice V10 & DynamicSlice
|
||||
void SliceBase::FillVectorsFromInput(const Tensor* start_tensor,
|
||||
const Tensor* ends_tensor,
|
||||
void SliceBase::FillVectorsFromInput(const Tensor& start_tensor,
|
||||
const Tensor& ends_tensor,
|
||||
const Tensor* axes_tensor,
|
||||
const Tensor* steps_tensor,
|
||||
std::vector<int64_t>& input_starts,
|
||||
std::vector<int64_t>& input_ends,
|
||||
std::vector<int64_t>& input_axes,
|
||||
std::vector<int64_t>& input_steps) const {
|
||||
ORT_ENFORCE(nullptr != start_tensor && start_tensor->Shape().NumDimensions() == 1, "Starts must be a 1-D array");
|
||||
ORT_ENFORCE(nullptr != ends_tensor && ends_tensor->Shape().NumDimensions() == 1, "Ends must be a 1-D array");
|
||||
ORT_ENFORCE(start_tensor->Shape() == ends_tensor->Shape(), "Starts and ends shape mismatch");
|
||||
ORT_ENFORCE(nullptr == axes_tensor || start_tensor->Shape() == axes_tensor->Shape(), "Starts and axes shape mismatch");
|
||||
ORT_ENFORCE(nullptr == steps_tensor || start_tensor->Shape() == steps_tensor->Shape(), "Starts and steps shape mismatch");
|
||||
ORT_ENFORCE(start_tensor.Shape().NumDimensions() == 1, "Starts must be a 1-D array");
|
||||
ORT_ENFORCE(ends_tensor.Shape().NumDimensions() == 1, "Ends must be a 1-D array");
|
||||
ORT_ENFORCE(start_tensor.Shape() == ends_tensor.Shape(), "Starts and ends shape mismatch");
|
||||
ORT_ENFORCE(nullptr == axes_tensor || start_tensor.Shape() == axes_tensor->Shape(), "Starts and axes shape mismatch");
|
||||
ORT_ENFORCE(nullptr == steps_tensor || start_tensor.Shape() == steps_tensor->Shape(), "Starts and steps shape mismatch");
|
||||
|
||||
const auto& size = start_tensor->Shape().Size();
|
||||
const auto& size = start_tensor.Shape().Size();
|
||||
input_starts.resize(size);
|
||||
input_ends.resize(size);
|
||||
if (nullptr != axes_tensor)
|
||||
|
|
@ -283,9 +283,9 @@ void SliceBase::FillVectorsFromInput(const Tensor* start_tensor,
|
|||
if (nullptr != steps_tensor)
|
||||
input_steps.resize(size);
|
||||
|
||||
if (start_tensor->IsDataType<int32_t>()) {
|
||||
std::copy(start_tensor->Data<int32_t>(), start_tensor->Data<int32_t>() + size, input_starts.begin());
|
||||
std::copy(ends_tensor->Data<int32_t>(), ends_tensor->Data<int32_t>() + size, input_ends.begin());
|
||||
if (start_tensor.IsDataType<int32_t>()) {
|
||||
std::copy(start_tensor.Data<int32_t>(), start_tensor.Data<int32_t>() + size, input_starts.begin());
|
||||
std::copy(ends_tensor.Data<int32_t>(), ends_tensor.Data<int32_t>() + size, input_ends.begin());
|
||||
if (nullptr != axes_tensor)
|
||||
std::copy(axes_tensor->Data<int32_t>(), axes_tensor->Data<int32_t>() + size, input_axes.begin());
|
||||
// Slice V10
|
||||
|
|
@ -293,9 +293,9 @@ void SliceBase::FillVectorsFromInput(const Tensor* start_tensor,
|
|||
std::copy(steps_tensor->Data<int32_t>(), steps_tensor->Data<int32_t>() + size, input_steps.begin());
|
||||
}
|
||||
|
||||
else if (start_tensor->IsDataType<int64_t>()) {
|
||||
std::copy(start_tensor->Data<int64_t>(), start_tensor->Data<int64_t>() + size, input_starts.begin());
|
||||
std::copy(ends_tensor->Data<int64_t>(), ends_tensor->Data<int64_t>() + size, input_ends.begin());
|
||||
else if (start_tensor.IsDataType<int64_t>()) {
|
||||
std::copy(start_tensor.Data<int64_t>(), start_tensor.Data<int64_t>() + size, input_starts.begin());
|
||||
std::copy(ends_tensor.Data<int64_t>(), ends_tensor.Data<int64_t>() + size, input_ends.begin());
|
||||
if (nullptr != axes_tensor)
|
||||
std::copy(axes_tensor->Data<int64_t>(), axes_tensor->Data<int64_t>() + size, input_axes.begin());
|
||||
// Slice V10
|
||||
|
|
@ -305,7 +305,7 @@ void SliceBase::FillVectorsFromInput(const Tensor* start_tensor,
|
|||
|
||||
// should not reach this as no kernel is registered for this condition to be triggered - just an additional safety check
|
||||
else {
|
||||
ORT_THROW("Data type for starts and ends inputs' need to be int32_t or int64_t, but instead got ", start_tensor->DataType());
|
||||
ORT_THROW("Data type for starts and ends inputs' need to be int32_t or int64_t, but instead got ", start_tensor.DataType());
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -379,7 +379,7 @@ Status Slice<T, dynamic>::Compute(OpKernelContext* ctx) const {
|
|||
std::vector<int64_t> input_ends;
|
||||
std::vector<int64_t> input_axes;
|
||||
std::vector<int64_t> input_steps;
|
||||
FillVectorsFromInput(ctx->Input<Tensor>(1), ctx->Input<Tensor>(2), ctx->Input<Tensor>(3),
|
||||
FillVectorsFromInput(*ctx->Input<Tensor>(1), *ctx->Input<Tensor>(2), ctx->Input<Tensor>(3),
|
||||
ctx->Input<Tensor>(4), input_starts, input_ends, input_axes, input_steps);
|
||||
|
||||
ORT_RETURN_IF_ERROR(PrepareForCompute(input_starts, input_ends, input_axes, input_steps,
|
||||
|
|
|
|||
|
|
@ -43,8 +43,8 @@ class SliceBase {
|
|||
std::vector<int64_t>*& flattened_output_dims) const;
|
||||
|
||||
// Slice V10 & DynamicSlice
|
||||
void FillVectorsFromInput(const Tensor* start_tensor,
|
||||
const Tensor* ends_tensor,
|
||||
void FillVectorsFromInput(const Tensor& start_tensor,
|
||||
const Tensor& ends_tensor,
|
||||
const Tensor* axes_tensor,
|
||||
const Tensor* steps_tensor,
|
||||
std::vector<int64_t>& input_starts,
|
||||
|
|
|
|||
|
|
@ -15,7 +15,10 @@ namespace onnxruntime {
|
|||
struct CUDAProviderFactory : IExecutionProviderFactory {
|
||||
CUDAProviderFactory(OrtDevice::DeviceId device_id,
|
||||
size_t cuda_mem_limit = std::numeric_limits<size_t>::max(),
|
||||
ArenaExtendStrategy arena_extend_strategy = ArenaExtendStrategy::kNextPowerOfTwo) : device_id_(device_id), cuda_mem_limit_(cuda_mem_limit), arena_extend_strategy_(arena_extend_strategy) {}
|
||||
ArenaExtendStrategy arena_extend_strategy = ArenaExtendStrategy::kNextPowerOfTwo)
|
||||
: device_id_(device_id),
|
||||
cuda_mem_limit_(cuda_mem_limit),
|
||||
arena_extend_strategy_(arena_extend_strategy) {}
|
||||
~CUDAProviderFactory() override {}
|
||||
|
||||
std::unique_ptr<IExecutionProvider> CreateProvider() override;
|
||||
|
|
|
|||
|
|
@ -156,7 +156,7 @@ template <bool dynamic>
|
|||
void Slice<dynamic>::FillInputVectors(OpKernelContext* ctx, std::vector<int64_t>& input_starts,
|
||||
std::vector<int64_t>& input_ends, std::vector<int64_t>& input_axes,
|
||||
std::vector<int64_t>& input_steps) const {
|
||||
FillVectorsFromInput(ctx->Input<Tensor>(1), ctx->Input<Tensor>(2), ctx->Input<Tensor>(3),
|
||||
FillVectorsFromInput(*ctx->Input<Tensor>(1), *ctx->Input<Tensor>(2), ctx->Input<Tensor>(3),
|
||||
ctx->Input<Tensor>(4), input_starts, input_ends, input_axes, input_steps);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -690,9 +690,6 @@ common::Status InferenceSession::CreateSubgraphSessionState(Graph& graph, Sessio
|
|||
// Pass fused function manager to subgraph
|
||||
subgraph_session_state->GetMutableFuncMgr().SetFusedFuncs(session_state.GetFuncMgr());
|
||||
|
||||
// Pass fused function manager to subgraph
|
||||
subgraph_session_state->GetMutableFuncMgr().SetFusedFuncs(session_state.GetFuncMgr());
|
||||
|
||||
// recurse
|
||||
ORT_RETURN_IF_ERROR_SESSIONID_(CreateSubgraphSessionState(*subgraph, *subgraph_session_state));
|
||||
|
||||
|
|
|
|||
|
|
@ -373,8 +373,6 @@ class InferenceSession {
|
|||
// The file path of where the model was loaded. e.g. /tmp/test_squeezenet/model.onnx
|
||||
std::basic_string<ORTCHAR_T> model_location_;
|
||||
|
||||
SessionOptions session_options_;
|
||||
|
||||
private:
|
||||
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(InferenceSession);
|
||||
|
||||
|
|
@ -430,6 +428,8 @@ class InferenceSession {
|
|||
template <typename T>
|
||||
void StartProfiling(const std::basic_string<T>& file_prefix);
|
||||
|
||||
SessionOptions session_options_;
|
||||
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr_;
|
||||
|
||||
// List of transformers to run. When this list is not empty only the transformers in this list
|
||||
|
|
|
|||
|
|
@ -371,7 +371,7 @@ void TrainingSession::AddPredefinedTransformers(GraphTransformerManager& transfo
|
|||
const std::vector<std::string>& custom_list) {
|
||||
auto add_transformers = [&](TransformerLevel level) {
|
||||
// Generate and register transformers for level
|
||||
auto transformers_to_register = transformer_utils::GenerateTransformers(level, session_options_.free_dimension_overrides, custom_list);
|
||||
auto transformers_to_register = transformer_utils::GenerateTransformers(level, GetSessionOptions().free_dimension_overrides, custom_list);
|
||||
for (auto& entry : transformers_to_register) {
|
||||
transformer_manager.Register(std::move(entry), level);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -37,7 +37,7 @@ Status SliceGrad::Compute(OpKernelContext* context) const {
|
|||
std::vector<int64_t> input_ends;
|
||||
std::vector<int64_t> input_axes;
|
||||
std::vector<int64_t> input_steps;
|
||||
FillVectorsFromInput(context->Input<Tensor>(2), context->Input<Tensor>(3), context->Input<Tensor>(4),
|
||||
FillVectorsFromInput(*context->Input<Tensor>(2), *context->Input<Tensor>(3), context->Input<Tensor>(4),
|
||||
context->Input<Tensor>(5), input_starts, input_ends, input_axes, input_steps);
|
||||
|
||||
ORT_RETURN_IF_ERROR(PrepareForCompute(input_starts, input_ends, input_axes, input_steps,
|
||||
|
|
|
|||
|
|
@ -43,7 +43,7 @@ const Tensor* SliceGrad::GetSlicedOrUnslicedTensor(OpKernelContext* ctx) const {
|
|||
void SliceGrad::FillInputVectors(OpKernelContext* ctx, std::vector<int64_t>& input_starts,
|
||||
std::vector<int64_t>& input_ends, std::vector<int64_t>& input_axes,
|
||||
std::vector<int64_t>& input_steps) const {
|
||||
FillVectorsFromInput(ctx->Input<Tensor>(2), ctx->Input<Tensor>(3), ctx->Input<Tensor>(4),
|
||||
FillVectorsFromInput(*ctx->Input<Tensor>(2), *ctx->Input<Tensor>(3), ctx->Input<Tensor>(4),
|
||||
ctx->Input<Tensor>(5), input_starts, input_ends, input_axes, input_steps);
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue