Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
3d53e58a
Commit
3d53e58a
authored
Mar 13, 2018
by
fenglei.tian
Browse files
Options
Browse Files
Download
Plain Diff
merge and resolve conflict with origin master
parents
39dc384d
b5467550
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
407 additions
and
194 deletions
+407
-194
cpu_emitter.cpp
src/ngraph/runtime/cpu/cpu_emitter.cpp
+39
-0
cpu_external_function.cpp
src/ngraph/runtime/cpu/cpu_external_function.cpp
+1
-0
mkldnn_emitter.cpp
src/ngraph/runtime/cpu/mkldnn_emitter.cpp
+26
-0
mkldnn_emitter.hpp
src/ngraph/runtime/cpu/mkldnn_emitter.hpp
+4
-0
sigmoid.cpp
src/ngraph/runtime/cpu/ops/sigmoid.cpp
+21
-0
sigmoid.hpp
src/ngraph/runtime/cpu/ops/sigmoid.hpp
+24
-0
cpu_assignment.cpp
src/ngraph/runtime/cpu/pass/cpu_assignment.cpp
+15
-0
cpu_fusion.cpp
src/ngraph/runtime/cpu/pass/cpu_fusion.cpp
+51
-0
cpu_fusion.hpp
src/ngraph/runtime/cpu/pass/cpu_fusion.hpp
+2
-0
cpu_layout.cpp
src/ngraph/runtime/cpu/pass/cpu_layout.cpp
+25
-0
gpu_cuda_context_manager.cpp
src/ngraph/runtime/gpu/gpu_cuda_context_manager.cpp
+12
-19
gpu_cuda_function_builder.cpp
src/ngraph/runtime/gpu/gpu_cuda_function_builder.cpp
+29
-36
gpu_cuda_function_pool.cpp
src/ngraph/runtime/gpu/gpu_cuda_function_pool.cpp
+23
-32
gpu_cuda_kernel_builder.cpp
src/ngraph/runtime/gpu/gpu_cuda_kernel_builder.cpp
+55
-62
gpu_cuda_kernel_emitters.cpp
src/ngraph/runtime/gpu/gpu_cuda_kernel_emitters.cpp
+31
-38
gpu_emitter.cpp
src/ngraph/runtime/gpu/gpu_emitter.cpp
+1
-5
gpu_external_function.cpp
src/ngraph/runtime/gpu/gpu_external_function.cpp
+0
-0
gpu_tensor_view.cpp
src/ngraph/runtime/gpu/gpu_tensor_view.cpp
+1
-1
gpu_util.cpp
src/ngraph/runtime/gpu/gpu_util.cpp
+1
-1
cpu_fusion.cpp
test/cpu_fusion.cpp
+46
-0
No files found.
src/ngraph/runtime/cpu/cpu_emitter.cpp
View file @
3d53e58a
...
...
@@ -3279,6 +3279,45 @@ namespace ngraph
<<
to_string
(
sigmoid_index
)
<<
");
\n
"
;
}
template
<>
void
CPU_Emitter
::
EMITTER_DECL
(
ngraph
::
op
::
SigmoidBackprop
)
{
auto
input_shape
=
args
[
0
].
get_shape
();
auto
delta_shape
=
args
[
1
].
get_shape
();
auto
result_shape
=
out
[
0
].
get_shape
();
int
input_1d_size
=
static_cast
<
int
>
(
shape_size
(
input_shape
));
int
delta_1d_size
=
static_cast
<
int
>
(
shape_size
(
delta_shape
));
int
result_1d_size
=
static_cast
<
int
>
(
shape_size
(
result_shape
));
auto
&
mkldnn_emitter
=
external_function
->
get_mkldnn_emitter
();
auto
input_desc
=
mkldnn
::
memory
::
desc
(
{
input_1d_size
},
mkldnn_utils
::
get_mkldnn_data_type
(
args
[
0
].
get_element_type
()),
mkldnn
::
memory
::
format
::
x
);
auto
delta_desc
=
mkldnn
::
memory
::
desc
(
{
delta_1d_size
},
mkldnn_utils
::
get_mkldnn_data_type
(
args
[
1
].
get_element_type
()),
mkldnn
::
memory
::
format
::
x
);
auto
result_desc
=
mkldnn
::
memory
::
desc
(
{
result_1d_size
},
mkldnn_utils
::
get_mkldnn_data_type
(
out
[
0
].
get_element_type
()),
mkldnn
::
memory
::
format
::
x
);
size_t
sigmoid_index
=
mkldnn_emitter
->
build_sigmoid_backward
(
input_desc
,
delta_desc
,
result_desc
);
auto
&
deps
=
mkldnn_emitter
->
get_primitive_deps
(
sigmoid_index
);
writer
<<
"cpu::mkldnn_utils::set_memory_ptr(ctx, "
<<
to_string
(
deps
[
0
])
<<
", "
<<
args
[
0
].
get_name
()
<<
");
\n
"
;
writer
<<
"cpu::mkldnn_utils::set_memory_ptr(ctx, "
<<
to_string
(
deps
[
1
])
<<
", "
<<
args
[
1
].
get_name
()
<<
");
\n
"
;
writer
<<
"cpu::mkldnn_utils::set_memory_ptr(ctx, "
<<
to_string
(
deps
[
2
])
<<
", "
<<
out
[
0
].
get_name
()
<<
");
\n
"
;
writer
<<
"cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, "
<<
to_string
(
sigmoid_index
)
<<
");
\n
"
;
}
template
<>
void
CPU_Emitter
::
EMITTER_DECL
(
ngraph
::
op
::
Softmax
)
{
...
...
src/ngraph/runtime/cpu/cpu_external_function.cpp
View file @
3d53e58a
...
...
@@ -252,6 +252,7 @@ static const runtime::cpu::OpMap dispatcher{
{
TI
(
ngraph
::
op
::
ReluBackprop
),
&
runtime
::
cpu
::
CPU_Emitter
::
emit
<
op
::
ReluBackprop
>
},
{
TI
(
ngraph
::
op
::
Sigmoid
),
&
runtime
::
cpu
::
CPU_Emitter
::
emit
<
op
::
Sigmoid
>
},
{
TI
(
ngraph
::
op
::
Softmax
),
&
runtime
::
cpu
::
CPU_Emitter
::
emit
<
op
::
Softmax
>
},
{
TI
(
ngraph
::
op
::
SigmoidBackprop
),
&
runtime
::
cpu
::
CPU_Emitter
::
emit
<
op
::
SigmoidBackprop
>
},
};
runtime
::
cpu
::
CPU_ExternalFunction
::
CPU_ExternalFunction
(
...
...
src/ngraph/runtime/cpu/mkldnn_emitter.cpp
View file @
3d53e58a
...
...
@@ -513,6 +513,32 @@ size_t MKLDNNEmitter::build_sigmoid_forward(const mkldnn::memory::desc& input_de
return
primitive_index
;
}
size_t
MKLDNNEmitter
::
build_sigmoid_backward
(
const
mkldnn
::
memory
::
desc
&
input_desc
,
const
mkldnn
::
memory
::
desc
&
delta_desc
,
const
mkldnn
::
memory
::
desc
&
result_desc
)
{
size_t
input_index
=
build_memory_primitive
(
input_desc
);
size_t
delta_index
=
build_memory_primitive
(
delta_desc
);
size_t
result_index
=
build_memory_primitive
(
result_desc
);
// sigmoid forward primitive desc
mkldnn
::
eltwise_forward
::
primitive_desc
sigmoid_fwd_pd
=
mkldnn
::
eltwise_forward
::
primitive_desc
(
{
mkldnn
::
prop_kind
::
forward
,
mkldnn
::
algorithm
::
eltwise_logistic
,
input_desc
,
0
,
0
},
mkldnn_utils
::
global_cpu_engine
);
size_t
primitive_index
=
insert_primitive
(
new
mkldnn
::
eltwise_backward
(
{{
mkldnn
::
algorithm
::
eltwise_logistic
,
delta_desc
,
input_desc
,
0
,
0
},
mkldnn_utils
::
global_cpu_engine
,
sigmoid_fwd_pd
},
*
m_mkldnn_primitives
[
input_index
],
*
m_mkldnn_primitives
[
delta_index
],
*
m_mkldnn_primitives
[
result_index
]));
m_primitive_deps
[
primitive_index
]
=
{
input_index
,
delta_index
,
result_index
};
return
primitive_index
;
}
size_t
MKLDNNEmitter
::
build_elementwise_add
(
const
mkldnn
::
memory
::
desc
&
input0_data_desc
,
const
mkldnn
::
memory
::
desc
&
input1_data_desc
,
...
...
src/ngraph/runtime/cpu/mkldnn_emitter.hpp
View file @
3d53e58a
...
...
@@ -153,6 +153,10 @@ namespace ngraph
size_t
build_sigmoid_forward
(
const
mkldnn
::
memory
::
desc
&
input_desc
,
const
mkldnn
::
memory
::
desc
&
result_desc
);
size_t
build_sigmoid_backward
(
const
mkldnn
::
memory
::
desc
&
input_desc
,
const
mkldnn
::
memory
::
desc
&
delta_desc
,
const
mkldnn
::
memory
::
desc
&
result_desc
);
size_t
build_elementwise_add
(
const
mkldnn
::
memory
::
desc
&
input0_data_desc
,
const
mkldnn
::
memory
::
desc
&
input1_data_desc
,
...
...
src/ngraph/runtime/cpu/ops/sigmoid.cpp
View file @
3d53e58a
...
...
@@ -35,3 +35,24 @@ ngraph::op::Sigmoid::Sigmoid(std::shared_ptr<ngraph::Node> input)
{
add_output
(
input
->
get_element_type
(),
m_shape_input
);
}
ngraph
::
op
::
SigmoidBackprop
::
SigmoidBackprop
(
std
::
shared_ptr
<
Node
>
arg
,
std
::
shared_ptr
<
Node
>
delta
)
:
RequiresTensorViewArgs
(
"SigmoidBackprop"
,
{
arg
,
delta
})
{
if
(
arg
->
get_element_type
()
!=
delta
->
get_element_type
())
{
throw
ngraph_error
(
"Argument and delta element types for Sigmoid backprop do not match"
);
}
if
(
arg
->
get_shape
()
!=
delta
->
get_shape
())
{
throw
ngraph_error
(
"Argument and delta shape for Sigmoid backprop do not match"
);
}
set_value_type_checked
(
delta
->
get_element_type
(),
delta
->
get_shape
());
}
void
ngraph
::
op
::
Sigmoid
::
generate_adjoints
(
ngraph
::
autodiff
::
Adjoints
&
adjoints
,
const
std
::
shared_ptr
<
Node
>&
delta
)
{
auto
backprop
=
std
::
make_shared
<
op
::
SigmoidBackprop
>
(
get_input_op
(
0
),
delta
);
adjoints
.
add_delta
(
get_input_op
(
0
),
backprop
);
}
src/ngraph/runtime/cpu/ops/sigmoid.hpp
View file @
3d53e58a
...
...
@@ -17,6 +17,7 @@
#pragma once
#include "ngraph/ops/util/requires_tensor_view_args.hpp"
#include "ngraph/util.hpp"
namespace
ngraph
{
...
...
@@ -29,9 +30,32 @@ namespace ngraph
Shape
get_input_shape
()
const
{
return
m_shape_input
;
}
virtual
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
virtual
void
generate_adjoints
(
autodiff
::
Adjoints
&
adjoints
,
const
std
::
shared_ptr
<
Node
>&
delta
)
override
;
private
:
Shape
m_shape_input
;
};
/// \brief Elementwise SigmoidBackprop operation.
///
class
SigmoidBackprop
:
public
util
::
RequiresTensorViewArgs
{
public
:
/// \brief Constructs a SigmoidBackprop operation.
///
/// \param arg Node that produces the Sigmoid forward input tensor.
SigmoidBackprop
(
std
::
shared_ptr
<
ngraph
::
Node
>
arg
,
std
::
shared_ptr
<
ngraph
::
Node
>
delta
);
virtual
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
{
if
(
new_args
.
size
()
!=
2
)
{
throw
ngraph_error
(
"Incorrect number of new arguments"
);
}
return
std
::
make_shared
<
SigmoidBackprop
>
(
new_args
.
at
(
0
),
new_args
.
at
(
1
));
}
};
}
}
src/ngraph/runtime/cpu/pass/cpu_assignment.cpp
View file @
3d53e58a
...
...
@@ -316,6 +316,19 @@ namespace ngraph
}
}
template
<>
void
CPUAssignment
::
ASSIGN_DECL
(
ngraph
::
op
::
SigmoidBackprop
)
{
auto
sigmoid
=
static_cast
<
op
::
SigmoidBackprop
*>
(
node
);
if
(
node
->
get_input_element_type
(
0
)
==
element
::
f32
)
{
auto
op_annotations
=
std
::
make_shared
<
ngraph
::
runtime
::
cpu
::
CPUOpAnnotations
>
();
op_annotations
->
set_mkldnn_op
(
true
);
sigmoid
->
set_op_annotations
(
op_annotations
);
}
}
template
<>
void
CPUAssignment
::
ASSIGN_DECL
(
ngraph
::
op
::
ReluBackprop
)
{
...
...
@@ -386,6 +399,8 @@ static const runtime::cpu::pass::AssignOpMap s_dispatcher{
{
TI
(
ngraph
::
op
::
ReluBackprop
),
&
runtime
::
cpu
::
pass
::
CPUAssignment
::
assign
<
ngraph
::
op
::
ReluBackprop
>
},
{
TI
(
ngraph
::
op
::
Sigmoid
),
&
runtime
::
cpu
::
pass
::
CPUAssignment
::
assign
<
ngraph
::
op
::
Sigmoid
>
},
{
TI
(
ngraph
::
op
::
SigmoidBackprop
),
&
runtime
::
cpu
::
pass
::
CPUAssignment
::
assign
<
ngraph
::
op
::
SigmoidBackprop
>
},
};
bool
runtime
::
cpu
::
pass
::
CPUAssignment
::
run_on_call_graph
(
...
...
src/ngraph/runtime/cpu/pass/cpu_fusion.cpp
View file @
3d53e58a
...
...
@@ -568,6 +568,57 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_sigmoid()
this
->
add_matcher
(
m
);
}
void
ngraph
::
runtime
::
cpu
::
pass
::
CPUFusion
::
construct_sigmoid_bprop
()
{
//construct variance
auto
input
=
std
::
make_shared
<
pattern
::
op
::
Label
>
(
element
::
f32
,
Shape
{
3
,
4
});
auto
neg_input
=
std
::
make_shared
<
op
::
Negative
>
(
input
);
auto
exp_neg_input
=
std
::
make_shared
<
op
::
Exp
>
(
neg_input
);
// broadcast input
auto
constant
=
std
::
make_shared
<
pattern
::
op
::
Label
>
(
element
::
f32
,
Shape
{});
auto
broadcast_constant
=
std
::
make_shared
<
op
::
Broadcast
>
(
constant
,
Shape
{
3
,
4
},
AxisSet
{
0
,
1
});
auto
add_exp
=
std
::
make_shared
<
op
::
Add
>
(
exp_neg_input
,
broadcast_constant
);
// //auto divide_1_over_exp = std::make_shared<op::Divide>(broadcast_constant, add_exp);
auto
sigmoid_fwd
=
std
::
make_shared
<
pattern
::
op
::
Label
>
(
element
::
f32
,
Shape
{
3
,
4
});
auto
delta
=
std
::
make_shared
<
pattern
::
op
::
Label
>
(
element
::
f32
,
Shape
{
3
,
4
});
auto
neg_delta
=
std
::
make_shared
<
op
::
Negative
>
(
delta
);
auto
multiply_sigmoid_delta
=
std
::
make_shared
<
op
::
Multiply
>
(
sigmoid_fwd
,
neg_delta
);
auto
divide_2
=
std
::
make_shared
<
op
::
Divide
>
(
multiply_sigmoid_delta
,
add_exp
);
auto
multiply_2
=
std
::
make_shared
<
op
::
Multiply
>
(
divide_2
,
exp_neg_input
);
auto
negtive_2
=
std
::
make_shared
<
op
::
Negative
>
(
multiply_2
);
//Define a call back that needs to called once the DFG matches the pattern
ngraph
::
pattern
::
gr_callback_fn
callback
=
[
input
,
delta
](
pattern
::
Matcher
&
m
)
->
std
::
shared_ptr
<
Node
>
{
NGRAPH_DEBUG
<<
"In a callback for construct_fprop_sigmoid pattern against "
<<
m
.
match_root
()
->
get_name
();
auto
pattern_map
=
m
.
get_pattern_map
();
if
(
m
.
match_root
()
->
get_element_type
()
!=
element
::
f32
)
{
NGRAPH_DEBUG
<<
"mpattern = "
<<
m
.
match_root
()
->
get_name
()
<<
" type is not float!"
;
return
nullptr
;
}
if
(
m
.
match_root
()
->
get_shape
().
size
()
!=
pattern_map
[
input
]
->
get_shape
().
size
())
{
NGRAPH_DEBUG
<<
"mpattern = "
<<
m
.
match_root
()
->
get_name
()
<<
"input= "
<<
pattern_map
[
input
]
->
get_name
()
<<
"size dont match!"
;
return
nullptr
;
}
auto
dsigmoid
=
std
::
make_shared
<
op
::
SigmoidBackprop
>
(
pattern_map
[
input
],
pattern_map
[
delta
]);
return
dsigmoid
;
};
auto
m
=
std
::
make_shared
<
ngraph
::
pattern
::
Matcher
>
(
negtive_2
,
callback
);
this
->
add_matcher
(
m
);
}
void
ngraph
::
runtime
::
cpu
::
pass
::
CPUFusion
::
construct_conv_bias
()
{
Shape
shape
{
2
,
2
,
1
,
1
};
...
...
src/ngraph/runtime/cpu/pass/cpu_fusion.hpp
View file @
3d53e58a
...
...
@@ -44,6 +44,7 @@ public:
construct_zero_padded_reshaped_conv
();
construct_zero_padded_conv
();
construct_sigmoid
();
construct_sigmoid_bprop
();
construct_conv_bias
();
}
...
...
@@ -53,6 +54,7 @@ private:
void
construct_conv_bias
();
void
construct_fprop_bn
();
void
construct_sigmoid
();
void
construct_sigmoid_bprop
();
void
construct_zero_padded_reshaped_conv
();
void
construct_zero_padded_conv
();
};
src/ngraph/runtime/cpu/pass/cpu_layout.cpp
View file @
3d53e58a
...
...
@@ -960,6 +960,29 @@ namespace ngraph
}
}
template
<>
void
CPULayout
::
LAYOUT_DECL
(
ngraph
::
op
::
SigmoidBackprop
)
{
if
(
runtime
::
cpu
::
mkldnn_utils
::
use_mkldnn_kernel
(
node
.
get
()))
{
auto
input_layout
=
runtime
::
cpu
::
mkldnn_utils
::
get_input_mkldnn_format
(
node
.
get
(),
0
);
vector
<
memory
::
format
>
prim_input_formats
;
vector
<
memory
::
format
>
prim_output_formats
;
//ensure delta and input have same layout
prim_input_formats
.
push_back
(
input_layout
);
prim_input_formats
.
push_back
(
input_layout
);
prim_output_formats
.
push_back
(
input_layout
);
node
=
insert_input_conversions
(
external_function
,
node
,
prim_input_formats
);
set_output_layouts
(
node
,
prim_output_formats
);
}
else
{
set_default_layouts
(
external_function
,
node
);
}
}
template
<>
void
CPULayout
::
LAYOUT_DECL
(
ngraph
::
op
::
ReluBackprop
)
{
...
...
@@ -1095,6 +1118,8 @@ static const runtime::cpu::pass::LayoutOpMap s_dispatcher{
{
TI
(
ngraph
::
op
::
ReluBackprop
),
&
runtime
::
cpu
::
pass
::
CPULayout
::
layout
<
ngraph
::
op
::
ReluBackprop
>
},
{
TI
(
ngraph
::
op
::
Sigmoid
),
&
runtime
::
cpu
::
pass
::
CPULayout
::
layout
<
ngraph
::
op
::
Sigmoid
>
},
{
TI
(
ngraph
::
op
::
SigmoidBackprop
),
&
runtime
::
cpu
::
pass
::
CPULayout
::
layout
<
ngraph
::
op
::
SigmoidBackprop
>
},
};
bool
runtime
::
cpu
::
pass
::
CPULayout
::
run_on_call_graph
(
const
std
::
list
<
std
::
shared_ptr
<
Node
>>&
nodes
)
...
...
src/ngraph/runtime/gpu/gpu_cuda_context_manager.cpp
View file @
3d53e58a
...
...
@@ -19,25 +19,18 @@
#include "ngraph/runtime/gpu/gpu_cuda_context_manager.hpp"
namespace
ngraph
using
namespace
ngraph
;
runtime
::
gpu
::
CudaContextManager
&
runtime
::
gpu
::
CudaContextManager
::
instance
()
{
namespace
runtime
{
namespace
gpu
{
CudaContextManager
&
CudaContextManager
::
instance
()
{
static
CudaContextManager
manager
;
return
manager
;
}
static
CudaContextManager
manager
;
return
manager
;
}
CudaContextManager
::
CudaContextManager
()
{
CUDA_SAFE_CALL
(
cuInit
(
0
));
CUDA_SAFE_CALL
(
cuDeviceGet
(
&
m_device
,
0
));
CUDA_SAFE_CALL
(
cuCtxCreate
(
&
m_context
,
0
,
m_device
));
m_context_ptr
=
std
::
make_shared
<
CUcontext
>
(
m_context
);
}
}
}
runtime
::
gpu
::
CudaContextManager
::
CudaContextManager
()
{
CUDA_SAFE_CALL
(
cuInit
(
0
));
CUDA_SAFE_CALL
(
cuDeviceGet
(
&
m_device
,
0
));
CUDA_SAFE_CALL
(
cuCtxCreate
(
&
m_context
,
0
,
m_device
));
m_context_ptr
=
std
::
make_shared
<
CUcontext
>
(
m_context
);
}
src/ngraph/runtime/gpu/gpu_cuda_function_builder.cpp
View file @
3d53e58a
...
...
@@ -20,46 +20,39 @@
#include "ngraph/runtime/gpu/gpu_cuda_function_builder.hpp"
#include "ngraph/runtime/gpu/gpu_util.hpp"
namespace
ngraph
using
namespace
ngraph
;
std
::
shared_ptr
<
CUfunction
>
runtime
::
gpu
::
CudaFunctionBuilder
::
get
(
const
std
::
string
&
name
,
const
std
::
string
&
kernel
,
int
number_of_options
,
const
char
**
options
)
{
namespace
runtime
{
namespace
gpu
{
std
::
shared_ptr
<
CUfunction
>
CudaFunctionBuilder
::
get
(
const
std
::
string
&
name
,
const
std
::
string
&
kernel
,
int
number_of_options
,
const
char
**
options
)
{
nvrtcProgram
prog
;
NVRTC_SAFE_CALL
(
nvrtcCreateProgram
(
&
prog
,
kernel
.
c_str
(),
"op.cu"
,
0
,
// numHeaders
NULL
,
// headers
NULL
));
// includeNames
nvrtcProgram
prog
;
NVRTC_SAFE_CALL
(
nvrtcCreateProgram
(
&
prog
,
kernel
.
c_str
(),
"op.cu"
,
0
,
// numHeaders
NULL
,
// headers
NULL
));
// includeNames
nvrtcResult
compile_result
=
nvrtcCompileProgram
(
prog
,
number_of_options
,
options
);
nvrtcResult
compile_result
=
nvrtcCompileProgram
(
prog
,
number_of_options
,
options
);
if
(
compile_result
!=
NVRTC_SUCCESS
)
{
throw
std
::
runtime_error
(
"compile error:
\n
"
+
kernel
+
"
\n
options"
);
}
if
(
compile_result
!=
NVRTC_SUCCESS
)
{
throw
std
::
runtime_error
(
"compile error:
\n
"
+
kernel
+
"
\n
options"
);
}
size_t
ptx_size
;
NVRTC_SAFE_CALL
(
nvrtcGetPTXSize
(
prog
,
&
ptx_size
));
char
*
ptx
=
new
char
[
ptx_size
];
NVRTC_SAFE_CALL
(
nvrtcGetPTX
(
prog
,
size_t
ptx_size
;
NVRTC_SAFE_CALL
(
nvrtcGetPTXSize
(
prog
,
&
ptx_size
));
char
*
ptx
=
new
char
[
ptx_size
];
NVRTC_SAFE_CALL
(
nvrtcGetPTX
(
prog
,
ptx
));
// Load the generated PTX and get a handle to the parent kernel.
NVRTC_SAFE_CALL
(
nvrtcDestroyProgram
(
&
prog
));
// Destroy the program.
NVRTC_SAFE_CALL
(
nvrtcDestroyProgram
(
&
prog
));
// Destroy the program.
CUmodule
module
;
CUfunction
function
;
CUDA_SAFE_CALL
(
cuModuleLoadDataEx
(
&
module
,
ptx
,
0
,
0
,
0
));
CUDA_SAFE_CALL
(
cuModuleGetFunction
(
&
function
,
module
,
name
.
c_str
()));
return
std
::
make_shared
<
CUfunction
>
(
function
);
}
}
}
CUmodule
module
;
CUfunction
function
;
CUDA_SAFE_CALL
(
cuModuleLoadDataEx
(
&
module
,
ptx
,
0
,
0
,
0
));
CUDA_SAFE_CALL
(
cuModuleGetFunction
(
&
function
,
module
,
name
.
c_str
()));
return
std
::
make_shared
<
CUfunction
>
(
function
);
}
src/ngraph/runtime/gpu/gpu_cuda_function_pool.cpp
View file @
3d53e58a
...
...
@@ -26,40 +26,31 @@
static
const
std
::
string
s_output_dir
=
"gpu_codegen"
;
namespace
ngraph
using
namespace
ngraph
;
runtime
::
gpu
::
CudaFunctionPool
&
runtime
::
gpu
::
CudaFunctionPool
::
instance
()
{
namespace
runtime
{
namespace
gpu
{
CudaFunctionPool
&
CudaFunctionPool
::
instance
()
{
static
CudaFunctionPool
pool
;
return
pool
;
}
static
CudaFunctionPool
pool
;
return
pool
;
}
void
CudaFunctionPool
::
set
(
const
std
::
string
&
name
,
const
std
::
string
&
kernel
)
{
const
char
*
opts
[]
=
{
"--gpu-architecture=compute_35"
,
"--relocatable-device-code=true"
};
std
::
string
filename
=
file_util
::
path_join
(
s_output_dir
,
"cuda_kernel_"
+
name
+
"_codegen.cu"
);
std
::
ofstream
out
(
filename
);
out
<<
kernel
;
out
.
close
();
m_function_map
.
insert
(
{
name
,
CudaFunctionBuilder
::
get
(
"cuda_"
+
name
,
kernel
,
2
,
opts
)});
}
void
runtime
::
gpu
::
CudaFunctionPool
::
set
(
const
std
::
string
&
name
,
const
std
::
string
&
kernel
)
{
const
char
*
opts
[]
=
{
"--gpu-architecture=compute_35"
,
"--relocatable-device-code=true"
};
std
::
string
filename
=
file_util
::
path_join
(
s_output_dir
,
"cuda_kernel_"
+
name
+
"_codegen.cu"
);
std
::
ofstream
out
(
filename
);
out
<<
kernel
;
out
.
close
();
m_function_map
.
insert
({
name
,
CudaFunctionBuilder
::
get
(
"cuda_"
+
name
,
kernel
,
2
,
opts
)});
}
std
::
shared_ptr
<
CUfunction
>
CudaFunctionPool
::
get
(
const
std
::
string
&
name
)
{
auto
it
=
m_function_map
.
find
(
name
);
if
(
it
!=
m_function_map
.
end
())
{
return
(
*
it
).
second
;
}
return
nullptr
;
}
}
std
::
shared_ptr
<
CUfunction
>
runtime
::
gpu
::
CudaFunctionPool
::
get
(
const
std
::
string
&
name
)
{
auto
it
=
m_function_map
.
find
(
name
);
if
(
it
!=
m_function_map
.
end
())
{
return
(
*
it
).
second
;
}
return
nullptr
;
}
src/ngraph/runtime/gpu/gpu_cuda_kernel_builder.cpp
View file @
3d53e58a
...
...
@@ -16,74 +16,67 @@
#include "ngraph/runtime/gpu/gpu_cuda_kernel_builder.hpp"
#include "ngraph/codegen/code_writer.hpp"
namespace
ngraph
using
namespace
ngraph
;
void
runtime
::
gpu
::
CudaKernelBuilder
::
get_elementwise_op
(
codegen
::
CodeWriter
&
writer
,
const
std
::
string
&
name
,
const
std
::
string
&
data_type
,
const
std
::
string
&
op
,
const
size_t
&
num_inputs
)
{
namespace
runtime
writer
<<
"extern
\"
C
\"
__global__ void cuda_"
<<
name
<<
"("
;
for
(
size_t
i
=
0
;
i
<
num_inputs
;
i
++
)
{
writer
<<
data_type
<<
"* in"
<<
i
<<
", "
;
}
writer
<<
data_type
<<
"* out,"
<<
"size_t n)
\n
"
;
writer
<<
"{
\n
"
;
writer
.
indent
++
;
{
namespace
gpu
writer
<<
"size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
\n
"
;
writer
<<
"if (tid < n)
\n
"
;
writer
<<
"{
\n
"
;
writer
.
indent
++
;
{
void
CudaKernelBuilder
::
get_elementwise_op
(
codegen
::
CodeWriter
&
writer
,
const
std
::
string
&
name
,
const
std
::
string
&
data_type
,
const
std
::
string
&
op
,
const
size_t
&
num_inputs
)
writer
<<
"out[tid] = "
<<
op
<<
"("
;
for
(
size_t
i
=
0
;
i
<
num_inputs
-
1
;
i
++
)
{
writer
<<
"extern
\"
C
\"
__global__ void cuda_"
<<
name
<<
"("
;
for
(
size_t
i
=
0
;
i
<
num_inputs
;
i
++
)
{
writer
<<
data_type
<<
"* in"
<<
i
<<
", "
;
}
writer
<<
data_type
<<
"* out,"
<<
"size_t n)
\n
"
;
writer
<<
"{
\n
"
;
writer
.
indent
++
;
{
writer
<<
"size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
\n
"
;
writer
<<
"if (tid < n)
\n
"
;
writer
<<
"{
\n
"
;
writer
.
indent
++
;
{
writer
<<
"out[tid] = "
<<
op
<<
"("
;
for
(
size_t
i
=
0
;
i
<
num_inputs
-
1
;
i
++
)
{
writer
<<
"in"
<<
i
<<
"[tid], "
;
}
writer
<<
"in"
<<
num_inputs
-
1
<<
"[tid]);
\n
"
;
}
writer
.
indent
--
;
writer
<<
"}
\n
"
;
}
writer
.
indent
--
;
writer
<<
"}
\n
"
;
return
;
writer
<<
"in"
<<
i
<<
"[tid], "
;
}
writer
<<
"in"
<<
num_inputs
-
1
<<
"[tid]);
\n
"
;
}
writer
.
indent
--
;
writer
<<
"}
\n
"
;
}
writer
.
indent
--
;
writer
<<
"}
\n
"
;
void
CudaKernelBuilder
::
get_device_helper
(
codegen
::
CodeWriter
&
writer
,
const
std
::
string
&
name
,
const
std
::
string
&
data_type
,
const
std
::
string
&
math_kernel
,
const
size_t
&
num_inputs
)
{
if
(
math_kernel
.
size
())
{
writer
<<
"__device__ "
<<
data_type
<<
" "
<<
name
<<
"("
;
for
(
size_t
i
=
0
;
i
<
num_inputs
-
1
;
i
++
)
{
writer
<<
data_type
<<
" x"
<<
i
<<
", "
;
}
writer
<<
data_type
<<
" x"
<<
num_inputs
-
1
;
writer
<<
")
\n
"
;
writer
<<
"{
\n
"
;
writer
.
indent
++
;
{
writer
<<
"return "
+
math_kernel
<<
";
\n
"
;
}
writer
.
indent
--
;
writer
<<
"}
\n
"
;
}
return
;
}
return
;
}
void
runtime
::
gpu
::
CudaKernelBuilder
::
get_device_helper
(
codegen
::
CodeWriter
&
writer
,
const
std
::
string
&
name
,
const
std
::
string
&
data_type
,
const
std
::
string
&
math_kernel
,
const
size_t
&
num_inputs
)
{
if
(
math_kernel
.
size
())
{
writer
<<
"__device__ "
<<
data_type
<<
" "
<<
name
<<
"("
;
for
(
size_t
i
=
0
;
i
<
num_inputs
-
1
;
i
++
)
{
writer
<<
data_type
<<
" x"
<<
i
<<
", "
;
}
writer
<<
data_type
<<
" x"
<<
num_inputs
-
1
;
writer
<<
")
\n
"
;
writer
<<
"{
\n
"
;
writer
.
indent
++
;
{
writer
<<
"return "
+
math_kernel
<<
";
\n
"
;
}
writer
.
indent
--
;
writer
<<
"}
\n
"
;
}
return
;
}
src/ngraph/runtime/gpu/gpu_cuda_kernel_emitters.cpp
View file @
3d53e58a
...
...
@@ -20,26 +20,22 @@
#include "ngraph/runtime/gpu/gpu_cuda_kernel_emitters.hpp"
#include "ngraph/runtime/gpu/gpu_cuda_kernel_ops.hpp"
namespace
ngraph
using
namespace
ngraph
;
void
runtime
::
gpu
::
emit_broadcast
(
void
*
in
,
void
*
out
,
size_t
repeat_size
,
size_t
repeat_times
,
size_t
count
)
{
namespace
runtime
std
::
string
name
=
"broadcast"
;
// Create an instance of nvrtcProgram with the code string.
if
(
CudaFunctionPool
::
instance
().
get
(
name
)
==
nullptr
)
{
namespace
gpu
{
void
emit_broadcast
(
void
*
in
,
void
*
out
,
size_t
repeat_size
,
size_t
repeat_times
,
size_t
count
)
{
std
::
string
name
=
"broadcast"
;
// Create an instance of nvrtcProgram with the code string.
if
(
CudaFunctionPool
::
instance
().
get
(
name
)
==
nullptr
)
{
std
::
string
kernel
;
std
::
string
data_type
(
"float"
);
std
::
string
kernel
;
std
::
string
data_type
(
"float"
);
kernel
=
R"(
kernel
=
R"(
extern "C" __global__
void cuda_)"
+
name
+
"("
+
data_type
+
"* in, "
+
data_type
+
"* out, size_t m, size_t k, size_t n)
\n
"
+
R"(
void cuda_)"
+
name
+
"("
+
data_type
+
"* in, "
+
data_type
+
"* out, size_t m, size_t k, size_t n)
\n
"
+
R"(
{
size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
if(tid < n)
...
...
@@ -48,28 +44,25 @@ void cuda_)" + name + "(" + data_type +
out[tid] = in[idx];
}
})"
;
CudaFunctionPool
::
instance
().
set
(
name
,
kernel
);
}
CudaFunctionPool
::
instance
().
set
(
name
,
kernel
);
}
//convert runtime ptr to driver api ptr
CUdeviceptr
d_ptr_in
,
d_ptr_out
;
d_ptr_in
=
(
CUdeviceptr
)
in
;
d_ptr_out
=
(
CUdeviceptr
)
out
;
//convert runtime ptr to driver api ptr
CUdeviceptr
d_ptr_in
,
d_ptr_out
;
d_ptr_in
=
CUdeviceptr
(
in
)
;
d_ptr_out
=
CUdeviceptr
(
out
)
;
void
*
args_list
[]
=
{
&
d_ptr_in
,
&
d_ptr_out
,
&
repeat_size
,
&
repeat_times
,
&
count
};
CUDA_SAFE_CALL
(
cuLaunchKernel
(
*
CudaFunctionPool
::
instance
().
get
(
name
).
get
(),
count
,
1
,
1
,
// grid dim
1
,
1
,
1
,
// block dim
0
,
NULL
,
// shared mem and stream
args_list
,
0
));
// arguments
CUDA_SAFE_CALL
(
cuCtxSynchronize
());
// Retrieve and print output.
}
}
}
void
*
args_list
[]
=
{
&
d_ptr_in
,
&
d_ptr_out
,
&
repeat_size
,
&
repeat_times
,
&
count
};
CUDA_SAFE_CALL
(
cuLaunchKernel
(
*
CudaFunctionPool
::
instance
().
get
(
name
).
get
(),
static_cast
<
unsigned
int
>
(
count
),
1
,
1
,
// grid dim
1
,
1
,
1
,
// block dim
0
,
NULL
,
// shared mem and stream
args_list
,
0
));
// arguments
CUDA_SAFE_CALL
(
cuCtxSynchronize
());
// Retrieve and print output.
}
src/ngraph/runtime/gpu/gpu_emitter.cpp
View file @
3d53e58a
...
...
@@ -518,15 +518,11 @@ cudnnSetOpTensorDescriptor(opTensorDesc,
writer
.
indent
++
;
auto
arg_shape
=
args
[
0
].
get_shape
();
auto
arg_rank
=
arg_shape
.
size
();
auto
result_shape
=
out
[
0
].
get_shape
();
auto
&
result_element_type
=
out
[
0
].
get_element_type
();
auto
input_order
=
reshape
->
get_input_order
();
bool
same_layout
=
is_sorted
(
input_order
.
begin
(),
input_order
.
end
());
size_t
result_shape_product
=
1
;
for
(
auto
i
:
result_shape
)
{
result_shape_product
*=
i
;
...
...
src/ngraph/runtime/gpu/gpu_external_function.cpp
View file @
3d53e58a
This diff is collapsed.
Click to expand it.
src/ngraph/runtime/gpu/gpu_tensor_view.cpp
View file @
3d53e58a
...
...
@@ -41,7 +41,7 @@ runtime::gpu::GPU_TensorView::GPU_TensorView(const ngraph::element::Type& elemen
m_buffer_size
=
shape_size
(
shape
)
*
element_type
.
size
();
if
(
m_buffer_size
>
0
)
{
cudaMalloc
(
(
void
**
)
&
m_allocated_buffer_pool
,
m_buffer_size
);
cudaMalloc
(
static_cast
<
void
**>
(
&
m_allocated_buffer_pool
)
,
m_buffer_size
);
}
}
...
...
src/ngraph/runtime/gpu/gpu_util.cpp
View file @
3d53e58a
...
...
@@ -50,7 +50,7 @@ void runtime::gpu::check_cuda_errors(CUresult err)
void
*
runtime
::
gpu
::
create_gpu_buffer
(
size_t
buffer_size
)
{
void
*
allocated_buffer_pool
;
cudaMalloc
(
(
void
**
)
&
allocated_buffer_pool
,
buffer_size
);
cudaMalloc
(
static_cast
<
void
**>
(
&
allocated_buffer_pool
)
,
buffer_size
);
return
allocated_buffer_pool
;
}
...
...
test/cpu_fusion.cpp
View file @
3d53e58a
...
...
@@ -48,6 +48,8 @@
#include "ngraph/util.hpp"
#include "nlohmann/json.hpp"
#include "util/all_close.hpp"
#include "util/autodiff/backprop_function.hpp"
#include "util/autodiff/numeric_compare.hpp"
#include "util/matcher.hpp"
#include "util/test_tools.hpp"
...
...
@@ -914,3 +916,47 @@ TEST(cpu_fusion, sigmoid_n1c1h4)
vector
<
float
>
expected
{
0.73105858
f
,
0.98201379
f
,
0.73105858
f
,
0.98201379
f
};
ASSERT_TRUE
(
read_vector
<
float
>
(
result
)
==
expected
);
}
TEST
(
cpu_fusion
,
sigmoid_bprop_fusion
)
{
const
string
json_path
=
file_util
::
path_join
(
SERIALIZED_ZOO
,
"mxnet/Graph_fprop_sigmoid.json"
);
const
string
json_string
=
file_util
::
read_file_to_string
(
json_path
);
stringstream
ss
(
json_string
);
shared_ptr
<
Function
>
func
=
ngraph
::
deserialize
(
ss
);
auto
df
=
autodiff
::
backprop_function
(
func
);
auto
manager
=
runtime
::
Manager
::
get
(
"CPU"
);
auto
external
=
manager
->
compile
(
df
);
auto
backend
=
manager
->
allocate_backend
();
auto
cf
=
backend
->
make_call_frame
(
external
);
size_t
ccg
=
count_ops_of_type
<
op
::
SigmoidBackprop
>
(
df
);
ASSERT_EQ
(
ccg
,
1
);
}
TEST
(
cpu_fusion
,
sigmoid_bprop_n1c1h4
)
{
auto
input
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
1
,
1
,
4
});
auto
delta
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
1
,
1
,
4
});
auto
sigmoid_node
=
make_shared
<
op
::
SigmoidBackprop
>
(
input
,
delta
);
auto
func
=
make_shared
<
Function
>
(
sigmoid_node
,
op
::
ParameterVector
{
input
,
delta
});
auto
manager
=
runtime
::
Manager
::
get
(
"CPU"
);
auto
external
=
manager
->
compile
(
func
);
auto
backend
=
manager
->
allocate_backend
();
auto
cf
=
backend
->
make_call_frame
(
external
);
shared_ptr
<
runtime
::
TensorView
>
a
=
backend
->
make_primary_tensor_view
(
element
::
f32
,
input
->
get_shape
());
shared_ptr
<
runtime
::
TensorView
>
b
=
backend
->
make_primary_tensor_view
(
element
::
f32
,
delta
->
get_shape
());
shared_ptr
<
runtime
::
TensorView
>
result
=
backend
->
make_primary_tensor_view
(
element
::
f32
,
input
->
get_shape
());
vector
<
float
>
dataA
{
1.0
f
,
4.0
f
,
1.0
f
,
4.0
f
};
vector
<
float
>
dataB
{
1.0
f
,
1.0
f
,
1.0
f
,
1.0
f
};
copy_data
(
a
,
dataA
);
copy_data
(
b
,
dataB
);
cf
->
call
({
a
,
b
},
{
result
});
vector
<
float
>
expected
{
0.196612
f
,
0.0176627
f
,
0.196612
f
,
0.0176627
f
};
EXPECT_TRUE
(
test
::
all_close
(
expected
,
read_vector
<
float
>
(
result
)));
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment