Commit 2ca1528e authored by Jaikrishnan Menon's avatar Jaikrishnan Menon

Merge fixes

parent 1fbc72d6
...@@ -97,7 +97,7 @@ static std::vector<T> apply_permutation(std::vector<T> input, ngraph::AxisVector ...@@ -97,7 +97,7 @@ static std::vector<T> apply_permutation(std::vector<T> input, ngraph::AxisVector
return output; return output;
} }
void ngraph::pass::CPUFusion::construct_gemm_pattern() void ngraph::runtime::cpu::pass::CPUFusion::construct_gemm_pattern()
{ {
auto shape_w = Shape{2, 4}; auto shape_w = Shape{2, 4};
auto shape_x = Shape{4, 1}; auto shape_x = Shape{4, 1};
......
...@@ -20,13 +20,19 @@ ...@@ -20,13 +20,19 @@
namespace ngraph namespace ngraph
{ {
namespace runtime
{
namespace cpu
{
namespace pass namespace pass
{ {
class CPUFusion; class CPUFusion;
} }
}
}
} }
class ngraph::pass::CPUFusion : public ngraph::pass::GraphRewrite class ngraph::runtime::cpu::pass::CPUFusion : public ngraph::pass::GraphRewrite
{ {
public: public:
CPUFusion() CPUFusion()
......
...@@ -144,7 +144,7 @@ TEST(cpu_fusion, cpu_fusion_pass_basic) ...@@ -144,7 +144,7 @@ TEST(cpu_fusion, cpu_fusion_pass_basic)
auto add = dot + broadcast; auto add = dot + broadcast;
auto graph = make_shared<op::Abs>(add); auto graph = make_shared<op::Abs>(add);
pass::Manager pass_manager; pass::Manager pass_manager;
pass_manager.register_pass<pass::CPUFusion>(); pass_manager.register_pass<runtime::cpu::pass::CPUFusion>();
auto func = make_shared<Function>(graph, op::Parameters{A, B, C}); auto func = make_shared<Function>(graph, op::Parameters{A, B, C});
pass_manager.run_passes(func); pass_manager.run_passes(func);
ASSERT_NE(std::dynamic_pointer_cast<op::MatmulBias>(graph->get_input_op(0)), nullptr); ASSERT_NE(std::dynamic_pointer_cast<op::MatmulBias>(graph->get_input_op(0)), nullptr);
...@@ -157,7 +157,7 @@ TEST(cpu_fusion, gemm_mlp) ...@@ -157,7 +157,7 @@ TEST(cpu_fusion, gemm_mlp)
stringstream ss(json_string); stringstream ss(json_string);
shared_ptr<Function> func = ngraph::deserialize(ss); shared_ptr<Function> func = ngraph::deserialize(ss);
pass::Manager pass_manager; pass::Manager pass_manager;
pass_manager.register_pass<pass::CPUFusion>(); pass_manager.register_pass<runtime::cpu::pass::CPUFusion>();
pass_manager.run_passes(func); pass_manager.run_passes(func);
size_t ccg = count_ops_of_type<op::MatmulBias>(func); size_t ccg = count_ops_of_type<op::MatmulBias>(func);
ASSERT_EQ(ccg, 3); ASSERT_EQ(ccg, 3);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment