Commit 6ebc3c8c authored by Fenglei's avatar Fenglei Committed by adstraw

Dot op that can handle more than 2D on GPU (#645)

* general dot for gpu
parent bae77590
......@@ -182,6 +182,7 @@ cudnnSetOpTensorDescriptor(opTensorDesc,
const ngraph::op::Dot* dot = static_cast<const ngraph::op::Dot*>(node);
const Shape& arg0_shape = args[0].get_shape();
const Shape& arg1_shape = args[1].get_shape();
const Shape& out_shape = out[0].get_shape();
if (arg0_shape.empty() || arg1_shape.empty())
{
auto& first = (arg0_shape.empty() ? args[0] : args[1]);
......@@ -201,7 +202,7 @@ cudnnSetOpTensorDescriptor(opTensorDesc,
return;
}
//set output to 0 if input size is 0
// set output to 0 if input size is 0
if (args[0].get_size() == 0 || args[1].get_size() == 0)
{
writer.block_begin(" // " + node->get_name());
......@@ -211,7 +212,7 @@ cudnnSetOpTensorDescriptor(opTensorDesc,
return;
}
//case that can be treat as dot1d
// case that can be treat as dot1d
if ((arg0_shape.size() == arg1_shape.size()) &&
(arg0_shape.size() == dot->get_reduction_axes_count()))
......@@ -220,8 +221,8 @@ cudnnSetOpTensorDescriptor(opTensorDesc,
{
if (arg0_shape[i] != arg1_shape[i])
{
throw std::runtime_error(
"input1 and input2 shape does not match for dot;");
throw std::invalid_argument(
"arg0 and arg1 shape does not match for dot.");
}
}
writer.block_begin(" // " + node->get_name());
......@@ -232,6 +233,7 @@ cudnnSetOpTensorDescriptor(opTensorDesc,
<< "1," << out[0].get_name() << ");\n";
writer.block_end();
}
// matrix vector
else if ((arg0_shape.size() == 2) && (arg1_shape.size() == 1) &&
(dot->get_reduction_axes_count() == 1))
{
......@@ -252,22 +254,63 @@ cudnnSetOpTensorDescriptor(opTensorDesc,
writer << "cublasSetPointerMode(cublas_handle, CUBLAS_POINTER_MODE_DEVICE);\n";
writer.block_end();
}
else if ((arg0_shape.size() == 2) && (arg1_shape.size() == 2) &&
(dot->get_reduction_axes_count() == 1))
// cases that can be treat as matrix multiply
else
{
// GEMM Call
if (arg0_shape[0] != out[0].get_shape()[0] || // m
arg1_shape[1] != out[0].get_shape()[1] || // n
arg0_shape[1] != arg1_shape[0]) // k
// treat as out[m,n] = arg0[m,k] * arg1[k,n]
size_t reduction_axes = dot->get_reduction_axes_count();
size_t num_of_axes_for_m = arg0_shape.size() - reduction_axes;
size_t num_of_axes_for_n = arg1_shape.size() - reduction_axes;
size_t num_of_axes_for_k = reduction_axes;
size_t m = 1;
size_t n = 1;
size_t k = 1;
// check if input and output size correct
// check and calculate k for arg0 and arg1
size_t arg0_k_idx = num_of_axes_for_m; // first axe in arg0 for k
size_t arg1_k_idx = 0; // first axe in arg1 for k
for (size_t i = 0; i < num_of_axes_for_k; i++)
{
k *= arg0_shape[arg0_k_idx];
if (arg0_shape[arg0_k_idx++] != arg1_shape[arg1_k_idx++])
{
throw std::invalid_argument(
"arg0 and arg1 shape does not match for dot.");
}
}
// check and calculate m for arg0 and out
size_t arg0_m_idx = 0; // first axe in arg0 for m
size_t out_m_idx = 0; // first axe in out for m
for (size_t i = 0; i < num_of_axes_for_m; i++)
{
m *= arg0_shape[arg0_m_idx];
if (arg0_shape[arg0_m_idx++] != out_shape[out_m_idx++])
{
throw std::invalid_argument(
"arg0 and output shape does not match for dot.");
}
}
// check and calculate n for arg1 and out
size_t arg1_n_idx = num_of_axes_for_k; // first axe in arg1 for n
size_t out_n_idx = num_of_axes_for_m; // first axe in arg1 for n
for (size_t i = 0; i < num_of_axes_for_n; i++)
{
throw std::runtime_error("input and output shape does not match for dot;");
n *= arg1_shape[arg1_n_idx];
if (arg1_shape[arg1_n_idx++] != out_shape[out_n_idx++])
{
throw std::invalid_argument(
"arg1 and output shape does not match for dot.");
}
}
// GEMM Call
writer.block_begin(" // " + node->get_name());
writer << "const float alpha = 1.0;\n";
writer << "const float beta = 0.0;\n";
writer << "int m = " << arg0_shape[0] << ";\n";
writer << "int n = " << arg1_shape[1] << ";\n";
writer << "int k = " << arg0_shape[0] << ";\n";
writer << "int m = " << m << ";\n";
writer << "int n = " << n << ";\n";
writer << "int k = " << k << ";\n";
writer << "cublasSetPointerMode(cublas_handle, CUBLAS_POINTER_MODE_HOST);\n";
writer << "cublasSgemm("
<< "cublas_handle,"
......@@ -286,11 +329,6 @@ cudnnSetOpTensorDescriptor(opTensorDesc,
writer << "cublasSetPointerMode(cublas_handle, CUBLAS_POINTER_MODE_DEVICE);\n";
writer.block_end();
}
else
{
throw std::runtime_error(node->get_name() +
" with more then 2D is not implemented.");
}
}
template <>
......@@ -425,7 +463,7 @@ cudnnSetOpTensorDescriptor(opTensorDesc,
auto result_shape = out[0].get_shape();
auto& axes = broadcast->get_broadcast_axes();
//broadcast axes is empty, do a copy
// broadcast axes is empty, do a copy
if (axes.empty())
{
writer.block_begin(" // " + node->get_name());
......@@ -434,7 +472,7 @@ cudnnSetOpTensorDescriptor(opTensorDesc,
return;
}
//broadcast axes size is 1, or can be group to 1 (consecutive axes, like 01 or 12 or 123 etc)
// broadcast axes size is 1, or can be group to 1 (consecutive axes, like 01 or 12 or 123 etc)
vector<int> axes_v;
std::copy(axes.begin(), axes.end(), std::back_inserter(axes_v));
std::sort(axes_v.begin(), axes_v.end());
......@@ -506,7 +544,7 @@ cudnnSetOpTensorDescriptor(opTensorDesc,
result_shape_product *= i;
}
// If there is no layout change or we are just going from 1^n to 1^m or a zero-size tensor,
// we can just copy.
// we can just copy.
if (same_layout || result_shape_product < 2)
{
kernel::emit_memcpyDtD(writer, out[0], args[0]);
......
......@@ -1044,7 +1044,6 @@ TEST(${BACKEND_NAME}, dot2d)
//
TEST(${BACKEND_NAME}, dot3d_3d)
{
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
Shape shape{2, 2, 2};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::Parameter>(element::f32, shape);
......@@ -1093,7 +1092,6 @@ TEST(${BACKEND_NAME}, dot3d_3d)
//
TEST(${BACKEND_NAME}, dot3d_2d)
{
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
Shape shape_a{4, 2, 3};
auto A = make_shared<op::Parameter>(element::f32, shape_a);
Shape shape_b{3, 4};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment