all_close_f.cpp 17 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
16

17
#include <climits>
18 19 20 21 22 23 24 25 26 27 28 29
#include <cmath>

#include "util/all_close_f.hpp"

using namespace std;
using namespace ngraph;

union FloatUnion {
    float f;
    uint32_t i;
};

30 31 32 33 34
union DoubleUnion {
    double d;
    uint64_t i;
};

35
uint32_t test::float_distance(float a, float b)
36 37 38
{
    if (!isfinite(a) || !isfinite(b))
    {
39
        return UINT_MAX;
40 41 42 43 44 45 46 47 48 49 50 51 52 53 54
    }

    FloatUnion a_fu{a};
    FloatUnion b_fu{b};
    uint32_t a_uint = a_fu.i;
    uint32_t b_uint = b_fu.i;

    // A trick to handle both positive and negative numbers, see https://goo.gl/YbdnFQ
    // - If negative: convert to two's complement
    // - If positive: mask with sign bit
    uint32_t sign_mask = static_cast<uint32_t>(1U) << 31;
    a_uint = (sign_mask & a_uint) ? (~a_uint + 1) : (sign_mask | a_uint);
    b_uint = (sign_mask & b_uint) ? (~b_uint + 1) : (sign_mask | b_uint);

    uint32_t distance = (a_uint >= b_uint) ? (a_uint - b_uint) : (b_uint - a_uint);
55 56 57
    return distance;
}

58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80
uint64_t test::float_distance(double a, double b)
{
    if (!isfinite(a) || !isfinite(b))
    {
        return ULLONG_MAX;
    }

    DoubleUnion a_du{a};
    DoubleUnion b_du{b};
    uint64_t a_uint = a_du.i;
    uint64_t b_uint = b_du.i;

    // A trick to handle both positive and negative numbers, see https://goo.gl/YbdnFQ
    // - If negative: convert to two's complement
    // - If positive: mask with sign bit
    uint64_t sign_mask = static_cast<uint64_t>(1U) << 63;
    a_uint = (sign_mask & a_uint) ? (~a_uint + 1) : (sign_mask | a_uint);
    b_uint = (sign_mask & b_uint) ? (~b_uint + 1) : (sign_mask | b_uint);

    uint64_t distance = (a_uint >= b_uint) ? (a_uint - b_uint) : (b_uint - a_uint);
    return distance;
}

81 82 83 84 85 86 87 88 89
bool test::close_f(float a, float b, int mantissa_bits, int tolerance_bits)
{
    // isfinite(a) => !isinf(a) && !isnan(a)
    if (!isfinite(a) || !isfinite(b))
    {
        return false;
    }

    uint32_t distance = float_distance(a, b);
90 91 92 93 94 95 96 97 98 99

    // e.g. for float with 24 bit mantissa, 2 bit accuracy, and hard-coded 8 bit exponent_bits
    // tolerance_bit_shift = 32 -           (1 +  8 + (24 -     1         ) - 2             )
    //                       float_length    sign exp  mantissa implicit 1    tolerance_bits
    uint32_t tolerance_bit_shift = 32 - (1 + 8 + (mantissa_bits - 1) - tolerance_bits);
    uint32_t tolerance = static_cast<uint32_t>(1U) << tolerance_bit_shift;

    return distance <= tolerance;
}

100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120
bool test::close_f(double a, double b, int tolerance_bits)
{
    constexpr int mantissa_bits = 53;

    // isfinite(a) => !isinf(a) && !isnan(a)
    if (!isfinite(a) || !isfinite(b))
    {
        return false;
    }

    uint64_t distance = float_distance(a, b);

    // e.g. for double with 52 bit mantissa, 2 bit accuracy, and hard-coded 11 bit exponent_bits
    // tolerance_bit_shift = 64 -           (1 +  11 + (53 -     1         ) - 2             )
    //                       double_length   sign exp   mantissa implicit 1    tolerance_bits
    uint64_t tolerance_bit_shift = 64 - (1 + 11 + (mantissa_bits - 1) - tolerance_bits);
    uint64_t tolerance = static_cast<uint64_t>(1U) << tolerance_bit_shift;

    return distance <= tolerance;
}

121 122 123 124 125 126 127 128 129 130 131 132 133 134 135
vector<uint32_t> test::float_distances(const vector<float>& a, const vector<float>& b)
{
    if (a.size() != b.size())
    {
        throw ngraph_error("a.size() != b.size() for float_distances comparison.");
    }
    vector<uint32_t> distances(a.size());
    for (size_t i = 0; i < a.size(); ++i)
    {
        distances[i] = float_distance(a[i], b[i]);
    }

    return distances;
}

136 137 138 139 140 141 142 143 144 145 146 147 148 149 150
vector<uint64_t> test::float_distances(const vector<double>& a, const vector<double>& b)
{
    if (a.size() != b.size())
    {
        throw ngraph_error("a.size() != b.size() for float_distances comparison.");
    }
    vector<uint64_t> distances(a.size());
    for (size_t i = 0; i < a.size(); ++i)
    {
        distances[i] = float_distance(a[i], b[i]);
    }

    return distances;
}

151 152
uint32_t test::matching_mantissa_bits(uint32_t distance)
{
153 154
    uint32_t tolerance_bit_shift = 0;
    uint32_t num_bits_on = 0;
155

156 157 158
    // Do some bit probing to find the most significant bit that's on,
    // as well as how many bits are on.
    for (uint32_t check_bit = 0; check_bit < 32; ++check_bit)
159
    {
160
        if (distance & (1 << check_bit))
161
        {
162 163
            tolerance_bit_shift = check_bit;
            ++num_bits_on;
164 165 166
        }
    }

167 168 169
    // all_close_f is <= test for tolerance (where tolerance is uint32_t with single bit on)
    // So if more than one bit is on we need the next higher tolerance
    if (num_bits_on > 1)
170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187
    {
        ++tolerance_bit_shift;
    }

    // all_close_f calculation of tolerance_bit_shift:
    // e.g. for float with 24 bit mantissa, 2 bit accuracy, and hard-coded 8 bit exponent_bits
    //  tolerance_bit_shift   =     32 -          (1 +  8 + (24 -                    1         ) - 2             )
    //                              float_length   sign exp  matching_matissa_bits   implicit 1    tolerance_bits
    //
    // Assuming 0 tolerance_bits and solving for matching_matissa_bits yields:
    //  tolerance_bit_shift   =     32 -          (1 +  8 + (matching_matissa_bits - 1         ) - 0             )
    //  tolerance_bit_shift   =     32 -          (1 +  8 + (matching_matissa_bits - 1         )                 )
    //  matching_matissa_bits =     32 -          (1 +  8 + (tolerance_bit_shift   - 1         )                 )
    uint32_t matching_matissa_bits =
        tolerance_bit_shift < 24 ? (32 - (1 + 8 + (tolerance_bit_shift - 1))) : 0;
    return matching_matissa_bits;
}

188
uint32_t test::matching_mantissa_bits(uint64_t distance)
189
{
190 191
    uint32_t tolerance_bit_shift = 0;
    uint32_t num_bits_on = 0;
192

193 194 195
    // Do some bit probing to find the most significant bit that's on,
    // as well as how many bits are on.
    for (uint32_t check_bit = 0; check_bit < 64; ++check_bit)
196
    {
197
        if (distance & (1ull << check_bit))
198
        {
199 200
            tolerance_bit_shift = check_bit;
            ++num_bits_on;
201 202 203
        }
    }

204 205 206
    // all_close_f is <= test for tolerance (where tolerance is uint64_t with single bit on)
    // So if more than one bit is on we need the next higher tolerance
    if (num_bits_on > 1)
207 208 209 210 211 212 213 214 215 216 217 218 219
    {
        ++tolerance_bit_shift;
    }

    // all_close_f calculation of tolerance_bit_shift:
    // e.g. for double with 53 bit mantissa, 2 bit accuracy, and hard-coded 8 bit exponent_bits
    //  tolerance_bit_shift   =     64 -          (1 +  11 + (53 -                    1         ) - 2             )
    //                              double_length  sign exp   matching_matissa_bits   implicit 1    tolerance_bits
    //
    // Assuming 0 tolerance_bits and solving for matching_matissa_bits yields:
    //  tolerance_bit_shift   =     64 -          (1 +  11 + (matching_matissa_bits - 1         ) - 0             )
    //  tolerance_bit_shift   =     64 -          (1 +  11 + (matching_matissa_bits - 1         )                 )
    //  matching_matissa_bits =     64 -          (1 +  11 + (tolerance_bit_shift   - 1         )                 )
220
    uint32_t matching_matissa_bits =
221 222 223 224
        tolerance_bit_shift < 53 ? (64 - (1 + 11 + (tolerance_bit_shift - 1))) : 0;
    return matching_matissa_bits;
}

225 226 227 228
::testing::AssertionResult test::all_close_f(const vector<float>& a,
                                             const vector<float>& b,
                                             int mantissa_bits,
                                             int tolerance_bits)
229 230
{
    bool rc = true;
231
    stringstream msg;
232 233
    if (a.size() != b.size())
    {
234
        return ::testing::AssertionFailure() << "a.size() != b.size() for all_close_f comparison.";
235
    }
236 237 238 239 240 241 242 243 244 245 246 247
    vector<uint32_t> distances = float_distances(a, b);

    // e.g. for float with 24 bit mantissa, 2 bit accuracy, and hard-coded 8 bit exponent_bits
    // tolerance_bit_shift = 32 -           (1 +  8 + (24 -     1         ) - 2             )
    //                       float_length    sign exp  mantissa implicit 1    tolerance_bits
    uint32_t tolerance_bit_shift = 32 - (1 + 8 + (mantissa_bits - 1) - tolerance_bits);
    uint32_t tolerance = static_cast<uint32_t>(1U) << tolerance_bit_shift;
    uint32_t max_distance = 0;
    uint32_t min_distance = UINT_MAX;
    size_t max_distance_index = 0;
    size_t min_distance_index = 0;
    size_t diff_count = 0;
248 249
    for (size_t i = 0; i < a.size(); ++i)
    {
250 251 252 253 254 255 256 257 258 259 260
        if (distances[i] > max_distance)
        {
            max_distance = distances[i];
            max_distance_index = i;
        }
        if (distances[i] < min_distance)
        {
            min_distance = distances[i];
            min_distance_index = i;
        }
        bool is_close_f = distances[i] <= tolerance;
261 262
        if (!is_close_f)
        {
263
            if (diff_count < 5)
264
            {
265 266
                msg << std::setprecision(std::numeric_limits<long double>::digits10 + 1) << a[i]
                    << " is not close to " << b[i] << " at index " << i << "\n";
267 268
            }

269
            rc = false;
270
            diff_count++;
271 272
        }
    }
273 274
    if (!rc)
    {
275
        msg << "diff count: " << diff_count << " out of " << a.size() << "\n";
276
    }
277 278 279 280 281 282 283 284 285 286 287 288
    // Find median value via partial sorting
    size_t middle = distances.size() / 2;
    std::nth_element(distances.begin(), distances.begin() + middle, distances.end());
    uint32_t median_distance = distances[middle];
    if (distances.size() % 2 == 0)
    {
        // Find middle-1 value
        uint64_t median_sum = static_cast<uint64_t>(median_distance) +
                              *max_element(distances.begin(), distances.begin() + middle);
        median_distance = median_sum / 2;
    }

289 290 291 292 293 294 295 296 297
    if (rc && (std::getenv("NGRAPH_GTEST_INFO") != nullptr))
    {
        // Short unobtrusive message when passing
        std::cout << "[   INFO   ] Verifying match of >= " << (mantissa_bits - tolerance_bits)
                  << " mantissa bits (" << mantissa_bits << " bits precision - " << tolerance_bits
                  << " tolerance). Loosest match found is " << matching_mantissa_bits(max_distance)
                  << " mantissa bits.\n";
    }

298 299 300
    msg << "passing criteria - mismatch allowed  @ mantissa bit: "
        << (mantissa_bits - tolerance_bits) << " or later (" << mantissa_bits
        << " mantissa bits w/ " << tolerance_bits << " tolerance bits)\n";
301
    msg << std::setprecision(std::numeric_limits<long double>::digits10 + 1)
302 303 304
        << "tightest match   - mismatch occurred @ mantissa bit: "
        << matching_mantissa_bits(min_distance) << " or next bit (" << a[min_distance_index]
        << " vs " << b[min_distance_index] << " at [" << min_distance_index << "])\n";
305
    msg << std::setprecision(std::numeric_limits<long double>::digits10 + 1)
306 307 308 309 310
        << "loosest match    - mismatch occurred @ mantissa bit: "
        << matching_mantissa_bits(max_distance) << " or next bit (" << a[max_distance_index]
        << " vs " << b[max_distance_index] << " at [" << max_distance_index << "])\n";
    msg << "median match     - mismatch occurred @ mantissa bit: "
        << matching_mantissa_bits(median_distance) << " or next bit\n";
311 312 313 314 315

    ::testing::AssertionResult res =
        rc ? ::testing::AssertionSuccess() : ::testing::AssertionFailure();
    res << msg.str();
    return res;
316 317
}

318 319
::testing::AssertionResult
    test::all_close_f(const vector<double>& a, const vector<double>& b, int tolerance_bits)
320 321 322 323
{
    constexpr int mantissa_bits = 53;

    bool rc = true;
324
    stringstream msg;
325 326
    if (a.size() != b.size())
    {
327
        return ::testing::AssertionFailure() << "a.size() != b.size() for all_close_f comparison.";
328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357
    }
    vector<uint64_t> distances = float_distances(a, b);

    // e.g. for double with 52 bit mantissa, 2 bit accuracy, and hard-coded 11 bit exponent_bits
    // tolerance_bit_shift = 64 -           (1 +  11 + (53 -     1         ) - 2             )
    //                       double_length   sign exp   mantissa implicit 1    tolerance_bits
    uint64_t tolerance_bit_shift = 64 - (1 + 11 + (mantissa_bits - 1) - tolerance_bits);
    uint64_t tolerance = static_cast<uint64_t>(1U) << tolerance_bit_shift;
    uint64_t max_distance = 0;
    uint64_t min_distance = ULLONG_MAX;
    size_t max_distance_index = 0;
    size_t min_distance_index = 0;
    size_t diff_count = 0;
    for (size_t i = 0; i < a.size(); ++i)
    {
        if (distances[i] > max_distance)
        {
            max_distance = distances[i];
            max_distance_index = i;
        }
        if (distances[i] < min_distance)
        {
            min_distance = distances[i];
            min_distance_index = i;
        }
        bool is_close_f = distances[i] <= tolerance;
        if (!is_close_f)
        {
            if (diff_count < 5)
            {
358
                msg << a[i] << " is not close to " << b[i] << " at index " << i << "\n";
359 360 361 362 363 364
            }

            rc = false;
            diff_count++;
        }
    }
365
    msg << "diff count: " << diff_count << " out of " << a.size() << "\n";
366 367 368 369 370 371 372 373 374 375 376 377 378
    // Find median value via partial sorting
    size_t middle = distances.size() / 2;
    std::nth_element(distances.begin(), distances.begin() + middle, distances.end());
    uint64_t median_distance = distances[middle];
    if (distances.size() % 2 == 0)
    {
        uint64_t median_distance2 = *max_element(distances.begin(), distances.begin() + middle);
        uint64_t remainder1 = median_distance % 2;
        uint64_t remainder2 = median_distance2 % 2;
        median_distance =
            (median_distance / 2) + (median_distance2 / 2) + ((remainder1 + remainder2) / 2);
    }

379 380 381 382 383 384 385 386 387
    if (rc && (std::getenv("NGRAPH_GTEST_INFO") != nullptr))
    {
        // Short unobtrusive message when passing
        std::cout << "[   INFO   ] Verifying match of >= " << (mantissa_bits - tolerance_bits)
                  << " mantissa bits (" << mantissa_bits << " bits precision - " << tolerance_bits
                  << " tolerance). Loosest match found is " << matching_mantissa_bits(max_distance)
                  << " mantissa bits.\n";
    }

388 389 390
    msg << "passing criteria - mismatch allowed  @ mantissa bit: "
        << (mantissa_bits - tolerance_bits) << " or later (" << mantissa_bits
        << " mantissa bits w/ " << tolerance_bits << " tolerance bits)\n";
391
    msg << std::setprecision(std::numeric_limits<long double>::digits10 + 1)
392 393 394
        << "tightest match   - mismatch occurred @ mantissa bit: "
        << matching_mantissa_bits(min_distance) << " or next bit (" << a[min_distance_index]
        << " vs " << b[min_distance_index] << " at [" << min_distance_index << "])\n";
395
    msg << std::setprecision(std::numeric_limits<long double>::digits10 + 1)
396 397 398 399 400
        << "loosest match    - mismatch occurred @ mantissa bit: "
        << matching_mantissa_bits(max_distance) << " or next bit (" << a[max_distance_index]
        << " vs " << b[max_distance_index] << " at [" << max_distance_index << "])\n";
    msg << "median match     - mismatch occurred @ mantissa bit: "
        << matching_mantissa_bits(median_distance) << " or next bit\n";
401 402 403 404 405

    ::testing::AssertionResult res =
        rc ? ::testing::AssertionSuccess() : ::testing::AssertionFailure();
    res << msg.str();
    return res;
406 407
}

408 409 410 411
::testing::AssertionResult test::all_close_f(const std::shared_ptr<runtime::Tensor>& a,
                                             const std::shared_ptr<runtime::Tensor>& b,
                                             int mantissa_bits,
                                             int tolerance_bits)
412 413
{
    // Check that the layouts are compatible
Scott Cyphers's avatar
Scott Cyphers committed
414
    if (*a->get_tensor_layout() != *b->get_tensor_layout())
415
    {
416
        return ::testing::AssertionFailure() << "Cannot compare tensors with different layouts";
417 418 419
    }
    if (a->get_shape() != b->get_shape())
    {
420
        return ::testing::AssertionFailure() << "Cannot compare tensors with different shapes";
421 422 423 424 425 426
    }

    return test::all_close_f(
        read_float_vector(a), read_float_vector(b), mantissa_bits, tolerance_bits);
}

427 428 429 430 431
::testing::AssertionResult
    test::all_close_f(const std::vector<std::shared_ptr<runtime::Tensor>>& as,
                      const std::vector<std::shared_ptr<runtime::Tensor>>& bs,
                      int mantissa_bits,
                      int tolerance_bits)
432 433 434
{
    if (as.size() != bs.size())
    {
435
        return ::testing::AssertionFailure() << "Cannot compare tensors with different sizes";
436 437 438
    }
    for (size_t i = 0; i < as.size(); ++i)
    {
439 440
        auto ar = test::all_close_f(as[i], bs[i], mantissa_bits, tolerance_bits);
        if (!ar)
441
        {
442
            return ar;
443 444
        }
    }
445
    return ::testing::AssertionSuccess();
446
}