Commit 66d06693 authored by varun-intel's avatar varun-intel Committed by Scott Cyphers

recreate ops (#325)

* recreate ops

* style

* recompute ops

* style

* fix

* recreate ops

* style

* recompute ops

* style

* fix

* some

* more

* style

* remove a line

* const

* style

* NodeMap was using non-standard operator[] behavior.

* Missing include
parent d092cb91
...@@ -96,6 +96,7 @@ set (SRC ...@@ -96,6 +96,7 @@ set (SRC
types/element_type.cpp types/element_type.cpp
types/type.cpp types/type.cpp
util.cpp util.cpp
graph_util.cpp
) )
message(STATUS ${CMAKE_CURRENT_SOURCE_DIR}/ops) message(STATUS ${CMAKE_CURRENT_SOURCE_DIR}/ops)
......
...@@ -12,9 +12,11 @@ ...@@ -12,9 +12,11 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#include <list>
#include <memory> #include <memory>
#include "ngraph/function.hpp" #include "ngraph/function.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp" #include "ngraph/log.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
...@@ -34,8 +36,6 @@ Function::Function(const Nodes& results, ...@@ -34,8 +36,6 @@ Function::Function(const Nodes& results,
, m_instance_id(m_next_instance_id.fetch_add(1)) , m_instance_id(m_next_instance_id.fetch_add(1))
{ {
traverse_nodes(this, [&](shared_ptr<Node> node) { traverse_nodes(this, [&](shared_ptr<Node> node) {
m_ops.push_back(node);
std::shared_ptr<op::Parameter> p = std::dynamic_pointer_cast<op::Parameter>(node); std::shared_ptr<op::Parameter> p = std::dynamic_pointer_cast<op::Parameter>(node);
if (nullptr != p) if (nullptr != p)
{ {
...@@ -63,16 +63,6 @@ void Function::set_ordered_ops(const std::list<shared_ptr<Node>>& ordered_ops) ...@@ -63,16 +63,6 @@ void Function::set_ordered_ops(const std::list<shared_ptr<Node>>& ordered_ops)
m_ordered_ops_valid = true; m_ordered_ops_valid = true;
} }
std::list<shared_ptr<Node>>& Function::get_ops()
{
return m_ops;
}
const std::list<shared_ptr<Node>>& Function::get_ops() const
{
return m_ops;
}
std::list<shared_ptr<Node>>& Function::get_ordered_ops() std::list<shared_ptr<Node>>& Function::get_ordered_ops()
{ {
if (!m_ordered_ops_valid) if (!m_ordered_ops_valid)
...@@ -161,3 +151,24 @@ shared_ptr<Node> Function::get_result() const ...@@ -161,3 +151,24 @@ shared_ptr<Node> Function::get_result() const
} }
return m_results.at(0); return m_results.at(0);
} }
std::list<shared_ptr<Node>> Function::get_ops() const
{
std::list<std::shared_ptr<Node>> ops;
traverse_nodes(this, [&](shared_ptr<Node> node) {
ops.push_back(node);
std::shared_ptr<op::Parameter> p = std::dynamic_pointer_cast<op::Parameter>(node);
if (nullptr != p)
{
auto it = std::find_if(m_parameters.begin(),
m_parameters.end(),
[p](std::shared_ptr<op::Parameter> q) { return (p == q); });
if (it == m_parameters.end())
{
throw ngraph_error("Function references undeclared parameter");
}
}
});
return ops;
}
...@@ -72,8 +72,7 @@ namespace ngraph ...@@ -72,8 +72,7 @@ namespace ngraph
void set_name( void set_name(
const std::string& const std::string&
name); //so we can use `dynamic_cast` in FunctionCall to double check if we are dealing with an XLA or regular function name); //so we can use `dynamic_cast` in FunctionCall to double check if we are dealing with an XLA or regular function
std::list<std::shared_ptr<Node>>& get_ops(); std::list<std::shared_ptr<Node>> get_ops() const;
const std::list<std::shared_ptr<Node>>& get_ops() const;
std::list<std::shared_ptr<Node>>& get_ordered_ops(); std::list<std::shared_ptr<Node>>& get_ordered_ops();
const std::list<std::shared_ptr<Node>>& get_ordered_ops() const; const std::list<std::shared_ptr<Node>>& get_ordered_ops() const;
void set_ordered_ops(const std::list<std::shared_ptr<Node>>&); void set_ordered_ops(const std::list<std::shared_ptr<Node>>&);
...@@ -90,7 +89,6 @@ namespace ngraph ...@@ -90,7 +89,6 @@ namespace ngraph
std::string m_name; std::string m_name;
bool m_ordered_ops_valid; bool m_ordered_ops_valid;
std::list<std::shared_ptr<Node>> m_ordered_ops; std::list<std::shared_ptr<Node>> m_ordered_ops;
std::list<std::shared_ptr<Node>> m_ops;
size_t m_temporary_pool_size; size_t m_temporary_pool_size;
private: private:
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <cassert>
#include <deque>
#include <forward_list>
#include <iomanip>
#include <map>
#include <unordered_set>
#include "ngraph/function.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp"
#include "ngraph/node.hpp"
using namespace std;
void ngraph::traverse_nodes(const std::shared_ptr<const Function> p,
std::function<void(std::shared_ptr<Node>)> f)
{
traverse_nodes(p.get(), f);
}
void ngraph::traverse_nodes(const Function* p, std::function<void(std::shared_ptr<Node>)> f)
{
std::unordered_set<std::shared_ptr<Node>> instances_seen;
std::deque<std::shared_ptr<Node>> stack;
for (auto r : p->get_results())
{
stack.push_front(r);
}
for (auto param : p->get_parameters())
{
stack.push_front(param);
}
while (stack.size() > 0)
{
std::shared_ptr<Node> n = stack.front();
if (instances_seen.count(n) == 0)
{
instances_seen.insert(n);
f(n);
}
stack.pop_front();
for (auto arg : n->get_input_ops())
{
if (instances_seen.count(arg) == 0)
{
stack.push_front(arg);
}
}
}
}
void ngraph::traverse_functions(std::shared_ptr<ngraph::Function> p,
std::function<void(shared_ptr<Function>)> f)
{
std::unordered_set<shared_ptr<Function>> instances_seen;
deque<shared_ptr<Function>> stack;
stack.push_front(p);
while (stack.size() > 0)
{
shared_ptr<Function> func = stack.front();
if (instances_seen.find(func) == instances_seen.end())
{
instances_seen.insert(func);
f(func);
}
stack.pop_front();
for (shared_ptr<Node> op : func->get_ops())
{
shared_ptr<Function> fp = op->get_function();
if (fp)
{
stack.push_front(fp);
}
}
}
}
void ngraph::free_nodes(shared_ptr<Function> p)
{
std::deque<Node*> sorted_list;
traverse_nodes(p, [&](shared_ptr<Node> n) { sorted_list.push_front(n.get()); });
for (Node* n : sorted_list)
{
n->clear_arguments();
}
}
void ngraph::replace_node(std::shared_ptr<Node> target, std::shared_ptr<Node> replacement)
{
if (target->is_output()) //this restriction can be lifted when we find an use case for it
{
return;
}
//fix input/output descriptors
NGRAPH_DEBUG << "Replacing target = " << target << " , " << target->get_name() << " , "
<< "replacement = " << replacement << " , " << replacement->get_name();
assert(target->get_outputs().size() == replacement->get_outputs().size());
for (size_t i = 0; i < target->get_outputs().size(); i++)
{
auto& target_output = target->get_outputs().at(i);
std::set<ngraph::descriptor::Input*> copy_inputs{
begin(target_output.get_inputs()),
end(target_output.get_inputs())}; //replace_output modifies target_output->m_inputs
for (auto input : copy_inputs)
{
input->replace_output(replacement->get_outputs().at(i));
}
}
//fix users and arguments
replace_node_users_arguments(target, replacement);
}
void ngraph::replace_node_users_arguments(std::shared_ptr<Node> target,
std::shared_ptr<Node> replacement)
{
NGRAPH_DEBUG << "Replacing target = " << target << " , " << target->get_name() << " , "
<< "replacement = " << replacement << " , " << replacement->get_name();
NGRAPH_DEBUG << "user = " << replacement << " , " << replacement->get_name();
for (auto user : target->users())
{
auto& args = const_cast<ngraph::Nodes&>(user->get_arguments_FOR_GRAPH_REWRITE_ONLY());
auto it = std::find(begin(args), end(args), target);
assert(it != end(args));
//NGRAPH_DEBUG << "Replaced " << *it << " w/ " << replacement << " in args of " << user << " , args = " << &args;
it = args.erase(it);
args.insert(it, replacement);
const_cast<std::multiset<Node*>&>(replacement->users()).insert(user);
}
const_cast<std::multiset<Node*>&>(target->users()).clear();
}
std::list<std::shared_ptr<ngraph::Node>>
ngraph::topological_sort(const std::list<std::shared_ptr<Node>>& nodes)
{
deque<ngraph::Node*> independent_nodes;
unordered_map<const ngraph::Node*, size_t> node_depencency_count;
unordered_map<ngraph::Node*, shared_ptr<ngraph::Node>> node_map;
for (auto node : nodes)
{
node_map[node.get()] = node;
node_depencency_count[node.get()] = node->get_input_ops().size();
if (node->get_input_ops().size() == 0)
{
independent_nodes.push_back(node.get());
}
}
list<shared_ptr<ngraph::Node>> result_list;
while (independent_nodes.size() > 0)
{
auto independent_node = independent_nodes.front();
result_list.push_back(node_map[independent_node]);
independent_nodes.pop_front();
for (auto user : independent_node->users())
{
node_depencency_count[user] -= 1;
size_t count = node_depencency_count[user];
if (count == 0)
{
independent_nodes.push_back(user);
}
}
}
return result_list;
}
void ngraph::NodeMap::add(std::shared_ptr<ngraph::Node> orig,
std::shared_ptr<ngraph::Node> replacement)
{
if (exists(orig))
{
throw ngraph_error("NodeMap: key already exists");
}
m_node_map[orig] = replacement;
}
std::shared_ptr<ngraph::Node> ngraph::NodeMap::get(std::shared_ptr<ngraph::Node> orig) const
{
if (!exists(orig))
{
throw ngraph_error("NodeMap: key does not exist");
}
return m_node_map.at(orig);
}
std::list<std::shared_ptr<ngraph::Node>>
ngraph::clone_nodes(const std::list<std::shared_ptr<ngraph::Node>>& nodes, NodeMap& node_map)
{
// for each node in topological order
auto sorted_nodes = topological_sort(nodes);
for (auto node : sorted_nodes)
{
if (!node_map.exists(node))
{
// get (already) cloned arguments and clone the node
Nodes cloned_args;
for (auto arg : node->get_input_ops())
{
cloned_args.push_back(node_map.get(arg));
}
node_map.add(node, node->copy_with_new_args(cloned_args));
}
}
// create and return list of cloned nodes
// order matches input list (not necessarily topological)
std::list<std::shared_ptr<ngraph::Node>> cloned_nodes;
for (auto node : nodes)
{
cloned_nodes.push_back(node_map.get(node));
}
return cloned_nodes;
}
std::shared_ptr<ngraph::Function> ngraph::clone_function(std::shared_ptr<ngraph::Function> func,
NodeMap& node_map)
{
// clone function operations
clone_nodes(func->get_ops(), node_map);
// get cloned function result and parameters
auto cloned_result = node_map.get(func->get_result());
std::vector<std::shared_ptr<op::Parameter>> cloned_params;
for (auto param : func->get_parameters())
{
cloned_params.push_back(std::dynamic_pointer_cast<op::Parameter>(node_map.get(param)));
}
// create and return cloned function
return std::make_shared<ngraph::Function>(cloned_result, cloned_params);
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <algorithm>
#include <chrono>
#include <deque>
#include <functional>
#include <iostream>
#include <list>
#include <map>
#include <memory>
#include <sstream>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
namespace ngraph
{
class Node;
class Function;
void traverse_nodes(const std::shared_ptr<const Function> p,
std::function<void(std::shared_ptr<Node>)> f);
void traverse_nodes(const Function* p, std::function<void(std::shared_ptr<Node>)> f);
void traverse_functions(std::shared_ptr<Function> p,
std::function<void(std::shared_ptr<Function>)> f);
void free_nodes(std::shared_ptr<Function>);
void replace_node(std::shared_ptr<Node> target, std::shared_ptr<Node> replacement);
void replace_node_users_arguments(std::shared_ptr<Node> target,
std::shared_ptr<Node> replacement);
std::list<std::shared_ptr<Node>>
topological_sort(const std::list<std::shared_ptr<Node>>& nodes);
// maps original to replacement nodes e.g. for clone utilities
// performs index checking on access
class NodeMap
{
public:
// map original node to replcacement node
// throws ngraph_error if key already exists
void add(std::shared_ptr<ngraph::Node> orig, std::shared_ptr<ngraph::Node> replacement);
// get replacement node from original node
// throws ngrah_error if key does not exist
std::shared_ptr<ngraph::Node> get(std::shared_ptr<ngraph::Node> orig) const;
// returns true if original node is already mapped
bool exists(std::shared_ptr<ngraph::Node> orig) const
{
return (m_node_map.count(orig) != 0);
}
const std::unordered_map<std::shared_ptr<ngraph::Node>, std::shared_ptr<ngraph::Node>>&
get_node_map() const
{
return m_node_map;
}
std::unordered_map<std::shared_ptr<ngraph::Node>, std::shared_ptr<ngraph::Node>>&
get_node_map()
{
return m_node_map;
}
private:
std::unordered_map<std::shared_ptr<ngraph::Node>, std::shared_ptr<ngraph::Node>> m_node_map;
};
// input nodes are cloned and returned
// NodeMap input may contain default node mapping i.e. pre-cloned nodes
// NodeMap output (by reference) fully maps input and cloned nodes
std::list<std::shared_ptr<ngraph::Node>>
clone_nodes(const std::list<std::shared_ptr<ngraph::Node>>& nodes, NodeMap& node_map);
// input function is cloned and returned
// NodeMap input may contain default node mapping i.e. pre-cloned nodes
// NodeMap output (by reference) fully maps input and cloned function ops
std::shared_ptr<ngraph::Function> clone_function(std::shared_ptr<ngraph::Function> func,
NodeMap& node_map);
}
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include "ngraph/descriptor/input.hpp" #include "ngraph/descriptor/input.hpp"
#include "ngraph/descriptor/output.hpp" #include "ngraph/descriptor/output.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp" #include "ngraph/log.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/pass/liveness.hpp" #include "ngraph/pass/liveness.hpp"
......
...@@ -16,12 +16,12 @@ ...@@ -16,12 +16,12 @@
#include <memory> #include <memory>
#include "ngraph/function.hpp" #include "ngraph/function.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/ops/function_call.hpp" #include "ngraph/ops/function_call.hpp"
#include "ngraph/ops/reduce.hpp" #include "ngraph/ops/reduce.hpp"
#include "ngraph/pass/manager.hpp" #include "ngraph/pass/manager.hpp"
#include "ngraph/pass/pass.hpp" #include "ngraph/pass/pass.hpp"
#include "ngraph/util.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include "memory_visualize.hpp" #include "memory_visualize.hpp"
#include "ngraph/descriptor/tensor.hpp" #include "ngraph/descriptor/tensor.hpp"
#include "ngraph/function.hpp" #include "ngraph/function.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
......
...@@ -16,11 +16,11 @@ ...@@ -16,11 +16,11 @@
#include <unordered_map> #include <unordered_map>
#include "ngraph/function.hpp" #include "ngraph/function.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp" #include "ngraph/log.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/pass/manager.hpp" #include "ngraph/pass/manager.hpp"
#include "ngraph/pass/topological_sort.hpp" #include "ngraph/pass/topological_sort.hpp"
#include "ngraph/util.hpp"
using namespace ngraph; using namespace ngraph;
using namespace std; using namespace std;
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include <fstream> #include <fstream>
#include "ngraph/function.hpp" #include "ngraph/function.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/pass/pass.hpp" #include "ngraph/pass/pass.hpp"
#include "ngraph/pass/visualize_tree.hpp" #include "ngraph/pass/visualize_tree.hpp"
......
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#include "ngraph/descriptor/output.hpp" #include "ngraph/descriptor/output.hpp"
#include "ngraph/file_util.hpp" #include "ngraph/file_util.hpp"
#include "ngraph/function.hpp" #include "ngraph/function.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/ops/abs.hpp" #include "ngraph/ops/abs.hpp"
#include "ngraph/ops/acos.hpp" #include "ngraph/ops/acos.hpp"
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include <vector> #include <vector>
#include "ngraph/function.hpp" #include "ngraph/function.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/ops/broadcast.hpp" #include "ngraph/ops/broadcast.hpp"
#include "ngraph/ops/concatenate.hpp" #include "ngraph/ops/concatenate.hpp"
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#include "ngraph/serializer.hpp" #include "ngraph/serializer.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/ops/abs.hpp" #include "ngraph/ops/abs.hpp"
#include "ngraph/ops/acos.hpp" #include "ngraph/ops/acos.hpp"
#include "ngraph/ops/add.hpp" #include "ngraph/ops/add.hpp"
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include <unordered_set> #include <unordered_set>
#include "ngraph/function.hpp" #include "ngraph/function.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp" #include "ngraph/log.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/runtime/backend.hpp" #include "ngraph/runtime/backend.hpp"
...@@ -138,221 +139,6 @@ size_t ngraph::hash_combine(const std::vector<size_t>& list) ...@@ -138,221 +139,6 @@ size_t ngraph::hash_combine(const std::vector<size_t>& list)
return seed; return seed;
} }
void ngraph::traverse_nodes(std::shared_ptr<ngraph::Function> p,
std::function<void(shared_ptr<Node>)> f)
{
traverse_nodes(p.get(), f);
}
void ngraph::traverse_nodes(ngraph::Function* p, std::function<void(shared_ptr<Node>)> f)
{
std::unordered_set<shared_ptr<Node>> instances_seen;
deque<shared_ptr<Node>> stack;
for (size_t i = 0; i < p->get_output_size(); ++i)
{
stack.push_front(p->get_output_op(i));
}
for (auto param : p->get_parameters())
{
stack.push_front(param);
}
while (stack.size() > 0)
{
shared_ptr<Node> n = stack.front();
if (instances_seen.count(n) == 0)
{
instances_seen.insert(n);
f(n);
}
stack.pop_front();
for (auto arg : n->get_input_ops())
{
if (instances_seen.count(arg) == 0)
{
stack.push_front(arg);
}
}
}
}
void ngraph::traverse_functions(std::shared_ptr<ngraph::Function> p,
std::function<void(shared_ptr<Function>)> f)
{
std::unordered_set<shared_ptr<Function>> instances_seen;
deque<shared_ptr<Function>> stack;
stack.push_front(p);
while (stack.size() > 0)
{
shared_ptr<Function> func = stack.front();
if (instances_seen.find(func) == instances_seen.end())
{
instances_seen.insert(func);
f(func);
}
stack.pop_front();
for (shared_ptr<Node> op : func->get_ops())
{
shared_ptr<Function> fp = op->get_function();
if (fp)
{
stack.push_front(fp);
}
}
}
}
void ngraph::free_nodes(shared_ptr<Function> p)
{
std::deque<Node*> sorted_list;
traverse_nodes(p, [&](shared_ptr<Node> n) { sorted_list.push_front(n.get()); });
for (Node* n : sorted_list)
{
n->clear_arguments();
}
}
void ngraph::replace_node(std::shared_ptr<Node> target, std::shared_ptr<Node> replacement)
{
if (target->is_output()) //this restriction can be lifted when we find an use case for it
{
return;
}
//fix input/output descriptors
NGRAPH_DEBUG << "Replacing target = " << target << " , " << target->get_name() << " , "
<< "replacement = " << replacement << " , " << replacement->get_name();
assert(target->get_output_size() == replacement->get_output_size());
for (size_t i = 0; i < target->get_output_size(); i++)
{
std::set<ngraph::descriptor::Input*> copy_inputs{
begin(target->get_output_inputs(i)),
end(target->get_output_inputs(i))}; //replace_output modifies target_output->m_inputs
for (auto input : copy_inputs)
{
input->replace_output(replacement, i);
}
}
//fix users and arguments
replace_node_users_arguments(target, replacement);
}
void ngraph::replace_node_users_arguments(std::shared_ptr<Node> target,
std::shared_ptr<Node> replacement)
{
NGRAPH_DEBUG << "Replacing target = " << target << " , " << target->get_name() << " , "
<< "replacement = " << replacement << " , " << replacement->get_name();
NGRAPH_DEBUG << "user = " << replacement << " , " << replacement->get_name();
for (auto user : target->users())
{
auto& args = const_cast<ngraph::Nodes&>(user->get_arguments_FOR_GRAPH_REWRITE_ONLY());
auto it = std::find(begin(args), end(args), target);
assert(it != end(args));
//NGRAPH_DEBUG << "Replaced " << *it << " w/ " << replacement << " in args of " << user << " , args = " << &args;
it = args.erase(it);
args.insert(it, replacement);
const_cast<std::multiset<Node*>&>(replacement->users()).insert(user);
}
const_cast<std::multiset<Node*>&>(target->users()).clear();
}
std::list<std::shared_ptr<ngraph::Node>>
ngraph::topological_sort(const std::list<std::shared_ptr<Node>>& nodes)
{
deque<ngraph::Node*> independent_nodes;
unordered_map<const ngraph::Node*, size_t> node_depencency_count;
unordered_map<ngraph::Node*, shared_ptr<ngraph::Node>> node_map;
for (auto node : nodes)
{
node_map[node.get()] = node;
node_depencency_count[node.get()] = node->get_input_ops().size();
if (node->get_input_ops().size() == 0)
{
independent_nodes.push_back(node.get());
}
}
list<shared_ptr<ngraph::Node>> result_list;
while (independent_nodes.size() > 0)
{
auto independent_node = independent_nodes.front();
result_list.push_back(node_map[independent_node]);
independent_nodes.pop_front();
for (auto user : independent_node->users())
{
node_depencency_count[user] -= 1;
size_t count = node_depencency_count[user];
if (count == 0)
{
independent_nodes.push_back(user);
}
}
}
return result_list;
}
std::list<std::shared_ptr<ngraph::Node>>
ngraph::clone_nodes(const std::list<std::shared_ptr<ngraph::Node>>& nodes, NodeMap& node_map)
{
// for each node in topological order
auto sorted_nodes = topological_sort(nodes);
for (auto node : sorted_nodes)
{
if (node_map.count(node) == 0)
{
// get (already) cloned arguments and clone the node
Nodes cloned_args;
for (auto arg : node->get_input_ops())
{
cloned_args.push_back(node_map[arg]);
}
node_map[node] = node->copy_with_new_args(cloned_args);
}
}
// create and return list of cloned nodes
// order matches input list (not necessarily topological)
std::list<std::shared_ptr<ngraph::Node>> cloned_nodes;
for (auto node : nodes)
{
cloned_nodes.push_back(node_map[node]);
}
return cloned_nodes;
}
std::shared_ptr<ngraph::Function> ngraph::clone_function(std::shared_ptr<ngraph::Function> func,
NodeMap& node_map)
{
// clone function operations
clone_nodes(func->get_ops(), node_map);
// get cloned function result and parameters
Nodes cloned_results;
for (size_t i = 0; i < func->get_output_size(); ++i)
{
cloned_results.push_back(node_map[func->get_output_op(i)]);
}
std::vector<std::shared_ptr<op::Parameter>> cloned_params;
for (auto param : func->get_parameters())
{
cloned_params.push_back(std::dynamic_pointer_cast<op::Parameter>(node_map[param]));
}
// create and return cloned function
return std::make_shared<ngraph::Function>(cloned_results, cloned_params);
}
void* ngraph::aligned_alloc(size_t alignment, size_t size) void* ngraph::aligned_alloc(size_t alignment, size_t size)
{ {
#ifdef __APPLE__ #ifdef __APPLE__
...@@ -397,8 +183,8 @@ ngraph::FpropCache ngraph::cache_fprop(std::shared_ptr<ngraph::Function> fprop, ...@@ -397,8 +183,8 @@ ngraph::FpropCache ngraph::cache_fprop(std::shared_ptr<ngraph::Function> fprop,
// shape and element type as the nodes in fprop // shape and element type as the nodes in fprop
NodeMap node_param_map; NodeMap node_param_map;
ngraph::traverse_nodes(fprop, [&node_param_map](std::shared_ptr<Node> node) { ngraph::traverse_nodes(fprop, [&node_param_map](std::shared_ptr<Node> node) {
node_param_map[node] = node_param_map.get(
std::make_shared<op::Parameter>(node->get_element_type(), node->get_shape()); std::make_shared<op::Parameter>(node->get_element_type(), node->get_shape()));
}); });
// Traverse bprop to find all of the nodes in the graph // Traverse bprop to find all of the nodes in the graph
...@@ -425,7 +211,7 @@ ngraph::FpropCache ngraph::cache_fprop(std::shared_ptr<ngraph::Function> fprop, ...@@ -425,7 +211,7 @@ ngraph::FpropCache ngraph::cache_fprop(std::shared_ptr<ngraph::Function> fprop,
// and store those nodes that aren't needed in bprop // and store those nodes that aren't needed in bprop
FpropCache fprop_cache; FpropCache fprop_cache;
std::vector<std::shared_ptr<Node>> unused_nodes; std::vector<std::shared_ptr<Node>> unused_nodes;
for (auto kv : node_param_map) for (auto kv : node_param_map.get_node_map())
{ {
// if it's not in bprop, mark it unused // if it's not in bprop, mark it unused
if (in_bprop.count(kv.first) == 0) if (in_bprop.count(kv.first) == 0)
...@@ -442,7 +228,7 @@ ngraph::FpropCache ngraph::cache_fprop(std::shared_ptr<ngraph::Function> fprop, ...@@ -442,7 +228,7 @@ ngraph::FpropCache ngraph::cache_fprop(std::shared_ptr<ngraph::Function> fprop,
// erase all unused nodes form the map // erase all unused nodes form the map
for (auto node : unused_nodes) for (auto node : unused_nodes)
{ {
node_param_map.erase(node); node_param_map.get_node_map().erase(node);
} }
// create the new outputs for fprop and the new fprop function // create the new outputs for fprop and the new fprop function
...@@ -461,7 +247,7 @@ ngraph::FpropCache ngraph::cache_fprop(std::shared_ptr<ngraph::Function> fprop, ...@@ -461,7 +247,7 @@ ngraph::FpropCache ngraph::cache_fprop(std::shared_ptr<ngraph::Function> fprop,
Nodes cloned_results; Nodes cloned_results;
for (auto node : bprop->get_results()) for (auto node : bprop->get_results())
{ {
cloned_results.push_back(node_param_map[node]); cloned_results.push_back(node_param_map.get(node));
} }
// get clone bprop parameters // get clone bprop parameters
...@@ -469,13 +255,14 @@ ngraph::FpropCache ngraph::cache_fprop(std::shared_ptr<ngraph::Function> fprop, ...@@ -469,13 +255,14 @@ ngraph::FpropCache ngraph::cache_fprop(std::shared_ptr<ngraph::Function> fprop,
for (auto param : adjoints) for (auto param : adjoints)
{ {
bprop_input_params.push_back( bprop_input_params.push_back(
std::dynamic_pointer_cast<op::Parameter>(node_param_map[param])); std::dynamic_pointer_cast<op::Parameter>(node_param_map.get(param)));
} }
// add the cached fprop nodes as inputs to bprop // add the cached fprop nodes as inputs to bprop
for (auto x : fprop_cache.fprop_output_nodes) for (auto x : fprop_cache.fprop_output_nodes)
{ {
bprop_input_params.push_back(std::dynamic_pointer_cast<op::Parameter>(node_param_map[x])); bprop_input_params.push_back(
std::dynamic_pointer_cast<op::Parameter>(node_param_map.get(x)));
} }
// create the new bprop function // create the new bprop function
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <algorithm> #include <algorithm>
#include <chrono> #include <chrono>
#include <deque>
#include <functional> #include <functional>
#include <iostream> #include <iostream>
#include <list> #include <list>
...@@ -24,6 +25,7 @@ ...@@ -24,6 +25,7 @@
#include <sstream> #include <sstream>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <unordered_set>
#include <vector> #include <vector>
namespace ngraph namespace ngraph
...@@ -235,36 +237,6 @@ namespace ngraph ...@@ -235,36 +237,6 @@ namespace ngraph
return (x == 0 ? 0 : (1 + (x - 1) / y)); return (x == 0 ? 0 : (1 + (x - 1) / y));
} }
void traverse_nodes(Function* p, std::function<void(std::shared_ptr<Node>)> f);
void traverse_nodes(std::shared_ptr<Function> p, std::function<void(std::shared_ptr<Node>)> f);
void traverse_functions(std::shared_ptr<Function> p,
std::function<void(std::shared_ptr<Function>)> f);
void free_nodes(std::shared_ptr<Function>);
void replace_node(std::shared_ptr<Node> target, std::shared_ptr<Node> replacement);
void replace_node_users_arguments(std::shared_ptr<Node> target,
std::shared_ptr<Node> replacement);
std::list<std::shared_ptr<Node>>
topological_sort(const std::list<std::shared_ptr<Node>>& nodes);
using NodeMap =
std::unordered_map<std::shared_ptr<ngraph::Node>, std::shared_ptr<ngraph::Node>>;
// input nodes are cloned and returned
// NodeMap input may contain default node mapping i.e. pre-cloned nodes
// NodeMap output (by reference) fully maps input and cloned nodes
std::list<std::shared_ptr<ngraph::Node>>
clone_nodes(const std::list<std::shared_ptr<ngraph::Node>>& nodes, NodeMap& node_map);
// input function is cloned and returned
// NodeMap input may contain default node mapping i.e. pre-cloned nodes
// NodeMap output (by reference) fully maps input and cloned function ops
std::shared_ptr<ngraph::Function> clone_function(std::shared_ptr<ngraph::Function> func,
NodeMap& node_map);
void* aligned_alloc(size_t alignment, size_t size); void* aligned_alloc(size_t alignment, size_t size);
void aligned_free(void*); void aligned_free(void*);
size_t round_up(size_t size, size_t alignment); size_t round_up(size_t size, size_t alignment);
......
...@@ -132,6 +132,7 @@ TEST(build_graph, function_undeclared_parameters) ...@@ -132,6 +132,7 @@ TEST(build_graph, function_undeclared_parameters)
try try
{ {
auto f = make_shared<Function>(dot, op::Parameters{arg0, arg1, arg3}); auto f = make_shared<Function>(dot, op::Parameters{arg0, arg1, arg3});
f->get_ops();
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Undeclared parameter not detected."; FAIL() << "Undeclared parameter not detected.";
} }
......
...@@ -19,10 +19,10 @@ ...@@ -19,10 +19,10 @@
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "ngraph/graph_util.hpp"
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
#include "ngraph/pass/manager.hpp" #include "ngraph/pass/manager.hpp"
#include "ngraph/pass/topological_sort.hpp" #include "ngraph/pass/topological_sort.hpp"
#include "ngraph/util.hpp"
#include "util/test_tools.hpp" #include "util/test_tools.hpp"
using namespace ngraph; using namespace ngraph;
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include <memory> #include <memory>
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp" #include "ngraph/log.hpp"
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
#include "ngraph/pass/graph_rewrite.hpp" #include "ngraph/pass/graph_rewrite.hpp"
...@@ -27,7 +28,6 @@ ...@@ -27,7 +28,6 @@
#include "ngraph/pattern/matcher.hpp" #include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/any.hpp" #include "ngraph/pattern/op/any.hpp"
#include "ngraph/pattern/op/label.hpp" #include "ngraph/pattern/op/label.hpp"
#include "ngraph/util.hpp"
using namespace ngraph; using namespace ngraph;
using namespace std; using namespace std;
......
...@@ -19,13 +19,13 @@ ...@@ -19,13 +19,13 @@
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp" #include "ngraph/log.hpp"
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
#include "ngraph/pass/dump_sorted.hpp" #include "ngraph/pass/dump_sorted.hpp"
#include "ngraph/pass/manager.hpp" #include "ngraph/pass/manager.hpp"
#include "ngraph/pass/topological_sort.hpp" #include "ngraph/pass/topological_sort.hpp"
#include "ngraph/pass/visualize_tree.hpp" #include "ngraph/pass/visualize_tree.hpp"
#include "ngraph/util.hpp"
#include "util/test_tools.hpp" #include "util/test_tools.hpp"
using namespace std; using namespace std;
......
...@@ -19,8 +19,8 @@ ...@@ -19,8 +19,8 @@
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "ngraph/function.hpp" #include "ngraph/function.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
#include "ngraph/util.hpp"
#include "util/all_close.hpp" #include "util/all_close.hpp"
#include "util/ndarray.hpp" #include "util/ndarray.hpp"
...@@ -274,7 +274,7 @@ public: ...@@ -274,7 +274,7 @@ public:
auto cloneit = clone.begin(); auto cloneit = clone.begin();
while (origit != orig.end() && cloneit != clone.end()) while (origit != orig.end() && cloneit != clone.end())
{ {
if (*cloneit != nm.at(*origit)) if (*cloneit != nm.get_node_map().at(*origit))
{ {
return false; return false;
} }
...@@ -290,11 +290,11 @@ TEST_F(CloneTest, clone_nodes_full) ...@@ -290,11 +290,11 @@ TEST_F(CloneTest, clone_nodes_full)
auto cloned_nodes = clone_nodes(nodes, node_map); auto cloned_nodes = clone_nodes(nodes, node_map);
ASSERT_TRUE(CompareNodes(nodes, cloned_nodes, node_map)); ASSERT_TRUE(CompareNodes(nodes, cloned_nodes, node_map));
ASSERT_NE(nullptr, std::dynamic_pointer_cast<op::Parameter>(node_map[A])); ASSERT_NE(nullptr, std::dynamic_pointer_cast<op::Parameter>(node_map.get(A)));
ASSERT_NE(nullptr, std::dynamic_pointer_cast<op::Parameter>(node_map[B])); ASSERT_NE(nullptr, std::dynamic_pointer_cast<op::Parameter>(node_map.get(B)));
ASSERT_NE(nullptr, std::dynamic_pointer_cast<op::Parameter>(node_map[C])); ASSERT_NE(nullptr, std::dynamic_pointer_cast<op::Parameter>(node_map.get(C)));
ASSERT_NE(nullptr, std::dynamic_pointer_cast<op::Add>(node_map[AplusB])); ASSERT_NE(nullptr, std::dynamic_pointer_cast<op::Add>(node_map.get(AplusB)));
ASSERT_NE(nullptr, std::dynamic_pointer_cast<op::Multiply>(node_map[AplusBtimesC])); ASSERT_NE(nullptr, std::dynamic_pointer_cast<op::Multiply>(node_map.get(AplusBtimesC)));
auto sorted_nodes = topological_sort(nodes); auto sorted_nodes = topological_sort(nodes);
auto sorted_cloned_nodes = topological_sort(cloned_nodes); auto sorted_cloned_nodes = topological_sort(cloned_nodes);
...@@ -305,13 +305,13 @@ TEST_F(CloneTest, clone_nodes_partial) ...@@ -305,13 +305,13 @@ TEST_F(CloneTest, clone_nodes_partial)
{ {
// map A -> A' prior to clone // map A -> A' prior to clone
auto Aprime = make_shared<op::Parameter>(element::f32, shape); auto Aprime = make_shared<op::Parameter>(element::f32, shape);
node_map[A] = Aprime; node_map.add(A, Aprime);
auto cloned_nodes = clone_nodes(nodes, node_map); auto cloned_nodes = clone_nodes(nodes, node_map);
ASSERT_TRUE(CompareNodes(nodes, cloned_nodes, node_map)); ASSERT_TRUE(CompareNodes(nodes, cloned_nodes, node_map));
// ensure A -> A' after clone // ensure A -> A' after clone
ASSERT_EQ(Aprime, node_map[A]); ASSERT_EQ(Aprime, node_map.get(A));
} }
TEST_F(CloneTest, clone_function_full) TEST_F(CloneTest, clone_function_full)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment