Commit 36422810 authored by Adam Rogowiec's avatar Adam Rogowiec Committed by Scott Cyphers

[Fused] FakeQuantize operation. (#2928)

* Draft of FakeQuantize operation along with UTs.

* Add FakeQuantize to implemented operators on IGPU.

* Get back FakeQuantize op case to switch.

* Fix compilation errors.

* Skip test for INTERPRETER backend and disable type_prop tests.

* Initial implementation covering the most basic case

* Cleanup of fake_quantize_with_clip UT

* Reformat the cpu unit tests manifest and unlock anothe fake quant UT

* Handle the clipping case by subtracting input_low from quantization input

* Clip the input data before quantization to avoid Selects

* UT manifest fix

* Obsolete comment removed

* Code formatting

* Broadcast input data for non-scalar in/out params

* Code formatting

* Enable the type prop tests for FakeQuantize

* Dequant the data without using the Dequantize op (fixes an edge case)
parent 8707fba8
......@@ -300,6 +300,8 @@ set (SRC
op/fused/depth_to_space.hpp
op/fused/elu.cpp
op/fused/elu.hpp
op/fused/fake_quantize.cpp
op/fused/fake_quantize.hpp
op/fused/gemm.cpp
op/fused/gemm.hpp
op/fused/grn.cpp
......
......@@ -99,6 +99,7 @@
#include "ngraph/op/fused/conv_fused.hpp"
#include "ngraph/op/fused/depth_to_space.hpp"
#include "ngraph/op/fused/elu.hpp"
#include "ngraph/op/fused/fake_quantize.hpp"
#include "ngraph/op/fused/gemm.hpp"
#include "ngraph/op/fused/grn.hpp"
#include "ngraph/op/fused/group_conv.hpp"
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <memory>
#include "fake_quantize.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/convert.hpp"
#include "ngraph/op/dequantize.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/op/greater.hpp"
#include "ngraph/op/less_eq.hpp"
#include "ngraph/op/maximum.hpp"
#include "ngraph/op/minimum.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/quantize.hpp"
#include "ngraph/op/select.hpp"
#include "ngraph/op/subtract.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/shape.hpp"
using namespace std;
using namespace ngraph;
op::FakeQuantize::FakeQuantize(const shared_ptr<Node>& data,
const shared_ptr<Node>& input_low,
const shared_ptr<Node>& input_high,
const shared_ptr<Node>& output_low,
const shared_ptr<Node>& output_high,
size_t levels)
: FusedOp("FakeQuantize", {data, input_low, input_high, output_low, output_high})
, m_levels(levels)
{
constructor_validate_and_infer_types();
}
void op::FakeQuantize::pre_validate_and_infer_types()
{
const auto& data_pshape = get_input_partial_shape(0);
const auto& input_low_pshape = get_input_partial_shape(1);
const auto& input_high_pshape = get_input_partial_shape(2);
const auto& output_low_pshape = get_input_partial_shape(3);
const auto& output_high_pshape = get_input_partial_shape(4);
if (data_pshape.is_static() && input_low_pshape.is_static() && input_high_pshape.is_static() &&
output_low_pshape.is_static() && output_high_pshape.is_static())
{
const Shape data_shape{data_pshape.to_shape()};
const Shape input_low_shape{input_low_pshape.to_shape()};
const Shape input_high_shape{input_high_pshape.to_shape()};
const Shape output_low_shape{output_low_pshape.to_shape()};
const Shape output_high_shape{output_high_pshape.to_shape()};
NODE_VALIDATION_CHECK(
this,
(input_low_shape.size() == 0 ||
(input_low_shape.size() == 1 && input_low_shape.at(0) == data_shape.at(1))),
"Input low tensor shape: ",
input_low_shape,
", must either be a scalar or a vector of size equal to number of channels.");
NODE_VALIDATION_CHECK(
this,
(input_high_shape.size() == 0 ||
(input_high_shape.size() == 1 && input_high_shape.at(0) == data_shape.at(1))),
"Input high tensor shape: ",
input_high_shape,
", must either be a scalar or a vector of size equal to number of channels.");
NODE_VALIDATION_CHECK(
this,
(output_low_shape.size() == 0 ||
(output_low_shape.size() == 1 && output_low_shape.at(0) == data_shape.at(1))),
"Output low tensor shape: ",
output_low_shape,
", must either be a scalar or a vector of size equal to number of channels.");
NODE_VALIDATION_CHECK(
this,
(output_high_shape.size() == 0 ||
(output_high_shape.size() == 1 && output_high_shape.at(0) == data_shape.at(1))),
"Output high tensor shape: ",
output_high_shape,
", must either be a scalar or a vector of size equal to number of channels.");
}
}
NodeVector op::FakeQuantize::decompose_op() const
{
shared_ptr<Node> data{get_argument(0)};
shared_ptr<Node> input_low{get_argument(1)};
shared_ptr<Node> input_high{get_argument(2)};
shared_ptr<Node> output_low{get_argument(3)};
shared_ptr<Node> output_high{get_argument(4)};
if (input_low->get_shape().size() == 0)
{
NodeVector broadcasted_nodes =
numpy_style_broadcast(NodeVector{data, input_low, input_high, output_low, output_high});
data = broadcasted_nodes.at(0);
input_low = broadcasted_nodes.at(1);
input_high = broadcasted_nodes.at(2);
output_low = broadcasted_nodes.at(3);
output_high = broadcasted_nodes.at(4);
}
else
{
input_low = legacy_style_broadcast_for_binary_operation(data, input_low, 1).at(1);
input_high = legacy_style_broadcast_for_binary_operation(data, input_high, 1).at(1);
output_low = legacy_style_broadcast_for_binary_operation(data, output_low, 1).at(1);
output_high = legacy_style_broadcast_for_binary_operation(data, output_high, 1).at(1);
}
const auto input_data_shape = data->get_shape();
const auto input_data_type = data->get_element_type();
const auto levels_minus_one =
Constant::create(input_data_type,
input_data_shape,
vector<size_t>(shape_size(input_data_shape), m_levels - 1));
// map the number of quantization levels to the nGraph's quantization and dequantization scales
const auto quant_scale = (input_high - input_low) / levels_minus_one;
const auto dequant_scale = (output_high - output_low) / levels_minus_one;
// zero_point type needs to match the quantization output type
const auto zero_point = Constant::create(element::i32, data->get_shape(), {0.0});
const auto axes = get_default_order(input_data_shape);
// clip the input data to the range <input_low;input_high>
data =
std::make_shared<op::Minimum>(input_high, std::make_shared<op::Maximum>(input_low, data));
// shift the input data so that it contains only positive values (and zeros)
data = data - input_low;
shared_ptr<Node> quantized_data =
make_shared<op::Quantize>(data,
quant_scale,
zero_point,
element::i32,
axes,
op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_INFINITY);
quantized_data = make_shared<op::Convert>(quantized_data, input_data_type);
// dequantization without using the Dequantize op (just a multiplication by the dequant_scale)
const auto dequantized_data = quantized_data * dequant_scale;
// shift the results so that they fall into the <output_low;output_high> range
return {dequantized_data + output_low};
}
shared_ptr<Node> op::FakeQuantize::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<FakeQuantize>(new_args.at(0), // X
new_args.at(1), // input_low
new_args.at(2), // input_high
new_args.at(3), // output_low
new_args.at(4), // output_high
m_levels);
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/autodiff/adjoints.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/util/fused_op.hpp"
namespace ngraph
{
namespace op
{
///
/// \brief Class performing element-wise linear quantization.
///
/// \note Input floating point values are quantized into a discrete
/// set of floating point values.
///
/// \paragraph Implementation This class creates a node which performs the following operation:
/// round((data - input_low) / (input_high - input_low) * (levels-1)) /
/// (levels-1) * (output_high - output_low) + output_low
///
///
class FakeQuantize : public ngraph::op::util::FusedOp
{
public:
///
/// \brief Constructs a FakeQuantize operation node.
///
/// \param[in] data The input data tensor.
/// \param[in] input_low The minimum limit for input values.
/// \param[in] input_high The maximum limit for input values.
/// \param[in] output_low The minimum quantized value.
/// \param[in] output_high The maximum quantized value.
/// \param[in] levels The number of quantization levels.
///
FakeQuantize(const std::shared_ptr<ngraph::Node>& data,
const std::shared_ptr<ngraph::Node>& input_low,
const std::shared_ptr<ngraph::Node>& input_high,
const std::shared_ptr<ngraph::Node>& output_low,
const std::shared_ptr<ngraph::Node>& output_high,
std::size_t levels);
virtual NodeVector decompose_op() const override;
virtual void pre_validate_and_infer_types() override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
std::size_t get_levels() const { return m_levels; }
private:
std::size_t m_levels;
};
}
}
......@@ -23,6 +23,7 @@ NGRAPH_OP(ConvolutionBiasAdd, ngraph::op)
NGRAPH_OP(ConvolutionBiasBackpropFiltersBias, ngraph::op)
NGRAPH_OP(DepthToSpace, ngraph::op)
NGRAPH_OP(Elu, ngraph::op)
NGRAPH_OP(FakeQuantize, ngraph::op)
NGRAPH_OP(GRN, ngraph::op)
NGRAPH_OP(Gemm, ngraph::op)
NGRAPH_OP(GroupConvolution, ngraph::op)
......
......@@ -10,3 +10,4 @@ max_3d_to_scalar_int32
# Not implemented
erf
zero_sized_erf
......@@ -176,3 +176,6 @@ gather_no_axis_uint16
gather_no_axis_uint32
gather_no_axis_uint64
gather_no_axis_bool
fake_quantize
fake_quantize_with_clip
fake_quantize_with_clip_across_channels
......@@ -81,6 +81,7 @@
#include "ngraph/op/fused/conv_fused.hpp"
#include "ngraph/op/fused/depth_to_space.hpp"
#include "ngraph/op/fused/elu.hpp"
#include "ngraph/op/fused/fake_quantize.hpp"
#include "ngraph/op/fused/gemm.hpp"
#include "ngraph/op/fused/grn.hpp"
#include "ngraph/op/fused/group_conv.hpp"
......@@ -2055,6 +2056,7 @@ shared_ptr<runtime::Executable>
case OP_TYPEID::Elu:
case OP_TYPEID::EmbeddingLookup:
case OP_TYPEID::Erf:
case OP_TYPEID::FakeQuantize:
case OP_TYPEID::Gather:
case OP_TYPEID::GatherND:
case OP_TYPEID::GenerateMask:
......@@ -2176,6 +2178,7 @@ bool runtime::intelgpu::IntelGPUBackend::is_supported_impl(const Node& node)
case OP_TYPEID::HardSigmoid:
case OP_TYPEID::DepthToSpace:
case OP_TYPEID::Elu:
case OP_TYPEID::FakeQuantize:
case OP_TYPEID::Gemm:
case OP_TYPEID::GRN:
case OP_TYPEID::LeakyRelu:
......
......@@ -93,6 +93,9 @@ gather_no_axis_uint16
gather_no_axis_uint32
gather_no_axis_uint64
gather_no_axis_bool
fake_quantize
fake_quantize_with_clip
fake_quantize_with_clip_across_channels
# Not supported quant ops
model_dequantize_linear_1d_zero_scale_int8
......
......@@ -2,3 +2,7 @@
model_quant_conv_linear
model_qlinear_matmul
model_qlinear_matmul_3d
fake_quantize
fake_quantize_with_clip
fake_quantize_with_clip_across_channels
......@@ -70,6 +70,7 @@
#include "ngraph/op/fused/conv_fused.hpp"
#include "ngraph/op/fused/depth_to_space.hpp"
#include "ngraph/op/fused/elu.hpp"
#include "ngraph/op/fused/fake_quantize.hpp"
#include "ngraph/op/fused/gemm.hpp"
#include "ngraph/op/fused/grn.hpp"
#include "ngraph/op/fused/group_conv.hpp"
......@@ -941,6 +942,13 @@ static shared_ptr<ngraph::Function>
node = make_shared<op::Exp>(args[0]);
break;
}
case OP_TYPEID::FakeQuantize:
{
size_t levels = node_js.at("levels").get<size_t>();
node = make_shared<op::FakeQuantize>(
args[0], args[1], args[2], args[3], args[4], levels);
break;
}
case OP_TYPEID::Floor:
{
node = make_shared<op::Floor>(args[0]);
......@@ -1915,6 +1923,12 @@ static json write(const Node& n, bool binary_constant_data)
}
case OP_TYPEID::Exp: { break;
}
case OP_TYPEID::FakeQuantize:
{
auto tmp = dynamic_cast<const op::FakeQuantize*>(&n);
node["levels"] = tmp->get_levels();
break;
}
case OP_TYPEID::Floor: { break;
}
case OP_TYPEID::Gather:
......
......@@ -1030,3 +1030,129 @@ NGRAPH_TEST(${BACKEND_NAME}, split_var_len_parts)
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, fake_quantize)
{
const Shape data_shape{1, 2, 3, 4};
const size_t levels = 4;
const auto data = make_shared<op::Parameter>(element::f32, data_shape);
const auto input_low = make_shared<op::Parameter>(element::f32, Shape{});
const auto input_high = make_shared<op::Parameter>(element::f32, Shape{});
const auto output_low = make_shared<op::Parameter>(element::f32, Shape{});
const auto output_high = make_shared<op::Parameter>(element::f32, Shape{});
const auto quantize =
make_shared<op::FakeQuantize>(data, input_low, input_high, output_low, output_high, levels);
const auto function = make_shared<Function>(
NodeVector{quantize},
ParameterVector{data, input_low, input_high, output_low, output_high});
auto test_case = ngraph::test::NgraphTestCase(function, "${BACKEND_NAME}");
const size_t n_elements = shape_size(data_shape);
vector<float> input_data(n_elements);
iota(begin(input_data), end(input_data), 0);
test_case.add_input<float>(input_data);
// input_low
test_case.add_input<float>({0.0f});
// input_high
test_case.add_input<float>({23.f});
// output_low
test_case.add_input<float>({2.f});
// output_high
test_case.add_input<float>({16.f});
// expected result
test_case.add_expected_output<float>(
data_shape,
vector<float>{2.f, 2.f, 2.f, 2.f, 6.6666669f,
6.6666669f, 6.6666669f, 6.6666669f, 6.6666669f, 6.6666669f,
6.6666669f, 6.6666669f, 11.33333301f, 11.33333301f, 11.33333301f,
11.33333301f, 11.33333301f, 11.33333301f, 11.33333301f, 11.33333301f,
16.f, 16.f, 16.f, 16.f});
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, fake_quantize_with_clip)
{
const Shape data_shape{1, 2, 3, 4};
const size_t levels = 5;
const auto data = make_shared<op::Parameter>(element::f32, data_shape);
const auto input_low = make_shared<op::Parameter>(element::f32, Shape{});
const auto input_high = make_shared<op::Parameter>(element::f32, Shape{});
const auto output_low = make_shared<op::Parameter>(element::f32, Shape{});
const auto output_high = make_shared<op::Parameter>(element::f32, Shape{});
const auto quantize =
make_shared<op::FakeQuantize>(data, input_low, input_high, output_low, output_high, levels);
const auto function = make_shared<Function>(
NodeVector{quantize},
ParameterVector{data, input_low, input_high, output_low, output_high});
auto test_case = ngraph::test::NgraphTestCase(function, "${BACKEND_NAME}");
const size_t n_elements = shape_size(data_shape);
vector<float> input_data(n_elements);
iota(begin(input_data), end(input_data), 0);
test_case.add_input<float>(input_data);
// input_low
test_case.add_input<float>({3.f});
// input_high
test_case.add_input<float>({17.f});
// output_low
test_case.add_input<float>({2.f});
// output_high
test_case.add_input<float>({16.f});
// expected result
test_case.add_expected_output<float>(
data_shape,
vector<float>{2.f, 2.f, 2.f, 2.f, 2.f, 5.5f, 5.5f, 5.5f, 5.5f, 9.f, 9.f, 9.f,
12.5f, 12.5f, 12.5f, 12.5f, 16.f, 16.f, 16.f, 16.f, 16.f, 16.f, 16.f, 16.f});
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, fake_quantize_with_clip_across_channels)
{
Shape data_shape{1, 2, 5, 5};
size_t levels = 5;
auto data = make_shared<op::Parameter>(element::f32, data_shape);
auto input_low = make_shared<op::Parameter>(element::f32, Shape{2});
auto input_high = make_shared<op::Parameter>(element::f32, Shape{2});
auto output_low = make_shared<op::Parameter>(element::f32, Shape{2});
auto output_high = make_shared<op::Parameter>(element::f32, Shape{2});
auto quantize =
make_shared<op::FakeQuantize>(data, input_low, input_high, output_low, output_high, levels);
auto function = make_shared<Function>(
NodeVector{quantize},
ParameterVector{data, input_low, input_high, output_low, output_high});
auto test_case = ngraph::test::NgraphTestCase(function, "${BACKEND_NAME}");
size_t n_elements = shape_size(data_shape);
vector<float> input_data(n_elements);
iota(begin(input_data), end(input_data), 0);
test_case.add_input<float>(input_data);
// input_low
test_case.add_input<float>(vector<float>{5.f, 30.f});
// input_high
test_case.add_input<float>(vector<float>{10.f, 40.f});
// output_low
test_case.add_input<float>(vector<float>{0.f, 50.f});
// output_high
test_case.add_input<float>(vector<float>{20.f, 70.f});
// expected result
test_case.add_expected_output<float>(
data_shape,
vector<float>{0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 5.0f, 10.0f, 10.0f, 15.0f,
20.0f, 20.0f, 20.0f, 20.0f, 20.0f, 20.0f, 20.0f, 20.0f, 20.0f, 20.0f,
20.0f, 20.0f, 20.0f, 20.0f, 20.0f, 50.0f, 50.0f, 50.0f, 50.0f, 50.0f,
50.0f, 50.0f, 55.0f, 55.0f, 60.0f, 60.0f, 60.0f, 65.0f, 65.0f, 70.0f,
70.0f, 70.0f, 70.0f, 70.0f, 70.0f, 70.0f, 70.0f, 70.0f, 70.0f, 70.0f});
test_case.run();
}
......@@ -14,12 +14,12 @@
// limitations under the License.
//*****************************************************************************
#include <memory>
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "ngraph/op/embedding_lookup.hpp"
#include <memory>
using namespace std;
using namespace ngraph;
......@@ -14684,3 +14684,94 @@ TEST(type_prop, split)
EXPECT_EQ(split->output(0).get_element_type(), element::i32);
EXPECT_EQ(split->output(1).get_element_type(), element::i32);
}
TEST(type_prop, fake_quantize)
{
const auto data = make_shared<op::Parameter>(element::f32, Shape{1, 2, 3, 4});
const auto input_low = make_shared<op::Parameter>(element::f32, Shape{});
const auto input_high = make_shared<op::Parameter>(element::f32, Shape{});
const auto output_low = make_shared<op::Parameter>(element::f32, Shape{});
const auto output_high = make_shared<op::Parameter>(element::f32, Shape{});
const int levels = 5;
const auto fake_quantize =
make_shared<op::FakeQuantize>(data, input_low, input_high, output_low, output_high, levels);
EXPECT_EQ(fake_quantize->get_element_type(), element::f32);
EXPECT_EQ(fake_quantize->get_shape(), (Shape{1, 2, 3, 4}));
}
TEST(type_prop, fake_quantize_invalid_rank)
{
const auto data = make_shared<op::Parameter>(element::f32, Shape{1, 2, 3, 4});
auto input_low = make_shared<op::Parameter>(element::f32, Shape{3});
auto input_high = make_shared<op::Parameter>(element::f32, Shape{});
auto output_low = make_shared<op::Parameter>(element::f32, Shape{});
auto output_high = make_shared<op::Parameter>(element::f32, Shape{});
const int levels = 5;
// Invalid input_low dimension
try
{
const auto fake_quantize = make_shared<op::FakeQuantize>(
data, input_low, input_high, output_low, output_high, levels);
EXPECT_FALSE(fake_quantize.get())
<< "FakeQuantize validation did not work. Op node was created with incorrect params.";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("must either be a scalar or a vector of size equal "
"to number of channels."));
}
// Invalid input_high dimension
input_low = make_shared<op::Parameter>(element::f32, Shape{});
input_high = make_shared<op::Parameter>(element::f32, Shape{3});
try
{
const auto fake_quantize = make_shared<op::FakeQuantize>(
data, input_low, input_high, output_low, output_high, levels);
EXPECT_FALSE(fake_quantize.get())
<< "FakeQuantize validation did not work. Op node was created with incorrect params.";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("must either be a scalar or a vector of size equal "
"to number of channels."));
}
// Invalid output_low dimension
input_high = make_shared<op::Parameter>(element::f32, Shape{});
output_low = make_shared<op::Parameter>(element::f32, Shape{3});
try
{
const auto fake_quantize = make_shared<op::FakeQuantize>(
data, input_low, input_high, output_low, output_high, levels);
EXPECT_FALSE(fake_quantize.get())
<< "FakeQuantize validation did not work. Op node was created with incorrect params.";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("must either be a scalar or a vector of size equal "
"to number of channels."));
}
// Invalid output_high dimension
output_low = make_shared<op::Parameter>(element::f32, Shape{});
output_high = make_shared<op::Parameter>(element::f32, Shape{3});
try
{
const auto fake_quantize = make_shared<op::FakeQuantize>(
data, input_low, input_high, output_low, output_high, levels);
EXPECT_FALSE(fake_quantize.get())
<< "FakeQuantize validation did not work. Op node was created with incorrect params.";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("must either be a scalar or a vector of size equal "
"to number of channels."));
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment