Unverified Commit e9e410aa authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Merge branch 'master' into aprocter/cf-dyn-broadcast

parents b0ba5d72 23f838e5
...@@ -37,6 +37,7 @@ ngraph.ops ...@@ -37,6 +37,7 @@ ngraph.ops
equal equal
exp exp
floor floor
gelu
get_output_element get_output_element
greater greater
greater_eq greater_eq
......
...@@ -50,6 +50,7 @@ from ngraph.ops import elu ...@@ -50,6 +50,7 @@ from ngraph.ops import elu
from ngraph.ops import equal from ngraph.ops import equal
from ngraph.ops import exp from ngraph.ops import exp
from ngraph.ops import floor from ngraph.ops import floor
from ngraph.ops import gelu
from ngraph.ops import get_output_element from ngraph.ops import get_output_element
from ngraph.ops import greater from ngraph.ops import greater
from ngraph.ops import greater_eq from ngraph.ops import greater_eq
......
...@@ -74,6 +74,7 @@ from _pyngraph.op import Elu ...@@ -74,6 +74,7 @@ from _pyngraph.op import Elu
from _pyngraph.op import Equal from _pyngraph.op import Equal
from _pyngraph.op import Exp from _pyngraph.op import Exp
from _pyngraph.op import Floor from _pyngraph.op import Floor
from _pyngraph.op import Gelu
from _pyngraph.op import GetOutputElement from _pyngraph.op import GetOutputElement
from _pyngraph.op import Greater from _pyngraph.op import Greater
from _pyngraph.op import GreaterEq from _pyngraph.op import GreaterEq
......
...@@ -23,7 +23,7 @@ from ngraph.impl import AxisSet, AxisVector, Coordinate, CoordinateDiff, Functio ...@@ -23,7 +23,7 @@ from ngraph.impl import AxisSet, AxisVector, Coordinate, CoordinateDiff, Functio
from ngraph.impl.op import Abs, Acos, Add, And, Asin, ArgMax, ArgMin, Atan, AvgPool, \ from ngraph.impl.op import Abs, Acos, Add, And, Asin, ArgMax, ArgMin, Atan, AvgPool, \
BatchNormTraining, BatchNormInference, Broadcast, Ceiling, Clamp, Concat, Constant, Convert, \ BatchNormTraining, BatchNormInference, Broadcast, Ceiling, Clamp, Concat, Constant, Convert, \
Convolution, ConvolutionBackpropData, Cos, Cosh, Divide, Dot, Elu, Equal, Exp, Floor, \ Convolution, ConvolutionBackpropData, Cos, Cosh, Divide, Dot, Elu, Equal, Exp, Floor, \
GetOutputElement, Greater, GreaterEq, Less, LessEq, Log, LRN, Max, Maximum, MaxPool, \ Gelu, GetOutputElement, Greater, GreaterEq, Less, LessEq, Log, LRN, Max, Maximum, MaxPool, \
Min, Minimum, Multiply, Negative, Not, NotEqual, OneHot, Or, Pad, Parameter, Product, \ Min, Minimum, Multiply, Negative, Not, NotEqual, OneHot, Or, Pad, Parameter, Product, \
Power, Relu, ReplaceSlice, Reshape, Reverse, Select, Sign, Sin, Sinh, Slice, Softmax, \ Power, Relu, ReplaceSlice, Reshape, Reverse, Select, Sign, Sin, Sinh, Slice, Softmax, \
Sqrt, Subtract, Sum, Tan, Tanh, TopK Sqrt, Subtract, Sum, Tan, Tanh, TopK
...@@ -527,6 +527,24 @@ def convert(node, new_type, name=None): # type: (Node, NumericType, str) -> Nod ...@@ -527,6 +527,24 @@ def convert(node, new_type, name=None): # type: (Node, NumericType, str) -> Nod
return Convert(node, new_element_type) return Convert(node, new_element_type)
@nameable_op
def gelu(node, name=None): # type: (NodeInput, str) -> Node
r"""Perform Gaussian Error Linear Unit operation element-wise on data from input node.
Computes GELU function:
.. math:: f(x) = 0.5\cdot x\cdot(1 + erf( \dfrac{x}{\sqrt{2}})
For more information refer to:
`Gaussian Error Linear Unit (GELU) <https://arxiv.org/pdf/1606.08415.pdf>`_
:param node: Input tensor. One of: input node, array or scalar.
:param name: Optional output node name.
:return: The new node performing a GELU operation on its input data element-wise.
"""
return Gelu(as_node(node))
@nameable_op @nameable_op
def select(selection_node, input_node1, input_node2, name=None): def select(selection_node, input_node1, input_node2, name=None):
# type: (Node, Node, Node, str) -> Node # type: (Node, Node, Node, str) -> Node
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "ngraph/op/fused/gelu.hpp"
#include "pyngraph/ops/fused/gelu.hpp"
namespace py = pybind11;
void regclass_pyngraph_op_Gelu(py::module m)
{
py::class_<ngraph::op::Gelu, std::shared_ptr<ngraph::op::Gelu>, ngraph::op::Op> gelu(m, "Gelu");
gelu.doc() = "ngraph.impl.op.Gelu wraps ngraph::op::Gelu";
gelu.def(py::init<const std::shared_ptr<ngraph::Node>&>());
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <pybind11/pybind11.h>
namespace py = pybind11;
void regclass_pyngraph_op_Gelu(py::module m);
...@@ -54,6 +54,7 @@ void regmodule_pyngraph_op(py::module m_op) ...@@ -54,6 +54,7 @@ void regmodule_pyngraph_op(py::module m_op)
regclass_pyngraph_op_Equal(m_op); regclass_pyngraph_op_Equal(m_op);
regclass_pyngraph_op_Exp(m_op); regclass_pyngraph_op_Exp(m_op);
regclass_pyngraph_op_Floor(m_op); regclass_pyngraph_op_Floor(m_op);
regclass_pyngraph_op_Gelu(m_op);
regclass_pyngraph_op_GetOutputElement(m_op); regclass_pyngraph_op_GetOutputElement(m_op);
regclass_pyngraph_op_Greater(m_op); regclass_pyngraph_op_Greater(m_op);
regclass_pyngraph_op_GreaterEq(m_op); regclass_pyngraph_op_GreaterEq(m_op);
......
...@@ -44,6 +44,7 @@ ...@@ -44,6 +44,7 @@
#include "pyngraph/ops/floor.hpp" #include "pyngraph/ops/floor.hpp"
#include "pyngraph/ops/fused/clamp.hpp" #include "pyngraph/ops/fused/clamp.hpp"
#include "pyngraph/ops/fused/elu.hpp" #include "pyngraph/ops/fused/elu.hpp"
#include "pyngraph/ops/fused/gelu.hpp"
#include "pyngraph/ops/get_output_element.hpp" #include "pyngraph/ops/get_output_element.hpp"
#include "pyngraph/ops/greater.hpp" #include "pyngraph/ops/greater.hpp"
#include "pyngraph/ops/greater_eq.hpp" #include "pyngraph/ops/greater_eq.hpp"
......
...@@ -184,6 +184,7 @@ sources = [ ...@@ -184,6 +184,7 @@ sources = [
'pyngraph/ops/equal.cpp', 'pyngraph/ops/equal.cpp',
'pyngraph/ops/exp.cpp', 'pyngraph/ops/exp.cpp',
'pyngraph/ops/floor.cpp', 'pyngraph/ops/floor.cpp',
'pyngraph/ops/fused/gelu.cpp',
'pyngraph/ops/greater.cpp', 'pyngraph/ops/greater.cpp',
'pyngraph/ops/greater_eq.cpp', 'pyngraph/ops/greater_eq.cpp',
'pyngraph/ops/less.cpp', 'pyngraph/ops/less.cpp',
......
...@@ -19,7 +19,7 @@ import ngraph as ng ...@@ -19,7 +19,7 @@ import ngraph as ng
from test.ngraph.util import get_runtime from test.ngraph.util import get_runtime
def test_elu_operator(): def test_elu_operator_with_parameters():
runtime = get_runtime() runtime = get_runtime()
data_shape = [2, 2] data_shape = [2, 2]
...@@ -69,6 +69,38 @@ def test_elu_operator_with_scalar(): ...@@ -69,6 +69,38 @@ def test_elu_operator_with_scalar():
assert np.allclose(result, expected) assert np.allclose(result, expected)
def test_gelu_operator_with_parameters():
runtime = get_runtime()
data_value = np.array([[-5, 1], [-2, 3]], dtype=np.float32)
data_shape = [2, 2]
parameter_data = ng.parameter(data_shape, name='Data', dtype=np.float32)
model = ng.gelu(parameter_data)
computation = runtime.computation(model, parameter_data)
result = computation(data_value)
expected = np.array([[-1.4901161e-06, 8.4134471e-01], [-4.5500278e-02, 2.9959502]],
dtype=np.float32)
assert np.allclose(result, expected)
def test_gelu_operator_with_array():
runtime = get_runtime()
data_value = np.array([[-5, 1], [-2, 3]], dtype=np.float32)
model = ng.gelu(data_value)
computation = runtime.computation(model)
result = computation()
expected = np.array([[-1.4901161e-06, 8.4134471e-01], [-4.5500278e-02, 2.9959502]],
dtype=np.float32)
assert np.allclose(result, expected)
def test_clamp_operator(): def test_clamp_operator():
runtime = get_runtime() runtime = get_runtime()
...@@ -99,4 +131,5 @@ def test_clamp_operator_with_array(): ...@@ -99,4 +131,5 @@ def test_clamp_operator_with_array():
result = computation() result = computation()
expected = np.clip(data_value, min_value, max_value) expected = np.clip(data_value, min_value, max_value)
assert np.allclose(result, expected) assert np.allclose(result, expected)
...@@ -21,7 +21,9 @@ ...@@ -21,7 +21,9 @@
#include "ngraph/graph_util.hpp" #include "ngraph/graph_util.hpp"
#include "ngraph/op/abs.hpp" #include "ngraph/op/abs.hpp"
#include "ngraph/op/add.hpp" #include "ngraph/op/add.hpp"
#include "ngraph/op/all.hpp"
#include "ngraph/op/and.hpp" #include "ngraph/op/and.hpp"
#include "ngraph/op/any.hpp"
#include "ngraph/op/broadcast.hpp" #include "ngraph/op/broadcast.hpp"
#include "ngraph/op/ceiling.hpp" #include "ngraph/op/ceiling.hpp"
#include "ngraph/op/concat.hpp" #include "ngraph/op/concat.hpp"
...@@ -42,7 +44,9 @@ ...@@ -42,7 +44,9 @@
#include "ngraph/op/greater_eq.hpp" #include "ngraph/op/greater_eq.hpp"
#include "ngraph/op/less.hpp" #include "ngraph/op/less.hpp"
#include "ngraph/op/less_eq.hpp" #include "ngraph/op/less_eq.hpp"
#include "ngraph/op/max.hpp"
#include "ngraph/op/maximum.hpp" #include "ngraph/op/maximum.hpp"
#include "ngraph/op/min.hpp"
#include "ngraph/op/minimum.hpp" #include "ngraph/op/minimum.hpp"
#include "ngraph/op/multiply.hpp" #include "ngraph/op/multiply.hpp"
#include "ngraph/op/negative.hpp" #include "ngraph/op/negative.hpp"
...@@ -65,7 +69,9 @@ ...@@ -65,7 +69,9 @@
#include "ngraph/pattern/op/label.hpp" #include "ngraph/pattern/op/label.hpp"
#include "ngraph/runtime/reference/abs.hpp" #include "ngraph/runtime/reference/abs.hpp"
#include "ngraph/runtime/reference/add.hpp" #include "ngraph/runtime/reference/add.hpp"
#include "ngraph/runtime/reference/all.hpp"
#include "ngraph/runtime/reference/and.hpp" #include "ngraph/runtime/reference/and.hpp"
#include "ngraph/runtime/reference/any.hpp"
#include "ngraph/runtime/reference/broadcast.hpp" #include "ngraph/runtime/reference/broadcast.hpp"
#include "ngraph/runtime/reference/ceiling.hpp" #include "ngraph/runtime/reference/ceiling.hpp"
#include "ngraph/runtime/reference/concat.hpp" #include "ngraph/runtime/reference/concat.hpp"
...@@ -79,7 +85,9 @@ ...@@ -79,7 +85,9 @@
#include "ngraph/runtime/reference/greater_eq.hpp" #include "ngraph/runtime/reference/greater_eq.hpp"
#include "ngraph/runtime/reference/less.hpp" #include "ngraph/runtime/reference/less.hpp"
#include "ngraph/runtime/reference/less_eq.hpp" #include "ngraph/runtime/reference/less_eq.hpp"
#include "ngraph/runtime/reference/max.hpp"
#include "ngraph/runtime/reference/maximum.hpp" #include "ngraph/runtime/reference/maximum.hpp"
#include "ngraph/runtime/reference/min.hpp"
#include "ngraph/runtime/reference/minimum.hpp" #include "ngraph/runtime/reference/minimum.hpp"
#include "ngraph/runtime/reference/multiply.hpp" #include "ngraph/runtime/reference/multiply.hpp"
#include "ngraph/runtime/reference/negate.hpp" #include "ngraph/runtime/reference/negate.hpp"
...@@ -1710,180 +1718,207 @@ void pass::ConstantFolding::construct_constant_reverse() ...@@ -1710,180 +1718,207 @@ void pass::ConstantFolding::construct_constant_reverse()
} }
template <typename T> template <typename T>
static shared_ptr<op::Constant> fold_constant_product_helper(shared_ptr<op::Constant> constant, static shared_ptr<op::Constant>
const AxisSet& reduction_axes, fold_constant_arithmetic_reduction_helper(shared_ptr<op::Constant> constant,
const Shape& result_shape) shared_ptr<Node> reduction_node)
{ {
vector<T> out_vec(shape_size(result_shape)); vector<T> out_vec(shape_size(reduction_node->get_shape()));
runtime::reference::product<T>(constant->get_vector<T>().data(), if (auto max = dynamic_pointer_cast<op::Max>(reduction_node))
{
runtime::reference::max<T>(constant->get_vector<T>().data(),
out_vec.data(),
constant->get_output_shape(0),
reduction_node->get_shape(),
max->get_reduction_axes());
}
else if (auto min = dynamic_pointer_cast<op::Min>(reduction_node))
{
runtime::reference::min<T>(constant->get_vector<T>().data(),
out_vec.data(), out_vec.data(),
constant->get_output_shape(0), constant->get_output_shape(0),
result_shape, reduction_node->get_shape(),
reduction_axes); min->get_reduction_axes());
}
else if (auto prod = dynamic_pointer_cast<op::Product>(reduction_node))
{
runtime::reference::product<T>(constant->get_vector<T>().data(),
out_vec.data(),
constant->get_output_shape(0),
reduction_node->get_shape(),
prod->get_reduction_axes());
}
else if (auto sum = dynamic_pointer_cast<op::Sum>(reduction_node))
{
runtime::reference::sum<T>(constant->get_vector<T>().data(),
out_vec.data(),
constant->get_output_shape(0),
reduction_node->get_shape(),
sum->get_reduction_axes());
}
else
{
NGRAPH_CHECK(false,
"Internal nGraph error: Ops handled in "
"fold_constant_arithmetic_reduction_helper must be consistent with those "
"matched in construct_constant_arithmetic_reduction");
}
return make_shared<op::Constant>(constant->get_output_element_type(0), result_shape, out_vec); return make_shared<op::Constant>(
reduction_node->get_output_element_type(0), reduction_node->get_shape(), out_vec);
} }
static shared_ptr<op::Constant> fold_constant_product(shared_ptr<op::Constant> constant, static shared_ptr<op::Constant>
const AxisSet& reduction_axes, fold_constant_arithmetic_reduction(shared_ptr<op::Constant> constant,
const Shape& result_shape) shared_ptr<Node> reduction_node)
{ {
auto& input_element_type = constant->get_output_element_type(0); auto& input_element_type = constant->get_output_element_type(0);
switch (input_element_type.get_type_enum()) switch (input_element_type.get_type_enum())
{ {
case element::Type_t::undefined: case element::Type_t::undefined:
NGRAPH_CHECK(false, "Encountered 'undefined' element type in fold_constant_product"); NGRAPH_CHECK(false,
"Encountered 'undefined' element type in fold_constant_arithmetic_reduction");
break; break;
case element::Type_t::dynamic: case element::Type_t::dynamic:
NGRAPH_CHECK(false, "Encountered 'dynamic' element type in fold_constant_product"); NGRAPH_CHECK(false,
"Encountered 'dynamic' element type in fold_constant_arithmetic_reduction");
break; break;
case element::Type_t::boolean: case element::Type_t::boolean:
return fold_constant_product_helper<char>(constant, reduction_axes, result_shape); return fold_constant_arithmetic_reduction_helper<char>(constant, reduction_node);
case element::Type_t::bf16: case element::Type_t::bf16:
return fold_constant_product_helper<bfloat16>(constant, reduction_axes, result_shape); return fold_constant_arithmetic_reduction_helper<bfloat16>(constant, reduction_node);
case element::Type_t::f16: case element::Type_t::f16:
return fold_constant_product_helper<float16>(constant, reduction_axes, result_shape); return fold_constant_arithmetic_reduction_helper<float16>(constant, reduction_node);
case element::Type_t::f32: case element::Type_t::f32:
return fold_constant_product_helper<float>(constant, reduction_axes, result_shape); return fold_constant_arithmetic_reduction_helper<float>(constant, reduction_node);
case element::Type_t::f64: case element::Type_t::f64:
return fold_constant_product_helper<double>(constant, reduction_axes, result_shape); return fold_constant_arithmetic_reduction_helper<double>(constant, reduction_node);
case element::Type_t::i8: case element::Type_t::i8:
return fold_constant_product_helper<int8_t>(constant, reduction_axes, result_shape); return fold_constant_arithmetic_reduction_helper<int8_t>(constant, reduction_node);
case element::Type_t::i16: case element::Type_t::i16:
return fold_constant_product_helper<int16_t>(constant, reduction_axes, result_shape); return fold_constant_arithmetic_reduction_helper<int16_t>(constant, reduction_node);
case element::Type_t::i32: case element::Type_t::i32:
return fold_constant_product_helper<int32_t>(constant, reduction_axes, result_shape); return fold_constant_arithmetic_reduction_helper<int32_t>(constant, reduction_node);
case element::Type_t::i64: case element::Type_t::i64:
return fold_constant_product_helper<int64_t>(constant, reduction_axes, result_shape); return fold_constant_arithmetic_reduction_helper<int64_t>(constant, reduction_node);
case element::Type_t::u8: case element::Type_t::u8:
return fold_constant_product_helper<uint8_t>(constant, reduction_axes, result_shape); return fold_constant_arithmetic_reduction_helper<uint8_t>(constant, reduction_node);
case element::Type_t::u16: case element::Type_t::u16:
return fold_constant_product_helper<uint16_t>(constant, reduction_axes, result_shape); return fold_constant_arithmetic_reduction_helper<uint16_t>(constant, reduction_node);
case element::Type_t::u32: case element::Type_t::u32:
return fold_constant_product_helper<uint32_t>(constant, reduction_axes, result_shape); return fold_constant_arithmetic_reduction_helper<uint32_t>(constant, reduction_node);
case element::Type_t::u64: case element::Type_t::u64:
return fold_constant_product_helper<uint64_t>(constant, reduction_axes, result_shape); return fold_constant_arithmetic_reduction_helper<uint64_t>(constant, reduction_node);
} }
NGRAPH_UNREACHABLE("Unexpected switch case"); NGRAPH_UNREACHABLE("Unexpected switch case");
} }
void pass::ConstantFolding::construct_constant_product() void pass::ConstantFolding::construct_constant_arithmetic_reduction()
{ {
auto constant_label = make_shared<pattern::op::Label>( auto constant_data_label = make_shared<pattern::op::Label>(
element::i32, Shape{2, 3, 4}, pattern::has_class<op::Constant>()); element::i32, Shape{2, 3, 4}, pattern::has_class<op::Constant>());
auto convert_op = make_shared<op::Product>(constant_label, AxisSet{0, 1, 2}); auto constant_axes_label =
make_shared<pattern::op::Label>(element::i64, Shape{2}, pattern::has_class<op::Constant>());
auto constant_product_callback = [constant_label](pattern::Matcher& m) { auto is_supported_reduction = [](std::shared_ptr<Node> n) {
NGRAPH_DEBUG << "In callback for constant_product_callback against node = " return (pattern::has_class<op::Max>()(n) || pattern::has_class<op::Min>()(n) ||
pattern::has_class<op::Product>()(n) || pattern::has_class<op::Sum>()(n));
};
auto reduction =
std::make_shared<pattern::op::Any>(element::i32,
Shape{2},
is_supported_reduction,
NodeVector{constant_data_label, constant_axes_label});
auto constant_arithmetic_reduction_callback = [constant_data_label](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for constant_arithmetic_reduction_callback against node = "
<< m.get_match_root()->get_name(); << m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map(); auto pattern_map = m.get_pattern_map();
auto constant_match = static_pointer_cast<op::Constant>(pattern_map[constant_label]); auto constant_match = static_pointer_cast<op::Constant>(pattern_map[constant_data_label]);
auto product_match = static_pointer_cast<op::Product>(m.get_match_root()); auto reduction_match = m.get_match_root();
replace_node(m.get_match_root(), replace_node(reduction_match,
fold_constant_product(constant_match, fold_constant_arithmetic_reduction(constant_match, reduction_match));
product_match->get_reduction_axes(),
product_match->get_output_shape(0)));
return true; return true;
}; };
auto convert_matcher = auto arithmetic_reduction_matcher =
make_shared<pattern::Matcher>(convert_op, "ConstantFolding.ConstantProduct"); make_shared<pattern::Matcher>(reduction, "ConstantFolding.ConstantArithmeticReduction");
this->add_matcher(convert_matcher, constant_product_callback, all_pass_property_off); this->add_matcher(arithmetic_reduction_matcher,
constant_arithmetic_reduction_callback,
all_pass_property_off);
} }
// TODO(amprocte): Find a way to reduce duplication with Product. (The fact static shared_ptr<op::Constant> fold_constant_logical_reduction(shared_ptr<op::Constant> constant,
// that we bottom out in a reference call makes it a bit tricky.) shared_ptr<Node> reduction_node)
template <typename T>
static shared_ptr<op::Constant> fold_constant_sum_helper(shared_ptr<op::Constant> constant,
const AxisSet& reduction_axes,
const Shape& result_shape)
{ {
vector<T> out_vec(shape_size(result_shape)); vector<char> out_vec(shape_size(reduction_node->get_shape()));
runtime::reference::sum<T>(constant->get_vector<T>().data(),
out_vec.data(),
constant->get_output_shape(0),
result_shape,
reduction_axes);
return make_shared<op::Constant>(constant->get_output_element_type(0), result_shape, out_vec);
}
static shared_ptr<op::Constant> fold_constant_sum(shared_ptr<op::Constant> constant, if (auto all = dynamic_pointer_cast<::ngraph::op::All>(reduction_node))
const AxisSet& reduction_axes,
const Shape& result_shape)
{
auto& input_element_type = constant->get_output_element_type(0);
switch (input_element_type.get_type_enum())
{ {
case element::Type_t::undefined: runtime::reference::all(constant->get_vector<char>().data(),
NGRAPH_CHECK(false, "Encountered 'undefined' element type in fold_constant_sum"); out_vec.data(),
break; constant->get_output_shape(0),
case element::Type_t::dynamic: reduction_node->get_shape(),
NGRAPH_CHECK(false, "Encountered 'dynamic' element type in fold_constant_sum"); all->get_reduction_axes());
break; }
case element::Type_t::boolean: else if (auto any = dynamic_pointer_cast<::ngraph::op::Any>(reduction_node))
return fold_constant_sum_helper<char>(constant, reduction_axes, result_shape); {
case element::Type_t::bf16: runtime::reference::any(constant->get_vector<char>().data(),
return fold_constant_sum_helper<bfloat16>(constant, reduction_axes, result_shape); out_vec.data(),
case element::Type_t::f16: constant->get_output_shape(0),
return fold_constant_sum_helper<float16>(constant, reduction_axes, result_shape); reduction_node->get_shape(),
case element::Type_t::f32: any->get_reduction_axes());
return fold_constant_sum_helper<float>(constant, reduction_axes, result_shape); }
case element::Type_t::f64: else
return fold_constant_sum_helper<double>(constant, reduction_axes, result_shape); {
case element::Type_t::i8: NGRAPH_CHECK(false,
return fold_constant_sum_helper<int8_t>(constant, reduction_axes, result_shape); "Internal nGraph error: Ops handled in "
case element::Type_t::i16: "fold_constant_logical_reduction must be consistent with those "
return fold_constant_sum_helper<int16_t>(constant, reduction_axes, result_shape); "matched in construct_constant_logical_reduction");
case element::Type_t::i32:
return fold_constant_sum_helper<int32_t>(constant, reduction_axes, result_shape);
case element::Type_t::i64:
return fold_constant_sum_helper<int64_t>(constant, reduction_axes, result_shape);
case element::Type_t::u8:
return fold_constant_sum_helper<uint8_t>(constant, reduction_axes, result_shape);
case element::Type_t::u16:
return fold_constant_sum_helper<uint16_t>(constant, reduction_axes, result_shape);
case element::Type_t::u32:
return fold_constant_sum_helper<uint32_t>(constant, reduction_axes, result_shape);
case element::Type_t::u64:
return fold_constant_sum_helper<uint64_t>(constant, reduction_axes, result_shape);
} }
NGRAPH_UNREACHABLE("Unexpected switch case"); return make_shared<op::Constant>(
reduction_node->get_output_element_type(0), reduction_node->get_shape(), out_vec);
} }
void pass::ConstantFolding::construct_constant_sum() void pass::ConstantFolding::construct_constant_logical_reduction()
{ {
auto constant_label = make_shared<pattern::op::Label>( auto constant_data_label = make_shared<pattern::op::Label>(
element::i32, Shape{2, 3, 4}, pattern::has_class<op::Constant>()); element::boolean, Shape{2, 3, 4}, pattern::has_class<op::Constant>());
auto convert_op = make_shared<op::Sum>(constant_label, AxisSet{0, 1, 2}); auto constant_axes_label =
make_shared<pattern::op::Label>(element::i64, Shape{2}, pattern::has_class<op::Constant>());
auto constant_sum_callback = [constant_label](pattern::Matcher& m) { auto is_supported_reduction = [](std::shared_ptr<Node> n) {
NGRAPH_DEBUG << "In callback for constant_sum_callback against node = " return (pattern::has_class<::ngraph::op::All>()(n) ||
pattern::has_class<::ngraph::op::Any>()(n));
};
auto reduction =
std::make_shared<pattern::op::Any>(element::i32,
Shape{2},
is_supported_reduction,
NodeVector{constant_data_label, constant_axes_label});
auto constant_logical_reduction_callback = [constant_data_label](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for constant_logical_reduction_callback against node = "
<< m.get_match_root()->get_name(); << m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map(); auto pattern_map = m.get_pattern_map();
auto constant_match = static_pointer_cast<op::Constant>(pattern_map[constant_label]); auto constant_match = static_pointer_cast<op::Constant>(pattern_map[constant_data_label]);
auto sum_match = static_pointer_cast<op::Sum>(m.get_match_root()); auto reduction_match = m.get_match_root();
replace_node(m.get_match_root(), replace_node(reduction_match,
fold_constant_sum(constant_match, fold_constant_logical_reduction(constant_match, reduction_match));
sum_match->get_reduction_axes(),
sum_match->get_output_shape(0)));
return true; return true;
}; };
auto convert_matcher = make_shared<pattern::Matcher>(convert_op, "ConstantFolding.ConstantSum"); auto logical_reduction_matcher =
this->add_matcher(convert_matcher, constant_sum_callback, all_pass_property_off); make_shared<pattern::Matcher>(reduction, "ConstantFolding.ConstantLogicalReduction");
this->add_matcher(
logical_reduction_matcher, constant_logical_reduction_callback, all_pass_property_off);
} }
template <typename T> template <typename T>
......
...@@ -43,8 +43,8 @@ public: ...@@ -43,8 +43,8 @@ public:
CONVERT, CONVERT,
SHAPE_OF, SHAPE_OF,
REVERSE, REVERSE,
PRODUCT, ARITHMETIC_REDUCTION,
SUM, LOGICAL_REDUCTION,
CONCAT, CONCAT,
GATHER, GATHER,
SLICE, SLICE,
...@@ -70,8 +70,8 @@ public: ...@@ -70,8 +70,8 @@ public:
construct_constant_convert(); construct_constant_convert();
construct_constant_shape_of(); construct_constant_shape_of();
construct_constant_reverse(); construct_constant_reverse();
construct_constant_product(); construct_constant_arithmetic_reduction();
construct_constant_sum(); construct_constant_logical_reduction();
construct_constant_concat(); construct_constant_concat();
construct_constant_gather(); construct_constant_gather();
construct_constant_slice(); construct_constant_slice();
...@@ -104,8 +104,12 @@ public: ...@@ -104,8 +104,12 @@ public:
case CFTransformations::CONVERT: construct_constant_convert(); break; case CFTransformations::CONVERT: construct_constant_convert(); break;
case CFTransformations::SHAPE_OF: construct_constant_shape_of(); break; case CFTransformations::SHAPE_OF: construct_constant_shape_of(); break;
case CFTransformations::REVERSE: construct_constant_reverse(); break; case CFTransformations::REVERSE: construct_constant_reverse(); break;
case CFTransformations::PRODUCT: construct_constant_product(); break; case CFTransformations::ARITHMETIC_REDUCTION:
case CFTransformations::SUM: construct_constant_sum(); break; construct_constant_arithmetic_reduction();
break;
case CFTransformations::LOGICAL_REDUCTION:
construct_constant_logical_reduction();
break;
case CFTransformations::CONCAT: construct_constant_concat(); break; case CFTransformations::CONCAT: construct_constant_concat(); break;
case CFTransformations::GATHER: construct_constant_gather(); break; case CFTransformations::GATHER: construct_constant_gather(); break;
case CFTransformations::SLICE: construct_constant_slice(); break; case CFTransformations::SLICE: construct_constant_slice(); break;
...@@ -130,8 +134,8 @@ private: ...@@ -130,8 +134,8 @@ private:
void construct_constant_convert(); void construct_constant_convert();
void construct_constant_shape_of(); void construct_constant_shape_of();
void construct_constant_reverse(); void construct_constant_reverse();
void construct_constant_product(); void construct_constant_arithmetic_reduction();
void construct_constant_sum(); void construct_constant_logical_reduction();
void construct_constant_concat(); void construct_constant_concat();
void construct_constant_gather(); void construct_constant_gather();
void construct_constant_slice(); void construct_constant_slice();
......
...@@ -36,7 +36,7 @@ namespace ngraph ...@@ -36,7 +36,7 @@ namespace ngraph
const AxisSet& reduction_axes) const AxisSet& reduction_axes)
{ {
T minval = std::numeric_limits<T>::has_infinity T minval = std::numeric_limits<T>::has_infinity
? -std::numeric_limits<T>::infinity() ? T(-std::numeric_limits<T>::infinity())
: std::numeric_limits<T>::min(); : std::numeric_limits<T>::min();
CoordinateTransform output_transform(out_shape); CoordinateTransform output_transform(out_shape);
......
...@@ -1540,8 +1540,7 @@ NGRAPH_TEST(${BACKEND_NAME}, group_conv_transpose) ...@@ -1540,8 +1540,7 @@ NGRAPH_TEST(${BACKEND_NAME}, group_conv_transpose)
-0.0270785f, -0.00680824f, -0.06650258f, 0.08004665f, 0.07918708f, -0.0724144f, -0.0270785f, -0.00680824f, -0.06650258f, 0.08004665f, 0.07918708f, -0.0724144f,
0.06256775f, -0.17838378f, -0.18863615f, 0.20064656f, 0.133717f, -0.06876295f, 0.06256775f, -0.17838378f, -0.18863615f, 0.20064656f, 0.133717f, -0.06876295f,
-0.06398046f, -0.00864975f, 0.19289537f, -0.01490572f, -0.13673618f, 0.01949645f}); -0.06398046f, -0.00864975f, 0.19289537f, -0.01490572f, -0.13673618f, 0.01949645f});
test_case.set_tolerance(3); test_case.run(DEFAULT_FLOAT_TOLERANCE_BITS + 1);
test_case.run();
} }
NGRAPH_TEST(${BACKEND_NAME}, group_conv_transpose_output_shape) NGRAPH_TEST(${BACKEND_NAME}, group_conv_transpose_output_shape)
......
...@@ -461,6 +461,110 @@ TEST(constant_folding, const_sum) ...@@ -461,6 +461,110 @@ TEST(constant_folding, const_sum)
ASSERT_EQ(values_expected, values_out); ASSERT_EQ(values_expected, values_out);
} }
TEST(constant_folding, const_max)
{
Shape input_shape{3, 3};
vector<int32_t> values_in{1, 2, 3, 4, 5, 6, 7, 8, 9};
auto constant = op::Constant::create(element::i32, input_shape, values_in);
auto convert = make_shared<op::Max>(constant, AxisSet{1});
auto f = make_shared<Function>(convert, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Max>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const =
std::dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
auto values_out = new_const->get_vector<int32_t>();
vector<int32_t> values_expected{3, 6, 9};
ASSERT_EQ(values_expected, values_out);
}
TEST(constant_folding, const_min)
{
Shape input_shape{3, 3};
vector<int32_t> values_in{1, 2, 3, 4, 5, 6, 7, 8, 9};
auto constant = op::Constant::create(element::i32, input_shape, values_in);
auto convert = make_shared<op::Min>(constant, AxisSet{1});
auto f = make_shared<Function>(convert, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Min>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const =
std::dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
auto values_out = new_const->get_vector<int32_t>();
vector<int32_t> values_expected{1, 4, 7};
ASSERT_EQ(values_expected, values_out);
}
TEST(constant_folding, const_all)
{
Shape input_shape{3, 3};
vector<char> values_in{0, 1, 1, 0, 1, 0, 1, 1, 1};
auto constant = op::Constant::create(element::boolean, input_shape, values_in);
auto convert = make_shared<op::All>(constant, AxisSet{1});
auto f = make_shared<Function>(convert, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::All>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const =
std::dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
auto values_out = new_const->get_vector<char>();
vector<char> values_expected{0, 0, 1};
ASSERT_EQ(values_expected, values_out);
}
TEST(constant_folding, const_any)
{
Shape input_shape{3, 3};
vector<char> values_in{1, 0, 0, 1, 0, 1, 0, 0, 0};
auto constant = op::Constant::create(element::boolean, input_shape, values_in);
auto convert = make_shared<op::Any>(constant, AxisSet{1});
auto f = make_shared<Function>(convert, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Any>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const =
std::dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
auto values_out = new_const->get_vector<char>();
vector<char> values_expected{1, 1, 0};
ASSERT_EQ(values_expected, values_out);
}
TEST(constant_folding, const_concat) TEST(constant_folding, const_concat)
{ {
auto constant0 = auto constant0 =
......
...@@ -104,8 +104,7 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_lstm_fwd_with_clip) ...@@ -104,8 +104,7 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_lstm_fwd_with_clip)
// We have to enlarge tolerance bits to 3 - it's only one bit more than default value. // We have to enlarge tolerance bits to 3 - it's only one bit more than default value.
// The discrepancies may occur at most on 7th decimal position. // The discrepancies may occur at most on 7th decimal position.
test_case.set_tolerance(3); test_case.run(DEFAULT_FLOAT_TOLERANCE_BITS + 1);
test_case.run();
} }
NGRAPH_TEST(onnx_${BACKEND_NAME}, model_lstm_fwd_mixed_seq) NGRAPH_TEST(onnx_${BACKEND_NAME}, model_lstm_fwd_mixed_seq)
...@@ -144,8 +143,7 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_lstm_fwd_mixed_seq) ...@@ -144,8 +143,7 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_lstm_fwd_mixed_seq)
// We have to enlarge tolerance bits to 3 - it's only one bit more than default value. // We have to enlarge tolerance bits to 3 - it's only one bit more than default value.
// The discrepancies may occur at most on 7th decimal position. // The discrepancies may occur at most on 7th decimal position.
test_case.set_tolerance(3); test_case.run(DEFAULT_FLOAT_TOLERANCE_BITS + 1);
test_case.run();
} }
NGRAPH_TEST(onnx_${BACKEND_NAME}, model_lstm_fwd_hardsigmoid_activation) NGRAPH_TEST(onnx_${BACKEND_NAME}, model_lstm_fwd_hardsigmoid_activation)
...@@ -201,8 +199,7 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_lstm_fwd_hardsigmoid_activation) ...@@ -201,8 +199,7 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_lstm_fwd_hardsigmoid_activation)
test_case.add_expected_output<float>(Shape{1, 1, 2}, {0.19017234f, 0.00356848f}); test_case.add_expected_output<float>(Shape{1, 1, 2}, {0.19017234f, 0.00356848f});
// The discrepancies occur at most at 18th mantissa bit - 8th decimal position. // The discrepancies occur at most at 18th mantissa bit - 8th decimal position.
test_case.set_tolerance(6); test_case.run(DEFAULT_FLOAT_TOLERANCE_BITS + 4);
test_case.run();
} }
NGRAPH_TEST(onnx_${BACKEND_NAME}, model_lstm_fwd_large_batch_no_clip) NGRAPH_TEST(onnx_${BACKEND_NAME}, model_lstm_fwd_large_batch_no_clip)
...@@ -307,8 +304,7 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_lstm_bdir_short_input_seq) ...@@ -307,8 +304,7 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_lstm_bdir_short_input_seq)
test_case.add_expected_output<float>(Shape{2, 1, 2}, test_case.add_expected_output<float>(Shape{2, 1, 2},
{-0.0251062f, 0.0561262f, -0.0318928f, 0.0762679f}); {-0.0251062f, 0.0561262f, -0.0318928f, 0.0762679f});
test_case.set_tolerance(DEFAULT_FLOAT_TOLERANCE_BITS + 3); test_case.run(DEFAULT_FLOAT_TOLERANCE_BITS + 3);
test_case.run();
} }
NGRAPH_TEST(onnx_${BACKEND_NAME}, model_lstm_mixed_seq_reverse) NGRAPH_TEST(onnx_${BACKEND_NAME}, model_lstm_mixed_seq_reverse)
...@@ -353,6 +349,5 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_lstm_mixed_seq_reverse) ...@@ -353,6 +349,5 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_lstm_mixed_seq_reverse)
Shape{1, 2, 3}, Shape{1, 2, 3},
{0.52497941f, 0.54983425f, 0.5744428f, 1.34960834f, 1.54772296f, 1.65633056f}); {0.52497941f, 0.54983425f, 0.5744428f, 1.34960834f, 1.54772296f, 1.65633056f});
test_case.set_tolerance(DEFAULT_FLOAT_TOLERANCE_BITS + 1); test_case.run(DEFAULT_FLOAT_TOLERANCE_BITS + 1);
test_case.run();
} }
...@@ -19,8 +19,9 @@ ...@@ -19,8 +19,9 @@
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "ngraph/assertion.hpp" #include "ngraph/assertion.hpp"
void ngraph::test::NgraphTestCase::run() void ngraph::test::NgraphTestCase::run(size_t tolerance_bits)
{ {
m_tolerance_bits = tolerance_bits;
const auto& function_results = m_function->get_results(); const auto& function_results = m_function->get_results();
NGRAPH_CHECK(m_expected_outputs.size() == function_results.size(), NGRAPH_CHECK(m_expected_outputs.size() == function_results.size(),
"Expected number of outputs is different from the function's number of results."); "Expected number of outputs is different from the function's number of results.");
...@@ -52,12 +53,6 @@ void ngraph::test::NgraphTestCase::run() ...@@ -52,12 +53,6 @@ void ngraph::test::NgraphTestCase::run()
} }
} }
ngraph::test::NgraphTestCase& ngraph::test::NgraphTestCase::set_tolerance(int tolerance_bits)
{
m_tolerance_bits = tolerance_bits;
return *this;
}
ngraph::test::NgraphTestCase& ngraph::test::NgraphTestCase::dump_results(bool dump) ngraph::test::NgraphTestCase& ngraph::test::NgraphTestCase::dump_results(bool dump)
{ {
m_dump_results = dump; m_dump_results = dump;
......
...@@ -38,8 +38,6 @@ namespace ngraph ...@@ -38,8 +38,6 @@ namespace ngraph
{ {
} }
NgraphTestCase& set_tolerance(int tolerance_bits);
/// \brief Makes the test case print the expected and computed values to the console. This should only be used for debugging purposes. /// \brief Makes the test case print the expected and computed values to the console. This should only be used for debugging purposes.
/// ///
/// Just before the assertion is done, the current test case will gather expected and computed values, /// Just before the assertion is done, the current test case will gather expected and computed values,
...@@ -130,7 +128,7 @@ namespace ngraph ...@@ -130,7 +128,7 @@ namespace ngraph
add_expected_output(expected_shape, value); add_expected_output(expected_shape, value);
} }
void run(); void run(size_t tolerance_bits = DEFAULT_FLOAT_TOLERANCE_BITS);
private: private:
template <typename T> template <typename T>
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment