Commit 5e0fdc60 authored by amy.zhuang's avatar amy.zhuang

Addressed PR feedback.

parent 4fb88c5b
...@@ -46,8 +46,6 @@ ...@@ -46,8 +46,6 @@
/// ///
/// After optimization: the result of add1 is stored to the memory buffer assigned to concat, same for add2 and add3. /// After optimization: the result of add1 is stored to the memory buffer assigned to concat, same for add2 and add3.
#include <cassert>
#include "ngraph/runtime/cpu/pass/cpu_memory_optimization.hpp" #include "ngraph/runtime/cpu/pass/cpu_memory_optimization.hpp"
#include "ngraph/descriptor/output.hpp" #include "ngraph/descriptor/output.hpp"
...@@ -80,16 +78,13 @@ bool runtime::cpu::pass::CPUMemoryOptimization::run_on_function(std::shared_ptr< ...@@ -80,16 +78,13 @@ bool runtime::cpu::pass::CPUMemoryOptimization::run_on_function(std::shared_ptr<
} }
bool in_place_concat = true; bool in_place_concat = true;
AxisVector axis_list; AxisVector axis_list = ngraph::get_default_order(shape);
for (auto i = 0; i < shape.size(); i++)
{
axis_list.push_back(i);
}
auto index = 0; auto index = 0;
for (descriptor::Input& input : concat->get_inputs()) for (descriptor::Input& input : concat->get_inputs())
{ {
// no tensors with zero-sized dimensions after zero_dim_tensor_elimination // no tensors with zero-sized dimensions after zero_dim_tensor_elimination
assert(shape_size(input.get_shape()) != 0); NGRAPH_ASSERT(shape_size(input.get_shape()) != 0);
// check if input layout is padded // check if input layout is padded
auto input_md = mkldnn_utils::get_input_mkldnn_md(n.get(), index); auto input_md = mkldnn_utils::get_input_mkldnn_md(n.get(), index);
...@@ -113,13 +108,7 @@ bool runtime::cpu::pass::CPUMemoryOptimization::run_on_function(std::shared_ptr< ...@@ -113,13 +108,7 @@ bool runtime::cpu::pass::CPUMemoryOptimization::run_on_function(std::shared_ptr<
break; break;
} }
if (arg->get_output_size() != 1) NGRAPH_ASSERT(arg->get_output_size() == 1);
{
NGRAPH_DEBUG << "cpu_memory_optimization: " << arg->get_name()
<< ": multiple outputs, no in place concat";
in_place_concat = false;
break;
}
if (!std::dynamic_pointer_cast<op::Concat>(arg)) if (!std::dynamic_pointer_cast<op::Concat>(arg))
{ {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment