Commit a58d3bc2 authored by gaurides's avatar gaurides Committed by Scott Cyphers

Fix perf regression in some models (#3260)

* Fix perf regression in vgg16

* Make switch generic

* Remove unused variables

* Review comments

* Remove unused function parameters

* trivial commit to restart CI
parent 5465677f
......@@ -1194,6 +1194,7 @@ void runtime::cpu::CPU_ExternalFunction::register_common_passes(
REGISTER_KNOBBED_PASS(RecurrentReshapeElimination, false, ngraph::pass);
REGISTER_KNOBBED_PASS_WITH_ARGS(
CoreFusion, true, ngraph::pass, ngraph::pass::FusionType::ALL_FUSIONS);
REGISTER_KNOBBED_PASS(CPUPreFusion, true, runtime::cpu::pass);
// Disable CPUFusion if MLIR is enabled to preserve core ops.
if (std::getenv("NGRAPH_MLIR") == nullptr)
......
......@@ -520,6 +520,70 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_conv_bias_bprop()
this->add_matcher(m, callback);
}
static bool switch_nodes(std::shared_ptr<ngraph::Node> node1,
std::shared_ptr<ngraph::Node> node2,
size_t source_input_index = 0)
{
// check if node1 has only 1 argument, not sure how it will work with >1 args
if (node1->inputs().size() > 1)
{
NGRAPH_DEBUG << "Cannot switch. More than 1 inputs to this node\n";
return false;
}
if (node1->get_users().size() > 1)
{
NGRAPH_DEBUG << "Cannot switch. More than 1 user of this node\n";
return false;
}
if (node1->outputs().size() > 1)
{
NGRAPH_DEBUG << "Cannot switch. More than 1 output of this node\n";
return false;
}
if (node2->outputs().size() > 1)
{
NGRAPH_DEBUG << "Cannot switch. More than 1 output of this node\n";
return false;
}
auto target_inputs = node2->output(0).get_target_inputs();
// Remove the control_dependency, which shouldn't be there, but in case
// Other control_dependencies will work out fine even after switch.
node2->remove_control_dependency(node1);
// actual switch happening after this
auto arg = node1->get_argument(source_input_index);
node2->input(0).replace_source_output(arg);
node1->input(0).replace_source_output(node2->output(0));
// used implementation ref from replace_node
for (auto& input : target_inputs)
{
input.replace_source_output(node1->output(0));
}
return true;
}
void ngraph::runtime::cpu::pass::CPUPreFusion::construct_maxpool_relu_switch()
{
auto input_shape = Shape{1, 2, 2, 2};
auto input = std::make_shared<pattern::op::Label>(element::f32, input_shape);
Shape window_shape{2, 2};
auto max_pool = std::make_shared<ngraph::op::MaxPool>(input, window_shape);
auto prelu = std::make_shared<ngraph::op::Relu>(max_pool);
auto callback = [input](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for construct_maxpool_relu_switch against node = "
<< m.get_match_root()->get_name();
return switch_nodes(m.get_match_root()->get_argument(0), m.get_match_root());
};
auto m = std::make_shared<ngraph::pattern::Matcher>(prelu, "CPUPreFusion.MaxpoolReluSwitch");
this->add_matcher(m, callback);
}
void ngraph::runtime::cpu::pass::CPUFusion::construct_batch_norm_relu()
{
auto input_shape = Shape{1, 2, 2, 2};
......
......@@ -27,6 +27,7 @@ namespace ngraph
{
namespace pass
{
class CPUPreFusion;
class CPUFusion;
class CPUQuantFusion;
}
......@@ -34,6 +35,19 @@ namespace ngraph
}
}
class CPU_BACKEND_API ngraph::runtime::cpu::pass::CPUPreFusion : public ngraph::pass::GraphRewrite
{
public:
CPUPreFusion()
: GraphRewrite()
{
construct_maxpool_relu_switch();
}
private:
void construct_maxpool_relu_switch();
};
class CPU_BACKEND_API ngraph::runtime::cpu::pass::CPUFusion : public ngraph::pass::GraphRewrite
{
public:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment