Commit 066037c2 authored by Adam Procter's avatar Adam Procter Committed by Scott Cyphers

Switch everything from NODE_VALIDATION_ASSERT to NODE_VALIDATION_CHECK (#2546)

parent dd23b0cb
...@@ -275,19 +275,6 @@ namespace ngraph ...@@ -275,19 +275,6 @@ namespace ngraph
size_t m_placement_index = placement_invalid; size_t m_placement_index = placement_invalid;
}; };
class NodeValidationError : public AssertionFailure
{
public:
NodeValidationError(std::string what)
: AssertionFailure(what)
{
}
NodeValidationError(const char* what)
: AssertionFailure(what)
{
}
};
class NodeValidationFailure : public CheckFailure class NodeValidationFailure : public CheckFailure
{ {
public: public:
...@@ -321,12 +308,5 @@ namespace ngraph ...@@ -321,12 +308,5 @@ namespace ngraph
void check_new_args_count(const Node* node, const NodeVector& new_args); void check_new_args_count(const Node* node, const NodeVector& new_args);
} // namespace ngraph } // namespace ngraph
#define NODE_VALIDATION_ASSERT(node, cond) \
NGRAPH_ASSERT_STREAM_WITH_LOC( \
::ngraph::NodeValidationError, cond, ::ngraph::node_validation_assertion_string(node))
#define NODE_VALIDATION_FAIL(node) \
NGRAPH_FAIL_STREAM_WITH_LOC(::ngraph::NodeValidationError, \
::ngraph::node_validation_assertion_string(node))
#define NODE_VALIDATION_CHECK(node, cond, ...) \ #define NODE_VALIDATION_CHECK(node, cond, ...) \
NGRAPH_CHECK(::NodeValidationFailure, (node), (cond), __VA_ARGS__) NGRAPH_CHECK(::ngraph::NodeValidationFailure, (node), (cond), __VA_ARGS__)
...@@ -27,12 +27,13 @@ op::AllReduce::AllReduce(const shared_ptr<Node>& arg) ...@@ -27,12 +27,13 @@ op::AllReduce::AllReduce(const shared_ptr<Node>& arg)
void op::AllReduce::validate_and_infer_types() void op::AllReduce::validate_and_infer_types()
{ {
NODE_VALIDATION_ASSERT(this, NODE_VALIDATION_CHECK(this,
get_input_element_type(0).is_dynamic() || get_input_element_type(0).is_dynamic() ||
get_input_element_type(0) == element::f32 || get_input_element_type(0) == element::f32 ||
get_input_element_type(0) == element::f64) get_input_element_type(0) == element::f64,
<< "Only element types f32 and f64 are supported (argument element type: " "Only element types f32 and f64 are supported (argument element type: ",
<< get_input_element_type(0) << ")."; get_input_element_type(0),
").");
set_output_type(0, get_input_element_type(0), get_input_partial_shape(0)); set_output_type(0, get_input_element_type(0), get_input_partial_shape(0));
} }
......
...@@ -132,9 +132,15 @@ void op::AvgPoolBackprop::validate_and_infer_types() ...@@ -132,9 +132,15 @@ void op::AvgPoolBackprop::validate_and_infer_types()
const PartialShape& delta_shape = get_input_partial_shape(0); const PartialShape& delta_shape = get_input_partial_shape(0);
NODE_VALIDATION_ASSERT(this, forward_result_shape.compatible(delta_shape)) NODE_VALIDATION_CHECK(
<< "Inferred forward output shape does not match delta shape (inferred forward output " this,
<< "shape: " << forward_result_shape << ", delta shape: " << delta_shape << ")."; forward_result_shape.compatible(delta_shape),
"Inferred forward output shape does not match delta shape (inferred forward output ",
"shape: ",
forward_result_shape,
", delta shape: ",
delta_shape,
").");
// TODO(amprocte): Once m_forward_arg_shape is allowed to be dynamic, we may technically be // TODO(amprocte): Once m_forward_arg_shape is allowed to be dynamic, we may technically be
// able to infer some extra information from forward_result_shape that was not present in the // able to infer some extra information from forward_result_shape that was not present in the
......
...@@ -205,21 +205,26 @@ void ngraph::op::BatchNormTrainingBackprop::validate_and_infer_types() ...@@ -205,21 +205,26 @@ void ngraph::op::BatchNormTrainingBackprop::validate_and_infer_types()
{ {
PartialShape input_and_delta_shape{get_input_partial_shape(INPUT_DATA)}; PartialShape input_and_delta_shape{get_input_partial_shape(INPUT_DATA)};
NODE_VALIDATION_ASSERT( NODE_VALIDATION_CHECK(
this, PartialShape::merge_into(input_and_delta_shape, get_input_partial_shape(INPUT_DELTA))) this,
<< "Shape of delta does not match the shape of the input data (input data shape: " PartialShape::merge_into(input_and_delta_shape, get_input_partial_shape(INPUT_DELTA)),
<< get_input_partial_shape(INPUT_DATA) "Shape of delta does not match the shape of the input data (input data shape: ",
<< ", delta shape: " << get_input_partial_shape(INPUT_DELTA) << ")."; get_input_partial_shape(INPUT_DATA),
", delta shape: ",
get_input_partial_shape(INPUT_DELTA),
").");
element::Type input_and_delta_et; element::Type input_and_delta_et;
NODE_VALIDATION_ASSERT(this, NODE_VALIDATION_CHECK(this,
element::Type::merge(input_and_delta_et, element::Type::merge(input_and_delta_et,
get_input_element_type(INPUT_DATA), get_input_element_type(INPUT_DATA),
get_input_element_type(INPUT_DELTA))) get_input_element_type(INPUT_DELTA)),
<< "Element type for input (" << get_input_element_type(INPUT_DATA) "Element type for input (",
<< ") does not match element type for delta (" << get_input_element_type(INPUT_DATA) get_input_element_type(INPUT_DATA),
<< ")."; ") does not match element type for delta (",
get_input_element_type(INPUT_DATA),
").");
element::Type result_et; element::Type result_et;
PartialShape result_batch_shape; PartialShape result_batch_shape;
......
...@@ -44,9 +44,16 @@ void op::Broadcast::validate_and_infer_types() ...@@ -44,9 +44,16 @@ void op::Broadcast::validate_and_infer_types()
for (auto axis : m_broadcast_axes) for (auto axis : m_broadcast_axes)
{ {
NODE_VALIDATION_ASSERT(this, axis < m_shape.size()) NODE_VALIDATION_CHECK(this,
<< "Broadcast axis index (" << axis << ") exceeds specified output shape rank " axis < m_shape.size(),
<< "(broadcast axes: " << m_broadcast_axes << ", output shape: " << m_shape << ")."; "Broadcast axis index (",
axis,
") exceeds specified output shape rank ",
"(broadcast axes: ",
m_broadcast_axes,
", output shape: ",
m_shape,
").");
} }
Shape required_input_shape = m_shape; Shape required_input_shape = m_shape;
...@@ -59,10 +66,17 @@ void op::Broadcast::validate_and_infer_types() ...@@ -59,10 +66,17 @@ void op::Broadcast::validate_and_infer_types()
// There are two things that can go wrong, which are being picked up in // There are two things that can go wrong, which are being picked up in
// one fell swoop by this check: either the number of broadcast axes is not // one fell swoop by this check: either the number of broadcast axes is not
// enough, or there is a mismatch with one of the pre-broadcast axis lengths. // enough, or there is a mismatch with one of the pre-broadcast axis lengths.
NODE_VALIDATION_ASSERT(this, get_input_partial_shape(0).compatible(required_input_shape)) NODE_VALIDATION_CHECK(
<< "Broadcast argument shape, specified output shape, and axes are incompatible " this,
<< "(argument shape: " << get_input_partial_shape(0) << ", output shape: " << m_shape get_input_partial_shape(0).compatible(required_input_shape),
<< ", broadcast axes: " << m_broadcast_axes << ")."; "Broadcast argument shape, specified output shape, and axes are incompatible ",
"(argument shape: ",
get_input_partial_shape(0),
", output shape: ",
m_shape,
", broadcast axes: ",
m_broadcast_axes,
").");
set_output_type(0, get_input_element_type(0), m_shape); set_output_type(0, get_input_element_type(0), m_shape);
} }
......
...@@ -32,7 +32,7 @@ op::Concat::Concat(const NodeVector& args, size_t concatenation_axis) ...@@ -32,7 +32,7 @@ op::Concat::Concat(const NodeVector& args, size_t concatenation_axis)
void op::Concat::validate_and_infer_types() void op::Concat::validate_and_infer_types()
{ {
NODE_VALIDATION_ASSERT(this, m_inputs.size() >= 1) << "At least one argument required."; NODE_VALIDATION_CHECK(this, m_inputs.size() >= 1, "At least one argument required.");
PartialShape inputs_shape_scheme{PartialShape::dynamic()}; PartialShape inputs_shape_scheme{PartialShape::dynamic()};
element::Type inputs_et{element::dynamic}; element::Type inputs_et{element::dynamic};
...@@ -44,22 +44,32 @@ void op::Concat::validate_and_infer_types() ...@@ -44,22 +44,32 @@ void op::Concat::validate_and_infer_types()
Dimension this_input_rank = this_input_shape.rank(); Dimension this_input_rank = this_input_shape.rank();
if (this_input_rank.is_static()) if (this_input_rank.is_static())
{ {
NODE_VALIDATION_ASSERT(this, m_concatenation_axis < size_t(this_input_rank)) NODE_VALIDATION_CHECK(this,
<< "Concatenation axis (" << m_concatenation_axis << ") is out of bounds for " m_concatenation_axis < size_t(this_input_rank),
<< "argument " << i << ", which has shape " << this_input_shape << "."; "Concatenation axis (",
m_concatenation_axis,
") is out of bounds for ",
"argument ",
i,
", which has shape ",
this_input_shape,
".");
concatenation_axis_output_dim += this_input_shape[m_concatenation_axis]; concatenation_axis_output_dim += this_input_shape[m_concatenation_axis];
this_input_shape[m_concatenation_axis] = Dimension::dynamic(); this_input_shape[m_concatenation_axis] = Dimension::dynamic();
NODE_VALIDATION_ASSERT(this, NODE_VALIDATION_CHECK(
PartialShape::merge_into(inputs_shape_scheme, this_input_shape)) this,
<< "Argument shapes are inconsistent; they must have the same rank, and must have " PartialShape::merge_into(inputs_shape_scheme, this_input_shape),
<< "equal dimension everywhere except on the concatenation axis (axis " "Argument shapes are inconsistent; they must have the same rank, and must have ",
<< m_concatenation_axis << ")."; "equal dimension everywhere except on the concatenation axis (axis ",
m_concatenation_axis,
NODE_VALIDATION_ASSERT( ").");
this, element::Type::merge(inputs_et, inputs_et, get_input_element_type(i)))
<< "Argument element types are inconsistent."; NODE_VALIDATION_CHECK(
this,
element::Type::merge(inputs_et, inputs_et, get_input_element_type(i)),
"Argument element types are inconsistent.");
} }
else else
{ {
......
...@@ -47,11 +47,17 @@ namespace ngraph ...@@ -47,11 +47,17 @@ namespace ngraph
, m_data(ngraph::aligned_alloc(m_element_type.size(), , m_data(ngraph::aligned_alloc(m_element_type.size(),
shape_size(m_shape) * m_element_type.size())) shape_size(m_shape) * m_element_type.size()))
{ {
NODE_VALIDATION_ASSERT(this, NODE_VALIDATION_CHECK(
values.size() == 1 || values.size() == shape_size(m_shape)) this,
<< "Did not get the expected number of literals for a constant of shape " values.size() == 1 || values.size() == shape_size(m_shape),
<< m_shape << " (got " << values.size() << ", expected " "Did not get the expected number of literals for a constant of shape ",
<< (shape_size(m_shape) == 1 ? "" : "1 or ") << shape_size(m_shape) << ")."; m_shape,
" (got ",
values.size(),
", expected ",
(shape_size(m_shape) == 1 ? "" : "1 or "),
shape_size(m_shape),
").");
if (values.size() == 1) if (values.size() == 1)
{ {
...@@ -77,10 +83,16 @@ namespace ngraph ...@@ -77,10 +83,16 @@ namespace ngraph
, m_data(ngraph::aligned_alloc(m_element_type.size(), , m_data(ngraph::aligned_alloc(m_element_type.size(),
shape_size(m_shape) * m_element_type.size())) shape_size(m_shape) * m_element_type.size()))
{ {
NODE_VALIDATION_ASSERT(this, values.size() == shape_size(m_shape)) NODE_VALIDATION_CHECK(
<< "Did not get the expected number of literals for a constant of shape " this,
<< m_shape << " (got " << values.size() << ", expected " << shape_size(m_shape) values.size() == shape_size(m_shape),
<< "."; "Did not get the expected number of literals for a constant of shape ",
m_shape,
" (got ",
values.size(),
", expected ",
shape_size(m_shape),
".");
std::vector<double> dvalues = parse_string<double>(values); std::vector<double> dvalues = parse_string<double>(values);
write_values(dvalues); write_values(dvalues);
......
This diff is collapsed.
...@@ -42,50 +42,73 @@ void op::Dequantize::validate_and_infer_types() ...@@ -42,50 +42,73 @@ void op::Dequantize::validate_and_infer_types()
OFFSET OFFSET
}; };
NODE_VALIDATION_ASSERT(this, m_type.is_static()) << "Output element type must not be dynamic"; NODE_VALIDATION_CHECK(this, m_type.is_static(), "Output element type must not be dynamic");
NODE_VALIDATION_ASSERT(this, m_type.is_real()) << "Output element type (" << m_type NODE_VALIDATION_CHECK(
<< ") must be a floating point type"; this, m_type.is_real(), "Output element type (", m_type, ") must be a floating point type");
element::Type quantized_type; element::Type quantized_type;
NODE_VALIDATION_ASSERT(this, NODE_VALIDATION_CHECK(this,
element::Type::merge(quantized_type, element::Type::merge(quantized_type,
get_input_element_type(INPUT), get_input_element_type(INPUT),
get_input_element_type(OFFSET))) get_input_element_type(OFFSET)),
<< "Offset element type (" << get_input_element_type(OFFSET) "Offset element type (",
<< ") must match input element type (" << get_input_element_type(INPUT) << ")"; get_input_element_type(OFFSET),
") must match input element type (",
NODE_VALIDATION_ASSERT(this, quantized_type.is_dynamic() || quantized_type.is_quantized()) get_input_element_type(INPUT),
<< "Offset/input element type (" << quantized_type << ") must be a quantized type"; ")");
NODE_VALIDATION_CHECK(this,
quantized_type.is_dynamic() || quantized_type.is_quantized(),
"Offset/input element type (",
quantized_type,
") must be a quantized type");
element::Type unquantized_type; element::Type unquantized_type;
NODE_VALIDATION_ASSERT( NODE_VALIDATION_CHECK(
this, element::Type::merge(unquantized_type, get_input_element_type(SCALE), m_type)) this,
<< "Scale element type (" << get_input_element_type(SCALE) element::Type::merge(unquantized_type, get_input_element_type(SCALE), m_type),
<< ") must match output element type (" << m_type << ")"; "Scale element type (",
get_input_element_type(SCALE),
") must match output element type (",
m_type,
")");
PartialShape input_shape = get_input_partial_shape(0); PartialShape input_shape = get_input_partial_shape(0);
Dimension input_rank = input_shape.rank(); Dimension input_rank = input_shape.rank();
for (auto axis : m_axes) for (auto axis : m_axes)
{ {
NODE_VALIDATION_ASSERT(this, input_rank.is_dynamic() || axis < size_t(input_rank)) NODE_VALIDATION_CHECK(this,
<< "Quantization axis (" << axis << ") must be less than input shape rank (" input_rank.is_dynamic() || axis < size_t(input_rank),
<< input_rank << ")"; "Quantization axis (",
axis,
") must be less than input shape rank (",
input_rank,
")");
} }
PartialShape scale_offset_shape = get_input_partial_shape(SCALE); PartialShape scale_offset_shape = get_input_partial_shape(SCALE);
NODE_VALIDATION_ASSERT( NODE_VALIDATION_CHECK(
this, PartialShape::merge_into(scale_offset_shape, get_input_partial_shape(OFFSET))) this,
<< "Scale shape (" << get_input_partial_shape(SCALE) << ") and offset shape (" PartialShape::merge_into(scale_offset_shape, get_input_partial_shape(OFFSET)),
<< get_input_partial_shape(OFFSET) << ") must match"; "Scale shape (",
get_input_partial_shape(SCALE),
NODE_VALIDATION_ASSERT(this, scale_offset_shape.rank().compatible(m_axes.size())) ") and offset shape (",
<< "Scale/offset rank (" << scale_offset_shape.rank() << ") does not match the number of " get_input_partial_shape(OFFSET),
<< "quantization axes (" << m_axes.size() << ")"; ") must match");
NODE_VALIDATION_CHECK(this,
scale_offset_shape.rank().compatible(m_axes.size()),
"Scale/offset rank (",
scale_offset_shape.rank(),
") does not match the number of ",
"quantization axes (",
m_axes.size(),
")");
set_output_size(1); set_output_size(1);
...@@ -108,10 +131,16 @@ void op::Dequantize::validate_and_infer_types() ...@@ -108,10 +131,16 @@ void op::Dequantize::validate_and_infer_types()
} }
PartialShape result_shape = input_shape; PartialShape result_shape = input_shape;
NODE_VALIDATION_ASSERT( NODE_VALIDATION_CHECK(
this, PartialShape::merge_into(result_shape, PartialShape{injected_scale_offset_dims})) this,
<< "Scale/offset shape (" << scale_offset_shape << ") must match input shape (" PartialShape::merge_into(result_shape, PartialShape{injected_scale_offset_dims}),
<< input_shape << ") at the quantization axes (" << m_axes << ")"; "Scale/offset shape (",
scale_offset_shape,
") must match input shape (",
input_shape,
") at the quantization axes (",
m_axes,
")");
set_output_type(0, unquantized_type, result_shape); set_output_type(0, unquantized_type, result_shape);
} }
else else
......
...@@ -49,11 +49,14 @@ void op::Dot::validate_and_infer_types() ...@@ -49,11 +49,14 @@ void op::Dot::validate_and_infer_types()
{ {
element::Type result_et; element::Type result_et;
NODE_VALIDATION_ASSERT( NODE_VALIDATION_CHECK(
this, element::Type::merge(result_et, get_input_element_type(0), get_input_element_type(1))) this,
<< "Arguments do not have the same element type (arg0 element type: " element::Type::merge(result_et, get_input_element_type(0), get_input_element_type(1)),
<< get_input_element_type(0) << ", arg1 element type: " << get_input_element_type(1) "Arguments do not have the same element type (arg0 element type: ",
<< ")."; get_input_element_type(0),
", arg1 element type: ",
get_input_element_type(1),
").");
const PartialShape& arg0_shape = get_input_partial_shape(0); const PartialShape& arg0_shape = get_input_partial_shape(0);
const PartialShape& arg1_shape = get_input_partial_shape(1); const PartialShape& arg1_shape = get_input_partial_shape(1);
...@@ -82,17 +85,27 @@ void op::Dot::validate_and_infer_types() ...@@ -82,17 +85,27 @@ void op::Dot::validate_and_infer_types()
PartialShape result_shape; PartialShape result_shape;
NODE_VALIDATION_ASSERT(this, NODE_VALIDATION_CHECK(this,
reduction_axes_ambiguous || arg0_shape.rank().is_dynamic() || reduction_axes_ambiguous || arg0_shape.rank().is_dynamic() ||
m_reduction_axes_count <= size_t(arg0_shape.rank())) m_reduction_axes_count <= size_t(arg0_shape.rank()),
<< "Reduction axes count (" << m_reduction_axes_count "Reduction axes count (",
<< ") is too large (arg0 shape: " << arg0_shape << ", arg1 shape: " << arg1_shape << ")."; m_reduction_axes_count,
") is too large (arg0 shape: ",
NODE_VALIDATION_ASSERT(this, arg0_shape,
reduction_axes_ambiguous || arg1_shape.rank().is_dynamic() || ", arg1 shape: ",
m_reduction_axes_count <= size_t(arg1_shape.rank())) arg1_shape,
<< "Reduction axes count (" << m_reduction_axes_count ").");
<< ") is too large (arg0 shape: " << arg0_shape << ", arg1 shape: " << arg1_shape << ").";
NODE_VALIDATION_CHECK(this,
reduction_axes_ambiguous || arg1_shape.rank().is_dynamic() ||
m_reduction_axes_count <= size_t(arg1_shape.rank()),
"Reduction axes count (",
m_reduction_axes_count,
") is too large (arg0 shape: ",
arg0_shape,
", arg1 shape: ",
arg1_shape,
").");
if (!reduction_axes_ambiguous && arg0_shape.rank().is_static() && arg1_shape.rank().is_static()) if (!reduction_axes_ambiguous && arg0_shape.rank().is_static() && arg1_shape.rank().is_static())
{ {
...@@ -101,12 +114,20 @@ void op::Dot::validate_and_infer_types() ...@@ -101,12 +114,20 @@ void op::Dot::validate_and_infer_types()
size_t axis_index_arg0 = size_t(arg0_shape.rank()) - m_reduction_axes_count + i; size_t axis_index_arg0 = size_t(arg0_shape.rank()) - m_reduction_axes_count + i;
size_t axis_index_arg1 = i; size_t axis_index_arg1 = i;
NODE_VALIDATION_ASSERT( NODE_VALIDATION_CHECK(
this, arg0_shape[axis_index_arg0].compatible(arg1_shape[axis_index_arg1])) this,
<< "Paired axes (axis " << axis_index_arg0 << " from arg0, axis " << axis_index_arg1 arg0_shape[axis_index_arg0].compatible(arg1_shape[axis_index_arg1]),
<< " from arg1) do not have same length (arg0 shape: " << arg0_shape "Paired axes (axis ",
<< ", arg1 shape: " << arg1_shape axis_index_arg0,
<< ", reduction axes count: " << m_reduction_axes_count << ")."; " from arg0, axis ",
axis_index_arg1,
" from arg1) do not have same length (arg0 shape: ",
arg0_shape,
", arg1 shape: ",
arg1_shape,
", reduction axes count: ",
m_reduction_axes_count,
").");
} }
std::vector<Dimension> result_dims(size_t(arg0_shape.rank()) + size_t(arg1_shape.rank()) - std::vector<Dimension> result_dims(size_t(arg0_shape.rank()) + size_t(arg1_shape.rank()) -
......
...@@ -26,9 +26,10 @@ void op::EmbeddingLookup::validate_and_infer_types() ...@@ -26,9 +26,10 @@ void op::EmbeddingLookup::validate_and_infer_types()
const PartialShape& arg0_shape = get_input_partial_shape(0); const PartialShape& arg0_shape = get_input_partial_shape(0);
const PartialShape& arg1_shape = get_input_partial_shape(1); const PartialShape& arg1_shape = get_input_partial_shape(1);
NODE_VALIDATION_ASSERT( NODE_VALIDATION_CHECK(this,
this, arg1_shape.rank().is_dynamic() || static_cast<size_t>(arg1_shape.rank()) == 2) arg1_shape.rank().is_dynamic() ||
<< "weights are expected to be a matrix"; static_cast<size_t>(arg1_shape.rank()) == 2,
"weights are expected to be a matrix");
PartialShape result_shape; PartialShape result_shape;
if (arg0_shape.rank().is_static()) if (arg0_shape.rank().is_static())
......
...@@ -42,11 +42,12 @@ shared_ptr<Node> op::GenerateMask::copy_with_new_args(const NodeVector& new_args ...@@ -42,11 +42,12 @@ shared_ptr<Node> op::GenerateMask::copy_with_new_args(const NodeVector& new_args
void ngraph::op::GenerateMask::validate_and_infer_types() void ngraph::op::GenerateMask::validate_and_infer_types()
{ {
NODE_VALIDATION_ASSERT(this, get_input_partial_shape(0).compatible(PartialShape{})) NODE_VALIDATION_CHECK(this,
<< "Training node should be a scalar flag indicating a mode"; get_input_partial_shape(0).compatible(PartialShape{}),
"Training node should be a scalar flag indicating a mode");
NODE_VALIDATION_ASSERT(this, m_element_type.is_static()) NODE_VALIDATION_CHECK(
<< "Output element type must not be dynamic."; this, m_element_type.is_static(), "Output element type must not be dynamic.");
set_output_type(0, m_element_type, m_shape); set_output_type(0, m_element_type, m_shape);
} }
...@@ -65,36 +65,58 @@ void op::QuantizedAvgPool::validate_and_infer_types() ...@@ -65,36 +65,58 @@ void op::QuantizedAvgPool::validate_and_infer_types()
// Make sure batch size and channel count are not zero, and that we have at least one spatial // Make sure batch size and channel count are not zero, and that we have at least one spatial
// dimension (in other words, that arg has shape NCDi for some Di of rank>0, N != 0, C != 0). // dimension (in other words, that arg has shape NCDi for some Di of rank>0, N != 0, C != 0).
// //
NODE_VALIDATION_ASSERT(this, arg_shape.size() >= 3) NODE_VALIDATION_CHECK(this,
<< "Data input shape does not have rank of at least 3 (data input shape: " << arg_shape arg_shape.size() >= 3,
<< ")."; "Data input shape does not have rank of at least 3 (data input shape: ",
arg_shape,
").");
size_t batch_size = arg_shape[0]; size_t batch_size = arg_shape[0];
NODE_VALIDATION_ASSERT(this, batch_size != 0) NODE_VALIDATION_CHECK(
<< "Data batch size is zero (data input shape: " << arg_shape << ")."; this, batch_size != 0, "Data batch size is zero (data input shape: ", arg_shape, ").");
size_t channel_count = arg_shape[1]; size_t channel_count = arg_shape[1];
NODE_VALIDATION_ASSERT(this, channel_count != 0) NODE_VALIDATION_CHECK(
<< "Channel count is zero (data input shape: " << arg_shape << ")."; this, channel_count != 0, "Channel count is zero (data input shape: ", arg_shape, ").");
size_t spatial_dimension_count = arg_shape.size() - 2; size_t spatial_dimension_count = arg_shape.size() - 2;
// //
// Make sure window shape, window movement strides, and padding have same rank as Di. // Make sure window shape, window movement strides, and padding have same rank as Di.
// //
NODE_VALIDATION_ASSERT(this, m_window_shape.size() == spatial_dimension_count) NODE_VALIDATION_CHECK(
<< "Window shape rank does not match number of spatial dimensions (window shape: " this,
<< m_window_shape << ", data input shape: " << arg_shape << ")."; m_window_shape.size() == spatial_dimension_count,
NODE_VALIDATION_ASSERT(this, m_window_movement_strides.size() == spatial_dimension_count) "Window shape rank does not match number of spatial dimensions (window shape: ",
<< "Window movement stride rank does not match number of spatial dimensions (window " m_window_shape,
"movement strides: " ", data input shape: ",
<< m_window_movement_strides << ", data input shape: " << arg_shape << ")."; arg_shape,
NODE_VALIDATION_ASSERT(this, m_padding_below.size() == spatial_dimension_count) ").");
<< "Below-padding rank does not match number of spatial dimensions (padding below: " NODE_VALIDATION_CHECK(
<< m_padding_below << ", data input shape: " << arg_shape << ")."; this,
NODE_VALIDATION_ASSERT(this, m_padding_above.size() == spatial_dimension_count) m_window_movement_strides.size() == spatial_dimension_count,
<< "Above-padding rank does not match number of spatial dimensions (padding above: " "Window movement stride rank does not match number of spatial dimensions (window "
<< m_padding_above << ", data input shape: " << arg_shape << ")."; "movement strides: ",
m_window_movement_strides,
", data input shape: ",
arg_shape,
").");
NODE_VALIDATION_CHECK(
this,
m_padding_below.size() == spatial_dimension_count,
"Below-padding rank does not match number of spatial dimensions (padding below: ",
m_padding_below,
", data input shape: ",
arg_shape,
").");
NODE_VALIDATION_CHECK(
this,
m_padding_above.size() == spatial_dimension_count,
"Above-padding rank does not match number of spatial dimensions (padding above: ",
m_padding_above,
", data input shape: ",
arg_shape,
").");
// //
// Extract input item shape Di and make sure all dimensions are larger than 0. // Extract input item shape Di and make sure all dimensions are larger than 0.
...@@ -110,10 +132,13 @@ void op::QuantizedAvgPool::validate_and_infer_types() ...@@ -110,10 +132,13 @@ void op::QuantizedAvgPool::validate_and_infer_types()
for (size_t i = 0; i < spatial_dimension_count; i++) for (size_t i = 0; i < spatial_dimension_count; i++)
{ {
NODE_VALIDATION_ASSERT(this, input_item_virtual_shape[i] != 0) NODE_VALIDATION_CHECK(this,
<< "Data input spatial dimension " << i input_item_virtual_shape[i] != 0,
<< " has zero length even after padding (virtual shape of input item: " "Data input spatial dimension ",
<< input_item_virtual_shape << ")."; i,
" has zero length even after padding (virtual shape of input item: ",
input_item_virtual_shape,
").");
} }
// //
...@@ -121,9 +146,13 @@ void op::QuantizedAvgPool::validate_and_infer_types() ...@@ -121,9 +146,13 @@ void op::QuantizedAvgPool::validate_and_infer_types()
// //
for (size_t i = 0; i < spatial_dimension_count; i++) for (size_t i = 0; i < spatial_dimension_count; i++)
{ {
NODE_VALIDATION_ASSERT(this, m_window_shape[i] != 0) NODE_VALIDATION_CHECK(this,
<< "Window shape dimension " << i m_window_shape[i] != 0,
<< " has zero length (window shape: " << m_window_shape << ")."; "Window shape dimension ",
i,
" has zero length (window shape: ",
m_window_shape,
").");
} }
// //
...@@ -131,10 +160,14 @@ void op::QuantizedAvgPool::validate_and_infer_types() ...@@ -131,10 +160,14 @@ void op::QuantizedAvgPool::validate_and_infer_types()
// //
for (size_t i = 0; i < spatial_dimension_count; i++) for (size_t i = 0; i < spatial_dimension_count; i++)
{ {
NODE_VALIDATION_ASSERT(this, m_window_shape[i] <= input_item_virtual_shape[i]) NODE_VALIDATION_CHECK(
<< "Window shape after padding is larger than the spatial dimensions (window shape: " this,
<< m_window_shape << ", virtual shape of input item: " << input_item_virtual_shape m_window_shape[i] <= input_item_virtual_shape[i],
<< ")."; "Window shape after padding is larger than the spatial dimensions (window shape: ",
m_window_shape,
", virtual shape of input item: ",
input_item_virtual_shape,
").");
} }
// //
// Compute output item shape Do, checking at the same time that all window movement strides are larger than 0. // Compute output item shape Do, checking at the same time that all window movement strides are larger than 0.
...@@ -143,9 +176,13 @@ void op::QuantizedAvgPool::validate_and_infer_types() ...@@ -143,9 +176,13 @@ void op::QuantizedAvgPool::validate_and_infer_types()
for (size_t i = 0; i < spatial_dimension_count; i++) for (size_t i = 0; i < spatial_dimension_count; i++)
{ {
NODE_VALIDATION_ASSERT(this, m_window_movement_strides[i] != 0) NODE_VALIDATION_CHECK(this,
<< "Window movement strides dimension " << i m_window_movement_strides[i] != 0,
<< " has zero length (window movement strides: " << m_window_movement_strides << ")."; "Window movement strides dimension ",
i,
" has zero length (window movement strides: ",
m_window_movement_strides,
").");
output_item_shape.push_back(ceil_div(input_item_virtual_shape[i] - m_window_shape[i] + 1, output_item_shape.push_back(ceil_div(input_item_virtual_shape[i] - m_window_shape[i] + 1,
m_window_movement_strides[i])); m_window_movement_strides[i]));
} }
...@@ -167,11 +204,15 @@ void op::QuantizedAvgPool::validate_and_infer_types() ...@@ -167,11 +204,15 @@ void op::QuantizedAvgPool::validate_and_infer_types()
// Checking the lower edge of each dimension is easy, because there's no mystery // Checking the lower edge of each dimension is easy, because there's no mystery
// regarding the window's lower-edge placement... // regarding the window's lower-edge placement...
NODE_VALIDATION_ASSERT(this, NODE_VALIDATION_CHECK(
dim_padding_below == 0 || dim_window_size > dim_padding_below) this,
<< "Window will sometimes reside entirely within the below-padding region, but" dim_padding_below == 0 || dim_window_size > dim_padding_below,
<< " include_padding_in_avg_computation was not set (padding below: " "Window will sometimes reside entirely within the below-padding region, but",
<< m_padding_below << ", window shape: " << m_window_shape << ")."; " include_padding_in_avg_computation was not set (padding below: ",
m_padding_below,
", window shape: ",
m_window_shape,
").");
// Now check the upper-bound... // Now check the upper-bound...
{ {
...@@ -179,13 +220,16 @@ void op::QuantizedAvgPool::validate_and_infer_types() ...@@ -179,13 +220,16 @@ void op::QuantizedAvgPool::validate_and_infer_types()
const size_t dim_window_max_lower_offset = dim_num_strides * dim_stride; const size_t dim_window_max_lower_offset = dim_num_strides * dim_stride;
const size_t dim_padding_above_start_offset = dim_virtual_size - dim_padding_above; const size_t dim_padding_above_start_offset = dim_virtual_size - dim_padding_above;
NODE_VALIDATION_ASSERT(this, NODE_VALIDATION_CHECK(
dim_padding_above == 0 || this,
dim_window_max_lower_offset < dim_padding_above == 0 ||
dim_padding_above_start_offset) dim_window_max_lower_offset < dim_padding_above_start_offset,
<< "Window will sometimes reside entirely within the above-padding region, but" "Window will sometimes reside entirely within the above-padding region, but",
<< " include_padding_in_avg_computation was not set (padding above: " " include_padding_in_avg_computation was not set (padding above: ",
<< m_padding_above << ", window shape: " << m_window_shape << ")."; m_padding_above,
", window shape: ",
m_window_shape,
").");
} }
} }
} }
......
...@@ -33,7 +33,7 @@ op::QuantizedConcat::QuantizedConcat(const NodeVector& args, size_t concatenatio ...@@ -33,7 +33,7 @@ op::QuantizedConcat::QuantizedConcat(const NodeVector& args, size_t concatenatio
void op::QuantizedConcat::validate_and_infer_types() void op::QuantizedConcat::validate_and_infer_types()
{ {
NODE_VALIDATION_ASSERT(this, m_inputs.size() >= 1) << "At least one argument required."; NODE_VALIDATION_CHECK(this, m_inputs.size() >= 1, "At least one argument required.");
PartialShape inputs_shape_scheme{PartialShape::dynamic()}; PartialShape inputs_shape_scheme{PartialShape::dynamic()};
element::Type inputs_et{element::dynamic}; element::Type inputs_et{element::dynamic};
...@@ -45,23 +45,32 @@ void op::QuantizedConcat::validate_and_infer_types() ...@@ -45,23 +45,32 @@ void op::QuantizedConcat::validate_and_infer_types()
Dimension this_input_rank = this_input_shape.rank(); Dimension this_input_rank = this_input_shape.rank();
if (this_input_rank.is_static()) if (this_input_rank.is_static())
{ {
NODE_VALIDATION_ASSERT(this, m_concatenation_axis < size_t(this_input_rank)) NODE_VALIDATION_CHECK(this,
<< "QuantizedConcatenation axis (" << m_concatenation_axis m_concatenation_axis < size_t(this_input_rank),
<< ") is out of bounds for " "QuantizedConcatenation axis (",
<< "argument " << i << ", which has shape " << this_input_shape << "."; m_concatenation_axis,
") is out of bounds for ",
"argument ",
i,
", which has shape ",
this_input_shape,
".");
concatenation_axis_output_dim += this_input_shape[m_concatenation_axis]; concatenation_axis_output_dim += this_input_shape[m_concatenation_axis];
this_input_shape[m_concatenation_axis] = Dimension::dynamic(); this_input_shape[m_concatenation_axis] = Dimension::dynamic();
NODE_VALIDATION_ASSERT(this, NODE_VALIDATION_CHECK(
PartialShape::merge_into(inputs_shape_scheme, this_input_shape)) this,
<< "Argument shapes are inconsistent; they must have the same rank, and must have " PartialShape::merge_into(inputs_shape_scheme, this_input_shape),
<< "equal dimension everywhere except on the concatenation axis (axis " "Argument shapes are inconsistent; they must have the same rank, and must have ",
<< m_concatenation_axis << ")."; "equal dimension everywhere except on the concatenation axis (axis ",
m_concatenation_axis,
").");
NODE_VALIDATION_ASSERT( NODE_VALIDATION_CHECK(
this, element::Type::merge(inputs_et, inputs_et, get_input_element_type(i))) this,
<< "Argument element types are inconsistent."; element::Type::merge(inputs_et, inputs_et, get_input_element_type(i)),
"Argument element types are inconsistent.");
} }
else else
{ {
......
...@@ -64,36 +64,58 @@ void op::QuantizedMaxPool::validate_and_infer_types() ...@@ -64,36 +64,58 @@ void op::QuantizedMaxPool::validate_and_infer_types()
// Make sure batch size and channel count are not zero, and that we have at least one spatial // Make sure batch size and channel count are not zero, and that we have at least one spatial
// dimension (in other words, that arg has shape NCDi for some Di of rank>0, N != 0, C != 0). // dimension (in other words, that arg has shape NCDi for some Di of rank>0, N != 0, C != 0).
// //
NODE_VALIDATION_ASSERT(this, arg_shape.size() >= 3) NODE_VALIDATION_CHECK(this,
<< "Data input shape does not have rank of at least 3 (data input shape: " << arg_shape arg_shape.size() >= 3,
<< ")."; "Data input shape does not have rank of at least 3 (data input shape: ",
arg_shape,
").");
size_t batch_size = arg_shape[0]; size_t batch_size = arg_shape[0];
NODE_VALIDATION_ASSERT(this, batch_size != 0) NODE_VALIDATION_CHECK(
<< "Data batch size is zero (data input shape: " << arg_shape << ")."; this, batch_size != 0, "Data batch size is zero (data input shape: ", arg_shape, ").");
size_t channel_count = arg_shape[1]; size_t channel_count = arg_shape[1];
NODE_VALIDATION_ASSERT(this, channel_count != 0) NODE_VALIDATION_CHECK(
<< "Channel count is zero (data input shape: " << arg_shape << ")."; this, channel_count != 0, "Channel count is zero (data input shape: ", arg_shape, ").");
size_t spatial_dimension_count = arg_shape.size() - 2; size_t spatial_dimension_count = arg_shape.size() - 2;
// //
// Make sure window shape, window movement strides, and padding have same rank as Di. // Make sure window shape, window movement strides, and padding have same rank as Di.
// //
NODE_VALIDATION_ASSERT(this, m_window_shape.size() == spatial_dimension_count) NODE_VALIDATION_CHECK(
<< "Window shape rank does not match number of spatial dimensions (window shape: " this,
<< m_window_shape << ", data input shape: " << arg_shape << ")."; m_window_shape.size() == spatial_dimension_count,
NODE_VALIDATION_ASSERT(this, m_window_movement_strides.size() == spatial_dimension_count) "Window shape rank does not match number of spatial dimensions (window shape: ",
<< "Window movement stride rank does not match number of spatial dimensions (window " m_window_shape,
"movement strides: " ", data input shape: ",
<< m_window_movement_strides << ", data input shape: " << arg_shape << ")."; arg_shape,
NODE_VALIDATION_ASSERT(this, m_padding_below.size() == spatial_dimension_count) ").");
<< "Below-padding rank does not match number of spatial dimensions (padding below: " NODE_VALIDATION_CHECK(
<< m_padding_below << ", data input shape: " << arg_shape << ")."; this,
NODE_VALIDATION_ASSERT(this, m_padding_above.size() == spatial_dimension_count) m_window_movement_strides.size() == spatial_dimension_count,
<< "Above-padding rank does not match number of spatial dimensions (padding above: " "Window movement stride rank does not match number of spatial dimensions (window "
<< m_padding_above << ", data input shape: " << arg_shape << ")."; "movement strides: ",
m_window_movement_strides,
", data input shape: ",
arg_shape,
").");
NODE_VALIDATION_CHECK(
this,
m_padding_below.size() == spatial_dimension_count,
"Below-padding rank does not match number of spatial dimensions (padding below: ",
m_padding_below,
", data input shape: ",
arg_shape,
").");
NODE_VALIDATION_CHECK(
this,
m_padding_above.size() == spatial_dimension_count,
"Above-padding rank does not match number of spatial dimensions (padding above: ",
m_padding_above,
", data input shape: ",
arg_shape,
").");
// //
// Extract input item shape Di and make sure all dimensions are larger than 0. // Extract input item shape Di and make sure all dimensions are larger than 0.
...@@ -109,10 +131,13 @@ void op::QuantizedMaxPool::validate_and_infer_types() ...@@ -109,10 +131,13 @@ void op::QuantizedMaxPool::validate_and_infer_types()
for (size_t i = 0; i < spatial_dimension_count; i++) for (size_t i = 0; i < spatial_dimension_count; i++)
{ {
NODE_VALIDATION_ASSERT(this, input_item_virtual_shape[i] != 0) NODE_VALIDATION_CHECK(this,
<< "Data input spatial dimension " << i input_item_virtual_shape[i] != 0,
<< " has zero length even after padding (virtual shape of input item: " "Data input spatial dimension ",
<< input_item_virtual_shape << ")."; i,
" has zero length even after padding (virtual shape of input item: ",
input_item_virtual_shape,
").");
} }
// //
...@@ -120,9 +145,13 @@ void op::QuantizedMaxPool::validate_and_infer_types() ...@@ -120,9 +145,13 @@ void op::QuantizedMaxPool::validate_and_infer_types()
// //
for (size_t i = 0; i < spatial_dimension_count; i++) for (size_t i = 0; i < spatial_dimension_count; i++)
{ {
NODE_VALIDATION_ASSERT(this, m_window_shape[i] != 0) NODE_VALIDATION_CHECK(this,
<< "Window shape dimension " << i m_window_shape[i] != 0,
<< " has zero length (window shape: " << m_window_shape << ")."; "Window shape dimension ",
i,
" has zero length (window shape: ",
m_window_shape,
").");
} }
// //
...@@ -130,10 +159,14 @@ void op::QuantizedMaxPool::validate_and_infer_types() ...@@ -130,10 +159,14 @@ void op::QuantizedMaxPool::validate_and_infer_types()
// //
for (size_t i = 0; i < spatial_dimension_count; i++) for (size_t i = 0; i < spatial_dimension_count; i++)
{ {
NODE_VALIDATION_ASSERT(this, m_window_shape[i] <= input_item_virtual_shape[i]) NODE_VALIDATION_CHECK(
<< "Window shape after padding is larger than the spatial dimensions (window shape: " this,
<< m_window_shape << ", virtual shape of input item: " << input_item_virtual_shape m_window_shape[i] <= input_item_virtual_shape[i],
<< ")."; "Window shape after padding is larger than the spatial dimensions (window shape: ",
m_window_shape,
", virtual shape of input item: ",
input_item_virtual_shape,
").");
} }
// //
...@@ -143,9 +176,13 @@ void op::QuantizedMaxPool::validate_and_infer_types() ...@@ -143,9 +176,13 @@ void op::QuantizedMaxPool::validate_and_infer_types()
for (size_t i = 0; i < spatial_dimension_count; i++) for (size_t i = 0; i < spatial_dimension_count; i++)
{ {
NODE_VALIDATION_ASSERT(this, m_window_movement_strides[i] != 0) NODE_VALIDATION_CHECK(this,
<< "Window movement strides dimension " << i m_window_movement_strides[i] != 0,
<< " has zero length (window movement strides: " << m_window_movement_strides << ")."; "Window movement strides dimension ",
i,
" has zero length (window movement strides: ",
m_window_movement_strides,
").");
output_item_shape.push_back(ceil_div(input_item_virtual_shape[i] - m_window_shape[i] + 1, output_item_shape.push_back(ceil_div(input_item_virtual_shape[i] - m_window_shape[i] + 1,
m_window_movement_strides[i])); m_window_movement_strides[i]));
} }
......
...@@ -30,9 +30,13 @@ op::GetOutputElement::GetOutputElement(const shared_ptr<Node>& arg, size_t n) ...@@ -30,9 +30,13 @@ op::GetOutputElement::GetOutputElement(const shared_ptr<Node>& arg, size_t n)
void op::GetOutputElement::validate_and_infer_types() void op::GetOutputElement::validate_and_infer_types()
{ {
NODE_VALIDATION_ASSERT(this, m_n < get_input_size()) NODE_VALIDATION_CHECK(this,
<< "Output at index " << m_n << " requested, but node has only " << get_input_size() m_n < get_input_size(),
<< " inputs."; "Output at index ",
m_n,
" requested, but node has only ",
get_input_size(),
" inputs.");
set_output_type(0, get_input_element_type(m_n), get_input_partial_shape(m_n)); set_output_type(0, get_input_element_type(m_n), get_input_partial_shape(m_n));
} }
......
...@@ -36,9 +36,12 @@ void op::LRN::validate_and_infer_types() ...@@ -36,9 +36,12 @@ void op::LRN::validate_and_infer_types()
const PartialShape& input_shape = get_input_partial_shape(0); const PartialShape& input_shape = get_input_partial_shape(0);
NODE_VALIDATION_ASSERT( NODE_VALIDATION_CHECK(this,
this, input_shape.rank().is_dynamic() || static_cast<size_t>(input_shape.rank()) >= 3) input_shape.rank().is_dynamic() ||
<< "Argument must have rank >= 3 (argument shape: " << input_shape << ")."; static_cast<size_t>(input_shape.rank()) >= 3,
"Argument must have rank >= 3 (argument shape: ",
input_shape,
").");
} }
shared_ptr<Node> op::LRN::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::LRN::copy_with_new_args(const NodeVector& new_args) const
......
...@@ -134,9 +134,13 @@ void op::MaxPoolBackprop::validate_and_infer_types() ...@@ -134,9 +134,13 @@ void op::MaxPoolBackprop::validate_and_infer_types()
element::Type result_et; element::Type result_et;
NODE_VALIDATION_ASSERT(this, element::Type::merge(result_et, forward_arg_et, delta_et)) NODE_VALIDATION_CHECK(this,
<< "Element types for forward argument (" << forward_arg_et << ") and delta (" << delta_et element::Type::merge(result_et, forward_arg_et, delta_et),
<< ") do not match."; "Element types for forward argument (",
forward_arg_et,
") and delta (",
delta_et,
") do not match.");
// infer_batched_forward_pooling wants CoordinateDiffs for these, while the pooling ops for // infer_batched_forward_pooling wants CoordinateDiffs for these, while the pooling ops for
// now still take Shape (no negative padding). // now still take Shape (no negative padding).
...@@ -155,9 +159,15 @@ void op::MaxPoolBackprop::validate_and_infer_types() ...@@ -155,9 +159,15 @@ void op::MaxPoolBackprop::validate_and_infer_types()
const PartialShape& delta_shape = get_input_partial_shape(1); const PartialShape& delta_shape = get_input_partial_shape(1);
NODE_VALIDATION_ASSERT(this, forward_result_shape.compatible(delta_shape)) NODE_VALIDATION_CHECK(
<< "Inferred forward output shape does not match delta shape (inferred forward output " this,
<< "shape: " << forward_result_shape << ", delta shape: " << delta_shape << ")."; forward_result_shape.compatible(delta_shape),
"Inferred forward output shape does not match delta shape (inferred forward output ",
"shape: ",
forward_result_shape,
", delta shape: ",
delta_shape,
").");
// TODO(amprocte): We may technically be able to infer some extra information from // TODO(amprocte): We may technically be able to infer some extra information from
// forward_result_shape that was not present in the forward arg shape---namely batch size and // forward_result_shape that was not present in the forward arg shape---namely batch size and
......
...@@ -34,16 +34,25 @@ void op::OneHot::validate_and_infer_types() ...@@ -34,16 +34,25 @@ void op::OneHot::validate_and_infer_types()
PartialShape arg_shape = get_input_partial_shape(0); PartialShape arg_shape = get_input_partial_shape(0);
Rank arg_rank = arg_shape.rank(); Rank arg_rank = arg_shape.rank();
NODE_VALIDATION_ASSERT(this, m_shape.rank().is_static()) NODE_VALIDATION_CHECK(
<< "Requested result shape has dynamic rank."; this, m_shape.rank().is_static(), "Requested result shape has dynamic rank.");
NODE_VALIDATION_ASSERT(this, m_one_hot_axis < static_cast<size_t>(m_shape.rank())) NODE_VALIDATION_CHECK(this,
<< "One-hot axis (" << m_one_hot_axis m_one_hot_axis < static_cast<size_t>(m_shape.rank()),
<< ") is out of bounds (requested result shape: " << m_shape << ")."; "One-hot axis (",
m_one_hot_axis,
") is out of bounds (requested result shape: ",
m_shape,
").");
NODE_VALIDATION_ASSERT(this, m_shape[m_one_hot_axis].is_static()) NODE_VALIDATION_CHECK(this,
<< "Requested result shape (" << m_shape << ") has dynamic dimension at the one-hot axis " m_shape[m_one_hot_axis].is_static(),
<< "(" << m_one_hot_axis << ")."; "Requested result shape (",
m_shape,
") has dynamic dimension at the one-hot axis ",
"(",
m_one_hot_axis,
").");
PartialShape result_shape{m_shape}; PartialShape result_shape{m_shape};
...@@ -58,9 +67,13 @@ void op::OneHot::validate_and_infer_types() ...@@ -58,9 +67,13 @@ void op::OneHot::validate_and_infer_types()
PartialShape expected_input_shape{expected_input_dims}; PartialShape expected_input_shape{expected_input_dims};
PartialShape merged_input_shape{expected_input_shape}; PartialShape merged_input_shape{expected_input_shape};
NODE_VALIDATION_ASSERT(this, PartialShape::merge_into(merged_input_shape, arg_shape)) NODE_VALIDATION_CHECK(this,
<< "Argument shape " << arg_shape << " does not match the expected shape of " PartialShape::merge_into(merged_input_shape, arg_shape),
<< expected_input_shape << "."; "Argument shape ",
arg_shape,
" does not match the expected shape of ",
expected_input_shape,
".");
std::vector<Dimension> output_dims(static_cast<size_t>(merged_input_shape.rank())); std::vector<Dimension> output_dims(static_cast<size_t>(merged_input_shape.rank()));
for (size_t i = 0; i < static_cast<size_t>(merged_input_shape.rank()); i++) for (size_t i = 0; i < static_cast<size_t>(merged_input_shape.rank()); i++)
......
...@@ -38,31 +38,49 @@ void op::Pad::validate_and_infer_types() ...@@ -38,31 +38,49 @@ void op::Pad::validate_and_infer_types()
{ {
element::Type result_et; element::Type result_et;
NODE_VALIDATION_ASSERT( NODE_VALIDATION_CHECK(
this, element::Type::merge(result_et, get_input_element_type(0), get_input_element_type(1))) this,
<< "Argument element types do not match (arg0 element type: " << get_input_element_type(0) element::Type::merge(result_et, get_input_element_type(0), get_input_element_type(1)),
<< ", arg1 element type: " << get_input_element_type(1) << ")."; "Argument element types do not match (arg0 element type: ",
get_input_element_type(0),
NODE_VALIDATION_ASSERT(this, get_input_partial_shape(1).compatible(PartialShape{})) ", arg1 element type: ",
<< "Argument for padding value is not a scalar (shape: " << get_input_partial_shape(1) get_input_element_type(1),
<< ")."; ").");
NODE_VALIDATION_CHECK(this,
get_input_partial_shape(1).compatible(PartialShape{}),
"Argument for padding value is not a scalar (shape: ",
get_input_partial_shape(1),
").");
auto arg_shape = get_input_partial_shape(0); auto arg_shape = get_input_partial_shape(0);
NODE_VALIDATION_ASSERT(this, NODE_VALIDATION_CHECK(this,
m_padding_below.size() == m_padding_above.size() && m_padding_below.size() == m_padding_above.size() &&
m_padding_below.size() == m_padding_interior.size()) m_padding_below.size() == m_padding_interior.size(),
<< "Ranks for padding below (" << m_padding_below << "), padding above (" << m_padding_above "Ranks for padding below (",
<< ") and interior padding (" << m_padding_interior << ") " m_padding_below,
<< "do not match."; "), padding above (",
m_padding_above,
") and interior padding (",
m_padding_interior,
") ",
"do not match.");
size_t implied_rank = m_padding_below.size(); size_t implied_rank = m_padding_below.size();
NODE_VALIDATION_ASSERT(this, arg_shape.rank().compatible(implied_rank)) NODE_VALIDATION_CHECK(
<< "Rank for padding below/padding above/interior padding does not match the rank of the " this,
<< "data argument (padding below: " << m_padding_below << ", " arg_shape.rank().compatible(implied_rank),
<< ", padding above: " << m_padding_above << ", interior padding: " << m_padding_interior "Rank for padding below/padding above/interior padding does not match the rank of the ",
<< ")."; "data argument (padding below: ",
m_padding_below,
", ",
", padding above: ",
m_padding_above,
", interior padding: ",
m_padding_interior,
").");
std::vector<Dimension> result_dims(implied_rank, Dimension::dynamic()); std::vector<Dimension> result_dims(implied_rank, Dimension::dynamic());
......
...@@ -44,50 +44,73 @@ void op::Quantize::validate_and_infer_types() ...@@ -44,50 +44,73 @@ void op::Quantize::validate_and_infer_types()
OFFSET OFFSET
}; };
NODE_VALIDATION_ASSERT(this, m_type.is_static()) << "Output element type must not be dynamic"; NODE_VALIDATION_CHECK(this, m_type.is_static(), "Output element type must not be dynamic");
NODE_VALIDATION_ASSERT(this, m_type.is_quantized()) << "Output element type (" << m_type NODE_VALIDATION_CHECK(
<< ") must be a quantized type"; this, m_type.is_quantized(), "Output element type (", m_type, ") must be a quantized type");
element::Type unquantized_type; element::Type unquantized_type;
NODE_VALIDATION_ASSERT(this, NODE_VALIDATION_CHECK(this,
element::Type::merge(unquantized_type, element::Type::merge(unquantized_type,
get_input_element_type(INPUT), get_input_element_type(INPUT),
get_input_element_type(SCALE))) get_input_element_type(SCALE)),
<< "Scale element type (" << get_input_element_type(SCALE) "Scale element type (",
<< ") must match input element type (" << get_input_element_type(INPUT) << ")"; get_input_element_type(SCALE),
") must match input element type (",
NODE_VALIDATION_ASSERT(this, unquantized_type.is_dynamic() || unquantized_type.is_real()) get_input_element_type(INPUT),
<< "Scale/input element type (" << unquantized_type << ") must be a floating point number"; ")");
NODE_VALIDATION_CHECK(this,
unquantized_type.is_dynamic() || unquantized_type.is_real(),
"Scale/input element type (",
unquantized_type,
") must be a floating point number");
element::Type quantized_type; element::Type quantized_type;
NODE_VALIDATION_ASSERT( NODE_VALIDATION_CHECK(
this, element::Type::merge(quantized_type, get_input_element_type(OFFSET), m_type)) this,
<< "Offset element type (" << get_input_element_type(OFFSET) element::Type::merge(quantized_type, get_input_element_type(OFFSET), m_type),
<< ") must match output element type (" << m_type << ")"; "Offset element type (",
get_input_element_type(OFFSET),
") must match output element type (",
m_type,
")");
PartialShape input_shape = get_input_partial_shape(0); PartialShape input_shape = get_input_partial_shape(0);
Dimension input_rank = input_shape.rank(); Dimension input_rank = input_shape.rank();
for (auto axis : m_axes) for (auto axis : m_axes)
{ {
NODE_VALIDATION_ASSERT(this, input_rank.is_dynamic() || axis < size_t(input_rank)) NODE_VALIDATION_CHECK(this,
<< "Quantization axis (" << axis << ") must be less than input shape rank (" input_rank.is_dynamic() || axis < size_t(input_rank),
<< input_rank << ")"; "Quantization axis (",
axis,
") must be less than input shape rank (",
input_rank,
")");
} }
PartialShape scale_offset_shape = get_input_partial_shape(SCALE); PartialShape scale_offset_shape = get_input_partial_shape(SCALE);
NODE_VALIDATION_ASSERT( NODE_VALIDATION_CHECK(
this, PartialShape::merge_into(scale_offset_shape, get_input_partial_shape(OFFSET))) this,
<< "Scale shape (" << get_input_partial_shape(SCALE) << ") and offset shape (" PartialShape::merge_into(scale_offset_shape, get_input_partial_shape(OFFSET)),
<< get_input_partial_shape(OFFSET) << ") must match"; "Scale shape (",
get_input_partial_shape(SCALE),
NODE_VALIDATION_ASSERT(this, scale_offset_shape.rank().compatible(m_axes.size())) ") and offset shape (",
<< "Scale/offset rank (" << scale_offset_shape.rank() << ") does not match the number of " get_input_partial_shape(OFFSET),
<< "quantization axes (" << m_axes.size() << ")"; ") must match");
NODE_VALIDATION_CHECK(this,
scale_offset_shape.rank().compatible(m_axes.size()),
"Scale/offset rank (",
scale_offset_shape.rank(),
") does not match the number of ",
"quantization axes (",
m_axes.size(),
")");
set_output_size(1); set_output_size(1);
...@@ -110,10 +133,16 @@ void op::Quantize::validate_and_infer_types() ...@@ -110,10 +133,16 @@ void op::Quantize::validate_and_infer_types()
} }
PartialShape result_shape = input_shape; PartialShape result_shape = input_shape;
NODE_VALIDATION_ASSERT( NODE_VALIDATION_CHECK(
this, PartialShape::merge_into(result_shape, PartialShape{injected_scale_offset_dims})) this,
<< "Scale/offset shape (" << scale_offset_shape << ") must match input shape (" PartialShape::merge_into(result_shape, PartialShape{injected_scale_offset_dims}),
<< input_shape << ") at the quantization axes (" << m_axes << ")"; "Scale/offset shape (",
scale_offset_shape,
") must match input shape (",
input_shape,
") at the quantization axes (",
m_axes,
")");
set_output_type(0, quantized_type, result_shape); set_output_type(0, quantized_type, result_shape);
} }
else else
......
...@@ -59,51 +59,85 @@ void op::ReplaceSlice::validate_and_infer_types() ...@@ -59,51 +59,85 @@ void op::ReplaceSlice::validate_and_infer_types()
const PartialShape& arg1_shape = get_input_partial_shape(1); const PartialShape& arg1_shape = get_input_partial_shape(1);
Dimension merged_args_rank; Dimension merged_args_rank;
NODE_VALIDATION_ASSERT(this, NODE_VALIDATION_CHECK(this,
Dimension::merge(merged_args_rank, arg0_shape.rank(), arg1_shape.rank())) Dimension::merge(merged_args_rank, arg0_shape.rank(), arg1_shape.rank()),
<< "Argument ranks do not match (arg0 shape: " << arg0_shape "Argument ranks do not match (arg0 shape: ",
<< ", arg1 shape: " << arg1_shape << ")."; arg0_shape,
", arg1 shape: ",
arg1_shape,
").");
element::Type arg0_et = get_input_element_type(0); element::Type arg0_et = get_input_element_type(0);
element::Type arg1_et = get_input_element_type(1); element::Type arg1_et = get_input_element_type(1);
element::Type merged_args_et; element::Type merged_args_et;
NODE_VALIDATION_ASSERT(this, element::Type::merge(merged_args_et, arg0_et, arg1_et)) NODE_VALIDATION_CHECK(this,
<< "Argument element types do not match (arg0 element type: " << arg0_et element::Type::merge(merged_args_et, arg0_et, arg1_et),
<< ", arg1 element type: " << arg1_et << ")."; "Argument element types do not match (arg0 element type: ",
arg0_et,
NODE_VALIDATION_ASSERT(this, ", arg1 element type: ",
m_lower_bounds.size() == m_upper_bounds.size() && arg1_et,
m_lower_bounds.size() == m_strides.size()) ").");
<< "Ranks of lower bounds (" << m_lower_bounds << "), upper bounds (" << m_upper_bounds
<< ") and strides (" << m_strides << ") do not match."; NODE_VALIDATION_CHECK(this,
m_lower_bounds.size() == m_upper_bounds.size() &&
m_lower_bounds.size() == m_strides.size(),
"Ranks of lower bounds (",
m_lower_bounds,
"), upper bounds (",
m_upper_bounds,
") and strides (",
m_strides,
") do not match.");
size_t output_rank = m_upper_bounds.size(); size_t output_rank = m_upper_bounds.size();
for (size_t i = 0; i < output_rank; i++) for (size_t i = 0; i < output_rank; i++)
{ {
NODE_VALIDATION_ASSERT(this, m_lower_bounds[i] <= m_upper_bounds[i]) NODE_VALIDATION_CHECK(this,
<< "Lower bound for slice is greater than upper bound at axis " << i m_lower_bounds[i] <= m_upper_bounds[i],
<< " (lower bounds: " << m_lower_bounds << ", upper bounds: " << m_upper_bounds << ")."; "Lower bound for slice is greater than upper bound at axis ",
i,
NODE_VALIDATION_ASSERT(this, m_strides[i] != 0) << "Stride for slice is zero at axis " << i " (lower bounds: ",
<< " (strides: " << m_strides << ")."; m_lower_bounds,
", upper bounds: ",
m_upper_bounds,
").");
NODE_VALIDATION_CHECK(this,
m_strides[i] != 0,
"Stride for slice is zero at axis ",
i,
" (strides: ",
m_strides,
").");
} }
NODE_VALIDATION_ASSERT(this, NODE_VALIDATION_CHECK(this,
merged_args_rank.is_dynamic() || size_t(merged_args_rank) == output_rank) merged_args_rank.is_dynamic() || size_t(merged_args_rank) == output_rank,
<< "Argument ranks do not match the rank of the lower bounds (" << m_lower_bounds "Argument ranks do not match the rank of the lower bounds (",
<< "), upper bounds (" << m_upper_bounds << "), and strides (" << m_strides << ")."; m_lower_bounds,
"), upper bounds (",
m_upper_bounds,
"), and strides (",
m_strides,
").");
std::vector<Dimension> sliced_dims(output_rank); std::vector<Dimension> sliced_dims(output_rank);
for (size_t i = 0; i < output_rank; i++) for (size_t i = 0; i < output_rank; i++)
{ {
NODE_VALIDATION_ASSERT(this, NODE_VALIDATION_CHECK(this,
arg0_shape.rank().is_dynamic() || arg0_shape[i].is_dynamic() || arg0_shape.rank().is_dynamic() || arg0_shape[i].is_dynamic() ||
m_upper_bounds[i] <= size_t(arg0_shape[i])) m_upper_bounds[i] <= size_t(arg0_shape[i]),
<< "Upper bound for slice at axis " << i << " is out of range " "Upper bound for slice at axis ",
<< "(upper bounds: " << m_upper_bounds << ", argument shape: " << arg0_shape << ")."; i,
" is out of range ",
"(upper bounds: ",
m_upper_bounds,
", argument shape: ",
arg0_shape,
").");
size_t sliced_dim = m_upper_bounds[i] - m_lower_bounds[i]; size_t sliced_dim = m_upper_bounds[i] - m_lower_bounds[i];
sliced_dim = sliced_dim / m_strides[i] + ((sliced_dim % m_strides[i] == 0) ? 0 : 1); sliced_dim = sliced_dim / m_strides[i] + ((sliced_dim % m_strides[i] == 0) ? 0 : 1);
...@@ -112,9 +146,14 @@ void op::ReplaceSlice::validate_and_infer_types() ...@@ -112,9 +146,14 @@ void op::ReplaceSlice::validate_and_infer_types()
PartialShape slice_shape{sliced_dims}; PartialShape slice_shape{sliced_dims};
NODE_VALIDATION_ASSERT(this, arg1_shape.compatible(slice_shape)) NODE_VALIDATION_CHECK(this,
<< "Shape of replacement tensor (" << arg1_shape << ") does not match the slice shape " arg1_shape.compatible(slice_shape),
<< "(" << slice_shape << ")."; "Shape of replacement tensor (",
arg1_shape,
") does not match the slice shape ",
"(",
slice_shape,
").");
// Slight corner case here: if arg0 was rank-unknown, we can go ahead and set the output rank // Slight corner case here: if arg0 was rank-unknown, we can go ahead and set the output rank
// because the attribs will have given us enough info. // because the attribs will have given us enough info.
......
...@@ -41,25 +41,39 @@ void op::Reshape::validate_and_infer_types() ...@@ -41,25 +41,39 @@ void op::Reshape::validate_and_infer_types()
// Check that the input axis order is a permutation of (0,...,n-1) for some n. // Check that the input axis order is a permutation of (0,...,n-1) for some n.
for (size_t i = 0; i < m_input_order.size(); i++) for (size_t i = 0; i < m_input_order.size(); i++)
{ {
NODE_VALIDATION_ASSERT( NODE_VALIDATION_CHECK(
this, find(begin(m_input_order), end(m_input_order), i) != end(m_input_order)) this,
<< "Input axis order is not a permutation of argument's axis indices (axis order: " find(begin(m_input_order), end(m_input_order), i) != end(m_input_order),
<< m_input_order << ", argument shape: " << input_shape << ")."; "Input axis order is not a permutation of argument's axis indices (axis order: ",
m_input_order,
", argument shape: ",
input_shape,
").");
} }
// TODO(amprocte): should be possible to move around unknown dims in the input shape. // TODO(amprocte): should be possible to move around unknown dims in the input shape.
if (input_rank.is_static()) if (input_rank.is_static())
{ {
NODE_VALIDATION_ASSERT(this, m_input_order.size() == size_t(input_rank)) NODE_VALIDATION_CHECK(
<< "Input axis order is not a permutation of argument's axis indices (axis order: " this,
<< m_input_order << ", argument shape: " << input_shape << ")."; m_input_order.size() == size_t(input_rank),
"Input axis order is not a permutation of argument's axis indices (axis order: ",
m_input_order,
", argument shape: ",
input_shape,
").");
for (size_t i = 0; i < size_t(input_rank); i++) for (size_t i = 0; i < size_t(input_rank); i++)
{ {
auto it = find(begin(m_input_order), end(m_input_order), i); auto it = find(begin(m_input_order), end(m_input_order), i);
NODE_VALIDATION_ASSERT(this, it != end(m_input_order)) NODE_VALIDATION_CHECK(
<< "Input axis order is not a permutation of argument's axis indices (axis order: " this,
<< m_input_order << ", argument shape: " << input_shape << ")."; it != end(m_input_order),
"Input axis order is not a permutation of argument's axis indices (axis order: ",
m_input_order,
", argument shape: ",
input_shape,
").");
} }
// TODO(amprocte): make a partial_shape_size() analogous to shape_size(). // TODO(amprocte): make a partial_shape_size() analogous to shape_size().
...@@ -71,11 +85,16 @@ void op::Reshape::validate_and_infer_types() ...@@ -71,11 +85,16 @@ void op::Reshape::validate_and_infer_types()
if (input_shape_product.is_static()) if (input_shape_product.is_static())
{ {
NODE_VALIDATION_ASSERT(this, size_t(input_shape_product) == shape_size(m_output_shape)) NODE_VALIDATION_CHECK(
<< "Product of output shape dimensions does not match product of argument shape " this,
"dimensions " size_t(input_shape_product) == shape_size(m_output_shape),
<< "(output shape: " << m_output_shape << ", argument shape: " << input_shape "Product of output shape dimensions does not match product of argument shape "
<< ")."; "dimensions ",
"(output shape: ",
m_output_shape,
", argument shape: ",
input_shape,
").");
} }
} }
......
...@@ -32,8 +32,8 @@ op::Result::Result(const shared_ptr<Node>& arg) ...@@ -32,8 +32,8 @@ op::Result::Result(const shared_ptr<Node>& arg)
void op::Result::validate_and_infer_types() void op::Result::validate_and_infer_types()
{ {
NODE_VALIDATION_ASSERT(this, get_input_size() == 1) << "Argument has " << get_input_size() NODE_VALIDATION_CHECK(
<< " outputs (1 expected)."; this, get_input_size() == 1, "Argument has ", get_input_size(), " outputs (1 expected).");
// always borrow the placement conf even the default one // always borrow the placement conf even the default one
set_placement_index(get_argument(0)->get_placement_index()); set_placement_index(get_argument(0)->get_placement_index());
......
...@@ -40,9 +40,13 @@ void op::Reverse::validate_and_infer_types() ...@@ -40,9 +40,13 @@ void op::Reverse::validate_and_infer_types()
// Make sure all reversed axis indices are valid. // Make sure all reversed axis indices are valid.
for (size_t axis : m_reversed_axes) for (size_t axis : m_reversed_axes)
{ {
NODE_VALIDATION_ASSERT(this, axis < size_t(input_rank)) NODE_VALIDATION_CHECK(this,
<< "Reverse axis (" << axis << ") is out of bounds (argument shape: " << input_shape axis < size_t(input_rank),
<< ")."; "Reverse axis (",
axis,
") is out of bounds (argument shape: ",
input_shape,
").");
} }
} }
......
...@@ -41,20 +41,31 @@ void op::ReverseSequence::validate_and_infer_types() ...@@ -41,20 +41,31 @@ void op::ReverseSequence::validate_and_infer_types()
auto input_shape = get_input_partial_shape(0); auto input_shape = get_input_partial_shape(0);
auto input_rank = input_shape.rank(); auto input_rank = input_shape.rank();
NODE_VALIDATION_ASSERT(this, input_rank.is_dynamic() || m_batch_axis < size_t(input_rank)) NODE_VALIDATION_CHECK(this,
<< "Batch axis index (" << m_batch_axis input_rank.is_dynamic() || m_batch_axis < size_t(input_rank),
<< ") is out of bounds (argument shape: " << input_shape << ")."; "Batch axis index (",
m_batch_axis,
") is out of bounds (argument shape: ",
input_shape,
").");
NODE_VALIDATION_ASSERT(this, input_rank.is_dynamic() || m_seq_axis < size_t(input_rank)) NODE_VALIDATION_CHECK(this,
<< "Sequence axis index (" << m_seq_axis input_rank.is_dynamic() || m_seq_axis < size_t(input_rank),
<< ") is out of bounds (argument shape: " << input_shape << ")."; "Sequence axis index (",
m_seq_axis,
") is out of bounds (argument shape: ",
input_shape,
").");
auto indices_shape = get_input_partial_shape(1); auto indices_shape = get_input_partial_shape(1);
auto indices_rank = indices_shape.rank(); auto indices_rank = indices_shape.rank();
NODE_VALIDATION_ASSERT(this, indices_rank.is_dynamic() || size_t(indices_rank) == 1) NODE_VALIDATION_CHECK(
<< "Sequence indices must be a 1-dimensional tensor (sequence indices shape: " this,
<< get_input_partial_shape(1) << ")."; indices_rank.is_dynamic() || size_t(indices_rank) == 1,
"Sequence indices must be a 1-dimensional tensor (sequence indices shape: ",
get_input_partial_shape(1),
").");
PartialShape output_shape{input_shape}; PartialShape output_shape{input_shape};
...@@ -62,12 +73,19 @@ void op::ReverseSequence::validate_and_infer_types() ...@@ -62,12 +73,19 @@ void op::ReverseSequence::validate_and_infer_types()
{ {
Dimension merged_sequence_length; Dimension merged_sequence_length;
NODE_VALIDATION_ASSERT( NODE_VALIDATION_CHECK(
this, this,
Dimension::merge(merged_sequence_length, input_shape[m_batch_axis], indices_shape[0])) Dimension::merge(merged_sequence_length, input_shape[m_batch_axis], indices_shape[0]),
<< "Sequence length (" << indices_shape[0] << ") is not equal to batch axis " "Sequence length (",
<< "dimension (" << input_shape[m_batch_axis] << ") (argument shape: " << input_shape indices_shape[0],
<< ", sequence indices shape: " << indices_shape << ")."; ") is not equal to batch axis ",
"dimension (",
input_shape[m_batch_axis],
") (argument shape: ",
input_shape,
", sequence indices shape: ",
indices_shape,
").");
output_shape[m_batch_axis] = merged_sequence_length; output_shape[m_batch_axis] = merged_sequence_length;
} }
......
...@@ -36,24 +36,28 @@ op::Select::Select(const shared_ptr<Node>& arg0, ...@@ -36,24 +36,28 @@ op::Select::Select(const shared_ptr<Node>& arg0,
void op::Select::validate_and_infer_types() void op::Select::validate_and_infer_types()
{ {
NODE_VALIDATION_ASSERT(this, NODE_VALIDATION_CHECK(this,
get_input_element_type(0).is_dynamic() || get_input_element_type(0).is_dynamic() ||
get_input_element_type(0) == element::boolean) get_input_element_type(0) == element::boolean,
<< "Argument 0 does not have boolean element type (element type: " "Argument 0 does not have boolean element type (element type: ",
<< get_input_element_type(0) << ")."; get_input_element_type(0),
").");
PartialShape result_shape = get_input_partial_shape(0); PartialShape result_shape = get_input_partial_shape(0);
NODE_VALIDATION_ASSERT(this, PartialShape::merge_into(result_shape, get_input_partial_shape(1))) NODE_VALIDATION_CHECK(this,
<< "Argument shapes are inconsistent."; PartialShape::merge_into(result_shape, get_input_partial_shape(1)),
NODE_VALIDATION_ASSERT(this, PartialShape::merge_into(result_shape, get_input_partial_shape(2))) "Argument shapes are inconsistent.");
<< "Argument shapes are inconsistent."; NODE_VALIDATION_CHECK(this,
PartialShape::merge_into(result_shape, get_input_partial_shape(2)),
"Argument shapes are inconsistent.");
element::Type result_et; element::Type result_et;
NODE_VALIDATION_ASSERT( NODE_VALIDATION_CHECK(
this, element::Type::merge(result_et, get_input_element_type(1), get_input_element_type(2))) this,
<< "Argument 1 and 2 element types are inconsistent."; element::Type::merge(result_et, get_input_element_type(1), get_input_element_type(2)),
"Argument 1 and 2 element types are inconsistent.");
set_output_type(0, result_et, result_shape); set_output_type(0, result_et, result_shape);
} }
......
...@@ -51,40 +51,68 @@ void op::Slice::validate_and_infer_types() ...@@ -51,40 +51,68 @@ void op::Slice::validate_and_infer_types()
m_strides = Strides(m_lower_bounds.size(), 1); m_strides = Strides(m_lower_bounds.size(), 1);
} }
NODE_VALIDATION_ASSERT(this, NODE_VALIDATION_CHECK(this,
m_lower_bounds.size() == m_upper_bounds.size() && m_lower_bounds.size() == m_upper_bounds.size() &&
m_lower_bounds.size() == m_strides.size()) m_lower_bounds.size() == m_strides.size(),
<< "Ranks of lower bounds (" << m_lower_bounds << "), upper bounds (" << m_upper_bounds "Ranks of lower bounds (",
<< ") and strides (" << m_strides << ") do not match."; m_lower_bounds,
"), upper bounds (",
m_upper_bounds,
") and strides (",
m_strides,
") do not match.");
size_t output_rank = m_upper_bounds.size(); size_t output_rank = m_upper_bounds.size();
for (size_t i = 0; i < output_rank; i++) for (size_t i = 0; i < output_rank; i++)
{ {
NODE_VALIDATION_ASSERT(this, m_lower_bounds[i] <= m_upper_bounds[i]) NODE_VALIDATION_CHECK(this,
<< "Lower bound for slice is greater than upper bound at axis " << i m_lower_bounds[i] <= m_upper_bounds[i],
<< " (lower bounds: " << m_lower_bounds << ", upper bounds: " << m_upper_bounds << ")."; "Lower bound for slice is greater than upper bound at axis ",
i,
NODE_VALIDATION_ASSERT(this, m_strides[i] != 0) << "Stride for slice is zero at axis " << i " (lower bounds: ",
<< " (strides: " << m_strides << ")."; m_lower_bounds,
", upper bounds: ",
m_upper_bounds,
").");
NODE_VALIDATION_CHECK(this,
m_strides[i] != 0,
"Stride for slice is zero at axis ",
i,
" (strides: ",
m_strides,
").");
} }
const PartialShape& input_shape = get_input_partial_shape(0); const PartialShape& input_shape = get_input_partial_shape(0);
Dimension input_rank = input_shape.rank(); Dimension input_rank = input_shape.rank();
NODE_VALIDATION_ASSERT(this, input_rank.is_dynamic() || size_t(input_rank) == output_rank) NODE_VALIDATION_CHECK(this,
<< "Input rank does not match the rank of the lower bounds (" << m_lower_bounds input_rank.is_dynamic() || size_t(input_rank) == output_rank,
<< "), upper bounds (" << m_upper_bounds << "), and strides (" << m_strides << ")."; "Input rank does not match the rank of the lower bounds (",
m_lower_bounds,
"), upper bounds (",
m_upper_bounds,
"), and strides (",
m_strides,
").");
std::vector<Dimension> result_dims(output_rank); std::vector<Dimension> result_dims(output_rank);
for (size_t i = 0; i < output_rank; i++) for (size_t i = 0; i < output_rank; i++)
{ {
NODE_VALIDATION_ASSERT(this, NODE_VALIDATION_CHECK(this,
input_rank.is_dynamic() || input_shape[i].is_dynamic() || input_rank.is_dynamic() || input_shape[i].is_dynamic() ||
m_upper_bounds[i] <= size_t(input_shape[i])) m_upper_bounds[i] <= size_t(input_shape[i]),
<< "Upper bound for slice at axis " << i << " is out of range " "Upper bound for slice at axis ",
<< "(upper bounds: " << m_upper_bounds << ", argument shape: " << input_shape << ")."; i,
" is out of range ",
"(upper bounds: ",
m_upper_bounds,
", argument shape: ",
input_shape,
").");
size_t result_axis_size = m_upper_bounds[i] - m_lower_bounds[i]; size_t result_axis_size = m_upper_bounds[i] - m_lower_bounds[i];
result_axis_size = result_axis_size =
......
...@@ -37,9 +37,13 @@ op::Softmax::Softmax(const shared_ptr<Node>& arg, const AxisSet& axes) ...@@ -37,9 +37,13 @@ op::Softmax::Softmax(const shared_ptr<Node>& arg, const AxisSet& axes)
for (auto axis : m_axes) for (auto axis : m_axes)
{ {
NODE_VALIDATION_ASSERT(this, axis < get_shape().size()) NODE_VALIDATION_CHECK(this,
<< "Reduction axis (" << axis << ") is out of bounds (argument shape: " << get_shape() axis < get_shape().size(),
<< ")."; "Reduction axis (",
axis,
") is out of bounds (argument shape: ",
get_shape(),
").");
} }
// empty axes == all axes // empty axes == all axes
......
...@@ -43,26 +43,36 @@ void op::TopK::validate_and_infer_types() ...@@ -43,26 +43,36 @@ void op::TopK::validate_and_infer_types()
Rank input_rank = input_shape.rank(); Rank input_rank = input_shape.rank();
element::Type input_element_type = get_input_element_type(0); element::Type input_element_type = get_input_element_type(0);
NODE_VALIDATION_ASSERT(this, !m_index_element_type.is_dynamic()) NODE_VALIDATION_CHECK(
<< "Argument element type must not be dynamic."; this, !m_index_element_type.is_dynamic(), "Argument element type must not be dynamic.");
NODE_VALIDATION_ASSERT( NODE_VALIDATION_CHECK(this,
this, m_index_element_type == element::i32 || m_index_element_type == element::i64) m_index_element_type == element::i32 ||
<< "Argument element type must be i64 or i32 (got " << m_index_element_type << ")."; m_index_element_type == element::i64,
"Argument element type must be i64 or i32 (got ",
m_index_element_type,
").");
NODE_VALIDATION_ASSERT(this, input_rank.is_dynamic() || static_cast<size_t>(input_rank) > 0) NODE_VALIDATION_CHECK(this,
<< "Argument rank must be greater than 0."; input_rank.is_dynamic() || static_cast<size_t>(input_rank) > 0,
"Argument rank must be greater than 0.");
NODE_VALIDATION_ASSERT( NODE_VALIDATION_CHECK(this,
this, input_rank.is_dynamic() || m_top_k_axis < static_cast<size_t>(input_rank)) input_rank.is_dynamic() || m_top_k_axis < static_cast<size_t>(input_rank),
<< "TopK axis (" << m_top_k_axis << ") is out of bounds."; "TopK axis (",
m_top_k_axis,
") is out of bounds.");
NODE_VALIDATION_ASSERT(this, NODE_VALIDATION_CHECK(this,
input_rank.is_dynamic() || input_shape[m_top_k_axis].is_dynamic() || input_rank.is_dynamic() || input_shape[m_top_k_axis].is_dynamic() ||
m_k <= static_cast<size_t>(input_shape[m_top_k_axis])) m_k <= static_cast<size_t>(input_shape[m_top_k_axis]),
<< "K (" << m_k << ") exceeds the dimension (" "K (",
<< (input_rank.is_static() ? input_shape[m_top_k_axis] : 0) << ") of the TopK axis (axis " m_k,
<< m_top_k_axis << ")."; ") exceeds the dimension (",
(input_rank.is_static() ? input_shape[m_top_k_axis] : 0),
") of the TopK axis (axis ",
m_top_k_axis,
").");
PartialShape output_shape{input_shape}; PartialShape output_shape{input_shape};
......
...@@ -40,10 +40,16 @@ void op::util::ArithmeticReduction::validate_and_infer_types() ...@@ -40,10 +40,16 @@ void op::util::ArithmeticReduction::validate_and_infer_types()
for (auto axis : m_reduction_axes) for (auto axis : m_reduction_axes)
{ {
NODE_VALIDATION_ASSERT(this, axis < size_t(input_rank)) NODE_VALIDATION_CHECK(this,
<< "Reduction axis (" << axis << ") is out of bounds " axis < size_t(input_rank),
<< "(argument shape: " << input_shape << ", reduction axes: " << m_reduction_axes "Reduction axis (",
<< ")"; axis,
") is out of bounds ",
"(argument shape: ",
input_shape,
", reduction axes: ",
m_reduction_axes,
")");
} }
for (size_t i = 0; i < size_t(input_rank); i++) for (size_t i = 0; i < size_t(input_rank); i++)
......
...@@ -37,13 +37,18 @@ void op::util::IndexReduction::validate_and_infer_types() ...@@ -37,13 +37,18 @@ void op::util::IndexReduction::validate_and_infer_types()
const PartialShape& arg_shape = get_input_partial_shape(0); const PartialShape& arg_shape = get_input_partial_shape(0);
Rank rank = arg_shape.rank(); Rank rank = arg_shape.rank();
NODE_VALIDATION_ASSERT(this, rank.is_dynamic() || size_t(rank) >= 1) NODE_VALIDATION_CHECK(this, rank.is_dynamic() || size_t(rank) >= 1, "Argument rank is zero.");
<< "Argument rank is zero."; NODE_VALIDATION_CHECK(this,
NODE_VALIDATION_ASSERT(this, rank.is_dynamic() || m_axis < size_t(rank)) rank.is_dynamic() || m_axis < size_t(rank),
<< "Reduction axis (" << m_axis << ") is not less than argument rank (" << rank << ")."; "Reduction axis (",
NODE_VALIDATION_ASSERT( m_axis,
this, m_index_element_type == element::i32 || m_index_element_type == element::i64) ") is not less than argument rank (",
<< "Index element is neither i64 or i32."; rank,
").");
NODE_VALIDATION_CHECK(this,
m_index_element_type == element::i32 ||
m_index_element_type == element::i64,
"Index element is neither i64 or i32.");
PartialShape output_shape{PartialShape::dynamic()}; PartialShape output_shape{PartialShape::dynamic()};
......
...@@ -40,10 +40,16 @@ void op::util::LogicalReduction::validate_and_infer_types() ...@@ -40,10 +40,16 @@ void op::util::LogicalReduction::validate_and_infer_types()
for (auto axis : m_reduction_axes) for (auto axis : m_reduction_axes)
{ {
NODE_VALIDATION_ASSERT(this, axis < size_t(input_rank)) NODE_VALIDATION_CHECK(this,
<< "Reduction axis (" << axis << ") is out of bounds " axis < size_t(input_rank),
<< "(argument shape: " << input_shape << ", reduction axes: " << m_reduction_axes "Reduction axis (",
<< ")"; axis,
") is out of bounds ",
"(argument shape: ",
input_shape,
", reduction axes: ",
m_reduction_axes,
")");
} }
for (size_t i = 0; i < size_t(input_rank); i++) for (size_t i = 0; i < size_t(input_rank); i++)
...@@ -57,8 +63,9 @@ void op::util::LogicalReduction::validate_and_infer_types() ...@@ -57,8 +63,9 @@ void op::util::LogicalReduction::validate_and_infer_types()
result_shape = PartialShape(dims); result_shape = PartialShape(dims);
} }
NODE_VALIDATION_ASSERT(this, get_input_element_type(0).compatible(element::boolean)) NODE_VALIDATION_CHECK(this,
<< "Input element type must be boolean."; get_input_element_type(0).compatible(element::boolean),
"Input element type must be boolean.");
set_output_type(0, element::boolean, result_shape); set_output_type(0, element::boolean, result_shape);
} }
...@@ -29,9 +29,14 @@ void op::util::validate_conv_shapes(const Node* node, ...@@ -29,9 +29,14 @@ void op::util::validate_conv_shapes(const Node* node,
const Shape& data_shape, const Shape& data_shape,
const Shape& filters_shape) const Shape& filters_shape)
{ {
NODE_VALIDATION_ASSERT(node, data_shape[1] == filters_shape[1]) NODE_VALIDATION_CHECK(
<< "Number of channels for data and filters do not match (data num channels: " node,
<< data_shape[1] << ", filters num channels: " << filters_shape[1] << ")."; data_shape[1] == filters_shape[1],
"Number of channels for data and filters do not match (data num channels: ",
data_shape[1],
", filters num channels: ",
filters_shape[1],
").");
} }
op::ConvolutionAdd::ConvolutionAdd(const std::shared_ptr<op::Convolution>& conv, op::ConvolutionAdd::ConvolutionAdd(const std::shared_ptr<op::Convolution>& conv,
...@@ -79,9 +84,14 @@ op::ConvolutionAdd::ConvolutionAdd(const std::shared_ptr<Node>& data_batch, ...@@ -79,9 +84,14 @@ op::ConvolutionAdd::ConvolutionAdd(const std::shared_ptr<Node>& data_batch,
// //
// Make sure data batch and filter element types match. // Make sure data batch and filter element types match.
// //
NODE_VALIDATION_ASSERT(this, data_batch_et == filters_et) NODE_VALIDATION_CHECK(
<< "Element types for data_batch and filters do not match (data batch element type: " this,
<< data_batch_et << ", filters element type: " << filters_et << ")."; data_batch_et == filters_et,
"Element types for data_batch and filters do not match (data batch element type: ",
data_batch_et,
", filters element type: ",
filters_et,
").");
util::validate_conv_shapes(this, data_batch_shape, filters_shape); util::validate_conv_shapes(this, data_batch_shape, filters_shape);
set_output_type(0, set_output_type(0,
...@@ -105,8 +115,11 @@ op::ConvolutionAdd::ConvolutionAdd(const std::shared_ptr<Node>& data_batch, ...@@ -105,8 +115,11 @@ op::ConvolutionAdd::ConvolutionAdd(const std::shared_ptr<Node>& data_batch,
std::shared_ptr<Node> op::ConvolutionAdd::copy_with_new_args(const NodeVector& new_args) const std::shared_ptr<Node> op::ConvolutionAdd::copy_with_new_args(const NodeVector& new_args) const
{ {
NODE_VALIDATION_ASSERT(this, new_args.size() == 3) NODE_VALIDATION_CHECK(this,
<< "New arg size is not 3 (new args size: " << new_args.size() << ")."; new_args.size() == 3,
"New arg size is not 3 (new args size: ",
new_args.size(),
").");
return std::shared_ptr<Node>(new ConvolutionAdd(new_args.at(0), return std::shared_ptr<Node>(new ConvolutionAdd(new_args.at(0),
new_args.at(1), new_args.at(1),
......
...@@ -57,51 +57,85 @@ void op::UpdateSlice::validate_and_infer_types() ...@@ -57,51 +57,85 @@ void op::UpdateSlice::validate_and_infer_types()
const PartialShape& arg1_shape = get_input_partial_shape(1); const PartialShape& arg1_shape = get_input_partial_shape(1);
Dimension merged_args_rank; Dimension merged_args_rank;
NODE_VALIDATION_ASSERT(this, NODE_VALIDATION_CHECK(this,
Dimension::merge(merged_args_rank, arg0_shape.rank(), arg1_shape.rank())) Dimension::merge(merged_args_rank, arg0_shape.rank(), arg1_shape.rank()),
<< "Argument ranks do not match (arg0 shape: " << arg0_shape "Argument ranks do not match (arg0 shape: ",
<< ", arg1 shape: " << arg1_shape << ")."; arg0_shape,
", arg1 shape: ",
arg1_shape,
").");
element::Type arg0_et = get_input_element_type(0); element::Type arg0_et = get_input_element_type(0);
element::Type arg1_et = get_input_element_type(1); element::Type arg1_et = get_input_element_type(1);
element::Type merged_args_et; element::Type merged_args_et;
NODE_VALIDATION_ASSERT(this, element::Type::merge(merged_args_et, arg0_et, arg1_et)) NODE_VALIDATION_CHECK(this,
<< "Argument element types do not match (arg0 element type: " << arg0_et element::Type::merge(merged_args_et, arg0_et, arg1_et),
<< ", arg1 element type: " << arg1_et << ")."; "Argument element types do not match (arg0 element type: ",
arg0_et,
NODE_VALIDATION_ASSERT(this, ", arg1 element type: ",
m_lower_bounds.size() == m_upper_bounds.size() && arg1_et,
m_lower_bounds.size() == m_strides.size()) ").");
<< "Ranks of lower bounds (" << m_lower_bounds << "), upper bounds (" << m_upper_bounds
<< ") and strides (" << m_strides << ") do not match."; NODE_VALIDATION_CHECK(this,
m_lower_bounds.size() == m_upper_bounds.size() &&
m_lower_bounds.size() == m_strides.size(),
"Ranks of lower bounds (",
m_lower_bounds,
"), upper bounds (",
m_upper_bounds,
") and strides (",
m_strides,
") do not match.");
size_t output_rank = m_upper_bounds.size(); size_t output_rank = m_upper_bounds.size();
for (size_t i = 0; i < output_rank; i++) for (size_t i = 0; i < output_rank; i++)
{ {
NODE_VALIDATION_ASSERT(this, m_lower_bounds[i] <= m_upper_bounds[i]) NODE_VALIDATION_CHECK(this,
<< "Lower bound for slice is greater than upper bound at axis " << i m_lower_bounds[i] <= m_upper_bounds[i],
<< " (lower bounds: " << m_lower_bounds << ", upper bounds: " << m_upper_bounds << ")."; "Lower bound for slice is greater than upper bound at axis ",
i,
NODE_VALIDATION_ASSERT(this, m_strides[i] != 0) << "Stride for slice is zero at axis " << i " (lower bounds: ",
<< " (strides: " << m_strides << ")."; m_lower_bounds,
", upper bounds: ",
m_upper_bounds,
").");
NODE_VALIDATION_CHECK(this,
m_strides[i] != 0,
"Stride for slice is zero at axis ",
i,
" (strides: ",
m_strides,
").");
} }
NODE_VALIDATION_ASSERT(this, NODE_VALIDATION_CHECK(this,
merged_args_rank.is_dynamic() || size_t(merged_args_rank) == output_rank) merged_args_rank.is_dynamic() || size_t(merged_args_rank) == output_rank,
<< "Argument ranks do not match the rank of the lower bounds (" << m_lower_bounds "Argument ranks do not match the rank of the lower bounds (",
<< "), upper bounds (" << m_upper_bounds << "), and strides (" << m_strides << ")."; m_lower_bounds,
"), upper bounds (",
m_upper_bounds,
"), and strides (",
m_strides,
").");
std::vector<Dimension> sliced_dims(output_rank); std::vector<Dimension> sliced_dims(output_rank);
for (size_t i = 0; i < output_rank; i++) for (size_t i = 0; i < output_rank; i++)
{ {
NODE_VALIDATION_ASSERT(this, NODE_VALIDATION_CHECK(this,
arg0_shape.rank().is_dynamic() || arg0_shape[i].is_dynamic() || arg0_shape.rank().is_dynamic() || arg0_shape[i].is_dynamic() ||
m_upper_bounds[i] <= size_t(arg0_shape[i])) m_upper_bounds[i] <= size_t(arg0_shape[i]),
<< "Upper bound for slice at axis " << i << " is out of range " "Upper bound for slice at axis ",
<< "(upper bounds: " << m_upper_bounds << ", argument shape: " << arg0_shape << ")."; i,
" is out of range ",
"(upper bounds: ",
m_upper_bounds,
", argument shape: ",
arg0_shape,
").");
size_t sliced_dim = m_upper_bounds[i] - m_lower_bounds[i]; size_t sliced_dim = m_upper_bounds[i] - m_lower_bounds[i];
sliced_dim = sliced_dim / m_strides[i] + ((sliced_dim % m_strides[i] == 0) ? 0 : 1); sliced_dim = sliced_dim / m_strides[i] + ((sliced_dim % m_strides[i] == 0) ? 0 : 1);
...@@ -110,9 +144,14 @@ void op::UpdateSlice::validate_and_infer_types() ...@@ -110,9 +144,14 @@ void op::UpdateSlice::validate_and_infer_types()
PartialShape slice_shape{sliced_dims}; PartialShape slice_shape{sliced_dims};
NODE_VALIDATION_ASSERT(this, arg1_shape.compatible(slice_shape)) NODE_VALIDATION_CHECK(this,
<< "Shape of replacement tensor (" << arg1_shape << ") does not match the slice shape " arg1_shape.compatible(slice_shape),
<< "(" << slice_shape << ")."; "Shape of replacement tensor (",
arg1_shape,
") does not match the slice shape ",
"(",
slice_shape,
").");
// Slight corner case here: if arg0 was rank-unknown, we can go ahead and set the output rank // Slight corner case here: if arg0 was rank-unknown, we can go ahead and set the output rank
// because the attribs will have given us enough info. // because the attribs will have given us enough info.
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment