Unverified Commit 189cf3b7 authored by Michał Karzyński's avatar Michał Karzyński Committed by GitHub

[ONNX] Refactor exceptions to asserts (#1573)

parent 37174c90
...@@ -166,12 +166,12 @@ namespace ngraph ...@@ -166,12 +166,12 @@ namespace ngraph
/// Asserts condition "cond" with an exception class of "T", at location "loc". /// Asserts condition "cond" with an exception class of "T", at location "loc".
#define NGRAPH_ASSERT_STREAM_WITH_LOC(T, cond, loc) \ #define NGRAPH_ASSERT_STREAM_WITH_LOC(T, cond, loc) \
(cond ? ::ngraph::DummyAssertionHelper().get_stream() \ ((cond) ? ::ngraph::DummyAssertionHelper().get_stream() \
: ::ngraph::AssertionHelper<T>(__FILE__, __LINE__, #cond, loc).get_stream()) : ::ngraph::AssertionHelper<T>(__FILE__, __LINE__, #cond, loc).get_stream())
/// Asserts condition "cond" with an exception class of "T", and no location specified. /// Asserts condition "cond" with an exception class of "T", and no location specified.
#define NGRAPH_ASSERT_STREAM(T, cond) \ #define NGRAPH_ASSERT_STREAM(T, cond) \
(cond ? ::ngraph::DummyAssertionHelper().get_stream() \ ((cond) ? ::ngraph::DummyAssertionHelper().get_stream() \
: ::ngraph::AssertionHelper<T>(__FILE__, __LINE__, #cond).get_stream()) : ::ngraph::AssertionHelper<T>(__FILE__, __LINE__, #cond).get_stream())
/// Fails unconditionally with an exception class of "T", at location "loc". /// Fails unconditionally with an exception class of "T", at location "loc".
#define NGRAPH_FAIL_STREAM_WITH_LOC(T, loc) \ #define NGRAPH_FAIL_STREAM_WITH_LOC(T, loc) \
::ngraph::AssertionHelper<T>(__FILE__, __LINE__, "", loc).get_stream() ::ngraph::AssertionHelper<T>(__FILE__, __LINE__, "", loc).get_stream()
......
...@@ -33,6 +33,22 @@ namespace ngraph ...@@ -33,6 +33,22 @@ namespace ngraph
return result; return result;
} }
std::string Node::get_description() const
{
if (!get_name().empty())
{
return get_name();
}
std::stringstream stream;
for (std::size_t index = 0; index < m_output_names.size(); ++index)
{
stream << (index != 0 ? ", " : "");
stream << m_output_names.at(index).get();
}
return stream.str();
}
} // namespace onnx_import } // namespace onnx_import
} // namespace ngraph } // namespace ngraph
...@@ -72,6 +72,11 @@ namespace ngraph ...@@ -72,6 +72,11 @@ namespace ngraph
const std::string& op_type() const { return m_node_proto->op_type(); } const std::string& op_type() const { return m_node_proto->op_type(); }
const std::string& get_name() const { return m_node_proto->name(); } const std::string& get_name() const { return m_node_proto->name(); }
/// @brief Describe the ONNX Node to make debugging graphs easier
/// Function will return the Node's name if it has one, or the names of its outputs.
/// \return Description of Node
std::string get_description() const;
const std::vector<std::reference_wrapper<const std::string>>& get_output_names() const const std::vector<std::reference_wrapper<const std::string>>& get_output_names() const
{ {
return m_output_names; return m_output_names;
...@@ -114,7 +119,7 @@ namespace ngraph ...@@ -114,7 +119,7 @@ namespace ngraph
inline std::ostream& operator<<(std::ostream& outs, const Node& node) inline std::ostream& operator<<(std::ostream& outs, const Node& node)
{ {
return (outs << "<Node(" << node.op_type() << "): " << node.get_name() << ">"); return (outs << "<Node(" << node.op_type() << "): " << node.get_description() << ">");
} }
} // namespace onnx_import } // namespace onnx_import
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#pragma once #pragma once
#include "ngraph/assertion.hpp"
#include "ngraph/except.hpp" #include "ngraph/except.hpp"
namespace ngraph namespace ngraph
...@@ -24,31 +25,29 @@ namespace ngraph ...@@ -24,31 +25,29 @@ namespace ngraph
{ {
namespace error namespace error
{ {
struct NotSupported : ngraph_error struct NotSupported : AssertionFailure
{ {
explicit NotSupported(const std::string& op_name, explicit NotSupported(const std::string& what_arg)
const std::string& name, : AssertionFailure(what_arg)
const std::string& message)
: ngraph_error{op_name + " node (" + name + "): " + message}
{ {
} }
}; };
namespace parameter struct InvalidArgument : AssertionFailure
{ {
struct Value : ngraph_error explicit InvalidArgument(const std::string& what_arg)
: AssertionFailure(what_arg)
{ {
Value(const std::string& op_name, }
const std::string& name, };
const std::string& message)
: ngraph_error{op_name + " node (" + name + "): " + message}
{
}
};
} // namespace paramter
} // namespace error } // namespace error
} // namespace onnx_import } // namespace onnx_import
} // namespace ngraph } // namespace ngraph
#define ASSERT_IS_SUPPORTED(node_, cond_) \
NGRAPH_ASSERT_STREAM(ngraph::onnx_import::error::NotSupported, cond_) << (node_) << " "
#define ASSERT_VALID_ARGUMENT(node_, cond_) \
NGRAPH_ASSERT_STREAM(ngraph::onnx_import::error::InvalidArgument, cond_) << (node_) << " "
...@@ -44,18 +44,8 @@ namespace ngraph ...@@ -44,18 +44,8 @@ namespace ngraph
// float momentum{node.get_attribute_value<float>("momentum", 0.9f)}; // float momentum{node.get_attribute_value<float>("momentum", 0.9f)};
bool training = false; bool training = false;
if (!is_test) ASSERT_IS_SUPPORTED(node, is_test) << "only 'is_test' mode is supported.";
{ ASSERT_IS_SUPPORTED(node, spatial) << "only 'spatial' mode is supported.";
throw error::NotSupported("BatchNormalization",
node.get_name(),
"only 'is_test' mode is currently supported.");
}
if (!spatial)
{
throw error::NotSupported("BatchNormalization",
node.get_name(),
"only 'spatial' mode is currently supported.");
}
if (inputs.size() >= 5) if (inputs.size() >= 5)
{ {
......
...@@ -108,15 +108,10 @@ namespace ngraph ...@@ -108,15 +108,10 @@ namespace ngraph
int64_t groups{node.get_attribute_value<int64_t>("group", 1)}; int64_t groups{node.get_attribute_value<int64_t>("group", 1)};
// TODO: update to ASSERTION CHECK ASSERT_VALID_ARGUMENT(node,
if (groups < 0 || groups > data->get_shape().at(1) || ((groups >= 0) && (groups <= data->get_shape().at(1)) &&
groups > filters->get_shape().at(0)) (groups <= filters->get_shape().at(0))))
{ << "incorrect value of 'group' attribute: " << groups;
throw error::parameter::Value{"Conv",
node.get_name(),
"incorrect value of 'group' attribute: " +
std::to_string(groups)};
}
auto strides = convpool::get_strides(node); auto strides = convpool::get_strides(node);
auto dilations = convpool::get_dilations(node); auto dilations = convpool::get_dilations(node);
......
...@@ -31,12 +31,8 @@ namespace ngraph ...@@ -31,12 +31,8 @@ namespace ngraph
auto data = inputs.at(0); auto data = inputs.at(0);
auto axis = node.get_attribute_value<int64_t>("axis", 1); auto axis = node.get_attribute_value<int64_t>("axis", 1);
if (axis < 0 || axis > data->get_shape().size()) ASSERT_VALID_ARGUMENT(node, (axis >= 0) && (axis <= data->get_shape().size()))
{ << "provided 'axis' attribute is not valid.";
throw error::parameter::Value("Flatten node (",
node.get_name(),
"): provided axis attribute is not valid.");
}
return {reshape::flatten(data, axis)}; return {reshape::flatten(data, axis)};
} }
......
...@@ -45,19 +45,12 @@ namespace ngraph ...@@ -45,19 +45,12 @@ namespace ngraph
if (output_shape.empty() && ng_inputs.size() == 2) if (output_shape.empty() && ng_inputs.size() == 2)
{ {
// Currently only support Constant node. // Currently only support Constant node.
if (ng_inputs.at(1)->description() == "Constant") ASSERT_IS_SUPPORTED(node, ng_inputs.at(1)->description() == "Constant")
{ << "doesn't support shape input of other type than Constant.";
auto output_shape_node =
std::dynamic_pointer_cast<ngraph::op::Constant>(ng_inputs.at(1)); auto output_shape_node =
output_shape = output_shape_node->get_vector<std::size_t>(); std::dynamic_pointer_cast<ngraph::op::Constant>(ng_inputs.at(1));
} output_shape = output_shape_node->get_vector<std::size_t>();
else
{
throw error::NotSupported("Reshape",
node.get_name(),
"doesn't support "
"shape input of other type than Constant.");
}
} }
// Do nothing if there is no shape argument nor second node input. // Do nothing if there is no shape argument nor second node input.
else if (output_shape.empty()) else if (output_shape.empty())
......
...@@ -39,13 +39,11 @@ namespace ngraph ...@@ -39,13 +39,11 @@ namespace ngraph
{ {
axis = data_shape.size() + axis; axis = data_shape.size() + axis;
} }
else if (axis >= data_shape.size())
{ ASSERT_VALID_ARGUMENT(node, axis < data_shape.size())
throw error::parameter::Value( << "provided 'axis' value:" << axis
"Softmax node (", << " is out of input tensor dimensions range.";
node.get_name(),
"): provided axis attribute is out of input tensor dimensions range.");
}
// create vector of capacity data_dimensions - axis_divider position // create vector of capacity data_dimensions - axis_divider position
std::vector<size_t> axes(data_shape.size() - axis); std::vector<size_t> axes(data_shape.size() - axis);
std::iota(std::begin(axes), std::end(axes), axis); std::iota(std::begin(axes), std::end(axes), axis);
......
...@@ -34,22 +34,18 @@ namespace ngraph ...@@ -34,22 +34,18 @@ namespace ngraph
auto data = inputs.at(0); auto data = inputs.at(0);
auto data_shape = data->get_shape(); auto data_shape = data->get_shape();
auto axes = node.get_attribute_value<std::vector<int64_t>>("axes"); auto axes = node.get_attribute_value<std::vector<int64_t>>("axes");
if (axes.empty())
{ ASSERT_VALID_ARGUMENT(node, !axes.empty()) << "'axes' attribute is mandatory.";
throw error::parameter::Value(
"Unsqueeze", node.get_name(), "axes attribute is mandatory.");
}
std::sort(std::begin(axes), std::end(axes), std::greater<int64_t>()); std::sort(std::begin(axes), std::end(axes), std::greater<int64_t>());
AxisVector input_order{reshape::get_default_axis_vector(data_shape.size())}; AxisVector input_order{reshape::get_default_axis_vector(data_shape.size())};
for (auto axis : axes) for (auto axis : axes)
{ {
if ((axis < 0) || (axis > data_shape.size())) ASSERT_VALID_ARGUMENT(node, axis >= 0 && axis <= data_shape.size())
{ << "provided 'axes' attribute is not valid.";
throw error::parameter::Value(
"Unsqueeze", node.get_name(), "provided axes attribute is not valid.");
}
data_shape.insert(std::next(std::begin(data_shape), axis), 1); data_shape.insert(std::next(std::begin(data_shape), axis), 1);
} }
......
...@@ -76,15 +76,11 @@ namespace ngraph ...@@ -76,15 +76,11 @@ namespace ngraph
auto data_shape = ng_input->get_shape(); auto data_shape = ng_input->get_shape();
auto reduction_axes = detail::get_reduction_axes(node); auto reduction_axes = detail::get_reduction_axes(node);
if (reduction_axes.size() > data_shape.size())
{ ASSERT_VALID_ARGUMENT(node, reduction_axes.size() <= data_shape.size())
throw error::parameter::Value(node.op_type(), << "provided reduction axes count (" << reduction_axes.size()
node.get_name(), << ") is larger than input tensor rank (" << data_shape.size() << ")";
"provided reduction axes count (" +
std::to_string(reduction_axes.size()) +
") is larger than input tensor rank (" +
std::to_string(data_shape.size()) + ")");
}
auto op_node = std::make_shared<OnnxOperator>(ng_input, reduction_axes); auto op_node = std::make_shared<OnnxOperator>(ng_input, reduction_axes);
std::int64_t keepdims = node.get_attribute_value<std::int64_t>("keepdims", 1); std::int64_t keepdims = node.get_attribute_value<std::int64_t>("keepdims", 1);
......
...@@ -76,18 +76,12 @@ namespace ngraph ...@@ -76,18 +76,12 @@ namespace ngraph
{ {
if (inferred_dims.at(idx) == 0) if (inferred_dims.at(idx) == 0)
{ {
if (idx < input_shape.size()) NGRAPH_ASSERT(idx < input_shape.size())
{ << "Node " << node_name
inferred_dims.at(idx) = input_shape.at(idx); << " cannot copy dimension from the input data shape because "
} "requested index is out of range.";
else
{ inferred_dims.at(idx) = input_shape.at(idx);
throw error::parameter::Value(
"Reshape",
node_name,
"can not copy dimension from the input data shape since requested "
"index is out of range.");
}
} }
} }
...@@ -99,14 +93,10 @@ namespace ngraph ...@@ -99,14 +93,10 @@ namespace ngraph
if (neg_value_it != std::end(inferred_dims)) if (neg_value_it != std::end(inferred_dims))
{ {
// only single '-1' value is allowed // only single '-1' value is allowed
if (std::find(std::next(neg_value_it), std::end(inferred_dims), -1) != NGRAPH_ASSERT(std::find(std::next(neg_value_it), std::end(inferred_dims), -1) ==
std::end(inferred_dims)) std::end(inferred_dims))
{ << "Node " << node_name << " more than one dimension is set to (-1). "
throw error::parameter::Value("Reshape", << "Only one dimension value can be inferred.";
node_name,
"more than one dimension is set to (-1). "
"Only one dimension value can be inferred.");
}
// Set dimension value to 1 temporarily to be able to calculate its value. // Set dimension value to 1 temporarily to be able to calculate its value.
*neg_value_it = 1; *neg_value_it = 1;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment