Commit 066037c2 authored by Adam Procter's avatar Adam Procter Committed by Scott Cyphers

Switch everything from NODE_VALIDATION_ASSERT to NODE_VALIDATION_CHECK (#2546)

parent dd23b0cb
......@@ -275,19 +275,6 @@ namespace ngraph
size_t m_placement_index = placement_invalid;
};
class NodeValidationError : public AssertionFailure
{
public:
NodeValidationError(std::string what)
: AssertionFailure(what)
{
}
NodeValidationError(const char* what)
: AssertionFailure(what)
{
}
};
class NodeValidationFailure : public CheckFailure
{
public:
......@@ -321,12 +308,5 @@ namespace ngraph
void check_new_args_count(const Node* node, const NodeVector& new_args);
} // namespace ngraph
#define NODE_VALIDATION_ASSERT(node, cond) \
NGRAPH_ASSERT_STREAM_WITH_LOC( \
::ngraph::NodeValidationError, cond, ::ngraph::node_validation_assertion_string(node))
#define NODE_VALIDATION_FAIL(node) \
NGRAPH_FAIL_STREAM_WITH_LOC(::ngraph::NodeValidationError, \
::ngraph::node_validation_assertion_string(node))
#define NODE_VALIDATION_CHECK(node, cond, ...) \
NGRAPH_CHECK(::NodeValidationFailure, (node), (cond), __VA_ARGS__)
NGRAPH_CHECK(::ngraph::NodeValidationFailure, (node), (cond), __VA_ARGS__)
......@@ -27,12 +27,13 @@ op::AllReduce::AllReduce(const shared_ptr<Node>& arg)
void op::AllReduce::validate_and_infer_types()
{
NODE_VALIDATION_ASSERT(this,
get_input_element_type(0).is_dynamic() ||
get_input_element_type(0) == element::f32 ||
get_input_element_type(0) == element::f64)
<< "Only element types f32 and f64 are supported (argument element type: "
<< get_input_element_type(0) << ").";
NODE_VALIDATION_CHECK(this,
get_input_element_type(0).is_dynamic() ||
get_input_element_type(0) == element::f32 ||
get_input_element_type(0) == element::f64,
"Only element types f32 and f64 are supported (argument element type: ",
get_input_element_type(0),
").");
set_output_type(0, get_input_element_type(0), get_input_partial_shape(0));
}
......
......@@ -132,9 +132,15 @@ void op::AvgPoolBackprop::validate_and_infer_types()
const PartialShape& delta_shape = get_input_partial_shape(0);
NODE_VALIDATION_ASSERT(this, forward_result_shape.compatible(delta_shape))
<< "Inferred forward output shape does not match delta shape (inferred forward output "
<< "shape: " << forward_result_shape << ", delta shape: " << delta_shape << ").";
NODE_VALIDATION_CHECK(
this,
forward_result_shape.compatible(delta_shape),
"Inferred forward output shape does not match delta shape (inferred forward output ",
"shape: ",
forward_result_shape,
", delta shape: ",
delta_shape,
").");
// TODO(amprocte): Once m_forward_arg_shape is allowed to be dynamic, we may technically be
// able to infer some extra information from forward_result_shape that was not present in the
......
......@@ -205,21 +205,26 @@ void ngraph::op::BatchNormTrainingBackprop::validate_and_infer_types()
{
PartialShape input_and_delta_shape{get_input_partial_shape(INPUT_DATA)};
NODE_VALIDATION_ASSERT(
this, PartialShape::merge_into(input_and_delta_shape, get_input_partial_shape(INPUT_DELTA)))
<< "Shape of delta does not match the shape of the input data (input data shape: "
<< get_input_partial_shape(INPUT_DATA)
<< ", delta shape: " << get_input_partial_shape(INPUT_DELTA) << ").";
NODE_VALIDATION_CHECK(
this,
PartialShape::merge_into(input_and_delta_shape, get_input_partial_shape(INPUT_DELTA)),
"Shape of delta does not match the shape of the input data (input data shape: ",
get_input_partial_shape(INPUT_DATA),
", delta shape: ",
get_input_partial_shape(INPUT_DELTA),
").");
element::Type input_and_delta_et;
NODE_VALIDATION_ASSERT(this,
element::Type::merge(input_and_delta_et,
get_input_element_type(INPUT_DATA),
get_input_element_type(INPUT_DELTA)))
<< "Element type for input (" << get_input_element_type(INPUT_DATA)
<< ") does not match element type for delta (" << get_input_element_type(INPUT_DATA)
<< ").";
NODE_VALIDATION_CHECK(this,
element::Type::merge(input_and_delta_et,
get_input_element_type(INPUT_DATA),
get_input_element_type(INPUT_DELTA)),
"Element type for input (",
get_input_element_type(INPUT_DATA),
") does not match element type for delta (",
get_input_element_type(INPUT_DATA),
").");
element::Type result_et;
PartialShape result_batch_shape;
......
......@@ -44,9 +44,16 @@ void op::Broadcast::validate_and_infer_types()
for (auto axis : m_broadcast_axes)
{
NODE_VALIDATION_ASSERT(this, axis < m_shape.size())
<< "Broadcast axis index (" << axis << ") exceeds specified output shape rank "
<< "(broadcast axes: " << m_broadcast_axes << ", output shape: " << m_shape << ").";
NODE_VALIDATION_CHECK(this,
axis < m_shape.size(),
"Broadcast axis index (",
axis,
") exceeds specified output shape rank ",
"(broadcast axes: ",
m_broadcast_axes,
", output shape: ",
m_shape,
").");
}
Shape required_input_shape = m_shape;
......@@ -59,10 +66,17 @@ void op::Broadcast::validate_and_infer_types()
// There are two things that can go wrong, which are being picked up in
// one fell swoop by this check: either the number of broadcast axes is not
// enough, or there is a mismatch with one of the pre-broadcast axis lengths.
NODE_VALIDATION_ASSERT(this, get_input_partial_shape(0).compatible(required_input_shape))
<< "Broadcast argument shape, specified output shape, and axes are incompatible "
<< "(argument shape: " << get_input_partial_shape(0) << ", output shape: " << m_shape
<< ", broadcast axes: " << m_broadcast_axes << ").";
NODE_VALIDATION_CHECK(
this,
get_input_partial_shape(0).compatible(required_input_shape),
"Broadcast argument shape, specified output shape, and axes are incompatible ",
"(argument shape: ",
get_input_partial_shape(0),
", output shape: ",
m_shape,
", broadcast axes: ",
m_broadcast_axes,
").");
set_output_type(0, get_input_element_type(0), m_shape);
}
......
......@@ -32,7 +32,7 @@ op::Concat::Concat(const NodeVector& args, size_t concatenation_axis)
void op::Concat::validate_and_infer_types()
{
NODE_VALIDATION_ASSERT(this, m_inputs.size() >= 1) << "At least one argument required.";
NODE_VALIDATION_CHECK(this, m_inputs.size() >= 1, "At least one argument required.");
PartialShape inputs_shape_scheme{PartialShape::dynamic()};
element::Type inputs_et{element::dynamic};
......@@ -44,22 +44,32 @@ void op::Concat::validate_and_infer_types()
Dimension this_input_rank = this_input_shape.rank();
if (this_input_rank.is_static())
{
NODE_VALIDATION_ASSERT(this, m_concatenation_axis < size_t(this_input_rank))
<< "Concatenation axis (" << m_concatenation_axis << ") is out of bounds for "
<< "argument " << i << ", which has shape " << this_input_shape << ".";
NODE_VALIDATION_CHECK(this,
m_concatenation_axis < size_t(this_input_rank),
"Concatenation axis (",
m_concatenation_axis,
") is out of bounds for ",
"argument ",
i,
", which has shape ",
this_input_shape,
".");
concatenation_axis_output_dim += this_input_shape[m_concatenation_axis];
this_input_shape[m_concatenation_axis] = Dimension::dynamic();
NODE_VALIDATION_ASSERT(this,
PartialShape::merge_into(inputs_shape_scheme, this_input_shape))
<< "Argument shapes are inconsistent; they must have the same rank, and must have "
<< "equal dimension everywhere except on the concatenation axis (axis "
<< m_concatenation_axis << ").";
NODE_VALIDATION_ASSERT(
this, element::Type::merge(inputs_et, inputs_et, get_input_element_type(i)))
<< "Argument element types are inconsistent.";
NODE_VALIDATION_CHECK(
this,
PartialShape::merge_into(inputs_shape_scheme, this_input_shape),
"Argument shapes are inconsistent; they must have the same rank, and must have ",
"equal dimension everywhere except on the concatenation axis (axis ",
m_concatenation_axis,
").");
NODE_VALIDATION_CHECK(
this,
element::Type::merge(inputs_et, inputs_et, get_input_element_type(i)),
"Argument element types are inconsistent.");
}
else
{
......
......@@ -47,11 +47,17 @@ namespace ngraph
, m_data(ngraph::aligned_alloc(m_element_type.size(),
shape_size(m_shape) * m_element_type.size()))
{
NODE_VALIDATION_ASSERT(this,
values.size() == 1 || values.size() == shape_size(m_shape))
<< "Did not get the expected number of literals for a constant of shape "
<< m_shape << " (got " << values.size() << ", expected "
<< (shape_size(m_shape) == 1 ? "" : "1 or ") << shape_size(m_shape) << ").";
NODE_VALIDATION_CHECK(
this,
values.size() == 1 || values.size() == shape_size(m_shape),
"Did not get the expected number of literals for a constant of shape ",
m_shape,
" (got ",
values.size(),
", expected ",
(shape_size(m_shape) == 1 ? "" : "1 or "),
shape_size(m_shape),
").");
if (values.size() == 1)
{
......@@ -77,10 +83,16 @@ namespace ngraph
, m_data(ngraph::aligned_alloc(m_element_type.size(),
shape_size(m_shape) * m_element_type.size()))
{
NODE_VALIDATION_ASSERT(this, values.size() == shape_size(m_shape))
<< "Did not get the expected number of literals for a constant of shape "
<< m_shape << " (got " << values.size() << ", expected " << shape_size(m_shape)
<< ".";
NODE_VALIDATION_CHECK(
this,
values.size() == shape_size(m_shape),
"Did not get the expected number of literals for a constant of shape ",
m_shape,
" (got ",
values.size(),
", expected ",
shape_size(m_shape),
".");
std::vector<double> dvalues = parse_string<double>(values);
write_values(dvalues);
......
This diff is collapsed.
......@@ -42,50 +42,73 @@ void op::Dequantize::validate_and_infer_types()
OFFSET
};
NODE_VALIDATION_ASSERT(this, m_type.is_static()) << "Output element type must not be dynamic";
NODE_VALIDATION_CHECK(this, m_type.is_static(), "Output element type must not be dynamic");
NODE_VALIDATION_ASSERT(this, m_type.is_real()) << "Output element type (" << m_type
<< ") must be a floating point type";
NODE_VALIDATION_CHECK(
this, m_type.is_real(), "Output element type (", m_type, ") must be a floating point type");
element::Type quantized_type;
NODE_VALIDATION_ASSERT(this,
element::Type::merge(quantized_type,
get_input_element_type(INPUT),
get_input_element_type(OFFSET)))
<< "Offset element type (" << get_input_element_type(OFFSET)
<< ") must match input element type (" << get_input_element_type(INPUT) << ")";
NODE_VALIDATION_ASSERT(this, quantized_type.is_dynamic() || quantized_type.is_quantized())
<< "Offset/input element type (" << quantized_type << ") must be a quantized type";
NODE_VALIDATION_CHECK(this,
element::Type::merge(quantized_type,
get_input_element_type(INPUT),
get_input_element_type(OFFSET)),
"Offset element type (",
get_input_element_type(OFFSET),
") must match input element type (",
get_input_element_type(INPUT),
")");
NODE_VALIDATION_CHECK(this,
quantized_type.is_dynamic() || quantized_type.is_quantized(),
"Offset/input element type (",
quantized_type,
") must be a quantized type");
element::Type unquantized_type;
NODE_VALIDATION_ASSERT(
this, element::Type::merge(unquantized_type, get_input_element_type(SCALE), m_type))
<< "Scale element type (" << get_input_element_type(SCALE)
<< ") must match output element type (" << m_type << ")";
NODE_VALIDATION_CHECK(
this,
element::Type::merge(unquantized_type, get_input_element_type(SCALE), m_type),
"Scale element type (",
get_input_element_type(SCALE),
") must match output element type (",
m_type,
")");
PartialShape input_shape = get_input_partial_shape(0);
Dimension input_rank = input_shape.rank();
for (auto axis : m_axes)
{
NODE_VALIDATION_ASSERT(this, input_rank.is_dynamic() || axis < size_t(input_rank))
<< "Quantization axis (" << axis << ") must be less than input shape rank ("
<< input_rank << ")";
NODE_VALIDATION_CHECK(this,
input_rank.is_dynamic() || axis < size_t(input_rank),
"Quantization axis (",
axis,
") must be less than input shape rank (",
input_rank,
")");
}
PartialShape scale_offset_shape = get_input_partial_shape(SCALE);
NODE_VALIDATION_ASSERT(
this, PartialShape::merge_into(scale_offset_shape, get_input_partial_shape(OFFSET)))
<< "Scale shape (" << get_input_partial_shape(SCALE) << ") and offset shape ("
<< get_input_partial_shape(OFFSET) << ") must match";
NODE_VALIDATION_ASSERT(this, scale_offset_shape.rank().compatible(m_axes.size()))
<< "Scale/offset rank (" << scale_offset_shape.rank() << ") does not match the number of "
<< "quantization axes (" << m_axes.size() << ")";
NODE_VALIDATION_CHECK(
this,
PartialShape::merge_into(scale_offset_shape, get_input_partial_shape(OFFSET)),
"Scale shape (",
get_input_partial_shape(SCALE),
") and offset shape (",
get_input_partial_shape(OFFSET),
") must match");
NODE_VALIDATION_CHECK(this,
scale_offset_shape.rank().compatible(m_axes.size()),
"Scale/offset rank (",
scale_offset_shape.rank(),
") does not match the number of ",
"quantization axes (",
m_axes.size(),
")");
set_output_size(1);
......@@ -108,10 +131,16 @@ void op::Dequantize::validate_and_infer_types()
}
PartialShape result_shape = input_shape;
NODE_VALIDATION_ASSERT(
this, PartialShape::merge_into(result_shape, PartialShape{injected_scale_offset_dims}))
<< "Scale/offset shape (" << scale_offset_shape << ") must match input shape ("
<< input_shape << ") at the quantization axes (" << m_axes << ")";
NODE_VALIDATION_CHECK(
this,
PartialShape::merge_into(result_shape, PartialShape{injected_scale_offset_dims}),
"Scale/offset shape (",
scale_offset_shape,
") must match input shape (",
input_shape,
") at the quantization axes (",
m_axes,
")");
set_output_type(0, unquantized_type, result_shape);
}
else
......
......@@ -49,11 +49,14 @@ void op::Dot::validate_and_infer_types()
{
element::Type result_et;
NODE_VALIDATION_ASSERT(
this, element::Type::merge(result_et, get_input_element_type(0), get_input_element_type(1)))
<< "Arguments do not have the same element type (arg0 element type: "
<< get_input_element_type(0) << ", arg1 element type: " << get_input_element_type(1)
<< ").";
NODE_VALIDATION_CHECK(
this,
element::Type::merge(result_et, get_input_element_type(0), get_input_element_type(1)),
"Arguments do not have the same element type (arg0 element type: ",
get_input_element_type(0),
", arg1 element type: ",
get_input_element_type(1),
").");
const PartialShape& arg0_shape = get_input_partial_shape(0);
const PartialShape& arg1_shape = get_input_partial_shape(1);
......@@ -82,17 +85,27 @@ void op::Dot::validate_and_infer_types()
PartialShape result_shape;
NODE_VALIDATION_ASSERT(this,
reduction_axes_ambiguous || arg0_shape.rank().is_dynamic() ||
m_reduction_axes_count <= size_t(arg0_shape.rank()))
<< "Reduction axes count (" << m_reduction_axes_count
<< ") is too large (arg0 shape: " << arg0_shape << ", arg1 shape: " << arg1_shape << ").";
NODE_VALIDATION_ASSERT(this,
reduction_axes_ambiguous || arg1_shape.rank().is_dynamic() ||
m_reduction_axes_count <= size_t(arg1_shape.rank()))
<< "Reduction axes count (" << m_reduction_axes_count
<< ") is too large (arg0 shape: " << arg0_shape << ", arg1 shape: " << arg1_shape << ").";
NODE_VALIDATION_CHECK(this,
reduction_axes_ambiguous || arg0_shape.rank().is_dynamic() ||
m_reduction_axes_count <= size_t(arg0_shape.rank()),
"Reduction axes count (",
m_reduction_axes_count,
") is too large (arg0 shape: ",
arg0_shape,
", arg1 shape: ",
arg1_shape,
").");
NODE_VALIDATION_CHECK(this,
reduction_axes_ambiguous || arg1_shape.rank().is_dynamic() ||
m_reduction_axes_count <= size_t(arg1_shape.rank()),
"Reduction axes count (",
m_reduction_axes_count,
") is too large (arg0 shape: ",
arg0_shape,
", arg1 shape: ",
arg1_shape,
").");
if (!reduction_axes_ambiguous && arg0_shape.rank().is_static() && arg1_shape.rank().is_static())
{
......@@ -101,12 +114,20 @@ void op::Dot::validate_and_infer_types()
size_t axis_index_arg0 = size_t(arg0_shape.rank()) - m_reduction_axes_count + i;
size_t axis_index_arg1 = i;
NODE_VALIDATION_ASSERT(
this, arg0_shape[axis_index_arg0].compatible(arg1_shape[axis_index_arg1]))
<< "Paired axes (axis " << axis_index_arg0 << " from arg0, axis " << axis_index_arg1
<< " from arg1) do not have same length (arg0 shape: " << arg0_shape
<< ", arg1 shape: " << arg1_shape
<< ", reduction axes count: " << m_reduction_axes_count << ").";
NODE_VALIDATION_CHECK(
this,
arg0_shape[axis_index_arg0].compatible(arg1_shape[axis_index_arg1]),
"Paired axes (axis ",
axis_index_arg0,
" from arg0, axis ",
axis_index_arg1,
" from arg1) do not have same length (arg0 shape: ",
arg0_shape,
", arg1 shape: ",
arg1_shape,
", reduction axes count: ",
m_reduction_axes_count,
").");
}
std::vector<Dimension> result_dims(size_t(arg0_shape.rank()) + size_t(arg1_shape.rank()) -
......
......@@ -26,9 +26,10 @@ void op::EmbeddingLookup::validate_and_infer_types()
const PartialShape& arg0_shape = get_input_partial_shape(0);
const PartialShape& arg1_shape = get_input_partial_shape(1);
NODE_VALIDATION_ASSERT(
this, arg1_shape.rank().is_dynamic() || static_cast<size_t>(arg1_shape.rank()) == 2)
<< "weights are expected to be a matrix";
NODE_VALIDATION_CHECK(this,
arg1_shape.rank().is_dynamic() ||
static_cast<size_t>(arg1_shape.rank()) == 2,
"weights are expected to be a matrix");
PartialShape result_shape;
if (arg0_shape.rank().is_static())
......
......@@ -42,11 +42,12 @@ shared_ptr<Node> op::GenerateMask::copy_with_new_args(const NodeVector& new_args
void ngraph::op::GenerateMask::validate_and_infer_types()
{
NODE_VALIDATION_ASSERT(this, get_input_partial_shape(0).compatible(PartialShape{}))
<< "Training node should be a scalar flag indicating a mode";
NODE_VALIDATION_CHECK(this,
get_input_partial_shape(0).compatible(PartialShape{}),
"Training node should be a scalar flag indicating a mode");
NODE_VALIDATION_ASSERT(this, m_element_type.is_static())
<< "Output element type must not be dynamic.";
NODE_VALIDATION_CHECK(
this, m_element_type.is_static(), "Output element type must not be dynamic.");
set_output_type(0, m_element_type, m_shape);
}
......@@ -65,36 +65,58 @@ void op::QuantizedAvgPool::validate_and_infer_types()
// Make sure batch size and channel count are not zero, and that we have at least one spatial
// dimension (in other words, that arg has shape NCDi for some Di of rank>0, N != 0, C != 0).
//
NODE_VALIDATION_ASSERT(this, arg_shape.size() >= 3)
<< "Data input shape does not have rank of at least 3 (data input shape: " << arg_shape
<< ").";
NODE_VALIDATION_CHECK(this,
arg_shape.size() >= 3,
"Data input shape does not have rank of at least 3 (data input shape: ",
arg_shape,
").");
size_t batch_size = arg_shape[0];
NODE_VALIDATION_ASSERT(this, batch_size != 0)
<< "Data batch size is zero (data input shape: " << arg_shape << ").";
NODE_VALIDATION_CHECK(
this, batch_size != 0, "Data batch size is zero (data input shape: ", arg_shape, ").");
size_t channel_count = arg_shape[1];
NODE_VALIDATION_ASSERT(this, channel_count != 0)
<< "Channel count is zero (data input shape: " << arg_shape << ").";
NODE_VALIDATION_CHECK(
this, channel_count != 0, "Channel count is zero (data input shape: ", arg_shape, ").");
size_t spatial_dimension_count = arg_shape.size() - 2;
//
// Make sure window shape, window movement strides, and padding have same rank as Di.
//
NODE_VALIDATION_ASSERT(this, m_window_shape.size() == spatial_dimension_count)
<< "Window shape rank does not match number of spatial dimensions (window shape: "
<< m_window_shape << ", data input shape: " << arg_shape << ").";
NODE_VALIDATION_ASSERT(this, m_window_movement_strides.size() == spatial_dimension_count)
<< "Window movement stride rank does not match number of spatial dimensions (window "
"movement strides: "
<< m_window_movement_strides << ", data input shape: " << arg_shape << ").";
NODE_VALIDATION_ASSERT(this, m_padding_below.size() == spatial_dimension_count)
<< "Below-padding rank does not match number of spatial dimensions (padding below: "
<< m_padding_below << ", data input shape: " << arg_shape << ").";
NODE_VALIDATION_ASSERT(this, m_padding_above.size() == spatial_dimension_count)
<< "Above-padding rank does not match number of spatial dimensions (padding above: "
<< m_padding_above << ", data input shape: " << arg_shape << ").";
NODE_VALIDATION_CHECK(
this,
m_window_shape.size() == spatial_dimension_count,
"Window shape rank does not match number of spatial dimensions (window shape: ",
m_window_shape,
", data input shape: ",
arg_shape,
").");
NODE_VALIDATION_CHECK(
this,
m_window_movement_strides.size() == spatial_dimension_count,
"Window movement stride rank does not match number of spatial dimensions (window "
"movement strides: ",
m_window_movement_strides,
", data input shape: ",
arg_shape,
").");
NODE_VALIDATION_CHECK(
this,
m_padding_below.size() == spatial_dimension_count,
"Below-padding rank does not match number of spatial dimensions (padding below: ",
m_padding_below,
", data input shape: ",
arg_shape,
").");
NODE_VALIDATION_CHECK(
this,
m_padding_above.size() == spatial_dimension_count,
"Above-padding rank does not match number of spatial dimensions (padding above: ",
m_padding_above,
", data input shape: ",
arg_shape,
").");
//
// Extract input item shape Di and make sure all dimensions are larger than 0.
......@@ -110,10 +132,13 @@ void op::QuantizedAvgPool::validate_and_infer_types()
for (size_t i = 0; i < spatial_dimension_count; i++)
{
NODE_VALIDATION_ASSERT(this, input_item_virtual_shape[i] != 0)
<< "Data input spatial dimension " << i
<< " has zero length even after padding (virtual shape of input item: "
<< input_item_virtual_shape << ").";
NODE_VALIDATION_CHECK(this,
input_item_virtual_shape[i] != 0,
"Data input spatial dimension ",
i,
" has zero length even after padding (virtual shape of input item: ",
input_item_virtual_shape,
").");
}
//
......@@ -121,9 +146,13 @@ void op::QuantizedAvgPool::validate_and_infer_types()
//
for (size_t i = 0; i < spatial_dimension_count; i++)
{
NODE_VALIDATION_ASSERT(this, m_window_shape[i] != 0)
<< "Window shape dimension " << i
<< " has zero length (window shape: " << m_window_shape << ").";
NODE_VALIDATION_CHECK(this,
m_window_shape[i] != 0,
"Window shape dimension ",
i,
" has zero length (window shape: ",
m_window_shape,
").");
}
//
......@@ -131,10 +160,14 @@ void op::QuantizedAvgPool::validate_and_infer_types()
//
for (size_t i = 0; i < spatial_dimension_count; i++)
{
NODE_VALIDATION_ASSERT(this, m_window_shape[i] <= input_item_virtual_shape[i])
<< "Window shape after padding is larger than the spatial dimensions (window shape: "
<< m_window_shape << ", virtual shape of input item: " << input_item_virtual_shape
<< ").";
NODE_VALIDATION_CHECK(
this,
m_window_shape[i] <= input_item_virtual_shape[i],
"Window shape after padding is larger than the spatial dimensions (window shape: ",
m_window_shape,
", virtual shape of input item: ",
input_item_virtual_shape,
").");
}
//
// Compute output item shape Do, checking at the same time that all window movement strides are larger than 0.
......@@ -143,9 +176,13 @@ void op::QuantizedAvgPool::validate_and_infer_types()
for (size_t i = 0; i < spatial_dimension_count; i++)
{
NODE_VALIDATION_ASSERT(this, m_window_movement_strides[i] != 0)
<< "Window movement strides dimension " << i
<< " has zero length (window movement strides: " << m_window_movement_strides << ").";
NODE_VALIDATION_CHECK(this,
m_window_movement_strides[i] != 0,
"Window movement strides dimension ",
i,
" has zero length (window movement strides: ",
m_window_movement_strides,
").");
output_item_shape.push_back(ceil_div(input_item_virtual_shape[i] - m_window_shape[i] + 1,
m_window_movement_strides[i]));
}
......@@ -167,11 +204,15 @@ void op::QuantizedAvgPool::validate_and_infer_types()
// Checking the lower edge of each dimension is easy, because there's no mystery
// regarding the window's lower-edge placement...
NODE_VALIDATION_ASSERT(this,
dim_padding_below == 0 || dim_window_size > dim_padding_below)
<< "Window will sometimes reside entirely within the below-padding region, but"
<< " include_padding_in_avg_computation was not set (padding below: "
<< m_padding_below << ", window shape: " << m_window_shape << ").";
NODE_VALIDATION_CHECK(
this,
dim_padding_below == 0 || dim_window_size > dim_padding_below,
"Window will sometimes reside entirely within the below-padding region, but",
" include_padding_in_avg_computation was not set (padding below: ",
m_padding_below,
", window shape: ",
m_window_shape,
").");
// Now check the upper-bound...
{
......@@ -179,13 +220,16 @@ void op::QuantizedAvgPool::validate_and_infer_types()
const size_t dim_window_max_lower_offset = dim_num_strides * dim_stride;
const size_t dim_padding_above_start_offset = dim_virtual_size - dim_padding_above;
NODE_VALIDATION_ASSERT(this,
dim_padding_above == 0 ||
dim_window_max_lower_offset <
dim_padding_above_start_offset)
<< "Window will sometimes reside entirely within the above-padding region, but"
<< " include_padding_in_avg_computation was not set (padding above: "
<< m_padding_above << ", window shape: " << m_window_shape << ").";
NODE_VALIDATION_CHECK(
this,
dim_padding_above == 0 ||
dim_window_max_lower_offset < dim_padding_above_start_offset,
"Window will sometimes reside entirely within the above-padding region, but",
" include_padding_in_avg_computation was not set (padding above: ",
m_padding_above,
", window shape: ",
m_window_shape,
").");
}
}
}
......
......@@ -33,7 +33,7 @@ op::QuantizedConcat::QuantizedConcat(const NodeVector& args, size_t concatenatio
void op::QuantizedConcat::validate_and_infer_types()
{
NODE_VALIDATION_ASSERT(this, m_inputs.size() >= 1) << "At least one argument required.";
NODE_VALIDATION_CHECK(this, m_inputs.size() >= 1, "At least one argument required.");
PartialShape inputs_shape_scheme{PartialShape::dynamic()};
element::Type inputs_et{element::dynamic};
......@@ -45,23 +45,32 @@ void op::QuantizedConcat::validate_and_infer_types()
Dimension this_input_rank = this_input_shape.rank();
if (this_input_rank.is_static())
{
NODE_VALIDATION_ASSERT(this, m_concatenation_axis < size_t(this_input_rank))
<< "QuantizedConcatenation axis (" << m_concatenation_axis
<< ") is out of bounds for "
<< "argument " << i << ", which has shape " << this_input_shape << ".";
NODE_VALIDATION_CHECK(this,
m_concatenation_axis < size_t(this_input_rank),
"QuantizedConcatenation axis (",
m_concatenation_axis,
") is out of bounds for ",
"argument ",
i,
", which has shape ",
this_input_shape,
".");
concatenation_axis_output_dim += this_input_shape[m_concatenation_axis];
this_input_shape[m_concatenation_axis] = Dimension::dynamic();
NODE_VALIDATION_ASSERT(this,
PartialShape::merge_into(inputs_shape_scheme, this_input_shape))
<< "Argument shapes are inconsistent; they must have the same rank, and must have "
<< "equal dimension everywhere except on the concatenation axis (axis "
<< m_concatenation_axis << ").";
NODE_VALIDATION_CHECK(
this,
PartialShape::merge_into(inputs_shape_scheme, this_input_shape),
"Argument shapes are inconsistent; they must have the same rank, and must have ",
"equal dimension everywhere except on the concatenation axis (axis ",
m_concatenation_axis,
").");
NODE_VALIDATION_ASSERT(
this, element::Type::merge(inputs_et, inputs_et, get_input_element_type(i)))
<< "Argument element types are inconsistent.";
NODE_VALIDATION_CHECK(
this,
element::Type::merge(inputs_et, inputs_et, get_input_element_type(i)),
"Argument element types are inconsistent.");
}
else
{
......
......@@ -64,36 +64,58 @@ void op::QuantizedMaxPool::validate_and_infer_types()
// Make sure batch size and channel count are not zero, and that we have at least one spatial
// dimension (in other words, that arg has shape NCDi for some Di of rank>0, N != 0, C != 0).
//
NODE_VALIDATION_ASSERT(this, arg_shape.size() >= 3)
<< "Data input shape does not have rank of at least 3 (data input shape: " << arg_shape
<< ").";
NODE_VALIDATION_CHECK(this,
arg_shape.size() >= 3,
"Data input shape does not have rank of at least 3 (data input shape: ",
arg_shape,
").");
size_t batch_size = arg_shape[0];
NODE_VALIDATION_ASSERT(this, batch_size != 0)
<< "Data batch size is zero (data input shape: " << arg_shape << ").";
NODE_VALIDATION_CHECK(
this, batch_size != 0, "Data batch size is zero (data input shape: ", arg_shape, ").");
size_t channel_count = arg_shape[1];
NODE_VALIDATION_ASSERT(this, channel_count != 0)
<< "Channel count is zero (data input shape: " << arg_shape << ").";
NODE_VALIDATION_CHECK(
this, channel_count != 0, "Channel count is zero (data input shape: ", arg_shape, ").");
size_t spatial_dimension_count = arg_shape.size() - 2;
//
// Make sure window shape, window movement strides, and padding have same rank as Di.
//
NODE_VALIDATION_ASSERT(this, m_window_shape.size() == spatial_dimension_count)
<< "Window shape rank does not match number of spatial dimensions (window shape: "
<< m_window_shape << ", data input shape: " << arg_shape << ").";
NODE_VALIDATION_ASSERT(this, m_window_movement_strides.size() == spatial_dimension_count)
<< "Window movement stride rank does not match number of spatial dimensions (window "
"movement strides: "
<< m_window_movement_strides << ", data input shape: " << arg_shape << ").";
NODE_VALIDATION_ASSERT(this, m_padding_below.size() == spatial_dimension_count)
<< "Below-padding rank does not match number of spatial dimensions (padding below: "
<< m_padding_below << ", data input shape: " << arg_shape << ").";
NODE_VALIDATION_ASSERT(this, m_padding_above.size() == spatial_dimension_count)
<< "Above-padding rank does not match number of spatial dimensions (padding above: "
<< m_padding_above << ", data input shape: " << arg_shape << ").";
NODE_VALIDATION_CHECK(
this,
m_window_shape.size() == spatial_dimension_count,
"Window shape rank does not match number of spatial dimensions (window shape: ",
m_window_shape,
", data input shape: ",
arg_shape,
").");
NODE_VALIDATION_CHECK(
this,
m_window_movement_strides.size() == spatial_dimension_count,
"Window movement stride rank does not match number of spatial dimensions (window "
"movement strides: ",
m_window_movement_strides,
", data input shape: ",
arg_shape,
").");
NODE_VALIDATION_CHECK(
this,
m_padding_below.size() == spatial_dimension_count,
"Below-padding rank does not match number of spatial dimensions (padding below: ",
m_padding_below,
", data input shape: ",
arg_shape,
").");
NODE_VALIDATION_CHECK(
this,
m_padding_above.size() == spatial_dimension_count,
"Above-padding rank does not match number of spatial dimensions (padding above: ",
m_padding_above,
", data input shape: ",
arg_shape,
").");
//
// Extract input item shape Di and make sure all dimensions are larger than 0.
......@@ -109,10 +131,13 @@ void op::QuantizedMaxPool::validate_and_infer_types()
for (size_t i = 0; i < spatial_dimension_count; i++)
{
NODE_VALIDATION_ASSERT(this, input_item_virtual_shape[i] != 0)
<< "Data input spatial dimension " << i
<< " has zero length even after padding (virtual shape of input item: "
<< input_item_virtual_shape << ").";
NODE_VALIDATION_CHECK(this,
input_item_virtual_shape[i] != 0,
"Data input spatial dimension ",
i,
" has zero length even after padding (virtual shape of input item: ",
input_item_virtual_shape,
").");
}
//
......@@ -120,9 +145,13 @@ void op::QuantizedMaxPool::validate_and_infer_types()
//
for (size_t i = 0; i < spatial_dimension_count; i++)
{
NODE_VALIDATION_ASSERT(this, m_window_shape[i] != 0)
<< "Window shape dimension " << i
<< " has zero length (window shape: " << m_window_shape << ").";
NODE_VALIDATION_CHECK(this,
m_window_shape[i] != 0,
"Window shape dimension ",
i,
" has zero length (window shape: ",
m_window_shape,
").");
}
//
......@@ -130,10 +159,14 @@ void op::QuantizedMaxPool::validate_and_infer_types()
//
for (size_t i = 0; i < spatial_dimension_count; i++)
{
NODE_VALIDATION_ASSERT(this, m_window_shape[i] <= input_item_virtual_shape[i])
<< "Window shape after padding is larger than the spatial dimensions (window shape: "
<< m_window_shape << ", virtual shape of input item: " << input_item_virtual_shape
<< ").";
NODE_VALIDATION_CHECK(
this,
m_window_shape[i] <= input_item_virtual_shape[i],
"Window shape after padding is larger than the spatial dimensions (window shape: ",
m_window_shape,
", virtual shape of input item: ",
input_item_virtual_shape,
").");
}
//
......@@ -143,9 +176,13 @@ void op::QuantizedMaxPool::validate_and_infer_types()
for (size_t i = 0; i < spatial_dimension_count; i++)
{
NODE_VALIDATION_ASSERT(this, m_window_movement_strides[i] != 0)
<< "Window movement strides dimension " << i
<< " has zero length (window movement strides: " << m_window_movement_strides << ").";
NODE_VALIDATION_CHECK(this,
m_window_movement_strides[i] != 0,
"Window movement strides dimension ",
i,
" has zero length (window movement strides: ",
m_window_movement_strides,
").");
output_item_shape.push_back(ceil_div(input_item_virtual_shape[i] - m_window_shape[i] + 1,
m_window_movement_strides[i]));
}
......
......@@ -30,9 +30,13 @@ op::GetOutputElement::GetOutputElement(const shared_ptr<Node>& arg, size_t n)
void op::GetOutputElement::validate_and_infer_types()
{
NODE_VALIDATION_ASSERT(this, m_n < get_input_size())
<< "Output at index " << m_n << " requested, but node has only " << get_input_size()
<< " inputs.";
NODE_VALIDATION_CHECK(this,
m_n < get_input_size(),
"Output at index ",
m_n,
" requested, but node has only ",
get_input_size(),
" inputs.");
set_output_type(0, get_input_element_type(m_n), get_input_partial_shape(m_n));
}
......
......@@ -36,9 +36,12 @@ void op::LRN::validate_and_infer_types()
const PartialShape& input_shape = get_input_partial_shape(0);
NODE_VALIDATION_ASSERT(
this, input_shape.rank().is_dynamic() || static_cast<size_t>(input_shape.rank()) >= 3)
<< "Argument must have rank >= 3 (argument shape: " << input_shape << ").";
NODE_VALIDATION_CHECK(this,
input_shape.rank().is_dynamic() ||
static_cast<size_t>(input_shape.rank()) >= 3,
"Argument must have rank >= 3 (argument shape: ",
input_shape,
").");
}
shared_ptr<Node> op::LRN::copy_with_new_args(const NodeVector& new_args) const
......
......@@ -134,9 +134,13 @@ void op::MaxPoolBackprop::validate_and_infer_types()
element::Type result_et;
NODE_VALIDATION_ASSERT(this, element::Type::merge(result_et, forward_arg_et, delta_et))
<< "Element types for forward argument (" << forward_arg_et << ") and delta (" << delta_et
<< ") do not match.";
NODE_VALIDATION_CHECK(this,
element::Type::merge(result_et, forward_arg_et, delta_et),
"Element types for forward argument (",
forward_arg_et,
") and delta (",
delta_et,
") do not match.");
// infer_batched_forward_pooling wants CoordinateDiffs for these, while the pooling ops for
// now still take Shape (no negative padding).
......@@ -155,9 +159,15 @@ void op::MaxPoolBackprop::validate_and_infer_types()
const PartialShape& delta_shape = get_input_partial_shape(1);
NODE_VALIDATION_ASSERT(this, forward_result_shape.compatible(delta_shape))
<< "Inferred forward output shape does not match delta shape (inferred forward output "
<< "shape: " << forward_result_shape << ", delta shape: " << delta_shape << ").";
NODE_VALIDATION_CHECK(
this,
forward_result_shape.compatible(delta_shape),
"Inferred forward output shape does not match delta shape (inferred forward output ",
"shape: ",
forward_result_shape,
", delta shape: ",
delta_shape,
").");
// TODO(amprocte): We may technically be able to infer some extra information from
// forward_result_shape that was not present in the forward arg shape---namely batch size and
......
......@@ -34,16 +34,25 @@ void op::OneHot::validate_and_infer_types()
PartialShape arg_shape = get_input_partial_shape(0);
Rank arg_rank = arg_shape.rank();
NODE_VALIDATION_ASSERT(this, m_shape.rank().is_static())
<< "Requested result shape has dynamic rank.";
NODE_VALIDATION_CHECK(
this, m_shape.rank().is_static(), "Requested result shape has dynamic rank.");
NODE_VALIDATION_ASSERT(this, m_one_hot_axis < static_cast<size_t>(m_shape.rank()))
<< "One-hot axis (" << m_one_hot_axis
<< ") is out of bounds (requested result shape: " << m_shape << ").";
NODE_VALIDATION_CHECK(this,
m_one_hot_axis < static_cast<size_t>(m_shape.rank()),
"One-hot axis (",
m_one_hot_axis,
") is out of bounds (requested result shape: ",
m_shape,
").");
NODE_VALIDATION_ASSERT(this, m_shape[m_one_hot_axis].is_static())
<< "Requested result shape (" << m_shape << ") has dynamic dimension at the one-hot axis "
<< "(" << m_one_hot_axis << ").";
NODE_VALIDATION_CHECK(this,
m_shape[m_one_hot_axis].is_static(),
"Requested result shape (",
m_shape,
") has dynamic dimension at the one-hot axis ",
"(",
m_one_hot_axis,
").");
PartialShape result_shape{m_shape};
......@@ -58,9 +67,13 @@ void op::OneHot::validate_and_infer_types()
PartialShape expected_input_shape{expected_input_dims};
PartialShape merged_input_shape{expected_input_shape};
NODE_VALIDATION_ASSERT(this, PartialShape::merge_into(merged_input_shape, arg_shape))
<< "Argument shape " << arg_shape << " does not match the expected shape of "
<< expected_input_shape << ".";
NODE_VALIDATION_CHECK(this,
PartialShape::merge_into(merged_input_shape, arg_shape),
"Argument shape ",
arg_shape,
" does not match the expected shape of ",
expected_input_shape,
".");
std::vector<Dimension> output_dims(static_cast<size_t>(merged_input_shape.rank()));
for (size_t i = 0; i < static_cast<size_t>(merged_input_shape.rank()); i++)
......
......@@ -38,31 +38,49 @@ void op::Pad::validate_and_infer_types()
{
element::Type result_et;
NODE_VALIDATION_ASSERT(
this, element::Type::merge(result_et, get_input_element_type(0), get_input_element_type(1)))
<< "Argument element types do not match (arg0 element type: " << get_input_element_type(0)
<< ", arg1 element type: " << get_input_element_type(1) << ").";
NODE_VALIDATION_ASSERT(this, get_input_partial_shape(1).compatible(PartialShape{}))
<< "Argument for padding value is not a scalar (shape: " << get_input_partial_shape(1)
<< ").";
NODE_VALIDATION_CHECK(
this,
element::Type::merge(result_et, get_input_element_type(0), get_input_element_type(1)),
"Argument element types do not match (arg0 element type: ",
get_input_element_type(0),
", arg1 element type: ",
get_input_element_type(1),
").");
NODE_VALIDATION_CHECK(this,
get_input_partial_shape(1).compatible(PartialShape{}),
"Argument for padding value is not a scalar (shape: ",
get_input_partial_shape(1),
").");
auto arg_shape = get_input_partial_shape(0);
NODE_VALIDATION_ASSERT(this,
m_padding_below.size() == m_padding_above.size() &&
m_padding_below.size() == m_padding_interior.size())
<< "Ranks for padding below (" << m_padding_below << "), padding above (" << m_padding_above
<< ") and interior padding (" << m_padding_interior << ") "
<< "do not match.";
NODE_VALIDATION_CHECK(this,
m_padding_below.size() == m_padding_above.size() &&
m_padding_below.size() == m_padding_interior.size(),
"Ranks for padding below (",
m_padding_below,
"), padding above (",
m_padding_above,
") and interior padding (",
m_padding_interior,
") ",
"do not match.");
size_t implied_rank = m_padding_below.size();
NODE_VALIDATION_ASSERT(this, arg_shape.rank().compatible(implied_rank))
<< "Rank for padding below/padding above/interior padding does not match the rank of the "
<< "data argument (padding below: " << m_padding_below << ", "
<< ", padding above: " << m_padding_above << ", interior padding: " << m_padding_interior
<< ").";
NODE_VALIDATION_CHECK(
this,
arg_shape.rank().compatible(implied_rank),
"Rank for padding below/padding above/interior padding does not match the rank of the ",
"data argument (padding below: ",
m_padding_below,
", ",
", padding above: ",
m_padding_above,
", interior padding: ",
m_padding_interior,
").");
std::vector<Dimension> result_dims(implied_rank, Dimension::dynamic());
......
......@@ -44,50 +44,73 @@ void op::Quantize::validate_and_infer_types()
OFFSET
};
NODE_VALIDATION_ASSERT(this, m_type.is_static()) << "Output element type must not be dynamic";
NODE_VALIDATION_CHECK(this, m_type.is_static(), "Output element type must not be dynamic");
NODE_VALIDATION_ASSERT(this, m_type.is_quantized()) << "Output element type (" << m_type
<< ") must be a quantized type";
NODE_VALIDATION_CHECK(
this, m_type.is_quantized(), "Output element type (", m_type, ") must be a quantized type");
element::Type unquantized_type;
NODE_VALIDATION_ASSERT(this,
element::Type::merge(unquantized_type,
get_input_element_type(INPUT),
get_input_element_type(SCALE)))
<< "Scale element type (" << get_input_element_type(SCALE)
<< ") must match input element type (" << get_input_element_type(INPUT) << ")";
NODE_VALIDATION_ASSERT(this, unquantized_type.is_dynamic() || unquantized_type.is_real())
<< "Scale/input element type (" << unquantized_type << ") must be a floating point number";
NODE_VALIDATION_CHECK(this,
element::Type::merge(unquantized_type,
get_input_element_type(INPUT),
get_input_element_type(SCALE)),
"Scale element type (",
get_input_element_type(SCALE),
") must match input element type (",
get_input_element_type(INPUT),
")");
NODE_VALIDATION_CHECK(this,
unquantized_type.is_dynamic() || unquantized_type.is_real(),
"Scale/input element type (",
unquantized_type,
") must be a floating point number");
element::Type quantized_type;
NODE_VALIDATION_ASSERT(
this, element::Type::merge(quantized_type, get_input_element_type(OFFSET), m_type))
<< "Offset element type (" << get_input_element_type(OFFSET)
<< ") must match output element type (" << m_type << ")";
NODE_VALIDATION_CHECK(
this,
element::Type::merge(quantized_type, get_input_element_type(OFFSET), m_type),
"Offset element type (",
get_input_element_type(OFFSET),
") must match output element type (",
m_type,
")");
PartialShape input_shape = get_input_partial_shape(0);
Dimension input_rank = input_shape.rank();
for (auto axis : m_axes)
{
NODE_VALIDATION_ASSERT(this, input_rank.is_dynamic() || axis < size_t(input_rank))
<< "Quantization axis (" << axis << ") must be less than input shape rank ("
<< input_rank << ")";
NODE_VALIDATION_CHECK(this,
input_rank.is_dynamic() || axis < size_t(input_rank),
"Quantization axis (",
axis,
") must be less than input shape rank (",
input_rank,
")");
}
PartialShape scale_offset_shape = get_input_partial_shape(SCALE);
NODE_VALIDATION_ASSERT(
this, PartialShape::merge_into(scale_offset_shape, get_input_partial_shape(OFFSET)))
<< "Scale shape (" << get_input_partial_shape(SCALE) << ") and offset shape ("
<< get_input_partial_shape(OFFSET) << ") must match";
NODE_VALIDATION_ASSERT(this, scale_offset_shape.rank().compatible(m_axes.size()))
<< "Scale/offset rank (" << scale_offset_shape.rank() << ") does not match the number of "
<< "quantization axes (" << m_axes.size() << ")";
NODE_VALIDATION_CHECK(
this,
PartialShape::merge_into(scale_offset_shape, get_input_partial_shape(OFFSET)),
"Scale shape (",
get_input_partial_shape(SCALE),
") and offset shape (",
get_input_partial_shape(OFFSET),
") must match");
NODE_VALIDATION_CHECK(this,
scale_offset_shape.rank().compatible(m_axes.size()),
"Scale/offset rank (",
scale_offset_shape.rank(),
") does not match the number of ",
"quantization axes (",
m_axes.size(),
")");
set_output_size(1);
......@@ -110,10 +133,16 @@ void op::Quantize::validate_and_infer_types()
}
PartialShape result_shape = input_shape;
NODE_VALIDATION_ASSERT(
this, PartialShape::merge_into(result_shape, PartialShape{injected_scale_offset_dims}))
<< "Scale/offset shape (" << scale_offset_shape << ") must match input shape ("
<< input_shape << ") at the quantization axes (" << m_axes << ")";
NODE_VALIDATION_CHECK(
this,
PartialShape::merge_into(result_shape, PartialShape{injected_scale_offset_dims}),
"Scale/offset shape (",
scale_offset_shape,
") must match input shape (",
input_shape,
") at the quantization axes (",
m_axes,
")");
set_output_type(0, quantized_type, result_shape);
}
else
......
......@@ -59,51 +59,85 @@ void op::ReplaceSlice::validate_and_infer_types()
const PartialShape& arg1_shape = get_input_partial_shape(1);
Dimension merged_args_rank;
NODE_VALIDATION_ASSERT(this,
Dimension::merge(merged_args_rank, arg0_shape.rank(), arg1_shape.rank()))
<< "Argument ranks do not match (arg0 shape: " << arg0_shape
<< ", arg1 shape: " << arg1_shape << ").";
NODE_VALIDATION_CHECK(this,
Dimension::merge(merged_args_rank, arg0_shape.rank(), arg1_shape.rank()),
"Argument ranks do not match (arg0 shape: ",
arg0_shape,
", arg1 shape: ",
arg1_shape,
").");
element::Type arg0_et = get_input_element_type(0);
element::Type arg1_et = get_input_element_type(1);
element::Type merged_args_et;
NODE_VALIDATION_ASSERT(this, element::Type::merge(merged_args_et, arg0_et, arg1_et))
<< "Argument element types do not match (arg0 element type: " << arg0_et
<< ", arg1 element type: " << arg1_et << ").";
NODE_VALIDATION_ASSERT(this,
m_lower_bounds.size() == m_upper_bounds.size() &&
m_lower_bounds.size() == m_strides.size())
<< "Ranks of lower bounds (" << m_lower_bounds << "), upper bounds (" << m_upper_bounds
<< ") and strides (" << m_strides << ") do not match.";
NODE_VALIDATION_CHECK(this,
element::Type::merge(merged_args_et, arg0_et, arg1_et),
"Argument element types do not match (arg0 element type: ",
arg0_et,
", arg1 element type: ",
arg1_et,
").");
NODE_VALIDATION_CHECK(this,
m_lower_bounds.size() == m_upper_bounds.size() &&
m_lower_bounds.size() == m_strides.size(),
"Ranks of lower bounds (",
m_lower_bounds,
"), upper bounds (",
m_upper_bounds,
") and strides (",
m_strides,
") do not match.");
size_t output_rank = m_upper_bounds.size();
for (size_t i = 0; i < output_rank; i++)
{
NODE_VALIDATION_ASSERT(this, m_lower_bounds[i] <= m_upper_bounds[i])
<< "Lower bound for slice is greater than upper bound at axis " << i
<< " (lower bounds: " << m_lower_bounds << ", upper bounds: " << m_upper_bounds << ").";
NODE_VALIDATION_ASSERT(this, m_strides[i] != 0) << "Stride for slice is zero at axis " << i
<< " (strides: " << m_strides << ").";
NODE_VALIDATION_CHECK(this,
m_lower_bounds[i] <= m_upper_bounds[i],
"Lower bound for slice is greater than upper bound at axis ",
i,
" (lower bounds: ",
m_lower_bounds,
", upper bounds: ",
m_upper_bounds,
").");
NODE_VALIDATION_CHECK(this,
m_strides[i] != 0,
"Stride for slice is zero at axis ",
i,
" (strides: ",
m_strides,
").");
}
NODE_VALIDATION_ASSERT(this,
merged_args_rank.is_dynamic() || size_t(merged_args_rank) == output_rank)
<< "Argument ranks do not match the rank of the lower bounds (" << m_lower_bounds
<< "), upper bounds (" << m_upper_bounds << "), and strides (" << m_strides << ").";
NODE_VALIDATION_CHECK(this,
merged_args_rank.is_dynamic() || size_t(merged_args_rank) == output_rank,
"Argument ranks do not match the rank of the lower bounds (",
m_lower_bounds,
"), upper bounds (",
m_upper_bounds,
"), and strides (",
m_strides,
").");
std::vector<Dimension> sliced_dims(output_rank);
for (size_t i = 0; i < output_rank; i++)
{
NODE_VALIDATION_ASSERT(this,
arg0_shape.rank().is_dynamic() || arg0_shape[i].is_dynamic() ||
m_upper_bounds[i] <= size_t(arg0_shape[i]))
<< "Upper bound for slice at axis " << i << " is out of range "
<< "(upper bounds: " << m_upper_bounds << ", argument shape: " << arg0_shape << ").";
NODE_VALIDATION_CHECK(this,
arg0_shape.rank().is_dynamic() || arg0_shape[i].is_dynamic() ||
m_upper_bounds[i] <= size_t(arg0_shape[i]),
"Upper bound for slice at axis ",
i,
" is out of range ",
"(upper bounds: ",
m_upper_bounds,
", argument shape: ",
arg0_shape,
").");
size_t sliced_dim = m_upper_bounds[i] - m_lower_bounds[i];
sliced_dim = sliced_dim / m_strides[i] + ((sliced_dim % m_strides[i] == 0) ? 0 : 1);
......@@ -112,9 +146,14 @@ void op::ReplaceSlice::validate_and_infer_types()
PartialShape slice_shape{sliced_dims};
NODE_VALIDATION_ASSERT(this, arg1_shape.compatible(slice_shape))
<< "Shape of replacement tensor (" << arg1_shape << ") does not match the slice shape "
<< "(" << slice_shape << ").";
NODE_VALIDATION_CHECK(this,
arg1_shape.compatible(slice_shape),
"Shape of replacement tensor (",
arg1_shape,
") does not match the slice shape ",
"(",
slice_shape,
").");
// Slight corner case here: if arg0 was rank-unknown, we can go ahead and set the output rank
// because the attribs will have given us enough info.
......
......@@ -41,25 +41,39 @@ void op::Reshape::validate_and_infer_types()
// Check that the input axis order is a permutation of (0,...,n-1) for some n.
for (size_t i = 0; i < m_input_order.size(); i++)
{
NODE_VALIDATION_ASSERT(
this, find(begin(m_input_order), end(m_input_order), i) != end(m_input_order))
<< "Input axis order is not a permutation of argument's axis indices (axis order: "
<< m_input_order << ", argument shape: " << input_shape << ").";
NODE_VALIDATION_CHECK(
this,
find(begin(m_input_order), end(m_input_order), i) != end(m_input_order),
"Input axis order is not a permutation of argument's axis indices (axis order: ",
m_input_order,
", argument shape: ",
input_shape,
").");
}
// TODO(amprocte): should be possible to move around unknown dims in the input shape.
if (input_rank.is_static())
{
NODE_VALIDATION_ASSERT(this, m_input_order.size() == size_t(input_rank))
<< "Input axis order is not a permutation of argument's axis indices (axis order: "
<< m_input_order << ", argument shape: " << input_shape << ").";
NODE_VALIDATION_CHECK(
this,
m_input_order.size() == size_t(input_rank),
"Input axis order is not a permutation of argument's axis indices (axis order: ",
m_input_order,
", argument shape: ",
input_shape,
").");
for (size_t i = 0; i < size_t(input_rank); i++)
{
auto it = find(begin(m_input_order), end(m_input_order), i);
NODE_VALIDATION_ASSERT(this, it != end(m_input_order))
<< "Input axis order is not a permutation of argument's axis indices (axis order: "
<< m_input_order << ", argument shape: " << input_shape << ").";
NODE_VALIDATION_CHECK(
this,
it != end(m_input_order),
"Input axis order is not a permutation of argument's axis indices (axis order: ",
m_input_order,
", argument shape: ",
input_shape,
").");
}
// TODO(amprocte): make a partial_shape_size() analogous to shape_size().
......@@ -71,11 +85,16 @@ void op::Reshape::validate_and_infer_types()
if (input_shape_product.is_static())
{
NODE_VALIDATION_ASSERT(this, size_t(input_shape_product) == shape_size(m_output_shape))
<< "Product of output shape dimensions does not match product of argument shape "
"dimensions "
<< "(output shape: " << m_output_shape << ", argument shape: " << input_shape
<< ").";
NODE_VALIDATION_CHECK(
this,
size_t(input_shape_product) == shape_size(m_output_shape),
"Product of output shape dimensions does not match product of argument shape "
"dimensions ",
"(output shape: ",
m_output_shape,
", argument shape: ",
input_shape,
").");
}
}
......
......@@ -32,8 +32,8 @@ op::Result::Result(const shared_ptr<Node>& arg)
void op::Result::validate_and_infer_types()
{
NODE_VALIDATION_ASSERT(this, get_input_size() == 1) << "Argument has " << get_input_size()
<< " outputs (1 expected).";
NODE_VALIDATION_CHECK(
this, get_input_size() == 1, "Argument has ", get_input_size(), " outputs (1 expected).");
// always borrow the placement conf even the default one
set_placement_index(get_argument(0)->get_placement_index());
......
......@@ -40,9 +40,13 @@ void op::Reverse::validate_and_infer_types()
// Make sure all reversed axis indices are valid.
for (size_t axis : m_reversed_axes)
{
NODE_VALIDATION_ASSERT(this, axis < size_t(input_rank))
<< "Reverse axis (" << axis << ") is out of bounds (argument shape: " << input_shape
<< ").";
NODE_VALIDATION_CHECK(this,
axis < size_t(input_rank),
"Reverse axis (",
axis,
") is out of bounds (argument shape: ",
input_shape,
").");
}
}
......
......@@ -41,20 +41,31 @@ void op::ReverseSequence::validate_and_infer_types()
auto input_shape = get_input_partial_shape(0);
auto input_rank = input_shape.rank();
NODE_VALIDATION_ASSERT(this, input_rank.is_dynamic() || m_batch_axis < size_t(input_rank))
<< "Batch axis index (" << m_batch_axis
<< ") is out of bounds (argument shape: " << input_shape << ").";
NODE_VALIDATION_CHECK(this,
input_rank.is_dynamic() || m_batch_axis < size_t(input_rank),
"Batch axis index (",
m_batch_axis,
") is out of bounds (argument shape: ",
input_shape,
").");
NODE_VALIDATION_ASSERT(this, input_rank.is_dynamic() || m_seq_axis < size_t(input_rank))
<< "Sequence axis index (" << m_seq_axis
<< ") is out of bounds (argument shape: " << input_shape << ").";
NODE_VALIDATION_CHECK(this,
input_rank.is_dynamic() || m_seq_axis < size_t(input_rank),
"Sequence axis index (",
m_seq_axis,
") is out of bounds (argument shape: ",
input_shape,
").");
auto indices_shape = get_input_partial_shape(1);
auto indices_rank = indices_shape.rank();
NODE_VALIDATION_ASSERT(this, indices_rank.is_dynamic() || size_t(indices_rank) == 1)
<< "Sequence indices must be a 1-dimensional tensor (sequence indices shape: "
<< get_input_partial_shape(1) << ").";
NODE_VALIDATION_CHECK(
this,
indices_rank.is_dynamic() || size_t(indices_rank) == 1,
"Sequence indices must be a 1-dimensional tensor (sequence indices shape: ",
get_input_partial_shape(1),
").");
PartialShape output_shape{input_shape};
......@@ -62,12 +73,19 @@ void op::ReverseSequence::validate_and_infer_types()
{
Dimension merged_sequence_length;
NODE_VALIDATION_ASSERT(
NODE_VALIDATION_CHECK(
this,
Dimension::merge(merged_sequence_length, input_shape[m_batch_axis], indices_shape[0]))
<< "Sequence length (" << indices_shape[0] << ") is not equal to batch axis "
<< "dimension (" << input_shape[m_batch_axis] << ") (argument shape: " << input_shape
<< ", sequence indices shape: " << indices_shape << ").";
Dimension::merge(merged_sequence_length, input_shape[m_batch_axis], indices_shape[0]),
"Sequence length (",
indices_shape[0],
") is not equal to batch axis ",
"dimension (",
input_shape[m_batch_axis],
") (argument shape: ",
input_shape,
", sequence indices shape: ",
indices_shape,
").");
output_shape[m_batch_axis] = merged_sequence_length;
}
......
......@@ -36,24 +36,28 @@ op::Select::Select(const shared_ptr<Node>& arg0,
void op::Select::validate_and_infer_types()
{
NODE_VALIDATION_ASSERT(this,
get_input_element_type(0).is_dynamic() ||
get_input_element_type(0) == element::boolean)
<< "Argument 0 does not have boolean element type (element type: "
<< get_input_element_type(0) << ").";
NODE_VALIDATION_CHECK(this,
get_input_element_type(0).is_dynamic() ||
get_input_element_type(0) == element::boolean,
"Argument 0 does not have boolean element type (element type: ",
get_input_element_type(0),
").");
PartialShape result_shape = get_input_partial_shape(0);
NODE_VALIDATION_ASSERT(this, PartialShape::merge_into(result_shape, get_input_partial_shape(1)))
<< "Argument shapes are inconsistent.";
NODE_VALIDATION_ASSERT(this, PartialShape::merge_into(result_shape, get_input_partial_shape(2)))
<< "Argument shapes are inconsistent.";
NODE_VALIDATION_CHECK(this,
PartialShape::merge_into(result_shape, get_input_partial_shape(1)),
"Argument shapes are inconsistent.");
NODE_VALIDATION_CHECK(this,
PartialShape::merge_into(result_shape, get_input_partial_shape(2)),
"Argument shapes are inconsistent.");
element::Type result_et;
NODE_VALIDATION_ASSERT(
this, element::Type::merge(result_et, get_input_element_type(1), get_input_element_type(2)))
<< "Argument 1 and 2 element types are inconsistent.";
NODE_VALIDATION_CHECK(
this,
element::Type::merge(result_et, get_input_element_type(1), get_input_element_type(2)),
"Argument 1 and 2 element types are inconsistent.");
set_output_type(0, result_et, result_shape);
}
......
......@@ -51,40 +51,68 @@ void op::Slice::validate_and_infer_types()
m_strides = Strides(m_lower_bounds.size(), 1);
}
NODE_VALIDATION_ASSERT(this,
m_lower_bounds.size() == m_upper_bounds.size() &&
m_lower_bounds.size() == m_strides.size())
<< "Ranks of lower bounds (" << m_lower_bounds << "), upper bounds (" << m_upper_bounds
<< ") and strides (" << m_strides << ") do not match.";
NODE_VALIDATION_CHECK(this,
m_lower_bounds.size() == m_upper_bounds.size() &&
m_lower_bounds.size() == m_strides.size(),
"Ranks of lower bounds (",
m_lower_bounds,
"), upper bounds (",
m_upper_bounds,
") and strides (",
m_strides,
") do not match.");
size_t output_rank = m_upper_bounds.size();
for (size_t i = 0; i < output_rank; i++)
{
NODE_VALIDATION_ASSERT(this, m_lower_bounds[i] <= m_upper_bounds[i])
<< "Lower bound for slice is greater than upper bound at axis " << i
<< " (lower bounds: " << m_lower_bounds << ", upper bounds: " << m_upper_bounds << ").";
NODE_VALIDATION_ASSERT(this, m_strides[i] != 0) << "Stride for slice is zero at axis " << i
<< " (strides: " << m_strides << ").";
NODE_VALIDATION_CHECK(this,
m_lower_bounds[i] <= m_upper_bounds[i],
"Lower bound for slice is greater than upper bound at axis ",
i,
" (lower bounds: ",
m_lower_bounds,
", upper bounds: ",
m_upper_bounds,
").");
NODE_VALIDATION_CHECK(this,
m_strides[i] != 0,
"Stride for slice is zero at axis ",
i,
" (strides: ",
m_strides,
").");
}
const PartialShape& input_shape = get_input_partial_shape(0);
Dimension input_rank = input_shape.rank();
NODE_VALIDATION_ASSERT(this, input_rank.is_dynamic() || size_t(input_rank) == output_rank)
<< "Input rank does not match the rank of the lower bounds (" << m_lower_bounds
<< "), upper bounds (" << m_upper_bounds << "), and strides (" << m_strides << ").";
NODE_VALIDATION_CHECK(this,
input_rank.is_dynamic() || size_t(input_rank) == output_rank,
"Input rank does not match the rank of the lower bounds (",
m_lower_bounds,
"), upper bounds (",
m_upper_bounds,
"), and strides (",
m_strides,
").");
std::vector<Dimension> result_dims(output_rank);
for (size_t i = 0; i < output_rank; i++)
{
NODE_VALIDATION_ASSERT(this,
input_rank.is_dynamic() || input_shape[i].is_dynamic() ||
m_upper_bounds[i] <= size_t(input_shape[i]))
<< "Upper bound for slice at axis " << i << " is out of range "
<< "(upper bounds: " << m_upper_bounds << ", argument shape: " << input_shape << ").";
NODE_VALIDATION_CHECK(this,
input_rank.is_dynamic() || input_shape[i].is_dynamic() ||
m_upper_bounds[i] <= size_t(input_shape[i]),
"Upper bound for slice at axis ",
i,
" is out of range ",
"(upper bounds: ",
m_upper_bounds,
", argument shape: ",
input_shape,
").");
size_t result_axis_size = m_upper_bounds[i] - m_lower_bounds[i];
result_axis_size =
......
......@@ -37,9 +37,13 @@ op::Softmax::Softmax(const shared_ptr<Node>& arg, const AxisSet& axes)
for (auto axis : m_axes)
{
NODE_VALIDATION_ASSERT(this, axis < get_shape().size())
<< "Reduction axis (" << axis << ") is out of bounds (argument shape: " << get_shape()
<< ").";
NODE_VALIDATION_CHECK(this,
axis < get_shape().size(),
"Reduction axis (",
axis,
") is out of bounds (argument shape: ",
get_shape(),
").");
}
// empty axes == all axes
......
......@@ -43,26 +43,36 @@ void op::TopK::validate_and_infer_types()
Rank input_rank = input_shape.rank();
element::Type input_element_type = get_input_element_type(0);
NODE_VALIDATION_ASSERT(this, !m_index_element_type.is_dynamic())
<< "Argument element type must not be dynamic.";
NODE_VALIDATION_CHECK(
this, !m_index_element_type.is_dynamic(), "Argument element type must not be dynamic.");
NODE_VALIDATION_ASSERT(
this, m_index_element_type == element::i32 || m_index_element_type == element::i64)
<< "Argument element type must be i64 or i32 (got " << m_index_element_type << ").";
NODE_VALIDATION_CHECK(this,
m_index_element_type == element::i32 ||
m_index_element_type == element::i64,
"Argument element type must be i64 or i32 (got ",
m_index_element_type,
").");
NODE_VALIDATION_ASSERT(this, input_rank.is_dynamic() || static_cast<size_t>(input_rank) > 0)
<< "Argument rank must be greater than 0.";
NODE_VALIDATION_CHECK(this,
input_rank.is_dynamic() || static_cast<size_t>(input_rank) > 0,
"Argument rank must be greater than 0.");
NODE_VALIDATION_ASSERT(
this, input_rank.is_dynamic() || m_top_k_axis < static_cast<size_t>(input_rank))
<< "TopK axis (" << m_top_k_axis << ") is out of bounds.";
NODE_VALIDATION_CHECK(this,
input_rank.is_dynamic() || m_top_k_axis < static_cast<size_t>(input_rank),
"TopK axis (",
m_top_k_axis,
") is out of bounds.");
NODE_VALIDATION_ASSERT(this,
input_rank.is_dynamic() || input_shape[m_top_k_axis].is_dynamic() ||
m_k <= static_cast<size_t>(input_shape[m_top_k_axis]))
<< "K (" << m_k << ") exceeds the dimension ("
<< (input_rank.is_static() ? input_shape[m_top_k_axis] : 0) << ") of the TopK axis (axis "
<< m_top_k_axis << ").";
NODE_VALIDATION_CHECK(this,
input_rank.is_dynamic() || input_shape[m_top_k_axis].is_dynamic() ||
m_k <= static_cast<size_t>(input_shape[m_top_k_axis]),
"K (",
m_k,
") exceeds the dimension (",
(input_rank.is_static() ? input_shape[m_top_k_axis] : 0),
") of the TopK axis (axis ",
m_top_k_axis,
").");
PartialShape output_shape{input_shape};
......
......@@ -40,10 +40,16 @@ void op::util::ArithmeticReduction::validate_and_infer_types()
for (auto axis : m_reduction_axes)
{
NODE_VALIDATION_ASSERT(this, axis < size_t(input_rank))
<< "Reduction axis (" << axis << ") is out of bounds "
<< "(argument shape: " << input_shape << ", reduction axes: " << m_reduction_axes
<< ")";
NODE_VALIDATION_CHECK(this,
axis < size_t(input_rank),
"Reduction axis (",
axis,
") is out of bounds ",
"(argument shape: ",
input_shape,
", reduction axes: ",
m_reduction_axes,
")");
}
for (size_t i = 0; i < size_t(input_rank); i++)
......
......@@ -37,13 +37,18 @@ void op::util::IndexReduction::validate_and_infer_types()
const PartialShape& arg_shape = get_input_partial_shape(0);
Rank rank = arg_shape.rank();
NODE_VALIDATION_ASSERT(this, rank.is_dynamic() || size_t(rank) >= 1)
<< "Argument rank is zero.";
NODE_VALIDATION_ASSERT(this, rank.is_dynamic() || m_axis < size_t(rank))
<< "Reduction axis (" << m_axis << ") is not less than argument rank (" << rank << ").";
NODE_VALIDATION_ASSERT(
this, m_index_element_type == element::i32 || m_index_element_type == element::i64)
<< "Index element is neither i64 or i32.";
NODE_VALIDATION_CHECK(this, rank.is_dynamic() || size_t(rank) >= 1, "Argument rank is zero.");
NODE_VALIDATION_CHECK(this,
rank.is_dynamic() || m_axis < size_t(rank),
"Reduction axis (",
m_axis,
") is not less than argument rank (",
rank,
").");
NODE_VALIDATION_CHECK(this,
m_index_element_type == element::i32 ||
m_index_element_type == element::i64,
"Index element is neither i64 or i32.");
PartialShape output_shape{PartialShape::dynamic()};
......
......@@ -40,10 +40,16 @@ void op::util::LogicalReduction::validate_and_infer_types()
for (auto axis : m_reduction_axes)
{
NODE_VALIDATION_ASSERT(this, axis < size_t(input_rank))
<< "Reduction axis (" << axis << ") is out of bounds "
<< "(argument shape: " << input_shape << ", reduction axes: " << m_reduction_axes
<< ")";
NODE_VALIDATION_CHECK(this,
axis < size_t(input_rank),
"Reduction axis (",
axis,
") is out of bounds ",
"(argument shape: ",
input_shape,
", reduction axes: ",
m_reduction_axes,
")");
}
for (size_t i = 0; i < size_t(input_rank); i++)
......@@ -57,8 +63,9 @@ void op::util::LogicalReduction::validate_and_infer_types()
result_shape = PartialShape(dims);
}
NODE_VALIDATION_ASSERT(this, get_input_element_type(0).compatible(element::boolean))
<< "Input element type must be boolean.";
NODE_VALIDATION_CHECK(this,
get_input_element_type(0).compatible(element::boolean),
"Input element type must be boolean.");
set_output_type(0, element::boolean, result_shape);
}
......@@ -29,9 +29,14 @@ void op::util::validate_conv_shapes(const Node* node,
const Shape& data_shape,
const Shape& filters_shape)
{
NODE_VALIDATION_ASSERT(node, data_shape[1] == filters_shape[1])
<< "Number of channels for data and filters do not match (data num channels: "
<< data_shape[1] << ", filters num channels: " << filters_shape[1] << ").";
NODE_VALIDATION_CHECK(
node,
data_shape[1] == filters_shape[1],
"Number of channels for data and filters do not match (data num channels: ",
data_shape[1],
", filters num channels: ",
filters_shape[1],
").");
}
op::ConvolutionAdd::ConvolutionAdd(const std::shared_ptr<op::Convolution>& conv,
......@@ -79,9 +84,14 @@ op::ConvolutionAdd::ConvolutionAdd(const std::shared_ptr<Node>& data_batch,
//
// Make sure data batch and filter element types match.
//
NODE_VALIDATION_ASSERT(this, data_batch_et == filters_et)
<< "Element types for data_batch and filters do not match (data batch element type: "
<< data_batch_et << ", filters element type: " << filters_et << ").";
NODE_VALIDATION_CHECK(
this,
data_batch_et == filters_et,
"Element types for data_batch and filters do not match (data batch element type: ",
data_batch_et,
", filters element type: ",
filters_et,
").");
util::validate_conv_shapes(this, data_batch_shape, filters_shape);
set_output_type(0,
......@@ -105,8 +115,11 @@ op::ConvolutionAdd::ConvolutionAdd(const std::shared_ptr<Node>& data_batch,
std::shared_ptr<Node> op::ConvolutionAdd::copy_with_new_args(const NodeVector& new_args) const
{
NODE_VALIDATION_ASSERT(this, new_args.size() == 3)
<< "New arg size is not 3 (new args size: " << new_args.size() << ").";
NODE_VALIDATION_CHECK(this,
new_args.size() == 3,
"New arg size is not 3 (new args size: ",
new_args.size(),
").");
return std::shared_ptr<Node>(new ConvolutionAdd(new_args.at(0),
new_args.at(1),
......
......@@ -57,51 +57,85 @@ void op::UpdateSlice::validate_and_infer_types()
const PartialShape& arg1_shape = get_input_partial_shape(1);
Dimension merged_args_rank;
NODE_VALIDATION_ASSERT(this,
Dimension::merge(merged_args_rank, arg0_shape.rank(), arg1_shape.rank()))
<< "Argument ranks do not match (arg0 shape: " << arg0_shape
<< ", arg1 shape: " << arg1_shape << ").";
NODE_VALIDATION_CHECK(this,
Dimension::merge(merged_args_rank, arg0_shape.rank(), arg1_shape.rank()),
"Argument ranks do not match (arg0 shape: ",
arg0_shape,
", arg1 shape: ",
arg1_shape,
").");
element::Type arg0_et = get_input_element_type(0);
element::Type arg1_et = get_input_element_type(1);
element::Type merged_args_et;
NODE_VALIDATION_ASSERT(this, element::Type::merge(merged_args_et, arg0_et, arg1_et))
<< "Argument element types do not match (arg0 element type: " << arg0_et
<< ", arg1 element type: " << arg1_et << ").";
NODE_VALIDATION_ASSERT(this,
m_lower_bounds.size() == m_upper_bounds.size() &&
m_lower_bounds.size() == m_strides.size())
<< "Ranks of lower bounds (" << m_lower_bounds << "), upper bounds (" << m_upper_bounds
<< ") and strides (" << m_strides << ") do not match.";
NODE_VALIDATION_CHECK(this,
element::Type::merge(merged_args_et, arg0_et, arg1_et),
"Argument element types do not match (arg0 element type: ",
arg0_et,
", arg1 element type: ",
arg1_et,
").");
NODE_VALIDATION_CHECK(this,
m_lower_bounds.size() == m_upper_bounds.size() &&
m_lower_bounds.size() == m_strides.size(),
"Ranks of lower bounds (",
m_lower_bounds,
"), upper bounds (",
m_upper_bounds,
") and strides (",
m_strides,
") do not match.");
size_t output_rank = m_upper_bounds.size();
for (size_t i = 0; i < output_rank; i++)
{
NODE_VALIDATION_ASSERT(this, m_lower_bounds[i] <= m_upper_bounds[i])
<< "Lower bound for slice is greater than upper bound at axis " << i
<< " (lower bounds: " << m_lower_bounds << ", upper bounds: " << m_upper_bounds << ").";
NODE_VALIDATION_ASSERT(this, m_strides[i] != 0) << "Stride for slice is zero at axis " << i
<< " (strides: " << m_strides << ").";
NODE_VALIDATION_CHECK(this,
m_lower_bounds[i] <= m_upper_bounds[i],
"Lower bound for slice is greater than upper bound at axis ",
i,
" (lower bounds: ",
m_lower_bounds,
", upper bounds: ",
m_upper_bounds,
").");
NODE_VALIDATION_CHECK(this,
m_strides[i] != 0,
"Stride for slice is zero at axis ",
i,
" (strides: ",
m_strides,
").");
}
NODE_VALIDATION_ASSERT(this,
merged_args_rank.is_dynamic() || size_t(merged_args_rank) == output_rank)
<< "Argument ranks do not match the rank of the lower bounds (" << m_lower_bounds
<< "), upper bounds (" << m_upper_bounds << "), and strides (" << m_strides << ").";
NODE_VALIDATION_CHECK(this,
merged_args_rank.is_dynamic() || size_t(merged_args_rank) == output_rank,
"Argument ranks do not match the rank of the lower bounds (",
m_lower_bounds,
"), upper bounds (",
m_upper_bounds,
"), and strides (",
m_strides,
").");
std::vector<Dimension> sliced_dims(output_rank);
for (size_t i = 0; i < output_rank; i++)
{
NODE_VALIDATION_ASSERT(this,
arg0_shape.rank().is_dynamic() || arg0_shape[i].is_dynamic() ||
m_upper_bounds[i] <= size_t(arg0_shape[i]))
<< "Upper bound for slice at axis " << i << " is out of range "
<< "(upper bounds: " << m_upper_bounds << ", argument shape: " << arg0_shape << ").";
NODE_VALIDATION_CHECK(this,
arg0_shape.rank().is_dynamic() || arg0_shape[i].is_dynamic() ||
m_upper_bounds[i] <= size_t(arg0_shape[i]),
"Upper bound for slice at axis ",
i,
" is out of range ",
"(upper bounds: ",
m_upper_bounds,
", argument shape: ",
arg0_shape,
").");
size_t sliced_dim = m_upper_bounds[i] - m_lower_bounds[i];
sliced_dim = sliced_dim / m_strides[i] + ((sliced_dim % m_strides[i] == 0) ? 0 : 1);
......@@ -110,9 +144,14 @@ void op::UpdateSlice::validate_and_infer_types()
PartialShape slice_shape{sliced_dims};
NODE_VALIDATION_ASSERT(this, arg1_shape.compatible(slice_shape))
<< "Shape of replacement tensor (" << arg1_shape << ") does not match the slice shape "
<< "(" << slice_shape << ").";
NODE_VALIDATION_CHECK(this,
arg1_shape.compatible(slice_shape),
"Shape of replacement tensor (",
arg1_shape,
") does not match the slice shape ",
"(",
slice_shape,
").");
// Slight corner case here: if arg0 was rank-unknown, we can go ahead and set the output rank
// because the attribs will have given us enough info.
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment