Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
e1b2f54c
Commit
e1b2f54c
authored
Mar 02, 2018
by
fenglei.tian
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
add gpu broadcast
parent
f4ff1c3b
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
63 additions
and
11 deletions
+63
-11
gpu_cuda_kernel_emitters.cpp
src/ngraph/runtime/gpu/gpu_cuda_kernel_emitters.cpp
+6
-0
gpu_cuda_kernel_emitters.hpp
src/ngraph/runtime/gpu/gpu_cuda_kernel_emitters.hpp
+2
-0
gpu_emitter.cpp
src/ngraph/runtime/gpu/gpu_emitter.cpp
+55
-1
backend_performance.cpp
test/backend_performance.cpp
+0
-1
backend_test.in.cpp
test/backend_test.in.cpp
+0
-9
No files found.
src/ngraph/runtime/gpu/gpu_cuda_kernel_emitters.cpp
View file @
e1b2f54c
...
...
@@ -61,6 +61,12 @@ namespace ngraph
0
));
// arguments
CUDA_SAFE_CALL
(
cuCtxSynchronize
());
// Retrieve and print output.
}
void
emit_broadcast
(
void
*
in
,
void
*
out
,
size_t
repeat_size
,
size_t
repeat_times
,
size_t
count
)
{
return
;
}
}
}
}
src/ngraph/runtime/gpu/gpu_cuda_kernel_emitters.hpp
View file @
e1b2f54c
...
...
@@ -27,6 +27,8 @@ namespace ngraph
namespace
gpu
{
void
emit_abs
(
void
*
in
,
void
*
out
,
size_t
count
);
void
emit_broadcast
(
void
*
in
,
void
*
out
,
size_t
repeat_size
,
size_t
repeat_times
,
size_t
count
);
}
}
}
src/ngraph/runtime/gpu/gpu_emitter.cpp
View file @
e1b2f54c
...
...
@@ -457,7 +457,62 @@ void runtime::gpu::GPU_Emitter::EmitBroadcast(
const
vector
<
runtime
::
gpu
::
GPU_TensorViewWrapper
>&
args
,
const
vector
<
runtime
::
gpu
::
GPU_TensorViewWrapper
>&
out
)
{
auto
broadcast
=
static_cast
<
const
ngraph
::
op
::
Broadcast
*>
(
n
);
auto
arg_shape
=
args
[
0
].
get_shape
();
auto
result_shape
=
out
[
0
].
get_shape
();
auto
&
axes
=
broadcast
->
get_broadcast_axes
();
if
(
axes
.
empty
())
{
writer
<<
"{ // "
<<
n
->
get_name
()
<<
"
\n
"
;
writer
.
indent
++
;
writer
<<
"runtime::gpu::cuda_memcpyDtD("
<<
out
[
0
].
get_name
()
<<
", "
<<
args
[
0
].
get_name
()
<<
", "
<<
out
[
0
].
get_size
()
<<
" * "
<<
out
[
0
].
get_element_type
().
size
()
<<
");
\n
"
;
writer
.
indent
--
;
writer
<<
"}
\n
"
;
return
;
}
vector
<
int
>
axes_v
;
std
::
copy
(
axes
.
begin
(),
axes
.
end
(),
std
::
back_inserter
(
axes_v
));
bool
is_one_axes
=
true
;
if
(
axes
.
size
()
!=
1
)
{
for
(
int
i
=
1
;
i
<
axes_v
.
size
();
i
++
)
{
if
(
axes_v
[
i
]
!=
axes_v
[
i
-
1
]
+
1
)
{
is_one_axes
=
false
;
break
;
}
}
}
if
(
is_one_axes
)
{
int
repeat_times
=
1
;
for
(
int
i
=
0
;
i
<
axes
.
size
();
i
++
)
{
repeat_times
*=
result_shape
[
axes_v
[
i
]];
}
int
repeat_size
=
1
;
for
(
int
i
=
*
axes
.
rbegin
();
i
<
result_shape
.
size
();
i
++
)
{
repeat_size
*=
result_shape
[
i
];
}
writer
<<
"{ // "
<<
n
->
get_name
()
<<
"
\n
"
;
writer
.
indent
++
;
writer
<<
"runtime::gpu::emit_broadcast("
<<
args
[
0
].
get_name
()
<<
", "
<<
out
[
0
].
get_name
()
<<
", "
<<
repeat_size
<<
", "
<<
repeat_times
<<
", "
<<
out
[
0
].
get_size
()
<<
");
\n
"
;
writer
.
indent
--
;
writer
<<
"}
\n
"
;
}
else
{
throw
std
::
runtime_error
(
n
->
get_name
()
+
" is not implemented."
);
}
}
void
runtime
::
gpu
::
GPU_Emitter
::
EmitConvert
(
codegen
::
CodeWriter
&
writer
,
...
...
@@ -474,7 +529,6 @@ void runtime::gpu::GPU_Emitter::EmitConstant(
const
vector
<
runtime
::
gpu
::
GPU_TensorViewWrapper
>&
args
,
const
vector
<
runtime
::
gpu
::
GPU_TensorViewWrapper
>&
out
)
{
throw
std
::
runtime_error
(
n
->
get_name
()
+
" is not implemented."
);
}
void
runtime
::
gpu
::
GPU_Emitter
::
EmitReshape
(
codegen
::
CodeWriter
&
writer
,
...
...
test/backend_performance.cpp
View file @
e1b2f54c
...
...
@@ -46,7 +46,6 @@ TEST(benchmark, mxnet_mnist_mlp_forward)
TEST
(
benchmark
,
gpu_mxnet_mnist_mlp_forward
)
{
SKIP_TEST_FOR
(
"GPU"
,
"${BACKEND_NAME}"
);
const
string
json_path
=
file_util
::
path_join
(
SERIALIZED_ZOO
,
"mxnet/mnist_mlp_forward.json"
);
run_benchmark
(
json_path
,
"GPU"
,
1000
);
}
...
...
test/backend_test.in.cpp
View file @
e1b2f54c
...
...
@@ -1620,7 +1620,6 @@ TEST(${BACKEND_NAME}, function_call)
TEST
(
$
{
BACKEND_NAME
},
broadcast_scalar_vector
)
{
SKIP_TEST_FOR
(
"GPU"
,
"${BACKEND_NAME}"
);
Shape
shape_a
{};
auto
A
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
shape_a
);
Shape
shape_r
{
4
};
...
...
@@ -1643,7 +1642,6 @@ TEST(${BACKEND_NAME}, broadcast_scalar_vector)
TEST
(
$
{
BACKEND_NAME
},
broadcast_scalar_matrix
)
{
SKIP_TEST_FOR
(
"GPU"
,
"${BACKEND_NAME}"
);
Shape
shape_a
{};
auto
A
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
shape_a
);
Shape
shape_r
{
2
,
2
};
...
...
@@ -1666,7 +1664,6 @@ TEST(${BACKEND_NAME}, broadcast_scalar_matrix)
TEST
(
$
{
BACKEND_NAME
},
broadcast_scalar_tensor
)
{
SKIP_TEST_FOR
(
"GPU"
,
"${BACKEND_NAME}"
);
Shape
shape_a
{};
auto
A
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
shape_a
);
Shape
shape_r
{
2
,
2
,
2
};
...
...
@@ -1689,7 +1686,6 @@ TEST(${BACKEND_NAME}, broadcast_scalar_tensor)
TEST
(
$
{
BACKEND_NAME
},
broadcast_trivial
)
{
SKIP_TEST_FOR
(
"GPU"
,
"${BACKEND_NAME}"
);
Shape
shape
{
2
,
2
,
2
};
auto
A
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
shape
);
auto
f
=
make_shared
<
Function
>
(
make_shared
<
op
::
Broadcast
>
(
A
,
shape
,
AxisSet
{}),
...
...
@@ -1711,7 +1707,6 @@ TEST(${BACKEND_NAME}, broadcast_trivial)
TEST
(
$
{
BACKEND_NAME
},
broadcast_vector_colwise
)
{
SKIP_TEST_FOR
(
"GPU"
,
"${BACKEND_NAME}"
);
Shape
shape_a
{
3
};
auto
A
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
shape_a
);
Shape
shape_r
{
3
,
4
};
...
...
@@ -1734,7 +1729,6 @@ TEST(${BACKEND_NAME}, broadcast_vector_colwise)
TEST
(
$
{
BACKEND_NAME
},
broadcast_vector_rowwise
)
{
SKIP_TEST_FOR
(
"GPU"
,
"${BACKEND_NAME}"
);
Shape
shape_a
{
4
};
auto
A
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
shape_a
);
Shape
shape_r
{
3
,
4
};
...
...
@@ -1806,7 +1800,6 @@ TEST(${BACKEND_NAME}, broadcast_vector_rowwise_int64)
TEST
(
$
{
BACKEND_NAME
},
broadcast_matrix_0
)
{
SKIP_TEST_FOR
(
"GPU"
,
"${BACKEND_NAME}"
);
Shape
shape_a
{
2
,
2
};
auto
A
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
shape_a
);
Shape
shape_r
{
2
,
2
,
2
};
...
...
@@ -1829,7 +1822,6 @@ TEST(${BACKEND_NAME}, broadcast_matrix_0)
TEST
(
$
{
BACKEND_NAME
},
broadcast_matrix_1
)
{
SKIP_TEST_FOR
(
"GPU"
,
"${BACKEND_NAME}"
);
Shape
shape_a
{
2
,
2
};
auto
A
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
shape_a
);
Shape
shape_r
{
2
,
2
,
2
};
...
...
@@ -1852,7 +1844,6 @@ TEST(${BACKEND_NAME}, broadcast_matrix_1)
TEST
(
$
{
BACKEND_NAME
},
broadcast_matrix_2
)
{
SKIP_TEST_FOR
(
"GPU"
,
"${BACKEND_NAME}"
);
Shape
shape_a
{
2
,
2
};
auto
A
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
shape_a
);
Shape
shape_r
{
2
,
2
,
2
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment