Unverified Commit 1d53977a authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Force AutoBroadcast defaults (#3878)

* Force AutoBroadcast to be specified at the op level since no default is correct for all ops.

* exports
parent e45c64f0
...@@ -35,7 +35,10 @@ namespace ngraph ...@@ -35,7 +35,10 @@ namespace ngraph
static constexpr NodeTypeInfo type_info{"Add", 0}; static constexpr NodeTypeInfo type_info{"Add", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; } const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs an uninitialized addition operation /// \brief Constructs an uninitialized addition operation
Add() = default; Add()
: util::BinaryElementwiseArithmetic(AutoBroadcastSpec::NONE)
{
}
/// \brief Constructs an addition operation. /// \brief Constructs an addition operation.
/// ///
...@@ -71,7 +74,10 @@ namespace ngraph ...@@ -71,7 +74,10 @@ namespace ngraph
static constexpr NodeTypeInfo type_info{"Add", 1}; static constexpr NodeTypeInfo type_info{"Add", 1};
const NodeTypeInfo& get_type_info() const override { return type_info; } const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs an uninitialized addition operation /// \brief Constructs an uninitialized addition operation
Add() = default; Add()
: util::BinaryElementwiseArithmetic(AutoBroadcastSpec::NUMPY)
{
}
/// \brief Constructs an addition operation. /// \brief Constructs an addition operation.
/// ///
......
...@@ -31,7 +31,10 @@ namespace ngraph ...@@ -31,7 +31,10 @@ namespace ngraph
NGRAPH_API NGRAPH_API
static constexpr NodeTypeInfo type_info{"Atan2", 0}; static constexpr NodeTypeInfo type_info{"Atan2", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; } const NodeTypeInfo& get_type_info() const override { return type_info; }
Atan2() = default; Atan2()
: util::BinaryElementwiseArithmetic(AutoBroadcastSpec::NONE)
{
}
/// \brief atan2(y,x) is the angle from the origin to the point (x,y) (note reversed /// \brief atan2(y,x) is the angle from the origin to the point (x,y) (note reversed
/// order). /// order).
......
...@@ -32,7 +32,10 @@ namespace ngraph ...@@ -32,7 +32,10 @@ namespace ngraph
static constexpr NodeTypeInfo type_info{"Divide", 0}; static constexpr NodeTypeInfo type_info{"Divide", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; } const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs a division operation. /// \brief Constructs a division operation.
Divide() = default; Divide()
: util::BinaryElementwiseArithmetic(AutoBroadcastSpec::NONE)
{
}
/// \brief Constructs a division operation. /// \brief Constructs a division operation.
/// ///
/// \param arg0 Node that produces the first input tensor. /// \param arg0 Node that produces the first input tensor.
...@@ -76,7 +79,11 @@ namespace ngraph ...@@ -76,7 +79,11 @@ namespace ngraph
static constexpr NodeTypeInfo type_info{"Divide", 1}; static constexpr NodeTypeInfo type_info{"Divide", 1};
const NodeTypeInfo& get_type_info() const override { return type_info; } const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs a division operation. /// \brief Constructs a division operation.
Divide() = default; Divide()
: util::BinaryElementwiseArithmetic(AutoBroadcastSpec::NUMPY)
{
}
/// \brief Constructs a division operation. /// \brief Constructs a division operation.
/// ///
/// \param arg0 Node that produces the first input tensor. /// \param arg0 Node that produces the first input tensor.
......
...@@ -32,7 +32,10 @@ namespace ngraph ...@@ -32,7 +32,10 @@ namespace ngraph
static constexpr NodeTypeInfo type_info{"Maximum", 0}; static constexpr NodeTypeInfo type_info{"Maximum", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; } const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs a maximum operation. /// \brief Constructs a maximum operation.
Maximum() = default; Maximum()
: util::BinaryElementwiseArithmetic(AutoBroadcastSpec::NONE)
{
}
/// \brief Constructs a maximum operation. /// \brief Constructs a maximum operation.
/// ///
/// \param arg0 Node that produces the first input tensor. /// \param arg0 Node that produces the first input tensor.
...@@ -62,7 +65,11 @@ namespace ngraph ...@@ -62,7 +65,11 @@ namespace ngraph
static constexpr NodeTypeInfo type_info{"Maximum", 1}; static constexpr NodeTypeInfo type_info{"Maximum", 1};
const NodeTypeInfo& get_type_info() const override { return type_info; } const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs a maximum operation. /// \brief Constructs a maximum operation.
Maximum() = default; Maximum()
: util::BinaryElementwiseArithmetic(AutoBroadcastSpec::NUMPY)
{
}
/// \brief Constructs a maximum operation. /// \brief Constructs a maximum operation.
/// ///
/// \param arg0 Node that produces the first input tensor. /// \param arg0 Node that produces the first input tensor.
......
...@@ -32,7 +32,10 @@ namespace ngraph ...@@ -32,7 +32,10 @@ namespace ngraph
static constexpr NodeTypeInfo type_info{"Minimum", 0}; static constexpr NodeTypeInfo type_info{"Minimum", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; } const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs a minimum operation. /// \brief Constructs a minimum operation.
Minimum() = default; Minimum()
: util::BinaryElementwiseArithmetic(AutoBroadcastSpec::NONE)
{
}
/// \brief Constructs a minimum operation. /// \brief Constructs a minimum operation.
/// ///
/// \param arg0 Node that produces the first input tensor. /// \param arg0 Node that produces the first input tensor.
...@@ -62,7 +65,11 @@ namespace ngraph ...@@ -62,7 +65,11 @@ namespace ngraph
static constexpr NodeTypeInfo type_info{"Minimum", 1}; static constexpr NodeTypeInfo type_info{"Minimum", 1};
const NodeTypeInfo& get_type_info() const override { return type_info; } const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs a minimum operation. /// \brief Constructs a minimum operation.
Minimum() = default; Minimum()
: util::BinaryElementwiseArithmetic(AutoBroadcastSpec::NUMPY)
{
}
/// \brief Constructs a minimum operation. /// \brief Constructs a minimum operation.
/// ///
/// \param arg0 Node that produces the first input tensor. /// \param arg0 Node that produces the first input tensor.
......
...@@ -32,7 +32,10 @@ namespace ngraph ...@@ -32,7 +32,10 @@ namespace ngraph
static constexpr NodeTypeInfo type_info{"Multiply", 0}; static constexpr NodeTypeInfo type_info{"Multiply", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; } const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs a multiplication operation. /// \brief Constructs a multiplication operation.
Multiply() = default; Multiply()
: util::BinaryElementwiseArithmetic(AutoBroadcastSpec::NONE)
{
}
/// \brief Constructs a multiplication operation. /// \brief Constructs a multiplication operation.
/// ///
/// \param arg0 Node that produces the first input tensor. /// \param arg0 Node that produces the first input tensor.
...@@ -62,7 +65,11 @@ namespace ngraph ...@@ -62,7 +65,11 @@ namespace ngraph
static constexpr NodeTypeInfo type_info{"Multiply", 1}; static constexpr NodeTypeInfo type_info{"Multiply", 1};
const NodeTypeInfo& get_type_info() const override { return type_info; } const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs a multiplication operation. /// \brief Constructs a multiplication operation.
Multiply() = default; Multiply()
: util::BinaryElementwiseArithmetic(AutoBroadcastSpec::NUMPY)
{
}
/// \brief Constructs a multiplication operation. /// \brief Constructs a multiplication operation.
/// ///
/// \param arg0 Node that produces the first input tensor. /// \param arg0 Node that produces the first input tensor.
......
...@@ -46,7 +46,10 @@ namespace ngraph ...@@ -46,7 +46,10 @@ namespace ngraph
NGRAPH_API NGRAPH_API
static constexpr NodeTypeInfo type_info{"Power", 0}; static constexpr NodeTypeInfo type_info{"Power", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; } const NodeTypeInfo& get_type_info() const override { return type_info; }
Power() = default; Power()
: util::BinaryElementwiseArithmetic(AutoBroadcastSpec::NONE)
{
}
/// \brief Constructs an exponentiation operation. /// \brief Constructs an exponentiation operation.
/// ///
/// \param arg0 Node that produces the first input tensor. /// \param arg0 Node that produces the first input tensor.
...@@ -89,7 +92,11 @@ namespace ngraph ...@@ -89,7 +92,11 @@ namespace ngraph
NGRAPH_API NGRAPH_API
static constexpr NodeTypeInfo type_info{"Power", 1}; static constexpr NodeTypeInfo type_info{"Power", 1};
const NodeTypeInfo& get_type_info() const override { return type_info; } const NodeTypeInfo& get_type_info() const override { return type_info; }
Power() = default; Power()
: util::BinaryElementwiseArithmetic(AutoBroadcastSpec::NUMPY)
{
}
/// \brief Constructs an exponentiation operation. /// \brief Constructs an exponentiation operation.
/// ///
/// \param arg0 Node that produces the first input tensor. /// \param arg0 Node that produces the first input tensor.
......
...@@ -36,7 +36,7 @@ shared_ptr<Node> op::Relu::copy_with_new_args(const NodeVector& new_args) const ...@@ -36,7 +36,7 @@ shared_ptr<Node> op::Relu::copy_with_new_args(const NodeVector& new_args) const
} }
op::ReluBackprop::ReluBackprop(shared_ptr<Node> arg, shared_ptr<Node> delta) op::ReluBackprop::ReluBackprop(shared_ptr<Node> arg, shared_ptr<Node> delta)
: BinaryElementwiseArithmetic(arg, delta) : BinaryElementwiseArithmetic(arg, delta, AutoBroadcastSpec::NONE)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -37,7 +37,7 @@ op::Sigmoid::Sigmoid(const Output<Node>& arg) ...@@ -37,7 +37,7 @@ op::Sigmoid::Sigmoid(const Output<Node>& arg)
} }
op::SigmoidBackprop::SigmoidBackprop(const Output<Node>& arg, const Output<Node>& delta) op::SigmoidBackprop::SigmoidBackprop(const Output<Node>& arg, const Output<Node>& delta)
: BinaryElementwiseArithmetic(arg, delta) : BinaryElementwiseArithmetic(arg, delta, AutoBroadcastSpec::NONE)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -47,7 +47,11 @@ namespace ngraph ...@@ -47,7 +47,11 @@ namespace ngraph
NGRAPH_API NGRAPH_API
static constexpr NodeTypeInfo type_info{"SigmoidBackprop", 0}; static constexpr NodeTypeInfo type_info{"SigmoidBackprop", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; } const NodeTypeInfo& get_type_info() const override { return type_info; }
SigmoidBackprop() = default; SigmoidBackprop()
: util::BinaryElementwiseArithmetic(AutoBroadcastSpec::NONE)
{
}
/// \brief Constructs a SigmoidBackprop operation. /// \brief Constructs a SigmoidBackprop operation.
/// ///
/// \param arg Node that produces the Sigmoid forward input tensor. /// \param arg Node that produces the Sigmoid forward input tensor.
......
...@@ -29,7 +29,11 @@ namespace ngraph ...@@ -29,7 +29,11 @@ namespace ngraph
NGRAPH_API NGRAPH_API
static constexpr NodeTypeInfo type_info{"Subtract", 0}; static constexpr NodeTypeInfo type_info{"Subtract", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; } const NodeTypeInfo& get_type_info() const override { return type_info; }
Subtract() = default; Subtract()
: util::BinaryElementwiseArithmetic(AutoBroadcastSpec::NONE)
{
}
/// \brief Constructs a subtraction operation. /// \brief Constructs a subtraction operation.
/// ///
/// \param arg0 Node that produces the first input tensor. /// \param arg0 Node that produces the first input tensor.
......
...@@ -19,6 +19,9 @@ ...@@ -19,6 +19,9 @@
using namespace ngraph; using namespace ngraph;
const op::AutoBroadcastSpec op::AutoBroadcastSpec::NUMPY(AutoBroadcastType::NUMPY, 0);
const op::AutoBroadcastSpec op::AutoBroadcastSpec::NONE{AutoBroadcastType::NONE, 0};
namespace ngraph namespace ngraph
{ {
template <> template <>
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include <ostream> #include <ostream>
#include "ngraph/attribute_adapter.hpp" #include "ngraph/attribute_adapter.hpp"
#include "ngraph/ngraph_visibility.hpp"
#include "ngraph/type.hpp" #include "ngraph/type.hpp"
namespace ngraph namespace ngraph
...@@ -269,6 +270,11 @@ namespace ngraph ...@@ -269,6 +270,11 @@ namespace ngraph
{ {
return a.m_type == m_type && a.m_axis == m_axis; return a.m_type == m_type && a.m_axis == m_axis;
} }
NGRAPH_API
static const AutoBroadcastSpec NUMPY;
NGRAPH_API
static const AutoBroadcastSpec NONE;
}; };
} }
} }
...@@ -19,7 +19,8 @@ ...@@ -19,7 +19,8 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::util::BinaryElementwiseArithmetic::BinaryElementwiseArithmetic() op::util::BinaryElementwiseArithmetic::BinaryElementwiseArithmetic(const AutoBroadcastSpec& autob)
: m_autob(autob)
{ {
} }
......
...@@ -54,12 +54,12 @@ namespace ngraph ...@@ -54,12 +54,12 @@ namespace ngraph
class BinaryElementwiseArithmetic : public Op class BinaryElementwiseArithmetic : public Op
{ {
protected: protected:
/// \brief Constructs a binary elementwise arithmetic operation. BinaryElementwiseArithmetic(const AutoBroadcastSpec& autob);
BinaryElementwiseArithmetic();
/// \brief Constructs a binary elementwise arithmetic operation.
BinaryElementwiseArithmetic(const std::shared_ptr<Node>& arg0, BinaryElementwiseArithmetic(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1, const std::shared_ptr<Node>& arg1,
const AutoBroadcastSpec& autob = AutoBroadcastSpec()); const AutoBroadcastSpec& autob);
/// \brief Constructs a binary elementwise arithmetic operation. /// \brief Constructs a binary elementwise arithmetic operation.
/// ///
...@@ -67,7 +67,7 @@ namespace ngraph ...@@ -67,7 +67,7 @@ namespace ngraph
/// \param arg1 Output that produces the second input tensor. /// \param arg1 Output that produces the second input tensor.
BinaryElementwiseArithmetic(const Output<Node>& arg0, BinaryElementwiseArithmetic(const Output<Node>& arg0,
const Output<Node>& arg1, const Output<Node>& arg1,
const AutoBroadcastSpec& autob = AutoBroadcastSpec()); const AutoBroadcastSpec& autob);
/// \brief Constructs a binary elementwise arithmetic operation. /// \brief Constructs a binary elementwise arithmetic operation.
/// ///
...@@ -77,7 +77,7 @@ namespace ngraph ...@@ -77,7 +77,7 @@ namespace ngraph
BinaryElementwiseArithmetic(const std::string& node_type, BinaryElementwiseArithmetic(const std::string& node_type,
const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1, const std::shared_ptr<Node>& arg1,
const AutoBroadcastSpec& autob = AutoBroadcastSpec()); const AutoBroadcastSpec& autob);
public: public:
void validate_and_infer_types() override; void validate_and_infer_types() override;
......
...@@ -22,7 +22,7 @@ using namespace ngraph; ...@@ -22,7 +22,7 @@ using namespace ngraph;
constexpr NodeTypeInfo op::GeluBackprop::type_info; constexpr NodeTypeInfo op::GeluBackprop::type_info;
op::GeluBackprop::GeluBackprop(const Output<ngraph::Node>& arg, const Output<ngraph::Node>& delta) op::GeluBackprop::GeluBackprop(const Output<ngraph::Node>& arg, const Output<ngraph::Node>& delta)
: BinaryElementwiseArithmetic(arg, delta) : BinaryElementwiseArithmetic(arg, delta, AutoBroadcastSpec::NONE)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
set_output_size(1); set_output_size(1);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment