Commit 680be054 authored by Fenglei's avatar Fenglei Committed by Nick Korovaiko

change get_shape().size() to get_size(), we need to check the actual size (#1051)

parent 1adcf3b2
......@@ -1161,7 +1161,7 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
<< ", temp_d, " << out[0].get_size() << " * "
<< out[0].get_element_type().size() << ");\n";
}
else if (args[0].get_shape().size() == out[0].get_shape().size())
else if (args[0].get_size() == out[0].get_size())
{
kernel::emit_memcpyDtD(writer, out[0], args[0]);
}
......@@ -1209,7 +1209,7 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
<< ", temp_d, " << out[0].get_size() << " * "
<< out[0].get_element_type().size() << ");\n";
}
else if (args[0].get_shape().size() == out[0].get_shape().size())
else if (args[0].get_size() == out[0].get_size())
{
kernel::emit_memcpyDtD(writer, out[0], args[0]);
}
......@@ -1247,7 +1247,7 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
{
kernel::emit_memset(writer, out[0], 0);
}
else if (args[0].get_shape().size() == out[0].get_shape().size())
else if (args[0].get_size() == out[0].get_size())
{
kernel::emit_memcpyDtD(writer, out[0], args[0]);
}
......@@ -1291,7 +1291,7 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
<< ", (void*)temp.data(), " << out[0].get_size() << " * "
<< out[0].get_element_type().size() << ");\n";
}
else if (args[0].get_shape().size() == out[0].get_shape().size())
else if (args[0].get_size() == out[0].get_size())
{
kernel::emit_memcpyDtD(writer, out[0], args[0]);
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment