Commit c4970542 authored by tsocha's avatar tsocha Committed by Michał Karzyński

[ONNX] Flatten operator (#1516)

parent a17ec605
...@@ -41,6 +41,8 @@ add_library(onnx_import STATIC ...@@ -41,6 +41,8 @@ add_library(onnx_import STATIC
op/conv.cpp op/conv.cpp
op/conv.hpp op/conv.hpp
op/div.hpp op/div.hpp
op/flatten.cpp
op/flatten.hpp
op/gemm.cpp op/gemm.cpp
op/gemm.hpp op/gemm.hpp
op/matmul.hpp op/matmul.hpp
......
//*****************************************************************************
// Copyright 2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "flatten.hpp"
#include "exceptions.hpp"
#include "utils/reshape.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
NodeVector flatten(const Node& node)
{
NodeVector inputs{node.get_ng_inputs()};
auto data = inputs.at(0);
auto axis = node.get_attribute_value<int64_t>("axis", 1);
if (axis < 0 || axis > data->get_shape().size())
{
throw error::parameter::Value("Flatten node (",
node.get_name(),
"): provided axis attribute is not valid.");
}
return {utils::flatten(data, axis)};
}
} // namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <memory>
#include "ngraph/node_vector.hpp"
#include "core/node.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
NodeVector flatten(const Node& node);
} // namespace op
} // namespace onnx_import
} // namespace ngraph
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#include "op/constant.hpp" #include "op/constant.hpp"
#include "op/conv.hpp" #include "op/conv.hpp"
#include "op/div.hpp" #include "op/div.hpp"
#include "op/flatten.hpp"
#include "op/gemm.hpp" #include "op/gemm.hpp"
#include "op/matmul.hpp" #include "op/matmul.hpp"
#include "op/max.hpp" #include "op/max.hpp"
...@@ -88,6 +89,7 @@ namespace ngraph ...@@ -88,6 +89,7 @@ namespace ngraph
m_map.emplace("Constant", std::bind(op::constant, std::placeholders::_1)); m_map.emplace("Constant", std::bind(op::constant, std::placeholders::_1));
m_map.emplace("Conv", std::bind(op::conv, std::placeholders::_1)); m_map.emplace("Conv", std::bind(op::conv, std::placeholders::_1));
m_map.emplace("Div", std::bind(op::div, std::placeholders::_1)); m_map.emplace("Div", std::bind(op::div, std::placeholders::_1));
m_map.emplace("Flatten", std::bind(op::flatten, std::placeholders::_1));
m_map.emplace("Gemm", std::bind(op::gemm, std::placeholders::_1)); m_map.emplace("Gemm", std::bind(op::gemm, std::placeholders::_1));
m_map.emplace("MatMul", std::bind(op::matmul, std::placeholders::_1)); m_map.emplace("MatMul", std::bind(op::matmul, std::placeholders::_1));
m_map.emplace("MaxPool", std::bind(op::max_pool, std::placeholders::_1)); m_map.emplace("MaxPool", std::bind(op::max_pool, std::placeholders::_1));
......
...@@ -24,6 +24,40 @@ namespace ngraph ...@@ -24,6 +24,40 @@ namespace ngraph
{ {
namespace onnx_import namespace onnx_import
{ {
namespace utils
{
std::shared_ptr<ngraph::Node> flatten(const std::shared_ptr<ngraph::Node>& node,
int axis)
{
auto data_shape = node->get_shape();
size_t first_dim_size = 1;
size_t last_dim_size = 1;
// First dimension of output tensor is the product of [d_0, ... d_{axis-1}] dimensions of input tensor.
// The last dimension is the product of the rest of input tensor dimensions: [d_{axis}, ..., d_n]
for (auto index = 0; index < data_shape.size(); ++index)
{
last_dim_size *= data_shape.at(index);
if (index < axis)
{
first_dim_size = last_dim_size;
}
}
last_dim_size /= first_dim_size;
// Generate an increasing sequence (0,1,2,3..) as input_order for Reshape
std::vector<size_t> input_order(data_shape.size());
std::iota(std::begin(input_order), std::end(input_order), 0);
return std::make_shared<ngraph::op::Reshape>(
node,
ngraph::AxisVector{input_order},
ngraph::Shape{first_dim_size, last_dim_size});
}
} // namespace utils
std::shared_ptr<ngraph::Node> reorder_axes(const std::shared_ptr<ngraph::Node>& node, std::shared_ptr<ngraph::Node> reorder_axes(const std::shared_ptr<ngraph::Node>& node,
std::vector<size_t> axes_order = {}) std::vector<size_t> axes_order = {})
{ {
......
...@@ -22,12 +22,24 @@ namespace ngraph ...@@ -22,12 +22,24 @@ namespace ngraph
{ {
namespace onnx_import namespace onnx_import
{ {
namespace utils
{
/// \brief Flatten the input tensor into a 2D matrix.
///
/// \param node The tensor to be flattened.
/// \param axis The axis dividing shape.
///
/// \return The new node being a 2D matrix representing flattened input node.
std::shared_ptr<ngraph::Node> flatten(const std::shared_ptr<ngraph::Node>& node,
int axis);
} // namespace utils
/// \brief Permute axes according to specified axes_order parameter. /// \brief Permute axes according to specified axes_order parameter.
/// ///
/// \param node The node which axes we want to permute. /// \param node The node which axes we want to permute.
/// \param axes_order The permutation of node tensor axes. /// \param axes_order The permutation of node tensor axes.
/// ///
/// \return: New node with permuted axes. /// \return New node with permuted axes.
std::shared_ptr<ngraph::Node> reorder_axes(const std::shared_ptr<ngraph::Node>& node, std::shared_ptr<ngraph::Node> reorder_axes(const std::shared_ptr<ngraph::Node>& node,
std::vector<int> axes_order); std::vector<int> axes_order);
...@@ -35,7 +47,7 @@ namespace ngraph ...@@ -35,7 +47,7 @@ namespace ngraph
/// ///
/// \param node Input tensor we want to transpose /// \param node Input tensor we want to transpose
/// ///
/// \return: New node with reversed dimensions. /// \return New node with reversed dimensions.
std::shared_ptr<ngraph::Node> transpose(const std::shared_ptr<ngraph::Node>& node); std::shared_ptr<ngraph::Node> transpose(const std::shared_ptr<ngraph::Node>& node);
} // namespace onnx_import } // namespace onnx_import
......
...@@ -490,6 +490,22 @@ TEST(onnx, model_softmax) ...@@ -490,6 +490,22 @@ TEST(onnx, model_softmax)
EXPECT_TRUE(test::all_close_f(expected_output, result_vectors.front())); EXPECT_TRUE(test::all_close_f(expected_output, result_vectors.front()));
} }
TEST(onnx, model_flatten)
{
auto function = onnx_import::import_onnx_function(
file_util::path_join(SERIALIZED_ZOO, "onnx/flatten.onnx"));
Inputs inputs;
inputs.emplace_back(
test::NDArray<float, 4>({{{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}}).get_vector());
Outputs expected_outputs{test::NDArray<float, 3>({{{1, 2, 3, 4}, {5, 6, 7, 8}}}).get_vector()};
Outputs outputs{execute(function, inputs, "INTERPRETER")};
EXPECT_TRUE(test::all_close_f(expected_outputs.front(), outputs.front()));
}
TEST(onnx, model_sub) TEST(onnx, model_sub)
{ {
auto function = ngraph::onnx_import::import_onnx_function( auto function = ngraph::onnx_import::import_onnx_function(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment