mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-16 21:00:14 +00:00
137 lines
4.7 KiB
C++
137 lines
4.7 KiB
C++
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
// Licensed under the MIT License.
|
|
#include <iostream>
|
|
|
|
#include "gtest/gtest.h"
|
|
|
|
#include "executor.h"
|
|
#include "grpc/prediction_service_impl.h"
|
|
#include "test_server_environment.h"
|
|
#include "external/server_context_test_spouse.h"
|
|
#include <grpcpp/impl/grpc_library.h>
|
|
#include "test_server_environment.h"
|
|
|
|
namespace onnxruntime {
|
|
namespace server {
|
|
namespace grpc {
|
|
namespace test {
|
|
|
|
// This class's constructor calls the GRPC library's init method.
|
|
static ::grpc::internal::GrpcLibraryInitializer g_initializer;
|
|
|
|
PredictRequest GetRequest() {
|
|
PredictRequest req{};
|
|
req.add_output_filter("Y");
|
|
onnx::TensorProto proto{};
|
|
proto.add_dims(3);
|
|
proto.add_dims(2);
|
|
proto.set_data_type(1);
|
|
std::vector<float> floats = {1.f, 2.f, 3.f, 4.f, 5.f, 6.f};
|
|
std::for_each(floats.begin(), floats.end(), [&proto](float f) { proto.add_float_data(f); });
|
|
(*req.mutable_inputs())["X"] = proto;
|
|
return req;
|
|
}
|
|
|
|
class PredictionServiceImplTest : public ::testing::Test {
|
|
protected:
|
|
void SetUp() override {
|
|
const static auto model_file = "testdata/mul_1.onnx";
|
|
|
|
onnxruntime::server::ServerEnvironment* env = onnxruntime::server::test::ServerEnv();
|
|
// Implementation detail - currently predict is hardcoded to model "default" version 1.
|
|
env->InitializeModel(model_file, "default", "1");
|
|
}
|
|
void TearDown() override {
|
|
onnxruntime::server::ServerEnvironment* env = onnxruntime::server::test::ServerEnv();
|
|
env->UnloadModel("default", "1");
|
|
}
|
|
std::shared_ptr<onnxruntime::server::ServerEnvironment> GetEnvironment() {
|
|
return std::shared_ptr<onnxruntime::server::ServerEnvironment>(onnxruntime::server::test::ServerEnv(), [](onnxruntime::server::ServerEnvironment*) {});
|
|
}
|
|
};
|
|
|
|
TEST_F(PredictionServiceImplTest, HappyPath) {
|
|
auto env = GetEnvironment();
|
|
PredictionServiceImpl test{env};
|
|
auto request = GetRequest();
|
|
PredictResponse resp{};
|
|
::grpc::ServerContext context;
|
|
auto status = test.Predict(&context, &request, &resp);
|
|
EXPECT_TRUE(status.ok());
|
|
}
|
|
|
|
TEST_F(PredictionServiceImplTest, InvalidInput) {
|
|
auto env = GetEnvironment();
|
|
PredictionServiceImpl test{env};
|
|
auto request = GetRequest();
|
|
(*request.mutable_inputs())["X"].add_dims(1);
|
|
PredictResponse resp{};
|
|
::grpc::ServerContext context{};
|
|
auto status = test.Predict(&context, &request, &resp);
|
|
EXPECT_EQ(status.error_code(), ::grpc::INVALID_ARGUMENT);
|
|
}
|
|
|
|
TEST_F(PredictionServiceImplTest, SuccessRequestID) {
|
|
auto env = GetEnvironment();
|
|
PredictionServiceImpl test{env};
|
|
auto request = GetRequest();
|
|
PredictResponse resp{};
|
|
::grpc::ServerContext context;
|
|
::grpc::testing::ServerContextTestSpouse spouse(&context);
|
|
auto status = test.Predict(&context, &request, &resp);
|
|
auto metadata = spouse.GetInitialMetadata();
|
|
EXPECT_NE(metadata.find("x-ms-request-id"), metadata.end());
|
|
EXPECT_TRUE(status.ok());
|
|
}
|
|
|
|
TEST_F(PredictionServiceImplTest, InvalidInputRequestID) {
|
|
auto env = GetEnvironment();
|
|
PredictionServiceImpl test{env};
|
|
auto request = GetRequest();
|
|
request.clear_inputs();
|
|
|
|
PredictResponse resp{};
|
|
::grpc::ServerContext context;
|
|
::grpc::testing::ServerContextTestSpouse spouse(&context);
|
|
auto status = test.Predict(&context, &request, &resp);
|
|
auto metadata = spouse.GetInitialMetadata();
|
|
EXPECT_NE(metadata.find("x-ms-request-id"), metadata.end());
|
|
EXPECT_FALSE(status.ok());
|
|
}
|
|
|
|
TEST_F(PredictionServiceImplTest, SuccessClientID) {
|
|
auto env = GetEnvironment();
|
|
PredictionServiceImpl test{env};
|
|
auto request = GetRequest();
|
|
PredictResponse resp{};
|
|
::grpc::ServerContext context;
|
|
::grpc::testing::ServerContextTestSpouse spouse(&context);
|
|
spouse.AddClientMetadata("x-ms-client-request-id", "client-id");
|
|
auto status = test.Predict(&context, &request, &resp);
|
|
auto metadata = spouse.GetInitialMetadata();
|
|
EXPECT_NE(metadata.find("x-ms-client-request-id"), metadata.end());
|
|
EXPECT_EQ(metadata.find("x-ms-client-request-id")->second, "client-id");
|
|
EXPECT_TRUE(status.ok());
|
|
}
|
|
|
|
TEST_F(PredictionServiceImplTest, InvalidInputClientID) {
|
|
auto env = GetEnvironment();
|
|
PredictionServiceImpl test{env};
|
|
auto request = GetRequest();
|
|
request.clear_inputs();
|
|
|
|
PredictResponse resp{};
|
|
::grpc::ServerContext context;
|
|
::grpc::testing::ServerContextTestSpouse spouse(&context);
|
|
spouse.AddClientMetadata("x-ms-client-request-id", "client-id");
|
|
auto status = test.Predict(&context, &request, &resp);
|
|
auto metadata = spouse.GetInitialMetadata();
|
|
EXPECT_NE(metadata.find("x-ms-client-request-id"), metadata.end());
|
|
EXPECT_EQ(metadata.find("x-ms-client-request-id")->second, "client-id");
|
|
EXPECT_FALSE(status.ok());
|
|
}
|
|
|
|
} // namespace test
|
|
} // namespace grpc
|
|
} // namespace server
|
|
} // namespace onnxruntime
|