Commit ee1a5ca8 authored by tsocha's avatar tsocha Committed by Scott Cyphers

[ONNX] Quantization operators (#2495)

* Lin quant ONNX operators

* Add EOL at EOF

* Update licence year

* Cleanup namespaces

* Add missing headers

* Check if group is correct

* Remove Check in dequantize

* Remove TODO

* Add converting zero point to data element type

* Remove unused scale variable

* Removing zero point

* Remove make_ng_conv_bias helper function

* Add test for Quantize Linear

* Remove dead code

* Add test for Dequantize Linear

* Add Quant Conv test

* Convert models to prototxt

* Remove artifact test

* Skip test on INTERPRETER

* Style check

* Remove stupid skipping

* Enable test skipping in onnx_backend tests

* Skip GPU tests on quantized operators

* Review fix
parent 6ca8ba97
......@@ -74,6 +74,8 @@ add_library(onnx_import STATIC
op/conv_transpose.hpp
op/depth_to_space.cpp
op/depth_to_space.hpp
op/dequantize_linear.cpp
op/dequantize_linear.hpp
op/div.hpp
op/dropout.hpp
op/elu.cpp
......@@ -123,6 +125,10 @@ add_library(onnx_import STATIC
op/pow.hpp
op/prelu.cpp
op/prelu.hpp
op/quant_conv.cpp
op/quant_conv.hpp
op/quantize_linear.cpp
op/quantize_linear.hpp
op/reciprocal.cpp
op/reciprocal.hpp
op/reduce.cpp
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <cstddef>
#include <cstdint>
#include <iterator>
#include <memory>
#include "exceptions.hpp"
#include "ngraph/axis_set.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/convert.hpp"
#include "ngraph/op/dequantize.hpp"
#include "ngraph/shape.hpp"
#include "quantize_linear.hpp"
#include "utils/common.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector dequantize_linear(const Node& node)
{
NodeVector inputs{node.get_ng_inputs()};
std::shared_ptr<ngraph::Node> x = inputs.at(0);
std::shared_ptr<ngraph::Node> x_scale = inputs.at(1);
std::shared_ptr<ngraph::Node> zero_point;
if (inputs.size() == 3 && !inputs.at(2)->is_null())
{
zero_point = inputs.at(2);
}
else
{
zero_point = common::make_constant_node(
x->get_element_type(), Shape{}, std::vector<std::uint8_t>{0});
}
Shape y_scale_shape = x_scale->get_shape();
Shape y_zero_point_shape = zero_point->get_shape();
ASSERT_VALID_ARGUMENT(node, y_scale_shape.size() == 0)
<< "x_scale must be a scalar.";
ASSERT_VALID_ARGUMENT(node, y_zero_point_shape.size() == 0)
<< "zero_point must be a scalar.";
if (x->get_element_type() != zero_point->get_element_type())
{
zero_point = std::make_shared<ngraph::op::Convert>(zero_point,
x->get_element_type());
}
return {std::make_shared<ngraph::op::Dequantize>(
x, x_scale, zero_point, x_scale->get_element_type(), AxisSet{})};
}
} // namespace set_1
} //namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "core/node.hpp"
#include "ngraph/node_vector.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector dequantize_linear(const Node& node);
} // namespace set_1
} //namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <cstddef>
#include <memory>
#include <vector>
#include "ngraph/builder/quantization/quantized_linear_convolution.hpp"
#include "ngraph/coordinate_diff.hpp"
#include "ngraph/frontend/onnx_import/exceptions.hpp"
#include "ngraph/frontend/onnx_import/op/conv.hpp"
#include "ngraph/frontend/onnx_import/utils/broadcasting.hpp"
#include "ngraph/frontend/onnx_import/utils/convpool.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/op/experimental/quantized_conv.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/slice.hpp"
#include "ngraph/strides.hpp"
#include "quant_conv.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace set_1
{
namespace
{
struct OpScale
{
std::shared_ptr<ngraph::Node> data_scale;
std::shared_ptr<ngraph::Node> filter_scale;
std::shared_ptr<ngraph::Node> output_scale;
};
std::shared_ptr<ngraph::Node>
make_ng_quant_conv(const std::shared_ptr<ngraph::Node>& data,
const std::shared_ptr<ngraph::Node>& filters,
const Strides& strides,
const Strides& filter_dilations,
const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above,
const Strides& data_dilations,
int groups,
const OpScale& op_scale,
const std::shared_ptr<ngraph::Node>& bias = nullptr)
{
if (groups > 1)
{
// Split one convolution op to N ops where N is the number of groups
// and concat results after computation.
// reference: https://github.com/NervanaSystems/ngraph-mxnet/blob/fdd692/src/ngraph/ngraph_emitter.cc#L822-L856
std::size_t n_data_channels{data->get_shape().at(1)};
std::size_t n_filters_channels{filters->get_shape().at(0)};
std::size_t data_group_size{n_data_channels / groups};
std::size_t filters_group_size{n_filters_channels / groups};
NodeVector convolution_nodes;
// initial bounds for splice
std::vector<std::size_t> data_lower_bounds(data->get_shape().size());
std::vector<std::size_t> data_upper_bounds{data->get_shape()};
std::vector<std::size_t> filters_lower_bounds(
filters->get_shape().size());
std::vector<std::size_t> filters_upper_bounds{filters->get_shape()};
for (std::size_t group{0}; group < groups; ++group)
{
// slice data
data_lower_bounds[1] = group * data_group_size;
data_upper_bounds[1] = (group + 1) * data_group_size;
auto sliced_data = std::make_shared<ngraph::op::Slice>(
data, data_lower_bounds, data_upper_bounds);
// slice filters
filters_lower_bounds[0] = group * filters_group_size;
filters_upper_bounds[0] = (group + 1) * filters_group_size;
auto sliced_filters = std::make_shared<ngraph::op::Slice>(
filters, filters_lower_bounds, filters_upper_bounds);
if (bias)
{
throw error::NotSupported(
"Groups != 1 not supported for Quantized Convolution with "
"bias.");
}
else
{
convolution_nodes.push_back(
ngraph::builder::quantization::QuantizedLinearConvolution(
sliced_data,
sliced_filters,
strides,
filter_dilations,
padding_below,
padding_above,
data_dilations,
op_scale.data_scale,
op_scale.filter_scale,
op_scale.output_scale));
}
}
std::size_t concatenation_axis = 1;
return std::make_shared<ngraph::op::Concat>(convolution_nodes,
concatenation_axis);
}
else
{
if (bias)
{
return ngraph::builder::quantization::
QuantizedLinearConvolutionBias(data,
filters,
bias,
strides,
filter_dilations,
padding_below,
padding_above,
data_dilations,
op_scale.data_scale,
op_scale.filter_scale,
op_scale.output_scale);
}
else
{
return ngraph::builder::quantization::QuantizedLinearConvolution(
data,
filters,
strides,
filter_dilations,
padding_below,
padding_above,
data_dilations,
op_scale.data_scale,
op_scale.filter_scale,
op_scale.output_scale);
}
}
}
} // namespace
NodeVector quant_conv(const Node& node)
{
NGRAPH_WARN << "[" << node.get_name()
<< "] Zero point different from 0 is not supported. Assuming Zero "
"point is 0";
const NodeVector& inputs = node.get_ng_inputs();
auto data = inputs.at(0);
auto filters = inputs.at(3);
int64_t groups{node.get_attribute_value<int64_t>("group", 1)};
auto data_scale = inputs.at(1);
auto filters_scale = inputs.at(4);
auto output_scale = inputs.at(6);
ASSERT_VALID_ARGUMENT(node,
((groups >= 0) && (groups <= data->get_shape().at(1)) &&
(groups <= filters->get_shape().at(0))))
<< "incorrect value of 'group' attribute: " << groups;
std::size_t n_data_channels{data->get_shape().at(1)};
std::size_t n_filters_channels{filters->get_shape().at(0)};
ASSERT_VALID_ARGUMENT(node, n_data_channels % groups == 0)
<< "provided group attribute value must be a multiple of data channels "
"count.";
ASSERT_VALID_ARGUMENT(node, n_filters_channels % groups == 0)
<< "provided group attribute value must be a multiple of filter channels "
"count.";
Strides strides = convpool::get_strides(node);
Strides filter_dilations = convpool::get_dilations(node);
Strides data_dilations = Strides(convpool::get_kernel_shape(node).size(), 1UL);
auto paddings = convpool::get_pads(node);
const CoordinateDiff& padding_below = paddings.first;
const CoordinateDiff& padding_above = paddings.second;
std::shared_ptr<ngraph::Node> conv_node = nullptr;
// no bias param
if (inputs.size() == 9 && !inputs.at(8)->is_null())
{
auto bias = inputs.at(8);
conv_node =
make_ng_quant_conv(data,
filters,
strides,
filter_dilations,
padding_below,
padding_above,
data_dilations,
groups,
OpScale{data_scale, filters_scale, output_scale},
bias);
}
else
{
conv_node =
make_ng_quant_conv(data,
filters,
strides,
filter_dilations,
padding_below,
padding_above,
data_dilations,
groups,
OpScale{data_scale, filters_scale, output_scale});
}
return {conv_node};
}
} // namespace set_1
} //namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "core/node.hpp"
#include "ngraph/node_vector.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace set_1
{
/// \brief Performs ONNX Quant Conv operation.
///
/// \param node The ONNX node object representing this operation.
///
/// \return The vector containing Ngraph nodes producing output of ONNX quantizied convolution
/// operation.
NodeVector quant_conv(const Node& node);
} // namespace set_1
} //namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <cstddef>
#include <cstdint>
#include <iterator>
#include <memory>
#include "exceptions.hpp"
#include "ngraph/axis_set.hpp"
#include "ngraph/op/quantize.hpp"
#include "ngraph/shape.hpp"
#include "quantize_linear.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector quantize_linear(const Node& node)
{
NodeVector inputs{node.get_ng_inputs()};
std::shared_ptr<ngraph::Node> x = inputs.at(0);
std::shared_ptr<ngraph::Node> y_scale = inputs.at(1);
std::shared_ptr<ngraph::Node> y_zero_point = inputs.at(2);
Shape y_scale_shape = y_scale->get_shape();
Shape y_zero_point_shape = y_zero_point->get_shape();
ASSERT_VALID_ARGUMENT(node, y_scale_shape.size() == 0)
<< "y_scale must be a scalar.";
ASSERT_VALID_ARGUMENT(node, y_zero_point_shape.size() == 0)
<< "y_zero_point must be a scalar.";
return {std::make_shared<ngraph::op::Quantize>(
x,
y_scale,
y_zero_point,
y_zero_point->get_element_type(),
AxisSet{},
ngraph::op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_EVEN)};
}
} // namespace set_1
} //namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "core/node.hpp"
#include "ngraph/node_vector.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector quantize_linear(const Node& node);
} // namespace set_1
} //namespace op
} // namespace onnx_import
} // namespace ngraph
......@@ -45,6 +45,7 @@
#include "op/cos.hpp"
#include "op/cosh.hpp"
#include "op/depth_to_space.hpp"
#include "op/dequantize_linear.hpp"
#include "op/div.hpp"
#include "op/dropout.hpp"
#include "op/elu.hpp"
......@@ -79,6 +80,8 @@
#include "op/pad.hpp"
#include "op/pow.hpp"
#include "op/prelu.hpp"
#include "op/quant_conv.hpp"
#include "op/quantize_linear.hpp"
#include "op/reciprocal.hpp"
#include "op/reduce.hpp"
#include "op/relu.hpp"
......@@ -240,6 +243,7 @@ namespace ngraph
REGISTER_OPERATOR("Cos", 1, cos);
REGISTER_OPERATOR("Cosh", 1, cosh);
REGISTER_OPERATOR("DepthToSpace", 1, depth_to_space);
REGISTER_OPERATOR("DequantizeLinear", 1, dequantize_linear);
REGISTER_OPERATOR("Div", 1, div);
REGISTER_OPERATOR("Div", 7, div);
REGISTER_OPERATOR("Dropout", 1, dropout);
......@@ -278,6 +282,8 @@ namespace ngraph
REGISTER_OPERATOR("Pad", 1, pad);
REGISTER_OPERATOR("Pow", 1, pow);
REGISTER_OPERATOR("PRelu", 1, prelu);
REGISTER_OPERATOR("QLinearConv", 1, quant_conv);
REGISTER_OPERATOR("QuantizeLinear", 1, quantize_linear);
REGISTER_OPERATOR("Reciprocal", 1, reciprocal);
REGISTER_OPERATOR("ReduceLogSum", 1, reduce_log_sum);
REGISTER_OPERATOR("ReduceLogSumExp", 1, reduce_log_sum_exp);
......
......@@ -96,3 +96,8 @@ all_2x2x3_eliminate_dims_0_1_2
floor_int32
divide_int32
one_hot_scalar_oob_in_3
# Quantized operators are not supported on gpu backend
model_dequantize_linear
model_quantize_linear
model_quant_conv_linear
# Quantized convolution is not supported on interpreter
model_quant_conv_linear
ir_version: 3
producer_name: "ngraph ONNXImporter"
graph {
node {
input: "X"
input: "x_scale"
input: "zero_point"
output: "Y"
name: "DequantizeLinear"
op_type: "DequantizeLinear"
}
name: "test_graph"
initializer {
data_type: 2
name: "zero_point"
raw_data: "\000"
}
initializer {
data_type: 1
float_data: 4
name: "x_scale"
}
input {
name: "X"
type {
tensor_type {
elem_type: 2
shape {
dim {
dim_value: 2
}
dim {
dim_value: 2
}
}
}
}
}
input {
name: "x_scale"
type {
tensor_type {
elem_type: 1
shape {
}
}
}
}
input {
name: "zero_point"
type {
tensor_type {
elem_type: 2
shape {
}
}
}
}
output {
name: "Y"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2
}
dim {
dim_value: 2
}
}
}
}
}
}
opset_import {
version: 4
}
ir_version: 3
producer_name: "backend-test"
graph {
node {
input: "x"
input: "x_scale"
input: "x_zero_point"
input: "w"
input: "w_scale"
input: "w_zero_point"
input: "y_scale"
input: "y_zero_point"
output: "y"
op_type: "QLinearConv"
attribute {
name: "pads"
ints: 1
ints: 1
ints: 1
ints: 1
type: INTS
}
}
name: "test_conv_with_strides_padding"
initializer {
data_type: 2
name: "x_zero_point"
raw_data: "\000"
}
initializer {
data_type: 1
float_data: 1
name: "x_scale"
}
initializer {
data_type: 2
name: "w_zero_point"
raw_data: "\000"
}
initializer {
data_type: 1
float_data: 1
name: "w_scale"
}
initializer {
data_type: 2
name: "y_zero_point"
raw_data: "\000"
}
initializer {
data_type: 1
float_data: 16
name: "y_scale"
}
initializer {
dims: 1
dims: 1
dims: 3
dims: 3
data_type: 3
int32_data: 1
int32_data: 0
int32_data: 2
int32_data: 1
int32_data: 0
int32_data: 2
int32_data: 1
int32_data: 0
int32_data: 2
name: "w"
}
input {
name: "x"
type {
tensor_type {
elem_type: 2
shape {
dim {
dim_value: 1
}
dim {
dim_value: 1
}
dim {
dim_value: 9
}
dim {
dim_value: 9
}
}
}
}
}
input {
name: "x_scale"
type {
tensor_type {
elem_type: 1
shape {
}
}
}
}
input {
name: "x_zero_point"
type {
tensor_type {
elem_type: 2
shape {
}
}
}
}
input {
name: "w"
type {
tensor_type {
elem_type: 3
shape {
dim {
dim_value: 1
}
dim {
dim_value: 1
}
dim {
dim_value: 3
}
dim {
dim_value: 3
}
}
}
}
}
input {
name: "w_scale"
type {
tensor_type {
elem_type: 1
shape {
}
}
}
}
input {
name: "w_zero_point"
type {
tensor_type {
elem_type: 2
shape {
}
}
}
}
input {
name: "y_scale"
type {
tensor_type {
elem_type: 1
shape {
}
}
}
}
input {
name: "y_zero_point"
type {
tensor_type {
elem_type: 2
shape {
}
}
}
}
output {
name: "y"
type {
tensor_type {
elem_type: 2
shape {
dim {
dim_value: 1
}
dim {
dim_value: 1
}
dim {
dim_value: 9
}
dim {
dim_value: 9
}
}
}
}
}
}
opset_import {
version: 8
}
ir_version: 3
producer_name: "ngraph ONNXImporter"
graph {
node {
input: "X"
input: "y_scale"
input: "y_zero_point"
output: "Y"
name: "QuantizeLinear"
op_type: "QuantizeLinear"
}
name: "test_graph"
initializer {
data_type: 2
name: "y_zero_point"
raw_data: "\000"
}
input {
name: "X"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2
}
dim {
dim_value: 2
}
}
}
}
}
input {
name: "y_scale"
type {
tensor_type {
elem_type: 1
shape {
}
}
}
}
input {
name: "y_zero_point"
type {
tensor_type {
elem_type: 2
shape {
}
}
}
}
output {
name: "Y"
type {
tensor_type {
elem_type: 2
shape {
dim {
dim_value: 2
}
dim {
dim_value: 2
}
}
}
}
}
}
opset_import {
version: 4
}
......@@ -2226,3 +2226,57 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, import_non_existing_file)
EXPECT_TRUE(msg.find("i.dont.exist") != std::string::npos);
}
}
NGRAPH_TEST(onnx_${BACKEND_NAME}, model_quantize_linear)
{
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/quant_lin.prototxt"));
Inputs inputs;
inputs.emplace_back(std::vector<float>{32.25f, 48.34f, 50.f, 83.f});
inputs.emplace_back(std::vector<float>{0.5f});
std::vector<std::vector<std::uint8_t>> expected_output{
std::vector<std::uint8_t>{64, 97, 100, 166}};
std::vector<std::vector<std::uint8_t>> outputs{
execute<float, std::uint8_t>(function, inputs, "${BACKEND_NAME}")};
EXPECT_TRUE(test::all_close(expected_output.front(), outputs.front()));
}
NGRAPH_TEST(onnx_${BACKEND_NAME}, model_dequantize_linear)
{
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/dequant_lin.prototxt"));
std::vector<std::vector<std::uint8_t>> inputs;
inputs.emplace_back(std::vector<std::uint8_t>{19, 210, 21, 10});
Outputs expected_output{std::vector<float>{76.f, 840.f, 84.f, 40.f}};
Outputs outputs{execute<std::uint8_t, float>(function, inputs, "${BACKEND_NAME}")};
EXPECT_TRUE(test::all_close_f(expected_output.front(), outputs.front()));
}
NGRAPH_TEST(onnx_${BACKEND_NAME}, model_quant_conv_linear)
{
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/quant_conv_lin.prototxt"));
std::vector<std::vector<std::uint8_t>> inputs;
inputs.emplace_back(std::vector<std::uint8_t>{
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21,
22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42,
43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81});
std::vector<std::vector<std::int8_t>> expected_output{std::vector<std::int8_t>{
2, 3, 3, 3, 4, 4, 4, 5, 2, 4, 6, 7, 8, 8, 9, 9, 10, 3, 8, 11, 12,
13, 13, 14, 14, 15, 5, 11, 16, 17, 18, 18, 19, 19, 20, 7, 14, 22, 22, 23, 23, 24,
24, 25, 8, 18, 27, 27, 28, 28, 29, 29, 30, 10, 21, 32, 32, 33, 33, 34, 34, 35, 12,
24, 37, 37, 38, 38, 39, 40, 40, 13, 17, 26, 27, 27, 27, 28, 28, 28, 9}};
std::vector<std::vector<std::int8_t>> outputs{
execute<std::uint8_t, std::int8_t>(function, inputs, "${BACKEND_NAME}")};
EXPECT_TRUE(test::all_close(expected_output.front(), outputs.front()));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment