Commit 6c676d2d authored by Jaikrishnan Menon's avatar Jaikrishnan Menon

CPU: Merge fixes

parent 2659d5be
...@@ -226,20 +226,13 @@ void runtime::cpu::CPU_ExternalFunction::compile() ...@@ -226,20 +226,13 @@ void runtime::cpu::CPU_ExternalFunction::compile()
string function_name = m_function->get_name(); string function_name = m_function->get_name();
<<<<<<< HEAD
ngraph::pass::Manager pass_manager; ngraph::pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>();
pass_manager.register_pass<runtime::cpu::pass::CPULayout>(); pass_manager.register_pass<runtime::cpu::pass::CPULayout>();
pass_manager.register_pass<ngraph::pass::Liveness>(); pass_manager.register_pass<ngraph::pass::Liveness>();
pass_manager.register_pass<ngraph::pass::MemoryLayout>(MemoryPoolAlignment); pass_manager.register_pass<ngraph::pass::MemoryLayout>(MemoryPoolAlignment);
=======
pass::Manager pass_manager;
// For now, just make everyone row-major.
pass_manager.register_pass<pass::CPUFusion>();
pass_manager.register_pass<pass::AssignLayout<descriptor::layout::DenseTensorViewLayout>>();
pass_manager.register_pass<pass::Liveness>();
pass_manager.register_pass<pass::MemoryLayout>(64);
>>>>>>> master
pass_manager.run_passes(m_function); pass_manager.run_passes(m_function);
codegen::CodeWriter writer; codegen::CodeWriter writer;
......
...@@ -95,7 +95,7 @@ static std::vector<T> apply_permutation(std::vector<T> input, ngraph::AxisVector ...@@ -95,7 +95,7 @@ static std::vector<T> apply_permutation(std::vector<T> input, ngraph::AxisVector
return output; return output;
} }
void ngraph::pass::CPUFusion::construct_gemm_pattern() void ngraph::runtime::cpu::pass::CPUFusion::construct_gemm_pattern()
{ {
auto shape_w = Shape{2, 4}; auto shape_w = Shape{2, 4};
auto shape_x = Shape{4, 1}; auto shape_x = Shape{4, 1};
......
...@@ -18,13 +18,19 @@ ...@@ -18,13 +18,19 @@
namespace ngraph namespace ngraph
{ {
namespace runtime
{
namespace cpu
{
namespace pass namespace pass
{ {
class CPUFusion; class CPUFusion;
} }
}
}
} }
class ngraph::pass::CPUFusion : public ngraph::pass::GraphRewrite class ngraph::runtime::cpu::pass::CPUFusion : public ngraph::pass::GraphRewrite
{ {
public: public:
CPUFusion() CPUFusion()
......
...@@ -142,7 +142,7 @@ TEST(cpu_fusion, cpu_fusion_pass_basic) ...@@ -142,7 +142,7 @@ TEST(cpu_fusion, cpu_fusion_pass_basic)
auto add = dot + broadcast; auto add = dot + broadcast;
auto graph = make_shared<op::Abs>(add); auto graph = make_shared<op::Abs>(add);
pass::Manager pass_manager; pass::Manager pass_manager;
pass_manager.register_pass<pass::CPUFusion>(); pass_manager.register_pass<runtime::cpu::pass::CPUFusion>();
auto func = make_shared<Function>(graph, op::Parameters{A, B, C}); auto func = make_shared<Function>(graph, op::Parameters{A, B, C});
pass_manager.run_passes(func); pass_manager.run_passes(func);
ASSERT_NE(std::dynamic_pointer_cast<op::MatmulBias>(graph->get_input_op(0)), nullptr); ASSERT_NE(std::dynamic_pointer_cast<op::MatmulBias>(graph->get_input_op(0)), nullptr);
...@@ -155,7 +155,7 @@ TEST(cpu_fusion, gemm_mlp) ...@@ -155,7 +155,7 @@ TEST(cpu_fusion, gemm_mlp)
stringstream ss(json_string); stringstream ss(json_string);
shared_ptr<Function> func = ngraph::deserialize(ss); shared_ptr<Function> func = ngraph::deserialize(ss);
pass::Manager pass_manager; pass::Manager pass_manager;
pass_manager.register_pass<pass::CPUFusion>(); pass_manager.register_pass<runtime::cpu::pass::CPUFusion>();
pass_manager.run_passes(func); pass_manager.run_passes(func);
size_t ccg = count_ops_of_type<op::MatmulBias>(func); size_t ccg = count_ops_of_type<op::MatmulBias>(func);
ASSERT_EQ(ccg, 3); ASSERT_EQ(ccg, 3);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment