Commit e3c1ff7a authored by Gleb Kazantaev's avatar Gleb Kazantaev Committed by Scott Cyphers

Added ReduceMean, ReduceMin, ReduceMax; Updated ReduceProd, ReduceSum (#3693)

* Added ReduceMean, ReduceMin, ReduceMax; Updated ReduceProd, ReduceSum

* Removed redundant method from arithmetic_reductions_keep_dims.cpp

* Fixed code style issues

* Removed generate_adjoints

* Fixed cpu_emitter
parent d4dd143f
...@@ -254,6 +254,8 @@ set (SRC ...@@ -254,6 +254,8 @@ set (SRC
op/product.hpp op/product.hpp
op/reduce_prod.cpp op/reduce_prod.cpp
op/reduce_prod.hpp op/reduce_prod.hpp
op/reduce_mean.cpp
op/reduce_mean.hpp
op/reduce_sum.cpp op/reduce_sum.cpp
op/reduce_sum.hpp op/reduce_sum.hpp
op/quantize.cpp op/quantize.cpp
...@@ -370,6 +372,8 @@ set (SRC ...@@ -370,6 +372,8 @@ set (SRC
op/util/activation_functions.hpp op/util/activation_functions.hpp
op/util/arithmetic_reduction.cpp op/util/arithmetic_reduction.cpp
op/util/arithmetic_reduction.hpp op/util/arithmetic_reduction.hpp
op/util/arithmetic_reductions_keep_dims.hpp
op/util/arithmetic_reductions_keep_dims.cpp
op/util/binary_elementwise_arithmetic.cpp op/util/binary_elementwise_arithmetic.cpp
op/util/binary_elementwise_arithmetic.hpp op/util/binary_elementwise_arithmetic.hpp
op/util/binary_elementwise_comparison.cpp op/util/binary_elementwise_comparison.cpp
......
...@@ -20,27 +20,27 @@ ...@@ -20,27 +20,27 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
constexpr NodeTypeInfo op::Max::type_info; constexpr NodeTypeInfo op::v0::Max::type_info;
op::Max::Max(const Output<Node>& arg, const AxisSet& reduction_axes) op::v0::Max::Max(const Output<Node>& arg, const AxisSet& reduction_axes)
: ArithmeticReduction(arg, reduction_axes) : ArithmeticReduction(arg, reduction_axes)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
op::Max::Max(const Output<Node>& arg, const Output<Node>& reduction_axes) op::v0::Max::Max(const Output<Node>& arg, const Output<Node>& reduction_axes)
: ArithmeticReduction(arg, reduction_axes) : ArithmeticReduction(arg, reduction_axes)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
shared_ptr<Node> op::Max::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::v0::Max::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
return make_shared<Max>(new_args.at(0), new_args.at(1)); return make_shared<op::v0::Max>(new_args.at(0), new_args.at(1));
} }
shared_ptr<Node> op::Max::get_default_value() const shared_ptr<Node> op::v0::Max::get_default_value() const
{ {
switch (get_element_type()) switch (get_element_type())
{ {
...@@ -80,3 +80,19 @@ shared_ptr<Node> op::Max::get_default_value() const ...@@ -80,3 +80,19 @@ shared_ptr<Node> op::Max::get_default_value() const
default: throw runtime_error("Max default value not defined for type"); default: throw runtime_error("Max default value not defined for type");
} }
} }
constexpr NodeTypeInfo op::v1::ReduceMax::type_info;
op::v1::ReduceMax::ReduceMax(const Output<Node>& arg,
const Output<Node>& reduction_axes,
bool keep_dims)
: ArithmeticReductionKeepDims(arg, reduction_axes, keep_dims)
{
constructor_validate_and_infer_types();
}
shared_ptr<Node> op::v1::ReduceMax::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<op::v1::ReduceMax>(new_args.at(0), new_args.at(1), get_keep_dims());
}
...@@ -17,36 +17,70 @@ ...@@ -17,36 +17,70 @@
#pragma once #pragma once
#include "ngraph/op/util/arithmetic_reduction.hpp" #include "ngraph/op/util/arithmetic_reduction.hpp"
#include "ngraph/op/util/arithmetic_reductions_keep_dims.hpp"
namespace ngraph namespace ngraph
{ {
namespace op namespace op
{ {
/// \brief Max-reduction operation. namespace v0
class Max : public util::ArithmeticReduction
{ {
public: /// \brief Max-reduction operation.
NGRAPH_API class Max : public util::ArithmeticReduction
static constexpr NodeTypeInfo type_info{"Max", 0}; {
const NodeTypeInfo& get_type_info() const override { return type_info; } public:
/// \brief Constructs a "max" reduction operation. NGRAPH_API
Max() = default; static constexpr NodeTypeInfo type_info{"Max", 0};
/// \brief Constructs a max-reduction operation.
/// const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \param arg The tensor to be reduced. /// \brief Constructs a "max" reduction operation.
/// \param reduction_axes The axis positions (0-based) to be elimaxated. Max() = default;
Max(const Output<Node>& arg, const AxisSet& reduction_axes);
/// \brief Constructs a "max" reduction operation. /// \brief Constructs a max-reduction operation.
/// ///
/// \param arg The tensor to be reduced. /// \param arg The tensor to be reduced.
/// \param reduction_axes The axis positions (0-based) to be elimaxated. /// \param reduction_axes The axis positions (0-based) to be elimaxated.
Max(const Output<Node>& arg, const Output<Node>& reduction_axes); Max(const Output<Node>& arg, const AxisSet& reduction_axes);
virtual std::shared_ptr<Node> /// \brief Constructs a "max" reduction operation.
copy_with_new_args(const NodeVector& new_args) const override; ///
/// \param arg The tensor to be reduced.
/// \return The default value for Max. /// \param reduction_axes The axis positions (0-based) to be elimaxated.
virtual std::shared_ptr<Node> get_default_value() const override; Max(const Output<Node>& arg, const Output<Node>& reduction_axes);
};
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
/// \return The default value for Max.
virtual std::shared_ptr<Node> get_default_value() const override;
};
}
namespace v1
{
class ReduceMax : public util::ArithmeticReductionKeepDims
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"ReduceMax", 1};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs a summation operation.
ReduceMax() = default;
/// \brief Constructs a summation operation.
///
/// \param arg The tensor to be summed.
/// \param reduction_axes The axis positions (0-based) to be eliminated.
/// \param keep_dims If set to 1 it holds axes that are used for reduction.
ReduceMax(const Output<Node>& arg,
const Output<Node>& reduction_axes,
bool keep_dims = false);
size_t get_version() const override { return 1; }
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
};
}
using v0::Max;
} }
} }
...@@ -20,27 +20,27 @@ ...@@ -20,27 +20,27 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
constexpr NodeTypeInfo op::Min::type_info; constexpr NodeTypeInfo op::v0::Min::type_info;
op::Min::Min(const Output<Node>& arg, const AxisSet& reduction_axes) op::v0::Min::Min(const Output<Node>& arg, const AxisSet& reduction_axes)
: ArithmeticReduction(arg, reduction_axes) : ArithmeticReduction(arg, reduction_axes)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
op::Min::Min(const Output<Node>& arg, const Output<Node>& reduction_axes) op::v0::Min::Min(const Output<Node>& arg, const Output<Node>& reduction_axes)
: ArithmeticReduction(arg, reduction_axes) : ArithmeticReduction(arg, reduction_axes)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
shared_ptr<Node> op::Min::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::v0::Min::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
return make_shared<Min>(new_args.at(0), get_reduction_axes()); return make_shared<op::v0::Min>(new_args.at(0), get_reduction_axes());
} }
shared_ptr<Node> op::Min::get_default_value() const shared_ptr<Node> op::v0::Min::get_default_value() const
{ {
switch (get_element_type()) switch (get_element_type())
{ {
...@@ -80,3 +80,19 @@ shared_ptr<Node> op::Min::get_default_value() const ...@@ -80,3 +80,19 @@ shared_ptr<Node> op::Min::get_default_value() const
default: throw runtime_error("Min default value not defined for type"); default: throw runtime_error("Min default value not defined for type");
} }
} }
constexpr NodeTypeInfo op::v1::ReduceMin::type_info;
op::v1::ReduceMin::ReduceMin(const Output<Node>& arg,
const Output<Node>& reduction_axes,
bool keep_dims)
: ArithmeticReductionKeepDims(arg, reduction_axes, keep_dims)
{
constructor_validate_and_infer_types();
}
shared_ptr<Node> op::v1::ReduceMin::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<op::v1::ReduceMin>(new_args.at(0), new_args.at(1), get_keep_dims());
}
...@@ -17,36 +17,70 @@ ...@@ -17,36 +17,70 @@
#pragma once #pragma once
#include "ngraph/op/util/arithmetic_reduction.hpp" #include "ngraph/op/util/arithmetic_reduction.hpp"
#include "ngraph/op/util/arithmetic_reductions_keep_dims.hpp"
namespace ngraph namespace ngraph
{ {
namespace op namespace op
{ {
/// \brief Min-reduction operation. namespace v0
class Min : public util::ArithmeticReduction
{ {
public: /// \brief Min-reduction operation.
NGRAPH_API class Min : public util::ArithmeticReduction
static constexpr NodeTypeInfo type_info{"Min", 0}; {
const NodeTypeInfo& get_type_info() const override { return type_info; } public:
/// \brief Constructs a "min" reduction operation. NGRAPH_API
Min() = default; static constexpr NodeTypeInfo type_info{"Min", 0};
/// \brief Constructs a min-reduction operation.
/// const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \param arg The tensor to be reduced. /// \brief Constructs a "min" reduction operation.
/// \param reduction_axes The axis positions (0-based) to be eliminated. Min() = default;
Min(const Output<Node>& arg, const AxisSet& reduction_axes);
/// \brief Constructs a "min" reduction operation. /// \brief Constructs a min-reduction operation.
/// ///
/// \param arg The tensor to be reduced. /// \param arg The tensor to be reduced.
/// \param reduction_axes The axis positions (0-based) to be eliminated. /// \param reduction_axes The axis positions (0-based) to be eliminated.
Min(const Output<Node>& arg, const Output<Node>& reduction_axes); Min(const Output<Node>& arg, const AxisSet& reduction_axes);
virtual std::shared_ptr<Node> /// \brief Constructs a "min" reduction operation.
copy_with_new_args(const NodeVector& new_args) const override; ///
/// \param arg The tensor to be reduced.
/// \return The default value for Min. /// \param reduction_axes The axis positions (0-based) to be eliminated.
virtual std::shared_ptr<Node> get_default_value() const override; Min(const Output<Node>& arg, const Output<Node>& reduction_axes);
};
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
/// \return The default value for Min.
virtual std::shared_ptr<Node> get_default_value() const override;
};
}
namespace v1
{
class ReduceMin : public util::ArithmeticReductionKeepDims
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"ReduceMin", 1};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs a summation operation.
ReduceMin() = default;
/// \brief Constructs a summation operation.
///
/// \param arg The tensor to be summed.
/// \param reduction_axes The axis positions (0-based) to be eliminated.
/// \param keep_dims If set to 1 it holds axes that are used for reduction.
ReduceMin(const Output<Node>& arg,
const Output<Node>& reduction_axes,
bool keep_dims = false);
size_t get_version() const override { return 1; }
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
};
}
using v0::Min;
} }
} }
//*****************************************************************************
// Copyright 2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/op/reduce_mean.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/op/broadcast.hpp"
using namespace std;
using namespace ngraph;
constexpr NodeTypeInfo op::v1::ReduceMean::type_info;
op::v1::ReduceMean::ReduceMean(const Output<Node>& arg,
const Output<Node>& reduction_axes,
bool keep_dims)
: ArithmeticReductionKeepDims(arg, reduction_axes, keep_dims)
{
constructor_validate_and_infer_types();
}
shared_ptr<Node> op::v1::ReduceMean::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<op::v1::ReduceMean>(new_args.at(0), new_args.at(1), get_keep_dims());
}
//*****************************************************************************
// Copyright 2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/axis_set.hpp"
#include "ngraph/op/util/arithmetic_reductions_keep_dims.hpp"
namespace ngraph
{
namespace op
{
namespace v1
{
class ReduceMean : public util::ArithmeticReductionKeepDims
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"ReduceMean", 1};
const NodeTypeInfo& get_type_info() const override { return type_info; }
ReduceMean() = default;
/// \param arg The tensor to be summed.
/// \param reduction_axes The axis positions (0-based) to be eliminated.
/// \param keep_dims If set to 1 it holds axes that are used for reduction.
ReduceMean(const Output<Node>& arg,
const Output<Node>& reduction_axes,
bool keep_dims = false);
size_t get_version() const override { return 1; }
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
};
}
}
}
...@@ -25,8 +25,7 @@ constexpr NodeTypeInfo op::v1::ReduceProd::type_info; ...@@ -25,8 +25,7 @@ constexpr NodeTypeInfo op::v1::ReduceProd::type_info;
op::v1::ReduceProd::ReduceProd(const Output<Node>& arg, op::v1::ReduceProd::ReduceProd(const Output<Node>& arg,
const Output<Node>& reduction_axes, const Output<Node>& reduction_axes,
bool keep_dims) bool keep_dims)
: ArithmeticReduction(arg, reduction_axes) : ArithmeticReductionKeepDims(arg, reduction_axes, keep_dims)
, m_keep_dims{keep_dims}
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
...@@ -39,52 +38,5 @@ shared_ptr<Node> op::v1::ReduceProd::get_default_value() const ...@@ -39,52 +38,5 @@ shared_ptr<Node> op::v1::ReduceProd::get_default_value() const
shared_ptr<Node> op::v1::ReduceProd::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::v1::ReduceProd::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
return make_shared<ReduceProd>(new_args.at(0), new_args.at(1), m_keep_dims); return make_shared<ReduceProd>(new_args.at(0), new_args.at(1), get_keep_dims());
}
void op::v1::ReduceProd::validate_and_infer_types()
{
if (m_keep_dims)
{
auto reduction_axes = get_reduction_axes();
auto input_shape = get_input_partial_shape(0);
auto input_rank = input_shape.rank();
PartialShape result_shape{PartialShape::dynamic()};
if (input_rank.is_static() && reduction_axes_constant())
{
std::vector<Dimension> dims;
for (auto axis : reduction_axes)
{
NODE_VALIDATION_CHECK(this,
axis < size_t(input_rank),
"Reduction axis (",
axis,
") is out of bounds ",
"(argument shape: ",
input_shape,
", reduction axes: ",
reduction_axes,
")");
}
for (size_t i = 0; i < size_t(input_rank); i++)
{
if (reduction_axes.count(i) == 0)
{
dims.push_back(input_shape[i]);
}
else
{
dims.push_back(Dimension{1});
}
}
result_shape = PartialShape(dims);
}
set_input_is_relevant_to_shape(1);
set_output_type(0, get_input_element_type(0), result_shape);
}
else
{
ArithmeticReduction::validate_and_infer_types();
}
} }
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include "ngraph/op/util/arithmetic_reduction.hpp" #include "ngraph/op/util/arithmetic_reductions_keep_dims.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -27,7 +27,7 @@ namespace ngraph ...@@ -27,7 +27,7 @@ namespace ngraph
/// \brief Product reduction operation. /// \brief Product reduction operation.
/// ///
/// Reduces the tensor, eliminating the specified reduction axes by taking the product. /// Reduces the tensor, eliminating the specified reduction axes by taking the product.
class ReduceProd : public util::ArithmeticReduction class ReduceProd : public util::ArithmeticReductionKeepDims
{ {
public: public:
NGRAPH_API NGRAPH_API
...@@ -44,20 +44,12 @@ namespace ngraph ...@@ -44,20 +44,12 @@ namespace ngraph
const Output<Node>& reduction_axes, const Output<Node>& reduction_axes,
bool keep_dims = false); bool keep_dims = false);
void validate_and_infer_types() override;
size_t get_version() const override { return 1; } size_t get_version() const override { return 1; }
/// \return If set to 1 it holds axes that are used for reduction.
/// For each such axis, output dimension is equal to 1.
bool get_keep_dims() const { return m_keep_dims; }
/// \return The default value for Product. /// \return The default value for Product.
virtual std::shared_ptr<Node> get_default_value() const override; virtual std::shared_ptr<Node> get_default_value() const override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
private:
bool m_keep_dims;
}; };
} }
} }
......
...@@ -26,8 +26,7 @@ constexpr NodeTypeInfo op::v1::ReduceSum::type_info; ...@@ -26,8 +26,7 @@ constexpr NodeTypeInfo op::v1::ReduceSum::type_info;
op::v1::ReduceSum::ReduceSum(const Output<Node>& arg, op::v1::ReduceSum::ReduceSum(const Output<Node>& arg,
const Output<Node>& reduction_axes, const Output<Node>& reduction_axes,
bool keep_dims) bool keep_dims)
: ArithmeticReduction(arg, reduction_axes) : ArithmeticReductionKeepDims(arg, reduction_axes, keep_dims)
, m_keep_dims{keep_dims}
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
...@@ -40,7 +39,7 @@ shared_ptr<Node> op::v1::ReduceSum::get_default_value() const ...@@ -40,7 +39,7 @@ shared_ptr<Node> op::v1::ReduceSum::get_default_value() const
shared_ptr<Node> op::v1::ReduceSum::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::v1::ReduceSum::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
return make_shared<ReduceSum>(new_args.at(0), new_args.at(1), m_keep_dims); return make_shared<ReduceSum>(new_args.at(0), new_args.at(1), get_keep_dims());
} }
void op::v1::ReduceSum::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas) void op::v1::ReduceSum::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
...@@ -52,50 +51,3 @@ void op::v1::ReduceSum::generate_adjoints(autodiff::Adjoints& adjoints, const No ...@@ -52,50 +51,3 @@ void op::v1::ReduceSum::generate_adjoints(autodiff::Adjoints& adjoints, const No
adjoints.add_delta(x, make_shared<op::Broadcast>(delta, x_shape, get_reduction_axes())); adjoints.add_delta(x, make_shared<op::Broadcast>(delta, x_shape, get_reduction_axes()));
} }
void op::v1::ReduceSum::validate_and_infer_types()
{
if (m_keep_dims)
{
auto reduction_axes = get_reduction_axes();
auto input_shape = get_input_partial_shape(0);
auto input_rank = input_shape.rank();
PartialShape result_shape{PartialShape::dynamic()};
if (input_rank.is_static() && reduction_axes_constant())
{
std::vector<Dimension> dims;
for (auto axis : reduction_axes)
{
NODE_VALIDATION_CHECK(this,
axis < size_t(input_rank),
"Reduction axis (",
axis,
") is out of bounds ",
"(argument shape: ",
input_shape,
", reduction axes: ",
reduction_axes,
")");
}
for (size_t i = 0; i < size_t(input_rank); i++)
{
if (reduction_axes.count(i) == 0)
{
dims.push_back(input_shape[i]);
}
else
{
dims.push_back(Dimension{1});
}
}
result_shape = PartialShape(dims);
}
set_input_is_relevant_to_shape(1);
set_output_type(0, get_input_element_type(0), result_shape);
}
else
{
ArithmeticReduction::validate_and_infer_types();
}
}
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
#pragma once #pragma once
#include "ngraph/axis_set.hpp" #include "ngraph/axis_set.hpp"
#include "ngraph/op/util/arithmetic_reduction.hpp" #include "ngraph/op/util/arithmetic_reductions_keep_dims.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -74,7 +74,7 @@ namespace ngraph ...@@ -74,7 +74,7 @@ namespace ngraph
/// | ----------------------------------------- | ---------------------------------------------------------------------------------------------------------------- | /// | ----------------------------------------- | ---------------------------------------------------------------------------------------------------------------- |
/// | \f$N[\textit{delete}(A,d_1,\dots,d_n)]\f$ | The tensor \f$T\f$, where \f$T\f$ is the input tensor with the `reduction_axes` \f$A\f$ eliminated by summation. | /// | \f$N[\textit{delete}(A,d_1,\dots,d_n)]\f$ | The tensor \f$T\f$, where \f$T\f$ is the input tensor with the `reduction_axes` \f$A\f$ eliminated by summation. |
// clang-format off // clang-format off
class ReduceSum : public util::ArithmeticReduction class ReduceSum : public util::ArithmeticReductionKeepDims
{ {
public: public:
NGRAPH_API NGRAPH_API
...@@ -91,12 +91,8 @@ namespace ngraph ...@@ -91,12 +91,8 @@ namespace ngraph
const Output<Node>& reduction_axes, const Output<Node>& reduction_axes,
bool keep_dims = false); bool keep_dims = false);
void validate_and_infer_types() override;
size_t get_version() const override { return 1; } size_t get_version() const override { return 1; }
/// \return If set to 1 it holds axes that are used for reduction.
/// For each such axis, output dimension is equal to 1.
bool get_keep_dims() const { return m_keep_dims; }
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
...@@ -106,9 +102,6 @@ namespace ngraph ...@@ -106,9 +102,6 @@ namespace ngraph
protected: protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override; const NodeVector& deltas) override;
private:
bool m_keep_dims;
}; };
} }
} }
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/op/util/arithmetic_reductions_keep_dims.hpp"
#include "ngraph/op/constant.hpp"
using namespace std;
using namespace ngraph;
op::util::ArithmeticReductionKeepDims::ArithmeticReductionKeepDims(
const ngraph::Output<ngraph::Node>& arg,
const ngraph::Output<ngraph::Node>& reduction_axes,
bool keep_dims)
: ArithmeticReduction(arg, reduction_axes)
, m_keep_dims{keep_dims}
{
}
void op::util::ArithmeticReductionKeepDims::validate_and_infer_types()
{
if (m_keep_dims)
{
auto reduction_axes = get_reduction_axes();
auto input_shape = get_input_partial_shape(0);
auto input_rank = input_shape.rank();
PartialShape result_shape{PartialShape::dynamic()};
if (input_rank.is_static() && reduction_axes_constant())
{
std::vector<Dimension> dims;
for (auto axis : reduction_axes)
{
NODE_VALIDATION_CHECK(this,
axis < size_t(input_rank),
"Reduction axis (",
axis,
") is out of bounds ",
"(argument shape: ",
input_shape,
", reduction axes: ",
reduction_axes,
")");
}
for (size_t i = 0; i < size_t(input_rank); i++)
{
if (reduction_axes.count(i) == 0)
{
dims.push_back(input_shape[i]);
}
else
{
dims.push_back(Dimension{1});
}
}
result_shape = PartialShape(dims);
}
set_input_is_relevant_to_shape(1);
set_output_type(0, get_input_element_type(0), result_shape);
}
else
{
ArithmeticReduction::validate_and_infer_types();
}
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/op/op.hpp"
#include "ngraph/op/util/arithmetic_reduction.hpp"
namespace ngraph
{
namespace op
{
namespace util
{
class ArithmeticReductionKeepDims : public util::ArithmeticReduction
{
protected:
ArithmeticReductionKeepDims() = default;
/// \param arg The tensor to be summed.
/// \param reduction_axes The axis positions (0-based) to be eliminated.
/// \param keep_dims If set to 1 it holds axes that are used for reduction.
ArithmeticReductionKeepDims(const Output<Node>& arg,
const Output<Node>& reduction_axes,
bool keep_dims = false);
public:
void validate_and_infer_types() override;
/// \return If set to 1 it holds axes that are used for reduction.
/// For each such axis, output dimension is equal to 1.
bool get_keep_dims() const { return m_keep_dims; }
void set_keep_dims(bool keep_dims) { m_keep_dims = keep_dims; }
private:
bool m_keep_dims = false;
};
}
}
}
...@@ -24,7 +24,9 @@ ...@@ -24,7 +24,9 @@
#include "ngraph/op/avg_pool.hpp" #include "ngraph/op/avg_pool.hpp"
#include "ngraph/op/convolution.hpp" #include "ngraph/op/convolution.hpp"
#include "ngraph/op/gather.hpp" #include "ngraph/op/gather.hpp"
#include "ngraph/op/max.hpp"
#include "ngraph/op/max_pool.hpp" #include "ngraph/op/max_pool.hpp"
#include "ngraph/op/min.hpp"
#include "ngraph/op/pad.hpp" #include "ngraph/op/pad.hpp"
#include "ngraph/op/product.hpp" #include "ngraph/op/product.hpp"
#include "ngraph/op/sum.hpp" #include "ngraph/op/sum.hpp"
...@@ -128,9 +130,7 @@ namespace ngraph ...@@ -128,9 +130,7 @@ namespace ngraph
class Reverse; class Reverse;
class ReverseSequence; class ReverseSequence;
class MaxPoolWithIndicesBackprop; class MaxPoolWithIndicesBackprop;
class Max;
class Erf; class Erf;
class Min;
class ReluBackprop; class ReluBackprop;
class Relu; class Relu;
class CPULeakyRelu; class CPULeakyRelu;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment