Commit d218ccf9 authored by Michał Karzyński's avatar Michał Karzyński Committed by Scott Cyphers

Add support for operator sets and Softmax:1 (#3420)

* Add opset_version field to Node

* Add opset version aliases to Softmax

* Add op::set1::Softmax operator

* Disable opset 1 ops in INTERPRETER

* Add serializer support for Softmax opset 1

* Opset1Transformation pass

* Added unit tests to softmax pass

* Code refactoring

* Added missing virtual to set_opset_version

* Clang styles applied

* Update src/ngraph/pass/opset1_transform.cpp
Co-Authored-By: 's avatarAdam Procter <adam.m.procter@intel.com>

* Part.1 Code review remarks introduced

* Part.2 Code review remarks introduced

* Changed opset_version to op_version

* Code review remarks introduced

* Code review remarks introduced

* Set Op as base class for Softmax instead of UnaryElementwiseArithmetic

* Fixed unit tests

* v1::Softmax::generate_adjoints mark temporarily as not supported

* Fix CI. Part.2

* Fix CI. Part.3

* Code review remarks introduced

* Rename Opset1Transformation to Opset1Upgrade

* Fixed clag style problem with enum switch

* Fixes clang compilator error

* Removed unused foward declaration

* Code review remarks introduced

* Added checking if input rank is static
parent 79f346b3
......@@ -24,10 +24,8 @@ namespace py = pybind11;
void regclass_pyngraph_op_Softmax(py::module m)
{
py::class_<ngraph::op::Softmax,
std::shared_ptr<ngraph::op::Softmax>,
ngraph::op::util::UnaryElementwiseArithmetic>
softmax(m, "Softmax");
py::class_<ngraph::op::Softmax, std::shared_ptr<ngraph::op::Softmax>, ngraph::op::Op> softmax(
m, "Softmax");
softmax.doc() = "ngraph.impl.op.Softmax wraps ngraph::op::Softmax";
softmax.def(py::init<const std::shared_ptr<ngraph::Node>&, const ngraph::AxisSet&>());
}
......@@ -437,6 +437,8 @@ set (SRC
pass/nop_elimination.hpp
pass/pass.cpp
pass/pass.hpp
pass/opset1_upgrade.cpp
pass/opset1_upgrade.hpp
pass/pass_config.cpp
pass/pass_config.hpp
pass/propagate_cacheability.cpp
......
......@@ -379,6 +379,8 @@ namespace ngraph
/// Get all the nodes that uses the current node
NodeVector get_users(bool check_is_used = false) const;
/// \return Version of this node
virtual size_t get_version() const { return 0; }
virtual std::shared_ptr<Node> get_default_value() const { return nullptr; }
/// Use instance ids for comparison instead of memory addresses to improve determinism
bool operator<(const Node& other) const { return m_instance_id < other.m_instance_id; }
......
......@@ -28,24 +28,39 @@
using namespace std;
using namespace ngraph;
// *** SOFTMAX OP SET 0 ***
const string op::Softmax::type_name{"Softmax"};
op::Softmax::Softmax(const Output<Node>& arg, const AxisSet& axes)
: UnaryElementwiseArithmetic(arg)
op::v0::Softmax::Softmax(const Output<Node>& arg, const AxisSet& axes)
: Op({arg})
, m_axes(axes)
{
constructor_validate_and_infer_types();
const PartialShape& input_shape = get_input_partial_shape(0);
NODE_VALIDATION_CHECK(this,
input_shape.rank().is_static(),
"Input node rank must be static (input_shape=",
input_shape,
").");
for (auto axis : m_axes)
{
NODE_VALIDATION_CHECK(this,
axis < get_shape().size(),
axis >= 0 && axis < static_cast<size_t>(input_shape.rank()),
"Reduction axis (",
axis,
") is out of bounds (argument shape: ",
get_shape(),
input_shape,
").");
}
if (input_shape.is_static())
{
set_output_type(0, get_input_element_type(0), input_shape.to_shape());
}
else
{
set_output_type(0, get_input_element_type(0), PartialShape::dynamic());
}
// empty axes == all axes
if (m_axes.size() == 0)
......@@ -57,13 +72,13 @@ op::Softmax::Softmax(const Output<Node>& arg, const AxisSet& axes)
}
}
shared_ptr<Node> op::Softmax::copy_with_new_args(const NodeVector& new_args) const
shared_ptr<Node> op::v0::Softmax::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<Softmax>(new_args.at(0), m_axes);
}
void op::Softmax::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
void op::v0::Softmax::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
{
auto delta = deltas.at(0);
......@@ -90,3 +105,75 @@ void op::Softmax::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVect
auto x = input_value(0);
adjoints.add_delta(x, adjoint);
}
// *** SOFTMAX OP SET V1 ***
const string op::v1::Softmax::type_name{"Softmax"};
op::v1::Softmax::Softmax(const Output<Node>& arg, const size_t axis)
: Op({arg})
, m_axis(axis)
{
constructor_validate_and_infer_types();
const PartialShape& input_shape = get_input_partial_shape(0);
NODE_VALIDATION_CHECK(this,
input_shape.rank().is_static(),
"Input node rank must be static (input_shape=",
input_shape,
").");
NODE_VALIDATION_CHECK(this,
axis >= 0 && axis < static_cast<size_t>(input_shape.rank()),
"Reduction axis (",
axis,
") is out of bounds (argument shape: ",
input_shape,
").");
if (input_shape.is_static())
set_output_type(0, get_input_element_type(0), input_shape.to_shape());
else
set_output_type(0, get_input_element_type(0), PartialShape::dynamic());
}
shared_ptr<Node> op::v1::Softmax::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<op::v1::Softmax>(new_args.at(0), m_axis);
}
void op::v1::Softmax::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
{
throw ngraph_error("op::v1::Softmax::generate_adjoints function is not implemented yet");
/* This might work, but as of this writing we have no way to test it, so we are being careful
auto delta = deltas.at(0);
auto z = delta * shared_from_this();
std::vector<size_t> axes(get_shape().size() - m_axis);
std::iota(std::begin(axes), std::end(axes), m_axis);
AxisSet axes_set{axes};
auto zsum = make_shared<op::Sum>(z, axes_set);
Shape shape;
for (size_t i = 0; i < get_shape().size(); ++i)
{
if (axes_set.find(i) == axes_set.end())
{
shape.push_back(get_shape()[i]);
}
else
{
shape.push_back(1);
}
}
auto order = ngraph::get_default_order(zsum->get_shape());
auto zreshape = make_shared<op::Reshape>(zsum, order, shape);
auto adjoint = z - builder::make_with_numpy_broadcast<op::Multiply>(output(0), zreshape);
auto x = input(0).get_source_output();
adjoints.add_delta(x, adjoint);
*/
}
......@@ -16,42 +16,85 @@
#pragma once
#include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
#include "ngraph/op/op.hpp"
namespace ngraph
{
namespace op
{
/// \brief Softmax operation.
///
class Softmax : public util::UnaryElementwiseArithmetic
namespace v0
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
Softmax() = default;
/// \brief Constructs a softmax operation.
/// \brief Softmax operation.
///
/// \param arg Node that produces the first input tensor.<br>
/// `[d0, ...]`
/// \param axes The axis positions (0-based) on which to calculate the softmax.
///
/// Output `[d0, ...]`
///
Softmax(const Output<Node>& arg, const AxisSet& axes);
class Softmax : public Op
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
Softmax() = default;
/// \brief Constructs a softmax operation.
///
/// \param arg Node that produces the first input tensor.<br>
/// `[d0, ...]`
/// \param axes The axis positions (0-based) on which to calculate the softmax.
///
/// Output `[d0, ...]`
///
Softmax(const Output<Node>& arg, const AxisSet& axes);
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
const AxisSet& get_axes() const { return m_axes; }
void set_axes(const AxisSet& axes) { m_axes = axes; }
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
private:
AxisSet m_axes;
};
}
namespace v1
{
class Softmax : public Op
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
Softmax()
: m_axis(0)
{
}
/// \brief Constructs a softmax operation.
///
/// \param arg Node that produces the first input tensor.<br>
/// `[d0, ...]`
/// \param axis The axis position (0-based) on which to calculate the softmax.
///
/// Output `[d0, ...]`
///
Softmax(const Output<Node>& arg, const size_t axis);
size_t get_version() const override { return 1; }
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
size_t get_axis() const { return m_axis; }
void set_axis(const size_t axis) { m_axis = axis; }
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
const AxisSet& get_axes() const { return m_axes; }
void set_axes(const AxisSet& axes) { m_axes = axes; }
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
private:
size_t m_axis;
};
}
private:
AxisSet m_axes;
};
// default opset version
using v0::Softmax;
}
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/pass/opset1_upgrade.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/softmax.hpp"
using namespace std;
using namespace ngraph;
#define NGRAPH_OP(a, b) a,
enum class OP_TYPEID
{
#include "ngraph/op/op_tbl.hpp"
};
#undef NGRAPH_OP
#define NGRAPH_OP(a, b) {#a, OP_TYPEID::a},
static unordered_map<string, OP_TYPEID> typeid_map{
#include "ngraph/op/op_tbl.hpp"
};
#undef NGRAPH_OP
static OP_TYPEID get_typeid(shared_ptr<Node> node)
{
OP_TYPEID type_id;
auto it = typeid_map.find(node->description());
if (it != typeid_map.end())
{
type_id = it->second;
}
else
{
throw unsupported_op("Unsupported op '" + node->description() + "'");
}
return type_id;
}
// END mapping to OP_TYPEID
bool pass::Opset1Upgrade::run_on_node(shared_ptr<Node> node)
{
bool modified = false;
size_t op_version = node->get_version();
if (op_version == 1)
{
return modified;
}
NGRAPH_CHECK(op_version == 0,
"Op version 1 transformation pass failed for ",
*node,
", only op version 0 operations expected. Op version ",
op_version,
" found.");
// Not all enumeration values explicitly handled in switch
#if defined(__clang__)
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wswitch-enum"
#endif
switch (get_typeid(node))
{
case OP_TYPEID::Softmax:
{
auto tmp = dynamic_cast<const op::v0::Softmax*>(node.get());
AxisSet axes = tmp->get_axes();
NGRAPH_CHECK(
axes.size() == 1,
"Unable to convert Softmax:0 to Softmax:1 with zero or more than one axis. Node: ",
*node);
auto replacement_node =
make_shared<op::v1::Softmax>(node->input(0).get_source_output(), axes.to_vector()[0]);
replace_node(node, replacement_node);
modified = true;
break;
}
default: break;
}
#if defined(__clang__)
#pragma clang diagnostic pop
#endif
return modified;
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/pass/pass.hpp"
namespace ngraph
{
namespace pass
{
class Opset1Upgrade : public NodePass
{
public:
///
/// \brief Constructor for the Opset 1 transformation pass.
///
/// \details This transformation pass iterates over all nodes in a graph
/// and updates opset version 0 ops to their opset version 1 equivalents.
/// All ops in the final graph have opset version 1.
Opset1Upgrade() = default;
bool run_on_node(std::shared_ptr<ngraph::Node> node) override;
};
}
}
......@@ -142,7 +142,6 @@ namespace ngraph
class SigmoidBackprop;
class SigmoidMultiply;
class SigmoidMultiplyBackprop;
class Softmax;
class Result;
class And;
class Or;
......
......@@ -243,6 +243,16 @@ private:
{
const Node& node = *node_wrapper.get_node();
size_t op_version = node.get_version();
bool is_op_version_supported = op_version == 0;
NGRAPH_CHECK(is_op_version_supported,
"Unsupported operator version ",
op_version,
" in ",
node,
".\n",
"INTERPRETER backend currently only supports op in version 0.");
// We want to check that every OP_TYPEID enumeration is included in the list.
// These GCC flags enable compile-time checking so that if an enumeration
// is not in the list an error is generated.
......
......@@ -554,7 +554,7 @@ json JSONSerializer::serialize_function(const Function& f)
template <typename T>
T get_value(json js, const string& key)
{
T rc;
T rc = {};
auto it = js.find(key);
if (it != js.end())
{
......@@ -719,15 +719,18 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
string node_name = node_js.at("name").get<string>();
string node_op = node_js.at("op").get<string>();
string friendly_name = get_value<string>(node_js, "friendly_name");
size_t op_version = get_value<size_t>(node_js, "op_version");
vector<json> control_deps_inputs = get_value<vector<json>>(node_js, "control_deps");
vector<string> node_outputs = get_value<vector<string>>(node_js, "outputs");
OutputVectorHelper args(deserialize_output_vector(node_js["inputs"]));
#if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
#pragma GCC diagnostic push
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum"
// #pragma GCC diagnostic error "-Wimplicit-fallthrough"
#endif
switch (get_typeid(node_op))
{
case OP_TYPEID::Abs:
......@@ -1831,8 +1834,16 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
}
case OP_TYPEID::Softmax:
{
auto softmax_axes = deserialize_axis_set(node_js.at("softmax_axes"));
node = make_shared<op::Softmax>(args[0], softmax_axes);
if (op_version == 0)
{
auto softmax_axes = deserialize_axis_set(node_js.at("softmax_axes"));
node = make_shared<op::Softmax>(args[0], softmax_axes);
}
if (op_version == 1)
{
size_t softmax_axis = node_js.at("softmax_axis");
node = make_shared<op::v1::Softmax>(args[0], softmax_axis);
}
break;
}
case OP_TYPEID::SpaceToDepth:
......@@ -2028,6 +2039,9 @@ json JSONSerializer::serialize_node(const Node& n)
m_nodes_serialized.insert(&n);
json node;
node["name"] = n.get_name();
auto op_version = n.get_version();
node["op_version"] = op_version;
if (n.get_name() != n.get_friendly_name())
{
node["friendly_name"] = n.get_friendly_name();
......@@ -2881,8 +2895,16 @@ json JSONSerializer::serialize_node(const Node& n)
}
case OP_TYPEID::Softmax:
{
auto tmp = dynamic_cast<const op::Softmax*>(&n);
node["softmax_axes"] = serialize_axis_set(tmp->get_axes());
if (op_version == 0)
{
auto tmp = dynamic_cast<const op::v0::Softmax*>(&n);
node["softmax_axes"] = serialize_axis_set(tmp->get_axes());
}
if (op_version == 1)
{
auto tmp = dynamic_cast<const op::v1::Softmax*>(&n);
node["softmax_axis"] = tmp->get_axis();
}
break;
}
case OP_TYPEID::Tan: { break;
......
......@@ -64,6 +64,7 @@ set(SRC
node_input_output.cpp
nop_elimination.cpp
op.cpp
opset_pass/softmax_opset_pass.cpp
partial_shape.cpp
pass.cpp
pass_liveness.cpp
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/opset1_upgrade.hpp"
#include "util/type_prop.hpp"
using namespace std;
using namespace ngraph;
TEST(serialize, opset1_softmax_pass_axis)
{
const size_t axis = 2;
const AxisSet axes{axis};
auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 3, 4});
auto softmax_s0 = make_shared<op::v0::Softmax>(arg, axes);
auto result = make_shared<op::Result>(softmax_s0);
auto f = make_shared<Function>(ResultVector{result}, ParameterVector{arg});
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<pass::Opset1Upgrade>();
pass_manager.run_passes(f);
auto softmax_s1_result = f->get_results().at(0);
auto node = softmax_s1_result->input(0).get_source_output().get_node_shared_ptr();
auto softmax_s1_node = static_pointer_cast<op::v1::Softmax>(node);
EXPECT_EQ(softmax_s1_node->get_axis(), axis);
EXPECT_EQ(softmax_s1_node->description(), "Softmax");
EXPECT_EQ(softmax_s1_node->get_version(), 1);
}
TEST(serialize, opset1_softmax_pass_axis_exception)
{
const AxisSet axes{1, 2};
auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 3, 4});
auto softmax_s0 = make_shared<op::v0::Softmax>(arg, axes);
auto result = make_shared<op::Result>(softmax_s0);
auto f = make_shared<Function>(ResultVector{result}, ParameterVector{arg});
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<pass::Opset1Upgrade>();
try
{
pass_manager.run_passes(f);
FAIL() << "Exception after Opset1Upgrade pass was not thrown.";
}
catch (const ngraph_error& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
std::string(
"Unable to convert Softmax:0 to Softmax:1 with zero or more than one axis."));
}
catch (...)
{
FAIL() << "Softmax pass failed for unexpected reason";
}
}
namespace fake_v2
{
class FakeSoftmax : public op::v0::Softmax
{
public:
FakeSoftmax(const Output<Node>& arg, const AxisSet& axes)
: Softmax{arg, axes}
{
}
size_t get_version() const override { return 2; }
};
}
TEST(serialize, opset1_softmax_pass_incorrect_op_version)
{
const AxisSet axes{2};
auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 3, 4});
auto softmax_s2 = make_shared<fake_v2::FakeSoftmax>(arg, axes);
auto result = make_shared<op::Result>(softmax_s2);
auto f = make_shared<Function>(ResultVector{result}, ParameterVector{arg});
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<pass::Opset1Upgrade>();
try
{
pass_manager.run_passes(f);
FAIL() << "Opset 1 transformation pass failed for";
}
catch (const ngraph_error& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Op version 1 transformation pass failed for"));
}
catch (...)
{
FAIL() << "Softmax pass failed for unexpected reason";
}
}
......@@ -340,3 +340,19 @@ TEST(serialize, non_zero_node_output)
EXPECT_EQ(topk_out.get_index(), 1);
EXPECT_EQ(topk_out.get_node()->description(), "TopK");
}
TEST(serialize, opset1_softmax)
{
auto arg = make_shared<op::Parameter>(element::f32, Shape{10});
auto softmax = make_shared<op::v1::Softmax>(arg, 0);
auto result = make_shared<op::Result>(softmax);
auto f = make_shared<Function>(ResultVector{result}, ParameterVector{arg});
string s = serialize(f);
shared_ptr<Function> g = deserialize(s);
auto g_result = g->get_results().at(0);
auto g_softmax = g_result->input(0).get_source_output().get_node_shared_ptr();
EXPECT_EQ(g_softmax->description(), "Softmax");
EXPECT_EQ(g_softmax->get_version(), 1);
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment