Commit 76047c77 authored by Robert Kimball's avatar Robert Kimball Committed by Scott Cyphers

Use less complex pass base where possible (#829)

parent 81c0ef79
...@@ -27,32 +27,29 @@ namespace ngraph ...@@ -27,32 +27,29 @@ namespace ngraph
namespace pass namespace pass
{ {
template <typename LT> template <typename LT>
class AssignLayout : public CallGraphPass class AssignLayout : public NodePass
{ {
public: public:
virtual bool run_on_call_graph(const std::list<std::shared_ptr<Node>>& nodes) override virtual bool run_on_node(std::shared_ptr<Node> node) override
{ {
for (const std::shared_ptr<Node>& node : nodes) try
{ {
try for (size_t i = 0; i < node->get_output_size(); ++i)
{ {
for (size_t i = 0; i < node->get_output_size(); ++i) auto tv = node->get_output_tensor_view(i);
if (nullptr == tv->get_tensor_view_layout())
{ {
auto tv = node->get_output_tensor_view(i); auto layout = std::make_shared<LT>(*tv);
if (nullptr == tv->get_tensor_view_layout()) tv->set_tensor_view_layout(layout);
{
auto layout = std::make_shared<LT>(*tv);
tv->set_tensor_view_layout(layout);
}
} }
} }
catch (const std::exception& e) }
{ catch (const std::exception& e)
std::stringstream ss; {
ss << "Error with node " << *node << ": "; std::stringstream ss;
ss << e.what(); ss << "Error with node " << *node << ": ";
throw std::invalid_argument(ss.str()); ss << e.what();
} throw std::invalid_argument(ss.str());
} }
return false; return false;
} }
......
...@@ -18,25 +18,15 @@ ...@@ -18,25 +18,15 @@
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/placement.hpp" #include "ngraph/placement.hpp"
using namespace std;
using namespace ngraph; using namespace ngraph;
using namespace std;
ngraph::pass::AssignPlacement::AssignPlacement( pass::AssignPlacement::AssignPlacement(function<Placement(shared_ptr<Node>)> placement_policy)
std::function<Placement(std::shared_ptr<Node>)> placement_policy)
: m_placement_policy(placement_policy) : m_placement_policy(placement_policy)
{ {
} }
bool ngraph::pass::AssignPlacement::run_on_call_graph(const std::list<std::shared_ptr<Node>>& nodes) bool pass::AssignPlacement::run_on_node(shared_ptr<Node> node)
{
for (const std::shared_ptr<Node>& node : nodes)
{
run_on_node(node);
}
return false;
}
bool ngraph::pass::AssignPlacement::run_on_node(shared_ptr<Node> node)
{ {
node->set_placement(m_placement_policy(node)); node->set_placement(m_placement_policy(node));
return false; return false;
......
...@@ -27,15 +27,14 @@ namespace ngraph ...@@ -27,15 +27,14 @@ namespace ngraph
{ {
namespace pass namespace pass
{ {
class AssignPlacement : public CallGraphPass class AssignPlacement : public NodePass
{ {
public: public:
// TODO: make policy a class // TODO: make policy a class
AssignPlacement(std::function<Placement(std::shared_ptr<Node>)> placement_policy); AssignPlacement(std::function<Placement(std::shared_ptr<Node>)> placement_policy);
virtual bool run_on_call_graph(const std::list<std::shared_ptr<Node>>& nodes) override;
private: private:
bool run_on_node(std::shared_ptr<Node> node); bool run_on_node(std::shared_ptr<Node> node) override;
std::function<Placement(std::shared_ptr<Node>)> m_placement_policy; std::function<Placement(std::shared_ptr<Node>)> m_placement_policy;
}; };
} }
......
...@@ -30,22 +30,19 @@ ...@@ -30,22 +30,19 @@
#include "ngraph/op/sum.hpp" #include "ngraph/op/sum.hpp"
using namespace ngraph; using namespace ngraph;
using namespace std;
bool ngraph::pass::GetOutputElementElimination::run_on_function(std::shared_ptr<ngraph::Function> f) bool pass::GetOutputElementElimination::run_on_node(shared_ptr<Node> n)
{ {
bool optimized = false; bool optimized = false;
for (auto n : f->get_ordered_ops()) for (auto& input : n->get_inputs())
{ {
for (auto& input : n->get_inputs()) if (auto goe = dynamic_pointer_cast<op::GetOutputElement>(input.get_output().get_node()))
{ {
if (auto goe = auto multi = goe->get_inputs().at(0).get_output().get_node();
std::dynamic_pointer_cast<op::GetOutputElement>(input.get_output().get_node())) input.replace_output(goe->get_inputs().at(goe->get_n()).get_output());
{ // we don't need to fix anything w.r.t GetOutputElement as it will become unreachable
auto multi = goe->get_inputs().at(0).get_output().get_node(); optimized = true;
input.replace_output(goe->get_inputs().at(goe->get_n()).get_output());
//we don't need to fix anything w.r.t GetOutputElement as it will become unreachable
optimized = true;
}
} }
} }
return optimized; return optimized;
......
...@@ -26,13 +26,8 @@ namespace ngraph ...@@ -26,13 +26,8 @@ namespace ngraph
} }
} }
class ngraph::pass::GetOutputElementElimination : public FunctionPass class ngraph::pass::GetOutputElementElimination : public NodePass
{ {
public: public:
GetOutputElementElimination() bool run_on_node(std::shared_ptr<Node> node) override;
: FunctionPass()
{
}
virtual bool run_on_function(std::shared_ptr<ngraph::Function> f);
}; };
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment