Commit 7b3b1b6c authored by Mateusz Bencer's avatar Mateusz Bencer Committed by Scott Cyphers

Change LogSoftmax to produce v1 and remove FusedOp (#4139)

* Removed LogSoftmax FusedOp, changed onnx to produce v1

* Code review remakrs introduced

* fix after merge from master
Co-authored-by: 's avatarScott Cyphers <diyessi@users.noreply.github.com>
parent 3b30799a
......@@ -385,8 +385,6 @@ set (SRC
op/fused/gru_cell.hpp
op/fused/layer_norm.cpp
op/fused/layer_norm.hpp
op/fused/log_softmax.cpp
op/fused/log_softmax.hpp
op/fused/lstm_cell.cpp
op/fused/lstm_cell.hpp
op/fused/lstm_sequence.cpp
......
......@@ -16,8 +16,9 @@
#include <memory>
#include "default_opset.hpp"
#include "log_softmax.hpp"
#include "ngraph/opsets/opset0.hpp"
#include "ngraph/validation_util.hpp"
namespace ngraph
{
......@@ -30,11 +31,16 @@ namespace ngraph
NodeVector log_softmax(const Node& node)
{
NodeVector inputs{node.get_ng_inputs()};
auto data = inputs.at(0);
auto data_shape = data->get_shape();
int axis = node.get_attribute_value<int64_t>("axis", 1);
const auto data = inputs.at(0);
const auto data_shape = data->get_shape();
return {std::make_shared<ngraph::opset0::LogSoftmax>(data, axis)};
const auto axis = node.get_attribute_value<int64_t>("axis", 1);
const auto normalized_axis =
ngraph::normalize_axis(node.get_description(), axis, data_shape.size());
const auto softmax =
std::make_shared<default_opset::Softmax>(data, normalized_axis);
return {std::make_shared<default_opset::Log>(softmax)};
}
} // namespace set_1
......
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <numeric>
#include "ngraph/op/fused/log_softmax.hpp"
#include "ngraph/op/log.hpp"
#include "ngraph/op/softmax.hpp"
#include "ngraph/validation_util.hpp"
using namespace std;
using namespace ngraph;
constexpr NodeTypeInfo op::LogSoftmax::type_info;
op::LogSoftmax::LogSoftmax(const Output<Node>& data, int64_t axis)
: FusedOp({data})
, m_axis(axis)
{
constructor_validate_and_infer_types();
}
NodeVector op::LogSoftmax::decompose_op() const
{
const auto data = input_value(0);
const auto data_shape = data.get_shape();
auto axis = ngraph::normalize_axis(this, m_axis, data_shape.size());
std::vector<size_t> axes(data_shape.size() - axis);
std::iota(std::begin(axes), std::end(axes), axis);
auto softmax = std::make_shared<ngraph::op::Softmax>(data, axes);
return {std::make_shared<ngraph::op::Log>(softmax)};
}
shared_ptr<Node> op::LogSoftmax::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<LogSoftmax>(new_args.at(0), m_axis);
}
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/node.hpp"
#include "ngraph/op/op.hpp"
#include "ngraph/op/util/fused_op.hpp"
namespace ngraph
{
namespace op
{
namespace v0
{
/// \brief LogSoftmax operation
class NGRAPH_API LogSoftmax : public ngraph::op::util::FusedOp
{
public:
static constexpr NodeTypeInfo type_info{"LogSoftmax", 0};
LogSoftmax() = default;
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs a LogSoftmax node.
///
/// \param data Node that produces the first input tensor
/// \param axis Describes the axis of the inputs when coerced to 2D
LogSoftmax(const Output<Node>& data, int64_t axis);
virtual NodeVector decompose_op() const override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
int64_t get_axis() const { return m_axis; }
protected:
int64_t m_axis;
};
}
using v0::LogSoftmax;
} // namespace op
} // namespace ngraph
......@@ -132,7 +132,6 @@ NGRAPH_OP(Less, ngraph::op::v1, 1)
NGRAPH_OP(LessEq, ngraph::op::v0, 0)
NGRAPH_OP(LessEqual, ngraph::op::v1, 1)
NGRAPH_OP(Log, ngraph::op, 0)
NGRAPH_OP(LogSoftmax, ngraph::op::v0, 0)
NGRAPH_OP(LogicalAnd, ngraph::op::v1, 1)
NGRAPH_OP(LogicalNot, ngraph::op::v1, 1)
NGRAPH_OP(LogicalOr, ngraph::op::v1, 1)
......
......@@ -98,7 +98,6 @@
#include "ngraph/op/fused/gru_cell.hpp"
#include "ngraph/op/fused/hard_sigmoid.hpp"
#include "ngraph/op/fused/layer_norm.hpp"
#include "ngraph/op/fused/log_softmax.hpp"
#include "ngraph/op/fused/lstm_cell.hpp"
#include "ngraph/op/fused/lstm_sequence.hpp"
#include "ngraph/op/fused/matmul.hpp"
......
......@@ -127,7 +127,6 @@ NGRAPH_OP(LayerNormBackprop, ngraph::op)
NGRAPH_OP(Less, ngraph::op)
NGRAPH_OP(LessEq, ngraph::op)
NGRAPH_OP(Log, ngraph::op)
NGRAPH_OP(LogSoftmax, ngraph::op)
NGRAPH_OP(LRN, ngraph::op)
NGRAPH_OP(LSTMCell, ngraph::op)
NGRAPH_OP(LSTMSequence, ngraph::op)
......
......@@ -1865,7 +1865,6 @@ protected:
case OP_TYPEID::Interpolate:
case OP_TYPEID::LayerNorm:
case OP_TYPEID::LayerNormBackprop:
case OP_TYPEID::LogSoftmax:
case OP_TYPEID::LSTMCell:
case OP_TYPEID::LSTMSequence:
case OP_TYPEID::MVN:
......
......@@ -1939,12 +1939,6 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
args[0], args[1], read_auto_broadcast(node_js, "auto_broadcast"));
break;
}
case OP_TYPEID::LogSoftmax:
{
auto axis = node_js.at("axis").get<int64_t>();
node = make_shared<op::LogSoftmax>(args[0], axis);
break;
}
case OP_TYPEID::LRN:
{
auto alpha = node_js.at("alpha").get<double>();
......@@ -3960,12 +3954,6 @@ json JSONSerializer::serialize_node(const Node& n)
}
break;
}
case OP_TYPEID::LogSoftmax:
{
auto tmp = static_cast<const op::LogSoftmax*>(&n);
node["axis"] = tmp->get_axis();
break;
}
case OP_TYPEID::LRN:
{
auto tmp = static_cast<const op::LRN*>(&n);
......
......@@ -148,7 +148,6 @@ set(SRC
type_prop/hard_sigmoid.cpp
type_prop/index_reduction.cpp
type_prop/layer_norm.cpp
type_prop/log_softmax.cpp
type_prop/lrn.cpp
type_prop/lstm_cell.cpp
type_prop/lstm_sequence.cpp
......
......@@ -717,15 +717,6 @@ namespace
EXPECT_FALSE(node.is_binary_elementwise_logical());
}
void op_is_LogSoftmax()
{
op::LogSoftmax node;
EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
EXPECT_FALSE(node.is_binary_elementwise_comparison());
EXPECT_FALSE(node.is_binary_elementwise_logical());
}
void op_is_LRN()
{
op::LRN node;
......
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "util/type_prop.hpp"
using namespace std;
using namespace ngraph;
TEST(type_prop, log_softmax)
{
const auto data = make_shared<op::Parameter>(element::f64, Shape{2, 2});
const auto axis = 2;
try
{
const auto log_softmax = make_shared<op::LogSoftmax>(data, axis);
// Should have thrown, so fail if it didn't
FAIL() << "Invalid axis value not detected";
}
catch (const ngraph_error& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Parameter axis "));
}
catch (...)
{
FAIL() << "Log softmax failed for unexpected reason";
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment