Unverified Commit 189cf3b7 authored by Michał Karzyński's avatar Michał Karzyński Committed by GitHub

[ONNX] Refactor exceptions to asserts (#1573)

parent 37174c90
......@@ -166,12 +166,12 @@ namespace ngraph
/// Asserts condition "cond" with an exception class of "T", at location "loc".
#define NGRAPH_ASSERT_STREAM_WITH_LOC(T, cond, loc) \
(cond ? ::ngraph::DummyAssertionHelper().get_stream() \
: ::ngraph::AssertionHelper<T>(__FILE__, __LINE__, #cond, loc).get_stream())
((cond) ? ::ngraph::DummyAssertionHelper().get_stream() \
: ::ngraph::AssertionHelper<T>(__FILE__, __LINE__, #cond, loc).get_stream())
/// Asserts condition "cond" with an exception class of "T", and no location specified.
#define NGRAPH_ASSERT_STREAM(T, cond) \
(cond ? ::ngraph::DummyAssertionHelper().get_stream() \
: ::ngraph::AssertionHelper<T>(__FILE__, __LINE__, #cond).get_stream())
((cond) ? ::ngraph::DummyAssertionHelper().get_stream() \
: ::ngraph::AssertionHelper<T>(__FILE__, __LINE__, #cond).get_stream())
/// Fails unconditionally with an exception class of "T", at location "loc".
#define NGRAPH_FAIL_STREAM_WITH_LOC(T, loc) \
::ngraph::AssertionHelper<T>(__FILE__, __LINE__, "", loc).get_stream()
......
......@@ -33,6 +33,22 @@ namespace ngraph
return result;
}
std::string Node::get_description() const
{
if (!get_name().empty())
{
return get_name();
}
std::stringstream stream;
for (std::size_t index = 0; index < m_output_names.size(); ++index)
{
stream << (index != 0 ? ", " : "");
stream << m_output_names.at(index).get();
}
return stream.str();
}
} // namespace onnx_import
} // namespace ngraph
......@@ -72,6 +72,11 @@ namespace ngraph
const std::string& op_type() const { return m_node_proto->op_type(); }
const std::string& get_name() const { return m_node_proto->name(); }
/// @brief Describe the ONNX Node to make debugging graphs easier
/// Function will return the Node's name if it has one, or the names of its outputs.
/// \return Description of Node
std::string get_description() const;
const std::vector<std::reference_wrapper<const std::string>>& get_output_names() const
{
return m_output_names;
......@@ -114,7 +119,7 @@ namespace ngraph
inline std::ostream& operator<<(std::ostream& outs, const Node& node)
{
return (outs << "<Node(" << node.op_type() << "): " << node.get_name() << ">");
return (outs << "<Node(" << node.op_type() << "): " << node.get_description() << ">");
}
} // namespace onnx_import
......
......@@ -16,6 +16,7 @@
#pragma once
#include "ngraph/assertion.hpp"
#include "ngraph/except.hpp"
namespace ngraph
......@@ -24,31 +25,29 @@ namespace ngraph
{
namespace error
{
struct NotSupported : ngraph_error
struct NotSupported : AssertionFailure
{
explicit NotSupported(const std::string& op_name,
const std::string& name,
const std::string& message)
: ngraph_error{op_name + " node (" + name + "): " + message}
explicit NotSupported(const std::string& what_arg)
: AssertionFailure(what_arg)
{
}
};
namespace parameter
struct InvalidArgument : AssertionFailure
{
struct Value : ngraph_error
explicit InvalidArgument(const std::string& what_arg)
: AssertionFailure(what_arg)
{
Value(const std::string& op_name,
const std::string& name,
const std::string& message)
: ngraph_error{op_name + " node (" + name + "): " + message}
{
}
};
} // namespace paramter
}
};
} // namespace error
} // namespace onnx_import
} // namespace ngraph
#define ASSERT_IS_SUPPORTED(node_, cond_) \
NGRAPH_ASSERT_STREAM(ngraph::onnx_import::error::NotSupported, cond_) << (node_) << " "
#define ASSERT_VALID_ARGUMENT(node_, cond_) \
NGRAPH_ASSERT_STREAM(ngraph::onnx_import::error::InvalidArgument, cond_) << (node_) << " "
......@@ -44,18 +44,8 @@ namespace ngraph
// float momentum{node.get_attribute_value<float>("momentum", 0.9f)};
bool training = false;
if (!is_test)
{
throw error::NotSupported("BatchNormalization",
node.get_name(),
"only 'is_test' mode is currently supported.");
}
if (!spatial)
{
throw error::NotSupported("BatchNormalization",
node.get_name(),
"only 'spatial' mode is currently supported.");
}
ASSERT_IS_SUPPORTED(node, is_test) << "only 'is_test' mode is supported.";
ASSERT_IS_SUPPORTED(node, spatial) << "only 'spatial' mode is supported.";
if (inputs.size() >= 5)
{
......
......@@ -108,15 +108,10 @@ namespace ngraph
int64_t groups{node.get_attribute_value<int64_t>("group", 1)};
// TODO: update to ASSERTION CHECK
if (groups < 0 || groups > data->get_shape().at(1) ||
groups > filters->get_shape().at(0))
{
throw error::parameter::Value{"Conv",
node.get_name(),
"incorrect value of 'group' attribute: " +
std::to_string(groups)};
}
ASSERT_VALID_ARGUMENT(node,
((groups >= 0) && (groups <= data->get_shape().at(1)) &&
(groups <= filters->get_shape().at(0))))
<< "incorrect value of 'group' attribute: " << groups;
auto strides = convpool::get_strides(node);
auto dilations = convpool::get_dilations(node);
......
......@@ -31,12 +31,8 @@ namespace ngraph
auto data = inputs.at(0);
auto axis = node.get_attribute_value<int64_t>("axis", 1);
if (axis < 0 || axis > data->get_shape().size())
{
throw error::parameter::Value("Flatten node (",
node.get_name(),
"): provided axis attribute is not valid.");
}
ASSERT_VALID_ARGUMENT(node, (axis >= 0) && (axis <= data->get_shape().size()))
<< "provided 'axis' attribute is not valid.";
return {reshape::flatten(data, axis)};
}
......
......@@ -45,19 +45,12 @@ namespace ngraph
if (output_shape.empty() && ng_inputs.size() == 2)
{
// Currently only support Constant node.
if (ng_inputs.at(1)->description() == "Constant")
{
auto output_shape_node =
std::dynamic_pointer_cast<ngraph::op::Constant>(ng_inputs.at(1));
output_shape = output_shape_node->get_vector<std::size_t>();
}
else
{
throw error::NotSupported("Reshape",
node.get_name(),
"doesn't support "
"shape input of other type than Constant.");
}
ASSERT_IS_SUPPORTED(node, ng_inputs.at(1)->description() == "Constant")
<< "doesn't support shape input of other type than Constant.";
auto output_shape_node =
std::dynamic_pointer_cast<ngraph::op::Constant>(ng_inputs.at(1));
output_shape = output_shape_node->get_vector<std::size_t>();
}
// Do nothing if there is no shape argument nor second node input.
else if (output_shape.empty())
......
......@@ -39,13 +39,11 @@ namespace ngraph
{
axis = data_shape.size() + axis;
}
else if (axis >= data_shape.size())
{
throw error::parameter::Value(
"Softmax node (",
node.get_name(),
"): provided axis attribute is out of input tensor dimensions range.");
}
ASSERT_VALID_ARGUMENT(node, axis < data_shape.size())
<< "provided 'axis' value:" << axis
<< " is out of input tensor dimensions range.";
// create vector of capacity data_dimensions - axis_divider position
std::vector<size_t> axes(data_shape.size() - axis);
std::iota(std::begin(axes), std::end(axes), axis);
......
......@@ -34,22 +34,18 @@ namespace ngraph
auto data = inputs.at(0);
auto data_shape = data->get_shape();
auto axes = node.get_attribute_value<std::vector<int64_t>>("axes");
if (axes.empty())
{
throw error::parameter::Value(
"Unsqueeze", node.get_name(), "axes attribute is mandatory.");
}
ASSERT_VALID_ARGUMENT(node, !axes.empty()) << "'axes' attribute is mandatory.";
std::sort(std::begin(axes), std::end(axes), std::greater<int64_t>());
AxisVector input_order{reshape::get_default_axis_vector(data_shape.size())};
for (auto axis : axes)
{
if ((axis < 0) || (axis > data_shape.size()))
{
throw error::parameter::Value(
"Unsqueeze", node.get_name(), "provided axes attribute is not valid.");
}
ASSERT_VALID_ARGUMENT(node, axis >= 0 && axis <= data_shape.size())
<< "provided 'axes' attribute is not valid.";
data_shape.insert(std::next(std::begin(data_shape), axis), 1);
}
......
......@@ -76,15 +76,11 @@ namespace ngraph
auto data_shape = ng_input->get_shape();
auto reduction_axes = detail::get_reduction_axes(node);
if (reduction_axes.size() > data_shape.size())
{
throw error::parameter::Value(node.op_type(),
node.get_name(),
"provided reduction axes count (" +
std::to_string(reduction_axes.size()) +
") is larger than input tensor rank (" +
std::to_string(data_shape.size()) + ")");
}
ASSERT_VALID_ARGUMENT(node, reduction_axes.size() <= data_shape.size())
<< "provided reduction axes count (" << reduction_axes.size()
<< ") is larger than input tensor rank (" << data_shape.size() << ")";
auto op_node = std::make_shared<OnnxOperator>(ng_input, reduction_axes);
std::int64_t keepdims = node.get_attribute_value<std::int64_t>("keepdims", 1);
......
......@@ -76,18 +76,12 @@ namespace ngraph
{
if (inferred_dims.at(idx) == 0)
{
if (idx < input_shape.size())
{
inferred_dims.at(idx) = input_shape.at(idx);
}
else
{
throw error::parameter::Value(
"Reshape",
node_name,
"can not copy dimension from the input data shape since requested "
"index is out of range.");
}
NGRAPH_ASSERT(idx < input_shape.size())
<< "Node " << node_name
<< " cannot copy dimension from the input data shape because "
"requested index is out of range.";
inferred_dims.at(idx) = input_shape.at(idx);
}
}
......@@ -99,14 +93,10 @@ namespace ngraph
if (neg_value_it != std::end(inferred_dims))
{
// only single '-1' value is allowed
if (std::find(std::next(neg_value_it), std::end(inferred_dims), -1) !=
std::end(inferred_dims))
{
throw error::parameter::Value("Reshape",
node_name,
"more than one dimension is set to (-1). "
"Only one dimension value can be inferred.");
}
NGRAPH_ASSERT(std::find(std::next(neg_value_it), std::end(inferred_dims), -1) ==
std::end(inferred_dims))
<< "Node " << node_name << " more than one dimension is set to (-1). "
<< "Only one dimension value can be inferred.";
// Set dimension value to 1 temporarily to be able to calculate its value.
*neg_value_it = 1;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment