Commit 76047c77 authored by Robert Kimball's avatar Robert Kimball Committed by Scott Cyphers

Use less complex pass base where possible (#829)

parent 81c0ef79
......@@ -27,32 +27,29 @@ namespace ngraph
namespace pass
{
template <typename LT>
class AssignLayout : public CallGraphPass
class AssignLayout : public NodePass
{
public:
virtual bool run_on_call_graph(const std::list<std::shared_ptr<Node>>& nodes) override
virtual bool run_on_node(std::shared_ptr<Node> node) override
{
for (const std::shared_ptr<Node>& node : nodes)
try
{
try
for (size_t i = 0; i < node->get_output_size(); ++i)
{
for (size_t i = 0; i < node->get_output_size(); ++i)
auto tv = node->get_output_tensor_view(i);
if (nullptr == tv->get_tensor_view_layout())
{
auto tv = node->get_output_tensor_view(i);
if (nullptr == tv->get_tensor_view_layout())
{
auto layout = std::make_shared<LT>(*tv);
tv->set_tensor_view_layout(layout);
}
auto layout = std::make_shared<LT>(*tv);
tv->set_tensor_view_layout(layout);
}
}
catch (const std::exception& e)
{
std::stringstream ss;
ss << "Error with node " << *node << ": ";
ss << e.what();
throw std::invalid_argument(ss.str());
}
}
catch (const std::exception& e)
{
std::stringstream ss;
ss << "Error with node " << *node << ": ";
ss << e.what();
throw std::invalid_argument(ss.str());
}
return false;
}
......
......@@ -18,25 +18,15 @@
#include "ngraph/node.hpp"
#include "ngraph/placement.hpp"
using namespace std;
using namespace ngraph;
using namespace std;
ngraph::pass::AssignPlacement::AssignPlacement(
std::function<Placement(std::shared_ptr<Node>)> placement_policy)
pass::AssignPlacement::AssignPlacement(function<Placement(shared_ptr<Node>)> placement_policy)
: m_placement_policy(placement_policy)
{
}
bool ngraph::pass::AssignPlacement::run_on_call_graph(const std::list<std::shared_ptr<Node>>& nodes)
{
for (const std::shared_ptr<Node>& node : nodes)
{
run_on_node(node);
}
return false;
}
bool ngraph::pass::AssignPlacement::run_on_node(shared_ptr<Node> node)
bool pass::AssignPlacement::run_on_node(shared_ptr<Node> node)
{
node->set_placement(m_placement_policy(node));
return false;
......
......@@ -27,15 +27,14 @@ namespace ngraph
{
namespace pass
{
class AssignPlacement : public CallGraphPass
class AssignPlacement : public NodePass
{
public:
// TODO: make policy a class
AssignPlacement(std::function<Placement(std::shared_ptr<Node>)> placement_policy);
virtual bool run_on_call_graph(const std::list<std::shared_ptr<Node>>& nodes) override;
private:
bool run_on_node(std::shared_ptr<Node> node);
bool run_on_node(std::shared_ptr<Node> node) override;
std::function<Placement(std::shared_ptr<Node>)> m_placement_policy;
};
}
......
......@@ -30,22 +30,19 @@
#include "ngraph/op/sum.hpp"
using namespace ngraph;
using namespace std;
bool ngraph::pass::GetOutputElementElimination::run_on_function(std::shared_ptr<ngraph::Function> f)
bool pass::GetOutputElementElimination::run_on_node(shared_ptr<Node> n)
{
bool optimized = false;
for (auto n : f->get_ordered_ops())
for (auto& input : n->get_inputs())
{
for (auto& input : n->get_inputs())
if (auto goe = dynamic_pointer_cast<op::GetOutputElement>(input.get_output().get_node()))
{
if (auto goe =
std::dynamic_pointer_cast<op::GetOutputElement>(input.get_output().get_node()))
{
auto multi = goe->get_inputs().at(0).get_output().get_node();
input.replace_output(goe->get_inputs().at(goe->get_n()).get_output());
//we don't need to fix anything w.r.t GetOutputElement as it will become unreachable
optimized = true;
}
auto multi = goe->get_inputs().at(0).get_output().get_node();
input.replace_output(goe->get_inputs().at(goe->get_n()).get_output());
// we don't need to fix anything w.r.t GetOutputElement as it will become unreachable
optimized = true;
}
}
return optimized;
......
......@@ -26,13 +26,8 @@ namespace ngraph
}
}
class ngraph::pass::GetOutputElementElimination : public FunctionPass
class ngraph::pass::GetOutputElementElimination : public NodePass
{
public:
GetOutputElementElimination()
: FunctionPass()
{
}
virtual bool run_on_function(std::shared_ptr<ngraph::Function> f);
bool run_on_node(std::shared_ptr<Node> node) override;
};
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment