Commit 69c51c27 authored by Jayaram Bobba's avatar Jayaram Bobba Committed by Scott Cyphers

Switch to using more expressive layout descriptors instead of numeric layout names (#1278)

* Switch to using mkldnn memory descriptors for layout

* More changes for using mkldnn descriptor instead of format

* Removed mkldnn format from cpu layout descriptor. TODO - shuffle folding

* Rotate mkldnn layouts on transpose

* Modifications to builder reshape to skip rotated layouts

* More fixes to layouts and removes axis order from cpu layout descriptor

* Code cleanup

* Removed shuffle folding pass since the functionality is subsumed by the layout pass

* Canonicalize a few more formats to keep MKLDNN happy.

* Style fixes

* Style fixes

* Style fixes

* Addressed PR feedback and added reshape passthrough for non-transpose cases

* Adjust named formats for weights tensors to keep MKLDNN happy

* Style fixes

* resolved merge issues
parent 5f77fe86
...@@ -79,7 +79,6 @@ set(SRC ...@@ -79,7 +79,6 @@ set(SRC
pass/cpu_rnn_fusion.cpp pass/cpu_rnn_fusion.cpp
pass/cpu_mat_fusion.cpp pass/cpu_mat_fusion.cpp
pass/cpu_loop_kernel_fusion.cpp pass/cpu_loop_kernel_fusion.cpp
pass/cpu_shuffle_folding.cpp
pass/cpu_workspace_insertion.cpp pass/cpu_workspace_insertion.cpp
) )
......
...@@ -53,10 +53,8 @@ namespace ngraph ...@@ -53,10 +53,8 @@ namespace ngraph
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node)) if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{ {
auto& mkldnn_emitter = external_function->get_mkldnn_emitter(); auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto input_desc = mkldnn_emitter->build_memory_descriptor( auto input_desc = mkldnn_utils::get_input_mkldnn_md(node, 0);
args[0], runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 0)); auto result_desc = mkldnn_utils::get_output_mkldnn_md(node, 0);
auto result_desc = mkldnn_emitter->build_memory_descriptor(
out[0], runtime::cpu::mkldnn_utils::get_output_mkldnn_format(node, 0));
size_t avg_pool_index = mkldnn_emitter->build_pooling_forward( size_t avg_pool_index = mkldnn_emitter->build_pooling_forward(
(include_padding_in_avg_computation (include_padding_in_avg_computation
...@@ -132,10 +130,8 @@ namespace ngraph ...@@ -132,10 +130,8 @@ namespace ngraph
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node)) if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{ {
auto& mkldnn_emitter = external_function->get_mkldnn_emitter(); auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto diff_dst_desc = mkldnn_emitter->build_memory_descriptor( auto diff_dst_desc = runtime::cpu::mkldnn_utils::get_input_mkldnn_md(node, 0);
args[0], runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 0)); auto diff_src_desc = runtime::cpu::mkldnn_utils::get_output_mkldnn_md(node, 0);
auto diff_src_desc = mkldnn_emitter->build_memory_descriptor(
out[0], runtime::cpu::mkldnn_utils::get_output_mkldnn_format(node, 0));
size_t avg_pool_index = mkldnn_emitter->build_pooling_backward( size_t avg_pool_index = mkldnn_emitter->build_pooling_backward(
(apb->get_include_padding_in_avg_computation() (apb->get_include_padding_in_avg_computation()
...@@ -159,7 +155,6 @@ namespace ngraph ...@@ -159,7 +155,6 @@ namespace ngraph
else else
{ {
std::function<decltype(runtime::cpu::kernel::avg_pool_backprop<float>)> kernel; std::function<decltype(runtime::cpu::kernel::avg_pool_backprop<float>)> kernel;
SELECT_KERNEL( SELECT_KERNEL(
kernel, out[0].get_element_type(), runtime::cpu::kernel::avg_pool_backprop); kernel, out[0].get_element_type(), runtime::cpu::kernel::avg_pool_backprop);
......
...@@ -77,26 +77,14 @@ namespace ngraph ...@@ -77,26 +77,14 @@ namespace ngraph
auto& out1_tensor = tensor_data[out[1].get_name()]; auto& out1_tensor = tensor_data[out[1].get_name()];
auto& out2_tensor = tensor_data[out[2].get_name()]; auto& out2_tensor = tensor_data[out[2].get_name()];
auto input_format =
runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 2);
auto result_format =
runtime::cpu::mkldnn_utils::get_output_mkldnn_format(node, 0);
auto mean_format =
runtime::cpu::mkldnn_utils::get_output_mkldnn_format(node, 1);
auto variance_format =
runtime::cpu::mkldnn_utils::get_output_mkldnn_format(node, 2);
auto& mkldnn_emitter = external_function->get_mkldnn_emitter(); auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto input_desc = mkldnn_utils::get_input_mkldnn_md(node, 2);
auto weights_shape = Shape{2, args[0].get_size()}; auto weights_shape = Shape{2, args[0].get_size()};
auto input_desc =
mkldnn_emitter->build_memory_descriptor(args[2], input_format);
auto weights_desc = mkldnn_emitter->build_memory_descriptor( auto weights_desc = mkldnn_emitter->build_memory_descriptor(
weights_shape, args[0].get_element_type(), mkldnn::memory::format::nc); weights_shape, args[0].get_element_type(), mkldnn::memory::format::nc);
auto results_desc = auto results_desc = mkldnn_utils::get_output_mkldnn_md(node, 0);
mkldnn_emitter->build_memory_descriptor(out[0], result_format); auto mean_desc = mkldnn_utils::get_output_mkldnn_md(node, 1);
auto mean_desc = mkldnn_emitter->build_memory_descriptor(out[1], mean_format); auto variance_desc = mkldnn_utils::get_output_mkldnn_md(node, 2);
auto variance_desc =
mkldnn_emitter->build_memory_descriptor(out[2], variance_format);
auto batchnorm_index = auto batchnorm_index =
mkldnn_emitter->build_batchnorm_forward(input_desc, mkldnn_emitter->build_batchnorm_forward(input_desc,
...@@ -131,24 +119,14 @@ namespace ngraph ...@@ -131,24 +119,14 @@ namespace ngraph
auto& arg3_tensor = tensor_data[args[3].get_name()]; auto& arg3_tensor = tensor_data[args[3].get_name()];
auto& arg4_tensor = tensor_data[args[4].get_name()]; auto& arg4_tensor = tensor_data[args[4].get_name()];
auto input_format =
runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 2);
auto mean_format = runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 3);
auto variance_format =
runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 4);
auto result_format =
runtime::cpu::mkldnn_utils::get_output_mkldnn_format(node, 0);
auto& mkldnn_emitter = external_function->get_mkldnn_emitter(); auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto weights_shape = Shape{2, args[0].get_size()}; auto weights_shape = Shape{2, args[0].get_size()};
auto input_desc = auto input_desc = mkldnn_utils::get_input_mkldnn_md(node, 2);
mkldnn_emitter->build_memory_descriptor(args[2], input_format);
auto weights_desc = mkldnn_emitter->build_memory_descriptor( auto weights_desc = mkldnn_emitter->build_memory_descriptor(
weights_shape, args[0].get_element_type(), mkldnn::memory::format::nc); weights_shape, args[0].get_element_type(), mkldnn::memory::format::nc);
auto mean_desc = mkldnn_emitter->build_memory_descriptor(args[3], mean_format); auto mean_desc = mkldnn_utils::get_input_mkldnn_md(node, 3);
auto variance_desc = auto variance_desc = mkldnn_utils::get_input_mkldnn_md(node, 4);
mkldnn_emitter->build_memory_descriptor(args[4], variance_format); auto results_desc = mkldnn_utils::get_output_mkldnn_md(node, 0);
auto results_desc =
mkldnn_emitter->build_memory_descriptor(out[0], result_format);
auto batchnorm_index = auto batchnorm_index =
mkldnn_emitter->build_batchnorm_forward(input_desc, mkldnn_emitter->build_batchnorm_forward(input_desc,
...@@ -298,22 +276,15 @@ namespace ngraph ...@@ -298,22 +276,15 @@ namespace ngraph
#pragma clang diagnostic pop #pragma clang diagnostic pop
auto input_format = runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 2);
auto mean_format = runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 3);
auto variance_format = runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 4);
auto delta_format = runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 5);
auto dinput_format = runtime::cpu::mkldnn_utils::get_output_mkldnn_format(node, 0);
auto& mkldnn_emitter = external_function->get_mkldnn_emitter(); auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto weights_shape = Shape{2, args[0].get_size()}; auto weights_shape = Shape{2, args[0].get_size()};
auto weights_desc = mkldnn_emitter->build_memory_descriptor( auto weights_desc = mkldnn_emitter->build_memory_descriptor(
weights_shape, args[0].get_element_type(), mkldnn::memory::format::nc); weights_shape, args[0].get_element_type(), mkldnn::memory::format::nc);
auto input_desc = mkldnn_emitter->build_memory_descriptor(args[2], input_format); auto input_desc = mkldnn_utils::get_input_mkldnn_md(node, 2);
auto mean_desc = mkldnn_emitter->build_memory_descriptor(args[3], mean_format); auto mean_desc = mkldnn_utils::get_input_mkldnn_md(node, 3);
auto variance_desc = auto variance_desc = mkldnn_utils::get_input_mkldnn_md(node, 4);
mkldnn_emitter->build_memory_descriptor(args[4], variance_format); auto delta_desc = mkldnn_utils::get_input_mkldnn_md(node, 5);
auto delta_desc = mkldnn_emitter->build_memory_descriptor(args[5], delta_format); auto dinput_desc = mkldnn_utils::get_output_mkldnn_md(node, 0);
auto dinput_desc = mkldnn_emitter->build_memory_descriptor(out[0], dinput_format);
auto dweights_desc = mkldnn_emitter->build_memory_descriptor( auto dweights_desc = mkldnn_emitter->build_memory_descriptor(
weights_shape, args[0].get_element_type(), mkldnn::memory::format::nc); weights_shape, args[0].get_element_type(), mkldnn::memory::format::nc);
......
...@@ -37,42 +37,10 @@ namespace ngraph ...@@ -37,42 +37,10 @@ namespace ngraph
auto& arg_tensor = tensor_data[args[0].get_name()]; auto& arg_tensor = tensor_data[args[0].get_name()];
auto& out_tensor = tensor_data[out[0].get_name()]; auto& out_tensor = tensor_data[out[0].get_name()];
auto input_tvl =
node->get_inputs()[0].get_output().get_tensor_view()->get_tensor_view_layout();
auto input_cpu_tvl =
dynamic_pointer_cast<runtime::cpu::LayoutDescriptor>(input_tvl);
auto input_format = input_cpu_tvl->get_mkldnn_format();
// Reorder input shape if needed
auto input_axis_order = input_cpu_tvl->get_axis_order();
Shape input_shape(input_axis_order.size());
for (size_t idx = 0; idx < input_axis_order.size(); idx++)
{
input_shape[idx] = args[0].get_shape()[input_axis_order[idx]];
}
auto output_tvl = node->get_output_tensor_view(0)->get_tensor_view_layout();
auto output_format =
dynamic_cast<runtime::cpu::LayoutDescriptor&>(*output_tvl).get_mkldnn_format();
// MKLDNN relies on format names for selecting optimized kernel implementations
// Hacky way to deal with this until they move to using canonicalized layouts
if (input_format == mkldnn::memory::format::nchw &&
runtime::cpu::mkldnn_utils::is_mkldnn_filter_format(output_format))
{
input_format = mkldnn::memory::format::oihw;
}
if (output_format == mkldnn::memory::format::nchw &&
runtime::cpu::mkldnn_utils::is_mkldnn_filter_format(input_format))
{
output_format = mkldnn::memory::format::oihw;
}
auto& mkldnn_emitter = external_function->get_mkldnn_emitter(); auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto input_desc = mkldnn_emitter->build_memory_descriptor( auto input_desc = mkldnn_utils::get_input_mkldnn_md(node, 0);
input_shape, args[0].get_element_type(), input_format); auto result_desc = mkldnn_utils::get_output_mkldnn_md(node, 0);
auto result_desc = mkldnn_emitter->build_memory_descriptor(out[0], output_format);
size_t reorder_index = mkldnn_emitter->build_reorder(input_desc, result_desc); size_t reorder_index = mkldnn_emitter->build_reorder(input_desc, result_desc);
......
...@@ -429,15 +429,8 @@ namespace ngraph ...@@ -429,15 +429,8 @@ namespace ngraph
window_dilation_strides_adjusted.push_back(s - 1); window_dilation_strides_adjusted.push_back(s - 1);
} }
auto input_format =
runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 0);
auto output_format =
runtime::cpu::mkldnn_utils::get_output_mkldnn_format(node, 0);
auto& mkldnn_emitter = external_function->get_mkldnn_emitter(); auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto input_data_desc = auto input_data_desc = mkldnn_utils::get_input_mkldnn_md(node, 0);
mkldnn_emitter->build_memory_descriptor(args[0], input_format);
Shape weights_shape_groups = convolution->get_weights_dimensions(); Shape weights_shape_groups = convolution->get_weights_dimensions();
...@@ -451,8 +444,7 @@ namespace ngraph ...@@ -451,8 +444,7 @@ namespace ngraph
auto padding_above = convolution->get_padding_above(); auto padding_above = convolution->get_padding_above();
auto filter_strides = convolution->get_window_movement_strides(); auto filter_strides = convolution->get_window_movement_strides();
auto result_desc = auto result_desc = mkldnn_utils::get_output_mkldnn_md(node, 0);
mkldnn_emitter->build_memory_descriptor(out[0], output_format);
auto weights_optimized_format = auto weights_optimized_format =
mkldnn_emitter->query_convolution_forward_weight_format( mkldnn_emitter->query_convolution_forward_weight_format(
......
...@@ -52,10 +52,8 @@ namespace ngraph ...@@ -52,10 +52,8 @@ namespace ngraph
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node)) if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{ {
auto& mkldnn_emitter = external_function->get_mkldnn_emitter(); auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto input_desc = mkldnn_emitter->build_memory_descriptor( auto input_desc = mkldnn_utils::get_input_mkldnn_md(node, 0);
args[0], runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 0)); auto result_desc = mkldnn_utils::get_output_mkldnn_md(node, 0);
auto result_desc = mkldnn_emitter->build_memory_descriptor(
out[0], runtime::cpu::mkldnn_utils::get_output_mkldnn_format(node, 0));
size_t max_pool_index = size_t max_pool_index =
mkldnn_emitter->build_pooling_forward(mkldnn::algorithm::pooling_max, mkldnn_emitter->build_pooling_forward(mkldnn::algorithm::pooling_max,
...@@ -126,12 +124,9 @@ namespace ngraph ...@@ -126,12 +124,9 @@ namespace ngraph
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node)) if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{ {
auto& mkldnn_emitter = external_function->get_mkldnn_emitter(); auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto fprop_src_desc = mkldnn_emitter->build_memory_descriptor( auto fprop_src_desc = mkldnn_utils::get_input_mkldnn_md(node, 0);
args[0], runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 0)); auto diff_dst_desc = mkldnn_utils::get_input_mkldnn_md(node, 1);
auto diff_dst_desc = mkldnn_emitter->build_memory_descriptor( auto diff_src_desc = mkldnn_utils::get_output_mkldnn_md(node, 0);
args[1], runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 1));
auto diff_src_desc = mkldnn_emitter->build_memory_descriptor(
out[0], runtime::cpu::mkldnn_utils::get_output_mkldnn_format(node, 0));
size_t max_pool_index = mkldnn_emitter->build_max_pooling_backward( size_t max_pool_index = mkldnn_emitter->build_max_pooling_backward(
mkldnn::algorithm::pooling_max, mkldnn::algorithm::pooling_max,
...@@ -210,10 +205,8 @@ namespace ngraph ...@@ -210,10 +205,8 @@ namespace ngraph
auto& out1_tensor = tensor_data[out[1].get_name()]; auto& out1_tensor = tensor_data[out[1].get_name()];
auto& mkldnn_emitter = external_function->get_mkldnn_emitter(); auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto input_desc = mkldnn_emitter->build_memory_descriptor( auto input_desc = runtime::cpu::mkldnn_utils::get_input_mkldnn_md(node, 0);
args[0], runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 0)); auto result_desc = runtime::cpu::mkldnn_utils::get_output_mkldnn_md(node, 0);
auto result_desc = mkldnn_emitter->build_memory_descriptor(
out[0], runtime::cpu::mkldnn_utils::get_output_mkldnn_format(node, 0));
size_t max_pool_index = mkldnn_emitter->build_max_pooling_with_indices_forward( size_t max_pool_index = mkldnn_emitter->build_max_pooling_with_indices_forward(
mkldnn::algorithm::pooling_max, mkldnn::algorithm::pooling_max,
...@@ -253,10 +246,8 @@ namespace ngraph ...@@ -253,10 +246,8 @@ namespace ngraph
auto mpb = static_cast<const ngraph::op::MaxPoolWithIndicesBackprop*>(node); auto mpb = static_cast<const ngraph::op::MaxPoolWithIndicesBackprop*>(node);
auto& mkldnn_emitter = external_function->get_mkldnn_emitter(); auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto diff_dst_desc = mkldnn_emitter->build_memory_descriptor( auto diff_dst_desc = runtime::cpu::mkldnn_utils::get_input_mkldnn_md(node, 1);
args[1], runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 1)); auto diff_src_desc = runtime::cpu::mkldnn_utils::get_output_mkldnn_md(node, 0);
auto diff_src_desc = mkldnn_emitter->build_memory_descriptor(
out[0], runtime::cpu::mkldnn_utils::get_output_mkldnn_format(node, 0));
size_t max_pool_index = mkldnn_emitter->build_max_pooling_with_indices_backward( size_t max_pool_index = mkldnn_emitter->build_max_pooling_with_indices_backward(
mkldnn::algorithm::pooling_max, mkldnn::algorithm::pooling_max,
......
...@@ -43,12 +43,9 @@ namespace ngraph ...@@ -43,12 +43,9 @@ namespace ngraph
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node)) if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{ {
auto& mkldnn_emitter = external_function->get_mkldnn_emitter(); auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto input_desc = mkldnn_emitter->build_memory_descriptor( auto input_desc = mkldnn_utils::get_input_mkldnn_md(node, 0);
args[0], runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 0)); auto delta_desc = mkldnn_utils::get_input_mkldnn_md(node, 1);
auto delta_desc = mkldnn_emitter->build_memory_descriptor( auto result_desc = mkldnn_utils::get_output_mkldnn_md(node, 0);
args[1], runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 1));
auto result_desc = mkldnn_emitter->build_memory_descriptor(
out[0], runtime::cpu::mkldnn_utils::get_output_mkldnn_format(node, 0));
size_t relu_index = size_t relu_index =
mkldnn_emitter->build_relu_backward(input_desc, delta_desc, result_desc); mkldnn_emitter->build_relu_backward(input_desc, delta_desc, result_desc);
......
This diff is collapsed.
This diff is collapsed.
...@@ -151,7 +151,6 @@ ...@@ -151,7 +151,6 @@
#include "ngraph/runtime/cpu/pass/cpu_mat_fusion.hpp" #include "ngraph/runtime/cpu/pass/cpu_mat_fusion.hpp"
#include "ngraph/runtime/cpu/pass/cpu_post_layout_optimizations.hpp" #include "ngraph/runtime/cpu/pass/cpu_post_layout_optimizations.hpp"
#include "ngraph/runtime/cpu/pass/cpu_rnn_fusion.hpp" #include "ngraph/runtime/cpu/pass/cpu_rnn_fusion.hpp"
#include "ngraph/runtime/cpu/pass/cpu_shuffle_folding.hpp"
#include "ngraph/runtime/cpu/pass/cpu_workspace_insertion.hpp" #include "ngraph/runtime/cpu/pass/cpu_workspace_insertion.hpp"
#ifdef NGRAPH_DISTRIBUTED #ifdef NGRAPH_DISTRIBUTED
...@@ -374,7 +373,6 @@ void runtime::cpu::CPU_ExternalFunction::compile() ...@@ -374,7 +373,6 @@ void runtime::cpu::CPU_ExternalFunction::compile()
pass_manager.register_pass<runtime::cpu::pass::CPUAssignment>(this); pass_manager.register_pass<runtime::cpu::pass::CPUAssignment>(this);
pass_manager.register_pass<runtime::cpu::pass::CPULayout>(this); pass_manager.register_pass<runtime::cpu::pass::CPULayout>(this);
pass_manager.register_pass<runtime::cpu::pass::CPUPostLayoutOptimizations>(); pass_manager.register_pass<runtime::cpu::pass::CPUPostLayoutOptimizations>();
pass_manager.register_pass<runtime::cpu::pass::CPUShuffleFolding>();
pass_manager.register_pass<ngraph::pass::ResultCopyElimination>(); pass_manager.register_pass<ngraph::pass::ResultCopyElimination>();
pass_manager.register_pass<ngraph::pass::GetOutputElementElimination>(); pass_manager.register_pass<ngraph::pass::GetOutputElementElimination>();
unordered_map<Node*, Node*> node_function_map; unordered_map<Node*, Node*> node_function_map;
...@@ -980,17 +978,19 @@ using namespace ngraph::runtime; ...@@ -980,17 +978,19 @@ using namespace ngraph::runtime;
void runtime::cpu::CPU_ExternalFunction::propagate_in_place_input( void runtime::cpu::CPU_ExternalFunction::propagate_in_place_input(
ngraph::descriptor::Output* output, std::string input_name) ngraph::descriptor::Output* output, std::string input_name)
{ {
auto it = output; std::deque<ngraph::descriptor::Output*> stack;
auto propagate_further = false; stack.push_front(output);
do
while (stack.size() > 0)
{ {
propagate_further = false; ngraph::descriptor::Output* it = stack.front();
stack.pop_front();
for (auto input : it->get_inputs()) for (auto input : it->get_inputs())
{ {
auto c_op = std::dynamic_pointer_cast<ngraph::op::Op>(input->get_node()); auto c_op = std::dynamic_pointer_cast<ngraph::op::Op>(input->get_node());
if (!c_op || c_op->is_output()) if (!c_op || c_op->is_output())
{ {
break; continue;
} }
if (auto op_annotations = c_op->get_op_annotations()) if (auto op_annotations = c_op->get_op_annotations())
...@@ -1003,13 +1003,14 @@ void runtime::cpu::CPU_ExternalFunction::propagate_in_place_input( ...@@ -1003,13 +1003,14 @@ void runtime::cpu::CPU_ExternalFunction::propagate_in_place_input(
auto& output_tensor = c_op->get_outputs().at(output_index).get_tensor(); auto& output_tensor = c_op->get_outputs().at(output_index).get_tensor();
m_variable_name_map[output_tensor.get_name()] = input_name; m_variable_name_map[output_tensor.get_name()] = input_name;
it = &c_op->get_outputs().at(output_index); NGRAPH_DEBUG << "CPU codegen: Forwarding " << input_name << " through "
propagate_further = true; << output_tensor.get_name();
stack.push_back(&c_op->get_outputs().at(output_index));
} }
} }
} }
} }
} while (propagate_further); }
} }
void runtime::cpu::CPU_ExternalFunction::propagate_in_place_output( void runtime::cpu::CPU_ExternalFunction::propagate_in_place_output(
...@@ -1078,7 +1079,6 @@ void runtime::cpu::CPU_ExternalFunction::build() ...@@ -1078,7 +1079,6 @@ void runtime::cpu::CPU_ExternalFunction::build()
pass_manager.register_pass<runtime::cpu::pass::CPUAssignment>(this); pass_manager.register_pass<runtime::cpu::pass::CPUAssignment>(this);
pass_manager.register_pass<runtime::cpu::pass::CPULayout>(this); pass_manager.register_pass<runtime::cpu::pass::CPULayout>(this);
pass_manager.register_pass<runtime::cpu::pass::CPUPostLayoutOptimizations>(); pass_manager.register_pass<runtime::cpu::pass::CPUPostLayoutOptimizations>();
pass_manager.register_pass<runtime::cpu::pass::CPUShuffleFolding>();
pass_manager.register_pass<ngraph::pass::ResultCopyElimination>(); pass_manager.register_pass<ngraph::pass::ResultCopyElimination>();
pass_manager.register_pass<ngraph::pass::GetOutputElementElimination>(); pass_manager.register_pass<ngraph::pass::GetOutputElementElimination>();
pass_manager.register_pass<ngraph::pass::Liveness>(); pass_manager.register_pass<ngraph::pass::Liveness>();
......
...@@ -18,63 +18,46 @@ ...@@ -18,63 +18,46 @@
#include <algorithm> #include <algorithm>
#include <numeric> #include <numeric>
#include "ngraph/runtime/cpu/mkldnn_utils.hpp"
namespace ngraph namespace ngraph
{ {
namespace runtime namespace runtime
{ {
namespace cpu namespace cpu
{ {
const AxisVector LayoutDescriptor::Native2DAxisOrder{0, 1}; const mkldnn::memory::desc
const AxisVector LayoutDescriptor::Native4DAxisOrder{0, 1, 2, 3}; LayoutDescriptor::DummyDesc(mkldnn::memory::dims(TENSOR_MAX_DIMS),
const AxisVector LayoutDescriptor::CHWNAxisOrder{1, 2, 3, 0}; mkldnn::memory::f32,
mkldnn::memory::format::format_undef);
AxisVector LayoutDescriptor::create_native_axis_order(size_t rank)
{
AxisVector native_axis_order(rank);
std::iota(native_axis_order.begin(), native_axis_order.end(), 0);
return native_axis_order;
}
LayoutDescriptor::LayoutDescriptor(const ngraph::descriptor::TensorView& tv, LayoutDescriptor::LayoutDescriptor(const ngraph::descriptor::TensorView& tv)
const AxisVector& tv_axis_order)
: TensorViewLayout(tv) : TensorViewLayout(tv)
, axis_order(tv_axis_order) , m_offset(0)
, offset(0) , m_size(ngraph::shape_size(tv.get_tensor_view_type()->get_shape()))
, size(ngraph::shape_size(tv.get_tensor_view_type()->get_shape())) , m_mkldnn_md(LayoutDescriptor::DummyDesc)
, mkldnn_format(mkldnn::memory::format::format_undef)
{ {
auto shape = get_shape(); auto shape = get_shape();
size_t s = 1; size_t s = 1;
if (tv_axis_order.size() != shape.size()) for (size_t i = 0; i < shape.size(); i++)
{ {
throw ngraph_error("Axis order is incomplete"); m_strides.emplace_back(s);
s *= shape[shape.size() - (i + 1)];
} }
std::reverse(m_strides.begin(), m_strides.end());
for (auto it = tv_axis_order.crbegin(); it != tv_axis_order.crend(); it++)
{
if (*it >= shape.size())
{
throw ngraph_error("Axis is out of bounds");
}
strides.emplace_back(s);
s *= shape[*it];
}
std::reverse(strides.begin(), strides.end());
} }
void LayoutDescriptor::set_axis_order(const AxisVector& perm) { axis_order = perm; }
size_t LayoutDescriptor::get_index_offset(const std::vector<size_t>& indices) size_t LayoutDescriptor::get_index_offset(const std::vector<size_t>& indices)
{ {
if (indices.size() != strides.size()) if (indices.size() != m_strides.size())
{ {
throw ngraph_error("Indices have incorrect rank"); throw ngraph_error("Indices have incorrect rank");
} }
size_t result = 0; size_t result = 0;
for (int i = 0; i < indices.size(); i++) for (int i = 0; i < indices.size(); i++)
{ {
result += strides[i] + indices[i]; result += m_strides[i] * indices[i];
} }
return result; return result;
} }
...@@ -93,12 +76,22 @@ namespace ngraph ...@@ -93,12 +76,22 @@ namespace ngraph
return false; return false;
} }
if (strides != p_other->strides) if (p_other->is_mkldnn_layout())
{
if (!is_mkldnn_layout())
{
return false;
}
return runtime::cpu::mkldnn_utils::compare_mkldnn_mds(m_mkldnn_md,
p_other->get_mkldnn_md());
}
if (m_strides != p_other->m_strides)
{ {
return false; return false;
} }
if (offset != p_other->offset) if (m_offset != p_other->m_offset)
{ {
return false; return false;
} }
......
...@@ -23,7 +23,6 @@ ...@@ -23,7 +23,6 @@
#include <mkldnn.hpp> #include <mkldnn.hpp>
#include "ngraph/axis_vector.hpp"
#include "ngraph/descriptor/layout/tensor_view_layout.hpp" #include "ngraph/descriptor/layout/tensor_view_layout.hpp"
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
#include "ngraph/type/type.hpp" #include "ngraph/type/type.hpp"
...@@ -37,36 +36,36 @@ namespace ngraph ...@@ -37,36 +36,36 @@ namespace ngraph
class LayoutDescriptor : public ngraph::descriptor::layout::TensorViewLayout class LayoutDescriptor : public ngraph::descriptor::layout::TensorViewLayout
{ {
public: public:
LayoutDescriptor(const ngraph::descriptor::TensorView& tv, LayoutDescriptor(const ngraph::descriptor::TensorView& tv);
const AxisVector& tv_axis_order);
~LayoutDescriptor() override {} ~LayoutDescriptor() override {}
size_t get_size() override { return size; } size_t get_size() override { return m_size; }
size_t get_offset() const { return offset; } size_t get_offset() const { return m_offset; }
size_t get_index_offset(const std::vector<size_t>& indices) override; size_t get_index_offset(const std::vector<size_t>& indices) override;
const Strides& get_strides() const override { return strides; } const Strides& get_strides() const override { return m_strides; }
void set_strides(Strides& strides) { m_strides = strides; }
bool operator==(const TensorViewLayout& other) const override; bool operator==(const TensorViewLayout& other) const override;
void set_mkldnn_format(const mkldnn::memory::format& format) const mkldnn::memory::desc& get_mkldnn_md() const { return m_mkldnn_md; }
void set_mkldnn_md(const mkldnn::memory::desc md) { m_mkldnn_md = md; }
bool is_mkldnn_layout() const
{ {
mkldnn_format = format; return m_mkldnn_md.data.format != mkldnn::memory::format::format_undef;
} }
mkldnn::memory::format get_mkldnn_format() const { return mkldnn_format; }
const AxisVector& get_axis_order() const { return axis_order; } static const mkldnn::memory::desc DummyDesc;
void set_axis_order(const AxisVector& perm);
static const AxisVector Native2DAxisOrder;
static const AxisVector Native4DAxisOrder;
static const AxisVector CHWNAxisOrder;
static AxisVector create_native_axis_order(size_t rank);
private: private:
AxisVector axis_order; // Native row-major layout for now
Strides strides; Strides m_strides;
size_t offset; size_t m_offset;
size_t size; size_t m_size;
// Numeric backend-specific fields // For tensor views that can be tracked with MKLDNN memory
mkldnn::memory::format mkldnn_format; // descriptors, this holds the physical layout information
// Otherwise, physical layout is assumed to be in row-major
// format represented by m_strides
mkldnn::memory::desc m_mkldnn_md;
}; };
typedef std::vector<std::shared_ptr<ngraph::runtime::cpu::LayoutDescriptor>> typedef std::vector<std::shared_ptr<ngraph::runtime::cpu::LayoutDescriptor>>
......
...@@ -46,8 +46,8 @@ runtime::cpu::CPUTensorView::CPUTensorView(const ngraph::element::Type& element_ ...@@ -46,8 +46,8 @@ runtime::cpu::CPUTensorView::CPUTensorView(const ngraph::element::Type& element_
// TODO(jmenon): A fallback layout should not be needed but is required // TODO(jmenon): A fallback layout should not be needed but is required
// because of how some unit test functionality is written (ex. 'backprop_derivative') // because of how some unit test functionality is written (ex. 'backprop_derivative')
// This needs to be removed // This needs to be removed
m_descriptor->set_tensor_view_layout(std::make_shared<runtime::cpu::LayoutDescriptor>( m_descriptor->set_tensor_view_layout(
*m_descriptor, runtime::cpu::LayoutDescriptor::create_native_axis_order(shape.size()))); std::make_shared<runtime::cpu::LayoutDescriptor>(*m_descriptor));
buffer_size = shape_size(shape) * element_type.size(); buffer_size = shape_size(shape) * element_type.size();
...@@ -119,23 +119,42 @@ void runtime::cpu::CPUTensorView::read(void* target, size_t tensor_offset, size_ ...@@ -119,23 +119,42 @@ void runtime::cpu::CPUTensorView::read(void* target, size_t tensor_offset, size_
auto tvl = this->get_tensor_view_layout(); auto tvl = this->get_tensor_view_layout();
auto cpu_tvl = dynamic_cast<runtime::cpu::LayoutDescriptor*>(tvl.get()); auto cpu_tvl = dynamic_cast<runtime::cpu::LayoutDescriptor*>(tvl.get());
if (cpu_tvl && cpu_tvl->get_mkldnn_format() != memory::format::format_undef &&
!runtime::cpu::mkldnn_utils::compare_mkldnn_formats( auto needs_conversion = [&]() {
cpu_tvl->get_mkldnn_format(), if (!cpu_tvl)
runtime::cpu::mkldnn_utils::CreateNativeDataFormat(*cpu_tvl))) {
return false;
}
if (!cpu_tvl->is_mkldnn_layout())
{
return false;
}
if (cpu_tvl->get_size() <= 1)
{
return false;
}
auto native_md = mkldnn_utils::create_blocked_mkldnn_md(
this->get_shape(),
cpu_tvl->get_strides(),
this->get_descriptor()->get_tensor_view_type()->get_element_type());
if (mkldnn_utils::compare_mkldnn_mds(cpu_tvl->get_mkldnn_md(), native_md))
{
return false;
}
return true;
};
if (needs_conversion())
{ {
auto tensor_shape = this->get_shape(); auto tensor_shape = this->get_shape();
auto input_format = cpu_tvl->get_mkldnn_format(); auto input_desc = cpu_tvl->get_mkldnn_md();
auto output_format = runtime::cpu::mkldnn_utils::CreateNativeDataFormat(*cpu_tvl); auto output_desc = mkldnn_utils::create_blocked_mkldnn_md(
memory::data_type et = runtime::cpu::mkldnn_utils::get_mkldnn_data_type( this->get_shape(),
cpu_tvl->get_strides(),
this->get_descriptor()->get_tensor_view_type()->get_element_type()); this->get_descriptor()->get_tensor_view_type()->get_element_type());
engine cpu_engine{engine::cpu, 0}; memory input{{input_desc, mkldnn_utils::global_cpu_engine}, aligned_buffer};
memory::dims mkldnn_shape{tensor_shape.begin(), tensor_shape.end()}; memory output{{output_desc, mkldnn_utils::global_cpu_engine}, target};
memory::desc input_desc{mkldnn_shape, et, input_format};
memory::desc output_desc{mkldnn_shape, et, output_format};
memory input{{input_desc, cpu_engine}, aligned_buffer};
memory output{{output_desc, cpu_engine}, target};
reorder prim{input, output}; reorder prim{input, output};
mkldnn::stream s(mkldnn::stream::kind::eager); mkldnn::stream s(mkldnn::stream::kind::eager);
s.submit({prim}).wait(); s.submit({prim}).wait();
......
...@@ -64,24 +64,24 @@ const std::vector<size_t>& MKLDNNEmitter::get_primitive_deps(size_t index) const ...@@ -64,24 +64,24 @@ const std::vector<size_t>& MKLDNNEmitter::get_primitive_deps(size_t index) const
mkldnn::memory::desc MKLDNNEmitter::build_memory_descriptor(const TensorViewWrapper& tvw, mkldnn::memory::desc MKLDNNEmitter::build_memory_descriptor(const TensorViewWrapper& tvw,
mkldnn::memory::format fmt) const mkldnn::memory::format fmt) const
{ {
if (fmt == mkldnn::memory::format::blocked)
{
throw ngraph_error("Cannot created blocked descriptor.");
}
return mkldnn::memory::desc( return mkldnn::memory::desc(
mkldnn::memory::dims(tvw.get_shape().begin(), tvw.get_shape().end()), mkldnn::memory::dims(tvw.get_shape().begin(), tvw.get_shape().end()),
mkldnn_utils::get_mkldnn_data_type(tvw.get_element_type()), mkldnn_utils::get_mkldnn_data_type(tvw.get_element_type()),
fmt); fmt);
} }
mkldnn::memory::desc MKLDNNEmitter::build_memory_descriptor(const TensorViewWrapper& tvw) const
{
auto layout =
std::static_pointer_cast<LayoutDescriptor>(tvw.get_tensor_view()->get_tensor_view_layout());
return build_memory_descriptor(tvw, layout->get_mkldnn_format());
}
mkldnn::memory::desc MKLDNNEmitter::build_memory_descriptor(const Shape& shape, mkldnn::memory::desc MKLDNNEmitter::build_memory_descriptor(const Shape& shape,
const ngraph::element::Type& et, const ngraph::element::Type& et,
mkldnn::memory::format fmt) const mkldnn::memory::format fmt) const
{ {
if (fmt == mkldnn::memory::format::blocked)
{
throw ngraph_error("Cannot created blocked descriptor");
}
return mkldnn::memory::desc(mkldnn::memory::dims(shape.begin(), shape.end()), return mkldnn::memory::desc(mkldnn::memory::dims(shape.begin(), shape.end()),
mkldnn_utils::get_mkldnn_data_type(et), mkldnn_utils::get_mkldnn_data_type(et),
fmt); fmt);
...@@ -112,11 +112,6 @@ mkldnn::memory::desc ...@@ -112,11 +112,6 @@ mkldnn::memory::desc
return mkldnn::memory::desc(md); return mkldnn::memory::desc(md);
} }
mkldnn::memory MKLDNNEmitter::build_memory_primitive(const TensorViewWrapper& tvw) const
{
return mkldnn::memory({build_memory_descriptor(tvw), mkldnn_utils::global_cpu_engine}, nullptr);
}
size_t MKLDNNEmitter::build_memory_primitive(const mkldnn::memory::desc& desc) size_t MKLDNNEmitter::build_memory_primitive(const mkldnn::memory::desc& desc)
{ {
// The MKL-DNN C++ API forces proper initialization of a memory primitive // The MKL-DNN C++ API forces proper initialization of a memory primitive
...@@ -200,24 +195,34 @@ size_t MKLDNNEmitter::build_convolution_forward(const mkldnn::memory::desc& inpu ...@@ -200,24 +195,34 @@ size_t MKLDNNEmitter::build_convolution_forward(const mkldnn::memory::desc& inpu
mkldnn::primitive_attr conv_attr; mkldnn::primitive_attr conv_attr;
conv_attr.set_post_ops(pops); conv_attr.set_post_ops(pops);
size_t conv_index = insert_primitive(new mkldnn::convolution_forward( size_t conv_index = 0;
{{mkldnn::prop_kind::forward, try
mkldnn::algorithm::convolution_direct, {
input_data_desc, auto conv_prim = new mkldnn::convolution_forward(
weights_desc, {{mkldnn::prop_kind::forward,
result_desc, mkldnn::algorithm::convolution_direct,
mkldnn::memory::dims(strides.begin(), strides.end()), input_data_desc,
mkldnn::memory::dims(dilation_strides.begin(), dilation_strides.end()), weights_desc,
mkldnn::memory::dims(padding_below.begin(), padding_below.end()), result_desc,
mkldnn::memory::dims(padding_above.begin(), padding_above.end()), mkldnn::memory::dims(strides.begin(), strides.end()),
mkldnn::padding_kind::zero}, mkldnn::memory::dims(dilation_strides.begin(), dilation_strides.end()),
conv_attr, mkldnn::memory::dims(padding_below.begin(), padding_below.end()),
mkldnn_utils::global_cpu_engine}, mkldnn::memory::dims(padding_above.begin(), padding_above.end()),
*m_mkldnn_primitives[input_data_index], mkldnn::padding_kind::zero},
*m_mkldnn_primitives[weights_index], conv_attr,
*m_mkldnn_primitives[result_index])); mkldnn_utils::global_cpu_engine},
*m_mkldnn_primitives[input_data_index],
*m_mkldnn_primitives[weights_index],
*m_mkldnn_primitives[result_index]);
conv_index = insert_primitive(conv_prim);
m_primitive_deps[conv_index] = {input_data_index, weights_index, result_index}; m_primitive_deps[conv_index] = {input_data_index, weights_index, result_index};
}
catch (const mkldnn::error& e)
{
throw ngraph_error("Could not create mkldnn convolution " + e.message);
}
return conv_index; return conv_index;
} }
...@@ -239,26 +244,34 @@ size_t MKLDNNEmitter::build_convolution_forward(const mkldnn::memory::desc& inpu ...@@ -239,26 +244,34 @@ size_t MKLDNNEmitter::build_convolution_forward(const mkldnn::memory::desc& inpu
mkldnn::primitive_attr conv_attr; mkldnn::primitive_attr conv_attr;
conv_attr.set_post_ops(pops); conv_attr.set_post_ops(pops);
const size_t conv_index = insert_primitive(new mkldnn::convolution_forward( size_t conv_index = -1;
{{mkldnn::prop_kind::forward, try
mkldnn::algorithm::convolution_direct, {
input_data_desc, conv_index = insert_primitive(new mkldnn::convolution_forward(
weights_desc, {{mkldnn::prop_kind::forward,
bias_desc, mkldnn::algorithm::convolution_direct,
result_desc, input_data_desc,
mkldnn::memory::dims(strides.begin(), strides.end()), weights_desc,
mkldnn::memory::dims(dilation_strides.begin(), dilation_strides.end()), bias_desc,
mkldnn::memory::dims(padding_below.begin(), padding_below.end()), result_desc,
mkldnn::memory::dims(padding_above.begin(), padding_above.end()), mkldnn::memory::dims(strides.begin(), strides.end()),
mkldnn::padding_kind::zero}, mkldnn::memory::dims(dilation_strides.begin(), dilation_strides.end()),
conv_attr, mkldnn::memory::dims(padding_below.begin(), padding_below.end()),
mkldnn_utils::global_cpu_engine}, mkldnn::memory::dims(padding_above.begin(), padding_above.end()),
*m_mkldnn_primitives[input_data_index], mkldnn::padding_kind::zero},
*m_mkldnn_primitives[weights_index], conv_attr,
*m_mkldnn_primitives[bias_index], mkldnn_utils::global_cpu_engine},
*m_mkldnn_primitives[result_index])); *m_mkldnn_primitives[input_data_index],
*m_mkldnn_primitives[weights_index],
*m_mkldnn_primitives[bias_index],
*m_mkldnn_primitives[result_index]));
m_primitive_deps[conv_index] = {input_data_index, weights_index, bias_index, result_index}; m_primitive_deps[conv_index] = {input_data_index, weights_index, bias_index, result_index};
}
catch (const mkldnn::error& e)
{
throw ngraph_error("Could not create convolution " + e.message);
}
return conv_index; return conv_index;
} }
......
...@@ -66,7 +66,6 @@ namespace ngraph ...@@ -66,7 +66,6 @@ namespace ngraph
// TODO(jmenon): Get rid of TensorViewWrappers at some point // TODO(jmenon): Get rid of TensorViewWrappers at some point
mkldnn::memory::desc build_memory_descriptor(const TensorViewWrapper& tvw, mkldnn::memory::desc build_memory_descriptor(const TensorViewWrapper& tvw,
mkldnn::memory::format fmt) const; mkldnn::memory::format fmt) const;
mkldnn::memory::desc build_memory_descriptor(const TensorViewWrapper& tvw) const;
mkldnn::memory::desc build_memory_descriptor(const Shape& shape, mkldnn::memory::desc build_memory_descriptor(const Shape& shape,
const ngraph::element::Type& et, const ngraph::element::Type& et,
mkldnn::memory::format fmt) const; mkldnn::memory::format fmt) const;
...@@ -74,7 +73,6 @@ namespace ngraph ...@@ -74,7 +73,6 @@ namespace ngraph
build_blocked_memory_descriptor(const mkldnn::memory::dims& dim, build_blocked_memory_descriptor(const mkldnn::memory::dims& dim,
const mkldnn::memory::dims& strides, const mkldnn::memory::dims& strides,
mkldnn::memory::data_type dtype) const; mkldnn::memory::data_type dtype) const;
mkldnn::memory build_memory_primitive(const TensorViewWrapper& tvw) const;
size_t build_memory_primitive(const mkldnn::memory::desc& desc); size_t build_memory_primitive(const mkldnn::memory::desc& desc);
size_t build_convolution_forward(const mkldnn::memory::desc& input_data_desc, size_t build_convolution_forward(const mkldnn::memory::desc& input_data_desc,
...@@ -115,19 +113,16 @@ namespace ngraph ...@@ -115,19 +113,16 @@ namespace ngraph
window_dilation_strides_adjusted.push_back(s - 1); window_dilation_strides_adjusted.push_back(s - 1);
} }
auto data_format = mkldnn_utils::get_input_mkldnn_format(node, 0); auto data_desc = mkldnn_utils::get_input_mkldnn_md(node, 0);
auto weights_format = mkldnn_utils::get_input_mkldnn_format(node, 1); auto weights_desc = mkldnn_utils::get_input_mkldnn_md(node, 1);
// HACK to help MKLDNN pick the right implementation // MKLDNN relies on named formats for kernel selection
if (weights_format == mkldnn::memory::format::nchw) if (weights_desc.data.format == mkldnn_nchw)
{ weights_desc.data.format = mkldnn_oihw;
weights_format = mkldnn::memory::format::oihw; if (weights_desc.data.format == mkldnn_ncdhw)
} weights_desc.data.format = mkldnn_oidhw;
auto result_format = mkldnn_utils::get_output_mkldnn_format(node, 0);
auto data_desc = build_memory_descriptor(args[0], data_format); auto result_desc = mkldnn_utils::get_output_mkldnn_md(node, 0);
auto weights_desc = build_memory_descriptor(args[1], weights_format);
auto result_desc = build_memory_descriptor(out[0], result_format);
mkldnn::post_ops ops; mkldnn::post_ops ops;
...@@ -166,8 +161,7 @@ namespace ngraph ...@@ -166,8 +161,7 @@ namespace ngraph
if (std::is_same<OP, ngraph::op::ConvolutionBias>() || if (std::is_same<OP, ngraph::op::ConvolutionBias>() ||
std::is_same<OP, ngraph::op::ConvolutionBiasAdd>()) std::is_same<OP, ngraph::op::ConvolutionBiasAdd>())
{ {
auto bias_format = mkldnn_utils::get_input_mkldnn_format(node, 2); auto bias_desc = mkldnn_utils::get_input_mkldnn_md(node, 2);
auto bias_desc = build_memory_descriptor(args[2], bias_format);
return build_convolution_forward(data_desc, return build_convolution_forward(data_desc,
weights_desc, weights_desc,
bias_desc, bias_desc,
...@@ -254,22 +248,18 @@ namespace ngraph ...@@ -254,22 +248,18 @@ namespace ngraph
window_dilation_strides_adjusted.push_back(s - 1); window_dilation_strides_adjusted.push_back(s - 1);
} }
auto arg0_format = mkldnn_utils::get_input_mkldnn_format(node, 0); auto arg0_desc = mkldnn_utils::get_input_mkldnn_md(node, 0);
if (std::is_same<OP, ngraph::op::ConvolutionBackpropData>()) auto arg1_desc = mkldnn_utils::get_input_mkldnn_md(node, 1);
{ auto out0_desc = mkldnn_utils::get_output_mkldnn_md(node, 0);
// HACK to help MKLDNN pick the right implementation
arg0_format = (arg0_format == mkldnn::memory::format::nchw)
? mkldnn::memory::format::oihw
: arg0_format;
}
auto arg0_desc = build_memory_descriptor(args[0], arg0_format);
auto arg1_format = mkldnn_utils::get_input_mkldnn_format(node, 1);
auto arg1_desc = build_memory_descriptor(args[1], arg1_format);
auto out0_format = mkldnn_utils::get_output_mkldnn_format(node, 0);
auto out0_desc = build_memory_descriptor(out[0], out0_format);
if (std::is_same<OP, ngraph::op::ConvolutionBackpropData>()) if (std::is_same<OP, ngraph::op::ConvolutionBackpropData>())
{ {
// MKLDNN relies on named formats for kernel selection
if (arg0_desc.data.format == mkldnn_nchw)
arg0_desc.data.format = mkldnn_oihw;
if (arg0_desc.data.format == mkldnn_ncdhw)
arg0_desc.data.format = mkldnn_oidhw;
return build_convolution_backward_data( return build_convolution_backward_data(
arg0_desc, arg0_desc,
arg1_desc, arg1_desc,
...@@ -292,8 +282,7 @@ namespace ngraph ...@@ -292,8 +282,7 @@ namespace ngraph
} }
if (std::is_same<OP, ngraph::op::ConvolutionBiasBackpropFiltersBias>()) if (std::is_same<OP, ngraph::op::ConvolutionBiasBackpropFiltersBias>())
{ {
auto out1_format = mkldnn_utils::get_output_mkldnn_format(node, 1); auto out1_desc = mkldnn_utils::get_output_mkldnn_md(node, 1);
auto out1_desc = build_memory_descriptor(out[1], out1_format);
return build_convolution_backward_weights_bias( return build_convolution_backward_weights_bias(
arg0_desc, arg0_desc,
arg1_desc, arg1_desc,
......
...@@ -36,5 +36,12 @@ extern "C" void ngraph::runtime::cpu::mkldnn_utils::mkldnn_invoke_primitive(CPUR ...@@ -36,5 +36,12 @@ extern "C" void ngraph::runtime::cpu::mkldnn_utils::mkldnn_invoke_primitive(CPUR
size_t primitive_index) size_t primitive_index)
{ {
mkldnn::stream s(mkldnn::stream::kind::eager); mkldnn::stream s(mkldnn::stream::kind::eager);
s.submit({*ctx->mkldnn_primitives[primitive_index]}).wait(); try
{
s.submit({*ctx->mkldnn_primitives[primitive_index]}).wait();
}
catch (const mkldnn::error& e)
{
throw ngraph_error("Could not run mkdnn primitive " + e.message);
}
} }
...@@ -193,7 +193,7 @@ mkldnn::memory::data_type ...@@ -193,7 +193,7 @@ mkldnn::memory::data_type
runtime::cpu::mkldnn_utils::get_mkldnn_data_type(const ngraph::element::Type& type) runtime::cpu::mkldnn_utils::get_mkldnn_data_type(const ngraph::element::Type& type)
{ {
auto it = s_mkldnn_data_type_map.find(type); auto it = s_mkldnn_data_type_map.find(type);
if (it == s_mkldnn_data_type_map.end() || it->second == memory::data_type::data_undef) if (it == s_mkldnn_data_type_map.end())
{ {
throw ngraph_error("No MKLDNN data type exists for the given element type"); throw ngraph_error("No MKLDNN data type exists for the given element type");
} }
...@@ -209,18 +209,211 @@ const std::string& runtime::cpu::mkldnn_utils::get_mkldnn_format_string(memory:: ...@@ -209,18 +209,211 @@ const std::string& runtime::cpu::mkldnn_utils::get_mkldnn_format_string(memory::
return it->second; return it->second;
} }
mkldnn::memory::format runtime::cpu::mkldnn_utils::get_input_mkldnn_format(const Node* node, const mkldnn::memory::desc& runtime::cpu::mkldnn_utils::get_input_mkldnn_md(const Node* node,
size_t index) size_t index)
{ {
auto tvl = node->get_inputs()[index].get_output().get_tensor_view()->get_tensor_view_layout(); auto cpu_tvl = dynamic_pointer_cast<runtime::cpu::LayoutDescriptor>(
return dynamic_cast<runtime::cpu::LayoutDescriptor&>(*tvl).get_mkldnn_format(); node->get_inputs()[index].get_output().get_tensor_view()->get_tensor_view_layout());
return cpu_tvl->get_mkldnn_md();
} }
mkldnn::memory::format runtime::cpu::mkldnn_utils::get_output_mkldnn_format(const Node* node, const mkldnn::memory::desc& runtime::cpu::mkldnn_utils::get_output_mkldnn_md(const Node* node,
size_t index) size_t index)
{ {
auto tvl = node->get_output_tensor_view(index)->get_tensor_view_layout(); auto tvl = node->get_output_tensor_view(index)->get_tensor_view_layout();
return dynamic_cast<runtime::cpu::LayoutDescriptor&>(*tvl).get_mkldnn_format(); return dynamic_cast<runtime::cpu::LayoutDescriptor&>(*tvl).get_mkldnn_md();
}
mkldnn::memory::desc runtime::cpu::mkldnn_utils::create_default_mkldnn_md(
const Node* node,
size_t index,
bool output = false,
mkldnn::memory::format format = mkldnn::memory::format::any)
{
Shape shape;
mkldnn::memory::data_type et;
if (output)
{
shape = node->get_output_shape(index);
et = runtime::cpu::mkldnn_utils::get_mkldnn_data_type(node->get_output_element_type(0));
}
else
{
shape = node->get_input_shape(index);
et = runtime::cpu::mkldnn_utils::get_mkldnn_data_type(node->get_input_element_type(0));
}
return memory::desc(memory::dims(shape.begin(), shape.end()), et, format);
}
bool runtime::cpu::mkldnn_utils::can_create_mkldnn_md(const Shape& dims,
const Strides& strides,
const ngraph::element::Type type)
{
auto it = s_mkldnn_data_type_map.find(type);
if (it == s_mkldnn_data_type_map.end() || it->second == mkldnn::memory::data_type::data_undef)
{
return false;
}
if (dims.size() > TENSOR_MAX_DIMS)
{
return false;
}
if (shape_size(dims) == 0)
{
return false;
}
return true;
}
bool runtime::cpu::mkldnn_utils::is_perm_sorted(const Strides& a, const AxisVector& perm)
{
for (size_t i = 0; i < a.size() - 1; i++)
{
if (a[perm[i]] < a[perm[i + 1]])
return false;
}
return true;
}
mkldnn::memory::desc runtime::cpu::mkldnn_utils::create_blocked_mkldnn_md(
const Shape& dims, const Strides& strides, const ngraph::element::Type type)
{
memory::dims dim(dims.begin(), dims.end());
memory::dims stride(strides.begin(), strides.end());
memory::data_type dtype = get_mkldnn_data_type(type);
if (dims.size() == 1)
{
return memory::desc(dim, dtype, memory::format::x);
}
if (dims.size() == 2)
{
if (is_perm_sorted(strides, {0, 1}))
{
return memory::desc(dim, dtype, memory::format::nc);
}
}
if (dims.size() == 4)
{
if (is_perm_sorted(strides, {0, 1, 2, 3}))
{
return memory::desc(dim, dtype, memory::format::nchw);
}
if (is_perm_sorted(strides, {0, 2, 3, 1}))
{
return memory::desc(dim, dtype, memory::format::nhwc);
}
}
if (dims.size() == 5)
{
if (is_perm_sorted(strides, {0, 1, 2, 3, 4}))
{
return memory::desc(dim, dtype, memory::format::ncdhw);
}
if (is_perm_sorted(strides, {0, 2, 3, 4, 1}))
{
return memory::desc(dim, dtype, memory::format::ndhwc);
}
}
mkldnn_memory_desc_t md;
md.primitive_kind = mkldnn_memory;
md.ndims = static_cast<int>(dim.size());
md.format = mkldnn_blocked;
md.data_type = mkldnn::memory::convert_to_c(dtype);
for (size_t i = 0; i < dim.size(); i++)
{
md.layout_desc.blocking.block_dims[i] = 1;
md.layout_desc.blocking.strides[1][i] = 1;
md.layout_desc.blocking.strides[0][i] = stride[i];
md.layout_desc.blocking.padding_dims[i] = dim[i];
md.layout_desc.blocking.offset_padding_to_data[i] = 0;
md.dims[i] = dim[i];
}
md.layout_desc.blocking.offset_padding = 0;
return memory::desc(md);
}
memory::desc runtime::cpu::mkldnn_utils::rotate_blocked_md(const memory::desc& in,
AxisVector& axis_order)
{
mkldnn_memory_desc_t md;
md.primitive_kind = in.data.primitive_kind;
md.ndims = in.data.ndims;
md.format = mkldnn_blocked;
md.data_type = in.data.data_type;
for (size_t i = 0; i < in.data.ndims; i++)
{
md.layout_desc.blocking.block_dims[i] =
in.data.layout_desc.blocking.block_dims[axis_order[i]];
md.layout_desc.blocking.strides[1][i] =
in.data.layout_desc.blocking.strides[1][axis_order[i]];
md.layout_desc.blocking.strides[0][i] =
in.data.layout_desc.blocking.strides[0][axis_order[i]];
md.layout_desc.blocking.padding_dims[i] =
in.data.layout_desc.blocking.padding_dims[axis_order[i]];
md.layout_desc.blocking.offset_padding_to_data[i] =
in.data.layout_desc.blocking.offset_padding_to_data[axis_order[i]];
md.dims[i] = in.data.dims[axis_order[i]];
}
md.layout_desc.blocking.offset_padding = in.data.layout_desc.blocking.offset_padding;
auto out_md = memory::desc(md);
auto get_named_md = [](const mkldnn_memory_desc_t& blk, const mkldnn_memory_format_t format) {
mkldnn_memory_desc_t named_md;
// Could throw an exception if named `format` is not compatible with `md.dims`
error::wrap_c_api(
mkldnn_memory_desc_init(&named_md, blk.ndims, blk.dims, blk.data_type, format), "");
return memory::desc(named_md);
};
auto compare_named_md = [&](const mkldnn_memory_desc_t& blk,
const mkldnn_memory_format_t format,
const memory::desc& out) {
try
{
auto named_md = get_named_md(blk, format);
if (compare_mkldnn_mds(named_md, out))
{
return true;
}
}
catch (const mkldnn::error&)
{
// Cannot create the named descriptor compatible with `in` desc
return false;
}
return false;
};
#define CANONICALIZE_MD(X) \
if (compare_named_md(md, X, out_md)) \
return get_named_md(md, X);
switch (md.ndims)
{
case 1: CANONICALIZE_MD(mkldnn_x); break;
case 2: CANONICALIZE_MD(mkldnn_nc); break;
case 4:
CANONICALIZE_MD(mkldnn_nchw);
CANONICALIZE_MD(mkldnn_nhwc);
CANONICALIZE_MD(mkldnn_nChw8c);
CANONICALIZE_MD(mkldnn_nChw16c);
break;
case 5:
CANONICALIZE_MD(mkldnn_ncdhw);
CANONICALIZE_MD(mkldnn_ndhwc);
CANONICALIZE_MD(mkldnn_nCdhw16c);
break;
default:;
}
return out_md;
} }
bool runtime::cpu::mkldnn_utils::use_mkldnn_kernel(const ngraph::Node* node) bool runtime::cpu::mkldnn_utils::use_mkldnn_kernel(const ngraph::Node* node)
...@@ -231,19 +424,26 @@ bool runtime::cpu::mkldnn_utils::use_mkldnn_kernel(const ngraph::Node* node) ...@@ -231,19 +424,26 @@ bool runtime::cpu::mkldnn_utils::use_mkldnn_kernel(const ngraph::Node* node)
->is_mkldnn_op()); ->is_mkldnn_op());
} }
bool runtime::cpu::mkldnn_utils::compare_mkldnn_formats(mkldnn::memory::format fmt1, bool runtime::cpu::mkldnn_utils::compare_mkldnn_formats(mkldnn::memory::format lhs,
mkldnn::memory::format fmt2) mkldnn::memory::format rhs)
{ {
std::set<mkldnn::memory::format> similar_4d_formats{mkldnn::memory::format::nchw, std::set<mkldnn::memory::format> similar_4d_formats{mkldnn::memory::format::nchw,
mkldnn::memory::format::oihw}; mkldnn::memory::format::oihw};
if ((fmt1 == fmt2) || (similar_4d_formats.find(fmt1) != similar_4d_formats.end() && if ((lhs == rhs) || (similar_4d_formats.find(lhs) != similar_4d_formats.end() &&
similar_4d_formats.find(fmt2) != similar_4d_formats.end())) similar_4d_formats.find(rhs) != similar_4d_formats.end()))
{ {
return true; return true;
} }
return false; return false;
} }
bool runtime::cpu::mkldnn_utils::compare_mkldnn_mds(const mkldnn::memory::desc& lhs,
const mkldnn::memory::desc& rhs)
{
return (memory::primitive_desc(lhs, runtime::cpu::mkldnn_utils::global_cpu_engine) ==
memory::primitive_desc(rhs, runtime::cpu::mkldnn_utils::global_cpu_engine));
}
bool runtime::cpu::mkldnn_utils::is_mkldnn_filter_format(mkldnn::memory::format fmt) bool runtime::cpu::mkldnn_utils::is_mkldnn_filter_format(mkldnn::memory::format fmt)
{ {
if (s_filter_formats.find(fmt) != s_filter_formats.end()) if (s_filter_formats.find(fmt) != s_filter_formats.end())
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include <mkldnn.hpp> #include <mkldnn.hpp>
#include "ngraph/axis_vector.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/runtime/cpu/cpu_layout_descriptor.hpp" #include "ngraph/runtime/cpu/cpu_layout_descriptor.hpp"
#include "ngraph/type/element_type.hpp" #include "ngraph/type/element_type.hpp"
...@@ -40,11 +41,25 @@ namespace ngraph ...@@ -40,11 +41,25 @@ namespace ngraph
mkldnn::memory::data_type get_mkldnn_data_type(const ngraph::element::Type& type); mkldnn::memory::data_type get_mkldnn_data_type(const ngraph::element::Type& type);
const std::string& get_mkldnn_format_string(mkldnn::memory::format fmt); const std::string& get_mkldnn_format_string(mkldnn::memory::format fmt);
mkldnn::memory::format get_input_mkldnn_format(const Node* node, size_t index); const mkldnn::memory::desc& get_input_mkldnn_md(const Node* node, size_t index);
mkldnn::memory::format get_output_mkldnn_format(const Node* node, size_t index); const mkldnn::memory::desc& get_output_mkldnn_md(const Node* node, size_t index);
mkldnn::memory::desc create_default_mkldnn_md(const Node* node,
size_t index,
bool is_output,
mkldnn::memory::format format);
bool is_perm_sorted(const Strides& a, const AxisVector& perm);
bool can_create_mkldnn_md(const Shape& dims,
const Strides& strides,
const ngraph::element::Type type);
mkldnn::memory::desc create_blocked_mkldnn_md(const Shape& dims,
const Strides& strides,
const ngraph::element::Type type);
mkldnn::memory::desc rotate_blocked_md(const mkldnn::memory::desc& in,
AxisVector& axis_order);
bool use_mkldnn_kernel(const ngraph::Node* node); bool use_mkldnn_kernel(const ngraph::Node* node);
bool compare_mkldnn_formats(mkldnn::memory::format fmt1, bool compare_mkldnn_formats(mkldnn::memory::format lhs, mkldnn::memory::format rhs);
mkldnn::memory::format fmt2); bool compare_mkldnn_mds(const mkldnn::memory::desc& lhs,
const mkldnn::memory::desc& rhs);
bool is_mkldnn_filter_format(mkldnn::memory::format fmt); bool is_mkldnn_filter_format(mkldnn::memory::format fmt);
bool is_mkldnn_blocked_data_format(mkldnn::memory::format fmt); bool is_mkldnn_blocked_data_format(mkldnn::memory::format fmt);
} }
......
...@@ -97,7 +97,6 @@ op::Lstm::Lstm(std::shared_ptr<Node> input_xt_1, ...@@ -97,7 +97,6 @@ op::Lstm::Lstm(std::shared_ptr<Node> input_xt_1,
if (shape_size(input_xt_1->get_shape()) != if (shape_size(input_xt_1->get_shape()) !=
m_src_sequence_length * m_batch_size * m_src_layer_feature_size) m_src_sequence_length * m_batch_size * m_src_layer_feature_size)
{ {
std::cout << "shape_size: " << shape_size(input_xt_1->get_shape()) << std::endl;
throw ngraph_error("input_xt_1 size is not equal t*n*c"); throw ngraph_error("input_xt_1 size is not equal t*n*c");
} }
...@@ -159,7 +158,6 @@ op::Lstm::Lstm(std::shared_ptr<Node> src_layer, ...@@ -159,7 +158,6 @@ op::Lstm::Lstm(std::shared_ptr<Node> src_layer,
if (shape_size(src_layer->get_shape()) != if (shape_size(src_layer->get_shape()) !=
m_src_sequence_length * m_batch_size * m_src_layer_feature_size) m_src_sequence_length * m_batch_size * m_src_layer_feature_size)
{ {
std::cout << "shape_size: " << shape_size(src_layer->get_shape()) << std::endl;
throw ngraph_error("src_layer size is not equal t*n*c"); throw ngraph_error("src_layer size is not equal t*n*c");
} }
......
...@@ -87,7 +87,6 @@ op::Rnn::Rnn(std::shared_ptr<Node> src_layer, ...@@ -87,7 +87,6 @@ op::Rnn::Rnn(std::shared_ptr<Node> src_layer,
if (shape_size(src_layer->get_shape()) != if (shape_size(src_layer->get_shape()) !=
m_src_sequence_length * m_batch_size * m_src_layer_feature_size) m_src_sequence_length * m_batch_size * m_src_layer_feature_size)
{ {
std::cout << "shape_size: " << shape_size(src_layer->get_shape()) << std::endl;
throw ngraph_error("src_layer size is not equal t*n*c"); throw ngraph_error("src_layer size is not equal t*n*c");
} }
......
...@@ -33,7 +33,6 @@ ...@@ -33,7 +33,6 @@
#include "ngraph/op/lrn.hpp" #include "ngraph/op/lrn.hpp"
#include "ngraph/op/max_pool.hpp" #include "ngraph/op/max_pool.hpp"
#include "ngraph/op/relu.hpp" #include "ngraph/op/relu.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/softmax.hpp" #include "ngraph/op/softmax.hpp"
#include "ngraph/runtime/cpu/cpu_op_annotations.hpp" #include "ngraph/runtime/cpu/cpu_op_annotations.hpp"
#include "ngraph/runtime/cpu/mkldnn_utils.hpp" #include "ngraph/runtime/cpu/mkldnn_utils.hpp"
...@@ -546,53 +545,6 @@ namespace ngraph ...@@ -546,53 +545,6 @@ namespace ngraph
} }
} }
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::Reshape)
{
auto reshape = static_cast<op::Reshape*>(node);
auto arg0_shape = node->get_input_shape(0);
auto result_shape = node->get_output_shape(0);
auto axis_order = reshape->get_input_order();
bool flag = true;
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
auto users = reshape->get_users();
auto arg = reshape->get_argument(0);
// we need to copy input data if reshape modifies the data or inputs are
// not in the memory pool, or has output users.
bool need_copy = reshape->get_is_transpose() || arg->is_constant();
if (!need_copy)
{
// map output to the input memory
op_annotations->add_in_place_oi_pair({0, 0, false});
reshape->set_op_annotations(op_annotations);
}
// Use Eigen for 3D
if (node->get_input_element_type(0) == element::f32 &&
arg0_shape.size() < TENSOR_MAX_DIMS && arg0_shape.size() > 3 &&
arg0_shape.size() == result_shape.size())
{
for (size_t i = 0; i < axis_order.size(); i++)
{
if (arg0_shape[axis_order[i]] != result_shape[i])
{
flag = false;
break;
}
}
if (flag)
{
op_annotations->set_mkldnn_op(true);
reshape->set_op_annotations(op_annotations);
}
}
}
template <> template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::BatchNorm) void CPUAssignment::ASSIGN_DECL(ngraph::op::BatchNorm)
{ {
...@@ -759,7 +711,6 @@ static const runtime::cpu::pass::AssignOpMap s_dispatcher{ ...@@ -759,7 +711,6 @@ static const runtime::cpu::pass::AssignOpMap s_dispatcher{
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::SigmoidBackprop>}, &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::SigmoidBackprop>},
{TI(ngraph::op::Lstm), &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::Lstm>}, {TI(ngraph::op::Lstm), &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::Lstm>},
{TI(ngraph::op::Rnn), &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::Rnn>}, {TI(ngraph::op::Rnn), &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::Rnn>},
{TI(ngraph::op::Reshape), &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::Reshape>},
{TI(ngraph::op::Softmax), &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::Softmax>}, {TI(ngraph::op::Softmax), &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::Softmax>},
}; };
......
This diff is collapsed.
...@@ -56,13 +56,13 @@ namespace ngraph ...@@ -56,13 +56,13 @@ namespace ngraph
static std::shared_ptr<Node> insert_input_conversions( static std::shared_ptr<Node> insert_input_conversions(
CPU_ExternalFunction* external_function, CPU_ExternalFunction* external_function,
std::shared_ptr<Node>& node, std::shared_ptr<Node>& node,
const std::vector<mkldnn::memory::format>& required_formats); const std::vector<mkldnn::memory::desc>& required_mds);
static void set_output_layouts( static void
std::shared_ptr<Node>& node, set_output_layouts(std::shared_ptr<Node>& node,
const std::vector<mkldnn::memory::format>& output_formats); const std::vector<mkldnn::memory::desc>& output_mds);
static void set_default_layouts(CPU_ExternalFunction* external_function, static void set_native_layouts(CPU_ExternalFunction* external_function,
std::shared_ptr<Node> node, std::shared_ptr<Node> node,
bool use_replace); bool use_replace);
}; };
} }
} }
......
...@@ -45,7 +45,7 @@ void ngraph::runtime::cpu::pass::CPUPostLayoutOptimizations::construct_weight_fu ...@@ -45,7 +45,7 @@ void ngraph::runtime::cpu::pass::CPUPostLayoutOptimizations::construct_weight_fu
std::make_shared<ngraph::op::Reshape>(param, AxisVector{0}, Shape{16, 4, 1, 1}); std::make_shared<ngraph::op::Reshape>(param, AxisVector{0}, Shape{16, 4, 1, 1});
auto data_conv = std::make_shared<pattern::op::Label>(element::f32, Shape{16, 4, 7, 7}); auto data_conv = std::make_shared<pattern::op::Label>(element::f32, Shape{16, 4, 7, 7});
auto tvt = reshape_conv->get_outputs().at(0).get_tensor_view().get(); auto tvt = reshape_conv->get_outputs().at(0).get_tensor_view().get();
auto lt_desc = std::make_shared<runtime::cpu::LayoutDescriptor>(*tvt, AxisVector{0, 1, 2, 3}); auto lt_desc = std::make_shared<runtime::cpu::LayoutDescriptor>(*tvt);
auto cvt_lt_conv = std::make_shared<runtime::cpu::op::ConvertLayout>(reshape_conv, lt_desc); auto cvt_lt_conv = std::make_shared<runtime::cpu::op::ConvertLayout>(reshape_conv, lt_desc);
auto conv = std::make_shared<ngraph::op::Convolution>( auto conv = std::make_shared<ngraph::op::Convolution>(
data_conv, cvt_lt_conv, Strides{1, 1}, Strides{1, 1}); data_conv, cvt_lt_conv, Strides{1, 1}, Strides{1, 1});
......
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include <algorithm>
#include <iostream>
#include <map>
#include <memory>
#include "ngraph/op/reshape.hpp"
#include "ngraph/runtime/cpu/mkldnn_utils.hpp"
#include "ngraph/runtime/cpu/op/convert_layout.hpp"
#include "cpu_shuffle_folding.hpp"
static const std::map<const ngraph::AxisVector, const mkldnn::memory::format>
input_order_format_map{{ngraph::AxisVector{3, 2, 0, 1}, mkldnn::memory::format::hwio}};
bool ngraph::runtime::cpu::pass::CPUShuffleFolding::run_on_function(
std::shared_ptr<ngraph::Function> function)
{
bool clobbered = false;
for (const auto& n : function->get_ordered_ops())
{
auto convert_layout = std::dynamic_pointer_cast<op::ConvertLayout>(n);
if (convert_layout)
{
auto reshape = std::dynamic_pointer_cast<ngraph::op::Reshape>(n->get_argument(0));
if (reshape)
{
auto output_shape = reshape->get_output_shape();
auto input_shape = reshape->get_input_shape(0);
if (output_shape.size() != input_shape.size())
{
continue;
}
size_t j = 0;
bool is_shuffle = true;
for (auto i : reshape->get_input_order())
{
if (input_shape.at(i) != output_shape.at(j++))
{
is_shuffle = false;
break;
}
}
if (!is_shuffle)
{
continue;
}
auto reshape_input_layout =
reshape->get_argument(0)->get_output_tensor_view()->get_tensor_view_layout();
auto output_layout =
convert_layout->get_output_tensor_view()->get_tensor_view_layout();
if (reshape_input_layout)
{
auto reshape_input_layout_descriptor =
std::static_pointer_cast<runtime::cpu::LayoutDescriptor>(
reshape_input_layout);
auto reshape_input_format =
reshape_input_layout_descriptor->get_mkldnn_format();
auto output_format =
std::static_pointer_cast<runtime::cpu::LayoutDescriptor>(output_layout)
->get_mkldnn_format();
if (mkldnn_utils::is_mkldnn_filter_format(output_format) &&
output_format == mkldnn::memory::format::OIhw16i16o &&
reshape_input_format == mkldnn::memory::format::nchw)
{
if (input_order_format_map.find(reshape->get_input_order()) !=
input_order_format_map.end())
{
reshape_input_layout_descriptor->set_mkldnn_format(
input_order_format_map.at(reshape->get_input_order()));
reshape_input_layout_descriptor->set_axis_order(
reshape->get_input_order());
function->replace_node(reshape, reshape->get_argument(0));
}
}
}
}
}
}
return clobbered;
}
...@@ -1063,7 +1063,7 @@ TEST(cpu_fusion, weight_fusion) ...@@ -1063,7 +1063,7 @@ TEST(cpu_fusion, weight_fusion)
std::make_shared<ngraph::op::Reshape>(param, AxisVector{0}, Shape{16, 4, 1, 1}); std::make_shared<ngraph::op::Reshape>(param, AxisVector{0}, Shape{16, 4, 1, 1});
auto data_conv = std::make_shared<op::Parameter>(element::f32, Shape{16, 4, 7, 7}); auto data_conv = std::make_shared<op::Parameter>(element::f32, Shape{16, 4, 7, 7});
auto tvt = reshape_conv->get_outputs().at(0).get_tensor_view().get(); auto tvt = reshape_conv->get_outputs().at(0).get_tensor_view().get();
auto lt_desc = std::make_shared<runtime::cpu::LayoutDescriptor>(*tvt, AxisVector{0, 1, 2, 3}); auto lt_desc = std::make_shared<runtime::cpu::LayoutDescriptor>(*tvt);
auto cvt_lt_conv = std::make_shared<runtime::cpu::op::ConvertLayout>(reshape_conv, lt_desc); auto cvt_lt_conv = std::make_shared<runtime::cpu::op::ConvertLayout>(reshape_conv, lt_desc);
auto conv = std::make_shared<ngraph::op::Convolution>( auto conv = std::make_shared<ngraph::op::Convolution>(
data_conv, cvt_lt_conv, Strides{1, 1}, Strides{1, 1}); data_conv, cvt_lt_conv, Strides{1, 1}, Strides{1, 1});
...@@ -1072,8 +1072,7 @@ TEST(cpu_fusion, weight_fusion) ...@@ -1072,8 +1072,7 @@ TEST(cpu_fusion, weight_fusion)
std::make_shared<op::Reshape>(param, AxisVector{0}, Shape{16, 4, 1, 1}); std::make_shared<op::Reshape>(param, AxisVector{0}, Shape{16, 4, 1, 1});
auto dummy_arg_conv_bprop = std::make_shared<op::Parameter>(element::f32, Shape{1, 16, 7, 7}); auto dummy_arg_conv_bprop = std::make_shared<op::Parameter>(element::f32, Shape{1, 16, 7, 7});
auto tvt_bprop = reshape_conv_bprop->get_outputs().at(0).get_tensor_view().get(); auto tvt_bprop = reshape_conv_bprop->get_outputs().at(0).get_tensor_view().get();
auto lt_desc_bprop = auto lt_desc_bprop = std::make_shared<runtime::cpu::LayoutDescriptor>(*tvt_bprop);
std::make_shared<runtime::cpu::LayoutDescriptor>(*tvt_bprop, AxisVector{0, 1, 2, 3});
auto cvt_lt_conv_bprop = auto cvt_lt_conv_bprop =
std::make_shared<runtime::cpu::op::ConvertLayout>(reshape_conv_bprop, lt_desc_bprop); std::make_shared<runtime::cpu::op::ConvertLayout>(reshape_conv_bprop, lt_desc_bprop);
auto conv_bprop = std::make_shared<op::ConvolutionBackpropData>(Shape{1, 4, 7, 7}, auto conv_bprop = std::make_shared<op::ConvolutionBackpropData>(Shape{1, 4, 7, 7},
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment