Commit 118e0679 authored by Jai Menon's avatar Jai Menon Committed by GitHub

Merge branch 'master' into jmenon/codegen

parents c2ff1508 c3dfdf5a
......@@ -39,16 +39,16 @@ set (SRC
ops/unary_elementwise_arithmetic.cpp
ops/unary_elementwise_builtin.cpp
pass/assign_tensors.cpp
pass/call_pass.cpp
pass/collect_functions.cpp
pass/dump_sorted.cpp
pass/liveness.cpp
pass/manager.cpp
pass/manager_state.cpp
pass/memory_layout.cpp
pass/memory_visualize.cpp
pass/pass.cpp
pass/propagate_types.cpp
pass/topological_sort.cpp
pass/tree_pass.cpp
pass/visualize_tree.cpp
runtime/call_frame.cpp
runtime/external_function.cpp
......@@ -58,7 +58,6 @@ set (SRC
types/element_type.cpp
types/type.cpp
util.cpp
visualize.cpp
)
# find_program (GRAPHVIZ dot)
......
......@@ -15,21 +15,95 @@
#include <memory>
#include "ngraph/function.hpp"
#include "ngraph/util.hpp"
using namespace std;
using namespace ngraph;
size_t Function::m_next_instance_id = 0;
Function::Function(const std::shared_ptr<Node>& result,
const std::shared_ptr<ValueType>& result_type,
const std::vector<std::shared_ptr<op::Parameter>>& parameters)
const std::vector<std::shared_ptr<op::Parameter>>& parameters,
const std::string& name)
: m_result(result)
, m_parameters(parameters)
, m_name("Function")
, m_name(name)
, m_result_type(result_type)
, m_ordered_ops_valid(false)
, m_instance_id(m_next_instance_id++)
{
size_t i = 0;
for (auto parameter : parameters)
{
parameter->assign_function(this, i++);
}
traverse_nodes(result, [&](Node* node) { m_ops.push_back(node); });
}
void Function::set_ordered_ops(const std::list<Node*>& ordered_ops)
{
m_ordered_ops = ordered_ops;
m_ordered_ops_valid = true;
}
std::list<Node*>& Function::get_ops()
{
return m_ops;
}
const std::list<Node*>& Function::get_ops() const
{
return m_ops;
}
std::list<Node*>& Function::get_ordered_ops()
{
if (!m_ordered_ops_valid)
{
throw ngraph_error("Access to ordered ops invalid");
}
return m_ordered_ops;
}
const std::list<Node*>& Function::get_ordered_ops() const
{
if (!m_ordered_ops_valid)
{
throw ngraph_error("Access to ordered ops invalid");
}
return m_ordered_ops;
}
std::string Function::get_name() const
{
string rc;
if (m_name.empty())
{
rc = "Function_" + to_string(m_instance_id);
}
else
{
rc = m_name;
}
return rc;
}
void Function::set_name(const string& name)
{
if (m_name.empty())
{
m_name = name;
}
else
{
throw ngraph_error("Function name may be set exactly once");
}
}
std::ostream& ngraph::operator<<(std::ostream& out, const Function& f)
{
out << "Function(" << f.get_name() << ")";
return out;
}
......@@ -15,11 +15,13 @@
#pragma once
#include <initializer_list>
#include <list>
#include <memory>
#include <string>
#include <vector>
#include "ngraph/descriptor/tensor_view.hpp"
#include "ngraph/log.hpp"
#include "ngraph/node.hpp"
#include "ngraph/ops/op.hpp"
#include "ngraph/ops/parameter.hpp"
......@@ -34,7 +36,8 @@ namespace ngraph
public:
Function(const std::shared_ptr<Node>& result,
const std::shared_ptr<ValueType>& result_type,
const std::vector<std::shared_ptr<op::Parameter>>& parameters);
const std::vector<std::shared_ptr<op::Parameter>>& parameters,
const std::string& name = "");
std::shared_ptr<Node> get_result() { return m_result; }
const std::vector<std::shared_ptr<op::Parameter>> get_parameters() const
......@@ -42,11 +45,31 @@ namespace ngraph
return m_parameters;
}
const std::shared_ptr<ValueType> get_result_type() const { return m_result_type; }
std::string get_name() const { return m_name; }
std::string get_name() const;
void set_name(const std::string& name);
std::list<Node*>& get_ops();
const std::list<Node*>& get_ops() const;
std::list<Node*>& get_ordered_ops();
const std::list<Node*>& get_ordered_ops() const;
void set_ordered_ops(const std::list<Node*>&);
void set_ordered_ops_valid() { m_ordered_ops_valid = true; }
void clear_ordered_ops_valid() { m_ordered_ops_valid = false; }
friend std::ostream& operator<<(std::ostream&, const Function&);
protected:
std::shared_ptr<Node> m_result;
std::vector<std::shared_ptr<ngraph::op::Parameter>> m_parameters;
std::string m_name;
std::shared_ptr<ValueType> m_result_type;
bool m_ordered_ops_valid;
std::list<Node*> m_ordered_ops;
std::list<Node*> m_ops;
private:
Function(const Function&) = delete;
Function(const Function&&) = delete;
static size_t m_next_instance_id;
size_t m_instance_id;
};
}
......@@ -104,6 +104,32 @@ std::string Node::get_node_id() const
return ss.str();
}
std::string Node::get_name() const
{
string rc;
if (m_name.empty())
{
rc = description() + "_" + to_string(m_instance_id);
}
else
{
rc = m_name;
}
return rc;
}
void Node::set_name(const string& name)
{
if (m_name.empty())
{
m_name = name;
}
else
{
throw ngraph_error("Node name may be set exactly once");
}
}
namespace ngraph
{
ostream& operator<<(ostream& out, const Node& node)
......
......@@ -55,6 +55,8 @@ namespace ngraph
public:
/// The class name, must not contain spaces
virtual std::string description() const = 0;
std::string get_name() const;
void set_name(const std::string& name);
/// Propagate types and check arguments for consistency
virtual void propagate_types() = 0;
......
......@@ -25,15 +25,15 @@
using namespace std;
using namespace ngraph;
bool pass::AssignTensors::run_on_call_list(std::list<Node*>& node_list)
bool pass::AssignTensors::run_on_call_graph(list<Node*>& nodes)
{
for (Node* node : node_list)
for (Node* node : nodes)
{
try
{
// We need to set the nodes is_output state prior to call assign_tensors
// so that the output state can be passes to the constructed tensors.
if (node == get_state().get_function()->get_result().get())
if (node == get_state().get_functions().at(0)->get_result().get())
{
node->set_is_output();
}
......@@ -50,21 +50,3 @@ bool pass::AssignTensors::run_on_call_list(std::list<Node*>& node_list)
}
return false;
}
void pass::AssignTensors::check_dependencies(
const std::vector<std::shared_ptr<CallBase>>& registered_passes) const
{
bool found_propagate_types = false;
for (auto pass : registered_passes)
{
if (dynamic_pointer_cast<PropagateTypes>(pass))
{
found_propagate_types = true;
}
}
if (!found_propagate_types)
{
throw runtime_error("Dependency 'PropagateTypes' not found for pass 'AssignTensors'");
}
}
......@@ -14,7 +14,7 @@
#pragma once
#include "ngraph/pass/call_pass.hpp"
#include "ngraph/pass/pass.hpp"
namespace ngraph
{
......@@ -25,12 +25,10 @@ namespace ngraph
class Node;
}
class ngraph::pass::AssignTensors : public CallBase
class ngraph::pass::AssignTensors : public CallGraphPass
{
public:
virtual bool run_on_call_list(std::list<Node*>&) override;
void check_dependencies(const std::vector<std::shared_ptr<CallBase>>&) const override;
virtual bool run_on_call_graph(std::list<Node*>& nodes) override;
private:
};
......@@ -12,34 +12,39 @@
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <functional>
#include <memory>
#include <set>
#include <sstream>
namespace ngraph
#include "ngraph/pass/collect_functions.hpp"
#include "ngraph/function.hpp"
#include "ngraph/log.hpp"
#include "ngraph/node.hpp"
#include "ngraph/ops/function_call.hpp"
#include "ngraph/ops/op.hpp"
#include "ngraph/util.hpp"
using namespace std;
using namespace ngraph;
using namespace ngraph::pass;
bool CollectFunctions::run_on_function(ngraph::Function* func)
{
class Visualize;
class Node;
using node_ptr = std::shared_ptr<Node>;
set<Function*> functions;
deque<Function*> stack;
stack.push_back(func);
while (stack.empty() == false)
{
Function* f = stack.front();
stack.pop_front();
functions.insert(f);
traverse_nodes(f->get_result(), [&](Node* node) {
op::FunctionCall* fc = dynamic_cast<op::FunctionCall*>(node);
if (fc)
{
stack.push_back(fc->get_function().get());
}
});
}
get_state().set_functions(functions);
return false;
}
class ngraph::Visualize
{
public:
Visualize(const std::string& name = "ngraph");
void add(node_ptr);
void save_dot(const std::string& path) const;
private:
std::string add_attributes(const Node* node);
std::string get_attributes(const Node* node);
std::stringstream m_ss;
std::string m_name;
std::set<const Node*> m_nodes_with_attributes;
};
......@@ -12,4 +12,22 @@
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "ngraph/pass/call_pass.hpp"
#pragma once
#include "ngraph/pass/pass.hpp"
namespace ngraph
{
namespace pass
{
class CollectFunctions;
}
}
class ngraph::pass::CollectFunctions : public FunctionPass
{
public:
bool run_on_function(ngraph::Function*) override;
private:
};
......@@ -27,41 +27,44 @@ pass::DumpSorted::DumpSorted(const string& output_file)
{
}
bool pass::DumpSorted::run_on_call_list(list<Node*>& nodes)
bool pass::DumpSorted::run_on_module(vector<Function*>& functions)
{
ofstream out{m_output_file};
if (out)
{
for (const Node* node : nodes)
for (Function* f : functions)
{
out << node->get_node_id() << "(";
vector<string> inputs;
for (const Input& input : node->get_inputs())
for (const Node* node : f->get_ordered_ops())
{
inputs.push_back(input.get_tensor().get_name());
}
out << join(inputs);
out << ") -> ";
out << node->get_name() << "(";
vector<string> inputs;
for (const Input& input : node->get_inputs())
{
inputs.push_back(input.get_tensor().get_name());
}
out << join(inputs);
out << ") -> ";
vector<string> outputs;
for (const Output& output : node->get_outputs())
{
outputs.push_back(output.get_tensor().get_name());
}
out << join(outputs);
out << "\n";
vector<string> outputs;
for (const Output& output : node->get_outputs())
{
outputs.push_back(output.get_tensor().get_name());
}
out << join(outputs);
out << "\n";
for (const Tensor* tensor : node->liveness_live_list)
{
out << " L " << tensor->get_name() << "\n";
}
for (const Tensor* tensor : node->liveness_new_list)
{
out << " N " << tensor->get_name() << "\n";
}
for (const Tensor* tensor : node->liveness_free_list)
{
out << " F " << tensor->get_name() << "\n";
for (const Tensor* tensor : node->liveness_live_list)
{
out << " L " << tensor->get_name() << "\n";
}
for (const Tensor* tensor : node->liveness_new_list)
{
out << " N " << tensor->get_name() << "\n";
}
for (const Tensor* tensor : node->liveness_free_list)
{
out << " F " << tensor->get_name() << "\n";
}
}
}
}
......
......@@ -16,7 +16,7 @@
#include <string>
#include "ngraph/pass/call_pass.hpp"
#include "ngraph/pass/pass.hpp"
namespace ngraph
{
......@@ -27,12 +27,12 @@ namespace ngraph
class Node;
}
class ngraph::pass::DumpSorted : public CallBase
class ngraph::pass::DumpSorted : public ModulePass
{
public:
DumpSorted(const std::string& output_file);
virtual bool run_on_call_list(std::list<Node*>&) override;
virtual bool run_on_module(std::vector<Function*>&) override;
private:
const std::string m_output_file;
......
......@@ -27,7 +27,7 @@ using namespace std;
using namespace ngraph;
using namespace ngraph::descriptor;
bool pass::Liveness::run_on_call_list(list<Node*>& ops)
bool pass::Liveness::run_on_call_graph(list<Node*>& ops)
{
unordered_set<Tensor*> currently_live;
......@@ -123,24 +123,6 @@ bool pass::Liveness::run_on_call_list(list<Node*>& ops)
return false;
}
void pass::Liveness::check_dependencies(
const std::vector<std::shared_ptr<CallBase>>& registered_passes) const
{
bool found_propagate_types = false;
for (auto pass : registered_passes)
{
if (dynamic_pointer_cast<AssignTensors>(pass))
{
found_propagate_types = true;
}
}
if (!found_propagate_types)
{
throw runtime_error("Dependency 'PropagateTypes' not found for pass 'AssignTensors'");
}
}
bool pass::Liveness::is_temporary(const Tensor& tensor)
{
return tensor.is_persistent() == false && tensor.is_input() == false &&
......
......@@ -15,7 +15,7 @@
#pragma once
#include "ngraph/descriptor/tensor.hpp"
#include "ngraph/pass/call_pass.hpp"
#include "ngraph/pass/pass.hpp"
namespace ngraph
{
......@@ -26,12 +26,10 @@ namespace ngraph
class Node;
}
class ngraph::pass::Liveness : public CallBase
class ngraph::pass::Liveness : public CallGraphPass
{
public:
virtual bool run_on_call_list(std::list<Node*>&) override;
void check_dependencies(const std::vector<std::shared_ptr<CallBase>>&) const override;
virtual bool run_on_call_graph(std::list<Node*>&) override;
private:
bool is_temporary(const descriptor::Tensor&);
......
......@@ -19,40 +19,11 @@
#include "ngraph/log.hpp"
#include "ngraph/node.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/pass.hpp"
using namespace std;
using namespace ngraph;
Function* ngraph::pass::ManagerState::get_function()
{
return m_function;
}
void ngraph::pass::ManagerState::set_function(Function* func)
{
m_function = func;
}
size_t ngraph::pass::ManagerState::get_temporary_pool_size()
{
return m_temporary_pool_size;
}
void ngraph::pass::ManagerState::set_temporary_pool_size(size_t size)
{
m_temporary_pool_size = size;
}
std::list<Node*>& ngraph::pass::ManagerState::get_call_graph()
{
return m_call_graph;
}
const std::list<Node*>& ngraph::pass::ManagerState::get_call_graph() const
{
return m_call_graph;
}
ngraph::pass::Manager::Manager()
{
}
......@@ -65,26 +36,6 @@ void ngraph::pass::Manager::initialize_default_passes()
{
}
void ngraph::pass::Manager::register_pass_ptr(std::shared_ptr<TreeBase> p)
{
if (p == nullptr)
{
throw invalid_argument("null pass registered");
}
p->check_dependencies(m_tree_passes);
m_tree_passes.push_back(p);
}
void ngraph::pass::Manager::register_pass_ptr(std::shared_ptr<CallBase> p)
{
if (p == nullptr)
{
throw invalid_argument("null pass registered");
}
p->check_dependencies(m_call_passes);
m_call_passes.push_back(p);
}
void ngraph::pass::Manager::run_passes(shared_ptr<Function> func)
{
run_passes(func.get());
......@@ -92,23 +43,79 @@ void ngraph::pass::Manager::run_passes(shared_ptr<Function> func)
void ngraph::pass::Manager::run_passes(Function* func)
{
m_state.set_function(func);
for (shared_ptr<TreeBase> p : m_tree_passes)
{
p->set_state(get_state());
p->run_on_tree(func->get_result());
}
vector<Function*> fs = {func};
get_state().set_functions(fs);
for (shared_ptr<CallBase>& p : m_call_passes)
for (shared_ptr<PassBase> pass : m_pass_list)
{
p->set_state(get_state());
p->run_on_call_list(get_state().get_call_graph());
pass->set_state(get_state());
auto module_pass = dynamic_pointer_cast<ModulePass>(pass);
auto function_pass = dynamic_pointer_cast<FunctionPass>(pass);
auto node_pass = dynamic_pointer_cast<NodePass>(pass);
auto call_graph_pass = dynamic_pointer_cast<CallGraphPass>(pass);
if (module_pass)
{
module_pass->run_on_module(fs);
}
else if (function_pass)
{
for (Function* f : fs)
{
function_pass->run_on_function(f);
}
}
else if (node_pass)
{
for (Function* f : fs)
{
for (Node* n : f->get_ops())
{
node_pass->run_on_node(n);
}
}
}
else if (call_graph_pass)
{
for (Function* f : fs)
{
call_graph_pass->run_on_call_graph(f->get_ordered_ops());
}
}
}
}
const std::list<ngraph::Node*>& ngraph::pass::Manager::get_call_graph() const
{
return m_state.get_call_graph();
// for (shared_ptr<ModulePass>& p : m_module_passes)
// {
// p->set_state(get_state());
// p->run_on_module(fs);
// }
// for (Function* f : fs)
// {
// for (shared_ptr<FunctionPass> p : m_function_passes)
// {
// p->set_state(get_state());
// p->run_on_function(f);
// }
// }
// for (Function* f : fs)
// {
// NGRAPH_INFO;
// for (shared_ptr<NodePass> p : m_node_passes)
// {
// for (Node* node : f->get_ops())
// {
// NGRAPH_INFO;
// p->set_state(get_state());
// p->run_on_node(node);
// }
// }
// }
// for (shared_ptr<CallGraphPass>& p : m_call_graph_passes)
// {
// p->set_state(get_state());
// p->run_on_call_graph(func->get_ordered_ops());
// }
}
ngraph::pass::ManagerState& ngraph::pass::Manager::get_state()
......
......@@ -18,8 +18,8 @@
#include <memory>
#include <vector>
#include "ngraph/pass/call_pass.hpp"
#include "ngraph/pass/tree_pass.hpp"
#include "ngraph/pass/manager_state.hpp"
#include "ngraph/pass/pass.hpp"
namespace ngraph
{
......@@ -33,24 +33,6 @@ namespace ngraph
class Function;
}
class ngraph::pass::ManagerState
{
public:
Function* get_function();
void set_function(Function*);
size_t get_temporary_pool_size();
void set_temporary_pool_size(size_t);
std::list<Node*>& get_call_graph();
const std::list<Node*>& get_call_graph() const;
private:
Function* m_function = nullptr;
size_t m_temporary_pool_size = 0;
std::list<Node*> m_call_graph;
};
class ngraph::pass::Manager
{
public:
......@@ -62,29 +44,18 @@ public:
template <typename T, class... Args>
void register_pass(Args... args)
{
static_assert(std::is_base_of<pass::Base, T>::value, "pass not derived from pass base");
if (std::is_base_of<TreeBase, T>::value)
{
register_pass_ptr(std::make_shared<T>(args...));
}
else if (std::is_base_of<CallBase, T>::value)
{
register_pass_ptr(std::make_shared<T>(args...));
}
static_assert(std::is_base_of<pass::PassBase, T>::value, "pass not derived from pass base");
auto pass = std::make_shared<T>(args...);
auto pass_base = std::static_pointer_cast<PassBase>(pass);
m_pass_list.push_back(pass_base);
}
void run_passes(Function*);
void run_passes(std::shared_ptr<Function>);
const std::list<Node*>& get_call_graph() const;
ManagerState& get_state();
private:
void register_pass_ptr(std::shared_ptr<TreeBase>);
void register_pass_ptr(std::shared_ptr<CallBase>);
std::vector<std::shared_ptr<TreeBase>> m_tree_passes;
std::vector<std::shared_ptr<CallBase>> m_call_passes;
std::vector<std::shared_ptr<PassBase>> m_pass_list;
ManagerState m_state;
};
......@@ -12,31 +12,28 @@
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <list>
#include <iostream>
#include <memory>
#include <vector>
#include "ngraph/pass/pass.hpp"
#include "ngraph/function.hpp"
#include "ngraph/log.hpp"
#include "ngraph/node.hpp"
#include "ngraph/pass/manager_state.hpp"
namespace ngraph
{
namespace pass
{
class TreeBase;
}
using namespace std;
using namespace ngraph;
class Node;
vector<Function*>& ngraph::pass::ManagerState::get_functions()
{
return m_function_list;
}
class ngraph::pass::TreeBase : public Base
size_t ngraph::pass::ManagerState::get_temporary_pool_size()
{
public:
virtual ~TreeBase() {}
// return true if changes were made to the tree
virtual bool run_on_tree(std::shared_ptr<Node>) = 0;
return m_temporary_pool_size;
}
// derived class throws exception if its dependencies have not been met
virtual void check_dependencies(const std::vector<std::shared_ptr<TreeBase>>&) const {}
};
void ngraph::pass::ManagerState::set_temporary_pool_size(size_t size)
{
m_temporary_pool_size = size;
}
......@@ -14,29 +14,36 @@
#pragma once
#include <list>
#include <memory>
#include <vector>
#include "ngraph/pass/pass.hpp"
namespace ngraph
{
namespace pass
{
class CallBase;
class ManagerState;
}
class Node;
class Function;
}
class ngraph::pass::CallBase : public Base
class ngraph::pass::ManagerState
{
public:
virtual ~CallBase() {}
virtual bool run_on_call_list(std::list<Node*>&) = 0;
std::vector<Function*>& get_functions();
template <typename T>
void set_functions(const T& collection)
{
m_function_list.clear();
m_function_list.insert(m_function_list.begin(), collection.begin(), collection.end());
}
size_t get_temporary_pool_size();
void set_temporary_pool_size(size_t);
// derived class throws exception if its dependencies have not been met
virtual void check_dependencies(const std::vector<std::shared_ptr<CallBase>>&) const {}
private:
size_t m_temporary_pool_size = 0;
std::vector<Function*> m_function_list;
};
......@@ -27,7 +27,7 @@ using namespace std;
using namespace ngraph;
using namespace ngraph::descriptor;
bool pass::MemoryLayout::run_on_call_list(std::list<Node*>& node_list)
bool pass::MemoryLayout::run_on_call_graph(std::list<Node*>& node_list)
{
MemoryManager mm;
for (const Node* node : node_list)
......@@ -47,24 +47,6 @@ bool pass::MemoryLayout::run_on_call_list(std::list<Node*>& node_list)
return false;
}
void pass::MemoryLayout::check_dependencies(
const std::vector<std::shared_ptr<CallBase>>& registered_passes) const
{
bool found_propagate_types = false;
for (auto pass : registered_passes)
{
if (dynamic_pointer_cast<Liveness>(pass))
{
found_propagate_types = true;
}
}
if (!found_propagate_types)
{
throw runtime_error("Dependency 'PropagateTypes' not found for pass 'AssignTensors'");
}
}
pass::MemoryManager::node::node(size_t size, block_state state)
: m_size{size}
, m_state{state}
......
......@@ -18,7 +18,7 @@
#include <list>
#include <sstream>
#include "ngraph/pass/call_pass.hpp"
#include "ngraph/pass/pass.hpp"
namespace ngraph
{
......@@ -31,12 +31,10 @@ namespace ngraph
class Node;
}
class ngraph::pass::MemoryLayout : public CallBase
class ngraph::pass::MemoryLayout : public CallGraphPass
{
public:
virtual bool run_on_call_list(std::list<Node*>&) override;
void check_dependencies(const std::vector<std::shared_ptr<CallBase>>&) const override;
virtual bool run_on_call_graph(std::list<Node*>&) override;
private:
};
......
......@@ -19,6 +19,7 @@
#include "memory_visualize.hpp"
#include "ngraph/descriptor/tensor.hpp"
#include "ngraph/function.hpp"
#include "ngraph/node.hpp"
#include "ngraph/util.hpp"
......@@ -31,71 +32,70 @@ pass::MemoryVisualize::MemoryVisualize(const string& filename)
{
}
bool pass::MemoryVisualize::run_on_call_list(list<Node*>& _nodes)
bool pass::MemoryVisualize::run_on_module(vector<Function*>& functions)
{
const list<Node*> nodes = _nodes;
ofstream file(m_filename);
{
file << "<!DOCTYPE html>\n<html>\n";
file << "<head>\n";
file << " <style>\n";
file << " th, td {\n";
file << " border-bottom: 1px solid #ddd;\n";
file << " width: 200px;\n";
file << " }\n";
file << " table, td, th {\n";
// file << " border: 1px solid #ddd;\n";
// file << " text-align: left;\n";
file << " }\n";
file << " table {\n";
file << " border-collapse: collapse;\n";
// file << " width: 100%;\n";
file << " }\n";
// file << " tr:hover {background-color: #f5f5f5}\n";
file << " tr:nth-child(even) {background-color: #f2f2f2}\n";
file << " </style>\n";
file << "</head>\n";
file << "<body>\n";
unordered_set<descriptor::Tensor*> tensors;
size_t temp_max_size = 0;
for (Node* node : nodes)
{
tensors.insert(node->liveness_live_list.begin(), node->liveness_live_list.end());
}
for (descriptor::Tensor* tensor : tensors)
for (const Function* f : functions)
{
if (tensor->is_persistent() == false)
const list<Node*> nodes = f->get_ordered_ops();
file << "<!DOCTYPE html>\n<html>\n";
file << "<head>\n";
file << " <style>\n";
file << " th, td {\n";
file << " border-bottom: 1px solid #ddd;\n";
file << " width: 200px;\n";
file << " }\n";
file << " table, td, th {\n";
// file << " border: 1px solid #ddd;\n";
// file << " text-align: left;\n";
file << " }\n";
file << " table {\n";
file << " border-collapse: collapse;\n";
// file << " width: 100%;\n";
file << " }\n";
// file << " tr:hover {background-color: #f5f5f5}\n";
file << " tr:nth-child(even) {background-color: #f2f2f2}\n";
file << " </style>\n";
file << "</head>\n";
file << "<body>\n";
unordered_set<descriptor::Tensor*> tensors;
size_t temp_max_size = 0;
for (Node* node : nodes)
{
temp_max_size += tensor->size();
tensors.insert(node->liveness_live_list.begin(), node->liveness_live_list.end());
}
for (descriptor::Tensor* tensor : tensors)
{
if (tensor->is_persistent() == false)
{
temp_max_size += tensor->size();
}
}
}
// file << "<table>\n";
// file << "<tr><td>Persistent Memory Footprint</td><td align=\"right\">";
// file << computation_decl.exop_block.persistent_size() << "</td></tr>\n";
// file << "<tr><td>Temporary Memory Footprint</td><td align=\"right\">";
// file << computation_decl.exop_block.memory_footprint() << "</td></tr>\n";
// file << "<tr><td>Max temporary Memory Footprint</td><td align=\"right\">";
// file << temp_max_size << "</td></tr>\n";
// file << "</table>\n";
// file << "<table>\n";
// file << "<tr><td>Persistent Memory Footprint</td><td align=\"right\">";
// file << computation_decl.exop_block.persistent_size() << "</td></tr>\n";
// file << "<tr><td>Temporary Memory Footprint</td><td align=\"right\">";
// file << computation_decl.exop_block.memory_footprint() << "</td></tr>\n";
// file << "<tr><td>Max temporary Memory Footprint</td><td align=\"right\">";
// file << temp_max_size << "</td></tr>\n";
// file << "</table>\n";
file << "<hr>\n";
draw_tensor_weight(file, nodes);
// file << "<hr>\n";
// draw_op_influence(file);
file << "<hr>\n";
draw_histogram(file, nodes);
// file << "<hr>\n";
file << "</body>\n</html>\n";
file << "<hr>\n";
draw_tensor_weight(file, nodes);
// file << "<hr>\n";
// draw_op_influence(file);
file << "<hr>\n";
draw_histogram(file, nodes);
// file << "<hr>\n";
file << "</body>\n</html>\n";
}
}
return false;
}
void pass::MemoryVisualize::check_dependencies(const vector<shared_ptr<CallBase>>& deps) const
{
}
const Node* pass::MemoryVisualize::find_largest_op(const list<Node*>& nodes)
{
const Node* largest_op = nullptr;
......@@ -207,7 +207,7 @@ void pass::MemoryVisualize::draw_histogram(ostream& file, const list<Node*>& nod
size_t x2 = ((usage / memory_footprint) * scale) + offset;
file << "<text x=\"" << 0 << "\" y=\"" << y + text_offset << "\" fill=\""
<< "black"
<< "\">" << node->get_node_id() << "</text>\n";
<< "\">" << node->get_name() << "</text>\n";
file << "<line x1=\"" << x1 << "\" y1=\"" << y << "\" x2=\"" << x2 << "\" y2=\"" << y
<< "\"";
file << " style=\"stroke:forestgreen;stroke-width:" << stroke_width << "\" />\n";
......@@ -231,7 +231,7 @@ void pass::MemoryVisualize::draw_op_influence(ostream& file, const list<Node*>&
{
int weight = compute_op_weight(exop);
file << " <tr>";
file << "<td>" << exop->get_node_id() << "</td>";
file << "<td>" << exop->get_name() << "</td>";
file << "<td align=\"right\">" << weight << "</td>";
file << "</tr>\n";
}
......
......@@ -18,7 +18,7 @@
#include <limits>
#include <list>
#include "ngraph/pass/call_pass.hpp"
#include "ngraph/pass/pass.hpp"
namespace ngraph
{
......@@ -29,13 +29,11 @@ namespace ngraph
class Node;
}
class ngraph::pass::MemoryVisualize : public CallBase
class ngraph::pass::MemoryVisualize : public ModulePass
{
public:
MemoryVisualize(const std::string& filename);
virtual bool run_on_call_list(std::list<Node*>&) override;
void check_dependencies(const std::vector<std::shared_ptr<CallBase>>&) const override;
virtual bool run_on_module(std::vector<Function*>&) override;
private:
const Node* find_largest_op(const std::list<Node*>& nodes);
......
......@@ -15,12 +15,12 @@
#include "ngraph/pass/pass.hpp"
#include "ngraph/pass/manager.hpp"
ngraph::pass::ManagerState& ngraph::pass::Base::get_state()
ngraph::pass::ManagerState& ngraph::pass::PassBase::get_state()
{
return *m_state;
}
void ngraph::pass::Base::set_state(ManagerState& state)
void ngraph::pass::PassBase::set_state(ManagerState& state)
{
m_state = &state;
}
......@@ -14,21 +14,33 @@
#pragma once
#include <list>
#include <memory>
#include <vector>
#include "ngraph/node.hpp"
namespace ngraph
{
namespace pass
{
class Base;
class PassBase;
class ModulePass;
class FunctionPass;
class NodePass;
class CallGraphPass;
class Manager;
class ManagerState;
}
class Function;
}
class ngraph::pass::Base
class ngraph::pass::PassBase
{
friend class Manager;
public:
virtual ~PassBase() {}
protected:
ManagerState& get_state();
void set_state(ManagerState&);
......@@ -36,3 +48,31 @@ protected:
private:
ManagerState* m_state;
};
class ngraph::pass::ModulePass : public PassBase
{
public:
virtual ~ModulePass() {}
virtual bool run_on_module(std::vector<ngraph::Function*>&) = 0;
};
class ngraph::pass::FunctionPass : public PassBase
{
public:
virtual ~FunctionPass() {}
virtual bool run_on_function(ngraph::Function*) = 0;
};
class ngraph::pass::NodePass : public PassBase
{
public:
virtual ~NodePass() {}
virtual bool run_on_node(ngraph::Node*) = 0;
};
class ngraph::pass::CallGraphPass : public PassBase
{
public:
virtual ~CallGraphPass() {}
virtual bool run_on_call_graph(std::list<ngraph::Node*>&) = 0;
};
......@@ -20,9 +20,9 @@
using namespace std;
using namespace ngraph;
bool pass::PropagateTypes::run_on_call_list(std::list<Node*>& node_list)
bool pass::PropagateTypes::run_on_call_graph(list<Node*>& nodes)
{
for (Node* node : node_list)
for (Node* node : nodes)
{
try
{
......
......@@ -14,7 +14,7 @@
#pragma once
#include "ngraph/pass/call_pass.hpp"
#include "ngraph/pass/pass.hpp"
namespace ngraph
{
......@@ -25,10 +25,10 @@ namespace ngraph
class Node;
}
class ngraph::pass::PropagateTypes : public CallBase
class ngraph::pass::PropagateTypes : public CallGraphPass
{
public:
virtual bool run_on_call_list(std::list<Node*>&) override;
virtual bool run_on_call_graph(std::list<Node*>&) override;
private:
};
......@@ -15,6 +15,7 @@
#include <deque>
#include <unordered_map>
#include "ngraph/function.hpp"
#include "ngraph/log.hpp"
#include "ngraph/node.hpp"
#include "ngraph/pass/manager.hpp"
......@@ -24,14 +25,13 @@
using namespace ngraph;
using namespace std;
bool ngraph::pass::TopologicalSort::run_on_tree(std::shared_ptr<Node> p)
bool ngraph::pass::TopologicalSort::run_on_function(ngraph::Function* func)
{
list<Node*>& sorted_list = get_state().get_call_graph();
sorted_list.clear();
list<Node*> result_list;
deque<Node*> independent_nodes;
unordered_map<Node*, size_t> node_depencency_count;
traverse_nodes(p, [&](Node* node) {
traverse_nodes(func->get_result(), [&](Node* node) {
node_depencency_count[node] = node->get_arguments().size();
if (node->get_arguments().size() == 0)
{
......@@ -42,7 +42,7 @@ bool ngraph::pass::TopologicalSort::run_on_tree(std::shared_ptr<Node> p)
while (independent_nodes.size() > 0)
{
auto independent_node = independent_nodes.front();
sorted_list.push_back(independent_node);
result_list.push_back(independent_node);
independent_nodes.pop_front();
for (auto user : independent_node->users())
......@@ -56,5 +56,7 @@ bool ngraph::pass::TopologicalSort::run_on_tree(std::shared_ptr<Node> p)
}
}
func->set_ordered_ops(result_list);
return false;
}
......@@ -17,7 +17,7 @@
#include <list>
#include <memory>
#include "ngraph/pass/tree_pass.hpp"
#include "ngraph/pass/pass.hpp"
namespace ngraph
{
......@@ -28,9 +28,9 @@ namespace ngraph
class Node;
}
class ngraph::pass::TopologicalSort : public TreeBase
class ngraph::pass::TopologicalSort : public FunctionPass
{
public:
TopologicalSort() {}
bool run_on_tree(std::shared_ptr<Node>) override;
bool run_on_function(ngraph::Function*) override;
};
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "ngraph/pass/tree_pass.hpp"
......@@ -14,25 +14,30 @@
#include <fstream>
#include "ngraph/function.hpp"
#include "ngraph/node.hpp"
#include "ngraph/pass/pass.hpp"
#include "ngraph/pass/visualize_tree.hpp"
#include "ngraph/util.hpp"
using namespace ngraph;
using namespace std;
bool pass::VisualizeTree::run_on_tree(std::shared_ptr<Node> base_node)
bool pass::VisualizeTree::run_on_module(vector<ngraph::Function*>& functions)
{
// map<size_t, list<node_ptr>> dependent_nodes;
traverse_nodes(base_node, [&](Node* node) {
for (auto arg : node->get_arguments())
{
m_ss << add_attributes(arg.get());
m_ss << add_attributes(node);
m_ss << " " << arg->get_node_id() << " -> " << node->get_node_id();
m_ss << ";\n";
}
});
for (Function* f : functions)
{
// map<size_t, list<node_ptr>> dependent_nodes;
traverse_nodes(f->get_result(), [&](Node* node) {
for (auto arg : node->get_arguments())
{
m_ss << add_attributes(arg.get());
m_ss << add_attributes(node);
m_ss << " " << arg->get_name() << " -> " << node->get_name();
m_ss << ";\n";
}
});
}
render();
......@@ -60,11 +65,11 @@ std::string pass::VisualizeTree::get_attributes(const Node* node)
stringstream ss;
if (node->is_parameter())
{
ss << " " << node->get_node_id() << " [shape=box color=blue]\n";
ss << " " << node->get_name() << " [shape=box color=blue]\n";
}
else
{
ss << " " << node->get_node_id() << " [shape=ellipse color=black]\n";
ss << " " << node->get_name() << " [shape=ellipse color=black]\n";
}
return ss.str();
}
......
......@@ -18,7 +18,7 @@
#include <sstream>
#include <string>
#include "ngraph/pass/tree_pass.hpp"
#include "ngraph/pass/pass.hpp"
namespace ngraph
{
......@@ -29,11 +29,11 @@ namespace ngraph
class Node;
}
class ngraph::pass::VisualizeTree : public TreeBase
class ngraph::pass::VisualizeTree : public ModulePass
{
public:
VisualizeTree(const std::string& file_name);
bool run_on_tree(std::shared_ptr<Node>) override;
bool run_on_module(std::vector<ngraph::Function*>&) override;
private:
std::string add_attributes(const Node* node);
......
......@@ -659,7 +659,7 @@ void ExternalFunction::compile(FunctionMap& function_map)
// Turn this into a pass
// Assign layouts
// For now, just make everyone row-major.
for (const Node* node : pass_manager.get_call_graph())
for (const Node* node : m_function->get_ordered_ops())
{
for (const descriptor::Output& output : node->get_outputs())
{
......@@ -696,7 +696,7 @@ void ExternalFunction::compile(FunctionMap& function_map)
m_n_outputs = tensor_index.size() - m_n_inputs;
// All remaining tensor views
for (const Node* node : pass_manager.get_call_graph())
for (const Node* node : m_function->get_ordered_ops())
{
for (const descriptor::Output& output : node->get_outputs())
{
......@@ -712,7 +712,7 @@ void ExternalFunction::compile(FunctionMap& function_map)
// Now we build the eigen-VM instructions
auto op_map = get_op_map();
for (const Node* node : pass_manager.get_call_graph())
for (const Node* node : m_function->get_ordered_ops())
{
auto handler_it = op_map.find(type_index(typeid(*node)));
if (handler_it == op_map.end())
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <cstdio>
#include <fstream>
#include <list>
#include "ngraph/node.hpp"
#include "ngraph/util.hpp"
#include "ngraph/visualize.hpp"
using namespace ngraph;
using namespace std;
Visualize::Visualize(const string& name)
: m_name{name}
{
}
void Visualize::add(node_ptr p)
{
// map<size_t, list<node_ptr>> dependent_nodes;
traverse_nodes(p, [&](Node* node) {
for (auto arg : node->get_arguments())
{
m_ss << add_attributes(arg.get());
m_ss << add_attributes(node);
m_ss << " " << arg->get_node_id() << " -> " << node->get_node_id();
m_ss << ";\n";
}
});
}
std::string Visualize::add_attributes(const Node* node)
{
string rc;
if (!contains(m_nodes_with_attributes, node))
{
m_nodes_with_attributes.insert(node);
rc = get_attributes(node);
}
return rc;
}
std::string Visualize::get_attributes(const Node* node)
{
stringstream ss;
if (node->is_parameter())
{
ss << " " << node->get_node_id() << " [shape=box color=blue]\n";
}
else
{
ss << " " << node->get_node_id() << " [shape=ellipse color=black]\n";
}
return ss.str();
}
void Visualize::save_dot(const string& path) const
{
#ifdef GRAPHVIZ_FOUND
auto tmp_file = path + ".tmp";
ofstream out(tmp_file);
if (out)
{
out << "digraph " << m_name << "\n{\n";
out << m_ss.str();
out << "}\n";
out.close();
stringstream ss;
ss << "dot -Tpng " << tmp_file << " -o " << path;
auto cmd = ss.str();
auto stream = popen(cmd.c_str(), "r");
pclose(stream);
remove(tmp_file.c_str());
}
#else
#endif
}
......@@ -51,7 +51,7 @@ TEST(pass, liveness)
shared_ptr<Function> func = make_test_graph();
pass_manager.run_passes(func.get());
auto sorted = pass_manager.get_call_graph();
auto sorted = func->get_ordered_ops();
// for (const Node* node : sorted)
// {
......
......@@ -40,15 +40,28 @@ TEST(pass_manager, add)
auto graph = make_test_graph();
size_t node_count = get_node_count(graph->get_result());
pass_manager.run_passes(graph.get());
auto sorted = pass_manager.get_call_graph();
auto sorted = graph->get_ordered_ops();
EXPECT_EQ(node_count, sorted.size());
EXPECT_TRUE(validate_list(sorted));
}
TEST(pass_manager, dependency)
TEST(pass_manager, module_add_function)
{
pass::Manager pass_manager;
// First create "f(A,B,C) = (A+B)*C".
auto shape = Shape{2, 2};
auto A = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto B = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto C = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto rt_f = make_shared<TensorViewType>(element::Float32::element_type(), shape);
auto f = make_shared<Function>((A + B) * C, rt_f, op::Parameters{A, B, C});
pass_manager.register_pass<pass::TopologicalSort>();
EXPECT_THROW(pass_manager.register_pass<pass::AssignTensors>(), runtime_error);
// Now make "g(X,Y,Z) = f(X,Y,Z) + f(X,Y,Z)"
auto X = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto Y = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto Z = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto rt_g = make_shared<TensorViewType>(element::Float32::element_type(), shape);
auto g = make_shared<Function>(make_shared<op::FunctionCall>(f, Nodes{X, Y, Z}) +
make_shared<op::FunctionCall>(f, Nodes{X, Y, Z}),
rt_g,
op::Parameters{X, Y, Z});
}
......@@ -218,7 +218,7 @@ TEST(memory_layout, basic)
auto graph = make_test_graph();
pass_manager.run_passes(graph);
auto sorted = pass_manager.get_call_graph();
auto sorted = graph->get_ordered_ops();
size_t temporary_pool_size = pass_manager.get_state().get_temporary_pool_size();
EXPECT_EQ(12, temporary_pool_size);
}
......@@ -21,10 +21,10 @@
#include "ngraph/log.hpp"
#include "ngraph/ngraph.hpp"
#include "ngraph/pass/collect_functions.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/topological_sort.hpp"
#include "ngraph/util.hpp"
#include "ngraph/visualize.hpp"
#include "test_tools.hpp"
using namespace std;
......@@ -69,7 +69,7 @@ TEST(topological_sort, basic)
pass::Manager pass_manager;
pass_manager.register_pass<pass::TopologicalSort>();
pass_manager.run_passes(f0);
auto sorted_list = pass_manager.get_call_graph();
auto sorted_list = f0->get_ordered_ops();
size_t node_count = get_node_count(r0);
......@@ -121,7 +121,7 @@ TEST(benchmark, topological_sort)
pass::Manager pass_manager;
pass_manager.register_pass<pass::TopologicalSort>();
pass_manager.run_passes(f0);
auto sorted_list = pass_manager.get_call_graph();
auto sorted_list = f0->get_ordered_ops();
timer.stop();
NGRAPH_INFO << "topological sort took " << timer.get_milliseconds() << "ms";
......@@ -135,3 +135,51 @@ TEST(benchmark, topological_sort)
timer.stop();
NGRAPH_INFO << "delete nodes took " << timer.get_milliseconds() << "ms";
}
TEST(topological_sort, collect_functions)
{
// First create "f(A,B,C) = (A+B)*C".
auto shape = Shape{2, 2};
auto A = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto B = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto C = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto rt_f = make_shared<TensorViewType>(element::Float32::element_type(), shape);
auto f = make_shared<Function>((A + B) * C, rt_f, op::Parameters{A, B, C}, "f");
// Now make "g(X,Y,Z) = f(X,Y,Z) + f(X,Y,Z)"
auto X = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto Y = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto Z = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto rt_g = make_shared<TensorViewType>(element::Float32::element_type(), shape);
auto g = make_shared<Function>(make_shared<op::FunctionCall>(f, Nodes{X, Y, Z}) +
make_shared<op::FunctionCall>(f, Nodes{X, Y, Z}),
rt_g,
op::Parameters{X, Y, Z},
"g");
// Now make "h(X,Y,Z) = g(X,Y,Z) + g(X,Y,Z)"
auto X1 = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto Y1 = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto Z1 = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto rt_h = make_shared<TensorViewType>(element::Float32::element_type(), shape);
auto h = make_shared<Function>(make_shared<op::FunctionCall>(g, Nodes{X1, Y1, Z1}) +
make_shared<op::FunctionCall>(g, Nodes{X1, Y1, Z1}),
rt_h,
op::Parameters{X1, Y1, Z1},
"h");
pass::Manager pass_manager;
pass_manager.register_pass<pass::CollectFunctions>();
pass_manager.run_passes(h);
set<string> expected = {"f", "g", "h"};
auto functions = pass_manager.get_state().get_functions();
vector<string> fnames;
for (Function* func : functions)
{
fnames.push_back(func->get_name());
}
EXPECT_EQ(expected.size(), functions.size());
EXPECT_TRUE(contains(fnames, "f"));
EXPECT_TRUE(contains(fnames, "g"));
EXPECT_TRUE(contains(fnames, "h"));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment