Commit 466c4d43 authored by Nagy Mostafa's avatar Nagy Mostafa Committed by Sang Ik Lee

Fix compose/decompose of group_conv for CPU backend. (#3947)

* disable group conv if data dilation is not 1s

* Enable group conv decomposition if data or window dilation is enabled. Re-enable a unit-test

* PR fix

* PR fix

* Fix fails
parent c546282d
...@@ -62,8 +62,8 @@ static bool is_trivial_convolution(std::shared_ptr<op::Convolution> conv) ...@@ -62,8 +62,8 @@ static bool is_trivial_convolution(std::shared_ptr<op::Convolution> conv)
{ {
Strides stride_1{1, 1}; Strides stride_1{1, 1};
CoordinateDiff pad_0{0, 0}; CoordinateDiff pad_0{0, 0};
return conv->get_window_dilation_strides() == stride_1 || return conv->get_window_dilation_strides() == stride_1 &&
conv->get_data_dilation_strides() == stride_1 || conv->get_padding_above() == pad_0 || conv->get_data_dilation_strides() == stride_1 && conv->get_padding_above() == pad_0 &&
conv->get_padding_below() == pad_0; conv->get_padding_below() == pad_0;
} }
...@@ -115,6 +115,7 @@ std::shared_ptr<Node> fuse_group_convolution(const std::shared_ptr<Node>& n) ...@@ -115,6 +115,7 @@ std::shared_ptr<Node> fuse_group_convolution(const std::shared_ptr<Node>& n)
NGRAPH_DEBUG << "convolution data's rank isn't equal to 4"; NGRAPH_DEBUG << "convolution data's rank isn't equal to 4";
return {nullptr}; return {nullptr};
} }
if (!is_trivial_convolution(sconv)) if (!is_trivial_convolution(sconv))
{ {
NGRAPH_DEBUG << arg->get_name() << " isn't trivial convolution"; NGRAPH_DEBUG << arg->get_name() << " isn't trivial convolution";
......
...@@ -1227,6 +1227,11 @@ void runtime::cpu::CPU_ExternalFunction::register_common_passes( ...@@ -1227,6 +1227,11 @@ void runtime::cpu::CPU_ExternalFunction::register_common_passes(
return false; return false;
#endif #endif
} }
// GroupConvolution is only supported with MKLDNN
else if (auto conv = as_type<ngraph::op::GroupConvolution>(const_cast<Node*>(&node)))
{
return mkldnn_utils::can_use_mkldnn_conv<ngraph::op::GroupConvolution>(conv);
}
if (dex) if (dex)
{ {
......
...@@ -11,9 +11,6 @@ max_3d_to_scalar_int32 ...@@ -11,9 +11,6 @@ max_3d_to_scalar_int32
send_recv send_recv
send_recv_ring send_recv_ring
# param not supported in CPU backend
group_conv_data_dilation
# axes input param not supported # axes input param not supported
lrn_across_h lrn_across_h
lrn_across_hw lrn_across_hw
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment