Commit 6ebc3c8c authored by Fenglei's avatar Fenglei Committed by adstraw

Dot op that can handle more than 2D on GPU (#645)

* general dot for gpu
parent bae77590
...@@ -182,6 +182,7 @@ cudnnSetOpTensorDescriptor(opTensorDesc, ...@@ -182,6 +182,7 @@ cudnnSetOpTensorDescriptor(opTensorDesc,
const ngraph::op::Dot* dot = static_cast<const ngraph::op::Dot*>(node); const ngraph::op::Dot* dot = static_cast<const ngraph::op::Dot*>(node);
const Shape& arg0_shape = args[0].get_shape(); const Shape& arg0_shape = args[0].get_shape();
const Shape& arg1_shape = args[1].get_shape(); const Shape& arg1_shape = args[1].get_shape();
const Shape& out_shape = out[0].get_shape();
if (arg0_shape.empty() || arg1_shape.empty()) if (arg0_shape.empty() || arg1_shape.empty())
{ {
auto& first = (arg0_shape.empty() ? args[0] : args[1]); auto& first = (arg0_shape.empty() ? args[0] : args[1]);
...@@ -201,7 +202,7 @@ cudnnSetOpTensorDescriptor(opTensorDesc, ...@@ -201,7 +202,7 @@ cudnnSetOpTensorDescriptor(opTensorDesc,
return; return;
} }
//set output to 0 if input size is 0 // set output to 0 if input size is 0
if (args[0].get_size() == 0 || args[1].get_size() == 0) if (args[0].get_size() == 0 || args[1].get_size() == 0)
{ {
writer.block_begin(" // " + node->get_name()); writer.block_begin(" // " + node->get_name());
...@@ -211,7 +212,7 @@ cudnnSetOpTensorDescriptor(opTensorDesc, ...@@ -211,7 +212,7 @@ cudnnSetOpTensorDescriptor(opTensorDesc,
return; return;
} }
//case that can be treat as dot1d // case that can be treat as dot1d
if ((arg0_shape.size() == arg1_shape.size()) && if ((arg0_shape.size() == arg1_shape.size()) &&
(arg0_shape.size() == dot->get_reduction_axes_count())) (arg0_shape.size() == dot->get_reduction_axes_count()))
...@@ -220,8 +221,8 @@ cudnnSetOpTensorDescriptor(opTensorDesc, ...@@ -220,8 +221,8 @@ cudnnSetOpTensorDescriptor(opTensorDesc,
{ {
if (arg0_shape[i] != arg1_shape[i]) if (arg0_shape[i] != arg1_shape[i])
{ {
throw std::runtime_error( throw std::invalid_argument(
"input1 and input2 shape does not match for dot;"); "arg0 and arg1 shape does not match for dot.");
} }
} }
writer.block_begin(" // " + node->get_name()); writer.block_begin(" // " + node->get_name());
...@@ -232,6 +233,7 @@ cudnnSetOpTensorDescriptor(opTensorDesc, ...@@ -232,6 +233,7 @@ cudnnSetOpTensorDescriptor(opTensorDesc,
<< "1," << out[0].get_name() << ");\n"; << "1," << out[0].get_name() << ");\n";
writer.block_end(); writer.block_end();
} }
// matrix vector
else if ((arg0_shape.size() == 2) && (arg1_shape.size() == 1) && else if ((arg0_shape.size() == 2) && (arg1_shape.size() == 1) &&
(dot->get_reduction_axes_count() == 1)) (dot->get_reduction_axes_count() == 1))
{ {
...@@ -252,22 +254,63 @@ cudnnSetOpTensorDescriptor(opTensorDesc, ...@@ -252,22 +254,63 @@ cudnnSetOpTensorDescriptor(opTensorDesc,
writer << "cublasSetPointerMode(cublas_handle, CUBLAS_POINTER_MODE_DEVICE);\n"; writer << "cublasSetPointerMode(cublas_handle, CUBLAS_POINTER_MODE_DEVICE);\n";
writer.block_end(); writer.block_end();
} }
else if ((arg0_shape.size() == 2) && (arg1_shape.size() == 2) && // cases that can be treat as matrix multiply
(dot->get_reduction_axes_count() == 1)) else
{ {
// GEMM Call // treat as out[m,n] = arg0[m,k] * arg1[k,n]
if (arg0_shape[0] != out[0].get_shape()[0] || // m size_t reduction_axes = dot->get_reduction_axes_count();
arg1_shape[1] != out[0].get_shape()[1] || // n size_t num_of_axes_for_m = arg0_shape.size() - reduction_axes;
arg0_shape[1] != arg1_shape[0]) // k size_t num_of_axes_for_n = arg1_shape.size() - reduction_axes;
size_t num_of_axes_for_k = reduction_axes;
size_t m = 1;
size_t n = 1;
size_t k = 1;
// check if input and output size correct
// check and calculate k for arg0 and arg1
size_t arg0_k_idx = num_of_axes_for_m; // first axe in arg0 for k
size_t arg1_k_idx = 0; // first axe in arg1 for k
for (size_t i = 0; i < num_of_axes_for_k; i++)
{
k *= arg0_shape[arg0_k_idx];
if (arg0_shape[arg0_k_idx++] != arg1_shape[arg1_k_idx++])
{
throw std::invalid_argument(
"arg0 and arg1 shape does not match for dot.");
}
}
// check and calculate m for arg0 and out
size_t arg0_m_idx = 0; // first axe in arg0 for m
size_t out_m_idx = 0; // first axe in out for m
for (size_t i = 0; i < num_of_axes_for_m; i++)
{
m *= arg0_shape[arg0_m_idx];
if (arg0_shape[arg0_m_idx++] != out_shape[out_m_idx++])
{
throw std::invalid_argument(
"arg0 and output shape does not match for dot.");
}
}
// check and calculate n for arg1 and out
size_t arg1_n_idx = num_of_axes_for_k; // first axe in arg1 for n
size_t out_n_idx = num_of_axes_for_m; // first axe in arg1 for n
for (size_t i = 0; i < num_of_axes_for_n; i++)
{ {
throw std::runtime_error("input and output shape does not match for dot;"); n *= arg1_shape[arg1_n_idx];
if (arg1_shape[arg1_n_idx++] != out_shape[out_n_idx++])
{
throw std::invalid_argument(
"arg1 and output shape does not match for dot.");
}
} }
// GEMM Call
writer.block_begin(" // " + node->get_name()); writer.block_begin(" // " + node->get_name());
writer << "const float alpha = 1.0;\n"; writer << "const float alpha = 1.0;\n";
writer << "const float beta = 0.0;\n"; writer << "const float beta = 0.0;\n";
writer << "int m = " << arg0_shape[0] << ";\n"; writer << "int m = " << m << ";\n";
writer << "int n = " << arg1_shape[1] << ";\n"; writer << "int n = " << n << ";\n";
writer << "int k = " << arg0_shape[0] << ";\n"; writer << "int k = " << k << ";\n";
writer << "cublasSetPointerMode(cublas_handle, CUBLAS_POINTER_MODE_HOST);\n"; writer << "cublasSetPointerMode(cublas_handle, CUBLAS_POINTER_MODE_HOST);\n";
writer << "cublasSgemm(" writer << "cublasSgemm("
<< "cublas_handle," << "cublas_handle,"
...@@ -286,11 +329,6 @@ cudnnSetOpTensorDescriptor(opTensorDesc, ...@@ -286,11 +329,6 @@ cudnnSetOpTensorDescriptor(opTensorDesc,
writer << "cublasSetPointerMode(cublas_handle, CUBLAS_POINTER_MODE_DEVICE);\n"; writer << "cublasSetPointerMode(cublas_handle, CUBLAS_POINTER_MODE_DEVICE);\n";
writer.block_end(); writer.block_end();
} }
else
{
throw std::runtime_error(node->get_name() +
" with more then 2D is not implemented.");
}
} }
template <> template <>
...@@ -425,7 +463,7 @@ cudnnSetOpTensorDescriptor(opTensorDesc, ...@@ -425,7 +463,7 @@ cudnnSetOpTensorDescriptor(opTensorDesc,
auto result_shape = out[0].get_shape(); auto result_shape = out[0].get_shape();
auto& axes = broadcast->get_broadcast_axes(); auto& axes = broadcast->get_broadcast_axes();
//broadcast axes is empty, do a copy // broadcast axes is empty, do a copy
if (axes.empty()) if (axes.empty())
{ {
writer.block_begin(" // " + node->get_name()); writer.block_begin(" // " + node->get_name());
...@@ -434,7 +472,7 @@ cudnnSetOpTensorDescriptor(opTensorDesc, ...@@ -434,7 +472,7 @@ cudnnSetOpTensorDescriptor(opTensorDesc,
return; return;
} }
//broadcast axes size is 1, or can be group to 1 (consecutive axes, like 01 or 12 or 123 etc) // broadcast axes size is 1, or can be group to 1 (consecutive axes, like 01 or 12 or 123 etc)
vector<int> axes_v; vector<int> axes_v;
std::copy(axes.begin(), axes.end(), std::back_inserter(axes_v)); std::copy(axes.begin(), axes.end(), std::back_inserter(axes_v));
std::sort(axes_v.begin(), axes_v.end()); std::sort(axes_v.begin(), axes_v.end());
...@@ -506,7 +544,7 @@ cudnnSetOpTensorDescriptor(opTensorDesc, ...@@ -506,7 +544,7 @@ cudnnSetOpTensorDescriptor(opTensorDesc,
result_shape_product *= i; result_shape_product *= i;
} }
// If there is no layout change or we are just going from 1^n to 1^m or a zero-size tensor, // If there is no layout change or we are just going from 1^n to 1^m or a zero-size tensor,
// we can just copy. // we can just copy.
if (same_layout || result_shape_product < 2) if (same_layout || result_shape_product < 2)
{ {
kernel::emit_memcpyDtD(writer, out[0], args[0]); kernel::emit_memcpyDtD(writer, out[0], args[0]);
......
...@@ -1044,7 +1044,6 @@ TEST(${BACKEND_NAME}, dot2d) ...@@ -1044,7 +1044,6 @@ TEST(${BACKEND_NAME}, dot2d)
// //
TEST(${BACKEND_NAME}, dot3d_3d) TEST(${BACKEND_NAME}, dot3d_3d)
{ {
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
Shape shape{2, 2, 2}; Shape shape{2, 2, 2};
auto A = make_shared<op::Parameter>(element::f32, shape); auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::Parameter>(element::f32, shape); auto B = make_shared<op::Parameter>(element::f32, shape);
...@@ -1093,7 +1092,6 @@ TEST(${BACKEND_NAME}, dot3d_3d) ...@@ -1093,7 +1092,6 @@ TEST(${BACKEND_NAME}, dot3d_3d)
// //
TEST(${BACKEND_NAME}, dot3d_2d) TEST(${BACKEND_NAME}, dot3d_2d)
{ {
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
Shape shape_a{4, 2, 3}; Shape shape_a{4, 2, 3};
auto A = make_shared<op::Parameter>(element::f32, shape_a); auto A = make_shared<op::Parameter>(element::f32, shape_a);
Shape shape_b{3, 4}; Shape shape_b{3, 4};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment