Commit 7ac72a21 authored by Jaikrishnan Menon's avatar Jaikrishnan Menon

CPU: Implement scalar broadcast

parent 22e1368a
......@@ -20,6 +20,7 @@
#include "ngraph/descriptor/layout/dense_tensor_view_layout.hpp"
#include "ngraph/node.hpp"
#include "ngraph/ops/broadcast.hpp"
#include "ngraph/ops/concatenate.hpp"
#include "ngraph/ops/constant.hpp"
#include "ngraph/ops/get_tuple_element.hpp"
......@@ -771,3 +772,47 @@ void Emitter::EMITTER_DECL(EmitParameterizedConstantUInt64)
TU += "};\n }\n";
}
void Emitter::EMITTER_DECL(EmitBroadcast)
{
auto broadcast = static_cast<const op::Broadcast*>(n);
auto arg_tensor_type =
dynamic_pointer_cast<const TensorViewType>(n->get_arguments().at(0)->get_value_type());
assert(arg_tensor_type);
auto result_tensor_type = dynamic_pointer_cast<const TensorViewType>(n->get_value_type());
assert(result_tensor_type);
auto arg_shape = arg_tensor_type->get_shape();
auto result_shape = result_tensor_type->get_shape();
auto& result_element_type = result_tensor_type->get_element_type();
if (broadcast->get_broadcast_axes().empty())
{
TU +=
" {\n"
" call_frame->get_parameterized_tensor_view<" +
element_type_names[TI(result_element_type)] + ">(" + to_string(outputs[0].get_index()) +
")->get_vector() =\n"
" call_frame->get_parameterized_tensor_view<" +
element_type_names[TI(result_element_type)] + ">(" + to_string(inputs[0].get_index()) +
")->get_vector();\n"
" }\n";
}
else if (arg_shape.size() == 0)
{
TU += " {\n"
" auto arg0 = call_frame->get_tensor_view_data<" + element_type_names[TI(result_element_type)] + ">(" + to_string(inputs[0].get_index()) + ");\n"
" auto out = call_frame->get_tensor_view_data<" + element_type_names[TI(result_element_type)] + ">(" + to_string(outputs[0].get_index()) + ");\n"
" EigenArray1d<" + element_type_names[TI(result_element_type)] + ">(out, "
EIGEN_VECTOR_FORMAT(outputs[0].get_layout<DenseTensorViewLayout>()->get_size()) ") =\n"
" EigenArray1d<" + element_type_names[TI(result_element_type)] + ">(arg0, "
EIGEN_VECTOR_FORMAT(inputs[0].get_layout<DenseTensorViewLayout>()->get_size()) ")(0, 0);\n"
" }\n";
}
else
{
throw ngraph_error("Broadcast not implemented for given inputs");
}
}
......@@ -73,6 +73,7 @@ namespace ngraph
void EMITTER_DECL(EmitParameterizedConstantUInt8);
void EMITTER_DECL(EmitParameterizedConstantUInt32);
void EMITTER_DECL(EmitParameterizedConstantUInt64);
void EMITTER_DECL(EmitBroadcast);
};
}
}
......
......@@ -62,6 +62,7 @@
// TODO: Decide if we want to ship this or
// just enable it for developer build-test cycles
//#define NGCPU_PCH
//#define NGCPU_DEBUGINFO
using namespace std;
using namespace ngraph::runtime::cpu;
......@@ -107,7 +108,7 @@ static const OpMap dispatcher{
&Emitter::EmitParameterizedConstantUInt32},
{TI(ngraph::op::ParameterizedConstant<ngraph::element::UInt64>),
&Emitter::EmitParameterizedConstantUInt64},
{TI(ngraph::op::Broadcast), &Emitter::EmitBroadcast},
};
#undef TI
......@@ -252,6 +253,10 @@ extern "C" void __entrypoint(ngraph::runtime::cpu::CallFrame* call_frame,
estate.enable_pch();
#endif
#if defined(NGCPU_DEBUGINFO)
estate.enable_debuginfo();
#endif
auto llvm_module = estate.compile(TU, "__ngcpu_codegen.cpp");
assert(llvm_module);
estate.add_module(llvm_module);
......
......@@ -1128,8 +1128,9 @@ TEST(execute, function_call)
(*cf)({x, z, y}, {result});
ASSERT_EQ((vector<float>{100, 144, 196, 256}), result->get_vector());
}
*/
TEST(execute, broadcast_scalar_vector)
TEST(cpu, broadcast_scalar_vector)
{
auto shape_a = Shape{};
auto A = make_shared<op::Parameter>(element::Float32::element_type(), shape_a);
......@@ -1138,7 +1139,7 @@ TEST(execute, broadcast_scalar_vector)
auto f = make_shared<Function>(
make_shared<op::Broadcast>(A, shape_r, AxisSet{0}), rt, op::Parameters{A});
auto manager = runtime::Manager::get("NGVM");
auto manager = runtime::Manager::get("CPU");
auto external = manager->compile(f);
auto backend = manager->allocate_backend();
auto cf = backend->make_call_frame(external);
......@@ -1152,7 +1153,7 @@ TEST(execute, broadcast_scalar_vector)
ASSERT_EQ((vector<float>{6, 6, 6, 6}), result->get_vector());
}
TEST(execute, broadcast_scalar_matrix)
TEST(cpu, broadcast_scalar_matrix)
{
auto shape_a = Shape{};
auto A = make_shared<op::Parameter>(element::Float32::element_type(), shape_a);
......@@ -1161,7 +1162,7 @@ TEST(execute, broadcast_scalar_matrix)
auto f = make_shared<Function>(
make_shared<op::Broadcast>(A, shape_r, AxisSet{0, 1}), rt, op::Parameters{A});
auto manager = runtime::Manager::get("NGVM");
auto manager = runtime::Manager::get("CPU");
auto external = manager->compile(f);
auto backend = manager->allocate_backend();
auto cf = backend->make_call_frame(external);
......@@ -1175,7 +1176,7 @@ TEST(execute, broadcast_scalar_matrix)
ASSERT_EQ((vector<float>{6, 6, 6, 6}), result->get_vector());
}
TEST(execute, broadcast_scalar_tensor)
TEST(cpu, broadcast_scalar_tensor)
{
auto shape_a = Shape{};
auto A = make_shared<op::Parameter>(element::Float32::element_type(), shape_a);
......@@ -1184,7 +1185,7 @@ TEST(execute, broadcast_scalar_tensor)
auto f = make_shared<Function>(
make_shared<op::Broadcast>(A, shape_r, AxisSet{0, 1, 2}), rt, op::Parameters{A});
auto manager = runtime::Manager::get("NGVM");
auto manager = runtime::Manager::get("CPU");
auto external = manager->compile(f);
auto backend = manager->allocate_backend();
auto cf = backend->make_call_frame(external);
......@@ -1198,7 +1199,7 @@ TEST(execute, broadcast_scalar_tensor)
ASSERT_EQ((vector<float>{6, 6, 6, 6, 6, 6, 6, 6}), result->get_vector());
}
TEST(execute, broadcast_trivial)
TEST(cpu, broadcast_trivial)
{
auto shape = Shape{2, 2, 2};
auto A = make_shared<op::Parameter>(element::Float32::element_type(), shape);
......@@ -1206,7 +1207,7 @@ TEST(execute, broadcast_trivial)
auto f = make_shared<Function>(
make_shared<op::Broadcast>(A, shape, AxisSet{}), rt, op::Parameters{A});
auto manager = runtime::Manager::get("NGVM");
auto manager = runtime::Manager::get("CPU");
auto external = manager->compile(f);
auto backend = manager->allocate_backend();
auto cf = backend->make_call_frame(external);
......@@ -1220,6 +1221,7 @@ TEST(execute, broadcast_trivial)
ASSERT_EQ((vector<float>{2, 4, 6, 8, 16, 32, 64, 128}), result->get_vector());
}
/*
TEST(execute, broadcast_vector_colwise)
{
auto shape_a = Shape{3};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment