Commit 9ba4a78a authored by Jayaram Bobba's avatar Jayaram Bobba Committed by Scott Cyphers

Change reduction operations to 2-input dynamic variants (#2972)

* Change reduction operations to 2-input dynamic variants with convenience constructors for cases where reduction AxisSet is known at op construction time

* Modify rest of arithmetic and logical reduction ops to 2-input dynamic variants. Some fixes to existing passes to keep constant reduction axes inputs intact

* add new All tests to GPU manifest
parent ff5d79ca
......@@ -59,6 +59,11 @@ namespace ngraph
static_cast<std::set<size_t>*>(this)->operator=(v);
return *this;
}
std::vector<int64_t> to_vector() const
{
return std::vector<int64_t>(this->begin(), this->end());
}
};
std::ostream& operator<<(std::ostream& s, const AxisSet& axis_set);
......
......@@ -31,8 +31,14 @@ op::All::All(const Output<Node>& arg, const AxisSet& reduction_axes)
constructor_validate_and_infer_types();
}
op::All::All(const Output<Node>& arg, const Output<Node>& reduction_axes)
: LogicalReduction(arg, reduction_axes)
{
constructor_validate_and_infer_types();
}
shared_ptr<Node> op::All::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<All>(new_args.at(0), m_reduction_axes);
return make_shared<All>(new_args.at(0), new_args.at(1));
}
......@@ -39,6 +39,11 @@ namespace ngraph
/// \param arg The tensor to be reduced.
/// \param reduction_axes The axis positions (0-based) to be eliminated.
All(const Output<Node>& arg, const AxisSet& reduction_axes);
/// \brief Constructs an "all" reduction operation.
///
/// \param arg The tensor to be reduced.
/// \param reduction_axes The axis positions (0-based) to be eliminated.
All(const Output<Node>& arg, const Output<Node>& reduction_axes);
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
......
......@@ -31,8 +31,14 @@ op::Any::Any(const Output<Node>& arg, const AxisSet& reduction_axes)
constructor_validate_and_infer_types();
}
op::Any::Any(const Output<Node>& arg, const Output<Node>& reduction_axes)
: LogicalReduction(arg, reduction_axes)
{
constructor_validate_and_infer_types();
}
shared_ptr<Node> op::Any::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<Any>(new_args.at(0), m_reduction_axes);
return make_shared<Any>(new_args.at(0), new_args.at(1));
}
......@@ -39,6 +39,11 @@ namespace ngraph
/// \param arg The tensor to be reduced.
/// \param reduction_axes The axis positions (0-based) to be eliminated.
Any(const Output<Node>& arg, const AxisSet& reduction_axes);
/// \brief Constructs an "any" reduction operation.
///
/// \param arg The tensor to be reduced.
/// \param reduction_axes The axis positions (0-based) to be eliminated.
Any(const Output<Node>& arg, const Output<Node>& reduction_axes);
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
......
......@@ -20,8 +20,20 @@
using namespace std;
using namespace ngraph;
op::Max::Max(const shared_ptr<Node>& arg, const AxisSet& reduction_axes)
: ArithmeticReduction("Max", arg, reduction_axes)
const string op::Max::type_name{"Max"};
op::Max::Max()
{
}
op::Max::Max(const Output<Node>& arg, const AxisSet& reduction_axes)
: ArithmeticReduction(arg, reduction_axes)
{
constructor_validate_and_infer_types();
}
op::Max::Max(const Output<Node>& arg, const Output<Node>& reduction_axes)
: ArithmeticReduction(arg, reduction_axes)
{
constructor_validate_and_infer_types();
}
......@@ -29,7 +41,7 @@ op::Max::Max(const shared_ptr<Node>& arg, const AxisSet& reduction_axes)
shared_ptr<Node> op::Max::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<Max>(new_args.at(0), m_reduction_axes);
return make_shared<Max>(new_args.at(0), new_args.at(1));
}
shared_ptr<Node> op::Max::get_default_value() const
......
......@@ -26,11 +26,21 @@ namespace ngraph
class Max : public util::ArithmeticReduction
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a "max" reduction operation.
Max();
/// \brief Constructs a max-reduction operation.
///
/// \param arg The tensor to be reduced.
/// \param reduction_axes The axis positions (0-based) to be eliminated.
Max(const std::shared_ptr<Node>& arg, const AxisSet& reduction_axes);
/// \param reduction_axes The axis positions (0-based) to be elimaxated.
Max(const Output<Node>& arg, const AxisSet& reduction_axes);
/// \brief Constructs a "max" reduction operation.
///
/// \param arg The tensor to be reduced.
/// \param reduction_axes The axis positions (0-based) to be elimaxated.
Max(const Output<Node>& arg, const Output<Node>& reduction_axes);
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
......
......@@ -20,8 +20,20 @@
using namespace std;
using namespace ngraph;
op::Min::Min(const shared_ptr<Node>& arg, const AxisSet& reduction_axes)
: ArithmeticReduction("Min", arg, reduction_axes)
const string op::Min::type_name{"Min"};
op::Min::Min()
{
}
op::Min::Min(const Output<Node>& arg, const AxisSet& reduction_axes)
: ArithmeticReduction(arg, reduction_axes)
{
constructor_validate_and_infer_types();
}
op::Min::Min(const Output<Node>& arg, const Output<Node>& reduction_axes)
: ArithmeticReduction(arg, reduction_axes)
{
constructor_validate_and_infer_types();
}
......@@ -29,7 +41,7 @@ op::Min::Min(const shared_ptr<Node>& arg, const AxisSet& reduction_axes)
shared_ptr<Node> op::Min::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<Min>(new_args.at(0), m_reduction_axes);
return make_shared<Min>(new_args.at(0), get_reduction_axes());
}
shared_ptr<Node> op::Min::get_default_value() const
......
......@@ -26,11 +26,21 @@ namespace ngraph
class Min : public util::ArithmeticReduction
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a "min" reduction operation.
Min();
/// \brief Constructs a min-reduction operation.
///
/// \param arg The tensor to be reduced.
/// \param reduction_axes The axis positions (0-based) to be eliminated.
Min(const std::shared_ptr<Node>& arg, const AxisSet& reduction_axes);
Min(const Output<Node>& arg, const AxisSet& reduction_axes);
/// \brief Constructs a "min" reduction operation.
///
/// \param arg The tensor to be reduced.
/// \param reduction_axes The axis positions (0-based) to be eliminated.
Min(const Output<Node>& arg, const Output<Node>& reduction_axes);
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
......
......@@ -19,8 +19,20 @@
using namespace std;
using namespace ngraph;
op::Product::Product(const shared_ptr<Node>& arg, const AxisSet& reduction_axes)
: ArithmeticReduction("Product", arg, reduction_axes)
const string op::Product::type_name{"Product"};
op::Product::Product()
{
}
op::Product::Product(const Output<Node>& arg, const AxisSet& reduction_axes)
: ArithmeticReduction(arg, reduction_axes)
{
constructor_validate_and_infer_types();
}
op::Product::Product(const Output<Node>& arg, const Output<Node>& reduction_axes)
: ArithmeticReduction(arg, reduction_axes)
{
constructor_validate_and_infer_types();
}
......@@ -28,5 +40,5 @@ op::Product::Product(const shared_ptr<Node>& arg, const AxisSet& reduction_axes)
shared_ptr<Node> op::Product::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<Product>(new_args.at(0), m_reduction_axes);
return make_shared<Product>(new_args.at(0), get_reduction_axes());
}
......@@ -29,11 +29,21 @@ namespace ngraph
class Product : public util::ArithmeticReduction
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a product reduction operation.
Product();
/// \brief Constructs a product reduction operation.
///
/// \param arg The tensor to be reduced.
/// \param reduction_axes The axis positions (0-based) to be eliminated.
Product(const Output<Node>& arg, const AxisSet& reduction_axes);
/// \brief Constructs a product reduction operation.
///
/// \param arg The tensor to be reduced.
/// \param reduction_axes The axis positions (0-based) to be eliminated.
Product(const std::shared_ptr<Node>& arg, const AxisSet& reduction_axes);
Product(const Output<Node>& arg, const Output<Node>& reduction_axes);
/// \return The default value for Product.
virtual std::shared_ptr<Node> get_default_value() const override
......
......@@ -20,8 +20,20 @@
using namespace std;
using namespace ngraph;
op::Sum::Sum(const shared_ptr<Node>& arg, const AxisSet& reduction_axes)
: ArithmeticReduction("Sum", arg, reduction_axes)
const string op::Sum::type_name{"Sum"};
op::Sum::Sum()
{
}
op::Sum::Sum(const Output<Node>& arg, const AxisSet& reduction_axes)
: ArithmeticReduction(arg, reduction_axes)
{
constructor_validate_and_infer_types();
}
op::Sum::Sum(const Output<Node>& arg, const Output<Node>& reduction_axes)
: ArithmeticReduction(arg, reduction_axes)
{
constructor_validate_and_infer_types();
}
......@@ -29,7 +41,7 @@ op::Sum::Sum(const shared_ptr<Node>& arg, const AxisSet& reduction_axes)
shared_ptr<Node> op::Sum::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<Sum>(new_args.at(0), m_reduction_axes);
return make_shared<Sum>(new_args.at(0), new_args.at(1));
}
void op::Sum::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
......@@ -39,5 +51,5 @@ void op::Sum::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector&
auto x = get_argument(0);
auto& x_shape = input(0).get_shape();
adjoints.add_delta(x, make_shared<op::Broadcast>(delta, x_shape, m_reduction_axes));
adjoints.add_delta(x, make_shared<op::Broadcast>(delta, x_shape, get_reduction_axes()));
}
......@@ -74,11 +74,21 @@ namespace ngraph
class Sum : public util::ArithmeticReduction
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a summation operation.
Sum();
/// \brief Constructs a summation operation.
///
/// \param arg The tensor to be summed.
/// \param reduction_axes The axis positions (0-based) to be eliminated.
Sum(const Output<Node>& arg, const AxisSet& reduction_axes);
/// \brief Constructs a summation operation.
///
/// \param arg The tensor to be summed.
/// \param reduction_axes The axis positions (0-based) to be eliminated.
Sum(const std::shared_ptr<Node>& arg, const AxisSet& reduction_axes);
Sum(const Output<Node>& arg, const Output<Node>& reduction_axes);
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
......
......@@ -17,26 +17,65 @@
#include <memory>
#include "ngraph/axis_vector.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/topk.hpp"
#include "ngraph/shape.hpp"
using namespace std;
using namespace ngraph;
op::TopK::TopK(const shared_ptr<Node>& arg,
const string op::TopK::type_name{"TopK"};
op::TopK::TopK()
{
}
op::TopK::TopK(const Output<Node>& arg,
size_t top_k_axis,
const element::Type& index_element_type,
size_t k,
bool compute_max)
: Op("TopK", check_single_output_args({arg}))
: Op({arg, op::Constant::create(element::i64, Shape{1}, {k})->output(0)})
, m_top_k_axis(top_k_axis)
, m_index_element_type(index_element_type)
, m_compute_max(compute_max)
{
constructor_validate_and_infer_types();
}
op::TopK::TopK(const Output<Node>& arg,
const Output<Node>& k,
size_t top_k_axis,
const element::Type& index_element_type,
bool compute_max)
: Op({arg, k})
, m_top_k_axis(top_k_axis)
, m_index_element_type(index_element_type)
, m_k(k)
, m_compute_max(compute_max)
{
constructor_validate_and_infer_types();
}
size_t op::TopK::get_k() const
{
size_t k = 0;
if (auto const_op = dynamic_pointer_cast<op::Constant>(get_argument(1)))
{
k = const_op->get_vector<int64_t>()[0];
}
if (k == 0 && get_input_partial_shape(0).is_static())
{
k = get_input_partial_shape(0).to_shape()[m_top_k_axis];
}
return k;
}
void op::TopK::set_k(size_t k)
{
this->input(1).replace_source_output(
op::Constant::create(element::i64, Shape{1}, {k})->output(0));
}
void op::TopK::validate_and_infer_types()
{
const PartialShape& input_shape = get_input_partial_shape(0);
......@@ -63,11 +102,12 @@ void op::TopK::validate_and_infer_types()
m_top_k_axis,
") is out of bounds.");
size_t k = get_k();
NODE_VALIDATION_CHECK(this,
input_rank.is_dynamic() || input_shape[m_top_k_axis].is_dynamic() ||
m_k <= static_cast<size_t>(input_shape[m_top_k_axis]),
k <= static_cast<size_t>(input_shape[m_top_k_axis]),
"K (",
m_k,
k,
") exceeds the dimension (",
(input_rank.is_static() ? input_shape[m_top_k_axis] : 0),
") of the TopK axis (axis ",
......@@ -76,16 +116,9 @@ void op::TopK::validate_and_infer_types()
PartialShape output_shape{input_shape};
if (input_rank.is_static())
if (input_rank.is_static() && k != 0)
{
if (m_k != 0)
{
output_shape[m_top_k_axis] = m_k;
}
else if (input_shape[m_top_k_axis].is_static())
{
m_k = static_cast<size_t>(input_shape[m_top_k_axis]);
}
output_shape[m_top_k_axis] = k;
}
set_output_size(2);
......@@ -97,7 +130,7 @@ shared_ptr<Node> op::TopK::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<TopK>(
new_args.at(0), m_top_k_axis, m_index_element_type, m_k, m_compute_max);
new_args.at(0), new_args.at(1), m_top_k_axis, m_index_element_type, m_compute_max);
}
void op::TopK::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
......
......@@ -30,6 +30,11 @@ namespace ngraph
class TopK : public Op
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a TopK operation
TopK();
/// \brief Constructs a TopK operation.
///
/// \param arg The input tensor
......@@ -37,25 +42,38 @@ namespace ngraph
/// \param index_element_type produce indices. Currently, only int64 or int32 are supported
/// \param k Number of top indices to compute. Compute all indices if k = 0
/// \param compute_max Compute top k max or top k min?
TopK(const std::shared_ptr<Node>& arg,
TopK(const Output<Node>& arg,
size_t top_k_axis,
const element::Type& index_element_type,
size_t k = 0,
bool compute_max = true);
/// \brief Constructs a TopK operation.
///
/// \param arg The input tensor
/// \param k Number of top indices to compute. Compute all indices if k = 0
/// \param top_k_axis The axis along which to compute top k indices
/// \param index_element_type produce indices. Currently, only int64 or int32 are supported
/// \param compute_max Compute top k max or top k min?
TopK(const Output<Node>& arg,
const Output<Node>& k,
size_t top_k_axis,
const element::Type& index_element_type,
bool compute_max = true);
void validate_and_infer_types() override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
size_t get_k() const;
void set_k(size_t k);
size_t get_top_k_axis() const { return m_top_k_axis; }
element::Type get_index_element_type() const { return m_index_element_type; }
size_t get_k() const { return m_k; }
bool get_compute_max() const { return m_compute_max; }
protected:
size_t m_top_k_axis;
element::Type m_index_element_type;
size_t m_k;
bool m_compute_max;
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
......
......@@ -15,6 +15,7 @@
//*****************************************************************************
#include "ngraph/op/util/arithmetic_reduction.hpp"
#include "ngraph/op/constant.hpp"
using namespace std;
using namespace ngraph;
......@@ -23,28 +24,41 @@ op::util::ArithmeticReduction::ArithmeticReduction()
{
}
op::util::ArithmeticReduction::ArithmeticReduction(const std::shared_ptr<Node>& arg,
op::util::ArithmeticReduction::ArithmeticReduction(const Output<Node>& arg,
const AxisSet& reduction_axes)
: Op(check_single_output_args({arg}))
: Op({arg,
op::Constant::create(
element::i64, Shape{reduction_axes.size()}, reduction_axes.to_vector())
->output(0)})
{
set_reduction_axes(reduction_axes);
}
op::util::ArithmeticReduction::ArithmeticReduction(const std::string& node_type,
const std::shared_ptr<Node>& arg,
const AxisSet& reduction_axes)
: Op(node_type, check_single_output_args({arg}))
op::util::ArithmeticReduction::ArithmeticReduction(const Output<Node>& arg,
const Output<Node>& reduction_axes)
: Op({arg, reduction_axes})
{
}
const AxisSet op::util::ArithmeticReduction::get_reduction_axes() const
{
set_reduction_axes(reduction_axes);
AxisSet axes;
if (auto const_op = dynamic_pointer_cast<op::Constant>(get_argument(1)))
{
axes = const_op->get_axis_set_val();
}
return axes;
}
void op::util::ArithmeticReduction::set_reduction_axes(const AxisSet& reduction_axes)
{
m_reduction_axes = reduction_axes;
this->input(1).replace_source_output(
op::Constant::create(element::i64, Shape{reduction_axes.size()}, reduction_axes.to_vector())
->output(0));
}
void op::util::ArithmeticReduction::validate_and_infer_types()
{
auto reduction_axes = get_reduction_axes();
auto input_shape = get_input_partial_shape(0);
auto input_rank = input_shape.rank();
......@@ -54,7 +68,7 @@ void op::util::ArithmeticReduction::validate_and_infer_types()
{
std::vector<Dimension> dims;
for (auto axis : m_reduction_axes)
for (auto axis : reduction_axes)
{
NODE_VALIDATION_CHECK(this,
axis < size_t(input_rank),
......@@ -64,13 +78,13 @@ void op::util::ArithmeticReduction::validate_and_infer_types()
"(argument shape: ",
input_shape,
", reduction axes: ",
m_reduction_axes,
reduction_axes,
")");
}
for (size_t i = 0; i < size_t(input_rank); i++)
{
if (m_reduction_axes.count(i) == 0)
if (reduction_axes.count(i) == 0)
{
dims.push_back(input_shape[i]);
}
......
......@@ -37,32 +37,19 @@ namespace ngraph
/// \param arg Output that produces the first input tensor.
/// \param reduction_axes The axis positions (0-based) to be eliminated.
ArithmeticReduction(const Output<Node>& arg, const AxisSet& reduction_axes);
/// \brief Constructs an arithmetic reduction operation.
///
/// \param arg Node that produces the first input tensor.
/// \param reduction_axes The axis positions (0-based) to be eliminated.
ArithmeticReduction(const std::shared_ptr<Node>& arg,
const AxisSet& reduction_axes);
/// \brief Constructs an arithmetic reduction operation.
///
/// \param arg Node that produces the first input tensor.
/// \param arg Output that produces the first input tensor.
/// \param reduction_axes The axis positions (0-based) to be eliminated.
ArithmeticReduction(const std::string& node_type,
const std::shared_ptr<Node>& arg,
const AxisSet& reduction_axes);
ArithmeticReduction(const Output<Node>& arg, const Output<Node>& reduction_axes);
public:
void validate_and_infer_types() override;
/// \return The axis positions (0-based) to be eliminated through reduction.
const AxisSet& get_reduction_axes() const { return m_reduction_axes; }
const AxisSet get_reduction_axes() const;
/// \brief Change the reduction axes
void set_reduction_axes(const AxisSet& reduction_axes);
protected:
AxisSet m_reduction_axes;
};
}
}
......
......@@ -15,6 +15,7 @@
//*****************************************************************************
#include "ngraph/op/util/logical_reduction.hpp"
#include "ngraph/op/constant.hpp"
using namespace std;
using namespace ngraph;
......@@ -24,38 +25,39 @@ op::util::LogicalReduction::LogicalReduction()
}
op::util::LogicalReduction::LogicalReduction(const Output<Node>& arg, const AxisSet& reduction_axes)
: Op({arg})
: Op({arg,
op::Constant::create(
element::i64, Shape{reduction_axes.size()}, reduction_axes.to_vector())
->output(0)})
{
set_reduction_axes(reduction_axes);
}
op::util::LogicalReduction::LogicalReduction(const std::shared_ptr<Node>& arg,
const AxisSet& reduction_axes)
: Op(check_single_output_args({arg}))
op::util::LogicalReduction::LogicalReduction(const Output<Node>& arg,
const Output<Node>& reduction_axes)
: Op({arg, reduction_axes})
{
set_reduction_axes(reduction_axes);
}
op::util::LogicalReduction::LogicalReduction(const std::string& node_type,
const std::shared_ptr<Node>& arg,
const AxisSet& reduction_axes)
: Op(node_type, check_single_output_args({arg}))
const AxisSet op::util::LogicalReduction::get_reduction_axes() const
{
set_reduction_axes(reduction_axes);
}
const AxisSet& op::util::LogicalReduction::get_reduction_axes() const
{
return m_reduction_axes;
AxisSet axes;
if (auto const_op = dynamic_pointer_cast<op::Constant>(get_argument(1)))
{
axes = const_op->get_axis_set_val();
}
return axes;
}
void op::util::LogicalReduction::set_reduction_axes(const AxisSet& reduction_axes)
{
m_reduction_axes = reduction_axes;
this->input(1).replace_source_output(
op::Constant::create(element::i64, Shape{reduction_axes.size()}, reduction_axes.to_vector())
->output(0));
}
void op::util::LogicalReduction::validate_and_infer_types()
{
auto reduction_axes = get_reduction_axes();
auto input_shape = get_input_partial_shape(0);
auto input_rank = input_shape.rank();
......@@ -65,7 +67,7 @@ void op::util::LogicalReduction::validate_and_infer_types()
{
std::vector<Dimension> dims;
for (auto axis : m_reduction_axes)
for (auto axis : reduction_axes)
{
NODE_VALIDATION_CHECK(this,
axis < size_t(input_rank),
......@@ -75,13 +77,13 @@ void op::util::LogicalReduction::validate_and_infer_types()
"(argument shape: ",
input_shape,
", reduction axes: ",
m_reduction_axes,
reduction_axes,
")");
}
for (size_t i = 0; i < size_t(input_rank); i++)
{
if (m_reduction_axes.count(i) == 0)
if (reduction_axes.count(i) == 0)
{
dims.push_back(input_shape[i]);
}
......
......@@ -36,28 +36,18 @@ namespace ngraph
/// \param arg Output that produces the first input tensor.
/// \param reduction_axes The axis positions (0-based) to be eliminated.
LogicalReduction(const Output<Node>& arg, const AxisSet& reduction_axes);
/// \brief Constructs a logical reduction operation.
/// \brief Constructs a 'dynamic' logical reduction operation.
///
/// \param arg Node that produces the first input tensor.
/// \param reduction_axes The axis positions (0-based) to be eliminated.
LogicalReduction(const std::shared_ptr<Node>& arg, const AxisSet& reduction_axes);
/// \brief Constructs a logical reduction operation.
///
/// \param arg Node that produces the first input tensor.
/// \param reduction_axes The axis positions (0-based) to be eliminated.
LogicalReduction(const std::string& node_type,
const std::shared_ptr<Node>& arg,
const AxisSet& reduction_axes);
LogicalReduction(const Output<Node>& arg, const Output<Node>& reduction_axes);
public:
void validate_and_infer_types() override;
/// \return The axis positions (0-based) to be eliminated through reduction.
const AxisSet& get_reduction_axes() const;
const AxisSet get_reduction_axes() const;
void set_reduction_axes(const AxisSet& reduction_axes);
protected:
AxisSet m_reduction_axes;
};
}
}
......
......@@ -250,7 +250,11 @@ static void materialize_shapes(shared_ptr<Node> n,
NGRAPH_DEBUG << "Materializing " << describe_reshape(reorders.at(arg)) << " for "
<< arg->get_name();
mark_reshape_for_deletion(reorders.at(arg), reshapes_to_delete);
insert_reshape(n, reorders.at(arg), i);
if (reorders.at(arg)->get_input_order() != get_default_order(arg->get_shape()))
{
// Insert if arg needs to be transposed.
insert_reshape(n, reorders.at(arg), i);
}
//no swimming up
}
}
......
......@@ -44,7 +44,7 @@ static bool verify_no_internal_zero_length_ops(shared_ptr<Function> f)
set<Output<Node>> zero_length_source_outputs;
for (auto n : f->get_ordered_ops())
{
if (n->is_output() || n->is_parameter() || n->get_output_size() > 1)
if (n->is_output() || n->is_parameter() || n->is_constant() || n->get_output_size() > 1)
{
continue;
}
......
......@@ -78,6 +78,8 @@ any_2x2x3_eliminate_dims_0_1
any_2x2x3_eliminate_dims_0_2
any_2x2x3_eliminate_dims_1_2
any_2x2x3_eliminate_dims_0_1_2
all_dynamic_axis
all_change_axis
all_trivial
all_2x2_to_scalar_false
all_2x2_to_scalar_true
......
......@@ -178,11 +178,12 @@ void runtime::hybrid::rewrite_function(const shared_ptr<Function>& f,
{
for (auto input : node->get_arguments())
{
if (input->get_placement_index() == 0)
if (input->get_placement_index() == 0 && !input->is_constant())
{
// Since this input is from outside the cluster we need to create
// a new Parameter node placed in the cluster instead of this external
// node
// node. Constant nodes are ignored here since the values are available
// in the graph.
std::vector<Output<Node>> source_outputs =
get_outputs_to(*input, *node);
NGRAPH_CHECK(
......
......@@ -1053,7 +1053,7 @@ shared_ptr<runtime::Executable>
}
case OP_TYPEID::Sum:
{
arguments_check(op, 1, 1);
arguments_check(op, 2, 1);
const shared_ptr<op::Sum> sum = static_pointer_cast<op::Sum>(op);
const AxisSet& axis = sum->get_reduction_axes();
......@@ -1071,7 +1071,7 @@ shared_ptr<runtime::Executable>
}
case OP_TYPEID::Product:
{
arguments_check(op, 1, 1);
arguments_check(op, 2, 1);
const shared_ptr<op::Product> prod = static_pointer_cast<op::Product>(op);
const AxisSet& axis = prod->get_reduction_axes();
......@@ -1142,7 +1142,7 @@ shared_ptr<runtime::Executable>
}
case OP_TYPEID::All:
{
arguments_check(op, 1, 1);
arguments_check(op, 2, 1);
// Empty axis is not a case for do_equal_propagation()
kern.emit<op::All>(static_pointer_cast<op::All>(op));
......@@ -1150,7 +1150,7 @@ shared_ptr<runtime::Executable>
}
case OP_TYPEID::Any:
{
arguments_check(op, 1, 1);
arguments_check(op, 2, 1);
// Empty axis is not a case for do_equal_propagation()
kern.emit<op::Any>(static_pointer_cast<op::Any>(op));
......@@ -1850,14 +1850,14 @@ shared_ptr<runtime::Executable>
}
case OP_TYPEID::Min:
{
arguments_check(op, 1, 1);
arguments_check(op, 2, 1);
kern.emit<op::Min>(static_pointer_cast<op::Min>(op));
break;
}
case OP_TYPEID::Max:
{
arguments_check(op, 1, 1);
arguments_check(op, 2, 1);
kern.emit<op::Max>(static_pointer_cast<op::Max>(op));
break;
......@@ -2009,7 +2009,7 @@ shared_ptr<runtime::Executable>
}
case OP_TYPEID::TopK:
{
arguments_check(op, 1, 2);
arguments_check(op, 2, 2);
const shared_ptr<op::TopK> topk_op = static_pointer_cast<op::TopK>(op);
......
......@@ -274,3 +274,45 @@ NGRAPH_TEST(${BACKEND_NAME}, all_2x2x3_eliminate_dims_0_1_2)
handle->call_with_validate({result}, {a});
EXPECT_EQ((vector<char>{0}), read_vector<char>(result));
}
NGRAPH_TEST(${BACKEND_NAME}, all_dynamic_axis)
{
Shape shape{2, 3};
auto A = make_shared<op::Parameter>(element::boolean, shape);
auto B = op::Constant::create(element::i64, Shape{1}, {1});
auto f = make_shared<Function>(make_shared<op::All>(A, B), ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::boolean, shape);
copy_data(a, test::NDArray<char, 2>({{1, 0, 1}, {1, 1, 1}}).get_vector());
auto result = backend->create_tensor(element::boolean, Shape{2});
auto handle = backend->compile(f);
handle->call_with_validate({result}, {a});
EXPECT_EQ((vector<char>{0, 1}), read_vector<char>(result));
}
NGRAPH_TEST(${BACKEND_NAME}, all_change_axis)
{
Shape shape{2, 3};
auto A = make_shared<op::Parameter>(element::boolean, shape);
auto B = op::Constant::create(element::i64, Shape{1}, {1});
auto all = make_shared<op::All>(A, B);
ASSERT_EQ(all->get_reduction_axes(), AxisSet{1});
auto f = make_shared<Function>(all, ParameterVector{A});
all->set_reduction_axes(AxisSet{0});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::boolean, shape);
copy_data(a, test::NDArray<char, 2>({{1, 0, 1}, {1, 1, 1}}).get_vector());
auto result = backend->create_tensor(element::boolean, Shape{3});
auto handle = backend->compile(f);
handle->call_with_validate({result}, {a});
EXPECT_EQ((vector<char>{1, 0, 1}), read_vector<char>(result));
}
......@@ -344,9 +344,9 @@ TEST(copy, sum)
Shape shape{4, 3};
AxisSet axes{1};
auto arg0 = make_shared<op::Parameter>(element::f32, shape);
NodeVector new_args{make_shared<op::Parameter>(element::f32, shape)};
auto node = make_shared<op::Sum>(arg0, axes);
NodeVector new_args{make_shared<op::Parameter>(element::f32, shape), node->get_argument(1)};
auto new_node = node->copy_with_new_args(new_args);
auto node_cast = dynamic_pointer_cast<op::Sum>(new_node);
ASSERT_NE(node_cast, nullptr);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment