mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-30 03:37:44 +00:00
Merge remote-tracking branch 'upstream/master' into DmlDev
This commit is contained in:
commit
1e1ba6cc4f
56 changed files with 686 additions and 675 deletions
|
|
@ -144,6 +144,7 @@ if(HAS_DEPRECATED_COPY)
|
|||
endif()
|
||||
|
||||
target_include_directories(onnxruntime_providers PRIVATE ${ONNXRUNTIME_ROOT} ${eigen_INCLUDE_DIRS} ${gemmlowp_src} ${RE2_INCLUDE_DIR})
|
||||
|
||||
add_dependencies(onnxruntime_providers onnx ${onnxruntime_EXTERNAL_DEPENDENCIES})
|
||||
|
||||
if (onnxruntime_ENABLE_TRAINING)
|
||||
|
|
@ -154,6 +155,9 @@ if (onnxruntime_ENABLE_TRAINING)
|
|||
if (onnxruntime_USE_HOROVOD)
|
||||
target_include_directories(onnxruntime_providers PRIVATE ${HOROVOD_INCLUDE_DIRS})
|
||||
endif()
|
||||
if (onnxruntime_USE_NCCL OR onnxruntime_USE_HOROVOD)
|
||||
target_include_directories(onnxruntime_providers PUBLIC ${MPI_INCLUDE_DIRS})
|
||||
endif()
|
||||
endif()
|
||||
|
||||
install(DIRECTORY ${PROJECT_SOURCE_DIR}/../include/onnxruntime/core/providers/cpu DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/onnxruntime/core/providers)
|
||||
|
|
|
|||
|
|
@ -31,16 +31,9 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
/// <returns></returns>
|
||||
public static FixedBufferOnnxValue CreateFromTensor<T>(Tensor<T> value)
|
||||
{
|
||||
if (value is Tensor<string>)
|
||||
{
|
||||
throw new ArgumentException("Only numeric tensors can be used to create FixedBufferOnnxValue.", nameof(value));
|
||||
}
|
||||
|
||||
NativeOnnxValueHelper.CreateNativeOnnxValue(value, out IntPtr onnxValue, out MemoryHandle pinnedMemoryHandle, out OnnxValueType onnxValueType, out TensorElementType elementType);
|
||||
|
||||
Debug.Assert(
|
||||
onnxValueType == OnnxValueType.ONNX_TYPE_TENSOR && elementType != TensorElementType.String,
|
||||
"the value should always be a numeric tensor");
|
||||
Debug.Assert(onnxValueType == OnnxValueType.ONNX_TYPE_TENSOR, "the value should always be a tensor");
|
||||
|
||||
return new FixedBufferOnnxValue(pinnedMemoryHandle, onnxValue, onnxValueType, elementType);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -366,6 +366,11 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
int outputIndex = 0;
|
||||
foreach (var output in outputValues)
|
||||
{
|
||||
if (output.ElementType == TensorElementType.String)
|
||||
{
|
||||
throw new NotSupportedException("Using string type FixedBufferOnnxValue in outputs is not supported.");
|
||||
}
|
||||
|
||||
outputValuesArray[outputIndex] = output.Value;
|
||||
|
||||
outputIndex++;
|
||||
|
|
@ -556,6 +561,11 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
int outputIndex = 0;
|
||||
foreach (var output in outputValues)
|
||||
{
|
||||
if (output.ElementType == TensorElementType.String)
|
||||
{
|
||||
throw new NotSupportedException("Using string type FixedBufferOnnxValue in outputs is not supported.");
|
||||
}
|
||||
|
||||
outputValuesArray[outputIndex] = output.Value;
|
||||
|
||||
outputIndex++;
|
||||
|
|
@ -695,7 +705,7 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
IntPtr nameHandle = IntPtr.Zero;
|
||||
string str = null;
|
||||
|
||||
IntPtr status = NativeMethods.OrtSessionEndProfiling(_nativeHandle,
|
||||
IntPtr status = NativeMethods.OrtSessionEndProfiling(_nativeHandle,
|
||||
NativeMemoryAllocator.DefaultInstance.Handle,
|
||||
out nameHandle);
|
||||
|
||||
|
|
@ -708,7 +718,7 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
{
|
||||
if (nameHandle != IntPtr.Zero)
|
||||
{
|
||||
NativeMemoryAllocator.DefaultInstance.FreeMemory(nameHandle);
|
||||
NativeMemoryAllocator.DefaultInstance.FreeMemory(nameHandle);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1056,12 +1056,41 @@ namespace Microsoft.ML.OnnxRuntime.Tests
|
|||
public void TestCreateFixedBufferOnnxValueFromStringTensor()
|
||||
{
|
||||
var tensor = new DenseTensor<string>(new string[] { "a", "b" }, new int[] { 1, 2 });
|
||||
using (var value = FixedBufferOnnxValue.CreateFromTensor(tensor)) { }
|
||||
}
|
||||
|
||||
Assert.Throws<ArgumentException>("value", () =>
|
||||
[Fact]
|
||||
public void TestReusingStringFixedBufferOnnxValue()
|
||||
{
|
||||
string modelPath = Path.Combine(Directory.GetCurrentDirectory(), "test_types_STRING.pb");
|
||||
using (var session = new InferenceSession(modelPath))
|
||||
{
|
||||
// cannot create from string tensor
|
||||
FixedBufferOnnxValue.CreateFromTensor(tensor);
|
||||
});
|
||||
var tensorA = new DenseTensor<string>(new string[] { "a", "b", "c", "d", "e" }, new int[] { 1, 5 });
|
||||
var tensorB = new DenseTensor<string>(new string[] { "v", "w", "x", "y", "z" }, new int[] { 1, 5 });
|
||||
var tensorC = new DenseTensor<string>(new string[] { "i", "j", "k", "l", "m" }, new int[] { 1, 5 });
|
||||
var tensorD = new DenseTensor<string>(new string[] { "i", "j", "k", "l", "m" }, new int[] { 1, 5 });
|
||||
using (FixedBufferOnnxValue a = FixedBufferOnnxValue.CreateFromTensor(tensorA),
|
||||
b = FixedBufferOnnxValue.CreateFromTensor(tensorB),
|
||||
c = FixedBufferOnnxValue.CreateFromTensor(tensorC),
|
||||
d = FixedBufferOnnxValue.CreateFromTensor(tensorD))
|
||||
{
|
||||
// OK to use string type FixedBufferOnnxValue only in input
|
||||
session.Run(new[] { "input" }, new[] { a });
|
||||
|
||||
// Cannot use string type FixedBufferOnnxValue in output
|
||||
Assert.Throws<NotSupportedException>(() =>
|
||||
{
|
||||
// NamedOnnxValue inputs
|
||||
session.Run(new[] { NamedOnnxValue.CreateFromTensor("input", tensorB) }, new[] { "output" }, new[] { b });
|
||||
});
|
||||
Assert.Throws<NotSupportedException>(() =>
|
||||
{
|
||||
// both FixedBufferOnnxValue for inputs and outputs
|
||||
session.Run(new[] { "input" }, new[] { c }, new[] { "output" }, new[] { d });
|
||||
});
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
|
|
|
|||
|
|
@ -254,8 +254,17 @@ ExecutionFrame::ExecutionFrame(const std::vector<int>& feed_mlvalue_idxs, const
|
|||
// it's less efficient (the arena will add some overhead to coalesce individual allocations
|
||||
// back into blocks on 'free'), but better than failing completely.
|
||||
try {
|
||||
buffer = alloc->Alloc(mem_patterns_->patterns[i].PeakSize());
|
||||
|
||||
// static_activation_memory_in_bytes_ is max virtual memory size the planner computes
|
||||
auto peak_size = mem_patterns_->patterns[i].PeakSize();
|
||||
// Planning of one memory type should only happen once.
|
||||
ORT_ENFORCE(
|
||||
static_activation_memory_sizes_in_byte_.find(location.name) ==
|
||||
static_activation_memory_sizes_in_byte_.end(),
|
||||
"Memory type ",
|
||||
location.name,
|
||||
" should only appear once.");
|
||||
static_activation_memory_sizes_in_byte_[location.name] = peak_size;
|
||||
buffer = alloc->Alloc(peak_size);
|
||||
// handle allocator that doesn't throw
|
||||
if (buffer == nullptr) {
|
||||
// INFO level as this may fire on every run and there may not be much a user can do
|
||||
|
|
@ -375,6 +384,8 @@ Status ExecutionFrame::AllocateMLValueTensorSelfOwnBufferHelper(OrtValue& ort_va
|
|||
TraceAllocate(ort_value_index, size);
|
||||
}
|
||||
|
||||
dynamic_activation_memory_sizes_in_byte_[location.name] += size;
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -129,6 +129,18 @@ class ExecutionFrame final : public IExecutionFrame {
|
|||
return planner_ != nullptr;
|
||||
}
|
||||
|
||||
// Return the size of virtual memory allocated in runtime.
|
||||
// The memory is usually used for activations in forward and backward passes.
|
||||
const std::unordered_map<std::string, size_t>& GetDynamicMemorySizeInfo() {
|
||||
return dynamic_activation_memory_sizes_in_byte_;
|
||||
}
|
||||
|
||||
// Return the size of virtual memory allocated before computation.
|
||||
// The memory is usually used for activations in forward and backward passes.
|
||||
const std::unordered_map<std::string, size_t>& GetStaticMemorySizeInfo() {
|
||||
return static_activation_memory_sizes_in_byte_;
|
||||
}
|
||||
|
||||
private:
|
||||
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(ExecutionFrame);
|
||||
|
||||
|
|
@ -168,5 +180,13 @@ class ExecutionFrame final : public IExecutionFrame {
|
|||
|
||||
// Big chunks on different locations that will be used by mem_pattern.
|
||||
std::map<OrtMemoryInfo, BufferUniquePtr> buffers_;
|
||||
|
||||
// Size of virtual memory allocated before any kernel execution.
|
||||
// This field is not physical memory size.
|
||||
std::unordered_map<std::string, size_t> static_activation_memory_sizes_in_byte_;
|
||||
// Size of virtual memory allocated during kernel execution (i.e., inside a kernel,
|
||||
// we may allocate some memory for its outputs, if not planned.).
|
||||
// This field is not physical memory size.
|
||||
std::unordered_map<std::string, size_t> dynamic_activation_memory_sizes_in_byte_;
|
||||
};
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -446,6 +446,16 @@ Status SequentialExecutor::Execute(const SessionState& session_state, const std:
|
|||
session_state.Profiler().EndTimeAndRecordEvent(profiling::SESSION_EVENT, "SequentialExecutor::Execute", tp);
|
||||
}
|
||||
|
||||
for (auto i: frame.GetStaticMemorySizeInfo()) {
|
||||
LOGS(logger, INFO) << "[Memory] ExecutionFrame statically allocates "
|
||||
<< i.second << " bytes for " << i.first << std::endl;
|
||||
}
|
||||
|
||||
for (auto i: frame.GetDynamicMemorySizeInfo()) {
|
||||
LOGS(logger, INFO) << "[Memory] ExecutionFrame dynamically allocates "
|
||||
<< i.second << " bytes for " << i.first << std::endl;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -86,13 +86,13 @@ common::Status SessionStateInitializer::CreatePlan(
|
|||
const auto* exec_plan_ptr = session_state_.GetExecutionPlan();
|
||||
ORT_ENFORCE(exec_plan_ptr, "Execution plan was not found in SessionState. CreatePlan must be called first.");
|
||||
|
||||
std::unique_ptr<ITensorAllocator> tensor_allocator_(ITensorAllocator::Create(
|
||||
std::unique_ptr<ITensorAllocator> tensor_allocator(ITensorAllocator::Create(
|
||||
enable_mem_pattern_, *exec_plan_ptr, execution_providers_, session_state_.GetMutableWeightsBuffers()));
|
||||
|
||||
// lambda to save initialized tensors into SessionState directly
|
||||
const Env& env = Env::Default();
|
||||
ORT_RETURN_IF_ERROR(SaveInitializedTensors(
|
||||
env, graph_loc_, graph_, execution_providers_, ort_value_name_idx_map, tensor_allocator_.get(),
|
||||
env, graph_loc_, graph_, execution_providers_, ort_value_name_idx_map, tensor_allocator.get(),
|
||||
[this](int idx, const OrtValue& value, const OrtCallback& d, bool constant) -> Status {
|
||||
return session_state_.AddInitializedTensor(idx, value, &d, constant);
|
||||
},
|
||||
|
|
@ -191,7 +191,17 @@ common::Status SaveInitializedTensors(const Env& env, const std::basic_string<PA
|
|||
}
|
||||
|
||||
//2. allocate weight buffer on different locations
|
||||
ORT_RETURN_IF_ERROR(planner->FinalizePlan());
|
||||
// planned_initializers_memory_size_in_byte is not actual physical size.
|
||||
// It's the virtual size computed by planner.
|
||||
std::unordered_map<std::string, size_t> planned_initializers_memory_sizes_in_byte;
|
||||
ORT_RETURN_IF_ERROR(
|
||||
planner->FinalizePlan(planned_initializers_memory_sizes_in_byte));
|
||||
|
||||
for (auto i: planned_initializers_memory_sizes_in_byte) {
|
||||
LOGS(logger, INFO) << "[Memory] SessionStateInitializer statically allocates "
|
||||
<< i.second << " bytes for " << i.first << std::endl;
|
||||
}
|
||||
|
||||
OrtCallback deleter;
|
||||
//3. create weight tensors based on weights buffer
|
||||
for (const auto& entry : id_to_initialized_tensor) {
|
||||
|
|
|
|||
|
|
@ -27,7 +27,12 @@ class SimpleTensorAllocator : public ITensorAllocator {
|
|||
: ITensorAllocator(exec_providers),
|
||||
weights_buffers_(weights_buffers),
|
||||
seq_plan_(execution_plan) {}
|
||||
common::Status FinalizePlan() override { return Status::OK(); }
|
||||
common::Status FinalizePlan(std::unordered_map<std::string, size_t>& planned_memory_sizes_in_byte) override {
|
||||
// There is no memory plan to allocate a big block of memory, so
|
||||
// planned memory sizes in different locations are all empty.
|
||||
planned_memory_sizes_in_byte = std::unordered_map<std::string, size_t>();
|
||||
return Status::OK();
|
||||
}
|
||||
common::Status GetPreallocatedBuffer(int ort_value_index, const char* name, std::unique_ptr<MemBuffer>& out) override;
|
||||
common::Status Trace(int id, const ONNX_NAMESPACE::TensorProto* value) override;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -22,7 +22,15 @@ class ITensorAllocator {
|
|||
public:
|
||||
AllocatorPtr GetAllocator(const OrtMemoryInfo& memory_info);
|
||||
|
||||
virtual common::Status FinalizePlan() = 0;
|
||||
/**
|
||||
*
|
||||
* \param planned_memory_size_in_byte The size of memory allocated inside FinalizePlan
|
||||
*
|
||||
* When there is no more tensor to trace, call this function to finalize the
|
||||
* allocation.
|
||||
*/
|
||||
virtual common::Status FinalizePlan(std::unordered_map<std::string, size_t>& planned_memory_sizes_in_byte) = 0;
|
||||
|
||||
/**
|
||||
*
|
||||
* \param ort_value_index The index in planner
|
||||
|
|
|
|||
|
|
@ -21,7 +21,8 @@ class TensorAllocatorWithMemPattern : public ITensorAllocator {
|
|||
bool is_sealed_ = false;
|
||||
const ExecutionPlanBase& seq_plan_;
|
||||
|
||||
common::Status AllocatePlannedBuffers() {
|
||||
common::Status AllocatePlannedBuffersAndReportTotalSize(
|
||||
std::unordered_map<std::string, size_t>& planned_memory_sizes_in_byte) {
|
||||
const size_t location_len = mem_patterns_.locations.size();
|
||||
for (size_t i = 0; i < location_len; ++i) {
|
||||
auto& location = mem_patterns_.locations[i];
|
||||
|
|
@ -30,21 +31,30 @@ class TensorAllocatorWithMemPattern : public ITensorAllocator {
|
|||
return Status(common::ONNXRUNTIME, common::FAIL,
|
||||
"Failed to get allocator for location: " + location.ToString());
|
||||
|
||||
if (mem_patterns_.patterns[i].PeakSize() > 0) {
|
||||
void* buffer;
|
||||
if (alloc->Info().alloc_type == OrtArenaAllocator) {
|
||||
buffer = static_cast<IArenaAllocator*>(alloc.get())->Reserve(mem_patterns_.patterns[i].PeakSize());
|
||||
}
|
||||
else {
|
||||
buffer = alloc->Alloc(mem_patterns_.patterns[i].PeakSize());
|
||||
}
|
||||
weights_buffers_.push_back(BufferUniquePtr(buffer, alloc));
|
||||
auto kvp = buffers_.insert(std::make_pair(location, buffer));
|
||||
if (!kvp.second) {
|
||||
alloc->Free(buffer);
|
||||
return Status(common::ONNXRUNTIME, common::FAIL, "duplicated location");
|
||||
}
|
||||
// Don't allocate memory when there is no memory usage..
|
||||
if (mem_patterns_.patterns[i].PeakSize() <= 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const auto peak_size = mem_patterns_.patterns[i].PeakSize();
|
||||
void* buffer;
|
||||
if (alloc->Info().alloc_type == OrtArenaAllocator) {
|
||||
// Arena has a specific way to store static memory.
|
||||
// Arena does not reuse static memory allocated by Reserve.
|
||||
buffer = static_cast<IArenaAllocator*>(
|
||||
alloc.get())->Reserve(peak_size);
|
||||
}
|
||||
else {
|
||||
buffer = alloc->Alloc(peak_size);
|
||||
}
|
||||
weights_buffers_.push_back(BufferUniquePtr(buffer, alloc));
|
||||
auto kvp = buffers_.insert(std::make_pair(location, buffer));
|
||||
if (!kvp.second) {
|
||||
alloc->Free(buffer);
|
||||
return Status(common::ONNXRUNTIME, common::FAIL, "duplicated location");
|
||||
}
|
||||
|
||||
planned_memory_sizes_in_byte[location.name] += peak_size;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
@ -57,9 +67,9 @@ class TensorAllocatorWithMemPattern : public ITensorAllocator {
|
|||
weights_buffers_(weights_buffers),
|
||||
seq_plan_(execution_plan) {}
|
||||
|
||||
common::Status FinalizePlan() override {
|
||||
common::Status FinalizePlan(std::unordered_map<std::string, size_t>& planned_memory_sizes_in_byte) override {
|
||||
ORT_RETURN_IF_ERROR(planner_.GeneratePatterns(&mem_patterns_));
|
||||
ORT_RETURN_IF_ERROR(AllocatePlannedBuffers());
|
||||
ORT_RETURN_IF_ERROR(AllocatePlannedBuffersAndReportTotalSize(planned_memory_sizes_in_byte));
|
||||
is_sealed_ = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ import os
|
|||
import random
|
||||
from pathlib import Path
|
||||
from onnx import ModelProto, TensorProto, numpy_helper
|
||||
from OnnxModel import OnnxModel
|
||||
from onnx_model import OnnxModel
|
||||
|
||||
|
||||
def fake_input_ids_data(input_ids, batch_size, sequence_length, dictionary_size):
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ import csv
|
|||
import timeit
|
||||
from datetime import datetime
|
||||
from onnx import ModelProto, TensorProto, numpy_helper
|
||||
from OnnxModel import OnnxModel
|
||||
from onnx_model import OnnxModel
|
||||
from bert_test_data import get_bert_inputs, generate_test_data, output_test_data
|
||||
from bert_perf_test import create_session, onnxruntime_inference, setup_openmp_environ
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@
|
|||
import numpy as np
|
||||
from logging import getLogger
|
||||
from onnx import helper, numpy_helper, TensorProto
|
||||
from OnnxModel import OnnxModel
|
||||
from onnx_model import OnnxModel
|
||||
from fusion_base import Fusion
|
||||
from fusion_utils import FusionUtils
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
# Licensed under the MIT License.
|
||||
#--------------------------------------------------------------------------
|
||||
from logging import getLogger
|
||||
from OnnxModel import OnnxModel
|
||||
from onnx_model import OnnxModel
|
||||
from typing import Union, List
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@
|
|||
|
||||
from logging import getLogger
|
||||
from onnx import helper, numpy_helper
|
||||
from OnnxModel import OnnxModel
|
||||
from onnx_model import OnnxModel
|
||||
from fusion_base import Fusion
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@
|
|||
from typing import Dict
|
||||
from logging import getLogger
|
||||
from onnx import helper
|
||||
from OnnxModel import OnnxModel
|
||||
from onnx_model import OnnxModel
|
||||
from fusion_base import Fusion
|
||||
from fusion_utils import FusionUtils
|
||||
|
||||
|
|
@ -36,9 +36,7 @@ class FusionEmbedLayerNoMask(Fusion):
|
|||
v v
|
||||
SkipLayerNormalization
|
||||
"""
|
||||
def __init__(self,
|
||||
model: OnnxModel,
|
||||
description='no mask'):
|
||||
def __init__(self, model: OnnxModel, description='no mask'):
|
||||
super().__init__(model, "EmbedLayerNormalization", "SkipLayerNormalization", description)
|
||||
self.utils = FusionUtils(model)
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@
|
|||
from typing import Dict, Optional
|
||||
from logging import getLogger
|
||||
from onnx import helper
|
||||
from OnnxModel import OnnxModel
|
||||
from onnx_model import OnnxModel
|
||||
from fusion_base import Fusion
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@
|
|||
from typing import Dict, Optional
|
||||
from logging import getLogger
|
||||
from onnx import helper
|
||||
from OnnxModel import OnnxModel
|
||||
from onnx_model import OnnxModel
|
||||
from fusion_base import Fusion
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@
|
|||
|
||||
from logging import getLogger
|
||||
from onnx import helper
|
||||
from OnnxModel import OnnxModel
|
||||
from onnx_model import OnnxModel
|
||||
from fusion_base import Fusion
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@
|
|||
import numpy as np
|
||||
from logging import getLogger
|
||||
from onnx import helper, numpy_helper, TensorProto
|
||||
from OnnxModel import OnnxModel
|
||||
from onnx_model import OnnxModel
|
||||
from fusion_base import Fusion
|
||||
from fusion_utils import FusionUtils
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@
|
|||
import numpy as np
|
||||
from logging import getLogger
|
||||
from onnx import helper, numpy_helper, TensorProto
|
||||
from OnnxModel import OnnxModel
|
||||
from onnx_model import OnnxModel
|
||||
from fusion_base import Fusion
|
||||
from fusion_utils import FusionUtils
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@
|
|||
from typing import Dict
|
||||
from logging import getLogger
|
||||
from onnx import helper
|
||||
from OnnxModel import OnnxModel
|
||||
from onnx_model import OnnxModel
|
||||
from fusion_base import Fusion
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@
|
|||
|
||||
from logging import getLogger
|
||||
from onnx import helper, numpy_helper, TensorProto
|
||||
from OnnxModel import OnnxModel
|
||||
from onnx_model import OnnxModel
|
||||
from fusion_base import Fusion
|
||||
import numpy as np
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@
|
|||
|
||||
from logging import getLogger
|
||||
from onnx import helper, numpy_helper
|
||||
from OnnxModel import OnnxModel
|
||||
from onnx_model import OnnxModel
|
||||
from fusion_base import Fusion
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
# Licensed under the MIT License.
|
||||
#--------------------------------------------------------------------------
|
||||
from logging import getLogger
|
||||
from OnnxModel import OnnxModel
|
||||
from onnx_model import OnnxModel
|
||||
from typing import Tuple
|
||||
from onnx import helper, TensorProto
|
||||
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -21,7 +21,9 @@
|
|||
"source": [
|
||||
"In this tutorial, you'll be introduced to how to load a Bert model from PyTorch, convert it to ONNX, and inference it for high performance using ONNX Runtime and NVIDIA GPU. In the following sections, we are going to use the Bert model trained with Stanford Question Answering Dataset (SQuAD) dataset as an example. Bert SQuAD model is used in question answering scenarios, where the answer to every question is a segment of text, or span, from the corresponding reading passage, or the question might be unanswerable.\n",
|
||||
"\n",
|
||||
"This notebook is for GPU inference. For CPU inference, please look at another notebook [Inference PyTorch Bert Model with ONNX Runtime on CPU](PyTorch_Bert-Squad_OnnxRuntime_CPU.ipynb)."
|
||||
"This notebook is for GPU inference. For CPU inference, please look at another notebook [Inference PyTorch Bert Model with ONNX Runtime on CPU](PyTorch_Bert-Squad_OnnxRuntime_CPU.ipynb).\n",
|
||||
"\n",
|
||||
"Note that you might need change !{sys.executable} to !python when running the notebook in Linux."
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
@ -33,15 +35,16 @@
|
|||
"\n",
|
||||
"#### GPU Environment Setup using AnaConda\n",
|
||||
"\n",
|
||||
"First, we install [AnaConda](https://www.anaconda.com/distribution/) in a target machine and open an AnaConda prompt window when it is done. Then run the following commands to create a conda environment. This notebook is tested with PyTorch 1.4 and OnnxRuntime 1.2.0.\n",
|
||||
"First, we install [AnaConda](https://www.anaconda.com/distribution/) in a target machine and open an AnaConda prompt window when it is done. Then run the following commands to create a conda environment. This notebook was run with PyTorch 1.4 and OnnxRuntime 1.2.0. (We also verified it with PyTorch 1.5 and OnnxRuntime 1.3.0).\n",
|
||||
"\n",
|
||||
"```console\n",
|
||||
"conda create -n gpu_env python=3.6\n",
|
||||
"conda activate gpu_env\n",
|
||||
"conda install pytorch torchvision cudatoolkit=10.1 -c pytorch\n",
|
||||
"pip install onnxruntime-gpu\n",
|
||||
"pip install transformers==2.5.1\n",
|
||||
"pip install wget psutil onnx pytz pandas py-cpuinfo py3nvml netron\n",
|
||||
"pip install transformers==2.11.0\n",
|
||||
"pip install onnxruntime-tools\n",
|
||||
"pip install wget netron\n",
|
||||
"conda install jupyter\n",
|
||||
"jupyter notebook\n",
|
||||
"```\n",
|
||||
|
|
@ -390,11 +393,11 @@
|
|||
"latency = []\n",
|
||||
"for i in range(total_samples):\n",
|
||||
" data = dataset[i]\n",
|
||||
" # Use contiguous array as input might improve performance\n",
|
||||
" # TODO: use IO Binding (see https://github.com/microsoft/onnxruntime/pull/4206) to improve performance.\n",
|
||||
" ort_inputs = {\n",
|
||||
" 'input_ids': numpy.ascontiguousarray(data[0].cpu().reshape(1, max_seq_length).numpy()),\n",
|
||||
" 'input_mask': numpy.ascontiguousarray(data[1].cpu().reshape(1, max_seq_length).numpy()),\n",
|
||||
" 'segment_ids': numpy.ascontiguousarray(data[2].cpu().reshape(1, max_seq_length).numpy())\n",
|
||||
" 'input_ids': data[0].cpu().reshape(1, max_seq_length).numpy(),\n",
|
||||
" 'input_mask': data[1].cpu().reshape(1, max_seq_length).numpy(),\n",
|
||||
" 'segment_ids': data[2].cpu().reshape(1, max_seq_length).numpy()\n",
|
||||
" }\n",
|
||||
" start = time.time()\n",
|
||||
" ort_outputs = session.run(None, ort_inputs)\n",
|
||||
|
|
@ -544,9 +547,7 @@
|
|||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Let's compare the output and see whether the results are close.\n",
|
||||
"\n",
|
||||
"**Note**: Need end-to-end evaluation on performance and accuracy if you use this strategy."
|
||||
"Let's compare the output and see whether the results are close."
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
@ -576,69 +577,35 @@
|
|||
"source": [
|
||||
"## 5. Offline Optimization and Test Tools\n",
|
||||
"\n",
|
||||
"It is recommended to download the [OnnxRuntime Python Tools for BERT](https://github.com/microsoft/onnxruntime/tree/master/onnxruntime/python/tools/transformers), and try them on the exported ONNX models. It could help verify whether the model is fully optimized, and get performance test results.\n",
|
||||
"\n",
|
||||
"### Download OnnxRuntime Python Tools for Bert\n",
|
||||
"You may copy the whole [directory](https://github.com/microsoft/onnxruntime/tree/master/onnxruntime/python/tools/transformers) to a sub-directory named bert_scripts for this notebook. The list of script files might need update if import error happens when you run some script."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100% [..............................................................................] 15310 / 15310Downloaded bert_perf_test.py\n",
|
||||
"100% [................................................................................] 9571 / 9571Downloaded bert_test_data.py\n",
|
||||
"100% [................................................................................] 7272 / 7272Downloaded compare_bert_results.py\n",
|
||||
"100% [..............................................................................] 44905 / 44905Downloaded BertOnnxModel.py\n",
|
||||
"100% [..............................................................................] 21565 / 21565Downloaded BertOnnxModelKeras.py\n",
|
||||
"100% [..............................................................................] 26114 / 26114Downloaded BertOnnxModelTF.py\n",
|
||||
"100% [..............................................................................] 22773 / 22773Downloaded OnnxModel.py\n",
|
||||
"100% [................................................................................] 7795 / 7795Downloaded optimizer.py\n",
|
||||
"100% [................................................................................] 5885 / 5885Downloaded MachineInfo.py\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"import wget\n",
|
||||
"\n",
|
||||
"url_prfix = \"https://raw.githubusercontent.com/microsoft/onnxruntime/master/onnxruntime/python/tools/transformers/\"\n",
|
||||
"script_files = ['bert_perf_test.py', 'bert_test_data.py', 'compare_bert_results.py', 'BertOnnxModel.py', 'BertOnnxModelKeras.py', 'BertOnnxModelTF.py', 'Gpt2OnnxModel.py', 'OnnxModel.py', 'optimizer.py', 'MachineInfo.py']\n",
|
||||
"\n",
|
||||
"script_dir = './bert_scripts'\n",
|
||||
"if not os.path.exists(script_dir):\n",
|
||||
" os.makedirs(script_dir)\n",
|
||||
"\n",
|
||||
"for filename in script_files:\n",
|
||||
" target_file = os.path.join(script_dir, filename)\n",
|
||||
" if enable_overwrite and os.path.exists(target_file):\n",
|
||||
" os.remove(target_file)\n",
|
||||
" if not os.path.exists(target_file):\n",
|
||||
" wget.download(url_prfix + filename, target_file)\n",
|
||||
" print(\"Downloaded\", filename)"
|
||||
"It is recommended to try [OnnxRuntime Transformer Model Optimization Tool](https://github.com/microsoft/onnxruntime/tree/master/onnxruntime/python/tools/transformers) on the exported ONNX models. It could help verify whether the model can be fully optimized, and get performance test results."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### BERT Optimization Script\n",
|
||||
"#### Transformer Optimizer\n",
|
||||
"\n",
|
||||
"Sometime, some optimization of OnnxRuntime cannot be applied to a Bert model due to different reasons:\n",
|
||||
"* A new subgraph pattern is exported, which is not covered by the onnxruntime version users are using. For example, Gelu from PyTorch 1.4 is not fused by OnnxRuntime 1.1.2 (Note: it is covered in OnnxRuntime v1.2.0).\n",
|
||||
"* The exported model uses dynamic axis. That impacts shape inference. Without enough shape information, some optimization cannot be applied due to the constraint on the input shape.\n",
|
||||
"* Some optimization are not supported by OnnxRuntime, but it is feasible in offline script. Like changing input tensor type from int64 to int32 to avoid extra Cast nodes, or converting model to float16 to achieve better performance in V100 or T4 GPU.\n",
|
||||
"Although OnnxRuntime could optimize Bert model exported by PyTorch. Sometime, model cannot be fully optimized due to different reasons:\n",
|
||||
"* A new subgraph pattern is generated by new version of export tool, and the pattern is not covered by older version of OnnxRuntime. \n",
|
||||
"* The exported model uses dynamic axis and this makes it harder for shape inference of the graph. That blocks some optimization to be applied.\n",
|
||||
"* Some optimization is better to be done offline. Like change input tensor type from int64 to int32 to avoid extra Cast nodes, or convert model to float16 to achieve better performance in V100 or T4 GPU.\n",
|
||||
"\n",
|
||||
"We have python script **optimizer.py**, which is flexible in graph pattern matching and model conversions to tackle these problems.\n",
|
||||
"We have python script **optimizer.py**, which is more flexible in graph pattern matching and model conversion (like float32 to float16). You can also use it to verify whether a Bert model is fully optimized.\n",
|
||||
"\n",
|
||||
"In below example, we can see that the tool provide an extra optimization - SkipLayerNormalization and bias (Add) are not fused in OnnxRuntime due to shape inference.\n",
|
||||
"In this example, we can see that it introduces optimization that is not provided by onnxruntime: SkipLayerNormalization and bias fusion, which is not fused in OnnxRuntime due to shape inference as mentioned.\n",
|
||||
"\n",
|
||||
"The tool will tell whether a model is fully optimized or not. If not, that means you might need change the script to handle some new subgraph patern."
|
||||
"It will also tell whether the model is fully optimized or not. If not, that means you might need change the script to fuse some new pattern of subgraph.\n",
|
||||
"\n",
|
||||
"Example Usage:\n",
|
||||
"```\n",
|
||||
"from onnxruntime_tools import optimizer\n",
|
||||
"optimized_model = optimizer.optimize_model(export_model_path, model_type='bert', num_heads=12, hidden_size=768, use_gpu=True)\n",
|
||||
"optimized_model.convert_model_float32_to_float16()\n",
|
||||
"optimized_model.save_model_to_file(optimized_model_path)\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"You can also use optimizer_cli as the following."
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
@ -677,13 +644,15 @@
|
|||
],
|
||||
"source": [
|
||||
"optimized_fp32_model_path = './onnx/bert-base-cased-squad_opt_{}_fp32.onnx'.format('gpu' if use_gpu else 'cpu')\n",
|
||||
"%run ./bert_scripts/optimizer.py --input $export_model_path --output $optimized_fp32_model_path --input_int32"
|
||||
"\n",
|
||||
"!{sys.executable} -m onnxruntime_tools.optimizer_cli --input $export_model_path --output $optimized_fp32_model_path --input_int32"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Note: you might change \n",
|
||||
"#### Optimized Graph\n",
|
||||
"We can open the optimized model using [Netron](https://github.com/lutzroeder/netron) to visualize.\n",
|
||||
"\n",
|
||||
|
|
@ -702,7 +671,7 @@
|
|||
"import netron\n",
|
||||
"\n",
|
||||
"# change it to True if want to view the optimized model in browser\n",
|
||||
"enable_netron = False\n",
|
||||
"enable_netron = True\n",
|
||||
"if enable_netron:\n",
|
||||
" # If you encounter error \"access a socket in a way forbidden by its access permissions\", install Netron as standalone application instead.\n",
|
||||
" netron.start(optimized_fp32_model_path)"
|
||||
|
|
@ -739,7 +708,7 @@
|
|||
"source": [
|
||||
"GPU_OPTION = '--use_gpu' if use_gpu else ''\n",
|
||||
"\n",
|
||||
"%run ./bert_scripts/bert_perf_test.py --model $optimized_fp32_model_path --batch_size 1 --sequence_length 128 --samples 1000 --test_times 1 --inclusive --all $GPU_OPTION"
|
||||
"!{sys.executable} -m onnxruntime_tools.transformers.bert_perf_test --model $optimized_fp32_model_path --batch_size 1 --sequence_length 128 --samples 1000 --test_times 1 --inclusive --all $GPU_OPTION"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
@ -1108,7 +1077,7 @@
|
|||
}
|
||||
],
|
||||
"source": [
|
||||
"%run ./bert_scripts/compare_bert_results.py --baseline_model $export_model_path --optimized_model $optimized_fp32_model_path --batch_size 1 --sequence_length 128 --samples 100 --rtol 0.01 --atol 0.01 $GPU_OPTION"
|
||||
"!{sys.executable} -m onnxruntime_tools.transformers.compare_bert_results --baseline_model $export_model_path --optimized_model $optimized_fp32_model_path --batch_size 1 --sequence_length 128 --samples 100 --rtol 0.01 --atol 0.01 $GPU_OPTION"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
@ -1148,7 +1117,7 @@
|
|||
],
|
||||
"source": [
|
||||
"optimized_fp16_model_path = './onnx/bert-base-cased-squad_opt_{}_fp16.onnx'.format('gpu' if use_gpu else 'cpu')\n",
|
||||
"%run ./bert_scripts/optimizer.py --input $export_model_path --output $optimized_fp16_model_path --float16 --input_int32"
|
||||
"!{sys.executable} -m onnxruntime_tools.optimizer_cli --input $export_model_path --output $optimized_fp16_model_path --float16 --input_int32"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
@ -1168,7 +1137,7 @@
|
|||
],
|
||||
"source": [
|
||||
"GPU_OPTION = '--use_gpu' if use_gpu else ''\n",
|
||||
"%run ./bert_scripts/bert_perf_test.py --model $optimized_fp16_model_path --batch_size 1 --sequence_length 128 --samples 1000 --test_times 1 --inclusive --all $GPU_OPTION"
|
||||
"!{sys.executable} -m onnxruntime_tools.transformers.bert_perf_test --model $optimized_fp16_model_path --batch_size 1 --sequence_length 128 --samples 1000 --test_times 1 --inclusive --all $GPU_OPTION"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
@ -1534,7 +1503,7 @@
|
|||
],
|
||||
"source": [
|
||||
"GPU_OPTION = '--use_gpu' if use_gpu else ''\n",
|
||||
"%run ./bert_scripts/bert_perf_test.py --model $optimized_fp16_model_path --batch_size 1 2 4 8 16 32 64 --sequence_length 128 --samples 1000 --test_times 1 --inclusive $GPU_OPTION"
|
||||
"!{sys.executable} -m onnxruntime_tools.transformers.bert_perf_test --model $optimized_fp16_model_path --batch_size 1 2 4 8 16 32 64 --sequence_length 128 --samples 1000 --test_times 1 --inclusive $GPU_OPTION"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
@ -1715,6 +1684,8 @@
|
|||
"\n",
|
||||
"Note that running Jupyter Notebook has slight impact on performance result since Jupyter Notebook is using system resources like CPU etc. You can close Jupyter Notebook and other applications, then run the performance test in a console to get more accurate performance numbers.\n",
|
||||
"\n",
|
||||
"We have a [benchmark script](https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/python/tools/transformers/run_benchmark.sh). It is recommended to use it compare inference speed of OnnxRuntime with PyTorch.\n",
|
||||
"\n",
|
||||
"[OnnxRuntime C API](https://github.com/microsoft/onnxruntime/blob/master/docs/C_API.md) could get slightly better performance than python API. If you use C API in inference, you can use OnnxRuntime_Perf_Test.exe built from source to measure performance instead.\n",
|
||||
"\n",
|
||||
"Here is the machine configuration that generated the above results. You might get slower or faster result according to your hardware."
|
||||
|
|
@ -1771,7 +1742,7 @@
|
|||
}
|
||||
],
|
||||
"source": [
|
||||
"%run ./bert_scripts/MachineInfo.py --silent"
|
||||
"!{sys.executable} -m onnxruntime_tools.transformers.machine_info --silent"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
|
|
|||
|
|
@ -113,7 +113,7 @@
|
|||
" os.makedirs(directory)\n",
|
||||
"\n",
|
||||
"# Download scripts for BERT optimization.\n",
|
||||
"url_prfix = \"https://raw.githubusercontent.com/microsoft/onnxruntime/master/onnxruntime/python/tools/transformers/\"\n",
|
||||
"url_prfix = \"https://raw.githubusercontent.com/microsoft/onnxruntime/rel-1.3.0/onnxruntime/python/tools/bert/\"\n",
|
||||
"script_files = ['bert_perf_test.py', 'bert_test_data.py', 'compare_bert_results.py', 'BertOnnxModel.py', 'BertOnnxModelKeras.py', 'BertOnnxModelTF.py', 'Gpt2OnnxModel.py', 'OnnxModel.py', 'optimizer.py']\n",
|
||||
"\n",
|
||||
"for filename in script_files:\n",
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@
|
|||
|
||||
from logging import getLogger
|
||||
from onnx import TensorProto, helper
|
||||
from OnnxModel import OnnxModel
|
||||
from onnx_model import OnnxModel
|
||||
from fusion_reshape import FusionReshape
|
||||
from fusion_layernorm import FusionLayerNormalization, FusionLayerNormalizationTF
|
||||
from fusion_skiplayernorm import FusionSkipLayerNormalization, FusionBiasSkipLayerNormalization
|
||||
|
|
@ -15,6 +15,7 @@ from fusion_gelu import FusionGelu
|
|||
from fusion_fastgelu import FusionFastGelu
|
||||
from fusion_biasgelu import FusionBiasGelu
|
||||
from fusion_gelu_approximation import FusionGeluApproximation
|
||||
from fusion_utils import FusionUtils
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
|
@ -126,9 +127,10 @@ class BertOnnxModel(OnnxModel):
|
|||
new_graph_inputs = []
|
||||
|
||||
bert_inputs = self.get_bert_inputs()
|
||||
utils = FusionUtils(self)
|
||||
for input in graph.input:
|
||||
if input.name in bert_inputs:
|
||||
self.remove_cast_int32(input.name)
|
||||
utils.remove_cast_int32(input.name)
|
||||
input_shape = [
|
||||
batch_size if isinstance(batch_size, int) else 1,
|
||||
sequence_length if isinstance(sequence_length, int) else 128
|
||||
|
|
@ -10,7 +10,7 @@ import argparse
|
|||
import numpy as np
|
||||
from collections import deque
|
||||
from onnx import ModelProto, TensorProto, numpy_helper
|
||||
from BertOnnxModelTF import BertOnnxModelTF
|
||||
from onnx_model_bert_tf import BertOnnxModelTF
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
|
@ -10,7 +10,7 @@ import argparse
|
|||
import numpy as np
|
||||
from collections import deque
|
||||
from onnx import ModelProto, TensorProto, numpy_helper
|
||||
from BertOnnxModel import BertOnnxModel
|
||||
from onnx_model_bert import BertOnnxModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -9,7 +9,7 @@ import argparse
|
|||
import numpy as np
|
||||
from collections import deque
|
||||
from onnx import ModelProto, TensorProto, numpy_helper
|
||||
from BertOnnxModel import BertOnnxModel
|
||||
from onnx_model_bert import BertOnnxModel
|
||||
from fusion_gpt_attention_no_past import FusionGptAttentionNoPast
|
||||
from fusion_gpt_attention import FusionGptAttention
|
||||
|
||||
|
|
@ -27,12 +27,12 @@ import numpy as np
|
|||
from typing import Dict
|
||||
from collections import deque
|
||||
from onnx import ModelProto, TensorProto, numpy_helper, load_model
|
||||
from BertOnnxModel import BertOnnxModel, BertOptimizationOptions
|
||||
from BertOnnxModelTF import BertOnnxModelTF
|
||||
from BertOnnxModelKeras import BertOnnxModelKeras
|
||||
from Gpt2OnnxModel import Gpt2OnnxModel
|
||||
from onnx_model_bert import BertOnnxModel, BertOptimizationOptions
|
||||
from onnx_model_bert_tf import BertOnnxModelTF
|
||||
from onnx_model_bert_keras import BertOnnxModelKeras
|
||||
from onnx_model_gpt2 import Gpt2OnnxModel
|
||||
|
||||
logger = logging.getLogger('')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Map model type to tuple: optimizer class, export tools (pytorch, tf2onnx, keras2onnx) and whether OnnxRuntime has the optimization.
|
||||
MODEL_CLASSES = {
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ from datetime import datetime
|
|||
from pathlib import Path
|
||||
from onnx import ModelProto, TensorProto, numpy_helper
|
||||
import onnxruntime
|
||||
from OnnxModel import OnnxModel
|
||||
from onnx_model import OnnxModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -34,7 +34,6 @@ class BertOnnxModelShapeOptimizer(OnnxModel):
|
|||
This optimizer will replace Shape output or the shape input of Reshape node by initializer. Currently, it requires
|
||||
model inputs to have static shape.
|
||||
"""
|
||||
|
||||
def __init__(self, onnx_model):
|
||||
super().__init__(onnx_model.model)
|
||||
|
||||
|
|
@ -29,7 +29,7 @@ import sys
|
|||
import argparse
|
||||
import numpy as np
|
||||
from onnx import ModelProto, TensorProto, numpy_helper
|
||||
from optimizer import OnnxModel
|
||||
from onnxruntime_tools.transformers.onnx_model import OnnxModel
|
||||
import os
|
||||
import onnxruntime
|
||||
import random
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ import sys
|
|||
import argparse
|
||||
import numpy as np
|
||||
from onnx import ModelProto, TensorProto, numpy_helper
|
||||
from OnnxModel import OnnxModel
|
||||
from onnxruntime_tools.transformers.onnx_model import OnnxModel
|
||||
import os
|
||||
import onnxruntime
|
||||
import random
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ import sys
|
|||
import argparse
|
||||
import numpy as np
|
||||
from onnx import ModelProto, TensorProto, numpy_helper
|
||||
from OnnxModel import OnnxModel
|
||||
from onnxruntime_tools.transformers.onnx_model import OnnxModel
|
||||
import os
|
||||
import onnxruntime
|
||||
import random
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ from onnx.helper import make_node, make_tensor_value_info
|
|||
import numpy as np
|
||||
from onnx import numpy_helper
|
||||
from optimizer import optimize_model, optimize_by_onnxruntime
|
||||
from OnnxModel import OnnxModel
|
||||
from onnx_model import OnnxModel
|
||||
|
||||
BERT_TEST_MODELS = {
|
||||
"bert_pytorch_0": ('bert_squad_pytorch1.4_opset11', 'BertForQuestionAnswering_0.onnx'),
|
||||
|
|
|
|||
|
|
@ -83,7 +83,10 @@ void ComputeBroadcastBackwardAxes(
|
|||
}
|
||||
|
||||
std::vector<Dimension> GetShape(const ArgDef& arg_def) {
|
||||
ORT_ENFORCE(arg_def.type_proto, "During GetShape, ", arg_def.name, "'s type_proto is null.");
|
||||
ORT_ENFORCE(arg_def.type_proto
|
||||
&& arg_def.type_proto->has_tensor_type()
|
||||
&& arg_def.type_proto->tensor_type().has_shape(),
|
||||
"During GetShape, ", arg_def.name, "'s shape is null.");
|
||||
std::vector<Dimension> shape;
|
||||
const auto& dims = arg_def.type_proto->tensor_type().shape().dim();
|
||||
for (auto dim = dims.begin(); dim < dims.end(); dim++) {
|
||||
|
|
|
|||
|
|
@ -937,7 +937,25 @@ Example 4:
|
|||
.TypeConstraint("Tind",
|
||||
{"tensor(int32)", "tensor(int64)"},
|
||||
"Constrain indices to integer types")
|
||||
.SetDoc(R"DOC(SparseSoftmaxCrossEntropy)DOC");
|
||||
.SetDoc(R"DOC(SparseSoftmaxCrossEntropy)DOC")
|
||||
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
|
||||
propagateElemTypeFromInputToOutput(ctx, 0, 0);
|
||||
std::string reduction = getAttribute(ctx, "reduction", "mean");
|
||||
if (reduction.compare("none") == 0) {
|
||||
if (hasInputShape(ctx, 1)) {
|
||||
propagateShapeFromInputToOutput(ctx, 1, 0);
|
||||
}
|
||||
} else {
|
||||
updateOutputShape(ctx, 0, TensorShapeProto());
|
||||
}
|
||||
|
||||
if(ctx.getNumOutputs() == 2) {
|
||||
propagateElemTypeFromInputToOutput(ctx, 0, 1);
|
||||
if (hasInputShape(ctx, 0)) {
|
||||
propagateShapeFromInputToOutput(ctx, 0, 1);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
ONNX_CONTRIB_OPERATOR_SCHEMA(SparseSoftmaxCrossEntropyGrad)
|
||||
.SetDomain(kOnnxDomain)
|
||||
|
|
|
|||
|
|
@ -14,7 +14,9 @@ Status InsertMaxPoolOutput::Apply(Graph& graph, Node& node, RewriteRuleEffect& r
|
|||
|
||||
TypeProto t;
|
||||
t.mutable_tensor_type()->set_elem_type(TensorProto_DataType_INT64);
|
||||
t.mutable_tensor_type()->mutable_shape()->CopyFrom(*Y->Shape());
|
||||
if (Y->Shape() != nullptr) {
|
||||
t.mutable_tensor_type()->mutable_shape()->CopyFrom(*Y->Shape());
|
||||
}
|
||||
|
||||
NodeArg& node_arg = graph.GetOrCreateNodeArg(Y->Name() + "_mask", &t);
|
||||
|
||||
|
|
@ -38,7 +40,9 @@ Status InsertSoftmaxCrossEntropyLossOutput::Apply(Graph& graph, Node& node, Rewr
|
|||
|
||||
TypeProto t;
|
||||
t.mutable_tensor_type()->set_elem_type(X->TypeAsProto()->tensor_type().elem_type());
|
||||
t.mutable_tensor_type()->mutable_shape()->CopyFrom(*X->Shape()); // log probability should have the same shape as logits.
|
||||
if (X->Shape() != nullptr) {
|
||||
t.mutable_tensor_type()->mutable_shape()->CopyFrom(*X->Shape()); // log probability should have the same shape as logits.
|
||||
}
|
||||
|
||||
NodeArg& node_arg = graph.GetOrCreateNodeArg(X->Name() + "_log_prob", &t);
|
||||
|
||||
|
|
|
|||
|
|
@ -795,6 +795,7 @@ Status TrainingRunner::TrainingLoop(IDataLoader& training_data_loader, IDataLoad
|
|||
auto end_to_end_start = std::chrono::high_resolution_clock::now();
|
||||
bool end_to_end_measurement_started = false;
|
||||
|
||||
auto all_steps_time_start = std::chrono::high_resolution_clock::now();
|
||||
while (step_ < params_.num_train_steps) {
|
||||
for (size_t shard_it = 0; shard_it < num_shards_to_visit; ++shard_it) {
|
||||
auto training_data = training_data_loader.CurrentDataSet();
|
||||
|
|
@ -921,6 +922,8 @@ Status TrainingRunner::TrainingLoop(IDataLoader& training_data_loader, IDataLoad
|
|||
|
||||
++epoch;
|
||||
}
|
||||
auto all_steps_time_end = std::chrono::high_resolution_clock::now();
|
||||
std::chrono::duration<double> all_steps_duration_seconds = all_steps_time_end - all_steps_time_start;
|
||||
|
||||
const double e2e_throughput = [&]() {
|
||||
if (end_to_end_perf_start_step >= params_.num_train_steps) return 0.0;
|
||||
|
|
@ -959,7 +962,9 @@ Status TrainingRunner::TrainingLoop(IDataLoader& training_data_loader, IDataLoad
|
|||
<< "Average Running Time Per Batch: " << avg_time_per_batch << " ms\n"
|
||||
<< "Throughput: " << throughput << " Examples / Second\n"
|
||||
<< "Stabilized Throughput: " << stabilized_throughput << " Examples / Second\n"
|
||||
<< "EndToEnd Throughput: " << e2e_throughput << " Examples / Second\n";
|
||||
<< "EndToEnd Throughput: " << e2e_throughput << " Examples / Second\n"
|
||||
<< "Average Step Time: " << all_steps_duration_seconds.count() / (step_ - step_start)<< " Second\n"
|
||||
<< "Average Step Throughput: " << params_.batch_size * (step_ - step_start) / (all_steps_duration_seconds.count()) << " Examples / Second\n";
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
@ -1170,7 +1175,6 @@ Status TrainingRunner::Evaluate(InferenceSession& session, IDataLoader& data_loa
|
|||
&fetches));
|
||||
}
|
||||
|
||||
|
||||
// Assume that user-specified fetches are avaliable only on the last pipeline stage.
|
||||
// When there is no pipeline, all pipeline_context_.pipeline_stage_id should be 0 and
|
||||
// params_.pipeline_parallel_size is 1. Thus, the following condition is always true if there
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#ifdef USE_HOROVOD
|
||||
#if defined(USE_NCCL) || defined(USE_HOROVOD)
|
||||
|
||||
#include "orttraining/training_ops/cuda/communication/recv.h"
|
||||
#include "orttraining/training_ops/cuda/communication/common.h"
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#ifdef USE_HOROVOD
|
||||
#if defined(USE_NCCL) || defined(USE_HOROVOD)
|
||||
|
||||
#pragma once
|
||||
#include "core/common/common.h"
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#ifdef USE_HOROVOD
|
||||
#if defined(USE_NCCL) || defined(USE_HOROVOD)
|
||||
|
||||
#include "orttraining/training_ops/cuda/communication/send.h"
|
||||
#include "orttraining/training_ops/cuda/communication/common.h"
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#ifdef USE_HOROVOD
|
||||
#if defined(USE_NCCL) || defined(USE_HOROVOD)
|
||||
|
||||
#pragma once
|
||||
#include "core/common/common.h"
|
||||
|
|
|
|||
|
|
@ -114,14 +114,17 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1
|
|||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, SliceGrad);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, GatherElementsGrad);
|
||||
|
||||
#ifdef USE_HOROVOD
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, HorovodAllReduce);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, HorovodBarrier);
|
||||
#if defined(USE_NCCL) || defined(USE_HOROVOD)
|
||||
// P2P communication operators.
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Send);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Recv);
|
||||
#endif
|
||||
|
||||
#ifdef USE_HOROVOD
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, HorovodAllReduce);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, HorovodBarrier);
|
||||
#endif
|
||||
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, RecordEvent);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, WaitEvent);
|
||||
|
||||
|
|
@ -240,12 +243,15 @@ Status RegisterCudaTrainingKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, SliceGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, GatherElementsGrad)>,
|
||||
|
||||
// P2P communication operators.
|
||||
#if defined(USE_NCCL) || defined(USE_HOROVOD)
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Send)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Recv)>,
|
||||
#endif
|
||||
|
||||
#ifdef USE_HOROVOD
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, HorovodAllReduce)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, HorovodBarrier)>,
|
||||
// P2P communication operators.
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Send)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Recv)>,
|
||||
#endif
|
||||
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, RecordEvent)>,
|
||||
|
|
|
|||
|
|
@ -5,17 +5,26 @@
|
|||
import argparse
|
||||
import hashlib
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import urllib.request
|
||||
import zipfile
|
||||
|
||||
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
|
||||
REPO_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", "..", ".."))
|
||||
|
||||
sys.path.append(os.path.join(REPO_DIR, "tools", "python"))
|
||||
|
||||
import get_azcopy # noqa: E402
|
||||
|
||||
# update these if the E2E test data changes
|
||||
ARCHIVE_BLOB_URL = "https://onnxruntimetestdata.blob.core.windows.net/training/onnxruntime_training_data.zip?snapshot=2020-06-15T23:17:35.8314853Z"
|
||||
ARCHIVE_SHA256_DIGEST = "B01C169B6550D1A0A6F1B4E2F34AE2A8714B52DBB70AC04DA85D371F691BDFF9"
|
||||
|
||||
def _download(url, local_path):
|
||||
urllib.request.urlretrieve(url, local_path)
|
||||
def _download(azcopy_path, url, local_path):
|
||||
subprocess.run([azcopy_path, "cp", "--log-level", "NONE", url, local_path], check=True)
|
||||
|
||||
def _get_sha256_digest(file_path):
|
||||
alg = hashlib.sha256()
|
||||
|
|
@ -36,22 +45,19 @@ def _check_file_sha256_digest(path, expected_digest):
|
|||
raise RuntimeError(
|
||||
"SHA256 digest mismatch, expected: {}, actual: {}".format(expected_digest.lower(), actual_digest.lower()))
|
||||
|
||||
def _extract_archive(archive_path, target_dir):
|
||||
with zipfile.ZipFile(archive_path) as archive:
|
||||
archive.extractall(target_dir)
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Downloads training end-to-end test data.")
|
||||
parser.add_argument("target_dir", help="The test data destination directory.")
|
||||
args = parser.parse_args()
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
with tempfile.TemporaryDirectory() as temp_dir, \
|
||||
get_azcopy.get_azcopy() as azcopy_path:
|
||||
archive_path = os.path.join(temp_dir, "archive.zip")
|
||||
print("Downloading E2E test data from '{}'...".format(ARCHIVE_BLOB_URL))
|
||||
_download(ARCHIVE_BLOB_URL, archive_path)
|
||||
_download(azcopy_path, ARCHIVE_BLOB_URL, archive_path)
|
||||
_check_file_sha256_digest(archive_path, ARCHIVE_SHA256_DIGEST)
|
||||
print("Extracting to '{}'...".format(args.target_dir))
|
||||
_extract_archive(archive_path, args.target_dir)
|
||||
shutil.unpack_archive(archive_path, args.target_dir)
|
||||
print("Done.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -14,14 +14,6 @@ jobs:
|
|||
continueOnError: true
|
||||
condition: always()
|
||||
|
||||
- task: CmdLine@2
|
||||
displayName: 'Download azcopy'
|
||||
inputs:
|
||||
script: |
|
||||
curl -so azcopy.tar.gz -L 'https://aka.ms/downloadazcopy-v10-linux'
|
||||
tar -zxvf azcopy.tar.gz --strip 1
|
||||
workingDirectory: $(Build.BinariesDirectory)
|
||||
|
||||
- task: PythonScript@0
|
||||
displayName: 'Download test data'
|
||||
inputs:
|
||||
|
|
|
|||
|
|
@ -1,12 +1,4 @@
|
|||
steps:
|
||||
- task: CmdLine@2
|
||||
displayName: 'Download azcopy'
|
||||
inputs:
|
||||
script: |
|
||||
curl -so azcopy.tar.gz -L 'https://aka.ms/downloadazcopy-v10-mac'
|
||||
tar -zxvf azcopy.tar.gz --strip 1
|
||||
workingDirectory: $(Build.BinariesDirectory)
|
||||
|
||||
- task: PythonScript@0
|
||||
displayName: 'Download test data'
|
||||
inputs:
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
# Assumes AZCopy and Python download is already done
|
||||
# Assumes Python download is already done
|
||||
steps:
|
||||
- task: PythonScript@0
|
||||
displayName: 'Download test data'
|
||||
|
|
|
|||
|
|
@ -10,6 +10,13 @@ from urllib.parse import urlparse
|
|||
from urllib.parse import urljoin
|
||||
from urllib.parse import urlsplit
|
||||
|
||||
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
|
||||
REPO_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", "..", ".."))
|
||||
sys.path.append(os.path.join(REPO_DIR, "tools", "python"))
|
||||
|
||||
from get_azcopy import get_azcopy # noqa: E402
|
||||
|
||||
|
||||
# Hardcoded map of storage account to azure region endpoint
|
||||
storage_account_to_endpoint_map = {
|
||||
'onnxruntimetestdata.blob.core.windows.net': {
|
||||
|
|
@ -47,7 +54,7 @@ def get_azure_region():
|
|||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser(description="ONNXRuntime Data Downloader.")
|
||||
parser.add_argument("--test_data_url", help="Test data URL.")
|
||||
parser.add_argument("--test_data_url", required=True, help="Test data URL.")
|
||||
parser.add_argument("--azure_region", help="Azure region")
|
||||
parser.add_argument("--build_dir", required=True, help="Path to the build directory.")
|
||||
parser.add_argument("--edge_device", action="store_true", help="Edge device with limit disk space.")
|
||||
|
|
@ -80,7 +87,7 @@ def get_region_based_url(url, azure_location):
|
|||
return url
|
||||
|
||||
|
||||
def download_and_unzip(build_dir, url, dest_folder, use_token=True):
|
||||
def download_and_unzip(azcopy_path, build_dir, url, dest_folder, use_token=True):
|
||||
dest_folder = os.path.join(build_dir, dest_folder)
|
||||
# attach the SAS token to the url. Note DO NOT print the url with the token in any logs.
|
||||
token = os.environ.get('Test_Data_Download_Key')
|
||||
|
|
@ -90,14 +97,13 @@ def download_and_unzip(build_dir, url, dest_folder, use_token=True):
|
|||
url_with_token = url
|
||||
|
||||
# Download data using AZCopy tool
|
||||
# Our linux CI build machine has azcopy in /usr/bin but the version is too old
|
||||
azcopy_exe = \
|
||||
'azcopy.exe' if sys.platform.startswith("win") and shutil.which('azcopy') else os.path.join(build_dir, 'azcopy')
|
||||
try:
|
||||
subprocess.run([azcopy_exe, 'cp', '--log-level', 'ERROR', '--recursive', url_with_token, build_dir], check=True)
|
||||
subprocess.run(
|
||||
[azcopy_path, 'cp', '--log-level', 'ERROR', '--recursive', url_with_token, build_dir],
|
||||
check=True)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print(azcopy_exe)
|
||||
print(azcopy_path)
|
||||
raise Exception("Downloading data failed. Source: " + url + " Destination: " + build_dir)
|
||||
|
||||
os.makedirs(dest_folder, exist_ok=True)
|
||||
|
|
@ -116,29 +122,6 @@ def download_and_unzip(build_dir, url, dest_folder, use_token=True):
|
|||
os.unlink(local_file_name)
|
||||
|
||||
|
||||
def download_additional_data(build_dir, azure_region):
|
||||
additional_data_url = 'https://onnxruntimetestdata.blob.core.windows.net/models/'
|
||||
# url = get_region_based_url(args.test_data_url, azure_region)
|
||||
if not shutil.which('cmake'):
|
||||
cmake_url = urljoin(additional_data_url, 'cmake-3.15.1-win64-x64.zip')
|
||||
print("Starting download for cmake : " + cmake_url)
|
||||
download_and_unzip(build_dir, cmake_url, 'cmake_temp', False)
|
||||
dest_dir = os.path.join(build_dir, 'cmake')
|
||||
if os.path.exists(dest_dir):
|
||||
print('deleting %s' % dest_dir)
|
||||
shutil.rmtree(dest_dir)
|
||||
shutil.move(os.path.join(build_dir, 'cmake_temp', 'cmake-3.15.1-win64-x64'), dest_dir)
|
||||
|
||||
# Download OpenCPPCoverageSetup.exe
|
||||
opencpp_url = urljoin(additional_data_url, 'OpenCppCoverageSetup-x64-0.9.7.0.exe')
|
||||
print("Starting download for opencppcoverage " + opencpp_url)
|
||||
dest_folder = os.path.join(build_dir, 'installer', 'opencppcoverage')
|
||||
os.makedirs(dest_folder, exist_ok=True)
|
||||
azcopy_exe = 'azcopy.exe' if shutil.which('azcopy') else os.path.join(build_dir, 'azcopy')
|
||||
subprocess.run([azcopy_exe, 'cp', '--log-level', 'ERROR', opencpp_url, os.path.join(dest_folder, 'installer.exe')],
|
||||
check=True)
|
||||
|
||||
|
||||
args = parse_arguments()
|
||||
models_folder = 'models'
|
||||
|
||||
|
|
@ -157,9 +140,10 @@ else:
|
|||
azure_region = get_azure_region()
|
||||
try:
|
||||
# Download test data
|
||||
url = get_region_based_url(args.test_data_url, azure_region)
|
||||
print("Starting test data download %s" % url)
|
||||
download_and_unzip(args.build_dir, url, models_folder)
|
||||
with get_azcopy(os.path.join(args.build_dir, "azcopy")) as azcopy_path:
|
||||
url = get_region_based_url(args.test_data_url, azure_region)
|
||||
print("Starting test data download %s" % url)
|
||||
download_and_unzip(azcopy_path, args.build_dir, url, models_folder)
|
||||
|
||||
all_downloads_done = True
|
||||
|
||||
|
|
|
|||
78
tools/python/get_azcopy.py
Normal file
78
tools/python/get_azcopy.py
Normal file
|
|
@ -0,0 +1,78 @@
|
|||
import contextlib
|
||||
import os
|
||||
import platform
|
||||
import re
|
||||
import shutil
|
||||
import stat
|
||||
import subprocess
|
||||
import tempfile
|
||||
import urllib.parse
|
||||
import urllib.request
|
||||
|
||||
AZCOPY_VERSION = "10.4.3"
|
||||
|
||||
# See here for instructions on getting stable download links:
|
||||
# https://docs.microsoft.com/en-us/azure/storage/common/storage-use-azcopy-v10#obtain-a-static-download-link
|
||||
_AZCOPY_DOWNLOAD_URLS = {
|
||||
"Linux": "https://azcopyvnext.azureedge.net/release20200501/azcopy_linux_amd64_10.4.3.tar.gz",
|
||||
"Darwin": "https://azcopyvnext.azureedge.net/release20200501/azcopy_darwin_amd64_10.4.3.zip",
|
||||
"Windows": "https://azcopyvnext.azureedge.net/release20200501/azcopy_windows_amd64_10.4.3.zip",
|
||||
}
|
||||
|
||||
|
||||
def _check_version(azcopy_path):
|
||||
proc = subprocess.run(
|
||||
[azcopy_path, "--version"],
|
||||
stdout=subprocess.PIPE, universal_newlines=True)
|
||||
match = re.search(r"\d+(?:\.\d+)+", proc.stdout)
|
||||
|
||||
if not match:
|
||||
raise RuntimeError("Failed to determine azcopy version.")
|
||||
|
||||
return match.group(0) == AZCOPY_VERSION
|
||||
|
||||
|
||||
def _find_azcopy(start_dir):
|
||||
for root, _, file_names in os.walk(start_dir):
|
||||
for file_name in file_names:
|
||||
if file_name == "azcopy" or file_name == "azcopy.exe":
|
||||
return os.path.join(root, file_name)
|
||||
raise RuntimeError("Failed to azcopy in '{}'.".format(start_dir))
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def get_azcopy(local_azcopy_path="azcopy"):
|
||||
"""
|
||||
Creates a context manager that returns a path to a particular version of
|
||||
azcopy (specified in AZCOPY_VERSION). Downloads a temporary copy if needed.
|
||||
|
||||
:param local_azcopy_path: Path to a local azcopy to try first.
|
||||
|
||||
Example usage:
|
||||
with get_azcopy() as azcopy_path:
|
||||
subprocess.run([azcopy_path, "--version"])
|
||||
"""
|
||||
with contextlib.ExitStack() as context_stack:
|
||||
azcopy_path = shutil.which(local_azcopy_path)
|
||||
|
||||
if azcopy_path is None or not _check_version(azcopy_path):
|
||||
temp_dir = context_stack.enter_context(
|
||||
tempfile.TemporaryDirectory())
|
||||
|
||||
download_url = _AZCOPY_DOWNLOAD_URLS[platform.system()]
|
||||
download_basename = urllib.parse.urlsplit(
|
||||
download_url).path.rsplit("/", 1)[-1]
|
||||
assert len(download_basename) > 0
|
||||
downloaded_path = os.path.join(temp_dir, download_basename)
|
||||
|
||||
print("Downloading azcopy from '{}'...".format(download_url))
|
||||
urllib.request.urlretrieve(download_url, downloaded_path)
|
||||
|
||||
extracted_path = os.path.join(temp_dir, "azcopy")
|
||||
shutil.unpack_archive(downloaded_path, extracted_path)
|
||||
|
||||
azcopy_path = _find_azcopy(extracted_path)
|
||||
|
||||
os.chmod(azcopy_path, stat.S_IXUSR)
|
||||
|
||||
yield azcopy_path
|
||||
Loading…
Reference in a new issue