Commit a87ee09e authored by amy.zhuang's avatar amy.zhuang

No in place concat immediately after non-concat in place ops.

parent 6ca1a511
......@@ -74,24 +74,27 @@ bool pass::MemoryLayout::run_on_function(shared_ptr<ngraph::Function> function)
? in_place_outputs.at(tensor)->get_pool_offset()
: mm.allocate(tensor->size());
tensor->set_pool_offset(offset);
// check if the op is concat
if (auto concat = std::dynamic_pointer_cast<op::Concat>(node))
}
// check if the op is concat
if (auto concat = std::dynamic_pointer_cast<op::Concat>(node))
{
if (auto op_annotations = concat->get_op_annotations())
{
if (auto op_annotations = concat->get_op_annotations())
auto in_place_oi_pairs = op_annotations->get_in_place_oi_pairs();
if (in_place_oi_pairs.size() > 0)
{
auto in_place_oi_pairs = op_annotations->get_in_place_oi_pairs();
if (in_place_oi_pairs.size() > 0)
auto output_tensor = &concat->get_output_tensor();
auto offset = output_tensor->get_pool_offset();
for (auto arg : concat->get_arguments())
{
for (auto arg : concat->get_arguments())
{
auto input_node = std::dynamic_pointer_cast<op::Op>(arg);
auto input_tensor = &input_node->get_output_tensor();
auto old_offset = input_tensor->get_pool_offset();
input_tensor->set_pool_offset(offset);
NGRAPH_DEBUG << "memeory_layout: change offset, old offset is "
<< old_offset << ", new offset is " << offset << std::endl;
offset += input_tensor->size();
}
auto input_node = std::dynamic_pointer_cast<op::Op>(arg);
auto input_tensor = &input_node->get_output_tensor();
auto old_offset = input_tensor->get_pool_offset();
input_tensor->set_pool_offset(offset);
NGRAPH_DEBUG << "memeory_layout: change offset, old offset is "
<< old_offset << ", new offset is " << offset << std::endl;
offset += input_tensor->size();
}
}
}
......
......@@ -66,6 +66,30 @@ bool runtime::cpu::pass::CPUMemoryOptimization::run_on_function(std::shared_ptr<
break;
}
if (arg->get_output_size() != 1)
{
NGRAPH_DEBUG << "cpu_post_layout_assignment: " << arg->get_name()
<< ": multiple outputs, no in place concat";
in_place_concat = false;
break;
}
if (!std::dynamic_pointer_cast<op::Concat>(arg))
{
if (auto op = std::dynamic_pointer_cast<op::Op>(arg))
{
auto annotation = op->get_op_annotations();
if (annotation && annotation->get_in_place_oi_pairs().size() > 0)
{
NGRAPH_DEBUG << "cpu_post_layout_assignment: " << arg->get_name()
<< ": in place non concat op, no in place concat";
in_place_concat = false;
break;
}
}
}
if (output.get_inputs().size() != 1)
{
// check if we can do in place concat
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment