Unverified Commit d3747036 authored by Amy Zhuang's avatar Amy Zhuang Committed by GitHub

Add constant folding for Tile. (#4328)

Co-authored-by: 's avatarChris Sullivan <chris.sullivan@intel.com>
Co-authored-by: 's avatarSang Ik Lee <sang.ik.lee@intel.com>
parent a1ee816e
......@@ -497,6 +497,7 @@ set (SRC
pass/constant_folding_split.cpp
pass/constant_folding_variadic_split.cpp
pass/constant_folding_strided_slice.cpp
pass/constant_folding_tile.cpp
pass/constant_folding_transpose.cpp
pass/constant_folding_unary.cpp
pass/constant_folding.cpp
......
......@@ -61,7 +61,8 @@ public:
UNSQUEEZE,
SPLIT,
VARIADIC_SPLIT,
ONE_HOT
ONE_HOT,
TILE
};
ConstantFolding(const ngraph::BuildNodeExecutorMap& cfmap = ngraph::BuildNodeExecutorMap())
......@@ -97,6 +98,7 @@ public:
construct_constant_squeeze();
construct_constant_unsqueeze();
construct_constant_one_hot();
construct_constant_tile();
}
// this allows to specify the order in which matchers will be run
......@@ -141,6 +143,7 @@ public:
case CFTransformations::SPLIT: construct_constant_split(); break;
case CFTransformations::VARIADIC_SPLIT: construct_constant_variadic_split(); break;
case CFTransformations::ONE_HOT: construct_constant_one_hot(); break;
case CFTransformations::TILE: construct_constant_tile(); break;
}
}
}
......@@ -173,6 +176,7 @@ private:
void construct_constant_split();
void construct_constant_variadic_split();
void construct_constant_one_hot();
void construct_constant_tile();
ngraph::BuildNodeExecutorMap m_cfmap;
};
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "constant_folding.hpp"
#include "ngraph/op/experimental/tile.hpp"
#include "ngraph/runtime/reference/tile.hpp"
using namespace std;
using namespace ngraph;
template <typename T>
static shared_ptr<op::Constant> fold_constant_tile(const shared_ptr<op::Constant>& data,
const shared_ptr<Node>& tile)
{
runtime::AlignedBuffer buffer(shape_size(tile->get_shape()) * sizeof(T));
T* data_ptr = buffer.get_ptr<T>();
// No need to call the reference kernel.
if (shape_size(tile->get_shape()) == 0)
{
return make_shared<op::Constant>(
tile->get_output_element_type(0), tile->get_output_shape(0), data_ptr);
}
if (auto tile_v0 = as_type_ptr<op::v0::Tile>(tile))
{
runtime::reference::tile<T>(
data->get_data_ptr<T>(), data_ptr, data->get_shape(), tile_v0->get_shape());
}
else
{
throw ngraph_error("Unsupported op in tile constant folding.");
}
return make_shared<op::Constant>(
tile->get_output_element_type(0), tile->get_output_shape(0), data_ptr);
}
void pass::ConstantFolding::construct_constant_tile()
{
auto data_label = make_shared<pattern::op::Label>(
element::f32, Shape{2, 2, 3}, pattern::has_class<op::Constant>());
auto repeats_label =
make_shared<pattern::op::Label>(element::i64, Shape{3}, pattern::has_class<op::Constant>());
auto tile_v0 = make_shared<op::v0::Tile>(data_label, repeats_label);
auto constant_tile_callback = [data_label](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for constant_tile_callback against node = "
<< m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map();
auto data = static_pointer_cast<op::Constant>(pattern_map[data_label]);
auto tile = m.get_match_root();
NGRAPH_CHECK(revalidate_and_ensure_static(tile));
std::shared_ptr<Node> replacement;
auto data_type = data->get_output_element_type(0);
switch (data_type)
{
case element::Type_t::undefined:
NGRAPH_CHECK(false, "Encountered 'undefined' element type in constant_tile_callback");
break;
case element::Type_t::dynamic:
NGRAPH_CHECK(false, "Encountered 'dynamic' element type in constant_tile_callback");
break;
case element::Type_t::u1:
NGRAPH_CHECK(false, "Encountered 'u1' element type in constant_tile_callback");
break;
case element::Type_t::boolean: replacement = fold_constant_tile<char>(data, tile); break;
case element::Type_t::bf16: replacement = fold_constant_tile<bfloat16>(data, tile); break;
case element::Type_t::f16: replacement = fold_constant_tile<float16>(data, tile); break;
case element::Type_t::f32: replacement = fold_constant_tile<float>(data, tile); break;
case element::Type_t::f64: replacement = fold_constant_tile<double>(data, tile); break;
case element::Type_t::i8: replacement = fold_constant_tile<int8_t>(data, tile); break;
case element::Type_t::i16: replacement = fold_constant_tile<int16_t>(data, tile); break;
case element::Type_t::i32: replacement = fold_constant_tile<int32_t>(data, tile); break;
case element::Type_t::i64: replacement = fold_constant_tile<int64_t>(data, tile); break;
case element::Type_t::u8: replacement = fold_constant_tile<uint8_t>(data, tile); break;
case element::Type_t::u16: replacement = fold_constant_tile<uint16_t>(data, tile); break;
case element::Type_t::u32: replacement = fold_constant_tile<uint32_t>(data, tile); break;
case element::Type_t::u64: replacement = fold_constant_tile<uint64_t>(data, tile); break;
}
replace_node(m.get_match_root(), replacement);
return true;
};
auto tile_matcher_v0 = make_shared<pattern::Matcher>(tile_v0, "ConstantFolding.ConstantTileV0");
this->add_matcher(tile_matcher_v0, constant_tile_callback, PassProperty::CHANGE_DYNAMIC_STATE);
}
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <cmath>
#include "ngraph/coordinate_transform.hpp"
namespace ngraph
{
namespace runtime
{
namespace reference
{
template <typename T>
void tile(const T* arg, T* out, const Shape& in_shape, const Shape& out_shape)
{
Shape in_shape_expanded(in_shape);
in_shape_expanded.insert(
in_shape_expanded.begin(), out_shape.size() - in_shape.size(), 1);
CoordinateTransform input_transform(in_shape_expanded);
CoordinateTransform output_transform(out_shape);
for (const Coordinate& output_coord : output_transform)
{
std::vector<size_t> coord;
for (auto i = 0; i < output_coord.size(); i++)
{
auto val = output_coord[i] % in_shape_expanded[i];
coord.push_back(val);
}
Coordinate input_coord(coord);
out[output_transform.index(output_coord)] =
arg[input_transform.index(input_coord)];
}
}
}
}
}
......@@ -2195,6 +2195,146 @@ TEST(constant_folding, constant_v1_one_hot_negative_axes_2)
res->get_vector<bool>());
}
TEST(constant_folding, constant_tile_1d)
{
Shape shape_in{2};
Shape shape_repeats{1};
Shape shape_out{4};
vector<int> values_in{0, 1};
auto data = make_shared<op::Constant>(element::i32, shape_in, values_in);
vector<int> values_repeats{2};
auto repeats = make_shared<op::Constant>(element::i64, shape_repeats, values_repeats);
auto tile = make_shared<op::v0::Tile>(data, repeats);
auto f = make_shared<Function>(tile, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Tile>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const = as_type_ptr<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
auto values_out = new_const->get_vector<int>();
vector<int> values_expected{0, 1, 0, 1};
ASSERT_EQ(values_expected, values_out);
}
TEST(constant_folding, constant_tile_3d_small_data_rank)
{
Shape shape_in{2};
Shape shape_repeats{3};
Shape shape_out{2, 2, 4};
vector<int> values_in{0, 1};
auto data = make_shared<op::Constant>(element::i32, shape_in, values_in);
vector<int> values_repeats{2, 2, 2};
auto repeats = make_shared<op::Constant>(element::i64, shape_repeats, values_repeats);
auto tile = make_shared<op::v0::Tile>(data, repeats);
auto f = make_shared<Function>(tile, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Tile>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const = as_type_ptr<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
auto values_out = new_const->get_vector<int>();
vector<int> values_expected{0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1};
ASSERT_EQ(values_expected, values_out);
}
TEST(constant_folding, constant_tile_3d_few_repeats)
{
Shape shape_in{2, 1, 3};
Shape shape_repeats{2};
Shape shape_out{2, 2, 3};
vector<int> values_in{1, 2, 3, 4, 5, 6};
auto data = make_shared<op::Constant>(element::i32, shape_in, values_in);
vector<int> values_repeats{2, 1};
auto repeats = make_shared<op::Constant>(element::i64, shape_repeats, values_repeats);
auto tile = make_shared<op::v0::Tile>(data, repeats);
auto f = make_shared<Function>(tile, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Tile>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const = as_type_ptr<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
auto values_out = new_const->get_vector<int>();
vector<int> values_expected{1, 2, 3, 1, 2, 3, 4, 5, 6, 4, 5, 6};
ASSERT_EQ(values_expected, values_out);
}
TEST(constant_folding, constant_tile_1d_0_repeats)
{
Shape shape_in{2};
Shape shape_repeats{1};
Shape shape_out{};
vector<int> values_in{0, 1};
auto data = make_shared<op::Constant>(element::i32, shape_in, values_in);
vector<int> values_repeats{0};
auto repeats = make_shared<op::Constant>(element::i64, shape_repeats, values_repeats);
auto tile = make_shared<op::v0::Tile>(data, repeats);
auto f = make_shared<Function>(tile, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Tile>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const = as_type_ptr<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
auto values_out = new_const->get_vector<int>();
vector<int> values_expected{};
ASSERT_EQ(values_expected, values_out);
}
TEST(constant_folding, constant_tile_0_rank_data)
{
Shape shape_in{};
Shape shape_repeats{1};
Shape shape_out{4};
vector<int> values_in{1};
auto data = make_shared<op::Constant>(element::i32, shape_in, values_in);
vector<int> values_repeats{4};
auto repeats = make_shared<op::Constant>(element::i64, shape_repeats, values_repeats);
auto tile = make_shared<op::v0::Tile>(data, repeats);
auto f = make_shared<Function>(tile, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Tile>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const = as_type_ptr<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
auto values_out = new_const->get_vector<int>();
vector<int> values_expected{1, 1, 1, 1};
ASSERT_EQ(values_expected, values_out);
}
TEST(constant_folding, pass_property)
{
auto pass = std::make_shared<ngraph::pass::ConstantFolding>();
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment