Unverified Commit 61c7ba8c authored by Tomasz Dołbniak's avatar Tomasz Dołbniak Committed by GitHub

Merge branch 'master' into tomdol/pycapsule

parents 8c40ab98 feefdbb2
......@@ -16,9 +16,8 @@
import numpy as np
import ngraph as ng
from string import ascii_uppercase
from ngraph.utils.types import NumericData
from typing import Any, Callable, List
import test
......@@ -32,10 +31,14 @@ def get_runtime():
def run_op_node(input_data, op_fun, *args):
# type: (NumericData, Callable, *Any) -> List[NumericData]
"""Run computation on node performing `op_fun`.
`op_fun` has to accept a node as an argument.
This function converts passed raw input data to nGraph Constant Node and that form is passed
to `op_fun`.
:param input_data: The input data for performed computation.
:param op_fun: The function handler for operation we want to carry out.
:param args: The arguments passed to operation we want to carry out.
......@@ -45,14 +48,8 @@ def run_op_node(input_data, op_fun, *args):
comp_args = []
op_fun_args = []
comp_inputs = []
for idx, data in enumerate(input_data):
if np.isscalar(data):
op_fun_args.append(ng.constant(data, _get_numpy_dtype(data)))
else:
node = ng.parameter(data.shape, name=ascii_uppercase[idx], dtype=data.dtype)
op_fun_args.append(node)
comp_args.append(node)
comp_inputs.append(data)
for data in input_data:
op_fun_args.append(ng.constant(data, _get_numpy_dtype(data)))
op_fun_args.extend(args)
node = op_fun(*op_fun_args)
computation = runtime.computation(node, *comp_args)
......@@ -60,10 +57,15 @@ def run_op_node(input_data, op_fun, *args):
def run_op_numeric_data(input_data, op_fun, *args):
# type: (NumericData, Callable, *Any) -> List[NumericData]
"""Run computation on node performing `op_fun`.
`op_fun` has to accept a scalar or an array.
This function passess input data AS IS. This mean that in case they're a scalar (integral,
or floating point value) or a NumPy's ndarray object they will be automatically converted
to nGraph's Constant Nodes.
:param input_data: The input data for performed computation.
:param op_fun: The function handler for operation we want to carry out.
:param args: The arguments passed to operation we want to carry out.
......
......@@ -350,6 +350,7 @@ namespace
}
return callBackFuncPtr;
}
// NGDialect converters
Type NGraphTypeConverter::convertType(Type type)
{
......@@ -576,7 +577,6 @@ namespace
// Create Value for result, and extract type info.
Value* result = m_pass.buildOutputDefs(op, rewriter)[0];
NGRAPH_CHECK(result, "Unexpected null result in ConcatOp");
auto resultTy = result->getType().cast<MemRefType>();
// Create view to write into result.
MemRefView vRes(result);
......@@ -590,7 +590,6 @@ namespace
for (auto& operand : operands)
{
NGRAPH_CHECK(operand, "Unexpected null operand in ConcatOp");
auto operandTy = result->getType().cast<MemRefType>();
// Assuming rank = r, and the concatenation axis is A where A<r, we'll be creating
// loops of this form:
......
......@@ -74,7 +74,6 @@ void MLIRSubgraphExtractionPass::MLIRSubgraph::merge(MLIRSubgraph& sg2)
// Associate nodes of second sub-graph to first one
auto sg_nodes = sg2.get_nodes();
auto& node_map = m_pass.m_node_to_graph;
for (auto node : sg_nodes)
{
NGRAPH_DEBUG << *node;
......@@ -112,7 +111,6 @@ bool MLIRSubgraphExtractionPass::run_on_function(std::shared_ptr<Function> func)
for (auto op : func->get_ordered_ops())
{
NodeVector inputs;
int first_graph_id = -1;
std::unordered_set<int> subgraph_ids;
// unsupported ops, skip
if (!is_supported_mlir_op(op))
......
......@@ -160,5 +160,5 @@ namespace ngraph
/// \brief Macro to signal a code path that is unreachable in a successful execution. It's
/// implemented with NGRAPH_CHECK macro.
/// \param ... Additional error message that should describe why that execution path is unreachable.
/// \throws ::ngrap::CheckFailure if the macro is executed.
/// \throws ::ngraph::CheckFailure if the macro is executed.
#define NGRAPH_UNREACHABLE(...) NGRAPH_CHECK(false, "Unreachable: ", ##__VA_ARGS__)
......@@ -214,7 +214,7 @@ namespace ngraph
virtual bool is_constant() const;
virtual bool is_null() const { return false; }
virtual bool is_op() const { return false; }
virtual bool is_commutative() { return false; }
virtual bool is_commutative() const { return false; }
virtual bool is_dynamic() const;
virtual bool has_state() const { return false; }
size_t get_instance_id() const { return m_instance_id; }
......
......@@ -49,7 +49,7 @@ void op::Add::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector&
adjoints.add_delta(y, delta);
}
shared_ptr<Node> ngraph::operator+(const shared_ptr<Node>& arg0, const shared_ptr<Node>& arg1)
shared_ptr<Node> ngraph::operator+(const Output<Node>& arg0, const Output<Node>& arg1)
{
return make_shared<op::Add>(arg0, arg1);
}
......@@ -51,13 +51,12 @@ namespace ngraph
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
virtual bool is_commutative() const override { return true; }
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
virtual bool is_commutative() override { return true; }
};
}
std::shared_ptr<ngraph::Node> operator+(const std::shared_ptr<ngraph::Node>& arg0,
const std::shared_ptr<ngraph::Node>& arg1);
std::shared_ptr<Node> operator+(const Output<Node>& arg0, const Output<Node>& arg1);
}
......@@ -51,8 +51,7 @@ namespace ngraph
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
protected:
virtual bool is_commutative() override { return true; }
virtual bool is_commutative() const override { return true; }
};
}
}
......@@ -22,12 +22,15 @@
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/validation_util.hpp"
const std::string ngraph::op::BatchNormTraining::type_name{"BatchNormTraining"};
using namespace std;
using namespace ngraph;
ngraph::op::BatchNormTraining::BatchNormTraining(Output<ngraph::Node> input,
Output<ngraph::Node> gamma,
Output<ngraph::Node> beta,
double epsilon)
const string op::BatchNormTraining::type_name{"BatchNormTraining"};
op::BatchNormTraining::BatchNormTraining(const Output<Node>& input,
const Output<Node>& gamma,
const Output<Node>& beta,
double epsilon)
: Op({gamma, beta, input})
, m_epsilon(epsilon)
{
......@@ -35,17 +38,17 @@ ngraph::op::BatchNormTraining::BatchNormTraining(Output<ngraph::Node> input,
}
// DEPRECATED
ngraph::op::BatchNormTraining::BatchNormTraining(double eps,
Output<ngraph::Node> gamma,
Output<ngraph::Node> beta,
Output<ngraph::Node> input)
op::BatchNormTraining::BatchNormTraining(double eps,
const Output<Node>& gamma,
const Output<Node>& beta,
const Output<Node>& input)
: Op({gamma, beta, input})
, m_epsilon(eps)
{
constructor_validate_and_infer_types();
}
void ngraph::op::BatchNormTraining::validate_and_infer_types()
void op::BatchNormTraining::validate_and_infer_types()
{
element::Type result_et;
PartialShape result_batch_shape;
......@@ -66,16 +69,15 @@ void ngraph::op::BatchNormTraining::validate_and_infer_types()
set_output_type(2, result_et, result_channel_shape);
}
std::shared_ptr<ngraph::Node>
ngraph::op::BatchNormTraining::copy_with_new_args(const NodeVector& new_args) const
std::shared_ptr<Node> op::BatchNormTraining::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return std::make_shared<BatchNormTraining>(
new_args.at(2), new_args.at(0), new_args.at(1), m_epsilon);
}
void ngraph::op::BatchNormTraining::generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas)
void op::BatchNormTraining::generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas)
{
auto gamma = input(0).get_source_output();
auto beta = input(1).get_source_output();
......@@ -102,14 +104,14 @@ void ngraph::op::BatchNormTraining::generate_adjoints(autodiff::Adjoints& adjoin
adjoints.add_delta(beta, dbeta);
}
const std::string ngraph::op::BatchNormInference::type_name{"BatchNormInference"};
const string op::BatchNormInference::type_name{"BatchNormInference"};
ngraph::op::BatchNormInference::BatchNormInference(Output<ngraph::Node> input,
Output<ngraph::Node> gamma,
Output<ngraph::Node> beta,
Output<ngraph::Node> mean,
Output<ngraph::Node> variance,
double epsilon)
op::BatchNormInference::BatchNormInference(const Output<Node>& input,
const Output<Node>& gamma,
const Output<Node>& beta,
const Output<Node>& mean,
const Output<Node>& variance,
double epsilon)
: Op({gamma, beta, input, mean, variance})
, m_epsilon(epsilon)
{
......@@ -117,19 +119,19 @@ ngraph::op::BatchNormInference::BatchNormInference(Output<ngraph::Node> input,
}
// DEPRECATED
ngraph::op::BatchNormInference::BatchNormInference(double eps,
Output<ngraph::Node> gamma,
Output<ngraph::Node> beta,
Output<ngraph::Node> input,
Output<ngraph::Node> mean,
Output<ngraph::Node> variance)
op::BatchNormInference::BatchNormInference(double eps,
const Output<Node>& gamma,
const Output<Node>& beta,
const Output<Node>& input,
const Output<Node>& mean,
const Output<Node>& variance)
: Op({gamma, beta, input, mean, variance})
, m_epsilon(eps)
{
constructor_validate_and_infer_types();
}
void ngraph::op::BatchNormInference::validate_and_infer_types()
void op::BatchNormInference::validate_and_infer_types()
{
element::Type result_et;
PartialShape result_batch_shape;
......@@ -152,23 +154,22 @@ void ngraph::op::BatchNormInference::validate_and_infer_types()
set_output_type(0, result_et, result_batch_shape);
}
std::shared_ptr<ngraph::Node>
ngraph::op::BatchNormInference::copy_with_new_args(const NodeVector& new_args) const
std::shared_ptr<Node> op::BatchNormInference::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return std::make_shared<BatchNormInference>(
new_args.at(2), new_args.at(0), new_args.at(1), new_args.at(3), new_args.at(4), m_epsilon);
}
const std::string ngraph::op::BatchNormTrainingBackprop::type_name{"BatchNormTrainingBackprop"};
const string op::BatchNormTrainingBackprop::type_name{"BatchNormTrainingBackprop"};
ngraph::op::BatchNormTrainingBackprop::BatchNormTrainingBackprop(Output<ngraph::Node> input,
Output<ngraph::Node> gamma,
Output<ngraph::Node> beta,
Output<ngraph::Node> mean,
Output<ngraph::Node> variance,
Output<ngraph::Node> delta,
double epsilon)
op::BatchNormTrainingBackprop::BatchNormTrainingBackprop(const Output<Node>& input,
const Output<Node>& gamma,
const Output<Node>& beta,
const Output<Node>& mean,
const Output<Node>& variance,
const Output<Node>& delta,
double epsilon)
: Op({gamma, beta, input, mean, variance, delta})
, m_epsilon(epsilon)
......@@ -177,13 +178,13 @@ ngraph::op::BatchNormTrainingBackprop::BatchNormTrainingBackprop(Output<ngraph::
constructor_validate_and_infer_types();
}
ngraph::op::BatchNormTrainingBackprop::BatchNormTrainingBackprop(double epsilon,
Output<ngraph::Node> gamma,
Output<ngraph::Node> beta,
Output<ngraph::Node> input,
Output<ngraph::Node> mean,
Output<ngraph::Node> variance,
Output<ngraph::Node> delta)
op::BatchNormTrainingBackprop::BatchNormTrainingBackprop(double epsilon,
const Output<Node>& gamma,
const Output<Node>& beta,
const Output<Node>& input,
const Output<Node>& mean,
const Output<Node>& variance,
const Output<Node>& delta)
: Op({gamma, beta, input, mean, variance, delta})
, m_epsilon(epsilon)
......@@ -192,7 +193,7 @@ ngraph::op::BatchNormTrainingBackprop::BatchNormTrainingBackprop(double epsilon,
constructor_validate_and_infer_types();
}
void ngraph::op::BatchNormTrainingBackprop::validate_and_infer_types()
void op::BatchNormTrainingBackprop::validate_and_infer_types()
{
PartialShape input_and_delta_shape{get_input_partial_shape(INPUT_DATA)};
......@@ -239,8 +240,8 @@ void ngraph::op::BatchNormTrainingBackprop::validate_and_infer_types()
set_output_type(2, result_et, result_channel_shape);
}
std::shared_ptr<ngraph::Node>
ngraph::op::BatchNormTrainingBackprop::copy_with_new_args(const NodeVector& new_args) const
std::shared_ptr<Node>
op::BatchNormTrainingBackprop::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return std::make_shared<op::BatchNormTrainingBackprop>(new_args.at(2),
......
......@@ -39,9 +39,9 @@ namespace ngraph
/// \param gamma gamma scaling for normalized value. [C]
/// \param beta bias added to the scaled normalized value [C]
/// \param epsilon Avoids divsion by 0 if input has 0 variance
BatchNormTraining(Output<Node> input,
Output<Node> gamma,
Output<Node> beta,
BatchNormTraining(const Output<Node>& input,
const Output<Node>& gamma,
const Output<Node>& beta,
double epsilon);
NGRAPH_DEPRECATED_DOC
......@@ -66,9 +66,9 @@ namespace ngraph
/// output[2]: shall have rank 1, with the same span as input's channel axis.
NGRAPH_DEPRECATED("Use another constructor")
BatchNormTraining(double eps,
Output<Node> gamma,
Output<Node> beta,
Output<Node> input);
const Output<Node>& gamma,
const Output<Node>& beta,
const Output<Node>& input);
void validate_and_infer_types() override;
......@@ -101,11 +101,11 @@ namespace ngraph
/// \param mean value for mean normalization [C]
/// \param variance value for variance normalization [C]
/// \param epsilon Avoids divsion by 0 if input has 0 variance
BatchNormInference(Output<ngraph::Node> input,
Output<ngraph::Node> gamma,
Output<ngraph::Node> beta,
Output<ngraph::Node> mean,
Output<ngraph::Node> variance,
BatchNormInference(const Output<Node>& input,
const Output<Node>& gamma,
const Output<Node>& beta,
const Output<Node>& mean,
const Output<Node>& variance,
double epsilon);
NGRAPH_DEPRECATED_DOC
......@@ -128,11 +128,11 @@ namespace ngraph
/// output: shall have the same shape as 'input'.
NGRAPH_DEPRECATED("Use another constructor")
BatchNormInference(double eps,
Output<ngraph::Node> gamma,
Output<ngraph::Node> beta,
Output<ngraph::Node> input,
Output<ngraph::Node> mean,
Output<ngraph::Node> variance);
const Output<Node>& gamma,
const Output<Node>& beta,
const Output<Node>& input,
const Output<Node>& mean,
const Output<Node>& variance);
void validate_and_infer_types() override;
......@@ -165,24 +165,23 @@ namespace ngraph
static const std::string type_name;
const std::string& description() const override { return type_name; }
BatchNormTrainingBackprop() = default;
BatchNormTrainingBackprop(Output<Node> input,
Output<Node> gamma,
Output<Node> beta,
Output<Node> mean,
Output<Node> variance,
Output<Node> delta,
BatchNormTrainingBackprop(const Output<Node>& input,
const Output<Node>& gamma,
const Output<Node>& beta,
const Output<Node>& mean,
const Output<Node>& variance,
const Output<Node>& delta,
double epsilon);
NGRAPH_DEPRECATED_DOC
NGRAPH_DEPRECATED("Use another constructor")
BatchNormTrainingBackprop(double epsilon,
Output<Node> gamma,
Output<Node> beta,
Output<Node> input,
Output<Node> mean,
Output<Node> variance,
Output<Node> delta);
const Output<Node>& gamma,
const Output<Node>& beta,
const Output<Node>& input,
const Output<Node>& mean,
const Output<Node>& variance,
const Output<Node>& delta);
void validate_and_infer_types() override;
......
......@@ -64,7 +64,7 @@ void op::Divide::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVecto
adjoints.add_delta(y, -delta * shared_from_this() / y);
}
shared_ptr<Node> ngraph::operator/(const Output<Node> arg0, const Output<Node> arg1)
shared_ptr<Node> ngraph::operator/(const Output<Node>& arg0, const Output<Node>& arg1)
{
return make_shared<op::Divide>(arg0, arg1);
}
......@@ -64,6 +64,5 @@ namespace ngraph
};
}
std::shared_ptr<ngraph::Node> operator/(const Output<ngraph::Node> arg0,
const Output<ngraph::Node> arg1);
std::shared_ptr<Node> operator/(const Output<Node>& arg0, const Output<Node>& arg1);
}
......@@ -58,7 +58,7 @@ namespace ngraph
void validate_and_infer_types() override;
size_t get_reduction_axes_count() const { return m_reduction_axes_count; }
void get_reduction_axes_count(size_t reduction_axes_count)
void set_reduction_axes_count(size_t reduction_axes_count)
{
m_reduction_axes_count = reduction_axes_count;
}
......
......@@ -56,6 +56,8 @@ namespace ngraph
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
virtual bool is_commutative() const override { return true; }
};
}
}
......@@ -23,6 +23,8 @@ using namespace ngraph;
static int PARAMS = 0;
static int INDICES = 1;
const string op::Gather::type_name{"Gather"};
shared_ptr<Node> op::Gather::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
......
......@@ -26,13 +26,15 @@ namespace ngraph
class Gather : public Op
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
Gather() = default;
/// \param params The tensor from which slices are gathered
/// \param indices Index tensor: Data type must be `element::i32` or `element::i64`
/// \param axis Axis in params to gather
Gather(const std::shared_ptr<Node>& params,
const std::shared_ptr<Node>& indices,
size_t axis = 0)
: Op("Gather", check_single_output_args({params, indices}))
Gather(const Output<Node>& params, const Output<Node>& indices, size_t axis = 0)
: Op({params, indices})
, m_axis(axis)
{
constructor_validate_and_infer_types();
......@@ -46,6 +48,7 @@ namespace ngraph
}
size_t get_axis() const { return m_axis; }
void set_axis(size_t axis) { m_axis = axis; }
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
......
......@@ -23,6 +23,8 @@ using namespace ngraph;
static int PARAMS = 0;
static int INDICES = 1;
const string op::GatherND::type_name{"GatherND"};
shared_ptr<Node> op::GatherND::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
......
......@@ -26,10 +26,14 @@ namespace ngraph
class GatherND : public Op
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
GatherND() = default;
/// \param params The tensor from which slices are gathered
/// \param indices Index tensor: Data type must be `element::i32` or `element::i64`
GatherND(const std::shared_ptr<Node>& params, const std::shared_ptr<Node>& indices)
: Op("GatherND", check_single_output_args({params, indices}))
GatherND(const Output<Node>& params, const Output<Node>& indices)
: Op({params, indices})
{
constructor_validate_and_infer_types();
}
......
......@@ -19,10 +19,12 @@
using namespace std;
using namespace ngraph;
op::Greater::Greater(const shared_ptr<Node>& arg0,
const shared_ptr<Node>& arg1,
const string op::Greater::type_name{"Greater"};
op::Greater::Greater(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& autob)
: BinaryElementwiseComparison("Greater", arg0, arg1, autob)
: BinaryElementwiseComparison(arg0, arg1, autob)
{
constructor_validate_and_infer_types();
}
......
......@@ -26,13 +26,18 @@ namespace ngraph
class Greater : public util::BinaryElementwiseComparison
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a greater-than operation.
Greater() = default;
/// \brief Constructs a greater-than operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param autob Auto broadcast specification
Greater(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1,
Greater(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& autob = AutoBroadcastSpec());
virtual std::shared_ptr<Node>
......
......@@ -19,10 +19,12 @@
using namespace std;
using namespace ngraph;
op::GreaterEq::GreaterEq(const shared_ptr<Node>& arg0,
const shared_ptr<Node>& arg1,
const string op::GreaterEq::type_name{"GreaterEq"};
op::GreaterEq::GreaterEq(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& autob)
: BinaryElementwiseComparison("GreaterEq", arg0, arg1, autob)
: BinaryElementwiseComparison(arg0, arg1, autob)
{
constructor_validate_and_infer_types();
}
......
......@@ -26,13 +26,18 @@ namespace ngraph
class GreaterEq : public util::BinaryElementwiseComparison
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a greater-than-or-equal operation.
GreaterEq() = default;
/// \brief Constructs a greater-than-or-equal operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param autob Auto broadcast specification
GreaterEq(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1,
GreaterEq(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& autob = AutoBroadcastSpec());
virtual std::shared_ptr<Node>
......
......@@ -19,10 +19,10 @@
using namespace std;
using namespace ngraph;
op::Less::Less(const shared_ptr<Node>& arg0,
const shared_ptr<Node>& arg1,
const AutoBroadcastSpec& autob)
: BinaryElementwiseComparison("Less", arg0, arg1, autob)
const string op::Less::type_name{"Less"};
op::Less::Less(const Output<Node>& arg0, const Output<Node>& arg1, const AutoBroadcastSpec& autob)
: BinaryElementwiseComparison(arg0, arg1, autob)
{
constructor_validate_and_infer_types();
}
......
......@@ -26,13 +26,18 @@ namespace ngraph
class Less : public util::BinaryElementwiseComparison
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a less-than operation.
Less() = default;
/// \brief Constructs a less-than operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param autob Auto broadcast specification
Less(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1,
Less(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& autob = AutoBroadcastSpec());
virtual std::shared_ptr<Node>
......
......@@ -19,10 +19,12 @@
using namespace std;
using namespace ngraph;
op::LessEq::LessEq(const shared_ptr<Node>& arg0,
const shared_ptr<Node>& arg1,
const string op::LessEq::type_name{"LessEq"};
op::LessEq::LessEq(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& autob)
: BinaryElementwiseComparison("LessEq", arg0, arg1, autob)
: BinaryElementwiseComparison(arg0, arg1, autob)
{
constructor_validate_and_infer_types();
}
......
......@@ -26,13 +26,18 @@ namespace ngraph
class LessEq : public util::BinaryElementwiseComparison
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a less-than-or-equal operation.
LessEq() = default;
/// \brief Constructs a less-than-or-equal operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param autob Auto broadcast specification
LessEq(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1,
LessEq(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& autob = AutoBroadcastSpec());
virtual std::shared_ptr<Node>
......
......@@ -20,8 +20,10 @@
using namespace std;
using namespace ngraph;
op::Log::Log(const shared_ptr<Node>& arg)
: UnaryElementwiseArithmetic("Log", arg)
const string op::Log::type_name{"Log"};
op::Log::Log(const Output<Node>& arg)
: UnaryElementwiseArithmetic(arg)
{
constructor_validate_and_infer_types();
}
......
......@@ -26,10 +26,15 @@ namespace ngraph
class Log : public util::UnaryElementwiseArithmetic
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a natural log operation.
Log() = default;
/// \brief Constructs a natural log operation.
///
/// \param arg Node that produces the input tensor.
Log(const std::shared_ptr<Node>& arg);
Log(const Output<Node>& arg);
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
......
......@@ -20,12 +20,14 @@
using namespace std;
using namespace ngraph;
op::LRN::LRN(const std::shared_ptr<Node>& arg, double alpha, double beta, double bias, size_t nsize)
: UnaryElementwiseArithmetic("LRN", arg)
const string op::LRN::type_name{"LRN"};
op::LRN::LRN(const Output<Node>& arg, double alpha, double beta, double bias, size_t size)
: UnaryElementwiseArithmetic(arg)
, m_alpha(alpha)
, m_beta(beta)
, m_bias(bias)
, m_size(nsize)
, m_size(size)
{
constructor_validate_and_infer_types();
}
......
......@@ -38,23 +38,28 @@ namespace ngraph
class LRN : public util::UnaryElementwiseArithmetic
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a LRN operation.
LRN() = default;
/// \brief Constructs a LRN operation.
///
/// \param arg Node that produces the input tensor.
LRN(const std::shared_ptr<Node>& arg,
double alpha,
double beta,
double bias,
size_t size);
LRN(const Output<Node>& arg, double alpha, double beta, double bias, size_t size);
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
void validate_and_infer_types() override;
double get_alpha() const { return m_alpha; }
void set_alpha(double alpha) { m_alpha = alpha; }
double get_beta() const { return m_beta; }
void set_beta(double beta) { m_beta = beta; }
double get_bias() const { return m_bias; }
void set_bias(double bias) { m_bias = bias; }
size_t get_nsize() const { return m_size; }
void set_nsize(size_t size) { m_size = size; }
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
......
......@@ -22,10 +22,6 @@ using namespace ngraph;
const string op::Max::type_name{"Max"};
op::Max::Max()
{
}
op::Max::Max(const Output<Node>& arg, const AxisSet& reduction_axes)
: ArithmeticReduction(arg, reduction_axes)
{
......
......@@ -30,7 +30,7 @@ namespace ngraph
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a "max" reduction operation.
Max();
Max() = default;
/// \brief Constructs a max-reduction operation.
///
/// \param arg The tensor to be reduced.
......
......@@ -25,14 +25,16 @@
using namespace std;
using namespace ngraph;
op::MaxPool::MaxPool(const shared_ptr<Node>& arg,
const string op::MaxPool::type_name{"MaxPool"};
op::MaxPool::MaxPool(const Output<Node>& arg,
const Shape& window_shape,
const Strides& window_movement_strides,
const Shape& padding_below,
const Shape& padding_above,
const PadType& pad_type,
bool ceil_mode)
: Op("MaxPool", check_single_output_args({arg}))
: Op({arg})
, m_window_shape(window_shape)
, m_window_movement_strides(window_movement_strides)
, m_padding_below(padding_below)
......@@ -43,7 +45,7 @@ op::MaxPool::MaxPool(const shared_ptr<Node>& arg,
constructor_validate_and_infer_types();
}
op::MaxPool::MaxPool(const shared_ptr<Node>& arg,
op::MaxPool::MaxPool(const Output<Node>& arg,
const Shape& window_shape,
const Strides& window_movement_strides,
const Shape& padding_below,
......@@ -54,7 +56,7 @@ op::MaxPool::MaxPool(const shared_ptr<Node>& arg,
{
}
op::MaxPool::MaxPool(const shared_ptr<Node>& arg,
op::MaxPool::MaxPool(const Output<Node>& arg,
const Shape& window_shape,
const Strides& window_movement_strides,
const Shape& padding_below,
......@@ -121,14 +123,14 @@ void op::MaxPool::validate_and_infer_types()
m_ceil_mode));
}
op::MaxPool::MaxPool(const shared_ptr<Node>& arg,
op::MaxPool::MaxPool(const Output<Node>& arg,
const Shape& window_shape,
const Strides& window_movement_strides)
: MaxPool(arg, window_shape, window_movement_strides, Shape(), Shape())
{
}
op::MaxPool::MaxPool(const shared_ptr<Node>& arg, const Shape& window_shape)
op::MaxPool::MaxPool(const Output<Node>& arg, const Shape& window_shape)
: MaxPool(arg, window_shape, Strides(), Shape(), Shape())
{
}
......@@ -145,13 +147,15 @@ shared_ptr<Node> op::MaxPool::copy_with_new_args(const NodeVector& new_args) con
m_ceil_mode);
}
op::MaxPoolBackprop::MaxPoolBackprop(const shared_ptr<Node>& arg_forward,
const shared_ptr<Node>& delta,
const string op::MaxPoolBackprop::type_name{"MaxPoolBackprop"};
op::MaxPoolBackprop::MaxPoolBackprop(const Output<Node>& arg_forward,
const Output<Node>& delta,
const Shape& window_shape,
const Strides& window_movement_strides,
const Shape& padding_below,
const Shape& padding_above)
: Op("MaxPoolBackprop", check_single_output_args({arg_forward, delta}))
: Op({arg_forward, delta})
, m_window_shape(window_shape)
, m_window_movement_strides(window_movement_strides)
, m_padding_below(padding_below)
......@@ -160,14 +164,14 @@ op::MaxPoolBackprop::MaxPoolBackprop(const shared_ptr<Node>& arg_forward,
constructor_validate_and_infer_types();
}
op::MaxPoolBackprop::MaxPoolBackprop(const shared_ptr<Node>& arg_forward,
const shared_ptr<Node>& delta,
const shared_ptr<Node>& result_forward,
op::MaxPoolBackprop::MaxPoolBackprop(const Output<Node>& arg_forward,
const Output<Node>& delta,
const Output<Node>& result_forward,
const Shape& window_shape,
const Strides& window_movement_strides,
const Shape& padding_below,
const Shape& padding_above)
: Op("MaxPoolBackprop", check_single_output_args({arg_forward, delta, result_forward}))
: Op({arg_forward, delta, result_forward})
, m_window_shape(window_shape)
, m_window_movement_strides(window_movement_strides)
, m_padding_below(padding_below)
......
......@@ -28,6 +28,12 @@ namespace ngraph
class MaxPool : public Op
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a batched max pooling operation.
MaxPool() = default;
/// \brief Constructs a batched max pooling operation.
///
/// \param arg The node producing the input data batch tensor.
......@@ -37,7 +43,7 @@ namespace ngraph
/// \param padding_above The above-padding shape.
/// \param pad_type The pad type for automatically computing padding sizes
/// \param ceil_mode Whether to use ceiling while computing output shape.
MaxPool(const std::shared_ptr<Node>& arg,
MaxPool(const Output<Node>& arg,
const Shape& window_shape,
const Strides& window_movement_strides,
const Shape& padding_below,
......@@ -53,7 +59,7 @@ namespace ngraph
/// \param padding_below The below-padding shape.
/// \param padding_above The above-padding shape.
/// \param pad_type The pad type for automatically computing padding sizes
MaxPool(const std::shared_ptr<Node>& arg,
MaxPool(const Output<Node>& arg,
const Shape& window_shape,
const Strides& window_movement_strides,
const Shape& padding_below,
......@@ -67,7 +73,7 @@ namespace ngraph
/// \param window_movement_strides The window movement strides.
/// \param padding_below The below-padding shape.
/// \param padding_above The above-padding shape.
MaxPool(const std::shared_ptr<Node>& arg,
MaxPool(const Output<Node>& arg,
const Shape& window_shape,
const Strides& window_movement_strides,
const Shape& padding_below,
......@@ -80,7 +86,7 @@ namespace ngraph
/// \param arg The node producing the input data batch tensor.
/// \param window_shape The window shape.
/// \param window_movement_strides The window movement strides.
MaxPool(const std::shared_ptr<Node>& arg,
MaxPool(const Output<Node>& arg,
const Shape& window_shape,
const Strides& window_movement_strides);
......@@ -88,23 +94,32 @@ namespace ngraph
///
/// \param arg The node producing the input data batch tensor.
/// \param window_shape The window shape.
MaxPool(const std::shared_ptr<Node>& arg, const Shape& window_shape);
MaxPool(const Output<Node>& arg, const Shape& window_shape);
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
/// \return The window shape.
const Shape& get_window_shape() const { return m_window_shape; }
void set_window_shape(const Shape& window_shape) { m_window_shape = window_shape; }
/// \return The window movement strides.
const Strides& get_window_movement_strides() const { return m_window_movement_strides; }
void set_window_movement_strides(const Strides& window_movement_strides)
{
m_window_movement_strides = window_movement_strides;
}
/// \return The below-padding shape.
const Shape& get_padding_below() const { return m_padding_below; }
void set_padding_below(const Shape& padding_below) { m_padding_below = padding_below; }
/// \return The above-padding shape.
const Shape& get_padding_above() const { return m_padding_above; }
void set_adding_above(const Shape& padding_above) { m_padding_above = padding_above; }
/// \return The pad type for pooling.
const PadType& get_pad_type() const { return m_pad_type; }
void set_pad_type(const PadType& pad_type) { m_pad_type = pad_type; }
/// \return The ceiling mode being used for output shape computations
bool get_ceil_mode() const { return m_ceil_mode; }
void set_ceil_mode(bool ceil_mode) { m_ceil_mode = ceil_mode; }
/// \return The default value for MaxPool.
virtual std::shared_ptr<Node> get_default_value() const override
{
......@@ -126,16 +141,21 @@ namespace ngraph
class MaxPoolBackprop : public Op
{
public:
MaxPoolBackprop(const std::shared_ptr<Node>& arg_forward,
const std::shared_ptr<Node>& delta,
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
MaxPoolBackprop() = default;
MaxPoolBackprop(const Output<Node>& arg_forward,
const Output<Node>& delta,
const Shape& window_shape,
const Strides& window_movement_strides,
const Shape& padding_below,
const Shape& padding_above);
MaxPoolBackprop(const std::shared_ptr<Node>& arg_forward,
const std::shared_ptr<Node>& delta,
const std::shared_ptr<Node>& result_forward,
MaxPoolBackprop(const Output<Node>& arg_forward,
const Output<Node>& delta,
const Output<Node>& result_forward,
const Shape& window_shape,
const Strides& window_movement_strides,
const Shape& padding_below,
......@@ -147,9 +167,16 @@ namespace ngraph
void validate_and_infer_types() override;
const Shape& get_window_shape() const { return m_window_shape; }
void set_window_shape(const Shape& window_shape) { m_window_shape = window_shape; }
const Strides& get_window_movement_strides() const { return m_window_movement_strides; }
void set_window_movement_strides(const Strides& window_movement_strides)
{
m_window_movement_strides = window_movement_strides;
}
const Shape& get_padding_below() const { return m_padding_below; }
void set_padding_below(const Shape& padding_below) { m_padding_below = padding_below; }
const Shape& get_padding_above() const { return m_padding_above; }
void set_padding_above(const Shape& padding_above) { m_padding_above = padding_above; }
protected:
Shape m_window_shape;
Strides m_window_movement_strides;
......
......@@ -25,10 +25,12 @@
using namespace std;
using namespace ngraph;
op::Maximum::Maximum(const shared_ptr<Node>& arg0,
const shared_ptr<Node>& arg1,
const string op::Maximum::type_name{"Maximum"};
op::Maximum::Maximum(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& autob)
: BinaryElementwiseArithmetic("Maximum", arg0, arg1, autob)
: BinaryElementwiseArithmetic(arg0, arg1, autob)
{
constructor_validate_and_infer_types();
}
......
......@@ -26,19 +26,24 @@ namespace ngraph
class Maximum : public util::BinaryElementwiseArithmetic
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a maximum operation.
Maximum() = default;
/// \brief Constructs a maximum operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param autob Auto broadcast specification
Maximum(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1,
Maximum(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& autob = AutoBroadcastSpec());
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
virtual bool is_commutative() override { return true; }
virtual bool is_commutative() const override { return true; }
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
......
......@@ -22,10 +22,6 @@ using namespace ngraph;
const string op::Min::type_name{"Min"};
op::Min::Min()
{
}
op::Min::Min(const Output<Node>& arg, const AxisSet& reduction_axes)
: ArithmeticReduction(arg, reduction_axes)
{
......
......@@ -30,7 +30,7 @@ namespace ngraph
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a "min" reduction operation.
Min();
Min() = default;
/// \brief Constructs a min-reduction operation.
///
/// \param arg The tensor to be reduced.
......
......@@ -25,10 +25,12 @@
using namespace std;
using namespace ngraph;
op::Minimum::Minimum(const shared_ptr<Node>& arg0,
const shared_ptr<Node>& arg1,
const string op::Minimum::type_name{"Minimum"};
op::Minimum::Minimum(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& autob)
: BinaryElementwiseArithmetic("Minimum", arg0, arg1, autob)
: BinaryElementwiseArithmetic(arg0, arg1, autob)
{
constructor_validate_and_infer_types();
}
......
......@@ -26,18 +26,24 @@ namespace ngraph
class Minimum : public util::BinaryElementwiseArithmetic
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a minimum operation.
Minimum() = default;
/// \brief Constructs a minimum operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param autob Auto broadcast specification
Minimum(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1,
Minimum(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& autob = AutoBroadcastSpec());
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
virtual bool is_commutative() const override { return true; }
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
......
......@@ -19,10 +19,12 @@
using namespace std;
using namespace ngraph;
op::Multiply::Multiply(const shared_ptr<Node>& arg0,
const shared_ptr<Node>& arg1,
const string op::Multiply::type_name{"Multiply"};
op::Multiply::Multiply(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& autob)
: BinaryElementwiseArithmetic("Multiply", arg0, arg1, autob)
: BinaryElementwiseArithmetic(arg0, arg1, autob)
{
constructor_validate_and_infer_types();
}
......@@ -49,7 +51,7 @@ void op::Multiply::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVec
adjoints.add_delta(y, x * delta);
}
shared_ptr<Node> ngraph::operator*(const shared_ptr<Node> arg0, const shared_ptr<Node> arg1)
shared_ptr<Node> ngraph::operator*(const Output<Node>& arg0, const Output<Node>& arg1)
{
return make_shared<op::Multiply>(arg0, arg1);
}
......@@ -26,25 +26,29 @@ namespace ngraph
class Multiply : public util::BinaryElementwiseArithmetic
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a multiplication operation.
Multiply() = default;
/// \brief Constructs a multiplication operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param autob Auto broadcast specification
Multiply(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1,
Multiply(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& autob = AutoBroadcastSpec());
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
virtual bool is_commutative() const override { return true; }
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
virtual bool is_commutative() override { return true; }
};
};
std::shared_ptr<ngraph::Node> operator*(const std::shared_ptr<ngraph::Node> arg0,
const std::shared_ptr<ngraph::Node> arg1);
std::shared_ptr<Node> operator*(const Output<Node>& arg0, const Output<Node>& arg1);
}
......@@ -19,8 +19,10 @@
using namespace std;
using namespace ngraph;
op::Negative::Negative(const shared_ptr<Node>& arg)
: UnaryElementwiseArithmetic("Negative", arg)
const string op::Negative::type_name{"Negative"};
op::Negative::Negative(const Output<Node>& arg)
: UnaryElementwiseArithmetic(arg)
{
constructor_validate_and_infer_types();
}
......@@ -40,7 +42,7 @@ void op::Negative::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVec
adjoints.add_delta(x, -delta);
}
shared_ptr<Node> ngraph::operator-(const shared_ptr<Node> arg0)
shared_ptr<Node> ngraph::operator-(const Output<Node>& arg0)
{
return make_shared<op::Negative>(arg0);
}
......@@ -26,17 +26,23 @@ namespace ngraph
class Negative : public util::UnaryElementwiseArithmetic
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a negative operation.
Negative() = default;
/// \brief Constructs a negative operation.
///
/// \param arg Node that produces the input tensor.
Negative(const std::shared_ptr<Node>& arg);
Negative(const Output<Node>& arg);
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
};
}
std::shared_ptr<ngraph::Node> operator-(const std::shared_ptr<ngraph::Node> arg0);
std::shared_ptr<Node> operator-(const Output<Node>& arg0);
}
......@@ -20,8 +20,10 @@
using namespace ngraph;
using namespace std;
op::Not::Not(const shared_ptr<Node>& arg)
: Op("Not", check_single_output_args({arg}))
const string op::Not::type_name{"Not"};
op::Not::Not(const Output<Node>& arg)
: Op({arg})
{
constructor_validate_and_infer_types();
}
......
......@@ -26,10 +26,15 @@ namespace ngraph
class Not : public Op
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a logical negation operation.
Not() = default;
/// \brief Constructs a logical negation operation.
///
/// \param arg Node that produces the input tensor.
Not(const std::shared_ptr<Node>& arg);
Not(const Output<Node>& arg);
void validate_and_infer_types() override;
......
......@@ -19,10 +19,12 @@
using namespace std;
using namespace ngraph;
op::NotEqual::NotEqual(const shared_ptr<Node>& arg0,
const shared_ptr<Node>& arg1,
const string op::NotEqual::type_name{"NotEqual"};
op::NotEqual::NotEqual(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& autob)
: BinaryElementwiseComparison("NotEqual", arg0, arg1, autob)
: BinaryElementwiseComparison(arg0, arg1, autob)
{
constructor_validate_and_infer_types();
}
......
......@@ -26,17 +26,24 @@ namespace ngraph
class NotEqual : public util::BinaryElementwiseComparison
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a not-equal operation.
NotEqual() = default;
/// \brief Constructs a not-equal operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param autob Auto broadcast specification
NotEqual(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1,
NotEqual(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& autob = AutoBroadcastSpec());
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
virtual bool is_commutative() const override { return true; }
};
}
}
......@@ -20,8 +20,10 @@
using namespace std;
using namespace ngraph;
op::OneHot::OneHot(const shared_ptr<Node>& arg, const PartialShape& shape, size_t one_hot_axis)
: Op("OneHot", check_single_output_args({arg}))
const string op::OneHot::type_name{"OneHot"};
op::OneHot::OneHot(const Output<Node>& arg, const PartialShape& shape, size_t one_hot_axis)
: Op({arg})
, m_shape(shape)
, m_one_hot_axis(one_hot_axis)
{
......
......@@ -45,14 +45,17 @@ namespace ngraph
class OneHot : public Op
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a one-hot operation.
OneHot() = default;
/// \brief Constructs a one-hot operation.
///
/// \param arg Node that produces the input tensor to be one-hot encoded.
/// \param shape The shape of the output tensor, including the new one-hot axis.
/// \param one_hot_axis The index within the output shape of the new one-hot axis.
OneHot(const std::shared_ptr<Node>& arg,
const PartialShape& shape,
size_t one_hot_axis);
OneHot(const Output<Node>& arg, const PartialShape& shape, size_t one_hot_axis);
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
......@@ -60,6 +63,7 @@ namespace ngraph
/// \return The index of the one-hot axis.
size_t get_one_hot_axis() const { return m_one_hot_axis; }
void set_one_hot_axis(size_t one_hot_axis) { m_one_hot_axis = one_hot_axis; }
protected:
PartialShape m_shape;
size_t m_one_hot_axis;
......
......@@ -19,10 +19,10 @@
using namespace std;
using namespace ngraph;
op::Or::Or(const shared_ptr<Node>& arg0,
const shared_ptr<Node>& arg1,
const AutoBroadcastSpec& autob)
: BinaryElementwiseLogical("Or", arg0, arg1, autob)
const string op::Or::type_name{"Or"};
op::Or::Or(const Output<Node>& arg0, const Output<Node>& arg1, const AutoBroadcastSpec& autob)
: BinaryElementwiseLogical(arg0, arg1, autob)
{
constructor_validate_and_infer_types();
}
......
......@@ -29,6 +29,9 @@ namespace ngraph
class Or : public util::BinaryElementwiseLogical
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a logical-or operation.
///
/// \param arg0 Node that produces the first input tensor.<br>
......@@ -39,15 +42,14 @@ namespace ngraph
///
/// Output `[d0, ...]`
///
Or(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1,
Or(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& autob = AutoBroadcastSpec());
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
protected:
virtual bool is_commutative() override { return true; }
virtual bool is_commutative() const override { return true; }
};
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment