Commit 47a727a6 authored by baojun's avatar baojun Committed by Scott Cyphers

Make softmax axes dynamic (#3601)

* make axes dynamic

* add ut

* update softmax deserializer

* check axes to be constant

* remove duplicates
parent 5a6cf4d0
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include <algorithm> #include <algorithm>
#include "ngraph/builder/autobroadcast.hpp" #include "ngraph/builder/autobroadcast.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/multiply.hpp" #include "ngraph/op/multiply.hpp"
#include "ngraph/op/reshape.hpp" #include "ngraph/op/reshape.hpp"
#include "ngraph/op/subtract.hpp" #include "ngraph/op/subtract.hpp"
...@@ -32,68 +33,109 @@ using namespace ngraph; ...@@ -32,68 +33,109 @@ using namespace ngraph;
constexpr NodeTypeInfo op::v0::Softmax::type_info; constexpr NodeTypeInfo op::v0::Softmax::type_info;
op::v0::Softmax::Softmax(const Output<Node>& arg, const AxisSet& axes) op::v0::Softmax::Softmax(const Output<Node>& arg, const AxisSet& axes)
: Op({arg}) : Op({arg,
, m_axes(axes) op::Constant::create(element::i64, Shape{axes.to_vector().size()}, axes.to_vector())
->output(0)})
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
}
const PartialShape& input_shape = get_input_partial_shape(0); op::v0::Softmax::Softmax(const Output<Node>& arg, const Output<Node>& axes)
NODE_VALIDATION_CHECK(this, : Op({arg, axes})
input_shape.rank().is_static(), {
"Input node rank must be static (input_shape=", constructor_validate_and_infer_types();
input_shape, }
").");
for (auto axis : m_axes) bool op::v0::Softmax::are_axes_constant() const
{
return input_value(1).get_node_shared_ptr()->is_constant();
}
const AxisSet op::v0::Softmax::get_axes() const
{
AxisSet axes;
auto const_op = dynamic_pointer_cast<op::Constant>(input_value(1).get_node_shared_ptr());
if (const_op)
{ {
NODE_VALIDATION_CHECK(this, axes = const_op->get_axis_set_val();
axis < static_cast<size_t>(input_shape.rank()),
"Reduction axis (",
axis,
") is out of bounds (argument shape: ",
input_shape,
").");
} }
else
// empty axes == all axes
if (m_axes.size() == 0)
{ {
for (size_t i = 0; i < get_shape().size(); ++i) throw ngraph_error("get_axes called on a Softmax node whose 'axes' input is not constant");
{
m_axes.insert(i);
}
} }
return axes;
}
void op::v0::Softmax::set_axes(const AxisSet& axes)
{
this->input(1).replace_source_output(
op::Constant::create(element::i64, Shape{axes.to_vector().size()}, axes.to_vector())
->output(0));
} }
void op::v0::Softmax::validate_and_infer_types() void op::v0::Softmax::validate_and_infer_types()
{ {
const PartialShape& input_shape = get_input_partial_shape(0); const PartialShape& input_shape = get_input_partial_shape(0);
if (input_shape.is_static()) NODE_VALIDATION_CHECK(this,
input_shape.rank().is_static(),
"Input node rank must be static (input_shape=",
input_shape,
").");
if (input_shape.is_dynamic())
{ {
set_output_type(0, get_input_element_type(0), input_shape.to_shape()); set_output_type(0, get_input_element_type(0), input_shape);
} }
else else
{ {
set_output_type(0, get_input_element_type(0), PartialShape::dynamic()); set_output_type(0, get_input_element_type(0), input_shape.to_shape());
if (are_axes_constant())
{
auto m_axes = get_axes();
for (auto axis : m_axes)
{
NODE_VALIDATION_CHECK(this,
axis >= 0 && axis < static_cast<size_t>(input_shape.rank()),
"Reduction axis (",
axis,
") is out of bounds (argument shape: ",
input_shape,
").");
}
// empty axes == all axes
if (m_axes.size() == 0)
{
for (size_t i = 0; i < get_shape().size(); ++i)
{
m_axes.insert(i);
}
set_axes(m_axes);
}
}
} }
set_input_is_relevant_to_shape(1);
} }
shared_ptr<Node> op::v0::Softmax::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::v0::Softmax::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
return make_shared<Softmax>(new_args.at(0), m_axes); return make_shared<Softmax>(new_args.at(0), new_args.at(1));
} }
void op::v0::Softmax::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas) void op::v0::Softmax::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
{ {
auto delta = deltas.at(0); auto delta = deltas.at(0);
NGRAPH_CHECK(are_axes_constant(), "axes need to be constant");
auto axes = get_axes();
auto z = delta * shared_from_this(); auto z = delta * shared_from_this();
auto zsum = make_shared<op::Sum>(z, m_axes); auto zsum = make_shared<op::Sum>(z, axes);
Shape shape; Shape shape;
for (size_t i = 0; i < get_shape().size(); ++i) for (size_t i = 0; i < get_shape().size(); ++i)
{ {
if (m_axes.find(i) == m_axes.end()) if (axes.find(i) == axes.end())
{ {
shape.push_back(get_shape()[i]); shape.push_back(get_shape()[i]);
} }
......
...@@ -42,20 +42,29 @@ namespace ngraph ...@@ -42,20 +42,29 @@ namespace ngraph
/// Output `[d0, ...]` /// Output `[d0, ...]`
/// ///
Softmax(const Output<Node>& arg, const AxisSet& axes); Softmax(const Output<Node>& arg, const AxisSet& axes);
/// \brief Constructs a softmax operation.
///
/// \param arg Node that produces the first input tensor.<br>
/// `[d0, ...]`
/// \param axes node produces the axis positions (0-based) on which to calculate the
/// softmax.
///
/// Output `[d0, ...]`
///
Softmax(const Output<Node>& arg, const Output<Node>& axes);
void validate_and_infer_types() override; void validate_and_infer_types() override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
const AxisSet& get_axes() const { return m_axes; } bool are_axes_constant() const;
void set_axes(const AxisSet& axes) { m_axes = axes; } const AxisSet get_axes() const;
void set_axes(const AxisSet& axes);
protected: protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override; const NodeVector& deltas) override;
private:
AxisSet m_axes;
}; };
} }
......
...@@ -367,6 +367,10 @@ bool pass::Opset1Upgrade::run_on_node(shared_ptr<Node> node) ...@@ -367,6 +367,10 @@ bool pass::Opset1Upgrade::run_on_node(shared_ptr<Node> node)
case OP_TYPEID::Softmax: case OP_TYPEID::Softmax:
{ {
auto tmp = dynamic_cast<const op::v0::Softmax*>(node.get()); auto tmp = dynamic_cast<const op::v0::Softmax*>(node.get());
NGRAPH_CHECK(node->input_value(1).get_node_shared_ptr()->is_constant(),
"axes parameter is expected to be a static constant");
AxisSet axes = tmp->get_axes(); AxisSet axes = tmp->get_axes();
NGRAPH_CHECK( NGRAPH_CHECK(
......
...@@ -2068,8 +2068,15 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js) ...@@ -2068,8 +2068,15 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
{ {
if (op_version == 0) if (op_version == 0)
{ {
auto softmax_axes = deserialize_axis_set(node_js.at("softmax_axes")); if (has_key(node_js, "softmax_axes"))
node = make_shared<op::Softmax>(args[0], softmax_axes); {
auto softmax_axes = deserialize_axis_set(node_js.at("softmax_axes"));
node = make_shared<op::Softmax>(args[0], softmax_axes);
}
else
{
node = make_shared<op::Softmax>(args[0], args[1]);
}
} }
if (op_version == 1) if (op_version == 1)
{ {
...@@ -3320,8 +3327,7 @@ json JSONSerializer::serialize_node(const Node& n) ...@@ -3320,8 +3327,7 @@ json JSONSerializer::serialize_node(const Node& n)
{ {
if (op_version == 0) if (op_version == 0)
{ {
auto tmp = static_cast<const op::v0::Softmax*>(&n); break;
node["softmax_axes"] = serialize_axis_set(tmp->get_axes());
} }
if (op_version == 1) if (op_version == 1)
{ {
......
...@@ -38,6 +38,31 @@ using namespace ngraph; ...@@ -38,6 +38,31 @@ using namespace ngraph;
static string s_manifest = "${MANIFEST}"; static string s_manifest = "${MANIFEST}";
NGRAPH_TEST(${BACKEND_NAME}, softmax_dynamic_axes)
{
Shape shape_A{2, 3};
Shape shape_B{2};
auto A = make_shared<op::Parameter>(element::f32, shape_A);
auto B = make_shared<op::Parameter>(element::i64, shape_B);
auto f = make_shared<Function>(make_shared<op::Softmax>(A, B), ParameterVector{A, B});
auto backend = runtime::Backend::create("${BACKEND_NAME}", true);
auto a = backend->create_tensor(element::f32, shape_A);
auto b = backend->create_tensor(element::i64, shape_B);
copy_data(a, vector<float>{-3, -2, -1, 0, 1, 2});
copy_data(b, vector<int64_t>{0, 1});
auto result = backend->create_tensor(element::f32, shape_A);
auto d = expf(-3) + expf(-2) + expf(-1) + expf(0) + expf(1) + expf(2);
auto handle = backend->compile(f);
handle->call_with_validate({result}, {a, b});
vector<float> expected{
expf(-3) / d, expf(-2) / d, expf(-1) / d, expf(0) / d, expf(1) / d, expf(2) / d};
EXPECT_TRUE(test::all_close_f(expected, read_vector<float>(result)));
}
NGRAPH_TEST(${BACKEND_NAME}, softmax_all) NGRAPH_TEST(${BACKEND_NAME}, softmax_all)
{ {
Shape shape{2, 3}; Shape shape{2, 3};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment