Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
64e1dbe9
Commit
64e1dbe9
authored
Jun 14, 2019
by
gaurides
Committed by
Scott Cyphers
Jun 14, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Use all args for dropout (#3069)
parent
33c74139
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
51 additions
and
30 deletions
+51
-30
dropout.cpp
src/ngraph/runtime/cpu/builder/dropout.cpp
+8
-4
dropout.cpp
src/ngraph/runtime/cpu/op/dropout.cpp
+16
-7
dropout.hpp
src/ngraph/runtime/cpu/op/dropout.hpp
+5
-10
cpu_fusion.cpp
src/ngraph/runtime/cpu/pass/cpu_fusion.cpp
+22
-9
No files found.
src/ngraph/runtime/cpu/builder/dropout.cpp
View file @
64e1dbe9
...
...
@@ -38,13 +38,13 @@ namespace ngraph
auto
arg_buffer_index
=
external_function
->
get_buffer_index
(
args
[
0
].
get_name
());
auto
arg1_buffer_index
=
external_function
->
get_buffer_index
(
args
[
1
].
get_name
());
auto
arg4_buffer_index
=
external_function
->
get_buffer_index
(
args
[
4
].
get_name
());
auto
out0_buffer_index
=
external_function
->
get_buffer_index
(
out
[
0
].
get_name
());
auto
out1_buffer_index
=
external_function
->
get_buffer_index
(
out
[
1
].
get_name
());
size_t
element_count
=
out
[
0
].
get_size
();
bool
use_seed
=
drop
->
get_use_seed
();
double
keep_prob
=
drop
->
get_keep_prob
();
// Note: for performance optimization in addition to parallel RNG with multiple,
// threads, we create, initialize and advance each msr here in builder instead of
...
...
@@ -56,7 +56,7 @@ namespace ngraph
std
::
vector
<
std
::
minstd_rand
>
vmsr
(
nthr
);
if
(
use_seed
)
{
uint
32
_t
seed
=
drop
->
get_seed
();
uint
64
_t
seed
=
drop
->
get_seed
();
for
(
size_t
i
=
0
;
i
<
nthr
;
i
++
)
{
std
::
minstd_rand
msr
;
...
...
@@ -72,13 +72,15 @@ namespace ngraph
element_count
,
arg_buffer_index
,
arg1_buffer_index
,
arg4_buffer_index
,
out0_buffer_index
,
out1_buffer_index
,
keep_prob
,
vmsr
,
use_seed
](
CPURuntimeContext
*
ctx
,
CPUExecutionContext
*
ectx
)
{
bool
training
=
static_cast
<
bool
>
(
static_cast
<
float
*>
(
ctx
->
buffer_data
[
arg1_buffer_index
])[
0
]);
double
keep_prob
=
static_cast
<
double
*>
(
ctx
->
buffer_data
[
arg4_buffer_index
])[
0
];
runtime
::
cpu
::
kernel
::
generate_dropout
(
static_cast
<
float
*>
(
ctx
->
buffer_data
[
arg_buffer_index
]),
static_cast
<
float
*>
(
ctx
->
buffer_data
[
out0_buffer_index
]),
...
...
@@ -96,13 +98,15 @@ namespace ngraph
element_count
,
arg_buffer_index
,
arg1_buffer_index
,
arg4_buffer_index
,
out0_buffer_index
,
out1_buffer_index
,
keep_prob
,
vmsr
,
use_seed
](
CPURuntimeContext
*
ctx
,
CPUExecutionContext
*
ectx
)
{
bool
training
=
static_cast
<
bool
>
(
static_cast
<
double
*>
(
ctx
->
buffer_data
[
arg1_buffer_index
])[
0
]);
double
keep_prob
=
static_cast
<
double
*>
(
ctx
->
buffer_data
[
arg4_buffer_index
])[
0
];
runtime
::
cpu
::
kernel
::
generate_dropout
(
static_cast
<
double
*>
(
ctx
->
buffer_data
[
arg_buffer_index
]),
static_cast
<
double
*>
(
ctx
->
buffer_data
[
out0_buffer_index
]),
...
...
src/ngraph/runtime/cpu/op/dropout.cpp
View file @
64e1dbe9
...
...
@@ -26,11 +26,9 @@ using namespace ngraph;
op
::
Dropout
::
Dropout
(
const
std
::
shared_ptr
<
Node
>&
input
,
const
std
::
shared_ptr
<
Node
>&
gm_const
,
const
std
::
shared_ptr
<
Node
>&
use_seed
,
const
uint32_t
seed
,
const
double
keep_prob
)
:
Op
(
"Dropout"
,
check_single_output_args
({
input
,
gm_const
,
use_seed
}))
,
m_seed
(
seed
)
,
m_keep_prob
(
keep_prob
)
const
std
::
shared_ptr
<
Node
>&
seed
,
const
std
::
shared_ptr
<
Node
>&
keep_prob
)
:
Op
(
"Dropout"
,
check_single_output_args
({
input
,
gm_const
,
use_seed
,
seed
,
keep_prob
}))
{
constructor_validate_and_infer_types
();
...
...
@@ -41,13 +39,13 @@ op::Dropout::Dropout(const std::shared_ptr<Node>& input,
shared_ptr
<
Node
>
op
::
Dropout
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
if
(
new_args
.
size
()
!=
3
)
if
(
new_args
.
size
()
!=
5
)
{
throw
ngraph_error
(
"Incorrect number of new arguments"
);
}
return
make_shared
<
Dropout
>
(
new_args
.
at
(
0
),
new_args
.
at
(
1
),
new_args
.
at
(
2
),
m_seed
,
m_keep_prob
);
new_args
.
at
(
0
),
new_args
.
at
(
1
),
new_args
.
at
(
2
),
new_args
.
at
(
3
),
new_args
.
at
(
4
)
);
}
bool
op
::
Dropout
::
get_use_seed
()
const
...
...
@@ -60,3 +58,14 @@ bool op::Dropout::get_use_seed() const
}
return
use_seed
;
}
uint64_t
op
::
Dropout
::
get_seed
()
const
{
uint64_t
seed
=
0
;
if
(
auto
const_op
=
dynamic_pointer_cast
<
op
::
Constant
>
(
get_argument
(
3
)))
{
auto
seed_ptr
=
static_cast
<
const
uint64_t
*>
(
const_op
->
get_data_ptr
());
seed
=
*
seed_ptr
;
}
return
seed
;
}
src/ngraph/runtime/cpu/op/dropout.hpp
View file @
64e1dbe9
...
...
@@ -29,20 +29,15 @@ namespace ngraph
Dropout
(
const
std
::
shared_ptr
<
Node
>&
input
,
const
std
::
shared_ptr
<
Node
>&
gm_const
,
const
std
::
shared_ptr
<
Node
>&
use_seed
,
const
uint32_t
seed
,
const
double
keep_prob
);
// keep_prob = 1 - dropout_prob
const
std
::
shared_ptr
<
Node
>&
seed
,
const
std
::
shared_ptr
<
Node
>&
keep_prob
);
// keep_prob = 1 - dropout_prob
bool
get_use_seed
()
const
;
uint32_t
get_seed
()
const
{
return
m_seed
;
}
double
get_keep_prob
()
const
{
return
m_keep_prob
;
}
void
set_seed
(
uint32_t
new_seed
)
{
m_seed
=
new_seed
;
}
void
set_keep_prob
(
double
new_keep_prob
)
{
m_keep_prob
=
new_keep_prob
;
}
uint64_t
get_seed
()
const
;
double
get_keep_prob
()
const
;
virtual
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
private
:
uint32_t
m_seed
;
double
m_keep_prob
;
};
}
}
src/ngraph/runtime/cpu/pass/cpu_fusion.cpp
View file @
64e1dbe9
...
...
@@ -923,8 +923,8 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_dropout()
auto
x
=
std
::
make_shared
<
pattern
::
op
::
Label
>
(
element
::
f32
,
shape
);
auto
x_label
=
std
::
make_shared
<
pattern
::
op
::
Label
>
(
x
,
nullptr
,
NodeVector
{
x
});
uint
32
_t
seed
=
1234
;
auto
seed_label
=
std
::
make_shared
<
pattern
::
op
::
Label
>
(
element
::
u
32
,
Shape
{
0
});
uint
64
_t
seed
=
1234
;
auto
seed_label
=
std
::
make_shared
<
pattern
::
op
::
Label
>
(
element
::
u
64
,
Shape
{
0
});
double
value
=
0.9
;
auto
value_const
=
ngraph
::
op
::
Constant
::
create
(
element
::
f32
,
Shape
{
1
,
1
,
2
,
2
},
{
value
});
...
...
@@ -960,15 +960,28 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_dropout()
NGRAPH_DEBUG
<<
"training argument to GenerateMask must be constant"
;
return
false
;
}
if
(
!
std
::
dynamic_pointer_cast
<
ngraph
::
op
::
Constant
>
(
gm
->
get_argument
(
2
)))
{
NGRAPH_DEBUG
<<
"use_seed argument to GenerateMask must be constant"
;
return
false
;
}
if
(
!
std
::
dynamic_pointer_cast
<
ngraph
::
op
::
Constant
>
(
gm
->
get_argument
(
3
)))
{
NGRAPH_DEBUG
<<
"seed argument to GenerateMask must be constant"
;
return
false
;
}
if
(
!
std
::
dynamic_pointer_cast
<
ngraph
::
op
::
Constant
>
(
gm
->
get_argument
(
4
)))
{
NGRAPH_DEBUG
<<
"probability argument to GenerateMask must be constant"
;
return
false
;
}
auto
gm_value
=
gm
->
get_probability
();
auto
gm_seed
=
gm
->
get_seed
();
auto
training
=
gm
->
get_argument
(
0
);
//for training purpose this is always going to be 1
auto
use_seed_arg
=
gm
->
get_argument
(
2
);
// this is the use_seed node
auto
dropout_n
=
std
::
make_shared
<
ngraph
::
op
::
Dropout
>
(
pattern_map
[
x
],
gm
->
get_argument
(
0
),
gm
->
get_argument
(
2
),
gm
->
get_argument
(
3
),
gm
->
get_argument
(
4
));
auto
dropout_n
=
std
::
make_shared
<
ngraph
::
op
::
Dropout
>
(
pattern_map
[
x
],
training
,
use_seed_arg
,
gm_seed
,
gm_value
);
auto
goe1
=
std
::
make_shared
<
ngraph
::
op
::
GetOutputElement
>
(
dropout_n
,
0
);
ngraph
::
replace_node
(
m
.
get_match_root
(),
goe1
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment