Commit 112ff134 authored by tsocha's avatar tsocha Committed by Robert Kimball

[ONNX] Softmax operator (#1496)

* [ONNX] Softmax operator

* Review fix pt. 1

* Review fix pt. 2

* Add softmax test

* Update onnx_import.cpp
parent b8918f52
......@@ -46,6 +46,8 @@ add_library(onnx_import STATIC
op/max_pool.hpp
op/mul.hpp
op/relu.hpp
op/softmax.cpp
op/softmax.hpp
op/split.cpp
op/split.hpp
ops_bridge.cpp
......
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include <numeric>
#include "ngraph/op/softmax.hpp"
#include "exceptions.hpp"
#include "softmax.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
NodeVector softmax(const Node& node)
{
NodeVector inputs{node.get_ng_inputs()};
auto data = inputs.at(0);
auto data_shape = data->get_shape();
int axis = node.get_attribute_value<int64_t>("axis", 1);
if (axis < 0)
{
axis = data_shape.size() + axis;
}
else if (axis >= data_shape.size())
{
throw error::parameter::Value(
"Softmax node (",
node.get_name(),
"): provided axis attribute is out of input tensor dimensions range.");
}
// create vector of capacity data_dimensions - axis_divider position
std::vector<size_t> axes(data_shape.size() - axis);
std::iota(std::begin(axes), std::end(axes), axis);
return {std::make_shared<ngraph::op::Softmax>(data, axes)};
}
} // namespace op
} // namespace onnx_import
} // namespace ngraph
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include <memory>
#include "core/node.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
NodeVector softmax(const Node& node);
} // namespace op
} // namespace onnx_import
} // namespace ngraph
......@@ -28,6 +28,7 @@
#include "op/max_pool.hpp"
#include "op/mul.hpp"
#include "op/relu.hpp"
#include "op/softmax.hpp"
#include "op/split.hpp"
#include "ops_bridge.hpp"
......@@ -85,6 +86,7 @@ namespace ngraph
m_map.emplace("MaxPool", std::bind(op::max_pool, std::placeholders::_1));
m_map.emplace("Mul", std::bind(op::mul, std::placeholders::_1));
m_map.emplace("Relu", std::bind(op::relu, std::placeholders::_1));
m_map.emplace("Softmax", std::bind(op::softmax, std::placeholders::_1));
m_map.emplace("Split", std::bind(op::split, std::placeholders::_1));
}
......
......@@ -358,3 +358,46 @@ TEST(onnx, model_matmul)
auto result_vectors = execute(function, inputs, "INTERPRETER");
EXPECT_TRUE(test::all_close_f(expected_output, result_vectors.front()));
}
TEST(onnx, model_softmax)
{
auto function = ngraph::onnx_import::import_onnx_function(
ngraph::file_util::path_join(SERIALIZED_ZOO, "onnx/softmax.onnx"));
Inputs inputs;
inputs.emplace_back(
ngraph::test::NDArray<float, 3>(
{{{1, 2, 3, 4, 5}, {6, 7, 8, 9, 10}, {11, 12, 13, 14, 15}, {16, 17, 18, 19, 20}},
{{21, 22, 23, 24, 25},
{26, 27, 28, 29, 30},
{31, 32, 33, 34, 35},
{36, 37, 38, 39, 40}},
{{41, 42, 43, 44, 45},
{46, 47, 48, 49, 50},
{51, 52, 53, 54, 55},
{56, 57, 58, 59, 60}}})
.get_vector());
auto expected_output =
ngraph::test::NDArray<float, 3>(
{{{1.50461533e-26, 4.08996852e-26, 1.11176871e-25, 3.02210068e-25, 8.21492137e-25},
{2.23304715e-24, 6.07005148e-24, 1.65001106e-23, 4.48519509e-23, 1.21920243e-22},
{3.31413582e-22, 9.00875516e-22, 2.44883355e-21, 6.65661973e-21, 1.80945684e-20},
{4.91861366e-20, 1.33701781e-19, 3.63439123e-19, 9.87929963e-19, 2.68547207e-18}},
{{7.29986992e-18, 1.98431037e-17, 5.39391483e-17, 1.46621807e-16, 3.98559393e-16},
{1.08339676e-15, 2.94497771e-15, 8.00527940e-15, 2.17606055e-14, 5.91514586e-14},
{1.60790335e-13, 4.37073446e-13, 1.18808881e-12, 3.22956021e-12, 8.77885484e-12},
{2.38634016e-11, 6.48674509e-11, 1.76328013e-10, 4.79309234e-10, 1.30289758e-09}},
{{3.54164282e-09, 9.62718331e-09, 2.61693974e-08, 7.11357975e-08, 1.93367146e-07},
{5.25626399e-07, 1.42880069e-06, 3.88388295e-06, 1.05574884e-05, 2.86982290e-05},
{7.80098743e-05, 2.12052824e-04, 5.76419338e-04, 1.56687021e-03, 4.25919482e-03},
{1.15776919e-02, 3.14714295e-02, 8.55482149e-02, 2.32544158e-01, 6.32120559e-01}}})
.get_vector();
auto result_vectors = execute(function, inputs, "INTERPRETER");
EXPECT_TRUE(test::all_close_f(expected_output, result_vectors.front()));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment