Commit 9afbc891 authored by tsocha's avatar tsocha Committed by Michał Karzyński

[ONNX] Unsqueeze operator (#1521)

parent 1a7e2583
......@@ -62,6 +62,8 @@ add_library(onnx_import STATIC
op/split.hpp
op/sub.hpp
op/sum.hpp
op/unsqueeze.cpp
op/unsqueeze.hpp
ops_bridge.cpp
utils/broadcasting.cpp
utils/broadcasting.hpp
......
//*****************************************************************************
// Copyright 2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <numeric>
#include "unsqueeze.hpp"
#include "ngraph/op/reshape.hpp"
#include "exceptions.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
NodeVector unsqueeze(const Node& node)
{
NodeVector inputs{node.get_ng_inputs()};
auto data = inputs.at(0);
auto data_shape = data->get_shape();
auto axes = node.get_attribute_value<std::vector<int64_t>>("axes");
if (axes.empty())
{
throw error::parameter::Value(
"Unsqueeze", node.get_name(), "axes attribute is mandatory.");
}
std::sort(std::begin(axes), std::end(axes), std::greater<int64_t>());
// Generate an increasing sequence (0,1,2,3..) as input_order for Reshape
AxisVector input_order(data_shape.size());
std::iota(std::begin(input_order), std::end(input_order), 0);
for (auto axis : axes)
{
if ((axis < 0) || (axis > data_shape.size()))
{
throw error::parameter::Value(
"Unsqueeze", node.get_name(), "provided axes attribute is not valid.");
}
data_shape.insert(std::next(std::begin(data_shape), axis), 1);
}
return {std::make_shared<ngraph::op::Reshape>(data, input_order, data_shape)};
}
} // namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/node_vector.hpp"
#include "core/node.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
NodeVector unsqueeze(const Node& node);
} // namespace op
} // namespace onnx_import
} // namespace ngraph
......@@ -38,6 +38,7 @@
#include "op/split.hpp"
#include "op/sub.hpp"
#include "op/sum.hpp"
#include "op/unsqueeze.hpp"
#include "ops_bridge.hpp"
namespace ngraph
......@@ -104,6 +105,7 @@ namespace ngraph
m_map.emplace("Split", std::bind(op::split, std::placeholders::_1));
m_map.emplace("Sub", std::bind(op::sub, std::placeholders::_1));
m_map.emplace("Sum", std::bind(op::sum, std::placeholders::_1));
m_map.emplace("Unsqueeze", std::bind(op::unsqueeze, std::placeholders::_1));
}
NodeVector operator()(const Node& node) const
......
......@@ -536,6 +536,29 @@ TEST(onnx, model_sub)
EXPECT_TRUE(test::all_close_f(expected_output, result_vectors.front()));
}
TEST(onnx, model_unsqueeze)
{
auto function = onnx_import::import_onnx_function(
file_util::path_join(SERIALIZED_ZOO, "onnx/unsqueeze.onnx"));
Inputs inputs;
inputs.emplace_back(test::NDArray<float, 3>(
{{{1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}},
{{1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}},
{{1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}}})
.get_vector());
Outputs expected_output{
test::NDArray<float, 4>(
{{{{1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}},
{{1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}},
{{1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}}}})
.get_vector()};
Outputs outputs{execute(function, inputs, "INTERPRETER")};
EXPECT_TRUE(test::all_close_f(expected_output.front(), outputs.front()));
}
TEST(onnx, model_div)
{
auto function = ngraph::onnx_import::import_onnx_function(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment