Commit a8ce39d6 authored by Adam Rogowiec's avatar Adam Rogowiec Committed by Scott Cyphers

[ONNX] Variadic operators in opset 8. (#2261)

* Add broadcasting to variadic OP in opset 8.

* Apply style format.

* Update onnx_import.cpp
parent f5b2d581
......@@ -36,6 +36,15 @@ namespace ngraph
} // namespace set_1
namespace set_8
{
inline NodeVector max(const Node& node)
{
return variadic::make_ng_variadic_op_with_broadcast<ngraph::op::Maximum>(node);
}
} // namespace set_8
} //namespace op
} // namespace onnx_import
......
......@@ -44,6 +44,25 @@ namespace ngraph
} // namespace set_1
namespace set_8
{
NodeVector mean(const Node& node)
{
auto sum =
variadic::make_ng_variadic_op_with_broadcast<ngraph::op::Add>(node).front();
auto shape = sum->get_shape();
// Create a Constant representing the number of inputs with the same shape as sum
auto count = ngraph::op::Constant::create(
sum->get_element_type(),
shape,
std::vector<int>(shape_size(shape), node.get_ng_inputs().size()));
return {sum / count};
}
} // namespace set_8
} //namespace op
} // namespace onnx_import
......
......@@ -31,6 +31,12 @@ namespace ngraph
} // namespace set_1
namespace set_8
{
NodeVector mean(const Node& node);
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -36,6 +36,15 @@ namespace ngraph
} // namespace set_1
namespace set_8
{
inline NodeVector min(const Node& node)
{
return variadic::make_ng_variadic_op_with_broadcast<ngraph::op::Minimum>(node);
}
} // namespace set_8
} //namespace op
} // namespace onnx_import
......
......@@ -36,6 +36,15 @@ namespace ngraph
} // namespace set_1
namespace set_8
{
inline NodeVector sum(const Node& node)
{
return variadic::make_ng_variadic_op_with_broadcast<ngraph::op::Add>(node);
}
} // namespace set_8
} //namespace op
} // namespace onnx_import
......
......@@ -190,8 +190,11 @@ namespace ngraph
REGISTER_OPERATOR("MatMul", 1, matmul);
REGISTER_OPERATOR("MaxPool", 1, max_pool);
REGISTER_OPERATOR("Max", 1, max);
REGISTER_OPERATOR("Max", 8, max);
REGISTER_OPERATOR("Mean", 1, mean);
REGISTER_OPERATOR("Mean", 8, mean);
REGISTER_OPERATOR("Min", 1, min);
REGISTER_OPERATOR("Min", 8, min);
REGISTER_OPERATOR("Mul", 1, mul);
REGISTER_OPERATOR("Mul", 7, mul);
REGISTER_OPERATOR("Neg", 1, neg);
......@@ -228,6 +231,7 @@ namespace ngraph
REGISTER_OPERATOR("Sub", 1, sub);
REGISTER_OPERATOR("Sub", 7, sub);
REGISTER_OPERATOR("Sum", 1, sum);
REGISTER_OPERATOR("Sum", 8, sum);
REGISTER_OPERATOR("Tan", 1, tan);
REGISTER_OPERATOR("Tanh", 1, tanh);
REGISTER_OPERATOR("ThresholdedRelu", 1, thresholded_relu);
......
......@@ -24,6 +24,7 @@
#include "ngraph/node_vector.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/shape.hpp"
#include "utils/broadcasting.hpp"
namespace ngraph
{
......@@ -34,8 +35,10 @@ namespace ngraph
/// \brief Create an nGraph version of an ONNX variadic operation.
/// This creates a subgraph with a series of binary operations.
///
/// \tparam T Class of an nGraph binary operation (e.g. Add, Minimum, Maximum)
/// \param node incoming ONNX opearation
/// \param node Incoming ONNX opearation.
///
/// \tparam T Class of an nGraph binary operation (e.g. Add, Minimum, Maximum)
///
/// \return nGraph node equivalent of the ONNX operation
template <class T>
inline NodeVector make_ng_variadic_op(const Node& node)
......@@ -58,6 +61,36 @@ namespace ngraph
return {result};
}
/// \brief Create an nGraph version of an ONNX variadic operation.
/// This creates a subgraph with a series of binary operations.
///
/// \param node Incoming ONNX opearation.
///
/// \tparam T Class of an nGraph binary operation (e.g. Add, Minimum, Maximum)
///
/// \return nGraph node equivalent of the ONNX operation
template <class T>
inline NodeVector make_ng_variadic_op_with_broadcast(const Node& node)
{
NodeVector ng_inputs{node.get_ng_inputs()};
// Templated binary operation - Creates Add, Minimum, Maximum, etc.
auto binary_operation = [](const std::shared_ptr<ngraph::Node>& arg0,
const std::shared_ptr<ngraph::Node>& arg1) {
NodeVector args{numpy_style_broadcast_for_binary_operation(arg0, arg1)};
return std::make_shared<T>(args.at(0), args.at(1));
};
// Create a result node as a series of binary operations
auto result = std::accumulate(
std::next(std::begin(ng_inputs)), // First operand value - the second input
std::end(ng_inputs), // Last value - final input
ng_inputs.front(), // Initial value - first input
binary_operation);
return {result};
}
} // namespace variadic
} // namespace onnx_import
......
......@@ -1535,3 +1535,25 @@ TEST(onnx, model_matmul_vec_ten3d)
Outputs outputs{execute(function, inputs, "INTERPRETER")};
EXPECT_TRUE(test::all_close_f(expected_output.front(), outputs.front()));
}
TEST(onnx, model_sum_opset8)
{
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/sum_opset8.onnx"));
Inputs inputs;
inputs.emplace_back(std::vector<float>{1.0f, 2.0f, 3.0f});
inputs.emplace_back(test::NDArray<float, 2>{{10.0f}, {20.0f}, {30.0f}}.get_vector());
inputs.emplace_back(test::NDArray<float, 3>{{{100.0f}}, {{200.0f}}, {{300.0f}}}.get_vector());
Outputs expected_output{test::NDArray<float, 3>{
{{111.0f, 112.0f, 113.0f}, {121.0f, 122.0f, 123.0f}, {131.0f, 132.0f, 133.0f}},
{{211.0f, 212.0f, 213.0f}, {221.0f, 222.0f, 223.0f}, {231.0f, 232.0f, 233.0f}},
{{311.0f, 312.0f, 313.0f}, {321.0f, 322.0f, 323.0f}, {331.0f, 332.0f, 333.0f}}}
.get_vector()};
Outputs outputs{execute(function, inputs, "INTERPRETER")};
EXPECT_TRUE(test::all_close_f(expected_output.front(), outputs.front()));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment