Commit 6ffbd254 authored by Adam Procter's avatar Adam Procter Committed by Robert Kimball

New macros for faster node validation checks (#2537)

* Skeleton for faster validation asserts

* Switch to __VA_ARGS__ for compatibility, remove -Wno-variadic-macros

* Add benchmarks for constructing Add and Convolution

* Quick hack to avoid shadowing inside the CHECK macro

* Quick hack to avoid inadvertent capture inside the macro

* Update convolution, and change a bunch of tests to anticipate the new error class
parent eb34c884
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <exception>
#include <sstream>
#include <vector>
#include "ngraph/except.hpp"
namespace ngraph
{
static inline std::ostream& write_all_to_stream(std::ostream& str) { return str; }
template <typename T, typename... TS>
static inline std::ostream& write_all_to_stream(std::ostream& str, const T& arg, TS... args)
{
return write_all_to_stream(str << arg, args...);
}
struct CheckLocInfo
{
const char* file;
int line;
const char* check_string;
};
/// Base class for check failure exceptions.
class CheckFailure : public ngraph_error
{
public:
CheckFailure(const CheckLocInfo& check_loc_info,
const std::string& context_info,
const std::string& explanation)
: ngraph_error(make_what(check_loc_info, context_info, explanation))
{
}
private:
static std::string make_what(const CheckLocInfo& check_loc_info,
const std::string& context_info,
const std::string& explanation)
{
std::stringstream ss;
ss << "Check '" << check_loc_info.check_string << "' failed at " << check_loc_info.file
<< ":" << check_loc_info.line << ":" << std::endl;
ss << context_info << ":" << std::endl;
ss << explanation << std::endl;
return ss.str();
}
};
}
// TODO(amprocte): refactor so we don't have to introduce a locally-scoped variable and risk
// shadowing here.
#define CHECK(exc_class, ctx, check, ...) \
do \
{ \
if (!(check)) \
{ \
::std::stringstream ss___; \
::ngraph::write_all_to_stream(ss___, __VA_ARGS__); \
throw exc_class( \
(::ngraph::CheckLocInfo{__FILE__, __LINE__, #check}), (ctx), ss___.str()); \
} \
} while (0)
...@@ -422,15 +422,20 @@ NodeVector Node::get_users(bool check_is_used) const ...@@ -422,15 +422,20 @@ NodeVector Node::get_users(bool check_is_used) const
std::string ngraph::node_validation_assertion_string(const Node* node) std::string ngraph::node_validation_assertion_string(const Node* node)
{ {
std::stringstream ss; std::stringstream ss;
ss << "While validating node '" << *node << "' of type '" << node->description() << "'"; ss << "While validating node '" << *node << "'";
return ss.str(); return ss.str();
} }
void ngraph::check_new_args_count(const Node* node, const NodeVector& new_args) void ngraph::check_new_args_count(const Node* node, const NodeVector& new_args)
{ {
NODE_VALIDATION_ASSERT(node, new_args.size() == node->get_arguments().size()) NODE_VALIDATION_CHECK(node,
<< "copy_with_new_args() expected " << node->get_arguments().size() << " argument" new_args.size() == node->get_arguments().size(),
<< (node->get_arguments().size() == 1 ? "" : "s") << " but got " << new_args.size(); "copy_with_new_args() expected ",
node->get_arguments().size(),
" argument",
(node->get_arguments().size() == 1 ? "" : "s"),
" but got ",
new_args.size());
} }
const std::shared_ptr<Node>& ngraph::check_single_output_arg(const std::shared_ptr<Node>& node, const std::shared_ptr<Node>& ngraph::check_single_output_arg(const std::shared_ptr<Node>& node,
...@@ -459,13 +464,14 @@ std::tuple<element::Type, PartialShape> Node::validate_and_infer_elementwise_arg ...@@ -459,13 +464,14 @@ std::tuple<element::Type, PartialShape> Node::validate_and_infer_elementwise_arg
{ {
for (size_t i = 1; i < get_input_size(); ++i) for (size_t i = 1; i < get_input_size(); ++i)
{ {
NODE_VALIDATION_ASSERT( NODE_VALIDATION_CHECK(
this, element::Type::merge(element_type, element_type, get_input_element_type(i))) this,
<< "Argument element types are inconsistent."; element::Type::merge(element_type, element_type, get_input_element_type(i)),
"Argument element types are inconsistent.");
NODE_VALIDATION_ASSERT(this,
PartialShape::merge_into(pshape, get_input_partial_shape(i))) NODE_VALIDATION_CHECK(this,
<< "Argument shapes are inconsistent."; PartialShape::merge_into(pshape, get_input_partial_shape(i)),
"Argument shapes are inconsistent.");
} }
} }
...@@ -478,8 +484,11 @@ void Node::validate_and_infer_elementwise_arithmetic() ...@@ -478,8 +484,11 @@ void Node::validate_and_infer_elementwise_arithmetic()
element::Type& args_et = std::get<0>(args_et_pshape); element::Type& args_et = std::get<0>(args_et_pshape);
PartialShape& args_pshape = std::get<1>(args_et_pshape); PartialShape& args_pshape = std::get<1>(args_et_pshape);
NODE_VALIDATION_ASSERT(this, args_et.is_dynamic() || args_et != element::boolean) NODE_VALIDATION_CHECK(this,
<< "Arguments cannot have boolean element type (argument element type: " << args_et << ")."; args_et.is_dynamic() || args_et != element::boolean,
"Arguments cannot have boolean element type (argument element type: ",
args_et,
").");
set_output_type(0, args_et, args_pshape); set_output_type(0, args_et, args_pshape);
} }
...@@ -490,9 +499,12 @@ void Node::validate_and_infer_elementwise_logical() ...@@ -490,9 +499,12 @@ void Node::validate_and_infer_elementwise_logical()
element::Type& args_et = std::get<0>(args_et_pshape); element::Type& args_et = std::get<0>(args_et_pshape);
PartialShape& args_pshape = std::get<1>(args_et_pshape); PartialShape& args_pshape = std::get<1>(args_et_pshape);
NODE_VALIDATION_ASSERT(this, args_et.is_dynamic() || args_et == element::boolean) NODE_VALIDATION_CHECK(
<< "Operands for logical operators must have boolean element type but have element type " this,
<< args_et << "."; args_et.is_dynamic() || args_et == element::boolean,
"Operands for logical operators must have boolean element type but have element type ",
args_et,
".");
set_output_type(0, element::boolean, args_pshape); set_output_type(0, element::boolean, args_pshape);
} }
...@@ -30,6 +30,7 @@ ...@@ -30,6 +30,7 @@
#include "ngraph/assertion.hpp" #include "ngraph/assertion.hpp"
#include "ngraph/autodiff/adjoints.hpp" #include "ngraph/autodiff/adjoints.hpp"
#include "ngraph/check.hpp"
#include "ngraph/descriptor/input.hpp" #include "ngraph/descriptor/input.hpp"
#include "ngraph/descriptor/output.hpp" #include "ngraph/descriptor/output.hpp"
#include "ngraph/descriptor/tensor.hpp" #include "ngraph/descriptor/tensor.hpp"
...@@ -287,6 +288,17 @@ namespace ngraph ...@@ -287,6 +288,17 @@ namespace ngraph
} }
}; };
class NodeValidationFailure : public CheckFailure
{
public:
NodeValidationFailure(const CheckLocInfo& check_loc_info,
const Node* node,
const std::string& explanation)
: CheckFailure(check_loc_info, node_validation_assertion_string(node), explanation)
{
}
};
class NodeDescription class NodeDescription
{ {
public: public:
...@@ -315,3 +327,6 @@ namespace ngraph ...@@ -315,3 +327,6 @@ namespace ngraph
#define NODE_VALIDATION_FAIL(node) \ #define NODE_VALIDATION_FAIL(node) \
NGRAPH_FAIL_STREAM_WITH_LOC(::ngraph::NodeValidationError, \ NGRAPH_FAIL_STREAM_WITH_LOC(::ngraph::NodeValidationError, \
::ngraph::node_validation_assertion_string(node)) ::ngraph::node_validation_assertion_string(node))
#define NODE_VALIDATION_CHECK(node, cond, ...) \
CHECK(::NodeValidationFailure, (node), (cond), __VA_ARGS__)
...@@ -297,9 +297,14 @@ void op::ConvolutionBackpropData::validate_and_infer_types() ...@@ -297,9 +297,14 @@ void op::ConvolutionBackpropData::validate_and_infer_types()
m_window_movement_strides_forward, m_window_movement_strides_forward,
m_window_dilation_strides_forward); m_window_dilation_strides_forward);
NODE_VALIDATION_ASSERT(this, forward_result_shape.compatible(delta_shape)) NODE_VALIDATION_CHECK(this,
<< "Inferred forward output shape (" << forward_result_shape << ") does not match shape of " forward_result_shape.compatible(delta_shape),
<< "delta (" << delta_shape << ")."; "Inferred forward output shape (",
forward_result_shape,
") does not match shape of ",
"delta (",
delta_shape,
").");
set_output_type(0, forward_result_et, m_data_batch_shape); set_output_type(0, forward_result_et, m_data_batch_shape);
...@@ -494,9 +499,14 @@ void op::ConvolutionBackpropFilters::validate_and_infer_types() ...@@ -494,9 +499,14 @@ void op::ConvolutionBackpropFilters::validate_and_infer_types()
m_window_movement_strides_forward, m_window_movement_strides_forward,
m_window_dilation_strides_forward); m_window_dilation_strides_forward);
NODE_VALIDATION_ASSERT(this, forward_result_shape.compatible(delta_shape)) NODE_VALIDATION_CHECK(this,
<< "Inferred forward output shape (" << forward_result_shape << ") does not match shape of " forward_result_shape.compatible(delta_shape),
<< "delta (" << delta_shape << ")."; "Inferred forward output shape (",
forward_result_shape,
") does not match shape of ",
"delta (",
delta_shape,
").");
set_output_type(0, forward_result_et, m_filters_shape); set_output_type(0, forward_result_et, m_filters_shape);
......
This diff is collapsed.
...@@ -884,7 +884,7 @@ TEST(partial_shape, infer_windowed_reduction_rank_dynamic_rank_dynamic_zero_data ...@@ -884,7 +884,7 @@ TEST(partial_shape, infer_windowed_reduction_rank_dynamic_rank_dynamic_zero_data
window_dilation, window_dilation,
is_window_all_in_padding_allowed); is_window_all_in_padding_allowed);
}, },
NodeValidationError); NodeValidationFailure);
} }
TEST(partial_shape, infer_windowed_reduction_rank_dynamic_rank_dynamic_zero_window_dilation) TEST(partial_shape, infer_windowed_reduction_rank_dynamic_rank_dynamic_zero_window_dilation)
...@@ -911,7 +911,7 @@ TEST(partial_shape, infer_windowed_reduction_rank_dynamic_rank_dynamic_zero_wind ...@@ -911,7 +911,7 @@ TEST(partial_shape, infer_windowed_reduction_rank_dynamic_rank_dynamic_zero_wind
window_dilation, window_dilation,
is_window_all_in_padding_allowed); is_window_all_in_padding_allowed);
}, },
NodeValidationError); NodeValidationFailure);
} }
TEST(partial_shape, infer_windowed_reduction_rank_dynamic_rank_dynamic_zero_window_strides) TEST(partial_shape, infer_windowed_reduction_rank_dynamic_rank_dynamic_zero_window_strides)
...@@ -938,7 +938,7 @@ TEST(partial_shape, infer_windowed_reduction_rank_dynamic_rank_dynamic_zero_wind ...@@ -938,7 +938,7 @@ TEST(partial_shape, infer_windowed_reduction_rank_dynamic_rank_dynamic_zero_wind
window_dilation, window_dilation,
is_window_all_in_padding_allowed); is_window_all_in_padding_allowed);
}, },
NodeValidationError); NodeValidationFailure);
} }
TEST(partial_shape, infer_windowed_reduction_rank_static_dynamic_rank_dynamic_ok) TEST(partial_shape, infer_windowed_reduction_rank_static_dynamic_rank_dynamic_ok)
...@@ -992,7 +992,7 @@ TEST(partial_shape, ...@@ -992,7 +992,7 @@ TEST(partial_shape,
window_dilation, window_dilation,
is_window_all_in_padding_allowed); is_window_all_in_padding_allowed);
}, },
NodeValidationError); NodeValidationFailure);
} }
TEST(partial_shape, infer_windowed_reduction_rank_static_dynamic_rank_dynamic_neg_padding_ok) TEST(partial_shape, infer_windowed_reduction_rank_static_dynamic_rank_dynamic_neg_padding_ok)
...@@ -1071,7 +1071,7 @@ TEST(partial_shape, infer_windowed_reduction_rank_dynamic_rank_static_dynamic_wi ...@@ -1071,7 +1071,7 @@ TEST(partial_shape, infer_windowed_reduction_rank_dynamic_rank_static_dynamic_wi
window_dilation, window_dilation,
is_window_all_in_padding_allowed); is_window_all_in_padding_allowed);
}, },
NodeValidationError); NodeValidationFailure);
} }
TEST(partial_shape, TEST(partial_shape,
...@@ -1100,7 +1100,7 @@ TEST(partial_shape, ...@@ -1100,7 +1100,7 @@ TEST(partial_shape,
window_dilation, window_dilation,
is_window_all_in_padding_allowed); is_window_all_in_padding_allowed);
}, },
NodeValidationError); NodeValidationFailure);
} }
TEST(partial_shape, TEST(partial_shape,
...@@ -1156,7 +1156,7 @@ TEST(partial_shape, ...@@ -1156,7 +1156,7 @@ TEST(partial_shape,
window_dilation, window_dilation,
is_window_all_in_padding_allowed); is_window_all_in_padding_allowed);
}, },
NodeValidationError); NodeValidationFailure);
} }
TEST(partial_shape, TEST(partial_shape,
...@@ -1294,7 +1294,7 @@ TEST(partial_shape, infer_windowed_reduction_rank_static_dynamic_rank_static_dyn ...@@ -1294,7 +1294,7 @@ TEST(partial_shape, infer_windowed_reduction_rank_static_dynamic_rank_static_dyn
window_dilation, window_dilation,
is_window_all_in_padding_allowed); is_window_all_in_padding_allowed);
}, },
NodeValidationError); NodeValidationFailure);
} }
TEST(partial_shape, TEST(partial_shape,
...@@ -1351,5 +1351,5 @@ TEST(partial_shape, ...@@ -1351,5 +1351,5 @@ TEST(partial_shape,
window_dilation, window_dilation,
is_window_all_in_padding_allowed); is_window_all_in_padding_allowed);
}, },
NodeValidationError); NodeValidationFailure);
} }
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment