Commit 47ca008a authored by Jaikrishnan Menon's avatar Jaikrishnan Menon Committed by Scott Cyphers

CPU: Eliminate trivial sum reductions (#703)

parent 6f547fdb
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include "cpu_nop_elimination.hpp" #include "cpu_nop_elimination.hpp"
#include "ngraph/ops/pad.hpp" #include "ngraph/ops/pad.hpp"
#include "ngraph/ops/sum.hpp"
#define TI(x) std::type_index(typeid(x)) #define TI(x) std::type_index(typeid(x))
...@@ -40,10 +41,21 @@ HANDLER_DECL(eliminate_pad) ...@@ -40,10 +41,21 @@ HANDLER_DECL(eliminate_pad)
return false; return false;
} }
HANDLER_DECL(eliminate_sum)
{
auto sum = std::dynamic_pointer_cast<ngraph::op::Sum>(node);
if (sum->get_reduction_axes().empty())
{
function->replace_node(node, node->get_input_op(0));
return true;
}
return false;
}
static const std::unordered_map<std::type_index, static const std::unordered_map<std::type_index,
std::function<bool(const std::shared_ptr<ngraph::Function>&, std::function<bool(const std::shared_ptr<ngraph::Function>&,
const std::shared_ptr<ngraph::Node>&)>> const std::shared_ptr<ngraph::Node>&)>>
dispatcher{{TI(ngraph::op::Pad), &eliminate_pad}}; dispatcher{{TI(ngraph::op::Pad), &eliminate_pad}, {TI(ngraph::op::Sum), &eliminate_sum}};
bool ngraph::runtime::cpu::pass::CPUNopElimination::run_on_function( bool ngraph::runtime::cpu::pass::CPUNopElimination::run_on_function(
std::shared_ptr<ngraph::Function> function) std::shared_ptr<ngraph::Function> function)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment