Commit 9fd64b6f authored by fenglei.tian's avatar fenglei.tian

fix bug for 2d2d2 dot, enable some bprop dot tests

parent dd5c77e0
...@@ -226,18 +226,30 @@ cudnnSetOpTensorDescriptor(opTensorDesc, ...@@ -226,18 +226,30 @@ cudnnSetOpTensorDescriptor(opTensorDesc,
return; return;
} }
if ((arg0_shape.size() == 1) && (arg1_shape.size() == 1)) //case that can be treat as dot1d
if ((arg0_shape.size() == arg1_shape.size()) &&
(arg0_shape.size() == dot->get_reduction_axes_count()))
{
for (int i = 0; i < arg0_shape.size(); i++)
{ {
if (arg0_shape[i] != arg1_shape[i])
{
throw std::runtime_error("two input shape is not correct for dot;");
}
}
writer << "{ // " << node->get_name() << "\n"; writer << "{ // " << node->get_name() << "\n";
writer.indent++; writer.indent++;
writer << "cublasSdot(" writer << "cublasSdot("
<< "cublas_handle," << arg0_shape[0] << "," << args[0].get_name() << "," << "cublas_handle," << args[0].get_size() << ","
<< args[0].get_name() << ","
<< "1," << args[1].get_name() << "," << "1," << args[1].get_name() << ","
<< "1," << out[0].get_name() << ");\n"; << "1," << out[0].get_name() << ");\n";
writer.indent--; writer.indent--;
writer << "}\n"; writer << "}\n";
} }
else if ((arg0_shape.size() == 2) && (arg1_shape.size() == 1)) else if ((arg0_shape.size() == 2) && (arg1_shape.size() == 1) &&
(dot->get_reduction_axes_count() == 1))
{ {
writer << "{ // " << node->get_name() << "\n"; writer << "{ // " << node->get_name() << "\n";
writer.indent++; writer.indent++;
...@@ -258,7 +270,8 @@ cudnnSetOpTensorDescriptor(opTensorDesc, ...@@ -258,7 +270,8 @@ cudnnSetOpTensorDescriptor(opTensorDesc,
writer.indent--; writer.indent--;
writer << "}\n"; writer << "}\n";
} }
else if ((arg0_shape.size() == 2) && (arg1_shape.size() == 2)) else if ((arg0_shape.size() == 2) && (arg1_shape.size() == 2) &&
(dot->get_reduction_axes_count() == 1))
{ {
// GEMM Call // GEMM Call
if (arg0_shape[0] != out[0].get_shape()[0] || // m if (arg0_shape[0] != out[0].get_shape()[0] || // m
...@@ -275,6 +288,7 @@ cudnnSetOpTensorDescriptor(opTensorDesc, ...@@ -275,6 +288,7 @@ cudnnSetOpTensorDescriptor(opTensorDesc,
writer << "int n = " << arg1_shape[1] << ";\n"; writer << "int n = " << arg1_shape[1] << ";\n";
writer << "int k = " << arg0_shape[0] << ";\n"; writer << "int k = " << arg0_shape[0] << ";\n";
writer << "cublasSetPointerMode(cublas_handle, CUBLAS_POINTER_MODE_HOST);\n"; writer << "cublasSetPointerMode(cublas_handle, CUBLAS_POINTER_MODE_HOST);\n";
writer << "cublasSgemm(" writer << "cublasSgemm("
<< "cublas_handle," << "cublas_handle,"
<< "CUBLAS_OP_N," << "CUBLAS_OP_N,"
...@@ -289,6 +303,7 @@ cudnnSetOpTensorDescriptor(opTensorDesc, ...@@ -289,6 +303,7 @@ cudnnSetOpTensorDescriptor(opTensorDesc,
<< "&beta," // beta << "&beta," // beta
<< out[0].get_name() << "," << out[0].get_name() << ","
<< "n);\n"; << "n);\n";
writer << "cublasSetPointerMode(cublas_handle, CUBLAS_POINTER_MODE_DEVICE);\n"; writer << "cublasSetPointerMode(cublas_handle, CUBLAS_POINTER_MODE_DEVICE);\n";
writer.indent--; writer.indent--;
writer << "}\n"; writer << "}\n";
......
...@@ -706,7 +706,6 @@ TEST(${BACKEND_NAME}, backwards_dot_scalar_scalar) ...@@ -706,7 +706,6 @@ TEST(${BACKEND_NAME}, backwards_dot_scalar_scalar)
TEST(${BACKEND_NAME}, backwards_dot_scalar_tensor) TEST(${BACKEND_NAME}, backwards_dot_scalar_tensor)
{ {
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
auto manager = runtime::Manager::get("${BACKEND_NAME}"); auto manager = runtime::Manager::get("${BACKEND_NAME}");
auto backend = manager->allocate_backend(); auto backend = manager->allocate_backend();
...@@ -728,7 +727,6 @@ TEST(${BACKEND_NAME}, backwards_dot_scalar_tensor) ...@@ -728,7 +727,6 @@ TEST(${BACKEND_NAME}, backwards_dot_scalar_tensor)
TEST(${BACKEND_NAME}, backwards_dot_tensor_scalar) TEST(${BACKEND_NAME}, backwards_dot_tensor_scalar)
{ {
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
auto manager = runtime::Manager::get("${BACKEND_NAME}"); auto manager = runtime::Manager::get("${BACKEND_NAME}");
auto backend = manager->allocate_backend(); auto backend = manager->allocate_backend();
...@@ -750,7 +748,6 @@ TEST(${BACKEND_NAME}, backwards_dot_tensor_scalar) ...@@ -750,7 +748,6 @@ TEST(${BACKEND_NAME}, backwards_dot_tensor_scalar)
TEST(${BACKEND_NAME}, backwards_dot_vector_vector) TEST(${BACKEND_NAME}, backwards_dot_vector_vector)
{ {
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
auto manager = runtime::Manager::get("${BACKEND_NAME}"); auto manager = runtime::Manager::get("${BACKEND_NAME}");
auto backend = manager->allocate_backend(); auto backend = manager->allocate_backend();
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment