pattern.cpp 32 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
16 17 18 19 20 21 22 23

#include <algorithm>
#include <cstdio>
#include <iostream>
#include <list>
#include <memory>

#include "gtest/gtest.h"
24
#include "ngraph/file_util.hpp"
varun-intel's avatar
varun-intel committed
25
#include "ngraph/graph_util.hpp"
26 27
#include "ngraph/log.hpp"
#include "ngraph/ngraph.hpp"
28 29 30 31 32 33 34 35 36
#include "ngraph/op/add.hpp"
#include "ngraph/op/batch_norm.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/sqrt.hpp"
#include "ngraph/op/subtract.hpp"
#include "ngraph/op/sum.hpp"
#include "ngraph/op/sum.hpp"
37 38 39 40
#include "ngraph/pass/graph_rewrite.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/label.hpp"
41
#include "ngraph/pattern/op/skip.hpp"
42 43
#include "ngraph/runtime/cpu/pass/cpu_fusion.hpp"
#include "ngraph/serializer.hpp"
44
#include "util/matcher.hpp"
45
#include "util/test_tools.hpp"
46 47 48 49

using namespace ngraph;
using namespace std;

50 51 52 53 54 55 56 57
template <typename T>
std::shared_ptr<Node> create_reduction(const std::shared_ptr<Node>& node,
                                       const std::string& init_val,
                                       const AxisSet& reduction_axes)
{
    const auto& et = node->get_element_type();
    auto f_A = std::make_shared<op::Parameter>(et, Shape{});
    auto f_B = std::make_shared<op::Parameter>(et, Shape{});
58 59
    auto f =
        std::make_shared<Function>(std::make_shared<T>(f_A, f_B), op::ParameterVector{f_A, f_B});
60 61 62 63 64 65 66 67 68 69

    auto init = std::make_shared<op::Constant>(et, Shape{}, std::vector<std::string>({init_val}));
    return std::make_shared<op::Reduce>(node, init, f, reduction_axes);
}

std::shared_ptr<Node> xla_sum(const std::shared_ptr<Node>& node, const AxisSet& reduction_axes)
{
    return create_reduction<op::Add>(node, "0", reduction_axes);
}

70 71
static std::shared_ptr<Node> construct_constant_node(int n)
{
72 73 74 75 76 77 78 79
    return op::Constant::create(element::i32, Shape{}, {n});
}

bool sum_predicate(std::shared_ptr<Node> gn)
{
    NGRAPH_DEBUG << "pred_v2 : looking at " << gn->get_name();
    if (auto r = std::dynamic_pointer_cast<op::Reduce>(gn))
    {
80 81
        auto reducee = gn->get_argument(0);
        auto reduce_constant = gn->get_argument(1);
82

83
        if (!ngraph::is_zero(reduce_constant))
84 85 86 87
        {
            return false;
        }

88
        auto result = r->get_functions()[0]->get_result()->get_argument(0);
89 90
        NGRAPH_DEBUG << "looking at function's result  " << result->get_name();
        if (auto sum = std::dynamic_pointer_cast<op::Add>(result))
91
        {
92 93
            auto parm1 = std::dynamic_pointer_cast<op::Parameter>(sum->get_argument(0));
            auto parm2 = std::dynamic_pointer_cast<op::Parameter>(sum->get_argument(1));
94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109

            const auto parm_or_nil = [](std::shared_ptr<Node> p) {
                return p ? p->get_name() : std::string("(nil)");
            };
            NGRAPH_DEBUG << "parm1 = " << parm_or_nil(parm1) << " , parm2 = " << parm_or_nil(parm2)
                         << std::endl;
            if (parm1 && parm2 && parm1 != parm2)
            {
                return true;
            }
        }
    }

    return false;
}

110
std::shared_ptr<pattern::op::Label> construct_sum_pattern() // for the sake of explicitness
111 112
{
    return std::make_shared<pattern::op::Label>(element::i32, Shape{}, sum_predicate);
113 114
}

115 116 117 118 119 120 121 122 123 124 125 126
static std::shared_ptr<pattern::op::Label> construct_variance_graph()
{
    // construct varaiance
    auto N = op::Constant::create(element::f32, Shape{3}, {2, 2, 2});
    auto input = std::make_shared<pattern::op::Label>(element::f32, Shape{2, 3});
    auto input_sq = std::make_shared<op::Multiply>(input, input);
    auto sum_input = std::make_shared<op::Sum>(input, AxisSet{0});
    auto square_sumed_input = std::make_shared<op::Multiply>(sum_input, sum_input);
    auto sum_squared_input = std::make_shared<op::Sum>(input_sq, AxisSet{0});
    auto avg_input_sum_sq = std::make_shared<op::Divide>(square_sumed_input, N);
    auto xmu = std::make_shared<op::Subtract>(sum_squared_input, avg_input_sum_sq);
    auto variance = std::make_shared<op::Divide>(xmu, N);
127 128
    auto variance_label =
        std::make_shared<pattern::op::Label>(variance, nullptr, NodeVector{variance});
129 130 131 132 133 134

    return variance_label;
}

static std::shared_ptr<pattern::op::Label> construct_mean_graph()
{
135
    // construct mean;
136 137 138 139
    auto input = std::make_shared<pattern::op::Label>(element::f32, Shape{2, 3});
    auto N = op::Constant::create(element::f32, Shape{3}, {2, 2, 2});
    auto sum_input1 = std::make_shared<op::Sum>(input, AxisSet{0});
    auto mean = std::make_shared<op::Divide>(sum_input1, N);
140
    auto mean_label = std::make_shared<pattern::op::Label>(mean, nullptr, NodeVector{mean});
141 142 143
    return mean_label;
}

144 145 146 147 148
class TestGraphRewrite : public ngraph::pass::GraphRewrite
{
public:
    void construct_multiply_by_one()
    {
149
        // pattern #1 : a * 1 = a
150
        auto iconst1 = construct_constant_node(1);
151
        auto pattern = std::make_shared<pattern::op::Label>(iconst1);
152

153
        ngraph::pattern::graph_rewrite_callback callback = [pattern](pattern::Matcher& m) {
154
            NGRAPH_DEBUG << "In a callback for construct_multiply_by_one against "
155 156
                         << m.get_match_root()->get_name();
            assert(m.get_match_root()->get_arguments().size() == 2);
157

158 159
            auto pattern_map = m.get_pattern_map();

160 161
            size_t const_node_index =
                m.get_match_root()->get_arguments().at(0) == pattern_map[pattern];
162
            auto const_node = dynamic_pointer_cast<op::Constant>(
163 164
                m.get_match_root()->get_arguments().at(const_node_index));
            auto second_node = m.get_match_root()->get_arguments().at(const_node_index);
165 166
            NGRAPH_DEBUG << "second_node = " << second_node->get_name()
                         << " , pattern = " << pattern_map[pattern]->get_name();
167

168 169
            if (pattern_map[pattern]->get_element_type() != const_node->get_element_type() ||
                pattern_map[pattern]->get_shape() != const_node->get_shape())
170
            {
171
                NGRAPH_DEBUG << "Operands' types and/or shape don't match";
172
                return false;
173 174
            }

175
            auto const_values = const_node->get_vector<int32_t>();
176 177 178 179 180
            bool all_ones =
                std::all_of(begin(const_values), end(const_values), [](int e) { return e == 1; });

            if (!all_ones)
            {
181
                NGRAPH_DEBUG << "Constant vector's values aren't equal to 1";
182
                return false;
183
            }
184

185
            ngraph::replace_node(m.get_match_root(), pattern_map[pattern]);
186
            return true;
187 188 189 190 191 192 193 194
        };

        auto m = make_shared<TestMatcher>(pattern * iconst1, callback);
        this->add_matcher(m);
    }

    void construct_add_zero()
    {
195
        // pattern #2 : a + 0 = a
196
        auto iconst0 = construct_constant_node(0);
197
        auto pattern = std::make_shared<pattern::op::Label>(iconst0);
198

199
        auto callback = [pattern](pattern::Matcher& m) {
200
            NGRAPH_DEBUG << "In a callback for construct_add_zero against "
201 202
                         << m.get_match_root()->get_name();
            assert(m.get_match_root()->get_arguments().size() == 2);
203

204 205
            auto pattern_map = m.get_pattern_map();

206 207
            size_t const_node_index =
                m.get_match_root()->get_arguments().at(0) == pattern_map[pattern];
208
            auto const_node = dynamic_pointer_cast<op::Constant>(
209 210
                m.get_match_root()->get_arguments().at(const_node_index));
            auto second_node = m.get_match_root()->get_arguments().at(const_node_index);
211 212
            NGRAPH_DEBUG << "second_node = " << second_node->get_name()
                         << " , pattern = " << pattern_map[pattern]->get_name();
213

214 215
            if (pattern_map[pattern]->get_element_type() != const_node->get_element_type() ||
                pattern_map[pattern]->get_shape() != const_node->get_shape())
216
            {
217
                NGRAPH_DEBUG << "Operands' types and/or shape don't match";
218
                return false;
219 220
            }

221
            auto const_values = const_node->get_vector<int>();
222 223 224 225 226
            bool all_zeros =
                std::all_of(begin(const_values), end(const_values), [](int e) { return e == 0; });

            if (!all_zeros)
            {
227
                NGRAPH_DEBUG << "Constant vector's values aren't equal to 0";
228
                return false;
229 230
            }

231
            ngraph::replace_node(m.get_match_root(), pattern_map[pattern]);
232
            return true;
233 234 235 236 237 238
        };

        auto m = make_shared<TestMatcher>(pattern + iconst0, callback);
        this->add_matcher(m);
    }

239 240 241 242
    void construct_sum()
    {
        auto sum_pattern = construct_sum_pattern();

243
        ngraph::pattern::graph_rewrite_callback callback = [](pattern::Matcher& m) {
244
            NGRAPH_DEBUG << "In a callback for construct_sum_pattern against "
245 246
                         << m.get_match_root()->get_name();
            auto reduce = std::dynamic_pointer_cast<op::Reduce>(m.get_match_root());
247 248
            auto reducee = reduce->get_inputs().at(0).get_output().get_node();
            NGRAPH_DEBUG << "reducee = " << reducee->get_name();
249 250
            auto sum =
                std::shared_ptr<ngraph::Node>(new op::Sum(reducee, reduce->get_reduction_axes()));
251

252
            ngraph::replace_node(m.get_match_root(), sum);
253
            return true;
254 255 256 257 258 259
        };

        auto m = make_shared<TestMatcher>(sum_pattern, callback);
        this->add_matcher(m);
    }

260 261 262 263 264
    TestGraphRewrite()
        : GraphRewrite()
    {
        construct_multiply_by_one();
        construct_add_zero();
265
        construct_sum();
266 267 268 269 270 271 272
    }
};

static void run_passes(pass::Manager& pass_manager,
                       shared_ptr<Node> graph,
                       std::vector<shared_ptr<op::Parameter>> parms)
{
273
    auto func = make_shared<Function>(graph, op::ParameterVector{parms});
274 275 276 277 278
    pass_manager.run_passes(func);
}

TEST(pattern, graph_rewrite)
{
279
    Shape shape{};
280 281 282
    pass::Manager pass_manager;
    pass_manager.register_pass<TestGraphRewrite>();

283 284 285 286 287 288 289 290
    {
        auto a = make_shared<op::Parameter>(element::i32, shape);
        auto b = make_shared<op::Parameter>(element::i32, shape);
        auto c = make_shared<op::Parameter>(element::i32, shape);
        auto iconst0 = construct_constant_node(0);
        auto graph_a = a + iconst0;
        auto graph_b = b + iconst0;

291 292
        auto f = std::make_shared<Function>(ngraph::NodeVector{a, b, graph_a, c, graph_b},
                                            op::ParameterVector{a, b, c});
293 294 295 296 297
        pass_manager.run_passes(f);

        ASSERT_TRUE(graph_a->get_output_inputs(0).empty());
        ASSERT_TRUE(graph_b->get_output_inputs(0).empty());

298
        auto expected = ngraph::NodeVector{a, b, a, c, b};
299
        ASSERT_TRUE(count_ops_of_type<op::Add>(f) == 0);
300 301
    }

302
    {
303 304
        auto a = make_shared<op::Parameter>(element::i32, shape);
        auto b = make_shared<op::Parameter>(element::i32, shape);
305 306 307 308
        auto iconst0 = construct_constant_node(0);
        auto sum = (a + iconst0);
        auto graph = b + sum;
        run_passes(pass_manager, graph, {a, b});
309
        ASSERT_EQ(graph->get_arguments().at(1), a);
310
        ASSERT_EQ(&graph->get_inputs().at(1).get_output(),
311
                  &a->get_outputs().at(0)); // graph's input points to a's output
312
        ASSERT_TRUE(sum->get_output_inputs(0)
313
                        .empty()); // graph's input is removed from sum's output.get_inputs()
314
        ASSERT_TRUE(a->get_outputs().at(0).get_inputs().count(
315
            &graph->get_inputs().at(1))); // a's output feeds into graph's input
316 317 318
    }

    {
319 320
        auto a = make_shared<op::Parameter>(element::i32, shape);
        auto b = make_shared<op::Parameter>(element::i32, shape);
321 322 323 324
        auto iconst1 = construct_constant_node(1);
        auto mul = (a * iconst1);
        auto graph = b + mul;
        run_passes(pass_manager, graph, {a, b});
325
        ASSERT_EQ(graph->get_arguments().at(1), a);
326
        ASSERT_EQ(&graph->get_inputs().at(1).get_output(),
327
                  &a->get_outputs().at(0)); // graph's input points to a's output
328 329 330
        ASSERT_TRUE(mul->get_outputs()
                        .at(0)
                        .get_inputs()
331
                        .empty()); // graph's input is removed from sum's output.get_inputs()
332
        ASSERT_TRUE(a->get_outputs().at(0).get_inputs().count(
333
            &graph->get_inputs().at(1))); // a's output feeds into graph's input
334 335 336
    }

    {
337 338
        auto a = make_shared<op::Parameter>(element::i32, shape);
        auto b = make_shared<op::Parameter>(element::i32, shape);
339 340 341
        auto iconst1 = construct_constant_node(1);
        auto graph = ((((a * iconst1) * iconst1) * iconst1) * iconst1) + b;
        run_passes(pass_manager, graph, {a, b});
342
        ASSERT_EQ(graph->get_arguments().at(0), a);
343
        ASSERT_EQ(&graph->get_inputs().at(0).get_output(),
344
                  &a->get_outputs().at(0)); // graph's input points to a's output
345
        ASSERT_TRUE(a->get_outputs().at(0).get_inputs().count(
346
            &graph->get_inputs().at(0))); // a's output feeds into graph's input
347 348 349
    }

    {
350 351
        auto a = make_shared<op::Parameter>(element::i32, shape);
        auto b = make_shared<op::Parameter>(element::i32, shape);
352 353 354 355
        auto iconst0 = construct_constant_node(0);
        auto iconst1 = construct_constant_node(1);
        auto graph = b + (iconst0 + ((a + iconst0) * iconst1));
        run_passes(pass_manager, graph, {a, b});
356
        ASSERT_EQ(graph->get_arguments().at(1), a);
357
        ASSERT_EQ(&graph->get_inputs().at(1).get_output(),
358
                  &a->get_outputs().at(0)); // graph's input points to a's output
359
        ASSERT_TRUE(a->get_outputs().at(0).get_inputs().count(
360
            &graph->get_inputs().at(1))); // a's output feeds into graph's input
361 362 363
    }

    {
364 365
        auto a = make_shared<op::Parameter>(element::i32, shape);
        auto b = make_shared<op::Parameter>(element::i32, shape);
366 367 368
        auto iconst1 = construct_constant_node(1);
        auto graph = b + (iconst1 * (iconst1 * (iconst1 * (iconst1 * a))));
        run_passes(pass_manager, graph, {a, b});
369
        ASSERT_EQ(graph->get_arguments().at(1), a);
370
        ASSERT_EQ(&graph->get_inputs().at(1).get_output(),
371
                  &a->get_outputs().at(0)); // graph's input points to a's output
372
        ASSERT_TRUE(a->get_outputs().at(0).get_inputs().count(
373
            &graph->get_inputs().at(1))); // a's output feeds into graph's input
374
    }
375

376
    // Sum rewrite
377 378 379 380 381 382 383 384 385 386
    {
        auto parm = make_shared<op::Parameter>(element::i32, Shape{2, 2});
        auto axes = AxisSet{0, 1};
        auto sum_graph = xla_sum(parm, axes);
        auto innermost_abs = make_shared<op::Abs>(sum_graph);

        auto nested_sum_graph = make_shared<op::Abs>(
            make_shared<op::Abs>(make_shared<op::Abs>(make_shared<op::Abs>(innermost_abs))));

        run_passes(pass_manager, nested_sum_graph, {parm});
387
        auto sum = std::dynamic_pointer_cast<op::Sum>(innermost_abs->get_argument(0));
388 389
        ASSERT_TRUE(sum);
        ASSERT_EQ(sum->get_reduction_axes(), axes);
390
        ASSERT_EQ(sum->get_argument(0), parm);
391
    }
392 393
}

394 395 396 397 398 399 400 401 402 403 404
std::ostream& operator<<(std::ostream& os, const ngraph::NodeVector& nv)
{
    std::vector<std::string> names;
    for (auto n : nv)
    {
        names.push_back(n->get_name());
    }
    os << vector_to_string(names);
    return os;
}

405 406
TEST(pattern, matcher)
{
407
    Shape shape{};
408
    auto a = make_shared<op::Parameter>(element::i32, shape);
409 410
    TestMatcher n(nullptr);
    ASSERT_TRUE(n.match(a, a));
411
    ASSERT_EQ(n.get_matched_nodes(), (NodeVector{a}));
412 413

    auto abs = make_shared<op::Abs>(a);
414
    auto any = std::make_shared<pattern::op::Skip>(a);
415
    ASSERT_TRUE(n.match(any, abs));
416
    ASSERT_EQ(n.get_matched_nodes(), (NodeVector{abs, a}));
417

Nick Korovaiko's avatar
Nick Korovaiko committed
418 419
    auto false_pred = [](std::shared_ptr<Node> no) { return false; };
    auto any_false = std::make_shared<pattern::op::Skip>(a, false_pred);
420
    ASSERT_TRUE(n.match(any_false, a));
421
    ASSERT_EQ(n.get_matched_nodes(), (NodeVector{a, a}));
422

423
    auto pattern = std::make_shared<pattern::op::Label>(a);
424
    ASSERT_TRUE(n.match(pattern, a));
425
    ASSERT_EQ(n.get_pattern_map()[pattern], a);
426
    ASSERT_EQ(n.get_matched_nodes(), (NodeVector{a}));
427

Nick Korovaiko's avatar
Nick Korovaiko committed
428
    auto pattern_false = std::make_shared<pattern::op::Label>(a, false_pred);
429
    ASSERT_FALSE(n.match(pattern_false, a));
430
    ASSERT_EQ(n.get_matched_nodes(), (NodeVector{}));
431

432
    auto b = make_shared<op::Parameter>(element::i32, shape);
Nick Korovaiko's avatar
Nick Korovaiko committed
433 434 435

    auto is_bea = pattern::has_class<op::util::BinaryElementwiseArithmetic>();
    auto bea = std::make_shared<pattern::op::Any>(a, is_bea, NodeVector{a, b});
436 437 438
    auto add_ab = a + b;
    ASSERT_TRUE(n.match(bea, add_ab));
    ASSERT_EQ(n.get_matched_nodes(), (NodeVector{add_ab, a, b}));
Nick Korovaiko's avatar
Nick Korovaiko committed
439 440 441 442 443 444 445 446 447 448
    ASSERT_TRUE(n.match(bea, b + a));

    auto bea_false = std::make_shared<pattern::op::Any>(a, false_pred, NodeVector{a, b});
    ASSERT_FALSE(n.match(bea_false, a + b));

    auto bea_label = std::make_shared<pattern::op::Label>(a, nullptr, NodeVector{bea});
    auto ab = a + b;
    ASSERT_TRUE(n.match(bea_label, ab));
    ASSERT_EQ(n.get_pattern_map()[bea_label], ab);

449
    auto d = make_shared<op::Parameter>(element::i32, shape);
450 451 452
    ASSERT_FALSE(n.match(d, b));

    ASSERT_FALSE(n.match(abs + b, b + b));
453 454 455 456 457
    ASSERT_EQ(n.get_matched_nodes(), (NodeVector{}));

    auto add_absb = abs + b;
    ASSERT_TRUE(n.match(any + b, add_absb));
    ASSERT_EQ(n.get_matched_nodes(), (NodeVector{add_absb, abs, a, b}));
458

459
    ASSERT_TRUE(n.match(pattern + b, add_absb));
460
    ASSERT_EQ(n.get_pattern_map()[pattern], abs);
461
    ASSERT_EQ(n.get_matched_nodes(), (NodeVector{add_absb, abs, b}));
462

463
    ASSERT_TRUE(n.match(b + pattern, add_absb));
464
    ASSERT_EQ(n.get_pattern_map()[pattern], abs);
465
    ASSERT_EQ(n.get_matched_nodes(), (NodeVector{add_absb, abs, b}));
466

467
    auto c = make_shared<op::Parameter>(element::i32, shape);
468 469
    auto mul_add_absb = c * (add_absb);
    ASSERT_TRUE(n.match(c * (b + pattern), mul_add_absb));
470
    ASSERT_EQ(n.get_pattern_map()[pattern], abs);
471
    ASSERT_EQ(n.get_matched_nodes(), (NodeVector{mul_add_absb, c, add_absb, abs, b}));
472

473
    ASSERT_TRUE(n.match(c * (any + b), mul_add_absb)); // nested any
474
    ASSERT_EQ(n.get_matched_nodes(), (NodeVector{mul_add_absb, c, add_absb, abs, a, b}));
475
    ASSERT_TRUE(n.match(c * (any + b), (b + abs) * c)); // permutations w/ any
476
    auto mul_c_add_ab = c * add_ab;
477 478
    ASSERT_TRUE(n.match(c * (any_false + b), c * (a + b)));  // nested any
    ASSERT_TRUE(n.match(c * (any_false + b), mul_c_add_ab)); // permutations w/ any_false
479
    ASSERT_EQ(n.get_matched_nodes(), (NodeVector{mul_c_add_ab, c, add_ab, a, a, b}));
480 481 482

    auto iconst1_0 = construct_constant_node(1);
    auto iconst1_1 = construct_constant_node(1);
483
    ASSERT_TRUE(n.match(pattern * iconst1_0, a * iconst1_1)); // different iconst
484
    ASSERT_EQ(n.get_pattern_map()[pattern], a);
485 486
    auto fconst1_0 = op::Constant::create(element::f32, shape, {1});
    auto patternf = std::make_shared<pattern::op::Label>(fconst1_0);
487
    ASSERT_TRUE(n.match(patternf * fconst1_0, a * iconst1_1)); // different iconst
488

489
    // Subgraph labels
490
    auto add = a + b;
491
    auto label = std::make_shared<pattern::op::Label>(add, nullptr, NodeVector{add});
492 493
    ASSERT_TRUE(n.match(label, add));
    ASSERT_EQ(n.get_pattern_map()[label], add);
494
    ASSERT_EQ(n.get_matched_nodes(), (NodeVector{add, add, a, b}));
495 496 497 498 499 500

    ASSERT_FALSE(n.match(label, a - b));

    ASSERT_TRUE(n.match(make_shared<op::Abs>(label), make_shared<op::Abs>(add)));
    ASSERT_EQ(n.get_pattern_map()[label], add);

501
    // Correct argument order
502 503 504 505 506 507 508 509 510
    ASSERT_FALSE(n.match(b - a, a - b));
    auto aab = a * (a - b);
    auto paab = pattern * (pattern - b);
    ASSERT_TRUE(n.match(paab, aab));
    auto aba = a * (b - a);
    ASSERT_FALSE(n.match(paab, aba));
    auto paba = pattern * (b - pattern);
    ASSERT_FALSE(n.match(paba, aab));

511
    // Correlations
512 513
    auto label1 = std::make_shared<pattern::op::Label>(a);
    auto tmp = label1 + b;
514
    auto label2 = std::make_shared<pattern::op::Label>(tmp, nullptr, NodeVector{tmp});
515
    auto sub_label1 = label1 - label2;
516 517
    auto sub_add = a - add;
    ASSERT_TRUE(n.match(sub_label1, sub_add));
518 519
    ASSERT_EQ(n.get_pattern_map()[label1], a);
    ASSERT_EQ(n.get_pattern_map()[label2], add);
520
    ASSERT_EQ(n.get_matched_nodes(), (NodeVector{sub_add, a, add, add, a, b}));
521 522 523 524 525 526 527

    ASSERT_FALSE(n.match(sub_label1, add - a));

    auto add_label1 = label1 + label2;
    ASSERT_TRUE(n.match(add_label1, add + a));
    ASSERT_EQ(n.get_pattern_map()[label1], a);
    ASSERT_EQ(n.get_pattern_map()[label2], add);
528
}
529 530 531

TEST(pattern, sum)
{
532
    // Sum
533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550
    TestMatcher n(nullptr);
    auto reducee_const = std::make_shared<op::Constant>(
        element::i32, Shape{2, 2}, std::vector<std::string>({"0", "0", "0", "0"}));
    auto sum_graph = xla_sum(reducee_const, AxisSet{0, 1});

    auto reduce_label = construct_sum_pattern();
    ASSERT_TRUE(n.match(reduce_label, sum_graph));
    ASSERT_EQ(n.get_pattern_map()[reduce_label], sum_graph);

    auto nested_sum_graph = make_shared<op::Abs>(make_shared<op::Abs>(
        make_shared<op::Abs>(make_shared<op::Abs>(make_shared<op::Abs>(sum_graph)))));

    auto nested_reduce_label = make_shared<op::Abs>(make_shared<op::Abs>(
        make_shared<op::Abs>(make_shared<op::Abs>(make_shared<op::Abs>(reduce_label)))));

    ASSERT_TRUE(n.match(nested_reduce_label, nested_sum_graph));
    ASSERT_EQ(n.get_pattern_map()[reduce_label], sum_graph);
}
551 552 553

TEST(pattern, mean)
{
554
    // construct mean
555 556 557 558 559 560 561 562 563 564 565 566 567 568
    TestMatcher n(nullptr);

    auto input = std::make_shared<op::Parameter>(element::f32, Shape{2, 3});
    auto N = op::Constant::create(element::f32, Shape{3}, {2, 2, 2});
    auto sum_input1 = std::make_shared<op::Sum>(input, AxisSet{0});
    auto mean = std::make_shared<op::Divide>(sum_input1, N);

    auto mean_graph = construct_mean_graph();
    ASSERT_TRUE(n.match(mean_graph, mean));
    ASSERT_EQ(n.get_pattern_map()[mean_graph], mean);
}

TEST(pattern, variance)
{
569
    // construct variance
570 571 572 573 574 575 576 577 578 579 580 581 582 583
    TestMatcher n(nullptr);
    auto N = op::Constant::create(element::f32, Shape{3}, {2, 2, 2});
    auto input = std::make_shared<pattern::op::Label>(element::f32, Shape{2, 3});
    auto input_sq = std::make_shared<op::Multiply>(input, input);
    auto sum_input = std::make_shared<op::Sum>(input, AxisSet{0});
    auto square_sumed_input = std::make_shared<op::Multiply>(sum_input, sum_input);
    auto sum_squared_input = std::make_shared<op::Sum>(input_sq, AxisSet{0});
    auto avg_input_sum_sq = std::make_shared<op::Divide>(square_sumed_input, N);
    auto xmu = std::make_shared<op::Subtract>(sum_squared_input, avg_input_sum_sq);
    auto variance = std::make_shared<op::Divide>(xmu, N);

    auto var_graph = construct_variance_graph();
    ASSERT_TRUE(n.match(var_graph, variance));
    ASSERT_EQ(n.get_pattern_map()[var_graph], variance);
584
}
585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610

TEST(pattern, previous_matches)
{
    using ngraph::pattern::Matcher;
    Shape shape{};
    Matcher::PatternMap previous_matches;
    auto a = make_shared<op::Parameter>(element::i32, shape);
    auto b = make_shared<op::Parameter>(element::i32, shape);
    auto pattern = std::make_shared<pattern::op::Label>(b);
    auto abs = make_shared<op::Abs>(a);
    auto add = abs + b;
    {
        Matcher n(pattern + b);
        ASSERT_TRUE(n.match(add, previous_matches));
        ASSERT_EQ(n.get_pattern_map()[pattern], abs);
    }

    {
        Matcher n(pattern + b);
        previous_matches.insert(std::make_pair(pattern, a));
        ASSERT_FALSE(n.match(add, previous_matches));
    }
}

TEST(pattern, recurrent_pattern)
{
611
    using ngraph::pattern::RecurrentMatcher;
612 613 614 615 616 617 618 619 620 621 622 623
    Shape shape{};
    ngraph::pattern::Matcher::PatternMap previous_matches;
    auto a = make_shared<op::Parameter>(element::i32, shape);
    auto b = make_shared<op::Parameter>(element::i32, shape);
    auto rpattern = std::make_shared<pattern::op::Label>(b);
    auto iconst0 = construct_constant_node(0);
    auto abs = make_shared<op::Abs>(a);
    auto add1 = iconst0 + b;
    auto add2 = iconst0 + add1;
    auto add3 = iconst0 + add2;
    auto padd = iconst0 + rpattern;
    std::set<std::shared_ptr<pattern::op::Label>> empty_correlated_matches;
624 625 626 627
    RecurrentMatcher rm(padd, rpattern, empty_correlated_matches, nullptr);
    ASSERT_TRUE(rm.match(add3));
    ASSERT_EQ(rm.get_number_of_bound_labels(), 1);
    auto recurrent_matches = rm.get_bound_nodes_for_pattern(rpattern);
628 629 630 631
    ASSERT_EQ(recurrent_matches.at(0), add2);
    ASSERT_EQ(recurrent_matches.at(1), add1);
    ASSERT_EQ(recurrent_matches.at(2), b);

632
    // Multiple labels in a reccuring pattern
633 634 635 636 637
    auto iconst1 = construct_constant_node(1);
    auto iconst_label = std::make_shared<pattern::op::Label>(iconst1, nullptr, NodeVector{iconst1});
    auto add2_2 = iconst1 + add1;
    auto add3_2 = iconst0 + add2_2;
    auto padd2 = iconst_label + rpattern;
638 639 640 641
    RecurrentMatcher rm2(padd2, rpattern, empty_correlated_matches, nullptr);
    ASSERT_TRUE(rm2.match(add3_2));
    ASSERT_EQ(rm2.get_number_of_bound_labels(), 2);
    recurrent_matches = rm2.get_bound_nodes_for_pattern(rpattern);
642 643 644
    ASSERT_EQ(recurrent_matches.at(0), add2_2);
    ASSERT_EQ(recurrent_matches.at(1), add1);
    ASSERT_EQ(recurrent_matches.at(2), b);
645
    auto iconst_matches = rm2.get_bound_nodes_for_pattern(iconst_label);
646 647 648 649
    ASSERT_EQ(iconst_matches.at(0), iconst0);
    ASSERT_EQ(iconst_matches.at(1), iconst1);
    ASSERT_EQ(iconst_matches.at(2), iconst0);

650
    // Non-matching correlated labels
651 652
    std::set<std::shared_ptr<pattern::op::Label>> correlated_matches;
    correlated_matches.insert(iconst_label);
653 654 655 656
    RecurrentMatcher rm3(padd2, rpattern, correlated_matches, nullptr);
    ASSERT_TRUE(rm3.match(add3_2));
    ASSERT_EQ(rm3.get_number_of_bound_labels(), 2);
    iconst_matches = rm3.get_bound_nodes_for_pattern(iconst_label);
657 658 659
    ASSERT_EQ(iconst_matches.size(), 1);
    ASSERT_EQ(iconst_matches.at(0), iconst0);

660 661
    // Matching correlated labels and
    // testing if RecurrentMatcher can be reused for different nodes
662 663 664
    ASSERT_TRUE(rm3.match(add3));
    ASSERT_EQ(rm3.get_number_of_bound_labels(), 2);
    recurrent_matches = rm3.get_bound_nodes_for_pattern(rpattern);
665 666 667
    ASSERT_EQ(recurrent_matches.at(0), add2);
    ASSERT_EQ(recurrent_matches.at(1), add1);
    ASSERT_EQ(recurrent_matches.at(2), b);
668
    iconst_matches = rm3.get_bound_nodes_for_pattern(iconst_label);
669 670 671 672
    ASSERT_EQ(iconst_matches.at(0), iconst0);
    ASSERT_EQ(iconst_matches.at(1), iconst0);
    ASSERT_EQ(iconst_matches.at(2), iconst0);
}
673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695

class TestRecurrentGraphRewrite : public ngraph::pass::RecurrentGraphRewrite
{
public:
    void construct_recurrent_add()
    {
        Shape shape{};
        auto iconst0 = construct_constant_node(0);
        auto iconst_label =
            std::make_shared<pattern::op::Label>(iconst0, nullptr, NodeVector{iconst0});
        auto rpattern = std::make_shared<pattern::op::Label>(element::i32, shape);
        auto padd = iconst_label + rpattern;

        auto sum_pattern = construct_sum_pattern();

        ngraph::pattern::recurrent_graph_rewrite_callback callback = [iconst_label, rpattern](
            pattern::RecurrentMatcher& rm) {
            NGRAPH_DEBUG << "In a callback for construct_recurrent_add against "
                         << rm.get_match_root()->get_name();

            auto iconst_matches = rm.get_bound_nodes_for_pattern(iconst_label);

            auto is_iconst_zero = [](std::shared_ptr<Node> n) {
696
                bool result = ngraph::is_zero(n);
697
                NGRAPH_DEBUG << n->get_name() << " is " << (result ? " a zero " : " not a zero");
698
                return ngraph::is_zero(n);
699 700 701 702 703 704 705 706 707 708 709
            };

            bool are_all_iconst_zeros =
                std::all_of(iconst_matches.begin(), iconst_matches.end(), is_iconst_zero);

            if (!are_all_iconst_zeros)
            {
                return false;
            }

            auto number_of_adds = rm.get_number_of_recurrent_matches();
710 711
            // replace the topmost add with the seed (i.e. the first parameter to add)
            // matches are added in reverse order (i.e. the first match is the topmost node)
712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755
            auto arg = rm.get_bound_nodes_for_pattern(rpattern).at(number_of_adds - 1);
            NGRAPH_DEBUG << "Replacing " << rm.get_match_root()->get_name() << " with "
                         << arg->get_name();
            ngraph::replace_node(rm.get_match_root(), arg);
            return true;
        };

        std::set<std::shared_ptr<pattern::op::Label>> empty_correlated_matches;
        auto rm = make_shared<pattern::RecurrentMatcher>(
            padd, rpattern, empty_correlated_matches, callback);
        this->add_matcher(rm);
    }

    TestRecurrentGraphRewrite()
        : RecurrentGraphRewrite()
    {
        construct_recurrent_add();
    }
};

TEST(pattern, recurrent_graph_rewrite)
{
    Shape shape{};
    pass::Manager pass_manager;
    pass_manager.register_pass<TestRecurrentGraphRewrite>();

    {
        auto a = make_shared<op::Parameter>(element::i32, shape);
        auto iconst0 = construct_constant_node(0);
        auto add_a1 = a + iconst0;
        auto add_a2 = add_a1 + iconst0;
        auto add_a3 = add_a2 + iconst0;
        auto abs_add_a3 = std::make_shared<op::Abs>(add_a3);

        auto b = make_shared<op::Parameter>(element::i32, shape);
        auto add_b1 = b + iconst0;
        auto add_b2 = add_b1 + iconst0;
        auto abs_add_b2 = std::make_shared<op::Abs>(add_b2);

        auto graph = abs_add_a3 * abs_add_b2;

        auto f = std::make_shared<Function>(ngraph::NodeVector{graph}, op::ParameterVector{a, b});
        pass_manager.run_passes(f);

756 757
        auto left_abs = graph->get_argument(0);
        auto add_a = left_abs->get_argument(0);
758 759
        ASSERT_EQ(add_a, a);

760 761
        auto right_abs = graph->get_argument(1);
        auto add_b = right_abs->get_argument(0);
762 763 764
        ASSERT_EQ(add_b, b);
    }
}
Nick Korovaiko's avatar
Nick Korovaiko committed
765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796

TEST(pattern, label_on_skip)
{
    Shape shape{2, 2};
    auto a = make_shared<op::Parameter>(element::i32, shape);
    auto b = make_shared<op::Parameter>(element::i32, Shape{});
    auto iconst = ngraph::make_zero(element::i32, Shape{});
    auto label = std::make_shared<pattern::op::Label>(iconst);
    auto const_label =
        std::make_shared<pattern::op::Label>(iconst, ngraph::is_zero, NodeVector{iconst});

    auto bcst_pred = [](std::shared_ptr<Node> n) {
        return std::dynamic_pointer_cast<op::Broadcast>(n) != nullptr;
    };

    auto bcst = std::make_shared<pattern::op::Skip>(const_label, bcst_pred);
    auto bcst_label = std::make_shared<pattern::op::Label>(bcst, nullptr, NodeVector{bcst});
    auto matcher = std::make_shared<pattern::Matcher>(
        std::make_shared<op::Multiply>(label, bcst_label), nullptr);

    auto const_broadcast = make_shared<op::Broadcast>(iconst, shape, AxisSet{0, 1});
    auto mul = a * const_broadcast;
    auto mul_scalar = b * iconst;
    ASSERT_TRUE(matcher->match(mul));
    ASSERT_EQ(matcher->get_pattern_map()[bcst_label], const_broadcast);
    ASSERT_EQ(matcher->get_pattern_map()[const_label], iconst);
    ASSERT_EQ(matcher->get_pattern_map()[label], a);
    ASSERT_TRUE(matcher->match(mul_scalar));
    ASSERT_EQ(matcher->get_pattern_map()[bcst_label], iconst);
    ASSERT_EQ(matcher->get_pattern_map()[const_label], iconst);
    ASSERT_EQ(matcher->get_pattern_map()[label], b);
}
Nick Korovaiko's avatar
Nick Korovaiko committed
797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816

TEST(pattern, is_contained_match)
{
    Shape shape{};
    auto a = make_shared<op::Parameter>(element::i32, shape);
    auto absn = make_shared<op::Abs>(a);
    TestMatcher n(nullptr);

    auto label_a = std::make_shared<pattern::op::Label>(a);
    auto label_abs = make_shared<op::Abs>(a);
    ASSERT_TRUE(n.match(label_abs, absn));
    auto result_absn = make_shared<op::Result>(absn);
    ASSERT_TRUE(n.is_contained_match());

    auto absn2 = make_shared<op::Abs>(absn);
    auto result_absn2 = make_shared<op::Result>(absn2);
    auto label_abs2 = make_shared<op::Abs>(label_abs);
    ASSERT_TRUE(n.match(label_abs2, absn2));
    ASSERT_FALSE(n.is_contained_match());
}