Commit 734facb1 authored by Jaikrishnan Menon's avatar Jaikrishnan Menon

CPU: Implement op::Dot on matrices using Eigen

parent 305dd5b7
......@@ -107,11 +107,11 @@ void Emitter::EMITTER_DECL(EmitDot)
auto arg0_tensor_type =
dynamic_pointer_cast<const TensorViewType>(arg_nodes.at(0)->get_value_type());
assert(nullptr != arg0_tensor_type);
assert(arg0_tensor_type);
auto arg1_tensor_type =
dynamic_pointer_cast<const TensorViewType>(arg_nodes.at(1)->get_value_type());
assert(nullptr != arg1_tensor_type);
assert(arg1_tensor_type);
auto arg0_shape = arg0_tensor_type->get_shape();
auto arg1_shape = arg1_tensor_type->get_shape();
......@@ -152,6 +152,46 @@ void Emitter::EMITTER_DECL(EmitDot)
EIGEN_VECTOR_FORMAT(inputs[1].get_layout<DenseTensorViewLayout>()->get_size()) "));\n"
" }\n";
}
else if ((arg0_shape.size() == 2) && (arg1_shape.size() == 1))
{
auto arg0_layout = inputs[0].get_layout<DenseTensorViewLayout>();
TU += " {\n"
" auto arg0 = call_frame->get_tensor_view_data<" + element_type_names[TI(arg0_element_type)] + ">(" +
to_string(inputs[0].get_index()) + ");\n"
" auto arg1 = call_frame->get_tensor_view_data<" + element_type_names[TI(arg0_element_type)] + ">(" +
to_string(inputs[1].get_index()) + ");\n"
" auto out = call_frame->get_tensor_view_data<" + element_type_names[TI(arg0_element_type)] + ">(" +
to_string(outputs[0].get_index()) + ");\n"
" EigenVector<" + element_type_names[TI(arg0_element_type)] + ">(out, "
EIGEN_VECTOR_FORMAT(outputs[0].get_layout<DenseTensorViewLayout>()->get_size()) ") = \n"
" EigenMatrix<" + element_type_names[TI(arg0_element_type)] + ">(arg0, " +
EIGEN_MATRIX_FORMAT(arg0_layout->get_shape(), arg0_layout->get_strides()) + ") * "
"EigenVector<" + element_type_names[TI(arg0_element_type)] + ">(arg1, "
EIGEN_VECTOR_FORMAT(inputs[1].get_layout<DenseTensorViewLayout>()->get_size()) ");\n"
" }\n";
}
else if ((arg0_shape.size() == 2) && (arg1_shape.size() == 2))
{
auto arg0_layout = inputs[0].get_layout<DenseTensorViewLayout>();
auto arg1_layout = inputs[1].get_layout<DenseTensorViewLayout>();
auto out_layout = outputs[0].get_layout<DenseTensorViewLayout>();
TU += " {\n"
" auto arg0 = call_frame->get_tensor_view_data<" + element_type_names[TI(arg0_element_type)] + ">(" +
to_string(inputs[0].get_index()) + ");\n"
" auto arg1 = call_frame->get_tensor_view_data<" + element_type_names[TI(arg0_element_type)] + ">(" +
to_string(inputs[1].get_index()) + ");\n"
" auto out = call_frame->get_tensor_view_data<" + element_type_names[TI(arg0_element_type)] + ">(" +
to_string(outputs[0].get_index()) + ");\n"
" EigenMatrix<" + element_type_names[TI(arg0_element_type)] + ">(out, " +
EIGEN_MATRIX_FORMAT(out_layout->get_shape(), out_layout->get_strides()) + ") = \n"
" EigenMatrix<" + element_type_names[TI(arg0_element_type)] + ">(arg0, " +
EIGEN_MATRIX_FORMAT(arg0_layout->get_shape(), arg0_layout->get_strides()) + ") * "
"EigenMatrix<" + element_type_names[TI(arg0_element_type)] + ">(arg1, " +
EIGEN_MATRIX_FORMAT(arg1_layout->get_shape(), arg1_layout->get_strides()) + ");\n"
" }\n";
}
else
{
throw ngraph_error("Dot product for given tensors unimplemented");
......
......@@ -467,14 +467,13 @@ TEST(cpu, dot_0_0)
ASSERT_EQ((vector<float>{0}), result->get_vector());
}
/*
TEST(execute, dot_matrix_2x0_0x2)
TEST(cpu, dot_matrix_2x0_0x2)
{
auto shape_a = Shape{2, 0};
auto shape_b = Shape{0, 2};
auto shape_r = Shape{2, 2};
auto manager = runtime::Manager::get("NGVM");
auto manager = runtime::Manager::get("CPU");
auto backend = manager->allocate_backend();
auto make_external = [&]() {
......@@ -500,7 +499,7 @@ TEST(execute, dot_matrix_2x0_0x2)
ASSERT_EQ((vector<float>{0, 0, 0, 0}), result->get_vector());
}
TEST(execute, dot_matrix_0x2_2x0)
TEST(cpu, dot_matrix_0x2_2x0)
{
auto shape_a = Shape{0, 2};
auto A = make_shared<op::Parameter>(element::Float32::element_type(), shape_a);
......@@ -510,7 +509,7 @@ TEST(execute, dot_matrix_0x2_2x0)
auto rt = make_shared<TensorViewType>(element::Float32::element_type(), shape_r);
auto f = make_shared<Function>(make_shared<op::Dot>(A, B), rt, op::Parameters{A, B});
auto manager = runtime::Manager::get("NGVM");
auto manager = runtime::Manager::get("CPU");
auto external = manager->compile(f);
auto backend = manager->allocate_backend();
auto cf = backend->make_call_frame(external);
......@@ -526,7 +525,7 @@ TEST(execute, dot_matrix_0x2_2x0)
ASSERT_EQ((vector<float>{}), result->get_vector());
}
TEST(execute, dot_matrix_3x2_2x0)
TEST(cpu, dot_matrix_3x2_2x0)
{
auto shape_a = Shape{3, 2};
auto A = make_shared<op::Parameter>(element::Float32::element_type(), shape_a);
......@@ -536,7 +535,7 @@ TEST(execute, dot_matrix_3x2_2x0)
auto rt = make_shared<TensorViewType>(element::Float32::element_type(), shape_r);
auto f = make_shared<Function>(make_shared<op::Dot>(A, B), rt, op::Parameters{A, B});
auto manager = runtime::Manager::get("NGVM");
auto manager = runtime::Manager::get("CPU");
auto external = manager->compile(f);
auto backend = manager->allocate_backend();
auto cf = backend->make_call_frame(external);
......@@ -552,6 +551,7 @@ TEST(execute, dot_matrix_3x2_2x0)
ASSERT_EQ((vector<float>{}), result->get_vector());
}
/*
TEST(execute, dot_scalar_0x2)
{
auto shape_a = Shape{};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment