Commit de132138 authored by Adam Procter's avatar Adam Procter

Merge remote-tracking branch 'origin/master' into aprocter/cf-element-types

parents b46fa4c7 23f838e5
......@@ -21,7 +21,7 @@ include(ExternalProject)
# ONNX.proto definition version
#------------------------------------------------------------------------------
set(ONNX_VERSION 1.3.0)
set(ONNX_VERSION 1.5.0)
#------------------------------------------------------------------------------
# Download and install libonnx ...
......@@ -30,6 +30,9 @@ set(ONNX_VERSION 1.3.0)
set(ONNX_GIT_REPO_URL https://github.com/onnx/onnx.git)
set(ONNX_GIT_BRANCH rel-${ONNX_VERSION})
add_definitions(-DONNX_BUILD_SHARED_LIBS=ON)
add_definitions(-DONNX_ML=ON)
ExternalProject_Add(
ext_onnx
PREFIX onnx
......@@ -58,8 +61,8 @@ ExternalProject_Add(
ExternalProject_Get_Property(ext_onnx SOURCE_DIR BINARY_DIR)
set(ONNX_INCLUDE_DIR ${SOURCE_DIR}/onnx)
set(ONNX_PROTO_INCLUDE_DIR ${BINARY_DIR}/onnx)
set(ONNX_INCLUDE_DIR ${SOURCE_DIR})
set(ONNX_PROTO_INCLUDE_DIR ${BINARY_DIR})
if (WIN32)
set(ONNX_LIBRARY ${BINARY_DIR}/${CMAKE_BUILD_TYPE}/onnx.lib)
set(ONNX_PROTO_LIBRARY ${BINARY_DIR}/${CMAKE_BUILD_TYPE}/onnx_proto.lib)
......
......@@ -37,6 +37,7 @@ ngraph.ops
equal
exp
floor
gelu
get_output_element
greater
greater_eq
......
......@@ -50,6 +50,7 @@ from ngraph.ops import elu
from ngraph.ops import equal
from ngraph.ops import exp
from ngraph.ops import floor
from ngraph.ops import gelu
from ngraph.ops import get_output_element
from ngraph.ops import greater
from ngraph.ops import greater_eq
......
......@@ -74,6 +74,7 @@ from _pyngraph.op import Elu
from _pyngraph.op import Equal
from _pyngraph.op import Exp
from _pyngraph.op import Floor
from _pyngraph.op import Gelu
from _pyngraph.op import GetOutputElement
from _pyngraph.op import Greater
from _pyngraph.op import GreaterEq
......
......@@ -23,7 +23,7 @@ from ngraph.impl import AxisSet, AxisVector, Coordinate, CoordinateDiff, Functio
from ngraph.impl.op import Abs, Acos, Add, And, Asin, ArgMax, ArgMin, Atan, AvgPool, \
BatchNormTraining, BatchNormInference, Broadcast, Ceiling, Clamp, Concat, Constant, Convert, \
Convolution, ConvolutionBackpropData, Cos, Cosh, Divide, Dot, Elu, Equal, Exp, Floor, \
GetOutputElement, Greater, GreaterEq, Less, LessEq, Log, LRN, Max, Maximum, MaxPool, \
Gelu, GetOutputElement, Greater, GreaterEq, Less, LessEq, Log, LRN, Max, Maximum, MaxPool, \
Min, Minimum, Multiply, Negative, Not, NotEqual, OneHot, Or, Pad, Parameter, Product, \
Power, Relu, ReplaceSlice, Reshape, Reverse, Select, Sign, Sin, Sinh, Slice, Softmax, \
Sqrt, Subtract, Sum, Tan, Tanh, TopK
......@@ -527,6 +527,24 @@ def convert(node, new_type, name=None): # type: (Node, NumericType, str) -> Nod
return Convert(node, new_element_type)
@nameable_op
def gelu(node, name=None): # type: (NodeInput, str) -> Node
r"""Perform Gaussian Error Linear Unit operation element-wise on data from input node.
Computes GELU function:
.. math:: f(x) = 0.5\cdot x\cdot(1 + erf( \dfrac{x}{\sqrt{2}})
For more information refer to:
`Gaussian Error Linear Unit (GELU) <https://arxiv.org/pdf/1606.08415.pdf>`_
:param node: Input tensor. One of: input node, array or scalar.
:param name: Optional output node name.
:return: The new node performing a GELU operation on its input data element-wise.
"""
return Gelu(as_node(node))
@nameable_op
def select(selection_node, input_node1, input_node2, name=None):
# type: (Node, Node, Node, str) -> Node
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "ngraph/op/fused/gelu.hpp"
#include "pyngraph/ops/fused/gelu.hpp"
namespace py = pybind11;
void regclass_pyngraph_op_Gelu(py::module m)
{
py::class_<ngraph::op::Gelu, std::shared_ptr<ngraph::op::Gelu>, ngraph::op::Op> gelu(m, "Gelu");
gelu.doc() = "ngraph.impl.op.Gelu wraps ngraph::op::Gelu";
gelu.def(py::init<const std::shared_ptr<ngraph::Node>&>());
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <pybind11/pybind11.h>
namespace py = pybind11;
void regclass_pyngraph_op_Gelu(py::module m);
......@@ -54,6 +54,7 @@ void regmodule_pyngraph_op(py::module m_op)
regclass_pyngraph_op_Equal(m_op);
regclass_pyngraph_op_Exp(m_op);
regclass_pyngraph_op_Floor(m_op);
regclass_pyngraph_op_Gelu(m_op);
regclass_pyngraph_op_GetOutputElement(m_op);
regclass_pyngraph_op_Greater(m_op);
regclass_pyngraph_op_GreaterEq(m_op);
......
......@@ -44,6 +44,7 @@
#include "pyngraph/ops/floor.hpp"
#include "pyngraph/ops/fused/clamp.hpp"
#include "pyngraph/ops/fused/elu.hpp"
#include "pyngraph/ops/fused/gelu.hpp"
#include "pyngraph/ops/get_output_element.hpp"
#include "pyngraph/ops/greater.hpp"
#include "pyngraph/ops/greater_eq.hpp"
......
......@@ -184,6 +184,7 @@ sources = [
'pyngraph/ops/equal.cpp',
'pyngraph/ops/exp.cpp',
'pyngraph/ops/floor.cpp',
'pyngraph/ops/fused/gelu.cpp',
'pyngraph/ops/greater.cpp',
'pyngraph/ops/greater_eq.cpp',
'pyngraph/ops/less.cpp',
......
......@@ -19,7 +19,7 @@ import ngraph as ng
from test.ngraph.util import get_runtime
def test_elu_operator():
def test_elu_operator_with_parameters():
runtime = get_runtime()
data_shape = [2, 2]
......@@ -69,6 +69,38 @@ def test_elu_operator_with_scalar():
assert np.allclose(result, expected)
def test_gelu_operator_with_parameters():
runtime = get_runtime()
data_value = np.array([[-5, 1], [-2, 3]], dtype=np.float32)
data_shape = [2, 2]
parameter_data = ng.parameter(data_shape, name='Data', dtype=np.float32)
model = ng.gelu(parameter_data)
computation = runtime.computation(model, parameter_data)
result = computation(data_value)
expected = np.array([[-1.4901161e-06, 8.4134471e-01], [-4.5500278e-02, 2.9959502]],
dtype=np.float32)
assert np.allclose(result, expected)
def test_gelu_operator_with_array():
runtime = get_runtime()
data_value = np.array([[-5, 1], [-2, 3]], dtype=np.float32)
model = ng.gelu(data_value)
computation = runtime.computation(model)
result = computation()
expected = np.array([[-1.4901161e-06, 8.4134471e-01], [-4.5500278e-02, 2.9959502]],
dtype=np.float32)
assert np.allclose(result, expected)
def test_clamp_operator():
runtime = get_runtime()
......@@ -99,4 +131,5 @@ def test_clamp_operator_with_array():
result = computation()
expected = np.clip(data_value, min_value, max_value)
assert np.allclose(result, expected)
......@@ -16,7 +16,7 @@
#pragma once
#include <onnx-ml.pb.h>
#include <onnx/onnx_pb.h>
#include "ngraph/except.hpp"
#include "tensor.hpp"
......
......@@ -16,7 +16,7 @@
#pragma once
#include <onnx-ml.pb.h>
#include <onnx/onnx_pb.h>
#include <string>
#include <vector>
......
......@@ -14,7 +14,7 @@
// limitations under the License.
//*****************************************************************************
#include <onnx-ml.pb.h>
#include <onnx/onnx_pb.h>
#include "model.hpp"
#include "ngraph/log.hpp"
......
......@@ -16,7 +16,7 @@
#pragma once
#include <onnx-ml.pb.h>
#include <onnx/onnx_pb.h>
#include <ostream>
#include <string>
#include <unordered_map>
......
......@@ -14,7 +14,7 @@
// limitations under the License.
//*****************************************************************************
#include <onnx-ml.pb.h>
#include <onnx/onnx_pb.h>
#include "attribute.hpp"
#include "graph.hpp"
......
......@@ -16,7 +16,7 @@
#pragma once
#include <onnx-ml.pb.h>
#include <onnx/onnx_pb.h>
#include <utility>
#include <vector>
......
......@@ -16,7 +16,7 @@
#pragma once
#include <onnx-ml.pb.h>
#include <onnx/onnx_pb.h>
#include "ngraph/op/constant.hpp"
#include "ngraph/op/parameter.hpp"
......
......@@ -13,7 +13,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <onnx-ml.pb.h> // onnx types
#include <onnx/onnx_pb.h> // onnx types
#include "common.hpp"
......
......@@ -15,7 +15,7 @@
//*****************************************************************************
#include <cstdlib> // std::size_t, std::uintptr_t
#include <onnxifi.h>
#include <onnx/onnxifi.h>
#include <stdexcept> // std::invalid_agrument, std::out_of_rage
#include "backend.hpp"
......
......@@ -19,7 +19,7 @@
#include <cstddef> // std::size_t, std::uintptr_t
#include <map> // std::map
#include <mutex> // std::mutex
#include <onnxifi.h>
#include <onnx/onnxifi.h>
#include "backend.hpp"
#include "ngraph/runtime/backend.hpp"
......
......@@ -16,7 +16,7 @@
#pragma once
#include <onnxifi.h>
#include <onnx/onnxifi.h>
namespace ngraph
{
......
......@@ -16,7 +16,7 @@
#include <cstddef>
#include <cstdint>
#include <onnxifi.h>
#include <onnx/onnxifi.h>
#include <stdexcept>
#include "backend_manager.hpp"
......
......@@ -17,7 +17,7 @@
#pragma once
#include <memory>
#include <onnxifi.h>
#include <onnx/onnxifi.h>
#include "ngraph/runtime/backend.hpp"
#include "ngraph/runtime/tensor.hpp"
......
......@@ -21,7 +21,9 @@
#include "ngraph/graph_util.hpp"
#include "ngraph/op/abs.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/all.hpp"
#include "ngraph/op/and.hpp"
#include "ngraph/op/any.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/ceiling.hpp"
#include "ngraph/op/concat.hpp"
......@@ -41,7 +43,9 @@
#include "ngraph/op/greater_eq.hpp"
#include "ngraph/op/less.hpp"
#include "ngraph/op/less_eq.hpp"
#include "ngraph/op/max.hpp"
#include "ngraph/op/maximum.hpp"
#include "ngraph/op/min.hpp"
#include "ngraph/op/minimum.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/negative.hpp"
......@@ -64,7 +68,9 @@
#include "ngraph/pattern/op/label.hpp"
#include "ngraph/runtime/reference/abs.hpp"
#include "ngraph/runtime/reference/add.hpp"
#include "ngraph/runtime/reference/all.hpp"
#include "ngraph/runtime/reference/and.hpp"
#include "ngraph/runtime/reference/any.hpp"
#include "ngraph/runtime/reference/broadcast.hpp"
#include "ngraph/runtime/reference/ceiling.hpp"
#include "ngraph/runtime/reference/concat.hpp"
......@@ -78,7 +84,9 @@
#include "ngraph/runtime/reference/greater_eq.hpp"
#include "ngraph/runtime/reference/less.hpp"
#include "ngraph/runtime/reference/less_eq.hpp"
#include "ngraph/runtime/reference/max.hpp"
#include "ngraph/runtime/reference/maximum.hpp"
#include "ngraph/runtime/reference/min.hpp"
#include "ngraph/runtime/reference/minimum.hpp"
#include "ngraph/runtime/reference/multiply.hpp"
#include "ngraph/runtime/reference/negate.hpp"
......@@ -1608,180 +1616,207 @@ void pass::ConstantFolding::construct_constant_reverse()
}
template <typename T>
static shared_ptr<op::Constant> fold_constant_product_helper(shared_ptr<op::Constant> constant,
const AxisSet& reduction_axes,
const Shape& result_shape)
static shared_ptr<op::Constant>
fold_constant_arithmetic_reduction_helper(shared_ptr<op::Constant> constant,
shared_ptr<Node> reduction_node)
{
vector<T> out_vec(shape_size(result_shape));
vector<T> out_vec(shape_size(reduction_node->get_shape()));
runtime::reference::product<T>(constant->get_vector<T>().data(),
if (auto max = dynamic_pointer_cast<op::Max>(reduction_node))
{
runtime::reference::max<T>(constant->get_vector<T>().data(),
out_vec.data(),
constant->get_output_shape(0),
reduction_node->get_shape(),
max->get_reduction_axes());
}
else if (auto min = dynamic_pointer_cast<op::Min>(reduction_node))
{
runtime::reference::min<T>(constant->get_vector<T>().data(),
out_vec.data(),
constant->get_output_shape(0),
reduction_node->get_shape(),
min->get_reduction_axes());
}
else if (auto prod = dynamic_pointer_cast<op::Product>(reduction_node))
{
runtime::reference::product<T>(constant->get_vector<T>().data(),
out_vec.data(),
constant->get_output_shape(0),
reduction_node->get_shape(),
prod->get_reduction_axes());
}
else if (auto sum = dynamic_pointer_cast<op::Sum>(reduction_node))
{
runtime::reference::sum<T>(constant->get_vector<T>().data(),
out_vec.data(),
constant->get_output_shape(0),
result_shape,
reduction_axes);
reduction_node->get_shape(),
sum->get_reduction_axes());
}
else
{
NGRAPH_CHECK(false,
"Internal nGraph error: Ops handled in "
"fold_constant_arithmetic_reduction_helper must be consistent with those "
"matched in construct_constant_arithmetic_reduction");
}
return make_shared<op::Constant>(constant->get_output_element_type(0), result_shape, out_vec);
return make_shared<op::Constant>(
reduction_node->get_output_element_type(0), reduction_node->get_shape(), out_vec);
}
static shared_ptr<op::Constant> fold_constant_product(shared_ptr<op::Constant> constant,
const AxisSet& reduction_axes,
const Shape& result_shape)
static shared_ptr<op::Constant>
fold_constant_arithmetic_reduction(shared_ptr<op::Constant> constant,
shared_ptr<Node> reduction_node)
{
auto& input_element_type = constant->get_output_element_type(0);
switch (input_element_type.get_type_enum())
{
case element::Type_t::undefined:
NGRAPH_CHECK(false, "Encountered 'undefined' element type in fold_constant_product");
NGRAPH_CHECK(false,
"Encountered 'undefined' element type in fold_constant_arithmetic_reduction");
break;
case element::Type_t::dynamic:
NGRAPH_CHECK(false, "Encountered 'dynamic' element type in fold_constant_product");
NGRAPH_CHECK(false,
"Encountered 'dynamic' element type in fold_constant_arithmetic_reduction");
break;
case element::Type_t::boolean:
return fold_constant_product_helper<char>(constant, reduction_axes, result_shape);
return fold_constant_arithmetic_reduction_helper<char>(constant, reduction_node);
case element::Type_t::bf16:
return fold_constant_product_helper<bfloat16>(constant, reduction_axes, result_shape);
return fold_constant_arithmetic_reduction_helper<bfloat16>(constant, reduction_node);
case element::Type_t::f16:
return fold_constant_product_helper<float16>(constant, reduction_axes, result_shape);
return fold_constant_arithmetic_reduction_helper<float16>(constant, reduction_node);
case element::Type_t::f32:
return fold_constant_product_helper<float>(constant, reduction_axes, result_shape);
return fold_constant_arithmetic_reduction_helper<float>(constant, reduction_node);
case element::Type_t::f64:
return fold_constant_product_helper<double>(constant, reduction_axes, result_shape);
return fold_constant_arithmetic_reduction_helper<double>(constant, reduction_node);
case element::Type_t::i8:
return fold_constant_product_helper<int8_t>(constant, reduction_axes, result_shape);
return fold_constant_arithmetic_reduction_helper<int8_t>(constant, reduction_node);
case element::Type_t::i16:
return fold_constant_product_helper<int16_t>(constant, reduction_axes, result_shape);
return fold_constant_arithmetic_reduction_helper<int16_t>(constant, reduction_node);
case element::Type_t::i32:
return fold_constant_product_helper<int32_t>(constant, reduction_axes, result_shape);
return fold_constant_arithmetic_reduction_helper<int32_t>(constant, reduction_node);
case element::Type_t::i64:
return fold_constant_product_helper<int64_t>(constant, reduction_axes, result_shape);
return fold_constant_arithmetic_reduction_helper<int64_t>(constant, reduction_node);
case element::Type_t::u8:
return fold_constant_product_helper<uint8_t>(constant, reduction_axes, result_shape);
return fold_constant_arithmetic_reduction_helper<uint8_t>(constant, reduction_node);
case element::Type_t::u16:
return fold_constant_product_helper<uint16_t>(constant, reduction_axes, result_shape);
return fold_constant_arithmetic_reduction_helper<uint16_t>(constant, reduction_node);
case element::Type_t::u32:
return fold_constant_product_helper<uint32_t>(constant, reduction_axes, result_shape);
return fold_constant_arithmetic_reduction_helper<uint32_t>(constant, reduction_node);
case element::Type_t::u64:
return fold_constant_product_helper<uint64_t>(constant, reduction_axes, result_shape);
return fold_constant_arithmetic_reduction_helper<uint64_t>(constant, reduction_node);
}
NGRAPH_UNREACHABLE("Unexpected switch case");
}
void pass::ConstantFolding::construct_constant_product()
void pass::ConstantFolding::construct_constant_arithmetic_reduction()
{
auto constant_label = make_shared<pattern::op::Label>(
auto constant_data_label = make_shared<pattern::op::Label>(
element::i32, Shape{2, 3, 4}, pattern::has_class<op::Constant>());
auto convert_op = make_shared<op::Product>(constant_label, AxisSet{0, 1, 2});
auto constant_product_callback = [constant_label](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for constant_product_callback against node = "
auto constant_axes_label =
make_shared<pattern::op::Label>(element::i64, Shape{2}, pattern::has_class<op::Constant>());
auto is_supported_reduction = [](std::shared_ptr<Node> n) {
return (pattern::has_class<op::Max>()(n) || pattern::has_class<op::Min>()(n) ||
pattern::has_class<op::Product>()(n) || pattern::has_class<op::Sum>()(n));
};
auto reduction =
std::make_shared<pattern::op::Any>(element::i32,
Shape{2},
is_supported_reduction,
NodeVector{constant_data_label, constant_axes_label});
auto constant_arithmetic_reduction_callback = [constant_data_label](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for constant_arithmetic_reduction_callback against node = "
<< m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map();
auto constant_match = static_pointer_cast<op::Constant>(pattern_map[constant_label]);
auto product_match = static_pointer_cast<op::Product>(m.get_match_root());
auto constant_match = static_pointer_cast<op::Constant>(pattern_map[constant_data_label]);
auto reduction_match = m.get_match_root();
replace_node(m.get_match_root(),
fold_constant_product(constant_match,
product_match->get_reduction_axes(),
product_match->get_output_shape(0)));
replace_node(reduction_match,
fold_constant_arithmetic_reduction(constant_match, reduction_match));
return true;
};
auto convert_matcher =
make_shared<pattern::Matcher>(convert_op, "ConstantFolding.ConstantProduct");
this->add_matcher(convert_matcher, constant_product_callback, all_pass_property_off);
auto arithmetic_reduction_matcher =
make_shared<pattern::Matcher>(reduction, "ConstantFolding.ConstantArithmeticReduction");
this->add_matcher(arithmetic_reduction_matcher,
constant_arithmetic_reduction_callback,
all_pass_property_off);
}
// TODO(amprocte): Find a way to reduce duplication with Product. (The fact
// that we bottom out in a reference call makes it a bit tricky.)
template <typename T>
static shared_ptr<op::Constant> fold_constant_sum_helper(shared_ptr<op::Constant> constant,
const AxisSet& reduction_axes,
const Shape& result_shape)
static shared_ptr<op::Constant> fold_constant_logical_reduction(shared_ptr<op::Constant> constant,
shared_ptr<Node> reduction_node)
{
vector<T> out_vec(shape_size(result_shape));
runtime::reference::sum<T>(constant->get_vector<T>().data(),
out_vec.data(),
constant->get_output_shape(0),
result_shape,
reduction_axes);
return make_shared<op::Constant>(constant->get_output_element_type(0), result_shape, out_vec);
}
vector<char> out_vec(shape_size(reduction_node->get_shape()));
static shared_ptr<op::Constant> fold_constant_sum(shared_ptr<op::Constant> constant,
const AxisSet& reduction_axes,
const Shape& result_shape)
{
auto& input_element_type = constant->get_output_element_type(0);
switch (input_element_type.get_type_enum())
if (auto all = dynamic_pointer_cast<::ngraph::op::All>(reduction_node))
{
case element::Type_t::undefined:
NGRAPH_CHECK(false, "Encountered 'undefined' element type in fold_constant_sum");
break;
case element::Type_t::dynamic:
NGRAPH_CHECK(false, "Encountered 'dynamic' element type in fold_constant_sum");
break;
case element::Type_t::boolean:
return fold_constant_sum_helper<char>(constant, reduction_axes, result_shape);
case element::Type_t::bf16:
return fold_constant_sum_helper<bfloat16>(constant, reduction_axes, result_shape);
case element::Type_t::f16:
return fold_constant_sum_helper<float16>(constant, reduction_axes, result_shape);
case element::Type_t::f32:
return fold_constant_sum_helper<float>(constant, reduction_axes, result_shape);
case element::Type_t::f64:
return fold_constant_sum_helper<double>(constant, reduction_axes, result_shape);
case element::Type_t::i8:
return fold_constant_sum_helper<int8_t>(constant, reduction_axes, result_shape);
case element::Type_t::i16:
return fold_constant_sum_helper<int16_t>(constant, reduction_axes, result_shape);
case element::Type_t::i32:
return fold_constant_sum_helper<int32_t>(constant, reduction_axes, result_shape);
case element::Type_t::i64:
return fold_constant_sum_helper<int64_t>(constant, reduction_axes, result_shape);
case element::Type_t::u8:
return fold_constant_sum_helper<uint8_t>(constant, reduction_axes, result_shape);
case element::Type_t::u16:
return fold_constant_sum_helper<uint16_t>(constant, reduction_axes, result_shape);
case element::Type_t::u32:
return fold_constant_sum_helper<uint32_t>(constant, reduction_axes, result_shape);
case element::Type_t::u64:
return fold_constant_sum_helper<uint64_t>(constant, reduction_axes, result_shape);
runtime::reference::all(constant->get_vector<char>().data(),
out_vec.data(),
constant->get_output_shape(0),
reduction_node->get_shape(),
all->get_reduction_axes());
}
else if (auto any = dynamic_pointer_cast<::ngraph::op::Any>(reduction_node))
{
runtime::reference::any(constant->get_vector<char>().data(),
out_vec.data(),
constant->get_output_shape(0),
reduction_node->get_shape(),
any->get_reduction_axes());
}
else
{
NGRAPH_CHECK(false,
"Internal nGraph error: Ops handled in "
"fold_constant_logical_reduction must be consistent with those "
"matched in construct_constant_logical_reduction");
}
NGRAPH_UNREACHABLE("Unexpected switch case");
return make_shared<op::Constant>(
reduction_node->get_output_element_type(0), reduction_node->get_shape(), out_vec);
}
void pass::ConstantFolding::construct_constant_sum()
void pass::ConstantFolding::construct_constant_logical_reduction()
{
auto constant_label = make_shared<pattern::op::Label>(
element::i32, Shape{2, 3, 4}, pattern::has_class<op::Constant>());
auto convert_op = make_shared<op::Sum>(constant_label, AxisSet{0, 1, 2});
auto constant_sum_callback = [constant_label](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for constant_sum_callback against node = "
auto constant_data_label = make_shared<pattern::op::Label>(
element::boolean, Shape{2, 3, 4}, pattern::has_class<op::Constant>());
auto constant_axes_label =
make_shared<pattern::op::Label>(element::i64, Shape{2}, pattern::has_class<op::Constant>());
auto is_supported_reduction = [](std::shared_ptr<Node> n) {
return (pattern::has_class<::ngraph::op::All>()(n) ||
pattern::has_class<::ngraph::op::Any>()(n));
};
auto reduction =
std::make_shared<pattern::op::Any>(element::i32,
Shape{2},
is_supported_reduction,
NodeVector{constant_data_label, constant_axes_label});
auto constant_logical_reduction_callback = [constant_data_label](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for constant_logical_reduction_callback against node = "
<< m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map();
auto constant_match = static_pointer_cast<op::Constant>(pattern_map[constant_label]);
auto sum_match = static_pointer_cast<op::Sum>(m.get_match_root());
auto constant_match = static_pointer_cast<op::Constant>(pattern_map[constant_data_label]);
auto reduction_match = m.get_match_root();
replace_node(m.get_match_root(),
fold_constant_sum(constant_match,
sum_match->get_reduction_axes(),
sum_match->get_output_shape(0)));
replace_node(reduction_match,
fold_constant_logical_reduction(constant_match, reduction_match));
return true;
};
auto convert_matcher = make_shared<pattern::Matcher>(convert_op, "ConstantFolding.ConstantSum");
this->add_matcher(convert_matcher, constant_sum_callback, all_pass_property_off);
auto logical_reduction_matcher =
make_shared<pattern::Matcher>(reduction, "ConstantFolding.ConstantLogicalReduction");
this->add_matcher(
logical_reduction_matcher, constant_logical_reduction_callback, all_pass_property_off);
}
template <typename T>
......
......@@ -42,8 +42,8 @@ public:
CONVERT,
SHAPE_OF,
REVERSE,
PRODUCT,
SUM,
ARITHMETIC_REDUCTION,
LOGICAL_REDUCTION,
CONCAT,
GATHER,
SLICE,
......@@ -68,8 +68,8 @@ public:
construct_constant_convert();
construct_constant_shape_of();
construct_constant_reverse();
construct_constant_product();
construct_constant_sum();
construct_constant_arithmetic_reduction();
construct_constant_logical_reduction();
construct_constant_concat();
construct_constant_gather();
construct_constant_slice();
......@@ -101,8 +101,12 @@ public:
case CFTransformations::CONVERT: construct_constant_convert(); break;
case CFTransformations::SHAPE_OF: construct_constant_shape_of(); break;
case CFTransformations::REVERSE: construct_constant_reverse(); break;
case CFTransformations::PRODUCT: construct_constant_product(); break;
case CFTransformations::SUM: construct_constant_sum(); break;
case CFTransformations::ARITHMETIC_REDUCTION:
construct_constant_arithmetic_reduction();
break;
case CFTransformations::LOGICAL_REDUCTION:
construct_constant_logical_reduction();
break;
case CFTransformations::CONCAT: construct_constant_concat(); break;
case CFTransformations::GATHER: construct_constant_gather(); break;
case CFTransformations::SLICE: construct_constant_slice(); break;
......@@ -126,8 +130,8 @@ private:
void construct_constant_convert();
void construct_constant_shape_of();
void construct_constant_reverse();
void construct_constant_product();
void construct_constant_sum();
void construct_constant_arithmetic_reduction();
void construct_constant_logical_reduction();
void construct_constant_concat();
void construct_constant_gather();
void construct_constant_slice();
......
......@@ -17,7 +17,6 @@
#pragma once
#include "ngraph/op/op.hpp"
#include "ngraph/runtime/cpu/cpu_backend_visibility.h"
namespace ngraph
{
......@@ -35,7 +34,6 @@ namespace ngraph
class BatchMatMulTranspose : public Op
{
public:
CPU_BACKEND_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a batch of matmul product operation.
......
......@@ -30,7 +30,6 @@ namespace ngraph
class BatchNormTrainingRelu : public Op
{
public:
CPU_BACKEND_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
CPU_BACKEND_API BatchNormTrainingRelu(double eps,
......@@ -60,7 +59,6 @@ namespace ngraph
class BatchNormInferenceRelu : public Op
{
public:
CPU_BACKEND_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
BatchNormInferenceRelu(double eps,
......
......@@ -19,7 +19,6 @@
#include "ngraph/node.hpp"
#include "ngraph/op/op.hpp"
#include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
#include "ngraph/runtime/cpu/cpu_backend_visibility.h"
namespace ngraph
{
......@@ -30,7 +29,6 @@ namespace ngraph
class BoundedRelu : public ngraph::op::util::UnaryElementwiseArithmetic
{
public:
CPU_BACKEND_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a BoundedRelu operation.
......
......@@ -18,7 +18,6 @@
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/op.hpp"
#include "ngraph/runtime/cpu/cpu_backend_visibility.h"
namespace ngraph
{
......@@ -27,7 +26,6 @@ namespace ngraph
class ConvolutionAdd : public Op
{
public:
CPU_BACKEND_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
ConvolutionAdd(const std::shared_ptr<op::Convolution>& conv,
......
......@@ -28,7 +28,6 @@ namespace ngraph
class ConvolutionRelu : public Op
{
public:
CPU_BACKEND_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
CPU_BACKEND_API ConvolutionRelu(const std::shared_ptr<op::Convolution>& conv);
......
......@@ -35,7 +35,6 @@ namespace ngraph
class ConvertLayout : public ngraph::op::Op
{
public:
CPU_BACKEND_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
CPU_BACKEND_API ConvertLayout(
......
......@@ -18,7 +18,6 @@
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/op.hpp"
#include "ngraph/runtime/cpu/cpu_backend_visibility.h"
namespace ngraph
{
......@@ -28,7 +27,6 @@ namespace ngraph
class DeconvolutionBias : public Op
{
public:
CPU_BACKEND_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a batched-convolution data batch-backprop operation.
......
......@@ -17,7 +17,6 @@
#pragma once
#include "ngraph/op/op.hpp"
#include "ngraph/runtime/cpu/cpu_backend_visibility.h"
#include "ngraph/util.hpp"
namespace ngraph
......@@ -27,7 +26,6 @@ namespace ngraph
class Dropout : public Op
{
public:
CPU_BACKEND_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
Dropout(const Output<Node>& input,
......
......@@ -17,7 +17,6 @@
#pragma once
#include "ngraph/op/fused/group_conv.hpp"
#include "ngraph/op/op.hpp"
#include "ngraph/runtime/cpu/cpu_backend_visibility.h"
namespace ngraph
{
......@@ -28,7 +27,6 @@ namespace ngraph
class GroupConvolutionBias : public Op
{
public:
CPU_BACKEND_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
GroupConvolutionBias(const std::shared_ptr<op::GroupConvolution>& conv,
......
......@@ -20,7 +20,6 @@
#include <vector>
#include "ngraph/op/op.hpp"
#include "ngraph/runtime/cpu/cpu_backend_visibility.h"
namespace ngraph
{
......@@ -33,7 +32,6 @@ namespace ngraph
class HalideOp : public ngraph::op::Op
{
public:
CPU_BACKEND_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
HalideOp(const OutputVector& args,
......
......@@ -19,7 +19,6 @@
#include "ngraph/node.hpp"
#include "ngraph/op/op.hpp"
#include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
#include "ngraph/runtime/cpu/cpu_backend_visibility.h"
namespace ngraph
{
......@@ -31,7 +30,6 @@ namespace ngraph
class CPULeakyRelu : public ngraph::op::util::UnaryElementwiseArithmetic
{
public:
CPU_BACKEND_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a CPULeakyRelu operation.
......
......@@ -17,7 +17,6 @@
#pragma once
#include "ngraph/op/op.hpp"
#include "ngraph/runtime/cpu/cpu_backend_visibility.h"
#include "ngraph/runtime/cpu/op/rnn_utils.hpp"
#include "ngraph/util.hpp"
......@@ -28,7 +27,6 @@ namespace ngraph
class Lstm : public Op
{
public:
CPU_BACKEND_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
// INPUTS:
......
......@@ -27,7 +27,6 @@ namespace ngraph
class MatmulBias : public Op
{
public:
CPU_BACKEND_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
CPU_BACKEND_API MatmulBias(const Output<Node>& W,
......
......@@ -32,7 +32,6 @@ namespace ngraph
class MaxPoolWithIndices : public Op
{
public:
CPU_BACKEND_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
CPU_BACKEND_API MaxPoolWithIndices(const Output<Node>& arg,
......@@ -68,7 +67,6 @@ namespace ngraph
class MaxPoolWithIndicesBackprop : public Op
{
public:
CPU_BACKEND_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
CPU_BACKEND_API MaxPoolWithIndicesBackprop(const Output<Node>& arg_forward,
......
......@@ -19,7 +19,6 @@
#include <utility>
#include "ngraph/op/op.hpp"
#include "ngraph/runtime/cpu/cpu_backend_visibility.h"
namespace ngraph
{
......@@ -28,7 +27,6 @@ namespace ngraph
class QuantizedMatmul : public Op
{
public:
CPU_BACKEND_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
QuantizedMatmul(const Output<Node>& data,
......
......@@ -48,7 +48,6 @@ namespace ngraph
class Rnn : public Op
{
public:
CPU_BACKEND_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
CPU_BACKEND_API Rnn(const Output<Node>& src_layer,
......
......@@ -30,7 +30,6 @@ namespace ngraph
class SigmoidMultiply : public Op
{
public:
CPU_BACKEND_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// Defines valid function types
......@@ -69,7 +68,6 @@ namespace ngraph
class SigmoidMultiplyBackprop : public Op
{
public:
CPU_BACKEND_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
typedef SigmoidMultiply::FunctionType FunctionType;
......
......@@ -18,7 +18,6 @@
#include "ngraph/coordinate.hpp"
#include "ngraph/op/op.hpp"
#include "ngraph/runtime/cpu/cpu_backend_visibility.h"
#include "ngraph/strides.hpp"
namespace ngraph
......@@ -51,7 +50,6 @@ namespace ngraph
class UpdateSlice : public Op
{
public:
CPU_BACKEND_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a tensor slice update operation.
......
......@@ -39,7 +39,6 @@ namespace ngraph
class ngraph::runtime::plaidml::op::Convolution final : public ngraph::op::Op
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
Convolution(std::shared_ptr<ngraph::op::Convolution> src,
......@@ -66,7 +65,6 @@ private:
class ngraph::runtime::plaidml::op::ConvolutionBackpropData final : public ngraph::op::Op
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
ConvolutionBackpropData(std::shared_ptr<ngraph::op::ConvolutionBackpropData> src,
......@@ -93,7 +91,6 @@ private:
class ngraph::runtime::plaidml::op::ConvolutionBackpropFilters final : public ngraph::op::Op
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
ConvolutionBackpropFilters(std::shared_ptr<ngraph::op::ConvolutionBackpropFilters> src,
......
......@@ -40,7 +40,6 @@ namespace ngraph
class ngraph::runtime::plaidml::op::ImplicitBroadcast final : public ngraph::op::Op
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
ImplicitBroadcast(const Output<Node>& input, const Shape& shape);
......
......@@ -39,7 +39,6 @@ namespace ngraph
class ngraph::runtime::plaidml::op::Replicate final : public ngraph::op::Op
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
Replicate(const Output<Node>& arg, std::size_t replication_axis, std::size_t replication_count);
......
......@@ -38,7 +38,6 @@ namespace ngraph
class ngraph::runtime::plaidml::op::Winograd final : public ngraph::op::Op
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
Winograd(std::shared_ptr<Convolution> conv, const OutputVector& args);
......
......@@ -36,7 +36,7 @@ namespace ngraph
const AxisSet& reduction_axes)
{
T minval = std::numeric_limits<T>::has_infinity
? -std::numeric_limits<T>::infinity()
? T(-std::numeric_limits<T>::infinity())
: std::numeric_limits<T>::min();
CoordinateTransform output_transform(out_shape);
......
......@@ -1540,8 +1540,7 @@ NGRAPH_TEST(${BACKEND_NAME}, group_conv_transpose)
-0.0270785f, -0.00680824f, -0.06650258f, 0.08004665f, 0.07918708f, -0.0724144f,
0.06256775f, -0.17838378f, -0.18863615f, 0.20064656f, 0.133717f, -0.06876295f,
-0.06398046f, -0.00864975f, 0.19289537f, -0.01490572f, -0.13673618f, 0.01949645f});
test_case.set_tolerance(3);
test_case.run();
test_case.run(DEFAULT_FLOAT_TOLERANCE_BITS + 1);
}
NGRAPH_TEST(${BACKEND_NAME}, group_conv_transpose_output_shape)
......
......@@ -434,6 +434,110 @@ TEST(constant_folding, const_sum)
ASSERT_EQ(values_expected, values_out);
}
TEST(constant_folding, const_max)
{
Shape input_shape{3, 3};
vector<int32_t> values_in{1, 2, 3, 4, 5, 6, 7, 8, 9};
auto constant = op::Constant::create(element::i32, input_shape, values_in);
auto convert = make_shared<op::Max>(constant, AxisSet{1});
auto f = make_shared<Function>(convert, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Max>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const =
std::dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
auto values_out = new_const->get_vector<int32_t>();
vector<int32_t> values_expected{3, 6, 9};
ASSERT_EQ(values_expected, values_out);
}
TEST(constant_folding, const_min)
{
Shape input_shape{3, 3};
vector<int32_t> values_in{1, 2, 3, 4, 5, 6, 7, 8, 9};
auto constant = op::Constant::create(element::i32, input_shape, values_in);
auto convert = make_shared<op::Min>(constant, AxisSet{1});
auto f = make_shared<Function>(convert, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Min>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const =
std::dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
auto values_out = new_const->get_vector<int32_t>();
vector<int32_t> values_expected{1, 4, 7};
ASSERT_EQ(values_expected, values_out);
}
TEST(constant_folding, const_all)
{
Shape input_shape{3, 3};
vector<char> values_in{0, 1, 1, 0, 1, 0, 1, 1, 1};
auto constant = op::Constant::create(element::boolean, input_shape, values_in);
auto convert = make_shared<op::All>(constant, AxisSet{1});
auto f = make_shared<Function>(convert, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::All>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const =
std::dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
auto values_out = new_const->get_vector<char>();
vector<char> values_expected{0, 0, 1};
ASSERT_EQ(values_expected, values_out);
}
TEST(constant_folding, const_any)
{
Shape input_shape{3, 3};
vector<char> values_in{1, 0, 0, 1, 0, 1, 0, 0, 0};
auto constant = op::Constant::create(element::boolean, input_shape, values_in);
auto convert = make_shared<op::Any>(constant, AxisSet{1});
auto f = make_shared<Function>(convert, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Any>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const =
std::dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
auto values_out = new_const->get_vector<char>();
vector<char> values_expected{1, 1, 0};
ASSERT_EQ(values_expected, values_out);
}
TEST(constant_folding, const_concat)
{
auto constant0 =
......
......@@ -104,8 +104,7 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_lstm_fwd_with_clip)
// We have to enlarge tolerance bits to 3 - it's only one bit more than default value.
// The discrepancies may occur at most on 7th decimal position.
test_case.set_tolerance(3);
test_case.run();
test_case.run(DEFAULT_FLOAT_TOLERANCE_BITS + 1);
}
NGRAPH_TEST(onnx_${BACKEND_NAME}, model_lstm_fwd_mixed_seq)
......@@ -144,8 +143,7 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_lstm_fwd_mixed_seq)
// We have to enlarge tolerance bits to 3 - it's only one bit more than default value.
// The discrepancies may occur at most on 7th decimal position.
test_case.set_tolerance(3);
test_case.run();
test_case.run(DEFAULT_FLOAT_TOLERANCE_BITS + 1);
}
NGRAPH_TEST(onnx_${BACKEND_NAME}, model_lstm_fwd_hardsigmoid_activation)
......@@ -201,8 +199,7 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_lstm_fwd_hardsigmoid_activation)
test_case.add_expected_output<float>(Shape{1, 1, 2}, {0.19017234f, 0.00356848f});
// The discrepancies occur at most at 18th mantissa bit - 8th decimal position.
test_case.set_tolerance(6);
test_case.run();
test_case.run(DEFAULT_FLOAT_TOLERANCE_BITS + 4);
}
NGRAPH_TEST(onnx_${BACKEND_NAME}, model_lstm_fwd_large_batch_no_clip)
......@@ -307,8 +304,7 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_lstm_bdir_short_input_seq)
test_case.add_expected_output<float>(Shape{2, 1, 2},
{-0.0251062f, 0.0561262f, -0.0318928f, 0.0762679f});
test_case.set_tolerance(DEFAULT_FLOAT_TOLERANCE_BITS + 3);
test_case.run();
test_case.run(DEFAULT_FLOAT_TOLERANCE_BITS + 3);
}
NGRAPH_TEST(onnx_${BACKEND_NAME}, model_lstm_mixed_seq_reverse)
......@@ -353,6 +349,5 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_lstm_mixed_seq_reverse)
Shape{1, 2, 3},
{0.52497941f, 0.54983425f, 0.5744428f, 1.34960834f, 1.54772296f, 1.65633056f});
test_case.set_tolerance(DEFAULT_FLOAT_TOLERANCE_BITS + 1);
test_case.run();
test_case.run(DEFAULT_FLOAT_TOLERANCE_BITS + 1);
}
......@@ -17,7 +17,7 @@
#include <cstring>
#include <gtest/gtest.h>
#include <onnxifi.h>
#include <onnx/onnxifi.h>
#include "ngraph/runtime/backend_manager.hpp"
......
......@@ -19,8 +19,9 @@
#include "gtest/gtest.h"
#include "ngraph/assertion.hpp"
void ngraph::test::NgraphTestCase::run()
void ngraph::test::NgraphTestCase::run(size_t tolerance_bits)
{
m_tolerance_bits = tolerance_bits;
const auto& function_results = m_function->get_results();
NGRAPH_CHECK(m_expected_outputs.size() == function_results.size(),
"Expected number of outputs is different from the function's number of results.");
......@@ -52,12 +53,6 @@ void ngraph::test::NgraphTestCase::run()
}
}
ngraph::test::NgraphTestCase& ngraph::test::NgraphTestCase::set_tolerance(int tolerance_bits)
{
m_tolerance_bits = tolerance_bits;
return *this;
}
ngraph::test::NgraphTestCase& ngraph::test::NgraphTestCase::dump_results(bool dump)
{
m_dump_results = dump;
......
......@@ -38,8 +38,6 @@ namespace ngraph
{
}
NgraphTestCase& set_tolerance(int tolerance_bits);
/// \brief Makes the test case print the expected and computed values to the console. This should only be used for debugging purposes.
///
/// Just before the assertion is done, the current test case will gather expected and computed values,
......@@ -130,7 +128,7 @@ namespace ngraph
add_expected_output(expected_shape, value);
}
void run();
void run(size_t tolerance_bits = DEFAULT_FLOAT_TOLERANCE_BITS);
private:
template <typename T>
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment