Commit a87ee09e authored by amy.zhuang's avatar amy.zhuang

No in place concat immediately after non-concat in place ops.

parent 6ca1a511
...@@ -74,6 +74,8 @@ bool pass::MemoryLayout::run_on_function(shared_ptr<ngraph::Function> function) ...@@ -74,6 +74,8 @@ bool pass::MemoryLayout::run_on_function(shared_ptr<ngraph::Function> function)
? in_place_outputs.at(tensor)->get_pool_offset() ? in_place_outputs.at(tensor)->get_pool_offset()
: mm.allocate(tensor->size()); : mm.allocate(tensor->size());
tensor->set_pool_offset(offset); tensor->set_pool_offset(offset);
}
// check if the op is concat // check if the op is concat
if (auto concat = std::dynamic_pointer_cast<op::Concat>(node)) if (auto concat = std::dynamic_pointer_cast<op::Concat>(node))
{ {
...@@ -82,6 +84,8 @@ bool pass::MemoryLayout::run_on_function(shared_ptr<ngraph::Function> function) ...@@ -82,6 +84,8 @@ bool pass::MemoryLayout::run_on_function(shared_ptr<ngraph::Function> function)
auto in_place_oi_pairs = op_annotations->get_in_place_oi_pairs(); auto in_place_oi_pairs = op_annotations->get_in_place_oi_pairs();
if (in_place_oi_pairs.size() > 0) if (in_place_oi_pairs.size() > 0)
{ {
auto output_tensor = &concat->get_output_tensor();
auto offset = output_tensor->get_pool_offset();
for (auto arg : concat->get_arguments()) for (auto arg : concat->get_arguments())
{ {
auto input_node = std::dynamic_pointer_cast<op::Op>(arg); auto input_node = std::dynamic_pointer_cast<op::Op>(arg);
...@@ -95,7 +99,6 @@ bool pass::MemoryLayout::run_on_function(shared_ptr<ngraph::Function> function) ...@@ -95,7 +99,6 @@ bool pass::MemoryLayout::run_on_function(shared_ptr<ngraph::Function> function)
} }
} }
} }
}
if (!m_disable_memory_sharing) if (!m_disable_memory_sharing)
{ {
......
...@@ -66,6 +66,30 @@ bool runtime::cpu::pass::CPUMemoryOptimization::run_on_function(std::shared_ptr< ...@@ -66,6 +66,30 @@ bool runtime::cpu::pass::CPUMemoryOptimization::run_on_function(std::shared_ptr<
break; break;
} }
if (arg->get_output_size() != 1)
{
NGRAPH_DEBUG << "cpu_post_layout_assignment: " << arg->get_name()
<< ": multiple outputs, no in place concat";
in_place_concat = false;
break;
}
if (!std::dynamic_pointer_cast<op::Concat>(arg))
{
if (auto op = std::dynamic_pointer_cast<op::Op>(arg))
{
auto annotation = op->get_op_annotations();
if (annotation && annotation->get_in_place_oi_pairs().size() > 0)
{
NGRAPH_DEBUG << "cpu_post_layout_assignment: " << arg->get_name()
<< ": in place non concat op, no in place concat";
in_place_concat = false;
break;
}
}
}
if (output.get_inputs().size() != 1) if (output.get_inputs().size() != 1)
{ {
// check if we can do in place concat // check if we can do in place concat
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment