Commit fe77223d authored by Dmitry Kurtaev's avatar Dmitry Kurtaev

Modify nGraph's ConvolutionBackpropData and GroupConvolution

parent a2642d83
...@@ -544,6 +544,12 @@ public: ...@@ -544,6 +544,12 @@ public:
const int group = inpCn / inpGroupCn; const int group = inpCn / inpGroupCn;
std::vector<size_t> kernel_shape = getShape<size_t>(blobs[0]); std::vector<size_t> kernel_shape = getShape<size_t>(blobs[0]);
if (group != 1)
{
kernel_shape[0] /= group;
kernel_shape.insert(kernel_shape.begin(), group);
}
auto ieWeights = std::make_shared<ngraph::op::Constant>(ngraph::element::f32, kernel_shape, blobs[0].data); auto ieWeights = std::make_shared<ngraph::op::Constant>(ngraph::element::f32, kernel_shape, blobs[0].data);
if (fusedWeights) if (fusedWeights)
{ {
...@@ -566,14 +572,12 @@ public: ...@@ -566,14 +572,12 @@ public:
std::shared_ptr<ngraph::Node> conv_node; std::shared_ptr<ngraph::Node> conv_node;
if (group != 1) { if (group != 1) {
conv_node = std::make_shared<ngraph::op::GroupConvolution>( conv_node = std::make_shared<ngraph::op::v1::GroupConvolution>(
ieInpNode, ieWeights, ieInpNode, ieWeights,
ngraph::Strides(strides), ngraph::Strides(strides),
ngraph::Strides(dilations),
ngraph::CoordinateDiff(std::vector<std::ptrdiff_t>(pads_begin.begin(), pads_begin.end())), ngraph::CoordinateDiff(std::vector<std::ptrdiff_t>(pads_begin.begin(), pads_begin.end())),
ngraph::CoordinateDiff(std::vector<std::ptrdiff_t>(pads_end.begin(), pads_end.end())), ngraph::CoordinateDiff(std::vector<std::ptrdiff_t>(pads_end.begin(), pads_end.end())),
ngraph::Strides{}, ngraph::Strides(dilations),
group,
pad_type); pad_type);
} else { } else {
conv_node = std::make_shared<ngraph::op::v1::Convolution>( conv_node = std::make_shared<ngraph::op::v1::Convolution>(
...@@ -2037,37 +2041,29 @@ public: ...@@ -2037,37 +2041,29 @@ public:
Mat newWeights = blobs[0].reshape(1, inpCn); Mat newWeights = blobs[0].reshape(1, inpCn);
transpose(weightsMat, newWeights); transpose(weightsMat, newWeights);
} }
size_t batch = ieInpNode->get_shape()[0];
std::vector<size_t> out_shape = {batch, (size_t)numOutput};
std::vector<size_t> paddings_end; std::vector<size_t> paddings_end;
std::vector<size_t> inpShape = ieInpNode->get_shape();
if (padMode.empty()) if (padMode.empty())
{ {
for (int i = 0; i < pads_end.size(); i++) { for (int i = 0; i < pads_end.size(); i++) {
out_shape.push_back(strides[i] * (inpShape[2 + i] - 1) +
kernel_size[i] - pads_begin[i] - pads_end[i] + adjust_pads[i]);
paddings_end.push_back(pads_end[i] - adjust_pads[i]); paddings_end.push_back(pads_end[i] - adjust_pads[i]);
} }
} }
else if (padMode == "SAME") else if (padMode == "SAME")
{ {
for (int i = 0; i < pads_begin.size(); i++) { for (int i = 0; i < pads_begin.size(); i++) {
out_shape.push_back(strides[i] * (inpShape[2 + i] - 1) + 1 + adjust_pads[i]);
paddings_end.push_back(kernel_size[i] - pads_begin[i] - 1 - adjust_pads[i]); paddings_end.push_back(kernel_size[i] - pads_begin[i] - 1 - adjust_pads[i]);
} }
} else { } else {
paddings_end = pads_end; paddings_end = pads_end;
} }
auto deconv = std::make_shared<ngraph::op::ConvolutionBackpropData>( auto deconv = std::make_shared<ngraph::op::v1::ConvolutionBackpropData>(
ngraph::Shape{out_shape},
ieWeights,
ieInpNode, ieInpNode,
ieWeights,
ngraph::Strides(strides), ngraph::Strides(strides),
ngraph::Strides(dilations),
ngraph::CoordinateDiff(std::vector<std::ptrdiff_t>(pads_begin.begin(), pads_begin.end())), ngraph::CoordinateDiff(std::vector<std::ptrdiff_t>(pads_begin.begin(), pads_begin.end())),
ngraph::CoordinateDiff(std::vector<std::ptrdiff_t>(paddings_end.begin(), paddings_end.end())), ngraph::CoordinateDiff(std::vector<std::ptrdiff_t>(paddings_end.begin(), paddings_end.end())),
(strides.size() == 2 ? ngraph::Strides{1, 1} : ngraph::Strides{1, 1, 1})); ngraph::Strides(dilations));
if (hasBias() || fusedBias) if (hasBias() || fusedBias)
{ {
std::vector<size_t> shape(deconv->get_shape().size(), 1); std::vector<size_t> shape(deconv->get_shape().size(), 1);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment