Commit 066f5e47 authored by Mateusz Bencer's avatar Mateusz Bencer Committed by Scott Cyphers

[SPEC] Add constant folding for OneHot:v1 (#4087)

* first version

* Added constant cast

* Updated ref implementation

* Fixed out element type

* Added UT

* Code refactor
parent 6d4bf115
......@@ -479,6 +479,7 @@ set (SRC
pass/constant_folding_dyn_slice.cpp
pass/constant_folding_gather.cpp
pass/constant_folding_logical_reduction.cpp
pass/constant_folding_one_hot.cpp
pass/constant_folding_pad.cpp
pass/constant_folding_quantize.cpp
pass/constant_folding_range.cpp
......
......@@ -58,7 +58,8 @@ public:
SQUEEZE,
UNSQUEEZE,
SPLIT,
VARIADIC_SPLIT
VARIADIC_SPLIT,
ONE_HOT
};
ConstantFolding(const ngraph::BuildNodeExecutorMap& cfmap = ngraph::BuildNodeExecutorMap())
......@@ -93,6 +94,7 @@ public:
construct_constant_select();
construct_constant_squeeze();
construct_constant_unsqueeze();
construct_constant_one_hot();
}
// this allows to specify the order in which matchers will be run
......@@ -136,6 +138,7 @@ public:
case CFTransformations::UNSQUEEZE: construct_constant_unsqueeze(); break;
case CFTransformations::SPLIT: construct_constant_split(); break;
case CFTransformations::VARIADIC_SPLIT: construct_constant_variadic_split(); break;
case CFTransformations::ONE_HOT: construct_constant_one_hot(); break;
}
}
}
......@@ -167,6 +170,7 @@ private:
void construct_constant_unsqueeze();
void construct_constant_split();
void construct_constant_variadic_split();
void construct_constant_one_hot();
ngraph::BuildNodeExecutorMap m_cfmap;
};
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "constant_folding.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/one_hot.hpp"
#include "ngraph/runtime/reference/broadcast.hpp"
#include "ngraph/runtime/reference/one_hot.hpp"
using namespace std;
using namespace ngraph;
template <class INDICES_TYPE, class OUTPUT_TYPE>
shared_ptr<op::Constant> fold_constant_one_hot_ref(const shared_ptr<op::Constant>& indices,
const shared_ptr<op::Constant>& on_value,
const shared_ptr<op::Constant>& off_value,
const Shape& output_shape,
size_t axis)
{
std::vector<OUTPUT_TYPE> out_vec(shape_size(output_shape));
runtime::reference::one_hot<INDICES_TYPE, OUTPUT_TYPE>(indices->get_data_ptr<INDICES_TYPE>(),
out_vec.data(),
indices->get_shape(),
output_shape,
axis,
on_value->get_vector<OUTPUT_TYPE>()[0],
off_value->get_vector<OUTPUT_TYPE>()[0]);
return make_shared<op::Constant>(on_value->get_element_type(), output_shape, out_vec);
}
template <class OUTPUT_TYPE>
shared_ptr<op::Constant> fold_constant_one_hot(const shared_ptr<op::Constant>& indices,
const shared_ptr<op::Constant>& on_value,
const shared_ptr<op::Constant>& off_value,
const Shape& output_shape,
size_t axis)
{
switch (indices->get_element_type())
{
case element::Type_t::undefined:
case element::Type_t::dynamic:
case element::Type_t::u1:
case element::Type_t::boolean:
case element::Type_t::bf16:
case element::Type_t::f16:
case element::Type_t::f32:
case element::Type_t::f64:
NGRAPH_CHECK(false, "Indices input element type must be integer");
break;
case element::Type_t::i8:
return fold_constant_one_hot_ref<int8_t, OUTPUT_TYPE>(
indices, on_value, off_value, output_shape, axis);
break;
case element::Type_t::i16:
return fold_constant_one_hot_ref<int16_t, OUTPUT_TYPE>(
indices, on_value, off_value, output_shape, axis);
break;
case element::Type_t::i32:
return fold_constant_one_hot_ref<int32_t, OUTPUT_TYPE>(
indices, on_value, off_value, output_shape, axis);
break;
case element::Type_t::i64:
return fold_constant_one_hot_ref<int64_t, OUTPUT_TYPE>(
indices, on_value, off_value, output_shape, axis);
break;
case element::Type_t::u8:
return fold_constant_one_hot_ref<uint8_t, OUTPUT_TYPE>(
indices, on_value, off_value, output_shape, axis);
break;
case element::Type_t::u16:
return fold_constant_one_hot_ref<uint16_t, OUTPUT_TYPE>(
indices, on_value, off_value, output_shape, axis);
break;
case element::Type_t::u32:
return fold_constant_one_hot_ref<uint32_t, OUTPUT_TYPE>(
indices, on_value, off_value, output_shape, axis);
break;
case element::Type_t::u64:
return fold_constant_one_hot_ref<uint64_t, OUTPUT_TYPE>(
indices, on_value, off_value, output_shape, axis);
break;
}
}
void pass::ConstantFolding::construct_constant_one_hot()
{
auto indices_label =
make_shared<pattern::op::Label>(element::i64, Shape{3}, pattern::has_class<op::Constant>());
auto depth_label =
make_shared<pattern::op::Label>(element::i64, Shape{}, pattern::has_class<op::Constant>());
auto on_label =
make_shared<pattern::op::Label>(element::i64, Shape{}, pattern::has_class<op::Constant>());
auto off_label =
make_shared<pattern::op::Label>(element::i64, Shape{}, pattern::has_class<op::Constant>());
int64_t axis = 0;
auto ont_hot_pattern =
make_shared<op::v1::OneHot>(indices_label, depth_label, on_label, off_label, axis);
auto one_hot_callback = [indices_label, depth_label, on_label, off_label](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for one_hot_callback against node = "
<< m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map();
auto indices_node = static_pointer_cast<op::Constant>(pattern_map[indices_label]);
const auto depth_node = static_pointer_cast<op::Constant>(pattern_map[depth_label]);
const auto on_node = static_pointer_cast<op::Constant>(pattern_map[on_label]);
const auto off_node = static_pointer_cast<op::Constant>(pattern_map[off_label]);
auto one_hot = static_pointer_cast<op::v1::OneHot>(m.get_match_root());
const size_t axis = one_hot->get_axis();
const auto output_shape = one_hot->get_output_shape(0);
auto output_type = on_node->get_element_type();
std::shared_ptr<op::Constant> replacement =
fold_constant_one_hot<char>(indices_node, on_node, off_node, output_shape, axis);
switch (output_type)
{
case element::Type_t::undefined:
NGRAPH_CHECK(false, "Encountered 'undefined' element type in one_hot_callback");
break;
case element::Type_t::dynamic:
NGRAPH_CHECK(false, "Encountered 'dynamic' element type in one_hot_callback");
break;
case element::Type_t::u1:
NGRAPH_CHECK(false, "Encountered 'u1' element type in one_hot_callback");
break;
case element::Type_t::boolean:
replacement =
fold_constant_one_hot<char>(indices_node, on_node, off_node, output_shape, axis);
break;
case element::Type_t::bf16:
replacement = fold_constant_one_hot<bfloat16>(
indices_node, on_node, off_node, output_shape, axis);
break;
case element::Type_t::f16:
replacement =
fold_constant_one_hot<float16>(indices_node, on_node, off_node, output_shape, axis);
break;
case element::Type_t::f32:
replacement =
fold_constant_one_hot<float>(indices_node, on_node, off_node, output_shape, axis);
break;
case element::Type_t::f64:
replacement =
fold_constant_one_hot<double>(indices_node, on_node, off_node, output_shape, axis);
break;
case element::Type_t::i8:
replacement =
fold_constant_one_hot<int8_t>(indices_node, on_node, off_node, output_shape, axis);
break;
case element::Type_t::i16:
replacement =
fold_constant_one_hot<int16_t>(indices_node, on_node, off_node, output_shape, axis);
break;
case element::Type_t::i32:
replacement =
fold_constant_one_hot<int32_t>(indices_node, on_node, off_node, output_shape, axis);
break;
case element::Type_t::i64:
replacement =
fold_constant_one_hot<int64_t>(indices_node, on_node, off_node, output_shape, axis);
break;
case element::Type_t::u8:
replacement =
fold_constant_one_hot<uint8_t>(indices_node, on_node, off_node, output_shape, axis);
break;
case element::Type_t::u16:
replacement = fold_constant_one_hot<uint16_t>(
indices_node, on_node, off_node, output_shape, axis);
break;
case element::Type_t::u32:
replacement = fold_constant_one_hot<uint32_t>(
indices_node, on_node, off_node, output_shape, axis);
break;
case element::Type_t::u64:
replacement = fold_constant_one_hot<uint64_t>(
indices_node, on_node, off_node, output_shape, axis);
break;
}
replace_node(m.get_match_root(), replacement);
return true;
};
auto one_hot_matcher =
make_shared<pattern::Matcher>(ont_hot_pattern, "ConstantFolding.ConstantOneHot");
this->add_matcher(one_hot_matcher, one_hot_callback, PassProperty::CHANGE_DYNAMIC_STATE);
}
......@@ -514,13 +514,9 @@ namespace
const auto axis = node->get_axis();
NGRAPH_CHECK(depth->is_constant(), "depth input must be constant", *node);
const auto const_depth = as_type_ptr<op::Constant>(depth);
std::int64_t depth_value = const_depth->get_vector<std::int64_t>()[0];
const auto indices_shape = node->get_input_partial_shape(0);
NGRAPH_CHECK(indices_shape.is_static(), "indices shape must be static", *node);
auto output_shape = indices_shape.to_shape();
output_shape.insert(output_shape.begin() + axis, depth_value);
const auto output_pshape = node->get_output_partial_shape(0);
NGRAPH_CHECK(output_pshape.is_static(), "output shape must be static", *node);
const auto output_shape = output_pshape.to_shape();
auto one_hot = std::make_shared<ngraph::op::Convert>(
std::make_shared<ngraph::op::OneHot>(indices, output_shape, axis),
......
......@@ -27,28 +27,30 @@ namespace ngraph
{
namespace reference
{
template <typename T>
void one_hot(const T* arg,
T* out,
template <typename INDICES_TYPE, typename OUTPUT_TYPE>
void one_hot(const INDICES_TYPE* arg,
OUTPUT_TYPE* out,
const Shape& in_shape,
const Shape& out_shape,
size_t one_hot_axis)
size_t one_hot_axis,
const OUTPUT_TYPE on_value,
const OUTPUT_TYPE off_value)
{
// Step 1: Zero out the output.
// Step 1: Set off_value to the output.
CoordinateTransform output_transform(out_shape);
for (const Coordinate& output_coord : output_transform)
{
out[output_transform.index(output_coord)] = 0;
out[output_transform.index(output_coord)] = off_value;
}
// Step 2: Write ones at needed positions, throwing exceptions when invalid
// Step 2: Write off_value at needed positions, throwing exceptions when invalid
// conditions are encountered.
CoordinateTransform input_transform(in_shape);
for (const Coordinate& input_coord : input_transform)
{
T val = arg[input_transform.index(input_coord)];
INDICES_TYPE val = arg[input_transform.index(input_coord)];
if (std::floor(val) < val || std::floor(val) > val)
{
......@@ -64,9 +66,22 @@ namespace ngraph
Coordinate one_hot_coord = inject(input_coord, one_hot_axis, one_hot_pos);
out[output_transform.index(one_hot_coord)] = 1;
out[output_transform.index(one_hot_coord)] = on_value;
}
}
template <typename T>
void one_hot(const T* arg,
T* out,
const Shape& in_shape,
const Shape& out_shape,
size_t one_hot_axis)
{
const T on_value = 1;
const T off_value = 0;
return one_hot<T, T>(
arg, out, in_shape, out_shape, one_hot_axis, on_value, off_value);
}
}
}
}
......@@ -2072,6 +2072,129 @@ TEST(constant_folding, constant_v1_variadic_split_axis_1_3_splits_neg_length)
res3_values);
}
TEST(constant_folding, constant_v1_one_hot)
{
vector<int64_t> indices{0, 1, 2};
float16 on_value = 1.123f;
float16 off_value = 0.321f;
const auto indices_const = op::Constant::create(element::i64, Shape{3}, indices);
const auto depth_const = op::Constant::create(element::i64, Shape{}, {3});
const auto on_const = op::Constant::create(element::f16, Shape{}, {on_value});
const auto off_const = op::Constant::create(element::f16, Shape{}, {off_value});
int64_t axis = 1;
auto one_hot_v1 =
make_shared<op::v1::OneHot>(indices_const, depth_const, on_const, off_const, axis);
auto f = make_shared<Function>(one_hot_v1, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::v1::OneHot>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto res = as_type_ptr<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(res);
ASSERT_EQ((Shape{3, 3}), res->get_output_shape(0));
ASSERT_EQ(vector<float16>({on_value,
off_value,
off_value,
off_value,
on_value,
off_value,
off_value,
off_value,
on_value}),
res->get_vector<float16>());
}
TEST(constant_folding, constant_v1_one_hot_negative_axes)
{
vector<int64_t> indices{0, 2, -1, 1};
int16_t on_value = 4;
int16_t off_value = 1;
const auto indices_const = op::Constant::create(element::i64, Shape{4}, indices);
const auto depth_const = op::Constant::create(element::i64, Shape{}, {3});
const auto on_const = op::Constant::create(element::i16, Shape{}, {on_value});
const auto off_const = op::Constant::create(element::i16, Shape{}, {off_value});
int64_t axis = -1;
auto one_hot_v1 =
make_shared<op::v1::OneHot>(indices_const, depth_const, on_const, off_const, axis);
auto f = make_shared<Function>(one_hot_v1, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::v1::OneHot>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto res = as_type_ptr<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(res);
ASSERT_EQ((Shape{4, 3}), res->get_output_shape(0));
ASSERT_EQ(vector<int16_t>({on_value,
off_value,
off_value,
off_value,
off_value,
on_value,
off_value,
off_value,
off_value,
off_value,
on_value,
off_value}),
res->get_vector<int16_t>());
}
TEST(constant_folding, constant_v1_one_hot_negative_axes_2)
{
vector<int64_t> indices{0, 2, 1, -1};
auto on_value = true;
auto off_value = false;
const auto indices_const = op::Constant::create(element::i64, Shape{2, 2}, indices);
const auto depth_const = op::Constant::create(element::i64, Shape{}, {3});
const auto on_const = op::Constant::create(element::boolean, Shape{}, {on_value});
const auto off_const = op::Constant::create(element::boolean, Shape{}, {off_value});
int64_t axis = -1;
auto one_hot_v1 =
make_shared<op::v1::OneHot>(indices_const, depth_const, on_const, off_const, axis);
auto f = make_shared<Function>(one_hot_v1, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::v1::OneHot>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto res = as_type_ptr<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(res);
ASSERT_EQ((Shape{2, 2, 3}), res->get_output_shape(0));
ASSERT_EQ(vector<bool>({on_value,
off_value,
off_value,
off_value,
off_value,
on_value,
off_value,
on_value,
off_value,
off_value,
off_value,
off_value}),
res->get_vector<bool>());
}
TEST(constant_folding, pass_property)
{
auto pass = std::make_shared<ngraph::pass::ConstantFolding>();
......
......@@ -97,7 +97,7 @@ TEST(opset_transform, opset1_one_hot_downgrade_pass_depth_not_constant)
}
}
TEST(opset_transform, opset1_one_hot_downgrade_pass_indices_shape_not_static)
TEST(opset_transform, opset1_one_hot_downgrade_pass_output_shape_not_static)
{
auto indices = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
auto depth = op::Constant::create(element::i64, Shape{}, {4});
......@@ -116,11 +116,11 @@ TEST(opset_transform, opset1_one_hot_downgrade_pass_indices_shape_not_static)
{
pass_manager.run_passes(f);
// Should have thrown, so fail if it didn't
FAIL() << "Not static indices shape not detected";
FAIL() << "Not static output shape not detected";
}
catch (const ngraph_error& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("indices shape must be static"));
EXPECT_HAS_SUBSTRING(error.what(), std::string("output shape must be static"));
}
catch (...)
{
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment