Commit c6550bc0 authored by VINOD KUMAR DEVARAMPATI's avatar VINOD KUMAR DEVARAMPATI Committed by Scott Cyphers

Constant folding with Quantize (#1833)

* Constant folding with Quantize

* updated with review comments
parent eaa85e1c
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#include "ngraph/op/multiply.hpp" #include "ngraph/op/multiply.hpp"
#include "ngraph/op/negative.hpp" #include "ngraph/op/negative.hpp"
#include "ngraph/op/pad.hpp" #include "ngraph/op/pad.hpp"
#include "ngraph/op/quantize.hpp"
#include "ngraph/op/reshape.hpp" #include "ngraph/op/reshape.hpp"
#include "ngraph/op/subtract.hpp" #include "ngraph/op/subtract.hpp"
#include "ngraph/pattern/matcher.hpp" #include "ngraph/pattern/matcher.hpp"
...@@ -43,6 +44,7 @@ ...@@ -43,6 +44,7 @@
#include "ngraph/runtime/reference/multiply.hpp" #include "ngraph/runtime/reference/multiply.hpp"
#include "ngraph/runtime/reference/negate.hpp" #include "ngraph/runtime/reference/negate.hpp"
#include "ngraph/runtime/reference/pad.hpp" #include "ngraph/runtime/reference/pad.hpp"
#include "ngraph/runtime/reference/quantize.hpp"
#include "ngraph/runtime/reference/reshape.hpp" #include "ngraph/runtime/reference/reshape.hpp"
#include "ngraph/runtime/reference/subtract.hpp" #include "ngraph/runtime/reference/subtract.hpp"
...@@ -529,3 +531,81 @@ void ngraph::pass::ConstantFolding::construct_constant_dequantize() ...@@ -529,3 +531,81 @@ void ngraph::pass::ConstantFolding::construct_constant_dequantize()
auto dequantize_matcher = make_shared<pattern::Matcher>(dequant, constant_dequantize_callback); auto dequantize_matcher = make_shared<pattern::Matcher>(dequant, constant_dequantize_callback);
this->add_matcher(dequantize_matcher); this->add_matcher(dequantize_matcher);
} }
template <class REAL, class QUANT>
shared_ptr<op::Constant> make_constant_quantize(shared_ptr<op::Constant> constant,
shared_ptr<op::Quantize> quant,
shared_ptr<op::Constant> scale,
shared_ptr<op::Constant> offset)
{
auto out_shape = constant->get_shape();
vector<QUANT> out_vec(shape_size(out_shape));
runtime::reference::quantize<REAL, QUANT>(constant->get_vector<REAL>().data(),
scale->get_vector<REAL>().data(),
offset->get_vector<QUANT>().data(),
out_vec.data(),
constant->get_shape(),
scale->get_shape(),
quant->get_axes());
return make_shared<op::Constant>(quant->get_element_type(), out_shape, out_vec);
}
void ngraph::pass::ConstantFolding::construct_constant_quantize()
{
auto constant_label =
make_shared<pattern::op::Label>(element::f32, Shape{2}, pattern::has_class<op::Constant>());
auto q_scale = op::Constant::create(element::f32, Shape{}, {1});
auto q_offset = op::Constant::create(element::i8, Shape{}, {0});
auto mode = op::Quantize::RoundMode::HALF_AWAY_FROM_ZERO;
auto quant_op =
make_shared<op::Quantize>(constant_label, q_scale, q_offset, element::i8, AxisSet{}, mode);
auto quant = make_shared<pattern::op::Label>(quant_op, nullptr, NodeVector{quant_op});
auto constant_quantize_callback = [constant_label, quant](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for constant_quantize_callback against node = "
<< m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map();
auto constant_match = dynamic_pointer_cast<op::Constant>(pattern_map[constant_label]);
auto quant_match = pattern_map[quant];
auto quantize_op = dynamic_pointer_cast<op::Quantize>(quant_match);
auto args = quant_match->get_arguments();
auto scale = static_pointer_cast<op::Constant>(args[1]);
auto offset = static_pointer_cast<op::Constant>(args[2]);
auto type = quant_match->get_element_type();
if (constant_match->get_element_type() != element::f32)
{
return false;
}
if (quantize_op->get_round_mode() != op::Quantize::RoundMode::HALF_AWAY_FROM_ZERO)
{
return false;
}
if (type == element::u8)
{
replace_node(
m.get_match_root(),
make_constant_quantize<float, uint8_t>(constant_match, quantize_op, scale, offset));
return true;
}
else if (type == element::i8)
{
replace_node(
m.get_match_root(),
make_constant_quantize<float, int8_t>(constant_match, quantize_op, scale, offset));
return true;
}
return false;
};
auto quantize_matcher = make_shared<pattern::Matcher>(quant, constant_quantize_callback);
this->add_matcher(quantize_matcher);
}
...@@ -35,7 +35,8 @@ class ngraph::pass::ConstantFolding : public ngraph::pass::GraphRewrite ...@@ -35,7 +35,8 @@ class ngraph::pass::ConstantFolding : public ngraph::pass::GraphRewrite
PAD, PAD,
DEQUANTIZE, DEQUANTIZE,
UNARY, UNARY,
BINARY BINARY,
QUANTIZE
}; };
public: public:
...@@ -47,6 +48,7 @@ public: ...@@ -47,6 +48,7 @@ public:
construct_constant_pad(); construct_constant_pad();
construct_constant_unary(); construct_constant_unary();
construct_constant_binary(); construct_constant_binary();
construct_constant_quantize();
construct_constant_dequantize(); construct_constant_dequantize();
} }
...@@ -65,6 +67,7 @@ public: ...@@ -65,6 +67,7 @@ public:
case CFTransformations::UNARY: construct_constant_unary(); break; case CFTransformations::UNARY: construct_constant_unary(); break;
case CFTransformations::BINARY: construct_constant_binary(); break; case CFTransformations::BINARY: construct_constant_binary(); break;
case CFTransformations::DEQUANTIZE: construct_constant_dequantize(); break; case CFTransformations::DEQUANTIZE: construct_constant_dequantize(); break;
case CFTransformations::QUANTIZE: construct_constant_quantize(); break;
} }
} }
} }
...@@ -75,5 +78,6 @@ private: ...@@ -75,5 +78,6 @@ private:
void construct_constant_pad(); void construct_constant_pad();
void construct_constant_unary(); void construct_constant_unary();
void construct_constant_binary(); void construct_constant_binary();
void construct_constant_quantize();
void construct_constant_dequantize(); void construct_constant_dequantize();
}; };
...@@ -250,3 +250,38 @@ TEST(constant_folding, const_dequantize) ...@@ -250,3 +250,38 @@ TEST(constant_folding, const_dequantize)
vector<output_c_type> values_dequantize{0, 2, 2, 4, 4, 6, 6, 8, 8, 10, 10, 12}; vector<output_c_type> values_dequantize{0, 2, 2, 4, 4, 6, 6, 8, 8, 10, 10, 12};
ASSERT_EQ(values_dequantize, values_out); ASSERT_EQ(values_dequantize, values_out);
} }
TEST(constant_folding, const_quantize)
{
Shape input_shape{12};
Shape scale_offset_shape;
AxisSet quantization_axes;
auto quant_type = element::u8;
auto output_type = element::u8;
typedef uint8_t output_c_type;
vector<float> values_in{1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0, 5.0, 5.0, 6.0, 6.0, 7.0};
auto constant = op::Constant::create(element::f32, input_shape, values_in);
auto scale = op::Constant::create(element::f32, scale_offset_shape, {2});
auto offset = op::Constant::create(quant_type, scale_offset_shape, {1});
auto mode = op::Quantize::RoundMode::HALF_AWAY_FROM_ZERO;
auto quantize =
make_shared<op::Quantize>(constant, scale, offset, output_type, quantization_axes, mode);
auto f = make_shared<Function>(quantize, op::ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Quantize>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const =
std::dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
auto values_out = new_const->get_vector<output_c_type>();
vector<output_c_type> values_quantize{2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5};
ASSERT_EQ(values_quantize, values_out);
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment