Commit fb3f9e95 authored by Adam Procter's avatar Adam Procter Committed by Robert Kimball

Move SigmoidBackprop to BinaryElementwiseArithmetic (#1914)

parent 08483fbd
...@@ -28,24 +28,15 @@ shared_ptr<Node> op::Sigmoid::copy_with_new_args(const NodeVector& new_args) con ...@@ -28,24 +28,15 @@ shared_ptr<Node> op::Sigmoid::copy_with_new_args(const NodeVector& new_args) con
} }
op::Sigmoid::Sigmoid(shared_ptr<Node> arg) op::Sigmoid::Sigmoid(shared_ptr<Node> arg)
: UnaryElementwiseArithmetic("Sigmoid", {arg}) : UnaryElementwiseArithmetic("Sigmoid", arg)
{ {
set_output_type(0, arg->get_element_type(), arg->get_shape());
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
op::SigmoidBackprop::SigmoidBackprop(shared_ptr<Node> arg, shared_ptr<Node> delta) op::SigmoidBackprop::SigmoidBackprop(shared_ptr<Node> arg, shared_ptr<Node> delta)
: Op("SigmoidBackprop", check_single_output_args({arg, delta})) : BinaryElementwiseArithmetic("SigmoidBackprop", arg, delta)
{ {
NODE_VALIDATION_ASSERT(this, arg->get_element_type() == delta->get_element_type()) constructor_validate_and_infer_types();
<< "Argument and delta element types do not match (argument element type: "
<< arg->get_element_type() << ", delta element type: " << delta->get_element_type() << ").";
NODE_VALIDATION_ASSERT(this, arg->get_shape() == delta->get_shape())
<< "Argument and delta shapes do not match (argument shape: " << arg->get_shape()
<< ", delta shape: " << delta->get_shape() << ").";
set_output_type(0, delta->get_element_type(), delta->get_shape());
} }
shared_ptr<Node> op::SigmoidBackprop::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::SigmoidBackprop::copy_with_new_args(const NodeVector& new_args) const
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#pragma once #pragma once
#include "ngraph/op/op.hpp" #include "ngraph/op/op.hpp"
#include "ngraph/op/util/binary_elementwise_arithmetic.hpp"
#include "ngraph/op/util/unary_elementwise_arithmetic.hpp" #include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
...@@ -36,7 +37,7 @@ namespace ngraph ...@@ -36,7 +37,7 @@ namespace ngraph
/// \brief Elementwise SigmoidBackprop operation. /// \brief Elementwise SigmoidBackprop operation.
/// ///
class SigmoidBackprop : public Op class SigmoidBackprop : public util::BinaryElementwiseArithmetic
{ {
public: public:
/// \brief Constructs a SigmoidBackprop operation. /// \brief Constructs a SigmoidBackprop operation.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment