Commit b5844622 authored by Jayaram Bobba's avatar Jayaram Bobba Committed by Scott Cyphers

Adding support for mkldnn convolution+bias+relu kernel (#913)

* Adding support for mkldnn convolution+bias+relu op to use in batch norm folding

* Style fix

* Style fix
parent 606ad20b
......@@ -2511,6 +2511,11 @@ namespace ngraph
auto data_format = mkldnn_utils::get_input_mkldnn_format(node, 0);
auto weights_format = mkldnn_utils::get_input_mkldnn_format(node, 1);
auto bias_format = mkldnn_utils::get_input_mkldnn_format(node, 2);
// HACK to help MKLDNN pick the right implementation
if (weights_format == mkldnn::memory::format::nchw)
{
weights_format = mkldnn::memory::format::oihw;
}
auto result_format = mkldnn_utils::get_output_mkldnn_format(node, 0);
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
......@@ -2559,6 +2564,82 @@ namespace ngraph
}
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::ConvolutionBiasRelu)
{
auto convolution = static_cast<const ngraph::op::ConvolutionBiasRelu*>(node);
auto arg0_shape = args[0].get_shape();
auto arg1_shape = args[1].get_shape();
auto result_shape = out[0].get_shape();
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{
// For dilation, MKLDNN wants to know how many elements to insert between, not how far
// apart to space the elements like nGraph. So we have to subtract 1 from each pos.
Strides window_dilation_strides_adjusted;
for (size_t s : convolution->get_window_dilation_strides())
{
window_dilation_strides_adjusted.push_back(s - 1);
}
auto input_format =
runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 0);
auto weights_format =
runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 1);
auto bias_format = mkldnn_utils::get_input_mkldnn_format(node, 2);
// HACK to help MKLDNN pick the right implementation
if (weights_format == mkldnn::memory::format::nchw)
{
weights_format = mkldnn::memory::format::oihw;
}
auto output_format =
runtime::cpu::mkldnn_utils::get_output_mkldnn_format(node, 0);
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto input_data_desc =
mkldnn_emitter->build_memory_descriptor(args[0], input_format);
auto weights_desc =
mkldnn_emitter->build_memory_descriptor(args[1], weights_format);
auto bias_desc = mkldnn_emitter->build_memory_descriptor(args[2], bias_format);
auto result_desc =
mkldnn_emitter->build_memory_descriptor(out[0], output_format);
size_t conv_index = 0;
const float ops_scale = 1.f;
const float ops_alpha = -0.f; // relu negative slope
const float ops_beta = 0.f;
mkldnn::post_ops ops;
ops.append_eltwise(
ops_scale, mkldnn::algorithm::eltwise_relu, ops_alpha, ops_beta);
conv_index = mkldnn_emitter->build_convolution_forward(
input_data_desc,
weights_desc,
bias_desc,
result_desc,
convolution->get_window_movement_strides(),
window_dilation_strides_adjusted,
convolution->get_padding_below(),
convolution->get_padding_above(),
ops);
auto& deps = mkldnn_emitter->get_primitive_deps(conv_index);
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[0])
<< ", " << args[0].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[1])
<< ", " << args[1].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[2])
<< ", " << args[2].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[3])
<< ", " << out[0].get_name() << ");\n";
writer << "cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, "
<< to_string(conv_index) << ");\n";
}
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::ConvolutionBiasBackpropFiltersBias)
{
......
......@@ -251,6 +251,8 @@ static const runtime::cpu::OpMap dispatcher{
&runtime::cpu::CPU_Emitter::emit<op::ConvolutionBackpropData>},
{TI(ngraph::op::ConvolutionBias), &runtime::cpu::CPU_Emitter::emit<op::ConvolutionBias>},
{TI(ngraph::op::ConvolutionRelu), &runtime::cpu::CPU_Emitter::emit<op::ConvolutionRelu>},
{TI(ngraph::op::ConvolutionBiasRelu),
&runtime::cpu::CPU_Emitter::emit<op::ConvolutionBiasRelu>},
// conv+bias backprop for data share the same implementation as ConvolutionBackpropData
{TI(ngraph::op::ConvolutionBiasBackpropFiltersBias),
&runtime::cpu::CPU_Emitter::emit<op::ConvolutionBiasBackpropFiltersBias>},
......
......@@ -146,13 +146,17 @@ size_t MKLDNNEmitter::build_convolution_forward(const mkldnn::memory::desc& inpu
const ngraph::Strides& strides,
const ngraph::Strides& dilation_strides,
const ngraph::CoordinateDiff& padding_below,
const ngraph::CoordinateDiff& padding_above)
const ngraph::CoordinateDiff& padding_above,
const mkldnn::post_ops& pops)
{
const size_t input_data_index = build_memory_primitive(input_data_desc);
const size_t weights_index = build_memory_primitive(weights_desc);
const size_t bias_index = build_memory_primitive(bias_desc);
const size_t result_index = build_memory_primitive(result_desc);
mkldnn::primitive_attr conv_attr;
conv_attr.set_post_ops(pops);
const size_t conv_index = insert_primitive(new mkldnn::convolution_forward(
{{mkldnn::prop_kind::forward,
mkldnn::algorithm::convolution_direct,
......@@ -165,6 +169,7 @@ size_t MKLDNNEmitter::build_convolution_forward(const mkldnn::memory::desc& inpu
mkldnn::memory::dims(padding_below.begin(), padding_below.end()),
mkldnn::memory::dims(padding_above.begin(), padding_above.end()),
mkldnn::padding_kind::zero},
conv_attr,
mkldnn_utils::global_cpu_engine},
*m_mkldnn_primitives[input_data_index],
*m_mkldnn_primitives[weights_index],
......
......@@ -86,7 +86,8 @@ namespace ngraph
const ngraph::Strides& strides,
const ngraph::Strides& dilation_strides,
const ngraph::CoordinateDiff& padding_below,
const ngraph::CoordinateDiff& padding_above);
const ngraph::CoordinateDiff& padding_above,
const mkldnn::post_ops& pops = mkldnn::post_ops());
size_t
build_convolution_backward_weights(const mkldnn::memory::desc& input_desc,
......
......@@ -30,6 +30,15 @@ namespace ngraph
ConvolutionBias(const std::shared_ptr<op::Convolution>& conv,
const std::shared_ptr<Node>& bias);
ConvolutionBias(const std::shared_ptr<Node>& data_batch,
const std::shared_ptr<Node>& filters,
const std::shared_ptr<Node>& bias,
const Strides& window_movement_strides,
const Strides& window_dilation_strides,
const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above,
const Strides& data_dilation_strides);
const Strides& get_window_movement_strides() const { return m_window_movement_strides; }
const Strides& get_window_dilation_strides() const { return m_window_dilation_strides; }
const CoordinateDiff& get_padding_below() const { return m_padding_below; }
......@@ -49,16 +58,6 @@ namespace ngraph
CoordinateDiff m_padding_below;
CoordinateDiff m_padding_above;
Strides m_data_dilation_strides;
private:
ConvolutionBias(const std::shared_ptr<Node>& data_batch,
const std::shared_ptr<Node>& filters,
const std::shared_ptr<Node>& bias,
const Strides& window_movement_strides,
const Strides& window_dilation_strides,
const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above,
const Strides& data_dilation_strides);
};
/// \brief Filters and bias backprop for batched convolution operation. Data backprop is
......
......@@ -95,3 +95,78 @@ std::shared_ptr<Node> op::ConvolutionRelu::copy_with_new_args(const NodeVector&
get_padding_above(),
get_data_dilation_strides()));
}
op::ConvolutionBiasRelu::ConvolutionBiasRelu(const std::shared_ptr<op::ConvolutionBias>& conv)
: RequiresTensorViewArgs("ConvolutionBiasRelu",
{conv->get_argument(0), conv->get_argument(1), conv->get_argument(2)})
, m_window_movement_strides(conv->get_window_movement_strides())
, m_window_dilation_strides(conv->get_window_dilation_strides())
, m_padding_below(conv->get_padding_below())
, m_padding_above(conv->get_padding_above())
, m_data_dilation_strides(conv->get_data_dilation_strides())
{
set_value_type_checked(conv->get_element_type(), conv->get_shape());
}
op::ConvolutionBiasRelu::ConvolutionBiasRelu(const std::shared_ptr<Node>& data_batch,
const std::shared_ptr<Node>& filters,
const std::shared_ptr<Node>& bias,
const Strides& window_movement_strides,
const Strides& window_dilation_strides,
const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above,
const Strides& data_dilation_strides)
: RequiresTensorViewArgs("ConvolutionBiasRelu", {data_batch, filters, bias})
, m_window_movement_strides(window_movement_strides)
, m_window_dilation_strides(window_dilation_strides)
, m_padding_below(padding_below)
, m_padding_above(padding_above)
, m_data_dilation_strides(data_dilation_strides)
{
auto& data_batch_shape = data_batch->get_shape();
auto& data_batch_et = data_batch->get_element_type();
auto& filters_shape = filters->get_shape();
auto& filters_et = filters->get_element_type();
//
// Make sure data batch and filter element types match.
//
if (data_batch_et != filters_et)
{
throw ngraph_error("Convolution data batch and filter element types do not match");
}
set_value_type_checked(
data_batch_et,
util::infer_convolution_output_shape(data_batch_shape,
filters_shape,
window_movement_strides,
window_dilation_strides,
padding_below,
padding_above,
data_dilation_strides,
0, /* batch_axis_data, */
1, /* input_channel_axis_data, */
1, /* input_channel_axis_filters, */
0, /* output_channel_axis_filters, */
0, /* batch_axis_result, */
1, /* output_channel_axis_result, */
""));
}
std::shared_ptr<Node> op::ConvolutionBiasRelu::copy_with_new_args(const NodeVector& new_args) const
{
if (new_args.size() != 3)
{
throw ngraph_error("Incorrect number of new arguments");
}
return std::shared_ptr<Node>(new ConvolutionBiasRelu(new_args.at(0),
new_args.at(1),
new_args.at(2),
get_window_movement_strides(),
get_window_dilation_strides(),
get_padding_below(),
get_padding_above(),
get_data_dilation_strides()));
}
......@@ -18,6 +18,7 @@
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/util/requires_tensor_view_args.hpp"
#include "ngraph/runtime/cpu/op/conv_bias.hpp"
namespace ngraph
{
......@@ -29,6 +30,14 @@ namespace ngraph
public:
ConvolutionRelu(const std::shared_ptr<op::Convolution>& conv);
ConvolutionRelu(const std::shared_ptr<Node>& data_batch,
const std::shared_ptr<Node>& filters,
const Strides& window_movement_strides,
const Strides& window_dilation_strides,
const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above,
const Strides& data_dilation_strides);
const Strides& get_window_movement_strides() const { return m_window_movement_strides; }
const Strides& get_window_dilation_strides() const { return m_window_dilation_strides; }
const CoordinateDiff& get_padding_below() const { return m_padding_below; }
......@@ -45,15 +54,39 @@ namespace ngraph
CoordinateDiff m_padding_below;
CoordinateDiff m_padding_above;
Strides m_data_dilation_strides;
};
private:
ConvolutionRelu(const std::shared_ptr<Node>& data_batch,
const std::shared_ptr<Node>& filters,
const Strides& window_movement_strides,
const Strides& window_dilation_strides,
const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above,
const Strides& data_dilation_strides);
/// \brief Relu(Convolution) forward prop for batched convolution operation with bias
class ConvolutionBiasRelu : public util::RequiresTensorViewArgs
{
public:
ConvolutionBiasRelu(const std::shared_ptr<op::ConvolutionBias>& conv);
ConvolutionBiasRelu(const std::shared_ptr<Node>& data_batch,
const std::shared_ptr<Node>& filters,
const std::shared_ptr<Node>& bias,
const Strides& window_movement_strides,
const Strides& window_dilation_strides,
const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above,
const Strides& data_dilation_strides);
const Strides& get_window_movement_strides() const { return m_window_movement_strides; }
const Strides& get_window_dilation_strides() const { return m_window_dilation_strides; }
const CoordinateDiff& get_padding_below() const { return m_padding_below; }
const CoordinateDiff& get_padding_above() const { return m_padding_above; }
const Strides& get_data_dilation_strides() const { return m_data_dilation_strides; }
std::shared_ptr<Node> get_filters() { return get_argument(1); }
std::shared_ptr<Node> get_data_batch() { return get_argument(0); }
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
protected:
Strides m_window_movement_strides;
Strides m_window_dilation_strides;
CoordinateDiff m_padding_below;
CoordinateDiff m_padding_above;
Strides m_data_dilation_strides;
};
}
}
......@@ -121,10 +121,48 @@ namespace ngraph
void CPUAssignment::ASSIGN_DECL(ngraph::op::ConvolutionRelu)
{
auto convolution = static_cast<op::ConvolutionRelu*>(node);
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
convolution->set_op_annotations(op_annotations);
auto arg0_rank = node->get_input_shape(0).size();
auto arg1_rank = node->get_input_shape(1).size();
bool data_dilated = false;
for (size_t s : convolution->get_data_dilation_strides())
{
data_dilated = data_dilated || (s != 1);
}
if (!data_dilated && arg0_rank == 4 && arg1_rank == 4 &&
node->get_input_element_type(0) == element::f32)
{
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
convolution->set_op_annotations(op_annotations);
}
}
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::ConvolutionBiasRelu)
{
auto convolution = static_cast<op::ConvolutionBiasRelu*>(node);
auto arg0_rank = node->get_input_shape(0).size();
auto arg1_rank = node->get_input_shape(1).size();
bool data_dilated = false;
for (size_t s : convolution->get_data_dilation_strides())
{
data_dilated = data_dilated || (s != 1);
}
if (!data_dilated && arg0_rank == 4 && arg1_rank == 4 &&
node->get_input_element_type(0) == element::f32)
{
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
convolution->set_op_annotations(op_annotations);
}
}
template <>
......@@ -440,6 +478,8 @@ static const runtime::cpu::pass::AssignOpMap s_dispatcher{
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::Convolution>},
{TI(ngraph::op::ConvolutionRelu),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::ConvolutionRelu>},
{TI(ngraph::op::ConvolutionBiasRelu),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::ConvolutionBiasRelu>},
{TI(ngraph::op::BatchNormRelu),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::BatchNormRelu>},
{TI(ngraph::op::ConvolutionBackpropData),
......
......@@ -377,6 +377,25 @@ namespace ngraph
}
}
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::ConvolutionBiasRelu)
{
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node.get()))
{
vector<memory::format> prim_input_formats;
vector<memory::format> prim_output_formats;
ConvolutionLayout<ngraph::op::ConvolutionBiasRelu, true>(
node, prim_input_formats, prim_output_formats);
node =
insert_input_conversions(external_function, node, prim_input_formats);
set_output_layouts(node, prim_output_formats);
}
else
{
set_default_layouts(external_function, node);
}
}
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::ConvolutionBackpropData)
{
......@@ -1290,6 +1309,8 @@ static const runtime::cpu::pass::LayoutOpMap s_dispatcher{
&runtime::cpu::pass::CPULayout::layout<ngraph::op::ConvolutionBias>},
{TI(ngraph::op::ConvolutionRelu),
&runtime::cpu::pass::CPULayout::layout<ngraph::op::ConvolutionRelu>},
{TI(ngraph::op::ConvolutionBiasRelu),
&runtime::cpu::pass::CPULayout::layout<ngraph::op::ConvolutionBiasRelu>},
{TI(ngraph::op::ConvolutionBiasBackpropFiltersBias),
&runtime::cpu::pass::CPULayout::layout<ngraph::op::ConvolutionBiasBackpropFiltersBias>},
{TI(ngraph::op::BatchNorm), &runtime::cpu::pass::CPULayout::layout<ngraph::op::BatchNorm>},
......
......@@ -838,7 +838,7 @@ template <typename T>
static std::vector<std::vector<T>>
execute(std::shared_ptr<Function> f, std::vector<std::vector<T>> args, std::string cbackend)
{
auto backend = runtime::Backend::create("CPU");
auto backend = runtime::Backend::create(cbackend);
auto parms = f->get_parameters();
......@@ -916,6 +916,56 @@ TEST(cpu_fusion, conv_relu_n2c1h2w2_2)
EXPECT_TRUE(test::all_close(cpu_results.at(0), int_results.at(0)));
}
TEST(cpu_fusion, conv_bias_relu_n2c1h2w2_2)
{
Shape shape_a{2, 1, 6, 6};
Shape shape_weights{1, 1, 2, 2};
Shape shape_bias{1};
auto make_int_function = [shape_a, shape_weights, shape_bias]() {
auto A = std::make_shared<op::Parameter>(element::f32, shape_a);
auto weights = std::make_shared<op::Parameter>(element::f32, shape_weights);
auto conv = std::make_shared<op::Convolution>(A, weights, Strides{2, 2}, Strides{1, 1});
auto bias = std::make_shared<op::Parameter>(element::f32, shape_bias);
auto conv_bias =
conv + std::make_shared<op::Broadcast>(bias, conv->get_shape(), AxisSet{0, 2, 3});
auto relu = std::make_shared<op::Relu>(conv_bias);
auto f = make_shared<Function>(NodeVector{relu}, op::ParameterVector{A, weights, bias});
return f;
};
auto int_f = make_int_function();
auto make_cpu_function = [shape_a, shape_weights, shape_bias]() {
auto A = std::make_shared<op::Parameter>(element::f32, shape_a);
auto weights = std::make_shared<op::Parameter>(element::f32, shape_weights);
auto bias = std::make_shared<op::Parameter>(element::f32, shape_bias);
auto conv = std::make_shared<op::Convolution>(A, weights, Strides{2, 2}, Strides{1, 1});
auto conv_bias_relu = std::make_shared<op::ConvolutionBiasRelu>(
std::make_shared<op::ConvolutionBias>(conv, bias));
auto f = make_shared<Function>(NodeVector{conv_bias_relu},
op::ParameterVector{A, weights, bias});
return f;
};
auto cpu_f = make_cpu_function();
vector<vector<float>> args{
{1.25f, 2.25f, 5.25f, 6.25f, -1.25f, -1.25f, 3.25f, -4.25f, 7.25f, 8.25f, -1.25f,
-1.25f, 1.25f, 2.25f, -3.25f, 2.25f, 4.25f, 4.25f, 1.25f, 2.25f, -4.25f, 2.25f,
4.25f, 4.25f, 0.f, 0.f, -1.f, 0.f, 2.f, 2.f, 0.f, 0.f, 0.f,
0.f, 2.f, 2.f, 1.25f, 2.25f, 5.25f, 6.25f, 1.25f, 1.25f, 3.25f, 4.25f,
-7.25f, 8.25f, 1.25f, -1.25f, -1.25f, 2.25f, 3.25f, 2.25f, -4.25f, -4.25f, -1.25f,
-2.25f, 4.25f, 2.25f, 4.25f, 4.25f, 0.f, 0.f, 1.f, 0.f, -2.f, 2.f,
0.f, 0.f, 0.f, 0.f, -2.f, -2.f},
{2., 2., 2., 2.},
{0.1f}};
auto int_results = execute(int_f, args, "INTERPRETER");
auto cpu_results = execute(cpu_f, args, "CPU");
EXPECT_TRUE(test::all_close(cpu_results.at(0), int_results.at(0)));
}
std::vector<shared_ptr<runtime::TensorView>>
rnn_matrix_fusion_eval(const size_t time_steps,
const Shape& data_shape,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment