Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
066037c2
Commit
066037c2
authored
Mar 05, 2019
by
Adam Procter
Committed by
Scott Cyphers
Mar 05, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Switch everything from NODE_VALIDATION_ASSERT to NODE_VALIDATION_CHECK (#2546)
parent
dd23b0cb
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
36 changed files
with
911 additions
and
468 deletions
+911
-468
node.hpp
src/ngraph/node.hpp
+1
-21
allreduce.cpp
src/ngraph/op/allreduce.cpp
+7
-6
avg_pool.cpp
src/ngraph/op/avg_pool.cpp
+9
-3
batch_norm.cpp
src/ngraph/op/batch_norm.cpp
+17
-12
broadcast.cpp
src/ngraph/op/broadcast.cpp
+21
-7
concat.cpp
src/ngraph/op/concat.cpp
+23
-13
constant.hpp
src/ngraph/op/constant.hpp
+21
-9
convolution.cpp
src/ngraph/op/convolution.cpp
+0
-0
dequantize.cpp
src/ngraph/op/dequantize.cpp
+60
-31
dot.cpp
src/ngraph/op/dot.cpp
+43
-22
embedding_lookup.cpp
src/ngraph/op/embedding_lookup.cpp
+4
-3
generate_mask.cpp
src/ngraph/op/experimental/generate_mask.cpp
+5
-4
quantized_avg_pool.cpp
src/ngraph/op/experimental/quantized_avg_pool.cpp
+90
-46
quantized_concat.cpp
src/ngraph/op/experimental/quantized_concat.cpp
+22
-13
quantized_max_pool.cpp
src/ngraph/op/experimental/quantized_max_pool.cpp
+71
-34
get_output_element.cpp
src/ngraph/op/get_output_element.cpp
+7
-3
lrn.cpp
src/ngraph/op/lrn.cpp
+6
-3
max_pool.cpp
src/ngraph/op/max_pool.cpp
+16
-6
one_hot.cpp
src/ngraph/op/one_hot.cpp
+24
-11
pad.cpp
src/ngraph/op/pad.cpp
+37
-19
quantize.cpp
src/ngraph/op/quantize.cpp
+60
-31
replace_slice.cpp
src/ngraph/op/replace_slice.cpp
+70
-31
reshape.cpp
src/ngraph/op/reshape.cpp
+34
-15
result.cpp
src/ngraph/op/result.cpp
+2
-2
reverse.cpp
src/ngraph/op/reverse.cpp
+7
-3
reverse_sequence.cpp
src/ngraph/op/reverse_sequence.cpp
+32
-14
select.cpp
src/ngraph/op/select.cpp
+16
-12
slice.cpp
src/ngraph/op/slice.cpp
+47
-19
softmax.cpp
src/ngraph/op/softmax.cpp
+7
-3
topk.cpp
src/ngraph/op/topk.cpp
+26
-16
arithmetic_reduction.cpp
src/ngraph/op/util/arithmetic_reduction.cpp
+10
-4
index_reduction.cpp
src/ngraph/op/util/index_reduction.cpp
+12
-7
logical_reduction.cpp
src/ngraph/op/util/logical_reduction.cpp
+13
-6
conv_add.cpp
src/ngraph/runtime/cpu/op/conv_add.cpp
+21
-8
update_slice.cpp
src/ngraph/runtime/cpu/op/update_slice.cpp
+70
-31
type_prop.cpp
test/type_prop.cpp
+0
-0
No files found.
src/ngraph/node.hpp
View file @
066037c2
...
@@ -275,19 +275,6 @@ namespace ngraph
...
@@ -275,19 +275,6 @@ namespace ngraph
size_t
m_placement_index
=
placement_invalid
;
size_t
m_placement_index
=
placement_invalid
;
};
};
class
NodeValidationError
:
public
AssertionFailure
{
public
:
NodeValidationError
(
std
::
string
what
)
:
AssertionFailure
(
what
)
{
}
NodeValidationError
(
const
char
*
what
)
:
AssertionFailure
(
what
)
{
}
};
class
NodeValidationFailure
:
public
CheckFailure
class
NodeValidationFailure
:
public
CheckFailure
{
{
public
:
public
:
...
@@ -321,12 +308,5 @@ namespace ngraph
...
@@ -321,12 +308,5 @@ namespace ngraph
void
check_new_args_count
(
const
Node
*
node
,
const
NodeVector
&
new_args
);
void
check_new_args_count
(
const
Node
*
node
,
const
NodeVector
&
new_args
);
}
// namespace ngraph
}
// namespace ngraph
#define NODE_VALIDATION_ASSERT(node, cond) \
NGRAPH_ASSERT_STREAM_WITH_LOC( \
::ngraph::NodeValidationError, cond, ::ngraph::node_validation_assertion_string(node))
#define NODE_VALIDATION_FAIL(node) \
NGRAPH_FAIL_STREAM_WITH_LOC(::ngraph::NodeValidationError, \
::ngraph::node_validation_assertion_string(node))
#define NODE_VALIDATION_CHECK(node, cond, ...) \
#define NODE_VALIDATION_CHECK(node, cond, ...) \
NGRAPH_CHECK(::NodeValidationFailure, (node), (cond), __VA_ARGS__)
NGRAPH_CHECK(::
ngraph::
NodeValidationFailure, (node), (cond), __VA_ARGS__)
src/ngraph/op/allreduce.cpp
View file @
066037c2
...
@@ -27,12 +27,13 @@ op::AllReduce::AllReduce(const shared_ptr<Node>& arg)
...
@@ -27,12 +27,13 @@ op::AllReduce::AllReduce(const shared_ptr<Node>& arg)
void
op
::
AllReduce
::
validate_and_infer_types
()
void
op
::
AllReduce
::
validate_and_infer_types
()
{
{
NODE_VALIDATION_ASSERT
(
this
,
NODE_VALIDATION_CHECK
(
this
,
get_input_element_type
(
0
).
is_dynamic
()
||
get_input_element_type
(
0
).
is_dynamic
()
||
get_input_element_type
(
0
)
==
element
::
f32
||
get_input_element_type
(
0
)
==
element
::
f32
||
get_input_element_type
(
0
)
==
element
::
f64
)
get_input_element_type
(
0
)
==
element
::
f64
,
<<
"Only element types f32 and f64 are supported (argument element type: "
"Only element types f32 and f64 are supported (argument element type: "
,
<<
get_input_element_type
(
0
)
<<
")."
;
get_input_element_type
(
0
),
")."
);
set_output_type
(
0
,
get_input_element_type
(
0
),
get_input_partial_shape
(
0
));
set_output_type
(
0
,
get_input_element_type
(
0
),
get_input_partial_shape
(
0
));
}
}
...
...
src/ngraph/op/avg_pool.cpp
View file @
066037c2
...
@@ -132,9 +132,15 @@ void op::AvgPoolBackprop::validate_and_infer_types()
...
@@ -132,9 +132,15 @@ void op::AvgPoolBackprop::validate_and_infer_types()
const
PartialShape
&
delta_shape
=
get_input_partial_shape
(
0
);
const
PartialShape
&
delta_shape
=
get_input_partial_shape
(
0
);
NODE_VALIDATION_ASSERT
(
this
,
forward_result_shape
.
compatible
(
delta_shape
))
NODE_VALIDATION_CHECK
(
<<
"Inferred forward output shape does not match delta shape (inferred forward output "
this
,
<<
"shape: "
<<
forward_result_shape
<<
", delta shape: "
<<
delta_shape
<<
")."
;
forward_result_shape
.
compatible
(
delta_shape
),
"Inferred forward output shape does not match delta shape (inferred forward output "
,
"shape: "
,
forward_result_shape
,
", delta shape: "
,
delta_shape
,
")."
);
// TODO(amprocte): Once m_forward_arg_shape is allowed to be dynamic, we may technically be
// TODO(amprocte): Once m_forward_arg_shape is allowed to be dynamic, we may technically be
// able to infer some extra information from forward_result_shape that was not present in the
// able to infer some extra information from forward_result_shape that was not present in the
...
...
src/ngraph/op/batch_norm.cpp
View file @
066037c2
...
@@ -205,21 +205,26 @@ void ngraph::op::BatchNormTrainingBackprop::validate_and_infer_types()
...
@@ -205,21 +205,26 @@ void ngraph::op::BatchNormTrainingBackprop::validate_and_infer_types()
{
{
PartialShape
input_and_delta_shape
{
get_input_partial_shape
(
INPUT_DATA
)};
PartialShape
input_and_delta_shape
{
get_input_partial_shape
(
INPUT_DATA
)};
NODE_VALIDATION_ASSERT
(
NODE_VALIDATION_CHECK
(
this
,
PartialShape
::
merge_into
(
input_and_delta_shape
,
get_input_partial_shape
(
INPUT_DELTA
)))
this
,
<<
"Shape of delta does not match the shape of the input data (input data shape: "
PartialShape
::
merge_into
(
input_and_delta_shape
,
get_input_partial_shape
(
INPUT_DELTA
)),
<<
get_input_partial_shape
(
INPUT_DATA
)
"Shape of delta does not match the shape of the input data (input data shape: "
,
<<
", delta shape: "
<<
get_input_partial_shape
(
INPUT_DELTA
)
<<
")."
;
get_input_partial_shape
(
INPUT_DATA
),
", delta shape: "
,
get_input_partial_shape
(
INPUT_DELTA
),
")."
);
element
::
Type
input_and_delta_et
;
element
::
Type
input_and_delta_et
;
NODE_VALIDATION_ASSERT
(
this
,
NODE_VALIDATION_CHECK
(
this
,
element
::
Type
::
merge
(
input_and_delta_et
,
element
::
Type
::
merge
(
input_and_delta_et
,
get_input_element_type
(
INPUT_DATA
),
get_input_element_type
(
INPUT_DATA
),
get_input_element_type
(
INPUT_DELTA
)))
get_input_element_type
(
INPUT_DELTA
)),
<<
"Element type for input ("
<<
get_input_element_type
(
INPUT_DATA
)
"Element type for input ("
,
<<
") does not match element type for delta ("
<<
get_input_element_type
(
INPUT_DATA
)
get_input_element_type
(
INPUT_DATA
),
<<
")."
;
") does not match element type for delta ("
,
get_input_element_type
(
INPUT_DATA
),
")."
);
element
::
Type
result_et
;
element
::
Type
result_et
;
PartialShape
result_batch_shape
;
PartialShape
result_batch_shape
;
...
...
src/ngraph/op/broadcast.cpp
View file @
066037c2
...
@@ -44,9 +44,16 @@ void op::Broadcast::validate_and_infer_types()
...
@@ -44,9 +44,16 @@ void op::Broadcast::validate_and_infer_types()
for
(
auto
axis
:
m_broadcast_axes
)
for
(
auto
axis
:
m_broadcast_axes
)
{
{
NODE_VALIDATION_ASSERT
(
this
,
axis
<
m_shape
.
size
())
NODE_VALIDATION_CHECK
(
this
,
<<
"Broadcast axis index ("
<<
axis
<<
") exceeds specified output shape rank "
axis
<
m_shape
.
size
(),
<<
"(broadcast axes: "
<<
m_broadcast_axes
<<
", output shape: "
<<
m_shape
<<
")."
;
"Broadcast axis index ("
,
axis
,
") exceeds specified output shape rank "
,
"(broadcast axes: "
,
m_broadcast_axes
,
", output shape: "
,
m_shape
,
")."
);
}
}
Shape
required_input_shape
=
m_shape
;
Shape
required_input_shape
=
m_shape
;
...
@@ -59,10 +66,17 @@ void op::Broadcast::validate_and_infer_types()
...
@@ -59,10 +66,17 @@ void op::Broadcast::validate_and_infer_types()
// There are two things that can go wrong, which are being picked up in
// There are two things that can go wrong, which are being picked up in
// one fell swoop by this check: either the number of broadcast axes is not
// one fell swoop by this check: either the number of broadcast axes is not
// enough, or there is a mismatch with one of the pre-broadcast axis lengths.
// enough, or there is a mismatch with one of the pre-broadcast axis lengths.
NODE_VALIDATION_ASSERT
(
this
,
get_input_partial_shape
(
0
).
compatible
(
required_input_shape
))
NODE_VALIDATION_CHECK
(
<<
"Broadcast argument shape, specified output shape, and axes are incompatible "
this
,
<<
"(argument shape: "
<<
get_input_partial_shape
(
0
)
<<
", output shape: "
<<
m_shape
get_input_partial_shape
(
0
).
compatible
(
required_input_shape
),
<<
", broadcast axes: "
<<
m_broadcast_axes
<<
")."
;
"Broadcast argument shape, specified output shape, and axes are incompatible "
,
"(argument shape: "
,
get_input_partial_shape
(
0
),
", output shape: "
,
m_shape
,
", broadcast axes: "
,
m_broadcast_axes
,
")."
);
set_output_type
(
0
,
get_input_element_type
(
0
),
m_shape
);
set_output_type
(
0
,
get_input_element_type
(
0
),
m_shape
);
}
}
...
...
src/ngraph/op/concat.cpp
View file @
066037c2
...
@@ -32,7 +32,7 @@ op::Concat::Concat(const NodeVector& args, size_t concatenation_axis)
...
@@ -32,7 +32,7 @@ op::Concat::Concat(const NodeVector& args, size_t concatenation_axis)
void
op
::
Concat
::
validate_and_infer_types
()
void
op
::
Concat
::
validate_and_infer_types
()
{
{
NODE_VALIDATION_
ASSERT
(
this
,
m_inputs
.
size
()
>=
1
)
<<
"At least one argument required."
;
NODE_VALIDATION_
CHECK
(
this
,
m_inputs
.
size
()
>=
1
,
"At least one argument required."
)
;
PartialShape
inputs_shape_scheme
{
PartialShape
::
dynamic
()};
PartialShape
inputs_shape_scheme
{
PartialShape
::
dynamic
()};
element
::
Type
inputs_et
{
element
::
dynamic
};
element
::
Type
inputs_et
{
element
::
dynamic
};
...
@@ -44,22 +44,32 @@ void op::Concat::validate_and_infer_types()
...
@@ -44,22 +44,32 @@ void op::Concat::validate_and_infer_types()
Dimension
this_input_rank
=
this_input_shape
.
rank
();
Dimension
this_input_rank
=
this_input_shape
.
rank
();
if
(
this_input_rank
.
is_static
())
if
(
this_input_rank
.
is_static
())
{
{
NODE_VALIDATION_ASSERT
(
this
,
m_concatenation_axis
<
size_t
(
this_input_rank
))
NODE_VALIDATION_CHECK
(
this
,
<<
"Concatenation axis ("
<<
m_concatenation_axis
<<
") is out of bounds for "
m_concatenation_axis
<
size_t
(
this_input_rank
),
<<
"argument "
<<
i
<<
", which has shape "
<<
this_input_shape
<<
"."
;
"Concatenation axis ("
,
m_concatenation_axis
,
") is out of bounds for "
,
"argument "
,
i
,
", which has shape "
,
this_input_shape
,
"."
);
concatenation_axis_output_dim
+=
this_input_shape
[
m_concatenation_axis
];
concatenation_axis_output_dim
+=
this_input_shape
[
m_concatenation_axis
];
this_input_shape
[
m_concatenation_axis
]
=
Dimension
::
dynamic
();
this_input_shape
[
m_concatenation_axis
]
=
Dimension
::
dynamic
();
NODE_VALIDATION_ASSERT
(
this
,
NODE_VALIDATION_CHECK
(
PartialShape
::
merge_into
(
inputs_shape_scheme
,
this_input_shape
))
this
,
<<
"Argument shapes are inconsistent; they must have the same rank, and must have "
PartialShape
::
merge_into
(
inputs_shape_scheme
,
this_input_shape
),
<<
"equal dimension everywhere except on the concatenation axis (axis "
"Argument shapes are inconsistent; they must have the same rank, and must have "
,
<<
m_concatenation_axis
<<
")."
;
"equal dimension everywhere except on the concatenation axis (axis "
,
m_concatenation_axis
,
NODE_VALIDATION_ASSERT
(
")."
);
this
,
element
::
Type
::
merge
(
inputs_et
,
inputs_et
,
get_input_element_type
(
i
)))
<<
"Argument element types are inconsistent."
;
NODE_VALIDATION_CHECK
(
this
,
element
::
Type
::
merge
(
inputs_et
,
inputs_et
,
get_input_element_type
(
i
)),
"Argument element types are inconsistent."
);
}
}
else
else
{
{
...
...
src/ngraph/op/constant.hpp
View file @
066037c2
...
@@ -47,11 +47,17 @@ namespace ngraph
...
@@ -47,11 +47,17 @@ namespace ngraph
,
m_data
(
ngraph
::
aligned_alloc
(
m_element_type
.
size
(),
,
m_data
(
ngraph
::
aligned_alloc
(
m_element_type
.
size
(),
shape_size
(
m_shape
)
*
m_element_type
.
size
()))
shape_size
(
m_shape
)
*
m_element_type
.
size
()))
{
{
NODE_VALIDATION_ASSERT
(
this
,
NODE_VALIDATION_CHECK
(
values
.
size
()
==
1
||
values
.
size
()
==
shape_size
(
m_shape
))
this
,
<<
"Did not get the expected number of literals for a constant of shape "
values
.
size
()
==
1
||
values
.
size
()
==
shape_size
(
m_shape
),
<<
m_shape
<<
" (got "
<<
values
.
size
()
<<
", expected "
"Did not get the expected number of literals for a constant of shape "
,
<<
(
shape_size
(
m_shape
)
==
1
?
""
:
"1 or "
)
<<
shape_size
(
m_shape
)
<<
")."
;
m_shape
,
" (got "
,
values
.
size
(),
", expected "
,
(
shape_size
(
m_shape
)
==
1
?
""
:
"1 or "
),
shape_size
(
m_shape
),
")."
);
if
(
values
.
size
()
==
1
)
if
(
values
.
size
()
==
1
)
{
{
...
@@ -77,10 +83,16 @@ namespace ngraph
...
@@ -77,10 +83,16 @@ namespace ngraph
,
m_data
(
ngraph
::
aligned_alloc
(
m_element_type
.
size
(),
,
m_data
(
ngraph
::
aligned_alloc
(
m_element_type
.
size
(),
shape_size
(
m_shape
)
*
m_element_type
.
size
()))
shape_size
(
m_shape
)
*
m_element_type
.
size
()))
{
{
NODE_VALIDATION_ASSERT
(
this
,
values
.
size
()
==
shape_size
(
m_shape
))
NODE_VALIDATION_CHECK
(
<<
"Did not get the expected number of literals for a constant of shape "
this
,
<<
m_shape
<<
" (got "
<<
values
.
size
()
<<
", expected "
<<
shape_size
(
m_shape
)
values
.
size
()
==
shape_size
(
m_shape
),
<<
"."
;
"Did not get the expected number of literals for a constant of shape "
,
m_shape
,
" (got "
,
values
.
size
(),
", expected "
,
shape_size
(
m_shape
),
"."
);
std
::
vector
<
double
>
dvalues
=
parse_string
<
double
>
(
values
);
std
::
vector
<
double
>
dvalues
=
parse_string
<
double
>
(
values
);
write_values
(
dvalues
);
write_values
(
dvalues
);
...
...
src/ngraph/op/convolution.cpp
View file @
066037c2
This diff is collapsed.
Click to expand it.
src/ngraph/op/dequantize.cpp
View file @
066037c2
...
@@ -42,50 +42,73 @@ void op::Dequantize::validate_and_infer_types()
...
@@ -42,50 +42,73 @@ void op::Dequantize::validate_and_infer_types()
OFFSET
OFFSET
};
};
NODE_VALIDATION_
ASSERT
(
this
,
m_type
.
is_static
())
<<
"Output element type must not be dynamic"
;
NODE_VALIDATION_
CHECK
(
this
,
m_type
.
is_static
(),
"Output element type must not be dynamic"
)
;
NODE_VALIDATION_
ASSERT
(
this
,
m_type
.
is_real
())
<<
"Output element type ("
<<
m_type
NODE_VALIDATION_
CHECK
(
<<
") must be a floating point type"
;
this
,
m_type
.
is_real
(),
"Output element type ("
,
m_type
,
") must be a floating point type"
)
;
element
::
Type
quantized_type
;
element
::
Type
quantized_type
;
NODE_VALIDATION_ASSERT
(
this
,
NODE_VALIDATION_CHECK
(
this
,
element
::
Type
::
merge
(
quantized_type
,
element
::
Type
::
merge
(
quantized_type
,
get_input_element_type
(
INPUT
),
get_input_element_type
(
INPUT
),
get_input_element_type
(
OFFSET
)))
get_input_element_type
(
OFFSET
)),
<<
"Offset element type ("
<<
get_input_element_type
(
OFFSET
)
"Offset element type ("
,
<<
") must match input element type ("
<<
get_input_element_type
(
INPUT
)
<<
")"
;
get_input_element_type
(
OFFSET
),
") must match input element type ("
,
NODE_VALIDATION_ASSERT
(
this
,
quantized_type
.
is_dynamic
()
||
quantized_type
.
is_quantized
())
get_input_element_type
(
INPUT
),
<<
"Offset/input element type ("
<<
quantized_type
<<
") must be a quantized type"
;
")"
);
NODE_VALIDATION_CHECK
(
this
,
quantized_type
.
is_dynamic
()
||
quantized_type
.
is_quantized
(),
"Offset/input element type ("
,
quantized_type
,
") must be a quantized type"
);
element
::
Type
unquantized_type
;
element
::
Type
unquantized_type
;
NODE_VALIDATION_ASSERT
(
NODE_VALIDATION_CHECK
(
this
,
element
::
Type
::
merge
(
unquantized_type
,
get_input_element_type
(
SCALE
),
m_type
))
this
,
<<
"Scale element type ("
<<
get_input_element_type
(
SCALE
)
element
::
Type
::
merge
(
unquantized_type
,
get_input_element_type
(
SCALE
),
m_type
),
<<
") must match output element type ("
<<
m_type
<<
")"
;
"Scale element type ("
,
get_input_element_type
(
SCALE
),
") must match output element type ("
,
m_type
,
")"
);
PartialShape
input_shape
=
get_input_partial_shape
(
0
);
PartialShape
input_shape
=
get_input_partial_shape
(
0
);
Dimension
input_rank
=
input_shape
.
rank
();
Dimension
input_rank
=
input_shape
.
rank
();
for
(
auto
axis
:
m_axes
)
for
(
auto
axis
:
m_axes
)
{
{
NODE_VALIDATION_ASSERT
(
this
,
input_rank
.
is_dynamic
()
||
axis
<
size_t
(
input_rank
))
NODE_VALIDATION_CHECK
(
this
,
<<
"Quantization axis ("
<<
axis
<<
") must be less than input shape rank ("
input_rank
.
is_dynamic
()
||
axis
<
size_t
(
input_rank
),
<<
input_rank
<<
")"
;
"Quantization axis ("
,
axis
,
") must be less than input shape rank ("
,
input_rank
,
")"
);
}
}
PartialShape
scale_offset_shape
=
get_input_partial_shape
(
SCALE
);
PartialShape
scale_offset_shape
=
get_input_partial_shape
(
SCALE
);
NODE_VALIDATION_ASSERT
(
NODE_VALIDATION_CHECK
(
this
,
PartialShape
::
merge_into
(
scale_offset_shape
,
get_input_partial_shape
(
OFFSET
)))
this
,
<<
"Scale shape ("
<<
get_input_partial_shape
(
SCALE
)
<<
") and offset shape ("
PartialShape
::
merge_into
(
scale_offset_shape
,
get_input_partial_shape
(
OFFSET
)),
<<
get_input_partial_shape
(
OFFSET
)
<<
") must match"
;
"Scale shape ("
,
get_input_partial_shape
(
SCALE
),
NODE_VALIDATION_ASSERT
(
this
,
scale_offset_shape
.
rank
().
compatible
(
m_axes
.
size
()))
") and offset shape ("
,
<<
"Scale/offset rank ("
<<
scale_offset_shape
.
rank
()
<<
") does not match the number of "
get_input_partial_shape
(
OFFSET
),
<<
"quantization axes ("
<<
m_axes
.
size
()
<<
")"
;
") must match"
);
NODE_VALIDATION_CHECK
(
this
,
scale_offset_shape
.
rank
().
compatible
(
m_axes
.
size
()),
"Scale/offset rank ("
,
scale_offset_shape
.
rank
(),
") does not match the number of "
,
"quantization axes ("
,
m_axes
.
size
(),
")"
);
set_output_size
(
1
);
set_output_size
(
1
);
...
@@ -108,10 +131,16 @@ void op::Dequantize::validate_and_infer_types()
...
@@ -108,10 +131,16 @@ void op::Dequantize::validate_and_infer_types()
}
}
PartialShape
result_shape
=
input_shape
;
PartialShape
result_shape
=
input_shape
;
NODE_VALIDATION_ASSERT
(
NODE_VALIDATION_CHECK
(
this
,
PartialShape
::
merge_into
(
result_shape
,
PartialShape
{
injected_scale_offset_dims
}))
this
,
<<
"Scale/offset shape ("
<<
scale_offset_shape
<<
") must match input shape ("
PartialShape
::
merge_into
(
result_shape
,
PartialShape
{
injected_scale_offset_dims
}),
<<
input_shape
<<
") at the quantization axes ("
<<
m_axes
<<
")"
;
"Scale/offset shape ("
,
scale_offset_shape
,
") must match input shape ("
,
input_shape
,
") at the quantization axes ("
,
m_axes
,
")"
);
set_output_type
(
0
,
unquantized_type
,
result_shape
);
set_output_type
(
0
,
unquantized_type
,
result_shape
);
}
}
else
else
...
...
src/ngraph/op/dot.cpp
View file @
066037c2
...
@@ -49,11 +49,14 @@ void op::Dot::validate_and_infer_types()
...
@@ -49,11 +49,14 @@ void op::Dot::validate_and_infer_types()
{
{
element
::
Type
result_et
;
element
::
Type
result_et
;
NODE_VALIDATION_ASSERT
(
NODE_VALIDATION_CHECK
(
this
,
element
::
Type
::
merge
(
result_et
,
get_input_element_type
(
0
),
get_input_element_type
(
1
)))
this
,
<<
"Arguments do not have the same element type (arg0 element type: "
element
::
Type
::
merge
(
result_et
,
get_input_element_type
(
0
),
get_input_element_type
(
1
)),
<<
get_input_element_type
(
0
)
<<
", arg1 element type: "
<<
get_input_element_type
(
1
)
"Arguments do not have the same element type (arg0 element type: "
,
<<
")."
;
get_input_element_type
(
0
),
", arg1 element type: "
,
get_input_element_type
(
1
),
")."
);
const
PartialShape
&
arg0_shape
=
get_input_partial_shape
(
0
);
const
PartialShape
&
arg0_shape
=
get_input_partial_shape
(
0
);
const
PartialShape
&
arg1_shape
=
get_input_partial_shape
(
1
);
const
PartialShape
&
arg1_shape
=
get_input_partial_shape
(
1
);
...
@@ -82,17 +85,27 @@ void op::Dot::validate_and_infer_types()
...
@@ -82,17 +85,27 @@ void op::Dot::validate_and_infer_types()
PartialShape
result_shape
;
PartialShape
result_shape
;
NODE_VALIDATION_ASSERT
(
this
,
NODE_VALIDATION_CHECK
(
this
,
reduction_axes_ambiguous
||
arg0_shape
.
rank
().
is_dynamic
()
||
reduction_axes_ambiguous
||
arg0_shape
.
rank
().
is_dynamic
()
||
m_reduction_axes_count
<=
size_t
(
arg0_shape
.
rank
()))
m_reduction_axes_count
<=
size_t
(
arg0_shape
.
rank
()),
<<
"Reduction axes count ("
<<
m_reduction_axes_count
"Reduction axes count ("
,
<<
") is too large (arg0 shape: "
<<
arg0_shape
<<
", arg1 shape: "
<<
arg1_shape
<<
")."
;
m_reduction_axes_count
,
") is too large (arg0 shape: "
,
NODE_VALIDATION_ASSERT
(
this
,
arg0_shape
,
reduction_axes_ambiguous
||
arg1_shape
.
rank
().
is_dynamic
()
||
", arg1 shape: "
,
m_reduction_axes_count
<=
size_t
(
arg1_shape
.
rank
()))
arg1_shape
,
<<
"Reduction axes count ("
<<
m_reduction_axes_count
")."
);
<<
") is too large (arg0 shape: "
<<
arg0_shape
<<
", arg1 shape: "
<<
arg1_shape
<<
")."
;
NODE_VALIDATION_CHECK
(
this
,
reduction_axes_ambiguous
||
arg1_shape
.
rank
().
is_dynamic
()
||
m_reduction_axes_count
<=
size_t
(
arg1_shape
.
rank
()),
"Reduction axes count ("
,
m_reduction_axes_count
,
") is too large (arg0 shape: "
,
arg0_shape
,
", arg1 shape: "
,
arg1_shape
,
")."
);
if
(
!
reduction_axes_ambiguous
&&
arg0_shape
.
rank
().
is_static
()
&&
arg1_shape
.
rank
().
is_static
())
if
(
!
reduction_axes_ambiguous
&&
arg0_shape
.
rank
().
is_static
()
&&
arg1_shape
.
rank
().
is_static
())
{
{
...
@@ -101,12 +114,20 @@ void op::Dot::validate_and_infer_types()
...
@@ -101,12 +114,20 @@ void op::Dot::validate_and_infer_types()
size_t
axis_index_arg0
=
size_t
(
arg0_shape
.
rank
())
-
m_reduction_axes_count
+
i
;
size_t
axis_index_arg0
=
size_t
(
arg0_shape
.
rank
())
-
m_reduction_axes_count
+
i
;
size_t
axis_index_arg1
=
i
;
size_t
axis_index_arg1
=
i
;
NODE_VALIDATION_ASSERT
(
NODE_VALIDATION_CHECK
(
this
,
arg0_shape
[
axis_index_arg0
].
compatible
(
arg1_shape
[
axis_index_arg1
]))
this
,
<<
"Paired axes (axis "
<<
axis_index_arg0
<<
" from arg0, axis "
<<
axis_index_arg1
arg0_shape
[
axis_index_arg0
].
compatible
(
arg1_shape
[
axis_index_arg1
]),
<<
" from arg1) do not have same length (arg0 shape: "
<<
arg0_shape
"Paired axes (axis "
,
<<
", arg1 shape: "
<<
arg1_shape
axis_index_arg0
,
<<
", reduction axes count: "
<<
m_reduction_axes_count
<<
")."
;
" from arg0, axis "
,
axis_index_arg1
,
" from arg1) do not have same length (arg0 shape: "
,
arg0_shape
,
", arg1 shape: "
,
arg1_shape
,
", reduction axes count: "
,
m_reduction_axes_count
,
")."
);
}
}
std
::
vector
<
Dimension
>
result_dims
(
size_t
(
arg0_shape
.
rank
())
+
size_t
(
arg1_shape
.
rank
())
-
std
::
vector
<
Dimension
>
result_dims
(
size_t
(
arg0_shape
.
rank
())
+
size_t
(
arg1_shape
.
rank
())
-
...
...
src/ngraph/op/embedding_lookup.cpp
View file @
066037c2
...
@@ -26,9 +26,10 @@ void op::EmbeddingLookup::validate_and_infer_types()
...
@@ -26,9 +26,10 @@ void op::EmbeddingLookup::validate_and_infer_types()
const
PartialShape
&
arg0_shape
=
get_input_partial_shape
(
0
);
const
PartialShape
&
arg0_shape
=
get_input_partial_shape
(
0
);
const
PartialShape
&
arg1_shape
=
get_input_partial_shape
(
1
);
const
PartialShape
&
arg1_shape
=
get_input_partial_shape
(
1
);
NODE_VALIDATION_ASSERT
(
NODE_VALIDATION_CHECK
(
this
,
this
,
arg1_shape
.
rank
().
is_dynamic
()
||
static_cast
<
size_t
>
(
arg1_shape
.
rank
())
==
2
)
arg1_shape
.
rank
().
is_dynamic
()
||
<<
"weights are expected to be a matrix"
;
static_cast
<
size_t
>
(
arg1_shape
.
rank
())
==
2
,
"weights are expected to be a matrix"
);
PartialShape
result_shape
;
PartialShape
result_shape
;
if
(
arg0_shape
.
rank
().
is_static
())
if
(
arg0_shape
.
rank
().
is_static
())
...
...
src/ngraph/op/experimental/generate_mask.cpp
View file @
066037c2
...
@@ -42,11 +42,12 @@ shared_ptr<Node> op::GenerateMask::copy_with_new_args(const NodeVector& new_args
...
@@ -42,11 +42,12 @@ shared_ptr<Node> op::GenerateMask::copy_with_new_args(const NodeVector& new_args
void
ngraph
::
op
::
GenerateMask
::
validate_and_infer_types
()
void
ngraph
::
op
::
GenerateMask
::
validate_and_infer_types
()
{
{
NODE_VALIDATION_ASSERT
(
this
,
get_input_partial_shape
(
0
).
compatible
(
PartialShape
{}))
NODE_VALIDATION_CHECK
(
this
,
<<
"Training node should be a scalar flag indicating a mode"
;
get_input_partial_shape
(
0
).
compatible
(
PartialShape
{}),
"Training node should be a scalar flag indicating a mode"
);
NODE_VALIDATION_
ASSERT
(
this
,
m_element_type
.
is_static
())
NODE_VALIDATION_
CHECK
(
<<
"Output element type must not be dynamic."
;
this
,
m_element_type
.
is_static
(),
"Output element type must not be dynamic."
)
;
set_output_type
(
0
,
m_element_type
,
m_shape
);
set_output_type
(
0
,
m_element_type
,
m_shape
);
}
}
src/ngraph/op/experimental/quantized_avg_pool.cpp
View file @
066037c2
...
@@ -65,36 +65,58 @@ void op::QuantizedAvgPool::validate_and_infer_types()
...
@@ -65,36 +65,58 @@ void op::QuantizedAvgPool::validate_and_infer_types()
// Make sure batch size and channel count are not zero, and that we have at least one spatial
// Make sure batch size and channel count are not zero, and that we have at least one spatial
// dimension (in other words, that arg has shape NCDi for some Di of rank>0, N != 0, C != 0).
// dimension (in other words, that arg has shape NCDi for some Di of rank>0, N != 0, C != 0).
//
//
NODE_VALIDATION_ASSERT
(
this
,
arg_shape
.
size
()
>=
3
)
NODE_VALIDATION_CHECK
(
this
,
<<
"Data input shape does not have rank of at least 3 (data input shape: "
<<
arg_shape
arg_shape
.
size
()
>=
3
,
<<
")."
;
"Data input shape does not have rank of at least 3 (data input shape: "
,
arg_shape
,
")."
);
size_t
batch_size
=
arg_shape
[
0
];
size_t
batch_size
=
arg_shape
[
0
];
NODE_VALIDATION_
ASSERT
(
this
,
batch_size
!=
0
)
NODE_VALIDATION_
CHECK
(
<<
"Data batch size is zero (data input shape: "
<<
arg_shape
<<
")."
;
this
,
batch_size
!=
0
,
"Data batch size is zero (data input shape: "
,
arg_shape
,
")."
)
;
size_t
channel_count
=
arg_shape
[
1
];
size_t
channel_count
=
arg_shape
[
1
];
NODE_VALIDATION_
ASSERT
(
this
,
channel_count
!=
0
)
NODE_VALIDATION_
CHECK
(
<<
"Channel count is zero (data input shape: "
<<
arg_shape
<<
")."
;
this
,
channel_count
!=
0
,
"Channel count is zero (data input shape: "
,
arg_shape
,
")."
)
;
size_t
spatial_dimension_count
=
arg_shape
.
size
()
-
2
;
size_t
spatial_dimension_count
=
arg_shape
.
size
()
-
2
;
//
//
// Make sure window shape, window movement strides, and padding have same rank as Di.
// Make sure window shape, window movement strides, and padding have same rank as Di.
//
//
NODE_VALIDATION_ASSERT
(
this
,
m_window_shape
.
size
()
==
spatial_dimension_count
)
NODE_VALIDATION_CHECK
(
<<
"Window shape rank does not match number of spatial dimensions (window shape: "
this
,
<<
m_window_shape
<<
", data input shape: "
<<
arg_shape
<<
")."
;
m_window_shape
.
size
()
==
spatial_dimension_count
,
NODE_VALIDATION_ASSERT
(
this
,
m_window_movement_strides
.
size
()
==
spatial_dimension_count
)
"Window shape rank does not match number of spatial dimensions (window shape: "
,
<<
"Window movement stride rank does not match number of spatial dimensions (window "
m_window_shape
,
"movement strides: "
", data input shape: "
,
<<
m_window_movement_strides
<<
", data input shape: "
<<
arg_shape
<<
")."
;
arg_shape
,
NODE_VALIDATION_ASSERT
(
this
,
m_padding_below
.
size
()
==
spatial_dimension_count
)
")."
);
<<
"Below-padding rank does not match number of spatial dimensions (padding below: "
NODE_VALIDATION_CHECK
(
<<
m_padding_below
<<
", data input shape: "
<<
arg_shape
<<
")."
;
this
,
NODE_VALIDATION_ASSERT
(
this
,
m_padding_above
.
size
()
==
spatial_dimension_count
)
m_window_movement_strides
.
size
()
==
spatial_dimension_count
,
<<
"Above-padding rank does not match number of spatial dimensions (padding above: "
"Window movement stride rank does not match number of spatial dimensions (window "
<<
m_padding_above
<<
", data input shape: "
<<
arg_shape
<<
")."
;
"movement strides: "
,
m_window_movement_strides
,
", data input shape: "
,
arg_shape
,
")."
);
NODE_VALIDATION_CHECK
(
this
,
m_padding_below
.
size
()
==
spatial_dimension_count
,
"Below-padding rank does not match number of spatial dimensions (padding below: "
,
m_padding_below
,
", data input shape: "
,
arg_shape
,
")."
);
NODE_VALIDATION_CHECK
(
this
,
m_padding_above
.
size
()
==
spatial_dimension_count
,
"Above-padding rank does not match number of spatial dimensions (padding above: "
,
m_padding_above
,
", data input shape: "
,
arg_shape
,
")."
);
//
//
// Extract input item shape Di and make sure all dimensions are larger than 0.
// Extract input item shape Di and make sure all dimensions are larger than 0.
...
@@ -110,10 +132,13 @@ void op::QuantizedAvgPool::validate_and_infer_types()
...
@@ -110,10 +132,13 @@ void op::QuantizedAvgPool::validate_and_infer_types()
for
(
size_t
i
=
0
;
i
<
spatial_dimension_count
;
i
++
)
for
(
size_t
i
=
0
;
i
<
spatial_dimension_count
;
i
++
)
{
{
NODE_VALIDATION_ASSERT
(
this
,
input_item_virtual_shape
[
i
]
!=
0
)
NODE_VALIDATION_CHECK
(
this
,
<<
"Data input spatial dimension "
<<
i
input_item_virtual_shape
[
i
]
!=
0
,
<<
" has zero length even after padding (virtual shape of input item: "
"Data input spatial dimension "
,
<<
input_item_virtual_shape
<<
")."
;
i
,
" has zero length even after padding (virtual shape of input item: "
,
input_item_virtual_shape
,
")."
);
}
}
//
//
...
@@ -121,9 +146,13 @@ void op::QuantizedAvgPool::validate_and_infer_types()
...
@@ -121,9 +146,13 @@ void op::QuantizedAvgPool::validate_and_infer_types()
//
//
for
(
size_t
i
=
0
;
i
<
spatial_dimension_count
;
i
++
)
for
(
size_t
i
=
0
;
i
<
spatial_dimension_count
;
i
++
)
{
{
NODE_VALIDATION_ASSERT
(
this
,
m_window_shape
[
i
]
!=
0
)
NODE_VALIDATION_CHECK
(
this
,
<<
"Window shape dimension "
<<
i
m_window_shape
[
i
]
!=
0
,
<<
" has zero length (window shape: "
<<
m_window_shape
<<
")."
;
"Window shape dimension "
,
i
,
" has zero length (window shape: "
,
m_window_shape
,
")."
);
}
}
//
//
...
@@ -131,10 +160,14 @@ void op::QuantizedAvgPool::validate_and_infer_types()
...
@@ -131,10 +160,14 @@ void op::QuantizedAvgPool::validate_and_infer_types()
//
//
for
(
size_t
i
=
0
;
i
<
spatial_dimension_count
;
i
++
)
for
(
size_t
i
=
0
;
i
<
spatial_dimension_count
;
i
++
)
{
{
NODE_VALIDATION_ASSERT
(
this
,
m_window_shape
[
i
]
<=
input_item_virtual_shape
[
i
])
NODE_VALIDATION_CHECK
(
<<
"Window shape after padding is larger than the spatial dimensions (window shape: "
this
,
<<
m_window_shape
<<
", virtual shape of input item: "
<<
input_item_virtual_shape
m_window_shape
[
i
]
<=
input_item_virtual_shape
[
i
],
<<
")."
;
"Window shape after padding is larger than the spatial dimensions (window shape: "
,
m_window_shape
,
", virtual shape of input item: "
,
input_item_virtual_shape
,
")."
);
}
}
//
//
// Compute output item shape Do, checking at the same time that all window movement strides are larger than 0.
// Compute output item shape Do, checking at the same time that all window movement strides are larger than 0.
...
@@ -143,9 +176,13 @@ void op::QuantizedAvgPool::validate_and_infer_types()
...
@@ -143,9 +176,13 @@ void op::QuantizedAvgPool::validate_and_infer_types()
for
(
size_t
i
=
0
;
i
<
spatial_dimension_count
;
i
++
)
for
(
size_t
i
=
0
;
i
<
spatial_dimension_count
;
i
++
)
{
{
NODE_VALIDATION_ASSERT
(
this
,
m_window_movement_strides
[
i
]
!=
0
)
NODE_VALIDATION_CHECK
(
this
,
<<
"Window movement strides dimension "
<<
i
m_window_movement_strides
[
i
]
!=
0
,
<<
" has zero length (window movement strides: "
<<
m_window_movement_strides
<<
")."
;
"Window movement strides dimension "
,
i
,
" has zero length (window movement strides: "
,
m_window_movement_strides
,
")."
);
output_item_shape
.
push_back
(
ceil_div
(
input_item_virtual_shape
[
i
]
-
m_window_shape
[
i
]
+
1
,
output_item_shape
.
push_back
(
ceil_div
(
input_item_virtual_shape
[
i
]
-
m_window_shape
[
i
]
+
1
,
m_window_movement_strides
[
i
]));
m_window_movement_strides
[
i
]));
}
}
...
@@ -167,11 +204,15 @@ void op::QuantizedAvgPool::validate_and_infer_types()
...
@@ -167,11 +204,15 @@ void op::QuantizedAvgPool::validate_and_infer_types()
// Checking the lower edge of each dimension is easy, because there's no mystery
// Checking the lower edge of each dimension is easy, because there's no mystery
// regarding the window's lower-edge placement...
// regarding the window's lower-edge placement...
NODE_VALIDATION_ASSERT
(
this
,
NODE_VALIDATION_CHECK
(
dim_padding_below
==
0
||
dim_window_size
>
dim_padding_below
)
this
,
<<
"Window will sometimes reside entirely within the below-padding region, but"
dim_padding_below
==
0
||
dim_window_size
>
dim_padding_below
,
<<
" include_padding_in_avg_computation was not set (padding below: "
"Window will sometimes reside entirely within the below-padding region, but"
,
<<
m_padding_below
<<
", window shape: "
<<
m_window_shape
<<
")."
;
" include_padding_in_avg_computation was not set (padding below: "
,
m_padding_below
,
", window shape: "
,
m_window_shape
,
")."
);
// Now check the upper-bound...
// Now check the upper-bound...
{
{
...
@@ -179,13 +220,16 @@ void op::QuantizedAvgPool::validate_and_infer_types()
...
@@ -179,13 +220,16 @@ void op::QuantizedAvgPool::validate_and_infer_types()
const
size_t
dim_window_max_lower_offset
=
dim_num_strides
*
dim_stride
;
const
size_t
dim_window_max_lower_offset
=
dim_num_strides
*
dim_stride
;
const
size_t
dim_padding_above_start_offset
=
dim_virtual_size
-
dim_padding_above
;
const
size_t
dim_padding_above_start_offset
=
dim_virtual_size
-
dim_padding_above
;
NODE_VALIDATION_ASSERT
(
this
,
NODE_VALIDATION_CHECK
(
dim_padding_above
==
0
||
this
,
dim_window_max_lower_offset
<
dim_padding_above
==
0
||
dim_padding_above_start_offset
)
dim_window_max_lower_offset
<
dim_padding_above_start_offset
,
<<
"Window will sometimes reside entirely within the above-padding region, but"
"Window will sometimes reside entirely within the above-padding region, but"
,
<<
" include_padding_in_avg_computation was not set (padding above: "
" include_padding_in_avg_computation was not set (padding above: "
,
<<
m_padding_above
<<
", window shape: "
<<
m_window_shape
<<
")."
;
m_padding_above
,
", window shape: "
,
m_window_shape
,
")."
);
}
}
}
}
}
}
...
...
src/ngraph/op/experimental/quantized_concat.cpp
View file @
066037c2
...
@@ -33,7 +33,7 @@ op::QuantizedConcat::QuantizedConcat(const NodeVector& args, size_t concatenatio
...
@@ -33,7 +33,7 @@ op::QuantizedConcat::QuantizedConcat(const NodeVector& args, size_t concatenatio
void
op
::
QuantizedConcat
::
validate_and_infer_types
()
void
op
::
QuantizedConcat
::
validate_and_infer_types
()
{
{
NODE_VALIDATION_
ASSERT
(
this
,
m_inputs
.
size
()
>=
1
)
<<
"At least one argument required."
;
NODE_VALIDATION_
CHECK
(
this
,
m_inputs
.
size
()
>=
1
,
"At least one argument required."
)
;
PartialShape
inputs_shape_scheme
{
PartialShape
::
dynamic
()};
PartialShape
inputs_shape_scheme
{
PartialShape
::
dynamic
()};
element
::
Type
inputs_et
{
element
::
dynamic
};
element
::
Type
inputs_et
{
element
::
dynamic
};
...
@@ -45,23 +45,32 @@ void op::QuantizedConcat::validate_and_infer_types()
...
@@ -45,23 +45,32 @@ void op::QuantizedConcat::validate_and_infer_types()
Dimension
this_input_rank
=
this_input_shape
.
rank
();
Dimension
this_input_rank
=
this_input_shape
.
rank
();
if
(
this_input_rank
.
is_static
())
if
(
this_input_rank
.
is_static
())
{
{
NODE_VALIDATION_ASSERT
(
this
,
m_concatenation_axis
<
size_t
(
this_input_rank
))
NODE_VALIDATION_CHECK
(
this
,
<<
"QuantizedConcatenation axis ("
<<
m_concatenation_axis
m_concatenation_axis
<
size_t
(
this_input_rank
),
<<
") is out of bounds for "
"QuantizedConcatenation axis ("
,
<<
"argument "
<<
i
<<
", which has shape "
<<
this_input_shape
<<
"."
;
m_concatenation_axis
,
") is out of bounds for "
,
"argument "
,
i
,
", which has shape "
,
this_input_shape
,
"."
);
concatenation_axis_output_dim
+=
this_input_shape
[
m_concatenation_axis
];
concatenation_axis_output_dim
+=
this_input_shape
[
m_concatenation_axis
];
this_input_shape
[
m_concatenation_axis
]
=
Dimension
::
dynamic
();
this_input_shape
[
m_concatenation_axis
]
=
Dimension
::
dynamic
();
NODE_VALIDATION_ASSERT
(
this
,
NODE_VALIDATION_CHECK
(
PartialShape
::
merge_into
(
inputs_shape_scheme
,
this_input_shape
))
this
,
<<
"Argument shapes are inconsistent; they must have the same rank, and must have "
PartialShape
::
merge_into
(
inputs_shape_scheme
,
this_input_shape
),
<<
"equal dimension everywhere except on the concatenation axis (axis "
"Argument shapes are inconsistent; they must have the same rank, and must have "
,
<<
m_concatenation_axis
<<
")."
;
"equal dimension everywhere except on the concatenation axis (axis "
,
m_concatenation_axis
,
")."
);
NODE_VALIDATION_ASSERT
(
NODE_VALIDATION_CHECK
(
this
,
element
::
Type
::
merge
(
inputs_et
,
inputs_et
,
get_input_element_type
(
i
)))
this
,
<<
"Argument element types are inconsistent."
;
element
::
Type
::
merge
(
inputs_et
,
inputs_et
,
get_input_element_type
(
i
)),
"Argument element types are inconsistent."
);
}
}
else
else
{
{
...
...
src/ngraph/op/experimental/quantized_max_pool.cpp
View file @
066037c2
...
@@ -64,36 +64,58 @@ void op::QuantizedMaxPool::validate_and_infer_types()
...
@@ -64,36 +64,58 @@ void op::QuantizedMaxPool::validate_and_infer_types()
// Make sure batch size and channel count are not zero, and that we have at least one spatial
// Make sure batch size and channel count are not zero, and that we have at least one spatial
// dimension (in other words, that arg has shape NCDi for some Di of rank>0, N != 0, C != 0).
// dimension (in other words, that arg has shape NCDi for some Di of rank>0, N != 0, C != 0).
//
//
NODE_VALIDATION_ASSERT
(
this
,
arg_shape
.
size
()
>=
3
)
NODE_VALIDATION_CHECK
(
this
,
<<
"Data input shape does not have rank of at least 3 (data input shape: "
<<
arg_shape
arg_shape
.
size
()
>=
3
,
<<
")."
;
"Data input shape does not have rank of at least 3 (data input shape: "
,
arg_shape
,
")."
);
size_t
batch_size
=
arg_shape
[
0
];
size_t
batch_size
=
arg_shape
[
0
];
NODE_VALIDATION_
ASSERT
(
this
,
batch_size
!=
0
)
NODE_VALIDATION_
CHECK
(
<<
"Data batch size is zero (data input shape: "
<<
arg_shape
<<
")."
;
this
,
batch_size
!=
0
,
"Data batch size is zero (data input shape: "
,
arg_shape
,
")."
)
;
size_t
channel_count
=
arg_shape
[
1
];
size_t
channel_count
=
arg_shape
[
1
];
NODE_VALIDATION_
ASSERT
(
this
,
channel_count
!=
0
)
NODE_VALIDATION_
CHECK
(
<<
"Channel count is zero (data input shape: "
<<
arg_shape
<<
")."
;
this
,
channel_count
!=
0
,
"Channel count is zero (data input shape: "
,
arg_shape
,
")."
)
;
size_t
spatial_dimension_count
=
arg_shape
.
size
()
-
2
;
size_t
spatial_dimension_count
=
arg_shape
.
size
()
-
2
;
//
//
// Make sure window shape, window movement strides, and padding have same rank as Di.
// Make sure window shape, window movement strides, and padding have same rank as Di.
//
//
NODE_VALIDATION_ASSERT
(
this
,
m_window_shape
.
size
()
==
spatial_dimension_count
)
NODE_VALIDATION_CHECK
(
<<
"Window shape rank does not match number of spatial dimensions (window shape: "
this
,
<<
m_window_shape
<<
", data input shape: "
<<
arg_shape
<<
")."
;
m_window_shape
.
size
()
==
spatial_dimension_count
,
NODE_VALIDATION_ASSERT
(
this
,
m_window_movement_strides
.
size
()
==
spatial_dimension_count
)
"Window shape rank does not match number of spatial dimensions (window shape: "
,
<<
"Window movement stride rank does not match number of spatial dimensions (window "
m_window_shape
,
"movement strides: "
", data input shape: "
,
<<
m_window_movement_strides
<<
", data input shape: "
<<
arg_shape
<<
")."
;
arg_shape
,
NODE_VALIDATION_ASSERT
(
this
,
m_padding_below
.
size
()
==
spatial_dimension_count
)
")."
);
<<
"Below-padding rank does not match number of spatial dimensions (padding below: "
NODE_VALIDATION_CHECK
(
<<
m_padding_below
<<
", data input shape: "
<<
arg_shape
<<
")."
;
this
,
NODE_VALIDATION_ASSERT
(
this
,
m_padding_above
.
size
()
==
spatial_dimension_count
)
m_window_movement_strides
.
size
()
==
spatial_dimension_count
,
<<
"Above-padding rank does not match number of spatial dimensions (padding above: "
"Window movement stride rank does not match number of spatial dimensions (window "
<<
m_padding_above
<<
", data input shape: "
<<
arg_shape
<<
")."
;
"movement strides: "
,
m_window_movement_strides
,
", data input shape: "
,
arg_shape
,
")."
);
NODE_VALIDATION_CHECK
(
this
,
m_padding_below
.
size
()
==
spatial_dimension_count
,
"Below-padding rank does not match number of spatial dimensions (padding below: "
,
m_padding_below
,
", data input shape: "
,
arg_shape
,
")."
);
NODE_VALIDATION_CHECK
(
this
,
m_padding_above
.
size
()
==
spatial_dimension_count
,
"Above-padding rank does not match number of spatial dimensions (padding above: "
,
m_padding_above
,
", data input shape: "
,
arg_shape
,
")."
);
//
//
// Extract input item shape Di and make sure all dimensions are larger than 0.
// Extract input item shape Di and make sure all dimensions are larger than 0.
...
@@ -109,10 +131,13 @@ void op::QuantizedMaxPool::validate_and_infer_types()
...
@@ -109,10 +131,13 @@ void op::QuantizedMaxPool::validate_and_infer_types()
for
(
size_t
i
=
0
;
i
<
spatial_dimension_count
;
i
++
)
for
(
size_t
i
=
0
;
i
<
spatial_dimension_count
;
i
++
)
{
{
NODE_VALIDATION_ASSERT
(
this
,
input_item_virtual_shape
[
i
]
!=
0
)
NODE_VALIDATION_CHECK
(
this
,
<<
"Data input spatial dimension "
<<
i
input_item_virtual_shape
[
i
]
!=
0
,
<<
" has zero length even after padding (virtual shape of input item: "
"Data input spatial dimension "
,
<<
input_item_virtual_shape
<<
")."
;
i
,
" has zero length even after padding (virtual shape of input item: "
,
input_item_virtual_shape
,
")."
);
}
}
//
//
...
@@ -120,9 +145,13 @@ void op::QuantizedMaxPool::validate_and_infer_types()
...
@@ -120,9 +145,13 @@ void op::QuantizedMaxPool::validate_and_infer_types()
//
//
for
(
size_t
i
=
0
;
i
<
spatial_dimension_count
;
i
++
)
for
(
size_t
i
=
0
;
i
<
spatial_dimension_count
;
i
++
)
{
{
NODE_VALIDATION_ASSERT
(
this
,
m_window_shape
[
i
]
!=
0
)
NODE_VALIDATION_CHECK
(
this
,
<<
"Window shape dimension "
<<
i
m_window_shape
[
i
]
!=
0
,
<<
" has zero length (window shape: "
<<
m_window_shape
<<
")."
;
"Window shape dimension "
,
i
,
" has zero length (window shape: "
,
m_window_shape
,
")."
);
}
}
//
//
...
@@ -130,10 +159,14 @@ void op::QuantizedMaxPool::validate_and_infer_types()
...
@@ -130,10 +159,14 @@ void op::QuantizedMaxPool::validate_and_infer_types()
//
//
for
(
size_t
i
=
0
;
i
<
spatial_dimension_count
;
i
++
)
for
(
size_t
i
=
0
;
i
<
spatial_dimension_count
;
i
++
)
{
{
NODE_VALIDATION_ASSERT
(
this
,
m_window_shape
[
i
]
<=
input_item_virtual_shape
[
i
])
NODE_VALIDATION_CHECK
(
<<
"Window shape after padding is larger than the spatial dimensions (window shape: "
this
,
<<
m_window_shape
<<
", virtual shape of input item: "
<<
input_item_virtual_shape
m_window_shape
[
i
]
<=
input_item_virtual_shape
[
i
],
<<
")."
;
"Window shape after padding is larger than the spatial dimensions (window shape: "
,
m_window_shape
,
", virtual shape of input item: "
,
input_item_virtual_shape
,
")."
);
}
}
//
//
...
@@ -143,9 +176,13 @@ void op::QuantizedMaxPool::validate_and_infer_types()
...
@@ -143,9 +176,13 @@ void op::QuantizedMaxPool::validate_and_infer_types()
for
(
size_t
i
=
0
;
i
<
spatial_dimension_count
;
i
++
)
for
(
size_t
i
=
0
;
i
<
spatial_dimension_count
;
i
++
)
{
{
NODE_VALIDATION_ASSERT
(
this
,
m_window_movement_strides
[
i
]
!=
0
)
NODE_VALIDATION_CHECK
(
this
,
<<
"Window movement strides dimension "
<<
i
m_window_movement_strides
[
i
]
!=
0
,
<<
" has zero length (window movement strides: "
<<
m_window_movement_strides
<<
")."
;
"Window movement strides dimension "
,
i
,
" has zero length (window movement strides: "
,
m_window_movement_strides
,
")."
);
output_item_shape
.
push_back
(
ceil_div
(
input_item_virtual_shape
[
i
]
-
m_window_shape
[
i
]
+
1
,
output_item_shape
.
push_back
(
ceil_div
(
input_item_virtual_shape
[
i
]
-
m_window_shape
[
i
]
+
1
,
m_window_movement_strides
[
i
]));
m_window_movement_strides
[
i
]));
}
}
...
...
src/ngraph/op/get_output_element.cpp
View file @
066037c2
...
@@ -30,9 +30,13 @@ op::GetOutputElement::GetOutputElement(const shared_ptr<Node>& arg, size_t n)
...
@@ -30,9 +30,13 @@ op::GetOutputElement::GetOutputElement(const shared_ptr<Node>& arg, size_t n)
void
op
::
GetOutputElement
::
validate_and_infer_types
()
void
op
::
GetOutputElement
::
validate_and_infer_types
()
{
{
NODE_VALIDATION_ASSERT
(
this
,
m_n
<
get_input_size
())
NODE_VALIDATION_CHECK
(
this
,
<<
"Output at index "
<<
m_n
<<
" requested, but node has only "
<<
get_input_size
()
m_n
<
get_input_size
(),
<<
" inputs."
;
"Output at index "
,
m_n
,
" requested, but node has only "
,
get_input_size
(),
" inputs."
);
set_output_type
(
0
,
get_input_element_type
(
m_n
),
get_input_partial_shape
(
m_n
));
set_output_type
(
0
,
get_input_element_type
(
m_n
),
get_input_partial_shape
(
m_n
));
}
}
...
...
src/ngraph/op/lrn.cpp
View file @
066037c2
...
@@ -36,9 +36,12 @@ void op::LRN::validate_and_infer_types()
...
@@ -36,9 +36,12 @@ void op::LRN::validate_and_infer_types()
const
PartialShape
&
input_shape
=
get_input_partial_shape
(
0
);
const
PartialShape
&
input_shape
=
get_input_partial_shape
(
0
);
NODE_VALIDATION_ASSERT
(
NODE_VALIDATION_CHECK
(
this
,
this
,
input_shape
.
rank
().
is_dynamic
()
||
static_cast
<
size_t
>
(
input_shape
.
rank
())
>=
3
)
input_shape
.
rank
().
is_dynamic
()
||
<<
"Argument must have rank >= 3 (argument shape: "
<<
input_shape
<<
")."
;
static_cast
<
size_t
>
(
input_shape
.
rank
())
>=
3
,
"Argument must have rank >= 3 (argument shape: "
,
input_shape
,
")."
);
}
}
shared_ptr
<
Node
>
op
::
LRN
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
shared_ptr
<
Node
>
op
::
LRN
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
...
...
src/ngraph/op/max_pool.cpp
View file @
066037c2
...
@@ -134,9 +134,13 @@ void op::MaxPoolBackprop::validate_and_infer_types()
...
@@ -134,9 +134,13 @@ void op::MaxPoolBackprop::validate_and_infer_types()
element
::
Type
result_et
;
element
::
Type
result_et
;
NODE_VALIDATION_ASSERT
(
this
,
element
::
Type
::
merge
(
result_et
,
forward_arg_et
,
delta_et
))
NODE_VALIDATION_CHECK
(
this
,
<<
"Element types for forward argument ("
<<
forward_arg_et
<<
") and delta ("
<<
delta_et
element
::
Type
::
merge
(
result_et
,
forward_arg_et
,
delta_et
),
<<
") do not match."
;
"Element types for forward argument ("
,
forward_arg_et
,
") and delta ("
,
delta_et
,
") do not match."
);
// infer_batched_forward_pooling wants CoordinateDiffs for these, while the pooling ops for
// infer_batched_forward_pooling wants CoordinateDiffs for these, while the pooling ops for
// now still take Shape (no negative padding).
// now still take Shape (no negative padding).
...
@@ -155,9 +159,15 @@ void op::MaxPoolBackprop::validate_and_infer_types()
...
@@ -155,9 +159,15 @@ void op::MaxPoolBackprop::validate_and_infer_types()
const
PartialShape
&
delta_shape
=
get_input_partial_shape
(
1
);
const
PartialShape
&
delta_shape
=
get_input_partial_shape
(
1
);
NODE_VALIDATION_ASSERT
(
this
,
forward_result_shape
.
compatible
(
delta_shape
))
NODE_VALIDATION_CHECK
(
<<
"Inferred forward output shape does not match delta shape (inferred forward output "
this
,
<<
"shape: "
<<
forward_result_shape
<<
", delta shape: "
<<
delta_shape
<<
")."
;
forward_result_shape
.
compatible
(
delta_shape
),
"Inferred forward output shape does not match delta shape (inferred forward output "
,
"shape: "
,
forward_result_shape
,
", delta shape: "
,
delta_shape
,
")."
);
// TODO(amprocte): We may technically be able to infer some extra information from
// TODO(amprocte): We may technically be able to infer some extra information from
// forward_result_shape that was not present in the forward arg shape---namely batch size and
// forward_result_shape that was not present in the forward arg shape---namely batch size and
...
...
src/ngraph/op/one_hot.cpp
View file @
066037c2
...
@@ -34,16 +34,25 @@ void op::OneHot::validate_and_infer_types()
...
@@ -34,16 +34,25 @@ void op::OneHot::validate_and_infer_types()
PartialShape
arg_shape
=
get_input_partial_shape
(
0
);
PartialShape
arg_shape
=
get_input_partial_shape
(
0
);
Rank
arg_rank
=
arg_shape
.
rank
();
Rank
arg_rank
=
arg_shape
.
rank
();
NODE_VALIDATION_
ASSERT
(
this
,
m_shape
.
rank
().
is_static
())
NODE_VALIDATION_
CHECK
(
<<
"Requested result shape has dynamic rank."
;
this
,
m_shape
.
rank
().
is_static
(),
"Requested result shape has dynamic rank."
)
;
NODE_VALIDATION_ASSERT
(
this
,
m_one_hot_axis
<
static_cast
<
size_t
>
(
m_shape
.
rank
()))
NODE_VALIDATION_CHECK
(
this
,
<<
"One-hot axis ("
<<
m_one_hot_axis
m_one_hot_axis
<
static_cast
<
size_t
>
(
m_shape
.
rank
()),
<<
") is out of bounds (requested result shape: "
<<
m_shape
<<
")."
;
"One-hot axis ("
,
m_one_hot_axis
,
") is out of bounds (requested result shape: "
,
m_shape
,
")."
);
NODE_VALIDATION_ASSERT
(
this
,
m_shape
[
m_one_hot_axis
].
is_static
())
NODE_VALIDATION_CHECK
(
this
,
<<
"Requested result shape ("
<<
m_shape
<<
") has dynamic dimension at the one-hot axis "
m_shape
[
m_one_hot_axis
].
is_static
(),
<<
"("
<<
m_one_hot_axis
<<
")."
;
"Requested result shape ("
,
m_shape
,
") has dynamic dimension at the one-hot axis "
,
"("
,
m_one_hot_axis
,
")."
);
PartialShape
result_shape
{
m_shape
};
PartialShape
result_shape
{
m_shape
};
...
@@ -58,9 +67,13 @@ void op::OneHot::validate_and_infer_types()
...
@@ -58,9 +67,13 @@ void op::OneHot::validate_and_infer_types()
PartialShape
expected_input_shape
{
expected_input_dims
};
PartialShape
expected_input_shape
{
expected_input_dims
};
PartialShape
merged_input_shape
{
expected_input_shape
};
PartialShape
merged_input_shape
{
expected_input_shape
};
NODE_VALIDATION_ASSERT
(
this
,
PartialShape
::
merge_into
(
merged_input_shape
,
arg_shape
))
NODE_VALIDATION_CHECK
(
this
,
<<
"Argument shape "
<<
arg_shape
<<
" does not match the expected shape of "
PartialShape
::
merge_into
(
merged_input_shape
,
arg_shape
),
<<
expected_input_shape
<<
"."
;
"Argument shape "
,
arg_shape
,
" does not match the expected shape of "
,
expected_input_shape
,
"."
);
std
::
vector
<
Dimension
>
output_dims
(
static_cast
<
size_t
>
(
merged_input_shape
.
rank
()));
std
::
vector
<
Dimension
>
output_dims
(
static_cast
<
size_t
>
(
merged_input_shape
.
rank
()));
for
(
size_t
i
=
0
;
i
<
static_cast
<
size_t
>
(
merged_input_shape
.
rank
());
i
++
)
for
(
size_t
i
=
0
;
i
<
static_cast
<
size_t
>
(
merged_input_shape
.
rank
());
i
++
)
...
...
src/ngraph/op/pad.cpp
View file @
066037c2
...
@@ -38,31 +38,49 @@ void op::Pad::validate_and_infer_types()
...
@@ -38,31 +38,49 @@ void op::Pad::validate_and_infer_types()
{
{
element
::
Type
result_et
;
element
::
Type
result_et
;
NODE_VALIDATION_ASSERT
(
NODE_VALIDATION_CHECK
(
this
,
element
::
Type
::
merge
(
result_et
,
get_input_element_type
(
0
),
get_input_element_type
(
1
)))
this
,
<<
"Argument element types do not match (arg0 element type: "
<<
get_input_element_type
(
0
)
element
::
Type
::
merge
(
result_et
,
get_input_element_type
(
0
),
get_input_element_type
(
1
)),
<<
", arg1 element type: "
<<
get_input_element_type
(
1
)
<<
")."
;
"Argument element types do not match (arg0 element type: "
,
get_input_element_type
(
0
),
NODE_VALIDATION_ASSERT
(
this
,
get_input_partial_shape
(
1
).
compatible
(
PartialShape
{}))
", arg1 element type: "
,
<<
"Argument for padding value is not a scalar (shape: "
<<
get_input_partial_shape
(
1
)
get_input_element_type
(
1
),
<<
")."
;
")."
);
NODE_VALIDATION_CHECK
(
this
,
get_input_partial_shape
(
1
).
compatible
(
PartialShape
{}),
"Argument for padding value is not a scalar (shape: "
,
get_input_partial_shape
(
1
),
")."
);
auto
arg_shape
=
get_input_partial_shape
(
0
);
auto
arg_shape
=
get_input_partial_shape
(
0
);
NODE_VALIDATION_ASSERT
(
this
,
NODE_VALIDATION_CHECK
(
this
,
m_padding_below
.
size
()
==
m_padding_above
.
size
()
&&
m_padding_below
.
size
()
==
m_padding_above
.
size
()
&&
m_padding_below
.
size
()
==
m_padding_interior
.
size
())
m_padding_below
.
size
()
==
m_padding_interior
.
size
(),
<<
"Ranks for padding below ("
<<
m_padding_below
<<
"), padding above ("
<<
m_padding_above
"Ranks for padding below ("
,
<<
") and interior padding ("
<<
m_padding_interior
<<
") "
m_padding_below
,
<<
"do not match."
;
"), padding above ("
,
m_padding_above
,
") and interior padding ("
,
m_padding_interior
,
") "
,
"do not match."
);
size_t
implied_rank
=
m_padding_below
.
size
();
size_t
implied_rank
=
m_padding_below
.
size
();
NODE_VALIDATION_ASSERT
(
this
,
arg_shape
.
rank
().
compatible
(
implied_rank
))
NODE_VALIDATION_CHECK
(
<<
"Rank for padding below/padding above/interior padding does not match the rank of the "
this
,
<<
"data argument (padding below: "
<<
m_padding_below
<<
", "
arg_shape
.
rank
().
compatible
(
implied_rank
),
<<
", padding above: "
<<
m_padding_above
<<
", interior padding: "
<<
m_padding_interior
"Rank for padding below/padding above/interior padding does not match the rank of the "
,
<<
")."
;
"data argument (padding below: "
,
m_padding_below
,
", "
,
", padding above: "
,
m_padding_above
,
", interior padding: "
,
m_padding_interior
,
")."
);
std
::
vector
<
Dimension
>
result_dims
(
implied_rank
,
Dimension
::
dynamic
());
std
::
vector
<
Dimension
>
result_dims
(
implied_rank
,
Dimension
::
dynamic
());
...
...
src/ngraph/op/quantize.cpp
View file @
066037c2
...
@@ -44,50 +44,73 @@ void op::Quantize::validate_and_infer_types()
...
@@ -44,50 +44,73 @@ void op::Quantize::validate_and_infer_types()
OFFSET
OFFSET
};
};
NODE_VALIDATION_
ASSERT
(
this
,
m_type
.
is_static
())
<<
"Output element type must not be dynamic"
;
NODE_VALIDATION_
CHECK
(
this
,
m_type
.
is_static
(),
"Output element type must not be dynamic"
)
;
NODE_VALIDATION_
ASSERT
(
this
,
m_type
.
is_quantized
())
<<
"Output element type ("
<<
m_type
NODE_VALIDATION_
CHECK
(
<<
") must be a quantized type"
;
this
,
m_type
.
is_quantized
(),
"Output element type ("
,
m_type
,
") must be a quantized type"
)
;
element
::
Type
unquantized_type
;
element
::
Type
unquantized_type
;
NODE_VALIDATION_ASSERT
(
this
,
NODE_VALIDATION_CHECK
(
this
,
element
::
Type
::
merge
(
unquantized_type
,
element
::
Type
::
merge
(
unquantized_type
,
get_input_element_type
(
INPUT
),
get_input_element_type
(
INPUT
),
get_input_element_type
(
SCALE
)))
get_input_element_type
(
SCALE
)),
<<
"Scale element type ("
<<
get_input_element_type
(
SCALE
)
"Scale element type ("
,
<<
") must match input element type ("
<<
get_input_element_type
(
INPUT
)
<<
")"
;
get_input_element_type
(
SCALE
),
") must match input element type ("
,
NODE_VALIDATION_ASSERT
(
this
,
unquantized_type
.
is_dynamic
()
||
unquantized_type
.
is_real
())
get_input_element_type
(
INPUT
),
<<
"Scale/input element type ("
<<
unquantized_type
<<
") must be a floating point number"
;
")"
);
NODE_VALIDATION_CHECK
(
this
,
unquantized_type
.
is_dynamic
()
||
unquantized_type
.
is_real
(),
"Scale/input element type ("
,
unquantized_type
,
") must be a floating point number"
);
element
::
Type
quantized_type
;
element
::
Type
quantized_type
;
NODE_VALIDATION_ASSERT
(
NODE_VALIDATION_CHECK
(
this
,
element
::
Type
::
merge
(
quantized_type
,
get_input_element_type
(
OFFSET
),
m_type
))
this
,
<<
"Offset element type ("
<<
get_input_element_type
(
OFFSET
)
element
::
Type
::
merge
(
quantized_type
,
get_input_element_type
(
OFFSET
),
m_type
),
<<
") must match output element type ("
<<
m_type
<<
")"
;
"Offset element type ("
,
get_input_element_type
(
OFFSET
),
") must match output element type ("
,
m_type
,
")"
);
PartialShape
input_shape
=
get_input_partial_shape
(
0
);
PartialShape
input_shape
=
get_input_partial_shape
(
0
);
Dimension
input_rank
=
input_shape
.
rank
();
Dimension
input_rank
=
input_shape
.
rank
();
for
(
auto
axis
:
m_axes
)
for
(
auto
axis
:
m_axes
)
{
{
NODE_VALIDATION_ASSERT
(
this
,
input_rank
.
is_dynamic
()
||
axis
<
size_t
(
input_rank
))
NODE_VALIDATION_CHECK
(
this
,
<<
"Quantization axis ("
<<
axis
<<
") must be less than input shape rank ("
input_rank
.
is_dynamic
()
||
axis
<
size_t
(
input_rank
),
<<
input_rank
<<
")"
;
"Quantization axis ("
,
axis
,
") must be less than input shape rank ("
,
input_rank
,
")"
);
}
}
PartialShape
scale_offset_shape
=
get_input_partial_shape
(
SCALE
);
PartialShape
scale_offset_shape
=
get_input_partial_shape
(
SCALE
);
NODE_VALIDATION_ASSERT
(
NODE_VALIDATION_CHECK
(
this
,
PartialShape
::
merge_into
(
scale_offset_shape
,
get_input_partial_shape
(
OFFSET
)))
this
,
<<
"Scale shape ("
<<
get_input_partial_shape
(
SCALE
)
<<
") and offset shape ("
PartialShape
::
merge_into
(
scale_offset_shape
,
get_input_partial_shape
(
OFFSET
)),
<<
get_input_partial_shape
(
OFFSET
)
<<
") must match"
;
"Scale shape ("
,
get_input_partial_shape
(
SCALE
),
NODE_VALIDATION_ASSERT
(
this
,
scale_offset_shape
.
rank
().
compatible
(
m_axes
.
size
()))
") and offset shape ("
,
<<
"Scale/offset rank ("
<<
scale_offset_shape
.
rank
()
<<
") does not match the number of "
get_input_partial_shape
(
OFFSET
),
<<
"quantization axes ("
<<
m_axes
.
size
()
<<
")"
;
") must match"
);
NODE_VALIDATION_CHECK
(
this
,
scale_offset_shape
.
rank
().
compatible
(
m_axes
.
size
()),
"Scale/offset rank ("
,
scale_offset_shape
.
rank
(),
") does not match the number of "
,
"quantization axes ("
,
m_axes
.
size
(),
")"
);
set_output_size
(
1
);
set_output_size
(
1
);
...
@@ -110,10 +133,16 @@ void op::Quantize::validate_and_infer_types()
...
@@ -110,10 +133,16 @@ void op::Quantize::validate_and_infer_types()
}
}
PartialShape
result_shape
=
input_shape
;
PartialShape
result_shape
=
input_shape
;
NODE_VALIDATION_ASSERT
(
NODE_VALIDATION_CHECK
(
this
,
PartialShape
::
merge_into
(
result_shape
,
PartialShape
{
injected_scale_offset_dims
}))
this
,
<<
"Scale/offset shape ("
<<
scale_offset_shape
<<
") must match input shape ("
PartialShape
::
merge_into
(
result_shape
,
PartialShape
{
injected_scale_offset_dims
}),
<<
input_shape
<<
") at the quantization axes ("
<<
m_axes
<<
")"
;
"Scale/offset shape ("
,
scale_offset_shape
,
") must match input shape ("
,
input_shape
,
") at the quantization axes ("
,
m_axes
,
")"
);
set_output_type
(
0
,
quantized_type
,
result_shape
);
set_output_type
(
0
,
quantized_type
,
result_shape
);
}
}
else
else
...
...
src/ngraph/op/replace_slice.cpp
View file @
066037c2
...
@@ -59,51 +59,85 @@ void op::ReplaceSlice::validate_and_infer_types()
...
@@ -59,51 +59,85 @@ void op::ReplaceSlice::validate_and_infer_types()
const
PartialShape
&
arg1_shape
=
get_input_partial_shape
(
1
);
const
PartialShape
&
arg1_shape
=
get_input_partial_shape
(
1
);
Dimension
merged_args_rank
;
Dimension
merged_args_rank
;
NODE_VALIDATION_ASSERT
(
this
,
NODE_VALIDATION_CHECK
(
this
,
Dimension
::
merge
(
merged_args_rank
,
arg0_shape
.
rank
(),
arg1_shape
.
rank
()))
Dimension
::
merge
(
merged_args_rank
,
arg0_shape
.
rank
(),
arg1_shape
.
rank
()),
<<
"Argument ranks do not match (arg0 shape: "
<<
arg0_shape
"Argument ranks do not match (arg0 shape: "
,
<<
", arg1 shape: "
<<
arg1_shape
<<
")."
;
arg0_shape
,
", arg1 shape: "
,
arg1_shape
,
")."
);
element
::
Type
arg0_et
=
get_input_element_type
(
0
);
element
::
Type
arg0_et
=
get_input_element_type
(
0
);
element
::
Type
arg1_et
=
get_input_element_type
(
1
);
element
::
Type
arg1_et
=
get_input_element_type
(
1
);
element
::
Type
merged_args_et
;
element
::
Type
merged_args_et
;
NODE_VALIDATION_ASSERT
(
this
,
element
::
Type
::
merge
(
merged_args_et
,
arg0_et
,
arg1_et
))
NODE_VALIDATION_CHECK
(
this
,
<<
"Argument element types do not match (arg0 element type: "
<<
arg0_et
element
::
Type
::
merge
(
merged_args_et
,
arg0_et
,
arg1_et
),
<<
", arg1 element type: "
<<
arg1_et
<<
")."
;
"Argument element types do not match (arg0 element type: "
,
arg0_et
,
NODE_VALIDATION_ASSERT
(
this
,
", arg1 element type: "
,
m_lower_bounds
.
size
()
==
m_upper_bounds
.
size
()
&&
arg1_et
,
m_lower_bounds
.
size
()
==
m_strides
.
size
())
")."
);
<<
"Ranks of lower bounds ("
<<
m_lower_bounds
<<
"), upper bounds ("
<<
m_upper_bounds
<<
") and strides ("
<<
m_strides
<<
") do not match."
;
NODE_VALIDATION_CHECK
(
this
,
m_lower_bounds
.
size
()
==
m_upper_bounds
.
size
()
&&
m_lower_bounds
.
size
()
==
m_strides
.
size
(),
"Ranks of lower bounds ("
,
m_lower_bounds
,
"), upper bounds ("
,
m_upper_bounds
,
") and strides ("
,
m_strides
,
") do not match."
);
size_t
output_rank
=
m_upper_bounds
.
size
();
size_t
output_rank
=
m_upper_bounds
.
size
();
for
(
size_t
i
=
0
;
i
<
output_rank
;
i
++
)
for
(
size_t
i
=
0
;
i
<
output_rank
;
i
++
)
{
{
NODE_VALIDATION_ASSERT
(
this
,
m_lower_bounds
[
i
]
<=
m_upper_bounds
[
i
])
NODE_VALIDATION_CHECK
(
this
,
<<
"Lower bound for slice is greater than upper bound at axis "
<<
i
m_lower_bounds
[
i
]
<=
m_upper_bounds
[
i
],
<<
" (lower bounds: "
<<
m_lower_bounds
<<
", upper bounds: "
<<
m_upper_bounds
<<
")."
;
"Lower bound for slice is greater than upper bound at axis "
,
i
,
NODE_VALIDATION_ASSERT
(
this
,
m_strides
[
i
]
!=
0
)
<<
"Stride for slice is zero at axis "
<<
i
" (lower bounds: "
,
<<
" (strides: "
<<
m_strides
<<
")."
;
m_lower_bounds
,
", upper bounds: "
,
m_upper_bounds
,
")."
);
NODE_VALIDATION_CHECK
(
this
,
m_strides
[
i
]
!=
0
,
"Stride for slice is zero at axis "
,
i
,
" (strides: "
,
m_strides
,
")."
);
}
}
NODE_VALIDATION_ASSERT
(
this
,
NODE_VALIDATION_CHECK
(
this
,
merged_args_rank
.
is_dynamic
()
||
size_t
(
merged_args_rank
)
==
output_rank
)
merged_args_rank
.
is_dynamic
()
||
size_t
(
merged_args_rank
)
==
output_rank
,
<<
"Argument ranks do not match the rank of the lower bounds ("
<<
m_lower_bounds
"Argument ranks do not match the rank of the lower bounds ("
,
<<
"), upper bounds ("
<<
m_upper_bounds
<<
"), and strides ("
<<
m_strides
<<
")."
;
m_lower_bounds
,
"), upper bounds ("
,
m_upper_bounds
,
"), and strides ("
,
m_strides
,
")."
);
std
::
vector
<
Dimension
>
sliced_dims
(
output_rank
);
std
::
vector
<
Dimension
>
sliced_dims
(
output_rank
);
for
(
size_t
i
=
0
;
i
<
output_rank
;
i
++
)
for
(
size_t
i
=
0
;
i
<
output_rank
;
i
++
)
{
{
NODE_VALIDATION_ASSERT
(
this
,
NODE_VALIDATION_CHECK
(
this
,
arg0_shape
.
rank
().
is_dynamic
()
||
arg0_shape
[
i
].
is_dynamic
()
||
arg0_shape
.
rank
().
is_dynamic
()
||
arg0_shape
[
i
].
is_dynamic
()
||
m_upper_bounds
[
i
]
<=
size_t
(
arg0_shape
[
i
]))
m_upper_bounds
[
i
]
<=
size_t
(
arg0_shape
[
i
]),
<<
"Upper bound for slice at axis "
<<
i
<<
" is out of range "
"Upper bound for slice at axis "
,
<<
"(upper bounds: "
<<
m_upper_bounds
<<
", argument shape: "
<<
arg0_shape
<<
")."
;
i
,
" is out of range "
,
"(upper bounds: "
,
m_upper_bounds
,
", argument shape: "
,
arg0_shape
,
")."
);
size_t
sliced_dim
=
m_upper_bounds
[
i
]
-
m_lower_bounds
[
i
];
size_t
sliced_dim
=
m_upper_bounds
[
i
]
-
m_lower_bounds
[
i
];
sliced_dim
=
sliced_dim
/
m_strides
[
i
]
+
((
sliced_dim
%
m_strides
[
i
]
==
0
)
?
0
:
1
);
sliced_dim
=
sliced_dim
/
m_strides
[
i
]
+
((
sliced_dim
%
m_strides
[
i
]
==
0
)
?
0
:
1
);
...
@@ -112,9 +146,14 @@ void op::ReplaceSlice::validate_and_infer_types()
...
@@ -112,9 +146,14 @@ void op::ReplaceSlice::validate_and_infer_types()
PartialShape
slice_shape
{
sliced_dims
};
PartialShape
slice_shape
{
sliced_dims
};
NODE_VALIDATION_ASSERT
(
this
,
arg1_shape
.
compatible
(
slice_shape
))
NODE_VALIDATION_CHECK
(
this
,
<<
"Shape of replacement tensor ("
<<
arg1_shape
<<
") does not match the slice shape "
arg1_shape
.
compatible
(
slice_shape
),
<<
"("
<<
slice_shape
<<
")."
;
"Shape of replacement tensor ("
,
arg1_shape
,
") does not match the slice shape "
,
"("
,
slice_shape
,
")."
);
// Slight corner case here: if arg0 was rank-unknown, we can go ahead and set the output rank
// Slight corner case here: if arg0 was rank-unknown, we can go ahead and set the output rank
// because the attribs will have given us enough info.
// because the attribs will have given us enough info.
...
...
src/ngraph/op/reshape.cpp
View file @
066037c2
...
@@ -41,25 +41,39 @@ void op::Reshape::validate_and_infer_types()
...
@@ -41,25 +41,39 @@ void op::Reshape::validate_and_infer_types()
// Check that the input axis order is a permutation of (0,...,n-1) for some n.
// Check that the input axis order is a permutation of (0,...,n-1) for some n.
for
(
size_t
i
=
0
;
i
<
m_input_order
.
size
();
i
++
)
for
(
size_t
i
=
0
;
i
<
m_input_order
.
size
();
i
++
)
{
{
NODE_VALIDATION_ASSERT
(
NODE_VALIDATION_CHECK
(
this
,
find
(
begin
(
m_input_order
),
end
(
m_input_order
),
i
)
!=
end
(
m_input_order
))
this
,
<<
"Input axis order is not a permutation of argument's axis indices (axis order: "
find
(
begin
(
m_input_order
),
end
(
m_input_order
),
i
)
!=
end
(
m_input_order
),
<<
m_input_order
<<
", argument shape: "
<<
input_shape
<<
")."
;
"Input axis order is not a permutation of argument's axis indices (axis order: "
,
m_input_order
,
", argument shape: "
,
input_shape
,
")."
);
}
}
// TODO(amprocte): should be possible to move around unknown dims in the input shape.
// TODO(amprocte): should be possible to move around unknown dims in the input shape.
if
(
input_rank
.
is_static
())
if
(
input_rank
.
is_static
())
{
{
NODE_VALIDATION_ASSERT
(
this
,
m_input_order
.
size
()
==
size_t
(
input_rank
))
NODE_VALIDATION_CHECK
(
<<
"Input axis order is not a permutation of argument's axis indices (axis order: "
this
,
<<
m_input_order
<<
", argument shape: "
<<
input_shape
<<
")."
;
m_input_order
.
size
()
==
size_t
(
input_rank
),
"Input axis order is not a permutation of argument's axis indices (axis order: "
,
m_input_order
,
", argument shape: "
,
input_shape
,
")."
);
for
(
size_t
i
=
0
;
i
<
size_t
(
input_rank
);
i
++
)
for
(
size_t
i
=
0
;
i
<
size_t
(
input_rank
);
i
++
)
{
{
auto
it
=
find
(
begin
(
m_input_order
),
end
(
m_input_order
),
i
);
auto
it
=
find
(
begin
(
m_input_order
),
end
(
m_input_order
),
i
);
NODE_VALIDATION_ASSERT
(
this
,
it
!=
end
(
m_input_order
))
NODE_VALIDATION_CHECK
(
<<
"Input axis order is not a permutation of argument's axis indices (axis order: "
this
,
<<
m_input_order
<<
", argument shape: "
<<
input_shape
<<
")."
;
it
!=
end
(
m_input_order
),
"Input axis order is not a permutation of argument's axis indices (axis order: "
,
m_input_order
,
", argument shape: "
,
input_shape
,
")."
);
}
}
// TODO(amprocte): make a partial_shape_size() analogous to shape_size().
// TODO(amprocte): make a partial_shape_size() analogous to shape_size().
...
@@ -71,11 +85,16 @@ void op::Reshape::validate_and_infer_types()
...
@@ -71,11 +85,16 @@ void op::Reshape::validate_and_infer_types()
if
(
input_shape_product
.
is_static
())
if
(
input_shape_product
.
is_static
())
{
{
NODE_VALIDATION_ASSERT
(
this
,
size_t
(
input_shape_product
)
==
shape_size
(
m_output_shape
))
NODE_VALIDATION_CHECK
(
<<
"Product of output shape dimensions does not match product of argument shape "
this
,
"dimensions "
size_t
(
input_shape_product
)
==
shape_size
(
m_output_shape
),
<<
"(output shape: "
<<
m_output_shape
<<
", argument shape: "
<<
input_shape
"Product of output shape dimensions does not match product of argument shape "
<<
")."
;
"dimensions "
,
"(output shape: "
,
m_output_shape
,
", argument shape: "
,
input_shape
,
")."
);
}
}
}
}
...
...
src/ngraph/op/result.cpp
View file @
066037c2
...
@@ -32,8 +32,8 @@ op::Result::Result(const shared_ptr<Node>& arg)
...
@@ -32,8 +32,8 @@ op::Result::Result(const shared_ptr<Node>& arg)
void
op
::
Result
::
validate_and_infer_types
()
void
op
::
Result
::
validate_and_infer_types
()
{
{
NODE_VALIDATION_
ASSERT
(
this
,
get_input_size
()
==
1
)
<<
"Argument has "
<<
get_input_size
()
NODE_VALIDATION_
CHECK
(
<<
" outputs (1 expected)."
;
this
,
get_input_size
()
==
1
,
"Argument has "
,
get_input_size
(),
" outputs (1 expected)."
)
;
// always borrow the placement conf even the default one
// always borrow the placement conf even the default one
set_placement_index
(
get_argument
(
0
)
->
get_placement_index
());
set_placement_index
(
get_argument
(
0
)
->
get_placement_index
());
...
...
src/ngraph/op/reverse.cpp
View file @
066037c2
...
@@ -40,9 +40,13 @@ void op::Reverse::validate_and_infer_types()
...
@@ -40,9 +40,13 @@ void op::Reverse::validate_and_infer_types()
// Make sure all reversed axis indices are valid.
// Make sure all reversed axis indices are valid.
for
(
size_t
axis
:
m_reversed_axes
)
for
(
size_t
axis
:
m_reversed_axes
)
{
{
NODE_VALIDATION_ASSERT
(
this
,
axis
<
size_t
(
input_rank
))
NODE_VALIDATION_CHECK
(
this
,
<<
"Reverse axis ("
<<
axis
<<
") is out of bounds (argument shape: "
<<
input_shape
axis
<
size_t
(
input_rank
),
<<
")."
;
"Reverse axis ("
,
axis
,
") is out of bounds (argument shape: "
,
input_shape
,
")."
);
}
}
}
}
...
...
src/ngraph/op/reverse_sequence.cpp
View file @
066037c2
...
@@ -41,20 +41,31 @@ void op::ReverseSequence::validate_and_infer_types()
...
@@ -41,20 +41,31 @@ void op::ReverseSequence::validate_and_infer_types()
auto
input_shape
=
get_input_partial_shape
(
0
);
auto
input_shape
=
get_input_partial_shape
(
0
);
auto
input_rank
=
input_shape
.
rank
();
auto
input_rank
=
input_shape
.
rank
();
NODE_VALIDATION_ASSERT
(
this
,
input_rank
.
is_dynamic
()
||
m_batch_axis
<
size_t
(
input_rank
))
NODE_VALIDATION_CHECK
(
this
,
<<
"Batch axis index ("
<<
m_batch_axis
input_rank
.
is_dynamic
()
||
m_batch_axis
<
size_t
(
input_rank
),
<<
") is out of bounds (argument shape: "
<<
input_shape
<<
")."
;
"Batch axis index ("
,
m_batch_axis
,
") is out of bounds (argument shape: "
,
input_shape
,
")."
);
NODE_VALIDATION_ASSERT
(
this
,
input_rank
.
is_dynamic
()
||
m_seq_axis
<
size_t
(
input_rank
))
NODE_VALIDATION_CHECK
(
this
,
<<
"Sequence axis index ("
<<
m_seq_axis
input_rank
.
is_dynamic
()
||
m_seq_axis
<
size_t
(
input_rank
),
<<
") is out of bounds (argument shape: "
<<
input_shape
<<
")."
;
"Sequence axis index ("
,
m_seq_axis
,
") is out of bounds (argument shape: "
,
input_shape
,
")."
);
auto
indices_shape
=
get_input_partial_shape
(
1
);
auto
indices_shape
=
get_input_partial_shape
(
1
);
auto
indices_rank
=
indices_shape
.
rank
();
auto
indices_rank
=
indices_shape
.
rank
();
NODE_VALIDATION_ASSERT
(
this
,
indices_rank
.
is_dynamic
()
||
size_t
(
indices_rank
)
==
1
)
NODE_VALIDATION_CHECK
(
<<
"Sequence indices must be a 1-dimensional tensor (sequence indices shape: "
this
,
<<
get_input_partial_shape
(
1
)
<<
")."
;
indices_rank
.
is_dynamic
()
||
size_t
(
indices_rank
)
==
1
,
"Sequence indices must be a 1-dimensional tensor (sequence indices shape: "
,
get_input_partial_shape
(
1
),
")."
);
PartialShape
output_shape
{
input_shape
};
PartialShape
output_shape
{
input_shape
};
...
@@ -62,12 +73,19 @@ void op::ReverseSequence::validate_and_infer_types()
...
@@ -62,12 +73,19 @@ void op::ReverseSequence::validate_and_infer_types()
{
{
Dimension
merged_sequence_length
;
Dimension
merged_sequence_length
;
NODE_VALIDATION_
ASSERT
(
NODE_VALIDATION_
CHECK
(
this
,
this
,
Dimension
::
merge
(
merged_sequence_length
,
input_shape
[
m_batch_axis
],
indices_shape
[
0
]))
Dimension
::
merge
(
merged_sequence_length
,
input_shape
[
m_batch_axis
],
indices_shape
[
0
]),
<<
"Sequence length ("
<<
indices_shape
[
0
]
<<
") is not equal to batch axis "
"Sequence length ("
,
<<
"dimension ("
<<
input_shape
[
m_batch_axis
]
<<
") (argument shape: "
<<
input_shape
indices_shape
[
0
],
<<
", sequence indices shape: "
<<
indices_shape
<<
")."
;
") is not equal to batch axis "
,
"dimension ("
,
input_shape
[
m_batch_axis
],
") (argument shape: "
,
input_shape
,
", sequence indices shape: "
,
indices_shape
,
")."
);
output_shape
[
m_batch_axis
]
=
merged_sequence_length
;
output_shape
[
m_batch_axis
]
=
merged_sequence_length
;
}
}
...
...
src/ngraph/op/select.cpp
View file @
066037c2
...
@@ -36,24 +36,28 @@ op::Select::Select(const shared_ptr<Node>& arg0,
...
@@ -36,24 +36,28 @@ op::Select::Select(const shared_ptr<Node>& arg0,
void
op
::
Select
::
validate_and_infer_types
()
void
op
::
Select
::
validate_and_infer_types
()
{
{
NODE_VALIDATION_ASSERT
(
this
,
NODE_VALIDATION_CHECK
(
this
,
get_input_element_type
(
0
).
is_dynamic
()
||
get_input_element_type
(
0
).
is_dynamic
()
||
get_input_element_type
(
0
)
==
element
::
boolean
)
get_input_element_type
(
0
)
==
element
::
boolean
,
<<
"Argument 0 does not have boolean element type (element type: "
"Argument 0 does not have boolean element type (element type: "
,
<<
get_input_element_type
(
0
)
<<
")."
;
get_input_element_type
(
0
),
")."
);
PartialShape
result_shape
=
get_input_partial_shape
(
0
);
PartialShape
result_shape
=
get_input_partial_shape
(
0
);
NODE_VALIDATION_ASSERT
(
this
,
PartialShape
::
merge_into
(
result_shape
,
get_input_partial_shape
(
1
)))
NODE_VALIDATION_CHECK
(
this
,
<<
"Argument shapes are inconsistent."
;
PartialShape
::
merge_into
(
result_shape
,
get_input_partial_shape
(
1
)),
NODE_VALIDATION_ASSERT
(
this
,
PartialShape
::
merge_into
(
result_shape
,
get_input_partial_shape
(
2
)))
"Argument shapes are inconsistent."
);
<<
"Argument shapes are inconsistent."
;
NODE_VALIDATION_CHECK
(
this
,
PartialShape
::
merge_into
(
result_shape
,
get_input_partial_shape
(
2
)),
"Argument shapes are inconsistent."
);
element
::
Type
result_et
;
element
::
Type
result_et
;
NODE_VALIDATION_ASSERT
(
NODE_VALIDATION_CHECK
(
this
,
element
::
Type
::
merge
(
result_et
,
get_input_element_type
(
1
),
get_input_element_type
(
2
)))
this
,
<<
"Argument 1 and 2 element types are inconsistent."
;
element
::
Type
::
merge
(
result_et
,
get_input_element_type
(
1
),
get_input_element_type
(
2
)),
"Argument 1 and 2 element types are inconsistent."
);
set_output_type
(
0
,
result_et
,
result_shape
);
set_output_type
(
0
,
result_et
,
result_shape
);
}
}
...
...
src/ngraph/op/slice.cpp
View file @
066037c2
...
@@ -51,40 +51,68 @@ void op::Slice::validate_and_infer_types()
...
@@ -51,40 +51,68 @@ void op::Slice::validate_and_infer_types()
m_strides
=
Strides
(
m_lower_bounds
.
size
(),
1
);
m_strides
=
Strides
(
m_lower_bounds
.
size
(),
1
);
}
}
NODE_VALIDATION_ASSERT
(
this
,
NODE_VALIDATION_CHECK
(
this
,
m_lower_bounds
.
size
()
==
m_upper_bounds
.
size
()
&&
m_lower_bounds
.
size
()
==
m_upper_bounds
.
size
()
&&
m_lower_bounds
.
size
()
==
m_strides
.
size
())
m_lower_bounds
.
size
()
==
m_strides
.
size
(),
<<
"Ranks of lower bounds ("
<<
m_lower_bounds
<<
"), upper bounds ("
<<
m_upper_bounds
"Ranks of lower bounds ("
,
<<
") and strides ("
<<
m_strides
<<
") do not match."
;
m_lower_bounds
,
"), upper bounds ("
,
m_upper_bounds
,
") and strides ("
,
m_strides
,
") do not match."
);
size_t
output_rank
=
m_upper_bounds
.
size
();
size_t
output_rank
=
m_upper_bounds
.
size
();
for
(
size_t
i
=
0
;
i
<
output_rank
;
i
++
)
for
(
size_t
i
=
0
;
i
<
output_rank
;
i
++
)
{
{
NODE_VALIDATION_ASSERT
(
this
,
m_lower_bounds
[
i
]
<=
m_upper_bounds
[
i
])
NODE_VALIDATION_CHECK
(
this
,
<<
"Lower bound for slice is greater than upper bound at axis "
<<
i
m_lower_bounds
[
i
]
<=
m_upper_bounds
[
i
],
<<
" (lower bounds: "
<<
m_lower_bounds
<<
", upper bounds: "
<<
m_upper_bounds
<<
")."
;
"Lower bound for slice is greater than upper bound at axis "
,
i
,
NODE_VALIDATION_ASSERT
(
this
,
m_strides
[
i
]
!=
0
)
<<
"Stride for slice is zero at axis "
<<
i
" (lower bounds: "
,
<<
" (strides: "
<<
m_strides
<<
")."
;
m_lower_bounds
,
", upper bounds: "
,
m_upper_bounds
,
")."
);
NODE_VALIDATION_CHECK
(
this
,
m_strides
[
i
]
!=
0
,
"Stride for slice is zero at axis "
,
i
,
" (strides: "
,
m_strides
,
")."
);
}
}
const
PartialShape
&
input_shape
=
get_input_partial_shape
(
0
);
const
PartialShape
&
input_shape
=
get_input_partial_shape
(
0
);
Dimension
input_rank
=
input_shape
.
rank
();
Dimension
input_rank
=
input_shape
.
rank
();
NODE_VALIDATION_ASSERT
(
this
,
input_rank
.
is_dynamic
()
||
size_t
(
input_rank
)
==
output_rank
)
NODE_VALIDATION_CHECK
(
this
,
<<
"Input rank does not match the rank of the lower bounds ("
<<
m_lower_bounds
input_rank
.
is_dynamic
()
||
size_t
(
input_rank
)
==
output_rank
,
<<
"), upper bounds ("
<<
m_upper_bounds
<<
"), and strides ("
<<
m_strides
<<
")."
;
"Input rank does not match the rank of the lower bounds ("
,
m_lower_bounds
,
"), upper bounds ("
,
m_upper_bounds
,
"), and strides ("
,
m_strides
,
")."
);
std
::
vector
<
Dimension
>
result_dims
(
output_rank
);
std
::
vector
<
Dimension
>
result_dims
(
output_rank
);
for
(
size_t
i
=
0
;
i
<
output_rank
;
i
++
)
for
(
size_t
i
=
0
;
i
<
output_rank
;
i
++
)
{
{
NODE_VALIDATION_ASSERT
(
this
,
NODE_VALIDATION_CHECK
(
this
,
input_rank
.
is_dynamic
()
||
input_shape
[
i
].
is_dynamic
()
||
input_rank
.
is_dynamic
()
||
input_shape
[
i
].
is_dynamic
()
||
m_upper_bounds
[
i
]
<=
size_t
(
input_shape
[
i
]))
m_upper_bounds
[
i
]
<=
size_t
(
input_shape
[
i
]),
<<
"Upper bound for slice at axis "
<<
i
<<
" is out of range "
"Upper bound for slice at axis "
,
<<
"(upper bounds: "
<<
m_upper_bounds
<<
", argument shape: "
<<
input_shape
<<
")."
;
i
,
" is out of range "
,
"(upper bounds: "
,
m_upper_bounds
,
", argument shape: "
,
input_shape
,
")."
);
size_t
result_axis_size
=
m_upper_bounds
[
i
]
-
m_lower_bounds
[
i
];
size_t
result_axis_size
=
m_upper_bounds
[
i
]
-
m_lower_bounds
[
i
];
result_axis_size
=
result_axis_size
=
...
...
src/ngraph/op/softmax.cpp
View file @
066037c2
...
@@ -37,9 +37,13 @@ op::Softmax::Softmax(const shared_ptr<Node>& arg, const AxisSet& axes)
...
@@ -37,9 +37,13 @@ op::Softmax::Softmax(const shared_ptr<Node>& arg, const AxisSet& axes)
for
(
auto
axis
:
m_axes
)
for
(
auto
axis
:
m_axes
)
{
{
NODE_VALIDATION_ASSERT
(
this
,
axis
<
get_shape
().
size
())
NODE_VALIDATION_CHECK
(
this
,
<<
"Reduction axis ("
<<
axis
<<
") is out of bounds (argument shape: "
<<
get_shape
()
axis
<
get_shape
().
size
(),
<<
")."
;
"Reduction axis ("
,
axis
,
") is out of bounds (argument shape: "
,
get_shape
(),
")."
);
}
}
// empty axes == all axes
// empty axes == all axes
...
...
src/ngraph/op/topk.cpp
View file @
066037c2
...
@@ -43,26 +43,36 @@ void op::TopK::validate_and_infer_types()
...
@@ -43,26 +43,36 @@ void op::TopK::validate_and_infer_types()
Rank
input_rank
=
input_shape
.
rank
();
Rank
input_rank
=
input_shape
.
rank
();
element
::
Type
input_element_type
=
get_input_element_type
(
0
);
element
::
Type
input_element_type
=
get_input_element_type
(
0
);
NODE_VALIDATION_
ASSERT
(
this
,
!
m_index_element_type
.
is_dynamic
())
NODE_VALIDATION_
CHECK
(
<<
"Argument element type must not be dynamic."
;
this
,
!
m_index_element_type
.
is_dynamic
(),
"Argument element type must not be dynamic."
)
;
NODE_VALIDATION_ASSERT
(
NODE_VALIDATION_CHECK
(
this
,
this
,
m_index_element_type
==
element
::
i32
||
m_index_element_type
==
element
::
i64
)
m_index_element_type
==
element
::
i32
||
<<
"Argument element type must be i64 or i32 (got "
<<
m_index_element_type
<<
")."
;
m_index_element_type
==
element
::
i64
,
"Argument element type must be i64 or i32 (got "
,
m_index_element_type
,
")."
);
NODE_VALIDATION_ASSERT
(
this
,
input_rank
.
is_dynamic
()
||
static_cast
<
size_t
>
(
input_rank
)
>
0
)
NODE_VALIDATION_CHECK
(
this
,
<<
"Argument rank must be greater than 0."
;
input_rank
.
is_dynamic
()
||
static_cast
<
size_t
>
(
input_rank
)
>
0
,
"Argument rank must be greater than 0."
);
NODE_VALIDATION_ASSERT
(
NODE_VALIDATION_CHECK
(
this
,
this
,
input_rank
.
is_dynamic
()
||
m_top_k_axis
<
static_cast
<
size_t
>
(
input_rank
))
input_rank
.
is_dynamic
()
||
m_top_k_axis
<
static_cast
<
size_t
>
(
input_rank
),
<<
"TopK axis ("
<<
m_top_k_axis
<<
") is out of bounds."
;
"TopK axis ("
,
m_top_k_axis
,
") is out of bounds."
);
NODE_VALIDATION_ASSERT
(
this
,
NODE_VALIDATION_CHECK
(
this
,
input_rank
.
is_dynamic
()
||
input_shape
[
m_top_k_axis
].
is_dynamic
()
||
input_rank
.
is_dynamic
()
||
input_shape
[
m_top_k_axis
].
is_dynamic
()
||
m_k
<=
static_cast
<
size_t
>
(
input_shape
[
m_top_k_axis
]))
m_k
<=
static_cast
<
size_t
>
(
input_shape
[
m_top_k_axis
]),
<<
"K ("
<<
m_k
<<
") exceeds the dimension ("
"K ("
,
<<
(
input_rank
.
is_static
()
?
input_shape
[
m_top_k_axis
]
:
0
)
<<
") of the TopK axis (axis "
m_k
,
<<
m_top_k_axis
<<
")."
;
") exceeds the dimension ("
,
(
input_rank
.
is_static
()
?
input_shape
[
m_top_k_axis
]
:
0
),
") of the TopK axis (axis "
,
m_top_k_axis
,
")."
);
PartialShape
output_shape
{
input_shape
};
PartialShape
output_shape
{
input_shape
};
...
...
src/ngraph/op/util/arithmetic_reduction.cpp
View file @
066037c2
...
@@ -40,10 +40,16 @@ void op::util::ArithmeticReduction::validate_and_infer_types()
...
@@ -40,10 +40,16 @@ void op::util::ArithmeticReduction::validate_and_infer_types()
for
(
auto
axis
:
m_reduction_axes
)
for
(
auto
axis
:
m_reduction_axes
)
{
{
NODE_VALIDATION_ASSERT
(
this
,
axis
<
size_t
(
input_rank
))
NODE_VALIDATION_CHECK
(
this
,
<<
"Reduction axis ("
<<
axis
<<
") is out of bounds "
axis
<
size_t
(
input_rank
),
<<
"(argument shape: "
<<
input_shape
<<
", reduction axes: "
<<
m_reduction_axes
"Reduction axis ("
,
<<
")"
;
axis
,
") is out of bounds "
,
"(argument shape: "
,
input_shape
,
", reduction axes: "
,
m_reduction_axes
,
")"
);
}
}
for
(
size_t
i
=
0
;
i
<
size_t
(
input_rank
);
i
++
)
for
(
size_t
i
=
0
;
i
<
size_t
(
input_rank
);
i
++
)
...
...
src/ngraph/op/util/index_reduction.cpp
View file @
066037c2
...
@@ -37,13 +37,18 @@ void op::util::IndexReduction::validate_and_infer_types()
...
@@ -37,13 +37,18 @@ void op::util::IndexReduction::validate_and_infer_types()
const
PartialShape
&
arg_shape
=
get_input_partial_shape
(
0
);
const
PartialShape
&
arg_shape
=
get_input_partial_shape
(
0
);
Rank
rank
=
arg_shape
.
rank
();
Rank
rank
=
arg_shape
.
rank
();
NODE_VALIDATION_ASSERT
(
this
,
rank
.
is_dynamic
()
||
size_t
(
rank
)
>=
1
)
NODE_VALIDATION_CHECK
(
this
,
rank
.
is_dynamic
()
||
size_t
(
rank
)
>=
1
,
"Argument rank is zero."
);
<<
"Argument rank is zero."
;
NODE_VALIDATION_CHECK
(
this
,
NODE_VALIDATION_ASSERT
(
this
,
rank
.
is_dynamic
()
||
m_axis
<
size_t
(
rank
))
rank
.
is_dynamic
()
||
m_axis
<
size_t
(
rank
),
<<
"Reduction axis ("
<<
m_axis
<<
") is not less than argument rank ("
<<
rank
<<
")."
;
"Reduction axis ("
,
NODE_VALIDATION_ASSERT
(
m_axis
,
this
,
m_index_element_type
==
element
::
i32
||
m_index_element_type
==
element
::
i64
)
") is not less than argument rank ("
,
<<
"Index element is neither i64 or i32."
;
rank
,
")."
);
NODE_VALIDATION_CHECK
(
this
,
m_index_element_type
==
element
::
i32
||
m_index_element_type
==
element
::
i64
,
"Index element is neither i64 or i32."
);
PartialShape
output_shape
{
PartialShape
::
dynamic
()};
PartialShape
output_shape
{
PartialShape
::
dynamic
()};
...
...
src/ngraph/op/util/logical_reduction.cpp
View file @
066037c2
...
@@ -40,10 +40,16 @@ void op::util::LogicalReduction::validate_and_infer_types()
...
@@ -40,10 +40,16 @@ void op::util::LogicalReduction::validate_and_infer_types()
for
(
auto
axis
:
m_reduction_axes
)
for
(
auto
axis
:
m_reduction_axes
)
{
{
NODE_VALIDATION_ASSERT
(
this
,
axis
<
size_t
(
input_rank
))
NODE_VALIDATION_CHECK
(
this
,
<<
"Reduction axis ("
<<
axis
<<
") is out of bounds "
axis
<
size_t
(
input_rank
),
<<
"(argument shape: "
<<
input_shape
<<
", reduction axes: "
<<
m_reduction_axes
"Reduction axis ("
,
<<
")"
;
axis
,
") is out of bounds "
,
"(argument shape: "
,
input_shape
,
", reduction axes: "
,
m_reduction_axes
,
")"
);
}
}
for
(
size_t
i
=
0
;
i
<
size_t
(
input_rank
);
i
++
)
for
(
size_t
i
=
0
;
i
<
size_t
(
input_rank
);
i
++
)
...
@@ -57,8 +63,9 @@ void op::util::LogicalReduction::validate_and_infer_types()
...
@@ -57,8 +63,9 @@ void op::util::LogicalReduction::validate_and_infer_types()
result_shape
=
PartialShape
(
dims
);
result_shape
=
PartialShape
(
dims
);
}
}
NODE_VALIDATION_ASSERT
(
this
,
get_input_element_type
(
0
).
compatible
(
element
::
boolean
))
NODE_VALIDATION_CHECK
(
this
,
<<
"Input element type must be boolean."
;
get_input_element_type
(
0
).
compatible
(
element
::
boolean
),
"Input element type must be boolean."
);
set_output_type
(
0
,
element
::
boolean
,
result_shape
);
set_output_type
(
0
,
element
::
boolean
,
result_shape
);
}
}
src/ngraph/runtime/cpu/op/conv_add.cpp
View file @
066037c2
...
@@ -29,9 +29,14 @@ void op::util::validate_conv_shapes(const Node* node,
...
@@ -29,9 +29,14 @@ void op::util::validate_conv_shapes(const Node* node,
const
Shape
&
data_shape
,
const
Shape
&
data_shape
,
const
Shape
&
filters_shape
)
const
Shape
&
filters_shape
)
{
{
NODE_VALIDATION_ASSERT
(
node
,
data_shape
[
1
]
==
filters_shape
[
1
])
NODE_VALIDATION_CHECK
(
<<
"Number of channels for data and filters do not match (data num channels: "
node
,
<<
data_shape
[
1
]
<<
", filters num channels: "
<<
filters_shape
[
1
]
<<
")."
;
data_shape
[
1
]
==
filters_shape
[
1
],
"Number of channels for data and filters do not match (data num channels: "
,
data_shape
[
1
],
", filters num channels: "
,
filters_shape
[
1
],
")."
);
}
}
op
::
ConvolutionAdd
::
ConvolutionAdd
(
const
std
::
shared_ptr
<
op
::
Convolution
>&
conv
,
op
::
ConvolutionAdd
::
ConvolutionAdd
(
const
std
::
shared_ptr
<
op
::
Convolution
>&
conv
,
...
@@ -79,9 +84,14 @@ op::ConvolutionAdd::ConvolutionAdd(const std::shared_ptr<Node>& data_batch,
...
@@ -79,9 +84,14 @@ op::ConvolutionAdd::ConvolutionAdd(const std::shared_ptr<Node>& data_batch,
//
//
// Make sure data batch and filter element types match.
// Make sure data batch and filter element types match.
//
//
NODE_VALIDATION_ASSERT
(
this
,
data_batch_et
==
filters_et
)
NODE_VALIDATION_CHECK
(
<<
"Element types for data_batch and filters do not match (data batch element type: "
this
,
<<
data_batch_et
<<
", filters element type: "
<<
filters_et
<<
")."
;
data_batch_et
==
filters_et
,
"Element types for data_batch and filters do not match (data batch element type: "
,
data_batch_et
,
", filters element type: "
,
filters_et
,
")."
);
util
::
validate_conv_shapes
(
this
,
data_batch_shape
,
filters_shape
);
util
::
validate_conv_shapes
(
this
,
data_batch_shape
,
filters_shape
);
set_output_type
(
0
,
set_output_type
(
0
,
...
@@ -105,8 +115,11 @@ op::ConvolutionAdd::ConvolutionAdd(const std::shared_ptr<Node>& data_batch,
...
@@ -105,8 +115,11 @@ op::ConvolutionAdd::ConvolutionAdd(const std::shared_ptr<Node>& data_batch,
std
::
shared_ptr
<
Node
>
op
::
ConvolutionAdd
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
std
::
shared_ptr
<
Node
>
op
::
ConvolutionAdd
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
{
NODE_VALIDATION_ASSERT
(
this
,
new_args
.
size
()
==
3
)
NODE_VALIDATION_CHECK
(
this
,
<<
"New arg size is not 3 (new args size: "
<<
new_args
.
size
()
<<
")."
;
new_args
.
size
()
==
3
,
"New arg size is not 3 (new args size: "
,
new_args
.
size
(),
")."
);
return
std
::
shared_ptr
<
Node
>
(
new
ConvolutionAdd
(
new_args
.
at
(
0
),
return
std
::
shared_ptr
<
Node
>
(
new
ConvolutionAdd
(
new_args
.
at
(
0
),
new_args
.
at
(
1
),
new_args
.
at
(
1
),
...
...
src/ngraph/runtime/cpu/op/update_slice.cpp
View file @
066037c2
...
@@ -57,51 +57,85 @@ void op::UpdateSlice::validate_and_infer_types()
...
@@ -57,51 +57,85 @@ void op::UpdateSlice::validate_and_infer_types()
const
PartialShape
&
arg1_shape
=
get_input_partial_shape
(
1
);
const
PartialShape
&
arg1_shape
=
get_input_partial_shape
(
1
);
Dimension
merged_args_rank
;
Dimension
merged_args_rank
;
NODE_VALIDATION_ASSERT
(
this
,
NODE_VALIDATION_CHECK
(
this
,
Dimension
::
merge
(
merged_args_rank
,
arg0_shape
.
rank
(),
arg1_shape
.
rank
()))
Dimension
::
merge
(
merged_args_rank
,
arg0_shape
.
rank
(),
arg1_shape
.
rank
()),
<<
"Argument ranks do not match (arg0 shape: "
<<
arg0_shape
"Argument ranks do not match (arg0 shape: "
,
<<
", arg1 shape: "
<<
arg1_shape
<<
")."
;
arg0_shape
,
", arg1 shape: "
,
arg1_shape
,
")."
);
element
::
Type
arg0_et
=
get_input_element_type
(
0
);
element
::
Type
arg0_et
=
get_input_element_type
(
0
);
element
::
Type
arg1_et
=
get_input_element_type
(
1
);
element
::
Type
arg1_et
=
get_input_element_type
(
1
);
element
::
Type
merged_args_et
;
element
::
Type
merged_args_et
;
NODE_VALIDATION_ASSERT
(
this
,
element
::
Type
::
merge
(
merged_args_et
,
arg0_et
,
arg1_et
))
NODE_VALIDATION_CHECK
(
this
,
<<
"Argument element types do not match (arg0 element type: "
<<
arg0_et
element
::
Type
::
merge
(
merged_args_et
,
arg0_et
,
arg1_et
),
<<
", arg1 element type: "
<<
arg1_et
<<
")."
;
"Argument element types do not match (arg0 element type: "
,
arg0_et
,
NODE_VALIDATION_ASSERT
(
this
,
", arg1 element type: "
,
m_lower_bounds
.
size
()
==
m_upper_bounds
.
size
()
&&
arg1_et
,
m_lower_bounds
.
size
()
==
m_strides
.
size
())
")."
);
<<
"Ranks of lower bounds ("
<<
m_lower_bounds
<<
"), upper bounds ("
<<
m_upper_bounds
<<
") and strides ("
<<
m_strides
<<
") do not match."
;
NODE_VALIDATION_CHECK
(
this
,
m_lower_bounds
.
size
()
==
m_upper_bounds
.
size
()
&&
m_lower_bounds
.
size
()
==
m_strides
.
size
(),
"Ranks of lower bounds ("
,
m_lower_bounds
,
"), upper bounds ("
,
m_upper_bounds
,
") and strides ("
,
m_strides
,
") do not match."
);
size_t
output_rank
=
m_upper_bounds
.
size
();
size_t
output_rank
=
m_upper_bounds
.
size
();
for
(
size_t
i
=
0
;
i
<
output_rank
;
i
++
)
for
(
size_t
i
=
0
;
i
<
output_rank
;
i
++
)
{
{
NODE_VALIDATION_ASSERT
(
this
,
m_lower_bounds
[
i
]
<=
m_upper_bounds
[
i
])
NODE_VALIDATION_CHECK
(
this
,
<<
"Lower bound for slice is greater than upper bound at axis "
<<
i
m_lower_bounds
[
i
]
<=
m_upper_bounds
[
i
],
<<
" (lower bounds: "
<<
m_lower_bounds
<<
", upper bounds: "
<<
m_upper_bounds
<<
")."
;
"Lower bound for slice is greater than upper bound at axis "
,
i
,
NODE_VALIDATION_ASSERT
(
this
,
m_strides
[
i
]
!=
0
)
<<
"Stride for slice is zero at axis "
<<
i
" (lower bounds: "
,
<<
" (strides: "
<<
m_strides
<<
")."
;
m_lower_bounds
,
", upper bounds: "
,
m_upper_bounds
,
")."
);
NODE_VALIDATION_CHECK
(
this
,
m_strides
[
i
]
!=
0
,
"Stride for slice is zero at axis "
,
i
,
" (strides: "
,
m_strides
,
")."
);
}
}
NODE_VALIDATION_ASSERT
(
this
,
NODE_VALIDATION_CHECK
(
this
,
merged_args_rank
.
is_dynamic
()
||
size_t
(
merged_args_rank
)
==
output_rank
)
merged_args_rank
.
is_dynamic
()
||
size_t
(
merged_args_rank
)
==
output_rank
,
<<
"Argument ranks do not match the rank of the lower bounds ("
<<
m_lower_bounds
"Argument ranks do not match the rank of the lower bounds ("
,
<<
"), upper bounds ("
<<
m_upper_bounds
<<
"), and strides ("
<<
m_strides
<<
")."
;
m_lower_bounds
,
"), upper bounds ("
,
m_upper_bounds
,
"), and strides ("
,
m_strides
,
")."
);
std
::
vector
<
Dimension
>
sliced_dims
(
output_rank
);
std
::
vector
<
Dimension
>
sliced_dims
(
output_rank
);
for
(
size_t
i
=
0
;
i
<
output_rank
;
i
++
)
for
(
size_t
i
=
0
;
i
<
output_rank
;
i
++
)
{
{
NODE_VALIDATION_ASSERT
(
this
,
NODE_VALIDATION_CHECK
(
this
,
arg0_shape
.
rank
().
is_dynamic
()
||
arg0_shape
[
i
].
is_dynamic
()
||
arg0_shape
.
rank
().
is_dynamic
()
||
arg0_shape
[
i
].
is_dynamic
()
||
m_upper_bounds
[
i
]
<=
size_t
(
arg0_shape
[
i
]))
m_upper_bounds
[
i
]
<=
size_t
(
arg0_shape
[
i
]),
<<
"Upper bound for slice at axis "
<<
i
<<
" is out of range "
"Upper bound for slice at axis "
,
<<
"(upper bounds: "
<<
m_upper_bounds
<<
", argument shape: "
<<
arg0_shape
<<
")."
;
i
,
" is out of range "
,
"(upper bounds: "
,
m_upper_bounds
,
", argument shape: "
,
arg0_shape
,
")."
);
size_t
sliced_dim
=
m_upper_bounds
[
i
]
-
m_lower_bounds
[
i
];
size_t
sliced_dim
=
m_upper_bounds
[
i
]
-
m_lower_bounds
[
i
];
sliced_dim
=
sliced_dim
/
m_strides
[
i
]
+
((
sliced_dim
%
m_strides
[
i
]
==
0
)
?
0
:
1
);
sliced_dim
=
sliced_dim
/
m_strides
[
i
]
+
((
sliced_dim
%
m_strides
[
i
]
==
0
)
?
0
:
1
);
...
@@ -110,9 +144,14 @@ void op::UpdateSlice::validate_and_infer_types()
...
@@ -110,9 +144,14 @@ void op::UpdateSlice::validate_and_infer_types()
PartialShape
slice_shape
{
sliced_dims
};
PartialShape
slice_shape
{
sliced_dims
};
NODE_VALIDATION_ASSERT
(
this
,
arg1_shape
.
compatible
(
slice_shape
))
NODE_VALIDATION_CHECK
(
this
,
<<
"Shape of replacement tensor ("
<<
arg1_shape
<<
") does not match the slice shape "
arg1_shape
.
compatible
(
slice_shape
),
<<
"("
<<
slice_shape
<<
")."
;
"Shape of replacement tensor ("
,
arg1_shape
,
") does not match the slice shape "
,
"("
,
slice_shape
,
")."
);
// Slight corner case here: if arg0 was rank-unknown, we can go ahead and set the output rank
// Slight corner case here: if arg0 was rank-unknown, we can go ahead and set the output rank
// because the attribs will have given us enough info.
// because the attribs will have given us enough info.
...
...
test/type_prop.cpp
View file @
066037c2
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment