Update Vitis AI EP to support multiple DPU targets through provider options (#6690)

* Update Vitis-AI EP support multiple DPU targets & specifically arm64 dpuczdx8g target

* Fix Vitis AI docker and default PyXIR versions

Co-authored-by: Jorn Tuyls <jornt@xilinx.com>
Co-authored-by: Jorn Tuyls <jornt.tuyls@gmail.com>
This commit is contained in:
Jorn Tuyls 2021-06-03 11:53:46 +02:00 committed by GitHub
parent 896f32ec09
commit 3bb780dcd5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 148 additions and 48 deletions

View file

@ -4,16 +4,23 @@
# --------------------------------------------------------------
# Dockerfile to run ONNXRuntime with Vitis-AI integration
FROM xilinx/vitis-ai:latest
FROM xilinx/vitis-ai-cpu:1.3.598
ARG ONNXRUNTIME_REPO=https://github.com/Microsoft/onnxruntime
ARG ONNXRUNTIME_BRANCH=master
ARG PYXIR_REPO=https://github.com/Xilinx/pyxir
ARG PYXIR_BRANCH=master
ARG PYXIR_BRANCH=v0.2.0
ARG PYXIR_FLAG="--use_vai_rt"
RUN apt-get update &&\
apt-get install -y sudo git bash
RUN apt-get update && \
apt-get install -y \
sudo \
git \
bash \
gcc-aarch64-linux-gnu && \
apt-get clean && \
rm -rf /var/lib/apt/lists/*
ENV PATH /code/cmake-3.14.3-Linux-x86_64/bin:$PATH
ENV LD_LIBRARY_PATH /opt/xilinx/xrt/lib:$LD_LIBRARY_PATH
@ -22,7 +29,7 @@ WORKDIR /code
RUN . $VAI_ROOT/conda/etc/profile.d/conda.sh &&\
conda activate vitis-ai-tensorflow &&\
git clone --single-branch --branch ${PYXIR_BRANCH} --recursive ${PYXIR_REPO} pyxir &&\
cd pyxir && python3 setup.py install --use_vai_rt
cd pyxir && python3 setup.py install ${PYXIR_FLAG}
RUN . $VAI_ROOT/conda/etc/profile.d/conda.sh &&\
conda activate vitis-ai-tensorflow &&\
git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXRUNTIME_REPO} onnxruntime &&\

View file

@ -10,7 +10,9 @@ extern "C" {
/**
* \param use_arena zero: false. non-zero: true.
*/
ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_VITISAI, _In_ OrtSessionOptions* options, const char *backend_type, int device_id);
ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_VITISAI, _In_ OrtSessionOptions* options,
const char* backend_type, int device_id, const char* export_runtime_module,
const char* load_runtime_module);
#ifdef __cplusplus
}

View file

@ -54,9 +54,12 @@ static ONNX_NAMESPACE::ModelProto GetModelProtoFromFusedNode(const onnxruntime::
VitisAICustomOp::VitisAICustomOp(const ComputeContext* context,
const onnxruntime::Node* fused_node,
const std::string &backend_type,
const std::string& backend_type,
const std::string& export_runtime_module,
const std::string& load_runtime_module,
const logging::Logger* logger)
: backend_type_(backend_type)
: backend_type_(backend_type), export_runtime_module_(export_runtime_module),
load_runtime_module_(load_runtime_module)
{
SetLogger(logger);
@ -66,26 +69,43 @@ VitisAICustomOp::VitisAICustomOp(const ComputeContext* context,
name_ = context->node_name;
model_proto_ = GetModelProtoFromFusedNode(fused_node, *GetLogger());
std::istringstream model_stream{model_proto_.SerializeAsString()};
xg_ = pyxir::onnx::import_onnx_model(model_stream);
pyxir::partition(xg_, std::vector<std::string>{backend_type_}, "");
auto input_defs = fused_node->InputDefs();
for (auto idef : input_defs) {
in_tensor_names_.push_back(idef->Name());
}
auto output_defs = fused_node->OutputDefs();
for (auto odef : output_defs) {
out_tensor_names_.push_back(odef->Name());
}
pyxir::RunOptionsHolder run_options(new pyxir::runtime::RunOptions());
run_options->on_the_fly_quantization = true;
rt_mod_ = pyxir::build_rt(xg_, backend_type_, in_tensor_names_, out_tensor_names_,
"vai", run_options);
// If the `load_runtime_module` provider option is empty we build a PyXIR
// runtime module from scratch. Otherwise, we load the runtime module from
// the provided file.
if (load_runtime_module_.empty()) {
pyxir::partition(xg_, std::vector<std::string>{backend_type_}, "");
auto input_defs = fused_node->InputDefs();
for (auto idef : input_defs) {
in_tensor_names_.push_back(idef->Name());
}
auto output_defs = fused_node->OutputDefs();
for (auto odef : output_defs) {
out_tensor_names_.push_back(odef->Name());
}
pyxir::RunOptionsHolder run_options(new pyxir::runtime::RunOptions());
run_options->on_the_fly_quantization = true;
run_options->export_runtime_module_path = export_runtime_module_;
rt_mod_ = pyxir::build_rt(xg_, backend_type_, in_tensor_names_,
out_tensor_names_, "vai", run_options);
} else {
std::ifstream in_file(load_runtime_module_);
std::stringstream buffer;
buffer << in_file.rdbuf();
std::string serialized_rt_mod = buffer.str();
in_file.close();
std::istringstream sstream(serialized_rt_mod);
rt_mod_.reset(new pyxir::runtime::RuntimeModule());
rt_mod_->deserialize(sstream);
in_tensor_names_ = rt_mod_->get_in_tensor_names();
out_tensor_names_ = rt_mod_->get_out_tensor_names();
}
}
VitisAICustomOp::~VitisAICustomOp() {}

View file

@ -30,7 +30,9 @@ class VitisAICustomOp {
public:
VitisAICustomOp(const ComputeContext* context,
const onnxruntime::Node* fused_node,
const std::string &backend_type,
const std::string& backend_type,
const std::string& export_runtime_module,
const std::string& load_runtime_module,
const logging::Logger* logger);
Status Compute(const OrtApi* api, OrtKernelContext* context) const;
@ -46,27 +48,35 @@ class VitisAICustomOp {
}
private:
// The partition input tensor names
std::vector<std::string> in_tensor_names_;
// The partition output tensor names
std::vector<std::string> out_tensor_names_;
// The PyXIR graph data structure
pyxir::XGraphHolder xg_;
// The Vitis AI DPU target
std::string backend_type_;
// If not empty, the path to the file where the PyXIR runtime module
// should be exported to (used for cross compilation)
std::string export_runtime_module_;
// If not empty, the path to the file where the PyXIR runtime module should
// be loaded from
std::string load_runtime_module_;
// The PyXIR runtime module
pyxir::RtModHolder rt_mod_ = nullptr;
// The EP ComputeContext allocation function
AllocateFunc allocate_func_ = nullptr;
// The EP ComputeContext release function
DestroyFunc release_func_ = nullptr;
// The EP ComputeContext allocator
AllocatorHandle allocator_ = nullptr;
// The EP ComputeContext node name
std::string name_;
// The compute lock
mutable std::mutex compute_lock_;
// The logger
const logging::Logger* logger_ = nullptr;
// The ONNX ModelProto to go from fused node -> ModelProto -> PyXIR
ONNX_NAMESPACE::ModelProto model_proto_;
};
} // namespace vitisai_ep

View file

@ -30,7 +30,9 @@ typedef std::shared_ptr<pyxir::graph::XGraph> XGraphHolder;
typedef std::shared_ptr<pyxir::graph::XLayer> XLayerHolder;
VitisAIExecutionProvider::VitisAIExecutionProvider(const VitisAIExecutionProviderInfo& info)
: IExecutionProvider{onnxruntime::kVitisAIExecutionProvider}, backend_type_(info.backend_type), device_id_(info.device_id) {
: IExecutionProvider{onnxruntime::kVitisAIExecutionProvider}, backend_type_(info.backend_type),
device_id_(info.device_id), export_runtime_module_(info.export_runtime_module),
load_runtime_module_(info.load_runtime_module) {
AllocatorCreationInfo default_memory_info{
[](int) {
return std::make_unique<CPUAllocator>(OrtMemoryInfo(VITISAI, OrtAllocatorType::OrtDeviceAllocator));
@ -276,7 +278,8 @@ common::Status VitisAIExecutionProvider::Compile(const std::vector<onnxruntime::
for (const auto& fused_node : fused_nodes) {
NodeComputeInfo compute_info;
compute_info.create_state_func = [this, fused_node, logger = GetLogger()](ComputeContext* context, FunctionState* state) {
auto* p = new vitisai_ep::VitisAICustomOp(context, fused_node, backend_type_, logger);
auto* p = new vitisai_ep::VitisAICustomOp(context, fused_node, backend_type_, export_runtime_module_,
load_runtime_module_, logger);
*state = p;
return 0;
};

View file

@ -13,6 +13,8 @@ namespace onnxruntime {
struct VitisAIExecutionProviderInfo {
int device_id{0};
std::string backend_type;
std::string export_runtime_module;
std::string load_runtime_module;
};
// Logical device representation.
@ -31,8 +33,16 @@ class VitisAIExecutionProvider : public IExecutionProvider {
std::vector<NodeComputeInfo>& node_compute_funcs) override;
private:
std::string backend_type_;
// The Vitis AI DPU target
std::string backend_type_;
// Device ID (Unused for now)
int device_id_;
// If not empty, the path to the file where the PyXIR runtime module
// should be exported to (used for cross compilation)
std::string export_runtime_module_;
// If not empty, the path to the file where the PyXIR runtime module
// should be loaded from
std::string load_runtime_module_;
};
} // namespace onnxruntime

View file

@ -11,31 +11,50 @@ using namespace onnxruntime;
namespace onnxruntime {
struct VitisAIProviderFactory : IExecutionProviderFactory {
VitisAIProviderFactory(std::string&& backend_type, int device_id)
: backend_type_(std::move(backend_type)), device_id_(device_id) {}
VitisAIProviderFactory(std::string&& backend_type, int device_id, std::string&& export_runtime_module,
std::string&& load_runtime_module)
: backend_type_(std::move(backend_type)), device_id_(device_id),
export_runtime_module_(std::move(export_runtime_module)),
load_runtime_module_(std::move(load_runtime_module)) {}
~VitisAIProviderFactory() = default;
std::unique_ptr<IExecutionProvider> CreateProvider() override;
private:
// The Vitis AI DPU target
const std::string backend_type_;
// Device ID (Unused for now)
int device_id_;
// If not empty, the path to the file where the PyXIR runtime module
// should be exported to (used for cross compilation)
const std::string export_runtime_module_;
// If not empty, the path to the file where the PyXIR runtime module
// should be loaded from
const std::string load_runtime_module_;
};
std::unique_ptr<IExecutionProvider> VitisAIProviderFactory::CreateProvider() {
VitisAIExecutionProviderInfo info;
info.backend_type = backend_type_;
info.device_id = device_id_;
info.export_runtime_module = export_runtime_module_;
info.load_runtime_module = load_runtime_module_;
return std::make_unique<VitisAIExecutionProvider>(info);
}
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_VITISAI(const char *backend_type, int device_id) {
return std::make_shared<onnxruntime::VitisAIProviderFactory>(backend_type, device_id);
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_VITISAI(
const char* backend_type, int device_id, const char* export_runtime_module,
const char* load_runtime_module) {
return std::make_shared<onnxruntime::VitisAIProviderFactory>(
backend_type, device_id, export_runtime_module, load_runtime_module);
}
} // namespace onnxruntime
ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_VITISAI, _In_ OrtSessionOptions* options, _In_ const char* backend_type, int device_id) {
options->provider_factories.push_back(onnxruntime::CreateExecutionProviderFactory_VITISAI(backend_type, device_id));
ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_VITISAI,
_In_ OrtSessionOptions* options, _In_ const char* backend_type, int device_id,
const char* export_runtime_module, const char* load_runtime_module) {
options->provider_factories.push_back(onnxruntime::CreateExecutionProviderFactory_VITISAI(
backend_type, device_id, export_runtime_module, load_runtime_module));
return nullptr;
}

View file

@ -47,7 +47,9 @@ std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Cuda(c
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Dnnl(int use_arena);
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_OpenVINO(const OrtOpenVINOProviderOptions* params);
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Nuphar(bool, const char*);
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_VITISAI(const char* backend_type, int device_id);
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_VITISAI(const char* backend_type, int device_id,
const char* export_runtime_module,
const char* load_runtime_module);
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_ACL(int use_arena);
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_ArmNN(int use_arena);
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_DML(int device_id);
@ -574,7 +576,34 @@ static void RegisterExecutionProviders(InferenceSession* sess, const std::vector
#endif
} else if (type == kVitisAIExecutionProvider) {
#if USE_VITISAI
RegisterExecutionProvider(sess, *onnxruntime::CreateExecutionProviderFactory_VITISAI("dpuv1", 0));
// Retrieve Vitis AI provider options
// `target`: The name of the DPU target (default is DPUCADX8G for backward compatibility).
// `export_runtime_module`: export a Vitis AI PyXIR runtime module to the specified file.
// This can be used for cross compilation or saving state.
// `load_runtime_module`: Load an exported runtime module from disk.
std::string target = "DPUCADX8G";
std::string export_runtime_module = "";
std::string load_runtime_module = "";
auto it = provider_options_map.find(type);
if (it != provider_options_map.end()) {
auto vitis_ai_provider_options = it->second;
auto vai_options_it = vitis_ai_provider_options.find("target");
if (vai_options_it != vitis_ai_provider_options.end()) {
target = vai_options_it->second;
}
vai_options_it = vitis_ai_provider_options.find("export_runtime_module");
if (vai_options_it != vitis_ai_provider_options.end()) {
export_runtime_module = vai_options_it->second;
}
vai_options_it = vitis_ai_provider_options.find("load_runtime_module");
if (vai_options_it != vitis_ai_provider_options.end()) {
load_runtime_module = vai_options_it->second;
}
}
RegisterExecutionProvider(
sess, *onnxruntime::CreateExecutionProviderFactory_VITISAI(target.c_str(), 0,
export_runtime_module.c_str(),
load_runtime_module.c_str()));
#endif
} else if (type == kAclExecutionProvider) {
#ifdef USE_ACL
@ -861,7 +890,7 @@ void addGlobalMethods(py::module& m, Environment& env) {
onnxruntime::CreateExecutionProviderFactory_MIGraphX(0),
#endif
#ifdef USE_VITISAI
onnxruntime::CreateExecutionProviderFactory_VitisAI("DPU", 0),
onnxruntime::CreateExecutionProviderFactory_VITISAI("DPUCADX8G", 0, "", ""),
#endif
#ifdef USE_ACL
onnxruntime::CreateExecutionProviderFactory_ACL(0),