Commit 411f83e2 authored by Adam Rogowiec's avatar Adam Rogowiec Committed by Robert Kimball

[ONNX] Update Squeeze Op to conform with doc. (#1746)

* Update ONNX Squeeze Op implementation to conform with doc. Add unit test.

* Apply code-format.

* Correct attribute value type.

* Change used loop structure.

* Modified version of loops.

- Without erase and with minimal computation time complexity.

* Run CI
parent e61c2e21
...@@ -14,15 +14,20 @@ ...@@ -14,15 +14,20 @@
// limitations under the License. // limitations under the License.
//***************************************************************************** //*****************************************************************************
#include <cstddef>
#include <functional>
#include <iterator>
#include <numeric> #include <numeric>
#include <set>
#include <vector> #include <vector>
#include "ngraph/axis_vector.hpp" #include "ngraph/axis_vector.hpp"
#include "ngraph/op/reshape.hpp" #include "ngraph/op/reshape.hpp"
#include "utils/reshape.hpp" #include "ngraph/shape.hpp"
#include "exceptions.hpp" #include "exceptions.hpp"
#include "squeeze.hpp" #include "squeeze.hpp"
#include "utils/reshape.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -37,35 +42,49 @@ namespace ngraph ...@@ -37,35 +42,49 @@ namespace ngraph
NodeVector inputs{node.get_ng_inputs()}; NodeVector inputs{node.get_ng_inputs()};
auto data = inputs.at(0); auto data = inputs.at(0);
auto data_shape = data->get_shape(); auto data_shape = data->get_shape();
auto axes = node.get_attribute_value<std::vector<uint64_t>>("axes", {}); auto axes = node.get_attribute_value<std::vector<std::size_t>>("axes", {});
AxisVector input_order{reshape::get_default_axis_vector(data_shape.size())};
// Prepare set of unique axes marked to be removed from input data.
if (axes.empty()) if (axes.empty())
{ {
for (auto index = 0; index < data_shape.size(); ++index) // Default behaviour is to remove all single dimension axes.
for (std::size_t idx = 0; idx < data_shape.size(); ++idx)
{ {
if (data_shape.at(index) == 1) if (data_shape.at(idx) == 1)
{ {
axes.push_back(index); // Mark with zero elements to remove;
data_shape.at(idx) = 0;
} }
} }
} }
else
std::sort(std::begin(axes), std::end(axes), std::greater<uint64_t>()); {
std::set<std::size_t, std::greater<std::size_t>> unique_axes(
AxisVector input_order{reshape::get_default_axis_vector(data_shape.size())}; std::begin(axes), std::end(axes));
for (uint64_t axis : unique_axes)
for (auto axis : axes)
{ {
data_shape.erase(std::next(std::begin(data_shape), axis)); ASSERT_VALID_ARGUMENT(node, data_shape.at(axis) == 1)
<< "provided axis value is invalid. Only single dimension axes may "
"be removed.";
// Mark with zero elements to remove;
data_shape.at(axis) = 0;
}
} }
return {std::make_shared<ngraph::op::Reshape>(data, input_order, data_shape)}; Shape output_data_shape;
for (std::size_t idx = 0; idx < data_shape.size(); ++idx)
{
if (data_shape.at(idx) != 0)
{
output_data_shape.push_back(data_shape.at(idx));
}
}
return {std::make_shared<ngraph::op::Reshape>(
data, input_order, output_data_shape)};
} }
} // namespace set_1 } // namespace set_1
} //namespace op } //namespace op
} // namespace onnx_import } // namespace onnx_import
} // namespace ngraph } // namespace ngraph
...@@ -561,6 +561,25 @@ TEST(onnx, model_unsqueeze) ...@@ -561,6 +561,25 @@ TEST(onnx, model_unsqueeze)
EXPECT_TRUE(test::all_close_f(expected_output.front(), outputs.front())); EXPECT_TRUE(test::all_close_f(expected_output.front(), outputs.front()));
} }
TEST(onnx, model_squeeze)
{
auto function = onnx_import::import_onnx_function(
file_util::path_join(SERIALIZED_ZOO, "onnx/squeeze_duplicate_axes.onnx"));
// {1, 4, 1, 1, 2}
Inputs inputs{test::NDArray<float, 5>(
{{{{{1.0f, 2.0f}}}, {{{3.0f, 4.0f}}}, {{{5.0f, 6.0f}}}, {{{7.0f, 8.0f}}}}})
.get_vector()};
// {4, 2}
Outputs expected_output{
test::NDArray<float, 2>({{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}, {7.0f, 8.0f}})
.get_vector()};
Outputs outputs{execute(function, inputs, "INTERPRETER")};
EXPECT_TRUE(test::all_close_f(expected_output.front(), outputs.front()));
}
TEST(onnx, model_div) TEST(onnx, model_div)
{ {
auto function = auto function =
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment