Commit 2f69f86c authored by Mateusz Bencer's avatar Mateusz Bencer Committed by Scott Cyphers

[FUSED] Add new LogSoftmax fused op (#3867)

* LogSoftmax introduced

* Added LogSoftmax to serializer

* Fixed style

* Fixed CmakeLists style

* code review remarks introduced

* Code review remarks introduced
parent bc1ca25b
......@@ -357,6 +357,8 @@ set (SRC
op/fused/gru_cell.hpp
op/fused/layer_norm.cpp
op/fused/layer_norm.hpp
op/fused/log_softmax.cpp
op/fused/log_softmax.hpp
op/fused/lstm_cell.cpp
op/fused/lstm_cell.hpp
op/fused/lstm_sequence.cpp
......
......@@ -112,6 +112,7 @@ add_library(onnx_import STATIC
op/leaky_relu.hpp
op/less.hpp
op/log.hpp
op/log_softmax.cpp
op/log_softmax.hpp
op/lp_norm.cpp
op/lp_norm.hpp
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <memory>
#include "core/node.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/fused/log_softmax.hpp"
#include "utils/common.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector log_softmax(const Node& node)
{
NodeVector inputs{node.get_ng_inputs()};
auto data = inputs.at(0);
auto data_shape = data->get_shape();
int axis = node.get_attribute_value<int64_t>("axis", 1);
return {std::make_shared<ngraph::op::LogSoftmax>(data, axis)};
}
} // namespace set_1
} // namespace op
} // namespace onnx_import
} // namespace ngraph
......@@ -19,9 +19,7 @@
#include <memory>
#include "core/node.hpp"
#include "ngraph/frontend/onnx_import/op/softmax.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/log.hpp"
namespace ngraph
{
......@@ -31,10 +29,7 @@ namespace ngraph
{
namespace set_1
{
inline NodeVector log_softmax(const Node& node)
{
return {std::make_shared<ngraph::op::Log>(softmax(node).at(0))};
}
NodeVector log_softmax(const Node& node);
} // namespace set_1
......
......@@ -143,6 +143,7 @@ namespace ngraph
#include "ngraph/op/fused/gru_cell.hpp"
#include "ngraph/op/fused/hard_sigmoid.hpp"
#include "ngraph/op/fused/layer_norm.hpp"
#include "ngraph/op/fused/log_softmax.hpp"
#include "ngraph/op/fused/lstm_cell.hpp"
#include "ngraph/op/fused/lstm_sequence.hpp"
#include "ngraph/op/fused/matmul.hpp"
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <numeric>
#include "ngraph/op/fused/log_softmax.hpp"
#include "ngraph/op/log.hpp"
#include "ngraph/op/softmax.hpp"
#include "ngraph/validation_util.hpp"
using namespace std;
using namespace ngraph;
constexpr NodeTypeInfo op::LogSoftmax::type_info;
op::LogSoftmax::LogSoftmax(const Output<Node>& data, int64_t axis)
: FusedOp({data})
, m_axis(axis)
{
constructor_validate_and_infer_types();
}
NodeVector op::LogSoftmax::decompose_op() const
{
const auto data = input_value(0);
const auto data_shape = data.get_shape();
auto axis = ngraph::normalize_axis(this, m_axis, data_shape.size());
std::vector<size_t> axes(data_shape.size() - axis);
std::iota(std::begin(axes), std::end(axes), axis);
auto softmax = std::make_shared<ngraph::op::Softmax>(data, axes);
return {std::make_shared<ngraph::op::Log>(softmax)};
}
shared_ptr<Node> op::LogSoftmax::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<LogSoftmax>(new_args.at(0), m_axis);
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/node.hpp"
#include "ngraph/op/op.hpp"
#include "ngraph/op/util/fused_op.hpp"
namespace ngraph
{
namespace op
{
/// \brief LogSoftmax operation
class LogSoftmax : public ngraph::op::util::FusedOp
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"LogSoftmax", 0};
LogSoftmax() = default;
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs a LogSoftmax node.
///
/// \param data Node that produces the first input tensor
/// \param axis Describes the axis of the inputs when coerced to 2D
LogSoftmax(const Output<Node>& data, int64_t axis);
virtual NodeVector decompose_op() const override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
int64_t get_axis() const { return m_axis; }
protected:
int64_t m_axis;
};
} // namespace op
} // namespace ngraph
......@@ -39,6 +39,7 @@ NGRAPH_OP(GRUCell, ngraph::op)
NGRAPH_OP(HardSigmoid, ngraph::op)
NGRAPH_OP(LayerNorm, ngraph::op)
NGRAPH_OP(LayerNormBackprop, ngraph::op)
NGRAPH_OP(LogSoftmax, ngraph::op)
NGRAPH_OP(LSTMCell, ngraph::op)
NGRAPH_OP(LSTMSequence, ngraph::op)
NGRAPH_OP(MatMul, ngraph::op)
......
......@@ -83,6 +83,7 @@
#include "ngraph/op/fused/gru_cell.hpp"
#include "ngraph/op/fused/hard_sigmoid.hpp"
#include "ngraph/op/fused/layer_norm.hpp"
#include "ngraph/op/fused/log_softmax.hpp"
#include "ngraph/op/fused/lstm_cell.hpp"
#include "ngraph/op/fused/lstm_sequence.hpp"
#include "ngraph/op/fused/matmul.hpp"
......@@ -1823,6 +1824,12 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
args[0], args[1], read_auto_broadcast(node_js, "auto_broadcast"));
break;
}
case OP_TYPEID::LogSoftmax:
{
auto axis = node_js.at("axis").get<int64_t>();
node = make_shared<op::LogSoftmax>(args[0], axis);
break;
}
case OP_TYPEID::LRN:
{
auto alpha = node_js.at("alpha").get<double>();
......@@ -3585,6 +3592,12 @@ json JSONSerializer::serialize_node(const Node& n)
}
break;
}
case OP_TYPEID::LogSoftmax:
{
auto tmp = static_cast<const op::LogSoftmax*>(&n);
node["axis"] = tmp->get_axis();
break;
}
case OP_TYPEID::LRN:
{
auto tmp = static_cast<const op::LRN*>(&n);
......
......@@ -794,3 +794,27 @@ PartialShape ngraph::infer_slice_shape(const Node* node,
return dim;
}
std::size_t ngraph::normalize_axis(const Node* node, std::int64_t axis, std::int64_t tensor_rank)
{
const auto axis_range_min = -tensor_rank;
const auto axis_range_max = tensor_rank - 1;
// Accepted range of value for axis is [axis_range_min, axis_range_max].
NGRAPH_CHECK(((axis >= axis_range_min) && (axis <= axis_range_max)),
node->description(),
"Parameter axis ",
axis,
" out of the tensor rank [-",
axis_range_min,
", ",
axis_range_max,
"].");
if (axis < 0)
{
axis = axis + tensor_rank;
}
return static_cast<size_t>(axis);
}
......@@ -102,4 +102,6 @@ namespace ngraph
const AxisSet& new_axis_mask,
const AxisSet& shrink_axis_mask,
const AxisSet& ellipsis_mask);
std::size_t normalize_axis(const Node* node, std::int64_t axis, std::int64_t tensor_rank);
}
......@@ -138,6 +138,7 @@ set(SRC
type_prop/hard_sigmoid.cpp
type_prop/index_reduction.cpp
type_prop/layer_norm.cpp
type_prop/log_softmax.cpp
type_prop/lrn.cpp
type_prop/lstm_cell.cpp
type_prop/lstm_sequence.cpp
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "util/type_prop.hpp"
using namespace std;
using namespace ngraph;
TEST(type_prop, log_softmax)
{
const auto data = make_shared<op::Parameter>(element::f64, Shape{2, 2});
const auto axis = 2;
try
{
const auto log_softmax = make_shared<op::LogSoftmax>(data, axis);
// Should have thrown, so fail if it didn't
FAIL() << "Invalid axis value not detected";
}
catch (const ngraph_error& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Parameter axis "));
}
catch (...)
{
FAIL() << "Log softmax failed for unexpected reason";
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment