Commit e741f8f1 authored by Jayaram Bobba's avatar Jayaram Bobba Committed by Scott Cyphers

[SPEC] Added support for v1 Broadcast op specification (#3737)

* - Added support for v1 Broadcast op specification
- Added upgrade/downgrade conversions between v0 and v1

* Added unit test for pdpd broadcast

* Make numpy default autobroadcast type and some style fixes

* Added support in Dynamic wrapper for dyn elimination and copied over unit tests from DynBroadcast

* Addressed PR feedback

* Addressed PR feedback on documentation
parent 4ab4609f
......@@ -15,16 +15,263 @@
//*****************************************************************************
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/sum.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/partial_shape.hpp"
#include <numeric>
using namespace std;
using namespace ngraph;
constexpr NodeTypeInfo op::Broadcast::type_info;
constexpr NodeTypeInfo op::v1::Broadcast::type_info;
op::v1::Broadcast::Broadcast(const Output<Node>& arg,
const Output<Node>& target_shape,
const Output<Node>& axes_mapping,
const AutoBroadcastSpec& broadcast_spec)
: Op({arg, target_shape, axes_mapping})
, m_broadcast_spec(broadcast_spec)
{
constructor_validate_and_infer_types();
}
op::v1::Broadcast::Broadcast(const Output<Node>& arg,
const Output<Node>& target_shape,
const AutoBroadcastSpec& broadcast_spec)
: Op({arg, target_shape, op::Constant::create(element::u8, Shape{}, {0})->output(0)})
, m_broadcast_spec(broadcast_spec)
{
constructor_validate_and_infer_types();
}
std::pair<bool, AxisSet> op::v1::Broadcast::get_broadcast_axes() const
{
AxisSet broadcast_axes;
bool axes_known = false;
if (m_broadcast_spec.m_type == AutoBroadcastType::NONE)
{
if (input(1).get_partial_shape().is_static() &&
input_value(2).get_node_shared_ptr()->is_constant())
{
auto target_shape = input(1).get_shape();
NGRAPH_CHECK(target_shape.size() == 1);
auto axes_mapping_val =
static_pointer_cast<op::Constant>(input_value(2).get_node_shared_ptr())
->get_axis_vector_val();
std::vector<size_t> axes(target_shape[0]);
std::iota(axes.begin(), axes.end(), 0);
for (auto i = axes_mapping_val.rbegin(); i != axes_mapping_val.rend(); ++i)
{
axes.erase(axes.begin() + *i);
}
broadcast_axes.insert(axes.begin(), axes.end());
axes_known = true;
}
}
else if (m_broadcast_spec.m_type == AutoBroadcastType::NUMPY ||
m_broadcast_spec.m_type == AutoBroadcastType::PDPD)
{
if (input(0).get_partial_shape().is_static() &&
input_value(1).get_node_shared_ptr()->is_constant())
{
auto arg_shape = input(0).get_shape();
auto target_shape =
static_pointer_cast<op::Constant>(input_value(1).get_node_shared_ptr())
->get_shape_val();
auto start_axis = (m_broadcast_spec.m_type == AutoBroadcastType::PDPD)
? m_broadcast_spec.m_axis
: target_shape.size() - arg_shape.size();
NGRAPH_CHECK(start_axis >= 0);
for (size_t i = 0; i < target_shape.size(); i++)
{
if (i < start_axis || target_shape[i] != arg_shape[i - start_axis])
{
broadcast_axes.insert(i);
}
}
axes_known = true;
}
}
else
{
throw ngraph_error("Unknown autobroadcast type");
}
return std::make_pair(axes_known, broadcast_axes);
}
void op::v1::Broadcast::validate_and_infer_types()
{
// shape node should have integer data type. For now we only allow i64
auto shape_et = get_input_element_type(1);
NODE_VALIDATION_CHECK(this,
shape_et.compatible(element::Type_t::i64),
"Broadcast shape must have element type i64, but has ",
shape_et);
// shape node should produce a one dimensional shape.
auto broadcast_shape_rank = get_input_partial_shape(1).rank();
NODE_VALIDATION_CHECK(this,
broadcast_shape_rank.compatible(1),
"Broadcast shape rank must be 1, but has ",
broadcast_shape_rank);
if (m_broadcast_spec.m_type == AutoBroadcastType::NONE)
{
// axes_mapping node should have integer data type. For now we only allow i64
auto axes_et = get_input_element_type(2);
NODE_VALIDATION_CHECK(this,
axes_et.compatible(element::Type_t::i64),
"Broadcast axes must have element type i64, but has ",
axes_et);
// axes_mapping node should produce a one dimensional shape.
auto axes_shape_rank = get_input_partial_shape(2).rank();
NODE_VALIDATION_CHECK(this,
axes_shape_rank.compatible(1),
"Broadcast axes rank must be 1, but has ",
axes_shape_rank);
}
PartialShape result_shape{PartialShape::dynamic()};
if (input_value(1).get_node_shared_ptr()->is_constant())
{
result_shape = static_pointer_cast<op::Constant>(input_value(1).get_node_shared_ptr())
->get_shape_val();
}
if (m_broadcast_spec.m_type == AutoBroadcastType::NONE)
{
// Validate axes_mapping
if (input(0).get_partial_shape().is_static() && input(1).get_partial_shape().is_static() &&
input(2).get_partial_shape().is_static())
{
auto arg_shape = input(0).get_shape();
auto axes_shape = input(2).get_shape();
// Rank(arg_shape) == shape_size(axes_mapping)
NODE_VALIDATION_CHECK(this,
shape_size(axes_shape) == arg_shape.size(),
"Broadcast axes_mapping shape ",
axes_shape,
" doesn't match rank of input tensor ",
arg_shape.size());
if (input_value(1).get_node_shared_ptr()->is_constant() &&
input_value(2).get_node_shared_ptr()->is_constant())
{
auto target_shape =
static_pointer_cast<op::Constant>(input_value(1).get_node_shared_ptr())
->get_shape_val();
auto axes_mapping_val =
static_pointer_cast<op::Constant>(input_value(2).get_node_shared_ptr())
->get_axis_vector_val();
// axes_mapping needs to be in sorted order
NODE_VALIDATION_CHECK(
this,
std::is_sorted(axes_mapping_val.begin(), axes_mapping_val.end()),
"Broadcast doesn't permit transposes. axes_mapping ",
axes_mapping_val,
" not in sorted order");
for (size_t i = 0; i < axes_mapping_val.size(); i++)
{
NODE_VALIDATION_CHECK(this,
axes_mapping_val[i] < target_shape.size(),
"Broadcast axes_mapping[",
i,
"]: ",
axes_mapping_val[i],
" exceeds target rank ",
target_shape.size());
NODE_VALIDATION_CHECK(this,
target_shape[axes_mapping_val[i]] == arg_shape[i],
"Broadcast target[axes_mapping[",
i,
"]]",
" Expected ",
arg_shape[i],
". Got ",
target_shape[axes_mapping_val[i]]);
}
}
}
}
else if (m_broadcast_spec.m_type == AutoBroadcastType::NUMPY ||
m_broadcast_spec.m_type == AutoBroadcastType::PDPD)
{
if (input(0).get_partial_shape().is_static() && input(1).get_partial_shape().is_static())
{
auto arg_shape = input(0).get_shape();
if (input_value(1).get_node_shared_ptr()->is_constant())
{
auto target_shape =
static_pointer_cast<op::Constant>(input_value(1).get_node_shared_ptr())
->get_shape_val();
auto start_axis = (m_broadcast_spec.m_type == AutoBroadcastType::PDPD)
? m_broadcast_spec.m_axis
: target_shape.size() - arg_shape.size();
NODE_VALIDATION_CHECK(this,
start_axis >= 0,
"Broadcast target_shape has smaller rank ",
target_shape.size(),
" than arg shape ",
arg_shape.size());
for (auto i = start_axis; i < target_shape.size(); i++)
{
NODE_VALIDATION_CHECK(this,
arg_shape[i - start_axis] == 1 ||
arg_shape[i - start_axis] == target_shape[i],
"Broadcast incorrect target shape. Expecting ",
arg_shape[i - start_axis],
" . Got ",
target_shape[i]);
}
}
}
}
set_input_is_relevant_to_shape(0); // arg - Result element type
set_input_is_relevant_to_shape(1); // target_shape - Result shape
set_input_is_relevant_to_shape(2); // axes_mapping - Broadcast type
set_output_type(0, get_input_element_type(0), result_shape);
}
shared_ptr<Node> op::v1::Broadcast::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<v1::Broadcast>(
new_args.at(0), new_args.at(1), new_args.at(2), m_broadcast_spec);
}
void op::v1::Broadcast::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
{
auto delta = deltas.at(0);
auto x = input_value(0);
auto broadcast_axes = get_broadcast_axes();
if (broadcast_axes.first)
{
adjoints.add_delta(x, make_shared<op::Sum>(delta, broadcast_axes.second));
}
else
{
throw ngraph_error("Autodiff not supported on dynamic op variants");
}
}
constexpr NodeTypeInfo op::v0::Broadcast::type_info;
op::Broadcast::Broadcast(const OutputVector& args,
const Shape& shape,
const AxisSet& broadcast_axes)
op::v0::Broadcast::Broadcast(const OutputVector& args,
const Shape& shape,
const AxisSet& broadcast_axes)
: Op(args)
, m_shape(shape)
, m_broadcast_axes(broadcast_axes)
......@@ -32,12 +279,14 @@ op::Broadcast::Broadcast(const OutputVector& args,
constructor_validate_and_infer_types();
}
op::Broadcast::Broadcast(const Output<Node>& arg, const Shape& shape, const AxisSet& broadcast_axes)
op::v0::Broadcast::Broadcast(const Output<Node>& arg,
const Shape& shape,
const AxisSet& broadcast_axes)
: Broadcast(OutputVector{arg}, shape, broadcast_axes)
{
}
void op::Broadcast::validate_and_infer_types()
void op::v0::Broadcast::validate_and_infer_types()
{
infer_shape();
......@@ -80,13 +329,13 @@ void op::Broadcast::validate_and_infer_types()
set_output_type(0, get_input_element_type(0), m_shape);
}
shared_ptr<Node> op::Broadcast::copy_with_new_args(const NodeVector& new_args) const
shared_ptr<Node> op::v0::Broadcast::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<Broadcast>(new_args.at(0), m_shape, m_broadcast_axes);
return make_shared<v0::Broadcast>(new_args.at(0), m_shape, m_broadcast_axes);
}
void op::Broadcast::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
void op::v0::Broadcast::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
{
auto delta = deltas.at(0);
......@@ -95,27 +344,27 @@ void op::Broadcast::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVe
adjoints.add_delta(x, make_shared<op::Sum>(delta, m_broadcast_axes));
}
constexpr NodeTypeInfo op::BroadcastLike::type_info;
constexpr NodeTypeInfo op::v0::BroadcastLike::type_info;
op::BroadcastLike::BroadcastLike(const Output<Node>& arg,
const Output<Node>& like_arg,
const AxisSet& initial_broadcast_axes)
: Broadcast({arg, like_arg}, {}, {})
op::v0::BroadcastLike::BroadcastLike(const Output<Node>& arg,
const Output<Node>& like_arg,
const AxisSet& initial_broadcast_axes)
: op::v0::Broadcast({arg, like_arg}, {}, {})
, m_initial_broadcast_axes(initial_broadcast_axes)
{
constructor_validate_and_infer_types();
}
shared_ptr<Node> op::BroadcastLike::copy_with_new_args(const NodeVector& new_args) const
shared_ptr<Node> op::v0::BroadcastLike::copy_with_new_args(const NodeVector& new_args) const
{
if (new_args.size() != 2)
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<BroadcastLike>(new_args.at(0), new_args.at(1), m_initial_broadcast_axes);
return make_shared<v0::BroadcastLike>(new_args.at(0), new_args.at(1), m_initial_broadcast_axes);
}
void op::BroadcastLike::infer_shape()
void op::v0::BroadcastLike::infer_shape()
{
const Shape& in_shape = get_input_shape(0);
m_shape = get_input_shape(1);
......
......@@ -18,88 +18,169 @@
#include "ngraph/axis_set.hpp"
#include "ngraph/op/op.hpp"
#include "ngraph/op/util/attr_types.hpp"
namespace ngraph
{
namespace op
{
/// \brief Operation which "adds" axes to an input tensor, replicating elements from the
/// input as needed along the new axes.
class Broadcast : public Op
namespace v0
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"Broadcast", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs a broadcast operation.
Broadcast() = default;
/// \brief Constructs a broadcast operation.
///
/// \param arg Node that produces the input tensor to be broadcast.
/// \param shape The shape of the output tensor.
/// \param broadcast_axes The axis positions (0-based) in the result that are being
/// broadcast. The remaining axes in shape must be the same as
/// the shape of arg.
Broadcast(const Output<Node>& arg, const Shape& shape, const AxisSet& broadcast_axes);
void validate_and_infer_types() override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
/// \return A set containing the indices of the broadcast axes (0-based).
const AxisSet& get_broadcast_axes() const { return m_broadcast_axes; }
void set_broadcast_axes(const AxisSet& broadcast_axes)
/// \brief Operation which "adds" axes to an input tensor, replicating elements from the
/// input as needed along the new axes.
class Broadcast : public Op
{
m_broadcast_axes = broadcast_axes;
}
const Shape& get_broadcast_shape() const { return m_shape; }
void set_broadcast_shape(const Shape& shape) { m_shape = shape; }
protected:
Broadcast(const OutputVector& args, const Shape& shape, const AxisSet& broadcast_axes);
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
virtual void infer_shape() {}
Shape m_shape;
AxisSet m_broadcast_axes;
};
/// \brief Broadcast arg to the same shape as like_arg.
class BroadcastLike : public Broadcast
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"BroadcastLike", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Broadcast arg to the same shape as like_arg.
BroadcastLike() = default;
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"Broadcast", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs a broadcast operation.
Broadcast() = default;
/// \brief Constructs a broadcast operation.
///
/// \param arg The input tensor to be broadcast.
/// \param shape The shape of the output tensor.
/// \param broadcast_axes The axis positions (0-based) in the result that are being
/// broadcast. The remaining axes in shape must be the same as
/// the shape of arg.
Broadcast(const Output<Node>& arg,
const Shape& shape,
const AxisSet& broadcast_axes);
void validate_and_infer_types() override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
/// \return A set containing the indices of the broadcast axes (0-based).
const AxisSet& get_broadcast_axes() const { return m_broadcast_axes; }
void set_broadcast_axes(const AxisSet& broadcast_axes)
{
m_broadcast_axes = broadcast_axes;
}
const Shape& get_broadcast_shape() const { return m_shape; }
void set_broadcast_shape(const Shape& shape) { m_shape = shape; }
protected:
Broadcast(const OutputVector& args,
const Shape& shape,
const AxisSet& broadcast_axes);
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
virtual void infer_shape() {}
Shape m_shape;
AxisSet m_broadcast_axes;
};
/// \brief Broadcast arg to the same shape as like_arg.
///
/// Once the shape of like_arg is known, this op will be replaced with an equivalent
/// Broadcast op.
///
/// \param arg The argument to be broadcast.
/// \param like_arg Provides the shape for the result.
/// \param initial_broadcast_axes indicates which axes will be broadcast. If empty,
/// arg must be scalar and all axes are broadcast.
BroadcastLike(const Output<Node>& arg,
const Output<Node>& like_arg,
const AxisSet& initial_broadcast_axes);
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
void infer_shape() override;
const AxisSet& get_initial_broadcast_axes() const { return m_initial_broadcast_axes; }
void set_initial_broadcast_axes(const AxisSet& initial_broadcast_axes)
class BroadcastLike : public v0::Broadcast
{
m_initial_broadcast_axes = initial_broadcast_axes;
}
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"BroadcastLike", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Broadcast arg to the same shape as like_arg.
BroadcastLike() = default;
/// \brief Broadcast arg to the same shape as like_arg.
///
/// Once the shape of like_arg is known, this op will be replaced with an equivalent
/// Broadcast op.
///
/// \param arg The argument to be broadcast.
/// \param like_arg Provides the shape for the result.
/// \param initial_broadcast_axes indicates which axes will be broadcast. If empty,
/// arg must be scalar and all axes are broadcast.
BroadcastLike(const Output<Node>& arg,
const Output<Node>& like_arg,
const AxisSet& initial_broadcast_axes);
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
void infer_shape() override;
const AxisSet& get_initial_broadcast_axes() const
{
return m_initial_broadcast_axes;
}
void set_initial_broadcast_axes(const AxisSet& initial_broadcast_axes)
{
m_initial_broadcast_axes = initial_broadcast_axes;
}
protected:
AxisSet m_initial_broadcast_axes;
};
} // namespace v0
namespace v1
{
/// \brief Operation which "adds" axes to an input tensor, replicating elements from the
/// input as needed along the new axes.
class Broadcast : public Op
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"Broadcast", 1};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs a broadcast operation.
Broadcast() = default;
/// \brief Constructs a broadcast operation.
///
/// \param arg The input tensor to be broadcast.
/// \param target_shape The shape of the output tensor.
/// \param axes_mapping The axis positions (0-based) in the result that correspond
/// to input axes. 'Arg' tensor is broadcast along the
/// remaining
/// axes.
/// E.g., Input Shape - [3, 4], Target Shape - [3, 5, 4, 4]
/// axes_mapping - [0, 2] => Broadcast along axes 1 and 3.
/// axes_mapping - [0, 3] => Broadcast along axes 1 and 2.
/// \param broadcast_spec Broadcast specification to use for determining broadcast
/// axes. 'axes_mapping' is ignored if broadcast_spec is not
/// NONE
Broadcast(const Output<Node>& arg,
const Output<Node>& target_shape,
const Output<Node>& axes_mapping,
const AutoBroadcastSpec& broadcast_spec = AutoBroadcastSpec());
/// \brief Constructs a broadcast operation.
///
/// \param arg The input tensor to be broadcast.
/// \param target_shape The shape of the output tensor.
/// \param broadcast_spec Broadcast specification to use for determining broadcast
/// axes
Broadcast(const Output<Node>& arg,
const Output<Node>& target_shape,
const AutoBroadcastSpec& broadcast_spec =
AutoBroadcastSpec(AutoBroadcastType::NUMPY));
size_t get_version() const override { return 1; }
void validate_and_infer_types() override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
/// \return Broadcast Specification.
const AutoBroadcastSpec& get_broadcast_spec() const { return m_broadcast_spec; }
void set_broadcast_spec(const AutoBroadcastSpec& broadcast_spec)
{
m_broadcast_spec = broadcast_spec;
}
/// \return true and the AxisSet if broadcast axes can be fully determined.
std::pair<bool, AxisSet> get_broadcast_axes() const;
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
private:
AutoBroadcastSpec m_broadcast_spec;
};
} // namespace v1
protected:
AxisSet m_initial_broadcast_axes;
};
using v0::Broadcast;
using v0::BroadcastLike;
}
}
......@@ -102,6 +102,7 @@ namespace ngraph
enum class AutoBroadcastType
{
NONE = 0,
EXPLICIT = NONE,
NUMPY,
PDPD
};
......@@ -136,6 +137,11 @@ namespace ngraph
AutoBroadcastType m_type; // Implicit broadcasting algorithm
int64_t m_axis; // Axis to start alignment on
bool operator==(const AutoBroadcastSpec& a) const
{
return a.m_type == m_type && a.m_axis == m_axis;
}
};
std::ostream& operator<<(std::ostream& s, const AutoBroadcastType& type);
......
......@@ -16,6 +16,7 @@
#include "ngraph/pass/opset0_downgrade.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/pad.hpp"
......@@ -83,6 +84,22 @@ bool pass::Opset0Downgrade::run_on_node(shared_ptr<Node> node)
#endif
switch (get_typeid(node))
{
case OP_TYPEID::Broadcast:
{
auto tmp = dynamic_cast<const op::v1::Broadcast*>(node.get());
const auto arg = node->input(0).get_source_output();
NGRAPH_CHECK(node->input_value(1).get_node_shared_ptr()->is_constant());
auto target_shape =
static_pointer_cast<op::Constant>(node->input_value(1).get_node_shared_ptr())
->get_shape_val();
NGRAPH_CHECK(tmp->get_broadcast_axes().first);
auto replacement_node =
make_shared<op::v0::Broadcast>(arg, target_shape, tmp->get_broadcast_axes().second);
replace_node(node, replacement_node);
modified = true;
break;
}
case OP_TYPEID::Pad:
{
auto tmp = as_type_ptr<op::v1::Pad>(node);
......
......@@ -16,6 +16,7 @@
#include "ngraph/pass/opset1_upgrade.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/op/avg_pool.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/experimental/dyn_reshape.hpp"
......@@ -33,6 +34,7 @@
#include "ngraph/op/topk.hpp"
#include <limits>
#include <numeric>
using namespace std;
using namespace ngraph;
......@@ -137,6 +139,31 @@ bool pass::Opset1Upgrade::run_on_node(shared_ptr<Node> node)
modified = true;
break;
}
case OP_TYPEID::Broadcast:
{
auto tmp = dynamic_cast<const op::v0::Broadcast*>(node.get());
auto result_shape = tmp->get_broadcast_shape();
auto result_shape_node =
op::Constant::create(element::i64, Shape{result_shape.size()}, result_shape);
auto broadcast_axes = tmp->get_broadcast_axes();
// Flip broadcast_axes to get axes_mapping
std::vector<size_t> axes_mapping(result_shape.size());
std::iota(axes_mapping.begin(), axes_mapping.end(), 0);
for (auto i = broadcast_axes.rbegin(); i != broadcast_axes.rend(); i++)
{
axes_mapping.erase(axes_mapping.begin() + *i);
}
auto axes_mapping_node =
op::Constant::create(element::i64, Shape{axes_mapping.size()}, axes_mapping);
auto replacement_node = make_shared<op::v1::Broadcast>(node->input(0).get_source_output(),
result_shape_node->output(0),
axes_mapping_node->output(0));
replace_node(node, replacement_node);
modified = true;
break;
}
case OP_TYPEID::Convolution:
{
auto tmp = dynamic_cast<const op::v0::Convolution*>(node.get());
......
......@@ -22,6 +22,7 @@
#include "ngraph/code_writer.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/avg_pool.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/gather.hpp"
#include "ngraph/op/max.hpp"
......@@ -79,7 +80,6 @@ namespace ngraph
class NotEqual;
class Select;
class Subtract;
class Broadcast;
class Convert;
class Constant;
class Reshape;
......
......@@ -16,6 +16,7 @@
#include "ngraph/runtime/dynamic/dynamic_backend.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/experimental/dyn_broadcast.hpp"
#include "ngraph/op/experimental/dyn_replace_slice.hpp"
#include "ngraph/op/experimental/dyn_reshape.hpp"
......@@ -26,6 +27,7 @@
#include "ngraph/pass/constant_folding.hpp"
#include "ngraph/pass/dyn_elimination.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/opset0_downgrade.hpp"
#include "ngraph/pass/shape_relevance.hpp"
#include "ngraph/specialize_function.hpp"
#include "ngraph/util.hpp"
......@@ -85,7 +87,8 @@ bool is_dynamic_op(const std::shared_ptr<Node>& op)
{
return is_type<op::Transpose>(op) || is_type<op::DynBroadcast>(op) ||
is_type<op::DynReplaceSlice>(op) || is_type<op::DynSlice>(op) ||
is_type<op::v1::Reshape>(op) || is_type<op::DynReshape>(op) || is_type<op::Range>(op);
is_type<op::v1::Reshape>(op) || is_type<op::DynReshape>(op) || is_type<op::Range>(op) ||
is_type<op::v1::Broadcast>(op);
}
// Helper for a vile hack in DynamicExecutable::call. See body of that function for details.
......@@ -176,6 +179,7 @@ bool runtime::dynamic::DynamicExecutable::call(
pass::Manager passes;
passes.register_pass<pass::ConstantFolding>();
passes.register_pass<pass::DynElimination>();
passes.register_pass<pass::Opset0Downgrade>(); // Converts dynamic v1 variants to v0 ops
passes.set_per_pass_validation(false);
// FIXME(amprocte): Vile, temporary hack: we need to do repeated rounds of
......
......@@ -913,9 +913,17 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
}
case OP_TYPEID::Broadcast:
{
auto shape = node_js.at("shape").get<vector<size_t>>();
auto axes = deserialize_axis_set(node_js.at("axes"));
node = make_shared<op::Broadcast>(args[0], shape, axes);
if (op_version == 0)
{
auto shape = node_js.at("shape").get<vector<size_t>>();
auto axes = deserialize_axis_set(node_js.at("axes"));
node = make_shared<op::v0::Broadcast>(args[0], shape, axes);
}
if (op_version == 1)
{
node = make_shared<op::v1::Broadcast>(
args[0], args[1], args[2], read_auto_broadcast(node_js, "auto_broadcast"));
}
break;
}
case OP_TYPEID::BroadcastDistributed:
......@@ -2500,9 +2508,20 @@ json JSONSerializer::serialize_node(const Node& n)
}
case OP_TYPEID::Broadcast:
{
auto tmp = static_cast<const op::Broadcast*>(&n);
node["axes"] = serialize_axis_set(tmp->get_broadcast_axes());
node["shape"] = tmp->get_broadcast_shape();
if (op_version == 0)
{
auto tmp = dynamic_cast<const op::v0::Broadcast*>(&n);
node["axes"] = serialize_axis_set(tmp->get_broadcast_axes());
node["shape"] = tmp->get_broadcast_shape();
}
if (op_version == 1)
{
auto tmp = dynamic_cast<const op::v1::Broadcast*>(&n);
if (tmp->get_broadcast_spec().m_type != op::AutoBroadcastType::NONE)
{
node["auto_broadcast"] = write_auto_broadcast(tmp->get_broadcast_spec());
}
}
break;
}
case OP_TYPEID::BroadcastDistributed: { break;
......
......@@ -69,6 +69,7 @@ set(SRC
node_input_output.cpp
nop_elimination.cpp
op.cpp
opset_pass/broadcast_opset_pass.cpp
opset_pass/convolution_opset_pass.cpp
opset_pass/gather_opset_pass.cpp
opset_pass/pad_opset_pass.cpp
......
......@@ -73,3 +73,56 @@ NGRAPH_TEST(${BACKEND_NAME}, dyn_broadcast)
ASSERT_TRUE(test::all_close_f(results, expected_results[i], MIN_FLOAT_TOLERANCE_BITS));
}
}
NGRAPH_TEST(${BACKEND_NAME}, broadcast_v1)
{
// Create a graph for
// f(x,shape:i32,axes:32) = Broadcast(x,Convert<i64>(shape),Convert<i64>(axes)).
auto x = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto shape = make_shared<op::Parameter>(element::i32, PartialShape{Dimension::dynamic()});
auto axes = make_shared<op::Parameter>(element::i32, PartialShape{Dimension::dynamic()});
auto shape_i64 = make_shared<op::Convert>(shape, element::i64);
auto axes_i64 = make_shared<op::Convert>(axes, element::i64);
auto bc = make_shared<op::v1::Broadcast>(x, shape_i64, axes_i64);
auto f = make_shared<Function>(NodeVector{bc}, ParameterVector{x, shape, axes});
auto backend = runtime::Backend::create("${BACKEND_NAME}", true);
auto ex = backend->compile(f);
auto t_r = backend->create_dynamic_tensor(element::f32, PartialShape::dynamic());
std::vector<Shape> x_shapes{Shape{}, Shape{}, Shape{2}, Shape{2}};
std::vector<std::vector<int32_t>> shapes{{2, 2}, {2, 2, 2}, {3, 2}, {2, 3}, {2, 2}, {2, 2}};
std::vector<std::vector<int32_t>> axeses{{}, {}, {1}, {0}, {0}, {1}};
std::vector<std::vector<float>> inputs{{6}, {7}, {10, 11}, {10, 11}};
std::vector<Shape> expected_result_shapes{
Shape{2, 2}, Shape{2, 2, 2}, Shape{3, 2}, Shape{2, 3}, Shape{2, 2}, Shape{2, 2}};
std::vector<std::vector<float>> expected_results{{6, 6, 6, 6},
{7, 7, 7, 7, 7, 7, 7, 7},
{10, 11, 10, 11, 10, 11},
{10, 10, 10, 11, 11, 11},
{10, 10, 11, 11},
{10, 11, 10, 11}};
for (size_t i = 0; i < x_shapes.size(); i++)
{
auto t_x = backend->create_tensor(element::f32, x_shapes[i]);
auto t_shape = backend->create_tensor(element::i32, Shape{shapes[i].size()});
auto t_axes = backend->create_tensor(element::i32, Shape{axeses[i].size()});
copy_data(t_x, inputs[i]);
copy_data(t_shape, shapes[i]);
copy_data(t_axes, axeses[i]);
ex->call_with_validate({t_r}, {t_x, t_shape, t_axes});
ASSERT_EQ(t_r->get_shape(), expected_result_shapes[i]);
auto results = read_vector<float>(t_r);
ASSERT_TRUE(test::all_close_f(results, expected_results[i], MIN_FLOAT_TOLERANCE_BITS));
}
}
\ No newline at end of file
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "ngraph/op/util/attr_types.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/opset0_downgrade.hpp"
#include "ngraph/pass/opset1_upgrade.hpp"
#include "util/type_prop.hpp"
using namespace std;
using namespace ngraph;
TEST(opset_transform, opset1_broadcast_upgrade_pass)
{
auto arg = make_shared<op::Parameter>(element::f32, Shape{5, 6});
auto bcast_v0 = make_shared<op::v0::Broadcast>(arg, Shape{3, 5, 4, 6}, AxisSet{0, 2});
auto f = make_shared<Function>(NodeVector{bcast_v0}, ParameterVector{arg});
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<pass::Opset1Upgrade>();
pass_manager.run_passes(f);
auto bcast_v1 = static_pointer_cast<op::v1::Broadcast>(
f->get_results().at(0)->input_value(0).get_node_shared_ptr());
EXPECT_EQ(bcast_v1->description(), "Broadcast");
EXPECT_EQ(bcast_v1->get_version(), 1);
EXPECT_EQ(bcast_v1->get_broadcast_spec(), op::AutoBroadcastSpec());
EXPECT_EQ(bcast_v1->get_broadcast_axes(), (std::make_pair<bool, AxisSet>(true, AxisSet{0, 2})));
EXPECT_EQ(bcast_v1->input_value(1).get_node()->description(), "Constant");
EXPECT_EQ(bcast_v1->input_value(2).get_node()->description(), "Constant");
EXPECT_EQ(static_pointer_cast<op::Constant>(bcast_v1->input_value(1).get_node_shared_ptr())
->get_shape_val(),
(Shape{3, 5, 4, 6}));
EXPECT_EQ(static_pointer_cast<op::Constant>(bcast_v1->input_value(2).get_node_shared_ptr())
->get_axis_set_val(),
(AxisSet{1, 3}));
}
TEST(opset_transform, opset1_broadcast_downgrade_pass)
{
auto arg = make_shared<op::Parameter>(element::f32, Shape{1, 2, 3});
auto target_shape = op::Constant::create<int64_t>(element::i64, Shape{5}, {3, 1, 4, 2, 3});
auto axes_mapping = op::Constant::create<int64_t>(element::i64, Shape{3}, {1, 3, 4});
auto bcast_v1 = make_shared<op::v1::Broadcast>(arg, target_shape, axes_mapping);
auto f = make_shared<Function>(NodeVector{bcast_v1}, ParameterVector{arg});
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<pass::Opset0Downgrade>();
pass_manager.run_passes(f);
auto bcast_v0 = static_pointer_cast<op::v0::Broadcast>(
f->get_results().at(0)->input_value(0).get_node_shared_ptr());
EXPECT_EQ(bcast_v0->description(), "Broadcast");
EXPECT_EQ(bcast_v0->get_version(), 0);
EXPECT_EQ(bcast_v0->get_broadcast_shape(), (Shape{3, 1, 4, 2, 3}));
EXPECT_EQ(bcast_v0->get_broadcast_axes(), (AxisSet{0, 2}));
}
......@@ -198,3 +198,221 @@ TEST(type_prop, broadcast_partial_rank_static_dynamic_shape_mismatch_wrong_size)
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, broadcast_v1_numpy)
{
auto param = make_shared<op::Parameter>(element::f32, Shape{3, 1});
auto target_shape = op::Constant::create<int64_t>(element::i64, Shape{3}, {2, 3, 6});
auto bc = make_shared<op::v1::Broadcast>(param, target_shape);
ASSERT_EQ(bc->get_element_type(), element::f32);
ASSERT_EQ(bc->get_shape(), (Shape{2, 3, 6}));
}
TEST(type_prop, broadcast_v1_pdpd)
{
auto param = make_shared<op::Parameter>(element::f32, Shape{3, 1});
auto target_shape = op::Constant::create<int64_t>(element::i64, Shape{3}, {2, 3, 6});
auto bc = make_shared<op::v1::Broadcast>(
param, target_shape, op::AutoBroadcastSpec(op::AutoBroadcastType::PDPD, 1));
ASSERT_EQ(bc->get_element_type(), element::f32);
ASSERT_EQ(bc->get_shape(), (Shape{2, 3, 6}));
}
TEST(type_prop, broadcast_v1_axes_mapping)
{
auto param = make_shared<op::Parameter>(element::f32, Shape{3, 1});
auto target_shape = op::Constant::create<int64_t>(element::i64, Shape{3}, {2, 3, 1});
auto axes_mapping = op::Constant::create<int64_t>(element::i64, Shape{2}, {1, 2});
auto bc = make_shared<op::v1::Broadcast>(param, target_shape, axes_mapping);
ASSERT_EQ(bc->get_element_type(), element::f32);
ASSERT_EQ(bc->get_shape(), (Shape{2, 3, 1}));
}
TEST(type_prop, broadcast_v1_fail_rank)
{
auto param = make_shared<op::Parameter>(element::f32, Shape{3, 1});
auto target_shape = op::Constant::create<int64_t>(element::i64, Shape{3}, {2, 3, 1});
auto axes_mapping = op::Constant::create<int64_t>(element::i64, Shape{3}, {1, 2, 3});
try
{
auto bc = make_shared<op::v1::Broadcast>(param, target_shape, axes_mapping);
FAIL() << "Broadcast: target shape mismatch with input rank not detected";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
"Broadcast axes_mapping shape Shape{3} doesn't match rank of input tensor 2");
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, broadcast_v1_fail_transpose)
{
auto param = make_shared<op::Parameter>(element::f32, Shape{3, 1});
auto target_shape = op::Constant::create<int64_t>(element::i64, Shape{3}, {2, 1, 3});
auto axes_mapping = op::Constant::create<int64_t>(element::i64, Shape{2}, {2, 1});
try
{
auto bc = make_shared<op::v1::Broadcast>(param, target_shape, axes_mapping);
FAIL() << "Broadcast: transpose prohibition not detected";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
"Broadcast doesn't permit transposes. axes_mapping AxisVector{2, 1} "
"not in sorted order");
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, broadcast_v1_fail_axes_map)
{
auto param = make_shared<op::Parameter>(element::f32, Shape{3, 1});
auto target_shape = op::Constant::create<int64_t>(element::i64, Shape{3}, {2, 3, 1});
auto axes_mapping = op::Constant::create<int64_t>(element::i64, Shape{2}, {1, 3});
try
{
auto bc = make_shared<op::v1::Broadcast>(param, target_shape, axes_mapping);
FAIL() << "Broadcast: wrong axes_map not detected";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), "Broadcast axes_mapping[1]: 3 exceeds target rank 3");
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, broadcast_v1_fail_axes_map_shape)
{
auto param = make_shared<op::Parameter>(element::f32, Shape{3, 1});
auto target_shape = op::Constant::create<int64_t>(element::i64, Shape{3}, {2, 3, 3});
auto axes_mapping = op::Constant::create<int64_t>(element::i64, Shape{2}, {1, 2});
try
{
auto bc = make_shared<op::v1::Broadcast>(param, target_shape, axes_mapping);
FAIL() << "Broadcast: wrong target shape not detected";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), "Broadcast target[axes_mapping[1]] Expected 1. Got 3");
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, broadcast_v1_shape_wrong_rank)
{
auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 4});
auto bc_shape = make_shared<op::Parameter>(element::i64, Shape{1, 1});
auto bc_axes = make_shared<op::Parameter>(element::i64, Shape{1});
try
{
auto bc = make_shared<op::v1::Broadcast>(arg, bc_shape, bc_axes);
FAIL() << "DynBroadcast: wrong shape rank not detected";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), "Broadcast shape rank must be 1");
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, broadcast_v1_axes_wrong_rank)
{
auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 4});
auto bc_shape = make_shared<op::Parameter>(element::i64, Shape{1});
auto bc_axes = make_shared<op::Parameter>(element::i64, Shape{2, 2});
try
{
auto bc = make_shared<op::v1::Broadcast>(arg, bc_shape, bc_axes);
FAIL() << "Broadcast: axes shape rank not detected";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), "Broadcast axes rank must be 1");
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, broadcast_v1_output_partial_shape_dynamic)
{
auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 4});
auto bc_shape = make_shared<op::Parameter>(element::i64, Shape{1});
auto bc_axes = make_shared<op::Parameter>(element::i64, Shape{2});
auto bc = make_shared<op::v1::Broadcast>(arg, bc_shape, bc_axes);
ASSERT_TRUE(bc->get_output_partial_shape(0).is_dynamic());
}
TEST(type_prop, broadcast_v1_broadcast_shape_et_wrong)
{
auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 4});
// wrong element type
auto bc_shape = make_shared<op::Parameter>(element::boolean, Shape{1});
auto bc_axes = make_shared<op::Parameter>(element::i64, Shape{2});
try
{
auto bc = make_shared<op::v1::Broadcast>(arg, bc_shape, bc_axes);
FAIL() << "Broadcast: did not detect shape element type not i64";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Broadcast shape must have element type i64"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, broadcast_v1_axes_et_wrong)
{
auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 4});
auto bc_shape = make_shared<op::Parameter>(element::i64, Shape{1});
// wrong element type
auto bc_axes = make_shared<op::Parameter>(element::f32, Shape{2});
try
{
auto bc = make_shared<op::v1::Broadcast>(arg, bc_shape, bc_axes);
FAIL() << "Broadcast: did not detect axes element type not i64";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Broadcast axes must have element type i64"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment