Commit 118e0679 authored by Jai Menon's avatar Jai Menon Committed by GitHub

Merge branch 'master' into jmenon/codegen

parents c2ff1508 c3dfdf5a
...@@ -39,16 +39,16 @@ set (SRC ...@@ -39,16 +39,16 @@ set (SRC
ops/unary_elementwise_arithmetic.cpp ops/unary_elementwise_arithmetic.cpp
ops/unary_elementwise_builtin.cpp ops/unary_elementwise_builtin.cpp
pass/assign_tensors.cpp pass/assign_tensors.cpp
pass/call_pass.cpp pass/collect_functions.cpp
pass/dump_sorted.cpp pass/dump_sorted.cpp
pass/liveness.cpp pass/liveness.cpp
pass/manager.cpp pass/manager.cpp
pass/manager_state.cpp
pass/memory_layout.cpp pass/memory_layout.cpp
pass/memory_visualize.cpp pass/memory_visualize.cpp
pass/pass.cpp pass/pass.cpp
pass/propagate_types.cpp pass/propagate_types.cpp
pass/topological_sort.cpp pass/topological_sort.cpp
pass/tree_pass.cpp
pass/visualize_tree.cpp pass/visualize_tree.cpp
runtime/call_frame.cpp runtime/call_frame.cpp
runtime/external_function.cpp runtime/external_function.cpp
...@@ -58,7 +58,6 @@ set (SRC ...@@ -58,7 +58,6 @@ set (SRC
types/element_type.cpp types/element_type.cpp
types/type.cpp types/type.cpp
util.cpp util.cpp
visualize.cpp
) )
# find_program (GRAPHVIZ dot) # find_program (GRAPHVIZ dot)
......
...@@ -15,21 +15,95 @@ ...@@ -15,21 +15,95 @@
#include <memory> #include <memory>
#include "ngraph/function.hpp" #include "ngraph/function.hpp"
#include "ngraph/util.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
size_t Function::m_next_instance_id = 0;
Function::Function(const std::shared_ptr<Node>& result, Function::Function(const std::shared_ptr<Node>& result,
const std::shared_ptr<ValueType>& result_type, const std::shared_ptr<ValueType>& result_type,
const std::vector<std::shared_ptr<op::Parameter>>& parameters) const std::vector<std::shared_ptr<op::Parameter>>& parameters,
const std::string& name)
: m_result(result) : m_result(result)
, m_parameters(parameters) , m_parameters(parameters)
, m_name("Function") , m_name(name)
, m_result_type(result_type) , m_result_type(result_type)
, m_ordered_ops_valid(false)
, m_instance_id(m_next_instance_id++)
{ {
size_t i = 0; size_t i = 0;
for (auto parameter : parameters) for (auto parameter : parameters)
{ {
parameter->assign_function(this, i++); parameter->assign_function(this, i++);
} }
traverse_nodes(result, [&](Node* node) { m_ops.push_back(node); });
}
void Function::set_ordered_ops(const std::list<Node*>& ordered_ops)
{
m_ordered_ops = ordered_ops;
m_ordered_ops_valid = true;
}
std::list<Node*>& Function::get_ops()
{
return m_ops;
}
const std::list<Node*>& Function::get_ops() const
{
return m_ops;
}
std::list<Node*>& Function::get_ordered_ops()
{
if (!m_ordered_ops_valid)
{
throw ngraph_error("Access to ordered ops invalid");
}
return m_ordered_ops;
}
const std::list<Node*>& Function::get_ordered_ops() const
{
if (!m_ordered_ops_valid)
{
throw ngraph_error("Access to ordered ops invalid");
}
return m_ordered_ops;
}
std::string Function::get_name() const
{
string rc;
if (m_name.empty())
{
rc = "Function_" + to_string(m_instance_id);
}
else
{
rc = m_name;
}
return rc;
}
void Function::set_name(const string& name)
{
if (m_name.empty())
{
m_name = name;
}
else
{
throw ngraph_error("Function name may be set exactly once");
}
}
std::ostream& ngraph::operator<<(std::ostream& out, const Function& f)
{
out << "Function(" << f.get_name() << ")";
return out;
} }
...@@ -15,11 +15,13 @@ ...@@ -15,11 +15,13 @@
#pragma once #pragma once
#include <initializer_list> #include <initializer_list>
#include <list>
#include <memory> #include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
#include "ngraph/descriptor/tensor_view.hpp" #include "ngraph/descriptor/tensor_view.hpp"
#include "ngraph/log.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/ops/op.hpp" #include "ngraph/ops/op.hpp"
#include "ngraph/ops/parameter.hpp" #include "ngraph/ops/parameter.hpp"
...@@ -34,7 +36,8 @@ namespace ngraph ...@@ -34,7 +36,8 @@ namespace ngraph
public: public:
Function(const std::shared_ptr<Node>& result, Function(const std::shared_ptr<Node>& result,
const std::shared_ptr<ValueType>& result_type, const std::shared_ptr<ValueType>& result_type,
const std::vector<std::shared_ptr<op::Parameter>>& parameters); const std::vector<std::shared_ptr<op::Parameter>>& parameters,
const std::string& name = "");
std::shared_ptr<Node> get_result() { return m_result; } std::shared_ptr<Node> get_result() { return m_result; }
const std::vector<std::shared_ptr<op::Parameter>> get_parameters() const const std::vector<std::shared_ptr<op::Parameter>> get_parameters() const
...@@ -42,11 +45,31 @@ namespace ngraph ...@@ -42,11 +45,31 @@ namespace ngraph
return m_parameters; return m_parameters;
} }
const std::shared_ptr<ValueType> get_result_type() const { return m_result_type; } const std::shared_ptr<ValueType> get_result_type() const { return m_result_type; }
std::string get_name() const { return m_name; } std::string get_name() const;
void set_name(const std::string& name);
std::list<Node*>& get_ops();
const std::list<Node*>& get_ops() const;
std::list<Node*>& get_ordered_ops();
const std::list<Node*>& get_ordered_ops() const;
void set_ordered_ops(const std::list<Node*>&);
void set_ordered_ops_valid() { m_ordered_ops_valid = true; }
void clear_ordered_ops_valid() { m_ordered_ops_valid = false; }
friend std::ostream& operator<<(std::ostream&, const Function&);
protected: protected:
std::shared_ptr<Node> m_result; std::shared_ptr<Node> m_result;
std::vector<std::shared_ptr<ngraph::op::Parameter>> m_parameters; std::vector<std::shared_ptr<ngraph::op::Parameter>> m_parameters;
std::string m_name; std::string m_name;
std::shared_ptr<ValueType> m_result_type; std::shared_ptr<ValueType> m_result_type;
bool m_ordered_ops_valid;
std::list<Node*> m_ordered_ops;
std::list<Node*> m_ops;
private:
Function(const Function&) = delete;
Function(const Function&&) = delete;
static size_t m_next_instance_id;
size_t m_instance_id;
}; };
} }
...@@ -104,6 +104,32 @@ std::string Node::get_node_id() const ...@@ -104,6 +104,32 @@ std::string Node::get_node_id() const
return ss.str(); return ss.str();
} }
std::string Node::get_name() const
{
string rc;
if (m_name.empty())
{
rc = description() + "_" + to_string(m_instance_id);
}
else
{
rc = m_name;
}
return rc;
}
void Node::set_name(const string& name)
{
if (m_name.empty())
{
m_name = name;
}
else
{
throw ngraph_error("Node name may be set exactly once");
}
}
namespace ngraph namespace ngraph
{ {
ostream& operator<<(ostream& out, const Node& node) ostream& operator<<(ostream& out, const Node& node)
......
...@@ -55,6 +55,8 @@ namespace ngraph ...@@ -55,6 +55,8 @@ namespace ngraph
public: public:
/// The class name, must not contain spaces /// The class name, must not contain spaces
virtual std::string description() const = 0; virtual std::string description() const = 0;
std::string get_name() const;
void set_name(const std::string& name);
/// Propagate types and check arguments for consistency /// Propagate types and check arguments for consistency
virtual void propagate_types() = 0; virtual void propagate_types() = 0;
......
...@@ -25,15 +25,15 @@ ...@@ -25,15 +25,15 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
bool pass::AssignTensors::run_on_call_list(std::list<Node*>& node_list) bool pass::AssignTensors::run_on_call_graph(list<Node*>& nodes)
{ {
for (Node* node : node_list) for (Node* node : nodes)
{ {
try try
{ {
// We need to set the nodes is_output state prior to call assign_tensors // We need to set the nodes is_output state prior to call assign_tensors
// so that the output state can be passes to the constructed tensors. // so that the output state can be passes to the constructed tensors.
if (node == get_state().get_function()->get_result().get()) if (node == get_state().get_functions().at(0)->get_result().get())
{ {
node->set_is_output(); node->set_is_output();
} }
...@@ -50,21 +50,3 @@ bool pass::AssignTensors::run_on_call_list(std::list<Node*>& node_list) ...@@ -50,21 +50,3 @@ bool pass::AssignTensors::run_on_call_list(std::list<Node*>& node_list)
} }
return false; return false;
} }
void pass::AssignTensors::check_dependencies(
const std::vector<std::shared_ptr<CallBase>>& registered_passes) const
{
bool found_propagate_types = false;
for (auto pass : registered_passes)
{
if (dynamic_pointer_cast<PropagateTypes>(pass))
{
found_propagate_types = true;
}
}
if (!found_propagate_types)
{
throw runtime_error("Dependency 'PropagateTypes' not found for pass 'AssignTensors'");
}
}
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
#pragma once #pragma once
#include "ngraph/pass/call_pass.hpp" #include "ngraph/pass/pass.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -25,12 +25,10 @@ namespace ngraph ...@@ -25,12 +25,10 @@ namespace ngraph
class Node; class Node;
} }
class ngraph::pass::AssignTensors : public CallBase class ngraph::pass::AssignTensors : public CallGraphPass
{ {
public: public:
virtual bool run_on_call_list(std::list<Node*>&) override; virtual bool run_on_call_graph(std::list<Node*>& nodes) override;
void check_dependencies(const std::vector<std::shared_ptr<CallBase>>&) const override;
private: private:
}; };
...@@ -12,34 +12,39 @@ ...@@ -12,34 +12,39 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#pragma once #include "ngraph/pass/collect_functions.hpp"
#include "ngraph/function.hpp"
#include <functional> #include "ngraph/log.hpp"
#include <memory> #include "ngraph/node.hpp"
#include <set> #include "ngraph/ops/function_call.hpp"
#include <sstream> #include "ngraph/ops/op.hpp"
#include "ngraph/util.hpp"
namespace ngraph
using namespace std;
using namespace ngraph;
using namespace ngraph::pass;
bool CollectFunctions::run_on_function(ngraph::Function* func)
{ {
class Visualize; set<Function*> functions;
class Node; deque<Function*> stack;
using node_ptr = std::shared_ptr<Node>; stack.push_back(func);
while (stack.empty() == false)
{
Function* f = stack.front();
stack.pop_front();
functions.insert(f);
traverse_nodes(f->get_result(), [&](Node* node) {
op::FunctionCall* fc = dynamic_cast<op::FunctionCall*>(node);
if (fc)
{
stack.push_back(fc->get_function().get());
}
});
}
get_state().set_functions(functions);
return false;
} }
class ngraph::Visualize
{
public:
Visualize(const std::string& name = "ngraph");
void add(node_ptr);
void save_dot(const std::string& path) const;
private:
std::string add_attributes(const Node* node);
std::string get_attributes(const Node* node);
std::stringstream m_ss;
std::string m_name;
std::set<const Node*> m_nodes_with_attributes;
};
...@@ -12,4 +12,22 @@ ...@@ -12,4 +12,22 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#include "ngraph/pass/call_pass.hpp" #pragma once
#include "ngraph/pass/pass.hpp"
namespace ngraph
{
namespace pass
{
class CollectFunctions;
}
}
class ngraph::pass::CollectFunctions : public FunctionPass
{
public:
bool run_on_function(ngraph::Function*) override;
private:
};
...@@ -27,14 +27,16 @@ pass::DumpSorted::DumpSorted(const string& output_file) ...@@ -27,14 +27,16 @@ pass::DumpSorted::DumpSorted(const string& output_file)
{ {
} }
bool pass::DumpSorted::run_on_call_list(list<Node*>& nodes) bool pass::DumpSorted::run_on_module(vector<Function*>& functions)
{ {
ofstream out{m_output_file}; ofstream out{m_output_file};
if (out) if (out)
{ {
for (const Node* node : nodes) for (Function* f : functions)
{ {
out << node->get_node_id() << "("; for (const Node* node : f->get_ordered_ops())
{
out << node->get_name() << "(";
vector<string> inputs; vector<string> inputs;
for (const Input& input : node->get_inputs()) for (const Input& input : node->get_inputs())
{ {
...@@ -65,6 +67,7 @@ bool pass::DumpSorted::run_on_call_list(list<Node*>& nodes) ...@@ -65,6 +67,7 @@ bool pass::DumpSorted::run_on_call_list(list<Node*>& nodes)
} }
} }
} }
}
return false; return false;
} }
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#include <string> #include <string>
#include "ngraph/pass/call_pass.hpp" #include "ngraph/pass/pass.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -27,12 +27,12 @@ namespace ngraph ...@@ -27,12 +27,12 @@ namespace ngraph
class Node; class Node;
} }
class ngraph::pass::DumpSorted : public CallBase class ngraph::pass::DumpSorted : public ModulePass
{ {
public: public:
DumpSorted(const std::string& output_file); DumpSorted(const std::string& output_file);
virtual bool run_on_call_list(std::list<Node*>&) override; virtual bool run_on_module(std::vector<Function*>&) override;
private: private:
const std::string m_output_file; const std::string m_output_file;
......
...@@ -27,7 +27,7 @@ using namespace std; ...@@ -27,7 +27,7 @@ using namespace std;
using namespace ngraph; using namespace ngraph;
using namespace ngraph::descriptor; using namespace ngraph::descriptor;
bool pass::Liveness::run_on_call_list(list<Node*>& ops) bool pass::Liveness::run_on_call_graph(list<Node*>& ops)
{ {
unordered_set<Tensor*> currently_live; unordered_set<Tensor*> currently_live;
...@@ -123,24 +123,6 @@ bool pass::Liveness::run_on_call_list(list<Node*>& ops) ...@@ -123,24 +123,6 @@ bool pass::Liveness::run_on_call_list(list<Node*>& ops)
return false; return false;
} }
void pass::Liveness::check_dependencies(
const std::vector<std::shared_ptr<CallBase>>& registered_passes) const
{
bool found_propagate_types = false;
for (auto pass : registered_passes)
{
if (dynamic_pointer_cast<AssignTensors>(pass))
{
found_propagate_types = true;
}
}
if (!found_propagate_types)
{
throw runtime_error("Dependency 'PropagateTypes' not found for pass 'AssignTensors'");
}
}
bool pass::Liveness::is_temporary(const Tensor& tensor) bool pass::Liveness::is_temporary(const Tensor& tensor)
{ {
return tensor.is_persistent() == false && tensor.is_input() == false && return tensor.is_persistent() == false && tensor.is_input() == false &&
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
#pragma once #pragma once
#include "ngraph/descriptor/tensor.hpp" #include "ngraph/descriptor/tensor.hpp"
#include "ngraph/pass/call_pass.hpp" #include "ngraph/pass/pass.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -26,12 +26,10 @@ namespace ngraph ...@@ -26,12 +26,10 @@ namespace ngraph
class Node; class Node;
} }
class ngraph::pass::Liveness : public CallBase class ngraph::pass::Liveness : public CallGraphPass
{ {
public: public:
virtual bool run_on_call_list(std::list<Node*>&) override; virtual bool run_on_call_graph(std::list<Node*>&) override;
void check_dependencies(const std::vector<std::shared_ptr<CallBase>>&) const override;
private: private:
bool is_temporary(const descriptor::Tensor&); bool is_temporary(const descriptor::Tensor&);
......
...@@ -19,40 +19,11 @@ ...@@ -19,40 +19,11 @@
#include "ngraph/log.hpp" #include "ngraph/log.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/pass/manager.hpp" #include "ngraph/pass/manager.hpp"
#include "ngraph/pass/pass.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
Function* ngraph::pass::ManagerState::get_function()
{
return m_function;
}
void ngraph::pass::ManagerState::set_function(Function* func)
{
m_function = func;
}
size_t ngraph::pass::ManagerState::get_temporary_pool_size()
{
return m_temporary_pool_size;
}
void ngraph::pass::ManagerState::set_temporary_pool_size(size_t size)
{
m_temporary_pool_size = size;
}
std::list<Node*>& ngraph::pass::ManagerState::get_call_graph()
{
return m_call_graph;
}
const std::list<Node*>& ngraph::pass::ManagerState::get_call_graph() const
{
return m_call_graph;
}
ngraph::pass::Manager::Manager() ngraph::pass::Manager::Manager()
{ {
} }
...@@ -65,26 +36,6 @@ void ngraph::pass::Manager::initialize_default_passes() ...@@ -65,26 +36,6 @@ void ngraph::pass::Manager::initialize_default_passes()
{ {
} }
void ngraph::pass::Manager::register_pass_ptr(std::shared_ptr<TreeBase> p)
{
if (p == nullptr)
{
throw invalid_argument("null pass registered");
}
p->check_dependencies(m_tree_passes);
m_tree_passes.push_back(p);
}
void ngraph::pass::Manager::register_pass_ptr(std::shared_ptr<CallBase> p)
{
if (p == nullptr)
{
throw invalid_argument("null pass registered");
}
p->check_dependencies(m_call_passes);
m_call_passes.push_back(p);
}
void ngraph::pass::Manager::run_passes(shared_ptr<Function> func) void ngraph::pass::Manager::run_passes(shared_ptr<Function> func)
{ {
run_passes(func.get()); run_passes(func.get());
...@@ -92,23 +43,79 @@ void ngraph::pass::Manager::run_passes(shared_ptr<Function> func) ...@@ -92,23 +43,79 @@ void ngraph::pass::Manager::run_passes(shared_ptr<Function> func)
void ngraph::pass::Manager::run_passes(Function* func) void ngraph::pass::Manager::run_passes(Function* func)
{ {
m_state.set_function(func); vector<Function*> fs = {func};
for (shared_ptr<TreeBase> p : m_tree_passes) get_state().set_functions(fs);
for (shared_ptr<PassBase> pass : m_pass_list)
{
pass->set_state(get_state());
auto module_pass = dynamic_pointer_cast<ModulePass>(pass);
auto function_pass = dynamic_pointer_cast<FunctionPass>(pass);
auto node_pass = dynamic_pointer_cast<NodePass>(pass);
auto call_graph_pass = dynamic_pointer_cast<CallGraphPass>(pass);
if (module_pass)
{ {
p->set_state(get_state()); module_pass->run_on_module(fs);
p->run_on_tree(func->get_result());
} }
else if (function_pass)
for (shared_ptr<CallBase>& p : m_call_passes) {
for (Function* f : fs)
{ {
p->set_state(get_state()); function_pass->run_on_function(f);
p->run_on_call_list(get_state().get_call_graph());
} }
} }
else if (node_pass)
const std::list<ngraph::Node*>& ngraph::pass::Manager::get_call_graph() const {
{ for (Function* f : fs)
return m_state.get_call_graph(); {
for (Node* n : f->get_ops())
{
node_pass->run_on_node(n);
}
}
}
else if (call_graph_pass)
{
for (Function* f : fs)
{
call_graph_pass->run_on_call_graph(f->get_ordered_ops());
}
}
}
// for (shared_ptr<ModulePass>& p : m_module_passes)
// {
// p->set_state(get_state());
// p->run_on_module(fs);
// }
// for (Function* f : fs)
// {
// for (shared_ptr<FunctionPass> p : m_function_passes)
// {
// p->set_state(get_state());
// p->run_on_function(f);
// }
// }
// for (Function* f : fs)
// {
// NGRAPH_INFO;
// for (shared_ptr<NodePass> p : m_node_passes)
// {
// for (Node* node : f->get_ops())
// {
// NGRAPH_INFO;
// p->set_state(get_state());
// p->run_on_node(node);
// }
// }
// }
// for (shared_ptr<CallGraphPass>& p : m_call_graph_passes)
// {
// p->set_state(get_state());
// p->run_on_call_graph(func->get_ordered_ops());
// }
} }
ngraph::pass::ManagerState& ngraph::pass::Manager::get_state() ngraph::pass::ManagerState& ngraph::pass::Manager::get_state()
......
...@@ -18,8 +18,8 @@ ...@@ -18,8 +18,8 @@
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "ngraph/pass/call_pass.hpp" #include "ngraph/pass/manager_state.hpp"
#include "ngraph/pass/tree_pass.hpp" #include "ngraph/pass/pass.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -33,24 +33,6 @@ namespace ngraph ...@@ -33,24 +33,6 @@ namespace ngraph
class Function; class Function;
} }
class ngraph::pass::ManagerState
{
public:
Function* get_function();
void set_function(Function*);
size_t get_temporary_pool_size();
void set_temporary_pool_size(size_t);
std::list<Node*>& get_call_graph();
const std::list<Node*>& get_call_graph() const;
private:
Function* m_function = nullptr;
size_t m_temporary_pool_size = 0;
std::list<Node*> m_call_graph;
};
class ngraph::pass::Manager class ngraph::pass::Manager
{ {
public: public:
...@@ -62,29 +44,18 @@ public: ...@@ -62,29 +44,18 @@ public:
template <typename T, class... Args> template <typename T, class... Args>
void register_pass(Args... args) void register_pass(Args... args)
{ {
static_assert(std::is_base_of<pass::Base, T>::value, "pass not derived from pass base"); static_assert(std::is_base_of<pass::PassBase, T>::value, "pass not derived from pass base");
if (std::is_base_of<TreeBase, T>::value) auto pass = std::make_shared<T>(args...);
{ auto pass_base = std::static_pointer_cast<PassBase>(pass);
register_pass_ptr(std::make_shared<T>(args...)); m_pass_list.push_back(pass_base);
}
else if (std::is_base_of<CallBase, T>::value)
{
register_pass_ptr(std::make_shared<T>(args...));
}
} }
void run_passes(Function*); void run_passes(Function*);
void run_passes(std::shared_ptr<Function>); void run_passes(std::shared_ptr<Function>);
const std::list<Node*>& get_call_graph() const;
ManagerState& get_state(); ManagerState& get_state();
private: private:
void register_pass_ptr(std::shared_ptr<TreeBase>); std::vector<std::shared_ptr<PassBase>> m_pass_list;
void register_pass_ptr(std::shared_ptr<CallBase>);
std::vector<std::shared_ptr<TreeBase>> m_tree_passes;
std::vector<std::shared_ptr<CallBase>> m_call_passes;
ManagerState m_state; ManagerState m_state;
}; };
...@@ -12,31 +12,28 @@ ...@@ -12,31 +12,28 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#pragma once #include <iostream>
#include <list>
#include <memory> #include <memory>
#include <vector>
#include "ngraph/pass/pass.hpp" #include "ngraph/function.hpp"
#include "ngraph/log.hpp"
#include "ngraph/node.hpp"
#include "ngraph/pass/manager_state.hpp"
namespace ngraph using namespace std;
{ using namespace ngraph;
namespace pass
{
class TreeBase;
}
class Node; vector<Function*>& ngraph::pass::ManagerState::get_functions()
{
return m_function_list;
} }
class ngraph::pass::TreeBase : public Base size_t ngraph::pass::ManagerState::get_temporary_pool_size()
{ {
public: return m_temporary_pool_size;
virtual ~TreeBase() {} }
// return true if changes were made to the tree
virtual bool run_on_tree(std::shared_ptr<Node>) = 0;
// derived class throws exception if its dependencies have not been met void ngraph::pass::ManagerState::set_temporary_pool_size(size_t size)
virtual void check_dependencies(const std::vector<std::shared_ptr<TreeBase>>&) const {} {
}; m_temporary_pool_size = size;
}
...@@ -14,29 +14,36 @@ ...@@ -14,29 +14,36 @@
#pragma once #pragma once
#include <list>
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "ngraph/pass/pass.hpp"
namespace ngraph namespace ngraph
{ {
namespace pass namespace pass
{ {
class CallBase; class ManagerState;
} }
class Node; class Node;
class Function;
} }
class ngraph::pass::CallBase : public Base class ngraph::pass::ManagerState
{ {
public: public:
virtual ~CallBase() {} std::vector<Function*>& get_functions();
virtual bool run_on_call_list(std::list<Node*>&) = 0;
template <typename T>
void set_functions(const T& collection)
{
m_function_list.clear();
m_function_list.insert(m_function_list.begin(), collection.begin(), collection.end());
}
size_t get_temporary_pool_size();
void set_temporary_pool_size(size_t);
// derived class throws exception if its dependencies have not been met
virtual void check_dependencies(const std::vector<std::shared_ptr<CallBase>>&) const {}
private: private:
size_t m_temporary_pool_size = 0;
std::vector<Function*> m_function_list;
}; };
...@@ -27,7 +27,7 @@ using namespace std; ...@@ -27,7 +27,7 @@ using namespace std;
using namespace ngraph; using namespace ngraph;
using namespace ngraph::descriptor; using namespace ngraph::descriptor;
bool pass::MemoryLayout::run_on_call_list(std::list<Node*>& node_list) bool pass::MemoryLayout::run_on_call_graph(std::list<Node*>& node_list)
{ {
MemoryManager mm; MemoryManager mm;
for (const Node* node : node_list) for (const Node* node : node_list)
...@@ -47,24 +47,6 @@ bool pass::MemoryLayout::run_on_call_list(std::list<Node*>& node_list) ...@@ -47,24 +47,6 @@ bool pass::MemoryLayout::run_on_call_list(std::list<Node*>& node_list)
return false; return false;
} }
void pass::MemoryLayout::check_dependencies(
const std::vector<std::shared_ptr<CallBase>>& registered_passes) const
{
bool found_propagate_types = false;
for (auto pass : registered_passes)
{
if (dynamic_pointer_cast<Liveness>(pass))
{
found_propagate_types = true;
}
}
if (!found_propagate_types)
{
throw runtime_error("Dependency 'PropagateTypes' not found for pass 'AssignTensors'");
}
}
pass::MemoryManager::node::node(size_t size, block_state state) pass::MemoryManager::node::node(size_t size, block_state state)
: m_size{size} : m_size{size}
, m_state{state} , m_state{state}
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
#include <list> #include <list>
#include <sstream> #include <sstream>
#include "ngraph/pass/call_pass.hpp" #include "ngraph/pass/pass.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -31,12 +31,10 @@ namespace ngraph ...@@ -31,12 +31,10 @@ namespace ngraph
class Node; class Node;
} }
class ngraph::pass::MemoryLayout : public CallBase class ngraph::pass::MemoryLayout : public CallGraphPass
{ {
public: public:
virtual bool run_on_call_list(std::list<Node*>&) override; virtual bool run_on_call_graph(std::list<Node*>&) override;
void check_dependencies(const std::vector<std::shared_ptr<CallBase>>&) const override;
private: private:
}; };
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include "memory_visualize.hpp" #include "memory_visualize.hpp"
#include "ngraph/descriptor/tensor.hpp" #include "ngraph/descriptor/tensor.hpp"
#include "ngraph/function.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
...@@ -31,11 +32,13 @@ pass::MemoryVisualize::MemoryVisualize(const string& filename) ...@@ -31,11 +32,13 @@ pass::MemoryVisualize::MemoryVisualize(const string& filename)
{ {
} }
bool pass::MemoryVisualize::run_on_call_list(list<Node*>& _nodes) bool pass::MemoryVisualize::run_on_module(vector<Function*>& functions)
{ {
const list<Node*> nodes = _nodes;
ofstream file(m_filename); ofstream file(m_filename);
{ {
for (const Function* f : functions)
{
const list<Node*> nodes = f->get_ordered_ops();
file << "<!DOCTYPE html>\n<html>\n"; file << "<!DOCTYPE html>\n<html>\n";
file << "<head>\n"; file << "<head>\n";
file << " <style>\n"; file << " <style>\n";
...@@ -89,13 +92,10 @@ bool pass::MemoryVisualize::run_on_call_list(list<Node*>& _nodes) ...@@ -89,13 +92,10 @@ bool pass::MemoryVisualize::run_on_call_list(list<Node*>& _nodes)
// file << "<hr>\n"; // file << "<hr>\n";
file << "</body>\n</html>\n"; file << "</body>\n</html>\n";
} }
}
return false; return false;
} }
void pass::MemoryVisualize::check_dependencies(const vector<shared_ptr<CallBase>>& deps) const
{
}
const Node* pass::MemoryVisualize::find_largest_op(const list<Node*>& nodes) const Node* pass::MemoryVisualize::find_largest_op(const list<Node*>& nodes)
{ {
const Node* largest_op = nullptr; const Node* largest_op = nullptr;
...@@ -207,7 +207,7 @@ void pass::MemoryVisualize::draw_histogram(ostream& file, const list<Node*>& nod ...@@ -207,7 +207,7 @@ void pass::MemoryVisualize::draw_histogram(ostream& file, const list<Node*>& nod
size_t x2 = ((usage / memory_footprint) * scale) + offset; size_t x2 = ((usage / memory_footprint) * scale) + offset;
file << "<text x=\"" << 0 << "\" y=\"" << y + text_offset << "\" fill=\"" file << "<text x=\"" << 0 << "\" y=\"" << y + text_offset << "\" fill=\""
<< "black" << "black"
<< "\">" << node->get_node_id() << "</text>\n"; << "\">" << node->get_name() << "</text>\n";
file << "<line x1=\"" << x1 << "\" y1=\"" << y << "\" x2=\"" << x2 << "\" y2=\"" << y file << "<line x1=\"" << x1 << "\" y1=\"" << y << "\" x2=\"" << x2 << "\" y2=\"" << y
<< "\""; << "\"";
file << " style=\"stroke:forestgreen;stroke-width:" << stroke_width << "\" />\n"; file << " style=\"stroke:forestgreen;stroke-width:" << stroke_width << "\" />\n";
...@@ -231,7 +231,7 @@ void pass::MemoryVisualize::draw_op_influence(ostream& file, const list<Node*>& ...@@ -231,7 +231,7 @@ void pass::MemoryVisualize::draw_op_influence(ostream& file, const list<Node*>&
{ {
int weight = compute_op_weight(exop); int weight = compute_op_weight(exop);
file << " <tr>"; file << " <tr>";
file << "<td>" << exop->get_node_id() << "</td>"; file << "<td>" << exop->get_name() << "</td>";
file << "<td align=\"right\">" << weight << "</td>"; file << "<td align=\"right\">" << weight << "</td>";
file << "</tr>\n"; file << "</tr>\n";
} }
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
#include <limits> #include <limits>
#include <list> #include <list>
#include "ngraph/pass/call_pass.hpp" #include "ngraph/pass/pass.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -29,13 +29,11 @@ namespace ngraph ...@@ -29,13 +29,11 @@ namespace ngraph
class Node; class Node;
} }
class ngraph::pass::MemoryVisualize : public CallBase class ngraph::pass::MemoryVisualize : public ModulePass
{ {
public: public:
MemoryVisualize(const std::string& filename); MemoryVisualize(const std::string& filename);
virtual bool run_on_call_list(std::list<Node*>&) override; virtual bool run_on_module(std::vector<Function*>&) override;
void check_dependencies(const std::vector<std::shared_ptr<CallBase>>&) const override;
private: private:
const Node* find_largest_op(const std::list<Node*>& nodes); const Node* find_largest_op(const std::list<Node*>& nodes);
......
...@@ -15,12 +15,12 @@ ...@@ -15,12 +15,12 @@
#include "ngraph/pass/pass.hpp" #include "ngraph/pass/pass.hpp"
#include "ngraph/pass/manager.hpp" #include "ngraph/pass/manager.hpp"
ngraph::pass::ManagerState& ngraph::pass::Base::get_state() ngraph::pass::ManagerState& ngraph::pass::PassBase::get_state()
{ {
return *m_state; return *m_state;
} }
void ngraph::pass::Base::set_state(ManagerState& state) void ngraph::pass::PassBase::set_state(ManagerState& state)
{ {
m_state = &state; m_state = &state;
} }
...@@ -14,21 +14,33 @@ ...@@ -14,21 +14,33 @@
#pragma once #pragma once
#include <list>
#include <memory>
#include <vector>
#include "ngraph/node.hpp"
namespace ngraph namespace ngraph
{ {
namespace pass namespace pass
{ {
class Base; class PassBase;
class ModulePass;
class FunctionPass;
class NodePass;
class CallGraphPass;
class Manager; class Manager;
class ManagerState; class ManagerState;
} }
class Function;
} }
class ngraph::pass::Base class ngraph::pass::PassBase
{ {
friend class Manager; friend class Manager;
public: public:
virtual ~PassBase() {}
protected: protected:
ManagerState& get_state(); ManagerState& get_state();
void set_state(ManagerState&); void set_state(ManagerState&);
...@@ -36,3 +48,31 @@ protected: ...@@ -36,3 +48,31 @@ protected:
private: private:
ManagerState* m_state; ManagerState* m_state;
}; };
class ngraph::pass::ModulePass : public PassBase
{
public:
virtual ~ModulePass() {}
virtual bool run_on_module(std::vector<ngraph::Function*>&) = 0;
};
class ngraph::pass::FunctionPass : public PassBase
{
public:
virtual ~FunctionPass() {}
virtual bool run_on_function(ngraph::Function*) = 0;
};
class ngraph::pass::NodePass : public PassBase
{
public:
virtual ~NodePass() {}
virtual bool run_on_node(ngraph::Node*) = 0;
};
class ngraph::pass::CallGraphPass : public PassBase
{
public:
virtual ~CallGraphPass() {}
virtual bool run_on_call_graph(std::list<ngraph::Node*>&) = 0;
};
...@@ -20,9 +20,9 @@ ...@@ -20,9 +20,9 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
bool pass::PropagateTypes::run_on_call_list(std::list<Node*>& node_list) bool pass::PropagateTypes::run_on_call_graph(list<Node*>& nodes)
{ {
for (Node* node : node_list) for (Node* node : nodes)
{ {
try try
{ {
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
#pragma once #pragma once
#include "ngraph/pass/call_pass.hpp" #include "ngraph/pass/pass.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -25,10 +25,10 @@ namespace ngraph ...@@ -25,10 +25,10 @@ namespace ngraph
class Node; class Node;
} }
class ngraph::pass::PropagateTypes : public CallBase class ngraph::pass::PropagateTypes : public CallGraphPass
{ {
public: public:
virtual bool run_on_call_list(std::list<Node*>&) override; virtual bool run_on_call_graph(std::list<Node*>&) override;
private: private:
}; };
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include <deque> #include <deque>
#include <unordered_map> #include <unordered_map>
#include "ngraph/function.hpp"
#include "ngraph/log.hpp" #include "ngraph/log.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/pass/manager.hpp" #include "ngraph/pass/manager.hpp"
...@@ -24,14 +25,13 @@ ...@@ -24,14 +25,13 @@
using namespace ngraph; using namespace ngraph;
using namespace std; using namespace std;
bool ngraph::pass::TopologicalSort::run_on_tree(std::shared_ptr<Node> p) bool ngraph::pass::TopologicalSort::run_on_function(ngraph::Function* func)
{ {
list<Node*>& sorted_list = get_state().get_call_graph(); list<Node*> result_list;
sorted_list.clear();
deque<Node*> independent_nodes; deque<Node*> independent_nodes;
unordered_map<Node*, size_t> node_depencency_count; unordered_map<Node*, size_t> node_depencency_count;
traverse_nodes(p, [&](Node* node) { traverse_nodes(func->get_result(), [&](Node* node) {
node_depencency_count[node] = node->get_arguments().size(); node_depencency_count[node] = node->get_arguments().size();
if (node->get_arguments().size() == 0) if (node->get_arguments().size() == 0)
{ {
...@@ -42,7 +42,7 @@ bool ngraph::pass::TopologicalSort::run_on_tree(std::shared_ptr<Node> p) ...@@ -42,7 +42,7 @@ bool ngraph::pass::TopologicalSort::run_on_tree(std::shared_ptr<Node> p)
while (independent_nodes.size() > 0) while (independent_nodes.size() > 0)
{ {
auto independent_node = independent_nodes.front(); auto independent_node = independent_nodes.front();
sorted_list.push_back(independent_node); result_list.push_back(independent_node);
independent_nodes.pop_front(); independent_nodes.pop_front();
for (auto user : independent_node->users()) for (auto user : independent_node->users())
...@@ -56,5 +56,7 @@ bool ngraph::pass::TopologicalSort::run_on_tree(std::shared_ptr<Node> p) ...@@ -56,5 +56,7 @@ bool ngraph::pass::TopologicalSort::run_on_tree(std::shared_ptr<Node> p)
} }
} }
func->set_ordered_ops(result_list);
return false; return false;
} }
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
#include <list> #include <list>
#include <memory> #include <memory>
#include "ngraph/pass/tree_pass.hpp" #include "ngraph/pass/pass.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -28,9 +28,9 @@ namespace ngraph ...@@ -28,9 +28,9 @@ namespace ngraph
class Node; class Node;
} }
class ngraph::pass::TopologicalSort : public TreeBase class ngraph::pass::TopologicalSort : public FunctionPass
{ {
public: public:
TopologicalSort() {} TopologicalSort() {}
bool run_on_tree(std::shared_ptr<Node>) override; bool run_on_function(ngraph::Function*) override;
}; };
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "ngraph/pass/tree_pass.hpp"
...@@ -14,25 +14,30 @@ ...@@ -14,25 +14,30 @@
#include <fstream> #include <fstream>
#include "ngraph/function.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/pass/pass.hpp"
#include "ngraph/pass/visualize_tree.hpp" #include "ngraph/pass/visualize_tree.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
using namespace ngraph; using namespace ngraph;
using namespace std; using namespace std;
bool pass::VisualizeTree::run_on_tree(std::shared_ptr<Node> base_node) bool pass::VisualizeTree::run_on_module(vector<ngraph::Function*>& functions)
{ {
for (Function* f : functions)
{
// map<size_t, list<node_ptr>> dependent_nodes; // map<size_t, list<node_ptr>> dependent_nodes;
traverse_nodes(base_node, [&](Node* node) { traverse_nodes(f->get_result(), [&](Node* node) {
for (auto arg : node->get_arguments()) for (auto arg : node->get_arguments())
{ {
m_ss << add_attributes(arg.get()); m_ss << add_attributes(arg.get());
m_ss << add_attributes(node); m_ss << add_attributes(node);
m_ss << " " << arg->get_node_id() << " -> " << node->get_node_id(); m_ss << " " << arg->get_name() << " -> " << node->get_name();
m_ss << ";\n"; m_ss << ";\n";
} }
}); });
}
render(); render();
...@@ -60,11 +65,11 @@ std::string pass::VisualizeTree::get_attributes(const Node* node) ...@@ -60,11 +65,11 @@ std::string pass::VisualizeTree::get_attributes(const Node* node)
stringstream ss; stringstream ss;
if (node->is_parameter()) if (node->is_parameter())
{ {
ss << " " << node->get_node_id() << " [shape=box color=blue]\n"; ss << " " << node->get_name() << " [shape=box color=blue]\n";
} }
else else
{ {
ss << " " << node->get_node_id() << " [shape=ellipse color=black]\n"; ss << " " << node->get_name() << " [shape=ellipse color=black]\n";
} }
return ss.str(); return ss.str();
} }
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
#include <sstream> #include <sstream>
#include <string> #include <string>
#include "ngraph/pass/tree_pass.hpp" #include "ngraph/pass/pass.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -29,11 +29,11 @@ namespace ngraph ...@@ -29,11 +29,11 @@ namespace ngraph
class Node; class Node;
} }
class ngraph::pass::VisualizeTree : public TreeBase class ngraph::pass::VisualizeTree : public ModulePass
{ {
public: public:
VisualizeTree(const std::string& file_name); VisualizeTree(const std::string& file_name);
bool run_on_tree(std::shared_ptr<Node>) override; bool run_on_module(std::vector<ngraph::Function*>&) override;
private: private:
std::string add_attributes(const Node* node); std::string add_attributes(const Node* node);
......
...@@ -659,7 +659,7 @@ void ExternalFunction::compile(FunctionMap& function_map) ...@@ -659,7 +659,7 @@ void ExternalFunction::compile(FunctionMap& function_map)
// Turn this into a pass // Turn this into a pass
// Assign layouts // Assign layouts
// For now, just make everyone row-major. // For now, just make everyone row-major.
for (const Node* node : pass_manager.get_call_graph()) for (const Node* node : m_function->get_ordered_ops())
{ {
for (const descriptor::Output& output : node->get_outputs()) for (const descriptor::Output& output : node->get_outputs())
{ {
...@@ -696,7 +696,7 @@ void ExternalFunction::compile(FunctionMap& function_map) ...@@ -696,7 +696,7 @@ void ExternalFunction::compile(FunctionMap& function_map)
m_n_outputs = tensor_index.size() - m_n_inputs; m_n_outputs = tensor_index.size() - m_n_inputs;
// All remaining tensor views // All remaining tensor views
for (const Node* node : pass_manager.get_call_graph()) for (const Node* node : m_function->get_ordered_ops())
{ {
for (const descriptor::Output& output : node->get_outputs()) for (const descriptor::Output& output : node->get_outputs())
{ {
...@@ -712,7 +712,7 @@ void ExternalFunction::compile(FunctionMap& function_map) ...@@ -712,7 +712,7 @@ void ExternalFunction::compile(FunctionMap& function_map)
// Now we build the eigen-VM instructions // Now we build the eigen-VM instructions
auto op_map = get_op_map(); auto op_map = get_op_map();
for (const Node* node : pass_manager.get_call_graph()) for (const Node* node : m_function->get_ordered_ops())
{ {
auto handler_it = op_map.find(type_index(typeid(*node))); auto handler_it = op_map.find(type_index(typeid(*node)));
if (handler_it == op_map.end()) if (handler_it == op_map.end())
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <cstdio>
#include <fstream>
#include <list>
#include "ngraph/node.hpp"
#include "ngraph/util.hpp"
#include "ngraph/visualize.hpp"
using namespace ngraph;
using namespace std;
Visualize::Visualize(const string& name)
: m_name{name}
{
}
void Visualize::add(node_ptr p)
{
// map<size_t, list<node_ptr>> dependent_nodes;
traverse_nodes(p, [&](Node* node) {
for (auto arg : node->get_arguments())
{
m_ss << add_attributes(arg.get());
m_ss << add_attributes(node);
m_ss << " " << arg->get_node_id() << " -> " << node->get_node_id();
m_ss << ";\n";
}
});
}
std::string Visualize::add_attributes(const Node* node)
{
string rc;
if (!contains(m_nodes_with_attributes, node))
{
m_nodes_with_attributes.insert(node);
rc = get_attributes(node);
}
return rc;
}
std::string Visualize::get_attributes(const Node* node)
{
stringstream ss;
if (node->is_parameter())
{
ss << " " << node->get_node_id() << " [shape=box color=blue]\n";
}
else
{
ss << " " << node->get_node_id() << " [shape=ellipse color=black]\n";
}
return ss.str();
}
void Visualize::save_dot(const string& path) const
{
#ifdef GRAPHVIZ_FOUND
auto tmp_file = path + ".tmp";
ofstream out(tmp_file);
if (out)
{
out << "digraph " << m_name << "\n{\n";
out << m_ss.str();
out << "}\n";
out.close();
stringstream ss;
ss << "dot -Tpng " << tmp_file << " -o " << path;
auto cmd = ss.str();
auto stream = popen(cmd.c_str(), "r");
pclose(stream);
remove(tmp_file.c_str());
}
#else
#endif
}
...@@ -51,7 +51,7 @@ TEST(pass, liveness) ...@@ -51,7 +51,7 @@ TEST(pass, liveness)
shared_ptr<Function> func = make_test_graph(); shared_ptr<Function> func = make_test_graph();
pass_manager.run_passes(func.get()); pass_manager.run_passes(func.get());
auto sorted = pass_manager.get_call_graph(); auto sorted = func->get_ordered_ops();
// for (const Node* node : sorted) // for (const Node* node : sorted)
// { // {
......
...@@ -40,15 +40,28 @@ TEST(pass_manager, add) ...@@ -40,15 +40,28 @@ TEST(pass_manager, add)
auto graph = make_test_graph(); auto graph = make_test_graph();
size_t node_count = get_node_count(graph->get_result()); size_t node_count = get_node_count(graph->get_result());
pass_manager.run_passes(graph.get()); pass_manager.run_passes(graph.get());
auto sorted = pass_manager.get_call_graph(); auto sorted = graph->get_ordered_ops();
EXPECT_EQ(node_count, sorted.size()); EXPECT_EQ(node_count, sorted.size());
EXPECT_TRUE(validate_list(sorted)); EXPECT_TRUE(validate_list(sorted));
} }
TEST(pass_manager, dependency) TEST(pass_manager, module_add_function)
{ {
pass::Manager pass_manager; // First create "f(A,B,C) = (A+B)*C".
auto shape = Shape{2, 2};
auto A = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto B = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto C = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto rt_f = make_shared<TensorViewType>(element::Float32::element_type(), shape);
auto f = make_shared<Function>((A + B) * C, rt_f, op::Parameters{A, B, C});
pass_manager.register_pass<pass::TopologicalSort>(); // Now make "g(X,Y,Z) = f(X,Y,Z) + f(X,Y,Z)"
EXPECT_THROW(pass_manager.register_pass<pass::AssignTensors>(), runtime_error); auto X = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto Y = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto Z = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto rt_g = make_shared<TensorViewType>(element::Float32::element_type(), shape);
auto g = make_shared<Function>(make_shared<op::FunctionCall>(f, Nodes{X, Y, Z}) +
make_shared<op::FunctionCall>(f, Nodes{X, Y, Z}),
rt_g,
op::Parameters{X, Y, Z});
} }
...@@ -218,7 +218,7 @@ TEST(memory_layout, basic) ...@@ -218,7 +218,7 @@ TEST(memory_layout, basic)
auto graph = make_test_graph(); auto graph = make_test_graph();
pass_manager.run_passes(graph); pass_manager.run_passes(graph);
auto sorted = pass_manager.get_call_graph(); auto sorted = graph->get_ordered_ops();
size_t temporary_pool_size = pass_manager.get_state().get_temporary_pool_size(); size_t temporary_pool_size = pass_manager.get_state().get_temporary_pool_size();
EXPECT_EQ(12, temporary_pool_size); EXPECT_EQ(12, temporary_pool_size);
} }
...@@ -21,10 +21,10 @@ ...@@ -21,10 +21,10 @@
#include "ngraph/log.hpp" #include "ngraph/log.hpp"
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
#include "ngraph/pass/collect_functions.hpp"
#include "ngraph/pass/manager.hpp" #include "ngraph/pass/manager.hpp"
#include "ngraph/pass/topological_sort.hpp" #include "ngraph/pass/topological_sort.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
#include "ngraph/visualize.hpp"
#include "test_tools.hpp" #include "test_tools.hpp"
using namespace std; using namespace std;
...@@ -69,7 +69,7 @@ TEST(topological_sort, basic) ...@@ -69,7 +69,7 @@ TEST(topological_sort, basic)
pass::Manager pass_manager; pass::Manager pass_manager;
pass_manager.register_pass<pass::TopologicalSort>(); pass_manager.register_pass<pass::TopologicalSort>();
pass_manager.run_passes(f0); pass_manager.run_passes(f0);
auto sorted_list = pass_manager.get_call_graph(); auto sorted_list = f0->get_ordered_ops();
size_t node_count = get_node_count(r0); size_t node_count = get_node_count(r0);
...@@ -121,7 +121,7 @@ TEST(benchmark, topological_sort) ...@@ -121,7 +121,7 @@ TEST(benchmark, topological_sort)
pass::Manager pass_manager; pass::Manager pass_manager;
pass_manager.register_pass<pass::TopologicalSort>(); pass_manager.register_pass<pass::TopologicalSort>();
pass_manager.run_passes(f0); pass_manager.run_passes(f0);
auto sorted_list = pass_manager.get_call_graph(); auto sorted_list = f0->get_ordered_ops();
timer.stop(); timer.stop();
NGRAPH_INFO << "topological sort took " << timer.get_milliseconds() << "ms"; NGRAPH_INFO << "topological sort took " << timer.get_milliseconds() << "ms";
...@@ -135,3 +135,51 @@ TEST(benchmark, topological_sort) ...@@ -135,3 +135,51 @@ TEST(benchmark, topological_sort)
timer.stop(); timer.stop();
NGRAPH_INFO << "delete nodes took " << timer.get_milliseconds() << "ms"; NGRAPH_INFO << "delete nodes took " << timer.get_milliseconds() << "ms";
} }
TEST(topological_sort, collect_functions)
{
// First create "f(A,B,C) = (A+B)*C".
auto shape = Shape{2, 2};
auto A = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto B = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto C = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto rt_f = make_shared<TensorViewType>(element::Float32::element_type(), shape);
auto f = make_shared<Function>((A + B) * C, rt_f, op::Parameters{A, B, C}, "f");
// Now make "g(X,Y,Z) = f(X,Y,Z) + f(X,Y,Z)"
auto X = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto Y = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto Z = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto rt_g = make_shared<TensorViewType>(element::Float32::element_type(), shape);
auto g = make_shared<Function>(make_shared<op::FunctionCall>(f, Nodes{X, Y, Z}) +
make_shared<op::FunctionCall>(f, Nodes{X, Y, Z}),
rt_g,
op::Parameters{X, Y, Z},
"g");
// Now make "h(X,Y,Z) = g(X,Y,Z) + g(X,Y,Z)"
auto X1 = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto Y1 = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto Z1 = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto rt_h = make_shared<TensorViewType>(element::Float32::element_type(), shape);
auto h = make_shared<Function>(make_shared<op::FunctionCall>(g, Nodes{X1, Y1, Z1}) +
make_shared<op::FunctionCall>(g, Nodes{X1, Y1, Z1}),
rt_h,
op::Parameters{X1, Y1, Z1},
"h");
pass::Manager pass_manager;
pass_manager.register_pass<pass::CollectFunctions>();
pass_manager.run_passes(h);
set<string> expected = {"f", "g", "h"};
auto functions = pass_manager.get_state().get_functions();
vector<string> fnames;
for (Function* func : functions)
{
fnames.push_back(func->get_name());
}
EXPECT_EQ(expected.size(), functions.size());
EXPECT_TRUE(contains(fnames, "f"));
EXPECT_TRUE(contains(fnames, "g"));
EXPECT_TRUE(contains(fnames, "h"));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment