test_tools.cpp 10.8 KB
Newer Older
1
//*****************************************************************************
2
// Copyright 2017-2019 Intel Corporation
3 4 5 6 7 8 9 10 11 12 13 14 15
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
Robert Kimball's avatar
Robert Kimball committed
16 17 18

#include <algorithm>

19 20
#include "ngraph/ngraph.hpp"
#include "ngraph/util.hpp"
21
#include "test_tools.hpp"
Robert Kimball's avatar
Robert Kimball committed
22 23 24 25

using namespace std;
using namespace ngraph;

26
vector<float> read_float_vector(shared_ptr<runtime::Tensor> tv)
27 28
{
    vector<float> float_vec;
Scott Cyphers's avatar
Scott Cyphers committed
29
    element::Type element_type = tv->get_tensor_layout()->get_element_type();
30 31 32 33

    if (element_type == element::boolean)
    {
        vector<char> vec = read_vector<char>(tv);
34 35 36 37 38 39
        // Changed from vector ctor to explicit for loop to add static_cast
        // This silences MSVC warnings
        for (char value : vec)
        {
            float_vec.push_back(static_cast<float>(value));
        }
40 41 42 43
    }
    else if (element_type == element::f32)
    {
        vector<float> vec = read_vector<float>(tv);
44 45 46 47
        for (float value : vec)
        {
            float_vec.push_back(static_cast<float>(value));
        }
48 49 50 51
    }
    else if (element_type == element::f64)
    {
        vector<double> vec = read_vector<double>(tv);
52 53 54 55
        for (double value : vec)
        {
            float_vec.push_back(static_cast<float>(value));
        }
56 57 58 59
    }
    else if (element_type == element::i8)
    {
        vector<int8_t> vec = read_vector<int8_t>(tv);
60 61 62 63
        for (int8_t value : vec)
        {
            float_vec.push_back(static_cast<float>(value));
        }
64 65 66 67
    }
    else if (element_type == element::i16)
    {
        vector<int16_t> vec = read_vector<int16_t>(tv);
68 69 70 71
        for (int16_t value : vec)
        {
            float_vec.push_back(static_cast<float>(value));
        }
72 73 74 75
    }
    else if (element_type == element::i32)
    {
        vector<int32_t> vec = read_vector<int32_t>(tv);
76 77 78 79
        for (int32_t value : vec)
        {
            float_vec.push_back(static_cast<float>(value));
        }
80 81 82 83
    }
    else if (element_type == element::i64)
    {
        vector<int64_t> vec = read_vector<int64_t>(tv);
84 85 86 87
        for (int64_t value : vec)
        {
            float_vec.push_back(static_cast<float>(value));
        }
88 89 90 91
    }
    else if (element_type == element::u8)
    {
        vector<uint8_t> vec = read_vector<uint8_t>(tv);
92 93 94 95
        for (uint8_t value : vec)
        {
            float_vec.push_back(static_cast<float>(value));
        }
96 97 98 99
    }
    else if (element_type == element::u16)
    {
        vector<uint16_t> vec = read_vector<uint16_t>(tv);
100 101 102 103
        for (uint16_t value : vec)
        {
            float_vec.push_back(static_cast<float>(value));
        }
104 105 106 107
    }
    else if (element_type == element::u32)
    {
        vector<uint32_t> vec = read_vector<uint32_t>(tv);
108 109 110 111
        for (uint32_t value : vec)
        {
            float_vec.push_back(static_cast<float>(value));
        }
112 113 114 115
    }
    else if (element_type == element::u64)
    {
        vector<uint64_t> vec = read_vector<uint64_t>(tv);
116 117 118 119
        for (uint64_t value : vec)
        {
            float_vec.push_back(static_cast<float>(value));
        }
120 121 122 123 124 125 126 127 128
    }
    else
    {
        throw ngraph_error("Unsupported nGraph element type.");
    }

    return float_vec;
}

Bob Kimball's avatar
Bob Kimball committed
129 130
// This function traverses the list of ops and verifies that each op's dependencies (its inputs)
// is located earlier in the list. That is enough to be valid
131
bool validate_list(const list<shared_ptr<Node>>& nodes)
Robert Kimball's avatar
Robert Kimball committed
132 133 134 135
{
    bool rc = true;
    for (auto it = nodes.rbegin(); it != nodes.rend(); it++)
    {
136
        auto node_tmp = *it;
137
        auto dependencies_tmp = node_tmp->get_arguments();
Robert Kimball's avatar
Robert Kimball committed
138
        vector<Node*> dependencies;
139

Robert Kimball's avatar
Robert Kimball committed
140 141 142 143
        for (shared_ptr<Node> n : dependencies_tmp)
        {
            dependencies.push_back(n.get());
        }
144 145
        auto tmp = it;
        for (tmp++; tmp != nodes.rend(); tmp++)
Robert Kimball's avatar
Robert Kimball committed
146 147
        {
            auto dep_tmp = *tmp;
148
            auto found = find(dependencies.begin(), dependencies.end(), dep_tmp.get());
Robert Kimball's avatar
Robert Kimball committed
149 150 151 152 153 154 155 156 157 158 159 160 161
            if (found != dependencies.end())
            {
                dependencies.erase(found);
            }
        }
        if (dependencies.size() > 0)
        {
            rc = false;
        }
    }
    return rc;
}

162
shared_ptr<Function> make_test_graph()
Robert Kimball's avatar
Robert Kimball committed
163
{
164 165 166 167 168 169
    auto arg_0 = make_shared<op::Parameter>(element::f32, Shape{});
    auto arg_1 = make_shared<op::Parameter>(element::f32, Shape{});
    auto arg_2 = make_shared<op::Parameter>(element::f32, Shape{});
    auto arg_3 = make_shared<op::Parameter>(element::f32, Shape{});
    auto arg_4 = make_shared<op::Parameter>(element::f32, Shape{});
    auto arg_5 = make_shared<op::Parameter>(element::f32, Shape{});
Robert Kimball's avatar
Robert Kimball committed
170 171 172 173 174 175 176 177 178 179

    auto t0 = make_shared<op::Add>(arg_0, arg_1);
    auto t1 = make_shared<op::Dot>(t0, arg_2);
    auto t2 = make_shared<op::Multiply>(t0, arg_3);

    auto t3 = make_shared<op::Add>(t1, arg_4);
    auto t4 = make_shared<op::Add>(t2, arg_5);

    auto r0 = make_shared<op::Add>(t3, t4);

180
    auto f0 = make_shared<Function>(r0, ParameterVector{arg_0, arg_1, arg_2, arg_3, arg_4, arg_5});
181 182

    return f0;
Robert Kimball's avatar
Robert Kimball committed
183
}
184 185 186 187 188 189 190 191 192 193 194 195 196 197

template <>
void init_int_tv<char>(ngraph::runtime::Tensor* tv,
                       std::default_random_engine& engine,
                       char min,
                       char max)
{
    size_t size = tv->get_element_count();
    std::uniform_int_distribution<int16_t> dist(static_cast<short>(min), static_cast<short>(max));
    std::vector<char> vec(size);
    for (char& element : vec)
    {
        element = static_cast<char>(dist(engine));
    }
198
    tv->write(vec.data(), vec.size() * sizeof(char));
199 200 201 202 203 204 205 206 207 208 209 210 211 212 213
}

template <>
void init_int_tv<int8_t>(ngraph::runtime::Tensor* tv,
                         std::default_random_engine& engine,
                         int8_t min,
                         int8_t max)
{
    size_t size = tv->get_element_count();
    std::uniform_int_distribution<int16_t> dist(static_cast<short>(min), static_cast<short>(max));
    std::vector<int8_t> vec(size);
    for (int8_t& element : vec)
    {
        element = static_cast<int8_t>(dist(engine));
    }
214
    tv->write(vec.data(), vec.size() * sizeof(int8_t));
215 216 217 218 219 220 221 222 223 224 225 226 227 228 229
}

template <>
void init_int_tv<uint8_t>(ngraph::runtime::Tensor* tv,
                          std::default_random_engine& engine,
                          uint8_t min,
                          uint8_t max)
{
    size_t size = tv->get_element_count();
    std::uniform_int_distribution<int16_t> dist(static_cast<short>(min), static_cast<short>(max));
    std::vector<uint8_t> vec(size);
    for (uint8_t& element : vec)
    {
        element = static_cast<uint8_t>(dist(engine));
    }
230
    tv->write(vec.data(), vec.size() * sizeof(uint8_t));
231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286
}

void random_init(ngraph::runtime::Tensor* tv, std::default_random_engine& engine)
{
    element::Type et = tv->get_element_type();
    if (et == element::boolean)
    {
        init_int_tv<char>(tv, engine, 0, 1);
    }
    else if (et == element::f32)
    {
        init_real_tv<float>(tv, engine, numeric_limits<float>::min(), 1.0f);
    }
    else if (et == element::f64)
    {
        init_real_tv<double>(tv, engine, numeric_limits<float>::min(), 1.0f);
    }
    else if (et == element::i8)
    {
        init_int_tv<int8_t>(tv, engine, -1, 1);
    }
    else if (et == element::i16)
    {
        init_int_tv<int16_t>(tv, engine, -1, 1);
    }
    else if (et == element::i32)
    {
        init_int_tv<int32_t>(tv, engine, 0, 1);
    }
    else if (et == element::i64)
    {
        init_int_tv<int64_t>(tv, engine, 0, 1);
    }
    else if (et == element::u8)
    {
        init_int_tv<uint8_t>(tv, engine, 0, 1);
    }
    else if (et == element::u16)
    {
        init_int_tv<uint16_t>(tv, engine, 0, 1);
    }
    else if (et == element::u32)
    {
        init_int_tv<uint32_t>(tv, engine, 0, 1);
    }
    else if (et == element::u64)
    {
        init_int_tv<uint64_t>(tv, engine, 0, 1);
    }
    else
    {
        throw runtime_error("unsupported type");
    }
}

template <>
287 288 289
string get_results_str(const std::vector<char>& ref_data,
                       const std::vector<char>& actual_data,
                       size_t max_results)
290
{
291
    stringstream ss;
292
    size_t num_results = std::min(static_cast<size_t>(max_results), ref_data.size());
293
    ss << "First " << num_results << " results";
294 295
    for (size_t i = 0; i < num_results; ++i)
    {
296
        ss << std::endl
297 298 299
           << std::setw(4) << i << " ref: " << std::setw(16) << std::left
           << static_cast<int>(ref_data[i]) << "  actual: " << std::setw(16) << std::left
           << static_cast<int>(actual_data[i]);
300
    }
301
    ss << std::endl;
302 303

    return ss.str();
304
}
305

306
#ifndef NGRAPH_JSON_DISABLE
307 308 309 310 311 312 313 314
std::shared_ptr<Function> make_function_from_file(const std::string& file_name)
{
    const string json_path = file_util::path_join(SERIALIZED_ZOO, file_name);
    const string json_string = file_util::read_file_to_string(json_path);
    stringstream ss(json_string);
    shared_ptr<Function> func = ngraph::deserialize(ss);
    return func;
}
315
#endif
316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356

::testing::AssertionResult test_ordered_ops(shared_ptr<Function> f, const NodeVector& required_ops)
{
    unordered_set<Node*> seen;
    for (auto& node_ptr : f->get_ordered_ops())
    {
        Node* node = node_ptr.get();
        if (seen.count(node) > 0)
        {
            return ::testing::AssertionFailure() << "Duplication in ordered ops";
        }
        size_t arg_count = node->get_input_size();
        for (size_t i = 0; i < arg_count; ++i)
        {
            Node* dep = node->input(i).get_source_output().get_node();
            if (seen.count(dep) == 0)
            {
                return ::testing::AssertionFailure() << "Argument " << *dep
                                                     << " does not occur before op" << *node;
            }
        }
        for (auto& dep_ptr : node->get_control_dependencies())
        {
            if (seen.count(dep_ptr.get()) == 0)
            {
                return ::testing::AssertionFailure() << "Control dependency " << *dep_ptr
                                                     << " does not occur before op" << *node;
            }
        }
        seen.insert(node);
    }
    for (auto& node_ptr : required_ops)
    {
        if (seen.count(node_ptr.get()) == 0)
        {
            return ::testing::AssertionFailure() << "Required op " << *node_ptr
                                                 << "does not occur in ordered ops";
        }
    }
    return ::testing::AssertionSuccess();
}