Memory map files on windows to speed up model load (#8349)

* Memory map files on windows to speed up model load

* fix custom ops

Co-authored-by: Sheil Kumar <sheilk@microsoft.com>
This commit is contained in:
Sheil Kumar 2021-07-12 11:52:08 -07:00 committed by GitHub
parent f6956e0259
commit eec8e1394a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 48 additions and 20 deletions

View file

@ -1328,7 +1328,7 @@ STDMETHODIMP OnnxruntimeEngineFactory::CreateModel(_In_ void* data, _In_ size_t
RETURN_IF_FAILED(EnsureEnvironment());
OrtModel* ort_model = nullptr;
if (auto status = winml_adapter_api_->CreateModelFromData(data, size, &ort_model)) {
return E_INVALIDARG;
return __HRESULT_FROM_WIN32(ERROR_FILE_CORRUPT);
}
auto model = UniqueOrtModel(ort_model, winml_adapter_api_->ReleaseModel);

View file

@ -17,8 +17,49 @@
namespace WINMLP {
LearningModel::LearningModel(
const hstring& path,
const winml::ILearningModelOperatorProvider op_provider) try : LearningModel(_winml::Strings::UTF8FromHString(path),
op_provider) {
const winml::ILearningModelOperatorProvider op_provider) try : operator_provider_(op_provider) {
_winmlt::TelemetryEvent loadModel_event(_winmlt::EventCategory::kModelLoad);
WINML_THROW_IF_FAILED(CreateOnnxruntimeEngineFactory(engine_factory_.put()));
auto file_handle = wil::unique_handle(CreateFileW(path.c_str(),
GENERIC_READ,
0,
NULL,
OPEN_EXISTING,
FILE_ATTRIBUTE_READONLY,
NULL));
WINML_THROW_HR_IF_TRUE_MSG(__HRESULT_FROM_WIN32(GetLastError()),
file_handle.get() == INVALID_HANDLE_VALUE,
"Model load failed!");
auto file_mapping = wil::unique_handle(CreateFileMappingW(file_handle.get(), // current file handle
NULL, // default security
PAGE_READONLY, // read/write permission
0, // size of mapping object, high
0, // size of mapping object, low
NULL)); // name of mapping object
WINML_THROW_HR_IF_TRUE_MSG(__HRESULT_FROM_WIN32(GetLastError()),
file_mapping == nullptr,
"Model load failed!");
auto buffer = MapViewOfFile(file_mapping.get(), // handle to mapping object
FILE_MAP_READ, // read/write
0, // high-order 32 bits of file offset
0, // low-order 32 bits of file offset
0); // number of bytes to map. 0 means read whole file.
WINML_THROW_HR_IF_TRUE_MSG(__HRESULT_FROM_WIN32(GetLastError()),
file_mapping == nullptr,
"Model load failed!");
auto file_size_in_bytes = GetFileSize(file_handle.get(), NULL);
WINML_THROW_IF_FAILED(engine_factory_->CreateModel(buffer, file_size_in_bytes, model_.put()));
WINML_THROW_HR_IF_TRUE_MSG(E_UNEXPECTED, UnmapViewOfFile(buffer) == 0, "Could not unmap model file.");
WINML_THROW_IF_FAILED(model_->GetModelInfo(model_info_.put()));
}
WINML_CATCH_ALL
@ -33,17 +74,6 @@ LearningModel::LearningModel(
}
WINML_CATCH_ALL
LearningModel::LearningModel(
const std::string& path,
const winml::ILearningModelOperatorProvider operator_provider) try : operator_provider_(operator_provider) {
_winmlt::TelemetryEvent loadModel_event(_winmlt::EventCategory::kModelLoad);
WINML_THROW_IF_FAILED(CreateOnnxruntimeEngineFactory(engine_factory_.put()));
WINML_THROW_IF_FAILED(engine_factory_->CreateModel(path.c_str(), path.size(), model_.put()));
WINML_THROW_IF_FAILED(model_->GetModelInfo(model_info_.put()));
}
WINML_CATCH_ALL
static HRESULT CreateModelFromStream(
_winml::IEngineFactory* engine_factory,
const wss::IRandomAccessStreamReference stream,
@ -64,7 +94,9 @@ static HRESULT CreateModelFromStream(
WINML_THROW_IF_FAILED_MSG(bytes->Buffer(reinterpret_cast<byte**>(&data)), "Failed to acquire buffer from model stream.");
size_t len = static_cast<size_t>(content.Size());
WINML_THROW_IF_FAILED(engine_factory->CreateModel(data, len, model));
if (FAILED(engine_factory->CreateModel(data, len, model))) {
WINML_THROW_HR(E_INVALIDARG);
}
return S_OK;
}
@ -282,7 +314,7 @@ __stdcall LearningModel::Load(
WINML_THROW_HR_IF_FALSE_MSG(E_INVALIDARG, model_path_size > 0, "Failed to create LearningModel. Ivalid argument model_path_size.");
WINML_THROW_HR_IF_NULL_MSG(E_INVALIDARG, pp_model_unk, "Failed to create LearningModel. Ivalid argument pp_model_unk.");
auto path = _winml::Strings::UTF8FromUnicode(p_model_path, model_path_size);
winrt::hstring path(p_model_path, model_path_size);
auto model = make<winmlp::LearningModel>(path, nullptr);
*pp_model_unk = model.as<IUnknown>().detach();
return S_OK;

View file

@ -25,10 +25,6 @@ struct LearningModel : LearningModelT<LearningModel> {
const wss::IRandomAccessStreamReference stream,
const winml::ILearningModelOperatorProvider operator_provider);
LearningModel(
const std::string& path,
const winml::ILearningModelOperatorProvider operator_provider);
LearningModel(
_winml::IEngineFactory* engine_factory,
_winml::IModel* model,