Commit 76047c77 authored by Robert Kimball's avatar Robert Kimball Committed by Scott Cyphers

Use less complex pass base where possible (#829)

parent 81c0ef79
......@@ -27,12 +27,10 @@ namespace ngraph
namespace pass
{
template <typename LT>
class AssignLayout : public CallGraphPass
class AssignLayout : public NodePass
{
public:
virtual bool run_on_call_graph(const std::list<std::shared_ptr<Node>>& nodes) override
{
for (const std::shared_ptr<Node>& node : nodes)
virtual bool run_on_node(std::shared_ptr<Node> node) override
{
try
{
......@@ -53,7 +51,6 @@ namespace ngraph
ss << e.what();
throw std::invalid_argument(ss.str());
}
}
return false;
}
};
......
......@@ -18,25 +18,15 @@
#include "ngraph/node.hpp"
#include "ngraph/placement.hpp"
using namespace std;
using namespace ngraph;
using namespace std;
ngraph::pass::AssignPlacement::AssignPlacement(
std::function<Placement(std::shared_ptr<Node>)> placement_policy)
pass::AssignPlacement::AssignPlacement(function<Placement(shared_ptr<Node>)> placement_policy)
: m_placement_policy(placement_policy)
{
}
bool ngraph::pass::AssignPlacement::run_on_call_graph(const std::list<std::shared_ptr<Node>>& nodes)
{
for (const std::shared_ptr<Node>& node : nodes)
{
run_on_node(node);
}
return false;
}
bool ngraph::pass::AssignPlacement::run_on_node(shared_ptr<Node> node)
bool pass::AssignPlacement::run_on_node(shared_ptr<Node> node)
{
node->set_placement(m_placement_policy(node));
return false;
......
......@@ -27,15 +27,14 @@ namespace ngraph
{
namespace pass
{
class AssignPlacement : public CallGraphPass
class AssignPlacement : public NodePass
{
public:
// TODO: make policy a class
AssignPlacement(std::function<Placement(std::shared_ptr<Node>)> placement_policy);
virtual bool run_on_call_graph(const std::list<std::shared_ptr<Node>>& nodes) override;
private:
bool run_on_node(std::shared_ptr<Node> node);
bool run_on_node(std::shared_ptr<Node> node) override;
std::function<Placement(std::shared_ptr<Node>)> m_placement_policy;
};
}
......
......@@ -30,23 +30,20 @@
#include "ngraph/op/sum.hpp"
using namespace ngraph;
using namespace std;
bool ngraph::pass::GetOutputElementElimination::run_on_function(std::shared_ptr<ngraph::Function> f)
bool pass::GetOutputElementElimination::run_on_node(shared_ptr<Node> n)
{
bool optimized = false;
for (auto n : f->get_ordered_ops())
{
for (auto& input : n->get_inputs())
{
if (auto goe =
std::dynamic_pointer_cast<op::GetOutputElement>(input.get_output().get_node()))
if (auto goe = dynamic_pointer_cast<op::GetOutputElement>(input.get_output().get_node()))
{
auto multi = goe->get_inputs().at(0).get_output().get_node();
input.replace_output(goe->get_inputs().at(goe->get_n()).get_output());
//we don't need to fix anything w.r.t GetOutputElement as it will become unreachable
// we don't need to fix anything w.r.t GetOutputElement as it will become unreachable
optimized = true;
}
}
}
return optimized;
}
......@@ -26,13 +26,8 @@ namespace ngraph
}
}
class ngraph::pass::GetOutputElementElimination : public FunctionPass
class ngraph::pass::GetOutputElementElimination : public NodePass
{
public:
GetOutputElementElimination()
: FunctionPass()
{
}
virtual bool run_on_function(std::shared_ptr<ngraph::Function> f);
bool run_on_node(std::shared_ptr<Node> node) override;
};
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment