Commit 00afd349 authored by Nishant Patel's avatar Nishant Patel Committed by Nick Korovaiko

Add support for Quantize op via mkldnn for IA backend (codegen + DEX) (#1576)

* Add support for Quantize op via mkldnn for IA backend (codegen + DEX)

* PR feedback addressed

* Use call_with_validate
parent d8587872
...@@ -37,6 +37,7 @@ set(SRC ...@@ -37,6 +37,7 @@ set(SRC
builder/convert_layout.cpp builder/convert_layout.cpp
builder/convolution.cpp builder/convolution.cpp
builder/dequantize.cpp builder/dequantize.cpp
builder/quantize.cpp
builder/dot.cpp builder/dot.cpp
builder/function_call.cpp builder/function_call.cpp
builder/lstm.cpp builder/lstm.cpp
...@@ -81,6 +82,7 @@ set(SRC ...@@ -81,6 +82,7 @@ set(SRC
op/conv_relu.cpp op/conv_relu.cpp
op/convert_layout.cpp op/convert_layout.cpp
op/dequantize.cpp op/dequantize.cpp
op/quantize.cpp
op/loop_kernel.cpp op/loop_kernel.cpp
op/lstm.cpp op/lstm.cpp
op/matmul_bias.cpp op/matmul_bias.cpp
......
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include <vector>
#include "ngraph/runtime/cpu/cpu_builder.hpp"
#include "ngraph/runtime/cpu/mkldnn_invoke.hpp"
#include "ngraph/runtime/cpu/mkldnn_utils.hpp"
#include "ngraph/runtime/cpu/quantization_util.hpp"
using namespace std;
using namespace ngraph;
namespace ngraph
{
namespace runtime
{
namespace cpu
{
template <>
void Builder::BUILDER_DECL(ngraph::op::Quantize)
{
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{
auto quantize = static_cast<const ngraph::op::Quantize*>(node);
auto& functors = external_function->get_functors();
auto& arg_tensor = external_function->get_tensor_data(args[0].get_name());
auto& out_tensor = external_function->get_tensor_data(out[0].get_name());
auto& out1_tensor = external_function->get_tensor_data(out[1].get_name());
auto& out2_tensor = external_function->get_tensor_data(out[2].get_name());
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto input_desc = mkldnn_utils::get_input_mkldnn_md(node, 0);
auto result_desc = mkldnn_utils::get_output_mkldnn_md(node, 0);
vector<float> quant_util; // min_range, max_range & scale.
quantization_util::get_min_max_range(quantize->get_input_min(),
quantize->get_input_max(),
(quantize->get_quantize_et()).is_signed(),
quant_util);
std::vector<float> scales;
scales.push_back(quant_util[2]);
size_t quantize_index =
mkldnn_emitter->build_quantize_reorder(input_desc, result_desc, scales);
auto& deps = mkldnn_emitter->get_primitive_deps(quantize_index);
auto functor = [&, quantize_index, quant_util](CPURuntimeContext* ctx) {
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[0], arg_tensor);
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[1], out_tensor);
*(static_cast<float*>(out1_tensor)) = quant_util[0];
*(static_cast<float*>(out2_tensor)) = quant_util[1];
cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, quantize_index);
};
functors.emplace_back(functor);
}
else
{
throw ngraph_error("Unsupported parameters for QuantizeOp via DEX");
}
}
REGISTER_OP_BUILDER(Quantize);
}
}
}
...@@ -114,6 +114,7 @@ ...@@ -114,6 +114,7 @@
#include "ngraph/runtime/cpu/op/rnn.hpp" #include "ngraph/runtime/cpu/op/rnn.hpp"
#include "ngraph/runtime/cpu/op/sigmoid.hpp" #include "ngraph/runtime/cpu/op/sigmoid.hpp"
#include "ngraph/runtime/cpu/op/sigmoid_mul.hpp" #include "ngraph/runtime/cpu/op/sigmoid_mul.hpp"
#include "ngraph/runtime/cpu/quantization_util.hpp"
#include "ngraph/type/element_type.hpp" #include "ngraph/type/element_type.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
...@@ -2792,6 +2793,43 @@ namespace ngraph ...@@ -2792,6 +2793,43 @@ namespace ngraph
} }
} }
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Quantize)
{
auto quantize = static_cast<const ngraph::op::Quantize*>(node);
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto input_data_desc = mkldnn_utils::get_input_mkldnn_md(node, 0);
auto result_desc = mkldnn_utils::get_output_mkldnn_md(node, 0);
std::vector<float> quant_util; // min_range, max_range & scale.
quantization_util::get_min_max_range(quantize->get_input_min(),
quantize->get_input_max(),
(quantize->get_quantize_et()).is_signed(),
quant_util);
std::vector<float> scales;
scales.push_back(quant_util[2]);
size_t quantize_index = 0;
quantize_index = mkldnn_emitter->build_quantize_reorder(
input_data_desc, result_desc, scales);
auto& deps = mkldnn_emitter->get_primitive_deps(quantize_index);
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[0])
<< ", " << args[0].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[1])
<< ", " << out[0].get_name() << ");\n";
writer << "*(" << out[1].get_name() << ") = " << quant_util[0] << ";\n";
writer << "*(" << out[2].get_name() << ") = " << quant_util[1] << ";\n";
writer << "cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, "
<< to_string(quantize_index) << ");\n";
}
else
{
throw ngraph_error("Unsupported parameters for QuantizeOp");
}
}
template <> template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Dequantize) void CPU_Emitter::EMITTER_DECL(ngraph::op::Dequantize)
{ {
......
...@@ -149,6 +149,7 @@ ...@@ -149,6 +149,7 @@
#include "ngraph/runtime/cpu/op/lstm.hpp" #include "ngraph/runtime/cpu/op/lstm.hpp"
#include "ngraph/runtime/cpu/op/matmul_bias.hpp" #include "ngraph/runtime/cpu/op/matmul_bias.hpp"
#include "ngraph/runtime/cpu/op/max_pool_with_indices.hpp" #include "ngraph/runtime/cpu/op/max_pool_with_indices.hpp"
#include "ngraph/runtime/cpu/op/quantize.hpp"
#include "ngraph/runtime/cpu/op/quantized_avg_pool.hpp" #include "ngraph/runtime/cpu/op/quantized_avg_pool.hpp"
#include "ngraph/runtime/cpu/op/quantized_max_pool.hpp" #include "ngraph/runtime/cpu/op/quantized_max_pool.hpp"
#include "ngraph/runtime/cpu/op/rnn.hpp" #include "ngraph/runtime/cpu/op/rnn.hpp"
...@@ -296,6 +297,7 @@ static const runtime::cpu::OpMap dispatcher{ ...@@ -296,6 +297,7 @@ static const runtime::cpu::OpMap dispatcher{
{TI(ngraph::op::Ceiling), &runtime::cpu::CPU_Emitter::emit<op::Ceiling>}, {TI(ngraph::op::Ceiling), &runtime::cpu::CPU_Emitter::emit<op::Ceiling>},
{TI(ngraph::op::Sqrt), &runtime::cpu::CPU_Emitter::emit<op::Sqrt>}, {TI(ngraph::op::Sqrt), &runtime::cpu::CPU_Emitter::emit<op::Sqrt>},
{TI(ngraph::op::Convolution), &runtime::cpu::CPU_Emitter::emit<op::Convolution>}, {TI(ngraph::op::Convolution), &runtime::cpu::CPU_Emitter::emit<op::Convolution>},
{TI(ngraph::op::Quantize), &runtime::cpu::CPU_Emitter::emit<op::Quantize>},
{TI(ngraph::op::Dequantize), &runtime::cpu::CPU_Emitter::emit<op::Dequantize>}, {TI(ngraph::op::Dequantize), &runtime::cpu::CPU_Emitter::emit<op::Dequantize>},
{TI(ngraph::op::ConvolutionBackpropFilters), {TI(ngraph::op::ConvolutionBackpropFilters),
&runtime::cpu::CPU_Emitter::emit<op::ConvolutionBackpropFilters>}, &runtime::cpu::CPU_Emitter::emit<op::ConvolutionBackpropFilters>},
......
...@@ -140,12 +140,9 @@ size_t MKLDNNEmitter::build_dequantization(const ngraph::Node* node, ...@@ -140,12 +140,9 @@ size_t MKLDNNEmitter::build_dequantization(const ngraph::Node* node,
const float scale_factor = max_abs / target_range; const float scale_factor = max_abs / target_range;
std::vector<float> scales; std::vector<float> scales;
scales.push_back(scale_factor); scales.push_back(scale_factor);
mkldnn::primitive_attr attr;
attr.set_output_scales(0, scales);
attr.set_int_output_round_mode(mkldnn::round_mode::round_nearest);
size_t dequantize_index = 0; size_t dequantize_index = 0;
dequantize_index = this->build_quantize_reorder(input_desc, result_desc, attr); dequantize_index = this->build_quantize_reorder(input_desc, result_desc, scales);
return dequantize_index; return dequantize_index;
} }
...@@ -717,10 +714,15 @@ size_t MKLDNNEmitter::build_reorder(const mkldnn::memory::desc& input_desc, ...@@ -717,10 +714,15 @@ size_t MKLDNNEmitter::build_reorder(const mkldnn::memory::desc& input_desc,
size_t MKLDNNEmitter::build_quantize_reorder(const mkldnn::memory::desc& input_desc, size_t MKLDNNEmitter::build_quantize_reorder(const mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& result_desc, const mkldnn::memory::desc& result_desc,
const mkldnn::primitive_attr attr) const std::vector<float>& scales)
{ {
size_t input_index = build_memory_primitive(input_desc); size_t input_index = build_memory_primitive(input_desc);
size_t result_index = build_memory_primitive(result_desc); size_t result_index = build_memory_primitive(result_desc);
mkldnn::primitive_attr attr;
attr.set_output_scales(0, scales);
attr.set_int_output_round_mode(mkldnn::round_mode::round_nearest);
auto reorder_desc = auto reorder_desc =
mkldnn::reorder::primitive_desc({input_desc, mkldnn_utils::global_cpu_engine}, mkldnn::reorder::primitive_desc({input_desc, mkldnn_utils::global_cpu_engine},
{result_desc, mkldnn_utils::global_cpu_engine}, {result_desc, mkldnn_utils::global_cpu_engine},
......
...@@ -513,7 +513,7 @@ namespace ngraph ...@@ -513,7 +513,7 @@ namespace ngraph
size_t build_quantize_reorder(const mkldnn::memory::desc& input_desc, size_t build_quantize_reorder(const mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& result_desc, const mkldnn::memory::desc& result_desc,
const mkldnn::primitive_attr attr); const std::vector<float>& scales);
size_t build_dequantization(const ngraph::Node* node, size_t build_dequantization(const ngraph::Node* node,
const mkldnn::memory::desc& input_desc, const mkldnn::memory::desc& input_desc,
......
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "ngraph/runtime/cpu/op/quantize.hpp"
#include "ngraph/op/constant.hpp"
ngraph::op::Quantize::Quantize(std::shared_ptr<Node> input,
std::shared_ptr<Node> min,
std::shared_ptr<Node> max,
const element::Type& type)
: Op("Quantize", check_single_output_args({input, min, max}))
, m_element_type(type)
{
constructor_validate_and_infer_types();
if (input->get_element_type() != element::f32)
{
throw ngraph_error("Quantization supported only from float32 --> i8/u8!");
}
if (min->get_element_type() != min->get_element_type())
{
throw ngraph_error("Min's element type isn't equal to max's!");
}
if (min->get_shape().size() != 0)
{
throw ngraph_error("Min is not a scalar!");
}
if (max->get_shape().size() != 0)
{
throw ngraph_error("Max is not a scalar!");
}
if (!(std::dynamic_pointer_cast<op::Constant>(min) &&
std::dynamic_pointer_cast<op::Constant>(max)))
{
throw ngraph_error("Min and max have to be constants!");
}
auto min_const_op = std::static_pointer_cast<ngraph::op::Constant>(min);
auto max_const_op = std::static_pointer_cast<ngraph::op::Constant>(max);
float input_min_range = *(static_cast<float const*>(min_const_op->get_data_ptr()));
float input_max_range = *(static_cast<float const*>(max_const_op->get_data_ptr()));
this->m_input_min = input_min_range;
this->m_input_max = input_max_range;
set_output_size(3);
set_output_type(0, type, input->get_shape());
set_output_type(1, element::f32, Shape{});
set_output_type(2, element::f32, Shape{});
}
std::shared_ptr<ngraph::Node>
ngraph::op::Quantize::copy_with_new_args(const NodeVector& new_args) const
{
if (new_args.size() != 3)
{
throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<Quantize>(
new_args.at(0), new_args.at(1), new_args.at(2), m_element_type);
}
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include "ngraph/node.hpp"
#include "ngraph/node_vector.hpp"
#include "ngraph/op/op.hpp"
namespace ngraph
{
namespace op
{
class Quantize : public Op
{
public:
Quantize(std::shared_ptr<Node> input,
std::shared_ptr<Node> min,
std::shared_ptr<Node> max,
const element::Type& type);
const element::Type& get_quantize_et() const { return m_element_type; }
float get_input_min() const { return m_input_min; }
float get_input_max() const { return m_input_max; }
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
private:
const element::Type m_element_type;
float m_input_min;
float m_input_max;
};
}
}
...@@ -46,6 +46,7 @@ ...@@ -46,6 +46,7 @@
#include "ngraph/runtime/cpu/op/group_conv.hpp" #include "ngraph/runtime/cpu/op/group_conv.hpp"
#include "ngraph/runtime/cpu/op/lstm.hpp" #include "ngraph/runtime/cpu/op/lstm.hpp"
#include "ngraph/runtime/cpu/op/max_pool_with_indices.hpp" #include "ngraph/runtime/cpu/op/max_pool_with_indices.hpp"
#include "ngraph/runtime/cpu/op/quantize.hpp"
#include "ngraph/runtime/cpu/op/quantized_avg_pool.hpp" #include "ngraph/runtime/cpu/op/quantized_avg_pool.hpp"
#include "ngraph/runtime/cpu/op/quantized_max_pool.hpp" #include "ngraph/runtime/cpu/op/quantized_max_pool.hpp"
#include "ngraph/runtime/cpu/op/rnn.hpp" #include "ngraph/runtime/cpu/op/rnn.hpp"
...@@ -726,6 +727,19 @@ namespace ngraph ...@@ -726,6 +727,19 @@ namespace ngraph
dequantize->set_op_annotations(op_annotations); dequantize->set_op_annotations(op_annotations);
} }
} }
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::Quantize)
{
if (node->get_input_element_type(0) == element::f32)
{
auto quantize = static_cast<op::Quantize*>(node);
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
quantize->set_op_annotations(op_annotations);
}
}
} }
} }
} }
...@@ -783,6 +797,7 @@ static const runtime::cpu::pass::AssignOpMap s_dispatcher{ ...@@ -783,6 +797,7 @@ static const runtime::cpu::pass::AssignOpMap s_dispatcher{
{TI(ngraph::op::QuantizedAvgPool), {TI(ngraph::op::QuantizedAvgPool),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::QuantizedAvgPool>}, &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::QuantizedAvgPool>},
{TI(ngraph::op::Softmax), &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::Softmax>}, {TI(ngraph::op::Softmax), &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::Softmax>},
{TI(ngraph::op::Quantize), &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::Quantize>},
{TI(ngraph::op::ReplaceSlice), {TI(ngraph::op::ReplaceSlice),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::ReplaceSlice>}, &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::ReplaceSlice>},
{TI(ngraph::op::ConvolutionAdd), {TI(ngraph::op::ConvolutionAdd),
......
...@@ -54,6 +54,7 @@ ...@@ -54,6 +54,7 @@
#include "ngraph/runtime/cpu/op/group_conv.hpp" #include "ngraph/runtime/cpu/op/group_conv.hpp"
#include "ngraph/runtime/cpu/op/lstm.hpp" #include "ngraph/runtime/cpu/op/lstm.hpp"
#include "ngraph/runtime/cpu/op/max_pool_with_indices.hpp" #include "ngraph/runtime/cpu/op/max_pool_with_indices.hpp"
#include "ngraph/runtime/cpu/op/quantize.hpp"
#include "ngraph/runtime/cpu/op/quantized_avg_pool.hpp" #include "ngraph/runtime/cpu/op/quantized_avg_pool.hpp"
#include "ngraph/runtime/cpu/op/quantized_max_pool.hpp" #include "ngraph/runtime/cpu/op/quantized_max_pool.hpp"
#include "ngraph/runtime/cpu/op/rnn.hpp" #include "ngraph/runtime/cpu/op/rnn.hpp"
...@@ -1645,6 +1646,20 @@ namespace ngraph ...@@ -1645,6 +1646,20 @@ namespace ngraph
throw ngraph_error("Dequantized op is only supported in MKLDNN for now."); throw ngraph_error("Dequantized op is only supported in MKLDNN for now.");
} }
} }
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::Quantize)
{
if (mkldnn_utils::use_mkldnn_kernel(node.get()))
{
//TODO : Propogate Layout
set_native_layouts(external_function, node);
}
else
{
throw ngraph_error("Quantized op is only supported in MKLDNN for now.");
}
}
} }
} }
} }
...@@ -1707,6 +1722,7 @@ static const runtime::cpu::pass::LayoutOpMap s_dispatcher{ ...@@ -1707,6 +1722,7 @@ static const runtime::cpu::pass::LayoutOpMap s_dispatcher{
{TI(ngraph::op::ConvolutionAdd), {TI(ngraph::op::ConvolutionAdd),
&runtime::cpu::pass::CPULayout::layout<ngraph::op::ConvolutionAdd>}, &runtime::cpu::pass::CPULayout::layout<ngraph::op::ConvolutionAdd>},
{TI(ngraph::op::Dequantize), &runtime::cpu::pass::CPULayout::layout<ngraph::op::Dequantize>}, {TI(ngraph::op::Dequantize), &runtime::cpu::pass::CPULayout::layout<ngraph::op::Dequantize>},
{TI(ngraph::op::Quantize), &runtime::cpu::pass::CPULayout::layout<ngraph::op::Quantize>},
}; };
bool runtime::cpu::pass::CPULayout::run_on_call_graph(const std::list<std::shared_ptr<Node>>& nodes) bool runtime::cpu::pass::CPULayout::run_on_call_graph(const std::list<std::shared_ptr<Node>>& nodes)
......
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#pragma once
#include <vector>
#include "ngraph/node.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/runtime/cpu/op/quantize.hpp"
#include "ngraph/util.hpp"
namespace ngraph
{
namespace runtime
{
namespace cpu
{
namespace quantization_util
{
static inline void get_min_max_range(float input_min_range,
float input_max_range,
bool is_signed,
std::vector<float>& quant_util)
{
// begin code copied and pasted from
// github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/quantize_op.cc
float min_range;
float max_range;
// If input_min_range and input_max_range are close,
// introduce a slightly larger delta between them.
min_range = std::min(0.0f, input_min_range);
const float epsilon =
std::max(1.0f, std::max(fabsf(input_min_range), fabsf(input_max_range))) /
100.0f;
max_range = std::max(input_max_range, min_range + epsilon);
max_range = std::max(0.0f, max_range);
// end code copied and pasted from
// github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/quantize_op.cc
const float max_abs = std::max(std::abs(min_range), std::abs(max_range));
const float target_range =
static_cast<float>((is_signed ? std::pow(2, 7) : std::pow(2, 8)) - 1);
max_range = max_abs;
min_range = is_signed ? -max_abs : 0;
const float scale = target_range / max_abs;
quant_util.push_back(min_range);
quant_util.push_back(max_range);
quant_util.push_back(scale);
}
}
}
}
}
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#include "ngraph/op/constant.hpp" #include "ngraph/op/constant.hpp"
#include "ngraph/op/get_output_element.hpp" #include "ngraph/op/get_output_element.hpp"
#include "ngraph/runtime/cpu/op/dequantize.hpp" #include "ngraph/runtime/cpu/op/dequantize.hpp"
#include "ngraph/runtime/cpu/op/quantize.hpp"
#include "ngraph/runtime/cpu/op/quantized_avg_pool.hpp" #include "ngraph/runtime/cpu/op/quantized_avg_pool.hpp"
#include "ngraph/runtime/cpu/op/quantized_max_pool.hpp" #include "ngraph/runtime/cpu/op/quantized_max_pool.hpp"
#include "util/all_close.hpp" #include "util/all_close.hpp"
...@@ -202,3 +203,81 @@ TEST(quantize_cpu, dequantize_from_int8) ...@@ -202,3 +203,81 @@ TEST(quantize_cpu, dequantize_from_int8)
{ {
DequantizeTest<int8_t>(42, -1.0f, 300.0f, static_cast<float>(99.212601)); DequantizeTest<int8_t>(42, -1.0f, 300.0f, static_cast<float>(99.212601));
} }
TEST(quantize_cpu, quantize_to_uint8_small)
{
vector<float> a_data = {-85.0, 0.0, 2.0, 10.0, 15.0};
Shape shape_a{5};
auto A = make_shared<op::Parameter>(element::f32, shape_a);
auto B = op::Constant::create(element::f32, Shape{}, {-85.0f});
auto C = op::Constant::create(element::f32, Shape{}, {15.0f});
auto QT = make_shared<op::Quantize>(A, B, C, element::u8);
auto output_data = std::make_shared<op::GetOutputElement>(QT, 0);
auto output_min = std::make_shared<op::GetOutputElement>(QT, 1);
auto output_max = std::make_shared<op::GetOutputElement>(QT, 2);
auto f = make_shared<Function>(NodeVector{output_data, output_min, output_max},
op::ParameterVector{A});
auto backend = runtime::Backend::create("CPU");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape_a);
copy_data(a, a_data);
auto result = backend->create_tensor(element::u8, shape_a);
auto result_min = backend->create_tensor(element::f32, Shape{});
auto result_max = backend->create_tensor(element::f32, Shape{});
backend->call_with_validate(f, {result, result_min, result_max}, {a});
EXPECT_EQ((vector<uint8_t>{0, 0, 6, 30, 45}), read_vector<uint8_t>(result));
EXPECT_EQ((vector<float>{0.0}), read_vector<float>(result_min));
EXPECT_EQ((vector<float>{85.0}), read_vector<float>(result_max));
}
TEST(quantize_cpu, quantize_to_uint8)
{
vector<float> a_data = {-255.0, 0.0, 1.0, 1.25, 1.75, 64.0, 127.0, 500.0};
Shape shape_a{8};
auto A = make_shared<op::Parameter>(element::f32, shape_a);
auto B = op::Constant::create(element::f32, Shape{}, {-255.0f});
auto C = op::Constant::create(element::f32, Shape{}, {127.0f});
auto QT = make_shared<op::Quantize>(A, B, C, element::u8);
auto output_data = std::make_shared<op::GetOutputElement>(QT, 0);
auto output_min = std::make_shared<op::GetOutputElement>(QT, 1);
auto output_max = std::make_shared<op::GetOutputElement>(QT, 2);
auto f = make_shared<Function>(NodeVector{output_data, output_min, output_max},
op::ParameterVector{A});
auto backend = runtime::Backend::create("CPU");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape_a);
copy_data(a, a_data);
auto result = backend->create_tensor(element::u8, shape_a);
auto result_min = backend->create_tensor(element::f32, Shape{});
auto result_max = backend->create_tensor(element::f32, Shape{});
backend->call_with_validate(f, {result, result_min, result_max}, {a});
EXPECT_EQ((vector<uint8_t>{0, 0, 1, 1, 2, 64, 127, 255}), read_vector<uint8_t>(result));
EXPECT_EQ((vector<float>{0.0}), read_vector<float>(result_min));
EXPECT_EQ((vector<float>{255.0}), read_vector<float>(result_max));
}
TEST(quantize_cpu, quantize_to_int8)
{
vector<float> a_data = {-127.0, 0.0, 1.0, 3.0, 5.0, 64.0, 127.0, 500.0};
Shape shape_a{8};
auto A = make_shared<op::Parameter>(element::f32, shape_a);
auto B = op::Constant::create(element::f32, Shape{}, {-127.0f});
auto C = op::Constant::create(element::f32, Shape{}, {127.0f});
auto QT = make_shared<op::Quantize>(A, B, C, element::i8);
auto output_data = std::make_shared<op::GetOutputElement>(QT, 0);
auto output_min = std::make_shared<op::GetOutputElement>(QT, 1);
auto output_max = std::make_shared<op::GetOutputElement>(QT, 2);
auto f = make_shared<Function>(NodeVector{output_data, output_min, output_max},
op::ParameterVector{A});
auto backend = runtime::Backend::create("CPU");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape_a);
copy_data(a, a_data);
auto result = backend->create_tensor(element::i8, shape_a);
auto result_min = backend->create_tensor(element::f32, Shape{});
auto result_max = backend->create_tensor(element::f32, Shape{});
backend->call_with_validate(f, {result, result_min, result_max}, {a});
EXPECT_EQ((vector<int8_t>{-127, 0, 1, 3, 5, 64, 127, 127}), read_vector<int8_t>(result));
EXPECT_EQ((vector<float>{-127}), read_vector<float>(result_min));
EXPECT_EQ((vector<float>{127}), read_vector<float>(result_max));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment