Unverified Commit 8c63cd53 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Remove _backwards from conv backprop ops. (#2593)

* Remove _backwards from conv backprop ops.

FLip filter in backprop_data instead of general_convolution

* Review comment
parent a0e7694b
This diff is collapsed.
......@@ -220,31 +220,9 @@ namespace ngraph
return m_data_dilation_strides_forward;
}
/// \return The window movement strides for the backward prop.
const Strides& get_window_movement_strides_backward() const
{
return m_window_movement_strides_backward;
}
/// \return The window dilation strides for the backward prop.
const Strides& get_window_dilation_strides_backward() const
{
return m_window_dilation_strides_backward;
}
/// \return The padding-below sizes (possibly negative) for the backward prop.
const CoordinateDiff& get_padding_below_backward() const
{
return m_padding_below_backward;
}
/// \return The padding-above sizes (possibly negative) for the backward prop.
const CoordinateDiff& get_padding_above_backward() const
{
return m_padding_above_backward;
}
/// \return The input data dilation strides for the backward prop.
const Strides& get_data_dilation_strides_backward() const
{
return m_data_dilation_strides_backward;
}
// Compute the pad_above values to be used if in a convolution
CoordinateDiff compute_backward_delta_out_pad_above() const;
CoordinateDiff compute_backward_delta_out_pad_below() const;
protected:
Shape m_data_batch_shape;
......@@ -253,12 +231,6 @@ namespace ngraph
CoordinateDiff m_padding_below_forward;
CoordinateDiff m_padding_above_forward;
Strides m_data_dilation_strides_forward;
Strides m_window_movement_strides_backward;
Strides m_window_dilation_strides_backward;
CoordinateDiff m_padding_below_backward;
CoordinateDiff m_padding_above_backward;
Strides m_data_dilation_strides_backward;
};
/// \brief Filters backprop for batched convolution operation.
......@@ -317,31 +289,8 @@ namespace ngraph
return m_data_dilation_strides_forward;
}
/// \return The window movement strides for the backward prop.
const Strides& get_window_movement_strides_backward() const
{
return m_window_movement_strides_backward;
}
/// \return The window dilation strides for the backward prop.
const Strides& get_window_dilation_strides_backward() const
{
return m_window_dilation_strides_backward;
}
/// \return The padding-below sizes (possibly negative) for the backward prop.
const CoordinateDiff& get_padding_below_backward() const
{
return m_padding_below_backward;
}
/// \return The padding-above sizes (possibly negative) for the backward prop.
const CoordinateDiff& get_padding_above_backward() const
{
return m_padding_above_backward;
}
/// \return The data dilation strides for the backward prop.
const Strides& get_data_dilation_strides_backward() const
{
return m_data_dilation_strides_backward;
}
// Compute the pad_above value to be used if in a convolution
CoordinateDiff compute_backward_in_pad_above() const;
protected:
Shape m_filters_shape;
......@@ -350,12 +299,6 @@ namespace ngraph
CoordinateDiff m_padding_below_forward;
CoordinateDiff m_padding_above_forward;
Strides m_data_dilation_strides_forward;
Strides m_window_movement_strides_backward;
Strides m_window_dilation_strides_backward;
CoordinateDiff m_padding_below_backward;
CoordinateDiff m_padding_above_backward;
Strides m_data_dilation_strides_backward;
};
namespace util
......
......@@ -14,9 +14,8 @@
// limitations under the License.
//*****************************************************************************
#include "ngraph/op/convolution.hpp"
#include "ngraph/runtime/cpu/cpu_builder.hpp"
#include "ngraph/runtime/cpu/kernel/convolution.hpp"
#include "ngraph/runtime/cpu/cpu_builder.hpp"
#include "ngraph/runtime/cpu/mkldnn_invoke.hpp"
#include "ngraph/runtime/cpu/mkldnn_utils.hpp"
#include "ngraph/runtime/cpu/op/conv_add.hpp"
......@@ -107,14 +106,7 @@ namespace ngraph
window_dilation_strides,
padding_below,
padding_above,
data_dilation_strides,
0,
1,
1,
0,
0,
1,
false);
data_dilation_strides);
};
functors.emplace_back(functor);
}
......@@ -313,7 +305,6 @@ namespace ngraph
auto arg0_shape = args[0].get_shape();
auto arg1_shape = args[1].get_shape();
auto result_shape = out[0].get_shape();
auto& arg0_tensor = external_function->get_tensor_data(args[0].get_name());
auto& arg1_tensor = external_function->get_tensor_data(args[1].get_name());
......@@ -347,48 +338,47 @@ namespace ngraph
}
else
{
std::function<decltype(runtime::cpu::kernel::convolution<float>)> kernel;
SELECT_KERNEL(
kernel, out[0].get_element_type(), runtime::cpu::kernel::convolution);
auto window_movement_strides =
convolution->get_window_movement_strides_backward();
std::function<decltype(runtime::cpu::kernel::convolution_backprop_in<float>)>
kernel;
SELECT_KERNEL(kernel,
out[0].get_element_type(),
runtime::cpu::kernel::convolution_backprop_in);
auto& in_shape = convolution->get_data_batch_shape();
auto data_dilation_strides = convolution->get_data_dilation_strides_forward();
auto window_dilation_strides =
convolution->get_window_dilation_strides_backward();
auto padding_below = convolution->get_padding_below_backward();
auto padding_above = convolution->get_padding_above_backward();
auto data_dilation_strides = convolution->get_data_dilation_strides_backward();
convolution->get_window_dilation_strides_forward();
auto padding_below = convolution->get_padding_below_forward();
auto padding_above = convolution->get_padding_above_forward();
auto window_movement_strides =
convolution->get_window_movement_strides_forward();
auto backward_delta_out_pad_below =
convolution->compute_backward_delta_out_pad_below();
auto backward_delta_out_pad_above =
convolution->compute_backward_delta_out_pad_above();
auto functor = [&,
kernel,
arg0_shape,
arg1_shape,
result_shape,
window_movement_strides,
in_shape,
data_dilation_strides,
window_dilation_strides,
padding_below,
padding_above,
data_dilation_strides](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
backward_delta_out_pad_below,
backward_delta_out_pad_above,
window_movement_strides](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
kernel(arg1_tensor,
arg0_tensor,
out_tensor,
arg1_shape,
arg0_shape,
result_shape,
window_movement_strides,
window_dilation_strides,
padding_below,
padding_above,
in_shape,
data_dilation_strides,
0,
1,
0,
1,
0,
1,
true);
window_dilation_strides,
backward_delta_out_pad_below,
backward_delta_out_pad_above,
window_movement_strides);
};
functors.emplace_back(functor);
}
......@@ -403,7 +393,6 @@ namespace ngraph
auto arg0_shape = args[0].get_shape();
auto arg1_shape = args[1].get_shape();
auto result_shape = out[0].get_shape();
auto& arg0_tensor = external_function->get_tensor_data(args[0].get_name());
auto& arg1_tensor = external_function->get_tensor_data(args[1].get_name());
......@@ -437,28 +426,34 @@ namespace ngraph
}
else
{
std::function<decltype(runtime::cpu::kernel::convolution<float>)> kernel;
std::function<decltype(
runtime::cpu::kernel::convolution_backprop_filter<float>)>
kernel;
SELECT_KERNEL(
kernel, out[0].get_element_type(), runtime::cpu::kernel::convolution);
SELECT_KERNEL(kernel,
out[0].get_element_type(),
runtime::cpu::kernel::convolution_backprop_filter);
auto window_movement_strides =
convolution->get_window_movement_strides_backward();
auto& filters_shape = convolution->get_filters_shape();
auto window_dilation_strides =
convolution->get_window_dilation_strides_backward();
auto padding_below = convolution->get_padding_below_backward();
auto padding_above = convolution->get_padding_above_backward();
auto data_dilation_strides = convolution->get_data_dilation_strides_backward();
convolution->get_window_dilation_strides_forward();
auto window_movement_strides =
convolution->get_window_movement_strides_forward();
auto padding_below = convolution->get_padding_below_forward();
auto padding_above = convolution->get_padding_above_forward();
auto data_dilation_strides = convolution->get_data_dilation_strides_forward();
CoordinateDiff backward_in_pad_above =
convolution->compute_backward_in_pad_above();
auto functor = [&,
kernel,
arg0_shape,
arg1_shape,
result_shape,
window_movement_strides,
filters_shape,
window_dilation_strides,
window_movement_strides,
padding_below,
padding_above,
backward_in_pad_above,
data_dilation_strides](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
kernel(arg0_tensor,
......@@ -466,19 +461,12 @@ namespace ngraph
out_tensor,
arg0_shape,
arg1_shape,
result_shape,
filters_shape,
window_movement_strides,
window_dilation_strides,
padding_below,
padding_above,
data_dilation_strides,
1,
0,
0,
1,
1,
0,
false);
backward_in_pad_above,
data_dilation_strides);
};
functors.emplace_back(functor);
}
......@@ -624,6 +612,6 @@ namespace ngraph
REGISTER_OP_BUILDER(GroupConvolution);
REGISTER_OP_BUILDER(ConvolutionAdd);
REGISTER_OP_BUILDER(GroupConvolutionBias);
}
}
}
} // namespace cpu
} // namespace runtime
} // namespace ngraph
......@@ -2288,8 +2288,7 @@ namespace ngraph
writer << " {" << join(convolution->get_padding_above())
<< "},\n";
writer << " {"
<< join(convolution->get_data_dilation_strides()) << "},\n";
writer << " 0, 1, 1, 0, 0, 1, false);\n";
<< join(convolution->get_data_dilation_strides()) << "});\n";
}
}
......@@ -2323,24 +2322,25 @@ namespace ngraph
}
else
{
writer << "reference::convolution<" << out[0].get_type() << ">("
<< args[0].get_name() << ",\n";
writer << "reference::convolution_backprop_filters<" << out[0].get_type()
<< ">(" << args[0].get_name() << ",\n";
writer << " " << args[1].get_name() << ",\n";
writer << " " << out[0].get_name() << ",\n";
writer << " {" << join(convolution->get_filters_shape())
<< "},\n";
writer << " {" << join(arg0_shape) << "},\n";
writer << " {" << join(arg1_shape) << "},\n";
writer << " {" << join(result_shape) << "},\n";
writer << " {"
<< join(convolution->get_window_movement_strides_backward()) << "},\n";
<< join(convolution->get_window_dilation_strides_forward()) << "},\n";
writer << " {"
<< join(convolution->get_window_dilation_strides_backward()) << "},\n";
<< join(convolution->get_window_movement_strides_forward()) << "},\n";
writer << " {"
<< join(convolution->get_padding_below_backward()) << "},\n";
<< join(convolution->get_padding_below_forward()) << "},\n";
writer << " {"
<< join(convolution->get_padding_above_backward()) << "},\n";
<< join(convolution->get_padding_above_forward()) << "},\n";
writer << " {"
<< join(convolution->get_data_dilation_strides_backward()) << "},\n";
writer << " 1, 0, 0, 1, 1, 0, false);\n";
<< join(convolution->get_data_dilation_strides_forward()) << "});\n";
}
}
......@@ -2375,24 +2375,25 @@ namespace ngraph
else
{
// Note that args[1] and args[0] are switched here from the usual order.
writer << "reference::convolution<" << out[0].get_type() << ">("
writer << "reference::convolution_backprop_data<" << out[0].get_type() << ">("
<< args[1].get_name() << ",\n";
writer << " " << args[0].get_name() << ",\n";
writer << " " << out[0].get_name() << ",\n";
writer << " {"
<< join(convolution->get_data_batch_shape()) << "},\n";
writer << " {" << join(arg1_shape) << "},\n";
writer << " {" << join(arg0_shape) << "},\n";
writer << " {" << join(result_shape) << "},\n";
writer << " {"
<< join(convolution->get_window_movement_strides_backward()) << "},\n";
<< join(convolution->get_data_dilation_strides_forward()) << "},\n";
writer << " {"
<< join(convolution->get_window_dilation_strides_backward()) << "},\n";
<< join(convolution->get_window_dilation_strides_forward()) << "},\n";
writer << " {"
<< join(convolution->get_padding_below_backward()) << "},\n";
<< join(convolution->get_padding_below_forward()) << "},\n";
writer << " {"
<< join(convolution->get_padding_above_backward()) << "},\n";
<< join(convolution->get_padding_above_forward()) << "},\n";
writer << " {"
<< join(convolution->get_data_dilation_strides_backward()) << "},\n";
writer << " 0, 1, 0, 1, 0, 1, true);\n";
<< join(convolution->get_window_movement_strides_forward()) << "});\n";
}
}
......
......@@ -38,14 +38,7 @@ namespace ngraph
const Strides& window_dilation_strides,
const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above,
const Strides& data_dilation_strides,
size_t batch_axis_data,
size_t input_channel_axis_data,
size_t input_channel_axis_filters,
size_t output_channel_axis_filters,
size_t batch_axis_result,
size_t output_channel_axis_result,
bool rotate_filter)
const Strides& data_dilation_strides)
{
reference::convolution<ElementType>(static_cast<const ElementType*>(input0),
static_cast<const ElementType*>(input1),
......@@ -57,16 +50,63 @@ namespace ngraph
window_dilation_strides,
padding_below,
padding_above,
data_dilation_strides,
batch_axis_data,
input_channel_axis_data,
input_channel_axis_filters,
output_channel_axis_filters,
batch_axis_result,
output_channel_axis_result,
rotate_filter);
data_dilation_strides);
}
}
}
}
}
template <typename ElementType>
void convolution_backprop_filter(void* input0,
void* input1,
void* output,
const Shape& arg0_shape,
const Shape& arg1_shape,
const Shape& filter_shape,
const Strides& window_dilation_strides,
const Strides& window_movement_strides,
const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above,
const Strides& data_dilation_strides)
{
reference::convolution_backprop_filter<ElementType>(
static_cast<const ElementType*>(input0),
static_cast<const ElementType*>(input1),
static_cast<ElementType*>(output),
arg0_shape,
arg1_shape,
filter_shape,
window_dilation_strides,
window_movement_strides,
padding_below,
padding_above,
data_dilation_strides);
}
template <typename ElementType>
void convolution_backprop_in(void* input0,
void* input1,
void* output,
const Shape& arg0_shape,
const Shape& arg1_shape,
const Shape& in_shape,
const Strides& window_movement_strides,
const Strides& window_dilation_strides,
const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above,
const Strides& data_dilation_strides)
{
reference::convolution_backprop_in<ElementType>(
static_cast<const ElementType*>(input0),
static_cast<const ElementType*>(input1),
static_cast<ElementType*>(output),
arg0_shape,
arg1_shape,
in_shape,
window_movement_strides,
window_dilation_strides,
padding_below,
padding_above,
data_dilation_strides);
}
} // namespace kernel
} // namespace cpu
} // namespace runtime
} // namespace ngraph
......@@ -1629,11 +1629,11 @@ shared_ptr<runtime::Executable>
const shared_ptr<op::ConvolutionBackpropFilters> conv_op =
static_pointer_cast<op::ConvolutionBackpropFilters>(op);
const Strides& win_stride = conv_op->get_window_movement_strides_backward();
const CoordinateDiff& pad_below = conv_op->get_padding_below_backward();
const CoordinateDiff& pad_above = conv_op->get_padding_above_backward();
const Strides& win_dilation = conv_op->get_window_dilation_strides_backward();
const Strides& data_dilation = conv_op->get_data_dilation_strides_backward();
const Strides& win_stride = conv_op->get_window_dilation_strides_forward();
const CoordinateDiff& pad_below = conv_op->get_padding_below_forward();
CoordinateDiff pad_above = conv_op->compute_backward_in_pad_above();
const Strides& win_dilation = conv_op->get_window_movement_strides_forward();
const Strides& data_dilation = conv_op->get_data_dilation_strides_forward();
if ((win_stride.size() > 2) || (win_stride.at(0) != 1) || (win_stride.at(1) != 1) ||
(pad_below.size() > 2) || (pad_above.size() > 2) || (data_dilation.size() > 2) ||
......@@ -1648,10 +1648,10 @@ shared_ptr<runtime::Executable>
get_output_name(op),
get_output_shape(op),
get_output_type(op),
conv_op->get_padding_below_backward(),
conv_op->get_window_movement_strides_backward(),
conv_op->get_window_dilation_strides_backward(),
conv_op->get_data_dilation_strides_backward(),
conv_op->get_padding_below_forward(),
win_stride,
win_dilation,
data_dilation,
1,
0,
0,
......@@ -1728,11 +1728,11 @@ shared_ptr<runtime::Executable>
const shared_ptr<op::ConvolutionBackpropData> conv_op =
static_pointer_cast<op::ConvolutionBackpropData>(op);
const Strides& win_stride = conv_op->get_window_movement_strides_backward();
const CoordinateDiff& pad_below = conv_op->get_padding_below_backward();
const CoordinateDiff& pad_above = conv_op->get_padding_above_backward();
const Strides& win_dilation = conv_op->get_window_dilation_strides_backward();
const Strides& data_dilation = conv_op->get_data_dilation_strides_backward();
const Strides& win_stride = conv_op->get_data_dilation_strides_forward();
CoordinateDiff pad_below = conv_op->compute_backward_delta_out_pad_below();
CoordinateDiff pad_above = conv_op->compute_backward_delta_out_pad_above();
const Strides& win_dilation = conv_op->get_window_dilation_strides_forward();
const Strides& data_dilation = conv_op->get_window_movement_strides_forward();
if ((win_stride.size() > 2) || (win_stride.at(0) != 1) || (win_stride.at(1) != 1) ||
(pad_below.size() > 2) || (pad_above.size() > 2) || (data_dilation.size() > 2) ||
......@@ -1749,10 +1749,10 @@ shared_ptr<runtime::Executable>
get_output_name(op),
get_output_shape(op),
get_output_type(op),
conv_op->get_padding_below_backward(),
conv_op->get_window_movement_strides_backward(),
conv_op->get_window_dilation_strides_backward(),
conv_op->get_data_dilation_strides_backward(),
pad_below,
win_stride,
win_dilation,
data_dilation,
0,
1,
1,
......
......@@ -346,17 +346,7 @@ void print_node_parameters(ostringstream& writer, const shared_ptr<Node>& node)
<< print_table_row_dims("pad_above_forward",
conv_op_filt->get_padding_above_forward())
<< print_table_row_dims("pad_below_forward",
conv_op_filt->get_padding_below_forward())
<< print_table_row_dims("window_movement_strides_backward",
conv_op_filt->get_window_movement_strides_backward())
<< print_table_row_dims("window_dilation_strides_backward",
conv_op_filt->get_window_dilation_strides_backward())
<< print_table_row_dims("data_dilation_strides_backward",
conv_op_filt->get_data_dilation_strides_backward())
<< print_table_row_dims("padding_above_backward",
conv_op_filt->get_padding_above_backward())
<< print_table_row_dims("padding_below_backward",
conv_op_filt->get_padding_below_backward());
conv_op_filt->get_padding_below_forward());
break;
}
case OP_TYPEID::ConvolutionBackpropData:
......@@ -374,17 +364,7 @@ void print_node_parameters(ostringstream& writer, const shared_ptr<Node>& node)
<< print_table_row_dims("pad_above_forward",
conv_op_data->get_padding_above_forward())
<< print_table_row_dims("pad_below_forward",
conv_op_data->get_padding_below_forward())
<< print_table_row_dims("window_movement_strides_backward",
conv_op_data->get_window_movement_strides_backward())
<< print_table_row_dims("window_dilation_strides_backward",
conv_op_data->get_window_dilation_strides_backward())
<< print_table_row_dims("data_dilation_strides_backward",
conv_op_data->get_data_dilation_strides_backward())
<< print_table_row_dims("padding_above_backward",
conv_op_data->get_padding_above_backward())
<< print_table_row_dims("padding_below_backward",
conv_op_data->get_padding_below_backward());
conv_op_data->get_padding_below_forward());
break;
}
case OP_TYPEID::UNDEFINED_OP:
......
......@@ -146,9 +146,9 @@ namespace ngraph
{
class INTBackend;
class INTExecutable;
}
}
}
} // namespace interpreter
} // namespace runtime
} // namespace ngraph
class ngraph::runtime::interpreter::INTExecutable : public Executable
{
......@@ -551,38 +551,26 @@ private:
c->get_window_dilation_strides(),
c->get_padding_below(),
c->get_padding_above(),
c->get_data_dilation_strides(),
0,
1,
1,
0,
0,
1,
false);
c->get_data_dilation_strides());
break;
}
case OP_TYPEID::ConvolutionBackpropFilters:
{
const op::ConvolutionBackpropFilters* c =
static_cast<const op::ConvolutionBackpropFilters*>(&node);
reference::convolution<T>(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
node.get_input_shape(0),
node.get_input_shape(1),
node.get_output_shape(0),
c->get_window_movement_strides_backward(),
c->get_window_dilation_strides_backward(),
c->get_padding_below_backward(),
c->get_padding_above_backward(),
c->get_data_dilation_strides_backward(),
1,
0,
0,
1,
1,
0,
false);
reference::convolution_backprop_filter<T>(
args[0]->get_data_ptr<const T>(), // input
args[1]->get_data_ptr<const T>(), // delta_convolution_output
out[0]->get_data_ptr<T>(), // delta_filter
c->get_input_shape(0), // input_shape
c->get_input_shape(1), // convolution_output_shape
c->get_filters_shape(), // filter_shape
c->get_window_dilation_strides_forward(),
c->get_window_movement_strides_forward(),
c->get_padding_below_forward(),
c->compute_backward_in_pad_above(),
c->get_data_dilation_strides_forward());
break;
}
case OP_TYPEID::ConvolutionBackpropData:
......@@ -590,24 +578,17 @@ private:
// Note that args[1] and args[0] are switched here from the usual order.
const op::ConvolutionBackpropData* c =
static_cast<const op::ConvolutionBackpropData*>(&node);
reference::convolution<T>(args[1]->get_data_ptr<const T>(),
args[0]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
node.get_input_shape(1),
node.get_input_shape(0),
node.get_output_shape(0),
c->get_window_movement_strides_backward(),
c->get_window_dilation_strides_backward(),
c->get_padding_below_backward(),
c->get_padding_above_backward(),
c->get_data_dilation_strides_backward(),
0,
1,
0,
1,
0,
1,
true);
reference::convolution_backprop_in<T>(args[1]->get_data_ptr<const T>(),
args[0]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
c->get_input_shape(1),
c->get_input_shape(0),
c->get_data_batch_shape(),
c->get_data_dilation_strides_forward(),
c->get_window_dilation_strides_forward(),
c->compute_backward_delta_out_pad_below(),
c->compute_backward_delta_out_pad_above(),
c->get_window_movement_strides_forward());
break;
}
case OP_TYPEID::Cos:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment