Commit 5c706276 authored by Adam Rogowiec's avatar Adam Rogowiec Committed by Robert Kimball

[ONNX] Reshape operator (#1529)

* Move reshape utils down to reshape namespace.

* Reshape operation.

* Reshape operator binding.

* Error fixes.

* Reshape unit tests.

* Move flatten utility function to reshape namespace.

* Fix unused catched exception object

* Add Constant support for int64

* Review fix.

* clang-format

* Review fix part 2.

* Enable output shape as a second node input (only Constant).

* Unit test for "dynamic" output shape (from Constant node).

* Review fixes.

* Make sure second Reshape op input is Constant node.
parent 04b1434d
...@@ -56,6 +56,8 @@ add_library(onnx_import STATIC ...@@ -56,6 +56,8 @@ add_library(onnx_import STATIC
op/min.hpp op/min.hpp
op/mul.hpp op/mul.hpp
op/relu.hpp op/relu.hpp
op/reshape.cpp
op/reshape.hpp
op/softmax.cpp op/softmax.cpp
op/softmax.hpp op/softmax.hpp
op/split.cpp op/split.cpp
......
...@@ -69,6 +69,13 @@ namespace ngraph ...@@ -69,6 +69,13 @@ namespace ngraph
return __make_ng_constant<int32_t>(element::i32, tensor); return __make_ng_constant<int32_t>(element::i32, tensor);
} }
template <>
inline std::shared_ptr<ngraph::op::Constant>
make_ng_constant<Tensor::Type::int64>(const Tensor& tensor)
{
return __make_ng_constant<int64_t>(element::i64, tensor);
}
template <> template <>
inline std::shared_ptr<ngraph::op::Constant> inline std::shared_ptr<ngraph::op::Constant>
make_ng_constant<Tensor::Type::uint32>(const Tensor& tensor) make_ng_constant<Tensor::Type::uint32>(const Tensor& tensor)
...@@ -94,6 +101,7 @@ namespace ngraph ...@@ -94,6 +101,7 @@ namespace ngraph
MAKE_NG_CONSTANT(Tensor::Type::float32); MAKE_NG_CONSTANT(Tensor::Type::float32);
MAKE_NG_CONSTANT(Tensor::Type::float64); MAKE_NG_CONSTANT(Tensor::Type::float64);
MAKE_NG_CONSTANT(Tensor::Type::int32); MAKE_NG_CONSTANT(Tensor::Type::int32);
MAKE_NG_CONSTANT(Tensor::Type::int64);
MAKE_NG_CONSTANT(Tensor::Type::uint32); MAKE_NG_CONSTANT(Tensor::Type::uint32);
MAKE_NG_CONSTANT(Tensor::Type::uint64); MAKE_NG_CONSTANT(Tensor::Type::uint64);
default: throw error::tensor::invalid_data_type{tensor}; default: throw error::tensor::invalid_data_type{tensor};
......
...@@ -38,7 +38,7 @@ namespace ngraph ...@@ -38,7 +38,7 @@ namespace ngraph
"): provided axis attribute is not valid."); "): provided axis attribute is not valid.");
} }
return {utils::flatten(data, axis)}; return {reshape::flatten(data, axis)};
} }
} // namespace op } // namespace op
......
...@@ -47,11 +47,11 @@ namespace ngraph ...@@ -47,11 +47,11 @@ namespace ngraph
if (trans_a != 0) if (trans_a != 0)
{ {
input_a = transpose(input_a); input_a = reshape::transpose(input_a);
} }
if (trans_b != 0) if (trans_b != 0)
{ {
input_b = transpose(input_b); input_b = reshape::transpose(input_b);
} }
// code from python not implemented in c++ yet. // code from python not implemented in c++ yet.
......
//*****************************************************************************
// Copyright 2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <cstddef>
#include <memory>
#include <vector>
#include "ngraph/axis_vector.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/shape.hpp"
#include "exceptions.hpp"
#include "reshape.hpp"
#include "utils/reshape.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
NodeVector reshape(const Node& node)
{
NodeVector ng_inputs{node.get_ng_inputs()};
auto data = ng_inputs.at(0);
auto data_shape = data->get_shape();
auto output_shape = node.get_attribute_value<std::vector<std::size_t>>("shape", {});
// If no shape argument (opset >= 5) and there is second input.
if (output_shape.empty() && ng_inputs.size() == 2)
{
// Currently only support Constant node.
if (ng_inputs.at(1)->description() == "Constant")
{
auto output_shape_node =
std::dynamic_pointer_cast<ngraph::op::Constant>(ng_inputs.at(1));
output_shape = output_shape_node->get_vector<std::size_t>();
}
else
{
throw error::NotSupported("Reshape",
node.get_name(),
"doesn't support "
"shape input of other type than Constant.");
}
}
// Do nothing if there is no shape argument nor second node input.
else if (output_shape.empty())
{
return {data};
}
output_shape = reshape::infer_dimensions(node.get_name(), data_shape, output_shape);
return {std::make_shared<ngraph::op::Reshape>(
data,
reshape::get_default_axis_vector(data_shape.size()),
Shape{output_shape})};
}
} // namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/node_vector.hpp"
#include "core/node.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
///
/// \brief Reshape the input tensor similar to numpy.reshape.
///
/// \param[in] node The ONNX node representing this operation.
///
/// \return Ngraph node representing this operation.
///
NodeVector reshape(const Node& node);
} // namespace op
} // namespace onnx_import
} // namespace ngraph
...@@ -34,6 +34,7 @@ ...@@ -34,6 +34,7 @@
#include "op/min.hpp" #include "op/min.hpp"
#include "op/mul.hpp" #include "op/mul.hpp"
#include "op/relu.hpp" #include "op/relu.hpp"
#include "op/reshape.hpp"
#include "op/softmax.hpp" #include "op/softmax.hpp"
#include "op/split.hpp" #include "op/split.hpp"
#include "op/sub.hpp" #include "op/sub.hpp"
...@@ -101,6 +102,7 @@ namespace ngraph ...@@ -101,6 +102,7 @@ namespace ngraph
m_map.emplace("Min", std::bind(op::min, std::placeholders::_1)); m_map.emplace("Min", std::bind(op::min, std::placeholders::_1));
m_map.emplace("Mul", std::bind(op::mul, std::placeholders::_1)); m_map.emplace("Mul", std::bind(op::mul, std::placeholders::_1));
m_map.emplace("Relu", std::bind(op::relu, std::placeholders::_1)); m_map.emplace("Relu", std::bind(op::relu, std::placeholders::_1));
m_map.emplace("Reshape", std::bind(op::reshape, std::placeholders::_1));
m_map.emplace("Softmax", std::bind(op::softmax, std::placeholders::_1)); m_map.emplace("Softmax", std::bind(op::softmax, std::placeholders::_1));
m_map.emplace("Split", std::bind(op::split, std::placeholders::_1)); m_map.emplace("Split", std::bind(op::split, std::placeholders::_1));
m_map.emplace("Sub", std::bind(op::sub, std::placeholders::_1)); m_map.emplace("Sub", std::bind(op::sub, std::placeholders::_1));
......
...@@ -14,17 +14,25 @@ ...@@ -14,17 +14,25 @@
// limitations under the License. // limitations under the License.
//***************************************************************************** //*****************************************************************************
#include <algorithm>
#include <cstddef>
#include <functional>
#include <iterator>
#include <numeric> #include <numeric>
#include <stdexcept>
#include "ngraph/axis_vector.hpp"
#include "ngraph/op/reshape.hpp" #include "ngraph/op/reshape.hpp"
#include "ngraph/shape.hpp"
#include "exceptions.hpp"
#include "utils/reshape.hpp" #include "utils/reshape.hpp"
namespace ngraph namespace ngraph
{ {
namespace onnx_import namespace onnx_import
{ {
namespace utils namespace reshape
{ {
std::shared_ptr<ngraph::Node> flatten(const std::shared_ptr<ngraph::Node>& node, std::shared_ptr<ngraph::Node> flatten(const std::shared_ptr<ngraph::Node>& node,
int axis) int axis)
...@@ -52,16 +60,82 @@ namespace ngraph ...@@ -52,16 +60,82 @@ namespace ngraph
std::iota(std::begin(input_order), std::end(input_order), 0); std::iota(std::begin(input_order), std::end(input_order), 0);
return std::make_shared<ngraph::op::Reshape>( return std::make_shared<ngraph::op::Reshape>(
node, node, AxisVector{input_order}, Shape{first_dim_size, last_dim_size});
ngraph::AxisVector{input_order}, }
ngraph::Shape{first_dim_size, last_dim_size});
AxisVector get_default_axis_vector(std::size_t data_shape_size, std::size_t start_value)
{
AxisVector axis_vector(data_shape_size);
std::iota(std::begin(axis_vector), std::end(axis_vector), start_value);
return axis_vector;
}
std::vector<std::size_t> infer_dimensions(const std::string& node_name,
const std::vector<std::size_t>& input_shape,
const std::vector<std::size_t>& output_shape)
{
std::vector<std::size_t> inferred_dims{output_shape};
// If an output dimension is equal to zero its actual value is copied from the input
// shape argument.
for (std::size_t idx = 0; idx < inferred_dims.size(); ++idx)
{
if (inferred_dims.at(idx) == 0)
{
if (idx < input_shape.size())
{
inferred_dims.at(idx) = input_shape.at(idx);
}
else
{
throw error::parameter::Value(
"Reshape",
node_name,
"can not copy dimension from the input data shape since requested "
"index is out of range.");
}
}
}
// Check whether there are dimensions equal to -1 in output_shape. There may be at most
// one such case. Its value is then inferred from the size of the tensor and the
// remaining dimensions.
auto neg_value_it =
std::find(std::begin(inferred_dims), std::end(inferred_dims), -1);
if (neg_value_it != std::end(inferred_dims))
{
// only single '-1' value is allowed
if (std::find(std::next(neg_value_it), std::end(inferred_dims), -1) !=
std::end(inferred_dims))
{
throw error::parameter::Value("Reshape",
node_name,
"more than one dimension is set to (-1). "
"Only one dimension value can be inferred.");
}
// Set dimension value to 1 temporarily to be able to calculate its value.
*neg_value_it = 1;
std::size_t input_shape_product =
std::accumulate(std::begin(input_shape),
std::end(input_shape),
1UL,
std::multiplies<std::size_t>());
std::size_t output_shape_product =
std::accumulate(std::begin(inferred_dims),
std::end(inferred_dims),
1UL,
std::multiplies<std::size_t>());
*neg_value_it = input_shape_product / output_shape_product;
}
return inferred_dims;
} }
} // namespace utils
std::shared_ptr<ngraph::Node> reorder_axes(const std::shared_ptr<ngraph::Node>& node, std::shared_ptr<ngraph::Node> reorder_axes(const std::shared_ptr<ngraph::Node>& node,
std::vector<size_t> axes_order = {}) std::vector<size_t> axes_order = {})
{ {
ngraph::Shape out_shape = node->get_shape(); Shape out_shape = node->get_shape();
if (axes_order.empty()) if (axes_order.empty())
{ {
axes_order.resize(out_shape.size()); axes_order.resize(out_shape.size());
...@@ -69,13 +143,13 @@ namespace ngraph ...@@ -69,13 +143,13 @@ namespace ngraph
} }
else else
{ {
for (auto i = 0; i < axes_order.size(); ++i) for (std::size_t i = 0; i < axes_order.size(); ++i)
{ {
out_shape[i] = node->get_shape().at(axes_order.at(i)); out_shape[i] = node->get_shape().at(axes_order.at(i));
} }
} }
auto axis_vector = ngraph::AxisVector{axes_order.begin(), axes_order.end()}; auto axis_vector = AxisVector{std::begin(axes_order), std::end(axes_order)};
return std::make_shared<ngraph::op::Reshape>(node, axis_vector, out_shape); return std::make_shared<ngraph::op::Reshape>(node, axis_vector, out_shape);
} }
...@@ -86,6 +160,7 @@ namespace ngraph ...@@ -86,6 +160,7 @@ namespace ngraph
std::reverse(std::begin(axes_order), std::end(axes_order)); std::reverse(std::begin(axes_order), std::end(axes_order));
return reorder_axes(node, axes_order); return reorder_axes(node, axes_order);
} }
} // namespace onnx_import
} // namespace reshape
} // namespace onnx_import
} // namespace ngraph } // namespace ngraph
...@@ -16,13 +16,14 @@ ...@@ -16,13 +16,14 @@
#pragma once #pragma once
#include "ngraph/axis_vector.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
namespace ngraph namespace ngraph
{ {
namespace onnx_import namespace onnx_import
{ {
namespace utils namespace reshape
{ {
/// \brief Flatten the input tensor into a 2D matrix. /// \brief Flatten the input tensor into a 2D matrix.
/// ///
...@@ -32,14 +33,42 @@ namespace ngraph ...@@ -32,14 +33,42 @@ namespace ngraph
/// \return The new node being a 2D matrix representing flattened input node. /// \return The new node being a 2D matrix representing flattened input node.
std::shared_ptr<ngraph::Node> flatten(const std::shared_ptr<ngraph::Node>& node, std::shared_ptr<ngraph::Node> flatten(const std::shared_ptr<ngraph::Node>& node,
int axis); int axis);
} // namespace utils
/// \brief Gets the AxisVector filled with monotonic increasing
/// sequence.
///
/// \param[in] data_shape_size The data shape size.
/// \param[in] start_value The start_value for sequence. Default equals 0.
///
/// \return The filled AxisVector.
///
AxisVector get_default_axis_vector(std::size_t data_shape_size,
std::size_t start_value = 0);
/// \brief Infer `output_shape` dimension values.
///
/// \par Inferention rules
/// \li The input_shape may consist at most on -1 value. In this case the value
/// is inferred from the size of the tensor and the remaining dimensions.
/// \li If a dimension value is equal to 0, then its output value is going to
/// be copied from the input_shape argument.
///
/// \param[in] node_name The node name.
/// \param[in] input_shape The input node shape.
/// \param[in] output_shape The requested output shape for the input node data.
///
/// \return A vector containig new, valid node shape.
///
std::vector<std::size_t> infer_dimensions(const std::string& node_name,
const std::vector<std::size_t>& input_shape,
const std::vector<std::size_t>& output_shape);
/// \brief Permute axes according to specified axes_order parameter. /// \brief Permute axes according to specified axes_order parameter.
/// ///
/// \param node The node which axes we want to permute. /// \param node The node which axes we want to permute.
/// \param axes_order The permutation of node tensor axes. /// \param axes_order The permutation of node tensor axes.
/// ///
/// \return New node with permuted axes. /// \return: New node with permuted axes.
std::shared_ptr<ngraph::Node> reorder_axes(const std::shared_ptr<ngraph::Node>& node, std::shared_ptr<ngraph::Node> reorder_axes(const std::shared_ptr<ngraph::Node>& node,
std::vector<int> axes_order); std::vector<int> axes_order);
...@@ -47,8 +76,9 @@ namespace ngraph ...@@ -47,8 +76,9 @@ namespace ngraph
/// ///
/// \param node Input tensor we want to transpose /// \param node Input tensor we want to transpose
/// ///
/// \return New node with reversed dimensions. /// \return: New node with reversed dimensions.
std::shared_ptr<ngraph::Node> transpose(const std::shared_ptr<ngraph::Node>& node); std::shared_ptr<ngraph::Node> transpose(const std::shared_ptr<ngraph::Node>& node);
} // namespace onnx_import
} // namespace reshape
} // namespace onnx_import
} // namespace ngraph } // namespace ngraph
ONNXNgraphImporter:j
#
AB"Reshape*
shape@@@@ compute_graphZ
A



b
B




B
\ No newline at end of file
ONNXNgraphImporter:d
!
AB"Reshape*
shape@@@ compute_graphZ
A



b
B



B
\ No newline at end of file
ONNXNgraphImporter:^

AB"Reshape*
shape@@  compute_graphZ
A



b
B


 B
\ No newline at end of file
ONNXNgraphImporter:d
!
AB"Reshape*
shape@@@ compute_graphZ
A



b
B



B
\ No newline at end of file
ONNXNgraphImporter:X

AB"Reshape*
shape@ compute_graphZ
A



b
B

B
\ No newline at end of file
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment