Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
0a77e3d9
Commit
0a77e3d9
authored
Mar 08, 2018
by
Fenglei Tian
Browse files
Options
Browse Files
Download
Plain Diff
merge master and resolve conflict
parents
998d7c6b
529362b5
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
306 additions
and
85 deletions
+306
-85
gpu_cuda_kernel_builder.cpp
src/ngraph/runtime/gpu/gpu_cuda_kernel_builder.cpp
+13
-13
gpu_cuda_kernel_builder.hpp
src/ngraph/runtime/gpu/gpu_cuda_kernel_builder.hpp
+12
-12
gpu_cuda_kernel_emitters.cpp
src/ngraph/runtime/gpu/gpu_cuda_kernel_emitters.cpp
+1
-37
gpu_cuda_kernel_emitters.hpp
src/ngraph/runtime/gpu/gpu_cuda_kernel_emitters.hpp
+41
-1
gpu_cuda_kernel_ops.hpp
src/ngraph/runtime/gpu/gpu_cuda_kernel_ops.hpp
+153
-0
gpu_emitter.cpp
src/ngraph/runtime/gpu/gpu_emitter.cpp
+21
-0
gpu_emitter.hpp
src/ngraph/runtime/gpu/gpu_emitter.hpp
+7
-7
gpu_external_function.cpp
src/ngraph/runtime/gpu/gpu_external_function.cpp
+56
-0
backend_test.in.cpp
test/backend_test.in.cpp
+2
-15
No files found.
src/ngraph/runtime/gpu/gpu_cuda_kernel_builder.cpp
View file @
0a77e3d9
...
...
@@ -13,7 +13,6 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "ngraph/runtime/gpu/gpu_cuda_kernel_builder.hpp"
namespace
ngraph
...
...
@@ -22,10 +21,10 @@ namespace ngraph
{
namespace
gpu
{
void
CudaKernelBuilder
::
get_
1_element
_op
(
const
std
::
string
&
name
,
const
std
::
string
&
data_type
,
const
std
::
string
&
op
,
std
::
string
&
kernel
)
void
CudaKernelBuilder
::
get_
unary_elementwise
_op
(
const
std
::
string
&
name
,
const
std
::
string
&
data_type
,
const
std
::
string
&
op
,
std
::
string
&
kernel
)
{
kernel
=
R"(
extern "C" __global__
...
...
@@ -40,10 +39,10 @@ out[tid] =)" + op + "(in[tid]);\n" +
return
;
}
void
CudaKernelBuilder
::
get_
2_element
_op
(
const
std
::
string
&
name
,
const
std
::
string
&
data_type
,
const
std
::
string
&
op
,
std
::
string
&
kernel
)
void
CudaKernelBuilder
::
get_
binary_elementwise
_op
(
const
std
::
string
&
name
,
const
std
::
string
&
data_type
,
const
std
::
string
&
op
,
std
::
string
&
kernel
)
{
kernel
=
R"(
extern "C" __global__
...
...
@@ -60,10 +59,11 @@ out[tid] = in1[tid] )" + op +
return
;
}
void
CudaKernelBuilder
::
get_n_element_op
(
const
std
::
string
&
name
,
const
std
::
string
&
data_type
,
const
std
::
vector
<
std
::
string
>&
ops
,
std
::
string
&
kernel
)
void
CudaKernelBuilder
::
get_arbitrary_elementwise_op
(
const
std
::
string
&
name
,
const
std
::
string
&
data_type
,
const
std
::
vector
<
std
::
string
>&
ops
,
std
::
string
&
kernel
)
{
kernel
=
""
;
return
;
...
...
src/ngraph/runtime/gpu/gpu_cuda_kernel_builder.hpp
View file @
0a77e3d9
...
...
@@ -28,20 +28,20 @@ namespace ngraph
class
CudaKernelBuilder
{
public
:
static
void
get_
1_element
_op
(
const
std
::
string
&
name
,
const
std
::
string
&
data_type
,
const
std
::
string
&
op
,
std
::
string
&
kernel
);
static
void
get_
unary_elementwise
_op
(
const
std
::
string
&
name
,
const
std
::
string
&
data_type
,
const
std
::
string
&
op
,
std
::
string
&
kernel
);
static
void
get_
2_element
_op
(
const
std
::
string
&
name
,
const
std
::
string
&
data_type
,
const
std
::
string
&
op
,
std
::
string
&
kernel
);
static
void
get_
binary_elementwise
_op
(
const
std
::
string
&
name
,
const
std
::
string
&
data_type
,
const
std
::
string
&
op
,
std
::
string
&
kernel
);
static
void
get_
n_element
_op
(
const
std
::
string
&
name
,
const
std
::
string
&
data_type
,
const
std
::
vector
<
std
::
string
>&
ops
,
std
::
string
&
kernel
);
static
void
get_
arbitrary_elementwise
_op
(
const
std
::
string
&
name
,
const
std
::
string
&
data_type
,
const
std
::
vector
<
std
::
string
>&
ops
,
std
::
string
&
kernel
);
};
}
}
...
...
src/ngraph/runtime/gpu/gpu_cuda_kernel_emitters.cpp
View file @
0a77e3d9
...
...
@@ -17,10 +17,8 @@
#include <algorithm>
#include <map>
#include "ngraph/runtime/gpu/gpu_cuda_function_builder.hpp"
#include "ngraph/runtime/gpu/gpu_cuda_function_pool.hpp"
#include "ngraph/runtime/gpu/gpu_cuda_kernel_builder.hpp"
#include "ngraph/runtime/gpu/gpu_cuda_kernel_emitters.hpp"
#include "ngraph/runtime/gpu/gpu_cuda_kernel_ops.hpp"
namespace
ngraph
{
...
...
@@ -28,40 +26,6 @@ namespace ngraph
{
namespace
gpu
{
void
emit_abs
(
void
*
in
,
void
*
out
,
size_t
count
)
{
std
::
string
name
=
"abs"
;
// Create an instance of nvrtcProgram with the code string.
if
(
CudaFunctionPool
::
instance
().
get
(
name
)
==
nullptr
)
{
const
char
*
opts
[]
=
{
"--gpu-architecture=compute_35"
,
"--relocatable-device-code=true"
};
std
::
string
kernel
;
CudaKernelBuilder
::
get_1_element_op
(
name
,
"float"
,
"fabsf"
,
kernel
);
CudaFunctionPool
::
instance
().
set
(
name
,
CudaFunctionBuilder
::
get
(
"cuda_"
+
name
,
kernel
,
2
,
opts
));
}
//convert runtime ptr to driver api ptr
CUdeviceptr
d_ptr_in
,
d_ptr_out
;
d_ptr_in
=
(
CUdeviceptr
)
in
;
d_ptr_out
=
(
CUdeviceptr
)
out
;
void
*
args_list
[]
=
{
&
d_ptr_in
,
&
d_ptr_out
,
&
count
};
CUDA_SAFE_CALL
(
cuLaunchKernel
(
*
CudaFunctionPool
::
instance
().
get
(
name
).
get
(),
count
,
1
,
1
,
// grid dim
1
,
1
,
1
,
// block dim
0
,
NULL
,
// shared mem and stream
args_list
,
0
));
// arguments
CUDA_SAFE_CALL
(
cuCtxSynchronize
());
// Retrieve and print output.
}
void
emit_broadcast
(
void
*
in
,
void
*
out
,
size_t
repeat_size
,
size_t
repeat_times
,
size_t
count
)
{
...
...
src/ngraph/runtime/gpu/gpu_cuda_kernel_emitters.hpp
View file @
0a77e3d9
...
...
@@ -18,6 +18,9 @@
#include "ngraph/codegen/code_writer.hpp"
#include "ngraph/coordinate.hpp"
#include "ngraph/runtime/gpu/gpu_cuda_function_builder.hpp"
#include "ngraph/runtime/gpu/gpu_cuda_function_pool.hpp"
#include "ngraph/runtime/gpu/gpu_cuda_kernel_builder.hpp"
#include "ngraph/strides.hpp"
namespace
ngraph
...
...
@@ -26,9 +29,46 @@ namespace ngraph
{
namespace
gpu
{
void
emit_abs
(
void
*
in
,
void
*
out
,
size_t
count
);
template
<
typename
T
>
struct
CudaOpMap
;
void
emit_broadcast
(
void
*
in
,
void
*
out
,
size_t
repeat_size
,
size_t
repeat_times
,
size_t
count
);
template
<
typename
T
>
void
emit_unary_elementwise_op
(
void
*
in
,
void
*
out
,
size_t
count
,
std
::
string
name
)
{
// Create an instance of nvrtcProgram with the code string.
if
(
CudaFunctionPool
::
instance
().
get
(
name
)
==
nullptr
)
{
const
char
*
opts
[]
=
{
"--gpu-architecture=compute_35"
,
"--relocatable-device-code=true"
};
std
::
string
kernel
;
CudaKernelBuilder
::
get_unary_elementwise_op
(
name
,
"float"
,
CudaOpMap
<
T
>::
op
,
kernel
);
CudaFunctionPool
::
instance
().
set
(
name
,
CudaFunctionBuilder
::
get
(
"cuda_"
+
name
,
kernel
,
2
,
opts
));
}
//convert runtime ptr to driver api ptr
CUdeviceptr
d_ptr_in
,
d_ptr_out
;
d_ptr_in
=
(
CUdeviceptr
)
in
;
d_ptr_out
=
(
CUdeviceptr
)
out
;
void
*
args_list
[]
=
{
&
d_ptr_in
,
&
d_ptr_out
,
&
count
};
CUDA_SAFE_CALL
(
cuLaunchKernel
(
*
CudaFunctionPool
::
instance
().
get
(
name
).
get
(),
count
,
1
,
1
,
// grid dim
1
,
1
,
1
,
// block dim
0
,
NULL
,
// shared mem and stream
args_list
,
0
));
// arguments
CUDA_SAFE_CALL
(
cuCtxSynchronize
());
// Retrieve and print output.
}
}
}
}
src/ngraph/runtime/gpu/gpu_cuda_kernel_ops.hpp
0 → 100644
View file @
0a77e3d9
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
namespace
ngraph
{
namespace
op
{
class
Abs
;
class
Acos
;
class
Asin
;
class
Atan
;
class
Ceiling
;
class
Cos
;
class
Cosh
;
class
Exp
;
class
Floor
;
class
Log
;
class
Sin
;
class
Sinh
;
class
Tan
;
class
Tanh
;
// Unimplemented or unused in favor of cuDNN impl.
class
Max
;
class
Min
;
class
Negative
;
class
Not
;
class
Sign
;
class
Sqrt
;
}
namespace
runtime
{
namespace
gpu
{
template
<>
struct
CudaOpMap
<
ngraph
::
op
::
Abs
>
{
static
constexpr
const
char
*
op
=
"fabsf"
;
};
template
<>
struct
CudaOpMap
<
ngraph
::
op
::
Acos
>
{
static
constexpr
const
char
*
op
=
"acosf"
;
};
template
<>
struct
CudaOpMap
<
ngraph
::
op
::
Asin
>
{
static
constexpr
const
char
*
op
=
"asinf"
;
};
template
<>
struct
CudaOpMap
<
ngraph
::
op
::
Atan
>
{
static
constexpr
const
char
*
op
=
"atanf"
;
};
template
<>
struct
CudaOpMap
<
ngraph
::
op
::
Ceiling
>
{
static
constexpr
const
char
*
op
=
"ceilf"
;
};
template
<>
struct
CudaOpMap
<
ngraph
::
op
::
Cos
>
{
static
constexpr
const
char
*
op
=
"cosf"
;
};
template
<>
struct
CudaOpMap
<
ngraph
::
op
::
Cosh
>
{
static
constexpr
const
char
*
op
=
"coshf"
;
};
template
<>
struct
CudaOpMap
<
ngraph
::
op
::
Exp
>
{
static
constexpr
const
char
*
op
=
"expf"
;
};
template
<>
struct
CudaOpMap
<
ngraph
::
op
::
Floor
>
{
static
constexpr
const
char
*
op
=
"floorf"
;
};
template
<>
struct
CudaOpMap
<
ngraph
::
op
::
Log
>
{
static
constexpr
const
char
*
op
=
"logf"
;
};
template
<>
struct
CudaOpMap
<
ngraph
::
op
::
Max
>
{
static
constexpr
const
char
*
op
=
"fmaxf"
;
};
template
<>
struct
CudaOpMap
<
ngraph
::
op
::
Min
>
{
static
constexpr
const
char
*
op
=
"fminf"
;
};
template
<>
struct
CudaOpMap
<
ngraph
::
op
::
Sin
>
{
static
constexpr
const
char
*
op
=
"sinf"
;
};
template
<>
struct
CudaOpMap
<
ngraph
::
op
::
Sinh
>
{
static
constexpr
const
char
*
op
=
"sinhf"
;
};
template
<>
struct
CudaOpMap
<
ngraph
::
op
::
Sqrt
>
{
static
constexpr
const
char
*
op
=
"sqrtf"
;
};
template
<>
struct
CudaOpMap
<
ngraph
::
op
::
Tan
>
{
static
constexpr
const
char
*
op
=
"tanf"
;
};
template
<>
struct
CudaOpMap
<
ngraph
::
op
::
Tanh
>
{
static
constexpr
const
char
*
op
=
"tanhf"
;
};
}
}
}
src/ngraph/runtime/gpu/gpu_emitter.cpp
View file @
0a77e3d9
...
...
@@ -121,6 +121,27 @@ namespace ngraph
writer
<<
"}
\n
"
;
}
void
GPU_Emitter
::
EmitUnaryElementwise
(
GPU_ExternalFunction
*
external_function
,
codegen
::
CodeWriter
&
writer
,
const
ngraph
::
Node
*
node
,
const
std
::
vector
<
GPU_TensorViewWrapper
>&
args
,
const
std
::
vector
<
GPU_TensorViewWrapper
>&
out
)
{
if
(
out
[
0
].
get_size
()
==
0
)
{
return
;
}
writer
<<
"{ // "
<<
node
->
get_name
()
<<
"
\n
"
;
writer
.
indent
++
;
writer
<<
"int count = "
<<
out
[
0
].
get_size
()
<<
";
\n
"
;
writer
<<
"if(count == 0) return;
\n
"
;
writer
<<
"ngraph::runtime::gpu::emit_unary_elementwise_op<ngraph::op::"
<<
node
->
description
()
<<
">((void*) "
<<
args
[
0
].
get_name
()
<<
", (void*) "
<<
out
[
0
].
get_name
()
<<
", count,
\"
"
<<
node
->
description
()
<<
"
\"
);
\n
"
;
writer
.
indent
--
;
writer
<<
"}
\n
"
;
}
template
<>
void
GPU_Emitter
::
EMITTER_DECL
(
ngraph
::
op
::
Add
)
{
...
...
src/ngraph/runtime/gpu/gpu_emitter.hpp
View file @
0a77e3d9
...
...
@@ -58,13 +58,13 @@ namespace ngraph
{
}
private
:
static
std
::
string
emit_vector
(
const
GPU_TensorViewWrapper
&
,
const
std
::
string
&
name
=
""
);
static
std
::
string
emit_array1d
(
const
GPU_TensorViewWrapper
&
,
const
std
::
string
&
name
=
""
);
static
std
::
string
emit_matrix
(
const
GPU_TensorViewWrapper
&
,
const
std
::
string
&
name
=
""
);
static
void
EmitUnaryElementwise
(
GPU_ExternalFunction
*
external_function
,
codegen
::
CodeWriter
&
writer
,
const
ngraph
::
Node
*
node
,
const
std
::
vector
<
GPU_TensorViewWrapper
>&
args
,
const
std
::
vector
<
GPU_TensorViewWrapper
>&
out
)
{
}
};
}
}
...
...
src/ngraph/runtime/gpu/gpu_external_function.cpp
View file @
0a77e3d9
...
...
@@ -160,6 +160,7 @@ static StaticInitializers s_static_initializers;
#define TI(x) type_index(typeid(x))
<<<<<<<
HEAD
static
const
ngraph
::
runtime
::
gpu
::
OpMap
dispatcher
{
{
TI
(
ngraph
::
op
::
Add
),
&
ngraph
::
runtime
::
gpu
::
GPU_Emitter
::
emit
<
ngraph
::
op
::
Add
>
},
{
TI
(
ngraph
::
op
::
Dot
),
&
ngraph
::
runtime
::
gpu
::
GPU_Emitter
::
emit
<
ngraph
::
op
::
Dot
>
},
...
...
@@ -239,6 +240,60 @@ static const ngraph::runtime::gpu::OpMap dispatcher{
{
TI
(
ngraph
::
op
::
ReluBackprop
),
&
ngraph
::
runtime
::
gpu
::
GPU_Emitter
::
emit
<
ngraph
::
op
::
ReluBackprop
>
},
{
TI
(
ngraph
::
op
::
Softmax
),
&
ngraph
::
runtime
::
gpu
::
GPU_Emitter
::
emit
<
ngraph
::
op
::
Softmax
>
},
=======
static
const
runtime
::
gpu
::
OpMap
dispatcher
{
{
TI
(
ngraph
::
op
::
Add
),
&
runtime
::
gpu
::
GPU_Emitter
::
EmitAdd
},
{
TI
(
ngraph
::
op
::
Dot
),
&
runtime
::
gpu
::
GPU_Emitter
::
EmitDot
},
{
TI
(
ngraph
::
op
::
Multiply
),
&
runtime
::
gpu
::
GPU_Emitter
::
EmitMultiply
},
{
TI
(
ngraph
::
op
::
Parameter
),
&
runtime
::
gpu
::
GPU_Emitter
::
EmitNop
},
{
TI
(
ngraph
::
op
::
Abs
),
&
runtime
::
gpu
::
GPU_Emitter
::
EmitUnaryElementwise
},
{
TI
(
ngraph
::
op
::
Concat
),
&
runtime
::
gpu
::
GPU_Emitter
::
EmitConcat
},
{
TI
(
ngraph
::
op
::
Divide
),
&
runtime
::
gpu
::
GPU_Emitter
::
EmitDivide
},
{
TI
(
ngraph
::
op
::
Equal
),
&
runtime
::
gpu
::
GPU_Emitter
::
EmitEqual
},
{
TI
(
ngraph
::
op
::
Greater
),
&
runtime
::
gpu
::
GPU_Emitter
::
EmitGreater
},
{
TI
(
ngraph
::
op
::
GreaterEq
),
&
runtime
::
gpu
::
GPU_Emitter
::
EmitGreaterEq
},
{
TI
(
ngraph
::
op
::
Less
),
&
runtime
::
gpu
::
GPU_Emitter
::
EmitLess
},
{
TI
(
ngraph
::
op
::
LessEq
),
&
runtime
::
gpu
::
GPU_Emitter
::
EmitLessEq
},
{
TI
(
ngraph
::
op
::
Log
),
&
runtime
::
gpu
::
GPU_Emitter
::
EmitUnaryElementwise
},
{
TI
(
ngraph
::
op
::
Maximum
),
&
runtime
::
gpu
::
GPU_Emitter
::
EmitMaximum
},
{
TI
(
ngraph
::
op
::
Minimum
),
&
runtime
::
gpu
::
GPU_Emitter
::
EmitMinimum
},
{
TI
(
ngraph
::
op
::
Negative
),
&
runtime
::
gpu
::
GPU_Emitter
::
EmitNegative
},
{
TI
(
ngraph
::
op
::
NotEqual
),
&
runtime
::
gpu
::
GPU_Emitter
::
EmitNotEqual
},
{
TI
(
ngraph
::
op
::
Power
),
&
runtime
::
gpu
::
GPU_Emitter
::
EmitPower
},
{
TI
(
ngraph
::
op
::
Select
),
&
runtime
::
gpu
::
GPU_Emitter
::
EmitSelect
},
{
TI
(
ngraph
::
op
::
Subtract
),
&
runtime
::
gpu
::
GPU_Emitter
::
EmitSubtract
},
{
TI
(
ngraph
::
op
::
Broadcast
),
&
runtime
::
gpu
::
GPU_Emitter
::
EmitBroadcast
},
{
TI
(
ngraph
::
op
::
Convert
),
&
runtime
::
gpu
::
GPU_Emitter
::
EmitConvert
},
{
TI
(
ngraph
::
op
::
Constant
),
&
runtime
::
gpu
::
GPU_Emitter
::
EmitConstant
},
{
TI
(
ngraph
::
op
::
Reshape
),
&
runtime
::
gpu
::
GPU_Emitter
::
EmitReshape
},
{
TI
(
ngraph
::
op
::
FunctionCall
),
&
runtime
::
gpu
::
GPU_Emitter
::
EmitFunctionCall
},
{
TI
(
ngraph
::
op
::
Reduce
),
&
runtime
::
gpu
::
GPU_Emitter
::
EmitReduce
},
{
TI
(
ngraph
::
op
::
Sign
),
&
runtime
::
gpu
::
GPU_Emitter
::
EmitUnaryElementwise
},
{
TI
(
ngraph
::
op
::
Slice
),
&
runtime
::
gpu
::
GPU_Emitter
::
EmitSlice
},
{
TI
(
ngraph
::
op
::
Sum
),
&
runtime
::
gpu
::
GPU_Emitter
::
EmitSum
},
{
TI
(
ngraph
::
op
::
Exp
),
&
runtime
::
gpu
::
GPU_Emitter
::
EmitUnaryElementwise
},
{
TI
(
ngraph
::
op
::
Sin
),
&
runtime
::
gpu
::
GPU_Emitter
::
EmitUnaryElementwise
},
{
TI
(
ngraph
::
op
::
Sinh
),
&
runtime
::
gpu
::
GPU_Emitter
::
EmitUnaryElementwise
},
{
TI
(
ngraph
::
op
::
Cos
),
&
runtime
::
gpu
::
GPU_Emitter
::
EmitUnaryElementwise
},
{
TI
(
ngraph
::
op
::
Cosh
),
&
runtime
::
gpu
::
GPU_Emitter
::
EmitUnaryElementwise
},
{
TI
(
ngraph
::
op
::
Tan
),
&
runtime
::
gpu
::
GPU_Emitter
::
EmitUnaryElementwise
},
{
TI
(
ngraph
::
op
::
Tanh
),
&
runtime
::
gpu
::
GPU_Emitter
::
EmitUnaryElementwise
},
{
TI
(
ngraph
::
op
::
Asin
),
&
runtime
::
gpu
::
GPU_Emitter
::
EmitUnaryElementwise
},
{
TI
(
ngraph
::
op
::
Acos
),
&
runtime
::
gpu
::
GPU_Emitter
::
EmitUnaryElementwise
},
{
TI
(
ngraph
::
op
::
Atan
),
&
runtime
::
gpu
::
GPU_Emitter
::
EmitUnaryElementwise
},
{
TI
(
ngraph
::
op
::
ReplaceSlice
),
&
runtime
::
gpu
::
GPU_Emitter
::
EmitReplaceSlice
},
{
TI
(
ngraph
::
op
::
OneHot
),
&
runtime
::
gpu
::
GPU_Emitter
::
EmitOneHot
},
{
TI
(
ngraph
::
op
::
Floor
),
&
runtime
::
gpu
::
GPU_Emitter
::
EmitUnaryElementwise
},
{
TI
(
ngraph
::
op
::
Ceiling
),
&
runtime
::
gpu
::
GPU_Emitter
::
EmitUnaryElementwise
},
{
TI
(
ngraph
::
op
::
Sqrt
),
&
runtime
::
gpu
::
GPU_Emitter
::
EmitSqrt
},
{
TI
(
ngraph
::
op
::
Convolution
),
&
runtime
::
gpu
::
GPU_Emitter
::
EmitConvolution
},
{
TI
(
ngraph
::
op
::
Not
),
&
runtime
::
gpu
::
GPU_Emitter
::
EmitUnaryElementwise
},
{
TI
(
ngraph
::
op
::
MaxPool
),
&
runtime
::
gpu
::
GPU_Emitter
::
EmitMaxPool
},
{
TI
(
ngraph
::
op
::
Reverse
),
&
runtime
::
gpu
::
GPU_Emitter
::
EmitReverse
},
{
TI
(
ngraph
::
op
::
ReduceWindow
),
&
runtime
::
gpu
::
GPU_Emitter
::
EmitReduceWindow
},
{
TI
(
ngraph
::
op
::
SelectAndScatter
),
&
runtime
::
gpu
::
GPU_Emitter
::
EmitSelectAndScatter
},
{
TI
(
ngraph
::
op
::
Result
),
&
runtime
::
gpu
::
GPU_Emitter
::
EmitResult
},
>>>>>>>
origin
/
master
};
ngraph
::
runtime
::
gpu
::
GPU_ExternalFunction
::
GPU_ExternalFunction
(
...
...
@@ -292,6 +347,7 @@ void ngraph::runtime::gpu::GPU_ExternalFunction::compile()
#include "ngraph/pass/memory_layout.hpp"
#include "ngraph/runtime/aligned_buffer.hpp"
#include "ngraph/runtime/gpu/gpu_cuda_kernel_emitters.hpp"
#include "ngraph/runtime/gpu/gpu_cuda_kernel_ops.hpp"
#include "ngraph/runtime/gpu/gpu_util.hpp"
#include "ngraph/util.hpp"
)"
;
...
...
test/backend_test.in.cpp
View file @
0a77e3d9
...
...
@@ -336,7 +336,6 @@ TEST(${BACKEND_NAME}, abs)
TEST
(
$
{
BACKEND_NAME
},
ceiling
)
{
SKIP_TEST_FOR
(
"GPU"
,
"${BACKEND_NAME}"
);
Shape
shape
{
2
,
2
};
auto
A
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
shape
);
auto
f
=
make_shared
<
Function
>
(
make_shared
<
op
::
Ceiling
>
(
A
),
op
::
ParameterVector
{
A
});
...
...
@@ -777,7 +776,6 @@ TEST(${BACKEND_NAME}, equal)
TEST
(
$
{
BACKEND_NAME
},
floor
)
{
SKIP_TEST_FOR
(
"GPU"
,
"${BACKEND_NAME}"
);
Shape
shape
{
2
,
2
};
auto
A
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
shape
);
auto
f
=
make_shared
<
Function
>
(
make_shared
<
op
::
Floor
>
(
A
),
op
::
ParameterVector
{
A
});
...
...
@@ -1371,7 +1369,6 @@ TEST(${BACKEND_NAME}, lesseq_bool)
TEST
(
$
{
BACKEND_NAME
},
log
)
{
SKIP_TEST_FOR
(
"GPU"
,
"${BACKEND_NAME}"
);
Shape
shape
{
2
,
2
,
2
};
auto
A
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
shape
);
auto
f
=
make_shared
<
Function
>
(
make_shared
<
op
::
Log
>
(
A
),
op
::
ParameterVector
{
A
});
...
...
@@ -2674,7 +2671,6 @@ TEST(${BACKEND_NAME}, reshape_6d)
TEST
(
$
{
BACKEND_NAME
},
sin
)
{
SKIP_TEST_FOR
(
"GPU"
,
"${BACKEND_NAME}"
);
Shape
shape
{
6
};
auto
A
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
shape
);
auto
f
=
make_shared
<
Function
>
(
make_shared
<
op
::
Sin
>
(
A
),
op
::
ParameterVector
{
A
});
...
...
@@ -2700,7 +2696,6 @@ TEST(${BACKEND_NAME}, sin)
TEST
(
$
{
BACKEND_NAME
},
cos
)
{
SKIP_TEST_FOR
(
"GPU"
,
"${BACKEND_NAME}"
);
Shape
shape
{
6
};
auto
A
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
shape
);
auto
f
=
make_shared
<
Function
>
(
make_shared
<
op
::
Cos
>
(
A
),
op
::
ParameterVector
{
A
});
...
...
@@ -2726,7 +2721,6 @@ TEST(${BACKEND_NAME}, cos)
TEST
(
$
{
BACKEND_NAME
},
tan
)
{
SKIP_TEST_FOR
(
"GPU"
,
"${BACKEND_NAME}"
);
Shape
shape
{
6
};
auto
A
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
shape
);
auto
f
=
make_shared
<
Function
>
(
make_shared
<
op
::
Tan
>
(
A
),
op
::
ParameterVector
{
A
});
...
...
@@ -2747,12 +2741,11 @@ TEST(${BACKEND_NAME}, tan)
input
.
begin
(),
input
.
end
(),
input
.
begin
(),
[](
float
x
)
->
float
{
return
tanf
(
x
);
});
cf
->
call
({
a
},
{
result
});
EXPECT_
EQ
(
input
,
read_vector
<
float
>
(
result
));
EXPECT_
TRUE
(
test
::
all_close
(
input
,
read_vector
<
float
>
(
result
)
));
}
TEST
(
$
{
BACKEND_NAME
},
asin
)
{
SKIP_TEST_FOR
(
"GPU"
,
"${BACKEND_NAME}"
);
Shape
shape
{
6
};
auto
A
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
shape
);
auto
f
=
make_shared
<
Function
>
(
make_shared
<
op
::
Asin
>
(
A
),
op
::
ParameterVector
{
A
});
...
...
@@ -2777,7 +2770,6 @@ TEST(${BACKEND_NAME}, asin)
TEST
(
$
{
BACKEND_NAME
},
acos
)
{
SKIP_TEST_FOR
(
"GPU"
,
"${BACKEND_NAME}"
);
Shape
shape
{
6
};
auto
A
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
shape
);
auto
f
=
make_shared
<
Function
>
(
make_shared
<
op
::
Acos
>
(
A
),
op
::
ParameterVector
{
A
});
...
...
@@ -2802,7 +2794,6 @@ TEST(${BACKEND_NAME}, acos)
TEST
(
$
{
BACKEND_NAME
},
atan
)
{
SKIP_TEST_FOR
(
"GPU"
,
"${BACKEND_NAME}"
);
Shape
shape
{
6
};
auto
A
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
shape
);
auto
f
=
make_shared
<
Function
>
(
make_shared
<
op
::
Atan
>
(
A
),
op
::
ParameterVector
{
A
});
...
...
@@ -2827,7 +2818,6 @@ TEST(${BACKEND_NAME}, atan)
TEST
(
$
{
BACKEND_NAME
},
sinh
)
{
SKIP_TEST_FOR
(
"GPU"
,
"${BACKEND_NAME}"
);
Shape
shape
{
6
};
auto
A
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
shape
);
auto
f
=
make_shared
<
Function
>
(
make_shared
<
op
::
Sinh
>
(
A
),
op
::
ParameterVector
{
A
});
...
...
@@ -2852,7 +2842,6 @@ TEST(${BACKEND_NAME}, sinh)
TEST
(
$
{
BACKEND_NAME
},
cosh
)
{
SKIP_TEST_FOR
(
"GPU"
,
"${BACKEND_NAME}"
);
Shape
shape
{
6
};
auto
A
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
shape
);
auto
f
=
make_shared
<
Function
>
(
make_shared
<
op
::
Cosh
>
(
A
),
op
::
ParameterVector
{
A
});
...
...
@@ -2877,7 +2866,6 @@ TEST(${BACKEND_NAME}, cosh)
TEST
(
$
{
BACKEND_NAME
},
tanh
)
{
SKIP_TEST_FOR
(
"GPU"
,
"${BACKEND_NAME}"
);
Shape
shape
{
6
};
auto
A
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
shape
);
auto
f
=
make_shared
<
Function
>
(
make_shared
<
op
::
Tanh
>
(
A
),
op
::
ParameterVector
{
A
});
...
...
@@ -2897,12 +2885,11 @@ TEST(${BACKEND_NAME}, tanh)
input
.
begin
(),
input
.
end
(),
input
.
begin
(),
[](
float
x
)
->
float
{
return
tanhf
(
x
);
});
cf
->
call
({
a
},
{
result
});
EXPECT_
EQ
(
input
,
read_vector
<
float
>
(
result
));
EXPECT_
TRUE
(
test
::
all_close
(
input
,
read_vector
<
float
>
(
result
)
));
}
TEST
(
$
{
BACKEND_NAME
},
exp
)
{
SKIP_TEST_FOR
(
"GPU"
,
"${BACKEND_NAME}"
);
Shape
shape
{
8
};
auto
A
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
shape
);
auto
f
=
make_shared
<
Function
>
(
make_shared
<
op
::
Exp
>
(
A
),
op
::
ParameterVector
{
A
});
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment