mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-16 21:00:14 +00:00
57 lines
1.7 KiB
C++
57 lines
1.7 KiB
C++
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
// Licensed under the MIT License.
|
|
|
|
#include <iostream>
|
|
|
|
#include "gtest/gtest.h"
|
|
|
|
#include "executor.h"
|
|
#include "http/json_handling.h"
|
|
#include <spdlog/spdlog.h>
|
|
#include <spdlog/sinks/sink.h>
|
|
#include <spdlog/sinks/stdout_sinks.h>
|
|
#include "test_server_environment.h"
|
|
|
|
namespace onnxruntime {
|
|
namespace server {
|
|
namespace test {
|
|
|
|
class ExecutorTest : public ::testing::Test {
|
|
protected:
|
|
void SetUp() override {
|
|
const static auto model_file = "testdata/mul_1.onnx";
|
|
|
|
onnxruntime::server::ServerEnvironment* env = ServerEnv();
|
|
env->InitializeModel(model_file, "Name", "version");
|
|
}
|
|
|
|
void TearDown() override {
|
|
onnxruntime::server::ServerEnvironment* env = ServerEnv();
|
|
env->UnloadModel("Name", "version");
|
|
}
|
|
};
|
|
|
|
TEST_F(ExecutorTest, TestMul_1) {
|
|
const static auto input_json = R"({"inputs":{"X":{"dims":[3,2],"dataType":1,"floatData":[1,2,3,4,5,6]}},"outputFilter":["Y"]})";
|
|
const static auto expected = R"({"outputs":{"Y":{"dims":["3","2"],"dataType":1,"floatData":[1,4,9,16,25,36]}}})";
|
|
|
|
onnxruntime::server::ServerEnvironment* env = ServerEnv();
|
|
|
|
onnxruntime::server::Executor executor(env, "RequestId");
|
|
onnxruntime::server::PredictRequest request{};
|
|
onnxruntime::server::PredictResponse response{};
|
|
|
|
auto protostatus = onnxruntime::server::GetRequestFromJson(input_json, request);
|
|
EXPECT_TRUE(protostatus.ok());
|
|
|
|
auto prediction_res = executor.Predict("Name", "version", request, response);
|
|
EXPECT_TRUE(prediction_res.ok());
|
|
|
|
std::string body;
|
|
protostatus = GenerateResponseInJson(response, body);
|
|
EXPECT_EQ(expected, body);
|
|
}
|
|
|
|
} // namespace test
|
|
} // namespace server
|
|
} // namespace onnxruntime
|