Commit cef0508b authored by Pruthvi's avatar Pruthvi Committed by Scott Cyphers

use constructor_validate_and_infer_types() in CumSum ctor (#4044)

* - use construct_validate_infer_types() in CumSum ctor

* - remove unused variable
- relax rank check

* Warning
parent a9b8e213
......@@ -22,39 +22,53 @@
using namespace std;
using namespace ngraph;
constexpr NodeTypeInfo op::CumSum::type_info;
constexpr NodeTypeInfo op::v0::CumSum::type_info;
op::CumSum::CumSum(const Output<Node>& arg,
const Output<Node>& axis,
const bool exclusive,
const bool reverse)
op::v0::CumSum::CumSum(const Output<Node>& arg,
const Output<Node>& axis,
const bool exclusive,
const bool reverse)
: Op({arg, axis})
, m_exclusive(exclusive)
, m_reverse(reverse)
{
constructor_validate_and_infer_types();
}
void op::v0::CumSum::validate_and_infer_types()
{
element::Type arg_type = get_input_element_type(0);
PartialShape arg_shape = get_input_partial_shape(0);
set_output_type(0, arg_type, arg_shape);
PartialShape axes_shape{PartialShape::dynamic()};
if (get_input_partial_shape(1).is_static())
{
axes_shape = get_input_partial_shape(1);
}
const auto& axis_type = get_input_element_type(1);
NODE_VALIDATION_CHECK(this,
axis.get_element_type() == element::i32 ||
axis.get_element_type() == element::i64,
axis_type == element::i32 || axis_type == element::i64,
"axis element type must be either int64_t or int32_t but got (",
axis.get_element_type(),
axis_type,
").");
set_output_type(0, arg.get_element_type(), arg.get_shape());
}
shared_ptr<Node> op::CumSum::copy_with_new_args(const NodeVector& new_args) const
shared_ptr<Node> op::v0::CumSum::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<op::CumSum>(new_args.at(0), new_args.at(1), m_exclusive, m_reverse);
}
void op::CumSum::generate_adjoints(autodiff::Adjoints& adjoints, const OutputVector& deltas)
void op::v0::CumSum::generate_adjoints(autodiff::Adjoints& adjoints, const OutputVector& deltas)
{
auto delta = deltas.at(0);
auto input_tensor = input_value(0);
adjoints.add_delta(input_tensor, delta);
}
shared_ptr<Node> op::CumSum::get_default_value() const
shared_ptr<Node> op::v0::CumSum::get_default_value() const
{
return ngraph::make_constant_from_string("0", get_element_type(), get_shape());
}
......@@ -91,6 +91,8 @@ namespace ngraph
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
void validate_and_infer_types() override;
/// \return The default value for CumSum.
virtual std::shared_ptr<Node> get_default_value() const override;
bool is_exclusive() const { return m_exclusive; }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment