Commit b6bc86bf authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Scott Cyphers

TF-flavoured group convolution (#1182)

* tf group convolution

* change perms
parent 2fc0bbb4
...@@ -254,6 +254,8 @@ std::shared_ptr<Node> fuse_group_convolution(const std::shared_ptr<Node>& n) ...@@ -254,6 +254,8 @@ std::shared_ptr<Node> fuse_group_convolution(const std::shared_ptr<Node>& n)
auto concat = std::dynamic_pointer_cast<op::Concat>(n); auto concat = std::dynamic_pointer_cast<op::Concat>(n);
std::shared_ptr<op::Convolution> sconv; std::shared_ptr<op::Convolution> sconv;
NodeVector slices;
const size_t CHANNEL = 1; const size_t CHANNEL = 1;
if (concat->get_concatenation_axis() != CHANNEL) if (concat->get_concatenation_axis() != CHANNEL)
{ {
...@@ -296,11 +298,17 @@ std::shared_ptr<Node> fuse_group_convolution(const std::shared_ptr<Node>& n) ...@@ -296,11 +298,17 @@ std::shared_ptr<Node> fuse_group_convolution(const std::shared_ptr<Node>& n)
auto slice = pattern_map[slice_weights_label]; auto slice = pattern_map[slice_weights_label];
if (weights->get_shape().at(IC) != slice->get_shape().at(IC)) if (weights->get_shape().at(IC) != slice->get_shape().at(IC))
{ {
NGRAPH_DEBUG << "slices are done on the wrong axis (IC)"; slices.push_back(slice);
return {nullptr};
} }
} }
//TF-flavoured group convolution needs channels re-arranged
const size_t CONCAT_AXIS_OC = 0;
if (!slices.empty())
{
weights = std::make_shared<op::Concat>(slices, CONCAT_AXIS_OC);
}
auto new_conv = std::make_shared<op::GroupConvolution>(data, auto new_conv = std::make_shared<op::GroupConvolution>(data,
weights, weights,
sconv->get_window_movement_strides(), sconv->get_window_movement_strides(),
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment