Commit 0768a969 authored by Jayaram Bobba's avatar Jayaram Bobba Committed by Scott Cyphers

Jbobba/conv sum cleanup (#1167)

* inplace compute

* fix warnings

* Initial support for convolution sum fusion

* Added in-place support for conv sum fusion and test cases

* reverting spurious changes

* Bug fix to account for inplace input in conv sum fusion

* fix compilation error

* Addressed PR feedback

* Handle corner cases for conv sum fusion. Skip computation reuse while using an inplace kernel

* Check node argument for in-place relu assignment

* Addressed PR comments

* Addressed PR feedback
parent 5be99c0a
...@@ -174,6 +174,36 @@ std::list<std::shared_ptr<ngraph::Node>> ...@@ -174,6 +174,36 @@ std::list<std::shared_ptr<ngraph::Node>>
return result_list; return result_list;
} }
// Check if all paths from X to a result go through Y
bool ngraph::is_post_dominated(Node* X, Node* Y)
{
std::unordered_set<Node*> visited;
std::deque<Node*> stack;
stack.push_front(X);
while (stack.size() > 0)
{
ngraph::Node* curr = stack.front();
visited.insert(curr);
if (curr->is_output())
{
return false;
}
stack.pop_front();
if (curr != Y)
{
for (auto next : curr->get_users())
{
if (visited.count(next.get()) == 0)
{
stack.push_front(next.get());
}
}
}
}
return true;
}
void ngraph::NodeMap::update(std::shared_ptr<ngraph::Node> orig, std::shared_ptr<ngraph::Node> val) void ngraph::NodeMap::update(std::shared_ptr<ngraph::Node> orig, std::shared_ptr<ngraph::Node> val)
{ {
if (!exists(orig)) if (!exists(orig))
...@@ -435,15 +465,15 @@ bool ngraph::is_one(std::shared_ptr<Node> reduce_constant) ...@@ -435,15 +465,15 @@ bool ngraph::is_one(std::shared_ptr<Node> reduce_constant)
return result_bool; return result_bool;
} }
bool ngraph::is_used(std::shared_ptr<ngraph::Node> node) bool ngraph::is_used(Node* node)
{ {
std::unordered_set<std::shared_ptr<ngraph::Node>> instances_seen; std::unordered_set<Node*> instances_seen;
std::deque<std::shared_ptr<ngraph::Node>> stack; std::deque<Node*> stack;
stack.push_front(node); stack.push_front(node);
while (stack.size() > 0) while (stack.size() > 0)
{ {
std::shared_ptr<ngraph::Node> n = stack.front(); ngraph::Node* n = stack.front();
if (instances_seen.count(n) == 0) if (instances_seen.count(n) == 0)
{ {
if (n->is_output()) if (n->is_output())
...@@ -455,11 +485,21 @@ bool ngraph::is_used(std::shared_ptr<ngraph::Node> node) ...@@ -455,11 +485,21 @@ bool ngraph::is_used(std::shared_ptr<ngraph::Node> node)
stack.pop_front(); stack.pop_front();
for (auto arg : n->get_users()) for (auto arg : n->get_users())
{ {
if (instances_seen.count(arg) == 0) if (instances_seen.count(arg.get()) == 0)
{ {
stack.push_front(arg); stack.push_front(arg.get());
} }
} }
} }
return false; return false;
} }
size_t ngraph::get_user_count(Node* node)
{
size_t count = 0;
for (auto node_user : node->get_users())
{
count += is_used(node_user.get());
}
return count;
}
...@@ -52,6 +52,9 @@ namespace ngraph ...@@ -52,6 +52,9 @@ namespace ngraph
std::list<std::shared_ptr<Node>> std::list<std::shared_ptr<Node>>
topological_sort(const std::list<std::shared_ptr<Node>>& nodes); topological_sort(const std::list<std::shared_ptr<Node>>& nodes);
// Check if all paths from X to a result go through Y
bool is_post_dominated(Node* X, Node* Y);
bool is_equal_to_const_value(std::string const_value, std::shared_ptr<Node> reduce_constant); bool is_equal_to_const_value(std::string const_value, std::shared_ptr<Node> reduce_constant);
// maps original to replacement nodes e.g. for clone utilities // maps original to replacement nodes e.g. for clone utilities
...@@ -132,5 +135,10 @@ namespace ngraph ...@@ -132,5 +135,10 @@ namespace ngraph
bool is_one(std::shared_ptr<Node> reduce_constant); bool is_one(std::shared_ptr<Node> reduce_constant);
bool is_used(std::shared_ptr<Node> node); // Returns true if `node` is live in the graph i.e. a result op
// transitively uses this `node`
bool is_used(Node* node);
// Returns count of `node` users that are still live in the graph
size_t get_user_count(Node* node);
} }
...@@ -809,7 +809,7 @@ using namespace ngraph::runtime; ...@@ -809,7 +809,7 @@ using namespace ngraph::runtime;
} }
auto computes_output = [&]() { auto computes_output = [&]() {
if (std::dynamic_pointer_cast<ngraph::op::Result>(node)) if (node->is_output())
{ {
return true; return true;
} }
...@@ -828,8 +828,33 @@ using namespace ngraph::runtime; ...@@ -828,8 +828,33 @@ using namespace ngraph::runtime;
} }
return false; return false;
}; };
// Always enable nodes computing output tensors
if (computes_output()) auto possibly_overwritten = [&]() {
for (const descriptor::Output& output : node->get_outputs())
{
for (const descriptor::Input* input : output.get_inputs())
{
if (auto op =
std::dynamic_pointer_cast<ngraph::op::Op>(input->get_node()))
{
if (auto op_annotations = op->get_op_annotations())
{
for (auto oi_pair : op_annotations->get_in_place_oi_pairs())
{
if (input->get_index() == oi_pair.second)
{
return true;
}
}
}
}
}
}
return false;
};
// Always enable nodes computing output tensors or nodes whose outputs might get
// overwritten due to inplace kernels
if (computes_output() || possibly_overwritten())
{ {
writer << " || 1"; writer << " || 1";
} }
......
...@@ -499,8 +499,12 @@ namespace ngraph ...@@ -499,8 +499,12 @@ namespace ngraph
auto op_annotations = auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>(); std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true); op_annotations->set_mkldnn_op(true);
std::map<size_t, size_t> oi_pairs = {{0, 0}}; if (get_user_count(node->get_argument(0).get()) == 1)
op_annotations->set_in_place_oi_pairs(oi_pairs); {
// Safe to overwrite input
std::map<size_t, size_t> oi_pairs = {{0, 0}};
op_annotations->set_in_place_oi_pairs(oi_pairs);
}
relu->set_op_annotations(op_annotations); relu->set_op_annotations(op_annotations);
} }
} }
......
...@@ -1063,18 +1063,17 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_conv_bias_add() ...@@ -1063,18 +1063,17 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_conv_bias_add()
return false; return false;
} }
if (conv_m->get_users().size() > 1) if (get_user_count(conv_m.get()) > 1)
{ {
NGRAPH_DEBUG << "Convolution has more than one user"; NGRAPH_DEBUG << "Convolution has more than one user";
return false;
// return false;
} }
if (inplace_input->get_users().size() > 1) if (!is_post_dominated(inplace_input.get(), add_m.get()))
{ {
NGRAPH_DEBUG << "Add has more than one user. Convolution Add might use an in-place " NGRAPH_DEBUG << "Unsafe to use in-place kernel since add's in-place input has "
"destructive kernel"; "potential live users";
// return false; return false;
} }
if (inplace_input->is_parameter()) if (inplace_input->is_parameter())
......
...@@ -539,7 +539,7 @@ void ngraph::runtime::cpu::pass::RNNFusion::construct_rnn_lstm_fprop() ...@@ -539,7 +539,7 @@ void ngraph::runtime::cpu::pass::RNNFusion::construct_rnn_lstm_fprop()
{ {
if (std::find(lstm_nodes.begin(), lstm_nodes.end(), goe0_user) == if (std::find(lstm_nodes.begin(), lstm_nodes.end(), goe0_user) ==
lstm_nodes.end() && lstm_nodes.end() &&
ngraph::is_used(goe0_user)) ngraph::is_used(goe0_user.get()))
{ {
lstm_goe0_user.insert(goe0_user); lstm_goe0_user.insert(goe0_user);
map_goe_to_lstm_slices[goe_0] = ht_slice_per_timestep[index]; map_goe_to_lstm_slices[goe_0] = ht_slice_per_timestep[index];
...@@ -818,7 +818,7 @@ void ngraph::runtime::cpu::pass::MultiLayerRNNFusion::construct_multi_layer_rnn_ ...@@ -818,7 +818,7 @@ void ngraph::runtime::cpu::pass::MultiLayerRNNFusion::construct_multi_layer_rnn_
node_to_replace = rnn_ct->get_users()[0]; node_to_replace = rnn_ct->get_users()[0];
} }
} }
if (ngraph::is_used(node_to_replace)) if (ngraph::is_used(node_to_replace.get()))
{ {
ngraph::replace_node(node_to_replace, ct_slice); ngraph::replace_node(node_to_replace, ct_slice);
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment