Commit 58f9af01 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Robert Kimball

Control dependencies (#1445)

* topological sort with cdeps

* add control deps API, fix unit tests

* rollback adjoints changes

* fix test failures,add more tests

* remove dead code

* address scott's feedback
parent 68eb2e7d
......@@ -85,24 +85,23 @@ void Function::init()
{
validate_nodes_and_infer_types();
traverse_nodes(this, [&](shared_ptr<Node> node) {
std::shared_ptr<op::Parameter> p = std::dynamic_pointer_cast<op::Parameter>(node);
if (nullptr != p)
{
auto it = std::find_if(m_parameters.begin(),
m_parameters.end(),
[p](std::shared_ptr<op::Parameter> q) { return (p == q); });
if (it == m_parameters.end())
{
throw ngraph_error("Function references undeclared parameter");
}
}
});
traverse_nodes(this,
[&](shared_ptr<Node> node) {
if (node->is_parameter())
{
auto it = std::find(m_parameters.begin(), m_parameters.end(), node);
if (it == m_parameters.end())
{
throw ngraph_error("Function references undeclared parameter");
}
}
},
true /*include control dependencies*/);
}
std::list<shared_ptr<Node>> Function::get_ordered_ops()
std::list<shared_ptr<Node>> Function::get_ordered_ops(bool include_control_deps) const
{
return topological_sort(get_ops());
return topological_sort(get_ops(include_control_deps), include_control_deps);
}
const std::string& Function::get_friendly_name() const
......@@ -176,10 +175,10 @@ shared_ptr<Node> Function::get_result() const
return m_results.at(0);
}
std::list<shared_ptr<Node>> Function::get_ops() const
std::list<shared_ptr<Node>> Function::get_ops(bool include_control_deps) const
{
std::list<std::shared_ptr<Node>> ops;
traverse_nodes(this, [&](shared_ptr<Node> node) { ops.push_back(node); });
traverse_nodes(this, [&](shared_ptr<Node> node) { ops.push_back(node); }, include_control_deps);
return ops;
}
......
......@@ -73,8 +73,8 @@ namespace ngraph
// so we can use `dynamic_cast` in FunctionCall to double check if we are dealing with
// an XLA or regular function
void set_name(const std::string& name);
std::list<std::shared_ptr<Node>> get_ops() const;
std::list<std::shared_ptr<Node>> get_ordered_ops();
std::list<std::shared_ptr<Node>> get_ops(bool include_control_deps = true) const;
std::list<std::shared_ptr<Node>> get_ordered_ops(bool include_control_deps = true) const;
friend std::ostream& operator<<(std::ostream&, const Function&);
size_t get_instance_id() { return m_instance_id; }
size_t get_temporary_pool_size();
......
......@@ -39,12 +39,15 @@ using namespace std;
using namespace ngraph;
void ngraph::traverse_nodes(const std::shared_ptr<const Function> p,
std::function<void(std::shared_ptr<Node>)> f)
std::function<void(std::shared_ptr<Node>)> f,
bool include_control_deps)
{
traverse_nodes(p.get(), f);
traverse_nodes(p.get(), f, include_control_deps);
}
void ngraph::traverse_nodes(const Function* p, std::function<void(std::shared_ptr<Node>)> f)
void ngraph::traverse_nodes(const Function* p,
std::function<void(std::shared_ptr<Node>)> f,
bool include_control_deps)
{
NodeVector nodes;
......@@ -58,14 +61,15 @@ void ngraph::traverse_nodes(const Function* p, std::function<void(std::shared_pt
nodes.push_back(param);
}
traverse_nodes(nodes, f);
traverse_nodes(nodes, f, include_control_deps);
}
// This version of traverses directly from input/output nodes to perform functions on
// graphs that are not wrapped by functions. Most useful for finding parameters of a graph
// directly from the result nodes, not from function parameters.
void ngraph::traverse_nodes(const NodeVector& io_nodes,
std::function<void(std::shared_ptr<Node>)> f)
std::function<void(std::shared_ptr<Node>)> f,
bool include_control_deps)
{
std::unordered_set<std::shared_ptr<Node>> instances_seen;
std::deque<std::shared_ptr<Node>> stack;
......@@ -91,6 +95,17 @@ void ngraph::traverse_nodes(const NodeVector& io_nodes,
stack.push_front(arg);
}
}
if (include_control_deps)
{
for (auto cdep : n->get_control_dependencies())
{
if (instances_seen.count(cdep) == 0)
{
stack.push_front(cdep);
}
}
}
}
}
......@@ -213,7 +228,7 @@ std::list<std::shared_ptr<ngraph::Node>>
ngraph::clone_nodes(const std::list<std::shared_ptr<ngraph::Node>>& nodes, NodeMap& node_map)
{
// for each node in topological order
auto sorted_nodes = topological_sort(nodes);
auto sorted_nodes = topological_sort(nodes, true);
for (auto node : sorted_nodes)
{
if (!node_map.exists(node))
......@@ -224,7 +239,14 @@ std::list<std::shared_ptr<ngraph::Node>>
{
cloned_args.push_back(node_map.get(arg));
}
node_map.add(node, node->copy_with_new_args(cloned_args));
auto cloned_node = node->copy_with_new_args(cloned_args);
//copy control dependencies
for (auto cdep : node->get_control_dependencies())
{
cloned_node->add_control_dependency(node_map.get(cdep));
}
node_map.add(node, cloned_node);
}
}
......@@ -248,7 +270,7 @@ std::shared_ptr<ngraph::Function> ngraph::clone_function(const ngraph::Function&
NodeMap& node_map)
{
// clone function operations
clone_nodes(func.get_ops(), node_map);
clone_nodes(func.get_ops(true), node_map);
// get cloned function results and parameters
ResultVector cloned_results;
......
......@@ -43,10 +43,15 @@ namespace ngraph
}
void traverse_nodes(const std::shared_ptr<const Function> p,
std::function<void(std::shared_ptr<Node>)> f);
void traverse_nodes(const Function* p, std::function<void(std::shared_ptr<Node>)> f);
std::function<void(std::shared_ptr<Node>)> f,
bool include_control_deps = false);
void traverse_nodes(const Function* p,
std::function<void(std::shared_ptr<Node>)> f,
bool include_control_deps);
void traverse_nodes(const NodeVector& io_nodes, std::function<void(std::shared_ptr<Node>)> f);
void traverse_nodes(const NodeVector& io_nodes,
std::function<void(std::shared_ptr<Node>)> f,
bool include_control_deps);
void traverse_functions(std::shared_ptr<Function> p,
std::function<void(std::shared_ptr<Function>)> f);
......@@ -54,17 +59,31 @@ namespace ngraph
void replace_node(std::shared_ptr<Node> target, std::shared_ptr<Node> replacement);
template <typename T>
std::list<std::shared_ptr<Node>> topological_sort(const T& nodes)
std::list<std::shared_ptr<Node>> topological_sort(const T& nodes,
bool include_control_deps = false)
{
std::deque<ngraph::Node*> independent_nodes;
std::unordered_map<const ngraph::Node*, size_t> node_dependency_count;
std::unordered_map<ngraph::Node*, std::shared_ptr<ngraph::Node>> node_map;
std::unordered_map<ngraph::Node*, std::set<Node*>> control_deps_users;
for (auto node : nodes)
{
//build an equivalent of node->get_users() but for control dependencies
size_t control_deps_count = 0;
if (include_control_deps)
{
for (auto cd : node->get_control_dependencies())
{
control_deps_users[cd.get()].insert(node.get());
}
control_deps_count = node->get_control_dependencies().size();
}
node_map[node.get()] = node;
node_dependency_count[node.get()] = node->get_arguments().size();
if (node->get_arguments().size() == 0)
size_t deps_count = node->get_arguments().size() + control_deps_count;
node_dependency_count[node.get()] = deps_count;
if (deps_count == 0)
{
independent_nodes.push_back(node.get());
}
......@@ -87,6 +106,21 @@ namespace ngraph
independent_nodes.push_back(user);
}
}
if (include_control_deps)
{
auto cdit = control_deps_users.find(independent_node);
if (cdit != control_deps_users.end())
for (auto cd_user : cdit->second)
{
node_dependency_count[cd_user] -= 1;
size_t count = node_dependency_count[cd_user];
if (count == 0)
{
independent_nodes.push_back(cd_user);
}
}
}
}
return result_list;
......
......@@ -182,6 +182,16 @@ NodeVector Node::get_arguments() const
return result;
}
const std::set<std::shared_ptr<Node>>& Node::get_control_dependencies() const
{
return m_control_dependencies;
}
void Node::add_control_dependency(std::shared_ptr<Node> node)
{
m_control_dependencies.insert(node);
}
std::vector<std::shared_ptr<Function>> Node::get_functions() const
{
return std::vector<std::shared_ptr<Function>>{};
......
......@@ -144,6 +144,16 @@ namespace ngraph
// TODO: Remove from unit tests.
const std::deque<descriptor::Output>& get_outputs() const;
/// Get control dependencies registered on the node
const std::set<std::shared_ptr<Node>>& get_control_dependencies() const;
void add_control_dependency(std::shared_ptr<Node> node);
void remove_control_dependency(std::shared_ptr<Node> node)
{
m_control_dependencies.erase(node);
}
/// Returns the number of outputs on the for the node.
size_t get_output_size() const;
......@@ -216,6 +226,7 @@ namespace ngraph
/// Use instance ids for comparison instead of memory addresses to improve determinism
bool operator<(const Node& other) const { return m_instance_id < other.m_instance_id; }
protected:
std::set<std::shared_ptr<Node>> m_control_dependencies;
void set_output_size(size_t n);
std::string m_node_type;
......
......@@ -167,14 +167,17 @@ void ngraph::serialize(ostream& out, shared_ptr<ngraph::Function> func, size_t i
writer.write(func->get_name(), j.c_str(), static_cast<uint32_t>(j.size()));
traverse_functions(func, [&](shared_ptr<ngraph::Function> f) {
traverse_nodes(const_cast<Function*>(f.get()), [&](shared_ptr<Node> node) {
if (auto c = dynamic_pointer_cast<op::Constant>(node))
{
uint32_t size = static_cast<uint32_t>(shape_size(c->get_output_shape(0)) *
c->get_output_element_type(0).size());
writer.write(c->get_name(), c->get_data_ptr(), size);
}
});
traverse_nodes(const_cast<Function*>(f.get()),
[&](shared_ptr<Node> node) {
if (auto c = dynamic_pointer_cast<op::Constant>(node))
{
uint32_t size =
static_cast<uint32_t>(shape_size(c->get_output_shape(0)) *
c->get_output_element_type(0).size());
writer.write(c->get_name(), c->get_data_ptr(), size);
}
},
true);
});
writer.close();
......@@ -301,45 +304,13 @@ static json write(const Function& f, bool binary_constant_data)
function["result"].push_back(f.get_output_op(i)->get_name());
}
list<shared_ptr<Node>> result_list;
{
deque<Node*> independent_nodes;
unordered_map<const Node*, size_t> node_depencency_count;
unordered_map<Node*, shared_ptr<Node>> node_map;
traverse_nodes(const_cast<Function*>(&f), [&](shared_ptr<Node> node) {
node_map[node.get()] = node;
node_depencency_count[node.get()] = node->get_arguments().size();
if (node->get_arguments().size() == 0)
{
independent_nodes.push_back(node.get());
}
});
while (independent_nodes.size() > 0)
{
auto independent_node = independent_nodes.front();
result_list.push_back(node_map[independent_node]);
independent_nodes.pop_front();
for (auto sp_user : independent_node->get_users())
{
Node* user = sp_user.get();
node_depencency_count[user] -= 1;
size_t count = node_depencency_count[user];
if (count == 0)
{
independent_nodes.push_back(user);
}
}
}
}
Function* pf = const_cast<Function*>(&f);
json nodes;
for (shared_ptr<Node> node : result_list)
for (shared_ptr<Node> node : pf->get_ordered_ops(true))
{
nodes.push_back(write(*node, binary_constant_data));
}
function["ops"] = nodes;
return function;
}
......@@ -362,9 +333,12 @@ static shared_ptr<ngraph::Function>
string node_name = node_js.at("name").get<string>();
string node_op = node_js.at("op").get<string>();
vector<string> node_inputs = node_js.at("inputs").get<vector<string>>();
vector<string> control_deps_inputs =
get_or_default<vector<string>>(node_js, "control_deps", vector<string>{});
vector<string> node_outputs = node_js.at("outputs").get<vector<string>>();
shared_ptr<Node> node;
vector<shared_ptr<Node>> args;
vector<shared_ptr<Node>> control_deps;
for (const string& name : node_inputs)
{
args.push_back(node_map.at(name));
......@@ -934,6 +908,12 @@ static shared_ptr<ngraph::Function>
ss << "unsupported op " << node_op;
throw runtime_error(ss.str());
}
for (const string& name : control_deps_inputs)
{
node->add_control_dependency(node_map.at(name));
}
node_map[node_name] = node;
// Typically, it could be unsafe to change the name of a node since it may break nameing
......@@ -1000,18 +980,24 @@ static json write(const Node& n, bool binary_constant_data)
node["op"] = n.description();
// TODO Multiple outputs
json inputs = json::array();
json control_deps = json::array();
json outputs = json::array();
for (const descriptor::Input& input : n.get_inputs())
{
inputs.push_back(input.get_output().get_node()->get_name());
}
for (auto cdep : n.get_control_dependencies())
{
control_deps.push_back(cdep->get_name());
}
for (size_t i = 0; i < n.get_output_size(); ++i)
{
outputs.push_back(n.get_output_tensor(i).get_name());
}
node["inputs"] = inputs;
node["control_deps"] = control_deps;
node["outputs"] = outputs;
if (std::getenv("NGRAPH_SERIALIZER_OUTPUT_SHAPES") != nullptr)
......
......@@ -205,15 +205,17 @@ ngraph::FpropCache ngraph::cache_fprop(std::shared_ptr<ngraph::Function> fprop,
// Traverse bprop to find all of the nodes in the bprop graph
std::unordered_set<std::shared_ptr<Node>> in_bprop;
ngraph::traverse_nodes(bprop, [&in_bprop](std::shared_ptr<Node> node) {
if (node->get_outputs().size() == 1)
{
if (in_bprop.count(node) == 0)
{
in_bprop.insert(node);
}
}
});
ngraph::traverse_nodes(bprop,
[&in_bprop](std::shared_ptr<Node> node) {
if (node->get_outputs().size() == 1)
{
if (in_bprop.count(node) == 0)
{
in_bprop.insert(node);
}
}
},
false /* no control dependencies */);
// Traverse fprop to make a map that stores parameters with the same
// shape and element type as the nodes in fprop iff they are in bprop
......@@ -290,7 +292,8 @@ ngraph::FpropCache ngraph::cache_fprop(std::shared_ptr<ngraph::Function> fprop,
{
fprop_cache.fprop_output_nodes.push_back(inverted_node_map.at(node));
}
});
},
false /* no control dependencies */);
// create the new outputs for fprop and the new fprop function
ResultVector fprop_outputs = fprop->get_results();
......
......@@ -44,6 +44,8 @@ set(SRC
util.cpp
uuid.cpp
zero_dim_tensor_elimination.cpp
control_dependencies.cpp
)
if (NGRAPH_ONNX_IMPORT_ENABLE)
......
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include <algorithm>
#include <cstdio>
#include <iostream>
#include <list>
#include <memory>
#include "gtest/gtest.h"
#include "ngraph/autodiff/adjoints.hpp"
#include "ngraph/file_util.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp"
#include "ngraph/ngraph.hpp"
#include "ngraph/op/batch_norm.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/parameter.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/visualize_tree.hpp"
#include "ngraph/pattern/matcher.hpp"
#include "ngraph/runtime/cpu/pass/cpu_fusion.hpp"
#include "ngraph/serializer.hpp"
#include "ngraph/util.hpp"
#include "nlohmann/json.hpp"
#include "util/all_close.hpp"
#include "util/autodiff/backprop_function.hpp"
#include "util/autodiff/numeric_compare.hpp"
#include "util/ndarray.hpp"
#include "util/random.hpp"
#include "util/test_tools.hpp"
using namespace ngraph;
using namespace std;
class ControlDependencyOp : public ngraph::op::Op
{
public:
virtual std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override
{
auto clone = make_shared<ControlDependencyOp>(new_args, std::set<std::shared_ptr<Node>>{});
return clone;
}
ControlDependencyOp(const NodeVector& args, const std::set<std::shared_ptr<Node>>& deps)
: Op("ControlDependencyOp", args)
{
if (args.size() == 0 && deps.size() == 0)
{
throw ngraph_error("Expected some arguments or dependencies");
}
if (deps.size() != 0)
{
m_control_dependencies.insert(deps.begin(), deps.end());
}
if (args.size() != 0)
{
set_output_type(0, args.at(0)->get_element_type(), args.at(0)->get_shape());
}
else
{
auto dn = *(deps.begin());
set_output_type(0, dn->get_element_type(), dn->get_shape());
}
}
};
TEST(control_dependencies, cdep_ops)
{
auto A = make_shared<op::Parameter>(element::f32, Shape{});
auto B = make_shared<op::Parameter>(element::f32, Shape{});
auto absn = make_shared<op::Abs>(A);
auto cdop =
make_shared<ControlDependencyOp>(NodeVector{A}, std::set<std::shared_ptr<Node>>{absn});
auto f = make_shared<Function>(cdop, op::ParameterVector{A, B});
auto nodes = f->get_ordered_ops(true);
ASSERT_EQ(nodes.back()->get_argument(0), cdop);
}
TEST(control_dependencies, two_cdep_ops)
{
auto A = make_shared<op::Parameter>(element::f32, Shape{});
auto B = make_shared<op::Parameter>(element::f32, Shape{});
auto absn = make_shared<op::Abs>(A);
auto C = make_shared<op::Parameter>(element::f32, Shape{});
auto absn_c = make_shared<op::Abs>(C);
auto cdop = make_shared<ControlDependencyOp>(NodeVector{A},
std::set<std::shared_ptr<Node>>{absn, absn_c});
auto f = make_shared<Function>(cdop, op::ParameterVector{A, B, C});
auto nodes = f->get_ordered_ops(true);
ASSERT_EQ(nodes.back()->get_argument(0), cdop);
}
TEST(control_dependencies, two_cdep_ops_op_on_top)
{
auto A = make_shared<op::Parameter>(element::f32, Shape{});
auto absn = make_shared<op::Abs>(A);
auto B = make_shared<op::Parameter>(element::f32, Shape{});
auto absn_b = make_shared<op::Abs>(B);
auto cdop = make_shared<ControlDependencyOp>(NodeVector{A},
std::set<std::shared_ptr<Node>>{absn, absn_b});
auto absn_cdop = make_shared<op::Abs>(cdop);
auto f = make_shared<Function>(absn_cdop, op::ParameterVector{A, B});
auto nodes = f->get_ordered_ops(true);
ASSERT_EQ(nodes.back()->get_argument(0), absn_cdop);
}
TEST(control_dependencies, clone_function_cdop)
{
auto A = make_shared<op::Parameter>(element::f32, Shape{});
auto absn = make_shared<op::Abs>(A);
auto cdop =
make_shared<ControlDependencyOp>(NodeVector{A}, std::set<std::shared_ptr<Node>>{absn});
auto f = make_shared<Function>(cdop, op::ParameterVector{A});
auto clone = ngraph::clone_function(*f.get());
auto matcher = std::make_shared<pattern::Matcher>(cdop, nullptr);
auto cdop_clone = clone->get_results().at(0)->get_argument(0);
ASSERT_TRUE(matcher->match(cdop_clone));
auto cloned_deps = cdop_clone->get_control_dependencies();
ASSERT_EQ(cloned_deps.size(), 1);
auto cloned_abs = *begin(cloned_deps);
ASSERT_TRUE(std::dynamic_pointer_cast<op::Abs>(cloned_abs));
}
TEST(control_dependencies, clone_function_cdop_abs)
{
auto A = make_shared<op::Parameter>(element::f32, Shape{});
auto absn = make_shared<op::Abs>(A);
auto B = make_shared<op::Parameter>(element::f32, Shape{});
auto absn_b = make_shared<op::Abs>(B);
auto cdop = make_shared<ControlDependencyOp>(NodeVector{A},
std::set<std::shared_ptr<Node>>{absn, absn_b});
auto absn_cdop = make_shared<op::Abs>(cdop);
auto f = make_shared<Function>(absn_cdop, op::ParameterVector{A, B});
auto clone = ngraph::clone_function(*f.get());
auto matcher = std::make_shared<pattern::Matcher>(cdop, nullptr);
auto cdop_clone = clone->get_results().at(0)->get_argument(0)->get_argument(0);
ASSERT_TRUE(matcher->match(cdop_clone));
auto cloned_deps = cdop_clone->get_control_dependencies();
ASSERT_EQ(cloned_deps.size(), 2);
for (auto ccdep : cloned_deps)
{
ASSERT_TRUE(std::dynamic_pointer_cast<op::Abs>(ccdep));
}
}
TEST(control_dependencies, serialize_cdop)
{
auto A = make_shared<op::Parameter>(element::f32, Shape{});
auto absn = make_shared<op::Abs>(A);
auto cdop = make_shared<op::Negative>(A);
cdop->add_control_dependency(absn);
auto f = make_shared<Function>(cdop, op::ParameterVector{A});
string js = serialize(f, 4);
shared_ptr<Function> clone = deserialize(js);
auto matcher = std::make_shared<pattern::Matcher>(cdop, nullptr);
auto cdop_clone = clone->get_results().at(0)->get_argument(0);
ASSERT_TRUE(matcher->match(cdop_clone));
auto cloned_deps = cdop_clone->get_control_dependencies();
ASSERT_EQ(cloned_deps.size(), 1);
auto cloned_abs = *begin(cloned_deps);
ASSERT_TRUE(std::dynamic_pointer_cast<op::Abs>(cloned_abs));
}
TEST(control_dependencies, serialize_cdop_abs)
{
auto A = make_shared<op::Parameter>(element::f32, Shape{});
auto absn = make_shared<op::Abs>(A);
auto B = make_shared<op::Parameter>(element::f32, Shape{});
auto absn_b = make_shared<op::Abs>(B);
auto cdop = make_shared<op::Negative>(A);
cdop->add_control_dependency(absn);
cdop->add_control_dependency(absn_b);
auto absn_cdop = make_shared<op::Abs>(cdop);
auto f = make_shared<Function>(absn_cdop, op::ParameterVector{A, B});
string js = serialize(f, 4);
shared_ptr<Function> clone = deserialize(js);
auto matcher = std::make_shared<pattern::Matcher>(cdop, nullptr);
auto cdop_clone = clone->get_results().at(0)->get_argument(0)->get_argument(0);
ASSERT_TRUE(matcher->match(cdop_clone));
auto cloned_deps = cdop_clone->get_control_dependencies();
ASSERT_EQ(cloned_deps.size(), 2);
for (auto ccdep : cloned_deps)
{
ASSERT_TRUE(std::dynamic_pointer_cast<op::Abs>(ccdep));
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment