Print hex value for float compare when test failed (#228)

This helps identify fp accuracy issues
This commit is contained in:
KeDengMS 2018-12-19 21:31:15 -08:00 committed by GitHub
parent 0dca080238
commit abce6041c1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -64,14 +64,20 @@ std::pair<COMPARE_RESULT, std::string> CompareFloatResult(const Tensor& outvalue
const double real_value = post_processing ? std::max<double>(0.0, std::min<double>(255.0, real_output[di]))
: real_output[di];
const double diff = fabs(expected_output[di] - real_value);
const double rtol = per_sample_tolerance + relative_per_sample_tolerance * fabs(expected_output[di]);
if (diff > rtol || (std::isnan(diff) && !std::isnan(expected_output[di]))) {
const double tol = per_sample_tolerance + relative_per_sample_tolerance * fabs(expected_output[di]);
if (diff > tol || (std::isnan(diff) && !std::isnan(expected_output[di]))) {
res.first = COMPARE_RESULT::RESULT_DIFFERS;
// update error message if this is a larger diff
if (diff > max_diff || (std::isnan(diff) && !std::isnan(max_diff))) {
int64_t expected_int = 0;
int64_t real_int = 0;
memcpy(&expected_int, &expected_output[di], sizeof(FLOAT_TYPE));
memcpy(&real_int, &real_output[di], sizeof(FLOAT_TYPE));
std::ostringstream oss;
oss << "expected " << expected_output[di] << ", got " << real_value
<< ", diff: " << diff << ", tol=" << rtol << ".";
oss << std::hex << "expected " << expected_output[di] << " (" << expected_int << "), got "
<< real_value << " (" << real_int << ")"
<< ", diff: " << diff << ", tol=" << tol << ".";
res.second = oss.str();
max_diff = diff;
}