mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-22 22:01:08 +00:00
MNIST Sample Fix (#2259)
* Fix Global variable initialization order * Remove static initialization, and add error messages
This commit is contained in:
parent
0b88eff43a
commit
3ecdd985cb
1 changed files with 23 additions and 10 deletions
|
|
@ -10,8 +10,6 @@
|
|||
#pragma comment(lib, "gdi32.lib")
|
||||
#pragma comment(lib, "onnxruntime.lib")
|
||||
|
||||
Ort::Env env{ORT_LOGGING_LEVEL_WARNING, "test"};
|
||||
|
||||
// This is the structure to interface with the MNIST model
|
||||
// After instantiation, set the input_image_ data to be the 28x28 pixel image of the number to recognize
|
||||
// Then call Run() to fill in the results_ data with the probabilities of each
|
||||
|
|
@ -41,6 +39,7 @@ struct MNIST {
|
|||
int64_t result_{0};
|
||||
|
||||
private:
|
||||
Ort::Env env;
|
||||
Ort::Session session_{env, L"model.onnx", Ort::SessionOptions{nullptr}};
|
||||
|
||||
Ort::Value input_tensor_{nullptr};
|
||||
|
|
@ -55,7 +54,7 @@ const constexpr int drawing_area_scale_{4}; // Number of times larger to make t
|
|||
const constexpr int drawing_area_width_{MNIST::width_ * drawing_area_scale_};
|
||||
const constexpr int drawing_area_height_{MNIST::height_ * drawing_area_scale_};
|
||||
|
||||
MNIST mnist_;
|
||||
std::unique_ptr<MNIST> mnist_;
|
||||
HBITMAP dib_;
|
||||
HDC hdc_dib_;
|
||||
bool painting_{};
|
||||
|
|
@ -79,9 +78,9 @@ void ConvertDibToMnist() {
|
|||
DIBInfo info{dib_};
|
||||
|
||||
const DWORD* input = reinterpret_cast<const DWORD*>(info.Bits());
|
||||
float* output = mnist_.input_image_.data();
|
||||
float* output = mnist_->input_image_.data();
|
||||
|
||||
std::fill(mnist_.input_image_.begin(), mnist_.input_image_.end(), 0.f);
|
||||
std::fill(mnist_->input_image_.begin(), mnist_->input_image_.end(), 0.f);
|
||||
|
||||
for (unsigned y = 0; y < MNIST::height_; y++) {
|
||||
for (unsigned x = 0; x < MNIST::width_; x++) {
|
||||
|
|
@ -97,6 +96,13 @@ LRESULT CALLBACK WndProc(HWND, UINT, WPARAM, LPARAM);
|
|||
// The Windows entry point function
|
||||
int APIENTRY wWinMain(_In_ HINSTANCE hInstance, _In_opt_ HINSTANCE /*hPrevInstance*/, _In_ LPTSTR /*lpCmdLine*/,
|
||||
_In_ int nCmdShow) {
|
||||
try {
|
||||
mnist_ = std::make_unique<MNIST>();
|
||||
} catch (const Ort::Exception& exception) {
|
||||
MessageBoxA(nullptr, exception.what(), "Error:", MB_OK);
|
||||
return 0;
|
||||
}
|
||||
|
||||
{
|
||||
WNDCLASSEX wc{};
|
||||
wc.cbSize = sizeof(WNDCLASSEX);
|
||||
|
|
@ -139,6 +145,13 @@ int APIENTRY wWinMain(_In_ HINSTANCE hInstance, _In_opt_ HINSTANCE /*hPrevInstan
|
|||
TranslateMessage(&msg);
|
||||
DispatchMessage(&msg);
|
||||
}
|
||||
|
||||
DeleteObject(dib_);
|
||||
DeleteDC(hdc_dib_);
|
||||
|
||||
DeleteObject(brush_winner_);
|
||||
DeleteObject(brush_bars_);
|
||||
|
||||
return (int)msg.wParam;
|
||||
}
|
||||
|
||||
|
|
@ -158,14 +171,14 @@ LRESULT CALLBACK WndProc(HWND hWnd, UINT message, WPARAM wParam, LPARAM lParam)
|
|||
constexpr int graph_width = 64;
|
||||
SelectObject(hdc, brush_bars_);
|
||||
|
||||
auto least = *std::min_element(mnist_.results_.begin(), mnist_.results_.end());
|
||||
auto greatest = mnist_.results_[mnist_.result_];
|
||||
auto least = *std::min_element(mnist_->results_.begin(), mnist_->results_.end());
|
||||
auto greatest = mnist_->results_[mnist_->result_];
|
||||
auto range = greatest - least;
|
||||
|
||||
int graphs_zero = static_cast<int>(graphs_left - least * graph_width / range);
|
||||
|
||||
// Hilight the winner
|
||||
RECT rc{graphs_left, static_cast<LONG>(mnist_.result_) * 16, graphs_left + graph_width + 128, static_cast<LONG>(mnist_.result_ + 1) * 16};
|
||||
RECT rc{graphs_left, static_cast<LONG>(mnist_->result_) * 16, graphs_left + graph_width + 128, static_cast<LONG>(mnist_->result_ + 1) * 16};
|
||||
FillRect(hdc, &rc, brush_winner_);
|
||||
|
||||
// For every entry, draw the odds and the graph for it
|
||||
|
|
@ -173,7 +186,7 @@ LRESULT CALLBACK WndProc(HWND hWnd, UINT message, WPARAM wParam, LPARAM lParam)
|
|||
wchar_t value[80];
|
||||
for (unsigned i = 0; i < 10; i++) {
|
||||
int y = 16 * i;
|
||||
float result = mnist_.results_[i];
|
||||
float result = mnist_->results_[i];
|
||||
|
||||
auto length = wsprintf(value, L"%2d: %d.%02d", i, int(result), abs(int(result * 100) % 100));
|
||||
TextOut(hdc, graphs_left + graph_width + 5, y, value, length);
|
||||
|
|
@ -214,7 +227,7 @@ LRESULT CALLBACK WndProc(HWND hWnd, UINT message, WPARAM wParam, LPARAM lParam)
|
|||
case WM_LBUTTONUP:
|
||||
ReleaseCapture();
|
||||
ConvertDibToMnist();
|
||||
mnist_.Run();
|
||||
mnist_->Run();
|
||||
InvalidateRect(hWnd, nullptr, true);
|
||||
return 0;
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue