Commit 3c8ab010 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by adstraw

ConvolutionRelu (#689)

parent c9737e83
......@@ -190,6 +190,7 @@ if (NGRAPH_CPU_ENABLE AND LLVM_INCLUDE_DIR AND
runtime/cpu/kernel/eigen_thread_pool.cpp
runtime/cpu/kernel/pad.cpp
runtime/cpu/op/conv_bias.cpp
runtime/cpu/op/conv_relu.cpp
runtime/cpu/op/convert_layout.cpp
runtime/cpu/op/sigmoid.cpp
runtime/cpu/op/matmul_bias.cpp
......
......@@ -91,6 +91,7 @@
#include "ngraph/runtime/cpu/cpu_op_annotations.hpp"
#include "ngraph/runtime/cpu/mkldnn_utils.hpp"
#include "ngraph/runtime/cpu/op/conv_bias.hpp"
#include "ngraph/runtime/cpu/op/conv_relu.hpp"
#include "ngraph/runtime/cpu/op/convert_layout.hpp"
#include "ngraph/runtime/cpu/op/matmul_bias.hpp"
#include "ngraph/runtime/cpu/op/sigmoid.hpp"
......@@ -2228,6 +2229,77 @@ namespace ngraph
writer << "}\n";
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::ConvolutionRelu)
{
auto convolution = static_cast<const ngraph::op::ConvolutionRelu*>(node);
auto arg0_shape = args[0].get_shape();
auto arg1_shape = args[1].get_shape();
auto result_shape = out[0].get_shape();
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{
// For dilation, MKLDNN wants to know how many elements to insert between, not how far
// apart to space the elements like nGraph. So we have to subtract 1 from each pos.
Strides window_dilation_strides_adjusted;
for (size_t s : convolution->get_window_dilation_strides())
{
window_dilation_strides_adjusted.push_back(s - 1);
}
auto input_format =
runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 0);
auto weights_format =
runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 1);
// HACK to help MKLDNN pick the right implementation
if (weights_format == mkldnn::memory::format::nchw)
{
weights_format = mkldnn::memory::format::oihw;
}
auto output_format =
runtime::cpu::mkldnn_utils::get_output_mkldnn_format(node, 0);
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto input_data_desc =
mkldnn_emitter->build_memory_descriptor(args[0], input_format);
auto weights_desc =
mkldnn_emitter->build_memory_descriptor(args[1], weights_format);
auto result_desc =
mkldnn_emitter->build_memory_descriptor(out[0], output_format);
size_t conv_index = 0;
const float ops_scale = 1.f;
const float ops_alpha = -0.f; // relu negative slope
const float ops_beta = 0.f;
mkldnn::post_ops ops;
ops.append_eltwise(
ops_scale, mkldnn::algorithm::eltwise_relu, ops_alpha, ops_beta);
conv_index = mkldnn_emitter->build_convolution_forward(
input_data_desc,
weights_desc,
result_desc,
convolution->get_window_movement_strides(),
window_dilation_strides_adjusted,
convolution->get_padding_below(),
convolution->get_padding_above(),
ops);
auto& deps = mkldnn_emitter->get_primitive_deps(conv_index);
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[0])
<< ", " << args[0].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[1])
<< ", " << args[1].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[2])
<< ", " << out[0].get_name() << ");\n";
writer << "cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, "
<< to_string(conv_index) << ");\n";
}
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Convolution)
{
......
......@@ -111,6 +111,7 @@
#include "ngraph/runtime/cpu/cpu_tracing.hpp"
#include "ngraph/runtime/cpu/mkldnn_utils.hpp"
#include "ngraph/runtime/cpu/op/conv_bias.hpp"
#include "ngraph/runtime/cpu/op/conv_relu.hpp"
#include "ngraph/runtime/cpu/op/convert_layout.hpp"
#include "ngraph/runtime/cpu/op/matmul_bias.hpp"
#include "ngraph/runtime/cpu/op/sigmoid.hpp"
......@@ -244,6 +245,7 @@ static const runtime::cpu::OpMap dispatcher{
{TI(ngraph::op::ConvolutionBackpropData),
&runtime::cpu::CPU_Emitter::emit<op::ConvolutionBackpropData>},
{TI(ngraph::op::ConvolutionBias), &runtime::cpu::CPU_Emitter::emit<op::ConvolutionBias>},
{TI(ngraph::op::ConvolutionRelu), &runtime::cpu::CPU_Emitter::emit<op::ConvolutionRelu>},
// conv+bias backprop for data share the same implementation as ConvolutionBackpropData
{TI(ngraph::op::ConvolutionBiasBackpropFiltersBias),
&runtime::cpu::CPU_Emitter::emit<op::ConvolutionBiasBackpropFiltersBias>},
......
......@@ -108,12 +108,16 @@ size_t MKLDNNEmitter::build_convolution_forward(const mkldnn::memory::desc& inpu
const ngraph::Strides& strides,
const ngraph::Strides& dilation_strides,
const ngraph::CoordinateDiff& padding_below,
const ngraph::CoordinateDiff& padding_above)
const ngraph::CoordinateDiff& padding_above,
const mkldnn::post_ops& pops)
{
size_t input_data_index = build_memory_primitive(input_data_desc);
size_t weights_index = build_memory_primitive(weights_desc);
size_t result_index = build_memory_primitive(result_desc);
mkldnn::primitive_attr conv_attr;
conv_attr.set_post_ops(pops);
size_t conv_index = insert_primitive(new mkldnn::convolution_forward(
{{mkldnn::prop_kind::forward,
mkldnn::algorithm::convolution_direct,
......@@ -125,6 +129,7 @@ size_t MKLDNNEmitter::build_convolution_forward(const mkldnn::memory::desc& inpu
mkldnn::memory::dims(padding_below.begin(), padding_below.end()),
mkldnn::memory::dims(padding_above.begin(), padding_above.end()),
mkldnn::padding_kind::zero},
conv_attr,
mkldnn_utils::global_cpu_engine},
*m_mkldnn_primitives[input_data_index],
*m_mkldnn_primitives[weights_index],
......
......@@ -73,7 +73,8 @@ namespace ngraph
const ngraph::Strides& strides,
const ngraph::Strides& dilation_strides,
const ngraph::CoordinateDiff& padding_below,
const ngraph::CoordinateDiff& padding_above);
const ngraph::CoordinateDiff& padding_above,
const mkldnn::post_ops& pops = mkldnn::post_ops());
/**
* Convolution + bias forward
......
......@@ -29,6 +29,7 @@
#include "ngraph/runtime/cpu/cpu_layout_descriptor.hpp"
#include "ngraph/runtime/cpu/cpu_op_annotations.hpp"
#include "ngraph/runtime/cpu/op/conv_bias.hpp"
#include "ngraph/runtime/cpu/op/conv_relu.hpp"
#include "ngraph/type/element_type.hpp"
#include "mkldnn_utils.hpp"
......@@ -49,6 +50,7 @@ static const std::unordered_set<std::type_index> s_op_registry{
TI(ngraph::op::ConvolutionBackpropData),
TI(ngraph::op::ConvolutionBackpropFilters),
TI(ngraph::op::ConvolutionBias),
TI(ngraph::op::ConvolutionRelu),
TI(ngraph::op::ConvolutionBiasBackpropFiltersBias),
TI(ngraph::op::MaxPool),
TI(ngraph::op::MaxPoolBackprop),
......
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include <numeric>
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/runtime/cpu/op/conv_relu.hpp"
#include "ngraph/util.hpp"
using namespace std;
using namespace ngraph;
op::ConvolutionRelu::ConvolutionRelu(const std::shared_ptr<op::Convolution>& conv)
: RequiresTensorViewArgs("ConvolutionRelu", {conv->get_input_op(0), conv->get_input_op(1)})
, m_window_movement_strides(conv->get_window_movement_strides())
, m_window_dilation_strides(conv->get_window_dilation_strides())
, m_padding_below(conv->get_padding_below())
, m_padding_above(conv->get_padding_above())
, m_data_dilation_strides(conv->get_data_dilation_strides())
{
set_value_type_checked(conv->get_element_type(), conv->get_shape());
}
op::ConvolutionRelu::ConvolutionRelu(const std::shared_ptr<Node>& data_batch,
const std::shared_ptr<Node>& filters,
const Strides& window_movement_strides,
const Strides& window_dilation_strides,
const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above,
const Strides& data_dilation_strides)
: RequiresTensorViewArgs("ConvolutionRelu", {data_batch, filters})
, m_window_movement_strides(window_movement_strides)
, m_window_dilation_strides(window_dilation_strides)
, m_padding_below(padding_below)
, m_padding_above(padding_above)
, m_data_dilation_strides(data_dilation_strides)
{
}
std::shared_ptr<Node> op::ConvolutionRelu::copy_with_new_args(const NodeVector& new_args) const
{
if (new_args.size() != 2)
{
throw ngraph_error("Incorrect number of new arguments");
}
return std::shared_ptr<Node>(new ConvolutionRelu(new_args.at(0),
new_args.at(1),
get_window_movement_strides(),
get_window_dilation_strides(),
get_padding_below(),
get_padding_above(),
get_data_dilation_strides()));
}
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/util/requires_tensor_view_args.hpp"
namespace ngraph
{
namespace op
{
/// \brief Relu(Convolution) forward prop for batched convolution operation.
class ConvolutionRelu : public util::RequiresTensorViewArgs
{
public:
ConvolutionRelu(const std::shared_ptr<op::Convolution>& conv);
const Strides& get_window_movement_strides() const { return m_window_movement_strides; }
const Strides& get_window_dilation_strides() const { return m_window_dilation_strides; }
const CoordinateDiff& get_padding_below() const { return m_padding_below; }
const CoordinateDiff& get_padding_above() const { return m_padding_above; }
const Strides& get_data_dilation_strides() const { return m_data_dilation_strides; }
std::shared_ptr<Node> get_filters() { return get_input_op(1); }
std::shared_ptr<Node> get_data_batch() { return get_input_op(0); }
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
protected:
Strides m_window_movement_strides;
Strides m_window_dilation_strides;
CoordinateDiff m_padding_below;
CoordinateDiff m_padding_above;
Strides m_data_dilation_strides;
private:
ConvolutionRelu(const std::shared_ptr<Node>& data_batch,
const std::shared_ptr<Node>& filters,
const Strides& window_movement_strides,
const Strides& window_dilation_strides,
const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above,
const Strides& data_dilation_strides);
};
}
}
......@@ -34,6 +34,7 @@
#include "ngraph/runtime/cpu/cpu_op_annotations.hpp"
#include "ngraph/runtime/cpu/mkldnn_utils.hpp"
#include "ngraph/runtime/cpu/op/conv_bias.hpp"
#include "ngraph/runtime/cpu/op/conv_relu.hpp"
#include "ngraph/runtime/cpu/op/sigmoid.hpp"
using namespace std;
......@@ -101,6 +102,16 @@ namespace ngraph
}
}
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::ConvolutionRelu)
{
auto convolution = static_cast<op::ConvolutionRelu*>(node);
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
convolution->set_op_annotations(op_annotations);
}
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::ConvolutionBackpropData)
{
......@@ -398,6 +409,8 @@ static const runtime::cpu::pass::AssignOpMap s_dispatcher{
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::BatchNormBackprop>},
{TI(ngraph::op::Convolution),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::Convolution>},
{TI(ngraph::op::ConvolutionRelu),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::ConvolutionRelu>},
{TI(ngraph::op::ConvolutionBackpropData),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::ConvolutionBackpropData>},
{TI(ngraph::op::ConvolutionBackpropFilters),
......
......@@ -36,6 +36,7 @@
#include "ngraph/op/negative.hpp"
#include "ngraph/op/pad.hpp"
#include "ngraph/op/parameter.hpp"
#include "ngraph/op/relu.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/sqrt.hpp"
#include "ngraph/op/subtract.hpp"
......@@ -44,6 +45,7 @@
#include "ngraph/pattern/op/any.hpp"
#include "ngraph/pattern/op/label.hpp"
#include "ngraph/runtime/cpu/op/conv_bias.hpp"
#include "ngraph/runtime/cpu/op/conv_relu.hpp"
#include "ngraph/runtime/cpu/op/matmul_bias.hpp"
#include "ngraph/runtime/cpu/op/sigmoid.hpp"
......@@ -678,3 +680,62 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_conv_bias()
auto m = std::make_shared<ngraph::pattern::Matcher>(p_conv_bias, callback);
this->add_matcher(m);
}
void ngraph::runtime::cpu::pass::CPUFusion::construct_conv_relu()
{
Shape shape{2, 2, 1, 1};
auto data_batch = std::make_shared<pattern::op::Label>(element::f32, shape);
auto filters = std::make_shared<pattern::op::Label>(element::f32, shape);
auto pconv = std::make_shared<op::Convolution>(data_batch,
filters,
Strides{1, 1},
Strides{1, 1},
CoordinateDiff{0, 0},
CoordinateDiff{0, 0},
Strides{1, 1});
auto prelu = std::make_shared<op::Relu>(pconv);
pattern::gr_callback_fn callback = [](pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_conv_relu against "
<< m.match_root()->get_name();
auto conv = std::dynamic_pointer_cast<op::Convolution>(m.match_root()->get_input_op(0));
//These checks are to make sure a MKLDNN Convolution kernel can be used.
bool data_dilated = false;
for (size_t s : conv->get_data_dilation_strides())
{
data_dilated = data_dilated || (s != 1);
}
if (data_dilated)
{
NGRAPH_DEBUG << "Convolution has dilations greater than 1";
return false;
}
if (conv->get_element_type() != element::f32)
{
NGRAPH_DEBUG << "Convolution isn't of type float";
return false;
}
auto arg0_rank = conv->get_input_shape(0).size();
auto arg1_rank = conv->get_input_shape(1).size();
if (arg0_rank != 4 || arg1_rank != 4)
{
NGRAPH_DEBUG << "Convolution's arguments ranks aren't equal to 4";
return false;
}
auto conv_relu = std::shared_ptr<Node>(new op::ConvolutionRelu(conv));
ngraph::replace_node(m.match_root(), conv_relu);
return true;
};
auto m = std::make_shared<pattern::Matcher>(prelu, callback);
this->add_matcher(m);
}
......@@ -46,6 +46,7 @@ public:
construct_sigmoid();
construct_sigmoid_bprop();
construct_conv_bias();
construct_conv_relu();
}
private:
......@@ -57,4 +58,5 @@ private:
void construct_sigmoid_bprop();
void construct_zero_padded_reshaped_conv();
void construct_zero_padded_conv();
void construct_conv_relu();
};
......@@ -28,6 +28,7 @@
#include "ngraph/op/batch_norm.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/parameter.hpp"
#include "ngraph/op/relu.hpp"
#include "ngraph/op/sum.hpp"
#include "ngraph/pass/graph_rewrite.hpp"
#include "ngraph/pass/manager.hpp"
......@@ -41,6 +42,7 @@
#include "ngraph/pass/reshape_elimination.hpp"
#include "ngraph/pass/visualize_tree.hpp"
#include "ngraph/runtime/cpu/op/conv_bias.hpp"
#include "ngraph/runtime/cpu/op/conv_relu.hpp"
#include "ngraph/runtime/cpu/op/matmul_bias.hpp"
#include "ngraph/runtime/cpu/op/sigmoid.hpp"
#include "ngraph/runtime/cpu/pass/cpu_fusion.hpp"
......@@ -982,6 +984,65 @@ TEST(cpu_fusion, sigmoid_bprop_n1c1h4)
EXPECT_TRUE(test::all_close(expected, read_vector<float>(result)));
}
TEST(cpu_fusion, fuse_conv_relu)
{
auto A = std::make_shared<op::Parameter>(element::f32, Shape{2, 1, 2, 2});
auto weights = std::make_shared<op::Parameter>(element::f32, Shape{1, 1, 2, 2});
auto convolution = std::make_shared<op::Convolution>(A, weights, Strides{1, 1}, Strides{1, 1});
auto relu = std::make_shared<op::Relu>(convolution);
auto abs_node =
std::make_shared<op::Abs>(std::make_shared<op::Abs>(std::make_shared<op::Abs>(relu)));
auto func = make_shared<Function>(abs_node, op::ParameterVector{A, weights});
pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>();
pass_manager.run_passes(func);
size_t cb = count_ops_of_type<op::ConvolutionRelu>(func);
ASSERT_GT(cb, 0);
}
TEST(cpu_fusion, conv_relu_n2c1h2w2)
{
Shape shape_a{2, 1, 6, 6};
Shape shape_weights{1, 1, 2, 2};
auto A = std::make_shared<op::Parameter>(element::f32, shape_a);
auto weights = std::make_shared<op::Parameter>(element::f32, shape_weights);
auto conv = std::make_shared<op::Convolution>(A, weights, Strides{2, 2}, Strides{1, 1});
auto relu = std::make_shared<op::Relu>(conv);
auto conv_relu = std::make_shared<op::ConvolutionRelu>(conv);
auto manager = runtime::Manager::get("CPU");
auto backend = manager->allocate_backend();
auto _a = backend->make_primary_tensor_view(element::f32, shape_a);
vector<float> va{
1.25f, 2.25f, 5.25f, 6.25f, -1.25f, -1.25f, 3.25f, -4.25f, 7.25f, 8.25f, -1.25f, -1.25f,
1.25f, 2.25f, -3.25f, 2.25f, 4.25f, 4.25f, 1.25f, 2.25f, -4.25f, 2.25f, 4.25f, 4.25f,
0.f, 0.f, -1.f, 0.f, 2.f, 2.f, 0.f, 0.f, 0.f, 0.f, 2.f, 2.f,
1.25f, 2.25f, 5.25f, 6.25f, 1.25f, 1.25f, 3.25f, 4.25f, -7.25f, 8.25f, 1.25f, -1.25f,
-1.25f, 2.25f, 3.25f, 2.25f, -4.25f, -4.25f, -1.25f, -2.25f, 4.25f, 2.25f, 4.25f, 4.25f,
0.f, 0.f, 1.f, 0.f, -2.f, 2.f, 0.f, 0.f, 0.f, 0.f, -2.f, -2.f};
copy_data(_a, va);
auto _weights = backend->make_primary_tensor_view(element::f32, shape_weights);
copy_data(_weights, vector<float>{2., 2., 2., 2.});
auto f = make_shared<Function>(NodeVector{conv_relu, relu}, op::ParameterVector{A, weights});
auto external = manager->compile(f);
auto cf = backend->make_call_frame(external);
shared_ptr<runtime::TensorView> _conv_relu =
backend->make_primary_tensor_view(element::f32, conv_relu->get_shape());
shared_ptr<runtime::TensorView> _relu =
backend->make_primary_tensor_view(element::f32, relu->get_shape());
cf->call({_conv_relu, _relu}, {_a, _weights});
EXPECT_TRUE(test::all_close(read_vector<float>(_conv_relu), read_vector<float>(_relu)));
}
TEST(cpu_fusion, batchnorm_fprop_inference_b2c2h2w1)
{
auto input_shape = Shape{2, 2, 2, 1};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment