Commit c16d65c4 authored by VINOD KUMAR DEVARAMPATI's avatar VINOD KUMAR DEVARAMPATI Committed by Robert Kimball

added constant folding for dequantize (#1762)

* added constant folding for dequantize

* modified as per review comments
parent fe06f325
......@@ -20,11 +20,13 @@
#include "ngraph/graph_util.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/dequantize.hpp"
#include "ngraph/op/pad.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/label.hpp"
#include "ngraph/runtime/reference/broadcast.hpp"
#include "ngraph/runtime/reference/dequantize.hpp"
#include "ngraph/runtime/reference/pad.hpp"
#include "ngraph/runtime/reference/reshape.hpp"
......@@ -231,3 +233,75 @@ void ngraph::pass::ConstantFolding::construct_constant_broadcast()
auto broadcast_matcher = make_shared<pattern::Matcher>(broadcast, constant_broadcast_callback);
this->add_matcher(broadcast_matcher);
}
template <class QUANT, class REAL>
shared_ptr<op::Constant> make_constant_dequantize(shared_ptr<op::Constant> constant,
shared_ptr<op::Dequantize> dequant,
shared_ptr<op::Constant> scale,
shared_ptr<op::Constant> offset)
{
auto out_shape = constant->get_shape();
vector<REAL> out_vec(shape_size(out_shape));
runtime::reference::dequantize<QUANT, REAL>(constant->get_vector<QUANT>().data(),
scale->get_vector<REAL>().data(),
offset->get_vector<QUANT>().data(),
out_vec.data(),
constant->get_shape(),
scale->get_shape(),
dequant->get_axes());
return make_shared<op::Constant>(dequant->get_element_type(), out_shape, out_vec);
}
void ngraph::pass::ConstantFolding::construct_constant_dequantize()
{
auto constant_label =
make_shared<pattern::op::Label>(element::u8, Shape{2}, pattern::has_class<op::Constant>());
auto dq_scale = op::Constant::create(element::f32, Shape{}, {1});
auto dq_offset = op::Constant::create(element::u8, Shape{}, {1});
auto dequant_op =
make_shared<op::Dequantize>(constant_label, dq_scale, dq_offset, element::f32, AxisSet{});
auto dequant = make_shared<pattern::op::Label>(dequant_op, nullptr, NodeVector{dequant_op});
auto constant_dequantize_callback = [constant_label, dequant](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for constant_dequantize_callback against node = "
<< m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map();
auto constant_match = dynamic_pointer_cast<op::Constant>(pattern_map[constant_label]);
auto dequant_match = pattern_map[dequant];
auto dequantize_op = dynamic_pointer_cast<op::Dequantize>(dequant_match);
auto args = dequant_match->get_arguments();
auto scale = dynamic_pointer_cast<op::Constant>(args[1]);
auto offset = dynamic_pointer_cast<op::Constant>(args[2]);
auto type = constant_match->get_element_type();
if (dequant_match->get_element_type() != element::f32)
{
return false;
}
if (type == element::u8)
{
replace_node(m.get_match_root(),
make_constant_dequantize<uint8_t, float>(
constant_match, dequantize_op, scale, offset));
return true;
}
else if (type == element::i8)
{
replace_node(m.get_match_root(),
make_constant_dequantize<int8_t, float>(
constant_match, dequantize_op, scale, offset));
return true;
}
return false;
};
auto dequantize_matcher = make_shared<pattern::Matcher>(dequant, constant_dequantize_callback);
this->add_matcher(dequantize_matcher);
}
......@@ -32,7 +32,8 @@ class ngraph::pass::ConstantFolding : public ngraph::pass::GraphRewrite
{
RESHAPE,
BROADCAST,
PAD
PAD,
DEQUANTIZE
};
public:
......@@ -42,6 +43,7 @@ public:
construct_constant_reshape();
construct_constant_broadcast();
construct_constant_pad();
construct_constant_dequantize();
}
//this allows to specify the order in which matchers will be run
......@@ -56,6 +58,7 @@ public:
case CFTransformations::RESHAPE: construct_constant_reshape(); break;
case CFTransformations::BROADCAST: construct_constant_broadcast(); break;
case CFTransformations::PAD: construct_constant_pad(); break;
case CFTransformations::DEQUANTIZE: construct_constant_dequantize(); break;
}
}
}
......@@ -64,4 +67,5 @@ private:
void construct_constant_reshape();
void construct_constant_broadcast();
void construct_constant_pad();
void construct_constant_dequantize();
};
......@@ -163,3 +163,37 @@ TEST(constant_folding, constant_pad_interior)
vector<int> padded_values{777, 111, 111, 111, 888};
ASSERT_EQ(padded_values, values_out);
}
TEST(constant_folding, const_dequantize)
{
Shape input_shape{12};
Shape scale_offset_shape;
AxisSet quantization_axes;
auto quant_type = element::u8;
auto output_type = element::f32;
typedef float output_c_type;
vector<uint8_t> values_in{1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7};
auto constant = op::Constant::create(quant_type, input_shape, values_in);
auto scale = op::Constant::create(output_type, scale_offset_shape, {2});
auto offset = op::Constant::create(quant_type, scale_offset_shape, {1});
auto dequantize =
make_shared<op::Dequantize>(constant, scale, offset, output_type, quantization_axes);
auto f = make_shared<Function>(dequantize, op::ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Dequantize>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const =
std::dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
auto values_out = new_const->get_vector<output_c_type>();
vector<output_c_type> values_dequantize{0, 2, 2, 4, 4, 6, 6, 8, 8, 10, 10, 12};
ASSERT_EQ(values_dequantize, values_out);
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment