Unverified Commit ec9ed92c authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

Merge branch 'master' into bob/nbench_db

parents b26c0e38 feefdbb2
...@@ -32,6 +32,7 @@ ngraph.ops ...@@ -32,6 +32,7 @@ ngraph.ops
cosh cosh
divide divide
dot dot
elu
equal equal
exp exp
floor floor
......
...@@ -45,6 +45,7 @@ from ngraph.ops import cos ...@@ -45,6 +45,7 @@ from ngraph.ops import cos
from ngraph.ops import cosh from ngraph.ops import cosh
from ngraph.ops import divide from ngraph.ops import divide
from ngraph.ops import dot from ngraph.ops import dot
from ngraph.ops import elu
from ngraph.ops import equal from ngraph.ops import equal
from ngraph.ops import exp from ngraph.ops import exp
from ngraph.ops import floor from ngraph.ops import floor
......
...@@ -69,6 +69,7 @@ from _pyngraph.op import Cos ...@@ -69,6 +69,7 @@ from _pyngraph.op import Cos
from _pyngraph.op import Cosh from _pyngraph.op import Cosh
from _pyngraph.op import Divide from _pyngraph.op import Divide
from _pyngraph.op import Dot from _pyngraph.op import Dot
from _pyngraph.op import Elu
from _pyngraph.op import Equal from _pyngraph.op import Equal
from _pyngraph.op import Exp from _pyngraph.op import Exp
from _pyngraph.op import Floor from _pyngraph.op import Floor
......
...@@ -22,7 +22,7 @@ from ngraph.impl import AxisSet, AxisVector, Coordinate, CoordinateDiff, Functio ...@@ -22,7 +22,7 @@ from ngraph.impl import AxisSet, AxisVector, Coordinate, CoordinateDiff, Functio
from ngraph.impl.op import Abs, Acos, Add, And, Asin, ArgMax, ArgMin, Atan, AvgPool, \ from ngraph.impl.op import Abs, Acos, Add, And, Asin, ArgMax, ArgMin, Atan, AvgPool, \
BatchNormTraining, BatchNormInference, Broadcast, Ceiling, Concat, Constant, Convert, \ BatchNormTraining, BatchNormInference, Broadcast, Ceiling, Concat, Constant, Convert, \
Convolution, ConvolutionBackpropData, Cos, Cosh, Divide, Dot, Equal, Exp, Floor, \ Convolution, ConvolutionBackpropData, Cos, Cosh, Divide, Dot, Elu, Equal, Exp, Floor, \
GetOutputElement, Greater, GreaterEq, Less, LessEq, Log, LRN, Max, Maximum, MaxPool, \ GetOutputElement, Greater, GreaterEq, Less, LessEq, Log, LRN, Max, Maximum, MaxPool, \
Min, Minimum, Multiply, Negative, Not, NotEqual, OneHot, Or, Pad, Parameter, Product, \ Min, Minimum, Multiply, Negative, Not, NotEqual, OneHot, Or, Pad, Parameter, Product, \
Power, Relu, ReplaceSlice, Reshape, Reverse, Select, Sign, Sin, Sinh, Slice, Softmax, \ Power, Relu, ReplaceSlice, Reshape, Reverse, Select, Sign, Sin, Sinh, Slice, Softmax, \
...@@ -35,7 +35,7 @@ from ngraph.utils.decorators import nameable_op, binary_op, unary_op ...@@ -35,7 +35,7 @@ from ngraph.utils.decorators import nameable_op, binary_op, unary_op
from ngraph.utils.input_validation import assert_list_of_ints from ngraph.utils.input_validation import assert_list_of_ints
from ngraph.utils.reduction import get_reduction_axes from ngraph.utils.reduction import get_reduction_axes
from ngraph.utils.types import NumericType, NumericData, TensorShape, make_constant_node, \ from ngraph.utils.types import NumericType, NumericData, TensorShape, make_constant_node, \
NodeInput, ScalarData NodeInput, ScalarData, as_node
from ngraph.utils.types import get_element_type from ngraph.utils.types import get_element_type
...@@ -60,6 +60,24 @@ def constant(value, dtype=None, name=None): # type: (NumericData, NumericType, ...@@ -60,6 +60,24 @@ def constant(value, dtype=None, name=None): # type: (NumericData, NumericType,
return make_constant_node(value, dtype) return make_constant_node(value, dtype)
@nameable_op
def elu(data, alpha, name=None): # type: (NodeInput, NodeInput, str) -> Node
"""Perform Exponential Linear Unit operation element-wise on data from input node.
Computes exponential linear: alpha * (exp(data) - 1) if < 0, data otherwise.
For more information refer to:
`Fast and Accurate Deep Network Learning by Exponential Linear Units (ELUs)
<http://arxiv.org/abs/1511.07289>`_
:param data: Input tensor. One of: input node, array or scalar.
:param alpha: Multiplier for negative values. One of: input node or scalar value.
:param name: Optional output node name.
:return: The new node performing an ELU operation on its input data element-wise.
"""
return Elu(as_node(data), as_node(alpha))
# Unary ops # Unary ops
@unary_op @unary_op
def absolute(node, name=None): # type: (NodeInput, str) -> Node def absolute(node, name=None): # type: (NodeInput, str) -> Node
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "ngraph/op/fused/elu.hpp"
#include "pyngraph/ops/elu.hpp"
namespace py = pybind11;
void regclass_pyngraph_op_Elu(py::module m)
{
py::class_<ngraph::op::Elu, std::shared_ptr<ngraph::op::Elu>, ngraph::op::Op> elu(m, "Elu");
elu.doc() = "ngraph.impl.op.Elu wraps ngraph::op::Elu";
elu.def(py::init<const std::shared_ptr<ngraph::Node>&, const std::shared_ptr<ngraph::Node>&>());
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <pybind11/pybind11.h>
namespace py = pybind11;
void regclass_pyngraph_op_Elu(py::module m);
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "ngraph/op/fused/elu.hpp"
#include "pyngraph/ops/fused/elu.hpp"
namespace py = pybind11;
void regclass_pyngraph_op_Elu(py::module m)
{
py::class_<ngraph::op::Elu, std::shared_ptr<ngraph::op::Elu>, ngraph::op::Op> elu(m, "Elu");
elu.doc() = "ngraph.impl.op.Elu wraps ngraph::op::Elu";
elu.def(py::init<const std::shared_ptr<ngraph::Node>&, const std::shared_ptr<ngraph::Node>&>());
}
...@@ -49,6 +49,7 @@ void regmodule_pyngraph_op(py::module m_op) ...@@ -49,6 +49,7 @@ void regmodule_pyngraph_op(py::module m_op)
regclass_pyngraph_op_Cosh(m_op); regclass_pyngraph_op_Cosh(m_op);
regclass_pyngraph_op_Divide(m_op); regclass_pyngraph_op_Divide(m_op);
regclass_pyngraph_op_Dot(m_op); regclass_pyngraph_op_Dot(m_op);
regclass_pyngraph_op_Elu(m_op);
regclass_pyngraph_op_Equal(m_op); regclass_pyngraph_op_Equal(m_op);
regclass_pyngraph_op_Exp(m_op); regclass_pyngraph_op_Exp(m_op);
regclass_pyngraph_op_Floor(m_op); regclass_pyngraph_op_Floor(m_op);
......
...@@ -39,6 +39,7 @@ ...@@ -39,6 +39,7 @@
#include "pyngraph/ops/cosh.hpp" #include "pyngraph/ops/cosh.hpp"
#include "pyngraph/ops/divide.hpp" #include "pyngraph/ops/divide.hpp"
#include "pyngraph/ops/dot.hpp" #include "pyngraph/ops/dot.hpp"
#include "pyngraph/ops/elu.hpp"
#include "pyngraph/ops/equal.hpp" #include "pyngraph/ops/equal.hpp"
#include "pyngraph/ops/exp.hpp" #include "pyngraph/ops/exp.hpp"
#include "pyngraph/ops/floor.hpp" #include "pyngraph/ops/floor.hpp"
......
...@@ -179,6 +179,7 @@ sources = [ ...@@ -179,6 +179,7 @@ sources = [
'pyngraph/ops/ceiling.cpp', 'pyngraph/ops/ceiling.cpp',
'pyngraph/ops/divide.cpp', 'pyngraph/ops/divide.cpp',
'pyngraph/ops/dot.cpp', 'pyngraph/ops/dot.cpp',
'pyngraph/ops/elu.cpp',
'pyngraph/ops/equal.cpp', 'pyngraph/ops/equal.cpp',
'pyngraph/ops/exp.cpp', 'pyngraph/ops/exp.cpp',
'pyngraph/ops/floor.cpp', 'pyngraph/ops/floor.cpp',
......
# ******************************************************************************
# Copyright 2017-2019 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************
import numpy as np
import ngraph as ng
from test.ngraph.util import get_runtime
def test_elu_operator():
runtime = get_runtime()
data_shape = [2, 2]
alpha_shape = [2]
parameter_data = ng.parameter(data_shape, name='Data', dtype=np.float32)
parameter_alpha = ng.parameter(alpha_shape, name='Alpha', dtype=np.float32)
model = ng.elu(parameter_data, parameter_alpha)
computation = runtime.computation(model, parameter_data, parameter_alpha)
value_data = np.array([[-5, 1], [-2, 3]], dtype=np.float32)
value_alpha = np.array([3, 3], dtype=np.float32)
result = computation(value_data, value_alpha)
expected = np.array([[-2.9797862, 1.], [-2.5939941, 3.]], dtype=np.float32)
assert np.allclose(result, expected)
def test_elu_operator_with_scalar_and_array():
runtime = get_runtime()
data_value = np.array([[-5, 1], [-2, 3]], dtype=np.float32)
alpha_value = np.float32(3)
model = ng.elu(data_value, alpha_value)
computation = runtime.computation(model)
result = computation()
expected = np.array([[-2.9797862, 1.], [-2.5939941, 3.]], dtype=np.float32)
assert np.allclose(result, expected)
def test_elu_operator_with_scalar():
runtime = get_runtime()
data_value = np.array([[-5, 1], [-2, 3]], dtype=np.float32)
alpha_value = np.float32(3)
data_shape = [2, 2]
parameter_data = ng.parameter(data_shape, name='Data', dtype=np.float32)
model = ng.elu(parameter_data, alpha_value)
computation = runtime.computation(model, parameter_data)
result = computation(data_value)
expected = np.array([[-2.9797862, 1.], [-2.5939941, 3.]], dtype=np.float32)
assert np.allclose(result, expected)
...@@ -16,9 +16,8 @@ ...@@ -16,9 +16,8 @@
import numpy as np import numpy as np
import ngraph as ng import ngraph as ng
from ngraph.utils.types import NumericData
from string import ascii_uppercase from typing import Any, Callable, List
import test import test
...@@ -32,10 +31,14 @@ def get_runtime(): ...@@ -32,10 +31,14 @@ def get_runtime():
def run_op_node(input_data, op_fun, *args): def run_op_node(input_data, op_fun, *args):
# type: (NumericData, Callable, *Any) -> List[NumericData]
"""Run computation on node performing `op_fun`. """Run computation on node performing `op_fun`.
`op_fun` has to accept a node as an argument. `op_fun` has to accept a node as an argument.
This function converts passed raw input data to nGraph Constant Node and that form is passed
to `op_fun`.
:param input_data: The input data for performed computation. :param input_data: The input data for performed computation.
:param op_fun: The function handler for operation we want to carry out. :param op_fun: The function handler for operation we want to carry out.
:param args: The arguments passed to operation we want to carry out. :param args: The arguments passed to operation we want to carry out.
...@@ -45,14 +48,8 @@ def run_op_node(input_data, op_fun, *args): ...@@ -45,14 +48,8 @@ def run_op_node(input_data, op_fun, *args):
comp_args = [] comp_args = []
op_fun_args = [] op_fun_args = []
comp_inputs = [] comp_inputs = []
for idx, data in enumerate(input_data): for data in input_data:
if np.isscalar(data): op_fun_args.append(ng.constant(data, _get_numpy_dtype(data)))
op_fun_args.append(ng.constant(data, _get_numpy_dtype(data)))
else:
node = ng.parameter(data.shape, name=ascii_uppercase[idx], dtype=data.dtype)
op_fun_args.append(node)
comp_args.append(node)
comp_inputs.append(data)
op_fun_args.extend(args) op_fun_args.extend(args)
node = op_fun(*op_fun_args) node = op_fun(*op_fun_args)
computation = runtime.computation(node, *comp_args) computation = runtime.computation(node, *comp_args)
...@@ -60,10 +57,15 @@ def run_op_node(input_data, op_fun, *args): ...@@ -60,10 +57,15 @@ def run_op_node(input_data, op_fun, *args):
def run_op_numeric_data(input_data, op_fun, *args): def run_op_numeric_data(input_data, op_fun, *args):
# type: (NumericData, Callable, *Any) -> List[NumericData]
"""Run computation on node performing `op_fun`. """Run computation on node performing `op_fun`.
`op_fun` has to accept a scalar or an array. `op_fun` has to accept a scalar or an array.
This function passess input data AS IS. This mean that in case they're a scalar (integral,
or floating point value) or a NumPy's ndarray object they will be automatically converted
to nGraph's Constant Nodes.
:param input_data: The input data for performed computation. :param input_data: The input data for performed computation.
:param op_fun: The function handler for operation we want to carry out. :param op_fun: The function handler for operation we want to carry out.
:param args: The arguments passed to operation we want to carry out. :param args: The arguments passed to operation we want to carry out.
......
...@@ -350,6 +350,7 @@ namespace ...@@ -350,6 +350,7 @@ namespace
} }
return callBackFuncPtr; return callBackFuncPtr;
} }
// NGDialect converters // NGDialect converters
Type NGraphTypeConverter::convertType(Type type) Type NGraphTypeConverter::convertType(Type type)
{ {
...@@ -576,7 +577,6 @@ namespace ...@@ -576,7 +577,6 @@ namespace
// Create Value for result, and extract type info. // Create Value for result, and extract type info.
Value* result = m_pass.buildOutputDefs(op, rewriter)[0]; Value* result = m_pass.buildOutputDefs(op, rewriter)[0];
NGRAPH_CHECK(result, "Unexpected null result in ConcatOp"); NGRAPH_CHECK(result, "Unexpected null result in ConcatOp");
auto resultTy = result->getType().cast<MemRefType>();
// Create view to write into result. // Create view to write into result.
MemRefView vRes(result); MemRefView vRes(result);
...@@ -590,7 +590,6 @@ namespace ...@@ -590,7 +590,6 @@ namespace
for (auto& operand : operands) for (auto& operand : operands)
{ {
NGRAPH_CHECK(operand, "Unexpected null operand in ConcatOp"); NGRAPH_CHECK(operand, "Unexpected null operand in ConcatOp");
auto operandTy = result->getType().cast<MemRefType>();
// Assuming rank = r, and the concatenation axis is A where A<r, we'll be creating // Assuming rank = r, and the concatenation axis is A where A<r, we'll be creating
// loops of this form: // loops of this form:
......
...@@ -74,7 +74,6 @@ void MLIRSubgraphExtractionPass::MLIRSubgraph::merge(MLIRSubgraph& sg2) ...@@ -74,7 +74,6 @@ void MLIRSubgraphExtractionPass::MLIRSubgraph::merge(MLIRSubgraph& sg2)
// Associate nodes of second sub-graph to first one // Associate nodes of second sub-graph to first one
auto sg_nodes = sg2.get_nodes(); auto sg_nodes = sg2.get_nodes();
auto& node_map = m_pass.m_node_to_graph;
for (auto node : sg_nodes) for (auto node : sg_nodes)
{ {
NGRAPH_DEBUG << *node; NGRAPH_DEBUG << *node;
...@@ -112,7 +111,6 @@ bool MLIRSubgraphExtractionPass::run_on_function(std::shared_ptr<Function> func) ...@@ -112,7 +111,6 @@ bool MLIRSubgraphExtractionPass::run_on_function(std::shared_ptr<Function> func)
for (auto op : func->get_ordered_ops()) for (auto op : func->get_ordered_ops())
{ {
NodeVector inputs; NodeVector inputs;
int first_graph_id = -1;
std::unordered_set<int> subgraph_ids; std::unordered_set<int> subgraph_ids;
// unsupported ops, skip // unsupported ops, skip
if (!is_supported_mlir_op(op)) if (!is_supported_mlir_op(op))
......
...@@ -160,5 +160,5 @@ namespace ngraph ...@@ -160,5 +160,5 @@ namespace ngraph
/// \brief Macro to signal a code path that is unreachable in a successful execution. It's /// \brief Macro to signal a code path that is unreachable in a successful execution. It's
/// implemented with NGRAPH_CHECK macro. /// implemented with NGRAPH_CHECK macro.
/// \param ... Additional error message that should describe why that execution path is unreachable. /// \param ... Additional error message that should describe why that execution path is unreachable.
/// \throws ::ngrap::CheckFailure if the macro is executed. /// \throws ::ngraph::CheckFailure if the macro is executed.
#define NGRAPH_UNREACHABLE(...) NGRAPH_CHECK(false, "Unreachable: ", ##__VA_ARGS__) #define NGRAPH_UNREACHABLE(...) NGRAPH_CHECK(false, "Unreachable: ", ##__VA_ARGS__)
...@@ -214,7 +214,7 @@ namespace ngraph ...@@ -214,7 +214,7 @@ namespace ngraph
virtual bool is_constant() const; virtual bool is_constant() const;
virtual bool is_null() const { return false; } virtual bool is_null() const { return false; }
virtual bool is_op() const { return false; } virtual bool is_op() const { return false; }
virtual bool is_commutative() { return false; } virtual bool is_commutative() const { return false; }
virtual bool is_dynamic() const; virtual bool is_dynamic() const;
virtual bool has_state() const { return false; } virtual bool has_state() const { return false; }
size_t get_instance_id() const { return m_instance_id; } size_t get_instance_id() const { return m_instance_id; }
......
...@@ -49,7 +49,7 @@ void op::Add::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& ...@@ -49,7 +49,7 @@ void op::Add::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector&
adjoints.add_delta(y, delta); adjoints.add_delta(y, delta);
} }
shared_ptr<Node> ngraph::operator+(const shared_ptr<Node>& arg0, const shared_ptr<Node>& arg1) shared_ptr<Node> ngraph::operator+(const Output<Node>& arg0, const Output<Node>& arg1)
{ {
return make_shared<op::Add>(arg0, arg1); return make_shared<op::Add>(arg0, arg1);
} }
...@@ -51,13 +51,12 @@ namespace ngraph ...@@ -51,13 +51,12 @@ namespace ngraph
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override; std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
virtual bool is_commutative() const override { return true; }
protected: protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override; const NodeVector& deltas) override;
virtual bool is_commutative() override { return true; }
}; };
} }
std::shared_ptr<ngraph::Node> operator+(const std::shared_ptr<ngraph::Node>& arg0, std::shared_ptr<Node> operator+(const Output<Node>& arg0, const Output<Node>& arg1);
const std::shared_ptr<ngraph::Node>& arg1);
} }
...@@ -51,8 +51,7 @@ namespace ngraph ...@@ -51,8 +51,7 @@ namespace ngraph
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override; std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
protected: virtual bool is_commutative() const override { return true; }
virtual bool is_commutative() override { return true; }
}; };
} }
} }
...@@ -22,12 +22,15 @@ ...@@ -22,12 +22,15 @@
#include "ngraph/op/get_output_element.hpp" #include "ngraph/op/get_output_element.hpp"
#include "ngraph/validation_util.hpp" #include "ngraph/validation_util.hpp"
const std::string ngraph::op::BatchNormTraining::type_name{"BatchNormTraining"}; using namespace std;
using namespace ngraph;
ngraph::op::BatchNormTraining::BatchNormTraining(Output<ngraph::Node> input, const string op::BatchNormTraining::type_name{"BatchNormTraining"};
Output<ngraph::Node> gamma,
Output<ngraph::Node> beta, op::BatchNormTraining::BatchNormTraining(const Output<Node>& input,
double epsilon) const Output<Node>& gamma,
const Output<Node>& beta,
double epsilon)
: Op({gamma, beta, input}) : Op({gamma, beta, input})
, m_epsilon(epsilon) , m_epsilon(epsilon)
{ {
...@@ -35,17 +38,17 @@ ngraph::op::BatchNormTraining::BatchNormTraining(Output<ngraph::Node> input, ...@@ -35,17 +38,17 @@ ngraph::op::BatchNormTraining::BatchNormTraining(Output<ngraph::Node> input,
} }
// DEPRECATED // DEPRECATED
ngraph::op::BatchNormTraining::BatchNormTraining(double eps, op::BatchNormTraining::BatchNormTraining(double eps,
Output<ngraph::Node> gamma, const Output<Node>& gamma,
Output<ngraph::Node> beta, const Output<Node>& beta,
Output<ngraph::Node> input) const Output<Node>& input)
: Op({gamma, beta, input}) : Op({gamma, beta, input})
, m_epsilon(eps) , m_epsilon(eps)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
void ngraph::op::BatchNormTraining::validate_and_infer_types() void op::BatchNormTraining::validate_and_infer_types()
{ {
element::Type result_et; element::Type result_et;
PartialShape result_batch_shape; PartialShape result_batch_shape;
...@@ -66,16 +69,15 @@ void ngraph::op::BatchNormTraining::validate_and_infer_types() ...@@ -66,16 +69,15 @@ void ngraph::op::BatchNormTraining::validate_and_infer_types()
set_output_type(2, result_et, result_channel_shape); set_output_type(2, result_et, result_channel_shape);
} }
std::shared_ptr<ngraph::Node> std::shared_ptr<Node> op::BatchNormTraining::copy_with_new_args(const NodeVector& new_args) const
ngraph::op::BatchNormTraining::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
return std::make_shared<BatchNormTraining>( return std::make_shared<BatchNormTraining>(
new_args.at(2), new_args.at(0), new_args.at(1), m_epsilon); new_args.at(2), new_args.at(0), new_args.at(1), m_epsilon);
} }
void ngraph::op::BatchNormTraining::generate_adjoints(autodiff::Adjoints& adjoints, void op::BatchNormTraining::generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) const NodeVector& deltas)
{ {
auto gamma = input(0).get_source_output(); auto gamma = input(0).get_source_output();
auto beta = input(1).get_source_output(); auto beta = input(1).get_source_output();
...@@ -102,14 +104,14 @@ void ngraph::op::BatchNormTraining::generate_adjoints(autodiff::Adjoints& adjoin ...@@ -102,14 +104,14 @@ void ngraph::op::BatchNormTraining::generate_adjoints(autodiff::Adjoints& adjoin
adjoints.add_delta(beta, dbeta); adjoints.add_delta(beta, dbeta);
} }
const std::string ngraph::op::BatchNormInference::type_name{"BatchNormInference"}; const string op::BatchNormInference::type_name{"BatchNormInference"};
ngraph::op::BatchNormInference::BatchNormInference(Output<ngraph::Node> input, op::BatchNormInference::BatchNormInference(const Output<Node>& input,
Output<ngraph::Node> gamma, const Output<Node>& gamma,
Output<ngraph::Node> beta, const Output<Node>& beta,
Output<ngraph::Node> mean, const Output<Node>& mean,
Output<ngraph::Node> variance, const Output<Node>& variance,
double epsilon) double epsilon)
: Op({gamma, beta, input, mean, variance}) : Op({gamma, beta, input, mean, variance})
, m_epsilon(epsilon) , m_epsilon(epsilon)
{ {
...@@ -117,19 +119,19 @@ ngraph::op::BatchNormInference::BatchNormInference(Output<ngraph::Node> input, ...@@ -117,19 +119,19 @@ ngraph::op::BatchNormInference::BatchNormInference(Output<ngraph::Node> input,
} }
// DEPRECATED // DEPRECATED
ngraph::op::BatchNormInference::BatchNormInference(double eps, op::BatchNormInference::BatchNormInference(double eps,
Output<ngraph::Node> gamma, const Output<Node>& gamma,
Output<ngraph::Node> beta, const Output<Node>& beta,
Output<ngraph::Node> input, const Output<Node>& input,
Output<ngraph::Node> mean, const Output<Node>& mean,
Output<ngraph::Node> variance) const Output<Node>& variance)
: Op({gamma, beta, input, mean, variance}) : Op({gamma, beta, input, mean, variance})
, m_epsilon(eps) , m_epsilon(eps)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
void ngraph::op::BatchNormInference::validate_and_infer_types() void op::BatchNormInference::validate_and_infer_types()
{ {
element::Type result_et; element::Type result_et;
PartialShape result_batch_shape; PartialShape result_batch_shape;
...@@ -152,23 +154,22 @@ void ngraph::op::BatchNormInference::validate_and_infer_types() ...@@ -152,23 +154,22 @@ void ngraph::op::BatchNormInference::validate_and_infer_types()
set_output_type(0, result_et, result_batch_shape); set_output_type(0, result_et, result_batch_shape);
} }
std::shared_ptr<ngraph::Node> std::shared_ptr<Node> op::BatchNormInference::copy_with_new_args(const NodeVector& new_args) const
ngraph::op::BatchNormInference::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
return std::make_shared<BatchNormInference>( return std::make_shared<BatchNormInference>(
new_args.at(2), new_args.at(0), new_args.at(1), new_args.at(3), new_args.at(4), m_epsilon); new_args.at(2), new_args.at(0), new_args.at(1), new_args.at(3), new_args.at(4), m_epsilon);
} }
const std::string ngraph::op::BatchNormTrainingBackprop::type_name{"BatchNormTrainingBackprop"}; const string op::BatchNormTrainingBackprop::type_name{"BatchNormTrainingBackprop"};
ngraph::op::BatchNormTrainingBackprop::BatchNormTrainingBackprop(Output<ngraph::Node> input, op::BatchNormTrainingBackprop::BatchNormTrainingBackprop(const Output<Node>& input,
Output<ngraph::Node> gamma, const Output<Node>& gamma,
Output<ngraph::Node> beta, const Output<Node>& beta,
Output<ngraph::Node> mean, const Output<Node>& mean,
Output<ngraph::Node> variance, const Output<Node>& variance,
Output<ngraph::Node> delta, const Output<Node>& delta,
double epsilon) double epsilon)
: Op({gamma, beta, input, mean, variance, delta}) : Op({gamma, beta, input, mean, variance, delta})
, m_epsilon(epsilon) , m_epsilon(epsilon)
...@@ -177,13 +178,13 @@ ngraph::op::BatchNormTrainingBackprop::BatchNormTrainingBackprop(Output<ngraph:: ...@@ -177,13 +178,13 @@ ngraph::op::BatchNormTrainingBackprop::BatchNormTrainingBackprop(Output<ngraph::
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
ngraph::op::BatchNormTrainingBackprop::BatchNormTrainingBackprop(double epsilon, op::BatchNormTrainingBackprop::BatchNormTrainingBackprop(double epsilon,
Output<ngraph::Node> gamma, const Output<Node>& gamma,
Output<ngraph::Node> beta, const Output<Node>& beta,
Output<ngraph::Node> input, const Output<Node>& input,
Output<ngraph::Node> mean, const Output<Node>& mean,
Output<ngraph::Node> variance, const Output<Node>& variance,
Output<ngraph::Node> delta) const Output<Node>& delta)
: Op({gamma, beta, input, mean, variance, delta}) : Op({gamma, beta, input, mean, variance, delta})
, m_epsilon(epsilon) , m_epsilon(epsilon)
...@@ -192,7 +193,7 @@ ngraph::op::BatchNormTrainingBackprop::BatchNormTrainingBackprop(double epsilon, ...@@ -192,7 +193,7 @@ ngraph::op::BatchNormTrainingBackprop::BatchNormTrainingBackprop(double epsilon,
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
void ngraph::op::BatchNormTrainingBackprop::validate_and_infer_types() void op::BatchNormTrainingBackprop::validate_and_infer_types()
{ {
PartialShape input_and_delta_shape{get_input_partial_shape(INPUT_DATA)}; PartialShape input_and_delta_shape{get_input_partial_shape(INPUT_DATA)};
...@@ -239,8 +240,8 @@ void ngraph::op::BatchNormTrainingBackprop::validate_and_infer_types() ...@@ -239,8 +240,8 @@ void ngraph::op::BatchNormTrainingBackprop::validate_and_infer_types()
set_output_type(2, result_et, result_channel_shape); set_output_type(2, result_et, result_channel_shape);
} }
std::shared_ptr<ngraph::Node> std::shared_ptr<Node>
ngraph::op::BatchNormTrainingBackprop::copy_with_new_args(const NodeVector& new_args) const op::BatchNormTrainingBackprop::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
return std::make_shared<op::BatchNormTrainingBackprop>(new_args.at(2), return std::make_shared<op::BatchNormTrainingBackprop>(new_args.at(2),
......
...@@ -39,9 +39,9 @@ namespace ngraph ...@@ -39,9 +39,9 @@ namespace ngraph
/// \param gamma gamma scaling for normalized value. [C] /// \param gamma gamma scaling for normalized value. [C]
/// \param beta bias added to the scaled normalized value [C] /// \param beta bias added to the scaled normalized value [C]
/// \param epsilon Avoids divsion by 0 if input has 0 variance /// \param epsilon Avoids divsion by 0 if input has 0 variance
BatchNormTraining(Output<Node> input, BatchNormTraining(const Output<Node>& input,
Output<Node> gamma, const Output<Node>& gamma,
Output<Node> beta, const Output<Node>& beta,
double epsilon); double epsilon);
NGRAPH_DEPRECATED_DOC NGRAPH_DEPRECATED_DOC
...@@ -66,9 +66,9 @@ namespace ngraph ...@@ -66,9 +66,9 @@ namespace ngraph
/// output[2]: shall have rank 1, with the same span as input's channel axis. /// output[2]: shall have rank 1, with the same span as input's channel axis.
NGRAPH_DEPRECATED("Use another constructor") NGRAPH_DEPRECATED("Use another constructor")
BatchNormTraining(double eps, BatchNormTraining(double eps,
Output<Node> gamma, const Output<Node>& gamma,
Output<Node> beta, const Output<Node>& beta,
Output<Node> input); const Output<Node>& input);
void validate_and_infer_types() override; void validate_and_infer_types() override;
...@@ -101,11 +101,11 @@ namespace ngraph ...@@ -101,11 +101,11 @@ namespace ngraph
/// \param mean value for mean normalization [C] /// \param mean value for mean normalization [C]
/// \param variance value for variance normalization [C] /// \param variance value for variance normalization [C]
/// \param epsilon Avoids divsion by 0 if input has 0 variance /// \param epsilon Avoids divsion by 0 if input has 0 variance
BatchNormInference(Output<ngraph::Node> input, BatchNormInference(const Output<Node>& input,
Output<ngraph::Node> gamma, const Output<Node>& gamma,
Output<ngraph::Node> beta, const Output<Node>& beta,
Output<ngraph::Node> mean, const Output<Node>& mean,
Output<ngraph::Node> variance, const Output<Node>& variance,
double epsilon); double epsilon);
NGRAPH_DEPRECATED_DOC NGRAPH_DEPRECATED_DOC
...@@ -128,11 +128,11 @@ namespace ngraph ...@@ -128,11 +128,11 @@ namespace ngraph
/// output: shall have the same shape as 'input'. /// output: shall have the same shape as 'input'.
NGRAPH_DEPRECATED("Use another constructor") NGRAPH_DEPRECATED("Use another constructor")
BatchNormInference(double eps, BatchNormInference(double eps,
Output<ngraph::Node> gamma, const Output<Node>& gamma,
Output<ngraph::Node> beta, const Output<Node>& beta,
Output<ngraph::Node> input, const Output<Node>& input,
Output<ngraph::Node> mean, const Output<Node>& mean,
Output<ngraph::Node> variance); const Output<Node>& variance);
void validate_and_infer_types() override; void validate_and_infer_types() override;
...@@ -165,24 +165,23 @@ namespace ngraph ...@@ -165,24 +165,23 @@ namespace ngraph
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
BatchNormTrainingBackprop() = default; BatchNormTrainingBackprop() = default;
BatchNormTrainingBackprop(Output<Node> input, BatchNormTrainingBackprop(const Output<Node>& input,
Output<Node> gamma, const Output<Node>& gamma,
Output<Node> beta, const Output<Node>& beta,
Output<Node> mean, const Output<Node>& mean,
Output<Node> variance, const Output<Node>& variance,
Output<Node> delta, const Output<Node>& delta,
double epsilon); double epsilon);
NGRAPH_DEPRECATED_DOC NGRAPH_DEPRECATED_DOC
NGRAPH_DEPRECATED("Use another constructor") NGRAPH_DEPRECATED("Use another constructor")
BatchNormTrainingBackprop(double epsilon, BatchNormTrainingBackprop(double epsilon,
Output<Node> gamma, const Output<Node>& gamma,
Output<Node> beta, const Output<Node>& beta,
Output<Node> input, const Output<Node>& input,
const Output<Node>& mean,
Output<Node> mean, const Output<Node>& variance,
Output<Node> variance, const Output<Node>& delta);
Output<Node> delta);
void validate_and_infer_types() override; void validate_and_infer_types() override;
......
...@@ -64,7 +64,7 @@ void op::Divide::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVecto ...@@ -64,7 +64,7 @@ void op::Divide::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVecto
adjoints.add_delta(y, -delta * shared_from_this() / y); adjoints.add_delta(y, -delta * shared_from_this() / y);
} }
shared_ptr<Node> ngraph::operator/(const Output<Node> arg0, const Output<Node> arg1) shared_ptr<Node> ngraph::operator/(const Output<Node>& arg0, const Output<Node>& arg1)
{ {
return make_shared<op::Divide>(arg0, arg1); return make_shared<op::Divide>(arg0, arg1);
} }
...@@ -64,6 +64,5 @@ namespace ngraph ...@@ -64,6 +64,5 @@ namespace ngraph
}; };
} }
std::shared_ptr<ngraph::Node> operator/(const Output<ngraph::Node> arg0, std::shared_ptr<Node> operator/(const Output<Node>& arg0, const Output<Node>& arg1);
const Output<ngraph::Node> arg1);
} }
...@@ -58,7 +58,7 @@ namespace ngraph ...@@ -58,7 +58,7 @@ namespace ngraph
void validate_and_infer_types() override; void validate_and_infer_types() override;
size_t get_reduction_axes_count() const { return m_reduction_axes_count; } size_t get_reduction_axes_count() const { return m_reduction_axes_count; }
void get_reduction_axes_count(size_t reduction_axes_count) void set_reduction_axes_count(size_t reduction_axes_count)
{ {
m_reduction_axes_count = reduction_axes_count; m_reduction_axes_count = reduction_axes_count;
} }
......
...@@ -56,6 +56,8 @@ namespace ngraph ...@@ -56,6 +56,8 @@ namespace ngraph
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
virtual bool is_commutative() const override { return true; }
}; };
} }
} }
...@@ -39,7 +39,7 @@ NodeVector op::Elu::decompose_op() const ...@@ -39,7 +39,7 @@ NodeVector op::Elu::decompose_op() const
auto data = get_argument(0); auto data = get_argument(0);
auto alpha_node = get_argument(1); auto alpha_node = get_argument(1);
alpha_node = ngraph::op::make_broadcast_node(alpha_node, data->get_shape()); alpha_node = ngraph::op::numpy_style_broadcast(alpha_node, data->get_shape());
shared_ptr<ngraph::Node> zero_node = shared_ptr<ngraph::Node> zero_node =
builder::make_constant(data->get_element_type(), data->get_shape(), 0); builder::make_constant(data->get_element_type(), data->get_shape(), 0);
......
...@@ -23,6 +23,8 @@ using namespace ngraph; ...@@ -23,6 +23,8 @@ using namespace ngraph;
static int PARAMS = 0; static int PARAMS = 0;
static int INDICES = 1; static int INDICES = 1;
const string op::Gather::type_name{"Gather"};
shared_ptr<Node> op::Gather::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Gather::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
......
...@@ -26,13 +26,15 @@ namespace ngraph ...@@ -26,13 +26,15 @@ namespace ngraph
class Gather : public Op class Gather : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
Gather() = default;
/// \param params The tensor from which slices are gathered /// \param params The tensor from which slices are gathered
/// \param indices Index tensor: Data type must be `element::i32` or `element::i64` /// \param indices Index tensor: Data type must be `element::i32` or `element::i64`
/// \param axis Axis in params to gather /// \param axis Axis in params to gather
Gather(const std::shared_ptr<Node>& params, Gather(const Output<Node>& params, const Output<Node>& indices, size_t axis = 0)
const std::shared_ptr<Node>& indices, : Op({params, indices})
size_t axis = 0)
: Op("Gather", check_single_output_args({params, indices}))
, m_axis(axis) , m_axis(axis)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
...@@ -46,6 +48,7 @@ namespace ngraph ...@@ -46,6 +48,7 @@ namespace ngraph
} }
size_t get_axis() const { return m_axis; } size_t get_axis() const { return m_axis; }
void set_axis(size_t axis) { m_axis = axis; }
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
......
...@@ -23,6 +23,8 @@ using namespace ngraph; ...@@ -23,6 +23,8 @@ using namespace ngraph;
static int PARAMS = 0; static int PARAMS = 0;
static int INDICES = 1; static int INDICES = 1;
const string op::GatherND::type_name{"GatherND"};
shared_ptr<Node> op::GatherND::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::GatherND::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
......
...@@ -26,10 +26,14 @@ namespace ngraph ...@@ -26,10 +26,14 @@ namespace ngraph
class GatherND : public Op class GatherND : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
GatherND() = default;
/// \param params The tensor from which slices are gathered /// \param params The tensor from which slices are gathered
/// \param indices Index tensor: Data type must be `element::i32` or `element::i64` /// \param indices Index tensor: Data type must be `element::i32` or `element::i64`
GatherND(const std::shared_ptr<Node>& params, const std::shared_ptr<Node>& indices) GatherND(const Output<Node>& params, const Output<Node>& indices)
: Op("GatherND", check_single_output_args({params, indices})) : Op({params, indices})
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -19,10 +19,12 @@ ...@@ -19,10 +19,12 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::Greater::Greater(const shared_ptr<Node>& arg0, const string op::Greater::type_name{"Greater"};
const shared_ptr<Node>& arg1,
op::Greater::Greater(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& autob) const AutoBroadcastSpec& autob)
: BinaryElementwiseComparison("Greater", arg0, arg1, autob) : BinaryElementwiseComparison(arg0, arg1, autob)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -26,13 +26,18 @@ namespace ngraph ...@@ -26,13 +26,18 @@ namespace ngraph
class Greater : public util::BinaryElementwiseComparison class Greater : public util::BinaryElementwiseComparison
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a greater-than operation.
Greater() = default;
/// \brief Constructs a greater-than operation. /// \brief Constructs a greater-than operation.
/// ///
/// \param arg0 Node that produces the first input tensor. /// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor. /// \param arg1 Node that produces the second input tensor.
/// \param autob Auto broadcast specification /// \param autob Auto broadcast specification
Greater(const std::shared_ptr<Node>& arg0, Greater(const Output<Node>& arg0,
const std::shared_ptr<Node>& arg1, const Output<Node>& arg1,
const AutoBroadcastSpec& autob = AutoBroadcastSpec()); const AutoBroadcastSpec& autob = AutoBroadcastSpec());
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
......
...@@ -19,10 +19,12 @@ ...@@ -19,10 +19,12 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::GreaterEq::GreaterEq(const shared_ptr<Node>& arg0, const string op::GreaterEq::type_name{"GreaterEq"};
const shared_ptr<Node>& arg1,
op::GreaterEq::GreaterEq(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& autob) const AutoBroadcastSpec& autob)
: BinaryElementwiseComparison("GreaterEq", arg0, arg1, autob) : BinaryElementwiseComparison(arg0, arg1, autob)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -26,13 +26,18 @@ namespace ngraph ...@@ -26,13 +26,18 @@ namespace ngraph
class GreaterEq : public util::BinaryElementwiseComparison class GreaterEq : public util::BinaryElementwiseComparison
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a greater-than-or-equal operation.
GreaterEq() = default;
/// \brief Constructs a greater-than-or-equal operation. /// \brief Constructs a greater-than-or-equal operation.
/// ///
/// \param arg0 Node that produces the first input tensor. /// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor. /// \param arg1 Node that produces the second input tensor.
/// \param autob Auto broadcast specification /// \param autob Auto broadcast specification
GreaterEq(const std::shared_ptr<Node>& arg0, GreaterEq(const Output<Node>& arg0,
const std::shared_ptr<Node>& arg1, const Output<Node>& arg1,
const AutoBroadcastSpec& autob = AutoBroadcastSpec()); const AutoBroadcastSpec& autob = AutoBroadcastSpec());
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
......
...@@ -19,10 +19,10 @@ ...@@ -19,10 +19,10 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::Less::Less(const shared_ptr<Node>& arg0, const string op::Less::type_name{"Less"};
const shared_ptr<Node>& arg1,
const AutoBroadcastSpec& autob) op::Less::Less(const Output<Node>& arg0, const Output<Node>& arg1, const AutoBroadcastSpec& autob)
: BinaryElementwiseComparison("Less", arg0, arg1, autob) : BinaryElementwiseComparison(arg0, arg1, autob)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -26,13 +26,18 @@ namespace ngraph ...@@ -26,13 +26,18 @@ namespace ngraph
class Less : public util::BinaryElementwiseComparison class Less : public util::BinaryElementwiseComparison
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a less-than operation.
Less() = default;
/// \brief Constructs a less-than operation. /// \brief Constructs a less-than operation.
/// ///
/// \param arg0 Node that produces the first input tensor. /// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor. /// \param arg1 Node that produces the second input tensor.
/// \param autob Auto broadcast specification /// \param autob Auto broadcast specification
Less(const std::shared_ptr<Node>& arg0, Less(const Output<Node>& arg0,
const std::shared_ptr<Node>& arg1, const Output<Node>& arg1,
const AutoBroadcastSpec& autob = AutoBroadcastSpec()); const AutoBroadcastSpec& autob = AutoBroadcastSpec());
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
......
...@@ -19,10 +19,12 @@ ...@@ -19,10 +19,12 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::LessEq::LessEq(const shared_ptr<Node>& arg0, const string op::LessEq::type_name{"LessEq"};
const shared_ptr<Node>& arg1,
op::LessEq::LessEq(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& autob) const AutoBroadcastSpec& autob)
: BinaryElementwiseComparison("LessEq", arg0, arg1, autob) : BinaryElementwiseComparison(arg0, arg1, autob)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -26,13 +26,18 @@ namespace ngraph ...@@ -26,13 +26,18 @@ namespace ngraph
class LessEq : public util::BinaryElementwiseComparison class LessEq : public util::BinaryElementwiseComparison
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a less-than-or-equal operation.
LessEq() = default;
/// \brief Constructs a less-than-or-equal operation. /// \brief Constructs a less-than-or-equal operation.
/// ///
/// \param arg0 Node that produces the first input tensor. /// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor. /// \param arg1 Node that produces the second input tensor.
/// \param autob Auto broadcast specification /// \param autob Auto broadcast specification
LessEq(const std::shared_ptr<Node>& arg0, LessEq(const Output<Node>& arg0,
const std::shared_ptr<Node>& arg1, const Output<Node>& arg1,
const AutoBroadcastSpec& autob = AutoBroadcastSpec()); const AutoBroadcastSpec& autob = AutoBroadcastSpec());
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
......
...@@ -20,8 +20,10 @@ ...@@ -20,8 +20,10 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::Log::Log(const shared_ptr<Node>& arg) const string op::Log::type_name{"Log"};
: UnaryElementwiseArithmetic("Log", arg)
op::Log::Log(const Output<Node>& arg)
: UnaryElementwiseArithmetic(arg)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -26,10 +26,15 @@ namespace ngraph ...@@ -26,10 +26,15 @@ namespace ngraph
class Log : public util::UnaryElementwiseArithmetic class Log : public util::UnaryElementwiseArithmetic
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a natural log operation.
Log() = default;
/// \brief Constructs a natural log operation. /// \brief Constructs a natural log operation.
/// ///
/// \param arg Node that produces the input tensor. /// \param arg Node that produces the input tensor.
Log(const std::shared_ptr<Node>& arg); Log(const Output<Node>& arg);
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
......
...@@ -20,12 +20,14 @@ ...@@ -20,12 +20,14 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::LRN::LRN(const std::shared_ptr<Node>& arg, double alpha, double beta, double bias, size_t nsize) const string op::LRN::type_name{"LRN"};
: UnaryElementwiseArithmetic("LRN", arg)
op::LRN::LRN(const Output<Node>& arg, double alpha, double beta, double bias, size_t size)
: UnaryElementwiseArithmetic(arg)
, m_alpha(alpha) , m_alpha(alpha)
, m_beta(beta) , m_beta(beta)
, m_bias(bias) , m_bias(bias)
, m_size(nsize) , m_size(size)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -38,23 +38,28 @@ namespace ngraph ...@@ -38,23 +38,28 @@ namespace ngraph
class LRN : public util::UnaryElementwiseArithmetic class LRN : public util::UnaryElementwiseArithmetic
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a LRN operation.
LRN() = default;
/// \brief Constructs a LRN operation. /// \brief Constructs a LRN operation.
/// ///
/// \param arg Node that produces the input tensor. /// \param arg Node that produces the input tensor.
LRN(const std::shared_ptr<Node>& arg, LRN(const Output<Node>& arg, double alpha, double beta, double bias, size_t size);
double alpha,
double beta,
double bias,
size_t size);
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
void validate_and_infer_types() override; void validate_and_infer_types() override;
double get_alpha() const { return m_alpha; } double get_alpha() const { return m_alpha; }
void set_alpha(double alpha) { m_alpha = alpha; }
double get_beta() const { return m_beta; } double get_beta() const { return m_beta; }
void set_beta(double beta) { m_beta = beta; }
double get_bias() const { return m_bias; } double get_bias() const { return m_bias; }
void set_bias(double bias) { m_bias = bias; }
size_t get_nsize() const { return m_size; } size_t get_nsize() const { return m_size; }
void set_nsize(size_t size) { m_size = size; }
protected: protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override; const NodeVector& deltas) override;
......
...@@ -22,10 +22,6 @@ using namespace ngraph; ...@@ -22,10 +22,6 @@ using namespace ngraph;
const string op::Max::type_name{"Max"}; const string op::Max::type_name{"Max"};
op::Max::Max()
{
}
op::Max::Max(const Output<Node>& arg, const AxisSet& reduction_axes) op::Max::Max(const Output<Node>& arg, const AxisSet& reduction_axes)
: ArithmeticReduction(arg, reduction_axes) : ArithmeticReduction(arg, reduction_axes)
{ {
......
...@@ -30,7 +30,7 @@ namespace ngraph ...@@ -30,7 +30,7 @@ namespace ngraph
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
/// \brief Constructs a "max" reduction operation. /// \brief Constructs a "max" reduction operation.
Max(); Max() = default;
/// \brief Constructs a max-reduction operation. /// \brief Constructs a max-reduction operation.
/// ///
/// \param arg The tensor to be reduced. /// \param arg The tensor to be reduced.
......
...@@ -25,14 +25,16 @@ ...@@ -25,14 +25,16 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::MaxPool::MaxPool(const shared_ptr<Node>& arg, const string op::MaxPool::type_name{"MaxPool"};
op::MaxPool::MaxPool(const Output<Node>& arg,
const Shape& window_shape, const Shape& window_shape,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Shape& padding_below, const Shape& padding_below,
const Shape& padding_above, const Shape& padding_above,
const PadType& pad_type, const PadType& pad_type,
bool ceil_mode) bool ceil_mode)
: Op("MaxPool", check_single_output_args({arg})) : Op({arg})
, m_window_shape(window_shape) , m_window_shape(window_shape)
, m_window_movement_strides(window_movement_strides) , m_window_movement_strides(window_movement_strides)
, m_padding_below(padding_below) , m_padding_below(padding_below)
...@@ -43,7 +45,7 @@ op::MaxPool::MaxPool(const shared_ptr<Node>& arg, ...@@ -43,7 +45,7 @@ op::MaxPool::MaxPool(const shared_ptr<Node>& arg,
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
op::MaxPool::MaxPool(const shared_ptr<Node>& arg, op::MaxPool::MaxPool(const Output<Node>& arg,
const Shape& window_shape, const Shape& window_shape,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Shape& padding_below, const Shape& padding_below,
...@@ -54,7 +56,7 @@ op::MaxPool::MaxPool(const shared_ptr<Node>& arg, ...@@ -54,7 +56,7 @@ op::MaxPool::MaxPool(const shared_ptr<Node>& arg,
{ {
} }
op::MaxPool::MaxPool(const shared_ptr<Node>& arg, op::MaxPool::MaxPool(const Output<Node>& arg,
const Shape& window_shape, const Shape& window_shape,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Shape& padding_below, const Shape& padding_below,
...@@ -121,14 +123,14 @@ void op::MaxPool::validate_and_infer_types() ...@@ -121,14 +123,14 @@ void op::MaxPool::validate_and_infer_types()
m_ceil_mode)); m_ceil_mode));
} }
op::MaxPool::MaxPool(const shared_ptr<Node>& arg, op::MaxPool::MaxPool(const Output<Node>& arg,
const Shape& window_shape, const Shape& window_shape,
const Strides& window_movement_strides) const Strides& window_movement_strides)
: MaxPool(arg, window_shape, window_movement_strides, Shape(), Shape()) : MaxPool(arg, window_shape, window_movement_strides, Shape(), Shape())
{ {
} }
op::MaxPool::MaxPool(const shared_ptr<Node>& arg, const Shape& window_shape) op::MaxPool::MaxPool(const Output<Node>& arg, const Shape& window_shape)
: MaxPool(arg, window_shape, Strides(), Shape(), Shape()) : MaxPool(arg, window_shape, Strides(), Shape(), Shape())
{ {
} }
...@@ -145,13 +147,15 @@ shared_ptr<Node> op::MaxPool::copy_with_new_args(const NodeVector& new_args) con ...@@ -145,13 +147,15 @@ shared_ptr<Node> op::MaxPool::copy_with_new_args(const NodeVector& new_args) con
m_ceil_mode); m_ceil_mode);
} }
op::MaxPoolBackprop::MaxPoolBackprop(const shared_ptr<Node>& arg_forward, const string op::MaxPoolBackprop::type_name{"MaxPoolBackprop"};
const shared_ptr<Node>& delta,
op::MaxPoolBackprop::MaxPoolBackprop(const Output<Node>& arg_forward,
const Output<Node>& delta,
const Shape& window_shape, const Shape& window_shape,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Shape& padding_below, const Shape& padding_below,
const Shape& padding_above) const Shape& padding_above)
: Op("MaxPoolBackprop", check_single_output_args({arg_forward, delta})) : Op({arg_forward, delta})
, m_window_shape(window_shape) , m_window_shape(window_shape)
, m_window_movement_strides(window_movement_strides) , m_window_movement_strides(window_movement_strides)
, m_padding_below(padding_below) , m_padding_below(padding_below)
...@@ -160,14 +164,14 @@ op::MaxPoolBackprop::MaxPoolBackprop(const shared_ptr<Node>& arg_forward, ...@@ -160,14 +164,14 @@ op::MaxPoolBackprop::MaxPoolBackprop(const shared_ptr<Node>& arg_forward,
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
op::MaxPoolBackprop::MaxPoolBackprop(const shared_ptr<Node>& arg_forward, op::MaxPoolBackprop::MaxPoolBackprop(const Output<Node>& arg_forward,
const shared_ptr<Node>& delta, const Output<Node>& delta,
const shared_ptr<Node>& result_forward, const Output<Node>& result_forward,
const Shape& window_shape, const Shape& window_shape,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Shape& padding_below, const Shape& padding_below,
const Shape& padding_above) const Shape& padding_above)
: Op("MaxPoolBackprop", check_single_output_args({arg_forward, delta, result_forward})) : Op({arg_forward, delta, result_forward})
, m_window_shape(window_shape) , m_window_shape(window_shape)
, m_window_movement_strides(window_movement_strides) , m_window_movement_strides(window_movement_strides)
, m_padding_below(padding_below) , m_padding_below(padding_below)
......
...@@ -28,6 +28,12 @@ namespace ngraph ...@@ -28,6 +28,12 @@ namespace ngraph
class MaxPool : public Op class MaxPool : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a batched max pooling operation.
MaxPool() = default;
/// \brief Constructs a batched max pooling operation. /// \brief Constructs a batched max pooling operation.
/// ///
/// \param arg The node producing the input data batch tensor. /// \param arg The node producing the input data batch tensor.
...@@ -37,7 +43,7 @@ namespace ngraph ...@@ -37,7 +43,7 @@ namespace ngraph
/// \param padding_above The above-padding shape. /// \param padding_above The above-padding shape.
/// \param pad_type The pad type for automatically computing padding sizes /// \param pad_type The pad type for automatically computing padding sizes
/// \param ceil_mode Whether to use ceiling while computing output shape. /// \param ceil_mode Whether to use ceiling while computing output shape.
MaxPool(const std::shared_ptr<Node>& arg, MaxPool(const Output<Node>& arg,
const Shape& window_shape, const Shape& window_shape,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Shape& padding_below, const Shape& padding_below,
...@@ -53,7 +59,7 @@ namespace ngraph ...@@ -53,7 +59,7 @@ namespace ngraph
/// \param padding_below The below-padding shape. /// \param padding_below The below-padding shape.
/// \param padding_above The above-padding shape. /// \param padding_above The above-padding shape.
/// \param pad_type The pad type for automatically computing padding sizes /// \param pad_type The pad type for automatically computing padding sizes
MaxPool(const std::shared_ptr<Node>& arg, MaxPool(const Output<Node>& arg,
const Shape& window_shape, const Shape& window_shape,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Shape& padding_below, const Shape& padding_below,
...@@ -67,7 +73,7 @@ namespace ngraph ...@@ -67,7 +73,7 @@ namespace ngraph
/// \param window_movement_strides The window movement strides. /// \param window_movement_strides The window movement strides.
/// \param padding_below The below-padding shape. /// \param padding_below The below-padding shape.
/// \param padding_above The above-padding shape. /// \param padding_above The above-padding shape.
MaxPool(const std::shared_ptr<Node>& arg, MaxPool(const Output<Node>& arg,
const Shape& window_shape, const Shape& window_shape,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Shape& padding_below, const Shape& padding_below,
...@@ -80,7 +86,7 @@ namespace ngraph ...@@ -80,7 +86,7 @@ namespace ngraph
/// \param arg The node producing the input data batch tensor. /// \param arg The node producing the input data batch tensor.
/// \param window_shape The window shape. /// \param window_shape The window shape.
/// \param window_movement_strides The window movement strides. /// \param window_movement_strides The window movement strides.
MaxPool(const std::shared_ptr<Node>& arg, MaxPool(const Output<Node>& arg,
const Shape& window_shape, const Shape& window_shape,
const Strides& window_movement_strides); const Strides& window_movement_strides);
...@@ -88,23 +94,32 @@ namespace ngraph ...@@ -88,23 +94,32 @@ namespace ngraph
/// ///
/// \param arg The node producing the input data batch tensor. /// \param arg The node producing the input data batch tensor.
/// \param window_shape The window shape. /// \param window_shape The window shape.
MaxPool(const std::shared_ptr<Node>& arg, const Shape& window_shape); MaxPool(const Output<Node>& arg, const Shape& window_shape);
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
/// \return The window shape. /// \return The window shape.
const Shape& get_window_shape() const { return m_window_shape; } const Shape& get_window_shape() const { return m_window_shape; }
void set_window_shape(const Shape& window_shape) { m_window_shape = window_shape; }
/// \return The window movement strides. /// \return The window movement strides.
const Strides& get_window_movement_strides() const { return m_window_movement_strides; } const Strides& get_window_movement_strides() const { return m_window_movement_strides; }
void set_window_movement_strides(const Strides& window_movement_strides)
{
m_window_movement_strides = window_movement_strides;
}
/// \return The below-padding shape. /// \return The below-padding shape.
const Shape& get_padding_below() const { return m_padding_below; } const Shape& get_padding_below() const { return m_padding_below; }
void set_padding_below(const Shape& padding_below) { m_padding_below = padding_below; }
/// \return The above-padding shape. /// \return The above-padding shape.
const Shape& get_padding_above() const { return m_padding_above; } const Shape& get_padding_above() const { return m_padding_above; }
void set_adding_above(const Shape& padding_above) { m_padding_above = padding_above; }
/// \return The pad type for pooling. /// \return The pad type for pooling.
const PadType& get_pad_type() const { return m_pad_type; } const PadType& get_pad_type() const { return m_pad_type; }
void set_pad_type(const PadType& pad_type) { m_pad_type = pad_type; }
/// \return The ceiling mode being used for output shape computations /// \return The ceiling mode being used for output shape computations
bool get_ceil_mode() const { return m_ceil_mode; } bool get_ceil_mode() const { return m_ceil_mode; }
void set_ceil_mode(bool ceil_mode) { m_ceil_mode = ceil_mode; }
/// \return The default value for MaxPool. /// \return The default value for MaxPool.
virtual std::shared_ptr<Node> get_default_value() const override virtual std::shared_ptr<Node> get_default_value() const override
{ {
...@@ -126,16 +141,21 @@ namespace ngraph ...@@ -126,16 +141,21 @@ namespace ngraph
class MaxPoolBackprop : public Op class MaxPoolBackprop : public Op
{ {
public: public:
MaxPoolBackprop(const std::shared_ptr<Node>& arg_forward, NGRAPH_API
const std::shared_ptr<Node>& delta, static const std::string type_name;
const std::string& description() const override { return type_name; }
MaxPoolBackprop() = default;
MaxPoolBackprop(const Output<Node>& arg_forward,
const Output<Node>& delta,
const Shape& window_shape, const Shape& window_shape,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Shape& padding_below, const Shape& padding_below,
const Shape& padding_above); const Shape& padding_above);
MaxPoolBackprop(const std::shared_ptr<Node>& arg_forward, MaxPoolBackprop(const Output<Node>& arg_forward,
const std::shared_ptr<Node>& delta, const Output<Node>& delta,
const std::shared_ptr<Node>& result_forward, const Output<Node>& result_forward,
const Shape& window_shape, const Shape& window_shape,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Shape& padding_below, const Shape& padding_below,
...@@ -147,9 +167,16 @@ namespace ngraph ...@@ -147,9 +167,16 @@ namespace ngraph
void validate_and_infer_types() override; void validate_and_infer_types() override;
const Shape& get_window_shape() const { return m_window_shape; } const Shape& get_window_shape() const { return m_window_shape; }
void set_window_shape(const Shape& window_shape) { m_window_shape = window_shape; }
const Strides& get_window_movement_strides() const { return m_window_movement_strides; } const Strides& get_window_movement_strides() const { return m_window_movement_strides; }
void set_window_movement_strides(const Strides& window_movement_strides)
{
m_window_movement_strides = window_movement_strides;
}
const Shape& get_padding_below() const { return m_padding_below; } const Shape& get_padding_below() const { return m_padding_below; }
void set_padding_below(const Shape& padding_below) { m_padding_below = padding_below; }
const Shape& get_padding_above() const { return m_padding_above; } const Shape& get_padding_above() const { return m_padding_above; }
void set_padding_above(const Shape& padding_above) { m_padding_above = padding_above; }
protected: protected:
Shape m_window_shape; Shape m_window_shape;
Strides m_window_movement_strides; Strides m_window_movement_strides;
......
...@@ -25,10 +25,12 @@ ...@@ -25,10 +25,12 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::Maximum::Maximum(const shared_ptr<Node>& arg0, const string op::Maximum::type_name{"Maximum"};
const shared_ptr<Node>& arg1,
op::Maximum::Maximum(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& autob) const AutoBroadcastSpec& autob)
: BinaryElementwiseArithmetic("Maximum", arg0, arg1, autob) : BinaryElementwiseArithmetic(arg0, arg1, autob)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -26,19 +26,24 @@ namespace ngraph ...@@ -26,19 +26,24 @@ namespace ngraph
class Maximum : public util::BinaryElementwiseArithmetic class Maximum : public util::BinaryElementwiseArithmetic
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a maximum operation.
Maximum() = default;
/// \brief Constructs a maximum operation. /// \brief Constructs a maximum operation.
/// ///
/// \param arg0 Node that produces the first input tensor. /// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor. /// \param arg1 Node that produces the second input tensor.
/// \param autob Auto broadcast specification /// \param autob Auto broadcast specification
Maximum(const std::shared_ptr<Node>& arg0, Maximum(const Output<Node>& arg0,
const std::shared_ptr<Node>& arg1, const Output<Node>& arg1,
const AutoBroadcastSpec& autob = AutoBroadcastSpec()); const AutoBroadcastSpec& autob = AutoBroadcastSpec());
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
virtual bool is_commutative() override { return true; } virtual bool is_commutative() const override { return true; }
protected: protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override; const NodeVector& deltas) override;
......
...@@ -22,10 +22,6 @@ using namespace ngraph; ...@@ -22,10 +22,6 @@ using namespace ngraph;
const string op::Min::type_name{"Min"}; const string op::Min::type_name{"Min"};
op::Min::Min()
{
}
op::Min::Min(const Output<Node>& arg, const AxisSet& reduction_axes) op::Min::Min(const Output<Node>& arg, const AxisSet& reduction_axes)
: ArithmeticReduction(arg, reduction_axes) : ArithmeticReduction(arg, reduction_axes)
{ {
......
...@@ -30,7 +30,7 @@ namespace ngraph ...@@ -30,7 +30,7 @@ namespace ngraph
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
/// \brief Constructs a "min" reduction operation. /// \brief Constructs a "min" reduction operation.
Min(); Min() = default;
/// \brief Constructs a min-reduction operation. /// \brief Constructs a min-reduction operation.
/// ///
/// \param arg The tensor to be reduced. /// \param arg The tensor to be reduced.
......
...@@ -25,10 +25,12 @@ ...@@ -25,10 +25,12 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::Minimum::Minimum(const shared_ptr<Node>& arg0, const string op::Minimum::type_name{"Minimum"};
const shared_ptr<Node>& arg1,
op::Minimum::Minimum(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& autob) const AutoBroadcastSpec& autob)
: BinaryElementwiseArithmetic("Minimum", arg0, arg1, autob) : BinaryElementwiseArithmetic(arg0, arg1, autob)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -26,18 +26,24 @@ namespace ngraph ...@@ -26,18 +26,24 @@ namespace ngraph
class Minimum : public util::BinaryElementwiseArithmetic class Minimum : public util::BinaryElementwiseArithmetic
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a minimum operation.
Minimum() = default;
/// \brief Constructs a minimum operation. /// \brief Constructs a minimum operation.
/// ///
/// \param arg0 Node that produces the first input tensor. /// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor. /// \param arg1 Node that produces the second input tensor.
/// \param autob Auto broadcast specification /// \param autob Auto broadcast specification
Minimum(const std::shared_ptr<Node>& arg0, Minimum(const Output<Node>& arg0,
const std::shared_ptr<Node>& arg1, const Output<Node>& arg1,
const AutoBroadcastSpec& autob = AutoBroadcastSpec()); const AutoBroadcastSpec& autob = AutoBroadcastSpec());
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
virtual bool is_commutative() const override { return true; }
protected: protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override; const NodeVector& deltas) override;
......
...@@ -19,10 +19,12 @@ ...@@ -19,10 +19,12 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::Multiply::Multiply(const shared_ptr<Node>& arg0, const string op::Multiply::type_name{"Multiply"};
const shared_ptr<Node>& arg1,
op::Multiply::Multiply(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& autob) const AutoBroadcastSpec& autob)
: BinaryElementwiseArithmetic("Multiply", arg0, arg1, autob) : BinaryElementwiseArithmetic(arg0, arg1, autob)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
...@@ -49,7 +51,7 @@ void op::Multiply::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVec ...@@ -49,7 +51,7 @@ void op::Multiply::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVec
adjoints.add_delta(y, x * delta); adjoints.add_delta(y, x * delta);
} }
shared_ptr<Node> ngraph::operator*(const shared_ptr<Node> arg0, const shared_ptr<Node> arg1) shared_ptr<Node> ngraph::operator*(const Output<Node>& arg0, const Output<Node>& arg1)
{ {
return make_shared<op::Multiply>(arg0, arg1); return make_shared<op::Multiply>(arg0, arg1);
} }
...@@ -26,25 +26,29 @@ namespace ngraph ...@@ -26,25 +26,29 @@ namespace ngraph
class Multiply : public util::BinaryElementwiseArithmetic class Multiply : public util::BinaryElementwiseArithmetic
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a multiplication operation.
Multiply() = default;
/// \brief Constructs a multiplication operation. /// \brief Constructs a multiplication operation.
/// ///
/// \param arg0 Node that produces the first input tensor. /// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor. /// \param arg1 Node that produces the second input tensor.
/// \param autob Auto broadcast specification /// \param autob Auto broadcast specification
Multiply(const std::shared_ptr<Node>& arg0, Multiply(const Output<Node>& arg0,
const std::shared_ptr<Node>& arg1, const Output<Node>& arg1,
const AutoBroadcastSpec& autob = AutoBroadcastSpec()); const AutoBroadcastSpec& autob = AutoBroadcastSpec());
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
virtual bool is_commutative() const override { return true; }
protected: protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override; const NodeVector& deltas) override;
virtual bool is_commutative() override { return true; }
}; };
}; };
std::shared_ptr<ngraph::Node> operator*(const std::shared_ptr<ngraph::Node> arg0, std::shared_ptr<Node> operator*(const Output<Node>& arg0, const Output<Node>& arg1);
const std::shared_ptr<ngraph::Node> arg1);
} }
...@@ -19,8 +19,10 @@ ...@@ -19,8 +19,10 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::Negative::Negative(const shared_ptr<Node>& arg) const string op::Negative::type_name{"Negative"};
: UnaryElementwiseArithmetic("Negative", arg)
op::Negative::Negative(const Output<Node>& arg)
: UnaryElementwiseArithmetic(arg)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
...@@ -40,7 +42,7 @@ void op::Negative::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVec ...@@ -40,7 +42,7 @@ void op::Negative::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVec
adjoints.add_delta(x, -delta); adjoints.add_delta(x, -delta);
} }
shared_ptr<Node> ngraph::operator-(const shared_ptr<Node> arg0) shared_ptr<Node> ngraph::operator-(const Output<Node>& arg0)
{ {
return make_shared<op::Negative>(arg0); return make_shared<op::Negative>(arg0);
} }
...@@ -26,17 +26,23 @@ namespace ngraph ...@@ -26,17 +26,23 @@ namespace ngraph
class Negative : public util::UnaryElementwiseArithmetic class Negative : public util::UnaryElementwiseArithmetic
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a negative operation.
Negative() = default;
/// \brief Constructs a negative operation. /// \brief Constructs a negative operation.
/// ///
/// \param arg Node that produces the input tensor. /// \param arg Node that produces the input tensor.
Negative(const std::shared_ptr<Node>& arg); Negative(const Output<Node>& arg);
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override; const NodeVector& deltas) override;
}; };
} }
std::shared_ptr<ngraph::Node> operator-(const std::shared_ptr<ngraph::Node> arg0); std::shared_ptr<Node> operator-(const Output<Node>& arg0);
} }
...@@ -20,8 +20,10 @@ ...@@ -20,8 +20,10 @@
using namespace ngraph; using namespace ngraph;
using namespace std; using namespace std;
op::Not::Not(const shared_ptr<Node>& arg) const string op::Not::type_name{"Not"};
: Op("Not", check_single_output_args({arg}))
op::Not::Not(const Output<Node>& arg)
: Op({arg})
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -26,10 +26,15 @@ namespace ngraph ...@@ -26,10 +26,15 @@ namespace ngraph
class Not : public Op class Not : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a logical negation operation.
Not() = default;
/// \brief Constructs a logical negation operation. /// \brief Constructs a logical negation operation.
/// ///
/// \param arg Node that produces the input tensor. /// \param arg Node that produces the input tensor.
Not(const std::shared_ptr<Node>& arg); Not(const Output<Node>& arg);
void validate_and_infer_types() override; void validate_and_infer_types() override;
......
...@@ -19,10 +19,12 @@ ...@@ -19,10 +19,12 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::NotEqual::NotEqual(const shared_ptr<Node>& arg0, const string op::NotEqual::type_name{"NotEqual"};
const shared_ptr<Node>& arg1,
op::NotEqual::NotEqual(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& autob) const AutoBroadcastSpec& autob)
: BinaryElementwiseComparison("NotEqual", arg0, arg1, autob) : BinaryElementwiseComparison(arg0, arg1, autob)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -26,17 +26,24 @@ namespace ngraph ...@@ -26,17 +26,24 @@ namespace ngraph
class NotEqual : public util::BinaryElementwiseComparison class NotEqual : public util::BinaryElementwiseComparison
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a not-equal operation.
NotEqual() = default;
/// \brief Constructs a not-equal operation. /// \brief Constructs a not-equal operation.
/// ///
/// \param arg0 Node that produces the first input tensor. /// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor. /// \param arg1 Node that produces the second input tensor.
/// \param autob Auto broadcast specification /// \param autob Auto broadcast specification
NotEqual(const std::shared_ptr<Node>& arg0, NotEqual(const Output<Node>& arg0,
const std::shared_ptr<Node>& arg1, const Output<Node>& arg1,
const AutoBroadcastSpec& autob = AutoBroadcastSpec()); const AutoBroadcastSpec& autob = AutoBroadcastSpec());
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
virtual bool is_commutative() const override { return true; }
}; };
} }
} }
...@@ -20,8 +20,10 @@ ...@@ -20,8 +20,10 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::OneHot::OneHot(const shared_ptr<Node>& arg, const PartialShape& shape, size_t one_hot_axis) const string op::OneHot::type_name{"OneHot"};
: Op("OneHot", check_single_output_args({arg}))
op::OneHot::OneHot(const Output<Node>& arg, const PartialShape& shape, size_t one_hot_axis)
: Op({arg})
, m_shape(shape) , m_shape(shape)
, m_one_hot_axis(one_hot_axis) , m_one_hot_axis(one_hot_axis)
{ {
......
...@@ -45,14 +45,17 @@ namespace ngraph ...@@ -45,14 +45,17 @@ namespace ngraph
class OneHot : public Op class OneHot : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a one-hot operation.
OneHot() = default;
/// \brief Constructs a one-hot operation. /// \brief Constructs a one-hot operation.
/// ///
/// \param arg Node that produces the input tensor to be one-hot encoded. /// \param arg Node that produces the input tensor to be one-hot encoded.
/// \param shape The shape of the output tensor, including the new one-hot axis. /// \param shape The shape of the output tensor, including the new one-hot axis.
/// \param one_hot_axis The index within the output shape of the new one-hot axis. /// \param one_hot_axis The index within the output shape of the new one-hot axis.
OneHot(const std::shared_ptr<Node>& arg, OneHot(const Output<Node>& arg, const PartialShape& shape, size_t one_hot_axis);
const PartialShape& shape,
size_t one_hot_axis);
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
...@@ -60,6 +63,7 @@ namespace ngraph ...@@ -60,6 +63,7 @@ namespace ngraph
/// \return The index of the one-hot axis. /// \return The index of the one-hot axis.
size_t get_one_hot_axis() const { return m_one_hot_axis; } size_t get_one_hot_axis() const { return m_one_hot_axis; }
void set_one_hot_axis(size_t one_hot_axis) { m_one_hot_axis = one_hot_axis; }
protected: protected:
PartialShape m_shape; PartialShape m_shape;
size_t m_one_hot_axis; size_t m_one_hot_axis;
......
...@@ -19,10 +19,10 @@ ...@@ -19,10 +19,10 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::Or::Or(const shared_ptr<Node>& arg0, const string op::Or::type_name{"Or"};
const shared_ptr<Node>& arg1,
const AutoBroadcastSpec& autob) op::Or::Or(const Output<Node>& arg0, const Output<Node>& arg1, const AutoBroadcastSpec& autob)
: BinaryElementwiseLogical("Or", arg0, arg1, autob) : BinaryElementwiseLogical(arg0, arg1, autob)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -29,6 +29,9 @@ namespace ngraph ...@@ -29,6 +29,9 @@ namespace ngraph
class Or : public util::BinaryElementwiseLogical class Or : public util::BinaryElementwiseLogical
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a logical-or operation. /// \brief Constructs a logical-or operation.
/// ///
/// \param arg0 Node that produces the first input tensor.<br> /// \param arg0 Node that produces the first input tensor.<br>
...@@ -39,15 +42,14 @@ namespace ngraph ...@@ -39,15 +42,14 @@ namespace ngraph
/// ///
/// Output `[d0, ...]` /// Output `[d0, ...]`
/// ///
Or(const std::shared_ptr<Node>& arg0, Or(const Output<Node>& arg0,
const std::shared_ptr<Node>& arg1, const Output<Node>& arg1,
const AutoBroadcastSpec& autob = AutoBroadcastSpec()); const AutoBroadcastSpec& autob = AutoBroadcastSpec());
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
protected: virtual bool is_commutative() const override { return true; }
virtual bool is_commutative() override { return true; }
}; };
} }
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment