Commit 15c99fe5 authored by Chris Sullivan's avatar Chris Sullivan Committed by Scott Cyphers

Add Softmax::validate_and_infer_types (#3749)

parent 5c153f46
...@@ -53,14 +53,6 @@ op::v0::Softmax::Softmax(const Output<Node>& arg, const AxisSet& axes) ...@@ -53,14 +53,6 @@ op::v0::Softmax::Softmax(const Output<Node>& arg, const AxisSet& axes)
input_shape, input_shape,
")."); ").");
} }
if (input_shape.is_static())
{
set_output_type(0, get_input_element_type(0), input_shape.to_shape());
}
else
{
set_output_type(0, get_input_element_type(0), PartialShape::dynamic());
}
// empty axes == all axes // empty axes == all axes
if (m_axes.size() == 0) if (m_axes.size() == 0)
...@@ -72,6 +64,19 @@ op::v0::Softmax::Softmax(const Output<Node>& arg, const AxisSet& axes) ...@@ -72,6 +64,19 @@ op::v0::Softmax::Softmax(const Output<Node>& arg, const AxisSet& axes)
} }
} }
void op::v0::Softmax::validate_and_infer_types()
{
const PartialShape& input_shape = get_input_partial_shape(0);
if (input_shape.is_static())
{
set_output_type(0, get_input_element_type(0), input_shape.to_shape());
}
else
{
set_output_type(0, get_input_element_type(0), PartialShape::dynamic());
}
}
shared_ptr<Node> op::v0::Softmax::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::v0::Softmax::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
...@@ -128,7 +133,11 @@ op::v1::Softmax::Softmax(const Output<Node>& arg, const size_t axis) ...@@ -128,7 +133,11 @@ op::v1::Softmax::Softmax(const Output<Node>& arg, const size_t axis)
") is out of bounds (argument shape: ", ") is out of bounds (argument shape: ",
input_shape, input_shape,
")."); ").");
}
void op::v1::Softmax::validate_and_infer_types()
{
const PartialShape& input_shape = get_input_partial_shape(0);
if (input_shape.is_static()) if (input_shape.is_static())
set_output_type(0, get_input_element_type(0), input_shape.to_shape()); set_output_type(0, get_input_element_type(0), input_shape.to_shape());
else else
......
...@@ -43,6 +43,8 @@ namespace ngraph ...@@ -43,6 +43,8 @@ namespace ngraph
/// ///
Softmax(const Output<Node>& arg, const AxisSet& axes); Softmax(const Output<Node>& arg, const AxisSet& axes);
void validate_and_infer_types() override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
...@@ -79,6 +81,8 @@ namespace ngraph ...@@ -79,6 +81,8 @@ namespace ngraph
/// ///
Softmax(const Output<Node>& arg, const size_t axis); Softmax(const Output<Node>& arg, const size_t axis);
void validate_and_infer_types() override;
size_t get_version() const override { return 1; } size_t get_version() const override { return 1; }
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment