Commit 26afa898 authored by fenglei.tian's avatar fenglei.tian

fix bugs in dot op

parent dd3af46c
......@@ -42,7 +42,7 @@ runtime::gpu::GPU_CallFrame::GPU_CallFrame(std::shared_ptr<GPU_ExternalFunction>
}
// Pass scalars as reference on the Host
cublasSetPointerMode(m_cublas_handle, CUBLAS_POINTER_MODE_HOST);
cublasSetPointerMode(m_cublas_handle, CUBLAS_POINTER_MODE_DEVICE);
}
runtime::gpu::GPU_CallFrame::~GPU_CallFrame()
......
......@@ -141,25 +141,54 @@ void runtime::gpu::GPU_Emitter::EmitDot(codegen::CodeWriter& writer,
const vector<runtime::gpu::GPU_TensorViewWrapper>& args,
const vector<runtime::gpu::GPU_TensorViewWrapper>& out)
{
throw std::runtime_error(n->get_name() + " is not implemented.");
const ngraph::op::Dot* dot = static_cast<const ngraph::op::Dot*>(n);
const Shape& arg0_shape = args[0].get_shape();
const Shape& arg1_shape = args[1].get_shape();
if (arg0_shape.empty() || arg1_shape.empty())
{
auto& first = (arg0_shape.empty() ? args[0] : args[1]);
auto& second = (arg0_shape.empty() ? args[1] : args[0]);
writer << "{ // " << n->get_name() << "\n";
writer.indent++;
writer << "int count = " << second.get_size() << ";\n";
writer << "if(count == 0) return;\n";
writer << "cublasScopy("
<< "cublas_handle,"
<< "count ," << second.get_name() << ","
<< "1," << out[0].get_name() << ", 1);\n";
writer << "cublasSscal("
<< "cublas_handle,"
<< "count ," << first.get_name() << "," << out[0].get_name() << ", 1);\n";
writer.indent--;
writer << "}\n";
return;
}
//return if out size is 0;
if (out[0].get_size() == 0)
{
writer << "{ // " << n->get_name() << "\n";
writer.indent++;
writer << "cublasSdot("
<< "cublas_handle," << second.get_size() << "," << first.get_name() << ","
<< "1," << second.get_name() << ","
<< "1," << out[0].get_name() << ");\n";
writer << "return;\n";
writer.indent--;
writer << "}\n";
return;
}
else if ((arg0_shape.size() == 1) && (arg1_shape.size() == 1))
//set out put to 0 if
if (args[0].get_size() == 0 || args[1].get_size() == 0)
{
writer << "{ // " << n->get_name() << "\n";
writer.indent++;
writer << "runtime::gpu::cuda_memset(" << out[0].get_name() << ", 0, " << out[0].get_size()
<< " * sizeof(float));\n";
writer << "return;\n";
writer.indent--;
writer << "}\n";
return;
}
if ((arg0_shape.size() == 1) && (arg1_shape.size() == 1))
{
writer << "{ // " << n->get_name() << "\n";
writer.indent++;
......@@ -174,10 +203,9 @@ void runtime::gpu::GPU_Emitter::EmitDot(codegen::CodeWriter& writer,
{
writer << "{ // " << n->get_name() << "\n";
writer.indent++;
writer << "static const float alpha = 1.0;\n";
writer << "static const float beta = 1.0;\n";
writer << "const float alpha = 1.0;\n";
writer << "const float beta = 0;\n";
writer << "cublasSetPointerMode(cublas_handle, CUBLAS_POINTER_MODE_HOST);\n";
;
writer << "cublasSgemv("
<< "cublas_handle,"
<< "CUBLAS_OP_T," << arg0_shape[0] << "," << arg0_shape[1] << ","
......@@ -187,6 +215,7 @@ void runtime::gpu::GPU_Emitter::EmitDot(codegen::CodeWriter& writer,
<< "&beta," // beta
<< out[0].get_name() << ","
<< "1);\n";
writer << "cublasSetPointerMode(cublas_handle, CUBLAS_POINTER_MODE_DEVICE);\n";
writer.indent--;
writer << "}\n";
}
......@@ -201,8 +230,8 @@ void runtime::gpu::GPU_Emitter::EmitDot(codegen::CodeWriter& writer,
}
writer << "{ // " << n->get_name() << "\n";
writer.indent++;
writer << "static const float alpha = 1.0;\n";
writer << "static const float beta = 0.0;\n";
writer << "const float alpha = 1.0;\n";
writer << "const float beta = 0.0;\n";
writer << "int m = " << arg0_shape[0] << ";\n";
writer << "int n = " << arg1_shape[1] << ";\n";
writer << "int k = " << arg0_shape[0] << ";\n";
......@@ -221,12 +250,13 @@ void runtime::gpu::GPU_Emitter::EmitDot(codegen::CodeWriter& writer,
<< "&beta," // beta
<< out[0].get_name() << ","
<< "n);\n";
writer << "cublasSetPointerMode(cublas_handle, CUBLAS_POINTER_MODE_DEVICE);\n";
writer.indent--;
writer << "}\n";
}
else
{
// General ND Call?
throw std::runtime_error(n->get_name() + " with more then 2D is not implemented.");
}
}
......@@ -505,7 +535,7 @@ void runtime::gpu::GPU_Emitter::EmitReshape(codegen::CodeWriter& writer,
writer.indent++;
writer << "static const float alpha = 1.0;\n";
writer << "static const float beta = 0.0;\n";
writer << "cublasSetPointerMode(cublas_handle, CUBLAS_POINTER_MODE_HOST);\n";
//writer << "cublasSetPointerMode(cublas_handle, CUBLAS_POINTER_MODE_HOST);\n";
;
writer << "cublasSgeam("
<< "cublas_handle,"
......
......@@ -38,7 +38,7 @@ runtime::gpu::GPU_TensorView::GPU_TensorView(const ngraph::element::Type& elemen
m_descriptor->set_tensor_view_layout(
std::make_shared<ngraph::descriptor::layout::DenseTensorViewLayout>(*m_descriptor));
m_buffer_size = m_descriptor->get_tensor_view_layout()->get_size() * element_type.size();
m_buffer_size = shape_size(shape) * element_type.size();
if (m_buffer_size > 0)
{
cudaMalloc((void**)&m_allocated_buffer_pool, m_buffer_size);
......
......@@ -64,3 +64,8 @@ void runtime::gpu::cuda_memcpyHtD(void* d, void* s, size_t buffer_size)
{
cudaMemcpy(d, s, buffer_size, cudaMemcpyHostToDevice);
}
void runtime::gpu::cuda_memset(void* d, int value, size_t buffer_size)
{
cudaMemset(d, value, buffer_size);
}
......@@ -27,6 +27,7 @@ namespace ngraph
void* create_gpu_buffer(size_t buffer_size);
void cuda_memcpyDtD(void* d, void* s, size_t element_count, size_t element_size);
void cuda_memcpyHtD(void* d, void* s, size_t buffer_size);
void cuda_memset(void* d, int value, size_t buffer_size);
}
}
}
......@@ -712,7 +712,6 @@ TEST(${BACKEND_NAME}, floor)
TEST(${BACKEND_NAME}, dot_0_0)
{
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
Shape shape{0};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::Parameter>(element::f32, shape);
......@@ -740,7 +739,6 @@ TEST(${BACKEND_NAME}, dot_0_0)
TEST(${BACKEND_NAME}, dot_matrix_2x0_0x2)
{
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
Shape shape_a{2, 0};
Shape shape_b{0, 2};
Shape shape_r{2, 2};
......@@ -775,7 +773,6 @@ TEST(${BACKEND_NAME}, dot_matrix_2x0_0x2)
TEST(${BACKEND_NAME}, dot_matrix_0x2_2x0)
{
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
Shape shape_a{0, 2};
auto A = make_shared<op::Parameter>(element::f32, shape_a);
Shape shape_b{2, 0};
......@@ -801,7 +798,6 @@ TEST(${BACKEND_NAME}, dot_matrix_0x2_2x0)
TEST(${BACKEND_NAME}, dot_matrix_3x2_2x0)
{
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
Shape shape_a{3, 2};
auto A = make_shared<op::Parameter>(element::f32, shape_a);
Shape shape_b{2, 0};
......@@ -827,7 +823,6 @@ TEST(${BACKEND_NAME}, dot_matrix_3x2_2x0)
TEST(${BACKEND_NAME}, dot_scalar_0x2)
{
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
Shape shape_a{};
auto A = make_shared<op::Parameter>(element::f32, shape_a);
Shape shape_b{0, 2};
......@@ -853,7 +848,6 @@ TEST(${BACKEND_NAME}, dot_scalar_0x2)
TEST(${BACKEND_NAME}, dot_2x0_0)
{
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
Shape shape_a{2, 0};
auto A = make_shared<op::Parameter>(element::f32, shape_a);
Shape shape_b{0};
......@@ -882,7 +876,6 @@ TEST(${BACKEND_NAME}, dot_2x0_0)
TEST(${BACKEND_NAME}, dot1d)
{
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
Shape shape{4};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::Parameter>(element::f32, shape);
......@@ -907,7 +900,6 @@ TEST(${BACKEND_NAME}, dot1d)
TEST(${BACKEND_NAME}, dot2d)
{
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
Shape shape{2, 2};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::Parameter>(element::f32, shape);
......@@ -1034,7 +1026,6 @@ TEST(${BACKEND_NAME}, dot3d_2d)
TEST(${BACKEND_NAME}, dot_scalar_tensor_arg0)
{
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
Shape shape_a{};
Shape shape_b{2, 2, 2};
auto A = make_shared<op::Parameter>(element::f32, shape_a);
......@@ -1059,7 +1050,6 @@ TEST(${BACKEND_NAME}, dot_scalar_tensor_arg0)
TEST(${BACKEND_NAME}, dot_scalar_tensor_arg1)
{
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
Shape shape_a{2, 2, 2};
Shape shape_b{};
auto A = make_shared<op::Parameter>(element::f32, shape_a);
......@@ -1084,7 +1074,6 @@ TEST(${BACKEND_NAME}, dot_scalar_tensor_arg1)
TEST(${BACKEND_NAME}, dot_scalar_scalar)
{
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
Shape shape{};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::Parameter>(element::f32, shape);
......@@ -1108,7 +1097,6 @@ TEST(${BACKEND_NAME}, dot_scalar_scalar)
TEST(${BACKEND_NAME}, dot_matrix_vector)
{
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
Shape shape_a{4, 4};
Shape shape_b{4};
auto A = make_shared<op::Parameter>(element::f32, shape_a);
......@@ -6058,6 +6046,7 @@ TEST(${BACKEND_NAME}, convolution_outlining)
TEST(${BACKEND_NAME}, convolution_layout)
{
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
Shape shape_a{1, 16, 2, 2};
auto A = make_shared<op::Parameter>(element::f32, shape_a);
Shape shape_b{32, 16, 1, 1};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment