Commit c82d9e9e authored by Jaikrishnan Menon's avatar Jaikrishnan Menon

CPU: Implement column-wise and row-wise broadcast

parent 7f27722a
......@@ -204,7 +204,7 @@ void Emitter::EMITTER_DECL(EmitDot)
}
else
{
throw ngraph_error("Dot product for given tensors unimplemented");
throw ngraph_error("Dot product not implemented for given inputs");
}
}
......@@ -811,6 +811,45 @@ void Emitter::EMITTER_DECL(EmitBroadcast)
EIGEN_VECTOR_FORMAT(inputs[0].get_layout<DenseTensorViewLayout>()->get_size()) ")(0, 0);\n"
" }\n";
}
else if (arg_shape.size() == 1 && result_shape.size() == 2)
{
if (broadcast->get_broadcast_axes() == AxisSet{1})
{
auto out_layout = outputs[0].get_layout<DenseTensorViewLayout>();
TU += " {\n"
" auto arg0 = call_frame->get_tensor_view_data<" + element_type_names[TI(result_element_type)] +
">(" + to_string(inputs[0].get_index()) + ");\n"
" auto out = call_frame->get_tensor_view_data<" + element_type_names[TI(result_element_type)] +
">(" + to_string(outputs[0].get_index()) + ");\n"
" EigenMatrix<" + element_type_names[TI(result_element_type)] + ">(out, " +
EIGEN_MATRIX_FORMAT(out_layout->get_shape(), out_layout->get_strides()) + ").colwise() =\n"
" EigenVector<" + element_type_names[TI(result_element_type)] + ">(arg0, "
EIGEN_VECTOR_FORMAT(inputs[0].get_layout<DenseTensorViewLayout>()->get_size()) ");\n"
" }\n";
}
else if (broadcast->get_broadcast_axes() == AxisSet{0})
{
auto out_layout = outputs[0].get_layout<DenseTensorViewLayout>();
TU += " {\n"
" auto arg0 = call_frame->get_tensor_view_data<" + element_type_names[TI(result_element_type)] +
">(" + to_string(inputs[0].get_index()) + ");\n"
" auto out = call_frame->get_tensor_view_data<" + element_type_names[TI(result_element_type)] +
">(" + to_string(outputs[0].get_index()) + ");\n"
" EigenMatrix<" + element_type_names[TI(result_element_type)] + ">(out, " +
EIGEN_MATRIX_FORMAT(out_layout->get_shape(), out_layout->get_strides()) + ").rowwise() =\n"
" EigenVector<" + element_type_names[TI(result_element_type)] + ">(arg0, "
EIGEN_VECTOR_FORMAT(inputs[0].get_layout<DenseTensorViewLayout>()->get_size()) ").transpose();\n"
" }\n";
}
else
{
throw ngraph_error(
"Internal error: axis set for vector-matrix broadcast is neither {0} nor "
"{1}");
}
}
else
{
throw ngraph_error("Broadcast not implemented for given inputs");
......
......@@ -1221,8 +1221,7 @@ TEST(cpu, broadcast_trivial)
ASSERT_EQ((vector<float>{2, 4, 6, 8, 16, 32, 64, 128}), result->get_vector());
}
/*
TEST(execute, broadcast_vector_colwise)
TEST(cpu, broadcast_vector_colwise)
{
auto shape_a = Shape{3};
auto A = make_shared<op::Parameter>(element::Float32::element_type(), shape_a);
......@@ -1231,7 +1230,7 @@ TEST(execute, broadcast_vector_colwise)
auto f = make_shared<Function>(
make_shared<op::Broadcast>(A, shape_r, AxisSet{1}), rt, op::Parameters{A});
auto manager = runtime::Manager::get("NGVM");
auto manager = runtime::Manager::get("CPU");
auto external = manager->compile(f);
auto backend = manager->allocate_backend();
auto cf = backend->make_call_frame(external);
......@@ -1245,7 +1244,7 @@ TEST(execute, broadcast_vector_colwise)
ASSERT_EQ((vector<float>{1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3}), result->get_vector());
}
TEST(execute, broadcast_vector_rowwise)
TEST(cpu, broadcast_vector_rowwise)
{
auto shape_a = Shape{4};
auto A = make_shared<op::Parameter>(element::Float32::element_type(), shape_a);
......@@ -1254,7 +1253,7 @@ TEST(execute, broadcast_vector_rowwise)
auto f = make_shared<Function>(
make_shared<op::Broadcast>(A, shape_r, AxisSet{0}), rt, op::Parameters{A});
auto manager = runtime::Manager::get("NGVM");
auto manager = runtime::Manager::get("CPU");
auto external = manager->compile(f);
auto backend = manager->allocate_backend();
auto cf = backend->make_call_frame(external);
......@@ -1268,7 +1267,7 @@ TEST(execute, broadcast_vector_rowwise)
ASSERT_EQ((vector<float>{1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4}), result->get_vector());
}
TEST(execute, broadcast_vector_rowwise_int64)
TEST(cpu, broadcast_vector_rowwise_int64)
{
auto shape_a = Shape{4};
auto A = make_shared<op::Parameter>(element::Int64::element_type(), shape_a);
......@@ -1277,7 +1276,7 @@ TEST(execute, broadcast_vector_rowwise_int64)
auto f = make_shared<Function>(
make_shared<op::Broadcast>(A, shape_r, AxisSet{0}), rt, op::Parameters{A});
auto manager = runtime::Manager::get("NGVM");
auto manager = runtime::Manager::get("CPU");
auto external = manager->compile(f);
auto backend = manager->allocate_backend();
auto cf = backend->make_call_frame(external);
......@@ -1292,6 +1291,7 @@ TEST(execute, broadcast_vector_rowwise_int64)
result->get_vector());
}
/*
TEST(execute, convert_int32_float32)
{
auto shape = Shape{2, 2};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment