2023-10-12 08:28:55 +00:00
|
|
|
#include <stdexcept>
|
|
|
|
|
|
2023-12-19 14:11:01 +00:00
|
|
|
#include <torch/csrc/inductor/aoti_runner/model_container_runner_cpu.h>
|
2024-08-08 19:49:53 +00:00
|
|
|
#if defined(USE_CUDA) || defined(USE_ROCM)
|
2023-12-19 14:11:01 +00:00
|
|
|
#include <torch/csrc/inductor/aoti_runner/model_container_runner_cuda.h>
|
2023-10-12 08:28:55 +00:00
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
#include "aoti_custom_class.h"
|
|
|
|
|
|
|
|
|
|
namespace torch::aot_inductor {
|
|
|
|
|
|
|
|
|
|
static auto registerMyAOTIClass =
|
|
|
|
|
torch::class_<MyAOTIClass>("aoti", "MyAOTIClass")
|
|
|
|
|
.def(torch::init<std::string, std::string>())
|
|
|
|
|
.def("forward", &MyAOTIClass::forward)
|
|
|
|
|
.def_pickle(
|
|
|
|
|
[](const c10::intrusive_ptr<MyAOTIClass>& self)
|
|
|
|
|
-> std::vector<std::string> {
|
|
|
|
|
std::vector<std::string> v;
|
|
|
|
|
v.push_back(self->lib_path());
|
|
|
|
|
v.push_back(self->device());
|
|
|
|
|
return v;
|
|
|
|
|
},
|
|
|
|
|
[](std::vector<std::string> params) {
|
|
|
|
|
return c10::make_intrusive<MyAOTIClass>(params[0], params[1]);
|
|
|
|
|
});
|
|
|
|
|
|
|
|
|
|
MyAOTIClass::MyAOTIClass(
|
|
|
|
|
const std::string& model_path,
|
|
|
|
|
const std::string& device)
|
|
|
|
|
: lib_path_(model_path), device_(device) {
|
2024-08-28 17:42:19 +00:00
|
|
|
if (device_ == "cpu") {
|
2023-11-14 22:08:25 +00:00
|
|
|
runner_ = std::make_unique<torch::inductor::AOTIModelContainerRunnerCpu>(
|
2023-10-12 08:28:55 +00:00
|
|
|
model_path.c_str());
|
2024-08-28 17:42:19 +00:00
|
|
|
#if defined(USE_CUDA) || defined(USE_ROCM)
|
|
|
|
|
} else if (device_ == "cuda") {
|
|
|
|
|
runner_ = std::make_unique<torch::inductor::AOTIModelContainerRunnerCuda>(
|
|
|
|
|
model_path.c_str());
|
|
|
|
|
#endif
|
2023-10-12 08:28:55 +00:00
|
|
|
} else {
|
|
|
|
|
throw std::runtime_error("invalid device: " + device);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<torch::Tensor> MyAOTIClass::forward(
|
|
|
|
|
std::vector<torch::Tensor> inputs) {
|
|
|
|
|
return runner_->run(inputs);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace torch::aot_inductor
|