MNIST Sample Fix (#2259)

* Fix Global variable initialization order

* Remove static initialization, and add error messages
This commit is contained in:
Ryan Hill 2019-10-28 11:22:45 -07:00 committed by GitHub
parent 0b88eff43a
commit 3ecdd985cb
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -10,8 +10,6 @@
#pragma comment(lib, "gdi32.lib")
#pragma comment(lib, "onnxruntime.lib")
Ort::Env env{ORT_LOGGING_LEVEL_WARNING, "test"};
// This is the structure to interface with the MNIST model
// After instantiation, set the input_image_ data to be the 28x28 pixel image of the number to recognize
// Then call Run() to fill in the results_ data with the probabilities of each
@ -41,6 +39,7 @@ struct MNIST {
int64_t result_{0};
private:
Ort::Env env;
Ort::Session session_{env, L"model.onnx", Ort::SessionOptions{nullptr}};
Ort::Value input_tensor_{nullptr};
@ -55,7 +54,7 @@ const constexpr int drawing_area_scale_{4}; // Number of times larger to make t
const constexpr int drawing_area_width_{MNIST::width_ * drawing_area_scale_};
const constexpr int drawing_area_height_{MNIST::height_ * drawing_area_scale_};
MNIST mnist_;
std::unique_ptr<MNIST> mnist_;
HBITMAP dib_;
HDC hdc_dib_;
bool painting_{};
@ -79,9 +78,9 @@ void ConvertDibToMnist() {
DIBInfo info{dib_};
const DWORD* input = reinterpret_cast<const DWORD*>(info.Bits());
float* output = mnist_.input_image_.data();
float* output = mnist_->input_image_.data();
std::fill(mnist_.input_image_.begin(), mnist_.input_image_.end(), 0.f);
std::fill(mnist_->input_image_.begin(), mnist_->input_image_.end(), 0.f);
for (unsigned y = 0; y < MNIST::height_; y++) {
for (unsigned x = 0; x < MNIST::width_; x++) {
@ -97,6 +96,13 @@ LRESULT CALLBACK WndProc(HWND, UINT, WPARAM, LPARAM);
// The Windows entry point function
int APIENTRY wWinMain(_In_ HINSTANCE hInstance, _In_opt_ HINSTANCE /*hPrevInstance*/, _In_ LPTSTR /*lpCmdLine*/,
_In_ int nCmdShow) {
try {
mnist_ = std::make_unique<MNIST>();
} catch (const Ort::Exception& exception) {
MessageBoxA(nullptr, exception.what(), "Error:", MB_OK);
return 0;
}
{
WNDCLASSEX wc{};
wc.cbSize = sizeof(WNDCLASSEX);
@ -139,6 +145,13 @@ int APIENTRY wWinMain(_In_ HINSTANCE hInstance, _In_opt_ HINSTANCE /*hPrevInstan
TranslateMessage(&msg);
DispatchMessage(&msg);
}
DeleteObject(dib_);
DeleteDC(hdc_dib_);
DeleteObject(brush_winner_);
DeleteObject(brush_bars_);
return (int)msg.wParam;
}
@ -158,14 +171,14 @@ LRESULT CALLBACK WndProc(HWND hWnd, UINT message, WPARAM wParam, LPARAM lParam)
constexpr int graph_width = 64;
SelectObject(hdc, brush_bars_);
auto least = *std::min_element(mnist_.results_.begin(), mnist_.results_.end());
auto greatest = mnist_.results_[mnist_.result_];
auto least = *std::min_element(mnist_->results_.begin(), mnist_->results_.end());
auto greatest = mnist_->results_[mnist_->result_];
auto range = greatest - least;
int graphs_zero = static_cast<int>(graphs_left - least * graph_width / range);
// Hilight the winner
RECT rc{graphs_left, static_cast<LONG>(mnist_.result_) * 16, graphs_left + graph_width + 128, static_cast<LONG>(mnist_.result_ + 1) * 16};
RECT rc{graphs_left, static_cast<LONG>(mnist_->result_) * 16, graphs_left + graph_width + 128, static_cast<LONG>(mnist_->result_ + 1) * 16};
FillRect(hdc, &rc, brush_winner_);
// For every entry, draw the odds and the graph for it
@ -173,7 +186,7 @@ LRESULT CALLBACK WndProc(HWND hWnd, UINT message, WPARAM wParam, LPARAM lParam)
wchar_t value[80];
for (unsigned i = 0; i < 10; i++) {
int y = 16 * i;
float result = mnist_.results_[i];
float result = mnist_->results_[i];
auto length = wsprintf(value, L"%2d: %d.%02d", i, int(result), abs(int(result * 100) % 100));
TextOut(hdc, graphs_left + graph_width + 5, y, value, length);
@ -214,7 +227,7 @@ LRESULT CALLBACK WndProc(HWND hWnd, UINT message, WPARAM wParam, LPARAM lParam)
case WM_LBUTTONUP:
ReleaseCapture();
ConvertDibToMnist();
mnist_.Run();
mnist_->Run();
InvalidateRect(hWnd, nullptr, true);
return 0;