Unverified Commit 09adba0c authored by Nick Korovaiko's avatar Nick Korovaiko Committed by GitHub

workaround for depthwise convolution (#1178)

* workaround for depthwise convolution

* fixe error msg
parent 1574031c
......@@ -240,7 +240,10 @@ std::shared_ptr<Node> fuse_group_convolution(const std::shared_ptr<Node>& n)
data_label, Coordinate{0, 0, 0}, Coordinate{1, 2, 9}, Strides{1, 1, 1});
auto slice_weights = std::make_shared<op::Slice>(
weights_label, Coordinate{0, 0, 0}, Coordinate{2, 2, 3}, Strides{1, 1, 1});
auto conv = std::make_shared<op::Convolution>(slice_data, slice_weights);
auto slice_weights_label =
std::make_shared<pattern::op::Label>(slice_weights, nullptr, NodeVector{slice_weights});
auto conv = std::make_shared<op::Convolution>(slice_data, slice_weights_label);
auto matcher = std::make_shared<pattern::Matcher>(conv, nullptr);
NGRAPH_DEBUG << "In simplify_concat (group convolution) for " << n->get_name();
......@@ -288,6 +291,14 @@ std::shared_ptr<Node> fuse_group_convolution(const std::shared_ptr<Node>& n)
NGRAPH_DEBUG << "data or weights nodes are different among slices";
return {nullptr};
}
const size_t IC = 1;
auto slice = pattern_map[slice_weights_label];
if (weights->get_shape().at(IC) != slice->get_shape().at(IC))
{
NGRAPH_DEBUG << "slices are done on the wrong axis (IC)";
return {nullptr};
}
}
auto new_conv = std::make_shared<op::GroupConvolution>(data,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment