Commit 38b28c13 authored by nikolay.korovaiko's avatar nikolay.korovaiko

conv+bias tests

parent 22819e78
......@@ -298,7 +298,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_conv_bias()
std::shared_ptr<Node> nn;
auto conv = std::dynamic_pointer_cast<op::Convolution>(m.match_root()->get_input_op(0));
auto bias = m.match_root()->get_input_op(1);
auto bias = m.match_root()->get_input_op(1)->get_input_op(0);
auto conv_bias = std::shared_ptr<Node>(new op::ConvolutionBias(conv, bias));
return conv_bias;
};
......
......@@ -36,6 +36,7 @@
#include "ngraph/json.hpp"
#include "ngraph/pass/reshape_elimination.hpp"
#include "ngraph/pass/visualize_tree.hpp"
#include "ngraph/runtime/cpu/ops/conv_bias.hpp"
#include "ngraph/runtime/cpu/ops/matmul_bias.hpp"
#include "ngraph/runtime/cpu/pass/cpu_fusion.hpp"
#include "ngraph/serializer.hpp"
......@@ -288,3 +289,17 @@ TEST(cpu_fusion, fuse_fprop_bn)
size_t ccg = count_ops_of_type<op::BatchNorm>(func);
ASSERT_EQ(ccg, 1);
}
TEST(cpu_fusion, fuse_conv_bias)
{
pass::Manager pass_manager;
pass_manager.register_pass<ngraph::pass::ReshapeElimination>();
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>();
const string json_path = file_util::path_join(SERIALIZED_ZOO, "conv_bias.json");
const string json_string = file_util::read_file_to_string(json_path);
stringstream ss(json_string);
shared_ptr<Function> func = ngraph::deserialize(ss);
pass_manager.run_passes(func);
size_t cb = count_ops_of_type<op::ConvolutionBias>(func);
ASSERT_GT(cb, 0);
}
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment