Unverified Commit 8c63cd53 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Remove _backwards from conv backprop ops. (#2593)

* Remove _backwards from conv backprop ops.

FLip filter in backprop_data instead of general_convolution

* Review comment
parent a0e7694b
This diff is collapsed.
...@@ -220,31 +220,9 @@ namespace ngraph ...@@ -220,31 +220,9 @@ namespace ngraph
return m_data_dilation_strides_forward; return m_data_dilation_strides_forward;
} }
/// \return The window movement strides for the backward prop. // Compute the pad_above values to be used if in a convolution
const Strides& get_window_movement_strides_backward() const CoordinateDiff compute_backward_delta_out_pad_above() const;
{ CoordinateDiff compute_backward_delta_out_pad_below() const;
return m_window_movement_strides_backward;
}
/// \return The window dilation strides for the backward prop.
const Strides& get_window_dilation_strides_backward() const
{
return m_window_dilation_strides_backward;
}
/// \return The padding-below sizes (possibly negative) for the backward prop.
const CoordinateDiff& get_padding_below_backward() const
{
return m_padding_below_backward;
}
/// \return The padding-above sizes (possibly negative) for the backward prop.
const CoordinateDiff& get_padding_above_backward() const
{
return m_padding_above_backward;
}
/// \return The input data dilation strides for the backward prop.
const Strides& get_data_dilation_strides_backward() const
{
return m_data_dilation_strides_backward;
}
protected: protected:
Shape m_data_batch_shape; Shape m_data_batch_shape;
...@@ -253,12 +231,6 @@ namespace ngraph ...@@ -253,12 +231,6 @@ namespace ngraph
CoordinateDiff m_padding_below_forward; CoordinateDiff m_padding_below_forward;
CoordinateDiff m_padding_above_forward; CoordinateDiff m_padding_above_forward;
Strides m_data_dilation_strides_forward; Strides m_data_dilation_strides_forward;
Strides m_window_movement_strides_backward;
Strides m_window_dilation_strides_backward;
CoordinateDiff m_padding_below_backward;
CoordinateDiff m_padding_above_backward;
Strides m_data_dilation_strides_backward;
}; };
/// \brief Filters backprop for batched convolution operation. /// \brief Filters backprop for batched convolution operation.
...@@ -317,31 +289,8 @@ namespace ngraph ...@@ -317,31 +289,8 @@ namespace ngraph
return m_data_dilation_strides_forward; return m_data_dilation_strides_forward;
} }
/// \return The window movement strides for the backward prop. // Compute the pad_above value to be used if in a convolution
const Strides& get_window_movement_strides_backward() const CoordinateDiff compute_backward_in_pad_above() const;
{
return m_window_movement_strides_backward;
}
/// \return The window dilation strides for the backward prop.
const Strides& get_window_dilation_strides_backward() const
{
return m_window_dilation_strides_backward;
}
/// \return The padding-below sizes (possibly negative) for the backward prop.
const CoordinateDiff& get_padding_below_backward() const
{
return m_padding_below_backward;
}
/// \return The padding-above sizes (possibly negative) for the backward prop.
const CoordinateDiff& get_padding_above_backward() const
{
return m_padding_above_backward;
}
/// \return The data dilation strides for the backward prop.
const Strides& get_data_dilation_strides_backward() const
{
return m_data_dilation_strides_backward;
}
protected: protected:
Shape m_filters_shape; Shape m_filters_shape;
...@@ -350,12 +299,6 @@ namespace ngraph ...@@ -350,12 +299,6 @@ namespace ngraph
CoordinateDiff m_padding_below_forward; CoordinateDiff m_padding_below_forward;
CoordinateDiff m_padding_above_forward; CoordinateDiff m_padding_above_forward;
Strides m_data_dilation_strides_forward; Strides m_data_dilation_strides_forward;
Strides m_window_movement_strides_backward;
Strides m_window_dilation_strides_backward;
CoordinateDiff m_padding_below_backward;
CoordinateDiff m_padding_above_backward;
Strides m_data_dilation_strides_backward;
}; };
namespace util namespace util
......
...@@ -14,9 +14,8 @@ ...@@ -14,9 +14,8 @@
// limitations under the License. // limitations under the License.
//***************************************************************************** //*****************************************************************************
#include "ngraph/op/convolution.hpp"
#include "ngraph/runtime/cpu/cpu_builder.hpp"
#include "ngraph/runtime/cpu/kernel/convolution.hpp" #include "ngraph/runtime/cpu/kernel/convolution.hpp"
#include "ngraph/runtime/cpu/cpu_builder.hpp"
#include "ngraph/runtime/cpu/mkldnn_invoke.hpp" #include "ngraph/runtime/cpu/mkldnn_invoke.hpp"
#include "ngraph/runtime/cpu/mkldnn_utils.hpp" #include "ngraph/runtime/cpu/mkldnn_utils.hpp"
#include "ngraph/runtime/cpu/op/conv_add.hpp" #include "ngraph/runtime/cpu/op/conv_add.hpp"
...@@ -107,14 +106,7 @@ namespace ngraph ...@@ -107,14 +106,7 @@ namespace ngraph
window_dilation_strides, window_dilation_strides,
padding_below, padding_below,
padding_above, padding_above,
data_dilation_strides, data_dilation_strides);
0,
1,
1,
0,
0,
1,
false);
}; };
functors.emplace_back(functor); functors.emplace_back(functor);
} }
...@@ -313,7 +305,6 @@ namespace ngraph ...@@ -313,7 +305,6 @@ namespace ngraph
auto arg0_shape = args[0].get_shape(); auto arg0_shape = args[0].get_shape();
auto arg1_shape = args[1].get_shape(); auto arg1_shape = args[1].get_shape();
auto result_shape = out[0].get_shape();
auto& arg0_tensor = external_function->get_tensor_data(args[0].get_name()); auto& arg0_tensor = external_function->get_tensor_data(args[0].get_name());
auto& arg1_tensor = external_function->get_tensor_data(args[1].get_name()); auto& arg1_tensor = external_function->get_tensor_data(args[1].get_name());
...@@ -347,48 +338,47 @@ namespace ngraph ...@@ -347,48 +338,47 @@ namespace ngraph
} }
else else
{ {
std::function<decltype(runtime::cpu::kernel::convolution<float>)> kernel; std::function<decltype(runtime::cpu::kernel::convolution_backprop_in<float>)>
kernel;
SELECT_KERNEL( SELECT_KERNEL(kernel,
kernel, out[0].get_element_type(), runtime::cpu::kernel::convolution); out[0].get_element_type(),
runtime::cpu::kernel::convolution_backprop_in);
auto window_movement_strides = auto& in_shape = convolution->get_data_batch_shape();
convolution->get_window_movement_strides_backward(); auto data_dilation_strides = convolution->get_data_dilation_strides_forward();
auto window_dilation_strides = auto window_dilation_strides =
convolution->get_window_dilation_strides_backward(); convolution->get_window_dilation_strides_forward();
auto padding_below = convolution->get_padding_below_backward(); auto padding_below = convolution->get_padding_below_forward();
auto padding_above = convolution->get_padding_above_backward(); auto padding_above = convolution->get_padding_above_forward();
auto data_dilation_strides = convolution->get_data_dilation_strides_backward(); auto window_movement_strides =
convolution->get_window_movement_strides_forward();
auto backward_delta_out_pad_below =
convolution->compute_backward_delta_out_pad_below();
auto backward_delta_out_pad_above =
convolution->compute_backward_delta_out_pad_above();
auto functor = [&, auto functor = [&,
kernel, kernel,
arg0_shape, arg0_shape,
arg1_shape, arg1_shape,
result_shape, in_shape,
window_movement_strides, data_dilation_strides,
window_dilation_strides, window_dilation_strides,
padding_below, backward_delta_out_pad_below,
padding_above, backward_delta_out_pad_above,
data_dilation_strides](CPURuntimeContext* ctx, window_movement_strides](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) { CPUExecutionContext* ectx) {
kernel(arg1_tensor, kernel(arg1_tensor,
arg0_tensor, arg0_tensor,
out_tensor, out_tensor,
arg1_shape, arg1_shape,
arg0_shape, arg0_shape,
result_shape, in_shape,
window_movement_strides,
window_dilation_strides,
padding_below,
padding_above,
data_dilation_strides, data_dilation_strides,
0, window_dilation_strides,
1, backward_delta_out_pad_below,
0, backward_delta_out_pad_above,
1, window_movement_strides);
0,
1,
true);
}; };
functors.emplace_back(functor); functors.emplace_back(functor);
} }
...@@ -403,7 +393,6 @@ namespace ngraph ...@@ -403,7 +393,6 @@ namespace ngraph
auto arg0_shape = args[0].get_shape(); auto arg0_shape = args[0].get_shape();
auto arg1_shape = args[1].get_shape(); auto arg1_shape = args[1].get_shape();
auto result_shape = out[0].get_shape();
auto& arg0_tensor = external_function->get_tensor_data(args[0].get_name()); auto& arg0_tensor = external_function->get_tensor_data(args[0].get_name());
auto& arg1_tensor = external_function->get_tensor_data(args[1].get_name()); auto& arg1_tensor = external_function->get_tensor_data(args[1].get_name());
...@@ -437,28 +426,34 @@ namespace ngraph ...@@ -437,28 +426,34 @@ namespace ngraph
} }
else else
{ {
std::function<decltype(runtime::cpu::kernel::convolution<float>)> kernel; std::function<decltype(
runtime::cpu::kernel::convolution_backprop_filter<float>)>
kernel;
SELECT_KERNEL( SELECT_KERNEL(kernel,
kernel, out[0].get_element_type(), runtime::cpu::kernel::convolution); out[0].get_element_type(),
runtime::cpu::kernel::convolution_backprop_filter);
auto window_movement_strides = auto& filters_shape = convolution->get_filters_shape();
convolution->get_window_movement_strides_backward();
auto window_dilation_strides = auto window_dilation_strides =
convolution->get_window_dilation_strides_backward(); convolution->get_window_dilation_strides_forward();
auto padding_below = convolution->get_padding_below_backward(); auto window_movement_strides =
auto padding_above = convolution->get_padding_above_backward(); convolution->get_window_movement_strides_forward();
auto data_dilation_strides = convolution->get_data_dilation_strides_backward(); auto padding_below = convolution->get_padding_below_forward();
auto padding_above = convolution->get_padding_above_forward();
auto data_dilation_strides = convolution->get_data_dilation_strides_forward();
CoordinateDiff backward_in_pad_above =
convolution->compute_backward_in_pad_above();
auto functor = [&, auto functor = [&,
kernel, kernel,
arg0_shape, arg0_shape,
arg1_shape, arg1_shape,
result_shape, filters_shape,
window_movement_strides,
window_dilation_strides, window_dilation_strides,
window_movement_strides,
padding_below, padding_below,
padding_above, backward_in_pad_above,
data_dilation_strides](CPURuntimeContext* ctx, data_dilation_strides](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) { CPUExecutionContext* ectx) {
kernel(arg0_tensor, kernel(arg0_tensor,
...@@ -466,19 +461,12 @@ namespace ngraph ...@@ -466,19 +461,12 @@ namespace ngraph
out_tensor, out_tensor,
arg0_shape, arg0_shape,
arg1_shape, arg1_shape,
result_shape, filters_shape,
window_movement_strides, window_movement_strides,
window_dilation_strides, window_dilation_strides,
padding_below, padding_below,
padding_above, backward_in_pad_above,
data_dilation_strides, data_dilation_strides);
1,
0,
0,
1,
1,
0,
false);
}; };
functors.emplace_back(functor); functors.emplace_back(functor);
} }
...@@ -624,6 +612,6 @@ namespace ngraph ...@@ -624,6 +612,6 @@ namespace ngraph
REGISTER_OP_BUILDER(GroupConvolution); REGISTER_OP_BUILDER(GroupConvolution);
REGISTER_OP_BUILDER(ConvolutionAdd); REGISTER_OP_BUILDER(ConvolutionAdd);
REGISTER_OP_BUILDER(GroupConvolutionBias); REGISTER_OP_BUILDER(GroupConvolutionBias);
} } // namespace cpu
} } // namespace runtime
} } // namespace ngraph
...@@ -2288,8 +2288,7 @@ namespace ngraph ...@@ -2288,8 +2288,7 @@ namespace ngraph
writer << " {" << join(convolution->get_padding_above()) writer << " {" << join(convolution->get_padding_above())
<< "},\n"; << "},\n";
writer << " {" writer << " {"
<< join(convolution->get_data_dilation_strides()) << "},\n"; << join(convolution->get_data_dilation_strides()) << "});\n";
writer << " 0, 1, 1, 0, 0, 1, false);\n";
} }
} }
...@@ -2323,24 +2322,25 @@ namespace ngraph ...@@ -2323,24 +2322,25 @@ namespace ngraph
} }
else else
{ {
writer << "reference::convolution<" << out[0].get_type() << ">(" writer << "reference::convolution_backprop_filters<" << out[0].get_type()
<< args[0].get_name() << ",\n"; << ">(" << args[0].get_name() << ",\n";
writer << " " << args[1].get_name() << ",\n"; writer << " " << args[1].get_name() << ",\n";
writer << " " << out[0].get_name() << ",\n"; writer << " " << out[0].get_name() << ",\n";
writer << " {" << join(convolution->get_filters_shape())
<< "},\n";
writer << " {" << join(arg0_shape) << "},\n"; writer << " {" << join(arg0_shape) << "},\n";
writer << " {" << join(arg1_shape) << "},\n"; writer << " {" << join(arg1_shape) << "},\n";
writer << " {" << join(result_shape) << "},\n"; writer << " {" << join(result_shape) << "},\n";
writer << " {" writer << " {"
<< join(convolution->get_window_movement_strides_backward()) << "},\n"; << join(convolution->get_window_dilation_strides_forward()) << "},\n";
writer << " {" writer << " {"
<< join(convolution->get_window_dilation_strides_backward()) << "},\n"; << join(convolution->get_window_movement_strides_forward()) << "},\n";
writer << " {" writer << " {"
<< join(convolution->get_padding_below_backward()) << "},\n"; << join(convolution->get_padding_below_forward()) << "},\n";
writer << " {" writer << " {"
<< join(convolution->get_padding_above_backward()) << "},\n"; << join(convolution->get_padding_above_forward()) << "},\n";
writer << " {" writer << " {"
<< join(convolution->get_data_dilation_strides_backward()) << "},\n"; << join(convolution->get_data_dilation_strides_forward()) << "});\n";
writer << " 1, 0, 0, 1, 1, 0, false);\n";
} }
} }
...@@ -2375,24 +2375,25 @@ namespace ngraph ...@@ -2375,24 +2375,25 @@ namespace ngraph
else else
{ {
// Note that args[1] and args[0] are switched here from the usual order. // Note that args[1] and args[0] are switched here from the usual order.
writer << "reference::convolution<" << out[0].get_type() << ">(" writer << "reference::convolution_backprop_data<" << out[0].get_type() << ">("
<< args[1].get_name() << ",\n"; << args[1].get_name() << ",\n";
writer << " " << args[0].get_name() << ",\n"; writer << " " << args[0].get_name() << ",\n";
writer << " " << out[0].get_name() << ",\n"; writer << " " << out[0].get_name() << ",\n";
writer << " {"
<< join(convolution->get_data_batch_shape()) << "},\n";
writer << " {" << join(arg1_shape) << "},\n"; writer << " {" << join(arg1_shape) << "},\n";
writer << " {" << join(arg0_shape) << "},\n"; writer << " {" << join(arg0_shape) << "},\n";
writer << " {" << join(result_shape) << "},\n"; writer << " {" << join(result_shape) << "},\n";
writer << " {" writer << " {"
<< join(convolution->get_window_movement_strides_backward()) << "},\n"; << join(convolution->get_data_dilation_strides_forward()) << "},\n";
writer << " {" writer << " {"
<< join(convolution->get_window_dilation_strides_backward()) << "},\n"; << join(convolution->get_window_dilation_strides_forward()) << "},\n";
writer << " {" writer << " {"
<< join(convolution->get_padding_below_backward()) << "},\n"; << join(convolution->get_padding_below_forward()) << "},\n";
writer << " {" writer << " {"
<< join(convolution->get_padding_above_backward()) << "},\n"; << join(convolution->get_padding_above_forward()) << "},\n";
writer << " {" writer << " {"
<< join(convolution->get_data_dilation_strides_backward()) << "},\n"; << join(convolution->get_window_movement_strides_forward()) << "});\n";
writer << " 0, 1, 0, 1, 0, 1, true);\n";
} }
} }
......
...@@ -38,14 +38,7 @@ namespace ngraph ...@@ -38,14 +38,7 @@ namespace ngraph
const Strides& window_dilation_strides, const Strides& window_dilation_strides,
const CoordinateDiff& padding_below, const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above, const CoordinateDiff& padding_above,
const Strides& data_dilation_strides, const Strides& data_dilation_strides)
size_t batch_axis_data,
size_t input_channel_axis_data,
size_t input_channel_axis_filters,
size_t output_channel_axis_filters,
size_t batch_axis_result,
size_t output_channel_axis_result,
bool rotate_filter)
{ {
reference::convolution<ElementType>(static_cast<const ElementType*>(input0), reference::convolution<ElementType>(static_cast<const ElementType*>(input0),
static_cast<const ElementType*>(input1), static_cast<const ElementType*>(input1),
...@@ -57,16 +50,63 @@ namespace ngraph ...@@ -57,16 +50,63 @@ namespace ngraph
window_dilation_strides, window_dilation_strides,
padding_below, padding_below,
padding_above, padding_above,
data_dilation_strides, data_dilation_strides);
batch_axis_data,
input_channel_axis_data,
input_channel_axis_filters,
output_channel_axis_filters,
batch_axis_result,
output_channel_axis_result,
rotate_filter);
}
} }
template <typename ElementType>
void convolution_backprop_filter(void* input0,
void* input1,
void* output,
const Shape& arg0_shape,
const Shape& arg1_shape,
const Shape& filter_shape,
const Strides& window_dilation_strides,
const Strides& window_movement_strides,
const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above,
const Strides& data_dilation_strides)
{
reference::convolution_backprop_filter<ElementType>(
static_cast<const ElementType*>(input0),
static_cast<const ElementType*>(input1),
static_cast<ElementType*>(output),
arg0_shape,
arg1_shape,
filter_shape,
window_dilation_strides,
window_movement_strides,
padding_below,
padding_above,
data_dilation_strides);
} }
template <typename ElementType>
void convolution_backprop_in(void* input0,
void* input1,
void* output,
const Shape& arg0_shape,
const Shape& arg1_shape,
const Shape& in_shape,
const Strides& window_movement_strides,
const Strides& window_dilation_strides,
const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above,
const Strides& data_dilation_strides)
{
reference::convolution_backprop_in<ElementType>(
static_cast<const ElementType*>(input0),
static_cast<const ElementType*>(input1),
static_cast<ElementType*>(output),
arg0_shape,
arg1_shape,
in_shape,
window_movement_strides,
window_dilation_strides,
padding_below,
padding_above,
data_dilation_strides);
} }
} } // namespace kernel
} // namespace cpu
} // namespace runtime
} // namespace ngraph
...@@ -1629,11 +1629,11 @@ shared_ptr<runtime::Executable> ...@@ -1629,11 +1629,11 @@ shared_ptr<runtime::Executable>
const shared_ptr<op::ConvolutionBackpropFilters> conv_op = const shared_ptr<op::ConvolutionBackpropFilters> conv_op =
static_pointer_cast<op::ConvolutionBackpropFilters>(op); static_pointer_cast<op::ConvolutionBackpropFilters>(op);
const Strides& win_stride = conv_op->get_window_movement_strides_backward(); const Strides& win_stride = conv_op->get_window_dilation_strides_forward();
const CoordinateDiff& pad_below = conv_op->get_padding_below_backward(); const CoordinateDiff& pad_below = conv_op->get_padding_below_forward();
const CoordinateDiff& pad_above = conv_op->get_padding_above_backward(); CoordinateDiff pad_above = conv_op->compute_backward_in_pad_above();
const Strides& win_dilation = conv_op->get_window_dilation_strides_backward(); const Strides& win_dilation = conv_op->get_window_movement_strides_forward();
const Strides& data_dilation = conv_op->get_data_dilation_strides_backward(); const Strides& data_dilation = conv_op->get_data_dilation_strides_forward();
if ((win_stride.size() > 2) || (win_stride.at(0) != 1) || (win_stride.at(1) != 1) || if ((win_stride.size() > 2) || (win_stride.at(0) != 1) || (win_stride.at(1) != 1) ||
(pad_below.size() > 2) || (pad_above.size() > 2) || (data_dilation.size() > 2) || (pad_below.size() > 2) || (pad_above.size() > 2) || (data_dilation.size() > 2) ||
...@@ -1648,10 +1648,10 @@ shared_ptr<runtime::Executable> ...@@ -1648,10 +1648,10 @@ shared_ptr<runtime::Executable>
get_output_name(op), get_output_name(op),
get_output_shape(op), get_output_shape(op),
get_output_type(op), get_output_type(op),
conv_op->get_padding_below_backward(), conv_op->get_padding_below_forward(),
conv_op->get_window_movement_strides_backward(), win_stride,
conv_op->get_window_dilation_strides_backward(), win_dilation,
conv_op->get_data_dilation_strides_backward(), data_dilation,
1, 1,
0, 0,
0, 0,
...@@ -1728,11 +1728,11 @@ shared_ptr<runtime::Executable> ...@@ -1728,11 +1728,11 @@ shared_ptr<runtime::Executable>
const shared_ptr<op::ConvolutionBackpropData> conv_op = const shared_ptr<op::ConvolutionBackpropData> conv_op =
static_pointer_cast<op::ConvolutionBackpropData>(op); static_pointer_cast<op::ConvolutionBackpropData>(op);
const Strides& win_stride = conv_op->get_window_movement_strides_backward(); const Strides& win_stride = conv_op->get_data_dilation_strides_forward();
const CoordinateDiff& pad_below = conv_op->get_padding_below_backward(); CoordinateDiff pad_below = conv_op->compute_backward_delta_out_pad_below();
const CoordinateDiff& pad_above = conv_op->get_padding_above_backward(); CoordinateDiff pad_above = conv_op->compute_backward_delta_out_pad_above();
const Strides& win_dilation = conv_op->get_window_dilation_strides_backward(); const Strides& win_dilation = conv_op->get_window_dilation_strides_forward();
const Strides& data_dilation = conv_op->get_data_dilation_strides_backward(); const Strides& data_dilation = conv_op->get_window_movement_strides_forward();
if ((win_stride.size() > 2) || (win_stride.at(0) != 1) || (win_stride.at(1) != 1) || if ((win_stride.size() > 2) || (win_stride.at(0) != 1) || (win_stride.at(1) != 1) ||
(pad_below.size() > 2) || (pad_above.size() > 2) || (data_dilation.size() > 2) || (pad_below.size() > 2) || (pad_above.size() > 2) || (data_dilation.size() > 2) ||
...@@ -1749,10 +1749,10 @@ shared_ptr<runtime::Executable> ...@@ -1749,10 +1749,10 @@ shared_ptr<runtime::Executable>
get_output_name(op), get_output_name(op),
get_output_shape(op), get_output_shape(op),
get_output_type(op), get_output_type(op),
conv_op->get_padding_below_backward(), pad_below,
conv_op->get_window_movement_strides_backward(), win_stride,
conv_op->get_window_dilation_strides_backward(), win_dilation,
conv_op->get_data_dilation_strides_backward(), data_dilation,
0, 0,
1, 1,
1, 1,
......
...@@ -346,17 +346,7 @@ void print_node_parameters(ostringstream& writer, const shared_ptr<Node>& node) ...@@ -346,17 +346,7 @@ void print_node_parameters(ostringstream& writer, const shared_ptr<Node>& node)
<< print_table_row_dims("pad_above_forward", << print_table_row_dims("pad_above_forward",
conv_op_filt->get_padding_above_forward()) conv_op_filt->get_padding_above_forward())
<< print_table_row_dims("pad_below_forward", << print_table_row_dims("pad_below_forward",
conv_op_filt->get_padding_below_forward()) conv_op_filt->get_padding_below_forward());
<< print_table_row_dims("window_movement_strides_backward",
conv_op_filt->get_window_movement_strides_backward())
<< print_table_row_dims("window_dilation_strides_backward",
conv_op_filt->get_window_dilation_strides_backward())
<< print_table_row_dims("data_dilation_strides_backward",
conv_op_filt->get_data_dilation_strides_backward())
<< print_table_row_dims("padding_above_backward",
conv_op_filt->get_padding_above_backward())
<< print_table_row_dims("padding_below_backward",
conv_op_filt->get_padding_below_backward());
break; break;
} }
case OP_TYPEID::ConvolutionBackpropData: case OP_TYPEID::ConvolutionBackpropData:
...@@ -374,17 +364,7 @@ void print_node_parameters(ostringstream& writer, const shared_ptr<Node>& node) ...@@ -374,17 +364,7 @@ void print_node_parameters(ostringstream& writer, const shared_ptr<Node>& node)
<< print_table_row_dims("pad_above_forward", << print_table_row_dims("pad_above_forward",
conv_op_data->get_padding_above_forward()) conv_op_data->get_padding_above_forward())
<< print_table_row_dims("pad_below_forward", << print_table_row_dims("pad_below_forward",
conv_op_data->get_padding_below_forward()) conv_op_data->get_padding_below_forward());
<< print_table_row_dims("window_movement_strides_backward",
conv_op_data->get_window_movement_strides_backward())
<< print_table_row_dims("window_dilation_strides_backward",
conv_op_data->get_window_dilation_strides_backward())
<< print_table_row_dims("data_dilation_strides_backward",
conv_op_data->get_data_dilation_strides_backward())
<< print_table_row_dims("padding_above_backward",
conv_op_data->get_padding_above_backward())
<< print_table_row_dims("padding_below_backward",
conv_op_data->get_padding_below_backward());
break; break;
} }
case OP_TYPEID::UNDEFINED_OP: case OP_TYPEID::UNDEFINED_OP:
......
...@@ -146,9 +146,9 @@ namespace ngraph ...@@ -146,9 +146,9 @@ namespace ngraph
{ {
class INTBackend; class INTBackend;
class INTExecutable; class INTExecutable;
} } // namespace interpreter
} } // namespace runtime
} } // namespace ngraph
class ngraph::runtime::interpreter::INTExecutable : public Executable class ngraph::runtime::interpreter::INTExecutable : public Executable
{ {
...@@ -551,38 +551,26 @@ private: ...@@ -551,38 +551,26 @@ private:
c->get_window_dilation_strides(), c->get_window_dilation_strides(),
c->get_padding_below(), c->get_padding_below(),
c->get_padding_above(), c->get_padding_above(),
c->get_data_dilation_strides(), c->get_data_dilation_strides());
0,
1,
1,
0,
0,
1,
false);
break; break;
} }
case OP_TYPEID::ConvolutionBackpropFilters: case OP_TYPEID::ConvolutionBackpropFilters:
{ {
const op::ConvolutionBackpropFilters* c = const op::ConvolutionBackpropFilters* c =
static_cast<const op::ConvolutionBackpropFilters*>(&node); static_cast<const op::ConvolutionBackpropFilters*>(&node);
reference::convolution<T>(args[0]->get_data_ptr<const T>(), reference::convolution_backprop_filter<T>(
args[1]->get_data_ptr<const T>(), args[0]->get_data_ptr<const T>(), // input
out[0]->get_data_ptr<T>(), args[1]->get_data_ptr<const T>(), // delta_convolution_output
node.get_input_shape(0), out[0]->get_data_ptr<T>(), // delta_filter
node.get_input_shape(1), c->get_input_shape(0), // input_shape
node.get_output_shape(0), c->get_input_shape(1), // convolution_output_shape
c->get_window_movement_strides_backward(), c->get_filters_shape(), // filter_shape
c->get_window_dilation_strides_backward(), c->get_window_dilation_strides_forward(),
c->get_padding_below_backward(), c->get_window_movement_strides_forward(),
c->get_padding_above_backward(), c->get_padding_below_forward(),
c->get_data_dilation_strides_backward(), c->compute_backward_in_pad_above(),
1, c->get_data_dilation_strides_forward());
0,
0,
1,
1,
0,
false);
break; break;
} }
case OP_TYPEID::ConvolutionBackpropData: case OP_TYPEID::ConvolutionBackpropData:
...@@ -590,24 +578,17 @@ private: ...@@ -590,24 +578,17 @@ private:
// Note that args[1] and args[0] are switched here from the usual order. // Note that args[1] and args[0] are switched here from the usual order.
const op::ConvolutionBackpropData* c = const op::ConvolutionBackpropData* c =
static_cast<const op::ConvolutionBackpropData*>(&node); static_cast<const op::ConvolutionBackpropData*>(&node);
reference::convolution<T>(args[1]->get_data_ptr<const T>(), reference::convolution_backprop_in<T>(args[1]->get_data_ptr<const T>(),
args[0]->get_data_ptr<const T>(), args[0]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(),
node.get_input_shape(1), c->get_input_shape(1),
node.get_input_shape(0), c->get_input_shape(0),
node.get_output_shape(0), c->get_data_batch_shape(),
c->get_window_movement_strides_backward(), c->get_data_dilation_strides_forward(),
c->get_window_dilation_strides_backward(), c->get_window_dilation_strides_forward(),
c->get_padding_below_backward(), c->compute_backward_delta_out_pad_below(),
c->get_padding_above_backward(), c->compute_backward_delta_out_pad_above(),
c->get_data_dilation_strides_backward(), c->get_window_movement_strides_forward());
0,
1,
0,
1,
0,
1,
true);
break; break;
} }
case OP_TYPEID::Cos: case OP_TYPEID::Cos:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment