Commit 2ca1528e authored by Jaikrishnan Menon's avatar Jaikrishnan Menon

Merge fixes

parent 1fbc72d6
......@@ -97,7 +97,7 @@ static std::vector<T> apply_permutation(std::vector<T> input, ngraph::AxisVector
return output;
}
void ngraph::pass::CPUFusion::construct_gemm_pattern()
void ngraph::runtime::cpu::pass::CPUFusion::construct_gemm_pattern()
{
auto shape_w = Shape{2, 4};
auto shape_x = Shape{4, 1};
......
......@@ -20,13 +20,19 @@
namespace ngraph
{
namespace pass
namespace runtime
{
class CPUFusion;
namespace cpu
{
namespace pass
{
class CPUFusion;
}
}
}
}
class ngraph::pass::CPUFusion : public ngraph::pass::GraphRewrite
class ngraph::runtime::cpu::pass::CPUFusion : public ngraph::pass::GraphRewrite
{
public:
CPUFusion()
......
......@@ -144,7 +144,7 @@ TEST(cpu_fusion, cpu_fusion_pass_basic)
auto add = dot + broadcast;
auto graph = make_shared<op::Abs>(add);
pass::Manager pass_manager;
pass_manager.register_pass<pass::CPUFusion>();
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>();
auto func = make_shared<Function>(graph, op::Parameters{A, B, C});
pass_manager.run_passes(func);
ASSERT_NE(std::dynamic_pointer_cast<op::MatmulBias>(graph->get_input_op(0)), nullptr);
......@@ -157,7 +157,7 @@ TEST(cpu_fusion, gemm_mlp)
stringstream ss(json_string);
shared_ptr<Function> func = ngraph::deserialize(ss);
pass::Manager pass_manager;
pass_manager.register_pass<pass::CPUFusion>();
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>();
pass_manager.run_passes(func);
size_t ccg = count_ops_of_type<op::MatmulBias>(func);
ASSERT_EQ(ccg, 3);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment