Commit cef0508b authored by Pruthvi's avatar Pruthvi Committed by Scott Cyphers

use constructor_validate_and_infer_types() in CumSum ctor (#4044)

* - use construct_validate_infer_types() in CumSum ctor

* - remove unused variable
- relax rank check

* Warning
parent a9b8e213
...@@ -22,9 +22,9 @@ ...@@ -22,9 +22,9 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
constexpr NodeTypeInfo op::CumSum::type_info; constexpr NodeTypeInfo op::v0::CumSum::type_info;
op::CumSum::CumSum(const Output<Node>& arg, op::v0::CumSum::CumSum(const Output<Node>& arg,
const Output<Node>& axis, const Output<Node>& axis,
const bool exclusive, const bool exclusive,
const bool reverse) const bool reverse)
...@@ -32,29 +32,43 @@ op::CumSum::CumSum(const Output<Node>& arg, ...@@ -32,29 +32,43 @@ op::CumSum::CumSum(const Output<Node>& arg,
, m_exclusive(exclusive) , m_exclusive(exclusive)
, m_reverse(reverse) , m_reverse(reverse)
{ {
constructor_validate_and_infer_types();
}
void op::v0::CumSum::validate_and_infer_types()
{
element::Type arg_type = get_input_element_type(0);
PartialShape arg_shape = get_input_partial_shape(0);
set_output_type(0, arg_type, arg_shape);
PartialShape axes_shape{PartialShape::dynamic()};
if (get_input_partial_shape(1).is_static())
{
axes_shape = get_input_partial_shape(1);
}
const auto& axis_type = get_input_element_type(1);
NODE_VALIDATION_CHECK(this, NODE_VALIDATION_CHECK(this,
axis.get_element_type() == element::i32 || axis_type == element::i32 || axis_type == element::i64,
axis.get_element_type() == element::i64,
"axis element type must be either int64_t or int32_t but got (", "axis element type must be either int64_t or int32_t but got (",
axis.get_element_type(), axis_type,
")."); ").");
set_output_type(0, arg.get_element_type(), arg.get_shape());
} }
shared_ptr<Node> op::CumSum::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::v0::CumSum::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
return make_shared<op::CumSum>(new_args.at(0), new_args.at(1), m_exclusive, m_reverse); return make_shared<op::CumSum>(new_args.at(0), new_args.at(1), m_exclusive, m_reverse);
} }
void op::CumSum::generate_adjoints(autodiff::Adjoints& adjoints, const OutputVector& deltas) void op::v0::CumSum::generate_adjoints(autodiff::Adjoints& adjoints, const OutputVector& deltas)
{ {
auto delta = deltas.at(0); auto delta = deltas.at(0);
auto input_tensor = input_value(0); auto input_tensor = input_value(0);
adjoints.add_delta(input_tensor, delta); adjoints.add_delta(input_tensor, delta);
} }
shared_ptr<Node> op::CumSum::get_default_value() const shared_ptr<Node> op::v0::CumSum::get_default_value() const
{ {
return ngraph::make_constant_from_string("0", get_element_type(), get_shape()); return ngraph::make_constant_from_string("0", get_element_type(), get_shape());
} }
...@@ -91,6 +91,8 @@ namespace ngraph ...@@ -91,6 +91,8 @@ namespace ngraph
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
void validate_and_infer_types() override;
/// \return The default value for CumSum. /// \return The default value for CumSum.
virtual std::shared_ptr<Node> get_default_value() const override; virtual std::shared_ptr<Node> get_default_value() const override;
bool is_exclusive() const { return m_exclusive; } bool is_exclusive() const { return m_exclusive; }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment