Commit 87197ec3 authored by amy.zhuang's avatar amy.zhuang

No in place concat if input format differs from output format.

parent 8f102516
...@@ -1006,19 +1006,19 @@ namespace ngraph ...@@ -1006,19 +1006,19 @@ namespace ngraph
{ {
writer << "if (" << args[i].get_name() << " < " << out[0].get_name() writer << "if (" << args[i].get_name() << " < " << out[0].get_name()
<< " || " << args[i].get_name() << " >= " << out[0].get_name() << " || " << args[i].get_name() << " >= " << out[0].get_name()
<< " + " << out[0].get_size() << ")\n"; << " + " << out[0].get_size() * out[0].get_element_type().size()
<< ")\n";
writer.block_begin(); writer.block_begin();
writer << "memcpy(" << out[0].get_name() << " + " << offset << ", " writer << "memcpy(" << out[0].get_name() << " + " << offset << ", "
<< args[i].get_name() << ", " << args[i].get_name() << ", "
<< args[i].get_size() * out[0].get_element_type().size() << args[i].get_size() * out[0].get_element_type().size()
<< ");\n"; << ");\n";
writer.block_end(); writer.block_end();
offset += args[i].get_size(); offset += args[i].get_size() * out[0].get_element_type().size();
} }
return; return;
} }
} }
auto result_shape = out[0].get_shape(); auto result_shape = out[0].get_shape();
#if USE_EIGEN_CORE_INLINE == 1 #if USE_EIGEN_CORE_INLINE == 1
......
...@@ -78,6 +78,25 @@ bool runtime::cpu::pass::CPUMemoryOptimization::run_on_function(std::shared_ptr< ...@@ -78,6 +78,25 @@ bool runtime::cpu::pass::CPUMemoryOptimization::run_on_function(std::shared_ptr<
} }
bool in_place_concat = true; bool in_place_concat = true;
auto output_md = mkldnn_utils::get_output_mkldnn_md(n.get(), 0);
auto output_format = static_cast<mkldnn::memory::format>(output_md.data.format);
for (size_t i = 0; i < n->get_input_size(); i++)
{
auto input_md = mkldnn_utils::get_input_mkldnn_md(n.get(), i);
auto input_format = static_cast<mkldnn::memory::format>(input_md.data.format);
if (output_format != input_format)
{
NGRAPH_DEBUG << "cpu_memory_optimization: input format is different from "
"output format, no in place concat";
in_place_concat = false;
break;
}
}
if (!in_place_concat)
{
continue;
}
AxisVector axis_list = ngraph::get_default_order(shape); AxisVector axis_list = ngraph::get_default_order(shape);
auto index = 0; auto index = 0;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment