Commit fc216f39 authored by Adam Procter's avatar Adam Procter Committed by Robert Kimball

"Any" and "All" ops (#2217)

* Skip --exclude-libs linker flag on macOS

* Change test to if(LINUX)

* Add "Any" op and AnyAllReplacement pass

* Add AnyAllReplacement to IGPU backend

* Stub (error-out) handlers for GPU and INTELGPU

* Add 'All' op

* Add AnyAllInsertion pass, deprecate deprecable ops, add stubs for INTELGPU

* Add failing unit tests to INTELGPU manifest

* Reduce boilerplate

* Reduce more boilerplate

* Add static keywords
parent 16d88a7f
.. all.rst:
###
All
###
.. code-block:: cpp
All // Boolean "all" reduction operation.
Description
===========
Reduces a tensor of booleans, eliminating the specified reduction axes by taking the logical conjunction (i.e., "AND-reduce").
Inputs
------
+-----------------+------------------------------+--------------------------------+
| Name | Element Type | Shape |
+=================+==============================+================================+
| ``arg`` | ``ngraph::element::boolean`` | Any |
+-----------------+------------------------------+--------------------------------+
Attributes
----------
+--------------------+--------------------------------------------------------------------+
| Name | Description |
+====================+====================================================================+
| ``reduction_axes`` | The axis positions (0-based) on which to calculate the conjunction |
+--------------------+--------------------------------------------------------------------+
Outputs
-------
+-----------------+-------------------------+---------------------------------------------------+
| Name | Element Type | Shape |
+=================+=========================+===================================================+
| ``output`` | Same as ``arg`` | Same as ``arg``, with ``reduction_axes`` removed. |
+-----------------+-------------------------+---------------------------------------------------+
C++ Interface
=============
.. doxygenclass:: ngraph::op::All
:project: ngraph
:members:
.. any.rst:
###
Any
###
.. code-block:: cpp
Any // Boolean "any" reduction operation.
Description
===========
Reduces a tensor of booleans, eliminating the specified reduction axes by taking the logical disjunction (i.e., "OR-reduce").
Inputs
------
+-----------------+------------------------------+--------------------------------+
| Name | Element Type | Shape |
+=================+==============================+================================+
| ``arg`` | ``ngraph::element::boolean`` | Any |
+-----------------+------------------------------+--------------------------------+
Attributes
----------
+--------------------+--------------------------------------------------------------------+
| Name | Description |
+====================+====================================================================+
| ``reduction_axes`` | The axis positions (0-based) on which to calculate the disjunction |
+--------------------+--------------------------------------------------------------------+
Outputs
-------
+-----------------+-------------------------+---------------------------------------------------+
| Name | Element Type | Shape |
+=================+=========================+===================================================+
| ``output`` | Same as ``arg`` | Same as ``arg``, with ``reduction_axes`` removed. |
+-----------------+-------------------------+---------------------------------------------------+
C++ Interface
=============
.. doxygenclass:: ngraph::op::Any
:project: ngraph
:members:
...@@ -50,8 +50,10 @@ Not currently a comprehensive list. ...@@ -50,8 +50,10 @@ Not currently a comprehensive list.
* :doc:`abs` * :doc:`abs`
* :doc:`acos` * :doc:`acos`
* :doc:`add` * :doc:`add`
* :doc:`all`
* :doc:`allreduce` * :doc:`allreduce`
* :doc:`and` * :doc:`and`
* :doc:`any`
* :doc:`asin` * :doc:`asin`
* :doc:`atan` * :doc:`atan`
* :doc:`avg_pool` * :doc:`avg_pool`
...@@ -119,8 +121,10 @@ Not currently a comprehensive list. ...@@ -119,8 +121,10 @@ Not currently a comprehensive list.
abs.rst abs.rst
acos.rst acos.rst
add.rst add.rst
all.rst
allreduce.rst allreduce.rst
and.rst and.rst
any.rst
asin.rst asin.rst
atan.rst atan.rst
avg_pool.rst avg_pool.rst
......
...@@ -39,8 +39,10 @@ set (SRC ...@@ -39,8 +39,10 @@ set (SRC
op/abs.cpp op/abs.cpp
op/acos.cpp op/acos.cpp
op/add.cpp op/add.cpp
op/all.cpp
op/allreduce.cpp op/allreduce.cpp
op/and.cpp op/and.cpp
op/any.cpp
op/argmin.cpp op/argmin.cpp
op/argmax.cpp op/argmax.cpp
op/asin.cpp op/asin.cpp
...@@ -122,8 +124,11 @@ set (SRC ...@@ -122,8 +124,11 @@ set (SRC
op/util/binary_elementwise_comparison.cpp op/util/binary_elementwise_comparison.cpp
op/util/binary_elementwise_logical.cpp op/util/binary_elementwise_logical.cpp
op/util/index_reduction.cpp op/util/index_reduction.cpp
op/util/logical_reduction.cpp
op/util/unary_elementwise_arithmetic.cpp op/util/unary_elementwise_arithmetic.cpp
partial_shape.cpp partial_shape.cpp
pass/any_all_insertion.cpp
pass/any_all_replacement.cpp
pass/assign_placement.cpp pass/assign_placement.cpp
pass/algebraic_simplification.cpp pass/algebraic_simplification.cpp
pass/common_function_collection.cpp pass/common_function_collection.cpp
......
...@@ -84,9 +84,11 @@ if (NGRAPH_GPU_ENABLE OR (NGRAPH_CPU_ENABLE AND NOT NGRAPH_DEX_ONLY)) ...@@ -84,9 +84,11 @@ if (NGRAPH_GPU_ENABLE OR (NGRAPH_CPU_ENABLE AND NOT NGRAPH_DEX_ONLY))
add_dependencies(codegen header_resource) add_dependencies(codegen header_resource)
if (NGRAPH_CPU_ENABLE) if (NGRAPH_CPU_ENABLE)
add_dependencies(codegen libmkldnn libeigen) add_dependencies(codegen libmkldnn libeigen)
# --exclude-libs=ALL prevents symbols from statically-linked libraries (LLVM, in this case) if(LINUX)
# from being automatically exported # --exclude-libs=ALL prevents symbols from statically-linked libraries (LLVM, in this case)
set_property(TARGET codegen APPEND PROPERTY LINK_FLAGS "-Wl,--exclude-libs=ALL") # from being automatically exported
set_property(TARGET codegen APPEND PROPERTY LINK_FLAGS "-Wl,--exclude-libs=ALL")
endif()
endif() endif()
target_include_directories(codegen SYSTEM PRIVATE ${CMAKE_BINARY_DIR}) target_include_directories(codegen SYSTEM PRIVATE ${CMAKE_BINARY_DIR})
target_link_libraries(codegen PRIVATE libllvm ngraph) target_link_libraries(codegen PRIVATE libllvm ngraph)
......
...@@ -60,8 +60,10 @@ ...@@ -60,8 +60,10 @@
#include "ngraph/op/abs.hpp" #include "ngraph/op/abs.hpp"
#include "ngraph/op/acos.hpp" #include "ngraph/op/acos.hpp"
#include "ngraph/op/add.hpp" #include "ngraph/op/add.hpp"
#include "ngraph/op/all.hpp"
#include "ngraph/op/allreduce.hpp" #include "ngraph/op/allreduce.hpp"
#include "ngraph/op/and.hpp" #include "ngraph/op/and.hpp"
#include "ngraph/op/any.hpp"
#include "ngraph/op/argmax.hpp" #include "ngraph/op/argmax.hpp"
#include "ngraph/op/argmin.hpp" #include "ngraph/op/argmin.hpp"
#include "ngraph/op/asin.hpp" #include "ngraph/op/asin.hpp"
......
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/op/all.hpp"
using namespace std;
using namespace ngraph;
op::All::All(const shared_ptr<Node>& arg, const AxisSet& reduction_axes)
: LogicalReduction("All", arg, reduction_axes)
{
constructor_validate_and_infer_types();
}
shared_ptr<Node> op::All::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<All>(new_args.at(0), m_reduction_axes);
}
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/axis_set.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/op/op.hpp"
#include "ngraph/op/util/logical_reduction.hpp"
namespace ngraph
{
namespace op
{
/// \brief Logical "all" reduction operation.
class All : public util::LogicalReduction
{
public:
/// \brief Constructs an "all" reduction operation.
///
/// \param arg The tensor view to be reduced.
/// \param reduction_axes The axis positions (0-based) to be eliminated.
All(const std::shared_ptr<Node>& arg, const AxisSet& reduction_axes);
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
/// \return The default value for All.
virtual std::shared_ptr<Node> get_default_value() const override
{
return ngraph::make_constant_from_string("1", get_element_type(), get_shape());
}
};
}
}
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/op/any.hpp"
using namespace std;
using namespace ngraph;
op::Any::Any(const shared_ptr<Node>& arg, const AxisSet& reduction_axes)
: LogicalReduction("Any", arg, reduction_axes)
{
constructor_validate_and_infer_types();
}
shared_ptr<Node> op::Any::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<Any>(new_args.at(0), m_reduction_axes);
}
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/axis_set.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/op/op.hpp"
#include "ngraph/op/util/logical_reduction.hpp"
namespace ngraph
{
namespace op
{
/// \brief Logical "any" reduction operation.
class Any : public util::LogicalReduction
{
public:
/// \brief Constructs an "any" reduction operation.
///
/// \param arg The tensor view to be reduced.
/// \param reduction_axes The axis positions (0-based) to be eliminated.
Any(const std::shared_ptr<Node>& arg, const AxisSet& reduction_axes);
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
/// \return The default value for Any.
virtual std::shared_ptr<Node> get_default_value() const override
{
return ngraph::make_constant_from_string("0", get_element_type(), get_shape());
}
};
}
}
...@@ -48,8 +48,10 @@ ...@@ -48,8 +48,10 @@
NGRAPH_OP(Abs, ngraph::op) NGRAPH_OP(Abs, ngraph::op)
NGRAPH_OP(Acos, ngraph::op) NGRAPH_OP(Acos, ngraph::op)
NGRAPH_OP(Add, ngraph::op) NGRAPH_OP(Add, ngraph::op)
NGRAPH_OP(All, ngraph::op)
NGRAPH_OP(AllReduce, ngraph::op) NGRAPH_OP(AllReduce, ngraph::op)
NGRAPH_OP(And, ngraph::op) NGRAPH_OP(And, ngraph::op)
NGRAPH_OP(Any, ngraph::op)
NGRAPH_OP(ArgMax, ngraph::op) NGRAPH_OP(ArgMax, ngraph::op)
NGRAPH_OP(ArgMin, ngraph::op) NGRAPH_OP(ArgMin, ngraph::op)
NGRAPH_OP(Asin, ngraph::op) NGRAPH_OP(Asin, ngraph::op)
......
...@@ -23,7 +23,9 @@ namespace ngraph ...@@ -23,7 +23,9 @@ namespace ngraph
{ {
namespace op namespace op
{ {
/// \brief Tensor reduction operation. /// \brief (DEPRECATED) Tensor reduction operation.
///
/// WARNING: This op is deprecated and will be removed in a future version of nGraph.
/// ///
/// Element-wise reduces the input tensor, eliminating the specified reduction axes, given a reduction function that maps two scalars to a scalar. /// Element-wise reduces the input tensor, eliminating the specified reduction axes, given a reduction function that maps two scalars to a scalar.
/// For example, if the reduction function \f$f(x,y) = x+y\f$: /// For example, if the reduction function \f$f(x,y) = x+y\f$:
......
...@@ -22,7 +22,9 @@ namespace ngraph ...@@ -22,7 +22,9 @@ namespace ngraph
{ {
namespace op namespace op
{ {
/// \brief Windowed reduction operation. /// \brief (DEPRECATED) Windowed reduction operation.
///
/// WARNING: This op is deprecated and will be removed in a future version of nGraph.
/// ///
/// Slides a window of user-defined shape, with user-defined strides, over the tensor and produces for each window position the result obtained by /// Slides a window of user-defined shape, with user-defined strides, over the tensor and produces for each window position the result obtained by
/// reducing the tensors in the window to a scalar, using the user-supplied reduction function. /// reducing the tensors in the window to a scalar, using the user-supplied reduction function.
......
...@@ -22,7 +22,9 @@ namespace ngraph ...@@ -22,7 +22,9 @@ namespace ngraph
{ {
namespace op namespace op
{ {
/// \brief Select-and-scatter operation. /// \brief (DEPRECATED) Select-and-scatter operation.
///
/// WARNING: This op is deprecated and will be removed in a future version of nGraph.
/// ///
/// Select-and-scatter takes three inputs, all of which must have the same element type \f$E\f$: /// Select-and-scatter takes three inputs, all of which must have the same element type \f$E\f$:
/// ///
......
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/op/util/logical_reduction.hpp"
using namespace std;
using namespace ngraph;
op::util::LogicalReduction::LogicalReduction(const std::string& node_type,
const std::shared_ptr<Node>& arg,
const AxisSet& reduction_axes)
: Op(node_type, check_single_output_args({arg}))
, m_reduction_axes(reduction_axes)
{
}
void op::util::LogicalReduction::validate_and_infer_types()
{
auto input_shape = get_input_partial_shape(0);
auto input_rank = input_shape.rank();
PartialShape result_shape{PartialShape::dynamic()};
if (input_rank.is_static())
{
std::vector<Dimension> dims;
for (auto axis : m_reduction_axes)
{
NODE_VALIDATION_ASSERT(this, axis < size_t(input_rank))
<< "Reduction axis (" << axis << ") is out of bounds "
<< "(argument shape: " << input_shape << ", reduction axes: " << m_reduction_axes
<< ")";
}
for (size_t i = 0; i < size_t(input_rank); i++)
{
if (m_reduction_axes.count(i) == 0)
{
dims.push_back(input_shape[i]);
}
}
result_shape = PartialShape(dims);
}
NODE_VALIDATION_ASSERT(this, get_input_element_type(0).compatible(element::boolean))
<< "Input element type must be boolean.";
set_output_type(0, element::boolean, result_shape);
}
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/op/op.hpp"
namespace ngraph
{
namespace op
{
namespace util
{
/// \brief Abstract base class for logical reduction operations, i.e., operations where chosen axes of the input tensors
/// are eliminated (reduced out) by repeated application of a particular binary logical operation.
class LogicalReduction : public Op
{
public:
/// \brief Constructs a logical reduction operation.
///
/// \param arg Node that produces the first input tensor.
/// \param reduction_axes The axis positions (0-based) to be eliminated.
LogicalReduction(const std::string& node_type,
const std::shared_ptr<Node>& arg,
const AxisSet& reduction_axes);
void validate_and_infer_types() override;
/// \return The axis positions (0-based) to be eliminated through reduction.
const AxisSet& get_reduction_axes() const { return m_reduction_axes; }
protected:
AxisSet m_reduction_axes;
};
}
}
}
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "any_all_insertion.hpp"
#include "ngraph/op/all.hpp"
#include "ngraph/op/and.hpp"
#include "ngraph/op/any.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/or.hpp"
#include "ngraph/op/reduce.hpp"
#include "ngraph/util.hpp"
using namespace ngraph;
static bool is_boolean_scalar_constant_with_val(std::shared_ptr<ngraph::Node> node, bool val)
{
auto k = std::dynamic_pointer_cast<op::Constant>(node);
if (k == nullptr)
{
return false;
}
if (k->get_element_type() != element::boolean)
{
return false;
}
if (k->get_shape() != Shape{})
{
return false;
}
const char* k_data = k->get_data_ptr<char>();
return (*k_data == static_cast<char>(val));
}
template <typename T>
static bool check_reduce_for_replacement(std::shared_ptr<ngraph::op::Reduce> reduce,
bool expected_k_val)
{
auto reductee = reduce->get_argument(0);
auto init_val = reduce->get_argument(1);
if (!is_boolean_scalar_constant_with_val(init_val, expected_k_val))
{
return false;
}
auto func = reduce->get_functions().at(0);
auto func_result_op = func->get_results().at(0)->get_argument(0);
if (std::dynamic_pointer_cast<T>(func_result_op) == nullptr)
{
return false;
}
auto func_params = func->get_parameters();
auto func_param_0 = func_params.at(0);
auto func_param_1 = func_params.at(1);
auto func_result_op_arg_0 = func_result_op->get_argument(0);
auto func_result_op_arg_1 = func_result_op->get_argument(1);
if (!((func_param_0 == func_result_op_arg_0 && func_param_1 == func_result_op_arg_1) ||
(func_param_0 == func_result_op_arg_1 && func_param_1 == func_result_op_arg_0)))
{
return false;
}
return true;
}
bool ngraph::pass::AnyAllInsertion::run_on_node(std::shared_ptr<ngraph::Node> node)
{
auto reduce = std::dynamic_pointer_cast<ngraph::op::Reduce>(node);
if (reduce == nullptr)
{
return false;
}
if (check_reduce_for_replacement<op::Or>(reduce, false))
{
ngraph::replace_node(
reduce,
std::make_shared<op::Any>(reduce->get_argument(0), reduce->get_reduction_axes()));
return true;
}
else if (check_reduce_for_replacement<op::And>(reduce, true))
{
ngraph::replace_node(
reduce,
std::make_shared<op::All>(reduce->get_argument(0), reduce->get_reduction_axes()));
return true;
}
return false;
}
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/pass/pass.hpp"
namespace ngraph
{
namespace pass
{
// Pass to convert Reduce ops into Any/All where possible. NOTE: this will disappear once
// the Reduce op is retired.
class AnyAllInsertion : public NodePass
{
public:
bool run_on_node(std::shared_ptr<ngraph::Node> node) override;
};
}
}
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "any_all_replacement.hpp"
#include "ngraph/op/all.hpp"
#include "ngraph/op/and.hpp"
#include "ngraph/op/any.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/or.hpp"
#include "ngraph/op/reduce.hpp"
#include "ngraph/util.hpp"
using namespace ngraph;
static std::shared_ptr<Node> make_any(std::shared_ptr<Node> arg, const AxisSet& reduction_axes)
{
auto f_arg0 = std::make_shared<op::Parameter>(element::boolean, Shape{});
auto f_arg1 = std::make_shared<op::Parameter>(element::boolean, Shape{});
auto f_or = std::make_shared<op::Or>(f_arg0, f_arg1);
auto f = std::make_shared<Function>(f_or, ParameterVector{f_arg0, f_arg1});
auto k_false = op::Constant::create(element::boolean, Shape{}, std::vector<char>{0});
return std::make_shared<op::Reduce>(arg, k_false, f, reduction_axes);
}
static std::shared_ptr<Node> make_all(std::shared_ptr<Node> arg, const AxisSet& reduction_axes)
{
auto f_arg0 = std::make_shared<op::Parameter>(element::boolean, Shape{});
auto f_arg1 = std::make_shared<op::Parameter>(element::boolean, Shape{});
auto f_or = std::make_shared<op::And>(f_arg0, f_arg1);
auto f = std::make_shared<Function>(f_or, ParameterVector{f_arg0, f_arg1});
auto k_true = op::Constant::create(element::boolean, Shape{}, std::vector<char>{1});
return std::make_shared<op::Reduce>(arg, k_true, f, reduction_axes);
}
bool ngraph::pass::AnyAllReplacement::run_on_node(std::shared_ptr<ngraph::Node> node)
{
bool clobbered = false;
if (auto any = std::dynamic_pointer_cast<ngraph::op::Any>(node))
{
ngraph::replace_node(any, make_any(any->get_argument(0), any->get_reduction_axes()));
clobbered = true;
}
else if (auto all = std::dynamic_pointer_cast<ngraph::op::All>(node))
{
ngraph::replace_node(all, make_all(all->get_argument(0), all->get_reduction_axes()));
clobbered = true;
}
return clobbered;
}
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/pass/pass.hpp"
namespace ngraph
{
namespace pass
{
class AnyAllReplacement : public NodePass
{
public:
bool run_on_node(std::shared_ptr<ngraph::Node> node) override;
};
}
}
...@@ -118,6 +118,7 @@ ...@@ -118,6 +118,7 @@
#include "ngraph/op/tanh.hpp" #include "ngraph/op/tanh.hpp"
#include "ngraph/op/topk.hpp" #include "ngraph/op/topk.hpp"
#include "ngraph/pass/algebraic_simplification.hpp" #include "ngraph/pass/algebraic_simplification.hpp"
#include "ngraph/pass/any_all_replacement.hpp"
#include "ngraph/pass/common_function_collection.hpp" #include "ngraph/pass/common_function_collection.hpp"
#include "ngraph/pass/constant_folding.hpp" #include "ngraph/pass/constant_folding.hpp"
#include "ngraph/pass/core_fusion.hpp" #include "ngraph/pass/core_fusion.hpp"
...@@ -1081,6 +1082,7 @@ void runtime::cpu::CPU_ExternalFunction::register_common_passes(ngraph::pass::Ma ...@@ -1081,6 +1082,7 @@ void runtime::cpu::CPU_ExternalFunction::register_common_passes(ngraph::pass::Ma
{ {
auto pass_map = pass_manager.get_pass_config().get_enables(); auto pass_map = pass_manager.get_pass_config().get_enables();
REGISTER_KNOBBED_PASS(AnyAllReplacement, true, ngraph::pass);
REGISTER_KNOBBED_PASS(LikeReplacement, true, ngraph::pass); REGISTER_KNOBBED_PASS(LikeReplacement, true, ngraph::pass);
REGISTER_KNOBBED_PASS(NopElimination, true, ngraph::pass); REGISTER_KNOBBED_PASS(NopElimination, true, ngraph::pass);
REGISTER_KNOBBED_PASS(ZeroDimTensorElimination, true, ngraph::pass); REGISTER_KNOBBED_PASS(ZeroDimTensorElimination, true, ngraph::pass);
......
...@@ -19,3 +19,7 @@ max_3d_to_scalar_int32 ...@@ -19,3 +19,7 @@ max_3d_to_scalar_int32
argmin_trivial_in_i32 argmin_trivial_in_i32
argmax_4D_axis_3_i64_in_i32 argmax_4D_axis_3_i64_in_i32
# Even after AnyAllReplacement, these trigger an "Unsupported Reduce" error.
any_2x2x3_eliminate_dims_0_1_2
all_2x2x3_eliminate_dims_0_1_2
...@@ -33,8 +33,10 @@ ...@@ -33,8 +33,10 @@
#include "ngraph/op/abs.hpp" #include "ngraph/op/abs.hpp"
#include "ngraph/op/acos.hpp" #include "ngraph/op/acos.hpp"
#include "ngraph/op/add.hpp" #include "ngraph/op/add.hpp"
#include "ngraph/op/all.hpp"
#include "ngraph/op/allreduce.hpp" #include "ngraph/op/allreduce.hpp"
#include "ngraph/op/and.hpp" #include "ngraph/op/and.hpp"
#include "ngraph/op/any.hpp"
#include "ngraph/op/argmax.hpp" #include "ngraph/op/argmax.hpp"
#include "ngraph/op/argmin.hpp" #include "ngraph/op/argmin.hpp"
#include "ngraph/op/asin.hpp" #include "ngraph/op/asin.hpp"
...@@ -157,6 +159,11 @@ void runtime::gpu::GPU_Emitter::emit_Add(EMIT_ARGS) ...@@ -157,6 +159,11 @@ void runtime::gpu::GPU_Emitter::emit_Add(EMIT_ARGS)
emit_elementwise<ngraph::op::Add>(external_function, writer, node, args, out); emit_elementwise<ngraph::op::Add>(external_function, writer, node, args, out);
} }
void runtime::gpu::GPU_Emitter::emit_All(EMIT_ARGS)
{
throw unsupported_op("Unsupported op '" + node->description() + "'");
}
void runtime::gpu::GPU_Emitter::emit_AllReduce(EMIT_ARGS) void runtime::gpu::GPU_Emitter::emit_AllReduce(EMIT_ARGS)
{ {
throw unsupported_op("Unsupported op '" + node->description() + "'"); throw unsupported_op("Unsupported op '" + node->description() + "'");
...@@ -167,6 +174,11 @@ void runtime::gpu::GPU_Emitter::emit_And(EMIT_ARGS) ...@@ -167,6 +174,11 @@ void runtime::gpu::GPU_Emitter::emit_And(EMIT_ARGS)
emit_elementwise<ngraph::op::And>(external_function, writer, node, args, out); emit_elementwise<ngraph::op::And>(external_function, writer, node, args, out);
} }
void runtime::gpu::GPU_Emitter::emit_Any(EMIT_ARGS)
{
throw unsupported_op("Unsupported op '" + node->description() + "'");
}
void runtime::gpu::GPU_Emitter::emit_ArgMax(EMIT_ARGS) void runtime::gpu::GPU_Emitter::emit_ArgMax(EMIT_ARGS)
{ {
cudnnReduceTensorOp_t reduce_op = CUDNN_REDUCE_TENSOR_MAX; cudnnReduceTensorOp_t reduce_op = CUDNN_REDUCE_TENSOR_MAX;
......
...@@ -104,6 +104,7 @@ ...@@ -104,6 +104,7 @@
#include "ngraph/op/tanh.hpp" #include "ngraph/op/tanh.hpp"
#include "ngraph/op/topk.hpp" #include "ngraph/op/topk.hpp"
#include "ngraph/pass/algebraic_simplification.hpp" #include "ngraph/pass/algebraic_simplification.hpp"
#include "ngraph/pass/any_all_replacement.hpp"
#include "ngraph/pass/common_function_collection.hpp" #include "ngraph/pass/common_function_collection.hpp"
#include "ngraph/pass/like_replacement.hpp" #include "ngraph/pass/like_replacement.hpp"
#include "ngraph/runtime/gpu/gpu_backend.hpp" #include "ngraph/runtime/gpu/gpu_backend.hpp"
...@@ -570,6 +571,7 @@ void runtime::gpu::GPU_ExternalFunction::compile() ...@@ -570,6 +571,7 @@ void runtime::gpu::GPU_ExternalFunction::compile()
pass_manager.register_pass<ngraph::pass::AlgebraicSimplification>(); pass_manager.register_pass<ngraph::pass::AlgebraicSimplification>();
#endif #endif
pass_manager.register_pass<runtime::gpu::pass::BatchNormCache>(); pass_manager.register_pass<runtime::gpu::pass::BatchNormCache>();
pass_manager.register_pass<ngraph::pass::AnyAllReplacement>();
pass_manager.register_pass<ngraph::pass::LikeReplacement>(); pass_manager.register_pass<ngraph::pass::LikeReplacement>();
pass_manager.register_pass<runtime::gpu::pass::GPULayout>(this); pass_manager.register_pass<runtime::gpu::pass::GPULayout>(this);
pass_manager.register_pass<ngraph::pass::AssignLayout<descriptor::layout::DenseTensorLayout>>(); pass_manager.register_pass<ngraph::pass::AssignLayout<descriptor::layout::DenseTensorLayout>>();
......
...@@ -45,6 +45,7 @@ ...@@ -45,6 +45,7 @@
#include <CPP/topology.hpp> #include <CPP/topology.hpp>
#include "ngraph/pass/algebraic_simplification.hpp" #include "ngraph/pass/algebraic_simplification.hpp"
#include "ngraph/pass/any_all_replacement.hpp"
#include "ngraph/pass/cse.hpp" #include "ngraph/pass/cse.hpp"
#include "ngraph/pass/get_output_element_elimination.hpp" #include "ngraph/pass/get_output_element_elimination.hpp"
#include "ngraph/pass/manager.hpp" #include "ngraph/pass/manager.hpp"
...@@ -417,6 +418,7 @@ runtime::Handle runtime::intelgpu::IntelGPUBackend::compile(shared_ptr<Function> ...@@ -417,6 +418,7 @@ runtime::Handle runtime::intelgpu::IntelGPUBackend::compile(shared_ptr<Function>
{ {
ngraph::pass::Manager pass_manager; ngraph::pass::Manager pass_manager;
pass_manager.register_pass<ngraph::pass::AnyAllReplacement>();
pass_manager.register_pass<ngraph::pass::NopElimination>(); pass_manager.register_pass<ngraph::pass::NopElimination>();
pass_manager.register_pass<ngraph::pass::AlgebraicSimplification>(); pass_manager.register_pass<ngraph::pass::AlgebraicSimplification>();
pass_manager.register_pass<ngraph::pass::CommonSubexpressionElimination>(); pass_manager.register_pass<ngraph::pass::CommonSubexpressionElimination>();
...@@ -1711,7 +1713,9 @@ runtime::Handle runtime::intelgpu::IntelGPUBackend::compile(shared_ptr<Function> ...@@ -1711,7 +1713,9 @@ runtime::Handle runtime::intelgpu::IntelGPUBackend::compile(shared_ptr<Function>
topology.add(lrn); topology.add(lrn);
break; break;
} }
case OP_TYPEID::All:
case OP_TYPEID::AllReduce: case OP_TYPEID::AllReduce:
case OP_TYPEID::Any:
case OP_TYPEID::FunctionCall: case OP_TYPEID::FunctionCall:
case OP_TYPEID::Dequantize: case OP_TYPEID::Dequantize:
case OP_TYPEID::Quantize: case OP_TYPEID::Quantize:
......
...@@ -151,3 +151,6 @@ max_3d_to_scalar_double ...@@ -151,3 +151,6 @@ max_3d_to_scalar_double
argmin_trivial_in_i32 argmin_trivial_in_i32
argmax_4D_axis_3_i64_in_i32 argmax_4D_axis_3_i64_in_i32
argmin_trivial_in_double argmin_trivial_in_double
all_2x2x3_eliminate_dim_1
all_2x2x3_eliminate_dim_2
all_2x2x3_eliminate_dims_0_1
...@@ -21,6 +21,8 @@ ...@@ -21,6 +21,8 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "ngraph/op/all.hpp"
#include "ngraph/op/any.hpp"
#include "ngraph/op/argmax.hpp" #include "ngraph/op/argmax.hpp"
#include "ngraph/op/argmin.hpp" #include "ngraph/op/argmin.hpp"
#include "ngraph/op/avg_pool.hpp" #include "ngraph/op/avg_pool.hpp"
...@@ -64,7 +66,9 @@ ...@@ -64,7 +66,9 @@
#include "ngraph/runtime/reference/abs.hpp" #include "ngraph/runtime/reference/abs.hpp"
#include "ngraph/runtime/reference/acos.hpp" #include "ngraph/runtime/reference/acos.hpp"
#include "ngraph/runtime/reference/add.hpp" #include "ngraph/runtime/reference/add.hpp"
#include "ngraph/runtime/reference/all.hpp"
#include "ngraph/runtime/reference/and.hpp" #include "ngraph/runtime/reference/and.hpp"
#include "ngraph/runtime/reference/any.hpp"
#include "ngraph/runtime/reference/argmax.hpp" #include "ngraph/runtime/reference/argmax.hpp"
#include "ngraph/runtime/reference/argmin.hpp" #include "ngraph/runtime/reference/argmin.hpp"
#include "ngraph/runtime/reference/asin.hpp" #include "ngraph/runtime/reference/asin.hpp"
...@@ -238,6 +242,16 @@ private: ...@@ -238,6 +242,16 @@ private:
element_count); element_count);
break; break;
} }
case OP_TYPEID::All:
{
const op::All* all = static_cast<const op::All*>(&node);
reference::all(static_cast<const char*>(args[0]),
static_cast<char*>(out[0]),
node.get_input_shape(0),
node.get_output_shape(0),
all->get_reduction_axes());
break;
}
case OP_TYPEID::AllReduce: { case OP_TYPEID::AllReduce: {
#ifdef NGRAPH_DISTRIBUTED #ifdef NGRAPH_DISTRIBUTED
reference::allreduce<T>(static_cast<const T*>(args[0]), reference::allreduce<T>(static_cast<const T*>(args[0]),
...@@ -256,6 +270,16 @@ private: ...@@ -256,6 +270,16 @@ private:
element_count); element_count);
break; break;
} }
case OP_TYPEID::Any:
{
const op::Any* any = static_cast<const op::Any*>(&node);
reference::any(static_cast<const char*>(args[0]),
static_cast<char*>(out[0]),
node.get_input_shape(0),
node.get_output_shape(0),
any->get_reduction_axes());
break;
}
case OP_TYPEID::ArgMin: case OP_TYPEID::ArgMin:
{ {
const op::ArgMin* argmin = static_cast<const op::ArgMin*>(&node); const op::ArgMin* argmin = static_cast<const op::ArgMin*>(&node);
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include "ngraph/runtime/plaidml/plaidml_compiler.hpp" #include "ngraph/runtime/plaidml/plaidml_compiler.hpp"
#include "ngraph/log.hpp" #include "ngraph/log.hpp"
#include "ngraph/pass/algebraic_simplification.hpp" #include "ngraph/pass/algebraic_simplification.hpp"
#include "ngraph/pass/any_all_replacement.hpp"
#include "ngraph/pass/core_fusion.hpp" #include "ngraph/pass/core_fusion.hpp"
#include "ngraph/pass/cse.hpp" #include "ngraph/pass/cse.hpp"
#include "ngraph/pass/get_output_element_elimination.hpp" #include "ngraph/pass/get_output_element_elimination.hpp"
...@@ -74,6 +75,7 @@ std::shared_ptr<ngraph::runtime::plaidml::CompiledFunction> ...@@ -74,6 +75,7 @@ std::shared_ptr<ngraph::runtime::plaidml::CompiledFunction>
ngraph::pass::Manager pass_manager; ngraph::pass::Manager pass_manager;
// We apply the same general-purposes passes as the CPU backend. // We apply the same general-purposes passes as the CPU backend.
pass_manager.register_pass<ngraph::pass::AnyAllReplacement>();
pass_manager.register_pass<ngraph::pass::LikeReplacement>(); pass_manager.register_pass<ngraph::pass::LikeReplacement>();
pass_manager.register_pass<ngraph::pass::NopElimination>(); pass_manager.register_pass<ngraph::pass::NopElimination>();
pass_manager.register_pass<ngraph::pass::AlgebraicSimplification>(); pass_manager.register_pass<ngraph::pass::AlgebraicSimplification>();
......
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <cmath>
#include "ngraph/coordinate_transform.hpp"
#include "ngraph/shape_util.hpp"
namespace ngraph
{
namespace runtime
{
namespace reference
{
static inline void all(const char* arg,
char* out,
const Shape& in_shape,
const Shape& out_shape,
const AxisSet& reduction_axes)
{
CoordinateTransform output_transform(out_shape);
for (const Coordinate& output_coord : output_transform)
{
out[output_transform.index(output_coord)] = 1;
}
CoordinateTransform input_transform(in_shape);
for (const Coordinate& input_coord : input_transform)
{
Coordinate output_coord = reduce(input_coord, reduction_axes);
out[output_transform.index(output_coord)] =
out[output_transform.index(output_coord)] &&
arg[input_transform.index(input_coord)];
}
}
}
}
}
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <cmath>
#include "ngraph/coordinate_transform.hpp"
#include "ngraph/shape_util.hpp"
namespace ngraph
{
namespace runtime
{
namespace reference
{
static inline void any(const char* arg,
char* out,
const Shape& in_shape,
const Shape& out_shape,
const AxisSet& reduction_axes)
{
CoordinateTransform output_transform(out_shape);
for (const Coordinate& output_coord : output_transform)
{
out[output_transform.index(output_coord)] = 0;
}
CoordinateTransform input_transform(in_shape);
for (const Coordinate& input_coord : input_transform)
{
Coordinate output_coord = reduce(input_coord, reduction_axes);
out[output_transform.index(output_coord)] =
out[output_transform.index(output_coord)] ||
arg[input_transform.index(input_coord)];
}
}
}
}
}
...@@ -23,8 +23,10 @@ ...@@ -23,8 +23,10 @@
#include "ngraph/op/abs.hpp" #include "ngraph/op/abs.hpp"
#include "ngraph/op/acos.hpp" #include "ngraph/op/acos.hpp"
#include "ngraph/op/add.hpp" #include "ngraph/op/add.hpp"
#include "ngraph/op/all.hpp"
#include "ngraph/op/allreduce.hpp" #include "ngraph/op/allreduce.hpp"
#include "ngraph/op/and.hpp" #include "ngraph/op/and.hpp"
#include "ngraph/op/any.hpp"
#include "ngraph/op/argmax.hpp" #include "ngraph/op/argmax.hpp"
#include "ngraph/op/argmin.hpp" #include "ngraph/op/argmin.hpp"
#include "ngraph/op/asin.hpp" #include "ngraph/op/asin.hpp"
...@@ -459,6 +461,12 @@ static shared_ptr<ngraph::Function> ...@@ -459,6 +461,12 @@ static shared_ptr<ngraph::Function>
node = make_shared<op::Add>(args[0], args[1]); node = make_shared<op::Add>(args[0], args[1]);
break; break;
} }
case OP_TYPEID::All:
{
auto reduction_axes = node_js.at("reduction_axes").get<set<size_t>>();
node = make_shared<op::All>(args[0], reduction_axes);
break;
}
case OP_TYPEID::AllReduce: case OP_TYPEID::AllReduce:
{ {
node = make_shared<op::AllReduce>(args[0]); node = make_shared<op::AllReduce>(args[0]);
...@@ -469,6 +477,12 @@ static shared_ptr<ngraph::Function> ...@@ -469,6 +477,12 @@ static shared_ptr<ngraph::Function>
node = make_shared<op::And>(args[0], args[1]); node = make_shared<op::And>(args[0], args[1]);
break; break;
} }
case OP_TYPEID::Any:
{
auto reduction_axes = node_js.at("reduction_axes").get<set<size_t>>();
node = make_shared<op::Any>(args[0], reduction_axes);
break;
}
case OP_TYPEID::ArgMin: case OP_TYPEID::ArgMin:
{ {
auto axis = node_js.at("axis").get<size_t>(); auto axis = node_js.at("axis").get<size_t>();
...@@ -1258,10 +1272,22 @@ static json write(const Node& n, bool binary_constant_data) ...@@ -1258,10 +1272,22 @@ static json write(const Node& n, bool binary_constant_data)
node["index_element_type"] = write_element_type(tmp->get_element_type()); node["index_element_type"] = write_element_type(tmp->get_element_type());
break; break;
} }
case OP_TYPEID::All:
{
auto tmp = dynamic_cast<const op::All*>(&n);
node["reduction_axes"] = tmp->get_reduction_axes();
break;
}
case OP_TYPEID::AllReduce: { break; case OP_TYPEID::AllReduce: { break;
} }
case OP_TYPEID::And: { break; case OP_TYPEID::And: { break;
} }
case OP_TYPEID::Any:
{
auto tmp = dynamic_cast<const op::Any*>(&n);
node["reduction_axes"] = tmp->get_reduction_axes();
break;
}
case OP_TYPEID::Asin: { break; case OP_TYPEID::Asin: { break;
} }
case OP_TYPEID::Atan: { break; case OP_TYPEID::Atan: { break;
......
...@@ -25,6 +25,8 @@ endif() ...@@ -25,6 +25,8 @@ endif()
set(SRC set(SRC
algebraic_simplification.cpp algebraic_simplification.cpp
all_close_f.cpp all_close_f.cpp
any_all_insertion.cpp
any_all_replacement.cpp
assertion.cpp assertion.cpp
build_graph.cpp build_graph.cpp
builder_autobroadcast.cpp builder_autobroadcast.cpp
...@@ -114,6 +116,8 @@ add_subdirectory(util) ...@@ -114,6 +116,8 @@ add_subdirectory(util)
# such as ${BACKEND_NAME} with their values, such as CPU, GPU, or INTERPRETER. # such as ${BACKEND_NAME} with their values, such as CPU, GPU, or INTERPRETER.
set(MULTI_TEST_SRC set(MULTI_TEST_SRC
autodiff.in.cpp autodiff.in.cpp
backend_all.in.cpp
backend_any.in.cpp
backend_binary_elementwise.in.cpp backend_binary_elementwise.in.cpp
backend_broadcast.in.cpp backend_broadcast.in.cpp
backend_comparison.in.cpp backend_comparison.in.cpp
......
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/pass/any_all_insertion.hpp"
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "ngraph/pass/manager.hpp"
#include "util/test_tools.hpp"
using namespace ngraph;
using namespace std;
// Ripped off of pass/any_all_replacement.cpp.
static std::shared_ptr<op::Reduce> make_any(std::shared_ptr<Node> arg,
const AxisSet& reduction_axes)
{
auto f_arg0 = std::make_shared<op::Parameter>(element::boolean, Shape{});
auto f_arg1 = std::make_shared<op::Parameter>(element::boolean, Shape{});
auto f_or = std::make_shared<op::Or>(f_arg0, f_arg1);
auto f = std::make_shared<Function>(f_or, ParameterVector{f_arg0, f_arg1});
auto k_false = op::Constant::create(element::boolean, Shape{}, std::vector<char>{0});
return std::make_shared<op::Reduce>(arg, k_false, f, reduction_axes);
}
// Ripped off of pass/any_all_replacement.cpp.
static std::shared_ptr<op::Reduce> make_all(std::shared_ptr<Node> arg,
const AxisSet& reduction_axes)
{
auto f_arg0 = std::make_shared<op::Parameter>(element::boolean, Shape{});
auto f_arg1 = std::make_shared<op::Parameter>(element::boolean, Shape{});
auto f_and = std::make_shared<op::And>(f_arg0, f_arg1);
auto f = std::make_shared<Function>(f_and, ParameterVector{f_arg0, f_arg1});
auto k_true = op::Constant::create(element::boolean, Shape{}, std::vector<char>{1});
return std::make_shared<op::Reduce>(arg, k_true, f, reduction_axes);
}
static void
check_any_replacement(std::shared_ptr<Node> n, std::shared_ptr<Node> arg, const AxisSet& axes)
{
auto any = std::dynamic_pointer_cast<op::Any>(n);
ASSERT_NE(any, nullptr);
ASSERT_EQ(any->get_reduction_axes(), axes);
ASSERT_EQ(any->get_argument(0), arg);
}
static void
check_all_replacement(std::shared_ptr<Node> n, std::shared_ptr<Node> arg, const AxisSet& axes)
{
auto all = std::dynamic_pointer_cast<op::All>(n);
ASSERT_NE(all, nullptr);
ASSERT_EQ(all->get_reduction_axes(), axes);
ASSERT_EQ(all->get_argument(0), arg);
}
TEST(any_all_insertion, any_simple)
{
auto param = make_shared<op::Parameter>(element::boolean, Shape{2, 3, 4});
auto any = make_any(param, AxisSet{1});
auto f = make_shared<Function>(any, ParameterVector{param});
pass::Manager pass_manager;
pass_manager.register_pass<pass::AnyAllInsertion>();
pass_manager.run_passes(f);
check_any_replacement(
f->get_results().at(0)->get_argument(0), param, any->get_reduction_axes());
}
TEST(any_all_insertion, any_chained)
{
auto param = make_shared<op::Parameter>(element::boolean, Shape{2, 3, 4});
auto any_0 = make_any(param, AxisSet{1});
auto any_1 = make_any(any_0, AxisSet{1});
auto f = make_shared<Function>(
ResultVector{make_shared<op::Result>(any_0), make_shared<op::Result>(any_1)},
ParameterVector{param});
pass::Manager pass_manager;
pass_manager.register_pass<pass::AnyAllInsertion>();
pass_manager.run_passes(f);
check_any_replacement(
f->get_results().at(0)->get_argument(0), param, any_0->get_reduction_axes());
check_any_replacement(f->get_results().at(1)->get_argument(0),
f->get_results().at(0)->get_argument(0),
any_1->get_reduction_axes());
}
TEST(any_all_insertion, all_simple)
{
auto param = make_shared<op::Parameter>(element::boolean, Shape{2, 3, 4});
auto all = make_all(param, AxisSet{1});
auto f = make_shared<Function>(all, ParameterVector{param});
pass::Manager pass_manager;
pass_manager.register_pass<pass::AnyAllInsertion>();
pass_manager.run_passes(f);
check_all_replacement(
f->get_results().at(0)->get_argument(0), param, all->get_reduction_axes());
}
TEST(any_all_insertion, all_chained)
{
auto param = make_shared<op::Parameter>(element::boolean, Shape{2, 3, 4});
auto all_0 = make_all(param, AxisSet{1});
auto all_1 = make_all(all_0, AxisSet{1});
auto f = make_shared<Function>(
ResultVector{make_shared<op::Result>(all_0), make_shared<op::Result>(all_1)},
ParameterVector{param});
pass::Manager pass_manager;
pass_manager.register_pass<pass::AnyAllInsertion>();
pass_manager.run_passes(f);
check_all_replacement(
f->get_results().at(0)->get_argument(0), param, all_0->get_reduction_axes());
check_all_replacement(f->get_results().at(1)->get_argument(0),
f->get_results().at(0)->get_argument(0),
all_1->get_reduction_axes());
}
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/pass/any_all_replacement.hpp"
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "ngraph/pass/manager.hpp"
#include "util/test_tools.hpp"
using namespace ngraph;
using namespace std;
template <typename T>
static void
check_replacement(std::shared_ptr<Node> n, std::shared_ptr<Node> arg, const AxisSet& axes)
{
// NB: We are not checking all properties we could check here. (In particular could check the
// constant value.)
auto reduce = std::dynamic_pointer_cast<op::Reduce>(n);
ASSERT_NE(reduce, nullptr);
ASSERT_EQ(reduce->get_reduction_axes(), axes);
ASSERT_EQ(reduce->get_argument(0), arg);
auto k = std::dynamic_pointer_cast<op::Constant>(reduce->get_argument(1));
ASSERT_NE(k, nullptr);
auto reduce_f = reduce->get_functions().at(0);
auto reduce_f_op = std::dynamic_pointer_cast<T>(reduce_f->get_results().at(0)->get_argument(0));
ASSERT_NE(reduce_f_op, nullptr);
ASSERT_EQ(reduce_f_op->get_argument(0), reduce_f->get_parameters().at(0));
ASSERT_EQ(reduce_f_op->get_argument(1), reduce_f->get_parameters().at(1));
ASSERT_EQ(reduce_f->get_parameters().at(0)->get_shape(), Shape{});
ASSERT_EQ(reduce_f->get_parameters().at(1)->get_shape(), Shape{});
}
TEST(any_all_replacement, any_simple)
{
auto param = make_shared<op::Parameter>(element::boolean, Shape{2, 3, 4});
auto any = make_shared<op::Any>(param, AxisSet{1});
auto f = make_shared<Function>(any, ParameterVector{param});
pass::Manager pass_manager;
pass_manager.register_pass<pass::AnyAllReplacement>();
pass_manager.run_passes(f);
check_replacement<op::Or>(
f->get_results().at(0)->get_argument(0), param, any->get_reduction_axes());
}
TEST(any_all_replacement, any_chained)
{
auto param = make_shared<op::Parameter>(element::boolean, Shape{2, 3, 4});
auto any_0 = make_shared<op::Any>(param, AxisSet{1});
auto any_1 = make_shared<op::Any>(any_0, AxisSet{1});
auto f = make_shared<Function>(
ResultVector{make_shared<op::Result>(any_0), make_shared<op::Result>(any_1)},
ParameterVector{param});
pass::Manager pass_manager;
pass_manager.register_pass<pass::AnyAllReplacement>();
pass_manager.run_passes(f);
check_replacement<op::Or>(
f->get_results().at(0)->get_argument(0), param, any_0->get_reduction_axes());
check_replacement<op::Or>(f->get_results().at(1)->get_argument(0),
f->get_results().at(0)->get_argument(0),
any_1->get_reduction_axes());
}
TEST(any_all_replacement, all_simple)
{
auto param = make_shared<op::Parameter>(element::boolean, Shape{2, 3, 4});
auto all = make_shared<op::All>(param, AxisSet{1});
auto f = make_shared<Function>(all, ParameterVector{param});
pass::Manager pass_manager;
pass_manager.register_pass<pass::AnyAllReplacement>();
pass_manager.run_passes(f);
check_replacement<op::And>(
f->get_results().at(0)->get_argument(0), param, all->get_reduction_axes());
}
TEST(any_all_replacement, all_chained)
{
auto param = make_shared<op::Parameter>(element::boolean, Shape{2, 3, 4});
auto all_0 = make_shared<op::All>(param, AxisSet{1});
auto all_1 = make_shared<op::All>(all_0, AxisSet{1});
auto f = make_shared<Function>(
ResultVector{make_shared<op::Result>(all_0), make_shared<op::Result>(all_1)},
ParameterVector{param});
pass::Manager pass_manager;
pass_manager.register_pass<pass::AnyAllReplacement>();
pass_manager.run_passes(f);
check_replacement<op::And>(
f->get_results().at(0)->get_argument(0), param, all_0->get_reduction_axes());
check_replacement<op::And>(f->get_results().at(1)->get_argument(0),
f->get_results().at(0)->get_argument(0),
all_1->get_reduction_axes());
}
This diff is collapsed.
This diff is collapsed.
...@@ -13657,3 +13657,277 @@ TEST(type_prop, shape_of_partial_rank_dynamic) ...@@ -13657,3 +13657,277 @@ TEST(type_prop, shape_of_partial_rank_dynamic)
ASSERT_EQ(so->get_output_element_type(0), element::u64); ASSERT_EQ(so->get_output_element_type(0), element::u64);
ASSERT_TRUE(so->get_output_partial_shape(0).same_scheme(PartialShape::dynamic(1))); ASSERT_TRUE(so->get_output_partial_shape(0).same_scheme(PartialShape::dynamic(1)));
} }
TEST(type_prop, any_deduce)
{
auto param_0 = make_shared<op::Parameter>(element::boolean, Shape{2, 4});
auto r0 = make_shared<op::Any>(param_0, AxisSet{0});
ASSERT_EQ(r0->get_element_type(), element::boolean);
ASSERT_EQ(r0->get_shape(), (Shape{4}));
auto r1 = make_shared<op::Any>(param_0, AxisSet{1});
ASSERT_EQ(r1->get_element_type(), element::boolean);
ASSERT_EQ(r1->get_shape(), (Shape{2}));
auto r01 = make_shared<op::Any>(param_0, AxisSet{0, 1});
ASSERT_EQ(r01->get_element_type(), element::boolean);
ASSERT_EQ(r01->get_shape(), (Shape{}));
auto r_none = make_shared<op::Any>(param_0, AxisSet{});
ASSERT_EQ(r_none->get_element_type(), element::boolean);
ASSERT_EQ(r_none->get_shape(), (Shape{2, 4}));
}
TEST(type_prop, any_deduce_et_dynamic)
{
auto param_0 = make_shared<op::Parameter>(element::dynamic, Shape{2, 4});
auto r0 = make_shared<op::Any>(param_0, AxisSet{0});
ASSERT_EQ(r0->get_element_type(), element::boolean);
ASSERT_EQ(r0->get_shape(), (Shape{4}));
auto r1 = make_shared<op::Any>(param_0, AxisSet{1});
ASSERT_EQ(r1->get_element_type(), element::boolean);
ASSERT_EQ(r1->get_shape(), (Shape{2}));
auto r01 = make_shared<op::Any>(param_0, AxisSet{0, 1});
ASSERT_EQ(r01->get_element_type(), element::boolean);
ASSERT_EQ(r01->get_shape(), (Shape{}));
auto r_none = make_shared<op::Any>(param_0, AxisSet{});
ASSERT_EQ(r_none->get_element_type(), element::boolean);
ASSERT_EQ(r_none->get_shape(), (Shape{2, 4}));
}
TEST(type_prop, any_et_non_boolean)
{
auto param_0 = make_shared<op::Parameter>(element::i32, Shape{2, 4});
try
{
auto r = make_shared<op::Any>(param_0, AxisSet{0, 1});
// Should have thrown, so fail if it didn't
FAIL() << "Did not detect invalid element type for Any";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Input element type must be boolean"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, any_axis_oob)
{
auto param_0 = make_shared<op::Parameter>(element::boolean, Shape{2, 4});
try
{
auto r = make_shared<op::Any>(param_0, AxisSet{0, 2, 1});
// Should have thrown, so fail if it didn't
FAIL() << "Did not detect out-of-bound axis for Any";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Reduction axis (2) is out of bounds"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, any_partial_rank_dynamic)
{
auto param = make_shared<op::Parameter>(element::boolean, PartialShape::dynamic());
auto axes = AxisSet{2385, 0, 4404}; // arbitrary
auto any = make_shared<op::Any>(param, axes);
EXPECT_EQ(any->get_output_element_type(0), element::boolean);
EXPECT_TRUE(any->get_output_partial_shape(0).is_dynamic());
}
TEST(type_prop, any_partial_rank_static_dynamic_ok_result_static)
{
auto param = make_shared<op::Parameter>(element::boolean,
PartialShape{1, 2, Dimension::dynamic(), 4, 5});
auto axes = AxisSet{2, 3};
auto any = make_shared<op::Any>(param, axes);
EXPECT_EQ(any->get_output_element_type(0), element::boolean);
EXPECT_EQ(any->get_shape(), (Shape{1, 2, 5}));
}
TEST(type_prop, any_partial_rank_static_dynamic_ok_result_dynamic)
{
auto param = make_shared<op::Parameter>(
element::boolean, PartialShape{1, 2, Dimension::dynamic(), 4, Dimension::dynamic()});
auto axes = AxisSet{2, 3};
auto any = make_shared<op::Any>(param, axes);
EXPECT_EQ(any->get_output_element_type(0), element::boolean);
EXPECT_TRUE(
any->get_output_partial_shape(0).same_scheme(PartialShape{1, 2, Dimension::dynamic()}));
}
TEST(type_prop, any_partial_rank_static_dynamic_axes_oob)
{
auto param = make_shared<op::Parameter>(
element::boolean, PartialShape{1, 2, Dimension::dynamic(), 4, Dimension::dynamic()});
auto axes = AxisSet{2, 5, 1};
try
{
auto any = make_shared<op::Any>(param, axes);
// Should have thrown, so fail if it didn't
FAIL() << "Did not detect out-of-bound axis for Any (rank-static dynamic input)";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Reduction axis (5) is out of bounds"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, all_deduce)
{
auto param_0 = make_shared<op::Parameter>(element::boolean, Shape{2, 4});
auto r0 = make_shared<op::All>(param_0, AxisSet{0});
ASSERT_EQ(r0->get_element_type(), element::boolean);
ASSERT_EQ(r0->get_shape(), (Shape{4}));
auto r1 = make_shared<op::All>(param_0, AxisSet{1});
ASSERT_EQ(r1->get_element_type(), element::boolean);
ASSERT_EQ(r1->get_shape(), (Shape{2}));
auto r01 = make_shared<op::All>(param_0, AxisSet{0, 1});
ASSERT_EQ(r01->get_element_type(), element::boolean);
ASSERT_EQ(r01->get_shape(), (Shape{}));
auto r_none = make_shared<op::All>(param_0, AxisSet{});
ASSERT_EQ(r_none->get_element_type(), element::boolean);
ASSERT_EQ(r_none->get_shape(), (Shape{2, 4}));
}
TEST(type_prop, all_deduce_et_dynamic)
{
auto param_0 = make_shared<op::Parameter>(element::dynamic, Shape{2, 4});
auto r0 = make_shared<op::All>(param_0, AxisSet{0});
ASSERT_EQ(r0->get_element_type(), element::boolean);
ASSERT_EQ(r0->get_shape(), (Shape{4}));
auto r1 = make_shared<op::All>(param_0, AxisSet{1});
ASSERT_EQ(r1->get_element_type(), element::boolean);
ASSERT_EQ(r1->get_shape(), (Shape{2}));
auto r01 = make_shared<op::All>(param_0, AxisSet{0, 1});
ASSERT_EQ(r01->get_element_type(), element::boolean);
ASSERT_EQ(r01->get_shape(), (Shape{}));
auto r_none = make_shared<op::All>(param_0, AxisSet{});
ASSERT_EQ(r_none->get_element_type(), element::boolean);
ASSERT_EQ(r_none->get_shape(), (Shape{2, 4}));
}
TEST(type_prop, all_et_non_boolean)
{
auto param_0 = make_shared<op::Parameter>(element::i32, Shape{2, 4});
try
{
auto r = make_shared<op::All>(param_0, AxisSet{0, 1});
// Should have thrown, so fail if it didn't
FAIL() << "Did not detect invalid element type for All";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Input element type must be boolean"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, all_axis_oob)
{
auto param_0 = make_shared<op::Parameter>(element::boolean, Shape{2, 4});
try
{
auto r = make_shared<op::All>(param_0, AxisSet{0, 2, 1});
// Should have thrown, so fail if it didn't
FAIL() << "Did not detect out-of-bound axis for All";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Reduction axis (2) is out of bounds"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, all_partial_rank_dynamic)
{
auto param = make_shared<op::Parameter>(element::boolean, PartialShape::dynamic());
auto axes = AxisSet{2385, 0, 4404}; // arbitrary
auto all = make_shared<op::All>(param, axes);
EXPECT_EQ(all->get_output_element_type(0), element::boolean);
EXPECT_TRUE(all->get_output_partial_shape(0).is_dynamic());
}
TEST(type_prop, all_partial_rank_static_dynamic_ok_result_static)
{
auto param = make_shared<op::Parameter>(element::boolean,
PartialShape{1, 2, Dimension::dynamic(), 4, 5});
auto axes = AxisSet{2, 3};
auto all = make_shared<op::All>(param, axes);
EXPECT_EQ(all->get_output_element_type(0), element::boolean);
EXPECT_EQ(all->get_shape(), (Shape{1, 2, 5}));
}
TEST(type_prop, all_partial_rank_static_dynamic_ok_result_dynamic)
{
auto param = make_shared<op::Parameter>(
element::boolean, PartialShape{1, 2, Dimension::dynamic(), 4, Dimension::dynamic()});
auto axes = AxisSet{2, 3};
auto all = make_shared<op::All>(param, axes);
EXPECT_EQ(all->get_output_element_type(0), element::boolean);
EXPECT_TRUE(
all->get_output_partial_shape(0).same_scheme(PartialShape{1, 2, Dimension::dynamic()}));
}
TEST(type_prop, all_partial_rank_static_dynamic_axes_oob)
{
auto param = make_shared<op::Parameter>(
element::boolean, PartialShape{1, 2, Dimension::dynamic(), 4, Dimension::dynamic()});
auto axes = AxisSet{2, 5, 1};
try
{
auto all = make_shared<op::All>(param, axes);
// Should have thrown, so fail if it didn't
FAIL() << "Did not detect out-of-bound axis for All (rank-static dynamic input)";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Reduction axis (5) is out of bounds"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment