Commit e741f8f1 authored by Jayaram Bobba's avatar Jayaram Bobba Committed by Scott Cyphers

[SPEC] Added support for v1 Broadcast op specification (#3737)

* - Added support for v1 Broadcast op specification
- Added upgrade/downgrade conversions between v0 and v1

* Added unit test for pdpd broadcast

* Make numpy default autobroadcast type and some style fixes

* Added support in Dynamic wrapper for dyn elimination and copied over unit tests from DynBroadcast

* Addressed PR feedback

* Addressed PR feedback on documentation
parent 4ab4609f
This diff is collapsed.
......@@ -18,10 +18,13 @@
#include "ngraph/axis_set.hpp"
#include "ngraph/op/op.hpp"
#include "ngraph/op/util/attr_types.hpp"
namespace ngraph
{
namespace op
{
namespace v0
{
/// \brief Operation which "adds" axes to an input tensor, replicating elements from the
/// input as needed along the new axes.
......@@ -35,12 +38,14 @@ namespace ngraph
Broadcast() = default;
/// \brief Constructs a broadcast operation.
///
/// \param arg Node that produces the input tensor to be broadcast.
/// \param arg The input tensor to be broadcast.
/// \param shape The shape of the output tensor.
/// \param broadcast_axes The axis positions (0-based) in the result that are being
/// broadcast. The remaining axes in shape must be the same as
/// the shape of arg.
Broadcast(const Output<Node>& arg, const Shape& shape, const AxisSet& broadcast_axes);
Broadcast(const Output<Node>& arg,
const Shape& shape,
const AxisSet& broadcast_axes);
void validate_and_infer_types() override;
......@@ -56,7 +61,9 @@ namespace ngraph
const Shape& get_broadcast_shape() const { return m_shape; }
void set_broadcast_shape(const Shape& shape) { m_shape = shape; }
protected:
Broadcast(const OutputVector& args, const Shape& shape, const AxisSet& broadcast_axes);
Broadcast(const OutputVector& args,
const Shape& shape,
const AxisSet& broadcast_axes);
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
......@@ -67,7 +74,7 @@ namespace ngraph
};
/// \brief Broadcast arg to the same shape as like_arg.
class BroadcastLike : public Broadcast
class BroadcastLike : public v0::Broadcast
{
public:
NGRAPH_API
......@@ -92,7 +99,10 @@ namespace ngraph
copy_with_new_args(const NodeVector& new_args) const override;
void infer_shape() override;
const AxisSet& get_initial_broadcast_axes() const { return m_initial_broadcast_axes; }
const AxisSet& get_initial_broadcast_axes() const
{
return m_initial_broadcast_axes;
}
void set_initial_broadcast_axes(const AxisSet& initial_broadcast_axes)
{
m_initial_broadcast_axes = initial_broadcast_axes;
......@@ -101,5 +111,76 @@ namespace ngraph
protected:
AxisSet m_initial_broadcast_axes;
};
} // namespace v0
namespace v1
{
/// \brief Operation which "adds" axes to an input tensor, replicating elements from the
/// input as needed along the new axes.
class Broadcast : public Op
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"Broadcast", 1};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs a broadcast operation.
Broadcast() = default;
/// \brief Constructs a broadcast operation.
///
/// \param arg The input tensor to be broadcast.
/// \param target_shape The shape of the output tensor.
/// \param axes_mapping The axis positions (0-based) in the result that correspond
/// to input axes. 'Arg' tensor is broadcast along the
/// remaining
/// axes.
/// E.g., Input Shape - [3, 4], Target Shape - [3, 5, 4, 4]
/// axes_mapping - [0, 2] => Broadcast along axes 1 and 3.
/// axes_mapping - [0, 3] => Broadcast along axes 1 and 2.
/// \param broadcast_spec Broadcast specification to use for determining broadcast
/// axes. 'axes_mapping' is ignored if broadcast_spec is not
/// NONE
Broadcast(const Output<Node>& arg,
const Output<Node>& target_shape,
const Output<Node>& axes_mapping,
const AutoBroadcastSpec& broadcast_spec = AutoBroadcastSpec());
/// \brief Constructs a broadcast operation.
///
/// \param arg The input tensor to be broadcast.
/// \param target_shape The shape of the output tensor.
/// \param broadcast_spec Broadcast specification to use for determining broadcast
/// axes
Broadcast(const Output<Node>& arg,
const Output<Node>& target_shape,
const AutoBroadcastSpec& broadcast_spec =
AutoBroadcastSpec(AutoBroadcastType::NUMPY));
size_t get_version() const override { return 1; }
void validate_and_infer_types() override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
/// \return Broadcast Specification.
const AutoBroadcastSpec& get_broadcast_spec() const { return m_broadcast_spec; }
void set_broadcast_spec(const AutoBroadcastSpec& broadcast_spec)
{
m_broadcast_spec = broadcast_spec;
}
/// \return true and the AxisSet if broadcast axes can be fully determined.
std::pair<bool, AxisSet> get_broadcast_axes() const;
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
private:
AutoBroadcastSpec m_broadcast_spec;
};
} // namespace v1
using v0::Broadcast;
using v0::BroadcastLike;
}
}
......@@ -102,6 +102,7 @@ namespace ngraph
enum class AutoBroadcastType
{
NONE = 0,
EXPLICIT = NONE,
NUMPY,
PDPD
};
......@@ -136,6 +137,11 @@ namespace ngraph
AutoBroadcastType m_type; // Implicit broadcasting algorithm
int64_t m_axis; // Axis to start alignment on
bool operator==(const AutoBroadcastSpec& a) const
{
return a.m_type == m_type && a.m_axis == m_axis;
}
};
std::ostream& operator<<(std::ostream& s, const AutoBroadcastType& type);
......
......@@ -16,6 +16,7 @@
#include "ngraph/pass/opset0_downgrade.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/pad.hpp"
......@@ -83,6 +84,22 @@ bool pass::Opset0Downgrade::run_on_node(shared_ptr<Node> node)
#endif
switch (get_typeid(node))
{
case OP_TYPEID::Broadcast:
{
auto tmp = dynamic_cast<const op::v1::Broadcast*>(node.get());
const auto arg = node->input(0).get_source_output();
NGRAPH_CHECK(node->input_value(1).get_node_shared_ptr()->is_constant());
auto target_shape =
static_pointer_cast<op::Constant>(node->input_value(1).get_node_shared_ptr())
->get_shape_val();
NGRAPH_CHECK(tmp->get_broadcast_axes().first);
auto replacement_node =
make_shared<op::v0::Broadcast>(arg, target_shape, tmp->get_broadcast_axes().second);
replace_node(node, replacement_node);
modified = true;
break;
}
case OP_TYPEID::Pad:
{
auto tmp = as_type_ptr<op::v1::Pad>(node);
......
......@@ -16,6 +16,7 @@
#include "ngraph/pass/opset1_upgrade.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/op/avg_pool.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/experimental/dyn_reshape.hpp"
......@@ -33,6 +34,7 @@
#include "ngraph/op/topk.hpp"
#include <limits>
#include <numeric>
using namespace std;
using namespace ngraph;
......@@ -137,6 +139,31 @@ bool pass::Opset1Upgrade::run_on_node(shared_ptr<Node> node)
modified = true;
break;
}
case OP_TYPEID::Broadcast:
{
auto tmp = dynamic_cast<const op::v0::Broadcast*>(node.get());
auto result_shape = tmp->get_broadcast_shape();
auto result_shape_node =
op::Constant::create(element::i64, Shape{result_shape.size()}, result_shape);
auto broadcast_axes = tmp->get_broadcast_axes();
// Flip broadcast_axes to get axes_mapping
std::vector<size_t> axes_mapping(result_shape.size());
std::iota(axes_mapping.begin(), axes_mapping.end(), 0);
for (auto i = broadcast_axes.rbegin(); i != broadcast_axes.rend(); i++)
{
axes_mapping.erase(axes_mapping.begin() + *i);
}
auto axes_mapping_node =
op::Constant::create(element::i64, Shape{axes_mapping.size()}, axes_mapping);
auto replacement_node = make_shared<op::v1::Broadcast>(node->input(0).get_source_output(),
result_shape_node->output(0),
axes_mapping_node->output(0));
replace_node(node, replacement_node);
modified = true;
break;
}
case OP_TYPEID::Convolution:
{
auto tmp = dynamic_cast<const op::v0::Convolution*>(node.get());
......
......@@ -22,6 +22,7 @@
#include "ngraph/code_writer.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/avg_pool.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/gather.hpp"
#include "ngraph/op/max.hpp"
......@@ -79,7 +80,6 @@ namespace ngraph
class NotEqual;
class Select;
class Subtract;
class Broadcast;
class Convert;
class Constant;
class Reshape;
......
......@@ -16,6 +16,7 @@
#include "ngraph/runtime/dynamic/dynamic_backend.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/experimental/dyn_broadcast.hpp"
#include "ngraph/op/experimental/dyn_replace_slice.hpp"
#include "ngraph/op/experimental/dyn_reshape.hpp"
......@@ -26,6 +27,7 @@
#include "ngraph/pass/constant_folding.hpp"
#include "ngraph/pass/dyn_elimination.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/opset0_downgrade.hpp"
#include "ngraph/pass/shape_relevance.hpp"
#include "ngraph/specialize_function.hpp"
#include "ngraph/util.hpp"
......@@ -85,7 +87,8 @@ bool is_dynamic_op(const std::shared_ptr<Node>& op)
{
return is_type<op::Transpose>(op) || is_type<op::DynBroadcast>(op) ||
is_type<op::DynReplaceSlice>(op) || is_type<op::DynSlice>(op) ||
is_type<op::v1::Reshape>(op) || is_type<op::DynReshape>(op) || is_type<op::Range>(op);
is_type<op::v1::Reshape>(op) || is_type<op::DynReshape>(op) || is_type<op::Range>(op) ||
is_type<op::v1::Broadcast>(op);
}
// Helper for a vile hack in DynamicExecutable::call. See body of that function for details.
......@@ -176,6 +179,7 @@ bool runtime::dynamic::DynamicExecutable::call(
pass::Manager passes;
passes.register_pass<pass::ConstantFolding>();
passes.register_pass<pass::DynElimination>();
passes.register_pass<pass::Opset0Downgrade>(); // Converts dynamic v1 variants to v0 ops
passes.set_per_pass_validation(false);
// FIXME(amprocte): Vile, temporary hack: we need to do repeated rounds of
......
......@@ -912,10 +912,18 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
break;
}
case OP_TYPEID::Broadcast:
{
if (op_version == 0)
{
auto shape = node_js.at("shape").get<vector<size_t>>();
auto axes = deserialize_axis_set(node_js.at("axes"));
node = make_shared<op::Broadcast>(args[0], shape, axes);
node = make_shared<op::v0::Broadcast>(args[0], shape, axes);
}
if (op_version == 1)
{
node = make_shared<op::v1::Broadcast>(
args[0], args[1], args[2], read_auto_broadcast(node_js, "auto_broadcast"));
}
break;
}
case OP_TYPEID::BroadcastDistributed:
......@@ -2500,9 +2508,20 @@ json JSONSerializer::serialize_node(const Node& n)
}
case OP_TYPEID::Broadcast:
{
auto tmp = static_cast<const op::Broadcast*>(&n);
if (op_version == 0)
{
auto tmp = dynamic_cast<const op::v0::Broadcast*>(&n);
node["axes"] = serialize_axis_set(tmp->get_broadcast_axes());
node["shape"] = tmp->get_broadcast_shape();
}
if (op_version == 1)
{
auto tmp = dynamic_cast<const op::v1::Broadcast*>(&n);
if (tmp->get_broadcast_spec().m_type != op::AutoBroadcastType::NONE)
{
node["auto_broadcast"] = write_auto_broadcast(tmp->get_broadcast_spec());
}
}
break;
}
case OP_TYPEID::BroadcastDistributed: { break;
......
......@@ -69,6 +69,7 @@ set(SRC
node_input_output.cpp
nop_elimination.cpp
op.cpp
opset_pass/broadcast_opset_pass.cpp
opset_pass/convolution_opset_pass.cpp
opset_pass/gather_opset_pass.cpp
opset_pass/pad_opset_pass.cpp
......
......@@ -73,3 +73,56 @@ NGRAPH_TEST(${BACKEND_NAME}, dyn_broadcast)
ASSERT_TRUE(test::all_close_f(results, expected_results[i], MIN_FLOAT_TOLERANCE_BITS));
}
}
NGRAPH_TEST(${BACKEND_NAME}, broadcast_v1)
{
// Create a graph for
// f(x,shape:i32,axes:32) = Broadcast(x,Convert<i64>(shape),Convert<i64>(axes)).
auto x = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto shape = make_shared<op::Parameter>(element::i32, PartialShape{Dimension::dynamic()});
auto axes = make_shared<op::Parameter>(element::i32, PartialShape{Dimension::dynamic()});
auto shape_i64 = make_shared<op::Convert>(shape, element::i64);
auto axes_i64 = make_shared<op::Convert>(axes, element::i64);
auto bc = make_shared<op::v1::Broadcast>(x, shape_i64, axes_i64);
auto f = make_shared<Function>(NodeVector{bc}, ParameterVector{x, shape, axes});
auto backend = runtime::Backend::create("${BACKEND_NAME}", true);
auto ex = backend->compile(f);
auto t_r = backend->create_dynamic_tensor(element::f32, PartialShape::dynamic());
std::vector<Shape> x_shapes{Shape{}, Shape{}, Shape{2}, Shape{2}};
std::vector<std::vector<int32_t>> shapes{{2, 2}, {2, 2, 2}, {3, 2}, {2, 3}, {2, 2}, {2, 2}};
std::vector<std::vector<int32_t>> axeses{{}, {}, {1}, {0}, {0}, {1}};
std::vector<std::vector<float>> inputs{{6}, {7}, {10, 11}, {10, 11}};
std::vector<Shape> expected_result_shapes{
Shape{2, 2}, Shape{2, 2, 2}, Shape{3, 2}, Shape{2, 3}, Shape{2, 2}, Shape{2, 2}};
std::vector<std::vector<float>> expected_results{{6, 6, 6, 6},
{7, 7, 7, 7, 7, 7, 7, 7},
{10, 11, 10, 11, 10, 11},
{10, 10, 10, 11, 11, 11},
{10, 10, 11, 11},
{10, 11, 10, 11}};
for (size_t i = 0; i < x_shapes.size(); i++)
{
auto t_x = backend->create_tensor(element::f32, x_shapes[i]);
auto t_shape = backend->create_tensor(element::i32, Shape{shapes[i].size()});
auto t_axes = backend->create_tensor(element::i32, Shape{axeses[i].size()});
copy_data(t_x, inputs[i]);
copy_data(t_shape, shapes[i]);
copy_data(t_axes, axeses[i]);
ex->call_with_validate({t_r}, {t_x, t_shape, t_axes});
ASSERT_EQ(t_r->get_shape(), expected_result_shapes[i]);
auto results = read_vector<float>(t_r);
ASSERT_TRUE(test::all_close_f(results, expected_results[i], MIN_FLOAT_TOLERANCE_BITS));
}
}
\ No newline at end of file
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "ngraph/op/util/attr_types.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/opset0_downgrade.hpp"
#include "ngraph/pass/opset1_upgrade.hpp"
#include "util/type_prop.hpp"
using namespace std;
using namespace ngraph;
TEST(opset_transform, opset1_broadcast_upgrade_pass)
{
auto arg = make_shared<op::Parameter>(element::f32, Shape{5, 6});
auto bcast_v0 = make_shared<op::v0::Broadcast>(arg, Shape{3, 5, 4, 6}, AxisSet{0, 2});
auto f = make_shared<Function>(NodeVector{bcast_v0}, ParameterVector{arg});
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<pass::Opset1Upgrade>();
pass_manager.run_passes(f);
auto bcast_v1 = static_pointer_cast<op::v1::Broadcast>(
f->get_results().at(0)->input_value(0).get_node_shared_ptr());
EXPECT_EQ(bcast_v1->description(), "Broadcast");
EXPECT_EQ(bcast_v1->get_version(), 1);
EXPECT_EQ(bcast_v1->get_broadcast_spec(), op::AutoBroadcastSpec());
EXPECT_EQ(bcast_v1->get_broadcast_axes(), (std::make_pair<bool, AxisSet>(true, AxisSet{0, 2})));
EXPECT_EQ(bcast_v1->input_value(1).get_node()->description(), "Constant");
EXPECT_EQ(bcast_v1->input_value(2).get_node()->description(), "Constant");
EXPECT_EQ(static_pointer_cast<op::Constant>(bcast_v1->input_value(1).get_node_shared_ptr())
->get_shape_val(),
(Shape{3, 5, 4, 6}));
EXPECT_EQ(static_pointer_cast<op::Constant>(bcast_v1->input_value(2).get_node_shared_ptr())
->get_axis_set_val(),
(AxisSet{1, 3}));
}
TEST(opset_transform, opset1_broadcast_downgrade_pass)
{
auto arg = make_shared<op::Parameter>(element::f32, Shape{1, 2, 3});
auto target_shape = op::Constant::create<int64_t>(element::i64, Shape{5}, {3, 1, 4, 2, 3});
auto axes_mapping = op::Constant::create<int64_t>(element::i64, Shape{3}, {1, 3, 4});
auto bcast_v1 = make_shared<op::v1::Broadcast>(arg, target_shape, axes_mapping);
auto f = make_shared<Function>(NodeVector{bcast_v1}, ParameterVector{arg});
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<pass::Opset0Downgrade>();
pass_manager.run_passes(f);
auto bcast_v0 = static_pointer_cast<op::v0::Broadcast>(
f->get_results().at(0)->input_value(0).get_node_shared_ptr());
EXPECT_EQ(bcast_v0->description(), "Broadcast");
EXPECT_EQ(bcast_v0->get_version(), 0);
EXPECT_EQ(bcast_v0->get_broadcast_shape(), (Shape{3, 1, 4, 2, 3}));
EXPECT_EQ(bcast_v0->get_broadcast_axes(), (AxisSet{0, 2}));
}
......@@ -198,3 +198,221 @@ TEST(type_prop, broadcast_partial_rank_static_dynamic_shape_mismatch_wrong_size)
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, broadcast_v1_numpy)
{
auto param = make_shared<op::Parameter>(element::f32, Shape{3, 1});
auto target_shape = op::Constant::create<int64_t>(element::i64, Shape{3}, {2, 3, 6});
auto bc = make_shared<op::v1::Broadcast>(param, target_shape);
ASSERT_EQ(bc->get_element_type(), element::f32);
ASSERT_EQ(bc->get_shape(), (Shape{2, 3, 6}));
}
TEST(type_prop, broadcast_v1_pdpd)
{
auto param = make_shared<op::Parameter>(element::f32, Shape{3, 1});
auto target_shape = op::Constant::create<int64_t>(element::i64, Shape{3}, {2, 3, 6});
auto bc = make_shared<op::v1::Broadcast>(
param, target_shape, op::AutoBroadcastSpec(op::AutoBroadcastType::PDPD, 1));
ASSERT_EQ(bc->get_element_type(), element::f32);
ASSERT_EQ(bc->get_shape(), (Shape{2, 3, 6}));
}
TEST(type_prop, broadcast_v1_axes_mapping)
{
auto param = make_shared<op::Parameter>(element::f32, Shape{3, 1});
auto target_shape = op::Constant::create<int64_t>(element::i64, Shape{3}, {2, 3, 1});
auto axes_mapping = op::Constant::create<int64_t>(element::i64, Shape{2}, {1, 2});
auto bc = make_shared<op::v1::Broadcast>(param, target_shape, axes_mapping);
ASSERT_EQ(bc->get_element_type(), element::f32);
ASSERT_EQ(bc->get_shape(), (Shape{2, 3, 1}));
}
TEST(type_prop, broadcast_v1_fail_rank)
{
auto param = make_shared<op::Parameter>(element::f32, Shape{3, 1});
auto target_shape = op::Constant::create<int64_t>(element::i64, Shape{3}, {2, 3, 1});
auto axes_mapping = op::Constant::create<int64_t>(element::i64, Shape{3}, {1, 2, 3});
try
{
auto bc = make_shared<op::v1::Broadcast>(param, target_shape, axes_mapping);
FAIL() << "Broadcast: target shape mismatch with input rank not detected";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
"Broadcast axes_mapping shape Shape{3} doesn't match rank of input tensor 2");
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, broadcast_v1_fail_transpose)
{
auto param = make_shared<op::Parameter>(element::f32, Shape{3, 1});
auto target_shape = op::Constant::create<int64_t>(element::i64, Shape{3}, {2, 1, 3});
auto axes_mapping = op::Constant::create<int64_t>(element::i64, Shape{2}, {2, 1});
try
{
auto bc = make_shared<op::v1::Broadcast>(param, target_shape, axes_mapping);
FAIL() << "Broadcast: transpose prohibition not detected";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
"Broadcast doesn't permit transposes. axes_mapping AxisVector{2, 1} "
"not in sorted order");
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, broadcast_v1_fail_axes_map)
{
auto param = make_shared<op::Parameter>(element::f32, Shape{3, 1});
auto target_shape = op::Constant::create<int64_t>(element::i64, Shape{3}, {2, 3, 1});
auto axes_mapping = op::Constant::create<int64_t>(element::i64, Shape{2}, {1, 3});
try
{
auto bc = make_shared<op::v1::Broadcast>(param, target_shape, axes_mapping);
FAIL() << "Broadcast: wrong axes_map not detected";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), "Broadcast axes_mapping[1]: 3 exceeds target rank 3");
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, broadcast_v1_fail_axes_map_shape)
{
auto param = make_shared<op::Parameter>(element::f32, Shape{3, 1});
auto target_shape = op::Constant::create<int64_t>(element::i64, Shape{3}, {2, 3, 3});
auto axes_mapping = op::Constant::create<int64_t>(element::i64, Shape{2}, {1, 2});
try
{
auto bc = make_shared<op::v1::Broadcast>(param, target_shape, axes_mapping);
FAIL() << "Broadcast: wrong target shape not detected";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), "Broadcast target[axes_mapping[1]] Expected 1. Got 3");
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, broadcast_v1_shape_wrong_rank)
{
auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 4});
auto bc_shape = make_shared<op::Parameter>(element::i64, Shape{1, 1});
auto bc_axes = make_shared<op::Parameter>(element::i64, Shape{1});
try
{
auto bc = make_shared<op::v1::Broadcast>(arg, bc_shape, bc_axes);
FAIL() << "DynBroadcast: wrong shape rank not detected";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), "Broadcast shape rank must be 1");
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, broadcast_v1_axes_wrong_rank)
{
auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 4});
auto bc_shape = make_shared<op::Parameter>(element::i64, Shape{1});
auto bc_axes = make_shared<op::Parameter>(element::i64, Shape{2, 2});
try
{
auto bc = make_shared<op::v1::Broadcast>(arg, bc_shape, bc_axes);
FAIL() << "Broadcast: axes shape rank not detected";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), "Broadcast axes rank must be 1");
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, broadcast_v1_output_partial_shape_dynamic)
{
auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 4});
auto bc_shape = make_shared<op::Parameter>(element::i64, Shape{1});
auto bc_axes = make_shared<op::Parameter>(element::i64, Shape{2});
auto bc = make_shared<op::v1::Broadcast>(arg, bc_shape, bc_axes);
ASSERT_TRUE(bc->get_output_partial_shape(0).is_dynamic());
}
TEST(type_prop, broadcast_v1_broadcast_shape_et_wrong)
{
auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 4});
// wrong element type
auto bc_shape = make_shared<op::Parameter>(element::boolean, Shape{1});
auto bc_axes = make_shared<op::Parameter>(element::i64, Shape{2});
try
{
auto bc = make_shared<op::v1::Broadcast>(arg, bc_shape, bc_axes);
FAIL() << "Broadcast: did not detect shape element type not i64";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Broadcast shape must have element type i64"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, broadcast_v1_axes_et_wrong)
{
auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 4});
auto bc_shape = make_shared<op::Parameter>(element::i64, Shape{1});
// wrong element type
auto bc_axes = make_shared<op::Parameter>(element::f32, Shape{2});
try
{
auto bc = make_shared<op::v1::Broadcast>(arg, bc_shape, bc_axes);
FAIL() << "Broadcast: did not detect axes element type not i64";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Broadcast axes must have element type i64"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment