Commit 03f13e4b authored by gaurides's avatar gaurides Committed by Scott Cyphers

DeconvBias (#2716)

* deconv optimizations for dcgan

* Added test cases

* modified some tests, not working at this point

* Removed temp code

* fixes to get unit test to pass

* Added node validation checks

* Update mkldnn emitter to memory reuse design

* Code cleanup

* Fix to enable deconv select the right kernel

* Fix file permissions

* Disabled unit test cases

* Remove unused variable

* Address PR feedback

* Removed dead code

* Style check

* removed dead code
parent 59632bac
......@@ -94,6 +94,7 @@ set(SRC
op/conv_add.cpp
op/conv_relu.cpp
op/convert_layout.cpp
op/deconv.cpp
op/group_conv.cpp
op/group_conv_bias.cpp
op/halide_op.cpp
......
......@@ -602,6 +602,56 @@ namespace ngraph
}
}
template <>
void Builder::BUILDER_DECL(ngraph::op::DeconvolutionBias)
{
auto& functors = external_function->get_functors();
auto arg0_shape = args[0].get_shape();
auto arg1_shape = args[1].get_shape();
auto arg2_shape = args[2].get_shape();
auto result_shape = out[0].get_shape();
auto& arg0_tensor = external_function->get_tensor_data(args[0].get_name());
auto& arg1_tensor = external_function->get_tensor_data(args[1].get_name());
auto& arg2_tensor = external_function->get_tensor_data(args[2].get_name());
auto& out_tensor = external_function->get_tensor_data(out[0].get_name());
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto deconvbias_desc =
mkldnn_emitter
->get_deconvolutionbias_forward_data<ngraph::op::DeconvolutionBias>(
node);
auto weights_desc = mkldnn_utils::get_input_mkldnn_md(node, 0);
// DeconvolutionBias needs 5 primitives: weights, delta, bias, result,
// and deconvolutionbias.
auto conv_index = mkldnn_emitter->reserve_primitive_space(5);
auto& deps = mkldnn_emitter->get_primitive_deps(conv_index);
auto functor = [&, deconvbias_desc, conv_index, weights_desc](
CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
if (ctx->first_iteration)
{
mkldnn_emitter->build_deconvolutionbias_forward(
deconvbias_desc, conv_index, weights_desc);
}
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[0], arg0_tensor);
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[1], arg1_tensor);
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[2], arg2_tensor);
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[3], out_tensor);
cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, conv_index);
};
functors.emplace_back(functor);
}
else
{
throw ngraph_error("DeconvolutionBias is only supported with MKLDNN kernel");
}
}
REGISTER_OP_BUILDER(Convolution);
REGISTER_OP_BUILDER(ConvolutionRelu);
REGISTER_OP_BUILDER(ConvolutionBias);
......@@ -612,6 +662,7 @@ namespace ngraph
REGISTER_OP_BUILDER(GroupConvolution);
REGISTER_OP_BUILDER(ConvolutionAdd);
REGISTER_OP_BUILDER(GroupConvolutionBias);
REGISTER_OP_BUILDER(DeconvolutionBias)
} // namespace cpu
} // namespace runtime
} // namespace ngraph
......@@ -117,6 +117,7 @@
#include "ngraph/runtime/cpu/op/conv_add.hpp"
#include "ngraph/runtime/cpu/op/conv_relu.hpp"
#include "ngraph/runtime/cpu/op/convert_layout.hpp"
#include "ngraph/runtime/cpu/op/deconv.hpp"
#include "ngraph/runtime/cpu/op/group_conv.hpp"
#include "ngraph/runtime/cpu/op/group_conv_bias.hpp"
#include "ngraph/runtime/cpu/op/leaky_relu.hpp"
......@@ -2283,6 +2284,40 @@ namespace ngraph
}
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::DeconvolutionBias)
{
auto arg0_shape = args[0].get_shape();
auto arg1_shape = args[1].get_shape();
auto arg2_shape = args[2].get_shape();
auto result_shape = out[0].get_shape();
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto conv_index =
mkldnn_emitter->build_deconvolution<ngraph::op::DeconvolutionBias>(
node, args, out);
auto& deps = mkldnn_emitter->get_primitive_deps(conv_index);
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[0])
<< ", " << args[0].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[1])
<< ", " << args[1].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[2])
<< ", " << args[2].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[3])
<< ", " << out[0].get_name() << ");\n";
writer << "cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, "
<< to_string(conv_index) << ");\n";
}
else
{
throw ngraph_error("DeconvolutionBias is only supported with MKLDNN kernel.");
}
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::ConvolutionBackpropData)
{
......
......@@ -159,6 +159,7 @@
#include "ngraph/runtime/cpu/op/conv_add.hpp"
#include "ngraph/runtime/cpu/op/conv_relu.hpp"
#include "ngraph/runtime/cpu/op/convert_layout.hpp"
#include "ngraph/runtime/cpu/op/deconv.hpp"
#include "ngraph/runtime/cpu/op/group_conv.hpp"
#include "ngraph/runtime/cpu/op/group_conv_bias.hpp"
#include "ngraph/runtime/cpu/op/leaky_relu.hpp"
......@@ -430,6 +431,8 @@ static const runtime::cpu::OpMap dispatcher{
{TI(ngraph::op::Dequantize), &runtime::cpu::CPU_Emitter::emit<ngraph::op::Dequantize>},
{TI(ngraph::op::GroupConvolutionBias),
&runtime::cpu::CPU_Emitter::emit<op::GroupConvolutionBias>},
{TI(ngraph::op::DeconvolutionBias),
&runtime::cpu::CPU_Emitter::emit<ngraph::op::DeconvolutionBias>},
{TI(ngraph::op::QuantizedConcat), &runtime::cpu::CPU_Emitter::emit<op::QuantizedConcat>},
};
......
......@@ -51,6 +51,7 @@
#include "ngraph/runtime/cpu/mkldnn_invoke.hpp"
#include "ngraph/runtime/cpu/mkldnn_utils.hpp"
#include "ngraph/runtime/cpu/op/convert_layout.hpp"
#include "ngraph/runtime/cpu/op/deconv.hpp"
#include "ngraph/runtime/cpu/op/lstm.hpp"
#include "ngraph/runtime/cpu/op/max_pool_with_indices.hpp"
#include "ngraph/runtime/cpu/op/rnn.hpp"
......@@ -295,6 +296,87 @@ mkldnn::memory::format MKLDNNEmitter::query_convolution_forward_weight_format(
prim_desc.weights_primitive_desc().desc().data.format);
}
void MKLDNNEmitter::build_deconvolutionbias_forward(
const mkldnn::deconvolution_forward::desc& deconv_desc,
size_t deconv_index,
const mkldnn::memory::desc& weights_desc)
{
size_t weights_index = m_primitive_deps[deconv_index][0];
build_memory_primitive(weights_desc, weights_index);
size_t delta_index = m_primitive_deps[deconv_index][1];
build_memory_primitive(deconv_desc.data.src_desc, delta_index);
size_t bias_index = m_primitive_deps[deconv_index][2];
build_memory_primitive(deconv_desc.data.bias_desc, bias_index);
size_t result_index = m_primitive_deps[deconv_index][3];
build_memory_primitive(deconv_desc.data.dst_desc, result_index);
try
{
m_mkldnn_primitives[deconv_index] =
new mkldnn::deconvolution_forward({deconv_desc, executor::global_cpu_engine},
*m_mkldnn_primitives[delta_index],
*m_mkldnn_primitives[weights_index],
*m_mkldnn_primitives[bias_index],
*m_mkldnn_primitives[result_index]);
}
catch (const mkldnn::error& e)
{
throw ngraph_error("Could not create mkldnn deconvolution_forward " + e.message);
}
}
size_t MKLDNNEmitter::build_deconvolutionbias_forward(const mkldnn::memory::desc& input_data_desc,
const mkldnn::memory::desc& weights_desc,
const mkldnn::memory::desc& bias_desc,
const mkldnn::memory::desc& result_desc,
const ngraph::Strides& strides,
const ngraph::Strides& dilation_strides,
const ngraph::CoordinateDiff& padding_below,
const ngraph::CoordinateDiff& padding_above,
const mkldnn::post_ops& pops)
{
size_t input_data_index = build_memory_primitive(input_data_desc);
size_t weights_index = build_memory_primitive(weights_desc);
size_t bias_index = build_memory_primitive(bias_desc);
size_t result_index = build_memory_primitive(result_desc);
mkldnn::primitive_attr conv_attr;
conv_attr.set_post_ops(pops);
size_t conv_index = 0;
try
{
auto conv_prim = new mkldnn::deconvolution_forward(
{{mkldnn::prop_kind::forward,
mkldnn::algorithm::deconvolution_direct,
input_data_desc,
weights_desc,
bias_desc,
result_desc,
mkldnn::memory::dims(strides.begin(), strides.end()),
mkldnn::memory::dims(dilation_strides.begin(), dilation_strides.end()),
mkldnn::memory::dims(padding_below.begin(), padding_below.end()),
mkldnn::memory::dims(padding_above.begin(), padding_above.end()),
mkldnn::padding_kind::zero},
conv_attr,
executor::global_cpu_engine},
*m_mkldnn_primitives[input_data_index],
*m_mkldnn_primitives[weights_index],
*m_mkldnn_primitives[bias_index],
*m_mkldnn_primitives[result_index]);
conv_index = insert_primitive(conv_prim);
m_primitive_deps[conv_index] = {weights_index, input_data_index, bias_index, result_index};
}
catch (const mkldnn::error& e)
{
throw ngraph_error("Could not create mkldnn deconvolution_forward " + e.message);
}
return conv_index;
}
size_t MKLDNNEmitter::build_convolution_forward(const mkldnn::memory::desc& input_data_desc,
const mkldnn::memory::desc& weights_desc,
const mkldnn::memory::desc& result_desc,
......
......@@ -50,6 +50,7 @@
#include "ngraph/runtime/cpu/op/bounded_relu.hpp"
#include "ngraph/runtime/cpu/op/conv_add.hpp"
#include "ngraph/runtime/cpu/op/conv_relu.hpp"
#include "ngraph/runtime/cpu/op/deconv.hpp"
#include "ngraph/runtime/cpu/op/group_conv.hpp"
#include "ngraph/runtime/cpu/op/group_conv_bias.hpp"
#include "ngraph/runtime/cpu/op/leaky_relu.hpp"
......@@ -324,6 +325,89 @@ namespace ngraph
}
}
void build_deconvolutionbias_forward(
const mkldnn::deconvolution_forward::desc& fwd_desc,
size_t conv_index,
const mkldnn::memory::desc& weights_desc);
size_t build_deconvolutionbias_forward(
const mkldnn::memory::desc& input_data_desc,
const mkldnn::memory::desc& weights_desc,
const mkldnn::memory::desc& bias_desc,
const mkldnn::memory::desc& result_desc,
const ngraph::Strides& strides,
const ngraph::Strides& dilation_strides,
const ngraph::CoordinateDiff& padding_below,
const ngraph::CoordinateDiff& padding_above,
const mkldnn::post_ops& pops = mkldnn::post_ops());
template <typename OP>
size_t build_deconvolution(const ngraph::Node* node,
const std::vector<TensorViewWrapper>& args,
const std::vector<TensorViewWrapper>& out)
{
auto convolution = static_cast<const OP*>(node);
// For dilation, MKLDNN wants to know how many elements to insert between, not how far
// apart to space the elements like nGraph. So we have to subtract 1 from each pos.
Strides window_dilation_strides_adjusted;
for (size_t s : convolution->get_window_dilation_strides_forward())
{
window_dilation_strides_adjusted.push_back(s - 1);
}
auto weights_desc = mkldnn_utils::get_input_mkldnn_md(node, 0);
auto data_desc = mkldnn_utils::get_input_mkldnn_md(node, 1);
// MKLDNN relies on named formats for kernel selection
if (weights_desc.data.format == mkldnn_nchw)
weights_desc.data.format = mkldnn_oihw;
if (weights_desc.data.format == mkldnn_ncdhw)
weights_desc.data.format = mkldnn_oidhw;
auto result_desc = mkldnn_utils::get_output_mkldnn_md(node, 0);
mkldnn::post_ops ops;
auto add_relu = [&]() {
if (dynamic_cast<const ngraph::op::DeconvolutionBias*>(node))
{
return (dynamic_cast<const ngraph::op::DeconvolutionBias*>(node))
->with_relu();
}
return false;
};
if (add_relu())
{
const float ops_scale = 1.f;
const float ops_alpha = -0.f; // relu negative slope
const float ops_beta = 0.f;
ops.append_eltwise(
ops_scale, mkldnn::algorithm::eltwise_relu, ops_alpha, ops_beta);
}
if (std::is_same<OP, ngraph::op::DeconvolutionBias>())
{
auto bias_desc = mkldnn_utils::get_input_mkldnn_md(node, 2);
return build_deconvolutionbias_forward(
data_desc,
weights_desc,
bias_desc,
result_desc,
convolution->get_window_movement_strides_forward(),
window_dilation_strides_adjusted,
convolution->get_padding_below_forward(),
convolution->get_padding_above_forward(),
ops);
}
else
{
throw ngraph_error("Unsupported Op.");
}
}
template <typename OP>
size_t build_inner_product(const ngraph::Node* node,
const std::vector<TensorViewWrapper>& args,
......@@ -1496,6 +1580,55 @@ namespace ngraph
m_mkldnn_primitives[ip_idx] = prim;
}
template <typename OP>
mkldnn::deconvolution_forward::desc
get_deconvolutionbias_forward_data(const ngraph::Node* node)
{
auto convolution = static_cast<const OP*>(node);
// For dilation, MKLDNN wants to know how many elements to insert between, not how far
// apart to space the elements like nGraph. So we have to subtract 1 from each pos.
Strides window_dilation_strides_adjusted;
for (size_t s : convolution->get_window_dilation_strides_forward())
{
window_dilation_strides_adjusted.push_back(s - 1);
}
auto weights_desc = mkldnn_utils::get_input_mkldnn_md(node, 0);
// MKLDNN relies on named formats for kernel selection
if (weights_desc.data.format == mkldnn_nchw)
{
weights_desc.data.format = mkldnn_oihw;
}
if (weights_desc.data.format == mkldnn_ncdhw)
{
weights_desc.data.format = mkldnn_oidhw;
}
// MKLDNN deconvolution primivtive needs weights format to be "mkldnn_any"
// with any other format it picks reference kernel which is very slow
// TODO: check if there's change in MKLDNN primitive format req.
weights_desc.data.format = mkldnn_any;
auto delta_desc = mkldnn_utils::get_input_mkldnn_md(node, 1);
auto bias_desc = mkldnn_utils::get_input_mkldnn_md(node, 2);
auto result_desc = mkldnn_utils::get_output_mkldnn_md(node, 0);
mkldnn::algorithm deconvolution_algo = mkldnn_utils::get_deconv_algo();
mkldnn::post_ops ops;
return mkldnn::deconvolution_forward::desc(
mkldnn::prop_kind::forward,
deconvolution_algo,
delta_desc,
weights_desc,
bias_desc,
result_desc,
MKLDNN_DIMS(convolution->get_window_movement_strides_forward()),
MKLDNN_DIMS(window_dilation_strides_adjusted),
MKLDNN_DIMS(convolution->get_padding_below_forward()),
MKLDNN_DIMS(convolution->get_padding_above_forward()),
mkldnn::padding_kind::zero);
}
template <typename OP>
mkldnn::convolution_backward_data::desc
get_convolution_backward_data_desc(const ngraph::Node* node)
......
......@@ -726,6 +726,13 @@ bool runtime::cpu::mkldnn_utils::can_use_mkldnn_batchnorm_fprop(const ngraph::No
}
}
mkldnn::algorithm runtime::cpu::mkldnn_utils::get_deconv_algo()
{
// Note: there is no deconvolution_auto, so for now will return direct
// TODO:
return mkldnn::algorithm::deconvolution_direct;
}
mkldnn::algorithm runtime::cpu::mkldnn_utils::get_conv_algo()
{
#if defined(MKLDNN_VERSION_MAJOR) && defined(MKLDNN_VERSION_MINOR) && defined(MKLDNN_VERSION_PATCH)
......
......@@ -84,6 +84,9 @@ namespace ngraph
*/
mkldnn::algorithm get_conv_algo();
// Placeholder for when "auto" support is added for deconv
mkldnn::algorithm get_deconv_algo();
bool use_mkldnn_kernel(const ngraph::Node* node);
void assign_mkldnn_kernel(Node* node);
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <numeric>
#include "deconv.hpp"
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/util.hpp"
#include "ngraph/validation_util.hpp"
using namespace std;
using namespace ngraph;
op::DeconvolutionBias::DeconvolutionBias(const Shape& data_batch_shape,
const shared_ptr<Node>& filters,
const shared_ptr<Node>& output_delta,
const shared_ptr<Node>& bias,
const Strides& window_movement_strides_forward,
const Strides& window_dilation_strides_forward,
const CoordinateDiff& padding_below_forward,
const CoordinateDiff& padding_above_forward,
const Strides& data_dilation_strides_forward,
const bool with_relu)
: Op("DeconvolutionBias", check_single_output_args({filters, output_delta, bias}))
, m_data_batch_shape(data_batch_shape)
, m_window_movement_strides_forward(window_movement_strides_forward)
, m_window_dilation_strides_forward(window_dilation_strides_forward)
, m_padding_below_forward(padding_below_forward)
, m_padding_above_forward(padding_above_forward)
, m_data_dilation_strides_forward(data_dilation_strides_forward)
, m_with_relu(with_relu)
{
NGRAPH_DEBUG << "DeconvolutionBias ctor" << endl;
NGRAPH_DEBUG << "data: " << data_batch_shape << ", filters: " << filters->get_shape()
<< ", output_delta: " << output_delta->get_shape();
constructor_validate_and_infer_types();
}
void op::DeconvolutionBias::validate_and_infer_types()
{
NGRAPH_DEBUG << "DeconvolutionBias::validate_and_infer_types" << endl;
// Backprop to data is itself convolution, with inputs/outputs/attributes transmogrified as
// follows.
//
// Forward Backward
// "N" axis for data batch 0 0
// "C" axis for data batch 1 1
// "Co" axis for filters 0 0
// "Ci" axis for filters 1 1
// "N" axis for output 0 0
// "C" axis for output 1 1
// Data batch x delta
// Data batch shape S_x S_o
// Filters f reverse(f) [on spatial axes]
// Filters shape S_f S_f
// Window movement strides q_x p_x
// Window dilation strides p_f p_f
// Padding below a_x (S_f - 1)p_f - a_x
// Padding above b_x (S_f - 1)p_f + ((a_x + (S_x - 1)p_x + b_x - (S_f - 1)p_f) % q_x) - b_x
// Data dilation strides p_x q_x
// Output shape S_o S_x
//
// To _validate_, we simply need to check/infer the output shape of the forward convolution,
// then check to make sure that the incoming delta has the same shape as the forward output.
//
// We will also compute and store the various parameters in the "backward" column above, since
// some backends need them. (TODO(amprocte): Is it just because of the way the reference works
// that this stuff is needed? If so, we can probably get rid of it and have conv_backprop
// reference kernels that do the calculations of the backward parameters internally, or supply
// utility functions to do it.)
const PartialShape& filters_shape = get_input_partial_shape(0);
element::Type filters_et = get_input_element_type(0);
const PartialShape& delta_shape = get_input_partial_shape(1);
element::Type delta_et = get_input_element_type(1);
const PartialShape& bias_shape = get_input_partial_shape(2);
element::Type bias_et = get_input_element_type(2);
element::Type forward_result_et;
PartialShape forward_result_shape;
const PartialShape& fwd_filters_shape{
filters_shape[1], filters_shape[0], filters_shape[2], filters_shape[3]};
std::tie(forward_result_et, forward_result_shape) =
infer_convolution_forward(this,
delta_et,
filters_et,
m_data_batch_shape,
m_data_dilation_strides_forward,
m_padding_below_forward,
m_padding_above_forward,
fwd_filters_shape,
m_window_movement_strides_forward,
m_window_dilation_strides_forward);
NGRAPH_DEBUG << "\tpartial filter_shape: " << filters_shape << "delta_shape: " << delta_shape
<< ", inferred_res_shape: " << forward_result_shape << endl;
NODE_VALIDATION_CHECK(this,
forward_result_shape.compatible(delta_shape),
"Inferred forward output shape (",
forward_result_shape,
") does not match shape of ",
"data_batch (",
m_data_batch_shape,
").");
NODE_VALIDATION_CHECK(this,
filters_et.compatible(bias_et),
"Filter element type (",
filters_et,
") does not match bias element type (",
bias_et,
").");
NODE_VALIDATION_CHECK(this,
static_cast<size_t>(bias_shape.rank()) == 1,
"bias_shape size(",
bias_shape.rank(),
") is not equal to 1");
NODE_VALIDATION_CHECK(this,
static_cast<size_t>(bias_shape[0]) ==
static_cast<size_t>(filters_shape[0]),
"Filter input channel count (",
filters_shape,
") does not compatible with ",
"bias shape channel count (",
bias_shape,
").");
set_output_type(0, forward_result_et, m_data_batch_shape);
}
void op::DeconvolutionBias::generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas)
{
throw ngraph_error("DeconvolutionBias generate_adjoints not supported implemented");
}
shared_ptr<Node> op::DeconvolutionBias::copy_with_new_args(const NodeVector& new_args) const
{
NGRAPH_DEBUG << "DeconvolutionBias::copy_with_new_args" << endl;
check_new_args_count(this, new_args);
return make_shared<DeconvolutionBias>(m_data_batch_shape,
new_args.at(0),
new_args.at(1),
new_args.at(2),
m_window_movement_strides_forward,
m_window_dilation_strides_forward,
m_padding_below_forward,
m_padding_above_forward,
m_data_dilation_strides_forward,
false);
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/op.hpp"
namespace ngraph
{
namespace op
{
/// \brief Deconvolution + Bias
class DeconvolutionBias : public Op
{
public:
/// \brief Constructs a batched-convolution data batch-backprop operation.
///
/// \param data_batch_shape The shape of the data batch from forward-prop.
/// \param filters The node producing the filters from forward-prop.
/// \param output_delta The node producing output delta.
/// \param bias The node producing bias
/// \param window_movement_strides_forward The window movement strides from forward-prop.
/// \param window_dilation_strides_forward The window dilation strides from forward-prop.
/// \param padding_below_forward The padding-below sizes from forward-prop.
/// \param padding_above_forward The padding-above sizes from forward-prop.
/// \param data_dilation_strides_forward The data dilation strides from forward-prop.
/// \param with_relu Flag indicating to add relu or not
DeconvolutionBias(const Shape& data_batch_shape,
const std::shared_ptr<Node>& filters,
const std::shared_ptr<Node>& output_delta,
const std::shared_ptr<Node>& bias,
const Strides& window_movement_strides_forward,
const Strides& window_dilation_strides_forward,
const CoordinateDiff& padding_below_forward,
const CoordinateDiff& padding_above_forward,
const Strides& data_dilation_strides_forward,
const bool with_relu);
void validate_and_infer_types() override;
void generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas) override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
/// \return The data batch shape.
const Shape& get_data_batch_shape() const { return m_data_batch_shape; }
/// \return The window movement strides from the forward prop.
const Strides& get_window_movement_strides_forward() const
{
return m_window_movement_strides_forward;
}
/// \return The window dilation strides from the forward prop.
const Strides& get_window_dilation_strides_forward() const
{
return m_window_dilation_strides_forward;
}
/// \return The padding-below sizes (possibly negative) from the forward prop.
const CoordinateDiff& get_padding_below_forward() const
{
return m_padding_below_forward;
}
/// \return The padding-above sizes (possibly negative) from the forward prop.
const CoordinateDiff& get_padding_above_forward() const
{
return m_padding_above_forward;
}
/// \return The input data dilation strides from the forward prop.
const Strides& get_data_dilation_strides_forward() const
{
return m_data_dilation_strides_forward;
}
/// \return The window movement strides for the backward prop.
const Strides& get_window_movement_strides_backward() const
{
return m_window_movement_strides_backward;
}
/// \return The window dilation strides for the backward prop.
const Strides& get_window_dilation_strides_backward() const
{
return m_window_dilation_strides_backward;
}
/// \return The padding-below sizes (possibly negative) for the backward prop.
const CoordinateDiff& get_padding_below_backward() const
{
return m_padding_below_backward;
}
/// \return The padding-above sizes (possibly negative) for the backward prop.
const CoordinateDiff& get_padding_above_backward() const
{
return m_padding_above_backward;
}
/// \return The input data dilation strides for the backward prop.
const Strides& get_data_dilation_strides_backward() const
{
return m_data_dilation_strides_backward;
}
bool with_relu() const { return m_with_relu; }
protected:
Shape m_data_batch_shape;
Strides m_window_movement_strides_forward;
Strides m_window_dilation_strides_forward;
CoordinateDiff m_padding_below_forward;
CoordinateDiff m_padding_above_forward;
Strides m_data_dilation_strides_forward;
Strides m_window_movement_strides_backward;
Strides m_window_dilation_strides_backward;
CoordinateDiff m_padding_below_backward;
CoordinateDiff m_padding_above_backward;
Strides m_data_dilation_strides_backward;
bool m_with_relu;
};
}
}
......@@ -55,6 +55,7 @@
#include "ngraph/runtime/cpu/op/bounded_relu.hpp"
#include "ngraph/runtime/cpu/op/conv_add.hpp"
#include "ngraph/runtime/cpu/op/conv_relu.hpp"
#include "ngraph/runtime/cpu/op/deconv.hpp"
#include "ngraph/runtime/cpu/op/group_conv.hpp"
#include "ngraph/runtime/cpu/op/group_conv_bias.hpp"
#include "ngraph/runtime/cpu/op/leaky_relu.hpp"
......@@ -253,6 +254,40 @@ namespace ngraph
}
}
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::DeconvolutionBias)
{
auto convolution = static_cast<ngraph::op::DeconvolutionBias*>(node);
auto arg0_shape = node->get_input_shape(0);
auto arg1_shape = node->get_input_shape(1);
auto arg2_shape = node->get_input_shape(2);
auto result_shape = node->get_output_shape(0);
auto arg0_rank = arg0_shape.size();
auto arg1_rank = arg1_shape.size();
auto arg2_rank = arg2_shape.size();
bool data_dilated = false;
for (size_t s : convolution->get_data_dilation_strides_forward())
{
data_dilated = data_dilated || (s != 1);
}
if (!data_dilated && ((arg0_rank == 4 && arg1_rank == 4) ||
(arg0_rank == 5 && arg1_rank == 5)) &&
(arg2_rank == 1) && node->get_input_element_type(0) == element::f32)
{
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
convolution->set_op_annotations(op_annotations);
}
else
{
NGRAPH_DEBUG << "DeconvolutionBias : data_dilated = " << data_dilated;
}
}
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::ConvolutionBackpropData)
{
......@@ -952,6 +987,8 @@ static const runtime::cpu::pass::AssignOpMap s_dispatcher{
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::QuantizedDotBias>},
{TI(ngraph::op::GetOutputElement),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::GetOutputElement>},
{TI(ngraph::op::DeconvolutionBias),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::DeconvolutionBias>},
};
bool runtime::cpu::pass::CPUAssignment::run_on_call_graph(
......
......@@ -69,6 +69,7 @@
#include "ngraph/runtime/cpu/op/bounded_relu.hpp"
#include "ngraph/runtime/cpu/op/conv_add.hpp"
#include "ngraph/runtime/cpu/op/conv_relu.hpp"
#include "ngraph/runtime/cpu/op/deconv.hpp"
#include "ngraph/runtime/cpu/op/group_conv.hpp"
#include "ngraph/runtime/cpu/op/group_conv_bias.hpp"
#include "ngraph/runtime/cpu/op/leaky_relu.hpp"
......@@ -1834,6 +1835,158 @@ void ngraph::runtime::cpu::pass::CPUFusion::
this->add_matcher(m);
}
void ngraph::runtime::cpu::pass::CPUFusion::construct_deconvolution_affine_folding()
{
Shape data_batch_shape{100, 512, 4, 4};
Shape filters_shape{64, 512, 4, 4};
auto data_label = std::make_shared<pattern::op::Label>(element::f32, data_batch_shape);
auto filters = std::make_shared<pattern::op::Label>(element::f32, filters_shape);
Shape conv_out_shape{100, 64, 1, 1};
auto out_delta = std::make_shared<pattern::op::Label>(element::f32, conv_out_shape);
auto conv = std::make_shared<op::ConvolutionBackpropData>(data_label->get_shape(),
filters,
out_delta,
Strides{1, 1},
Strides{1, 1},
CoordinateDiff{0, 0},
CoordinateDiff{0, 0},
Strides{1, 1});
auto conv_label = std::make_shared<pattern::op::Label>(conv, nullptr, NodeVector{conv});
auto mean = std::make_shared<pattern::op::Label>(element::f32, Shape{512});
auto var = std::make_shared<pattern::op::Label>(element::f32, Shape{512});
auto gamma = std::make_shared<pattern::op::Label>(element::f32, Shape{512});
auto beta = std::make_shared<pattern::op::Label>(element::f32, Shape{512});
double eps = 0.001;
auto bn = std::make_shared<op::BatchNormInference>(eps, gamma, beta, conv_label, mean, var);
ngraph::pattern::graph_rewrite_callback callback =
[data_label, filters, out_delta, conv_label, mean, var, gamma, beta, eps](
pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for deconv affine folding against node = "
<< m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map();
auto m_bn = std::dynamic_pointer_cast<op::BatchNormInference>(m.get_match_root());
auto conv_m =
std::static_pointer_cast<op::ConvolutionBackpropData>(pattern_map[conv_label]);
if (conv_m->get_users().size() > 1)
{
return false;
}
if (conv_m->get_shape().size() != 4)
{
return false;
}
// new weights = old weights * gamma / sqrt(variance + epsilon)
// new biases = (-mean) * gamma / sqrt(variance + epsilon) + beta
auto bn_eps = op::Constant::create(element::f32, Shape{}, {m_bn->get_eps_value()});
auto var_eps = std::make_shared<op::Add>(
pattern_map[var],
std::make_shared<op::Broadcast>(bn_eps, pattern_map[var]->get_shape(), AxisSet{0}));
auto sqrt_var_eps = std::make_shared<op::Sqrt>(var_eps);
auto weight_scaling = std::make_shared<op::Divide>(pattern_map[gamma], sqrt_var_eps);
auto weight_scaling_bcast = std::make_shared<op::Broadcast>(
weight_scaling, pattern_map[filters]->get_shape(), AxisSet{0, 2, 3});
auto new_weights =
std::make_shared<op::Multiply>(pattern_map[filters], weight_scaling_bcast);
auto mean_gamma = std::make_shared<op::Multiply>(pattern_map[mean], weight_scaling);
auto new_biases = std::make_shared<op::Subtract>(pattern_map[beta], mean_gamma);
// Weights are in i,o,h,w relative to deconvolution. Flip them to o,i,h,w
auto new_weights_reshape =
std::make_shared<op::Reshape>(new_weights,
AxisVector{1, 0, 2, 3},
Shape{new_weights->get_shape().at(1),
new_weights->get_shape().at(0),
new_weights->get_shape().at(2),
new_weights->get_shape().at(3)});
auto g_conv_bprop_data_bias = std::make_shared<op::DeconvolutionBias>(
conv_m->get_data_batch_shape(),
new_weights_reshape,
pattern_map[out_delta],
new_biases,
conv_m->get_window_movement_strides_forward(),
conv_m->get_window_dilation_strides_forward(),
conv_m->get_padding_below_forward(),
conv_m->get_padding_above_forward(),
conv_m->get_data_dilation_strides_forward(),
false);
ngraph::replace_node(m.get_match_root(), g_conv_bprop_data_bias);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(bn, callback);
this->add_matcher(m);
}
void ngraph::runtime::cpu::pass::CPUFusion::construct_deconvolution_affine_folding_relu()
{
Shape data_batch_shape{100, 512, 4, 4};
Shape filters_shape{512, 64, 4, 4}; //Note: the weights are in o,i,h,w
auto data_label = std::make_shared<pattern::op::Label>(element::f32, data_batch_shape);
auto filters = std::make_shared<pattern::op::Label>(element::f32, filters_shape);
Shape conv_out_shape{100, 64, 1, 1};
auto out_delta = std::make_shared<pattern::op::Label>(element::f32, conv_out_shape);
auto bias = std::make_shared<pattern::op::Label>(element::f32, Shape{512});
auto deconvb = std::make_shared<op::DeconvolutionBias>(data_label->get_shape(),
filters,
out_delta,
bias,
Strides{1, 1},
Strides{1, 1},
CoordinateDiff{0, 0},
CoordinateDiff{0, 0},
Strides{1, 1},
false);
auto deconvb_label =
std::make_shared<pattern::op::Label>(deconvb, nullptr, NodeVector{deconvb});
auto prelu = std::make_shared<op::Relu>(deconvb_label);
ngraph::pattern::graph_rewrite_callback callback =
[data_label, filters, out_delta, deconvb_label](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for deconvbias+relu against node = "
<< m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map();
auto deconvb_m =
std::static_pointer_cast<op::DeconvolutionBias>(pattern_map[deconvb_label]);
if (deconvb_m->get_users().size() > 1)
{
return false;
}
auto g_deconvbias_relu = std::make_shared<op::DeconvolutionBias>(
deconvb_m->get_data_batch_shape(),
deconvb_m->get_argument(0),
deconvb_m->get_argument(1),
deconvb_m->get_argument(2),
deconvb_m->get_window_movement_strides_forward(),
deconvb_m->get_window_dilation_strides_forward(),
deconvb_m->get_padding_below_forward(),
deconvb_m->get_padding_above_forward(),
deconvb_m->get_data_dilation_strides_forward(),
true);
ngraph::replace_node(m.get_match_root(), g_deconvbias_relu);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(prelu, callback);
this->add_matcher(m);
}
void ngraph::runtime::cpu::pass::CPUFusion::construct_fuse_lstm_recurrent_state()
{
auto src_layer_label = std::make_shared<pattern::op::Label>(element::f32, Shape{30, 100});
......
......@@ -72,6 +72,12 @@ public:
construct_conv_add_relu();
construct_update_slice();
construct_fuse_lstm_recurrent_state();
if (std::getenv("NGRAPH_DECONV_FUSE") != nullptr)
{
// Note: enable when the deconv perf is better than convbackpropdata
construct_deconvolution_affine_folding();
construct_deconvolution_affine_folding_relu();
}
}
}
......@@ -101,6 +107,8 @@ private:
void construct_groupconv_batchnorm_global_stats_folding_relu();
void construct_update_slice();
void construct_fuse_lstm_recurrent_state();
void construct_deconvolution_affine_folding();
void construct_deconvolution_affine_folding_relu();
};
class CPU_BACKEND_API ngraph::runtime::cpu::pass::CPUQuantFusion : public ngraph::pass::GraphRewrite
......
......@@ -63,6 +63,7 @@
#include "ngraph/runtime/cpu/op/conv_add.hpp"
#include "ngraph/runtime/cpu/op/conv_relu.hpp"
#include "ngraph/runtime/cpu/op/convert_layout.hpp"
#include "ngraph/runtime/cpu/op/deconv.hpp"
#include "ngraph/runtime/cpu/op/group_conv.hpp"
#include "ngraph/runtime/cpu/op/group_conv_bias.hpp"
#include "ngraph/runtime/cpu/op/leaky_relu.hpp"
......@@ -838,6 +839,86 @@ namespace ngraph
}
}
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::DeconvolutionBias)
{
if (mkldnn_utils::use_mkldnn_kernel(node.get()))
{
auto convolution =
static_cast<const ngraph::op::DeconvolutionBias*>(node.get());
auto data_batch_shape = convolution->get_data_batch_shape();
auto weights_shape = node->get_input_shape(0);
auto delta_shape = node->get_input_shape(1);
auto bias_shape = node->get_input_shape(2);
auto result_shape = node->get_output_shape(0);
auto filter_strides = convolution->get_window_movement_strides_forward();
auto padding_below = convolution->get_padding_below_forward();
auto padding_above = convolution->get_padding_above_forward();
Strides window_dilation_strides_adjusted;
for (size_t s : convolution->get_window_dilation_strides_forward())
{
window_dilation_strides_adjusted.push_back(s - 1);
}
memory::data_type et =
mkldnn_utils::get_mkldnn_data_type(node->get_input_element_type(0));
engine cpu_engine(engine::cpu, 0);
memory::dims mkldnn_arg0_shape(weights_shape.begin(), weights_shape.end());
memory::dims mkldnn_arg1_shape(delta_shape.begin(), delta_shape.end());
memory::dims mkldnn_arg2_shape(bias_shape.begin(), bias_shape.end());
memory::dims mkldnn_result_shape(result_shape.begin(), result_shape.end());
memory::dims mkldnn_filter_strides(filter_strides.begin(),
filter_strides.end());
memory::dims mkldnn_dilated_strides(
window_dilation_strides_adjusted.begin(),
window_dilation_strides_adjusted.end());
memory::dims mkldnn_padding_below(padding_below.begin(),
padding_below.end());
memory::dims mkldnn_padding_above(padding_above.begin(),
padding_above.end());
const memory::desc weights_desc(mkldnn_arg0_shape, et, memory::format::any);
const memory::desc delta_desc(mkldnn_arg1_shape, et, memory::format::any);
const memory::desc bias_desc(mkldnn_arg2_shape, et, memory::format::any);
const memory::desc result_desc(
mkldnn_result_shape, et, memory::format::any);
deconvolution_forward::desc deconv_desc(prop_kind::forward_inference,
algorithm::deconvolution_direct,
delta_desc, //src_desc
weights_desc, //weights_desc
bias_desc, //bias_desc
result_desc, // dst_desc
mkldnn_filter_strides,
mkldnn_dilated_strides,
mkldnn_padding_below,
mkldnn_padding_above,
padding_kind::zero);
deconvolution_forward::primitive_desc deconv_prim_desc(deconv_desc,
cpu_engine);
vector<memory::desc> i_mds;
vector<memory::desc> o_mds;
i_mds.push_back(deconv_prim_desc.weights_primitive_desc()
.desc()); //TODO: Find what format this is?
i_mds.push_back(deconv_prim_desc.src_primitive_desc().desc());
i_mds.push_back(deconv_prim_desc.bias_primitive_desc().desc());
o_mds.push_back(deconv_prim_desc.dst_primitive_desc().desc());
node = insert_input_conversions(external_function, node, i_mds);
set_output_layouts(node, o_mds);
}
else
{
throw ngraph_error("DeconvolutionBias only supported in MKLDNN for now");
}
}
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::ConvolutionBackpropData)
{
......@@ -2307,6 +2388,8 @@ static const runtime::cpu::pass::LayoutOpMap s_dispatcher{
&runtime::cpu::pass::CPULayout::layout<ngraph::op::QuantizedConvolutionBiasSignedAdd>},
{TI(ngraph::op::GroupConvolutionBias),
&runtime::cpu::pass::CPULayout::layout<ngraph::op::GroupConvolutionBias>},
{TI(ngraph::op::DeconvolutionBias),
&runtime::cpu::pass::CPULayout::layout<ngraph::op::DeconvolutionBias>},
{TI(ngraph::op::QuantizedConcat),
&runtime::cpu::pass::CPULayout::layout<ngraph::op::QuantizedConcat>},
{TI(ngraph::op::QuantizedDotBias),
......
......@@ -61,6 +61,7 @@
#include "ngraph/runtime/cpu/op/conv_add.hpp"
#include "ngraph/runtime/cpu/op/conv_relu.hpp"
#include "ngraph/runtime/cpu/op/convert_layout.hpp"
#include "ngraph/runtime/cpu/op/deconv.hpp"
#include "ngraph/runtime/cpu/op/group_conv.hpp"
#include "ngraph/runtime/cpu/op/group_conv_bias.hpp"
#include "ngraph/runtime/cpu/op/leaky_relu.hpp"
......@@ -1108,6 +1109,94 @@ TEST(cpu_fusion, conv_add)
EXPECT_TRUE(test::all_close(cpu_results.at(0), int_results.at(0)));
}
shared_ptr<Function> gen_deconv(const bool add_goe)
{
Shape conv_out_shape{100, 64, 1, 1};
auto out_delta = std::make_shared<op::Parameter>(element::f32, conv_out_shape);
Shape filters_shape{64, 512, 4, 4};
Shape bias_shape{512};
Shape data_batch_shape{100, 512, 4, 4};
auto data_label = std::make_shared<pattern::op::Label>(element::f32, data_batch_shape);
auto filters = std::make_shared<op::Parameter>(element::f32, filters_shape);
auto conv = std::make_shared<op::ConvolutionBackpropData>(data_label->get_shape(),
filters,
out_delta,
Strides{1, 1},
Strides{1, 1},
CoordinateDiff{0, 0},
CoordinateDiff{0, 0},
Strides{1, 1});
auto conv_label = std::make_shared<pattern::op::Label>(conv, nullptr, NodeVector{conv});
auto mean = std::make_shared<op::Parameter>(element::f32, bias_shape);
auto var = std::make_shared<op::Parameter>(element::f32, bias_shape);
auto gamma = std::make_shared<op::Parameter>(element::f32, bias_shape);
auto beta = std::make_shared<op::Parameter>(element::f32, bias_shape);
double eps = 0.001;
auto goe_bn = std::make_shared<op::GetOutputElement>(conv, 0);
// Adding a goe will stop fusion since the patterns wont expect to see this op
auto bn = add_goe
? std::make_shared<op::BatchNormInference>(goe_bn, gamma, beta, mean, var, eps)
: std::make_shared<op::BatchNormInference>(conv, gamma, beta, mean, var, eps);
return make_shared<Function>(NodeVector{bn},
ParameterVector{filters, out_delta, gamma, beta, mean, var});
}
TEST(cpu_fusion, fuse_deconv)
{
bool use_deconv_fuse = (getenv("NGRAPH_DECONV_FUSE") != nullptr);
if (!use_deconv_fuse)
{
set_environment("NGRAPH_DECONV_FUSE", "1", 1);
}
auto fuse_func = gen_deconv(false);
auto nofuse_func = gen_deconv(true);
{
pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>();
pass_manager.run_passes(fuse_func);
ASSERT_EQ(count_ops_of_type<op::DeconvolutionBias>(fuse_func), 1);
}
{
pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>();
pass_manager.run_passes(nofuse_func);
ASSERT_EQ(count_ops_of_type<op::DeconvolutionBias>(nofuse_func), 0);
ASSERT_EQ(count_ops_of_type<op::Relu>(nofuse_func), 0);
}
// Test values
{
test::Uniform<float> rng(1.0f, 100.0f);
vector<vector<float>> args;
for (shared_ptr<op::Parameter> param : fuse_func->get_parameters())
{
auto name = param->get_name();
vector<float> tensor_val(shape_size(param->get_shape()));
rng.initialize(tensor_val);
args.push_back(tensor_val);
}
auto nofuse_results = execute(nofuse_func, args, "CPU");
auto fuse_results = execute(fuse_func, args, "CPU");
EXPECT_TRUE(test::all_close(fuse_results.at(0), nofuse_results.at(0)));
}
if (!use_deconv_fuse)
{
unset_environment("NGRAPH_DECONV_FUSE");
}
}
shared_ptr<Function> gen_groupconv_batchnorm(const bool add_goe,
const bool with_relu,
const Shape shape_in,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment