Commit 86bf4762 authored by Jayaram Bobba's avatar Jayaram Bobba Committed by Scott Cyphers

[SPEC] Select op with implicit broadcast (#3974)

* Added v1::Select op with support for implicit broadcasting

* Addressed PR feedback

* Constant folding support for v1::Select op

* Remove commented-out code

* More shape inference tests
parent 04212bcb
......@@ -211,6 +211,7 @@ NGRAPH_OP(ScaleShift, ngraph::op::v0, 0)
NGRAPH_OP(ScatterAdd, ngraph::op::v0, 0)
NGRAPH_OP(ScatterNDAdd, ngraph::op::v0, 0)
NGRAPH_OP(Select, ngraph::op::v0, 0)
NGRAPH_OP(Select, ngraph::op::v1, 1)
NGRAPH_OP(Selu, ngraph::op::v0, 0)
NGRAPH_OP(Send, ngraph::op::v0, 0)
NGRAPH_OP(ShapeOf, ngraph::op::v0, 0)
......
......@@ -25,20 +25,109 @@
using namespace std;
using namespace ngraph;
constexpr NodeTypeInfo op::Select::type_info;
constexpr NodeTypeInfo op::v1::Select::type_info;
op::Select::Select(const Output<Node>& arg0, const Output<Node>& arg1, const Output<Node>& arg2)
op::v1::Select::Select(const Output<Node>& arg0,
const Output<Node>& arg1,
const Output<Node>& arg2,
const AutoBroadcastSpec& auto_broadcast)
: Op({arg0, arg1, arg2})
, m_auto_broadcast(auto_broadcast)
{
constructor_validate_and_infer_types();
}
void op::Select::validate_and_infer_types()
void op::v1::Select::validate_and_infer_types()
{
// Condition element type check
NODE_VALIDATION_CHECK(this,
get_input_element_type(0).is_dynamic() ||
get_input_element_type(0) == element::boolean,
"Argument 0 does not have boolean element type (element type: ",
"Argument 0 must have boolean element type (element type: ",
get_input_element_type(0),
").");
// Then/Else element type check
element::Type result_et;
NODE_VALIDATION_CHECK(
this,
element::Type::merge(result_et, get_input_element_type(1), get_input_element_type(2)),
"Argument 1 and 2 element types must match.");
PartialShape result_shape = get_input_partial_shape(2);
for (int i = 1; i >= 0; i--)
{
if (get_auto_broadcast().m_type == op::AutoBroadcastType::NONE)
{
NODE_VALIDATION_CHECK(
this,
PartialShape::merge_into(result_shape, get_input_partial_shape(i)),
"Argument shapes are inconsistent.");
}
else if (get_auto_broadcast().m_type == op::AutoBroadcastType::NUMPY ||
get_auto_broadcast().m_type == op::AutoBroadcastType::PDPD)
{
NODE_VALIDATION_CHECK(this,
PartialShape::broadcast_merge_into(result_shape,
get_input_partial_shape(i),
get_auto_broadcast()),
"Argument shapes are inconsistent.");
}
else
{
NODE_VALIDATION_CHECK(this, false, "Unsupported auto broadcast specification");
}
}
set_output_type(0, result_et, result_shape);
}
shared_ptr<Node> op::v1::Select::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<v1::Select>(
new_args.at(0), new_args.at(1), new_args.at(2), m_auto_broadcast);
}
bool op::v1::Select::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("auto_broadcast", m_auto_broadcast);
return true;
}
void op::v1::Select::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
{
if (get_auto_broadcast().m_type != op::AutoBroadcastType::NONE)
{
throw ngraph_error("Autodiff not supported with auto broadcasting");
}
auto delta = deltas.at(0);
auto p = input_value(0);
auto x = input_value(1);
auto y = input_value(2);
auto p_as_x_type = make_shared<op::Convert>(p, x.get_element_type());
auto not_p_as_y_type = make_shared<op::Convert>(make_shared<op::Not>(p), y.get_element_type());
adjoints.add_delta(x, delta * p_as_x_type);
adjoints.add_delta(y, delta * not_p_as_y_type);
}
constexpr NodeTypeInfo op::v0::Select::type_info;
op::v0::Select::Select(const Output<Node>& arg0, const Output<Node>& arg1, const Output<Node>& arg2)
: Op({arg0, arg1, arg2})
{
constructor_validate_and_infer_types();
}
void op::v0::Select::validate_and_infer_types()
{
NODE_VALIDATION_CHECK(this,
get_input_element_type(0).is_dynamic() ||
get_input_element_type(0) == element::boolean,
"Argument 0 must have boolean element type (element type: ",
get_input_element_type(0),
").");
......@@ -61,13 +150,13 @@ void op::Select::validate_and_infer_types()
set_output_type(0, result_et, result_shape);
}
shared_ptr<Node> op::Select::copy_with_new_args(const NodeVector& new_args) const
shared_ptr<Node> op::v0::Select::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<Select>(new_args.at(0), new_args.at(1), new_args.at(2));
return make_shared<v0::Select>(new_args.at(0), new_args.at(1), new_args.at(2));
}
void op::Select::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
void op::v0::Select::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
{
auto delta = deltas.at(0);
......
......@@ -66,6 +66,72 @@ namespace ngraph
const NodeVector& deltas) override;
};
}
namespace v1
{
// clang-format off
/// \brief Elementwise selection operation.
///
/// ## Inputs
///
/// | | Type | Description |
/// | ------ | --------------------------------------------- | ------------------------------------------------------------ |
/// | `arg0` | \f$\texttt{bool}[d_1,\dots,d_n]~(n \geq 0)\f$ | A tensor of any shape, with element `bool`. |
/// | `arg1` | \f$E[d_1,\dots,d_n]~(n \geq 0)\f$ | A tensor of a shape that is broadcast-compatible with `arg0`, with any element type. |
/// | `arg2` | \f$E[d_1,\dots,d_n]~(n \geq 0)\f$ | A tensor of a shape that is broadcast-compatible with `arg0`, and same element type as `arg1`. |
/// | `auto_broadcast`| AutoBroadcastSpec | Auto broadcast specification. |
///
/// ## Output
///
/// | Type | Description |
/// | ---------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
/// | \f$E[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \texttt{arg1}[i_1,\dots,i_n]\text{ if }\texttt{arg0}[i_1,\dots,i_n] \neq 0\text{, else }\texttt{arg2}[i_1,\dots,i_n]\f$ |
// clang-format on
class NGRAPH_API Select : public Op
{
public:
static constexpr NodeTypeInfo type_info{"Select", 1};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs a selection operation.
Select()
: m_auto_broadcast(AutoBroadcastSpec(AutoBroadcastType::NUMPY))
{
}
/// \brief Constructs a selection operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param arg2 Node that produces the third input tensor.
/// \param auto_broadcast Auto broadcast specification. Default is Numpy-style
/// implicit broadcasting.
Select(const Output<Node>& arg0,
const Output<Node>& arg1,
const Output<Node>& arg2,
const AutoBroadcastSpec& auto_broadcast =
AutoBroadcastSpec(AutoBroadcastType::NUMPY));
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override;
const AutoBroadcastSpec& get_auto_broadcast() const { return m_auto_broadcast; }
void set_auto_broadcast(const AutoBroadcastSpec& auto_broadcast)
{
m_auto_broadcast = auto_broadcast;
}
bool supports_auto_broadcast() const override { return true; }
// TODO: Move all uses of get_autob to get_auto_broadcast() and remove this.
const AutoBroadcastSpec& get_autob() const override { return m_auto_broadcast; }
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
private:
AutoBroadcastSpec m_auto_broadcast;
};
}
using v0::Select;
}
}
......@@ -135,7 +135,7 @@ NGRAPH_OP(Result, ngraph::op)
NGRAPH_OP(Reverse, ngraph::op::v1)
NGRAPH_OP(ReverseSequence, ngraph::op::v0)
NGRAPH_OP(RNNCell, ngraph::op::v0)
NGRAPH_OP(Select, ngraph::op::v0)
NGRAPH_OP(Select, ngraph::op::v1)
NGRAPH_OP(Selu, ngraph::op::v0)
NGRAPH_OP(ShapeOf, ngraph::op::v0)
NGRAPH_OP(ShuffleChannels, ngraph::op::v0)
......
......@@ -22,19 +22,33 @@ using namespace std;
using namespace ngraph;
template <class T>
shared_ptr<op::Constant> fold_constant_select(shared_ptr<op::Constant> selection,
shared_ptr<op::Constant> t,
shared_ptr<op::Constant> f,
shared_ptr<op::Select> select)
shared_ptr<op::Constant> fold_constant_select(const shared_ptr<op::Constant>& selection,
const shared_ptr<op::Constant>& t,
const shared_ptr<op::Constant>& f,
const shared_ptr<Node>& select)
{
auto out_shape = select->get_shape();
vector<T> out_vec(shape_size(out_shape));
if (auto select_v0 = as_type_ptr<op::v0::Select>(select))
{
runtime::reference::select<T>(selection->get_data_ptr<char>(),
t->get_data_ptr<T>(),
f->get_data_ptr<T>(),
out_vec.data(),
shape_size(out_shape));
}
else if (auto select_v1 = as_type_ptr<op::v1::Select>(select))
{
runtime::reference::select<T>(selection->get_data_ptr<char>(),
t->get_data_ptr<T>(),
f->get_data_ptr<T>(),
out_vec.data(),
selection->get_shape(),
t->get_shape(),
f->get_shape(),
select_v1->get_auto_broadcast());
}
return make_shared<op::Constant>(select->get_element_type(), out_shape, out_vec);
}
......@@ -47,7 +61,8 @@ void pass::ConstantFolding::construct_constant_select()
element::i64, Shape{2, 3, 4}, pattern::has_class<op::Constant>());
auto f_label = make_shared<pattern::op::Label>(
element::i64, Shape{2, 3, 4}, pattern::has_class<op::Constant>());
auto select_op = make_shared<op::Select>(selection_label, t_label, f_label);
auto select_v0_op = make_shared<op::v0::Select>(selection_label, t_label, f_label);
auto select_v1_op = make_shared<op::v1::Select>(selection_label, t_label, f_label);
auto constant_select_callback = [selection_label, t_label, f_label](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for constant_select_callback against node = "
......@@ -55,10 +70,11 @@ void pass::ConstantFolding::construct_constant_select()
auto pattern_map = m.get_pattern_map();
auto selection_node = static_pointer_cast<op::Constant>(pattern_map[selection_label]);
auto t_node = static_pointer_cast<op::Constant>(pattern_map[t_label]);
auto f_node = static_pointer_cast<op::Constant>(pattern_map[f_label]);
auto select = static_pointer_cast<op::Select>(m.get_match_root());
const auto& selection_node =
static_pointer_cast<op::Constant>(pattern_map[selection_label]);
const auto& t_node = static_pointer_cast<op::Constant>(pattern_map[t_label]);
const auto& f_node = static_pointer_cast<op::Constant>(pattern_map[f_label]);
const auto& select = m.get_match_root();
NGRAPH_CHECK(revalidate_and_ensure_static(select));
......@@ -120,7 +136,12 @@ void pass::ConstantFolding::construct_constant_select()
return true;
};
auto select_matcher =
make_shared<pattern::Matcher>(select_op, "ConstantFolding.ConstantSelect");
this->add_matcher(select_matcher, constant_select_callback, PassProperty::CHANGE_DYNAMIC_STATE);
this->add_matcher(
make_shared<pattern::Matcher>(select_v0_op, "ConstantFolding.ConstantSelectV0"),
constant_select_callback,
PassProperty::CHANGE_DYNAMIC_STATE);
this->add_matcher(
make_shared<pattern::Matcher>(select_v1_op, "ConstantFolding.ConstantSelectV1"),
constant_select_callback,
PassProperty::CHANGE_DYNAMIC_STATE);
}
......@@ -24,6 +24,7 @@
#include "ngraph/node.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/ops.hpp"
#include "ngraph/pass/implicit_broadcast_elimination.hpp"
#include "ngraph/pass/opset0_downgrade.hpp"
#include "ngraph/slice_plan.hpp"
#include "ngraph/type.hpp"
......@@ -490,6 +491,15 @@ namespace
return true;
}
bool op_cast(shared_ptr<op::v1::Select> node)
{
ngraph::pass::ImplicitBroadcastElimination().run_on_node(node);
auto replacement_node = make_shared<op::v0::Select>(
node->input_value(0), node->input_value(1), node->input_value(2));
replace_node(node, replacement_node);
return true;
}
bool op_cast(shared_ptr<op::v1::StridedSlice> node)
{
auto convert_mask_to_axes = [](const std::vector<int64_t>& mask) {
......
......@@ -497,6 +497,16 @@ namespace
return true;
}
bool op_cast(shared_ptr<op::Select> node)
{
auto replacement_node = make_shared<op::v1::Select>(node->input_value(0),
node->input_value(1),
node->input_value(2),
op::AutoBroadcastSpec());
replace_node(node, replacement_node);
return true;
}
bool op_cast(shared_ptr<op::Softmax> node)
{
NGRAPH_CHECK(node->input_value(1).get_node_shared_ptr()->is_constant(),
......
......@@ -221,6 +221,211 @@ namespace ngraph
}
}
}
/// \brief Helper function to implement autobroadcasting elementwise ternaryop
/// references.
///
/// \tparam U Element type of the selector tensor.
/// \tparam T Element type of the input tensors.
/// \tparam Functor Type of the functor for the elementwise operation. Must support
/// operator()(U,T,T), and operator()(U,T,T) must return a value of type
/// T.
///
/// \param arg0 Pointer to the buffer for selector tensor.
/// \param arg1 Pointer to the buffer for left operand input tensor.
/// \param arg2 Pointer to the buffer for right operand input tensor.
/// \param out Pointer to the buffer for output tensor. This must be pre-allocated by
/// the caller, and must be large enough to hold a tensor of the correct
/// shape.
/// \param broadcast_spec Specification of the auto-broadcasting scheme.
/// \param elementwise_functor Functor implementing the elementwise operation to be
/// applied across the input tensors. Must accept an argument
/// of
/// type U and two of type T, and return a value of type T.
template <typename T, typename U, typename Functor>
void autobroadcast_select(const U* arg0,
const T* arg1,
const T* arg2,
T* out,
const Shape& arg0_shape,
const Shape& arg1_shape,
const Shape& arg2_shape,
const op::AutoBroadcastSpec& broadcast_spec,
Functor elementwise_functor)
{
switch (broadcast_spec.m_type)
{
case op::AutoBroadcastType::NONE:
for (size_t i = 0; i < shape_size(arg0_shape); i++)
{
out[i] = elementwise_functor(arg0[i], arg1[i], arg2[i]);
}
break;
case op::AutoBroadcastType::NUMPY:
// Uses same approach as autobroadcast_binop.
{
Shape arg0_padded_shape = arg0_shape;
Shape arg1_padded_shape = arg1_shape;
Shape arg2_padded_shape = arg2_shape;
while (arg1_padded_shape.size() < arg2_padded_shape.size())
{
arg1_padded_shape.insert(arg1_padded_shape.begin(), 1);
}
while (arg2_padded_shape.size() < arg1_padded_shape.size())
{
arg2_padded_shape.insert(arg2_padded_shape.begin(), 1);
}
while (arg0_padded_shape.size() < arg1_padded_shape.size())
{
arg0_padded_shape.insert(arg0_padded_shape.begin(), 1);
}
Shape arg0_squeezed_shape;
Shape arg1_squeezed_shape;
Shape arg2_squeezed_shape;
AxisSet arg0_squeezed_axes;
AxisSet arg1_squeezed_axes;
AxisSet arg2_squeezed_axes;
Shape output_shape;
for (size_t i = 0; i < arg1_padded_shape.size(); i++)
{
if (arg1_padded_shape[i] == 1)
{
arg1_squeezed_axes.insert(i);
}
else
{
arg1_squeezed_shape.push_back(arg1_padded_shape[i]);
}
if (arg2_padded_shape[i] == 1)
{
arg2_squeezed_axes.insert(i);
}
else
{
arg2_squeezed_shape.push_back(arg2_padded_shape[i]);
}
if (arg0_padded_shape[i] == 1)
{
arg0_squeezed_axes.insert(i);
}
else
{
arg0_squeezed_shape.push_back(arg0_padded_shape[i]);
}
output_shape.push_back(arg1_padded_shape[i] == 1
? arg2_padded_shape[i]
: arg1_padded_shape[i]);
}
CoordinateTransform arg0_transform(arg0_squeezed_shape);
CoordinateTransform arg1_transform(arg1_squeezed_shape);
CoordinateTransform arg2_transform(arg2_squeezed_shape);
CoordinateTransform output_transform(output_shape);
for (const Coordinate& output_coord : output_transform)
{
Coordinate arg0_coord = reduce(output_coord, arg0_squeezed_axes);
Coordinate arg1_coord = reduce(output_coord, arg1_squeezed_axes);
Coordinate arg2_coord = reduce(output_coord, arg2_squeezed_axes);
out[output_transform.index(output_coord)] =
elementwise_functor(arg0[arg0_transform.index(arg0_coord)],
arg1[arg1_transform.index(arg1_coord)],
arg2[arg2_transform.index(arg2_coord)]);
}
}
break;
case op::AutoBroadcastType::PDPD:
{
// arg0 and arg2 are broadcast to arg1 shape
int64_t axis = broadcast_spec.m_axis;
if (axis == -1)
{
axis = arg1_shape.size() - arg2_shape.size();
}
Shape arg0_padded_shape = arg0_shape;
Shape arg2_padded_shape = arg2_shape;
// Trim trailing ones
while (arg0_padded_shape.size() > 0 && arg0_padded_shape.back() == 1)
{
arg0_padded_shape.pop_back();
}
for (int64_t i = 0; i < axis; ++i)
{
arg0_padded_shape.insert(arg0_padded_shape.begin(), 1);
}
while (arg0_padded_shape.size() < arg1_shape.size())
{
arg0_padded_shape.insert(arg0_padded_shape.end(), 1);
}
while (arg2_padded_shape.size() > 0 && arg2_padded_shape.back() == 1)
{
arg2_padded_shape.pop_back();
}
for (int64_t i = 0; i < axis; ++i)
{
arg2_padded_shape.insert(arg2_padded_shape.begin(), 1);
}
while (arg2_padded_shape.size() < arg1_shape.size())
{
arg2_padded_shape.insert(arg2_padded_shape.end(), 1);
}
Shape arg0_squeezed_shape;
AxisSet arg0_squeezed_axes;
Shape arg2_squeezed_shape;
AxisSet arg2_squeezed_axes;
for (size_t i = 0; i < arg1_shape.size(); i++)
{
if (arg0_padded_shape[i] == 1)
{
arg0_squeezed_axes.insert(i);
}
else
{
arg0_squeezed_shape.push_back(arg0_padded_shape[i]);
}
if (arg2_padded_shape[i] == 1)
{
arg2_squeezed_axes.insert(i);
}
else
{
arg2_squeezed_shape.push_back(arg2_padded_shape[i]);
}
}
CoordinateTransform arg0_transform(arg0_squeezed_shape);
CoordinateTransform arg1_transform(arg1_shape);
CoordinateTransform arg2_transform(arg2_squeezed_shape);
CoordinateTransform output_transform(arg1_shape);
for (const Coordinate& output_coord : output_transform)
{
Coordinate arg0_coord = reduce(output_coord, arg0_squeezed_axes);
Coordinate arg2_coord = reduce(output_coord, arg2_squeezed_axes);
out[output_transform.index(output_coord)] =
elementwise_functor(arg0[arg0_transform.index(arg0_coord)],
arg1[arg1_transform.index(output_coord)],
arg2[arg2_transform.index(arg2_coord)]);
}
}
}
}
}
}
}
......@@ -19,6 +19,8 @@
#include <cstddef>
#include <iostream>
#include "ngraph/runtime/reference/autobroadcast_binop.hpp"
namespace ngraph
{
namespace runtime
......@@ -37,6 +39,28 @@ namespace ngraph
out[i] = arg0[i] ? arg1[i] : arg2[i];
}
}
template <typename T>
void select(const char* arg0,
const T* arg1,
const T* arg2,
T* out,
const Shape& arg0_shape,
const Shape& arg1_shape,
const Shape& arg2_shape,
const op::AutoBroadcastSpec& broadcast_spec)
{
autobroadcast_select(
arg0,
arg1,
arg2,
out,
arg0_shape,
arg1_shape,
arg2_shape,
broadcast_spec,
[](char s, T x, T y) -> T { return static_cast<T>(s ? x : y); });
}
}
}
}
......@@ -2613,6 +2613,15 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
node = make_shared<op::Select>(args[0], args[1], args[2]);
break;
}
case OP_TYPEID::Select_v1:
{
node = make_shared<op::v1::Select>(
args[0],
args[1],
args[2],
read_auto_broadcast(node_js, "auto_broadcast", op::AutoBroadcastType::NUMPY));
break;
}
case OP_TYPEID::Selu:
{
node = make_shared<op::Selu>(args[0], args[1], args[2]);
......@@ -4345,6 +4354,12 @@ json JSONSerializer::serialize_node(const Node& n)
}
case OP_TYPEID::Select: { break;
}
case OP_TYPEID::Select_v1:
{
auto tmp = static_cast<const op::v1::Select*>(&n);
node["auto_broadcast"] = write_auto_broadcast(tmp->get_auto_broadcast());
break;
}
case OP_TYPEID::Selu: { break;
}
case OP_TYPEID::Send:
......
......@@ -87,6 +87,7 @@ set(SRC
opset_pass/poolings_opset_pass.cpp
opset_pass/product_opset_pass.cpp
opset_pass/reverse_opset_pass.cpp
opset_pass/select_opset_pass.cpp
opset_pass/slice_opset_pass.cpp
opset_pass/softmax_opset_pass.cpp
opset_pass/sum_opset_pass.cpp
......
......@@ -53,6 +53,30 @@ NGRAPH_TEST(${BACKEND_NAME}, select)
MIN_FLOAT_TOLERANCE_BITS));
}
NGRAPH_TEST(${BACKEND_NAME}, select_v1)
{
auto A = make_shared<op::Parameter>(element::boolean, Shape{4});
auto B = make_shared<op::Parameter>(element::f32, Shape{4});
auto C = make_shared<op::Parameter>(element::f32, Shape{2, 4});
auto f = make_shared<Function>(make_shared<op::v1::Select>(A, B, C), ParameterVector{A, B, C});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::boolean, Shape{4});
copy_data(a, vector<char>{0, 1, 1, 0});
auto b = backend->create_tensor(element::f32, Shape{4});
copy_data(b, vector<float>{1, 2, 3, 4});
auto c = backend->create_tensor(element::f32, Shape{2, 4});
copy_data(c, vector<float>{11, 12, 13, 14, 15, 16, 17, 18});
auto result = backend->create_tensor(element::f32, Shape{2, 4});
auto handle = backend->compile(f);
handle->call_with_validate({result}, {a, b, c});
EXPECT_TRUE(
test::all_close_f((vector<float>{11, 2, 3, 14, 15, 2, 3, 18}), read_vector<float>(result)));
}
NGRAPH_TEST(${BACKEND_NAME}, select_double)
{
Shape shape{2, 2, 2};
......
......@@ -1837,6 +1837,35 @@ TEST(constant_folding, constant_select)
ASSERT_EQ(values_expected, values_out);
}
TEST(constant_folding, constant_v1_select)
{
Shape shape{2, 4};
vector<char> values_selection{0, 1, 1, 0};
vector<int64_t> values_t{1, 2, 3, 4};
vector<int64_t> values_f{11, 12, 13, 14, 15, 16, 17, 18};
auto constant_selection =
make_shared<op::Constant>(element::boolean, Shape{4}, values_selection);
auto constant_t = make_shared<op::Constant>(element::i64, Shape{4}, values_t);
auto constant_f = make_shared<op::Constant>(element::i64, Shape{2, 4}, values_f);
auto select = make_shared<op::v1::Select>(constant_selection, constant_t, constant_f);
auto f = make_shared<Function>(select, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Select>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const = as_type_ptr<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
auto values_out = new_const->get_vector<int64_t>();
vector<int64_t> values_expected{11, 2, 3, 14, 15, 2, 3, 18};
ASSERT_EQ(values_expected, values_out);
}
TEST(constant_folding, pass_property)
{
auto pass = std::make_shared<ngraph::pass::ConstantFolding>();
......
......@@ -124,7 +124,7 @@ TEST(opset, check_opset1)
CHECK_OPSET(op::v1::Reverse, opset1::Reverse)
CHECK_OPSET(op::v0::ReverseSequence, opset1::ReverseSequence)
// CHECK_OPSET(op::v0::RNNCell, opset1::RNNCell)
CHECK_OPSET(op::v0::Select, opset1::Select)
CHECK_OPSET(op::v1::Select, opset1::Select)
CHECK_OPSET(op::v0::Selu, opset1::Selu)
CHECK_OPSET(op::v0::ShapeOf, opset1::ShapeOf)
CHECK_OPSET(op::v0::ShuffleChannels, opset1::ShuffleChannels)
......
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/opset0_downgrade.hpp"
#include "ngraph/pass/opset1_upgrade.hpp"
#include "util/test_control.hpp"
#include "util/type_prop.hpp"
using namespace std;
using namespace ngraph;
TEST(opset_transform, opset0_select_downgrade_pass)
{
auto cond = make_shared<op::Parameter>(element::boolean, Shape{2});
auto ptrue = make_shared<op::Parameter>(element::f32, Shape{4, 2});
auto pfalse = make_shared<op::Parameter>(element::f32, Shape{4, 2});
auto v1_node = make_shared<op::v1::Select>(cond, ptrue, pfalse);
auto result = make_shared<op::Result>(v1_node);
auto f = make_shared<Function>(ResultVector{result}, ParameterVector{cond, ptrue, pfalse});
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<pass::Opset0Downgrade>();
pass_manager.run_passes(f);
auto v0_result = f->get_results().at(0);
auto node = v0_result->input_value(0).get_node_shared_ptr();
auto v0_node = as_type_ptr<op::v0::Select>(node);
ASSERT_TRUE(v0_node);
EXPECT_EQ(v0_node->output(0).get_element_type(), element::f32);
EXPECT_EQ(v0_node->output(0).get_shape(), (Shape{4, 2}));
}
TEST(opset_transform, opset1_select_upgrade_pass)
{
auto cond = make_shared<op::Parameter>(element::boolean, Shape{4, 2});
auto ptrue = make_shared<op::Parameter>(element::f32, Shape{4, 2});
auto pfalse = make_shared<op::Parameter>(element::f32, Shape{4, 2});
auto v0_node = make_shared<op::v0::Select>(cond, ptrue, pfalse);
auto result = make_shared<op::Result>(v0_node);
auto f = make_shared<Function>(ResultVector{result}, ParameterVector{cond, ptrue, pfalse});
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<pass::Opset1Upgrade>();
pass_manager.run_passes(f);
auto v1_result = f->get_results().at(0);
auto node = v1_result->input_value(0).get_node_shared_ptr();
auto v1_node = as_type_ptr<op::v1::Select>(node);
ASSERT_TRUE(v1_node);
EXPECT_EQ(v1_node->get_auto_broadcast(), op::AutoBroadcastSpec());
EXPECT_EQ(v1_node->output(0).get_element_type(), element::f32);
EXPECT_EQ(v1_node->output(0).get_shape(), (Shape{4, 2}));
}
......@@ -108,7 +108,7 @@ TEST(type_prop, select_elem_mismatch_a)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Argument 0 does not have boolean element type"));
std::string("Argument 0 must have boolean element type"));
}
catch (...)
{
......@@ -290,3 +290,181 @@ TEST(type_prop, select_partial_all_rank_static_intransitive_incompatibility)
FAIL() << "Deduced type check failed for unexpected reason";
}
}
//------------------------------ v1::Select ---------------------------------//
//
//
struct SelectParams
{
std::vector<Shape> shapes;
std::vector<element::Type> ets;
op::AutoBroadcastSpec auto_broadcast;
SelectParams(const std::vector<Shape>& shape,
const std::vector<element::Type>& et,
const op::AutoBroadcastSpec& auto_broadcast)
: shapes(shape)
, ets(et)
, auto_broadcast(auto_broadcast)
{
}
};
struct DeduceV1SelectTest : ::testing::TestWithParam<SelectParams>
{
};
TEST_P(DeduceV1SelectTest, output_shape)
{
auto tp = GetParam();
auto cond = make_shared<op::Parameter>(tp.ets[0], tp.shapes[0]);
auto ptrue = make_shared<op::Parameter>(tp.ets[1], tp.shapes[1]);
auto pfalse = make_shared<op::Parameter>(tp.ets[2], tp.shapes[2]);
auto select = make_shared<op::v1::Select>(cond, ptrue, pfalse, tp.auto_broadcast);
ASSERT_EQ(select->get_shape(), tp.shapes[3]);
ASSERT_EQ(select->get_element_type(), tp.ets[3]);
}
INSTANTIATE_TEST_CASE_P(
type_prop,
DeduceV1SelectTest,
::testing::Values(SelectParams({{2, 4}, {2, 4}, {2, 4}, {2, 4}},
{element::boolean, element::f32, element::f32, element::f32},
op::AutoBroadcastType::NONE),
SelectParams({{2, 4}, {2, 4}, {2, 4}, {2, 4}},
{element::boolean, element::f32, element::f32, element::f32},
op::AutoBroadcastType::NUMPY),
SelectParams({{}, {2, 4}, {2, 4}, {2, 4}},
{element::boolean, element::f32, element::f32, element::f32},
op::AutoBroadcastType::NUMPY),
SelectParams({{}, {4}, {2, 4}, {2, 4}},
{element::boolean, element::f32, element::dynamic, element::f32},
op::AutoBroadcastType::NUMPY),
SelectParams({{}, {2, 4}, {4}, {2, 4}},
{element::boolean, element::f32, element::f32, element::f32},
op::AutoBroadcastType::NUMPY),
SelectParams({{4}, {2, 4}, {4}, {2, 4}},
{element::boolean, element::i8, element::dynamic, element::i8},
op::AutoBroadcastType::NUMPY),
SelectParams({{4}, {4}, {2, 4}, {2, 4}},
{element::dynamic, element::dynamic, element::i8, element::i8},
op::AutoBroadcastType::NUMPY),
SelectParams({{2}, {2}, {2, 4}, {2, 4}},
{element::boolean, element::f32, element::dynamic, element::f32},
{op::AutoBroadcastType::PDPD, 0}),
// TODO: Whats the right behavior here?
// SelectParams({{2}, {2, 4}, {2}, {2, 4}}, {element::boolean, element::f32,
// element::dynamic, element::f32}, {op::AutoBroadcastType::PDPD, 0}),
SelectParams({{4}, {4}, {2, 4}, {2, 4}},
{element::boolean, element::f32, element::dynamic, element::f32},
{op::AutoBroadcastType::PDPD, 1})),
PrintToDummyParamName());
TEST(type_prop, select_v1_partial_shape)
{
auto a = make_shared<op::Parameter>(element::boolean, PartialShape::dynamic());
auto b = make_shared<op::Parameter>(element::f32, Shape{2, 4});
auto c = make_shared<op::Parameter>(element::f32, Shape{2, 4});
auto select = make_shared<op::v1::Select>(a, b, c, op::AutoBroadcastType::NONE);
ASSERT_EQ(select->get_shape(), (Shape{2, 4}));
}
TEST(type_prop, select_v1_partial_shape_autob)
{
auto a = make_shared<op::Parameter>(element::boolean, PartialShape{Dimension::dynamic()});
auto b = make_shared<op::Parameter>(element::f32, PartialShape{Dimension::dynamic()});
auto c = make_shared<op::Parameter>(element::f32, PartialShape{2, Dimension::dynamic()});
auto select = make_shared<op::v1::Select>(a, b, c);
ASSERT_TRUE(
select->get_output_partial_shape(0).same_scheme(PartialShape{2, Dimension::dynamic()}));
}
TEST(type_prop, select_v1_wrong_et)
{
auto param0 = make_shared<op::Parameter>(element::i8, Shape{2, 4});
auto param1 = make_shared<op::Parameter>(element::f32, Shape{2, 4});
auto param2 = make_shared<op::Parameter>(element::f32, Shape{2, 4});
try
{
auto sel = make_shared<op::v1::Select>(param0, param1, param2);
FAIL() << "Did not detect wrong element type";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Argument 0 must have boolean element type"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, select_v1_et_mismatch)
{
auto param0 = make_shared<op::Parameter>(element::boolean, Shape{2, 4});
auto param1 = make_shared<op::Parameter>(element::f32, Shape{2, 4});
auto param2 = make_shared<op::Parameter>(element::i8, Shape{2, 4});
try
{
auto sel = make_shared<op::v1::Select>(param0, param1, param2);
FAIL() << "Did not detect element type mismatch";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Argument 1 and 2 element types must match."));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, select_v1_shape_mismatch)
{
auto param0 = make_shared<op::Parameter>(element::boolean, Shape{2, 4});
auto param1 = make_shared<op::Parameter>(element::f32, Shape{2, 3});
auto param2 = make_shared<op::Parameter>(element::f32, Shape{2, 4});
try
{
auto sel = make_shared<op::v1::Select>(param0, param1, param2);
FAIL() << "Did not detect shape mismatch";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Argument shapes are inconsistent."));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, select_v1_partial_shape_mismatch)
{
auto param0 =
make_shared<op::Parameter>(element::boolean, PartialShape{3, Dimension::dynamic()});
auto param1 = make_shared<op::Parameter>(element::f32, PartialShape{2, Dimension::dynamic()});
auto param2 = make_shared<op::Parameter>(element::f32, Shape{2, 4});
try
{
auto sel = make_shared<op::v1::Select>(param0, param1, param2);
FAIL() << "Did not detect shape mismatch";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Argument shapes are inconsistent."));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment