Commit b6bc86bf authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Scott Cyphers

TF-flavoured group convolution (#1182)

* tf group convolution

* change perms
parent 2fc0bbb4
......@@ -254,6 +254,8 @@ std::shared_ptr<Node> fuse_group_convolution(const std::shared_ptr<Node>& n)
auto concat = std::dynamic_pointer_cast<op::Concat>(n);
std::shared_ptr<op::Convolution> sconv;
NodeVector slices;
const size_t CHANNEL = 1;
if (concat->get_concatenation_axis() != CHANNEL)
{
......@@ -296,9 +298,15 @@ std::shared_ptr<Node> fuse_group_convolution(const std::shared_ptr<Node>& n)
auto slice = pattern_map[slice_weights_label];
if (weights->get_shape().at(IC) != slice->get_shape().at(IC))
{
NGRAPH_DEBUG << "slices are done on the wrong axis (IC)";
return {nullptr};
slices.push_back(slice);
}
}
//TF-flavoured group convolution needs channels re-arranged
const size_t CONCAT_AXIS_OC = 0;
if (!slices.empty())
{
weights = std::make_shared<op::Concat>(slices, CONCAT_AXIS_OC);
}
auto new_conv = std::make_shared<op::GroupConvolution>(data,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment