Unverified Commit 61c7ba8c authored by Tomasz Dołbniak's avatar Tomasz Dołbniak Committed by GitHub

Merge branch 'master' into tomdol/pycapsule

parents 8c40ab98 feefdbb2
...@@ -16,9 +16,8 @@ ...@@ -16,9 +16,8 @@
import numpy as np import numpy as np
import ngraph as ng import ngraph as ng
from ngraph.utils.types import NumericData
from string import ascii_uppercase from typing import Any, Callable, List
import test import test
...@@ -32,10 +31,14 @@ def get_runtime(): ...@@ -32,10 +31,14 @@ def get_runtime():
def run_op_node(input_data, op_fun, *args): def run_op_node(input_data, op_fun, *args):
# type: (NumericData, Callable, *Any) -> List[NumericData]
"""Run computation on node performing `op_fun`. """Run computation on node performing `op_fun`.
`op_fun` has to accept a node as an argument. `op_fun` has to accept a node as an argument.
This function converts passed raw input data to nGraph Constant Node and that form is passed
to `op_fun`.
:param input_data: The input data for performed computation. :param input_data: The input data for performed computation.
:param op_fun: The function handler for operation we want to carry out. :param op_fun: The function handler for operation we want to carry out.
:param args: The arguments passed to operation we want to carry out. :param args: The arguments passed to operation we want to carry out.
...@@ -45,14 +48,8 @@ def run_op_node(input_data, op_fun, *args): ...@@ -45,14 +48,8 @@ def run_op_node(input_data, op_fun, *args):
comp_args = [] comp_args = []
op_fun_args = [] op_fun_args = []
comp_inputs = [] comp_inputs = []
for idx, data in enumerate(input_data): for data in input_data:
if np.isscalar(data): op_fun_args.append(ng.constant(data, _get_numpy_dtype(data)))
op_fun_args.append(ng.constant(data, _get_numpy_dtype(data)))
else:
node = ng.parameter(data.shape, name=ascii_uppercase[idx], dtype=data.dtype)
op_fun_args.append(node)
comp_args.append(node)
comp_inputs.append(data)
op_fun_args.extend(args) op_fun_args.extend(args)
node = op_fun(*op_fun_args) node = op_fun(*op_fun_args)
computation = runtime.computation(node, *comp_args) computation = runtime.computation(node, *comp_args)
...@@ -60,10 +57,15 @@ def run_op_node(input_data, op_fun, *args): ...@@ -60,10 +57,15 @@ def run_op_node(input_data, op_fun, *args):
def run_op_numeric_data(input_data, op_fun, *args): def run_op_numeric_data(input_data, op_fun, *args):
# type: (NumericData, Callable, *Any) -> List[NumericData]
"""Run computation on node performing `op_fun`. """Run computation on node performing `op_fun`.
`op_fun` has to accept a scalar or an array. `op_fun` has to accept a scalar or an array.
This function passess input data AS IS. This mean that in case they're a scalar (integral,
or floating point value) or a NumPy's ndarray object they will be automatically converted
to nGraph's Constant Nodes.
:param input_data: The input data for performed computation. :param input_data: The input data for performed computation.
:param op_fun: The function handler for operation we want to carry out. :param op_fun: The function handler for operation we want to carry out.
:param args: The arguments passed to operation we want to carry out. :param args: The arguments passed to operation we want to carry out.
......
...@@ -350,6 +350,7 @@ namespace ...@@ -350,6 +350,7 @@ namespace
} }
return callBackFuncPtr; return callBackFuncPtr;
} }
// NGDialect converters // NGDialect converters
Type NGraphTypeConverter::convertType(Type type) Type NGraphTypeConverter::convertType(Type type)
{ {
...@@ -576,7 +577,6 @@ namespace ...@@ -576,7 +577,6 @@ namespace
// Create Value for result, and extract type info. // Create Value for result, and extract type info.
Value* result = m_pass.buildOutputDefs(op, rewriter)[0]; Value* result = m_pass.buildOutputDefs(op, rewriter)[0];
NGRAPH_CHECK(result, "Unexpected null result in ConcatOp"); NGRAPH_CHECK(result, "Unexpected null result in ConcatOp");
auto resultTy = result->getType().cast<MemRefType>();
// Create view to write into result. // Create view to write into result.
MemRefView vRes(result); MemRefView vRes(result);
...@@ -590,7 +590,6 @@ namespace ...@@ -590,7 +590,6 @@ namespace
for (auto& operand : operands) for (auto& operand : operands)
{ {
NGRAPH_CHECK(operand, "Unexpected null operand in ConcatOp"); NGRAPH_CHECK(operand, "Unexpected null operand in ConcatOp");
auto operandTy = result->getType().cast<MemRefType>();
// Assuming rank = r, and the concatenation axis is A where A<r, we'll be creating // Assuming rank = r, and the concatenation axis is A where A<r, we'll be creating
// loops of this form: // loops of this form:
......
...@@ -74,7 +74,6 @@ void MLIRSubgraphExtractionPass::MLIRSubgraph::merge(MLIRSubgraph& sg2) ...@@ -74,7 +74,6 @@ void MLIRSubgraphExtractionPass::MLIRSubgraph::merge(MLIRSubgraph& sg2)
// Associate nodes of second sub-graph to first one // Associate nodes of second sub-graph to first one
auto sg_nodes = sg2.get_nodes(); auto sg_nodes = sg2.get_nodes();
auto& node_map = m_pass.m_node_to_graph;
for (auto node : sg_nodes) for (auto node : sg_nodes)
{ {
NGRAPH_DEBUG << *node; NGRAPH_DEBUG << *node;
...@@ -112,7 +111,6 @@ bool MLIRSubgraphExtractionPass::run_on_function(std::shared_ptr<Function> func) ...@@ -112,7 +111,6 @@ bool MLIRSubgraphExtractionPass::run_on_function(std::shared_ptr<Function> func)
for (auto op : func->get_ordered_ops()) for (auto op : func->get_ordered_ops())
{ {
NodeVector inputs; NodeVector inputs;
int first_graph_id = -1;
std::unordered_set<int> subgraph_ids; std::unordered_set<int> subgraph_ids;
// unsupported ops, skip // unsupported ops, skip
if (!is_supported_mlir_op(op)) if (!is_supported_mlir_op(op))
......
...@@ -160,5 +160,5 @@ namespace ngraph ...@@ -160,5 +160,5 @@ namespace ngraph
/// \brief Macro to signal a code path that is unreachable in a successful execution. It's /// \brief Macro to signal a code path that is unreachable in a successful execution. It's
/// implemented with NGRAPH_CHECK macro. /// implemented with NGRAPH_CHECK macro.
/// \param ... Additional error message that should describe why that execution path is unreachable. /// \param ... Additional error message that should describe why that execution path is unreachable.
/// \throws ::ngrap::CheckFailure if the macro is executed. /// \throws ::ngraph::CheckFailure if the macro is executed.
#define NGRAPH_UNREACHABLE(...) NGRAPH_CHECK(false, "Unreachable: ", ##__VA_ARGS__) #define NGRAPH_UNREACHABLE(...) NGRAPH_CHECK(false, "Unreachable: ", ##__VA_ARGS__)
...@@ -214,7 +214,7 @@ namespace ngraph ...@@ -214,7 +214,7 @@ namespace ngraph
virtual bool is_constant() const; virtual bool is_constant() const;
virtual bool is_null() const { return false; } virtual bool is_null() const { return false; }
virtual bool is_op() const { return false; } virtual bool is_op() const { return false; }
virtual bool is_commutative() { return false; } virtual bool is_commutative() const { return false; }
virtual bool is_dynamic() const; virtual bool is_dynamic() const;
virtual bool has_state() const { return false; } virtual bool has_state() const { return false; }
size_t get_instance_id() const { return m_instance_id; } size_t get_instance_id() const { return m_instance_id; }
......
...@@ -49,7 +49,7 @@ void op::Add::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& ...@@ -49,7 +49,7 @@ void op::Add::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector&
adjoints.add_delta(y, delta); adjoints.add_delta(y, delta);
} }
shared_ptr<Node> ngraph::operator+(const shared_ptr<Node>& arg0, const shared_ptr<Node>& arg1) shared_ptr<Node> ngraph::operator+(const Output<Node>& arg0, const Output<Node>& arg1)
{ {
return make_shared<op::Add>(arg0, arg1); return make_shared<op::Add>(arg0, arg1);
} }
...@@ -51,13 +51,12 @@ namespace ngraph ...@@ -51,13 +51,12 @@ namespace ngraph
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override; std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
virtual bool is_commutative() const override { return true; }
protected: protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override; const NodeVector& deltas) override;
virtual bool is_commutative() override { return true; }
}; };
} }
std::shared_ptr<ngraph::Node> operator+(const std::shared_ptr<ngraph::Node>& arg0, std::shared_ptr<Node> operator+(const Output<Node>& arg0, const Output<Node>& arg1);
const std::shared_ptr<ngraph::Node>& arg1);
} }
...@@ -51,8 +51,7 @@ namespace ngraph ...@@ -51,8 +51,7 @@ namespace ngraph
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override; std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
protected: virtual bool is_commutative() const override { return true; }
virtual bool is_commutative() override { return true; }
}; };
} }
} }
...@@ -22,12 +22,15 @@ ...@@ -22,12 +22,15 @@
#include "ngraph/op/get_output_element.hpp" #include "ngraph/op/get_output_element.hpp"
#include "ngraph/validation_util.hpp" #include "ngraph/validation_util.hpp"
const std::string ngraph::op::BatchNormTraining::type_name{"BatchNormTraining"}; using namespace std;
using namespace ngraph;
ngraph::op::BatchNormTraining::BatchNormTraining(Output<ngraph::Node> input, const string op::BatchNormTraining::type_name{"BatchNormTraining"};
Output<ngraph::Node> gamma,
Output<ngraph::Node> beta, op::BatchNormTraining::BatchNormTraining(const Output<Node>& input,
double epsilon) const Output<Node>& gamma,
const Output<Node>& beta,
double epsilon)
: Op({gamma, beta, input}) : Op({gamma, beta, input})
, m_epsilon(epsilon) , m_epsilon(epsilon)
{ {
...@@ -35,17 +38,17 @@ ngraph::op::BatchNormTraining::BatchNormTraining(Output<ngraph::Node> input, ...@@ -35,17 +38,17 @@ ngraph::op::BatchNormTraining::BatchNormTraining(Output<ngraph::Node> input,
} }
// DEPRECATED // DEPRECATED
ngraph::op::BatchNormTraining::BatchNormTraining(double eps, op::BatchNormTraining::BatchNormTraining(double eps,
Output<ngraph::Node> gamma, const Output<Node>& gamma,
Output<ngraph::Node> beta, const Output<Node>& beta,
Output<ngraph::Node> input) const Output<Node>& input)
: Op({gamma, beta, input}) : Op({gamma, beta, input})
, m_epsilon(eps) , m_epsilon(eps)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
void ngraph::op::BatchNormTraining::validate_and_infer_types() void op::BatchNormTraining::validate_and_infer_types()
{ {
element::Type result_et; element::Type result_et;
PartialShape result_batch_shape; PartialShape result_batch_shape;
...@@ -66,16 +69,15 @@ void ngraph::op::BatchNormTraining::validate_and_infer_types() ...@@ -66,16 +69,15 @@ void ngraph::op::BatchNormTraining::validate_and_infer_types()
set_output_type(2, result_et, result_channel_shape); set_output_type(2, result_et, result_channel_shape);
} }
std::shared_ptr<ngraph::Node> std::shared_ptr<Node> op::BatchNormTraining::copy_with_new_args(const NodeVector& new_args) const
ngraph::op::BatchNormTraining::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
return std::make_shared<BatchNormTraining>( return std::make_shared<BatchNormTraining>(
new_args.at(2), new_args.at(0), new_args.at(1), m_epsilon); new_args.at(2), new_args.at(0), new_args.at(1), m_epsilon);
} }
void ngraph::op::BatchNormTraining::generate_adjoints(autodiff::Adjoints& adjoints, void op::BatchNormTraining::generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) const NodeVector& deltas)
{ {
auto gamma = input(0).get_source_output(); auto gamma = input(0).get_source_output();
auto beta = input(1).get_source_output(); auto beta = input(1).get_source_output();
...@@ -102,14 +104,14 @@ void ngraph::op::BatchNormTraining::generate_adjoints(autodiff::Adjoints& adjoin ...@@ -102,14 +104,14 @@ void ngraph::op::BatchNormTraining::generate_adjoints(autodiff::Adjoints& adjoin
adjoints.add_delta(beta, dbeta); adjoints.add_delta(beta, dbeta);
} }
const std::string ngraph::op::BatchNormInference::type_name{"BatchNormInference"}; const string op::BatchNormInference::type_name{"BatchNormInference"};
ngraph::op::BatchNormInference::BatchNormInference(Output<ngraph::Node> input, op::BatchNormInference::BatchNormInference(const Output<Node>& input,
Output<ngraph::Node> gamma, const Output<Node>& gamma,
Output<ngraph::Node> beta, const Output<Node>& beta,
Output<ngraph::Node> mean, const Output<Node>& mean,
Output<ngraph::Node> variance, const Output<Node>& variance,
double epsilon) double epsilon)
: Op({gamma, beta, input, mean, variance}) : Op({gamma, beta, input, mean, variance})
, m_epsilon(epsilon) , m_epsilon(epsilon)
{ {
...@@ -117,19 +119,19 @@ ngraph::op::BatchNormInference::BatchNormInference(Output<ngraph::Node> input, ...@@ -117,19 +119,19 @@ ngraph::op::BatchNormInference::BatchNormInference(Output<ngraph::Node> input,
} }
// DEPRECATED // DEPRECATED
ngraph::op::BatchNormInference::BatchNormInference(double eps, op::BatchNormInference::BatchNormInference(double eps,
Output<ngraph::Node> gamma, const Output<Node>& gamma,
Output<ngraph::Node> beta, const Output<Node>& beta,
Output<ngraph::Node> input, const Output<Node>& input,
Output<ngraph::Node> mean, const Output<Node>& mean,
Output<ngraph::Node> variance) const Output<Node>& variance)
: Op({gamma, beta, input, mean, variance}) : Op({gamma, beta, input, mean, variance})
, m_epsilon(eps) , m_epsilon(eps)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
void ngraph::op::BatchNormInference::validate_and_infer_types() void op::BatchNormInference::validate_and_infer_types()
{ {
element::Type result_et; element::Type result_et;
PartialShape result_batch_shape; PartialShape result_batch_shape;
...@@ -152,23 +154,22 @@ void ngraph::op::BatchNormInference::validate_and_infer_types() ...@@ -152,23 +154,22 @@ void ngraph::op::BatchNormInference::validate_and_infer_types()
set_output_type(0, result_et, result_batch_shape); set_output_type(0, result_et, result_batch_shape);
} }
std::shared_ptr<ngraph::Node> std::shared_ptr<Node> op::BatchNormInference::copy_with_new_args(const NodeVector& new_args) const
ngraph::op::BatchNormInference::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
return std::make_shared<BatchNormInference>( return std::make_shared<BatchNormInference>(
new_args.at(2), new_args.at(0), new_args.at(1), new_args.at(3), new_args.at(4), m_epsilon); new_args.at(2), new_args.at(0), new_args.at(1), new_args.at(3), new_args.at(4), m_epsilon);
} }
const std::string ngraph::op::BatchNormTrainingBackprop::type_name{"BatchNormTrainingBackprop"}; const string op::BatchNormTrainingBackprop::type_name{"BatchNormTrainingBackprop"};
ngraph::op::BatchNormTrainingBackprop::BatchNormTrainingBackprop(Output<ngraph::Node> input, op::BatchNormTrainingBackprop::BatchNormTrainingBackprop(const Output<Node>& input,
Output<ngraph::Node> gamma, const Output<Node>& gamma,
Output<ngraph::Node> beta, const Output<Node>& beta,
Output<ngraph::Node> mean, const Output<Node>& mean,
Output<ngraph::Node> variance, const Output<Node>& variance,
Output<ngraph::Node> delta, const Output<Node>& delta,
double epsilon) double epsilon)
: Op({gamma, beta, input, mean, variance, delta}) : Op({gamma, beta, input, mean, variance, delta})
, m_epsilon(epsilon) , m_epsilon(epsilon)
...@@ -177,13 +178,13 @@ ngraph::op::BatchNormTrainingBackprop::BatchNormTrainingBackprop(Output<ngraph:: ...@@ -177,13 +178,13 @@ ngraph::op::BatchNormTrainingBackprop::BatchNormTrainingBackprop(Output<ngraph::
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
ngraph::op::BatchNormTrainingBackprop::BatchNormTrainingBackprop(double epsilon, op::BatchNormTrainingBackprop::BatchNormTrainingBackprop(double epsilon,
Output<ngraph::Node> gamma, const Output<Node>& gamma,
Output<ngraph::Node> beta, const Output<Node>& beta,
Output<ngraph::Node> input, const Output<Node>& input,
Output<ngraph::Node> mean, const Output<Node>& mean,
Output<ngraph::Node> variance, const Output<Node>& variance,
Output<ngraph::Node> delta) const Output<Node>& delta)
: Op({gamma, beta, input, mean, variance, delta}) : Op({gamma, beta, input, mean, variance, delta})
, m_epsilon(epsilon) , m_epsilon(epsilon)
...@@ -192,7 +193,7 @@ ngraph::op::BatchNormTrainingBackprop::BatchNormTrainingBackprop(double epsilon, ...@@ -192,7 +193,7 @@ ngraph::op::BatchNormTrainingBackprop::BatchNormTrainingBackprop(double epsilon,
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
void ngraph::op::BatchNormTrainingBackprop::validate_and_infer_types() void op::BatchNormTrainingBackprop::validate_and_infer_types()
{ {
PartialShape input_and_delta_shape{get_input_partial_shape(INPUT_DATA)}; PartialShape input_and_delta_shape{get_input_partial_shape(INPUT_DATA)};
...@@ -239,8 +240,8 @@ void ngraph::op::BatchNormTrainingBackprop::validate_and_infer_types() ...@@ -239,8 +240,8 @@ void ngraph::op::BatchNormTrainingBackprop::validate_and_infer_types()
set_output_type(2, result_et, result_channel_shape); set_output_type(2, result_et, result_channel_shape);
} }
std::shared_ptr<ngraph::Node> std::shared_ptr<Node>
ngraph::op::BatchNormTrainingBackprop::copy_with_new_args(const NodeVector& new_args) const op::BatchNormTrainingBackprop::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
return std::make_shared<op::BatchNormTrainingBackprop>(new_args.at(2), return std::make_shared<op::BatchNormTrainingBackprop>(new_args.at(2),
......
...@@ -39,9 +39,9 @@ namespace ngraph ...@@ -39,9 +39,9 @@ namespace ngraph
/// \param gamma gamma scaling for normalized value. [C] /// \param gamma gamma scaling for normalized value. [C]
/// \param beta bias added to the scaled normalized value [C] /// \param beta bias added to the scaled normalized value [C]
/// \param epsilon Avoids divsion by 0 if input has 0 variance /// \param epsilon Avoids divsion by 0 if input has 0 variance
BatchNormTraining(Output<Node> input, BatchNormTraining(const Output<Node>& input,
Output<Node> gamma, const Output<Node>& gamma,
Output<Node> beta, const Output<Node>& beta,
double epsilon); double epsilon);
NGRAPH_DEPRECATED_DOC NGRAPH_DEPRECATED_DOC
...@@ -66,9 +66,9 @@ namespace ngraph ...@@ -66,9 +66,9 @@ namespace ngraph
/// output[2]: shall have rank 1, with the same span as input's channel axis. /// output[2]: shall have rank 1, with the same span as input's channel axis.
NGRAPH_DEPRECATED("Use another constructor") NGRAPH_DEPRECATED("Use another constructor")
BatchNormTraining(double eps, BatchNormTraining(double eps,
Output<Node> gamma, const Output<Node>& gamma,
Output<Node> beta, const Output<Node>& beta,
Output<Node> input); const Output<Node>& input);
void validate_and_infer_types() override; void validate_and_infer_types() override;
...@@ -101,11 +101,11 @@ namespace ngraph ...@@ -101,11 +101,11 @@ namespace ngraph
/// \param mean value for mean normalization [C] /// \param mean value for mean normalization [C]
/// \param variance value for variance normalization [C] /// \param variance value for variance normalization [C]
/// \param epsilon Avoids divsion by 0 if input has 0 variance /// \param epsilon Avoids divsion by 0 if input has 0 variance
BatchNormInference(Output<ngraph::Node> input, BatchNormInference(const Output<Node>& input,
Output<ngraph::Node> gamma, const Output<Node>& gamma,
Output<ngraph::Node> beta, const Output<Node>& beta,
Output<ngraph::Node> mean, const Output<Node>& mean,
Output<ngraph::Node> variance, const Output<Node>& variance,
double epsilon); double epsilon);
NGRAPH_DEPRECATED_DOC NGRAPH_DEPRECATED_DOC
...@@ -128,11 +128,11 @@ namespace ngraph ...@@ -128,11 +128,11 @@ namespace ngraph
/// output: shall have the same shape as 'input'. /// output: shall have the same shape as 'input'.
NGRAPH_DEPRECATED("Use another constructor") NGRAPH_DEPRECATED("Use another constructor")
BatchNormInference(double eps, BatchNormInference(double eps,
Output<ngraph::Node> gamma, const Output<Node>& gamma,
Output<ngraph::Node> beta, const Output<Node>& beta,
Output<ngraph::Node> input, const Output<Node>& input,
Output<ngraph::Node> mean, const Output<Node>& mean,
Output<ngraph::Node> variance); const Output<Node>& variance);
void validate_and_infer_types() override; void validate_and_infer_types() override;
...@@ -165,24 +165,23 @@ namespace ngraph ...@@ -165,24 +165,23 @@ namespace ngraph
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
BatchNormTrainingBackprop() = default; BatchNormTrainingBackprop() = default;
BatchNormTrainingBackprop(Output<Node> input, BatchNormTrainingBackprop(const Output<Node>& input,
Output<Node> gamma, const Output<Node>& gamma,
Output<Node> beta, const Output<Node>& beta,
Output<Node> mean, const Output<Node>& mean,
Output<Node> variance, const Output<Node>& variance,
Output<Node> delta, const Output<Node>& delta,
double epsilon); double epsilon);
NGRAPH_DEPRECATED_DOC NGRAPH_DEPRECATED_DOC
NGRAPH_DEPRECATED("Use another constructor") NGRAPH_DEPRECATED("Use another constructor")
BatchNormTrainingBackprop(double epsilon, BatchNormTrainingBackprop(double epsilon,
Output<Node> gamma, const Output<Node>& gamma,
Output<Node> beta, const Output<Node>& beta,
Output<Node> input, const Output<Node>& input,
const Output<Node>& mean,
Output<Node> mean, const Output<Node>& variance,
Output<Node> variance, const Output<Node>& delta);
Output<Node> delta);
void validate_and_infer_types() override; void validate_and_infer_types() override;
......
...@@ -64,7 +64,7 @@ void op::Divide::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVecto ...@@ -64,7 +64,7 @@ void op::Divide::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVecto
adjoints.add_delta(y, -delta * shared_from_this() / y); adjoints.add_delta(y, -delta * shared_from_this() / y);
} }
shared_ptr<Node> ngraph::operator/(const Output<Node> arg0, const Output<Node> arg1) shared_ptr<Node> ngraph::operator/(const Output<Node>& arg0, const Output<Node>& arg1)
{ {
return make_shared<op::Divide>(arg0, arg1); return make_shared<op::Divide>(arg0, arg1);
} }
...@@ -64,6 +64,5 @@ namespace ngraph ...@@ -64,6 +64,5 @@ namespace ngraph
}; };
} }
std::shared_ptr<ngraph::Node> operator/(const Output<ngraph::Node> arg0, std::shared_ptr<Node> operator/(const Output<Node>& arg0, const Output<Node>& arg1);
const Output<ngraph::Node> arg1);
} }
...@@ -58,7 +58,7 @@ namespace ngraph ...@@ -58,7 +58,7 @@ namespace ngraph
void validate_and_infer_types() override; void validate_and_infer_types() override;
size_t get_reduction_axes_count() const { return m_reduction_axes_count; } size_t get_reduction_axes_count() const { return m_reduction_axes_count; }
void get_reduction_axes_count(size_t reduction_axes_count) void set_reduction_axes_count(size_t reduction_axes_count)
{ {
m_reduction_axes_count = reduction_axes_count; m_reduction_axes_count = reduction_axes_count;
} }
......
...@@ -56,6 +56,8 @@ namespace ngraph ...@@ -56,6 +56,8 @@ namespace ngraph
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
virtual bool is_commutative() const override { return true; }
}; };
} }
} }
...@@ -23,6 +23,8 @@ using namespace ngraph; ...@@ -23,6 +23,8 @@ using namespace ngraph;
static int PARAMS = 0; static int PARAMS = 0;
static int INDICES = 1; static int INDICES = 1;
const string op::Gather::type_name{"Gather"};
shared_ptr<Node> op::Gather::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Gather::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
......
...@@ -26,13 +26,15 @@ namespace ngraph ...@@ -26,13 +26,15 @@ namespace ngraph
class Gather : public Op class Gather : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
Gather() = default;
/// \param params The tensor from which slices are gathered /// \param params The tensor from which slices are gathered
/// \param indices Index tensor: Data type must be `element::i32` or `element::i64` /// \param indices Index tensor: Data type must be `element::i32` or `element::i64`
/// \param axis Axis in params to gather /// \param axis Axis in params to gather
Gather(const std::shared_ptr<Node>& params, Gather(const Output<Node>& params, const Output<Node>& indices, size_t axis = 0)
const std::shared_ptr<Node>& indices, : Op({params, indices})
size_t axis = 0)
: Op("Gather", check_single_output_args({params, indices}))
, m_axis(axis) , m_axis(axis)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
...@@ -46,6 +48,7 @@ namespace ngraph ...@@ -46,6 +48,7 @@ namespace ngraph
} }
size_t get_axis() const { return m_axis; } size_t get_axis() const { return m_axis; }
void set_axis(size_t axis) { m_axis = axis; }
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
......
...@@ -23,6 +23,8 @@ using namespace ngraph; ...@@ -23,6 +23,8 @@ using namespace ngraph;
static int PARAMS = 0; static int PARAMS = 0;
static int INDICES = 1; static int INDICES = 1;
const string op::GatherND::type_name{"GatherND"};
shared_ptr<Node> op::GatherND::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::GatherND::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
......
...@@ -26,10 +26,14 @@ namespace ngraph ...@@ -26,10 +26,14 @@ namespace ngraph
class GatherND : public Op class GatherND : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
GatherND() = default;
/// \param params The tensor from which slices are gathered /// \param params The tensor from which slices are gathered
/// \param indices Index tensor: Data type must be `element::i32` or `element::i64` /// \param indices Index tensor: Data type must be `element::i32` or `element::i64`
GatherND(const std::shared_ptr<Node>& params, const std::shared_ptr<Node>& indices) GatherND(const Output<Node>& params, const Output<Node>& indices)
: Op("GatherND", check_single_output_args({params, indices})) : Op({params, indices})
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -19,10 +19,12 @@ ...@@ -19,10 +19,12 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::Greater::Greater(const shared_ptr<Node>& arg0, const string op::Greater::type_name{"Greater"};
const shared_ptr<Node>& arg1,
op::Greater::Greater(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& autob) const AutoBroadcastSpec& autob)
: BinaryElementwiseComparison("Greater", arg0, arg1, autob) : BinaryElementwiseComparison(arg0, arg1, autob)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -26,13 +26,18 @@ namespace ngraph ...@@ -26,13 +26,18 @@ namespace ngraph
class Greater : public util::BinaryElementwiseComparison class Greater : public util::BinaryElementwiseComparison
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a greater-than operation.
Greater() = default;
/// \brief Constructs a greater-than operation. /// \brief Constructs a greater-than operation.
/// ///
/// \param arg0 Node that produces the first input tensor. /// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor. /// \param arg1 Node that produces the second input tensor.
/// \param autob Auto broadcast specification /// \param autob Auto broadcast specification
Greater(const std::shared_ptr<Node>& arg0, Greater(const Output<Node>& arg0,
const std::shared_ptr<Node>& arg1, const Output<Node>& arg1,
const AutoBroadcastSpec& autob = AutoBroadcastSpec()); const AutoBroadcastSpec& autob = AutoBroadcastSpec());
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
......
...@@ -19,10 +19,12 @@ ...@@ -19,10 +19,12 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::GreaterEq::GreaterEq(const shared_ptr<Node>& arg0, const string op::GreaterEq::type_name{"GreaterEq"};
const shared_ptr<Node>& arg1,
op::GreaterEq::GreaterEq(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& autob) const AutoBroadcastSpec& autob)
: BinaryElementwiseComparison("GreaterEq", arg0, arg1, autob) : BinaryElementwiseComparison(arg0, arg1, autob)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -26,13 +26,18 @@ namespace ngraph ...@@ -26,13 +26,18 @@ namespace ngraph
class GreaterEq : public util::BinaryElementwiseComparison class GreaterEq : public util::BinaryElementwiseComparison
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a greater-than-or-equal operation.
GreaterEq() = default;
/// \brief Constructs a greater-than-or-equal operation. /// \brief Constructs a greater-than-or-equal operation.
/// ///
/// \param arg0 Node that produces the first input tensor. /// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor. /// \param arg1 Node that produces the second input tensor.
/// \param autob Auto broadcast specification /// \param autob Auto broadcast specification
GreaterEq(const std::shared_ptr<Node>& arg0, GreaterEq(const Output<Node>& arg0,
const std::shared_ptr<Node>& arg1, const Output<Node>& arg1,
const AutoBroadcastSpec& autob = AutoBroadcastSpec()); const AutoBroadcastSpec& autob = AutoBroadcastSpec());
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
......
...@@ -19,10 +19,10 @@ ...@@ -19,10 +19,10 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::Less::Less(const shared_ptr<Node>& arg0, const string op::Less::type_name{"Less"};
const shared_ptr<Node>& arg1,
const AutoBroadcastSpec& autob) op::Less::Less(const Output<Node>& arg0, const Output<Node>& arg1, const AutoBroadcastSpec& autob)
: BinaryElementwiseComparison("Less", arg0, arg1, autob) : BinaryElementwiseComparison(arg0, arg1, autob)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -26,13 +26,18 @@ namespace ngraph ...@@ -26,13 +26,18 @@ namespace ngraph
class Less : public util::BinaryElementwiseComparison class Less : public util::BinaryElementwiseComparison
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a less-than operation.
Less() = default;
/// \brief Constructs a less-than operation. /// \brief Constructs a less-than operation.
/// ///
/// \param arg0 Node that produces the first input tensor. /// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor. /// \param arg1 Node that produces the second input tensor.
/// \param autob Auto broadcast specification /// \param autob Auto broadcast specification
Less(const std::shared_ptr<Node>& arg0, Less(const Output<Node>& arg0,
const std::shared_ptr<Node>& arg1, const Output<Node>& arg1,
const AutoBroadcastSpec& autob = AutoBroadcastSpec()); const AutoBroadcastSpec& autob = AutoBroadcastSpec());
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
......
...@@ -19,10 +19,12 @@ ...@@ -19,10 +19,12 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::LessEq::LessEq(const shared_ptr<Node>& arg0, const string op::LessEq::type_name{"LessEq"};
const shared_ptr<Node>& arg1,
op::LessEq::LessEq(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& autob) const AutoBroadcastSpec& autob)
: BinaryElementwiseComparison("LessEq", arg0, arg1, autob) : BinaryElementwiseComparison(arg0, arg1, autob)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -26,13 +26,18 @@ namespace ngraph ...@@ -26,13 +26,18 @@ namespace ngraph
class LessEq : public util::BinaryElementwiseComparison class LessEq : public util::BinaryElementwiseComparison
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a less-than-or-equal operation.
LessEq() = default;
/// \brief Constructs a less-than-or-equal operation. /// \brief Constructs a less-than-or-equal operation.
/// ///
/// \param arg0 Node that produces the first input tensor. /// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor. /// \param arg1 Node that produces the second input tensor.
/// \param autob Auto broadcast specification /// \param autob Auto broadcast specification
LessEq(const std::shared_ptr<Node>& arg0, LessEq(const Output<Node>& arg0,
const std::shared_ptr<Node>& arg1, const Output<Node>& arg1,
const AutoBroadcastSpec& autob = AutoBroadcastSpec()); const AutoBroadcastSpec& autob = AutoBroadcastSpec());
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
......
...@@ -20,8 +20,10 @@ ...@@ -20,8 +20,10 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::Log::Log(const shared_ptr<Node>& arg) const string op::Log::type_name{"Log"};
: UnaryElementwiseArithmetic("Log", arg)
op::Log::Log(const Output<Node>& arg)
: UnaryElementwiseArithmetic(arg)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -26,10 +26,15 @@ namespace ngraph ...@@ -26,10 +26,15 @@ namespace ngraph
class Log : public util::UnaryElementwiseArithmetic class Log : public util::UnaryElementwiseArithmetic
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a natural log operation.
Log() = default;
/// \brief Constructs a natural log operation. /// \brief Constructs a natural log operation.
/// ///
/// \param arg Node that produces the input tensor. /// \param arg Node that produces the input tensor.
Log(const std::shared_ptr<Node>& arg); Log(const Output<Node>& arg);
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
......
...@@ -20,12 +20,14 @@ ...@@ -20,12 +20,14 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::LRN::LRN(const std::shared_ptr<Node>& arg, double alpha, double beta, double bias, size_t nsize) const string op::LRN::type_name{"LRN"};
: UnaryElementwiseArithmetic("LRN", arg)
op::LRN::LRN(const Output<Node>& arg, double alpha, double beta, double bias, size_t size)
: UnaryElementwiseArithmetic(arg)
, m_alpha(alpha) , m_alpha(alpha)
, m_beta(beta) , m_beta(beta)
, m_bias(bias) , m_bias(bias)
, m_size(nsize) , m_size(size)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -38,23 +38,28 @@ namespace ngraph ...@@ -38,23 +38,28 @@ namespace ngraph
class LRN : public util::UnaryElementwiseArithmetic class LRN : public util::UnaryElementwiseArithmetic
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a LRN operation.
LRN() = default;
/// \brief Constructs a LRN operation. /// \brief Constructs a LRN operation.
/// ///
/// \param arg Node that produces the input tensor. /// \param arg Node that produces the input tensor.
LRN(const std::shared_ptr<Node>& arg, LRN(const Output<Node>& arg, double alpha, double beta, double bias, size_t size);
double alpha,
double beta,
double bias,
size_t size);
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
void validate_and_infer_types() override; void validate_and_infer_types() override;
double get_alpha() const { return m_alpha; } double get_alpha() const { return m_alpha; }
void set_alpha(double alpha) { m_alpha = alpha; }
double get_beta() const { return m_beta; } double get_beta() const { return m_beta; }
void set_beta(double beta) { m_beta = beta; }
double get_bias() const { return m_bias; } double get_bias() const { return m_bias; }
void set_bias(double bias) { m_bias = bias; }
size_t get_nsize() const { return m_size; } size_t get_nsize() const { return m_size; }
void set_nsize(size_t size) { m_size = size; }
protected: protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override; const NodeVector& deltas) override;
......
...@@ -22,10 +22,6 @@ using namespace ngraph; ...@@ -22,10 +22,6 @@ using namespace ngraph;
const string op::Max::type_name{"Max"}; const string op::Max::type_name{"Max"};
op::Max::Max()
{
}
op::Max::Max(const Output<Node>& arg, const AxisSet& reduction_axes) op::Max::Max(const Output<Node>& arg, const AxisSet& reduction_axes)
: ArithmeticReduction(arg, reduction_axes) : ArithmeticReduction(arg, reduction_axes)
{ {
......
...@@ -30,7 +30,7 @@ namespace ngraph ...@@ -30,7 +30,7 @@ namespace ngraph
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
/// \brief Constructs a "max" reduction operation. /// \brief Constructs a "max" reduction operation.
Max(); Max() = default;
/// \brief Constructs a max-reduction operation. /// \brief Constructs a max-reduction operation.
/// ///
/// \param arg The tensor to be reduced. /// \param arg The tensor to be reduced.
......
...@@ -25,14 +25,16 @@ ...@@ -25,14 +25,16 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::MaxPool::MaxPool(const shared_ptr<Node>& arg, const string op::MaxPool::type_name{"MaxPool"};
op::MaxPool::MaxPool(const Output<Node>& arg,
const Shape& window_shape, const Shape& window_shape,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Shape& padding_below, const Shape& padding_below,
const Shape& padding_above, const Shape& padding_above,
const PadType& pad_type, const PadType& pad_type,
bool ceil_mode) bool ceil_mode)
: Op("MaxPool", check_single_output_args({arg})) : Op({arg})
, m_window_shape(window_shape) , m_window_shape(window_shape)
, m_window_movement_strides(window_movement_strides) , m_window_movement_strides(window_movement_strides)
, m_padding_below(padding_below) , m_padding_below(padding_below)
...@@ -43,7 +45,7 @@ op::MaxPool::MaxPool(const shared_ptr<Node>& arg, ...@@ -43,7 +45,7 @@ op::MaxPool::MaxPool(const shared_ptr<Node>& arg,
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
op::MaxPool::MaxPool(const shared_ptr<Node>& arg, op::MaxPool::MaxPool(const Output<Node>& arg,
const Shape& window_shape, const Shape& window_shape,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Shape& padding_below, const Shape& padding_below,
...@@ -54,7 +56,7 @@ op::MaxPool::MaxPool(const shared_ptr<Node>& arg, ...@@ -54,7 +56,7 @@ op::MaxPool::MaxPool(const shared_ptr<Node>& arg,
{ {
} }
op::MaxPool::MaxPool(const shared_ptr<Node>& arg, op::MaxPool::MaxPool(const Output<Node>& arg,
const Shape& window_shape, const Shape& window_shape,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Shape& padding_below, const Shape& padding_below,
...@@ -121,14 +123,14 @@ void op::MaxPool::validate_and_infer_types() ...@@ -121,14 +123,14 @@ void op::MaxPool::validate_and_infer_types()
m_ceil_mode)); m_ceil_mode));
} }
op::MaxPool::MaxPool(const shared_ptr<Node>& arg, op::MaxPool::MaxPool(const Output<Node>& arg,
const Shape& window_shape, const Shape& window_shape,
const Strides& window_movement_strides) const Strides& window_movement_strides)
: MaxPool(arg, window_shape, window_movement_strides, Shape(), Shape()) : MaxPool(arg, window_shape, window_movement_strides, Shape(), Shape())
{ {
} }
op::MaxPool::MaxPool(const shared_ptr<Node>& arg, const Shape& window_shape) op::MaxPool::MaxPool(const Output<Node>& arg, const Shape& window_shape)
: MaxPool(arg, window_shape, Strides(), Shape(), Shape()) : MaxPool(arg, window_shape, Strides(), Shape(), Shape())
{ {
} }
...@@ -145,13 +147,15 @@ shared_ptr<Node> op::MaxPool::copy_with_new_args(const NodeVector& new_args) con ...@@ -145,13 +147,15 @@ shared_ptr<Node> op::MaxPool::copy_with_new_args(const NodeVector& new_args) con
m_ceil_mode); m_ceil_mode);
} }
op::MaxPoolBackprop::MaxPoolBackprop(const shared_ptr<Node>& arg_forward, const string op::MaxPoolBackprop::type_name{"MaxPoolBackprop"};
const shared_ptr<Node>& delta,
op::MaxPoolBackprop::MaxPoolBackprop(const Output<Node>& arg_forward,
const Output<Node>& delta,
const Shape& window_shape, const Shape& window_shape,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Shape& padding_below, const Shape& padding_below,
const Shape& padding_above) const Shape& padding_above)
: Op("MaxPoolBackprop", check_single_output_args({arg_forward, delta})) : Op({arg_forward, delta})
, m_window_shape(window_shape) , m_window_shape(window_shape)
, m_window_movement_strides(window_movement_strides) , m_window_movement_strides(window_movement_strides)
, m_padding_below(padding_below) , m_padding_below(padding_below)
...@@ -160,14 +164,14 @@ op::MaxPoolBackprop::MaxPoolBackprop(const shared_ptr<Node>& arg_forward, ...@@ -160,14 +164,14 @@ op::MaxPoolBackprop::MaxPoolBackprop(const shared_ptr<Node>& arg_forward,
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
op::MaxPoolBackprop::MaxPoolBackprop(const shared_ptr<Node>& arg_forward, op::MaxPoolBackprop::MaxPoolBackprop(const Output<Node>& arg_forward,
const shared_ptr<Node>& delta, const Output<Node>& delta,
const shared_ptr<Node>& result_forward, const Output<Node>& result_forward,
const Shape& window_shape, const Shape& window_shape,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Shape& padding_below, const Shape& padding_below,
const Shape& padding_above) const Shape& padding_above)
: Op("MaxPoolBackprop", check_single_output_args({arg_forward, delta, result_forward})) : Op({arg_forward, delta, result_forward})
, m_window_shape(window_shape) , m_window_shape(window_shape)
, m_window_movement_strides(window_movement_strides) , m_window_movement_strides(window_movement_strides)
, m_padding_below(padding_below) , m_padding_below(padding_below)
......
...@@ -28,6 +28,12 @@ namespace ngraph ...@@ -28,6 +28,12 @@ namespace ngraph
class MaxPool : public Op class MaxPool : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a batched max pooling operation.
MaxPool() = default;
/// \brief Constructs a batched max pooling operation. /// \brief Constructs a batched max pooling operation.
/// ///
/// \param arg The node producing the input data batch tensor. /// \param arg The node producing the input data batch tensor.
...@@ -37,7 +43,7 @@ namespace ngraph ...@@ -37,7 +43,7 @@ namespace ngraph
/// \param padding_above The above-padding shape. /// \param padding_above The above-padding shape.
/// \param pad_type The pad type for automatically computing padding sizes /// \param pad_type The pad type for automatically computing padding sizes
/// \param ceil_mode Whether to use ceiling while computing output shape. /// \param ceil_mode Whether to use ceiling while computing output shape.
MaxPool(const std::shared_ptr<Node>& arg, MaxPool(const Output<Node>& arg,
const Shape& window_shape, const Shape& window_shape,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Shape& padding_below, const Shape& padding_below,
...@@ -53,7 +59,7 @@ namespace ngraph ...@@ -53,7 +59,7 @@ namespace ngraph
/// \param padding_below The below-padding shape. /// \param padding_below The below-padding shape.
/// \param padding_above The above-padding shape. /// \param padding_above The above-padding shape.
/// \param pad_type The pad type for automatically computing padding sizes /// \param pad_type The pad type for automatically computing padding sizes
MaxPool(const std::shared_ptr<Node>& arg, MaxPool(const Output<Node>& arg,
const Shape& window_shape, const Shape& window_shape,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Shape& padding_below, const Shape& padding_below,
...@@ -67,7 +73,7 @@ namespace ngraph ...@@ -67,7 +73,7 @@ namespace ngraph
/// \param window_movement_strides The window movement strides. /// \param window_movement_strides The window movement strides.
/// \param padding_below The below-padding shape. /// \param padding_below The below-padding shape.
/// \param padding_above The above-padding shape. /// \param padding_above The above-padding shape.
MaxPool(const std::shared_ptr<Node>& arg, MaxPool(const Output<Node>& arg,
const Shape& window_shape, const Shape& window_shape,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Shape& padding_below, const Shape& padding_below,
...@@ -80,7 +86,7 @@ namespace ngraph ...@@ -80,7 +86,7 @@ namespace ngraph
/// \param arg The node producing the input data batch tensor. /// \param arg The node producing the input data batch tensor.
/// \param window_shape The window shape. /// \param window_shape The window shape.
/// \param window_movement_strides The window movement strides. /// \param window_movement_strides The window movement strides.
MaxPool(const std::shared_ptr<Node>& arg, MaxPool(const Output<Node>& arg,
const Shape& window_shape, const Shape& window_shape,
const Strides& window_movement_strides); const Strides& window_movement_strides);
...@@ -88,23 +94,32 @@ namespace ngraph ...@@ -88,23 +94,32 @@ namespace ngraph
/// ///
/// \param arg The node producing the input data batch tensor. /// \param arg The node producing the input data batch tensor.
/// \param window_shape The window shape. /// \param window_shape The window shape.
MaxPool(const std::shared_ptr<Node>& arg, const Shape& window_shape); MaxPool(const Output<Node>& arg, const Shape& window_shape);
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
/// \return The window shape. /// \return The window shape.
const Shape& get_window_shape() const { return m_window_shape; } const Shape& get_window_shape() const { return m_window_shape; }
void set_window_shape(const Shape& window_shape) { m_window_shape = window_shape; }
/// \return The window movement strides. /// \return The window movement strides.
const Strides& get_window_movement_strides() const { return m_window_movement_strides; } const Strides& get_window_movement_strides() const { return m_window_movement_strides; }
void set_window_movement_strides(const Strides& window_movement_strides)
{
m_window_movement_strides = window_movement_strides;
}
/// \return The below-padding shape. /// \return The below-padding shape.
const Shape& get_padding_below() const { return m_padding_below; } const Shape& get_padding_below() const { return m_padding_below; }
void set_padding_below(const Shape& padding_below) { m_padding_below = padding_below; }
/// \return The above-padding shape. /// \return The above-padding shape.
const Shape& get_padding_above() const { return m_padding_above; } const Shape& get_padding_above() const { return m_padding_above; }
void set_adding_above(const Shape& padding_above) { m_padding_above = padding_above; }
/// \return The pad type for pooling. /// \return The pad type for pooling.
const PadType& get_pad_type() const { return m_pad_type; } const PadType& get_pad_type() const { return m_pad_type; }
void set_pad_type(const PadType& pad_type) { m_pad_type = pad_type; }
/// \return The ceiling mode being used for output shape computations /// \return The ceiling mode being used for output shape computations
bool get_ceil_mode() const { return m_ceil_mode; } bool get_ceil_mode() const { return m_ceil_mode; }
void set_ceil_mode(bool ceil_mode) { m_ceil_mode = ceil_mode; }
/// \return The default value for MaxPool. /// \return The default value for MaxPool.
virtual std::shared_ptr<Node> get_default_value() const override virtual std::shared_ptr<Node> get_default_value() const override
{ {
...@@ -126,16 +141,21 @@ namespace ngraph ...@@ -126,16 +141,21 @@ namespace ngraph
class MaxPoolBackprop : public Op class MaxPoolBackprop : public Op
{ {
public: public:
MaxPoolBackprop(const std::shared_ptr<Node>& arg_forward, NGRAPH_API
const std::shared_ptr<Node>& delta, static const std::string type_name;
const std::string& description() const override { return type_name; }
MaxPoolBackprop() = default;
MaxPoolBackprop(const Output<Node>& arg_forward,
const Output<Node>& delta,
const Shape& window_shape, const Shape& window_shape,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Shape& padding_below, const Shape& padding_below,
const Shape& padding_above); const Shape& padding_above);
MaxPoolBackprop(const std::shared_ptr<Node>& arg_forward, MaxPoolBackprop(const Output<Node>& arg_forward,
const std::shared_ptr<Node>& delta, const Output<Node>& delta,
const std::shared_ptr<Node>& result_forward, const Output<Node>& result_forward,
const Shape& window_shape, const Shape& window_shape,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Shape& padding_below, const Shape& padding_below,
...@@ -147,9 +167,16 @@ namespace ngraph ...@@ -147,9 +167,16 @@ namespace ngraph
void validate_and_infer_types() override; void validate_and_infer_types() override;
const Shape& get_window_shape() const { return m_window_shape; } const Shape& get_window_shape() const { return m_window_shape; }
void set_window_shape(const Shape& window_shape) { m_window_shape = window_shape; }
const Strides& get_window_movement_strides() const { return m_window_movement_strides; } const Strides& get_window_movement_strides() const { return m_window_movement_strides; }
void set_window_movement_strides(const Strides& window_movement_strides)
{
m_window_movement_strides = window_movement_strides;
}
const Shape& get_padding_below() const { return m_padding_below; } const Shape& get_padding_below() const { return m_padding_below; }
void set_padding_below(const Shape& padding_below) { m_padding_below = padding_below; }
const Shape& get_padding_above() const { return m_padding_above; } const Shape& get_padding_above() const { return m_padding_above; }
void set_padding_above(const Shape& padding_above) { m_padding_above = padding_above; }
protected: protected:
Shape m_window_shape; Shape m_window_shape;
Strides m_window_movement_strides; Strides m_window_movement_strides;
......
...@@ -25,10 +25,12 @@ ...@@ -25,10 +25,12 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::Maximum::Maximum(const shared_ptr<Node>& arg0, const string op::Maximum::type_name{"Maximum"};
const shared_ptr<Node>& arg1,
op::Maximum::Maximum(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& autob) const AutoBroadcastSpec& autob)
: BinaryElementwiseArithmetic("Maximum", arg0, arg1, autob) : BinaryElementwiseArithmetic(arg0, arg1, autob)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -26,19 +26,24 @@ namespace ngraph ...@@ -26,19 +26,24 @@ namespace ngraph
class Maximum : public util::BinaryElementwiseArithmetic class Maximum : public util::BinaryElementwiseArithmetic
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a maximum operation.
Maximum() = default;
/// \brief Constructs a maximum operation. /// \brief Constructs a maximum operation.
/// ///
/// \param arg0 Node that produces the first input tensor. /// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor. /// \param arg1 Node that produces the second input tensor.
/// \param autob Auto broadcast specification /// \param autob Auto broadcast specification
Maximum(const std::shared_ptr<Node>& arg0, Maximum(const Output<Node>& arg0,
const std::shared_ptr<Node>& arg1, const Output<Node>& arg1,
const AutoBroadcastSpec& autob = AutoBroadcastSpec()); const AutoBroadcastSpec& autob = AutoBroadcastSpec());
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
virtual bool is_commutative() override { return true; } virtual bool is_commutative() const override { return true; }
protected: protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override; const NodeVector& deltas) override;
......
...@@ -22,10 +22,6 @@ using namespace ngraph; ...@@ -22,10 +22,6 @@ using namespace ngraph;
const string op::Min::type_name{"Min"}; const string op::Min::type_name{"Min"};
op::Min::Min()
{
}
op::Min::Min(const Output<Node>& arg, const AxisSet& reduction_axes) op::Min::Min(const Output<Node>& arg, const AxisSet& reduction_axes)
: ArithmeticReduction(arg, reduction_axes) : ArithmeticReduction(arg, reduction_axes)
{ {
......
...@@ -30,7 +30,7 @@ namespace ngraph ...@@ -30,7 +30,7 @@ namespace ngraph
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
/// \brief Constructs a "min" reduction operation. /// \brief Constructs a "min" reduction operation.
Min(); Min() = default;
/// \brief Constructs a min-reduction operation. /// \brief Constructs a min-reduction operation.
/// ///
/// \param arg The tensor to be reduced. /// \param arg The tensor to be reduced.
......
...@@ -25,10 +25,12 @@ ...@@ -25,10 +25,12 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::Minimum::Minimum(const shared_ptr<Node>& arg0, const string op::Minimum::type_name{"Minimum"};
const shared_ptr<Node>& arg1,
op::Minimum::Minimum(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& autob) const AutoBroadcastSpec& autob)
: BinaryElementwiseArithmetic("Minimum", arg0, arg1, autob) : BinaryElementwiseArithmetic(arg0, arg1, autob)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -26,18 +26,24 @@ namespace ngraph ...@@ -26,18 +26,24 @@ namespace ngraph
class Minimum : public util::BinaryElementwiseArithmetic class Minimum : public util::BinaryElementwiseArithmetic
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a minimum operation.
Minimum() = default;
/// \brief Constructs a minimum operation. /// \brief Constructs a minimum operation.
/// ///
/// \param arg0 Node that produces the first input tensor. /// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor. /// \param arg1 Node that produces the second input tensor.
/// \param autob Auto broadcast specification /// \param autob Auto broadcast specification
Minimum(const std::shared_ptr<Node>& arg0, Minimum(const Output<Node>& arg0,
const std::shared_ptr<Node>& arg1, const Output<Node>& arg1,
const AutoBroadcastSpec& autob = AutoBroadcastSpec()); const AutoBroadcastSpec& autob = AutoBroadcastSpec());
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
virtual bool is_commutative() const override { return true; }
protected: protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override; const NodeVector& deltas) override;
......
...@@ -19,10 +19,12 @@ ...@@ -19,10 +19,12 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::Multiply::Multiply(const shared_ptr<Node>& arg0, const string op::Multiply::type_name{"Multiply"};
const shared_ptr<Node>& arg1,
op::Multiply::Multiply(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& autob) const AutoBroadcastSpec& autob)
: BinaryElementwiseArithmetic("Multiply", arg0, arg1, autob) : BinaryElementwiseArithmetic(arg0, arg1, autob)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
...@@ -49,7 +51,7 @@ void op::Multiply::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVec ...@@ -49,7 +51,7 @@ void op::Multiply::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVec
adjoints.add_delta(y, x * delta); adjoints.add_delta(y, x * delta);
} }
shared_ptr<Node> ngraph::operator*(const shared_ptr<Node> arg0, const shared_ptr<Node> arg1) shared_ptr<Node> ngraph::operator*(const Output<Node>& arg0, const Output<Node>& arg1)
{ {
return make_shared<op::Multiply>(arg0, arg1); return make_shared<op::Multiply>(arg0, arg1);
} }
...@@ -26,25 +26,29 @@ namespace ngraph ...@@ -26,25 +26,29 @@ namespace ngraph
class Multiply : public util::BinaryElementwiseArithmetic class Multiply : public util::BinaryElementwiseArithmetic
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a multiplication operation.
Multiply() = default;
/// \brief Constructs a multiplication operation. /// \brief Constructs a multiplication operation.
/// ///
/// \param arg0 Node that produces the first input tensor. /// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor. /// \param arg1 Node that produces the second input tensor.
/// \param autob Auto broadcast specification /// \param autob Auto broadcast specification
Multiply(const std::shared_ptr<Node>& arg0, Multiply(const Output<Node>& arg0,
const std::shared_ptr<Node>& arg1, const Output<Node>& arg1,
const AutoBroadcastSpec& autob = AutoBroadcastSpec()); const AutoBroadcastSpec& autob = AutoBroadcastSpec());
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
virtual bool is_commutative() const override { return true; }
protected: protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override; const NodeVector& deltas) override;
virtual bool is_commutative() override { return true; }
}; };
}; };
std::shared_ptr<ngraph::Node> operator*(const std::shared_ptr<ngraph::Node> arg0, std::shared_ptr<Node> operator*(const Output<Node>& arg0, const Output<Node>& arg1);
const std::shared_ptr<ngraph::Node> arg1);
} }
...@@ -19,8 +19,10 @@ ...@@ -19,8 +19,10 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::Negative::Negative(const shared_ptr<Node>& arg) const string op::Negative::type_name{"Negative"};
: UnaryElementwiseArithmetic("Negative", arg)
op::Negative::Negative(const Output<Node>& arg)
: UnaryElementwiseArithmetic(arg)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
...@@ -40,7 +42,7 @@ void op::Negative::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVec ...@@ -40,7 +42,7 @@ void op::Negative::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVec
adjoints.add_delta(x, -delta); adjoints.add_delta(x, -delta);
} }
shared_ptr<Node> ngraph::operator-(const shared_ptr<Node> arg0) shared_ptr<Node> ngraph::operator-(const Output<Node>& arg0)
{ {
return make_shared<op::Negative>(arg0); return make_shared<op::Negative>(arg0);
} }
...@@ -26,17 +26,23 @@ namespace ngraph ...@@ -26,17 +26,23 @@ namespace ngraph
class Negative : public util::UnaryElementwiseArithmetic class Negative : public util::UnaryElementwiseArithmetic
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a negative operation.
Negative() = default;
/// \brief Constructs a negative operation. /// \brief Constructs a negative operation.
/// ///
/// \param arg Node that produces the input tensor. /// \param arg Node that produces the input tensor.
Negative(const std::shared_ptr<Node>& arg); Negative(const Output<Node>& arg);
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override; const NodeVector& deltas) override;
}; };
} }
std::shared_ptr<ngraph::Node> operator-(const std::shared_ptr<ngraph::Node> arg0); std::shared_ptr<Node> operator-(const Output<Node>& arg0);
} }
...@@ -20,8 +20,10 @@ ...@@ -20,8 +20,10 @@
using namespace ngraph; using namespace ngraph;
using namespace std; using namespace std;
op::Not::Not(const shared_ptr<Node>& arg) const string op::Not::type_name{"Not"};
: Op("Not", check_single_output_args({arg}))
op::Not::Not(const Output<Node>& arg)
: Op({arg})
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -26,10 +26,15 @@ namespace ngraph ...@@ -26,10 +26,15 @@ namespace ngraph
class Not : public Op class Not : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a logical negation operation.
Not() = default;
/// \brief Constructs a logical negation operation. /// \brief Constructs a logical negation operation.
/// ///
/// \param arg Node that produces the input tensor. /// \param arg Node that produces the input tensor.
Not(const std::shared_ptr<Node>& arg); Not(const Output<Node>& arg);
void validate_and_infer_types() override; void validate_and_infer_types() override;
......
...@@ -19,10 +19,12 @@ ...@@ -19,10 +19,12 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::NotEqual::NotEqual(const shared_ptr<Node>& arg0, const string op::NotEqual::type_name{"NotEqual"};
const shared_ptr<Node>& arg1,
op::NotEqual::NotEqual(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& autob) const AutoBroadcastSpec& autob)
: BinaryElementwiseComparison("NotEqual", arg0, arg1, autob) : BinaryElementwiseComparison(arg0, arg1, autob)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -26,17 +26,24 @@ namespace ngraph ...@@ -26,17 +26,24 @@ namespace ngraph
class NotEqual : public util::BinaryElementwiseComparison class NotEqual : public util::BinaryElementwiseComparison
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a not-equal operation.
NotEqual() = default;
/// \brief Constructs a not-equal operation. /// \brief Constructs a not-equal operation.
/// ///
/// \param arg0 Node that produces the first input tensor. /// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor. /// \param arg1 Node that produces the second input tensor.
/// \param autob Auto broadcast specification /// \param autob Auto broadcast specification
NotEqual(const std::shared_ptr<Node>& arg0, NotEqual(const Output<Node>& arg0,
const std::shared_ptr<Node>& arg1, const Output<Node>& arg1,
const AutoBroadcastSpec& autob = AutoBroadcastSpec()); const AutoBroadcastSpec& autob = AutoBroadcastSpec());
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
virtual bool is_commutative() const override { return true; }
}; };
} }
} }
...@@ -20,8 +20,10 @@ ...@@ -20,8 +20,10 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::OneHot::OneHot(const shared_ptr<Node>& arg, const PartialShape& shape, size_t one_hot_axis) const string op::OneHot::type_name{"OneHot"};
: Op("OneHot", check_single_output_args({arg}))
op::OneHot::OneHot(const Output<Node>& arg, const PartialShape& shape, size_t one_hot_axis)
: Op({arg})
, m_shape(shape) , m_shape(shape)
, m_one_hot_axis(one_hot_axis) , m_one_hot_axis(one_hot_axis)
{ {
......
...@@ -45,14 +45,17 @@ namespace ngraph ...@@ -45,14 +45,17 @@ namespace ngraph
class OneHot : public Op class OneHot : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a one-hot operation.
OneHot() = default;
/// \brief Constructs a one-hot operation. /// \brief Constructs a one-hot operation.
/// ///
/// \param arg Node that produces the input tensor to be one-hot encoded. /// \param arg Node that produces the input tensor to be one-hot encoded.
/// \param shape The shape of the output tensor, including the new one-hot axis. /// \param shape The shape of the output tensor, including the new one-hot axis.
/// \param one_hot_axis The index within the output shape of the new one-hot axis. /// \param one_hot_axis The index within the output shape of the new one-hot axis.
OneHot(const std::shared_ptr<Node>& arg, OneHot(const Output<Node>& arg, const PartialShape& shape, size_t one_hot_axis);
const PartialShape& shape,
size_t one_hot_axis);
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
...@@ -60,6 +63,7 @@ namespace ngraph ...@@ -60,6 +63,7 @@ namespace ngraph
/// \return The index of the one-hot axis. /// \return The index of the one-hot axis.
size_t get_one_hot_axis() const { return m_one_hot_axis; } size_t get_one_hot_axis() const { return m_one_hot_axis; }
void set_one_hot_axis(size_t one_hot_axis) { m_one_hot_axis = one_hot_axis; }
protected: protected:
PartialShape m_shape; PartialShape m_shape;
size_t m_one_hot_axis; size_t m_one_hot_axis;
......
...@@ -19,10 +19,10 @@ ...@@ -19,10 +19,10 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::Or::Or(const shared_ptr<Node>& arg0, const string op::Or::type_name{"Or"};
const shared_ptr<Node>& arg1,
const AutoBroadcastSpec& autob) op::Or::Or(const Output<Node>& arg0, const Output<Node>& arg1, const AutoBroadcastSpec& autob)
: BinaryElementwiseLogical("Or", arg0, arg1, autob) : BinaryElementwiseLogical(arg0, arg1, autob)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -29,6 +29,9 @@ namespace ngraph ...@@ -29,6 +29,9 @@ namespace ngraph
class Or : public util::BinaryElementwiseLogical class Or : public util::BinaryElementwiseLogical
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a logical-or operation. /// \brief Constructs a logical-or operation.
/// ///
/// \param arg0 Node that produces the first input tensor.<br> /// \param arg0 Node that produces the first input tensor.<br>
...@@ -39,15 +42,14 @@ namespace ngraph ...@@ -39,15 +42,14 @@ namespace ngraph
/// ///
/// Output `[d0, ...]` /// Output `[d0, ...]`
/// ///
Or(const std::shared_ptr<Node>& arg0, Or(const Output<Node>& arg0,
const std::shared_ptr<Node>& arg1, const Output<Node>& arg1,
const AutoBroadcastSpec& autob = AutoBroadcastSpec()); const AutoBroadcastSpec& autob = AutoBroadcastSpec());
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
protected: virtual bool is_commutative() const override { return true; }
virtual bool is_commutative() override { return true; }
}; };
} }
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment