Commit 11a78e5f authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Merge pull request #103 from NervanaSystems/bob/pass4

Add pass manager
parents 826ce031 accc570c
...@@ -24,7 +24,7 @@ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-global-constructors") ...@@ -24,7 +24,7 @@ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-global-constructors")
# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-gnu-zero-variadic-macro-arguments") # set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-gnu-zero-variadic-macro-arguments")
# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-undef") # set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-undef")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-exit-time-destructors") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-exit-time-destructors")
# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-missing-prototypes") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-missing-prototypes")
# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-disabled-macro-expansion") # set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-disabled-macro-expansion")
# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-pedantic") # set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-pedantic")
# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-documentation") # set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-documentation")
......
...@@ -12,30 +12,35 @@ ...@@ -12,30 +12,35 @@
# limitations under the License. # limitations under the License.
set (SRC set (SRC
tree.cpp
util.cpp
log.cpp log.cpp
ngraph/descriptor/input.cpp ngraph/descriptor/input.cpp
ngraph/descriptor/output.cpp ngraph/descriptor/output.cpp
ngraph/descriptor/tensor.cpp
ngraph/descriptor/tensor_view.cpp ngraph/descriptor/tensor_view.cpp
ngraph/descriptor/tensor.cpp
ngraph/node.cpp
ngraph/pass/call_pass.cpp
ngraph/pass/manager.cpp
ngraph/pass/pass.cpp
ngraph/pass/propagate_types.cpp
ngraph/pass/topological_sort.cpp
ngraph/pass/tree_pass.cpp
ngraph/visualize.cpp
ops/binary_elementwise_builtin.cpp ops/binary_elementwise_builtin.cpp
ops/broadcast.cpp ops/broadcast.cpp
ops/concatenate.cpp ops/concatenate.cpp
ops/convert.cpp
ops/constant.cpp ops/constant.cpp
ops/convert.cpp
ops/dot.cpp ops/dot.cpp
ops/function.cpp ops/function.cpp
ops/op.cpp ops/op.cpp
ops/parameter.cpp ops/parameter.cpp
ops/tuple.cpp ops/tuple.cpp
ops/unary_elementwise_builtin.cpp ops/unary_elementwise_builtin.cpp
tree.cpp
types/element_type.cpp types/element_type.cpp
types/type.cpp types/type.cpp
ngraph/node.cpp util.cpp
ngraph/topological_sort.cpp )
ngraph/visualize.cpp
)
set(NGRAPH_INCLUDE_PATH set(NGRAPH_INCLUDE_PATH
${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_CURRENT_SOURCE_DIR}
......
...@@ -45,6 +45,7 @@ namespace ngraph ...@@ -45,6 +45,7 @@ namespace ngraph
bool operator==(const Type& other) const; bool operator==(const Type& other) const;
bool operator!=(const Type& other) const { return !(*this == other); } bool operator!=(const Type& other) const { return !(*this == other); }
friend std::ostream& operator<<(std::ostream&, const Type&);
private: private:
static std::map<std::string, Type> m_element_list; static std::map<std::string, Type> m_element_list;
......
...@@ -64,6 +64,7 @@ namespace ngraph ...@@ -64,6 +64,7 @@ namespace ngraph
void assign_tensors(); void assign_tensors();
const Nodes& get_arguments() const { return m_arguments; } const Nodes& get_arguments() const { return m_arguments; }
void clear_arguments() { m_arguments.clear(); }
const std::multiset<Node*>& users() const { return m_users; } const std::multiset<Node*>& users() const { return m_users; }
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "call_pass.hpp"
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <list>
#include "pass.hpp"
namespace ngraph
{
namespace pass
{
class CallBase;
}
class Node;
}
class ngraph::pass::CallBase : public Base
{
public:
virtual ~CallBase() {}
virtual bool run_on_call_list(std::list<Node*>) = 0;
private:
};
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <iostream>
#include <memory>
#include "manager.hpp"
#include "ngraph/node.hpp"
#include "log.hpp"
using namespace std;
ngraph::pass::Manager::Manager()
{
}
ngraph::pass::Manager::~Manager()
{
}
void ngraph::pass::Manager::initialize_default_passes()
{
}
void ngraph::pass::Manager::register_pass(std::shared_ptr<TreeBase> p)
{
if (p == nullptr)
{
throw invalid_argument("null pass registered");
}
m_tree_passes.push_back(p);
}
void ngraph::pass::Manager::register_pass(std::shared_ptr<CallBase> p)
{
if (p == nullptr)
{
throw invalid_argument("null pass registered");
}
m_call_passes.push_back(p);
}
void ngraph::pass::Manager::run_passes(std::shared_ptr<Node> nodes)
{
for (shared_ptr<TreeBase> p : m_tree_passes)
{
p->run_on_tree(nodes);
if (p->call_graph_produced())
{
m_sorted_list = p->get_call_graph();
}
}
for (shared_ptr<CallBase>& p : m_call_passes)
{
p->run_on_call_list(m_sorted_list);
}
}
const std::list<ngraph::Node*>& ngraph::pass::Manager::get_sorted_list() const
{
return m_sorted_list;
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <vector>
#include "call_pass.hpp"
#include "tree_pass.hpp"
namespace ngraph
{
namespace pass
{
class Manager;
}
class Node;
}
class ngraph::pass::Manager
{
public:
Manager();
~Manager();
void initialize_default_passes();
void register_pass(std::shared_ptr<TreeBase>);
void register_pass(std::shared_ptr<CallBase>);
void run_passes(std::shared_ptr<Node> nodes);
const std::list<Node*>& get_sorted_list() const;
private:
std::vector<std::shared_ptr<TreeBase>> m_tree_passes;
std::vector<std::shared_ptr<CallBase>> m_call_passes;
std::list<Node*> m_sorted_list;
};
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "pass.hpp"
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace pass
{
class Base;
}
}
class ngraph::pass::Base
{
public:
private:
};
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <sstream>
#include "propagate_types.hpp"
#include "ngraph/ngraph.hpp"
using namespace std;
using namespace ngraph;
bool pass::PropagateTypes::run_on_call_list(std::list<Node*> node_list)
{
for (Node* node : node_list)
{
try
{
node->propagate_types();
}
catch (exception& e)
{
stringstream ss;
ss << "Error with node " << *node << ": ";
ss << e.what();
throw invalid_argument(ss.str());
}
}
return false;
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "call_pass.hpp"
namespace ngraph
{
namespace pass
{
class PropagateTypes;
}
class Node;
}
class ngraph::pass::PropagateTypes : public CallBase
{
public:
virtual bool run_on_call_list(std::list<Node*>) override;
private:
};
...@@ -18,20 +18,21 @@ ...@@ -18,20 +18,21 @@
#include "topological_sort.hpp" #include "topological_sort.hpp"
#include "node.hpp" #include "node.hpp"
#include "util.hpp" #include "util.hpp"
#include "log.hpp"
using namespace ngraph; using namespace ngraph;
using namespace std; using namespace std;
void ngraph::TopologicalSort::process(node_ptr p) bool ngraph::pass::TopologicalSort::run_on_tree(std::shared_ptr<Node> p)
{ {
deque<Node*> independent_nodes; deque<Node*> independent_nodes;
unordered_map<Node*, size_t> node_depencency_count; unordered_map<Node*, size_t> node_depencency_count;
traverse_nodes(p, [&](node_ptr node) { traverse_nodes(p, [&](Node* node) {
node_depencency_count[node.get()] = node->get_arguments().size(); node_depencency_count[node] = node->get_arguments().size();
if (node->get_arguments().size() == 0) if (node->get_arguments().size() == 0)
{ {
independent_nodes.push_back(node.get()); independent_nodes.push_back(node);
} }
}); });
...@@ -51,14 +52,11 @@ void ngraph::TopologicalSort::process(node_ptr p) ...@@ -51,14 +52,11 @@ void ngraph::TopologicalSort::process(node_ptr p)
} }
} }
} }
}
const std::list<Node*>& ngraph::TopologicalSort::get_sorted_list() const return false;
{
return m_sorted_list;
} }
std::list<Node*>& ngraph::TopologicalSort::get_sorted_list() std::list<Node*> ngraph::pass::TopologicalSort::get_call_graph() const
{ {
return m_sorted_list; return m_sorted_list;
} }
...@@ -17,24 +17,27 @@ ...@@ -17,24 +17,27 @@
#include <memory> #include <memory>
#include <list> #include <list>
#include "tree_pass.hpp"
namespace ngraph namespace ngraph
{ {
namespace pass
{
class TopologicalSort; class TopologicalSort;
}
class Node; class Node;
using node_ptr = std::shared_ptr<Node>;
} }
class ngraph::TopologicalSort class ngraph::pass::TopologicalSort : public TreeBase
{ {
public: public:
TopologicalSort() {} TopologicalSort() {}
void process(node_ptr); bool run_on_tree(std::shared_ptr<Node>) override;
const std::list<Node*>& get_sorted_list() const;
std::list<Node*>& get_sorted_list();
private: bool call_graph_produced() const override { return true; }
void promote_node(Node* n); std::list<Node*> get_call_graph() const override;
private:
std::list<Node*> m_sorted_list; std::list<Node*> m_sorted_list;
}; };
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "tree_pass.hpp"
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <memory>
#include <list>
#include "pass.hpp"
namespace ngraph
{
namespace pass
{
class TreeBase;
}
class Node;
}
class ngraph::pass::TreeBase : public Base
{
public:
virtual ~TreeBase() {}
// return true if changes were made to the tree
virtual bool run_on_tree(std::shared_ptr<Node>) = 0;
virtual bool call_graph_produced() const { return false; }
virtual std::list<Node*> get_call_graph() const { return std::list<Node*>(); }
private:
std::list<Node*> m_sorted_list;
};
...@@ -37,6 +37,7 @@ namespace ngraph ...@@ -37,6 +37,7 @@ namespace ngraph
/// Add tensor views in depth-first order. /// Add tensor views in depth-first order.
virtual void collect_tensor_views(std::vector<std::shared_ptr<const TensorViewType>>& views) const = 0; virtual void collect_tensor_views(std::vector<std::shared_ptr<const TensorViewType>>& views) const = 0;
friend std::ostream& operator<<(std::ostream&, const ValueType&);
}; };
/// Describes a tensor view; an element type and a shape. /// Describes a tensor view; an element type and a shape.
...@@ -57,6 +58,8 @@ namespace ngraph ...@@ -57,6 +58,8 @@ namespace ngraph
virtual bool operator==(const ValueType& that) const override; virtual bool operator==(const ValueType& that) const override;
virtual void collect_tensor_views(std::vector<std::shared_ptr<const TensorViewType>>& views) const override; virtual void collect_tensor_views(std::vector<std::shared_ptr<const TensorViewType>>& views) const override;
friend std::ostream& operator<<(std::ostream&, const TensorViewType&);
protected: protected:
const element::Type& m_element_type; const element::Type& m_element_type;
Shape m_shape; Shape m_shape;
...@@ -83,6 +86,7 @@ namespace ngraph ...@@ -83,6 +86,7 @@ namespace ngraph
virtual bool operator==(const ValueType& that) const override; virtual bool operator==(const ValueType& that) const override;
virtual void collect_tensor_views(std::vector<std::shared_ptr<const TensorViewType>>& views) const override; virtual void collect_tensor_views(std::vector<std::shared_ptr<const TensorViewType>>& views) const override;
friend std::ostream& operator<<(std::ostream&, const TupleType&);
protected: protected:
std::vector<std::shared_ptr<ValueType>> m_element_types; std::vector<std::shared_ptr<ValueType>> m_element_types;
......
...@@ -31,10 +31,10 @@ Visualize::Visualize(const string& name) ...@@ -31,10 +31,10 @@ Visualize::Visualize(const string& name)
void Visualize::add(node_ptr p) void Visualize::add(node_ptr p)
{ {
// map<size_t, list<node_ptr>> dependent_nodes; // map<size_t, list<node_ptr>> dependent_nodes;
traverse_nodes(p, [&](node_ptr node) { traverse_nodes(p, [&](Node* node) {
for (auto arg : node->get_arguments()) for (auto arg : node->get_arguments())
{ {
m_ss << add_attributes(arg); m_ss << add_attributes(arg.get());
m_ss << add_attributes(node); m_ss << add_attributes(node);
m_ss << " " << arg->get_node_id() << " -> " << node->get_node_id(); m_ss << " " << arg->get_node_id() << " -> " << node->get_node_id();
m_ss << ";\n"; m_ss << ";\n";
...@@ -42,7 +42,7 @@ void Visualize::add(node_ptr p) ...@@ -42,7 +42,7 @@ void Visualize::add(node_ptr p)
}); });
} }
std::string Visualize::add_attributes(node_ptr node) std::string Visualize::add_attributes(const Node* node)
{ {
string rc; string rc;
if (!contains(m_nodes_with_attributes, node)) if (!contains(m_nodes_with_attributes, node))
...@@ -53,7 +53,7 @@ std::string Visualize::add_attributes(node_ptr node) ...@@ -53,7 +53,7 @@ std::string Visualize::add_attributes(node_ptr node)
return rc; return rc;
} }
std::string Visualize::get_attributes(node_ptr node) std::string Visualize::get_attributes(const Node* node)
{ {
stringstream ss; stringstream ss;
if (node->is_parameter()) if (node->is_parameter())
......
...@@ -36,10 +36,10 @@ public: ...@@ -36,10 +36,10 @@ public:
void save_dot(const std::string& path) const; void save_dot(const std::string& path) const;
private: private:
std::string add_attributes(node_ptr node); std::string add_attributes(const Node* node);
std::string get_attributes(node_ptr node); std::string get_attributes(const Node* node);
std::stringstream m_ss; std::stringstream m_ss;
std::string m_name; std::string m_name;
std::set<node_ptr> m_nodes_with_attributes; std::set<const Node*> m_nodes_with_attributes;
}; };
...@@ -14,8 +14,10 @@ ...@@ -14,8 +14,10 @@
#include <cassert> #include <cassert>
#include <cmath> #include <cmath>
#include <iostream>
#include "ngraph/element_type.hpp" #include "ngraph/element_type.hpp"
#include "log.hpp"
using namespace ngraph; using namespace ngraph;
...@@ -30,6 +32,7 @@ ngraph::element::Type::Type(size_t bitwidth, ...@@ -30,6 +32,7 @@ ngraph::element::Type::Type(size_t bitwidth,
, m_is_signed{is_signed} , m_is_signed{is_signed}
, m_cname{cname} , m_cname{cname}
{ {
INFO << m_cname;
assert(m_bitwidth % 8 == 0); assert(m_bitwidth % 8 == 0);
} }
...@@ -48,3 +51,9 @@ size_t ngraph::element::Type::size() const ...@@ -48,3 +51,9 @@ size_t ngraph::element::Type::size() const
{ {
return std::ceil((float)m_bitwidth / 8.0); return std::ceil((float)m_bitwidth / 8.0);
} }
std::ostream& ngraph::element::operator<<(std::ostream& out, const ngraph::element::Type& obj)
{
// out << "ElementType(" << obj.c_type_string() << ")";
return out;
}
\ No newline at end of file
...@@ -54,7 +54,26 @@ bool TupleType::operator==(const ValueType& that) const ...@@ -54,7 +54,26 @@ bool TupleType::operator==(const ValueType& that) const
void TupleType::collect_tensor_views(std::vector<std::shared_ptr<const TensorViewType>>& views) const void TupleType::collect_tensor_views(std::vector<std::shared_ptr<const TensorViewType>>& views) const
{ {
for(auto elt : m_element_types){ for(auto elt : m_element_types)
{
elt->collect_tensor_views(views); elt->collect_tensor_views(views);
} }
} }
std::ostream& ngraph::operator<<(std::ostream& out, const ValueType& obj)
{
out << "ValueType()";
return out;
}
std::ostream& ngraph::operator<<(std::ostream& out, const TensorViewType& obj)
{
out << "TensorViewType(" << obj.m_element_type << ")";
return out;
}
std::ostream& ngraph::operator<<(std::ostream& out, const TupleType& obj)
{
out << "TupleType()";
return out;
}
...@@ -14,9 +14,13 @@ ...@@ -14,9 +14,13 @@
#include <iomanip> #include <iomanip>
#include <map> #include <map>
#include <deque>
#include <forward_list>
#include <unordered_set>
#include "util.hpp" #include "util.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "log.hpp"
using namespace std; using namespace std;
...@@ -131,24 +135,37 @@ size_t ngraph::hash_combine(const std::vector<size_t>& list) ...@@ -131,24 +135,37 @@ size_t ngraph::hash_combine(const std::vector<size_t>& list)
return seed; return seed;
} }
static void traverse_nodes(std::shared_ptr<ngraph::Node> p, void ngraph::traverse_nodes(const std::shared_ptr<ngraph::Node>& p,
std::function<void(std::shared_ptr<ngraph::Node>)> f, std::function<void(Node*)> f)
std::set<size_t>& instances_seen)
{ {
f(p); std::unordered_set<Node*> instances_seen;
for (auto arg : p->get_arguments()) deque<Node*> stack;
stack.push_front(p.get());
while (stack.size() > 0)
{ {
if (instances_seen.find(arg->get_instance_id()) == instances_seen.end()) Node* n = stack.front();
if (instances_seen.find(n) == instances_seen.end())
{ {
instances_seen.insert(arg->get_instance_id()); instances_seen.insert(n);
traverse_nodes(arg, f, instances_seen); f(n);
} }
stack.pop_front();
for (auto arg : n->get_arguments()) { stack.push_front(arg.get()); }
} }
} }
void ngraph::traverse_nodes(std::shared_ptr<ngraph::Node> p, void ngraph::free_nodes(shared_ptr<Node> p)
std::function<void(std::shared_ptr<ngraph::Node>)> f)
{ {
std::set<size_t> instances_seen; std::deque<Node*> sorted_list;
::traverse_nodes(p, f, instances_seen);
traverse_nodes(p, [&](Node* n)
{
sorted_list.push_front(n);
});
for (Node* n : sorted_list)
{
n->clear_arguments();
}
} }
...@@ -195,5 +195,7 @@ namespace ngraph ...@@ -195,5 +195,7 @@ namespace ngraph
return a * b; return a * b;
} }
void traverse_nodes(std::shared_ptr<Node> p, std::function<void(std::shared_ptr<Node>)> f); void traverse_nodes(const std::shared_ptr<Node>& p, std::function<void(Node*)> f);
void free_nodes(std::shared_ptr<Node>);
} // end namespace ngraph } // end namespace ngraph
...@@ -22,13 +22,15 @@ include_directories( ...@@ -22,13 +22,15 @@ include_directories(
) )
set (SRC set (SRC
main.cpp
build_graph.cpp build_graph.cpp
eigen.cpp eigen.cpp
element_type.cpp element_type.cpp
op.cpp
input_output_assign.cpp input_output_assign.cpp
main.cpp
op.cpp
pass_manager.cpp
tensor.cpp tensor.cpp
test_tools.cpp
topological_sort.cpp topological_sort.cpp
type_prop.cpp type_prop.cpp
util.cpp util.cpp
...@@ -48,4 +50,3 @@ add_dependencies(unit-test ngraph libgtest eigen) ...@@ -48,4 +50,3 @@ add_dependencies(unit-test ngraph libgtest eigen)
add_custom_target(check add_custom_target(check
COMMAND ${PROJECT_BINARY_DIR}/test/unit-test COMMAND ${PROJECT_BINARY_DIR}/test/unit-test
DEPENDS unit-test) DEPENDS unit-test)
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <memory>
#include <sstream>
#include <string>
#include <vector>
#include "gtest/gtest.h"
#include "ngraph/pass/topological_sort.hpp"
#include "ngraph/pass/propagate_types.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/ngraph.hpp"
#include "test_tools.hpp"
using namespace ngraph;
using namespace std;
// TEST(pass_manager, add)
// {
// pass::Manager pass_manager;
// auto topological_sort = make_shared<pass::TopologicalSort>();
// auto propagate_types = make_shared<pass::PropagateTypes>();
// pass_manager.register_pass(topological_sort);
// pass_manager.register_pass(propagate_types);
// auto graph = make_test_graph();
// size_t node_count = get_node_count(graph);
// pass_manager.run_passes(graph);
// auto sorted = pass_manager.get_sorted_list();
// EXPECT_EQ(node_count, sorted.size());
// EXPECT_TRUE(validate_list(sorted));
// }
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <algorithm>
#include "test_tools.hpp"
#include "ngraph/ngraph.hpp"
#include "util.hpp"
using namespace std;
using namespace ngraph;
// This function traverses the list of ops and verifies that each op's dependencies (its inputs)
// is located earlier in the list. That is enough to be valid
bool validate_list(const list<Node*>& nodes)
{
bool rc = true;
for (auto it = nodes.rbegin(); it != nodes.rend(); it++)
{
auto node_tmp = *it;
auto dependencies_tmp = node_tmp->get_arguments();
vector<Node*> dependencies;
for (shared_ptr<Node> n : dependencies_tmp)
{
dependencies.push_back(n.get());
}
auto tmp = it++;
for (; tmp != nodes.rend(); tmp++)
{
auto dep_tmp = *tmp;
auto found = find(dependencies.begin(), dependencies.end(), dep_tmp);
if (found != dependencies.end())
{
dependencies.erase(found);
}
}
if (dependencies.size() > 0)
{
rc = false;
}
}
return rc;
}
shared_ptr<Node> make_test_graph()
{
auto arg_0 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{1});
auto arg_1 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{1});
auto arg_2 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{1});
auto arg_3 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{1});
auto arg_4 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{1});
auto arg_5 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{1});
auto t0 = make_shared<op::Add>(arg_0, arg_1);
auto t1 = make_shared<op::Dot>(t0, arg_2);
auto t2 = make_shared<op::Multiply>(t0, arg_3);
auto t3 = make_shared<op::Add>(t1, arg_4);
auto t4 = make_shared<op::Add>(t2, arg_5);
auto r0 = make_shared<op::Add>(t3, t4);
return r0;
}
size_t get_node_count(std::shared_ptr<Node> n)
{
size_t node_count = 0;
traverse_nodes(n, [&](const Node* node) {
node_count++;
});
return node_count;
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <list>
#include <memory>
namespace ngraph
{
class Node;
}
bool validate_list(const std::list<ngraph::Node*>& nodes);
std::shared_ptr<ngraph::Node> make_test_graph();
size_t get_node_count(std::shared_ptr<ngraph::Node> n);
...@@ -20,43 +20,15 @@ ...@@ -20,43 +20,15 @@
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
#include "ngraph/topological_sort.hpp" #include "ngraph/pass/topological_sort.hpp"
#include "ngraph/visualize.hpp" #include "ngraph/visualize.hpp"
#include "util.hpp" #include "util.hpp"
#include "log.hpp"
#include "test_tools.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
static bool validate_list(const list<Node*>& nodes)
{
bool rc = true;
for (auto it = nodes.rbegin(); it != nodes.rend(); it++)
{
auto node_tmp = *it;
auto dependencies_tmp = node_tmp->get_arguments();
vector<Node*> dependencies;
for (shared_ptr<Node> n : dependencies_tmp)
{
dependencies.push_back(n.get());
}
auto tmp = it++;
for (; tmp != nodes.rend(); tmp++)
{
auto dep_tmp = *tmp;
auto found = find(dependencies.begin(), dependencies.end(), dep_tmp);
if (found != dependencies.end())
{
dependencies.erase(found);
}
}
if (dependencies.size() > 0)
{
rc = false;
}
}
return rc;
}
TEST(topological_sort, basic) TEST(topological_sort, basic)
{ {
vector<shared_ptr<op::Parameter>> args; vector<shared_ptr<op::Parameter>> args;
...@@ -86,20 +58,73 @@ TEST(topological_sort, basic) ...@@ -86,20 +58,73 @@ TEST(topological_sort, basic)
ASSERT_NE(nullptr, f0); ASSERT_NE(nullptr, f0);
ASSERT_EQ(2, r0->get_arguments().size()); ASSERT_EQ(2, r0->get_arguments().size());
auto op_r0 = static_pointer_cast<Op>(r0);
// Visualize vz; // Visualize vz;
// vz.add(r0); // vz.add(r0);
// vz.save_dot("test.png"); // vz.save_dot("test.png");
TopologicalSort ts; pass::TopologicalSort ts;
ts.process(r0); ts.run_on_tree(r0);
auto sorted_list = ts.get_sorted_list(); auto sorted_list = ts.get_call_graph();
size_t node_count = get_node_count(r0);
EXPECT_EQ(node_count, sorted_list.size());
EXPECT_TRUE(validate_list(sorted_list));
}
// TEST(topological_sort, cycle)
// {
// vector<shared_ptr<op::Parameter>> args;
// for (int i = 0; i < 10; i++)
// {
// auto arg = make_shared<op::Parameter>(element::Float32::element_type(), Shape{1});
// ASSERT_NE(nullptr, arg);
// args.push_back(arg);
// }
// auto add_0 = make_shared<op::Add>(args[0], args[1]);
// auto add_1 = make_shared<op::Add>(args[0], args[1]);
// }
shared_ptr<Node> make_cell(shared_ptr<Node> in_0, shared_ptr<Node> in_1, shared_ptr<Node> in_2)
{
auto t0 = make_shared<op::Dot>(in_0, in_1);
auto t1 = make_shared<op::Add>(t0, in_2);
auto t2 = make_shared<op::Negative>(t1); // no tanh yet, this will do
return static_pointer_cast<Node>(t2);
}
TEST(benchmark, topological_sort)
{
stopwatch timer;
// x[i+1] = tanh(dot(W,x[i])+b)
shared_ptr<Node> result;
result = make_shared<op::Parameter>(element::Float32::element_type(), Shape{1});
for (int i=0; i<1000000; i++)
{
auto in_1 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{1});
auto in_2 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{1});
result = make_cell(result, in_1, in_2);
}
auto op_r0 = static_pointer_cast<Op>(result);
timer.start();
pass::TopologicalSort ts;
ts.run_on_tree(op_r0);
auto sorted_list = ts.get_call_graph();
timer.stop();
INFO << "topological sort took " << timer.get_milliseconds() << "ms";
size_t node_count = 0; size_t node_count = 0;
traverse_nodes(r0, [&](node_ptr node) { traverse_nodes(op_r0, [&](const Node* node) {
node_count++; node_count++;
}); });
EXPECT_EQ(node_count, sorted_list.size()); INFO << "node count " << node_count;
EXPECT_TRUE(validate_list(sorted_list));
timer.start();
ngraph::free_nodes(result);
timer.stop();
INFO << "delete nodes took " << timer.get_milliseconds() << "ms";
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment