pytorch/test/cpp/jit/test_utils.h
Han Qi 0723639b60 Revert D34455360: Multisect successfully blamed D34455360 for test failures
Summary:
This diff is reverting D34455360 (61d6c43864)
D34455360 (61d6c43864) is making the following tests to fail and this revert diff is either the revert of the blame diff or the revert of the stack of diffs that need to be reverted to revert the blame diff

Tests affected:
- https://www.internalfb.com/intern/test/562950004334605/

Multisect link:
https://www.internalfb.com/intern/testinfra/multisect/756170

Test Plan: NA

Reviewed By: zhxchen17

Differential Revision: D34596156

fbshipit-source-id: a465bca0094db3caf6130c80f1ed49eea981359b
(cherry picked from commit ef5e5578c64ce9827570757fb016aafa9c782c6a)
2022-03-08 23:18:54 +00:00

108 lines
3.4 KiB
C++

#pragma once
#include <torch/csrc/jit/ir/irparser.h>
#include <torch/csrc/jit/runtime/autodiff.h>
#include <torch/csrc/jit/runtime/interpreter.h>
#include <torch/csrc/jit/testing/file_check.h>
namespace {
static inline void trim(std::string& s) {
s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](unsigned char ch) {
return !std::isspace(ch);
}));
s.erase(
std::find_if(
s.rbegin(),
s.rend(),
[](unsigned char ch) { return !std::isspace(ch); })
.base(),
s.end());
for (int64_t i = 0; i < s.size(); ++i) {
if (s[i] == '\n') {
s.erase(i, 1);
i--;
}
}
for (int64_t i = 0; i < s.size(); ++i) {
if (s[i] == ' ') {
for (int64_t j = i + 1; j < s.size(); j++) {
if (s[j] == ' ') {
s.erase(j, 1);
j--;
} else {
break;
}
}
}
}
}
} // namespace
#define ASSERT_THROWS_WITH_MESSAGE(statement, substring) \
try { \
(void)statement; \
FAIL(); \
} catch (const std::exception& e) { \
std::string substring_s(substring); \
trim(substring_s); \
auto exception_string = std::string(e.what()); \
trim(exception_string); \
ASSERT_NE(exception_string.find(substring_s), std::string::npos); \
}
namespace torch {
namespace jit {
using tensor_list = std::vector<at::Tensor>;
using namespace torch::autograd;
// work around the fact that variable_tensor_list doesn't duplicate all
// of std::vector's constructors.
// most constructors are never used in the implementation, just in our tests.
Stack createStack(std::vector<at::Tensor>&& list);
void assertAllClose(const tensor_list& a, const tensor_list& b);
std::vector<at::Tensor> run(
InterpreterState& interp,
const std::vector<at::Tensor>& inputs);
std::pair<tensor_list, tensor_list> runGradient(
Gradient& grad_spec,
tensor_list& tensors_in,
tensor_list& tensor_grads_in);
std::shared_ptr<Graph> build_lstm();
std::shared_ptr<Graph> build_mobile_export_analysis_graph();
std::shared_ptr<Graph> build_mobile_export_with_out();
std::shared_ptr<Graph> build_mobile_export_analysis_graph_with_vararg();
std::shared_ptr<Graph> build_mobile_export_analysis_graph_nested();
std::shared_ptr<Graph> build_mobile_export_analysis_graph_non_const();
at::Tensor t_use(at::Tensor x);
at::Tensor t_def(at::Tensor x);
// given the difference of output vs expected tensor, check whether the
// difference is within a relative tolerance range. This is a standard way of
// matching tensor values up to certain precision
bool checkRtol(const at::Tensor& diff, const std::vector<at::Tensor> inputs);
bool almostEqual(const at::Tensor& a, const at::Tensor& b);
bool exactlyEqual(const at::Tensor& a, const at::Tensor& b);
bool exactlyEqual(
const std::vector<at::Tensor>& a,
const std::vector<at::Tensor>& b);
std::vector<at::Tensor> runGraph(
std::shared_ptr<Graph> graph,
const std::vector<at::Tensor>& inputs);
std::pair<at::Tensor, at::Tensor> lstm(
at::Tensor input,
at::Tensor hx,
at::Tensor cx,
at::Tensor w_ih,
at::Tensor w_hh);
} // namespace jit
} // namespace torch