algebraic_simplification.cpp 24 KB
Newer Older
1
//*****************************************************************************
2
// Copyright 2017-2019 Intel Corporation
3 4 5 6 7 8 9 10 11 12 13 14 15
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
16 17 18 19 20 21 22 23 24 25 26 27 28 29

#include <algorithm>
#include <cstdio>
#include <iostream>
#include <list>
#include <memory>

#include "gtest/gtest.h"
#include "ngraph/file_util.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp"
#include "ngraph/ngraph.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/batch_norm.hpp"
30
#include "ngraph/op/concat.hpp"
31 32
#include "ngraph/op/constant.hpp"
#include "ngraph/op/divide.hpp"
33 34
#include "ngraph/op/divide.hpp"
#include "ngraph/op/exp.hpp"
35
#include "ngraph/op/get_output_element.hpp"
36
#include "ngraph/op/log.hpp"
37
#include "ngraph/op/multiply.hpp"
38
#include "ngraph/op/negative.hpp"
39
#include "ngraph/op/product.hpp"
40 41 42 43 44 45
#include "ngraph/op/sqrt.hpp"
#include "ngraph/op/subtract.hpp"
#include "ngraph/op/sum.hpp"
#include "ngraph/pass/algebraic_simplification.hpp"
#include "ngraph/pass/graph_rewrite.hpp"
#include "ngraph/pass/manager.hpp"
46
#include "ngraph/pass/pass.hpp"
47
#include "ngraph/pass/visualize_tree.hpp"
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
#include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/label.hpp"
#include "ngraph/pattern/op/skip.hpp"
#include "ngraph/serializer.hpp"
#include "util/matcher.hpp"
#include "util/test_tools.hpp"

using namespace ngraph;
using namespace std;

TEST(algebraic_simplification, add_types_shapes)
{
    Shape shapes[] = {Shape{}, Shape{2, 2}, Shape{3, 3, 3}};
    element::Type types[] = {element::i32, element::f32, element::f64};
    for (auto type : types)
    {
        for (auto shape : shapes)
        {
            pass::Manager pass_manager;
            pass_manager.register_pass<pass::AlgebraicSimplification>();

            auto a = make_shared<op::Parameter>(type, shape);
            auto b = make_shared<op::Parameter>(type, shape);
            auto c = make_shared<op::Parameter>(type, shape);
            auto iconst0 = ngraph::make_constant_from_string("0", type, shape);
            auto add_a_0 = a + iconst0;
            auto add_a_0_0 = add_a_0 + iconst0;
            auto add_b_0 = b + iconst0;
            auto add_b_0_0 = add_b_0 + iconst0;

            auto f = std::make_shared<Function>(ngraph::NodeVector{a, b, add_a_0_0, c, add_b_0_0},
79
                                                ParameterVector{a, b, c});
80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109
            pass_manager.run_passes(f);

            ASSERT_EQ(count_ops_of_type<op::Add>(f), 0);
            auto expected = ngraph::NodeVector{a, b, a, c, b};
            auto results = f->get_results();
            for (size_t i = 0; i < results.size(); i++)
            {
                ASSERT_EQ(expected.at(i), results.at(i)->get_argument(0));
            }
        }
    }
}

TEST(algebraic_simplification, add_broadcast)
{
    Shape shape{2, 2};
    pass::Manager pass_manager;
    pass_manager.register_pass<pass::AlgebraicSimplification>();

    auto a = make_shared<op::Parameter>(element::i32, shape);
    auto b = make_shared<op::Parameter>(element::i32, shape);
    auto c = make_shared<op::Parameter>(element::i32, shape);
    auto iconst0 = ngraph::make_zero(element::i32, Shape{});
    auto const_broadcast = make_shared<op::Broadcast>(iconst0, shape, AxisSet{0, 1});
    auto add_a_0 = a + const_broadcast;
    auto add_a_0_0 = add_a_0 + const_broadcast;
    auto add_b_0 = b + const_broadcast;
    auto add_b_0_0 = add_b_0 + const_broadcast;

    auto f = std::make_shared<Function>(ngraph::NodeVector{a, b, add_a_0_0, c, add_b_0_0},
110
                                        ParameterVector{a, b, c});
111 112 113 114 115 116 117 118 119 120 121
    pass_manager.run_passes(f);

    ASSERT_EQ(count_ops_of_type<op::Add>(f), 0);
    auto expected = ngraph::NodeVector{a, b, a, c, b};
    auto results = f->get_results();
    for (size_t i = 0; i < results.size(); i++)
    {
        ASSERT_EQ(expected.at(i), results.at(i)->get_argument(0));
    }
}

122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138
TEST(algebraic_simplification, multiply_broadcast)
{
    Shape shape{2, 2};
    pass::Manager pass_manager;
    pass_manager.register_pass<pass::AlgebraicSimplification>();

    auto a = make_shared<op::Parameter>(element::i32, shape);
    auto b = make_shared<op::Parameter>(element::i32, shape);
    auto c = make_shared<op::Parameter>(element::i32, shape);
    auto iconst0 = ngraph::make_zero(element::i32, Shape{});
    auto const_broadcast = make_shared<op::Broadcast>(iconst0, shape, AxisSet{0, 1});
    auto mul_a_0 = a * const_broadcast;
    auto mul_a_0_0 = mul_a_0 * const_broadcast;
    auto mul_b_0 = b * const_broadcast;
    auto mul_b_0_0 = mul_b_0 * const_broadcast;

    auto f = std::make_shared<Function>(ngraph::NodeVector{a, b, mul_a_0_0, c, mul_b_0_0},
139
                                        ParameterVector{a, b, c});
140 141 142 143 144 145 146 147 148 149 150
    pass_manager.run_passes(f);

    ASSERT_EQ(count_ops_of_type<op::Add>(f), 0);
    auto expected = ngraph::NodeVector{a, b, const_broadcast, c, const_broadcast};
    auto results = f->get_results();
    for (size_t i = 0; i < results.size(); i++)
    {
        ASSERT_EQ(expected.at(i), results.at(i)->get_argument(0));
    }
}

151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167
TEST(algebraic_simplification, zero_plus_zero_commutativity)
{
    Shape shape{};
    auto type = element::f32;
    pass::Manager pass_manager;
    pass_manager.register_pass<pass::AlgebraicSimplification>();

    auto a = make_shared<op::Parameter>(type, shape);
    auto b = make_shared<op::Parameter>(type, shape);
    auto c = make_shared<op::Parameter>(type, shape);
    auto iconst0 = ngraph::make_constant_from_string("0", type, shape);
    auto add_a_0 = iconst0 + iconst0;
    auto add_a_0_0 = iconst0 + iconst0;
    auto add_b_0 = iconst0 + b;
    auto add_b_0_0 = iconst0 + b;

    auto f = std::make_shared<Function>(ngraph::NodeVector{a, b, add_a_0_0, c, add_b_0_0},
168
                                        ParameterVector{a, b, c});
169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190
    pass_manager.run_passes(f);

    ASSERT_TRUE(ngraph::is_zero(f->get_results().at(2)->get_argument(0)));
    ASSERT_EQ(f->get_results().at(4)->get_argument(0), b);
}

TEST(algebraic_simplification, zero_multiply_zero_one)
{
    Shape shape{};
    auto type = element::f32;
    pass::Manager pass_manager;
    pass_manager.register_pass<pass::AlgebraicSimplification>();

    auto a = make_shared<op::Parameter>(type, shape);
    auto b = make_shared<op::Parameter>(type, shape);
    auto c = make_shared<op::Parameter>(type, shape);
    auto iconst0 = ngraph::make_constant_from_string("0", type, shape);
    auto iconst1 = ngraph::make_constant_from_string("1", type, shape);
    auto add_a_0 = iconst0 * iconst0;
    auto add_b_0 = iconst1 * iconst0;

    auto f = std::make_shared<Function>(ngraph::NodeVector{a, b, add_a_0, c, add_b_0},
191
                                        ParameterVector{a, b, c});
192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215
    pass_manager.run_passes(f);

    ASSERT_TRUE(ngraph::is_zero(f->get_results().at(2)->get_argument(0)));
    ASSERT_TRUE(ngraph::is_zero(f->get_results().at(4)->get_argument(0)));
}

TEST(algebraic_simplification, add_negative_tests)
{
    Shape shape{};
    auto type = element::f32;
    pass::Manager pass_manager;
    pass_manager.register_pass<pass::AlgebraicSimplification>();

    auto a = make_shared<op::Parameter>(type, shape);
    auto b = make_shared<op::Parameter>(type, shape);
    auto c = make_shared<op::Parameter>(type, shape);
    auto abs_a = make_shared<op::Abs>(a);
    auto iconst2 = ngraph::make_constant_from_string("2", type, shape);
    auto add_a_0 = a + iconst2;
    auto add_a_0_0 = add_a_0 + iconst2;
    auto add_b_0 = b + abs_a;
    auto add_b_0_0 = add_b_0 + abs_a;

    auto f = std::make_shared<Function>(ngraph::NodeVector{a, b, add_a_0_0, c, add_b_0_0},
216
                                        ParameterVector{a, b, c});
217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244
    pass_manager.run_passes(f);

    auto expected = ngraph::NodeVector{a, b, add_a_0_0, c, add_b_0_0};
    auto results = f->get_results();
    for (size_t i = 0; i < results.size(); i++)
    {
        ASSERT_EQ(expected.at(i), results.at(i)->get_argument(0));
    }
}

TEST(algebraic_simplification, multiply_negative_tests)
{
    Shape shape{};
    auto type = element::f32;
    pass::Manager pass_manager;
    pass_manager.register_pass<pass::AlgebraicSimplification>();

    auto a = make_shared<op::Parameter>(type, shape);
    auto b = make_shared<op::Parameter>(type, shape);
    auto c = make_shared<op::Parameter>(type, shape);
    auto abs_a = make_shared<op::Abs>(a);
    auto iconst2 = ngraph::make_constant_from_string("2", type, shape);
    auto add_a_0 = a * iconst2;
    auto add_a_0_0 = add_a_0 * iconst2;
    auto add_b_0 = b * abs_a;
    auto add_b_0_0 = add_b_0 * abs_a;

    auto f = std::make_shared<Function>(ngraph::NodeVector{a, b, add_a_0_0, c, add_b_0_0},
245
                                        ParameterVector{a, b, c});
246 247 248 249 250 251 252 253 254
    pass_manager.run_passes(f);

    auto expected = ngraph::NodeVector{a, b, add_a_0_0, c, add_b_0_0};
    auto results = f->get_results();
    for (size_t i = 0; i < results.size(); i++)
    {
        ASSERT_EQ(expected.at(i), results.at(i)->get_argument(0));
    }
}
255

256 257 258 259 260 261 262 263 264
TEST(algebraic_simplification, multiply_prod_vector_one)
{
    auto fconst1 = ngraph::op::Constant::create(element::f64, Shape{}, {2.0});
    auto broadcast = std::make_shared<op::Broadcast>(fconst1, Shape{3, 5}, AxisSet{0, 1});
    auto prod_fconst1 = std::make_shared<op::Product>(broadcast, AxisSet{1});

    pass::Manager pass_manager;
    pass_manager.register_pass<pass::AlgebraicSimplification>();

265
    auto f = std::make_shared<Function>(ngraph::NodeVector{prod_fconst1}, ParameterVector{});
266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284
    pass_manager.run_passes(f);
    auto new_broadcast =
        std::dynamic_pointer_cast<op::Broadcast>(f->get_results().at(0)->get_argument(0));
    ASSERT_TRUE(new_broadcast);
    auto new_const = std::dynamic_pointer_cast<op::Constant>(new_broadcast->get_argument(0));
    auto values = new_const->get_vector<double>();
    ASSERT_EQ(values.size(), 1);
    ASSERT_EQ(values.at(0), 32);
}

TEST(algebraic_simplification, multiply_prod_scalar_one)
{
    auto fconst1 = ngraph::op::Constant::create(element::f64, Shape{}, {2.0});
    auto broadcast = std::make_shared<op::Broadcast>(fconst1, Shape{3, 5}, AxisSet{0, 1});
    auto prod_fconst1 = std::make_shared<op::Product>(broadcast, AxisSet{0, 1});

    pass::Manager pass_manager;
    pass_manager.register_pass<pass::AlgebraicSimplification>();

285
    auto f = std::make_shared<Function>(ngraph::NodeVector{prod_fconst1}, ParameterVector{});
286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303
    pass_manager.run_passes(f);
    auto new_const =
        std::dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0));
    ASSERT_TRUE(new_const);
    auto values = new_const->get_vector<double>();
    ASSERT_EQ(values.size(), 1);
    ASSERT_EQ(values.at(0), 32768);
}

TEST(algebraic_simplification, multiply_prod_negative)
{
    auto fconst1 = ngraph::op::Constant::create(element::f64, Shape{2}, {1.0, 1.0});
    auto broadcast = std::make_shared<op::Broadcast>(fconst1, Shape{2, 5}, AxisSet{1});
    auto prod_fconst1 = std::make_shared<op::Product>(broadcast, AxisSet{0, 1});

    pass::Manager pass_manager;
    pass_manager.register_pass<pass::AlgebraicSimplification>();

304
    auto f = std::make_shared<Function>(ngraph::NodeVector{prod_fconst1}, ParameterVector{});
305 306 307 308 309
    pass_manager.run_passes(f);
    auto f_prod = f->get_results().at(0)->get_argument(0);
    ASSERT_EQ(f_prod, prod_fconst1);
}

310 311 312 313 314 315 316 317 318
TEST(algebraic_simplification, multiply_sum_scalar_one)
{
    auto fconst1 = ngraph::op::Constant::create(element::f64, Shape{}, {1.0});
    auto broadcast = std::make_shared<op::Broadcast>(fconst1, Shape{3, 5}, AxisSet{0, 1});
    auto sum_fconst1 = std::make_shared<op::Sum>(broadcast, AxisSet{0, 1});

    pass::Manager pass_manager;
    pass_manager.register_pass<pass::AlgebraicSimplification>();

319
    auto f = std::make_shared<Function>(ngraph::NodeVector{sum_fconst1}, ParameterVector{});
320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337
    pass_manager.run_passes(f);
    auto new_const =
        std::dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0));
    ASSERT_TRUE(new_const);
    auto values = new_const->get_vector<double>();
    ASSERT_EQ(values.size(), 1);
    ASSERT_EQ(values.at(0), 15);
}

TEST(algebraic_simplification, multiply_sum_vector_one)
{
    auto fconst1 = ngraph::op::Constant::create(element::f64, Shape{}, {1.0});
    auto broadcast = std::make_shared<op::Broadcast>(fconst1, Shape{3, 5}, AxisSet{0, 1});
    auto sum_fconst1 = std::make_shared<op::Sum>(broadcast, AxisSet{1});

    pass::Manager pass_manager;
    pass_manager.register_pass<pass::AlgebraicSimplification>();

338
    auto f = std::make_shared<Function>(ngraph::NodeVector{sum_fconst1}, ParameterVector{});
339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357
    pass_manager.run_passes(f);
    auto new_broadcast =
        std::dynamic_pointer_cast<op::Broadcast>(f->get_results().at(0)->get_argument(0));
    ASSERT_TRUE(new_broadcast);
    auto new_const = std::dynamic_pointer_cast<op::Constant>(new_broadcast->get_argument(0));
    auto values = new_const->get_vector<double>();
    ASSERT_EQ(values.size(), 1);
    ASSERT_EQ(values.at(0), 5);
}

TEST(algebraic_simplification, multiply_sum_negative)
{
    auto fconst1 = ngraph::op::Constant::create(element::f64, Shape{2}, {1.0, 1.0});
    auto broadcast = std::make_shared<op::Broadcast>(fconst1, Shape{2, 5}, AxisSet{1});
    auto sum_fconst1 = std::make_shared<op::Sum>(broadcast, AxisSet{0, 1});

    pass::Manager pass_manager;
    pass_manager.register_pass<pass::AlgebraicSimplification>();

358
    auto f = std::make_shared<Function>(ngraph::NodeVector{sum_fconst1}, ParameterVector{});
359 360 361 362
    pass_manager.run_passes(f);
    auto f_sum = f->get_results().at(0)->get_argument(0);
    ASSERT_EQ(f_sum, sum_fconst1);
}
363

364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381
TEST(algebraic_simplification, concat_reshape_slice)
{
    auto a = make_shared<op::Parameter>(element::f32, Shape{96, 100});
    auto goe = make_shared<op::GetOutputElement>(a, 0);
    auto slice1 = make_shared<op::Slice>(goe, Coordinate{0, 0}, Coordinate{32, 100}, Strides{1, 1});
    auto slice2 =
        make_shared<op::Slice>(goe, Coordinate{32, 0}, Coordinate{64, 100}, Strides{1, 1});
    auto slice3 =
        make_shared<op::Slice>(goe, Coordinate{64, 0}, Coordinate{96, 100}, Strides{1, 1});

    auto reshape1 = make_shared<op::Reshape>(slice1, AxisVector{0, 1}, Shape{32, 1, 100});
    auto reshape2 = make_shared<op::Reshape>(slice2, AxisVector{0, 1}, Shape{32, 1, 100});
    auto reshape3 = make_shared<op::Reshape>(slice3, AxisVector{0, 1}, Shape{32, 1, 100});

    size_t concat_axis = 1;
    auto concat = make_shared<op::Concat>(NodeVector{reshape1, reshape2, reshape3}, concat_axis);

    pass::Manager pass_manager;
382
    pass_manager.register_pass<pass::VisualizeTree>("before.png");
383
    pass_manager.register_pass<pass::AlgebraicSimplification>();
384
    pass_manager.register_pass<pass::VisualizeTree>("after.png");
385

386
    auto f = std::make_shared<Function>(ngraph::NodeVector{concat}, ParameterVector{a});
387
    pass_manager.run_passes(f);
388
    ASSERT_TRUE(std::dynamic_pointer_cast<op::Reshape>(f->get_results().at(0)->get_argument(0)));
389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404
}

TEST(algebraic_simplification, concat_slice)
{
    auto a = make_shared<op::Parameter>(element::f32, Shape{96, 100});
    auto goe = make_shared<op::GetOutputElement>(a, 0);
    auto slice1 = make_shared<op::Slice>(goe, Coordinate{0, 0}, Coordinate{32, 100}, Strides{1, 1});
    auto slice2 =
        make_shared<op::Slice>(goe, Coordinate{32, 0}, Coordinate{64, 100}, Strides{1, 1});
    auto slice3 =
        make_shared<op::Slice>(goe, Coordinate{64, 0}, Coordinate{96, 100}, Strides{1, 1});

    size_t concat_axis = 0;
    auto concat = make_shared<op::Concat>(NodeVector{slice1, slice2, slice3}, concat_axis);

    pass::Manager pass_manager;
405
    pass_manager.register_pass<pass::VisualizeTree>("before.png");
406
    pass_manager.register_pass<pass::AlgebraicSimplification>();
407
    pass_manager.register_pass<pass::VisualizeTree>("after.png");
408

409
    auto f = std::make_shared<Function>(ngraph::NodeVector{concat}, ParameterVector{a});
410 411 412 413 414 415 416 417 418 419 420 421 422 423 424
    pass_manager.run_passes(f);
    ASSERT_EQ(f->get_results().at(0)->get_argument(0), goe);
}

TEST(algebraic_simplification, concat_parameter_slice)
{
    auto a = make_shared<op::Parameter>(element::f32, Shape{96, 100});
    auto slice1 = make_shared<op::Slice>(a, Coordinate{0, 0}, Coordinate{32, 100}, Strides{1, 1});
    auto slice2 = make_shared<op::Slice>(a, Coordinate{32, 0}, Coordinate{64, 100}, Strides{1, 1});
    auto slice3 = make_shared<op::Slice>(a, Coordinate{64, 0}, Coordinate{96, 100}, Strides{1, 1});

    size_t concat_axis = 0;
    auto concat = make_shared<op::Concat>(NodeVector{slice1, slice2, slice3}, concat_axis);

    pass::Manager pass_manager;
425
    pass_manager.register_pass<pass::VisualizeTree>("before.png");
426
    pass_manager.register_pass<pass::AlgebraicSimplification>();
427
    pass_manager.register_pass<pass::VisualizeTree>("after.png");
428

429
    auto f = std::make_shared<Function>(ngraph::NodeVector{concat}, ParameterVector{a});
430 431 432 433 434 435 436 437 438 439 440 441 442 443 444
    pass_manager.run_passes(f);
    ASSERT_EQ(f->get_results().at(0)->get_argument(0), a);
}

TEST(algebraic_simplification, concat_parameter_slices_reversed)
{
    auto a = make_shared<op::Parameter>(element::f32, Shape{96, 100});
    auto slice1 = make_shared<op::Slice>(a, Coordinate{0, 0}, Coordinate{32, 100}, Strides{1, 1});
    auto slice2 = make_shared<op::Slice>(a, Coordinate{32, 0}, Coordinate{64, 100}, Strides{1, 1});
    auto slice3 = make_shared<op::Slice>(a, Coordinate{64, 0}, Coordinate{96, 100}, Strides{1, 1});

    size_t concat_axis = 0;
    auto concat = make_shared<op::Concat>(NodeVector{slice3, slice2, slice1}, concat_axis);

    pass::Manager pass_manager;
445
    pass_manager.register_pass<pass::VisualizeTree>("before.png");
446
    pass_manager.register_pass<pass::AlgebraicSimplification>();
447
    pass_manager.register_pass<pass::VisualizeTree>("after.png");
448

449
    auto f = std::make_shared<Function>(ngraph::NodeVector{concat}, ParameterVector{a});
450 451 452 453 454 455 456
    pass_manager.run_passes(f);
    ASSERT_EQ(f->get_results().at(0)->get_argument(0), concat);
}

TEST(algebraic_simplification, concat_parameter_slices_element_count)
{
    auto a = make_shared<op::Parameter>(element::f32, Shape{96, 100});
457
    // slicing 30 elements out of 96; should trigger a check that some elements are missing
458 459 460 461 462 463 464 465
    auto slice1 = make_shared<op::Slice>(a, Coordinate{0, 0}, Coordinate{10, 100}, Strides{1, 1});
    auto slice2 = make_shared<op::Slice>(a, Coordinate{10, 0}, Coordinate{20, 100}, Strides{1, 1});
    auto slice3 = make_shared<op::Slice>(a, Coordinate{20, 0}, Coordinate{30, 100}, Strides{1, 1});

    size_t concat_axis = 0;
    auto concat = make_shared<op::Concat>(NodeVector{slice1, slice2, slice3}, concat_axis);

    pass::Manager pass_manager;
466
    pass_manager.register_pass<pass::VisualizeTree>("before.png");
467
    pass_manager.register_pass<pass::AlgebraicSimplification>();
468
    pass_manager.register_pass<pass::VisualizeTree>("after.png");
469

470
    auto f = std::make_shared<Function>(ngraph::NodeVector{concat}, ParameterVector{a});
471 472 473 474 475 476 477 478 479 480 481 482 483 484 485
    pass_manager.run_passes(f);
    ASSERT_EQ(f->get_results().at(0)->get_argument(0), concat);
}

TEST(algebraic_simplification, concat_parameter_non_uniform_slices)
{
    auto a = make_shared<op::Parameter>(element::f32, Shape{96, 100});
    auto slice1 = make_shared<op::Slice>(a, Coordinate{0, 0}, Coordinate{38, 100}, Strides{1, 1});
    auto slice2 = make_shared<op::Slice>(a, Coordinate{38, 0}, Coordinate{64, 100}, Strides{1, 1});
    auto slice3 = make_shared<op::Slice>(a, Coordinate{64, 0}, Coordinate{96, 100}, Strides{1, 1});

    size_t concat_axis = 0;
    auto concat = make_shared<op::Concat>(NodeVector{slice1, slice2, slice3}, concat_axis);

    pass::Manager pass_manager;
486
    pass_manager.register_pass<pass::VisualizeTree>("before.png");
487
    pass_manager.register_pass<pass::AlgebraicSimplification>();
488
    pass_manager.register_pass<pass::VisualizeTree>("after.png");
489

490
    auto f = std::make_shared<Function>(ngraph::NodeVector{concat}, ParameterVector{a});
491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510
    pass_manager.run_passes(f);
    ASSERT_EQ(f->get_results().at(0)->get_argument(0), concat);
}

TEST(algebraic_simplification, concat_different_goes)
{
    auto a = make_shared<op::Parameter>(element::f32, Shape{96, 100});
    auto goe1 = make_shared<op::GetOutputElement>(a, 0);
    auto goe2 = make_shared<op::GetOutputElement>(a, 0);
    auto slice1 =
        make_shared<op::Slice>(goe1, Coordinate{0, 0}, Coordinate{32, 100}, Strides{1, 1});
    auto slice2 =
        make_shared<op::Slice>(goe2, Coordinate{32, 0}, Coordinate{64, 100}, Strides{1, 1});
    auto slice3 =
        make_shared<op::Slice>(goe1, Coordinate{64, 0}, Coordinate{96, 100}, Strides{1, 1});

    size_t concat_axis = 0;
    auto concat = make_shared<op::Concat>(NodeVector{slice1, slice2, slice3}, concat_axis);

    pass::Manager pass_manager;
511
    pass_manager.register_pass<pass::VisualizeTree>("before.png");
512
    pass_manager.register_pass<pass::AlgebraicSimplification>();
513
    pass_manager.register_pass<pass::VisualizeTree>("after.png");
514

515
    auto f = std::make_shared<Function>(ngraph::NodeVector{concat}, ParameterVector{a});
516 517 518 519
    pass_manager.run_passes(f);
    ASSERT_EQ(f->get_results().at(0)->get_argument(0), concat);
}

520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535
TEST(algebraic_simplification, log_neg_neg)
{
    auto a = make_shared<op::Parameter>(element::f32, Shape{96, 100});
    auto b = make_shared<op::Parameter>(element::f32, Shape{96, 100});
    auto exp_a = make_shared<op::Exp>(a);
    auto div = exp_a / b;
    auto log_div = make_shared<op::Log>(div);

    auto neg_inner = make_shared<op::Negative>(log_div);
    auto neg2 = make_shared<op::Negative>(neg_inner);
    auto neg3 = make_shared<op::Negative>(neg2);
    auto neg4 = make_shared<op::Negative>(neg3);

    pass::Manager pass_manager;
    pass_manager.register_pass<pass::AlgebraicSimplification>();

536
    auto f = std::make_shared<Function>(ngraph::NodeVector{neg4}, ParameterVector{a, b});
537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561
    pass_manager.run_passes(f);
    auto sub = std::dynamic_pointer_cast<op::Subtract>(neg_inner->get_argument(0));
    ASSERT_TRUE(sub != nullptr);
    ASSERT_EQ(sub->get_argument(0), a);
    auto new_log = std::dynamic_pointer_cast<op::Log>(sub->get_argument(1));
    ASSERT_TRUE(new_log != nullptr);
    ASSERT_EQ(new_log->get_argument(0), b);
}

TEST(algebraic_simplification, log_no_exp)
{
    auto a = make_shared<op::Parameter>(element::f32, Shape{96, 100});
    auto b = make_shared<op::Parameter>(element::f32, Shape{96, 100});
    auto abs_a = make_shared<op::Abs>(a);
    auto div = abs_a / b;
    auto log_div = make_shared<op::Log>(div);

    auto neg_inner = make_shared<op::Negative>(log_div);
    auto neg2 = make_shared<op::Negative>(neg_inner);
    auto neg3 = make_shared<op::Negative>(neg2);
    auto neg4 = make_shared<op::Negative>(neg3);

    pass::Manager pass_manager;
    pass_manager.register_pass<pass::AlgebraicSimplification>();

562
    auto f = std::make_shared<Function>(ngraph::NodeVector{neg4}, ParameterVector{a, b});
563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582
    pass_manager.run_passes(f);
    ASSERT_EQ(neg_inner->get_argument(0), log_div);
}

TEST(algebraic_simplification, log_no_divide)
{
    auto a = make_shared<op::Parameter>(element::f32, Shape{96, 100});
    auto b = make_shared<op::Parameter>(element::f32, Shape{96, 100});
    auto exp_a = make_shared<op::Exp>(a);
    auto mul = exp_a * b;
    auto log_mul = make_shared<op::Log>(mul);

    auto neg_inner = make_shared<op::Negative>(log_mul);
    auto neg2 = make_shared<op::Negative>(neg_inner);
    auto neg3 = make_shared<op::Negative>(neg2);
    auto neg4 = make_shared<op::Negative>(neg3);

    pass::Manager pass_manager;
    pass_manager.register_pass<pass::AlgebraicSimplification>();

583
    auto f = std::make_shared<Function>(ngraph::NodeVector{neg4}, ParameterVector{a, b});
584 585 586
    pass_manager.run_passes(f);
    ASSERT_EQ(neg_inner->get_argument(0), log_mul);
}
587 588 589 590 591 592

TEST(algebraic_simplification, pass_property)
{
    auto pass = std::make_shared<ngraph::pass::AlgebraicSimplification>();

    ASSERT_EQ(true, pass->get_property(pass::PassProperty::REQUIRE_STATIC_SHAPE));
593
    ASSERT_EQ(false, pass->get_property(pass::PassProperty::CHANGE_DYNAMIC_STATE));
594
}