Commit 22819e78 authored by nikolay.korovaiko's avatar nikolay.korovaiko

conv+bias fusion

parent 5c0a29ee
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#include "ngraph/ops/broadcast.hpp" #include "ngraph/ops/broadcast.hpp"
#include "ngraph/ops/broadcast.hpp" #include "ngraph/ops/broadcast.hpp"
#include "ngraph/ops/constant.hpp" #include "ngraph/ops/constant.hpp"
#include "ngraph/ops/convolution.hpp"
#include "ngraph/ops/divide.hpp" #include "ngraph/ops/divide.hpp"
#include "ngraph/ops/dot.hpp" #include "ngraph/ops/dot.hpp"
#include "ngraph/ops/multiply.hpp" #include "ngraph/ops/multiply.hpp"
...@@ -38,6 +39,7 @@ ...@@ -38,6 +39,7 @@
#include "ngraph/pattern/matcher.hpp" #include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/any.hpp" #include "ngraph/pattern/op/any.hpp"
#include "ngraph/pattern/op/label.hpp" #include "ngraph/pattern/op/label.hpp"
#include "ngraph/runtime/cpu/ops/conv_bias.hpp"
#include "ngraph/runtime/cpu/ops/matmul_bias.hpp" #include "ngraph/runtime/cpu/ops/matmul_bias.hpp"
static bool init_cblas_arg(std::shared_ptr<ngraph::Node> reshape, static bool init_cblas_arg(std::shared_ptr<ngraph::Node> reshape,
...@@ -270,3 +272,37 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_fprop_bn() ...@@ -270,3 +272,37 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_fprop_bn()
auto m = std::make_shared<ngraph::pattern::Matcher>(add_beta, callback); auto m = std::make_shared<ngraph::pattern::Matcher>(add_beta, callback);
this->add_matcher(m); this->add_matcher(m);
} }
void ngraph::runtime::cpu::pass::CPUFusion::construct_conv_bias()
{
Shape shape{2, 2, 1, 1};
auto data_batch = std::make_shared<pattern::op::Label>(element::f32, shape);
auto filters = std::make_shared<pattern::op::Label>(element::f32, shape);
auto pbias = std::make_shared<pattern::op::Label>(element::f32, Shape{});
auto pbroadcast = std::make_shared<op::Broadcast>(pbias, shape, AxisSet{0, 1, 2, 3});
auto pconv1 = std::make_shared<op::Convolution>(data_batch,
filters,
Strides{1, 1},
Strides{1, 1},
CoordinateDiff{0, 0},
CoordinateDiff{0, 0},
Strides{1, 1});
auto p_conv_bias = pbroadcast + pconv1;
ngraph::pattern::gr_callback_fn callback = [](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for construct_conv_bias against node = "
<< m.match_root()->get_name();
auto pattern_map = m.get_pattern_map();
std::shared_ptr<Node> nn;
auto conv = std::dynamic_pointer_cast<op::Convolution>(m.match_root()->get_input_op(0));
auto bias = m.match_root()->get_input_op(1);
auto conv_bias = std::shared_ptr<Node>(new op::ConvolutionBias(conv, bias));
return conv_bias;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(p_conv_bias, callback);
this->add_matcher(m);
}
...@@ -40,9 +40,11 @@ public: ...@@ -40,9 +40,11 @@ public:
{ {
construct_gemm_pattern(); construct_gemm_pattern();
construct_fprop_bn(); construct_fprop_bn();
construct_conv_bias();
} }
private: private:
void construct_gemm_pattern(); void construct_gemm_pattern();
void construct_fprop_bn(); void construct_fprop_bn();
void construct_conv_bias();
}; };
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment