Commit 1d1d6633 authored by Robert Kimball's avatar Robert Kimball Committed by Scott Cyphers

simplify argument type check (#398)

parent bd01bf2c
......@@ -20,12 +20,10 @@
using namespace std;
using namespace ngraph;
op::BinaryElementwise::BinaryElementwise(
const std::string& node_type,
std::function<const element::Type&(const element::Type&, const element::Type&)>
element_type_function,
const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
op::BinaryElementwise::BinaryElementwise(const std::string& node_type,
const element::Type& result_element_type,
const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
: RequiresTensorViewArgs(node_type, Nodes{arg0, arg1})
{
auto& input_0 = get_inputs().at(0);
......@@ -35,8 +33,5 @@ op::BinaryElementwise::BinaryElementwise(
throw ngraph_error("Arguments must have the same tensor view shape");
}
const element::Type& result_element_type =
element_type_function(input_0.get_element_type(), input_1.get_element_type());
set_value_type_checked(make_shared<TensorViewType>(result_element_type, input_0.get_shape()));
}
......@@ -20,24 +20,15 @@ using namespace ngraph;
op::BinaryElementwiseArithmetic::BinaryElementwiseArithmetic(const std::string& node_type,
const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
: BinaryElementwise(
node_type,
[](const element::Type& arg0_element_type,
const element::Type& arg1_element_type) -> const element::Type& {
if (arg0_element_type != arg1_element_type)
{
throw ngraph_error("Arguments must have the same tensor view element type");
}
if (arg0_element_type == element::boolean)
{
throw ngraph_error(
"Operands for arithmetic operators must have numeric element type");
}
return arg0_element_type;
},
arg0,
arg1)
: BinaryElementwise(node_type, arg0->get_element_type(), arg0, arg1)
{
if (arg0->get_element_type() != arg1->get_element_type())
{
throw ngraph_error("Arguments must have the same tensor view element type");
}
if (arg0->get_element_type() == element::boolean)
{
throw ngraph_error("Operands for arithmetic operators must have numeric element type");
}
}
......@@ -21,18 +21,10 @@ using namespace ngraph;
op::BinaryElementwiseComparison::BinaryElementwiseComparison(const std::string& node_type,
const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
: BinaryElementwise(node_type,
[](const element::Type& arg0_element_type,
const element::Type& arg1_element_type) -> const element::Type& {
if (arg0_element_type != arg1_element_type)
{
throw ngraph_error(
"Arguments must have the same tensor view element type");
}
return element::boolean;
},
arg0,
arg1)
: BinaryElementwise(node_type, element::boolean, arg0, arg1)
{
if (arg0->get_element_type() != arg1->get_element_type())
{
throw ngraph_error("Arguments must have the same tensor view element type");
}
}
......@@ -20,11 +20,7 @@ using namespace std;
using namespace ngraph;
op::Convert::Convert(const std::shared_ptr<Node>& arg, const ngraph::element::Type& element_type)
: UnaryElementwise("Convert",
[&](const ngraph::element::Type& ignored) -> const ngraph::element::Type& {
return element_type;
},
arg)
: UnaryElementwise("Convert", element_type, arg)
, m_element_type(element_type)
{
}
......
......@@ -19,19 +19,7 @@ using namespace ngraph;
using namespace ngraph::op;
op::Not::Not(const std::shared_ptr<Node>& arg)
: UnaryElementwise(
"Not",
[](const ngraph::element::Type& arg_element_type) -> const ngraph::element::Type& {
if (arg_element_type != element::boolean)
{
throw ngraph_error(
"Operands for logical operators must have boolean element "
"type");
}
return arg_element_type;
},
arg)
: UnaryElementwise("Not", arg->get_element_type(), arg)
{
}
......
......@@ -60,10 +60,9 @@ namespace ngraph
/// \brief Constructs a unary elementwise tensor operation.
///
/// \param arg Node that produces the input tensor.
UnaryElementwise(
const std::string& node_type,
std::function<const element::Type&(const element::Type&)> element_type_function,
const std::shared_ptr<Node>& arg);
UnaryElementwise(const std::string& node_type,
const element::Type& result_element_type,
const std::shared_ptr<Node>& arg);
};
/// \brief Abstract base class for elementwise unary arithmetic operations, i.e., operations where the same
......@@ -119,12 +118,10 @@ namespace ngraph
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
BinaryElementwise(
const std::string& node_type,
std::function<const element::Type&(const element::Type&, const element::Type&)>
element_type_function,
const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1);
BinaryElementwise(const std::string& node_type,
const element::Type& result_element_type,
const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1);
};
/// \brief Abstract base class for elementwise binary comparison operations, i.e., operations where the same
......
......@@ -19,14 +19,12 @@
using namespace std;
using namespace ngraph;
op::UnaryElementwise::UnaryElementwise(
const std::string& node_type,
std::function<const element::Type&(const element::Type&)> element_type_function,
const std::shared_ptr<Node>& arg)
op::UnaryElementwise::UnaryElementwise(const std::string& node_type,
const element::Type& result_element_type,
const std::shared_ptr<Node>& arg)
: RequiresTensorViewArgs(node_type, Nodes{arg})
{
auto& input = get_inputs().at(0);
const element::Type& result_element_type = element_type_function(input.get_element_type());
set_value_type_checked(result_element_type, input.get_shape());
}
......@@ -18,18 +18,10 @@ using namespace ngraph;
op::UnaryElementwiseArithmetic::UnaryElementwiseArithmetic(const std::string& node_type,
const std::shared_ptr<Node>& arg)
: UnaryElementwise(
node_type,
[](const ngraph::element::Type& arg_element_type) -> const ngraph::element::Type& {
if (arg_element_type == element::boolean)
{
throw ngraph_error(
"Operands for arithmetic operators must have numeric element "
"type");
}
return arg_element_type;
},
arg)
: UnaryElementwise(node_type, arg->get_element_type(), arg)
{
if (arg->get_element_type() == element::boolean)
{
throw ngraph_error("Operands for arithmetic operators must have numeric element type");
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment