Commit 6b36a480 authored by Nishant Patel's avatar Nishant Patel Committed by Scott Cyphers

Add builder for {de}quantize to make API's consistent and support {de}quantize with mkldnn (#1839)

* Add builder for {de}quantize

* Add declaration in header

* Add mkldnn support for {de}quantize

* Add support for {de}quantize with mkldnn

* Add Dex support

* Generalizing some api's and adding a test case for DQ in backend_test.in.cpp

* Unify scale between ngraph and mkldnn

* Check for nullptrs

* PR feedback

* fix unit test failure

* Adding tests for builder and deleting the backend tests

* curly braces

* test rename
parent 1c53fd36
......@@ -18,12 +18,6 @@
#include "ngraph/builder/quantization.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/dequantize.hpp"
#include "ngraph/op/experimental/quantized_avg_pool.hpp"
#include "ngraph/op/experimental/quantized_conv.hpp"
#include "ngraph/op/experimental/quantized_conv_bias.hpp"
#include "ngraph/op/experimental/quantized_conv_relu.hpp"
#include "ngraph/op/experimental/quantized_max_pool.hpp"
#include "quantization_util.hpp"
using namespace std;
......@@ -33,6 +27,66 @@ namespace ngraph
{
namespace builder
{
std::shared_ptr<Node> ScaledQuantize(std::shared_ptr<Node> input,
std::shared_ptr<Node> min,
std::shared_ptr<Node> max,
const ngraph::element::Type& type,
const ngraph::AxisSet& axes,
op::Quantize::RoundMode round_mode)
{
auto offset = op::Constant::create(type, Shape{}, {0});
if (input->get_element_type() == element::f32)
{
float scale =
builder::quantization_util::get_quantization_scale<float>(min, max, type, true);
auto quantize_scale =
op::Constant::create(input->get_element_type(), Shape{}, {scale});
return make_shared<op::Quantize>(
input, quantize_scale, offset, type, axes, round_mode);
}
else if (input->get_element_type() == element::f64)
{
double scale = builder::quantization_util::get_quantization_scale<double>(
min, max, type, true);
auto quantize_scale =
op::Constant::create(input->get_element_type(), Shape{}, {scale});
return make_shared<op::Quantize>(
input, quantize_scale, offset, type, axes, round_mode);
}
else
{
throw ngraph_error("Unsupported quantization element type");
}
}
std::shared_ptr<Node> ScaledDequantize(std::shared_ptr<Node> input,
std::shared_ptr<Node> min,
std::shared_ptr<Node> max,
const ngraph::element::Type& type,
const ngraph::AxisSet& axes)
{
auto input_et = input->get_element_type();
auto offset = op::Constant::create(input_et, Shape{}, {0});
if (type == element::f32)
{
float scale =
builder::quantization_util::get_quantization_scale<float>(min, max, input_et);
auto dequantize_scale = op::Constant::create(type, Shape{}, {scale});
return make_shared<op::Dequantize>(input, dequantize_scale, offset, type, axes);
}
else if (type == element::f64)
{
double scale =
builder::quantization_util::get_quantization_scale<double>(min, max, input_et);
auto dequantize_scale = op::Constant::create(type, Shape{}, {scale});
return make_shared<op::Dequantize>(input, dequantize_scale, offset, type, axes);
}
else
{
throw ngraph_error("Unsupported dequantization element type");
}
}
std::shared_ptr<Node> ScaledQuantizedAvgPool(const std::shared_ptr<Node>& arg,
const Shape& window_shape,
const Strides& window_movement_strides,
......@@ -83,7 +137,8 @@ namespace ngraph
padding_below,
padding_above,
data_dilation_strides,
requantization_scale);
requantization_scale,
with_relu);
}
std::shared_ptr<Node>
......
......@@ -18,12 +18,31 @@
#include "ngraph/coordinate_diff.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/dequantize.hpp"
#include "ngraph/op/experimental/quantized_avg_pool.hpp"
#include "ngraph/op/experimental/quantized_conv.hpp"
#include "ngraph/op/experimental/quantized_conv_bias.hpp"
#include "ngraph/op/experimental/quantized_conv_relu.hpp"
#include "ngraph/op/experimental/quantized_max_pool.hpp"
#include "ngraph/op/quantize.hpp"
namespace ngraph
{
namespace builder
{
std::shared_ptr<Node> ScaledQuantize(std::shared_ptr<Node> input,
std::shared_ptr<Node> min,
std::shared_ptr<Node> max,
const ngraph::element::Type& type,
const ngraph::AxisSet& axes,
op::Quantize::RoundMode round_mode);
std::shared_ptr<Node> ScaledDequantize(std::shared_ptr<Node> input,
std::shared_ptr<Node> min,
std::shared_ptr<Node> max,
const ngraph::element::Type& type,
const ngraph::AxisSet& axes);
std::shared_ptr<Node> ScaledQuantizedAvgPool(const std::shared_ptr<Node>& arg,
const Shape& window_shape,
const Strides& window_movement_strides,
......
......@@ -74,23 +74,23 @@ namespace ngraph
std::static_pointer_cast<ngraph::op::Constant>(min_freezed_output);
auto max_freezed_output_const_op =
std::static_pointer_cast<ngraph::op::Constant>(max_freezed_output);
float input_min = *(static_cast<float const*>(min_input_const_op->get_data_ptr()));
float input_max = *(static_cast<float const*>(max_input_const_op->get_data_ptr()));
float filter_min =
*(static_cast<float const*>(min_filter_const_op->get_data_ptr()));
float filter_max =
*(static_cast<float const*>(max_filter_const_op->get_data_ptr()));
float output_min =
*(static_cast<float const*>(min_freezed_output_const_op->get_data_ptr()));
float output_max =
*(static_cast<float const*>(max_freezed_output_const_op->get_data_ptr()));
auto input_min = min_input_const_op->get_vector<float>();
auto input_max = max_input_const_op->get_vector<float>();
auto filter_min = min_filter_const_op->get_vector<float>();
auto filter_max = max_filter_const_op->get_vector<float>();
auto output_min = min_freezed_output_const_op->get_vector<float>();
auto output_max = max_freezed_output_const_op->get_vector<float>();
float min_out_value;
float max_out_value;
quantization_range_for_multiplication<uint8_t, int8_t, int32_t>(
input_min, input_max, filter_min, filter_max, &min_out_value, &max_out_value);
quantization_range_for_multiplication<uint8_t, int8_t, int32_t>(input_min[0],
input_max[0],
filter_min[0],
filter_max[0],
&min_out_value,
&max_out_value);
const float max_abs32 = std::max(std::abs(min_out_value), std::abs(max_out_value));
const float max_abs8 = std::max(std::abs(output_min), std::abs(output_max));
const float max_abs8 = std::max(std::abs(output_min[0]), std::abs(output_max[0]));
// Output is signed int.
// s32 = f32 * std::pow(2, 31)/ max_abs32;
// s8 = f32 * std::pow(2, 7)/ max_abs8;
......@@ -99,6 +99,59 @@ namespace ngraph
(std::pow(2, -24) * static_cast<double>(max_abs32 / max_abs8)));
return scale;
}
template <typename T>
static inline T get_quantization_scale(const std::shared_ptr<Node> min_input,
const std::shared_ptr<Node> max_input,
const ngraph::element::Type& type,
bool bump_by_eps = false)
{
auto min_input_const_op =
std::dynamic_pointer_cast<ngraph::op::Constant>(min_input);
auto max_input_const_op =
std::dynamic_pointer_cast<ngraph::op::Constant>(max_input);
if (min_input_const_op == nullptr)
{
throw ngraph_error("min input must be constant");
}
else if (max_input_const_op == nullptr)
{
throw ngraph_error("max input must be constant");
}
auto input_min_range = min_input_const_op->get_vector<T>();
auto input_max_range = max_input_const_op->get_vector<T>();
T min_range = std::numeric_limits<T>::min();
T max_range = std::numeric_limits<T>::max();
if (bump_by_eps)
{
// If input_min_range and input_max_range are close,
// introduce a slightly larger delta between them.
min_range = std::min(static_cast<T>(0.0f), input_min_range[0]);
const T epsilon = std::max(static_cast<T>(1.0f),
static_cast<T>(std::max(fabs(input_min_range[0]),
fabs(input_max_range[0])))) /
static_cast<T>(100.0f);
max_range = std::max(input_max_range[0], min_range + epsilon);
max_range = std::max(static_cast<T>(0.0f), max_range);
// end code copied and pasted from
// github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/quantize_op.cc
}
else
{
min_range = input_min_range[0];
max_range = input_max_range[0];
}
const T max_abs = std::max(std::abs(min_range), std::abs(max_range));
const T bitwidth = type.bitwidth();
const T target_range = static_cast<T>(
(type.is_signed() ? std::pow(2, (bitwidth - 1)) : std::pow(2, bitwidth)) - 1);
const T scale_factor = max_abs / target_range;
return scale_factor;
}
}
}
}
......@@ -46,9 +46,9 @@ op::QuantizedConvolution::QuantizedConvolution(const shared_ptr<Node>& data_batc
auto& filters_shape = filters->get_shape();
auto scale_const_op = std::static_pointer_cast<ngraph::op::Constant>(scale);
float scale_val = *(static_cast<float const*>(scale_const_op->get_data_ptr()));
auto scale_val = scale_const_op->get_vector<float>();
this->m_scale = scale_val;
this->m_scale = scale_val[0];
set_output_type(0,
element::i8,
......
......@@ -50,8 +50,8 @@ op::QuantizedConvolutionBias::QuantizedConvolutionBias(const shared_ptr<Node>& d
auto& filters_shape = filters->get_shape();
auto scale_const_op = std::static_pointer_cast<ngraph::op::Constant>(scale);
float scale_val = *(static_cast<float const*>(scale_const_op->get_data_ptr()));
this->m_scale = scale_val;
auto scale_val = scale_const_op->get_vector<float>();
this->m_scale = scale_val[0];
// TODO: call ngraph util
// util::validate_convbias_shapes(data_batch_shape, filters_shape, bias->get_shape());
......
......@@ -4670,31 +4670,81 @@ namespace ngraph
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Dequantize)
{
auto dequantize = static_cast<const ngraph::op::Dequantize*>(node);
writer << "reference::dequantize(";
writer << " " << args[0].get_name() << ",\n";
writer << " " << args[1].get_name() << ",\n";
writer << " " << args[2].get_name() << ",\n";
writer << " " << out[0].get_name() << ",\n";
writer << " {" << join(args[0].get_shape()) << "},\n";
writer << " {" << join(args[1].get_shape()) << "},\n";
writer << " {" << join(dequantize->get_axes()) << "});\n";
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto input_data_desc = mkldnn_utils::get_input_mkldnn_md(node, 0);
auto result_desc = mkldnn_utils::get_output_mkldnn_md(node, 0);
size_t dequantize_index =
mkldnn_emitter->build_dequantization(node, input_data_desc, result_desc);
auto& deps = mkldnn_emitter->get_primitive_deps(dequantize_index);
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[0])
<< ", " << args[0].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[1])
<< ", " << out[0].get_name() << ");\n";
writer << "cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, "
<< to_string(dequantize_index) << ");\n";
}
else
{
auto dequantize = static_cast<const ngraph::op::Dequantize*>(node);
writer << "reference::dequantize(";
writer << " " << args[0].get_name() << ",\n";
writer << " " << args[1].get_name() << ",\n";
writer << " " << args[2].get_name() << ",\n";
writer << " " << out[0].get_name() << ",\n";
writer << " {" << join(args[0].get_shape()) << "},\n";
writer << " {" << join(args[1].get_shape()) << "},\n";
writer << " {" << join(dequantize->get_axes()) << "});\n";
}
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Quantize)
{
auto quantize = static_cast<const ngraph::op::Quantize*>(node);
writer << "reference::quantize(";
writer << " " << args[0].get_name() << ",\n";
writer << " " << args[1].get_name() << ",\n";
writer << " " << args[2].get_name() << ",\n";
writer << " " << out[0].get_name() << ",\n";
writer << " {" << join(args[0].get_shape()) << "},\n";
writer << " {" << join(args[1].get_shape()) << "},\n";
writer << " {" << join(quantize->get_axes()) << "},\n";
writer << " static_cast<op::Quantize::RoundMode>("
<< static_cast<int>(quantize->get_round_mode()) << "));\n";
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto input_data_desc = mkldnn_utils::get_input_mkldnn_md(node, 0);
auto result_desc = mkldnn_utils::get_output_mkldnn_md(node, 0);
auto scale_const_op =
std::dynamic_pointer_cast<ngraph::op::Constant>(quantize->get_argument(1));
if (scale_const_op == nullptr)
{
throw ngraph_error("Quantize scale must be a constant");
}
auto scale = scale_const_op->get_vector<float>();
std::vector<float> scales;
scales.push_back(1.0 / scale[0]);
size_t quantize_index = 0;
quantize_index = mkldnn_emitter->build_quantize_reorder(
input_data_desc, result_desc, scales);
auto& deps = mkldnn_emitter->get_primitive_deps(quantize_index);
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[0])
<< ", " << args[0].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[1])
<< ", " << out[0].get_name() << ");\n";
writer << "cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, "
<< to_string(quantize_index) << ");\n";
}
else
{
writer << "reference::quantize(";
writer << " " << args[0].get_name() << ",\n";
writer << " " << args[1].get_name() << ",\n";
writer << " " << args[2].get_name() << ",\n";
writer << " " << out[0].get_name() << ",\n";
writer << " {" << join(args[0].get_shape()) << "},\n";
writer << " {" << join(args[1].get_shape()) << "},\n";
writer << " {" << join(quantize->get_axes()) << "},\n";
writer << " static_cast<op::Quantize::RoundMode>("
<< static_cast<int>(quantize->get_round_mode()) << "));\n";
}
}
#undef TI
......
......@@ -20,6 +20,7 @@
#include "mkldnn_emitter.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/dequantize.hpp"
#include "ngraph/op/experimental/quantized_avg_pool.hpp"
#include "ngraph/op/experimental/quantized_max_pool.hpp"
#include "ngraph/runtime/cpu/cpu_layout_descriptor.hpp"
......@@ -122,6 +123,40 @@ size_t MKLDNNEmitter::build_memory_primitive(const mkldnn::memory::desc& desc)
return index;
}
size_t MKLDNNEmitter::build_quantize_reorder(const mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& result_desc,
const std::vector<float>& scales)
{
size_t input_index = build_memory_primitive(input_desc);
size_t result_index = build_memory_primitive(result_desc);
mkldnn::primitive_attr attr;
attr.set_output_scales(0, scales);
attr.set_int_output_round_mode(mkldnn::round_mode::round_nearest);
auto reorder_desc =
mkldnn::reorder::primitive_desc({input_desc, mkldnn_utils::global_cpu_engine},
{result_desc, mkldnn_utils::global_cpu_engine},
attr);
size_t primitive_index = insert_primitive(new mkldnn::reorder(
reorder_desc, *m_mkldnn_primitives[input_index], *m_mkldnn_primitives[result_index]));
m_primitive_deps[primitive_index] = {input_index, result_index};
return primitive_index;
}
size_t MKLDNNEmitter::build_dequantization(const ngraph::Node* node,
const mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& result_desc)
{
auto dequantize = static_cast<const ngraph::op::Dequantize*>(node);
auto scale_const_op =
std::static_pointer_cast<ngraph::op::Constant>(dequantize->get_argument(1));
float scale = *(static_cast<float const*>(scale_const_op->get_data_ptr()));
std::vector<float> scales;
scales.push_back(scale);
size_t dequantize_index = 0;
dequantize_index = this->build_quantize_reorder(input_desc, result_desc, scales);
return dequantize_index;
}
size_t MKLDNNEmitter::build_quantized_max_pool(const ngraph::Node* node)
{
auto qmax_pool = static_cast<const ngraph::op::QuantizedMaxPool*>(node);
......
......@@ -593,6 +593,14 @@ namespace ngraph
size_t build_quantized_avg_pool(const ngraph::Node* node);
size_t build_dequantization(const ngraph::Node* node,
const mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& result_desc);
size_t build_quantize_reorder(const mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& result_desc,
const std::vector<float>& scales);
private:
std::vector<mkldnn::primitive*> m_mkldnn_primitives;
std::vector<mkldnn::stream> m_mkldnn_streams;
......
......@@ -29,7 +29,9 @@
#include "ngraph/op/avg_pool.hpp"
#include "ngraph/op/batch_norm.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/dequantize.hpp"
#include "ngraph/op/experimental/quantized_avg_pool.hpp"
#include "ngraph/op/experimental/quantized_conv.hpp"
#include "ngraph/op/experimental/quantized_conv_bias.hpp"
......@@ -37,6 +39,7 @@
#include "ngraph/op/experimental/quantized_max_pool.hpp"
#include "ngraph/op/lrn.hpp"
#include "ngraph/op/max_pool.hpp"
#include "ngraph/op/quantize.hpp"
#include "ngraph/op/relu.hpp"
#include "ngraph/op/replace_slice.hpp"
#include "ngraph/op/slice.hpp"
......@@ -787,7 +790,6 @@ namespace ngraph
void CPUAssignment::ASSIGN_DECL(ngraph::op::QuantizedConvolutionBias)
{
auto quantized_conv_bias = static_cast<op::QuantizedConvolutionBias*>(node);
if (node->get_input_element_type(0) == element::u8 &&
node->get_input_element_type(1) == element::i8)
{
......@@ -797,6 +799,63 @@ namespace ngraph
quantized_conv_bias->set_op_annotations(op_annotations);
}
}
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::Dequantize)
{
auto dequantize = static_cast<op::Dequantize*>(node);
auto offset_const_op =
std::static_pointer_cast<ngraph::op::Constant>(dequantize->get_argument(2));
if (node->get_input_element_type(0) == element::u8)
{
auto offset = offset_const_op->get_vector<uint8_t>();
if (offset[0] != 0)
return;
}
if (node->get_input_element_type(0) == element::i8)
{
auto offset = offset_const_op->get_vector<int8_t>();
if (offset[0] != 0)
return;
}
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
dequantize->set_op_annotations(op_annotations);
}
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::Quantize)
{
auto quantize = static_cast<op::Quantize*>(node);
auto offset_const_op =
std::static_pointer_cast<ngraph::op::Constant>(quantize->get_argument(2));
op::Quantize::RoundMode round_mode = quantize->get_round_mode();
if (round_mode != op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_EVEN)
{
return;
}
if (node->get_output_element_type(0) == element::u8)
{
auto offset = offset_const_op->get_vector<uint8_t>();
if (offset[0] != 0)
{
return;
}
}
if (node->get_output_element_type(0) == element::i8)
{
auto offset = offset_const_op->get_vector<int8_t>();
if (offset[0] != 0)
{
return;
}
}
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
quantize->set_op_annotations(op_annotations);
}
}
}
}
......@@ -870,6 +929,9 @@ static const runtime::cpu::pass::AssignOpMap s_dispatcher{
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::QuantizedConvolutionRelu>},
{TI(ngraph::op::QuantizedConvolutionBias),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::QuantizedConvolutionBias>},
{TI(ngraph::op::Quantize), &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::Quantize>},
{TI(ngraph::op::Dequantize),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::Dequantize>},
};
bool runtime::cpu::pass::CPUAssignment::run_on_call_graph(
......
......@@ -322,11 +322,14 @@ namespace ngraph
std::unique_ptr<convolution_forward::desc> fwd_desc{nullptr};
if (use_bias)
{
memory::data_type et_bias =
mkldnn_utils::get_mkldnn_data_type(node->get_input_element_type(2));
auto arg2_shape = node->get_input_shape(2);
ngraph::op::util::validate_convbias_shapes(
arg0_shape, arg1_shape, arg2_shape);
memory::dims mkldnn_arg2_shape(arg2_shape.begin(), arg2_shape.end());
const memory::desc bias_desc(mkldnn_arg2_shape, et, memory::format::any);
const memory::desc bias_desc(
mkldnn_arg2_shape, et_bias, memory::format::any);
try
{
fwd_desc.reset(
......
......@@ -54,6 +54,7 @@ quantize_clamp
dequantize
dequantize_axes
dequantize_int8
dequantize_zero_offset
quantize_ROUND_NEAREST_TOWARD_ZERO
quantize_ROUND_NEAREST_UPWARD
quantize_ROUND_NEAREST_DOWNWARD
......@@ -61,4 +62,4 @@ quantize_ROUND_NEAREST_TOWARD_EVEN
quantize_ROUND_TOWARD_INFINITY
quantize_ROUND_TOWARD_ZERO
quantize_ROUND_UP
quantize_ROUND_DOWN
\ No newline at end of file
quantize_ROUND_DOWN
......@@ -20,6 +20,7 @@ batch_norm_three_outputs
dequantize
dequantize_axes
dequantize_int8
dequantize_zero_offset
divide_by_zero_int32
dot_3d_multi_axis
dot_4d_5d_multi_axis
......
......@@ -70,7 +70,7 @@ if (NGRAPH_HYBRID_ENABLE)
endif()
if (NGRAPH_CPU_ENABLE)
list(APPEND SRC core_fusion.cpp quantize_cpu.cpp)
list(APPEND SRC core_fusion.cpp builder_quantization.cpp)
list(APPEND SRC backend_performance.cpp cpu_fusion.cpp cpu_test.cpp cpu_reshape_sinking.cpp cpu_debugger.cpp)
if (NGRAPH_HALIDE)
list(APPEND SRC halide.cpp)
......
......@@ -5518,6 +5518,38 @@ NGRAPH_TEST(${BACKEND_NAME}, dequantize)
read_vector<output_c_type>(y));
}
NGRAPH_TEST(${BACKEND_NAME}, dequantize_zero_offset)
{
Shape input_shape{4, 3};
Shape scale_offset_shape;
AxisSet quantization_axes;
auto input_type = element::u8;
auto output_type = element::f32;
typedef uint8_t input_c_type;
typedef float output_c_type;
auto X = make_shared<op::Parameter>(input_type, input_shape);
auto scale = op::Constant::create(output_type, scale_offset_shape, {2});
auto offset = op::Constant::create(input_type, scale_offset_shape, {0});
auto dequantize = make_shared<op::Dequantize>(X, scale, offset, output_type, quantization_axes);
auto f = make_shared<Function>(dequantize, op::ParameterVector{X});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
auto x = backend->create_tensor(input_type, input_shape);
auto y = backend->create_tensor(output_type, input_shape);
copy_data(x, vector<input_c_type>{1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7});
// minus offset 0 0 0 0 0 0 0 0 0 0 0 0
// multiplied by scale 2 2 2 2 2 2 2 2 2 2 2 2
// equals 2 4 4 6 6 8 8 10 10 12 12 14
backend->call_with_validate(f, {y}, {x});
EXPECT_EQ((vector<output_c_type>{2, 4, 4, 6, 6, 8, 8, 10, 10, 12, 12, 14}),
read_vector<output_c_type>(y));
}
NGRAPH_TEST(${BACKEND_NAME}, quantize_axes)
{
Shape input_shape{4, 3};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment