Commit 47a727a6 authored by baojun's avatar baojun Committed by Scott Cyphers

Make softmax axes dynamic (#3601)

* make axes dynamic

* add ut

* update softmax deserializer

* check axes to be constant

* remove duplicates
parent 5a6cf4d0
......@@ -19,6 +19,7 @@
#include <algorithm>
#include "ngraph/builder/autobroadcast.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/subtract.hpp"
......@@ -32,28 +33,75 @@ using namespace ngraph;
constexpr NodeTypeInfo op::v0::Softmax::type_info;
op::v0::Softmax::Softmax(const Output<Node>& arg, const AxisSet& axes)
: Op({arg})
, m_axes(axes)
: Op({arg,
op::Constant::create(element::i64, Shape{axes.to_vector().size()}, axes.to_vector())
->output(0)})
{
constructor_validate_and_infer_types();
}
op::v0::Softmax::Softmax(const Output<Node>& arg, const Output<Node>& axes)
: Op({arg, axes})
{
constructor_validate_and_infer_types();
}
bool op::v0::Softmax::are_axes_constant() const
{
return input_value(1).get_node_shared_ptr()->is_constant();
}
const AxisSet op::v0::Softmax::get_axes() const
{
AxisSet axes;
auto const_op = dynamic_pointer_cast<op::Constant>(input_value(1).get_node_shared_ptr());
if (const_op)
{
axes = const_op->get_axis_set_val();
}
else
{
throw ngraph_error("get_axes called on a Softmax node whose 'axes' input is not constant");
}
return axes;
}
void op::v0::Softmax::set_axes(const AxisSet& axes)
{
this->input(1).replace_source_output(
op::Constant::create(element::i64, Shape{axes.to_vector().size()}, axes.to_vector())
->output(0));
}
void op::v0::Softmax::validate_and_infer_types()
{
const PartialShape& input_shape = get_input_partial_shape(0);
NODE_VALIDATION_CHECK(this,
input_shape.rank().is_static(),
"Input node rank must be static (input_shape=",
input_shape,
").");
if (input_shape.is_dynamic())
{
set_output_type(0, get_input_element_type(0), input_shape);
}
else
{
set_output_type(0, get_input_element_type(0), input_shape.to_shape());
if (are_axes_constant())
{
auto m_axes = get_axes();
for (auto axis : m_axes)
{
NODE_VALIDATION_CHECK(this,
axis < static_cast<size_t>(input_shape.rank()),
axis >= 0 && axis < static_cast<size_t>(input_shape.rank()),
"Reduction axis (",
axis,
") is out of bounds (argument shape: ",
input_shape,
").");
}
// empty axes == all axes
if (m_axes.size() == 0)
{
......@@ -61,39 +109,33 @@ op::v0::Softmax::Softmax(const Output<Node>& arg, const AxisSet& axes)
{
m_axes.insert(i);
}
set_axes(m_axes);
}
}
void op::v0::Softmax::validate_and_infer_types()
{
const PartialShape& input_shape = get_input_partial_shape(0);
if (input_shape.is_static())
{
set_output_type(0, get_input_element_type(0), input_shape.to_shape());
}
else
{
set_output_type(0, get_input_element_type(0), PartialShape::dynamic());
}
set_input_is_relevant_to_shape(1);
}
shared_ptr<Node> op::v0::Softmax::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<Softmax>(new_args.at(0), m_axes);
return make_shared<Softmax>(new_args.at(0), new_args.at(1));
}
void op::v0::Softmax::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
{
auto delta = deltas.at(0);
NGRAPH_CHECK(are_axes_constant(), "axes need to be constant");
auto axes = get_axes();
auto z = delta * shared_from_this();
auto zsum = make_shared<op::Sum>(z, m_axes);
auto zsum = make_shared<op::Sum>(z, axes);
Shape shape;
for (size_t i = 0; i < get_shape().size(); ++i)
{
if (m_axes.find(i) == m_axes.end())
if (axes.find(i) == axes.end())
{
shape.push_back(get_shape()[i]);
}
......
......@@ -42,20 +42,29 @@ namespace ngraph
/// Output `[d0, ...]`
///
Softmax(const Output<Node>& arg, const AxisSet& axes);
/// \brief Constructs a softmax operation.
///
/// \param arg Node that produces the first input tensor.<br>
/// `[d0, ...]`
/// \param axes node produces the axis positions (0-based) on which to calculate the
/// softmax.
///
/// Output `[d0, ...]`
///
Softmax(const Output<Node>& arg, const Output<Node>& axes);
void validate_and_infer_types() override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
const AxisSet& get_axes() const { return m_axes; }
void set_axes(const AxisSet& axes) { m_axes = axes; }
bool are_axes_constant() const;
const AxisSet get_axes() const;
void set_axes(const AxisSet& axes);
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
private:
AxisSet m_axes;
};
}
......
......@@ -367,6 +367,10 @@ bool pass::Opset1Upgrade::run_on_node(shared_ptr<Node> node)
case OP_TYPEID::Softmax:
{
auto tmp = dynamic_cast<const op::v0::Softmax*>(node.get());
NGRAPH_CHECK(node->input_value(1).get_node_shared_ptr()->is_constant(),
"axes parameter is expected to be a static constant");
AxisSet axes = tmp->get_axes();
NGRAPH_CHECK(
......
......@@ -2067,10 +2067,17 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
case OP_TYPEID::Softmax:
{
if (op_version == 0)
{
if (has_key(node_js, "softmax_axes"))
{
auto softmax_axes = deserialize_axis_set(node_js.at("softmax_axes"));
node = make_shared<op::Softmax>(args[0], softmax_axes);
}
else
{
node = make_shared<op::Softmax>(args[0], args[1]);
}
}
if (op_version == 1)
{
size_t softmax_axis = node_js.at("softmax_axis");
......@@ -3320,8 +3327,7 @@ json JSONSerializer::serialize_node(const Node& n)
{
if (op_version == 0)
{
auto tmp = static_cast<const op::v0::Softmax*>(&n);
node["softmax_axes"] = serialize_axis_set(tmp->get_axes());
break;
}
if (op_version == 1)
{
......
......@@ -38,6 +38,31 @@ using namespace ngraph;
static string s_manifest = "${MANIFEST}";
NGRAPH_TEST(${BACKEND_NAME}, softmax_dynamic_axes)
{
Shape shape_A{2, 3};
Shape shape_B{2};
auto A = make_shared<op::Parameter>(element::f32, shape_A);
auto B = make_shared<op::Parameter>(element::i64, shape_B);
auto f = make_shared<Function>(make_shared<op::Softmax>(A, B), ParameterVector{A, B});
auto backend = runtime::Backend::create("${BACKEND_NAME}", true);
auto a = backend->create_tensor(element::f32, shape_A);
auto b = backend->create_tensor(element::i64, shape_B);
copy_data(a, vector<float>{-3, -2, -1, 0, 1, 2});
copy_data(b, vector<int64_t>{0, 1});
auto result = backend->create_tensor(element::f32, shape_A);
auto d = expf(-3) + expf(-2) + expf(-1) + expf(0) + expf(1) + expf(2);
auto handle = backend->compile(f);
handle->call_with_validate({result}, {a, b});
vector<float> expected{
expf(-3) / d, expf(-2) / d, expf(-1) / d, expf(0) / d, expf(1) / d, expf(2) / d};
EXPECT_TRUE(test::all_close_f(expected, read_vector<float>(result)));
}
NGRAPH_TEST(${BACKEND_NAME}, softmax_all)
{
Shape shape{2, 3};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment