Commit 76047c77 authored by Robert Kimball's avatar Robert Kimball Committed by Scott Cyphers

Use less complex pass base where possible (#829)

parent 81c0ef79
...@@ -27,12 +27,10 @@ namespace ngraph ...@@ -27,12 +27,10 @@ namespace ngraph
namespace pass namespace pass
{ {
template <typename LT> template <typename LT>
class AssignLayout : public CallGraphPass class AssignLayout : public NodePass
{ {
public: public:
virtual bool run_on_call_graph(const std::list<std::shared_ptr<Node>>& nodes) override virtual bool run_on_node(std::shared_ptr<Node> node) override
{
for (const std::shared_ptr<Node>& node : nodes)
{ {
try try
{ {
...@@ -53,7 +51,6 @@ namespace ngraph ...@@ -53,7 +51,6 @@ namespace ngraph
ss << e.what(); ss << e.what();
throw std::invalid_argument(ss.str()); throw std::invalid_argument(ss.str());
} }
}
return false; return false;
} }
}; };
......
...@@ -18,25 +18,15 @@ ...@@ -18,25 +18,15 @@
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/placement.hpp" #include "ngraph/placement.hpp"
using namespace std;
using namespace ngraph; using namespace ngraph;
using namespace std;
ngraph::pass::AssignPlacement::AssignPlacement( pass::AssignPlacement::AssignPlacement(function<Placement(shared_ptr<Node>)> placement_policy)
std::function<Placement(std::shared_ptr<Node>)> placement_policy)
: m_placement_policy(placement_policy) : m_placement_policy(placement_policy)
{ {
} }
bool ngraph::pass::AssignPlacement::run_on_call_graph(const std::list<std::shared_ptr<Node>>& nodes) bool pass::AssignPlacement::run_on_node(shared_ptr<Node> node)
{
for (const std::shared_ptr<Node>& node : nodes)
{
run_on_node(node);
}
return false;
}
bool ngraph::pass::AssignPlacement::run_on_node(shared_ptr<Node> node)
{ {
node->set_placement(m_placement_policy(node)); node->set_placement(m_placement_policy(node));
return false; return false;
......
...@@ -27,15 +27,14 @@ namespace ngraph ...@@ -27,15 +27,14 @@ namespace ngraph
{ {
namespace pass namespace pass
{ {
class AssignPlacement : public CallGraphPass class AssignPlacement : public NodePass
{ {
public: public:
// TODO: make policy a class // TODO: make policy a class
AssignPlacement(std::function<Placement(std::shared_ptr<Node>)> placement_policy); AssignPlacement(std::function<Placement(std::shared_ptr<Node>)> placement_policy);
virtual bool run_on_call_graph(const std::list<std::shared_ptr<Node>>& nodes) override;
private: private:
bool run_on_node(std::shared_ptr<Node> node); bool run_on_node(std::shared_ptr<Node> node) override;
std::function<Placement(std::shared_ptr<Node>)> m_placement_policy; std::function<Placement(std::shared_ptr<Node>)> m_placement_policy;
}; };
} }
......
...@@ -30,23 +30,20 @@ ...@@ -30,23 +30,20 @@
#include "ngraph/op/sum.hpp" #include "ngraph/op/sum.hpp"
using namespace ngraph; using namespace ngraph;
using namespace std;
bool ngraph::pass::GetOutputElementElimination::run_on_function(std::shared_ptr<ngraph::Function> f) bool pass::GetOutputElementElimination::run_on_node(shared_ptr<Node> n)
{ {
bool optimized = false; bool optimized = false;
for (auto n : f->get_ordered_ops())
{
for (auto& input : n->get_inputs()) for (auto& input : n->get_inputs())
{ {
if (auto goe = if (auto goe = dynamic_pointer_cast<op::GetOutputElement>(input.get_output().get_node()))
std::dynamic_pointer_cast<op::GetOutputElement>(input.get_output().get_node()))
{ {
auto multi = goe->get_inputs().at(0).get_output().get_node(); auto multi = goe->get_inputs().at(0).get_output().get_node();
input.replace_output(goe->get_inputs().at(goe->get_n()).get_output()); input.replace_output(goe->get_inputs().at(goe->get_n()).get_output());
//we don't need to fix anything w.r.t GetOutputElement as it will become unreachable // we don't need to fix anything w.r.t GetOutputElement as it will become unreachable
optimized = true; optimized = true;
} }
} }
}
return optimized; return optimized;
} }
...@@ -26,13 +26,8 @@ namespace ngraph ...@@ -26,13 +26,8 @@ namespace ngraph
} }
} }
class ngraph::pass::GetOutputElementElimination : public FunctionPass class ngraph::pass::GetOutputElementElimination : public NodePass
{ {
public: public:
GetOutputElementElimination() bool run_on_node(std::shared_ptr<Node> node) override;
: FunctionPass()
{
}
virtual bool run_on_function(std::shared_ptr<ngraph::Function> f);
}; };
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment