Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
66198b33
Commit
66198b33
authored
Apr 26, 2018
by
Jayaram Bobba
Committed by
Scott Cyphers
Apr 26, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Support for batchnorm+relu fusion for all batchnorm variants. (#903)
parent
82c19d24
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
218 additions
and
86 deletions
+218
-86
cpu_emitter.cpp
src/ngraph/runtime/cpu/cpu_emitter.cpp
+31
-74
cpu_emitter.hpp
src/ngraph/runtime/cpu/cpu_emitter.hpp
+7
-0
batch_norm_relu.cpp
src/ngraph/runtime/cpu/op/batch_norm_relu.cpp
+95
-2
batch_norm_relu.hpp
src/ngraph/runtime/cpu/op/batch_norm_relu.hpp
+8
-0
cpu_fusion.cpp
src/ngraph/runtime/cpu/pass/cpu_fusion.cpp
+58
-6
cpu_fusion.hpp
src/ngraph/runtime/cpu/pass/cpu_fusion.hpp
+2
-0
cpu_layout.cpp
src/ngraph/runtime/cpu/pass/cpu_layout.cpp
+17
-4
No files found.
src/ngraph/runtime/cpu/cpu_emitter.cpp
View file @
66198b33
...
@@ -366,8 +366,12 @@ namespace ngraph
...
@@ -366,8 +366,12 @@ namespace ngraph
writer
.
block_end
();
writer
.
block_end
();
}
}
template
<>
void
CPU_Emitter
::
emitBatchNorm
(
CPU_ExternalFunction
*
external_function
,
void
CPU_Emitter
::
EMITTER_DECL
(
ngraph
::
op
::
BatchNorm
)
codegen
::
CodeWriter
&
writer
,
const
ngraph
::
Node
*
node
,
const
std
::
vector
<
TensorViewWrapper
>&
args
,
const
std
::
vector
<
TensorViewWrapper
>&
out
,
bool
append_relu
)
{
{
const
ngraph
::
op
::
BatchNorm
*
batchnorm
=
const
ngraph
::
op
::
BatchNorm
*
batchnorm
=
static_cast
<
const
ngraph
::
op
::
BatchNorm
*>
(
node
);
static_cast
<
const
ngraph
::
op
::
BatchNorm
*>
(
node
);
...
@@ -382,6 +386,17 @@ namespace ngraph
...
@@ -382,6 +386,17 @@ namespace ngraph
<<
args
[
1
].
get_name
()
<<
", "
<<
args
[
1
].
get_name
()
<<
", "
<<
args
[
1
].
get_size
()
*
args
[
1
].
get_element_type
().
size
()
<<
");
\n
"
;
<<
args
[
1
].
get_size
()
*
args
[
1
].
get_element_type
().
size
()
<<
");
\n
"
;
const
float
ops_scale
=
1.
f
;
const
float
ops_alpha
=
-
0.
f
;
// relu negative slope
const
float
ops_beta
=
0.
f
;
mkldnn
::
post_ops
ops
;
if
(
append_relu
)
{
ops
.
append_eltwise
(
ops_scale
,
mkldnn
::
algorithm
::
eltwise_relu
,
ops_alpha
,
ops_beta
);
}
if
(
batchnorm
->
get_training_flag
()
&&
args
.
size
()
==
3
)
if
(
batchnorm
->
get_training_flag
()
&&
args
.
size
()
==
3
)
{
{
auto
input_format
=
auto
input_format
=
...
@@ -413,7 +428,8 @@ namespace ngraph
...
@@ -413,7 +428,8 @@ namespace ngraph
variance_desc
,
variance_desc
,
batchnorm
->
get_eps_value
(),
batchnorm
->
get_eps_value
(),
false
,
false
,
batchnorm
->
get_training_flag
());
batchnorm
->
get_training_flag
(),
ops
);
auto
&
deps
=
mkldnn_emitter
->
get_primitive_deps
(
batchnorm_index
);
auto
&
deps
=
mkldnn_emitter
->
get_primitive_deps
(
batchnorm_index
);
writer
<<
"cpu::mkldnn_utils::set_memory_ptr(ctx, "
<<
to_string
(
deps
[
0
])
writer
<<
"cpu::mkldnn_utils::set_memory_ptr(ctx, "
<<
to_string
(
deps
[
0
])
...
@@ -459,7 +475,8 @@ namespace ngraph
...
@@ -459,7 +475,8 @@ namespace ngraph
variance_desc
,
variance_desc
,
batchnorm
->
get_eps_value
(),
batchnorm
->
get_eps_value
(),
true
,
true
,
batchnorm
->
get_training_flag
());
batchnorm
->
get_training_flag
(),
ops
);
auto
&
deps
=
mkldnn_emitter
->
get_primitive_deps
(
batchnorm_index
);
auto
&
deps
=
mkldnn_emitter
->
get_primitive_deps
(
batchnorm_index
);
writer
<<
"cpu::mkldnn_utils::set_memory_ptr(ctx, "
<<
to_string
(
deps
[
0
])
writer
<<
"cpu::mkldnn_utils::set_memory_ptr(ctx, "
<<
to_string
(
deps
[
0
])
...
@@ -480,83 +497,23 @@ namespace ngraph
...
@@ -480,83 +497,23 @@ namespace ngraph
}
}
template
<>
template
<>
void
CPU_Emitter
::
EMITTER_DECL
(
ngraph
::
op
::
BatchNorm
Relu
)
void
CPU_Emitter
::
EMITTER_DECL
(
ngraph
::
op
::
BatchNorm
)
{
{
if
(
!
mkldnn_utils
::
use_mkldnn_kernel
(
node
))
if
(
!
mkldnn_utils
::
use_mkldnn_kernel
(
node
))
{
{
throw
ngraph_error
(
"BatchNormRelu is only supported with MKLDNN kernel."
);
throw
ngraph_error
(
"BatchNorm is only supported with 4-D MKLDNN kernel."
);
}
emitBatchNorm
(
external_function
,
writer
,
node
,
args
,
out
,
false
);
}
}
const
ngraph
::
op
::
BatchNormRelu
*
batchnorm
=
template
<>
static_cast
<
const
ngraph
::
op
::
BatchNormRelu
*>
(
node
);
void
CPU_Emitter
::
EMITTER_DECL
(
ngraph
::
op
::
BatchNormRelu
)
if
(
!
batchnorm
->
get_training_flag
()
||
batchnorm
->
get_inputs
().
size
()
!=
3
)
{
{
throw
ngraph_error
(
"Only training batchnorm should have been fused"
);
if
(
!
mkldnn_utils
::
use_mkldnn_kernel
(
node
))
{
throw
ngraph_error
(
"BatchNormRelu is only supported with 4-D MKLDNN kernel."
);
}
}
emitBatchNorm
(
external_function
,
writer
,
node
,
args
,
out
,
true
);
const
float
ops_scale
=
1.
f
;
const
float
ops_alpha
=
-
0.
f
;
// relu negative slope
const
float
ops_beta
=
0.
f
;
mkldnn
::
post_ops
ops
;
ops
.
append_eltwise
(
ops_scale
,
mkldnn
::
algorithm
::
eltwise_relu
,
ops_alpha
,
ops_beta
);
writer
.
block_begin
();
writer
<<
"{
\n
"
;
// define weights
writer
<<
"std::vector<"
<<
args
[
0
].
get_element_type
().
c_type_string
()
<<
">bn_weights(2*"
<<
args
[
0
].
get_size
()
<<
");
\n
"
;
writer
<<
"memcpy(&bn_weights[0], "
<<
args
[
0
].
get_name
()
<<
", "
<<
args
[
0
].
get_size
()
*
args
[
0
].
get_element_type
().
size
()
<<
");
\n
"
;
writer
<<
"memcpy(&bn_weights[0]+"
<<
args
[
0
].
get_size
()
<<
", "
<<
args
[
1
].
get_name
()
<<
", "
<<
args
[
1
].
get_size
()
*
args
[
1
].
get_element_type
().
size
()
<<
");
\n
"
;
auto
input_format
=
runtime
::
cpu
::
mkldnn_utils
::
get_input_mkldnn_format
(
node
,
2
);
auto
result_format
=
runtime
::
cpu
::
mkldnn_utils
::
get_output_mkldnn_format
(
node
,
0
);
auto
mean_format
=
runtime
::
cpu
::
mkldnn_utils
::
get_output_mkldnn_format
(
node
,
1
);
auto
variance_format
=
runtime
::
cpu
::
mkldnn_utils
::
get_output_mkldnn_format
(
node
,
2
);
auto
&
mkldnn_emitter
=
external_function
->
get_mkldnn_emitter
();
auto
weights_shape
=
Shape
{
2
,
args
[
0
].
get_size
()};
auto
input_desc
=
mkldnn_emitter
->
build_memory_descriptor
(
args
[
2
],
input_format
);
auto
weights_desc
=
mkldnn_emitter
->
build_memory_descriptor
(
weights_shape
,
args
[
0
].
get_element_type
(),
mkldnn
::
memory
::
format
::
nc
);
auto
results_desc
=
mkldnn_emitter
->
build_memory_descriptor
(
out
[
0
],
result_format
);
auto
mean_desc
=
mkldnn_emitter
->
build_memory_descriptor
(
out
[
1
],
mean_format
);
auto
variance_desc
=
mkldnn_emitter
->
build_memory_descriptor
(
out
[
2
],
variance_format
);
auto
batchnorm_index
=
mkldnn_emitter
->
build_batchnorm_forward
(
input_desc
,
weights_desc
,
results_desc
,
mean_desc
,
variance_desc
,
batchnorm
->
get_eps_value
(),
false
,
batchnorm
->
get_training_flag
(),
ops
);
auto
&
deps
=
mkldnn_emitter
->
get_primitive_deps
(
batchnorm_index
);
writer
<<
"cpu::mkldnn_utils::set_memory_ptr(ctx, "
<<
to_string
(
deps
[
0
])
<<
", "
<<
args
[
2
].
get_name
()
<<
");
\n
"
;
writer
<<
"cpu::mkldnn_utils::set_memory_ptr(ctx, "
<<
to_string
(
deps
[
1
])
<<
", bn_weights.data());
\n
"
;
writer
<<
"cpu::mkldnn_utils::set_memory_ptr(ctx, "
<<
to_string
(
deps
[
2
])
<<
", "
<<
out
[
0
].
get_name
()
<<
");
\n
"
;
writer
<<
"cpu::mkldnn_utils::set_memory_ptr(ctx, "
<<
to_string
(
deps
[
3
])
<<
", "
<<
out
[
1
].
get_name
()
<<
");
\n
"
;
writer
<<
"cpu::mkldnn_utils::set_memory_ptr(ctx, "
<<
to_string
(
deps
[
4
])
<<
", "
<<
out
[
2
].
get_name
()
<<
");
\n
"
;
writer
<<
"cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, "
<<
to_string
(
batchnorm_index
)
<<
");
\n
"
;
writer
.
block_end
();
writer
<<
"}
\n
"
;
}
}
template
<>
template
<>
...
...
src/ngraph/runtime/cpu/cpu_emitter.hpp
View file @
66198b33
...
@@ -58,6 +58,13 @@ namespace ngraph
...
@@ -58,6 +58,13 @@ namespace ngraph
{
{
}
}
static
void
emitBatchNorm
(
CPU_ExternalFunction
*
external_function
,
codegen
::
CodeWriter
&
writer
,
const
ngraph
::
Node
*
node
,
const
std
::
vector
<
TensorViewWrapper
>&
args
,
const
std
::
vector
<
TensorViewWrapper
>&
out
,
bool
append_relu
=
false
);
private
:
private
:
static
std
::
string
emit_vector
(
const
TensorViewWrapper
&
,
static
std
::
string
emit_vector
(
const
TensorViewWrapper
&
,
const
std
::
string
&
name
=
""
);
const
std
::
string
&
name
=
""
);
...
...
src/ngraph/runtime/cpu/op/batch_norm_relu.cpp
View file @
66198b33
...
@@ -76,11 +76,104 @@ ngraph::op::BatchNormRelu::BatchNormRelu(double eps,
...
@@ -76,11 +76,104 @@ ngraph::op::BatchNormRelu::BatchNormRelu(double eps,
add_output
(
input
->
get_element_type
(),
m_bn_variance_shape
);
add_output
(
input
->
get_element_type
(),
m_bn_variance_shape
);
}
}
ngraph
::
op
::
BatchNormRelu
::
BatchNormRelu
(
double
eps
,
std
::
shared_ptr
<
ngraph
::
Node
>
gamma
,
std
::
shared_ptr
<
ngraph
::
Node
>
beta
,
std
::
shared_ptr
<
ngraph
::
Node
>
input
,
std
::
shared_ptr
<
ngraph
::
Node
>
mean
,
std
::
shared_ptr
<
ngraph
::
Node
>
variance
,
bool
training
)
:
RequiresTensorViewArgs
(
"BatchNormRelu"
,
{
gamma
,
beta
,
input
,
mean
,
variance
})
,
m_bn_input_shape
(
input
->
get_shape
())
,
m_bn_variance_shape
(
variance
->
get_shape
())
,
m_bn_mean_shape
(
mean
->
get_shape
())
,
m_epsilon
(
eps
)
,
m_training
(
training
)
{
if
(
m_bn_input_shape
.
size
()
!=
4
)
{
throw
ngraph_error
(
"input tensor to batchnorm must have rank 4"
);
}
else
{
this
->
m_bn_variance_shape
.
push_back
(
input
->
get_shape
()[
1
]);
this
->
m_bn_mean_shape
.
push_back
(
input
->
get_shape
()[
1
]);
}
if
(
m_bn_input_shape
[
1
]
==
0
)
{
throw
ngraph_error
(
"input tensor must have at least one channel axis for batch normalization"
);
}
auto
et
=
input
->
get_element_type
();
const
char
*
input_names
[]
=
{
"gamma"
,
"beta"
};
for
(
size_t
i
=
0
;
i
<
2
;
i
++
)
{
if
(
get_argument
(
i
)
->
get_element_type
()
!=
et
)
{
auto
err_msg
=
std
::
string
(
"The element type of "
)
+
input_names
[
i
]
+
" isn't equal to input data's type"
;
throw
ngraph_error
(
err_msg
.
c_str
());
}
}
if
((
gamma
->
get_shape
().
size
()
!=
1
)
||
(
beta
->
get_shape
().
size
()
!=
1
))
{
throw
ngraph_error
(
"gamma and beta shoud have rank 1"
);
}
if
(
gamma
->
get_shape
().
size
()
!=
beta
->
get_shape
().
size
())
{
throw
ngraph_error
(
"gamma and beta rank does not match"
);
}
if
(
gamma
->
get_element_type
()
!=
beta
->
get_element_type
())
{
throw
ngraph_error
(
"gamma and beta element type does not match"
);
}
add_output
(
input
->
get_element_type
(),
m_bn_input_shape
);
}
std
::
shared_ptr
<
ngraph
::
Node
>
std
::
shared_ptr
<
ngraph
::
Node
>
ngraph
::
op
::
BatchNormRelu
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
ngraph
::
op
::
BatchNormRelu
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
{
if
(
new_args
.
size
()
!=
3
)
if
(
this
->
m_training
)
throw
ngraph_error
(
"Incorrect number of new arguments"
);
{
if
(
new_args
.
size
()
==
3
)
{
return
std
::
make_shared
<
BatchNormRelu
>
(
return
std
::
make_shared
<
BatchNormRelu
>
(
m_epsilon
,
new_args
.
at
(
0
),
new_args
.
at
(
1
),
new_args
.
at
(
2
));
m_epsilon
,
new_args
.
at
(
0
),
new_args
.
at
(
1
),
new_args
.
at
(
2
));
}
else
if
(
new_args
.
size
()
==
5
)
{
return
std
::
make_shared
<
BatchNormRelu
>
(
m_epsilon
,
new_args
.
at
(
0
),
new_args
.
at
(
1
),
new_args
.
at
(
2
),
new_args
.
at
(
3
),
new_args
.
at
(
4
),
true
);
}
else
{
throw
ngraph_error
(
"BatchNormRelu: Incorrect number of new arguments"
);
}
}
else
{
if
(
new_args
.
size
()
!=
5
)
{
throw
ngraph_error
(
"Incorrect number of new arguments"
);
}
return
std
::
make_shared
<
BatchNormRelu
>
(
m_epsilon
,
new_args
.
at
(
0
),
new_args
.
at
(
1
),
new_args
.
at
(
2
),
new_args
.
at
(
3
),
new_args
.
at
(
4
),
false
);
}
}
}
src/ngraph/runtime/cpu/op/batch_norm_relu.hpp
View file @
66198b33
...
@@ -35,6 +35,14 @@ namespace ngraph
...
@@ -35,6 +35,14 @@ namespace ngraph
std
::
shared_ptr
<
Node
>
beta
,
std
::
shared_ptr
<
Node
>
beta
,
std
::
shared_ptr
<
Node
>
input
);
std
::
shared_ptr
<
Node
>
input
);
BatchNormRelu
(
double
eps
,
std
::
shared_ptr
<
ngraph
::
Node
>
gamma
,
std
::
shared_ptr
<
ngraph
::
Node
>
beta
,
std
::
shared_ptr
<
ngraph
::
Node
>
input
,
std
::
shared_ptr
<
ngraph
::
Node
>
mean
,
std
::
shared_ptr
<
ngraph
::
Node
>
variance
,
bool
training
=
false
);
const
Shape
&
get_inputs_shape
()
const
{
return
m_bn_input_shape
;
}
const
Shape
&
get_inputs_shape
()
const
{
return
m_bn_input_shape
;
}
const
Shape
&
get_variance_shape
()
const
{
return
m_bn_variance_shape
;
}
const
Shape
&
get_variance_shape
()
const
{
return
m_bn_variance_shape
;
}
const
Shape
&
get_mean_shape
()
const
{
return
m_bn_mean_shape
;
}
const
Shape
&
get_mean_shape
()
const
{
return
m_bn_mean_shape
;
}
...
...
src/ngraph/runtime/cpu/pass/cpu_fusion.cpp
View file @
66198b33
...
@@ -775,12 +775,6 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_batch_norm_relu()
...
@@ -775,12 +775,6 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_batch_norm_relu()
auto
m_bn
=
std
::
dynamic_pointer_cast
<
op
::
BatchNorm
>
(
auto
m_bn
=
std
::
dynamic_pointer_cast
<
op
::
BatchNorm
>
(
m
.
match_root
()
->
get_argument
(
0
)
->
get_inputs
().
at
(
0
).
get_output
().
get_node
());
m
.
match_root
()
->
get_argument
(
0
)
->
get_inputs
().
at
(
0
).
get_output
().
get_node
());
if
(
!
m_bn
->
get_training_flag
())
{
NGRAPH_DEBUG
<<
" This is an inference batchnorm, so skipping fusion"
;
return
false
;
}
//as of now, only MKLDNN supports this fusion
//as of now, only MKLDNN supports this fusion
//and it requires input data's rank to be equal to 4
//and it requires input data's rank to be equal to 4
if
(
pattern_map
[
input
]
->
get_shape
().
size
()
!=
4
)
if
(
pattern_map
[
input
]
->
get_shape
().
size
()
!=
4
)
...
@@ -825,6 +819,64 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_batch_norm_relu()
...
@@ -825,6 +819,64 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_batch_norm_relu()
this
->
add_matcher
(
m
);
this
->
add_matcher
(
m
);
}
}
void
ngraph
::
runtime
::
cpu
::
pass
::
CPUFusion
::
construct_batch_norm_relu_global_stats
()
{
auto
input_shape
=
Shape
{
1
,
2
,
2
,
2
};
auto
input
=
std
::
make_shared
<
pattern
::
op
::
Label
>
(
element
::
f32
,
input_shape
);
auto
mean_shape
=
Shape
{
2
};
auto
mean
=
std
::
make_shared
<
pattern
::
op
::
Label
>
(
element
::
f32
,
mean_shape
);
auto
var_shape
=
Shape
{
2
};
auto
var
=
std
::
make_shared
<
pattern
::
op
::
Label
>
(
element
::
f32
,
var_shape
);
auto
gamma_shape
=
Shape
{
2
};
auto
gamma
=
std
::
make_shared
<
pattern
::
op
::
Label
>
(
element
::
f32
,
gamma_shape
);
auto
beta_shape
=
Shape
{
2
};
auto
beta
=
std
::
make_shared
<
pattern
::
op
::
Label
>
(
element
::
f32
,
beta_shape
);
double
eps
=
0.001
;
auto
shape_r
=
Shape
{
1
,
2
,
2
,
2
};
auto
bn
=
std
::
make_shared
<
op
::
BatchNorm
>
(
eps
,
gamma
,
beta
,
input
,
mean
,
var
);
auto
prelu
=
std
::
make_shared
<
op
::
Relu
>
(
bn
);
ngraph
::
pattern
::
graph_rewrite_callback
callback
=
[
input
,
mean
,
var
,
gamma
,
beta
](
pattern
::
Matcher
&
m
)
{
NGRAPH_DEBUG
<<
"In callback for construct_batch_norm_relu against node = "
<<
m
.
match_root
()
->
get_name
();
auto
pattern_map
=
m
.
get_pattern_map
();
auto
m_bn
=
std
::
dynamic_pointer_cast
<
op
::
BatchNorm
>
(
m
.
match_root
()
->
get_inputs
().
at
(
0
).
get_output
().
get_node
());
//as of now, only MKLDNN supports this fusion
//and it requires input data's rank to be equal to 4
if
(
pattern_map
[
input
]
->
get_shape
().
size
()
!=
4
)
{
NGRAPH_DEBUG
<<
" Input data's rank isn't equal to 4. Shape = "
<<
pattern_map
[
input
]
->
get_shape
().
size
();
return
false
;
}
if
(
m_bn
->
get_users
().
size
()
>
1
)
{
NGRAPH_DEBUG
<<
"Relu isn't the only user of BatchNorm's output"
;
return
false
;
}
auto
bn_relu
=
std
::
make_shared
<
op
::
BatchNormRelu
>
(
m_bn
->
get_eps_value
(),
pattern_map
[
gamma
],
pattern_map
[
beta
],
pattern_map
[
input
],
pattern_map
[
mean
],
pattern_map
[
var
],
m_bn
->
get_training_flag
());
ngraph
::
replace_node
(
m
.
match_root
(),
bn_relu
);
return
true
;
};
auto
m
=
std
::
make_shared
<
ngraph
::
pattern
::
Matcher
>
(
prelu
,
callback
);
this
->
add_matcher
(
m
);
}
void
ngraph
::
runtime
::
cpu
::
pass
::
CPUFusion
::
construct_conv_relu
()
void
ngraph
::
runtime
::
cpu
::
pass
::
CPUFusion
::
construct_conv_relu
()
{
{
Shape
shape
{
2
,
2
,
1
,
1
};
Shape
shape
{
2
,
2
,
1
,
1
};
...
...
src/ngraph/runtime/cpu/pass/cpu_fusion.hpp
View file @
66198b33
...
@@ -48,6 +48,7 @@ public:
...
@@ -48,6 +48,7 @@ public:
construct_sigmoid_bprop
();
construct_sigmoid_bprop
();
construct_conv_bias
();
construct_conv_bias
();
construct_batch_norm_relu
();
construct_batch_norm_relu
();
construct_batch_norm_relu_global_stats
();
construct_conv_relu
();
construct_conv_relu
();
}
}
...
@@ -62,5 +63,6 @@ private:
...
@@ -62,5 +63,6 @@ private:
void
construct_zero_padded_conv
();
void
construct_zero_padded_conv
();
void
construct_zero_padded_conv_backprop_filters
();
void
construct_zero_padded_conv_backprop_filters
();
void
construct_batch_norm_relu
();
void
construct_batch_norm_relu
();
void
construct_batch_norm_relu_global_stats
();
void
construct_conv_relu
();
void
construct_conv_relu
();
};
};
src/ngraph/runtime/cpu/pass/cpu_layout.cpp
View file @
66198b33
...
@@ -1104,17 +1104,30 @@ namespace ngraph
...
@@ -1104,17 +1104,30 @@ namespace ngraph
vector
<
memory
::
format
>
prim_input_formats
;
vector
<
memory
::
format
>
prim_input_formats
;
vector
<
memory
::
format
>
prim_output_formats
;
vector
<
memory
::
format
>
prim_output_formats
;
if
(
!
bn
->
get_training_flag
()
||
bn
->
get_inputs
().
size
()
!
=
3
)
if
(
bn
->
get_inputs
().
size
()
=
=
3
)
{
{
throw
ngraph_error
(
"Only training batchnorm should have been fused"
);
}
prim_input_formats
.
push_back
(
memory
::
format
::
x
);
prim_input_formats
.
push_back
(
memory
::
format
::
x
);
prim_input_formats
.
push_back
(
memory
::
format
::
x
);
prim_input_formats
.
push_back
(
memory
::
format
::
x
);
prim_input_formats
.
push_back
(
input_layout
);
prim_input_formats
.
push_back
(
input_layout
);
prim_output_formats
.
push_back
(
input_layout
);
prim_output_formats
.
push_back
(
input_layout
);
prim_output_formats
.
push_back
(
memory
::
format
::
x
);
prim_output_formats
.
push_back
(
memory
::
format
::
x
);
prim_output_formats
.
push_back
(
memory
::
format
::
x
);
prim_output_formats
.
push_back
(
memory
::
format
::
x
);
}
else
if
(
bn
->
get_inputs
().
size
()
==
5
)
{
prim_input_formats
.
push_back
(
memory
::
format
::
x
);
prim_input_formats
.
push_back
(
memory
::
format
::
x
);
prim_input_formats
.
push_back
(
input_layout
);
prim_input_formats
.
push_back
(
memory
::
format
::
x
);
prim_input_formats
.
push_back
(
memory
::
format
::
x
);
prim_output_formats
.
push_back
(
input_layout
);
}
else
{
throw
ngraph_error
(
"In CPU Layout: unknown number of inputs for BatchNormRelu "
+
to_string
(
bn
->
get_inputs
().
size
()));
}
node
=
node
=
insert_input_conversions
(
external_function
,
node
,
prim_input_formats
);
insert_input_conversions
(
external_function
,
node
,
prim_input_formats
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment