all_close_f.cpp 22.1 KB
Newer Older
1
//*****************************************************************************
2
// Copyright 2017-2019 Intel Corporation
3 4 5 6 7 8 9 10 11 12 13 14 15
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
16

17
#include <climits>
18 19 20 21 22 23 24 25 26 27 28 29
#include <cmath>

#include "util/all_close_f.hpp"

using namespace std;
using namespace ngraph;

union FloatUnion {
    float f;
    uint32_t i;
};

30 31 32 33 34
union DoubleUnion {
    double d;
    uint64_t i;
};

35 36 37 38 39 40
constexpr uint32_t FLOAT_BELOW_MIN_SIGNAL = UINT_MAX;
constexpr uint32_t FLOAT_MAX_DIFF = UINT_MAX - 1;
constexpr uint64_t DOUBLE_BELOW_MIN_SIGNAL = ULLONG_MAX;
constexpr uint64_t DOUBLE_MAX_DIFF = ULLONG_MAX - 1;

uint32_t test::float_distance(float a, float b, float min_signal)
41
{
42
    if (std::isnan(a) && std::isnan(b))
43
    {
44 45 46 47 48 49 50 51 52 53 54 55
        return 0;
    }
    else if (std::isinf(a) && std::isinf(b))
    {
        if (a > 0 && b > 0)
        {
            return 0;
        }
        else if (a < 0 && b < 0)
        {
            return 0;
        }
56
        return FLOAT_MAX_DIFF;
57 58 59 60
    }

    FloatUnion a_fu{a};
    FloatUnion b_fu{b};
61
    FloatUnion min_signal_fu{min_signal};
62 63 64 65 66 67 68
    uint32_t a_uint = a_fu.i;
    uint32_t b_uint = b_fu.i;

    // A trick to handle both positive and negative numbers, see https://goo.gl/YbdnFQ
    // - If negative: convert to two's complement
    // - If positive: mask with sign bit
    uint32_t sign_mask = static_cast<uint32_t>(1U) << 31;
69
    uint32_t abs_value_bits_mask = ~sign_mask;
70 71 72
    a_uint = (sign_mask & a_uint) ? (~a_uint + 1) : (sign_mask | a_uint);
    b_uint = (sign_mask & b_uint) ? (~b_uint + 1) : (sign_mask | b_uint);

73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91
    uint32_t distance;
    uint32_t a_uint_abs = (abs_value_bits_mask & a_fu.i);
    uint32_t b_uint_abs = (abs_value_bits_mask & b_fu.i);
    uint32_t min_signal_uint_abs = (abs_value_bits_mask & min_signal_fu.i);
    if ((a_uint_abs < min_signal_uint_abs) && (b_uint_abs < min_signal_uint_abs))
    {
        // Both a & b below minimum signal
        distance = FLOAT_BELOW_MIN_SIGNAL;
    }
    else
    {
        distance = (a_uint >= b_uint) ? (a_uint - b_uint) : (b_uint - a_uint);
        // We've reserved UINT_MAX to mean FLOAT_BELOW_MIN_SIGNAL
        if (distance == UINT_MAX)
        {
            distance = FLOAT_MAX_DIFF;
        }
    }

92 93 94
    return distance;
}

95
uint64_t test::float_distance(double a, double b, double min_signal)
96
{
97
    if (std::isnan(a) && std::isnan(b))
98
    {
99 100 101 102 103 104 105 106 107 108 109 110
        return 0;
    }
    else if (std::isinf(a) && std::isinf(b))
    {
        if (a > 0 && b > 0)
        {
            return 0;
        }
        else if (a < 0 && b < 0)
        {
            return 0;
        }
111
        return DOUBLE_MAX_DIFF;
112 113 114 115
    }

    DoubleUnion a_du{a};
    DoubleUnion b_du{b};
116
    DoubleUnion min_signal_du{min_signal};
117 118 119 120 121 122 123
    uint64_t a_uint = a_du.i;
    uint64_t b_uint = b_du.i;

    // A trick to handle both positive and negative numbers, see https://goo.gl/YbdnFQ
    // - If negative: convert to two's complement
    // - If positive: mask with sign bit
    uint64_t sign_mask = static_cast<uint64_t>(1U) << 63;
124
    uint64_t abs_value_bits_mask = ~sign_mask;
125 126 127
    a_uint = (sign_mask & a_uint) ? (~a_uint + 1) : (sign_mask | a_uint);
    b_uint = (sign_mask & b_uint) ? (~b_uint + 1) : (sign_mask | b_uint);

128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146
    uint64_t distance;
    uint64_t a_uint_abs = (abs_value_bits_mask & a_du.i);
    uint64_t b_uint_abs = (abs_value_bits_mask & b_du.i);
    uint64_t min_signal_uint_abs = (abs_value_bits_mask & min_signal_du.i);
    if ((a_uint_abs < min_signal_uint_abs) && (b_uint_abs < min_signal_uint_abs))
    {
        // Both a & b below minimum signal
        distance = DOUBLE_BELOW_MIN_SIGNAL;
    }
    else
    {
        distance = (a_uint >= b_uint) ? (a_uint - b_uint) : (b_uint - a_uint);
        // We've reserved ULLONG_MAX to mean DOUBLE_BELOW_MIN_SIGNAL
        if (distance == ULLONG_MAX)
        {
            distance = DOUBLE_MAX_DIFF;
        }
    }

147 148 149
    return distance;
}

150
bool test::close_f(float a, float b, int tolerance_bits, float min_signal)
151
{
152
    if (std::isnan(a) && std::isnan(b))
153
    {
154 155 156 157 158 159 160 161 162 163 164 165
        return true;
    }
    else if (std::isinf(a) && std::isinf(b))
    {
        if (a > 0 && b > 0)
        {
            return true;
        }
        else if (a < 0 && b < 0)
        {
            return true;
        }
166 167 168
        return false;
    }

169
    uint32_t distance = float_distance(a, b, min_signal);
170 171 172 173

    // e.g. for float with 24 bit mantissa, 2 bit accuracy, and hard-coded 8 bit exponent_bits
    // tolerance_bit_shift = 32 -           (1 +  8 + (24 -     1         ) - 2             )
    //                       float_length    sign exp  mantissa implicit 1    tolerance_bits
174
    uint32_t tolerance_bit_shift = 32 - (1 + 8 + (FLOAT_MANTISSA_BITS - 1) - tolerance_bits);
175 176
    uint32_t tolerance = static_cast<uint32_t>(1U) << tolerance_bit_shift;

177
    return (distance <= tolerance) || (distance == FLOAT_BELOW_MIN_SIGNAL);
178 179
}

180
bool test::close_f(double a, double b, int tolerance_bits, double min_signal)
181
{
182
    if (std::isnan(a) && std::isnan(b))
183
    {
184 185 186 187 188 189 190 191 192 193 194 195
        return true;
    }
    else if (std::isinf(a) && std::isinf(b))
    {
        if (a > 0 && b > 0)
        {
            return true;
        }
        else if (a < 0 && b < 0)
        {
            return true;
        }
196 197 198
        return false;
    }

199
    uint64_t distance = float_distance(a, b, min_signal);
200 201 202 203

    // e.g. for double with 52 bit mantissa, 2 bit accuracy, and hard-coded 11 bit exponent_bits
    // tolerance_bit_shift = 64 -           (1 +  11 + (53 -     1         ) - 2             )
    //                       double_length   sign exp   mantissa implicit 1    tolerance_bits
204
    uint64_t tolerance_bit_shift = 64 - (1 + 11 + (DOUBLE_MANTISSA_BITS - 1) - tolerance_bits);
205 206
    uint64_t tolerance = static_cast<uint64_t>(1U) << tolerance_bit_shift;

207
    return (distance <= tolerance) || (distance == DOUBLE_BELOW_MIN_SIGNAL);
208 209
}

210 211
vector<uint32_t>
    test::float_distances(const vector<float>& a, const vector<float>& b, float min_signal)
212 213 214 215 216 217 218 219
{
    if (a.size() != b.size())
    {
        throw ngraph_error("a.size() != b.size() for float_distances comparison.");
    }
    vector<uint32_t> distances(a.size());
    for (size_t i = 0; i < a.size(); ++i)
    {
220
        distances[i] = float_distance(a[i], b[i], min_signal);
221 222 223 224 225
    }

    return distances;
}

226 227
vector<uint64_t>
    test::float_distances(const vector<double>& a, const vector<double>& b, double min_signal)
228 229 230 231 232 233 234 235
{
    if (a.size() != b.size())
    {
        throw ngraph_error("a.size() != b.size() for float_distances comparison.");
    }
    vector<uint64_t> distances(a.size());
    for (size_t i = 0; i < a.size(); ++i)
    {
236
        distances[i] = float_distance(a[i], b[i], min_signal);
237 238 239 240 241
    }

    return distances;
}

242 243
uint32_t test::matching_mantissa_bits(uint32_t distance)
{
244 245
    uint32_t tolerance_bit_shift = 0;
    uint32_t num_bits_on = 0;
246

247 248 249
    // Do some bit probing to find the most significant bit that's on,
    // as well as how many bits are on.
    for (uint32_t check_bit = 0; check_bit < 32; ++check_bit)
250
    {
251
        if (distance & (1 << check_bit))
252
        {
253 254
            tolerance_bit_shift = check_bit;
            ++num_bits_on;
255 256 257
        }
    }

258 259 260
    // all_close_f is <= test for tolerance (where tolerance is uint32_t with single bit on)
    // So if more than one bit is on we need the next higher tolerance
    if (num_bits_on > 1)
261 262 263 264
    {
        ++tolerance_bit_shift;
    }

265
    // clang-format off
266 267 268 269 270 271 272 273 274
    // all_close_f calculation of tolerance_bit_shift:
    // e.g. for float with 24 bit mantissa, 2 bit accuracy, and hard-coded 8 bit exponent_bits
    //  tolerance_bit_shift   =     32 -          (1 +  8 + (24 -                    1         ) - 2             )
    //                              float_length   sign exp  matching_matissa_bits   implicit 1    tolerance_bits
    //
    // Assuming 0 tolerance_bits and solving for matching_matissa_bits yields:
    //  tolerance_bit_shift   =     32 -          (1 +  8 + (matching_matissa_bits - 1         ) - 0             )
    //  tolerance_bit_shift   =     32 -          (1 +  8 + (matching_matissa_bits - 1         )                 )
    //  matching_matissa_bits =     32 -          (1 +  8 + (tolerance_bit_shift   - 1         )                 )
275
    // clang-format on
276 277 278 279 280
    uint32_t matching_matissa_bits =
        tolerance_bit_shift < 24 ? (32 - (1 + 8 + (tolerance_bit_shift - 1))) : 0;
    return matching_matissa_bits;
}

281
uint32_t test::matching_mantissa_bits(uint64_t distance)
282
{
283 284
    uint32_t tolerance_bit_shift = 0;
    uint32_t num_bits_on = 0;
285

286 287 288
    // Do some bit probing to find the most significant bit that's on,
    // as well as how many bits are on.
    for (uint32_t check_bit = 0; check_bit < 64; ++check_bit)
289
    {
290
        if (distance & (1ull << check_bit))
291
        {
292 293
            tolerance_bit_shift = check_bit;
            ++num_bits_on;
294 295 296
        }
    }

297 298 299
    // all_close_f is <= test for tolerance (where tolerance is uint64_t with single bit on)
    // So if more than one bit is on we need the next higher tolerance
    if (num_bits_on > 1)
300 301 302 303
    {
        ++tolerance_bit_shift;
    }

304
    // clang-format off
305 306 307 308 309 310 311 312 313
    // all_close_f calculation of tolerance_bit_shift:
    // e.g. for double with 53 bit mantissa, 2 bit accuracy, and hard-coded 8 bit exponent_bits
    //  tolerance_bit_shift   =     64 -          (1 +  11 + (53 -                    1         ) - 2             )
    //                              double_length  sign exp   matching_matissa_bits   implicit 1    tolerance_bits
    //
    // Assuming 0 tolerance_bits and solving for matching_matissa_bits yields:
    //  tolerance_bit_shift   =     64 -          (1 +  11 + (matching_matissa_bits - 1         ) - 0             )
    //  tolerance_bit_shift   =     64 -          (1 +  11 + (matching_matissa_bits - 1         )                 )
    //  matching_matissa_bits =     64 -          (1 +  11 + (tolerance_bit_shift   - 1         )                 )
314
    // clang-format on
315
    uint32_t matching_matissa_bits =
316 317 318 319
        tolerance_bit_shift < 53 ? (64 - (1 + 11 + (tolerance_bit_shift - 1))) : 0;
    return matching_matissa_bits;
}

320 321 322 323
::testing::AssertionResult test::all_close_f(const vector<float>& a,
                                             const vector<float>& b,
                                             int tolerance_bits,
                                             float min_signal)
324
{
325 326 327 328 329 330 331 332 333
    if (tolerance_bits < MIN_FLOAT_TOLERANCE_BITS)
    {
        tolerance_bits = MIN_FLOAT_TOLERANCE_BITS;
    }
    if (tolerance_bits >= FLOAT_MANTISSA_BITS)
    {
        tolerance_bits = FLOAT_MANTISSA_BITS - 1;
    }

334
    bool rc = true;
335
    stringstream msg;
336 337
    if (a.size() != b.size())
    {
338
        return ::testing::AssertionFailure() << "a.size() != b.size() for all_close_f comparison.";
339
    }
340 341 342 343
    if (a.size() == 0)
    {
        return ::testing::AssertionSuccess() << "No elements to compare";
    }
344
    vector<uint32_t> distances = float_distances(a, b, min_signal);
345 346 347 348

    // e.g. for float with 24 bit mantissa, 2 bit accuracy, and hard-coded 8 bit exponent_bits
    // tolerance_bit_shift = 32 -           (1 +  8 + (24 -     1         ) - 2             )
    //                       float_length    sign exp  mantissa implicit 1    tolerance_bits
349
    uint32_t tolerance_bit_shift = 32 - (1 + 8 + (FLOAT_MANTISSA_BITS - 1) - tolerance_bits);
350 351
    uint32_t tolerance = static_cast<uint32_t>(1U) << tolerance_bit_shift;
    uint32_t max_distance = 0;
352
    uint32_t min_distance = FLOAT_BELOW_MIN_SIGNAL;
353 354 355
    size_t max_distance_index = 0;
    size_t min_distance_index = 0;
    size_t diff_count = 0;
356
    size_t below_min_count = 0;
357 358
    for (size_t i = 0; i < a.size(); ++i)
    {
359 360 361 362 363 364 365
        if (distances[i] == FLOAT_BELOW_MIN_SIGNAL)
        {
            // Special value that indicates both values were below min_signal
            below_min_count++;
            continue;
        }

366 367 368 369 370 371 372 373 374 375 376
        if (distances[i] > max_distance)
        {
            max_distance = distances[i];
            max_distance_index = i;
        }
        if (distances[i] < min_distance)
        {
            min_distance = distances[i];
            min_distance_index = i;
        }
        bool is_close_f = distances[i] <= tolerance;
377 378
        if (!is_close_f)
        {
379
            if (diff_count < 5)
380
            {
381
                msg << std::setprecision(std::numeric_limits<long double>::digits10 + 1) << a[i]
382
                    << " is not close to " << b[i] << " at index " << i << std::endl;
383 384
            }

385
            rc = false;
386
            diff_count++;
387 388
        }
    }
389 390
    if (!rc)
    {
391
        msg << "diff count: " << diff_count << " out of " << a.size() << std::endl;
392
    }
393 394 395 396 397 398 399 400 401 402 403 404
    // Find median value via partial sorting
    size_t middle = distances.size() / 2;
    std::nth_element(distances.begin(), distances.begin() + middle, distances.end());
    uint32_t median_distance = distances[middle];
    if (distances.size() % 2 == 0)
    {
        // Find middle-1 value
        uint64_t median_sum = static_cast<uint64_t>(median_distance) +
                              *max_element(distances.begin(), distances.begin() + middle);
        median_distance = median_sum / 2;
    }

405
    bool all_below_min_signal = below_min_count == distances.size();
406 407 408
    if (rc && (std::getenv("NGRAPH_GTEST_INFO") != nullptr))
    {
        // Short unobtrusive message when passing
409 410
        std::cout << "[   INFO   ] Verifying match of <= " << (FLOAT_MANTISSA_BITS - tolerance_bits)
                  << " mantissa bits (" << FLOAT_MANTISSA_BITS << " bits precision - "
411 412 413
                  << tolerance_bits << " tolerance). ";
        if (all_below_min_signal)
        {
414
            std::cout << "All values below min_signal: " << min_signal << std::endl;
415 416 417 418 419 420 421
        }
        else
        {
            std::cout << below_min_count << " value(s) below min_signal: " << min_signal
                      << " Loosest match found is " << matching_mantissa_bits(max_distance)
                      << " mantissa bits.\n";
        }
422 423
    }

424
    msg << "passing criteria - mismatch allowed  @ mantissa bit: "
425 426
        << (FLOAT_MANTISSA_BITS - tolerance_bits) << " or later (" << tolerance_bits
        << " tolerance bits)\n";
427 428
    if (all_below_min_signal)
    {
429
        msg << "All values below min_signal: " << min_signal << std::endl;
430 431 432
    }
    else
    {
433
        msg << below_min_count << " value(s) below min_signal: " << min_signal << std::endl;
434 435 436 437 438 439 440 441 442 443 444
        msg << std::setprecision(std::numeric_limits<long double>::digits10 + 1)
            << "tightest match   - mismatch occurred @ mantissa bit: "
            << matching_mantissa_bits(min_distance) << " or next bit (" << a[min_distance_index]
            << " vs " << b[min_distance_index] << " at [" << min_distance_index << "])\n";
        msg << std::setprecision(std::numeric_limits<long double>::digits10 + 1)
            << "loosest match    - mismatch occurred @ mantissa bit: "
            << matching_mantissa_bits(max_distance) << " or next bit (" << a[max_distance_index]
            << " vs " << b[max_distance_index] << " at [" << max_distance_index << "])\n";
        msg << "median match     - mismatch occurred @ mantissa bit: "
            << matching_mantissa_bits(median_distance) << " or next bit\n";
    }
445 446 447 448 449

    ::testing::AssertionResult res =
        rc ? ::testing::AssertionSuccess() : ::testing::AssertionFailure();
    res << msg.str();
    return res;
450 451
}

452 453 454 455
::testing::AssertionResult test::all_close_f(const vector<double>& a,
                                             const vector<double>& b,
                                             int tolerance_bits,
                                             double min_signal)
456
{
457 458 459 460 461 462 463 464
    if (tolerance_bits < 0)
    {
        tolerance_bits = 0;
    }
    if (tolerance_bits >= DOUBLE_MANTISSA_BITS)
    {
        tolerance_bits = DOUBLE_MANTISSA_BITS - 1;
    }
465 466

    bool rc = true;
467
    stringstream msg;
468 469
    if (a.size() != b.size())
    {
470
        return ::testing::AssertionFailure() << "a.size() != b.size() for all_close_f comparison.";
471
    }
472 473 474 475
    if (a.size() == 0)
    {
        return ::testing::AssertionSuccess() << "No elements to compare";
    }
476
    vector<uint64_t> distances = float_distances(a, b, min_signal);
477 478 479 480

    // e.g. for double with 52 bit mantissa, 2 bit accuracy, and hard-coded 11 bit exponent_bits
    // tolerance_bit_shift = 64 -           (1 +  11 + (53 -     1         ) - 2             )
    //                       double_length   sign exp   mantissa implicit 1    tolerance_bits
481
    uint64_t tolerance_bit_shift = 64 - (1 + 11 + (DOUBLE_MANTISSA_BITS - 1) - tolerance_bits);
482 483
    uint64_t tolerance = static_cast<uint64_t>(1U) << tolerance_bit_shift;
    uint64_t max_distance = 0;
484
    uint64_t min_distance = DOUBLE_BELOW_MIN_SIGNAL;
485 486 487
    size_t max_distance_index = 0;
    size_t min_distance_index = 0;
    size_t diff_count = 0;
488
    size_t below_min_count = 0;
489 490
    for (size_t i = 0; i < a.size(); ++i)
    {
491 492 493 494 495 496 497
        if (distances[i] == DOUBLE_BELOW_MIN_SIGNAL)
        {
            // Special value that indicates both values were below min_signal
            below_min_count++;
            continue;
        }

498 499 500 501 502 503 504 505 506 507 508 509 510 511 512
        if (distances[i] > max_distance)
        {
            max_distance = distances[i];
            max_distance_index = i;
        }
        if (distances[i] < min_distance)
        {
            min_distance = distances[i];
            min_distance_index = i;
        }
        bool is_close_f = distances[i] <= tolerance;
        if (!is_close_f)
        {
            if (diff_count < 5)
            {
513
                msg << a[i] << " is not close to " << b[i] << " at index " << i << std::endl;
514 515 516 517 518 519
            }

            rc = false;
            diff_count++;
        }
    }
520 521
    if (!rc)
    {
522
        msg << "diff count: " << diff_count << " out of " << a.size() << std::endl;
523
    }
524 525 526 527 528 529 530 531 532 533 534 535 536
    // Find median value via partial sorting
    size_t middle = distances.size() / 2;
    std::nth_element(distances.begin(), distances.begin() + middle, distances.end());
    uint64_t median_distance = distances[middle];
    if (distances.size() % 2 == 0)
    {
        uint64_t median_distance2 = *max_element(distances.begin(), distances.begin() + middle);
        uint64_t remainder1 = median_distance % 2;
        uint64_t remainder2 = median_distance2 % 2;
        median_distance =
            (median_distance / 2) + (median_distance2 / 2) + ((remainder1 + remainder2) / 2);
    }

537
    bool all_below_min_signal = below_min_count == distances.size();
538 539 540
    if (rc && (std::getenv("NGRAPH_GTEST_INFO") != nullptr))
    {
        // Short unobtrusive message when passing
541 542 543
        std::cout << "[   INFO   ] Verifying match of >= "
                  << (DOUBLE_MANTISSA_BITS - tolerance_bits) << " mantissa bits ("
                  << DOUBLE_MANTISSA_BITS << " bits precision - " << tolerance_bits
544 545 546
                  << " tolerance). ";
        if (all_below_min_signal)
        {
547
            std::cout << "All values below min_signal: " << min_signal << std::endl;
548 549 550 551 552 553 554
        }
        else
        {
            std::cout << below_min_count << " value(s) below min_signal: " << min_signal
                      << " Loosest match found is " << matching_mantissa_bits(max_distance)
                      << " mantissa bits.\n";
        }
555 556
    }

557
    msg << "passing criteria - mismatch allowed  @ mantissa bit: "
558 559
        << (DOUBLE_MANTISSA_BITS - tolerance_bits) << " or later (" << tolerance_bits
        << " tolerance bits)\n";
560 561
    if (all_below_min_signal)
    {
562
        msg << "All values below min_signal: " << min_signal << std::endl;
563 564 565
    }
    else
    {
566
        msg << below_min_count << " value(s) below min_signal: " << min_signal << std::endl;
567 568 569 570 571 572 573 574 575 576 577
        msg << std::setprecision(std::numeric_limits<long double>::digits10 + 1)
            << "tightest match   - mismatch occurred @ mantissa bit: "
            << matching_mantissa_bits(min_distance) << " or next bit (" << a[min_distance_index]
            << " vs " << b[min_distance_index] << " at [" << min_distance_index << "])\n";
        msg << std::setprecision(std::numeric_limits<long double>::digits10 + 1)
            << "loosest match    - mismatch occurred @ mantissa bit: "
            << matching_mantissa_bits(max_distance) << " or next bit (" << a[max_distance_index]
            << " vs " << b[max_distance_index] << " at [" << max_distance_index << "])\n";
        msg << "median match     - mismatch occurred @ mantissa bit: "
            << matching_mantissa_bits(median_distance) << " or next bit\n";
    }
578 579 580 581 582

    ::testing::AssertionResult res =
        rc ? ::testing::AssertionSuccess() : ::testing::AssertionFailure();
    res << msg.str();
    return res;
583 584
}

585 586
::testing::AssertionResult test::all_close_f(const std::shared_ptr<runtime::Tensor>& a,
                                             const std::shared_ptr<runtime::Tensor>& b,
587 588
                                             int tolerance_bits,
                                             float min_signal)
589 590
{
    // Check that the layouts are compatible
Scott Cyphers's avatar
Scott Cyphers committed
591
    if (*a->get_tensor_layout() != *b->get_tensor_layout())
592
    {
593
        return ::testing::AssertionFailure() << "Cannot compare tensors with different layouts";
594 595 596
    }
    if (a->get_shape() != b->get_shape())
    {
597
        return ::testing::AssertionFailure() << "Cannot compare tensors with different shapes";
598 599
    }

600 601
    return test::all_close_f(
        read_float_vector(a), read_float_vector(b), tolerance_bits, min_signal);
602 603
}

604 605 606
::testing::AssertionResult
    test::all_close_f(const std::vector<std::shared_ptr<runtime::Tensor>>& as,
                      const std::vector<std::shared_ptr<runtime::Tensor>>& bs,
607 608
                      int tolerance_bits,
                      float min_signal)
609 610 611
{
    if (as.size() != bs.size())
    {
612
        return ::testing::AssertionFailure() << "Cannot compare tensors with different sizes";
613 614 615
    }
    for (size_t i = 0; i < as.size(); ++i)
    {
616
        auto ar = test::all_close_f(as[i], bs[i], tolerance_bits, min_signal);
617
        if (!ar)
618
        {
619
            return ar;
620 621
        }
    }
622
    return ::testing::AssertionSuccess();
623
}