mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-17 21:10:43 +00:00
Description: Format all python files under onnxruntime with black and isort. After checking in, we can use .git-blame-ignore-revs to ignore the formatting PR in git blame. #11315, #11316
212 lines
8.1 KiB
Python
212 lines
8.1 KiB
Python
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
# Licensed under the MIT License.
|
|
|
|
import base64
|
|
import datetime
|
|
import errno
|
|
import json
|
|
import math
|
|
import os
|
|
import socket
|
|
import struct
|
|
import subprocess
|
|
import sys
|
|
import time
|
|
import urllib.request
|
|
|
|
import numpy
|
|
import onnx_ml_pb2
|
|
import predict_pb2
|
|
import requests
|
|
|
|
|
|
def test_log(str):
|
|
print("[Test Log][{0}] {1}".format(datetime.datetime.now(), str))
|
|
|
|
|
|
def is_process_killed(pid):
|
|
if sys.platform.startswith("win"):
|
|
process_name = "onnxruntime_host.exe"
|
|
call = "TASKLIST", "/FI", "imagename eq {0}".format(process_name)
|
|
output = subprocess.check_output(call).decode("utf-8")
|
|
print(output)
|
|
last_line = output.strip().split("\r\n")[-1]
|
|
return not last_line.lower().startswith(process_name)
|
|
else:
|
|
try:
|
|
os.kill(pid, 0)
|
|
except OSError:
|
|
return False
|
|
else:
|
|
return True
|
|
|
|
|
|
def prepare_mnist_model(target_path):
|
|
if not os.path.isfile(target_path):
|
|
# Backup path: in case the mnist model is missing, we need to download it from Internet.
|
|
test_log(
|
|
"Downloading model from blob storage: https://ortsrvdev.blob.core.windows.net/test-data/model.onnx to {0}".format(
|
|
target_path
|
|
)
|
|
)
|
|
urllib.request.urlretrieve("https://ortsrvdev.blob.core.windows.net/test-data/model.onnx", target_path)
|
|
else:
|
|
test_log("Found mnist model at {0}".format(target_path))
|
|
|
|
|
|
def decode_base64_string(s, count_and_type):
|
|
b = base64.b64decode(s)
|
|
r = struct.unpack(count_and_type, b)
|
|
|
|
return r
|
|
|
|
|
|
def compare_floats(a, b, rel_tol=0.0001, abs_tol=0.0001):
|
|
if not math.isclose(a, b, rel_tol=rel_tol, abs_tol=abs_tol):
|
|
test_log(
|
|
"Not match with relative tolerance {0} and absolute tolerance {1}: {2} and {3}".format(
|
|
rel_tol, abs_tol, a, b
|
|
)
|
|
)
|
|
return False
|
|
|
|
return True
|
|
|
|
|
|
def wait_service_up(server, port, timeout=1):
|
|
s = socket.socket()
|
|
if timeout:
|
|
end = time.time() + timeout
|
|
|
|
while True:
|
|
try:
|
|
if timeout:
|
|
next_timeout = end - time.time()
|
|
if next_timeout < 0:
|
|
return False
|
|
else:
|
|
s.settimeout(next_timeout)
|
|
|
|
s.connect((server, port))
|
|
except socket.timeout as err:
|
|
if timeout:
|
|
return False
|
|
except Exception as err:
|
|
pass
|
|
else:
|
|
s.close()
|
|
return True
|
|
|
|
|
|
def launch_server_app(cmd, server_ip, server_port, wait_server_ready_in_seconds):
|
|
test_log("Launching server app: [{0}]".format(" ".join(cmd)))
|
|
server_app_proc = subprocess.Popen(cmd)
|
|
test_log("Server app PID: {0}".format(server_app_proc.pid))
|
|
test_log("Wait up to {0} second(s) for server initialization".format(wait_server_ready_in_seconds))
|
|
wait_service_up(server_ip, server_port, wait_server_ready_in_seconds)
|
|
|
|
# Additional sleep to make sure the server is ready.
|
|
time.sleep(1)
|
|
|
|
return server_app_proc
|
|
|
|
|
|
def shutdown_server_app(server_app_proc, wait_for_server_off_in_seconds):
|
|
if server_app_proc is not None:
|
|
test_log("Shutdown server app")
|
|
server_app_proc.kill()
|
|
|
|
while not is_process_killed(server_app_proc.pid):
|
|
server_app_proc.wait(timeout=wait_for_server_off_in_seconds)
|
|
test_log("PID {0} has been killed: {1}".format(server_app_proc.pid, is_process_killed(server_app_proc.pid)))
|
|
|
|
# Additional sleep to make sure the resource has been freed.
|
|
time.sleep(1)
|
|
|
|
return True
|
|
|
|
|
|
def make_http_request(url, request_headers, payload):
|
|
test_log("POST Request Started")
|
|
resp = requests.post(url, headers=request_headers, data=payload)
|
|
test_log("POST Request Done")
|
|
return resp
|
|
|
|
|
|
def json_response_validation(cls, resp, expected_resp_json_file):
|
|
cls.assertEqual(resp.status_code, 200)
|
|
cls.assertTrue(resp.headers.get("x-ms-request-id"))
|
|
cls.assertEqual(resp.headers.get("Content-Type"), "application/json")
|
|
|
|
with open(expected_resp_json_file) as f:
|
|
expected_result = json.loads(f.read())
|
|
|
|
actual_response = json.loads(resp.content.decode("utf-8"))
|
|
cls.assertTrue(actual_response["outputs"])
|
|
|
|
for output in expected_result["outputs"].keys():
|
|
cls.assertTrue(actual_response["outputs"][output])
|
|
cls.assertTrue(actual_response["outputs"][output]["dataType"])
|
|
cls.assertEqual(actual_response["outputs"][output]["dataType"], expected_result["outputs"][output]["dataType"])
|
|
cls.assertTrue(actual_response["outputs"][output]["dims"])
|
|
cls.assertEqual(actual_response["outputs"][output]["dims"], expected_result["outputs"][output]["dims"])
|
|
cls.assertTrue(actual_response["outputs"][output]["rawData"])
|
|
|
|
count = 1
|
|
for x in actual_response["outputs"][output]["dims"]:
|
|
count = count * int(x)
|
|
|
|
if actual_response["outputs"][output]["dataType"] == 10 or actual_response["outputs"][output]["dataType"] == 16:
|
|
actual_array = numpy.frombuffer(
|
|
base64.b64decode(actual_response["outputs"][output]["rawData"]), dtype=numpy.float16
|
|
)
|
|
expected_array = numpy.frombuffer(
|
|
base64.b64decode(expected_result["outputs"][output]["rawData"]), dtype=numpy.float16
|
|
)
|
|
cls.assertEqual(len(actual_array), len(expected_array))
|
|
cls.assertEqual(len(actual_array), count)
|
|
for i in range(0, count):
|
|
cls.assertTrue(compare_floats(actual_array[i], expected_array[i], rel_tol=0.05, abs_tol=0.05))
|
|
elif actual_response["outputs"][output]["dataType"] == 1:
|
|
actual_array = decode_base64_string(actual_response["outputs"][output]["rawData"], "{0}f".format(count))
|
|
expected_array = decode_base64_string(expected_result["outputs"][output]["rawData"], "{0}f".format(count))
|
|
cls.assertEqual(len(actual_array), len(expected_array))
|
|
cls.assertEqual(len(actual_array), count)
|
|
for i in range(0, count):
|
|
cls.assertTrue(compare_floats(actual_array[i], expected_array[i], rel_tol=0.001))
|
|
|
|
|
|
def pb_response_validation(cls, resp, expected_resp_pb_file):
|
|
cls.assertEqual(resp.status_code, 200)
|
|
cls.assertTrue(resp.headers.get("x-ms-request-id"))
|
|
cls.assertEqual(resp.headers.get("Content-Type"), "application/octet-stream")
|
|
|
|
actual_result = predict_pb2.PredictResponse()
|
|
actual_result.ParseFromString(resp.content)
|
|
|
|
expected_result = predict_pb2.PredictResponse()
|
|
with open(expected_resp_pb_file, "rb") as f:
|
|
expected_result.ParseFromString(f.read())
|
|
|
|
for k in expected_result.outputs.keys():
|
|
cls.assertEqual(actual_result.outputs[k].data_type, expected_result.outputs[k].data_type)
|
|
|
|
count = 1
|
|
for i in range(0, len(expected_result.outputs[k].dims)):
|
|
cls.assertEqual(actual_result.outputs[k].dims[i], expected_result.outputs[k].dims[i])
|
|
count = count * int(actual_result.outputs[k].dims[i])
|
|
|
|
if actual_result.outputs[k].data_type == 10 or actual_result.outputs[k].data_type == 16:
|
|
actual_array = numpy.frombuffer(actual_result.outputs[k].raw_data, dtype=numpy.float16)
|
|
expected_array = numpy.frombuffer(expected_result.outputs[k].raw_data, dtype=numpy.float16)
|
|
cls.assertEqual(len(actual_array), len(expected_array))
|
|
cls.assertEqual(len(actual_array), count)
|
|
for i in range(0, count):
|
|
cls.assertTrue(compare_floats(actual_array[i], expected_array[i], rel_tol=0.05, abs_tol=0.05))
|
|
elif actual_result.outputs[k].data_type == 1:
|
|
actual_array = numpy.frombuffer(actual_result.outputs[k].raw_data, dtype=numpy.float32)
|
|
expected_array = numpy.frombuffer(expected_result.outputs[k].raw_data, dtype=numpy.float32)
|
|
cls.assertEqual(len(actual_array), len(expected_array))
|
|
cls.assertEqual(len(actual_array), count)
|
|
for i in range(0, count):
|
|
cls.assertTrue(compare_floats(actual_array[i], expected_array[i], rel_tol=0.001))
|