Commit fb3f9e95 authored by Adam Procter's avatar Adam Procter Committed by Robert Kimball

Move SigmoidBackprop to BinaryElementwiseArithmetic (#1914)

parent 08483fbd
......@@ -28,24 +28,15 @@ shared_ptr<Node> op::Sigmoid::copy_with_new_args(const NodeVector& new_args) con
}
op::Sigmoid::Sigmoid(shared_ptr<Node> arg)
: UnaryElementwiseArithmetic("Sigmoid", {arg})
: UnaryElementwiseArithmetic("Sigmoid", arg)
{
set_output_type(0, arg->get_element_type(), arg->get_shape());
constructor_validate_and_infer_types();
}
op::SigmoidBackprop::SigmoidBackprop(shared_ptr<Node> arg, shared_ptr<Node> delta)
: Op("SigmoidBackprop", check_single_output_args({arg, delta}))
: BinaryElementwiseArithmetic("SigmoidBackprop", arg, delta)
{
NODE_VALIDATION_ASSERT(this, arg->get_element_type() == delta->get_element_type())
<< "Argument and delta element types do not match (argument element type: "
<< arg->get_element_type() << ", delta element type: " << delta->get_element_type() << ").";
NODE_VALIDATION_ASSERT(this, arg->get_shape() == delta->get_shape())
<< "Argument and delta shapes do not match (argument shape: " << arg->get_shape()
<< ", delta shape: " << delta->get_shape() << ").";
set_output_type(0, delta->get_element_type(), delta->get_shape());
constructor_validate_and_infer_types();
}
shared_ptr<Node> op::SigmoidBackprop::copy_with_new_args(const NodeVector& new_args) const
......
......@@ -17,6 +17,7 @@
#pragma once
#include "ngraph/op/op.hpp"
#include "ngraph/op/util/binary_elementwise_arithmetic.hpp"
#include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
#include "ngraph/util.hpp"
......@@ -36,7 +37,7 @@ namespace ngraph
/// \brief Elementwise SigmoidBackprop operation.
///
class SigmoidBackprop : public Op
class SigmoidBackprop : public util::BinaryElementwiseArithmetic
{
public:
/// \brief Constructs a SigmoidBackprop operation.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment