Commit 0c5d5a65 authored by Jaikrishnan Menon's avatar Jaikrishnan Menon

CPU: Enable more concatenation tests

parent 1e73a52c
...@@ -195,7 +195,7 @@ void Emitter::EMITTER_DECL(EmitConcat) ...@@ -195,7 +195,7 @@ void Emitter::EMITTER_DECL(EmitConcat)
to_string(inputs[i].get_tensor_view_layout()->get_shape().at(0)) + ") << " to_string(inputs[i].get_tensor_view_layout()->get_shape().at(0)) + ") << "
"EigenVector<" + element_type_names[TI(result_element_type)] + ">(call_frame->" "EigenVector<" + element_type_names[TI(result_element_type)] + ">(call_frame->"
"get_tensor_view_data<" + element_type_names[TI(result_element_type)] + ">(" + "get_tensor_view_data<" + element_type_names[TI(result_element_type)] + ">(" +
to_string(outputs[0].get_index()) + "), " to_string(inputs[i].get_index()) + "), "
EIGEN_VECTOR_FORMAT(inputs[i].get_layout<DenseTensorViewLayout>()->get_size()) ");\n"; EIGEN_VECTOR_FORMAT(inputs[i].get_layout<DenseTensorViewLayout>()->get_size()) ");\n";
concat_pos += inputs[i].get_tensor_view_layout()->get_shape().at(0); concat_pos += inputs[i].get_tensor_view_layout()->get_shape().at(0);
} }
...@@ -204,15 +204,6 @@ void Emitter::EMITTER_DECL(EmitConcat) ...@@ -204,15 +204,6 @@ void Emitter::EMITTER_DECL(EmitConcat)
} }
else if (result_shape.size() == 2) else if (result_shape.size() == 2)
{ {
/*
PUSH_POLYMORPHIC_INSTRUCTION(
result_element_type,
"Concat has unhandled element type",
eigen::ConcatMatrixInstruction,
in,
(dynamic_cast<const op::Concat*>(n))->get_concatenation_axis(),
out[0]);
*/
auto out_layout = outputs[0].get_layout<DenseTensorViewLayout>(); auto out_layout = outputs[0].get_layout<DenseTensorViewLayout>();
auto axis = (dynamic_cast<const op::Concat*>(n))->get_concatenation_axis(); auto axis = (dynamic_cast<const op::Concat*>(n))->get_concatenation_axis();
......
...@@ -293,8 +293,7 @@ TEST(cpu, concat_matrix_colwise) ...@@ -293,8 +293,7 @@ TEST(cpu, concat_matrix_colwise)
result->get_vector()); result->get_vector());
} }
/* TEST(cpu, concat_matrix_rowwise)
TEST(execute, concat_matrix_rowwise)
{ {
auto shape_a = Shape{2, 2}; auto shape_a = Shape{2, 2};
auto A = make_shared<op::Parameter>(element::Float32::element_type(), shape_a); auto A = make_shared<op::Parameter>(element::Float32::element_type(), shape_a);
...@@ -307,7 +306,7 @@ TEST(execute, concat_matrix_rowwise) ...@@ -307,7 +306,7 @@ TEST(execute, concat_matrix_rowwise)
auto f = make_shared<Function>( auto f = make_shared<Function>(
make_shared<op::Concat>(Nodes{A, B, C}, 0), rt, op::Parameters{A, B, C}); make_shared<op::Concat>(Nodes{A, B, C}, 0), rt, op::Parameters{A, B, C});
auto manager = runtime::Manager::get("NGVM"); auto manager = runtime::Manager::get("CPU");
auto external = manager->compile(f); auto external = manager->compile(f);
auto backend = manager->allocate_backend(); auto backend = manager->allocate_backend();
auto cf = backend->make_call_frame(external); auto cf = backend->make_call_frame(external);
...@@ -326,7 +325,7 @@ TEST(execute, concat_matrix_rowwise) ...@@ -326,7 +325,7 @@ TEST(execute, concat_matrix_rowwise)
result->get_vector()); result->get_vector());
} }
TEST(execute, concat_matrix_int64) TEST(cpu, concat_matrix_int64)
{ {
auto shape_a = Shape{2, 2}; auto shape_a = Shape{2, 2};
auto A = make_shared<op::Parameter>(element::Int64::element_type(), shape_a); auto A = make_shared<op::Parameter>(element::Int64::element_type(), shape_a);
...@@ -339,7 +338,7 @@ TEST(execute, concat_matrix_int64) ...@@ -339,7 +338,7 @@ TEST(execute, concat_matrix_int64)
auto f = make_shared<Function>( auto f = make_shared<Function>(
make_shared<op::Concat>(Nodes{A, B, C}, 0), rt, op::Parameters{A, B, C}); make_shared<op::Concat>(Nodes{A, B, C}, 0), rt, op::Parameters{A, B, C});
auto manager = runtime::Manager::get("NGVM"); auto manager = runtime::Manager::get("CPU");
auto external = manager->compile(f); auto external = manager->compile(f);
auto backend = manager->allocate_backend(); auto backend = manager->allocate_backend();
auto cf = backend->make_call_frame(external); auto cf = backend->make_call_frame(external);
...@@ -358,7 +357,7 @@ TEST(execute, concat_matrix_int64) ...@@ -358,7 +357,7 @@ TEST(execute, concat_matrix_int64)
result->get_vector()); result->get_vector());
} }
TEST(execute, concat_vector) TEST(cpu, concat_vector)
{ {
auto shape_a = Shape{4}; auto shape_a = Shape{4};
auto A = make_shared<op::Parameter>(element::Float32::element_type(), shape_a); auto A = make_shared<op::Parameter>(element::Float32::element_type(), shape_a);
...@@ -371,7 +370,7 @@ TEST(execute, concat_vector) ...@@ -371,7 +370,7 @@ TEST(execute, concat_vector)
auto f = make_shared<Function>( auto f = make_shared<Function>(
make_shared<op::Concat>(Nodes{A, B, C}, 0), rt, op::Parameters{A, B, C}); make_shared<op::Concat>(Nodes{A, B, C}, 0), rt, op::Parameters{A, B, C});
auto manager = runtime::Manager::get("NGVM"); auto manager = runtime::Manager::get("CPU");
auto external = manager->compile(f); auto external = manager->compile(f);
auto backend = manager->allocate_backend(); auto backend = manager->allocate_backend();
auto cf = backend->make_call_frame(external); auto cf = backend->make_call_frame(external);
...@@ -389,6 +388,7 @@ TEST(execute, concat_vector) ...@@ -389,6 +388,7 @@ TEST(execute, concat_vector)
ASSERT_EQ((vector<float>{2, 4, 8, 16, 1, 2, 4, 8, 16, 32, 18, 19}), result->get_vector()); ASSERT_EQ((vector<float>{2, 4, 8, 16, 1, 2, 4, 8, 16, 32, 18, 19}), result->get_vector());
} }
/*
TEST(execute, divide) TEST(execute, divide)
{ {
auto manager = runtime::Manager::get("NGVM"); auto manager = runtime::Manager::get("NGVM");
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment