Commit b1f8cfa1 authored by Adam Procter's avatar Adam Procter Committed by Scott Cyphers

Add (dynamic) Range op (#3098)

* Add 'Range' op with type prop tests (no DynElimination yet)

* Implement DynElimination for Range

* Add bailouts for GPU and INTELGPU backends

* Add some execution tests

* Add missing include for GPU

* Add /bigobj flag for MSVS on unit-test
parent b8056257
......@@ -160,6 +160,8 @@ set (SRC
op/experimental/quantized_conv_relu.hpp
op/experimental/quantized_max_pool.cpp
op/experimental/quantized_max_pool.hpp
op/experimental/range.cpp
op/experimental/range.hpp
op/experimental/shape_of.cpp
op/experimental/shape_of.hpp
op/experimental/tile.cpp
......
......@@ -91,6 +91,7 @@
#include "ngraph/op/experimental/dyn_pad.hpp"
#include "ngraph/op/experimental/dyn_reshape.hpp"
#include "ngraph/op/experimental/dyn_slice.hpp"
#include "ngraph/op/experimental/range.hpp"
#include "ngraph/op/experimental/shape_of.hpp"
#include "ngraph/op/experimental/tile.hpp"
#include "ngraph/op/experimental/transpose.hpp"
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <algorithm>
#include "ngraph/op/constant.hpp"
#include "ngraph/op/experimental/range.hpp"
using namespace std;
using namespace ngraph;
const string op::Range::type_name = "Range";
op::Range::Range()
{
}
op::Range::Range(const Output<Node>& start, const Output<Node>& stop, const Output<Node>& step)
: Op({start, stop, step})
{
constructor_validate_and_infer_types();
}
template <typename T>
static typename std::enable_if<std::is_integral<T>::value, void>::type
check_start(const op::Range* node, T start)
{
// Nothing to check for integral types.
}
template <typename T>
static typename std::enable_if<std::is_integral<T>::value, void>::type
check_stop(const op::Range* node, T stop)
{
// Nothing to check for integral types.
}
template <typename T>
static typename std::enable_if<std::is_integral<T>::value, void>::type
check_step(const op::Range* node, T step)
{
NODE_VALIDATION_CHECK(node, step != 0, "'step' cannot be zero.");
}
//
// The code in the following three functions is a bit awkward, to work around some compiler
// warnings and the need to support our custom float16/bfloat16 type:
//
// (1) We can't use STL things like isnan, because our custom float16/bfloat16 types don't always
// support them.
// (2) We check whether (x - x) == (x - x) to check for "is_finite".
// (3) We have to break (x - x) out into a temporary because otherwise the compiler throws a
// warning about == on floats.
// (4) We check <0 || >0 to check for != 0, because otherwise the compiler throws a warning about
// == on floats.
//
template <typename T>
static
typename std::enable_if<std::is_floating_point<T>::value || std::is_same<T, float16>::value ||
std::is_same<T, bfloat16>::value,
void>::type
check_start(const op::Range* node, T start)
{
T start_minus_start = start - start;
NODE_VALIDATION_CHECK(node,
start == start && start_minus_start == start_minus_start,
"'start' cannot be nan or infinite.");
}
template <typename T>
static
typename std::enable_if<std::is_floating_point<T>::value || std::is_same<T, float16>::value ||
std::is_same<T, bfloat16>::value,
void>::type
check_stop(const op::Range* node, T stop)
{
T stop_minus_stop = stop - stop;
NODE_VALIDATION_CHECK(node,
stop == stop && stop_minus_stop == stop_minus_stop,
"'stop' cannot be nan or infinite.");
}
template <typename T>
static
typename std::enable_if<std::is_floating_point<T>::value || std::is_same<T, float16>::value ||
std::is_same<T, bfloat16>::value,
void>::type
check_step(const op::Range* node, T step)
{
T step_minus_step = step - step;
NODE_VALIDATION_CHECK(node,
step == step && step_minus_step == step_minus_step &&
(step > static_cast<T>(0) || step < static_cast<T>(0)),
"'step' cannot be zero, nan, or infinite.");
}
template <typename T>
static typename std::enable_if<std::is_integral<T>::value, T>::type adjust_for_step_and_sign(T span,
T step)
{
return ceil_div(span < 0 ? -span : span, step < 0 ? -step : step);
}
template <typename T>
static
typename std::enable_if<std::is_floating_point<T>::value || std::is_same<T, float16>::value ||
std::is_same<T, bfloat16>::value,
T>::type
adjust_for_step_and_sign(T span, T step)
{
return ceil(fabs(span) / fabs(step));
}
template <typename T>
static PartialShape infer_output_shape(const op::Range* node, const element::Type& et)
{
auto const_start = dynamic_pointer_cast<op::Constant>(node->get_argument(0));
auto const_stop = dynamic_pointer_cast<op::Constant>(node->get_argument(1));
auto const_step = dynamic_pointer_cast<op::Constant>(node->get_argument(2));
T start = static_cast<T>(0);
T stop = static_cast<T>(0);
T step = static_cast<T>(0);
if (const_start != nullptr)
{
std::vector<T> start_val = const_start->get_vector<T>();
NODE_VALIDATION_CHECK(node, start_val.size() == 1);
start = start_val[0];
check_start<T>(node, start);
}
if (const_stop != nullptr)
{
std::vector<T> stop_val = const_stop->get_vector<T>();
NODE_VALIDATION_CHECK(node, stop_val.size() == 1);
stop = stop_val[0];
check_stop<T>(node, stop);
}
if (const_step != nullptr)
{
std::vector<T> step_val = const_step->get_vector<T>();
NODE_VALIDATION_CHECK(node, step_val.size() == 1);
step = step_val[0];
check_step<T>(node, step);
}
PartialShape result{PartialShape::dynamic(1)};
if (const_start != nullptr && const_stop != nullptr && const_step != nullptr)
{
T span;
if (step > static_cast<T>(0) && start >= stop)
{
span = static_cast<T>(0);
}
else if (step < static_cast<T>(0) && start <= stop)
{
span = static_cast<T>(0);
}
else
{
span = stop - start;
}
T strided = adjust_for_step_and_sign<T>(span, step);
result = PartialShape{Dimension(static_cast<int64_t>(strided))};
}
return result;
}
void op::Range::validate_and_infer_types()
{
set_input_is_relevant_to_shape(0);
set_input_is_relevant_to_shape(1);
set_input_is_relevant_to_shape(2);
auto result_et = element::dynamic;
NODE_VALIDATION_CHECK(
this,
element::Type::merge(result_et, result_et, get_input_element_type(0)) &&
element::Type::merge(result_et, result_et, get_input_element_type(1)) &&
element::Type::merge(result_et, result_et, get_input_element_type(2)),
"Element types for start, stop, and step do not match.");
NODE_VALIDATION_CHECK(this,
result_et != element::boolean,
"Element type for start, stop, and step, must not be boolean.");
NODE_VALIDATION_CHECK(
this, get_input_partial_shape(0).compatible(Shape{}), "'start' input is not a scalar");
NODE_VALIDATION_CHECK(
this, get_input_partial_shape(0).compatible(Shape{}), "'stop' input is not a scalar");
NODE_VALIDATION_CHECK(
this, get_input_partial_shape(0).compatible(Shape{}), "'step' input is not a scalar");
PartialShape result_shape;
#if !(defined(__GNUC__) && (__GNUC__ == 4 && __GNUC_MINOR__ == 8))
#pragma GCC diagnostic push
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum"
#endif
switch (result_et.get_type_enum())
{
case element::Type_t::bf16: result_shape = infer_output_shape<bfloat16>(this, result_et); break;
case element::Type_t::f16: result_shape = infer_output_shape<float16>(this, result_et); break;
case element::Type_t::f32: result_shape = infer_output_shape<float>(this, result_et); break;
case element::Type_t::f64: result_shape = infer_output_shape<double>(this, result_et); break;
case element::Type_t::i8: result_shape = infer_output_shape<int8_t>(this, result_et); break;
case element::Type_t::i16: result_shape = infer_output_shape<int16_t>(this, result_et); break;
case element::Type_t::i32: result_shape = infer_output_shape<int32_t>(this, result_et); break;
case element::Type_t::i64: result_shape = infer_output_shape<int64_t>(this, result_et); break;
case element::Type_t::u8: result_shape = infer_output_shape<uint8_t>(this, result_et); break;
case element::Type_t::u16: result_shape = infer_output_shape<uint16_t>(this, result_et); break;
case element::Type_t::u32: result_shape = infer_output_shape<uint32_t>(this, result_et); break;
case element::Type_t::u64: result_shape = infer_output_shape<uint64_t>(this, result_et); break;
case element::Type_t::dynamic: result_shape = PartialShape::dynamic(1); break;
case element::Type_t::undefined:
case element::Type_t::boolean:
NODE_VALIDATION_CHECK(
this, false, "Internal nGraph error: unsupported element type: ", result_et);
break;
}
#if !(defined(__GNUC__) && __GNUC__ == 4 && __GNUC_MINOR__ == 8)
#pragma GCC diagnostic pop
#endif
set_output_type(0, result_et, result_shape);
}
shared_ptr<Node> op::Range::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<Range>(new_args.at(0), new_args.at(1), new_args.at(2));
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/node.hpp"
#include "ngraph/op/op.hpp"
namespace ngraph
{
namespace op
{
/// \brief Range operation, analogous to `range()` in Python.
class Range : public Op
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs an unitialized range operation.
Range();
/// \brief Constructs a range operation.
///
/// \param start The tensor producing the start value. Must be a scalar of integer
/// element type, and same element type as `stop` and `step`.
/// \param stop The tensor producing the stop value. Must be a scalar of integer
/// element type, and same element type as `start` and `step`.
/// \param step The tensor producing the step value. Must be a scalar of integer
/// element type, and same element type as `start` and `stop`.
Range(const Output<Node>& start, const Output<Node>& stop, const Output<Node>& step);
void validate_and_infer_types() override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
};
}
}
......@@ -127,6 +127,7 @@ NGRAPH_OP(QuantizedConvolution, ngraph::op)
NGRAPH_OP(QuantizedDotBias, ngraph::op)
NGRAPH_OP(QuantizedDot, ngraph::op)
NGRAPH_OP(QuantizedMaxPool, ngraph::op)
NGRAPH_OP(Range, ngraph::op)
NGRAPH_OP(Relu, ngraph::op)
NGRAPH_OP(ReluBackprop, ngraph::op)
NGRAPH_OP(ReplaceSlice, ngraph::op)
......
......@@ -18,6 +18,7 @@
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/experimental/dyn_broadcast.hpp"
#include "ngraph/op/experimental/dyn_slice.hpp"
#include "ngraph/op/experimental/range.hpp"
#include "ngraph/op/experimental/transpose.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/reverse.hpp"
......@@ -34,6 +35,7 @@ pass::DynElimination::DynElimination()
construct_transpose();
construct_broadcast();
construct_dyn_reshape();
construct_range();
}
void pass::DynElimination::construct_transpose()
......@@ -437,3 +439,141 @@ void pass::DynElimination::construct_dyn_reshape()
make_shared<pattern::Matcher>(dyn_slice_pat, "DynElimination.DynShape");
add_matcher(dyn_slice_matcher, dyn_slice_callback, all_pass_property_off);
}
template <typename T>
std::shared_ptr<op::Constant>
make_range_replacement_integral(const element::Type& et,
const Shape& shape,
const std::shared_ptr<op::Constant>& start_arg,
const std::shared_ptr<op::Constant>& step_arg)
{
std::vector<T> elements(shape_size(shape));
std::vector<T> start_vec = start_arg->get_vector<T>();
std::vector<T> step_vec = step_arg->get_vector<T>();
NGRAPH_CHECK(start_vec.size() == 1 && step_vec.size() == 1);
T start = start_vec[0];
T step = step_vec[0];
T val = start;
for (size_t i = 0; i < elements.size(); i++)
{
elements[i] = val;
val = val + step;
}
return make_shared<op::Constant>(et, shape, elements);
}
template <typename T>
std::shared_ptr<op::Constant>
make_range_replacement_floating(const element::Type& et,
const Shape& shape,
const std::shared_ptr<op::Constant>& start_arg,
const std::shared_ptr<op::Constant>& step_arg)
{
std::vector<T> elements(shape_size(shape));
std::vector<T> start_vec = start_arg->get_vector<T>();
std::vector<T> step_vec = step_arg->get_vector<T>();
NGRAPH_CHECK(start_vec.size() == 1 && step_vec.size() == 1);
T start = start_vec[0];
T step = step_vec[0];
for (size_t i = 0; i < elements.size(); i++)
{
elements[i] = start + (static_cast<T>(i) * step);
}
return make_shared<op::Constant>(et, shape, elements);
}
void pass::DynElimination::construct_range()
{
auto start_arg_label =
make_shared<pattern::op::Label>(element::f32, Shape{}, pattern::has_class<op::Constant>());
auto stop_arg_label =
make_shared<pattern::op::Label>(element::f32, Shape{}, pattern::has_class<op::Constant>());
auto step_arg_label =
make_shared<pattern::op::Label>(element::f32, Shape{}, pattern::has_class<op::Constant>());
auto range_pat = make_shared<op::Range>(start_arg_label, stop_arg_label, step_arg_label);
auto range_callback = [start_arg_label, stop_arg_label, step_arg_label](pattern::Matcher& m) {
auto pattern_map = m.get_pattern_map();
auto start_arg = static_pointer_cast<op::Constant>(pattern_map[start_arg_label]);
auto step_arg = static_pointer_cast<op::Constant>(pattern_map[step_arg_label]);
auto range_node = static_pointer_cast<op::Range>(m.get_match_root());
NGRAPH_CHECK(start_arg->get_output_partial_shape(0).rank().compatible(0) &&
step_arg->get_output_partial_shape(0).rank().compatible(0));
auto et = range_node->get_output_element_type(0);
auto shape = range_node->get_output_shape(0);
std::shared_ptr<op::Constant> replacement;
#if !(defined(__GNUC__) && (__GNUC__ == 4 && __GNUC_MINOR__ == 8))
#pragma GCC diagnostic push
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum"
#endif
switch (et.get_type_enum())
{
case element::Type_t::bf16:
replacement = make_range_replacement_floating<bfloat16>(et, shape, start_arg, step_arg);
break;
case element::Type_t::f16:
replacement = make_range_replacement_floating<float16>(et, shape, start_arg, step_arg);
break;
case element::Type_t::f32:
replacement = make_range_replacement_floating<float>(et, shape, start_arg, step_arg);
break;
case element::Type_t::f64:
replacement = make_range_replacement_floating<double>(et, shape, start_arg, step_arg);
break;
case element::Type_t::i8:
replacement = make_range_replacement_integral<int8_t>(et, shape, start_arg, step_arg);
break;
case element::Type_t::i16:
replacement = make_range_replacement_integral<int16_t>(et, shape, start_arg, step_arg);
break;
case element::Type_t::i32:
replacement = make_range_replacement_integral<int32_t>(et, shape, start_arg, step_arg);
break;
case element::Type_t::i64:
replacement = make_range_replacement_integral<int64_t>(et, shape, start_arg, step_arg);
break;
case element::Type_t::u8:
replacement = make_range_replacement_integral<uint8_t>(et, shape, start_arg, step_arg);
break;
case element::Type_t::u16:
replacement = make_range_replacement_integral<uint16_t>(et, shape, start_arg, step_arg);
break;
case element::Type_t::u32:
replacement = make_range_replacement_integral<uint32_t>(et, shape, start_arg, step_arg);
break;
case element::Type_t::u64:
replacement = make_range_replacement_integral<uint64_t>(et, shape, start_arg, step_arg);
break;
case element::Type_t::undefined:
case element::Type_t::dynamic:
case element::Type_t::boolean:
NGRAPH_CHECK(false, "Internal nGraph error: unsupported element type: ", et);
break;
}
#if !(defined(__GNUC__) && __GNUC__ == 4 && __GNUC_MINOR__ == 8)
#pragma GCC diagnostic pop
#endif
replace_node(range_node, replacement);
return true;
};
auto range_matcher = make_shared<pattern::Matcher>(range_pat, "DynElimination.Range");
add_matcher(range_matcher, range_callback, all_pass_property_off);
}
......@@ -32,6 +32,7 @@ namespace ngraph
void construct_transpose();
void construct_broadcast();
void construct_dyn_reshape();
void construct_range();
};
}
}
......@@ -231,7 +231,8 @@ bool runtime::gpu::GPU_Backend::is_supported(const Node& op) const
"EmbeddingLookup",
"GenerateMask",
"DynBroadcast",
"Transpose"};
"Transpose",
"Range"};
set<string> float_only = {"MaxPoolBackprop", "AvgPoolBackprop", "MaxPool", "Dot"};
......
......@@ -72,6 +72,7 @@
#include "ngraph/op/experimental/quantized_dot.hpp"
#include "ngraph/op/experimental/quantized_dot_bias.hpp"
#include "ngraph/op/experimental/quantized_max_pool.hpp"
#include "ngraph/op/experimental/range.hpp"
#include "ngraph/op/experimental/shape_of.hpp"
#include "ngraph/op/experimental/tile.hpp"
#include "ngraph/op/experimental/transpose.hpp"
......@@ -994,6 +995,11 @@ std::string runtime::gpu::GPU_Emitter::emit_QuantizedMaxPool(EMIT_ARGS)
throw unsupported_op("Unsupported op '" + node->description() + "'");
}
std::string runtime::gpu::GPU_Emitter::emit_Range(EMIT_ARGS)
{
throw unsupported_op("Unsupported op '" + node->description() + "'");
}
std::string runtime::gpu::GPU_Emitter::emit_Relu(EMIT_ARGS)
{
return emit_elementwise<ngraph::op::Relu>(compiled_function, function_name, node, args, out);
......
......@@ -2086,6 +2086,7 @@ shared_ptr<runtime::Executable>
case OP_TYPEID::QuantizedDot:
case OP_TYPEID::QuantizedDotBias:
case OP_TYPEID::QuantizedMaxPool:
case OP_TYPEID::Range:
case OP_TYPEID::ReplaceSlice:
case OP_TYPEID::ScalarConstantLike:
case OP_TYPEID::ScaleShift:
......
......@@ -1178,6 +1178,11 @@ private:
throw unsupported_op("Unsupported op '" + node.description() +
"' in Interpreter back end.");
}
case OP_TYPEID::Range:
{
throw unsupported_op("Unsupported op '" + node.description() + "'");
break;
}
case OP_TYPEID::Relu:
{
size_t element_count = shape_size(node.get_output_shape(0));
......
......@@ -64,6 +64,7 @@
#include "ngraph/op/experimental/quantized_dot.hpp"
#include "ngraph/op/experimental/quantized_dot_bias.hpp"
#include "ngraph/op/experimental/quantized_max_pool.hpp"
#include "ngraph/op/experimental/range.hpp"
#include "ngraph/op/experimental/shape_of.hpp"
#include "ngraph/op/experimental/tile.hpp"
#include "ngraph/op/experimental/transpose.hpp"
......@@ -1557,6 +1558,11 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
break;
}
case OP_TYPEID::Range:
{
node = make_shared<op::Range>(args[0], args[1], args[2]);
break;
}
case OP_TYPEID::Relu:
{
node = make_shared<op::Relu>(args[0]);
......@@ -2522,6 +2528,8 @@ json JSONSerializer::serialize_node(const Node& n)
node["padding_above"] = tmp->get_padding_above();
break;
}
case OP_TYPEID::Range: { break;
}
case OP_TYPEID::Relu: { break;
}
case OP_TYPEID::ReluBackprop: { break;
......
......@@ -240,6 +240,12 @@ if ("${CMAKE_CXX_COMPILER_ID}" MATCHES "^(Apple)?Clang$")
target_compile_options(unit-test PRIVATE -Wno-undef -Wno-reserved-id-macro)
endif()
# So many type_prop tests these days that we need to set /bigobj flag for MSVS.
# We should probably split up type_prop.cpp.
if (MSVS)
target_compile_options(unit-test PRIVATE "/bigobj")
endif()
if (NGRAPH_CPU_ENABLE)
# The INTERPRETER backend is required for convolution, and backwards unit tests
target_link_libraries(unit-test PRIVATE cpu_backend interpreter_backend)
......
......@@ -131,3 +131,66 @@ TEST(dyn_elimination, slice)
ASSERT_EQ(f->get_results().at(0)->get_element_type(), element::f32);
ASSERT_EQ(f->get_results().at(0)->get_shape(), (Shape{2, 4, 2, 2, 1, 2, 2}));
}
TEST(dyn_elimination, range)
{
auto constant_start = make_shared<op::Constant>(element::i64, Shape{}, vector<int64_t>{0});
auto constant_stop = make_shared<op::Constant>(element::i64, Shape{}, vector<int64_t>{5});
auto constant_step = make_shared<op::Constant>(element::i64, Shape{}, vector<int64_t>{2});
auto range = make_shared<op::Range>(constant_start, constant_stop, constant_step);
ASSERT_EQ(range->get_element_type(), element::i64);
ASSERT_EQ(range->get_shape(), (Shape{3}));
auto f = make_shared<Function>(range, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::DynElimination>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Range>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto replacement = dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_NE(replacement, nullptr);
ASSERT_EQ(replacement->get_element_type(), element::i64);
ASSERT_EQ(replacement->get_shape(), (Shape{3}));
auto vals = replacement->get_vector<int64_t>();
ASSERT_EQ(vals, (vector<int64_t>{0, 2, 4}));
}
TEST(dyn_elimination, range_f64)
{
auto constant_start = make_shared<op::Constant>(element::f64, Shape{}, vector<double>{-0.5});
auto constant_stop = make_shared<op::Constant>(element::f64, Shape{}, vector<double>{2});
auto constant_step = make_shared<op::Constant>(element::f64, Shape{}, vector<double>{0.25});
auto range = make_shared<op::Range>(constant_start, constant_stop, constant_step);
ASSERT_EQ(range->get_element_type(), element::f64);
ASSERT_EQ(range->get_shape(), (Shape{10}));
auto f = make_shared<Function>(range, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::DynElimination>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Range>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto replacement = dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_NE(replacement, nullptr);
ASSERT_EQ(replacement->get_element_type(), element::f64);
ASSERT_EQ(replacement->get_shape(), (Shape{10}));
auto vals = replacement->get_vector<double>();
ASSERT_TRUE(test::all_close_f(
vals, vector<double>{-0.5, -0.25, 0, 0.25, 0.5, 0.75, 1.0, 1.25, 1.5, 1.75}));
}
......@@ -308,3 +308,60 @@ NGRAPH_TEST(dynamic_${BACKEND_NAME}, all)
ASSERT_EQ(results, expected_results[i]);
}
}
template <typename T>
struct RangeTest
{
T start;
T stop;
T step;
Shape expected_result_shape;
std::vector<T> expected_result;
};
// TODO(amprocte): We should test this with more than just int32, but there is a bug in the
// handling of element type-changing that is currently blocking doing that easily.
NGRAPH_TEST(dynamic_${BACKEND_NAME}, range)
{
// Create a graph for f(start,stop,step) = Range(start,stop,step).
auto start = make_shared<op::Parameter>(element::i32, Shape{});
auto stop = make_shared<op::Parameter>(element::i32, Shape{});
auto step = make_shared<op::Parameter>(element::i32, Shape{});
auto range = make_shared<op::Range>(start, stop, step);
ASSERT_TRUE(range->get_output_partial_shape(0).same_scheme(PartialShape::dynamic(1)));
auto f = make_shared<Function>(NodeVector{range}, ParameterVector{start, stop, step});
auto backend = runtime::Backend::create("${BACKEND_NAME}", true);
auto ex = backend->compile(f);
auto t_r = backend->create_dynamic_tensor(element::i32, PartialShape::dynamic());
std::vector<RangeTest<int32_t>> int32_tests = {
RangeTest<int32_t>{0, 10, 1, Shape{10}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}},
RangeTest<int32_t>{-5, 6, 3, Shape{4}, {-5, -2, 1, 4}},
RangeTest<int32_t>{10, 0, 1, Shape{0}, {}},
RangeTest<int32_t>{10, 5, -3, Shape{2}, {10, 7}}};
for (auto& test : int32_tests)
{
auto t_start = backend->create_tensor(element::i32, Shape{});
auto t_stop = backend->create_tensor(element::i32, Shape{});
auto t_step = backend->create_tensor(element::i32, Shape{});
copy_data(t_start, std::vector<int32_t>{test.start});
copy_data(t_stop, std::vector<int32_t>{test.stop});
copy_data(t_step, std::vector<int32_t>{test.step});
ex->call_with_validate({t_r}, {t_start, t_stop, t_step});
ASSERT_EQ(t_r->get_element_type(), element::i32);
ASSERT_EQ(t_r->get_shape(), test.expected_result_shape);
auto results = read_vector<int32_t>(t_r);
ASSERT_EQ(results, test.expected_result);
}
}
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment