Commit 9fd64b6f authored by fenglei.tian's avatar fenglei.tian

fix bug for 2d2d2 dot, enable some bprop dot tests

parent dd5c77e0
......@@ -226,18 +226,30 @@ cudnnSetOpTensorDescriptor(opTensorDesc,
return;
}
if ((arg0_shape.size() == 1) && (arg1_shape.size() == 1))
{
writer << "{ // " << node->get_name() << "\n";
writer.indent++;
writer << "cublasSdot("
<< "cublas_handle," << arg0_shape[0] << "," << args[0].get_name() << ","
<< "1," << args[1].get_name() << ","
<< "1," << out[0].get_name() << ");\n";
writer.indent--;
writer << "}\n";
}
else if ((arg0_shape.size() == 2) && (arg1_shape.size() == 1))
//case that can be treat as dot1d
if ((arg0_shape.size() == arg1_shape.size()) &&
(arg0_shape.size() == dot->get_reduction_axes_count()))
{
for (int i = 0; i < arg0_shape.size(); i++)
{
if (arg0_shape[i] != arg1_shape[i])
{
throw std::runtime_error("two input shape is not correct for dot;");
}
}
writer << "{ // " << node->get_name() << "\n";
writer.indent++;
writer << "cublasSdot("
<< "cublas_handle," << args[0].get_size() << ","
<< args[0].get_name() << ","
<< "1," << args[1].get_name() << ","
<< "1," << out[0].get_name() << ");\n";
writer.indent--;
writer << "}\n";
}
else if ((arg0_shape.size() == 2) && (arg1_shape.size() == 1) &&
(dot->get_reduction_axes_count() == 1))
{
writer << "{ // " << node->get_name() << "\n";
writer.indent++;
......@@ -258,7 +270,8 @@ cudnnSetOpTensorDescriptor(opTensorDesc,
writer.indent--;
writer << "}\n";
}
else if ((arg0_shape.size() == 2) && (arg1_shape.size() == 2))
else if ((arg0_shape.size() == 2) && (arg1_shape.size() == 2) &&
(dot->get_reduction_axes_count() == 1))
{
// GEMM Call
if (arg0_shape[0] != out[0].get_shape()[0] || // m
......@@ -275,6 +288,7 @@ cudnnSetOpTensorDescriptor(opTensorDesc,
writer << "int n = " << arg1_shape[1] << ";\n";
writer << "int k = " << arg0_shape[0] << ";\n";
writer << "cublasSetPointerMode(cublas_handle, CUBLAS_POINTER_MODE_HOST);\n";
writer << "cublasSgemm("
<< "cublas_handle,"
<< "CUBLAS_OP_N,"
......@@ -289,6 +303,7 @@ cudnnSetOpTensorDescriptor(opTensorDesc,
<< "&beta," // beta
<< out[0].get_name() << ","
<< "n);\n";
writer << "cublasSetPointerMode(cublas_handle, CUBLAS_POINTER_MODE_DEVICE);\n";
writer.indent--;
writer << "}\n";
......
......@@ -706,7 +706,6 @@ TEST(${BACKEND_NAME}, backwards_dot_scalar_scalar)
TEST(${BACKEND_NAME}, backwards_dot_scalar_tensor)
{
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
auto manager = runtime::Manager::get("${BACKEND_NAME}");
auto backend = manager->allocate_backend();
......@@ -728,7 +727,6 @@ TEST(${BACKEND_NAME}, backwards_dot_scalar_tensor)
TEST(${BACKEND_NAME}, backwards_dot_tensor_scalar)
{
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
auto manager = runtime::Manager::get("${BACKEND_NAME}");
auto backend = manager->allocate_backend();
......@@ -750,7 +748,6 @@ TEST(${BACKEND_NAME}, backwards_dot_tensor_scalar)
TEST(${BACKEND_NAME}, backwards_dot_vector_vector)
{
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
auto manager = runtime::Manager::get("${BACKEND_NAME}");
auto backend = manager->allocate_backend();
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment