Unverified Commit 02559274 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Merge branch 'master' into bob/backend_api3

parents d607c6f4 13b4966b
......@@ -325,6 +325,12 @@ namespace ngraph
max_freezed_output_conv_1,
min_freezed_output_conv_2,
max_freezed_output_conv_2);
if (output_et == element::u8)
{
// Need to multiply by two to account for u8 requantization_scale
auto two = make_constant(element::f32, sum_scale->get_shape(), 2.0f);
sum_scale = two * sum_scale;
}
if (bias->get_element_type() != element::i32)
{
......
......@@ -548,3 +548,19 @@ bool ngraph::is_valid_rank(const std::shared_ptr<Node>& node, std::vector<size_t
}
return false;
}
bool ngraph::compare_constants(const std::shared_ptr<Node>& n1, const std::shared_ptr<Node>& n2)
{
if (!(n1->is_constant() && n2->is_constant()))
{
return false;
}
if (static_pointer_cast<op::Constant>(n1)->get_value_strings() !=
static_pointer_cast<op::Constant>(n2)->get_value_strings())
{
return false;
}
return true;
}
......@@ -309,6 +309,8 @@ namespace ngraph
bool is_one(std::shared_ptr<Node> reduce_constant);
bool compare_constants(const std::shared_ptr<Node>& n1, const std::shared_ptr<Node>& n2);
// Returns true if `node` is live in the graph i.e. a result op
// transitively uses this `node`
bool is_used(Node* node);
......
......@@ -334,7 +334,7 @@ namespace ngraph
}
if (old_pops.kind(i) == mkldnn::primitive::kind::sum)
{
new_pops.append_sum(2 * dyn_post_op_scales[0]);
new_pops.append_sum(dyn_post_op_scales[0]);
}
}
conv_attr.set_post_ops(new_pops);
......
......@@ -1102,6 +1102,7 @@ void runtime::cpu::CPU_ExternalFunction::register_common_passes(ngraph::pass::Ma
REGISTER_KNOBBED_PASS(ReshapeElimination, false, ngraph::pass);
REGISTER_KNOBBED_PASS(CoreFusion, true, ngraph::pass);
REGISTER_KNOBBED_PASS(CPUFusion, true, runtime::cpu::pass);
REGISTER_KNOBBED_PASS(CPUQuantFusion, true, runtime::cpu::pass);
REGISTER_KNOBBED_PASS(CPUHorizontalFusion, true, runtime::cpu::pass);
REGISTER_KNOBBED_PASS(CPUCollapseDims, true, runtime::cpu::pass);
#if defined(NGRAPH_HALIDE)
......
......@@ -186,21 +186,14 @@ namespace ngraph
ops.append_sum(1.f);
}
if (std::is_same<OP, ngraph::op::QuantizedConvolutionBiasAdd>())
if (std::is_same<OP, ngraph::op::QuantizedConvolutionBiasAdd>() ||
std::is_same<OP, ngraph::op::QuantizedConvolutionBiasSignedAdd>())
{
auto sum_scale_val =
extract_scale_value<ngraph::op::QuantizedConvolutionBiasAdd>(node, 5);
ops.append_sum(sum_scale_val[0]);
}
if (std::is_same<OP, ngraph::op::QuantizedConvolutionBiasSignedAdd>())
{
auto sum_scale_val =
extract_scale_value<ngraph::op::QuantizedConvolutionBiasSignedAdd>(node,
5);
ops.append_sum(2.0 * sum_scale_val[0]);
}
if (has_relu<OP>(node))
{
const float ops_scale = 1.f;
......@@ -740,21 +733,14 @@ namespace ngraph
ops.append_sum(1.f);
}
if (std::is_same<OP, ngraph::op::QuantizedConvolutionBiasAdd>())
if (std::is_same<OP, ngraph::op::QuantizedConvolutionBiasAdd>() ||
std::is_same<OP, ngraph::op::QuantizedConvolutionBiasSignedAdd>())
{
auto sum_scale_val =
extract_scale_value<ngraph::op::QuantizedConvolutionBiasAdd>(node, 5);
ops.append_sum(sum_scale_val[0]);
}
if (std::is_same<OP, ngraph::op::QuantizedConvolutionBiasSignedAdd>())
{
auto sum_scale_val =
extract_scale_value<ngraph::op::QuantizedConvolutionBiasSignedAdd>(node,
5);
ops.append_sum(2.0 * sum_scale_val[0]);
}
if (has_relu<OP>(node))
{
const float ops_scale = 1.f;
......
......@@ -99,11 +99,27 @@ namespace ngraph
{
return false;
}
if (node->get_input_element_type(0) != element::f32)
// Data
if (node->get_input_element_type(0) != element::f32 &&
node->get_input_element_type(0) != element::i8 &&
node->get_input_element_type(0) != element::u8)
{
return false;
}
// Weights
if (node->get_input_element_type(1) != element::f32 &&
node->get_input_element_type(1) != element::i8)
{
return false;
}
// Outputs
if (node->get_output_element_type(0) != element::f32 &&
node->get_output_element_type(0) != element::i8 &&
node->get_output_element_type(0) != element::u8 &&
node->get_output_element_type(0) != element::i32)
{
return false;
}
return true;
}
}
......
......@@ -29,10 +29,15 @@
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/convert.hpp"
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/dequantize.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/exp.hpp"
#include "ngraph/op/experimental/quantized_conv.hpp"
#include "ngraph/op/experimental/quantized_conv_bias.hpp"
#include "ngraph/op/experimental/quantized_conv_relu.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/max_pool.hpp"
#include "ngraph/op/maximum.hpp"
......@@ -41,6 +46,7 @@
#include "ngraph/op/negative.hpp"
#include "ngraph/op/pad.hpp"
#include "ngraph/op/parameter.hpp"
#include "ngraph/op/quantize.hpp"
#include "ngraph/op/relu.hpp"
#include "ngraph/op/replace_slice.hpp"
#include "ngraph/op/reshape.hpp"
......@@ -1856,3 +1862,345 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_update_slice()
replace_slice, callback, "CPUFusion.UpdateSlice");
this->add_matcher(m);
}
// QuantizedConvolution + Dequantize + Relu + Quantize -> QuantizedConvolutionRelu
void ngraph::runtime::cpu::pass::CPUQuantFusion::construct_qconv_relu(bool with_bias)
{
Shape shape{2, 2, 1, 1};
auto data_batch = std::make_shared<pattern::op::Label>(element::u8, shape);
auto filters = std::make_shared<pattern::op::Label>(element::i8, shape);
auto requantization_scale = std::make_shared<pattern::op::Label>(element::f32, Shape{});
auto dq_scale = std::make_shared<pattern::op::Label>(element::f32, Shape{});
auto dq_zp = std::make_shared<pattern::op::Label>(element::i8, Shape{});
auto q_scale = std::make_shared<pattern::op::Label>(element::f32, Shape{});
auto q_zp = std::make_shared<pattern::op::Label>(element::u8, Shape{});
op::Quantize::RoundMode round_mode = op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_EVEN;
std::shared_ptr<ngraph::op::Op> qconv;
if (with_bias)
{
auto bias = std::make_shared<pattern::op::Label>(element::i32, Shape{shape[0]});
qconv = std::make_shared<op::QuantizedConvolutionBias>(data_batch,
filters,
bias,
Strides{1, 1},
Strides{1, 1},
CoordinateDiff{0, 0},
CoordinateDiff{0, 0},
Strides{1, 1},
requantization_scale,
false);
}
else
{
qconv = std::make_shared<op::QuantizedConvolution>(data_batch,
filters,
Strides{1, 1},
Strides{1, 1},
CoordinateDiff{0, 0},
CoordinateDiff{0, 0},
Strides{1, 1},
requantization_scale);
}
auto dq = std::make_shared<op::Dequantize>(qconv, dq_scale, dq_zp, element::f32, AxisSet{});
auto relu = std::make_shared<op::Relu>(dq);
auto q =
std::make_shared<op::Quantize>(relu, q_scale, q_zp, element::u8, AxisSet{}, round_mode);
pattern::graph_rewrite_callback callback = [with_bias](pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_qconv_relu against "
<< m.get_match_root()->get_name();
auto q_m = std::static_pointer_cast<op::Quantize>(m.get_match_root());
auto dq_m = std::static_pointer_cast<op::Dequantize>(q_m->get_argument(0)->get_argument(0));
if (!(ngraph::is_zero(q_m->get_argument(2)) && ngraph::is_zero(dq_m->get_argument(2))))
{
NGRAPH_DEBUG << "Non-zero zero points";
return false;
}
if (q_m->get_round_mode() != op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_EVEN)
{
NGRAPH_DEBUG << "Unsupported round mode for fused kernel";
return false;
}
if (q_m->get_element_type() != element::u8)
{
NGRAPH_DEBUG << "Quantize op produces non uint8 output";
return false;
}
if (dq_m->get_argument(0)->get_users().size() > 1)
{
NGRAPH_DEBUG << "QuantizedConvolutionBias has more than one user";
return false;
}
if (!with_bias)
{
if (!runtime::cpu::mkldnn_utils::can_use_mkldnn_conv<op::QuantizedConvolution>(
dq_m->get_argument(0).get()))
{
NGRAPH_DEBUG << "Quantized Convolution not supported by MKLDNN";
return false;
}
}
std::shared_ptr<ngraph::op::Op> qconv_n;
if (with_bias)
{
auto qconv_m =
std::static_pointer_cast<op::QuantizedConvolutionBias>(dq_m->get_argument(0));
// Rescale to q_m's scales directly
auto requant_scale =
qconv_m->get_argument(3) * dq_m->get_argument(1) / q_m->get_argument(1);
qconv_n = std::make_shared<op::QuantizedConvolutionBias>(
qconv_m->get_argument(0),
qconv_m->get_argument(1),
qconv_m->get_argument(2),
qconv_m->get_window_movement_strides(),
qconv_m->get_window_dilation_strides(),
qconv_m->get_padding_below(),
qconv_m->get_padding_above(),
qconv_m->get_data_dilation_strides(),
requant_scale,
true);
}
else
{
auto qconv_m =
std::static_pointer_cast<op::QuantizedConvolution>(dq_m->get_argument(0));
// Rescale to q_m's scales directly
auto requant_scale =
qconv_m->get_argument(2) * dq_m->get_argument(1) / q_m->get_argument(1);
qconv_n = std::make_shared<op::QuantizedConvolutionRelu>(
qconv_m->get_argument(0),
qconv_m->get_argument(1),
qconv_m->get_window_movement_strides(),
qconv_m->get_window_dilation_strides(),
qconv_m->get_padding_below(),
qconv_m->get_padding_above(),
qconv_m->get_data_dilation_strides(),
requant_scale);
}
ngraph::replace_node(m.get_match_root(), qconv_n);
return true;
};
std::shared_ptr<pattern::Matcher> m;
if (with_bias)
{
m = std::make_shared<pattern::Matcher>(q, callback, "CPUQuantFusion.QConvBiasRelu");
}
else
{
m = std::make_shared<pattern::Matcher>(q, callback, "CPUQuantFusion.QConvRelu");
}
this->add_matcher(m);
}
void ngraph::runtime::cpu::pass::CPUQuantFusion::construct_dq_q()
{
Shape shape{2, 2, 1, 1};
auto input = std::make_shared<pattern::op::Label>(element::i8, shape);
auto dq_scale = std::make_shared<pattern::op::Label>(element::f32, Shape{});
auto dq_zp = std::make_shared<pattern::op::Label>(element::i8, Shape{});
auto q_scale = std::make_shared<pattern::op::Label>(element::f32, Shape{});
auto q_zp = std::make_shared<pattern::op::Label>(element::i8, Shape{});
auto dq = std::make_shared<op::Dequantize>(input, dq_scale, dq_zp, element::f32, AxisSet{});
op::Quantize::RoundMode round_mode = op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_EVEN;
auto q = std::make_shared<op::Quantize>(dq, q_scale, q_zp, element::i8, AxisSet{}, round_mode);
pattern::graph_rewrite_callback callback = [input](pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_dq_q against "
<< m.get_match_root()->get_name();
auto q_m = std::static_pointer_cast<op::Quantize>(m.get_match_root());
auto dq_m = std::static_pointer_cast<op::Dequantize>(q_m->get_argument(0));
if (!(ngraph::is_zero(q_m->get_argument(2)) && ngraph::is_zero(dq_m->get_argument(2))))
{
NGRAPH_DEBUG << "Non-zero zero points";
return false;
}
if (m.get_match_root()->get_element_type() !=
m.get_pattern_map()[input]->get_element_type())
{
NGRAPH_DEBUG << "Type mismatch between input and quantize output";
return false;
}
if (!ngraph::compare_constants(q_m->get_argument(1), dq_m->get_argument(1)))
{
NGRAPH_DEBUG << "Scales dont match";
return false;
}
ngraph::replace_node(m.get_match_root(), m.get_pattern_map()[input]);
return true;
};
auto m = std::make_shared<pattern::Matcher>(q, callback, "CPUQuantFusion.DQandQ");
this->add_matcher(m);
}
// Left Branch(LB): QCONVB + DQ + {Reshape/Broadcast}
// Right Branch(RB): DQ + {Reshape/Broadcast}
// Relu(LB + RB) -> QCB{S}A
void ngraph::runtime::cpu::pass::CPUQuantFusion::construct_qconvb_add()
{
Shape shape{2, 2, 1, 1};
auto data_batch = std::make_shared<pattern::op::Label>(element::u8, shape);
auto filters = std::make_shared<pattern::op::Label>(element::i8, shape);
auto bias = std::make_shared<pattern::op::Label>(element::i32, Shape{shape[1]});
auto requantization_scale = std::make_shared<pattern::op::Label>(element::f32, Shape{});
auto output_scale = std::make_shared<pattern::op::Label>(element::f32, Shape{});
auto dq_scale1 = std::make_shared<pattern::op::Label>(element::f32, Shape{});
auto dq_zp1 = std::make_shared<pattern::op::Label>(element::i8, Shape{});
auto dq_scale2 = std::make_shared<pattern::op::Label>(element::f32, Shape{});
auto dq_zp2 = std::make_shared<pattern::op::Label>(element::i8, Shape{});
// Left Graph
auto qconvb = std::make_shared<op::QuantizedConvolutionBias>(data_batch,
filters,
bias,
Strides{1, 1},
Strides{1, 1},
CoordinateDiff{0, 0},
CoordinateDiff{0, 0},
Strides{1, 1},
requantization_scale,
false);
auto qconvb_label = std::make_shared<pattern::op::Label>(qconvb, nullptr, NodeVector{qconvb});
auto dq_l =
std::make_shared<op::Dequantize>(qconvb_label, dq_scale1, dq_zp1, element::f32, AxisSet{});
auto dq_l_label = std::make_shared<pattern::op::Label>(dq_l, nullptr, NodeVector{dq_l});
auto skipr_l = std::make_shared<pattern::op::Skip>(
dq_l_label, [](std::shared_ptr<Node> n) { return n->description() == "Reshape"; });
auto skipb_l = std::make_shared<pattern::op::Skip>(
skipr_l, [](std::shared_ptr<Node> n) { return n->description() == "Broadcast"; });
//Right Graph
auto summand = std::make_shared<pattern::op::Label>(element::i8, qconvb->get_shape());
auto dq_r =
std::make_shared<op::Dequantize>(summand, dq_scale2, dq_zp2, element::f32, AxisSet{});
auto dq_r_label = std::make_shared<pattern::op::Label>(dq_r, nullptr, NodeVector{dq_r});
auto skipr_r = std::make_shared<pattern::op::Skip>(
dq_r_label, [](std::shared_ptr<Node> n) { return n->description() == "Reshape"; });
auto skipb_r = std::make_shared<pattern::op::Skip>(
skipr_r, [](std::shared_ptr<Node> n) { return n->description() == "Broadcast"; });
//Add left + right
auto add = skipb_l + skipb_r;
;
auto prelu = std::make_shared<op::Relu>(add);
pattern::graph_rewrite_callback callback = [dq_l_label, dq_r_label](pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_qconvb_dq_add_relu against "
<< m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map();
auto add_m = std::dynamic_pointer_cast<op::Add>(m.get_match_root()->get_argument(0));
auto dq_l_m = std::dynamic_pointer_cast<op::Dequantize>(pattern_map[dq_l_label]);
auto dq_r_m = std::dynamic_pointer_cast<op::Dequantize>(pattern_map[dq_r_label]);
auto qconv =
std::static_pointer_cast<op::QuantizedConvolutionBias>(dq_l_m->get_argument(0));
auto inplace_input = dq_r_m->get_argument(0);
if (!(ngraph::is_zero(dq_l_m->get_argument(2)) && ngraph::is_zero(dq_r_m->get_argument(2))))
{
NGRAPH_DEBUG << "Non-zero zero points";
return false;
}
if (dq_r_m->get_input_element_type(0) != element::i8 &&
dq_r_m->get_input_element_type(0) != element::u8)
{
NGRAPH_DEBUG << "Non int8/uint8 summand";
return false;
}
if (get_user_count(qconv.get()) > 1)
{
NGRAPH_DEBUG << "QuantizedConvolutionBias has more than one user";
return false;
}
// The next two checks are not required once we support fallbacks in dex/codegen
// for non in-place input
if (!is_post_dominated(inplace_input.get(), add_m.get()))
{
NGRAPH_DEBUG << "Unsafe to use in-place kernel since add's in-place input has "
"potential live users";
return false;
}
if (inplace_input->is_parameter())
{
NGRAPH_DEBUG
<< "Unsafe to use in-place kernel since add's in-place input is a parameter";
return false;
}
if (inplace_input->get_shape() != qconv->get_shape())
{
NGRAPH_DEBUG << "Summand shape doesn't match convolution shape";
return false;
}
auto requant_scale = qconv->get_argument(3);
auto dq_l_scale = dq_l_m->get_argument(1);
auto dq_r_scale = dq_r_m->get_argument(1);
auto sum_scale = (dq_r_scale / dq_l_scale);
std::shared_ptr<ngraph::op::Op> qconvba;
if (dq_r_m->get_input_element_type(2) == element::i8)
{
// TODO (jbobba): Investigate the need for Convert op
qconvba = std::make_shared<op::Convert>(
std::make_shared<op::QuantizedConvolutionBiasSignedAdd>(
qconv->get_argument(0),
qconv->get_argument(1),
qconv->get_argument(2),
inplace_input,
qconv->get_window_movement_strides(),
qconv->get_window_dilation_strides(),
qconv->get_padding_below(),
qconv->get_padding_above(),
qconv->get_data_dilation_strides(),
requant_scale,
sum_scale,
true),
element::u8);
}
else
{
qconvba = std::make_shared<op::QuantizedConvolutionBiasAdd>(
qconv->get_argument(0),
qconv->get_argument(1),
qconv->get_argument(2),
inplace_input,
qconv->get_window_movement_strides(),
qconv->get_window_dilation_strides(),
qconv->get_padding_below(),
qconv->get_padding_above(),
qconv->get_data_dilation_strides(),
requant_scale,
sum_scale,
true);
}
auto zp = op::Constant::create(element::u8, Shape{}, {0});
auto DQ =
std::make_shared<op::Dequantize>(qconvba, dq_l_scale, zp, element::f32, AxisSet{});
ngraph::replace_node(m.get_match_root(), DQ);
return true;
};
auto m =
std::make_shared<pattern::Matcher>(prelu, callback, "CPUQuantFusion.QConvBiasSignedAdd");
this->add_matcher(m);
}
......@@ -28,6 +28,7 @@ namespace ngraph
namespace pass
{
class CPUFusion;
class CPUQuantFusion;
}
}
}
......@@ -101,3 +102,21 @@ private:
void construct_update_slice();
void construct_fuse_lstm_recurrent_state();
};
class CPU_BACKEND_API ngraph::runtime::cpu::pass::CPUQuantFusion : public ngraph::pass::GraphRewrite
{
public:
CPUQuantFusion()
: GraphRewrite()
{
construct_qconv_relu(true);
construct_qconv_relu(false);
construct_qconvb_add();
construct_dq_q();
}
private:
void construct_qconv_relu(bool with_bias);
void construct_dq_q();
void construct_qconvb_add();
};
......@@ -21,6 +21,7 @@
#include <memory>
#include "gtest/gtest.h"
#include "misc.hpp"
#include "ngraph/autodiff/adjoints.hpp"
#include "ngraph/file_util.hpp"
#include "ngraph/graph_util.hpp"
......@@ -28,10 +29,14 @@
#include "ngraph/ngraph.hpp"
#include "ngraph/op/batch_norm.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/dequantize.hpp"
#include "ngraph/op/experimental/quantized_conv.hpp"
#include "ngraph/op/experimental/quantized_conv_bias.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/max_pool.hpp"
#include "ngraph/op/negative.hpp"
#include "ngraph/op/parameter.hpp"
#include "ngraph/op/quantize.hpp"
#include "ngraph/op/relu.hpp"
#include "ngraph/op/sigmoid.hpp"
#include "ngraph/op/sum.hpp"
......@@ -3473,3 +3478,332 @@ TEST(cpu_fusion, validate_fuse_gru_inputs)
EXPECT_TRUE(test::all_close(cpu_results.at(i), int_results.at(i), 1.0e-4f, 1.0e-4f));
}
}
TEST(cpu_quant_fusion, qconv_relu)
{
auto make_function = []() {
Shape shape_input{1, 2, 2, 2};
Shape shape_weights{1, 2, 1, 1};
auto input = std::make_shared<op::Parameter>(element::f32, shape_input);
auto weights = std::make_shared<op::Parameter>(element::f32, shape_weights);
auto input_scale = op::Constant::create(element::f32, Shape{}, {2.0f});
auto weights_scale = op::Constant::create(element::f32, Shape{}, {2.0f});
auto output_scale = op::Constant::create(element::f32, Shape{}, {4.0f});
auto int8_zero = op::Constant::create(element::i8, Shape{}, {0});
auto uint8_zero = op::Constant::create(element::u8, Shape{}, {0});
op::Quantize::RoundMode round_mode = op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_EVEN;
auto q_input = std::make_shared<op::Quantize>(
input, input_scale, uint8_zero, element::u8, AxisSet{}, round_mode);
auto q_weights = std::make_shared<op::Quantize>(
weights, weights_scale, int8_zero, element::i8, AxisSet{}, round_mode);
auto requant_scale = (input_scale * weights_scale) / output_scale;
auto conv = std::make_shared<op::QuantizedConvolution>(q_input,
q_weights,
Strides{1, 1},
Strides{1, 1},
CoordinateDiff{0, 0},
CoordinateDiff{0, 0},
Strides{1, 1},
requant_scale);
auto dq = std::make_shared<op::Dequantize>(
conv, output_scale, int8_zero, element::f32, AxisSet{});
auto relu = std::make_shared<op::Relu>(dq);
auto q = std::make_shared<op::Quantize>(
relu, output_scale, uint8_zero, element::u8, AxisSet{}, round_mode);
auto q_f =
std::make_shared<op::Dequantize>(q, output_scale, uint8_zero, element::f32, AxisSet{});
return make_shared<Function>(NodeVector{q_f}, ParameterVector{input, weights});
};
auto cpu_f1 = make_function();
auto cpu_f2 = make_function();
test::Uniform<float> rng(2.0f, 2.0f);
vector<vector<float>> args;
for (shared_ptr<op::Parameter> param : cpu_f1->get_parameters())
{
vector<float> tensor_val(shape_size(param->get_shape()));
rng.initialize(tensor_val);
args.push_back(tensor_val);
}
set_environment("NGRAPH_PASS_ENABLES", "CPUQuantFusion:0", 1);
auto cpu1_results = execute(cpu_f1, args, "CPU");
set_environment("NGRAPH_PASS_ENABLES", "CPUQuantFusion:1", 1);
auto cpu2_results = execute(cpu_f2, args, "CPU");
// Expected output - [2, 2, ...]
EXPECT_TRUE(test::all_close(cpu1_results.at(0), cpu2_results.at(0)));
}
TEST(cpu_quant_fusion, qconvb_relu)
{
auto make_function = []() {
Shape shape_input{1, 2, 2, 2};
Shape shape_weights{1, 2, 1, 1};
auto input = std::make_shared<op::Parameter>(element::f32, shape_input);
auto weights = std::make_shared<op::Parameter>(element::f32, shape_weights);
auto bias = std::make_shared<op::Parameter>(element::f32, Shape{shape_weights[0]});
auto input_scale = op::Constant::create(element::f32, Shape{}, {2.0f});
auto weights_scale = op::Constant::create(element::f32, Shape{}, {2.0f});
auto output_scale = op::Constant::create(element::f32, Shape{}, {4.0f});
auto int8_zero = op::Constant::create(element::i8, Shape{}, {0});
auto int32_zero = op::Constant::create(element::i32, Shape{}, {0});
auto uint8_zero = op::Constant::create(element::u8, Shape{}, {0});
op::Quantize::RoundMode round_mode = op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_EVEN;
auto q_input = std::make_shared<op::Quantize>(
input, input_scale, uint8_zero, element::u8, AxisSet{}, round_mode);
auto q_weights = std::make_shared<op::Quantize>(
weights, weights_scale, int8_zero, element::i8, AxisSet{}, round_mode);
auto q_bias = std::make_shared<op::Quantize>(
bias, input_scale * weights_scale, int32_zero, element::i32, AxisSet{}, round_mode);
auto requant_scale = (input_scale * weights_scale) / output_scale;
auto conv = std::make_shared<op::QuantizedConvolutionBias>(q_input,
q_weights,
bias,
Strides{1, 1},
Strides{1, 1},
CoordinateDiff{0, 0},
CoordinateDiff{0, 0},
Strides{1, 1},
requant_scale);
auto dq = std::make_shared<op::Dequantize>(
conv, output_scale, int8_zero, element::f32, AxisSet{});
auto relu = std::make_shared<op::Relu>(dq);
auto q = std::make_shared<op::Quantize>(
relu, output_scale, uint8_zero, element::u8, AxisSet{}, round_mode);
auto q_f =
std::make_shared<op::Dequantize>(q, output_scale, uint8_zero, element::f32, AxisSet{});
return make_shared<Function>(NodeVector{q_f}, ParameterVector{input, weights, bias});
};
auto cpu_f1 = make_function();
auto cpu_f2 = make_function();
test::Uniform<float> rng(2.0f, 2.0f);
vector<vector<float>> args;
for (shared_ptr<op::Parameter> param : cpu_f1->get_parameters())
{
vector<float> tensor_val(shape_size(param->get_shape()));
rng.initialize(tensor_val);
args.push_back(tensor_val);
}
set_environment("NGRAPH_PASS_ENABLES", "CPUQuantFusion:0", 1);
auto cpu1_results = execute(cpu_f1, args, "CPU");
set_environment("NGRAPH_PASS_ENABLES", "CPUQuantFusion:1", 1);
auto cpu2_results = execute(cpu_f2, args, "CPU");
EXPECT_TRUE(test::all_close(cpu1_results.at(0), cpu2_results.at(0)));
}
TEST(cpu_quant_fusion, dq_q)
{
auto make_function = [](bool match_scales = true, bool match_et = true) {
Shape shape_input{1, 2, 2};
auto input = std::make_shared<op::Parameter>(element::i8, shape_input);
auto dq_scale = op::Constant::create(element::f32, Shape{}, {2.0f});
auto int8_zero = op::Constant::create(element::i8, Shape{}, {0});
auto dq =
std::make_shared<op::Dequantize>(input, dq_scale, int8_zero, element::f32, AxisSet{});
float q_scalev = 2.0f;
if (!match_scales)
{
q_scalev = 1.0f;
}
auto q_scale = op::Constant::create(element::f32, Shape{}, {q_scalev});
op::Quantize::RoundMode round_mode = op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_EVEN;
if (match_et)
{
auto q = std::make_shared<op::Quantize>(
dq, q_scale, int8_zero, element::i8, AxisSet{}, round_mode);
return make_shared<Function>(NodeVector{q}, ParameterVector{input});
}
else
{
auto uint8_zero = op::Constant::create(element::u8, Shape{}, {0});
auto q = std::make_shared<op::Quantize>(
dq, q_scale, uint8_zero, element::u8, AxisSet{}, round_mode);
return make_shared<Function>(NodeVector{q}, ParameterVector{input});
}
};
auto cpu_f1 = make_function();
auto cpu_f2 = make_function();
vector<vector<int8_t>> args;
args.push_back({-1, 2, 3, 4});
set_environment("NGRAPH_PASS_ENABLES", "CPUQuantFusion:0", 1);
auto cpu1_results = execute(cpu_f1, args, "CPU");
set_environment("NGRAPH_PASS_ENABLES", "CPUQuantFusion:1", 1);
auto cpu2_results = execute(cpu_f2, args, "CPU");
EXPECT_TRUE(test::all_close(cpu1_results.at(0), cpu2_results.at(0)));
auto backend = runtime::Backend::create("CPU");
auto fuse = make_function(true, true);
auto no_fuse1 = make_function(false, true);
auto no_fuse2 = make_function(true, false);
backend->compile(fuse);
backend->compile(no_fuse1);
backend->compile(no_fuse2);
ASSERT_EQ(count_ops_of_type<op::Quantize>(fuse), 0);
ASSERT_EQ(count_ops_of_type<op::Quantize>(no_fuse1), 1);
ASSERT_EQ(count_ops_of_type<op::Quantize>(no_fuse2), 1);
}
TEST(cpu_quant_fusion, qconvbsa)
{
auto make_function = []() {
Shape shape_input{1, 2, 2, 2};
Shape shape_weights{1, 2, 1, 1};
Shape shape_summand{1, 1, 2, 2};
auto input = std::make_shared<op::Parameter>(element::f32, shape_input);
auto weights = std::make_shared<op::Parameter>(element::f32, shape_weights);
auto bias = std::make_shared<op::Parameter>(element::f32, Shape{shape_weights[0]});
auto summand = std::make_shared<op::Parameter>(element::f32, shape_summand);
auto input_scale = op::Constant::create(element::f32, Shape{}, {2.0f});
auto weights_scale = op::Constant::create(element::f32, Shape{}, {2.0f});
auto output_scale = op::Constant::create(element::f32, Shape{}, {4.0f});
auto summand_scale = op::Constant::create(element::f32, Shape{}, {2.0f});
auto int8_zero = op::Constant::create(element::i8, Shape{}, {0});
auto int32_zero = op::Constant::create(element::i32, Shape{}, {0});
auto uint8_zero = op::Constant::create(element::u8, Shape{}, {0});
op::Quantize::RoundMode round_mode = op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_EVEN;
auto q_input = std::make_shared<op::Quantize>(
input, input_scale, uint8_zero, element::u8, AxisSet{}, round_mode);
auto q_weights = std::make_shared<op::Quantize>(
weights, weights_scale, int8_zero, element::i8, AxisSet{}, round_mode);
auto q_bias = std::make_shared<op::Quantize>(
bias, input_scale * weights_scale, int32_zero, element::i32, AxisSet{}, round_mode);
auto q_summand = std::make_shared<op::Quantize>(
summand, summand_scale, int8_zero, element::i8, AxisSet{}, round_mode);
// Left Graph
auto requant_scale = (input_scale * weights_scale) / output_scale;
auto conv = std::make_shared<op::QuantizedConvolutionBias>(q_input,
q_weights,
bias,
Strides{1, 1},
Strides{1, 1},
CoordinateDiff{0, 0},
CoordinateDiff{0, 0},
Strides{1, 1},
requant_scale);
auto dq_l = std::make_shared<op::Dequantize>(
conv, output_scale, int8_zero, element::f32, AxisSet{});
auto r_l = std::make_shared<op::Reshape>(dq_l, AxisVector{0, 1, 2, 3}, Shape{1, 2, 2});
auto b_l = std::make_shared<op::Broadcast>(r_l, Shape{1, 1, 2, 2}, AxisSet{0});
// Right Graph
auto dq_r = std::make_shared<op::Dequantize>(
q_summand, summand_scale, int8_zero, element::f32, AxisSet{});
auto r_r = std::make_shared<op::Reshape>(dq_r, AxisVector{0, 1, 2, 3}, Shape{1, 2, 2});
auto b_r = std::make_shared<op::Broadcast>(r_r, Shape{1, 1, 2, 2}, AxisSet{0});
auto add = b_l + b_r;
auto relu = std::make_shared<op::Relu>(add);
return make_shared<Function>(NodeVector{relu},
ParameterVector{input, weights, bias, summand});
};
auto cpu_f1 = make_function();
auto cpu_f2 = make_function();
test::Uniform<float> rng(4.0f, 4.0f);
vector<vector<float>> args;
for (shared_ptr<op::Parameter> param : cpu_f1->get_parameters())
{
vector<float> tensor_val(shape_size(param->get_shape()));
rng.initialize(tensor_val);
args.push_back(tensor_val);
}
// Disable CPUQuantFusion
set_environment("NGRAPH_PASS_ENABLES", "CPUQuantFusion:0", 1);
auto cpu1_results = execute(cpu_f1, args, "CPU");
// Enable CPUQuantFusion
set_environment("NGRAPH_PASS_ENABLES", "CPUQuantFusion:1", 1);
auto cpu2_results = execute(cpu_f2, args, "CPU");
EXPECT_TRUE(test::all_close(cpu1_results.at(0), cpu2_results.at(0)));
}
TEST(cpu_quant_fusion, qconvba)
{
auto make_function = []() {
Shape shape_input{1, 2, 2, 2};
Shape shape_weights{1, 2, 1, 1};
Shape shape_summand{1, 1, 2, 2};
auto input = std::make_shared<op::Parameter>(element::f32, shape_input);
auto weights = std::make_shared<op::Parameter>(element::f32, shape_weights);
auto bias = std::make_shared<op::Parameter>(element::f32, Shape{shape_weights[0]});
auto summand = std::make_shared<op::Parameter>(element::f32, shape_summand);
auto input_scale = op::Constant::create(element::f32, Shape{}, {2.0f});
auto weights_scale = op::Constant::create(element::f32, Shape{}, {2.0f});
auto output_scale = op::Constant::create(element::f32, Shape{}, {4.0f});
auto summand_scale = op::Constant::create(element::f32, Shape{}, {4.0f});
auto int8_zero = op::Constant::create(element::i8, Shape{}, {0});
auto int32_zero = op::Constant::create(element::i32, Shape{}, {0});
auto uint8_zero = op::Constant::create(element::u8, Shape{}, {0});
op::Quantize::RoundMode round_mode = op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_EVEN;
auto q_input = std::make_shared<op::Quantize>(
input, input_scale, uint8_zero, element::u8, AxisSet{}, round_mode);
auto q_weights = std::make_shared<op::Quantize>(
weights, weights_scale, int8_zero, element::i8, AxisSet{}, round_mode);
auto q_bias = std::make_shared<op::Quantize>(
bias, input_scale * weights_scale, int32_zero, element::i32, AxisSet{}, round_mode);
auto q_summand = std::make_shared<op::Quantize>(
summand, summand_scale, uint8_zero, element::u8, AxisSet{}, round_mode);
// Left Graph
auto requant_scale = (input_scale * weights_scale) / output_scale;
auto conv = std::make_shared<op::QuantizedConvolutionBias>(q_input,
q_weights,
bias,
Strides{1, 1},
Strides{1, 1},
CoordinateDiff{0, 0},
CoordinateDiff{0, 0},
Strides{1, 1},
requant_scale);
auto dq_l = std::make_shared<op::Dequantize>(
conv, output_scale, int8_zero, element::f32, AxisSet{});
auto r_l = std::make_shared<op::Reshape>(dq_l, AxisVector{0, 1, 2, 3}, Shape{1, 2, 2});
auto b_l = std::make_shared<op::Broadcast>(r_l, Shape{1, 1, 2, 2}, AxisSet{0});
// Right Graph
auto dq_r = std::make_shared<op::Dequantize>(
q_summand, summand_scale, uint8_zero, element::f32, AxisSet{});
auto r_r = std::make_shared<op::Reshape>(dq_r, AxisVector{0, 1, 2, 3}, Shape{1, 2, 2});
auto b_r = std::make_shared<op::Broadcast>(r_r, Shape{1, 1, 2, 2}, AxisSet{0});
auto add = b_l + b_r;
auto relu = std::make_shared<op::Relu>(add);
return make_shared<Function>(NodeVector{relu},
ParameterVector{input, weights, bias, summand});
};
auto cpu_f1 = make_function();
auto cpu_f2 = make_function();
test::Uniform<float> rng(2.0f, 2.0f);
vector<vector<float>> args;
for (shared_ptr<op::Parameter> param : cpu_f1->get_parameters())
{
vector<float> tensor_val(shape_size(param->get_shape()));
rng.initialize(tensor_val);
args.push_back(tensor_val);
}
// Disable CPUQuantFusion
set_environment("NGRAPH_PASS_ENABLES", "CPUQuantFusion:0", 1);
auto cpu1_results = execute(cpu_f1, args, "CPU");
// Enable CPUQuantFusion
set_environment("NGRAPH_PASS_ENABLES", "CPUQuantFusion:1", 1);
auto cpu2_results = execute(cpu_f2, args, "CPU");
EXPECT_TRUE(test::all_close(cpu1_results.at(0), cpu2_results.at(0)));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment