Commit a87ee09e authored by amy.zhuang's avatar amy.zhuang

No in place concat immediately after non-concat in place ops.

parent 6ca1a511
...@@ -74,24 +74,27 @@ bool pass::MemoryLayout::run_on_function(shared_ptr<ngraph::Function> function) ...@@ -74,24 +74,27 @@ bool pass::MemoryLayout::run_on_function(shared_ptr<ngraph::Function> function)
? in_place_outputs.at(tensor)->get_pool_offset() ? in_place_outputs.at(tensor)->get_pool_offset()
: mm.allocate(tensor->size()); : mm.allocate(tensor->size());
tensor->set_pool_offset(offset); tensor->set_pool_offset(offset);
// check if the op is concat }
if (auto concat = std::dynamic_pointer_cast<op::Concat>(node))
// check if the op is concat
if (auto concat = std::dynamic_pointer_cast<op::Concat>(node))
{
if (auto op_annotations = concat->get_op_annotations())
{ {
if (auto op_annotations = concat->get_op_annotations()) auto in_place_oi_pairs = op_annotations->get_in_place_oi_pairs();
if (in_place_oi_pairs.size() > 0)
{ {
auto in_place_oi_pairs = op_annotations->get_in_place_oi_pairs(); auto output_tensor = &concat->get_output_tensor();
if (in_place_oi_pairs.size() > 0) auto offset = output_tensor->get_pool_offset();
for (auto arg : concat->get_arguments())
{ {
for (auto arg : concat->get_arguments()) auto input_node = std::dynamic_pointer_cast<op::Op>(arg);
{ auto input_tensor = &input_node->get_output_tensor();
auto input_node = std::dynamic_pointer_cast<op::Op>(arg); auto old_offset = input_tensor->get_pool_offset();
auto input_tensor = &input_node->get_output_tensor(); input_tensor->set_pool_offset(offset);
auto old_offset = input_tensor->get_pool_offset(); NGRAPH_DEBUG << "memeory_layout: change offset, old offset is "
input_tensor->set_pool_offset(offset); << old_offset << ", new offset is " << offset << std::endl;
NGRAPH_DEBUG << "memeory_layout: change offset, old offset is " offset += input_tensor->size();
<< old_offset << ", new offset is " << offset << std::endl;
offset += input_tensor->size();
}
} }
} }
} }
......
...@@ -66,6 +66,30 @@ bool runtime::cpu::pass::CPUMemoryOptimization::run_on_function(std::shared_ptr< ...@@ -66,6 +66,30 @@ bool runtime::cpu::pass::CPUMemoryOptimization::run_on_function(std::shared_ptr<
break; break;
} }
if (arg->get_output_size() != 1)
{
NGRAPH_DEBUG << "cpu_post_layout_assignment: " << arg->get_name()
<< ": multiple outputs, no in place concat";
in_place_concat = false;
break;
}
if (!std::dynamic_pointer_cast<op::Concat>(arg))
{
if (auto op = std::dynamic_pointer_cast<op::Op>(arg))
{
auto annotation = op->get_op_annotations();
if (annotation && annotation->get_in_place_oi_pairs().size() > 0)
{
NGRAPH_DEBUG << "cpu_post_layout_assignment: " << arg->get_name()
<< ": in place non concat op, no in place concat";
in_place_concat = false;
break;
}
}
}
if (output.get_inputs().size() != 1) if (output.get_inputs().size() != 1)
{ {
// check if we can do in place concat // check if we can do in place concat
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment