pattern.cpp 18.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------

#include <algorithm>
#include <cstdio>
#include <iostream>
#include <list>
#include <memory>

#include "gtest/gtest.h"
varun-intel's avatar
varun-intel committed
22
#include "ngraph/graph_util.hpp"
23 24
#include "ngraph/log.hpp"
#include "ngraph/ngraph.hpp"
25
#include "ngraph/ops/sum.hpp"
26 27 28 29 30
#include "ngraph/pass/graph_rewrite.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/any.hpp"
#include "ngraph/pattern/op/label.hpp"
31
#include "util/matcher.hpp"
32 33 34 35

using namespace ngraph;
using namespace std;

36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54
template <typename T>
std::shared_ptr<Node> create_reduction(const std::shared_ptr<Node>& node,
                                       const std::string& init_val,
                                       const AxisSet& reduction_axes)
{
    const auto& et = node->get_element_type();
    auto f_A = std::make_shared<op::Parameter>(et, Shape{});
    auto f_B = std::make_shared<op::Parameter>(et, Shape{});
    auto f = std::make_shared<Function>(std::make_shared<T>(f_A, f_B), op::Parameters{f_A, f_B});

    auto init = std::make_shared<op::Constant>(et, Shape{}, std::vector<std::string>({init_val}));
    return std::make_shared<op::Reduce>(node, init, f, reduction_axes);
}

std::shared_ptr<Node> xla_sum(const std::shared_ptr<Node>& node, const AxisSet& reduction_axes)
{
    return create_reduction<op::Add>(node, "0", reduction_axes);
}

55 56
static std::shared_ptr<Node> construct_constant_node(int n)
{
57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123
    return op::Constant::create(element::i32, Shape{}, {n});
}

bool is_equal_to_const_value(std::string const_value, std::shared_ptr<Node> reduce_constant)
{
    if (auto rc = std::dynamic_pointer_cast<op::Constant>(reduce_constant))
    {
        auto cshape = rc->get_shape();
        size_t n = shape_size(cshape);
        //awkward(but generic) way to construct a constant of a given type, shape, value
        std::vector<std::string> vz{n, const_value};
        auto zero_constant = std::make_shared<op::Constant>(rc->get_element_type(), cshape, vz);

        //equally awkward way to compare elements to const_value
        size_t n_bytes = n * rc->get_element_type().size();
        NGRAPH_DEBUG << "Comparing " << n_bytes << " bytes";
        return !memcmp(zero_constant->get_data_ptr(), rc->get_data_ptr(), n_bytes);
    }
    else
    {
        return false;
    }
}

bool is_zero(std::shared_ptr<Node> reduce_constant)
{
    return is_equal_to_const_value("0", reduce_constant);
}

bool sum_predicate(std::shared_ptr<Node> gn)
{
    NGRAPH_DEBUG << "pred_v2 : looking at " << gn->get_name();
    if (auto r = std::dynamic_pointer_cast<op::Reduce>(gn))
    {
        auto reducee = gn->get_input_op(0);
        auto reduce_constant = gn->get_input_op(1);

        if (!is_zero(reduce_constant))
        {
            return false;
        }

        NGRAPH_DEBUG << "looking at function's result  "
                     << r->get_functions()[0]->get_result()->get_name();
        if (auto sum = std::dynamic_pointer_cast<op::Add>(r->get_functions()[0]->get_result()))
        {
            auto parm1 = std::dynamic_pointer_cast<op::Parameter>(sum->get_input_op(0));
            auto parm2 = std::dynamic_pointer_cast<op::Parameter>(sum->get_input_op(1));

            const auto parm_or_nil = [](std::shared_ptr<Node> p) {
                return p ? p->get_name() : std::string("(nil)");
            };
            NGRAPH_DEBUG << "parm1 = " << parm_or_nil(parm1) << " , parm2 = " << parm_or_nil(parm2)
                         << std::endl;
            if (parm1 && parm2 && parm1 != parm2)
            {
                return true;
            }
        }
    }

    return false;
}

std::shared_ptr<pattern::op::Label> construct_sum_pattern() //for the sake of explicitness
{
    return std::make_shared<pattern::op::Label>(element::i32, Shape{}, sum_predicate);
124 125 126 127 128 129 130 131 132
}

class TestGraphRewrite : public ngraph::pass::GraphRewrite
{
public:
    void construct_multiply_by_one()
    {
        //pattern #1 : a * 1 = a
        auto iconst1 = construct_constant_node(1);
133
        auto pattern = std::make_shared<pattern::op::Label>(iconst1);
134 135

        ngraph::pattern::gr_callback_fn callback = [pattern](pattern::Matcher& m) {
136 137
            NGRAPH_DEBUG << "In a callback for construct_multiply_by_one against "
                         << m.match_root()->get_name();
138
            assert(m.match_root()->get_input_ops().size() == 2);
139

140 141
            auto pattern_map = m.get_pattern_map();

142
            size_t const_node_index = m.match_root()->get_input_ops().at(0) == pattern_map[pattern];
143
            auto const_node = dynamic_pointer_cast<op::Constant>(
144 145
                m.match_root()->get_input_ops().at(const_node_index));
            auto second_node = m.match_root()->get_input_ops().at(const_node_index);
146 147
            NGRAPH_DEBUG << "second_node = " << second_node->get_name()
                         << " , pattern = " << pattern_map[pattern]->get_name();
148

149
            std::shared_ptr<ngraph::Node> nn = nullptr;
150 151
            if (pattern_map[pattern]->get_element_type() != const_node->get_element_type() ||
                pattern_map[pattern]->get_shape() != const_node->get_shape())
152
            {
153
                NGRAPH_DEBUG << "Operands' types and/or shape don't match";
154
                return nn;
155 156
            }

157
            auto const_values = const_node->get_vector<int32_t>();
158 159 160 161 162
            bool all_ones =
                std::all_of(begin(const_values), end(const_values), [](int e) { return e == 1; });

            if (!all_ones)
            {
163
                NGRAPH_DEBUG << "Constant vector's values aren't equal to 1";
164
                return nn;
165
            }
166
            return pattern_map[pattern];
167 168 169 170 171 172 173 174 175 176
        };

        auto m = make_shared<TestMatcher>(pattern * iconst1, callback);
        this->add_matcher(m);
    }

    void construct_add_zero()
    {
        //pattern #2 : a + 0 = a
        auto iconst0 = construct_constant_node(0);
177
        auto pattern = std::make_shared<pattern::op::Label>(iconst0);
178

179
        auto callback = [pattern](pattern::Matcher& m) {
180 181
            NGRAPH_DEBUG << "In a callback for construct_add_zero against "
                         << m.match_root()->get_name();
182
            assert(m.match_root()->get_input_ops().size() == 2);
183

184 185
            auto pattern_map = m.get_pattern_map();

186
            size_t const_node_index = m.match_root()->get_input_ops().at(0) == pattern_map[pattern];
187
            auto const_node = dynamic_pointer_cast<op::Constant>(
188 189
                m.match_root()->get_input_ops().at(const_node_index));
            auto second_node = m.match_root()->get_input_ops().at(const_node_index);
190 191
            NGRAPH_DEBUG << "second_node = " << second_node->get_name()
                         << " , pattern = " << pattern_map[pattern]->get_name();
192

193 194 195
            //ASSERT_NE(nullptr, const_node);

            std::shared_ptr<ngraph::Node> nn = nullptr;
196 197
            if (pattern_map[pattern]->get_element_type() != const_node->get_element_type() ||
                pattern_map[pattern]->get_shape() != const_node->get_shape())
198
            {
199
                NGRAPH_DEBUG << "Operands' types and/or shape don't match";
200
                return nn;
201 202
            }

203
            auto const_values = const_node->get_vector<int>();
204 205 206 207 208
            bool all_zeros =
                std::all_of(begin(const_values), end(const_values), [](int e) { return e == 0; });

            if (!all_zeros)
            {
209
                NGRAPH_DEBUG << "Constant vector's values aren't equal to 0";
210
                return nn;
211 212
            }

213
            return pattern_map[pattern];
214 215 216 217 218 219
        };

        auto m = make_shared<TestMatcher>(pattern + iconst0, callback);
        this->add_matcher(m);
    }

220 221 222 223 224 225 226 227 228 229
    void construct_sum()
    {
        auto sum_pattern = construct_sum_pattern();

        ngraph::pattern::gr_callback_fn callback = [](pattern::Matcher& m) {
            NGRAPH_DEBUG << "In a callback for construct_sum_pattern against "
                         << m.match_root()->get_name();
            auto reduce = std::dynamic_pointer_cast<op::Reduce>(m.match_root());
            auto reducee = reduce->get_inputs().at(0).get_output().get_node();
            NGRAPH_DEBUG << "reducee = " << reducee->get_name();
230 231 232
            auto sum =
                std::shared_ptr<ngraph::Node>(new op::Sum(reducee, reduce->get_reduction_axes()));
            return sum;
233 234 235 236 237 238
        };

        auto m = make_shared<TestMatcher>(sum_pattern, callback);
        this->add_matcher(m);
    }

239 240 241 242 243
    TestGraphRewrite()
        : GraphRewrite()
    {
        construct_multiply_by_one();
        construct_add_zero();
244
        construct_sum();
245 246 247 248 249 250 251
    }
};

static void run_passes(pass::Manager& pass_manager,
                       shared_ptr<Node> graph,
                       std::vector<shared_ptr<op::Parameter>> parms)
{
252
    auto func = make_shared<Function>(graph, op::Parameters{parms});
253 254 255 256 257
    pass_manager.run_passes(func);
}

TEST(pattern, graph_rewrite)
{
258
    auto shape = Shape{};
259 260 261
    pass::Manager pass_manager;
    pass_manager.register_pass<TestGraphRewrite>();

262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280
    {
        auto a = make_shared<op::Parameter>(element::i32, shape);
        auto b = make_shared<op::Parameter>(element::i32, shape);
        auto c = make_shared<op::Parameter>(element::i32, shape);
        auto iconst0 = construct_constant_node(0);
        auto graph_a = a + iconst0;
        auto graph_b = b + iconst0;

        auto f = std::make_shared<Function>(ngraph::Nodes{a, b, graph_a, c, graph_b},
                                            op::Parameters{a, b, c});
        pass_manager.run_passes(f);

        ASSERT_TRUE(graph_a->get_output_inputs(0).empty());
        ASSERT_TRUE(graph_b->get_output_inputs(0).empty());

        auto expected = ngraph::Nodes{a, b, a, c, b};
        ASSERT_TRUE(f->get_results() == expected);
    }

281
    {
282 283
        auto a = make_shared<op::Parameter>(element::i32, shape);
        auto b = make_shared<op::Parameter>(element::i32, shape);
284 285 286 287
        auto iconst0 = construct_constant_node(0);
        auto sum = (a + iconst0);
        auto graph = b + sum;
        run_passes(pass_manager, graph, {a, b});
288
        ASSERT_EQ(graph->get_input_ops().at(1), a);
289 290
        ASSERT_EQ(&graph->get_inputs().at(1).get_output(),
                  &a->get_outputs().at(0)); //graph's input points to a's output
291
        ASSERT_TRUE(sum->get_output_inputs(0)
292 293 294 295 296 297
                        .empty()); //graph's input is removed from sum's output.get_inputs()
        ASSERT_TRUE(a->get_outputs().at(0).get_inputs().count(
            &graph->get_inputs().at(1))); //a's output feeds into graph's input
    }

    {
298 299
        auto a = make_shared<op::Parameter>(element::i32, shape);
        auto b = make_shared<op::Parameter>(element::i32, shape);
300 301 302 303
        auto iconst1 = construct_constant_node(1);
        auto mul = (a * iconst1);
        auto graph = b + mul;
        run_passes(pass_manager, graph, {a, b});
304
        ASSERT_EQ(graph->get_input_ops().at(1), a);
305 306 307 308 309 310 311 312 313 314 315
        ASSERT_EQ(&graph->get_inputs().at(1).get_output(),
                  &a->get_outputs().at(0)); //graph's input points to a's output
        ASSERT_TRUE(mul->get_outputs()
                        .at(0)
                        .get_inputs()
                        .empty()); //graph's input is removed from sum's output.get_inputs()
        ASSERT_TRUE(a->get_outputs().at(0).get_inputs().count(
            &graph->get_inputs().at(1))); //a's output feeds into graph's input
    }

    {
316 317
        auto a = make_shared<op::Parameter>(element::i32, shape);
        auto b = make_shared<op::Parameter>(element::i32, shape);
318 319 320
        auto iconst1 = construct_constant_node(1);
        auto graph = ((((a * iconst1) * iconst1) * iconst1) * iconst1) + b;
        run_passes(pass_manager, graph, {a, b});
321
        ASSERT_EQ(graph->get_input_ops().at(0), a);
322 323 324 325 326 327 328
        ASSERT_EQ(&graph->get_inputs().at(0).get_output(),
                  &a->get_outputs().at(0)); //graph's input points to a's output
        ASSERT_TRUE(a->get_outputs().at(0).get_inputs().count(
            &graph->get_inputs().at(0))); //a's output feeds into graph's input
    }

    {
329 330
        auto a = make_shared<op::Parameter>(element::i32, shape);
        auto b = make_shared<op::Parameter>(element::i32, shape);
331 332 333 334
        auto iconst0 = construct_constant_node(0);
        auto iconst1 = construct_constant_node(1);
        auto graph = b + (iconst0 + ((a + iconst0) * iconst1));
        run_passes(pass_manager, graph, {a, b});
335
        ASSERT_EQ(graph->get_input_ops().at(1), a);
336 337 338 339 340 341 342
        ASSERT_EQ(&graph->get_inputs().at(1).get_output(),
                  &a->get_outputs().at(0)); //graph's input points to a's output
        ASSERT_TRUE(a->get_outputs().at(0).get_inputs().count(
            &graph->get_inputs().at(1))); //a's output feeds into graph's input
    }

    {
343 344
        auto a = make_shared<op::Parameter>(element::i32, shape);
        auto b = make_shared<op::Parameter>(element::i32, shape);
345 346 347
        auto iconst1 = construct_constant_node(1);
        auto graph = b + (iconst1 * (iconst1 * (iconst1 * (iconst1 * a))));
        run_passes(pass_manager, graph, {a, b});
348
        ASSERT_EQ(graph->get_input_ops().at(1), a);
349 350 351 352 353
        ASSERT_EQ(&graph->get_inputs().at(1).get_output(),
                  &a->get_outputs().at(0)); //graph's input points to a's output
        ASSERT_TRUE(a->get_outputs().at(0).get_inputs().count(
            &graph->get_inputs().at(1))); //a's output feeds into graph's input
    }
354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370

    //Sum rewrite
    {
        auto parm = make_shared<op::Parameter>(element::i32, Shape{2, 2});
        auto axes = AxisSet{0, 1};
        auto sum_graph = xla_sum(parm, axes);
        auto innermost_abs = make_shared<op::Abs>(sum_graph);

        auto nested_sum_graph = make_shared<op::Abs>(
            make_shared<op::Abs>(make_shared<op::Abs>(make_shared<op::Abs>(innermost_abs))));

        run_passes(pass_manager, nested_sum_graph, {parm});
        auto sum = std::dynamic_pointer_cast<op::Sum>(innermost_abs->get_input_op(0));
        ASSERT_TRUE(sum);
        ASSERT_EQ(sum->get_reduction_axes(), axes);
        ASSERT_EQ(sum->get_input_op(0), parm);
    }
371 372 373 374
}

TEST(pattern, matcher)
{
375
    auto shape = Shape{};
376
    auto a = make_shared<op::Parameter>(element::i32, shape);
377 378 379 380 381 382 383 384 385 386 387
    TestMatcher n(nullptr);
    ASSERT_TRUE(n.match(a, a));

    auto abs = make_shared<op::Abs>(a);
    auto any = std::make_shared<pattern::op::Any>(a);
    ASSERT_TRUE(n.match(any, abs));

    auto any_false =
        std::make_shared<pattern::op::Any>(a, [](std::shared_ptr<Node> no) { return false; });
    ASSERT_TRUE(n.match(any_false, a));

388
    auto pattern = std::make_shared<pattern::op::Label>(a);
389
    ASSERT_TRUE(n.match(pattern, a));
390
    ASSERT_EQ(n.get_pattern_map()[pattern], a);
391 392

    auto pattern_false =
393
        std::make_shared<pattern::op::Label>(a, [](std::shared_ptr<Node> no) { return false; });
394 395
    ASSERT_FALSE(n.match(pattern_false, a));

396 397
    auto b = make_shared<op::Parameter>(element::i32, shape);
    auto d = make_shared<op::Parameter>(element::i32, shape);
398 399 400 401 402 403
    ASSERT_FALSE(n.match(d, b));

    ASSERT_FALSE(n.match(abs + b, b + b));
    ASSERT_TRUE(n.match(any + b, abs + b));

    ASSERT_TRUE(n.match(pattern + b, abs + b));
404
    ASSERT_EQ(n.get_pattern_map()[pattern], abs);
405 406

    ASSERT_TRUE(n.match(b + pattern, abs + b));
407
    ASSERT_EQ(n.get_pattern_map()[pattern], abs);
408

409
    auto c = make_shared<op::Parameter>(element::i32, shape);
410
    ASSERT_TRUE(n.match(c * (b + pattern), c * (abs + b)));
411
    ASSERT_EQ(n.get_pattern_map()[pattern], abs);
412 413 414 415 416 417 418 419 420

    ASSERT_TRUE(n.match(c * (any + b), c * (abs + b)));     //nested any
    ASSERT_TRUE(n.match(c * (any + b), (b + abs) * c));     //permutations w/ any
    ASSERT_TRUE(n.match(c * (any_false + b), c * (a + b))); //nested any
    ASSERT_TRUE(n.match(c * (any_false + b), (b + a) * c)); //permutations w/ any_false

    auto iconst1_0 = construct_constant_node(1);
    auto iconst1_1 = construct_constant_node(1);
    ASSERT_TRUE(n.match(pattern * iconst1_0, a * iconst1_1)); //different iconst
421
    ASSERT_EQ(n.get_pattern_map()[pattern], a);
422 423
    auto fconst1_0 = op::Constant::create(element::f32, shape, {1});
    auto patternf = std::make_shared<pattern::op::Label>(fconst1_0);
424
    ASSERT_TRUE(n.match(patternf * fconst1_0, a * iconst1_1)); //different iconst
425 426 427

    //Subgraph labels
    auto add = a + b;
428
    auto label = std::make_shared<pattern::op::Label>(add, nullptr, Nodes{add});
429 430 431 432 433 434 435 436 437
    ASSERT_TRUE(n.match(label, add));
    ASSERT_EQ(n.get_pattern_map()[label], add);

    ASSERT_FALSE(n.match(label, a - b));

    ASSERT_TRUE(n.match(make_shared<op::Abs>(label), make_shared<op::Abs>(add)));
    ASSERT_EQ(n.get_pattern_map()[label], add);

    //Correlations
438 439 440
    auto label1 = std::make_shared<pattern::op::Label>(a);
    auto tmp = label1 + b;
    auto label2 = std::make_shared<pattern::op::Label>(tmp, nullptr, Nodes{tmp});
441 442 443 444 445 446 447 448 449 450 451
    auto sub_label1 = label1 - label2;
    ASSERT_TRUE(n.match(sub_label1, a - add));
    ASSERT_EQ(n.get_pattern_map()[label1], a);
    ASSERT_EQ(n.get_pattern_map()[label2], add);

    ASSERT_FALSE(n.match(sub_label1, add - a));

    auto add_label1 = label1 + label2;
    ASSERT_TRUE(n.match(add_label1, add + a));
    ASSERT_EQ(n.get_pattern_map()[label1], a);
    ASSERT_EQ(n.get_pattern_map()[label2], add);
452
}
453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474

TEST(pattern, sum)
{
    //Sum
    TestMatcher n(nullptr);
    auto reducee_const = std::make_shared<op::Constant>(
        element::i32, Shape{2, 2}, std::vector<std::string>({"0", "0", "0", "0"}));
    auto sum_graph = xla_sum(reducee_const, AxisSet{0, 1});

    auto reduce_label = construct_sum_pattern();
    ASSERT_TRUE(n.match(reduce_label, sum_graph));
    ASSERT_EQ(n.get_pattern_map()[reduce_label], sum_graph);

    auto nested_sum_graph = make_shared<op::Abs>(make_shared<op::Abs>(
        make_shared<op::Abs>(make_shared<op::Abs>(make_shared<op::Abs>(sum_graph)))));

    auto nested_reduce_label = make_shared<op::Abs>(make_shared<op::Abs>(
        make_shared<op::Abs>(make_shared<op::Abs>(make_shared<op::Abs>(reduce_label)))));

    ASSERT_TRUE(n.match(nested_reduce_label, nested_sum_graph));
    ASSERT_EQ(n.get_pattern_map()[reduce_label], sum_graph);
}