Commit 0768a969 authored by Jayaram Bobba's avatar Jayaram Bobba Committed by Scott Cyphers

Jbobba/conv sum cleanup (#1167)

* inplace compute

* fix warnings

* Initial support for convolution sum fusion

* Added in-place support for conv sum fusion and test cases

* reverting spurious changes

* Bug fix to account for inplace input in conv sum fusion

* fix compilation error

* Addressed PR feedback

* Handle corner cases for conv sum fusion. Skip computation reuse while using an inplace kernel

* Check node argument for in-place relu assignment

* Addressed PR comments

* Addressed PR feedback
parent 5be99c0a
......@@ -174,6 +174,36 @@ std::list<std::shared_ptr<ngraph::Node>>
return result_list;
}
// Check if all paths from X to a result go through Y
bool ngraph::is_post_dominated(Node* X, Node* Y)
{
std::unordered_set<Node*> visited;
std::deque<Node*> stack;
stack.push_front(X);
while (stack.size() > 0)
{
ngraph::Node* curr = stack.front();
visited.insert(curr);
if (curr->is_output())
{
return false;
}
stack.pop_front();
if (curr != Y)
{
for (auto next : curr->get_users())
{
if (visited.count(next.get()) == 0)
{
stack.push_front(next.get());
}
}
}
}
return true;
}
void ngraph::NodeMap::update(std::shared_ptr<ngraph::Node> orig, std::shared_ptr<ngraph::Node> val)
{
if (!exists(orig))
......@@ -435,15 +465,15 @@ bool ngraph::is_one(std::shared_ptr<Node> reduce_constant)
return result_bool;
}
bool ngraph::is_used(std::shared_ptr<ngraph::Node> node)
bool ngraph::is_used(Node* node)
{
std::unordered_set<std::shared_ptr<ngraph::Node>> instances_seen;
std::deque<std::shared_ptr<ngraph::Node>> stack;
std::unordered_set<Node*> instances_seen;
std::deque<Node*> stack;
stack.push_front(node);
while (stack.size() > 0)
{
std::shared_ptr<ngraph::Node> n = stack.front();
ngraph::Node* n = stack.front();
if (instances_seen.count(n) == 0)
{
if (n->is_output())
......@@ -455,11 +485,21 @@ bool ngraph::is_used(std::shared_ptr<ngraph::Node> node)
stack.pop_front();
for (auto arg : n->get_users())
{
if (instances_seen.count(arg) == 0)
if (instances_seen.count(arg.get()) == 0)
{
stack.push_front(arg);
stack.push_front(arg.get());
}
}
}
return false;
}
size_t ngraph::get_user_count(Node* node)
{
size_t count = 0;
for (auto node_user : node->get_users())
{
count += is_used(node_user.get());
}
return count;
}
......@@ -52,6 +52,9 @@ namespace ngraph
std::list<std::shared_ptr<Node>>
topological_sort(const std::list<std::shared_ptr<Node>>& nodes);
// Check if all paths from X to a result go through Y
bool is_post_dominated(Node* X, Node* Y);
bool is_equal_to_const_value(std::string const_value, std::shared_ptr<Node> reduce_constant);
// maps original to replacement nodes e.g. for clone utilities
......@@ -132,5 +135,10 @@ namespace ngraph
bool is_one(std::shared_ptr<Node> reduce_constant);
bool is_used(std::shared_ptr<Node> node);
// Returns true if `node` is live in the graph i.e. a result op
// transitively uses this `node`
bool is_used(Node* node);
// Returns count of `node` users that are still live in the graph
size_t get_user_count(Node* node);
}
......@@ -809,7 +809,7 @@ using namespace ngraph::runtime;
}
auto computes_output = [&]() {
if (std::dynamic_pointer_cast<ngraph::op::Result>(node))
if (node->is_output())
{
return true;
}
......@@ -828,8 +828,33 @@ using namespace ngraph::runtime;
}
return false;
};
// Always enable nodes computing output tensors
if (computes_output())
auto possibly_overwritten = [&]() {
for (const descriptor::Output& output : node->get_outputs())
{
for (const descriptor::Input* input : output.get_inputs())
{
if (auto op =
std::dynamic_pointer_cast<ngraph::op::Op>(input->get_node()))
{
if (auto op_annotations = op->get_op_annotations())
{
for (auto oi_pair : op_annotations->get_in_place_oi_pairs())
{
if (input->get_index() == oi_pair.second)
{
return true;
}
}
}
}
}
}
return false;
};
// Always enable nodes computing output tensors or nodes whose outputs might get
// overwritten due to inplace kernels
if (computes_output() || possibly_overwritten())
{
writer << " || 1";
}
......
......@@ -499,8 +499,12 @@ namespace ngraph
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
std::map<size_t, size_t> oi_pairs = {{0, 0}};
op_annotations->set_in_place_oi_pairs(oi_pairs);
if (get_user_count(node->get_argument(0).get()) == 1)
{
// Safe to overwrite input
std::map<size_t, size_t> oi_pairs = {{0, 0}};
op_annotations->set_in_place_oi_pairs(oi_pairs);
}
relu->set_op_annotations(op_annotations);
}
}
......
......@@ -1063,18 +1063,17 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_conv_bias_add()
return false;
}
if (conv_m->get_users().size() > 1)
if (get_user_count(conv_m.get()) > 1)
{
NGRAPH_DEBUG << "Convolution has more than one user";
// return false;
return false;
}
if (inplace_input->get_users().size() > 1)
if (!is_post_dominated(inplace_input.get(), add_m.get()))
{
NGRAPH_DEBUG << "Add has more than one user. Convolution Add might use an in-place "
"destructive kernel";
// return false;
NGRAPH_DEBUG << "Unsafe to use in-place kernel since add's in-place input has "
"potential live users";
return false;
}
if (inplace_input->is_parameter())
......
......@@ -539,7 +539,7 @@ void ngraph::runtime::cpu::pass::RNNFusion::construct_rnn_lstm_fprop()
{
if (std::find(lstm_nodes.begin(), lstm_nodes.end(), goe0_user) ==
lstm_nodes.end() &&
ngraph::is_used(goe0_user))
ngraph::is_used(goe0_user.get()))
{
lstm_goe0_user.insert(goe0_user);
map_goe_to_lstm_slices[goe_0] = ht_slice_per_timestep[index];
......@@ -818,7 +818,7 @@ void ngraph::runtime::cpu::pass::MultiLayerRNNFusion::construct_multi_layer_rnn_
node_to_replace = rnn_ct->get_users()[0];
}
}
if (ngraph::is_used(node_to_replace))
if (ngraph::is_used(node_to_replace.get()))
{
ngraph::replace_node(node_to_replace, ct_slice);
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment