Unverified Commit c386da90 authored by Adam Procter's avatar Adam Procter Committed by GitHub

Better node validation error messages (#1533)

parent 132b5305
...@@ -344,6 +344,20 @@ NodeVector Node::get_users() const ...@@ -344,6 +344,20 @@ NodeVector Node::get_users() const
return result; return result;
} }
std::string ngraph::node_validation_assertion_string(const Node* node)
{
std::stringstream ss;
ss << "While validating node '" << *node << "' of type '" << node->description() << "'";
return ss.str();
}
void ngraph::check_new_args_count(const Node* node, const NodeVector& new_args)
{
NODE_VALIDATION_ASSERT(node, new_args.size() == node->get_arguments().size())
<< "copy_with_new_args() expected " << node->get_arguments().size() << " argument"
<< (node->get_arguments().size() == 1 ? "" : "s") << " but got " << new_args.size();
}
const std::shared_ptr<Node>& ngraph::check_single_output_arg(const std::shared_ptr<Node>& node, const std::shared_ptr<Node>& ngraph::check_single_output_arg(const std::shared_ptr<Node>& node,
size_t i) size_t i)
{ {
...@@ -361,13 +375,6 @@ const NodeVector& ngraph::check_single_output_args(const NodeVector& args) ...@@ -361,13 +375,6 @@ const NodeVector& ngraph::check_single_output_args(const NodeVector& args)
return args; return args;
} }
std::string ngraph::type_check_assert_string(const Node* node)
{
std::stringstream ss;
ss << "While type-checking node " << *node;
return ss.str();
}
void Node::validate_and_infer_elementwise(element::Type result_type) void Node::validate_and_infer_elementwise(element::Type result_type)
{ {
const element::Type& element_type = get_input_element_type(0); const element::Type& element_type = get_input_element_type(0);
...@@ -376,12 +383,12 @@ void Node::validate_and_infer_elementwise(element::Type result_type) ...@@ -376,12 +383,12 @@ void Node::validate_and_infer_elementwise(element::Type result_type)
{ {
for (size_t i = 1; i < get_input_size(); ++i) for (size_t i = 1; i < get_input_size(); ++i)
{ {
TYPE_CHECK_ASSERT(this, get_input_element_type(i) == element_type) NODE_VALIDATION_ASSERT(this, get_input_element_type(i) == element_type)
<< "Argument 0 element type " << element_type << "Argument 0 element type " << element_type
<< " differs in element type from argument " << i << " " << *get_argument(i) << " differs in element type from argument " << i << " " << *get_argument(i)
<< " element type " << get_input_element_type(i); << " element type " << get_input_element_type(i);
TYPE_CHECK_ASSERT(this, get_input_shape(i) == shape) NODE_VALIDATION_ASSERT(this, get_input_shape(i) == shape)
<< "Argument 0 shape " << shape << " differs in shape from argument " << i << " " << "Argument 0 shape " << shape << " differs in shape from argument " << i << " "
<< *get_argument(i) << " shape " << get_input_shape(i); << *get_argument(i) << " shape " << get_input_shape(i);
} }
...@@ -391,16 +398,16 @@ void Node::validate_and_infer_elementwise(element::Type result_type) ...@@ -391,16 +398,16 @@ void Node::validate_and_infer_elementwise(element::Type result_type)
void Node::validate_and_infer_elementwise_arithmetic() void Node::validate_and_infer_elementwise_arithmetic()
{ {
TYPE_CHECK_ASSERT(this, get_input_element_type(0) != element::boolean) NODE_VALIDATION_ASSERT(this, get_input_element_type(0) != element::boolean)
<< "Operands for arithmetic operators must have numeric element type but have element type " << "Arguments cannot have boolean element type (argument element type: "
<< get_input_element_type(0); << get_input_element_type(0) << ").";
validate_and_infer_elementwise(get_input_element_type(0)); validate_and_infer_elementwise(get_input_element_type(0));
} }
void Node::validate_and_infer_elementwise_logical() void Node::validate_and_infer_elementwise_logical()
{ {
TYPE_CHECK_ASSERT(this, get_input_element_type(0) == element::boolean) NODE_VALIDATION_ASSERT(this, get_input_element_type(0) == element::boolean)
<< "Operands for logical operators must have boolean element type but have element type " << "Operands for logical operators must have boolean element type but have element type "
<< get_input_element_type(0); << get_input_element_type(0) << ".";
validate_and_infer_elementwise(get_input_element_type(0)); validate_and_infer_elementwise(get_input_element_type(0));
} }
...@@ -58,7 +58,11 @@ namespace ngraph ...@@ -58,7 +58,11 @@ namespace ngraph
const std::shared_ptr<Node>& dst_node, const std::shared_ptr<Node>& dst_node,
const std::shared_ptr<Node>& new_node); const std::shared_ptr<Node>& new_node);
std::string type_check_assert_string(const Node* node); std::string node_validation_assertion_string(const Node* node);
const std::shared_ptr<Node>& check_single_output_arg(const std::shared_ptr<Node>& node,
size_t i);
const NodeVector& check_single_output_args(const NodeVector& args);
const std::shared_ptr<Node>& check_single_output_arg(const std::shared_ptr<Node>& node, const std::shared_ptr<Node>& check_single_output_arg(const std::shared_ptr<Node>& node,
size_t i); size_t i);
...@@ -223,22 +227,25 @@ namespace ngraph ...@@ -223,22 +227,25 @@ namespace ngraph
Placement m_placement = Placement::DEFAULT; Placement m_placement = Placement::DEFAULT;
}; };
class TypeCheckError : public AssertionFailure class NodeValidationError : public AssertionFailure
{ {
public: public:
TypeCheckError(std::string what) NodeValidationError(std::string what)
: AssertionFailure(what) : AssertionFailure(what)
{ {
} }
TypeCheckError(const char* what) NodeValidationError(const char* what)
: AssertionFailure(what) : AssertionFailure(what)
{ {
} }
}; };
void check_new_args_count(const Node* node, const NodeVector& new_args);
} }
#define TYPE_CHECK_ASSERT(node, cond) \ #define NODE_VALIDATION_ASSERT(node, cond) \
NGRAPH_ASSERT_STREAM_WITH_LOC( \ NGRAPH_ASSERT_STREAM_WITH_LOC( \
::ngraph::TypeCheckError, cond, ::ngraph::type_check_assert_string(node)) ::ngraph::NodeValidationError, cond, ::ngraph::node_validation_assertion_string(node))
#define TYPE_CHECK_FAIL(node) \ #define NODE_VALIDATION_FAIL(node) \
NGRAPH_FAIL_STREAM_WITH_LOC(::ngraph::TypeCheckError, ::ngraph::type_check_assert_string(node)) NGRAPH_FAIL_STREAM_WITH_LOC(::ngraph::NodeValidationError, \
::ngraph::node_validation_assertion_string(node))
...@@ -29,10 +29,7 @@ op::Abs::Abs(const shared_ptr<Node>& arg) ...@@ -29,10 +29,7 @@ op::Abs::Abs(const shared_ptr<Node>& arg)
shared_ptr<Node> op::Abs::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Abs::copy_with_new_args(const NodeVector& new_args) const
{ {
if (new_args.size() != 1) check_new_args_count(this, new_args);
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Abs>(new_args.at(0)); return make_shared<Abs>(new_args.at(0));
} }
......
...@@ -40,10 +40,7 @@ op::Acos::Acos(const shared_ptr<Node>& arg) ...@@ -40,10 +40,7 @@ op::Acos::Acos(const shared_ptr<Node>& arg)
shared_ptr<Node> op::Acos::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Acos::copy_with_new_args(const NodeVector& new_args) const
{ {
if (new_args.size() != 1) check_new_args_count(this, new_args);
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Acos>(new_args.at(0)); return make_shared<Acos>(new_args.at(0));
} }
......
...@@ -27,10 +27,7 @@ op::Add::Add(const shared_ptr<Node>& arg0, const shared_ptr<Node>& arg1) ...@@ -27,10 +27,7 @@ op::Add::Add(const shared_ptr<Node>& arg0, const shared_ptr<Node>& arg1)
shared_ptr<Node> op::Add::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Add::copy_with_new_args(const NodeVector& new_args) const
{ {
if (new_args.size() != 2) check_new_args_count(this, new_args);
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Add>(new_args.at(0), new_args.at(1)); return make_shared<Add>(new_args.at(0), new_args.at(1));
} }
......
...@@ -27,19 +27,17 @@ op::AllReduce::AllReduce(const shared_ptr<Node>& arg) ...@@ -27,19 +27,17 @@ op::AllReduce::AllReduce(const shared_ptr<Node>& arg)
void op::AllReduce::validate_and_infer_types() void op::AllReduce::validate_and_infer_types()
{ {
set_output_type(0, get_input_element_type(0), get_input_shape(0)); NODE_VALIDATION_ASSERT(this,
get_input_element_type(0) == element::f32 ||
get_input_element_type(0) == element::f64)
<< "Only element types f32 and f64 are supported (argument element type: "
<< get_input_element_type(0) << ").";
if ((get_input_element_type(0) != element::f32) && (get_input_element_type(0) != element::f64)) set_output_type(0, get_input_element_type(0), get_input_shape(0));
{
throw ngraph_error("Unsupported data type for AllReduce");
}
} }
shared_ptr<Node> op::AllReduce::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::AllReduce::copy_with_new_args(const NodeVector& new_args) const
{ {
if (new_args.size() != 1) check_new_args_count(this, new_args);
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<AllReduce>(new_args.at(0)); return make_shared<AllReduce>(new_args.at(0));
} }
...@@ -27,9 +27,6 @@ op::And::And(const shared_ptr<Node>& arg0, const shared_ptr<Node>& arg1) ...@@ -27,9 +27,6 @@ op::And::And(const shared_ptr<Node>& arg0, const shared_ptr<Node>& arg1)
shared_ptr<Node> op::And::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::And::copy_with_new_args(const NodeVector& new_args) const
{ {
if (new_args.size() != 2) check_new_args_count(this, new_args);
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<And>(new_args.at(0), new_args.at(1)); return make_shared<And>(new_args.at(0), new_args.at(1));
} }
...@@ -21,9 +21,6 @@ using namespace ngraph; ...@@ -21,9 +21,6 @@ using namespace ngraph;
shared_ptr<Node> op::ArgMax::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::ArgMax::copy_with_new_args(const NodeVector& new_args) const
{ {
if (new_args.size() != 1) check_new_args_count(this, new_args);
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<ArgMax>(new_args.at(0), m_axis, this->get_element_type()); return make_shared<ArgMax>(new_args.at(0), m_axis, this->get_element_type());
} }
...@@ -21,9 +21,6 @@ using namespace ngraph; ...@@ -21,9 +21,6 @@ using namespace ngraph;
shared_ptr<Node> op::ArgMin::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::ArgMin::copy_with_new_args(const NodeVector& new_args) const
{ {
if (new_args.size() != 1) check_new_args_count(this, new_args);
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<ArgMin>(new_args.at(0), m_axis, this->get_element_type()); return make_shared<ArgMin>(new_args.at(0), m_axis, this->get_element_type());
} }
...@@ -39,10 +39,7 @@ op::Asin::Asin(const shared_ptr<Node>& arg) ...@@ -39,10 +39,7 @@ op::Asin::Asin(const shared_ptr<Node>& arg)
shared_ptr<Node> op::Asin::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Asin::copy_with_new_args(const NodeVector& new_args) const
{ {
if (new_args.size() != 1) check_new_args_count(this, new_args);
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Asin>(new_args.at(0)); return make_shared<Asin>(new_args.at(0));
} }
......
...@@ -38,10 +38,7 @@ op::Atan::Atan(const shared_ptr<Node>& arg) ...@@ -38,10 +38,7 @@ op::Atan::Atan(const shared_ptr<Node>& arg)
shared_ptr<Node> op::Atan::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Atan::copy_with_new_args(const NodeVector& new_args) const
{ {
if (new_args.size() != 1) check_new_args_count(this, new_args);
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Atan>(new_args.at(0)); return make_shared<Atan>(new_args.at(0));
} }
......
This diff is collapsed.
...@@ -44,25 +44,32 @@ void op::Broadcast::validate_and_infer_types() ...@@ -44,25 +44,32 @@ void op::Broadcast::validate_and_infer_types()
Shape target_shape = m_shape; Shape target_shape = m_shape;
for (auto i = m_broadcast_axes.rbegin(); i != m_broadcast_axes.rend(); ++i) for (auto i = m_broadcast_axes.rbegin(); i != m_broadcast_axes.rend(); ++i)
{ {
if (*i >= target_shape.size()) NODE_VALIDATION_ASSERT(this, *i < target_shape.size())
{ << "Broadcast axis index (" << *i << ") exceeds target shape rank "
throw ngraph_error("Broadcast axis exceeds target shape rank"); << "(broadcast axes: " << m_broadcast_axes << ", target shape: " << target_shape
} << ").";
target_shape.erase(target_shape.begin() + *i); target_shape.erase(target_shape.begin() + *i);
} }
if (Shape{target_shape} != get_input_shape(0))
{ // TODO(amprocte): We can probably have a more helpful error message here.
throw ngraph_error("Broadcast arg, shape, and axes are incompatible"); // There are two things that can go wrong, which are being picked up in
} // one fell swoop by this check: either the number of broadcast axes is not
// enough (arg->get_shape().size() + broadcast_axes.size() != shape.size())
// or there is a mismatch with one of the pre-broadcast axis lengths
// (i.e. target_shape.size() == arg->get_shape.size() but there is some i
// where target_shape[i] != arg->get_shape[i]).
NODE_VALIDATION_ASSERT(this, target_shape == get_input_shape(0))
<< "Broadcast argument shape, target shape, and axes are incompatible "
<< "(argument shape: " << get_input_shape(0) << ", target shape: " << m_shape
<< ", broadcast axes: " << m_broadcast_axes << ").";
set_output_type(0, get_input_element_type(0), m_shape); set_output_type(0, get_input_element_type(0), m_shape);
} }
shared_ptr<Node> op::Broadcast::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Broadcast::copy_with_new_args(const NodeVector& new_args) const
{ {
if (new_args.size() != 1) check_new_args_count(this, new_args);
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Broadcast>(new_args.at(0), m_shape, m_broadcast_axes); return make_shared<Broadcast>(new_args.at(0), m_shape, m_broadcast_axes);
} }
......
...@@ -27,9 +27,6 @@ op::Ceiling::Ceiling(const shared_ptr<Node>& arg) ...@@ -27,9 +27,6 @@ op::Ceiling::Ceiling(const shared_ptr<Node>& arg)
shared_ptr<Node> op::Ceiling::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Ceiling::copy_with_new_args(const NodeVector& new_args) const
{ {
if (new_args.size() != 1) check_new_args_count(this, new_args);
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Ceiling>(new_args.at(0)); return make_shared<Ceiling>(new_args.at(0));
} }
...@@ -32,56 +32,58 @@ op::Concat::Concat(const NodeVector& args, size_t concatenation_axis) ...@@ -32,56 +32,58 @@ op::Concat::Concat(const NodeVector& args, size_t concatenation_axis)
void op::Concat::validate_and_infer_types() void op::Concat::validate_and_infer_types()
{ {
if (m_inputs.size() < 1) NODE_VALIDATION_ASSERT(this, m_inputs.size() >= 1) << "At least one argument required.";
{
throw ngraph_error("At least one argument required");
}
auto& input_0 = get_inputs().at(0);
auto input_0_shape = input_0.get_shape();
if (m_concatenation_axis >= input_0_shape.size())
{
throw ngraph_error("Concatenation axis is out of bounds");
}
size_t concatenation_axis_length = input_0_shape.at(m_concatenation_axis); Shape first_input_shape = get_input_shape(0);
auto& input_0_element_type = input_0.get_element_type(); size_t expected_rank = first_input_shape.size();
element::Type expected_et = get_input_element_type(0);
for (auto i = 1; i < get_inputs().size(); i++) for (auto i = 1; i < get_inputs().size(); i++)
{ {
auto& input_i = get_inputs().at(i); NODE_VALIDATION_ASSERT(this, get_input_shape(i).size() == expected_rank)
auto input_i_shape = input_i.get_shape(); << "Not all arguments have the same rank: argument 0 has shape " << first_input_shape
if (input_i_shape.size() != input_0_shape.size()) << " of rank " << expected_rank << " but argument " << i << " has shape "
{ << get_input_shape(i) << " of rank " << get_input_shape(i).size() << ".";
throw ngraph_error("Arguments to concat do not have same rank");
NODE_VALIDATION_ASSERT(this, get_input_element_type(i) == expected_et)
<< "Not all arguments have the same element type: argument 0 has element type "
<< expected_et << " but argument " << i << " has element type "
<< get_input_element_type(i) << ".";
} }
if (input_i.get_element_type() != input_0_element_type) NODE_VALIDATION_ASSERT(this, m_concatenation_axis < expected_rank)
{ << "Concatenation axis (" << m_concatenation_axis << ") is out of bounds (inputs have rank "
throw ngraph_error("Argument element types do not match"); << expected_rank << ").";
}
size_t concatenation_axis_output_length = first_input_shape.at(m_concatenation_axis);
for (auto j = 0; j < input_i_shape.size(); j++) for (auto i = 1; i < get_inputs().size(); i++)
{
for (auto j = 0; j < get_input_shape(i).size(); j++)
{ {
if (j != m_concatenation_axis && input_0_shape.at(j) != input_i_shape.at(j)) if (j != m_concatenation_axis)
{ {
throw ngraph_error( NODE_VALIDATION_ASSERT(this, first_input_shape[j] == get_input_shape(i)[j])
"Arguments to concat do not have same dimension on a non-concatenation axis"); << "Dimensions of argument " << i << " do not match for axis " << j
<< " (expected " << first_input_shape[j] << ", got " << get_input_shape(i)[j]
<< ").";
} }
else if (j == m_concatenation_axis) else
{ {
concatenation_axis_length += input_i_shape.at(j); concatenation_axis_output_length += get_input_shape(i)[j];
} }
} }
} }
vector<size_t> concatenated_shape = input_0_shape;
concatenated_shape.at(m_concatenation_axis) = concatenation_axis_length;
set_output_type(0, input_0_element_type, concatenated_shape); Shape concatenated_shape = first_input_shape;
concatenated_shape[m_concatenation_axis] = concatenation_axis_output_length;
set_output_type(0, expected_et, concatenated_shape);
} }
shared_ptr<Node> op::Concat::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Concat::copy_with_new_args(const NodeVector& new_args) const
{ {
// TODO(amprocte): Should we check the new_args count here?
return make_shared<Concat>(new_args, m_concatenation_axis); return make_shared<Concat>(new_args, m_concatenation_axis);
} }
......
...@@ -151,10 +151,7 @@ vector<string> op::Constant::get_value_strings() const ...@@ -151,10 +151,7 @@ vector<string> op::Constant::get_value_strings() const
shared_ptr<Node> op::Constant::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Constant::copy_with_new_args(const NodeVector& new_args) const
{ {
if (new_args.size() != 0) check_new_args_count(this, new_args);
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Constant>(m_element_type, m_shape, m_data); return make_shared<Constant>(m_element_type, m_shape, m_data);
} }
......
...@@ -46,17 +46,19 @@ namespace ngraph ...@@ -46,17 +46,19 @@ namespace ngraph
, m_data(ngraph::aligned_alloc(m_element_type.size(), , m_data(ngraph::aligned_alloc(m_element_type.size(),
shape_size(m_shape) * m_element_type.size())) shape_size(m_shape) * m_element_type.size()))
{ {
NODE_VALIDATION_ASSERT(this,
values.size() == 1 || values.size() == shape_size(m_shape))
<< "Did not get the expected number of literals for a constant of shape "
<< m_shape << " (got " << values.size() << ", expected "
<< (shape_size(m_shape) == 1 ? "" : "1 or ") << shape_size(m_shape) << ").";
if (values.size() == 1) if (values.size() == 1)
{ {
write_values(std::vector<T>(shape_size(m_shape), values[0])); write_values(std::vector<T>(shape_size(m_shape), values[0]));
} }
else if (values.size() == shape_size(m_shape))
{
write_values(values);
}
else else
{ {
throw ngraph_error("Constant does not have the expected number of literals"); write_values(values);
} }
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
...@@ -74,10 +76,11 @@ namespace ngraph ...@@ -74,10 +76,11 @@ namespace ngraph
, m_data(ngraph::aligned_alloc(m_element_type.size(), , m_data(ngraph::aligned_alloc(m_element_type.size(),
shape_size(m_shape) * m_element_type.size())) shape_size(m_shape) * m_element_type.size()))
{ {
if (values.size() != shape_size(m_shape)) NODE_VALIDATION_ASSERT(this, values.size() == shape_size(m_shape))
{ << "Did not get the expected number of literals for a constant of shape "
throw ngraph_error("Constant does not have the expected number of literals"); << m_shape << " (got " << values.size() << ", expected " << shape_size(m_shape)
} << ".";
std::vector<double> dvalues = parse_string<double>(values); std::vector<double> dvalues = parse_string<double>(values);
write_values(dvalues); write_values(dvalues);
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
......
...@@ -35,10 +35,7 @@ void op::Convert::validate_and_infer_types() ...@@ -35,10 +35,7 @@ void op::Convert::validate_and_infer_types()
shared_ptr<Node> op::Convert::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Convert::copy_with_new_args(const NodeVector& new_args) const
{ {
if (new_args.size() != 1) check_new_args_count(this, new_args);
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Convert>(new_args.at(0), m_element_type); return make_shared<Convert>(new_args.at(0), m_element_type);
} }
......
...@@ -379,10 +379,7 @@ op::Convolution::Convolution(const shared_ptr<Node>& data_batch, const shared_pt ...@@ -379,10 +379,7 @@ op::Convolution::Convolution(const shared_ptr<Node>& data_batch, const shared_pt
shared_ptr<Node> op::Convolution::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Convolution::copy_with_new_args(const NodeVector& new_args) const
{ {
if (new_args.size() != 2) check_new_args_count(this, new_args);
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Convolution>(new_args.at(0), return make_shared<Convolution>(new_args.at(0),
new_args.at(1), new_args.at(1),
m_window_movement_strides, m_window_movement_strides,
...@@ -584,10 +581,7 @@ void op::ConvolutionBackpropData::generate_adjoints(autodiff::Adjoints& adjoints ...@@ -584,10 +581,7 @@ void op::ConvolutionBackpropData::generate_adjoints(autodiff::Adjoints& adjoints
shared_ptr<Node> op::ConvolutionBackpropData::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::ConvolutionBackpropData::copy_with_new_args(const NodeVector& new_args) const
{ {
if (new_args.size() != 2) check_new_args_count(this, new_args);
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<ConvolutionBackpropData>(m_data_batch_shape, return make_shared<ConvolutionBackpropData>(m_data_batch_shape,
new_args.at(0), new_args.at(0),
new_args.at(1), new_args.at(1),
...@@ -687,10 +681,7 @@ void op::ConvolutionBackpropFilters::validate_and_infer_types() ...@@ -687,10 +681,7 @@ void op::ConvolutionBackpropFilters::validate_and_infer_types()
shared_ptr<Node> shared_ptr<Node>
op::ConvolutionBackpropFilters::copy_with_new_args(const NodeVector& new_args) const op::ConvolutionBackpropFilters::copy_with_new_args(const NodeVector& new_args) const
{ {
if (new_args.size() != 2) check_new_args_count(this, new_args);
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<ConvolutionBackpropFilters>(new_args.at(0), return make_shared<ConvolutionBackpropFilters>(new_args.at(0),
m_filters_shape, m_filters_shape,
new_args.at(1), new_args.at(1),
......
...@@ -30,10 +30,7 @@ op::Cos::Cos(const shared_ptr<Node>& arg) ...@@ -30,10 +30,7 @@ op::Cos::Cos(const shared_ptr<Node>& arg)
shared_ptr<Node> op::Cos::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Cos::copy_with_new_args(const NodeVector& new_args) const
{ {
if (new_args.size() != 1) check_new_args_count(this, new_args);
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Cos>(new_args.at(0)); return make_shared<Cos>(new_args.at(0));
} }
......
...@@ -29,10 +29,7 @@ op::Cosh::Cosh(const shared_ptr<Node>& arg) ...@@ -29,10 +29,7 @@ op::Cosh::Cosh(const shared_ptr<Node>& arg)
shared_ptr<Node> op::Cosh::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Cosh::copy_with_new_args(const NodeVector& new_args) const
{ {
if (new_args.size() != 1) check_new_args_count(this, new_args);
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Cosh>(new_args.at(0)); return make_shared<Cosh>(new_args.at(0));
} }
......
...@@ -29,10 +29,7 @@ op::Divide::Divide(const shared_ptr<Node>& arg0, const shared_ptr<Node>& arg1) ...@@ -29,10 +29,7 @@ op::Divide::Divide(const shared_ptr<Node>& arg0, const shared_ptr<Node>& arg1)
shared_ptr<Node> op::Divide::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Divide::copy_with_new_args(const NodeVector& new_args) const
{ {
if (new_args.size() != 2) check_new_args_count(this, new_args);
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Divide>(new_args.at(0), new_args.at(1)); return make_shared<Divide>(new_args.at(0), new_args.at(1));
} }
......
...@@ -56,30 +56,33 @@ void op::Dot::validate_and_infer_types() ...@@ -56,30 +56,33 @@ void op::Dot::validate_and_infer_types()
(input_0.get_shape().size() == 0 || input_1.get_shape().size() == 0) ? 0 : 1; (input_0.get_shape().size() == 0 || input_1.get_shape().size() == 0) ? 0 : 1;
} }
if (input_0.get_element_type() != input_1.get_element_type()) NODE_VALIDATION_ASSERT(this, input_0.get_element_type() == input_1.get_element_type())
{ << "Arguments do not have the same element type (arg0 element type: "
throw ngraph_error("Arguments to dot must have the same element type"); << input_0.get_element_type() << ", arg1 element type: " << input_1.get_element_type()
} << ").";
Shape input_0_shape = input_0.get_shape(); Shape input_0_shape = input_0.get_shape();
Shape input_1_shape = input_1.get_shape(); Shape input_1_shape = input_1.get_shape();
if (m_reduction_axes_count > input_0_shape.size()) NODE_VALIDATION_ASSERT(this,
{ m_reduction_axes_count <= input_0_shape.size() &&
throw ngraph_error("Dot has too many axes for arg0"); m_reduction_axes_count <= input_1_shape.size())
} << "Reduction axes count (" << m_reduction_axes_count
<< ") is too large (arg0 shape: " << input_0_shape << ", arg1 shape: " << input_1_shape
if (m_reduction_axes_count > input_1_shape.size()) << ").";
{
throw ngraph_error("Dot has too many axes for arg1");
}
for (size_t i = 0; i < m_reduction_axes_count; i++) for (size_t i = 0; i < m_reduction_axes_count; i++)
{ {
if (input_0_shape[input_0_shape.size() - m_reduction_axes_count + i] != input_1_shape[i]) size_t axis_index_arg0 = input_0_shape.size() - m_reduction_axes_count + i;
{ size_t axis_index_arg1 = i;
throw ngraph_error("Dot axes do not have same length");
} NODE_VALIDATION_ASSERT(this,
input_0_shape[axis_index_arg0] == input_1_shape[axis_index_arg1])
<< "Paired axes (axis " << axis_index_arg0 << " from arg0, axis " << axis_index_arg1
<< " from arg1) "
<< "do not have same length (arg0 shape: " << input_0_shape
<< ", arg1 shape: " << input_1_shape << ", "
<< "reduction axes count: " << m_reduction_axes_count << ").";
} }
Shape result_shape(input_0_shape.size() + input_1_shape.size() - 2 * m_reduction_axes_count); Shape result_shape(input_0_shape.size() + input_1_shape.size() - 2 * m_reduction_axes_count);
......
...@@ -56,10 +56,7 @@ namespace ngraph ...@@ -56,10 +56,7 @@ namespace ngraph
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override copy_with_new_args(const NodeVector& new_args) const override
{ {
if (new_args.size() != 2) check_new_args_count(this, new_args);
{
throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<Dot>( return std::make_shared<Dot>(
new_args.at(0), new_args.at(1), m_reduction_axes_count); new_args.at(0), new_args.at(1), m_reduction_axes_count);
} }
......
...@@ -27,9 +27,6 @@ op::Equal::Equal(const shared_ptr<Node>& arg0, const shared_ptr<Node>& arg1) ...@@ -27,9 +27,6 @@ op::Equal::Equal(const shared_ptr<Node>& arg0, const shared_ptr<Node>& arg1)
shared_ptr<Node> op::Equal::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Equal::copy_with_new_args(const NodeVector& new_args) const
{ {
if (new_args.size() != 2) check_new_args_count(this, new_args);
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Equal>(new_args.at(0), new_args.at(1)); return make_shared<Equal>(new_args.at(0), new_args.at(1));
} }
...@@ -28,10 +28,7 @@ op::Exp::Exp(const shared_ptr<Node>& arg) ...@@ -28,10 +28,7 @@ op::Exp::Exp(const shared_ptr<Node>& arg)
shared_ptr<Node> op::Exp::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Exp::copy_with_new_args(const NodeVector& new_args) const
{ {
if (new_args.size() != 1) check_new_args_count(this, new_args);
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Exp>(new_args.at(0)); return make_shared<Exp>(new_args.at(0));
} }
......
...@@ -27,9 +27,6 @@ op::Floor::Floor(const shared_ptr<Node>& arg) ...@@ -27,9 +27,6 @@ op::Floor::Floor(const shared_ptr<Node>& arg)
shared_ptr<Node> op::Floor::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Floor::copy_with_new_args(const NodeVector& new_args) const
{ {
if (new_args.size() != 1) check_new_args_count(this, new_args);
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Floor>(new_args.at(0)); return make_shared<Floor>(new_args.at(0));
} }
...@@ -30,18 +30,22 @@ op::FunctionCall::FunctionCall(shared_ptr<Function> function, const NodeVector& ...@@ -30,18 +30,22 @@ op::FunctionCall::FunctionCall(shared_ptr<Function> function, const NodeVector&
// TODO : [nikolayk] this needs to be rewritten as follows // TODO : [nikolayk] this needs to be rewritten as follows
// for each i : FunctionCall->get_inputs.at(i).get_tensor_view_type == // for each i : FunctionCall->get_inputs.at(i).get_tensor_view_type ==
// flatten(function_parms).at(i) // flatten(function_parms).at(i)
if (get_input_size() != function_params.size()) NODE_VALIDATION_ASSERT(this, get_input_size() == function_params.size())
{ << "Number of arguments (" << get_input_size() << ") does not match "
throw ngraph_error("Wrong number of arguments."); << "number of function parameters (" << function_params.size() << ").";
}
for (size_t i = 0; i < get_input_size(); i++) for (size_t i = 0; i < get_input_size(); i++)
{ {
if (get_input_element_type(i) != function->get_parameters().at(i)->get_element_type() || NODE_VALIDATION_ASSERT(
get_input_shape(i) != function->get_parameters().at(i)->get_shape()) this, get_input_element_type(i) == function->get_parameters()[i]->get_element_type())
{ << "Element type mismatch for argument " << i << " (argument has type "
throw ngraph_error("Function argument type mismatch."); << get_input_element_type(i) << ", function expects type "
} << function->get_parameters()[i]->get_element_type();
NODE_VALIDATION_ASSERT(this,
get_input_shape(i) == function->get_parameters()[i]->get_shape())
<< "Shape mismatch for argument " << i << " (argument has shape " << get_input_shape(i)
<< ", function expects shape " << function->get_parameters()[i]->get_shape();
} }
set_output_size(m_function->get_output_size()); set_output_size(m_function->get_output_size());
...@@ -53,6 +57,7 @@ op::FunctionCall::FunctionCall(shared_ptr<Function> function, const NodeVector& ...@@ -53,6 +57,7 @@ op::FunctionCall::FunctionCall(shared_ptr<Function> function, const NodeVector&
shared_ptr<Node> op::FunctionCall::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::FunctionCall::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args);
shared_ptr<FunctionCall> fc = make_shared<FunctionCall>(m_function, new_args); shared_ptr<FunctionCall> fc = make_shared<FunctionCall>(m_function, new_args);
fc->m_function = clone_function(*m_function); fc->m_function = clone_function(*m_function);
return fc; return fc;
......
...@@ -25,20 +25,16 @@ op::GetOutputElement::GetOutputElement(const shared_ptr<Node>& arg, size_t n) ...@@ -25,20 +25,16 @@ op::GetOutputElement::GetOutputElement(const shared_ptr<Node>& arg, size_t n)
: Node("GetOutputElement", {arg}) : Node("GetOutputElement", {arg})
, m_n{n} , m_n{n}
{ {
if (m_n >= arg->get_output_size()) NODE_VALIDATION_ASSERT(this, m_n < arg->get_output_size())
{ << "Output at index " << m_n << " requested, but argument has only "
throw ngraph_error("Indexing tuple beyond its size"); << arg->get_output_size() << " outputs.";
}
set_output_type(0, arg->get_output_element_type(n), arg->get_output_shape(n)); set_output_type(0, arg->get_output_element_type(n), arg->get_output_shape(n));
} }
shared_ptr<Node> op::GetOutputElement::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::GetOutputElement::copy_with_new_args(const NodeVector& new_args) const
{ {
if (new_args.size() != 1) check_new_args_count(this, new_args);
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<GetOutputElement>(new_args.at(0), m_n); return make_shared<GetOutputElement>(new_args.at(0), m_n);
} }
......
...@@ -27,9 +27,6 @@ op::Greater::Greater(const shared_ptr<Node>& arg0, const shared_ptr<Node>& arg1) ...@@ -27,9 +27,6 @@ op::Greater::Greater(const shared_ptr<Node>& arg0, const shared_ptr<Node>& arg1)
shared_ptr<Node> op::Greater::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Greater::copy_with_new_args(const NodeVector& new_args) const
{ {
if (new_args.size() != 2) check_new_args_count(this, new_args);
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Greater>(new_args.at(0), new_args.at(1)); return make_shared<Greater>(new_args.at(0), new_args.at(1));
} }
...@@ -27,9 +27,6 @@ op::GreaterEq::GreaterEq(const shared_ptr<Node>& arg0, const shared_ptr<Node>& a ...@@ -27,9 +27,6 @@ op::GreaterEq::GreaterEq(const shared_ptr<Node>& arg0, const shared_ptr<Node>& a
shared_ptr<Node> op::GreaterEq::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::GreaterEq::copy_with_new_args(const NodeVector& new_args) const
{ {
if (new_args.size() != 2) check_new_args_count(this, new_args);
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<GreaterEq>(new_args.at(0), new_args.at(1)); return make_shared<GreaterEq>(new_args.at(0), new_args.at(1));
} }
...@@ -27,9 +27,6 @@ op::Less::Less(const shared_ptr<Node>& arg0, const shared_ptr<Node>& arg1) ...@@ -27,9 +27,6 @@ op::Less::Less(const shared_ptr<Node>& arg0, const shared_ptr<Node>& arg1)
shared_ptr<Node> op::Less::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Less::copy_with_new_args(const NodeVector& new_args) const
{ {
if (new_args.size() != 2) check_new_args_count(this, new_args);
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Less>(new_args.at(0), new_args.at(1)); return make_shared<Less>(new_args.at(0), new_args.at(1));
} }
...@@ -27,9 +27,6 @@ op::LessEq::LessEq(const shared_ptr<Node>& arg0, const shared_ptr<Node>& arg1) ...@@ -27,9 +27,6 @@ op::LessEq::LessEq(const shared_ptr<Node>& arg0, const shared_ptr<Node>& arg1)
shared_ptr<Node> op::LessEq::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::LessEq::copy_with_new_args(const NodeVector& new_args) const
{ {
if (new_args.size() != 2) check_new_args_count(this, new_args);
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<LessEq>(new_args.at(0), new_args.at(1)); return make_shared<LessEq>(new_args.at(0), new_args.at(1));
} }
...@@ -28,10 +28,7 @@ op::Log::Log(const shared_ptr<Node>& arg) ...@@ -28,10 +28,7 @@ op::Log::Log(const shared_ptr<Node>& arg)
shared_ptr<Node> op::Log::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Log::copy_with_new_args(const NodeVector& new_args) const
{ {
if (new_args.size() != 1) check_new_args_count(this, new_args);
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Log>(new_args.at(0)); return make_shared<Log>(new_args.at(0));
} }
......
...@@ -28,18 +28,13 @@ op::LRN::LRN(const std::shared_ptr<Node>& arg, double alpha, double beta, double ...@@ -28,18 +28,13 @@ op::LRN::LRN(const std::shared_ptr<Node>& arg, double alpha, double beta, double
, m_size(nsize) , m_size(nsize)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
if (arg->get_shape().size() < 3) NODE_VALIDATION_ASSERT(this, arg->get_shape().size() >= 3)
{ << "Argument must have rank >= 3 (argument shape: " << arg->get_shape() << ").";
throw ngraph_error("LRN expects a tensor at least of rank of 3");
}
} }
shared_ptr<Node> op::LRN::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::LRN::copy_with_new_args(const NodeVector& new_args) const
{ {
if (new_args.size() != 1) check_new_args_count(this, new_args);
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<op::LRN>(new_args.at(0), m_alpha, m_beta, m_bias, m_size); return make_shared<op::LRN>(new_args.at(0), m_alpha, m_beta, m_bias, m_size);
} }
......
...@@ -27,9 +27,6 @@ op::Max::Max(const shared_ptr<Node>& arg, const AxisSet& reduction_axes) ...@@ -27,9 +27,6 @@ op::Max::Max(const shared_ptr<Node>& arg, const AxisSet& reduction_axes)
shared_ptr<Node> op::Max::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Max::copy_with_new_args(const NodeVector& new_args) const
{ {
if (new_args.size() != 1) check_new_args_count(this, new_args);
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Max>(new_args.at(0), m_reduction_axes); return make_shared<Max>(new_args.at(0), m_reduction_axes);
} }
...@@ -201,10 +201,7 @@ op::MaxPool::MaxPool(const shared_ptr<Node>& arg, const Shape& window_shape) ...@@ -201,10 +201,7 @@ op::MaxPool::MaxPool(const shared_ptr<Node>& arg, const Shape& window_shape)
shared_ptr<Node> op::MaxPool::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::MaxPool::copy_with_new_args(const NodeVector& new_args) const
{ {
if (new_args.size() != 1) check_new_args_count(this, new_args);
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<MaxPool>(new_args.at(0), return make_shared<MaxPool>(new_args.at(0),
m_window_shape, m_window_shape,
m_window_movement_strides, m_window_movement_strides,
...@@ -378,18 +375,13 @@ shared_ptr<op::MaxPool> op::MaxPoolBackprop::get_forward_op() const ...@@ -378,18 +375,13 @@ shared_ptr<op::MaxPool> op::MaxPoolBackprop::get_forward_op() const
shared_ptr<Node> op::MaxPoolBackprop::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::MaxPoolBackprop::copy_with_new_args(const NodeVector& new_args) const
{ {
if (new_args.size() != 2) check_new_args_count(this, new_args);
{ return make_shared<op::MaxPoolBackprop>(new_args.at(0),
throw ngraph_error("Incorrect number of new arguments");
}
MaxPoolBackprop* mpbp = new MaxPoolBackprop(new_args.at(0),
new_args.at(1), new_args.at(1),
m_window_shape, m_window_shape,
m_window_movement_strides, m_window_movement_strides,
m_padding_below, m_padding_below,
m_padding_above); m_padding_above);
return shared_ptr<op::MaxPoolBackprop>(mpbp);
} }
void op::MaxPool::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas) void op::MaxPool::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
......
...@@ -33,10 +33,7 @@ op::Maximum::Maximum(const shared_ptr<Node>& arg0, const shared_ptr<Node>& arg1) ...@@ -33,10 +33,7 @@ op::Maximum::Maximum(const shared_ptr<Node>& arg0, const shared_ptr<Node>& arg1)
shared_ptr<Node> op::Maximum::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Maximum::copy_with_new_args(const NodeVector& new_args) const
{ {
if (new_args.size() != 2) check_new_args_count(this, new_args);
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Maximum>(new_args.at(0), new_args.at(1)); return make_shared<Maximum>(new_args.at(0), new_args.at(1));
} }
......
...@@ -27,9 +27,6 @@ op::Min::Min(const shared_ptr<Node>& arg, const AxisSet& reduction_axes) ...@@ -27,9 +27,6 @@ op::Min::Min(const shared_ptr<Node>& arg, const AxisSet& reduction_axes)
shared_ptr<Node> op::Min::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Min::copy_with_new_args(const NodeVector& new_args) const
{ {
if (new_args.size() != 1) check_new_args_count(this, new_args);
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Min>(new_args.at(0), m_reduction_axes); return make_shared<Min>(new_args.at(0), m_reduction_axes);
} }
...@@ -33,10 +33,7 @@ op::Minimum::Minimum(const shared_ptr<Node>& arg0, const shared_ptr<Node>& arg1) ...@@ -33,10 +33,7 @@ op::Minimum::Minimum(const shared_ptr<Node>& arg0, const shared_ptr<Node>& arg1)
shared_ptr<Node> op::Minimum::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Minimum::copy_with_new_args(const NodeVector& new_args) const
{ {
if (new_args.size() != 2) check_new_args_count(this, new_args);
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Minimum>(new_args.at(0), new_args.at(1)); return make_shared<Minimum>(new_args.at(0), new_args.at(1));
} }
......
...@@ -27,10 +27,7 @@ op::Multiply::Multiply(const shared_ptr<Node>& arg0, const shared_ptr<Node>& arg ...@@ -27,10 +27,7 @@ op::Multiply::Multiply(const shared_ptr<Node>& arg0, const shared_ptr<Node>& arg
shared_ptr<Node> op::Multiply::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Multiply::copy_with_new_args(const NodeVector& new_args) const
{ {
if (new_args.size() != 2) check_new_args_count(this, new_args);
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Multiply>(new_args.at(0), new_args.at(1)); return make_shared<Multiply>(new_args.at(0), new_args.at(1));
} }
......
...@@ -27,10 +27,7 @@ op::Negative::Negative(const shared_ptr<Node>& arg) ...@@ -27,10 +27,7 @@ op::Negative::Negative(const shared_ptr<Node>& arg)
shared_ptr<Node> op::Negative::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Negative::copy_with_new_args(const NodeVector& new_args) const
{ {
if (new_args.size() != 1) check_new_args_count(this, new_args);
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Negative>(new_args.at(0)); return make_shared<Negative>(new_args.at(0));
} }
......
...@@ -33,9 +33,6 @@ void op::Not::validate_and_infer_types() ...@@ -33,9 +33,6 @@ void op::Not::validate_and_infer_types()
shared_ptr<Node> op::Not::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Not::copy_with_new_args(const NodeVector& new_args) const
{ {
if (new_args.size() != 1) check_new_args_count(this, new_args);
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Not>(new_args.at(0)); return make_shared<Not>(new_args.at(0));
} }
...@@ -27,9 +27,6 @@ op::NotEqual::NotEqual(const shared_ptr<Node>& arg0, const shared_ptr<Node>& arg ...@@ -27,9 +27,6 @@ op::NotEqual::NotEqual(const shared_ptr<Node>& arg0, const shared_ptr<Node>& arg
shared_ptr<Node> op::NotEqual::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::NotEqual::copy_with_new_args(const NodeVector& new_args) const
{ {
if (new_args.size() != 2) check_new_args_count(this, new_args);
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<NotEqual>(new_args.at(0), new_args.at(1)); return make_shared<NotEqual>(new_args.at(0), new_args.at(1));
} }
...@@ -30,27 +30,22 @@ op::OneHot::OneHot(const shared_ptr<Node>& arg, const Shape& shape, size_t one_h ...@@ -30,27 +30,22 @@ op::OneHot::OneHot(const shared_ptr<Node>& arg, const Shape& shape, size_t one_h
auto& input = m_inputs.at(0); auto& input = m_inputs.at(0);
auto& input_element_type = input.get_element_type(); auto& input_element_type = input.get_element_type();
if (one_hot_axis >= shape.size()) NODE_VALIDATION_ASSERT(this, one_hot_axis < shape.size())
{ << "One-hot axis (" << one_hot_axis
throw ngraph_error("One-hot axis is out of bounds"); << ") is out of bounds (requested result shape: " << shape << ").";
}
auto expected_input_shape = shape; auto expected_input_shape = shape;
expected_input_shape.erase(expected_input_shape.begin() + one_hot_axis); expected_input_shape.erase(expected_input_shape.begin() + one_hot_axis);
if (input.get_shape() != expected_input_shape) NODE_VALIDATION_ASSERT(this, input.get_shape() == expected_input_shape)
{ << "Argument shape " << input.get_shape() << " does not match the expected shape of "
throw ngraph_error("One-hot argument shape is not compatible with desired output shape"); << expected_input_shape << ".";
}
set_output_type(0, input_element_type, shape); set_output_type(0, input_element_type, shape);
} }
shared_ptr<Node> op::OneHot::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::OneHot::copy_with_new_args(const NodeVector& new_args) const
{ {
if (new_args.size() != 1) check_new_args_count(this, new_args);
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<OneHot>(new_args.at(0), m_shape, m_one_hot_axis); return make_shared<OneHot>(new_args.at(0), m_shape, m_one_hot_axis);
} }
...@@ -27,9 +27,6 @@ op::Or::Or(const shared_ptr<Node>& arg0, const shared_ptr<Node>& arg1) ...@@ -27,9 +27,6 @@ op::Or::Or(const shared_ptr<Node>& arg0, const shared_ptr<Node>& arg1)
shared_ptr<Node> op::Or::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Or::copy_with_new_args(const NodeVector& new_args) const
{ {
if (new_args.size() != 2) check_new_args_count(this, new_args);
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Or>(new_args.at(0), new_args.at(1)); return make_shared<Or>(new_args.at(0), new_args.at(1));
} }
...@@ -33,32 +33,27 @@ op::Pad::Pad(const shared_ptr<Node>& arg, ...@@ -33,32 +33,27 @@ op::Pad::Pad(const shared_ptr<Node>& arg,
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
if (get_input_element_type(0) != get_input_element_type(1)) NODE_VALIDATION_ASSERT(this, get_input_element_type(0) == get_input_element_type(1))
{ << "Argument element types do not match (arg0 element type: " << get_input_element_type(0)
throw ngraph_error("Pad argument tensor and padding value element types do not match"); << ", arg1 element type: " << get_input_element_type(1) << ").";
}
if (get_input_shape(1) != Shape{}) NODE_VALIDATION_ASSERT(this, get_input_shape(1) == Shape{})
{ << "Argument for padding value is not a scalar (shape: " << get_input_shape(1) << ").";
throw ngraph_error("Padding value for pad is not a scalar");
}
auto arg_shape = get_input_shape(0); auto arg_shape = get_input_shape(0);
if (arg_shape.size() != padding_below.size()) NODE_VALIDATION_ASSERT(this, arg_shape.size() == padding_below.size())
{ << "Rank for padding below does not match the rank of the data argument (padding below: "
throw ngraph_error("Pad rank for below-padding does not match rank of argument tensor"); << padding_below << ", data argument shape: " << arg_shape << ").";
}
if (arg_shape.size() != padding_above.size()) NODE_VALIDATION_ASSERT(this, arg_shape.size() == padding_above.size())
{ << "Rank for padding above does not match the rank of the data argument (padding above: "
throw ngraph_error("Pad rank for above-padding does not match rank of argument tensor"); << padding_above << ", data argument shape: " << arg_shape << ").";
}
if (arg_shape.size() != padding_interior.size()) NODE_VALIDATION_ASSERT(this, arg_shape.size() == padding_interior.size())
{ << "Rank for interior padding does not match the rank of the data argument (interior "
throw ngraph_error("Pad rank for interior padding does not match rank of argument tensor"); "padding: "
} << padding_interior << ", data argument shape: " << arg_shape << ").";
Shape result_shape; Shape result_shape;
...@@ -75,10 +70,7 @@ op::Pad::Pad(const shared_ptr<Node>& arg, ...@@ -75,10 +70,7 @@ op::Pad::Pad(const shared_ptr<Node>& arg,
shared_ptr<Node> op::Pad::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Pad::copy_with_new_args(const NodeVector& new_args) const
{ {
if (new_args.size() != 2) check_new_args_count(this, new_args);
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Pad>( return make_shared<Pad>(
new_args.at(0), new_args.at(1), m_padding_below, m_padding_above, m_padding_interior); new_args.at(0), new_args.at(1), m_padding_below, m_padding_above, m_padding_interior);
} }
......
...@@ -40,10 +40,7 @@ void op::Parameter::validate_and_infer_types() ...@@ -40,10 +40,7 @@ void op::Parameter::validate_and_infer_types()
shared_ptr<Node> op::Parameter::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Parameter::copy_with_new_args(const NodeVector& new_args) const
{ {
if (new_args.size() != 0) check_new_args_count(this, new_args);
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Parameter>(m_element_type, m_shape); return make_shared<Parameter>(m_element_type, m_shape);
} }
......
...@@ -30,10 +30,7 @@ op::Power::Power(const shared_ptr<Node>& arg0, const shared_ptr<Node>& arg1) ...@@ -30,10 +30,7 @@ op::Power::Power(const shared_ptr<Node>& arg0, const shared_ptr<Node>& arg1)
shared_ptr<Node> op::Power::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Power::copy_with_new_args(const NodeVector& new_args) const
{ {
if (new_args.size() != 2) check_new_args_count(this, new_args);
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Power>(new_args.at(0), new_args.at(1)); return make_shared<Power>(new_args.at(0), new_args.at(1));
} }
......
...@@ -27,9 +27,6 @@ op::Product::Product(const shared_ptr<Node>& arg, const AxisSet& reduction_axes) ...@@ -27,9 +27,6 @@ op::Product::Product(const shared_ptr<Node>& arg, const AxisSet& reduction_axes)
shared_ptr<Node> op::Product::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Product::copy_with_new_args(const NodeVector& new_args) const
{ {
if (new_args.size() != 1) check_new_args_count(this, new_args);
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Product>(new_args.at(0), m_reduction_axes); return make_shared<Product>(new_args.at(0), m_reduction_axes);
} }
...@@ -98,10 +98,7 @@ op::Reduce::Reduce(const shared_ptr<Node>& arg_reductee, ...@@ -98,10 +98,7 @@ op::Reduce::Reduce(const shared_ptr<Node>& arg_reductee,
shared_ptr<Node> op::Reduce::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Reduce::copy_with_new_args(const NodeVector& new_args) const
{ {
if (new_args.size() != 2) check_new_args_count(this, new_args);
{
throw ngraph_error("Incorrect number of new arguments");
}
shared_ptr<Reduce> fc = shared_ptr<Reduce> fc =
make_shared<Reduce>(new_args.at(0), new_args.at(1), m_reduction_function, m_reduction_axes); make_shared<Reduce>(new_args.at(0), new_args.at(1), m_reduction_function, m_reduction_axes);
fc->m_reduction_function = clone_function(*m_reduction_function); fc->m_reduction_function = clone_function(*m_reduction_function);
......
...@@ -135,10 +135,7 @@ op::ReduceWindow::ReduceWindow(const shared_ptr<Node>& arg_reductee, ...@@ -135,10 +135,7 @@ op::ReduceWindow::ReduceWindow(const shared_ptr<Node>& arg_reductee,
shared_ptr<Node> op::ReduceWindow::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::ReduceWindow::copy_with_new_args(const NodeVector& new_args) const
{ {
if (new_args.size() != 2) check_new_args_count(this, new_args);
{
throw ngraph_error("Incorrect number of new arguments");
}
auto node = make_shared<ReduceWindow>(new_args.at(0), auto node = make_shared<ReduceWindow>(new_args.at(0),
new_args.at(1), new_args.at(1),
m_reduction_function, m_reduction_function,
......
...@@ -24,38 +24,23 @@ op::Relu::Relu(shared_ptr<Node> arg) ...@@ -24,38 +24,23 @@ op::Relu::Relu(shared_ptr<Node> arg)
: UnaryElementwiseArithmetic("Relu", {arg}) : UnaryElementwiseArithmetic("Relu", {arg})
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
set_output_type(0, arg->get_element_type(), arg->get_shape());
} }
shared_ptr<Node> op::Relu::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Relu::copy_with_new_args(const NodeVector& new_args) const
{ {
if (new_args.size() != 1) check_new_args_count(this, new_args);
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Relu>(new_args.at(0)); return make_shared<Relu>(new_args.at(0));
} }
op::ReluBackprop::ReluBackprop(shared_ptr<Node> arg, shared_ptr<Node> delta) op::ReluBackprop::ReluBackprop(shared_ptr<Node> arg, shared_ptr<Node> delta)
: Op("ReluBackprop", check_single_output_args({arg, delta})) : BinaryElementwiseArithmetic("ReluBackprop", arg, delta)
{ {
if (arg->get_element_type() != delta->get_element_type()) constructor_validate_and_infer_types();
{
throw ngraph_error("Argument and delta element types for Relu backprop do not match");
}
if (arg->get_shape() != delta->get_shape())
{
throw ngraph_error("Argument and delta shape for Relu backprop do not match");
}
set_output_type(0, delta->get_element_type(), delta->get_shape());
} }
shared_ptr<Node> op::ReluBackprop::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::ReluBackprop::copy_with_new_args(const NodeVector& new_args) const
{ {
if (new_args.size() != 2) check_new_args_count(this, new_args);
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<ReluBackprop>(new_args.at(0), new_args.at(1)); return make_shared<ReluBackprop>(new_args.at(0), new_args.at(1));
} }
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/op/op.hpp" #include "ngraph/op/op.hpp"
#include "ngraph/op/op.hpp" #include "ngraph/op/util/binary_elementwise_arithmetic.hpp"
#include "ngraph/op/util/unary_elementwise_arithmetic.hpp" #include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
...@@ -47,7 +47,7 @@ namespace ngraph ...@@ -47,7 +47,7 @@ namespace ngraph
/// \brief Elementwise ReluBackprop operation. /// \brief Elementwise ReluBackprop operation.
/// ///
class ReluBackprop : public Op class ReluBackprop : public ngraph::op::util::BinaryElementwiseArithmetic
{ {
public: public:
/// \brief Constructs a ReluBackprop operation. /// \brief Constructs a ReluBackprop operation.
......
...@@ -27,9 +27,6 @@ op::Remainder::Remainder(const shared_ptr<Node>& arg0, const shared_ptr<Node>& a ...@@ -27,9 +27,6 @@ op::Remainder::Remainder(const shared_ptr<Node>& arg0, const shared_ptr<Node>& a
shared_ptr<Node> op::Remainder::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Remainder::copy_with_new_args(const NodeVector& new_args) const
{ {
if (new_args.size() != 2) check_new_args_count(this, new_args);
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Remainder>(new_args.at(0), new_args.at(1)); return make_shared<Remainder>(new_args.at(0), new_args.at(1));
} }
...@@ -60,52 +60,43 @@ void op::ReplaceSlice::check_args() ...@@ -60,52 +60,43 @@ void op::ReplaceSlice::check_args()
auto& input_1_shape = input_1.get_shape(); auto& input_1_shape = input_1.get_shape();
auto& input_1_element_type = input_1.get_element_type(); auto& input_1_element_type = input_1.get_element_type();
if (input_0_shape.size() != input_1_shape.size()) NODE_VALIDATION_ASSERT(this, input_0_shape.size() == input_1_shape.size())
{ << "Argument ranks do not match (arg0 shape: " << input_0_shape
throw ngraph_error("Replace-slice argument ranks do not match"); << ", arg1 shape: " << input_1_shape << ").";
}
if (input_0_element_type != input_1_element_type) NODE_VALIDATION_ASSERT(this, input_0_element_type == input_1_element_type)
{ << "Argument element types do not match (arg0 element type: " << input_0_element_type
throw ngraph_error("Element types for replace-slice arguments do not match"); << ", arg1 element type: " << input_1_element_type << ").";
}
if (m_lower_bounds.size() != input_0_shape.size()) NODE_VALIDATION_ASSERT(this, m_lower_bounds.size() == input_0_shape.size())
{ << "Rank of lower bounds (" << m_lower_bounds.size() << ") does not match rank "
throw ngraph_error( << "of argument (" << input_0_shape.size() << ") (lower bounds: " << m_lower_bounds
"Number of lower bounds provided for slice does not match number of input axes"); << ", argument shape: " << input_0_shape << ").";
}
if (m_upper_bounds.size() != input_0_shape.size()) NODE_VALIDATION_ASSERT(this, m_upper_bounds.size() == input_0_shape.size())
{ << "Rank of upper bounds (" << m_upper_bounds.size() << ") does not match rank "
throw ngraph_error( << "of argument (" << input_0_shape.size() << ") (upper bounds: " << m_upper_bounds
"Number of upper bounds provided for slice does not match number of input axes"); << ", argument shape: " << input_0_shape << ").";
}
if (m_strides.size() != input_0_shape.size()) NODE_VALIDATION_ASSERT(this, m_strides.size() == input_0_shape.size())
{ << "Rank of strides (" << m_strides.size() << ") does not match rank "
throw ngraph_error( << "of argument (" << input_0_shape.size() << ") (strides: " << m_strides
"Number of strides provided for slice does not match number of input axes"); << ", argument shape: " << input_0_shape << ").";
}
Shape slice_shape; Shape slice_shape;
for (size_t i = 0; i < input_0_shape.size(); i++) for (size_t i = 0; i < input_0_shape.size(); i++)
{ {
if (m_upper_bounds[i] > input_0_shape[i]) NODE_VALIDATION_ASSERT(this, m_upper_bounds[i] <= input_0_shape[i])
{ << "Upper bound for slice at axis " << i << " is out of range "
throw ngraph_error("Upper bound for slice is out of range"); << "(upper bounds: " << m_upper_bounds << ", argument shape: " << input_0_shape << ").";
}
if (m_lower_bounds[i] > m_upper_bounds[i]) NODE_VALIDATION_ASSERT(this, m_lower_bounds[i] <= m_upper_bounds[i])
{ << "Lower bound for slice is greater than upper bound at axis " << i
throw ngraph_error("Lower bound for slice is greater than upper bound"); << " (lower bounds: " << m_lower_bounds << ", upper bounds: " << m_upper_bounds << ").";
}
if (0 == m_strides[i]) NODE_VALIDATION_ASSERT(this, m_strides[i] != 0) << "Stride for slice is zero at axis " << i
{ << " (strides: " << m_strides << ").";
throw ngraph_error("Stride for slice is zero");
}
size_t slice_axis_size = m_upper_bounds[i] - m_lower_bounds[i]; size_t slice_axis_size = m_upper_bounds[i] - m_lower_bounds[i];
slice_axis_size = slice_axis_size =
...@@ -113,20 +104,16 @@ void op::ReplaceSlice::check_args() ...@@ -113,20 +104,16 @@ void op::ReplaceSlice::check_args()
slice_shape.push_back(slice_axis_size); slice_shape.push_back(slice_axis_size);
} }
if (input_1_shape != slice_shape) NODE_VALIDATION_ASSERT(this, input_1_shape == slice_shape)
{ << "Shape of replacement tensor (" << input_1_shape << ") does not match the slice shape "
throw ngraph_error("Shape of replacement tensor does not match slice shape"); << "(" << slice_shape << ").";
}
set_output_type(0, input_0_element_type, input_0_shape); set_output_type(0, input_0_element_type, input_0_shape);
} }
shared_ptr<Node> op::ReplaceSlice::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::ReplaceSlice::copy_with_new_args(const NodeVector& new_args) const
{ {
if (new_args.size() != 2) check_new_args_count(this, new_args);
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<ReplaceSlice>( return make_shared<ReplaceSlice>(
new_args.at(0), new_args.at(1), m_lower_bounds, m_upper_bounds, m_strides); new_args.at(0), new_args.at(1), m_lower_bounds, m_upper_bounds, m_strides);
} }
......
...@@ -39,19 +39,16 @@ void op::Reshape::validate_and_infer_types() ...@@ -39,19 +39,16 @@ void op::Reshape::validate_and_infer_types()
auto input_shape = input.get_shape(); auto input_shape = input.get_shape();
auto input_rank = input_shape.size(); auto input_rank = input_shape.size();
if (m_input_order.size() != input_rank) NODE_VALIDATION_ASSERT(this, m_input_order.size() == input_rank)
{ << "Input axis order is not a permutation of argument's axis indices (axis order: "
throw ngraph_error("Input axis order for reshape is not a permutation of argument's axes"); << m_input_order << ", argument shape: " << input_shape << ").";
}
for (size_t i = 0; i < input_rank; i++) for (size_t i = 0; i < input_rank; i++)
{ {
auto it = find(begin(m_input_order), end(m_input_order), i); auto it = find(begin(m_input_order), end(m_input_order), i);
if (end(m_input_order) == it) NODE_VALIDATION_ASSERT(this, it != end(m_input_order))
{ << "Input axis order is not a permutation of argument's axis indices (axis order: "
throw ngraph_error( << m_input_order << ", argument shape: " << input_shape << ").";
"Input axis order for reshape is not a permutation of argument's axes");
}
} }
size_t input_shape_product = 1; size_t input_shape_product = 1;
...@@ -66,12 +63,9 @@ void op::Reshape::validate_and_infer_types() ...@@ -66,12 +63,9 @@ void op::Reshape::validate_and_infer_types()
output_shape_product *= i; output_shape_product *= i;
} }
if (input_shape_product != output_shape_product) NODE_VALIDATION_ASSERT(this, input_shape_product == output_shape_product)
{ << "Product of output shape dimensions does not match product of argument shape dimensions "
throw ngraph_error( << "(output shape: " << m_output_shape << ", argument shape: " << input_shape << ").";
"Product of output shape dimensions does not match product of argument shape "
"dimensions for reshape");
}
if (!std::is_sorted(m_input_order.begin(), m_input_order.end())) if (!std::is_sorted(m_input_order.begin(), m_input_order.end()))
{ {
...@@ -82,10 +76,7 @@ void op::Reshape::validate_and_infer_types() ...@@ -82,10 +76,7 @@ void op::Reshape::validate_and_infer_types()
shared_ptr<Node> op::Reshape::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Reshape::copy_with_new_args(const NodeVector& new_args) const
{ {
if (new_args.size() != 1) check_new_args_count(this, new_args);
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Reshape>(new_args.at(0), m_input_order, m_output_shape); return make_shared<Reshape>(new_args.at(0), m_input_order, m_output_shape);
} }
......
...@@ -32,10 +32,8 @@ op::Result::Result(const shared_ptr<Node>& arg) ...@@ -32,10 +32,8 @@ op::Result::Result(const shared_ptr<Node>& arg)
void op::Result::validate_and_infer_types() void op::Result::validate_and_infer_types()
{ {
if (get_input_size() != 1) NODE_VALIDATION_ASSERT(this, get_input_size() == 1) << "Argument has " << get_input_size()
{ << " outputs (1 expected).";
throw ngraph_error("Result expected a single-output argument");
}
// always borrow the placement conf even the default one // always borrow the placement conf even the default one
set_placement(get_argument(0)->get_placement()); set_placement(get_argument(0)->get_placement());
...@@ -44,15 +42,7 @@ void op::Result::validate_and_infer_types() ...@@ -44,15 +42,7 @@ void op::Result::validate_and_infer_types()
shared_ptr<Node> op::Result::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Result::copy_with_new_args(const NodeVector& new_args) const
{ {
if (new_args.size() != 1) check_new_args_count(this, new_args);
{
throw ngraph_error("Incorrect number of new arguments");
}
if (new_args.at(0)->get_outputs().size() != 1)
{
throw ngraph_error("Result::copy_with_new_args expected a single-output argument");
}
auto res = make_shared<Result>(new_args.at(0)); auto res = make_shared<Result>(new_args.at(0));
if (res) if (res)
......
...@@ -38,13 +38,9 @@ void op::Reverse::validate_and_infer_types() ...@@ -38,13 +38,9 @@ void op::Reverse::validate_and_infer_types()
// Make sure all reversed axis indices are valid. // Make sure all reversed axis indices are valid.
for (size_t axis : m_reversed_axes) for (size_t axis : m_reversed_axes)
{ {
if (axis >= input_rank) NODE_VALIDATION_ASSERT(this, axis < input_rank)
{ << "Reverse axis (" << axis << ") is out of bounds (argument shape: " << input_shape
stringstream ss;
ss << "Reverse axis " << axis << " is out of bounds (input rank is " << input_rank
<< ")."; << ").";
throw ngraph_error(ss.str());
}
} }
set_output_type(0, get_input_element_type(0), input_shape); set_output_type(0, get_input_element_type(0), input_shape);
...@@ -52,10 +48,7 @@ void op::Reverse::validate_and_infer_types() ...@@ -52,10 +48,7 @@ void op::Reverse::validate_and_infer_types()
shared_ptr<Node> op::Reverse::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Reverse::copy_with_new_args(const NodeVector& new_args) const
{ {
if (new_args.size() != 1) check_new_args_count(this, new_args);
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Reverse>(new_args.at(0), m_reversed_axes); return make_shared<Reverse>(new_args.at(0), m_reversed_axes);
} }
......
...@@ -38,36 +38,30 @@ op::ReverseSequence::ReverseSequence(const std::shared_ptr<Node> arg, ...@@ -38,36 +38,30 @@ op::ReverseSequence::ReverseSequence(const std::shared_ptr<Node> arg,
void op::ReverseSequence::validate_and_infer_types() void op::ReverseSequence::validate_and_infer_types()
{ {
if (get_input_shape(1).size() != 1) NODE_VALIDATION_ASSERT(this, get_input_shape(1).size() == 1)
{ << "Sequence indices must be a 1-dimensional tensor (sequence indices shape: "
throw ngraph_error("indices should be a 1-dimensional array"); << get_input_shape(1) << ").";
}
if (m_batch_axis >= get_input_shape(0).size()) NODE_VALIDATION_ASSERT(this, m_batch_axis < get_input_shape(0).size())
{ << "Batch axis index (" << m_batch_axis
throw ngraph_error("batch axis index is out of bounds"); << ") is out of bounds (argument shape: " << get_input_shape(0) << ").";
}
if (m_seq_axis >= get_input_shape(0).size()) NODE_VALIDATION_ASSERT(this, m_seq_axis < get_input_shape(0).size())
{ << "Sequence axis index (" << m_seq_axis
throw ngraph_error("sequence axis index is out of bounds"); << ") is out of bounds (argument shape: " << get_input_shape(0) << ").";
}
if (get_input_shape(0).at(m_batch_axis) != get_input_shape(1).at(0)) NODE_VALIDATION_ASSERT(this, get_input_shape(0)[m_batch_axis] == get_input_shape(1)[0])
{ << "Sequence length (" << get_input_shape(1)[0] << ") is not equal to batch axis "
throw ngraph_error("Sequence length size should be equal to batch axis dimension"); << "dimension (" << get_input_shape(0)[m_batch_axis]
} << ") (argument shape: " << get_input_shape(0)
<< ", sequence indices shape: " << get_input_shape(1) << ").";
set_output_type(0, get_input_element_type(0), get_input_shape(0)); set_output_type(0, get_input_element_type(0), get_input_shape(0));
} }
shared_ptr<Node> op::ReverseSequence::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::ReverseSequence::copy_with_new_args(const NodeVector& new_args) const
{ {
if (new_args.size() != 2) check_new_args_count(this, new_args);
{
throw ngraph_error("Incorrect number of new arguments");
}
auto res = auto res =
make_shared<ReverseSequence>(new_args.at(0), new_args.at(1), m_batch_axis, m_seq_axis); make_shared<ReverseSequence>(new_args.at(0), new_args.at(1), m_batch_axis, m_seq_axis);
return res; return res;
......
...@@ -37,28 +37,27 @@ op::Select::Select(const shared_ptr<Node>& arg0, ...@@ -37,28 +37,27 @@ op::Select::Select(const shared_ptr<Node>& arg0,
auto& input_1 = get_inputs().at(1); auto& input_1 = get_inputs().at(1);
auto& input_2 = get_inputs().at(2); auto& input_2 = get_inputs().at(2);
if (input_0.get_element_type() != element::boolean) NODE_VALIDATION_ASSERT(this, input_0.get_element_type() == element::boolean)
{ << "Argument 0 does not have boolean element type (element type: "
throw ngraph_error("Argument 0 for arithmetic operators must have boolean element type"); << input_0.get_element_type() << ").";
}
if (input_0.get_shape() != input_1.get_shape() || input_0.get_shape() != input_2.get_shape()) NODE_VALIDATION_ASSERT(this,
{ input_0.get_shape() == input_1.get_shape() &&
throw ngraph_error("Arguments must have the same shape"); input_0.get_shape() == input_2.get_shape())
} << "Arguments do not all have the same shape (arg0 shape: " << input_0.get_shape()
if (input_1.get_element_type() != input_2.get_element_type()) << ", arg1 shape: " << input_1.get_shape() << ", arg2 shape: " << input_2.get_shape()
{ << ").";
throw ngraph_error("Arguments 1 and 2 must have the same element type");
} NODE_VALIDATION_ASSERT(this, input_1.get_element_type() == input_2.get_element_type())
<< "Arguments 1 and 2 do not have the same element type (arg1 type: "
<< input_1.get_element_type() << ", arg2 type: " << input_2.get_element_type() << ").";
set_output_type(0, input_1.get_element_type(), input_1.get_shape()); set_output_type(0, input_1.get_element_type(), input_1.get_shape());
} }
shared_ptr<Node> op::Select::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Select::copy_with_new_args(const NodeVector& new_args) const
{ {
if (new_args.size() != 3) check_new_args_count(this, new_args);
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Select>(new_args.at(0), new_args.at(1), new_args.at(2)); return make_shared<Select>(new_args.at(0), new_args.at(1), new_args.at(2));
} }
......
...@@ -222,10 +222,7 @@ op::SelectAndScatter::SelectAndScatter(const shared_ptr<Node>& arg_selectee, ...@@ -222,10 +222,7 @@ op::SelectAndScatter::SelectAndScatter(const shared_ptr<Node>& arg_selectee,
shared_ptr<Node> op::SelectAndScatter::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::SelectAndScatter::copy_with_new_args(const NodeVector& new_args) const
{ {
if (new_args.size() != 3) check_new_args_count(this, new_args);
{
throw ngraph_error("Incorrect number of new arguments");
}
auto node = make_shared<SelectAndScatter>(new_args.at(0), auto node = make_shared<SelectAndScatter>(new_args.at(0),
new_args.at(1), new_args.at(1),
new_args.at(2), new_args.at(2),
......
...@@ -23,11 +23,7 @@ using namespace ngraph; ...@@ -23,11 +23,7 @@ using namespace ngraph;
shared_ptr<Node> op::Sigmoid::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Sigmoid::copy_with_new_args(const NodeVector& new_args) const
{ {
if (new_args.size() != 1) check_new_args_count(this, new_args);
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Sigmoid>(new_args.at(0)); return make_shared<Sigmoid>(new_args.at(0));
} }
...@@ -41,23 +37,20 @@ op::Sigmoid::Sigmoid(shared_ptr<Node> arg) ...@@ -41,23 +37,20 @@ op::Sigmoid::Sigmoid(shared_ptr<Node> arg)
op::SigmoidBackprop::SigmoidBackprop(shared_ptr<Node> arg, shared_ptr<Node> delta) op::SigmoidBackprop::SigmoidBackprop(shared_ptr<Node> arg, shared_ptr<Node> delta)
: Op("SigmoidBackprop", check_single_output_args({arg, delta})) : Op("SigmoidBackprop", check_single_output_args({arg, delta}))
{ {
if (arg->get_element_type() != delta->get_element_type()) NODE_VALIDATION_ASSERT(this, arg->get_element_type() == delta->get_element_type())
{ << "Argument and delta element types do not match (argument element type: "
throw ngraph_error("Argument and delta element types for Sigmoid backprop do not match"); << arg->get_element_type() << ", delta element type: " << delta->get_element_type() << ").";
}
if (arg->get_shape() != delta->get_shape()) NODE_VALIDATION_ASSERT(this, arg->get_shape() == delta->get_shape())
{ << "Argument and delta shapes do not match (argument shape: " << arg->get_shape()
throw ngraph_error("Argument and delta shape for Sigmoid backprop do not match"); << ", delta shape: " << delta->get_shape() << ").";
}
set_output_type(0, delta->get_element_type(), delta->get_shape()); set_output_type(0, delta->get_element_type(), delta->get_shape());
} }
shared_ptr<Node> op::SigmoidBackprop::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::SigmoidBackprop::copy_with_new_args(const NodeVector& new_args) const
{ {
if (new_args.size() != 2) check_new_args_count(this, new_args);
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<SigmoidBackprop>(new_args.at(0), new_args.at(1)); return make_shared<SigmoidBackprop>(new_args.at(0), new_args.at(1));
} }
......
...@@ -27,9 +27,6 @@ op::Sign::Sign(const shared_ptr<Node>& arg) ...@@ -27,9 +27,6 @@ op::Sign::Sign(const shared_ptr<Node>& arg)
shared_ptr<Node> op::Sign::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Sign::copy_with_new_args(const NodeVector& new_args) const
{ {
if (new_args.size() != 1) check_new_args_count(this, new_args);
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Sign>(new_args.at(0)); return make_shared<Sign>(new_args.at(0));
} }
...@@ -29,10 +29,7 @@ op::Sin::Sin(const shared_ptr<Node>& arg) ...@@ -29,10 +29,7 @@ op::Sin::Sin(const shared_ptr<Node>& arg)
shared_ptr<Node> op::Sin::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Sin::copy_with_new_args(const NodeVector& new_args) const
{ {
if (new_args.size() != 1) check_new_args_count(this, new_args);
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Sin>(new_args.at(0)); return make_shared<Sin>(new_args.at(0));
} }
......
...@@ -29,10 +29,7 @@ op::Sinh::Sinh(const shared_ptr<Node>& arg) ...@@ -29,10 +29,7 @@ op::Sinh::Sinh(const shared_ptr<Node>& arg)
shared_ptr<Node> op::Sinh::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Sinh::copy_with_new_args(const NodeVector& new_args) const
{ {
if (new_args.size() != 1) check_new_args_count(this, new_args);
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Sinh>(new_args.at(0)); return make_shared<Sinh>(new_args.at(0));
} }
......
...@@ -51,42 +51,35 @@ void op::Slice::validate_and_infer_types() ...@@ -51,42 +51,35 @@ void op::Slice::validate_and_infer_types()
auto& input = get_inputs().at(0); auto& input = get_inputs().at(0);
auto& input_shape = input.get_shape(); auto& input_shape = input.get_shape();
if (m_lower_bounds.size() != input_shape.size()) NODE_VALIDATION_ASSERT(this, m_lower_bounds.size() == input_shape.size())
{ << "Rank of lower bounds (" << m_lower_bounds.size() << ") does not match rank "
throw ngraph_error( << "of argument (" << input_shape.size() << ") (lower bounds: " << m_lower_bounds
"Number of lower bounds provided for slice does not match number of input axes"); << ", argument shape: " << input_shape << ").";
}
if (m_upper_bounds.size() != input_shape.size()) NODE_VALIDATION_ASSERT(this, m_upper_bounds.size() == input_shape.size())
{ << "Rank of upper bounds (" << m_upper_bounds.size() << ") does not match rank "
throw ngraph_error( << "of argument (" << input_shape.size() << ") (upper bounds: " << m_upper_bounds
"Number of upper bounds provided for slice does not match number of input axes"); << ", argument shape: " << input_shape << ").";
}
if (m_strides.size() != input_shape.size()) NODE_VALIDATION_ASSERT(this, m_strides.size() == input_shape.size())
{ << "Rank of strides (" << m_strides.size() << ") does not match rank "
throw ngraph_error( << "of argument (" << input_shape.size() << ") (strides: " << m_strides
"Number of strides provided for slice does not match number of input axes"); << ", argument shape: " << input_shape << ").";
}
Shape result_shape; Shape result_shape;
for (size_t i = 0; i < input_shape.size(); i++) for (size_t i = 0; i < input_shape.size(); i++)
{ {
if (m_upper_bounds[i] > input_shape[i]) NODE_VALIDATION_ASSERT(this, m_upper_bounds[i] <= input_shape[i])
{ << "Upper bound for slice at axis " << i << " is out of range "
throw ngraph_error("Upper bound for slice is out of range"); << "(upper bounds: " << m_upper_bounds << ", argument shape: " << input_shape << ").";
}
if (m_lower_bounds[i] > m_upper_bounds[i]) NODE_VALIDATION_ASSERT(this, m_lower_bounds[i] <= m_upper_bounds[i])
{ << "Lower bound for slice is greater than upper bound at axis " << i
throw ngraph_error("Lower bound for slice is greater than upper bound"); << " (lower bounds: " << m_lower_bounds << ", upper bounds: " << m_upper_bounds << ").";
}
if (0 == m_strides[i]) NODE_VALIDATION_ASSERT(this, m_strides[i] != 0) << "Stride for slice is zero at axis " << i
{ << " (strides: " << m_strides << ").";
throw ngraph_error("Strides distance for slice is zero");
}
size_t result_axis_size = m_upper_bounds[i] - m_lower_bounds[i]; size_t result_axis_size = m_upper_bounds[i] - m_lower_bounds[i];
result_axis_size = result_axis_size =
...@@ -99,10 +92,7 @@ void op::Slice::validate_and_infer_types() ...@@ -99,10 +92,7 @@ void op::Slice::validate_and_infer_types()
shared_ptr<Node> op::Slice::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Slice::copy_with_new_args(const NodeVector& new_args) const
{ {
if (new_args.size() != 1) check_new_args_count(this, new_args);
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Slice>(new_args.at(0), m_lower_bounds, m_upper_bounds, m_strides); return make_shared<Slice>(new_args.at(0), m_lower_bounds, m_upper_bounds, m_strides);
} }
......
...@@ -37,10 +37,9 @@ op::Softmax::Softmax(const shared_ptr<Node>& arg, const AxisSet& axes) ...@@ -37,10 +37,9 @@ op::Softmax::Softmax(const shared_ptr<Node>& arg, const AxisSet& axes)
for (auto axis : m_axes) for (auto axis : m_axes)
{ {
if (axis >= get_shape().size()) NODE_VALIDATION_ASSERT(this, axis < get_shape().size())
{ << "Reduction axis (" << axis << ") is out of bounds (argument shape: " << get_shape()
throw ngraph_error("Axis for softmax reduction operator is out of bounds"); << ").";
}
} }
// empty axes == all axes // empty axes == all axes
...@@ -55,10 +54,7 @@ op::Softmax::Softmax(const shared_ptr<Node>& arg, const AxisSet& axes) ...@@ -55,10 +54,7 @@ op::Softmax::Softmax(const shared_ptr<Node>& arg, const AxisSet& axes)
shared_ptr<Node> op::Softmax::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Softmax::copy_with_new_args(const NodeVector& new_args) const
{ {
if (new_args.size() != 1) check_new_args_count(this, new_args);
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Softmax>(new_args.at(0), m_axes); return make_shared<Softmax>(new_args.at(0), m_axes);
} }
......
...@@ -29,10 +29,7 @@ op::Sqrt::Sqrt(const shared_ptr<Node>& arg) ...@@ -29,10 +29,7 @@ op::Sqrt::Sqrt(const shared_ptr<Node>& arg)
shared_ptr<Node> op::Sqrt::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Sqrt::copy_with_new_args(const NodeVector& new_args) const
{ {
if (new_args.size() != 1) check_new_args_count(this, new_args);
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Sqrt>(new_args.at(0)); return make_shared<Sqrt>(new_args.at(0));
} }
......
...@@ -29,9 +29,6 @@ op::StopGradient::StopGradient(const shared_ptr<Node>& arg) ...@@ -29,9 +29,6 @@ op::StopGradient::StopGradient(const shared_ptr<Node>& arg)
shared_ptr<Node> op::StopGradient::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::StopGradient::copy_with_new_args(const NodeVector& new_args) const
{ {
if (new_args.size() != 1) check_new_args_count(this, new_args);
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<StopGradient>(new_args.at(0)); return make_shared<StopGradient>(new_args.at(0));
} }
...@@ -28,10 +28,7 @@ op::Subtract::Subtract(const shared_ptr<Node>& arg0, const shared_ptr<Node>& arg ...@@ -28,10 +28,7 @@ op::Subtract::Subtract(const shared_ptr<Node>& arg0, const shared_ptr<Node>& arg
shared_ptr<Node> op::Subtract::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Subtract::copy_with_new_args(const NodeVector& new_args) const
{ {
if (new_args.size() != 2) check_new_args_count(this, new_args);
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Subtract>(new_args.at(0), new_args.at(1)); return make_shared<Subtract>(new_args.at(0), new_args.at(1));
} }
......
...@@ -28,10 +28,7 @@ op::Sum::Sum(const shared_ptr<Node>& arg, const AxisSet& reduction_axes) ...@@ -28,10 +28,7 @@ op::Sum::Sum(const shared_ptr<Node>& arg, const AxisSet& reduction_axes)
shared_ptr<Node> op::Sum::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Sum::copy_with_new_args(const NodeVector& new_args) const
{ {
if (new_args.size() != 1) check_new_args_count(this, new_args);
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Sum>(new_args.at(0), m_reduction_axes); return make_shared<Sum>(new_args.at(0), m_reduction_axes);
} }
......
...@@ -30,10 +30,7 @@ op::Tan::Tan(const shared_ptr<Node>& arg) ...@@ -30,10 +30,7 @@ op::Tan::Tan(const shared_ptr<Node>& arg)
shared_ptr<Node> op::Tan::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Tan::copy_with_new_args(const NodeVector& new_args) const
{ {
if (new_args.size() != 1) check_new_args_count(this, new_args);
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Tan>(new_args.at(0)); return make_shared<Tan>(new_args.at(0));
} }
......
...@@ -29,10 +29,7 @@ op::Tanh::Tanh(const shared_ptr<Node>& arg) ...@@ -29,10 +29,7 @@ op::Tanh::Tanh(const shared_ptr<Node>& arg)
shared_ptr<Node> op::Tanh::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Tanh::copy_with_new_args(const NodeVector& new_args) const
{ {
if (new_args.size() != 1) check_new_args_count(this, new_args);
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Tanh>(new_args.at(0)); return make_shared<Tanh>(new_args.at(0));
} }
......
...@@ -33,10 +33,10 @@ void op::util::ArithmeticReduction::validate_and_infer_types() ...@@ -33,10 +33,10 @@ void op::util::ArithmeticReduction::validate_and_infer_types()
for (auto axis : m_reduction_axes) for (auto axis : m_reduction_axes)
{ {
if (axis >= input_shape.size()) NODE_VALIDATION_ASSERT(this, axis < input_shape.size())
{ << "Reduction axis (" << axis << ") is out of bounds "
throw ngraph_error("Reduction axis for arithmetic reduction operator is out of bounds"); << "(argument shape: " << input_shape << ", reduction axes: " << m_reduction_axes
} << ")";
} }
Shape result_shape; Shape result_shape;
......
...@@ -31,9 +31,10 @@ op::util::IndexReduction::IndexReduction(const std::string& node_type, ...@@ -31,9 +31,10 @@ op::util::IndexReduction::IndexReduction(const std::string& node_type,
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
auto rank = arg->get_shape().size(); auto rank = arg->get_shape().size();
TYPE_CHECK_ASSERT(this, rank >= 1) << "Tensor's rank must be at least 1"; NODE_VALIDATION_ASSERT(this, rank >= 1) << "Argument rank must be at least 1";
TYPE_CHECK_ASSERT(this, axis < rank) << "Axis " << axis << " is greater than rank of " << rank; NODE_VALIDATION_ASSERT(this, axis < rank) << "Axis " << axis << " is greater than rank of "
TYPE_CHECK_ASSERT(this, << rank;
NODE_VALIDATION_ASSERT(this,
index_element_type == element::i32 || index_element_type == element::i64) index_element_type == element::i32 || index_element_type == element::i64)
<< "Index element type must be i64 or i32"; << "Index element type must be i64 or i32";
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment