Commit 296ee2cf authored by Nishant Patel's avatar Nishant Patel Committed by Scott Cyphers

Quantize(reorder) bias to int32 (#1933)

* Quantize the bias to int32

* Bias scale fix

* mnist works

* Quantize Bias

* Introduce Quantize op in the graph to quantize bias & feedback

* Comments and some refactoring

* Add test case with float bias and enable int32 as quantized type in ngraph

* Change shape of scale from Shape{} to Shape{1} in the backend
parent 71cc8bbf
...@@ -121,13 +121,27 @@ namespace ngraph ...@@ -121,13 +121,27 @@ namespace ngraph
std::shared_ptr<Node> max_freezed_output, std::shared_ptr<Node> max_freezed_output,
const bool with_relu) const bool with_relu)
{ {
auto output_et = with_relu ? element::u8 : element::i8;
auto requantization_scale = quantization_util::get_scale(min_input, auto requantization_scale = quantization_util::get_scale(min_input,
max_input, max_input,
min_filter, min_filter,
max_filter, max_filter,
min_freezed_output, min_freezed_output,
max_freezed_output); max_freezed_output,
output_et);
if (bias->get_element_type() != element::i32)
{
auto zero = make_constant(element::i32, min_input->get_shape(), 0);
AxisSet quantization_axes;
auto bias_scale =
quantization_util::get_bias_scale(min_input, max_input, min_filter, max_filter);
op::Quantize::RoundMode round_mode =
op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_EVEN;
bias = make_shared<op::Quantize>(
bias, bias_scale, zero, element::i32, quantization_axes, round_mode);
}
return make_shared<op::QuantizedConvolutionBias>(input, return make_shared<op::QuantizedConvolutionBias>(input,
filters, filters,
bias, bias,
...@@ -160,7 +174,8 @@ namespace ngraph ...@@ -160,7 +174,8 @@ namespace ngraph
min_filter, min_filter,
max_filter, max_filter,
min_freezed_output, min_freezed_output,
max_freezed_output); max_freezed_output,
element::u8);
return make_shared<op::QuantizedConvolutionRelu>(input, return make_shared<op::QuantizedConvolutionRelu>(input,
filters, filters,
...@@ -191,7 +206,8 @@ namespace ngraph ...@@ -191,7 +206,8 @@ namespace ngraph
min_filter, min_filter,
max_filter, max_filter,
min_freezed_output, min_freezed_output,
max_freezed_output); max_freezed_output,
element::i8);
return make_shared<op::QuantizedConvolution>(input, return make_shared<op::QuantizedConvolution>(input,
filters, filters,
......
...@@ -92,7 +92,8 @@ namespace ngraph ...@@ -92,7 +92,8 @@ namespace ngraph
std::shared_ptr<Node> min_filter, std::shared_ptr<Node> min_filter,
std::shared_ptr<Node> max_filter, std::shared_ptr<Node> max_filter,
std::shared_ptr<Node> min_freezed_output, std::shared_ptr<Node> min_freezed_output,
std::shared_ptr<Node> max_freezed_output) std::shared_ptr<Node> max_freezed_output,
const ngraph::element::Type& output_type)
{ {
auto type = min_input->get_element_type(); auto type = min_input->get_element_type();
if (type != max_input->get_element_type() || if (type != max_input->get_element_type() ||
...@@ -121,11 +122,59 @@ namespace ngraph ...@@ -121,11 +122,59 @@ namespace ngraph
auto max_abs32 = max_abs(min_out_value, max_out_value); auto max_abs32 = max_abs(min_out_value, max_out_value);
auto max_abs8 = max_abs(min_freezed_output, max_freezed_output); auto max_abs8 = max_abs(min_freezed_output, max_freezed_output);
// Output is signed int. // The output of int8 convolution is accumalated in int32.
// s32 = f32 * std::pow(2, 31)/ max_abs32; // Mkldnn needs a scale to requantize the output back to {u}int8 based on
// s8 = f32 * std::pow(2, 7)/ max_abs8; // if relu is fused or not.
// s8 = s32 * std::pow(2, -24) * max_abs32 / max_abs8;
return make_constant(type, shape, std::pow(2, -24)) * (max_abs32 / max_abs8); // Equation to go from f32 to s32. std::pow(2, 31)/ max_abs32 can be thought of
// as the scale used for the quantization..
// 1. s32 = f32 * std::pow(2, 31)/ max_abs32;
// Equation to go from f32 to u8.
// 2. u8 = f32 * std::pow(2, 8)/ max_abs8;
// Equation to go from f32 to s8.
// 3. s8 = f32 * std::pow(2, 7)/ max_abs8;
// Replacing f32 from eq 1 in eq 2.
// 4. u8 = s32 * std::pow(2, -23) * max_abs32 / max_abs8;
// Replacing f32 from eq 1 in eq 3.
// 5. s8 = s32 * std::pow(2, -24) * max_abs32 / max_abs8;
return make_constant(
type, shape, std::pow(2, (output_type == element::i8) ? -24 : -23)) *
(max_abs32 / max_abs8);
}
std::shared_ptr<Node> get_bias_scale(std::shared_ptr<Node> min_input,
std::shared_ptr<Node> max_input,
std::shared_ptr<Node> min_filter,
std::shared_ptr<Node> max_filter)
{
auto type = min_input->get_element_type();
if (type != max_input->get_element_type() ||
type != min_filter->get_element_type() ||
type != max_filter->get_element_type())
{
throw ngraph_error("get_bias_scale: min and max must have same type");
}
auto shape = min_input->get_shape();
if (shape != max_input->get_shape() || shape != min_filter->get_shape() ||
shape != max_filter->get_shape())
{
throw ngraph_error("get_bias_scale: min and max must have same shape");
}
auto max_abs_input_range = max_abs(min_input, max_input);
auto max_abs_filter_range = max_abs(min_filter, max_filter);
auto range = make_constant(type,
shape,
std::numeric_limits<uint8_t>::max() *
std::numeric_limits<int8_t>::max());
return range / (max_abs_input_range * max_abs_filter_range);
} }
std::shared_ptr<Node> get_scale(std::shared_ptr<Node> input_min_range, std::shared_ptr<Node> get_scale(std::shared_ptr<Node> input_min_range,
......
...@@ -108,7 +108,6 @@ namespace ngraph ...@@ -108,7 +108,6 @@ namespace ngraph
auto& out0_tensor = external_function->get_tensor_data(out[0].get_name()); auto& out0_tensor = external_function->get_tensor_data(out[0].get_name());
auto& mkldnn_emitter = external_function->get_mkldnn_emitter(); auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto conv_index = auto conv_index =
mkldnn_emitter->build_convolution<ngraph::op::QuantizedConvolutionBias>( mkldnn_emitter->build_convolution<ngraph::op::QuantizedConvolutionBias>(
node, args, out); node, args, out);
......
...@@ -254,6 +254,7 @@ size_t MKLDNNEmitter::build_convolution_forward(const mkldnn::memory::desc& inpu ...@@ -254,6 +254,7 @@ size_t MKLDNNEmitter::build_convolution_forward(const mkldnn::memory::desc& inpu
mkldnn::memory::dims(padding_below.begin(), padding_below.end()), mkldnn::memory::dims(padding_below.begin(), padding_below.end()),
mkldnn::memory::dims(padding_above.begin(), padding_above.end()), mkldnn::memory::dims(padding_above.begin(), padding_above.end()),
mkldnn::padding_kind::zero}, mkldnn::padding_kind::zero},
conv_attr, conv_attr,
executor::global_cpu_engine}, executor::global_cpu_engine},
*m_mkldnn_primitives[input_data_index], *m_mkldnn_primitives[input_data_index],
...@@ -271,15 +272,16 @@ size_t MKLDNNEmitter::build_convolution_forward(const mkldnn::memory::desc& inpu ...@@ -271,15 +272,16 @@ size_t MKLDNNEmitter::build_convolution_forward(const mkldnn::memory::desc& inpu
return conv_index; return conv_index;
} }
size_t MKLDNNEmitter::build_quantized_convolution(const mkldnn::memory::desc& input_data_desc, size_t
const mkldnn::memory::desc& weights_desc, MKLDNNEmitter::build_quantized_convolution_forward(const mkldnn::memory::desc& input_data_desc,
const mkldnn::memory::desc& result_desc, const mkldnn::memory::desc& weights_desc,
const ngraph::Strides& strides, const mkldnn::memory::desc& result_desc,
const ngraph::Strides& dilation_strides, const ngraph::Strides& strides,
const ngraph::CoordinateDiff& padding_below, const ngraph::Strides& dilation_strides,
const ngraph::CoordinateDiff& padding_above, const ngraph::CoordinateDiff& padding_below,
const float scale, const ngraph::CoordinateDiff& padding_above,
const mkldnn::post_ops& pops) const float scale,
const mkldnn::post_ops& pops)
{ {
size_t input_data_index = build_memory_primitive(input_data_desc); size_t input_data_index = build_memory_primitive(input_data_desc);
size_t weights_index = build_memory_primitive(weights_desc); size_t weights_index = build_memory_primitive(weights_desc);
...@@ -312,60 +314,49 @@ size_t MKLDNNEmitter::build_quantized_convolution(const mkldnn::memory::desc& in ...@@ -312,60 +314,49 @@ size_t MKLDNNEmitter::build_quantized_convolution(const mkldnn::memory::desc& in
return conv_index; return conv_index;
} }
size_t MKLDNNEmitter::build_quantized_convolution(const mkldnn::memory::desc& input_data_desc, size_t
const mkldnn::memory::desc& weights_desc, MKLDNNEmitter::build_quantized_convolution_forward(const mkldnn::memory::desc& input_data_desc,
const mkldnn::memory::desc& bias_desc, const mkldnn::memory::desc& weights_desc,
const mkldnn::memory::desc& result_desc, const mkldnn::memory::desc& bias_desc,
const ngraph::Strides& strides, const mkldnn::memory::desc& result_desc,
const ngraph::Strides& dilation_strides, const ngraph::Strides& strides,
const ngraph::CoordinateDiff& padding_below, const ngraph::Strides& dilation_strides,
const ngraph::CoordinateDiff& padding_above, const ngraph::CoordinateDiff& padding_below,
const float scale, const ngraph::CoordinateDiff& padding_above,
const mkldnn::post_ops& pops) const float scale,
const mkldnn::post_ops& pops)
{ {
size_t input_data_index = build_memory_primitive(input_data_desc); size_t input_data_index = build_memory_primitive(input_data_desc);
size_t weights_index = build_memory_primitive(weights_desc); size_t weights_index = build_memory_primitive(weights_desc);
size_t bias_index = build_memory_primitive(bias_desc); size_t bias_index = build_memory_primitive(bias_desc);
size_t result_index = build_memory_primitive(result_desc); size_t result_index = build_memory_primitive(result_desc);
std::vector<float> output_scale; std::vector<float> output_scale;
output_scale.push_back(scale); output_scale.push_back(scale);
mkldnn::primitive_attr conv_attr; mkldnn::primitive_attr conv_attr;
conv_attr.set_post_ops(pops); conv_attr.set_post_ops(pops);
/* Specify the rounding mode */ /* Specify the rounding mode */
conv_attr.set_int_output_round_mode(mkldnn::round_mode::round_nearest); conv_attr.set_int_output_round_mode(mkldnn::round_mode::round_nearest);
/* Specify the scales array and corresponding mask */ /* Specify the scales array and corresponding mask */
conv_attr.set_output_scales(0, output_scale); conv_attr.set_output_scales(0, output_scale);
size_t conv_index = insert_primitive(new mkldnn::convolution_forward(
size_t conv_index = 0; {{mkldnn::prop_kind::forward,
try mkldnn::algorithm::convolution_direct,
{ input_data_desc,
conv_index = insert_primitive(new mkldnn::convolution_forward( weights_desc,
{{mkldnn::prop_kind::forward, bias_desc,
mkldnn::algorithm::convolution_direct, result_desc,
input_data_desc, mkldnn::memory::dims(strides.begin(), strides.end()),
weights_desc, mkldnn::memory::dims(dilation_strides.begin(), dilation_strides.end()),
bias_desc, mkldnn::memory::dims(padding_below.begin(), padding_below.end()),
result_desc, mkldnn::memory::dims(padding_above.begin(), padding_above.end()),
mkldnn::memory::dims(strides.begin(), strides.end()), mkldnn::padding_kind::zero},
mkldnn::memory::dims(dilation_strides.begin(), dilation_strides.end()), conv_attr,
mkldnn::memory::dims(padding_below.begin(), padding_below.end()), executor::global_cpu_engine},
mkldnn::memory::dims(padding_above.begin(), padding_above.end()), *m_mkldnn_primitives[input_data_index],
mkldnn::padding_kind::zero}, *m_mkldnn_primitives[weights_index],
conv_attr, *m_mkldnn_primitives[bias_index],
executor::global_cpu_engine}, *m_mkldnn_primitives[result_index]));
*m_mkldnn_primitives[input_data_index], m_primitive_deps[conv_index] = {input_data_index, weights_index, bias_index, result_index};
*m_mkldnn_primitives[weights_index],
*m_mkldnn_primitives[bias_index],
*m_mkldnn_primitives[result_index]));
m_primitive_deps[conv_index] = {input_data_index, weights_index, bias_index, result_index};
}
catch (const mkldnn::error& e)
{
throw ngraph_error("Could not create convolution " + e.message);
}
return conv_index; return conv_index;
} }
......
...@@ -108,31 +108,28 @@ namespace ngraph ...@@ -108,31 +108,28 @@ namespace ngraph
const ngraph::CoordinateDiff& padding_above, const ngraph::CoordinateDiff& padding_above,
const mkldnn::post_ops& pops = mkldnn::post_ops()); const mkldnn::post_ops& pops = mkldnn::post_ops());
size_t size_t build_quantized_convolution_forward(
build_quantized_convolution(const mkldnn::memory::desc& input_data_desc, const mkldnn::memory::desc& input_data_desc,
const mkldnn::memory::desc& weights_desc, const mkldnn::memory::desc& weights_desc,
const mkldnn::memory::desc& result_desc, const mkldnn::memory::desc& result_desc,
const ngraph::Strides& strides, const ngraph::Strides& strides,
const ngraph::Strides& dilation_strides, const ngraph::Strides& dilation_strides,
const ngraph::CoordinateDiff& padding_below, const ngraph::CoordinateDiff& padding_below,
const ngraph::CoordinateDiff& padding_above, const ngraph::CoordinateDiff& padding_above,
const float scale, const float scale,
const mkldnn::post_ops& pops = mkldnn::post_ops()); const mkldnn::post_ops& pops = mkldnn::post_ops());
/** size_t build_quantized_convolution_forward(
* QuantizedConvolution + bias forward const mkldnn::memory::desc& input_data_desc,
*/ const mkldnn::memory::desc& weights_desc,
size_t const mkldnn::memory::desc& bias_desc,
build_quantized_convolution(const mkldnn::memory::desc& input_data_desc, const mkldnn::memory::desc& result_desc,
const mkldnn::memory::desc& weights_desc, const ngraph::Strides& strides,
const mkldnn::memory::desc& bias_desc, const ngraph::Strides& dilation_strides,
const mkldnn::memory::desc& result_desc, const ngraph::CoordinateDiff& padding_below,
const ngraph::Strides& strides, const ngraph::CoordinateDiff& padding_above,
const ngraph::Strides& dilation_strides, const float scale,
const ngraph::CoordinateDiff& padding_below, const mkldnn::post_ops& pops = mkldnn::post_ops());
const ngraph::CoordinateDiff& padding_above,
const float scale,
const mkldnn::post_ops& pops = mkldnn::post_ops());
template <typename OP> template <typename OP>
size_t build_convolution(const ngraph::Node* node, size_t build_convolution(const ngraph::Node* node,
...@@ -237,7 +234,7 @@ namespace ngraph ...@@ -237,7 +234,7 @@ namespace ngraph
auto scale_val = scale_const_op->get_vector<float>(); auto scale_val = scale_const_op->get_vector<float>();
return build_quantized_convolution( return build_quantized_convolution_forward(
data_desc, data_desc,
weights_desc, weights_desc,
result_desc, result_desc,
...@@ -260,7 +257,7 @@ namespace ngraph ...@@ -260,7 +257,7 @@ namespace ngraph
auto scale_val = scale_const_op->get_vector<float>(); auto scale_val = scale_const_op->get_vector<float>();
return build_quantized_convolution( return build_quantized_convolution_forward(
data_desc, data_desc,
weights_desc, weights_desc,
result_desc, result_desc,
...@@ -285,7 +282,7 @@ namespace ngraph ...@@ -285,7 +282,7 @@ namespace ngraph
// conv+bias = cvt_to_int8(scale*(dst + bias)) // conv+bias = cvt_to_int8(scale*(dst + bias))
auto bias_desc = mkldnn_utils::get_input_mkldnn_md(node, 2); auto bias_desc = mkldnn_utils::get_input_mkldnn_md(node, 2);
return build_quantized_convolution( return build_quantized_convolution_forward(
data_desc, data_desc,
weights_desc, weights_desc,
bias_desc, bias_desc,
......
...@@ -282,6 +282,10 @@ mkldnn::memory::desc runtime::cpu::mkldnn_utils::create_default_mkldnn_md( ...@@ -282,6 +282,10 @@ mkldnn::memory::desc runtime::cpu::mkldnn_utils::create_default_mkldnn_md(
et = runtime::cpu::mkldnn_utils::get_mkldnn_data_type(node->get_input_element_type(index)); et = runtime::cpu::mkldnn_utils::get_mkldnn_data_type(node->get_input_element_type(index));
} }
if (shape == Shape{})
{
shape = Shape{1};
}
return memory::desc(memory::dims(shape.begin(), shape.end()), et, format); return memory::desc(memory::dims(shape.begin(), shape.end()), et, format);
} }
......
...@@ -93,6 +93,10 @@ shared_ptr<Node> runtime::cpu::pass::CPULayout::insert_input_conversions( ...@@ -93,6 +93,10 @@ shared_ptr<Node> runtime::cpu::pass::CPULayout::insert_input_conversions(
auto tv = output.get_tensor_ptr(); auto tv = output.get_tensor_ptr();
auto tvl = dynamic_pointer_cast<runtime::cpu::LayoutDescriptor>(tv->get_tensor_layout()); auto tvl = dynamic_pointer_cast<runtime::cpu::LayoutDescriptor>(tv->get_tensor_layout());
if (input.get_shape() == Shape{})
{
tvl->set_mkldnn_md(required_mds[index]);
}
if (!tvl) if (!tvl)
{ {
throw ngraph_error( throw ngraph_error(
......
...@@ -28,7 +28,7 @@ const element::Type element::f32(32, true, true, false, "float"); ...@@ -28,7 +28,7 @@ const element::Type element::f32(32, true, true, false, "float");
const element::Type element::f64(64, true, true, false, "double"); const element::Type element::f64(64, true, true, false, "double");
const element::Type element::i8(8, false, true, true, "int8_t"); const element::Type element::i8(8, false, true, true, "int8_t");
const element::Type element::i16(16, false, true, false, "int16_t"); const element::Type element::i16(16, false, true, false, "int16_t");
const element::Type element::i32(32, false, true, false, "int32_t"); const element::Type element::i32(32, false, true, true, "int32_t");
const element::Type element::i64(64, false, true, false, "int64_t"); const element::Type element::i64(64, false, true, false, "int64_t");
const element::Type element::u8(8, false, false, true, "uint8_t"); const element::Type element::u8(8, false, false, true, "uint8_t");
const element::Type element::u16(16, false, false, false, "uint16_t"); const element::Type element::u16(16, false, false, false, "uint16_t");
......
...@@ -208,7 +208,7 @@ TEST(builder, scaled_QC_with_relu) ...@@ -208,7 +208,7 @@ TEST(builder, scaled_QC_with_relu)
copy_data(b, b_data); copy_data(b, b_data);
auto result = backend->create_tensor(element::u8, shape_r); auto result = backend->create_tensor(element::u8, shape_r);
backend->call_with_validate(f, {result}, {a, b}); backend->call_with_validate(f, {result}, {a, b});
EXPECT_EQ((vector<uint8_t>{0, 0, 0, 0, 0, 0, 69, 106, 90}), read_vector<uint8_t>(result)); EXPECT_EQ((vector<uint8_t>{0, 0, 0, 0, 0, 0, 138, 212, 181}), read_vector<uint8_t>(result));
} }
TEST(builder, scaled_QC_with_bias) TEST(builder, scaled_QC_with_bias)
...@@ -300,7 +300,53 @@ TEST(builder, scaled_QC_with_bias_and_relu) ...@@ -300,7 +300,53 @@ TEST(builder, scaled_QC_with_bias_and_relu)
copy_data(c, c_data); copy_data(c, c_data);
auto result = backend->create_tensor(element::u8, shape_r); auto result = backend->create_tensor(element::u8, shape_r);
backend->call_with_validate(f, {result}, {a, b, c}); backend->call_with_validate(f, {result}, {a, b, c});
EXPECT_EQ((vector<uint8_t>{0, 0, 0, 0, 0, 0, 96, 133, 117}), read_vector<uint8_t>(result)); EXPECT_EQ((vector<uint8_t>{0, 0, 0, 0, 0, 0, 191, 255, 234}), read_vector<uint8_t>(result));
}
TEST(builder, scaled_QC_with_f32_bias_and_relu)
{
Shape shape_a{1, 1, 3, 3}; // input shape
Shape shape_b{1, 1, 3, 3}; // filter shape
Shape shape_r{1, 1, 3, 3}; // output shape
vector<uint8_t> a_data = {1, 2, 3, 4, 5, 6, 7, 8, 9};
vector<int8_t> b_data = {1, 2, 1, 0, 0, 0, -1, -2, -1};
vector<float> c_data = {5};
auto A = make_shared<op::Parameter>(element::u8, shape_a);
auto B = make_shared<op::Parameter>(element::i8, shape_b);
auto Bias = make_shared<op::Parameter>(element::f32, Shape{1});
auto C = op::Constant::create(element::f32, Shape{}, {0.0f});
auto D = op::Constant::create(element::f32, Shape{}, {255.0f});
auto E = op::Constant::create(element::f32, Shape{}, {-127.0f});
auto F = op::Constant::create(element::f32, Shape{}, {127.0f});
auto G = op::Constant::create(element::f32, Shape{}, {20.0f});
auto H = op::Constant::create(element::f32, Shape{}, {-24.0f});
auto CV = ngraph::builder::ScaledQuantizedConvolutionBias(A,
B,
Bias,
Strides{1, 1}, // move_strides
Strides{1, 1}, // filter_dilation
CoordinateDiff{1, 1}, // below_pads
CoordinateDiff{1, 1}, // above_pads
Strides{1, 1}, // data_dilation
C,
D,
E,
F,
G,
H,
true);
auto f = make_shared<Function>(NodeVector{CV}, op::ParameterVector{A, B, Bias});
auto backend = runtime::Backend::create("CPU");
// Create some tensors for input/output
auto a = backend->create_tensor(element::u8, shape_a);
copy_data(a, a_data);
auto b = backend->create_tensor(element::i8, shape_b);
copy_data(b, b_data);
auto c = backend->create_tensor(element::f32, Shape{1});
copy_data(c, c_data);
auto result = backend->create_tensor(element::u8, shape_r);
backend->call_with_validate(f, {result}, {a, b, c});
EXPECT_EQ((vector<uint8_t>{0, 0, 0, 0, 0, 0, 191, 255, 234}), read_vector<uint8_t>(result));
} }
TEST(builder, scaled_Q_unsigned) TEST(builder, scaled_Q_unsigned)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment