Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
47be91bd
Commit
47be91bd
authored
Jun 17, 2019
by
nishant.b.patel
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Change infer_convolution_forward method to just do shape checks
parent
351917d5
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
137 additions
and
116 deletions
+137
-116
convolution.cpp
src/ngraph/op/convolution.cpp
+51
-32
conv_fused.cpp
src/ngraph/op/fused/conv_fused.cpp
+34
-20
quantized_convolution.cpp
src/ngraph/op/quantized_convolution.cpp
+17
-19
deconv.cpp
src/ngraph/runtime/cpu/op/deconv.cpp
+18
-11
validation_util.cpp
src/ngraph/validation_util.cpp
+9
-23
validation_util.hpp
src/ngraph/validation_util.hpp
+8
-11
No files found.
src/ngraph/op/convolution.cpp
View file @
47be91bd
...
@@ -100,16 +100,23 @@ void op::Convolution::validate_and_infer_types()
...
@@ -100,16 +100,23 @@ void op::Convolution::validate_and_infer_types()
element
::
Type
result_et
;
element
::
Type
result_et
;
PartialShape
result_shape
;
PartialShape
result_shape
;
std
::
tie
(
result_et
,
result_shape
)
=
infer_convolution_forward
(
this
,
NODE_VALIDATION_CHECK
(
data_batch_et
,
this
,
filters_et
,
element
::
Type
::
merge
(
result_et
,
data_batch_et
,
filters_et
),
data_batch_shape
,
"Element types for data batch and filters do not match (data batch element type: "
,
m_data_dilation_strides
,
data_batch_et
,
m_padding_below
,
", filters element type: "
,
m_padding_above
,
filters_et
,
filters_shape
,
")."
);
m_window_movement_strides
,
m_window_dilation_strides
);
result_shape
=
infer_convolution_forward
(
this
,
data_batch_shape
,
m_data_dilation_strides
,
m_padding_below
,
m_padding_above
,
filters_shape
,
m_window_movement_strides
,
m_window_dilation_strides
);
set_output_type
(
0
,
result_et
,
result_shape
);
set_output_type
(
0
,
result_et
,
result_shape
);
}
}
...
@@ -255,17 +262,23 @@ void op::ConvolutionBackpropData::validate_and_infer_types()
...
@@ -255,17 +262,23 @@ void op::ConvolutionBackpropData::validate_and_infer_types()
element
::
Type
forward_result_et
;
element
::
Type
forward_result_et
;
PartialShape
forward_result_shape
;
PartialShape
forward_result_shape
;
std
::
tie
(
forward_result_et
,
forward_result_shape
)
=
NODE_VALIDATION_CHECK
(
infer_convolution_forward
(
this
,
this
,
delta_et
,
element
::
Type
::
merge
(
forward_result_et
,
delta_et
,
filters_et
),
filters_et
,
"Element types for data batch and filters do not match (data batch element type: "
,
m_data_batch_shape
,
delta_et
,
m_data_dilation_strides_forward
,
", filters element type: "
,
m_padding_below_forward
,
filters_et
,
m_padding_above_forward
,
")."
);
filters_shape
,
m_window_movement_strides_forward
,
forward_result_shape
=
infer_convolution_forward
(
this
,
m_window_dilation_strides_forward
);
m_data_batch_shape
,
m_data_dilation_strides_forward
,
m_padding_below_forward
,
m_padding_above_forward
,
filters_shape
,
m_window_movement_strides_forward
,
m_window_dilation_strides_forward
);
NODE_VALIDATION_CHECK
(
this
,
NODE_VALIDATION_CHECK
(
this
,
forward_result_shape
.
compatible
(
delta_shape
),
forward_result_shape
.
compatible
(
delta_shape
),
...
@@ -481,17 +494,23 @@ void op::ConvolutionBackpropFilters::validate_and_infer_types()
...
@@ -481,17 +494,23 @@ void op::ConvolutionBackpropFilters::validate_and_infer_types()
element
::
Type
forward_result_et
;
element
::
Type
forward_result_et
;
PartialShape
forward_result_shape
;
PartialShape
forward_result_shape
;
std
::
tie
(
forward_result_et
,
forward_result_shape
)
=
NODE_VALIDATION_CHECK
(
infer_convolution_forward
(
this
,
this
,
data_batch_et
,
element
::
Type
::
merge
(
forward_result_et
,
data_batch_et
,
delta_et
),
delta_et
,
"Element types for data batch and filters do not match (data batch element type: "
,
data_batch_shape
,
data_batch_et
,
m_data_dilation_strides_forward
,
", filters element type: "
,
m_padding_below_forward
,
delta_et
,
m_padding_above_forward
,
")."
);
m_filters_shape
,
m_window_movement_strides_forward
,
forward_result_shape
=
infer_convolution_forward
(
this
,
m_window_dilation_strides_forward
);
data_batch_shape
,
m_data_dilation_strides_forward
,
m_padding_below_forward
,
m_padding_above_forward
,
m_filters_shape
,
m_window_movement_strides_forward
,
m_window_dilation_strides_forward
);
NODE_VALIDATION_CHECK
(
this
,
NODE_VALIDATION_CHECK
(
this
,
forward_result_shape
.
compatible
(
delta_shape
),
forward_result_shape
.
compatible
(
delta_shape
),
...
...
src/ngraph/op/fused/conv_fused.cpp
View file @
47be91bd
...
@@ -159,16 +159,23 @@ void op::ConvolutionBias::validate_and_infer_types()
...
@@ -159,16 +159,23 @@ void op::ConvolutionBias::validate_and_infer_types()
element
::
Type
result_et
;
element
::
Type
result_et
;
PartialShape
result_shape
;
PartialShape
result_shape
;
std
::
tie
(
result_et
,
result_shape
)
=
infer_convolution_forward
(
this
,
NODE_VALIDATION_CHECK
(
data_batch_et
,
this
,
filters_et
,
element
::
Type
::
merge
(
result_et
,
data_batch_et
,
filters_et
),
data_batch_shape
,
"Element types for data batch and filters do not match (data batch element type: "
,
m_data_dilation_strides
,
data_batch_et
,
m_padding_below
,
", filters element type: "
,
m_padding_above
,
filters_et
,
filters_shape
,
")."
);
m_window_movement_strides
,
m_window_dilation_strides
);
result_shape
=
infer_convolution_forward
(
this
,
data_batch_shape
,
m_data_dilation_strides
,
m_padding_below
,
m_padding_above
,
filters_shape
,
m_window_movement_strides
,
m_window_dilation_strides
);
set_output_type
(
0
,
result_et
,
result_shape
);
set_output_type
(
0
,
result_et
,
result_shape
);
}
}
...
@@ -407,16 +414,23 @@ void op::ConvolutionBiasAdd::validate_and_infer_types()
...
@@ -407,16 +414,23 @@ void op::ConvolutionBiasAdd::validate_and_infer_types()
element
::
Type
result_et
;
element
::
Type
result_et
;
PartialShape
result_shape
;
PartialShape
result_shape
;
std
::
tie
(
result_et
,
result_shape
)
=
infer_convolution_forward
(
this
,
NODE_VALIDATION_CHECK
(
data_batch_et
,
this
,
filters_et
,
element
::
Type
::
merge
(
result_et
,
data_batch_et
,
filters_et
),
data_batch_shape
,
"Element types for data batch and filters do not match (data batch element type: "
,
m_data_dilation_strides
,
data_batch_et
,
m_padding_below
,
", filters element type: "
,
m_padding_above
,
filters_et
,
filters_shape
,
")."
);
m_window_movement_strides
,
m_window_dilation_strides
);
result_shape
=
infer_convolution_forward
(
this
,
data_batch_shape
,
m_data_dilation_strides
,
m_padding_below
,
m_padding_above
,
filters_shape
,
m_window_movement_strides
,
m_window_dilation_strides
);
// TODO: Check result_shape is compatible with add_input
// TODO: Check result_shape is compatible with add_input
set_output_type
(
0
,
result_et
,
result_shape
);
set_output_type
(
0
,
result_et
,
result_shape
);
}
}
...
...
src/ngraph/op/quantized_convolution.cpp
View file @
47be91bd
...
@@ -118,8 +118,11 @@ void op::QuantizedConvolution::validate_and_infer_types()
...
@@ -118,8 +118,11 @@ void op::QuantizedConvolution::validate_and_infer_types()
shape_size
(
get_input_shape
(
7
))
==
1
,
shape_size
(
get_input_shape
(
7
))
==
1
,
"Output scale and output zero point shape must be same and 1"
);
"Output scale and output zero point shape must be same and 1"
);
auto
input_shape
=
get_input_shape
(
0
);
// auto input_shape = get_input_shape(0);
auto
filters_shape
=
get_input_shape
(
1
);
// auto filters_shape = get_input_shape(1);
const
PartialShape
&
input_shape
=
get_input_partial_shape
(
0
);
const
PartialShape
&
filters_shape
=
get_input_partial_shape
(
1
);
if
(
m_data_dilation_strides
.
size
()
==
0
)
if
(
m_data_dilation_strides
.
size
()
==
0
)
{
{
...
@@ -146,23 +149,16 @@ void op::QuantizedConvolution::validate_and_infer_types()
...
@@ -146,23 +149,16 @@ void op::QuantizedConvolution::validate_and_infer_types()
m_padding_above
=
conv_default_padding
(
this
,
input_shape
,
filters_shape
);
m_padding_above
=
conv_default_padding
(
this
,
input_shape
,
filters_shape
);
}
}
set_output_type
(
0
,
PartialShape
result_shape
;
m_output_type
,
util
::
infer_convolution_output_shape
(
this
,
result_shape
=
infer_convolution_forward
(
this
,
input_shape
,
input_shape
,
filters_shape
,
m_data_dilation_strides
,
m_window_movement_strides
,
m_padding_below
,
m_window_dilation_strides
,
m_padding_above
,
m_padding_below
,
filters_shape
,
m_padding_above
,
m_window_movement_strides
,
m_data_dilation_strides
,
m_window_dilation_strides
);
0
,
/* batch_axis_data, */
1
,
/* input_channel_axis_data, */
1
,
/* input_channel_axis_filters, */
0
,
/* output_channel_axis_filters, */
0
,
/* batch_axis_result, */
1
/* output_channel_axis_result, */
));
NODE_VALIDATION_CHECK
(
NODE_VALIDATION_CHECK
(
this
,
this
,
...
@@ -172,6 +168,8 @@ void op::QuantizedConvolution::validate_and_infer_types()
...
@@ -172,6 +168,8 @@ void op::QuantizedConvolution::validate_and_infer_types()
") must match output element type ("
,
") must match output element type ("
,
get_output_element_type
(
0
),
get_output_element_type
(
0
),
")"
);
")"
);
set_output_type
(
0
,
m_output_type
,
result_shape
);
}
}
shared_ptr
<
Node
>
op
::
QuantizedConvolution
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
shared_ptr
<
Node
>
op
::
QuantizedConvolution
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
...
...
src/ngraph/runtime/cpu/op/deconv.cpp
View file @
47be91bd
...
@@ -96,17 +96,24 @@ void op::DeconvolutionBias::validate_and_infer_types()
...
@@ -96,17 +96,24 @@ void op::DeconvolutionBias::validate_and_infer_types()
const
PartialShape
&
fwd_filters_shape
{
const
PartialShape
&
fwd_filters_shape
{
filters_shape
[
1
],
filters_shape
[
0
],
filters_shape
[
2
],
filters_shape
[
3
]};
filters_shape
[
1
],
filters_shape
[
0
],
filters_shape
[
2
],
filters_shape
[
3
]};
std
::
tie
(
forward_result_et
,
forward_result_shape
)
=
infer_convolution_forward
(
this
,
NODE_VALIDATION_CHECK
(
delta_et
,
this
,
filters_et
,
element
::
Type
::
merge
(
forward_result_et
,
delta_et
,
filters_et
),
m_data_batch_shape
,
"Element types for data batch and filters do not match (data batch element type: "
,
m_data_dilation_strides_forward
,
delta_et
,
m_padding_below_forward
,
", filters element type: "
,
m_padding_above_forward
,
filters_et
,
fwd_filters_shape
,
")."
);
m_window_movement_strides_forward
,
m_window_dilation_strides_forward
);
forward_result_shape
=
infer_convolution_forward
(
this
,
m_data_batch_shape
,
m_data_dilation_strides_forward
,
m_padding_below_forward
,
m_padding_above_forward
,
fwd_filters_shape
,
m_window_movement_strides_forward
,
m_window_dilation_strides_forward
);
NGRAPH_DEBUG
<<
"
\t
partial filter_shape: "
<<
filters_shape
<<
"delta_shape: "
<<
delta_shape
NGRAPH_DEBUG
<<
"
\t
partial filter_shape: "
<<
filters_shape
<<
"delta_shape: "
<<
delta_shape
<<
", inferred_res_shape: "
<<
forward_result_shape
<<
endl
;
<<
", inferred_res_shape: "
<<
forward_result_shape
<<
endl
;
...
...
src/ngraph/validation_util.cpp
View file @
47be91bd
...
@@ -211,29 +211,15 @@ PartialShape ngraph::infer_windowed_reduction_output_shape(const Node* node,
...
@@ -211,29 +211,15 @@ PartialShape ngraph::infer_windowed_reduction_output_shape(const Node* node,
//
//
// Infers the output batch shape and element type for convolution fprop.
// Infers the output batch shape and element type for convolution fprop.
//
//
std
::
tuple
<
element
::
Type
,
PartialShape
>
PartialShape
ngraph
::
infer_convolution_forward
(
const
Node
*
node
,
ngraph
::
infer_convolution_forward
(
const
Node
*
node
,
const
PartialShape
&
data_batch_shape
,
element
::
Type
et_batch
,
const
Strides
&
data_dilation
,
element
::
Type
et_filters
,
const
CoordinateDiff
&
data_padding_below
,
const
PartialShape
&
data_batch_shape
,
const
CoordinateDiff
&
data_padding_above
,
const
Strides
&
data_dilation
,
const
PartialShape
&
filters_shape
,
const
CoordinateDiff
&
data_padding_below
,
const
Strides
&
filter_strides
,
const
CoordinateDiff
&
data_padding_above
,
const
Strides
&
filter_dilation
)
const
PartialShape
&
filters_shape
,
const
Strides
&
filter_strides
,
const
Strides
&
filter_dilation
)
{
{
element
::
Type
et_result
;
NODE_VALIDATION_CHECK
(
node
,
element
::
Type
::
merge
(
et_result
,
et_batch
,
et_filters
),
"Element types for data batch and filters do not match (data batch element type: "
,
et_batch
,
", filters element type: "
,
et_filters
,
")."
);
Rank
data_batch_filters_rank
{
Rank
::
dynamic
()};
Rank
data_batch_filters_rank
{
Rank
::
dynamic
()};
NODE_VALIDATION_CHECK
(
NODE_VALIDATION_CHECK
(
...
@@ -358,7 +344,7 @@ std::tuple<element::Type, PartialShape>
...
@@ -358,7 +344,7 @@ std::tuple<element::Type, PartialShape>
batch_output_shape
[
i
+
2
]
=
data_output_shape
[
i
];
batch_output_shape
[
i
+
2
]
=
data_output_shape
[
i
];
}
}
return
std
::
make_tuple
(
et_result
,
batch_output_shape
)
;
return
batch_output_shape
;
}
}
//
//
...
...
src/ngraph/validation_util.hpp
View file @
47be91bd
...
@@ -43,17 +43,14 @@ namespace ngraph
...
@@ -43,17 +43,14 @@ namespace ngraph
const
Strides
&
window_dilation
,
const
Strides
&
window_dilation
,
bool
is_window_all_in_padding_allowed
);
bool
is_window_all_in_padding_allowed
);
std
::
tuple
<
element
::
Type
,
PartialShape
>
PartialShape
infer_convolution_forward
(
const
Node
*
node
,
infer_convolution_forward
(
const
Node
*
node
,
const
PartialShape
&
data_batch_shape
,
element
::
Type
et_batch
,
const
Strides
&
data_dilation
,
element
::
Type
et_filters
,
const
CoordinateDiff
&
data_padding_below
,
const
PartialShape
&
data_batch_shape
,
const
CoordinateDiff
&
data_padding_above
,
const
Strides
&
data_dilation
,
const
PartialShape
&
filters_shape
,
const
CoordinateDiff
&
data_padding_below
,
const
Strides
&
filter_strides
,
const
CoordinateDiff
&
data_padding_above
,
const
Strides
&
filter_dilation
);
const
PartialShape
&
filters_shape
,
const
Strides
&
filter_strides
,
const
Strides
&
filter_dilation
);
PartialShape
infer_batched_pooling_forward
(
const
Node
*
node
,
PartialShape
infer_batched_pooling_forward
(
const
Node
*
node
,
const
PartialShape
&
data_batch_shape
,
const
PartialShape
&
data_batch_shape
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment