Commit ae352fa4 authored by Tomasz Dołbniak's avatar Tomasz Dołbniak Committed by Scott Cyphers

[ONNX] Hardmax support in ONNX importer (#2869)

* [ONNX] Hardmax implementation in the onnx importer

* [ONNX] More generic handling of types in hardmax

* [ONNX] Support for doubles in EmbeddingLookup CPU builder

* [ONNX] Throw when the provided axis is out of range

* [ONNX] Skip the hardmax test on GPU

* Unused headers clean-up

* refactor: move the identity matrix generator to common.hpp

* ASSERT_VALID_ARGUMENT for axis range validation

* Adapt to the code changes in master
parent d83d18a4
......@@ -96,6 +96,8 @@ add_library(onnx_import STATIC
op/greater.hpp
op/hard_sigmoid.cpp
op/hard_sigmoid.hpp
op/hardmax.cpp
op/hardmax.hpp
op/identity.hpp
op/leaky_relu.cpp
op/leaky_relu.hpp
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "hardmax.hpp"
#include "exceptions.hpp"
#include "ngraph/frontend/onnx_import/utils/common.hpp"
#include "ngraph/op/argmax.hpp"
#include "ngraph/op/embedding_lookup.hpp"
#include "ngraph/op/util/reshape.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector hardmax(const Node& node)
{
const auto input = node.get_ng_inputs().at(0);
const auto& input_shape = input->get_shape();
auto axis = node.get_attribute_value<std::int64_t>("axis", 1);
ASSERT_VALID_ARGUMENT(node, axis >= 0 && axis < input_shape.size())
<< "The provided axis value " << axis
<< " does not match the input tensor dimensions";
// reshape to 2D - "batch size" x "input feature dimensions" (NxD)
const auto coerced_tensor = ngraph::op::util::flatten(input, axis);
const auto& coerced_shape = coerced_tensor->get_shape();
const std::shared_ptr<ngraph::Node> argmax_2d =
std::make_shared<ngraph::op::ArgMax>(coerced_tensor, 1, element::i64);
std::shared_ptr<ngraph::Node> eye_matrix =
common::square_identity(coerced_shape.at(1), input->get_element_type());
// the results are elements of the eye_matrix indexed by argmax_2d values
// in other words: eye_matrix[argmax_2d]
auto results =
std::make_shared<ngraph::op::EmbeddingLookup>(argmax_2d, eye_matrix);
return {ngraph::op::util::reshape(results, input_shape)};
}
} // namespace set_1
} // namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "core/node.hpp"
#include "ngraph/node.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector hardmax(const Node& node);
} // namespace set_1
} //namespace op
} // namespace onnx_import
} // namespace ngraph
......@@ -60,6 +60,7 @@
#include "op/global_max_pool.hpp"
#include "op/greater.hpp"
#include "op/hard_sigmoid.hpp"
#include "op/hardmax.hpp"
#include "op/identity.hpp"
#include "op/leaky_relu.hpp"
#include "op/less.hpp"
......@@ -262,6 +263,7 @@ namespace ngraph
REGISTER_OPERATOR("GlobalLpPool", 1, global_lp_pool);
REGISTER_OPERATOR("GlobalMaxPool", 1, global_max_pool);
REGISTER_OPERATOR("Greater", 1, greater);
REGISTER_OPERATOR("Hardmax", 1, hardmax);
REGISTER_OPERATOR("HardSigmoid", 1, hard_sigmoid);
REGISTER_OPERATOR("Identity", 1, identity);
REGISTER_OPERATOR("LeakyRelu", 1, leaky_relu);
......
......@@ -86,6 +86,25 @@ namespace ngraph
}
}
/// \brief Creates a square identity matrix.
///
/// \param[in] n Order of the resulting matrix.
///
/// \return A Constant node representing identity matrix with shape (n, n).
template <typename T = double>
std::shared_ptr<ngraph::op::Constant> square_identity(const size_t n,
const element::Type& type)
{
std::vector<T> identity_matrix(n * n, T{0});
for (size_t row = 0; row < n; ++row)
{
const size_t diagonal_element = (n * row) + row;
identity_matrix.at(diagonal_element) = T{1};
}
return std::make_shared<ngraph::op::Constant>(type, Shape{{n, n}}, identity_matrix);
}
} // namespace common
} // namespace onnx_import
} // namespace ngraph
......@@ -112,6 +112,68 @@ namespace ngraph
"Unsupported index type in CPU Builder for EmbeddingLookup");
}
}
else if (element_type == element::f64)
{
if (index_element_type == element::f32)
{
functor = [&,
in_shape,
element_count,
arg0_buffer_index,
arg1_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
ngraph::runtime::reference::embedding<double, float>(
static_cast<float*>(ctx->buffer_data[arg0_buffer_index]),
static_cast<double*>(ctx->buffer_data[arg1_buffer_index]),
static_cast<double*>(ctx->buffer_data[out_buffer_index]),
element_count,
in_shape);
};
}
else if (index_element_type == element::i32)
{
functor = [&,
in_shape,
element_count,
arg0_buffer_index,
arg1_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
ngraph::runtime::reference::embedding<double, int>(
static_cast<int*>(ctx->buffer_data[arg0_buffer_index]),
static_cast<double*>(ctx->buffer_data[arg1_buffer_index]),
static_cast<double*>(ctx->buffer_data[out_buffer_index]),
element_count,
in_shape);
};
}
else if (index_element_type == element::i64)
{
functor = [&,
in_shape,
element_count,
arg0_buffer_index,
arg1_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
ngraph::runtime::reference::embedding<double, int64_t>(
static_cast<int64_t*>(ctx->buffer_data[arg0_buffer_index]),
static_cast<double*>(ctx->buffer_data[arg1_buffer_index]),
static_cast<double*>(ctx->buffer_data[out_buffer_index]),
element_count,
in_shape);
};
}
else
{
throw ngraph_error(
"Unsupported index type in CPU Builder for EmbeddingLookup");
}
}
else if (element_type == element::i32)
{
if (index_element_type == element::f32)
......
......@@ -155,3 +155,4 @@ gather_scalar_indices
gather_nd_single_indices
gemm
gemm_broadcast_input_C
model_hardmax
ir_version: 3
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "x"
output: "y"
op_type: "Hardmax"
attribute {
name: "axis"
i: 2
type: INT
}
}
name: "hardmax_graph"
input {
name: "x"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 3
}
dim {
dim_value: 4
}
dim {
dim_value: 5
}
}
}
}
}
output {
name: "y"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 3
}
dim {
dim_value: 4
}
dim {
dim_value: 5
}
}
}
}
}
}
opset_import {
version: 9
}
......@@ -1420,3 +1420,39 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_erf_int32)
EXPECT_TRUE(test::all_close(expected_outputs, outputs.front()));
}
NGRAPH_TEST(onnx_${BACKEND_NAME}, model_hardmax)
{
auto hardmax_fn = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/hardmax.prototxt"));
auto test_case = ngraph::test::NgraphTestCase(hardmax_fn, "${BACKEND_NAME}");
test_case.add_input<float>(
{-2.02458119f, 0.00126542f, -0.58045743f, -0.75186814f, 0.9406899f,
-0.513188f, 0.85887463f, 1.61444086f, 0.23801147f, -0.26816885f,
0.6597208f, 1.43889519f, 0.28798895f, 1.44769952f, -1.99466756f,
0.41386644f, 0.69389555f, 1.46118255f, -1.67628606f, 1.49697552f,
0.06337166f, -1.15740783f, 0.8792142f, -0.95352717f, -1.87895792f,
-0.74066102f, -0.27131459f, 0.2219685f, 0.31831001f, 0.52495901f,
0.60283089f, 0.60397976f, 0.92401468f, 0.29565101f, -1.14443776f,
-1.07399045f, -0.92266259f, 0.24017731f, -0.30105675f, 1.18513269f,
0.55494542f, 1.12119279f, -0.43156474f, 0.15101668f, -1.460439f,
0.96375129f, 1.10411785f, -0.30272771f, -0.48855848f, 0.12103213f,
-0.71388492f, 1.38398178f, 0.21924434f, 0.93105052f, -0.21074303f,
0.48213503f, -1.37810638f, 8.99060285f, 0.54794592f, -0.46820172f});
// values for hardmax with axis==2
test_case.add_expected_output<float>(
Shape{3, 4, 5}, {0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f,
0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f,
0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f,
0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f,
0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f,
0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f});
test_case.run();
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment