Unverified Commit 11d93be4 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Merge branch 'master' into bob/fused_tolerance

parents fe89d17b 45ad33ea
...@@ -21,7 +21,7 @@ include(ExternalProject) ...@@ -21,7 +21,7 @@ include(ExternalProject)
# ONNX.proto definition version # ONNX.proto definition version
#------------------------------------------------------------------------------ #------------------------------------------------------------------------------
set(ONNX_VERSION 1.3.0) set(ONNX_VERSION 1.5.0)
#------------------------------------------------------------------------------ #------------------------------------------------------------------------------
# Download and install libonnx ... # Download and install libonnx ...
...@@ -30,6 +30,9 @@ set(ONNX_VERSION 1.3.0) ...@@ -30,6 +30,9 @@ set(ONNX_VERSION 1.3.0)
set(ONNX_GIT_REPO_URL https://github.com/onnx/onnx.git) set(ONNX_GIT_REPO_URL https://github.com/onnx/onnx.git)
set(ONNX_GIT_BRANCH rel-${ONNX_VERSION}) set(ONNX_GIT_BRANCH rel-${ONNX_VERSION})
add_definitions(-DONNX_BUILD_SHARED_LIBS=ON)
add_definitions(-DONNX_ML=ON)
ExternalProject_Add( ExternalProject_Add(
ext_onnx ext_onnx
PREFIX onnx PREFIX onnx
...@@ -58,8 +61,8 @@ ExternalProject_Add( ...@@ -58,8 +61,8 @@ ExternalProject_Add(
ExternalProject_Get_Property(ext_onnx SOURCE_DIR BINARY_DIR) ExternalProject_Get_Property(ext_onnx SOURCE_DIR BINARY_DIR)
set(ONNX_INCLUDE_DIR ${SOURCE_DIR}/onnx) set(ONNX_INCLUDE_DIR ${SOURCE_DIR})
set(ONNX_PROTO_INCLUDE_DIR ${BINARY_DIR}/onnx) set(ONNX_PROTO_INCLUDE_DIR ${BINARY_DIR})
if (WIN32) if (WIN32)
set(ONNX_LIBRARY ${BINARY_DIR}/${CMAKE_BUILD_TYPE}/onnx.lib) set(ONNX_LIBRARY ${BINARY_DIR}/${CMAKE_BUILD_TYPE}/onnx.lib)
set(ONNX_PROTO_LIBRARY ${BINARY_DIR}/${CMAKE_BUILD_TYPE}/onnx_proto.lib) set(ONNX_PROTO_LIBRARY ${BINARY_DIR}/${CMAKE_BUILD_TYPE}/onnx_proto.lib)
......
...@@ -37,6 +37,7 @@ ngraph.ops ...@@ -37,6 +37,7 @@ ngraph.ops
equal equal
exp exp
floor floor
gelu
get_output_element get_output_element
greater greater
greater_eq greater_eq
......
...@@ -50,6 +50,7 @@ from ngraph.ops import elu ...@@ -50,6 +50,7 @@ from ngraph.ops import elu
from ngraph.ops import equal from ngraph.ops import equal
from ngraph.ops import exp from ngraph.ops import exp
from ngraph.ops import floor from ngraph.ops import floor
from ngraph.ops import gelu
from ngraph.ops import get_output_element from ngraph.ops import get_output_element
from ngraph.ops import greater from ngraph.ops import greater
from ngraph.ops import greater_eq from ngraph.ops import greater_eq
......
...@@ -74,6 +74,7 @@ from _pyngraph.op import Elu ...@@ -74,6 +74,7 @@ from _pyngraph.op import Elu
from _pyngraph.op import Equal from _pyngraph.op import Equal
from _pyngraph.op import Exp from _pyngraph.op import Exp
from _pyngraph.op import Floor from _pyngraph.op import Floor
from _pyngraph.op import Gelu
from _pyngraph.op import GetOutputElement from _pyngraph.op import GetOutputElement
from _pyngraph.op import Greater from _pyngraph.op import Greater
from _pyngraph.op import GreaterEq from _pyngraph.op import GreaterEq
......
...@@ -23,7 +23,7 @@ from ngraph.impl import AxisSet, AxisVector, Coordinate, CoordinateDiff, Functio ...@@ -23,7 +23,7 @@ from ngraph.impl import AxisSet, AxisVector, Coordinate, CoordinateDiff, Functio
from ngraph.impl.op import Abs, Acos, Add, And, Asin, ArgMax, ArgMin, Atan, AvgPool, \ from ngraph.impl.op import Abs, Acos, Add, And, Asin, ArgMax, ArgMin, Atan, AvgPool, \
BatchNormTraining, BatchNormInference, Broadcast, Ceiling, Clamp, Concat, Constant, Convert, \ BatchNormTraining, BatchNormInference, Broadcast, Ceiling, Clamp, Concat, Constant, Convert, \
Convolution, ConvolutionBackpropData, Cos, Cosh, Divide, Dot, Elu, Equal, Exp, Floor, \ Convolution, ConvolutionBackpropData, Cos, Cosh, Divide, Dot, Elu, Equal, Exp, Floor, \
GetOutputElement, Greater, GreaterEq, Less, LessEq, Log, LRN, Max, Maximum, MaxPool, \ Gelu, GetOutputElement, Greater, GreaterEq, Less, LessEq, Log, LRN, Max, Maximum, MaxPool, \
Min, Minimum, Multiply, Negative, Not, NotEqual, OneHot, Or, Pad, Parameter, Product, \ Min, Minimum, Multiply, Negative, Not, NotEqual, OneHot, Or, Pad, Parameter, Product, \
Power, Relu, ReplaceSlice, Reshape, Reverse, Select, Sign, Sin, Sinh, Slice, Softmax, \ Power, Relu, ReplaceSlice, Reshape, Reverse, Select, Sign, Sin, Sinh, Slice, Softmax, \
Sqrt, Subtract, Sum, Tan, Tanh, TopK Sqrt, Subtract, Sum, Tan, Tanh, TopK
...@@ -527,6 +527,24 @@ def convert(node, new_type, name=None): # type: (Node, NumericType, str) -> Nod ...@@ -527,6 +527,24 @@ def convert(node, new_type, name=None): # type: (Node, NumericType, str) -> Nod
return Convert(node, new_element_type) return Convert(node, new_element_type)
@nameable_op
def gelu(node, name=None): # type: (NodeInput, str) -> Node
r"""Perform Gaussian Error Linear Unit operation element-wise on data from input node.
Computes GELU function:
.. math:: f(x) = 0.5\cdot x\cdot(1 + erf( \dfrac{x}{\sqrt{2}})
For more information refer to:
`Gaussian Error Linear Unit (GELU) <https://arxiv.org/pdf/1606.08415.pdf>`_
:param node: Input tensor. One of: input node, array or scalar.
:param name: Optional output node name.
:return: The new node performing a GELU operation on its input data element-wise.
"""
return Gelu(as_node(node))
@nameable_op @nameable_op
def select(selection_node, input_node1, input_node2, name=None): def select(selection_node, input_node1, input_node2, name=None):
# type: (Node, Node, Node, str) -> Node # type: (Node, Node, Node, str) -> Node
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "ngraph/op/fused/gelu.hpp"
#include "pyngraph/ops/fused/gelu.hpp"
namespace py = pybind11;
void regclass_pyngraph_op_Gelu(py::module m)
{
py::class_<ngraph::op::Gelu, std::shared_ptr<ngraph::op::Gelu>, ngraph::op::Op> gelu(m, "Gelu");
gelu.doc() = "ngraph.impl.op.Gelu wraps ngraph::op::Gelu";
gelu.def(py::init<const std::shared_ptr<ngraph::Node>&>());
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <pybind11/pybind11.h>
namespace py = pybind11;
void regclass_pyngraph_op_Gelu(py::module m);
...@@ -54,6 +54,7 @@ void regmodule_pyngraph_op(py::module m_op) ...@@ -54,6 +54,7 @@ void regmodule_pyngraph_op(py::module m_op)
regclass_pyngraph_op_Equal(m_op); regclass_pyngraph_op_Equal(m_op);
regclass_pyngraph_op_Exp(m_op); regclass_pyngraph_op_Exp(m_op);
regclass_pyngraph_op_Floor(m_op); regclass_pyngraph_op_Floor(m_op);
regclass_pyngraph_op_Gelu(m_op);
regclass_pyngraph_op_GetOutputElement(m_op); regclass_pyngraph_op_GetOutputElement(m_op);
regclass_pyngraph_op_Greater(m_op); regclass_pyngraph_op_Greater(m_op);
regclass_pyngraph_op_GreaterEq(m_op); regclass_pyngraph_op_GreaterEq(m_op);
......
...@@ -44,6 +44,7 @@ ...@@ -44,6 +44,7 @@
#include "pyngraph/ops/floor.hpp" #include "pyngraph/ops/floor.hpp"
#include "pyngraph/ops/fused/clamp.hpp" #include "pyngraph/ops/fused/clamp.hpp"
#include "pyngraph/ops/fused/elu.hpp" #include "pyngraph/ops/fused/elu.hpp"
#include "pyngraph/ops/fused/gelu.hpp"
#include "pyngraph/ops/get_output_element.hpp" #include "pyngraph/ops/get_output_element.hpp"
#include "pyngraph/ops/greater.hpp" #include "pyngraph/ops/greater.hpp"
#include "pyngraph/ops/greater_eq.hpp" #include "pyngraph/ops/greater_eq.hpp"
......
...@@ -184,6 +184,7 @@ sources = [ ...@@ -184,6 +184,7 @@ sources = [
'pyngraph/ops/equal.cpp', 'pyngraph/ops/equal.cpp',
'pyngraph/ops/exp.cpp', 'pyngraph/ops/exp.cpp',
'pyngraph/ops/floor.cpp', 'pyngraph/ops/floor.cpp',
'pyngraph/ops/fused/gelu.cpp',
'pyngraph/ops/greater.cpp', 'pyngraph/ops/greater.cpp',
'pyngraph/ops/greater_eq.cpp', 'pyngraph/ops/greater_eq.cpp',
'pyngraph/ops/less.cpp', 'pyngraph/ops/less.cpp',
......
...@@ -19,7 +19,7 @@ import ngraph as ng ...@@ -19,7 +19,7 @@ import ngraph as ng
from test.ngraph.util import get_runtime from test.ngraph.util import get_runtime
def test_elu_operator(): def test_elu_operator_with_parameters():
runtime = get_runtime() runtime = get_runtime()
data_shape = [2, 2] data_shape = [2, 2]
...@@ -69,6 +69,38 @@ def test_elu_operator_with_scalar(): ...@@ -69,6 +69,38 @@ def test_elu_operator_with_scalar():
assert np.allclose(result, expected) assert np.allclose(result, expected)
def test_gelu_operator_with_parameters():
runtime = get_runtime()
data_value = np.array([[-5, 1], [-2, 3]], dtype=np.float32)
data_shape = [2, 2]
parameter_data = ng.parameter(data_shape, name='Data', dtype=np.float32)
model = ng.gelu(parameter_data)
computation = runtime.computation(model, parameter_data)
result = computation(data_value)
expected = np.array([[-1.4901161e-06, 8.4134471e-01], [-4.5500278e-02, 2.9959502]],
dtype=np.float32)
assert np.allclose(result, expected)
def test_gelu_operator_with_array():
runtime = get_runtime()
data_value = np.array([[-5, 1], [-2, 3]], dtype=np.float32)
model = ng.gelu(data_value)
computation = runtime.computation(model)
result = computation()
expected = np.array([[-1.4901161e-06, 8.4134471e-01], [-4.5500278e-02, 2.9959502]],
dtype=np.float32)
assert np.allclose(result, expected)
def test_clamp_operator(): def test_clamp_operator():
runtime = get_runtime() runtime = get_runtime()
...@@ -99,4 +131,5 @@ def test_clamp_operator_with_array(): ...@@ -99,4 +131,5 @@ def test_clamp_operator_with_array():
result = computation() result = computation()
expected = np.clip(data_value, min_value, max_value) expected = np.clip(data_value, min_value, max_value)
assert np.allclose(result, expected) assert np.allclose(result, expected)
...@@ -43,6 +43,8 @@ namespace ...@@ -43,6 +43,8 @@ namespace
using namespace mlir::edsc::op; using namespace mlir::edsc::op;
using namespace ngraph::runtime; using namespace ngraph::runtime;
using namespace ngraph::runtime::ngmlir; using namespace ngraph::runtime::ngmlir;
// Index notation to generate standard (i.e., non-affine) loads and stores.
using StdIndexedValue = TemplatedIndexedValue<intrinsics::std_load, intrinsics::std_store>;
class DialectLoweringPass; class DialectLoweringPass;
...@@ -682,7 +684,8 @@ namespace ...@@ -682,7 +684,8 @@ namespace
// Create view to write into result. // Create view to write into result.
MemRefView vRes(result), vParams(params), vIndices(indices); MemRefView vRes(result), vParams(params), vIndices(indices);
// Indexed Values // Indexed Values
IndexedValue iRes(result), iParams(params), iIndices(indices); IndexedValue iRes(result), iIndices(indices);
StdIndexedValue iParams(params);
// Construct outer loop for params dims. Exclude the axis dim. // Construct outer loop for params dims. Exclude the axis dim.
SmallVector<ValueHandle, 4> paramsLbs, paramsUbs; SmallVector<ValueHandle, 4> paramsLbs, paramsUbs;
...@@ -894,7 +897,8 @@ namespace ...@@ -894,7 +897,8 @@ namespace
// Views // Views
MemRefView vRes(result), vArg(arg); MemRefView vRes(result), vArg(arg);
// Index Values // Index Values
IndexedValue iRes(result), iArg(arg); StdIndexedValue iRes(result), stdArg(arg);
IndexedValue affineArg(arg);
// Bounds Index Handles // Bounds Index Handles
auto resLbs = vRes.getLbs(); auto resLbs = vRes.getLbs();
auto resUbs = vRes.getUbs(); auto resUbs = vRes.getUbs();
...@@ -944,9 +948,9 @@ namespace ...@@ -944,9 +948,9 @@ namespace
ValueHandle newRedIdx = ValueHandle newRedIdx =
std::is_same<RedOp, NGArgMinRedOp>() std::is_same<RedOp, NGArgMinRedOp>()
? edsc::intrinsics::select( ? edsc::intrinsics::select(
iArg(allIVs) < iArg(tempIVs), allIVs[axis], currRedIdx) affineArg(allIVs) < stdArg(tempIVs), allIVs[axis], currRedIdx)
: edsc::intrinsics::select( : edsc::intrinsics::select(
iArg(tempIVs) < iArg(allIVs), allIVs[axis], currRedIdx); stdArg(tempIVs) < affineArg(allIVs), allIVs[axis], currRedIdx);
iRes(nonRedIVs) = ValueHandle::create<IndexCastOp>(newRedIdx, resTy); iRes(nonRedIVs) = ValueHandle::create<IndexCastOp>(newRedIdx, resTy);
}); });
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include <onnx-ml.pb.h> #include <onnx/onnx_pb.h>
#include "ngraph/except.hpp" #include "ngraph/except.hpp"
#include "tensor.hpp" #include "tensor.hpp"
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include <onnx-ml.pb.h> #include <onnx/onnx_pb.h>
#include <string> #include <string>
#include <vector> #include <vector>
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
// limitations under the License. // limitations under the License.
//***************************************************************************** //*****************************************************************************
#include <onnx-ml.pb.h> #include <onnx/onnx_pb.h>
#include "model.hpp" #include "model.hpp"
#include "ngraph/log.hpp" #include "ngraph/log.hpp"
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include <onnx-ml.pb.h> #include <onnx/onnx_pb.h>
#include <ostream> #include <ostream>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
// limitations under the License. // limitations under the License.
//***************************************************************************** //*****************************************************************************
#include <onnx-ml.pb.h> #include <onnx/onnx_pb.h>
#include "attribute.hpp" #include "attribute.hpp"
#include "graph.hpp" #include "graph.hpp"
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include <onnx-ml.pb.h> #include <onnx/onnx_pb.h>
#include <utility> #include <utility>
#include <vector> #include <vector>
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include <onnx-ml.pb.h> #include <onnx/onnx_pb.h>
#include "ngraph/op/constant.hpp" #include "ngraph/op/constant.hpp"
#include "ngraph/op/parameter.hpp" #include "ngraph/op/parameter.hpp"
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
//***************************************************************************** //*****************************************************************************
#include <onnx-ml.pb.h> // onnx types #include <onnx/onnx_pb.h> // onnx types
#include "common.hpp" #include "common.hpp"
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
//***************************************************************************** //*****************************************************************************
#include <cstdlib> // std::size_t, std::uintptr_t #include <cstdlib> // std::size_t, std::uintptr_t
#include <onnxifi.h> #include <onnx/onnxifi.h>
#include <stdexcept> // std::invalid_agrument, std::out_of_rage #include <stdexcept> // std::invalid_agrument, std::out_of_rage
#include "backend.hpp" #include "backend.hpp"
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
#include <cstddef> // std::size_t, std::uintptr_t #include <cstddef> // std::size_t, std::uintptr_t
#include <map> // std::map #include <map> // std::map
#include <mutex> // std::mutex #include <mutex> // std::mutex
#include <onnxifi.h> #include <onnx/onnxifi.h>
#include "backend.hpp" #include "backend.hpp"
#include "ngraph/runtime/backend.hpp" #include "ngraph/runtime/backend.hpp"
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include <onnxifi.h> #include <onnx/onnxifi.h>
namespace ngraph namespace ngraph
{ {
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#include <cstddef> #include <cstddef>
#include <cstdint> #include <cstdint>
#include <onnxifi.h> #include <onnx/onnxifi.h>
#include <stdexcept> #include <stdexcept>
#include "backend_manager.hpp" #include "backend_manager.hpp"
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
#pragma once #pragma once
#include <memory> #include <memory>
#include <onnxifi.h> #include <onnx/onnxifi.h>
#include "ngraph/runtime/backend.hpp" #include "ngraph/runtime/backend.hpp"
#include "ngraph/runtime/tensor.hpp" #include "ngraph/runtime/tensor.hpp"
......
...@@ -173,6 +173,7 @@ namespace ngraph ...@@ -173,6 +173,7 @@ namespace ngraph
class AvgPoolBackprop : public Op class AvgPoolBackprop : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
AvgPoolBackprop() = default; AvgPoolBackprop() = default;
......
...@@ -92,6 +92,7 @@ namespace ngraph ...@@ -92,6 +92,7 @@ namespace ngraph
class BatchNormInference : public Op class BatchNormInference : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
BatchNormInference() = default; BatchNormInference() = default;
......
...@@ -24,8 +24,10 @@ ...@@ -24,8 +24,10 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
const string op::BatchMatMul::type_name{"BatchMatMul"};
op::BatchMatMul::BatchMatMul(const shared_ptr<Node>& arg0, const shared_ptr<Node>& arg1) op::BatchMatMul::BatchMatMul(const shared_ptr<Node>& arg0, const shared_ptr<Node>& arg1)
: Op("BatchMatMul", check_single_output_args({arg0, arg1})) : Op(check_single_output_args({arg0, arg1}))
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -32,6 +32,9 @@ namespace ngraph ...@@ -32,6 +32,9 @@ namespace ngraph
class BatchMatMul : public Op class BatchMatMul : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a batch of matmul product operation. /// \brief Constructs a batch of matmul product operation.
/// ///
/// \param arg0 The node producing the first argument. /// \param arg0 The node producing the first argument.
......
...@@ -23,6 +23,8 @@ ...@@ -23,6 +23,8 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
const string op::CompiledKernel::type_name{"CompiledKernel"};
shared_ptr<Node> ngraph::op::CompiledKernel::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> ngraph::op::CompiledKernel::copy_with_new_args(const NodeVector& new_args) const
{ {
auto args = inputs(); auto args = inputs();
...@@ -64,7 +66,7 @@ shared_ptr<Node> ngraph::op::CompiledKernel::copy_with_new_args(const NodeVector ...@@ -64,7 +66,7 @@ shared_ptr<Node> ngraph::op::CompiledKernel::copy_with_new_args(const NodeVector
ngraph::op::CompiledKernel::CompiledKernel(const NodeVector& node_list, ngraph::op::CompiledKernel::CompiledKernel(const NodeVector& node_list,
const NodeVector& outputs, const NodeVector& outputs,
const NodeVector& args) const NodeVector& args)
: Op("CompiledKernel", check_single_output_args({args})) : Op(check_single_output_args({args}))
, m_node_list(node_list) , m_node_list(node_list)
, m_output_nodes(outputs) , m_output_nodes(outputs)
{ {
......
...@@ -32,6 +32,9 @@ namespace ngraph ...@@ -32,6 +32,9 @@ namespace ngraph
class CompiledKernel : public ngraph::op::Op class CompiledKernel : public ngraph::op::Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
CompiledKernel(const NodeVector& node_list, CompiledKernel(const NodeVector& node_list,
const NodeVector& outputs, const NodeVector& outputs,
const NodeVector& args); const NodeVector& args);
......
...@@ -20,10 +20,12 @@ ...@@ -20,10 +20,12 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
const string op::DynBroadcast::type_name{"DynBroadcast"};
op::DynBroadcast::DynBroadcast(const shared_ptr<Node>& arg, op::DynBroadcast::DynBroadcast(const shared_ptr<Node>& arg,
const shared_ptr<Node>& shape, const shared_ptr<Node>& shape,
const shared_ptr<Node>& broadcast_axes) const shared_ptr<Node>& broadcast_axes)
: Op("DynBroadcast", check_single_output_args({arg, shape, broadcast_axes})) : Op(check_single_output_args({arg, shape, broadcast_axes}))
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -28,6 +28,9 @@ namespace ngraph ...@@ -28,6 +28,9 @@ namespace ngraph
class DynBroadcast : public Op class DynBroadcast : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a dynamic broadcast operation. /// \brief Constructs a dynamic broadcast operation.
/// ///
/// \param arg Node that produces the input tensor to be broadcast. /// \param arg Node that produces the input tensor to be broadcast.
......
...@@ -19,12 +19,14 @@ ...@@ -19,12 +19,14 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
const string op::DynPad::type_name{"DynPad"};
op::DynPad::DynPad(const std::shared_ptr<Node>& arg, op::DynPad::DynPad(const std::shared_ptr<Node>& arg,
const std::shared_ptr<Node>& padding_below, const std::shared_ptr<Node>& padding_below,
const std::shared_ptr<Node>& padding_above, const std::shared_ptr<Node>& padding_above,
const std::shared_ptr<Node>& padding_value, const std::shared_ptr<Node>& padding_value,
op::PadMode pad_mode) op::PadMode pad_mode)
: Op("DynPad", check_single_output_args({arg, padding_below, padding_above, padding_value})) : Op(check_single_output_args({arg, padding_below, padding_above, padding_value}))
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -27,6 +27,9 @@ namespace ngraph ...@@ -27,6 +27,9 @@ namespace ngraph
class DynPad : public Op class DynPad : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Perform dynamic padding of a tensor /// \brief Perform dynamic padding of a tensor
/// ///
/// \param arg The node producing input tensor to be padded. /// \param arg The node producing input tensor to be padded.
......
...@@ -24,6 +24,8 @@ ...@@ -24,6 +24,8 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
const string op::DynReplaceSlice::type_name{"DynReplaceSlice"};
op::DynReplaceSlice::DynReplaceSlice(const shared_ptr<Node>& arg, op::DynReplaceSlice::DynReplaceSlice(const shared_ptr<Node>& arg,
const shared_ptr<Node>& replacement, const shared_ptr<Node>& replacement,
const shared_ptr<Node>& lower_bounds, const shared_ptr<Node>& lower_bounds,
...@@ -34,8 +36,7 @@ op::DynReplaceSlice::DynReplaceSlice(const shared_ptr<Node>& arg, ...@@ -34,8 +36,7 @@ op::DynReplaceSlice::DynReplaceSlice(const shared_ptr<Node>& arg,
const AxisSet& new_axis, const AxisSet& new_axis,
const AxisSet& shrink_axis, const AxisSet& shrink_axis,
const AxisSet& ellipsis_mask) const AxisSet& ellipsis_mask)
: Op("DynReplaceSlice", : Op(check_single_output_args({arg, replacement, lower_bounds, upper_bounds, strides}))
check_single_output_args({arg, replacement, lower_bounds, upper_bounds, strides}))
, m_lower_bounds_mask(lower_bounds_mask) , m_lower_bounds_mask(lower_bounds_mask)
, m_upper_bounds_mask(upper_bounds_mask) , m_upper_bounds_mask(upper_bounds_mask)
, m_new_axis(new_axis) , m_new_axis(new_axis)
......
...@@ -27,6 +27,9 @@ namespace ngraph ...@@ -27,6 +27,9 @@ namespace ngraph
class DynReplaceSlice : public Op class DynReplaceSlice : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a dynamic tensor replace-slice operation. /// \brief Constructs a dynamic tensor replace-slice operation.
/// ///
/// \param arg The tensor in which to replace the slice. /// \param arg The tensor in which to replace the slice.
......
...@@ -24,10 +24,12 @@ ...@@ -24,10 +24,12 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
const string op::DynReshape::type_name{"DynReshape"};
op::DynReshape::DynReshape(const shared_ptr<Node>& arg, op::DynReshape::DynReshape(const shared_ptr<Node>& arg,
const shared_ptr<Node>& pattern, const shared_ptr<Node>& pattern,
bool zero_flag) bool zero_flag)
: Op("DynReshape", check_single_output_args({arg, pattern})) : Op(check_single_output_args({arg, pattern}))
, m_zero_flag(zero_flag) , m_zero_flag(zero_flag)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
......
...@@ -31,6 +31,9 @@ namespace ngraph ...@@ -31,6 +31,9 @@ namespace ngraph
class DynReshape : public Op class DynReshape : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a dynamic reshape operation. This operation does not perform transpose. /// \brief Constructs a dynamic reshape operation. This operation does not perform transpose.
/// ///
/// \param arg The tensor to be reshaped. /// \param arg The tensor to be reshaped.
......
...@@ -24,6 +24,8 @@ ...@@ -24,6 +24,8 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
const string op::DynSlice::type_name{"DynSlice"};
op::DynSlice::DynSlice(const shared_ptr<Node>& arg, op::DynSlice::DynSlice(const shared_ptr<Node>& arg,
const shared_ptr<Node>& lower_bounds, const shared_ptr<Node>& lower_bounds,
const shared_ptr<Node>& upper_bounds, const shared_ptr<Node>& upper_bounds,
...@@ -33,7 +35,7 @@ op::DynSlice::DynSlice(const shared_ptr<Node>& arg, ...@@ -33,7 +35,7 @@ op::DynSlice::DynSlice(const shared_ptr<Node>& arg,
const AxisSet& new_axis, const AxisSet& new_axis,
const AxisSet& shrink_axis, const AxisSet& shrink_axis,
const AxisSet& ellipsis_mask) const AxisSet& ellipsis_mask)
: Op("DynSlice", check_single_output_args({arg, lower_bounds, upper_bounds, strides})) : Op(check_single_output_args({arg, lower_bounds, upper_bounds, strides}))
, m_lower_bounds_mask(lower_bounds_mask) , m_lower_bounds_mask(lower_bounds_mask)
, m_upper_bounds_mask(upper_bounds_mask) , m_upper_bounds_mask(upper_bounds_mask)
, m_new_axis(new_axis) , m_new_axis(new_axis)
......
...@@ -27,6 +27,9 @@ namespace ngraph ...@@ -27,6 +27,9 @@ namespace ngraph
class DynSlice : public Op class DynSlice : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a dynamic tensor slice operation. /// \brief Constructs a dynamic tensor slice operation.
/// ///
/// \param arg The tensor to be sliced. /// \param arg The tensor to be sliced.
......
...@@ -19,10 +19,12 @@ ...@@ -19,10 +19,12 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
const string op::CTCGreedyDecoder::type_name{"CTCGreedyDecoder"};
op::CTCGreedyDecoder::CTCGreedyDecoder(const shared_ptr<Node>& input, op::CTCGreedyDecoder::CTCGreedyDecoder(const shared_ptr<Node>& input,
const std::shared_ptr<Node>& seq_len, const std::shared_ptr<Node>& seq_len,
const bool ctc_merge_repeated) const bool ctc_merge_repeated)
: Op("CTCGreedyDecoder", check_single_output_args({input, seq_len})) : Op(check_single_output_args({input, seq_len}))
, m_ctc_merge_repeated(ctc_merge_repeated) , m_ctc_merge_repeated(ctc_merge_repeated)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
......
...@@ -25,6 +25,9 @@ namespace ngraph ...@@ -25,6 +25,9 @@ namespace ngraph
class CTCGreedyDecoder : public Op class CTCGreedyDecoder : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a CTCGreedyDecoder operation /// \brief Constructs a CTCGreedyDecoder operation
/// ///
/// \param input Logits on which greedy decoding is performed /// \param input Logits on which greedy decoding is performed
......
...@@ -21,15 +21,16 @@ ...@@ -21,15 +21,16 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
const string op::DetectionOutput::type_name{"DetectionOutput"};
op::DetectionOutput::DetectionOutput(const std::shared_ptr<Node>& box_logits, op::DetectionOutput::DetectionOutput(const std::shared_ptr<Node>& box_logits,
const std::shared_ptr<Node>& class_preds, const std::shared_ptr<Node>& class_preds,
const std::shared_ptr<Node>& proposals, const std::shared_ptr<Node>& proposals,
const std::shared_ptr<Node>& aux_class_preds, const std::shared_ptr<Node>& aux_class_preds,
const std::shared_ptr<Node>& aux_box_preds, const std::shared_ptr<Node>& aux_box_preds,
const DetectionOutputAttrs& attrs) const DetectionOutputAttrs& attrs)
: Op("DetectionOutput", : Op(check_single_output_args(
check_single_output_args( {box_logits, class_preds, proposals, aux_class_preds, aux_box_preds}))
{box_logits, class_preds, proposals, aux_class_preds, aux_box_preds}))
, m_attrs(attrs) , m_attrs(attrs)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
......
...@@ -47,6 +47,9 @@ namespace ngraph ...@@ -47,6 +47,9 @@ namespace ngraph
class DetectionOutput : public Op class DetectionOutput : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a DetectionOutput operation /// \brief Constructs a DetectionOutput operation
/// ///
/// \param box_logits Box logits /// \param box_logits Box logits
......
...@@ -21,10 +21,12 @@ ...@@ -21,10 +21,12 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
const string op::Interpolate::type_name{"Interpolate"};
op::Interpolate::Interpolate(const std::shared_ptr<Node>& image, op::Interpolate::Interpolate(const std::shared_ptr<Node>& image,
const std::shared_ptr<Node>& output_shape, const std::shared_ptr<Node>& output_shape,
const InterpolateAttrs& attrs) const InterpolateAttrs& attrs)
: Op("Interpolate", check_single_output_args({image, output_shape})) : Op(check_single_output_args({image, output_shape}))
, m_attrs(attrs) , m_attrs(attrs)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
......
...@@ -36,6 +36,9 @@ namespace ngraph ...@@ -36,6 +36,9 @@ namespace ngraph
class Interpolate : public Op class Interpolate : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a Interpolate operation /// \brief Constructs a Interpolate operation
/// ///
/// \param image Input image /// \param image Input image
......
...@@ -21,10 +21,12 @@ ...@@ -21,10 +21,12 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
const string op::PriorBox::type_name{"PriorBox"};
op::PriorBox::PriorBox(const shared_ptr<Node>& layer_shape, op::PriorBox::PriorBox(const shared_ptr<Node>& layer_shape,
const shared_ptr<Node>& image_shape, const shared_ptr<Node>& image_shape,
const PriorBoxAttrs& attrs) const PriorBoxAttrs& attrs)
: Op("PriorBox", check_single_output_args({layer_shape, image_shape})) : Op(check_single_output_args({layer_shape, image_shape}))
, m_attrs(attrs) , m_attrs(attrs)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
......
...@@ -49,6 +49,9 @@ namespace ngraph ...@@ -49,6 +49,9 @@ namespace ngraph
class PriorBox : public Op class PriorBox : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a PriorBox operation /// \brief Constructs a PriorBox operation
/// ///
/// \param layer_shape Shape of layer for which prior boxes are computed /// \param layer_shape Shape of layer for which prior boxes are computed
......
...@@ -21,10 +21,12 @@ ...@@ -21,10 +21,12 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
const string op::PriorBoxClustered::type_name{"PriorBoxClustered"};
op::PriorBoxClustered::PriorBoxClustered(const shared_ptr<Node>& layer_shape, op::PriorBoxClustered::PriorBoxClustered(const shared_ptr<Node>& layer_shape,
const shared_ptr<Node>& image_shape, const shared_ptr<Node>& image_shape,
const PriorBoxClusteredAttrs& attrs) const PriorBoxClusteredAttrs& attrs)
: Op("PriorBoxClustered", check_single_output_args({layer_shape, image_shape})) : Op(check_single_output_args({layer_shape, image_shape}))
, m_attrs(attrs) , m_attrs(attrs)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
......
...@@ -47,6 +47,9 @@ namespace ngraph ...@@ -47,6 +47,9 @@ namespace ngraph
class PriorBoxClustered : public Op class PriorBoxClustered : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a PriorBoxClustered operation /// \brief Constructs a PriorBoxClustered operation
/// ///
/// \param layer_shape Shape of layer for which prior boxes are computed /// \param layer_shape Shape of layer for which prior boxes are computed
......
...@@ -21,11 +21,13 @@ ...@@ -21,11 +21,13 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
const string op::Proposal::type_name{"Proposal"};
op::Proposal::Proposal(const std::shared_ptr<Node>& class_probs, op::Proposal::Proposal(const std::shared_ptr<Node>& class_probs,
const std::shared_ptr<Node>& class_logits, const std::shared_ptr<Node>& class_logits,
const std::shared_ptr<Node>& image_shape, const std::shared_ptr<Node>& image_shape,
const ProposalAttrs& attrs) const ProposalAttrs& attrs)
: Op("Proposal", check_single_output_args({class_probs, class_logits, image_shape})) : Op(check_single_output_args({class_probs, class_logits, image_shape}))
, m_attrs(attrs) , m_attrs(attrs)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
......
...@@ -57,6 +57,9 @@ namespace ngraph ...@@ -57,6 +57,9 @@ namespace ngraph
class Proposal : public Op class Proposal : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a Proposal operation /// \brief Constructs a Proposal operation
/// ///
/// \param class_probs Class probability scores /// \param class_probs Class probability scores
......
...@@ -19,6 +19,8 @@ ...@@ -19,6 +19,8 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
const string op::PSROIPooling::type_name{"PSROIPooling"};
op::PSROIPooling::PSROIPooling(const shared_ptr<Node>& input, op::PSROIPooling::PSROIPooling(const shared_ptr<Node>& input,
const std::shared_ptr<Node>& coords, const std::shared_ptr<Node>& coords,
const size_t output_dim, const size_t output_dim,
...@@ -26,7 +28,7 @@ op::PSROIPooling::PSROIPooling(const shared_ptr<Node>& input, ...@@ -26,7 +28,7 @@ op::PSROIPooling::PSROIPooling(const shared_ptr<Node>& input,
const float spatial_scale, const float spatial_scale,
const Shape& num_bins, const Shape& num_bins,
const std::string& kind) const std::string& kind)
: Op("PSROIPooling", check_single_output_args({input, coords})) : Op(check_single_output_args({input, coords}))
, m_output_dim(output_dim) , m_output_dim(output_dim)
, m_group_size(group_size) , m_group_size(group_size)
, m_spatial_scale(spatial_scale) , m_spatial_scale(spatial_scale)
......
...@@ -25,6 +25,9 @@ namespace ngraph ...@@ -25,6 +25,9 @@ namespace ngraph
class PSROIPooling : public Op class PSROIPooling : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a PSROIPooling operation /// \brief Constructs a PSROIPooling operation
/// ///
/// \param input Input feature map {N, C, ...} /// \param input Input feature map {N, C, ...}
......
...@@ -19,6 +19,8 @@ ...@@ -19,6 +19,8 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
const string op::RegionYolo::type_name{"RegionYolo"};
op::RegionYolo::RegionYolo(const shared_ptr<Node>& input, op::RegionYolo::RegionYolo(const shared_ptr<Node>& input,
const size_t num_coords, const size_t num_coords,
const size_t num_classes, const size_t num_classes,
...@@ -27,7 +29,7 @@ op::RegionYolo::RegionYolo(const shared_ptr<Node>& input, ...@@ -27,7 +29,7 @@ op::RegionYolo::RegionYolo(const shared_ptr<Node>& input,
const vector<int64_t>& mask, const vector<int64_t>& mask,
const int axis, const int axis,
const int end_axis) const int end_axis)
: Op("RegionYolo", check_single_output_args({input})) : Op(check_single_output_args({input}))
, m_num_coords(num_coords) , m_num_coords(num_coords)
, m_num_classes(num_classes) , m_num_classes(num_classes)
, m_num_regions(num_regions) , m_num_regions(num_regions)
......
...@@ -25,6 +25,9 @@ namespace ngraph ...@@ -25,6 +25,9 @@ namespace ngraph
class RegionYolo : public Op class RegionYolo : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a RegionYolo operation /// \brief Constructs a RegionYolo operation
/// ///
/// \param input Input /// \param input Input
......
...@@ -21,8 +21,10 @@ ...@@ -21,8 +21,10 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
const string op::ReorgYolo::type_name{"ReorgYolo"};
op::ReorgYolo::ReorgYolo(const shared_ptr<Node>& input, const Strides& strides) op::ReorgYolo::ReorgYolo(const shared_ptr<Node>& input, const Strides& strides)
: Op("ReorgYolo", check_single_output_args({input})) : Op(check_single_output_args({input}))
, m_strides(strides) , m_strides(strides)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
......
...@@ -25,6 +25,9 @@ namespace ngraph ...@@ -25,6 +25,9 @@ namespace ngraph
class ReorgYolo : public Op class ReorgYolo : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a ReorgYolo operation /// \brief Constructs a ReorgYolo operation
/// ///
/// \param input Input /// \param input Input
......
...@@ -19,12 +19,14 @@ ...@@ -19,12 +19,14 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
const string op::ROIPooling::type_name{"ROIPooling"};
op::ROIPooling::ROIPooling(const shared_ptr<Node>& input, op::ROIPooling::ROIPooling(const shared_ptr<Node>& input,
const std::shared_ptr<Node>& coords, const std::shared_ptr<Node>& coords,
const Shape& output_size, const Shape& output_size,
const float spatial_scale, const float spatial_scale,
const std::string& kind) const std::string& kind)
: Op("ROIPooling", check_single_output_args({input, coords})) : Op(check_single_output_args({input, coords}))
, m_output_size(output_size) , m_output_size(output_size)
, m_spatial_scale(spatial_scale) , m_spatial_scale(spatial_scale)
, m_kind(kind) , m_kind(kind)
......
...@@ -25,6 +25,9 @@ namespace ngraph ...@@ -25,6 +25,9 @@ namespace ngraph
class ROIPooling : public Op class ROIPooling : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a ROIPooling operation /// \brief Constructs a ROIPooling operation
/// ///
/// \param input Input feature map {N, C, ...} /// \param input Input feature map {N, C, ...}
......
...@@ -23,8 +23,10 @@ ...@@ -23,8 +23,10 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
const string op::QuantizedConcat::type_name{"QuantizedConcat"};
op::QuantizedConcat::QuantizedConcat(const NodeVector& args, size_t concatenation_axis) op::QuantizedConcat::QuantizedConcat(const NodeVector& args, size_t concatenation_axis)
: Op("QuantizedConcat", check_single_output_args(args)) : Op(check_single_output_args(args))
, m_concatenation_axis(concatenation_axis) , m_concatenation_axis(concatenation_axis)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
......
...@@ -28,6 +28,9 @@ namespace ngraph ...@@ -28,6 +28,9 @@ namespace ngraph
class QuantizedConcat : public Op class QuantizedConcat : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a concatenation operation. /// \brief Constructs a concatenation operation.
/// ///
/// \param args The nodes producing the input tensors. /// \param args The nodes producing the input tensors.
......
...@@ -20,8 +20,10 @@ ...@@ -20,8 +20,10 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
const string op::ShapeOf::type_name{"ShapeOf"};
op::ShapeOf::ShapeOf(const shared_ptr<Node>& arg) op::ShapeOf::ShapeOf(const shared_ptr<Node>& arg)
: Op("ShapeOf", check_single_output_args({arg})) : Op(check_single_output_args({arg}))
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -26,6 +26,9 @@ namespace ngraph ...@@ -26,6 +26,9 @@ namespace ngraph
class ShapeOf : public Op class ShapeOf : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a shape-of operation. /// \brief Constructs a shape-of operation.
ShapeOf(const std::shared_ptr<Node>& arg); ShapeOf(const std::shared_ptr<Node>& arg);
......
...@@ -21,8 +21,10 @@ ...@@ -21,8 +21,10 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
const string op::Tile::type_name{"Tile"};
op::Tile::Tile(const std::shared_ptr<Node>& arg, const std::shared_ptr<Node>& repeats) op::Tile::Tile(const std::shared_ptr<Node>& arg, const std::shared_ptr<Node>& repeats)
: Op("Tile", check_single_output_args({arg, repeats})) : Op(check_single_output_args({arg, repeats}))
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -27,6 +27,9 @@ namespace ngraph ...@@ -27,6 +27,9 @@ namespace ngraph
class Tile : public Op class Tile : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Perform dynamic padding of a tensor /// \brief Perform dynamic padding of a tensor
/// ///
/// \param arg The node producing input tensor to be padded. /// \param arg The node producing input tensor to be padded.
......
...@@ -22,8 +22,10 @@ ...@@ -22,8 +22,10 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
const string op::Transpose::type_name{"Transpose"};
op::Transpose::Transpose(const shared_ptr<Node>& arg, const shared_ptr<Node>& input_order) op::Transpose::Transpose(const shared_ptr<Node>& arg, const shared_ptr<Node>& input_order)
: Op("Transpose", check_single_output_args({arg, input_order})) : Op(check_single_output_args({arg, input_order}))
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -28,6 +28,9 @@ namespace ngraph ...@@ -28,6 +28,9 @@ namespace ngraph
class Transpose : public Op class Transpose : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a transpose operation. /// \brief Constructs a transpose operation.
/// ///
/// \param arg Node producing the tensor to be transposed. /// \param arg Node producing the tensor to be transposed.
......
...@@ -21,12 +21,14 @@ ...@@ -21,12 +21,14 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
const string op::Pad::type_name{"Pad"};
op::Pad::Pad(const shared_ptr<Node>& arg, op::Pad::Pad(const shared_ptr<Node>& arg,
const shared_ptr<Node>& arg_pad_value, const shared_ptr<Node>& arg_pad_value,
const CoordinateDiff& padding_below, const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above, const CoordinateDiff& padding_above,
PadMode pad_mode) PadMode pad_mode)
: Op("Pad", check_single_output_args({arg, arg_pad_value})) : Op(check_single_output_args({arg, arg_pad_value}))
, m_padding_below(padding_below) , m_padding_below(padding_below)
, m_padding_above(padding_above) , m_padding_above(padding_above)
, m_padding_interior_fake(padding_below.size()) , m_padding_interior_fake(padding_below.size())
......
...@@ -28,6 +28,9 @@ namespace ngraph ...@@ -28,6 +28,9 @@ namespace ngraph
class Pad : public Op class Pad : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a generic padding operation. /// \brief Constructs a generic padding operation.
/// ///
/// \param arg The node producing input tensor to be padded. /// \param arg The node producing input tensor to be padded.
......
...@@ -21,10 +21,12 @@ ...@@ -21,10 +21,12 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
const string op::Parameter::type_name{"Parameter"};
op::Parameter::Parameter(const element::Type& element_type, op::Parameter::Parameter(const element::Type& element_type,
const PartialShape& pshape, const PartialShape& pshape,
const bool cacheable) const bool cacheable)
: Op("Parameter", {}) : Op(NodeVector{})
, m_cacheable(cacheable) , m_cacheable(cacheable)
, m_partial_shape(pshape) , m_partial_shape(pshape)
, m_element_type(element_type) , m_element_type(element_type)
......
...@@ -35,6 +35,9 @@ namespace ngraph ...@@ -35,6 +35,9 @@ namespace ngraph
const NodeVector& deltas) override; const NodeVector& deltas) override;
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructions a tensor-typed parameter node. /// \brief Constructions a tensor-typed parameter node.
/// ///
/// \param element_type The element type of the parameter. /// \param element_type The element type of the parameter.
......
...@@ -18,12 +18,17 @@ ...@@ -18,12 +18,17 @@
#include "ngraph/op/passthrough.hpp" #include "ngraph/op/passthrough.hpp"
using namespace std;
using namespace ngraph;
const string op::Passthrough::type_name{"Passthrough"};
ngraph::op::Passthrough::Passthrough(const std::string& logical_type, ngraph::op::Passthrough::Passthrough(const std::string& logical_type,
const std::string& language, const std::string& language,
const std::string& function, const std::string& function,
const NodeVector& args, const NodeVector& args,
std::vector<std::tuple<element::Type, PartialShape>> outputs) std::vector<std::tuple<element::Type, PartialShape>> outputs)
: Op{"Passthrough", args} : Op{args}
, m_logical_type{logical_type} , m_logical_type{logical_type}
, m_language{language} , m_language{language}
, m_function{function} , m_function{function}
...@@ -65,5 +70,5 @@ std::shared_ptr<ngraph::Node> ...@@ -65,5 +70,5 @@ std::shared_ptr<ngraph::Node>
"Passthrough node input counts cannot be changed for a given Passthrough function"}; "Passthrough node input counts cannot be changed for a given Passthrough function"};
} }
return std::make_shared<Passthrough>( return std::make_shared<Passthrough>(
description(), m_language, m_function, new_args, m_output_shapes); m_logical_type, m_language, m_function, new_args, m_output_shapes);
} }
...@@ -38,6 +38,9 @@ namespace ngraph ...@@ -38,6 +38,9 @@ namespace ngraph
class ngraph::op::Passthrough final : public Op class ngraph::op::Passthrough final : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
Passthrough(const std::string& logical_type, // aka "What this operation is doing" Passthrough(const std::string& logical_type, // aka "What this operation is doing"
const std::string& language, // The language the implementation is written in const std::string& language, // The language the implementation is written in
const std::string& function, // The operation implementation const std::string& function, // The operation implementation
......
...@@ -22,10 +22,12 @@ ...@@ -22,10 +22,12 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
const string op::Power::type_name{"Power"};
op::Power::Power(const shared_ptr<Node>& arg0, op::Power::Power(const shared_ptr<Node>& arg0,
const shared_ptr<Node>& arg1, const shared_ptr<Node>& arg1,
const AutoBroadcastSpec& autob) const AutoBroadcastSpec& autob)
: BinaryElementwiseArithmetic("Power", arg0, arg1, autob) : BinaryElementwiseArithmetic(arg0, arg1, autob)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -39,6 +39,9 @@ namespace ngraph ...@@ -39,6 +39,9 @@ namespace ngraph
class Power : public util::BinaryElementwiseArithmetic class Power : public util::BinaryElementwiseArithmetic
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs an exponentiation operation. /// \brief Constructs an exponentiation operation.
/// ///
/// \param arg0 Node that produces the first input tensor. /// \param arg0 Node that produces the first input tensor.
......
...@@ -20,8 +20,11 @@ ...@@ -20,8 +20,11 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
const string op::Relu::type_name{"Relu"};
const string op::ReluBackprop::type_name{"ReluBackprop"};
op::Relu::Relu(shared_ptr<Node> arg) op::Relu::Relu(shared_ptr<Node> arg)
: UnaryElementwiseArithmetic("Relu", {arg}) : UnaryElementwiseArithmetic(arg)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
...@@ -33,7 +36,7 @@ shared_ptr<Node> op::Relu::copy_with_new_args(const NodeVector& new_args) const ...@@ -33,7 +36,7 @@ shared_ptr<Node> op::Relu::copy_with_new_args(const NodeVector& new_args) const
} }
op::ReluBackprop::ReluBackprop(shared_ptr<Node> arg, shared_ptr<Node> delta) op::ReluBackprop::ReluBackprop(shared_ptr<Node> arg, shared_ptr<Node> delta)
: BinaryElementwiseArithmetic("ReluBackprop", arg, delta) : BinaryElementwiseArithmetic(arg, delta)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -33,6 +33,9 @@ namespace ngraph ...@@ -33,6 +33,9 @@ namespace ngraph
class Relu : public ngraph::op::util::UnaryElementwiseArithmetic class Relu : public ngraph::op::util::UnaryElementwiseArithmetic
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a Relu operation. /// \brief Constructs a Relu operation.
/// ///
/// \param arg Node that produces the input tensor. /// \param arg Node that produces the input tensor.
...@@ -50,6 +53,9 @@ namespace ngraph ...@@ -50,6 +53,9 @@ namespace ngraph
class ReluBackprop : public ngraph::op::util::BinaryElementwiseArithmetic class ReluBackprop : public ngraph::op::util::BinaryElementwiseArithmetic
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a ReluBackprop operation. /// \brief Constructs a ReluBackprop operation.
/// ///
/// \param arg Node that produces the relu forward input tensor. /// \param arg Node that produces the relu forward input tensor.
......
...@@ -23,8 +23,10 @@ ...@@ -23,8 +23,10 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
const string op::Reverse::type_name{"Reverse"};
op::Reverse::Reverse(const shared_ptr<Node>& arg, const AxisSet& reversed_axes) op::Reverse::Reverse(const shared_ptr<Node>& arg, const AxisSet& reversed_axes)
: Op("Reverse", check_single_output_args({arg})) : Op(check_single_output_args({arg}))
, m_reversed_axes(reversed_axes) , m_reversed_axes(reversed_axes)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
......
...@@ -46,6 +46,9 @@ namespace ngraph ...@@ -46,6 +46,9 @@ namespace ngraph
class Reverse : public Op class Reverse : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a reverse operation. /// \brief Constructs a reverse operation.
/// ///
/// \param arg The input tensor, some of whose axes are to be reversed. /// \param arg The input tensor, some of whose axes are to be reversed.
......
...@@ -25,11 +25,13 @@ ...@@ -25,11 +25,13 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
const string op::ReverseSequence::type_name{"ReverseSequence"};
op::ReverseSequence::ReverseSequence(const std::shared_ptr<Node> arg, op::ReverseSequence::ReverseSequence(const std::shared_ptr<Node> arg,
const std::shared_ptr<Node> seq_indices, const std::shared_ptr<Node> seq_indices,
size_t batch_axis, size_t batch_axis,
size_t seq_axis) size_t seq_axis)
: Op("ReverseSequence", check_single_output_args({arg, seq_indices})) : Op(check_single_output_args({arg, seq_indices}))
, m_batch_axis(batch_axis) , m_batch_axis(batch_axis)
, m_seq_axis(seq_axis) , m_seq_axis(seq_axis)
{ {
......
...@@ -25,6 +25,9 @@ namespace ngraph ...@@ -25,6 +25,9 @@ namespace ngraph
class ReverseSequence : public Op class ReverseSequence : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs an arcsin operation. /// \brief Constructs an arcsin operation.
/// ///
/// \param arg Node that produces the input tensor. /// \param arg Node that produces the input tensor.
......
...@@ -24,6 +24,8 @@ static int INPUTS = 0; ...@@ -24,6 +24,8 @@ static int INPUTS = 0;
static int INDICES = 1; static int INDICES = 1;
static int UPDATES = 2; static int UPDATES = 2;
const string op::ScatterAdd::type_name{"ScatterAdd"};
shared_ptr<Node> op::ScatterAdd::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::ScatterAdd::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
......
...@@ -26,13 +26,16 @@ namespace ngraph ...@@ -26,13 +26,16 @@ namespace ngraph
class ScatterAdd : public Op class ScatterAdd : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \param inputs Tensor /// \param inputs Tensor
/// \param indices Index tensor: Data type must be `element::i32` or `element::i64` /// \param indices Index tensor: Data type must be `element::i32` or `element::i64`
/// \param updates Tensor: Must have same type as inputs /// \param updates Tensor: Must have same type as inputs
ScatterAdd(const std::shared_ptr<Node>& inputs, ScatterAdd(const std::shared_ptr<Node>& inputs,
const std::shared_ptr<Node>& indices, const std::shared_ptr<Node>& indices,
const std::shared_ptr<Node>& updates) const std::shared_ptr<Node>& updates)
: Op("ScatterAdd", check_single_output_args({inputs, indices, updates})) : Op(check_single_output_args({inputs, indices, updates}))
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -24,6 +24,8 @@ static int INPUTS = 0; ...@@ -24,6 +24,8 @@ static int INPUTS = 0;
static int INDICES = 1; static int INDICES = 1;
static int UPDATES = 2; static int UPDATES = 2;
const string op::ScatterNDAdd::type_name{"ScatterNDAdd"};
shared_ptr<Node> op::ScatterNDAdd::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::ScatterNDAdd::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
......
...@@ -26,13 +26,16 @@ namespace ngraph ...@@ -26,13 +26,16 @@ namespace ngraph
class ScatterNDAdd : public Op class ScatterNDAdd : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \param inputs Tensor /// \param inputs Tensor
/// \param indices Index tensor: Data type must be `element::i32` or `element::i64` /// \param indices Index tensor: Data type must be `element::i32` or `element::i64`
/// \param updates Tensor: Must have same type as inputs /// \param updates Tensor: Must have same type as inputs
ScatterNDAdd(const std::shared_ptr<Node>& inputs, ScatterNDAdd(const std::shared_ptr<Node>& inputs,
const std::shared_ptr<Node>& indices, const std::shared_ptr<Node>& indices,
const std::shared_ptr<Node>& updates) const std::shared_ptr<Node>& updates)
: Op("ScatterNDAdd", check_single_output_args({inputs, indices, updates})) : Op(check_single_output_args({inputs, indices, updates}))
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -21,6 +21,9 @@ ...@@ -21,6 +21,9 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
const string op::Sigmoid::type_name{"Sigmoid"};
const string op::SigmoidBackprop::type_name{"SigmoidBackprop"};
shared_ptr<Node> op::Sigmoid::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Sigmoid::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
...@@ -28,13 +31,13 @@ shared_ptr<Node> op::Sigmoid::copy_with_new_args(const NodeVector& new_args) con ...@@ -28,13 +31,13 @@ shared_ptr<Node> op::Sigmoid::copy_with_new_args(const NodeVector& new_args) con
} }
op::Sigmoid::Sigmoid(shared_ptr<Node> arg) op::Sigmoid::Sigmoid(shared_ptr<Node> arg)
: UnaryElementwiseArithmetic("Sigmoid", arg) : UnaryElementwiseArithmetic(arg)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
op::SigmoidBackprop::SigmoidBackprop(shared_ptr<Node> arg, shared_ptr<Node> delta) op::SigmoidBackprop::SigmoidBackprop(shared_ptr<Node> arg, shared_ptr<Node> delta)
: BinaryElementwiseArithmetic("SigmoidBackprop", arg, delta) : BinaryElementwiseArithmetic(arg, delta)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -28,6 +28,9 @@ namespace ngraph ...@@ -28,6 +28,9 @@ namespace ngraph
class Sigmoid : public util::UnaryElementwiseArithmetic class Sigmoid : public util::UnaryElementwiseArithmetic
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
Sigmoid(std::shared_ptr<Node> arg); Sigmoid(std::shared_ptr<Node> arg);
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
...@@ -40,6 +43,9 @@ namespace ngraph ...@@ -40,6 +43,9 @@ namespace ngraph
class SigmoidBackprop : public util::BinaryElementwiseArithmetic class SigmoidBackprop : public util::BinaryElementwiseArithmetic
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a SigmoidBackprop operation. /// \brief Constructs a SigmoidBackprop operation.
/// ///
/// \param arg Node that produces the Sigmoid forward input tensor. /// \param arg Node that produces the Sigmoid forward input tensor.
......
...@@ -19,8 +19,10 @@ ...@@ -19,8 +19,10 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
const string op::Sign::type_name{"Sign"};
op::Sign::Sign(const shared_ptr<Node>& arg) op::Sign::Sign(const shared_ptr<Node>& arg)
: UnaryElementwiseArithmetic("Sign", arg) : UnaryElementwiseArithmetic(arg)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -27,6 +27,9 @@ namespace ngraph ...@@ -27,6 +27,9 @@ namespace ngraph
class Sign : public util::UnaryElementwiseArithmetic class Sign : public util::UnaryElementwiseArithmetic
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs an elementwise sign operation. /// \brief Constructs an elementwise sign operation.
/// ///
/// \param arg Node that produces the input tensor. /// \param arg Node that produces the input tensor.
......
...@@ -21,8 +21,10 @@ ...@@ -21,8 +21,10 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
const string op::Sin::type_name{"Sin"};
op::Sin::Sin(const shared_ptr<Node>& arg) op::Sin::Sin(const shared_ptr<Node>& arg)
: UnaryElementwiseArithmetic("Sin", arg) : UnaryElementwiseArithmetic(arg)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -38,6 +38,9 @@ namespace ngraph ...@@ -38,6 +38,9 @@ namespace ngraph
class Sin : public util::UnaryElementwiseArithmetic class Sin : public util::UnaryElementwiseArithmetic
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a sine operation. /// \brief Constructs a sine operation.
/// ///
/// \param arg Node that produces the input tensor. /// \param arg Node that produces the input tensor.
......
...@@ -21,8 +21,10 @@ ...@@ -21,8 +21,10 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
const string op::Sinh::type_name{"Sinh"};
op::Sinh::Sinh(const shared_ptr<Node>& arg) op::Sinh::Sinh(const shared_ptr<Node>& arg)
: UnaryElementwiseArithmetic("Sinh", arg) : UnaryElementwiseArithmetic(arg)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -26,6 +26,9 @@ namespace ngraph ...@@ -26,6 +26,9 @@ namespace ngraph
class Sinh : public util::UnaryElementwiseArithmetic class Sinh : public util::UnaryElementwiseArithmetic
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a hyperbolic sine operation. /// \brief Constructs a hyperbolic sine operation.
/// ///
/// \param arg Node that produces the input tensor. /// \param arg Node that produces the input tensor.
......
...@@ -29,8 +29,10 @@ ...@@ -29,8 +29,10 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
const string op::Softmax::type_name{"Softmax"};
op::Softmax::Softmax(const shared_ptr<Node>& arg, const AxisSet& axes) op::Softmax::Softmax(const shared_ptr<Node>& arg, const AxisSet& axes)
: UnaryElementwiseArithmetic("Softmax", arg) : UnaryElementwiseArithmetic(arg)
, m_axes(axes) , m_axes(axes)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
......
...@@ -27,6 +27,9 @@ namespace ngraph ...@@ -27,6 +27,9 @@ namespace ngraph
class Softmax : public util::UnaryElementwiseArithmetic class Softmax : public util::UnaryElementwiseArithmetic
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a softmax operation. /// \brief Constructs a softmax operation.
/// ///
/// \param arg Node that produces the first input tensor.<br> /// \param arg Node that produces the first input tensor.<br>
......
...@@ -21,8 +21,10 @@ ...@@ -21,8 +21,10 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
const string op::Sqrt::type_name{"Sqrt"};
op::Sqrt::Sqrt(const shared_ptr<Node>& arg) op::Sqrt::Sqrt(const shared_ptr<Node>& arg)
: UnaryElementwiseArithmetic("Sqrt", arg) : UnaryElementwiseArithmetic(arg)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -38,6 +38,9 @@ namespace ngraph ...@@ -38,6 +38,9 @@ namespace ngraph
class Sqrt : public util::UnaryElementwiseArithmetic class Sqrt : public util::UnaryElementwiseArithmetic
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a square operation. /// \brief Constructs a square operation.
/// ///
/// \param arg Node that produces the input tensor. /// \param arg Node that produces the input tensor.
......
...@@ -21,8 +21,10 @@ ...@@ -21,8 +21,10 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
const string op::StopGradient::type_name{"StopGradient"};
op::StopGradient::StopGradient(const shared_ptr<Node>& arg) op::StopGradient::StopGradient(const shared_ptr<Node>& arg)
: UnaryElementwiseArithmetic("StopGradient", arg) : UnaryElementwiseArithmetic(arg)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment