Commit a48a75d5 authored by Jaikrishnan Menon's avatar Jaikrishnan Menon

CPU: Implement op::Convert

parent 5bf6c0ef
...@@ -855,3 +855,30 @@ void Emitter::EMITTER_DECL(EmitBroadcast) ...@@ -855,3 +855,30 @@ void Emitter::EMITTER_DECL(EmitBroadcast)
throw ngraph_error("Broadcast not implemented for given inputs"); throw ngraph_error("Broadcast not implemented for given inputs");
} }
} }
void Emitter::EMITTER_DECL(EmitConvert)
{
auto arg = n->get_arguments().at(0);
auto arg_tensor_type = dynamic_pointer_cast<const TensorViewType>(arg->get_value_type());
assert(arg_tensor_type);
auto& arg_element_type = arg_tensor_type->get_element_type();
auto result_tensor_type = dynamic_pointer_cast<const TensorViewType>(n->get_value_type());
assert(result_tensor_type);
auto& result_element_type = result_tensor_type->get_element_type();
TU += " {\n"
" auto arg0 = call_frame->get_tensor_view_data<" + element_type_names[TI(arg_element_type)] +
">(" + to_string(inputs[0].get_index()) + ");\n"
" auto out = call_frame->get_tensor_view_data<" + element_type_names[TI(result_element_type)] +
">(" + to_string(outputs[0].get_index()) + ");\n"
" EigenArray1d<" + element_type_names[TI(result_element_type)] + ">(out, "
EIGEN_VECTOR_FORMAT(outputs[0].get_layout<DenseTensorViewLayout>()->get_size()) ") =\n"
" EigenArray1d<" + element_type_names[TI(arg_element_type)] + ">(arg0, "
EIGEN_VECTOR_FORMAT(inputs[0].get_layout<DenseTensorViewLayout>()->get_size()) ")\n"
".template cast<typename " + element_type_names[TI(result_element_type)] + "::type>();\n"
" }\n";
}
...@@ -74,6 +74,7 @@ namespace ngraph ...@@ -74,6 +74,7 @@ namespace ngraph
void EMITTER_DECL(EmitParameterizedConstantUInt32); void EMITTER_DECL(EmitParameterizedConstantUInt32);
void EMITTER_DECL(EmitParameterizedConstantUInt64); void EMITTER_DECL(EmitParameterizedConstantUInt64);
void EMITTER_DECL(EmitBroadcast); void EMITTER_DECL(EmitBroadcast);
void EMITTER_DECL(EmitConvert);
}; };
} }
} }
......
...@@ -104,6 +104,7 @@ static const OpMap dispatcher{ ...@@ -104,6 +104,7 @@ static const OpMap dispatcher{
{TI(ngraph::op::ParameterizedConstant<ngraph::element::UInt64>), {TI(ngraph::op::ParameterizedConstant<ngraph::element::UInt64>),
&Emitter::EmitParameterizedConstantUInt64}, &Emitter::EmitParameterizedConstantUInt64},
{TI(ngraph::op::Broadcast), &Emitter::EmitBroadcast}, {TI(ngraph::op::Broadcast), &Emitter::EmitBroadcast},
{TI(ngraph::op::Convert), &Emitter::EmitConvert},
}; };
#undef TI #undef TI
......
...@@ -1291,8 +1291,7 @@ TEST(cpu, broadcast_vector_rowwise_int64) ...@@ -1291,8 +1291,7 @@ TEST(cpu, broadcast_vector_rowwise_int64)
result->get_vector()); result->get_vector());
} }
/* TEST(cpu, convert_int32_float32)
TEST(execute, convert_int32_float32)
{ {
auto shape = Shape{2, 2}; auto shape = Shape{2, 2};
auto A = make_shared<op::Parameter>(element::Int32::element_type(), shape); auto A = make_shared<op::Parameter>(element::Int32::element_type(), shape);
...@@ -1300,7 +1299,7 @@ TEST(execute, convert_int32_float32) ...@@ -1300,7 +1299,7 @@ TEST(execute, convert_int32_float32)
auto f = make_shared<Function>( auto f = make_shared<Function>(
make_shared<op::Convert>(A, element::Float32::element_type()), rt, op::Parameters{A}); make_shared<op::Convert>(A, element::Float32::element_type()), rt, op::Parameters{A});
auto manager = runtime::Manager::get("NGVM"); auto manager = runtime::Manager::get("CPU");
auto external = manager->compile(f); auto external = manager->compile(f);
auto backend = manager->allocate_backend(); auto backend = manager->allocate_backend();
auto cf = backend->make_call_frame(external); auto cf = backend->make_call_frame(external);
...@@ -1314,7 +1313,7 @@ TEST(execute, convert_int32_float32) ...@@ -1314,7 +1313,7 @@ TEST(execute, convert_int32_float32)
ASSERT_EQ((vector<element::Float32::type>{1, 2, 3, 4}), result->get_vector()); ASSERT_EQ((vector<element::Float32::type>{1, 2, 3, 4}), result->get_vector());
} }
TEST(execute, convert_int32_bool) TEST(cpu, convert_int32_bool)
{ {
auto shape = Shape{2, 2}; auto shape = Shape{2, 2};
auto A = make_shared<op::Parameter>(element::Int32::element_type(), shape); auto A = make_shared<op::Parameter>(element::Int32::element_type(), shape);
...@@ -1322,7 +1321,7 @@ TEST(execute, convert_int32_bool) ...@@ -1322,7 +1321,7 @@ TEST(execute, convert_int32_bool)
auto f = make_shared<Function>( auto f = make_shared<Function>(
make_shared<op::Convert>(A, element::Bool::element_type()), rt, op::Parameters{A}); make_shared<op::Convert>(A, element::Bool::element_type()), rt, op::Parameters{A});
auto manager = runtime::Manager::get("NGVM"); auto manager = runtime::Manager::get("CPU");
auto external = manager->compile(f); auto external = manager->compile(f);
auto backend = manager->allocate_backend(); auto backend = manager->allocate_backend();
auto cf = backend->make_call_frame(external); auto cf = backend->make_call_frame(external);
...@@ -1336,7 +1335,7 @@ TEST(execute, convert_int32_bool) ...@@ -1336,7 +1335,7 @@ TEST(execute, convert_int32_bool)
ASSERT_EQ((vector<element::Bool::type>{1, 2, 3, 4}), result->get_vector()); ASSERT_EQ((vector<element::Bool::type>{1, 2, 3, 4}), result->get_vector());
} }
TEST(execute, convert_float32_bool) TEST(cpu, convert_float32_bool)
{ {
auto shape = Shape{2, 2}; auto shape = Shape{2, 2};
auto A = make_shared<op::Parameter>(element::Float32::element_type(), shape); auto A = make_shared<op::Parameter>(element::Float32::element_type(), shape);
...@@ -1344,7 +1343,7 @@ TEST(execute, convert_float32_bool) ...@@ -1344,7 +1343,7 @@ TEST(execute, convert_float32_bool)
auto f = make_shared<Function>( auto f = make_shared<Function>(
make_shared<op::Convert>(A, element::Bool::element_type()), rt, op::Parameters{A}); make_shared<op::Convert>(A, element::Bool::element_type()), rt, op::Parameters{A});
auto manager = runtime::Manager::get("NGVM"); auto manager = runtime::Manager::get("CPU");
auto external = manager->compile(f); auto external = manager->compile(f);
auto backend = manager->allocate_backend(); auto backend = manager->allocate_backend();
auto cf = backend->make_call_frame(external); auto cf = backend->make_call_frame(external);
...@@ -1358,6 +1357,7 @@ TEST(execute, convert_float32_bool) ...@@ -1358,6 +1357,7 @@ TEST(execute, convert_float32_bool)
ASSERT_EQ((vector<element::Bool::type>{1, 2, 3, 4}), result->get_vector()); ASSERT_EQ((vector<element::Bool::type>{1, 2, 3, 4}), result->get_vector());
} }
/*
// Trivial case with no reduction axes. // Trivial case with no reduction axes.
TEST(execute, reduce_trivial) TEST(execute, reduce_trivial)
{ {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment