Commit 6ffbd254 authored by Adam Procter's avatar Adam Procter Committed by Robert Kimball

New macros for faster node validation checks (#2537)

* Skeleton for faster validation asserts

* Switch to __VA_ARGS__ for compatibility, remove -Wno-variadic-macros

* Add benchmarks for constructing Add and Convolution

* Quick hack to avoid shadowing inside the CHECK macro

* Quick hack to avoid inadvertent capture inside the macro

* Update convolution, and change a bunch of tests to anticipate the new error class
parent eb34c884
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <exception>
#include <sstream>
#include <vector>
#include "ngraph/except.hpp"
namespace ngraph
{
static inline std::ostream& write_all_to_stream(std::ostream& str) { return str; }
template <typename T, typename... TS>
static inline std::ostream& write_all_to_stream(std::ostream& str, const T& arg, TS... args)
{
return write_all_to_stream(str << arg, args...);
}
struct CheckLocInfo
{
const char* file;
int line;
const char* check_string;
};
/// Base class for check failure exceptions.
class CheckFailure : public ngraph_error
{
public:
CheckFailure(const CheckLocInfo& check_loc_info,
const std::string& context_info,
const std::string& explanation)
: ngraph_error(make_what(check_loc_info, context_info, explanation))
{
}
private:
static std::string make_what(const CheckLocInfo& check_loc_info,
const std::string& context_info,
const std::string& explanation)
{
std::stringstream ss;
ss << "Check '" << check_loc_info.check_string << "' failed at " << check_loc_info.file
<< ":" << check_loc_info.line << ":" << std::endl;
ss << context_info << ":" << std::endl;
ss << explanation << std::endl;
return ss.str();
}
};
}
// TODO(amprocte): refactor so we don't have to introduce a locally-scoped variable and risk
// shadowing here.
#define CHECK(exc_class, ctx, check, ...) \
do \
{ \
if (!(check)) \
{ \
::std::stringstream ss___; \
::ngraph::write_all_to_stream(ss___, __VA_ARGS__); \
throw exc_class( \
(::ngraph::CheckLocInfo{__FILE__, __LINE__, #check}), (ctx), ss___.str()); \
} \
} while (0)
...@@ -422,15 +422,20 @@ NodeVector Node::get_users(bool check_is_used) const ...@@ -422,15 +422,20 @@ NodeVector Node::get_users(bool check_is_used) const
std::string ngraph::node_validation_assertion_string(const Node* node) std::string ngraph::node_validation_assertion_string(const Node* node)
{ {
std::stringstream ss; std::stringstream ss;
ss << "While validating node '" << *node << "' of type '" << node->description() << "'"; ss << "While validating node '" << *node << "'";
return ss.str(); return ss.str();
} }
void ngraph::check_new_args_count(const Node* node, const NodeVector& new_args) void ngraph::check_new_args_count(const Node* node, const NodeVector& new_args)
{ {
NODE_VALIDATION_ASSERT(node, new_args.size() == node->get_arguments().size()) NODE_VALIDATION_CHECK(node,
<< "copy_with_new_args() expected " << node->get_arguments().size() << " argument" new_args.size() == node->get_arguments().size(),
<< (node->get_arguments().size() == 1 ? "" : "s") << " but got " << new_args.size(); "copy_with_new_args() expected ",
node->get_arguments().size(),
" argument",
(node->get_arguments().size() == 1 ? "" : "s"),
" but got ",
new_args.size());
} }
const std::shared_ptr<Node>& ngraph::check_single_output_arg(const std::shared_ptr<Node>& node, const std::shared_ptr<Node>& ngraph::check_single_output_arg(const std::shared_ptr<Node>& node,
...@@ -459,13 +464,14 @@ std::tuple<element::Type, PartialShape> Node::validate_and_infer_elementwise_arg ...@@ -459,13 +464,14 @@ std::tuple<element::Type, PartialShape> Node::validate_and_infer_elementwise_arg
{ {
for (size_t i = 1; i < get_input_size(); ++i) for (size_t i = 1; i < get_input_size(); ++i)
{ {
NODE_VALIDATION_ASSERT( NODE_VALIDATION_CHECK(
this, element::Type::merge(element_type, element_type, get_input_element_type(i))) this,
<< "Argument element types are inconsistent."; element::Type::merge(element_type, element_type, get_input_element_type(i)),
"Argument element types are inconsistent.");
NODE_VALIDATION_ASSERT(this,
PartialShape::merge_into(pshape, get_input_partial_shape(i))) NODE_VALIDATION_CHECK(this,
<< "Argument shapes are inconsistent."; PartialShape::merge_into(pshape, get_input_partial_shape(i)),
"Argument shapes are inconsistent.");
} }
} }
...@@ -478,8 +484,11 @@ void Node::validate_and_infer_elementwise_arithmetic() ...@@ -478,8 +484,11 @@ void Node::validate_and_infer_elementwise_arithmetic()
element::Type& args_et = std::get<0>(args_et_pshape); element::Type& args_et = std::get<0>(args_et_pshape);
PartialShape& args_pshape = std::get<1>(args_et_pshape); PartialShape& args_pshape = std::get<1>(args_et_pshape);
NODE_VALIDATION_ASSERT(this, args_et.is_dynamic() || args_et != element::boolean) NODE_VALIDATION_CHECK(this,
<< "Arguments cannot have boolean element type (argument element type: " << args_et << ")."; args_et.is_dynamic() || args_et != element::boolean,
"Arguments cannot have boolean element type (argument element type: ",
args_et,
").");
set_output_type(0, args_et, args_pshape); set_output_type(0, args_et, args_pshape);
} }
...@@ -490,9 +499,12 @@ void Node::validate_and_infer_elementwise_logical() ...@@ -490,9 +499,12 @@ void Node::validate_and_infer_elementwise_logical()
element::Type& args_et = std::get<0>(args_et_pshape); element::Type& args_et = std::get<0>(args_et_pshape);
PartialShape& args_pshape = std::get<1>(args_et_pshape); PartialShape& args_pshape = std::get<1>(args_et_pshape);
NODE_VALIDATION_ASSERT(this, args_et.is_dynamic() || args_et == element::boolean) NODE_VALIDATION_CHECK(
<< "Operands for logical operators must have boolean element type but have element type " this,
<< args_et << "."; args_et.is_dynamic() || args_et == element::boolean,
"Operands for logical operators must have boolean element type but have element type ",
args_et,
".");
set_output_type(0, element::boolean, args_pshape); set_output_type(0, element::boolean, args_pshape);
} }
...@@ -30,6 +30,7 @@ ...@@ -30,6 +30,7 @@
#include "ngraph/assertion.hpp" #include "ngraph/assertion.hpp"
#include "ngraph/autodiff/adjoints.hpp" #include "ngraph/autodiff/adjoints.hpp"
#include "ngraph/check.hpp"
#include "ngraph/descriptor/input.hpp" #include "ngraph/descriptor/input.hpp"
#include "ngraph/descriptor/output.hpp" #include "ngraph/descriptor/output.hpp"
#include "ngraph/descriptor/tensor.hpp" #include "ngraph/descriptor/tensor.hpp"
...@@ -287,6 +288,17 @@ namespace ngraph ...@@ -287,6 +288,17 @@ namespace ngraph
} }
}; };
class NodeValidationFailure : public CheckFailure
{
public:
NodeValidationFailure(const CheckLocInfo& check_loc_info,
const Node* node,
const std::string& explanation)
: CheckFailure(check_loc_info, node_validation_assertion_string(node), explanation)
{
}
};
class NodeDescription class NodeDescription
{ {
public: public:
...@@ -315,3 +327,6 @@ namespace ngraph ...@@ -315,3 +327,6 @@ namespace ngraph
#define NODE_VALIDATION_FAIL(node) \ #define NODE_VALIDATION_FAIL(node) \
NGRAPH_FAIL_STREAM_WITH_LOC(::ngraph::NodeValidationError, \ NGRAPH_FAIL_STREAM_WITH_LOC(::ngraph::NodeValidationError, \
::ngraph::node_validation_assertion_string(node)) ::ngraph::node_validation_assertion_string(node))
#define NODE_VALIDATION_CHECK(node, cond, ...) \
CHECK(::NodeValidationFailure, (node), (cond), __VA_ARGS__)
...@@ -297,9 +297,14 @@ void op::ConvolutionBackpropData::validate_and_infer_types() ...@@ -297,9 +297,14 @@ void op::ConvolutionBackpropData::validate_and_infer_types()
m_window_movement_strides_forward, m_window_movement_strides_forward,
m_window_dilation_strides_forward); m_window_dilation_strides_forward);
NODE_VALIDATION_ASSERT(this, forward_result_shape.compatible(delta_shape)) NODE_VALIDATION_CHECK(this,
<< "Inferred forward output shape (" << forward_result_shape << ") does not match shape of " forward_result_shape.compatible(delta_shape),
<< "delta (" << delta_shape << ")."; "Inferred forward output shape (",
forward_result_shape,
") does not match shape of ",
"delta (",
delta_shape,
").");
set_output_type(0, forward_result_et, m_data_batch_shape); set_output_type(0, forward_result_et, m_data_batch_shape);
...@@ -494,9 +499,14 @@ void op::ConvolutionBackpropFilters::validate_and_infer_types() ...@@ -494,9 +499,14 @@ void op::ConvolutionBackpropFilters::validate_and_infer_types()
m_window_movement_strides_forward, m_window_movement_strides_forward,
m_window_dilation_strides_forward); m_window_dilation_strides_forward);
NODE_VALIDATION_ASSERT(this, forward_result_shape.compatible(delta_shape)) NODE_VALIDATION_CHECK(this,
<< "Inferred forward output shape (" << forward_result_shape << ") does not match shape of " forward_result_shape.compatible(delta_shape),
<< "delta (" << delta_shape << ")."; "Inferred forward output shape (",
forward_result_shape,
") does not match shape of ",
"delta (",
delta_shape,
").");
set_output_type(0, forward_result_et, m_filters_shape); set_output_type(0, forward_result_et, m_filters_shape);
......
...@@ -39,18 +39,29 @@ PartialShape ngraph::infer_windowed_reduction_output_shape(const Node* node, ...@@ -39,18 +39,29 @@ PartialShape ngraph::infer_windowed_reduction_output_shape(const Node* node,
{ {
PartialShape data_shape_merged{PartialShape::dynamic()}; PartialShape data_shape_merged{PartialShape::dynamic()};
NODE_VALIDATION_ASSERT(node, NODE_VALIDATION_CHECK(node,
data_shape_merged.merge_rank(data_shape.rank()) && data_shape_merged.merge_rank(data_shape.rank()) &&
data_shape_merged.merge_rank(data_dilation.size()) && data_shape_merged.merge_rank(data_dilation.size()) &&
data_shape_merged.merge_rank(data_padding_below.size()) && data_shape_merged.merge_rank(data_padding_below.size()) &&
data_shape_merged.merge_rank(data_padding_above.size()) && data_shape_merged.merge_rank(data_padding_above.size()) &&
data_shape_merged.merge_rank(window_shape.rank()) && data_shape_merged.merge_rank(window_shape.rank()) &&
data_shape_merged.merge_rank(window_strides.size()) && data_shape_merged.merge_rank(window_strides.size()) &&
data_shape_merged.merge_rank(window_dilation.size())) data_shape_merged.merge_rank(window_dilation.size()),
<< "Ranks for data shape (" << data_shape << "), data dilation (" << data_dilation "Ranks for data shape (",
<< "), padding below (" << data_padding_below << "), padding above (" << data_padding_above data_shape,
<< "), window shape (" << window_shape << "), window strides (" << window_strides "), data dilation (",
<< "), and window dilation (" << window_dilation << ") do not match."; data_dilation,
"), padding below (",
data_padding_below,
"), padding above (",
data_padding_above,
"), window shape (",
window_shape,
"), window strides (",
window_strides,
"), and window dilation (",
window_dilation,
") do not match.");
PartialShape output_shape = PartialShape::dynamic(data_shape_merged.rank()); PartialShape output_shape = PartialShape::dynamic(data_shape_merged.rank());
...@@ -58,15 +69,27 @@ PartialShape ngraph::infer_windowed_reduction_output_shape(const Node* node, ...@@ -58,15 +69,27 @@ PartialShape ngraph::infer_windowed_reduction_output_shape(const Node* node,
{ {
for (size_t i = 0; i < static_cast<size_t>(output_shape.rank()); i++) for (size_t i = 0; i < static_cast<size_t>(output_shape.rank()); i++)
{ {
NODE_VALIDATION_ASSERT(node, data_dilation[i] > 0) NODE_VALIDATION_CHECK(node,
<< "Data dilation (" << data_dilation << ") has zero dimension at axis " << i data_dilation[i] > 0,
<< "."; "Data dilation (",
NODE_VALIDATION_ASSERT(node, window_strides[i] > 0) data_dilation,
<< "Window strides (" << window_strides << ") has zero dimension at axis " << i ") has zero dimension at axis ",
<< "."; i,
NODE_VALIDATION_ASSERT(node, window_dilation[i] > 0) ".");
<< "Window dilation (" << window_dilation << ") has zero dimension at axis " << i NODE_VALIDATION_CHECK(node,
<< "."; window_strides[i] > 0,
"Window strides (",
window_strides,
") has zero dimension at axis ",
i,
".");
NODE_VALIDATION_CHECK(node,
window_dilation[i] > 0,
"Window dilation (",
window_dilation,
") has zero dimension at axis ",
i,
".");
bool data_dim_static = data_shape.rank().is_static() && data_shape[i].is_static(); bool data_dim_static = data_shape.rank().is_static() && data_shape[i].is_static();
bool window_dim_static = window_shape.rank().is_static() && window_shape[i].is_static(); bool window_dim_static = window_shape.rank().is_static() && window_shape[i].is_static();
...@@ -77,9 +100,14 @@ PartialShape ngraph::infer_windowed_reduction_output_shape(const Node* node, ...@@ -77,9 +100,14 @@ PartialShape ngraph::infer_windowed_reduction_output_shape(const Node* node,
data_padded_dilated_dim = (static_cast<ptrdiff_t>(data_dilation[i]) * data_padded_dilated_dim = (static_cast<ptrdiff_t>(data_dilation[i]) *
(static_cast<ptrdiff_t>(data_shape[i]) - 1)) + (static_cast<ptrdiff_t>(data_shape[i]) - 1)) +
1 + data_padding_below[i] + data_padding_above[i]; 1 + data_padding_below[i] + data_padding_above[i];
NODE_VALIDATION_ASSERT(node, data_padded_dilated_dim > 0) NODE_VALIDATION_CHECK(
<< "Data shape after padding and dilation has dimension less than 1 (dim: " node,
<< data_padded_dilated_dim << ") at axis " << i << "."; data_padded_dilated_dim > 0,
"Data shape after padding and dilation has dimension less than 1 (dim: ",
data_padded_dilated_dim,
") at axis ",
i,
".");
} }
ptrdiff_t window_dilated_dim = -1; ptrdiff_t window_dilated_dim = -1;
...@@ -89,28 +117,42 @@ PartialShape ngraph::infer_windowed_reduction_output_shape(const Node* node, ...@@ -89,28 +117,42 @@ PartialShape ngraph::infer_windowed_reduction_output_shape(const Node* node,
(static_cast<ptrdiff_t>(window_shape[i]) - 1) + (static_cast<ptrdiff_t>(window_shape[i]) - 1) +
1; 1;
NODE_VALIDATION_ASSERT(node, window_dilated_dim > 0) NODE_VALIDATION_CHECK(node,
<< "Window after dilation has dimension less than 1 (dim: " window_dilated_dim > 0,
<< window_dilated_dim << ") at axis " << i << "."; "Window after dilation has dimension less than 1 (dim: ",
window_dilated_dim,
NODE_VALIDATION_ASSERT(node, ") at axis ",
is_window_all_in_padding_allowed || i,
(window_dilated_dim > data_padding_below[i] && ".");
window_dilated_dim > data_padding_above[i]))
<< "Window after dilation is sometimes entirely in the padding area for axis " NODE_VALIDATION_CHECK(
<< i << " (dilated window dimension: " << window_dilated_dim node,
<< ", padding below dimension: " << data_padding_below[i] is_window_all_in_padding_allowed ||
<< ", padding above dimension: " << data_padding_above[i] (window_dilated_dim > data_padding_below[i] &&
<< ") and this is not " window_dilated_dim > data_padding_above[i]),
<< "allowed."; "Window after dilation is sometimes entirely in the padding area for axis ",
i,
" (dilated window dimension: ",
window_dilated_dim,
", padding below dimension: ",
data_padding_below[i],
", padding above dimension: ",
data_padding_above[i],
") and this is not ",
"allowed.");
} }
if (data_dim_static && window_dim_static) if (data_dim_static && window_dim_static)
{ {
NODE_VALIDATION_ASSERT(node, window_dilated_dim <= data_padded_dilated_dim) NODE_VALIDATION_CHECK(node,
<< "Window after dilation has dimension (dim: " << window_dilated_dim window_dilated_dim <= data_padded_dilated_dim,
<< ") larger than the data shape after padding (dim: " "Window after dilation has dimension (dim: ",
<< data_padded_dilated_dim << ") at axis " << i << "."; window_dilated_dim,
") larger than the data shape after padding (dim: ",
data_padded_dilated_dim,
") at axis ",
i,
".");
output_shape[i] = ceil_div(static_cast<size_t>(data_padded_dilated_dim) - output_shape[i] = ceil_div(static_cast<size_t>(data_padded_dilated_dim) -
static_cast<size_t>(window_dilated_dim) + 1, static_cast<size_t>(window_dilated_dim) + 1,
...@@ -139,39 +181,64 @@ std::tuple<element::Type, PartialShape> ...@@ -139,39 +181,64 @@ std::tuple<element::Type, PartialShape>
{ {
element::Type et_result; element::Type et_result;
NODE_VALIDATION_ASSERT(node, element::Type::merge(et_result, et_batch, et_filters)) NODE_VALIDATION_CHECK(
<< "Element types for data batch and filters do not match (data batch element type: " node,
<< et_batch << ", filters element type: " << et_filters << ")."; element::Type::merge(et_result, et_batch, et_filters),
"Element types for data batch and filters do not match (data batch element type: ",
et_batch,
", filters element type: ",
et_filters,
").");
Rank data_batch_filters_rank{Rank::dynamic()}; Rank data_batch_filters_rank{Rank::dynamic()};
NODE_VALIDATION_ASSERT( NODE_VALIDATION_CHECK(
node, Rank::merge(data_batch_filters_rank, data_batch_shape.rank(), filters_shape.rank())) node,
<< "Data batch and filters rank do not match (data batch shape: " << data_batch_shape Rank::merge(data_batch_filters_rank, data_batch_shape.rank(), filters_shape.rank()),
<< ", filters shape: " << filters_shape << ")."; "Data batch and filters rank do not match (data batch shape: ",
data_batch_shape,
NODE_VALIDATION_ASSERT(node, ", filters shape: ",
data_batch_filters_rank.is_dynamic() || filters_shape,
static_cast<size_t>(data_batch_filters_rank) >= 3) ").");
<< "Data batch and filters must have rank of at least 3 (one batch axis, "
<< "one input-channel axis, and at least one spatial dimension) " NODE_VALIDATION_CHECK(node,
<< "(data batch shape: " << data_batch_shape << ", filters shape: " << filters_shape data_batch_filters_rank.is_dynamic() ||
<< ")."; static_cast<size_t>(data_batch_filters_rank) >= 3,
"Data batch and filters must have rank of at least 3 (one batch axis, ",
"one input-channel axis, and at least one spatial dimension) ",
"(data batch shape: ",
data_batch_shape,
", filters shape: ",
filters_shape,
").");
Rank spatial_rank{Rank::dynamic()}; Rank spatial_rank{Rank::dynamic()};
NODE_VALIDATION_ASSERT(node, NODE_VALIDATION_CHECK(node,
Rank::merge(spatial_rank, spatial_rank, data_batch_filters_rank - 2) && Rank::merge(spatial_rank, spatial_rank, data_batch_filters_rank - 2) &&
Rank::merge(spatial_rank, spatial_rank, data_dilation.size()) && Rank::merge(spatial_rank, spatial_rank, data_dilation.size()) &&
Rank::merge(spatial_rank, spatial_rank, data_padding_below.size()) && Rank::merge(spatial_rank, spatial_rank, data_padding_below.size()) &&
Rank::merge(spatial_rank, spatial_rank, data_padding_above.size()) && Rank::merge(spatial_rank, spatial_rank, data_padding_above.size()) &&
Rank::merge(spatial_rank, spatial_rank, filter_strides.size()) && Rank::merge(spatial_rank, spatial_rank, filter_strides.size()) &&
Rank::merge(spatial_rank, spatial_rank, filter_dilation.size())) Rank::merge(spatial_rank, spatial_rank, filter_dilation.size()),
<< "Ranks for data item shape/filters shape (data batch has shape " << data_batch_shape "Ranks for data item shape/filters shape (data batch has shape ",
<< ", so data item rank is " << (data_batch_shape.rank() - 2) << " and filters have shape " data_batch_shape,
<< filters_shape << ", so filters spatial rank is " << (filters_shape.rank() - 2) ", so data item rank is ",
<< "), data dilation (" << data_dilation << "), padding below (" << data_padding_below (data_batch_shape.rank() - 2),
<< "), padding above (" << data_padding_above << "), filter strides (" << filter_strides " and filters have shape ",
<< "), and filter dilation (" << filter_dilation << ") do not match."; filters_shape,
", so filters spatial rank is ",
(filters_shape.rank() - 2),
"), data dilation (",
data_dilation,
"), padding below (",
data_padding_below,
"), padding above (",
data_padding_above,
"), filter strides (",
filter_strides,
"), and filter dilation (",
filter_dilation,
") do not match.");
Dimension batch_size = Dimension batch_size =
(data_batch_shape.rank().is_static() ? data_batch_shape[0] : Dimension::dynamic()); (data_batch_shape.rank().is_static() ? data_batch_shape[0] : Dimension::dynamic());
...@@ -202,25 +269,31 @@ std::tuple<element::Type, PartialShape> ...@@ -202,25 +269,31 @@ std::tuple<element::Type, PartialShape>
} }
} }
NODE_VALIDATION_ASSERT(node, batch_size.is_dynamic() || static_cast<size_t>(batch_size) > 0) NODE_VALIDATION_CHECK(node,
<< "Batch size is zero."; batch_size.is_dynamic() || static_cast<size_t>(batch_size) > 0,
"Batch size is zero.");
Dimension merged_channel_count; Dimension merged_channel_count;
NODE_VALIDATION_ASSERT( NODE_VALIDATION_CHECK(
node, node,
Dimension::merge(merged_channel_count, data_channel_count, filter_input_channel_count)) Dimension::merge(merged_channel_count, data_channel_count, filter_input_channel_count),
<< "Data batch channel count (" << data_channel_count << ") does not match filter input " "Data batch channel count (",
<< "channel count (" << filter_input_channel_count << ")."; data_channel_count,
") does not match filter input ",
NODE_VALIDATION_ASSERT( "channel count (",
node, merged_channel_count.is_dynamic() || static_cast<size_t>(merged_channel_count) > 0) filter_input_channel_count,
<< "Data batch channel count and/or filter input channel count is zero."; ").");
NODE_VALIDATION_ASSERT(node, NODE_VALIDATION_CHECK(node,
filter_output_channel_count.is_dynamic() || merged_channel_count.is_dynamic() ||
static_cast<size_t>(filter_output_channel_count) > 0) static_cast<size_t>(merged_channel_count) > 0,
<< "Filter output channel count is zero."; "Data batch channel count and/or filter input channel count is zero.");
NODE_VALIDATION_CHECK(node,
filter_output_channel_count.is_dynamic() ||
static_cast<size_t>(filter_output_channel_count) > 0,
"Filter output channel count is zero.");
PartialShape data_output_shape = infer_windowed_reduction_output_shape(node, PartialShape data_output_shape = infer_windowed_reduction_output_shape(node,
data_spatial_shape, data_spatial_shape,
...@@ -255,25 +328,36 @@ PartialShape ngraph::infer_batched_pooling_forward(const Node* node, ...@@ -255,25 +328,36 @@ PartialShape ngraph::infer_batched_pooling_forward(const Node* node,
const Strides& window_strides, const Strides& window_strides,
bool is_window_all_in_padding_allowed) bool is_window_all_in_padding_allowed)
{ {
NODE_VALIDATION_ASSERT(node, NODE_VALIDATION_CHECK(node,
data_batch_shape.rank().is_dynamic() || data_batch_shape.rank().is_dynamic() ||
static_cast<size_t>(data_batch_shape.rank()) >= 3) static_cast<size_t>(data_batch_shape.rank()) >= 3,
<< "Data batch must have rank of at least 3 (one batch axis, " "Data batch must have rank of at least 3 (one batch axis, ",
<< "one input-channel axis, and at least one spatial dimension) " "one input-channel axis, and at least one spatial dimension) ",
<< "(data batch shape: " << data_batch_shape << ")."; "(data batch shape: ",
data_batch_shape,
").");
PartialShape data_spatial_shape{PartialShape::dynamic()}; PartialShape data_spatial_shape{PartialShape::dynamic()};
NODE_VALIDATION_ASSERT(node, NODE_VALIDATION_CHECK(node,
data_spatial_shape.merge_rank(data_batch_shape.rank() - 2) && data_spatial_shape.merge_rank(data_batch_shape.rank() - 2) &&
data_spatial_shape.merge_rank(data_padding_below.size()) && data_spatial_shape.merge_rank(data_padding_below.size()) &&
data_spatial_shape.merge_rank(data_padding_above.size()) && data_spatial_shape.merge_rank(data_padding_above.size()) &&
data_spatial_shape.merge_rank(window_shape.rank()) && data_spatial_shape.merge_rank(window_shape.rank()) &&
data_spatial_shape.merge_rank(window_strides.size())) data_spatial_shape.merge_rank(window_strides.size()),
<< "Ranks for data item shape (data batch has shape " << data_batch_shape "Ranks for data item shape (data batch has shape ",
<< ", so data item rank is " << (data_batch_shape.rank() - 2) << "), padding below (" data_batch_shape,
<< data_padding_below << "), padding above (" << data_padding_above << "), window shape (" ", so data item rank is ",
<< window_shape << "), and window strides (" << window_strides << ") do not match."; (data_batch_shape.rank() - 2),
"), padding below (",
data_padding_below,
"), padding above (",
data_padding_above,
"), window shape (",
window_shape,
"), and window strides (",
window_strides,
") do not match.");
Dimension batch_size{Dimension::dynamic()}; Dimension batch_size{Dimension::dynamic()};
Dimension channel_count{Dimension::dynamic()}; Dimension channel_count{Dimension::dynamic()};
...@@ -289,12 +373,13 @@ PartialShape ngraph::infer_batched_pooling_forward(const Node* node, ...@@ -289,12 +373,13 @@ PartialShape ngraph::infer_batched_pooling_forward(const Node* node,
data_spatial_shape[i] = data_batch_shape[i + 2]; data_spatial_shape[i] = data_batch_shape[i + 2];
} }
NODE_VALIDATION_ASSERT(node, batch_size.is_dynamic() || static_cast<size_t>(batch_size) > 0) NODE_VALIDATION_CHECK(node,
<< "Batch size is zero."; batch_size.is_dynamic() || static_cast<size_t>(batch_size) > 0,
"Batch size is zero.");
NODE_VALIDATION_ASSERT(node, NODE_VALIDATION_CHECK(node,
channel_count.is_dynamic() || static_cast<size_t>(channel_count) > 0) channel_count.is_dynamic() || static_cast<size_t>(channel_count) > 0,
<< "Channel count is zero."; "Channel count is zero.");
// For pooling ops we don't need dilation, so we fill in the identity value (all 1). // For pooling ops we don't need dilation, so we fill in the identity value (all 1).
Strides data_dilation(static_cast<size_t>(data_spatial_shape.rank()), 1); Strides data_dilation(static_cast<size_t>(data_spatial_shape.rank()), 1);
...@@ -358,17 +443,19 @@ static std::tuple<element::Type, PartialShape, PartialShape> infer_batch_norm_fo ...@@ -358,17 +443,19 @@ static std::tuple<element::Type, PartialShape, PartialShape> infer_batch_norm_fo
for (auto& inp : channel_shaped_inputs) for (auto& inp : channel_shaped_inputs)
{ {
NODE_VALIDATION_ASSERT(node, element::Type::merge(et_result, et_result, inp.m_element_type)) NODE_VALIDATION_CHECK(node,
<< "Input element types do not match."; element::Type::merge(et_result, et_result, inp.m_element_type),
"Input element types do not match.");
} }
// Extract channel dimension from input shape. // Extract channel dimension from input shape.
Dimension channel_dim{Dimension::dynamic()}; Dimension channel_dim{Dimension::dynamic()};
NODE_VALIDATION_ASSERT(node, NODE_VALIDATION_CHECK(node,
input_shape.is_dynamic() || static_cast<size_t>(input_shape.rank()) >= 2) input_shape.is_dynamic() || static_cast<size_t>(input_shape.rank()) >= 2,
<< "Input argument must have rank of at least 2 (input argument shape: " << input_shape "Input argument must have rank of at least 2 (input argument shape: ",
<< ")."; input_shape,
").");
if (input_shape.rank().is_static()) if (input_shape.rank().is_static())
{ {
...@@ -380,20 +467,34 @@ static std::tuple<element::Type, PartialShape, PartialShape> infer_batch_norm_fo ...@@ -380,20 +467,34 @@ static std::tuple<element::Type, PartialShape, PartialShape> infer_batch_norm_fo
for (auto& inp : channel_shaped_inputs) for (auto& inp : channel_shaped_inputs)
{ {
NODE_VALIDATION_ASSERT(node, PartialShape::merge_into(channel_shape, inp.m_shape)) NODE_VALIDATION_CHECK(node,
<< "Shapes for " << channel_input_names << " do not match."; PartialShape::merge_into(channel_shape, inp.m_shape),
"Shapes for ",
channel_input_names,
" do not match.");
} }
NODE_VALIDATION_ASSERT(node, channel_shape.merge_rank(1)) << "Shape for " << channel_input_names NODE_VALIDATION_CHECK(node,
<< " (" << channel_shape channel_shape.merge_rank(1),
<< ") does not have rank 1."; "Shape for ",
channel_input_names,
NODE_VALIDATION_ASSERT(node, Dimension::merge(channel_dim, channel_dim, channel_shape[0])) " (",
<< "Input channel dimension (" << channel_dim << ") does not match shape for " channel_shape,
<< channel_input_names << " (" << channel_shape << ")."; ") does not have rank 1.");
NODE_VALIDATION_ASSERT(node, channel_dim.is_dynamic() || static_cast<size_t>(channel_dim) >= 1) NODE_VALIDATION_CHECK(node,
<< "Channel count must be at least 1."; Dimension::merge(channel_dim, channel_dim, channel_shape[0]),
"Input channel dimension (",
channel_dim,
") does not match shape for ",
channel_input_names,
" (",
channel_shape,
").");
NODE_VALIDATION_CHECK(node,
channel_dim.is_dynamic() || static_cast<size_t>(channel_dim) >= 1,
"Channel count must be at least 1.");
// Batch result shape is same as the input shape, except we may possibly have inferred more // Batch result shape is same as the input shape, except we may possibly have inferred more
// information from the channel count via gamma/beta/etc. // information from the channel count via gamma/beta/etc.
......
...@@ -884,7 +884,7 @@ TEST(partial_shape, infer_windowed_reduction_rank_dynamic_rank_dynamic_zero_data ...@@ -884,7 +884,7 @@ TEST(partial_shape, infer_windowed_reduction_rank_dynamic_rank_dynamic_zero_data
window_dilation, window_dilation,
is_window_all_in_padding_allowed); is_window_all_in_padding_allowed);
}, },
NodeValidationError); NodeValidationFailure);
} }
TEST(partial_shape, infer_windowed_reduction_rank_dynamic_rank_dynamic_zero_window_dilation) TEST(partial_shape, infer_windowed_reduction_rank_dynamic_rank_dynamic_zero_window_dilation)
...@@ -911,7 +911,7 @@ TEST(partial_shape, infer_windowed_reduction_rank_dynamic_rank_dynamic_zero_wind ...@@ -911,7 +911,7 @@ TEST(partial_shape, infer_windowed_reduction_rank_dynamic_rank_dynamic_zero_wind
window_dilation, window_dilation,
is_window_all_in_padding_allowed); is_window_all_in_padding_allowed);
}, },
NodeValidationError); NodeValidationFailure);
} }
TEST(partial_shape, infer_windowed_reduction_rank_dynamic_rank_dynamic_zero_window_strides) TEST(partial_shape, infer_windowed_reduction_rank_dynamic_rank_dynamic_zero_window_strides)
...@@ -938,7 +938,7 @@ TEST(partial_shape, infer_windowed_reduction_rank_dynamic_rank_dynamic_zero_wind ...@@ -938,7 +938,7 @@ TEST(partial_shape, infer_windowed_reduction_rank_dynamic_rank_dynamic_zero_wind
window_dilation, window_dilation,
is_window_all_in_padding_allowed); is_window_all_in_padding_allowed);
}, },
NodeValidationError); NodeValidationFailure);
} }
TEST(partial_shape, infer_windowed_reduction_rank_static_dynamic_rank_dynamic_ok) TEST(partial_shape, infer_windowed_reduction_rank_static_dynamic_rank_dynamic_ok)
...@@ -992,7 +992,7 @@ TEST(partial_shape, ...@@ -992,7 +992,7 @@ TEST(partial_shape,
window_dilation, window_dilation,
is_window_all_in_padding_allowed); is_window_all_in_padding_allowed);
}, },
NodeValidationError); NodeValidationFailure);
} }
TEST(partial_shape, infer_windowed_reduction_rank_static_dynamic_rank_dynamic_neg_padding_ok) TEST(partial_shape, infer_windowed_reduction_rank_static_dynamic_rank_dynamic_neg_padding_ok)
...@@ -1071,7 +1071,7 @@ TEST(partial_shape, infer_windowed_reduction_rank_dynamic_rank_static_dynamic_wi ...@@ -1071,7 +1071,7 @@ TEST(partial_shape, infer_windowed_reduction_rank_dynamic_rank_static_dynamic_wi
window_dilation, window_dilation,
is_window_all_in_padding_allowed); is_window_all_in_padding_allowed);
}, },
NodeValidationError); NodeValidationFailure);
} }
TEST(partial_shape, TEST(partial_shape,
...@@ -1100,7 +1100,7 @@ TEST(partial_shape, ...@@ -1100,7 +1100,7 @@ TEST(partial_shape,
window_dilation, window_dilation,
is_window_all_in_padding_allowed); is_window_all_in_padding_allowed);
}, },
NodeValidationError); NodeValidationFailure);
} }
TEST(partial_shape, TEST(partial_shape,
...@@ -1156,7 +1156,7 @@ TEST(partial_shape, ...@@ -1156,7 +1156,7 @@ TEST(partial_shape,
window_dilation, window_dilation,
is_window_all_in_padding_allowed); is_window_all_in_padding_allowed);
}, },
NodeValidationError); NodeValidationFailure);
} }
TEST(partial_shape, TEST(partial_shape,
...@@ -1294,7 +1294,7 @@ TEST(partial_shape, infer_windowed_reduction_rank_static_dynamic_rank_static_dyn ...@@ -1294,7 +1294,7 @@ TEST(partial_shape, infer_windowed_reduction_rank_static_dynamic_rank_static_dyn
window_dilation, window_dilation,
is_window_all_in_padding_allowed); is_window_all_in_padding_allowed);
}, },
NodeValidationError); NodeValidationFailure);
} }
TEST(partial_shape, TEST(partial_shape,
...@@ -1351,5 +1351,5 @@ TEST(partial_shape, ...@@ -1351,5 +1351,5 @@ TEST(partial_shape,
window_dilation, window_dilation,
is_window_all_in_padding_allowed); is_window_all_in_padding_allowed);
}, },
NodeValidationError); NodeValidationFailure);
} }
...@@ -212,7 +212,7 @@ TEST(type_prop, batchnorm_training_rank_less_than_2) ...@@ -212,7 +212,7 @@ TEST(type_prop, batchnorm_training_rank_less_than_2)
auto bc = make_shared<op::BatchNormTraining>(dummy, dummy, dummy, 0.001); auto bc = make_shared<op::BatchNormTraining>(dummy, dummy, dummy, 0.001);
FAIL() << "BatchNorm c-tor should throw for tensors whose rank is less than 2"; FAIL() << "BatchNorm c-tor should throw for tensors whose rank is less than 2";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), EXPECT_HAS_SUBSTRING(error.what(),
std::string("Input argument must have rank of at least 2")); std::string("Input argument must have rank of at least 2"));
...@@ -233,7 +233,7 @@ TEST(type_prop, batchnorm_training_zero_channel_check) ...@@ -233,7 +233,7 @@ TEST(type_prop, batchnorm_training_zero_channel_check)
auto bc = make_shared<op::BatchNormTraining>(data_batch, gamma, beta, 0.001); auto bc = make_shared<op::BatchNormTraining>(data_batch, gamma, beta, 0.001);
FAIL() << "BatchNorm c-tor should throw for tensors w/ zero-dimension channels"; FAIL() << "BatchNorm c-tor should throw for tensors w/ zero-dimension channels";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), std::string("Channel count must be at least 1")); EXPECT_HAS_SUBSTRING(error.what(), std::string("Channel count must be at least 1"));
} }
...@@ -254,7 +254,7 @@ TEST(type_prop, batchnorm_training_et_check) ...@@ -254,7 +254,7 @@ TEST(type_prop, batchnorm_training_et_check)
auto bc = make_shared<op::BatchNormTraining>(data_batch, gamma, beta, 0.001); auto bc = make_shared<op::BatchNormTraining>(data_batch, gamma, beta, 0.001);
FAIL() << "BatchNorm c-tor should throw for different element types"; FAIL() << "BatchNorm c-tor should throw for different element types";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), std::string("Input element types do not match")); EXPECT_HAS_SUBSTRING(error.what(), std::string("Input element types do not match"));
} }
...@@ -275,7 +275,7 @@ TEST(type_prop, batchnorm_training_shape_check) ...@@ -275,7 +275,7 @@ TEST(type_prop, batchnorm_training_shape_check)
auto bc = make_shared<op::BatchNormTraining>(data_batch, gamma, beta, 0.001); auto bc = make_shared<op::BatchNormTraining>(data_batch, gamma, beta, 0.001);
FAIL() << "BatchNorm c-tor should throw if gamma and beta shapes don't match"; FAIL() << "BatchNorm c-tor should throw if gamma and beta shapes don't match";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), std::string("Shapes for gamma/beta do not match")); EXPECT_HAS_SUBSTRING(error.what(), std::string("Shapes for gamma/beta do not match"));
} }
...@@ -300,7 +300,7 @@ TEST(type_prop, batchnorm_training_backprop_et_check) ...@@ -300,7 +300,7 @@ TEST(type_prop, batchnorm_training_backprop_et_check)
data_batch, gamma, beta, mean, variance, delta, 0.001); data_batch, gamma, beta, mean, variance, delta, 0.001);
FAIL() << "Deduced type should disagree with c-tor arguments"; FAIL() << "Deduced type should disagree with c-tor arguments";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), std::string("Input element types do not match")); EXPECT_HAS_SUBSTRING(error.what(), std::string("Input element types do not match"));
} }
...@@ -325,7 +325,7 @@ TEST(type_prop, batchnorm_training_backprop_shape_check) ...@@ -325,7 +325,7 @@ TEST(type_prop, batchnorm_training_backprop_shape_check)
data_batch, gamma, beta, mean, variance, delta, 0.001); data_batch, gamma, beta, mean, variance, delta, 0.001);
FAIL() << "Deduced type should disagree with c-tor arguments"; FAIL() << "Deduced type should disagree with c-tor arguments";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), EXPECT_HAS_SUBSTRING(error.what(),
std::string("Shapes for gamma/beta/mean/variance do not match")); std::string("Shapes for gamma/beta/mean/variance do not match"));
...@@ -443,7 +443,7 @@ TEST(type_prop, batchnorm_inference_partial_input_rank_static_dynamic_zero_chann ...@@ -443,7 +443,7 @@ TEST(type_prop, batchnorm_inference_partial_input_rank_static_dynamic_zero_chann
make_shared<op::BatchNormInference>(data_batch, gamma, beta, mean, variance, epsilon); make_shared<op::BatchNormInference>(data_batch, gamma, beta, mean, variance, epsilon);
FAIL() << "Zero channel count not detected"; FAIL() << "Zero channel count not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), std::string("Channel count must be at least 1")); EXPECT_HAS_SUBSTRING(error.what(), std::string("Channel count must be at least 1"));
} }
...@@ -506,7 +506,7 @@ TEST(type_prop, batchnorm_inference_partial_input_rank_dynamic_some_rank_static_ ...@@ -506,7 +506,7 @@ TEST(type_prop, batchnorm_inference_partial_input_rank_dynamic_some_rank_static_
make_shared<op::BatchNormInference>(data_batch, gamma, beta, mean, variance, epsilon); make_shared<op::BatchNormInference>(data_batch, gamma, beta, mean, variance, epsilon);
FAIL() << "Wrong gamma/beta/mean/variance shape not detected"; FAIL() << "Wrong gamma/beta/mean/variance shape not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(
error.what(), error.what(),
...@@ -545,7 +545,7 @@ TEST(type_prop, ...@@ -545,7 +545,7 @@ TEST(type_prop,
make_shared<op::BatchNormInference>(data_batch, gamma, beta, mean, variance, epsilon); make_shared<op::BatchNormInference>(data_batch, gamma, beta, mean, variance, epsilon);
FAIL() << "Inconsistent gamma/beta/mean/variance shape not detected"; FAIL() << "Inconsistent gamma/beta/mean/variance shape not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), EXPECT_HAS_SUBSTRING(error.what(),
std::string("Shapes for gamma/beta/mean/variance do not match")); std::string("Shapes for gamma/beta/mean/variance do not match"));
...@@ -583,7 +583,7 @@ TEST(type_prop, ...@@ -583,7 +583,7 @@ TEST(type_prop,
make_shared<op::BatchNormInference>(data_batch, gamma, beta, mean, variance, epsilon); make_shared<op::BatchNormInference>(data_batch, gamma, beta, mean, variance, epsilon);
FAIL() << "Inconsistent gamma/beta/mean/variance channel count not detected"; FAIL() << "Inconsistent gamma/beta/mean/variance channel count not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), EXPECT_HAS_SUBSTRING(error.what(),
std::string("Shapes for gamma/beta/mean/variance do not match")); std::string("Shapes for gamma/beta/mean/variance do not match"));
...@@ -649,7 +649,7 @@ TEST(type_prop, ...@@ -649,7 +649,7 @@ TEST(type_prop,
make_shared<op::BatchNormInference>(data_batch, gamma, beta, mean, variance, epsilon); make_shared<op::BatchNormInference>(data_batch, gamma, beta, mean, variance, epsilon);
FAIL() << "Inconsistent input/gamma/beta/mean/variance channel count not detected"; FAIL() << "Inconsistent input/gamma/beta/mean/variance channel count not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), EXPECT_HAS_SUBSTRING(error.what(),
std::string("Input channel dimension (4) does not match " std::string("Input channel dimension (4) does not match "
...@@ -759,7 +759,7 @@ TEST(type_prop, batchnorm_training_partial_input_rank_static_dynamic_zero_channe ...@@ -759,7 +759,7 @@ TEST(type_prop, batchnorm_training_partial_input_rank_static_dynamic_zero_channe
auto bn = make_shared<op::BatchNormTraining>(data_batch, gamma, beta, epsilon); auto bn = make_shared<op::BatchNormTraining>(data_batch, gamma, beta, epsilon);
FAIL() << "Zero channel count not detected"; FAIL() << "Zero channel count not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), std::string("Channel count must be at least 1")); EXPECT_HAS_SUBSTRING(error.what(), std::string("Channel count must be at least 1"));
} }
...@@ -813,7 +813,7 @@ TEST(type_prop, batchnorm_training_partial_input_rank_dynamic_some_rank_static_d ...@@ -813,7 +813,7 @@ TEST(type_prop, batchnorm_training_partial_input_rank_dynamic_some_rank_static_d
auto bn = make_shared<op::BatchNormTraining>(data_batch, gamma, beta, epsilon); auto bn = make_shared<op::BatchNormTraining>(data_batch, gamma, beta, epsilon);
FAIL() << "Wrong gamma/beta shape not detected"; FAIL() << "Wrong gamma/beta shape not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), EXPECT_HAS_SUBSTRING(error.what(),
std::string("Shape for gamma/beta ({?,?}) does not have rank 1")); std::string("Shape for gamma/beta ({?,?}) does not have rank 1"));
...@@ -844,7 +844,7 @@ TEST(type_prop, ...@@ -844,7 +844,7 @@ TEST(type_prop,
auto bn = make_shared<op::BatchNormTraining>(data_batch, gamma, beta, epsilon); auto bn = make_shared<op::BatchNormTraining>(data_batch, gamma, beta, epsilon);
FAIL() << "Inconsistent gamma/beta shape not detected"; FAIL() << "Inconsistent gamma/beta shape not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), std::string("Shapes for gamma/beta do not match")); EXPECT_HAS_SUBSTRING(error.what(), std::string("Shapes for gamma/beta do not match"));
} }
...@@ -874,7 +874,7 @@ TEST(type_prop, ...@@ -874,7 +874,7 @@ TEST(type_prop,
auto bn = make_shared<op::BatchNormTraining>(data_batch, gamma, beta, epsilon); auto bn = make_shared<op::BatchNormTraining>(data_batch, gamma, beta, epsilon);
FAIL() << "Inconsistent gamma/beta channel count not detected"; FAIL() << "Inconsistent gamma/beta channel count not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), std::string("Shapes for gamma/beta do not match")); EXPECT_HAS_SUBSTRING(error.what(), std::string("Shapes for gamma/beta do not match"));
} }
...@@ -930,7 +930,7 @@ TEST(type_prop, ...@@ -930,7 +930,7 @@ TEST(type_prop,
auto bn = make_shared<op::BatchNormTraining>(data_batch, gamma, beta, epsilon); auto bn = make_shared<op::BatchNormTraining>(data_batch, gamma, beta, epsilon);
FAIL() << "Inconsistent input/gamma/beta channel count not detected"; FAIL() << "Inconsistent input/gamma/beta channel count not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(
error.what(), error.what(),
...@@ -1049,7 +1049,7 @@ TEST(type_prop, batchnorm_training_backprop_partial_input_rank_static_dynamic_ze ...@@ -1049,7 +1049,7 @@ TEST(type_prop, batchnorm_training_backprop_partial_input_rank_static_dynamic_ze
data_batch, gamma, beta, mean, variance, delta, epsilon); data_batch, gamma, beta, mean, variance, delta, epsilon);
FAIL() << "Zero channel count not detected"; FAIL() << "Zero channel count not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), std::string("Channel count must be at least 1")); EXPECT_HAS_SUBSTRING(error.what(), std::string("Channel count must be at least 1"));
} }
...@@ -1160,7 +1160,7 @@ TEST(type_prop, batchnorm_training_backprop_partial_delta_rank_static_dynamic_ze ...@@ -1160,7 +1160,7 @@ TEST(type_prop, batchnorm_training_backprop_partial_delta_rank_static_dynamic_ze
data_batch, gamma, beta, mean, variance, delta, epsilon); data_batch, gamma, beta, mean, variance, delta, epsilon);
FAIL() << "Zero channel count not detected"; FAIL() << "Zero channel count not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), std::string("Channel count must be at least 1")); EXPECT_HAS_SUBSTRING(error.what(), std::string("Channel count must be at least 1"));
} }
...@@ -1237,7 +1237,7 @@ TEST( ...@@ -1237,7 +1237,7 @@ TEST(
data_batch, gamma, beta, mean, variance, delta, epsilon); data_batch, gamma, beta, mean, variance, delta, epsilon);
FAIL() << "Wrong gamma/beta/mean/variance shape not detected"; FAIL() << "Wrong gamma/beta/mean/variance shape not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(
error.what(), error.what(),
...@@ -1280,7 +1280,7 @@ TEST( ...@@ -1280,7 +1280,7 @@ TEST(
data_batch, gamma, beta, mean, variance, delta, epsilon); data_batch, gamma, beta, mean, variance, delta, epsilon);
FAIL() << "Wrong gamma/beta/mean/variance shape not detected"; FAIL() << "Wrong gamma/beta/mean/variance shape not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), EXPECT_HAS_SUBSTRING(error.what(),
std::string("Shapes for gamma/beta/mean/variance do not match")); std::string("Shapes for gamma/beta/mean/variance do not match"));
...@@ -1322,7 +1322,7 @@ TEST( ...@@ -1322,7 +1322,7 @@ TEST(
data_batch, gamma, beta, mean, variance, delta, epsilon); data_batch, gamma, beta, mean, variance, delta, epsilon);
FAIL() << "nconsistent gamma/beta/mean/variance channel count not detected"; FAIL() << "nconsistent gamma/beta/mean/variance channel count not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), EXPECT_HAS_SUBSTRING(error.what(),
std::string("Shapes for gamma/beta/mean/variance do not match")); std::string("Shapes for gamma/beta/mean/variance do not match"));
...@@ -1400,7 +1400,7 @@ TEST( ...@@ -1400,7 +1400,7 @@ TEST(
data_batch, gamma, beta, mean, variance, delta, epsilon); data_batch, gamma, beta, mean, variance, delta, epsilon);
FAIL() << "Inconsistent delta/gamma/beta/mean/variance channel count not detected"; FAIL() << "Inconsistent delta/gamma/beta/mean/variance channel count not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), EXPECT_HAS_SUBSTRING(error.what(),
std::string("Input channel dimension (4) does not match " std::string("Input channel dimension (4) does not match "
...@@ -2218,7 +2218,7 @@ void test_binary(std::string node_type, ...@@ -2218,7 +2218,7 @@ void test_binary(std::string node_type,
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Incompatible view arguments not detected."; FAIL() << "Incompatible view arguments not detected.";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), std::string("Argument shapes are inconsistent")); EXPECT_HAS_SUBSTRING(error.what(), std::string("Argument shapes are inconsistent"));
} }
...@@ -2237,7 +2237,7 @@ void test_binary(std::string node_type, ...@@ -2237,7 +2237,7 @@ void test_binary(std::string node_type,
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Incompatible view arguments not detected."; FAIL() << "Incompatible view arguments not detected.";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), EXPECT_HAS_SUBSTRING(error.what(),
std::string("Argument element types are inconsistent")); std::string("Argument element types are inconsistent"));
...@@ -2310,7 +2310,7 @@ void test_binary_logical(std::string node_type, ...@@ -2310,7 +2310,7 @@ void test_binary_logical(std::string node_type,
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Incompatible view arguments not detected."; FAIL() << "Incompatible view arguments not detected.";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), std::string("Argument shapes are inconsistent")); EXPECT_HAS_SUBSTRING(error.what(), std::string("Argument shapes are inconsistent"));
} }
...@@ -2329,7 +2329,7 @@ void test_binary_logical(std::string node_type, ...@@ -2329,7 +2329,7 @@ void test_binary_logical(std::string node_type,
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Incompatible view arguments not detected."; FAIL() << "Incompatible view arguments not detected.";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), EXPECT_HAS_SUBSTRING(error.what(),
std::string("Argument element types are inconsistent")); std::string("Argument element types are inconsistent"));
...@@ -2463,7 +2463,7 @@ TEST(type_prop, binary_arithmetic_bad_argument_element_types) ...@@ -2463,7 +2463,7 @@ TEST(type_prop, binary_arithmetic_bad_argument_element_types)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Did not detect incorrect element types for arithmetic operator"; FAIL() << "Did not detect incorrect element types for arithmetic operator";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), EXPECT_HAS_SUBSTRING(error.what(),
std::string("Arguments cannot have boolean element type")); std::string("Arguments cannot have boolean element type"));
...@@ -2483,7 +2483,7 @@ TEST(type_prop, unary_arithmetic_bad_argument_element_types) ...@@ -2483,7 +2483,7 @@ TEST(type_prop, unary_arithmetic_bad_argument_element_types)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Did not detect incorrect element types for arithmetic operator"; FAIL() << "Did not detect incorrect element types for arithmetic operator";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), EXPECT_HAS_SUBSTRING(error.what(),
std::string("Arguments cannot have boolean element type")); std::string("Arguments cannot have boolean element type"));
...@@ -5420,7 +5420,7 @@ TEST(type_prop, conv_invalid_element_type_mismatch) ...@@ -5420,7 +5420,7 @@ TEST(type_prop, conv_invalid_element_type_mismatch)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid input with element type mismatch not detected"; FAIL() << "Invalid input with element type mismatch not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), EXPECT_HAS_SUBSTRING(error.what(),
std::string("Element types for data batch and filters do not match")); std::string("Element types for data batch and filters do not match"));
...@@ -5443,7 +5443,7 @@ TEST(type_prop, conv_invalid_0d_input) ...@@ -5443,7 +5443,7 @@ TEST(type_prop, conv_invalid_0d_input)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid 0D input not detected"; FAIL() << "Invalid 0D input not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), EXPECT_HAS_SUBSTRING(error.what(),
std::string("Data batch and filters must have rank of at least 3 " std::string("Data batch and filters must have rank of at least 3 "
...@@ -5468,7 +5468,7 @@ TEST(type_prop, conv_invalid_1d_input) ...@@ -5468,7 +5468,7 @@ TEST(type_prop, conv_invalid_1d_input)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid 1D input not detected"; FAIL() << "Invalid 1D input not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), EXPECT_HAS_SUBSTRING(error.what(),
std::string("Data batch and filters must have rank of at least 3 " std::string("Data batch and filters must have rank of at least 3 "
...@@ -5493,7 +5493,7 @@ TEST(type_prop, conv_invalid_2d_input) ...@@ -5493,7 +5493,7 @@ TEST(type_prop, conv_invalid_2d_input)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid 2D input not detected"; FAIL() << "Invalid 2D input not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), EXPECT_HAS_SUBSTRING(error.what(),
std::string("Data batch and filters must have rank of at least 3 " std::string("Data batch and filters must have rank of at least 3 "
...@@ -5518,7 +5518,7 @@ TEST(type_prop, conv_invalid_0_batch_size) ...@@ -5518,7 +5518,7 @@ TEST(type_prop, conv_invalid_0_batch_size)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid input with 0 batch size not detected"; FAIL() << "Invalid input with 0 batch size not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), std::string("Batch size is zero")); EXPECT_HAS_SUBSTRING(error.what(), std::string("Batch size is zero"));
} }
...@@ -5540,7 +5540,7 @@ TEST(type_prop, conv_invalid_0_input_channels) ...@@ -5540,7 +5540,7 @@ TEST(type_prop, conv_invalid_0_input_channels)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid input with 0 input channels not detected"; FAIL() << "Invalid input with 0 input channels not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(
error.what(), error.what(),
...@@ -5564,7 +5564,7 @@ TEST(type_prop, conv_invalid_wrong_number_of_filter_dimensions_too_many) ...@@ -5564,7 +5564,7 @@ TEST(type_prop, conv_invalid_wrong_number_of_filter_dimensions_too_many)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid input with too many filter dimensions not detected"; FAIL() << "Invalid input with too many filter dimensions not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), std::string("Data batch and filters rank do not match")); EXPECT_HAS_SUBSTRING(error.what(), std::string("Data batch and filters rank do not match"));
} }
...@@ -5586,7 +5586,7 @@ TEST(type_prop, conv_invalid_wrong_number_of_filter_dimensions_too_few) ...@@ -5586,7 +5586,7 @@ TEST(type_prop, conv_invalid_wrong_number_of_filter_dimensions_too_few)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid input with too few filter dimensions not detected"; FAIL() << "Invalid input with too few filter dimensions not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), std::string("Data batch and filters rank do not match")); EXPECT_HAS_SUBSTRING(error.what(), std::string("Data batch and filters rank do not match"));
} }
...@@ -5608,7 +5608,7 @@ TEST(type_prop, conv_invalid_0_output_channels) ...@@ -5608,7 +5608,7 @@ TEST(type_prop, conv_invalid_0_output_channels)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid input with 0 output channels not detected"; FAIL() << "Invalid input with 0 output channels not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), std::string("Filter output channel count is zero")); EXPECT_HAS_SUBSTRING(error.what(), std::string("Filter output channel count is zero"));
} }
...@@ -5630,7 +5630,7 @@ TEST(type_prop, conv_invalid_input_channel_mismatch) ...@@ -5630,7 +5630,7 @@ TEST(type_prop, conv_invalid_input_channel_mismatch)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid input with channel count mismatch not detected"; FAIL() << "Invalid input with channel count mismatch not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(
error.what(), error.what(),
...@@ -5655,7 +5655,7 @@ TEST(type_prop, conv_invalid_movement_stride_rank) ...@@ -5655,7 +5655,7 @@ TEST(type_prop, conv_invalid_movement_stride_rank)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid input with wrong movement stride rank not detected"; FAIL() << "Invalid input with wrong movement stride rank not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(
error.what(), error.what(),
...@@ -5684,7 +5684,7 @@ TEST(type_prop, conv_invalid_window_dilation_stride_rank) ...@@ -5684,7 +5684,7 @@ TEST(type_prop, conv_invalid_window_dilation_stride_rank)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid input with wrong window dilation stride rank not detected"; FAIL() << "Invalid input with wrong window dilation stride rank not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(
error.what(), error.what(),
...@@ -5719,7 +5719,7 @@ TEST(type_prop, conv_invalid_data_dilation_stride_rank) ...@@ -5719,7 +5719,7 @@ TEST(type_prop, conv_invalid_data_dilation_stride_rank)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid input with wrong data dilation stride rank not detected"; FAIL() << "Invalid input with wrong data dilation stride rank not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(
error.what(), error.what(),
...@@ -5753,7 +5753,7 @@ TEST(type_prop, conv_invalid_padding_below_rank) ...@@ -5753,7 +5753,7 @@ TEST(type_prop, conv_invalid_padding_below_rank)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid input with wrong padding-below rank not detected"; FAIL() << "Invalid input with wrong padding-below rank not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(
error.what(), error.what(),
...@@ -5787,7 +5787,7 @@ TEST(type_prop, conv_invalid_padding_above_rank) ...@@ -5787,7 +5787,7 @@ TEST(type_prop, conv_invalid_padding_above_rank)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid input with wrong padding-above rank not detected"; FAIL() << "Invalid input with wrong padding-above rank not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(
error.what(), error.what(),
...@@ -5821,7 +5821,7 @@ TEST(type_prop, conv_invalid_input_spatial_size_negative_after_padding) ...@@ -5821,7 +5821,7 @@ TEST(type_prop, conv_invalid_input_spatial_size_negative_after_padding)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid input with negative-length post-padding spatial axis not detected"; FAIL() << "Invalid input with negative-length post-padding spatial axis not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), EXPECT_HAS_SUBSTRING(error.what(),
std::string("Data shape after padding and dilation has dimension less " std::string("Data shape after padding and dilation has dimension less "
...@@ -5850,7 +5850,7 @@ TEST(type_prop, conv_invalid_input_spatial_size_zero_after_padding) ...@@ -5850,7 +5850,7 @@ TEST(type_prop, conv_invalid_input_spatial_size_zero_after_padding)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid input with zero-length post-padding spatial axis not detected"; FAIL() << "Invalid input with zero-length post-padding spatial axis not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), EXPECT_HAS_SUBSTRING(error.what(),
std::string("Data shape after padding and dilation has dimension less " std::string("Data shape after padding and dilation has dimension less "
...@@ -5874,7 +5874,7 @@ TEST(type_prop, conv_invalid_input_spatial_size_0) ...@@ -5874,7 +5874,7 @@ TEST(type_prop, conv_invalid_input_spatial_size_0)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid input with zero-length spatial axis not detected"; FAIL() << "Invalid input with zero-length spatial axis not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), EXPECT_HAS_SUBSTRING(error.what(),
std::string("Data shape after padding and dilation has " std::string("Data shape after padding and dilation has "
...@@ -5898,7 +5898,7 @@ TEST(type_prop, conv_invalid_window_size_0) ...@@ -5898,7 +5898,7 @@ TEST(type_prop, conv_invalid_window_size_0)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid input with zero-length window axis not detected"; FAIL() << "Invalid input with zero-length window axis not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(
error.what(), error.what(),
...@@ -5922,7 +5922,7 @@ TEST(type_prop, conv_invalid_window_dilation_stride_0) ...@@ -5922,7 +5922,7 @@ TEST(type_prop, conv_invalid_window_dilation_stride_0)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid input with wrong 0-length window dilation stride axis not detected"; FAIL() << "Invalid input with wrong 0-length window dilation stride axis not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(
error.what(), error.what(),
...@@ -5952,7 +5952,7 @@ TEST(type_prop, conv_invalid_data_dilation_stride_0) ...@@ -5952,7 +5952,7 @@ TEST(type_prop, conv_invalid_data_dilation_stride_0)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid input with wrong 0-length data dilation stride axis not detected"; FAIL() << "Invalid input with wrong 0-length data dilation stride axis not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(
error.what(), error.what(),
...@@ -5976,7 +5976,7 @@ TEST(type_prop, conv_invalid_dilated_window_too_large) ...@@ -5976,7 +5976,7 @@ TEST(type_prop, conv_invalid_dilated_window_too_large)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid input with oversized dilated window not detected"; FAIL() << "Invalid input with oversized dilated window not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), EXPECT_HAS_SUBSTRING(error.what(),
std::string("Window after dilation has dimension (dim: 9) larger than " std::string("Window after dilation has dimension (dim: 9) larger than "
...@@ -6000,7 +6000,7 @@ TEST(type_prop, conv_invalid_movement_stride_0) ...@@ -6000,7 +6000,7 @@ TEST(type_prop, conv_invalid_movement_stride_0)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid input with wrong 0-length movement stride axis not detected"; FAIL() << "Invalid input with wrong 0-length movement stride axis not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(
error.what(), error.what(),
...@@ -6062,7 +6062,7 @@ TEST(type_prop, conv_partial_rank_dynamic_rank_dynamic_window_strides_rank_wrong ...@@ -6062,7 +6062,7 @@ TEST(type_prop, conv_partial_rank_dynamic_rank_dynamic_window_strides_rank_wrong
FAIL() << "Window stride rank mismatch not detected"; FAIL() << "Window stride rank mismatch not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(
error.what(), error.what(),
...@@ -6103,7 +6103,7 @@ TEST(type_prop, conv_partial_rank_dynamic_rank_dynamic_window_strides_dim_zero) ...@@ -6103,7 +6103,7 @@ TEST(type_prop, conv_partial_rank_dynamic_rank_dynamic_window_strides_dim_zero)
FAIL() << "Window stride with dimension zero not detected"; FAIL() << "Window stride with dimension zero not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(
error.what(), error.what(),
...@@ -6140,7 +6140,7 @@ TEST(type_prop, conv_partial_rank_dynamic_rank_dynamic_window_dilation_rank_wron ...@@ -6140,7 +6140,7 @@ TEST(type_prop, conv_partial_rank_dynamic_rank_dynamic_window_dilation_rank_wron
FAIL() << "Window dilation rank mismatch not detected"; FAIL() << "Window dilation rank mismatch not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(
error.what(), error.what(),
...@@ -6181,7 +6181,7 @@ TEST(type_prop, conv_partial_rank_dynamic_rank_dynamic_window_dilation_dim_zero) ...@@ -6181,7 +6181,7 @@ TEST(type_prop, conv_partial_rank_dynamic_rank_dynamic_window_dilation_dim_zero)
FAIL() << "Window dilation with dimension zero not detected"; FAIL() << "Window dilation with dimension zero not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(
error.what(), error.what(),
...@@ -6218,7 +6218,7 @@ TEST(type_prop, conv_partial_rank_dynamic_rank_dynamic_padding_below_rank_wrong) ...@@ -6218,7 +6218,7 @@ TEST(type_prop, conv_partial_rank_dynamic_rank_dynamic_padding_below_rank_wrong)
FAIL() << "Padding below rank mismatch not detected"; FAIL() << "Padding below rank mismatch not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(
error.what(), error.what(),
...@@ -6259,7 +6259,7 @@ TEST(type_prop, conv_partial_rank_dynamic_rank_dynamic_padding_above_rank_wrong) ...@@ -6259,7 +6259,7 @@ TEST(type_prop, conv_partial_rank_dynamic_rank_dynamic_padding_above_rank_wrong)
FAIL() << "Padding above rank mismatch not detected"; FAIL() << "Padding above rank mismatch not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(
error.what(), error.what(),
...@@ -6300,7 +6300,7 @@ TEST(type_prop, conv_partial_rank_dynamic_rank_dynamic_data_dilation_rank_wrong) ...@@ -6300,7 +6300,7 @@ TEST(type_prop, conv_partial_rank_dynamic_rank_dynamic_data_dilation_rank_wrong)
FAIL() << "Data dilation rank mismatch not detected"; FAIL() << "Data dilation rank mismatch not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(
error.what(), error.what(),
...@@ -6341,7 +6341,7 @@ TEST(type_prop, conv_partial_rank_dynamic_rank_dynamic_data_dilation_dim_zero) ...@@ -6341,7 +6341,7 @@ TEST(type_prop, conv_partial_rank_dynamic_rank_dynamic_data_dilation_dim_zero)
FAIL() << "Data dilation with dimension zero not detected"; FAIL() << "Data dilation with dimension zero not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(
error.what(), error.what(),
...@@ -6403,7 +6403,7 @@ TEST(type_prop, conv_partial_rank_static_dynamic_rank_dynamic_data_batch_rank_wr ...@@ -6403,7 +6403,7 @@ TEST(type_prop, conv_partial_rank_static_dynamic_rank_dynamic_data_batch_rank_wr
FAIL() << "Data batch rank mismatch not detected"; FAIL() << "Data batch rank mismatch not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(
error.what(), error.what(),
...@@ -6473,7 +6473,7 @@ TEST(type_prop, conv_partial_rank_static_dynamic_rank_dynamic_batch_size_known_z ...@@ -6473,7 +6473,7 @@ TEST(type_prop, conv_partial_rank_static_dynamic_rank_dynamic_batch_size_known_z
FAIL() << "Zero batch size not detected"; FAIL() << "Zero batch size not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), std::string("Batch size is zero")); EXPECT_HAS_SUBSTRING(error.what(), std::string("Batch size is zero"));
} }
...@@ -6535,7 +6535,7 @@ TEST(type_prop, conv_partial_rank_static_dynamic_rank_dynamic_input_channel_coun ...@@ -6535,7 +6535,7 @@ TEST(type_prop, conv_partial_rank_static_dynamic_rank_dynamic_input_channel_coun
FAIL() << "Zero input channel count not detected"; FAIL() << "Zero input channel count not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(
error.what(), error.what(),
...@@ -6599,7 +6599,7 @@ TEST(type_prop, conv_partial_rank_dynamic_rank_static_dynamic_output_channel_cou ...@@ -6599,7 +6599,7 @@ TEST(type_prop, conv_partial_rank_dynamic_rank_static_dynamic_output_channel_cou
FAIL() << "Zero output channel count not detected"; FAIL() << "Zero output channel count not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), std::string("Filter output channel count is zero")); EXPECT_HAS_SUBSTRING(error.what(), std::string("Filter output channel count is zero"));
} }
...@@ -6659,7 +6659,7 @@ TEST(type_prop, conv_partial_rank_dynamic_rank_static_dynamic_input_channel_coun ...@@ -6659,7 +6659,7 @@ TEST(type_prop, conv_partial_rank_dynamic_rank_static_dynamic_input_channel_coun
FAIL() << "Zero input channel count not detected"; FAIL() << "Zero input channel count not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(
error.what(), error.what(),
...@@ -6721,7 +6721,7 @@ TEST(type_prop, conv_partial_rank_static_dynamic_rank_static_dynamic_arg_ranks_m ...@@ -6721,7 +6721,7 @@ TEST(type_prop, conv_partial_rank_static_dynamic_rank_static_dynamic_arg_ranks_m
FAIL() << "Argument rank mismatch not detected"; FAIL() << "Argument rank mismatch not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), EXPECT_HAS_SUBSTRING(error.what(),
std::string("Data batch and filters rank do not match (data batch " std::string("Data batch and filters rank do not match (data batch "
...@@ -6786,7 +6786,7 @@ TEST(type_prop, conv_partial_rank_static_dynamic_rank_static_dynamic_input_chann ...@@ -6786,7 +6786,7 @@ TEST(type_prop, conv_partial_rank_static_dynamic_rank_static_dynamic_input_chann
FAIL() << "Input channel count mismatch not detected"; FAIL() << "Input channel count mismatch not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(
error.what(), error.what(),
...@@ -6879,7 +6879,7 @@ TEST( ...@@ -6879,7 +6879,7 @@ TEST(
FAIL() << "Oversize filter not detected"; FAIL() << "Oversize filter not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), EXPECT_HAS_SUBSTRING(error.what(),
std::string("Window after dilation has dimension (dim: 201) larger " std::string("Window after dilation has dimension (dim: 201) larger "
...@@ -7002,7 +7002,7 @@ TEST( ...@@ -7002,7 +7002,7 @@ TEST(
FAIL() << "Oversize filter after window dilation not detected"; FAIL() << "Oversize filter after window dilation not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), EXPECT_HAS_SUBSTRING(error.what(),
std::string("Window after dilation has dimension (dim: 201) larger " std::string("Window after dilation has dimension (dim: 201) larger "
...@@ -7041,7 +7041,7 @@ TEST( ...@@ -7041,7 +7041,7 @@ TEST(
FAIL() << "Zero dimension in data batch not detected"; FAIL() << "Zero dimension in data batch not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), EXPECT_HAS_SUBSTRING(error.what(),
std::string("Data shape after padding and dilation has " std::string("Data shape after padding and dilation has "
...@@ -7108,7 +7108,7 @@ TEST( ...@@ -7108,7 +7108,7 @@ TEST(
FAIL() << "Zero padded dimension in data batch not detected"; FAIL() << "Zero padded dimension in data batch not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), EXPECT_HAS_SUBSTRING(error.what(),
std::string("Data shape after padding and dilation has " std::string("Data shape after padding and dilation has "
...@@ -7147,7 +7147,7 @@ TEST( ...@@ -7147,7 +7147,7 @@ TEST(
FAIL() << "Negative padded dimension in data batch not detected"; FAIL() << "Negative padded dimension in data batch not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), EXPECT_HAS_SUBSTRING(error.what(),
std::string("Data shape after padding and dilation has dimension less " std::string("Data shape after padding and dilation has dimension less "
...@@ -7302,7 +7302,7 @@ TEST(type_prop, max_pool_invalid_0d_input) ...@@ -7302,7 +7302,7 @@ TEST(type_prop, max_pool_invalid_0d_input)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid 0D input not detected"; FAIL() << "Invalid 0D input not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), std::string("Data batch must have rank of at least 3")); EXPECT_HAS_SUBSTRING(error.what(), std::string("Data batch must have rank of at least 3"));
} }
...@@ -7324,7 +7324,7 @@ TEST(type_prop, max_pool_invalid_1d_input) ...@@ -7324,7 +7324,7 @@ TEST(type_prop, max_pool_invalid_1d_input)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid 1D input not detected"; FAIL() << "Invalid 1D input not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), std::string("Data batch must have rank of at least 3")); EXPECT_HAS_SUBSTRING(error.what(), std::string("Data batch must have rank of at least 3"));
} }
...@@ -7346,7 +7346,7 @@ TEST(type_prop, max_pool_invalid_2d_input) ...@@ -7346,7 +7346,7 @@ TEST(type_prop, max_pool_invalid_2d_input)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid 2D input not detected"; FAIL() << "Invalid 2D input not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), std::string("Data batch must have rank of at least 3")); EXPECT_HAS_SUBSTRING(error.what(), std::string("Data batch must have rank of at least 3"));
} }
...@@ -7368,7 +7368,7 @@ TEST(type_prop, max_pool_invalid_0_batch_size) ...@@ -7368,7 +7368,7 @@ TEST(type_prop, max_pool_invalid_0_batch_size)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid input with 0 batch size not detected"; FAIL() << "Invalid input with 0 batch size not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), std::string("Batch size is zero")); EXPECT_HAS_SUBSTRING(error.what(), std::string("Batch size is zero"));
} }
...@@ -7390,7 +7390,7 @@ TEST(type_prop, max_pool_invalid_0_channels) ...@@ -7390,7 +7390,7 @@ TEST(type_prop, max_pool_invalid_0_channels)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid input with 0 channels not detected"; FAIL() << "Invalid input with 0 channels not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), std::string("Channel count is zero")); EXPECT_HAS_SUBSTRING(error.what(), std::string("Channel count is zero"));
} }
...@@ -7412,7 +7412,7 @@ TEST(type_prop, max_pool_invalid_wrong_number_of_window_dimensions_too_many) ...@@ -7412,7 +7412,7 @@ TEST(type_prop, max_pool_invalid_wrong_number_of_window_dimensions_too_many)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid input with too many window dimensions not detected"; FAIL() << "Invalid input with too many window dimensions not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(
error.what(), error.what(),
...@@ -7439,7 +7439,7 @@ TEST(type_prop, max_pool_invalid_wrong_number_of_window_dimensions_too_few) ...@@ -7439,7 +7439,7 @@ TEST(type_prop, max_pool_invalid_wrong_number_of_window_dimensions_too_few)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid input with too few window dimensions not detected"; FAIL() << "Invalid input with too few window dimensions not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(
error.what(), error.what(),
...@@ -7467,7 +7467,7 @@ TEST(type_prop, max_pool_invalid_movement_stride_rank) ...@@ -7467,7 +7467,7 @@ TEST(type_prop, max_pool_invalid_movement_stride_rank)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid input with wrong movement stride rank not detected"; FAIL() << "Invalid input with wrong movement stride rank not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(
error.what(), error.what(),
...@@ -7494,7 +7494,7 @@ TEST(type_prop, max_pool_invalid_input_data_size_0) ...@@ -7494,7 +7494,7 @@ TEST(type_prop, max_pool_invalid_input_data_size_0)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid input with zero-length spatial axis not detected"; FAIL() << "Invalid input with zero-length spatial axis not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), EXPECT_HAS_SUBSTRING(error.what(),
std::string("Data shape after padding and dilation has " std::string("Data shape after padding and dilation has "
...@@ -7518,7 +7518,7 @@ TEST(type_prop, max_pool_invalid_window_size_0) ...@@ -7518,7 +7518,7 @@ TEST(type_prop, max_pool_invalid_window_size_0)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid input with zero-length window axis not detected"; FAIL() << "Invalid input with zero-length window axis not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(
error.what(), error.what(),
...@@ -7542,7 +7542,7 @@ TEST(type_prop, max_pool_invalid_dilated_too_large) ...@@ -7542,7 +7542,7 @@ TEST(type_prop, max_pool_invalid_dilated_too_large)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid input with oversized window not detected"; FAIL() << "Invalid input with oversized window not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), EXPECT_HAS_SUBSTRING(error.what(),
std::string("Window after dilation has dimension (dim: 9) larger than " std::string("Window after dilation has dimension (dim: 9) larger than "
...@@ -7567,7 +7567,7 @@ TEST(type_prop, max_pool_invalid_movement_stride_0) ...@@ -7567,7 +7567,7 @@ TEST(type_prop, max_pool_invalid_movement_stride_0)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid input with 0-length movement stride axis not detected"; FAIL() << "Invalid input with 0-length movement stride axis not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(
error.what(), error.what(),
...@@ -7611,7 +7611,7 @@ TEST(type_prop, max_pool_partial_rank_dynamic_attrib_rank_mismatch) ...@@ -7611,7 +7611,7 @@ TEST(type_prop, max_pool_partial_rank_dynamic_attrib_rank_mismatch)
param, window_shape, window_movement_strides, padding_below, padding_above); param, window_shape, window_movement_strides, padding_below, padding_above);
FAIL() << "Mismatch of attribute ranks not detected"; FAIL() << "Mismatch of attribute ranks not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(
error.what(), error.what(),
...@@ -7675,7 +7675,7 @@ TEST(type_prop, max_pool_partial_rank_static_dynamic_attrib_rank_mismatch) ...@@ -7675,7 +7675,7 @@ TEST(type_prop, max_pool_partial_rank_static_dynamic_attrib_rank_mismatch)
param, window_shape, window_movement_strides, padding_below, padding_above); param, window_shape, window_movement_strides, padding_below, padding_above);
FAIL() << "Mismatch of attribute ranks not detected"; FAIL() << "Mismatch of attribute ranks not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(
error.what(), error.what(),
...@@ -7706,7 +7706,7 @@ TEST(type_prop, max_pool_partial_rank_static_dynamic_window_not_too_big) ...@@ -7706,7 +7706,7 @@ TEST(type_prop, max_pool_partial_rank_static_dynamic_window_not_too_big)
param, window_shape, window_movement_strides, padding_below, padding_above); param, window_shape, window_movement_strides, padding_below, padding_above);
FAIL() << "Oversized window not detected"; FAIL() << "Oversized window not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), EXPECT_HAS_SUBSTRING(error.what(),
std::string("Window after dilation has dimension (dim: 9) larger than " std::string("Window after dilation has dimension (dim: 9) larger than "
...@@ -8371,7 +8371,7 @@ TEST(type_prop, avg_pool_invalid_0d_input) ...@@ -8371,7 +8371,7 @@ TEST(type_prop, avg_pool_invalid_0d_input)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid 0D input not detected"; FAIL() << "Invalid 0D input not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), EXPECT_HAS_SUBSTRING(error.what(),
"Data batch must have rank of at least 3 (one batch axis, one " "Data batch must have rank of at least 3 (one batch axis, one "
...@@ -8395,7 +8395,7 @@ TEST(type_prop, avg_pool_invalid_1d_input) ...@@ -8395,7 +8395,7 @@ TEST(type_prop, avg_pool_invalid_1d_input)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid 1D input not detected"; FAIL() << "Invalid 1D input not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), EXPECT_HAS_SUBSTRING(error.what(),
"Data batch must have rank of at least 3 (one batch axis, one " "Data batch must have rank of at least 3 (one batch axis, one "
...@@ -8419,7 +8419,7 @@ TEST(type_prop, avg_pool_invalid_2d_input) ...@@ -8419,7 +8419,7 @@ TEST(type_prop, avg_pool_invalid_2d_input)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid 2D input not detected"; FAIL() << "Invalid 2D input not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), EXPECT_HAS_SUBSTRING(error.what(),
"Data batch must have rank of at least 3 (one batch axis, one " "Data batch must have rank of at least 3 (one batch axis, one "
...@@ -8443,7 +8443,7 @@ TEST(type_prop, avg_pool_invalid_0_batch_size) ...@@ -8443,7 +8443,7 @@ TEST(type_prop, avg_pool_invalid_0_batch_size)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid input with 0 batch size not detected"; FAIL() << "Invalid input with 0 batch size not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), "Batch size is zero"); EXPECT_HAS_SUBSTRING(error.what(), "Batch size is zero");
} }
...@@ -8465,7 +8465,7 @@ TEST(type_prop, avg_pool_invalid_0_channels) ...@@ -8465,7 +8465,7 @@ TEST(type_prop, avg_pool_invalid_0_channels)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid input with 0 channels not detected"; FAIL() << "Invalid input with 0 channels not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), "Channel count is zero"); EXPECT_HAS_SUBSTRING(error.what(), "Channel count is zero");
} }
...@@ -8487,7 +8487,7 @@ TEST(type_prop, avg_pool_invalid_wrong_number_of_window_dimensions_too_many) ...@@ -8487,7 +8487,7 @@ TEST(type_prop, avg_pool_invalid_wrong_number_of_window_dimensions_too_many)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid input with too many window dimensions not detected"; FAIL() << "Invalid input with too many window dimensions not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), EXPECT_HAS_SUBSTRING(error.what(),
"Ranks for data item shape (data batch has shape {6,2,10,10}, so data " "Ranks for data item shape (data batch has shape {6,2,10,10}, so data "
...@@ -8513,7 +8513,7 @@ TEST(type_prop, avg_pool_invalid_wrong_number_of_window_dimensions_too_few) ...@@ -8513,7 +8513,7 @@ TEST(type_prop, avg_pool_invalid_wrong_number_of_window_dimensions_too_few)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid input with too few window dimensions not detected"; FAIL() << "Invalid input with too few window dimensions not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), EXPECT_HAS_SUBSTRING(error.what(),
"Ranks for data item shape (data batch has shape {6,2,10,10}, so data " "Ranks for data item shape (data batch has shape {6,2,10,10}, so data "
...@@ -8540,7 +8540,7 @@ TEST(type_prop, avg_pool_invalid_movement_stride_rank) ...@@ -8540,7 +8540,7 @@ TEST(type_prop, avg_pool_invalid_movement_stride_rank)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid input with wrong movement stride rank not detected"; FAIL() << "Invalid input with wrong movement stride rank not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), EXPECT_HAS_SUBSTRING(error.what(),
"Ranks for data item shape (data batch has shape {6,2,10,10}, so data " "Ranks for data item shape (data batch has shape {6,2,10,10}, so data "
...@@ -8570,7 +8570,7 @@ TEST(type_prop, avg_pool_invalid_padding_below_rank) ...@@ -8570,7 +8570,7 @@ TEST(type_prop, avg_pool_invalid_padding_below_rank)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid input with wrong below-padding rank not detected"; FAIL() << "Invalid input with wrong below-padding rank not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), EXPECT_HAS_SUBSTRING(error.what(),
"Ranks for data item shape (data batch has shape {6,2,10,10}, so data " "Ranks for data item shape (data batch has shape {6,2,10,10}, so data "
...@@ -8600,7 +8600,7 @@ TEST(type_prop, avg_pool_invalid_padding_above_rank) ...@@ -8600,7 +8600,7 @@ TEST(type_prop, avg_pool_invalid_padding_above_rank)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid input with wrong above-padding rank not detected"; FAIL() << "Invalid input with wrong above-padding rank not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), EXPECT_HAS_SUBSTRING(error.what(),
"Ranks for data item shape (data batch has shape {6,2,10,10}, so data " "Ranks for data item shape (data batch has shape {6,2,10,10}, so data "
...@@ -8626,7 +8626,7 @@ TEST(type_prop, avg_pool_invalid_input_item_size_0) ...@@ -8626,7 +8626,7 @@ TEST(type_prop, avg_pool_invalid_input_item_size_0)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid input with zero-length spatial axis not detected"; FAIL() << "Invalid input with zero-length spatial axis not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(
error.what(), error.what(),
...@@ -8650,7 +8650,7 @@ TEST(type_prop, avg_pool_invalid_window_size_0) ...@@ -8650,7 +8650,7 @@ TEST(type_prop, avg_pool_invalid_window_size_0)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid input with zero-length window axis not detected"; FAIL() << "Invalid input with zero-length window axis not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), EXPECT_HAS_SUBSTRING(error.what(),
"Window after dilation has dimension less than 1 (dim: 0) at axis 1"); "Window after dilation has dimension less than 1 (dim: 0) at axis 1");
...@@ -8673,7 +8673,7 @@ TEST(type_prop, avg_pool_invalid_dilated_too_large) ...@@ -8673,7 +8673,7 @@ TEST(type_prop, avg_pool_invalid_dilated_too_large)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid input with oversized window not detected"; FAIL() << "Invalid input with oversized window not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), EXPECT_HAS_SUBSTRING(error.what(),
"Window after dilation has dimension (dim: 9) larger than the data " "Window after dilation has dimension (dim: 9) larger than the data "
...@@ -8712,7 +8712,7 @@ TEST(type_prop, avg_pool_invalid_movement_stride_0) ...@@ -8712,7 +8712,7 @@ TEST(type_prop, avg_pool_invalid_movement_stride_0)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid input with 0-length movement stride axis not detected"; FAIL() << "Invalid input with 0-length movement stride axis not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), EXPECT_HAS_SUBSTRING(error.what(),
"Window strides (Strides{0, 1}) has zero dimension at axis 0"); "Window strides (Strides{0, 1}) has zero dimension at axis 0");
...@@ -8765,7 +8765,7 @@ TEST(type_prop, avg_pool_partial_rank_dynamic_attrib_rank_mismatch) ...@@ -8765,7 +8765,7 @@ TEST(type_prop, avg_pool_partial_rank_dynamic_attrib_rank_mismatch)
include_padding_in_average); include_padding_in_average);
FAIL() << "Mismatch of attribute ranks not detected"; FAIL() << "Mismatch of attribute ranks not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(
error.what(), error.what(),
...@@ -8844,7 +8844,7 @@ TEST(type_prop, avg_pool_partial_rank_static_dynamic_attrib_rank_mismatch) ...@@ -8844,7 +8844,7 @@ TEST(type_prop, avg_pool_partial_rank_static_dynamic_attrib_rank_mismatch)
include_padding_in_average); include_padding_in_average);
FAIL() << "Mismatch of attribute ranks not detected"; FAIL() << "Mismatch of attribute ranks not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(
error.what(), error.what(),
...@@ -8880,7 +8880,7 @@ TEST(type_prop, avg_pool_partial_rank_static_dynamic_window_not_too_big) ...@@ -8880,7 +8880,7 @@ TEST(type_prop, avg_pool_partial_rank_static_dynamic_window_not_too_big)
include_padding_in_average); include_padding_in_average);
FAIL() << "Oversized window not detected"; FAIL() << "Oversized window not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), EXPECT_HAS_SUBSTRING(error.what(),
std::string("Window after dilation has dimension (dim: 9) larger than " std::string("Window after dilation has dimension (dim: 9) larger than "
...@@ -8935,7 +8935,7 @@ TEST(type_prop, avg_pool_partial_rank_static_dynamic_window_in_padding) ...@@ -8935,7 +8935,7 @@ TEST(type_prop, avg_pool_partial_rank_static_dynamic_window_in_padding)
include_padding_in_average); include_padding_in_average);
FAIL() << "Window in padding not detected"; FAIL() << "Window in padding not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), EXPECT_HAS_SUBSTRING(error.what(),
std::string("Window after dilation has dimension (dim: 9) larger than " std::string("Window after dilation has dimension (dim: 9) larger than "
...@@ -10030,7 +10030,7 @@ TEST(type_prop, binary_elementwise_arithmetic_left_rank_static_dynamic_inconsist ...@@ -10030,7 +10030,7 @@ TEST(type_prop, binary_elementwise_arithmetic_left_rank_static_dynamic_inconsist
auto add = make_shared<op::Add>(a, b); auto add = make_shared<op::Add>(a, b);
FAIL() << "Inconsistent partial shapes not detected"; FAIL() << "Inconsistent partial shapes not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), "Argument shapes are inconsistent"); EXPECT_HAS_SUBSTRING(error.what(), "Argument shapes are inconsistent");
} }
...@@ -10050,7 +10050,7 @@ TEST(type_prop, binary_elementwise_arithmetic_right_rank_static_dynamic_inconsis ...@@ -10050,7 +10050,7 @@ TEST(type_prop, binary_elementwise_arithmetic_right_rank_static_dynamic_inconsis
auto add = make_shared<op::Add>(a, b); auto add = make_shared<op::Add>(a, b);
FAIL() << "Inconsistent partial shapes not detected"; FAIL() << "Inconsistent partial shapes not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), "Argument shapes are inconsistent"); EXPECT_HAS_SUBSTRING(error.what(), "Argument shapes are inconsistent");
} }
...@@ -10070,7 +10070,7 @@ TEST(type_prop, binary_elementwise_arithmetic_both_rank_static_dynamic_inconsist ...@@ -10070,7 +10070,7 @@ TEST(type_prop, binary_elementwise_arithmetic_both_rank_static_dynamic_inconsist
auto add = make_shared<op::Add>(a, b); auto add = make_shared<op::Add>(a, b);
FAIL() << "Inconsistent partial shapes not detected"; FAIL() << "Inconsistent partial shapes not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), "Argument shapes are inconsistent"); EXPECT_HAS_SUBSTRING(error.what(), "Argument shapes are inconsistent");
} }
...@@ -10090,7 +10090,7 @@ TEST(type_prop, binary_elementwise_arithmetic_left_rank_static_dynamic_different ...@@ -10090,7 +10090,7 @@ TEST(type_prop, binary_elementwise_arithmetic_left_rank_static_dynamic_different
auto add = make_shared<op::Add>(a, b); auto add = make_shared<op::Add>(a, b);
FAIL() << "Inconsistent partial shapes not detected"; FAIL() << "Inconsistent partial shapes not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), "Argument shapes are inconsistent"); EXPECT_HAS_SUBSTRING(error.what(), "Argument shapes are inconsistent");
} }
...@@ -10110,7 +10110,7 @@ TEST(type_prop, binary_elementwise_arithmetic_right_rank_static_dynamic_differen ...@@ -10110,7 +10110,7 @@ TEST(type_prop, binary_elementwise_arithmetic_right_rank_static_dynamic_differen
auto add = make_shared<op::Add>(a, b); auto add = make_shared<op::Add>(a, b);
FAIL() << "Inconsistent partial shapes not detected"; FAIL() << "Inconsistent partial shapes not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), "Argument shapes are inconsistent"); EXPECT_HAS_SUBSTRING(error.what(), "Argument shapes are inconsistent");
} }
...@@ -10130,7 +10130,7 @@ TEST(type_prop, binary_elementwise_arithmetic_both_rank_static_dynamic_different ...@@ -10130,7 +10130,7 @@ TEST(type_prop, binary_elementwise_arithmetic_both_rank_static_dynamic_different
auto add = make_shared<op::Add>(a, b); auto add = make_shared<op::Add>(a, b);
FAIL() << "Inconsistent partial shapes not detected"; FAIL() << "Inconsistent partial shapes not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), "Argument shapes are inconsistent"); EXPECT_HAS_SUBSTRING(error.what(), "Argument shapes are inconsistent");
} }
...@@ -12095,3 +12095,56 @@ TEST(type_prop, all_partial_rank_static_dynamic_axes_oob) ...@@ -12095,3 +12095,56 @@ TEST(type_prop, all_partial_rank_static_dynamic_axes_oob)
FAIL() << "Deduced type check failed for unexpected reason"; FAIL() << "Deduced type check failed for unexpected reason";
} }
} }
TEST(type_prop, DISABLED_benchmark_type_prop_add)
{
auto p1 = make_shared<op::Parameter>(element::f32, Shape{1, 2, 3, 4});
auto p2 = make_shared<op::Parameter>(element::f32, Shape{1, 2, 3, 4});
constexpr size_t num_iterations = 1000000;
size_t total_nanosec = 0;
stopwatch sw;
for (size_t i = 0; i < num_iterations; i++)
{
sw.start();
auto n = make_shared<op::Add>(p1, p2);
sw.stop();
total_nanosec += sw.get_nanoseconds();
}
std::cout.imbue(std::locale(""));
std::cout << "Constructed " << std::fixed << num_iterations << " Add ops in " << std::fixed
<< total_nanosec << " ns" << std::endl;
}
TEST(type_prop, DISABLED_benchmark_type_prop_convolution)
{
auto d = make_shared<op::Parameter>(element::f32, Shape{64, 3, 224, 224});
auto f = make_shared<op::Parameter>(element::f32, Shape{64, 3, 7, 7});
auto strides = Strides{1, 1};
auto dilation = Strides{1, 1};
auto padding_below = CoordinateDiff{1, 1};
auto padding_above = CoordinateDiff{1, 1};
constexpr size_t num_iterations = 1000000;
size_t total_nanosec = 0;
stopwatch sw;
for (size_t i = 0; i < num_iterations; i++)
{
sw.start();
auto n =
make_shared<op::Convolution>(d, f, strides, dilation, padding_below, padding_above);
sw.stop();
total_nanosec += sw.get_nanoseconds();
}
std::cout.imbue(std::locale(""));
std::cout << "Constructed " << std::fixed << num_iterations << " Convolution ops in "
<< std::fixed << total_nanosec << " ns" << std::endl;
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment