Commit b8106133 authored by Jayaram Bobba's avatar Jayaram Bobba Committed by Scott Cyphers

More quantized fusion patterns (#2480)

* Add QuantizedConcat

* Remove unused variables and add check for size of mins and maxes vector

* Resolve conflicts

* Merged with master and addressed some PR feedback

* Maxpool and Avgpool fusions. Exclude Q from conv+relu fusion

* Remove single-user check from fusions

* Quantized concat fusion

* workaround: do reshape sinking by default

* style fix

* check scales for QuantizedConcat

* use compare_constants

* remove stale comment

* Handle all concat cases from arg size 2 to 6

* addressed feedback
parent 3863180d
...@@ -21,9 +21,11 @@ ...@@ -21,9 +21,11 @@
#include <unordered_set> #include <unordered_set>
#include "cpu_fusion.hpp" #include "cpu_fusion.hpp"
#include "ngraph/builder/make_constant.hpp"
#include "ngraph/graph_util.hpp" #include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp" #include "ngraph/log.hpp"
#include "ngraph/op/add.hpp" #include "ngraph/op/add.hpp"
#include "ngraph/op/avg_pool.hpp"
#include "ngraph/op/batch_norm.hpp" #include "ngraph/op/batch_norm.hpp"
#include "ngraph/op/broadcast.hpp" #include "ngraph/op/broadcast.hpp"
#include "ngraph/op/broadcast.hpp" #include "ngraph/op/broadcast.hpp"
...@@ -35,9 +37,12 @@ ...@@ -35,9 +37,12 @@
#include "ngraph/op/divide.hpp" #include "ngraph/op/divide.hpp"
#include "ngraph/op/dot.hpp" #include "ngraph/op/dot.hpp"
#include "ngraph/op/exp.hpp" #include "ngraph/op/exp.hpp"
#include "ngraph/op/experimental/quantized_avg_pool.hpp"
#include "ngraph/op/experimental/quantized_concat.hpp"
#include "ngraph/op/experimental/quantized_conv.hpp" #include "ngraph/op/experimental/quantized_conv.hpp"
#include "ngraph/op/experimental/quantized_conv_bias.hpp" #include "ngraph/op/experimental/quantized_conv_bias.hpp"
#include "ngraph/op/experimental/quantized_conv_relu.hpp" #include "ngraph/op/experimental/quantized_conv_relu.hpp"
#include "ngraph/op/experimental/quantized_max_pool.hpp"
#include "ngraph/op/get_output_element.hpp" #include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/max_pool.hpp" #include "ngraph/op/max_pool.hpp"
#include "ngraph/op/maximum.hpp" #include "ngraph/op/maximum.hpp"
...@@ -1870,7 +1875,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_update_slice() ...@@ -1870,7 +1875,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_update_slice()
this->add_matcher(m); this->add_matcher(m);
} }
// QuantizedConvolution + Dequantize + Relu + Quantize -> QuantizedConvolutionRelu // QuantizedConvolution + Dequantize + Relu -> QuantizedConvolutionRelu + Dequantize
void ngraph::runtime::cpu::pass::CPUQuantFusion::construct_qconv_relu(bool with_bias) void ngraph::runtime::cpu::pass::CPUQuantFusion::construct_qconv_relu(bool with_bias)
{ {
Shape shape{2, 2, 1, 1}; Shape shape{2, 2, 1, 1};
...@@ -1879,9 +1884,6 @@ void ngraph::runtime::cpu::pass::CPUQuantFusion::construct_qconv_relu(bool with_ ...@@ -1879,9 +1884,6 @@ void ngraph::runtime::cpu::pass::CPUQuantFusion::construct_qconv_relu(bool with_
auto requantization_scale = std::make_shared<pattern::op::Label>(element::f32, Shape{}); auto requantization_scale = std::make_shared<pattern::op::Label>(element::f32, Shape{});
auto dq_scale = std::make_shared<pattern::op::Label>(element::f32, Shape{}); auto dq_scale = std::make_shared<pattern::op::Label>(element::f32, Shape{});
auto dq_zp = std::make_shared<pattern::op::Label>(element::i8, Shape{}); auto dq_zp = std::make_shared<pattern::op::Label>(element::i8, Shape{});
auto q_scale = std::make_shared<pattern::op::Label>(element::f32, Shape{});
auto q_zp = std::make_shared<pattern::op::Label>(element::u8, Shape{});
op::Quantize::RoundMode round_mode = op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_EVEN;
std::shared_ptr<ngraph::op::Op> qconv; std::shared_ptr<ngraph::op::Op> qconv;
if (with_bias) if (with_bias)
...@@ -1911,31 +1913,16 @@ void ngraph::runtime::cpu::pass::CPUQuantFusion::construct_qconv_relu(bool with_ ...@@ -1911,31 +1913,16 @@ void ngraph::runtime::cpu::pass::CPUQuantFusion::construct_qconv_relu(bool with_
} }
auto dq = std::make_shared<op::Dequantize>(qconv, dq_scale, dq_zp, element::f32, AxisSet{}); auto dq = std::make_shared<op::Dequantize>(qconv, dq_scale, dq_zp, element::f32, AxisSet{});
auto relu = std::make_shared<op::Relu>(dq); auto relu = std::make_shared<op::Relu>(dq);
auto q =
std::make_shared<op::Quantize>(relu, q_scale, q_zp, element::u8, AxisSet{}, round_mode);
pattern::graph_rewrite_callback callback = [with_bias](pattern::Matcher& m) { pattern::graph_rewrite_callback callback = [with_bias](pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_qconv_relu against " NGRAPH_DEBUG << "In a callback for construct_qconv_relu against "
<< m.get_match_root()->get_name(); << m.get_match_root()->get_name();
auto q_m = std::static_pointer_cast<op::Quantize>(m.get_match_root()); auto dq_m = std::static_pointer_cast<op::Dequantize>(m.get_match_root()->get_argument(0));
auto dq_m = std::static_pointer_cast<op::Dequantize>(q_m->get_argument(0)->get_argument(0));
if (!(ngraph::is_zero(q_m->get_argument(2)) && ngraph::is_zero(dq_m->get_argument(2))))
{
NGRAPH_DEBUG << "Non-zero zero points";
return false;
}
if (q_m->get_round_mode() != op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_EVEN)
{
NGRAPH_DEBUG << "Unsupported round mode for fused kernel";
return false;
}
if (q_m->get_element_type() != element::u8) if (!(ngraph::is_zero(dq_m->get_argument(2))))
{ {
NGRAPH_DEBUG << "Quantize op produces non uint8 output"; NGRAPH_DEBUG << "Non-zero zero point";
return false; return false;
} }
...@@ -1960,9 +1947,6 @@ void ngraph::runtime::cpu::pass::CPUQuantFusion::construct_qconv_relu(bool with_ ...@@ -1960,9 +1947,6 @@ void ngraph::runtime::cpu::pass::CPUQuantFusion::construct_qconv_relu(bool with_
{ {
auto qconv_m = auto qconv_m =
std::static_pointer_cast<op::QuantizedConvolutionBias>(dq_m->get_argument(0)); std::static_pointer_cast<op::QuantizedConvolutionBias>(dq_m->get_argument(0));
// Rescale to q_m's scales directly
auto requant_scale =
qconv_m->get_argument(3) * dq_m->get_argument(1) / q_m->get_argument(1);
qconv_n = std::make_shared<op::QuantizedConvolutionBias>( qconv_n = std::make_shared<op::QuantizedConvolutionBias>(
qconv_m->get_argument(0), qconv_m->get_argument(0),
qconv_m->get_argument(1), qconv_m->get_argument(1),
...@@ -1972,16 +1956,13 @@ void ngraph::runtime::cpu::pass::CPUQuantFusion::construct_qconv_relu(bool with_ ...@@ -1972,16 +1956,13 @@ void ngraph::runtime::cpu::pass::CPUQuantFusion::construct_qconv_relu(bool with_
qconv_m->get_padding_below(), qconv_m->get_padding_below(),
qconv_m->get_padding_above(), qconv_m->get_padding_above(),
qconv_m->get_data_dilation_strides(), qconv_m->get_data_dilation_strides(),
requant_scale, qconv_m->get_argument(3),
true); true);
} }
else else
{ {
auto qconv_m = auto qconv_m =
std::static_pointer_cast<op::QuantizedConvolution>(dq_m->get_argument(0)); std::static_pointer_cast<op::QuantizedConvolution>(dq_m->get_argument(0));
// Rescale to q_m's scales directly
auto requant_scale =
qconv_m->get_argument(2) * dq_m->get_argument(1) / q_m->get_argument(1);
qconv_n = std::make_shared<op::QuantizedConvolutionRelu>( qconv_n = std::make_shared<op::QuantizedConvolutionRelu>(
qconv_m->get_argument(0), qconv_m->get_argument(0),
qconv_m->get_argument(1), qconv_m->get_argument(1),
...@@ -1990,24 +1971,158 @@ void ngraph::runtime::cpu::pass::CPUQuantFusion::construct_qconv_relu(bool with_ ...@@ -1990,24 +1971,158 @@ void ngraph::runtime::cpu::pass::CPUQuantFusion::construct_qconv_relu(bool with_
qconv_m->get_padding_below(), qconv_m->get_padding_below(),
qconv_m->get_padding_above(), qconv_m->get_padding_above(),
qconv_m->get_data_dilation_strides(), qconv_m->get_data_dilation_strides(),
requant_scale); qconv_m->get_argument(2));
} }
ngraph::replace_node(m.get_match_root(), qconv_n); auto zp =
builder::make_constant<uint8_t>(element::u8, dq_m->get_argument(1)->get_shape(), 0);
auto dq_n = std::make_shared<op::Dequantize>(
qconv_n, dq_m->get_argument(1), zp, dq_m->get_output_element_type(0), dq_m->get_axes());
ngraph::replace_node(m.get_match_root(), dq_n);
return true; return true;
}; };
std::shared_ptr<pattern::Matcher> m; std::shared_ptr<pattern::Matcher> m;
if (with_bias) if (with_bias)
{ {
m = std::make_shared<pattern::Matcher>(q, callback, "CPUQuantFusion.QConvBiasRelu"); m = std::make_shared<pattern::Matcher>(relu, callback, "CPUQuantFusion.QConvBiasRelu");
} }
else else
{ {
m = std::make_shared<pattern::Matcher>(q, callback, "CPUQuantFusion.QConvRelu"); m = std::make_shared<pattern::Matcher>(relu, callback, "CPUQuantFusion.QConvRelu");
} }
this->add_matcher(m); this->add_matcher(m);
} }
// Dequantize + AvgPool -> QuantizedAvgPool + Dequantize
void ngraph::runtime::cpu::pass::CPUQuantFusion::construct_qavg_pool()
{
Shape shape{2, 2, 1, 1};
auto input = std::make_shared<pattern::op::Label>(element::i8, shape);
auto dq_scale = std::make_shared<pattern::op::Label>(element::f32, Shape{});
auto dq_zp = std::make_shared<pattern::op::Label>(element::i8, Shape{});
auto dq = std::make_shared<op::Dequantize>(input, dq_scale, dq_zp, element::f32, AxisSet{});
auto avg_pool = std::make_shared<op::AvgPool>(dq, Shape{1, 1});
pattern::graph_rewrite_callback callback = [](pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_qavg_pool against "
<< m.get_match_root()->get_name();
auto avg_pool_m = std::static_pointer_cast<op::AvgPool>(m.get_match_root());
auto dq_m = std::static_pointer_cast<op::Dequantize>(avg_pool_m->get_argument(0));
auto qavg_pool_n = std::make_shared<op::QuantizedAvgPool>(
dq_m->get_argument(0),
avg_pool_m->get_window_shape(),
avg_pool_m->get_window_movement_strides(),
avg_pool_m->get_padding_below(),
avg_pool_m->get_padding_above(),
avg_pool_m->get_include_padding_in_avg_computation());
auto dq_n = std::make_shared<op::Dequantize>(qavg_pool_n,
dq_m->get_argument(1),
dq_m->get_argument(2),
dq_m->get_output_element_type(0),
dq_m->get_axes());
ngraph::replace_node(m.get_match_root(), dq_n);
return true;
};
this->add_matcher(
std::make_shared<pattern::Matcher>(avg_pool, callback, "CPUQuantFusion.QAvgPool"));
}
// Dequantize + Maxpool -> QuantizedMaxpool + Dequantize
void ngraph::runtime::cpu::pass::CPUQuantFusion::construct_qmax_pool()
{
Shape shape{2, 2, 1, 1};
auto input = std::make_shared<pattern::op::Label>(element::i8, shape);
auto dq_scale = std::make_shared<pattern::op::Label>(element::f32, Shape{});
auto dq_zp = std::make_shared<pattern::op::Label>(element::i8, Shape{});
auto dq = std::make_shared<op::Dequantize>(input, dq_scale, dq_zp, element::f32, AxisSet{});
auto max_pool = std::make_shared<op::MaxPool>(dq, Shape{1, 1});
pattern::graph_rewrite_callback callback = [](pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_qmax_pool against "
<< m.get_match_root()->get_name();
auto max_pool_m = std::static_pointer_cast<op::MaxPool>(m.get_match_root());
auto dq_m = std::static_pointer_cast<op::Dequantize>(max_pool_m->get_argument(0));
auto qmax_pool_n =
std::make_shared<op::QuantizedMaxPool>(dq_m->get_argument(0),
max_pool_m->get_window_shape(),
max_pool_m->get_window_movement_strides(),
max_pool_m->get_padding_below(),
max_pool_m->get_padding_above());
auto dq_n = std::make_shared<op::Dequantize>(qmax_pool_n,
dq_m->get_argument(1),
dq_m->get_argument(2),
dq_m->get_output_element_type(0),
dq_m->get_axes());
ngraph::replace_node(m.get_match_root(), dq_n);
return true;
};
this->add_matcher(
std::make_shared<pattern::Matcher>(max_pool, callback, "CPUQuantFusion.QMaxPool"));
}
// {Dequantize}* + Concat -> QuantizedConcat + Dequantize
void ngraph::runtime::cpu::pass::CPUQuantFusion::construct_qconcat()
{
Shape shape{2, 2, 1, 1};
NodeVector inputs;
NodeVector concats;
// Pattern matcher looks for concats with exact number of inputs
inputs.push_back(std::make_shared<pattern::op::Label>(element::f32, shape));
// Concat2, Concat3, ... Concat6
for (size_t i = 0; i < 5; i++)
{
inputs.push_back(std::make_shared<pattern::op::Label>(element::f32, shape));
concats.push_back(std::make_shared<op::Concat>(inputs, 0));
}
pattern::graph_rewrite_callback callback = [](pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_qconcat against "
<< m.get_match_root()->get_name();
auto concat_m = std::static_pointer_cast<op::Concat>(m.get_match_root());
auto dq_m = std::static_pointer_cast<op::Dequantize>(concat_m->get_argument(0));
NodeVector new_args;
for (auto arg : concat_m->get_arguments())
{
if (arg->description() != "Dequantize")
{
return false;
}
// ensure dequant scales are same
if (!ngraph::compare_constants(arg->get_argument(1), dq_m->get_argument(1)))
{
NGRAPH_DEBUG << "QuantizedConcat: Dequantize scale must be same";
return false;
}
new_args.push_back(arg->get_argument(0));
}
auto concat_n =
std::make_shared<op::QuantizedConcat>(new_args, concat_m->get_concatenation_axis());
auto dq_n = std::make_shared<op::Dequantize>(concat_n,
dq_m->get_argument(1),
dq_m->get_argument(2),
dq_m->get_element_type(),
dq_m->get_axes());
ngraph::replace_node(m.get_match_root(), dq_n);
return true;
};
for (size_t i = 0; i < 5; i++)
{
this->add_matcher(std::make_shared<pattern::Matcher>(
concats[i], callback, "CPUQuantFusion.QConcat" + std::to_string(i + 2)));
}
}
void ngraph::runtime::cpu::pass::CPUQuantFusion::construct_dq_q() void ngraph::runtime::cpu::pass::CPUQuantFusion::construct_dq_q()
{ {
Shape shape{2, 2, 1, 1}; Shape shape{2, 2, 1, 1};
...@@ -2103,7 +2218,6 @@ void ngraph::runtime::cpu::pass::CPUQuantFusion::construct_qconvb_add() ...@@ -2103,7 +2218,6 @@ void ngraph::runtime::cpu::pass::CPUQuantFusion::construct_qconvb_add()
//Add left + right //Add left + right
auto add = skipb_l + skipb_r; auto add = skipb_l + skipb_r;
;
auto prelu = std::make_shared<op::Relu>(add); auto prelu = std::make_shared<op::Relu>(add);
pattern::graph_rewrite_callback callback = [dq_l_label, dq_r_label](pattern::Matcher& m) { pattern::graph_rewrite_callback callback = [dq_l_label, dq_r_label](pattern::Matcher& m) {
......
...@@ -111,12 +111,18 @@ public: ...@@ -111,12 +111,18 @@ public:
{ {
construct_qconv_relu(true); construct_qconv_relu(true);
construct_qconv_relu(false); construct_qconv_relu(false);
construct_qavg_pool();
construct_qmax_pool();
construct_qconcat();
construct_qconvb_add(); construct_qconvb_add();
construct_dq_q(); construct_dq_q();
} }
private: private:
void construct_qconv_relu(bool with_bias); void construct_qconv_relu(bool with_bias);
void construct_qavg_pool();
void construct_qmax_pool();
void construct_qconcat();
void construct_dq_q(); void construct_dq_q();
void construct_qconvb_add(); void construct_qconvb_add();
}; };
...@@ -30,6 +30,7 @@ ...@@ -30,6 +30,7 @@
#include "ngraph/op/batch_norm.hpp" #include "ngraph/op/batch_norm.hpp"
#include "ngraph/op/concat.hpp" #include "ngraph/op/concat.hpp"
#include "ngraph/op/dequantize.hpp" #include "ngraph/op/dequantize.hpp"
#include "ngraph/op/experimental/quantized_concat.hpp"
#include "ngraph/op/experimental/quantized_conv.hpp" #include "ngraph/op/experimental/quantized_conv.hpp"
#include "ngraph/op/experimental/quantized_conv_bias.hpp" #include "ngraph/op/experimental/quantized_conv_bias.hpp"
#include "ngraph/op/get_output_element.hpp" #include "ngraph/op/get_output_element.hpp"
...@@ -3604,6 +3605,134 @@ TEST(cpu_quant_fusion, qconvb_relu) ...@@ -3604,6 +3605,134 @@ TEST(cpu_quant_fusion, qconvb_relu)
EXPECT_TRUE(test::all_close(cpu1_results.at(0), cpu2_results.at(0))); EXPECT_TRUE(test::all_close(cpu1_results.at(0), cpu2_results.at(0)));
} }
TEST(cpu_quant_fusion, qavg_pool)
{
auto make_function = []() {
Shape shape_input{1, 2, 4, 4};
auto input = std::make_shared<op::Parameter>(element::f32, shape_input);
auto input_scale = op::Constant::create(element::f32, Shape{}, {2.0f});
auto weights_scale = op::Constant::create(element::f32, Shape{}, {2.0f});
auto int8_zero = op::Constant::create(element::i8, Shape{}, {0});
auto uint8_zero = op::Constant::create(element::u8, Shape{}, {0});
op::Quantize::RoundMode round_mode = op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_EVEN;
auto q_input = std::make_shared<op::Quantize>(
input, input_scale, uint8_zero, element::u8, AxisSet{}, round_mode);
auto dq = std::make_shared<op::Dequantize>(
q_input, input_scale, uint8_zero, element::f32, AxisSet{});
auto avg_pool = std::make_shared<op::AvgPool>(dq, Shape{2, 2});
return make_shared<Function>(NodeVector{avg_pool}, ParameterVector{input});
};
auto cpu_f1 = make_function();
auto cpu_f2 = make_function();
test::Uniform<float> rng(4.0f, 4.0f);
vector<vector<float>> args;
for (shared_ptr<op::Parameter> param : cpu_f1->get_parameters())
{
vector<float> tensor_val(shape_size(param->get_shape()));
rng.initialize(tensor_val);
args.push_back(tensor_val);
}
set_environment("NGRAPH_PASS_ENABLES", "CPUQuantFusion:0", 1);
auto cpu1_results = execute(cpu_f1, args, "CPU");
set_environment("NGRAPH_PASS_ENABLES", "CPUQuantFusion:1", 1);
auto cpu2_results = execute(cpu_f2, args, "CPU");
EXPECT_TRUE(test::all_close(cpu1_results.at(0), cpu2_results.at(0)));
}
TEST(cpu_quant_fusion, qmax_pool)
{
auto make_function = []() {
Shape shape_input{1, 2, 4, 4};
auto input = std::make_shared<op::Parameter>(element::f32, shape_input);
auto input_scale = op::Constant::create(element::f32, Shape{}, {2.0f});
auto weights_scale = op::Constant::create(element::f32, Shape{}, {2.0f});
auto int8_zero = op::Constant::create(element::i8, Shape{}, {0});
auto uint8_zero = op::Constant::create(element::u8, Shape{}, {0});
op::Quantize::RoundMode round_mode = op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_EVEN;
auto q_input = std::make_shared<op::Quantize>(
input, input_scale, uint8_zero, element::u8, AxisSet{}, round_mode);
auto dq = std::make_shared<op::Dequantize>(
q_input, input_scale, uint8_zero, element::f32, AxisSet{});
auto maxpool = std::make_shared<op::MaxPool>(dq, Shape{2, 2});
return make_shared<Function>(NodeVector{maxpool}, ParameterVector{input});
};
auto cpu_f1 = make_function();
auto cpu_f2 = make_function();
test::Uniform<float> rng(1.0f, 10.0f);
vector<vector<float>> args;
for (shared_ptr<op::Parameter> param : cpu_f1->get_parameters())
{
vector<float> tensor_val(shape_size(param->get_shape()));
rng.initialize(tensor_val);
args.push_back(tensor_val);
}
set_environment("NGRAPH_PASS_ENABLES", "CPUQuantFusion:0", 1);
auto cpu1_results = execute(cpu_f1, args, "CPU");
set_environment("NGRAPH_PASS_ENABLES", "CPUQuantFusion:1", 1);
auto cpu2_results = execute(cpu_f2, args, "CPU");
EXPECT_TRUE(test::all_close(cpu1_results.at(0), cpu2_results.at(0)));
}
TEST(cpu_quant_fusion, qconcat)
{
auto make_function = []() {
auto get_input_slice = [](std::shared_ptr<op::Parameter>& input) {
auto input_scale = op::Constant::create(element::f32, Shape{}, {2.0f});
auto int8_zero = op::Constant::create(element::i8, Shape{}, {0});
auto uint8_zero = op::Constant::create(element::u8, Shape{}, {0});
op::Quantize::RoundMode round_mode = op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_EVEN;
auto q_input = std::make_shared<op::Quantize>(
input, input_scale, uint8_zero, element::u8, AxisSet{}, round_mode);
auto dq = std::make_shared<op::Dequantize>(
q_input, input_scale, uint8_zero, element::f32, AxisSet{});
return dq;
};
NodeVector concat_inputs, concats;
ParameterVector inputs;
Shape shape_input{1, 2, 4, 4};
inputs.push_back(std::make_shared<op::Parameter>(element::f32, shape_input));
concat_inputs.push_back(get_input_slice(inputs.back()));
// Concat2 -- Concat7
for (size_t i = 0; i < 6; i++)
{
inputs.push_back(std::make_shared<op::Parameter>(element::f32, shape_input));
concat_inputs.push_back(get_input_slice(inputs.back()));
concats.push_back(std::make_shared<op::Concat>(concat_inputs, 0));
}
return make_shared<Function>(concats, inputs);
};
auto cpu_f1 = make_function();
auto cpu_f2 = make_function();
test::Uniform<float> rng(2.0f, 2.0f);
vector<vector<float>> args;
for (shared_ptr<op::Parameter> param : cpu_f1->get_parameters())
{
vector<float> tensor_val(shape_size(param->get_shape()));
rng.initialize(tensor_val);
args.push_back(tensor_val);
}
set_environment("NGRAPH_PASS_ENABLES", "CPUQuantFusion:0", 1);
auto cpu1_results = execute(cpu_f1, args, "CPU");
set_environment("NGRAPH_PASS_ENABLES", "CPUQuantFusion:1", 1);
auto cpu2_results = execute(cpu_f2, args, "CPU");
// Expect Concat2 -- Concat6 to be fused and not Concat7
ASSERT_EQ(count_ops_of_type<op::QuantizedConcat>(cpu_f2), 5);
EXPECT_TRUE(test::all_close(cpu1_results.at(0), cpu2_results.at(0)));
}
TEST(cpu_quant_fusion, dq_q) TEST(cpu_quant_fusion, dq_q)
{ {
auto make_function = [](bool match_scales = true, bool match_et = true) { auto make_function = [](bool match_scales = true, bool match_et = true) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment