Commit 1d1d6633 authored by Robert Kimball's avatar Robert Kimball Committed by Scott Cyphers

simplify argument type check (#398)

parent bd01bf2c
...@@ -20,12 +20,10 @@ ...@@ -20,12 +20,10 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::BinaryElementwise::BinaryElementwise( op::BinaryElementwise::BinaryElementwise(const std::string& node_type,
const std::string& node_type, const element::Type& result_element_type,
std::function<const element::Type&(const element::Type&, const element::Type&)> const std::shared_ptr<Node>& arg0,
element_type_function, const std::shared_ptr<Node>& arg1)
const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
: RequiresTensorViewArgs(node_type, Nodes{arg0, arg1}) : RequiresTensorViewArgs(node_type, Nodes{arg0, arg1})
{ {
auto& input_0 = get_inputs().at(0); auto& input_0 = get_inputs().at(0);
...@@ -35,8 +33,5 @@ op::BinaryElementwise::BinaryElementwise( ...@@ -35,8 +33,5 @@ op::BinaryElementwise::BinaryElementwise(
throw ngraph_error("Arguments must have the same tensor view shape"); throw ngraph_error("Arguments must have the same tensor view shape");
} }
const element::Type& result_element_type =
element_type_function(input_0.get_element_type(), input_1.get_element_type());
set_value_type_checked(make_shared<TensorViewType>(result_element_type, input_0.get_shape())); set_value_type_checked(make_shared<TensorViewType>(result_element_type, input_0.get_shape()));
} }
...@@ -20,24 +20,15 @@ using namespace ngraph; ...@@ -20,24 +20,15 @@ using namespace ngraph;
op::BinaryElementwiseArithmetic::BinaryElementwiseArithmetic(const std::string& node_type, op::BinaryElementwiseArithmetic::BinaryElementwiseArithmetic(const std::string& node_type,
const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1) const std::shared_ptr<Node>& arg1)
: BinaryElementwise( : BinaryElementwise(node_type, arg0->get_element_type(), arg0, arg1)
node_type,
[](const element::Type& arg0_element_type,
const element::Type& arg1_element_type) -> const element::Type& {
if (arg0_element_type != arg1_element_type)
{
throw ngraph_error("Arguments must have the same tensor view element type");
}
if (arg0_element_type == element::boolean)
{
throw ngraph_error(
"Operands for arithmetic operators must have numeric element type");
}
return arg0_element_type;
},
arg0,
arg1)
{ {
if (arg0->get_element_type() != arg1->get_element_type())
{
throw ngraph_error("Arguments must have the same tensor view element type");
}
if (arg0->get_element_type() == element::boolean)
{
throw ngraph_error("Operands for arithmetic operators must have numeric element type");
}
} }
...@@ -21,18 +21,10 @@ using namespace ngraph; ...@@ -21,18 +21,10 @@ using namespace ngraph;
op::BinaryElementwiseComparison::BinaryElementwiseComparison(const std::string& node_type, op::BinaryElementwiseComparison::BinaryElementwiseComparison(const std::string& node_type,
const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1) const std::shared_ptr<Node>& arg1)
: BinaryElementwise(node_type, : BinaryElementwise(node_type, element::boolean, arg0, arg1)
[](const element::Type& arg0_element_type,
const element::Type& arg1_element_type) -> const element::Type& {
if (arg0_element_type != arg1_element_type)
{
throw ngraph_error(
"Arguments must have the same tensor view element type");
}
return element::boolean;
},
arg0,
arg1)
{ {
if (arg0->get_element_type() != arg1->get_element_type())
{
throw ngraph_error("Arguments must have the same tensor view element type");
}
} }
...@@ -20,11 +20,7 @@ using namespace std; ...@@ -20,11 +20,7 @@ using namespace std;
using namespace ngraph; using namespace ngraph;
op::Convert::Convert(const std::shared_ptr<Node>& arg, const ngraph::element::Type& element_type) op::Convert::Convert(const std::shared_ptr<Node>& arg, const ngraph::element::Type& element_type)
: UnaryElementwise("Convert", : UnaryElementwise("Convert", element_type, arg)
[&](const ngraph::element::Type& ignored) -> const ngraph::element::Type& {
return element_type;
},
arg)
, m_element_type(element_type) , m_element_type(element_type)
{ {
} }
......
...@@ -19,19 +19,7 @@ using namespace ngraph; ...@@ -19,19 +19,7 @@ using namespace ngraph;
using namespace ngraph::op; using namespace ngraph::op;
op::Not::Not(const std::shared_ptr<Node>& arg) op::Not::Not(const std::shared_ptr<Node>& arg)
: UnaryElementwise( : UnaryElementwise("Not", arg->get_element_type(), arg)
"Not",
[](const ngraph::element::Type& arg_element_type) -> const ngraph::element::Type& {
if (arg_element_type != element::boolean)
{
throw ngraph_error(
"Operands for logical operators must have boolean element "
"type");
}
return arg_element_type;
},
arg)
{ {
} }
......
...@@ -60,10 +60,9 @@ namespace ngraph ...@@ -60,10 +60,9 @@ namespace ngraph
/// \brief Constructs a unary elementwise tensor operation. /// \brief Constructs a unary elementwise tensor operation.
/// ///
/// \param arg Node that produces the input tensor. /// \param arg Node that produces the input tensor.
UnaryElementwise( UnaryElementwise(const std::string& node_type,
const std::string& node_type, const element::Type& result_element_type,
std::function<const element::Type&(const element::Type&)> element_type_function, const std::shared_ptr<Node>& arg);
const std::shared_ptr<Node>& arg);
}; };
/// \brief Abstract base class for elementwise unary arithmetic operations, i.e., operations where the same /// \brief Abstract base class for elementwise unary arithmetic operations, i.e., operations where the same
...@@ -119,12 +118,10 @@ namespace ngraph ...@@ -119,12 +118,10 @@ namespace ngraph
/// ///
/// \param arg0 Node that produces the first input tensor. /// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor. /// \param arg1 Node that produces the second input tensor.
BinaryElementwise( BinaryElementwise(const std::string& node_type,
const std::string& node_type, const element::Type& result_element_type,
std::function<const element::Type&(const element::Type&, const element::Type&)> const std::shared_ptr<Node>& arg0,
element_type_function, const std::shared_ptr<Node>& arg1);
const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1);
}; };
/// \brief Abstract base class for elementwise binary comparison operations, i.e., operations where the same /// \brief Abstract base class for elementwise binary comparison operations, i.e., operations where the same
......
...@@ -19,14 +19,12 @@ ...@@ -19,14 +19,12 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::UnaryElementwise::UnaryElementwise( op::UnaryElementwise::UnaryElementwise(const std::string& node_type,
const std::string& node_type, const element::Type& result_element_type,
std::function<const element::Type&(const element::Type&)> element_type_function, const std::shared_ptr<Node>& arg)
const std::shared_ptr<Node>& arg)
: RequiresTensorViewArgs(node_type, Nodes{arg}) : RequiresTensorViewArgs(node_type, Nodes{arg})
{ {
auto& input = get_inputs().at(0); auto& input = get_inputs().at(0);
const element::Type& result_element_type = element_type_function(input.get_element_type());
set_value_type_checked(result_element_type, input.get_shape()); set_value_type_checked(result_element_type, input.get_shape());
} }
...@@ -18,18 +18,10 @@ using namespace ngraph; ...@@ -18,18 +18,10 @@ using namespace ngraph;
op::UnaryElementwiseArithmetic::UnaryElementwiseArithmetic(const std::string& node_type, op::UnaryElementwiseArithmetic::UnaryElementwiseArithmetic(const std::string& node_type,
const std::shared_ptr<Node>& arg) const std::shared_ptr<Node>& arg)
: UnaryElementwise( : UnaryElementwise(node_type, arg->get_element_type(), arg)
node_type,
[](const ngraph::element::Type& arg_element_type) -> const ngraph::element::Type& {
if (arg_element_type == element::boolean)
{
throw ngraph_error(
"Operands for arithmetic operators must have numeric element "
"type");
}
return arg_element_type;
},
arg)
{ {
if (arg->get_element_type() == element::boolean)
{
throw ngraph_error("Operands for arithmetic operators must have numeric element type");
}
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment