Commit 0c5d5a65 authored by Jaikrishnan Menon's avatar Jaikrishnan Menon

CPU: Enable more concatenation tests

parent 1e73a52c
......@@ -195,7 +195,7 @@ void Emitter::EMITTER_DECL(EmitConcat)
to_string(inputs[i].get_tensor_view_layout()->get_shape().at(0)) + ") << "
"EigenVector<" + element_type_names[TI(result_element_type)] + ">(call_frame->"
"get_tensor_view_data<" + element_type_names[TI(result_element_type)] + ">(" +
to_string(outputs[0].get_index()) + "), "
to_string(inputs[i].get_index()) + "), "
EIGEN_VECTOR_FORMAT(inputs[i].get_layout<DenseTensorViewLayout>()->get_size()) ");\n";
concat_pos += inputs[i].get_tensor_view_layout()->get_shape().at(0);
}
......@@ -204,15 +204,6 @@ void Emitter::EMITTER_DECL(EmitConcat)
}
else if (result_shape.size() == 2)
{
/*
PUSH_POLYMORPHIC_INSTRUCTION(
result_element_type,
"Concat has unhandled element type",
eigen::ConcatMatrixInstruction,
in,
(dynamic_cast<const op::Concat*>(n))->get_concatenation_axis(),
out[0]);
*/
auto out_layout = outputs[0].get_layout<DenseTensorViewLayout>();
auto axis = (dynamic_cast<const op::Concat*>(n))->get_concatenation_axis();
......
......@@ -293,8 +293,7 @@ TEST(cpu, concat_matrix_colwise)
result->get_vector());
}
/*
TEST(execute, concat_matrix_rowwise)
TEST(cpu, concat_matrix_rowwise)
{
auto shape_a = Shape{2, 2};
auto A = make_shared<op::Parameter>(element::Float32::element_type(), shape_a);
......@@ -307,7 +306,7 @@ TEST(execute, concat_matrix_rowwise)
auto f = make_shared<Function>(
make_shared<op::Concat>(Nodes{A, B, C}, 0), rt, op::Parameters{A, B, C});
auto manager = runtime::Manager::get("NGVM");
auto manager = runtime::Manager::get("CPU");
auto external = manager->compile(f);
auto backend = manager->allocate_backend();
auto cf = backend->make_call_frame(external);
......@@ -326,7 +325,7 @@ TEST(execute, concat_matrix_rowwise)
result->get_vector());
}
TEST(execute, concat_matrix_int64)
TEST(cpu, concat_matrix_int64)
{
auto shape_a = Shape{2, 2};
auto A = make_shared<op::Parameter>(element::Int64::element_type(), shape_a);
......@@ -339,7 +338,7 @@ TEST(execute, concat_matrix_int64)
auto f = make_shared<Function>(
make_shared<op::Concat>(Nodes{A, B, C}, 0), rt, op::Parameters{A, B, C});
auto manager = runtime::Manager::get("NGVM");
auto manager = runtime::Manager::get("CPU");
auto external = manager->compile(f);
auto backend = manager->allocate_backend();
auto cf = backend->make_call_frame(external);
......@@ -358,7 +357,7 @@ TEST(execute, concat_matrix_int64)
result->get_vector());
}
TEST(execute, concat_vector)
TEST(cpu, concat_vector)
{
auto shape_a = Shape{4};
auto A = make_shared<op::Parameter>(element::Float32::element_type(), shape_a);
......@@ -371,7 +370,7 @@ TEST(execute, concat_vector)
auto f = make_shared<Function>(
make_shared<op::Concat>(Nodes{A, B, C}, 0), rt, op::Parameters{A, B, C});
auto manager = runtime::Manager::get("NGVM");
auto manager = runtime::Manager::get("CPU");
auto external = manager->compile(f);
auto backend = manager->allocate_backend();
auto cf = backend->make_call_frame(external);
......@@ -389,6 +388,7 @@ TEST(execute, concat_vector)
ASSERT_EQ((vector<float>{2, 4, 8, 16, 1, 2, 4, 8, 16, 32, 18, 19}), result->get_vector());
}
/*
TEST(execute, divide)
{
auto manager = runtime::Manager::get("NGVM");
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment