serialize.cpp 10.2 KB
Newer Older
1
//*****************************************************************************
2
// Copyright 2017-2019 Intel Corporation
3 4 5 6 7 8 9 10 11 12 13 14 15
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
Robert Kimball's avatar
Robert Kimball committed
16

Robert Kimball's avatar
Robert Kimball committed
17 18 19
#include <fstream>
#include <sstream>

20
#include "gmock/gmock.h"
Robert Kimball's avatar
Robert Kimball committed
21
#include "gtest/gtest.h"
Robert Kimball's avatar
Robert Kimball committed
22

23
#include "ngraph/file_util.hpp"
Robert Kimball's avatar
Robert Kimball committed
24
#include "ngraph/ngraph.hpp"
25
#include "ngraph/op/constant.hpp"
26 27
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/passthrough.hpp"
Robert Kimball's avatar
Robert Kimball committed
28 29
#include "ngraph/serializer.hpp"
#include "ngraph/util.hpp"
30
#include "nlohmann/json.hpp"
31
#include "util/all_close_f.hpp"
32
#include "util/test_tools.hpp"
Robert Kimball's avatar
Robert Kimball committed
33 34 35

using namespace std;
using namespace ngraph;
36
using json = nlohmann::json;
Robert Kimball's avatar
Robert Kimball committed
37

38 39 40 41
using ::testing::ElementsAre;
using ::testing::NotNull;
using ::testing::StrEq;

42 43 44 45 46 47 48 49 50 51 52 53 54 55 56
template <typename T>
T get_or_default(nlohmann::json& j, const std::string& key, const T& default_value)
{
    T rc;
    try
    {
        rc = j.at(key).get<T>();
    }
    catch (...)
    {
        rc = default_value;
    }
    return rc;
}

57
#if defined(NGRAPH_INTERPRETER_ENABLE)
Robert Kimball's avatar
Robert Kimball committed
58 59 60
TEST(serialize, main)
{
    // First create "f(A,B,C) = (A+B)*C".
61
    Shape shape{2, 2};
62 63 64
    auto A = make_shared<op::Parameter>(element::f32, shape);
    auto B = make_shared<op::Parameter>(element::f32, shape);
    auto C = make_shared<op::Parameter>(element::f32, shape);
65
    auto f = make_shared<Function>((A + B) * C, ParameterVector{A, B, C}, "f");
Robert Kimball's avatar
Robert Kimball committed
66

67
    string js = serialize(f, 4);
Robert Kimball's avatar
Robert Kimball committed
68 69

    {
70 71
        ofstream out("serialize_function.js");
        out << js;
Robert Kimball's avatar
Robert Kimball committed
72 73 74 75
    }

    istringstream in(js);
    shared_ptr<Function> sfunc = deserialize(in);
76
    auto backend = runtime::Backend::create("INTERPRETER");
77
    auto handle = backend->compile(sfunc);
Robert Kimball's avatar
Robert Kimball committed
78

79
    auto x = backend->create_tensor(element::f32, shape);
Robert Kimball's avatar
Robert Kimball committed
80
    copy_data(x, vector<float>{1, 2, 3, 4});
81
    auto y = backend->create_tensor(element::f32, shape);
Robert Kimball's avatar
Robert Kimball committed
82
    copy_data(y, vector<float>{5, 6, 7, 8});
83
    auto z = backend->create_tensor(element::f32, shape);
Robert Kimball's avatar
Robert Kimball committed
84
    copy_data(z, vector<float>{9, 10, 11, 12});
85
    auto result = backend->create_tensor(element::f32, shape);
Robert Kimball's avatar
Robert Kimball committed
86

87
    handle->call_with_validate({result}, {x, y, z});
88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131
    EXPECT_EQ((vector<float>{54, 80, 110, 144}), read_vector<float>(result));

    handle->call_with_validate({result}, {y, x, z});
    EXPECT_EQ((vector<float>{54, 80, 110, 144}), read_vector<float>(result));

    handle->call_with_validate({result}, {x, z, y});
    EXPECT_EQ((vector<float>{50, 72, 98, 128}), read_vector<float>(result));
}

TEST(serialize, friendly_name)
{
    // First create "f(A,B,C) = (A+B)*C".
    Shape shape{2, 2};
    auto A = make_shared<op::Parameter>(element::f32, shape);
    auto B = make_shared<op::Parameter>(element::f32, shape);
    auto C = make_shared<op::Parameter>(element::f32, shape);
    auto sum = A + B;
    auto product = sum * C;
    auto f = make_shared<Function>(product, ParameterVector{A, B, C}, "f");

    A->set_friendly_name("A");
    B->set_friendly_name("B");
    C->set_friendly_name("C");
    sum->set_friendly_name("Sum");
    product->set_friendly_name("Product");

    string js = serialize(f, 4);
    ofstream out("serialize_function.js");
    out << js;

    istringstream in(js);
    shared_ptr<Function> sfunc = deserialize(in);
    auto backend = runtime::Backend::create("INTERPRETER");
    auto handle = backend->compile(sfunc);

    auto x = backend->create_tensor(element::f32, shape);
    copy_data(x, vector<float>{1, 2, 3, 4});
    auto y = backend->create_tensor(element::f32, shape);
    copy_data(y, vector<float>{5, 6, 7, 8});
    auto z = backend->create_tensor(element::f32, shape);
    copy_data(z, vector<float>{9, 10, 11, 12});
    auto result = backend->create_tensor(element::f32, shape);

    handle->call_with_validate({result}, {x, y, z});
132
    EXPECT_EQ((vector<float>{54, 80, 110, 144}), read_vector<float>(result));
Robert Kimball's avatar
Robert Kimball committed
133

134
    handle->call_with_validate({result}, {y, x, z});
135
    EXPECT_EQ((vector<float>{54, 80, 110, 144}), read_vector<float>(result));
Robert Kimball's avatar
Robert Kimball committed
136

137
    handle->call_with_validate({result}, {x, z, y});
138
    EXPECT_EQ((vector<float>{50, 72, 98, 128}), read_vector<float>(result));
Robert Kimball's avatar
Robert Kimball committed
139
}
140
#endif
141 142 143

TEST(serialize, existing_models)
{
144 145 146 147
    vector<string> models = {"mxnet/mnist_mlp_forward.json",
                             "mxnet/10_bucket_LSTM.json",
                             "mxnet/LSTM_backward.json",
                             "mxnet/LSTM_forward.json"};
148 149 150 151 152

    for (const string& model : models)
    {
        const string json_path = file_util::path_join(SERIALIZED_ZOO, model);
        const string json_string = file_util::read_file_to_string(json_path);
153
        shared_ptr<Function> f = ngraph::deserialize(json_string);
154 155
    }
}
156

157 158 159 160 161 162 163 164 165 166 167 168
TEST(serialize, default_value)
{
    json j = {{"test1", 1}, {"test2", 2}};

    int x1 = j.at("test1").get<int>();
    EXPECT_EQ(x1, 1);
    int x2 = get_or_default<int>(j, "test2", 0);
    EXPECT_EQ(x2, 2);
    int x3 = get_or_default<int>(j, "test3", 3);
    EXPECT_EQ(x3, 3);
}

169 170 171 172 173
TEST(serialize, constant)
{
    const string tmp_file = "serialize_constant.cpio";
    Shape shape{2, 2, 2};
    auto A = op::Constant::create(element::f32, shape, {1, 2, 3, 4, 5, 6, 7, 8});
174
    auto f = make_shared<Function>(A, ParameterVector{});
175 176 177 178

    EXPECT_EQ((vector<float>{1, 2, 3, 4, 5, 6, 7, 8}), A->get_vector<float>());
    serialize(tmp_file, f);
    auto g = deserialize(tmp_file);
179
    ASSERT_NE(g, nullptr);
180 181 182 183 184 185 186 187 188 189 190 191 192 193 194
    file_util::remove_file(tmp_file);
    bool found = false;
    for (shared_ptr<Node> node : g->get_ops())
    {
        shared_ptr<op::Constant> c = dynamic_pointer_cast<op::Constant>(node);
        if (c)
        {
            found = true;
            EXPECT_EQ((vector<float>{1, 2, 3, 4, 5, 6, 7, 8}), c->get_vector<float>());
            break;
        }
    }
    EXPECT_TRUE(found);
}

195 196 197 198 199 200 201 202 203 204 205 206 207 208
TEST(benchmark, serialize)
{
    stopwatch timer;
    string model = "mxnet/LSTM_backward.json";

    const string json_path = file_util::path_join(SERIALIZED_ZOO, model);
    timer.start();
    const string json_string = file_util::read_file_to_string(json_path);
    timer.stop();
    cout << "file read took " << timer.get_milliseconds() << "ms\n";
    timer.start();
    shared_ptr<Function> f = ngraph::deserialize(json_string);
    timer.stop();
    cout << "deserialize took " << timer.get_milliseconds() << "ms\n";
209 210 211 212

    ngraph::set_serialize_output_shapes(true);
    ofstream out("test.json");
    out << serialize(f, 4);
213
}
214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260

MATCHER_P2(IsOutputShape, type, shape, "")
{
    return std::get<0>(arg) == type && std::get<1>(arg).to_shape() == shape;
}

TEST(serialize, passthrough)
{
    const string tmp_file = "serialize_passthrough.json";

    using estuple = std::tuple<element::Type, PartialShape>;

    Shape shape{2, 2, 2};
    auto p = make_shared<op::Passthrough>(
        "SerializationTest",
        "Plain",
        "Hello, world!",
        NodeVector{},
        std::vector<estuple>{estuple{element::f32, PartialShape{2, 3}},
                             estuple{element::i8, PartialShape{4, 5}}});
    auto f = make_shared<Function>(NodeVector{std::make_shared<op::GetOutputElement>(p, 0),
                                              std::make_shared<op::GetOutputElement>(p, 1)},
                                   ParameterVector{});
    serialize(tmp_file, f);

    auto g = deserialize(tmp_file);
    file_util::remove_file(tmp_file);
    ASSERT_THAT(g, NotNull());

    std::shared_ptr<op::Passthrough> pt;
    for (const auto& op : g->get_ops())
    {
        pt = dynamic_pointer_cast<op::Passthrough>(op);
        if (pt)
        {
            break;
        }
    }
    ASSERT_THAT(pt.get(), NotNull());

    EXPECT_THAT(pt->logical_type(), StrEq("SerializationTest"));
    EXPECT_THAT(pt->language(), StrEq("Plain"));
    EXPECT_THAT(pt->function(), StrEq("Hello, world!"));
    EXPECT_THAT(pt->output_shapes(),
                ElementsAre(IsOutputShape(element::f32, Shape{2, 3}),
                            IsOutputShape(element::i8, Shape{4, 5})));
}
261 262 263

TEST(serialize, constant_infinity_nan)
{
264 265 266
    vector<float> a_data{123.f, 456.f, INFINITY, -INFINITY, NAN};
    vector<float> b_data{5.f, 5.f, 5.f, 5.f, 5.f, 5.f};
    vector<float> c_data{0.05f, 0.05f, 0.05f, 0.05f, 0.05f, 0.05001f, 0.05f};
267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312
    vector<int64_t> d_data{-100, -10, -1, 0, 50, 5000000000001};
    auto A = make_shared<op::Constant>(element::f32, Shape{5}, a_data);
    auto B = make_shared<op::Constant>(element::f32, Shape{6}, b_data);
    auto C = make_shared<op::Constant>(element::f32, Shape{7}, c_data);
    auto D = make_shared<op::Constant>(element::i64, Shape{d_data.size()}, d_data);
    A->set_friendly_name("A");
    B->set_friendly_name("B");
    C->set_friendly_name("C");
    D->set_friendly_name("D");
    auto f = make_shared<Function>(NodeVector{A, B, C, D}, ParameterVector{});

    string s = serialize(f, 4);
    shared_ptr<Function> g = deserialize(s);

    shared_ptr<op::Constant> a;
    shared_ptr<op::Constant> b;
    shared_ptr<op::Constant> c;
    shared_ptr<op::Constant> d;
    for (auto node : g->get_ops())
    {
        if (node->get_friendly_name() == "A")
        {
            a = static_pointer_cast<op::Constant>(node);
        }
        else if (node->get_friendly_name() == "B")
        {
            b = static_pointer_cast<op::Constant>(node);
        }
        else if (node->get_friendly_name() == "C")
        {
            c = static_pointer_cast<op::Constant>(node);
        }
        else if (node->get_friendly_name() == "D")
        {
            d = static_pointer_cast<op::Constant>(node);
        }
    }
    ASSERT_NE(a, nullptr);
    ASSERT_NE(b, nullptr);
    ASSERT_NE(c, nullptr);
    ASSERT_NE(d, nullptr);
    EXPECT_TRUE(test::all_close_f(a->get_vector<float>(), a_data));
    EXPECT_TRUE(test::all_close_f(b->get_vector<float>(), b_data));
    EXPECT_TRUE(test::all_close_f(c->get_vector<float>(), c_data));
    EXPECT_EQ(d->get_vector<int64_t>(), d_data);
}