Commit aacbb305 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Robert Kimball

remove functions from handlers in nop_elimination (#1007)

parent fa221c5f
...@@ -25,20 +25,19 @@ ...@@ -25,20 +25,19 @@
#include "ngraph/op/pad.hpp" #include "ngraph/op/pad.hpp"
#include "ngraph/op/slice.hpp" #include "ngraph/op/slice.hpp"
#include "ngraph/op/sum.hpp" #include "ngraph/op/sum.hpp"
#include "ngraph/util.hpp"
#include "nop_elimination.hpp" #include "nop_elimination.hpp"
#define TI(x) std::type_index(typeid(x)) #define TI(x) std::type_index(typeid(x))
#define HANDLER_DECL(x) \ #define HANDLER_DECL(x) static bool x(const std::shared_ptr<ngraph::Node>& node)
static bool x(const std::shared_ptr<ngraph::Function>& function, \
const std::shared_ptr<ngraph::Node>& node)
HANDLER_DECL(eliminate_pad) HANDLER_DECL(eliminate_pad)
{ {
auto pad = std::dynamic_pointer_cast<ngraph::op::Pad>(node); auto pad = std::dynamic_pointer_cast<ngraph::op::Pad>(node);
if (pad->get_input_shape(0) == pad->get_output_shape(0)) if (pad->get_input_shape(0) == pad->get_output_shape(0))
{ {
function->replace_node(node, node->get_argument(0)); ngraph::replace_node(node, node->get_argument(0));
return true; return true;
} }
return false; return false;
...@@ -49,7 +48,7 @@ HANDLER_DECL(eliminate_sum) ...@@ -49,7 +48,7 @@ HANDLER_DECL(eliminate_sum)
auto sum = std::dynamic_pointer_cast<ngraph::op::Sum>(node); auto sum = std::dynamic_pointer_cast<ngraph::op::Sum>(node);
if (sum->get_reduction_axes().empty()) if (sum->get_reduction_axes().empty())
{ {
function->replace_node(node, node->get_argument(0)); ngraph::replace_node(node, node->get_argument(0));
return true; return true;
} }
return false; return false;
...@@ -60,7 +59,7 @@ HANDLER_DECL(eliminate_convert) ...@@ -60,7 +59,7 @@ HANDLER_DECL(eliminate_convert)
auto convert = std::dynamic_pointer_cast<ngraph::op::Convert>(node); auto convert = std::dynamic_pointer_cast<ngraph::op::Convert>(node);
if (convert->get_convert_element_type() == convert->get_argument(0)->get_element_type()) if (convert->get_convert_element_type() == convert->get_argument(0)->get_element_type())
{ {
function->replace_node(node, node->get_argument(0)); ngraph::replace_node(node, node->get_argument(0));
return true; return true;
} }
return false; return false;
...@@ -71,7 +70,7 @@ HANDLER_DECL(eliminate_slice) ...@@ -71,7 +70,7 @@ HANDLER_DECL(eliminate_slice)
auto slice = std::dynamic_pointer_cast<ngraph::op::Slice>(node); auto slice = std::dynamic_pointer_cast<ngraph::op::Slice>(node);
if (slice->get_input_shape(0) == slice->get_output_shape(0)) if (slice->get_input_shape(0) == slice->get_output_shape(0))
{ {
function->replace_node(node, node->get_argument(0)); ngraph::replace_node(node, node->get_argument(0));
return true; return true;
} }
return false; return false;
...@@ -82,15 +81,14 @@ HANDLER_DECL(eliminate_broadcast) ...@@ -82,15 +81,14 @@ HANDLER_DECL(eliminate_broadcast)
auto broadcast = std::dynamic_pointer_cast<ngraph::op::Broadcast>(node); auto broadcast = std::dynamic_pointer_cast<ngraph::op::Broadcast>(node);
if (broadcast->get_input_shape(0) == broadcast->get_output_shape(0)) if (broadcast->get_input_shape(0) == broadcast->get_output_shape(0))
{ {
function->replace_node(node, node->get_argument(0)); ngraph::replace_node(node, node->get_argument(0));
return true; return true;
} }
return false; return false;
} }
static const std::unordered_map<std::type_index, static const std::unordered_map<std::type_index,
std::function<bool(const std::shared_ptr<ngraph::Function>&, std::function<bool(const std::shared_ptr<ngraph::Node>&)>>
const std::shared_ptr<ngraph::Node>&)>>
dispatcher{{TI(ngraph::op::Pad), &eliminate_pad}, dispatcher{{TI(ngraph::op::Pad), &eliminate_pad},
{TI(ngraph::op::Sum), &eliminate_sum}, {TI(ngraph::op::Sum), &eliminate_sum},
{TI(ngraph::op::Convert), &eliminate_convert}, {TI(ngraph::op::Convert), &eliminate_convert},
...@@ -108,7 +106,7 @@ bool ngraph::pass::NopElimination::run_on_function(std::shared_ptr<ngraph::Funct ...@@ -108,7 +106,7 @@ bool ngraph::pass::NopElimination::run_on_function(std::shared_ptr<ngraph::Funct
auto handler = dispatcher.find(TI(node)); auto handler = dispatcher.find(TI(node));
if (handler != dispatcher.end()) if (handler != dispatcher.end())
{ {
clobbered = handler->second(function, n) || clobbered; clobbered = handler->second(n) || clobbered;
} }
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment