Commit 90820076 authored by Tomasz Dołbniak's avatar Tomasz Dołbniak Committed by Michał Karzyński

[ONNX] Where operator support (#2449)

parent bfb48511
...@@ -162,6 +162,7 @@ add_library(onnx_import STATIC ...@@ -162,6 +162,7 @@ add_library(onnx_import STATIC
op/transpose.hpp op/transpose.hpp
op/unsqueeze.cpp op/unsqueeze.cpp
op/unsqueeze.hpp op/unsqueeze.hpp
op/where.hpp
op/xor.hpp op/xor.hpp
ops_bridge.cpp ops_bridge.cpp
ops_bridge.hpp ops_bridge.hpp
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <memory>
#include "core/node.hpp"
#include "ngraph/node_vector.hpp"
#include "ngraph/op/select.hpp"
#include "utils/broadcasting.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace set_1
{
inline NodeVector where(const Node& node)
{
NodeVector ng_inputs{numpy_style_broadcast(node.get_ng_inputs())};
return {std::make_shared<ngraph::op::Select>(
ng_inputs.at(0), ng_inputs.at(1), ng_inputs.at(2))};
}
} // namespace set_1
} //namespace op
} // namespace onnx_import
} // namespace ngraph
...@@ -103,6 +103,7 @@ ...@@ -103,6 +103,7 @@
#include "op/topk.hpp" #include "op/topk.hpp"
#include "op/transpose.hpp" #include "op/transpose.hpp"
#include "op/unsqueeze.hpp" #include "op/unsqueeze.hpp"
#include "op/where.hpp"
#include "op/xor.hpp" #include "op/xor.hpp"
#include "ops_bridge.hpp" #include "ops_bridge.hpp"
...@@ -304,6 +305,7 @@ namespace ngraph ...@@ -304,6 +305,7 @@ namespace ngraph
REGISTER_OPERATOR("TopK", 1, topk); REGISTER_OPERATOR("TopK", 1, topk);
REGISTER_OPERATOR("Transpose", 1, transpose); REGISTER_OPERATOR("Transpose", 1, transpose);
REGISTER_OPERATOR("Unsqueeze", 1, unsqueeze); REGISTER_OPERATOR("Unsqueeze", 1, unsqueeze);
REGISTER_OPERATOR("Where", 1, where);
REGISTER_OPERATOR("Xor", 1, logical_xor); REGISTER_OPERATOR("Xor", 1, logical_xor);
} }
......
...@@ -14,12 +14,14 @@ ...@@ -14,12 +14,14 @@
// limitations under the License. // limitations under the License.
//***************************************************************************** //*****************************************************************************
#include <cstddef>
#include <iterator> #include <iterator>
#include <numeric> #include <numeric>
#include <vector> #include <vector>
#include "broadcasting.hpp" #include "broadcasting.hpp"
#include "ngraph/axis_vector.hpp" #include "ngraph/axis_vector.hpp"
#include "ngraph/log.hpp"
#include "ngraph/op/broadcast.hpp" #include "ngraph/op/broadcast.hpp"
#include "ngraph/op/reshape.hpp" #include "ngraph/op/reshape.hpp"
#include "reshape.hpp" #include "reshape.hpp"
...@@ -38,15 +40,12 @@ static std::vector<ngraph::Shape> get_numpy_broadcast_shape(ngraph::Shape left_s ...@@ -38,15 +40,12 @@ static std::vector<ngraph::Shape> get_numpy_broadcast_shape(ngraph::Shape left_s
auto rank_right = right_shape.size(); auto rank_right = right_shape.size();
auto max_rank = std::max(rank_left, rank_right); auto max_rank = std::max(rank_left, rank_right);
for (auto i = 0; i < (max_rank - rank_left); ++i) // left-pad the left_shape with ones
{ left_shape.insert(std::begin(left_shape), max_rank - rank_left, 1);
left_shape.insert(std::begin(left_shape), 1); // left-pad the right_shape with ones
} right_shape.insert(std::begin(right_shape), max_rank - rank_right, 1);
for (auto i = 0; i < (max_rank - rank_right); ++i)
{ for (std::size_t index = 0; index < max_rank; ++index)
right_shape.insert(std::begin(right_shape), 1);
}
for (auto index = 0; index < max_rank; ++index)
{ {
output_shape.push_back(std::max(left_shape.at(index), right_shape.at(index))); output_shape.push_back(std::max(left_shape.at(index), right_shape.at(index)));
} }
...@@ -54,22 +53,64 @@ static std::vector<ngraph::Shape> get_numpy_broadcast_shape(ngraph::Shape left_s ...@@ -54,22 +53,64 @@ static std::vector<ngraph::Shape> get_numpy_broadcast_shape(ngraph::Shape left_s
return {output_shape, left_shape, right_shape}; return {output_shape, left_shape, right_shape};
} }
/// \brief Calculate the output shape of numpy-style broadcast operation for all input nodes.
///
/// This function finds the maximum tensor shape that will be the result of element-wise operation
/// that will be applied to the inputs vector. The function also prepares the shape of each input
/// for the element-wise operation by left-padding those shapes so that their rank is equal to
/// the target_shape's rank.
///
/// \param inputs A vector of input nodes for which a common shape should be found
/// \return A pair that contains the target shape as its first object and a vector of padded
/// input shapes ready to be broadcasted as the second object
static std::pair<ngraph::Shape, std::vector<ngraph::Shape>>
get_numpy_broadcast_shapes(const ngraph::NodeVector& inputs)
{
auto shape_left_fold = [](const ngraph::Shape& accumulator,
const std::shared_ptr<ngraph::Node>& input) {
// TODO: in a separate PR remove the 'get_numpy_broadcast_shape' function
return get_numpy_broadcast_shape(accumulator, input->get_shape()).at(0);
};
ngraph::Shape target_shape =
std::accumulate(std::begin(inputs), std::end(inputs), ngraph::Shape{}, shape_left_fold);
std::vector<ngraph::Shape> full_shapes;
for (const std::shared_ptr<ngraph::Node>& input : inputs)
{
ngraph::Shape padded_shape = input->get_shape();
padded_shape.insert(std::begin(padded_shape), target_shape.size() - padded_shape.size(), 1);
full_shapes.push_back(std::move(padded_shape));
}
return {target_shape, full_shapes};
}
/// \brief Broadcast input node. /// \brief Broadcast input node.
/// ///
/// \note The source shape does not have to be the actual shape of input node. However /// \note The source shape does not have to be the actual shape of input node. However
/// it should be a superset of it (containing it as a continuous subset). This implies /// it should be a superset of it (containing it as a continuous subset). This implies
/// we may expand the number of axes of input node. /// we may expand the number of axes of input node.
/// The ranks of source_shape and output_shape must be equal. This means that the
/// source_shape has to be padded with ones for this operation.
/// ///
/// \param[in] node The input Node to be broadcasted. /// \param[in] node The input Node to be broadcasted.
/// \param[in] output_shape The output shape. /// \param[in] output_shape The output shape.
/// \param[in] source_shape The source shape from which we want to broadcast input node. /// \param[in] source_shape The source shape from which we want to broadcast input node.
/// ///
/// \return The boroadcasted Node. /// \return The broadcasted Node.
/// ///
static std::shared_ptr<ngraph::Node> broadcast(const std::shared_ptr<ngraph::Node>& node, static std::shared_ptr<ngraph::Node>
const ngraph::Shape& output_shape, broadcast_node_numpy_style(const std::shared_ptr<ngraph::Node>& node,
const ngraph::Shape& source_shape) const ngraph::Shape& output_shape,
const ngraph::Shape& source_shape)
{ {
if (source_shape.size() != output_shape.size())
{
NGRAPH_WARN << "Ranks of source_shape and output_shape dont match: " << source_shape.size()
<< " vs " << output_shape.size();
}
ngraph::AxisVector broadcast_axes; ngraph::AxisVector broadcast_axes;
ngraph::Shape squeezed_shape; ngraph::Shape squeezed_shape;
// Positions of axes which have length of 1 are needed to calculate broadcast_axes // Positions of axes which have length of 1 are needed to calculate broadcast_axes
...@@ -111,8 +152,31 @@ namespace ngraph ...@@ -111,8 +152,31 @@ namespace ngraph
auto left_full_shape = numpy_shapes.at(1); auto left_full_shape = numpy_shapes.at(1);
auto right_full_shape = numpy_shapes.at(2); auto right_full_shape = numpy_shapes.at(2);
return {broadcast(left, output_shape, left_full_shape), return {broadcast_node_numpy_style(left, output_shape, left_full_shape),
broadcast(right, output_shape, right_full_shape)}; broadcast_node_numpy_style(right, output_shape, right_full_shape)};
}
NodeVector numpy_style_broadcast(NodeVector inputs)
{
if (inputs.size() <= 1)
{
return inputs;
}
// find the output tensor's shape, then broadcast all inputs so that they are compatible
auto bcast_shapes = get_numpy_broadcast_shapes(inputs);
NodeVector broadcasted_inputs;
for (std::size_t i = 0; i < inputs.size(); ++i)
{
const std::shared_ptr<ngraph::Node> input_node = inputs[i];
Shape source_shape = input_node->get_shape();
broadcasted_inputs.push_back(broadcast_node_numpy_style(
inputs[i], bcast_shapes.first, bcast_shapes.second[i]));
}
return broadcasted_inputs;
} }
NodeVector NodeVector
...@@ -147,8 +211,8 @@ namespace ngraph ...@@ -147,8 +211,8 @@ namespace ngraph
std::next(std::begin(right_shape), right_shape.size() - 2), std::next(std::begin(right_shape), right_shape.size() - 2),
std::end(right_shape)); std::end(right_shape));
return {broadcast(left, left_output_shape, left_full_shape), return {broadcast_node_numpy_style(left, left_output_shape, left_full_shape),
broadcast(right, right_output_shape, right_full_shape)}; broadcast_node_numpy_style(right, right_output_shape, right_full_shape)};
} }
NodeVector NodeVector
...@@ -181,8 +245,8 @@ namespace ngraph ...@@ -181,8 +245,8 @@ namespace ngraph
} }
// Find first dimensions at front with length different from 1 // Find first dimensions at front with length different from 1
size_t num_ones = 0; std::size_t num_ones = 0;
for (size_t dimension : new_right_shape) for (std::size_t dimension : new_right_shape)
{ {
if (dimension == 1) if (dimension == 1)
{ {
...@@ -216,7 +280,7 @@ namespace ngraph ...@@ -216,7 +280,7 @@ namespace ngraph
const Shape& input_shape, const Shape& input_shape,
std::size_t start_match_axis) std::size_t start_match_axis)
{ {
std::vector<size_t> result(output_shape.size() - input_shape.size()); std::vector<std::size_t> result(output_shape.size() - input_shape.size());
// Populate the result vector with monotonic increasing series from 0 until // Populate the result vector with monotonic increasing series from 0 until
// output_shape_size, excluding values in range [start_match_axis, start_match_axis + input_shape.size() // output_shape_size, excluding values in range [start_match_axis, start_match_axis + input_shape.size()
std::iota(std::begin(result), std::begin(result) + start_match_axis, 0); std::iota(std::begin(result), std::begin(result) + start_match_axis, 0);
......
...@@ -47,6 +47,13 @@ namespace ngraph ...@@ -47,6 +47,13 @@ namespace ngraph
return numpy_style_broadcast_for_binary_operation(inputs.at(0), inputs.at(1)); return numpy_style_broadcast_for_binary_operation(inputs.at(0), inputs.at(1));
} }
/// \brief Cast shape of all input nodes for an element-wise operation that requires shape-compatibility
///
/// \param inputs Original list of inputs
///
/// \return Numpy-style broadcasted list of nodes.
NodeVector numpy_style_broadcast(NodeVector inputs);
/// \brief Cast shape of two nodes to make them compatible for an element-wise binary operation. /// \brief Cast shape of two nodes to make them compatible for an element-wise binary operation.
/// ///
/// If necessary the right-hand-side argument will be broadcast to match the shape /// If necessary the right-hand-side argument will be broadcast to match the shape
......
ngraph ONNXImporter:±
"
cond cond_bool"Cast*
to  

cond_bool
x1
x2y"Where where_graphZ
cond



Z
x1


Z
x2


b
y



B
\ No newline at end of file
...@@ -1969,3 +1969,34 @@ TEST(onnx_${BACKEND_NAME}, model_sign) ...@@ -1969,3 +1969,34 @@ TEST(onnx_${BACKEND_NAME}, model_sign)
EXPECT_TRUE(test::all_close_f(expected_outputs.front(), outputs.front())); EXPECT_TRUE(test::all_close_f(expected_outputs.front(), outputs.front()));
} }
TEST(onnx_${BACKEND_NAME}, model_where)
{
auto function =
onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/where.onnx"));
using Inputs = std::vector<std::vector<int>>;
using Outputs = std::vector<std::vector<int>>;
// conditions tensor - 3x3x3
auto condition = std::vector<int>{
{0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0}};
// 1x3 tensor of "1"
auto x1 = std::vector<int>{1, 1, 1};
// 3x1 tensor of "2"
auto x2 = std::vector<int>{2, 2, 2};
Inputs inputs;
inputs.push_back(std::move(condition));
inputs.push_back(std::move(x1));
inputs.push_back(std::move(x2));
// y = 3x3x3
Outputs expected_outputs{
{2, 1, 2, 1, 2, 1, 2, 1, 2, 2, 1, 2, 1, 2, 1, 2, 1, 2, 2, 1, 2, 1, 2, 1, 2, 1, 2}};
Outputs outputs{execute(function, inputs, "${BACKEND_NAME}")};
EXPECT_EQ(expected_outputs.front(), outputs.front());
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment