Commit 46e0dea7 authored by Jayaram Bobba's avatar Jayaram Bobba

Enable optimal layouts on MKLDNN convolution backprop ops

parent d0f8dff2
......@@ -2001,11 +2001,7 @@ namespace ngraph
auto arg1_shape = args[1].get_shape();
auto result_shape = out[0].get_shape();
auto op_annotations =
static_cast<const ngraph::op::Op*>(node)->get_op_annotations();
if (op_annotations &&
static_pointer_cast<ngraph::runtime::cpu::CPUOpAnnotations>(op_annotations)
->is_mkldnn_op())
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{
// For dilation, MKLDNN wants to know how many elements to insert between, not how far
// apart to space the elements like nGraph. So we have to subtract 1 from each pos.
......@@ -2014,22 +2010,13 @@ namespace ngraph
{
window_dilation_strides_adjusted.push_back(s - 1);
}
auto input_tvl = node->get_inputs()[0]
.get_output()
.get_tensor_view()
->get_tensor_view_layout();
auto weights_tvl = node->get_inputs()[1]
.get_output()
.get_tensor_view()
->get_tensor_view_layout();
auto output_tvl = node->get_output_tensor_view(0)->get_tensor_view_layout();
auto input_format = dynamic_cast<runtime::cpu::LayoutDescriptor&>(*input_tvl)
.get_mkldnn_format();
auto input_format =
runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 0);
auto weights_format =
dynamic_cast<runtime::cpu::LayoutDescriptor&>(*weights_tvl)
.get_mkldnn_format();
auto output_format = dynamic_cast<runtime::cpu::LayoutDescriptor&>(*output_tvl)
.get_mkldnn_format();
runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 1);
auto output_format =
runtime::cpu::mkldnn_utils::get_output_mkldnn_format(node, 0);
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto input_data_desc =
......@@ -2091,17 +2078,8 @@ namespace ngraph
auto arg0_shape = args[0].get_shape();
auto arg1_shape = args[1].get_shape();
auto result_shape = out[0].get_shape();
auto arg0_rank = arg0_shape.size();
auto arg1_rank = arg1_shape.size();
bool data_dilated = false;
for (size_t s : convolution->get_data_dilation_strides_forward())
{
data_dilated = data_dilated || (s != 1);
}
if (!data_dilated && arg0_rank == 4 && arg1_rank == 4 &&
args[0].get_element_type() == element::f32)
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{
const string& elem_type =
runtime::cpu::mkldnn_utils::get_mkldnn_data_type_string(
......@@ -2112,12 +2090,19 @@ namespace ngraph
{
window_dilation_strides_adjusted.push_back(s - 1);
}
auto data_format = runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 0);
auto delta_format =
runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 1);
auto result_format =
runtime::cpu::mkldnn_utils::get_output_mkldnn_format(node, 0);
auto emit_memory_desc = [&writer](const std::string& var,
const std::string& shape,
const std::string& type,
const std::string& layout) {
writer << "memory::desc " << var << " = memory::desc({" << shape << "}, "
<< type << ", memory::format::" << layout << ");\n";
<< type << ", " << layout << ");\n";
};
auto emit_memory = [&writer](
......@@ -2135,9 +2120,21 @@ namespace ngraph
writer << "try\n";
writer.block_begin();
writer << "engine cpu_engine = engine(engine::cpu, 0);\n";
emit_memory_desc("data_desc", join(arg0_shape), elem_type, "nchw");
emit_memory_desc("delta_desc", join(arg1_shape), elem_type, "nchw");
emit_memory_desc("result_desc", join(result_shape), elem_type, "oihw");
emit_memory_desc(
"data_desc",
join(arg0_shape),
elem_type,
runtime::cpu::mkldnn_utils::get_mkldnn_format_string(data_format));
emit_memory_desc(
"delta_desc",
join(arg1_shape),
elem_type,
runtime::cpu::mkldnn_utils::get_mkldnn_format_string(delta_format));
emit_memory_desc(
"result_desc",
join(result_shape),
elem_type,
runtime::cpu::mkldnn_utils::get_mkldnn_format_string(result_format));
emit_memory("data", "data_desc", args[0].get_name());
emit_memory("delta", "delta_desc", args[1].get_name());
emit_memory("result", "result_desc", out[0].get_name());
......@@ -2202,17 +2199,8 @@ namespace ngraph
auto arg0_shape = args[0].get_shape();
auto arg1_shape = args[1].get_shape();
auto result_shape = out[0].get_shape();
auto arg0_rank = arg0_shape.size();
auto arg1_rank = arg1_shape.size();
bool data_dilated = false;
for (size_t s : convolution->get_data_dilation_strides_forward())
{
data_dilated = data_dilated || (s != 1);
}
if (!data_dilated && arg0_rank == 4 && arg1_rank == 4 &&
args[0].get_element_type() == element::f32)
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{
const string& elem_type =
runtime::cpu::mkldnn_utils::get_mkldnn_data_type_string(
......@@ -2224,12 +2212,19 @@ namespace ngraph
window_dilation_strides_adjusted.push_back(s - 1);
}
auto weight_format =
runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 0);
auto delta_format =
runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 1);
auto result_format =
runtime::cpu::mkldnn_utils::get_output_mkldnn_format(node, 0);
auto emit_memory_desc = [&writer](const std::string& var,
const std::string& shape,
const std::string& type,
const std::string& layout) {
writer << "memory::desc " << var << " = memory::desc({" << shape << "}, "
<< type << ", memory::format::" << layout << ");\n";
<< type << ", " << layout << ");\n";
};
auto emit_memory = [&writer](
......@@ -2247,9 +2242,21 @@ namespace ngraph
writer << "try\n";
writer.block_begin();
writer << "engine cpu_engine = engine(engine::cpu, 0);\n";
emit_memory_desc("weight_desc", join(arg0_shape), elem_type, "oihw");
emit_memory_desc("delta_desc", join(arg1_shape), elem_type, "nchw");
emit_memory_desc("result_desc", join(result_shape), elem_type, "nchw");
emit_memory_desc(
"weight_desc",
join(arg0_shape),
elem_type,
runtime::cpu::mkldnn_utils::get_mkldnn_format_string(weight_format));
emit_memory_desc(
"delta_desc",
join(arg1_shape),
elem_type,
runtime::cpu::mkldnn_utils::get_mkldnn_format_string(delta_format));
emit_memory_desc(
"result_desc",
join(result_shape),
elem_type,
runtime::cpu::mkldnn_utils::get_mkldnn_format_string(result_format));
emit_memory("weight", "weight_desc", args[0].get_name());
emit_memory("delta", "delta_desc", args[1].get_name());
emit_memory("result", "result_desc", out[0].get_name());
......
......@@ -107,8 +107,9 @@ void runtime::cpu::CPUTensorView::read(void* target, size_t tensor_offset, size_
auto tvl = this->get_tensor_view_layout();
auto cpu_tvl = dynamic_cast<runtime::cpu::LayoutDescriptor*>(tvl.get());
if (cpu_tvl && cpu_tvl->get_mkldnn_format() != memory::format::format_undef &&
cpu_tvl->get_mkldnn_format() !=
runtime::cpu::mkldnn_utils::CreateNativeDataFormat(*cpu_tvl))
!runtime::cpu::mkldnn_utils::compare_mkldnn_formats(
cpu_tvl->get_mkldnn_format(),
runtime::cpu::mkldnn_utils::CreateNativeDataFormat(*cpu_tvl)))
{
auto tensor_shape = this->get_shape();
auto input_format = cpu_tvl->get_mkldnn_format();
......
......@@ -19,18 +19,21 @@
#include <typeinfo>
#include <unordered_set>
#include "ngraph/types/element_type.hpp"
#include "ngraph/node.hpp"
#include "ngraph/ops/avg_pool.hpp"
#include "ngraph/ops/batch_norm.hpp"
#include "ngraph/ops/convolution.hpp"
#include "ngraph/ops/max_pool.hpp"
#include "ngraph/ops/relu.hpp"
#include "ngraph/runtime/cpu/cpu_layout_descriptor.hpp"
#include "ngraph/runtime/cpu/cpu_op_annotations.hpp"
#include "ngraph/types/element_type.hpp"
#include "mkldnn_utils.hpp"
using namespace mkldnn;
using namespace ngraph;
using namespace std;
#define TI(x) std::type_index(typeid(x))
......@@ -120,7 +123,8 @@ mkldnn::memory::format runtime::cpu::mkldnn_utils::CreateNativeDataFormat(
}
}
const std::string& runtime::cpu::mkldnn_utils::get_mkldnn_data_type_string(const ngraph::element::Type& type)
const std::string&
runtime::cpu::mkldnn_utils::get_mkldnn_data_type_string(const ngraph::element::Type& type)
{
auto it = s_mkldnn_data_type_string_map.find(type);
if (it == s_mkldnn_data_type_string_map.end() || it->second.empty())
......@@ -128,7 +132,8 @@ const std::string& runtime::cpu::mkldnn_utils::get_mkldnn_data_type_string(const
return it->second;
}
mkldnn::memory::data_type runtime::cpu::mkldnn_utils::get_mkldnn_data_type(const ngraph::element::Type& type)
mkldnn::memory::data_type
runtime::cpu::mkldnn_utils::get_mkldnn_data_type(const ngraph::element::Type& type)
{
auto it = s_mkldnn_data_type_map.find(type);
if (it == s_mkldnn_data_type_map.end() || it->second == memory::data_type::data_undef)
......@@ -146,3 +151,38 @@ const std::string& runtime::cpu::mkldnn_utils::get_mkldnn_format_string(memory::
std::to_string(fmt));
return it->second;
}
mkldnn::memory::format runtime::cpu::mkldnn_utils::get_input_mkldnn_format(const Node* node,
int index)
{
auto tvl = node->get_inputs()[index].get_output().get_tensor_view()->get_tensor_view_layout();
return dynamic_cast<runtime::cpu::LayoutDescriptor&>(*tvl).get_mkldnn_format();
}
mkldnn::memory::format runtime::cpu::mkldnn_utils::get_output_mkldnn_format(const Node* node,
int index)
{
auto tvl = node->get_output_tensor_view(0)->get_tensor_view_layout();
return dynamic_cast<runtime::cpu::LayoutDescriptor&>(*tvl).get_mkldnn_format();
}
bool runtime::cpu::mkldnn_utils::use_mkldnn_kernel(const ngraph::Node* node)
{
auto op_annotations = static_cast<const ngraph::op::Op*>(node)->get_op_annotations();
return (op_annotations &&
static_pointer_cast<ngraph::runtime::cpu::CPUOpAnnotations>(op_annotations)
->is_mkldnn_op());
}
bool runtime::cpu::mkldnn_utils::compare_mkldnn_formats(mkldnn::memory::format fmt1,
mkldnn::memory::format fmt2)
{
set<mkldnn::memory::format> similar_4d_formats{mkldnn::memory::format::nchw,
mkldnn::memory::format::oihw};
if ((fmt1 == fmt2) || (similar_4d_formats.find(fmt1) != similar_4d_formats.end() &&
similar_4d_formats.find(fmt2) != similar_4d_formats.end()))
{
return true;
}
return false;
}
\ No newline at end of file
......@@ -38,6 +38,12 @@ namespace ngraph
const std::string& get_mkldnn_data_type_string(const ngraph::element::Type& type);
mkldnn::memory::data_type get_mkldnn_data_type(const ngraph::element::Type& type);
const std::string& get_mkldnn_format_string(mkldnn::memory::format fmt);
mkldnn::memory::format get_input_mkldnn_format(const Node* node, int index);
mkldnn::memory::format get_output_mkldnn_format(const Node* node, int index);
bool use_mkldnn_kernel(const ngraph::Node* node);
bool compare_mkldnn_formats(mkldnn::memory::format fmt1,
mkldnn::memory::format fmt2);
}
}
}
......
......@@ -66,6 +66,60 @@ namespace ngraph
convolution->set_op_annotations(op_annotations);
}
}
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::ConvolutionBackpropData)
{
auto convolution = static_cast<op::ConvolutionBackpropData*>(node);
auto arg0_shape = node->get_input_shape(0);
auto arg1_shape = node->get_input_shape(1);
auto result_shape = node->get_output_shape(0);
auto arg0_rank = arg0_shape.size();
auto arg1_rank = arg1_shape.size();
bool data_dilated = false;
for (size_t s : convolution->get_data_dilation_strides_forward())
{
data_dilated = data_dilated || (s != 1);
}
if (!data_dilated && arg0_rank == 4 && arg1_rank == 4 &&
node->get_input_element_type(0) == element::f32)
{
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
convolution->set_op_annotations(op_annotations);
}
}
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::ConvolutionBackpropFilters)
{
auto convolution = static_cast<op::ConvolutionBackpropFilters*>(node);
auto arg0_shape = node->get_input_shape(0);
auto arg1_shape = node->get_input_shape(1);
auto result_shape = node->get_output_shape(0);
auto arg0_rank = arg0_shape.size();
auto arg1_rank = arg1_shape.size();
bool data_dilated = false;
for (size_t s : convolution->get_data_dilation_strides_forward())
{
data_dilated = data_dilated || (s != 1);
}
if (!data_dilated && arg0_rank == 4 && arg1_rank == 4 &&
node->get_input_element_type(0) == element::f32)
{
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
convolution->set_op_annotations(op_annotations);
}
}
}
}
}
......@@ -76,6 +130,10 @@ namespace ngraph
static const runtime::cpu::pass::AssignOpMap s_dispatcher{
{TI(ngraph::op::Convolution),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::Convolution>},
{TI(ngraph::op::ConvolutionBackpropData),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::ConvolutionBackpropData>},
{TI(ngraph::op::ConvolutionBackpropFilters),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::ConvolutionBackpropFilters>},
};
bool runtime::cpu::pass::CPUAssignment::run_on_call_graph(
......
......@@ -36,6 +36,95 @@ using namespace std;
using namespace mkldnn;
using namespace ngraph;
shared_ptr<Node> runtime::cpu::pass::CPULayout::insert_input_conversions(
runtime::cpu::CPU_ExternalFunction* external_function,
shared_ptr<Node>& node,
const vector<memory::format>& required_formats)
{
vector<shared_ptr<Node>> new_args;
bool replace_node = false;
uint index = 0;
for (const descriptor::Input& input : node->get_inputs())
{
const auto& output = input.get_output();
auto tv = output.get_tensor_view();
auto tvt = tv->get_tensor_view_type();
auto rank = tvt->get_shape().size();
auto tvl = tv->get_tensor_view_layout();
auto mkldnn_tvl = dynamic_cast<runtime::cpu::LayoutDescriptor*>(tvl.get());
if (!mkldnn_tvl ||
!runtime::cpu::mkldnn_utils::compare_mkldnn_formats(mkldnn_tvl->get_mkldnn_format(),
required_formats[index]))
{
auto native_axis_order =
ngraph::runtime::cpu::LayoutDescriptor::create_native_axis_order(rank);
auto layout =
std::make_shared<ngraph::runtime::cpu::LayoutDescriptor>(*tv, native_axis_order);
layout->set_mkldnn_format(required_formats[index]);
auto new_node = std::shared_ptr<Node>(
new runtime::cpu::op::ConvertLayout(output.get_node(), output.get_index(), layout));
new_args.push_back(new_node);
replace_node = true;
NGRAPH_DEBUG << "Inserted conversion node " << new_node->get_name() << " between "
<< output.get_node()->get_name()
<< "(layout: " << mkldnn_tvl->get_mkldnn_format() << ") and "
<< node->get_name() << "(layout: " << required_formats[index] << ")";
}
else
{
new_args.push_back(node->get_input_op(index));
}
index++;
}
shared_ptr<Node> new_node;
if (replace_node)
{
new_node = node->copy_with_new_args(new_args);
if (node->is_output())
{
external_function->get_function()->replace_node(node, new_node);
}
else
{
ngraph::replace_node(node, new_node);
}
NGRAPH_DEBUG << "Replaced " << node->get_name() << " with " << new_node->get_name();
auto old_op_annotations = static_pointer_cast<ngraph::op::Op>(node)->get_op_annotations();
static_pointer_cast<ngraph::op::Op>(new_node)->set_op_annotations(old_op_annotations);
node = new_node;
}
return node;
}
void runtime::cpu::pass::CPULayout::set_output_layouts(shared_ptr<Node>& node,
const vector<memory::format>& output_formats)
{
for (size_t i = 0; i < node->get_output_size(); ++i)
{
auto tv = node->get_output_tensor_view(i);
auto tvt = tv->get_tensor_view_type();
auto rank = tvt->get_shape().size();
auto tvl = tv->get_tensor_view_layout();
if (tvl)
{
throw ngraph_error("Node output layout already set");
}
auto native_axis_order =
ngraph::runtime::cpu::LayoutDescriptor::create_native_axis_order(rank);
auto layout =
std::make_shared<ngraph::runtime::cpu::LayoutDescriptor>(*tv, native_axis_order);
layout->set_mkldnn_format(output_formats[i]);
tv->set_tensor_view_layout(layout);
NGRAPH_DEBUG << "Setting Node: " << node->get_name()
<< " output layout: " << output_formats[i] << endl;
}
}
void runtime::cpu::pass::CPULayout::set_default_layouts(
runtime::cpu::CPU_ExternalFunction* external_function, std::shared_ptr<Node> node)
{
......@@ -51,8 +140,9 @@ void runtime::cpu::pass::CPULayout::set_default_layouts(
auto tvl = tv->get_tensor_view_layout();
auto cpu_tvl = dynamic_cast<runtime::cpu::LayoutDescriptor*>(tvl.get());
if (cpu_tvl && cpu_tvl->get_mkldnn_format() != memory::format::format_undef &&
cpu_tvl->get_mkldnn_format() !=
runtime::cpu::mkldnn_utils::CreateNativeDataFormat(*cpu_tvl))
!runtime::cpu::mkldnn_utils::compare_mkldnn_formats(
cpu_tvl->get_mkldnn_format(),
runtime::cpu::mkldnn_utils::CreateNativeDataFormat(*cpu_tvl)))
{
auto native_axis_order =
ngraph::runtime::cpu::LayoutDescriptor::create_native_axis_order(rank);
......@@ -127,11 +217,7 @@ namespace ngraph
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::Convolution)
{
auto op_annotations =
static_pointer_cast<ngraph::op::Op>(node)->get_op_annotations();
if (op_annotations &&
static_pointer_cast<ngraph::runtime::cpu::CPUOpAnnotations>(op_annotations)
->is_mkldnn_op())
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node.get()))
{
auto convolution = static_cast<const ngraph::op::Convolution*>(node.get());
......@@ -181,100 +267,194 @@ namespace ngraph
mkldnn_padding_above,
padding_kind::zero);
convolution_forward::primitive_desc prim_desc(fwd_desc, cpu_engine);
memory::format prim_input_formats[2];
memory::format prim_output_formats[1];
prim_input_formats[0] = static_cast<memory::format>(
prim_desc.src_primitive_desc().desc().data.format);
prim_output_formats[0] = static_cast<memory::format>(
prim_desc.dst_primitive_desc().desc().data.format);
prim_input_formats[1] = static_cast<memory::format>(
prim_desc.weights_primitive_desc().desc().data.format);
std::vector<shared_ptr<Node>> new_args;
bool replace_node = false;
uint index = 0;
for (const descriptor::Input& input : node->get_inputs())
{
const auto& output = input.get_output();
auto tv = output.get_tensor_view();
auto tvt = tv->get_tensor_view_type();
auto rank = tvt->get_shape().size();
auto tvl = tv->get_tensor_view_layout();
auto mkldnn_tvl =
dynamic_cast<runtime::cpu::LayoutDescriptor*>(tvl.get());
if (!mkldnn_tvl ||
mkldnn_tvl->get_mkldnn_format() != prim_input_formats[index])
{
auto native_axis_order = ngraph::runtime::cpu::LayoutDescriptor::
create_native_axis_order(rank);
auto layout =
std::make_shared<ngraph::runtime::cpu::LayoutDescriptor>(
*tv, native_axis_order);
layout->set_mkldnn_format(prim_input_formats[index]);
auto new_node =
std::shared_ptr<Node>(new runtime::cpu::op::ConvertLayout(
output.get_node(), output.get_index(), layout));
new_args.push_back(new_node);
replace_node = true;
NGRAPH_DEBUG << "Inserted conversion node " << new_node->get_name()
<< " between " << output.get_node()->get_name()
<< "(layout: " << mkldnn_tvl->get_mkldnn_format()
<< ") and " << node->get_name()
<< "(layout: " << prim_input_formats[index] << ")";
}
else
{
new_args.push_back(node->get_input_op(index));
}
index++;
}
vector<memory::format> prim_input_formats;
vector<memory::format> prim_output_formats;
prim_input_formats.push_back(static_cast<memory::format>(
prim_desc.src_primitive_desc().desc().data.format));
prim_input_formats.push_back(static_cast<memory::format>(
prim_desc.weights_primitive_desc().desc().data.format));
prim_output_formats.push_back(static_cast<memory::format>(
prim_desc.dst_primitive_desc().desc().data.format));
shared_ptr<Node> new_node;
if (replace_node)
node =
insert_input_conversions(external_function, node, prim_input_formats);
set_output_layouts(node, prim_output_formats);
}
else
{
set_default_layouts(external_function, node);
}
}
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::ConvolutionBackpropData)
{
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node.get()))
{
auto convolution =
static_cast<const ngraph::op::ConvolutionBackpropData*>(node.get());
auto arg0_shape = node->get_input_shape(0);
auto arg1_shape = node->get_input_shape(1);
auto result_shape = node->get_output_shape(0);
auto filter_strides = convolution->get_window_movement_strides_forward();
auto padding_below = convolution->get_padding_below_forward();
auto padding_above = convolution->get_padding_above_forward();
Strides window_dilation_strides_adjusted;
for (size_t s : convolution->get_window_dilation_strides_forward())
{
new_node = node->copy_with_new_args(new_args);
if (node->is_output())
{
external_function->get_function()->replace_node(node, new_node);
}
else
{
ngraph::replace_node(node, new_node);
}
NGRAPH_DEBUG << "Replaced " << node->get_name() << " with "
<< new_node->get_name();
auto old_op_annotations =
static_pointer_cast<ngraph::op::Op>(node)->get_op_annotations();
static_pointer_cast<ngraph::op::Op>(new_node)->set_op_annotations(
old_op_annotations);
node = new_node;
window_dilation_strides_adjusted.push_back(s - 1);
}
// Set convolution output format
for (size_t i = 0; i < node->get_output_size(); ++i)
memory::data_type et = runtime::cpu::mkldnn_utils::get_mkldnn_data_type(
node->get_input_element_type(0));
engine cpu_engine(engine::cpu, 0);
memory::dims mkldnn_arg0_shape(arg0_shape.begin(), arg0_shape.end());
memory::dims mkldnn_arg1_shape(arg1_shape.begin(), arg1_shape.end());
memory::dims mkldnn_result_shape(result_shape.begin(), result_shape.end());
memory::dims mkldnn_filter_strides(filter_strides.begin(),
filter_strides.end());
memory::dims mkldnn_dilated_strides(
window_dilation_strides_adjusted.begin(),
window_dilation_strides_adjusted.end());
memory::dims mkldnn_padding_below(padding_below.begin(),
padding_below.end());
memory::dims mkldnn_padding_above(padding_above.begin(),
padding_above.end());
const memory::desc weights_desc(mkldnn_arg0_shape, et, memory::format::any);
const memory::desc delta_desc(mkldnn_arg1_shape, et, memory::format::any);
const memory::desc result_desc(
mkldnn_result_shape, et, memory::format::any);
convolution_backward_data::desc bwd_desc(algorithm::convolution_direct,
result_desc,
weights_desc,
delta_desc,
mkldnn_filter_strides,
mkldnn_dilated_strides,
mkldnn_padding_below,
mkldnn_padding_above,
padding_kind::zero);
convolution_forward::desc fwd_desc(prop_kind::forward,
algorithm::convolution_direct,
result_desc,
weights_desc,
delta_desc,
mkldnn_filter_strides,
mkldnn_dilated_strides,
mkldnn_padding_below,
mkldnn_padding_above,
padding_kind::zero);
convolution_forward::primitive_desc fwd_prim_desc(fwd_desc, cpu_engine);
convolution_backward_data::primitive_desc prim_desc(
bwd_desc, cpu_engine, fwd_prim_desc);
vector<memory::format> prim_input_formats;
vector<memory::format> prim_output_formats;
prim_input_formats.push_back(static_cast<memory::format>(
prim_desc.weights_primitive_desc().desc().data.format));
prim_input_formats.push_back(static_cast<memory::format>(
prim_desc.diff_dst_primitive_desc().desc().data.format));
prim_output_formats.push_back(static_cast<memory::format>(
prim_desc.diff_src_primitive_desc().desc().data.format));
node =
insert_input_conversions(external_function, node, prim_input_formats);
set_output_layouts(node, prim_output_formats);
}
else
{
set_default_layouts(external_function, node);
}
}
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::ConvolutionBackpropFilters)
{
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node.get()))
{
auto convolution =
static_cast<const ngraph::op::ConvolutionBackpropFilters*>(node.get());
auto arg0_shape = node->get_input_shape(0);
auto arg1_shape = node->get_input_shape(1);
auto result_shape = node->get_output_shape(0);
auto filter_strides = convolution->get_window_movement_strides_forward();
auto padding_below = convolution->get_padding_below_forward();
auto padding_above = convolution->get_padding_above_forward();
Strides window_dilation_strides_adjusted;
for (size_t s : convolution->get_window_dilation_strides_forward())
{
auto tv = node->get_output_tensor_view(i);
auto tvt = tv->get_tensor_view_type();
auto rank = tvt->get_shape().size();
auto tvl = tv->get_tensor_view_layout();
if (tvl)
{
throw ngraph_error("Convolution output layout already set");
}
auto native_axis_order =
ngraph::runtime::cpu::LayoutDescriptor::create_native_axis_order(
rank);
auto layout = std::make_shared<ngraph::runtime::cpu::LayoutDescriptor>(
*tv, native_axis_order);
layout->set_mkldnn_format(prim_output_formats[i]);
tv->set_tensor_view_layout(layout);
NGRAPH_DEBUG << "Setting Node: " << node->get_name()
<< " output layout: " << prim_output_formats[i] << endl;
window_dilation_strides_adjusted.push_back(s - 1);
}
memory::data_type et = runtime::cpu::mkldnn_utils::get_mkldnn_data_type(
node->get_input_element_type(0));
engine cpu_engine(engine::cpu, 0);
memory::dims mkldnn_arg0_shape(arg0_shape.begin(), arg0_shape.end());
memory::dims mkldnn_arg1_shape(arg1_shape.begin(), arg1_shape.end());
memory::dims mkldnn_result_shape(result_shape.begin(), result_shape.end());
memory::dims mkldnn_filter_strides(filter_strides.begin(),
filter_strides.end());
memory::dims mkldnn_dilated_strides(
window_dilation_strides_adjusted.begin(),
window_dilation_strides_adjusted.end());
memory::dims mkldnn_padding_below(padding_below.begin(),
padding_below.end());
memory::dims mkldnn_padding_above(padding_above.begin(),
padding_above.end());
const memory::desc data_desc(mkldnn_arg0_shape, et, memory::format::any);
const memory::desc delta_desc(mkldnn_arg1_shape, et, memory::format::any);
const memory::desc result_desc(
mkldnn_result_shape, et, memory::format::any);
convolution_backward_weights::desc bwd_desc(algorithm::convolution_direct,
data_desc,
result_desc,
delta_desc,
mkldnn_filter_strides,
mkldnn_dilated_strides,
mkldnn_padding_below,
mkldnn_padding_above,
padding_kind::zero);
convolution_forward::desc fwd_desc(prop_kind::forward,
algorithm::convolution_direct,
data_desc,
result_desc,
delta_desc,
mkldnn_filter_strides,
mkldnn_dilated_strides,
mkldnn_padding_below,
mkldnn_padding_above,
padding_kind::zero);
convolution_forward::primitive_desc fwd_prim_desc(fwd_desc, cpu_engine);
convolution_backward_weights::primitive_desc prim_desc(
bwd_desc, cpu_engine, fwd_prim_desc);
vector<memory::format> prim_input_formats;
vector<memory::format> prim_output_formats;
prim_input_formats.push_back(static_cast<memory::format>(
prim_desc.src_primitive_desc().desc().data.format));
prim_input_formats.push_back(static_cast<memory::format>(
prim_desc.diff_dst_primitive_desc().desc().data.format));
prim_output_formats.push_back(static_cast<memory::format>(
prim_desc.diff_weights_primitive_desc().desc().data.format));
node =
insert_input_conversions(external_function, node, prim_input_formats);
set_output_layouts(node, prim_output_formats);
}
else
{
......@@ -290,6 +470,10 @@ namespace ngraph
static const runtime::cpu::pass::LayoutOpMap s_dispatcher{
{TI(ngraph::op::Convolution), &runtime::cpu::pass::CPULayout::layout<ngraph::op::Convolution>},
{TI(ngraph::op::ConvolutionBackpropData),
&runtime::cpu::pass::CPULayout::layout<ngraph::op::ConvolutionBackpropData>},
{TI(ngraph::op::ConvolutionBackpropFilters),
&runtime::cpu::pass::CPULayout::layout<ngraph::op::ConvolutionBackpropFilters>},
};
bool runtime::cpu::pass::CPULayout::run_on_call_graph(const std::list<std::shared_ptr<Node>>& nodes)
......
......@@ -53,6 +53,13 @@ namespace ngraph
private:
std::shared_ptr<CPU_ExternalFunction> m_external_function;
static std::shared_ptr<Node> insert_input_conversions(
CPU_ExternalFunction* external_function,
std::shared_ptr<Node>& node,
const std::vector<mkldnn::memory::format>& required_formats);
static void set_output_layouts(
std::shared_ptr<Node>& node,
const std::vector<mkldnn::memory::format>& output_formats);
static void set_default_layouts(CPU_ExternalFunction* external_function,
std::shared_ptr<Node> node);
};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment