Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
f5cd6381
Unverified
Commit
f5cd6381
authored
Feb 28, 2018
by
Fenglei
Committed by
GitHub
Feb 28, 2018
Browse files
Options
Browse Files
Download
Plain Diff
Merge pull request #544 from NervanaSystems/tfl/gpu_dot
Fix bugs and enable dot in GPU
parents
a217ffef
6482b3d2
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
52 additions
and
29 deletions
+52
-29
gpu_call_frame.cpp
src/ngraph/runtime/gpu/gpu_call_frame.cpp
+1
-1
gpu_emitter.cpp
src/ngraph/runtime/gpu/gpu_emitter.cpp
+43
-15
gpu_tensor_view.cpp
src/ngraph/runtime/gpu/gpu_tensor_view.cpp
+1
-1
gpu_util.cpp
src/ngraph/runtime/gpu/gpu_util.cpp
+5
-0
gpu_util.hpp
src/ngraph/runtime/gpu/gpu_util.hpp
+1
-0
backend_test.in.cpp
test/backend_test.in.cpp
+1
-12
No files found.
src/ngraph/runtime/gpu/gpu_call_frame.cpp
View file @
f5cd6381
...
...
@@ -47,7 +47,7 @@ runtime::gpu::GPU_CallFrame::GPU_CallFrame(std::shared_ptr<GPU_ExternalFunction>
}
// Pass scalars as reference on the Device
cublasSetPointerMode
(
m_cublas_handle
,
CUBLAS_POINTER_MODE_
HOST
);
cublasSetPointerMode
(
m_cublas_handle
,
CUBLAS_POINTER_MODE_
DEVICE
);
}
runtime
::
gpu
::
GPU_CallFrame
::~
GPU_CallFrame
()
...
...
src/ngraph/runtime/gpu/gpu_emitter.cpp
View file @
f5cd6381
...
...
@@ -149,25 +149,54 @@ void runtime::gpu::GPU_Emitter::EmitDot(codegen::CodeWriter& writer,
const
vector
<
runtime
::
gpu
::
GPU_TensorViewWrapper
>&
args
,
const
vector
<
runtime
::
gpu
::
GPU_TensorViewWrapper
>&
out
)
{
throw
std
::
runtime_error
(
n
->
get_name
()
+
" is not implemented."
);
const
ngraph
::
op
::
Dot
*
dot
=
static_cast
<
const
ngraph
::
op
::
Dot
*>
(
n
);
const
Shape
&
arg0_shape
=
args
[
0
].
get_shape
();
const
Shape
&
arg1_shape
=
args
[
1
].
get_shape
();
if
(
arg0_shape
.
empty
()
||
arg1_shape
.
empty
())
{
auto
&
first
=
(
arg0_shape
.
empty
()
?
args
[
0
]
:
args
[
1
]);
auto
&
second
=
(
arg0_shape
.
empty
()
?
args
[
1
]
:
args
[
0
]);
writer
<<
"{ // "
<<
n
->
get_name
()
<<
"
\n
"
;
writer
.
indent
++
;
writer
<<
"int count = "
<<
second
.
get_size
()
<<
";
\n
"
;
writer
<<
"if(count == 0) return;
\n
"
;
writer
<<
"cublasScopy("
<<
"cublas_handle,"
<<
"count ,"
<<
second
.
get_name
()
<<
","
<<
"1,"
<<
out
[
0
].
get_name
()
<<
", 1);
\n
"
;
writer
<<
"cublasSscal("
<<
"cublas_handle,"
<<
"count ,"
<<
first
.
get_name
()
<<
","
<<
out
[
0
].
get_name
()
<<
", 1);
\n
"
;
writer
.
indent
--
;
writer
<<
"}
\n
"
;
return
;
}
//return if output size is 0;
if
(
out
[
0
].
get_size
()
==
0
)
{
writer
<<
"{ // "
<<
n
->
get_name
()
<<
"
\n
"
;
writer
.
indent
++
;
writer
<<
"cublasSdot("
<<
"cublas_handle,"
<<
second
.
get_size
()
<<
","
<<
first
.
get_name
()
<<
","
<<
"1,"
<<
second
.
get_name
()
<<
","
<<
"1,"
<<
out
[
0
].
get_name
()
<<
");
\n
"
;
writer
<<
"return;
\n
"
;
writer
.
indent
--
;
writer
<<
"}
\n
"
;
return
;
}
else
if
((
arg0_shape
.
size
()
==
1
)
&&
(
arg1_shape
.
size
()
==
1
))
//set output to 0 if input size is 0
if
(
args
[
0
].
get_size
()
==
0
||
args
[
1
].
get_size
()
==
0
)
{
writer
<<
"{ // "
<<
n
->
get_name
()
<<
"
\n
"
;
writer
.
indent
++
;
writer
<<
"runtime::gpu::cuda_memset("
<<
out
[
0
].
get_name
()
<<
", 0, "
<<
out
[
0
].
get_size
()
<<
" * sizeof(float));
\n
"
;
writer
<<
"return;
\n
"
;
writer
.
indent
--
;
writer
<<
"}
\n
"
;
return
;
}
if
((
arg0_shape
.
size
()
==
1
)
&&
(
arg1_shape
.
size
()
==
1
))
{
writer
<<
"{ // "
<<
n
->
get_name
()
<<
"
\n
"
;
writer
.
indent
++
;
...
...
@@ -182,10 +211,9 @@ void runtime::gpu::GPU_Emitter::EmitDot(codegen::CodeWriter& writer,
{
writer
<<
"{ // "
<<
n
->
get_name
()
<<
"
\n
"
;
writer
.
indent
++
;
writer
<<
"
static
const float alpha = 1.0;
\n
"
;
writer
<<
"
static const float beta = 1.
0;
\n
"
;
writer
<<
"const float alpha = 1.0;
\n
"
;
writer
<<
"
const float beta =
0;
\n
"
;
writer
<<
"cublasSetPointerMode(cublas_handle, CUBLAS_POINTER_MODE_HOST);
\n
"
;
;
writer
<<
"cublasSgemv("
<<
"cublas_handle,"
<<
"CUBLAS_OP_T,"
<<
arg0_shape
[
0
]
<<
","
<<
arg0_shape
[
1
]
<<
","
...
...
@@ -195,6 +223,7 @@ void runtime::gpu::GPU_Emitter::EmitDot(codegen::CodeWriter& writer,
<<
"&beta,"
// beta
<<
out
[
0
].
get_name
()
<<
","
<<
"1);
\n
"
;
writer
<<
"cublasSetPointerMode(cublas_handle, CUBLAS_POINTER_MODE_DEVICE);
\n
"
;
writer
.
indent
--
;
writer
<<
"}
\n
"
;
}
...
...
@@ -209,8 +238,8 @@ void runtime::gpu::GPU_Emitter::EmitDot(codegen::CodeWriter& writer,
}
writer
<<
"{ // "
<<
n
->
get_name
()
<<
"
\n
"
;
writer
.
indent
++
;
writer
<<
"
static
const float alpha = 1.0;
\n
"
;
writer
<<
"
static
const float beta = 0.0;
\n
"
;
writer
<<
"const float alpha = 1.0;
\n
"
;
writer
<<
"const float beta = 0.0;
\n
"
;
writer
<<
"int m = "
<<
arg0_shape
[
0
]
<<
";
\n
"
;
writer
<<
"int n = "
<<
arg1_shape
[
1
]
<<
";
\n
"
;
writer
<<
"int k = "
<<
arg0_shape
[
0
]
<<
";
\n
"
;
...
...
@@ -229,12 +258,13 @@ void runtime::gpu::GPU_Emitter::EmitDot(codegen::CodeWriter& writer,
<<
"&beta,"
// beta
<<
out
[
0
].
get_name
()
<<
","
<<
"n);
\n
"
;
writer
<<
"cublasSetPointerMode(cublas_handle, CUBLAS_POINTER_MODE_DEVICE);
\n
"
;
writer
.
indent
--
;
writer
<<
"}
\n
"
;
}
else
{
// General ND Call?
throw
std
::
runtime_error
(
n
->
get_name
()
+
" with more then 2D is not implemented."
);
}
}
...
...
@@ -513,8 +543,6 @@ void runtime::gpu::GPU_Emitter::EmitReshape(codegen::CodeWriter& writer,
writer
.
indent
++
;
writer
<<
"static const float alpha = 1.0;
\n
"
;
writer
<<
"static const float beta = 0.0;
\n
"
;
writer
<<
"cublasSetPointerMode(cublas_handle, CUBLAS_POINTER_MODE_HOST);
\n
"
;
;
writer
<<
"cublasSgeam("
<<
"cublas_handle,"
<<
"CUBLAS_OP_T,"
...
...
src/ngraph/runtime/gpu/gpu_tensor_view.cpp
View file @
f5cd6381
...
...
@@ -38,7 +38,7 @@ runtime::gpu::GPU_TensorView::GPU_TensorView(const ngraph::element::Type& elemen
m_descriptor
->
set_tensor_view_layout
(
std
::
make_shared
<
ngraph
::
descriptor
::
layout
::
DenseTensorViewLayout
>
(
*
m_descriptor
));
m_buffer_size
=
m_descriptor
->
get_tensor_view_layout
()
->
get_size
(
)
*
element_type
.
size
();
m_buffer_size
=
shape_size
(
shape
)
*
element_type
.
size
();
if
(
m_buffer_size
>
0
)
{
cudaMalloc
((
void
**
)
&
m_allocated_buffer_pool
,
m_buffer_size
);
...
...
src/ngraph/runtime/gpu/gpu_util.cpp
View file @
f5cd6381
...
...
@@ -64,3 +64,8 @@ void runtime::gpu::cuda_memcpyHtD(void* d, void* s, size_t buffer_size)
{
cudaMemcpy
(
d
,
s
,
buffer_size
,
cudaMemcpyHostToDevice
);
}
void
runtime
::
gpu
::
cuda_memset
(
void
*
d
,
int
value
,
size_t
buffer_size
)
{
cudaMemset
(
d
,
value
,
buffer_size
);
}
src/ngraph/runtime/gpu/gpu_util.hpp
View file @
f5cd6381
...
...
@@ -63,6 +63,7 @@ namespace ngraph
void
*
create_gpu_buffer
(
size_t
buffer_size
);
void
cuda_memcpyDtD
(
void
*
d
,
void
*
s
,
size_t
element_count
,
size_t
element_size
);
void
cuda_memcpyHtD
(
void
*
d
,
void
*
s
,
size_t
buffer_size
);
void
cuda_memset
(
void
*
d
,
int
value
,
size_t
buffer_size
);
}
}
}
test/backend_test.in.cpp
View file @
f5cd6381
...
...
@@ -716,7 +716,6 @@ TEST(${BACKEND_NAME}, floor)
TEST
(
$
{
BACKEND_NAME
},
dot_0_0
)
{
SKIP_TEST_FOR
(
"GPU"
,
"${BACKEND_NAME}"
);
SKIP_TEST_FOR
(
"ARGON"
,
"${BACKEND_NAME}"
);
Shape
shape
{
0
};
...
...
@@ -746,7 +745,6 @@ TEST(${BACKEND_NAME}, dot_0_0)
TEST
(
$
{
BACKEND_NAME
},
dot_matrix_2x0_0x2
)
{
SKIP_TEST_FOR
(
"GPU"
,
"${BACKEND_NAME}"
);
SKIP_TEST_FOR
(
"ARGON"
,
"${BACKEND_NAME}"
);
Shape
shape_a
{
2
,
0
};
...
...
@@ -783,7 +781,6 @@ TEST(${BACKEND_NAME}, dot_matrix_2x0_0x2)
TEST
(
$
{
BACKEND_NAME
},
dot_matrix_0x2_2x0
)
{
SKIP_TEST_FOR
(
"GPU"
,
"${BACKEND_NAME}"
);
SKIP_TEST_FOR
(
"ARGON"
,
"${BACKEND_NAME}"
);
Shape
shape_a
{
0
,
2
};
...
...
@@ -812,7 +809,6 @@ TEST(${BACKEND_NAME}, dot_matrix_0x2_2x0)
TEST
(
$
{
BACKEND_NAME
},
dot_matrix_3x2_2x0
)
{
SKIP_TEST_FOR
(
"GPU"
,
"${BACKEND_NAME}"
);
SKIP_TEST_FOR
(
"ARGON"
,
"${BACKEND_NAME}"
);
Shape
shape_a
{
3
,
2
};
...
...
@@ -841,7 +837,6 @@ TEST(${BACKEND_NAME}, dot_matrix_3x2_2x0)
TEST
(
$
{
BACKEND_NAME
},
dot_scalar_0x2
)
{
SKIP_TEST_FOR
(
"GPU"
,
"${BACKEND_NAME}"
);
SKIP_TEST_FOR
(
"ARGON"
,
"${BACKEND_NAME}"
);
Shape
shape_a
{};
...
...
@@ -869,7 +864,6 @@ TEST(${BACKEND_NAME}, dot_scalar_0x2)
TEST
(
$
{
BACKEND_NAME
},
dot_2x0_0
)
{
SKIP_TEST_FOR
(
"GPU"
,
"${BACKEND_NAME}"
);
SKIP_TEST_FOR
(
"ARGON"
,
"${BACKEND_NAME}"
);
Shape
shape_a
{
2
,
0
};
...
...
@@ -900,7 +894,6 @@ TEST(${BACKEND_NAME}, dot_2x0_0)
TEST
(
$
{
BACKEND_NAME
},
dot1d
)
{
SKIP_TEST_FOR
(
"GPU"
,
"${BACKEND_NAME}"
);
Shape
shape
{
4
};
auto
A
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
shape
);
auto
B
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
shape
);
...
...
@@ -925,7 +918,6 @@ TEST(${BACKEND_NAME}, dot1d)
TEST
(
$
{
BACKEND_NAME
},
dot2d
)
{
SKIP_TEST_FOR
(
"GPU"
,
"${BACKEND_NAME}"
);
Shape
shape
{
2
,
2
};
auto
A
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
shape
);
auto
B
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
shape
);
...
...
@@ -1052,7 +1044,6 @@ TEST(${BACKEND_NAME}, dot3d_2d)
TEST
(
$
{
BACKEND_NAME
},
dot_scalar_tensor_arg0
)
{
SKIP_TEST_FOR
(
"GPU"
,
"${BACKEND_NAME}"
);
Shape
shape_a
{};
Shape
shape_b
{
2
,
2
,
2
};
auto
A
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
shape_a
);
...
...
@@ -1077,7 +1068,6 @@ TEST(${BACKEND_NAME}, dot_scalar_tensor_arg0)
TEST
(
$
{
BACKEND_NAME
},
dot_scalar_tensor_arg1
)
{
SKIP_TEST_FOR
(
"GPU"
,
"${BACKEND_NAME}"
);
Shape
shape_a
{
2
,
2
,
2
};
Shape
shape_b
{};
auto
A
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
shape_a
);
...
...
@@ -1102,7 +1092,6 @@ TEST(${BACKEND_NAME}, dot_scalar_tensor_arg1)
TEST
(
$
{
BACKEND_NAME
},
dot_scalar_scalar
)
{
SKIP_TEST_FOR
(
"GPU"
,
"${BACKEND_NAME}"
);
Shape
shape
{};
auto
A
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
shape
);
auto
B
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
shape
);
...
...
@@ -1126,7 +1115,6 @@ TEST(${BACKEND_NAME}, dot_scalar_scalar)
TEST
(
$
{
BACKEND_NAME
},
dot_matrix_vector
)
{
SKIP_TEST_FOR
(
"GPU"
,
"${BACKEND_NAME}"
);
Shape
shape_a
{
4
,
4
};
Shape
shape_b
{
4
};
auto
A
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
shape_a
);
...
...
@@ -6245,6 +6233,7 @@ TEST(${BACKEND_NAME}, convolution_outlining)
TEST
(
$
{
BACKEND_NAME
},
mkldnn_layouts
)
{
ONLY_ENABLE_TEST_FOR
(
"CPU"
,
"${BACKEND_NAME}"
);
Shape
shape_a
{
1
,
16
,
2
,
2
};
auto
A
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
shape_a
);
Shape
shape_b
{
32
,
16
,
1
,
1
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment