onnxruntime/docs/python/inference/notebooks/onnxruntime-tvm-tutorial.ipynb
Valery Chernov 1cdc23aba4
[TVM EP] Rename Standalone TVM (STVM) Execution Provider to TVM EP (#10260)
* update java API for STVM EP. Issue is from PR#10019

* use_stvm -> use_tvm

* rename stvm worktree

* STVMAllocator -> TVMAllocator

* StvmExecutionProviderInfo -> TvmExecutionProviderInfo

* stvm -> tvm for cpu_targets. resolve onnxruntime::tvm and origin tvm namespaces conflict

* STVMRunner -> TVMRunner

* StvmExecutionProvider -> TvmExecutionProvider

* tvm::env_vars

* StvmProviderFactory -> TvmProviderFactory

* rename factory funcs

* StvmCPUDataTransfer -> TvmCPUDataTransfer

* small clean

* STVMFuncState -> TVMFuncState

* USE_TVM -> NUPHAR_USE_TVM

* USE_STVM -> USE_TVM

* python API: providers.stvm -> providers.tvm. clean TVM_EP.md

* clean build scripts #1

* clean build scripts, java frontend and others #2

* once more clean #3

* fix build of nuphar tvm test

* final transfer stvm namespace to onnxruntime::tvm

* rename stvm->tvm

* NUPHAR_USE_TVM -> USE_NUPHAR_TVM

* small fixes for correct CI tests

* clean after rebase. Last renaming stvm to tvm, separate TVM and Nuphar in cmake and build files

* update CUDA support for TVM EP

* roll back CudaNN home check

* ERROR for not positive input shape dimension instead of WARNING

* update documentation for CUDA

* small corrections after review

* update GPU description

* update GPU description

* misprints were fixed

* cleaned up error msgs

Co-authored-by: Valery Chernov <valery.chernov@deelvin.com>
Co-authored-by: KJlaccHoeUM9l <wotpricol@mail.ru>
Co-authored-by: Thierry Moreau <tmoreau@octoml.ai>
2022-02-15 10:21:02 +01:00

502 lines
16 KiB
Text

{
"cells": [
{
"cell_type": "markdown",
"id": "72476497",
"metadata": {},
"source": [
"# ONNX Runtime: Tutorial for TVM execution provider\n",
"\n",
"This notebook shows a simple example for model inference with TVM EP.\n",
"\n",
"\n",
"#### Tutorial Roadmap:\n",
"1. Prerequistes\n",
"2. Accuracy check for TVM EP\n",
"3. Configuration options"
]
},
{
"cell_type": "markdown",
"id": "9345cbab",
"metadata": {},
"source": [
"## 1. Prerequistes\n",
"\n",
"Make sure that you have installed all the necessary dependencies described in the corresponding paragraph of the documentation.\n",
"\n",
"Also, make sure you have the `tvm` and `onnxruntime-tvm` packages in your pip environment. \n",
"\n",
"If you are using `PYTHONPATH` variable expansion, make sure it contains the following paths: `<path_to_msft_onnxrt>/onnxruntime/cmake/external/tvm_update/python` and `<path_to_msft_onnxrt>/onnxruntime/build/Linux/Release`."
]
},
{
"cell_type": "markdown",
"id": "da4ca21f",
"metadata": {},
"source": [
"### Common import\n",
"\n",
"These packages can be delivered from standard `pip`."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "0f072875",
"metadata": {},
"outputs": [],
"source": [
"import onnx\n",
"import numpy as np\n",
"from typing import List, AnyStr\n",
"from onnx import ModelProto, helper, checker, mapping"
]
},
{
"cell_type": "markdown",
"id": "118670aa",
"metadata": {},
"source": [
"### Specialized import\n",
"\n",
"It is better to collect these packages from source code in order to clearly understand what is available to you right now."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "a5502966",
"metadata": {},
"outputs": [],
"source": [
"import tvm.testing\n",
"from tvm.contrib.download import download_testdata\n",
"import onnxruntime.providers.tvm # nessesary to register tvm_onnx_import_and_compile and others"
]
},
{
"cell_type": "markdown",
"id": "b7313183",
"metadata": {},
"source": [
"### Helper functions for working with ONNX ModelProto\n",
"\n",
"This set of helper functions allows you to recognize the meta information of the models. This information is needed for more versatile processing of ONNX models."
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "7d0a36e8",
"metadata": {},
"outputs": [],
"source": [
"def get_onnx_input_names(model: ModelProto) -> List[AnyStr]:\n",
" inputs = [node.name for node in model.graph.input]\n",
" initializer = [node.name for node in model.graph.initializer]\n",
" inputs = list(set(inputs) - set(initializer))\n",
" return sorted(inputs)\n",
"\n",
"\n",
"def get_onnx_output_names(model: ModelProto) -> List[AnyStr]:\n",
" return [node.name for node in model.graph.output]\n",
"\n",
"\n",
"def get_onnx_input_types(model: ModelProto) -> List[np.dtype]:\n",
" input_names = get_onnx_input_names(model)\n",
" return [\n",
" mapping.TENSOR_TYPE_TO_NP_TYPE[node.type.tensor_type.elem_type]\n",
" for node in sorted(model.graph.input, key=lambda node: node.name) if node.name in input_names\n",
" ]\n",
"\n",
"\n",
"def get_onnx_input_shapes(model: ModelProto) -> List[List[int]]:\n",
" input_names = get_onnx_input_names(model)\n",
" return [\n",
" [dv.dim_value for dv in node.type.tensor_type.shape.dim]\n",
" for node in sorted(model.graph.input, key=lambda node: node.name) if node.name in input_names\n",
" ]\n",
"\n",
"\n",
"def get_random_model_inputs(model: ModelProto) -> List[np.ndarray]:\n",
" input_shapes = get_onnx_input_shapes(model)\n",
" input_types = get_onnx_input_types(model)\n",
" assert len(input_types) == len(input_shapes)\n",
" inputs = [np.random.uniform(size=shape).astype(dtype) for shape, dtype in zip(input_shapes, input_types)]\n",
" return inputs"
]
},
{
"cell_type": "markdown",
"id": "f0de1682",
"metadata": {},
"source": [
"### Wrapper helper functions for Inference\n",
"\n",
"Wrapper helper functions for running model inference using ONNX Runtime EP."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "258ce9e9",
"metadata": {},
"outputs": [],
"source": [
"def get_onnxruntime_output(model: ModelProto, inputs: List, provider_name: AnyStr) -> np.ndarray:\n",
" output_names = get_onnx_output_names(model)\n",
" input_names = get_onnx_input_names(model)\n",
" assert len(input_names) == len(inputs)\n",
" input_dict = {input_name: input_value for input_name, input_value in zip(input_names, inputs)}\n",
"\n",
" inference_session = onnxruntime.InferenceSession(model.SerializeToString(), providers=[provider_name])\n",
" output = inference_session.run(output_names, input_dict)\n",
"\n",
" # Unpack output if there's only a single value.\n",
" if len(output) == 1:\n",
" output = output[0]\n",
" return output\n",
"\n",
"\n",
"def get_cpu_onnxruntime_output(model: ModelProto, inputs: List) -> np.ndarray:\n",
" return get_onnxruntime_output(model, inputs, \"CPUExecutionProvider\")\n",
"\n",
"\n",
"def get_tvm_onnxruntime_output(model: ModelProto, inputs: List) -> np.ndarray:\n",
" return get_onnxruntime_output(model, inputs, \"TvmExecutionProvider\")"
]
},
{
"cell_type": "markdown",
"id": "cc17d3b2",
"metadata": {},
"source": [
"### Helper function for checking accuracy\n",
"\n",
"This function uses the TVM API to compare two output tensors. The tensor obtained using the `CPUExecutionProvider` is used as a reference.\n",
"\n",
"If a mismatch is found between tensors, an appropriate exception will be thrown."
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "f33a372b",
"metadata": {},
"outputs": [],
"source": [
"def verify_with_ort_with_inputs(\n",
" model,\n",
" inputs,\n",
" out_shape=None,\n",
" opset=None,\n",
" freeze_params=False,\n",
" dtype=\"float32\",\n",
" rtol=1e-5,\n",
" atol=1e-5,\n",
" opt_level=1,\n",
"):\n",
" if opset is not None:\n",
" model.opset_import[0].version = opset\n",
"\n",
" ort_out = get_cpu_onnxruntime_output(model, inputs)\n",
" tvm_out = get_tvm_onnxruntime_output(model, inputs)\n",
" for tvm_val, ort_val in zip(tvm_out, ort_out):\n",
" tvm.testing.assert_allclose(ort_val, tvm_val, rtol=rtol, atol=atol)\n",
" assert ort_val.dtype == tvm_val.dtype"
]
},
{
"cell_type": "markdown",
"id": "8c62b01a",
"metadata": {},
"source": [
"### Helper functions for download models\n",
"\n",
"These functions use the TVM API to download models from the ONNX Model Zoo."
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "324c00e7",
"metadata": {},
"outputs": [],
"source": [
"BASE_MODEL_URL = \"https://github.com/onnx/models/raw/master/\"\n",
"MODEL_URL_COLLECTION = {\n",
" \"ResNet50-v1\": \"vision/classification/resnet/model/resnet50-v1-7.onnx\",\n",
" \"ResNet50-v2\": \"vision/classification/resnet/model/resnet50-v2-7.onnx\",\n",
" \"SqueezeNet-v1.1\": \"vision/classification/squeezenet/model/squeezenet1.1-7.onnx\",\n",
" \"SqueezeNet-v1.0\": \"vision/classification/squeezenet/model/squeezenet1.0-7.onnx\",\n",
" \"Inception-v1\": \"vision/classification/inception_and_googlenet/inception_v1/model/inception-v1-7.onnx\",\n",
" \"Inception-v2\": \"vision/classification/inception_and_googlenet/inception_v2/model/inception-v2-7.onnx\",\n",
"}\n",
"\n",
"\n",
"def get_model_url(model_name):\n",
" return BASE_MODEL_URL + MODEL_URL_COLLECTION[model_name]\n",
"\n",
"\n",
"def get_name_from_url(url):\n",
" return url[url.rfind(\"/\") + 1 :].strip()\n",
"\n",
"\n",
"def find_of_download(model_name):\n",
" model_url = get_model_url(model_name)\n",
" model_file_name = get_name_from_url(model_url)\n",
" return download_testdata(model_url, model_file_name, module=\"models\")"
]
},
{
"cell_type": "markdown",
"id": "90fb7c5c",
"metadata": {},
"source": [
"## 2. Accuracy check for TVM EP \n",
"\n",
"This section will check the accuracy. The check will be to compare the output tensors for `CPUExecutionProvider` and `TvmExecutionProvider`. See the description of `verify_with_ort_with_inputs` function used above.\n",
"\n",
"\n",
"### Check for simple architectures"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "c739ed5c",
"metadata": {},
"outputs": [],
"source": [
"def get_two_input_model(op_name: AnyStr) -> ModelProto:\n",
" dtype = \"float32\"\n",
" in_shape = [1, 2, 3, 3]\n",
" in_type = mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(dtype)]\n",
" out_shape = in_shape\n",
" out_type = in_type\n",
"\n",
" layer = helper.make_node(op_name, [\"in1\", \"in2\"], [\"out\"])\n",
" graph = helper.make_graph(\n",
" [layer],\n",
" \"two_input_test\",\n",
" inputs=[\n",
" helper.make_tensor_value_info(\"in1\", in_type, in_shape),\n",
" helper.make_tensor_value_info(\"in2\", in_type, in_shape),\n",
" ],\n",
" outputs=[\n",
" helper.make_tensor_value_info(\n",
" \"out\", out_type, out_shape\n",
" )\n",
" ],\n",
" )\n",
" model = helper.make_model(graph, producer_name=\"two_input_test\")\n",
" checker.check_model(model, full_check=True)\n",
" return model"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "7048ee6d",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/onnxruntime/onnxruntime/core/providers/tvm/tvm_execution_provider.cc:480: TVM EP options:\n",
"target: llvm -mcpu=skylake-avx512\n",
"target_host: llvm -mcpu=skylake-avx512\n",
"opt level: 3\n",
"freeze weights: 1\n",
"tuning file path: \n",
"tuning type: AutoTVM\n",
"convert layout to NHWC: 0\n",
"input tensor names: \n",
"input tensor shapes: \n",
"/home/tvm/python/tvm/relay/build_module.py:414: UserWarning: target_host parameter is going to be deprecated. Please pass in tvm.target.Target(target, host=target_host) instead.\n",
" warnings.warn(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"****************** Success! ******************\n"
]
}
],
"source": [
"onnx_model = get_two_input_model(\"Add\")\n",
"inputs = get_random_model_inputs(onnx_model)\n",
"verify_with_ort_with_inputs(onnx_model, inputs)\n",
"print(\"****************** Success! ******************\")"
]
},
{
"cell_type": "markdown",
"id": "52c880f4",
"metadata": {},
"source": [
"### Check for DNN architectures "
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "f5d465dc",
"metadata": {},
"outputs": [],
"source": [
"def get_onnx_model(model_name):\n",
" model_path = find_of_download(model_name)\n",
" onnx_model = onnx.load(model_path)\n",
" return onnx_model"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "68daac7e",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/onnxruntime/onnxruntime/core/providers/tvm/tvm_execution_provider.cc:480: TVM EP options:\n",
"target: llvm -mcpu=skylake-avx512\n",
"target_host: llvm -mcpu=skylake-avx512\n",
"opt level: 3\n",
"freeze weights: 1\n",
"tuning file path: \n",
"tuning type: AutoTVM\n",
"convert layout to NHWC: 0\n",
"input tensor names: \n",
"input tensor shapes: \n",
"/home/tvm/python/tvm/driver/build_module.py:235: UserWarning: Specifying name with IRModule input is useless\n",
" warnings.warn(\"Specifying name with IRModule input is useless\")\n",
"One or more operators have not been tuned. Please tune your model for better performance. Use DEBUG logging level to see more details.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"****************** Success! ******************\n"
]
}
],
"source": [
"model_name = \"ResNet50-v1\"\n",
"\n",
"onnx_model = get_onnx_model(model_name)\n",
"inputs = get_random_model_inputs(onnx_model)\n",
"verify_with_ort_with_inputs(onnx_model, inputs)\n",
"print(\"****************** Success! ******************\")"
]
},
{
"cell_type": "markdown",
"id": "e27f64a2",
"metadata": {},
"source": [
"## 3. Configuration options\n",
"\n",
"This section shows how you can configure TVM EP using custom options. For more details on the options used, see the corresponding section of the documentation."
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "a053f59f",
"metadata": {},
"outputs": [],
"source": [
"provider_name = \"TvmExecutionProvider\"\n",
"provider_options = dict(target=\"llvm -mtriple=x86_64-linux-gnu\",\n",
" target_host=\"llvm -mtriple=x86_64-linux-gnu\",\n",
" opt_level=3,\n",
" freeze_weights=True,\n",
" tuning_file_path=\"\",\n",
" tuning_type=\"Ansor\",\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "3f6e6f01",
"metadata": {},
"outputs": [],
"source": [
"model_name = \"ResNet50-v1\"\n",
"onnx_model = get_onnx_model(model_name)\n",
"input_dict = {input_name: input_value for input_name, input_value in zip(get_onnx_input_names(onnx_model),\n",
" get_random_model_inputs(onnx_model))}\n",
"output_names = get_onnx_output_names(onnx_model)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "85ab83f2",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/onnxruntime/onnxruntime/core/providers/tvm/tvm_execution_provider.cc:480: TVM EP options:\n",
"target: llvm -mtriple=x86_64-linux-gnu\n",
"target_host: llvm -mtriple=x86_64-linux-gnu\n",
"opt level: 3\n",
"freeze weights: 1\n",
"tuning file path: \n",
"tuning type: Ansor\n",
"convert layout to NHWC: 0\n",
"input tensor names: \n",
"input tensor shapes: \n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"****************** Output shape: (1, 1000) ******************\n"
]
}
],
"source": [
"tvm_session = onnxruntime.InferenceSession(onnx_model.SerializeToString(),\n",
" providers=[provider_name],\n",
" provider_options=[provider_options]\n",
" )\n",
"output = tvm_session.run(output_names, input_dict)[0]\n",
"print(f\"****************** Output shape: {output.shape} ******************\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
}
},
"nbformat": 4,
"nbformat_minor": 5
}