Commit 951e77b4 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Scott Cyphers

ArgMin (#1435)

* argmin

* address feedbacka argmin

* add new lines

*  addnew lines

* address adam's nitpicks

* scott's feedback

* fix unit tests
parent c46d4546
......@@ -40,6 +40,7 @@ set (SRC
op/add.cpp
op/allreduce.cpp
op/and.cpp
op/argmin.cpp
op/asin.cpp
op/atan.cpp
op/avg_pool.cpp
......@@ -109,6 +110,7 @@ set (SRC
op/util/binary_elementwise_comparison.cpp
op/util/binary_elementwise_logical.cpp
op/util/binary_elementwise.cpp
op/util/index_reduction.cpp
op/util/requires_tensor_view_args.cpp
op/util/unary_elementwise_arithmetic.cpp
op/util/unary_elementwise.cpp
......
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "ngraph/op/argmin.hpp"
using namespace std;
using namespace ngraph;
shared_ptr<Node> op::ArgMin::copy_with_new_args(const NodeVector& new_args) const
{
if (new_args.size() != 1)
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<ArgMin>(new_args.at(0), m_axis, this->get_element_type());
}
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include "ngraph/axis_set.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/op/util/index_reduction.hpp"
#include "ngraph/op/util/requires_tensor_view_args.hpp"
namespace ngraph
{
namespace op
{
//brief Computes minimum index along a specified axis for a given tensor
class ArgMin : public op::util::IndexReduction
{
public:
/// \brief Constructs a ArgMin operation.
///
/// \param arg The input tensor
/// \param axis The axis along which to compute an index for minimum
/// \param index_element_type produce indices. Currently, only int64 or int32 are supported
ArgMin(const std::shared_ptr<Node>& arg,
size_t axis,
const element::Type& index_element_type)
: IndexReduction("ArgMin", arg, axis, index_element_type)
{
}
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
};
}
}
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include <memory>
#include "ngraph/op/util/index_reduction.hpp"
using namespace std;
using namespace ngraph;
op::util::IndexReduction::IndexReduction(const std::string& node_type,
const std::shared_ptr<Node>& arg,
size_t axis,
const element::Type& index_element_type)
: RequiresTensorViewArgs(node_type, {arg})
, m_axis(axis)
{
auto rank = arg->get_shape().size();
TYPE_CHECK_ASSERT(this, rank >= 1) << "Tensor's rank must be at least 1";
TYPE_CHECK_ASSERT(this, axis < rank) << "Axis " << axis << " is greater than rank of " << rank;
TYPE_CHECK_ASSERT(this,
index_element_type == element::i32 || index_element_type == element::i64)
<< "Index element type must be i64 or i32";
Shape output_shape = arg->get_shape();
output_shape.erase(output_shape.begin() + axis);
set_value_type_checked(index_element_type, output_shape);
}
void op::util::IndexReduction::generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas)
{
throw ngraph_error("Forward-propagation-only operation");
}
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include "ngraph/op/util/requires_tensor_view_args.hpp"
namespace ngraph
{
namespace op
{
namespace util
{
class IndexReduction : public util::RequiresTensorViewArgs
{
public:
size_t get_reduction_axis() const { return m_axis; }
element::Type get_index_element_type() const { return m_index_element_type; }
IndexReduction(const std::string& node_type,
const std::shared_ptr<Node>& arg,
size_t axis,
const element::Type& index_element_type);
protected:
size_t m_axis;
element::Type m_index_element_type;
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
};
}
}
}
......@@ -26,6 +26,7 @@ set(SRC
cpu_tracing.cpp
builder/add.cpp
builder/avg_pool.cpp
builder/argmin.cpp
builder/batch_norm.cpp
builder/broadcast.cpp
builder/bounded_relu.cpp
......
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include <cstring>
#include "ngraph/op/argmin.hpp"
#include "ngraph/runtime/cpu/cpu_builder.hpp"
#include "ngraph/runtime/reference/argmin.hpp"
using namespace std;
using namespace ngraph;
namespace ngraph
{
namespace runtime
{
namespace cpu
{
template <>
void Builder::BUILDER_DECL(ngraph::op::ArgMin)
{
auto& functors = external_function->get_functors();
auto& tensor_data = external_function->get_tensor_data();
const ngraph::op::ArgMin* argmin = static_cast<const ngraph::op::ArgMin*>(node);
function<void(CPURuntimeContext*)> functor;
auto& arg_tensor = tensor_data[args[0].get_name()];
auto& out_tensor = tensor_data[out[0].get_name()];
if (out[0].get_element_type() != element::i64 &&
out[0].get_element_type() != element::i32)
{
throw ngraph_error("Unsupported index element type");
}
bool is_int64 = out[0].get_element_type() == element::i64;
auto axis = argmin->get_reduction_axis();
auto in_shape = args[0].get_shape();
auto out_shape = out[0].get_shape();
auto element_type = args[0].get_element_type();
if (element_type == element::f32)
{
if (is_int64)
{
functor = [&, in_shape, out_shape, axis](CPURuntimeContext* ctx) {
ngraph::runtime::reference::argmin<float, int64_t>(
static_cast<float*>(arg_tensor),
static_cast<int64_t*>(out_tensor),
in_shape,
out_shape,
axis);
};
}
else
{
functor = [&, in_shape, out_shape, axis](CPURuntimeContext* ctx) {
ngraph::runtime::reference::argmin<float, int32_t>(
static_cast<float*>(arg_tensor),
static_cast<int*>(out_tensor),
in_shape,
out_shape,
axis);
};
}
}
else if (element_type == element::f64)
{
if (is_int64)
{
functor = [&, in_shape, out_shape, axis](CPURuntimeContext* ctx) {
ngraph::runtime::reference::argmin<double, int64_t>(
static_cast<double*>(arg_tensor),
static_cast<int64_t*>(out_tensor),
in_shape,
out_shape,
axis);
};
}
else
{
functor = [&, in_shape, out_shape, axis](CPURuntimeContext* ctx) {
ngraph::runtime::reference::argmin<double, int32_t>(
static_cast<double*>(arg_tensor),
static_cast<int*>(out_tensor),
in_shape,
out_shape,
axis);
};
}
}
else
{
throw ngraph_error("Unsupported type in CPU Builder for ArgMin");
}
functors.emplace_back(functor);
}
REGISTER_OP_BUILDER(ArgMin);
}
}
}
......@@ -28,6 +28,7 @@
#include "ngraph/op/add.hpp"
#include "ngraph/op/allreduce.hpp"
#include "ngraph/op/and.hpp"
#include "ngraph/op/argmin.hpp"
#include "ngraph/op/asin.hpp"
#include "ngraph/op/atan.hpp"
#include "ngraph/op/avg_pool.hpp"
......@@ -2292,6 +2293,26 @@ namespace ngraph
writer.block_end();
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::ArgMin)
{
auto argmin = static_cast<const ngraph::op::ArgMin*>(node);
if (out[0].get_element_type() != element::i64 &&
out[0].get_element_type() != element::i32)
{
throw ngraph_error("Unsupported index element type");
}
writer.block_begin();
writer << "reference::argmin<" << args[0].get_type() << ", "
<< out[0].get_element_type().c_type_string() << ">(" << args[0].get_name()
<< ",\n";
writer << " " << out[0].get_name() << ",\n";
writer << " {" << join(args[0].get_shape()) << "},\n";
writer << " {" << join(out[0].get_shape()) << "},\n";
writer << " " << argmin->get_reduction_axis() << ");\n";
writer.block_end();
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Power)
{
......
......@@ -52,6 +52,7 @@
#include "ngraph/op/add.hpp"
#include "ngraph/op/allreduce.hpp"
#include "ngraph/op/and.hpp"
#include "ngraph/op/argmin.hpp"
#include "ngraph/op/asin.hpp"
#include "ngraph/op/atan.hpp"
#include "ngraph/op/avg_pool.hpp"
......@@ -279,6 +280,7 @@ static const runtime::cpu::OpMap dispatcher{
{TI(ngraph::op::Tan), &runtime::cpu::CPU_Emitter::emit<op::Tan>},
{TI(ngraph::op::Tanh), &runtime::cpu::CPU_Emitter::emit<op::Tanh>},
{TI(ngraph::op::Asin), &runtime::cpu::CPU_Emitter::emit<op::Asin>},
{TI(ngraph::op::ArgMin), &runtime::cpu::CPU_Emitter::emit<op::ArgMin>},
{TI(ngraph::op::Acos), &runtime::cpu::CPU_Emitter::emit<op::Acos>},
{TI(ngraph::op::Atan), &runtime::cpu::CPU_Emitter::emit<op::Atan>},
{TI(ngraph::op::ReplaceSlice), &runtime::cpu::CPU_Emitter::emit<op::ReplaceSlice>},
......@@ -424,6 +426,7 @@ void runtime::cpu::CPU_ExternalFunction::compile()
#include "ngraph/runtime/cpu/cpu_runtime_context.hpp"
#include "ngraph/runtime/cpu/mkldnn_invoke.hpp"
#include "ngraph/runtime/reference/and.hpp"
#include "ngraph/runtime/reference/argmin.hpp"
#include "ngraph/runtime/reference/avg_pool.hpp"
#include "ngraph/runtime/reference/batch_norm.hpp"
#include "ngraph/runtime/reference/broadcast.hpp"
......
......@@ -32,3 +32,4 @@ backwards_avgpool_n1_c1_hw4x4
backwards_avgpool_n2_c2_hw4x4
max_pool_3d
avg_pool_3d
argmin_trivial
......@@ -117,3 +117,4 @@ zero_sized_sqrt
zero_sized_subtract
zero_sized_tan
zero_sized_tanh
argmin_trivial
......@@ -25,6 +25,7 @@
#include "ngraph/runtime/host_tensor_view.hpp"
#include "ngraph/runtime/tensor_view.hpp"
#include "ngraph/op/argmin.hpp"
#include "ngraph/op/avg_pool.hpp"
#include "ngraph/op/batch_norm.hpp"
#include "ngraph/op/broadcast.hpp"
......@@ -56,6 +57,7 @@
#include "ngraph/runtime/reference/acos.hpp"
#include "ngraph/runtime/reference/add.hpp"
#include "ngraph/runtime/reference/and.hpp"
#include "ngraph/runtime/reference/argmin.hpp"
#include "ngraph/runtime/reference/asin.hpp"
#include "ngraph/runtime/reference/atan.hpp"
#include "ngraph/runtime/reference/avg_pool.hpp"
......@@ -209,6 +211,30 @@ private:
out[0]->get_data_ptr<T>(),
out[0]->get_element_count());
}
else if (node_op == "ArgMin")
{
const op::ArgMin* argmin = static_cast<const op::ArgMin*>(&node);
if (out[0]->get_element_type() == element::i64)
{
reference::argmin<T, int64_t>(args[0]->get_data_ptr<T>(),
out[0]->get_data_ptr<int64_t>(),
args[0]->get_shape(),
out[0]->get_shape(),
argmin->get_reduction_axis());
}
else if (out[0]->get_element_type() == element::i32)
{
reference::argmin<T, int32_t>(args[0]->get_data_ptr<T>(),
out[0]->get_data_ptr<int32_t>(),
args[0]->get_shape(),
out[0]->get_shape(),
argmin->get_reduction_axis());
}
else
{
throw ngraph_error("Unexpected type");
}
}
else if (node_op == "Asin")
{
reference::asin<T>(
......
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include <cmath>
#include "ngraph/coordinate_transform.hpp"
namespace ngraph
{
namespace runtime
{
namespace reference
{
template <typename T, typename U>
void argmin(
const T* arg, U* out, const Shape& in_shape, const Shape& out_shape, size_t axis)
{
//take the first elements (i.e. 0 indices) in out_shape - axis as minimums
memset(out, 0, shape_size(out_shape) * sizeof(U));
AxisVector av{axis};
CoordinateTransform input_transform(in_shape);
for (const Coordinate& input_coord : input_transform)
{
Coordinate output_coord = project(input_coord, av);
CoordinateTransform output_transform(out_shape);
auto min_index = static_cast<size_t>(out[output_transform.index(output_coord)]);
auto min_coord = input_coord;
min_coord[axis] = min_index;
if (arg[input_transform.index(input_coord)] <
arg[input_transform.index(min_coord)])
{
out[output_transform.index(output_coord)] =
static_cast<U>(input_coord[axis]);
}
}
}
}
}
}
......@@ -25,6 +25,7 @@
#include "ngraph/op/add.hpp"
#include "ngraph/op/allreduce.hpp"
#include "ngraph/op/and.hpp"
#include "ngraph/op/argmin.hpp"
#include "ngraph/op/asin.hpp"
#include "ngraph/op/atan.hpp"
#include "ngraph/op/avg_pool.hpp"
......@@ -387,6 +388,12 @@ static shared_ptr<ngraph::Function>
{
node = make_shared<op::And>(args[0], args[1]);
}
else if (node_op == "ArgMin")
{
auto axis = node_js.at("axis").get<size_t>();
auto target_type = read_element_type(node_js.at("index_element_type"));
node = make_shared<op::ArgMin>(args[0], axis, target_type);
}
else if (node_op == "Asin")
{
node = make_shared<op::Asin>(args[0]);
......@@ -1011,6 +1018,12 @@ static json write(const Node& n, bool binary_constant_data)
else if (node_op == "Add")
{
}
else if (node_op == "ArgMin")
{
auto tmp = dynamic_cast<const op::ArgMin*>(&n);
node["axis"] = tmp->get_reduction_axis();
node["index_element_type"] = write_element_type(tmp->get_element_type());
}
else if (node_op == "AllReduce")
{
}
......
......@@ -25,6 +25,7 @@
#include "ngraph/autodiff/adjoints.hpp"
#include "ngraph/log.hpp"
#include "ngraph/ngraph.hpp"
#include "ngraph/op/argmin.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/lrn.hpp"
#include "ngraph/serializer.hpp"
......@@ -8926,3 +8927,23 @@ NGRAPH_TEST(${BACKEND_NAME}, reverse_sequence_n4d2c3h2w2)
backend->call_with_validate(f, {result}, {a, b});
EXPECT_EQ(read_vector<int>(result), expected);
}
// Trivial case.
NGRAPH_TEST(${BACKEND_NAME}, argmin_trivial)
{
Shape shape{4, 3};
Shape rshape{3};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto f =
make_shared<Function>(make_shared<op::ArgMin>(A, 0, element::i32), op::ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape);
copy_data(a, vector<float>{12, 2, 10, 9, 8, 4, 6, 1, 5, 3, 11, 7});
auto result = backend->create_tensor(element::i32, rshape);
backend->call_with_validate(f, {result}, {a});
EXPECT_EQ((vector<int>{3, 2, 1}), read_vector<int>(result));
}
......@@ -17,6 +17,7 @@
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "ngraph/op/argmin.hpp"
#include "ngraph/op/batch_norm.hpp"
#include "ngraph/op/reverse_sequence.hpp"
......@@ -6457,3 +6458,60 @@ TEST(type_prop, sum_axis_oob)
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, index_reduction_scalar)
{
auto a = make_shared<op::Parameter>(element::f32, Shape{});
try
{
auto argmin = make_shared<op::ArgMin>(a, 0, element::i32);
FAIL() << "ArgMin c-tor should throw for scalar shapes";
}
catch (const TypeCheckError& error)
{
EXPECT_HAS_SUBSTRING(error.what(), "Tensor's rank must be at least 1");
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, index_reduction_invalid_rank)
{
auto a = make_shared<op::Parameter>(element::f32, Shape{2, 2});
try
{
auto argmin = make_shared<op::ArgMin>(a, 2, element::i32);
FAIL() << "ArgMin c-tor should throw for scalar shapes";
}
catch (const TypeCheckError& error)
{
EXPECT_HAS_SUBSTRING(error.what(), "is greater than rank of");
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, index_reduction_invalid_index_type)
{
auto a = make_shared<op::Parameter>(element::f32, Shape{2, 2});
try
{
auto argmin = make_shared<op::ArgMin>(a, 1, element::f32);
FAIL() << "ArgMin c-tor should throw for scalar shapes";
}
catch (const TypeCheckError& error)
{
EXPECT_HAS_SUBSTRING(error.what(), "Index element type must be");
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment