Commit 466c4d43 authored by Nagy Mostafa's avatar Nagy Mostafa Committed by Sang Ik Lee

Fix compose/decompose of group_conv for CPU backend. (#3947)

* disable group conv if data dilation is not 1s

* Enable group conv decomposition if data or window dilation is enabled. Re-enable a unit-test

* PR fix

* PR fix

* Fix fails
parent c546282d
......@@ -62,8 +62,8 @@ static bool is_trivial_convolution(std::shared_ptr<op::Convolution> conv)
{
Strides stride_1{1, 1};
CoordinateDiff pad_0{0, 0};
return conv->get_window_dilation_strides() == stride_1 ||
conv->get_data_dilation_strides() == stride_1 || conv->get_padding_above() == pad_0 ||
return conv->get_window_dilation_strides() == stride_1 &&
conv->get_data_dilation_strides() == stride_1 && conv->get_padding_above() == pad_0 &&
conv->get_padding_below() == pad_0;
}
......@@ -115,6 +115,7 @@ std::shared_ptr<Node> fuse_group_convolution(const std::shared_ptr<Node>& n)
NGRAPH_DEBUG << "convolution data's rank isn't equal to 4";
return {nullptr};
}
if (!is_trivial_convolution(sconv))
{
NGRAPH_DEBUG << arg->get_name() << " isn't trivial convolution";
......
......@@ -1227,6 +1227,11 @@ void runtime::cpu::CPU_ExternalFunction::register_common_passes(
return false;
#endif
}
// GroupConvolution is only supported with MKLDNN
else if (auto conv = as_type<ngraph::op::GroupConvolution>(const_cast<Node*>(&node)))
{
return mkldnn_utils::can_use_mkldnn_conv<ngraph::op::GroupConvolution>(conv);
}
if (dex)
{
......
......@@ -11,9 +11,6 @@ max_3d_to_scalar_int32
send_recv
send_recv_ring
# param not supported in CPU backend
group_conv_data_dilation
# axes input param not supported
lrn_across_h
lrn_across_hw
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment