pattern.cpp 32.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
16 17 18 19 20 21 22 23

#include <algorithm>
#include <cstdio>
#include <iostream>
#include <list>
#include <memory>

#include "gtest/gtest.h"
24
#include "ngraph/file_util.hpp"
varun-intel's avatar
varun-intel committed
25
#include "ngraph/graph_util.hpp"
26 27
#include "ngraph/log.hpp"
#include "ngraph/ngraph.hpp"
28 29 30 31 32 33 34 35 36
#include "ngraph/op/add.hpp"
#include "ngraph/op/batch_norm.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/sqrt.hpp"
#include "ngraph/op/subtract.hpp"
#include "ngraph/op/sum.hpp"
#include "ngraph/op/sum.hpp"
37 38 39 40
#include "ngraph/pass/graph_rewrite.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/label.hpp"
41
#include "ngraph/pattern/op/skip.hpp"
42 43
#include "ngraph/runtime/cpu/pass/cpu_fusion.hpp"
#include "ngraph/serializer.hpp"
44
#include "util/matcher.hpp"
45
#include "util/test_tools.hpp"
46 47 48 49

using namespace ngraph;
using namespace std;

50 51 52 53 54 55 56 57
template <typename T>
std::shared_ptr<Node> create_reduction(const std::shared_ptr<Node>& node,
                                       const std::string& init_val,
                                       const AxisSet& reduction_axes)
{
    const auto& et = node->get_element_type();
    auto f_A = std::make_shared<op::Parameter>(et, Shape{});
    auto f_B = std::make_shared<op::Parameter>(et, Shape{});
58
    auto f = std::make_shared<Function>(std::make_shared<T>(f_A, f_B), ParameterVector{f_A, f_B});
59 60 61 62 63 64 65 66 67 68

    auto init = std::make_shared<op::Constant>(et, Shape{}, std::vector<std::string>({init_val}));
    return std::make_shared<op::Reduce>(node, init, f, reduction_axes);
}

std::shared_ptr<Node> xla_sum(const std::shared_ptr<Node>& node, const AxisSet& reduction_axes)
{
    return create_reduction<op::Add>(node, "0", reduction_axes);
}

69 70
static std::shared_ptr<Node> construct_constant_node(int n)
{
71 72 73 74 75 76 77 78
    return op::Constant::create(element::i32, Shape{}, {n});
}

bool sum_predicate(std::shared_ptr<Node> gn)
{
    NGRAPH_DEBUG << "pred_v2 : looking at " << gn->get_name();
    if (auto r = std::dynamic_pointer_cast<op::Reduce>(gn))
    {
79 80
        auto reducee = gn->get_argument(0);
        auto reduce_constant = gn->get_argument(1);
81

82
        if (!ngraph::is_zero(reduce_constant))
83 84 85 86
        {
            return false;
        }

87
        auto result = r->get_functions()[0]->get_result()->get_argument(0);
88 89
        NGRAPH_DEBUG << "looking at function's result  " << result->get_name();
        if (auto sum = std::dynamic_pointer_cast<op::Add>(result))
90
        {
91 92
            auto parm1 = std::dynamic_pointer_cast<op::Parameter>(sum->get_argument(0));
            auto parm2 = std::dynamic_pointer_cast<op::Parameter>(sum->get_argument(1));
93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108

            const auto parm_or_nil = [](std::shared_ptr<Node> p) {
                return p ? p->get_name() : std::string("(nil)");
            };
            NGRAPH_DEBUG << "parm1 = " << parm_or_nil(parm1) << " , parm2 = " << parm_or_nil(parm2)
                         << std::endl;
            if (parm1 && parm2 && parm1 != parm2)
            {
                return true;
            }
        }
    }

    return false;
}

109
std::shared_ptr<pattern::op::Label> construct_sum_pattern() // for the sake of explicitness
110 111
{
    return std::make_shared<pattern::op::Label>(element::i32, Shape{}, sum_predicate);
112 113
}

114 115 116 117 118 119 120 121 122 123 124 125
static std::shared_ptr<pattern::op::Label> construct_variance_graph()
{
    // construct varaiance
    auto N = op::Constant::create(element::f32, Shape{3}, {2, 2, 2});
    auto input = std::make_shared<pattern::op::Label>(element::f32, Shape{2, 3});
    auto input_sq = std::make_shared<op::Multiply>(input, input);
    auto sum_input = std::make_shared<op::Sum>(input, AxisSet{0});
    auto square_sumed_input = std::make_shared<op::Multiply>(sum_input, sum_input);
    auto sum_squared_input = std::make_shared<op::Sum>(input_sq, AxisSet{0});
    auto avg_input_sum_sq = std::make_shared<op::Divide>(square_sumed_input, N);
    auto xmu = std::make_shared<op::Subtract>(sum_squared_input, avg_input_sum_sq);
    auto variance = std::make_shared<op::Divide>(xmu, N);
126 127
    auto variance_label =
        std::make_shared<pattern::op::Label>(variance, nullptr, NodeVector{variance});
128 129 130 131 132 133

    return variance_label;
}

static std::shared_ptr<pattern::op::Label> construct_mean_graph()
{
134
    // construct mean;
135 136 137 138
    auto input = std::make_shared<pattern::op::Label>(element::f32, Shape{2, 3});
    auto N = op::Constant::create(element::f32, Shape{3}, {2, 2, 2});
    auto sum_input1 = std::make_shared<op::Sum>(input, AxisSet{0});
    auto mean = std::make_shared<op::Divide>(sum_input1, N);
139
    auto mean_label = std::make_shared<pattern::op::Label>(mean, nullptr, NodeVector{mean});
140 141 142
    return mean_label;
}

143 144 145 146 147
class TestGraphRewrite : public ngraph::pass::GraphRewrite
{
public:
    void construct_multiply_by_one()
    {
148
        // pattern #1 : a * 1 = a
149
        auto iconst1 = construct_constant_node(1);
150
        auto pattern = std::make_shared<pattern::op::Label>(iconst1);
151

152
        ngraph::pattern::graph_rewrite_callback callback = [pattern](pattern::Matcher& m) {
153
            NGRAPH_DEBUG << "In a callback for construct_multiply_by_one against "
154 155
                         << m.get_match_root()->get_name();
            assert(m.get_match_root()->get_arguments().size() == 2);
156

157 158
            auto pattern_map = m.get_pattern_map();

159 160
            size_t const_node_index =
                m.get_match_root()->get_arguments().at(0) == pattern_map[pattern];
161
            auto const_node = dynamic_pointer_cast<op::Constant>(
162 163
                m.get_match_root()->get_arguments().at(const_node_index));
            auto second_node = m.get_match_root()->get_arguments().at(const_node_index);
164 165
            NGRAPH_DEBUG << "second_node = " << second_node->get_name()
                         << " , pattern = " << pattern_map[pattern]->get_name();
166

167 168
            if (pattern_map[pattern]->get_element_type() != const_node->get_element_type() ||
                pattern_map[pattern]->get_shape() != const_node->get_shape())
169
            {
170
                NGRAPH_DEBUG << "Operands' types and/or shape don't match";
171
                return false;
172 173
            }

174
            auto const_values = const_node->get_vector<int32_t>();
175 176 177 178 179
            bool all_ones =
                std::all_of(begin(const_values), end(const_values), [](int e) { return e == 1; });

            if (!all_ones)
            {
180
                NGRAPH_DEBUG << "Constant vector's values aren't equal to 1";
181
                return false;
182
            }
183

184
            ngraph::replace_node(m.get_match_root(), pattern_map[pattern]);
185
            return true;
186 187 188 189 190 191 192 193
        };

        auto m = make_shared<TestMatcher>(pattern * iconst1, callback);
        this->add_matcher(m);
    }

    void construct_add_zero()
    {
194
        // pattern #2 : a + 0 = a
195
        auto iconst0 = construct_constant_node(0);
196
        auto pattern = std::make_shared<pattern::op::Label>(iconst0);
197

198
        auto callback = [pattern](pattern::Matcher& m) {
199
            NGRAPH_DEBUG << "In a callback for construct_add_zero against "
200 201
                         << m.get_match_root()->get_name();
            assert(m.get_match_root()->get_arguments().size() == 2);
202

203 204
            auto pattern_map = m.get_pattern_map();

205 206
            size_t const_node_index =
                m.get_match_root()->get_arguments().at(0) == pattern_map[pattern];
207
            auto const_node = dynamic_pointer_cast<op::Constant>(
208 209
                m.get_match_root()->get_arguments().at(const_node_index));
            auto second_node = m.get_match_root()->get_arguments().at(const_node_index);
210 211
            NGRAPH_DEBUG << "second_node = " << second_node->get_name()
                         << " , pattern = " << pattern_map[pattern]->get_name();
212

213 214
            if (pattern_map[pattern]->get_element_type() != const_node->get_element_type() ||
                pattern_map[pattern]->get_shape() != const_node->get_shape())
215
            {
216
                NGRAPH_DEBUG << "Operands' types and/or shape don't match";
217
                return false;
218 219
            }

220
            auto const_values = const_node->get_vector<int>();
221 222 223 224 225
            bool all_zeros =
                std::all_of(begin(const_values), end(const_values), [](int e) { return e == 0; });

            if (!all_zeros)
            {
226
                NGRAPH_DEBUG << "Constant vector's values aren't equal to 0";
227
                return false;
228 229
            }

230
            ngraph::replace_node(m.get_match_root(), pattern_map[pattern]);
231
            return true;
232 233 234 235 236 237
        };

        auto m = make_shared<TestMatcher>(pattern + iconst0, callback);
        this->add_matcher(m);
    }

238 239 240 241
    void construct_sum()
    {
        auto sum_pattern = construct_sum_pattern();

242
        ngraph::pattern::graph_rewrite_callback callback = [](pattern::Matcher& m) {
243
            NGRAPH_DEBUG << "In a callback for construct_sum_pattern against "
244 245
                         << m.get_match_root()->get_name();
            auto reduce = std::dynamic_pointer_cast<op::Reduce>(m.get_match_root());
246 247
            auto reducee = reduce->get_inputs().at(0).get_output().get_node();
            NGRAPH_DEBUG << "reducee = " << reducee->get_name();
248 249
            auto sum =
                std::shared_ptr<ngraph::Node>(new op::Sum(reducee, reduce->get_reduction_axes()));
250

251
            ngraph::replace_node(m.get_match_root(), sum);
252
            return true;
253 254 255 256 257 258
        };

        auto m = make_shared<TestMatcher>(sum_pattern, callback);
        this->add_matcher(m);
    }

259 260 261 262 263
    TestGraphRewrite()
        : GraphRewrite()
    {
        construct_multiply_by_one();
        construct_add_zero();
264
        construct_sum();
265 266 267 268 269 270 271
    }
};

static void run_passes(pass::Manager& pass_manager,
                       shared_ptr<Node> graph,
                       std::vector<shared_ptr<op::Parameter>> parms)
{
272
    auto func = make_shared<Function>(graph, ParameterVector{parms});
273 274 275 276 277
    pass_manager.run_passes(func);
}

TEST(pattern, graph_rewrite)
{
278
    Shape shape{};
279 280 281
    pass::Manager pass_manager;
    pass_manager.register_pass<TestGraphRewrite>();

282 283 284 285 286 287 288 289
    {
        auto a = make_shared<op::Parameter>(element::i32, shape);
        auto b = make_shared<op::Parameter>(element::i32, shape);
        auto c = make_shared<op::Parameter>(element::i32, shape);
        auto iconst0 = construct_constant_node(0);
        auto graph_a = a + iconst0;
        auto graph_b = b + iconst0;

290
        auto f = std::make_shared<Function>(ngraph::NodeVector{a, b, graph_a, c, graph_b},
291
                                            ParameterVector{a, b, c});
292 293 294 295 296
        pass_manager.run_passes(f);

        ASSERT_TRUE(graph_a->get_output_inputs(0).empty());
        ASSERT_TRUE(graph_b->get_output_inputs(0).empty());

297
        auto expected = ngraph::NodeVector{a, b, a, c, b};
298
        ASSERT_TRUE(count_ops_of_type<op::Add>(f) == 0);
299 300
    }

301
    {
302 303
        auto a = make_shared<op::Parameter>(element::i32, shape);
        auto b = make_shared<op::Parameter>(element::i32, shape);
304 305 306 307
        auto iconst0 = construct_constant_node(0);
        auto sum = (a + iconst0);
        auto graph = b + sum;
        run_passes(pass_manager, graph, {a, b});
308
        ASSERT_EQ(graph->get_arguments().at(1), a);
309
        ASSERT_EQ(&graph->get_inputs().at(1).get_output(),
310
                  &a->get_outputs().at(0)); // graph's input points to a's output
311
        ASSERT_TRUE(sum->get_output_inputs(0)
312
                        .empty()); // graph's input is removed from sum's output.get_inputs()
313
        ASSERT_TRUE(a->get_outputs().at(0).get_inputs().count(
314
            &graph->get_inputs().at(1))); // a's output feeds into graph's input
315 316 317
    }

    {
318 319
        auto a = make_shared<op::Parameter>(element::i32, shape);
        auto b = make_shared<op::Parameter>(element::i32, shape);
320 321 322 323
        auto iconst1 = construct_constant_node(1);
        auto mul = (a * iconst1);
        auto graph = b + mul;
        run_passes(pass_manager, graph, {a, b});
324
        ASSERT_EQ(graph->get_arguments().at(1), a);
325
        ASSERT_EQ(&graph->get_inputs().at(1).get_output(),
326
                  &a->get_outputs().at(0)); // graph's input points to a's output
327 328 329
        ASSERT_TRUE(mul->get_outputs()
                        .at(0)
                        .get_inputs()
330
                        .empty()); // graph's input is removed from sum's output.get_inputs()
331
        ASSERT_TRUE(a->get_outputs().at(0).get_inputs().count(
332
            &graph->get_inputs().at(1))); // a's output feeds into graph's input
333 334 335
    }

    {
336 337
        auto a = make_shared<op::Parameter>(element::i32, shape);
        auto b = make_shared<op::Parameter>(element::i32, shape);
338 339 340
        auto iconst1 = construct_constant_node(1);
        auto graph = ((((a * iconst1) * iconst1) * iconst1) * iconst1) + b;
        run_passes(pass_manager, graph, {a, b});
341
        ASSERT_EQ(graph->get_arguments().at(0), a);
342
        ASSERT_EQ(&graph->get_inputs().at(0).get_output(),
343
                  &a->get_outputs().at(0)); // graph's input points to a's output
344
        ASSERT_TRUE(a->get_outputs().at(0).get_inputs().count(
345
            &graph->get_inputs().at(0))); // a's output feeds into graph's input
346 347 348
    }

    {
349 350
        auto a = make_shared<op::Parameter>(element::i32, shape);
        auto b = make_shared<op::Parameter>(element::i32, shape);
351 352 353 354
        auto iconst0 = construct_constant_node(0);
        auto iconst1 = construct_constant_node(1);
        auto graph = b + (iconst0 + ((a + iconst0) * iconst1));
        run_passes(pass_manager, graph, {a, b});
355
        ASSERT_EQ(graph->get_arguments().at(1), a);
356
        ASSERT_EQ(&graph->get_inputs().at(1).get_output(),
357
                  &a->get_outputs().at(0)); // graph's input points to a's output
358
        ASSERT_TRUE(a->get_outputs().at(0).get_inputs().count(
359
            &graph->get_inputs().at(1))); // a's output feeds into graph's input
360 361 362
    }

    {
363 364
        auto a = make_shared<op::Parameter>(element::i32, shape);
        auto b = make_shared<op::Parameter>(element::i32, shape);
365 366 367
        auto iconst1 = construct_constant_node(1);
        auto graph = b + (iconst1 * (iconst1 * (iconst1 * (iconst1 * a))));
        run_passes(pass_manager, graph, {a, b});
368
        ASSERT_EQ(graph->get_arguments().at(1), a);
369
        ASSERT_EQ(&graph->get_inputs().at(1).get_output(),
370
                  &a->get_outputs().at(0)); // graph's input points to a's output
371
        ASSERT_TRUE(a->get_outputs().at(0).get_inputs().count(
372
            &graph->get_inputs().at(1))); // a's output feeds into graph's input
373
    }
374

375
    // Sum rewrite
376 377 378 379 380 381 382 383 384 385
    {
        auto parm = make_shared<op::Parameter>(element::i32, Shape{2, 2});
        auto axes = AxisSet{0, 1};
        auto sum_graph = xla_sum(parm, axes);
        auto innermost_abs = make_shared<op::Abs>(sum_graph);

        auto nested_sum_graph = make_shared<op::Abs>(
            make_shared<op::Abs>(make_shared<op::Abs>(make_shared<op::Abs>(innermost_abs))));

        run_passes(pass_manager, nested_sum_graph, {parm});
386
        auto sum = std::dynamic_pointer_cast<op::Sum>(innermost_abs->get_argument(0));
387 388
        ASSERT_TRUE(sum);
        ASSERT_EQ(sum->get_reduction_axes(), axes);
389
        ASSERT_EQ(sum->get_argument(0), parm);
390
    }
391 392
}

393 394 395 396 397 398 399 400 401 402 403
std::ostream& operator<<(std::ostream& os, const ngraph::NodeVector& nv)
{
    std::vector<std::string> names;
    for (auto n : nv)
    {
        names.push_back(n->get_name());
    }
    os << vector_to_string(names);
    return os;
}

404 405
TEST(pattern, matcher)
{
406
    Shape shape{};
407
    auto a = make_shared<op::Parameter>(element::i32, shape);
408 409
    TestMatcher n(nullptr);
    ASSERT_TRUE(n.match(a, a));
410
    ASSERT_EQ(n.get_matched_nodes(), (NodeVector{a}));
411 412

    auto abs = make_shared<op::Abs>(a);
413
    auto any = std::make_shared<pattern::op::Skip>(a);
414
    ASSERT_TRUE(n.match(any, abs));
415
    ASSERT_EQ(n.get_matched_nodes(), (NodeVector{abs, a}));
416

Nick Korovaiko's avatar
Nick Korovaiko committed
417 418
    auto false_pred = [](std::shared_ptr<Node> no) { return false; };
    auto any_false = std::make_shared<pattern::op::Skip>(a, false_pred);
419
    ASSERT_TRUE(n.match(any_false, a));
420
    ASSERT_EQ(n.get_matched_nodes(), (NodeVector{a, a}));
421

422
    auto pattern = std::make_shared<pattern::op::Label>(a);
423
    ASSERT_TRUE(n.match(pattern, a));
424
    ASSERT_EQ(n.get_pattern_map()[pattern], a);
425
    ASSERT_EQ(n.get_matched_nodes(), (NodeVector{a}));
426

Nick Korovaiko's avatar
Nick Korovaiko committed
427
    auto pattern_false = std::make_shared<pattern::op::Label>(a, false_pred);
428
    ASSERT_FALSE(n.match(pattern_false, a));
429
    ASSERT_EQ(n.get_matched_nodes(), (NodeVector{}));
430

431
    auto b = make_shared<op::Parameter>(element::i32, shape);
Nick Korovaiko's avatar
Nick Korovaiko committed
432 433 434

    auto is_bea = pattern::has_class<op::util::BinaryElementwiseArithmetic>();
    auto bea = std::make_shared<pattern::op::Any>(a, is_bea, NodeVector{a, b});
435 436 437
    auto add_ab = a + b;
    ASSERT_TRUE(n.match(bea, add_ab));
    ASSERT_EQ(n.get_matched_nodes(), (NodeVector{add_ab, a, b}));
Nick Korovaiko's avatar
Nick Korovaiko committed
438 439 440 441 442
    ASSERT_TRUE(n.match(bea, b + a));

    auto bea_false = std::make_shared<pattern::op::Any>(a, false_pred, NodeVector{a, b});
    ASSERT_FALSE(n.match(bea_false, a + b));

443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459
    auto add_abs_b = abs + b;
    auto bea_any_of = std::make_shared<pattern::op::AnyOf>(a, is_bea, NodeVector{abs});
    ASSERT_TRUE(n.match(bea_any_of, add_abs_b));

    auto add_b_abs = b + abs;
    ASSERT_TRUE(n.match(bea_any_of, add_b_abs));

    auto bea_any_of_label =
        std::make_shared<pattern::op::Label>(a, nullptr, NodeVector{bea_any_of});
    ASSERT_TRUE(n.match(bea_any_of_label, add_b_abs));
    ASSERT_EQ(n.get_pattern_map()[bea_any_of_label], add_b_abs);

    auto abs_label = std::make_shared<pattern::op::Label>(a, nullptr, NodeVector{abs});
    auto bea_label_any_of = std::make_shared<pattern::op::AnyOf>(a, is_bea, NodeVector{abs_label});
    ASSERT_TRUE(n.match(bea_label_any_of, add_b_abs));
    ASSERT_EQ(n.get_pattern_map()[abs_label], abs);

Nick Korovaiko's avatar
Nick Korovaiko committed
460 461 462 463 464
    auto bea_label = std::make_shared<pattern::op::Label>(a, nullptr, NodeVector{bea});
    auto ab = a + b;
    ASSERT_TRUE(n.match(bea_label, ab));
    ASSERT_EQ(n.get_pattern_map()[bea_label], ab);

465
    auto d = make_shared<op::Parameter>(element::i32, shape);
466 467 468
    ASSERT_FALSE(n.match(d, b));

    ASSERT_FALSE(n.match(abs + b, b + b));
469 470 471 472 473
    ASSERT_EQ(n.get_matched_nodes(), (NodeVector{}));

    auto add_absb = abs + b;
    ASSERT_TRUE(n.match(any + b, add_absb));
    ASSERT_EQ(n.get_matched_nodes(), (NodeVector{add_absb, abs, a, b}));
474

475
    ASSERT_TRUE(n.match(pattern + b, add_absb));
476
    ASSERT_EQ(n.get_pattern_map()[pattern], abs);
477
    ASSERT_EQ(n.get_matched_nodes(), (NodeVector{add_absb, abs, b}));
478

479
    ASSERT_TRUE(n.match(b + pattern, add_absb));
480
    ASSERT_EQ(n.get_pattern_map()[pattern], abs);
481
    ASSERT_EQ(n.get_matched_nodes(), (NodeVector{add_absb, abs, b}));
482

483
    auto c = make_shared<op::Parameter>(element::i32, shape);
484 485
    auto mul_add_absb = c * (add_absb);
    ASSERT_TRUE(n.match(c * (b + pattern), mul_add_absb));
486
    ASSERT_EQ(n.get_pattern_map()[pattern], abs);
487
    ASSERT_EQ(n.get_matched_nodes(), (NodeVector{mul_add_absb, c, add_absb, abs, b}));
488

489
    ASSERT_TRUE(n.match(c * (any + b), mul_add_absb)); // nested any
490
    ASSERT_EQ(n.get_matched_nodes(), (NodeVector{mul_add_absb, c, add_absb, abs, a, b}));
491
    ASSERT_TRUE(n.match(c * (any + b), (b + abs) * c)); // permutations w/ any
492
    auto mul_c_add_ab = c * add_ab;
493 494
    ASSERT_TRUE(n.match(c * (any_false + b), c * (a + b)));  // nested any
    ASSERT_TRUE(n.match(c * (any_false + b), mul_c_add_ab)); // permutations w/ any_false
495
    ASSERT_EQ(n.get_matched_nodes(), (NodeVector{mul_c_add_ab, c, add_ab, a, a, b}));
496 497 498

    auto iconst1_0 = construct_constant_node(1);
    auto iconst1_1 = construct_constant_node(1);
499
    ASSERT_TRUE(n.match(pattern * iconst1_0, a * iconst1_1)); // different iconst
500
    ASSERT_EQ(n.get_pattern_map()[pattern], a);
501 502
    auto fconst1_0 = op::Constant::create(element::f32, shape, {1});
    auto patternf = std::make_shared<pattern::op::Label>(fconst1_0);
503
    ASSERT_TRUE(n.match(patternf * fconst1_0, a * iconst1_1)); // different iconst
504

505
    // Subgraph labels
506
    auto add = a + b;
507
    auto label = std::make_shared<pattern::op::Label>(add, nullptr, NodeVector{add});
508 509
    ASSERT_TRUE(n.match(label, add));
    ASSERT_EQ(n.get_pattern_map()[label], add);
510
    ASSERT_EQ(n.get_matched_nodes(), (NodeVector{add, add, a, b}));
511 512 513 514 515 516

    ASSERT_FALSE(n.match(label, a - b));

    ASSERT_TRUE(n.match(make_shared<op::Abs>(label), make_shared<op::Abs>(add)));
    ASSERT_EQ(n.get_pattern_map()[label], add);

517
    // Correct argument order
518 519 520 521 522 523 524 525 526
    ASSERT_FALSE(n.match(b - a, a - b));
    auto aab = a * (a - b);
    auto paab = pattern * (pattern - b);
    ASSERT_TRUE(n.match(paab, aab));
    auto aba = a * (b - a);
    ASSERT_FALSE(n.match(paab, aba));
    auto paba = pattern * (b - pattern);
    ASSERT_FALSE(n.match(paba, aab));

527
    // Correlations
528 529
    auto label1 = std::make_shared<pattern::op::Label>(a);
    auto tmp = label1 + b;
530
    auto label2 = std::make_shared<pattern::op::Label>(tmp, nullptr, NodeVector{tmp});
531
    auto sub_label1 = label1 - label2;
532 533
    auto sub_add = a - add;
    ASSERT_TRUE(n.match(sub_label1, sub_add));
534 535
    ASSERT_EQ(n.get_pattern_map()[label1], a);
    ASSERT_EQ(n.get_pattern_map()[label2], add);
536
    ASSERT_EQ(n.get_matched_nodes(), (NodeVector{sub_add, a, add, add, a, b}));
537 538 539 540 541 542 543

    ASSERT_FALSE(n.match(sub_label1, add - a));

    auto add_label1 = label1 + label2;
    ASSERT_TRUE(n.match(add_label1, add + a));
    ASSERT_EQ(n.get_pattern_map()[label1], a);
    ASSERT_EQ(n.get_pattern_map()[label2], add);
544
}
545 546 547

TEST(pattern, sum)
{
548
    // Sum
549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566
    TestMatcher n(nullptr);
    auto reducee_const = std::make_shared<op::Constant>(
        element::i32, Shape{2, 2}, std::vector<std::string>({"0", "0", "0", "0"}));
    auto sum_graph = xla_sum(reducee_const, AxisSet{0, 1});

    auto reduce_label = construct_sum_pattern();
    ASSERT_TRUE(n.match(reduce_label, sum_graph));
    ASSERT_EQ(n.get_pattern_map()[reduce_label], sum_graph);

    auto nested_sum_graph = make_shared<op::Abs>(make_shared<op::Abs>(
        make_shared<op::Abs>(make_shared<op::Abs>(make_shared<op::Abs>(sum_graph)))));

    auto nested_reduce_label = make_shared<op::Abs>(make_shared<op::Abs>(
        make_shared<op::Abs>(make_shared<op::Abs>(make_shared<op::Abs>(reduce_label)))));

    ASSERT_TRUE(n.match(nested_reduce_label, nested_sum_graph));
    ASSERT_EQ(n.get_pattern_map()[reduce_label], sum_graph);
}
567 568 569

TEST(pattern, mean)
{
570
    // construct mean
571 572 573 574 575 576 577 578 579 580 581 582 583 584
    TestMatcher n(nullptr);

    auto input = std::make_shared<op::Parameter>(element::f32, Shape{2, 3});
    auto N = op::Constant::create(element::f32, Shape{3}, {2, 2, 2});
    auto sum_input1 = std::make_shared<op::Sum>(input, AxisSet{0});
    auto mean = std::make_shared<op::Divide>(sum_input1, N);

    auto mean_graph = construct_mean_graph();
    ASSERT_TRUE(n.match(mean_graph, mean));
    ASSERT_EQ(n.get_pattern_map()[mean_graph], mean);
}

TEST(pattern, variance)
{
585
    // construct variance
586 587 588 589 590 591 592 593 594 595 596 597 598 599
    TestMatcher n(nullptr);
    auto N = op::Constant::create(element::f32, Shape{3}, {2, 2, 2});
    auto input = std::make_shared<pattern::op::Label>(element::f32, Shape{2, 3});
    auto input_sq = std::make_shared<op::Multiply>(input, input);
    auto sum_input = std::make_shared<op::Sum>(input, AxisSet{0});
    auto square_sumed_input = std::make_shared<op::Multiply>(sum_input, sum_input);
    auto sum_squared_input = std::make_shared<op::Sum>(input_sq, AxisSet{0});
    auto avg_input_sum_sq = std::make_shared<op::Divide>(square_sumed_input, N);
    auto xmu = std::make_shared<op::Subtract>(sum_squared_input, avg_input_sum_sq);
    auto variance = std::make_shared<op::Divide>(xmu, N);

    auto var_graph = construct_variance_graph();
    ASSERT_TRUE(n.match(var_graph, variance));
    ASSERT_EQ(n.get_pattern_map()[var_graph], variance);
600
}
601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626

TEST(pattern, previous_matches)
{
    using ngraph::pattern::Matcher;
    Shape shape{};
    Matcher::PatternMap previous_matches;
    auto a = make_shared<op::Parameter>(element::i32, shape);
    auto b = make_shared<op::Parameter>(element::i32, shape);
    auto pattern = std::make_shared<pattern::op::Label>(b);
    auto abs = make_shared<op::Abs>(a);
    auto add = abs + b;
    {
        Matcher n(pattern + b);
        ASSERT_TRUE(n.match(add, previous_matches));
        ASSERT_EQ(n.get_pattern_map()[pattern], abs);
    }

    {
        Matcher n(pattern + b);
        previous_matches.insert(std::make_pair(pattern, a));
        ASSERT_FALSE(n.match(add, previous_matches));
    }
}

TEST(pattern, recurrent_pattern)
{
627
    using ngraph::pattern::RecurrentMatcher;
628 629 630 631 632 633 634 635 636 637 638 639
    Shape shape{};
    ngraph::pattern::Matcher::PatternMap previous_matches;
    auto a = make_shared<op::Parameter>(element::i32, shape);
    auto b = make_shared<op::Parameter>(element::i32, shape);
    auto rpattern = std::make_shared<pattern::op::Label>(b);
    auto iconst0 = construct_constant_node(0);
    auto abs = make_shared<op::Abs>(a);
    auto add1 = iconst0 + b;
    auto add2 = iconst0 + add1;
    auto add3 = iconst0 + add2;
    auto padd = iconst0 + rpattern;
    std::set<std::shared_ptr<pattern::op::Label>> empty_correlated_matches;
640 641 642 643
    RecurrentMatcher rm(padd, rpattern, empty_correlated_matches, nullptr);
    ASSERT_TRUE(rm.match(add3));
    ASSERT_EQ(rm.get_number_of_bound_labels(), 1);
    auto recurrent_matches = rm.get_bound_nodes_for_pattern(rpattern);
644 645 646 647
    ASSERT_EQ(recurrent_matches.at(0), add2);
    ASSERT_EQ(recurrent_matches.at(1), add1);
    ASSERT_EQ(recurrent_matches.at(2), b);

648
    // Multiple labels in a reccuring pattern
649 650 651 652 653
    auto iconst1 = construct_constant_node(1);
    auto iconst_label = std::make_shared<pattern::op::Label>(iconst1, nullptr, NodeVector{iconst1});
    auto add2_2 = iconst1 + add1;
    auto add3_2 = iconst0 + add2_2;
    auto padd2 = iconst_label + rpattern;
654 655 656 657
    RecurrentMatcher rm2(padd2, rpattern, empty_correlated_matches, nullptr);
    ASSERT_TRUE(rm2.match(add3_2));
    ASSERT_EQ(rm2.get_number_of_bound_labels(), 2);
    recurrent_matches = rm2.get_bound_nodes_for_pattern(rpattern);
658 659 660
    ASSERT_EQ(recurrent_matches.at(0), add2_2);
    ASSERT_EQ(recurrent_matches.at(1), add1);
    ASSERT_EQ(recurrent_matches.at(2), b);
661
    auto iconst_matches = rm2.get_bound_nodes_for_pattern(iconst_label);
662 663 664 665
    ASSERT_EQ(iconst_matches.at(0), iconst0);
    ASSERT_EQ(iconst_matches.at(1), iconst1);
    ASSERT_EQ(iconst_matches.at(2), iconst0);

666
    // Non-matching correlated labels
667 668
    std::set<std::shared_ptr<pattern::op::Label>> correlated_matches;
    correlated_matches.insert(iconst_label);
669 670 671 672
    RecurrentMatcher rm3(padd2, rpattern, correlated_matches, nullptr);
    ASSERT_TRUE(rm3.match(add3_2));
    ASSERT_EQ(rm3.get_number_of_bound_labels(), 2);
    iconst_matches = rm3.get_bound_nodes_for_pattern(iconst_label);
673 674 675
    ASSERT_EQ(iconst_matches.size(), 1);
    ASSERT_EQ(iconst_matches.at(0), iconst0);

676 677
    // Matching correlated labels and
    // testing if RecurrentMatcher can be reused for different nodes
678 679 680
    ASSERT_TRUE(rm3.match(add3));
    ASSERT_EQ(rm3.get_number_of_bound_labels(), 2);
    recurrent_matches = rm3.get_bound_nodes_for_pattern(rpattern);
681 682 683
    ASSERT_EQ(recurrent_matches.at(0), add2);
    ASSERT_EQ(recurrent_matches.at(1), add1);
    ASSERT_EQ(recurrent_matches.at(2), b);
684
    iconst_matches = rm3.get_bound_nodes_for_pattern(iconst_label);
685 686 687 688
    ASSERT_EQ(iconst_matches.at(0), iconst0);
    ASSERT_EQ(iconst_matches.at(1), iconst0);
    ASSERT_EQ(iconst_matches.at(2), iconst0);
}
689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711

class TestRecurrentGraphRewrite : public ngraph::pass::RecurrentGraphRewrite
{
public:
    void construct_recurrent_add()
    {
        Shape shape{};
        auto iconst0 = construct_constant_node(0);
        auto iconst_label =
            std::make_shared<pattern::op::Label>(iconst0, nullptr, NodeVector{iconst0});
        auto rpattern = std::make_shared<pattern::op::Label>(element::i32, shape);
        auto padd = iconst_label + rpattern;

        auto sum_pattern = construct_sum_pattern();

        ngraph::pattern::recurrent_graph_rewrite_callback callback = [iconst_label, rpattern](
            pattern::RecurrentMatcher& rm) {
            NGRAPH_DEBUG << "In a callback for construct_recurrent_add against "
                         << rm.get_match_root()->get_name();

            auto iconst_matches = rm.get_bound_nodes_for_pattern(iconst_label);

            auto is_iconst_zero = [](std::shared_ptr<Node> n) {
712
                bool result = ngraph::is_zero(n);
713
                NGRAPH_DEBUG << n->get_name() << " is " << (result ? " a zero " : " not a zero");
714
                return ngraph::is_zero(n);
715 716 717 718 719 720 721 722 723 724 725
            };

            bool are_all_iconst_zeros =
                std::all_of(iconst_matches.begin(), iconst_matches.end(), is_iconst_zero);

            if (!are_all_iconst_zeros)
            {
                return false;
            }

            auto number_of_adds = rm.get_number_of_recurrent_matches();
726 727
            // replace the topmost add with the seed (i.e. the first parameter to add)
            // matches are added in reverse order (i.e. the first match is the topmost node)
728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768
            auto arg = rm.get_bound_nodes_for_pattern(rpattern).at(number_of_adds - 1);
            NGRAPH_DEBUG << "Replacing " << rm.get_match_root()->get_name() << " with "
                         << arg->get_name();
            ngraph::replace_node(rm.get_match_root(), arg);
            return true;
        };

        std::set<std::shared_ptr<pattern::op::Label>> empty_correlated_matches;
        auto rm = make_shared<pattern::RecurrentMatcher>(
            padd, rpattern, empty_correlated_matches, callback);
        this->add_matcher(rm);
    }

    TestRecurrentGraphRewrite()
        : RecurrentGraphRewrite()
    {
        construct_recurrent_add();
    }
};

TEST(pattern, recurrent_graph_rewrite)
{
    Shape shape{};
    pass::Manager pass_manager;
    pass_manager.register_pass<TestRecurrentGraphRewrite>();

    {
        auto a = make_shared<op::Parameter>(element::i32, shape);
        auto iconst0 = construct_constant_node(0);
        auto add_a1 = a + iconst0;
        auto add_a2 = add_a1 + iconst0;
        auto add_a3 = add_a2 + iconst0;
        auto abs_add_a3 = std::make_shared<op::Abs>(add_a3);

        auto b = make_shared<op::Parameter>(element::i32, shape);
        auto add_b1 = b + iconst0;
        auto add_b2 = add_b1 + iconst0;
        auto abs_add_b2 = std::make_shared<op::Abs>(add_b2);

        auto graph = abs_add_a3 * abs_add_b2;

769
        auto f = std::make_shared<Function>(ngraph::NodeVector{graph}, ParameterVector{a, b});
770 771
        pass_manager.run_passes(f);

772 773
        auto left_abs = graph->get_argument(0);
        auto add_a = left_abs->get_argument(0);
774 775
        ASSERT_EQ(add_a, a);

776 777
        auto right_abs = graph->get_argument(1);
        auto add_b = right_abs->get_argument(0);
778 779 780
        ASSERT_EQ(add_b, b);
    }
}
Nick Korovaiko's avatar
Nick Korovaiko committed
781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812

TEST(pattern, label_on_skip)
{
    Shape shape{2, 2};
    auto a = make_shared<op::Parameter>(element::i32, shape);
    auto b = make_shared<op::Parameter>(element::i32, Shape{});
    auto iconst = ngraph::make_zero(element::i32, Shape{});
    auto label = std::make_shared<pattern::op::Label>(iconst);
    auto const_label =
        std::make_shared<pattern::op::Label>(iconst, ngraph::is_zero, NodeVector{iconst});

    auto bcst_pred = [](std::shared_ptr<Node> n) {
        return std::dynamic_pointer_cast<op::Broadcast>(n) != nullptr;
    };

    auto bcst = std::make_shared<pattern::op::Skip>(const_label, bcst_pred);
    auto bcst_label = std::make_shared<pattern::op::Label>(bcst, nullptr, NodeVector{bcst});
    auto matcher = std::make_shared<pattern::Matcher>(
        std::make_shared<op::Multiply>(label, bcst_label), nullptr);

    auto const_broadcast = make_shared<op::Broadcast>(iconst, shape, AxisSet{0, 1});
    auto mul = a * const_broadcast;
    auto mul_scalar = b * iconst;
    ASSERT_TRUE(matcher->match(mul));
    ASSERT_EQ(matcher->get_pattern_map()[bcst_label], const_broadcast);
    ASSERT_EQ(matcher->get_pattern_map()[const_label], iconst);
    ASSERT_EQ(matcher->get_pattern_map()[label], a);
    ASSERT_TRUE(matcher->match(mul_scalar));
    ASSERT_EQ(matcher->get_pattern_map()[bcst_label], iconst);
    ASSERT_EQ(matcher->get_pattern_map()[const_label], iconst);
    ASSERT_EQ(matcher->get_pattern_map()[label], b);
}
Nick Korovaiko's avatar
Nick Korovaiko committed
813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832

TEST(pattern, is_contained_match)
{
    Shape shape{};
    auto a = make_shared<op::Parameter>(element::i32, shape);
    auto absn = make_shared<op::Abs>(a);
    TestMatcher n(nullptr);

    auto label_a = std::make_shared<pattern::op::Label>(a);
    auto label_abs = make_shared<op::Abs>(a);
    ASSERT_TRUE(n.match(label_abs, absn));
    auto result_absn = make_shared<op::Result>(absn);
    ASSERT_TRUE(n.is_contained_match());

    auto absn2 = make_shared<op::Abs>(absn);
    auto result_absn2 = make_shared<op::Result>(absn2);
    auto label_abs2 = make_shared<op::Abs>(label_abs);
    ASSERT_TRUE(n.match(label_abs2, absn2));
    ASSERT_FALSE(n.is_contained_match());
}