Commit 3548772b authored by Sang Ik Lee's avatar Sang Ik Lee Committed by Scott Cyphers

TopK (w/ArgMax, ArgMin python wrapper) (#1560)

* Implement TopK.

* Update python wrappers for TopK, ArgMin and ArgMax.

* Address some reviewer comments.

* Add type property check tests for TopK.
Set correct TopK behavior for K==0.

* TopK: Add 1d and 3d unit tests.

* Address more reviewer comments.

* Apply code style.
parent d309e96f
......@@ -19,6 +19,8 @@ from ngraph.ops import absolute
from ngraph.ops import absolute as abs
from ngraph.ops import acos
from ngraph.ops import add
from ngraph.ops import argmax
from ngraph.ops import argmin
from ngraph.ops import asin
from ngraph.ops import atan
from ngraph.ops import avg_pool
......@@ -79,5 +81,6 @@ from ngraph.ops import subtract
from ngraph.ops import sum
from ngraph.ops import tan
from ngraph.ops import tanh
from ngraph.ops import topk
from ngraph.runtime import runtime
......@@ -39,6 +39,8 @@ from _pyngraph.op import Acos
from _pyngraph.op import Add
from _pyngraph.op import AllReduce
from _pyngraph.op import And
from _pyngraph.op import ArgMax
from _pyngraph.op import ArgMin
from _pyngraph.op import Asin
from _pyngraph.op import Atan
from _pyngraph.op import AvgPool
......@@ -112,3 +114,4 @@ from _pyngraph.op import Subtract
from _pyngraph.op import Sum
from _pyngraph.op import Tan
from _pyngraph.op import Tanh
from _pyngraph.op import TopK
......@@ -37,3 +37,4 @@ from _pyngraph.op.util import BinaryElementwiseArithmetic
from _pyngraph.op.util import BinaryElementwiseLogical
from _pyngraph.op.util import OpAnnotations
from _pyngraph.op.util import ArithmeticReduction
from _pyngraph.op.util import IndexReduction
......@@ -20,12 +20,12 @@ import numpy as np
from ngraph.impl import AxisSet, AxisVector, Coordinate, CoordinateDiff, Function, Node, \
NodeVector, Shape, Strides
from ngraph.impl.op import Abs, Acos, Add, And, Asin, Atan, AvgPool, BatchNorm, Broadcast, \
Ceiling, Concat, Constant, Convert, Convolution, ConvolutionBackpropData, Cos, Cosh, Divide, \
Dot, Equal, Exp, Floor, FunctionCall, GetOutputElement, Greater, GreaterEq, Less, LessEq, \
Log, LRN, Max, Maximum, MaxPool, Min, Minimum, Multiply, Negative, Not, NotEqual, OneHot, Or, \
Pad, Parameter, Product, Power, Reduce, Relu, ReplaceSlice, Reshape, Reverse, Select, Sign, \
Sin, Sinh, Slice, Softmax, Sqrt, Subtract, Sum, Tan, Tanh
from ngraph.impl.op import Abs, Acos, Add, And, Asin, ArgMax, ArgMin, Atan, AvgPool, BatchNorm, \
Broadcast, Ceiling, Concat, Constant, Convert, Convolution, ConvolutionBackpropData, Cos, \
Cosh, Divide, Dot, Equal, Exp, Floor, FunctionCall, GetOutputElement, Greater, GreaterEq, \
Less, LessEq, Log, LRN, Max, Maximum, MaxPool, Min, Minimum, Multiply, Negative, Not, \
NotEqual, OneHot, Or, Pad, Parameter, Product, Power, Reduce, Relu, ReplaceSlice, Reshape, \
Reverse, Select, Sign, Sin, Sinh, Slice, Softmax, Sqrt, Subtract, Sum, Tan, Tanh, TopK
from typing import Callable, Iterable, List, Union
......@@ -951,6 +951,56 @@ def lrn(data, # type: Node
return LRN(data, alpha, beta, bias, size)
@nameable_op
def argmax(data, # type: Node
axis=0, # type: int
):
# type: (...) -> Node
"""Return a node which performs ArgMax index reduction operation.
:param data: Input data.
:param axis: Reduction Axis.
:return: The new node which performs ArgMax
"""
return ArgMax(data, axis, get_element_type(np.int32))
@nameable_op
def argmin(data, # type: Node
axis=0, # type: int
):
# type: (...) -> Node
"""Return a node which performs ArgMin index reduction operation.
:param data: Input data.
:param axis: Reduction Axis.
:return: The new node which performs ArgMin
"""
return ArgMin(data, axis, get_element_type(np.int32))
@nameable_op
def topk(data, # type: Node
k, # type: int
kaxis=-1, # type: int
cmax=True, # type: bool
):
# type: (...) -> Node
"""Return a node which performs TopK.
:param data: Input data.
:param kaxis: TopK Axis.
:param k: K.
:param cmax: Compute TopK largest (True) or smallest (False)
:return: The new node which performs TopK (both indices and values)
"""
return TopK(data,
len(data.get_shape()) - 1 if kaxis == -1 else kaxis,
get_element_type(np.int32),
k,
cmax)
@nameable_op
def function_call(function_to_call, args): # type: (Node, NodeVector) -> Node
"""Return Function call op."""
......
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "ngraph/op/argmax.hpp" // ngraph::op::ArgMax
#include "pyngraph/ops/argmax.hpp"
namespace py = pybind11;
void regclass_pyngraph_op_ArgMax(py::module m)
{
py::class_<ngraph::op::ArgMax,
std::shared_ptr<ngraph::op::ArgMax>,
ngraph::op::util::IndexReduction>
add(m, "ArgMax");
add.doc() = "ngraph.impl.op.ArgMax wraps ngraph::op::ArgMax";
add.def(py::init<const std::shared_ptr<ngraph::Node>&, size_t, const ngraph::element::Type&>());
}
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <pybind11/pybind11.h>
namespace py = pybind11;
void regclass_pyngraph_op_ArgMax(py::module m);
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "ngraph/op/argmin.hpp" // ngraph::op::ArgMin
#include "pyngraph/ops/argmin.hpp"
namespace py = pybind11;
void regclass_pyngraph_op_ArgMin(py::module m)
{
py::class_<ngraph::op::ArgMin,
std::shared_ptr<ngraph::op::ArgMin>,
ngraph::op::util::IndexReduction>
add(m, "ArgMin");
add.doc() = "ngraph.impl.op.ArgMin wraps ngraph::op::ArgMin";
add.def(py::init<const std::shared_ptr<ngraph::Node>&, size_t, const ngraph::element::Type&>());
}
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <pybind11/pybind11.h>
namespace py = pybind11;
void regclass_pyngraph_op_ArgMin(py::module m);
......@@ -24,6 +24,8 @@ void regmodule_pyngraph_op(py::module m_op)
{
regclass_pyngraph_op_Abs(m_op);
regclass_pyngraph_op_Acos(m_op);
regclass_pyngraph_op_ArgMax(m_op);
regclass_pyngraph_op_ArgMin(m_op);
regclass_pyngraph_op_Asin(m_op);
regclass_pyngraph_op_Atan(m_op);
regclass_pyngraph_op_AvgPool(m_op);
......@@ -80,6 +82,7 @@ void regmodule_pyngraph_op(py::module m_op)
regclass_pyngraph_op_Sum(m_op);
regclass_pyngraph_op_Tan(m_op);
regclass_pyngraph_op_Tanh(m_op);
regclass_pyngraph_op_TopK(m_op);
regclass_pyngraph_op_Relu(m_op);
regclass_pyngraph_op_ReluBackprop(m_op);
regclass_pyngraph_op_Max(m_op);
......
......@@ -21,6 +21,8 @@
#include "pyngraph/ops/acos.hpp"
#include "pyngraph/ops/add.hpp"
#include "pyngraph/ops/and.hpp"
#include "pyngraph/ops/argmax.hpp"
#include "pyngraph/ops/argmin.hpp"
#include "pyngraph/ops/asin.hpp"
#include "pyngraph/ops/atan.hpp"
#include "pyngraph/ops/avg_pool.hpp"
......@@ -80,6 +82,7 @@
#include "pyngraph/ops/sum.hpp"
#include "pyngraph/ops/tan.hpp"
#include "pyngraph/ops/tanh.hpp"
#include "pyngraph/ops/topk.hpp"
namespace py = pybind11;
......
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "ngraph/op/topk.hpp" // ngraph::op::TopK
#include "pyngraph/ops/topk.hpp"
namespace py = pybind11;
void regclass_pyngraph_op_TopK(py::module m)
{
py::class_<ngraph::op::TopK, std::shared_ptr<ngraph::op::TopK>, ngraph::op::Op> add(m, "TopK");
add.doc() = "ngraph.impl.op.TopK wraps ngraph::op::TopK";
add.def(py::init<const std::shared_ptr<ngraph::Node>&,
size_t,
const ngraph::element::Type&,
size_t,
bool>());
}
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <pybind11/pybind11.h>
namespace py = pybind11;
void regclass_pyngraph_op_TopK(py::module m);
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "ngraph/op/op.hpp"
#include "ngraph/op/util/index_reduction.hpp"
#include "pyngraph/ops/util/index_reduction.hpp"
namespace py = pybind11;
void regclass_pyngraph_op_util_IndexReduction(py::module m)
{
py::class_<ngraph::op::util::IndexReduction,
std::shared_ptr<ngraph::op::util::IndexReduction>,
ngraph::op::Op>
indexReduction(m, "IndexRedection");
indexReduction.def_property_readonly("reduction_axis",
&ngraph::op::util::IndexReduction::get_reduction_axis);
indexReduction.def_property_readonly("index_element_type",
&ngraph::op::util::IndexReduction::get_index_element_type);
}
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <pybind11/pybind11.h>
namespace py = pybind11;
void regclass_pyngraph_op_util_IndexReduction(py::module m);
......@@ -31,4 +31,5 @@ void regmodule_pyngraph_op_util(py::module m)
regclass_pyngraph_op_util_BinaryElementwiseLogical(m_util);
// regclass_pyngraph_op_util_UnaryElementwise(m_util);
regclass_pyngraph_op_util_UnaryElementwiseArithmetic(m_util);
regclass_pyngraph_op_util_IndexReduction(m_util);
}
......@@ -21,6 +21,7 @@
#include "pyngraph/ops/util/binary_elementwise_arithmetic.hpp"
#include "pyngraph/ops/util/binary_elementwise_comparison.hpp"
#include "pyngraph/ops/util/binary_elementwise_logical.hpp"
#include "pyngraph/ops/util/index_reduction.hpp"
#include "pyngraph/ops/util/op_annotations.hpp"
#include "pyngraph/ops/util/unary_elementwise_arithmetic.hpp"
......
......@@ -21,7 +21,7 @@ import setuptools
import os
import distutils.ccompiler
__version__ = '0.2.0'
__version__ = '0.7.0'
PYNGRAPH_SOURCE_DIR = os.path.abspath(os.path.dirname(__file__))
NGRAPH_DEFAULT_INSTALL_DIR = os.environ.get('HOME')
......@@ -142,10 +142,13 @@ sources = ['pyngraph/function.cpp',
'pyngraph/ops/util/binary_elementwise_logical.cpp',
'pyngraph/ops/util/regmodule_pyngraph_op_util.cpp',
'pyngraph/ops/util/unary_elementwise_arithmetic.cpp',
'pyngraph/ops/util/index_reduction.cpp',
'pyngraph/ops/abs.cpp',
'pyngraph/ops/acos.cpp',
'pyngraph/ops/add.cpp',
'pyngraph/ops/and.cpp',
'pyngraph/ops/argmax.cpp',
'pyngraph/ops/argmin.cpp',
'pyngraph/ops/asin.cpp',
'pyngraph/ops/atan.cpp',
'pyngraph/ops/avg_pool.cpp',
......@@ -200,6 +203,7 @@ sources = ['pyngraph/function.cpp',
'pyngraph/ops/sum.cpp',
'pyngraph/ops/tan.cpp',
'pyngraph/ops/tanh.cpp',
'pyngraph/ops/topk.cpp',
'pyngraph/ops/allreduce.cpp',
'pyngraph/ops/function_call.cpp',
'pyngraph/ops/get_output_element.cpp',
......@@ -294,8 +298,8 @@ setup(
package_dir={'ngraph': PYNGRAPH_SOURCE_DIR + "/ngraph",
'ngraph.utils': PYNGRAPH_SOURCE_DIR + "/ngraph/utils",
'ngraph.impl': PYNGRAPH_SOURCE_DIR + "/ngraph/impl",
'ngraph.impl.op': PYNGRAPH_SOURCE_DIR + "/ngraph/impl/op",
'ngraph.impl.onnx_import': PYNGRAPH_SOURCE_DIR + "/ngraph/impl/onnx_import",
'ngraph.impl.op': PYNGRAPH_SOURCE_DIR + "/ngraph/impl/op",
'ngraph.impl.op.util': PYNGRAPH_SOURCE_DIR + "/ngraph/impl/op/util",
'ngraph.impl.passes': PYNGRAPH_SOURCE_DIR + "/ngraph/impl/passes",
'ngraph.impl.runtime': PYNGRAPH_SOURCE_DIR + "/ngraph/impl/runtime"},
......
......@@ -17,7 +17,7 @@ import numpy as np
import pytest
import ngraph as ng
from test.ngraph.util import run_op_node
from test.ngraph.util import run_op_node, get_runtime
@pytest.mark.parametrize('ng_api_helper, numpy_function, reduction_axes', [
......@@ -45,6 +45,56 @@ def test_reduction_ops(ng_api_helper, numpy_function, reduction_axes):
assert np.allclose(result, expected)
@pytest.config.gpu_skip(reason='Not implemented')
def test_argmax():
runtime = get_runtime()
input_x = ng.constant(np.array([[9, 2, 10],
[12, 8, 4],
[6, 1, 5],
[3, 11, 7]], dtype=np.float32))
model = runtime.computation(ng.argmax(input_x, 0))
result = model()
assert np.allclose(result,
np.array([1, 3, 0], dtype=np.int32))
@pytest.config.gpu_skip(reason='Not implemented')
def test_argmin():
runtime = get_runtime()
input_x = ng.constant(np.array([[12, 2, 10],
[9, 8, 4],
[6, 1, 5],
[3, 11, 7]], dtype=np.float32))
model = runtime.computation(ng.argmin(input_x, 0))
result = model()
assert np.allclose(result,
np.array([3, 2, 1], dtype=np.int32))
@pytest.config.gpu_skip(reason='Not implemented')
def test_topk():
runtime = get_runtime()
input_x = ng.constant(np.array([[9, 2, 10],
[12, 8, 4],
[6, 1, 5],
[3, 11, 7]], dtype=np.float32))
comp_topk = ng.topk(input_x, 4, 0)
model0 = runtime.computation(ng.get_output_element(comp_topk, 0))
result0 = model0()
assert np.allclose(result0,
np.array([[1, 3, 0],
[0, 1, 3],
[2, 0, 2],
[3, 2, 1]], dtype=np.int32))
model1 = runtime.computation(ng.get_output_element(comp_topk, 1))
result1 = model1()
assert np.allclose(result1,
np.array([[12, 11, 10],
[9, 8, 7],
[6, 2, 5],
[3, 1, 4]], dtype=np.float32))
def test_reduce():
from functools import reduce
np.random.seed(133391)
......
......@@ -104,6 +104,7 @@ set (SRC
op/sum.cpp
op/tan.cpp
op/tanh.cpp
op/topk.cpp
op/util/arithmetic_reduction.cpp
op/util/binary_elementwise_arithmetic.cpp
op/util/binary_elementwise_comparison.cpp
......
......@@ -61,6 +61,8 @@
#include "ngraph/op/add.hpp"
#include "ngraph/op/allreduce.hpp"
#include "ngraph/op/and.hpp"
#include "ngraph/op/argmax.hpp"
#include "ngraph/op/argmin.hpp"
#include "ngraph/op/asin.hpp"
#include "ngraph/op/atan.hpp"
#include "ngraph/op/avg_pool.hpp"
......@@ -85,6 +87,7 @@
#include "ngraph/op/less.hpp"
#include "ngraph/op/less_eq.hpp"
#include "ngraph/op/log.hpp"
#include "ngraph/op/lrn.hpp"
#include "ngraph/op/max.hpp"
#include "ngraph/op/max_pool.hpp"
#include "ngraph/op/maximum.hpp"
......@@ -123,6 +126,7 @@
#include "ngraph/op/sum.hpp"
#include "ngraph/op/tan.hpp"
#include "ngraph/op/tanh.hpp"
#include "ngraph/op/topk.hpp"
#include "ngraph/runtime/backend.hpp"
#include "ngraph/runtime/tensor_view.hpp"
#include "ngraph/shape.hpp"
......
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include <memory>
#include "ngraph/axis_vector.hpp"
#include "ngraph/op/topk.hpp"
#include "ngraph/shape.hpp"
using namespace std;
using namespace ngraph;
op::TopK::TopK(const shared_ptr<Node>& arg,
size_t top_k_axis,
const element::Type& index_element_type,
size_t k,
bool compute_max)
: Op("TopK", check_single_output_args({arg}))
, m_top_k_axis(top_k_axis)
, m_index_element_type(index_element_type)
, m_k(k)
, m_compute_max(compute_max)
{
constructor_validate_and_infer_types();
}
void op::TopK::validate_and_infer_types()
{
auto& input = get_inputs().at(0);
auto rank = input.get_shape().size();
NODE_VALIDATION_ASSERT(this, rank > 0) << "Input Tensor's rank must be greater than 0";
NODE_VALIDATION_ASSERT(this, m_top_k_axis < rank) << "TopK axis must be less than rank";
NODE_VALIDATION_ASSERT(
this, m_index_element_type == element::i32 || m_index_element_type == element::i64)
<< "Index element type must be i64 or i32";
NODE_VALIDATION_ASSERT(this, m_k <= input.get_shape()[m_top_k_axis])
<< "K should not exceed TopK axis length";
Shape input_shape = input.get_shape();
Shape output_shape(input_shape);
if (m_k != 0)
{
output_shape[m_top_k_axis] = m_k;
}
else
{
m_k = input_shape[m_top_k_axis];
}
set_output_size(2);
set_output_type(0, m_index_element_type, output_shape);
set_output_type(1, input.get_element_type(), output_shape);
}
shared_ptr<Node> op::TopK::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<TopK>(
new_args.at(0), m_top_k_axis, m_index_element_type, m_k, m_compute_max);
}
void op::TopK::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
{
throw ngraph_error("Forward-propagation-only operation");
}
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include <memory>
#include "ngraph/axis_set.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/op/op.hpp"
namespace ngraph
{
namespace op
{
//brief Computes indices of top k maximum/minimum index along a specified axis for a given tensor
class TopK : public Op
{
public:
/// \brief Constructs a TopK operation.
///
/// \param arg The input tensor
/// \param top_k_axis The axis along which to compute top k indices
/// \param index_element_type produce indices. Currently, only int64 or int32 are supported
/// \param k Number of top indices to compute. Compute all indices if k = 0
/// \param compute_max Compute top k max or top k min?
TopK(const std::shared_ptr<Node>& arg,
size_t top_k_axis,
const element::Type& index_element_type,
size_t k = 0,
bool compute_max = true);
void validate_and_infer_types() override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
size_t get_top_k_axis() const { return m_top_k_axis; }
element::Type get_index_element_type() const { return m_index_element_type; }
size_t get_k() const { return m_k; }
bool get_compute_max() const { return m_compute_max; }
protected:
size_t m_top_k_axis;
element::Type m_index_element_type;
size_t m_k;
bool m_compute_max;
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
};
}
}
......@@ -61,6 +61,7 @@ set(SRC
builder/slice.cpp
builder/softmax.cpp
builder/sum.cpp
builder/topk.cpp
kernel/eigen_thread_pool.cpp
kernel/pad.cpp
kernel/reduce_max.cpp
......
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include <cstring>
#include "ngraph/op/topk.hpp"
#include "ngraph/runtime/cpu/cpu_builder.hpp"
#include "ngraph/runtime/reference/topk.hpp"
using namespace std;
using namespace ngraph;
namespace ngraph
{
namespace runtime
{
namespace cpu
{
template <>
void Builder::BUILDER_DECL(ngraph::op::TopK)
{
auto& functors = external_function->get_functors();
auto& tensor_data = external_function->get_tensor_data();
const ngraph::op::TopK* topk = static_cast<const ngraph::op::TopK*>(node);
function<void(CPURuntimeContext*)> functor;
auto& arg_tensor = tensor_data[args[0].get_name()];
auto& out_indices_tensor = tensor_data[out[0].get_name()];
auto& out_values_tensor = tensor_data[out[1].get_name()];
if (out[0].get_element_type() != element::i64 &&
out[0].get_element_type() != element::i32)
{
throw ngraph_error("Unsupported index element type");
}
bool is_int64 = out[0].get_element_type() == element::i64;
auto axis = topk->get_top_k_axis();
auto in_shape = args[0].get_shape();
auto out_shape = out[0].get_shape();
auto k = topk->get_k();
auto compute_max = topk->get_compute_max();
auto element_type = args[0].get_element_type();
if (element_type == element::f32)
{
if (is_int64)
{
functor =
[&, in_shape, out_shape, axis, k, compute_max](CPURuntimeContext* ctx) {
ngraph::runtime::reference::topk<float, int64_t>(
static_cast<float*>(arg_tensor),
static_cast<int64_t*>(out_indices_tensor),
static_cast<float*>(out_values_tensor),
in_shape,
out_shape,
axis,
k,
compute_max);
};
}
else
{
functor =
[&, in_shape, out_shape, axis, k, compute_max](CPURuntimeContext* ctx) {
ngraph::runtime::reference::topk<float, int32_t>(
static_cast<float*>(arg_tensor),
static_cast<int32_t*>(out_indices_tensor),
static_cast<float*>(out_values_tensor),
in_shape,
out_shape,
axis,
k,
compute_max);
};
}
}
else if (element_type == element::f64)
{
if (is_int64)
{
functor =
[&, in_shape, out_shape, axis, k, compute_max](CPURuntimeContext* ctx) {
ngraph::runtime::reference::topk<double, int64_t>(
static_cast<double*>(arg_tensor),
static_cast<int64_t*>(out_indices_tensor),
static_cast<double*>(out_values_tensor),
in_shape,
out_shape,
axis,
k,
compute_max);
};
}
else
{
functor =
[&, in_shape, out_shape, axis, k, compute_max](CPURuntimeContext* ctx) {
ngraph::runtime::reference::topk<double, int32_t>(
static_cast<double*>(arg_tensor),
static_cast<int32_t*>(out_indices_tensor),
static_cast<double*>(out_values_tensor),
in_shape,
out_shape,
axis,
k,
compute_max);
};
}
}
else
{
throw ngraph_error("Unsupported type in CPU Builder for TopK");
}
functors.emplace_back(functor);
}
REGISTER_OP_BUILDER(TopK);
}
}
}
......@@ -92,6 +92,7 @@
#include "ngraph/op/sum.hpp"
#include "ngraph/op/tan.hpp"
#include "ngraph/op/tanh.hpp"
#include "ngraph/op/topk.hpp"
#include "ngraph/runtime/cpu/cpu_kernel_emitters.hpp"
#include "ngraph/runtime/cpu/cpu_op_annotations.hpp"
#include "ngraph/runtime/cpu/mkldnn_utils.hpp"
......@@ -2331,6 +2332,30 @@ namespace ngraph
emitArgMinArgMax(args, out, argmax->get_reduction_axis(), "argmax", writer);
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::TopK)
{
auto topk = static_cast<const ngraph::op::TopK*>(node);
if (out[0].get_element_type() != element::i64 &&
out[0].get_element_type() != element::i32)
{
throw ngraph_error("Unsupported index element type");
}
writer.block_begin();
writer << "reference::topk<" << args[0].get_type() << ", "
<< out[0].get_element_type().c_type_string() << ">(" << args[0].get_name()
<< ",\n";
writer << " " << out[0].get_name() << ",\n";
writer << " " << out[1].get_name() << ",\n";
writer << " {" << join(args[0].get_shape()) << "},\n";
writer << " {" << join(out[0].get_shape()) << "},\n";
writer << " " << topk->get_top_k_axis() << ",\n";
writer << " " << topk->get_k() << ",\n";
writer << " " << topk->get_compute_max() << ");\n";
writer.block_end();
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Power)
{
......
......@@ -115,6 +115,7 @@
#include "ngraph/op/sum.hpp"
#include "ngraph/op/tan.hpp"
#include "ngraph/op/tanh.hpp"
#include "ngraph/op/topk.hpp"
#include "ngraph/pass/algebraic_simplification.hpp"
#include "ngraph/pass/common_function_collection.hpp"
#include "ngraph/pass/core_fusion.hpp"
......@@ -279,6 +280,7 @@ static const runtime::cpu::OpMap dispatcher{
{TI(ngraph::op::Cosh), &runtime::cpu::CPU_Emitter::emit<op::Cosh>},
{TI(ngraph::op::Tan), &runtime::cpu::CPU_Emitter::emit<op::Tan>},
{TI(ngraph::op::Tanh), &runtime::cpu::CPU_Emitter::emit<op::Tanh>},
{TI(ngraph::op::TopK), &runtime::cpu::CPU_Emitter::emit<op::TopK>},
{TI(ngraph::op::Asin), &runtime::cpu::CPU_Emitter::emit<op::Asin>},
{TI(ngraph::op::ArgMin), &runtime::cpu::CPU_Emitter::emit<op::ArgMin>},
{TI(ngraph::op::ArgMax), &runtime::cpu::CPU_Emitter::emit<op::ArgMax>},
......@@ -455,6 +457,7 @@ void runtime::cpu::CPU_ExternalFunction::compile()
#include "ngraph/runtime/reference/select_and_scatter.hpp"
#include "ngraph/runtime/reference/slice.hpp"
#include "ngraph/runtime/reference/sum.hpp"
#include "ngraph/runtime/reference/topk.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/strides.hpp"
#include "ngraph/util.hpp"
......
......@@ -34,3 +34,21 @@ max_pool_3d
avg_pool_3d
argmin_trivial
argmax_trivial
topk_1d_max_all
topk_1d_max_partial
topk_1d_max_one
topk_1d_min_all
topk_1d_min_partial
topk_1d_min_one
topk_2d_max_all
topk_2d_max_partial
topk_2d_max_one
topk_2d_min_all
topk_2d_min_partial
topk_2d_min_one
topk_3d_max_all
topk_3d_max_partial
topk_3d_max_one
topk_3d_min_all
topk_3d_min_partial
topk_3d_min_one
......@@ -115,3 +115,21 @@ zero_sized_tan
zero_sized_tanh
argmin_trivial
argmax_trivial
topk_1d_max_all
topk_1d_max_partial
topk_1d_max_one
topk_1d_min_all
topk_1d_min_partial
topk_1d_min_one
topk_2d_max_all
topk_2d_max_partial
topk_2d_max_one
topk_2d_min_all
topk_2d_min_partial
topk_2d_min_one
topk_3d_max_all
topk_3d_max_partial
topk_3d_max_one
topk_3d_min_all
topk_3d_min_partial
topk_3d_min_one
......@@ -54,6 +54,7 @@
#include "ngraph/op/sum.hpp"
#include "ngraph/op/select_and_scatter.hpp"
#include "ngraph/op/topk.hpp"
#include "ngraph/runtime/reference/abs.hpp"
#include "ngraph/runtime/reference/acos.hpp"
#include "ngraph/runtime/reference/add.hpp"
......@@ -119,6 +120,7 @@
#include "ngraph/runtime/reference/sum.hpp"
#include "ngraph/runtime/reference/tan.hpp"
#include "ngraph/runtime/reference/tanh.hpp"
#include "ngraph/runtime/reference/topk.hpp"
#ifdef NGRAPH_DISTRIBUTED
#include "ngraph/runtime/reference/allreduce.hpp"
......@@ -1025,6 +1027,36 @@ private:
reference::tanh<T>(
args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count());
}
else if (node_op == "TopK")
{
const op::TopK* topk = static_cast<const op::TopK*>(&node);
if (out[0]->get_element_type() == element::i64)
{
reference::topk<T, int64_t>(args[0]->get_data_ptr<T>(),
out[0]->get_data_ptr<int64_t>(),
out[1]->get_data_ptr<T>(),
args[0]->get_shape(),
out[0]->get_shape(),
topk->get_top_k_axis(),
topk->get_k(),
topk->get_compute_max());
}
else if (out[0]->get_element_type() == element::i32)
{
reference::topk<T, int32_t>(args[0]->get_data_ptr<T>(),
out[0]->get_data_ptr<int32_t>(),
out[1]->get_data_ptr<T>(),
args[0]->get_shape(),
out[0]->get_shape(),
topk->get_top_k_axis(),
topk->get_k(),
topk->get_compute_max());
}
else
{
throw ngraph_error("Unexpected type");
}
}
else
{
std::stringstream ss;
......
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include <algorithm>
#include <cmath>
#include <numeric>
#include "ngraph/coordinate_transform.hpp"
namespace ngraph
{
namespace runtime
{
namespace reference
{
template <typename T, typename U>
void topk(const T* arg,
U* out_indices,
T* out_values,
const Shape& in_shape,
const Shape& out_shape,
size_t axis,
size_t k,
bool compute_max)
{
using namespace std;
// reorder source axis visit order and make "axis" inner most
size_t ndim = static_cast<size_t>(in_shape.size());
Coordinate start_corner(ndim, 0);
Coordinate end_corner(in_shape);
end_corner[axis] = 1;
Strides strides(ndim, 1);
AxisVector axis_order(ndim);
iota(axis_order.begin(), axis_order.end(), 0);
axis_order.erase(axis_order.begin() + axis);
axis_order.push_back(axis);
// Create CoordinateTransforms that visits only the first element along "axis"
CoordinateTransform input_transform(
in_shape, start_corner, end_corner, strides, axis_order);
CoordinateTransform output_transform(
out_shape, start_corner, end_corner, strides, axis_order);
// Create temp vector for sorting.
vector<tuple<T, U>> workspace(in_shape[axis]);
vector<size_t> in_strides = ngraph::row_major_strides(in_shape);
vector<size_t> out_strides = ngraph::row_major_strides(out_shape);
auto in_axis_stride = in_strides[axis];
auto out_axis_stride = out_strides[axis];
for (const Coordinate& coord : input_transform)
{
auto arg_index = input_transform.index(coord);
auto out_index = output_transform.index(coord);
// Fill the temp vector
U i = 0;
for (tuple<T, U>& entry : workspace)
{
get<0>(entry) = arg[arg_index];
get<1>(entry) = i;
arg_index += in_axis_stride;
i++;
}
// Sort the temp vector
sort(
workspace.begin(),
workspace.end(),
compute_max
? [](const tuple<T, U>& a, const tuple<T, U>& b) -> bool { return a > b; }
: [](const tuple<T, U>& a, const tuple<T, U>& b) -> bool { return a < b; });
// Write temp vector to output
for (size_t j = 0; j < k; j++)
{
tuple<T, U> entry = workspace[j];
out_values[out_index] = get<0>(entry);
out_indices[out_index] = get<1>(entry);
out_index += out_axis_stride;
}
}
}
}
}
}
......@@ -90,6 +90,7 @@
#include "ngraph/op/sum.hpp"
#include "ngraph/op/tan.hpp"
#include "ngraph/op/tanh.hpp"
#include "ngraph/op/topk.hpp"
#include "ngraph/serializer.hpp"
#include "ngraph/util.hpp"
#include "nlohmann/json.hpp"
......@@ -915,6 +916,14 @@ static shared_ptr<ngraph::Function>
{
node = make_shared<op::Tanh>(args[0]);
}
else if (node_op == "TopK")
{
auto top_k_axis = node_js.at("top_k_axis").get<size_t>();
auto k = node_js.at("k").get<size_t>();
auto compute_max = node_js.at("compute_max").get<bool>();
auto target_type = read_element_type(node_js.at("index_element_type"));
node = make_shared<op::TopK>(args[0], top_k_axis, target_type, k, compute_max);
}
else if (node_op == "StopGradient")
{
node = make_shared<op::StopGradient>(args[0]);
......@@ -1365,6 +1374,14 @@ static json write(const Node& n, bool binary_constant_data)
else if (node_op == "Tanh")
{
}
else if (node_op == "TopK")
{
auto tmp = dynamic_cast<const op::TopK*>(&n);
node["top_k_axis"] = tmp->get_top_k_axis();
node["index_element_type"] = write_element_type(tmp->get_index_element_type());
node["k"] = tmp->get_k();
node["compute_max"] = tmp->get_compute_max();
}
return node;
}
This diff is collapsed.
......@@ -17,9 +17,6 @@
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "ngraph/op/argmin.hpp"
#include "ngraph/op/batch_norm.hpp"
#include "ngraph/op/reverse_sequence.hpp"
#include <memory>
using namespace std;
......@@ -6493,3 +6490,79 @@ TEST(type_prop, index_reduction_invalid_index_type)
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, topk_invalid_rank)
{
auto a = make_shared<op::Parameter>(element::f32, Shape{});
try
{
auto topk = make_shared<op::TopK>(a, 0, element::i32, 1, true);
FAIL() << "TopK c-tor should throw for scalar shapes";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(error.what(), "Input Tensor's rank must be greater than 0");
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, topk_invalid_top_k)
{
auto a = make_shared<op::Parameter>(element::f32, Shape{2, 2});
try
{
auto topk = make_shared<op::TopK>(a, 2, element::i32, 1, true);
FAIL() << "TopK c-tor should throw for invalid top k axis";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(error.what(), "TopK axis must be less than rank");
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, topk_invalid_index_type)
{
auto a = make_shared<op::Parameter>(element::f32, Shape{2, 2});
try
{
auto topk = make_shared<op::TopK>(a, 0, element::f32, 1, true);
FAIL() << "TopK c-tor should throw for invalid index element type";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(error.what(), "Index element type must be i64 or i32");
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, topk_invalid_k)
{
auto a = make_shared<op::Parameter>(element::f32, Shape{2, 2});
try
{
auto topk = make_shared<op::TopK>(a, 0, element::i32, 3, true);
FAIL() << "TopK c-tor should throw for invalid K";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(error.what(), "K should not exceed TopK axis length");
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment