mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-17 21:10:43 +00:00
Description: Format all python files under onnxruntime with black and isort. After checking in, we can use .git-blame-ignore-revs to ignore the formatting PR in git blame. #11315, #11316
519 lines
22 KiB
Python
519 lines
22 KiB
Python
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
# Licensed under the MIT License.
|
|
|
|
import json
|
|
import os
|
|
import subprocess
|
|
import time
|
|
import unittest
|
|
|
|
import grpc
|
|
import numpy
|
|
import onnx_ml_pb2
|
|
import predict_pb2
|
|
import prediction_service_pb2_grpc
|
|
import requests
|
|
import test_util
|
|
|
|
|
|
class HttpJsonPayloadTests(unittest.TestCase):
|
|
server_ip = "127.0.0.1"
|
|
server_port = 54321
|
|
url_pattern = "http://{0}:{1}/v1/models/{2}/versions/{3}:predict"
|
|
server_app_path = ""
|
|
test_data_path = ""
|
|
model_path = ""
|
|
log_level = "verbose"
|
|
server_app_proc = None
|
|
wait_server_ready_in_seconds = 1
|
|
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
onnx_model = os.path.join(cls.model_path, "model.onnx")
|
|
test_util.prepare_mnist_model(onnx_model)
|
|
cmd = [
|
|
cls.server_app_path,
|
|
"--http_port",
|
|
str(cls.server_port),
|
|
"--model_path",
|
|
onnx_model,
|
|
"--log_level",
|
|
cls.log_level,
|
|
]
|
|
test_util.test_log("Launching server app: [{0}]".format(" ".join(cmd)))
|
|
cls.server_app_proc = subprocess.Popen(cmd)
|
|
test_util.test_log("Server app PID: {0}".format(cls.server_app_proc.pid))
|
|
test_util.test_log(
|
|
"Sleep {0} second(s) to wait for server initialization".format(cls.wait_server_ready_in_seconds)
|
|
)
|
|
time.sleep(cls.wait_server_ready_in_seconds)
|
|
|
|
@classmethod
|
|
def tearDownClass(cls):
|
|
test_util.test_log("Shutdown server app")
|
|
cls.server_app_proc.kill()
|
|
test_util.test_log(
|
|
"PID {0} has been killed: {1}".format(
|
|
cls.server_app_proc.pid, test_util.is_process_killed(cls.server_app_proc.pid)
|
|
)
|
|
)
|
|
|
|
def test_mnist_happy_path(self):
|
|
input_data_file = os.path.join(self.test_data_path, "mnist_test_data_set_0_input.json")
|
|
output_data_file = os.path.join(self.test_data_path, "mnist_test_data_set_0_output.json")
|
|
|
|
with open(input_data_file, "r") as f:
|
|
request_payload = f.read()
|
|
|
|
with open(output_data_file, "r") as f:
|
|
expected_response_json = f.read()
|
|
expected_response = json.loads(expected_response_json)
|
|
|
|
request_headers = {
|
|
"Content-Type": "application/json",
|
|
"Accept": "application/json",
|
|
"x-ms-client-request-id": "This~is~my~id",
|
|
}
|
|
|
|
url = self.url_pattern.format(self.server_ip, self.server_port, "default", 1)
|
|
test_util.test_log(url)
|
|
r = requests.post(url, headers=request_headers, data=request_payload)
|
|
self.assertEqual(r.status_code, 200)
|
|
self.assertEqual(r.headers.get("Content-Type"), "application/json")
|
|
self.assertTrue(r.headers.get("x-ms-request-id"))
|
|
self.assertEqual(r.headers.get("x-ms-client-request-id"), "This~is~my~id")
|
|
|
|
actual_response = json.loads(r.content.decode("utf-8"))
|
|
|
|
# Note:
|
|
# The 'dims' field is defined as "repeated int64" in protobuf.
|
|
# When it is serialized to JSON, all int64/fixed64/uint64 numbers are converted to string
|
|
# Reference: https://developers.google.com/protocol-buffers/docs/proto3#json
|
|
|
|
self.assertTrue(actual_response["outputs"])
|
|
self.assertTrue(actual_response["outputs"]["Plus214_Output_0"])
|
|
self.assertTrue(actual_response["outputs"]["Plus214_Output_0"]["dims"])
|
|
self.assertEqual(actual_response["outputs"]["Plus214_Output_0"]["dims"], ["1", "10"])
|
|
self.assertTrue(actual_response["outputs"]["Plus214_Output_0"]["dataType"])
|
|
self.assertEqual(actual_response["outputs"]["Plus214_Output_0"]["dataType"], 1)
|
|
self.assertTrue(actual_response["outputs"]["Plus214_Output_0"]["rawData"])
|
|
actual_data = test_util.decode_base64_string(actual_response["outputs"]["Plus214_Output_0"]["rawData"], "10f")
|
|
expected_data = test_util.decode_base64_string(
|
|
expected_response["outputs"]["Plus214_Output_0"]["rawData"], "10f"
|
|
)
|
|
|
|
for i in range(0, 10):
|
|
self.assertTrue(test_util.compare_floats(actual_data[i], expected_data[i]))
|
|
|
|
def test_mnist_invalid_url(self):
|
|
url = self.url_pattern.format(self.server_ip, self.server_port, "default", -1)
|
|
test_util.test_log(url)
|
|
|
|
request_headers = {"Content-Type": "application/json", "Accept": "application/json"}
|
|
|
|
r = requests.post(url, headers=request_headers, data={"foo": "bar"})
|
|
self.assertEqual(r.status_code, 404)
|
|
self.assertEqual(r.headers.get("Content-Type"), "application/json")
|
|
self.assertTrue(r.headers.get("x-ms-request-id"))
|
|
|
|
def test_mnist_invalid_content_type(self):
|
|
input_data_file = os.path.join(self.test_data_path, "mnist_test_data_set_0_input.json")
|
|
url = self.url_pattern.format(self.server_ip, self.server_port, "default", 1)
|
|
test_util.test_log(url)
|
|
|
|
request_headers = {
|
|
"Content-Type": "application/abc",
|
|
"Accept": "application/json",
|
|
"x-ms-client-request-id": "This~is~my~id",
|
|
}
|
|
|
|
with open(input_data_file, "r") as f:
|
|
request_payload = f.read()
|
|
|
|
r = requests.post(url, headers=request_headers, data=request_payload)
|
|
self.assertEqual(r.status_code, 400)
|
|
self.assertEqual(r.headers.get("Content-Type"), "application/json")
|
|
self.assertTrue(r.headers.get("x-ms-request-id"))
|
|
self.assertEqual(r.headers.get("x-ms-client-request-id"), "This~is~my~id")
|
|
self.assertEqual(
|
|
r.content.decode("utf-8"),
|
|
'{"error_code": 400, "error_message": "Missing or unknown \'Content-Type\' header field in the request"}\n',
|
|
)
|
|
|
|
def test_mnist_missing_content_type(self):
|
|
input_data_file = os.path.join(self.test_data_path, "mnist_test_data_set_0_input.json")
|
|
url = self.url_pattern.format(self.server_ip, self.server_port, "default", 1)
|
|
test_util.test_log(url)
|
|
|
|
request_headers = {"Accept": "application/json"}
|
|
|
|
with open(input_data_file, "r") as f:
|
|
request_payload = f.read()
|
|
|
|
r = requests.post(url, headers=request_headers, data=request_payload)
|
|
self.assertEqual(r.status_code, 400)
|
|
self.assertEqual(r.headers.get("Content-Type"), "application/json")
|
|
self.assertTrue(r.headers.get("x-ms-request-id"))
|
|
self.assertEqual(
|
|
r.content.decode("utf-8"),
|
|
'{"error_code": 400, "error_message": "Missing or unknown \'Content-Type\' header field in the request"}\n',
|
|
)
|
|
|
|
def test_single_model_shortcut(self):
|
|
input_data_file = os.path.join(self.test_data_path, "mnist_test_data_set_0_input.json")
|
|
output_data_file = os.path.join(self.test_data_path, "mnist_test_data_set_0_output.json")
|
|
|
|
with open(input_data_file, "r") as f:
|
|
request_payload = f.read()
|
|
|
|
with open(output_data_file, "r") as f:
|
|
expected_response_json = f.read()
|
|
expected_response = json.loads(expected_response_json)
|
|
|
|
request_headers = {
|
|
"Content-Type": "application/json",
|
|
"Accept": "application/json",
|
|
"x-ms-client-request-id": "This~is~my~id",
|
|
}
|
|
|
|
url = "http://{0}:{1}/score".format(self.server_ip, self.server_port)
|
|
test_util.test_log(url)
|
|
r = requests.post(url, headers=request_headers, data=request_payload)
|
|
self.assertEqual(r.status_code, 200)
|
|
self.assertEqual(r.headers.get("Content-Type"), "application/json")
|
|
self.assertTrue(r.headers.get("x-ms-request-id"))
|
|
self.assertEqual(r.headers.get("x-ms-client-request-id"), "This~is~my~id")
|
|
|
|
actual_response = json.loads(r.content.decode("utf-8"))
|
|
|
|
# Note:
|
|
# The 'dims' field is defined as "repeated int64" in protobuf.
|
|
# When it is serialized to JSON, all int64/fixed64/uint64 numbers are converted to string
|
|
# Reference: https://developers.google.com/protocol-buffers/docs/proto3#json
|
|
|
|
self.assertTrue(actual_response["outputs"])
|
|
self.assertTrue(actual_response["outputs"]["Plus214_Output_0"])
|
|
self.assertTrue(actual_response["outputs"]["Plus214_Output_0"]["dims"])
|
|
self.assertEqual(actual_response["outputs"]["Plus214_Output_0"]["dims"], ["1", "10"])
|
|
self.assertTrue(actual_response["outputs"]["Plus214_Output_0"]["dataType"])
|
|
self.assertEqual(actual_response["outputs"]["Plus214_Output_0"]["dataType"], 1)
|
|
self.assertTrue(actual_response["outputs"]["Plus214_Output_0"]["rawData"])
|
|
actual_data = test_util.decode_base64_string(actual_response["outputs"]["Plus214_Output_0"]["rawData"], "10f")
|
|
expected_data = test_util.decode_base64_string(
|
|
expected_response["outputs"]["Plus214_Output_0"]["rawData"], "10f"
|
|
)
|
|
|
|
for i in range(0, 10):
|
|
self.assertTrue(test_util.compare_floats(actual_data[i], expected_data[i]))
|
|
|
|
def test_single_version_shortcut(self):
|
|
input_data_file = os.path.join(self.test_data_path, "mnist_test_data_set_0_input.json")
|
|
output_data_file = os.path.join(self.test_data_path, "mnist_test_data_set_0_output.json")
|
|
|
|
with open(input_data_file, "r") as f:
|
|
request_payload = f.read()
|
|
|
|
with open(output_data_file, "r") as f:
|
|
expected_response_json = f.read()
|
|
expected_response = json.loads(expected_response_json)
|
|
|
|
request_headers = {
|
|
"Content-Type": "application/json",
|
|
"Accept": "application/json",
|
|
"x-ms-client-request-id": "This~is~my~id",
|
|
}
|
|
|
|
url = "http://{0}:{1}/v1/models/{2}:predict".format(self.server_ip, self.server_port, "default")
|
|
test_util.test_log(url)
|
|
r = requests.post(url, headers=request_headers, data=request_payload)
|
|
self.assertEqual(r.status_code, 200)
|
|
self.assertEqual(r.headers.get("Content-Type"), "application/json")
|
|
self.assertTrue(r.headers.get("x-ms-request-id"))
|
|
self.assertEqual(r.headers.get("x-ms-client-request-id"), "This~is~my~id")
|
|
|
|
actual_response = json.loads(r.content.decode("utf-8"))
|
|
|
|
# Note:
|
|
# The 'dims' field is defined as "repeated int64" in protobuf.
|
|
# When it is serialized to JSON, all int64/fixed64/uint64 numbers are converted to string
|
|
# Reference: https://developers.google.com/protocol-buffers/docs/proto3#json
|
|
|
|
self.assertTrue(actual_response["outputs"])
|
|
self.assertTrue(actual_response["outputs"]["Plus214_Output_0"])
|
|
self.assertTrue(actual_response["outputs"]["Plus214_Output_0"]["dims"])
|
|
self.assertEqual(actual_response["outputs"]["Plus214_Output_0"]["dims"], ["1", "10"])
|
|
self.assertTrue(actual_response["outputs"]["Plus214_Output_0"]["dataType"])
|
|
self.assertEqual(actual_response["outputs"]["Plus214_Output_0"]["dataType"], 1)
|
|
self.assertTrue(actual_response["outputs"]["Plus214_Output_0"]["rawData"])
|
|
actual_data = test_util.decode_base64_string(actual_response["outputs"]["Plus214_Output_0"]["rawData"], "10f")
|
|
expected_data = test_util.decode_base64_string(
|
|
expected_response["outputs"]["Plus214_Output_0"]["rawData"], "10f"
|
|
)
|
|
|
|
for i in range(0, 10):
|
|
self.assertTrue(test_util.compare_floats(actual_data[i], expected_data[i]))
|
|
|
|
|
|
class HttpProtobufPayloadTests(unittest.TestCase):
|
|
server_ip = "127.0.0.1"
|
|
server_port = 54321
|
|
url_pattern = "http://{0}:{1}/v1/models/{2}/versions/{3}:predict"
|
|
server_app_path = ""
|
|
test_data_path = ""
|
|
model_path = ""
|
|
log_level = "verbose"
|
|
server_app_proc = None
|
|
wait_server_ready_in_seconds = 1
|
|
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
onnx_model = os.path.join(cls.model_path, "model.onnx")
|
|
test_util.prepare_mnist_model(onnx_model)
|
|
cmd = [
|
|
cls.server_app_path,
|
|
"--http_port",
|
|
str(cls.server_port),
|
|
"--model_path",
|
|
onnx_model,
|
|
"--log_level",
|
|
cls.log_level,
|
|
]
|
|
test_util.test_log("Launching server app: [{0}]".format(" ".join(cmd)))
|
|
cls.server_app_proc = subprocess.Popen(cmd)
|
|
test_util.test_log("Server app PID: {0}".format(cls.server_app_proc.pid))
|
|
test_util.test_log(
|
|
"Sleep {0} second(s) to wait for server initialization".format(cls.wait_server_ready_in_seconds)
|
|
)
|
|
time.sleep(cls.wait_server_ready_in_seconds)
|
|
|
|
@classmethod
|
|
def tearDownClass(cls):
|
|
test_util.test_log("Shutdown server app")
|
|
cls.server_app_proc.kill()
|
|
test_util.test_log(
|
|
"PID {0} has been killed: {1}".format(
|
|
cls.server_app_proc.pid, test_util.is_process_killed(cls.server_app_proc.pid)
|
|
)
|
|
)
|
|
|
|
def test_mnist_happy_path(self):
|
|
input_data_file = os.path.join(self.test_data_path, "mnist_test_data_set_0_input.pb")
|
|
output_data_file = os.path.join(self.test_data_path, "mnist_test_data_set_0_output.pb")
|
|
|
|
with open(input_data_file, "rb") as f:
|
|
request_payload = f.read()
|
|
|
|
content_type_headers = ["application/x-protobuf", "application/octet-stream", "application/vnd.google.protobuf"]
|
|
|
|
for h in content_type_headers:
|
|
request_headers = {"Content-Type": h, "Accept": "application/x-protobuf"}
|
|
|
|
url = self.url_pattern.format(self.server_ip, self.server_port, "default", 1)
|
|
test_util.test_log(url)
|
|
r = requests.post(url, headers=request_headers, data=request_payload)
|
|
self.assertEqual(r.status_code, 200)
|
|
self.assertEqual(r.headers.get("Content-Type"), "application/x-protobuf")
|
|
self.assertTrue(r.headers.get("x-ms-request-id"))
|
|
|
|
actual_result = predict_pb2.PredictResponse()
|
|
actual_result.ParseFromString(r.content)
|
|
|
|
expected_result = predict_pb2.PredictResponse()
|
|
with open(output_data_file, "rb") as f:
|
|
expected_result.ParseFromString(f.read())
|
|
|
|
for k in expected_result.outputs.keys():
|
|
self.assertEqual(actual_result.outputs[k].data_type, expected_result.outputs[k].data_type)
|
|
|
|
count = 1
|
|
for i in range(0, len(expected_result.outputs["Plus214_Output_0"].dims)):
|
|
self.assertEqual(
|
|
actual_result.outputs["Plus214_Output_0"].dims[i],
|
|
expected_result.outputs["Plus214_Output_0"].dims[i],
|
|
)
|
|
count = count * int(actual_result.outputs["Plus214_Output_0"].dims[i])
|
|
|
|
actual_array = numpy.frombuffer(actual_result.outputs["Plus214_Output_0"].raw_data, dtype=numpy.float32)
|
|
expected_array = numpy.frombuffer(expected_result.outputs["Plus214_Output_0"].raw_data, dtype=numpy.float32)
|
|
self.assertEqual(len(actual_array), len(expected_array))
|
|
self.assertEqual(len(actual_array), count)
|
|
for i in range(0, count):
|
|
self.assertTrue(test_util.compare_floats(actual_array[i], expected_array[i], rel_tol=0.001))
|
|
|
|
def test_respect_accept_header(self):
|
|
input_data_file = os.path.join(self.test_data_path, "mnist_test_data_set_0_input.pb")
|
|
|
|
with open(input_data_file, "rb") as f:
|
|
request_payload = f.read()
|
|
|
|
accept_headers = ["application/x-protobuf", "application/octet-stream", "application/vnd.google.protobuf"]
|
|
|
|
for h in accept_headers:
|
|
request_headers = {"Content-Type": "application/x-protobuf", "Accept": h}
|
|
|
|
url = self.url_pattern.format(self.server_ip, self.server_port, "default", 1)
|
|
test_util.test_log(url)
|
|
r = requests.post(url, headers=request_headers, data=request_payload)
|
|
self.assertEqual(r.status_code, 200)
|
|
self.assertEqual(r.headers.get("Content-Type"), h)
|
|
|
|
def test_missing_accept_header(self):
|
|
input_data_file = os.path.join(self.test_data_path, "mnist_test_data_set_0_input.pb")
|
|
|
|
with open(input_data_file, "rb") as f:
|
|
request_payload = f.read()
|
|
|
|
request_headers = {
|
|
"Content-Type": "application/x-protobuf",
|
|
}
|
|
|
|
url = self.url_pattern.format(self.server_ip, self.server_port, "default", 1)
|
|
test_util.test_log(url)
|
|
r = requests.post(url, headers=request_headers, data=request_payload)
|
|
self.assertEqual(r.status_code, 200)
|
|
self.assertEqual(r.headers.get("Content-Type"), "application/octet-stream")
|
|
|
|
def test_any_accept_header(self):
|
|
input_data_file = os.path.join(self.test_data_path, "mnist_test_data_set_0_input.pb")
|
|
|
|
with open(input_data_file, "rb") as f:
|
|
request_payload = f.read()
|
|
|
|
request_headers = {"Content-Type": "application/x-protobuf", "Accept": "*/*"}
|
|
|
|
url = self.url_pattern.format(self.server_ip, self.server_port, "default", 1)
|
|
test_util.test_log(url)
|
|
r = requests.post(url, headers=request_headers, data=request_payload)
|
|
self.assertEqual(r.status_code, 200)
|
|
self.assertEqual(r.headers.get("Content-Type"), "application/octet-stream")
|
|
|
|
|
|
class HttpEndpointTests(unittest.TestCase):
|
|
server_ip = "127.0.0.1"
|
|
server_port = 54321
|
|
server_app_path = ""
|
|
test_data_path = ""
|
|
model_path = ""
|
|
log_level = "verbose"
|
|
server_app_proc = None
|
|
wait_server_ready_in_seconds = 1
|
|
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
onnx_model = os.path.join(cls.model_path, "model.onnx")
|
|
test_util.prepare_mnist_model(onnx_model)
|
|
cmd = [
|
|
cls.server_app_path,
|
|
"--http_port",
|
|
str(cls.server_port),
|
|
"--model_path",
|
|
onnx_model,
|
|
"--log_level",
|
|
cls.log_level,
|
|
]
|
|
test_util.test_log("Launching server app: [{0}]".format(" ".join(cmd)))
|
|
cls.server_app_proc = subprocess.Popen(cmd)
|
|
test_util.test_log("Server app PID: {0}".format(cls.server_app_proc.pid))
|
|
test_util.test_log(
|
|
"Sleep {0} second(s) to wait for server initialization".format(cls.wait_server_ready_in_seconds)
|
|
)
|
|
time.sleep(cls.wait_server_ready_in_seconds)
|
|
|
|
@classmethod
|
|
def tearDownClass(cls):
|
|
test_util.test_log("Shutdown server app")
|
|
cls.server_app_proc.kill()
|
|
test_util.test_log(
|
|
"PID {0} has been killed: {1}".format(
|
|
cls.server_app_proc.pid, test_util.is_process_killed(cls.server_app_proc.pid)
|
|
)
|
|
)
|
|
|
|
def test_health_endpoint(self):
|
|
url = url = "http://{0}:{1}/".format(self.server_ip, self.server_port)
|
|
test_util.test_log(url)
|
|
r = requests.get(url)
|
|
self.assertEqual(r.status_code, 200)
|
|
self.assertEqual(r.content.decode("utf-8"), "Healthy")
|
|
|
|
|
|
class GRPCTests(unittest.TestCase):
|
|
server_ip = "127.0.0.1"
|
|
server_port = 54321
|
|
server_app_path = ""
|
|
test_data_path = ""
|
|
model_path = ""
|
|
log_level = "verbose"
|
|
server_app_proc = None
|
|
wait_server_ready_in_seconds = 1
|
|
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
onnx_model = os.path.join(cls.model_path, "model.onnx")
|
|
test_util.prepare_mnist_model(onnx_model)
|
|
cmd = [
|
|
cls.server_app_path,
|
|
"--grpc_port",
|
|
str(cls.server_port),
|
|
"--model_path",
|
|
onnx_model,
|
|
"--log_level",
|
|
cls.log_level,
|
|
]
|
|
test_util.test_log("Launching server app: [{0}]".format(" ".join(cmd)))
|
|
cls.server_app_proc = subprocess.Popen(cmd)
|
|
test_util.test_log("Server app PID: {0}".format(cls.server_app_proc.pid))
|
|
test_util.test_log(
|
|
"Sleep {0} second(s) to wait for server initialization".format(cls.wait_server_ready_in_seconds)
|
|
)
|
|
time.sleep(cls.wait_server_ready_in_seconds)
|
|
|
|
@classmethod
|
|
def tearDownClass(cls):
|
|
test_util.test_log("Shutdown server app")
|
|
cls.server_app_proc.kill()
|
|
test_util.test_log(
|
|
"PID {0} has been killed: {1}".format(
|
|
cls.server_app_proc.pid, test_util.is_process_killed(cls.server_app_proc.pid)
|
|
)
|
|
)
|
|
|
|
def test_mnist_happy_path(self):
|
|
input_data_file = os.path.join(self.test_data_path, "mnist_test_data_set_0_input.pb")
|
|
output_data_file = os.path.join(self.test_data_path, "mnist_test_data_set_0_output.pb")
|
|
|
|
with open(input_data_file, "rb") as f:
|
|
request_payload = f.read()
|
|
|
|
request = predict_pb2.PredictRequest()
|
|
request.ParseFromString(request_payload)
|
|
uri = "{}:{}".format(self.server_ip, self.server_port)
|
|
test_util.test_log(uri)
|
|
with grpc.insecure_channel(uri) as channel:
|
|
stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)
|
|
actual_result = stub.Predict(request)
|
|
|
|
expected_result = predict_pb2.PredictResponse()
|
|
with open(output_data_file, "rb") as f:
|
|
expected_result.ParseFromString(f.read())
|
|
|
|
for k in expected_result.outputs.keys():
|
|
self.assertEqual(actual_result.outputs[k].data_type, expected_result.outputs[k].data_type)
|
|
|
|
count = 1
|
|
for i in range(0, len(expected_result.outputs["Plus214_Output_0"].dims)):
|
|
self.assertEqual(
|
|
actual_result.outputs["Plus214_Output_0"].dims[i], expected_result.outputs["Plus214_Output_0"].dims[i]
|
|
)
|
|
count = count * int(actual_result.outputs["Plus214_Output_0"].dims[i])
|
|
|
|
actual_array = numpy.frombuffer(actual_result.outputs["Plus214_Output_0"].raw_data, dtype=numpy.float32)
|
|
expected_array = numpy.frombuffer(expected_result.outputs["Plus214_Output_0"].raw_data, dtype=numpy.float32)
|
|
self.assertEqual(len(actual_array), len(expected_array))
|
|
self.assertEqual(len(actual_array), count)
|
|
for i in range(0, count):
|
|
self.assertTrue(test_util.compare_floats(actual_array[i], expected_array[i], rel_tol=0.001))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|