Commit b5f43973 authored by Nishant Patel's avatar Nishant Patel Committed by Scott Cyphers

Add support for Quantized Convolution op via mkldnn for IA backend (codegen + DEX) (#1620)

* Add support for Quantized Convolution op via mkldnn for IA backend (codegen + DEX)

* Use call_with_validate

* Style fix

* Fix clang compile errors
parent 28228857
...@@ -24,6 +24,7 @@ set(SRC ...@@ -24,6 +24,7 @@ set(SRC
cpu_tensor_view_wrapper.cpp cpu_tensor_view_wrapper.cpp
cpu_tensor_view.cpp cpu_tensor_view.cpp
cpu_tracing.cpp cpu_tracing.cpp
quantization_util.cpp
builder/add.cpp builder/add.cpp
builder/allreduce.cpp builder/allreduce.cpp
builder/avg_pool.cpp builder/avg_pool.cpp
...@@ -35,6 +36,7 @@ set(SRC ...@@ -35,6 +36,7 @@ set(SRC
builder/concat.cpp builder/concat.cpp
builder/convert.cpp builder/convert.cpp
builder/convert_layout.cpp builder/convert_layout.cpp
builder/quantized_conv.cpp
builder/convolution.cpp builder/convolution.cpp
builder/dequantize.cpp builder/dequantize.cpp
builder/quantize.cpp builder/quantize.cpp
...@@ -80,6 +82,7 @@ set(SRC ...@@ -80,6 +82,7 @@ set(SRC
op/group_conv.cpp op/group_conv.cpp
op/conv_bias.cpp op/conv_bias.cpp
op/conv_relu.cpp op/conv_relu.cpp
op/quantized_conv.cpp
op/convert_layout.cpp op/convert_layout.cpp
op/dequantize.cpp op/dequantize.cpp
op/quantize.cpp op/quantize.cpp
......
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "ngraph/runtime/cpu/op/quantized_conv.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/runtime/cpu/cpu_builder.hpp"
#include "ngraph/runtime/cpu/mkldnn_invoke.hpp"
#include "ngraph/runtime/cpu/mkldnn_utils.hpp"
using namespace std;
using namespace ngraph;
namespace ngraph
{
namespace runtime
{
namespace cpu
{
template <>
void Builder::BUILDER_DECL(ngraph::op::QuantizedConvolution)
{
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{
auto qconvolution = static_cast<const ngraph::op::QuantizedConvolution*>(node);
auto& functors = external_function->get_functors();
auto& arg0_tensor = external_function->get_tensor_data(args[0].get_name());
auto& arg1_tensor = external_function->get_tensor_data(args[1].get_name());
auto& out_tensor = external_function->get_tensor_data(out[0].get_name());
auto& out1_tensor = external_function->get_tensor_data(out[1].get_name());
auto& out2_tensor = external_function->get_tensor_data(out[2].get_name());
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto conv_index =
mkldnn_emitter->build_convolution<ngraph::op::QuantizedConvolution>(
node, args, out);
auto& deps = mkldnn_emitter->get_primitive_deps(conv_index);
float min_freezed_output = qconvolution->get_freezed_output_min();
float max_freezed_output = qconvolution->get_freezed_output_max();
auto functor = [&, conv_index, min_freezed_output, max_freezed_output](
CPURuntimeContext* ctx) {
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[0], arg0_tensor);
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[1], arg1_tensor);
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[2], out_tensor);
*(static_cast<float*>(out1_tensor)) = min_freezed_output;
*(static_cast<float*>(out2_tensor)) = max_freezed_output;
cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, conv_index);
};
functors.emplace_back(functor);
}
else
{
throw ngraph_error("unsupported parameters for QuantizedConvolution via DEX");
}
}
REGISTER_OP_BUILDER(QuantizedConvolution);
}
}
}
...@@ -2631,6 +2631,37 @@ namespace ngraph ...@@ -2631,6 +2631,37 @@ namespace ngraph
} }
} }
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::QuantizedConvolution)
{
auto qconvolution = static_cast<const ngraph::op::QuantizedConvolution*>(node);
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto conv_index =
mkldnn_emitter->build_convolution<ngraph::op::QuantizedConvolution>(
node, args, out);
auto& deps = mkldnn_emitter->get_primitive_deps(conv_index);
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[0])
<< ", " << args[0].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[1])
<< ", " << args[1].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[2])
<< ", " << out[0].get_name() << ");\n";
writer << "*(" << out[1].get_name()
<< ") = " << qconvolution->get_freezed_output_min() << ";\n";
writer << "*(" << out[2].get_name()
<< ") = " << qconvolution->get_freezed_output_max() << ";\n";
writer << "cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, "
<< to_string(conv_index) << ");\n";
}
else
{
throw ngraph_error("unsupported parameters for QuantizedConvolution");
}
}
template <> template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::GroupConvolution) void CPU_Emitter::EMITTER_DECL(ngraph::op::GroupConvolution)
{ {
......
...@@ -150,6 +150,7 @@ ...@@ -150,6 +150,7 @@
#include "ngraph/runtime/cpu/op/max_pool_with_indices.hpp" #include "ngraph/runtime/cpu/op/max_pool_with_indices.hpp"
#include "ngraph/runtime/cpu/op/quantize.hpp" #include "ngraph/runtime/cpu/op/quantize.hpp"
#include "ngraph/runtime/cpu/op/quantized_avg_pool.hpp" #include "ngraph/runtime/cpu/op/quantized_avg_pool.hpp"
#include "ngraph/runtime/cpu/op/quantized_conv.hpp"
#include "ngraph/runtime/cpu/op/quantized_max_pool.hpp" #include "ngraph/runtime/cpu/op/quantized_max_pool.hpp"
#include "ngraph/runtime/cpu/op/rnn.hpp" #include "ngraph/runtime/cpu/op/rnn.hpp"
#include "ngraph/runtime/cpu/op/sigmoid.hpp" #include "ngraph/runtime/cpu/op/sigmoid.hpp"
...@@ -305,6 +306,8 @@ static const runtime::cpu::OpMap dispatcher{ ...@@ -305,6 +306,8 @@ static const runtime::cpu::OpMap dispatcher{
{TI(ngraph::op::GroupConvolution), &runtime::cpu::CPU_Emitter::emit<op::GroupConvolution>}, {TI(ngraph::op::GroupConvolution), &runtime::cpu::CPU_Emitter::emit<op::GroupConvolution>},
{TI(ngraph::op::ConvolutionBias), &runtime::cpu::CPU_Emitter::emit<op::ConvolutionBias>}, {TI(ngraph::op::ConvolutionBias), &runtime::cpu::CPU_Emitter::emit<op::ConvolutionBias>},
{TI(ngraph::op::ConvolutionRelu), &runtime::cpu::CPU_Emitter::emit<op::ConvolutionRelu>}, {TI(ngraph::op::ConvolutionRelu), &runtime::cpu::CPU_Emitter::emit<op::ConvolutionRelu>},
{TI(ngraph::op::QuantizedConvolution),
&runtime::cpu::CPU_Emitter::emit<op::QuantizedConvolution>},
{TI(ngraph::op::ConvolutionBiasAdd), &runtime::cpu::CPU_Emitter::emit<op::ConvolutionBiasAdd>}, {TI(ngraph::op::ConvolutionBiasAdd), &runtime::cpu::CPU_Emitter::emit<op::ConvolutionBiasAdd>},
// conv+bias backprop for data share the same implementation as ConvolutionBackpropData // conv+bias backprop for data share the same implementation as ConvolutionBackpropData
{TI(ngraph::op::ConvolutionBiasBackpropFiltersBias), {TI(ngraph::op::ConvolutionBiasBackpropFiltersBias),
......
...@@ -297,6 +297,45 @@ size_t MKLDNNEmitter::build_convolution_forward(const mkldnn::memory::desc& inpu ...@@ -297,6 +297,45 @@ size_t MKLDNNEmitter::build_convolution_forward(const mkldnn::memory::desc& inpu
return conv_index; return conv_index;
} }
size_t MKLDNNEmitter::build_quantized_convolution(const mkldnn::memory::desc& input_data_desc,
const mkldnn::memory::desc& weights_desc,
const mkldnn::memory::desc& result_desc,
const ngraph::Strides& strides,
const ngraph::Strides& dilation_strides,
const ngraph::CoordinateDiff& padding_below,
const ngraph::CoordinateDiff& padding_above,
const float scale)
{
size_t input_data_index = build_memory_primitive(input_data_desc);
size_t weights_index = build_memory_primitive(weights_desc);
size_t result_index = build_memory_primitive(result_desc);
std::vector<float> output_scale;
output_scale.push_back(scale);
mkldnn::primitive_attr conv_attr;
/* Specify the rounding mode */
conv_attr.set_int_output_round_mode(mkldnn::round_mode::round_nearest);
/* Specify the scales array and corresponding mask */
conv_attr.set_output_scales(0, output_scale);
size_t conv_index = insert_primitive(new mkldnn::convolution_forward(
{{mkldnn::prop_kind::forward,
mkldnn::algorithm::convolution_direct,
input_data_desc,
weights_desc,
result_desc,
mkldnn::memory::dims(strides.begin(), strides.end()),
mkldnn::memory::dims(dilation_strides.begin(), dilation_strides.end()),
mkldnn::memory::dims(padding_below.begin(), padding_below.end()),
mkldnn::memory::dims(padding_above.begin(), padding_above.end()),
mkldnn::padding_kind::zero},
conv_attr,
mkldnn_utils::global_cpu_engine},
*m_mkldnn_primitives[input_data_index],
*m_mkldnn_primitives[weights_index],
*m_mkldnn_primitives[result_index]));
m_primitive_deps[conv_index] = {input_data_index, weights_index, result_index};
return conv_index;
}
size_t MKLDNNEmitter::build_convolution_forward(const mkldnn::memory::desc& input_data_desc, size_t MKLDNNEmitter::build_convolution_forward(const mkldnn::memory::desc& input_data_desc,
const mkldnn::memory::desc& weights_desc, const mkldnn::memory::desc& weights_desc,
const mkldnn::memory::desc& bias_desc, const mkldnn::memory::desc& bias_desc,
......
...@@ -32,6 +32,8 @@ ...@@ -32,6 +32,8 @@
#include "ngraph/runtime/cpu/op/conv_add.hpp" #include "ngraph/runtime/cpu/op/conv_add.hpp"
#include "ngraph/runtime/cpu/op/conv_bias.hpp" #include "ngraph/runtime/cpu/op/conv_bias.hpp"
#include "ngraph/runtime/cpu/op/conv_relu.hpp" #include "ngraph/runtime/cpu/op/conv_relu.hpp"
#include "ngraph/runtime/cpu/op/quantized_conv.hpp"
#include "ngraph/runtime/cpu/quantization_util.hpp"
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
#include "ngraph/strides.hpp" #include "ngraph/strides.hpp"
#include "ngraph/type/element_type.hpp" #include "ngraph/type/element_type.hpp"
...@@ -104,6 +106,15 @@ namespace ngraph ...@@ -104,6 +106,15 @@ namespace ngraph
const ngraph::CoordinateDiff& padding_above, const ngraph::CoordinateDiff& padding_above,
const mkldnn::post_ops& pops = mkldnn::post_ops()); const mkldnn::post_ops& pops = mkldnn::post_ops());
size_t build_quantized_convolution(const mkldnn::memory::desc& input_data_desc,
const mkldnn::memory::desc& weights_desc,
const mkldnn::memory::desc& result_desc,
const ngraph::Strides& strides,
const ngraph::Strides& dilation_strides,
const ngraph::CoordinateDiff& padding_below,
const ngraph::CoordinateDiff& padding_above,
const float scale);
template <typename OP> template <typename OP>
size_t build_convolution(const ngraph::Node* node, size_t build_convolution(const ngraph::Node* node,
const std::vector<TensorViewWrapper>& args, const std::vector<TensorViewWrapper>& args,
...@@ -185,6 +196,19 @@ namespace ngraph ...@@ -185,6 +196,19 @@ namespace ngraph
convolution->get_padding_above(), convolution->get_padding_above(),
ops); ops);
} }
else if (std::is_same<OP, ngraph::op::QuantizedConvolution>())
{
const float scale = quantization_util::get_scale(node);
return build_quantized_convolution(
data_desc,
weights_desc,
result_desc,
convolution->get_window_movement_strides(),
window_dilation_strides_adjusted,
convolution->get_padding_below(),
convolution->get_padding_above(),
scale);
}
else else
{ {
return build_convolution_forward(data_desc, return build_convolution_forward(data_desc,
......
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "quantized_conv.hpp"
#include <numeric>
#include "ngraph/op/constant.hpp"
#include "ngraph/op/convolution.hpp"
#include "ngraph/util.hpp"
using namespace std;
using namespace ngraph;
op::QuantizedConvolution::QuantizedConvolution(const shared_ptr<Node>& data_batch,
const shared_ptr<Node>& filters,
const Strides& window_movement_strides,
const Strides& window_dilation_strides,
const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above,
const Strides& data_dilation_strides,
const std::shared_ptr<Node> min_input,
const std::shared_ptr<Node> max_input,
const std::shared_ptr<Node> min_filter,
const std::shared_ptr<Node> max_filter,
const std::shared_ptr<Node> min_freezed_output,
const std::shared_ptr<Node> max_freezed_output)
: Op("QuantizedConvolution",
check_single_output_args({data_batch,
filters,
min_input,
max_input,
min_filter,
max_filter,
min_freezed_output,
max_freezed_output}))
, m_window_movement_strides(window_movement_strides)
, m_window_dilation_strides(window_dilation_strides)
, m_padding_below(padding_below)
, m_padding_above(padding_above)
, m_data_dilation_strides(data_dilation_strides)
{
constructor_validate_and_infer_types();
//TODO(nbpatel): Add checks.
auto& data_batch_shape = data_batch->get_shape();
auto& filters_shape = filters->get_shape();
auto min_input_const_op = std::static_pointer_cast<ngraph::op::Constant>(min_input);
auto max_input_const_op = std::static_pointer_cast<ngraph::op::Constant>(max_input);
auto min_filter_const_op = std::static_pointer_cast<ngraph::op::Constant>(min_filter);
auto max_filter_const_op = std::static_pointer_cast<ngraph::op::Constant>(max_filter);
auto min_freezed_output_const_op =
std::static_pointer_cast<ngraph::op::Constant>(min_freezed_output);
auto max_freezed_output_const_op =
std::static_pointer_cast<ngraph::op::Constant>(max_freezed_output);
float input_min = *(static_cast<float const*>(min_input_const_op->get_data_ptr()));
float input_max = *(static_cast<float const*>(max_input_const_op->get_data_ptr()));
float filter_min = *(static_cast<float const*>(min_filter_const_op->get_data_ptr()));
float filter_max = *(static_cast<float const*>(max_filter_const_op->get_data_ptr()));
float output_min = *(static_cast<float const*>(min_freezed_output_const_op->get_data_ptr()));
float output_max = *(static_cast<float const*>(max_freezed_output_const_op->get_data_ptr()));
this->m_input_min = input_min;
this->m_input_max = input_max;
this->m_filter_min = filter_min;
this->m_filter_max = filter_max;
this->m_freezed_output_min = output_min;
this->m_freezed_output_max = output_max;
set_output_size(3);
set_output_type(0,
element::i8,
util::infer_convolution_output_shape(this,
data_batch_shape,
filters_shape,
window_movement_strides,
window_dilation_strides,
padding_below,
padding_above,
data_dilation_strides,
0, /* batch_axis_data, */
1, /* input_channel_axis_data, */
1, /* input_channel_axis_filters, */
0, /* output_channel_axis_filters, */
0, /* batch_axis_result, */
1 /* output_channel_axis_result, */
));
set_output_type(1, element::f32, Shape{1});
set_output_type(2, element::f32, Shape{1});
}
shared_ptr<Node> op::QuantizedConvolution::copy_with_new_args(const NodeVector& new_args) const
{
if (new_args.size() != 8)
{
throw ngraph_error("Incorrect number of new arguments");
}
return shared_ptr<Node>(new QuantizedConvolution(new_args.at(0),
new_args.at(1),
get_window_movement_strides(),
get_window_dilation_strides(),
get_padding_below(),
get_padding_above(),
get_data_dilation_strides(),
new_args.at(2),
new_args.at(3),
new_args.at(4),
new_args.at(5),
new_args.at(6),
new_args.at(7)));
}
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/op.hpp"
namespace ngraph
{
namespace op
{
class QuantizedConvolution : public Op
{
public:
QuantizedConvolution(const std::shared_ptr<Node>& data_batch,
const std::shared_ptr<Node>& filters,
const Strides& window_movement_strides,
const Strides& window_dilation_strides,
const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above,
const Strides& data_dilation_strides,
const std::shared_ptr<Node> min_input,
const std::shared_ptr<Node> max_input,
const std::shared_ptr<Node> min_filter,
const std::shared_ptr<Node> max_filter,
const std::shared_ptr<Node> min_freezed_output,
const std::shared_ptr<Node> max_freezed_output);
const Strides& get_window_movement_strides() const { return m_window_movement_strides; }
const Strides& get_window_dilation_strides() const { return m_window_dilation_strides; }
const CoordinateDiff& get_padding_below() const { return m_padding_below; }
const CoordinateDiff& get_padding_above() const { return m_padding_above; }
const Strides& get_data_dilation_strides() const { return m_data_dilation_strides; }
std::shared_ptr<Node> get_filters() { return get_argument(1); }
std::shared_ptr<Node> get_data_batch() { return get_argument(0); }
float get_input_min() const { return m_input_min; }
float get_input_max() const { return m_input_max; }
float get_filter_min() const { return m_filter_min; }
float get_filter_max() const { return m_filter_max; }
float get_freezed_output_min() const { return m_freezed_output_min; }
float get_freezed_output_max() const { return m_freezed_output_max; }
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
protected:
Strides m_window_movement_strides;
Strides m_window_dilation_strides;
CoordinateDiff m_padding_below;
CoordinateDiff m_padding_above;
Strides m_data_dilation_strides;
float m_input_min;
float m_input_max;
float m_filter_min;
float m_filter_max;
float m_freezed_output_min;
float m_freezed_output_max;
};
}
}
...@@ -48,6 +48,7 @@ ...@@ -48,6 +48,7 @@
#include "ngraph/runtime/cpu/op/max_pool_with_indices.hpp" #include "ngraph/runtime/cpu/op/max_pool_with_indices.hpp"
#include "ngraph/runtime/cpu/op/quantize.hpp" #include "ngraph/runtime/cpu/op/quantize.hpp"
#include "ngraph/runtime/cpu/op/quantized_avg_pool.hpp" #include "ngraph/runtime/cpu/op/quantized_avg_pool.hpp"
#include "ngraph/runtime/cpu/op/quantized_conv.hpp"
#include "ngraph/runtime/cpu/op/quantized_max_pool.hpp" #include "ngraph/runtime/cpu/op/quantized_max_pool.hpp"
#include "ngraph/runtime/cpu/op/rnn.hpp" #include "ngraph/runtime/cpu/op/rnn.hpp"
#include "ngraph/runtime/cpu/op/sigmoid.hpp" #include "ngraph/runtime/cpu/op/sigmoid.hpp"
...@@ -728,6 +729,20 @@ namespace ngraph ...@@ -728,6 +729,20 @@ namespace ngraph
} }
} }
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::QuantizedConvolution)
{
if (node->get_input_element_type(0) == element::u8 &&
node->get_input_element_type(1) == element::i8)
{
auto quantized_conv = static_cast<op::QuantizedConvolution*>(node);
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
quantized_conv->set_op_annotations(op_annotations);
}
}
template <> template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::Quantize) void CPUAssignment::ASSIGN_DECL(ngraph::op::Quantize)
{ {
...@@ -781,6 +796,8 @@ static const runtime::cpu::pass::AssignOpMap s_dispatcher{ ...@@ -781,6 +796,8 @@ static const runtime::cpu::pass::AssignOpMap s_dispatcher{
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::MaxPoolWithIndicesBackprop>}, &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::MaxPoolWithIndicesBackprop>},
{TI(ngraph::op::ConvolutionBias), {TI(ngraph::op::ConvolutionBias),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::ConvolutionBias>}, &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::ConvolutionBias>},
{TI(ngraph::op::QuantizedConvolution),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::QuantizedConvolution>},
{TI(ngraph::op::ConvolutionBiasBackpropFiltersBias), {TI(ngraph::op::ConvolutionBiasBackpropFiltersBias),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::ConvolutionBiasBackpropFiltersBias>}, &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::ConvolutionBiasBackpropFiltersBias>},
{TI(ngraph::op::LRN), &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::LRN>}, {TI(ngraph::op::LRN), &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::LRN>},
......
...@@ -56,6 +56,7 @@ ...@@ -56,6 +56,7 @@
#include "ngraph/runtime/cpu/op/max_pool_with_indices.hpp" #include "ngraph/runtime/cpu/op/max_pool_with_indices.hpp"
#include "ngraph/runtime/cpu/op/quantize.hpp" #include "ngraph/runtime/cpu/op/quantize.hpp"
#include "ngraph/runtime/cpu/op/quantized_avg_pool.hpp" #include "ngraph/runtime/cpu/op/quantized_avg_pool.hpp"
#include "ngraph/runtime/cpu/op/quantized_conv.hpp"
#include "ngraph/runtime/cpu/op/quantized_max_pool.hpp" #include "ngraph/runtime/cpu/op/quantized_max_pool.hpp"
#include "ngraph/runtime/cpu/op/rnn.hpp" #include "ngraph/runtime/cpu/op/rnn.hpp"
...@@ -297,6 +298,11 @@ namespace ngraph ...@@ -297,6 +298,11 @@ namespace ngraph
memory::data_type et = memory::data_type et =
mkldnn_utils::get_mkldnn_data_type(node->get_input_element_type(0)); mkldnn_utils::get_mkldnn_data_type(node->get_input_element_type(0));
memory::data_type et_weights = runtime::cpu::mkldnn_utils::get_mkldnn_data_type(
node->get_input_element_type(1));
memory::data_type et_result = runtime::cpu::mkldnn_utils::get_mkldnn_data_type(
node->get_output_element_type(0));
engine cpu_engine(engine::cpu, 0); engine cpu_engine(engine::cpu, 0);
memory::dims mkldnn_arg0_shape(arg0_shape.begin(), arg0_shape.end()); memory::dims mkldnn_arg0_shape(arg0_shape.begin(), arg0_shape.end());
memory::dims mkldnn_arg1_shape(arg1_shape.begin(), arg1_shape.end()); memory::dims mkldnn_arg1_shape(arg1_shape.begin(), arg1_shape.end());
...@@ -308,8 +314,10 @@ namespace ngraph ...@@ -308,8 +314,10 @@ namespace ngraph
memory::dims mkldnn_padding_below(padding_below.begin(), padding_below.end()); memory::dims mkldnn_padding_below(padding_below.begin(), padding_below.end());
memory::dims mkldnn_padding_above(padding_above.begin(), padding_above.end()); memory::dims mkldnn_padding_above(padding_above.begin(), padding_above.end());
const memory::desc input_data_desc(mkldnn_arg0_shape, et, memory::format::any); const memory::desc input_data_desc(mkldnn_arg0_shape, et, memory::format::any);
const memory::desc weights_desc(mkldnn_arg1_shape, et, memory::format::any); const memory::desc weights_desc(
const memory::desc result_desc(mkldnn_result_shape, et, memory::format::any); mkldnn_arg1_shape, et_weights, memory::format::any);
const memory::desc result_desc(
mkldnn_result_shape, et_result, memory::format::any);
std::unique_ptr<convolution_forward::desc> fwd_desc{nullptr}; std::unique_ptr<convolution_forward::desc> fwd_desc{nullptr};
if (use_bias) if (use_bias)
{ {
...@@ -384,6 +392,48 @@ namespace ngraph ...@@ -384,6 +392,48 @@ namespace ngraph
o_mds.push_back(prim_desc.dst_primitive_desc().desc()); o_mds.push_back(prim_desc.dst_primitive_desc().desc());
} }
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::QuantizedConvolution)
{
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node.get()))
{
vector<memory::desc> i_mds;
vector<memory::desc> o_mds;
ConvolutionLayout<ngraph::op::QuantizedConvolution, false, false>(
node, i_mds, o_mds);
auto min_input_md = mkldnn_utils::create_default_mkldnn_md(
node.get(), 2, false, memory::format::x);
auto max_input_md = mkldnn_utils::create_default_mkldnn_md(
node.get(), 3, false, memory::format::x);
auto min_filter_md = mkldnn_utils::create_default_mkldnn_md(
node.get(), 4, false, memory::format::x);
auto max_filter_md = mkldnn_utils::create_default_mkldnn_md(
node.get(), 5, false, memory::format::x);
auto min_freezed_output_md = mkldnn_utils::create_default_mkldnn_md(
node.get(), 6, false, memory::format::x);
auto max_freezed_output_md = mkldnn_utils::create_default_mkldnn_md(
node.get(), 7, false, memory::format::x);
auto min_output_md = mkldnn_utils::create_default_mkldnn_md(
node.get(), 1, true, memory::format::x);
auto max_output_md = mkldnn_utils::create_default_mkldnn_md(
node.get(), 2, true, memory::format::x);
i_mds.push_back(min_input_md);
i_mds.push_back(max_input_md);
i_mds.push_back(min_filter_md);
i_mds.push_back(max_filter_md);
i_mds.push_back(min_freezed_output_md);
i_mds.push_back(max_freezed_output_md);
o_mds.push_back(min_output_md);
o_mds.push_back(max_output_md);
node = insert_input_conversions(external_function, node, i_mds);
set_output_layouts(node, o_mds);
}
else
{
set_native_layouts(external_function, node);
}
}
template <> template <>
void CPULayout::LAYOUT_DECL(ngraph::op::Convolution) void CPULayout::LAYOUT_DECL(ngraph::op::Convolution)
{ {
...@@ -1682,6 +1732,8 @@ static const runtime::cpu::pass::LayoutOpMap s_dispatcher{ ...@@ -1682,6 +1732,8 @@ static const runtime::cpu::pass::LayoutOpMap s_dispatcher{
{TI(ngraph::op::AvgPool), &runtime::cpu::pass::CPULayout::layout<ngraph::op::AvgPool>}, {TI(ngraph::op::AvgPool), &runtime::cpu::pass::CPULayout::layout<ngraph::op::AvgPool>},
{TI(ngraph::op::AvgPoolBackprop), {TI(ngraph::op::AvgPoolBackprop),
&runtime::cpu::pass::CPULayout::layout<ngraph::op::AvgPoolBackprop>}, &runtime::cpu::pass::CPULayout::layout<ngraph::op::AvgPoolBackprop>},
{TI(ngraph::op::QuantizedConvolution),
&runtime::cpu::pass::CPULayout::layout<ngraph::op::QuantizedConvolution>},
{TI(ngraph::op::Convolution), &runtime::cpu::pass::CPULayout::layout<ngraph::op::Convolution>}, {TI(ngraph::op::Convolution), &runtime::cpu::pass::CPULayout::layout<ngraph::op::Convolution>},
{TI(ngraph::op::GroupConvolution), {TI(ngraph::op::GroupConvolution),
&runtime::cpu::pass::CPULayout::layout<ngraph::op::GroupConvolution>}, &runtime::cpu::pass::CPULayout::layout<ngraph::op::GroupConvolution>},
......
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "quantization_util.hpp"
#include "ngraph/runtime/cpu/op/quantized_conv.hpp"
namespace ngraph
{
namespace runtime
{
namespace cpu
{
namespace quantization_util
{
float get_scale(const ngraph::Node* node)
{
auto qconvolution = static_cast<const ngraph::op::QuantizedConvolution*>(node);
float min_out_value;
float max_out_value;
quantization_range_for_multiplication<uint8_t, int8_t, int32_t>(
qconvolution->get_input_min(),
qconvolution->get_input_max(),
qconvolution->get_filter_min(),
qconvolution->get_filter_max(),
&min_out_value,
&max_out_value);
const float max_abs32 =
std::max(std::abs(min_out_value), std::abs(max_out_value));
const float max_abs8 =
std::max(std::abs(qconvolution->get_freezed_output_min()),
std::abs(qconvolution->get_freezed_output_max()));
// Output is signed int.
// s32 = f32 * std::pow(2, 31)/ max_abs32;
// s8 = f32 * std::pow(2, 7)/ max_abs8;
// s8 = s32 * std::pow(2, -24) * max_abs32 / max_abs8;
const float scale = static_cast<float>(
(std::pow(2, -24) * static_cast<double>(max_abs32 / max_abs8)));
return scale;
}
}
}
}
}
...@@ -24,12 +24,16 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -24,12 +24,16 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#pragma once #pragma once
#include <limits>
#include <vector> #include <vector>
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/op/constant.hpp" #include "ngraph/op/constant.hpp"
#include "ngraph/runtime/cpu/op/quantize.hpp" #include "ngraph/runtime/cpu/op/quantize.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
namespace ngraph namespace ngraph
{ {
namespace runtime namespace runtime
...@@ -38,6 +42,23 @@ namespace ngraph ...@@ -38,6 +42,23 @@ namespace ngraph
{ {
namespace quantization_util namespace quantization_util
{ {
template <class T1, class T2, class T3>
void quantization_range_for_multiplication(
float min_a, float max_a, float min_b, float max_b, float* min_c, float* max_c)
{
// begin code copied and pasted (and modified) from
// github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/quantization_utils.h
float a_one_quant_level = (max_a - min_a) / (std::numeric_limits<T1>::max() -
std::numeric_limits<T1>::min());
float b_one_quant_level = (max_b - min_b) / (std::numeric_limits<T2>::max() -
std::numeric_limits<T2>::min());
float c_one_quant_level = a_one_quant_level * b_one_quant_level;
*min_c = c_one_quant_level * std::numeric_limits<T3>::min();
*max_c = c_one_quant_level * std::numeric_limits<T3>::max();
// end code copied and pasted (and modified) from
// github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/quantization_utils.h
}
static inline void get_min_max_range(float input_min_range, static inline void get_min_max_range(float input_min_range,
float input_max_range, float input_max_range,
bool is_signed, bool is_signed,
...@@ -67,6 +88,8 @@ namespace ngraph ...@@ -67,6 +88,8 @@ namespace ngraph
quant_util.push_back(max_range); quant_util.push_back(max_range);
quant_util.push_back(scale); quant_util.push_back(scale);
} }
float get_scale(const ngraph::Node* node);
} }
} }
} }
......
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include "ngraph/runtime/cpu/op/dequantize.hpp" #include "ngraph/runtime/cpu/op/dequantize.hpp"
#include "ngraph/runtime/cpu/op/quantize.hpp" #include "ngraph/runtime/cpu/op/quantize.hpp"
#include "ngraph/runtime/cpu/op/quantized_avg_pool.hpp" #include "ngraph/runtime/cpu/op/quantized_avg_pool.hpp"
#include "ngraph/runtime/cpu/op/quantized_conv.hpp"
#include "ngraph/runtime/cpu/op/quantized_max_pool.hpp" #include "ngraph/runtime/cpu/op/quantized_max_pool.hpp"
#include "util/all_close.hpp" #include "util/all_close.hpp"
#include "util/all_close_f.hpp" #include "util/all_close_f.hpp"
...@@ -204,6 +205,55 @@ TEST(quantize_cpu, dequantize_from_int8) ...@@ -204,6 +205,55 @@ TEST(quantize_cpu, dequantize_from_int8)
DequantizeTest<int8_t>(42, -1.0f, 300.0f, static_cast<float>(99.212601)); DequantizeTest<int8_t>(42, -1.0f, 300.0f, static_cast<float>(99.212601));
} }
TEST(quantize_cpu, quantizedConv2D_small)
{
Shape shape_a{1, 1, 3, 4}; // input shape
Shape shape_b{1, 1, 3, 3}; // filter shape
Shape shape_r{1, 1, 3, 4}; // output shape
vector<uint8_t> a_data = {1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4};
vector<int8_t> b_data = {1, 2, 3, 4, 5, 0, 0, 1, 2};
auto A = make_shared<op::Parameter>(element::u8, shape_a);
auto B = make_shared<op::Parameter>(element::i8, shape_b);
auto C = op::Constant::create(element::f32, Shape{1}, {0.0f});
auto D = op::Constant::create(element::f32, Shape{1}, {255.0f});
auto E = op::Constant::create(element::f32, Shape{1}, {-127.0f});
auto F = op::Constant::create(element::f32, Shape{1}, {127.0f});
auto G = op::Constant::create(element::f32, Shape{1}, {22.0f});
auto H = op::Constant::create(element::f32, Shape{1}, {90.0f});
auto CV = make_shared<op::QuantizedConvolution>(A,
B,
Strides{1, 1}, // move_strides
Strides{1, 1}, // filter_dilation
CoordinateDiff{1, 1}, // below_pads
CoordinateDiff{1, 1}, // above_pads
Strides{1, 1}, // data_dilation
C,
D,
E,
F,
G,
H);
auto output_data = std::make_shared<op::GetOutputElement>(CV, 0);
auto output_min = std::make_shared<op::GetOutputElement>(CV, 1);
auto output_max = std::make_shared<op::GetOutputElement>(CV, 2);
auto f = make_shared<Function>(NodeVector{output_data, output_min, output_max},
op::ParameterVector{A, B});
auto backend = runtime::Backend::create("CPU");
// Create some tensors for input/output
auto a = backend->create_tensor(element::u8, shape_a);
copy_data(a, a_data);
auto b = backend->create_tensor(element::i8, shape_b);
copy_data(b, b_data);
auto result = backend->create_tensor(element::i8, shape_r);
auto result_min = backend->create_tensor(element::f32, Shape{1});
auto result_max = backend->create_tensor(element::f32, Shape{1});
backend->call_with_validate(f, {result, result_min, result_max}, {a, b});
EXPECT_EQ((vector<int8_t>{31, 48, 42, 45, 54, 102, 127, 61, 47, 74, 61, 55}),
read_vector<int8_t>(result));
EXPECT_EQ((vector<float>{22.0}), read_vector<float>(result_min));
EXPECT_EQ((vector<float>{90.0}), read_vector<float>(result_max));
}
TEST(quantize_cpu, quantize_to_uint8_small) TEST(quantize_cpu, quantize_to_uint8_small)
{ {
vector<float> a_data = {-85.0, 0.0, 2.0, 10.0, 15.0}; vector<float> a_data = {-85.0, 0.0, 2.0, 10.0, 15.0};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment