Commit 3ad5140c authored by Robert Kimball's avatar Robert Kimball Committed by Adam Procter

Bob/add pass to interpreter (#146)

* cleanup pass registration code.

* Replace explicit code to do passes with pass_manager passes

* make external function work with const types internally, it should not change the graph
parent d9acd066
...@@ -65,7 +65,7 @@ void ngraph::pass::Manager::initialize_default_passes() ...@@ -65,7 +65,7 @@ void ngraph::pass::Manager::initialize_default_passes()
{ {
} }
void ngraph::pass::Manager::register_pass(std::shared_ptr<TreeBase> p) void ngraph::pass::Manager::register_pass_ptr(std::shared_ptr<TreeBase> p)
{ {
if (p == nullptr) if (p == nullptr)
{ {
...@@ -75,7 +75,7 @@ void ngraph::pass::Manager::register_pass(std::shared_ptr<TreeBase> p) ...@@ -75,7 +75,7 @@ void ngraph::pass::Manager::register_pass(std::shared_ptr<TreeBase> p)
m_tree_passes.push_back(p); m_tree_passes.push_back(p);
} }
void ngraph::pass::Manager::register_pass(std::shared_ptr<CallBase> p) void ngraph::pass::Manager::register_pass_ptr(std::shared_ptr<CallBase> p)
{ {
if (p == nullptr) if (p == nullptr)
{ {
......
...@@ -59,8 +59,19 @@ public: ...@@ -59,8 +59,19 @@ public:
void initialize_default_passes(); void initialize_default_passes();
void register_pass(std::shared_ptr<TreeBase>); template<typename T, class... Args>
void register_pass(std::shared_ptr<CallBase>); void register_pass(Args... args)
{
static_assert(std::is_base_of<pass::Base, T>::value, "pass not derived from pass base");
if (std::is_base_of<TreeBase, T>::value)
{
register_pass_ptr(std::make_shared<T>(args...));
}
else if (std::is_base_of<CallBase, T>::value)
{
register_pass_ptr(std::make_shared<T>(args...));
}
}
void run_passes(Function*); void run_passes(Function*);
void run_passes(std::shared_ptr<Function>); void run_passes(std::shared_ptr<Function>);
...@@ -70,6 +81,9 @@ public: ...@@ -70,6 +81,9 @@ public:
ManagerState& get_state(); ManagerState& get_state();
private: private:
void register_pass_ptr(std::shared_ptr<TreeBase>);
void register_pass_ptr(std::shared_ptr<CallBase>);
std::vector<std::shared_ptr<TreeBase>> m_tree_passes; std::vector<std::shared_ptr<TreeBase>> m_tree_passes;
std::vector<std::shared_ptr<CallBase>> m_call_passes; std::vector<std::shared_ptr<CallBase>> m_call_passes;
ManagerState m_state; ManagerState m_state;
......
...@@ -39,7 +39,9 @@ ...@@ -39,7 +39,9 @@
#include "ngraph/ops/select.hpp" #include "ngraph/ops/select.hpp"
#include "ngraph/ops/subtract.hpp" #include "ngraph/ops/subtract.hpp"
#include "ngraph/ops/tuple.hpp" #include "ngraph/ops/tuple.hpp"
#include "ngraph/pass/assign_tensors.hpp"
#include "ngraph/pass/manager.hpp" #include "ngraph/pass/manager.hpp"
#include "ngraph/pass/propagate_types.hpp"
#include "ngraph/pass/topological_sort.hpp" #include "ngraph/pass/topological_sort.hpp"
#include "ngraph/runtime/eigen/abs.hpp" #include "ngraph/runtime/eigen/abs.hpp"
#include "ngraph/runtime/eigen/add.hpp" #include "ngraph/runtime/eigen/add.hpp"
...@@ -78,7 +80,7 @@ ExternalFunction::ExternalFunction(const std::shared_ptr<ngraph::Function>& func ...@@ -78,7 +80,7 @@ ExternalFunction::ExternalFunction(const std::shared_ptr<ngraph::Function>& func
} }
#define REGISTER_INSTRUCTION(op_class, instr_class, ...) \ #define REGISTER_INSTRUCTION(op_class, instr_class, ...) \
op_map[type_index(typeid(op_class))] = [](Node* n, \ op_map[type_index(typeid(op_class))] = [](const Node* n, \
ExternalFunction* ef, \ ExternalFunction* ef, \
const std::vector<size_t>& in, \ const std::vector<size_t>& in, \
const std::vector<size_t>& out) { \ const std::vector<size_t>& out) { \
...@@ -94,7 +96,7 @@ ExternalFunction::ExternalFunction(const std::shared_ptr<ngraph::Function>& func ...@@ -94,7 +96,7 @@ ExternalFunction::ExternalFunction(const std::shared_ptr<ngraph::Function>& func
// Define code generators for handled ops. // Define code generators for handled ops.
std::unordered_map<std::type_index, std::unordered_map<std::type_index,
std::function<void(ngraph::Node*, std::function<void(const ngraph::Node*,
ExternalFunction*, ExternalFunction*,
const std::vector<size_t>& inputs, const std::vector<size_t>& inputs,
const std::vector<size_t>& outputs)>>& const std::vector<size_t>& outputs)>>&
...@@ -102,7 +104,7 @@ std::unordered_map<std::type_index, ...@@ -102,7 +104,7 @@ std::unordered_map<std::type_index,
{ {
static bool initialized = false; static bool initialized = false;
static std::unordered_map<std::type_index, static std::unordered_map<std::type_index,
std::function<void(Node*, std::function<void(const Node*,
ExternalFunction*, ExternalFunction*,
const std::vector<size_t>& inputs, const std::vector<size_t>& inputs,
const std::vector<size_t>& outputs)>> const std::vector<size_t>& outputs)>>
...@@ -126,16 +128,16 @@ std::unordered_map<std::type_index, ...@@ -126,16 +128,16 @@ std::unordered_map<std::type_index,
op::ScalarConstant<element::Float32>, op::ScalarConstant<element::Float32>,
runtime::eigen::ConstantInstruction<element::Float32>, runtime::eigen::ConstantInstruction<element::Float32>,
std::vector<element::Float32::type>{ std::vector<element::Float32::type>{
dynamic_cast<op::ScalarConstant<element::Float32>*>(n)->get_value()}, dynamic_cast<const op::ScalarConstant<element::Float32>*>(n)->get_value()},
out[0]); out[0]);
REGISTER_INSTRUCTION( REGISTER_INSTRUCTION(
op::TensorConstant<element::Float32>, op::TensorConstant<element::Float32>,
runtime::eigen::ConstantInstruction<element::Float32>, runtime::eigen::ConstantInstruction<element::Float32>,
dynamic_cast<op::TensorConstant<element::Float32>*>(n)->get_value()->get_vector(), dynamic_cast<const op::TensorConstant<element::Float32>*>(n)->get_value()->get_vector(),
out[0]); out[0]);
op_map[type_index(typeid(op::Concat))] = [](Node* n, op_map[type_index(typeid(op::Concat))] = [](const Node* n,
ExternalFunction* ef, ExternalFunction* ef,
const std::vector<size_t>& in, const std::vector<size_t>& in,
const std::vector<size_t>& out) { const std::vector<size_t>& out) {
...@@ -155,7 +157,7 @@ std::unordered_map<std::type_index, ...@@ -155,7 +157,7 @@ std::unordered_map<std::type_index,
{ {
ef->get_instructions()->push_back( ef->get_instructions()->push_back(
make_shared<runtime::eigen::ConcatMatrixInstruction<element::Float32>>( make_shared<runtime::eigen::ConcatMatrixInstruction<element::Float32>>(
in, (dynamic_cast<op::Concat *>(n))->get_concatenation_axis(), out[0])); in, (dynamic_cast<const op::Concat *>(n))->get_concatenation_axis(), out[0]));
} }
else else
{ {
...@@ -163,7 +165,7 @@ std::unordered_map<std::type_index, ...@@ -163,7 +165,7 @@ std::unordered_map<std::type_index,
} }
}; };
op_map[type_index(typeid(op::Dot))] = [](Node* n, op_map[type_index(typeid(op::Dot))] = [](const Node* n,
ExternalFunction* ef, ExternalFunction* ef,
const std::vector<size_t>& in, const std::vector<size_t>& in,
const std::vector<size_t>& out) { const std::vector<size_t>& out) {
...@@ -228,24 +230,24 @@ std::unordered_map<std::type_index, ...@@ -228,24 +230,24 @@ std::unordered_map<std::type_index,
}; };
// Parameter is a "runtime no-op" because the output tensor has already been filled. // Parameter is a "runtime no-op" because the output tensor has already been filled.
op_map[type_index(typeid(op::Parameter))] = [](Node* n, op_map[type_index(typeid(op::Parameter))] = [](const Node* n,
ExternalFunction* ef, ExternalFunction* ef,
const std::vector<size_t>& in, const std::vector<size_t>& in,
const std::vector<size_t>& out) {}; const std::vector<size_t>& out) {};
// GetTupleElement will be spliced out, with the users of out redirected to in's source, but, for now, we need to copy. // GetTupleElement will be spliced out, with the users of out redirected to in's source, but, for now, we need to copy.
op_map[type_index(typeid(op::GetTupleElement))] = [](Node* n, op_map[type_index(typeid(op::GetTupleElement))] = [](const Node* n,
ExternalFunction* ef, ExternalFunction* ef,
const std::vector<size_t>& in, const std::vector<size_t>& in,
const std::vector<size_t>& out) { const std::vector<size_t>& out) {
auto get_tuple_element = static_cast<op::GetTupleElement*>(n); auto get_tuple_element = static_cast<const op::GetTupleElement*>(n);
ef->get_instructions()->push_back( ef->get_instructions()->push_back(
make_shared<runtime::eigen::CopyInstruction<element::Float32>>( make_shared<runtime::eigen::CopyInstruction<element::Float32>>(
in.at(get_tuple_element->get_n()), out.at(0))); in.at(get_tuple_element->get_n()), out.at(0)));
}; };
// Tuple will be spliced out, with the users of out connected to the corresponding in's source, but, for now, we need to copy. // Tuple will be spliced out, with the users of out connected to the corresponding in's source, but, for now, we need to copy.
op_map[type_index(typeid(op::Tuple))] = [](Node* n, op_map[type_index(typeid(op::Tuple))] = [](const Node* n,
ExternalFunction* ef, ExternalFunction* ef,
const std::vector<size_t>& in, const std::vector<size_t>& in,
const std::vector<size_t>& out) { const std::vector<size_t>& out) {
...@@ -272,27 +274,17 @@ void ExternalFunction::compile() ...@@ -272,27 +274,17 @@ void ExternalFunction::compile()
// This will be replaced with the pass manager // This will be replaced with the pass manager
// Get the ordered list of ops in execution order // Get the ordered list of ops in execution order
pass::Manager pass_manager; pass::Manager pass_manager;
auto topological_sort = make_shared<pass::TopologicalSort>(); pass_manager.register_pass<pass::TopologicalSort>();
pass_manager.register_pass(topological_sort); pass_manager.register_pass<pass::PropagateTypes>();
pass_manager.register_pass<pass::AssignTensors>();
pass_manager.run_passes(m_function); pass_manager.run_passes(m_function);
auto nodes = pass_manager.get_call_graph();
// Types
for (auto node : nodes)
{
node->propagate_types();
}
// Determine tensors
for (auto node : nodes)
{
node->assign_tensors();
}
// Determine tensor requirements for the call frame // Determine tensor requirements for the call frame
unordered_map<shared_ptr<ngraph::descriptor::TensorView>, size_t> tensor_index; unordered_map<shared_ptr<ngraph::descriptor::TensorView>, size_t> tensor_index;
// First come the function inputs // First come the function inputs
for (auto param : m_function->get_parameters()) for (auto param : m_function->get_parameters())
{ {
for (auto output : param->get_outputs()) for (const descriptor::Output& output : param->get_outputs())
{ {
auto tv = output.get_tensor_view(); auto tv = output.get_tensor_view();
size_t index = tensor_index.size(); size_t index = tensor_index.size();
...@@ -302,7 +294,7 @@ void ExternalFunction::compile() ...@@ -302,7 +294,7 @@ void ExternalFunction::compile()
m_n_inputs = tensor_index.size(); m_n_inputs = tensor_index.size();
// Next are the function outputs // Next are the function outputs
for (auto output : m_function->get_result()->get_outputs()) for (const descriptor::Output& output : m_function->get_result()->get_outputs())
{ {
auto tv = output.get_tensor_view(); auto tv = output.get_tensor_view();
size_t index = tensor_index.size(); size_t index = tensor_index.size();
...@@ -311,9 +303,9 @@ void ExternalFunction::compile() ...@@ -311,9 +303,9 @@ void ExternalFunction::compile()
m_n_outputs = tensor_index.size() - m_n_inputs; m_n_outputs = tensor_index.size() - m_n_inputs;
// All remaining tensor views // All remaining tensor views
for (auto node : nodes) for (const Node* node : pass_manager.get_call_graph())
{ {
for (auto output : node->get_outputs()) for (const descriptor::Output& output : node->get_outputs())
{ {
auto tv = output.get_tensor_view(); auto tv = output.get_tensor_view();
if (0 == tensor_index.count(tv)) if (0 == tensor_index.count(tv))
...@@ -327,7 +319,7 @@ void ExternalFunction::compile() ...@@ -327,7 +319,7 @@ void ExternalFunction::compile()
// Now we build the eigen-VM instructions // Now we build the eigen-VM instructions
auto op_map = get_op_map(); auto op_map = get_op_map();
for (auto node : nodes) for (const Node* node : pass_manager.get_call_graph())
{ {
auto handler_it = op_map.find(type_index(typeid(*node))); auto handler_it = op_map.find(type_index(typeid(*node)));
if (handler_it == op_map.end()) if (handler_it == op_map.end())
...@@ -335,14 +327,14 @@ void ExternalFunction::compile() ...@@ -335,14 +327,14 @@ void ExternalFunction::compile()
throw ngraph_error("Unhandled op during code generation"); throw ngraph_error("Unhandled op during code generation");
} }
std::vector<size_t> in; std::vector<size_t> in;
for (auto input : node->get_inputs()) for (const descriptor::Input& input : node->get_inputs())
{ {
auto output = input.get_output(); const descriptor::Output& output = input.get_output();
auto tv = output.get_tensor_view(); auto tv = output.get_tensor_view();
in.push_back(tensor_index.at(tv)); in.push_back(tensor_index.at(tv));
} }
std::vector<size_t> out; std::vector<size_t> out;
for (auto output : node->get_outputs()) for (const descriptor::Output& output : node->get_outputs())
{ {
auto tv = output.get_tensor_view(); auto tv = output.get_tensor_view();
out.push_back(tensor_index.at(tv)); out.push_back(tensor_index.at(tv));
......
...@@ -53,7 +53,7 @@ namespace ngraph ...@@ -53,7 +53,7 @@ namespace ngraph
std::vector<std::shared_ptr<ngraph::descriptor::TensorView>> m_temp_views; std::vector<std::shared_ptr<ngraph::descriptor::TensorView>> m_temp_views;
static std::unordered_map<std::type_index, static std::unordered_map<std::type_index,
std::function<void(ngraph::Node*, std::function<void(const ngraph::Node*,
ExternalFunction*, ExternalFunction*,
const std::vector<size_t>& inputs, const std::vector<size_t>& inputs,
const std::vector<size_t>& outputs)>>& const std::vector<size_t>& outputs)>>&
......
...@@ -41,19 +41,13 @@ TEST(pass, liveness) ...@@ -41,19 +41,13 @@ TEST(pass, liveness)
string image = "liveness.png"; string image = "liveness.png";
string dump_file = "liveness.txt"; string dump_file = "liveness.txt";
pass::Manager pass_manager; pass::Manager pass_manager;
auto visualize = make_shared<pass::VisualizeTree>(image);
auto topological_sort = make_shared<pass::TopologicalSort>(); pass_manager.register_pass<pass::VisualizeTree>(image);
auto propagate_types = make_shared<pass::PropagateTypes>(); pass_manager.register_pass<pass::TopologicalSort>();
auto assign_tensors = make_shared<pass::AssignTensors>(); pass_manager.register_pass<pass::PropagateTypes>();
auto liveness = make_shared<pass::Liveness>(); pass_manager.register_pass<pass::AssignTensors>();
auto dump_sorted = make_shared<pass::DumpSorted>(dump_file); pass_manager.register_pass<pass::Liveness>();
pass_manager.register_pass<pass::DumpSorted>(dump_file);
pass_manager.register_pass(visualize);
pass_manager.register_pass(topological_sort);
pass_manager.register_pass(propagate_types);
pass_manager.register_pass(assign_tensors);
pass_manager.register_pass(liveness);
pass_manager.register_pass(dump_sorted);
shared_ptr<Function> func = make_test_graph(); shared_ptr<Function> func = make_test_graph();
pass_manager.run_passes(func.get()); pass_manager.run_passes(func.get());
......
...@@ -32,13 +32,10 @@ using namespace std; ...@@ -32,13 +32,10 @@ using namespace std;
TEST(pass_manager, add) TEST(pass_manager, add)
{ {
pass::Manager pass_manager; pass::Manager pass_manager;
auto topological_sort = make_shared<pass::TopologicalSort>();
auto propagate_types = make_shared<pass::PropagateTypes>();
auto assign_tensors = make_shared<pass::AssignTensors>();
pass_manager.register_pass(topological_sort); pass_manager.register_pass<pass::TopologicalSort>();
pass_manager.register_pass(propagate_types); pass_manager.register_pass<pass::PropagateTypes>();
pass_manager.register_pass(assign_tensors); pass_manager.register_pass<pass::AssignTensors>();
auto graph = make_test_graph(); auto graph = make_test_graph();
size_t node_count = get_node_count(graph->get_result()); size_t node_count = get_node_count(graph->get_result());
...@@ -51,10 +48,7 @@ TEST(pass_manager, add) ...@@ -51,10 +48,7 @@ TEST(pass_manager, add)
TEST(pass_manager, dependency) TEST(pass_manager, dependency)
{ {
pass::Manager pass_manager; pass::Manager pass_manager;
auto topological_sort = make_shared<pass::TopologicalSort>();
auto propagate_types = make_shared<pass::PropagateTypes>();
auto assign_tensors = make_shared<pass::AssignTensors>();
pass_manager.register_pass(topological_sort); pass_manager.register_pass<pass::TopologicalSort>();
EXPECT_THROW(pass_manager.register_pass(assign_tensors), runtime_error); EXPECT_THROW(pass_manager.register_pass<pass::AssignTensors>(), runtime_error);
} }
...@@ -209,19 +209,12 @@ TEST(memory_layout, basic) ...@@ -209,19 +209,12 @@ TEST(memory_layout, basic)
{ {
string dump_file = "memory_layout.txt"; string dump_file = "memory_layout.txt";
pass::Manager pass_manager; pass::Manager pass_manager;
auto topological_sort = make_shared<pass::TopologicalSort>(); pass_manager.register_pass<pass::TopologicalSort>();
auto propagate_types = make_shared<pass::PropagateTypes>(); pass_manager.register_pass<pass::PropagateTypes>();
auto assign_tensors = make_shared<pass::AssignTensors>(); pass_manager.register_pass<pass::AssignTensors>();
auto liveness = make_shared<pass::Liveness>(); pass_manager.register_pass<pass::Liveness>();
auto memory_layout = make_shared<pass::MemoryLayout>(); pass_manager.register_pass<pass::MemoryLayout>();
auto dump_sorted = make_shared<pass::DumpSorted>(dump_file); pass_manager.register_pass<pass::DumpSorted>(dump_file);
pass_manager.register_pass(topological_sort);
pass_manager.register_pass(propagate_types);
pass_manager.register_pass(assign_tensors);
pass_manager.register_pass(liveness);
pass_manager.register_pass(memory_layout);
pass_manager.register_pass(dump_sorted);
auto graph = make_test_graph(); auto graph = make_test_graph();
pass_manager.run_passes(graph); pass_manager.run_passes(graph);
......
...@@ -35,15 +35,11 @@ using namespace ngraph::descriptor; ...@@ -35,15 +35,11 @@ using namespace ngraph::descriptor;
TEST(tensor, size) TEST(tensor, size)
{ {
pass::Manager pass_manager; pass::Manager pass_manager;
auto topological_sort = make_shared<pass::TopologicalSort>();
auto propagate_types = make_shared<pass::PropagateTypes>();
auto assign_tensors = make_shared<pass::AssignTensors>();
auto liveness = make_shared<pass::Liveness>();
pass_manager.register_pass(topological_sort); pass_manager.register_pass<pass::TopologicalSort>();
pass_manager.register_pass(propagate_types); pass_manager.register_pass<pass::PropagateTypes>();
pass_manager.register_pass(assign_tensors); pass_manager.register_pass<pass::AssignTensors>();
pass_manager.register_pass(liveness); pass_manager.register_pass<pass::Liveness>();
{ {
auto arg0 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{2, 3}); auto arg0 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{2, 3});
......
...@@ -64,8 +64,7 @@ TEST(topological_sort, basic) ...@@ -64,8 +64,7 @@ TEST(topological_sort, basic)
// vz.add(r0); // vz.add(r0);
// vz.save_dot("test.png"); // vz.save_dot("test.png");
pass::Manager pass_manager; pass::Manager pass_manager;
auto topological_sort = make_shared<pass::TopologicalSort>(); pass_manager.register_pass<pass::TopologicalSort>();
pass_manager.register_pass(topological_sort);
pass_manager.run_passes(f0); pass_manager.run_passes(f0);
auto sorted_list = pass_manager.get_call_graph(); auto sorted_list = pass_manager.get_call_graph();
...@@ -116,8 +115,7 @@ TEST(benchmark, topological_sort) ...@@ -116,8 +115,7 @@ TEST(benchmark, topological_sort)
timer.start(); timer.start();
pass::Manager pass_manager; pass::Manager pass_manager;
auto topological_sort = make_shared<pass::TopologicalSort>(); pass_manager.register_pass<pass::TopologicalSort>();
pass_manager.register_pass(topological_sort);
pass_manager.run_passes(f0); pass_manager.run_passes(f0);
auto sorted_list = pass_manager.get_call_graph(); auto sorted_list = pass_manager.get_call_graph();
timer.stop(); timer.stop();
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment