Commit ddf3f4f0 authored by Adam Straw's avatar Adam Straw Committed by Scott Cyphers

move nop elimination pass to nGraph and add broadcast elimination (#995)

* move nop elimination pass to nGraph and add broadcast elimination

* fix pad test bug

* remove graph visualizer and clean up includes in nop eliminate test

* code format
parent 76e36f2a
......@@ -117,6 +117,7 @@ set (SRC
pass/manager_state.cpp
pass/memory_layout.cpp
pass/memory_visualize.cpp
pass/nop_elimination.cpp
pass/pass.cpp
pass/reshape_elimination.cpp
pass/result_copy_elimination.cpp
......@@ -222,7 +223,6 @@ if (NGRAPH_CPU_ENABLE AND LLVM_INCLUDE_DIR AND MKLDNN_INCLUDE_DIR)
runtime/cpu/pass/cpu_fusion.cpp
runtime/cpu/pass/cpu_workspace_insertion.cpp
runtime/cpu/pass/cpu_layout.cpp
runtime/cpu/pass/cpu_nop_elimination.cpp
runtime/cpu/pass/cpu_rnn_mat_fusion.cpp
runtime/cpu/pass/cpu_post_layout_optimizations.cpp
runtime/cpu/pass/cpu_shuffle_folding.cpp
......
......@@ -99,7 +99,7 @@ static bool simplify_multiply(std::shared_ptr<Node> n)
return false;
}
//`simplify_multiply` optimizes the following 2 *base* cases
//`simplify_add` optimizes the following 2 *base* cases
//(4 cases in total including variants due to commutativity)
//
//a + 0 -> a
......
......@@ -20,11 +20,12 @@
#include <typeinfo>
#include <unordered_map>
#include "cpu_nop_elimination.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/convert.hpp"
#include "ngraph/op/pad.hpp"
#include "ngraph/op/slice.hpp"
#include "ngraph/op/sum.hpp"
#include "nop_elimination.hpp"
#define TI(x) std::type_index(typeid(x))
......@@ -76,16 +77,27 @@ HANDLER_DECL(eliminate_slice)
return false;
}
HANDLER_DECL(eliminate_broadcast)
{
auto broadcast = std::dynamic_pointer_cast<ngraph::op::Broadcast>(node);
if (broadcast->get_input_shape(0) == broadcast->get_output_shape(0))
{
function->replace_node(node, node->get_argument(0));
return true;
}
return false;
}
static const std::unordered_map<std::type_index,
std::function<bool(const std::shared_ptr<ngraph::Function>&,
const std::shared_ptr<ngraph::Node>&)>>
dispatcher{{TI(ngraph::op::Pad), &eliminate_pad},
{TI(ngraph::op::Sum), &eliminate_sum},
{TI(ngraph::op::Convert), &eliminate_convert},
{TI(ngraph::op::Slice), &eliminate_slice}};
{TI(ngraph::op::Slice), &eliminate_slice},
{TI(ngraph::op::Broadcast), &eliminate_broadcast}};
bool ngraph::runtime::cpu::pass::CPUNopElimination::run_on_function(
std::shared_ptr<ngraph::Function> function)
bool ngraph::pass::NopElimination::run_on_function(std::shared_ptr<ngraph::Function> function)
{
bool clobbered = false;
......
......@@ -20,18 +20,12 @@
namespace ngraph
{
namespace runtime
namespace pass
{
namespace cpu
class NopElimination : public FunctionPass
{
namespace pass
{
class CPUNopElimination : public ngraph::pass::FunctionPass
{
public:
bool run_on_function(std::shared_ptr<ngraph::Function> function) override;
};
}
}
public:
bool run_on_function(std::shared_ptr<ngraph::Function> function) override;
};
}
}
......@@ -107,6 +107,7 @@
#include "ngraph/pass/liveness.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/memory_layout.hpp"
#include "ngraph/pass/nop_elimination.hpp"
#include "ngraph/pass/result_copy_elimination.hpp"
#include "ngraph/runtime/cpu/cpu_backend.hpp"
#include "ngraph/runtime/cpu/cpu_call_frame.hpp"
......@@ -125,7 +126,6 @@
#include "ngraph/runtime/cpu/pass/cpu_assignment.hpp"
#include "ngraph/runtime/cpu/pass/cpu_fusion.hpp"
#include "ngraph/runtime/cpu/pass/cpu_layout.hpp"
#include "ngraph/runtime/cpu/pass/cpu_nop_elimination.hpp"
#include "ngraph/runtime/cpu/pass/cpu_post_layout_optimizations.hpp"
#include "ngraph/runtime/cpu/pass/cpu_shuffle_folding.hpp"
#include "ngraph/runtime/cpu/pass/cpu_workspace_insertion.hpp"
......@@ -319,7 +319,7 @@ void runtime::cpu::CPU_ExternalFunction::compile()
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPUNopElimination>();
pass_manager.register_pass<ngraph::pass::NopElimination>();
pass_manager.register_pass<ngraph::pass::AlgebraicSimplification>();
pass_manager.register_pass<ngraph::pass::CommonSubexpressionElimination>();
pass_manager.register_pass<ngraph::pass::CoreFusion>();
......
......@@ -38,6 +38,7 @@ set (SRC
main.cpp
op.cpp
graph_partition.cpp
nop_elimination.cpp
pass_liveness.cpp
pass_manager.cpp
pass_memory_layout.cpp
......
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include <memory>
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/nop_elimination.hpp"
#include "util/test_tools.hpp"
using namespace ngraph;
using namespace std;
TEST(nop_elimination, eliminate_pad)
{
Shape shape_a{2};
auto A = make_shared<op::Parameter>(element::f32, shape_a);
Shape shape_b{};
auto B = make_shared<op::Parameter>(element::f32, shape_b);
Shape padding_below{0};
Shape padding_above{0};
Shape padding_interior{0};
auto f = make_shared<Function>(
make_shared<op::Pad>(A, B, padding_below, padding_above, padding_interior),
op::ParameterVector{A, B});
pass::Manager pass_manager;
pass_manager.register_pass<pass::NopElimination>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Pad>(f), 0);
}
TEST(nop_elimination, eliminate_sum)
{
Shape shape{2, 2};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto f = make_shared<Function>(make_shared<op::Sum>(A, AxisSet{}), op::ParameterVector{A});
pass::Manager pass_manager;
pass_manager.register_pass<pass::NopElimination>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Sum>(f), 0);
}
TEST(nop_elimination, eliminate_convert)
{
Shape shape{};
auto type = element::f32;
auto A = make_shared<op::Parameter>(type, shape);
auto f =
make_shared<Function>(make_shared<op::Convert>(A, element::f32), op::ParameterVector{A});
pass::Manager pass_manager;
pass_manager.register_pass<pass::NopElimination>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Convert>(f), 0);
}
TEST(nop_elimination, eliminate_slice)
{
Shape shape{2, 2};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto f = make_shared<Function>(make_shared<op::Slice>(A, Coordinate{0, 0}, Coordinate{2, 2}),
op::ParameterVector{A});
pass::Manager pass_manager;
pass_manager.register_pass<pass::NopElimination>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Slice>(f), 0);
}
TEST(nop_elimination, eliminate_broadcast)
{
Shape shape{};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto f = make_shared<Function>(make_shared<op::Broadcast>(A, shape, AxisSet{}),
op::ParameterVector{A});
pass::Manager pass_manager;
pass_manager.register_pass<pass::NopElimination>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Broadcast>(f), 0);
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment