Commit 23c0c2fa authored by Amy Zhuang's avatar Amy Zhuang Committed by Scott Cyphers

Add more checks in horizontal fusion. (#2088)

* Add more checks in horizontal fusion.

* Address PR feedback: use a helper function.
parent 9da2c78c
......@@ -30,6 +30,55 @@
using namespace ngraph;
using namespace std;
bool has_same_attributes(const std::shared_ptr<ngraph::op::ConvolutionBias> conv1,
const std::shared_ptr<ngraph::op::ConvolutionBias> conv2)
{
auto conv1_shape = conv1->get_input_shape(1);
auto conv2_shape = conv2->get_input_shape(1);
if (conv1_shape[2] != conv2_shape[2] || conv1_shape[3] != conv2_shape[3])
{
NGRAPH_DEBUG << "conv_horizontal_fusion: skip conv node with different filter shape\n";
return false;
}
if (conv1->get_window_movement_strides() != conv2->get_window_movement_strides())
{
NGRAPH_DEBUG << "conv_horizontal_fusion: skip conv node with different window "
"movement strides\n";
return false;
}
if (conv1->get_window_dilation_strides() != conv2->get_window_dilation_strides())
{
NGRAPH_DEBUG << "conv_horizontal_fusion: skip conv node with different window "
"dilation strides\n";
return false;
}
if (conv1->get_padding_below() != conv2->get_padding_below())
{
NGRAPH_DEBUG << "conv_horizontal_fusion: skip conv node with different padding "
"below\n";
return false;
}
if (conv1->get_padding_above() != conv2->get_padding_above())
{
NGRAPH_DEBUG << "conv_horizontal_fusion: skip conv node with different padding "
"above\n";
return false;
}
if (conv1->get_data_dilation_strides() != conv2->get_data_dilation_strides())
{
NGRAPH_DEBUG << "conv_horizontal_fusion: skip conv node with different data "
"dilation strides\n";
return false;
}
if (conv1->with_relu() != conv2->with_relu())
{
NGRAPH_DEBUG << "conv_horizontal_fusion: skip conv node with different relu "
"status\n";
return false;
}
return true;
};
void ngraph::runtime::cpu::pass::CPUHorizontalFusion::cpu_conv_horizontal_fusion()
{
auto has_multiple_users = [](std::shared_ptr<Node> n) {
......@@ -65,10 +114,6 @@ void ngraph::runtime::cpu::pass::CPUHorizontalFusion::cpu_conv_horizontal_fusion
return false;
}
auto m_filters_shape = conv_bias_root->get_input_shape(1);
auto f_h = m_filters_shape[2];
auto f_w = m_filters_shape[3];
// get weights and bias from each CBR and create Concat nodes
std::vector<std::shared_ptr<Node>> weights_nodes;
std::vector<std::shared_ptr<Node>> bias_nodes;
......@@ -92,13 +137,13 @@ void ngraph::runtime::cpu::pass::CPUHorizontalFusion::cpu_conv_horizontal_fusion
<< u->get_name() << "\n";
continue;
}
auto u_filters_shape = u->get_input_shape(1);
if (u_filters_shape[2] != f_h || u_filters_shape[3] != f_w)
auto conv_u = std::static_pointer_cast<op::ConvolutionBias>(u);
if (!has_same_attributes(conv_u, conv_bias_root))
{
NGRAPH_DEBUG
<< "conv_horizontal_fusion: skip conv node with different filter shape\n";
continue;
}
weights_nodes.push_back(u->get_argument(1));
bias_nodes.push_back(u->get_argument(2));
conv_bias_nodes.push_back(u);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment