Commit a855a3ad authored by Nick Korovaiko's avatar Nick Korovaiko Committed by adstraw

Make MatMulBias aware of addition commutativity (#713)

* make matmulbias callback aware that addition is commutative
parent d73f92c4
......@@ -63,6 +63,27 @@ namespace ngraph
/// \param graph_node is an input graph to be matched against
bool match(const std::shared_ptr<Node>& graph_node);
template <typename T>
static std::shared_ptr<T> unique_match(std::shared_ptr<Node> node)
{
std::shared_ptr<T> matched;
for (auto arg : node->get_input_ops())
{
if (auto t_casted = std::dynamic_pointer_cast<T>(arg))
{
if (matched)
{
throw ngraph_error("There's more than two arguments of the same type");
}
else
{
matched = t_casted;
}
}
}
return matched;
}
bool process_match(gr_callback_fn callback = nullptr);
void reset() {}
......
......@@ -138,8 +138,8 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_matmulbias_pattern()
<< m.match_root()->get_name();
auto mpattern = m.match_root(); //add
auto m_matmul = std::dynamic_pointer_cast<op::MatmulBias>(mpattern->get_input_op(0));
auto m_broadcast = std::dynamic_pointer_cast<op::Broadcast>(mpattern->get_input_op(1));
auto m_matmul = ngraph::pattern::Matcher::unique_match<op::MatmulBias>(mpattern);
auto m_broadcast = ngraph::pattern::Matcher::unique_match<op::Broadcast>(mpattern);
auto m_bias = m_broadcast->get_input_op(0);
auto pattern_map = m.get_pattern_map();
......
......@@ -269,6 +269,27 @@ TEST(cpu_fusion, cpu_fusion_pass_basic)
ASSERT_NE(std::dynamic_pointer_cast<op::MatmulBias>(graph->get_input_op(0)), nullptr);
}
TEST(cpu_fusion, commutative_matmul_bias)
{
Shape shape{};
Shape shape_w{2, 4};
Shape shape_x{4, 1};
Shape shape_b{1};
auto A = make_shared<op::Parameter>(element::f32, shape_w);
auto B = make_shared<op::Parameter>(element::f32, shape_x);
auto C = make_shared<op::Parameter>(element::f32, shape_b);
auto dot = make_shared<op::Dot>(A, B);
auto broadcast = make_shared<op::Broadcast>(C, dot->get_shape(), AxisSet{0});
auto add = broadcast + dot;
auto graph = make_shared<op::Abs>(add);
pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>();
auto func = make_shared<Function>(graph, op::ParameterVector{A, B, C});
pass_manager.run_passes(func);
ASSERT_NE(std::dynamic_pointer_cast<op::MatmulBias>(graph->get_input_op(0)), nullptr);
}
TEST(cpu_fusion, cpu_fusion_pass_matmul_bias)
{
Shape shape_w{2, 4};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment