Commit b04f3c36 authored by Jayaram Bobba's avatar Jayaram Bobba Committed by Scott Cyphers

Moved mkldnn conv availability checks to utils and use it across passes (#1984)

* Moved mkldnn conv availability checks to utils and use it across passes

* Style fix
parent 28002287
......@@ -77,6 +77,41 @@ namespace ngraph
std::map<element::Type, const std::string>& get_mkldnn_data_type_string_map();
std::map<mkldnn::memory::format, const std::string>& get_mkldnn_format_string_map();
std::set<mkldnn::memory::format>& get_filter_formats();
template <typename T>
bool can_use_mkldnn_conv(ngraph::Node* node)
{
auto convolution = static_cast<const T*>(node);
auto arg0_rank = node->get_input_shape(0).size();
for (size_t s : convolution->get_data_dilation_strides())
{
if (s != 1)
return false;
}
if (arg0_rank != 4 && arg0_rank != 5)
{
return false;
}
if (node->get_input_element_type(0) != element::f32)
{
return false;
}
// Temporarily disable MKLDNN for large paddings due to
// a bug in v0.16 - MKFDNN-982
for (auto s : convolution->get_padding_below())
{
if (s >= 7)
return false;
}
for (auto s : convolution->get_padding_above())
{
if (s >= 7)
return false;
}
return true;
}
}
}
}
......
......@@ -125,47 +125,12 @@ namespace ngraph
}
}
template <typename T>
bool can_use_mkldnn_conv(ngraph::Node* node)
{
auto convolution = static_cast<const T*>(node);
auto arg0_rank = node->get_input_shape(0).size();
for (size_t s : convolution->get_data_dilation_strides())
{
if (s != 1)
return false;
}
if (arg0_rank != 4 && arg0_rank != 5)
{
return false;
}
if (node->get_input_element_type(0) != element::f32)
{
return false;
}
// Temporarily disable MKLDNN for large paddings due to
// a bug in v0.16 - MKFDNN-982
for (auto s : convolution->get_padding_below())
{
if (s >= 7)
return false;
}
for (auto s : convolution->get_padding_above())
{
if (s >= 7)
return false;
}
return true;
}
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::Convolution)
{
auto convolution = static_cast<op::Convolution*>(node);
if (can_use_mkldnn_conv<ngraph::op::Convolution>(node))
if (mkldnn_utils::can_use_mkldnn_conv<ngraph::op::Convolution>(node))
{
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
......@@ -179,7 +144,7 @@ namespace ngraph
{
auto convolution = static_cast<op::GroupConvolution*>(node);
if (can_use_mkldnn_conv<ngraph::op::GroupConvolution>(node))
if (mkldnn_utils::can_use_mkldnn_conv<ngraph::op::GroupConvolution>(node))
{
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
......@@ -193,7 +158,7 @@ namespace ngraph
{
auto convolution = static_cast<op::GroupConvolutionBias*>(node);
if (can_use_mkldnn_conv<ngraph::op::GroupConvolutionBias>(node))
if (mkldnn_utils::can_use_mkldnn_conv<ngraph::op::GroupConvolutionBias>(node))
{
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
......@@ -207,7 +172,7 @@ namespace ngraph
{
auto convolution = static_cast<op::ConvolutionRelu*>(node);
if (can_use_mkldnn_conv<ngraph::op::ConvolutionRelu>(node))
if (mkldnn_utils::can_use_mkldnn_conv<ngraph::op::ConvolutionRelu>(node))
{
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
......@@ -221,7 +186,7 @@ namespace ngraph
{
auto convolution = static_cast<op::ConvolutionBiasAdd*>(node);
if (can_use_mkldnn_conv<ngraph::op::ConvolutionBiasAdd>(node))
if (mkldnn_utils::can_use_mkldnn_conv<ngraph::op::ConvolutionBiasAdd>(node))
{
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
......@@ -238,7 +203,7 @@ namespace ngraph
{
auto convolution = static_cast<op::ConvolutionAdd*>(node);
if (can_use_mkldnn_conv<ngraph::op::ConvolutionAdd>(node))
if (mkldnn_utils::can_use_mkldnn_conv<ngraph::op::ConvolutionAdd>(node))
{
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
......@@ -335,20 +300,7 @@ namespace ngraph
{
auto convolution = static_cast<op::ConvolutionBias*>(node);
auto data_shape = node->get_input_shape(0);
auto weights_shape = node->get_input_shape(1);
auto result_shape = node->get_output_shape(0);
auto data_rank = data_shape.size();
auto weights_rank = weights_shape.size();
bool data_dilated = false;
for (size_t s : convolution->get_data_dilation_strides())
{
data_dilated = data_dilated || (s != 1);
}
if (!data_dilated && data_rank == 4 && weights_rank == 4 &&
node->get_input_element_type(0) == element::f32)
if (mkldnn_utils::can_use_mkldnn_conv<ngraph::op::ConvolutionBias>(node))
{
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
......
......@@ -52,6 +52,7 @@
#include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/label.hpp"
#include "ngraph/pattern/op/skip.hpp"
#include "ngraph/runtime/cpu/mkldnn_utils.hpp"
#include "ngraph/runtime/cpu/op/batch_norm_relu.hpp"
#include "ngraph/runtime/cpu/op/bounded_relu.hpp"
#include "ngraph/runtime/cpu/op/conv_add.hpp"
......@@ -639,32 +640,31 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_conv_bias()
auto pattern_map = m.get_pattern_map();
auto conv = std::static_pointer_cast<op::Convolution>(m.get_match_root()->get_argument(0));
if (conv->get_input_shape(0).size() == 4)
if (!runtime::cpu::mkldnn_utils::can_use_mkldnn_conv<op::Convolution>(conv.get()))
{
NGRAPH_DEBUG << "Convolution not supported by MKLDNN";
return false;
}
auto bias = m.get_match_root()->get_argument(1)->get_argument(0);
auto bias_shape = bias->get_shape();
if (bias_shape.size() > 1)
{
NGRAPH_DEBUG
<< "mpattern = " << m.get_match_root()->get_name()
NGRAPH_DEBUG << "mpattern = " << m.get_match_root()->get_name()
<< "conv_bias bias shape != 1, requires reshape to match filter count.";
auto order = ngraph::get_default_order(bias_shape);
auto bias_reshape =
std::make_shared<op::Reshape>(bias, order, Shape{conv->get_input_shape(1)[0]});
auto conv_bias = std::shared_ptr<Node>(new op::ConvolutionBias(conv, bias_reshape));
ngraph::replace_node(m.get_match_root(), conv_bias);
return true;
}
else
{
auto conv_bias = std::shared_ptr<Node>(new op::ConvolutionBias(conv, bias));
ngraph::replace_node(m.get_match_root(), conv_bias);
return true;
}
}
NGRAPH_DEBUG << "mpattern = " << m.get_match_root()->get_name()
<< "conv_bias fusion skipped due to input rank size != 4.";
return false;
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(p_conv_bias, callback);
......@@ -910,31 +910,9 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_conv_relu()
auto conv = std::static_pointer_cast<op::Convolution>(m.get_match_root()->get_argument(0));
// These checks are to make sure a MKLDNN Convolution kernel can be used.
bool data_dilated = false;
for (size_t s : conv->get_data_dilation_strides())
if (!runtime::cpu::mkldnn_utils::can_use_mkldnn_conv<op::Convolution>(conv.get()))
{
data_dilated = data_dilated || (s != 1);
}
if (data_dilated)
{
NGRAPH_DEBUG << "Convolution has dilations greater than 1";
return false;
}
if (conv->get_element_type() != element::f32)
{
NGRAPH_DEBUG << "Convolution isn't of type float";
return false;
}
auto arg0_rank = conv->get_input_shape(0).size();
auto arg1_rank = conv->get_input_shape(1).size();
if (arg0_rank != 4 || arg1_rank != 4)
{
NGRAPH_DEBUG << "Convolution's arguments ranks aren't equal to 4";
NGRAPH_DEBUG << "Convolution not supported by MKLDNN";
return false;
}
......@@ -978,40 +956,14 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_conv_bias_relu()
auto conv =
std::static_pointer_cast<op::ConvolutionBias>(m.get_match_root()->get_argument(0));
// These checks are to make sure a MKLDNN Convolution kernel can be used.
bool data_dilated = false;
for (size_t s : conv->get_data_dilation_strides())
{
data_dilated = data_dilated || (s != 1);
}
if (data_dilated)
{
NGRAPH_DEBUG << "Convolution has dilations greater than 1";
return false;
}
if (conv->get_element_type() != element::f32)
{
NGRAPH_DEBUG << "Convolution isn't of type float";
return false;
}
auto arg0_rank = conv->get_input_shape(0).size();
auto arg1_rank = conv->get_input_shape(1).size();
if (arg0_rank != 4 || arg1_rank != 4)
{
NGRAPH_DEBUG << "Convolution's arguments ranks aren't equal to 4";
return false;
}
if (conv->get_users().size() > 1)
{
NGRAPH_DEBUG << "Convolution has more than one user";
return false;
}
// ConvolutionBias created only if it can run with MKLDNN.
// No further checks needed.
auto conv_relu = std::make_shared<op::ConvolutionBias>(conv->get_argument(0),
conv->get_argument(1),
conv->get_argument(2),
......@@ -1060,31 +1012,9 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_conv_add()
inplace_input = add_m->get_argument(1);
}
//These checks are to make sure a MKLDNN Convolution kernel can be used.
bool data_dilated = false;
for (size_t s : conv_m->get_data_dilation_strides())
if (!runtime::cpu::mkldnn_utils::can_use_mkldnn_conv<op::Convolution>(conv_m.get()))
{
data_dilated = data_dilated || (s != 1);
}
if (data_dilated)
{
NGRAPH_DEBUG << "Convolution has dilations greater than 1";
return false;
}
if (conv_m->get_element_type() != element::f32)
{
NGRAPH_DEBUG << "Convolution isn't of type float";
return false;
}
auto arg0_rank = conv_m->get_input_shape(0).size();
auto arg1_rank = conv_m->get_input_shape(1).size();
if (arg0_rank != 4 || arg1_rank != 4)
{
NGRAPH_DEBUG << "Convolution's arguments ranks aren't equal to 4";
NGRAPH_DEBUG << "Convolution not supported by MKLDNN";
return false;
}
......@@ -1199,31 +1129,9 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_conv_bias_add()
inplace_input = add_m->get_argument(1);
}
// These checks are to make sure a MKLDNN Convolution kernel can be used.
bool data_dilated = false;
for (size_t s : conv_m->get_data_dilation_strides())
{
data_dilated = data_dilated || (s != 1);
}
if (data_dilated)
{
NGRAPH_DEBUG << "Convolution has dilations greater than 1";
return false;
}
if (conv_m->get_element_type() != element::f32)
{
NGRAPH_DEBUG << "Convolution isn't of type float";
return false;
}
auto arg0_rank = conv_m->get_input_shape(0).size();
auto arg1_rank = conv_m->get_input_shape(1).size();
if (arg0_rank != 4 || arg1_rank != 4)
if (!runtime::cpu::mkldnn_utils::can_use_mkldnn_conv<op::ConvolutionBias>(conv_m.get()))
{
NGRAPH_DEBUG << "Convolution's arguments ranks aren't equal to 4";
NGRAPH_DEBUG << "Convolution not supported by MKLDNN";
return false;
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment