Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
066037c2
Commit
066037c2
authored
Mar 05, 2019
by
Adam Procter
Committed by
Scott Cyphers
Mar 05, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Switch everything from NODE_VALIDATION_ASSERT to NODE_VALIDATION_CHECK (#2546)
parent
dd23b0cb
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
36 changed files
with
887 additions
and
444 deletions
+887
-444
node.hpp
src/ngraph/node.hpp
+1
-21
allreduce.cpp
src/ngraph/op/allreduce.cpp
+5
-4
avg_pool.cpp
src/ngraph/op/avg_pool.cpp
+9
-3
batch_norm.cpp
src/ngraph/op/batch_norm.cpp
+15
-10
broadcast.cpp
src/ngraph/op/broadcast.cpp
+21
-7
concat.cpp
src/ngraph/op/concat.cpp
+23
-13
constant.hpp
src/ngraph/op/constant.hpp
+21
-9
convolution.cpp
src/ngraph/op/convolution.cpp
+0
-0
dequantize.cpp
src/ngraph/op/dequantize.cpp
+57
-28
dot.cpp
src/ngraph/op/dot.cpp
+41
-20
embedding_lookup.cpp
src/ngraph/op/embedding_lookup.cpp
+4
-3
generate_mask.cpp
src/ngraph/op/experimental/generate_mask.cpp
+5
-4
quantized_avg_pool.cpp
src/ngraph/op/experimental/quantized_avg_pool.cpp
+89
-45
quantized_concat.cpp
src/ngraph/op/experimental/quantized_concat.cpp
+22
-13
quantized_max_pool.cpp
src/ngraph/op/experimental/quantized_max_pool.cpp
+71
-34
get_output_element.cpp
src/ngraph/op/get_output_element.cpp
+7
-3
lrn.cpp
src/ngraph/op/lrn.cpp
+6
-3
max_pool.cpp
src/ngraph/op/max_pool.cpp
+16
-6
one_hot.cpp
src/ngraph/op/one_hot.cpp
+24
-11
pad.cpp
src/ngraph/op/pad.cpp
+36
-18
quantize.cpp
src/ngraph/op/quantize.cpp
+57
-28
replace_slice.cpp
src/ngraph/op/replace_slice.cpp
+67
-28
reshape.cpp
src/ngraph/op/reshape.cpp
+34
-15
result.cpp
src/ngraph/op/result.cpp
+2
-2
reverse.cpp
src/ngraph/op/reverse.cpp
+7
-3
reverse_sequence.cpp
src/ngraph/op/reverse_sequence.cpp
+32
-14
select.cpp
src/ngraph/op/select.cpp
+15
-11
slice.cpp
src/ngraph/op/slice.cpp
+45
-17
softmax.cpp
src/ngraph/op/softmax.cpp
+7
-3
topk.cpp
src/ngraph/op/topk.cpp
+25
-15
arithmetic_reduction.cpp
src/ngraph/op/util/arithmetic_reduction.cpp
+10
-4
index_reduction.cpp
src/ngraph/op/util/index_reduction.cpp
+12
-7
logical_reduction.cpp
src/ngraph/op/util/logical_reduction.cpp
+13
-6
conv_add.cpp
src/ngraph/runtime/cpu/op/conv_add.cpp
+21
-8
update_slice.cpp
src/ngraph/runtime/cpu/op/update_slice.cpp
+67
-28
type_prop.cpp
test/type_prop.cpp
+0
-0
No files found.
src/ngraph/node.hpp
View file @
066037c2
...
...
@@ -275,19 +275,6 @@ namespace ngraph
size_t
m_placement_index
=
placement_invalid
;
};
class
NodeValidationError
:
public
AssertionFailure
{
public
:
NodeValidationError
(
std
::
string
what
)
:
AssertionFailure
(
what
)
{
}
NodeValidationError
(
const
char
*
what
)
:
AssertionFailure
(
what
)
{
}
};
class
NodeValidationFailure
:
public
CheckFailure
{
public
:
...
...
@@ -321,12 +308,5 @@ namespace ngraph
void
check_new_args_count
(
const
Node
*
node
,
const
NodeVector
&
new_args
);
}
// namespace ngraph
#define NODE_VALIDATION_ASSERT(node, cond) \
NGRAPH_ASSERT_STREAM_WITH_LOC( \
::ngraph::NodeValidationError, cond, ::ngraph::node_validation_assertion_string(node))
#define NODE_VALIDATION_FAIL(node) \
NGRAPH_FAIL_STREAM_WITH_LOC(::ngraph::NodeValidationError, \
::ngraph::node_validation_assertion_string(node))
#define NODE_VALIDATION_CHECK(node, cond, ...) \
NGRAPH_CHECK(::NodeValidationFailure, (node), (cond), __VA_ARGS__)
NGRAPH_CHECK(::
ngraph::
NodeValidationFailure, (node), (cond), __VA_ARGS__)
src/ngraph/op/allreduce.cpp
View file @
066037c2
...
...
@@ -27,12 +27,13 @@ op::AllReduce::AllReduce(const shared_ptr<Node>& arg)
void
op
::
AllReduce
::
validate_and_infer_types
()
{
NODE_VALIDATION_
ASSERT
(
this
,
NODE_VALIDATION_
CHECK
(
this
,
get_input_element_type
(
0
).
is_dynamic
()
||
get_input_element_type
(
0
)
==
element
::
f32
||
get_input_element_type
(
0
)
==
element
::
f64
)
<<
"Only element types f32 and f64 are supported (argument element type: "
<<
get_input_element_type
(
0
)
<<
")."
;
get_input_element_type
(
0
)
==
element
::
f64
,
"Only element types f32 and f64 are supported (argument element type: "
,
get_input_element_type
(
0
),
")."
);
set_output_type
(
0
,
get_input_element_type
(
0
),
get_input_partial_shape
(
0
));
}
...
...
src/ngraph/op/avg_pool.cpp
View file @
066037c2
...
...
@@ -132,9 +132,15 @@ void op::AvgPoolBackprop::validate_and_infer_types()
const
PartialShape
&
delta_shape
=
get_input_partial_shape
(
0
);
NODE_VALIDATION_ASSERT
(
this
,
forward_result_shape
.
compatible
(
delta_shape
))
<<
"Inferred forward output shape does not match delta shape (inferred forward output "
<<
"shape: "
<<
forward_result_shape
<<
", delta shape: "
<<
delta_shape
<<
")."
;
NODE_VALIDATION_CHECK
(
this
,
forward_result_shape
.
compatible
(
delta_shape
),
"Inferred forward output shape does not match delta shape (inferred forward output "
,
"shape: "
,
forward_result_shape
,
", delta shape: "
,
delta_shape
,
")."
);
// TODO(amprocte): Once m_forward_arg_shape is allowed to be dynamic, we may technically be
// able to infer some extra information from forward_result_shape that was not present in the
...
...
src/ngraph/op/batch_norm.cpp
View file @
066037c2
...
...
@@ -205,21 +205,26 @@ void ngraph::op::BatchNormTrainingBackprop::validate_and_infer_types()
{
PartialShape
input_and_delta_shape
{
get_input_partial_shape
(
INPUT_DATA
)};
NODE_VALIDATION_ASSERT
(
this
,
PartialShape
::
merge_into
(
input_and_delta_shape
,
get_input_partial_shape
(
INPUT_DELTA
)))
<<
"Shape of delta does not match the shape of the input data (input data shape: "
<<
get_input_partial_shape
(
INPUT_DATA
)
<<
", delta shape: "
<<
get_input_partial_shape
(
INPUT_DELTA
)
<<
")."
;
NODE_VALIDATION_CHECK
(
this
,
PartialShape
::
merge_into
(
input_and_delta_shape
,
get_input_partial_shape
(
INPUT_DELTA
)),
"Shape of delta does not match the shape of the input data (input data shape: "
,
get_input_partial_shape
(
INPUT_DATA
),
", delta shape: "
,
get_input_partial_shape
(
INPUT_DELTA
),
")."
);
element
::
Type
input_and_delta_et
;
NODE_VALIDATION_
ASSERT
(
this
,
NODE_VALIDATION_
CHECK
(
this
,
element
::
Type
::
merge
(
input_and_delta_et
,
get_input_element_type
(
INPUT_DATA
),
get_input_element_type
(
INPUT_DELTA
)))
<<
"Element type for input ("
<<
get_input_element_type
(
INPUT_DATA
)
<<
") does not match element type for delta ("
<<
get_input_element_type
(
INPUT_DATA
)
<<
")."
;
get_input_element_type
(
INPUT_DELTA
)),
"Element type for input ("
,
get_input_element_type
(
INPUT_DATA
),
") does not match element type for delta ("
,
get_input_element_type
(
INPUT_DATA
),
")."
);
element
::
Type
result_et
;
PartialShape
result_batch_shape
;
...
...
src/ngraph/op/broadcast.cpp
View file @
066037c2
...
...
@@ -44,9 +44,16 @@ void op::Broadcast::validate_and_infer_types()
for
(
auto
axis
:
m_broadcast_axes
)
{
NODE_VALIDATION_ASSERT
(
this
,
axis
<
m_shape
.
size
())
<<
"Broadcast axis index ("
<<
axis
<<
") exceeds specified output shape rank "
<<
"(broadcast axes: "
<<
m_broadcast_axes
<<
", output shape: "
<<
m_shape
<<
")."
;
NODE_VALIDATION_CHECK
(
this
,
axis
<
m_shape
.
size
(),
"Broadcast axis index ("
,
axis
,
") exceeds specified output shape rank "
,
"(broadcast axes: "
,
m_broadcast_axes
,
", output shape: "
,
m_shape
,
")."
);
}
Shape
required_input_shape
=
m_shape
;
...
...
@@ -59,10 +66,17 @@ void op::Broadcast::validate_and_infer_types()
// There are two things that can go wrong, which are being picked up in
// one fell swoop by this check: either the number of broadcast axes is not
// enough, or there is a mismatch with one of the pre-broadcast axis lengths.
NODE_VALIDATION_ASSERT
(
this
,
get_input_partial_shape
(
0
).
compatible
(
required_input_shape
))
<<
"Broadcast argument shape, specified output shape, and axes are incompatible "
<<
"(argument shape: "
<<
get_input_partial_shape
(
0
)
<<
", output shape: "
<<
m_shape
<<
", broadcast axes: "
<<
m_broadcast_axes
<<
")."
;
NODE_VALIDATION_CHECK
(
this
,
get_input_partial_shape
(
0
).
compatible
(
required_input_shape
),
"Broadcast argument shape, specified output shape, and axes are incompatible "
,
"(argument shape: "
,
get_input_partial_shape
(
0
),
", output shape: "
,
m_shape
,
", broadcast axes: "
,
m_broadcast_axes
,
")."
);
set_output_type
(
0
,
get_input_element_type
(
0
),
m_shape
);
}
...
...
src/ngraph/op/concat.cpp
View file @
066037c2
...
...
@@ -32,7 +32,7 @@ op::Concat::Concat(const NodeVector& args, size_t concatenation_axis)
void
op
::
Concat
::
validate_and_infer_types
()
{
NODE_VALIDATION_
ASSERT
(
this
,
m_inputs
.
size
()
>=
1
)
<<
"At least one argument required."
;
NODE_VALIDATION_
CHECK
(
this
,
m_inputs
.
size
()
>=
1
,
"At least one argument required."
)
;
PartialShape
inputs_shape_scheme
{
PartialShape
::
dynamic
()};
element
::
Type
inputs_et
{
element
::
dynamic
};
...
...
@@ -44,22 +44,32 @@ void op::Concat::validate_and_infer_types()
Dimension
this_input_rank
=
this_input_shape
.
rank
();
if
(
this_input_rank
.
is_static
())
{
NODE_VALIDATION_ASSERT
(
this
,
m_concatenation_axis
<
size_t
(
this_input_rank
))
<<
"Concatenation axis ("
<<
m_concatenation_axis
<<
") is out of bounds for "
<<
"argument "
<<
i
<<
", which has shape "
<<
this_input_shape
<<
"."
;
NODE_VALIDATION_CHECK
(
this
,
m_concatenation_axis
<
size_t
(
this_input_rank
),
"Concatenation axis ("
,
m_concatenation_axis
,
") is out of bounds for "
,
"argument "
,
i
,
", which has shape "
,
this_input_shape
,
"."
);
concatenation_axis_output_dim
+=
this_input_shape
[
m_concatenation_axis
];
this_input_shape
[
m_concatenation_axis
]
=
Dimension
::
dynamic
();
NODE_VALIDATION_ASSERT
(
this
,
PartialShape
::
merge_into
(
inputs_shape_scheme
,
this_input_shape
))
<<
"Argument shapes are inconsistent; they must have the same rank, and must have "
<<
"equal dimension everywhere except on the concatenation axis (axis "
<<
m_concatenation_axis
<<
")."
;
NODE_VALIDATION_ASSERT
(
this
,
element
::
Type
::
merge
(
inputs_et
,
inputs_et
,
get_input_element_type
(
i
)))
<<
"Argument element types are inconsistent."
;
NODE_VALIDATION_CHECK
(
this
,
PartialShape
::
merge_into
(
inputs_shape_scheme
,
this_input_shape
),
"Argument shapes are inconsistent; they must have the same rank, and must have "
,
"equal dimension everywhere except on the concatenation axis (axis "
,
m_concatenation_axis
,
")."
);
NODE_VALIDATION_CHECK
(
this
,
element
::
Type
::
merge
(
inputs_et
,
inputs_et
,
get_input_element_type
(
i
)),
"Argument element types are inconsistent."
);
}
else
{
...
...
src/ngraph/op/constant.hpp
View file @
066037c2
...
...
@@ -47,11 +47,17 @@ namespace ngraph
,
m_data
(
ngraph
::
aligned_alloc
(
m_element_type
.
size
(),
shape_size
(
m_shape
)
*
m_element_type
.
size
()))
{
NODE_VALIDATION_ASSERT
(
this
,
values
.
size
()
==
1
||
values
.
size
()
==
shape_size
(
m_shape
))
<<
"Did not get the expected number of literals for a constant of shape "
<<
m_shape
<<
" (got "
<<
values
.
size
()
<<
", expected "
<<
(
shape_size
(
m_shape
)
==
1
?
""
:
"1 or "
)
<<
shape_size
(
m_shape
)
<<
")."
;
NODE_VALIDATION_CHECK
(
this
,
values
.
size
()
==
1
||
values
.
size
()
==
shape_size
(
m_shape
),
"Did not get the expected number of literals for a constant of shape "
,
m_shape
,
" (got "
,
values
.
size
(),
", expected "
,
(
shape_size
(
m_shape
)
==
1
?
""
:
"1 or "
),
shape_size
(
m_shape
),
")."
);
if
(
values
.
size
()
==
1
)
{
...
...
@@ -77,10 +83,16 @@ namespace ngraph
,
m_data
(
ngraph
::
aligned_alloc
(
m_element_type
.
size
(),
shape_size
(
m_shape
)
*
m_element_type
.
size
()))
{
NODE_VALIDATION_ASSERT
(
this
,
values
.
size
()
==
shape_size
(
m_shape
))
<<
"Did not get the expected number of literals for a constant of shape "
<<
m_shape
<<
" (got "
<<
values
.
size
()
<<
", expected "
<<
shape_size
(
m_shape
)
<<
"."
;
NODE_VALIDATION_CHECK
(
this
,
values
.
size
()
==
shape_size
(
m_shape
),
"Did not get the expected number of literals for a constant of shape "
,
m_shape
,
" (got "
,
values
.
size
(),
", expected "
,
shape_size
(
m_shape
),
"."
);
std
::
vector
<
double
>
dvalues
=
parse_string
<
double
>
(
values
);
write_values
(
dvalues
);
...
...
src/ngraph/op/convolution.cpp
View file @
066037c2
This diff is collapsed.
Click to expand it.
src/ngraph/op/dequantize.cpp
View file @
066037c2
...
...
@@ -42,50 +42,73 @@ void op::Dequantize::validate_and_infer_types()
OFFSET
};
NODE_VALIDATION_
ASSERT
(
this
,
m_type
.
is_static
())
<<
"Output element type must not be dynamic"
;
NODE_VALIDATION_
CHECK
(
this
,
m_type
.
is_static
(),
"Output element type must not be dynamic"
)
;
NODE_VALIDATION_
ASSERT
(
this
,
m_type
.
is_real
())
<<
"Output element type ("
<<
m_type
<<
") must be a floating point type"
;
NODE_VALIDATION_
CHECK
(
this
,
m_type
.
is_real
(),
"Output element type ("
,
m_type
,
") must be a floating point type"
)
;
element
::
Type
quantized_type
;
NODE_VALIDATION_
ASSERT
(
this
,
NODE_VALIDATION_
CHECK
(
this
,
element
::
Type
::
merge
(
quantized_type
,
get_input_element_type
(
INPUT
),
get_input_element_type
(
OFFSET
)))
<<
"Offset element type ("
<<
get_input_element_type
(
OFFSET
)
<<
") must match input element type ("
<<
get_input_element_type
(
INPUT
)
<<
")"
;
get_input_element_type
(
OFFSET
)),
"Offset element type ("
,
get_input_element_type
(
OFFSET
),
") must match input element type ("
,
get_input_element_type
(
INPUT
),
")"
);
NODE_VALIDATION_ASSERT
(
this
,
quantized_type
.
is_dynamic
()
||
quantized_type
.
is_quantized
())
<<
"Offset/input element type ("
<<
quantized_type
<<
") must be a quantized type"
;
NODE_VALIDATION_CHECK
(
this
,
quantized_type
.
is_dynamic
()
||
quantized_type
.
is_quantized
(),
"Offset/input element type ("
,
quantized_type
,
") must be a quantized type"
);
element
::
Type
unquantized_type
;
NODE_VALIDATION_ASSERT
(
this
,
element
::
Type
::
merge
(
unquantized_type
,
get_input_element_type
(
SCALE
),
m_type
))
<<
"Scale element type ("
<<
get_input_element_type
(
SCALE
)
<<
") must match output element type ("
<<
m_type
<<
")"
;
NODE_VALIDATION_CHECK
(
this
,
element
::
Type
::
merge
(
unquantized_type
,
get_input_element_type
(
SCALE
),
m_type
),
"Scale element type ("
,
get_input_element_type
(
SCALE
),
") must match output element type ("
,
m_type
,
")"
);
PartialShape
input_shape
=
get_input_partial_shape
(
0
);
Dimension
input_rank
=
input_shape
.
rank
();
for
(
auto
axis
:
m_axes
)
{
NODE_VALIDATION_ASSERT
(
this
,
input_rank
.
is_dynamic
()
||
axis
<
size_t
(
input_rank
))
<<
"Quantization axis ("
<<
axis
<<
") must be less than input shape rank ("
<<
input_rank
<<
")"
;
NODE_VALIDATION_CHECK
(
this
,
input_rank
.
is_dynamic
()
||
axis
<
size_t
(
input_rank
),
"Quantization axis ("
,
axis
,
") must be less than input shape rank ("
,
input_rank
,
")"
);
}
PartialShape
scale_offset_shape
=
get_input_partial_shape
(
SCALE
);
NODE_VALIDATION_ASSERT
(
this
,
PartialShape
::
merge_into
(
scale_offset_shape
,
get_input_partial_shape
(
OFFSET
)))
<<
"Scale shape ("
<<
get_input_partial_shape
(
SCALE
)
<<
") and offset shape ("
<<
get_input_partial_shape
(
OFFSET
)
<<
") must match"
;
NODE_VALIDATION_ASSERT
(
this
,
scale_offset_shape
.
rank
().
compatible
(
m_axes
.
size
()))
<<
"Scale/offset rank ("
<<
scale_offset_shape
.
rank
()
<<
") does not match the number of "
<<
"quantization axes ("
<<
m_axes
.
size
()
<<
")"
;
NODE_VALIDATION_CHECK
(
this
,
PartialShape
::
merge_into
(
scale_offset_shape
,
get_input_partial_shape
(
OFFSET
)),
"Scale shape ("
,
get_input_partial_shape
(
SCALE
),
") and offset shape ("
,
get_input_partial_shape
(
OFFSET
),
") must match"
);
NODE_VALIDATION_CHECK
(
this
,
scale_offset_shape
.
rank
().
compatible
(
m_axes
.
size
()),
"Scale/offset rank ("
,
scale_offset_shape
.
rank
(),
") does not match the number of "
,
"quantization axes ("
,
m_axes
.
size
(),
")"
);
set_output_size
(
1
);
...
...
@@ -108,10 +131,16 @@ void op::Dequantize::validate_and_infer_types()
}
PartialShape
result_shape
=
input_shape
;
NODE_VALIDATION_ASSERT
(
this
,
PartialShape
::
merge_into
(
result_shape
,
PartialShape
{
injected_scale_offset_dims
}))
<<
"Scale/offset shape ("
<<
scale_offset_shape
<<
") must match input shape ("
<<
input_shape
<<
") at the quantization axes ("
<<
m_axes
<<
")"
;
NODE_VALIDATION_CHECK
(
this
,
PartialShape
::
merge_into
(
result_shape
,
PartialShape
{
injected_scale_offset_dims
}),
"Scale/offset shape ("
,
scale_offset_shape
,
") must match input shape ("
,
input_shape
,
") at the quantization axes ("
,
m_axes
,
")"
);
set_output_type
(
0
,
unquantized_type
,
result_shape
);
}
else
...
...
src/ngraph/op/dot.cpp
View file @
066037c2
...
...
@@ -49,11 +49,14 @@ void op::Dot::validate_and_infer_types()
{
element
::
Type
result_et
;
NODE_VALIDATION_ASSERT
(
this
,
element
::
Type
::
merge
(
result_et
,
get_input_element_type
(
0
),
get_input_element_type
(
1
)))
<<
"Arguments do not have the same element type (arg0 element type: "
<<
get_input_element_type
(
0
)
<<
", arg1 element type: "
<<
get_input_element_type
(
1
)
<<
")."
;
NODE_VALIDATION_CHECK
(
this
,
element
::
Type
::
merge
(
result_et
,
get_input_element_type
(
0
),
get_input_element_type
(
1
)),
"Arguments do not have the same element type (arg0 element type: "
,
get_input_element_type
(
0
),
", arg1 element type: "
,
get_input_element_type
(
1
),
")."
);
const
PartialShape
&
arg0_shape
=
get_input_partial_shape
(
0
);
const
PartialShape
&
arg1_shape
=
get_input_partial_shape
(
1
);
...
...
@@ -82,17 +85,27 @@ void op::Dot::validate_and_infer_types()
PartialShape
result_shape
;
NODE_VALIDATION_
ASSERT
(
this
,
NODE_VALIDATION_
CHECK
(
this
,
reduction_axes_ambiguous
||
arg0_shape
.
rank
().
is_dynamic
()
||
m_reduction_axes_count
<=
size_t
(
arg0_shape
.
rank
()))
<<
"Reduction axes count ("
<<
m_reduction_axes_count
<<
") is too large (arg0 shape: "
<<
arg0_shape
<<
", arg1 shape: "
<<
arg1_shape
<<
")."
;
NODE_VALIDATION_ASSERT
(
this
,
m_reduction_axes_count
<=
size_t
(
arg0_shape
.
rank
()),
"Reduction axes count ("
,
m_reduction_axes_count
,
") is too large (arg0 shape: "
,
arg0_shape
,
", arg1 shape: "
,
arg1_shape
,
")."
);
NODE_VALIDATION_CHECK
(
this
,
reduction_axes_ambiguous
||
arg1_shape
.
rank
().
is_dynamic
()
||
m_reduction_axes_count
<=
size_t
(
arg1_shape
.
rank
()))
<<
"Reduction axes count ("
<<
m_reduction_axes_count
<<
") is too large (arg0 shape: "
<<
arg0_shape
<<
", arg1 shape: "
<<
arg1_shape
<<
")."
;
m_reduction_axes_count
<=
size_t
(
arg1_shape
.
rank
()),
"Reduction axes count ("
,
m_reduction_axes_count
,
") is too large (arg0 shape: "
,
arg0_shape
,
", arg1 shape: "
,
arg1_shape
,
")."
);
if
(
!
reduction_axes_ambiguous
&&
arg0_shape
.
rank
().
is_static
()
&&
arg1_shape
.
rank
().
is_static
())
{
...
...
@@ -101,12 +114,20 @@ void op::Dot::validate_and_infer_types()
size_t
axis_index_arg0
=
size_t
(
arg0_shape
.
rank
())
-
m_reduction_axes_count
+
i
;
size_t
axis_index_arg1
=
i
;
NODE_VALIDATION_ASSERT
(
this
,
arg0_shape
[
axis_index_arg0
].
compatible
(
arg1_shape
[
axis_index_arg1
]))
<<
"Paired axes (axis "
<<
axis_index_arg0
<<
" from arg0, axis "
<<
axis_index_arg1
<<
" from arg1) do not have same length (arg0 shape: "
<<
arg0_shape
<<
", arg1 shape: "
<<
arg1_shape
<<
", reduction axes count: "
<<
m_reduction_axes_count
<<
")."
;
NODE_VALIDATION_CHECK
(
this
,
arg0_shape
[
axis_index_arg0
].
compatible
(
arg1_shape
[
axis_index_arg1
]),
"Paired axes (axis "
,
axis_index_arg0
,
" from arg0, axis "
,
axis_index_arg1
,
" from arg1) do not have same length (arg0 shape: "
,
arg0_shape
,
", arg1 shape: "
,
arg1_shape
,
", reduction axes count: "
,
m_reduction_axes_count
,
")."
);
}
std
::
vector
<
Dimension
>
result_dims
(
size_t
(
arg0_shape
.
rank
())
+
size_t
(
arg1_shape
.
rank
())
-
...
...
src/ngraph/op/embedding_lookup.cpp
View file @
066037c2
...
...
@@ -26,9 +26,10 @@ void op::EmbeddingLookup::validate_and_infer_types()
const
PartialShape
&
arg0_shape
=
get_input_partial_shape
(
0
);
const
PartialShape
&
arg1_shape
=
get_input_partial_shape
(
1
);
NODE_VALIDATION_ASSERT
(
this
,
arg1_shape
.
rank
().
is_dynamic
()
||
static_cast
<
size_t
>
(
arg1_shape
.
rank
())
==
2
)
<<
"weights are expected to be a matrix"
;
NODE_VALIDATION_CHECK
(
this
,
arg1_shape
.
rank
().
is_dynamic
()
||
static_cast
<
size_t
>
(
arg1_shape
.
rank
())
==
2
,
"weights are expected to be a matrix"
);
PartialShape
result_shape
;
if
(
arg0_shape
.
rank
().
is_static
())
...
...
src/ngraph/op/experimental/generate_mask.cpp
View file @
066037c2
...
...
@@ -42,11 +42,12 @@ shared_ptr<Node> op::GenerateMask::copy_with_new_args(const NodeVector& new_args
void
ngraph
::
op
::
GenerateMask
::
validate_and_infer_types
()
{
NODE_VALIDATION_ASSERT
(
this
,
get_input_partial_shape
(
0
).
compatible
(
PartialShape
{}))
<<
"Training node should be a scalar flag indicating a mode"
;
NODE_VALIDATION_CHECK
(
this
,
get_input_partial_shape
(
0
).
compatible
(
PartialShape
{}),
"Training node should be a scalar flag indicating a mode"
);
NODE_VALIDATION_
ASSERT
(
this
,
m_element_type
.
is_static
())
<<
"Output element type must not be dynamic."
;
NODE_VALIDATION_
CHECK
(
this
,
m_element_type
.
is_static
(),
"Output element type must not be dynamic."
)
;
set_output_type
(
0
,
m_element_type
,
m_shape
);
}
src/ngraph/op/experimental/quantized_avg_pool.cpp
View file @
066037c2
...
...
@@ -65,36 +65,58 @@ void op::QuantizedAvgPool::validate_and_infer_types()
// Make sure batch size and channel count are not zero, and that we have at least one spatial
// dimension (in other words, that arg has shape NCDi for some Di of rank>0, N != 0, C != 0).
//
NODE_VALIDATION_ASSERT
(
this
,
arg_shape
.
size
()
>=
3
)
<<
"Data input shape does not have rank of at least 3 (data input shape: "
<<
arg_shape
<<
")."
;
NODE_VALIDATION_CHECK
(
this
,
arg_shape
.
size
()
>=
3
,
"Data input shape does not have rank of at least 3 (data input shape: "
,
arg_shape
,
")."
);
size_t
batch_size
=
arg_shape
[
0
];
NODE_VALIDATION_
ASSERT
(
this
,
batch_size
!=
0
)
<<
"Data batch size is zero (data input shape: "
<<
arg_shape
<<
")."
;
NODE_VALIDATION_
CHECK
(
this
,
batch_size
!=
0
,
"Data batch size is zero (data input shape: "
,
arg_shape
,
")."
)
;
size_t
channel_count
=
arg_shape
[
1
];
NODE_VALIDATION_
ASSERT
(
this
,
channel_count
!=
0
)
<<
"Channel count is zero (data input shape: "
<<
arg_shape
<<
")."
;
NODE_VALIDATION_
CHECK
(
this
,
channel_count
!=
0
,
"Channel count is zero (data input shape: "
,
arg_shape
,
")."
)
;
size_t
spatial_dimension_count
=
arg_shape
.
size
()
-
2
;
//
// Make sure window shape, window movement strides, and padding have same rank as Di.
//
NODE_VALIDATION_ASSERT
(
this
,
m_window_shape
.
size
()
==
spatial_dimension_count
)
<<
"Window shape rank does not match number of spatial dimensions (window shape: "
<<
m_window_shape
<<
", data input shape: "
<<
arg_shape
<<
")."
;
NODE_VALIDATION_ASSERT
(
this
,
m_window_movement_strides
.
size
()
==
spatial_dimension_count
)
<<
"Window movement stride rank does not match number of spatial dimensions (window "
"movement strides: "
<<
m_window_movement_strides
<<
", data input shape: "
<<
arg_shape
<<
")."
;
NODE_VALIDATION_ASSERT
(
this
,
m_padding_below
.
size
()
==
spatial_dimension_count
)
<<
"Below-padding rank does not match number of spatial dimensions (padding below: "
<<
m_padding_below
<<
", data input shape: "
<<
arg_shape
<<
")."
;
NODE_VALIDATION_ASSERT
(
this
,
m_padding_above
.
size
()
==
spatial_dimension_count
)
<<
"Above-padding rank does not match number of spatial dimensions (padding above: "
<<
m_padding_above
<<
", data input shape: "
<<
arg_shape
<<
")."
;
NODE_VALIDATION_CHECK
(
this
,
m_window_shape
.
size
()
==
spatial_dimension_count
,
"Window shape rank does not match number of spatial dimensions (window shape: "
,
m_window_shape
,
", data input shape: "
,
arg_shape
,
")."
);
NODE_VALIDATION_CHECK
(
this
,
m_window_movement_strides
.
size
()
==
spatial_dimension_count
,
"Window movement stride rank does not match number of spatial dimensions (window "
"movement strides: "
,
m_window_movement_strides
,
", data input shape: "
,
arg_shape
,
")."
);
NODE_VALIDATION_CHECK
(
this
,
m_padding_below
.
size
()
==
spatial_dimension_count
,
"Below-padding rank does not match number of spatial dimensions (padding below: "
,
m_padding_below
,
", data input shape: "
,
arg_shape
,
")."
);
NODE_VALIDATION_CHECK
(
this
,
m_padding_above
.
size
()
==
spatial_dimension_count
,
"Above-padding rank does not match number of spatial dimensions (padding above: "
,
m_padding_above
,
", data input shape: "
,
arg_shape
,
")."
);
//
// Extract input item shape Di and make sure all dimensions are larger than 0.
...
...
@@ -110,10 +132,13 @@ void op::QuantizedAvgPool::validate_and_infer_types()
for
(
size_t
i
=
0
;
i
<
spatial_dimension_count
;
i
++
)
{
NODE_VALIDATION_ASSERT
(
this
,
input_item_virtual_shape
[
i
]
!=
0
)
<<
"Data input spatial dimension "
<<
i
<<
" has zero length even after padding (virtual shape of input item: "
<<
input_item_virtual_shape
<<
")."
;
NODE_VALIDATION_CHECK
(
this
,
input_item_virtual_shape
[
i
]
!=
0
,
"Data input spatial dimension "
,
i
,
" has zero length even after padding (virtual shape of input item: "
,
input_item_virtual_shape
,
")."
);
}
//
...
...
@@ -121,9 +146,13 @@ void op::QuantizedAvgPool::validate_and_infer_types()
//
for
(
size_t
i
=
0
;
i
<
spatial_dimension_count
;
i
++
)
{
NODE_VALIDATION_ASSERT
(
this
,
m_window_shape
[
i
]
!=
0
)
<<
"Window shape dimension "
<<
i
<<
" has zero length (window shape: "
<<
m_window_shape
<<
")."
;
NODE_VALIDATION_CHECK
(
this
,
m_window_shape
[
i
]
!=
0
,
"Window shape dimension "
,
i
,
" has zero length (window shape: "
,
m_window_shape
,
")."
);
}
//
...
...
@@ -131,10 +160,14 @@ void op::QuantizedAvgPool::validate_and_infer_types()
//
for
(
size_t
i
=
0
;
i
<
spatial_dimension_count
;
i
++
)
{
NODE_VALIDATION_ASSERT
(
this
,
m_window_shape
[
i
]
<=
input_item_virtual_shape
[
i
])
<<
"Window shape after padding is larger than the spatial dimensions (window shape: "
<<
m_window_shape
<<
", virtual shape of input item: "
<<
input_item_virtual_shape
<<
")."
;
NODE_VALIDATION_CHECK
(
this
,
m_window_shape
[
i
]
<=
input_item_virtual_shape
[
i
],
"Window shape after padding is larger than the spatial dimensions (window shape: "
,
m_window_shape
,
", virtual shape of input item: "
,
input_item_virtual_shape
,
")."
);
}
//
// Compute output item shape Do, checking at the same time that all window movement strides are larger than 0.
...
...
@@ -143,9 +176,13 @@ void op::QuantizedAvgPool::validate_and_infer_types()
for
(
size_t
i
=
0
;
i
<
spatial_dimension_count
;
i
++
)
{
NODE_VALIDATION_ASSERT
(
this
,
m_window_movement_strides
[
i
]
!=
0
)
<<
"Window movement strides dimension "
<<
i
<<
" has zero length (window movement strides: "
<<
m_window_movement_strides
<<
")."
;
NODE_VALIDATION_CHECK
(
this
,
m_window_movement_strides
[
i
]
!=
0
,
"Window movement strides dimension "
,
i
,
" has zero length (window movement strides: "
,
m_window_movement_strides
,
")."
);
output_item_shape
.
push_back
(
ceil_div
(
input_item_virtual_shape
[
i
]
-
m_window_shape
[
i
]
+
1
,
m_window_movement_strides
[
i
]));
}
...
...
@@ -167,11 +204,15 @@ void op::QuantizedAvgPool::validate_and_infer_types()
// Checking the lower edge of each dimension is easy, because there's no mystery
// regarding the window's lower-edge placement...
NODE_VALIDATION_ASSERT
(
this
,
dim_padding_below
==
0
||
dim_window_size
>
dim_padding_below
)
<<
"Window will sometimes reside entirely within the below-padding region, but"
<<
" include_padding_in_avg_computation was not set (padding below: "
<<
m_padding_below
<<
", window shape: "
<<
m_window_shape
<<
")."
;
NODE_VALIDATION_CHECK
(
this
,
dim_padding_below
==
0
||
dim_window_size
>
dim_padding_below
,
"Window will sometimes reside entirely within the below-padding region, but"
,
" include_padding_in_avg_computation was not set (padding below: "
,
m_padding_below
,
", window shape: "
,
m_window_shape
,
")."
);
// Now check the upper-bound...
{
...
...
@@ -179,13 +220,16 @@ void op::QuantizedAvgPool::validate_and_infer_types()
const
size_t
dim_window_max_lower_offset
=
dim_num_strides
*
dim_stride
;
const
size_t
dim_padding_above_start_offset
=
dim_virtual_size
-
dim_padding_above
;
NODE_VALIDATION_ASSERT
(
this
,
NODE_VALIDATION_CHECK
(
this
,
dim_padding_above
==
0
||
dim_window_max_lower_offset
<
dim_padding_above_start_offset
)
<<
"Window will sometimes reside entirely within the above-padding region, but"
<<
" include_padding_in_avg_computation was not set (padding above: "
<<
m_padding_above
<<
", window shape: "
<<
m_window_shape
<<
")."
;
dim_window_max_lower_offset
<
dim_padding_above_start_offset
,
"Window will sometimes reside entirely within the above-padding region, but"
,
" include_padding_in_avg_computation was not set (padding above: "
,
m_padding_above
,
", window shape: "
,
m_window_shape
,
")."
);
}
}
}
...
...
src/ngraph/op/experimental/quantized_concat.cpp
View file @
066037c2
...
...
@@ -33,7 +33,7 @@ op::QuantizedConcat::QuantizedConcat(const NodeVector& args, size_t concatenatio
void
op
::
QuantizedConcat
::
validate_and_infer_types
()
{
NODE_VALIDATION_
ASSERT
(
this
,
m_inputs
.
size
()
>=
1
)
<<
"At least one argument required."
;
NODE_VALIDATION_
CHECK
(
this
,
m_inputs
.
size
()
>=
1
,
"At least one argument required."
)
;
PartialShape
inputs_shape_scheme
{
PartialShape
::
dynamic
()};
element
::
Type
inputs_et
{
element
::
dynamic
};
...
...
@@ -45,23 +45,32 @@ void op::QuantizedConcat::validate_and_infer_types()
Dimension
this_input_rank
=
this_input_shape
.
rank
();
if
(
this_input_rank
.
is_static
())
{
NODE_VALIDATION_ASSERT
(
this
,
m_concatenation_axis
<
size_t
(
this_input_rank
))
<<
"QuantizedConcatenation axis ("
<<
m_concatenation_axis
<<
") is out of bounds for "
<<
"argument "
<<
i
<<
", which has shape "
<<
this_input_shape
<<
"."
;
NODE_VALIDATION_CHECK
(
this
,
m_concatenation_axis
<
size_t
(
this_input_rank
),
"QuantizedConcatenation axis ("
,
m_concatenation_axis
,
") is out of bounds for "
,
"argument "
,
i
,
", which has shape "
,
this_input_shape
,
"."
);
concatenation_axis_output_dim
+=
this_input_shape
[
m_concatenation_axis
];
this_input_shape
[
m_concatenation_axis
]
=
Dimension
::
dynamic
();
NODE_VALIDATION_ASSERT
(
this
,
PartialShape
::
merge_into
(
inputs_shape_scheme
,
this_input_shape
))
<<
"Argument shapes are inconsistent; they must have the same rank, and must have "
<<
"equal dimension everywhere except on the concatenation axis (axis "
<<
m_concatenation_axis
<<
")."
;
NODE_VALIDATION_CHECK
(
this
,
PartialShape
::
merge_into
(
inputs_shape_scheme
,
this_input_shape
),
"Argument shapes are inconsistent; they must have the same rank, and must have "
,
"equal dimension everywhere except on the concatenation axis (axis "
,
m_concatenation_axis
,
")."
);
NODE_VALIDATION_ASSERT
(
this
,
element
::
Type
::
merge
(
inputs_et
,
inputs_et
,
get_input_element_type
(
i
)))
<<
"Argument element types are inconsistent."
;
NODE_VALIDATION_CHECK
(
this
,
element
::
Type
::
merge
(
inputs_et
,
inputs_et
,
get_input_element_type
(
i
)),
"Argument element types are inconsistent."
);
}
else
{
...
...
src/ngraph/op/experimental/quantized_max_pool.cpp
View file @
066037c2
...
...
@@ -64,36 +64,58 @@ void op::QuantizedMaxPool::validate_and_infer_types()
// Make sure batch size and channel count are not zero, and that we have at least one spatial
// dimension (in other words, that arg has shape NCDi for some Di of rank>0, N != 0, C != 0).
//
NODE_VALIDATION_ASSERT
(
this
,
arg_shape
.
size
()
>=
3
)
<<
"Data input shape does not have rank of at least 3 (data input shape: "
<<
arg_shape
<<
")."
;
NODE_VALIDATION_CHECK
(
this
,
arg_shape
.
size
()
>=
3
,
"Data input shape does not have rank of at least 3 (data input shape: "
,
arg_shape
,
")."
);
size_t
batch_size
=
arg_shape
[
0
];
NODE_VALIDATION_
ASSERT
(
this
,
batch_size
!=
0
)
<<
"Data batch size is zero (data input shape: "
<<
arg_shape
<<
")."
;
NODE_VALIDATION_
CHECK
(
this
,
batch_size
!=
0
,
"Data batch size is zero (data input shape: "
,
arg_shape
,
")."
)
;
size_t
channel_count
=
arg_shape
[
1
];
NODE_VALIDATION_
ASSERT
(
this
,
channel_count
!=
0
)
<<
"Channel count is zero (data input shape: "
<<
arg_shape
<<
")."
;
NODE_VALIDATION_
CHECK
(
this
,
channel_count
!=
0
,
"Channel count is zero (data input shape: "
,
arg_shape
,
")."
)
;
size_t
spatial_dimension_count
=
arg_shape
.
size
()
-
2
;
//
// Make sure window shape, window movement strides, and padding have same rank as Di.
//
NODE_VALIDATION_ASSERT
(
this
,
m_window_shape
.
size
()
==
spatial_dimension_count
)
<<
"Window shape rank does not match number of spatial dimensions (window shape: "
<<
m_window_shape
<<
", data input shape: "
<<
arg_shape
<<
")."
;
NODE_VALIDATION_ASSERT
(
this
,
m_window_movement_strides
.
size
()
==
spatial_dimension_count
)
<<
"Window movement stride rank does not match number of spatial dimensions (window "
"movement strides: "
<<
m_window_movement_strides
<<
", data input shape: "
<<
arg_shape
<<
")."
;
NODE_VALIDATION_ASSERT
(
this
,
m_padding_below
.
size
()
==
spatial_dimension_count
)
<<
"Below-padding rank does not match number of spatial dimensions (padding below: "
<<
m_padding_below
<<
", data input shape: "
<<
arg_shape
<<
")."
;
NODE_VALIDATION_ASSERT
(
this
,
m_padding_above
.
size
()
==
spatial_dimension_count
)
<<
"Above-padding rank does not match number of spatial dimensions (padding above: "
<<
m_padding_above
<<
", data input shape: "
<<
arg_shape
<<
")."
;
NODE_VALIDATION_CHECK
(
this
,
m_window_shape
.
size
()
==
spatial_dimension_count
,
"Window shape rank does not match number of spatial dimensions (window shape: "
,
m_window_shape
,
", data input shape: "
,
arg_shape
,
")."
);
NODE_VALIDATION_CHECK
(
this
,
m_window_movement_strides
.
size
()
==
spatial_dimension_count
,
"Window movement stride rank does not match number of spatial dimensions (window "
"movement strides: "
,
m_window_movement_strides
,
", data input shape: "
,
arg_shape
,
")."
);
NODE_VALIDATION_CHECK
(
this
,
m_padding_below
.
size
()
==
spatial_dimension_count
,
"Below-padding rank does not match number of spatial dimensions (padding below: "
,
m_padding_below
,
", data input shape: "
,
arg_shape
,
")."
);
NODE_VALIDATION_CHECK
(
this
,
m_padding_above
.
size
()
==
spatial_dimension_count
,
"Above-padding rank does not match number of spatial dimensions (padding above: "
,
m_padding_above
,
", data input shape: "
,
arg_shape
,
")."
);
//
// Extract input item shape Di and make sure all dimensions are larger than 0.
...
...
@@ -109,10 +131,13 @@ void op::QuantizedMaxPool::validate_and_infer_types()
for
(
size_t
i
=
0
;
i
<
spatial_dimension_count
;
i
++
)
{
NODE_VALIDATION_ASSERT
(
this
,
input_item_virtual_shape
[
i
]
!=
0
)
<<
"Data input spatial dimension "
<<
i
<<
" has zero length even after padding (virtual shape of input item: "
<<
input_item_virtual_shape
<<
")."
;
NODE_VALIDATION_CHECK
(
this
,
input_item_virtual_shape
[
i
]
!=
0
,
"Data input spatial dimension "
,
i
,
" has zero length even after padding (virtual shape of input item: "
,
input_item_virtual_shape
,
")."
);
}
//
...
...
@@ -120,9 +145,13 @@ void op::QuantizedMaxPool::validate_and_infer_types()
//
for
(
size_t
i
=
0
;
i
<
spatial_dimension_count
;
i
++
)
{
NODE_VALIDATION_ASSERT
(
this
,
m_window_shape
[
i
]
!=
0
)
<<
"Window shape dimension "
<<
i
<<
" has zero length (window shape: "
<<
m_window_shape
<<
")."
;
NODE_VALIDATION_CHECK
(
this
,
m_window_shape
[
i
]
!=
0
,
"Window shape dimension "
,
i
,
" has zero length (window shape: "
,
m_window_shape
,
")."
);
}
//
...
...
@@ -130,10 +159,14 @@ void op::QuantizedMaxPool::validate_and_infer_types()
//
for
(
size_t
i
=
0
;
i
<
spatial_dimension_count
;
i
++
)
{
NODE_VALIDATION_ASSERT
(
this
,
m_window_shape
[
i
]
<=
input_item_virtual_shape
[
i
])
<<
"Window shape after padding is larger than the spatial dimensions (window shape: "
<<
m_window_shape
<<
", virtual shape of input item: "
<<
input_item_virtual_shape
<<
")."
;
NODE_VALIDATION_CHECK
(
this
,
m_window_shape
[
i
]
<=
input_item_virtual_shape
[
i
],
"Window shape after padding is larger than the spatial dimensions (window shape: "
,
m_window_shape
,
", virtual shape of input item: "
,
input_item_virtual_shape
,
")."
);
}
//
...
...
@@ -143,9 +176,13 @@ void op::QuantizedMaxPool::validate_and_infer_types()
for
(
size_t
i
=
0
;
i
<
spatial_dimension_count
;
i
++
)
{
NODE_VALIDATION_ASSERT
(
this
,
m_window_movement_strides
[
i
]
!=
0
)
<<
"Window movement strides dimension "
<<
i
<<
" has zero length (window movement strides: "
<<
m_window_movement_strides
<<
")."
;
NODE_VALIDATION_CHECK
(
this
,
m_window_movement_strides
[
i
]
!=
0
,
"Window movement strides dimension "
,
i
,
" has zero length (window movement strides: "
,
m_window_movement_strides
,
")."
);
output_item_shape
.
push_back
(
ceil_div
(
input_item_virtual_shape
[
i
]
-
m_window_shape
[
i
]
+
1
,
m_window_movement_strides
[
i
]));
}
...
...
src/ngraph/op/get_output_element.cpp
View file @
066037c2
...
...
@@ -30,9 +30,13 @@ op::GetOutputElement::GetOutputElement(const shared_ptr<Node>& arg, size_t n)
void
op
::
GetOutputElement
::
validate_and_infer_types
()
{
NODE_VALIDATION_ASSERT
(
this
,
m_n
<
get_input_size
())
<<
"Output at index "
<<
m_n
<<
" requested, but node has only "
<<
get_input_size
()
<<
" inputs."
;
NODE_VALIDATION_CHECK
(
this
,
m_n
<
get_input_size
(),
"Output at index "
,
m_n
,
" requested, but node has only "
,
get_input_size
(),
" inputs."
);
set_output_type
(
0
,
get_input_element_type
(
m_n
),
get_input_partial_shape
(
m_n
));
}
...
...
src/ngraph/op/lrn.cpp
View file @
066037c2
...
...
@@ -36,9 +36,12 @@ void op::LRN::validate_and_infer_types()
const
PartialShape
&
input_shape
=
get_input_partial_shape
(
0
);
NODE_VALIDATION_ASSERT
(
this
,
input_shape
.
rank
().
is_dynamic
()
||
static_cast
<
size_t
>
(
input_shape
.
rank
())
>=
3
)
<<
"Argument must have rank >= 3 (argument shape: "
<<
input_shape
<<
")."
;
NODE_VALIDATION_CHECK
(
this
,
input_shape
.
rank
().
is_dynamic
()
||
static_cast
<
size_t
>
(
input_shape
.
rank
())
>=
3
,
"Argument must have rank >= 3 (argument shape: "
,
input_shape
,
")."
);
}
shared_ptr
<
Node
>
op
::
LRN
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
...
...
src/ngraph/op/max_pool.cpp
View file @
066037c2
...
...
@@ -134,9 +134,13 @@ void op::MaxPoolBackprop::validate_and_infer_types()
element
::
Type
result_et
;
NODE_VALIDATION_ASSERT
(
this
,
element
::
Type
::
merge
(
result_et
,
forward_arg_et
,
delta_et
))
<<
"Element types for forward argument ("
<<
forward_arg_et
<<
") and delta ("
<<
delta_et
<<
") do not match."
;
NODE_VALIDATION_CHECK
(
this
,
element
::
Type
::
merge
(
result_et
,
forward_arg_et
,
delta_et
),
"Element types for forward argument ("
,
forward_arg_et
,
") and delta ("
,
delta_et
,
") do not match."
);
// infer_batched_forward_pooling wants CoordinateDiffs for these, while the pooling ops for
// now still take Shape (no negative padding).
...
...
@@ -155,9 +159,15 @@ void op::MaxPoolBackprop::validate_and_infer_types()
const
PartialShape
&
delta_shape
=
get_input_partial_shape
(
1
);
NODE_VALIDATION_ASSERT
(
this
,
forward_result_shape
.
compatible
(
delta_shape
))
<<
"Inferred forward output shape does not match delta shape (inferred forward output "
<<
"shape: "
<<
forward_result_shape
<<
", delta shape: "
<<
delta_shape
<<
")."
;
NODE_VALIDATION_CHECK
(
this
,
forward_result_shape
.
compatible
(
delta_shape
),
"Inferred forward output shape does not match delta shape (inferred forward output "
,
"shape: "
,
forward_result_shape
,
", delta shape: "
,
delta_shape
,
")."
);
// TODO(amprocte): We may technically be able to infer some extra information from
// forward_result_shape that was not present in the forward arg shape---namely batch size and
...
...
src/ngraph/op/one_hot.cpp
View file @
066037c2
...
...
@@ -34,16 +34,25 @@ void op::OneHot::validate_and_infer_types()
PartialShape
arg_shape
=
get_input_partial_shape
(
0
);
Rank
arg_rank
=
arg_shape
.
rank
();
NODE_VALIDATION_
ASSERT
(
this
,
m_shape
.
rank
().
is_static
())
<<
"Requested result shape has dynamic rank."
;
NODE_VALIDATION_
CHECK
(
this
,
m_shape
.
rank
().
is_static
(),
"Requested result shape has dynamic rank."
)
;
NODE_VALIDATION_ASSERT
(
this
,
m_one_hot_axis
<
static_cast
<
size_t
>
(
m_shape
.
rank
()))
<<
"One-hot axis ("
<<
m_one_hot_axis
<<
") is out of bounds (requested result shape: "
<<
m_shape
<<
")."
;
NODE_VALIDATION_CHECK
(
this
,
m_one_hot_axis
<
static_cast
<
size_t
>
(
m_shape
.
rank
()),
"One-hot axis ("
,
m_one_hot_axis
,
") is out of bounds (requested result shape: "
,
m_shape
,
")."
);
NODE_VALIDATION_ASSERT
(
this
,
m_shape
[
m_one_hot_axis
].
is_static
())
<<
"Requested result shape ("
<<
m_shape
<<
") has dynamic dimension at the one-hot axis "
<<
"("
<<
m_one_hot_axis
<<
")."
;
NODE_VALIDATION_CHECK
(
this
,
m_shape
[
m_one_hot_axis
].
is_static
(),
"Requested result shape ("
,
m_shape
,
") has dynamic dimension at the one-hot axis "
,
"("
,
m_one_hot_axis
,
")."
);
PartialShape
result_shape
{
m_shape
};
...
...
@@ -58,9 +67,13 @@ void op::OneHot::validate_and_infer_types()
PartialShape
expected_input_shape
{
expected_input_dims
};
PartialShape
merged_input_shape
{
expected_input_shape
};
NODE_VALIDATION_ASSERT
(
this
,
PartialShape
::
merge_into
(
merged_input_shape
,
arg_shape
))
<<
"Argument shape "
<<
arg_shape
<<
" does not match the expected shape of "
<<
expected_input_shape
<<
"."
;
NODE_VALIDATION_CHECK
(
this
,
PartialShape
::
merge_into
(
merged_input_shape
,
arg_shape
),
"Argument shape "
,
arg_shape
,
" does not match the expected shape of "
,
expected_input_shape
,
"."
);
std
::
vector
<
Dimension
>
output_dims
(
static_cast
<
size_t
>
(
merged_input_shape
.
rank
()));
for
(
size_t
i
=
0
;
i
<
static_cast
<
size_t
>
(
merged_input_shape
.
rank
());
i
++
)
...
...
src/ngraph/op/pad.cpp
View file @
066037c2
...
...
@@ -38,31 +38,49 @@ void op::Pad::validate_and_infer_types()
{
element
::
Type
result_et
;
NODE_VALIDATION_ASSERT
(
this
,
element
::
Type
::
merge
(
result_et
,
get_input_element_type
(
0
),
get_input_element_type
(
1
)))
<<
"Argument element types do not match (arg0 element type: "
<<
get_input_element_type
(
0
)
<<
", arg1 element type: "
<<
get_input_element_type
(
1
)
<<
")."
;
NODE_VALIDATION_ASSERT
(
this
,
get_input_partial_shape
(
1
).
compatible
(
PartialShape
{}))
<<
"Argument for padding value is not a scalar (shape: "
<<
get_input_partial_shape
(
1
)
<<
")."
;
NODE_VALIDATION_CHECK
(
this
,
element
::
Type
::
merge
(
result_et
,
get_input_element_type
(
0
),
get_input_element_type
(
1
)),
"Argument element types do not match (arg0 element type: "
,
get_input_element_type
(
0
),
", arg1 element type: "
,
get_input_element_type
(
1
),
")."
);
NODE_VALIDATION_CHECK
(
this
,
get_input_partial_shape
(
1
).
compatible
(
PartialShape
{}),
"Argument for padding value is not a scalar (shape: "
,
get_input_partial_shape
(
1
),
")."
);
auto
arg_shape
=
get_input_partial_shape
(
0
);
NODE_VALIDATION_
ASSERT
(
this
,
NODE_VALIDATION_
CHECK
(
this
,
m_padding_below
.
size
()
==
m_padding_above
.
size
()
&&
m_padding_below
.
size
()
==
m_padding_interior
.
size
())
<<
"Ranks for padding below ("
<<
m_padding_below
<<
"), padding above ("
<<
m_padding_above
<<
") and interior padding ("
<<
m_padding_interior
<<
") "
<<
"do not match."
;
m_padding_below
.
size
()
==
m_padding_interior
.
size
(),
"Ranks for padding below ("
,
m_padding_below
,
"), padding above ("
,
m_padding_above
,
") and interior padding ("
,
m_padding_interior
,
") "
,
"do not match."
);
size_t
implied_rank
=
m_padding_below
.
size
();
NODE_VALIDATION_ASSERT
(
this
,
arg_shape
.
rank
().
compatible
(
implied_rank
))
<<
"Rank for padding below/padding above/interior padding does not match the rank of the "
<<
"data argument (padding below: "
<<
m_padding_below
<<
", "
<<
", padding above: "
<<
m_padding_above
<<
", interior padding: "
<<
m_padding_interior
<<
")."
;
NODE_VALIDATION_CHECK
(
this
,
arg_shape
.
rank
().
compatible
(
implied_rank
),
"Rank for padding below/padding above/interior padding does not match the rank of the "
,
"data argument (padding below: "
,
m_padding_below
,
", "
,
", padding above: "
,
m_padding_above
,
", interior padding: "
,
m_padding_interior
,
")."
);
std
::
vector
<
Dimension
>
result_dims
(
implied_rank
,
Dimension
::
dynamic
());
...
...
src/ngraph/op/quantize.cpp
View file @
066037c2
...
...
@@ -44,50 +44,73 @@ void op::Quantize::validate_and_infer_types()
OFFSET
};
NODE_VALIDATION_
ASSERT
(
this
,
m_type
.
is_static
())
<<
"Output element type must not be dynamic"
;
NODE_VALIDATION_
CHECK
(
this
,
m_type
.
is_static
(),
"Output element type must not be dynamic"
)
;
NODE_VALIDATION_
ASSERT
(
this
,
m_type
.
is_quantized
())
<<
"Output element type ("
<<
m_type
<<
") must be a quantized type"
;
NODE_VALIDATION_
CHECK
(
this
,
m_type
.
is_quantized
(),
"Output element type ("
,
m_type
,
") must be a quantized type"
)
;
element
::
Type
unquantized_type
;
NODE_VALIDATION_
ASSERT
(
this
,
NODE_VALIDATION_
CHECK
(
this
,
element
::
Type
::
merge
(
unquantized_type
,
get_input_element_type
(
INPUT
),
get_input_element_type
(
SCALE
)))
<<
"Scale element type ("
<<
get_input_element_type
(
SCALE
)
<<
") must match input element type ("
<<
get_input_element_type
(
INPUT
)
<<
")"
;
get_input_element_type
(
SCALE
)),
"Scale element type ("
,
get_input_element_type
(
SCALE
),
") must match input element type ("
,
get_input_element_type
(
INPUT
),
")"
);
NODE_VALIDATION_ASSERT
(
this
,
unquantized_type
.
is_dynamic
()
||
unquantized_type
.
is_real
())
<<
"Scale/input element type ("
<<
unquantized_type
<<
") must be a floating point number"
;
NODE_VALIDATION_CHECK
(
this
,
unquantized_type
.
is_dynamic
()
||
unquantized_type
.
is_real
(),
"Scale/input element type ("
,
unquantized_type
,
") must be a floating point number"
);
element
::
Type
quantized_type
;
NODE_VALIDATION_ASSERT
(
this
,
element
::
Type
::
merge
(
quantized_type
,
get_input_element_type
(
OFFSET
),
m_type
))
<<
"Offset element type ("
<<
get_input_element_type
(
OFFSET
)
<<
") must match output element type ("
<<
m_type
<<
")"
;
NODE_VALIDATION_CHECK
(
this
,
element
::
Type
::
merge
(
quantized_type
,
get_input_element_type
(
OFFSET
),
m_type
),
"Offset element type ("
,
get_input_element_type
(
OFFSET
),
") must match output element type ("
,
m_type
,
")"
);
PartialShape
input_shape
=
get_input_partial_shape
(
0
);
Dimension
input_rank
=
input_shape
.
rank
();
for
(
auto
axis
:
m_axes
)
{
NODE_VALIDATION_ASSERT
(
this
,
input_rank
.
is_dynamic
()
||
axis
<
size_t
(
input_rank
))
<<
"Quantization axis ("
<<
axis
<<
") must be less than input shape rank ("
<<
input_rank
<<
")"
;
NODE_VALIDATION_CHECK
(
this
,
input_rank
.
is_dynamic
()
||
axis
<
size_t
(
input_rank
),
"Quantization axis ("
,
axis
,
") must be less than input shape rank ("
,
input_rank
,
")"
);
}
PartialShape
scale_offset_shape
=
get_input_partial_shape
(
SCALE
);
NODE_VALIDATION_ASSERT
(
this
,
PartialShape
::
merge_into
(
scale_offset_shape
,
get_input_partial_shape
(
OFFSET
)))
<<
"Scale shape ("
<<
get_input_partial_shape
(
SCALE
)
<<
") and offset shape ("
<<
get_input_partial_shape
(
OFFSET
)
<<
") must match"
;
NODE_VALIDATION_ASSERT
(
this
,
scale_offset_shape
.
rank
().
compatible
(
m_axes
.
size
()))
<<
"Scale/offset rank ("
<<
scale_offset_shape
.
rank
()
<<
") does not match the number of "
<<
"quantization axes ("
<<
m_axes
.
size
()
<<
")"
;
NODE_VALIDATION_CHECK
(
this
,
PartialShape
::
merge_into
(
scale_offset_shape
,
get_input_partial_shape
(
OFFSET
)),
"Scale shape ("
,
get_input_partial_shape
(
SCALE
),
") and offset shape ("
,
get_input_partial_shape
(
OFFSET
),
") must match"
);
NODE_VALIDATION_CHECK
(
this
,
scale_offset_shape
.
rank
().
compatible
(
m_axes
.
size
()),
"Scale/offset rank ("
,
scale_offset_shape
.
rank
(),
") does not match the number of "
,
"quantization axes ("
,
m_axes
.
size
(),
")"
);
set_output_size
(
1
);
...
...
@@ -110,10 +133,16 @@ void op::Quantize::validate_and_infer_types()
}
PartialShape
result_shape
=
input_shape
;
NODE_VALIDATION_ASSERT
(
this
,
PartialShape
::
merge_into
(
result_shape
,
PartialShape
{
injected_scale_offset_dims
}))
<<
"Scale/offset shape ("
<<
scale_offset_shape
<<
") must match input shape ("
<<
input_shape
<<
") at the quantization axes ("
<<
m_axes
<<
")"
;
NODE_VALIDATION_CHECK
(
this
,
PartialShape
::
merge_into
(
result_shape
,
PartialShape
{
injected_scale_offset_dims
}),
"Scale/offset shape ("
,
scale_offset_shape
,
") must match input shape ("
,
input_shape
,
") at the quantization axes ("
,
m_axes
,
")"
);
set_output_type
(
0
,
quantized_type
,
result_shape
);
}
else
...
...
src/ngraph/op/replace_slice.cpp
View file @
066037c2
...
...
@@ -59,51 +59,85 @@ void op::ReplaceSlice::validate_and_infer_types()
const
PartialShape
&
arg1_shape
=
get_input_partial_shape
(
1
);
Dimension
merged_args_rank
;
NODE_VALIDATION_ASSERT
(
this
,
Dimension
::
merge
(
merged_args_rank
,
arg0_shape
.
rank
(),
arg1_shape
.
rank
()))
<<
"Argument ranks do not match (arg0 shape: "
<<
arg0_shape
<<
", arg1 shape: "
<<
arg1_shape
<<
")."
;
NODE_VALIDATION_CHECK
(
this
,
Dimension
::
merge
(
merged_args_rank
,
arg0_shape
.
rank
(),
arg1_shape
.
rank
()),
"Argument ranks do not match (arg0 shape: "
,
arg0_shape
,
", arg1 shape: "
,
arg1_shape
,
")."
);
element
::
Type
arg0_et
=
get_input_element_type
(
0
);
element
::
Type
arg1_et
=
get_input_element_type
(
1
);
element
::
Type
merged_args_et
;
NODE_VALIDATION_ASSERT
(
this
,
element
::
Type
::
merge
(
merged_args_et
,
arg0_et
,
arg1_et
))
<<
"Argument element types do not match (arg0 element type: "
<<
arg0_et
<<
", arg1 element type: "
<<
arg1_et
<<
")."
;
NODE_VALIDATION_CHECK
(
this
,
element
::
Type
::
merge
(
merged_args_et
,
arg0_et
,
arg1_et
),
"Argument element types do not match (arg0 element type: "
,
arg0_et
,
", arg1 element type: "
,
arg1_et
,
")."
);
NODE_VALIDATION_
ASSERT
(
this
,
NODE_VALIDATION_
CHECK
(
this
,
m_lower_bounds
.
size
()
==
m_upper_bounds
.
size
()
&&
m_lower_bounds
.
size
()
==
m_strides
.
size
())
<<
"Ranks of lower bounds ("
<<
m_lower_bounds
<<
"), upper bounds ("
<<
m_upper_bounds
<<
") and strides ("
<<
m_strides
<<
") do not match."
;
m_lower_bounds
.
size
()
==
m_strides
.
size
(),
"Ranks of lower bounds ("
,
m_lower_bounds
,
"), upper bounds ("
,
m_upper_bounds
,
") and strides ("
,
m_strides
,
") do not match."
);
size_t
output_rank
=
m_upper_bounds
.
size
();
for
(
size_t
i
=
0
;
i
<
output_rank
;
i
++
)
{
NODE_VALIDATION_ASSERT
(
this
,
m_lower_bounds
[
i
]
<=
m_upper_bounds
[
i
])
<<
"Lower bound for slice is greater than upper bound at axis "
<<
i
<<
" (lower bounds: "
<<
m_lower_bounds
<<
", upper bounds: "
<<
m_upper_bounds
<<
")."
;
NODE_VALIDATION_ASSERT
(
this
,
m_strides
[
i
]
!=
0
)
<<
"Stride for slice is zero at axis "
<<
i
<<
" (strides: "
<<
m_strides
<<
")."
;
NODE_VALIDATION_CHECK
(
this
,
m_lower_bounds
[
i
]
<=
m_upper_bounds
[
i
],
"Lower bound for slice is greater than upper bound at axis "
,
i
,
" (lower bounds: "
,
m_lower_bounds
,
", upper bounds: "
,
m_upper_bounds
,
")."
);
NODE_VALIDATION_CHECK
(
this
,
m_strides
[
i
]
!=
0
,
"Stride for slice is zero at axis "
,
i
,
" (strides: "
,
m_strides
,
")."
);
}
NODE_VALIDATION_ASSERT
(
this
,
merged_args_rank
.
is_dynamic
()
||
size_t
(
merged_args_rank
)
==
output_rank
)
<<
"Argument ranks do not match the rank of the lower bounds ("
<<
m_lower_bounds
<<
"), upper bounds ("
<<
m_upper_bounds
<<
"), and strides ("
<<
m_strides
<<
")."
;
NODE_VALIDATION_CHECK
(
this
,
merged_args_rank
.
is_dynamic
()
||
size_t
(
merged_args_rank
)
==
output_rank
,
"Argument ranks do not match the rank of the lower bounds ("
,
m_lower_bounds
,
"), upper bounds ("
,
m_upper_bounds
,
"), and strides ("
,
m_strides
,
")."
);
std
::
vector
<
Dimension
>
sliced_dims
(
output_rank
);
for
(
size_t
i
=
0
;
i
<
output_rank
;
i
++
)
{
NODE_VALIDATION_
ASSERT
(
this
,
NODE_VALIDATION_
CHECK
(
this
,
arg0_shape
.
rank
().
is_dynamic
()
||
arg0_shape
[
i
].
is_dynamic
()
||
m_upper_bounds
[
i
]
<=
size_t
(
arg0_shape
[
i
]))
<<
"Upper bound for slice at axis "
<<
i
<<
" is out of range "
<<
"(upper bounds: "
<<
m_upper_bounds
<<
", argument shape: "
<<
arg0_shape
<<
")."
;
m_upper_bounds
[
i
]
<=
size_t
(
arg0_shape
[
i
]),
"Upper bound for slice at axis "
,
i
,
" is out of range "
,
"(upper bounds: "
,
m_upper_bounds
,
", argument shape: "
,
arg0_shape
,
")."
);
size_t
sliced_dim
=
m_upper_bounds
[
i
]
-
m_lower_bounds
[
i
];
sliced_dim
=
sliced_dim
/
m_strides
[
i
]
+
((
sliced_dim
%
m_strides
[
i
]
==
0
)
?
0
:
1
);
...
...
@@ -112,9 +146,14 @@ void op::ReplaceSlice::validate_and_infer_types()
PartialShape
slice_shape
{
sliced_dims
};
NODE_VALIDATION_ASSERT
(
this
,
arg1_shape
.
compatible
(
slice_shape
))
<<
"Shape of replacement tensor ("
<<
arg1_shape
<<
") does not match the slice shape "
<<
"("
<<
slice_shape
<<
")."
;
NODE_VALIDATION_CHECK
(
this
,
arg1_shape
.
compatible
(
slice_shape
),
"Shape of replacement tensor ("
,
arg1_shape
,
") does not match the slice shape "
,
"("
,
slice_shape
,
")."
);
// Slight corner case here: if arg0 was rank-unknown, we can go ahead and set the output rank
// because the attribs will have given us enough info.
...
...
src/ngraph/op/reshape.cpp
View file @
066037c2
...
...
@@ -41,25 +41,39 @@ void op::Reshape::validate_and_infer_types()
// Check that the input axis order is a permutation of (0,...,n-1) for some n.
for
(
size_t
i
=
0
;
i
<
m_input_order
.
size
();
i
++
)
{
NODE_VALIDATION_ASSERT
(
this
,
find
(
begin
(
m_input_order
),
end
(
m_input_order
),
i
)
!=
end
(
m_input_order
))
<<
"Input axis order is not a permutation of argument's axis indices (axis order: "
<<
m_input_order
<<
", argument shape: "
<<
input_shape
<<
")."
;
NODE_VALIDATION_CHECK
(
this
,
find
(
begin
(
m_input_order
),
end
(
m_input_order
),
i
)
!=
end
(
m_input_order
),
"Input axis order is not a permutation of argument's axis indices (axis order: "
,
m_input_order
,
", argument shape: "
,
input_shape
,
")."
);
}
// TODO(amprocte): should be possible to move around unknown dims in the input shape.
if
(
input_rank
.
is_static
())
{
NODE_VALIDATION_ASSERT
(
this
,
m_input_order
.
size
()
==
size_t
(
input_rank
))
<<
"Input axis order is not a permutation of argument's axis indices (axis order: "
<<
m_input_order
<<
", argument shape: "
<<
input_shape
<<
")."
;
NODE_VALIDATION_CHECK
(
this
,
m_input_order
.
size
()
==
size_t
(
input_rank
),
"Input axis order is not a permutation of argument's axis indices (axis order: "
,
m_input_order
,
", argument shape: "
,
input_shape
,
")."
);
for
(
size_t
i
=
0
;
i
<
size_t
(
input_rank
);
i
++
)
{
auto
it
=
find
(
begin
(
m_input_order
),
end
(
m_input_order
),
i
);
NODE_VALIDATION_ASSERT
(
this
,
it
!=
end
(
m_input_order
))
<<
"Input axis order is not a permutation of argument's axis indices (axis order: "
<<
m_input_order
<<
", argument shape: "
<<
input_shape
<<
")."
;
NODE_VALIDATION_CHECK
(
this
,
it
!=
end
(
m_input_order
),
"Input axis order is not a permutation of argument's axis indices (axis order: "
,
m_input_order
,
", argument shape: "
,
input_shape
,
")."
);
}
// TODO(amprocte): make a partial_shape_size() analogous to shape_size().
...
...
@@ -71,11 +85,16 @@ void op::Reshape::validate_and_infer_types()
if
(
input_shape_product
.
is_static
())
{
NODE_VALIDATION_ASSERT
(
this
,
size_t
(
input_shape_product
)
==
shape_size
(
m_output_shape
))
<<
"Product of output shape dimensions does not match product of argument shape "
"dimensions "
<<
"(output shape: "
<<
m_output_shape
<<
", argument shape: "
<<
input_shape
<<
")."
;
NODE_VALIDATION_CHECK
(
this
,
size_t
(
input_shape_product
)
==
shape_size
(
m_output_shape
),
"Product of output shape dimensions does not match product of argument shape "
"dimensions "
,
"(output shape: "
,
m_output_shape
,
", argument shape: "
,
input_shape
,
")."
);
}
}
...
...
src/ngraph/op/result.cpp
View file @
066037c2
...
...
@@ -32,8 +32,8 @@ op::Result::Result(const shared_ptr<Node>& arg)
void
op
::
Result
::
validate_and_infer_types
()
{
NODE_VALIDATION_
ASSERT
(
this
,
get_input_size
()
==
1
)
<<
"Argument has "
<<
get_input_size
()
<<
" outputs (1 expected)."
;
NODE_VALIDATION_
CHECK
(
this
,
get_input_size
()
==
1
,
"Argument has "
,
get_input_size
(),
" outputs (1 expected)."
)
;
// always borrow the placement conf even the default one
set_placement_index
(
get_argument
(
0
)
->
get_placement_index
());
...
...
src/ngraph/op/reverse.cpp
View file @
066037c2
...
...
@@ -40,9 +40,13 @@ void op::Reverse::validate_and_infer_types()
// Make sure all reversed axis indices are valid.
for
(
size_t
axis
:
m_reversed_axes
)
{
NODE_VALIDATION_ASSERT
(
this
,
axis
<
size_t
(
input_rank
))
<<
"Reverse axis ("
<<
axis
<<
") is out of bounds (argument shape: "
<<
input_shape
<<
")."
;
NODE_VALIDATION_CHECK
(
this
,
axis
<
size_t
(
input_rank
),
"Reverse axis ("
,
axis
,
") is out of bounds (argument shape: "
,
input_shape
,
")."
);
}
}
...
...
src/ngraph/op/reverse_sequence.cpp
View file @
066037c2
...
...
@@ -41,20 +41,31 @@ void op::ReverseSequence::validate_and_infer_types()
auto
input_shape
=
get_input_partial_shape
(
0
);
auto
input_rank
=
input_shape
.
rank
();
NODE_VALIDATION_ASSERT
(
this
,
input_rank
.
is_dynamic
()
||
m_batch_axis
<
size_t
(
input_rank
))
<<
"Batch axis index ("
<<
m_batch_axis
<<
") is out of bounds (argument shape: "
<<
input_shape
<<
")."
;
NODE_VALIDATION_CHECK
(
this
,
input_rank
.
is_dynamic
()
||
m_batch_axis
<
size_t
(
input_rank
),
"Batch axis index ("
,
m_batch_axis
,
") is out of bounds (argument shape: "
,
input_shape
,
")."
);
NODE_VALIDATION_ASSERT
(
this
,
input_rank
.
is_dynamic
()
||
m_seq_axis
<
size_t
(
input_rank
))
<<
"Sequence axis index ("
<<
m_seq_axis
<<
") is out of bounds (argument shape: "
<<
input_shape
<<
")."
;
NODE_VALIDATION_CHECK
(
this
,
input_rank
.
is_dynamic
()
||
m_seq_axis
<
size_t
(
input_rank
),
"Sequence axis index ("
,
m_seq_axis
,
") is out of bounds (argument shape: "
,
input_shape
,
")."
);
auto
indices_shape
=
get_input_partial_shape
(
1
);
auto
indices_rank
=
indices_shape
.
rank
();
NODE_VALIDATION_ASSERT
(
this
,
indices_rank
.
is_dynamic
()
||
size_t
(
indices_rank
)
==
1
)
<<
"Sequence indices must be a 1-dimensional tensor (sequence indices shape: "
<<
get_input_partial_shape
(
1
)
<<
")."
;
NODE_VALIDATION_CHECK
(
this
,
indices_rank
.
is_dynamic
()
||
size_t
(
indices_rank
)
==
1
,
"Sequence indices must be a 1-dimensional tensor (sequence indices shape: "
,
get_input_partial_shape
(
1
),
")."
);
PartialShape
output_shape
{
input_shape
};
...
...
@@ -62,12 +73,19 @@ void op::ReverseSequence::validate_and_infer_types()
{
Dimension
merged_sequence_length
;
NODE_VALIDATION_
ASSERT
(
NODE_VALIDATION_
CHECK
(
this
,
Dimension
::
merge
(
merged_sequence_length
,
input_shape
[
m_batch_axis
],
indices_shape
[
0
]))
<<
"Sequence length ("
<<
indices_shape
[
0
]
<<
") is not equal to batch axis "
<<
"dimension ("
<<
input_shape
[
m_batch_axis
]
<<
") (argument shape: "
<<
input_shape
<<
", sequence indices shape: "
<<
indices_shape
<<
")."
;
Dimension
::
merge
(
merged_sequence_length
,
input_shape
[
m_batch_axis
],
indices_shape
[
0
]),
"Sequence length ("
,
indices_shape
[
0
],
") is not equal to batch axis "
,
"dimension ("
,
input_shape
[
m_batch_axis
],
") (argument shape: "
,
input_shape
,
", sequence indices shape: "
,
indices_shape
,
")."
);
output_shape
[
m_batch_axis
]
=
merged_sequence_length
;
}
...
...
src/ngraph/op/select.cpp
View file @
066037c2
...
...
@@ -36,24 +36,28 @@ op::Select::Select(const shared_ptr<Node>& arg0,
void
op
::
Select
::
validate_and_infer_types
()
{
NODE_VALIDATION_
ASSERT
(
this
,
NODE_VALIDATION_
CHECK
(
this
,
get_input_element_type
(
0
).
is_dynamic
()
||
get_input_element_type
(
0
)
==
element
::
boolean
)
<<
"Argument 0 does not have boolean element type (element type: "
<<
get_input_element_type
(
0
)
<<
")."
;
get_input_element_type
(
0
)
==
element
::
boolean
,
"Argument 0 does not have boolean element type (element type: "
,
get_input_element_type
(
0
),
")."
);
PartialShape
result_shape
=
get_input_partial_shape
(
0
);
NODE_VALIDATION_ASSERT
(
this
,
PartialShape
::
merge_into
(
result_shape
,
get_input_partial_shape
(
1
)))
<<
"Argument shapes are inconsistent."
;
NODE_VALIDATION_ASSERT
(
this
,
PartialShape
::
merge_into
(
result_shape
,
get_input_partial_shape
(
2
)))
<<
"Argument shapes are inconsistent."
;
NODE_VALIDATION_CHECK
(
this
,
PartialShape
::
merge_into
(
result_shape
,
get_input_partial_shape
(
1
)),
"Argument shapes are inconsistent."
);
NODE_VALIDATION_CHECK
(
this
,
PartialShape
::
merge_into
(
result_shape
,
get_input_partial_shape
(
2
)),
"Argument shapes are inconsistent."
);
element
::
Type
result_et
;
NODE_VALIDATION_ASSERT
(
this
,
element
::
Type
::
merge
(
result_et
,
get_input_element_type
(
1
),
get_input_element_type
(
2
)))
<<
"Argument 1 and 2 element types are inconsistent."
;
NODE_VALIDATION_CHECK
(
this
,
element
::
Type
::
merge
(
result_et
,
get_input_element_type
(
1
),
get_input_element_type
(
2
)),
"Argument 1 and 2 element types are inconsistent."
);
set_output_type
(
0
,
result_et
,
result_shape
);
}
...
...
src/ngraph/op/slice.cpp
View file @
066037c2
...
...
@@ -51,40 +51,68 @@ void op::Slice::validate_and_infer_types()
m_strides
=
Strides
(
m_lower_bounds
.
size
(),
1
);
}
NODE_VALIDATION_
ASSERT
(
this
,
NODE_VALIDATION_
CHECK
(
this
,
m_lower_bounds
.
size
()
==
m_upper_bounds
.
size
()
&&
m_lower_bounds
.
size
()
==
m_strides
.
size
())
<<
"Ranks of lower bounds ("
<<
m_lower_bounds
<<
"), upper bounds ("
<<
m_upper_bounds
<<
") and strides ("
<<
m_strides
<<
") do not match."
;
m_lower_bounds
.
size
()
==
m_strides
.
size
(),
"Ranks of lower bounds ("
,
m_lower_bounds
,
"), upper bounds ("
,
m_upper_bounds
,
") and strides ("
,
m_strides
,
") do not match."
);
size_t
output_rank
=
m_upper_bounds
.
size
();
for
(
size_t
i
=
0
;
i
<
output_rank
;
i
++
)
{
NODE_VALIDATION_ASSERT
(
this
,
m_lower_bounds
[
i
]
<=
m_upper_bounds
[
i
])
<<
"Lower bound for slice is greater than upper bound at axis "
<<
i
<<
" (lower bounds: "
<<
m_lower_bounds
<<
", upper bounds: "
<<
m_upper_bounds
<<
")."
;
NODE_VALIDATION_ASSERT
(
this
,
m_strides
[
i
]
!=
0
)
<<
"Stride for slice is zero at axis "
<<
i
<<
" (strides: "
<<
m_strides
<<
")."
;
NODE_VALIDATION_CHECK
(
this
,
m_lower_bounds
[
i
]
<=
m_upper_bounds
[
i
],
"Lower bound for slice is greater than upper bound at axis "
,
i
,
" (lower bounds: "
,
m_lower_bounds
,
", upper bounds: "
,
m_upper_bounds
,
")."
);
NODE_VALIDATION_CHECK
(
this
,
m_strides
[
i
]
!=
0
,
"Stride for slice is zero at axis "
,
i
,
" (strides: "
,
m_strides
,
")."
);
}
const
PartialShape
&
input_shape
=
get_input_partial_shape
(
0
);
Dimension
input_rank
=
input_shape
.
rank
();
NODE_VALIDATION_ASSERT
(
this
,
input_rank
.
is_dynamic
()
||
size_t
(
input_rank
)
==
output_rank
)
<<
"Input rank does not match the rank of the lower bounds ("
<<
m_lower_bounds
<<
"), upper bounds ("
<<
m_upper_bounds
<<
"), and strides ("
<<
m_strides
<<
")."
;
NODE_VALIDATION_CHECK
(
this
,
input_rank
.
is_dynamic
()
||
size_t
(
input_rank
)
==
output_rank
,
"Input rank does not match the rank of the lower bounds ("
,
m_lower_bounds
,
"), upper bounds ("
,
m_upper_bounds
,
"), and strides ("
,
m_strides
,
")."
);
std
::
vector
<
Dimension
>
result_dims
(
output_rank
);
for
(
size_t
i
=
0
;
i
<
output_rank
;
i
++
)
{
NODE_VALIDATION_
ASSERT
(
this
,
NODE_VALIDATION_
CHECK
(
this
,
input_rank
.
is_dynamic
()
||
input_shape
[
i
].
is_dynamic
()
||
m_upper_bounds
[
i
]
<=
size_t
(
input_shape
[
i
]))
<<
"Upper bound for slice at axis "
<<
i
<<
" is out of range "
<<
"(upper bounds: "
<<
m_upper_bounds
<<
", argument shape: "
<<
input_shape
<<
")."
;
m_upper_bounds
[
i
]
<=
size_t
(
input_shape
[
i
]),
"Upper bound for slice at axis "
,
i
,
" is out of range "
,
"(upper bounds: "
,
m_upper_bounds
,
", argument shape: "
,
input_shape
,
")."
);
size_t
result_axis_size
=
m_upper_bounds
[
i
]
-
m_lower_bounds
[
i
];
result_axis_size
=
...
...
src/ngraph/op/softmax.cpp
View file @
066037c2
...
...
@@ -37,9 +37,13 @@ op::Softmax::Softmax(const shared_ptr<Node>& arg, const AxisSet& axes)
for
(
auto
axis
:
m_axes
)
{
NODE_VALIDATION_ASSERT
(
this
,
axis
<
get_shape
().
size
())
<<
"Reduction axis ("
<<
axis
<<
") is out of bounds (argument shape: "
<<
get_shape
()
<<
")."
;
NODE_VALIDATION_CHECK
(
this
,
axis
<
get_shape
().
size
(),
"Reduction axis ("
,
axis
,
") is out of bounds (argument shape: "
,
get_shape
(),
")."
);
}
// empty axes == all axes
...
...
src/ngraph/op/topk.cpp
View file @
066037c2
...
...
@@ -43,26 +43,36 @@ void op::TopK::validate_and_infer_types()
Rank
input_rank
=
input_shape
.
rank
();
element
::
Type
input_element_type
=
get_input_element_type
(
0
);
NODE_VALIDATION_
ASSERT
(
this
,
!
m_index_element_type
.
is_dynamic
())
<<
"Argument element type must not be dynamic."
;
NODE_VALIDATION_
CHECK
(
this
,
!
m_index_element_type
.
is_dynamic
(),
"Argument element type must not be dynamic."
)
;
NODE_VALIDATION_ASSERT
(
this
,
m_index_element_type
==
element
::
i32
||
m_index_element_type
==
element
::
i64
)
<<
"Argument element type must be i64 or i32 (got "
<<
m_index_element_type
<<
")."
;
NODE_VALIDATION_CHECK
(
this
,
m_index_element_type
==
element
::
i32
||
m_index_element_type
==
element
::
i64
,
"Argument element type must be i64 or i32 (got "
,
m_index_element_type
,
")."
);
NODE_VALIDATION_ASSERT
(
this
,
input_rank
.
is_dynamic
()
||
static_cast
<
size_t
>
(
input_rank
)
>
0
)
<<
"Argument rank must be greater than 0."
;
NODE_VALIDATION_CHECK
(
this
,
input_rank
.
is_dynamic
()
||
static_cast
<
size_t
>
(
input_rank
)
>
0
,
"Argument rank must be greater than 0."
);
NODE_VALIDATION_ASSERT
(
this
,
input_rank
.
is_dynamic
()
||
m_top_k_axis
<
static_cast
<
size_t
>
(
input_rank
))
<<
"TopK axis ("
<<
m_top_k_axis
<<
") is out of bounds."
;
NODE_VALIDATION_CHECK
(
this
,
input_rank
.
is_dynamic
()
||
m_top_k_axis
<
static_cast
<
size_t
>
(
input_rank
),
"TopK axis ("
,
m_top_k_axis
,
") is out of bounds."
);
NODE_VALIDATION_
ASSERT
(
this
,
NODE_VALIDATION_
CHECK
(
this
,
input_rank
.
is_dynamic
()
||
input_shape
[
m_top_k_axis
].
is_dynamic
()
||
m_k
<=
static_cast
<
size_t
>
(
input_shape
[
m_top_k_axis
]))
<<
"K ("
<<
m_k
<<
") exceeds the dimension ("
<<
(
input_rank
.
is_static
()
?
input_shape
[
m_top_k_axis
]
:
0
)
<<
") of the TopK axis (axis "
<<
m_top_k_axis
<<
")."
;
m_k
<=
static_cast
<
size_t
>
(
input_shape
[
m_top_k_axis
]),
"K ("
,
m_k
,
") exceeds the dimension ("
,
(
input_rank
.
is_static
()
?
input_shape
[
m_top_k_axis
]
:
0
),
") of the TopK axis (axis "
,
m_top_k_axis
,
")."
);
PartialShape
output_shape
{
input_shape
};
...
...
src/ngraph/op/util/arithmetic_reduction.cpp
View file @
066037c2
...
...
@@ -40,10 +40,16 @@ void op::util::ArithmeticReduction::validate_and_infer_types()
for
(
auto
axis
:
m_reduction_axes
)
{
NODE_VALIDATION_ASSERT
(
this
,
axis
<
size_t
(
input_rank
))
<<
"Reduction axis ("
<<
axis
<<
") is out of bounds "
<<
"(argument shape: "
<<
input_shape
<<
", reduction axes: "
<<
m_reduction_axes
<<
")"
;
NODE_VALIDATION_CHECK
(
this
,
axis
<
size_t
(
input_rank
),
"Reduction axis ("
,
axis
,
") is out of bounds "
,
"(argument shape: "
,
input_shape
,
", reduction axes: "
,
m_reduction_axes
,
")"
);
}
for
(
size_t
i
=
0
;
i
<
size_t
(
input_rank
);
i
++
)
...
...
src/ngraph/op/util/index_reduction.cpp
View file @
066037c2
...
...
@@ -37,13 +37,18 @@ void op::util::IndexReduction::validate_and_infer_types()
const
PartialShape
&
arg_shape
=
get_input_partial_shape
(
0
);
Rank
rank
=
arg_shape
.
rank
();
NODE_VALIDATION_ASSERT
(
this
,
rank
.
is_dynamic
()
||
size_t
(
rank
)
>=
1
)
<<
"Argument rank is zero."
;
NODE_VALIDATION_ASSERT
(
this
,
rank
.
is_dynamic
()
||
m_axis
<
size_t
(
rank
))
<<
"Reduction axis ("
<<
m_axis
<<
") is not less than argument rank ("
<<
rank
<<
")."
;
NODE_VALIDATION_ASSERT
(
this
,
m_index_element_type
==
element
::
i32
||
m_index_element_type
==
element
::
i64
)
<<
"Index element is neither i64 or i32."
;
NODE_VALIDATION_CHECK
(
this
,
rank
.
is_dynamic
()
||
size_t
(
rank
)
>=
1
,
"Argument rank is zero."
);
NODE_VALIDATION_CHECK
(
this
,
rank
.
is_dynamic
()
||
m_axis
<
size_t
(
rank
),
"Reduction axis ("
,
m_axis
,
") is not less than argument rank ("
,
rank
,
")."
);
NODE_VALIDATION_CHECK
(
this
,
m_index_element_type
==
element
::
i32
||
m_index_element_type
==
element
::
i64
,
"Index element is neither i64 or i32."
);
PartialShape
output_shape
{
PartialShape
::
dynamic
()};
...
...
src/ngraph/op/util/logical_reduction.cpp
View file @
066037c2
...
...
@@ -40,10 +40,16 @@ void op::util::LogicalReduction::validate_and_infer_types()
for
(
auto
axis
:
m_reduction_axes
)
{
NODE_VALIDATION_ASSERT
(
this
,
axis
<
size_t
(
input_rank
))
<<
"Reduction axis ("
<<
axis
<<
") is out of bounds "
<<
"(argument shape: "
<<
input_shape
<<
", reduction axes: "
<<
m_reduction_axes
<<
")"
;
NODE_VALIDATION_CHECK
(
this
,
axis
<
size_t
(
input_rank
),
"Reduction axis ("
,
axis
,
") is out of bounds "
,
"(argument shape: "
,
input_shape
,
", reduction axes: "
,
m_reduction_axes
,
")"
);
}
for
(
size_t
i
=
0
;
i
<
size_t
(
input_rank
);
i
++
)
...
...
@@ -57,8 +63,9 @@ void op::util::LogicalReduction::validate_and_infer_types()
result_shape
=
PartialShape
(
dims
);
}
NODE_VALIDATION_ASSERT
(
this
,
get_input_element_type
(
0
).
compatible
(
element
::
boolean
))
<<
"Input element type must be boolean."
;
NODE_VALIDATION_CHECK
(
this
,
get_input_element_type
(
0
).
compatible
(
element
::
boolean
),
"Input element type must be boolean."
);
set_output_type
(
0
,
element
::
boolean
,
result_shape
);
}
src/ngraph/runtime/cpu/op/conv_add.cpp
View file @
066037c2
...
...
@@ -29,9 +29,14 @@ void op::util::validate_conv_shapes(const Node* node,
const
Shape
&
data_shape
,
const
Shape
&
filters_shape
)
{
NODE_VALIDATION_ASSERT
(
node
,
data_shape
[
1
]
==
filters_shape
[
1
])
<<
"Number of channels for data and filters do not match (data num channels: "
<<
data_shape
[
1
]
<<
", filters num channels: "
<<
filters_shape
[
1
]
<<
")."
;
NODE_VALIDATION_CHECK
(
node
,
data_shape
[
1
]
==
filters_shape
[
1
],
"Number of channels for data and filters do not match (data num channels: "
,
data_shape
[
1
],
", filters num channels: "
,
filters_shape
[
1
],
")."
);
}
op
::
ConvolutionAdd
::
ConvolutionAdd
(
const
std
::
shared_ptr
<
op
::
Convolution
>&
conv
,
...
...
@@ -79,9 +84,14 @@ op::ConvolutionAdd::ConvolutionAdd(const std::shared_ptr<Node>& data_batch,
//
// Make sure data batch and filter element types match.
//
NODE_VALIDATION_ASSERT
(
this
,
data_batch_et
==
filters_et
)
<<
"Element types for data_batch and filters do not match (data batch element type: "
<<
data_batch_et
<<
", filters element type: "
<<
filters_et
<<
")."
;
NODE_VALIDATION_CHECK
(
this
,
data_batch_et
==
filters_et
,
"Element types for data_batch and filters do not match (data batch element type: "
,
data_batch_et
,
", filters element type: "
,
filters_et
,
")."
);
util
::
validate_conv_shapes
(
this
,
data_batch_shape
,
filters_shape
);
set_output_type
(
0
,
...
...
@@ -105,8 +115,11 @@ op::ConvolutionAdd::ConvolutionAdd(const std::shared_ptr<Node>& data_batch,
std
::
shared_ptr
<
Node
>
op
::
ConvolutionAdd
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
NODE_VALIDATION_ASSERT
(
this
,
new_args
.
size
()
==
3
)
<<
"New arg size is not 3 (new args size: "
<<
new_args
.
size
()
<<
")."
;
NODE_VALIDATION_CHECK
(
this
,
new_args
.
size
()
==
3
,
"New arg size is not 3 (new args size: "
,
new_args
.
size
(),
")."
);
return
std
::
shared_ptr
<
Node
>
(
new
ConvolutionAdd
(
new_args
.
at
(
0
),
new_args
.
at
(
1
),
...
...
src/ngraph/runtime/cpu/op/update_slice.cpp
View file @
066037c2
...
...
@@ -57,51 +57,85 @@ void op::UpdateSlice::validate_and_infer_types()
const
PartialShape
&
arg1_shape
=
get_input_partial_shape
(
1
);
Dimension
merged_args_rank
;
NODE_VALIDATION_ASSERT
(
this
,
Dimension
::
merge
(
merged_args_rank
,
arg0_shape
.
rank
(),
arg1_shape
.
rank
()))
<<
"Argument ranks do not match (arg0 shape: "
<<
arg0_shape
<<
", arg1 shape: "
<<
arg1_shape
<<
")."
;
NODE_VALIDATION_CHECK
(
this
,
Dimension
::
merge
(
merged_args_rank
,
arg0_shape
.
rank
(),
arg1_shape
.
rank
()),
"Argument ranks do not match (arg0 shape: "
,
arg0_shape
,
", arg1 shape: "
,
arg1_shape
,
")."
);
element
::
Type
arg0_et
=
get_input_element_type
(
0
);
element
::
Type
arg1_et
=
get_input_element_type
(
1
);
element
::
Type
merged_args_et
;
NODE_VALIDATION_ASSERT
(
this
,
element
::
Type
::
merge
(
merged_args_et
,
arg0_et
,
arg1_et
))
<<
"Argument element types do not match (arg0 element type: "
<<
arg0_et
<<
", arg1 element type: "
<<
arg1_et
<<
")."
;
NODE_VALIDATION_CHECK
(
this
,
element
::
Type
::
merge
(
merged_args_et
,
arg0_et
,
arg1_et
),
"Argument element types do not match (arg0 element type: "
,
arg0_et
,
", arg1 element type: "
,
arg1_et
,
")."
);
NODE_VALIDATION_
ASSERT
(
this
,
NODE_VALIDATION_
CHECK
(
this
,
m_lower_bounds
.
size
()
==
m_upper_bounds
.
size
()
&&
m_lower_bounds
.
size
()
==
m_strides
.
size
())
<<
"Ranks of lower bounds ("
<<
m_lower_bounds
<<
"), upper bounds ("
<<
m_upper_bounds
<<
") and strides ("
<<
m_strides
<<
") do not match."
;
m_lower_bounds
.
size
()
==
m_strides
.
size
(),
"Ranks of lower bounds ("
,
m_lower_bounds
,
"), upper bounds ("
,
m_upper_bounds
,
") and strides ("
,
m_strides
,
") do not match."
);
size_t
output_rank
=
m_upper_bounds
.
size
();
for
(
size_t
i
=
0
;
i
<
output_rank
;
i
++
)
{
NODE_VALIDATION_ASSERT
(
this
,
m_lower_bounds
[
i
]
<=
m_upper_bounds
[
i
])
<<
"Lower bound for slice is greater than upper bound at axis "
<<
i
<<
" (lower bounds: "
<<
m_lower_bounds
<<
", upper bounds: "
<<
m_upper_bounds
<<
")."
;
NODE_VALIDATION_ASSERT
(
this
,
m_strides
[
i
]
!=
0
)
<<
"Stride for slice is zero at axis "
<<
i
<<
" (strides: "
<<
m_strides
<<
")."
;
NODE_VALIDATION_CHECK
(
this
,
m_lower_bounds
[
i
]
<=
m_upper_bounds
[
i
],
"Lower bound for slice is greater than upper bound at axis "
,
i
,
" (lower bounds: "
,
m_lower_bounds
,
", upper bounds: "
,
m_upper_bounds
,
")."
);
NODE_VALIDATION_CHECK
(
this
,
m_strides
[
i
]
!=
0
,
"Stride for slice is zero at axis "
,
i
,
" (strides: "
,
m_strides
,
")."
);
}
NODE_VALIDATION_ASSERT
(
this
,
merged_args_rank
.
is_dynamic
()
||
size_t
(
merged_args_rank
)
==
output_rank
)
<<
"Argument ranks do not match the rank of the lower bounds ("
<<
m_lower_bounds
<<
"), upper bounds ("
<<
m_upper_bounds
<<
"), and strides ("
<<
m_strides
<<
")."
;
NODE_VALIDATION_CHECK
(
this
,
merged_args_rank
.
is_dynamic
()
||
size_t
(
merged_args_rank
)
==
output_rank
,
"Argument ranks do not match the rank of the lower bounds ("
,
m_lower_bounds
,
"), upper bounds ("
,
m_upper_bounds
,
"), and strides ("
,
m_strides
,
")."
);
std
::
vector
<
Dimension
>
sliced_dims
(
output_rank
);
for
(
size_t
i
=
0
;
i
<
output_rank
;
i
++
)
{
NODE_VALIDATION_
ASSERT
(
this
,
NODE_VALIDATION_
CHECK
(
this
,
arg0_shape
.
rank
().
is_dynamic
()
||
arg0_shape
[
i
].
is_dynamic
()
||
m_upper_bounds
[
i
]
<=
size_t
(
arg0_shape
[
i
]))
<<
"Upper bound for slice at axis "
<<
i
<<
" is out of range "
<<
"(upper bounds: "
<<
m_upper_bounds
<<
", argument shape: "
<<
arg0_shape
<<
")."
;
m_upper_bounds
[
i
]
<=
size_t
(
arg0_shape
[
i
]),
"Upper bound for slice at axis "
,
i
,
" is out of range "
,
"(upper bounds: "
,
m_upper_bounds
,
", argument shape: "
,
arg0_shape
,
")."
);
size_t
sliced_dim
=
m_upper_bounds
[
i
]
-
m_lower_bounds
[
i
];
sliced_dim
=
sliced_dim
/
m_strides
[
i
]
+
((
sliced_dim
%
m_strides
[
i
]
==
0
)
?
0
:
1
);
...
...
@@ -110,9 +144,14 @@ void op::UpdateSlice::validate_and_infer_types()
PartialShape
slice_shape
{
sliced_dims
};
NODE_VALIDATION_ASSERT
(
this
,
arg1_shape
.
compatible
(
slice_shape
))
<<
"Shape of replacement tensor ("
<<
arg1_shape
<<
") does not match the slice shape "
<<
"("
<<
slice_shape
<<
")."
;
NODE_VALIDATION_CHECK
(
this
,
arg1_shape
.
compatible
(
slice_shape
),
"Shape of replacement tensor ("
,
arg1_shape
,
") does not match the slice shape "
,
"("
,
slice_shape
,
")."
);
// Slight corner case here: if arg0 was rank-unknown, we can go ahead and set the output rank
// because the attribs will have given us enough info.
...
...
test/type_prop.cpp
View file @
066037c2
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment