Commit 6c676d2d authored by Jaikrishnan Menon's avatar Jaikrishnan Menon

CPU: Merge fixes

parent 2659d5be
......@@ -226,20 +226,13 @@ void runtime::cpu::CPU_ExternalFunction::compile()
string function_name = m_function->get_name();
<<<<<<< HEAD
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>();
pass_manager.register_pass<runtime::cpu::pass::CPULayout>();
pass_manager.register_pass<ngraph::pass::Liveness>();
pass_manager.register_pass<ngraph::pass::MemoryLayout>(MemoryPoolAlignment);
=======
pass::Manager pass_manager;
// For now, just make everyone row-major.
pass_manager.register_pass<pass::CPUFusion>();
pass_manager.register_pass<pass::AssignLayout<descriptor::layout::DenseTensorViewLayout>>();
pass_manager.register_pass<pass::Liveness>();
pass_manager.register_pass<pass::MemoryLayout>(64);
>>>>>>> master
pass_manager.run_passes(m_function);
codegen::CodeWriter writer;
......
......@@ -95,7 +95,7 @@ static std::vector<T> apply_permutation(std::vector<T> input, ngraph::AxisVector
return output;
}
void ngraph::pass::CPUFusion::construct_gemm_pattern()
void ngraph::runtime::cpu::pass::CPUFusion::construct_gemm_pattern()
{
auto shape_w = Shape{2, 4};
auto shape_x = Shape{4, 1};
......
......@@ -18,13 +18,19 @@
namespace ngraph
{
namespace runtime
{
namespace cpu
{
namespace pass
{
class CPUFusion;
}
}
}
}
class ngraph::pass::CPUFusion : public ngraph::pass::GraphRewrite
class ngraph::runtime::cpu::pass::CPUFusion : public ngraph::pass::GraphRewrite
{
public:
CPUFusion()
......
......@@ -142,7 +142,7 @@ TEST(cpu_fusion, cpu_fusion_pass_basic)
auto add = dot + broadcast;
auto graph = make_shared<op::Abs>(add);
pass::Manager pass_manager;
pass_manager.register_pass<pass::CPUFusion>();
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>();
auto func = make_shared<Function>(graph, op::Parameters{A, B, C});
pass_manager.run_passes(func);
ASSERT_NE(std::dynamic_pointer_cast<op::MatmulBias>(graph->get_input_op(0)), nullptr);
......@@ -155,7 +155,7 @@ TEST(cpu_fusion, gemm_mlp)
stringstream ss(json_string);
shared_ptr<Function> func = ngraph::deserialize(ss);
pass::Manager pass_manager;
pass_manager.register_pass<pass::CPUFusion>();
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>();
pass_manager.run_passes(func);
size_t ccg = count_ops_of_type<op::MatmulBias>(func);
ASSERT_EQ(ccg, 3);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment