Unverified Commit 1d53977a authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Force AutoBroadcast defaults (#3878)

* Force AutoBroadcast to be specified at the op level since no default is correct for all ops.

* exports
parent e45c64f0
......@@ -35,7 +35,10 @@ namespace ngraph
static constexpr NodeTypeInfo type_info{"Add", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs an uninitialized addition operation
Add() = default;
Add()
: util::BinaryElementwiseArithmetic(AutoBroadcastSpec::NONE)
{
}
/// \brief Constructs an addition operation.
///
......@@ -71,7 +74,10 @@ namespace ngraph
static constexpr NodeTypeInfo type_info{"Add", 1};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs an uninitialized addition operation
Add() = default;
Add()
: util::BinaryElementwiseArithmetic(AutoBroadcastSpec::NUMPY)
{
}
/// \brief Constructs an addition operation.
///
......
......@@ -31,7 +31,10 @@ namespace ngraph
NGRAPH_API
static constexpr NodeTypeInfo type_info{"Atan2", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
Atan2() = default;
Atan2()
: util::BinaryElementwiseArithmetic(AutoBroadcastSpec::NONE)
{
}
/// \brief atan2(y,x) is the angle from the origin to the point (x,y) (note reversed
/// order).
......
......@@ -32,7 +32,10 @@ namespace ngraph
static constexpr NodeTypeInfo type_info{"Divide", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs a division operation.
Divide() = default;
Divide()
: util::BinaryElementwiseArithmetic(AutoBroadcastSpec::NONE)
{
}
/// \brief Constructs a division operation.
///
/// \param arg0 Node that produces the first input tensor.
......@@ -76,7 +79,11 @@ namespace ngraph
static constexpr NodeTypeInfo type_info{"Divide", 1};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs a division operation.
Divide() = default;
Divide()
: util::BinaryElementwiseArithmetic(AutoBroadcastSpec::NUMPY)
{
}
/// \brief Constructs a division operation.
///
/// \param arg0 Node that produces the first input tensor.
......
......@@ -32,7 +32,10 @@ namespace ngraph
static constexpr NodeTypeInfo type_info{"Maximum", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs a maximum operation.
Maximum() = default;
Maximum()
: util::BinaryElementwiseArithmetic(AutoBroadcastSpec::NONE)
{
}
/// \brief Constructs a maximum operation.
///
/// \param arg0 Node that produces the first input tensor.
......@@ -62,7 +65,11 @@ namespace ngraph
static constexpr NodeTypeInfo type_info{"Maximum", 1};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs a maximum operation.
Maximum() = default;
Maximum()
: util::BinaryElementwiseArithmetic(AutoBroadcastSpec::NUMPY)
{
}
/// \brief Constructs a maximum operation.
///
/// \param arg0 Node that produces the first input tensor.
......
......@@ -32,7 +32,10 @@ namespace ngraph
static constexpr NodeTypeInfo type_info{"Minimum", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs a minimum operation.
Minimum() = default;
Minimum()
: util::BinaryElementwiseArithmetic(AutoBroadcastSpec::NONE)
{
}
/// \brief Constructs a minimum operation.
///
/// \param arg0 Node that produces the first input tensor.
......@@ -62,7 +65,11 @@ namespace ngraph
static constexpr NodeTypeInfo type_info{"Minimum", 1};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs a minimum operation.
Minimum() = default;
Minimum()
: util::BinaryElementwiseArithmetic(AutoBroadcastSpec::NUMPY)
{
}
/// \brief Constructs a minimum operation.
///
/// \param arg0 Node that produces the first input tensor.
......
......@@ -32,7 +32,10 @@ namespace ngraph
static constexpr NodeTypeInfo type_info{"Multiply", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs a multiplication operation.
Multiply() = default;
Multiply()
: util::BinaryElementwiseArithmetic(AutoBroadcastSpec::NONE)
{
}
/// \brief Constructs a multiplication operation.
///
/// \param arg0 Node that produces the first input tensor.
......@@ -62,7 +65,11 @@ namespace ngraph
static constexpr NodeTypeInfo type_info{"Multiply", 1};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs a multiplication operation.
Multiply() = default;
Multiply()
: util::BinaryElementwiseArithmetic(AutoBroadcastSpec::NUMPY)
{
}
/// \brief Constructs a multiplication operation.
///
/// \param arg0 Node that produces the first input tensor.
......
......@@ -46,7 +46,10 @@ namespace ngraph
NGRAPH_API
static constexpr NodeTypeInfo type_info{"Power", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
Power() = default;
Power()
: util::BinaryElementwiseArithmetic(AutoBroadcastSpec::NONE)
{
}
/// \brief Constructs an exponentiation operation.
///
/// \param arg0 Node that produces the first input tensor.
......@@ -89,7 +92,11 @@ namespace ngraph
NGRAPH_API
static constexpr NodeTypeInfo type_info{"Power", 1};
const NodeTypeInfo& get_type_info() const override { return type_info; }
Power() = default;
Power()
: util::BinaryElementwiseArithmetic(AutoBroadcastSpec::NUMPY)
{
}
/// \brief Constructs an exponentiation operation.
///
/// \param arg0 Node that produces the first input tensor.
......
......@@ -36,7 +36,7 @@ shared_ptr<Node> op::Relu::copy_with_new_args(const NodeVector& new_args) const
}
op::ReluBackprop::ReluBackprop(shared_ptr<Node> arg, shared_ptr<Node> delta)
: BinaryElementwiseArithmetic(arg, delta)
: BinaryElementwiseArithmetic(arg, delta, AutoBroadcastSpec::NONE)
{
constructor_validate_and_infer_types();
}
......
......@@ -37,7 +37,7 @@ op::Sigmoid::Sigmoid(const Output<Node>& arg)
}
op::SigmoidBackprop::SigmoidBackprop(const Output<Node>& arg, const Output<Node>& delta)
: BinaryElementwiseArithmetic(arg, delta)
: BinaryElementwiseArithmetic(arg, delta, AutoBroadcastSpec::NONE)
{
constructor_validate_and_infer_types();
}
......
......@@ -47,7 +47,11 @@ namespace ngraph
NGRAPH_API
static constexpr NodeTypeInfo type_info{"SigmoidBackprop", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
SigmoidBackprop() = default;
SigmoidBackprop()
: util::BinaryElementwiseArithmetic(AutoBroadcastSpec::NONE)
{
}
/// \brief Constructs a SigmoidBackprop operation.
///
/// \param arg Node that produces the Sigmoid forward input tensor.
......
......@@ -29,7 +29,11 @@ namespace ngraph
NGRAPH_API
static constexpr NodeTypeInfo type_info{"Subtract", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
Subtract() = default;
Subtract()
: util::BinaryElementwiseArithmetic(AutoBroadcastSpec::NONE)
{
}
/// \brief Constructs a subtraction operation.
///
/// \param arg0 Node that produces the first input tensor.
......
......@@ -19,6 +19,9 @@
using namespace ngraph;
const op::AutoBroadcastSpec op::AutoBroadcastSpec::NUMPY(AutoBroadcastType::NUMPY, 0);
const op::AutoBroadcastSpec op::AutoBroadcastSpec::NONE{AutoBroadcastType::NONE, 0};
namespace ngraph
{
template <>
......
......@@ -20,6 +20,7 @@
#include <ostream>
#include "ngraph/attribute_adapter.hpp"
#include "ngraph/ngraph_visibility.hpp"
#include "ngraph/type.hpp"
namespace ngraph
......@@ -269,6 +270,11 @@ namespace ngraph
{
return a.m_type == m_type && a.m_axis == m_axis;
}
NGRAPH_API
static const AutoBroadcastSpec NUMPY;
NGRAPH_API
static const AutoBroadcastSpec NONE;
};
}
}
......@@ -19,7 +19,8 @@
using namespace std;
using namespace ngraph;
op::util::BinaryElementwiseArithmetic::BinaryElementwiseArithmetic()
op::util::BinaryElementwiseArithmetic::BinaryElementwiseArithmetic(const AutoBroadcastSpec& autob)
: m_autob(autob)
{
}
......
......@@ -54,12 +54,12 @@ namespace ngraph
class BinaryElementwiseArithmetic : public Op
{
protected:
/// \brief Constructs a binary elementwise arithmetic operation.
BinaryElementwiseArithmetic();
BinaryElementwiseArithmetic(const AutoBroadcastSpec& autob);
/// \brief Constructs a binary elementwise arithmetic operation.
BinaryElementwiseArithmetic(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1,
const AutoBroadcastSpec& autob = AutoBroadcastSpec());
const AutoBroadcastSpec& autob);
/// \brief Constructs a binary elementwise arithmetic operation.
///
......@@ -67,7 +67,7 @@ namespace ngraph
/// \param arg1 Output that produces the second input tensor.
BinaryElementwiseArithmetic(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& autob = AutoBroadcastSpec());
const AutoBroadcastSpec& autob);
/// \brief Constructs a binary elementwise arithmetic operation.
///
......@@ -77,7 +77,7 @@ namespace ngraph
BinaryElementwiseArithmetic(const std::string& node_type,
const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1,
const AutoBroadcastSpec& autob = AutoBroadcastSpec());
const AutoBroadcastSpec& autob);
public:
void validate_and_infer_types() override;
......
......@@ -22,7 +22,7 @@ using namespace ngraph;
constexpr NodeTypeInfo op::GeluBackprop::type_info;
op::GeluBackprop::GeluBackprop(const Output<ngraph::Node>& arg, const Output<ngraph::Node>& delta)
: BinaryElementwiseArithmetic(arg, delta)
: BinaryElementwiseArithmetic(arg, delta, AutoBroadcastSpec::NONE)
{
constructor_validate_and_infer_types();
set_output_size(1);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment