Commit 4a11e81a authored by tsocha's avatar tsocha Committed by Artur Wojcik

[ONNX] Transpose op (#1611)

* [ONNX] Transpose op

* Review fix pt. 1

* Review fix pt. 2
parent 9d1f2367
...@@ -109,6 +109,8 @@ add_library(onnx_import STATIC ...@@ -109,6 +109,8 @@ add_library(onnx_import STATIC
op/tanh.hpp op/tanh.hpp
op/thresholded_relu.cpp op/thresholded_relu.cpp
op/thresholded_relu.hpp op/thresholded_relu.hpp
op/transpose.cpp
op/transpose.hpp
op/unsqueeze.cpp op/unsqueeze.cpp
op/unsqueeze.hpp op/unsqueeze.hpp
op/xor.hpp op/xor.hpp
......
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <memory>
#include <vector>
#include "ngraph/node.hpp"
#include "transpose.hpp"
#include "utils/reshape.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
NodeVector transpose(const Node& node)
{
std::shared_ptr<ngraph::Node> data = node.get_ng_inputs().at(0);
auto permute_axes = node.get_attribute_value<std::vector<std::size_t>>("perm", {});
return {(permute_axes.empty()) ? reshape::transpose(data)
: reshape::reorder_axes(data, permute_axes)};
}
} // namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/node_vector.hpp"
#include "core/node.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
NodeVector transpose(const Node& node);
} // namespace op
} // namespace onnx_import
} // namespace ngraph
...@@ -71,6 +71,7 @@ ...@@ -71,6 +71,7 @@
#include "op/sum.hpp" #include "op/sum.hpp"
#include "op/tanh.hpp" #include "op/tanh.hpp"
#include "op/thresholded_relu.hpp" #include "op/thresholded_relu.hpp"
#include "op/transpose.hpp"
#include "op/unsqueeze.hpp" #include "op/unsqueeze.hpp"
#include "op/xor.hpp" #include "op/xor.hpp"
#include "ops_bridge.hpp" #include "ops_bridge.hpp"
...@@ -186,6 +187,7 @@ namespace ngraph ...@@ -186,6 +187,7 @@ namespace ngraph
m_map.emplace("Tanh", std::bind(op::tanh, std::placeholders::_1)); m_map.emplace("Tanh", std::bind(op::tanh, std::placeholders::_1));
m_map.emplace("ThresholdedRelu", m_map.emplace("ThresholdedRelu",
std::bind(op::thresholded_relu, std::placeholders::_1)); std::bind(op::thresholded_relu, std::placeholders::_1));
m_map.emplace("Transpose", std::bind(op::transpose, std::placeholders::_1));
m_map.emplace("Unsqueeze", std::bind(op::unsqueeze, std::placeholders::_1)); m_map.emplace("Unsqueeze", std::bind(op::unsqueeze, std::placeholders::_1));
m_map.emplace("Xor", std::bind(op::logical_xor, std::placeholders::_1)); m_map.emplace("Xor", std::bind(op::logical_xor, std::placeholders::_1));
} }
......
...@@ -117,7 +117,7 @@ namespace ngraph ...@@ -117,7 +117,7 @@ namespace ngraph
} }
std::shared_ptr<ngraph::Node> reorder_axes(const std::shared_ptr<ngraph::Node>& node, std::shared_ptr<ngraph::Node> reorder_axes(const std::shared_ptr<ngraph::Node>& node,
std::vector<size_t> axes_order = {}) std::vector<std::size_t> axes_order = {})
{ {
Shape out_shape = node->get_shape(); Shape out_shape = node->get_shape();
if (axes_order.empty()) if (axes_order.empty())
......
...@@ -70,7 +70,7 @@ namespace ngraph ...@@ -70,7 +70,7 @@ namespace ngraph
/// ///
/// \return: New node with permuted axes. /// \return: New node with permuted axes.
std::shared_ptr<ngraph::Node> reorder_axes(const std::shared_ptr<ngraph::Node>& node, std::shared_ptr<ngraph::Node> reorder_axes(const std::shared_ptr<ngraph::Node>& node,
std::vector<int> axes_order); std::vector<std::size_t> axes_order);
/// \brief Return transposed tensor (with axes in reversed order). /// \brief Return transposed tensor (with axes in reversed order).
/// ///
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment