Commit b5f43973 authored by Nishant Patel's avatar Nishant Patel Committed by Scott Cyphers

Add support for Quantized Convolution op via mkldnn for IA backend (codegen + DEX) (#1620)

* Add support for Quantized Convolution op via mkldnn for IA backend (codegen + DEX)

* Use call_with_validate

* Style fix

* Fix clang compile errors
parent 28228857
......@@ -24,6 +24,7 @@ set(SRC
cpu_tensor_view_wrapper.cpp
cpu_tensor_view.cpp
cpu_tracing.cpp
quantization_util.cpp
builder/add.cpp
builder/allreduce.cpp
builder/avg_pool.cpp
......@@ -35,6 +36,7 @@ set(SRC
builder/concat.cpp
builder/convert.cpp
builder/convert_layout.cpp
builder/quantized_conv.cpp
builder/convolution.cpp
builder/dequantize.cpp
builder/quantize.cpp
......@@ -80,6 +82,7 @@ set(SRC
op/group_conv.cpp
op/conv_bias.cpp
op/conv_relu.cpp
op/quantized_conv.cpp
op/convert_layout.cpp
op/dequantize.cpp
op/quantize.cpp
......
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "ngraph/runtime/cpu/op/quantized_conv.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/runtime/cpu/cpu_builder.hpp"
#include "ngraph/runtime/cpu/mkldnn_invoke.hpp"
#include "ngraph/runtime/cpu/mkldnn_utils.hpp"
using namespace std;
using namespace ngraph;
namespace ngraph
{
namespace runtime
{
namespace cpu
{
template <>
void Builder::BUILDER_DECL(ngraph::op::QuantizedConvolution)
{
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{
auto qconvolution = static_cast<const ngraph::op::QuantizedConvolution*>(node);
auto& functors = external_function->get_functors();
auto& arg0_tensor = external_function->get_tensor_data(args[0].get_name());
auto& arg1_tensor = external_function->get_tensor_data(args[1].get_name());
auto& out_tensor = external_function->get_tensor_data(out[0].get_name());
auto& out1_tensor = external_function->get_tensor_data(out[1].get_name());
auto& out2_tensor = external_function->get_tensor_data(out[2].get_name());
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto conv_index =
mkldnn_emitter->build_convolution<ngraph::op::QuantizedConvolution>(
node, args, out);
auto& deps = mkldnn_emitter->get_primitive_deps(conv_index);
float min_freezed_output = qconvolution->get_freezed_output_min();
float max_freezed_output = qconvolution->get_freezed_output_max();
auto functor = [&, conv_index, min_freezed_output, max_freezed_output](
CPURuntimeContext* ctx) {
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[0], arg0_tensor);
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[1], arg1_tensor);
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[2], out_tensor);
*(static_cast<float*>(out1_tensor)) = min_freezed_output;
*(static_cast<float*>(out2_tensor)) = max_freezed_output;
cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, conv_index);
};
functors.emplace_back(functor);
}
else
{
throw ngraph_error("unsupported parameters for QuantizedConvolution via DEX");
}
}
REGISTER_OP_BUILDER(QuantizedConvolution);
}
}
}
......@@ -2631,6 +2631,37 @@ namespace ngraph
}
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::QuantizedConvolution)
{
auto qconvolution = static_cast<const ngraph::op::QuantizedConvolution*>(node);
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto conv_index =
mkldnn_emitter->build_convolution<ngraph::op::QuantizedConvolution>(
node, args, out);
auto& deps = mkldnn_emitter->get_primitive_deps(conv_index);
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[0])
<< ", " << args[0].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[1])
<< ", " << args[1].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[2])
<< ", " << out[0].get_name() << ");\n";
writer << "*(" << out[1].get_name()
<< ") = " << qconvolution->get_freezed_output_min() << ";\n";
writer << "*(" << out[2].get_name()
<< ") = " << qconvolution->get_freezed_output_max() << ";\n";
writer << "cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, "
<< to_string(conv_index) << ");\n";
}
else
{
throw ngraph_error("unsupported parameters for QuantizedConvolution");
}
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::GroupConvolution)
{
......
......@@ -150,6 +150,7 @@
#include "ngraph/runtime/cpu/op/max_pool_with_indices.hpp"
#include "ngraph/runtime/cpu/op/quantize.hpp"
#include "ngraph/runtime/cpu/op/quantized_avg_pool.hpp"
#include "ngraph/runtime/cpu/op/quantized_conv.hpp"
#include "ngraph/runtime/cpu/op/quantized_max_pool.hpp"
#include "ngraph/runtime/cpu/op/rnn.hpp"
#include "ngraph/runtime/cpu/op/sigmoid.hpp"
......@@ -305,6 +306,8 @@ static const runtime::cpu::OpMap dispatcher{
{TI(ngraph::op::GroupConvolution), &runtime::cpu::CPU_Emitter::emit<op::GroupConvolution>},
{TI(ngraph::op::ConvolutionBias), &runtime::cpu::CPU_Emitter::emit<op::ConvolutionBias>},
{TI(ngraph::op::ConvolutionRelu), &runtime::cpu::CPU_Emitter::emit<op::ConvolutionRelu>},
{TI(ngraph::op::QuantizedConvolution),
&runtime::cpu::CPU_Emitter::emit<op::QuantizedConvolution>},
{TI(ngraph::op::ConvolutionBiasAdd), &runtime::cpu::CPU_Emitter::emit<op::ConvolutionBiasAdd>},
// conv+bias backprop for data share the same implementation as ConvolutionBackpropData
{TI(ngraph::op::ConvolutionBiasBackpropFiltersBias),
......
......@@ -297,6 +297,45 @@ size_t MKLDNNEmitter::build_convolution_forward(const mkldnn::memory::desc& inpu
return conv_index;
}
size_t MKLDNNEmitter::build_quantized_convolution(const mkldnn::memory::desc& input_data_desc,
const mkldnn::memory::desc& weights_desc,
const mkldnn::memory::desc& result_desc,
const ngraph::Strides& strides,
const ngraph::Strides& dilation_strides,
const ngraph::CoordinateDiff& padding_below,
const ngraph::CoordinateDiff& padding_above,
const float scale)
{
size_t input_data_index = build_memory_primitive(input_data_desc);
size_t weights_index = build_memory_primitive(weights_desc);
size_t result_index = build_memory_primitive(result_desc);
std::vector<float> output_scale;
output_scale.push_back(scale);
mkldnn::primitive_attr conv_attr;
/* Specify the rounding mode */
conv_attr.set_int_output_round_mode(mkldnn::round_mode::round_nearest);
/* Specify the scales array and corresponding mask */
conv_attr.set_output_scales(0, output_scale);
size_t conv_index = insert_primitive(new mkldnn::convolution_forward(
{{mkldnn::prop_kind::forward,
mkldnn::algorithm::convolution_direct,
input_data_desc,
weights_desc,
result_desc,
mkldnn::memory::dims(strides.begin(), strides.end()),
mkldnn::memory::dims(dilation_strides.begin(), dilation_strides.end()),
mkldnn::memory::dims(padding_below.begin(), padding_below.end()),
mkldnn::memory::dims(padding_above.begin(), padding_above.end()),
mkldnn::padding_kind::zero},
conv_attr,
mkldnn_utils::global_cpu_engine},
*m_mkldnn_primitives[input_data_index],
*m_mkldnn_primitives[weights_index],
*m_mkldnn_primitives[result_index]));
m_primitive_deps[conv_index] = {input_data_index, weights_index, result_index};
return conv_index;
}
size_t MKLDNNEmitter::build_convolution_forward(const mkldnn::memory::desc& input_data_desc,
const mkldnn::memory::desc& weights_desc,
const mkldnn::memory::desc& bias_desc,
......
......@@ -32,6 +32,8 @@
#include "ngraph/runtime/cpu/op/conv_add.hpp"
#include "ngraph/runtime/cpu/op/conv_bias.hpp"
#include "ngraph/runtime/cpu/op/conv_relu.hpp"
#include "ngraph/runtime/cpu/op/quantized_conv.hpp"
#include "ngraph/runtime/cpu/quantization_util.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/strides.hpp"
#include "ngraph/type/element_type.hpp"
......@@ -104,6 +106,15 @@ namespace ngraph
const ngraph::CoordinateDiff& padding_above,
const mkldnn::post_ops& pops = mkldnn::post_ops());
size_t build_quantized_convolution(const mkldnn::memory::desc& input_data_desc,
const mkldnn::memory::desc& weights_desc,
const mkldnn::memory::desc& result_desc,
const ngraph::Strides& strides,
const ngraph::Strides& dilation_strides,
const ngraph::CoordinateDiff& padding_below,
const ngraph::CoordinateDiff& padding_above,
const float scale);
template <typename OP>
size_t build_convolution(const ngraph::Node* node,
const std::vector<TensorViewWrapper>& args,
......@@ -185,6 +196,19 @@ namespace ngraph
convolution->get_padding_above(),
ops);
}
else if (std::is_same<OP, ngraph::op::QuantizedConvolution>())
{
const float scale = quantization_util::get_scale(node);
return build_quantized_convolution(
data_desc,
weights_desc,
result_desc,
convolution->get_window_movement_strides(),
window_dilation_strides_adjusted,
convolution->get_padding_below(),
convolution->get_padding_above(),
scale);
}
else
{
return build_convolution_forward(data_desc,
......
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "quantized_conv.hpp"
#include <numeric>
#include "ngraph/op/constant.hpp"
#include "ngraph/op/convolution.hpp"
#include "ngraph/util.hpp"
using namespace std;
using namespace ngraph;
op::QuantizedConvolution::QuantizedConvolution(const shared_ptr<Node>& data_batch,
const shared_ptr<Node>& filters,
const Strides& window_movement_strides,
const Strides& window_dilation_strides,
const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above,
const Strides& data_dilation_strides,
const std::shared_ptr<Node> min_input,
const std::shared_ptr<Node> max_input,
const std::shared_ptr<Node> min_filter,
const std::shared_ptr<Node> max_filter,
const std::shared_ptr<Node> min_freezed_output,
const std::shared_ptr<Node> max_freezed_output)
: Op("QuantizedConvolution",
check_single_output_args({data_batch,
filters,
min_input,
max_input,
min_filter,
max_filter,
min_freezed_output,
max_freezed_output}))
, m_window_movement_strides(window_movement_strides)
, m_window_dilation_strides(window_dilation_strides)
, m_padding_below(padding_below)
, m_padding_above(padding_above)
, m_data_dilation_strides(data_dilation_strides)
{
constructor_validate_and_infer_types();
//TODO(nbpatel): Add checks.
auto& data_batch_shape = data_batch->get_shape();
auto& filters_shape = filters->get_shape();
auto min_input_const_op = std::static_pointer_cast<ngraph::op::Constant>(min_input);
auto max_input_const_op = std::static_pointer_cast<ngraph::op::Constant>(max_input);
auto min_filter_const_op = std::static_pointer_cast<ngraph::op::Constant>(min_filter);
auto max_filter_const_op = std::static_pointer_cast<ngraph::op::Constant>(max_filter);
auto min_freezed_output_const_op =
std::static_pointer_cast<ngraph::op::Constant>(min_freezed_output);
auto max_freezed_output_const_op =
std::static_pointer_cast<ngraph::op::Constant>(max_freezed_output);
float input_min = *(static_cast<float const*>(min_input_const_op->get_data_ptr()));
float input_max = *(static_cast<float const*>(max_input_const_op->get_data_ptr()));
float filter_min = *(static_cast<float const*>(min_filter_const_op->get_data_ptr()));
float filter_max = *(static_cast<float const*>(max_filter_const_op->get_data_ptr()));
float output_min = *(static_cast<float const*>(min_freezed_output_const_op->get_data_ptr()));
float output_max = *(static_cast<float const*>(max_freezed_output_const_op->get_data_ptr()));
this->m_input_min = input_min;
this->m_input_max = input_max;
this->m_filter_min = filter_min;
this->m_filter_max = filter_max;
this->m_freezed_output_min = output_min;
this->m_freezed_output_max = output_max;
set_output_size(3);
set_output_type(0,
element::i8,
util::infer_convolution_output_shape(this,
data_batch_shape,
filters_shape,
window_movement_strides,
window_dilation_strides,
padding_below,
padding_above,
data_dilation_strides,
0, /* batch_axis_data, */
1, /* input_channel_axis_data, */
1, /* input_channel_axis_filters, */
0, /* output_channel_axis_filters, */
0, /* batch_axis_result, */
1 /* output_channel_axis_result, */
));
set_output_type(1, element::f32, Shape{1});
set_output_type(2, element::f32, Shape{1});
}
shared_ptr<Node> op::QuantizedConvolution::copy_with_new_args(const NodeVector& new_args) const
{
if (new_args.size() != 8)
{
throw ngraph_error("Incorrect number of new arguments");
}
return shared_ptr<Node>(new QuantizedConvolution(new_args.at(0),
new_args.at(1),
get_window_movement_strides(),
get_window_dilation_strides(),
get_padding_below(),
get_padding_above(),
get_data_dilation_strides(),
new_args.at(2),
new_args.at(3),
new_args.at(4),
new_args.at(5),
new_args.at(6),
new_args.at(7)));
}
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/op.hpp"
namespace ngraph
{
namespace op
{
class QuantizedConvolution : public Op
{
public:
QuantizedConvolution(const std::shared_ptr<Node>& data_batch,
const std::shared_ptr<Node>& filters,
const Strides& window_movement_strides,
const Strides& window_dilation_strides,
const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above,
const Strides& data_dilation_strides,
const std::shared_ptr<Node> min_input,
const std::shared_ptr<Node> max_input,
const std::shared_ptr<Node> min_filter,
const std::shared_ptr<Node> max_filter,
const std::shared_ptr<Node> min_freezed_output,
const std::shared_ptr<Node> max_freezed_output);
const Strides& get_window_movement_strides() const { return m_window_movement_strides; }
const Strides& get_window_dilation_strides() const { return m_window_dilation_strides; }
const CoordinateDiff& get_padding_below() const { return m_padding_below; }
const CoordinateDiff& get_padding_above() const { return m_padding_above; }
const Strides& get_data_dilation_strides() const { return m_data_dilation_strides; }
std::shared_ptr<Node> get_filters() { return get_argument(1); }
std::shared_ptr<Node> get_data_batch() { return get_argument(0); }
float get_input_min() const { return m_input_min; }
float get_input_max() const { return m_input_max; }
float get_filter_min() const { return m_filter_min; }
float get_filter_max() const { return m_filter_max; }
float get_freezed_output_min() const { return m_freezed_output_min; }
float get_freezed_output_max() const { return m_freezed_output_max; }
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
protected:
Strides m_window_movement_strides;
Strides m_window_dilation_strides;
CoordinateDiff m_padding_below;
CoordinateDiff m_padding_above;
Strides m_data_dilation_strides;
float m_input_min;
float m_input_max;
float m_filter_min;
float m_filter_max;
float m_freezed_output_min;
float m_freezed_output_max;
};
}
}
......@@ -48,6 +48,7 @@
#include "ngraph/runtime/cpu/op/max_pool_with_indices.hpp"
#include "ngraph/runtime/cpu/op/quantize.hpp"
#include "ngraph/runtime/cpu/op/quantized_avg_pool.hpp"
#include "ngraph/runtime/cpu/op/quantized_conv.hpp"
#include "ngraph/runtime/cpu/op/quantized_max_pool.hpp"
#include "ngraph/runtime/cpu/op/rnn.hpp"
#include "ngraph/runtime/cpu/op/sigmoid.hpp"
......@@ -728,6 +729,20 @@ namespace ngraph
}
}
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::QuantizedConvolution)
{
if (node->get_input_element_type(0) == element::u8 &&
node->get_input_element_type(1) == element::i8)
{
auto quantized_conv = static_cast<op::QuantizedConvolution*>(node);
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
quantized_conv->set_op_annotations(op_annotations);
}
}
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::Quantize)
{
......@@ -781,6 +796,8 @@ static const runtime::cpu::pass::AssignOpMap s_dispatcher{
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::MaxPoolWithIndicesBackprop>},
{TI(ngraph::op::ConvolutionBias),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::ConvolutionBias>},
{TI(ngraph::op::QuantizedConvolution),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::QuantizedConvolution>},
{TI(ngraph::op::ConvolutionBiasBackpropFiltersBias),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::ConvolutionBiasBackpropFiltersBias>},
{TI(ngraph::op::LRN), &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::LRN>},
......
......@@ -56,6 +56,7 @@
#include "ngraph/runtime/cpu/op/max_pool_with_indices.hpp"
#include "ngraph/runtime/cpu/op/quantize.hpp"
#include "ngraph/runtime/cpu/op/quantized_avg_pool.hpp"
#include "ngraph/runtime/cpu/op/quantized_conv.hpp"
#include "ngraph/runtime/cpu/op/quantized_max_pool.hpp"
#include "ngraph/runtime/cpu/op/rnn.hpp"
......@@ -297,6 +298,11 @@ namespace ngraph
memory::data_type et =
mkldnn_utils::get_mkldnn_data_type(node->get_input_element_type(0));
memory::data_type et_weights = runtime::cpu::mkldnn_utils::get_mkldnn_data_type(
node->get_input_element_type(1));
memory::data_type et_result = runtime::cpu::mkldnn_utils::get_mkldnn_data_type(
node->get_output_element_type(0));
engine cpu_engine(engine::cpu, 0);
memory::dims mkldnn_arg0_shape(arg0_shape.begin(), arg0_shape.end());
memory::dims mkldnn_arg1_shape(arg1_shape.begin(), arg1_shape.end());
......@@ -308,8 +314,10 @@ namespace ngraph
memory::dims mkldnn_padding_below(padding_below.begin(), padding_below.end());
memory::dims mkldnn_padding_above(padding_above.begin(), padding_above.end());
const memory::desc input_data_desc(mkldnn_arg0_shape, et, memory::format::any);
const memory::desc weights_desc(mkldnn_arg1_shape, et, memory::format::any);
const memory::desc result_desc(mkldnn_result_shape, et, memory::format::any);
const memory::desc weights_desc(
mkldnn_arg1_shape, et_weights, memory::format::any);
const memory::desc result_desc(
mkldnn_result_shape, et_result, memory::format::any);
std::unique_ptr<convolution_forward::desc> fwd_desc{nullptr};
if (use_bias)
{
......@@ -384,6 +392,48 @@ namespace ngraph
o_mds.push_back(prim_desc.dst_primitive_desc().desc());
}
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::QuantizedConvolution)
{
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node.get()))
{
vector<memory::desc> i_mds;
vector<memory::desc> o_mds;
ConvolutionLayout<ngraph::op::QuantizedConvolution, false, false>(
node, i_mds, o_mds);
auto min_input_md = mkldnn_utils::create_default_mkldnn_md(
node.get(), 2, false, memory::format::x);
auto max_input_md = mkldnn_utils::create_default_mkldnn_md(
node.get(), 3, false, memory::format::x);
auto min_filter_md = mkldnn_utils::create_default_mkldnn_md(
node.get(), 4, false, memory::format::x);
auto max_filter_md = mkldnn_utils::create_default_mkldnn_md(
node.get(), 5, false, memory::format::x);
auto min_freezed_output_md = mkldnn_utils::create_default_mkldnn_md(
node.get(), 6, false, memory::format::x);
auto max_freezed_output_md = mkldnn_utils::create_default_mkldnn_md(
node.get(), 7, false, memory::format::x);
auto min_output_md = mkldnn_utils::create_default_mkldnn_md(
node.get(), 1, true, memory::format::x);
auto max_output_md = mkldnn_utils::create_default_mkldnn_md(
node.get(), 2, true, memory::format::x);
i_mds.push_back(min_input_md);
i_mds.push_back(max_input_md);
i_mds.push_back(min_filter_md);
i_mds.push_back(max_filter_md);
i_mds.push_back(min_freezed_output_md);
i_mds.push_back(max_freezed_output_md);
o_mds.push_back(min_output_md);
o_mds.push_back(max_output_md);
node = insert_input_conversions(external_function, node, i_mds);
set_output_layouts(node, o_mds);
}
else
{
set_native_layouts(external_function, node);
}
}
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::Convolution)
{
......@@ -1682,6 +1732,8 @@ static const runtime::cpu::pass::LayoutOpMap s_dispatcher{
{TI(ngraph::op::AvgPool), &runtime::cpu::pass::CPULayout::layout<ngraph::op::AvgPool>},
{TI(ngraph::op::AvgPoolBackprop),
&runtime::cpu::pass::CPULayout::layout<ngraph::op::AvgPoolBackprop>},
{TI(ngraph::op::QuantizedConvolution),
&runtime::cpu::pass::CPULayout::layout<ngraph::op::QuantizedConvolution>},
{TI(ngraph::op::Convolution), &runtime::cpu::pass::CPULayout::layout<ngraph::op::Convolution>},
{TI(ngraph::op::GroupConvolution),
&runtime::cpu::pass::CPULayout::layout<ngraph::op::GroupConvolution>},
......
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "quantization_util.hpp"
#include "ngraph/runtime/cpu/op/quantized_conv.hpp"
namespace ngraph
{
namespace runtime
{
namespace cpu
{
namespace quantization_util
{
float get_scale(const ngraph::Node* node)
{
auto qconvolution = static_cast<const ngraph::op::QuantizedConvolution*>(node);
float min_out_value;
float max_out_value;
quantization_range_for_multiplication<uint8_t, int8_t, int32_t>(
qconvolution->get_input_min(),
qconvolution->get_input_max(),
qconvolution->get_filter_min(),
qconvolution->get_filter_max(),
&min_out_value,
&max_out_value);
const float max_abs32 =
std::max(std::abs(min_out_value), std::abs(max_out_value));
const float max_abs8 =
std::max(std::abs(qconvolution->get_freezed_output_min()),
std::abs(qconvolution->get_freezed_output_max()));
// Output is signed int.
// s32 = f32 * std::pow(2, 31)/ max_abs32;
// s8 = f32 * std::pow(2, 7)/ max_abs8;
// s8 = s32 * std::pow(2, -24) * max_abs32 / max_abs8;
const float scale = static_cast<float>(
(std::pow(2, -24) * static_cast<double>(max_abs32 / max_abs8)));
return scale;
}
}
}
}
}
......@@ -24,12 +24,16 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#pragma once
#include <limits>
#include <vector>
#include "ngraph/node.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/runtime/cpu/op/quantize.hpp"
#include "ngraph/util.hpp"
namespace ngraph
{
namespace runtime
......@@ -38,6 +42,23 @@ namespace ngraph
{
namespace quantization_util
{
template <class T1, class T2, class T3>
void quantization_range_for_multiplication(
float min_a, float max_a, float min_b, float max_b, float* min_c, float* max_c)
{
// begin code copied and pasted (and modified) from
// github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/quantization_utils.h
float a_one_quant_level = (max_a - min_a) / (std::numeric_limits<T1>::max() -
std::numeric_limits<T1>::min());
float b_one_quant_level = (max_b - min_b) / (std::numeric_limits<T2>::max() -
std::numeric_limits<T2>::min());
float c_one_quant_level = a_one_quant_level * b_one_quant_level;
*min_c = c_one_quant_level * std::numeric_limits<T3>::min();
*max_c = c_one_quant_level * std::numeric_limits<T3>::max();
// end code copied and pasted (and modified) from
// github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/quantization_utils.h
}
static inline void get_min_max_range(float input_min_range,
float input_max_range,
bool is_signed,
......@@ -67,6 +88,8 @@ namespace ngraph
quant_util.push_back(max_range);
quant_util.push_back(scale);
}
float get_scale(const ngraph::Node* node);
}
}
}
......
......@@ -26,6 +26,7 @@
#include "ngraph/runtime/cpu/op/dequantize.hpp"
#include "ngraph/runtime/cpu/op/quantize.hpp"
#include "ngraph/runtime/cpu/op/quantized_avg_pool.hpp"
#include "ngraph/runtime/cpu/op/quantized_conv.hpp"
#include "ngraph/runtime/cpu/op/quantized_max_pool.hpp"
#include "util/all_close.hpp"
#include "util/all_close_f.hpp"
......@@ -204,6 +205,55 @@ TEST(quantize_cpu, dequantize_from_int8)
DequantizeTest<int8_t>(42, -1.0f, 300.0f, static_cast<float>(99.212601));
}
TEST(quantize_cpu, quantizedConv2D_small)
{
Shape shape_a{1, 1, 3, 4}; // input shape
Shape shape_b{1, 1, 3, 3}; // filter shape
Shape shape_r{1, 1, 3, 4}; // output shape
vector<uint8_t> a_data = {1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4};
vector<int8_t> b_data = {1, 2, 3, 4, 5, 0, 0, 1, 2};
auto A = make_shared<op::Parameter>(element::u8, shape_a);
auto B = make_shared<op::Parameter>(element::i8, shape_b);
auto C = op::Constant::create(element::f32, Shape{1}, {0.0f});
auto D = op::Constant::create(element::f32, Shape{1}, {255.0f});
auto E = op::Constant::create(element::f32, Shape{1}, {-127.0f});
auto F = op::Constant::create(element::f32, Shape{1}, {127.0f});
auto G = op::Constant::create(element::f32, Shape{1}, {22.0f});
auto H = op::Constant::create(element::f32, Shape{1}, {90.0f});
auto CV = make_shared<op::QuantizedConvolution>(A,
B,
Strides{1, 1}, // move_strides
Strides{1, 1}, // filter_dilation
CoordinateDiff{1, 1}, // below_pads
CoordinateDiff{1, 1}, // above_pads
Strides{1, 1}, // data_dilation
C,
D,
E,
F,
G,
H);
auto output_data = std::make_shared<op::GetOutputElement>(CV, 0);
auto output_min = std::make_shared<op::GetOutputElement>(CV, 1);
auto output_max = std::make_shared<op::GetOutputElement>(CV, 2);
auto f = make_shared<Function>(NodeVector{output_data, output_min, output_max},
op::ParameterVector{A, B});
auto backend = runtime::Backend::create("CPU");
// Create some tensors for input/output
auto a = backend->create_tensor(element::u8, shape_a);
copy_data(a, a_data);
auto b = backend->create_tensor(element::i8, shape_b);
copy_data(b, b_data);
auto result = backend->create_tensor(element::i8, shape_r);
auto result_min = backend->create_tensor(element::f32, Shape{1});
auto result_max = backend->create_tensor(element::f32, Shape{1});
backend->call_with_validate(f, {result, result_min, result_max}, {a, b});
EXPECT_EQ((vector<int8_t>{31, 48, 42, 45, 54, 102, 127, 61, 47, 74, 61, 55}),
read_vector<int8_t>(result));
EXPECT_EQ((vector<float>{22.0}), read_vector<float>(result_min));
EXPECT_EQ((vector<float>{90.0}), read_vector<float>(result_max));
}
TEST(quantize_cpu, quantize_to_uint8_small)
{
vector<float> a_data = {-85.0, 0.0, 2.0, 10.0, 15.0};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment