Commit 9e083bfe authored by Adam Procter's avatar Adam Procter Committed by Scott Cyphers

CF for Range (#3356)

parent d0b97e73
...@@ -32,6 +32,7 @@ ...@@ -32,6 +32,7 @@
#include "ngraph/op/equal.hpp" #include "ngraph/op/equal.hpp"
#include "ngraph/op/experimental/dyn_reshape.hpp" #include "ngraph/op/experimental/dyn_reshape.hpp"
#include "ngraph/op/experimental/dyn_slice.hpp" #include "ngraph/op/experimental/dyn_slice.hpp"
#include "ngraph/op/experimental/range.hpp"
#include "ngraph/op/experimental/shape_of.hpp" #include "ngraph/op/experimental/shape_of.hpp"
#include "ngraph/op/experimental/transpose.hpp" #include "ngraph/op/experimental/transpose.hpp"
#include "ngraph/op/floor.hpp" #include "ngraph/op/floor.hpp"
...@@ -87,6 +88,7 @@ ...@@ -87,6 +88,7 @@
#include "ngraph/runtime/reference/pad.hpp" #include "ngraph/runtime/reference/pad.hpp"
#include "ngraph/runtime/reference/product.hpp" #include "ngraph/runtime/reference/product.hpp"
#include "ngraph/runtime/reference/quantize.hpp" #include "ngraph/runtime/reference/quantize.hpp"
#include "ngraph/runtime/reference/range.hpp"
#include "ngraph/runtime/reference/relu.hpp" #include "ngraph/runtime/reference/relu.hpp"
#include "ngraph/runtime/reference/reshape.hpp" #include "ngraph/runtime/reference/reshape.hpp"
#include "ngraph/runtime/reference/reverse.hpp" #include "ngraph/runtime/reference/reverse.hpp"
...@@ -2250,6 +2252,100 @@ void pass::ConstantFolding::construct_constant_dyn_slice() ...@@ -2250,6 +2252,100 @@ void pass::ConstantFolding::construct_constant_dyn_slice()
this->add_matcher(dyn_slice_matcher, constant_dyn_slice_callback, all_pass_property_off); this->add_matcher(dyn_slice_matcher, constant_dyn_slice_callback, all_pass_property_off);
} }
template <class T>
shared_ptr<op::Constant> fold_constant_range(shared_ptr<op::Constant> start,
shared_ptr<op::Constant> step,
shared_ptr<op::Range> range)
{
vector<T> out_vec(shape_size(range->get_shape()));
runtime::reference::range<T>(start->get_vector<T>().data(),
step->get_vector<T>().data(),
range->get_shape(),
out_vec.data());
return make_shared<op::Constant>(range->get_element_type(), range->get_shape(), out_vec);
}
void pass::ConstantFolding::construct_constant_range()
{
auto start_label =
make_shared<pattern::op::Label>(element::i64, Shape{}, pattern::has_class<op::Constant>());
auto stop_label =
make_shared<pattern::op::Label>(element::i64, Shape{}, pattern::has_class<op::Constant>());
auto step_label =
make_shared<pattern::op::Label>(element::i64, Shape{}, pattern::has_class<op::Constant>());
auto range_op = make_shared<op::Range>(start_label, stop_label, step_label);
auto constant_range_callback = [start_label, stop_label, step_label](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for constant_range_callback against node = "
<< m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map();
auto start_node = static_pointer_cast<op::Constant>(pattern_map[start_label]);
auto stop_node = static_pointer_cast<op::Constant>(pattern_map[stop_label]);
auto step_node = static_pointer_cast<op::Constant>(pattern_map[step_label]);
auto range = static_pointer_cast<op::Range>(m.get_match_root());
std::shared_ptr<op::Constant> replacement;
switch (range->get_output_element_type(0).get_type_enum())
{
case element::Type_t::undefined:
NGRAPH_CHECK(false, "Encountered 'undefined' element type in constant_range_callback");
break;
case element::Type_t::dynamic:
NGRAPH_CHECK(false, "Encountered 'dynamic' element type in constant_range_callback");
break;
case element::Type_t::boolean:
replacement = fold_constant_range<char>(start_node, step_node, range);
break;
case element::Type_t::bf16:
replacement = fold_constant_range<bfloat16>(start_node, step_node, range);
break;
case element::Type_t::f16:
replacement = fold_constant_range<float16>(start_node, step_node, range);
break;
case element::Type_t::f32:
replacement = fold_constant_range<float>(start_node, step_node, range);
break;
case element::Type_t::f64:
replacement = fold_constant_range<double>(start_node, step_node, range);
break;
case element::Type_t::i8:
replacement = fold_constant_range<int8_t>(start_node, step_node, range);
break;
case element::Type_t::i16:
replacement = fold_constant_range<int16_t>(start_node, step_node, range);
break;
case element::Type_t::i32:
replacement = fold_constant_range<int32_t>(start_node, step_node, range);
break;
case element::Type_t::i64:
replacement = fold_constant_range<int64_t>(start_node, step_node, range);
break;
case element::Type_t::u8:
replacement = fold_constant_range<uint8_t>(start_node, step_node, range);
break;
case element::Type_t::u16:
replacement = fold_constant_range<uint16_t>(start_node, step_node, range);
break;
case element::Type_t::u32:
replacement = fold_constant_range<uint32_t>(start_node, step_node, range);
break;
case element::Type_t::u64:
replacement = fold_constant_range<uint64_t>(start_node, step_node, range);
break;
}
replace_node(m.get_match_root(), replacement);
return true;
};
auto range_matcher = make_shared<pattern::Matcher>(range_op, "ConstantFolding.ConstantRange");
this->add_matcher(range_matcher, constant_range_callback, all_pass_property_off);
}
template <class T> template <class T>
shared_ptr<op::Constant> fold_constant_select(shared_ptr<op::Constant> selection, shared_ptr<op::Constant> fold_constant_select(shared_ptr<op::Constant> selection,
shared_ptr<op::Constant> t, shared_ptr<op::Constant> t,
......
...@@ -50,6 +50,7 @@ public: ...@@ -50,6 +50,7 @@ public:
DYN_SLICE, DYN_SLICE,
DYN_RESHAPE, DYN_RESHAPE,
TRANSPOSE, TRANSPOSE,
RANGE,
SELECT SELECT
}; };
...@@ -75,6 +76,7 @@ public: ...@@ -75,6 +76,7 @@ public:
construct_constant_dyn_slice(); construct_constant_dyn_slice();
construct_constant_dyn_reshape(); construct_constant_dyn_reshape();
construct_constant_transpose(); construct_constant_transpose();
construct_constant_range();
construct_constant_select(); construct_constant_select();
} }
...@@ -107,6 +109,7 @@ public: ...@@ -107,6 +109,7 @@ public:
case CFTransformations::DYN_SLICE: construct_constant_dyn_slice(); break; case CFTransformations::DYN_SLICE: construct_constant_dyn_slice(); break;
case CFTransformations::DYN_RESHAPE: construct_constant_dyn_reshape(); break; case CFTransformations::DYN_RESHAPE: construct_constant_dyn_reshape(); break;
case CFTransformations::TRANSPOSE: construct_constant_transpose(); break; case CFTransformations::TRANSPOSE: construct_constant_transpose(); break;
case CFTransformations::RANGE: construct_constant_range(); break;
case CFTransformations::SELECT: construct_constant_select(); break; case CFTransformations::SELECT: construct_constant_select(); break;
} }
} }
...@@ -131,6 +134,7 @@ private: ...@@ -131,6 +134,7 @@ private:
void construct_constant_dyn_slice(); void construct_constant_dyn_slice();
void construct_constant_dyn_reshape(); void construct_constant_dyn_reshape();
void construct_constant_transpose(); void construct_constant_transpose();
void construct_constant_range();
void construct_constant_select(); void construct_constant_select();
ngraph::BuildNodeExecutorMap m_cfmap; ngraph::BuildNodeExecutorMap m_cfmap;
......
...@@ -30,6 +30,7 @@ ...@@ -30,6 +30,7 @@
#include "ngraph/op/slice.hpp" #include "ngraph/op/slice.hpp"
#include "ngraph/pattern/matcher.hpp" #include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/label.hpp" #include "ngraph/pattern/op/label.hpp"
#include "ngraph/runtime/reference/range.hpp"
#include "ngraph/slice_plan.hpp" #include "ngraph/slice_plan.hpp"
using namespace std; using namespace std;
...@@ -342,11 +343,10 @@ void pass::DynElimination::construct_dyn_reshape() ...@@ -342,11 +343,10 @@ void pass::DynElimination::construct_dyn_reshape()
} }
template <typename T> template <typename T>
std::shared_ptr<op::Constant> std::shared_ptr<op::Constant> make_range_replacement(const element::Type& et,
make_range_replacement_integral(const element::Type& et, const Shape& shape,
const Shape& shape, const std::shared_ptr<op::Constant>& start_arg,
const std::shared_ptr<op::Constant>& start_arg, const std::shared_ptr<op::Constant>& step_arg)
const std::shared_ptr<op::Constant>& step_arg)
{ {
std::vector<T> elements(shape_size(shape)); std::vector<T> elements(shape_size(shape));
std::vector<T> start_vec = start_arg->get_vector<T>(); std::vector<T> start_vec = start_arg->get_vector<T>();
...@@ -354,40 +354,7 @@ std::shared_ptr<op::Constant> ...@@ -354,40 +354,7 @@ std::shared_ptr<op::Constant>
NGRAPH_CHECK(start_vec.size() == 1 && step_vec.size() == 1); NGRAPH_CHECK(start_vec.size() == 1 && step_vec.size() == 1);
T start = start_vec[0]; runtime::reference::range<T>(start_vec.data(), step_vec.data(), shape, elements.data());
T step = step_vec[0];
T val = start;
for (size_t i = 0; i < elements.size(); i++)
{
elements[i] = val;
val = val + step;
}
return make_shared<op::Constant>(et, shape, elements);
}
template <typename T>
std::shared_ptr<op::Constant>
make_range_replacement_floating(const element::Type& et,
const Shape& shape,
const std::shared_ptr<op::Constant>& start_arg,
const std::shared_ptr<op::Constant>& step_arg)
{
std::vector<T> elements(shape_size(shape));
std::vector<T> start_vec = start_arg->get_vector<T>();
std::vector<T> step_vec = step_arg->get_vector<T>();
NGRAPH_CHECK(start_vec.size() == 1 && step_vec.size() == 1);
T start = start_vec[0];
T step = step_vec[0];
for (size_t i = 0; i < elements.size(); i++)
{
elements[i] = start + (static_cast<T>(i) * step);
}
return make_shared<op::Constant>(et, shape, elements); return make_shared<op::Constant>(et, shape, elements);
} }
...@@ -426,40 +393,40 @@ void pass::DynElimination::construct_range() ...@@ -426,40 +393,40 @@ void pass::DynElimination::construct_range()
switch (et.get_type_enum()) switch (et.get_type_enum())
{ {
case element::Type_t::bf16: case element::Type_t::bf16:
replacement = make_range_replacement_floating<bfloat16>(et, shape, start_arg, step_arg); replacement = make_range_replacement<bfloat16>(et, shape, start_arg, step_arg);
break; break;
case element::Type_t::f16: case element::Type_t::f16:
replacement = make_range_replacement_floating<float16>(et, shape, start_arg, step_arg); replacement = make_range_replacement<float16>(et, shape, start_arg, step_arg);
break; break;
case element::Type_t::f32: case element::Type_t::f32:
replacement = make_range_replacement_floating<float>(et, shape, start_arg, step_arg); replacement = make_range_replacement<float>(et, shape, start_arg, step_arg);
break; break;
case element::Type_t::f64: case element::Type_t::f64:
replacement = make_range_replacement_floating<double>(et, shape, start_arg, step_arg); replacement = make_range_replacement<double>(et, shape, start_arg, step_arg);
break; break;
case element::Type_t::i8: case element::Type_t::i8:
replacement = make_range_replacement_integral<int8_t>(et, shape, start_arg, step_arg); replacement = make_range_replacement<int8_t>(et, shape, start_arg, step_arg);
break; break;
case element::Type_t::i16: case element::Type_t::i16:
replacement = make_range_replacement_integral<int16_t>(et, shape, start_arg, step_arg); replacement = make_range_replacement<int16_t>(et, shape, start_arg, step_arg);
break; break;
case element::Type_t::i32: case element::Type_t::i32:
replacement = make_range_replacement_integral<int32_t>(et, shape, start_arg, step_arg); replacement = make_range_replacement<int32_t>(et, shape, start_arg, step_arg);
break; break;
case element::Type_t::i64: case element::Type_t::i64:
replacement = make_range_replacement_integral<int64_t>(et, shape, start_arg, step_arg); replacement = make_range_replacement<int64_t>(et, shape, start_arg, step_arg);
break; break;
case element::Type_t::u8: case element::Type_t::u8:
replacement = make_range_replacement_integral<uint8_t>(et, shape, start_arg, step_arg); replacement = make_range_replacement<uint8_t>(et, shape, start_arg, step_arg);
break; break;
case element::Type_t::u16: case element::Type_t::u16:
replacement = make_range_replacement_integral<uint16_t>(et, shape, start_arg, step_arg); replacement = make_range_replacement<uint16_t>(et, shape, start_arg, step_arg);
break; break;
case element::Type_t::u32: case element::Type_t::u32:
replacement = make_range_replacement_integral<uint32_t>(et, shape, start_arg, step_arg); replacement = make_range_replacement<uint32_t>(et, shape, start_arg, step_arg);
break; break;
case element::Type_t::u64: case element::Type_t::u64:
replacement = make_range_replacement_integral<uint64_t>(et, shape, start_arg, step_arg); replacement = make_range_replacement<uint64_t>(et, shape, start_arg, step_arg);
break; break;
case element::Type_t::undefined: case element::Type_t::undefined:
case element::Type_t::dynamic: case element::Type_t::dynamic:
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <cmath>
#include <type_traits>
#include "ngraph/axis_vector.hpp"
#include "ngraph/check.hpp"
#include "ngraph/coordinate_transform.hpp"
namespace ngraph
{
namespace runtime
{
namespace reference
{
// Return type is `void`, only enabled if `T` is a built-in FP
// type, or nGraph's `bfloat16` or `float16` type.
template <typename T>
typename std::enable_if<std::is_floating_point<T>::value ||
std::is_same<T, bfloat16>::value ||
std::is_same<T, float16>::value>::type
range(const T* start, const T* step, const Shape& out_shape, T* out)
{
for (size_t i = 0; i < shape_size(out_shape); i++)
{
out[i] = *start + (static_cast<T>(i) * (*step));
}
}
// Return type is `void`, only enabled if `T` is `is_integral`.
template <typename T>
typename std::enable_if<std::is_integral<T>::value>::type
range(const T* start, const T* step, const Shape& out_shape, T* out)
{
T val = *start;
for (size_t i = 0; i < shape_size(out_shape); i++)
{
out[i] = val;
val += *step;
}
}
}
}
}
...@@ -891,6 +891,65 @@ TEST(constant_folding, constant_transpose) ...@@ -891,6 +891,65 @@ TEST(constant_folding, constant_transpose)
ASSERT_TRUE(test::all_close_f(values_permute, values_out, MIN_FLOAT_TOLERANCE_BITS)); ASSERT_TRUE(test::all_close_f(values_permute, values_out, MIN_FLOAT_TOLERANCE_BITS));
} }
void range_test_check(const vector<double>& values_out, const vector<double>& values_expected)
{
ASSERT_TRUE(test::all_close_f(values_out, values_expected, MIN_FLOAT_TOLERANCE_BITS));
}
void range_test_check(const vector<float>& values_out, const vector<float>& values_expected)
{
ASSERT_TRUE(test::all_close_f(values_out, values_expected, MIN_FLOAT_TOLERANCE_BITS));
}
template <typename T>
typename std::enable_if<std::is_integral<T>::value>::type
range_test_check(const vector<T>& values_out, const vector<T>& values_expected)
{
ASSERT_EQ(values_out, values_expected);
}
template <typename T>
void range_test(T start, T stop, T step, const vector<T>& values_expected)
{
vector<T> values_start{start};
vector<T> values_stop{stop};
vector<T> values_step{step};
auto constant_start = make_shared<op::Constant>(element::from<T>(), Shape{}, values_start);
auto constant_stop = make_shared<op::Constant>(element::from<T>(), Shape{}, values_stop);
auto constant_step = make_shared<op::Constant>(element::from<T>(), Shape{}, values_step);
auto range = make_shared<op::Range>(constant_start, constant_stop, constant_step);
auto f = make_shared<Function>(range, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Range>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const =
std::dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
auto values_out = new_const->template get_vector<T>();
range_test_check(values_out, values_expected);
}
TEST(constant_folding, constant_range)
{
range_test<int8_t>(5, 12, 2, {5, 7, 9, 11});
range_test<int32_t>(5, 12, 2, {5, 7, 9, 11});
range_test<int64_t>(5, 12, 2, {5, 7, 9, 11});
range_test<uint64_t>(5, 12, 2, {5, 7, 9, 11});
range_test<double>(5, 12, 2, {5, 7, 9, 11});
range_test<float>(5, 12, 2, {5, 7, 9, 11});
range_test<int32_t>(5, 12, -2, {});
range_test<float>(12, 4, -2, {12, 10, 8, 6});
}
TEST(constant_folding, constant_select) TEST(constant_folding, constant_select)
{ {
Shape shape{2, 4}; Shape shape{2, 4};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment