Commit aacbb305 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Robert Kimball

remove functions from handlers in nop_elimination (#1007)

parent fa221c5f
......@@ -25,20 +25,19 @@
#include "ngraph/op/pad.hpp"
#include "ngraph/op/slice.hpp"
#include "ngraph/op/sum.hpp"
#include "ngraph/util.hpp"
#include "nop_elimination.hpp"
#define TI(x) std::type_index(typeid(x))
#define HANDLER_DECL(x) \
static bool x(const std::shared_ptr<ngraph::Function>& function, \
const std::shared_ptr<ngraph::Node>& node)
#define HANDLER_DECL(x) static bool x(const std::shared_ptr<ngraph::Node>& node)
HANDLER_DECL(eliminate_pad)
{
auto pad = std::dynamic_pointer_cast<ngraph::op::Pad>(node);
if (pad->get_input_shape(0) == pad->get_output_shape(0))
{
function->replace_node(node, node->get_argument(0));
ngraph::replace_node(node, node->get_argument(0));
return true;
}
return false;
......@@ -49,7 +48,7 @@ HANDLER_DECL(eliminate_sum)
auto sum = std::dynamic_pointer_cast<ngraph::op::Sum>(node);
if (sum->get_reduction_axes().empty())
{
function->replace_node(node, node->get_argument(0));
ngraph::replace_node(node, node->get_argument(0));
return true;
}
return false;
......@@ -60,7 +59,7 @@ HANDLER_DECL(eliminate_convert)
auto convert = std::dynamic_pointer_cast<ngraph::op::Convert>(node);
if (convert->get_convert_element_type() == convert->get_argument(0)->get_element_type())
{
function->replace_node(node, node->get_argument(0));
ngraph::replace_node(node, node->get_argument(0));
return true;
}
return false;
......@@ -71,7 +70,7 @@ HANDLER_DECL(eliminate_slice)
auto slice = std::dynamic_pointer_cast<ngraph::op::Slice>(node);
if (slice->get_input_shape(0) == slice->get_output_shape(0))
{
function->replace_node(node, node->get_argument(0));
ngraph::replace_node(node, node->get_argument(0));
return true;
}
return false;
......@@ -82,15 +81,14 @@ HANDLER_DECL(eliminate_broadcast)
auto broadcast = std::dynamic_pointer_cast<ngraph::op::Broadcast>(node);
if (broadcast->get_input_shape(0) == broadcast->get_output_shape(0))
{
function->replace_node(node, node->get_argument(0));
ngraph::replace_node(node, node->get_argument(0));
return true;
}
return false;
}
static const std::unordered_map<std::type_index,
std::function<bool(const std::shared_ptr<ngraph::Function>&,
const std::shared_ptr<ngraph::Node>&)>>
std::function<bool(const std::shared_ptr<ngraph::Node>&)>>
dispatcher{{TI(ngraph::op::Pad), &eliminate_pad},
{TI(ngraph::op::Sum), &eliminate_sum},
{TI(ngraph::op::Convert), &eliminate_convert},
......@@ -108,7 +106,7 @@ bool ngraph::pass::NopElimination::run_on_function(std::shared_ptr<ngraph::Funct
auto handler = dispatcher.find(TI(node));
if (handler != dispatcher.end())
{
clobbered = handler->second(function, n) || clobbered;
clobbered = handler->second(n) || clobbered;
}
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment