Commit deacf29a authored by Adam Procter's avatar Adam Procter Committed by Scott Cyphers

Add some convenience macros/classes for error messages (#1258)

* Testing out some ideas for better error messages on AvgPool

* Add uncaught_exception() check to ConstructionAssertLogger dtor

* More general assertion class, not homed inside Node

* Minor formatting change

* NODE_ASSERT for type prop failure

* Produce lighter-weight DummyAssertionHandler when assertion succeeds

* New ctor for AssertionHelper that takes a single location arg; more const&-ness for the constructors

* Remove move constructor for AssertionHelper; fix broken test in assertion.cpp

* Miscellaneous improvements

* Templatized AssertionHelper so different exception classes can be used; implemented TYPE_CHECK_ASSERT around this
* Changed from a "stack" of locations to a single location (the stack was too complicated)
* Added "FAIL" classes/macros which do not take a condition

* Rename a helper function

* Cleanup, cruft removal

* Add test to make sure the assert helper has the lifetime we expect

* Missing includes
parent 289586ab
...@@ -15,10 +15,14 @@ ...@@ -15,10 +15,14 @@
# ****************************************************************************** # ******************************************************************************
set (SRC set (SRC
axis_set.cpp
axis_vector.cpp
autodiff/adjoints.cpp autodiff/adjoints.cpp
builder/autobroadcast.cpp builder/autobroadcast.cpp
builder/numpy_transpose.cpp builder/numpy_transpose.cpp
builder/reduce_ops.cpp builder/reduce_ops.cpp
coordinate.cpp
coordinate_diff.cpp
coordinate_transform.cpp coordinate_transform.cpp
descriptor/input.cpp descriptor/input.cpp
descriptor/layout/dense_tensor_view_layout.cpp descriptor/layout/dense_tensor_view_layout.cpp
...@@ -137,6 +141,8 @@ set (SRC ...@@ -137,6 +141,8 @@ set (SRC
runtime/host_tensor_view.cpp runtime/host_tensor_view.cpp
runtime/tensor_view.cpp runtime/tensor_view.cpp
serializer.cpp serializer.cpp
shape.cpp
strides.cpp
type/element_type.cpp type/element_type.cpp
type/type.cpp type/type.cpp
util.cpp util.cpp
......
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include <exception>
#include <sstream>
#include <vector>
#include "ngraph/except.hpp"
namespace ngraph
{
/// Base class for ngraph assertion failure exceptions.
class AssertionFailure : public ngraph_error
{
public:
AssertionFailure(const std::string& what_arg)
: ngraph_error(what_arg)
, m_what(what_arg)
{
}
AssertionFailure(const char* what_arg)
: ngraph_error(what_arg)
, m_what(what_arg)
{
}
const char* what() const noexcept override { return m_what.c_str(); }
private:
std::string m_what;
};
///
/// Helper class for failed assertions. Callers should not instantiate this class directly.
/// This class is meant to be wrapped with a macro like NGRAPH_ASSERT. This class provides
/// two main facilities: (1) an ostream accessible via get_stream(), to which a detailed
/// error explanation can be written; and (2) throws an exception of type T when the
/// AssertionHelper is destructed.
///
///
/// Typical usage is via a wrapper around the NGRAPH_ASSERT_STREAM macro:
///
/// class MyException : public AssertionFailure;
///
/// #define MY_ASSERT(cond) NGRAPH_ASSERT_STREAM(::ngraph::MyException, cond)
///
/// ...
///
/// MY_ASSERT(42 != 43) << "Uh-oh. " << 42 << " is not " << 43 << ".";
///
/// If the assertion fails, it will throw a CompileError exception with a what() string of:
///
/// Assertion '42 != 43' failed at foo.cpp:123:
/// Uh-oh. 42 is not 43.
///
///
/// AssertionHelper also provides support for tagging the exception with a "location" string,
/// reflecting things like the op that was being processed when the error occurred. For
/// example:
///
/// class CompileError : public AssertionFailure;
///
/// #define COMPILE_ASSERT(node,cond) \
/// NGRAPH_ASSERT_STREAM_WITH_LOC(::ngraph::CompileError, cond, \
/// "While compiling node " + node->name())
///
/// ...
///
/// COMPILE_ASSERT(node, node->get_users().size != 0) << "Node has no users";
///
/// If the assertion fails, it will throw a CompileError exception with a what() string
/// similar to:
///
/// While compiling node Add_123:
/// Assertion 'node->get_users().size != 0' failed at foo.cpp:123:
/// Node has no users
///
template <class T>
class AssertionHelper
{
public:
AssertionHelper(const std::string& file,
int line,
const std::string& assertion_expression = "",
const std::string& location_info = "")
: m_file(file)
, m_line(line)
, m_assertion_expression(assertion_expression)
, m_location_info(location_info)
{
}
~AssertionHelper() noexcept(false)
{
// If stack unwinding is already in progress, do not double-throw.
if (!std::uncaught_exception())
{
std::stringstream ss;
if (!m_location_info.empty())
{
ss << m_location_info << ":" << std::endl;
}
if (m_assertion_expression.empty())
{
ss << "Failure ";
}
else
{
ss << "Assertion '" << m_assertion_expression << "' failed ";
}
ss << "at " << m_file << ":" << m_line << ":" << std::endl;
std::string explanation = m_stream.str();
if (explanation.empty())
{
explanation = "(no explanation given)";
}
ss << explanation;
throw T(ss.str());
}
}
/// Returns an ostream to which additional error details can be written. The returned
/// stream has the lifetime of the AssertionHelper.
std::ostream& get_stream() { return m_stream; }
private:
std::stringstream m_stream;
std::string m_file;
int m_line;
std::string m_assertion_expression;
std::string m_location_info;
};
///
/// Class that returns a dummy ostream to absorb error strings for non-failed assertions.
/// This is cheaper to construct than AssertionHelper, so the macros will produce a
/// DummyAssertionHelper in lieu of an AssertionHelper if the condition is true.
///
class DummyAssertionHelper
{
public:
/// Returns an ostream to which additional error details can be written. Anything written
/// to this stream will be ignored. The returned stream has the lifetime of the
/// DummyAssertionHelper.
std::ostream& get_stream() { return m_stream; }
private:
std::stringstream m_stream;
};
}
/// Asserts condition "cond" with an exception class of "T", at location "loc".
#define NGRAPH_ASSERT_STREAM_WITH_LOC(T, cond, loc) \
(cond ? ::ngraph::DummyAssertionHelper().get_stream() \
: ::ngraph::AssertionHelper<T>(__FILE__, __LINE__, #cond, loc).get_stream())
/// Asserts condition "cond" with an exception class of "T", and no location specified.
#define NGRAPH_ASSERT_STREAM(T, cond) \
(cond ? ::ngraph::DummyAssertionHelper().get_stream() \
: ::ngraph::AssertionHelper<T>(__FILE__, __LINE__, #cond).get_stream())
/// Fails unconditionally with an exception class of "T", at location "loc".
#define NGRAPH_FAIL_STREAM_WITH_LOC(T, loc) \
::ngraph::AssertionHelper<T>(__FILE__, __LINE__, "", loc).get_stream()
/// Fails unconditionally with an exception class of "T", and no location specified.
#define NGRAPH_FAIL_STREAM(T) ::ngraph::AssertionHelper<T>(__FILE__, __LINE__).get_stream()
#define NGRAPH_ASSERT(cond) NGRAPH_ASSERT_STREAM(::ngraph::AssertionFailure, cond)
#define NGRAPH_FAIL() NGRAPH_FAIL_STREAM(::ngraph::AssertionFailure)
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "ngraph/axis_set.hpp"
#include "ngraph/util.hpp"
std::ostream& ngraph::operator<<(std::ostream& s, const AxisSet& axis_set)
{
s << "AxisSet{";
s << ngraph::join(axis_set);
s << "}";
return s;
}
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#pragma once #pragma once
#include <cstddef> #include <cstddef>
#include <ostream>
#include <set> #include <set>
#include <vector> #include <vector>
...@@ -59,4 +60,6 @@ namespace ngraph ...@@ -59,4 +60,6 @@ namespace ngraph
return *this; return *this;
} }
}; };
std::ostream& operator<<(std::ostream& s, const AxisSet& axis_set);
} }
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "ngraph/axis_vector.hpp"
#include "ngraph/util.hpp"
std::ostream& ngraph::operator<<(std::ostream& s, const AxisVector& axis_vector)
{
s << "AxisVector{";
s << ngraph::join(axis_vector);
s << "}";
return s;
}
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#pragma once #pragma once
#include <cstddef> #include <cstddef>
#include <ostream>
#include <vector> #include <vector>
namespace ngraph namespace ngraph
...@@ -63,4 +64,6 @@ namespace ngraph ...@@ -63,4 +64,6 @@ namespace ngraph
return *this; return *this;
} }
}; };
std::ostream& operator<<(std::ostream& s, const AxisVector& axis_vector);
} }
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "ngraph/coordinate.hpp"
#include "ngraph/util.hpp"
std::ostream& ngraph::operator<<(std::ostream& s, const Coordinate& coordinate)
{
s << "Coordinate{";
s << ngraph::join(coordinate);
s << "}";
return s;
}
...@@ -130,4 +130,6 @@ namespace ngraph ...@@ -130,4 +130,6 @@ namespace ngraph
std::vector<std::pair<size_t, size_t>>{ std::vector<std::pair<size_t, size_t>>{
std::pair<size_t, size_t>(new_axis_pos, new_axis_val)}); std::pair<size_t, size_t>(new_axis_pos, new_axis_val)});
} }
std::ostream& operator<<(std::ostream& s, const Coordinate& coordinate);
} }
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "ngraph/coordinate_diff.hpp"
#include "ngraph/util.hpp"
std::ostream& ngraph::operator<<(std::ostream& s, const CoordinateDiff& coordinate_diff)
{
s << "CoordinateDiff{";
s << ngraph::join(coordinate_diff);
s << "}";
return s;
}
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#pragma once #pragma once
#include <cstddef> #include <cstddef>
#include <ostream>
#include <vector> #include <vector>
namespace ngraph namespace ngraph
...@@ -63,4 +64,6 @@ namespace ngraph ...@@ -63,4 +64,6 @@ namespace ngraph
return *this; return *this;
} }
}; };
std::ostream& operator<<(std::ostream& s, const CoordinateDiff& coordinate_diff);
} }
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include <memory> #include <memory>
#include <sstream>
#include <typeindex> #include <typeindex>
#include <typeinfo> #include <typeinfo>
...@@ -329,3 +330,11 @@ NodeVector Node::get_users() const ...@@ -329,3 +330,11 @@ NodeVector Node::get_users() const
return result; return result;
} }
std::string ngraph::type_check_assert_string(const Node* node)
{
std::stringstream ss;
ss << "While type-checking node '" << node->get_name() << "' of type '" << node->description()
<< "'";
return ss.str();
}
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
#include "ngraph/assertion.hpp"
#include "ngraph/autodiff/adjoints.hpp" #include "ngraph/autodiff/adjoints.hpp"
#include "ngraph/descriptor/input.hpp" #include "ngraph/descriptor/input.hpp"
#include "ngraph/descriptor/output.hpp" #include "ngraph/descriptor/output.hpp"
...@@ -58,6 +59,8 @@ namespace ngraph ...@@ -58,6 +59,8 @@ namespace ngraph
const std::shared_ptr<Node>& dst_node, const std::shared_ptr<Node>& dst_node,
const std::shared_ptr<Node>& new_node); const std::shared_ptr<Node>& new_node);
std::string type_check_assert_string(const Node* node);
/// Nodes are the backbone of the graph of Value dataflow. Every node has /// Nodes are the backbone of the graph of Value dataflow. Every node has
/// zero or more nodes as arguments and one value, which is either a tensor /// zero or more nodes as arguments and one value, which is either a tensor
/// view or a (possibly empty) tuple of values. /// view or a (possibly empty) tuple of values.
...@@ -204,4 +207,23 @@ namespace ngraph ...@@ -204,4 +207,23 @@ namespace ngraph
std::unordered_map<Node*, autodiff::Adjoints> m_adjoint_map; std::unordered_map<Node*, autodiff::Adjoints> m_adjoint_map;
Placement m_placement = Placement::DEFAULT; Placement m_placement = Placement::DEFAULT;
}; };
class TypeCheckError : public AssertionFailure
{
public:
TypeCheckError(std::string what)
: AssertionFailure(what)
{
}
TypeCheckError(const char* what)
: AssertionFailure(what)
{
}
};
} }
#define TYPE_CHECK_ASSERT(node, cond) \
NGRAPH_ASSERT_STREAM_WITH_LOC( \
::ngraph::TypeCheckError, cond, ::ngraph::type_check_assert_string(node))
#define TYPE_CHECK_FAIL(node) \
NGRAPH_FAIL_STREAM_WITH_LOC(::ngraph::TypeCheckError, ::ngraph::type_check_assert_string(node))
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
*******************************************************************************/ *******************************************************************************/
#include "ngraph/op/avg_pool.hpp" #include "ngraph/op/avg_pool.hpp"
#include "ngraph/assertion.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
using namespace std; using namespace std;
...@@ -36,56 +37,39 @@ op::AvgPool::AvgPool(const shared_ptr<Node>& arg, ...@@ -36,56 +37,39 @@ op::AvgPool::AvgPool(const shared_ptr<Node>& arg,
auto& arg_shape = get_input_shape(0); auto& arg_shape = get_input_shape(0);
// //
// Make sure arg: NCDi for some Di of rank>0, N != 0, C != 0. // Make sure batch size and channel count are not zero, and that we have at least one spatial
// dimension (in other words, that arg has shape NCDi for some Di of rank>0, N != 0, C != 0).
// //
if (arg_shape.size() < 3) TYPE_CHECK_ASSERT(this, arg_shape.size() >= 3)
{ << "Data input shape does not have rank of at least 3 (data input shape: " << arg_shape
throw ngraph_error( << ").";
"Average-pool data batch input must have rank of at least 3 (one batch axis, one "
"channel axis, at least one spatial dimension).");
}
size_t batch_size = arg_shape[0]; size_t batch_size = arg_shape[0];
if (batch_size == 0) TYPE_CHECK_ASSERT(this, batch_size != 0)
{ << "Data batch size is zero (data input shape: " << arg_shape << ").";
throw ngraph_error("Average-pool data batch size is zero.");
}
size_t channel_count = arg_shape[1]; size_t channel_count = arg_shape[1];
if (channel_count == 0) TYPE_CHECK_ASSERT(this, channel_count != 0)
{ << "Channel count is zero (data input shape: " << arg_shape << ").";
throw ngraph_error("Average-pool requires at least one feature channel.");
}
size_t spatial_dimension_count = arg_shape.size() - 2; size_t spatial_dimension_count = arg_shape.size() - 2;
// //
// Make sure window shape, window movement strides, and padding have same rank as Di. // Make sure window shape, window movement strides, and padding have same rank as Di.
// //
if (window_shape.size() != spatial_dimension_count) TYPE_CHECK_ASSERT(this, window_shape.size() == spatial_dimension_count)
{ << "Window shape rank does not match number of spatial dimensions (window shape: "
throw ngraph_error( << window_shape << ", data input shape: " << arg_shape << ").";
"Average-pool window shape rank does not match number of spatial dimensions."); TYPE_CHECK_ASSERT(this, window_movement_strides.size() == spatial_dimension_count)
} << "Window movement stride rank does not match number of spatial dimensions (window "
"movement strides: "
if (window_movement_strides.size() != spatial_dimension_count) << window_movement_strides << ", data input shape: " << arg_shape << ").";
{ TYPE_CHECK_ASSERT(this, padding_below.size() == spatial_dimension_count)
throw ngraph_error( << "Below-padding rank does not match number of spatial dimensions (padding below: "
"Average-pool window movement stride rank does not match number of spatial " << padding_below << ", data input shape: " << arg_shape << ").";
"dimensions."); TYPE_CHECK_ASSERT(this, padding_above.size() == spatial_dimension_count)
} << "Above-padding rank does not match number of spatial dimensions (padding above: "
<< padding_above << ", data input shape: " << arg_shape << ").";
if (padding_below.size() != spatial_dimension_count)
{
throw ngraph_error(
"Average-pool below-padding rank does not match number of spatial dimensions.");
}
if (padding_above.size() != spatial_dimension_count)
{
throw ngraph_error(
"Average-pool above-padding rank does not match number of spatial dimensions.");
}
// //
// Extract input item shape Di and make sure all dimensions are larger than 0. // Extract input item shape Di and make sure all dimensions are larger than 0.
...@@ -97,11 +81,14 @@ op::AvgPool::AvgPool(const shared_ptr<Node>& arg, ...@@ -97,11 +81,14 @@ op::AvgPool::AvgPool(const shared_ptr<Node>& arg,
size_t dim_size = arg_shape[1 + 1 + i]; size_t dim_size = arg_shape[1 + 1 + i];
size_t virtual_dim_size = padding_below[i] + dim_size + padding_above[i]; size_t virtual_dim_size = padding_below[i] + dim_size + padding_above[i];
input_item_virtual_shape.push_back(virtual_dim_size); input_item_virtual_shape.push_back(virtual_dim_size);
}
if (virtual_dim_size == 0) for (size_t i = 0; i < spatial_dimension_count; i++)
{ {
throw ngraph_error("Average-pool input spatial dimension is zero even after padding."); TYPE_CHECK_ASSERT(this, input_item_virtual_shape[i] != 0)
} << "Data input spatial dimension " << i
<< " has zero length even after padding (virtual shape of input item: "
<< input_item_virtual_shape << ").";
} }
// //
...@@ -109,10 +96,9 @@ op::AvgPool::AvgPool(const shared_ptr<Node>& arg, ...@@ -109,10 +96,9 @@ op::AvgPool::AvgPool(const shared_ptr<Node>& arg,
// //
for (size_t i = 0; i < spatial_dimension_count; i++) for (size_t i = 0; i < spatial_dimension_count; i++)
{ {
if (window_shape[i] == 0) TYPE_CHECK_ASSERT(this, window_shape[i] != 0)
{ << "Window shape dimension " << i << " has zero length (window shape: " << window_shape
throw ngraph_error("Average-pool window shape has a zero-length axis."); << ").";
}
} }
// //
...@@ -120,12 +106,10 @@ op::AvgPool::AvgPool(const shared_ptr<Node>& arg, ...@@ -120,12 +106,10 @@ op::AvgPool::AvgPool(const shared_ptr<Node>& arg,
// //
for (size_t i = 0; i < spatial_dimension_count; i++) for (size_t i = 0; i < spatial_dimension_count; i++)
{ {
if (window_shape[i] > input_item_virtual_shape[i]) TYPE_CHECK_ASSERT(this, window_shape[i] <= input_item_virtual_shape[i])
{ << "Window shape after padding is larger than the spatial dimensions (window shape: "
throw ngraph_error( << window_shape << ", virtual shape of input item: " << input_item_virtual_shape
"Average-pool window shape is larger than the spatial dimensions even after " << ").";
"padding.");
}
} }
// //
...@@ -135,10 +119,9 @@ op::AvgPool::AvgPool(const shared_ptr<Node>& arg, ...@@ -135,10 +119,9 @@ op::AvgPool::AvgPool(const shared_ptr<Node>& arg,
for (size_t i = 0; i < spatial_dimension_count; i++) for (size_t i = 0; i < spatial_dimension_count; i++)
{ {
if (window_movement_strides[i] == 0) TYPE_CHECK_ASSERT(this, window_movement_strides[i] != 0)
{ << "Window movement strides dimension " << i
throw ngraph_error("Average-pool window axis movement stride is zero."); << " has zero length (window movement strides: " << window_movement_strides << ").";
}
output_item_shape.push_back(ceil_div(input_item_virtual_shape[i] - window_shape[i] + 1, output_item_shape.push_back(ceil_div(input_item_virtual_shape[i] - window_shape[i] + 1,
window_movement_strides[i])); window_movement_strides[i]));
} }
...@@ -160,12 +143,10 @@ op::AvgPool::AvgPool(const shared_ptr<Node>& arg, ...@@ -160,12 +143,10 @@ op::AvgPool::AvgPool(const shared_ptr<Node>& arg,
// Checking the lower edge of each dimension is easy, because there's no mystery // Checking the lower edge of each dimension is easy, because there's no mystery
// regarding the window's lower-edge placement... // regarding the window's lower-edge placement...
if ((dim_padding_below > 0) && (dim_window_size <= dim_padding_below)) TYPE_CHECK_ASSERT(this, dim_padding_below == 0 || dim_window_size > dim_padding_below)
{ << "Window will sometimes reside entirely within the below-padding region, but"
throw ngraph_error( << " include_padding_in_avg_computation was not set (padding below: "
"Average-pool window will sometimes reside entirely within the padding-below " << padding_below << ", window shape: " << window_shape << ").";
"region, but this average-pool op disregards padding elements.");
}
// Now check the upper-bound... // Now check the upper-bound...
{ {
...@@ -173,14 +154,12 @@ op::AvgPool::AvgPool(const shared_ptr<Node>& arg, ...@@ -173,14 +154,12 @@ op::AvgPool::AvgPool(const shared_ptr<Node>& arg,
const size_t dim_window_max_lower_offset = dim_num_strides * dim_stride; const size_t dim_window_max_lower_offset = dim_num_strides * dim_stride;
const size_t dim_padding_above_start_offset = dim_virtual_size - dim_padding_above; const size_t dim_padding_above_start_offset = dim_virtual_size - dim_padding_above;
if ((dim_padding_above > 0) && TYPE_CHECK_ASSERT(this,
(dim_window_max_lower_offset >= dim_padding_above_start_offset)) dim_padding_above == 0 ||
{ dim_window_max_lower_offset < dim_padding_above_start_offset)
throw ngraph_error( << "Window will sometimes reside entirely within the above-padding region, but"
"Average-pool window will sometimes reside entirely within the " << " include_padding_in_avg_computation was not set (padding above: "
"padding-above " << padding_above << ", window shape: " << window_shape << ").";
"region, but this average-pool op disregards padding elements.");
}
} }
} }
} }
...@@ -198,18 +177,10 @@ op::AvgPool::AvgPool(const shared_ptr<Node>& arg, ...@@ -198,18 +177,10 @@ op::AvgPool::AvgPool(const shared_ptr<Node>& arg,
static Shape default_padding(const shared_ptr<Node>& arg) static Shape default_padding(const shared_ptr<Node>& arg)
{ {
if (arg->get_outputs().size() != 1) auto& arg_shape = arg->get_output_shape(0);
{
throw ngraph_error("Average-pool data batch argument must have exactly one output");
}
auto& arg_shape = arg->get_outputs().at(0).get_shape();
if (arg_shape.size() < 3) if (arg_shape.size() < 3)
{ {
// For consistency we should throw the same error message here that we throw in the constructor. return Shape{};
throw ngraph_error(
"Average-pool data batch input must have rank of at least 3 (one batch axis, one "
"channel axis, at least one spatial dimension).");
} }
return Shape(arg_shape.size() - 2, 0); return Shape(arg_shape.size() - 2, 0);
} }
...@@ -228,18 +199,10 @@ op::AvgPool::AvgPool(const shared_ptr<Node>& arg, ...@@ -228,18 +199,10 @@ op::AvgPool::AvgPool(const shared_ptr<Node>& arg,
static Strides default_strides(const shared_ptr<Node>& arg) static Strides default_strides(const shared_ptr<Node>& arg)
{ {
if (arg->get_outputs().size() != 1) auto& arg_shape = arg->get_output_shape(0);
{
throw ngraph_error("Average-pool data batch argument must have exactly one output");
}
auto& arg_shape = arg->get_outputs().at(0).get_shape();
if (arg_shape.size() < 3) if (arg_shape.size() < 3)
{ {
// For consistency we should throw the same error message here that we throw in the constructor. return Strides{};
throw ngraph_error(
"Average-pool data batch input must have rank of at least 3 (one batch axis, one "
"channel axis, at least one spatial dimension).");
} }
return Strides(arg_shape.size() - 2, 1); return Strides(arg_shape.size() - 2, 1);
} }
......
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "ngraph/shape.hpp"
#include "ngraph/util.hpp"
std::ostream& ngraph::operator<<(std::ostream& s, const Shape& shape)
{
s << "Shape{";
s << ngraph::join(shape);
s << "}";
return s;
}
...@@ -105,4 +105,6 @@ namespace ngraph ...@@ -105,4 +105,6 @@ namespace ngraph
{ {
return 1 == shape.size(); return 1 == shape.size();
} }
std::ostream& operator<<(std::ostream& s, const Shape& shape);
} }
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "ngraph/strides.hpp"
#include "ngraph/util.hpp"
std::ostream& ngraph::operator<<(std::ostream& s, const Strides& strides)
{
s << "Strides{";
s << ngraph::join(strides);
s << "}";
return s;
}
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#pragma once #pragma once
#include <cstddef> #include <cstddef>
#include <ostream>
#include <vector> #include <vector>
namespace ngraph namespace ngraph
...@@ -63,4 +64,6 @@ namespace ngraph ...@@ -63,4 +64,6 @@ namespace ngraph
return *this; return *this;
} }
}; };
std::ostream& operator<<(std::ostream& s, const Strides& strides);
} }
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
set(SRC set(SRC
algebraic_simplification.cpp algebraic_simplification.cpp
assertion.cpp
builder_autobroadcast.cpp builder_autobroadcast.cpp
build_graph.cpp build_graph.cpp
constant_folding.cpp constant_folding.cpp
......
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include <gtest/gtest.h>
#include "ngraph/assertion.hpp"
using namespace ngraph;
using namespace std;
TEST(assertion, assertion_true)
{
NGRAPH_ASSERT(true) << "this should not throw";
}
TEST(assertion, assertion_false)
{
EXPECT_THROW({ NGRAPH_ASSERT(false) << "this should throw"; }, AssertionFailure);
}
TEST(assertion, assertion_with_explanation)
{
bool assertion_failure_thrown = false;
try
{
NGRAPH_ASSERT(false) << "xyzzyxyzzy";
}
catch (const AssertionFailure& e)
{
assertion_failure_thrown = true;
EXPECT_PRED_FORMAT2(testing::IsSubstring, "Assertion 'false' failed", e.what());
EXPECT_PRED_FORMAT2(testing::IsSubstring, "xyzzyxyzzy", e.what());
}
EXPECT_TRUE(assertion_failure_thrown);
}
TEST(assertion, assertion_throws_at_semicolon)
{
bool assertion_failure_thrown = false;
bool got_past_semicolon = false;
try
{
NGRAPH_ASSERT(false) << "first assert";
got_past_semicolon = true;
NGRAPH_ASSERT(false) << "second assert";
}
catch (const AssertionFailure& e)
{
assertion_failure_thrown = true;
EXPECT_PRED_FORMAT2(testing::IsSubstring, "first assert", e.what());
}
EXPECT_FALSE(got_past_semicolon);
EXPECT_TRUE(assertion_failure_thrown);
}
TEST(assertion, assertion_no_explanation)
{
bool assertion_failure_thrown = false;
try
{
NGRAPH_ASSERT(false);
}
catch (const AssertionFailure& e)
{
assertion_failure_thrown = true;
EXPECT_PRED_FORMAT2(testing::IsSubstring, "Assertion 'false' failed", e.what());
EXPECT_PRED_FORMAT2(testing::IsSubstring, "(no explanation given)", e.what());
}
EXPECT_TRUE(assertion_failure_thrown);
}
// Internally, NGRAPH_ASSERT works by throwing from the destructor of an "AssertionHelper" object
// generated inside the macro. This can be dangerous if a throw happens somewhere else while the
// AssertionHelper is in scope, because stack unwinding will cause a call ~AssertionHelper, and
// this causes a "double-throw", resulting in uncatchable program termination.
//
// To avoid this, ~AssertionHelper destructor checks std::uncaught_exception() and does not throw
// if it returns true. This avoids the most likely double-throw scenario in ordinary usage, where
// the expressions feeding the stream throw exceptions themselves.
//
// Here we are testing to make sure that the exception from the stream-feeding expression is
// propagated properly, and that ~AssertionHelper itself does not throw even though the assertion
// is false.
TEST(assertion, throw_in_stream)
{
auto f = []() -> std::string {
// The choice of exception class here is arbitrary.
throw std::domain_error("this should throw std::domain_error");
};
EXPECT_THROW({ NGRAPH_ASSERT(false) << f(); }, std::domain_error);
}
TEST(assertion, fail_with_explanation)
{
bool assertion_failure_thrown = false;
try
{
NGRAPH_FAIL() << "xyzzyxyzzy";
}
catch (const AssertionFailure& e)
{
assertion_failure_thrown = true;
EXPECT_PRED_FORMAT2(testing::IsSubstring, "Failure", e.what());
EXPECT_PRED_FORMAT2(testing::IsSubstring, "xyzzyxyzzy", e.what());
}
EXPECT_TRUE(assertion_failure_thrown);
}
TEST(assertion, fail_no_explanation)
{
bool assertion_failure_thrown = false;
try
{
NGRAPH_FAIL();
}
catch (const AssertionFailure& e)
{
assertion_failure_thrown = true;
EXPECT_PRED_FORMAT2(testing::IsSubstring, "Failure", e.what());
EXPECT_PRED_FORMAT2(testing::IsSubstring, "(no explanation given)", e.what());
}
EXPECT_TRUE(assertion_failure_thrown);
}
...@@ -24,6 +24,9 @@ ...@@ -24,6 +24,9 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
#define EXPECT_HAS_SUBSTRING(haystack, needle) \
EXPECT_PRED_FORMAT2(testing::IsSubstring, needle, haystack)
// //
// Tests for broadcast. // Tests for broadcast.
// //
...@@ -5879,12 +5882,9 @@ TEST(type_prop, avg_pool_invalid_0d_input) ...@@ -5879,12 +5882,9 @@ TEST(type_prop, avg_pool_invalid_0d_input)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid 0D input not detected"; FAIL() << "Invalid 0D input not detected";
} }
catch (const ngraph_error& error) catch (const TypeCheckError& error)
{ {
EXPECT_EQ(error.what(), EXPECT_HAS_SUBSTRING(error.what(), "Data input shape does not have rank of at least 3");
std::string("Average-pool data batch input must have rank of at "
"least 3 (one batch axis, one channel axis, at "
"least one spatial dimension)."));
} }
catch (...) catch (...)
{ {
...@@ -5904,12 +5904,9 @@ TEST(type_prop, avg_pool_invalid_1d_input) ...@@ -5904,12 +5904,9 @@ TEST(type_prop, avg_pool_invalid_1d_input)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid 1D input not detected"; FAIL() << "Invalid 1D input not detected";
} }
catch (const ngraph_error& error) catch (const TypeCheckError& error)
{ {
EXPECT_EQ(error.what(), EXPECT_HAS_SUBSTRING(error.what(), "Data input shape does not have rank of at least 3");
std::string("Average-pool data batch input must have rank of at "
"least 3 (one batch axis, one channel axis, at "
"least one spatial dimension)."));
} }
catch (...) catch (...)
{ {
...@@ -5929,12 +5926,9 @@ TEST(type_prop, avg_pool_invalid_2d_input) ...@@ -5929,12 +5926,9 @@ TEST(type_prop, avg_pool_invalid_2d_input)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid 2D input not detected"; FAIL() << "Invalid 2D input not detected";
} }
catch (const ngraph_error& error) catch (const TypeCheckError& error)
{ {
EXPECT_EQ(error.what(), EXPECT_HAS_SUBSTRING(error.what(), "Data input shape does not have rank of at least 3");
std::string("Average-pool data batch input must have rank of at "
"least 3 (one batch axis, one channel axis, at "
"least one spatial dimension)."));
} }
catch (...) catch (...)
{ {
...@@ -5954,9 +5948,9 @@ TEST(type_prop, avg_pool_invalid_0_batch_size) ...@@ -5954,9 +5948,9 @@ TEST(type_prop, avg_pool_invalid_0_batch_size)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid input with 0 batch size not detected"; FAIL() << "Invalid input with 0 batch size not detected";
} }
catch (const ngraph_error& error) catch (const TypeCheckError& error)
{ {
EXPECT_EQ(error.what(), std::string("Average-pool data batch size is zero.")); EXPECT_HAS_SUBSTRING(error.what(), "Data batch size is zero");
} }
catch (...) catch (...)
{ {
...@@ -5976,9 +5970,9 @@ TEST(type_prop, avg_pool_invalid_0_channels) ...@@ -5976,9 +5970,9 @@ TEST(type_prop, avg_pool_invalid_0_channels)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid input with 0 channels not detected"; FAIL() << "Invalid input with 0 channels not detected";
} }
catch (const ngraph_error& error) catch (const TypeCheckError& error)
{ {
EXPECT_EQ(error.what(), std::string("Average-pool requires at least one feature channel.")); EXPECT_HAS_SUBSTRING(error.what(), "Channel count is zero");
} }
catch (...) catch (...)
{ {
...@@ -5998,12 +5992,10 @@ TEST(type_prop, avg_pool_invalid_wrong_number_of_window_dimensions_too_many) ...@@ -5998,12 +5992,10 @@ TEST(type_prop, avg_pool_invalid_wrong_number_of_window_dimensions_too_many)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid input with too many window dimensions not detected"; FAIL() << "Invalid input with too many window dimensions not detected";
} }
catch (const ngraph_error& error) catch (const TypeCheckError& error)
{ {
EXPECT_EQ( EXPECT_HAS_SUBSTRING(error.what(),
error.what(), "Window shape rank does not match number of spatial dimensions");
std::string(
"Average-pool window shape rank does not match number of spatial dimensions."));
} }
catch (...) catch (...)
{ {
...@@ -6023,12 +6015,10 @@ TEST(type_prop, avg_pool_invalid_wrong_number_of_window_dimensions_too_few) ...@@ -6023,12 +6015,10 @@ TEST(type_prop, avg_pool_invalid_wrong_number_of_window_dimensions_too_few)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid input with too few window dimensions not detected"; FAIL() << "Invalid input with too few window dimensions not detected";
} }
catch (const ngraph_error& error) catch (const TypeCheckError& error)
{ {
EXPECT_EQ( EXPECT_HAS_SUBSTRING(error.what(),
error.what(), "Window shape rank does not match number of spatial dimensions");
std::string(
"Average-pool window shape rank does not match number of spatial dimensions."));
} }
catch (...) catch (...)
{ {
...@@ -6049,11 +6039,11 @@ TEST(type_prop, avg_pool_invalid_movement_stride_rank) ...@@ -6049,11 +6039,11 @@ TEST(type_prop, avg_pool_invalid_movement_stride_rank)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid input with wrong movement stride rank not detected"; FAIL() << "Invalid input with wrong movement stride rank not detected";
} }
catch (const ngraph_error& error) catch (const TypeCheckError& error)
{ {
EXPECT_EQ(error.what(), EXPECT_HAS_SUBSTRING(
std::string("Average-pool window movement stride rank does not " error.what(),
"match number of spatial dimensions.")); "Window movement stride rank does not match number of spatial dimensions");
} }
catch (...) catch (...)
{ {
...@@ -6077,11 +6067,10 @@ TEST(type_prop, avg_pool_invalid_padding_below_rank) ...@@ -6077,11 +6067,10 @@ TEST(type_prop, avg_pool_invalid_padding_below_rank)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid input with wrong below-padding rank not detected"; FAIL() << "Invalid input with wrong below-padding rank not detected";
} }
catch (const ngraph_error& error) catch (const TypeCheckError& error)
{ {
EXPECT_EQ(error.what(), EXPECT_HAS_SUBSTRING(error.what(),
std::string("Average-pool below-padding rank does not " "Below-padding rank does not match number of spatial dimensions");
"match number of spatial dimensions."));
} }
catch (...) catch (...)
{ {
...@@ -6105,11 +6094,10 @@ TEST(type_prop, avg_pool_invalid_padding_above_rank) ...@@ -6105,11 +6094,10 @@ TEST(type_prop, avg_pool_invalid_padding_above_rank)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid input with wrong above-padding rank not detected"; FAIL() << "Invalid input with wrong above-padding rank not detected";
} }
catch (const ngraph_error& error) catch (const TypeCheckError& error)
{ {
EXPECT_EQ(error.what(), EXPECT_HAS_SUBSTRING(error.what(),
std::string("Average-pool above-padding rank does not " "Above-padding rank does not match number of spatial dimensions");
"match number of spatial dimensions."));
} }
catch (...) catch (...)
{ {
...@@ -6129,10 +6117,10 @@ TEST(type_prop, avg_pool_invalid_input_item_size_0) ...@@ -6129,10 +6117,10 @@ TEST(type_prop, avg_pool_invalid_input_item_size_0)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid input with zero-length spatial axis not detected"; FAIL() << "Invalid input with zero-length spatial axis not detected";
} }
catch (const ngraph_error& error) catch (const TypeCheckError& error)
{ {
EXPECT_EQ(error.what(), EXPECT_HAS_SUBSTRING(error.what(),
std::string("Average-pool input spatial dimension is zero even after padding.")); "Data input spatial dimension 0 has zero length even after padding");
} }
catch (...) catch (...)
{ {
...@@ -6152,9 +6140,9 @@ TEST(type_prop, avg_pool_invalid_window_size_0) ...@@ -6152,9 +6140,9 @@ TEST(type_prop, avg_pool_invalid_window_size_0)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid input with zero-length window axis not detected"; FAIL() << "Invalid input with zero-length window axis not detected";
} }
catch (const ngraph_error& error) catch (const TypeCheckError& error)
{ {
EXPECT_EQ(error.what(), std::string("Average-pool window shape has a zero-length axis.")); EXPECT_HAS_SUBSTRING(error.what(), "Window shape dimension 1 has zero length");
} }
catch (...) catch (...)
{ {
...@@ -6174,11 +6162,10 @@ TEST(type_prop, avg_pool_invalid_dilated_too_large) ...@@ -6174,11 +6162,10 @@ TEST(type_prop, avg_pool_invalid_dilated_too_large)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid input with oversized window not detected"; FAIL() << "Invalid input with oversized window not detected";
} }
catch (const ngraph_error& error) catch (const TypeCheckError& error)
{ {
EXPECT_EQ(error.what(), EXPECT_HAS_SUBSTRING(error.what(),
std::string("Average-pool window shape is larger than the spatial " "Window shape after padding is larger than the spatial dimensions");
"dimensions even after padding."));
} }
catch (...) catch (...)
{ {
...@@ -6199,9 +6186,9 @@ TEST(type_prop, avg_pool_invalid_movement_stride_0) ...@@ -6199,9 +6186,9 @@ TEST(type_prop, avg_pool_invalid_movement_stride_0)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid input with 0-length movement stride axis not detected"; FAIL() << "Invalid input with 0-length movement stride axis not detected";
} }
catch (const ngraph_error& error) catch (const TypeCheckError& error)
{ {
EXPECT_EQ(error.what(), std::string("Average-pool window axis movement stride is zero.")); EXPECT_HAS_SUBSTRING(error.what(), "Window movement strides dimension 0 has zero length");
} }
catch (...) catch (...)
{ {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment