Commit 90820076 authored by Tomasz Dołbniak's avatar Tomasz Dołbniak Committed by Michał Karzyński

[ONNX] Where operator support (#2449)

parent bfb48511
......@@ -162,6 +162,7 @@ add_library(onnx_import STATIC
op/transpose.hpp
op/unsqueeze.cpp
op/unsqueeze.hpp
op/where.hpp
op/xor.hpp
ops_bridge.cpp
ops_bridge.hpp
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <memory>
#include "core/node.hpp"
#include "ngraph/node_vector.hpp"
#include "ngraph/op/select.hpp"
#include "utils/broadcasting.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace set_1
{
inline NodeVector where(const Node& node)
{
NodeVector ng_inputs{numpy_style_broadcast(node.get_ng_inputs())};
return {std::make_shared<ngraph::op::Select>(
ng_inputs.at(0), ng_inputs.at(1), ng_inputs.at(2))};
}
} // namespace set_1
} //namespace op
} // namespace onnx_import
} // namespace ngraph
......@@ -103,6 +103,7 @@
#include "op/topk.hpp"
#include "op/transpose.hpp"
#include "op/unsqueeze.hpp"
#include "op/where.hpp"
#include "op/xor.hpp"
#include "ops_bridge.hpp"
......@@ -304,6 +305,7 @@ namespace ngraph
REGISTER_OPERATOR("TopK", 1, topk);
REGISTER_OPERATOR("Transpose", 1, transpose);
REGISTER_OPERATOR("Unsqueeze", 1, unsqueeze);
REGISTER_OPERATOR("Where", 1, where);
REGISTER_OPERATOR("Xor", 1, logical_xor);
}
......
......@@ -14,12 +14,14 @@
// limitations under the License.
//*****************************************************************************
#include <cstddef>
#include <iterator>
#include <numeric>
#include <vector>
#include "broadcasting.hpp"
#include "ngraph/axis_vector.hpp"
#include "ngraph/log.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/reshape.hpp"
#include "reshape.hpp"
......@@ -38,15 +40,12 @@ static std::vector<ngraph::Shape> get_numpy_broadcast_shape(ngraph::Shape left_s
auto rank_right = right_shape.size();
auto max_rank = std::max(rank_left, rank_right);
for (auto i = 0; i < (max_rank - rank_left); ++i)
{
left_shape.insert(std::begin(left_shape), 1);
}
for (auto i = 0; i < (max_rank - rank_right); ++i)
{
right_shape.insert(std::begin(right_shape), 1);
}
for (auto index = 0; index < max_rank; ++index)
// left-pad the left_shape with ones
left_shape.insert(std::begin(left_shape), max_rank - rank_left, 1);
// left-pad the right_shape with ones
right_shape.insert(std::begin(right_shape), max_rank - rank_right, 1);
for (std::size_t index = 0; index < max_rank; ++index)
{
output_shape.push_back(std::max(left_shape.at(index), right_shape.at(index)));
}
......@@ -54,22 +53,64 @@ static std::vector<ngraph::Shape> get_numpy_broadcast_shape(ngraph::Shape left_s
return {output_shape, left_shape, right_shape};
}
/// \brief Calculate the output shape of numpy-style broadcast operation for all input nodes.
///
/// This function finds the maximum tensor shape that will be the result of element-wise operation
/// that will be applied to the inputs vector. The function also prepares the shape of each input
/// for the element-wise operation by left-padding those shapes so that their rank is equal to
/// the target_shape's rank.
///
/// \param inputs A vector of input nodes for which a common shape should be found
/// \return A pair that contains the target shape as its first object and a vector of padded
/// input shapes ready to be broadcasted as the second object
static std::pair<ngraph::Shape, std::vector<ngraph::Shape>>
get_numpy_broadcast_shapes(const ngraph::NodeVector& inputs)
{
auto shape_left_fold = [](const ngraph::Shape& accumulator,
const std::shared_ptr<ngraph::Node>& input) {
// TODO: in a separate PR remove the 'get_numpy_broadcast_shape' function
return get_numpy_broadcast_shape(accumulator, input->get_shape()).at(0);
};
ngraph::Shape target_shape =
std::accumulate(std::begin(inputs), std::end(inputs), ngraph::Shape{}, shape_left_fold);
std::vector<ngraph::Shape> full_shapes;
for (const std::shared_ptr<ngraph::Node>& input : inputs)
{
ngraph::Shape padded_shape = input->get_shape();
padded_shape.insert(std::begin(padded_shape), target_shape.size() - padded_shape.size(), 1);
full_shapes.push_back(std::move(padded_shape));
}
return {target_shape, full_shapes};
}
/// \brief Broadcast input node.
///
/// \note The source shape does not have to be the actual shape of input node. However
/// it should be a superset of it (containing it as a continuous subset). This implies
/// we may expand the number of axes of input node.
/// The ranks of source_shape and output_shape must be equal. This means that the
/// source_shape has to be padded with ones for this operation.
///
/// \param[in] node The input Node to be broadcasted.
/// \param[in] output_shape The output shape.
/// \param[in] source_shape The source shape from which we want to broadcast input node.
///
/// \return The boroadcasted Node.
/// \return The broadcasted Node.
///
static std::shared_ptr<ngraph::Node> broadcast(const std::shared_ptr<ngraph::Node>& node,
const ngraph::Shape& output_shape,
const ngraph::Shape& source_shape)
static std::shared_ptr<ngraph::Node>
broadcast_node_numpy_style(const std::shared_ptr<ngraph::Node>& node,
const ngraph::Shape& output_shape,
const ngraph::Shape& source_shape)
{
if (source_shape.size() != output_shape.size())
{
NGRAPH_WARN << "Ranks of source_shape and output_shape dont match: " << source_shape.size()
<< " vs " << output_shape.size();
}
ngraph::AxisVector broadcast_axes;
ngraph::Shape squeezed_shape;
// Positions of axes which have length of 1 are needed to calculate broadcast_axes
......@@ -111,8 +152,31 @@ namespace ngraph
auto left_full_shape = numpy_shapes.at(1);
auto right_full_shape = numpy_shapes.at(2);
return {broadcast(left, output_shape, left_full_shape),
broadcast(right, output_shape, right_full_shape)};
return {broadcast_node_numpy_style(left, output_shape, left_full_shape),
broadcast_node_numpy_style(right, output_shape, right_full_shape)};
}
NodeVector numpy_style_broadcast(NodeVector inputs)
{
if (inputs.size() <= 1)
{
return inputs;
}
// find the output tensor's shape, then broadcast all inputs so that they are compatible
auto bcast_shapes = get_numpy_broadcast_shapes(inputs);
NodeVector broadcasted_inputs;
for (std::size_t i = 0; i < inputs.size(); ++i)
{
const std::shared_ptr<ngraph::Node> input_node = inputs[i];
Shape source_shape = input_node->get_shape();
broadcasted_inputs.push_back(broadcast_node_numpy_style(
inputs[i], bcast_shapes.first, bcast_shapes.second[i]));
}
return broadcasted_inputs;
}
NodeVector
......@@ -147,8 +211,8 @@ namespace ngraph
std::next(std::begin(right_shape), right_shape.size() - 2),
std::end(right_shape));
return {broadcast(left, left_output_shape, left_full_shape),
broadcast(right, right_output_shape, right_full_shape)};
return {broadcast_node_numpy_style(left, left_output_shape, left_full_shape),
broadcast_node_numpy_style(right, right_output_shape, right_full_shape)};
}
NodeVector
......@@ -181,8 +245,8 @@ namespace ngraph
}
// Find first dimensions at front with length different from 1
size_t num_ones = 0;
for (size_t dimension : new_right_shape)
std::size_t num_ones = 0;
for (std::size_t dimension : new_right_shape)
{
if (dimension == 1)
{
......@@ -216,7 +280,7 @@ namespace ngraph
const Shape& input_shape,
std::size_t start_match_axis)
{
std::vector<size_t> result(output_shape.size() - input_shape.size());
std::vector<std::size_t> result(output_shape.size() - input_shape.size());
// Populate the result vector with monotonic increasing series from 0 until
// output_shape_size, excluding values in range [start_match_axis, start_match_axis + input_shape.size()
std::iota(std::begin(result), std::begin(result) + start_match_axis, 0);
......
......@@ -47,6 +47,13 @@ namespace ngraph
return numpy_style_broadcast_for_binary_operation(inputs.at(0), inputs.at(1));
}
/// \brief Cast shape of all input nodes for an element-wise operation that requires shape-compatibility
///
/// \param inputs Original list of inputs
///
/// \return Numpy-style broadcasted list of nodes.
NodeVector numpy_style_broadcast(NodeVector inputs);
/// \brief Cast shape of two nodes to make them compatible for an element-wise binary operation.
///
/// If necessary the right-hand-side argument will be broadcast to match the shape
......
ngraph ONNXImporter:±
"
cond cond_bool"Cast*
to  

cond_bool
x1
x2y"Where where_graphZ
cond



Z
x1


Z
x2


b
y



B
\ No newline at end of file
......@@ -1969,3 +1969,34 @@ TEST(onnx_${BACKEND_NAME}, model_sign)
EXPECT_TRUE(test::all_close_f(expected_outputs.front(), outputs.front()));
}
TEST(onnx_${BACKEND_NAME}, model_where)
{
auto function =
onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/where.onnx"));
using Inputs = std::vector<std::vector<int>>;
using Outputs = std::vector<std::vector<int>>;
// conditions tensor - 3x3x3
auto condition = std::vector<int>{
{0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0}};
// 1x3 tensor of "1"
auto x1 = std::vector<int>{1, 1, 1};
// 3x1 tensor of "2"
auto x2 = std::vector<int>{2, 2, 2};
Inputs inputs;
inputs.push_back(std::move(condition));
inputs.push_back(std::move(x1));
inputs.push_back(std::move(x2));
// y = 3x3x3
Outputs expected_outputs{
{2, 1, 2, 1, 2, 1, 2, 1, 2, 2, 1, 2, 1, 2, 1, 2, 1, 2, 2, 1, 2, 1, 2, 1, 2, 1, 2}};
Outputs outputs{execute(function, inputs, "${BACKEND_NAME}")};
EXPECT_EQ(expected_outputs.front(), outputs.front());
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment