Commit 822aa81d authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Scott Cyphers

ArgMax (#1453)

* argmax

* manifests and serailizer
parent 30649733
......@@ -41,6 +41,7 @@ set (SRC
op/allreduce.cpp
op/and.cpp
op/argmin.cpp
op/argmax.cpp
op/asin.cpp
op/atan.cpp
op/avg_pool.cpp
......
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "ngraph/op/argmax.hpp"
using namespace std;
using namespace ngraph;
shared_ptr<Node> op::ArgMax::copy_with_new_args(const NodeVector& new_args) const
{
if (new_args.size() != 1)
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<ArgMax>(new_args.at(0), m_axis, this->get_element_type());
}
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include "ngraph/axis_set.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/op/util/index_reduction.hpp"
#include "ngraph/op/util/requires_tensor_view_args.hpp"
namespace ngraph
{
namespace op
{
//brief Computes minimum index along a specified axis for a given tensor
class ArgMax : public op::util::IndexReduction
{
public:
/// \brief Constructs a ArgMax operation.
///
/// \param arg The input tensor
/// \param axis The axis along which to compute an index for maximum
/// \param index_element_type produce indices. Currently, only int64 or int32 are supported
ArgMax(const std::shared_ptr<Node>& arg,
size_t axis,
const element::Type& index_element_type)
: IndexReduction("ArgMax", arg, axis, index_element_type)
{
}
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
};
}
}
......@@ -27,6 +27,7 @@ set(SRC
builder/add.cpp
builder/avg_pool.cpp
builder/argmin.cpp
builder/argmax.cpp
builder/batch_norm.cpp
builder/broadcast.cpp
builder/bounded_relu.cpp
......
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include <cstring>
#include "ngraph/op/argmax.hpp"
#include "ngraph/runtime/cpu/cpu_builder.hpp"
#include "ngraph/runtime/reference/argmax.hpp"
using namespace std;
using namespace ngraph;
namespace ngraph
{
namespace runtime
{
namespace cpu
{
template <>
void Builder::BUILDER_DECL(ngraph::op::ArgMax)
{
auto& functors = external_function->get_functors();
auto& tensor_data = external_function->get_tensor_data();
const ngraph::op::ArgMax* argmax = static_cast<const ngraph::op::ArgMax*>(node);
function<void(CPURuntimeContext*)> functor;
auto& arg_tensor = tensor_data[args[0].get_name()];
auto& out_tensor = tensor_data[out[0].get_name()];
if (out[0].get_element_type() != element::i64 &&
out[0].get_element_type() != element::i32)
{
throw ngraph_error("Unsupported index element type");
}
bool is_int64 = out[0].get_element_type() == element::i64;
auto axis = argmax->get_reduction_axis();
auto in_shape = args[0].get_shape();
auto out_shape = out[0].get_shape();
auto element_type = args[0].get_element_type();
if (element_type == element::f32)
{
if (is_int64)
{
functor = [&, in_shape, out_shape, axis](CPURuntimeContext* ctx) {
ngraph::runtime::reference::argmax<float, int64_t>(
static_cast<float*>(arg_tensor),
static_cast<int64_t*>(out_tensor),
in_shape,
out_shape,
axis);
};
}
else
{
functor = [&, in_shape, out_shape, axis](CPURuntimeContext* ctx) {
ngraph::runtime::reference::argmax<float, int32_t>(
static_cast<float*>(arg_tensor),
static_cast<int*>(out_tensor),
in_shape,
out_shape,
axis);
};
}
}
else if (element_type == element::f64)
{
if (is_int64)
{
functor = [&, in_shape, out_shape, axis](CPURuntimeContext* ctx) {
ngraph::runtime::reference::argmax<double, int64_t>(
static_cast<double*>(arg_tensor),
static_cast<int64_t*>(out_tensor),
in_shape,
out_shape,
axis);
};
}
else
{
functor = [&, in_shape, out_shape, axis](CPURuntimeContext* ctx) {
ngraph::runtime::reference::argmax<double, int32_t>(
static_cast<double*>(arg_tensor),
static_cast<int*>(out_tensor),
in_shape,
out_shape,
axis);
};
}
}
else
{
throw ngraph_error("Unsupported type in CPU Builder for ArgMax");
}
functors.emplace_back(functor);
}
REGISTER_OP_BUILDER(ArgMax);
}
}
}
......@@ -28,6 +28,7 @@
#include "ngraph/op/add.hpp"
#include "ngraph/op/allreduce.hpp"
#include "ngraph/op/and.hpp"
#include "ngraph/op/argmax.hpp"
#include "ngraph/op/argmin.hpp"
#include "ngraph/op/asin.hpp"
#include "ngraph/op/atan.hpp"
......@@ -2293,26 +2294,43 @@ namespace ngraph
writer.block_end();
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::ArgMin)
static void emitArgMinArgMax(const std::vector<TensorViewWrapper>& args,
const std::vector<TensorViewWrapper>& out,
size_t reduction_axis,
const char* kernel_name,
codegen::CodeWriter& writer)
{
auto argmin = static_cast<const ngraph::op::ArgMin*>(node);
if (out[0].get_element_type() != element::i64 &&
out[0].get_element_type() != element::i32)
{
throw ngraph_error("Unsupported index element type");
}
writer.block_begin();
writer << "reference::argmin<" << args[0].get_type() << ", "
writer << "reference::" << kernel_name << "<" << args[0].get_type() << ", "
<< out[0].get_element_type().c_type_string() << ">(" << args[0].get_name()
<< ",\n";
writer << " " << out[0].get_name() << ",\n";
writer << " {" << join(args[0].get_shape()) << "},\n";
writer << " {" << join(out[0].get_shape()) << "},\n";
writer << " " << argmin->get_reduction_axis() << ");\n";
writer << " " << reduction_axis << ");\n";
writer.block_end();
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::ArgMin)
{
auto argmin = static_cast<const ngraph::op::ArgMin*>(node);
emitArgMinArgMax(args, out, argmin->get_reduction_axis(), "argmin", writer);
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::ArgMax)
{
auto argmax = static_cast<const ngraph::op::ArgMax*>(node);
emitArgMinArgMax(args, out, argmax->get_reduction_axis(), "argmax", writer);
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Power)
{
......
......@@ -52,6 +52,7 @@
#include "ngraph/op/add.hpp"
#include "ngraph/op/allreduce.hpp"
#include "ngraph/op/and.hpp"
#include "ngraph/op/argmax.hpp"
#include "ngraph/op/argmin.hpp"
#include "ngraph/op/asin.hpp"
#include "ngraph/op/atan.hpp"
......@@ -281,6 +282,7 @@ static const runtime::cpu::OpMap dispatcher{
{TI(ngraph::op::Tanh), &runtime::cpu::CPU_Emitter::emit<op::Tanh>},
{TI(ngraph::op::Asin), &runtime::cpu::CPU_Emitter::emit<op::Asin>},
{TI(ngraph::op::ArgMin), &runtime::cpu::CPU_Emitter::emit<op::ArgMin>},
{TI(ngraph::op::ArgMax), &runtime::cpu::CPU_Emitter::emit<op::ArgMax>},
{TI(ngraph::op::Acos), &runtime::cpu::CPU_Emitter::emit<op::Acos>},
{TI(ngraph::op::Atan), &runtime::cpu::CPU_Emitter::emit<op::Atan>},
{TI(ngraph::op::ReplaceSlice), &runtime::cpu::CPU_Emitter::emit<op::ReplaceSlice>},
......@@ -426,6 +428,7 @@ void runtime::cpu::CPU_ExternalFunction::compile()
#include "ngraph/runtime/cpu/cpu_runtime_context.hpp"
#include "ngraph/runtime/cpu/mkldnn_invoke.hpp"
#include "ngraph/runtime/reference/and.hpp"
#include "ngraph/runtime/reference/argmax.hpp"
#include "ngraph/runtime/reference/argmin.hpp"
#include "ngraph/runtime/reference/avg_pool.hpp"
#include "ngraph/runtime/reference/batch_norm.hpp"
......
......@@ -33,3 +33,4 @@ backwards_avgpool_n2_c2_hw4x4
max_pool_3d
avg_pool_3d
argmin_trivial
argmax_trivial
......@@ -118,3 +118,4 @@ zero_sized_subtract
zero_sized_tan
zero_sized_tanh
argmin_trivial
argmax_trivial
......@@ -25,6 +25,7 @@
#include "ngraph/runtime/host_tensor_view.hpp"
#include "ngraph/runtime/tensor_view.hpp"
#include "ngraph/op/argmax.hpp"
#include "ngraph/op/argmin.hpp"
#include "ngraph/op/avg_pool.hpp"
#include "ngraph/op/batch_norm.hpp"
......@@ -57,6 +58,7 @@
#include "ngraph/runtime/reference/acos.hpp"
#include "ngraph/runtime/reference/add.hpp"
#include "ngraph/runtime/reference/and.hpp"
#include "ngraph/runtime/reference/argmax.hpp"
#include "ngraph/runtime/reference/argmin.hpp"
#include "ngraph/runtime/reference/asin.hpp"
#include "ngraph/runtime/reference/atan.hpp"
......@@ -235,6 +237,30 @@ private:
throw ngraph_error("Unexpected type");
}
}
else if (node_op == "ArgMax")
{
const op::ArgMax* argmax = static_cast<const op::ArgMax*>(&node);
if (out[0]->get_element_type() == element::i64)
{
reference::argmax<T, int64_t>(args[0]->get_data_ptr<T>(),
out[0]->get_data_ptr<int64_t>(),
args[0]->get_shape(),
out[0]->get_shape(),
argmax->get_reduction_axis());
}
else if (out[0]->get_element_type() == element::i32)
{
reference::argmax<T, int32_t>(args[0]->get_data_ptr<T>(),
out[0]->get_data_ptr<int32_t>(),
args[0]->get_shape(),
out[0]->get_shape(),
argmax->get_reduction_axis());
}
else
{
throw ngraph_error("Unexpected type");
}
}
else if (node_op == "Asin")
{
reference::asin<T>(
......
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include <cmath>
#include "ngraph/coordinate_transform.hpp"
namespace ngraph
{
namespace runtime
{
namespace reference
{
template <typename T, typename U>
void argmax(
const T* arg, U* out, const Shape& in_shape, const Shape& out_shape, size_t axis)
{
//take the first elements (i.e. 0 indices) in out_shape - axis as maximums
memset(out, 0, shape_size(out_shape) * sizeof(U));
AxisVector av{axis};
CoordinateTransform input_transform(in_shape);
for (const Coordinate& input_coord : input_transform)
{
Coordinate output_coord = project(input_coord, av);
CoordinateTransform output_transform(out_shape);
auto min_index = static_cast<size_t>(out[output_transform.index(output_coord)]);
auto min_coord = input_coord;
min_coord[axis] = min_index;
if (arg[input_transform.index(input_coord)] >
arg[input_transform.index(min_coord)])
{
out[output_transform.index(output_coord)] =
static_cast<U>(input_coord[axis]);
}
}
}
}
}
}
......@@ -25,6 +25,7 @@
#include "ngraph/op/add.hpp"
#include "ngraph/op/allreduce.hpp"
#include "ngraph/op/and.hpp"
#include "ngraph/op/argmax.hpp"
#include "ngraph/op/argmin.hpp"
#include "ngraph/op/asin.hpp"
#include "ngraph/op/atan.hpp"
......@@ -394,6 +395,12 @@ static shared_ptr<ngraph::Function>
auto target_type = read_element_type(node_js.at("index_element_type"));
node = make_shared<op::ArgMin>(args[0], axis, target_type);
}
else if (node_op == "ArgMax")
{
auto axis = node_js.at("axis").get<size_t>();
auto target_type = read_element_type(node_js.at("index_element_type"));
node = make_shared<op::ArgMax>(args[0], axis, target_type);
}
else if (node_op == "Asin")
{
node = make_shared<op::Asin>(args[0]);
......@@ -1024,6 +1031,12 @@ static json write(const Node& n, bool binary_constant_data)
node["axis"] = tmp->get_reduction_axis();
node["index_element_type"] = write_element_type(tmp->get_element_type());
}
else if (node_op == "ArgMax")
{
auto tmp = dynamic_cast<const op::ArgMax*>(&n);
node["axis"] = tmp->get_reduction_axis();
node["index_element_type"] = write_element_type(tmp->get_element_type());
}
else if (node_op == "AllReduce")
{
}
......
......@@ -25,6 +25,7 @@
#include "ngraph/autodiff/adjoints.hpp"
#include "ngraph/log.hpp"
#include "ngraph/ngraph.hpp"
#include "ngraph/op/argmax.hpp"
#include "ngraph/op/argmin.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/lrn.hpp"
......@@ -8947,3 +8948,22 @@ NGRAPH_TEST(${BACKEND_NAME}, argmin_trivial)
backend->call_with_validate(f, {result}, {a});
EXPECT_EQ((vector<int>{3, 2, 1}), read_vector<int>(result));
}
NGRAPH_TEST(${BACKEND_NAME}, argmax_trivial)
{
Shape shape{4, 3};
Shape rshape{3};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto f =
make_shared<Function>(make_shared<op::ArgMax>(A, 0, element::i32), op::ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape);
copy_data(a, vector<float>{9, 2, 10, 12, 8, 4, 6, 1, 5, 3, 11, 7});
auto result = backend->create_tensor(element::i32, rshape);
backend->call_with_validate(f, {result}, {a});
EXPECT_EQ((vector<int>{1, 3, 0}), read_vector<int>(result));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment