Commit 2fe3574d authored by amy.zhuang's avatar amy.zhuang

No in place concat for padded input layout.

parent a87ee09e
...@@ -76,30 +76,6 @@ bool pass::MemoryLayout::run_on_function(shared_ptr<ngraph::Function> function) ...@@ -76,30 +76,6 @@ bool pass::MemoryLayout::run_on_function(shared_ptr<ngraph::Function> function)
tensor->set_pool_offset(offset); tensor->set_pool_offset(offset);
} }
// check if the op is concat
if (auto concat = std::dynamic_pointer_cast<op::Concat>(node))
{
if (auto op_annotations = concat->get_op_annotations())
{
auto in_place_oi_pairs = op_annotations->get_in_place_oi_pairs();
if (in_place_oi_pairs.size() > 0)
{
auto output_tensor = &concat->get_output_tensor();
auto offset = output_tensor->get_pool_offset();
for (auto arg : concat->get_arguments())
{
auto input_node = std::dynamic_pointer_cast<op::Op>(arg);
auto input_tensor = &input_node->get_output_tensor();
auto old_offset = input_tensor->get_pool_offset();
input_tensor->set_pool_offset(offset);
NGRAPH_DEBUG << "memeory_layout: change offset, old offset is "
<< old_offset << ", new offset is " << offset << std::endl;
offset += input_tensor->size();
}
}
}
}
if (!m_disable_memory_sharing) if (!m_disable_memory_sharing)
{ {
for (const descriptor::Tensor* tensor : node->liveness_free_list) for (const descriptor::Tensor* tensor : node->liveness_free_list)
......
...@@ -1198,8 +1198,9 @@ void runtime::cpu::CPU_ExternalFunction::propagate_in_place_concat( ...@@ -1198,8 +1198,9 @@ void runtime::cpu::CPU_ExternalFunction::propagate_in_place_concat(
auto input_tensor = &input_node->get_output_tensor(); auto input_tensor = &input_node->get_output_tensor();
auto old_offset = input_tensor->get_pool_offset(); auto old_offset = input_tensor->get_pool_offset();
input_tensor->set_pool_offset(offset); input_tensor->set_pool_offset(offset);
NGRAPH_DEBUG << "cpu_external_function: change offset, old offset is " NGRAPH_DEBUG
<< old_offset << ", new offset is " << offset << std::endl; << "cpu_external_function, propagate: change offset, old offset is "
<< old_offset << ", new offset is " << offset << std::endl;
offset += input_tensor->size(); offset += input_tensor->size();
if (auto arg_concat = std::dynamic_pointer_cast<ngraph::op::Concat>(arg)) if (auto arg_concat = std::dynamic_pointer_cast<ngraph::op::Concat>(arg))
{ {
...@@ -1276,6 +1277,18 @@ void runtime::cpu::CPU_ExternalFunction::build() ...@@ -1276,6 +1277,18 @@ void runtime::cpu::CPU_ExternalFunction::build()
if (in_place_oi_pairs.size() > 0) if (in_place_oi_pairs.size() > 0)
{ {
bool found_last_concat = true; bool found_last_concat = true;
auto output_tensor = &concat->get_output_tensor();
auto offset = output_tensor->get_pool_offset();
for (auto arg : concat->get_arguments())
{
auto input_node = std::dynamic_pointer_cast<ngraph::op::Op>(arg);
auto input_tensor = &input_node->get_output_tensor();
auto old_offset = input_tensor->get_pool_offset();
input_tensor->set_pool_offset(offset);
NGRAPH_DEBUG << "cpu_external_function: change offset, old offset is "
<< old_offset << ", new offset is " << offset << std::endl;
offset += input_tensor->size();
}
for (auto user : concat->get_users()) for (auto user : concat->get_users())
{ {
if (dynamic_pointer_cast<ngraph::op::Concat>(user)) if (dynamic_pointer_cast<ngraph::op::Concat>(user))
...@@ -1290,8 +1303,9 @@ void runtime::cpu::CPU_ExternalFunction::build() ...@@ -1290,8 +1303,9 @@ void runtime::cpu::CPU_ExternalFunction::build()
{ {
if (auto arg_concat = dynamic_pointer_cast<ngraph::op::Concat>(arg)) if (auto arg_concat = dynamic_pointer_cast<ngraph::op::Concat>(arg))
{ {
NGRAPH_DEBUG << "call propagate_in_place_concat for " NGRAPH_DEBUG
<< arg->get_name() << std::endl; << "cpu_external_function: call propagate_in_place_concat for "
<< arg->get_name() << std::endl;
propagate_in_place_concat(arg_concat); propagate_in_place_concat(arg_concat);
} }
} }
......
...@@ -17,9 +17,11 @@ ...@@ -17,9 +17,11 @@
#include "ngraph/runtime/cpu/pass/cpu_memory_optimization.hpp" #include "ngraph/runtime/cpu/pass/cpu_memory_optimization.hpp"
#include "ngraph/descriptor/output.hpp" #include "ngraph/descriptor/output.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/op/concat.hpp" #include "ngraph/op/concat.hpp"
#include "ngraph/op/constant.hpp" #include "ngraph/op/constant.hpp"
#include "ngraph/runtime/cpu/cpu_op_annotations.hpp" #include "ngraph/runtime/cpu/cpu_op_annotations.hpp"
#include "ngraph/runtime/cpu/mkldnn_utils.hpp"
using namespace ngraph; using namespace ngraph;
...@@ -44,9 +46,25 @@ bool runtime::cpu::pass::CPUMemoryOptimization::run_on_function(std::shared_ptr< ...@@ -44,9 +46,25 @@ bool runtime::cpu::pass::CPUMemoryOptimization::run_on_function(std::shared_ptr<
} }
bool in_place_concat = true; bool in_place_concat = true;
AxisVector axis_list;
for (auto i = 0; i < shape.size(); i++)
{
axis_list.push_back(i);
}
auto index = 0;
for (descriptor::Input& input : concat->get_inputs()) for (descriptor::Input& input : concat->get_inputs())
{ {
// check if input layout is padded
auto input_md = mkldnn_utils::get_input_mkldnn_md(n.get(), index);
index++;
if (mkldnn_utils::is_mkldnn_padded_layout(input_md, axis_list))
{
NGRAPH_DEBUG
<< "cpu_post_layout_assignment: padded input layout, no in place concat";
in_place_concat = false;
break;
}
if (shape_size(input.get_shape()) == 0) if (shape_size(input.get_shape()) == 0)
{ {
NGRAPH_DEBUG << "cpu_post_layout_assignment: 0 length tensor, no in " NGRAPH_DEBUG << "cpu_post_layout_assignment: 0 length tensor, no in "
...@@ -114,57 +132,37 @@ bool runtime::cpu::pass::CPUMemoryOptimization::run_on_function(std::shared_ptr< ...@@ -114,57 +132,37 @@ bool runtime::cpu::pass::CPUMemoryOptimization::run_on_function(std::shared_ptr<
break; break;
} }
std::unordered_set<Node*> visited; for (auto user : arg->get_users())
std::deque<Node*> stack;
stack.push_front(arg.get());
while (stack.size() > 0)
{ {
ngraph::Node* curr = stack.front(); if ((user != concat))
visited.insert(curr);
if (curr->is_output())
{ {
NGRAPH_DEBUG << "cpu_post_layout_assignment: not post " if (auto op = std::dynamic_pointer_cast<op::Op>(user))
"dominated, no in place concat";
in_place_concat = false;
break;
}
else
{
if (auto op = dynamic_cast<op::Op*>(curr))
{ {
if (auto op_annotations = op->get_op_annotations()) if (auto op_annotations = op->get_op_annotations())
{ {
for (auto oi_pair : op_annotations->get_in_place_oi_pairs()) for (auto oi_pair : op_annotations->get_in_place_oi_pairs())
{ {
if (oi_pair.destructive) NGRAPH_DEBUG << "cpu_post_layout_assignment: "
{ "in place oi, no in place concat";
NGRAPH_DEBUG << "cpu_post_layout_assignment: " in_place_concat = false;
"destructive in place oi, no " break;
"in place concat";
in_place_concat = false;
break;
}
} }
} }
} }
} }
stack.pop_front();
if (curr != concat.get())
{
for (auto next : curr->get_users())
{
if (visited.count(next.get()) == 0)
{
stack.push_front(next.get());
}
}
}
} }
if (!in_place_concat) if (!in_place_concat)
{ {
break; break;
} }
else if (!is_post_dominated(arg.get(), n.get()))
{
NGRAPH_DEBUG << "cpu_post_layout_assignment: "
"not post dominated, no in place concat";
in_place_concat = false;
break;
}
} }
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment