pattern.cpp 30.3 KB
Newer Older
1
//*****************************************************************************
2
// Copyright 2017-2020 Intel Corporation
3 4 5 6 7 8 9 10 11 12 13 14 15
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
16 17 18 19 20 21 22 23

#include <algorithm>
#include <cstdio>
#include <iostream>
#include <list>
#include <memory>

#include "gtest/gtest.h"
24
#include "ngraph/file_util.hpp"
varun-intel's avatar
varun-intel committed
25
#include "ngraph/graph_util.hpp"
26 27
#include "ngraph/log.hpp"
#include "ngraph/ngraph.hpp"
28 29 30 31 32 33 34 35 36
#include "ngraph/op/add.hpp"
#include "ngraph/op/batch_norm.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/sqrt.hpp"
#include "ngraph/op/subtract.hpp"
#include "ngraph/op/sum.hpp"
#include "ngraph/op/sum.hpp"
37 38 39
#include "ngraph/pass/graph_rewrite.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pattern/matcher.hpp"
40
#include "ngraph/pattern/op/branch.hpp"
41
#include "ngraph/pattern/op/label.hpp"
42
#include "ngraph/pattern/op/or.hpp"
43
#include "ngraph/pattern/op/skip.hpp"
44
#include "ngraph/pattern/op/true.hpp"
45
#include "ngraph/serializer.hpp"
46
#include "util/matcher.hpp"
47
#include "util/test_tools.hpp"
48 49 50 51 52 53

using namespace ngraph;
using namespace std;

static std::shared_ptr<Node> construct_constant_node(int n)
{
54 55 56
    return op::Constant::create(element::i32, Shape{}, {n});
}

57 58 59 60 61 62 63 64 65 66 67 68
static std::shared_ptr<pattern::op::Label> construct_variance_graph()
{
    // construct varaiance
    auto N = op::Constant::create(element::f32, Shape{3}, {2, 2, 2});
    auto input = std::make_shared<pattern::op::Label>(element::f32, Shape{2, 3});
    auto input_sq = std::make_shared<op::Multiply>(input, input);
    auto sum_input = std::make_shared<op::Sum>(input, AxisSet{0});
    auto square_sumed_input = std::make_shared<op::Multiply>(sum_input, sum_input);
    auto sum_squared_input = std::make_shared<op::Sum>(input_sq, AxisSet{0});
    auto avg_input_sum_sq = std::make_shared<op::Divide>(square_sumed_input, N);
    auto xmu = std::make_shared<op::Subtract>(sum_squared_input, avg_input_sum_sq);
    auto variance = std::make_shared<op::Divide>(xmu, N);
69 70
    auto variance_label =
        std::make_shared<pattern::op::Label>(variance, nullptr, NodeVector{variance});
71 72 73 74 75 76

    return variance_label;
}

static std::shared_ptr<pattern::op::Label> construct_mean_graph()
{
77
    // construct mean;
78 79 80 81
    auto input = std::make_shared<pattern::op::Label>(element::f32, Shape{2, 3});
    auto N = op::Constant::create(element::f32, Shape{3}, {2, 2, 2});
    auto sum_input1 = std::make_shared<op::Sum>(input, AxisSet{0});
    auto mean = std::make_shared<op::Divide>(sum_input1, N);
82
    auto mean_label = std::make_shared<pattern::op::Label>(mean, nullptr, NodeVector{mean});
83 84 85
    return mean_label;
}

86 87 88 89 90
class TestGraphRewrite : public ngraph::pass::GraphRewrite
{
public:
    void construct_multiply_by_one()
    {
91
        // pattern #1 : a * 1 = a
92
        auto iconst1 = construct_constant_node(1);
93
        auto pattern = std::make_shared<pattern::op::Label>(iconst1);
94

95
        auto callback = [pattern](pattern::Matcher& m) {
96
            NGRAPH_DEBUG << "In a callback for construct_multiply_by_one against "
97
                         << m.get_match_root()->get_name();
98
            NGRAPH_CHECK(m.get_match_root()->get_arguments().size() == 2);
99

100 101
            auto pattern_map = m.get_pattern_map();

102 103
            size_t const_node_index =
                m.get_match_root()->get_arguments().at(0) == pattern_map[pattern];
104 105
            auto const_node =
                as_type_ptr<op::Constant>(m.get_match_root()->get_arguments().at(const_node_index));
106
            auto second_node = m.get_match_root()->get_arguments().at(const_node_index);
107 108
            NGRAPH_DEBUG << "second_node = " << second_node->get_name()
                         << " , pattern = " << pattern_map[pattern]->get_name();
109

110 111
            if (pattern_map[pattern]->get_element_type() != const_node->get_element_type() ||
                pattern_map[pattern]->get_shape() != const_node->get_shape())
112
            {
113
                NGRAPH_DEBUG << "Operands' types and/or shape don't match";
114
                return false;
115 116
            }

117
            auto const_values = const_node->get_vector<int32_t>();
118 119 120 121 122
            bool all_ones =
                std::all_of(begin(const_values), end(const_values), [](int e) { return e == 1; });

            if (!all_ones)
            {
123
                NGRAPH_DEBUG << "Constant vector's values aren't equal to 1";
124
                return false;
125
            }
126

127
            ngraph::replace_node(m.get_match_root(), pattern_map[pattern]);
128
            return true;
129 130
        };

131 132
        auto m = make_shared<TestMatcher>(pattern * iconst1);
        this->add_matcher(m, callback);
133 134 135 136
    }

    void construct_add_zero()
    {
137
        // pattern #2 : a + 0 = a
138
        auto iconst0 = construct_constant_node(0);
139
        auto pattern = std::make_shared<pattern::op::Label>(iconst0);
140

141
        auto callback = [pattern](pattern::Matcher& m) {
142
            NGRAPH_DEBUG << "In a callback for construct_add_zero against "
143
                         << m.get_match_root()->get_name();
144
            NGRAPH_CHECK(m.get_match_root()->get_arguments().size() == 2);
145

146 147
            auto pattern_map = m.get_pattern_map();

148 149
            size_t const_node_index =
                m.get_match_root()->get_arguments().at(0) == pattern_map[pattern];
Scott Cyphers's avatar
Scott Cyphers committed
150 151
            auto const_node =
                as_type_ptr<op::Constant>(m.get_match_root()->get_arguments().at(const_node_index));
152
            auto second_node = m.get_match_root()->get_arguments().at(const_node_index);
153 154
            NGRAPH_DEBUG << "second_node = " << second_node->get_name()
                         << " , pattern = " << pattern_map[pattern]->get_name();
155

156 157
            if (pattern_map[pattern]->get_element_type() != const_node->get_element_type() ||
                pattern_map[pattern]->get_shape() != const_node->get_shape())
158
            {
159
                NGRAPH_DEBUG << "Operands' types and/or shape don't match";
160
                return false;
161 162
            }

163
            auto const_values = const_node->get_vector<int>();
164 165 166 167 168
            bool all_zeros =
                std::all_of(begin(const_values), end(const_values), [](int e) { return e == 0; });

            if (!all_zeros)
            {
169
                NGRAPH_DEBUG << "Constant vector's values aren't equal to 0";
170
                return false;
171 172
            }

173
            ngraph::replace_node(m.get_match_root(), pattern_map[pattern]);
174
            return true;
175 176
        };

177 178 179
        auto add = pattern + iconst0;
        auto m = make_shared<TestMatcher>(add);
        this->add_matcher(m, callback);
180 181 182 183 184 185 186 187 188 189 190 191 192 193
    }

    TestGraphRewrite()
        : GraphRewrite()
    {
        construct_multiply_by_one();
        construct_add_zero();
    }
};

static void run_passes(pass::Manager& pass_manager,
                       shared_ptr<Node> graph,
                       std::vector<shared_ptr<op::Parameter>> parms)
{
194
    auto func = make_shared<Function>(graph, ParameterVector{parms});
195 196 197 198 199
    pass_manager.run_passes(func);
}

TEST(pattern, graph_rewrite)
{
200
    Shape shape{};
201 202 203
    pass::Manager pass_manager;
    pass_manager.register_pass<TestGraphRewrite>();

204 205 206 207 208 209 210 211
    {
        auto a = make_shared<op::Parameter>(element::i32, shape);
        auto b = make_shared<op::Parameter>(element::i32, shape);
        auto c = make_shared<op::Parameter>(element::i32, shape);
        auto iconst0 = construct_constant_node(0);
        auto graph_a = a + iconst0;
        auto graph_b = b + iconst0;

212
        auto f = std::make_shared<Function>(ngraph::NodeVector{a, b, graph_a, c, graph_b},
213
                                            ParameterVector{a, b, c});
214 215
        pass_manager.run_passes(f);

216 217
        ASSERT_TRUE(graph_a->output(0).get_target_inputs().empty());
        ASSERT_TRUE(graph_b->output(0).get_target_inputs().empty());
218

219
        auto expected = ngraph::NodeVector{a, b, a, c, b};
220
        ASSERT_TRUE(count_ops_of_type<op::Add>(f) == 0);
221 222
    }

223
    {
224 225
        auto a = make_shared<op::Parameter>(element::i32, shape);
        auto b = make_shared<op::Parameter>(element::i32, shape);
226 227 228 229
        auto iconst0 = construct_constant_node(0);
        auto sum = (a + iconst0);
        auto graph = b + sum;
        run_passes(pass_manager, graph, {a, b});
230
        ASSERT_EQ(graph->get_arguments().at(1), a);
231 232 233 234 235 236 237
        ASSERT_EQ(graph->input(1).get_source_output(),
                  a->output(0)); // graph's input points to a's output
        ASSERT_TRUE(sum->output(0)
                        .get_target_inputs()
                        .empty()); // graph's input is removed from sum's target inptus
        ASSERT_TRUE(a->output(0).get_target_inputs().count(
            graph->input(1))); // a's output feeds into graph's input
238 239 240
    }

    {
241 242
        auto a = make_shared<op::Parameter>(element::i32, shape);
        auto b = make_shared<op::Parameter>(element::i32, shape);
243 244 245 246
        auto iconst1 = construct_constant_node(1);
        auto mul = (a * iconst1);
        auto graph = b + mul;
        run_passes(pass_manager, graph, {a, b});
247
        ASSERT_EQ(graph->get_arguments().at(1), a);
248 249 250 251 252 253 254
        ASSERT_EQ(graph->input(1).get_source_output(),
                  a->output(0)); // graph's input points to a's output
        ASSERT_TRUE(mul->output(0)
                        .get_target_inputs()
                        .empty()); // graph's input is removed from sum's target inputs
        ASSERT_TRUE(a->output(0).get_target_inputs().count(
            graph->input(1))); // a's output feeds into graph's input
255 256 257
    }

    {
258 259
        auto a = make_shared<op::Parameter>(element::i32, shape);
        auto b = make_shared<op::Parameter>(element::i32, shape);
260 261 262
        auto iconst1 = construct_constant_node(1);
        auto graph = ((((a * iconst1) * iconst1) * iconst1) * iconst1) + b;
        run_passes(pass_manager, graph, {a, b});
263
        ASSERT_EQ(graph->get_arguments().at(0), a);
264 265 266 267
        ASSERT_EQ(graph->input(0).get_source_output(),
                  a->output(0)); // graph's input points to a's output
        ASSERT_TRUE(a->output(0).get_target_inputs().count(
            graph->input(0))); // a's output feeds into graph's input
268 269 270
    }

    {
271 272
        auto a = make_shared<op::Parameter>(element::i32, shape);
        auto b = make_shared<op::Parameter>(element::i32, shape);
273 274 275 276
        auto iconst0 = construct_constant_node(0);
        auto iconst1 = construct_constant_node(1);
        auto graph = b + (iconst0 + ((a + iconst0) * iconst1));
        run_passes(pass_manager, graph, {a, b});
277
        ASSERT_EQ(graph->get_arguments().at(1), a);
278 279 280 281
        ASSERT_EQ(graph->input(1).get_source_output(),
                  a->output(0)); // graph's input points to a's output
        ASSERT_TRUE(a->output(0).get_target_inputs().count(
            graph->input(1))); // a's output feeds into graph's input
282 283 284
    }

    {
285 286
        auto a = make_shared<op::Parameter>(element::i32, shape);
        auto b = make_shared<op::Parameter>(element::i32, shape);
287 288 289
        auto iconst1 = construct_constant_node(1);
        auto graph = b + (iconst1 * (iconst1 * (iconst1 * (iconst1 * a))));
        run_passes(pass_manager, graph, {a, b});
290
        ASSERT_EQ(graph->get_arguments().at(1), a);
291 292 293 294
        ASSERT_EQ(graph->input(1).get_source_output(),
                  a->output(0)); // graph's input points to a's output
        ASSERT_TRUE(a->output(0).get_target_inputs().count(
            graph->input(1))); // a's output feeds into graph's input
295 296 297 298 299
    }
}

TEST(pattern, matcher)
{
300
    Shape shape{};
301
    auto a = make_shared<op::Parameter>(element::i32, shape);
302
    TestMatcher n;
303
    ASSERT_TRUE(n.match(a, a));
304
    ASSERT_EQ(n.get_matched_nodes(), (NodeVector{a}));
305 306

    auto abs = make_shared<op::Abs>(a);
307
    auto any = std::make_shared<pattern::op::Skip>(a);
308
    ASSERT_TRUE(n.match(any, abs));
309
    ASSERT_EQ(n.get_matched_nodes(), (NodeVector{abs, a}));
310

311
    auto false_pred = [](std::shared_ptr<Node> /* no */) { return false; };
Nick Korovaiko's avatar
Nick Korovaiko committed
312
    auto any_false = std::make_shared<pattern::op::Skip>(a, false_pred);
313
    ASSERT_TRUE(n.match(any_false, a));
314
    ASSERT_EQ(n.get_matched_nodes(), (NodeVector{a, a}));
315

316
    auto pattern = std::make_shared<pattern::op::Label>(a);
317
    ASSERT_TRUE(n.match(pattern, a));
318
    ASSERT_EQ(n.get_pattern_map()[pattern], a);
319
    ASSERT_EQ(n.get_matched_nodes(), (NodeVector{a}));
320

Nick Korovaiko's avatar
Nick Korovaiko committed
321
    auto pattern_false = std::make_shared<pattern::op::Label>(a, false_pred);
322
    ASSERT_FALSE(n.match(pattern_false, a));
323
    ASSERT_EQ(n.get_matched_nodes(), (NodeVector{}));
324

325
    auto b = make_shared<op::Parameter>(element::i32, shape);
Nick Korovaiko's avatar
Nick Korovaiko committed
326

327 328 329
    auto is_bea = [](std::shared_ptr<Node> node) -> bool {
        return node->is_binary_elementwise_arithmetic();
    };
Nick Korovaiko's avatar
Nick Korovaiko committed
330
    auto bea = std::make_shared<pattern::op::Any>(a, is_bea, NodeVector{a, b});
331 332 333
    auto add_ab = a + b;
    ASSERT_TRUE(n.match(bea, add_ab));
    ASSERT_EQ(n.get_matched_nodes(), (NodeVector{add_ab, a, b}));
Nick Korovaiko's avatar
Nick Korovaiko committed
334 335 336 337 338
    ASSERT_TRUE(n.match(bea, b + a));

    auto bea_false = std::make_shared<pattern::op::Any>(a, false_pred, NodeVector{a, b});
    ASSERT_FALSE(n.match(bea_false, a + b));

339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355
    auto add_abs_b = abs + b;
    auto bea_any_of = std::make_shared<pattern::op::AnyOf>(a, is_bea, NodeVector{abs});
    ASSERT_TRUE(n.match(bea_any_of, add_abs_b));

    auto add_b_abs = b + abs;
    ASSERT_TRUE(n.match(bea_any_of, add_b_abs));

    auto bea_any_of_label =
        std::make_shared<pattern::op::Label>(a, nullptr, NodeVector{bea_any_of});
    ASSERT_TRUE(n.match(bea_any_of_label, add_b_abs));
    ASSERT_EQ(n.get_pattern_map()[bea_any_of_label], add_b_abs);

    auto abs_label = std::make_shared<pattern::op::Label>(a, nullptr, NodeVector{abs});
    auto bea_label_any_of = std::make_shared<pattern::op::AnyOf>(a, is_bea, NodeVector{abs_label});
    ASSERT_TRUE(n.match(bea_label_any_of, add_b_abs));
    ASSERT_EQ(n.get_pattern_map()[abs_label], abs);

Nick Korovaiko's avatar
Nick Korovaiko committed
356 357 358 359 360
    auto bea_label = std::make_shared<pattern::op::Label>(a, nullptr, NodeVector{bea});
    auto ab = a + b;
    ASSERT_TRUE(n.match(bea_label, ab));
    ASSERT_EQ(n.get_pattern_map()[bea_label], ab);

361
    auto d = make_shared<op::Parameter>(element::i32, shape);
362 363 364
    ASSERT_FALSE(n.match(d, b));

    ASSERT_FALSE(n.match(abs + b, b + b));
365 366 367 368 369
    ASSERT_EQ(n.get_matched_nodes(), (NodeVector{}));

    auto add_absb = abs + b;
    ASSERT_TRUE(n.match(any + b, add_absb));
    ASSERT_EQ(n.get_matched_nodes(), (NodeVector{add_absb, abs, a, b}));
370

371
    ASSERT_TRUE(n.match(pattern + b, add_absb));
372
    ASSERT_EQ(n.get_pattern_map()[pattern], abs);
373
    ASSERT_EQ(n.get_matched_nodes(), (NodeVector{add_absb, abs, b}));
374

375
    ASSERT_TRUE(n.match(b + pattern, add_absb));
376
    ASSERT_EQ(n.get_pattern_map()[pattern], abs);
377
    ASSERT_EQ(n.get_matched_nodes(), (NodeVector{add_absb, abs, b}));
378

379
    auto c = make_shared<op::Parameter>(element::i32, shape);
380 381
    auto mul_add_absb = c * (add_absb);
    ASSERT_TRUE(n.match(c * (b + pattern), mul_add_absb));
382
    ASSERT_EQ(n.get_pattern_map()[pattern], abs);
383
    ASSERT_EQ(n.get_matched_nodes(), (NodeVector{mul_add_absb, c, add_absb, abs, b}));
384

385
    ASSERT_TRUE(n.match(c * (any + b), mul_add_absb)); // nested any
386
    ASSERT_EQ(n.get_matched_nodes(), (NodeVector{mul_add_absb, c, add_absb, abs, a, b}));
387
    ASSERT_TRUE(n.match(c * (any + b), (b + abs) * c)); // permutations w/ any
388
    auto mul_c_add_ab = c * add_ab;
389 390
    ASSERT_TRUE(n.match(c * (any_false + b), c * (a + b)));  // nested any
    ASSERT_TRUE(n.match(c * (any_false + b), mul_c_add_ab)); // permutations w/ any_false
391
    ASSERT_EQ(n.get_matched_nodes(), (NodeVector{mul_c_add_ab, c, add_ab, a, a, b}));
392 393 394

    auto iconst1_0 = construct_constant_node(1);
    auto iconst1_1 = construct_constant_node(1);
395
    ASSERT_TRUE(n.match(pattern * iconst1_0, a * iconst1_1)); // different iconst
396
    ASSERT_EQ(n.get_pattern_map()[pattern], a);
397 398
    auto fconst1_0 = op::Constant::create(element::f32, shape, {1});
    auto patternf = std::make_shared<pattern::op::Label>(fconst1_0);
399
    ASSERT_TRUE(n.match(patternf * fconst1_0, a * iconst1_1)); // different iconst
400

401
    // Subgraph labels
402
    auto add = a + b;
403
    auto label = std::make_shared<pattern::op::Label>(add, nullptr, NodeVector{add});
404 405
    ASSERT_TRUE(n.match(label, add));
    ASSERT_EQ(n.get_pattern_map()[label], add);
406
    ASSERT_EQ(n.get_matched_nodes(), (NodeVector{add, add, a, b}));
407 408 409 410 411 412

    ASSERT_FALSE(n.match(label, a - b));

    ASSERT_TRUE(n.match(make_shared<op::Abs>(label), make_shared<op::Abs>(add)));
    ASSERT_EQ(n.get_pattern_map()[label], add);

413
    // Correct argument order
414 415 416 417 418 419 420 421 422
    ASSERT_FALSE(n.match(b - a, a - b));
    auto aab = a * (a - b);
    auto paab = pattern * (pattern - b);
    ASSERT_TRUE(n.match(paab, aab));
    auto aba = a * (b - a);
    ASSERT_FALSE(n.match(paab, aba));
    auto paba = pattern * (b - pattern);
    ASSERT_FALSE(n.match(paba, aab));

423
    // Correlations
424 425
    auto label1 = std::make_shared<pattern::op::Label>(a);
    auto tmp = label1 + b;
426
    auto label2 = std::make_shared<pattern::op::Label>(tmp, nullptr, NodeVector{tmp});
427
    auto sub_label1 = label1 - label2;
428 429
    auto sub_add = a - add;
    ASSERT_TRUE(n.match(sub_label1, sub_add));
430 431
    ASSERT_EQ(n.get_pattern_map()[label1], a);
    ASSERT_EQ(n.get_pattern_map()[label2], add);
432
    ASSERT_EQ(n.get_matched_nodes(), (NodeVector{sub_add, a, add, add, a, b}));
433 434 435 436 437 438 439

    ASSERT_FALSE(n.match(sub_label1, add - a));

    auto add_label1 = label1 + label2;
    ASSERT_TRUE(n.match(add_label1, add + a));
    ASSERT_EQ(n.get_pattern_map()[label1], a);
    ASSERT_EQ(n.get_pattern_map()[label2], add);
440

441 442 443 444 445 446 447 448 449 450 451 452 453 454 455
    // Or
    ASSERT_TRUE(n.match(std::make_shared<pattern::op::Or>(OutputVector{a + b, a - b}), a + b));
    ASSERT_TRUE(n.match(std::make_shared<pattern::op::Or>(OutputVector{a + b, a - b}), a - b));

    // Branch
    {
        auto branch = std::make_shared<pattern::op::Branch>();
        auto star = std::make_shared<pattern::op::Or>(
            OutputVector{branch, std::make_shared<pattern::op::True>()});
        auto pattern = star + star;
        branch->set_destination(pattern);
        ASSERT_TRUE(n.match(pattern, ((a + b) + (b + a) + a)));
        ASSERT_EQ(n.get_matched_nodes().size(), 4);
    }

456 457
    // strict mode
    {
458
        TestMatcher sm(Output<Node>{}, "TestMatcher", true);
459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477
        // exact shape and type
        auto scalar_param = make_shared<op::Parameter>(element::i32, Shape{});
        auto label_dynamic_shape =
            make_shared<pattern::op::Label>(element::i32, PartialShape::dynamic());
        auto param = make_shared<op::Parameter>(element::f32, Shape{});
        ASSERT_TRUE(sm.match(label_dynamic_shape, scalar_param));
        // wrong type
        auto scalar_param_wrong_type = make_shared<op::Parameter>(element::f32, Shape{});
        ASSERT_FALSE(sm.match(label, scalar_param_wrong_type));
        // dynamic dimension
        auto label_dynamic_dimension =
            make_shared<pattern::op::Label>(element::i32, PartialShape{Dimension::dynamic()});
        auto vector_param = make_shared<op::Parameter>(element::i32, Shape{10});
        ASSERT_TRUE(sm.match(label_dynamic_dimension, vector_param));
        // dynamic type
        auto label_dynamic_type =
            make_shared<pattern::op::Label>(element::dynamic, PartialShape{Dimension::dynamic()});
        ASSERT_TRUE(sm.match(label_dynamic_type, vector_param));
    }
478
}
479

480 481
TEST(pattern, mean)
{
482
    // construct mean
483
    TestMatcher n;
484 485 486 487 488 489 490 491 492 493 494 495 496

    auto input = std::make_shared<op::Parameter>(element::f32, Shape{2, 3});
    auto N = op::Constant::create(element::f32, Shape{3}, {2, 2, 2});
    auto sum_input1 = std::make_shared<op::Sum>(input, AxisSet{0});
    auto mean = std::make_shared<op::Divide>(sum_input1, N);

    auto mean_graph = construct_mean_graph();
    ASSERT_TRUE(n.match(mean_graph, mean));
    ASSERT_EQ(n.get_pattern_map()[mean_graph], mean);
}

TEST(pattern, variance)
{
497
    // construct variance
498
    TestMatcher n;
499 500 501 502 503 504 505 506 507 508 509 510 511
    auto N = op::Constant::create(element::f32, Shape{3}, {2, 2, 2});
    auto input = std::make_shared<pattern::op::Label>(element::f32, Shape{2, 3});
    auto input_sq = std::make_shared<op::Multiply>(input, input);
    auto sum_input = std::make_shared<op::Sum>(input, AxisSet{0});
    auto square_sumed_input = std::make_shared<op::Multiply>(sum_input, sum_input);
    auto sum_squared_input = std::make_shared<op::Sum>(input_sq, AxisSet{0});
    auto avg_input_sum_sq = std::make_shared<op::Divide>(square_sumed_input, N);
    auto xmu = std::make_shared<op::Subtract>(sum_squared_input, avg_input_sum_sq);
    auto variance = std::make_shared<op::Divide>(xmu, N);

    auto var_graph = construct_variance_graph();
    ASSERT_TRUE(n.match(var_graph, variance));
    ASSERT_EQ(n.get_pattern_map()[var_graph], variance);
512
}
513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536

TEST(pattern, previous_matches)
{
    using ngraph::pattern::Matcher;
    Shape shape{};
    Matcher::PatternMap previous_matches;
    auto a = make_shared<op::Parameter>(element::i32, shape);
    auto b = make_shared<op::Parameter>(element::i32, shape);
    auto pattern = std::make_shared<pattern::op::Label>(b);
    auto abs = make_shared<op::Abs>(a);
    auto add = abs + b;
    {
        Matcher n(pattern + b);
        ASSERT_TRUE(n.match(add, previous_matches));
        ASSERT_EQ(n.get_pattern_map()[pattern], abs);
    }

    {
        Matcher n(pattern + b);
        previous_matches.insert(std::make_pair(pattern, a));
        ASSERT_FALSE(n.match(add, previous_matches));
    }
}

537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563
TEST(pattern, test_sort)
{
    using ngraph::pattern::Matcher;
    Shape shape{};

    auto a = make_shared<op::Parameter>(element::i32, shape);
    auto b = make_shared<op::Parameter>(element::i32, shape);
    auto abs1 = make_shared<op::Abs>(a);
    auto abs2 = make_shared<op::Abs>(b);
    auto add = abs1 + abs2;

    auto pa = make_shared<op::Parameter>(element::i32, shape);
    auto pb = make_shared<op::Parameter>(element::i32, shape);
    auto pabs1 = make_shared<op::Abs>(pa);
    auto pabs1_label = std::make_shared<pattern::op::Label>(pabs1);
    auto pabs2 = make_shared<op::Abs>(b);
    auto padd = pabs1_label + pabs2;

    {
        Matcher n1(padd);
        ASSERT_TRUE(n1.match(add));
        auto r1 = n1.get_pattern_map()[pabs1_label];
        ASSERT_TRUE(n1.match(add));
        ASSERT_EQ(r1, n1.get_pattern_map()[pabs1_label]);
    }
}

564 565
TEST(pattern, recurrent_pattern)
{
566
    using ngraph::pattern::RecurrentMatcher;
567 568 569 570 571 572 573 574 575 576 577 578
    Shape shape{};
    ngraph::pattern::Matcher::PatternMap previous_matches;
    auto a = make_shared<op::Parameter>(element::i32, shape);
    auto b = make_shared<op::Parameter>(element::i32, shape);
    auto rpattern = std::make_shared<pattern::op::Label>(b);
    auto iconst0 = construct_constant_node(0);
    auto abs = make_shared<op::Abs>(a);
    auto add1 = iconst0 + b;
    auto add2 = iconst0 + add1;
    auto add3 = iconst0 + add2;
    auto padd = iconst0 + rpattern;
    std::set<std::shared_ptr<pattern::op::Label>> empty_correlated_matches;
579
    RecurrentMatcher rm(padd, rpattern, empty_correlated_matches);
580 581 582
    ASSERT_TRUE(rm.match(add3));
    ASSERT_EQ(rm.get_number_of_bound_labels(), 1);
    auto recurrent_matches = rm.get_bound_nodes_for_pattern(rpattern);
583 584 585 586
    ASSERT_EQ(recurrent_matches.at(0), add2);
    ASSERT_EQ(recurrent_matches.at(1), add1);
    ASSERT_EQ(recurrent_matches.at(2), b);

587
    // Multiple labels in a reccuring pattern
588 589 590 591 592
    auto iconst1 = construct_constant_node(1);
    auto iconst_label = std::make_shared<pattern::op::Label>(iconst1, nullptr, NodeVector{iconst1});
    auto add2_2 = iconst1 + add1;
    auto add3_2 = iconst0 + add2_2;
    auto padd2 = iconst_label + rpattern;
593
    RecurrentMatcher rm2(padd2, rpattern, empty_correlated_matches);
594 595 596
    ASSERT_TRUE(rm2.match(add3_2));
    ASSERT_EQ(rm2.get_number_of_bound_labels(), 2);
    recurrent_matches = rm2.get_bound_nodes_for_pattern(rpattern);
597 598 599
    ASSERT_EQ(recurrent_matches.at(0), add2_2);
    ASSERT_EQ(recurrent_matches.at(1), add1);
    ASSERT_EQ(recurrent_matches.at(2), b);
600
    auto iconst_matches = rm2.get_bound_nodes_for_pattern(iconst_label);
601 602 603 604
    ASSERT_EQ(iconst_matches.at(0), iconst0);
    ASSERT_EQ(iconst_matches.at(1), iconst1);
    ASSERT_EQ(iconst_matches.at(2), iconst0);

605
    // Non-matching correlated labels
606 607
    std::set<std::shared_ptr<pattern::op::Label>> correlated_matches;
    correlated_matches.insert(iconst_label);
608
    RecurrentMatcher rm3(padd2, rpattern, correlated_matches);
609 610 611
    ASSERT_TRUE(rm3.match(add3_2));
    ASSERT_EQ(rm3.get_number_of_bound_labels(), 2);
    iconst_matches = rm3.get_bound_nodes_for_pattern(iconst_label);
612 613 614
    ASSERT_EQ(iconst_matches.size(), 1);
    ASSERT_EQ(iconst_matches.at(0), iconst0);

615 616
    // Matching correlated labels and
    // testing if RecurrentMatcher can be reused for different nodes
617 618 619
    ASSERT_TRUE(rm3.match(add3));
    ASSERT_EQ(rm3.get_number_of_bound_labels(), 2);
    recurrent_matches = rm3.get_bound_nodes_for_pattern(rpattern);
620 621 622
    ASSERT_EQ(recurrent_matches.at(0), add2);
    ASSERT_EQ(recurrent_matches.at(1), add1);
    ASSERT_EQ(recurrent_matches.at(2), b);
623
    iconst_matches = rm3.get_bound_nodes_for_pattern(iconst_label);
624 625 626 627
    ASSERT_EQ(iconst_matches.at(0), iconst0);
    ASSERT_EQ(iconst_matches.at(1), iconst0);
    ASSERT_EQ(iconst_matches.at(2), iconst0);
}
628 629 630 631 632 633 634 635 636 637 638 639 640

class TestRecurrentGraphRewrite : public ngraph::pass::RecurrentGraphRewrite
{
public:
    void construct_recurrent_add()
    {
        Shape shape{};
        auto iconst0 = construct_constant_node(0);
        auto iconst_label =
            std::make_shared<pattern::op::Label>(iconst0, nullptr, NodeVector{iconst0});
        auto rpattern = std::make_shared<pattern::op::Label>(element::i32, shape);
        auto padd = iconst_label + rpattern;

641
        auto callback = [iconst_label, rpattern](pattern::RecurrentMatcher& rm) {
642 643 644 645 646 647
            NGRAPH_DEBUG << "In a callback for construct_recurrent_add against "
                         << rm.get_match_root()->get_name();

            auto iconst_matches = rm.get_bound_nodes_for_pattern(iconst_label);

            auto is_iconst_zero = [](std::shared_ptr<Node> n) {
648
                bool result = ngraph::is_zero(n);
649
                NGRAPH_DEBUG << n->get_name() << " is " << (result ? " a zero " : " not a zero");
650
                return ngraph::is_zero(n);
651 652 653 654 655 656 657 658 659 660 661
            };

            bool are_all_iconst_zeros =
                std::all_of(iconst_matches.begin(), iconst_matches.end(), is_iconst_zero);

            if (!are_all_iconst_zeros)
            {
                return false;
            }

            auto number_of_adds = rm.get_number_of_recurrent_matches();
662 663
            // replace the topmost add with the seed (i.e. the first parameter to add)
            // matches are added in reverse order (i.e. the first match is the topmost node)
664 665 666 667 668 669 670 671
            auto arg = rm.get_bound_nodes_for_pattern(rpattern).at(number_of_adds - 1);
            NGRAPH_DEBUG << "Replacing " << rm.get_match_root()->get_name() << " with "
                         << arg->get_name();
            ngraph::replace_node(rm.get_match_root(), arg);
            return true;
        };

        std::set<std::shared_ptr<pattern::op::Label>> empty_correlated_matches;
672 673
        auto rm = make_shared<pattern::RecurrentMatcher>(padd, rpattern, empty_correlated_matches);
        this->add_matcher(rm, callback);
674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703
    }

    TestRecurrentGraphRewrite()
        : RecurrentGraphRewrite()
    {
        construct_recurrent_add();
    }
};

TEST(pattern, recurrent_graph_rewrite)
{
    Shape shape{};
    pass::Manager pass_manager;
    pass_manager.register_pass<TestRecurrentGraphRewrite>();

    {
        auto a = make_shared<op::Parameter>(element::i32, shape);
        auto iconst0 = construct_constant_node(0);
        auto add_a1 = a + iconst0;
        auto add_a2 = add_a1 + iconst0;
        auto add_a3 = add_a2 + iconst0;
        auto abs_add_a3 = std::make_shared<op::Abs>(add_a3);

        auto b = make_shared<op::Parameter>(element::i32, shape);
        auto add_b1 = b + iconst0;
        auto add_b2 = add_b1 + iconst0;
        auto abs_add_b2 = std::make_shared<op::Abs>(add_b2);

        auto graph = abs_add_a3 * abs_add_b2;

704
        auto f = std::make_shared<Function>(ngraph::NodeVector{graph}, ParameterVector{a, b});
705 706
        pass_manager.run_passes(f);

707 708
        auto left_abs = graph->get_argument(0);
        auto add_a = left_abs->get_argument(0);
709 710
        ASSERT_EQ(add_a, a);

711 712
        auto right_abs = graph->get_argument(1);
        auto add_b = right_abs->get_argument(0);
713 714 715
        ASSERT_EQ(add_b, b);
    }
}
Nick Korovaiko's avatar
Nick Korovaiko committed
716 717 718 719 720 721 722 723 724 725 726 727

TEST(pattern, label_on_skip)
{
    Shape shape{2, 2};
    auto a = make_shared<op::Parameter>(element::i32, shape);
    auto b = make_shared<op::Parameter>(element::i32, Shape{});
    auto iconst = ngraph::make_zero(element::i32, Shape{});
    auto label = std::make_shared<pattern::op::Label>(iconst);
    auto const_label =
        std::make_shared<pattern::op::Label>(iconst, ngraph::is_zero, NodeVector{iconst});

    auto bcst_pred = [](std::shared_ptr<Node> n) {
728
        return as_type_ptr<op::Broadcast>(n) != nullptr;
Nick Korovaiko's avatar
Nick Korovaiko committed
729 730 731 732 733
    };

    auto bcst = std::make_shared<pattern::op::Skip>(const_label, bcst_pred);
    auto bcst_label = std::make_shared<pattern::op::Label>(bcst, nullptr, NodeVector{bcst});
    auto matcher = std::make_shared<pattern::Matcher>(
734
        std::make_shared<op::Multiply>(label, bcst_label), "label_on_skip");
Nick Korovaiko's avatar
Nick Korovaiko committed
735 736 737 738 739 740 741 742 743 744 745 746 747

    auto const_broadcast = make_shared<op::Broadcast>(iconst, shape, AxisSet{0, 1});
    auto mul = a * const_broadcast;
    auto mul_scalar = b * iconst;
    ASSERT_TRUE(matcher->match(mul));
    ASSERT_EQ(matcher->get_pattern_map()[bcst_label], const_broadcast);
    ASSERT_EQ(matcher->get_pattern_map()[const_label], iconst);
    ASSERT_EQ(matcher->get_pattern_map()[label], a);
    ASSERT_TRUE(matcher->match(mul_scalar));
    ASSERT_EQ(matcher->get_pattern_map()[bcst_label], iconst);
    ASSERT_EQ(matcher->get_pattern_map()[const_label], iconst);
    ASSERT_EQ(matcher->get_pattern_map()[label], b);
}
Nick Korovaiko's avatar
Nick Korovaiko committed
748 749 750 751 752 753

TEST(pattern, is_contained_match)
{
    Shape shape{};
    auto a = make_shared<op::Parameter>(element::i32, shape);
    auto absn = make_shared<op::Abs>(a);
754
    TestMatcher n;
Nick Korovaiko's avatar
Nick Korovaiko committed
755 756 757 758 759 760 761 762 763 764 765 766 767

    auto label_a = std::make_shared<pattern::op::Label>(a);
    auto label_abs = make_shared<op::Abs>(a);
    ASSERT_TRUE(n.match(label_abs, absn));
    auto result_absn = make_shared<op::Result>(absn);
    ASSERT_TRUE(n.is_contained_match());

    auto absn2 = make_shared<op::Abs>(absn);
    auto result_absn2 = make_shared<op::Result>(absn2);
    auto label_abs2 = make_shared<op::Abs>(label_abs);
    ASSERT_TRUE(n.match(label_abs2, absn2));
    ASSERT_FALSE(n.is_contained_match());
}