Commit ae48ed32 authored by Robert Kimball's avatar Robert Kimball

add pass skeletons

add test framework

stuff wired up

benchmark segfault for 5M nodes. fun

make traverse_nodes not recursive

print timings on benchmark

unit test for pass manager
parent f1608316
......@@ -24,7 +24,7 @@ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-global-constructors")
# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-gnu-zero-variadic-macro-arguments")
# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-undef")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-exit-time-destructors")
# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-missing-prototypes")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-missing-prototypes")
# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-disabled-macro-expansion")
# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-pedantic")
# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-documentation")
......
......@@ -33,9 +33,14 @@ set (SRC
types/element_type.cpp
types/type.cpp
ngraph/node.cpp
ngraph/topological_sort.cpp
ngraph/visualize.cpp
)
ngraph/pass/pass.cpp
ngraph/pass/manager.cpp
ngraph/pass/call_pass.cpp
ngraph/pass/tree_pass.cpp
ngraph/pass/topological_sort.cpp
ngraph/pass/propagate_types.cpp
)
set(NGRAPH_INCLUDE_PATH
${CMAKE_CURRENT_SOURCE_DIR}
......
......@@ -45,6 +45,7 @@ namespace ngraph
bool operator==(const Type& other) const;
bool operator!=(const Type& other) const { return !(*this == other); }
friend std::ostream& operator<<(std::ostream&, const Type&);
private:
static std::map<std::string, Type> m_element_list;
......
......@@ -64,6 +64,7 @@ namespace ngraph
void assign_tensors();
const Nodes& get_arguments() const { return m_arguments; }
void clear_arguments() { m_arguments.clear(); }
const std::multiset<Node*>& users() const { return m_users; }
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "call_pass.hpp"
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <list>
#include "pass.hpp"
namespace ngraph
{
namespace pass
{
class CallBase;
}
class Node;
}
class ngraph::pass::CallBase : public Base
{
public:
virtual ~CallBase() {}
virtual bool run_on_call_list(std::list<Node*>) = 0;
private:
};
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <iostream>
#include <memory>
#include "manager.hpp"
#include "ngraph/node.hpp"
#include "log.hpp"
using namespace std;
ngraph::pass::Manager::Manager()
{
}
ngraph::pass::Manager::~Manager()
{
}
void ngraph::pass::Manager::initialize_default_passes()
{
}
void ngraph::pass::Manager::register_pass(std::shared_ptr<TreeBase> p)
{
if (p == nullptr)
{
throw invalid_argument("null pass registered");
}
m_tree_passes.push_back(p);
}
void ngraph::pass::Manager::register_pass(std::shared_ptr<CallBase> p)
{
if (p == nullptr)
{
throw invalid_argument("null pass registered");
}
m_call_passes.push_back(p);
}
void ngraph::pass::Manager::run_passes(std::shared_ptr<Node> nodes)
{
for (shared_ptr<TreeBase> p : m_tree_passes)
{
p->run_on_tree(nodes);
if (p->call_graph_produced())
{
m_sorted_list = p->get_call_graph();
}
}
for (shared_ptr<CallBase>& p : m_call_passes)
{
p->run_on_call_list(m_sorted_list);
}
}
const std::list<ngraph::Node*>& ngraph::pass::Manager::get_sorted_list() const
{
return m_sorted_list;
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <vector>
#include "call_pass.hpp"
#include "tree_pass.hpp"
namespace ngraph
{
namespace pass
{
class Manager;
}
class Node;
}
class ngraph::pass::Manager
{
public:
Manager();
~Manager();
void initialize_default_passes();
void register_pass(std::shared_ptr<TreeBase>);
void register_pass(std::shared_ptr<CallBase>);
void run_passes(std::shared_ptr<Node> nodes);
const std::list<Node*>& get_sorted_list() const;
private:
std::vector<std::shared_ptr<TreeBase>> m_tree_passes;
std::vector<std::shared_ptr<CallBase>> m_call_passes;
std::list<Node*> m_sorted_list;
};
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "pass.hpp"
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace pass
{
class Base;
}
}
class ngraph::pass::Base
{
public:
private:
};
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <sstream>
#include "propagate_types.hpp"
#include "ngraph/ngraph.hpp"
using namespace std;
using namespace ngraph;
bool pass::PropagateTypes::run_on_call_list(std::list<Node*> node_list)
{
for (Node* node : node_list)
{
try
{
node->propagate_types();
}
catch (exception& e)
{
stringstream ss;
ss << "Error with node " << *node << ": ";
ss << e.what();
throw invalid_argument(ss.str());
}
}
return false;
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "call_pass.hpp"
namespace ngraph
{
namespace pass
{
class PropagateTypes;
}
class Node;
}
class ngraph::pass::PropagateTypes : public CallBase
{
public:
virtual bool run_on_call_list(std::list<Node*>) override;
private:
};
......@@ -18,20 +18,21 @@
#include "topological_sort.hpp"
#include "node.hpp"
#include "util.hpp"
#include "log.hpp"
using namespace ngraph;
using namespace std;
void ngraph::TopologicalSort::process(node_ptr p)
bool ngraph::pass::TopologicalSort::run_on_tree(std::shared_ptr<Node> p)
{
deque<Node*> independent_nodes;
unordered_map<Node*, size_t> node_depencency_count;
traverse_nodes(p, [&](node_ptr node) {
node_depencency_count[node.get()] = node->get_arguments().size();
traverse_nodes(p, [&](Node* node) {
node_depencency_count[node] = node->get_arguments().size();
if (node->get_arguments().size() == 0)
{
independent_nodes.push_back(node.get());
independent_nodes.push_back(node);
}
});
......@@ -51,14 +52,11 @@ void ngraph::TopologicalSort::process(node_ptr p)
}
}
}
}
const std::list<Node*>& ngraph::TopologicalSort::get_sorted_list() const
{
return m_sorted_list;
return false;
}
std::list<Node*>& ngraph::TopologicalSort::get_sorted_list()
std::list<Node*> ngraph::pass::TopologicalSort::get_call_graph() const
{
return m_sorted_list;
}
......@@ -17,24 +17,27 @@
#include <memory>
#include <list>
#include "tree_pass.hpp"
namespace ngraph
{
namespace pass
{
class TopologicalSort;
}
class Node;
using node_ptr = std::shared_ptr<Node>;
}
class ngraph::TopologicalSort
class ngraph::pass::TopologicalSort : public TreeBase
{
public:
TopologicalSort() {}
void process(node_ptr);
const std::list<Node*>& get_sorted_list() const;
std::list<Node*>& get_sorted_list();
bool run_on_tree(std::shared_ptr<Node>) override;
private:
void promote_node(Node* n);
bool call_graph_produced() const override { return true; }
std::list<Node*> get_call_graph() const override;
private:
std::list<Node*> m_sorted_list;
};
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "tree_pass.hpp"
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <memory>
#include <list>
#include "pass.hpp"
namespace ngraph
{
namespace pass
{
class TreeBase;
}
class Node;
}
class ngraph::pass::TreeBase : public Base
{
public:
virtual ~TreeBase() {}
// return true if changes were made to the tree
virtual bool run_on_tree(std::shared_ptr<Node>) = 0;
virtual bool call_graph_produced() const { return false; }
virtual std::list<Node*> get_call_graph() const { return std::list<Node*>(); }
private:
std::list<Node*> m_sorted_list;
};
......@@ -37,6 +37,7 @@ namespace ngraph
/// Add tensor views in depth-first order.
virtual void collect_tensor_views(std::vector<std::shared_ptr<const TensorViewType>>& views) const = 0;
friend std::ostream& operator<<(std::ostream&, const ValueType&);
};
/// Describes a tensor view; an element type and a shape.
......@@ -57,6 +58,8 @@ namespace ngraph
virtual bool operator==(const ValueType& that) const override;
virtual void collect_tensor_views(std::vector<std::shared_ptr<const TensorViewType>>& views) const override;
friend std::ostream& operator<<(std::ostream&, const TensorViewType&);
protected:
const element::Type& m_element_type;
Shape m_shape;
......@@ -83,6 +86,7 @@ namespace ngraph
virtual bool operator==(const ValueType& that) const override;
virtual void collect_tensor_views(std::vector<std::shared_ptr<const TensorViewType>>& views) const override;
friend std::ostream& operator<<(std::ostream&, const TupleType&);
protected:
std::vector<std::shared_ptr<ValueType>> m_element_types;
......
......@@ -31,10 +31,10 @@ Visualize::Visualize(const string& name)
void Visualize::add(node_ptr p)
{
// map<size_t, list<node_ptr>> dependent_nodes;
traverse_nodes(p, [&](node_ptr node) {
traverse_nodes(p, [&](Node* node) {
for (auto arg : node->get_arguments())
{
m_ss << add_attributes(arg);
m_ss << add_attributes(arg.get());
m_ss << add_attributes(node);
m_ss << " " << arg->get_node_id() << " -> " << node->get_node_id();
m_ss << ";\n";
......@@ -42,7 +42,7 @@ void Visualize::add(node_ptr p)
});
}
std::string Visualize::add_attributes(node_ptr node)
std::string Visualize::add_attributes(const Node* node)
{
string rc;
if (!contains(m_nodes_with_attributes, node))
......@@ -53,7 +53,7 @@ std::string Visualize::add_attributes(node_ptr node)
return rc;
}
std::string Visualize::get_attributes(node_ptr node)
std::string Visualize::get_attributes(const Node* node)
{
stringstream ss;
if (node->is_parameter())
......
......@@ -36,10 +36,10 @@ public:
void save_dot(const std::string& path) const;
private:
std::string add_attributes(node_ptr node);
std::string get_attributes(node_ptr node);
std::string add_attributes(const Node* node);
std::string get_attributes(const Node* node);
std::stringstream m_ss;
std::string m_name;
std::set<node_ptr> m_nodes_with_attributes;
std::set<const Node*> m_nodes_with_attributes;
};
......@@ -14,8 +14,10 @@
#include <cassert>
#include <cmath>
#include <iostream>
#include "ngraph/element_type.hpp"
#include "log.hpp"
using namespace ngraph;
......@@ -30,6 +32,7 @@ ngraph::element::Type::Type(size_t bitwidth,
, m_is_signed{is_signed}
, m_cname{cname}
{
INFO << m_cname;
assert(m_bitwidth % 8 == 0);
}
......@@ -48,3 +51,9 @@ size_t ngraph::element::Type::size() const
{
return std::ceil((float)m_bitwidth / 8.0);
}
std::ostream& ngraph::element::operator<<(std::ostream& out, const ngraph::element::Type& obj)
{
// out << "ElementType(" << obj.c_type_string() << ")";
return out;
}
\ No newline at end of file
......@@ -54,7 +54,26 @@ bool TupleType::operator==(const ValueType& that) const
void TupleType::collect_tensor_views(std::vector<std::shared_ptr<const TensorViewType>>& views) const
{
for(auto elt : m_element_types){
for(auto elt : m_element_types)
{
elt->collect_tensor_views(views);
}
}
std::ostream& ngraph::operator<<(std::ostream& out, const ValueType& obj)
{
out << "ValueType()";
return out;
}
std::ostream& ngraph::operator<<(std::ostream& out, const TensorViewType& obj)
{
out << "TensorViewType(" << obj.m_element_type << ")";
return out;
}
std::ostream& ngraph::operator<<(std::ostream& out, const TupleType& obj)
{
out << "TupleType()";
return out;
}
......@@ -14,9 +14,13 @@
#include <iomanip>
#include <map>
#include <deque>
#include <forward_list>
#include <unordered_set>
#include "util.hpp"
#include "ngraph/node.hpp"
#include "log.hpp"
using namespace std;
......@@ -131,24 +135,37 @@ size_t ngraph::hash_combine(const std::vector<size_t>& list)
return seed;
}
static void traverse_nodes(std::shared_ptr<ngraph::Node> p,
std::function<void(std::shared_ptr<ngraph::Node>)> f,
std::set<size_t>& instances_seen)
void ngraph::traverse_nodes(const std::shared_ptr<ngraph::Node>& p,
std::function<void(Node*)> f)
{
f(p);
for (auto arg : p->get_arguments())
std::unordered_set<Node*> instances_seen;
deque<Node*> stack;
stack.push_front(p.get());
while (stack.size() > 0)
{
if (instances_seen.find(arg->get_instance_id()) == instances_seen.end())
Node* n = stack.front();
if (instances_seen.find(n) == instances_seen.end())
{
instances_seen.insert(arg->get_instance_id());
traverse_nodes(arg, f, instances_seen);
instances_seen.insert(n);
f(n);
}
stack.pop_front();
for (auto arg : n->get_arguments()) { stack.push_front(arg.get()); }
}
}
void ngraph::traverse_nodes(std::shared_ptr<ngraph::Node> p,
std::function<void(std::shared_ptr<ngraph::Node>)> f)
void ngraph::free_nodes(shared_ptr<Node> p)
{
std::set<size_t> instances_seen;
::traverse_nodes(p, f, instances_seen);
std::deque<Node*> sorted_list;
traverse_nodes(p, [&](Node* n)
{
sorted_list.push_front(n);
});
for (Node* n : sorted_list)
{
n->clear_arguments();
}
}
......@@ -195,5 +195,7 @@ namespace ngraph
return a * b;
}
void traverse_nodes(std::shared_ptr<Node> p, std::function<void(std::shared_ptr<Node>)> f);
void traverse_nodes(const std::shared_ptr<Node>& p, std::function<void(Node*)> f);
void free_nodes(std::shared_ptr<Node>);
} // end namespace ngraph
......@@ -33,6 +33,8 @@ set (SRC
type_prop.cpp
util.cpp
uuid.cpp
pass_manager.cpp
test_tools.cpp
)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")
......@@ -48,4 +50,3 @@ add_dependencies(unit-test ngraph libgtest eigen)
add_custom_target(check
COMMAND ${PROJECT_BINARY_DIR}/test/unit-test
DEPENDS unit-test)
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <memory>
#include <sstream>
#include <string>
#include <vector>
#include "gtest/gtest.h"
#include "ngraph/pass/topological_sort.hpp"
#include "ngraph/pass/propagate_types.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/ngraph.hpp"
#include "test_tools.hpp"
using namespace ngraph;
using namespace std;
// TEST(pass_manager, add)
// {
// pass::Manager pass_manager;
// auto topological_sort = make_shared<pass::TopologicalSort>();
// auto propagate_types = make_shared<pass::PropagateTypes>();
// pass_manager.register_pass(topological_sort);
// pass_manager.register_pass(propagate_types);
// auto graph = make_test_graph();
// size_t node_count = get_node_count(graph);
// pass_manager.run_passes(graph);
// auto sorted = pass_manager.get_sorted_list();
// EXPECT_EQ(node_count, sorted.size());
// EXPECT_TRUE(validate_list(sorted));
// }
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <algorithm>
#include "test_tools.hpp"
#include "ngraph/ngraph.hpp"
#include "util.hpp"
using namespace std;
using namespace ngraph;
bool validate_list(const list<Node*>& nodes)
{
bool rc = true;
for (auto it = nodes.rbegin(); it != nodes.rend(); it++)
{
auto node_tmp = *it;
auto dependencies_tmp = node_tmp->get_arguments();
vector<Node*> dependencies;
for (shared_ptr<Node> n : dependencies_tmp)
{
dependencies.push_back(n.get());
}
auto tmp = it++;
for (; tmp != nodes.rend(); tmp++)
{
auto dep_tmp = *tmp;
auto found = find(dependencies.begin(), dependencies.end(), dep_tmp);
if (found != dependencies.end())
{
dependencies.erase(found);
}
}
if (dependencies.size() > 0)
{
rc = false;
}
}
return rc;
}
shared_ptr<Node> make_test_graph()
{
auto arg_0 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{1});
auto arg_1 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{1});
auto arg_2 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{1});
auto arg_3 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{1});
auto arg_4 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{1});
auto arg_5 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{1});
auto t0 = make_shared<op::Add>(arg_0, arg_1);
auto t1 = make_shared<op::Dot>(t0, arg_2);
auto t2 = make_shared<op::Multiply>(t0, arg_3);
auto t3 = make_shared<op::Add>(t1, arg_4);
auto t4 = make_shared<op::Add>(t2, arg_5);
auto r0 = make_shared<op::Add>(t3, t4);
return r0;
}
size_t get_node_count(std::shared_ptr<Node> n)
{
size_t node_count = 0;
traverse_nodes(n, [&](const Node* node) {
node_count++;
});
return node_count;
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <list>
#include <memory>
namespace ngraph
{
class Node;
}
bool validate_list(const std::list<ngraph::Node*>& nodes);
std::shared_ptr<ngraph::Node> make_test_graph();
size_t get_node_count(std::shared_ptr<ngraph::Node> n);
......@@ -20,43 +20,15 @@
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "ngraph/topological_sort.hpp"
#include "ngraph/pass/topological_sort.hpp"
#include "ngraph/visualize.hpp"
#include "util.hpp"
#include "log.hpp"
#include "test_tools.hpp"
using namespace std;
using namespace ngraph;
static bool validate_list(const list<Node*>& nodes)
{
bool rc = true;
for (auto it = nodes.rbegin(); it != nodes.rend(); it++)
{
auto node_tmp = *it;
auto dependencies_tmp = node_tmp->get_arguments();
vector<Node*> dependencies;
for (shared_ptr<Node> n : dependencies_tmp)
{
dependencies.push_back(n.get());
}
auto tmp = it++;
for (; tmp != nodes.rend(); tmp++)
{
auto dep_tmp = *tmp;
auto found = find(dependencies.begin(), dependencies.end(), dep_tmp);
if (found != dependencies.end())
{
dependencies.erase(found);
}
}
if (dependencies.size() > 0)
{
rc = false;
}
}
return rc;
}
TEST(topological_sort, basic)
{
vector<shared_ptr<op::Parameter>> args;
......@@ -86,20 +58,73 @@ TEST(topological_sort, basic)
ASSERT_NE(nullptr, f0);
ASSERT_EQ(2, r0->get_arguments().size());
auto op_r0 = static_pointer_cast<Op>(r0);
// Visualize vz;
// vz.add(r0);
// vz.save_dot("test.png");
TopologicalSort ts;
ts.process(r0);
auto sorted_list = ts.get_sorted_list();
pass::TopologicalSort ts;
ts.run_on_tree(r0);
auto sorted_list = ts.get_call_graph();
size_t node_count = get_node_count(r0);
EXPECT_EQ(node_count, sorted_list.size());
EXPECT_TRUE(validate_list(sorted_list));
}
// TEST(topological_sort, cycle)
// {
// vector<shared_ptr<op::Parameter>> args;
// for (int i = 0; i < 10; i++)
// {
// auto arg = make_shared<op::Parameter>(element::Float32::element_type(), Shape{1});
// ASSERT_NE(nullptr, arg);
// args.push_back(arg);
// }
// auto add_0 = make_shared<op::Add>(args[0], args[1]);
// auto add_1 = make_shared<op::Add>(args[0], args[1]);
// }
shared_ptr<Node> make_cell(shared_ptr<Node> in_0, shared_ptr<Node> in_1, shared_ptr<Node> in_2)
{
auto t0 = make_shared<op::Dot>(in_0, in_1);
auto t1 = make_shared<op::Add>(t0, in_2);
auto t2 = make_shared<op::Negative>(t1); // no tanh yet, this will do
return static_pointer_cast<Node>(t2);
}
TEST(benchmark, topological_sort)
{
stopwatch timer;
// x[i+1] = tanh(dot(W,x[i])+b)
shared_ptr<Node> result;
result = make_shared<op::Parameter>(element::Float32::element_type(), Shape{1});
for (int i=0; i<1000000; i++)
{
auto in_1 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{1});
auto in_2 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{1});
result = make_cell(result, in_1, in_2);
}
auto op_r0 = static_pointer_cast<Op>(result);
timer.start();
pass::TopologicalSort ts;
ts.run_on_tree(op_r0);
auto sorted_list = ts.get_call_graph();
timer.stop();
INFO << "topological sort took " << timer.get_milliseconds() << "ms";
size_t node_count = 0;
traverse_nodes(r0, [&](node_ptr node) {
traverse_nodes(op_r0, [&](const Node* node) {
node_count++;
});
EXPECT_EQ(node_count, sorted_list.size());
EXPECT_TRUE(validate_list(sorted_list));
INFO << "node count " << node_count;
timer.start();
ngraph::free_nodes(result);
timer.stop();
INFO << "delete nodes took " << timer.get_milliseconds() << "ms";
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment