Commit 6ffbd254 authored by Adam Procter's avatar Adam Procter Committed by Robert Kimball

New macros for faster node validation checks (#2537)

* Skeleton for faster validation asserts

* Switch to __VA_ARGS__ for compatibility, remove -Wno-variadic-macros

* Add benchmarks for constructing Add and Convolution

* Quick hack to avoid shadowing inside the CHECK macro

* Quick hack to avoid inadvertent capture inside the macro

* Update convolution, and change a bunch of tests to anticipate the new error class
parent eb34c884
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <exception>
#include <sstream>
#include <vector>
#include "ngraph/except.hpp"
namespace ngraph
{
static inline std::ostream& write_all_to_stream(std::ostream& str) { return str; }
template <typename T, typename... TS>
static inline std::ostream& write_all_to_stream(std::ostream& str, const T& arg, TS... args)
{
return write_all_to_stream(str << arg, args...);
}
struct CheckLocInfo
{
const char* file;
int line;
const char* check_string;
};
/// Base class for check failure exceptions.
class CheckFailure : public ngraph_error
{
public:
CheckFailure(const CheckLocInfo& check_loc_info,
const std::string& context_info,
const std::string& explanation)
: ngraph_error(make_what(check_loc_info, context_info, explanation))
{
}
private:
static std::string make_what(const CheckLocInfo& check_loc_info,
const std::string& context_info,
const std::string& explanation)
{
std::stringstream ss;
ss << "Check '" << check_loc_info.check_string << "' failed at " << check_loc_info.file
<< ":" << check_loc_info.line << ":" << std::endl;
ss << context_info << ":" << std::endl;
ss << explanation << std::endl;
return ss.str();
}
};
}
// TODO(amprocte): refactor so we don't have to introduce a locally-scoped variable and risk
// shadowing here.
#define CHECK(exc_class, ctx, check, ...) \
do \
{ \
if (!(check)) \
{ \
::std::stringstream ss___; \
::ngraph::write_all_to_stream(ss___, __VA_ARGS__); \
throw exc_class( \
(::ngraph::CheckLocInfo{__FILE__, __LINE__, #check}), (ctx), ss___.str()); \
} \
} while (0)
......@@ -422,15 +422,20 @@ NodeVector Node::get_users(bool check_is_used) const
std::string ngraph::node_validation_assertion_string(const Node* node)
{
std::stringstream ss;
ss << "While validating node '" << *node << "' of type '" << node->description() << "'";
ss << "While validating node '" << *node << "'";
return ss.str();
}
void ngraph::check_new_args_count(const Node* node, const NodeVector& new_args)
{
NODE_VALIDATION_ASSERT(node, new_args.size() == node->get_arguments().size())
<< "copy_with_new_args() expected " << node->get_arguments().size() << " argument"
<< (node->get_arguments().size() == 1 ? "" : "s") << " but got " << new_args.size();
NODE_VALIDATION_CHECK(node,
new_args.size() == node->get_arguments().size(),
"copy_with_new_args() expected ",
node->get_arguments().size(),
" argument",
(node->get_arguments().size() == 1 ? "" : "s"),
" but got ",
new_args.size());
}
const std::shared_ptr<Node>& ngraph::check_single_output_arg(const std::shared_ptr<Node>& node,
......@@ -459,13 +464,14 @@ std::tuple<element::Type, PartialShape> Node::validate_and_infer_elementwise_arg
{
for (size_t i = 1; i < get_input_size(); ++i)
{
NODE_VALIDATION_ASSERT(
this, element::Type::merge(element_type, element_type, get_input_element_type(i)))
<< "Argument element types are inconsistent.";
NODE_VALIDATION_ASSERT(this,
PartialShape::merge_into(pshape, get_input_partial_shape(i)))
<< "Argument shapes are inconsistent.";
NODE_VALIDATION_CHECK(
this,
element::Type::merge(element_type, element_type, get_input_element_type(i)),
"Argument element types are inconsistent.");
NODE_VALIDATION_CHECK(this,
PartialShape::merge_into(pshape, get_input_partial_shape(i)),
"Argument shapes are inconsistent.");
}
}
......@@ -478,8 +484,11 @@ void Node::validate_and_infer_elementwise_arithmetic()
element::Type& args_et = std::get<0>(args_et_pshape);
PartialShape& args_pshape = std::get<1>(args_et_pshape);
NODE_VALIDATION_ASSERT(this, args_et.is_dynamic() || args_et != element::boolean)
<< "Arguments cannot have boolean element type (argument element type: " << args_et << ").";
NODE_VALIDATION_CHECK(this,
args_et.is_dynamic() || args_et != element::boolean,
"Arguments cannot have boolean element type (argument element type: ",
args_et,
").");
set_output_type(0, args_et, args_pshape);
}
......@@ -490,9 +499,12 @@ void Node::validate_and_infer_elementwise_logical()
element::Type& args_et = std::get<0>(args_et_pshape);
PartialShape& args_pshape = std::get<1>(args_et_pshape);
NODE_VALIDATION_ASSERT(this, args_et.is_dynamic() || args_et == element::boolean)
<< "Operands for logical operators must have boolean element type but have element type "
<< args_et << ".";
NODE_VALIDATION_CHECK(
this,
args_et.is_dynamic() || args_et == element::boolean,
"Operands for logical operators must have boolean element type but have element type ",
args_et,
".");
set_output_type(0, element::boolean, args_pshape);
}
......@@ -30,6 +30,7 @@
#include "ngraph/assertion.hpp"
#include "ngraph/autodiff/adjoints.hpp"
#include "ngraph/check.hpp"
#include "ngraph/descriptor/input.hpp"
#include "ngraph/descriptor/output.hpp"
#include "ngraph/descriptor/tensor.hpp"
......@@ -287,6 +288,17 @@ namespace ngraph
}
};
class NodeValidationFailure : public CheckFailure
{
public:
NodeValidationFailure(const CheckLocInfo& check_loc_info,
const Node* node,
const std::string& explanation)
: CheckFailure(check_loc_info, node_validation_assertion_string(node), explanation)
{
}
};
class NodeDescription
{
public:
......@@ -315,3 +327,6 @@ namespace ngraph
#define NODE_VALIDATION_FAIL(node) \
NGRAPH_FAIL_STREAM_WITH_LOC(::ngraph::NodeValidationError, \
::ngraph::node_validation_assertion_string(node))
#define NODE_VALIDATION_CHECK(node, cond, ...) \
CHECK(::NodeValidationFailure, (node), (cond), __VA_ARGS__)
......@@ -297,9 +297,14 @@ void op::ConvolutionBackpropData::validate_and_infer_types()
m_window_movement_strides_forward,
m_window_dilation_strides_forward);
NODE_VALIDATION_ASSERT(this, forward_result_shape.compatible(delta_shape))
<< "Inferred forward output shape (" << forward_result_shape << ") does not match shape of "
<< "delta (" << delta_shape << ").";
NODE_VALIDATION_CHECK(this,
forward_result_shape.compatible(delta_shape),
"Inferred forward output shape (",
forward_result_shape,
") does not match shape of ",
"delta (",
delta_shape,
").");
set_output_type(0, forward_result_et, m_data_batch_shape);
......@@ -494,9 +499,14 @@ void op::ConvolutionBackpropFilters::validate_and_infer_types()
m_window_movement_strides_forward,
m_window_dilation_strides_forward);
NODE_VALIDATION_ASSERT(this, forward_result_shape.compatible(delta_shape))
<< "Inferred forward output shape (" << forward_result_shape << ") does not match shape of "
<< "delta (" << delta_shape << ").";
NODE_VALIDATION_CHECK(this,
forward_result_shape.compatible(delta_shape),
"Inferred forward output shape (",
forward_result_shape,
") does not match shape of ",
"delta (",
delta_shape,
").");
set_output_type(0, forward_result_et, m_filters_shape);
......
This diff is collapsed.
......@@ -884,7 +884,7 @@ TEST(partial_shape, infer_windowed_reduction_rank_dynamic_rank_dynamic_zero_data
window_dilation,
is_window_all_in_padding_allowed);
},
NodeValidationError);
NodeValidationFailure);
}
TEST(partial_shape, infer_windowed_reduction_rank_dynamic_rank_dynamic_zero_window_dilation)
......@@ -911,7 +911,7 @@ TEST(partial_shape, infer_windowed_reduction_rank_dynamic_rank_dynamic_zero_wind
window_dilation,
is_window_all_in_padding_allowed);
},
NodeValidationError);
NodeValidationFailure);
}
TEST(partial_shape, infer_windowed_reduction_rank_dynamic_rank_dynamic_zero_window_strides)
......@@ -938,7 +938,7 @@ TEST(partial_shape, infer_windowed_reduction_rank_dynamic_rank_dynamic_zero_wind
window_dilation,
is_window_all_in_padding_allowed);
},
NodeValidationError);
NodeValidationFailure);
}
TEST(partial_shape, infer_windowed_reduction_rank_static_dynamic_rank_dynamic_ok)
......@@ -992,7 +992,7 @@ TEST(partial_shape,
window_dilation,
is_window_all_in_padding_allowed);
},
NodeValidationError);
NodeValidationFailure);
}
TEST(partial_shape, infer_windowed_reduction_rank_static_dynamic_rank_dynamic_neg_padding_ok)
......@@ -1071,7 +1071,7 @@ TEST(partial_shape, infer_windowed_reduction_rank_dynamic_rank_static_dynamic_wi
window_dilation,
is_window_all_in_padding_allowed);
},
NodeValidationError);
NodeValidationFailure);
}
TEST(partial_shape,
......@@ -1100,7 +1100,7 @@ TEST(partial_shape,
window_dilation,
is_window_all_in_padding_allowed);
},
NodeValidationError);
NodeValidationFailure);
}
TEST(partial_shape,
......@@ -1156,7 +1156,7 @@ TEST(partial_shape,
window_dilation,
is_window_all_in_padding_allowed);
},
NodeValidationError);
NodeValidationFailure);
}
TEST(partial_shape,
......@@ -1294,7 +1294,7 @@ TEST(partial_shape, infer_windowed_reduction_rank_static_dynamic_rank_static_dyn
window_dilation,
is_window_all_in_padding_allowed);
},
NodeValidationError);
NodeValidationFailure);
}
TEST(partial_shape,
......@@ -1351,5 +1351,5 @@ TEST(partial_shape,
window_dilation,
is_window_all_in_padding_allowed);
},
NodeValidationError);
NodeValidationFailure);
}
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment