Commit de132138 authored by Adam Procter's avatar Adam Procter

Merge remote-tracking branch 'origin/master' into aprocter/cf-element-types

parents b46fa4c7 23f838e5
...@@ -21,7 +21,7 @@ include(ExternalProject) ...@@ -21,7 +21,7 @@ include(ExternalProject)
# ONNX.proto definition version # ONNX.proto definition version
#------------------------------------------------------------------------------ #------------------------------------------------------------------------------
set(ONNX_VERSION 1.3.0) set(ONNX_VERSION 1.5.0)
#------------------------------------------------------------------------------ #------------------------------------------------------------------------------
# Download and install libonnx ... # Download and install libonnx ...
...@@ -30,6 +30,9 @@ set(ONNX_VERSION 1.3.0) ...@@ -30,6 +30,9 @@ set(ONNX_VERSION 1.3.0)
set(ONNX_GIT_REPO_URL https://github.com/onnx/onnx.git) set(ONNX_GIT_REPO_URL https://github.com/onnx/onnx.git)
set(ONNX_GIT_BRANCH rel-${ONNX_VERSION}) set(ONNX_GIT_BRANCH rel-${ONNX_VERSION})
add_definitions(-DONNX_BUILD_SHARED_LIBS=ON)
add_definitions(-DONNX_ML=ON)
ExternalProject_Add( ExternalProject_Add(
ext_onnx ext_onnx
PREFIX onnx PREFIX onnx
...@@ -58,8 +61,8 @@ ExternalProject_Add( ...@@ -58,8 +61,8 @@ ExternalProject_Add(
ExternalProject_Get_Property(ext_onnx SOURCE_DIR BINARY_DIR) ExternalProject_Get_Property(ext_onnx SOURCE_DIR BINARY_DIR)
set(ONNX_INCLUDE_DIR ${SOURCE_DIR}/onnx) set(ONNX_INCLUDE_DIR ${SOURCE_DIR})
set(ONNX_PROTO_INCLUDE_DIR ${BINARY_DIR}/onnx) set(ONNX_PROTO_INCLUDE_DIR ${BINARY_DIR})
if (WIN32) if (WIN32)
set(ONNX_LIBRARY ${BINARY_DIR}/${CMAKE_BUILD_TYPE}/onnx.lib) set(ONNX_LIBRARY ${BINARY_DIR}/${CMAKE_BUILD_TYPE}/onnx.lib)
set(ONNX_PROTO_LIBRARY ${BINARY_DIR}/${CMAKE_BUILD_TYPE}/onnx_proto.lib) set(ONNX_PROTO_LIBRARY ${BINARY_DIR}/${CMAKE_BUILD_TYPE}/onnx_proto.lib)
......
...@@ -37,6 +37,7 @@ ngraph.ops ...@@ -37,6 +37,7 @@ ngraph.ops
equal equal
exp exp
floor floor
gelu
get_output_element get_output_element
greater greater
greater_eq greater_eq
......
...@@ -50,6 +50,7 @@ from ngraph.ops import elu ...@@ -50,6 +50,7 @@ from ngraph.ops import elu
from ngraph.ops import equal from ngraph.ops import equal
from ngraph.ops import exp from ngraph.ops import exp
from ngraph.ops import floor from ngraph.ops import floor
from ngraph.ops import gelu
from ngraph.ops import get_output_element from ngraph.ops import get_output_element
from ngraph.ops import greater from ngraph.ops import greater
from ngraph.ops import greater_eq from ngraph.ops import greater_eq
......
...@@ -74,6 +74,7 @@ from _pyngraph.op import Elu ...@@ -74,6 +74,7 @@ from _pyngraph.op import Elu
from _pyngraph.op import Equal from _pyngraph.op import Equal
from _pyngraph.op import Exp from _pyngraph.op import Exp
from _pyngraph.op import Floor from _pyngraph.op import Floor
from _pyngraph.op import Gelu
from _pyngraph.op import GetOutputElement from _pyngraph.op import GetOutputElement
from _pyngraph.op import Greater from _pyngraph.op import Greater
from _pyngraph.op import GreaterEq from _pyngraph.op import GreaterEq
......
...@@ -23,7 +23,7 @@ from ngraph.impl import AxisSet, AxisVector, Coordinate, CoordinateDiff, Functio ...@@ -23,7 +23,7 @@ from ngraph.impl import AxisSet, AxisVector, Coordinate, CoordinateDiff, Functio
from ngraph.impl.op import Abs, Acos, Add, And, Asin, ArgMax, ArgMin, Atan, AvgPool, \ from ngraph.impl.op import Abs, Acos, Add, And, Asin, ArgMax, ArgMin, Atan, AvgPool, \
BatchNormTraining, BatchNormInference, Broadcast, Ceiling, Clamp, Concat, Constant, Convert, \ BatchNormTraining, BatchNormInference, Broadcast, Ceiling, Clamp, Concat, Constant, Convert, \
Convolution, ConvolutionBackpropData, Cos, Cosh, Divide, Dot, Elu, Equal, Exp, Floor, \ Convolution, ConvolutionBackpropData, Cos, Cosh, Divide, Dot, Elu, Equal, Exp, Floor, \
GetOutputElement, Greater, GreaterEq, Less, LessEq, Log, LRN, Max, Maximum, MaxPool, \ Gelu, GetOutputElement, Greater, GreaterEq, Less, LessEq, Log, LRN, Max, Maximum, MaxPool, \
Min, Minimum, Multiply, Negative, Not, NotEqual, OneHot, Or, Pad, Parameter, Product, \ Min, Minimum, Multiply, Negative, Not, NotEqual, OneHot, Or, Pad, Parameter, Product, \
Power, Relu, ReplaceSlice, Reshape, Reverse, Select, Sign, Sin, Sinh, Slice, Softmax, \ Power, Relu, ReplaceSlice, Reshape, Reverse, Select, Sign, Sin, Sinh, Slice, Softmax, \
Sqrt, Subtract, Sum, Tan, Tanh, TopK Sqrt, Subtract, Sum, Tan, Tanh, TopK
...@@ -527,6 +527,24 @@ def convert(node, new_type, name=None): # type: (Node, NumericType, str) -> Nod ...@@ -527,6 +527,24 @@ def convert(node, new_type, name=None): # type: (Node, NumericType, str) -> Nod
return Convert(node, new_element_type) return Convert(node, new_element_type)
@nameable_op
def gelu(node, name=None): # type: (NodeInput, str) -> Node
r"""Perform Gaussian Error Linear Unit operation element-wise on data from input node.
Computes GELU function:
.. math:: f(x) = 0.5\cdot x\cdot(1 + erf( \dfrac{x}{\sqrt{2}})
For more information refer to:
`Gaussian Error Linear Unit (GELU) <https://arxiv.org/pdf/1606.08415.pdf>`_
:param node: Input tensor. One of: input node, array or scalar.
:param name: Optional output node name.
:return: The new node performing a GELU operation on its input data element-wise.
"""
return Gelu(as_node(node))
@nameable_op @nameable_op
def select(selection_node, input_node1, input_node2, name=None): def select(selection_node, input_node1, input_node2, name=None):
# type: (Node, Node, Node, str) -> Node # type: (Node, Node, Node, str) -> Node
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "ngraph/op/fused/gelu.hpp"
#include "pyngraph/ops/fused/gelu.hpp"
namespace py = pybind11;
void regclass_pyngraph_op_Gelu(py::module m)
{
py::class_<ngraph::op::Gelu, std::shared_ptr<ngraph::op::Gelu>, ngraph::op::Op> gelu(m, "Gelu");
gelu.doc() = "ngraph.impl.op.Gelu wraps ngraph::op::Gelu";
gelu.def(py::init<const std::shared_ptr<ngraph::Node>&>());
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <pybind11/pybind11.h>
namespace py = pybind11;
void regclass_pyngraph_op_Gelu(py::module m);
...@@ -54,6 +54,7 @@ void regmodule_pyngraph_op(py::module m_op) ...@@ -54,6 +54,7 @@ void regmodule_pyngraph_op(py::module m_op)
regclass_pyngraph_op_Equal(m_op); regclass_pyngraph_op_Equal(m_op);
regclass_pyngraph_op_Exp(m_op); regclass_pyngraph_op_Exp(m_op);
regclass_pyngraph_op_Floor(m_op); regclass_pyngraph_op_Floor(m_op);
regclass_pyngraph_op_Gelu(m_op);
regclass_pyngraph_op_GetOutputElement(m_op); regclass_pyngraph_op_GetOutputElement(m_op);
regclass_pyngraph_op_Greater(m_op); regclass_pyngraph_op_Greater(m_op);
regclass_pyngraph_op_GreaterEq(m_op); regclass_pyngraph_op_GreaterEq(m_op);
......
...@@ -44,6 +44,7 @@ ...@@ -44,6 +44,7 @@
#include "pyngraph/ops/floor.hpp" #include "pyngraph/ops/floor.hpp"
#include "pyngraph/ops/fused/clamp.hpp" #include "pyngraph/ops/fused/clamp.hpp"
#include "pyngraph/ops/fused/elu.hpp" #include "pyngraph/ops/fused/elu.hpp"
#include "pyngraph/ops/fused/gelu.hpp"
#include "pyngraph/ops/get_output_element.hpp" #include "pyngraph/ops/get_output_element.hpp"
#include "pyngraph/ops/greater.hpp" #include "pyngraph/ops/greater.hpp"
#include "pyngraph/ops/greater_eq.hpp" #include "pyngraph/ops/greater_eq.hpp"
......
...@@ -184,6 +184,7 @@ sources = [ ...@@ -184,6 +184,7 @@ sources = [
'pyngraph/ops/equal.cpp', 'pyngraph/ops/equal.cpp',
'pyngraph/ops/exp.cpp', 'pyngraph/ops/exp.cpp',
'pyngraph/ops/floor.cpp', 'pyngraph/ops/floor.cpp',
'pyngraph/ops/fused/gelu.cpp',
'pyngraph/ops/greater.cpp', 'pyngraph/ops/greater.cpp',
'pyngraph/ops/greater_eq.cpp', 'pyngraph/ops/greater_eq.cpp',
'pyngraph/ops/less.cpp', 'pyngraph/ops/less.cpp',
......
...@@ -19,7 +19,7 @@ import ngraph as ng ...@@ -19,7 +19,7 @@ import ngraph as ng
from test.ngraph.util import get_runtime from test.ngraph.util import get_runtime
def test_elu_operator(): def test_elu_operator_with_parameters():
runtime = get_runtime() runtime = get_runtime()
data_shape = [2, 2] data_shape = [2, 2]
...@@ -69,6 +69,38 @@ def test_elu_operator_with_scalar(): ...@@ -69,6 +69,38 @@ def test_elu_operator_with_scalar():
assert np.allclose(result, expected) assert np.allclose(result, expected)
def test_gelu_operator_with_parameters():
runtime = get_runtime()
data_value = np.array([[-5, 1], [-2, 3]], dtype=np.float32)
data_shape = [2, 2]
parameter_data = ng.parameter(data_shape, name='Data', dtype=np.float32)
model = ng.gelu(parameter_data)
computation = runtime.computation(model, parameter_data)
result = computation(data_value)
expected = np.array([[-1.4901161e-06, 8.4134471e-01], [-4.5500278e-02, 2.9959502]],
dtype=np.float32)
assert np.allclose(result, expected)
def test_gelu_operator_with_array():
runtime = get_runtime()
data_value = np.array([[-5, 1], [-2, 3]], dtype=np.float32)
model = ng.gelu(data_value)
computation = runtime.computation(model)
result = computation()
expected = np.array([[-1.4901161e-06, 8.4134471e-01], [-4.5500278e-02, 2.9959502]],
dtype=np.float32)
assert np.allclose(result, expected)
def test_clamp_operator(): def test_clamp_operator():
runtime = get_runtime() runtime = get_runtime()
...@@ -99,4 +131,5 @@ def test_clamp_operator_with_array(): ...@@ -99,4 +131,5 @@ def test_clamp_operator_with_array():
result = computation() result = computation()
expected = np.clip(data_value, min_value, max_value) expected = np.clip(data_value, min_value, max_value)
assert np.allclose(result, expected) assert np.allclose(result, expected)
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include <onnx-ml.pb.h> #include <onnx/onnx_pb.h>
#include "ngraph/except.hpp" #include "ngraph/except.hpp"
#include "tensor.hpp" #include "tensor.hpp"
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include <onnx-ml.pb.h> #include <onnx/onnx_pb.h>
#include <string> #include <string>
#include <vector> #include <vector>
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
// limitations under the License. // limitations under the License.
//***************************************************************************** //*****************************************************************************
#include <onnx-ml.pb.h> #include <onnx/onnx_pb.h>
#include "model.hpp" #include "model.hpp"
#include "ngraph/log.hpp" #include "ngraph/log.hpp"
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include <onnx-ml.pb.h> #include <onnx/onnx_pb.h>
#include <ostream> #include <ostream>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
// limitations under the License. // limitations under the License.
//***************************************************************************** //*****************************************************************************
#include <onnx-ml.pb.h> #include <onnx/onnx_pb.h>
#include "attribute.hpp" #include "attribute.hpp"
#include "graph.hpp" #include "graph.hpp"
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include <onnx-ml.pb.h> #include <onnx/onnx_pb.h>
#include <utility> #include <utility>
#include <vector> #include <vector>
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include <onnx-ml.pb.h> #include <onnx/onnx_pb.h>
#include "ngraph/op/constant.hpp" #include "ngraph/op/constant.hpp"
#include "ngraph/op/parameter.hpp" #include "ngraph/op/parameter.hpp"
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
//***************************************************************************** //*****************************************************************************
#include <onnx-ml.pb.h> // onnx types #include <onnx/onnx_pb.h> // onnx types
#include "common.hpp" #include "common.hpp"
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
//***************************************************************************** //*****************************************************************************
#include <cstdlib> // std::size_t, std::uintptr_t #include <cstdlib> // std::size_t, std::uintptr_t
#include <onnxifi.h> #include <onnx/onnxifi.h>
#include <stdexcept> // std::invalid_agrument, std::out_of_rage #include <stdexcept> // std::invalid_agrument, std::out_of_rage
#include "backend.hpp" #include "backend.hpp"
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
#include <cstddef> // std::size_t, std::uintptr_t #include <cstddef> // std::size_t, std::uintptr_t
#include <map> // std::map #include <map> // std::map
#include <mutex> // std::mutex #include <mutex> // std::mutex
#include <onnxifi.h> #include <onnx/onnxifi.h>
#include "backend.hpp" #include "backend.hpp"
#include "ngraph/runtime/backend.hpp" #include "ngraph/runtime/backend.hpp"
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include <onnxifi.h> #include <onnx/onnxifi.h>
namespace ngraph namespace ngraph
{ {
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#include <cstddef> #include <cstddef>
#include <cstdint> #include <cstdint>
#include <onnxifi.h> #include <onnx/onnxifi.h>
#include <stdexcept> #include <stdexcept>
#include "backend_manager.hpp" #include "backend_manager.hpp"
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
#pragma once #pragma once
#include <memory> #include <memory>
#include <onnxifi.h> #include <onnx/onnxifi.h>
#include "ngraph/runtime/backend.hpp" #include "ngraph/runtime/backend.hpp"
#include "ngraph/runtime/tensor.hpp" #include "ngraph/runtime/tensor.hpp"
......
This diff is collapsed.
...@@ -42,8 +42,8 @@ public: ...@@ -42,8 +42,8 @@ public:
CONVERT, CONVERT,
SHAPE_OF, SHAPE_OF,
REVERSE, REVERSE,
PRODUCT, ARITHMETIC_REDUCTION,
SUM, LOGICAL_REDUCTION,
CONCAT, CONCAT,
GATHER, GATHER,
SLICE, SLICE,
...@@ -68,8 +68,8 @@ public: ...@@ -68,8 +68,8 @@ public:
construct_constant_convert(); construct_constant_convert();
construct_constant_shape_of(); construct_constant_shape_of();
construct_constant_reverse(); construct_constant_reverse();
construct_constant_product(); construct_constant_arithmetic_reduction();
construct_constant_sum(); construct_constant_logical_reduction();
construct_constant_concat(); construct_constant_concat();
construct_constant_gather(); construct_constant_gather();
construct_constant_slice(); construct_constant_slice();
...@@ -101,8 +101,12 @@ public: ...@@ -101,8 +101,12 @@ public:
case CFTransformations::CONVERT: construct_constant_convert(); break; case CFTransformations::CONVERT: construct_constant_convert(); break;
case CFTransformations::SHAPE_OF: construct_constant_shape_of(); break; case CFTransformations::SHAPE_OF: construct_constant_shape_of(); break;
case CFTransformations::REVERSE: construct_constant_reverse(); break; case CFTransformations::REVERSE: construct_constant_reverse(); break;
case CFTransformations::PRODUCT: construct_constant_product(); break; case CFTransformations::ARITHMETIC_REDUCTION:
case CFTransformations::SUM: construct_constant_sum(); break; construct_constant_arithmetic_reduction();
break;
case CFTransformations::LOGICAL_REDUCTION:
construct_constant_logical_reduction();
break;
case CFTransformations::CONCAT: construct_constant_concat(); break; case CFTransformations::CONCAT: construct_constant_concat(); break;
case CFTransformations::GATHER: construct_constant_gather(); break; case CFTransformations::GATHER: construct_constant_gather(); break;
case CFTransformations::SLICE: construct_constant_slice(); break; case CFTransformations::SLICE: construct_constant_slice(); break;
...@@ -126,8 +130,8 @@ private: ...@@ -126,8 +130,8 @@ private:
void construct_constant_convert(); void construct_constant_convert();
void construct_constant_shape_of(); void construct_constant_shape_of();
void construct_constant_reverse(); void construct_constant_reverse();
void construct_constant_product(); void construct_constant_arithmetic_reduction();
void construct_constant_sum(); void construct_constant_logical_reduction();
void construct_constant_concat(); void construct_constant_concat();
void construct_constant_gather(); void construct_constant_gather();
void construct_constant_slice(); void construct_constant_slice();
......
...@@ -17,7 +17,6 @@ ...@@ -17,7 +17,6 @@
#pragma once #pragma once
#include "ngraph/op/op.hpp" #include "ngraph/op/op.hpp"
#include "ngraph/runtime/cpu/cpu_backend_visibility.h"
namespace ngraph namespace ngraph
{ {
...@@ -35,7 +34,6 @@ namespace ngraph ...@@ -35,7 +34,6 @@ namespace ngraph
class BatchMatMulTranspose : public Op class BatchMatMulTranspose : public Op
{ {
public: public:
CPU_BACKEND_API
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
/// \brief Constructs a batch of matmul product operation. /// \brief Constructs a batch of matmul product operation.
......
...@@ -30,7 +30,6 @@ namespace ngraph ...@@ -30,7 +30,6 @@ namespace ngraph
class BatchNormTrainingRelu : public Op class BatchNormTrainingRelu : public Op
{ {
public: public:
CPU_BACKEND_API
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
CPU_BACKEND_API BatchNormTrainingRelu(double eps, CPU_BACKEND_API BatchNormTrainingRelu(double eps,
...@@ -60,7 +59,6 @@ namespace ngraph ...@@ -60,7 +59,6 @@ namespace ngraph
class BatchNormInferenceRelu : public Op class BatchNormInferenceRelu : public Op
{ {
public: public:
CPU_BACKEND_API
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
BatchNormInferenceRelu(double eps, BatchNormInferenceRelu(double eps,
......
...@@ -19,7 +19,6 @@ ...@@ -19,7 +19,6 @@
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/op/op.hpp" #include "ngraph/op/op.hpp"
#include "ngraph/op/util/unary_elementwise_arithmetic.hpp" #include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
#include "ngraph/runtime/cpu/cpu_backend_visibility.h"
namespace ngraph namespace ngraph
{ {
...@@ -30,7 +29,6 @@ namespace ngraph ...@@ -30,7 +29,6 @@ namespace ngraph
class BoundedRelu : public ngraph::op::util::UnaryElementwiseArithmetic class BoundedRelu : public ngraph::op::util::UnaryElementwiseArithmetic
{ {
public: public:
CPU_BACKEND_API
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
/// \brief Constructs a BoundedRelu operation. /// \brief Constructs a BoundedRelu operation.
......
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
#include "ngraph/op/convolution.hpp" #include "ngraph/op/convolution.hpp"
#include "ngraph/op/op.hpp" #include "ngraph/op/op.hpp"
#include "ngraph/runtime/cpu/cpu_backend_visibility.h"
namespace ngraph namespace ngraph
{ {
...@@ -27,7 +26,6 @@ namespace ngraph ...@@ -27,7 +26,6 @@ namespace ngraph
class ConvolutionAdd : public Op class ConvolutionAdd : public Op
{ {
public: public:
CPU_BACKEND_API
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
ConvolutionAdd(const std::shared_ptr<op::Convolution>& conv, ConvolutionAdd(const std::shared_ptr<op::Convolution>& conv,
......
...@@ -28,7 +28,6 @@ namespace ngraph ...@@ -28,7 +28,6 @@ namespace ngraph
class ConvolutionRelu : public Op class ConvolutionRelu : public Op
{ {
public: public:
CPU_BACKEND_API
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
CPU_BACKEND_API ConvolutionRelu(const std::shared_ptr<op::Convolution>& conv); CPU_BACKEND_API ConvolutionRelu(const std::shared_ptr<op::Convolution>& conv);
......
...@@ -35,7 +35,6 @@ namespace ngraph ...@@ -35,7 +35,6 @@ namespace ngraph
class ConvertLayout : public ngraph::op::Op class ConvertLayout : public ngraph::op::Op
{ {
public: public:
CPU_BACKEND_API
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
CPU_BACKEND_API ConvertLayout( CPU_BACKEND_API ConvertLayout(
......
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
#include "ngraph/op/convolution.hpp" #include "ngraph/op/convolution.hpp"
#include "ngraph/op/op.hpp" #include "ngraph/op/op.hpp"
#include "ngraph/runtime/cpu/cpu_backend_visibility.h"
namespace ngraph namespace ngraph
{ {
...@@ -28,7 +27,6 @@ namespace ngraph ...@@ -28,7 +27,6 @@ namespace ngraph
class DeconvolutionBias : public Op class DeconvolutionBias : public Op
{ {
public: public:
CPU_BACKEND_API
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
/// \brief Constructs a batched-convolution data batch-backprop operation. /// \brief Constructs a batched-convolution data batch-backprop operation.
......
...@@ -17,7 +17,6 @@ ...@@ -17,7 +17,6 @@
#pragma once #pragma once
#include "ngraph/op/op.hpp" #include "ngraph/op/op.hpp"
#include "ngraph/runtime/cpu/cpu_backend_visibility.h"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
namespace ngraph namespace ngraph
...@@ -27,7 +26,6 @@ namespace ngraph ...@@ -27,7 +26,6 @@ namespace ngraph
class Dropout : public Op class Dropout : public Op
{ {
public: public:
CPU_BACKEND_API
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
Dropout(const Output<Node>& input, Dropout(const Output<Node>& input,
......
...@@ -17,7 +17,6 @@ ...@@ -17,7 +17,6 @@
#pragma once #pragma once
#include "ngraph/op/fused/group_conv.hpp" #include "ngraph/op/fused/group_conv.hpp"
#include "ngraph/op/op.hpp" #include "ngraph/op/op.hpp"
#include "ngraph/runtime/cpu/cpu_backend_visibility.h"
namespace ngraph namespace ngraph
{ {
...@@ -28,7 +27,6 @@ namespace ngraph ...@@ -28,7 +27,6 @@ namespace ngraph
class GroupConvolutionBias : public Op class GroupConvolutionBias : public Op
{ {
public: public:
CPU_BACKEND_API
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
GroupConvolutionBias(const std::shared_ptr<op::GroupConvolution>& conv, GroupConvolutionBias(const std::shared_ptr<op::GroupConvolution>& conv,
......
...@@ -20,7 +20,6 @@ ...@@ -20,7 +20,6 @@
#include <vector> #include <vector>
#include "ngraph/op/op.hpp" #include "ngraph/op/op.hpp"
#include "ngraph/runtime/cpu/cpu_backend_visibility.h"
namespace ngraph namespace ngraph
{ {
...@@ -33,7 +32,6 @@ namespace ngraph ...@@ -33,7 +32,6 @@ namespace ngraph
class HalideOp : public ngraph::op::Op class HalideOp : public ngraph::op::Op
{ {
public: public:
CPU_BACKEND_API
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
HalideOp(const OutputVector& args, HalideOp(const OutputVector& args,
......
...@@ -19,7 +19,6 @@ ...@@ -19,7 +19,6 @@
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/op/op.hpp" #include "ngraph/op/op.hpp"
#include "ngraph/op/util/unary_elementwise_arithmetic.hpp" #include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
#include "ngraph/runtime/cpu/cpu_backend_visibility.h"
namespace ngraph namespace ngraph
{ {
...@@ -31,7 +30,6 @@ namespace ngraph ...@@ -31,7 +30,6 @@ namespace ngraph
class CPULeakyRelu : public ngraph::op::util::UnaryElementwiseArithmetic class CPULeakyRelu : public ngraph::op::util::UnaryElementwiseArithmetic
{ {
public: public:
CPU_BACKEND_API
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
/// \brief Constructs a CPULeakyRelu operation. /// \brief Constructs a CPULeakyRelu operation.
......
...@@ -17,7 +17,6 @@ ...@@ -17,7 +17,6 @@
#pragma once #pragma once
#include "ngraph/op/op.hpp" #include "ngraph/op/op.hpp"
#include "ngraph/runtime/cpu/cpu_backend_visibility.h"
#include "ngraph/runtime/cpu/op/rnn_utils.hpp" #include "ngraph/runtime/cpu/op/rnn_utils.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
...@@ -28,7 +27,6 @@ namespace ngraph ...@@ -28,7 +27,6 @@ namespace ngraph
class Lstm : public Op class Lstm : public Op
{ {
public: public:
CPU_BACKEND_API
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
// INPUTS: // INPUTS:
......
...@@ -27,7 +27,6 @@ namespace ngraph ...@@ -27,7 +27,6 @@ namespace ngraph
class MatmulBias : public Op class MatmulBias : public Op
{ {
public: public:
CPU_BACKEND_API
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
CPU_BACKEND_API MatmulBias(const Output<Node>& W, CPU_BACKEND_API MatmulBias(const Output<Node>& W,
......
...@@ -32,7 +32,6 @@ namespace ngraph ...@@ -32,7 +32,6 @@ namespace ngraph
class MaxPoolWithIndices : public Op class MaxPoolWithIndices : public Op
{ {
public: public:
CPU_BACKEND_API
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
CPU_BACKEND_API MaxPoolWithIndices(const Output<Node>& arg, CPU_BACKEND_API MaxPoolWithIndices(const Output<Node>& arg,
...@@ -68,7 +67,6 @@ namespace ngraph ...@@ -68,7 +67,6 @@ namespace ngraph
class MaxPoolWithIndicesBackprop : public Op class MaxPoolWithIndicesBackprop : public Op
{ {
public: public:
CPU_BACKEND_API
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
CPU_BACKEND_API MaxPoolWithIndicesBackprop(const Output<Node>& arg_forward, CPU_BACKEND_API MaxPoolWithIndicesBackprop(const Output<Node>& arg_forward,
......
...@@ -19,7 +19,6 @@ ...@@ -19,7 +19,6 @@
#include <utility> #include <utility>
#include "ngraph/op/op.hpp" #include "ngraph/op/op.hpp"
#include "ngraph/runtime/cpu/cpu_backend_visibility.h"
namespace ngraph namespace ngraph
{ {
...@@ -28,7 +27,6 @@ namespace ngraph ...@@ -28,7 +27,6 @@ namespace ngraph
class QuantizedMatmul : public Op class QuantizedMatmul : public Op
{ {
public: public:
CPU_BACKEND_API
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
QuantizedMatmul(const Output<Node>& data, QuantizedMatmul(const Output<Node>& data,
......
...@@ -48,7 +48,6 @@ namespace ngraph ...@@ -48,7 +48,6 @@ namespace ngraph
class Rnn : public Op class Rnn : public Op
{ {
public: public:
CPU_BACKEND_API
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
CPU_BACKEND_API Rnn(const Output<Node>& src_layer, CPU_BACKEND_API Rnn(const Output<Node>& src_layer,
......
...@@ -30,7 +30,6 @@ namespace ngraph ...@@ -30,7 +30,6 @@ namespace ngraph
class SigmoidMultiply : public Op class SigmoidMultiply : public Op
{ {
public: public:
CPU_BACKEND_API
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
/// Defines valid function types /// Defines valid function types
...@@ -69,7 +68,6 @@ namespace ngraph ...@@ -69,7 +68,6 @@ namespace ngraph
class SigmoidMultiplyBackprop : public Op class SigmoidMultiplyBackprop : public Op
{ {
public: public:
CPU_BACKEND_API
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
typedef SigmoidMultiply::FunctionType FunctionType; typedef SigmoidMultiply::FunctionType FunctionType;
......
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
#include "ngraph/coordinate.hpp" #include "ngraph/coordinate.hpp"
#include "ngraph/op/op.hpp" #include "ngraph/op/op.hpp"
#include "ngraph/runtime/cpu/cpu_backend_visibility.h"
#include "ngraph/strides.hpp" #include "ngraph/strides.hpp"
namespace ngraph namespace ngraph
...@@ -51,7 +50,6 @@ namespace ngraph ...@@ -51,7 +50,6 @@ namespace ngraph
class UpdateSlice : public Op class UpdateSlice : public Op
{ {
public: public:
CPU_BACKEND_API
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
/// \brief Constructs a tensor slice update operation. /// \brief Constructs a tensor slice update operation.
......
...@@ -39,7 +39,6 @@ namespace ngraph ...@@ -39,7 +39,6 @@ namespace ngraph
class ngraph::runtime::plaidml::op::Convolution final : public ngraph::op::Op class ngraph::runtime::plaidml::op::Convolution final : public ngraph::op::Op
{ {
public: public:
NGRAPH_API
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
Convolution(std::shared_ptr<ngraph::op::Convolution> src, Convolution(std::shared_ptr<ngraph::op::Convolution> src,
...@@ -66,7 +65,6 @@ private: ...@@ -66,7 +65,6 @@ private:
class ngraph::runtime::plaidml::op::ConvolutionBackpropData final : public ngraph::op::Op class ngraph::runtime::plaidml::op::ConvolutionBackpropData final : public ngraph::op::Op
{ {
public: public:
NGRAPH_API
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
ConvolutionBackpropData(std::shared_ptr<ngraph::op::ConvolutionBackpropData> src, ConvolutionBackpropData(std::shared_ptr<ngraph::op::ConvolutionBackpropData> src,
...@@ -93,7 +91,6 @@ private: ...@@ -93,7 +91,6 @@ private:
class ngraph::runtime::plaidml::op::ConvolutionBackpropFilters final : public ngraph::op::Op class ngraph::runtime::plaidml::op::ConvolutionBackpropFilters final : public ngraph::op::Op
{ {
public: public:
NGRAPH_API
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
ConvolutionBackpropFilters(std::shared_ptr<ngraph::op::ConvolutionBackpropFilters> src, ConvolutionBackpropFilters(std::shared_ptr<ngraph::op::ConvolutionBackpropFilters> src,
......
...@@ -40,7 +40,6 @@ namespace ngraph ...@@ -40,7 +40,6 @@ namespace ngraph
class ngraph::runtime::plaidml::op::ImplicitBroadcast final : public ngraph::op::Op class ngraph::runtime::plaidml::op::ImplicitBroadcast final : public ngraph::op::Op
{ {
public: public:
NGRAPH_API
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
ImplicitBroadcast(const Output<Node>& input, const Shape& shape); ImplicitBroadcast(const Output<Node>& input, const Shape& shape);
......
...@@ -39,7 +39,6 @@ namespace ngraph ...@@ -39,7 +39,6 @@ namespace ngraph
class ngraph::runtime::plaidml::op::Replicate final : public ngraph::op::Op class ngraph::runtime::plaidml::op::Replicate final : public ngraph::op::Op
{ {
public: public:
NGRAPH_API
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
Replicate(const Output<Node>& arg, std::size_t replication_axis, std::size_t replication_count); Replicate(const Output<Node>& arg, std::size_t replication_axis, std::size_t replication_count);
......
...@@ -38,7 +38,6 @@ namespace ngraph ...@@ -38,7 +38,6 @@ namespace ngraph
class ngraph::runtime::plaidml::op::Winograd final : public ngraph::op::Op class ngraph::runtime::plaidml::op::Winograd final : public ngraph::op::Op
{ {
public: public:
NGRAPH_API
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
Winograd(std::shared_ptr<Convolution> conv, const OutputVector& args); Winograd(std::shared_ptr<Convolution> conv, const OutputVector& args);
......
...@@ -36,7 +36,7 @@ namespace ngraph ...@@ -36,7 +36,7 @@ namespace ngraph
const AxisSet& reduction_axes) const AxisSet& reduction_axes)
{ {
T minval = std::numeric_limits<T>::has_infinity T minval = std::numeric_limits<T>::has_infinity
? -std::numeric_limits<T>::infinity() ? T(-std::numeric_limits<T>::infinity())
: std::numeric_limits<T>::min(); : std::numeric_limits<T>::min();
CoordinateTransform output_transform(out_shape); CoordinateTransform output_transform(out_shape);
......
...@@ -1540,8 +1540,7 @@ NGRAPH_TEST(${BACKEND_NAME}, group_conv_transpose) ...@@ -1540,8 +1540,7 @@ NGRAPH_TEST(${BACKEND_NAME}, group_conv_transpose)
-0.0270785f, -0.00680824f, -0.06650258f, 0.08004665f, 0.07918708f, -0.0724144f, -0.0270785f, -0.00680824f, -0.06650258f, 0.08004665f, 0.07918708f, -0.0724144f,
0.06256775f, -0.17838378f, -0.18863615f, 0.20064656f, 0.133717f, -0.06876295f, 0.06256775f, -0.17838378f, -0.18863615f, 0.20064656f, 0.133717f, -0.06876295f,
-0.06398046f, -0.00864975f, 0.19289537f, -0.01490572f, -0.13673618f, 0.01949645f}); -0.06398046f, -0.00864975f, 0.19289537f, -0.01490572f, -0.13673618f, 0.01949645f});
test_case.set_tolerance(3); test_case.run(DEFAULT_FLOAT_TOLERANCE_BITS + 1);
test_case.run();
} }
NGRAPH_TEST(${BACKEND_NAME}, group_conv_transpose_output_shape) NGRAPH_TEST(${BACKEND_NAME}, group_conv_transpose_output_shape)
......
...@@ -434,6 +434,110 @@ TEST(constant_folding, const_sum) ...@@ -434,6 +434,110 @@ TEST(constant_folding, const_sum)
ASSERT_EQ(values_expected, values_out); ASSERT_EQ(values_expected, values_out);
} }
TEST(constant_folding, const_max)
{
Shape input_shape{3, 3};
vector<int32_t> values_in{1, 2, 3, 4, 5, 6, 7, 8, 9};
auto constant = op::Constant::create(element::i32, input_shape, values_in);
auto convert = make_shared<op::Max>(constant, AxisSet{1});
auto f = make_shared<Function>(convert, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Max>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const =
std::dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
auto values_out = new_const->get_vector<int32_t>();
vector<int32_t> values_expected{3, 6, 9};
ASSERT_EQ(values_expected, values_out);
}
TEST(constant_folding, const_min)
{
Shape input_shape{3, 3};
vector<int32_t> values_in{1, 2, 3, 4, 5, 6, 7, 8, 9};
auto constant = op::Constant::create(element::i32, input_shape, values_in);
auto convert = make_shared<op::Min>(constant, AxisSet{1});
auto f = make_shared<Function>(convert, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Min>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const =
std::dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
auto values_out = new_const->get_vector<int32_t>();
vector<int32_t> values_expected{1, 4, 7};
ASSERT_EQ(values_expected, values_out);
}
TEST(constant_folding, const_all)
{
Shape input_shape{3, 3};
vector<char> values_in{0, 1, 1, 0, 1, 0, 1, 1, 1};
auto constant = op::Constant::create(element::boolean, input_shape, values_in);
auto convert = make_shared<op::All>(constant, AxisSet{1});
auto f = make_shared<Function>(convert, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::All>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const =
std::dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
auto values_out = new_const->get_vector<char>();
vector<char> values_expected{0, 0, 1};
ASSERT_EQ(values_expected, values_out);
}
TEST(constant_folding, const_any)
{
Shape input_shape{3, 3};
vector<char> values_in{1, 0, 0, 1, 0, 1, 0, 0, 0};
auto constant = op::Constant::create(element::boolean, input_shape, values_in);
auto convert = make_shared<op::Any>(constant, AxisSet{1});
auto f = make_shared<Function>(convert, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Any>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const =
std::dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
auto values_out = new_const->get_vector<char>();
vector<char> values_expected{1, 1, 0};
ASSERT_EQ(values_expected, values_out);
}
TEST(constant_folding, const_concat) TEST(constant_folding, const_concat)
{ {
auto constant0 = auto constant0 =
......
...@@ -104,8 +104,7 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_lstm_fwd_with_clip) ...@@ -104,8 +104,7 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_lstm_fwd_with_clip)
// We have to enlarge tolerance bits to 3 - it's only one bit more than default value. // We have to enlarge tolerance bits to 3 - it's only one bit more than default value.
// The discrepancies may occur at most on 7th decimal position. // The discrepancies may occur at most on 7th decimal position.
test_case.set_tolerance(3); test_case.run(DEFAULT_FLOAT_TOLERANCE_BITS + 1);
test_case.run();
} }
NGRAPH_TEST(onnx_${BACKEND_NAME}, model_lstm_fwd_mixed_seq) NGRAPH_TEST(onnx_${BACKEND_NAME}, model_lstm_fwd_mixed_seq)
...@@ -144,8 +143,7 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_lstm_fwd_mixed_seq) ...@@ -144,8 +143,7 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_lstm_fwd_mixed_seq)
// We have to enlarge tolerance bits to 3 - it's only one bit more than default value. // We have to enlarge tolerance bits to 3 - it's only one bit more than default value.
// The discrepancies may occur at most on 7th decimal position. // The discrepancies may occur at most on 7th decimal position.
test_case.set_tolerance(3); test_case.run(DEFAULT_FLOAT_TOLERANCE_BITS + 1);
test_case.run();
} }
NGRAPH_TEST(onnx_${BACKEND_NAME}, model_lstm_fwd_hardsigmoid_activation) NGRAPH_TEST(onnx_${BACKEND_NAME}, model_lstm_fwd_hardsigmoid_activation)
...@@ -201,8 +199,7 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_lstm_fwd_hardsigmoid_activation) ...@@ -201,8 +199,7 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_lstm_fwd_hardsigmoid_activation)
test_case.add_expected_output<float>(Shape{1, 1, 2}, {0.19017234f, 0.00356848f}); test_case.add_expected_output<float>(Shape{1, 1, 2}, {0.19017234f, 0.00356848f});
// The discrepancies occur at most at 18th mantissa bit - 8th decimal position. // The discrepancies occur at most at 18th mantissa bit - 8th decimal position.
test_case.set_tolerance(6); test_case.run(DEFAULT_FLOAT_TOLERANCE_BITS + 4);
test_case.run();
} }
NGRAPH_TEST(onnx_${BACKEND_NAME}, model_lstm_fwd_large_batch_no_clip) NGRAPH_TEST(onnx_${BACKEND_NAME}, model_lstm_fwd_large_batch_no_clip)
...@@ -307,8 +304,7 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_lstm_bdir_short_input_seq) ...@@ -307,8 +304,7 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_lstm_bdir_short_input_seq)
test_case.add_expected_output<float>(Shape{2, 1, 2}, test_case.add_expected_output<float>(Shape{2, 1, 2},
{-0.0251062f, 0.0561262f, -0.0318928f, 0.0762679f}); {-0.0251062f, 0.0561262f, -0.0318928f, 0.0762679f});
test_case.set_tolerance(DEFAULT_FLOAT_TOLERANCE_BITS + 3); test_case.run(DEFAULT_FLOAT_TOLERANCE_BITS + 3);
test_case.run();
} }
NGRAPH_TEST(onnx_${BACKEND_NAME}, model_lstm_mixed_seq_reverse) NGRAPH_TEST(onnx_${BACKEND_NAME}, model_lstm_mixed_seq_reverse)
...@@ -353,6 +349,5 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_lstm_mixed_seq_reverse) ...@@ -353,6 +349,5 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_lstm_mixed_seq_reverse)
Shape{1, 2, 3}, Shape{1, 2, 3},
{0.52497941f, 0.54983425f, 0.5744428f, 1.34960834f, 1.54772296f, 1.65633056f}); {0.52497941f, 0.54983425f, 0.5744428f, 1.34960834f, 1.54772296f, 1.65633056f});
test_case.set_tolerance(DEFAULT_FLOAT_TOLERANCE_BITS + 1); test_case.run(DEFAULT_FLOAT_TOLERANCE_BITS + 1);
test_case.run();
} }
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
#include <cstring> #include <cstring>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <onnxifi.h> #include <onnx/onnxifi.h>
#include "ngraph/runtime/backend_manager.hpp" #include "ngraph/runtime/backend_manager.hpp"
......
...@@ -19,8 +19,9 @@ ...@@ -19,8 +19,9 @@
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "ngraph/assertion.hpp" #include "ngraph/assertion.hpp"
void ngraph::test::NgraphTestCase::run() void ngraph::test::NgraphTestCase::run(size_t tolerance_bits)
{ {
m_tolerance_bits = tolerance_bits;
const auto& function_results = m_function->get_results(); const auto& function_results = m_function->get_results();
NGRAPH_CHECK(m_expected_outputs.size() == function_results.size(), NGRAPH_CHECK(m_expected_outputs.size() == function_results.size(),
"Expected number of outputs is different from the function's number of results."); "Expected number of outputs is different from the function's number of results.");
...@@ -52,12 +53,6 @@ void ngraph::test::NgraphTestCase::run() ...@@ -52,12 +53,6 @@ void ngraph::test::NgraphTestCase::run()
} }
} }
ngraph::test::NgraphTestCase& ngraph::test::NgraphTestCase::set_tolerance(int tolerance_bits)
{
m_tolerance_bits = tolerance_bits;
return *this;
}
ngraph::test::NgraphTestCase& ngraph::test::NgraphTestCase::dump_results(bool dump) ngraph::test::NgraphTestCase& ngraph::test::NgraphTestCase::dump_results(bool dump)
{ {
m_dump_results = dump; m_dump_results = dump;
......
...@@ -38,8 +38,6 @@ namespace ngraph ...@@ -38,8 +38,6 @@ namespace ngraph
{ {
} }
NgraphTestCase& set_tolerance(int tolerance_bits);
/// \brief Makes the test case print the expected and computed values to the console. This should only be used for debugging purposes. /// \brief Makes the test case print the expected and computed values to the console. This should only be used for debugging purposes.
/// ///
/// Just before the assertion is done, the current test case will gather expected and computed values, /// Just before the assertion is done, the current test case will gather expected and computed values,
...@@ -130,7 +128,7 @@ namespace ngraph ...@@ -130,7 +128,7 @@ namespace ngraph
add_expected_output(expected_shape, value); add_expected_output(expected_shape, value);
} }
void run(); void run(size_t tolerance_bits = DEFAULT_FLOAT_TOLERANCE_BITS);
private: private:
template <typename T> template <typename T>
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment