Commit 6ffbd254 authored by Adam Procter's avatar Adam Procter Committed by Robert Kimball

New macros for faster node validation checks (#2537)

* Skeleton for faster validation asserts

* Switch to __VA_ARGS__ for compatibility, remove -Wno-variadic-macros

* Add benchmarks for constructing Add and Convolution

* Quick hack to avoid shadowing inside the CHECK macro

* Quick hack to avoid inadvertent capture inside the macro

* Update convolution, and change a bunch of tests to anticipate the new error class
parent eb34c884
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <exception>
#include <sstream>
#include <vector>
#include "ngraph/except.hpp"
namespace ngraph
{
static inline std::ostream& write_all_to_stream(std::ostream& str) { return str; }
template <typename T, typename... TS>
static inline std::ostream& write_all_to_stream(std::ostream& str, const T& arg, TS... args)
{
return write_all_to_stream(str << arg, args...);
}
struct CheckLocInfo
{
const char* file;
int line;
const char* check_string;
};
/// Base class for check failure exceptions.
class CheckFailure : public ngraph_error
{
public:
CheckFailure(const CheckLocInfo& check_loc_info,
const std::string& context_info,
const std::string& explanation)
: ngraph_error(make_what(check_loc_info, context_info, explanation))
{
}
private:
static std::string make_what(const CheckLocInfo& check_loc_info,
const std::string& context_info,
const std::string& explanation)
{
std::stringstream ss;
ss << "Check '" << check_loc_info.check_string << "' failed at " << check_loc_info.file
<< ":" << check_loc_info.line << ":" << std::endl;
ss << context_info << ":" << std::endl;
ss << explanation << std::endl;
return ss.str();
}
};
}
// TODO(amprocte): refactor so we don't have to introduce a locally-scoped variable and risk
// shadowing here.
#define CHECK(exc_class, ctx, check, ...) \
do \
{ \
if (!(check)) \
{ \
::std::stringstream ss___; \
::ngraph::write_all_to_stream(ss___, __VA_ARGS__); \
throw exc_class( \
(::ngraph::CheckLocInfo{__FILE__, __LINE__, #check}), (ctx), ss___.str()); \
} \
} while (0)
......@@ -422,15 +422,20 @@ NodeVector Node::get_users(bool check_is_used) const
std::string ngraph::node_validation_assertion_string(const Node* node)
{
std::stringstream ss;
ss << "While validating node '" << *node << "' of type '" << node->description() << "'";
ss << "While validating node '" << *node << "'";
return ss.str();
}
void ngraph::check_new_args_count(const Node* node, const NodeVector& new_args)
{
NODE_VALIDATION_ASSERT(node, new_args.size() == node->get_arguments().size())
<< "copy_with_new_args() expected " << node->get_arguments().size() << " argument"
<< (node->get_arguments().size() == 1 ? "" : "s") << " but got " << new_args.size();
NODE_VALIDATION_CHECK(node,
new_args.size() == node->get_arguments().size(),
"copy_with_new_args() expected ",
node->get_arguments().size(),
" argument",
(node->get_arguments().size() == 1 ? "" : "s"),
" but got ",
new_args.size());
}
const std::shared_ptr<Node>& ngraph::check_single_output_arg(const std::shared_ptr<Node>& node,
......@@ -459,13 +464,14 @@ std::tuple<element::Type, PartialShape> Node::validate_and_infer_elementwise_arg
{
for (size_t i = 1; i < get_input_size(); ++i)
{
NODE_VALIDATION_ASSERT(
this, element::Type::merge(element_type, element_type, get_input_element_type(i)))
<< "Argument element types are inconsistent.";
NODE_VALIDATION_ASSERT(this,
PartialShape::merge_into(pshape, get_input_partial_shape(i)))
<< "Argument shapes are inconsistent.";
NODE_VALIDATION_CHECK(
this,
element::Type::merge(element_type, element_type, get_input_element_type(i)),
"Argument element types are inconsistent.");
NODE_VALIDATION_CHECK(this,
PartialShape::merge_into(pshape, get_input_partial_shape(i)),
"Argument shapes are inconsistent.");
}
}
......@@ -478,8 +484,11 @@ void Node::validate_and_infer_elementwise_arithmetic()
element::Type& args_et = std::get<0>(args_et_pshape);
PartialShape& args_pshape = std::get<1>(args_et_pshape);
NODE_VALIDATION_ASSERT(this, args_et.is_dynamic() || args_et != element::boolean)
<< "Arguments cannot have boolean element type (argument element type: " << args_et << ").";
NODE_VALIDATION_CHECK(this,
args_et.is_dynamic() || args_et != element::boolean,
"Arguments cannot have boolean element type (argument element type: ",
args_et,
").");
set_output_type(0, args_et, args_pshape);
}
......@@ -490,9 +499,12 @@ void Node::validate_and_infer_elementwise_logical()
element::Type& args_et = std::get<0>(args_et_pshape);
PartialShape& args_pshape = std::get<1>(args_et_pshape);
NODE_VALIDATION_ASSERT(this, args_et.is_dynamic() || args_et == element::boolean)
<< "Operands for logical operators must have boolean element type but have element type "
<< args_et << ".";
NODE_VALIDATION_CHECK(
this,
args_et.is_dynamic() || args_et == element::boolean,
"Operands for logical operators must have boolean element type but have element type ",
args_et,
".");
set_output_type(0, element::boolean, args_pshape);
}
......@@ -30,6 +30,7 @@
#include "ngraph/assertion.hpp"
#include "ngraph/autodiff/adjoints.hpp"
#include "ngraph/check.hpp"
#include "ngraph/descriptor/input.hpp"
#include "ngraph/descriptor/output.hpp"
#include "ngraph/descriptor/tensor.hpp"
......@@ -287,6 +288,17 @@ namespace ngraph
}
};
class NodeValidationFailure : public CheckFailure
{
public:
NodeValidationFailure(const CheckLocInfo& check_loc_info,
const Node* node,
const std::string& explanation)
: CheckFailure(check_loc_info, node_validation_assertion_string(node), explanation)
{
}
};
class NodeDescription
{
public:
......@@ -315,3 +327,6 @@ namespace ngraph
#define NODE_VALIDATION_FAIL(node) \
NGRAPH_FAIL_STREAM_WITH_LOC(::ngraph::NodeValidationError, \
::ngraph::node_validation_assertion_string(node))
#define NODE_VALIDATION_CHECK(node, cond, ...) \
CHECK(::NodeValidationFailure, (node), (cond), __VA_ARGS__)
......@@ -297,9 +297,14 @@ void op::ConvolutionBackpropData::validate_and_infer_types()
m_window_movement_strides_forward,
m_window_dilation_strides_forward);
NODE_VALIDATION_ASSERT(this, forward_result_shape.compatible(delta_shape))
<< "Inferred forward output shape (" << forward_result_shape << ") does not match shape of "
<< "delta (" << delta_shape << ").";
NODE_VALIDATION_CHECK(this,
forward_result_shape.compatible(delta_shape),
"Inferred forward output shape (",
forward_result_shape,
") does not match shape of ",
"delta (",
delta_shape,
").");
set_output_type(0, forward_result_et, m_data_batch_shape);
......@@ -494,9 +499,14 @@ void op::ConvolutionBackpropFilters::validate_and_infer_types()
m_window_movement_strides_forward,
m_window_dilation_strides_forward);
NODE_VALIDATION_ASSERT(this, forward_result_shape.compatible(delta_shape))
<< "Inferred forward output shape (" << forward_result_shape << ") does not match shape of "
<< "delta (" << delta_shape << ").";
NODE_VALIDATION_CHECK(this,
forward_result_shape.compatible(delta_shape),
"Inferred forward output shape (",
forward_result_shape,
") does not match shape of ",
"delta (",
delta_shape,
").");
set_output_type(0, forward_result_et, m_filters_shape);
......
......@@ -39,18 +39,29 @@ PartialShape ngraph::infer_windowed_reduction_output_shape(const Node* node,
{
PartialShape data_shape_merged{PartialShape::dynamic()};
NODE_VALIDATION_ASSERT(node,
data_shape_merged.merge_rank(data_shape.rank()) &&
data_shape_merged.merge_rank(data_dilation.size()) &&
data_shape_merged.merge_rank(data_padding_below.size()) &&
data_shape_merged.merge_rank(data_padding_above.size()) &&
data_shape_merged.merge_rank(window_shape.rank()) &&
data_shape_merged.merge_rank(window_strides.size()) &&
data_shape_merged.merge_rank(window_dilation.size()))
<< "Ranks for data shape (" << data_shape << "), data dilation (" << data_dilation
<< "), padding below (" << data_padding_below << "), padding above (" << data_padding_above
<< "), window shape (" << window_shape << "), window strides (" << window_strides
<< "), and window dilation (" << window_dilation << ") do not match.";
NODE_VALIDATION_CHECK(node,
data_shape_merged.merge_rank(data_shape.rank()) &&
data_shape_merged.merge_rank(data_dilation.size()) &&
data_shape_merged.merge_rank(data_padding_below.size()) &&
data_shape_merged.merge_rank(data_padding_above.size()) &&
data_shape_merged.merge_rank(window_shape.rank()) &&
data_shape_merged.merge_rank(window_strides.size()) &&
data_shape_merged.merge_rank(window_dilation.size()),
"Ranks for data shape (",
data_shape,
"), data dilation (",
data_dilation,
"), padding below (",
data_padding_below,
"), padding above (",
data_padding_above,
"), window shape (",
window_shape,
"), window strides (",
window_strides,
"), and window dilation (",
window_dilation,
") do not match.");
PartialShape output_shape = PartialShape::dynamic(data_shape_merged.rank());
......@@ -58,15 +69,27 @@ PartialShape ngraph::infer_windowed_reduction_output_shape(const Node* node,
{
for (size_t i = 0; i < static_cast<size_t>(output_shape.rank()); i++)
{
NODE_VALIDATION_ASSERT(node, data_dilation[i] > 0)
<< "Data dilation (" << data_dilation << ") has zero dimension at axis " << i
<< ".";
NODE_VALIDATION_ASSERT(node, window_strides[i] > 0)
<< "Window strides (" << window_strides << ") has zero dimension at axis " << i
<< ".";
NODE_VALIDATION_ASSERT(node, window_dilation[i] > 0)
<< "Window dilation (" << window_dilation << ") has zero dimension at axis " << i
<< ".";
NODE_VALIDATION_CHECK(node,
data_dilation[i] > 0,
"Data dilation (",
data_dilation,
") has zero dimension at axis ",
i,
".");
NODE_VALIDATION_CHECK(node,
window_strides[i] > 0,
"Window strides (",
window_strides,
") has zero dimension at axis ",
i,
".");
NODE_VALIDATION_CHECK(node,
window_dilation[i] > 0,
"Window dilation (",
window_dilation,
") has zero dimension at axis ",
i,
".");
bool data_dim_static = data_shape.rank().is_static() && data_shape[i].is_static();
bool window_dim_static = window_shape.rank().is_static() && window_shape[i].is_static();
......@@ -77,9 +100,14 @@ PartialShape ngraph::infer_windowed_reduction_output_shape(const Node* node,
data_padded_dilated_dim = (static_cast<ptrdiff_t>(data_dilation[i]) *
(static_cast<ptrdiff_t>(data_shape[i]) - 1)) +
1 + data_padding_below[i] + data_padding_above[i];
NODE_VALIDATION_ASSERT(node, data_padded_dilated_dim > 0)
<< "Data shape after padding and dilation has dimension less than 1 (dim: "
<< data_padded_dilated_dim << ") at axis " << i << ".";
NODE_VALIDATION_CHECK(
node,
data_padded_dilated_dim > 0,
"Data shape after padding and dilation has dimension less than 1 (dim: ",
data_padded_dilated_dim,
") at axis ",
i,
".");
}
ptrdiff_t window_dilated_dim = -1;
......@@ -89,28 +117,42 @@ PartialShape ngraph::infer_windowed_reduction_output_shape(const Node* node,
(static_cast<ptrdiff_t>(window_shape[i]) - 1) +
1;
NODE_VALIDATION_ASSERT(node, window_dilated_dim > 0)
<< "Window after dilation has dimension less than 1 (dim: "
<< window_dilated_dim << ") at axis " << i << ".";
NODE_VALIDATION_ASSERT(node,
is_window_all_in_padding_allowed ||
(window_dilated_dim > data_padding_below[i] &&
window_dilated_dim > data_padding_above[i]))
<< "Window after dilation is sometimes entirely in the padding area for axis "
<< i << " (dilated window dimension: " << window_dilated_dim
<< ", padding below dimension: " << data_padding_below[i]
<< ", padding above dimension: " << data_padding_above[i]
<< ") and this is not "
<< "allowed.";
NODE_VALIDATION_CHECK(node,
window_dilated_dim > 0,
"Window after dilation has dimension less than 1 (dim: ",
window_dilated_dim,
") at axis ",
i,
".");
NODE_VALIDATION_CHECK(
node,
is_window_all_in_padding_allowed ||
(window_dilated_dim > data_padding_below[i] &&
window_dilated_dim > data_padding_above[i]),
"Window after dilation is sometimes entirely in the padding area for axis ",
i,
" (dilated window dimension: ",
window_dilated_dim,
", padding below dimension: ",
data_padding_below[i],
", padding above dimension: ",
data_padding_above[i],
") and this is not ",
"allowed.");
}
if (data_dim_static && window_dim_static)
{
NODE_VALIDATION_ASSERT(node, window_dilated_dim <= data_padded_dilated_dim)
<< "Window after dilation has dimension (dim: " << window_dilated_dim
<< ") larger than the data shape after padding (dim: "
<< data_padded_dilated_dim << ") at axis " << i << ".";
NODE_VALIDATION_CHECK(node,
window_dilated_dim <= data_padded_dilated_dim,
"Window after dilation has dimension (dim: ",
window_dilated_dim,
") larger than the data shape after padding (dim: ",
data_padded_dilated_dim,
") at axis ",
i,
".");
output_shape[i] = ceil_div(static_cast<size_t>(data_padded_dilated_dim) -
static_cast<size_t>(window_dilated_dim) + 1,
......@@ -139,39 +181,64 @@ std::tuple<element::Type, PartialShape>
{
element::Type et_result;
NODE_VALIDATION_ASSERT(node, element::Type::merge(et_result, et_batch, et_filters))
<< "Element types for data batch and filters do not match (data batch element type: "
<< et_batch << ", filters element type: " << et_filters << ").";
NODE_VALIDATION_CHECK(
node,
element::Type::merge(et_result, et_batch, et_filters),
"Element types for data batch and filters do not match (data batch element type: ",
et_batch,
", filters element type: ",
et_filters,
").");
Rank data_batch_filters_rank{Rank::dynamic()};
NODE_VALIDATION_ASSERT(
node, Rank::merge(data_batch_filters_rank, data_batch_shape.rank(), filters_shape.rank()))
<< "Data batch and filters rank do not match (data batch shape: " << data_batch_shape
<< ", filters shape: " << filters_shape << ").";
NODE_VALIDATION_ASSERT(node,
data_batch_filters_rank.is_dynamic() ||
static_cast<size_t>(data_batch_filters_rank) >= 3)
<< "Data batch and filters must have rank of at least 3 (one batch axis, "
<< "one input-channel axis, and at least one spatial dimension) "
<< "(data batch shape: " << data_batch_shape << ", filters shape: " << filters_shape
<< ").";
NODE_VALIDATION_CHECK(
node,
Rank::merge(data_batch_filters_rank, data_batch_shape.rank(), filters_shape.rank()),
"Data batch and filters rank do not match (data batch shape: ",
data_batch_shape,
", filters shape: ",
filters_shape,
").");
NODE_VALIDATION_CHECK(node,
data_batch_filters_rank.is_dynamic() ||
static_cast<size_t>(data_batch_filters_rank) >= 3,
"Data batch and filters must have rank of at least 3 (one batch axis, ",
"one input-channel axis, and at least one spatial dimension) ",
"(data batch shape: ",
data_batch_shape,
", filters shape: ",
filters_shape,
").");
Rank spatial_rank{Rank::dynamic()};
NODE_VALIDATION_ASSERT(node,
Rank::merge(spatial_rank, spatial_rank, data_batch_filters_rank - 2) &&
Rank::merge(spatial_rank, spatial_rank, data_dilation.size()) &&
Rank::merge(spatial_rank, spatial_rank, data_padding_below.size()) &&
Rank::merge(spatial_rank, spatial_rank, data_padding_above.size()) &&
Rank::merge(spatial_rank, spatial_rank, filter_strides.size()) &&
Rank::merge(spatial_rank, spatial_rank, filter_dilation.size()))
<< "Ranks for data item shape/filters shape (data batch has shape " << data_batch_shape
<< ", so data item rank is " << (data_batch_shape.rank() - 2) << " and filters have shape "
<< filters_shape << ", so filters spatial rank is " << (filters_shape.rank() - 2)
<< "), data dilation (" << data_dilation << "), padding below (" << data_padding_below
<< "), padding above (" << data_padding_above << "), filter strides (" << filter_strides
<< "), and filter dilation (" << filter_dilation << ") do not match.";
NODE_VALIDATION_CHECK(node,
Rank::merge(spatial_rank, spatial_rank, data_batch_filters_rank - 2) &&
Rank::merge(spatial_rank, spatial_rank, data_dilation.size()) &&
Rank::merge(spatial_rank, spatial_rank, data_padding_below.size()) &&
Rank::merge(spatial_rank, spatial_rank, data_padding_above.size()) &&
Rank::merge(spatial_rank, spatial_rank, filter_strides.size()) &&
Rank::merge(spatial_rank, spatial_rank, filter_dilation.size()),
"Ranks for data item shape/filters shape (data batch has shape ",
data_batch_shape,
", so data item rank is ",
(data_batch_shape.rank() - 2),
" and filters have shape ",
filters_shape,
", so filters spatial rank is ",
(filters_shape.rank() - 2),
"), data dilation (",
data_dilation,
"), padding below (",
data_padding_below,
"), padding above (",
data_padding_above,
"), filter strides (",
filter_strides,
"), and filter dilation (",
filter_dilation,
") do not match.");
Dimension batch_size =
(data_batch_shape.rank().is_static() ? data_batch_shape[0] : Dimension::dynamic());
......@@ -202,25 +269,31 @@ std::tuple<element::Type, PartialShape>
}
}
NODE_VALIDATION_ASSERT(node, batch_size.is_dynamic() || static_cast<size_t>(batch_size) > 0)
<< "Batch size is zero.";
NODE_VALIDATION_CHECK(node,
batch_size.is_dynamic() || static_cast<size_t>(batch_size) > 0,
"Batch size is zero.");
Dimension merged_channel_count;
NODE_VALIDATION_ASSERT(
NODE_VALIDATION_CHECK(
node,
Dimension::merge(merged_channel_count, data_channel_count, filter_input_channel_count))
<< "Data batch channel count (" << data_channel_count << ") does not match filter input "
<< "channel count (" << filter_input_channel_count << ").";
NODE_VALIDATION_ASSERT(
node, merged_channel_count.is_dynamic() || static_cast<size_t>(merged_channel_count) > 0)
<< "Data batch channel count and/or filter input channel count is zero.";
NODE_VALIDATION_ASSERT(node,
filter_output_channel_count.is_dynamic() ||
static_cast<size_t>(filter_output_channel_count) > 0)
<< "Filter output channel count is zero.";
Dimension::merge(merged_channel_count, data_channel_count, filter_input_channel_count),
"Data batch channel count (",
data_channel_count,
") does not match filter input ",
"channel count (",
filter_input_channel_count,
").");
NODE_VALIDATION_CHECK(node,
merged_channel_count.is_dynamic() ||
static_cast<size_t>(merged_channel_count) > 0,
"Data batch channel count and/or filter input channel count is zero.");
NODE_VALIDATION_CHECK(node,
filter_output_channel_count.is_dynamic() ||
static_cast<size_t>(filter_output_channel_count) > 0,
"Filter output channel count is zero.");
PartialShape data_output_shape = infer_windowed_reduction_output_shape(node,
data_spatial_shape,
......@@ -255,25 +328,36 @@ PartialShape ngraph::infer_batched_pooling_forward(const Node* node,
const Strides& window_strides,
bool is_window_all_in_padding_allowed)
{
NODE_VALIDATION_ASSERT(node,
data_batch_shape.rank().is_dynamic() ||
static_cast<size_t>(data_batch_shape.rank()) >= 3)
<< "Data batch must have rank of at least 3 (one batch axis, "
<< "one input-channel axis, and at least one spatial dimension) "
<< "(data batch shape: " << data_batch_shape << ").";
NODE_VALIDATION_CHECK(node,
data_batch_shape.rank().is_dynamic() ||
static_cast<size_t>(data_batch_shape.rank()) >= 3,
"Data batch must have rank of at least 3 (one batch axis, ",
"one input-channel axis, and at least one spatial dimension) ",
"(data batch shape: ",
data_batch_shape,
").");
PartialShape data_spatial_shape{PartialShape::dynamic()};
NODE_VALIDATION_ASSERT(node,
data_spatial_shape.merge_rank(data_batch_shape.rank() - 2) &&
data_spatial_shape.merge_rank(data_padding_below.size()) &&
data_spatial_shape.merge_rank(data_padding_above.size()) &&
data_spatial_shape.merge_rank(window_shape.rank()) &&
data_spatial_shape.merge_rank(window_strides.size()))
<< "Ranks for data item shape (data batch has shape " << data_batch_shape
<< ", so data item rank is " << (data_batch_shape.rank() - 2) << "), padding below ("
<< data_padding_below << "), padding above (" << data_padding_above << "), window shape ("
<< window_shape << "), and window strides (" << window_strides << ") do not match.";
NODE_VALIDATION_CHECK(node,
data_spatial_shape.merge_rank(data_batch_shape.rank() - 2) &&
data_spatial_shape.merge_rank(data_padding_below.size()) &&
data_spatial_shape.merge_rank(data_padding_above.size()) &&
data_spatial_shape.merge_rank(window_shape.rank()) &&
data_spatial_shape.merge_rank(window_strides.size()),
"Ranks for data item shape (data batch has shape ",
data_batch_shape,
", so data item rank is ",
(data_batch_shape.rank() - 2),
"), padding below (",
data_padding_below,
"), padding above (",
data_padding_above,
"), window shape (",
window_shape,
"), and window strides (",
window_strides,
") do not match.");
Dimension batch_size{Dimension::dynamic()};
Dimension channel_count{Dimension::dynamic()};
......@@ -289,12 +373,13 @@ PartialShape ngraph::infer_batched_pooling_forward(const Node* node,
data_spatial_shape[i] = data_batch_shape[i + 2];
}
NODE_VALIDATION_ASSERT(node, batch_size.is_dynamic() || static_cast<size_t>(batch_size) > 0)
<< "Batch size is zero.";
NODE_VALIDATION_CHECK(node,
batch_size.is_dynamic() || static_cast<size_t>(batch_size) > 0,
"Batch size is zero.");
NODE_VALIDATION_ASSERT(node,
channel_count.is_dynamic() || static_cast<size_t>(channel_count) > 0)
<< "Channel count is zero.";
NODE_VALIDATION_CHECK(node,
channel_count.is_dynamic() || static_cast<size_t>(channel_count) > 0,
"Channel count is zero.");
// For pooling ops we don't need dilation, so we fill in the identity value (all 1).
Strides data_dilation(static_cast<size_t>(data_spatial_shape.rank()), 1);
......@@ -358,17 +443,19 @@ static std::tuple<element::Type, PartialShape, PartialShape> infer_batch_norm_fo
for (auto& inp : channel_shaped_inputs)
{
NODE_VALIDATION_ASSERT(node, element::Type::merge(et_result, et_result, inp.m_element_type))
<< "Input element types do not match.";
NODE_VALIDATION_CHECK(node,
element::Type::merge(et_result, et_result, inp.m_element_type),
"Input element types do not match.");
}
// Extract channel dimension from input shape.
Dimension channel_dim{Dimension::dynamic()};
NODE_VALIDATION_ASSERT(node,
input_shape.is_dynamic() || static_cast<size_t>(input_shape.rank()) >= 2)
<< "Input argument must have rank of at least 2 (input argument shape: " << input_shape
<< ").";
NODE_VALIDATION_CHECK(node,
input_shape.is_dynamic() || static_cast<size_t>(input_shape.rank()) >= 2,
"Input argument must have rank of at least 2 (input argument shape: ",
input_shape,
").");
if (input_shape.rank().is_static())
{
......@@ -380,20 +467,34 @@ static std::tuple<element::Type, PartialShape, PartialShape> infer_batch_norm_fo
for (auto& inp : channel_shaped_inputs)
{
NODE_VALIDATION_ASSERT(node, PartialShape::merge_into(channel_shape, inp.m_shape))
<< "Shapes for " << channel_input_names << " do not match.";
NODE_VALIDATION_CHECK(node,
PartialShape::merge_into(channel_shape, inp.m_shape),
"Shapes for ",
channel_input_names,
" do not match.");
}
NODE_VALIDATION_ASSERT(node, channel_shape.merge_rank(1)) << "Shape for " << channel_input_names
<< " (" << channel_shape
<< ") does not have rank 1.";
NODE_VALIDATION_ASSERT(node, Dimension::merge(channel_dim, channel_dim, channel_shape[0]))
<< "Input channel dimension (" << channel_dim << ") does not match shape for "
<< channel_input_names << " (" << channel_shape << ").";
NODE_VALIDATION_ASSERT(node, channel_dim.is_dynamic() || static_cast<size_t>(channel_dim) >= 1)
<< "Channel count must be at least 1.";
NODE_VALIDATION_CHECK(node,
channel_shape.merge_rank(1),
"Shape for ",
channel_input_names,
" (",
channel_shape,
") does not have rank 1.");
NODE_VALIDATION_CHECK(node,
Dimension::merge(channel_dim, channel_dim, channel_shape[0]),
"Input channel dimension (",
channel_dim,
") does not match shape for ",
channel_input_names,
" (",
channel_shape,
").");
NODE_VALIDATION_CHECK(node,
channel_dim.is_dynamic() || static_cast<size_t>(channel_dim) >= 1,
"Channel count must be at least 1.");
// Batch result shape is same as the input shape, except we may possibly have inferred more
// information from the channel count via gamma/beta/etc.
......
......@@ -884,7 +884,7 @@ TEST(partial_shape, infer_windowed_reduction_rank_dynamic_rank_dynamic_zero_data
window_dilation,
is_window_all_in_padding_allowed);
},
NodeValidationError);
NodeValidationFailure);
}
TEST(partial_shape, infer_windowed_reduction_rank_dynamic_rank_dynamic_zero_window_dilation)
......@@ -911,7 +911,7 @@ TEST(partial_shape, infer_windowed_reduction_rank_dynamic_rank_dynamic_zero_wind
window_dilation,
is_window_all_in_padding_allowed);
},
NodeValidationError);
NodeValidationFailure);
}
TEST(partial_shape, infer_windowed_reduction_rank_dynamic_rank_dynamic_zero_window_strides)
......@@ -938,7 +938,7 @@ TEST(partial_shape, infer_windowed_reduction_rank_dynamic_rank_dynamic_zero_wind
window_dilation,
is_window_all_in_padding_allowed);
},
NodeValidationError);
NodeValidationFailure);
}
TEST(partial_shape, infer_windowed_reduction_rank_static_dynamic_rank_dynamic_ok)
......@@ -992,7 +992,7 @@ TEST(partial_shape,
window_dilation,
is_window_all_in_padding_allowed);
},
NodeValidationError);
NodeValidationFailure);
}
TEST(partial_shape, infer_windowed_reduction_rank_static_dynamic_rank_dynamic_neg_padding_ok)
......@@ -1071,7 +1071,7 @@ TEST(partial_shape, infer_windowed_reduction_rank_dynamic_rank_static_dynamic_wi
window_dilation,
is_window_all_in_padding_allowed);
},
NodeValidationError);
NodeValidationFailure);
}
TEST(partial_shape,
......@@ -1100,7 +1100,7 @@ TEST(partial_shape,
window_dilation,
is_window_all_in_padding_allowed);
},
NodeValidationError);
NodeValidationFailure);
}
TEST(partial_shape,
......@@ -1156,7 +1156,7 @@ TEST(partial_shape,
window_dilation,
is_window_all_in_padding_allowed);
},
NodeValidationError);
NodeValidationFailure);
}
TEST(partial_shape,
......@@ -1294,7 +1294,7 @@ TEST(partial_shape, infer_windowed_reduction_rank_static_dynamic_rank_static_dyn
window_dilation,
is_window_all_in_padding_allowed);
},
NodeValidationError);
NodeValidationFailure);
}
TEST(partial_shape,
......@@ -1351,5 +1351,5 @@ TEST(partial_shape,
window_dilation,
is_window_all_in_padding_allowed);
},
NodeValidationError);
NodeValidationFailure);
}
......@@ -212,7 +212,7 @@ TEST(type_prop, batchnorm_training_rank_less_than_2)
auto bc = make_shared<op::BatchNormTraining>(dummy, dummy, dummy, 0.001);
FAIL() << "BatchNorm c-tor should throw for tensors whose rank is less than 2";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Input argument must have rank of at least 2"));
......@@ -233,7 +233,7 @@ TEST(type_prop, batchnorm_training_zero_channel_check)
auto bc = make_shared<op::BatchNormTraining>(data_batch, gamma, beta, 0.001);
FAIL() << "BatchNorm c-tor should throw for tensors w/ zero-dimension channels";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Channel count must be at least 1"));
}
......@@ -254,7 +254,7 @@ TEST(type_prop, batchnorm_training_et_check)
auto bc = make_shared<op::BatchNormTraining>(data_batch, gamma, beta, 0.001);
FAIL() << "BatchNorm c-tor should throw for different element types";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Input element types do not match"));
}
......@@ -275,7 +275,7 @@ TEST(type_prop, batchnorm_training_shape_check)
auto bc = make_shared<op::BatchNormTraining>(data_batch, gamma, beta, 0.001);
FAIL() << "BatchNorm c-tor should throw if gamma and beta shapes don't match";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Shapes for gamma/beta do not match"));
}
......@@ -300,7 +300,7 @@ TEST(type_prop, batchnorm_training_backprop_et_check)
data_batch, gamma, beta, mean, variance, delta, 0.001);
FAIL() << "Deduced type should disagree with c-tor arguments";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Input element types do not match"));
}
......@@ -325,7 +325,7 @@ TEST(type_prop, batchnorm_training_backprop_shape_check)
data_batch, gamma, beta, mean, variance, delta, 0.001);
FAIL() << "Deduced type should disagree with c-tor arguments";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Shapes for gamma/beta/mean/variance do not match"));
......@@ -443,7 +443,7 @@ TEST(type_prop, batchnorm_inference_partial_input_rank_static_dynamic_zero_chann
make_shared<op::BatchNormInference>(data_batch, gamma, beta, mean, variance, epsilon);
FAIL() << "Zero channel count not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Channel count must be at least 1"));
}
......@@ -506,7 +506,7 @@ TEST(type_prop, batchnorm_inference_partial_input_rank_dynamic_some_rank_static_
make_shared<op::BatchNormInference>(data_batch, gamma, beta, mean, variance, epsilon);
FAIL() << "Wrong gamma/beta/mean/variance shape not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
......@@ -545,7 +545,7 @@ TEST(type_prop,
make_shared<op::BatchNormInference>(data_batch, gamma, beta, mean, variance, epsilon);
FAIL() << "Inconsistent gamma/beta/mean/variance shape not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Shapes for gamma/beta/mean/variance do not match"));
......@@ -583,7 +583,7 @@ TEST(type_prop,
make_shared<op::BatchNormInference>(data_batch, gamma, beta, mean, variance, epsilon);
FAIL() << "Inconsistent gamma/beta/mean/variance channel count not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Shapes for gamma/beta/mean/variance do not match"));
......@@ -649,7 +649,7 @@ TEST(type_prop,
make_shared<op::BatchNormInference>(data_batch, gamma, beta, mean, variance, epsilon);
FAIL() << "Inconsistent input/gamma/beta/mean/variance channel count not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Input channel dimension (4) does not match "
......@@ -759,7 +759,7 @@ TEST(type_prop, batchnorm_training_partial_input_rank_static_dynamic_zero_channe
auto bn = make_shared<op::BatchNormTraining>(data_batch, gamma, beta, epsilon);
FAIL() << "Zero channel count not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Channel count must be at least 1"));
}
......@@ -813,7 +813,7 @@ TEST(type_prop, batchnorm_training_partial_input_rank_dynamic_some_rank_static_d
auto bn = make_shared<op::BatchNormTraining>(data_batch, gamma, beta, epsilon);
FAIL() << "Wrong gamma/beta shape not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Shape for gamma/beta ({?,?}) does not have rank 1"));
......@@ -844,7 +844,7 @@ TEST(type_prop,
auto bn = make_shared<op::BatchNormTraining>(data_batch, gamma, beta, epsilon);
FAIL() << "Inconsistent gamma/beta shape not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Shapes for gamma/beta do not match"));
}
......@@ -874,7 +874,7 @@ TEST(type_prop,
auto bn = make_shared<op::BatchNormTraining>(data_batch, gamma, beta, epsilon);
FAIL() << "Inconsistent gamma/beta channel count not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Shapes for gamma/beta do not match"));
}
......@@ -930,7 +930,7 @@ TEST(type_prop,
auto bn = make_shared<op::BatchNormTraining>(data_batch, gamma, beta, epsilon);
FAIL() << "Inconsistent input/gamma/beta channel count not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
......@@ -1049,7 +1049,7 @@ TEST(type_prop, batchnorm_training_backprop_partial_input_rank_static_dynamic_ze
data_batch, gamma, beta, mean, variance, delta, epsilon);
FAIL() << "Zero channel count not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Channel count must be at least 1"));
}
......@@ -1160,7 +1160,7 @@ TEST(type_prop, batchnorm_training_backprop_partial_delta_rank_static_dynamic_ze
data_batch, gamma, beta, mean, variance, delta, epsilon);
FAIL() << "Zero channel count not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Channel count must be at least 1"));
}
......@@ -1237,7 +1237,7 @@ TEST(
data_batch, gamma, beta, mean, variance, delta, epsilon);
FAIL() << "Wrong gamma/beta/mean/variance shape not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
......@@ -1280,7 +1280,7 @@ TEST(
data_batch, gamma, beta, mean, variance, delta, epsilon);
FAIL() << "Wrong gamma/beta/mean/variance shape not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Shapes for gamma/beta/mean/variance do not match"));
......@@ -1322,7 +1322,7 @@ TEST(
data_batch, gamma, beta, mean, variance, delta, epsilon);
FAIL() << "nconsistent gamma/beta/mean/variance channel count not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Shapes for gamma/beta/mean/variance do not match"));
......@@ -1400,7 +1400,7 @@ TEST(
data_batch, gamma, beta, mean, variance, delta, epsilon);
FAIL() << "Inconsistent delta/gamma/beta/mean/variance channel count not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Input channel dimension (4) does not match "
......@@ -2218,7 +2218,7 @@ void test_binary(std::string node_type,
// Should have thrown, so fail if it didn't
FAIL() << "Incompatible view arguments not detected.";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Argument shapes are inconsistent"));
}
......@@ -2237,7 +2237,7 @@ void test_binary(std::string node_type,
// Should have thrown, so fail if it didn't
FAIL() << "Incompatible view arguments not detected.";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Argument element types are inconsistent"));
......@@ -2310,7 +2310,7 @@ void test_binary_logical(std::string node_type,
// Should have thrown, so fail if it didn't
FAIL() << "Incompatible view arguments not detected.";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Argument shapes are inconsistent"));
}
......@@ -2329,7 +2329,7 @@ void test_binary_logical(std::string node_type,
// Should have thrown, so fail if it didn't
FAIL() << "Incompatible view arguments not detected.";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Argument element types are inconsistent"));
......@@ -2463,7 +2463,7 @@ TEST(type_prop, binary_arithmetic_bad_argument_element_types)
// Should have thrown, so fail if it didn't
FAIL() << "Did not detect incorrect element types for arithmetic operator";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Arguments cannot have boolean element type"));
......@@ -2483,7 +2483,7 @@ TEST(type_prop, unary_arithmetic_bad_argument_element_types)
// Should have thrown, so fail if it didn't
FAIL() << "Did not detect incorrect element types for arithmetic operator";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Arguments cannot have boolean element type"));
......@@ -5420,7 +5420,7 @@ TEST(type_prop, conv_invalid_element_type_mismatch)
// Should have thrown, so fail if it didn't
FAIL() << "Invalid input with element type mismatch not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Element types for data batch and filters do not match"));
......@@ -5443,7 +5443,7 @@ TEST(type_prop, conv_invalid_0d_input)
// Should have thrown, so fail if it didn't
FAIL() << "Invalid 0D input not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Data batch and filters must have rank of at least 3 "
......@@ -5468,7 +5468,7 @@ TEST(type_prop, conv_invalid_1d_input)
// Should have thrown, so fail if it didn't
FAIL() << "Invalid 1D input not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Data batch and filters must have rank of at least 3 "
......@@ -5493,7 +5493,7 @@ TEST(type_prop, conv_invalid_2d_input)
// Should have thrown, so fail if it didn't
FAIL() << "Invalid 2D input not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Data batch and filters must have rank of at least 3 "
......@@ -5518,7 +5518,7 @@ TEST(type_prop, conv_invalid_0_batch_size)
// Should have thrown, so fail if it didn't
FAIL() << "Invalid input with 0 batch size not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Batch size is zero"));
}
......@@ -5540,7 +5540,7 @@ TEST(type_prop, conv_invalid_0_input_channels)
// Should have thrown, so fail if it didn't
FAIL() << "Invalid input with 0 input channels not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
......@@ -5564,7 +5564,7 @@ TEST(type_prop, conv_invalid_wrong_number_of_filter_dimensions_too_many)
// Should have thrown, so fail if it didn't
FAIL() << "Invalid input with too many filter dimensions not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Data batch and filters rank do not match"));
}
......@@ -5586,7 +5586,7 @@ TEST(type_prop, conv_invalid_wrong_number_of_filter_dimensions_too_few)
// Should have thrown, so fail if it didn't
FAIL() << "Invalid input with too few filter dimensions not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Data batch and filters rank do not match"));
}
......@@ -5608,7 +5608,7 @@ TEST(type_prop, conv_invalid_0_output_channels)
// Should have thrown, so fail if it didn't
FAIL() << "Invalid input with 0 output channels not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Filter output channel count is zero"));
}
......@@ -5630,7 +5630,7 @@ TEST(type_prop, conv_invalid_input_channel_mismatch)
// Should have thrown, so fail if it didn't
FAIL() << "Invalid input with channel count mismatch not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
......@@ -5655,7 +5655,7 @@ TEST(type_prop, conv_invalid_movement_stride_rank)
// Should have thrown, so fail if it didn't
FAIL() << "Invalid input with wrong movement stride rank not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
......@@ -5684,7 +5684,7 @@ TEST(type_prop, conv_invalid_window_dilation_stride_rank)
// Should have thrown, so fail if it didn't
FAIL() << "Invalid input with wrong window dilation stride rank not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
......@@ -5719,7 +5719,7 @@ TEST(type_prop, conv_invalid_data_dilation_stride_rank)
// Should have thrown, so fail if it didn't
FAIL() << "Invalid input with wrong data dilation stride rank not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
......@@ -5753,7 +5753,7 @@ TEST(type_prop, conv_invalid_padding_below_rank)
// Should have thrown, so fail if it didn't
FAIL() << "Invalid input with wrong padding-below rank not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
......@@ -5787,7 +5787,7 @@ TEST(type_prop, conv_invalid_padding_above_rank)
// Should have thrown, so fail if it didn't
FAIL() << "Invalid input with wrong padding-above rank not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
......@@ -5821,7 +5821,7 @@ TEST(type_prop, conv_invalid_input_spatial_size_negative_after_padding)
// Should have thrown, so fail if it didn't
FAIL() << "Invalid input with negative-length post-padding spatial axis not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Data shape after padding and dilation has dimension less "
......@@ -5850,7 +5850,7 @@ TEST(type_prop, conv_invalid_input_spatial_size_zero_after_padding)
// Should have thrown, so fail if it didn't
FAIL() << "Invalid input with zero-length post-padding spatial axis not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Data shape after padding and dilation has dimension less "
......@@ -5874,7 +5874,7 @@ TEST(type_prop, conv_invalid_input_spatial_size_0)
// Should have thrown, so fail if it didn't
FAIL() << "Invalid input with zero-length spatial axis not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Data shape after padding and dilation has "
......@@ -5898,7 +5898,7 @@ TEST(type_prop, conv_invalid_window_size_0)
// Should have thrown, so fail if it didn't
FAIL() << "Invalid input with zero-length window axis not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
......@@ -5922,7 +5922,7 @@ TEST(type_prop, conv_invalid_window_dilation_stride_0)
// Should have thrown, so fail if it didn't
FAIL() << "Invalid input with wrong 0-length window dilation stride axis not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
......@@ -5952,7 +5952,7 @@ TEST(type_prop, conv_invalid_data_dilation_stride_0)
// Should have thrown, so fail if it didn't
FAIL() << "Invalid input with wrong 0-length data dilation stride axis not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
......@@ -5976,7 +5976,7 @@ TEST(type_prop, conv_invalid_dilated_window_too_large)
// Should have thrown, so fail if it didn't
FAIL() << "Invalid input with oversized dilated window not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Window after dilation has dimension (dim: 9) larger than "
......@@ -6000,7 +6000,7 @@ TEST(type_prop, conv_invalid_movement_stride_0)
// Should have thrown, so fail if it didn't
FAIL() << "Invalid input with wrong 0-length movement stride axis not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
......@@ -6062,7 +6062,7 @@ TEST(type_prop, conv_partial_rank_dynamic_rank_dynamic_window_strides_rank_wrong
FAIL() << "Window stride rank mismatch not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
......@@ -6103,7 +6103,7 @@ TEST(type_prop, conv_partial_rank_dynamic_rank_dynamic_window_strides_dim_zero)
FAIL() << "Window stride with dimension zero not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
......@@ -6140,7 +6140,7 @@ TEST(type_prop, conv_partial_rank_dynamic_rank_dynamic_window_dilation_rank_wron
FAIL() << "Window dilation rank mismatch not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
......@@ -6181,7 +6181,7 @@ TEST(type_prop, conv_partial_rank_dynamic_rank_dynamic_window_dilation_dim_zero)
FAIL() << "Window dilation with dimension zero not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
......@@ -6218,7 +6218,7 @@ TEST(type_prop, conv_partial_rank_dynamic_rank_dynamic_padding_below_rank_wrong)
FAIL() << "Padding below rank mismatch not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
......@@ -6259,7 +6259,7 @@ TEST(type_prop, conv_partial_rank_dynamic_rank_dynamic_padding_above_rank_wrong)
FAIL() << "Padding above rank mismatch not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
......@@ -6300,7 +6300,7 @@ TEST(type_prop, conv_partial_rank_dynamic_rank_dynamic_data_dilation_rank_wrong)
FAIL() << "Data dilation rank mismatch not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
......@@ -6341,7 +6341,7 @@ TEST(type_prop, conv_partial_rank_dynamic_rank_dynamic_data_dilation_dim_zero)
FAIL() << "Data dilation with dimension zero not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
......@@ -6403,7 +6403,7 @@ TEST(type_prop, conv_partial_rank_static_dynamic_rank_dynamic_data_batch_rank_wr
FAIL() << "Data batch rank mismatch not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
......@@ -6473,7 +6473,7 @@ TEST(type_prop, conv_partial_rank_static_dynamic_rank_dynamic_batch_size_known_z
FAIL() << "Zero batch size not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Batch size is zero"));
}
......@@ -6535,7 +6535,7 @@ TEST(type_prop, conv_partial_rank_static_dynamic_rank_dynamic_input_channel_coun
FAIL() << "Zero input channel count not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
......@@ -6599,7 +6599,7 @@ TEST(type_prop, conv_partial_rank_dynamic_rank_static_dynamic_output_channel_cou
FAIL() << "Zero output channel count not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Filter output channel count is zero"));
}
......@@ -6659,7 +6659,7 @@ TEST(type_prop, conv_partial_rank_dynamic_rank_static_dynamic_input_channel_coun
FAIL() << "Zero input channel count not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
......@@ -6721,7 +6721,7 @@ TEST(type_prop, conv_partial_rank_static_dynamic_rank_static_dynamic_arg_ranks_m
FAIL() << "Argument rank mismatch not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Data batch and filters rank do not match (data batch "
......@@ -6786,7 +6786,7 @@ TEST(type_prop, conv_partial_rank_static_dynamic_rank_static_dynamic_input_chann
FAIL() << "Input channel count mismatch not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
......@@ -6879,7 +6879,7 @@ TEST(
FAIL() << "Oversize filter not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Window after dilation has dimension (dim: 201) larger "
......@@ -7002,7 +7002,7 @@ TEST(
FAIL() << "Oversize filter after window dilation not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Window after dilation has dimension (dim: 201) larger "
......@@ -7041,7 +7041,7 @@ TEST(
FAIL() << "Zero dimension in data batch not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Data shape after padding and dilation has "
......@@ -7108,7 +7108,7 @@ TEST(
FAIL() << "Zero padded dimension in data batch not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Data shape after padding and dilation has "
......@@ -7147,7 +7147,7 @@ TEST(
FAIL() << "Negative padded dimension in data batch not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Data shape after padding and dilation has dimension less "
......@@ -7302,7 +7302,7 @@ TEST(type_prop, max_pool_invalid_0d_input)
// Should have thrown, so fail if it didn't
FAIL() << "Invalid 0D input not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Data batch must have rank of at least 3"));
}
......@@ -7324,7 +7324,7 @@ TEST(type_prop, max_pool_invalid_1d_input)
// Should have thrown, so fail if it didn't
FAIL() << "Invalid 1D input not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Data batch must have rank of at least 3"));
}
......@@ -7346,7 +7346,7 @@ TEST(type_prop, max_pool_invalid_2d_input)
// Should have thrown, so fail if it didn't
FAIL() << "Invalid 2D input not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Data batch must have rank of at least 3"));
}
......@@ -7368,7 +7368,7 @@ TEST(type_prop, max_pool_invalid_0_batch_size)
// Should have thrown, so fail if it didn't
FAIL() << "Invalid input with 0 batch size not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Batch size is zero"));
}
......@@ -7390,7 +7390,7 @@ TEST(type_prop, max_pool_invalid_0_channels)
// Should have thrown, so fail if it didn't
FAIL() << "Invalid input with 0 channels not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Channel count is zero"));
}
......@@ -7412,7 +7412,7 @@ TEST(type_prop, max_pool_invalid_wrong_number_of_window_dimensions_too_many)
// Should have thrown, so fail if it didn't
FAIL() << "Invalid input with too many window dimensions not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
......@@ -7439,7 +7439,7 @@ TEST(type_prop, max_pool_invalid_wrong_number_of_window_dimensions_too_few)
// Should have thrown, so fail if it didn't
FAIL() << "Invalid input with too few window dimensions not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
......@@ -7467,7 +7467,7 @@ TEST(type_prop, max_pool_invalid_movement_stride_rank)
// Should have thrown, so fail if it didn't
FAIL() << "Invalid input with wrong movement stride rank not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
......@@ -7494,7 +7494,7 @@ TEST(type_prop, max_pool_invalid_input_data_size_0)
// Should have thrown, so fail if it didn't
FAIL() << "Invalid input with zero-length spatial axis not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Data shape after padding and dilation has "
......@@ -7518,7 +7518,7 @@ TEST(type_prop, max_pool_invalid_window_size_0)
// Should have thrown, so fail if it didn't
FAIL() << "Invalid input with zero-length window axis not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
......@@ -7542,7 +7542,7 @@ TEST(type_prop, max_pool_invalid_dilated_too_large)
// Should have thrown, so fail if it didn't
FAIL() << "Invalid input with oversized window not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Window after dilation has dimension (dim: 9) larger than "
......@@ -7567,7 +7567,7 @@ TEST(type_prop, max_pool_invalid_movement_stride_0)
// Should have thrown, so fail if it didn't
FAIL() << "Invalid input with 0-length movement stride axis not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
......@@ -7611,7 +7611,7 @@ TEST(type_prop, max_pool_partial_rank_dynamic_attrib_rank_mismatch)
param, window_shape, window_movement_strides, padding_below, padding_above);
FAIL() << "Mismatch of attribute ranks not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
......@@ -7675,7 +7675,7 @@ TEST(type_prop, max_pool_partial_rank_static_dynamic_attrib_rank_mismatch)
param, window_shape, window_movement_strides, padding_below, padding_above);
FAIL() << "Mismatch of attribute ranks not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
......@@ -7706,7 +7706,7 @@ TEST(type_prop, max_pool_partial_rank_static_dynamic_window_not_too_big)
param, window_shape, window_movement_strides, padding_below, padding_above);
FAIL() << "Oversized window not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Window after dilation has dimension (dim: 9) larger than "
......@@ -8371,7 +8371,7 @@ TEST(type_prop, avg_pool_invalid_0d_input)
// Should have thrown, so fail if it didn't
FAIL() << "Invalid 0D input not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
"Data batch must have rank of at least 3 (one batch axis, one "
......@@ -8395,7 +8395,7 @@ TEST(type_prop, avg_pool_invalid_1d_input)
// Should have thrown, so fail if it didn't
FAIL() << "Invalid 1D input not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
"Data batch must have rank of at least 3 (one batch axis, one "
......@@ -8419,7 +8419,7 @@ TEST(type_prop, avg_pool_invalid_2d_input)
// Should have thrown, so fail if it didn't
FAIL() << "Invalid 2D input not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
"Data batch must have rank of at least 3 (one batch axis, one "
......@@ -8443,7 +8443,7 @@ TEST(type_prop, avg_pool_invalid_0_batch_size)
// Should have thrown, so fail if it didn't
FAIL() << "Invalid input with 0 batch size not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), "Batch size is zero");
}
......@@ -8465,7 +8465,7 @@ TEST(type_prop, avg_pool_invalid_0_channels)
// Should have thrown, so fail if it didn't
FAIL() << "Invalid input with 0 channels not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), "Channel count is zero");
}
......@@ -8487,7 +8487,7 @@ TEST(type_prop, avg_pool_invalid_wrong_number_of_window_dimensions_too_many)
// Should have thrown, so fail if it didn't
FAIL() << "Invalid input with too many window dimensions not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
"Ranks for data item shape (data batch has shape {6,2,10,10}, so data "
......@@ -8513,7 +8513,7 @@ TEST(type_prop, avg_pool_invalid_wrong_number_of_window_dimensions_too_few)
// Should have thrown, so fail if it didn't
FAIL() << "Invalid input with too few window dimensions not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
"Ranks for data item shape (data batch has shape {6,2,10,10}, so data "
......@@ -8540,7 +8540,7 @@ TEST(type_prop, avg_pool_invalid_movement_stride_rank)
// Should have thrown, so fail if it didn't
FAIL() << "Invalid input with wrong movement stride rank not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
"Ranks for data item shape (data batch has shape {6,2,10,10}, so data "
......@@ -8570,7 +8570,7 @@ TEST(type_prop, avg_pool_invalid_padding_below_rank)
// Should have thrown, so fail if it didn't
FAIL() << "Invalid input with wrong below-padding rank not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
"Ranks for data item shape (data batch has shape {6,2,10,10}, so data "
......@@ -8600,7 +8600,7 @@ TEST(type_prop, avg_pool_invalid_padding_above_rank)
// Should have thrown, so fail if it didn't
FAIL() << "Invalid input with wrong above-padding rank not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
"Ranks for data item shape (data batch has shape {6,2,10,10}, so data "
......@@ -8626,7 +8626,7 @@ TEST(type_prop, avg_pool_invalid_input_item_size_0)
// Should have thrown, so fail if it didn't
FAIL() << "Invalid input with zero-length spatial axis not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
......@@ -8650,7 +8650,7 @@ TEST(type_prop, avg_pool_invalid_window_size_0)
// Should have thrown, so fail if it didn't
FAIL() << "Invalid input with zero-length window axis not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
"Window after dilation has dimension less than 1 (dim: 0) at axis 1");
......@@ -8673,7 +8673,7 @@ TEST(type_prop, avg_pool_invalid_dilated_too_large)
// Should have thrown, so fail if it didn't
FAIL() << "Invalid input with oversized window not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
"Window after dilation has dimension (dim: 9) larger than the data "
......@@ -8712,7 +8712,7 @@ TEST(type_prop, avg_pool_invalid_movement_stride_0)
// Should have thrown, so fail if it didn't
FAIL() << "Invalid input with 0-length movement stride axis not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
"Window strides (Strides{0, 1}) has zero dimension at axis 0");
......@@ -8765,7 +8765,7 @@ TEST(type_prop, avg_pool_partial_rank_dynamic_attrib_rank_mismatch)
include_padding_in_average);
FAIL() << "Mismatch of attribute ranks not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
......@@ -8844,7 +8844,7 @@ TEST(type_prop, avg_pool_partial_rank_static_dynamic_attrib_rank_mismatch)
include_padding_in_average);
FAIL() << "Mismatch of attribute ranks not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
......@@ -8880,7 +8880,7 @@ TEST(type_prop, avg_pool_partial_rank_static_dynamic_window_not_too_big)
include_padding_in_average);
FAIL() << "Oversized window not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Window after dilation has dimension (dim: 9) larger than "
......@@ -8935,7 +8935,7 @@ TEST(type_prop, avg_pool_partial_rank_static_dynamic_window_in_padding)
include_padding_in_average);
FAIL() << "Window in padding not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Window after dilation has dimension (dim: 9) larger than "
......@@ -10030,7 +10030,7 @@ TEST(type_prop, binary_elementwise_arithmetic_left_rank_static_dynamic_inconsist
auto add = make_shared<op::Add>(a, b);
FAIL() << "Inconsistent partial shapes not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), "Argument shapes are inconsistent");
}
......@@ -10050,7 +10050,7 @@ TEST(type_prop, binary_elementwise_arithmetic_right_rank_static_dynamic_inconsis
auto add = make_shared<op::Add>(a, b);
FAIL() << "Inconsistent partial shapes not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), "Argument shapes are inconsistent");
}
......@@ -10070,7 +10070,7 @@ TEST(type_prop, binary_elementwise_arithmetic_both_rank_static_dynamic_inconsist
auto add = make_shared<op::Add>(a, b);
FAIL() << "Inconsistent partial shapes not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), "Argument shapes are inconsistent");
}
......@@ -10090,7 +10090,7 @@ TEST(type_prop, binary_elementwise_arithmetic_left_rank_static_dynamic_different
auto add = make_shared<op::Add>(a, b);
FAIL() << "Inconsistent partial shapes not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), "Argument shapes are inconsistent");
}
......@@ -10110,7 +10110,7 @@ TEST(type_prop, binary_elementwise_arithmetic_right_rank_static_dynamic_differen
auto add = make_shared<op::Add>(a, b);
FAIL() << "Inconsistent partial shapes not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), "Argument shapes are inconsistent");
}
......@@ -10130,7 +10130,7 @@ TEST(type_prop, binary_elementwise_arithmetic_both_rank_static_dynamic_different
auto add = make_shared<op::Add>(a, b);
FAIL() << "Inconsistent partial shapes not detected";
}
catch (const NodeValidationError& error)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), "Argument shapes are inconsistent");
}
......@@ -12095,3 +12095,56 @@ TEST(type_prop, all_partial_rank_static_dynamic_axes_oob)
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, DISABLED_benchmark_type_prop_add)
{
auto p1 = make_shared<op::Parameter>(element::f32, Shape{1, 2, 3, 4});
auto p2 = make_shared<op::Parameter>(element::f32, Shape{1, 2, 3, 4});
constexpr size_t num_iterations = 1000000;
size_t total_nanosec = 0;
stopwatch sw;
for (size_t i = 0; i < num_iterations; i++)
{
sw.start();
auto n = make_shared<op::Add>(p1, p2);
sw.stop();
total_nanosec += sw.get_nanoseconds();
}
std::cout.imbue(std::locale(""));
std::cout << "Constructed " << std::fixed << num_iterations << " Add ops in " << std::fixed
<< total_nanosec << " ns" << std::endl;
}
TEST(type_prop, DISABLED_benchmark_type_prop_convolution)
{
auto d = make_shared<op::Parameter>(element::f32, Shape{64, 3, 224, 224});
auto f = make_shared<op::Parameter>(element::f32, Shape{64, 3, 7, 7});
auto strides = Strides{1, 1};
auto dilation = Strides{1, 1};
auto padding_below = CoordinateDiff{1, 1};
auto padding_above = CoordinateDiff{1, 1};
constexpr size_t num_iterations = 1000000;
size_t total_nanosec = 0;
stopwatch sw;
for (size_t i = 0; i < num_iterations; i++)
{
sw.start();
auto n =
make_shared<op::Convolution>(d, f, strides, dilation, padding_below, padding_above);
sw.stop();
total_nanosec += sw.get_nanoseconds();
}
std::cout.imbue(std::locale(""));
std::cout << "Constructed " << std::fixed << num_iterations << " Convolution ops in "
<< std::fixed << total_nanosec << " ns" << std::endl;
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment