Commit aa36865c authored by Matthew Brookhart's avatar Matthew Brookhart Committed by Nick Korovaiko

working generate_adjoints (#1173)

parent eef2b19d
...@@ -484,6 +484,78 @@ op::ConvolutionBackpropData::ConvolutionBackpropData(const Shape& data_batch_sha ...@@ -484,6 +484,78 @@ op::ConvolutionBackpropData::ConvolutionBackpropData(const Shape& data_batch_sha
set_value_type_checked(filters_et, inferred_convolution_output_shape); set_value_type_checked(filters_et, inferred_convolution_output_shape);
} }
void op::ConvolutionBackpropData::generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas)
{
auto delta = deltas.at(0);
auto x = get_argument(1);
const auto x_shape = x->get_shape();
auto f = get_argument(0);
const auto f_shape = f->get_shape();
auto data_conv = make_shared<op::Convolution>(delta,
f,
m_window_movement_strides_forward,
m_window_dilation_strides_forward,
m_padding_below_forward,
m_padding_above_forward,
m_data_dilation_strides_forward);
adjoints.add_delta(x, data_conv);
Strides window_movement_strides;
Strides window_dilation_strides;
CoordinateDiff padding_below;
CoordinateDiff padding_above;
Strides data_dilation_strides;
for (size_t i = 0; i < f_shape.size() - 2; i++)
{
window_movement_strides.push_back(m_window_dilation_strides_backward[i]);
window_dilation_strides.push_back(m_window_movement_strides_backward[i]);
padding_below.push_back(m_padding_below_backward[i]);
padding_above.push_back(m_padding_above_backward[i] -
(m_padding_below_backward[i] +
(x_shape[i + 2] - 1) * m_data_dilation_strides_backward[i] +
m_padding_above_backward[i] -
(f_shape[i + 2] - 1) * m_window_dilation_strides_backward[i]) %
m_window_movement_strides_backward[i]);
data_dilation_strides.push_back(m_data_dilation_strides_backward[i]);
}
auto swap_NC = [](const shared_ptr<Node> n) {
AxisVector ax_order(n->get_shape().size());
iota(ax_order.begin(), ax_order.end(), 0);
ax_order[0] = 1;
ax_order[1] = 0;
auto new_shape = n->get_shape();
new_shape[0] = n->get_shape()[1];
new_shape[1] = n->get_shape()[0];
return make_shared<op::Reshape>(n, ax_order, new_shape);
};
delta = swap_NC(delta);
x = swap_NC(x);
shared_ptr<Node> filter_deconv_bprop = make_shared<op::Convolution>(x,
delta,
window_movement_strides,
window_dilation_strides,
padding_below,
padding_above,
data_dilation_strides);
AxisSet axes;
for (size_t i = 2; i < filter_deconv_bprop->get_shape().size(); ++i)
{
axes.insert(i);
}
filter_deconv_bprop = make_shared<ngraph::op::Reverse>(filter_deconv_bprop, axes);
adjoints.add_delta(f, filter_deconv_bprop);
}
shared_ptr<Node> op::ConvolutionBackpropData::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::ConvolutionBackpropData::copy_with_new_args(const NodeVector& new_args) const
{ {
if (new_args.size() != 2) if (new_args.size() != 2)
......
...@@ -180,6 +180,7 @@ namespace ngraph ...@@ -180,6 +180,7 @@ namespace ngraph
const CoordinateDiff& padding_above_forward, const CoordinateDiff& padding_above_forward,
const Strides& data_dilation_strides_forward); const Strides& data_dilation_strides_forward);
void generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas) override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment