pattern.cpp 29.2 KB
Newer Older
1
//*****************************************************************************
2
// Copyright 2017-2019 Intel Corporation
3 4 5 6 7 8 9 10 11 12 13 14 15
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
16 17 18 19 20 21 22 23

#include <algorithm>
#include <cstdio>
#include <iostream>
#include <list>
#include <memory>

#include "gtest/gtest.h"
24
#include "ngraph/file_util.hpp"
varun-intel's avatar
varun-intel committed
25
#include "ngraph/graph_util.hpp"
26 27
#include "ngraph/log.hpp"
#include "ngraph/ngraph.hpp"
28 29 30 31 32 33 34 35 36
#include "ngraph/op/add.hpp"
#include "ngraph/op/batch_norm.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/sqrt.hpp"
#include "ngraph/op/subtract.hpp"
#include "ngraph/op/sum.hpp"
#include "ngraph/op/sum.hpp"
37 38 39 40
#include "ngraph/pass/graph_rewrite.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/label.hpp"
41
#include "ngraph/pattern/op/skip.hpp"
42 43
#include "ngraph/runtime/cpu/pass/cpu_fusion.hpp"
#include "ngraph/serializer.hpp"
44
#include "util/matcher.hpp"
45
#include "util/test_tools.hpp"
46 47 48 49 50 51

using namespace ngraph;
using namespace std;

static std::shared_ptr<Node> construct_constant_node(int n)
{
52 53 54
    return op::Constant::create(element::i32, Shape{}, {n});
}

55 56 57 58 59 60 61 62 63 64 65 66
static std::shared_ptr<pattern::op::Label> construct_variance_graph()
{
    // construct varaiance
    auto N = op::Constant::create(element::f32, Shape{3}, {2, 2, 2});
    auto input = std::make_shared<pattern::op::Label>(element::f32, Shape{2, 3});
    auto input_sq = std::make_shared<op::Multiply>(input, input);
    auto sum_input = std::make_shared<op::Sum>(input, AxisSet{0});
    auto square_sumed_input = std::make_shared<op::Multiply>(sum_input, sum_input);
    auto sum_squared_input = std::make_shared<op::Sum>(input_sq, AxisSet{0});
    auto avg_input_sum_sq = std::make_shared<op::Divide>(square_sumed_input, N);
    auto xmu = std::make_shared<op::Subtract>(sum_squared_input, avg_input_sum_sq);
    auto variance = std::make_shared<op::Divide>(xmu, N);
67 68
    auto variance_label =
        std::make_shared<pattern::op::Label>(variance, nullptr, NodeVector{variance});
69 70 71 72 73 74

    return variance_label;
}

static std::shared_ptr<pattern::op::Label> construct_mean_graph()
{
75
    // construct mean;
76 77 78 79
    auto input = std::make_shared<pattern::op::Label>(element::f32, Shape{2, 3});
    auto N = op::Constant::create(element::f32, Shape{3}, {2, 2, 2});
    auto sum_input1 = std::make_shared<op::Sum>(input, AxisSet{0});
    auto mean = std::make_shared<op::Divide>(sum_input1, N);
80
    auto mean_label = std::make_shared<pattern::op::Label>(mean, nullptr, NodeVector{mean});
81 82 83
    return mean_label;
}

84 85 86 87 88
class TestGraphRewrite : public ngraph::pass::GraphRewrite
{
public:
    void construct_multiply_by_one()
    {
89
        // pattern #1 : a * 1 = a
90
        auto iconst1 = construct_constant_node(1);
91
        auto pattern = std::make_shared<pattern::op::Label>(iconst1);
92

93
        ngraph::pattern::graph_rewrite_callback callback = [pattern](pattern::Matcher& m) {
94
            NGRAPH_DEBUG << "In a callback for construct_multiply_by_one against "
95
                         << m.get_match_root()->get_name();
96
            NGRAPH_CHECK(m.get_match_root()->get_arguments().size() == 2);
97

98 99
            auto pattern_map = m.get_pattern_map();

100 101
            size_t const_node_index =
                m.get_match_root()->get_arguments().at(0) == pattern_map[pattern];
102
            auto const_node = dynamic_pointer_cast<op::Constant>(
103 104
                m.get_match_root()->get_arguments().at(const_node_index));
            auto second_node = m.get_match_root()->get_arguments().at(const_node_index);
105 106
            NGRAPH_DEBUG << "second_node = " << second_node->get_name()
                         << " , pattern = " << pattern_map[pattern]->get_name();
107

108 109
            if (pattern_map[pattern]->get_element_type() != const_node->get_element_type() ||
                pattern_map[pattern]->get_shape() != const_node->get_shape())
110
            {
111
                NGRAPH_DEBUG << "Operands' types and/or shape don't match";
112
                return false;
113 114
            }

115
            auto const_values = const_node->get_vector<int32_t>();
116 117 118 119 120
            bool all_ones =
                std::all_of(begin(const_values), end(const_values), [](int e) { return e == 1; });

            if (!all_ones)
            {
121
                NGRAPH_DEBUG << "Constant vector's values aren't equal to 1";
122
                return false;
123
            }
124

125
            ngraph::replace_node(m.get_match_root(), pattern_map[pattern]);
126
            return true;
127 128 129 130 131 132 133 134
        };

        auto m = make_shared<TestMatcher>(pattern * iconst1, callback);
        this->add_matcher(m);
    }

    void construct_add_zero()
    {
135
        // pattern #2 : a + 0 = a
136
        auto iconst0 = construct_constant_node(0);
137
        auto pattern = std::make_shared<pattern::op::Label>(iconst0);
138

139
        auto callback = [pattern](pattern::Matcher& m) {
140
            NGRAPH_DEBUG << "In a callback for construct_add_zero against "
141
                         << m.get_match_root()->get_name();
142
            NGRAPH_CHECK(m.get_match_root()->get_arguments().size() == 2);
143

144 145
            auto pattern_map = m.get_pattern_map();

146 147
            size_t const_node_index =
                m.get_match_root()->get_arguments().at(0) == pattern_map[pattern];
148
            auto const_node = dynamic_pointer_cast<op::Constant>(
149 150
                m.get_match_root()->get_arguments().at(const_node_index));
            auto second_node = m.get_match_root()->get_arguments().at(const_node_index);
151 152
            NGRAPH_DEBUG << "second_node = " << second_node->get_name()
                         << " , pattern = " << pattern_map[pattern]->get_name();
153

154 155
            if (pattern_map[pattern]->get_element_type() != const_node->get_element_type() ||
                pattern_map[pattern]->get_shape() != const_node->get_shape())
156
            {
157
                NGRAPH_DEBUG << "Operands' types and/or shape don't match";
158
                return false;
159 160
            }

161
            auto const_values = const_node->get_vector<int>();
162 163 164 165 166
            bool all_zeros =
                std::all_of(begin(const_values), end(const_values), [](int e) { return e == 0; });

            if (!all_zeros)
            {
167
                NGRAPH_DEBUG << "Constant vector's values aren't equal to 0";
168
                return false;
169 170
            }

171
            ngraph::replace_node(m.get_match_root(), pattern_map[pattern]);
172
            return true;
173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190
        };

        auto m = make_shared<TestMatcher>(pattern + iconst0, callback);
        this->add_matcher(m);
    }

    TestGraphRewrite()
        : GraphRewrite()
    {
        construct_multiply_by_one();
        construct_add_zero();
    }
};

static void run_passes(pass::Manager& pass_manager,
                       shared_ptr<Node> graph,
                       std::vector<shared_ptr<op::Parameter>> parms)
{
191
    auto func = make_shared<Function>(graph, ParameterVector{parms});
192 193 194 195 196
    pass_manager.run_passes(func);
}

TEST(pattern, graph_rewrite)
{
197
    Shape shape{};
198 199 200
    pass::Manager pass_manager;
    pass_manager.register_pass<TestGraphRewrite>();

201 202 203 204 205 206 207 208
    {
        auto a = make_shared<op::Parameter>(element::i32, shape);
        auto b = make_shared<op::Parameter>(element::i32, shape);
        auto c = make_shared<op::Parameter>(element::i32, shape);
        auto iconst0 = construct_constant_node(0);
        auto graph_a = a + iconst0;
        auto graph_b = b + iconst0;

209
        auto f = std::make_shared<Function>(ngraph::NodeVector{a, b, graph_a, c, graph_b},
210
                                            ParameterVector{a, b, c});
211 212
        pass_manager.run_passes(f);

213 214
        ASSERT_TRUE(graph_a->output(0).get_target_inputs().empty());
        ASSERT_TRUE(graph_b->output(0).get_target_inputs().empty());
215

216
        auto expected = ngraph::NodeVector{a, b, a, c, b};
217
        ASSERT_TRUE(count_ops_of_type<op::Add>(f) == 0);
218 219
    }

220
    {
221 222
        auto a = make_shared<op::Parameter>(element::i32, shape);
        auto b = make_shared<op::Parameter>(element::i32, shape);
223 224 225 226
        auto iconst0 = construct_constant_node(0);
        auto sum = (a + iconst0);
        auto graph = b + sum;
        run_passes(pass_manager, graph, {a, b});
227
        ASSERT_EQ(graph->get_arguments().at(1), a);
228 229 230 231 232 233 234
        ASSERT_EQ(graph->input(1).get_source_output(),
                  a->output(0)); // graph's input points to a's output
        ASSERT_TRUE(sum->output(0)
                        .get_target_inputs()
                        .empty()); // graph's input is removed from sum's target inptus
        ASSERT_TRUE(a->output(0).get_target_inputs().count(
            graph->input(1))); // a's output feeds into graph's input
235 236 237
    }

    {
238 239
        auto a = make_shared<op::Parameter>(element::i32, shape);
        auto b = make_shared<op::Parameter>(element::i32, shape);
240 241 242 243
        auto iconst1 = construct_constant_node(1);
        auto mul = (a * iconst1);
        auto graph = b + mul;
        run_passes(pass_manager, graph, {a, b});
244
        ASSERT_EQ(graph->get_arguments().at(1), a);
245 246 247 248 249 250 251
        ASSERT_EQ(graph->input(1).get_source_output(),
                  a->output(0)); // graph's input points to a's output
        ASSERT_TRUE(mul->output(0)
                        .get_target_inputs()
                        .empty()); // graph's input is removed from sum's target inputs
        ASSERT_TRUE(a->output(0).get_target_inputs().count(
            graph->input(1))); // a's output feeds into graph's input
252 253 254
    }

    {
255 256
        auto a = make_shared<op::Parameter>(element::i32, shape);
        auto b = make_shared<op::Parameter>(element::i32, shape);
257 258 259
        auto iconst1 = construct_constant_node(1);
        auto graph = ((((a * iconst1) * iconst1) * iconst1) * iconst1) + b;
        run_passes(pass_manager, graph, {a, b});
260
        ASSERT_EQ(graph->get_arguments().at(0), a);
261 262 263 264
        ASSERT_EQ(graph->input(0).get_source_output(),
                  a->output(0)); // graph's input points to a's output
        ASSERT_TRUE(a->output(0).get_target_inputs().count(
            graph->input(0))); // a's output feeds into graph's input
265 266 267
    }

    {
268 269
        auto a = make_shared<op::Parameter>(element::i32, shape);
        auto b = make_shared<op::Parameter>(element::i32, shape);
270 271 272 273
        auto iconst0 = construct_constant_node(0);
        auto iconst1 = construct_constant_node(1);
        auto graph = b + (iconst0 + ((a + iconst0) * iconst1));
        run_passes(pass_manager, graph, {a, b});
274
        ASSERT_EQ(graph->get_arguments().at(1), a);
275 276 277 278
        ASSERT_EQ(graph->input(1).get_source_output(),
                  a->output(0)); // graph's input points to a's output
        ASSERT_TRUE(a->output(0).get_target_inputs().count(
            graph->input(1))); // a's output feeds into graph's input
279 280 281
    }

    {
282 283
        auto a = make_shared<op::Parameter>(element::i32, shape);
        auto b = make_shared<op::Parameter>(element::i32, shape);
284 285 286
        auto iconst1 = construct_constant_node(1);
        auto graph = b + (iconst1 * (iconst1 * (iconst1 * (iconst1 * a))));
        run_passes(pass_manager, graph, {a, b});
287
        ASSERT_EQ(graph->get_arguments().at(1), a);
288 289 290 291
        ASSERT_EQ(graph->input(1).get_source_output(),
                  a->output(0)); // graph's input points to a's output
        ASSERT_TRUE(a->output(0).get_target_inputs().count(
            graph->input(1))); // a's output feeds into graph's input
292 293 294
    }
}

295 296 297 298 299 300 301 302 303 304 305
std::ostream& operator<<(std::ostream& os, const ngraph::NodeVector& nv)
{
    std::vector<std::string> names;
    for (auto n : nv)
    {
        names.push_back(n->get_name());
    }
    os << vector_to_string(names);
    return os;
}

306 307
TEST(pattern, matcher)
{
308
    Shape shape{};
309
    auto a = make_shared<op::Parameter>(element::i32, shape);
310 311
    TestMatcher n(nullptr);
    ASSERT_TRUE(n.match(a, a));
312
    ASSERT_EQ(n.get_matched_nodes(), (NodeVector{a}));
313 314

    auto abs = make_shared<op::Abs>(a);
315
    auto any = std::make_shared<pattern::op::Skip>(a);
316
    ASSERT_TRUE(n.match(any, abs));
317
    ASSERT_EQ(n.get_matched_nodes(), (NodeVector{abs, a}));
318

Nick Korovaiko's avatar
Nick Korovaiko committed
319 320
    auto false_pred = [](std::shared_ptr<Node> no) { return false; };
    auto any_false = std::make_shared<pattern::op::Skip>(a, false_pred);
321
    ASSERT_TRUE(n.match(any_false, a));
322
    ASSERT_EQ(n.get_matched_nodes(), (NodeVector{a, a}));
323

324
    auto pattern = std::make_shared<pattern::op::Label>(a);
325
    ASSERT_TRUE(n.match(pattern, a));
326
    ASSERT_EQ(n.get_pattern_map()[pattern], a);
327
    ASSERT_EQ(n.get_matched_nodes(), (NodeVector{a}));
328

Nick Korovaiko's avatar
Nick Korovaiko committed
329
    auto pattern_false = std::make_shared<pattern::op::Label>(a, false_pred);
330
    ASSERT_FALSE(n.match(pattern_false, a));
331
    ASSERT_EQ(n.get_matched_nodes(), (NodeVector{}));
332

333
    auto b = make_shared<op::Parameter>(element::i32, shape);
Nick Korovaiko's avatar
Nick Korovaiko committed
334 335 336

    auto is_bea = pattern::has_class<op::util::BinaryElementwiseArithmetic>();
    auto bea = std::make_shared<pattern::op::Any>(a, is_bea, NodeVector{a, b});
337 338 339
    auto add_ab = a + b;
    ASSERT_TRUE(n.match(bea, add_ab));
    ASSERT_EQ(n.get_matched_nodes(), (NodeVector{add_ab, a, b}));
Nick Korovaiko's avatar
Nick Korovaiko committed
340 341 342 343 344
    ASSERT_TRUE(n.match(bea, b + a));

    auto bea_false = std::make_shared<pattern::op::Any>(a, false_pred, NodeVector{a, b});
    ASSERT_FALSE(n.match(bea_false, a + b));

345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361
    auto add_abs_b = abs + b;
    auto bea_any_of = std::make_shared<pattern::op::AnyOf>(a, is_bea, NodeVector{abs});
    ASSERT_TRUE(n.match(bea_any_of, add_abs_b));

    auto add_b_abs = b + abs;
    ASSERT_TRUE(n.match(bea_any_of, add_b_abs));

    auto bea_any_of_label =
        std::make_shared<pattern::op::Label>(a, nullptr, NodeVector{bea_any_of});
    ASSERT_TRUE(n.match(bea_any_of_label, add_b_abs));
    ASSERT_EQ(n.get_pattern_map()[bea_any_of_label], add_b_abs);

    auto abs_label = std::make_shared<pattern::op::Label>(a, nullptr, NodeVector{abs});
    auto bea_label_any_of = std::make_shared<pattern::op::AnyOf>(a, is_bea, NodeVector{abs_label});
    ASSERT_TRUE(n.match(bea_label_any_of, add_b_abs));
    ASSERT_EQ(n.get_pattern_map()[abs_label], abs);

Nick Korovaiko's avatar
Nick Korovaiko committed
362 363 364 365 366
    auto bea_label = std::make_shared<pattern::op::Label>(a, nullptr, NodeVector{bea});
    auto ab = a + b;
    ASSERT_TRUE(n.match(bea_label, ab));
    ASSERT_EQ(n.get_pattern_map()[bea_label], ab);

367
    auto d = make_shared<op::Parameter>(element::i32, shape);
368 369 370
    ASSERT_FALSE(n.match(d, b));

    ASSERT_FALSE(n.match(abs + b, b + b));
371 372 373 374 375
    ASSERT_EQ(n.get_matched_nodes(), (NodeVector{}));

    auto add_absb = abs + b;
    ASSERT_TRUE(n.match(any + b, add_absb));
    ASSERT_EQ(n.get_matched_nodes(), (NodeVector{add_absb, abs, a, b}));
376

377
    ASSERT_TRUE(n.match(pattern + b, add_absb));
378
    ASSERT_EQ(n.get_pattern_map()[pattern], abs);
379
    ASSERT_EQ(n.get_matched_nodes(), (NodeVector{add_absb, abs, b}));
380

381
    ASSERT_TRUE(n.match(b + pattern, add_absb));
382
    ASSERT_EQ(n.get_pattern_map()[pattern], abs);
383
    ASSERT_EQ(n.get_matched_nodes(), (NodeVector{add_absb, abs, b}));
384

385
    auto c = make_shared<op::Parameter>(element::i32, shape);
386 387
    auto mul_add_absb = c * (add_absb);
    ASSERT_TRUE(n.match(c * (b + pattern), mul_add_absb));
388
    ASSERT_EQ(n.get_pattern_map()[pattern], abs);
389
    ASSERT_EQ(n.get_matched_nodes(), (NodeVector{mul_add_absb, c, add_absb, abs, b}));
390

391
    ASSERT_TRUE(n.match(c * (any + b), mul_add_absb)); // nested any
392
    ASSERT_EQ(n.get_matched_nodes(), (NodeVector{mul_add_absb, c, add_absb, abs, a, b}));
393
    ASSERT_TRUE(n.match(c * (any + b), (b + abs) * c)); // permutations w/ any
394
    auto mul_c_add_ab = c * add_ab;
395 396
    ASSERT_TRUE(n.match(c * (any_false + b), c * (a + b)));  // nested any
    ASSERT_TRUE(n.match(c * (any_false + b), mul_c_add_ab)); // permutations w/ any_false
397
    ASSERT_EQ(n.get_matched_nodes(), (NodeVector{mul_c_add_ab, c, add_ab, a, a, b}));
398 399 400

    auto iconst1_0 = construct_constant_node(1);
    auto iconst1_1 = construct_constant_node(1);
401
    ASSERT_TRUE(n.match(pattern * iconst1_0, a * iconst1_1)); // different iconst
402
    ASSERT_EQ(n.get_pattern_map()[pattern], a);
403 404
    auto fconst1_0 = op::Constant::create(element::f32, shape, {1});
    auto patternf = std::make_shared<pattern::op::Label>(fconst1_0);
405
    ASSERT_TRUE(n.match(patternf * fconst1_0, a * iconst1_1)); // different iconst
406

407
    // Subgraph labels
408
    auto add = a + b;
409
    auto label = std::make_shared<pattern::op::Label>(add, nullptr, NodeVector{add});
410 411
    ASSERT_TRUE(n.match(label, add));
    ASSERT_EQ(n.get_pattern_map()[label], add);
412
    ASSERT_EQ(n.get_matched_nodes(), (NodeVector{add, add, a, b}));
413 414 415 416 417 418

    ASSERT_FALSE(n.match(label, a - b));

    ASSERT_TRUE(n.match(make_shared<op::Abs>(label), make_shared<op::Abs>(add)));
    ASSERT_EQ(n.get_pattern_map()[label], add);

419
    // Correct argument order
420 421 422 423 424 425 426 427 428
    ASSERT_FALSE(n.match(b - a, a - b));
    auto aab = a * (a - b);
    auto paab = pattern * (pattern - b);
    ASSERT_TRUE(n.match(paab, aab));
    auto aba = a * (b - a);
    ASSERT_FALSE(n.match(paab, aba));
    auto paba = pattern * (b - pattern);
    ASSERT_FALSE(n.match(paba, aab));

429
    // Correlations
430 431
    auto label1 = std::make_shared<pattern::op::Label>(a);
    auto tmp = label1 + b;
432
    auto label2 = std::make_shared<pattern::op::Label>(tmp, nullptr, NodeVector{tmp});
433
    auto sub_label1 = label1 - label2;
434 435
    auto sub_add = a - add;
    ASSERT_TRUE(n.match(sub_label1, sub_add));
436 437
    ASSERT_EQ(n.get_pattern_map()[label1], a);
    ASSERT_EQ(n.get_pattern_map()[label2], add);
438
    ASSERT_EQ(n.get_matched_nodes(), (NodeVector{sub_add, a, add, add, a, b}));
439 440 441 442 443 444 445

    ASSERT_FALSE(n.match(sub_label1, add - a));

    auto add_label1 = label1 + label2;
    ASSERT_TRUE(n.match(add_label1, add + a));
    ASSERT_EQ(n.get_pattern_map()[label1], a);
    ASSERT_EQ(n.get_pattern_map()[label2], add);
446 447 448

    // strict mode
    {
449
        TestMatcher sm(nullptr, nullptr, "TestMatcher", pass::PassProperty::REGULAR_FUSIONS, true);
450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468
        // exact shape and type
        auto scalar_param = make_shared<op::Parameter>(element::i32, Shape{});
        auto label_dynamic_shape =
            make_shared<pattern::op::Label>(element::i32, PartialShape::dynamic());
        auto param = make_shared<op::Parameter>(element::f32, Shape{});
        ASSERT_TRUE(sm.match(label_dynamic_shape, scalar_param));
        // wrong type
        auto scalar_param_wrong_type = make_shared<op::Parameter>(element::f32, Shape{});
        ASSERT_FALSE(sm.match(label, scalar_param_wrong_type));
        // dynamic dimension
        auto label_dynamic_dimension =
            make_shared<pattern::op::Label>(element::i32, PartialShape{Dimension::dynamic()});
        auto vector_param = make_shared<op::Parameter>(element::i32, Shape{10});
        ASSERT_TRUE(sm.match(label_dynamic_dimension, vector_param));
        // dynamic type
        auto label_dynamic_type =
            make_shared<pattern::op::Label>(element::dynamic, PartialShape{Dimension::dynamic()});
        ASSERT_TRUE(sm.match(label_dynamic_type, vector_param));
    }
469
}
470

471 472
TEST(pattern, mean)
{
473
    // construct mean
474 475 476 477 478 479 480 481 482 483 484 485 486 487
    TestMatcher n(nullptr);

    auto input = std::make_shared<op::Parameter>(element::f32, Shape{2, 3});
    auto N = op::Constant::create(element::f32, Shape{3}, {2, 2, 2});
    auto sum_input1 = std::make_shared<op::Sum>(input, AxisSet{0});
    auto mean = std::make_shared<op::Divide>(sum_input1, N);

    auto mean_graph = construct_mean_graph();
    ASSERT_TRUE(n.match(mean_graph, mean));
    ASSERT_EQ(n.get_pattern_map()[mean_graph], mean);
}

TEST(pattern, variance)
{
488
    // construct variance
489 490 491 492 493 494 495 496 497 498 499 500 501 502
    TestMatcher n(nullptr);
    auto N = op::Constant::create(element::f32, Shape{3}, {2, 2, 2});
    auto input = std::make_shared<pattern::op::Label>(element::f32, Shape{2, 3});
    auto input_sq = std::make_shared<op::Multiply>(input, input);
    auto sum_input = std::make_shared<op::Sum>(input, AxisSet{0});
    auto square_sumed_input = std::make_shared<op::Multiply>(sum_input, sum_input);
    auto sum_squared_input = std::make_shared<op::Sum>(input_sq, AxisSet{0});
    auto avg_input_sum_sq = std::make_shared<op::Divide>(square_sumed_input, N);
    auto xmu = std::make_shared<op::Subtract>(sum_squared_input, avg_input_sum_sq);
    auto variance = std::make_shared<op::Divide>(xmu, N);

    auto var_graph = construct_variance_graph();
    ASSERT_TRUE(n.match(var_graph, variance));
    ASSERT_EQ(n.get_pattern_map()[var_graph], variance);
503
}
504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529

TEST(pattern, previous_matches)
{
    using ngraph::pattern::Matcher;
    Shape shape{};
    Matcher::PatternMap previous_matches;
    auto a = make_shared<op::Parameter>(element::i32, shape);
    auto b = make_shared<op::Parameter>(element::i32, shape);
    auto pattern = std::make_shared<pattern::op::Label>(b);
    auto abs = make_shared<op::Abs>(a);
    auto add = abs + b;
    {
        Matcher n(pattern + b);
        ASSERT_TRUE(n.match(add, previous_matches));
        ASSERT_EQ(n.get_pattern_map()[pattern], abs);
    }

    {
        Matcher n(pattern + b);
        previous_matches.insert(std::make_pair(pattern, a));
        ASSERT_FALSE(n.match(add, previous_matches));
    }
}

TEST(pattern, recurrent_pattern)
{
530
    using ngraph::pattern::RecurrentMatcher;
531 532 533 534 535 536 537 538 539 540 541 542
    Shape shape{};
    ngraph::pattern::Matcher::PatternMap previous_matches;
    auto a = make_shared<op::Parameter>(element::i32, shape);
    auto b = make_shared<op::Parameter>(element::i32, shape);
    auto rpattern = std::make_shared<pattern::op::Label>(b);
    auto iconst0 = construct_constant_node(0);
    auto abs = make_shared<op::Abs>(a);
    auto add1 = iconst0 + b;
    auto add2 = iconst0 + add1;
    auto add3 = iconst0 + add2;
    auto padd = iconst0 + rpattern;
    std::set<std::shared_ptr<pattern::op::Label>> empty_correlated_matches;
543 544 545 546
    RecurrentMatcher rm(padd, rpattern, empty_correlated_matches, nullptr);
    ASSERT_TRUE(rm.match(add3));
    ASSERT_EQ(rm.get_number_of_bound_labels(), 1);
    auto recurrent_matches = rm.get_bound_nodes_for_pattern(rpattern);
547 548 549 550
    ASSERT_EQ(recurrent_matches.at(0), add2);
    ASSERT_EQ(recurrent_matches.at(1), add1);
    ASSERT_EQ(recurrent_matches.at(2), b);

551
    // Multiple labels in a reccuring pattern
552 553 554 555 556
    auto iconst1 = construct_constant_node(1);
    auto iconst_label = std::make_shared<pattern::op::Label>(iconst1, nullptr, NodeVector{iconst1});
    auto add2_2 = iconst1 + add1;
    auto add3_2 = iconst0 + add2_2;
    auto padd2 = iconst_label + rpattern;
557 558 559 560
    RecurrentMatcher rm2(padd2, rpattern, empty_correlated_matches, nullptr);
    ASSERT_TRUE(rm2.match(add3_2));
    ASSERT_EQ(rm2.get_number_of_bound_labels(), 2);
    recurrent_matches = rm2.get_bound_nodes_for_pattern(rpattern);
561 562 563
    ASSERT_EQ(recurrent_matches.at(0), add2_2);
    ASSERT_EQ(recurrent_matches.at(1), add1);
    ASSERT_EQ(recurrent_matches.at(2), b);
564
    auto iconst_matches = rm2.get_bound_nodes_for_pattern(iconst_label);
565 566 567 568
    ASSERT_EQ(iconst_matches.at(0), iconst0);
    ASSERT_EQ(iconst_matches.at(1), iconst1);
    ASSERT_EQ(iconst_matches.at(2), iconst0);

569
    // Non-matching correlated labels
570 571
    std::set<std::shared_ptr<pattern::op::Label>> correlated_matches;
    correlated_matches.insert(iconst_label);
572 573 574 575
    RecurrentMatcher rm3(padd2, rpattern, correlated_matches, nullptr);
    ASSERT_TRUE(rm3.match(add3_2));
    ASSERT_EQ(rm3.get_number_of_bound_labels(), 2);
    iconst_matches = rm3.get_bound_nodes_for_pattern(iconst_label);
576 577 578
    ASSERT_EQ(iconst_matches.size(), 1);
    ASSERT_EQ(iconst_matches.at(0), iconst0);

579 580
    // Matching correlated labels and
    // testing if RecurrentMatcher can be reused for different nodes
581 582 583
    ASSERT_TRUE(rm3.match(add3));
    ASSERT_EQ(rm3.get_number_of_bound_labels(), 2);
    recurrent_matches = rm3.get_bound_nodes_for_pattern(rpattern);
584 585 586
    ASSERT_EQ(recurrent_matches.at(0), add2);
    ASSERT_EQ(recurrent_matches.at(1), add1);
    ASSERT_EQ(recurrent_matches.at(2), b);
587
    iconst_matches = rm3.get_bound_nodes_for_pattern(iconst_label);
588 589 590 591
    ASSERT_EQ(iconst_matches.at(0), iconst0);
    ASSERT_EQ(iconst_matches.at(1), iconst0);
    ASSERT_EQ(iconst_matches.at(2), iconst0);
}
592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612

class TestRecurrentGraphRewrite : public ngraph::pass::RecurrentGraphRewrite
{
public:
    void construct_recurrent_add()
    {
        Shape shape{};
        auto iconst0 = construct_constant_node(0);
        auto iconst_label =
            std::make_shared<pattern::op::Label>(iconst0, nullptr, NodeVector{iconst0});
        auto rpattern = std::make_shared<pattern::op::Label>(element::i32, shape);
        auto padd = iconst_label + rpattern;

        ngraph::pattern::recurrent_graph_rewrite_callback callback = [iconst_label, rpattern](
            pattern::RecurrentMatcher& rm) {
            NGRAPH_DEBUG << "In a callback for construct_recurrent_add against "
                         << rm.get_match_root()->get_name();

            auto iconst_matches = rm.get_bound_nodes_for_pattern(iconst_label);

            auto is_iconst_zero = [](std::shared_ptr<Node> n) {
613
                bool result = ngraph::is_zero(n);
614
                NGRAPH_DEBUG << n->get_name() << " is " << (result ? " a zero " : " not a zero");
615
                return ngraph::is_zero(n);
616 617 618 619 620 621 622 623 624 625 626
            };

            bool are_all_iconst_zeros =
                std::all_of(iconst_matches.begin(), iconst_matches.end(), is_iconst_zero);

            if (!are_all_iconst_zeros)
            {
                return false;
            }

            auto number_of_adds = rm.get_number_of_recurrent_matches();
627 628
            // replace the topmost add with the seed (i.e. the first parameter to add)
            // matches are added in reverse order (i.e. the first match is the topmost node)
629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669
            auto arg = rm.get_bound_nodes_for_pattern(rpattern).at(number_of_adds - 1);
            NGRAPH_DEBUG << "Replacing " << rm.get_match_root()->get_name() << " with "
                         << arg->get_name();
            ngraph::replace_node(rm.get_match_root(), arg);
            return true;
        };

        std::set<std::shared_ptr<pattern::op::Label>> empty_correlated_matches;
        auto rm = make_shared<pattern::RecurrentMatcher>(
            padd, rpattern, empty_correlated_matches, callback);
        this->add_matcher(rm);
    }

    TestRecurrentGraphRewrite()
        : RecurrentGraphRewrite()
    {
        construct_recurrent_add();
    }
};

TEST(pattern, recurrent_graph_rewrite)
{
    Shape shape{};
    pass::Manager pass_manager;
    pass_manager.register_pass<TestRecurrentGraphRewrite>();

    {
        auto a = make_shared<op::Parameter>(element::i32, shape);
        auto iconst0 = construct_constant_node(0);
        auto add_a1 = a + iconst0;
        auto add_a2 = add_a1 + iconst0;
        auto add_a3 = add_a2 + iconst0;
        auto abs_add_a3 = std::make_shared<op::Abs>(add_a3);

        auto b = make_shared<op::Parameter>(element::i32, shape);
        auto add_b1 = b + iconst0;
        auto add_b2 = add_b1 + iconst0;
        auto abs_add_b2 = std::make_shared<op::Abs>(add_b2);

        auto graph = abs_add_a3 * abs_add_b2;

670
        auto f = std::make_shared<Function>(ngraph::NodeVector{graph}, ParameterVector{a, b});
671 672
        pass_manager.run_passes(f);

673 674
        auto left_abs = graph->get_argument(0);
        auto add_a = left_abs->get_argument(0);
675 676
        ASSERT_EQ(add_a, a);

677 678
        auto right_abs = graph->get_argument(1);
        auto add_b = right_abs->get_argument(0);
679 680 681
        ASSERT_EQ(add_b, b);
    }
}
Nick Korovaiko's avatar
Nick Korovaiko committed
682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713

TEST(pattern, label_on_skip)
{
    Shape shape{2, 2};
    auto a = make_shared<op::Parameter>(element::i32, shape);
    auto b = make_shared<op::Parameter>(element::i32, Shape{});
    auto iconst = ngraph::make_zero(element::i32, Shape{});
    auto label = std::make_shared<pattern::op::Label>(iconst);
    auto const_label =
        std::make_shared<pattern::op::Label>(iconst, ngraph::is_zero, NodeVector{iconst});

    auto bcst_pred = [](std::shared_ptr<Node> n) {
        return std::dynamic_pointer_cast<op::Broadcast>(n) != nullptr;
    };

    auto bcst = std::make_shared<pattern::op::Skip>(const_label, bcst_pred);
    auto bcst_label = std::make_shared<pattern::op::Label>(bcst, nullptr, NodeVector{bcst});
    auto matcher = std::make_shared<pattern::Matcher>(
        std::make_shared<op::Multiply>(label, bcst_label), nullptr);

    auto const_broadcast = make_shared<op::Broadcast>(iconst, shape, AxisSet{0, 1});
    auto mul = a * const_broadcast;
    auto mul_scalar = b * iconst;
    ASSERT_TRUE(matcher->match(mul));
    ASSERT_EQ(matcher->get_pattern_map()[bcst_label], const_broadcast);
    ASSERT_EQ(matcher->get_pattern_map()[const_label], iconst);
    ASSERT_EQ(matcher->get_pattern_map()[label], a);
    ASSERT_TRUE(matcher->match(mul_scalar));
    ASSERT_EQ(matcher->get_pattern_map()[bcst_label], iconst);
    ASSERT_EQ(matcher->get_pattern_map()[const_label], iconst);
    ASSERT_EQ(matcher->get_pattern_map()[label], b);
}
Nick Korovaiko's avatar
Nick Korovaiko committed
714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733

TEST(pattern, is_contained_match)
{
    Shape shape{};
    auto a = make_shared<op::Parameter>(element::i32, shape);
    auto absn = make_shared<op::Abs>(a);
    TestMatcher n(nullptr);

    auto label_a = std::make_shared<pattern::op::Label>(a);
    auto label_abs = make_shared<op::Abs>(a);
    ASSERT_TRUE(n.match(label_abs, absn));
    auto result_absn = make_shared<op::Result>(absn);
    ASSERT_TRUE(n.is_contained_match());

    auto absn2 = make_shared<op::Abs>(absn);
    auto result_absn2 = make_shared<op::Result>(absn2);
    auto label_abs2 = make_shared<op::Abs>(label_abs);
    ASSERT_TRUE(n.match(label_abs2, absn2));
    ASSERT_FALSE(n.is_contained_match());
}