Unverified Commit 8c63cd53 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Remove _backwards from conv backprop ops. (#2593)

* Remove _backwards from conv backprop ops.

FLip filter in backprop_data instead of general_convolution

* Review comment
parent a0e7694b
......@@ -270,13 +270,6 @@ void op::ConvolutionBackpropData::validate_and_infer_types()
//
// To _validate_, we simply need to check/infer the output shape of the forward convolution,
// then check to make sure that the incoming delta has the same shape as the forward output.
//
// We will also compute and store the various parameters in the "backward" column above, since
// some backends need them. (TODO(amprocte): Is it just because of the way the reference works
// that this stuff is needed? If so, we can probably get rid of it and have conv_backprop
// reference kernels that do the calculations of the backward parameters internally, or supply
// utility functions to do it.)
const PartialShape& filters_shape = get_input_partial_shape(0);
element::Type filters_et = get_input_element_type(0);
const PartialShape& delta_shape = get_input_partial_shape(1);
......@@ -307,40 +300,6 @@ void op::ConvolutionBackpropData::validate_and_infer_types()
").");
set_output_type(0, forward_result_et, m_data_batch_shape);
//
// Compute parameters needed for backprop-as-convolution.
//
// TODO(amprocte): Remove these fields, compute where needed.
//
if (delta_shape.is_static() && filters_shape.is_static())
{
size_t spatial_dim_count = static_cast<size_t>(delta_shape.rank()) - 2;
m_window_movement_strides_backward = m_data_dilation_strides_forward;
m_window_dilation_strides_backward = m_window_dilation_strides_forward;
m_data_dilation_strides_backward = m_window_movement_strides_forward;
m_padding_below_backward.resize(spatial_dim_count);
m_padding_above_backward.resize(spatial_dim_count);
for (size_t i = 0; i < spatial_dim_count; i++)
{
m_padding_below_backward[i] = (static_cast<ptrdiff_t>(filters_shape[i + 2]) - 1) *
m_window_dilation_strides_forward[i] -
m_padding_below_forward[i];
m_padding_above_backward[i] =
(static_cast<ptrdiff_t>(filters_shape[i + 2]) - 1) *
m_window_dilation_strides_forward[i] +
((m_padding_below_forward[i] +
(m_data_batch_shape[i + 2] - 1) * m_data_dilation_strides_forward[i] +
m_padding_above_forward[i] -
(static_cast<ptrdiff_t>(filters_shape[i + 2]) - 1) *
m_window_dilation_strides_forward[i]) %
m_window_movement_strides_forward[i]) -
m_padding_above_forward[i];
}
}
}
void op::ConvolutionBackpropData::generate_adjoints(autodiff::Adjoints& adjoints,
......@@ -364,23 +323,35 @@ void op::ConvolutionBackpropData::generate_adjoints(autodiff::Adjoints& adjoints
adjoints.add_delta(x, data_conv);
Strides window_movement_strides;
Strides window_dilation_strides;
Strides window_movement_strides = m_window_dilation_strides_forward;
Strides window_dilation_strides = m_data_dilation_strides_forward;
Strides data_dilation_strides = m_window_movement_strides_forward;
CoordinateDiff padding_below;
CoordinateDiff padding_above;
Strides data_dilation_strides;
const Shape& filters_shape = get_input_shape(0);
for (size_t i = 0; i < f_shape.size() - 2; i++)
{
window_movement_strides.push_back(m_window_dilation_strides_backward[i]);
window_dilation_strides.push_back(m_window_movement_strides_backward[i]);
padding_below.push_back(m_padding_below_backward[i]);
padding_above.push_back(m_padding_above_backward[i] -
(m_padding_below_backward[i] +
(x_shape[i + 2] - 1) * m_data_dilation_strides_backward[i] +
m_padding_above_backward[i] -
(f_shape[i + 2] - 1) * m_window_dilation_strides_backward[i]) %
m_window_movement_strides_backward[i]);
data_dilation_strides.push_back(m_data_dilation_strides_backward[i]);
ptrdiff_t padding_below_backward =
(static_cast<ptrdiff_t>(filters_shape[i + 2]) - 1) * window_dilation_strides[i] -
m_padding_below_forward[i];
padding_below.push_back(padding_below_backward);
ptrdiff_t padding_above_backward =
(static_cast<ptrdiff_t>(filters_shape[i + 2]) - 1) *
m_window_dilation_strides_forward[i] +
((m_padding_below_forward[i] +
((m_data_batch_shape[i + 2]) - 1) * m_data_dilation_strides_forward[i] +
m_padding_above_forward[i] -
(static_cast<ptrdiff_t>(filters_shape[i + 2]) - 1) *
m_window_dilation_strides_forward[i]) %
m_window_movement_strides_forward[i]) -
m_padding_above_forward[i];
padding_above.push_back(
padding_above_backward -
(padding_below_backward + (x_shape[i + 2] - 1) * m_window_movement_strides_forward[i] +
padding_above_backward - (f_shape[i + 2] - 1) * m_window_dilation_strides_forward[i]) %
m_data_dilation_strides_forward[i]);
}
auto swap_NC = [](const shared_ptr<Node> n) {
......@@ -427,6 +398,52 @@ shared_ptr<Node> op::ConvolutionBackpropData::copy_with_new_args(const NodeVecto
m_data_dilation_strides_forward);
}
CoordinateDiff op::ConvolutionBackpropData::compute_backward_delta_out_pad_below() const
{
auto& in_shape = get_data_batch_shape();
auto& filter_dilation = get_window_dilation_strides_forward();
auto& filter_shape = get_input_shape(0);
auto& in_pad_below = get_padding_below_forward();
size_t spatial_dim_count = static_cast<size_t>(in_shape.size()) - 2;
CoordinateDiff backward_delta_out_pad_below;
backward_delta_out_pad_below.resize(spatial_dim_count);
for (size_t i = 0; i < spatial_dim_count; i++)
{
backward_delta_out_pad_below[i] =
(static_cast<ptrdiff_t>(filter_shape[i + 2]) - 1) * filter_dilation[i] -
in_pad_below[i];
}
return backward_delta_out_pad_below;
}
CoordinateDiff op::ConvolutionBackpropData::compute_backward_delta_out_pad_above() const
{
auto& in_shape = get_data_batch_shape();
auto& filter_dilation = get_window_dilation_strides_forward();
auto& filter_shape = get_input_shape(0);
auto& in_pad_below = get_padding_below_forward();
auto& in_pad_above = get_padding_above_forward();
auto& in_dilation = get_data_dilation_strides_forward();
auto& stride = get_window_movement_strides_forward();
size_t spatial_dim_count = static_cast<size_t>(in_shape.size()) - 2;
CoordinateDiff backward_delta_out_pad_above;
backward_delta_out_pad_above.resize(spatial_dim_count);
for (size_t i = 0; i < spatial_dim_count; i++)
{
backward_delta_out_pad_above[i] =
(static_cast<ptrdiff_t>(filter_shape[i + 2]) - 1) * filter_dilation[i] +
((in_pad_below[i] + ((in_shape[i + 2]) - 1) * in_dilation[i] + in_pad_above[i] -
(static_cast<ptrdiff_t>(filter_shape[i + 2]) - 1) * filter_dilation[i]) %
stride[i]) -
in_pad_above[i];
}
return backward_delta_out_pad_above;
}
op::ConvolutionBackpropFilters::ConvolutionBackpropFilters(
const shared_ptr<Node>& data_batch,
const Shape& filters_shape,
......@@ -509,35 +526,6 @@ void op::ConvolutionBackpropFilters::validate_and_infer_types()
").");
set_output_type(0, forward_result_et, m_filters_shape);
//
// Compute parameters needed for backprop-as-convolution.
//
// TODO(amprocte): Remove these fields, compute where needed.
//
if (delta_shape.is_static() && data_batch_shape.is_static())
{
size_t spatial_dim_count = static_cast<size_t>(delta_shape.rank()) - 2;
m_window_movement_strides_backward = m_window_dilation_strides_forward;
m_window_dilation_strides_backward = m_window_movement_strides_forward;
m_padding_below_backward = m_padding_below_forward;
m_data_dilation_strides_backward = m_data_dilation_strides_forward;
m_padding_above_backward.resize(spatial_dim_count);
for (size_t i = 0; i < spatial_dim_count; i++)
{
m_padding_above_backward[i] =
m_padding_above_forward[i] -
(m_padding_below_forward[i] +
(static_cast<ptrdiff_t>(data_batch_shape[i + 2]) - 1) *
m_data_dilation_strides_forward[i] +
m_padding_above_forward[i] -
(m_filters_shape[i + 2] - 1) * m_window_dilation_strides_forward[i]) %
m_window_movement_strides_forward[i];
}
}
}
shared_ptr<Node>
......@@ -554,6 +542,31 @@ shared_ptr<Node>
m_data_dilation_strides_forward);
}
CoordinateDiff op::ConvolutionBackpropFilters::compute_backward_in_pad_above() const
{
const auto& in_shape = get_input_shape(0);
const auto& out_shape = get_input_shape(1);
const auto& filter_shape = get_filters_shape();
const auto& in_pad_above = get_padding_above_forward();
const auto& in_pad_below = get_padding_below_forward();
const auto& in_dilation = get_data_dilation_strides_forward();
const auto& filter_dilation = get_window_dilation_strides_forward();
const auto& stride = get_window_movement_strides_forward();
size_t spatial_dim_count = static_cast<size_t>(out_shape.size()) - 2;
CoordinateDiff backward_in_pad_above;
backward_in_pad_above.resize(spatial_dim_count);
for (size_t i = 0; i < spatial_dim_count; i++)
{
backward_in_pad_above[i] =
in_pad_above[i] -
(in_pad_below[i] + (static_cast<ptrdiff_t>(in_shape[i + 2]) - 1) * in_dilation[i] +
in_pad_above[i] - (filter_shape[i + 2] - 1) * filter_dilation[i]) %
stride[i];
}
return backward_in_pad_above;
}
//
// This is a legacy function, retained because the CPU backend uses it for now.
// TODO(amprocte): Update CPU backend to use the new stuff in validation_util.hpp, and remove this
......
......@@ -220,31 +220,9 @@ namespace ngraph
return m_data_dilation_strides_forward;
}
/// \return The window movement strides for the backward prop.
const Strides& get_window_movement_strides_backward() const
{
return m_window_movement_strides_backward;
}
/// \return The window dilation strides for the backward prop.
const Strides& get_window_dilation_strides_backward() const
{
return m_window_dilation_strides_backward;
}
/// \return The padding-below sizes (possibly negative) for the backward prop.
const CoordinateDiff& get_padding_below_backward() const
{
return m_padding_below_backward;
}
/// \return The padding-above sizes (possibly negative) for the backward prop.
const CoordinateDiff& get_padding_above_backward() const
{
return m_padding_above_backward;
}
/// \return The input data dilation strides for the backward prop.
const Strides& get_data_dilation_strides_backward() const
{
return m_data_dilation_strides_backward;
}
// Compute the pad_above values to be used if in a convolution
CoordinateDiff compute_backward_delta_out_pad_above() const;
CoordinateDiff compute_backward_delta_out_pad_below() const;
protected:
Shape m_data_batch_shape;
......@@ -253,12 +231,6 @@ namespace ngraph
CoordinateDiff m_padding_below_forward;
CoordinateDiff m_padding_above_forward;
Strides m_data_dilation_strides_forward;
Strides m_window_movement_strides_backward;
Strides m_window_dilation_strides_backward;
CoordinateDiff m_padding_below_backward;
CoordinateDiff m_padding_above_backward;
Strides m_data_dilation_strides_backward;
};
/// \brief Filters backprop for batched convolution operation.
......@@ -317,31 +289,8 @@ namespace ngraph
return m_data_dilation_strides_forward;
}
/// \return The window movement strides for the backward prop.
const Strides& get_window_movement_strides_backward() const
{
return m_window_movement_strides_backward;
}
/// \return The window dilation strides for the backward prop.
const Strides& get_window_dilation_strides_backward() const
{
return m_window_dilation_strides_backward;
}
/// \return The padding-below sizes (possibly negative) for the backward prop.
const CoordinateDiff& get_padding_below_backward() const
{
return m_padding_below_backward;
}
/// \return The padding-above sizes (possibly negative) for the backward prop.
const CoordinateDiff& get_padding_above_backward() const
{
return m_padding_above_backward;
}
/// \return The data dilation strides for the backward prop.
const Strides& get_data_dilation_strides_backward() const
{
return m_data_dilation_strides_backward;
}
// Compute the pad_above value to be used if in a convolution
CoordinateDiff compute_backward_in_pad_above() const;
protected:
Shape m_filters_shape;
......@@ -350,12 +299,6 @@ namespace ngraph
CoordinateDiff m_padding_below_forward;
CoordinateDiff m_padding_above_forward;
Strides m_data_dilation_strides_forward;
Strides m_window_movement_strides_backward;
Strides m_window_dilation_strides_backward;
CoordinateDiff m_padding_below_backward;
CoordinateDiff m_padding_above_backward;
Strides m_data_dilation_strides_backward;
};
namespace util
......
......@@ -14,9 +14,8 @@
// limitations under the License.
//*****************************************************************************
#include "ngraph/op/convolution.hpp"
#include "ngraph/runtime/cpu/cpu_builder.hpp"
#include "ngraph/runtime/cpu/kernel/convolution.hpp"
#include "ngraph/runtime/cpu/cpu_builder.hpp"
#include "ngraph/runtime/cpu/mkldnn_invoke.hpp"
#include "ngraph/runtime/cpu/mkldnn_utils.hpp"
#include "ngraph/runtime/cpu/op/conv_add.hpp"
......@@ -107,14 +106,7 @@ namespace ngraph
window_dilation_strides,
padding_below,
padding_above,
data_dilation_strides,
0,
1,
1,
0,
0,
1,
false);
data_dilation_strides);
};
functors.emplace_back(functor);
}
......@@ -313,7 +305,6 @@ namespace ngraph
auto arg0_shape = args[0].get_shape();
auto arg1_shape = args[1].get_shape();
auto result_shape = out[0].get_shape();
auto& arg0_tensor = external_function->get_tensor_data(args[0].get_name());
auto& arg1_tensor = external_function->get_tensor_data(args[1].get_name());
......@@ -347,48 +338,47 @@ namespace ngraph
}
else
{
std::function<decltype(runtime::cpu::kernel::convolution<float>)> kernel;
std::function<decltype(runtime::cpu::kernel::convolution_backprop_in<float>)>
kernel;
SELECT_KERNEL(
kernel, out[0].get_element_type(), runtime::cpu::kernel::convolution);
auto window_movement_strides =
convolution->get_window_movement_strides_backward();
SELECT_KERNEL(kernel,
out[0].get_element_type(),
runtime::cpu::kernel::convolution_backprop_in);
auto& in_shape = convolution->get_data_batch_shape();
auto data_dilation_strides = convolution->get_data_dilation_strides_forward();
auto window_dilation_strides =
convolution->get_window_dilation_strides_backward();
auto padding_below = convolution->get_padding_below_backward();
auto padding_above = convolution->get_padding_above_backward();
auto data_dilation_strides = convolution->get_data_dilation_strides_backward();
convolution->get_window_dilation_strides_forward();
auto padding_below = convolution->get_padding_below_forward();
auto padding_above = convolution->get_padding_above_forward();
auto window_movement_strides =
convolution->get_window_movement_strides_forward();
auto backward_delta_out_pad_below =
convolution->compute_backward_delta_out_pad_below();
auto backward_delta_out_pad_above =
convolution->compute_backward_delta_out_pad_above();
auto functor = [&,
kernel,
arg0_shape,
arg1_shape,
result_shape,
window_movement_strides,
in_shape,
data_dilation_strides,
window_dilation_strides,
padding_below,
padding_above,
data_dilation_strides](CPURuntimeContext* ctx,
backward_delta_out_pad_below,
backward_delta_out_pad_above,
window_movement_strides](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
kernel(arg1_tensor,
arg0_tensor,
out_tensor,
arg1_shape,
arg0_shape,
result_shape,
window_movement_strides,
window_dilation_strides,
padding_below,
padding_above,
in_shape,
data_dilation_strides,
0,
1,
0,
1,
0,
1,
true);
window_dilation_strides,
backward_delta_out_pad_below,
backward_delta_out_pad_above,
window_movement_strides);
};
functors.emplace_back(functor);
}
......@@ -403,7 +393,6 @@ namespace ngraph
auto arg0_shape = args[0].get_shape();
auto arg1_shape = args[1].get_shape();
auto result_shape = out[0].get_shape();
auto& arg0_tensor = external_function->get_tensor_data(args[0].get_name());
auto& arg1_tensor = external_function->get_tensor_data(args[1].get_name());
......@@ -437,28 +426,34 @@ namespace ngraph
}
else
{
std::function<decltype(runtime::cpu::kernel::convolution<float>)> kernel;
std::function<decltype(
runtime::cpu::kernel::convolution_backprop_filter<float>)>
kernel;
SELECT_KERNEL(
kernel, out[0].get_element_type(), runtime::cpu::kernel::convolution);
SELECT_KERNEL(kernel,
out[0].get_element_type(),
runtime::cpu::kernel::convolution_backprop_filter);
auto window_movement_strides =
convolution->get_window_movement_strides_backward();
auto& filters_shape = convolution->get_filters_shape();
auto window_dilation_strides =
convolution->get_window_dilation_strides_backward();
auto padding_below = convolution->get_padding_below_backward();
auto padding_above = convolution->get_padding_above_backward();
auto data_dilation_strides = convolution->get_data_dilation_strides_backward();
convolution->get_window_dilation_strides_forward();
auto window_movement_strides =
convolution->get_window_movement_strides_forward();
auto padding_below = convolution->get_padding_below_forward();
auto padding_above = convolution->get_padding_above_forward();
auto data_dilation_strides = convolution->get_data_dilation_strides_forward();
CoordinateDiff backward_in_pad_above =
convolution->compute_backward_in_pad_above();
auto functor = [&,
kernel,
arg0_shape,
arg1_shape,
result_shape,
window_movement_strides,
filters_shape,
window_dilation_strides,
window_movement_strides,
padding_below,
padding_above,
backward_in_pad_above,
data_dilation_strides](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
kernel(arg0_tensor,
......@@ -466,19 +461,12 @@ namespace ngraph
out_tensor,
arg0_shape,
arg1_shape,
result_shape,
filters_shape,
window_movement_strides,
window_dilation_strides,
padding_below,
padding_above,
data_dilation_strides,
1,
0,
0,
1,
1,
0,
false);
backward_in_pad_above,
data_dilation_strides);
};
functors.emplace_back(functor);
}
......@@ -624,6 +612,6 @@ namespace ngraph
REGISTER_OP_BUILDER(GroupConvolution);
REGISTER_OP_BUILDER(ConvolutionAdd);
REGISTER_OP_BUILDER(GroupConvolutionBias);
}
}
}
} // namespace cpu
} // namespace runtime
} // namespace ngraph
......@@ -2288,8 +2288,7 @@ namespace ngraph
writer << " {" << join(convolution->get_padding_above())
<< "},\n";
writer << " {"
<< join(convolution->get_data_dilation_strides()) << "},\n";
writer << " 0, 1, 1, 0, 0, 1, false);\n";
<< join(convolution->get_data_dilation_strides()) << "});\n";
}
}
......@@ -2323,24 +2322,25 @@ namespace ngraph
}
else
{
writer << "reference::convolution<" << out[0].get_type() << ">("
<< args[0].get_name() << ",\n";
writer << "reference::convolution_backprop_filters<" << out[0].get_type()
<< ">(" << args[0].get_name() << ",\n";
writer << " " << args[1].get_name() << ",\n";
writer << " " << out[0].get_name() << ",\n";
writer << " {" << join(convolution->get_filters_shape())
<< "},\n";
writer << " {" << join(arg0_shape) << "},\n";
writer << " {" << join(arg1_shape) << "},\n";
writer << " {" << join(result_shape) << "},\n";
writer << " {"
<< join(convolution->get_window_movement_strides_backward()) << "},\n";
<< join(convolution->get_window_dilation_strides_forward()) << "},\n";
writer << " {"
<< join(convolution->get_window_dilation_strides_backward()) << "},\n";
<< join(convolution->get_window_movement_strides_forward()) << "},\n";
writer << " {"
<< join(convolution->get_padding_below_backward()) << "},\n";
<< join(convolution->get_padding_below_forward()) << "},\n";
writer << " {"
<< join(convolution->get_padding_above_backward()) << "},\n";
<< join(convolution->get_padding_above_forward()) << "},\n";
writer << " {"
<< join(convolution->get_data_dilation_strides_backward()) << "},\n";
writer << " 1, 0, 0, 1, 1, 0, false);\n";
<< join(convolution->get_data_dilation_strides_forward()) << "});\n";
}
}
......@@ -2375,24 +2375,25 @@ namespace ngraph
else
{
// Note that args[1] and args[0] are switched here from the usual order.
writer << "reference::convolution<" << out[0].get_type() << ">("
writer << "reference::convolution_backprop_data<" << out[0].get_type() << ">("
<< args[1].get_name() << ",\n";
writer << " " << args[0].get_name() << ",\n";
writer << " " << out[0].get_name() << ",\n";
writer << " {"
<< join(convolution->get_data_batch_shape()) << "},\n";
writer << " {" << join(arg1_shape) << "},\n";
writer << " {" << join(arg0_shape) << "},\n";
writer << " {" << join(result_shape) << "},\n";
writer << " {"
<< join(convolution->get_window_movement_strides_backward()) << "},\n";
<< join(convolution->get_data_dilation_strides_forward()) << "},\n";
writer << " {"
<< join(convolution->get_window_dilation_strides_backward()) << "},\n";
<< join(convolution->get_window_dilation_strides_forward()) << "},\n";
writer << " {"
<< join(convolution->get_padding_below_backward()) << "},\n";
<< join(convolution->get_padding_below_forward()) << "},\n";
writer << " {"
<< join(convolution->get_padding_above_backward()) << "},\n";
<< join(convolution->get_padding_above_forward()) << "},\n";
writer << " {"
<< join(convolution->get_data_dilation_strides_backward()) << "},\n";
writer << " 0, 1, 0, 1, 0, 1, true);\n";
<< join(convolution->get_window_movement_strides_forward()) << "});\n";
}
}
......
......@@ -38,14 +38,7 @@ namespace ngraph
const Strides& window_dilation_strides,
const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above,
const Strides& data_dilation_strides,
size_t batch_axis_data,
size_t input_channel_axis_data,
size_t input_channel_axis_filters,
size_t output_channel_axis_filters,
size_t batch_axis_result,
size_t output_channel_axis_result,
bool rotate_filter)
const Strides& data_dilation_strides)
{
reference::convolution<ElementType>(static_cast<const ElementType*>(input0),
static_cast<const ElementType*>(input1),
......@@ -57,16 +50,63 @@ namespace ngraph
window_dilation_strides,
padding_below,
padding_above,
data_dilation_strides,
batch_axis_data,
input_channel_axis_data,
input_channel_axis_filters,
output_channel_axis_filters,
batch_axis_result,
output_channel_axis_result,
rotate_filter);
}
data_dilation_strides);
}
template <typename ElementType>
void convolution_backprop_filter(void* input0,
void* input1,
void* output,
const Shape& arg0_shape,
const Shape& arg1_shape,
const Shape& filter_shape,
const Strides& window_dilation_strides,
const Strides& window_movement_strides,
const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above,
const Strides& data_dilation_strides)
{
reference::convolution_backprop_filter<ElementType>(
static_cast<const ElementType*>(input0),
static_cast<const ElementType*>(input1),
static_cast<ElementType*>(output),
arg0_shape,
arg1_shape,
filter_shape,
window_dilation_strides,
window_movement_strides,
padding_below,
padding_above,
data_dilation_strides);
}
template <typename ElementType>
void convolution_backprop_in(void* input0,
void* input1,
void* output,
const Shape& arg0_shape,
const Shape& arg1_shape,
const Shape& in_shape,
const Strides& window_movement_strides,
const Strides& window_dilation_strides,
const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above,
const Strides& data_dilation_strides)
{
reference::convolution_backprop_in<ElementType>(
static_cast<const ElementType*>(input0),
static_cast<const ElementType*>(input1),
static_cast<ElementType*>(output),
arg0_shape,
arg1_shape,
in_shape,
window_movement_strides,
window_dilation_strides,
padding_below,
padding_above,
data_dilation_strides);
}
}
} // namespace kernel
} // namespace cpu
} // namespace runtime
} // namespace ngraph
......@@ -1629,11 +1629,11 @@ shared_ptr<runtime::Executable>
const shared_ptr<op::ConvolutionBackpropFilters> conv_op =
static_pointer_cast<op::ConvolutionBackpropFilters>(op);
const Strides& win_stride = conv_op->get_window_movement_strides_backward();
const CoordinateDiff& pad_below = conv_op->get_padding_below_backward();
const CoordinateDiff& pad_above = conv_op->get_padding_above_backward();
const Strides& win_dilation = conv_op->get_window_dilation_strides_backward();
const Strides& data_dilation = conv_op->get_data_dilation_strides_backward();
const Strides& win_stride = conv_op->get_window_dilation_strides_forward();
const CoordinateDiff& pad_below = conv_op->get_padding_below_forward();
CoordinateDiff pad_above = conv_op->compute_backward_in_pad_above();
const Strides& win_dilation = conv_op->get_window_movement_strides_forward();
const Strides& data_dilation = conv_op->get_data_dilation_strides_forward();
if ((win_stride.size() > 2) || (win_stride.at(0) != 1) || (win_stride.at(1) != 1) ||
(pad_below.size() > 2) || (pad_above.size() > 2) || (data_dilation.size() > 2) ||
......@@ -1648,10 +1648,10 @@ shared_ptr<runtime::Executable>
get_output_name(op),
get_output_shape(op),
get_output_type(op),
conv_op->get_padding_below_backward(),
conv_op->get_window_movement_strides_backward(),
conv_op->get_window_dilation_strides_backward(),
conv_op->get_data_dilation_strides_backward(),
conv_op->get_padding_below_forward(),
win_stride,
win_dilation,
data_dilation,
1,
0,
0,
......@@ -1728,11 +1728,11 @@ shared_ptr<runtime::Executable>
const shared_ptr<op::ConvolutionBackpropData> conv_op =
static_pointer_cast<op::ConvolutionBackpropData>(op);
const Strides& win_stride = conv_op->get_window_movement_strides_backward();
const CoordinateDiff& pad_below = conv_op->get_padding_below_backward();
const CoordinateDiff& pad_above = conv_op->get_padding_above_backward();
const Strides& win_dilation = conv_op->get_window_dilation_strides_backward();
const Strides& data_dilation = conv_op->get_data_dilation_strides_backward();
const Strides& win_stride = conv_op->get_data_dilation_strides_forward();
CoordinateDiff pad_below = conv_op->compute_backward_delta_out_pad_below();
CoordinateDiff pad_above = conv_op->compute_backward_delta_out_pad_above();
const Strides& win_dilation = conv_op->get_window_dilation_strides_forward();
const Strides& data_dilation = conv_op->get_window_movement_strides_forward();
if ((win_stride.size() > 2) || (win_stride.at(0) != 1) || (win_stride.at(1) != 1) ||
(pad_below.size() > 2) || (pad_above.size() > 2) || (data_dilation.size() > 2) ||
......@@ -1749,10 +1749,10 @@ shared_ptr<runtime::Executable>
get_output_name(op),
get_output_shape(op),
get_output_type(op),
conv_op->get_padding_below_backward(),
conv_op->get_window_movement_strides_backward(),
conv_op->get_window_dilation_strides_backward(),
conv_op->get_data_dilation_strides_backward(),
pad_below,
win_stride,
win_dilation,
data_dilation,
0,
1,
1,
......
......@@ -346,17 +346,7 @@ void print_node_parameters(ostringstream& writer, const shared_ptr<Node>& node)
<< print_table_row_dims("pad_above_forward",
conv_op_filt->get_padding_above_forward())
<< print_table_row_dims("pad_below_forward",
conv_op_filt->get_padding_below_forward())
<< print_table_row_dims("window_movement_strides_backward",
conv_op_filt->get_window_movement_strides_backward())
<< print_table_row_dims("window_dilation_strides_backward",
conv_op_filt->get_window_dilation_strides_backward())
<< print_table_row_dims("data_dilation_strides_backward",
conv_op_filt->get_data_dilation_strides_backward())
<< print_table_row_dims("padding_above_backward",
conv_op_filt->get_padding_above_backward())
<< print_table_row_dims("padding_below_backward",
conv_op_filt->get_padding_below_backward());
conv_op_filt->get_padding_below_forward());
break;
}
case OP_TYPEID::ConvolutionBackpropData:
......@@ -374,17 +364,7 @@ void print_node_parameters(ostringstream& writer, const shared_ptr<Node>& node)
<< print_table_row_dims("pad_above_forward",
conv_op_data->get_padding_above_forward())
<< print_table_row_dims("pad_below_forward",
conv_op_data->get_padding_below_forward())
<< print_table_row_dims("window_movement_strides_backward",
conv_op_data->get_window_movement_strides_backward())
<< print_table_row_dims("window_dilation_strides_backward",
conv_op_data->get_window_dilation_strides_backward())
<< print_table_row_dims("data_dilation_strides_backward",
conv_op_data->get_data_dilation_strides_backward())
<< print_table_row_dims("padding_above_backward",
conv_op_data->get_padding_above_backward())
<< print_table_row_dims("padding_below_backward",
conv_op_data->get_padding_below_backward());
conv_op_data->get_padding_below_forward());
break;
}
case OP_TYPEID::UNDEFINED_OP:
......
......@@ -146,9 +146,9 @@ namespace ngraph
{
class INTBackend;
class INTExecutable;
}
}
}
} // namespace interpreter
} // namespace runtime
} // namespace ngraph
class ngraph::runtime::interpreter::INTExecutable : public Executable
{
......@@ -551,38 +551,26 @@ private:
c->get_window_dilation_strides(),
c->get_padding_below(),
c->get_padding_above(),
c->get_data_dilation_strides(),
0,
1,
1,
0,
0,
1,
false);
c->get_data_dilation_strides());
break;
}
case OP_TYPEID::ConvolutionBackpropFilters:
{
const op::ConvolutionBackpropFilters* c =
static_cast<const op::ConvolutionBackpropFilters*>(&node);
reference::convolution<T>(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
node.get_input_shape(0),
node.get_input_shape(1),
node.get_output_shape(0),
c->get_window_movement_strides_backward(),
c->get_window_dilation_strides_backward(),
c->get_padding_below_backward(),
c->get_padding_above_backward(),
c->get_data_dilation_strides_backward(),
1,
0,
0,
1,
1,
0,
false);
reference::convolution_backprop_filter<T>(
args[0]->get_data_ptr<const T>(), // input
args[1]->get_data_ptr<const T>(), // delta_convolution_output
out[0]->get_data_ptr<T>(), // delta_filter
c->get_input_shape(0), // input_shape
c->get_input_shape(1), // convolution_output_shape
c->get_filters_shape(), // filter_shape
c->get_window_dilation_strides_forward(),
c->get_window_movement_strides_forward(),
c->get_padding_below_forward(),
c->compute_backward_in_pad_above(),
c->get_data_dilation_strides_forward());
break;
}
case OP_TYPEID::ConvolutionBackpropData:
......@@ -590,24 +578,17 @@ private:
// Note that args[1] and args[0] are switched here from the usual order.
const op::ConvolutionBackpropData* c =
static_cast<const op::ConvolutionBackpropData*>(&node);
reference::convolution<T>(args[1]->get_data_ptr<const T>(),
reference::convolution_backprop_in<T>(args[1]->get_data_ptr<const T>(),
args[0]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
node.get_input_shape(1),
node.get_input_shape(0),
node.get_output_shape(0),
c->get_window_movement_strides_backward(),
c->get_window_dilation_strides_backward(),
c->get_padding_below_backward(),
c->get_padding_above_backward(),
c->get_data_dilation_strides_backward(),
0,
1,
0,
1,
0,
1,
true);
c->get_input_shape(1),
c->get_input_shape(0),
c->get_data_batch_shape(),
c->get_data_dilation_strides_forward(),
c->get_window_dilation_strides_forward(),
c->compute_backward_delta_out_pad_below(),
c->compute_backward_delta_out_pad_above(),
c->get_window_movement_strides_forward());
break;
}
case OP_TYPEID::Cos:
......
......@@ -20,6 +20,7 @@
#include "ngraph/axis_vector.hpp"
#include "ngraph/coordinate_transform.hpp"
#include "ngraph/runtime/reference/reverse.hpp"
#include "ngraph/util.hpp"
namespace ngraph
......@@ -28,47 +29,48 @@ namespace ngraph
{
namespace reference
{
// in: NC_I...
// filter: C_OC_I...
// out: NC_O...
template <typename T>
void convolution(const T* arg0,
const T* arg1,
void general_convolution(const T* in,
const T* filter,
T* out,
const Shape& arg0_shape,
const Shape& arg1_shape,
const Shape& in_shape,
const Shape& filter_shape,
const Shape& out_shape,
const Strides& window_movement_strides,
const Strides& window_dilation_strides,
const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above,
const Strides& data_dilation_strides,
size_t batch_axis_data,
size_t input_channel_axis_data,
size_t input_channel_axis_filters,
size_t output_channel_axis_filters,
size_t batch_axis_result,
size_t output_channel_axis_result,
bool rotate_filter)
const Strides& stride,
const Strides& filter_dilation,
const CoordinateDiff& in_pad_below,
const CoordinateDiff& in_pad_above,
const Strides& in_dilation,
size_t in_batch_axis,
size_t in_channel_axis,
size_t filter_out_channel_axis,
size_t filter_in_channel_axis,
size_t out_batch_axis,
size_t out_channel_axis)
{
// Comments throughout assume without loss of generality that:
//
// * batch axes for both input data and output data are 0
// * input channel axes for both input data and filters are 1
// * output channel axes for filters is 0
// * output channel axis for output data is 1
// * rotate_filter is false
// * batch axes for both in and out are 0
// * in channel axes for both in and filter are 1
// * out channel axes for filter is 0
// * out channel axis for out is 1
// At the outermost level we will walk over every output coordinate O.
CoordinateTransform output_transform(out_shape);
// At the outermost level we will walk over every out coordinate O.
CoordinateTransform out_transform(out_shape);
for (const Coordinate& out_coord : output_transform)
for (const Coordinate& out_coord : out_transform)
{
// Our output coordinate O will have the form:
// Our out coordinate O will have the form:
//
// (N,chan_out,i_1,...,i_n)
size_t batch_index = out_coord[batch_axis_result];
size_t output_channel = out_coord[output_channel_axis_result];
size_t batch_index = out_coord[out_batch_axis];
size_t out_channel = out_coord[out_channel_axis];
// For the input data we need to iterate the coordinate:
// For the in we need to iterate the coordinate:
//
// I:
//
......@@ -82,59 +84,57 @@ namespace ngraph
//
// (1,l_1,...,l_n).
//
// Note that we are iterating within the *padded* and *dilated* data batch, so further
// down we must check the current coordinate is in the padding or dilation gap.
// Note that we are iterating within the *padded* and *dilated* in batch, so further
// down we must check the current coordinate is in the pad or dilation gap.
size_t n_spatial_dimensions = arg0_shape.size() - 2;
size_t n_input_channels = arg0_shape[input_channel_axis_data];
size_t n_spatial_dimensions = in_shape.size() - 2;
size_t n_in_channels = in_shape[in_channel_axis];
Coordinate input_batch_transform_start(2 + n_spatial_dimensions);
Coordinate input_batch_transform_end(2 + n_spatial_dimensions);
Strides input_batch_transform_movement_strides(2 + n_spatial_dimensions, 1);
CoordinateDiff input_batch_transform_padding_below(2 + n_spatial_dimensions, 0);
CoordinateDiff input_batch_transform_padding_above(2 + n_spatial_dimensions, 0);
Strides input_batch_transform_dilation_strides(2 + n_spatial_dimensions, 1);
Coordinate in_transform_start(2 + n_spatial_dimensions);
Coordinate in_transform_end(2 + n_spatial_dimensions);
Strides in_transform_movement_strides(2 + n_spatial_dimensions, 1);
CoordinateDiff in_transform_pad_below(2 + n_spatial_dimensions, 0);
CoordinateDiff in_transform_pad_above(2 + n_spatial_dimensions, 0);
Strides in_transform_dilation_strides(2 + n_spatial_dimensions, 1);
input_batch_transform_start[batch_axis_data] = batch_index;
input_batch_transform_end[batch_axis_data] = batch_index + 1;
input_batch_transform_start[input_channel_axis_data] = 0;
input_batch_transform_end[input_channel_axis_data] = n_input_channels;
in_transform_start[in_batch_axis] = batch_index;
in_transform_end[in_batch_axis] = batch_index + 1;
in_transform_start[in_channel_axis] = 0;
in_transform_end[in_channel_axis] = n_in_channels;
for (size_t i = 2; i < n_spatial_dimensions + 2; i++)
{
size_t window_dilation_stride = window_dilation_strides[i - 2];
size_t window_movement_stride = window_movement_strides[i - 2];
std::ptrdiff_t below_pad = padding_below[i - 2];
std::ptrdiff_t above_pad = padding_above[i - 2];
size_t data_dilation_stride = data_dilation_strides[i - 2];
input_batch_transform_start[i] = window_movement_stride * out_coord[i];
input_batch_transform_end[i] =
input_batch_transform_start[i] +
(arg1_shape[i] - 1) * window_dilation_stride + 1;
input_batch_transform_movement_strides[i] = window_dilation_stride;
input_batch_transform_padding_below[i] = below_pad;
input_batch_transform_padding_above[i] = above_pad;
input_batch_transform_dilation_strides[i] = data_dilation_stride;
size_t filter_dilation_stride = filter_dilation[i - 2];
size_t filter_movement_stride = stride[i - 2];
std::ptrdiff_t below_pad = in_pad_below[i - 2];
std::ptrdiff_t above_pad = in_pad_above[i - 2];
size_t in_dilation_stride = in_dilation[i - 2];
in_transform_start[i] = filter_movement_stride * out_coord[i];
in_transform_end[i] = in_transform_start[i] +
(filter_shape[i] - 1) * filter_dilation_stride + 1;
in_transform_movement_strides[i] = filter_dilation_stride;
in_transform_pad_below[i] = below_pad;
in_transform_pad_above[i] = above_pad;
in_transform_dilation_strides[i] = in_dilation_stride;
}
AxisVector input_batch_transform_axis_order(2 + n_spatial_dimensions);
for (size_t i = 0; i < input_batch_transform_axis_order.size(); i++)
AxisVector in_transform_axis_order(2 + n_spatial_dimensions);
for (size_t i = 0; i < in_transform_axis_order.size(); i++)
{
input_batch_transform_axis_order[i] = i;
in_transform_axis_order[i] = i;
}
CoordinateTransform input_batch_transform(
arg0_shape,
input_batch_transform_start,
input_batch_transform_end,
input_batch_transform_movement_strides,
input_batch_transform_axis_order,
input_batch_transform_padding_below,
input_batch_transform_padding_above,
input_batch_transform_dilation_strides);
// Simultaneously with iterating I, for the filters we need to iterate the coordinate:
CoordinateTransform in_transform(in_shape,
in_transform_start,
in_transform_end,
in_transform_movement_strides,
in_transform_axis_order,
in_transform_pad_below,
in_transform_pad_above,
in_transform_dilation_strides);
// Simultaneously with iterating I, for the filter we need to iterate the coordinate:
//
// F
//
......@@ -147,61 +147,153 @@ namespace ngraph
Shape filter_transform_start(2 + n_spatial_dimensions);
Shape filter_transform_end(2 + n_spatial_dimensions);
filter_transform_start[output_channel_axis_filters] = output_channel;
filter_transform_end[output_channel_axis_filters] = output_channel + 1;
filter_transform_start[input_channel_axis_filters] = 0;
filter_transform_end[input_channel_axis_filters] = n_input_channels;
filter_transform_start[filter_out_channel_axis] = out_channel;
filter_transform_end[filter_out_channel_axis] = out_channel + 1;
filter_transform_start[filter_in_channel_axis] = 0;
filter_transform_end[filter_in_channel_axis] = n_in_channels;
for (size_t i = 2; i < n_spatial_dimensions + 2; i++)
{
filter_transform_start[i] = 0;
filter_transform_end[i] = arg1_shape[i];
filter_transform_end[i] = filter_shape[i];
}
CoordinateTransform filter_transform(
arg1_shape, filter_transform_start, filter_transform_end);
filter_shape, filter_transform_start, filter_transform_end);
// As we go, we sum up:
//
// output[O] += arg0[I] * arg1[F].
// out[O] += in[I] * filter[F].
T result = 0;
CoordinateTransform::Iterator input_it = input_batch_transform.begin();
CoordinateTransform::Iterator in_it = in_transform.begin();
CoordinateTransform::Iterator filter_it = filter_transform.begin();
CoordinateTransform::Iterator input_it_end = input_batch_transform.end();
CoordinateTransform::Iterator in_it_end = in_transform.end();
CoordinateTransform::Iterator filter_it_end = filter_transform.end();
while (input_it != input_it_end && filter_it != filter_it_end)
{
const Coordinate& input_batch_coord = *input_it;
Coordinate filter_coord = *filter_it;
if (rotate_filter)
{
Shape target_shape = filter_transform.get_target_shape();
// Note that we only reverse the spatial dimensions here (loop
// starts at 2)
for (size_t i = 2; i < filter_coord.size(); i++)
while (in_it != in_it_end && filter_it != filter_it_end)
{
filter_coord[i] = target_shape[i] - filter_coord[i] - 1;
}
}
T v = input_batch_transform.has_source_coordinate(input_batch_coord)
? arg0[input_batch_transform.index(input_batch_coord)]
const Coordinate& in_coord = *in_it;
T v = in_transform.has_source_coordinate(in_coord)
? in[in_transform.index(in_coord)]
: 0;
result += v * arg1[filter_transform.index(filter_coord)];
result += v * filter[filter_transform.index(*filter_it)];
++input_it;
++in_it;
++filter_it;
}
out[output_transform.index(out_coord)] = result;
out[out_transform.index(out_coord)] = result;
}
}
template <typename T>
void convolution(const T* in,
const T* filter,
T* out,
const Shape& in_shape,
const Shape& filter_shape,
const Shape& out_shape,
const Strides& stride,
const Strides& filter_dilation,
const CoordinateDiff& in_pad_below,
const CoordinateDiff& in_pad_above,
const Strides& in_dilation)
{
general_convolution(in,
filter,
out,
in_shape,
filter_shape,
out_shape,
stride,
filter_dilation,
in_pad_below,
in_pad_above,
in_dilation,
0,
1,
0,
1,
0,
1);
}
template <typename T>
void convolution_backprop_filter(const T* in,
const T* delta_out,
T* delta_filter,
const Shape& in_shape,
const Shape& out_shape,
const Shape& filter_shape,
const Strides& filter_dilation,
const Strides& stride,
const CoordinateDiff& in_pad_below,
const CoordinateDiff& backprop_in_pad_above,
const Strides& in_dilation)
{
general_convolution(in,
delta_out,
delta_filter,
in_shape,
out_shape,
filter_shape,
filter_dilation,
stride,
in_pad_below,
backprop_in_pad_above,
in_dilation,
1,
0,
1,
0,
1,
0);
}
template <typename T>
void convolution_backprop_in(const T* delta_out,
const T* filter,
T* delta_in,
const Shape& out_shape,
const Shape& filter_shape,
const Shape& in_shape,
const Strides& in_dilation,
const Strides& filter_dilation,
const CoordinateDiff& backward_delta_out_pad_below,
const CoordinateDiff& backward_delta_out_pad_above,
const Strides& stride)
{
// Note that we only reverse the spatial dimensions here (loop
// starts at 2)
std::vector<T> reversed(shape_size(filter_shape));
AxisSet reverse_axes;
for (size_t i = 2; i < filter_shape.size(); ++i)
{
reverse_axes.insert(i);
}
reverse<T>(filter, &reversed[0], filter_shape, filter_shape, reverse_axes);
general_convolution(delta_out,
&reversed[0],
delta_in,
out_shape,
filter_shape,
in_shape,
in_dilation,
filter_dilation,
backward_delta_out_pad_below,
backward_delta_out_pad_above,
stride,
0,
1,
1,
0,
0,
1);
}
}
} // namespace reference
} // namespace runtime
} // namespace ngraph
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment