Commit 58f9af01 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Robert Kimball

Control dependencies (#1445)

* topological sort with cdeps

* add control deps API, fix unit tests

* rollback adjoints changes

* fix test failures,add more tests

* remove dead code

* address scott's feedback
parent 68eb2e7d
...@@ -85,24 +85,23 @@ void Function::init() ...@@ -85,24 +85,23 @@ void Function::init()
{ {
validate_nodes_and_infer_types(); validate_nodes_and_infer_types();
traverse_nodes(this, [&](shared_ptr<Node> node) { traverse_nodes(this,
std::shared_ptr<op::Parameter> p = std::dynamic_pointer_cast<op::Parameter>(node); [&](shared_ptr<Node> node) {
if (nullptr != p) if (node->is_parameter())
{ {
auto it = std::find_if(m_parameters.begin(), auto it = std::find(m_parameters.begin(), m_parameters.end(), node);
m_parameters.end(), if (it == m_parameters.end())
[p](std::shared_ptr<op::Parameter> q) { return (p == q); }); {
if (it == m_parameters.end()) throw ngraph_error("Function references undeclared parameter");
{ }
throw ngraph_error("Function references undeclared parameter"); }
} },
} true /*include control dependencies*/);
});
} }
std::list<shared_ptr<Node>> Function::get_ordered_ops() std::list<shared_ptr<Node>> Function::get_ordered_ops(bool include_control_deps) const
{ {
return topological_sort(get_ops()); return topological_sort(get_ops(include_control_deps), include_control_deps);
} }
const std::string& Function::get_friendly_name() const const std::string& Function::get_friendly_name() const
...@@ -176,10 +175,10 @@ shared_ptr<Node> Function::get_result() const ...@@ -176,10 +175,10 @@ shared_ptr<Node> Function::get_result() const
return m_results.at(0); return m_results.at(0);
} }
std::list<shared_ptr<Node>> Function::get_ops() const std::list<shared_ptr<Node>> Function::get_ops(bool include_control_deps) const
{ {
std::list<std::shared_ptr<Node>> ops; std::list<std::shared_ptr<Node>> ops;
traverse_nodes(this, [&](shared_ptr<Node> node) { ops.push_back(node); }); traverse_nodes(this, [&](shared_ptr<Node> node) { ops.push_back(node); }, include_control_deps);
return ops; return ops;
} }
......
...@@ -73,8 +73,8 @@ namespace ngraph ...@@ -73,8 +73,8 @@ namespace ngraph
// so we can use `dynamic_cast` in FunctionCall to double check if we are dealing with // so we can use `dynamic_cast` in FunctionCall to double check if we are dealing with
// an XLA or regular function // an XLA or regular function
void set_name(const std::string& name); void set_name(const std::string& name);
std::list<std::shared_ptr<Node>> get_ops() const; std::list<std::shared_ptr<Node>> get_ops(bool include_control_deps = true) const;
std::list<std::shared_ptr<Node>> get_ordered_ops(); std::list<std::shared_ptr<Node>> get_ordered_ops(bool include_control_deps = true) const;
friend std::ostream& operator<<(std::ostream&, const Function&); friend std::ostream& operator<<(std::ostream&, const Function&);
size_t get_instance_id() { return m_instance_id; } size_t get_instance_id() { return m_instance_id; }
size_t get_temporary_pool_size(); size_t get_temporary_pool_size();
......
...@@ -39,12 +39,15 @@ using namespace std; ...@@ -39,12 +39,15 @@ using namespace std;
using namespace ngraph; using namespace ngraph;
void ngraph::traverse_nodes(const std::shared_ptr<const Function> p, void ngraph::traverse_nodes(const std::shared_ptr<const Function> p,
std::function<void(std::shared_ptr<Node>)> f) std::function<void(std::shared_ptr<Node>)> f,
bool include_control_deps)
{ {
traverse_nodes(p.get(), f); traverse_nodes(p.get(), f, include_control_deps);
} }
void ngraph::traverse_nodes(const Function* p, std::function<void(std::shared_ptr<Node>)> f) void ngraph::traverse_nodes(const Function* p,
std::function<void(std::shared_ptr<Node>)> f,
bool include_control_deps)
{ {
NodeVector nodes; NodeVector nodes;
...@@ -58,14 +61,15 @@ void ngraph::traverse_nodes(const Function* p, std::function<void(std::shared_pt ...@@ -58,14 +61,15 @@ void ngraph::traverse_nodes(const Function* p, std::function<void(std::shared_pt
nodes.push_back(param); nodes.push_back(param);
} }
traverse_nodes(nodes, f); traverse_nodes(nodes, f, include_control_deps);
} }
// This version of traverses directly from input/output nodes to perform functions on // This version of traverses directly from input/output nodes to perform functions on
// graphs that are not wrapped by functions. Most useful for finding parameters of a graph // graphs that are not wrapped by functions. Most useful for finding parameters of a graph
// directly from the result nodes, not from function parameters. // directly from the result nodes, not from function parameters.
void ngraph::traverse_nodes(const NodeVector& io_nodes, void ngraph::traverse_nodes(const NodeVector& io_nodes,
std::function<void(std::shared_ptr<Node>)> f) std::function<void(std::shared_ptr<Node>)> f,
bool include_control_deps)
{ {
std::unordered_set<std::shared_ptr<Node>> instances_seen; std::unordered_set<std::shared_ptr<Node>> instances_seen;
std::deque<std::shared_ptr<Node>> stack; std::deque<std::shared_ptr<Node>> stack;
...@@ -91,6 +95,17 @@ void ngraph::traverse_nodes(const NodeVector& io_nodes, ...@@ -91,6 +95,17 @@ void ngraph::traverse_nodes(const NodeVector& io_nodes,
stack.push_front(arg); stack.push_front(arg);
} }
} }
if (include_control_deps)
{
for (auto cdep : n->get_control_dependencies())
{
if (instances_seen.count(cdep) == 0)
{
stack.push_front(cdep);
}
}
}
} }
} }
...@@ -213,7 +228,7 @@ std::list<std::shared_ptr<ngraph::Node>> ...@@ -213,7 +228,7 @@ std::list<std::shared_ptr<ngraph::Node>>
ngraph::clone_nodes(const std::list<std::shared_ptr<ngraph::Node>>& nodes, NodeMap& node_map) ngraph::clone_nodes(const std::list<std::shared_ptr<ngraph::Node>>& nodes, NodeMap& node_map)
{ {
// for each node in topological order // for each node in topological order
auto sorted_nodes = topological_sort(nodes); auto sorted_nodes = topological_sort(nodes, true);
for (auto node : sorted_nodes) for (auto node : sorted_nodes)
{ {
if (!node_map.exists(node)) if (!node_map.exists(node))
...@@ -224,7 +239,14 @@ std::list<std::shared_ptr<ngraph::Node>> ...@@ -224,7 +239,14 @@ std::list<std::shared_ptr<ngraph::Node>>
{ {
cloned_args.push_back(node_map.get(arg)); cloned_args.push_back(node_map.get(arg));
} }
node_map.add(node, node->copy_with_new_args(cloned_args)); auto cloned_node = node->copy_with_new_args(cloned_args);
//copy control dependencies
for (auto cdep : node->get_control_dependencies())
{
cloned_node->add_control_dependency(node_map.get(cdep));
}
node_map.add(node, cloned_node);
} }
} }
...@@ -248,7 +270,7 @@ std::shared_ptr<ngraph::Function> ngraph::clone_function(const ngraph::Function& ...@@ -248,7 +270,7 @@ std::shared_ptr<ngraph::Function> ngraph::clone_function(const ngraph::Function&
NodeMap& node_map) NodeMap& node_map)
{ {
// clone function operations // clone function operations
clone_nodes(func.get_ops(), node_map); clone_nodes(func.get_ops(true), node_map);
// get cloned function results and parameters // get cloned function results and parameters
ResultVector cloned_results; ResultVector cloned_results;
......
...@@ -43,10 +43,15 @@ namespace ngraph ...@@ -43,10 +43,15 @@ namespace ngraph
} }
void traverse_nodes(const std::shared_ptr<const Function> p, void traverse_nodes(const std::shared_ptr<const Function> p,
std::function<void(std::shared_ptr<Node>)> f); std::function<void(std::shared_ptr<Node>)> f,
void traverse_nodes(const Function* p, std::function<void(std::shared_ptr<Node>)> f); bool include_control_deps = false);
void traverse_nodes(const Function* p,
std::function<void(std::shared_ptr<Node>)> f,
bool include_control_deps);
void traverse_nodes(const NodeVector& io_nodes, std::function<void(std::shared_ptr<Node>)> f); void traverse_nodes(const NodeVector& io_nodes,
std::function<void(std::shared_ptr<Node>)> f,
bool include_control_deps);
void traverse_functions(std::shared_ptr<Function> p, void traverse_functions(std::shared_ptr<Function> p,
std::function<void(std::shared_ptr<Function>)> f); std::function<void(std::shared_ptr<Function>)> f);
...@@ -54,17 +59,31 @@ namespace ngraph ...@@ -54,17 +59,31 @@ namespace ngraph
void replace_node(std::shared_ptr<Node> target, std::shared_ptr<Node> replacement); void replace_node(std::shared_ptr<Node> target, std::shared_ptr<Node> replacement);
template <typename T> template <typename T>
std::list<std::shared_ptr<Node>> topological_sort(const T& nodes) std::list<std::shared_ptr<Node>> topological_sort(const T& nodes,
bool include_control_deps = false)
{ {
std::deque<ngraph::Node*> independent_nodes; std::deque<ngraph::Node*> independent_nodes;
std::unordered_map<const ngraph::Node*, size_t> node_dependency_count; std::unordered_map<const ngraph::Node*, size_t> node_dependency_count;
std::unordered_map<ngraph::Node*, std::shared_ptr<ngraph::Node>> node_map; std::unordered_map<ngraph::Node*, std::shared_ptr<ngraph::Node>> node_map;
std::unordered_map<ngraph::Node*, std::set<Node*>> control_deps_users;
for (auto node : nodes) for (auto node : nodes)
{ {
//build an equivalent of node->get_users() but for control dependencies
size_t control_deps_count = 0;
if (include_control_deps)
{
for (auto cd : node->get_control_dependencies())
{
control_deps_users[cd.get()].insert(node.get());
}
control_deps_count = node->get_control_dependencies().size();
}
node_map[node.get()] = node; node_map[node.get()] = node;
node_dependency_count[node.get()] = node->get_arguments().size(); size_t deps_count = node->get_arguments().size() + control_deps_count;
if (node->get_arguments().size() == 0) node_dependency_count[node.get()] = deps_count;
if (deps_count == 0)
{ {
independent_nodes.push_back(node.get()); independent_nodes.push_back(node.get());
} }
...@@ -87,6 +106,21 @@ namespace ngraph ...@@ -87,6 +106,21 @@ namespace ngraph
independent_nodes.push_back(user); independent_nodes.push_back(user);
} }
} }
if (include_control_deps)
{
auto cdit = control_deps_users.find(independent_node);
if (cdit != control_deps_users.end())
for (auto cd_user : cdit->second)
{
node_dependency_count[cd_user] -= 1;
size_t count = node_dependency_count[cd_user];
if (count == 0)
{
independent_nodes.push_back(cd_user);
}
}
}
} }
return result_list; return result_list;
......
...@@ -182,6 +182,16 @@ NodeVector Node::get_arguments() const ...@@ -182,6 +182,16 @@ NodeVector Node::get_arguments() const
return result; return result;
} }
const std::set<std::shared_ptr<Node>>& Node::get_control_dependencies() const
{
return m_control_dependencies;
}
void Node::add_control_dependency(std::shared_ptr<Node> node)
{
m_control_dependencies.insert(node);
}
std::vector<std::shared_ptr<Function>> Node::get_functions() const std::vector<std::shared_ptr<Function>> Node::get_functions() const
{ {
return std::vector<std::shared_ptr<Function>>{}; return std::vector<std::shared_ptr<Function>>{};
......
...@@ -144,6 +144,16 @@ namespace ngraph ...@@ -144,6 +144,16 @@ namespace ngraph
// TODO: Remove from unit tests. // TODO: Remove from unit tests.
const std::deque<descriptor::Output>& get_outputs() const; const std::deque<descriptor::Output>& get_outputs() const;
/// Get control dependencies registered on the node
const std::set<std::shared_ptr<Node>>& get_control_dependencies() const;
void add_control_dependency(std::shared_ptr<Node> node);
void remove_control_dependency(std::shared_ptr<Node> node)
{
m_control_dependencies.erase(node);
}
/// Returns the number of outputs on the for the node. /// Returns the number of outputs on the for the node.
size_t get_output_size() const; size_t get_output_size() const;
...@@ -216,6 +226,7 @@ namespace ngraph ...@@ -216,6 +226,7 @@ namespace ngraph
/// Use instance ids for comparison instead of memory addresses to improve determinism /// Use instance ids for comparison instead of memory addresses to improve determinism
bool operator<(const Node& other) const { return m_instance_id < other.m_instance_id; } bool operator<(const Node& other) const { return m_instance_id < other.m_instance_id; }
protected: protected:
std::set<std::shared_ptr<Node>> m_control_dependencies;
void set_output_size(size_t n); void set_output_size(size_t n);
std::string m_node_type; std::string m_node_type;
......
...@@ -167,14 +167,17 @@ void ngraph::serialize(ostream& out, shared_ptr<ngraph::Function> func, size_t i ...@@ -167,14 +167,17 @@ void ngraph::serialize(ostream& out, shared_ptr<ngraph::Function> func, size_t i
writer.write(func->get_name(), j.c_str(), static_cast<uint32_t>(j.size())); writer.write(func->get_name(), j.c_str(), static_cast<uint32_t>(j.size()));
traverse_functions(func, [&](shared_ptr<ngraph::Function> f) { traverse_functions(func, [&](shared_ptr<ngraph::Function> f) {
traverse_nodes(const_cast<Function*>(f.get()), [&](shared_ptr<Node> node) { traverse_nodes(const_cast<Function*>(f.get()),
if (auto c = dynamic_pointer_cast<op::Constant>(node)) [&](shared_ptr<Node> node) {
{ if (auto c = dynamic_pointer_cast<op::Constant>(node))
uint32_t size = static_cast<uint32_t>(shape_size(c->get_output_shape(0)) * {
c->get_output_element_type(0).size()); uint32_t size =
writer.write(c->get_name(), c->get_data_ptr(), size); static_cast<uint32_t>(shape_size(c->get_output_shape(0)) *
} c->get_output_element_type(0).size());
}); writer.write(c->get_name(), c->get_data_ptr(), size);
}
},
true);
}); });
writer.close(); writer.close();
...@@ -301,45 +304,13 @@ static json write(const Function& f, bool binary_constant_data) ...@@ -301,45 +304,13 @@ static json write(const Function& f, bool binary_constant_data)
function["result"].push_back(f.get_output_op(i)->get_name()); function["result"].push_back(f.get_output_op(i)->get_name());
} }
list<shared_ptr<Node>> result_list; Function* pf = const_cast<Function*>(&f);
{
deque<Node*> independent_nodes;
unordered_map<const Node*, size_t> node_depencency_count;
unordered_map<Node*, shared_ptr<Node>> node_map;
traverse_nodes(const_cast<Function*>(&f), [&](shared_ptr<Node> node) {
node_map[node.get()] = node;
node_depencency_count[node.get()] = node->get_arguments().size();
if (node->get_arguments().size() == 0)
{
independent_nodes.push_back(node.get());
}
});
while (independent_nodes.size() > 0)
{
auto independent_node = independent_nodes.front();
result_list.push_back(node_map[independent_node]);
independent_nodes.pop_front();
for (auto sp_user : independent_node->get_users())
{
Node* user = sp_user.get();
node_depencency_count[user] -= 1;
size_t count = node_depencency_count[user];
if (count == 0)
{
independent_nodes.push_back(user);
}
}
}
}
json nodes; json nodes;
for (shared_ptr<Node> node : result_list) for (shared_ptr<Node> node : pf->get_ordered_ops(true))
{ {
nodes.push_back(write(*node, binary_constant_data)); nodes.push_back(write(*node, binary_constant_data));
} }
function["ops"] = nodes; function["ops"] = nodes;
return function; return function;
} }
...@@ -362,9 +333,12 @@ static shared_ptr<ngraph::Function> ...@@ -362,9 +333,12 @@ static shared_ptr<ngraph::Function>
string node_name = node_js.at("name").get<string>(); string node_name = node_js.at("name").get<string>();
string node_op = node_js.at("op").get<string>(); string node_op = node_js.at("op").get<string>();
vector<string> node_inputs = node_js.at("inputs").get<vector<string>>(); vector<string> node_inputs = node_js.at("inputs").get<vector<string>>();
vector<string> control_deps_inputs =
get_or_default<vector<string>>(node_js, "control_deps", vector<string>{});
vector<string> node_outputs = node_js.at("outputs").get<vector<string>>(); vector<string> node_outputs = node_js.at("outputs").get<vector<string>>();
shared_ptr<Node> node; shared_ptr<Node> node;
vector<shared_ptr<Node>> args; vector<shared_ptr<Node>> args;
vector<shared_ptr<Node>> control_deps;
for (const string& name : node_inputs) for (const string& name : node_inputs)
{ {
args.push_back(node_map.at(name)); args.push_back(node_map.at(name));
...@@ -934,6 +908,12 @@ static shared_ptr<ngraph::Function> ...@@ -934,6 +908,12 @@ static shared_ptr<ngraph::Function>
ss << "unsupported op " << node_op; ss << "unsupported op " << node_op;
throw runtime_error(ss.str()); throw runtime_error(ss.str());
} }
for (const string& name : control_deps_inputs)
{
node->add_control_dependency(node_map.at(name));
}
node_map[node_name] = node; node_map[node_name] = node;
// Typically, it could be unsafe to change the name of a node since it may break nameing // Typically, it could be unsafe to change the name of a node since it may break nameing
...@@ -1000,18 +980,24 @@ static json write(const Node& n, bool binary_constant_data) ...@@ -1000,18 +980,24 @@ static json write(const Node& n, bool binary_constant_data)
node["op"] = n.description(); node["op"] = n.description();
// TODO Multiple outputs // TODO Multiple outputs
json inputs = json::array(); json inputs = json::array();
json control_deps = json::array();
json outputs = json::array(); json outputs = json::array();
for (const descriptor::Input& input : n.get_inputs()) for (const descriptor::Input& input : n.get_inputs())
{ {
inputs.push_back(input.get_output().get_node()->get_name()); inputs.push_back(input.get_output().get_node()->get_name());
} }
for (auto cdep : n.get_control_dependencies())
{
control_deps.push_back(cdep->get_name());
}
for (size_t i = 0; i < n.get_output_size(); ++i) for (size_t i = 0; i < n.get_output_size(); ++i)
{ {
outputs.push_back(n.get_output_tensor(i).get_name()); outputs.push_back(n.get_output_tensor(i).get_name());
} }
node["inputs"] = inputs; node["inputs"] = inputs;
node["control_deps"] = control_deps;
node["outputs"] = outputs; node["outputs"] = outputs;
if (std::getenv("NGRAPH_SERIALIZER_OUTPUT_SHAPES") != nullptr) if (std::getenv("NGRAPH_SERIALIZER_OUTPUT_SHAPES") != nullptr)
......
...@@ -205,15 +205,17 @@ ngraph::FpropCache ngraph::cache_fprop(std::shared_ptr<ngraph::Function> fprop, ...@@ -205,15 +205,17 @@ ngraph::FpropCache ngraph::cache_fprop(std::shared_ptr<ngraph::Function> fprop,
// Traverse bprop to find all of the nodes in the bprop graph // Traverse bprop to find all of the nodes in the bprop graph
std::unordered_set<std::shared_ptr<Node>> in_bprop; std::unordered_set<std::shared_ptr<Node>> in_bprop;
ngraph::traverse_nodes(bprop, [&in_bprop](std::shared_ptr<Node> node) { ngraph::traverse_nodes(bprop,
if (node->get_outputs().size() == 1) [&in_bprop](std::shared_ptr<Node> node) {
{ if (node->get_outputs().size() == 1)
if (in_bprop.count(node) == 0) {
{ if (in_bprop.count(node) == 0)
in_bprop.insert(node); {
} in_bprop.insert(node);
} }
}); }
},
false /* no control dependencies */);
// Traverse fprop to make a map that stores parameters with the same // Traverse fprop to make a map that stores parameters with the same
// shape and element type as the nodes in fprop iff they are in bprop // shape and element type as the nodes in fprop iff they are in bprop
...@@ -290,7 +292,8 @@ ngraph::FpropCache ngraph::cache_fprop(std::shared_ptr<ngraph::Function> fprop, ...@@ -290,7 +292,8 @@ ngraph::FpropCache ngraph::cache_fprop(std::shared_ptr<ngraph::Function> fprop,
{ {
fprop_cache.fprop_output_nodes.push_back(inverted_node_map.at(node)); fprop_cache.fprop_output_nodes.push_back(inverted_node_map.at(node));
} }
}); },
false /* no control dependencies */);
// create the new outputs for fprop and the new fprop function // create the new outputs for fprop and the new fprop function
ResultVector fprop_outputs = fprop->get_results(); ResultVector fprop_outputs = fprop->get_results();
......
...@@ -44,6 +44,8 @@ set(SRC ...@@ -44,6 +44,8 @@ set(SRC
util.cpp util.cpp
uuid.cpp uuid.cpp
zero_dim_tensor_elimination.cpp zero_dim_tensor_elimination.cpp
control_dependencies.cpp
) )
if (NGRAPH_ONNX_IMPORT_ENABLE) if (NGRAPH_ONNX_IMPORT_ENABLE)
......
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include <algorithm>
#include <cstdio>
#include <iostream>
#include <list>
#include <memory>
#include "gtest/gtest.h"
#include "ngraph/autodiff/adjoints.hpp"
#include "ngraph/file_util.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp"
#include "ngraph/ngraph.hpp"
#include "ngraph/op/batch_norm.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/parameter.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/visualize_tree.hpp"
#include "ngraph/pattern/matcher.hpp"
#include "ngraph/runtime/cpu/pass/cpu_fusion.hpp"
#include "ngraph/serializer.hpp"
#include "ngraph/util.hpp"
#include "nlohmann/json.hpp"
#include "util/all_close.hpp"
#include "util/autodiff/backprop_function.hpp"
#include "util/autodiff/numeric_compare.hpp"
#include "util/ndarray.hpp"
#include "util/random.hpp"
#include "util/test_tools.hpp"
using namespace ngraph;
using namespace std;
class ControlDependencyOp : public ngraph::op::Op
{
public:
virtual std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override
{
auto clone = make_shared<ControlDependencyOp>(new_args, std::set<std::shared_ptr<Node>>{});
return clone;
}
ControlDependencyOp(const NodeVector& args, const std::set<std::shared_ptr<Node>>& deps)
: Op("ControlDependencyOp", args)
{
if (args.size() == 0 && deps.size() == 0)
{
throw ngraph_error("Expected some arguments or dependencies");
}
if (deps.size() != 0)
{
m_control_dependencies.insert(deps.begin(), deps.end());
}
if (args.size() != 0)
{
set_output_type(0, args.at(0)->get_element_type(), args.at(0)->get_shape());
}
else
{
auto dn = *(deps.begin());
set_output_type(0, dn->get_element_type(), dn->get_shape());
}
}
};
TEST(control_dependencies, cdep_ops)
{
auto A = make_shared<op::Parameter>(element::f32, Shape{});
auto B = make_shared<op::Parameter>(element::f32, Shape{});
auto absn = make_shared<op::Abs>(A);
auto cdop =
make_shared<ControlDependencyOp>(NodeVector{A}, std::set<std::shared_ptr<Node>>{absn});
auto f = make_shared<Function>(cdop, op::ParameterVector{A, B});
auto nodes = f->get_ordered_ops(true);
ASSERT_EQ(nodes.back()->get_argument(0), cdop);
}
TEST(control_dependencies, two_cdep_ops)
{
auto A = make_shared<op::Parameter>(element::f32, Shape{});
auto B = make_shared<op::Parameter>(element::f32, Shape{});
auto absn = make_shared<op::Abs>(A);
auto C = make_shared<op::Parameter>(element::f32, Shape{});
auto absn_c = make_shared<op::Abs>(C);
auto cdop = make_shared<ControlDependencyOp>(NodeVector{A},
std::set<std::shared_ptr<Node>>{absn, absn_c});
auto f = make_shared<Function>(cdop, op::ParameterVector{A, B, C});
auto nodes = f->get_ordered_ops(true);
ASSERT_EQ(nodes.back()->get_argument(0), cdop);
}
TEST(control_dependencies, two_cdep_ops_op_on_top)
{
auto A = make_shared<op::Parameter>(element::f32, Shape{});
auto absn = make_shared<op::Abs>(A);
auto B = make_shared<op::Parameter>(element::f32, Shape{});
auto absn_b = make_shared<op::Abs>(B);
auto cdop = make_shared<ControlDependencyOp>(NodeVector{A},
std::set<std::shared_ptr<Node>>{absn, absn_b});
auto absn_cdop = make_shared<op::Abs>(cdop);
auto f = make_shared<Function>(absn_cdop, op::ParameterVector{A, B});
auto nodes = f->get_ordered_ops(true);
ASSERT_EQ(nodes.back()->get_argument(0), absn_cdop);
}
TEST(control_dependencies, clone_function_cdop)
{
auto A = make_shared<op::Parameter>(element::f32, Shape{});
auto absn = make_shared<op::Abs>(A);
auto cdop =
make_shared<ControlDependencyOp>(NodeVector{A}, std::set<std::shared_ptr<Node>>{absn});
auto f = make_shared<Function>(cdop, op::ParameterVector{A});
auto clone = ngraph::clone_function(*f.get());
auto matcher = std::make_shared<pattern::Matcher>(cdop, nullptr);
auto cdop_clone = clone->get_results().at(0)->get_argument(0);
ASSERT_TRUE(matcher->match(cdop_clone));
auto cloned_deps = cdop_clone->get_control_dependencies();
ASSERT_EQ(cloned_deps.size(), 1);
auto cloned_abs = *begin(cloned_deps);
ASSERT_TRUE(std::dynamic_pointer_cast<op::Abs>(cloned_abs));
}
TEST(control_dependencies, clone_function_cdop_abs)
{
auto A = make_shared<op::Parameter>(element::f32, Shape{});
auto absn = make_shared<op::Abs>(A);
auto B = make_shared<op::Parameter>(element::f32, Shape{});
auto absn_b = make_shared<op::Abs>(B);
auto cdop = make_shared<ControlDependencyOp>(NodeVector{A},
std::set<std::shared_ptr<Node>>{absn, absn_b});
auto absn_cdop = make_shared<op::Abs>(cdop);
auto f = make_shared<Function>(absn_cdop, op::ParameterVector{A, B});
auto clone = ngraph::clone_function(*f.get());
auto matcher = std::make_shared<pattern::Matcher>(cdop, nullptr);
auto cdop_clone = clone->get_results().at(0)->get_argument(0)->get_argument(0);
ASSERT_TRUE(matcher->match(cdop_clone));
auto cloned_deps = cdop_clone->get_control_dependencies();
ASSERT_EQ(cloned_deps.size(), 2);
for (auto ccdep : cloned_deps)
{
ASSERT_TRUE(std::dynamic_pointer_cast<op::Abs>(ccdep));
}
}
TEST(control_dependencies, serialize_cdop)
{
auto A = make_shared<op::Parameter>(element::f32, Shape{});
auto absn = make_shared<op::Abs>(A);
auto cdop = make_shared<op::Negative>(A);
cdop->add_control_dependency(absn);
auto f = make_shared<Function>(cdop, op::ParameterVector{A});
string js = serialize(f, 4);
shared_ptr<Function> clone = deserialize(js);
auto matcher = std::make_shared<pattern::Matcher>(cdop, nullptr);
auto cdop_clone = clone->get_results().at(0)->get_argument(0);
ASSERT_TRUE(matcher->match(cdop_clone));
auto cloned_deps = cdop_clone->get_control_dependencies();
ASSERT_EQ(cloned_deps.size(), 1);
auto cloned_abs = *begin(cloned_deps);
ASSERT_TRUE(std::dynamic_pointer_cast<op::Abs>(cloned_abs));
}
TEST(control_dependencies, serialize_cdop_abs)
{
auto A = make_shared<op::Parameter>(element::f32, Shape{});
auto absn = make_shared<op::Abs>(A);
auto B = make_shared<op::Parameter>(element::f32, Shape{});
auto absn_b = make_shared<op::Abs>(B);
auto cdop = make_shared<op::Negative>(A);
cdop->add_control_dependency(absn);
cdop->add_control_dependency(absn_b);
auto absn_cdop = make_shared<op::Abs>(cdop);
auto f = make_shared<Function>(absn_cdop, op::ParameterVector{A, B});
string js = serialize(f, 4);
shared_ptr<Function> clone = deserialize(js);
auto matcher = std::make_shared<pattern::Matcher>(cdop, nullptr);
auto cdop_clone = clone->get_results().at(0)->get_argument(0)->get_argument(0);
ASSERT_TRUE(matcher->match(cdop_clone));
auto cloned_deps = cdop_clone->get_control_dependencies();
ASSERT_EQ(cloned_deps.size(), 2);
for (auto ccdep : cloned_deps)
{
ASSERT_TRUE(std::dynamic_pointer_cast<op::Abs>(ccdep));
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment