Commit 5c706276 authored by Adam Rogowiec's avatar Adam Rogowiec Committed by Robert Kimball

[ONNX] Reshape operator (#1529)

* Move reshape utils down to reshape namespace.

* Reshape operation.

* Reshape operator binding.

* Error fixes.

* Reshape unit tests.

* Move flatten utility function to reshape namespace.

* Fix unused catched exception object

* Add Constant support for int64

* Review fix.

* clang-format

* Review fix part 2.

* Enable output shape as a second node input (only Constant).

* Unit test for "dynamic" output shape (from Constant node).

* Review fixes.

* Make sure second Reshape op input is Constant node.
parent 04b1434d
......@@ -56,6 +56,8 @@ add_library(onnx_import STATIC
op/min.hpp
op/mul.hpp
op/relu.hpp
op/reshape.cpp
op/reshape.hpp
op/softmax.cpp
op/softmax.hpp
op/split.cpp
......
......@@ -69,6 +69,13 @@ namespace ngraph
return __make_ng_constant<int32_t>(element::i32, tensor);
}
template <>
inline std::shared_ptr<ngraph::op::Constant>
make_ng_constant<Tensor::Type::int64>(const Tensor& tensor)
{
return __make_ng_constant<int64_t>(element::i64, tensor);
}
template <>
inline std::shared_ptr<ngraph::op::Constant>
make_ng_constant<Tensor::Type::uint32>(const Tensor& tensor)
......@@ -94,6 +101,7 @@ namespace ngraph
MAKE_NG_CONSTANT(Tensor::Type::float32);
MAKE_NG_CONSTANT(Tensor::Type::float64);
MAKE_NG_CONSTANT(Tensor::Type::int32);
MAKE_NG_CONSTANT(Tensor::Type::int64);
MAKE_NG_CONSTANT(Tensor::Type::uint32);
MAKE_NG_CONSTANT(Tensor::Type::uint64);
default: throw error::tensor::invalid_data_type{tensor};
......
......@@ -38,7 +38,7 @@ namespace ngraph
"): provided axis attribute is not valid.");
}
return {utils::flatten(data, axis)};
return {reshape::flatten(data, axis)};
}
} // namespace op
......
......@@ -47,11 +47,11 @@ namespace ngraph
if (trans_a != 0)
{
input_a = transpose(input_a);
input_a = reshape::transpose(input_a);
}
if (trans_b != 0)
{
input_b = transpose(input_b);
input_b = reshape::transpose(input_b);
}
// code from python not implemented in c++ yet.
......
//*****************************************************************************
// Copyright 2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <cstddef>
#include <memory>
#include <vector>
#include "ngraph/axis_vector.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/shape.hpp"
#include "exceptions.hpp"
#include "reshape.hpp"
#include "utils/reshape.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
NodeVector reshape(const Node& node)
{
NodeVector ng_inputs{node.get_ng_inputs()};
auto data = ng_inputs.at(0);
auto data_shape = data->get_shape();
auto output_shape = node.get_attribute_value<std::vector<std::size_t>>("shape", {});
// If no shape argument (opset >= 5) and there is second input.
if (output_shape.empty() && ng_inputs.size() == 2)
{
// Currently only support Constant node.
if (ng_inputs.at(1)->description() == "Constant")
{
auto output_shape_node =
std::dynamic_pointer_cast<ngraph::op::Constant>(ng_inputs.at(1));
output_shape = output_shape_node->get_vector<std::size_t>();
}
else
{
throw error::NotSupported("Reshape",
node.get_name(),
"doesn't support "
"shape input of other type than Constant.");
}
}
// Do nothing if there is no shape argument nor second node input.
else if (output_shape.empty())
{
return {data};
}
output_shape = reshape::infer_dimensions(node.get_name(), data_shape, output_shape);
return {std::make_shared<ngraph::op::Reshape>(
data,
reshape::get_default_axis_vector(data_shape.size()),
Shape{output_shape})};
}
} // namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/node_vector.hpp"
#include "core/node.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
///
/// \brief Reshape the input tensor similar to numpy.reshape.
///
/// \param[in] node The ONNX node representing this operation.
///
/// \return Ngraph node representing this operation.
///
NodeVector reshape(const Node& node);
} // namespace op
} // namespace onnx_import
} // namespace ngraph
......@@ -34,6 +34,7 @@
#include "op/min.hpp"
#include "op/mul.hpp"
#include "op/relu.hpp"
#include "op/reshape.hpp"
#include "op/softmax.hpp"
#include "op/split.hpp"
#include "op/sub.hpp"
......@@ -101,6 +102,7 @@ namespace ngraph
m_map.emplace("Min", std::bind(op::min, std::placeholders::_1));
m_map.emplace("Mul", std::bind(op::mul, std::placeholders::_1));
m_map.emplace("Relu", std::bind(op::relu, std::placeholders::_1));
m_map.emplace("Reshape", std::bind(op::reshape, std::placeholders::_1));
m_map.emplace("Softmax", std::bind(op::softmax, std::placeholders::_1));
m_map.emplace("Split", std::bind(op::split, std::placeholders::_1));
m_map.emplace("Sub", std::bind(op::sub, std::placeholders::_1));
......
......@@ -14,17 +14,25 @@
// limitations under the License.
//*****************************************************************************
#include <algorithm>
#include <cstddef>
#include <functional>
#include <iterator>
#include <numeric>
#include <stdexcept>
#include "ngraph/axis_vector.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/shape.hpp"
#include "exceptions.hpp"
#include "utils/reshape.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace utils
namespace reshape
{
std::shared_ptr<ngraph::Node> flatten(const std::shared_ptr<ngraph::Node>& node,
int axis)
......@@ -52,16 +60,82 @@ namespace ngraph
std::iota(std::begin(input_order), std::end(input_order), 0);
return std::make_shared<ngraph::op::Reshape>(
node,
ngraph::AxisVector{input_order},
ngraph::Shape{first_dim_size, last_dim_size});
node, AxisVector{input_order}, Shape{first_dim_size, last_dim_size});
}
AxisVector get_default_axis_vector(std::size_t data_shape_size, std::size_t start_value)
{
AxisVector axis_vector(data_shape_size);
std::iota(std::begin(axis_vector), std::end(axis_vector), start_value);
return axis_vector;
}
std::vector<std::size_t> infer_dimensions(const std::string& node_name,
const std::vector<std::size_t>& input_shape,
const std::vector<std::size_t>& output_shape)
{
std::vector<std::size_t> inferred_dims{output_shape};
// If an output dimension is equal to zero its actual value is copied from the input
// shape argument.
for (std::size_t idx = 0; idx < inferred_dims.size(); ++idx)
{
if (inferred_dims.at(idx) == 0)
{
if (idx < input_shape.size())
{
inferred_dims.at(idx) = input_shape.at(idx);
}
else
{
throw error::parameter::Value(
"Reshape",
node_name,
"can not copy dimension from the input data shape since requested "
"index is out of range.");
}
}
}
// Check whether there are dimensions equal to -1 in output_shape. There may be at most
// one such case. Its value is then inferred from the size of the tensor and the
// remaining dimensions.
auto neg_value_it =
std::find(std::begin(inferred_dims), std::end(inferred_dims), -1);
if (neg_value_it != std::end(inferred_dims))
{
// only single '-1' value is allowed
if (std::find(std::next(neg_value_it), std::end(inferred_dims), -1) !=
std::end(inferred_dims))
{
throw error::parameter::Value("Reshape",
node_name,
"more than one dimension is set to (-1). "
"Only one dimension value can be inferred.");
}
// Set dimension value to 1 temporarily to be able to calculate its value.
*neg_value_it = 1;
std::size_t input_shape_product =
std::accumulate(std::begin(input_shape),
std::end(input_shape),
1UL,
std::multiplies<std::size_t>());
std::size_t output_shape_product =
std::accumulate(std::begin(inferred_dims),
std::end(inferred_dims),
1UL,
std::multiplies<std::size_t>());
*neg_value_it = input_shape_product / output_shape_product;
}
return inferred_dims;
}
} // namespace utils
std::shared_ptr<ngraph::Node> reorder_axes(const std::shared_ptr<ngraph::Node>& node,
std::vector<size_t> axes_order = {})
{
ngraph::Shape out_shape = node->get_shape();
Shape out_shape = node->get_shape();
if (axes_order.empty())
{
axes_order.resize(out_shape.size());
......@@ -69,13 +143,13 @@ namespace ngraph
}
else
{
for (auto i = 0; i < axes_order.size(); ++i)
for (std::size_t i = 0; i < axes_order.size(); ++i)
{
out_shape[i] = node->get_shape().at(axes_order.at(i));
}
}
auto axis_vector = ngraph::AxisVector{axes_order.begin(), axes_order.end()};
auto axis_vector = AxisVector{std::begin(axes_order), std::end(axes_order)};
return std::make_shared<ngraph::op::Reshape>(node, axis_vector, out_shape);
}
......@@ -86,6 +160,7 @@ namespace ngraph
std::reverse(std::begin(axes_order), std::end(axes_order));
return reorder_axes(node, axes_order);
}
} // namespace onnx_import
} // namespace reshape
} // namespace onnx_import
} // namespace ngraph
......@@ -16,13 +16,14 @@
#pragma once
#include "ngraph/axis_vector.hpp"
#include "ngraph/node.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace utils
namespace reshape
{
/// \brief Flatten the input tensor into a 2D matrix.
///
......@@ -32,14 +33,42 @@ namespace ngraph
/// \return The new node being a 2D matrix representing flattened input node.
std::shared_ptr<ngraph::Node> flatten(const std::shared_ptr<ngraph::Node>& node,
int axis);
} // namespace utils
/// \brief Gets the AxisVector filled with monotonic increasing
/// sequence.
///
/// \param[in] data_shape_size The data shape size.
/// \param[in] start_value The start_value for sequence. Default equals 0.
///
/// \return The filled AxisVector.
///
AxisVector get_default_axis_vector(std::size_t data_shape_size,
std::size_t start_value = 0);
/// \brief Infer `output_shape` dimension values.
///
/// \par Inferention rules
/// \li The input_shape may consist at most on -1 value. In this case the value
/// is inferred from the size of the tensor and the remaining dimensions.
/// \li If a dimension value is equal to 0, then its output value is going to
/// be copied from the input_shape argument.
///
/// \param[in] node_name The node name.
/// \param[in] input_shape The input node shape.
/// \param[in] output_shape The requested output shape for the input node data.
///
/// \return A vector containig new, valid node shape.
///
std::vector<std::size_t> infer_dimensions(const std::string& node_name,
const std::vector<std::size_t>& input_shape,
const std::vector<std::size_t>& output_shape);
/// \brief Permute axes according to specified axes_order parameter.
///
/// \param node The node which axes we want to permute.
/// \param axes_order The permutation of node tensor axes.
///
/// \return New node with permuted axes.
/// \return: New node with permuted axes.
std::shared_ptr<ngraph::Node> reorder_axes(const std::shared_ptr<ngraph::Node>& node,
std::vector<int> axes_order);
......@@ -47,8 +76,9 @@ namespace ngraph
///
/// \param node Input tensor we want to transpose
///
/// \return New node with reversed dimensions.
/// \return: New node with reversed dimensions.
std::shared_ptr<ngraph::Node> transpose(const std::shared_ptr<ngraph::Node>& node);
} // namespace onnx_import
} // namespace reshape
} // namespace onnx_import
} // namespace ngraph
ONNXNgraphImporter:j
#
AB"Reshape*
shape@@@@ compute_graphZ
A



b
B




B
\ No newline at end of file
ONNXNgraphImporter:d
!
AB"Reshape*
shape@@@ compute_graphZ
A



b
B



B
\ No newline at end of file
ONNXNgraphImporter:^

AB"Reshape*
shape@@  compute_graphZ
A



b
B


 B
\ No newline at end of file
ONNXNgraphImporter:d
!
AB"Reshape*
shape@@@ compute_graphZ
A



b
B



B
\ No newline at end of file
ONNXNgraphImporter:X

AB"Reshape*
shape@ compute_graphZ
A



b
B

B
\ No newline at end of file
......@@ -14,6 +14,7 @@
// limitations under the License.
//*****************************************************************************
#include <cstdint>
#include <fstream>
#include <sstream>
......@@ -196,8 +197,8 @@ TEST(onnx, model_conv2d_strides_assymetric_padding)
TEST(onnx, model_average_pool_2d)
{
// Pooling with strides=2 and no padding
auto model = ngraph::onnx_import::import_onnx_function(
ngraph::file_util::path_join(SERIALIZED_ZOO, "onnx/average_pool_2d.onnx"));
auto model = onnx_import::import_onnx_function(
file_util::path_join(SERIALIZED_ZOO, "onnx/average_pool_2d.onnx"));
// input data shape (1, 1, 4, 4)
Inputs inputs;
......@@ -218,8 +219,8 @@ TEST(onnx, model_average_pool_2d)
TEST(onnx, model_average_pool_2d_pads)
{
// Pooling with strides=2 and padding=1
auto model = ngraph::onnx_import::import_onnx_function(
ngraph::file_util::path_join(SERIALIZED_ZOO, "onnx/average_pool_2d_pads.onnx"));
auto model = onnx_import::import_onnx_function(
file_util::path_join(SERIALIZED_ZOO, "onnx/average_pool_2d_pads.onnx"));
// input data shape (1, 1, 4, 4)
Inputs inputs;
......@@ -242,8 +243,8 @@ TEST(onnx, model_average_pool_2d_pads)
TEST(onnx, model_max_pool_2d_pads)
{
// Pooling with strides=2 and padding=1
auto model = ngraph::onnx_import::import_onnx_function(
ngraph::file_util::path_join(SERIALIZED_ZOO, "onnx/max_pool_2d_pads.onnx"));
auto model = onnx_import::import_onnx_function(
file_util::path_join(SERIALIZED_ZOO, "onnx/max_pool_2d_pads.onnx"));
// input data shape (1, 1, 4, 4)
Inputs inputs;
......@@ -309,8 +310,8 @@ TEST(onnx, model_relu)
TEST(onnx, model_sum)
{
// Simple Sum test
auto function = ngraph::onnx_import::import_onnx_function(
ngraph::file_util::path_join(SERIALIZED_ZOO, "onnx/sum.onnx"));
auto function =
onnx_import::import_onnx_function(file_util::path_join(SERIALIZED_ZOO, "onnx/sum.onnx"));
// input data shape (3, )
Inputs inputs;
......@@ -325,8 +326,8 @@ TEST(onnx, model_sum)
TEST(onnx, model_sum_one_input)
{
auto function = ngraph::onnx_import::import_onnx_function(
ngraph::file_util::path_join(SERIALIZED_ZOO, "onnx/sum_one_input.onnx"));
auto function = onnx_import::import_onnx_function(
file_util::path_join(SERIALIZED_ZOO, "onnx/sum_one_input.onnx"));
// input data shape (3, )
Inputs inputs{{3.f, 0.f, 2.f}};
......@@ -337,8 +338,8 @@ TEST(onnx, model_sum_one_input)
TEST(onnx, model_min_two_inputs)
{
auto function = ngraph::onnx_import::import_onnx_function(
ngraph::file_util::path_join(SERIALIZED_ZOO, "onnx/min_two_inputs.onnx"));
auto function = onnx_import::import_onnx_function(
file_util::path_join(SERIALIZED_ZOO, "onnx/min_two_inputs.onnx"));
// input data shape (3, )
Inputs inputs;
......@@ -352,8 +353,8 @@ TEST(onnx, model_min_two_inputs)
TEST(onnx, model_max)
{
auto function = ngraph::onnx_import::import_onnx_function(
ngraph::file_util::path_join(SERIALIZED_ZOO, "onnx/max.onnx"));
auto function =
onnx_import::import_onnx_function(file_util::path_join(SERIALIZED_ZOO, "onnx/max.onnx"));
// input data shape (3, )
Inputs inputs;
......@@ -368,8 +369,8 @@ TEST(onnx, model_max)
TEST(onnx, model_mean)
{
auto function = ngraph::onnx_import::import_onnx_function(
ngraph::file_util::path_join(SERIALIZED_ZOO, "onnx/mean.onnx"));
auto function =
onnx_import::import_onnx_function(file_util::path_join(SERIALIZED_ZOO, "onnx/mean.onnx"));
// input data shape (3, )
Inputs inputs;
......@@ -435,12 +436,12 @@ TEST(onnx, model_matmul)
TEST(onnx, model_softmax)
{
auto function = ngraph::onnx_import::import_onnx_function(
ngraph::file_util::path_join(SERIALIZED_ZOO, "onnx/softmax.onnx"));
auto function = onnx_import::import_onnx_function(
file_util::path_join(SERIALIZED_ZOO, "onnx/softmax.onnx"));
Inputs inputs;
inputs.emplace_back(
ngraph::test::NDArray<float, 3>(
test::NDArray<float, 3>(
{{{1, 2, 3, 4, 5}, {6, 7, 8, 9, 10}, {11, 12, 13, 14, 15}, {16, 17, 18, 19, 20}},
{{21, 22, 23, 24, 25},
......@@ -455,7 +456,7 @@ TEST(onnx, model_softmax)
.get_vector());
auto expected_output =
ngraph::test::NDArray<float, 3>(
test::NDArray<float, 3>(
{{{1.50461533e-26f, 4.08996852e-26f, 1.11176871e-25f, 3.02210068e-25f, 8.21492137e-25f},
{2.23304715e-24f, 6.07005148e-24f, 1.65001106e-23f, 4.48519509e-23f, 1.21920243e-22f},
{3.31413582e-22f, 9.00875516e-22f, 2.44883355e-21f, 6.65661973e-21f, 1.80945684e-20f},
......@@ -522,15 +523,15 @@ TEST(onnx, model_flatten)
TEST(onnx, model_sub)
{
auto function = ngraph::onnx_import::import_onnx_function(
ngraph::file_util::path_join(SERIALIZED_ZOO, "onnx/sub.onnx"));
auto function =
onnx_import::import_onnx_function(file_util::path_join(SERIALIZED_ZOO, "onnx/sub.onnx"));
Inputs inputs;
inputs.emplace_back(ngraph::test::NDArray<float, 3>({{{1, 2, 3}}}).get_vector());
inputs.emplace_back(test::NDArray<float, 3>({{{1, 2, 3}}}).get_vector());
inputs.emplace_back(ngraph::test::NDArray<float, 3>({{{4, 5, 7}}}).get_vector());
inputs.emplace_back(test::NDArray<float, 3>({{{4, 5, 7}}}).get_vector());
auto expected_output = ngraph::test::NDArray<float, 3>({{{-3, -3, -4}}}).get_vector();
auto expected_output = test::NDArray<float, 3>({{{-3, -3, -4}}}).get_vector();
auto result_vectors = execute(function, inputs, "INTERPRETER");
EXPECT_TRUE(test::all_close_f(expected_output, result_vectors.front()));
......@@ -561,16 +562,160 @@ TEST(onnx, model_unsqueeze)
TEST(onnx, model_div)
{
auto function = ngraph::onnx_import::import_onnx_function(
ngraph::file_util::path_join(SERIALIZED_ZOO, "onnx/div.onnx"));
auto function =
onnx_import::import_onnx_function(file_util::path_join(SERIALIZED_ZOO, "onnx/div.onnx"));
Inputs inputs;
inputs.emplace_back(ngraph::test::NDArray<float, 3>({{{1, 2, 3}}}).get_vector());
inputs.emplace_back(test::NDArray<float, 3>({{{1, 2, 3}}}).get_vector());
inputs.emplace_back(ngraph::test::NDArray<float, 3>({{{1, 4, 12}}}).get_vector());
inputs.emplace_back(test::NDArray<float, 3>({{{1, 4, 12}}}).get_vector());
auto expected_output = ngraph::test::NDArray<float, 3>({{{1, 0.5, 0.25}}}).get_vector();
auto expected_output = test::NDArray<float, 3>({{{1, 0.5, 0.25}}}).get_vector();
auto result_vectors = execute(function, inputs, "INTERPRETER");
EXPECT_TRUE(test::all_close_f(expected_output, result_vectors.front()));
}
TEST(onnx, model_reshape_reduced_dims)
{
auto function = onnx_import::import_onnx_function(
file_util::path_join(SERIALIZED_ZOO, "onnx/reshape_reduced_dims.onnx"));
// input data shape (2, 3, 4)
Inputs inputs{test::NDArray<float, 3>({{{0, 1, 2, 3}, {4, 5, 6, 7}, {8, 9, 10, 11}},
{{12, 13, 14, 15}, {16, 17, 18, 19}, {20, 21, 22, 23}}})
.get_vector()};
// output data shape (2, 12)
Outputs expected_outputs{
test::NDArray<float, 2>({{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11},
{12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}})
.get_vector()};
Outputs outputs{execute(function, inputs, "INTERPRETER")};
EXPECT_TRUE(test::all_close_f(expected_outputs.front(), outputs.front()));
}
TEST(onnx, model_reshape_reordered_dims)
{
auto function = onnx_import::import_onnx_function(
file_util::path_join(SERIALIZED_ZOO, "onnx/reshape_reordered_dims.onnx"));
// input data shape (2, 3, 4)
Inputs inputs{test::NDArray<float, 3>({{{0, 1, 2, 3}, {4, 5, 6, 7}, {8, 9, 10, 11}},
{{12, 13, 14, 15}, {16, 17, 18, 19}, {20, 21, 22, 23}}})
.get_vector()};
// output data shape (4, 2, 3)
Outputs expected_outputs{test::NDArray<float, 3>({{{0, 1, 2}, {3, 4, 5}},
{{6, 7, 8}, {9, 10, 11}},
{{12, 13, 14}, {15, 16, 17}},
{{18, 19, 20}, {21, 22, 23}}})
.get_vector()};
Outputs outputs{execute(function, inputs, "INTERPRETER")};
EXPECT_TRUE(test::all_close_f(expected_outputs.front(), outputs.front()));
}
TEST(onnx, model_reshape_extended_dims)
{
auto function = onnx_import::import_onnx_function(
file_util::path_join(SERIALIZED_ZOO, "onnx/reshape_extended_dims.onnx"));
// input data shape (2, 3, 4)
Inputs inputs{test::NDArray<float, 3>({{{0, 1, 2, 3}, {4, 5, 6, 7}, {8, 9, 10, 11}},
{{12, 13, 14, 15}, {16, 17, 18, 19}, {20, 21, 22, 23}}})
.get_vector()};
// output data shape (3, 2, 2, 2)
Outputs expected_outputs{test::NDArray<float, 4>({{{{0, 1}, {2, 3}}, {{4, 5}, {6, 7}}},
{{{8, 9}, {10, 11}}, {{12, 13}, {14, 15}}},
{{{16, 17}, {18, 19}}, {{20, 21}, {22, 23}}}})
.get_vector()};
Outputs outputs{execute(function, inputs, "INTERPRETER")};
EXPECT_TRUE(test::all_close_f(expected_outputs.front(), outputs.front()));
}
TEST(onnx, model_reshape_single_dim)
{
auto function = onnx_import::import_onnx_function(
file_util::path_join(SERIALIZED_ZOO, "onnx/reshape_single_dim.onnx"));
// input data shape (2, 3, 4)
Inputs inputs{test::NDArray<float, 3>({{{0, 1, 2, 3}, {4, 5, 6, 7}, {8, 9, 10, 11}},
{{12, 13, 14, 15}, {16, 17, 18, 19}, {20, 21, 22, 23}}})
.get_vector()};
// output data shape (24, )
Outputs expected_outputs{
test::NDArray<float, 1>(
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23})
.get_vector()};
Outputs outputs{execute(function, inputs, "INTERPRETER")};
EXPECT_TRUE(test::all_close_f(expected_outputs.front(), outputs.front()));
}
TEST(onnx, model_reshape_negative_dim)
{
auto function = onnx_import::import_onnx_function(
file_util::path_join(SERIALIZED_ZOO, "onnx/reshape_negative_dim.onnx"));
// input data shape (2, 3, 4)
Inputs inputs{test::NDArray<float, 3>({{{0, 1, 2, 3}, {4, 5, 6, 7}, {8, 9, 10, 11}},
{{12, 13, 14, 15}, {16, 17, 18, 19}, {20, 21, 22, 23}}})
.get_vector()};
// output data shape (6, 2, 2)
Outputs expected_outputs{test::NDArray<float, 3>({{{0, 1}, {2, 3}},
{{4, 5}, {6, 7}},
{{8, 9}, {10, 11}},
{{12, 13}, {14, 15}},
{{16, 17}, {18, 19}},
{{20, 21}, {22, 23}}})
.get_vector()};
Outputs outputs{execute(function, inputs, "INTERPRETER")};
EXPECT_TRUE(test::all_close_f(expected_outputs.front(), outputs.front()));
}
TEST(onnx, model_reshape_negative_with_zero_dim)
{
auto function = onnx_import::import_onnx_function(
file_util::path_join(SERIALIZED_ZOO, "onnx/reshape_negative_with_zero_dims.onnx"));
// input data shape (2, 3, 4)
Inputs inputs{test::NDArray<float, 3>({{{0, 1, 2, 3}, {4, 5, 6, 7}, {8, 9, 10, 11}},
{{12, 13, 14, 15}, {16, 17, 18, 19}, {20, 21, 22, 23}}})
.get_vector()};
// output data shape (2, 6, 2)
Outputs expected_outputs{
test::NDArray<float, 3>({{{0, 1}, {2, 3}, {4, 5}, {6, 7}, {8, 9}, {10, 11}},
{{12, 13}, {14, 15}, {16, 17}, {18, 19}, {20, 21}, {22, 23}}})
.get_vector()};
Outputs outputs{execute(function, inputs, "INTERPRETER")};
EXPECT_TRUE(test::all_close_f(expected_outputs.front(), outputs.front()));
}
TEST(onnx, model_reshape_output_shape_as_input)
{
auto function = onnx_import::import_onnx_function(
file_util::path_join(SERIALIZED_ZOO, "onnx/reshape_output_shape_as_input.onnx"));
// input data shape (2, 3, 4)
Inputs inputs{test::NDArray<float, 3>({{{0, 1, 2, 3}, {4, 5, 6, 7}, {8, 9, 10, 11}},
{{12, 13, 14, 15}, {16, 17, 18, 19}, {20, 21, 22, 23}}})
.get_vector()};
// output data shape (2, 6, 2)
Outputs expected_outputs{
test::NDArray<float, 3>({{{0, 1}, {2, 3}, {4, 5}, {6, 7}, {8, 9}, {10, 11}},
{{12, 13}, {14, 15}, {16, 17}, {18, 19}, {20, 21}, {22, 23}}})
.get_vector()};
Outputs outputs{execute(function, inputs, "INTERPRETER")};
EXPECT_TRUE(test::all_close_f(expected_outputs.front(), outputs.front()));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment