Commit 9ba4a78a authored by Jayaram Bobba's avatar Jayaram Bobba Committed by Scott Cyphers

Change reduction operations to 2-input dynamic variants (#2972)

* Change reduction operations to 2-input dynamic variants with convenience constructors for cases where reduction AxisSet is known at op construction time

* Modify rest of arithmetic and logical reduction ops to 2-input dynamic variants. Some fixes to existing passes to keep constant reduction axes inputs intact

* add new All tests to GPU manifest
parent ff5d79ca
...@@ -59,6 +59,11 @@ namespace ngraph ...@@ -59,6 +59,11 @@ namespace ngraph
static_cast<std::set<size_t>*>(this)->operator=(v); static_cast<std::set<size_t>*>(this)->operator=(v);
return *this; return *this;
} }
std::vector<int64_t> to_vector() const
{
return std::vector<int64_t>(this->begin(), this->end());
}
}; };
std::ostream& operator<<(std::ostream& s, const AxisSet& axis_set); std::ostream& operator<<(std::ostream& s, const AxisSet& axis_set);
......
...@@ -31,8 +31,14 @@ op::All::All(const Output<Node>& arg, const AxisSet& reduction_axes) ...@@ -31,8 +31,14 @@ op::All::All(const Output<Node>& arg, const AxisSet& reduction_axes)
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
op::All::All(const Output<Node>& arg, const Output<Node>& reduction_axes)
: LogicalReduction(arg, reduction_axes)
{
constructor_validate_and_infer_types();
}
shared_ptr<Node> op::All::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::All::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
return make_shared<All>(new_args.at(0), m_reduction_axes); return make_shared<All>(new_args.at(0), new_args.at(1));
} }
...@@ -39,6 +39,11 @@ namespace ngraph ...@@ -39,6 +39,11 @@ namespace ngraph
/// \param arg The tensor to be reduced. /// \param arg The tensor to be reduced.
/// \param reduction_axes The axis positions (0-based) to be eliminated. /// \param reduction_axes The axis positions (0-based) to be eliminated.
All(const Output<Node>& arg, const AxisSet& reduction_axes); All(const Output<Node>& arg, const AxisSet& reduction_axes);
/// \brief Constructs an "all" reduction operation.
///
/// \param arg The tensor to be reduced.
/// \param reduction_axes The axis positions (0-based) to be eliminated.
All(const Output<Node>& arg, const Output<Node>& reduction_axes);
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override; std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
......
...@@ -31,8 +31,14 @@ op::Any::Any(const Output<Node>& arg, const AxisSet& reduction_axes) ...@@ -31,8 +31,14 @@ op::Any::Any(const Output<Node>& arg, const AxisSet& reduction_axes)
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
op::Any::Any(const Output<Node>& arg, const Output<Node>& reduction_axes)
: LogicalReduction(arg, reduction_axes)
{
constructor_validate_and_infer_types();
}
shared_ptr<Node> op::Any::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Any::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
return make_shared<Any>(new_args.at(0), m_reduction_axes); return make_shared<Any>(new_args.at(0), new_args.at(1));
} }
...@@ -39,6 +39,11 @@ namespace ngraph ...@@ -39,6 +39,11 @@ namespace ngraph
/// \param arg The tensor to be reduced. /// \param arg The tensor to be reduced.
/// \param reduction_axes The axis positions (0-based) to be eliminated. /// \param reduction_axes The axis positions (0-based) to be eliminated.
Any(const Output<Node>& arg, const AxisSet& reduction_axes); Any(const Output<Node>& arg, const AxisSet& reduction_axes);
/// \brief Constructs an "any" reduction operation.
///
/// \param arg The tensor to be reduced.
/// \param reduction_axes The axis positions (0-based) to be eliminated.
Any(const Output<Node>& arg, const Output<Node>& reduction_axes);
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
......
...@@ -20,8 +20,20 @@ ...@@ -20,8 +20,20 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::Max::Max(const shared_ptr<Node>& arg, const AxisSet& reduction_axes) const string op::Max::type_name{"Max"};
: ArithmeticReduction("Max", arg, reduction_axes)
op::Max::Max()
{
}
op::Max::Max(const Output<Node>& arg, const AxisSet& reduction_axes)
: ArithmeticReduction(arg, reduction_axes)
{
constructor_validate_and_infer_types();
}
op::Max::Max(const Output<Node>& arg, const Output<Node>& reduction_axes)
: ArithmeticReduction(arg, reduction_axes)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
...@@ -29,7 +41,7 @@ op::Max::Max(const shared_ptr<Node>& arg, const AxisSet& reduction_axes) ...@@ -29,7 +41,7 @@ op::Max::Max(const shared_ptr<Node>& arg, const AxisSet& reduction_axes)
shared_ptr<Node> op::Max::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Max::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
return make_shared<Max>(new_args.at(0), m_reduction_axes); return make_shared<Max>(new_args.at(0), new_args.at(1));
} }
shared_ptr<Node> op::Max::get_default_value() const shared_ptr<Node> op::Max::get_default_value() const
......
...@@ -26,11 +26,21 @@ namespace ngraph ...@@ -26,11 +26,21 @@ namespace ngraph
class Max : public util::ArithmeticReduction class Max : public util::ArithmeticReduction
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a "max" reduction operation.
Max();
/// \brief Constructs a max-reduction operation. /// \brief Constructs a max-reduction operation.
/// ///
/// \param arg The tensor to be reduced. /// \param arg The tensor to be reduced.
/// \param reduction_axes The axis positions (0-based) to be eliminated. /// \param reduction_axes The axis positions (0-based) to be elimaxated.
Max(const std::shared_ptr<Node>& arg, const AxisSet& reduction_axes); Max(const Output<Node>& arg, const AxisSet& reduction_axes);
/// \brief Constructs a "max" reduction operation.
///
/// \param arg The tensor to be reduced.
/// \param reduction_axes The axis positions (0-based) to be elimaxated.
Max(const Output<Node>& arg, const Output<Node>& reduction_axes);
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
......
...@@ -20,8 +20,20 @@ ...@@ -20,8 +20,20 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::Min::Min(const shared_ptr<Node>& arg, const AxisSet& reduction_axes) const string op::Min::type_name{"Min"};
: ArithmeticReduction("Min", arg, reduction_axes)
op::Min::Min()
{
}
op::Min::Min(const Output<Node>& arg, const AxisSet& reduction_axes)
: ArithmeticReduction(arg, reduction_axes)
{
constructor_validate_and_infer_types();
}
op::Min::Min(const Output<Node>& arg, const Output<Node>& reduction_axes)
: ArithmeticReduction(arg, reduction_axes)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
...@@ -29,7 +41,7 @@ op::Min::Min(const shared_ptr<Node>& arg, const AxisSet& reduction_axes) ...@@ -29,7 +41,7 @@ op::Min::Min(const shared_ptr<Node>& arg, const AxisSet& reduction_axes)
shared_ptr<Node> op::Min::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Min::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
return make_shared<Min>(new_args.at(0), m_reduction_axes); return make_shared<Min>(new_args.at(0), get_reduction_axes());
} }
shared_ptr<Node> op::Min::get_default_value() const shared_ptr<Node> op::Min::get_default_value() const
......
...@@ -26,11 +26,21 @@ namespace ngraph ...@@ -26,11 +26,21 @@ namespace ngraph
class Min : public util::ArithmeticReduction class Min : public util::ArithmeticReduction
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a "min" reduction operation.
Min();
/// \brief Constructs a min-reduction operation. /// \brief Constructs a min-reduction operation.
/// ///
/// \param arg The tensor to be reduced. /// \param arg The tensor to be reduced.
/// \param reduction_axes The axis positions (0-based) to be eliminated. /// \param reduction_axes The axis positions (0-based) to be eliminated.
Min(const std::shared_ptr<Node>& arg, const AxisSet& reduction_axes); Min(const Output<Node>& arg, const AxisSet& reduction_axes);
/// \brief Constructs a "min" reduction operation.
///
/// \param arg The tensor to be reduced.
/// \param reduction_axes The axis positions (0-based) to be eliminated.
Min(const Output<Node>& arg, const Output<Node>& reduction_axes);
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
......
...@@ -19,8 +19,20 @@ ...@@ -19,8 +19,20 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::Product::Product(const shared_ptr<Node>& arg, const AxisSet& reduction_axes) const string op::Product::type_name{"Product"};
: ArithmeticReduction("Product", arg, reduction_axes)
op::Product::Product()
{
}
op::Product::Product(const Output<Node>& arg, const AxisSet& reduction_axes)
: ArithmeticReduction(arg, reduction_axes)
{
constructor_validate_and_infer_types();
}
op::Product::Product(const Output<Node>& arg, const Output<Node>& reduction_axes)
: ArithmeticReduction(arg, reduction_axes)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
...@@ -28,5 +40,5 @@ op::Product::Product(const shared_ptr<Node>& arg, const AxisSet& reduction_axes) ...@@ -28,5 +40,5 @@ op::Product::Product(const shared_ptr<Node>& arg, const AxisSet& reduction_axes)
shared_ptr<Node> op::Product::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Product::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
return make_shared<Product>(new_args.at(0), m_reduction_axes); return make_shared<Product>(new_args.at(0), get_reduction_axes());
} }
...@@ -29,11 +29,21 @@ namespace ngraph ...@@ -29,11 +29,21 @@ namespace ngraph
class Product : public util::ArithmeticReduction class Product : public util::ArithmeticReduction
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a product reduction operation.
Product();
/// \brief Constructs a product reduction operation.
///
/// \param arg The tensor to be reduced.
/// \param reduction_axes The axis positions (0-based) to be eliminated.
Product(const Output<Node>& arg, const AxisSet& reduction_axes);
/// \brief Constructs a product reduction operation. /// \brief Constructs a product reduction operation.
/// ///
/// \param arg The tensor to be reduced. /// \param arg The tensor to be reduced.
/// \param reduction_axes The axis positions (0-based) to be eliminated. /// \param reduction_axes The axis positions (0-based) to be eliminated.
Product(const std::shared_ptr<Node>& arg, const AxisSet& reduction_axes); Product(const Output<Node>& arg, const Output<Node>& reduction_axes);
/// \return The default value for Product. /// \return The default value for Product.
virtual std::shared_ptr<Node> get_default_value() const override virtual std::shared_ptr<Node> get_default_value() const override
......
...@@ -20,8 +20,20 @@ ...@@ -20,8 +20,20 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::Sum::Sum(const shared_ptr<Node>& arg, const AxisSet& reduction_axes) const string op::Sum::type_name{"Sum"};
: ArithmeticReduction("Sum", arg, reduction_axes)
op::Sum::Sum()
{
}
op::Sum::Sum(const Output<Node>& arg, const AxisSet& reduction_axes)
: ArithmeticReduction(arg, reduction_axes)
{
constructor_validate_and_infer_types();
}
op::Sum::Sum(const Output<Node>& arg, const Output<Node>& reduction_axes)
: ArithmeticReduction(arg, reduction_axes)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
...@@ -29,7 +41,7 @@ op::Sum::Sum(const shared_ptr<Node>& arg, const AxisSet& reduction_axes) ...@@ -29,7 +41,7 @@ op::Sum::Sum(const shared_ptr<Node>& arg, const AxisSet& reduction_axes)
shared_ptr<Node> op::Sum::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Sum::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
return make_shared<Sum>(new_args.at(0), m_reduction_axes); return make_shared<Sum>(new_args.at(0), new_args.at(1));
} }
void op::Sum::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas) void op::Sum::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
...@@ -39,5 +51,5 @@ void op::Sum::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& ...@@ -39,5 +51,5 @@ void op::Sum::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector&
auto x = get_argument(0); auto x = get_argument(0);
auto& x_shape = input(0).get_shape(); auto& x_shape = input(0).get_shape();
adjoints.add_delta(x, make_shared<op::Broadcast>(delta, x_shape, m_reduction_axes)); adjoints.add_delta(x, make_shared<op::Broadcast>(delta, x_shape, get_reduction_axes()));
} }
...@@ -74,11 +74,21 @@ namespace ngraph ...@@ -74,11 +74,21 @@ namespace ngraph
class Sum : public util::ArithmeticReduction class Sum : public util::ArithmeticReduction
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a summation operation.
Sum();
/// \brief Constructs a summation operation.
///
/// \param arg The tensor to be summed.
/// \param reduction_axes The axis positions (0-based) to be eliminated.
Sum(const Output<Node>& arg, const AxisSet& reduction_axes);
/// \brief Constructs a summation operation. /// \brief Constructs a summation operation.
/// ///
/// \param arg The tensor to be summed. /// \param arg The tensor to be summed.
/// \param reduction_axes The axis positions (0-based) to be eliminated. /// \param reduction_axes The axis positions (0-based) to be eliminated.
Sum(const std::shared_ptr<Node>& arg, const AxisSet& reduction_axes); Sum(const Output<Node>& arg, const Output<Node>& reduction_axes);
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
......
...@@ -17,26 +17,65 @@ ...@@ -17,26 +17,65 @@
#include <memory> #include <memory>
#include "ngraph/axis_vector.hpp" #include "ngraph/axis_vector.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/topk.hpp" #include "ngraph/op/topk.hpp"
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::TopK::TopK(const shared_ptr<Node>& arg, const string op::TopK::type_name{"TopK"};
op::TopK::TopK()
{
}
op::TopK::TopK(const Output<Node>& arg,
size_t top_k_axis, size_t top_k_axis,
const element::Type& index_element_type, const element::Type& index_element_type,
size_t k, size_t k,
bool compute_max) bool compute_max)
: Op("TopK", check_single_output_args({arg})) : Op({arg, op::Constant::create(element::i64, Shape{1}, {k})->output(0)})
, m_top_k_axis(top_k_axis)
, m_index_element_type(index_element_type)
, m_compute_max(compute_max)
{
constructor_validate_and_infer_types();
}
op::TopK::TopK(const Output<Node>& arg,
const Output<Node>& k,
size_t top_k_axis,
const element::Type& index_element_type,
bool compute_max)
: Op({arg, k})
, m_top_k_axis(top_k_axis) , m_top_k_axis(top_k_axis)
, m_index_element_type(index_element_type) , m_index_element_type(index_element_type)
, m_k(k)
, m_compute_max(compute_max) , m_compute_max(compute_max)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
size_t op::TopK::get_k() const
{
size_t k = 0;
if (auto const_op = dynamic_pointer_cast<op::Constant>(get_argument(1)))
{
k = const_op->get_vector<int64_t>()[0];
}
if (k == 0 && get_input_partial_shape(0).is_static())
{
k = get_input_partial_shape(0).to_shape()[m_top_k_axis];
}
return k;
}
void op::TopK::set_k(size_t k)
{
this->input(1).replace_source_output(
op::Constant::create(element::i64, Shape{1}, {k})->output(0));
}
void op::TopK::validate_and_infer_types() void op::TopK::validate_and_infer_types()
{ {
const PartialShape& input_shape = get_input_partial_shape(0); const PartialShape& input_shape = get_input_partial_shape(0);
...@@ -63,11 +102,12 @@ void op::TopK::validate_and_infer_types() ...@@ -63,11 +102,12 @@ void op::TopK::validate_and_infer_types()
m_top_k_axis, m_top_k_axis,
") is out of bounds."); ") is out of bounds.");
size_t k = get_k();
NODE_VALIDATION_CHECK(this, NODE_VALIDATION_CHECK(this,
input_rank.is_dynamic() || input_shape[m_top_k_axis].is_dynamic() || input_rank.is_dynamic() || input_shape[m_top_k_axis].is_dynamic() ||
m_k <= static_cast<size_t>(input_shape[m_top_k_axis]), k <= static_cast<size_t>(input_shape[m_top_k_axis]),
"K (", "K (",
m_k, k,
") exceeds the dimension (", ") exceeds the dimension (",
(input_rank.is_static() ? input_shape[m_top_k_axis] : 0), (input_rank.is_static() ? input_shape[m_top_k_axis] : 0),
") of the TopK axis (axis ", ") of the TopK axis (axis ",
...@@ -76,16 +116,9 @@ void op::TopK::validate_and_infer_types() ...@@ -76,16 +116,9 @@ void op::TopK::validate_and_infer_types()
PartialShape output_shape{input_shape}; PartialShape output_shape{input_shape};
if (input_rank.is_static()) if (input_rank.is_static() && k != 0)
{ {
if (m_k != 0) output_shape[m_top_k_axis] = k;
{
output_shape[m_top_k_axis] = m_k;
}
else if (input_shape[m_top_k_axis].is_static())
{
m_k = static_cast<size_t>(input_shape[m_top_k_axis]);
}
} }
set_output_size(2); set_output_size(2);
...@@ -97,7 +130,7 @@ shared_ptr<Node> op::TopK::copy_with_new_args(const NodeVector& new_args) const ...@@ -97,7 +130,7 @@ shared_ptr<Node> op::TopK::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
return make_shared<TopK>( return make_shared<TopK>(
new_args.at(0), m_top_k_axis, m_index_element_type, m_k, m_compute_max); new_args.at(0), new_args.at(1), m_top_k_axis, m_index_element_type, m_compute_max);
} }
void op::TopK::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas) void op::TopK::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
......
...@@ -30,6 +30,11 @@ namespace ngraph ...@@ -30,6 +30,11 @@ namespace ngraph
class TopK : public Op class TopK : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a TopK operation
TopK();
/// \brief Constructs a TopK operation. /// \brief Constructs a TopK operation.
/// ///
/// \param arg The input tensor /// \param arg The input tensor
...@@ -37,25 +42,38 @@ namespace ngraph ...@@ -37,25 +42,38 @@ namespace ngraph
/// \param index_element_type produce indices. Currently, only int64 or int32 are supported /// \param index_element_type produce indices. Currently, only int64 or int32 are supported
/// \param k Number of top indices to compute. Compute all indices if k = 0 /// \param k Number of top indices to compute. Compute all indices if k = 0
/// \param compute_max Compute top k max or top k min? /// \param compute_max Compute top k max or top k min?
TopK(const std::shared_ptr<Node>& arg, TopK(const Output<Node>& arg,
size_t top_k_axis, size_t top_k_axis,
const element::Type& index_element_type, const element::Type& index_element_type,
size_t k = 0, size_t k = 0,
bool compute_max = true); bool compute_max = true);
/// \brief Constructs a TopK operation.
///
/// \param arg The input tensor
/// \param k Number of top indices to compute. Compute all indices if k = 0
/// \param top_k_axis The axis along which to compute top k indices
/// \param index_element_type produce indices. Currently, only int64 or int32 are supported
/// \param compute_max Compute top k max or top k min?
TopK(const Output<Node>& arg,
const Output<Node>& k,
size_t top_k_axis,
const element::Type& index_element_type,
bool compute_max = true);
void validate_and_infer_types() override; void validate_and_infer_types() override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
size_t get_k() const;
void set_k(size_t k);
size_t get_top_k_axis() const { return m_top_k_axis; } size_t get_top_k_axis() const { return m_top_k_axis; }
element::Type get_index_element_type() const { return m_index_element_type; } element::Type get_index_element_type() const { return m_index_element_type; }
size_t get_k() const { return m_k; }
bool get_compute_max() const { return m_compute_max; } bool get_compute_max() const { return m_compute_max; }
protected: protected:
size_t m_top_k_axis; size_t m_top_k_axis;
element::Type m_index_element_type; element::Type m_index_element_type;
size_t m_k;
bool m_compute_max; bool m_compute_max;
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override; const NodeVector& deltas) override;
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
//***************************************************************************** //*****************************************************************************
#include "ngraph/op/util/arithmetic_reduction.hpp" #include "ngraph/op/util/arithmetic_reduction.hpp"
#include "ngraph/op/constant.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
...@@ -23,28 +24,41 @@ op::util::ArithmeticReduction::ArithmeticReduction() ...@@ -23,28 +24,41 @@ op::util::ArithmeticReduction::ArithmeticReduction()
{ {
} }
op::util::ArithmeticReduction::ArithmeticReduction(const std::shared_ptr<Node>& arg, op::util::ArithmeticReduction::ArithmeticReduction(const Output<Node>& arg,
const AxisSet& reduction_axes) const AxisSet& reduction_axes)
: Op(check_single_output_args({arg})) : Op({arg,
op::Constant::create(
element::i64, Shape{reduction_axes.size()}, reduction_axes.to_vector())
->output(0)})
{ {
set_reduction_axes(reduction_axes);
} }
op::util::ArithmeticReduction::ArithmeticReduction(const std::string& node_type, op::util::ArithmeticReduction::ArithmeticReduction(const Output<Node>& arg,
const std::shared_ptr<Node>& arg, const Output<Node>& reduction_axes)
const AxisSet& reduction_axes) : Op({arg, reduction_axes})
: Op(node_type, check_single_output_args({arg})) {
}
const AxisSet op::util::ArithmeticReduction::get_reduction_axes() const
{ {
set_reduction_axes(reduction_axes); AxisSet axes;
if (auto const_op = dynamic_pointer_cast<op::Constant>(get_argument(1)))
{
axes = const_op->get_axis_set_val();
}
return axes;
} }
void op::util::ArithmeticReduction::set_reduction_axes(const AxisSet& reduction_axes) void op::util::ArithmeticReduction::set_reduction_axes(const AxisSet& reduction_axes)
{ {
m_reduction_axes = reduction_axes; this->input(1).replace_source_output(
op::Constant::create(element::i64, Shape{reduction_axes.size()}, reduction_axes.to_vector())
->output(0));
} }
void op::util::ArithmeticReduction::validate_and_infer_types() void op::util::ArithmeticReduction::validate_and_infer_types()
{ {
auto reduction_axes = get_reduction_axes();
auto input_shape = get_input_partial_shape(0); auto input_shape = get_input_partial_shape(0);
auto input_rank = input_shape.rank(); auto input_rank = input_shape.rank();
...@@ -54,7 +68,7 @@ void op::util::ArithmeticReduction::validate_and_infer_types() ...@@ -54,7 +68,7 @@ void op::util::ArithmeticReduction::validate_and_infer_types()
{ {
std::vector<Dimension> dims; std::vector<Dimension> dims;
for (auto axis : m_reduction_axes) for (auto axis : reduction_axes)
{ {
NODE_VALIDATION_CHECK(this, NODE_VALIDATION_CHECK(this,
axis < size_t(input_rank), axis < size_t(input_rank),
...@@ -64,13 +78,13 @@ void op::util::ArithmeticReduction::validate_and_infer_types() ...@@ -64,13 +78,13 @@ void op::util::ArithmeticReduction::validate_and_infer_types()
"(argument shape: ", "(argument shape: ",
input_shape, input_shape,
", reduction axes: ", ", reduction axes: ",
m_reduction_axes, reduction_axes,
")"); ")");
} }
for (size_t i = 0; i < size_t(input_rank); i++) for (size_t i = 0; i < size_t(input_rank); i++)
{ {
if (m_reduction_axes.count(i) == 0) if (reduction_axes.count(i) == 0)
{ {
dims.push_back(input_shape[i]); dims.push_back(input_shape[i]);
} }
......
...@@ -37,32 +37,19 @@ namespace ngraph ...@@ -37,32 +37,19 @@ namespace ngraph
/// \param arg Output that produces the first input tensor. /// \param arg Output that produces the first input tensor.
/// \param reduction_axes The axis positions (0-based) to be eliminated. /// \param reduction_axes The axis positions (0-based) to be eliminated.
ArithmeticReduction(const Output<Node>& arg, const AxisSet& reduction_axes); ArithmeticReduction(const Output<Node>& arg, const AxisSet& reduction_axes);
/// \brief Constructs an arithmetic reduction operation.
///
/// \param arg Node that produces the first input tensor.
/// \param reduction_axes The axis positions (0-based) to be eliminated.
ArithmeticReduction(const std::shared_ptr<Node>& arg,
const AxisSet& reduction_axes);
/// \brief Constructs an arithmetic reduction operation. /// \brief Constructs an arithmetic reduction operation.
/// ///
/// \param arg Node that produces the first input tensor. /// \param arg Output that produces the first input tensor.
/// \param reduction_axes The axis positions (0-based) to be eliminated. /// \param reduction_axes The axis positions (0-based) to be eliminated.
ArithmeticReduction(const std::string& node_type, ArithmeticReduction(const Output<Node>& arg, const Output<Node>& reduction_axes);
const std::shared_ptr<Node>& arg,
const AxisSet& reduction_axes);
public: public:
void validate_and_infer_types() override; void validate_and_infer_types() override;
/// \return The axis positions (0-based) to be eliminated through reduction. /// \return The axis positions (0-based) to be eliminated through reduction.
const AxisSet& get_reduction_axes() const { return m_reduction_axes; } const AxisSet get_reduction_axes() const;
/// \brief Change the reduction axes /// \brief Change the reduction axes
void set_reduction_axes(const AxisSet& reduction_axes); void set_reduction_axes(const AxisSet& reduction_axes);
protected:
AxisSet m_reduction_axes;
}; };
} }
} }
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
//***************************************************************************** //*****************************************************************************
#include "ngraph/op/util/logical_reduction.hpp" #include "ngraph/op/util/logical_reduction.hpp"
#include "ngraph/op/constant.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
...@@ -24,38 +25,39 @@ op::util::LogicalReduction::LogicalReduction() ...@@ -24,38 +25,39 @@ op::util::LogicalReduction::LogicalReduction()
} }
op::util::LogicalReduction::LogicalReduction(const Output<Node>& arg, const AxisSet& reduction_axes) op::util::LogicalReduction::LogicalReduction(const Output<Node>& arg, const AxisSet& reduction_axes)
: Op({arg}) : Op({arg,
op::Constant::create(
element::i64, Shape{reduction_axes.size()}, reduction_axes.to_vector())
->output(0)})
{ {
set_reduction_axes(reduction_axes);
} }
op::util::LogicalReduction::LogicalReduction(const std::shared_ptr<Node>& arg, op::util::LogicalReduction::LogicalReduction(const Output<Node>& arg,
const AxisSet& reduction_axes) const Output<Node>& reduction_axes)
: Op(check_single_output_args({arg})) : Op({arg, reduction_axes})
{ {
set_reduction_axes(reduction_axes);
} }
op::util::LogicalReduction::LogicalReduction(const std::string& node_type, const AxisSet op::util::LogicalReduction::get_reduction_axes() const
const std::shared_ptr<Node>& arg,
const AxisSet& reduction_axes)
: Op(node_type, check_single_output_args({arg}))
{ {
set_reduction_axes(reduction_axes); AxisSet axes;
} if (auto const_op = dynamic_pointer_cast<op::Constant>(get_argument(1)))
{
const AxisSet& op::util::LogicalReduction::get_reduction_axes() const axes = const_op->get_axis_set_val();
{ }
return m_reduction_axes; return axes;
} }
void op::util::LogicalReduction::set_reduction_axes(const AxisSet& reduction_axes) void op::util::LogicalReduction::set_reduction_axes(const AxisSet& reduction_axes)
{ {
m_reduction_axes = reduction_axes; this->input(1).replace_source_output(
op::Constant::create(element::i64, Shape{reduction_axes.size()}, reduction_axes.to_vector())
->output(0));
} }
void op::util::LogicalReduction::validate_and_infer_types() void op::util::LogicalReduction::validate_and_infer_types()
{ {
auto reduction_axes = get_reduction_axes();
auto input_shape = get_input_partial_shape(0); auto input_shape = get_input_partial_shape(0);
auto input_rank = input_shape.rank(); auto input_rank = input_shape.rank();
...@@ -65,7 +67,7 @@ void op::util::LogicalReduction::validate_and_infer_types() ...@@ -65,7 +67,7 @@ void op::util::LogicalReduction::validate_and_infer_types()
{ {
std::vector<Dimension> dims; std::vector<Dimension> dims;
for (auto axis : m_reduction_axes) for (auto axis : reduction_axes)
{ {
NODE_VALIDATION_CHECK(this, NODE_VALIDATION_CHECK(this,
axis < size_t(input_rank), axis < size_t(input_rank),
...@@ -75,13 +77,13 @@ void op::util::LogicalReduction::validate_and_infer_types() ...@@ -75,13 +77,13 @@ void op::util::LogicalReduction::validate_and_infer_types()
"(argument shape: ", "(argument shape: ",
input_shape, input_shape,
", reduction axes: ", ", reduction axes: ",
m_reduction_axes, reduction_axes,
")"); ")");
} }
for (size_t i = 0; i < size_t(input_rank); i++) for (size_t i = 0; i < size_t(input_rank); i++)
{ {
if (m_reduction_axes.count(i) == 0) if (reduction_axes.count(i) == 0)
{ {
dims.push_back(input_shape[i]); dims.push_back(input_shape[i]);
} }
......
...@@ -36,28 +36,18 @@ namespace ngraph ...@@ -36,28 +36,18 @@ namespace ngraph
/// \param arg Output that produces the first input tensor. /// \param arg Output that produces the first input tensor.
/// \param reduction_axes The axis positions (0-based) to be eliminated. /// \param reduction_axes The axis positions (0-based) to be eliminated.
LogicalReduction(const Output<Node>& arg, const AxisSet& reduction_axes); LogicalReduction(const Output<Node>& arg, const AxisSet& reduction_axes);
/// \brief Constructs a logical reduction operation. /// \brief Constructs a 'dynamic' logical reduction operation.
/// ///
/// \param arg Node that produces the first input tensor. /// \param arg Node that produces the first input tensor.
/// \param reduction_axes The axis positions (0-based) to be eliminated. /// \param reduction_axes The axis positions (0-based) to be eliminated.
LogicalReduction(const std::shared_ptr<Node>& arg, const AxisSet& reduction_axes); LogicalReduction(const Output<Node>& arg, const Output<Node>& reduction_axes);
/// \brief Constructs a logical reduction operation.
///
/// \param arg Node that produces the first input tensor.
/// \param reduction_axes The axis positions (0-based) to be eliminated.
LogicalReduction(const std::string& node_type,
const std::shared_ptr<Node>& arg,
const AxisSet& reduction_axes);
public: public:
void validate_and_infer_types() override; void validate_and_infer_types() override;
/// \return The axis positions (0-based) to be eliminated through reduction. /// \return The axis positions (0-based) to be eliminated through reduction.
const AxisSet& get_reduction_axes() const; const AxisSet get_reduction_axes() const;
void set_reduction_axes(const AxisSet& reduction_axes); void set_reduction_axes(const AxisSet& reduction_axes);
protected:
AxisSet m_reduction_axes;
}; };
} }
} }
......
...@@ -250,7 +250,11 @@ static void materialize_shapes(shared_ptr<Node> n, ...@@ -250,7 +250,11 @@ static void materialize_shapes(shared_ptr<Node> n,
NGRAPH_DEBUG << "Materializing " << describe_reshape(reorders.at(arg)) << " for " NGRAPH_DEBUG << "Materializing " << describe_reshape(reorders.at(arg)) << " for "
<< arg->get_name(); << arg->get_name();
mark_reshape_for_deletion(reorders.at(arg), reshapes_to_delete); mark_reshape_for_deletion(reorders.at(arg), reshapes_to_delete);
insert_reshape(n, reorders.at(arg), i); if (reorders.at(arg)->get_input_order() != get_default_order(arg->get_shape()))
{
// Insert if arg needs to be transposed.
insert_reshape(n, reorders.at(arg), i);
}
//no swimming up //no swimming up
} }
} }
......
...@@ -44,7 +44,7 @@ static bool verify_no_internal_zero_length_ops(shared_ptr<Function> f) ...@@ -44,7 +44,7 @@ static bool verify_no_internal_zero_length_ops(shared_ptr<Function> f)
set<Output<Node>> zero_length_source_outputs; set<Output<Node>> zero_length_source_outputs;
for (auto n : f->get_ordered_ops()) for (auto n : f->get_ordered_ops())
{ {
if (n->is_output() || n->is_parameter() || n->get_output_size() > 1) if (n->is_output() || n->is_parameter() || n->is_constant() || n->get_output_size() > 1)
{ {
continue; continue;
} }
......
...@@ -78,6 +78,8 @@ any_2x2x3_eliminate_dims_0_1 ...@@ -78,6 +78,8 @@ any_2x2x3_eliminate_dims_0_1
any_2x2x3_eliminate_dims_0_2 any_2x2x3_eliminate_dims_0_2
any_2x2x3_eliminate_dims_1_2 any_2x2x3_eliminate_dims_1_2
any_2x2x3_eliminate_dims_0_1_2 any_2x2x3_eliminate_dims_0_1_2
all_dynamic_axis
all_change_axis
all_trivial all_trivial
all_2x2_to_scalar_false all_2x2_to_scalar_false
all_2x2_to_scalar_true all_2x2_to_scalar_true
......
...@@ -178,11 +178,12 @@ void runtime::hybrid::rewrite_function(const shared_ptr<Function>& f, ...@@ -178,11 +178,12 @@ void runtime::hybrid::rewrite_function(const shared_ptr<Function>& f,
{ {
for (auto input : node->get_arguments()) for (auto input : node->get_arguments())
{ {
if (input->get_placement_index() == 0) if (input->get_placement_index() == 0 && !input->is_constant())
{ {
// Since this input is from outside the cluster we need to create // Since this input is from outside the cluster we need to create
// a new Parameter node placed in the cluster instead of this external // a new Parameter node placed in the cluster instead of this external
// node // node. Constant nodes are ignored here since the values are available
// in the graph.
std::vector<Output<Node>> source_outputs = std::vector<Output<Node>> source_outputs =
get_outputs_to(*input, *node); get_outputs_to(*input, *node);
NGRAPH_CHECK( NGRAPH_CHECK(
......
...@@ -1053,7 +1053,7 @@ shared_ptr<runtime::Executable> ...@@ -1053,7 +1053,7 @@ shared_ptr<runtime::Executable>
} }
case OP_TYPEID::Sum: case OP_TYPEID::Sum:
{ {
arguments_check(op, 1, 1); arguments_check(op, 2, 1);
const shared_ptr<op::Sum> sum = static_pointer_cast<op::Sum>(op); const shared_ptr<op::Sum> sum = static_pointer_cast<op::Sum>(op);
const AxisSet& axis = sum->get_reduction_axes(); const AxisSet& axis = sum->get_reduction_axes();
...@@ -1071,7 +1071,7 @@ shared_ptr<runtime::Executable> ...@@ -1071,7 +1071,7 @@ shared_ptr<runtime::Executable>
} }
case OP_TYPEID::Product: case OP_TYPEID::Product:
{ {
arguments_check(op, 1, 1); arguments_check(op, 2, 1);
const shared_ptr<op::Product> prod = static_pointer_cast<op::Product>(op); const shared_ptr<op::Product> prod = static_pointer_cast<op::Product>(op);
const AxisSet& axis = prod->get_reduction_axes(); const AxisSet& axis = prod->get_reduction_axes();
...@@ -1142,7 +1142,7 @@ shared_ptr<runtime::Executable> ...@@ -1142,7 +1142,7 @@ shared_ptr<runtime::Executable>
} }
case OP_TYPEID::All: case OP_TYPEID::All:
{ {
arguments_check(op, 1, 1); arguments_check(op, 2, 1);
// Empty axis is not a case for do_equal_propagation() // Empty axis is not a case for do_equal_propagation()
kern.emit<op::All>(static_pointer_cast<op::All>(op)); kern.emit<op::All>(static_pointer_cast<op::All>(op));
...@@ -1150,7 +1150,7 @@ shared_ptr<runtime::Executable> ...@@ -1150,7 +1150,7 @@ shared_ptr<runtime::Executable>
} }
case OP_TYPEID::Any: case OP_TYPEID::Any:
{ {
arguments_check(op, 1, 1); arguments_check(op, 2, 1);
// Empty axis is not a case for do_equal_propagation() // Empty axis is not a case for do_equal_propagation()
kern.emit<op::Any>(static_pointer_cast<op::Any>(op)); kern.emit<op::Any>(static_pointer_cast<op::Any>(op));
...@@ -1850,14 +1850,14 @@ shared_ptr<runtime::Executable> ...@@ -1850,14 +1850,14 @@ shared_ptr<runtime::Executable>
} }
case OP_TYPEID::Min: case OP_TYPEID::Min:
{ {
arguments_check(op, 1, 1); arguments_check(op, 2, 1);
kern.emit<op::Min>(static_pointer_cast<op::Min>(op)); kern.emit<op::Min>(static_pointer_cast<op::Min>(op));
break; break;
} }
case OP_TYPEID::Max: case OP_TYPEID::Max:
{ {
arguments_check(op, 1, 1); arguments_check(op, 2, 1);
kern.emit<op::Max>(static_pointer_cast<op::Max>(op)); kern.emit<op::Max>(static_pointer_cast<op::Max>(op));
break; break;
...@@ -2009,7 +2009,7 @@ shared_ptr<runtime::Executable> ...@@ -2009,7 +2009,7 @@ shared_ptr<runtime::Executable>
} }
case OP_TYPEID::TopK: case OP_TYPEID::TopK:
{ {
arguments_check(op, 1, 2); arguments_check(op, 2, 2);
const shared_ptr<op::TopK> topk_op = static_pointer_cast<op::TopK>(op); const shared_ptr<op::TopK> topk_op = static_pointer_cast<op::TopK>(op);
......
...@@ -274,3 +274,45 @@ NGRAPH_TEST(${BACKEND_NAME}, all_2x2x3_eliminate_dims_0_1_2) ...@@ -274,3 +274,45 @@ NGRAPH_TEST(${BACKEND_NAME}, all_2x2x3_eliminate_dims_0_1_2)
handle->call_with_validate({result}, {a}); handle->call_with_validate({result}, {a});
EXPECT_EQ((vector<char>{0}), read_vector<char>(result)); EXPECT_EQ((vector<char>{0}), read_vector<char>(result));
} }
NGRAPH_TEST(${BACKEND_NAME}, all_dynamic_axis)
{
Shape shape{2, 3};
auto A = make_shared<op::Parameter>(element::boolean, shape);
auto B = op::Constant::create(element::i64, Shape{1}, {1});
auto f = make_shared<Function>(make_shared<op::All>(A, B), ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::boolean, shape);
copy_data(a, test::NDArray<char, 2>({{1, 0, 1}, {1, 1, 1}}).get_vector());
auto result = backend->create_tensor(element::boolean, Shape{2});
auto handle = backend->compile(f);
handle->call_with_validate({result}, {a});
EXPECT_EQ((vector<char>{0, 1}), read_vector<char>(result));
}
NGRAPH_TEST(${BACKEND_NAME}, all_change_axis)
{
Shape shape{2, 3};
auto A = make_shared<op::Parameter>(element::boolean, shape);
auto B = op::Constant::create(element::i64, Shape{1}, {1});
auto all = make_shared<op::All>(A, B);
ASSERT_EQ(all->get_reduction_axes(), AxisSet{1});
auto f = make_shared<Function>(all, ParameterVector{A});
all->set_reduction_axes(AxisSet{0});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::boolean, shape);
copy_data(a, test::NDArray<char, 2>({{1, 0, 1}, {1, 1, 1}}).get_vector());
auto result = backend->create_tensor(element::boolean, Shape{3});
auto handle = backend->compile(f);
handle->call_with_validate({result}, {a});
EXPECT_EQ((vector<char>{1, 0, 1}), read_vector<char>(result));
}
...@@ -344,9 +344,9 @@ TEST(copy, sum) ...@@ -344,9 +344,9 @@ TEST(copy, sum)
Shape shape{4, 3}; Shape shape{4, 3};
AxisSet axes{1}; AxisSet axes{1};
auto arg0 = make_shared<op::Parameter>(element::f32, shape); auto arg0 = make_shared<op::Parameter>(element::f32, shape);
NodeVector new_args{make_shared<op::Parameter>(element::f32, shape)};
auto node = make_shared<op::Sum>(arg0, axes); auto node = make_shared<op::Sum>(arg0, axes);
NodeVector new_args{make_shared<op::Parameter>(element::f32, shape), node->get_argument(1)};
auto new_node = node->copy_with_new_args(new_args); auto new_node = node->copy_with_new_args(new_args);
auto node_cast = dynamic_pointer_cast<op::Sum>(new_node); auto node_cast = dynamic_pointer_cast<op::Sum>(new_node);
ASSERT_NE(node_cast, nullptr); ASSERT_NE(node_cast, nullptr);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment