Commit 2a49f1c8 authored by Michał Karzyński's avatar Michał Karzyński Committed by Robert Kimball

[ONNX] Add ArgMin/Max operators (#1898)

* Add ArgMin operator

* Add ArgMax and a basic test case

* Rename variables

* Apply workaround for problems with Reshape on i64

* Review comments

* Review comments
parent 46ed8e05
......@@ -38,6 +38,10 @@ add_library(onnx_import STATIC
op/acos.hpp
op/add.hpp
op/and.hpp
op/argmax.cpp
op/argmax.hpp
op/argmin.cpp
op/argmin.hpp
op/asin.hpp
op/atan.hpp
op/average_pool.cpp
......
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/op/argmax.hpp"
#include "core/node.hpp"
#include "ngraph/node_vector.hpp"
#include "utils/reduction.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector argmax(const Node& node)
{
return {reduction::make_ng_index_reduction_op<ngraph::op::ArgMax>(node)};
}
} // namespace set_1
} //namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "core/node.hpp"
#include "ngraph/node_vector.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace set_1
{
/// \brief Convert ONNX ArgMax operation to an nGraph node.
///
/// \param node The ONNX node object representing this operation.
///
/// \return The vector containing an Ngraph node which produces the output
/// of an ONNX ArgMax operation.
NodeVector argmax(const Node& node);
} // namespace set_1
} //namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/op/argmin.hpp"
#include "core/node.hpp"
#include "ngraph/node_vector.hpp"
#include "utils/reduction.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector argmin(const Node& node)
{
return {reduction::make_ng_index_reduction_op<ngraph::op::ArgMin>(node)};
}
} // namespace set_1
} //namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "core/node.hpp"
#include "ngraph/node_vector.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace set_1
{
/// \brief Convert ONNX ArgMin operation to an nGraph node.
///
/// \param node The ONNX node object representing this operation.
///
/// \return The vector containing an Ngraph node which produces the output
/// of an ONNX ArgMin operation.
NodeVector argmin(const Node& node);
} // namespace set_1
} //namespace op
} // namespace onnx_import
} // namespace ngraph
......@@ -25,6 +25,8 @@
#include "op/acos.hpp"
#include "op/add.hpp"
#include "op/and.hpp"
#include "op/argmax.hpp"
#include "op/argmin.hpp"
#include "op/asin.hpp"
#include "op/atan.hpp"
#include "op/average_pool.hpp"
......@@ -145,6 +147,8 @@ namespace ngraph
REGISTER_OPERATOR("Acos", 1, acos);
REGISTER_OPERATOR("Add", 1, add);
REGISTER_OPERATOR("And", 1, logical_and);
REGISTER_OPERATOR("ArgMin", 1, argmin);
REGISTER_OPERATOR("ArgMax", 1, argmax);
REGISTER_OPERATOR("Asin", 1, asin);
REGISTER_OPERATOR("Atan", 1, atan);
REGISTER_OPERATOR("AveragePool", 1, average_pool);
......
......@@ -27,6 +27,7 @@
#include "ngraph/axis_set.hpp"
#include "ngraph/node_vector.hpp"
#include "ngraph/op/convert.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/util/arithmetic_reduction.hpp"
#include "ngraph/shape.hpp"
......@@ -101,6 +102,38 @@ namespace ngraph
reshape::get_default_axis_vector(op_node->get_shape().size()),
Shape{output_shape});
}
template <class IndexReduction>
std::shared_ptr<ngraph::Node> make_ng_index_reduction_op(const Node& node)
{
auto axis = node.get_attribute_value<int64_t>("axis", 0);
auto keepdims = node.get_attribute_value<int64_t>("keepdims", 1);
auto input_node = node.get_ng_inputs().at(0);
auto op_node = std::make_shared<IndexReduction>(input_node, axis, element::i64);
if (keepdims == 0)
{
return op_node;
}
// WORKAROUND FOR PROBLEMS WITH RESHAPE ON i64 @TODO: remove
auto convert_node = std::make_shared<ngraph::op::Convert>(op_node, element::f32);
auto output_shape = input_node->get_shape();
output_shape.at(axis) = 1;
auto reshape_node = std::make_shared<ngraph::op::Reshape>(
convert_node,
reshape::get_default_axis_vector(op_node->get_shape().size()),
Shape{output_shape});
// WORKAROUND FOR PROBLEMS WITH RESHAPE ON i64 @TODO: remove
auto reconvert_node =
std::make_shared<ngraph::op::Convert>(reshape_node, element::i64);
return reconvert_node;
}
} // namespace reduction
} // namespace onnx_import
} // namespace ngraph
......@@ -17,6 +17,7 @@
#include <cstdint>
#include <fstream>
#include <sstream>
#include <vector>
#include "gtest/gtest.h"
#include "ngraph/frontend/onnx_import/onnx.hpp"
......@@ -74,6 +75,18 @@ TEST(onnx, model_addmul_abc)
EXPECT_TRUE(test::all_close_f(expected_output, result_vectors.front()));
}
TEST(onnx, model_argmin_no_keepdims)
{
auto function = onnx_import::import_onnx_function(
file_util::path_join(SERIALIZED_ZOO, "onnx/argmin_no_keepdims.onnx"));
Inputs inputs{test::NDArray<float, 2>{{2, 1}, {3, 10}}.get_vector()};
std::vector<std::vector<int64_t>> expected_output{{1, 0}};
std::vector<std::vector<int64_t>> result{
execute<float, int64_t>(function, inputs, "INTERPRETER")};
EXPECT_EQ(expected_output, result);
}
TEST(onnx, model_split_equal_parts_default)
{
Model model{onnx_import::load_onnx_model(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment